From cc5697f9685479191c012cb919c97f5eedd1e041 Mon Sep 17 00:00:00 2001 From: Chris Forbes Date: Wed, 30 Jan 2019 11:54:08 -0800 Subject: [PATCH] Squashed 'third_party/SPIRV-Tools/' content from commit d14db341b git-subtree-dir: third_party/SPIRV-Tools git-subtree-split: d14db341b834cfb3c574a258c331b3a6b1c2cbc5 --- .appveyor.yml | 89 + .clang-format | 6 + .gitignore | 25 + .gn | 20 + Android.mk | 343 + BUILD.gn | 801 + CHANGES | 741 + CMakeLists.txt | 294 + CODE_OF_CONDUCT.md | 1 + CONTRIBUTING.md | 192 + DEPS | 174 + LICENSE | 202 + PRESUBMIT.py | 40 + README.md | 535 + android_test/Android.mk | 12 + android_test/jni/Application.mk | 5 + android_test/test.cpp | 22 + build_overrides/build.gni | 46 + build_overrides/gtest.gni | 25 + build_overrides/spirv_tools.gni | 25 + cmake/SPIRV-Tools-shared.pc.in | 12 + cmake/SPIRV-Tools.pc.in | 12 + cmake/setup_build.cmake | 20 + cmake/write_pkg_config.cmake | 31 + codereview.settings | 2 + examples/CMakeLists.txt | 35 + examples/cpp-interface/CMakeLists.txt | 19 + examples/cpp-interface/main.cpp | 64 + external/CMakeLists.txt | 104 + include/spirv-tools/instrument.hpp | 135 + include/spirv-tools/libspirv.h | 680 + include/spirv-tools/libspirv.hpp | 245 + include/spirv-tools/linker.hpp | 97 + include/spirv-tools/optimizer.hpp | 731 + kokoro/android/build.sh | 51 + kokoro/android/continuous.cfg | 17 + kokoro/android/presubmit.cfg | 16 + kokoro/check-format/build.sh | 41 + kokoro/check-format/presubmit_check_format.cfg | 16 + kokoro/img/linux.png | Bin 0 -> 17369 bytes kokoro/img/macos.png | Bin 0 -> 15464 bytes kokoro/img/windows.png | Bin 0 -> 16653 bytes kokoro/linux-clang-debug/build.sh | 24 + kokoro/linux-clang-debug/continuous.cfg | 16 + kokoro/linux-clang-debug/presubmit.cfg | 16 + kokoro/linux-clang-release/build.sh | 24 + kokoro/linux-clang-release/continuous.cfg | 16 + kokoro/linux-clang-release/presubmit.cfg | 16 + kokoro/linux-gcc-debug/build.sh | 24 + kokoro/linux-gcc-debug/continuous.cfg | 16 + kokoro/linux-gcc-debug/presubmit.cfg | 17 + kokoro/linux-gcc-release/build.sh | 24 + kokoro/linux-gcc-release/continuous.cfg | 16 + kokoro/linux-gcc-release/presubmit.cfg | 16 + kokoro/macos-clang-debug/build.sh | 25 + kokoro/macos-clang-debug/continuous.cfg | 16 + kokoro/macos-clang-debug/presubmit.cfg | 16 + kokoro/macos-clang-release/build.sh | 25 + kokoro/macos-clang-release/continuous.cfg | 16 + kokoro/macos-clang-release/presubmit.cfg | 16 + kokoro/ndk-build/build.sh | 56 + kokoro/ndk-build/continuous.cfg | 17 + kokoro/ndk-build/presubmit.cfg | 16 + kokoro/scripts/linux/build.sh | 100 + kokoro/scripts/macos/build.sh | 53 + kokoro/scripts/windows/build.bat | 95 + kokoro/shaderc-smoketest/build.sh | 71 + kokoro/shaderc-smoketest/continuous.cfg | 17 + kokoro/shaderc-smoketest/presubmit.cfg | 17 + kokoro/windows-msvc-2013-release/build.bat | 24 + kokoro/windows-msvc-2013-release/continuous.cfg | 16 + kokoro/windows-msvc-2013-release/presubmit.cfg | 16 + kokoro/windows-msvc-2015-release/build.bat | 24 + kokoro/windows-msvc-2015-release/continuous.cfg | 16 + kokoro/windows-msvc-2015-release/presubmit.cfg | 16 + kokoro/windows-msvc-2017-debug/build.bat | 23 + kokoro/windows-msvc-2017-debug/continuous.cfg | 16 + kokoro/windows-msvc-2017-debug/presubmit.cfg | 16 + kokoro/windows-msvc-2017-release/build.bat | 24 + kokoro/windows-msvc-2017-release/continuous.cfg | 16 + kokoro/windows-msvc-2017-release/presubmit.cfg | 16 + projects.md | 82 + source/CMakeLists.txt | 374 + source/assembly_grammar.cpp | 263 + source/assembly_grammar.h | 138 + source/binary.cpp | 795 + source/binary.h | 36 + source/cfa.h | 347 + source/comp/CMakeLists.txt | 52 + source/comp/bit_stream.cpp | 348 + source/comp/bit_stream.h | 280 + source/comp/huffman_codec.h | 389 + source/comp/markv.cpp | 112 + source/comp/markv.h | 74 + source/comp/markv_codec.cpp | 793 + source/comp/markv_codec.h | 337 + source/comp/markv_decoder.cpp | 925 ++ source/comp/markv_decoder.h | 175 + source/comp/markv_encoder.cpp | 486 + source/comp/markv_encoder.h | 167 + source/comp/markv_logger.h | 93 + source/comp/markv_model.h | 232 + source/comp/move_to_front.cpp | 456 + source/comp/move_to_front.h | 384 + source/diagnostic.cpp | 193 + source/diagnostic.h | 79 + source/disassemble.cpp | 478 + source/disassemble.h | 38 + source/enum_set.h | 173 + source/enum_string_mapping.cpp | 29 + source/enum_string_mapping.h | 36 + source/ext_inst.cpp | 160 + source/ext_inst.h | 40 + source/extensions.cpp | 44 + source/extensions.h | 40 + source/extinst.debuginfo.grammar.json | 568 + source/extinst.spv-amd-gcn-shader.grammar.json | 26 + source/extinst.spv-amd-shader-ballot.grammar.json | 41 + ...d-shader-explicit-vertex-parameter.grammar.json | 14 + ...inst.spv-amd-shader-trinary-minmax.grammar.json | 95 + source/id_descriptor.cpp | 78 + source/id_descriptor.h | 63 + source/instruction.h | 49 + source/latest_version_glsl_std_450_header.h | 20 + source/latest_version_opencl_std_header.h | 20 + source/latest_version_spirv_header.h | 20 + source/libspirv.cpp | 131 + source/link/CMakeLists.txt | 36 + source/link/linker.cpp | 769 + source/macro.h | 25 + source/name_mapper.cpp | 331 + source/name_mapper.h | 122 + source/opcode.cpp | 605 + source/opcode.h | 136 + source/operand.cpp | 492 + source/operand.h | 144 + source/opt/CMakeLists.txt | 225 + source/opt/aggressive_dead_code_elim_pass.cpp | 813 ++ source/opt/aggressive_dead_code_elim_pass.h | 200 + source/opt/basic_block.cpp | 266 + source/opt/basic_block.h | 331 + source/opt/block_merge_pass.cpp | 146 + source/opt/block_merge_pass.h | 70 + source/opt/build_module.cpp | 79 + source/opt/build_module.h | 46 + source/opt/ccp_pass.cpp | 328 + source/opt/ccp_pass.h | 113 + source/opt/cfg.cpp | 317 + source/opt/cfg.h | 177 + source/opt/cfg_cleanup_pass.cpp | 39 + source/opt/cfg_cleanup_pass.h | 41 + source/opt/code_sink.cpp | 316 + source/opt/code_sink.h | 107 + source/opt/combine_access_chains.cpp | 290 + source/opt/combine_access_chains.h | 83 + source/opt/common_uniform_elim_pass.cpp | 592 + source/opt/common_uniform_elim_pass.h | 213 + source/opt/compact_ids_pass.cpp | 70 + source/opt/compact_ids_pass.h | 42 + source/opt/composite.cpp | 52 + source/opt/composite.h | 51 + source/opt/const_folding_rules.cpp | 849 ++ source/opt/const_folding_rules.h | 80 + source/opt/constants.cpp | 374 + source/opt/constants.h | 696 + source/opt/copy_prop_arrays.cpp | 872 ++ source/opt/copy_prop_arrays.h | 241 + source/opt/dead_branch_elim_pass.cpp | 542 + source/opt/dead_branch_elim_pass.h | 156 + source/opt/dead_insert_elim_pass.cpp | 263 + source/opt/dead_insert_elim_pass.h | 90 + source/opt/dead_variable_elimination.cpp | 111 + source/opt/dead_variable_elimination.h | 56 + source/opt/decoration_manager.cpp | 527 + source/opt/decoration_manager.h | 192 + source/opt/def_use_manager.cpp | 299 + source/opt/def_use_manager.h | 256 + source/opt/dominator_analysis.cpp | 68 + source/opt/dominator_analysis.h | 138 + source/opt/dominator_tree.cpp | 394 + source/opt/dominator_tree.h | 305 + source/opt/eliminate_dead_constant_pass.cpp | 104 + source/opt/eliminate_dead_constant_pass.h | 35 + source/opt/eliminate_dead_functions_pass.cpp | 56 + source/opt/eliminate_dead_functions_pass.h | 44 + source/opt/feature_manager.cpp | 67 + source/opt/feature_manager.h | 78 + source/opt/flatten_decoration_pass.cpp | 165 + source/opt/flatten_decoration_pass.h | 35 + source/opt/fold.cpp | 706 + source/opt/fold.h | 171 + .../fold_spec_constant_op_and_composite_pass.cpp | 385 + .../opt/fold_spec_constant_op_and_composite_pass.h | 84 + source/opt/folding_rules.cpp | 2243 +++ source/opt/folding_rules.h | 79 + source/opt/freeze_spec_constant_value_pass.cpp | 53 + source/opt/freeze_spec_constant_value_pass.h | 35 + source/opt/function.cpp | 183 + source/opt/function.h | 203 + source/opt/if_conversion.cpp | 285 + source/opt/if_conversion.h | 89 + source/opt/inline_exhaustive_pass.cpp | 85 + source/opt/inline_exhaustive_pass.h | 53 + source/opt/inline_opaque_pass.cpp | 120 + source/opt/inline_opaque_pass.h | 60 + source/opt/inline_pass.cpp | 750 + source/opt/inline_pass.h | 172 + source/opt/inst_bindless_check_pass.cpp | 263 + source/opt/inst_bindless_check_pass.h | 93 + source/opt/instruction.cpp | 753 + source/opt/instruction.h | 733 + source/opt/instruction_list.cpp | 36 + source/opt/instruction_list.h | 130 + source/opt/instrument_pass.cpp | 712 + source/opt/instrument_pass.h | 357 + source/opt/ir_builder.h | 542 + source/opt/ir_context.cpp | 850 ++ source/opt/ir_context.h | 992 ++ source/opt/ir_loader.cpp | 163 + source/opt/ir_loader.h | 86 + source/opt/iterator.h | 358 + source/opt/licm_pass.cpp | 140 + source/opt/licm_pass.h | 72 + source/opt/local_access_chain_convert_pass.cpp | 353 + source/opt/local_access_chain_convert_pass.h | 131 + source/opt/local_redundancy_elimination.cpp | 67 + source/opt/local_redundancy_elimination.h | 68 + source/opt/local_single_block_elim_pass.cpp | 262 + source/opt/local_single_block_elim_pass.h | 107 + source/opt/local_single_store_elim_pass.cpp | 248 + source/opt/local_single_store_elim_pass.h | 103 + source/opt/local_ssa_elim_pass.cpp | 112 + source/opt/local_ssa_elim_pass.h | 72 + source/opt/log.h | 231 + source/opt/loop_dependence.cpp | 1675 +++ source/opt/loop_dependence.h | 558 + source/opt/loop_dependence_helpers.cpp | 541 + source/opt/loop_descriptor.cpp | 1012 ++ source/opt/loop_descriptor.h | 574 + source/opt/loop_fission.cpp | 513 + source/opt/loop_fission.h | 78 + source/opt/loop_fusion.cpp | 730 + source/opt/loop_fusion.h | 114 + source/opt/loop_fusion_pass.cpp | 69 + source/opt/loop_fusion_pass.h | 51 + source/opt/loop_peeling.cpp | 1086 ++ source/opt/loop_peeling.h | 336 + source/opt/loop_unroller.cpp | 1092 ++ source/opt/loop_unroller.h | 49 + source/opt/loop_unswitch_pass.cpp | 620 + source/opt/loop_unswitch_pass.h | 43 + source/opt/loop_utils.cpp | 694 + source/opt/loop_utils.h | 182 + source/opt/mem_pass.cpp | 495 + source/opt/mem_pass.h | 163 + source/opt/merge_return_pass.cpp | 816 ++ source/opt/merge_return_pass.h | 339 + source/opt/module.cpp | 186 + source/opt/module.h | 484 + source/opt/null_pass.h | 34 + source/opt/optimizer.cpp | 802 + source/opt/pass.cpp | 56 + source/opt/pass.h | 147 + source/opt/pass_manager.cpp | 71 + source/opt/pass_manager.h | 131 + source/opt/passes.h | 72 + source/opt/pch_source_opt.cpp | 15 + source/opt/pch_source_opt.h | 32 + source/opt/private_to_local_pass.cpp | 186 + source/opt/private_to_local_pass.h | 73 + source/opt/process_lines_pass.cpp | 157 + source/opt/process_lines_pass.h | 87 + source/opt/propagator.cpp | 291 + source/opt/propagator.h | 317 + source/opt/reduce_load_size.cpp | 182 + source/opt/reduce_load_size.h | 65 + source/opt/redundancy_elimination.cpp | 56 + source/opt/redundancy_elimination.h | 56 + source/opt/reflect.h | 66 + source/opt/register_pressure.cpp | 576 + source/opt/register_pressure.h | 196 + source/opt/remove_duplicates_pass.cpp | 196 + source/opt/remove_duplicates_pass.h | 67 + source/opt/replace_invalid_opc.cpp | 207 + source/opt/replace_invalid_opc.h | 67 + source/opt/scalar_analysis.cpp | 988 ++ source/opt/scalar_analysis.h | 314 + source/opt/scalar_analysis_nodes.h | 345 + source/opt/scalar_analysis_simplification.cpp | 539 + source/opt/scalar_replacement_pass.cpp | 809 ++ source/opt/scalar_replacement_pass.h | 238 + .../opt/set_spec_constant_default_value_pass.cpp | 368 + source/opt/set_spec_constant_default_value_pass.h | 114 + source/opt/simplification_pass.cpp | 120 + source/opt/simplification_pass.h | 50 + source/opt/ssa_rewrite_pass.cpp | 595 + source/opt/ssa_rewrite_pass.h | 304 + source/opt/strength_reduction_pass.cpp | 200 + source/opt/strength_reduction_pass.h | 65 + source/opt/strip_debug_info_pass.cpp | 36 + source/opt/strip_debug_info_pass.h | 35 + source/opt/strip_reflect_info_pass.cpp | 82 + source/opt/strip_reflect_info_pass.h | 44 + source/opt/struct_cfg_analysis.cpp | 128 + source/opt/struct_cfg_analysis.h | 101 + source/opt/tree_iterator.h | 246 + source/opt/type_manager.cpp | 935 ++ source/opt/type_manager.h | 218 + source/opt/types.cpp | 637 + source/opt/types.h | 608 + source/opt/unify_const_pass.cpp | 177 + source/opt/unify_const_pass.h | 35 + source/opt/upgrade_memory_model.cpp | 640 + source/opt/upgrade_memory_model.h | 134 + source/opt/value_number_table.cpp | 227 + source/opt/value_number_table.h | 91 + source/opt/vector_dce.cpp | 397 + source/opt/vector_dce.h | 151 + source/opt/workaround1209.cpp | 69 + source/opt/workaround1209.h | 41 + source/parsed_operand.cpp | 74 + source/parsed_operand.h | 33 + source/pch_source.cpp | 15 + source/pch_source.h | 15 + source/print.cpp | 125 + source/print.h | 75 + source/reduce/CMakeLists.txt | 74 + .../change_operand_reduction_opportunity.cpp | 32 + .../reduce/change_operand_reduction_opportunity.h | 56 + ...ange_operand_to_undef_reduction_opportunity.cpp | 41 + ...change_operand_to_undef_reduction_opportunity.h | 53 + source/reduce/operand_to_const_reduction_pass.cpp | 83 + source/reduce/operand_to_const_reduction_pass.h | 48 + .../operand_to_dominating_id_reduction_pass.cpp | 114 + .../operand_to_dominating_id_reduction_pass.h | 59 + source/reduce/operand_to_undef_reduction_pass.cpp | 94 + source/reduce/operand_to_undef_reduction_pass.h | 45 + source/reduce/pch_source_reduce.cpp | 15 + source/reduce/pch_source_reduce.h | 23 + source/reduce/reducer.cpp | 154 + source/reduce/reducer.h | 97 + source/reduce/reduction_opportunity.cpp | 27 + source/reduce/reduction_opportunity.h | 47 + source/reduce/reduction_pass.cpp | 86 + source/reduce/reduction_pass.h | 73 + source/reduce/reduction_util.cpp | 44 + source/reduce/reduction_util.h | 33 + .../remove_instruction_reduction_opportunity.cpp | 29 + .../remove_instruction_reduction_opportunity.h | 46 + .../remove_opname_instruction_reduction_pass.cpp | 44 + .../remove_opname_instruction_reduction_pass.h | 48 + ...ove_unreferenced_instruction_reduction_pass.cpp | 60 + ...emove_unreferenced_instruction_reduction_pass.h | 50 + ...red_loop_to_selection_reduction_opportunity.cpp | 360 + ...tured_loop_to_selection_reduction_opportunity.h | 116 + ...structured_loop_to_selection_reduction_pass.cpp | 95 + .../structured_loop_to_selection_reduction_pass.h | 61 + source/software_version.cpp | 27 + source/spirv_constant.h | 100 + source/spirv_definition.h | 33 + source/spirv_endian.cpp | 77 + source/spirv_endian.h | 37 + source/spirv_optimizer_options.cpp | 41 + source/spirv_optimizer_options.h | 40 + source/spirv_reducer_options.cpp | 31 + source/spirv_reducer_options.h | 35 + source/spirv_target_env.cpp | 250 + source/spirv_target_env.h | 36 + source/spirv_validator_options.cpp | 106 + source/spirv_validator_options.h | 57 + source/table.cpp | 63 + source/table.h | 132 + source/text.cpp | 811 ++ source/text.h | 53 + source/text_handler.cpp | 397 + source/text_handler.h | 264 + source/util/bit_vector.cpp | 82 + source/util/bit_vector.h | 119 + source/util/bitutils.h | 96 + source/util/hex_float.h | 1150 ++ source/util/ilist.h | 365 + source/util/ilist_node.h | 265 + source/util/make_unique.h | 30 + source/util/parse_number.cpp | 217 + source/util/parse_number.h | 252 + source/util/small_vector.h | 466 + source/util/string_utils.cpp | 58 + source/util/string_utils.h | 48 + source/util/timer.cpp | 102 + source/util/timer.h | 392 + source/val/basic_block.cpp | 149 + source/val/basic_block.h | 247 + source/val/construct.cpp | 131 + source/val/construct.h | 151 + source/val/decoration.h | 89 + source/val/function.cpp | 387 + source/val/function.h | 360 + source/val/instruction.cpp | 45 + source/val/instruction.h | 140 + source/val/validate.cpp | 531 + source/val/validate.h | 234 + source/val/validate_adjacency.cpp | 118 + source/val/validate_annotation.cpp | 263 + source/val/validate_arithmetics.cpp | 453 + source/val/validate_atomics.cpp | 229 + source/val/validate_barriers.cpp | 138 + source/val/validate_bitwise.cpp | 219 + source/val/validate_builtins.cpp | 2675 ++++ source/val/validate_capability.cpp | 335 + source/val/validate_cfg.cpp | 781 + source/val/validate_composites.cpp | 521 + source/val/validate_constants.cpp | 407 + source/val/validate_conversion.cpp | 462 + source/val/validate_datarules.cpp | 267 + source/val/validate_debug.cpp | 81 + source/val/validate_decorations.cpp | 1276 ++ source/val/validate_derivatives.cpp | 98 + source/val/validate_execution_limitations.cpp | 61 + source/val/validate_extensions.cpp | 2029 +++ source/val/validate_function.cpp | 275 + source/val/validate_id.cpp | 222 + source/val/validate_image.cpp | 1789 +++ source/val/validate_instruction.cpp | 535 + source/val/validate_interfaces.cpp | 114 + source/val/validate_layout.cpp | 202 + source/val/validate_literals.cpp | 99 + source/val/validate_logicals.cpp | 265 + source/val/validate_memory.cpp | 1186 ++ source/val/validate_memory_semantics.cpp | 241 + source/val/validate_memory_semantics.h | 28 + source/val/validate_mode_setting.cpp | 435 + source/val/validate_non_uniform.cpp | 77 + source/val/validate_primitives.cpp | 75 + source/val/validate_scopes.cpp | 204 + source/val/validate_scopes.h | 30 + source/val/validate_type.cpp | 393 + source/val/validation_state.cpp | 1025 ++ source/val/validation_state.h | 716 + syntax.md | 238 + test/CMakeLists.txt | 214 + test/assembly_context_test.cpp | 77 + test/assembly_format_test.cpp | 37 + test/binary_destroy_test.cpp | 44 + test/binary_endianness_test.cpp | 54 + test/binary_header_get_test.cpp | 84 + test/binary_parse_test.cpp | 892 ++ test/binary_strnlen_s_test.cpp | 32 + test/binary_to_text.literal_test.cpp | 76 + test/binary_to_text_test.cpp | 561 + test/bit_stream.cpp | 1025 ++ test/c_interface_test.cpp | 299 + test/comment_test.cpp | 50 + test/comp/CMakeLists.txt | 29 + test/comp/markv_codec_test.cpp | 829 ++ test/cpp_interface_test.cpp | 328 + test/diagnostic_test.cpp | 150 + test/enum_set_test.cpp | 290 + test/enum_string_mapping_test.cpp | 195 + test/ext_inst.debuginfo_test.cpp | 812 ++ test/ext_inst.glsl_test.cpp | 204 + test/ext_inst.opencl_test.cpp | 373 + test/fix_word_test.cpp | 64 + test/fuzzers/BUILD.gn | 137 + test/fuzzers/corpora/spv/simple.spv | Bin 0 -> 728 bytes test/fuzzers/spvtools_binary_parser_fuzzer.cpp | 44 + test/fuzzers/spvtools_opt_legalization_fuzzer.cpp | 38 + test/fuzzers/spvtools_opt_performance_fuzzer.cpp | 38 + test/fuzzers/spvtools_opt_size_fuzzer.cpp | 38 + test/fuzzers/spvtools_val_fuzzer.cpp | 36 + test/generator_magic_number_test.cpp | 62 + test/hex_float_test.cpp | 1331 ++ test/huffman_codec.cpp | 317 + test/immediate_int_test.cpp | 291 + test/libspirv_macros_test.cpp | 25 + test/link/CMakeLists.txt | 27 + test/link/binary_version_test.cpp | 60 + test/link/entry_points_test.cpp | 94 + test/link/global_values_amount_test.cpp | 153 + test/link/ids_limit_test.cpp | 72 + test/link/linker_fixture.h | 125 + test/link/matching_imports_to_exports_test.cpp | 403 + test/link/memory_model_test.cpp | 74 + test/link/partial_linkage_test.cpp | 89 + test/link/unique_ids_test.cpp | 142 + test/log_test.cpp | 53 + test/move_to_front_test.cpp | 828 ++ test/name_mapper_test.cpp | 347 + test/named_id_test.cpp | 87 + test/opcode_make_test.cpp | 44 + test/opcode_require_capabilities_test.cpp | 78 + test/opcode_split_test.cpp | 30 + test/opcode_table_get_test.cpp | 39 + test/operand-class-test-coverage.csv | 43 + test/operand_capabilities_test.cpp | 736 + test/operand_pattern_test.cpp | 270 + test/operand_test.cpp | 75 + test/opt/CMakeLists.txt | 92 + test/opt/aggressive_dead_code_elim_test.cpp | 6225 ++++++++ test/opt/assembly_builder.h | 266 + test/opt/assembly_builder_test.cpp | 283 + test/opt/block_merge_test.cpp | 751 + test/opt/ccp_test.cpp | 901 ++ test/opt/cfg_cleanup_test.cpp | 456 + test/opt/code_sink_test.cpp | 533 + test/opt/combine_access_chains_test.cpp | 773 + test/opt/common_uniform_elim_test.cpp | 1394 ++ test/opt/compact_ids_test.cpp | 279 + test/opt/constant_manager_test.cpp | 88 + test/opt/copy_prop_array_test.cpp | 1576 ++ test/opt/dead_branch_elim_test.cpp | 2878 ++++ test/opt/dead_insert_elim_test.cpp | 571 + test/opt/dead_variable_elim_test.cpp | 298 + test/opt/decoration_manager_test.cpp | 1283 ++ test/opt/def_use_test.cpp | 1719 +++ test/opt/dominator_tree/CMakeLists.txt | 31 + test/opt/dominator_tree/common_dominators.cpp | 151 + test/opt/dominator_tree/generated.cpp | 900 ++ test/opt/dominator_tree/nested_ifs.cpp | 153 + test/opt/dominator_tree/nested_ifs_post.cpp | 156 + test/opt/dominator_tree/nested_loops.cpp | 433 + .../nested_loops_with_unreachables.cpp | 848 ++ test/opt/dominator_tree/pch_test_opt_dom.cpp | 15 + test/opt/dominator_tree/pch_test_opt_dom.h | 25 + test/opt/dominator_tree/post.cpp | 207 + test/opt/dominator_tree/simple.cpp | 177 + .../opt/dominator_tree/switch_case_fallthrough.cpp | 163 + test/opt/dominator_tree/unreachable_for.cpp | 121 + test/opt/dominator_tree/unreachable_for_post.cpp | 118 + test/opt/eliminate_dead_const_test.cpp | 847 ++ test/opt/eliminate_dead_functions_test.cpp | 209 + test/opt/feature_manager_test.cpp | 142 + test/opt/flatten_decoration_test.cpp | 239 + test/opt/fold_spec_const_op_composite_test.cpp | 1394 ++ test/opt/fold_test.cpp | 6198 ++++++++ test/opt/freeze_spec_const_test.cpp | 133 + test/opt/function_test.cpp | 173 + test/opt/function_utils.h | 55 + test/opt/if_conversion_test.cpp | 511 + test/opt/inline_opaque_test.cpp | 412 + test/opt/inline_test.cpp | 3133 ++++ test/opt/insert_extract_elim_test.cpp | 900 ++ test/opt/inst_bindless_check_test.cpp | 1857 +++ test/opt/instruction_list_test.cpp | 115 + test/opt/instruction_test.cpp | 1135 ++ test/opt/ir_builder.cpp | 439 + test/opt/ir_context_test.cpp | 669 + test/opt/ir_loader_test.cpp | 451 + test/opt/iterator_test.cpp | 267 + test/opt/line_debug_info_test.cpp | 113 + test/opt/local_access_chain_convert_test.cpp | 714 + test/opt/local_redundancy_elimination_test.cpp | 159 + test/opt/local_single_block_elim.cpp | 1074 ++ test/opt/local_single_store_elim_test.cpp | 857 ++ test/opt/local_ssa_elim_test.cpp | 1834 +++ test/opt/loop_optimizations/CMakeLists.txt | 41 + .../opt/loop_optimizations/dependence_analysis.cpp | 4205 ++++++ .../dependence_analysis_helpers.cpp | 3017 ++++ .../loop_optimizations/fusion_compatibility.cpp | 1785 +++ test/opt/loop_optimizations/fusion_illegal.cpp | 1592 ++ test/opt/loop_optimizations/fusion_legal.cpp | 4578 ++++++ test/opt/loop_optimizations/fusion_pass.cpp | 717 + .../loop_optimizations/hoist_all_loop_types.cpp | 285 + .../hoist_double_nested_loops.cpp | 162 + .../hoist_from_independent_loops.cpp | 201 + test/opt/loop_optimizations/hoist_simple_case.cpp | 126 + .../hoist_single_nested_loops.cpp | 209 + .../loop_optimizations/hoist_without_preheader.cpp | 197 + test/opt/loop_optimizations/lcssa.cpp | 607 + test/opt/loop_optimizations/loop_descriptions.cpp | 384 + test/opt/loop_optimizations/loop_fission.cpp | 3491 +++++ test/opt/loop_optimizations/nested_loops.cpp | 795 + test/opt/loop_optimizations/pch_test_opt_loop.cpp | 15 + test/opt/loop_optimizations/pch_test_opt_loop.h | 25 + test/opt/loop_optimizations/peeling.cpp | 1186 ++ test/opt/loop_optimizations/peeling_pass.cpp | 1099 ++ test/opt/loop_optimizations/unroll_assumptions.cpp | 1448 ++ test/opt/loop_optimizations/unroll_simple.cpp | 3003 ++++ test/opt/loop_optimizations/unswitch.cpp | 967 ++ test/opt/module_test.cpp | 233 + test/opt/module_utils.h | 34 + test/opt/optimizer_test.cpp | 227 + test/opt/pass_fixture.h | 249 + test/opt/pass_manager_test.cpp | 191 + test/opt/pass_merge_return_test.cpp | 1373 ++ test/opt/pass_remove_duplicates_test.cpp | 646 + test/opt/pass_utils.cpp | 102 + test/opt/pass_utils.h | 84 + test/opt/pch_test_opt.cpp | 15 + test/opt/pch_test_opt.h | 25 + test/opt/private_to_local_test.cpp | 313 + test/opt/process_lines_test.cpp | 695 + test/opt/propagator_test.cpp | 219 + test/opt/reduce_load_size_test.cpp | 354 + test/opt/redundancy_elimination_test.cpp | 277 + test/opt/register_liveness.cpp | 1282 ++ test/opt/replace_invalid_opc_test.cpp | 566 + test/opt/scalar_analysis.cpp | 1221 ++ test/opt/scalar_replacement_test.cpp | 1625 +++ test/opt/set_spec_const_default_value_test.cpp | 1077 ++ test/opt/simplification_test.cpp | 207 + test/opt/strength_reduction_test.cpp | 438 + test/opt/strip_debug_info_test.cpp | 107 + test/opt/strip_reflect_info_test.cpp | 90 + test/opt/struct_cfg_analysis_test.cpp | 466 + test/opt/type_manager_test.cpp | 1148 ++ test/opt/types_test.cpp | 345 + test/opt/unify_const_test.cpp | 990 ++ test/opt/upgrade_memory_model_test.cpp | 1716 +++ test/opt/utils_test.cpp | 110 + test/opt/value_table_test.cpp | 591 + test/opt/vector_dce_test.cpp | 1231 ++ test/opt/workaround1209_test.cpp | 423 + test/parse_number_test.cpp | 970 ++ test/pch_test.cpp | 15 + test/pch_test.h | 18 + test/preserve_numeric_ids_test.cpp | 159 + test/reduce/CMakeLists.txt | 28 + .../operand_to_constant_reduction_pass_test.cpp | 156 + ...perand_to_dominating_id_reduction_pass_test.cpp | 196 + .../operand_to_undef_reduction_pass_test.cpp | 226 + test/reduce/reduce_test_util.cpp | 72 + test/reduce/reduce_test_util.h | 82 + test/reduce/reducer_test.cpp | 310 + ...move_opname_instruction_reduction_pass_test.cpp | 216 + ...nreferenced_instruction_reduction_pass_test.cpp | 230 + ...tured_loop_to_selection_reduction_pass_test.cpp | 3440 +++++ test/reduce/validation_during_reduction_test.cpp | 376 + test/scripts/test_compact_ids.py | 102 + test/software_version_test.cpp | 67 + test/stats/CMakeLists.txt | 27 + test/stats/stats_aggregate_test.cpp | 438 + test/stats/stats_analyzer_test.cpp | 174 + test/string_utils_test.cpp | 191 + test/target_env_test.cpp | 106 + test/test_fixture.h | 198 + test/text_advance_test.cpp | 134 + test/text_destroy_test.cpp | 75 + test/text_literal_test.cpp | 412 + test/text_start_new_inst_test.cpp | 75 + test/text_to_binary.annotation_test.cpp | 510 + test/text_to_binary.barrier_test.cpp | 170 + test/text_to_binary.constant_test.cpp | 830 ++ test/text_to_binary.control_flow_test.cpp | 394 + test/text_to_binary.debug_test.cpp | 214 + test/text_to_binary.device_side_enqueue_test.cpp | 112 + test/text_to_binary.extension_test.cpp | 857 ++ test/text_to_binary.function_test.cpp | 81 + test/text_to_binary.group_test.cpp | 76 + test/text_to_binary.image_test.cpp | 276 + test/text_to_binary.literal_test.cpp | 125 + test/text_to_binary.memory_test.cpp | 111 + test/text_to_binary.misc_test.cpp | 58 + test/text_to_binary.mode_setting_test.cpp | 302 + test/text_to_binary.pipe_storage_test.cpp | 126 + test/text_to_binary.reserved_sampling_test.cpp | 63 + test/text_to_binary.subgroup_dispatch_test.cpp | 122 + test/text_to_binary.type_declaration_test.cpp | 293 + test/text_to_binary_test.cpp | 269 + test/text_word_get_test.cpp | 254 + test/timer_test.cpp | 142 + test/tools/CMakeLists.txt | 18 + test/tools/expect.py | 677 + test/tools/expect_nosetest.py | 80 + test/tools/opt/CMakeLists.txt | 25 + test/tools/opt/flags.py | 333 + test/tools/opt/oconfig.py | 73 + test/tools/placeholder.py | 213 + test/tools/spirv_test_framework.py | 375 + test/tools/spirv_test_framework_nosetest.py | 155 + test/unit_spirv.cpp | 55 + test/unit_spirv.h | 234 + test/util/CMakeLists.txt | 20 + test/util/bit_vector_test.cpp | 164 + test/util/ilist_test.cpp | 325 + test/util/small_vector_test.cpp | 598 + test/val/CMakeLists.txt | 80 + test/val/pch_test_val.cpp | 15 + test/val/pch_test_val.h | 19 + test/val/val_adjacency_test.cpp | 470 + test/val/val_arithmetics_test.cpp | 1278 ++ test/val/val_atomics_test.cpp | 1933 +++ test/val/val_barriers_test.cpp | 1284 ++ test/val/val_bitwise_test.cpp | 549 + test/val/val_builtins_test.cpp | 2283 +++ test/val/val_capability_test.cpp | 2407 +++ test/val/val_cfg_test.cpp | 2067 +++ test/val/val_composites_test.cpp | 1496 ++ test/val/val_constants_test.cpp | 299 + test/val/val_conversion_test.cpp | 1419 ++ test/val/val_data_test.cpp | 939 ++ test/val/val_decoration_test.cpp | 5307 +++++++ test/val/val_derivatives_test.cpp | 156 + test/val/val_explicit_reserved_test.cpp | 122 + test/val/val_ext_inst_test.cpp | 5819 ++++++++ test/val/val_extensions_test.cpp | 345 + test/val/val_fixtures.h | 159 + test/val/val_id_test.cpp | 6403 ++++++++ test/val/val_image_test.cpp | 4457 ++++++ test/val/val_interfaces_test.cpp | 169 + test/val/val_layout_test.cpp | 706 + test/val/val_limits_test.cpp | 774 + test/val/val_literals_test.cpp | 142 + test/val/val_logicals_test.cpp | 954 ++ test/val/val_memory_test.cpp | 2573 ++++ test/val/val_modes_test.cpp | 801 + test/val/val_non_uniform_test.cpp | 293 + test/val/val_primitives_test.cpp | 321 + test/val/val_ssa_test.cpp | 1449 ++ test/val/val_state_test.cpp | 139 + test/val/val_storage_test.cpp | 191 + test/val/val_type_unique_test.cpp | 269 + test/val/val_validation_state_test.cpp | 359 + test/val/val_version_test.cpp | 294 + test/val/val_webgpu_test.cpp | 281 + tools/CMakeLists.txt | 84 + tools/as/as.cpp | 154 + tools/cfg/bin_to_dot.cpp | 187 + tools/cfg/bin_to_dot.h | 28 + tools/cfg/cfg.cpp | 127 + tools/comp/markv.cpp | 385 + tools/comp/markv_model_factory.cpp | 50 + tools/comp/markv_model_factory.h | 37 + tools/comp/markv_model_shader.cpp | 84 + tools/comp/markv_model_shader.h | 47 + tools/comp/markv_model_shader_default_autogen.inc | 14519 +++++++++++++++++++ tools/dis/dis.cpp | 209 + tools/emacs/50spirv-tools.el | 40 + tools/emacs/CMakeLists.txt | 48 + tools/io.h | 82 + tools/lesspipe/CMakeLists.txt | 28 + tools/lesspipe/spirv-lesspipe.sh | 27 + tools/link/linker.cpp | 159 + tools/opt/opt.cpp | 699 + tools/reduce/reduce.cpp | 242 + tools/stats/spirv_stats.cpp | 165 + tools/stats/spirv_stats.h | 93 + tools/stats/stats.cpp | 173 + tools/stats/stats_analyzer.cpp | 235 + tools/stats/stats_analyzer.h | 58 + tools/util/cli_consumer.cpp | 45 + tools/util/cli_consumer.h | 31 + tools/val/val.cpp | 182 + utils/check_code_format.sh | 37 + utils/check_copyright.py | 222 + utils/check_symbol_exports.py | 93 + utils/fixup_fuzz_result.py | 25 + utils/generate_grammar_tables.py | 749 + utils/generate_language_headers.py | 188 + utils/generate_registry_tables.py | 72 + utils/generate_vim_syntax.py | 207 + utils/update_build_version.py | 150 + 751 files changed, 297690 insertions(+) create mode 100644 .appveyor.yml create mode 100644 .clang-format create mode 100644 .gitignore create mode 100644 .gn create mode 100644 Android.mk create mode 100644 BUILD.gn create mode 100644 CHANGES create mode 100644 CMakeLists.txt create mode 100644 CODE_OF_CONDUCT.md create mode 100644 CONTRIBUTING.md create mode 100644 DEPS create mode 100644 LICENSE create mode 100644 PRESUBMIT.py create mode 100644 README.md create mode 100644 android_test/Android.mk create mode 100644 android_test/jni/Application.mk create mode 100644 android_test/test.cpp create mode 100644 build_overrides/build.gni create mode 100644 build_overrides/gtest.gni create mode 100644 build_overrides/spirv_tools.gni create mode 100644 cmake/SPIRV-Tools-shared.pc.in create mode 100644 cmake/SPIRV-Tools.pc.in create mode 100644 cmake/setup_build.cmake create mode 100644 cmake/write_pkg_config.cmake create mode 100644 codereview.settings create mode 100644 examples/CMakeLists.txt create mode 100644 examples/cpp-interface/CMakeLists.txt create mode 100644 examples/cpp-interface/main.cpp create mode 100644 external/CMakeLists.txt create mode 100644 include/spirv-tools/instrument.hpp create mode 100644 include/spirv-tools/libspirv.h create mode 100644 include/spirv-tools/libspirv.hpp create mode 100644 include/spirv-tools/linker.hpp create mode 100644 include/spirv-tools/optimizer.hpp create mode 100644 kokoro/android/build.sh create mode 100644 kokoro/android/continuous.cfg create mode 100644 kokoro/android/presubmit.cfg create mode 100644 kokoro/check-format/build.sh create mode 100644 kokoro/check-format/presubmit_check_format.cfg create mode 100644 kokoro/img/linux.png create mode 100644 kokoro/img/macos.png create mode 100644 kokoro/img/windows.png create mode 100644 kokoro/linux-clang-debug/build.sh create mode 100644 kokoro/linux-clang-debug/continuous.cfg create mode 100644 kokoro/linux-clang-debug/presubmit.cfg create mode 100644 kokoro/linux-clang-release/build.sh create mode 100644 kokoro/linux-clang-release/continuous.cfg create mode 100644 kokoro/linux-clang-release/presubmit.cfg create mode 100644 kokoro/linux-gcc-debug/build.sh create mode 100644 kokoro/linux-gcc-debug/continuous.cfg create mode 100644 kokoro/linux-gcc-debug/presubmit.cfg create mode 100644 kokoro/linux-gcc-release/build.sh create mode 100644 kokoro/linux-gcc-release/continuous.cfg create mode 100644 kokoro/linux-gcc-release/presubmit.cfg create mode 100644 kokoro/macos-clang-debug/build.sh create mode 100644 kokoro/macos-clang-debug/continuous.cfg create mode 100644 kokoro/macos-clang-debug/presubmit.cfg create mode 100644 kokoro/macos-clang-release/build.sh create mode 100644 kokoro/macos-clang-release/continuous.cfg create mode 100644 kokoro/macos-clang-release/presubmit.cfg create mode 100644 kokoro/ndk-build/build.sh create mode 100644 kokoro/ndk-build/continuous.cfg create mode 100644 kokoro/ndk-build/presubmit.cfg create mode 100644 kokoro/scripts/linux/build.sh create mode 100644 kokoro/scripts/macos/build.sh create mode 100644 kokoro/scripts/windows/build.bat create mode 100644 kokoro/shaderc-smoketest/build.sh create mode 100644 kokoro/shaderc-smoketest/continuous.cfg create mode 100644 kokoro/shaderc-smoketest/presubmit.cfg create mode 100644 kokoro/windows-msvc-2013-release/build.bat create mode 100644 kokoro/windows-msvc-2013-release/continuous.cfg create mode 100644 kokoro/windows-msvc-2013-release/presubmit.cfg create mode 100644 kokoro/windows-msvc-2015-release/build.bat create mode 100644 kokoro/windows-msvc-2015-release/continuous.cfg create mode 100644 kokoro/windows-msvc-2015-release/presubmit.cfg create mode 100644 kokoro/windows-msvc-2017-debug/build.bat create mode 100644 kokoro/windows-msvc-2017-debug/continuous.cfg create mode 100644 kokoro/windows-msvc-2017-debug/presubmit.cfg create mode 100644 kokoro/windows-msvc-2017-release/build.bat create mode 100644 kokoro/windows-msvc-2017-release/continuous.cfg create mode 100644 kokoro/windows-msvc-2017-release/presubmit.cfg create mode 100644 projects.md create mode 100644 source/CMakeLists.txt create mode 100644 source/assembly_grammar.cpp create mode 100644 source/assembly_grammar.h create mode 100644 source/binary.cpp create mode 100644 source/binary.h create mode 100644 source/cfa.h create mode 100644 source/comp/CMakeLists.txt create mode 100644 source/comp/bit_stream.cpp create mode 100644 source/comp/bit_stream.h create mode 100644 source/comp/huffman_codec.h create mode 100644 source/comp/markv.cpp create mode 100644 source/comp/markv.h create mode 100644 source/comp/markv_codec.cpp create mode 100644 source/comp/markv_codec.h create mode 100644 source/comp/markv_decoder.cpp create mode 100644 source/comp/markv_decoder.h create mode 100644 source/comp/markv_encoder.cpp create mode 100644 source/comp/markv_encoder.h create mode 100644 source/comp/markv_logger.h create mode 100644 source/comp/markv_model.h create mode 100644 source/comp/move_to_front.cpp create mode 100644 source/comp/move_to_front.h create mode 100644 source/diagnostic.cpp create mode 100644 source/diagnostic.h create mode 100644 source/disassemble.cpp create mode 100644 source/disassemble.h create mode 100644 source/enum_set.h create mode 100644 source/enum_string_mapping.cpp create mode 100644 source/enum_string_mapping.h create mode 100644 source/ext_inst.cpp create mode 100644 source/ext_inst.h create mode 100644 source/extensions.cpp create mode 100644 source/extensions.h create mode 100644 source/extinst.debuginfo.grammar.json create mode 100644 source/extinst.spv-amd-gcn-shader.grammar.json create mode 100644 source/extinst.spv-amd-shader-ballot.grammar.json create mode 100644 source/extinst.spv-amd-shader-explicit-vertex-parameter.grammar.json create mode 100644 source/extinst.spv-amd-shader-trinary-minmax.grammar.json create mode 100644 source/id_descriptor.cpp create mode 100644 source/id_descriptor.h create mode 100644 source/instruction.h create mode 100644 source/latest_version_glsl_std_450_header.h create mode 100644 source/latest_version_opencl_std_header.h create mode 100644 source/latest_version_spirv_header.h create mode 100644 source/libspirv.cpp create mode 100644 source/link/CMakeLists.txt create mode 100644 source/link/linker.cpp create mode 100644 source/macro.h create mode 100644 source/name_mapper.cpp create mode 100644 source/name_mapper.h create mode 100644 source/opcode.cpp create mode 100644 source/opcode.h create mode 100644 source/operand.cpp create mode 100644 source/operand.h create mode 100644 source/opt/CMakeLists.txt create mode 100644 source/opt/aggressive_dead_code_elim_pass.cpp create mode 100644 source/opt/aggressive_dead_code_elim_pass.h create mode 100644 source/opt/basic_block.cpp create mode 100644 source/opt/basic_block.h create mode 100644 source/opt/block_merge_pass.cpp create mode 100644 source/opt/block_merge_pass.h create mode 100644 source/opt/build_module.cpp create mode 100644 source/opt/build_module.h create mode 100644 source/opt/ccp_pass.cpp create mode 100644 source/opt/ccp_pass.h create mode 100644 source/opt/cfg.cpp create mode 100644 source/opt/cfg.h create mode 100644 source/opt/cfg_cleanup_pass.cpp create mode 100644 source/opt/cfg_cleanup_pass.h create mode 100644 source/opt/code_sink.cpp create mode 100644 source/opt/code_sink.h create mode 100644 source/opt/combine_access_chains.cpp create mode 100644 source/opt/combine_access_chains.h create mode 100644 source/opt/common_uniform_elim_pass.cpp create mode 100644 source/opt/common_uniform_elim_pass.h create mode 100644 source/opt/compact_ids_pass.cpp create mode 100644 source/opt/compact_ids_pass.h create mode 100644 source/opt/composite.cpp create mode 100644 source/opt/composite.h create mode 100644 source/opt/const_folding_rules.cpp create mode 100644 source/opt/const_folding_rules.h create mode 100644 source/opt/constants.cpp create mode 100644 source/opt/constants.h create mode 100644 source/opt/copy_prop_arrays.cpp create mode 100644 source/opt/copy_prop_arrays.h create mode 100644 source/opt/dead_branch_elim_pass.cpp create mode 100644 source/opt/dead_branch_elim_pass.h create mode 100644 source/opt/dead_insert_elim_pass.cpp create mode 100644 source/opt/dead_insert_elim_pass.h create mode 100644 source/opt/dead_variable_elimination.cpp create mode 100644 source/opt/dead_variable_elimination.h create mode 100644 source/opt/decoration_manager.cpp create mode 100644 source/opt/decoration_manager.h create mode 100644 source/opt/def_use_manager.cpp create mode 100644 source/opt/def_use_manager.h create mode 100644 source/opt/dominator_analysis.cpp create mode 100644 source/opt/dominator_analysis.h create mode 100644 source/opt/dominator_tree.cpp create mode 100644 source/opt/dominator_tree.h create mode 100644 source/opt/eliminate_dead_constant_pass.cpp create mode 100644 source/opt/eliminate_dead_constant_pass.h create mode 100644 source/opt/eliminate_dead_functions_pass.cpp create mode 100644 source/opt/eliminate_dead_functions_pass.h create mode 100644 source/opt/feature_manager.cpp create mode 100644 source/opt/feature_manager.h create mode 100644 source/opt/flatten_decoration_pass.cpp create mode 100644 source/opt/flatten_decoration_pass.h create mode 100644 source/opt/fold.cpp create mode 100644 source/opt/fold.h create mode 100644 source/opt/fold_spec_constant_op_and_composite_pass.cpp create mode 100644 source/opt/fold_spec_constant_op_and_composite_pass.h create mode 100644 source/opt/folding_rules.cpp create mode 100644 source/opt/folding_rules.h create mode 100644 source/opt/freeze_spec_constant_value_pass.cpp create mode 100644 source/opt/freeze_spec_constant_value_pass.h create mode 100644 source/opt/function.cpp create mode 100644 source/opt/function.h create mode 100644 source/opt/if_conversion.cpp create mode 100644 source/opt/if_conversion.h create mode 100644 source/opt/inline_exhaustive_pass.cpp create mode 100644 source/opt/inline_exhaustive_pass.h create mode 100644 source/opt/inline_opaque_pass.cpp create mode 100644 source/opt/inline_opaque_pass.h create mode 100644 source/opt/inline_pass.cpp create mode 100644 source/opt/inline_pass.h create mode 100644 source/opt/inst_bindless_check_pass.cpp create mode 100644 source/opt/inst_bindless_check_pass.h create mode 100644 source/opt/instruction.cpp create mode 100644 source/opt/instruction.h create mode 100644 source/opt/instruction_list.cpp create mode 100644 source/opt/instruction_list.h create mode 100644 source/opt/instrument_pass.cpp create mode 100644 source/opt/instrument_pass.h create mode 100644 source/opt/ir_builder.h create mode 100644 source/opt/ir_context.cpp create mode 100644 source/opt/ir_context.h create mode 100644 source/opt/ir_loader.cpp create mode 100644 source/opt/ir_loader.h create mode 100644 source/opt/iterator.h create mode 100644 source/opt/licm_pass.cpp create mode 100644 source/opt/licm_pass.h create mode 100644 source/opt/local_access_chain_convert_pass.cpp create mode 100644 source/opt/local_access_chain_convert_pass.h create mode 100644 source/opt/local_redundancy_elimination.cpp create mode 100644 source/opt/local_redundancy_elimination.h create mode 100644 source/opt/local_single_block_elim_pass.cpp create mode 100644 source/opt/local_single_block_elim_pass.h create mode 100644 source/opt/local_single_store_elim_pass.cpp create mode 100644 source/opt/local_single_store_elim_pass.h create mode 100644 source/opt/local_ssa_elim_pass.cpp create mode 100644 source/opt/local_ssa_elim_pass.h create mode 100644 source/opt/log.h create mode 100644 source/opt/loop_dependence.cpp create mode 100644 source/opt/loop_dependence.h create mode 100644 source/opt/loop_dependence_helpers.cpp create mode 100644 source/opt/loop_descriptor.cpp create mode 100644 source/opt/loop_descriptor.h create mode 100644 source/opt/loop_fission.cpp create mode 100644 source/opt/loop_fission.h create mode 100644 source/opt/loop_fusion.cpp create mode 100644 source/opt/loop_fusion.h create mode 100644 source/opt/loop_fusion_pass.cpp create mode 100644 source/opt/loop_fusion_pass.h create mode 100644 source/opt/loop_peeling.cpp create mode 100644 source/opt/loop_peeling.h create mode 100644 source/opt/loop_unroller.cpp create mode 100644 source/opt/loop_unroller.h create mode 100644 source/opt/loop_unswitch_pass.cpp create mode 100644 source/opt/loop_unswitch_pass.h create mode 100644 source/opt/loop_utils.cpp create mode 100644 source/opt/loop_utils.h create mode 100644 source/opt/mem_pass.cpp create mode 100644 source/opt/mem_pass.h create mode 100644 source/opt/merge_return_pass.cpp create mode 100644 source/opt/merge_return_pass.h create mode 100644 source/opt/module.cpp create mode 100644 source/opt/module.h create mode 100644 source/opt/null_pass.h create mode 100644 source/opt/optimizer.cpp create mode 100644 source/opt/pass.cpp create mode 100644 source/opt/pass.h create mode 100644 source/opt/pass_manager.cpp create mode 100644 source/opt/pass_manager.h create mode 100644 source/opt/passes.h create mode 100644 source/opt/pch_source_opt.cpp create mode 100644 source/opt/pch_source_opt.h create mode 100644 source/opt/private_to_local_pass.cpp create mode 100644 source/opt/private_to_local_pass.h create mode 100644 source/opt/process_lines_pass.cpp create mode 100644 source/opt/process_lines_pass.h create mode 100644 source/opt/propagator.cpp create mode 100644 source/opt/propagator.h create mode 100644 source/opt/reduce_load_size.cpp create mode 100644 source/opt/reduce_load_size.h create mode 100644 source/opt/redundancy_elimination.cpp create mode 100644 source/opt/redundancy_elimination.h create mode 100644 source/opt/reflect.h create mode 100644 source/opt/register_pressure.cpp create mode 100644 source/opt/register_pressure.h create mode 100644 source/opt/remove_duplicates_pass.cpp create mode 100644 source/opt/remove_duplicates_pass.h create mode 100644 source/opt/replace_invalid_opc.cpp create mode 100644 source/opt/replace_invalid_opc.h create mode 100644 source/opt/scalar_analysis.cpp create mode 100644 source/opt/scalar_analysis.h create mode 100644 source/opt/scalar_analysis_nodes.h create mode 100644 source/opt/scalar_analysis_simplification.cpp create mode 100644 source/opt/scalar_replacement_pass.cpp create mode 100644 source/opt/scalar_replacement_pass.h create mode 100644 source/opt/set_spec_constant_default_value_pass.cpp create mode 100644 source/opt/set_spec_constant_default_value_pass.h create mode 100644 source/opt/simplification_pass.cpp create mode 100644 source/opt/simplification_pass.h create mode 100644 source/opt/ssa_rewrite_pass.cpp create mode 100644 source/opt/ssa_rewrite_pass.h create mode 100644 source/opt/strength_reduction_pass.cpp create mode 100644 source/opt/strength_reduction_pass.h create mode 100644 source/opt/strip_debug_info_pass.cpp create mode 100644 source/opt/strip_debug_info_pass.h create mode 100644 source/opt/strip_reflect_info_pass.cpp create mode 100644 source/opt/strip_reflect_info_pass.h create mode 100644 source/opt/struct_cfg_analysis.cpp create mode 100644 source/opt/struct_cfg_analysis.h create mode 100644 source/opt/tree_iterator.h create mode 100644 source/opt/type_manager.cpp create mode 100644 source/opt/type_manager.h create mode 100644 source/opt/types.cpp create mode 100644 source/opt/types.h create mode 100644 source/opt/unify_const_pass.cpp create mode 100644 source/opt/unify_const_pass.h create mode 100644 source/opt/upgrade_memory_model.cpp create mode 100644 source/opt/upgrade_memory_model.h create mode 100644 source/opt/value_number_table.cpp create mode 100644 source/opt/value_number_table.h create mode 100644 source/opt/vector_dce.cpp create mode 100644 source/opt/vector_dce.h create mode 100644 source/opt/workaround1209.cpp create mode 100644 source/opt/workaround1209.h create mode 100644 source/parsed_operand.cpp create mode 100644 source/parsed_operand.h create mode 100644 source/pch_source.cpp create mode 100644 source/pch_source.h create mode 100644 source/print.cpp create mode 100644 source/print.h create mode 100644 source/reduce/CMakeLists.txt create mode 100644 source/reduce/change_operand_reduction_opportunity.cpp create mode 100644 source/reduce/change_operand_reduction_opportunity.h create mode 100644 source/reduce/change_operand_to_undef_reduction_opportunity.cpp create mode 100644 source/reduce/change_operand_to_undef_reduction_opportunity.h create mode 100644 source/reduce/operand_to_const_reduction_pass.cpp create mode 100644 source/reduce/operand_to_const_reduction_pass.h create mode 100644 source/reduce/operand_to_dominating_id_reduction_pass.cpp create mode 100644 source/reduce/operand_to_dominating_id_reduction_pass.h create mode 100644 source/reduce/operand_to_undef_reduction_pass.cpp create mode 100644 source/reduce/operand_to_undef_reduction_pass.h create mode 100644 source/reduce/pch_source_reduce.cpp create mode 100644 source/reduce/pch_source_reduce.h create mode 100644 source/reduce/reducer.cpp create mode 100644 source/reduce/reducer.h create mode 100644 source/reduce/reduction_opportunity.cpp create mode 100644 source/reduce/reduction_opportunity.h create mode 100644 source/reduce/reduction_pass.cpp create mode 100644 source/reduce/reduction_pass.h create mode 100644 source/reduce/reduction_util.cpp create mode 100644 source/reduce/reduction_util.h create mode 100644 source/reduce/remove_instruction_reduction_opportunity.cpp create mode 100644 source/reduce/remove_instruction_reduction_opportunity.h create mode 100644 source/reduce/remove_opname_instruction_reduction_pass.cpp create mode 100644 source/reduce/remove_opname_instruction_reduction_pass.h create mode 100644 source/reduce/remove_unreferenced_instruction_reduction_pass.cpp create mode 100644 source/reduce/remove_unreferenced_instruction_reduction_pass.h create mode 100644 source/reduce/structured_loop_to_selection_reduction_opportunity.cpp create mode 100644 source/reduce/structured_loop_to_selection_reduction_opportunity.h create mode 100644 source/reduce/structured_loop_to_selection_reduction_pass.cpp create mode 100644 source/reduce/structured_loop_to_selection_reduction_pass.h create mode 100644 source/software_version.cpp create mode 100644 source/spirv_constant.h create mode 100644 source/spirv_definition.h create mode 100644 source/spirv_endian.cpp create mode 100644 source/spirv_endian.h create mode 100644 source/spirv_optimizer_options.cpp create mode 100644 source/spirv_optimizer_options.h create mode 100644 source/spirv_reducer_options.cpp create mode 100644 source/spirv_reducer_options.h create mode 100644 source/spirv_target_env.cpp create mode 100644 source/spirv_target_env.h create mode 100644 source/spirv_validator_options.cpp create mode 100644 source/spirv_validator_options.h create mode 100644 source/table.cpp create mode 100644 source/table.h create mode 100644 source/text.cpp create mode 100644 source/text.h create mode 100644 source/text_handler.cpp create mode 100644 source/text_handler.h create mode 100644 source/util/bit_vector.cpp create mode 100644 source/util/bit_vector.h create mode 100644 source/util/bitutils.h create mode 100644 source/util/hex_float.h create mode 100644 source/util/ilist.h create mode 100644 source/util/ilist_node.h create mode 100644 source/util/make_unique.h create mode 100644 source/util/parse_number.cpp create mode 100644 source/util/parse_number.h create mode 100644 source/util/small_vector.h create mode 100644 source/util/string_utils.cpp create mode 100644 source/util/string_utils.h create mode 100644 source/util/timer.cpp create mode 100644 source/util/timer.h create mode 100644 source/val/basic_block.cpp create mode 100644 source/val/basic_block.h create mode 100644 source/val/construct.cpp create mode 100644 source/val/construct.h create mode 100644 source/val/decoration.h create mode 100644 source/val/function.cpp create mode 100644 source/val/function.h create mode 100644 source/val/instruction.cpp create mode 100644 source/val/instruction.h create mode 100644 source/val/validate.cpp create mode 100644 source/val/validate.h create mode 100644 source/val/validate_adjacency.cpp create mode 100644 source/val/validate_annotation.cpp create mode 100644 source/val/validate_arithmetics.cpp create mode 100644 source/val/validate_atomics.cpp create mode 100644 source/val/validate_barriers.cpp create mode 100644 source/val/validate_bitwise.cpp create mode 100644 source/val/validate_builtins.cpp create mode 100644 source/val/validate_capability.cpp create mode 100644 source/val/validate_cfg.cpp create mode 100644 source/val/validate_composites.cpp create mode 100644 source/val/validate_constants.cpp create mode 100644 source/val/validate_conversion.cpp create mode 100644 source/val/validate_datarules.cpp create mode 100644 source/val/validate_debug.cpp create mode 100644 source/val/validate_decorations.cpp create mode 100644 source/val/validate_derivatives.cpp create mode 100644 source/val/validate_execution_limitations.cpp create mode 100644 source/val/validate_extensions.cpp create mode 100644 source/val/validate_function.cpp create mode 100644 source/val/validate_id.cpp create mode 100644 source/val/validate_image.cpp create mode 100644 source/val/validate_instruction.cpp create mode 100644 source/val/validate_interfaces.cpp create mode 100644 source/val/validate_layout.cpp create mode 100644 source/val/validate_literals.cpp create mode 100644 source/val/validate_logicals.cpp create mode 100644 source/val/validate_memory.cpp create mode 100644 source/val/validate_memory_semantics.cpp create mode 100644 source/val/validate_memory_semantics.h create mode 100644 source/val/validate_mode_setting.cpp create mode 100644 source/val/validate_non_uniform.cpp create mode 100644 source/val/validate_primitives.cpp create mode 100644 source/val/validate_scopes.cpp create mode 100644 source/val/validate_scopes.h create mode 100644 source/val/validate_type.cpp create mode 100644 source/val/validation_state.cpp create mode 100644 source/val/validation_state.h create mode 100644 syntax.md create mode 100644 test/CMakeLists.txt create mode 100644 test/assembly_context_test.cpp create mode 100644 test/assembly_format_test.cpp create mode 100644 test/binary_destroy_test.cpp create mode 100644 test/binary_endianness_test.cpp create mode 100644 test/binary_header_get_test.cpp create mode 100644 test/binary_parse_test.cpp create mode 100644 test/binary_strnlen_s_test.cpp create mode 100644 test/binary_to_text.literal_test.cpp create mode 100644 test/binary_to_text_test.cpp create mode 100644 test/bit_stream.cpp create mode 100644 test/c_interface_test.cpp create mode 100644 test/comment_test.cpp create mode 100644 test/comp/CMakeLists.txt create mode 100644 test/comp/markv_codec_test.cpp create mode 100644 test/cpp_interface_test.cpp create mode 100644 test/diagnostic_test.cpp create mode 100644 test/enum_set_test.cpp create mode 100644 test/enum_string_mapping_test.cpp create mode 100644 test/ext_inst.debuginfo_test.cpp create mode 100644 test/ext_inst.glsl_test.cpp create mode 100644 test/ext_inst.opencl_test.cpp create mode 100644 test/fix_word_test.cpp create mode 100644 test/fuzzers/BUILD.gn create mode 100644 test/fuzzers/corpora/spv/simple.spv create mode 100644 test/fuzzers/spvtools_binary_parser_fuzzer.cpp create mode 100644 test/fuzzers/spvtools_opt_legalization_fuzzer.cpp create mode 100644 test/fuzzers/spvtools_opt_performance_fuzzer.cpp create mode 100644 test/fuzzers/spvtools_opt_size_fuzzer.cpp create mode 100644 test/fuzzers/spvtools_val_fuzzer.cpp create mode 100644 test/generator_magic_number_test.cpp create mode 100644 test/hex_float_test.cpp create mode 100644 test/huffman_codec.cpp create mode 100644 test/immediate_int_test.cpp create mode 100644 test/libspirv_macros_test.cpp create mode 100644 test/link/CMakeLists.txt create mode 100644 test/link/binary_version_test.cpp create mode 100644 test/link/entry_points_test.cpp create mode 100644 test/link/global_values_amount_test.cpp create mode 100644 test/link/ids_limit_test.cpp create mode 100644 test/link/linker_fixture.h create mode 100644 test/link/matching_imports_to_exports_test.cpp create mode 100644 test/link/memory_model_test.cpp create mode 100644 test/link/partial_linkage_test.cpp create mode 100644 test/link/unique_ids_test.cpp create mode 100644 test/log_test.cpp create mode 100644 test/move_to_front_test.cpp create mode 100644 test/name_mapper_test.cpp create mode 100644 test/named_id_test.cpp create mode 100644 test/opcode_make_test.cpp create mode 100644 test/opcode_require_capabilities_test.cpp create mode 100644 test/opcode_split_test.cpp create mode 100644 test/opcode_table_get_test.cpp create mode 100644 test/operand-class-test-coverage.csv create mode 100644 test/operand_capabilities_test.cpp create mode 100644 test/operand_pattern_test.cpp create mode 100644 test/operand_test.cpp create mode 100644 test/opt/CMakeLists.txt create mode 100644 test/opt/aggressive_dead_code_elim_test.cpp create mode 100644 test/opt/assembly_builder.h create mode 100644 test/opt/assembly_builder_test.cpp create mode 100644 test/opt/block_merge_test.cpp create mode 100644 test/opt/ccp_test.cpp create mode 100644 test/opt/cfg_cleanup_test.cpp create mode 100644 test/opt/code_sink_test.cpp create mode 100644 test/opt/combine_access_chains_test.cpp create mode 100644 test/opt/common_uniform_elim_test.cpp create mode 100644 test/opt/compact_ids_test.cpp create mode 100644 test/opt/constant_manager_test.cpp create mode 100644 test/opt/copy_prop_array_test.cpp create mode 100644 test/opt/dead_branch_elim_test.cpp create mode 100644 test/opt/dead_insert_elim_test.cpp create mode 100644 test/opt/dead_variable_elim_test.cpp create mode 100644 test/opt/decoration_manager_test.cpp create mode 100644 test/opt/def_use_test.cpp create mode 100644 test/opt/dominator_tree/CMakeLists.txt create mode 100644 test/opt/dominator_tree/common_dominators.cpp create mode 100644 test/opt/dominator_tree/generated.cpp create mode 100644 test/opt/dominator_tree/nested_ifs.cpp create mode 100644 test/opt/dominator_tree/nested_ifs_post.cpp create mode 100644 test/opt/dominator_tree/nested_loops.cpp create mode 100644 test/opt/dominator_tree/nested_loops_with_unreachables.cpp create mode 100644 test/opt/dominator_tree/pch_test_opt_dom.cpp create mode 100644 test/opt/dominator_tree/pch_test_opt_dom.h create mode 100644 test/opt/dominator_tree/post.cpp create mode 100644 test/opt/dominator_tree/simple.cpp create mode 100644 test/opt/dominator_tree/switch_case_fallthrough.cpp create mode 100644 test/opt/dominator_tree/unreachable_for.cpp create mode 100644 test/opt/dominator_tree/unreachable_for_post.cpp create mode 100644 test/opt/eliminate_dead_const_test.cpp create mode 100644 test/opt/eliminate_dead_functions_test.cpp create mode 100644 test/opt/feature_manager_test.cpp create mode 100644 test/opt/flatten_decoration_test.cpp create mode 100644 test/opt/fold_spec_const_op_composite_test.cpp create mode 100644 test/opt/fold_test.cpp create mode 100644 test/opt/freeze_spec_const_test.cpp create mode 100644 test/opt/function_test.cpp create mode 100644 test/opt/function_utils.h create mode 100644 test/opt/if_conversion_test.cpp create mode 100644 test/opt/inline_opaque_test.cpp create mode 100644 test/opt/inline_test.cpp create mode 100644 test/opt/insert_extract_elim_test.cpp create mode 100644 test/opt/inst_bindless_check_test.cpp create mode 100644 test/opt/instruction_list_test.cpp create mode 100644 test/opt/instruction_test.cpp create mode 100644 test/opt/ir_builder.cpp create mode 100644 test/opt/ir_context_test.cpp create mode 100644 test/opt/ir_loader_test.cpp create mode 100644 test/opt/iterator_test.cpp create mode 100644 test/opt/line_debug_info_test.cpp create mode 100644 test/opt/local_access_chain_convert_test.cpp create mode 100644 test/opt/local_redundancy_elimination_test.cpp create mode 100644 test/opt/local_single_block_elim.cpp create mode 100644 test/opt/local_single_store_elim_test.cpp create mode 100644 test/opt/local_ssa_elim_test.cpp create mode 100644 test/opt/loop_optimizations/CMakeLists.txt create mode 100644 test/opt/loop_optimizations/dependence_analysis.cpp create mode 100644 test/opt/loop_optimizations/dependence_analysis_helpers.cpp create mode 100644 test/opt/loop_optimizations/fusion_compatibility.cpp create mode 100644 test/opt/loop_optimizations/fusion_illegal.cpp create mode 100644 test/opt/loop_optimizations/fusion_legal.cpp create mode 100644 test/opt/loop_optimizations/fusion_pass.cpp create mode 100644 test/opt/loop_optimizations/hoist_all_loop_types.cpp create mode 100644 test/opt/loop_optimizations/hoist_double_nested_loops.cpp create mode 100644 test/opt/loop_optimizations/hoist_from_independent_loops.cpp create mode 100644 test/opt/loop_optimizations/hoist_simple_case.cpp create mode 100644 test/opt/loop_optimizations/hoist_single_nested_loops.cpp create mode 100644 test/opt/loop_optimizations/hoist_without_preheader.cpp create mode 100644 test/opt/loop_optimizations/lcssa.cpp create mode 100644 test/opt/loop_optimizations/loop_descriptions.cpp create mode 100644 test/opt/loop_optimizations/loop_fission.cpp create mode 100644 test/opt/loop_optimizations/nested_loops.cpp create mode 100644 test/opt/loop_optimizations/pch_test_opt_loop.cpp create mode 100644 test/opt/loop_optimizations/pch_test_opt_loop.h create mode 100644 test/opt/loop_optimizations/peeling.cpp create mode 100644 test/opt/loop_optimizations/peeling_pass.cpp create mode 100644 test/opt/loop_optimizations/unroll_assumptions.cpp create mode 100644 test/opt/loop_optimizations/unroll_simple.cpp create mode 100644 test/opt/loop_optimizations/unswitch.cpp create mode 100644 test/opt/module_test.cpp create mode 100644 test/opt/module_utils.h create mode 100644 test/opt/optimizer_test.cpp create mode 100644 test/opt/pass_fixture.h create mode 100644 test/opt/pass_manager_test.cpp create mode 100644 test/opt/pass_merge_return_test.cpp create mode 100644 test/opt/pass_remove_duplicates_test.cpp create mode 100644 test/opt/pass_utils.cpp create mode 100644 test/opt/pass_utils.h create mode 100644 test/opt/pch_test_opt.cpp create mode 100644 test/opt/pch_test_opt.h create mode 100644 test/opt/private_to_local_test.cpp create mode 100644 test/opt/process_lines_test.cpp create mode 100644 test/opt/propagator_test.cpp create mode 100644 test/opt/reduce_load_size_test.cpp create mode 100644 test/opt/redundancy_elimination_test.cpp create mode 100644 test/opt/register_liveness.cpp create mode 100644 test/opt/replace_invalid_opc_test.cpp create mode 100644 test/opt/scalar_analysis.cpp create mode 100644 test/opt/scalar_replacement_test.cpp create mode 100644 test/opt/set_spec_const_default_value_test.cpp create mode 100644 test/opt/simplification_test.cpp create mode 100644 test/opt/strength_reduction_test.cpp create mode 100644 test/opt/strip_debug_info_test.cpp create mode 100644 test/opt/strip_reflect_info_test.cpp create mode 100644 test/opt/struct_cfg_analysis_test.cpp create mode 100644 test/opt/type_manager_test.cpp create mode 100644 test/opt/types_test.cpp create mode 100644 test/opt/unify_const_test.cpp create mode 100644 test/opt/upgrade_memory_model_test.cpp create mode 100644 test/opt/utils_test.cpp create mode 100644 test/opt/value_table_test.cpp create mode 100644 test/opt/vector_dce_test.cpp create mode 100644 test/opt/workaround1209_test.cpp create mode 100644 test/parse_number_test.cpp create mode 100644 test/pch_test.cpp create mode 100644 test/pch_test.h create mode 100644 test/preserve_numeric_ids_test.cpp create mode 100644 test/reduce/CMakeLists.txt create mode 100644 test/reduce/operand_to_constant_reduction_pass_test.cpp create mode 100644 test/reduce/operand_to_dominating_id_reduction_pass_test.cpp create mode 100644 test/reduce/operand_to_undef_reduction_pass_test.cpp create mode 100644 test/reduce/reduce_test_util.cpp create mode 100644 test/reduce/reduce_test_util.h create mode 100644 test/reduce/reducer_test.cpp create mode 100644 test/reduce/remove_opname_instruction_reduction_pass_test.cpp create mode 100644 test/reduce/remove_unreferenced_instruction_reduction_pass_test.cpp create mode 100644 test/reduce/structured_loop_to_selection_reduction_pass_test.cpp create mode 100644 test/reduce/validation_during_reduction_test.cpp create mode 100644 test/scripts/test_compact_ids.py create mode 100644 test/software_version_test.cpp create mode 100644 test/stats/CMakeLists.txt create mode 100644 test/stats/stats_aggregate_test.cpp create mode 100644 test/stats/stats_analyzer_test.cpp create mode 100644 test/string_utils_test.cpp create mode 100644 test/target_env_test.cpp create mode 100644 test/test_fixture.h create mode 100644 test/text_advance_test.cpp create mode 100644 test/text_destroy_test.cpp create mode 100644 test/text_literal_test.cpp create mode 100644 test/text_start_new_inst_test.cpp create mode 100644 test/text_to_binary.annotation_test.cpp create mode 100644 test/text_to_binary.barrier_test.cpp create mode 100644 test/text_to_binary.constant_test.cpp create mode 100644 test/text_to_binary.control_flow_test.cpp create mode 100644 test/text_to_binary.debug_test.cpp create mode 100644 test/text_to_binary.device_side_enqueue_test.cpp create mode 100644 test/text_to_binary.extension_test.cpp create mode 100644 test/text_to_binary.function_test.cpp create mode 100644 test/text_to_binary.group_test.cpp create mode 100644 test/text_to_binary.image_test.cpp create mode 100644 test/text_to_binary.literal_test.cpp create mode 100644 test/text_to_binary.memory_test.cpp create mode 100644 test/text_to_binary.misc_test.cpp create mode 100644 test/text_to_binary.mode_setting_test.cpp create mode 100644 test/text_to_binary.pipe_storage_test.cpp create mode 100644 test/text_to_binary.reserved_sampling_test.cpp create mode 100644 test/text_to_binary.subgroup_dispatch_test.cpp create mode 100644 test/text_to_binary.type_declaration_test.cpp create mode 100644 test/text_to_binary_test.cpp create mode 100644 test/text_word_get_test.cpp create mode 100644 test/timer_test.cpp create mode 100644 test/tools/CMakeLists.txt create mode 100755 test/tools/expect.py create mode 100755 test/tools/expect_nosetest.py create mode 100644 test/tools/opt/CMakeLists.txt create mode 100644 test/tools/opt/flags.py create mode 100644 test/tools/opt/oconfig.py create mode 100755 test/tools/placeholder.py create mode 100755 test/tools/spirv_test_framework.py create mode 100755 test/tools/spirv_test_framework_nosetest.py create mode 100644 test/unit_spirv.cpp create mode 100644 test/unit_spirv.h create mode 100644 test/util/CMakeLists.txt create mode 100644 test/util/bit_vector_test.cpp create mode 100644 test/util/ilist_test.cpp create mode 100644 test/util/small_vector_test.cpp create mode 100644 test/val/CMakeLists.txt create mode 100644 test/val/pch_test_val.cpp create mode 100644 test/val/pch_test_val.h create mode 100644 test/val/val_adjacency_test.cpp create mode 100644 test/val/val_arithmetics_test.cpp create mode 100644 test/val/val_atomics_test.cpp create mode 100644 test/val/val_barriers_test.cpp create mode 100644 test/val/val_bitwise_test.cpp create mode 100644 test/val/val_builtins_test.cpp create mode 100644 test/val/val_capability_test.cpp create mode 100644 test/val/val_cfg_test.cpp create mode 100644 test/val/val_composites_test.cpp create mode 100644 test/val/val_constants_test.cpp create mode 100644 test/val/val_conversion_test.cpp create mode 100644 test/val/val_data_test.cpp create mode 100644 test/val/val_decoration_test.cpp create mode 100644 test/val/val_derivatives_test.cpp create mode 100644 test/val/val_explicit_reserved_test.cpp create mode 100644 test/val/val_ext_inst_test.cpp create mode 100644 test/val/val_extensions_test.cpp create mode 100644 test/val/val_fixtures.h create mode 100644 test/val/val_id_test.cpp create mode 100644 test/val/val_image_test.cpp create mode 100644 test/val/val_interfaces_test.cpp create mode 100644 test/val/val_layout_test.cpp create mode 100644 test/val/val_limits_test.cpp create mode 100644 test/val/val_literals_test.cpp create mode 100644 test/val/val_logicals_test.cpp create mode 100644 test/val/val_memory_test.cpp create mode 100644 test/val/val_modes_test.cpp create mode 100644 test/val/val_non_uniform_test.cpp create mode 100644 test/val/val_primitives_test.cpp create mode 100644 test/val/val_ssa_test.cpp create mode 100644 test/val/val_state_test.cpp create mode 100644 test/val/val_storage_test.cpp create mode 100644 test/val/val_type_unique_test.cpp create mode 100644 test/val/val_validation_state_test.cpp create mode 100644 test/val/val_version_test.cpp create mode 100644 test/val/val_webgpu_test.cpp create mode 100644 tools/CMakeLists.txt create mode 100644 tools/as/as.cpp create mode 100644 tools/cfg/bin_to_dot.cpp create mode 100644 tools/cfg/bin_to_dot.h create mode 100644 tools/cfg/cfg.cpp create mode 100644 tools/comp/markv.cpp create mode 100644 tools/comp/markv_model_factory.cpp create mode 100644 tools/comp/markv_model_factory.h create mode 100644 tools/comp/markv_model_shader.cpp create mode 100644 tools/comp/markv_model_shader.h create mode 100644 tools/comp/markv_model_shader_default_autogen.inc create mode 100644 tools/dis/dis.cpp create mode 100644 tools/emacs/50spirv-tools.el create mode 100644 tools/emacs/CMakeLists.txt create mode 100644 tools/io.h create mode 100644 tools/lesspipe/CMakeLists.txt create mode 100644 tools/lesspipe/spirv-lesspipe.sh create mode 100644 tools/link/linker.cpp create mode 100644 tools/opt/opt.cpp create mode 100644 tools/reduce/reduce.cpp create mode 100644 tools/stats/spirv_stats.cpp create mode 100644 tools/stats/spirv_stats.h create mode 100644 tools/stats/stats.cpp create mode 100644 tools/stats/stats_analyzer.cpp create mode 100644 tools/stats/stats_analyzer.h create mode 100644 tools/util/cli_consumer.cpp create mode 100644 tools/util/cli_consumer.h create mode 100644 tools/val/val.cpp create mode 100755 utils/check_code_format.sh create mode 100755 utils/check_copyright.py create mode 100755 utils/check_symbol_exports.py create mode 100755 utils/fixup_fuzz_result.py create mode 100755 utils/generate_grammar_tables.py create mode 100755 utils/generate_language_headers.py create mode 100755 utils/generate_registry_tables.py create mode 100755 utils/generate_vim_syntax.py create mode 100755 utils/update_build_version.py diff --git a/.appveyor.yml b/.appveyor.yml new file mode 100644 index 000000000..a50c7c2a9 --- /dev/null +++ b/.appveyor.yml @@ -0,0 +1,89 @@ +# Windows Build Configuration for AppVeyor +# http://www.appveyor.com/docs/appveyor-yml + +# version format +version: "{build}" + +# The most recent compiler gives the most interesting new results. +# Put it first so we get its feedback first. +os: + - Visual Studio 2017 + #- Visual Studio 2013 + +platform: + - x64 + +configuration: + - Debug + #- Release + +branches: + only: + - master + +# Travis advances the master-tot tag to current top of the tree after +# each push into the master branch, because it relies on that tag to +# upload build artifacts to the master-tot release. This will cause +# double testing for each push on Appveyor: one for the push, one for +# the tag advance. Disable testing tags. +skip_tags: true + +clone_depth: 1 + +matrix: + fast_finish: true # Show final status immediately if a test fails. + #exclude: + # - os: Visual Studio 2013 + # configuration: Debug + +# scripts that run after cloning repository +install: + # Install ninja + - set NINJA_URL="https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-win.zip" + - appveyor DownloadFile %NINJA_URL% -FileName ninja.zip + - 7z x ninja.zip -oC:\ninja > nul + - set PATH=C:\ninja;%PATH% + +before_build: + - git clone --depth=1 https://github.com/KhronosGroup/SPIRV-Headers.git external/spirv-headers + - git clone --depth=1 https://github.com/google/googletest.git external/googletest + - git clone --depth=1 https://github.com/google/effcee.git external/effcee + - git clone --depth=1 https://github.com/google/re2.git external/re2 + # Set path and environment variables for the current Visual Studio version + - if "%APPVEYOR_BUILD_WORKER_IMAGE%"=="Visual Studio 2013" (call "C:\Program Files (x86)\Microsoft Visual Studio 12.0\VC\vcvarsall.bat" x86_amd64) + - if "%APPVEYOR_BUILD_WORKER_IMAGE%"=="Visual Studio 2017" (call "C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Auxiliary\Build\vcvarsall.bat" x86_amd64) + +build: + parallel: true # enable MSBuild parallel builds + verbosity: minimal + +build_script: + - mkdir build && cd build + - cmake -GNinja -DSPIRV_BUILD_COMPRESSION=ON -DCMAKE_BUILD_TYPE=%CONFIGURATION% -DCMAKE_INSTALL_PREFIX=install -DRE2_BUILD_TESTING=OFF .. + - ninja install + +test_script: + - ctest -C %CONFIGURATION% --output-on-failure --timeout 300 + +after_test: + # Zip build artifacts for uploading and deploying + - cd install + - 7z a SPIRV-Tools-master-windows-"%PLATFORM%"-"%CONFIGURATION%".zip *\* + +artifacts: + - path: build\install\*.zip + name: artifacts-zip + +deploy: + - provider: GitHub + auth_token: + secure: TMfcScKzzFIm1YgeV/PwCRXFDCw8Xm0wY2Vb2FU6WKlbzb5eUITTpr6I5vHPnAxS + release: master-tot + description: "Continuous build of the latest master branch by Appveyor and Travis CI" + artifact: artifacts-zip + draft: false + prerelease: false + force_update: true + on: + branch: master + APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2017 diff --git a/.clang-format b/.clang-format new file mode 100644 index 000000000..1d43c4ed4 --- /dev/null +++ b/.clang-format @@ -0,0 +1,6 @@ +--- +Language: Cpp +BasedOnStyle: Google +DerivePointerAlignment: false +SortIncludes: true +... diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..059b18ed7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,25 @@ +.clang_complete +.ycm_extra_conf.py* +*.pyc +compile_commands.json +/build/ +/buildtools/ +/external/googletest +/external/SPIRV-Headers +/external/spirv-headers +/external/effcee +/external/re2 +/out +/TAGS +/third_party/llvm-build/ +/testing +/tools/clang/ +/utils/clang-format-diff.py + +# Vim +[._]*.s[a-w][a-z] +*~ + +# C-Lion +.idea +cmake-build-debug \ No newline at end of file diff --git a/.gn b/.gn new file mode 100644 index 000000000..377a97ab8 --- /dev/null +++ b/.gn @@ -0,0 +1,20 @@ +# Copyright 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +buildconfig = "//build/config/BUILDCONFIG.gn" + +default_args = { + clang_use_chrome_plugins = false + use_custom_libcxx = false +} diff --git a/Android.mk b/Android.mk new file mode 100644 index 000000000..179cf2614 --- /dev/null +++ b/Android.mk @@ -0,0 +1,343 @@ +LOCAL_PATH := $(call my-dir) +SPVTOOLS_OUT_PATH=$(if $(call host-path-is-absolute,$(TARGET_OUT)),$(TARGET_OUT),$(abspath $(TARGET_OUT))) + +ifeq ($(SPVHEADERS_LOCAL_PATH),) + SPVHEADERS_LOCAL_PATH := $(LOCAL_PATH)/external/spirv-headers +endif + +SPVTOOLS_SRC_FILES := \ + source/assembly_grammar.cpp \ + source/binary.cpp \ + source/diagnostic.cpp \ + source/disassemble.cpp \ + source/ext_inst.cpp \ + source/enum_string_mapping.cpp \ + source/extensions.cpp \ + source/id_descriptor.cpp \ + source/libspirv.cpp \ + source/name_mapper.cpp \ + source/opcode.cpp \ + source/operand.cpp \ + source/parsed_operand.cpp \ + source/print.cpp \ + source/software_version.cpp \ + source/spirv_endian.cpp \ + source/spirv_optimizer_options.cpp \ + source/spirv_target_env.cpp \ + source/spirv_validator_options.cpp \ + source/table.cpp \ + source/text.cpp \ + source/text_handler.cpp \ + source/util/bit_vector.cpp \ + source/util/parse_number.cpp \ + source/util/string_utils.cpp \ + source/util/timer.cpp \ + source/val/basic_block.cpp \ + source/val/construct.cpp \ + source/val/function.cpp \ + source/val/instruction.cpp \ + source/val/validation_state.cpp \ + source/val/validate.cpp \ + source/val/validate_adjacency.cpp \ + source/val/validate_annotation.cpp \ + source/val/validate_arithmetics.cpp \ + source/val/validate_atomics.cpp \ + source/val/validate_barriers.cpp \ + source/val/validate_bitwise.cpp \ + source/val/validate_builtins.cpp \ + source/val/validate_capability.cpp \ + source/val/validate_cfg.cpp \ + source/val/validate_composites.cpp \ + source/val/validate_constants.cpp \ + source/val/validate_conversion.cpp \ + source/val/validate_datarules.cpp \ + source/val/validate_debug.cpp \ + source/val/validate_decorations.cpp \ + source/val/validate_derivatives.cpp \ + source/val/validate_extensions.cpp \ + source/val/validate_execution_limitations.cpp \ + source/val/validate_function.cpp \ + source/val/validate_id.cpp \ + source/val/validate_image.cpp \ + source/val/validate_interfaces.cpp \ + source/val/validate_instruction.cpp \ + source/val/validate_memory.cpp \ + source/val/validate_memory_semantics.cpp \ + source/val/validate_mode_setting.cpp \ + source/val/validate_layout.cpp \ + source/val/validate_literals.cpp \ + source/val/validate_logicals.cpp \ + source/val/validate_non_uniform.cpp \ + source/val/validate_primitives.cpp \ + source/val/validate_scopes.cpp \ + source/val/validate_type.cpp + +SPVTOOLS_OPT_SRC_FILES := \ + source/opt/aggressive_dead_code_elim_pass.cpp \ + source/opt/basic_block.cpp \ + source/opt/block_merge_pass.cpp \ + source/opt/build_module.cpp \ + source/opt/cfg.cpp \ + source/opt/cfg_cleanup_pass.cpp \ + source/opt/ccp_pass.cpp \ + source/opt/code_sink.cpp \ + source/opt/combine_access_chains.cpp \ + source/opt/common_uniform_elim_pass.cpp \ + source/opt/compact_ids_pass.cpp \ + source/opt/composite.cpp \ + source/opt/const_folding_rules.cpp \ + source/opt/constants.cpp \ + source/opt/copy_prop_arrays.cpp \ + source/opt/dead_branch_elim_pass.cpp \ + source/opt/dead_insert_elim_pass.cpp \ + source/opt/dead_variable_elimination.cpp \ + source/opt/decoration_manager.cpp \ + source/opt/def_use_manager.cpp \ + source/opt/dominator_analysis.cpp \ + source/opt/dominator_tree.cpp \ + source/opt/eliminate_dead_constant_pass.cpp \ + source/opt/eliminate_dead_functions_pass.cpp \ + source/opt/feature_manager.cpp \ + source/opt/flatten_decoration_pass.cpp \ + source/opt/fold.cpp \ + source/opt/folding_rules.cpp \ + source/opt/fold_spec_constant_op_and_composite_pass.cpp \ + source/opt/freeze_spec_constant_value_pass.cpp \ + source/opt/function.cpp \ + source/opt/if_conversion.cpp \ + source/opt/inline_pass.cpp \ + source/opt/inline_exhaustive_pass.cpp \ + source/opt/inline_opaque_pass.cpp \ + source/opt/inst_bindless_check_pass.cpp \ + source/opt/instruction.cpp \ + source/opt/instruction_list.cpp \ + source/opt/instrument_pass.cpp \ + source/opt/ir_context.cpp \ + source/opt/ir_loader.cpp \ + source/opt/licm_pass.cpp \ + source/opt/local_access_chain_convert_pass.cpp \ + source/opt/local_redundancy_elimination.cpp \ + source/opt/local_single_block_elim_pass.cpp \ + source/opt/local_single_store_elim_pass.cpp \ + source/opt/local_ssa_elim_pass.cpp \ + source/opt/loop_dependence.cpp \ + source/opt/loop_dependence_helpers.cpp \ + source/opt/loop_descriptor.cpp \ + source/opt/loop_fission.cpp \ + source/opt/loop_fusion.cpp \ + source/opt/loop_fusion_pass.cpp \ + source/opt/loop_peeling.cpp \ + source/opt/loop_unroller.cpp \ + source/opt/loop_unswitch_pass.cpp \ + source/opt/loop_utils.cpp \ + source/opt/mem_pass.cpp \ + source/opt/merge_return_pass.cpp \ + source/opt/module.cpp \ + source/opt/optimizer.cpp \ + source/opt/pass.cpp \ + source/opt/pass_manager.cpp \ + source/opt/private_to_local_pass.cpp \ + source/opt/process_lines_pass.cpp \ + source/opt/propagator.cpp \ + source/opt/reduce_load_size.cpp \ + source/opt/redundancy_elimination.cpp \ + source/opt/register_pressure.cpp \ + source/opt/remove_duplicates_pass.cpp \ + source/opt/replace_invalid_opc.cpp \ + source/opt/scalar_analysis.cpp \ + source/opt/scalar_analysis_simplification.cpp \ + source/opt/scalar_replacement_pass.cpp \ + source/opt/set_spec_constant_default_value_pass.cpp \ + source/opt/simplification_pass.cpp \ + source/opt/ssa_rewrite_pass.cpp \ + source/opt/strength_reduction_pass.cpp \ + source/opt/strip_debug_info_pass.cpp \ + source/opt/strip_reflect_info_pass.cpp \ + source/opt/struct_cfg_analysis.cpp \ + source/opt/type_manager.cpp \ + source/opt/types.cpp \ + source/opt/unify_const_pass.cpp \ + source/opt/upgrade_memory_model.cpp \ + source/opt/value_number_table.cpp \ + source/opt/vector_dce.cpp \ + source/opt/workaround1209.cpp + +# Locations of grammar files. +# +# TODO(dneto): Build a single set of tables that embeds versioning differences on +# a per-item basis. That must happen before SPIR-V 1.4, etc. +# https://github.com/KhronosGroup/SPIRV-Tools/issues/1195 +SPV_CORE10_GRAMMAR=$(SPVHEADERS_LOCAL_PATH)/include/spirv/1.0/spirv.core.grammar.json +SPV_CORE11_GRAMMAR=$(SPVHEADERS_LOCAL_PATH)/include/spirv/1.1/spirv.core.grammar.json +SPV_CORE12_GRAMMAR=$(SPVHEADERS_LOCAL_PATH)/include/spirv/1.2/spirv.core.grammar.json +SPV_COREUNIFIED1_GRAMMAR=$(SPVHEADERS_LOCAL_PATH)/include/spirv/unified1/spirv.core.grammar.json +SPV_CORELATEST_GRAMMAR=$(SPV_COREUNIFIED1_GRAMMAR) +SPV_GLSL_GRAMMAR=$(SPVHEADERS_LOCAL_PATH)/include/spirv/1.2/extinst.glsl.std.450.grammar.json +SPV_OPENCL_GRAMMAR=$(SPVHEADERS_LOCAL_PATH)/include/spirv/1.2/extinst.opencl.std.100.grammar.json +# TODO(dneto): I expect the DebugInfo grammar file to eventually migrate to SPIRV-Headers +SPV_DEBUGINFO_GRAMMAR=$(LOCAL_PATH)/source/extinst.debuginfo.grammar.json + +define gen_spvtools_grammar_tables +$(call generate-file-dir,$(1)/core.insts-1.0.inc) +$(1)/core.insts-1.0.inc $(1)/operand.kinds-1.0.inc $(1)/glsl.std.450.insts.inc $(1)/opencl.std.insts.inc: \ + $(LOCAL_PATH)/utils/generate_grammar_tables.py \ + $(SPV_CORE10_GRAMMAR) \ + $(SPV_GLSL_GRAMMAR) \ + $(SPV_OPENCL_GRAMMAR) \ + $(SPV_DEBUGINFO_GRAMMAR) + @$(HOST_PYTHON) $(LOCAL_PATH)/utils/generate_grammar_tables.py \ + --spirv-core-grammar=$(SPV_CORE10_GRAMMAR) \ + --extinst-glsl-grammar=$(SPV_GLSL_GRAMMAR) \ + --extinst-opencl-grammar=$(SPV_OPENCL_GRAMMAR) \ + --extinst-debuginfo-grammar=$(SPV_DEBUGINFO_GRAMMAR) \ + --core-insts-output=$(1)/core.insts-1.0.inc \ + --glsl-insts-output=$(1)/glsl.std.450.insts.inc \ + --opencl-insts-output=$(1)/opencl.std.insts.inc \ + --operand-kinds-output=$(1)/operand.kinds-1.0.inc + @echo "[$(TARGET_ARCH_ABI)] Grammar v1.0 : instructions & operands <= grammar JSON files" +$(1)/core.insts-1.1.inc $(1)/operand.kinds-1.1.inc: \ + $(LOCAL_PATH)/utils/generate_grammar_tables.py \ + $(SPV_CORE11_GRAMMAR) \ + $(SPV_DEBUGINFO_GRAMMAR) + @$(HOST_PYTHON) $(LOCAL_PATH)/utils/generate_grammar_tables.py \ + --spirv-core-grammar=$(SPV_CORE11_GRAMMAR) \ + --extinst-debuginfo-grammar=$(SPV_DEBUGINFO_GRAMMAR) \ + --core-insts-output=$(1)/core.insts-1.1.inc \ + --operand-kinds-output=$(1)/operand.kinds-1.1.inc + @echo "[$(TARGET_ARCH_ABI)] Grammar v1.1 : instructions & operands <= grammar JSON files" +$(1)/core.insts-1.2.inc $(1)/operand.kinds-1.2.inc: \ + $(LOCAL_PATH)/utils/generate_grammar_tables.py \ + $(SPV_CORE12_GRAMMAR) \ + $(SPV_DEBUGINFO_GRAMMAR) + @$(HOST_PYTHON) $(LOCAL_PATH)/utils/generate_grammar_tables.py \ + --spirv-core-grammar=$(SPV_CORE12_GRAMMAR) \ + --extinst-debuginfo-grammar=$(SPV_DEBUGINFO_GRAMMAR) \ + --core-insts-output=$(1)/core.insts-1.2.inc \ + --operand-kinds-output=$(1)/operand.kinds-1.2.inc + @echo "[$(TARGET_ARCH_ABI)] Grammar v1.2 : instructions & operands <= grammar JSON files" +$(1)/core.insts-unified1.inc $(1)/operand.kinds-unified1.inc: \ + $(LOCAL_PATH)/utils/generate_grammar_tables.py \ + $(SPV_COREUNIFIED1_GRAMMAR) \ + $(SPV_DEBUGINFO_GRAMMAR) + @$(HOST_PYTHON) $(LOCAL_PATH)/utils/generate_grammar_tables.py \ + --spirv-core-grammar=$(SPV_COREUNIFIED1_GRAMMAR) \ + --extinst-debuginfo-grammar=$(SPV_DEBUGINFO_GRAMMAR) \ + --core-insts-output=$(1)/core.insts-unified1.inc \ + --operand-kinds-output=$(1)/operand.kinds-unified1.inc + @echo "[$(TARGET_ARCH_ABI)] Grammar v1.3 (from unified1) : instructions & operands <= grammar JSON files" +$(LOCAL_PATH)/source/opcode.cpp: $(1)/core.insts-1.0.inc $(1)/core.insts-1.1.inc $(1)/core.insts-1.2.inc $(1)/core.insts-unified1.inc +$(LOCAL_PATH)/source/operand.cpp: $(1)/operand.kinds-1.0.inc $(1)/operand.kinds-1.1.inc $(1)/operand.kinds-1.2.inc $(1)/operand.kinds-unified1.inc +$(LOCAL_PATH)/source/ext_inst.cpp: \ + $(1)/glsl.std.450.insts.inc \ + $(1)/opencl.std.insts.inc \ + $(1)/debuginfo.insts.inc \ + $(1)/spv-amd-gcn-shader.insts.inc \ + $(1)/spv-amd-shader-ballot.insts.inc \ + $(1)/spv-amd-shader-explicit-vertex-parameter.insts.inc \ + $(1)/spv-amd-shader-trinary-minmax.insts.inc +endef +$(eval $(call gen_spvtools_grammar_tables,$(SPVTOOLS_OUT_PATH))) + + +define gen_spvtools_lang_headers +# Generate language-specific headers. So far we only generate C headers +# $1 is the output directory. +# $2 is the base name of the header file, e.g. "DebugInfo". +# $3 is the grammar file containing token definitions. +$(call generate-file-dir,$(1)/$(2).h) +$(1)/$(2).h : \ + $(LOCAL_PATH)/utils/generate_language_headers.py \ + $(3) + @$(HOST_PYTHON) $(LOCAL_PATH)/utils/generate_language_headers.py \ + --extinst-name=$(2) \ + --extinst-grammar=$(3) \ + --extinst-output-base=$(1)/$(2) + @echo "[$(TARGET_ARCH_ABI)] Generate language specific header for $(2): headers <= grammar" +$(LOCAL_PATH)/source/ext_inst.cpp: $(1)/$(2).h +endef +# We generate language-specific headers for DebugInfo +$(eval $(call gen_spvtools_lang_headers,$(SPVTOOLS_OUT_PATH),DebugInfo,$(SPV_DEBUGINFO_GRAMMAR))) + + +define gen_spvtools_vendor_tables +$(call generate-file-dir,$(1)/$(2).insts.inc) +$(1)/$(2).insts.inc : \ + $(LOCAL_PATH)/utils/generate_grammar_tables.py \ + $(LOCAL_PATH)/source/extinst.$(2).grammar.json + @$(HOST_PYTHON) $(LOCAL_PATH)/utils/generate_grammar_tables.py \ + --extinst-vendor-grammar=$(LOCAL_PATH)/source/extinst.$(2).grammar.json \ + --vendor-insts-output=$(1)/$(2).insts.inc + @echo "[$(TARGET_ARCH_ABI)] Vendor extended instruction set: $(2) tables <= grammar" +$(LOCAL_PATH)/source/ext_inst.cpp: $(1)/$(2).insts.inc +endef +# Vendor extended instruction sets, with grammars from SPIRV-Tools source tree. +SPV_NONSTANDARD_EXTINST_GRAMMARS=$(foreach F,$(wildcard $(LOCAL_PATH)/source/extinst.*.grammar.json),$(patsubst extinst.%.grammar.json,%,$(notdir $F))) +$(foreach E,$(SPV_NONSTANDARD_EXTINST_GRAMMARS),$(eval $(call gen_spvtools_vendor_tables,$(SPVTOOLS_OUT_PATH),$E))) + +define gen_spvtools_enum_string_mapping +$(call generate-file-dir,$(1)/extension_enum.inc.inc) +$(1)/extension_enum.inc $(1)/enum_string_mapping.inc: \ + $(LOCAL_PATH)/utils/generate_grammar_tables.py \ + $(SPV_CORELATEST_GRAMMAR) + @$(HOST_PYTHON) $(LOCAL_PATH)/utils/generate_grammar_tables.py \ + --spirv-core-grammar=$(SPV_CORELATEST_GRAMMAR) \ + --extinst-debuginfo-grammar=$(SPV_DEBUGINFO_GRAMMAR) \ + --extension-enum-output=$(1)/extension_enum.inc \ + --enum-string-mapping-output=$(1)/enum_string_mapping.inc + @echo "[$(TARGET_ARCH_ABI)] Generate enum<->string mapping <= grammar JSON files" +# Generated header extension_enum.inc is transitively included by table.h, which is +# used pervasively. Capture the pervasive dependency. +$(foreach F,$(SPVTOOLS_SRC_FILES) $(SPVTOOLS_OPT_SRC_FILES),$(LOCAL_PATH)/$F ) \ + : $(1)/extension_enum.inc +$(LOCAL_PATH)/source/enum_string_mapping.cpp: $(1)/enum_string_mapping.inc +endef +$(eval $(call gen_spvtools_enum_string_mapping,$(SPVTOOLS_OUT_PATH))) + +define gen_spvtools_build_version_inc +$(call generate-file-dir,$(1)/dummy_filename) +$(1)/build-version.inc: \ + $(LOCAL_PATH)/utils/update_build_version.py \ + $(LOCAL_PATH)/CHANGES + @$(HOST_PYTHON) $(LOCAL_PATH)/utils/update_build_version.py \ + $(LOCAL_PATH) $(1)/build-version.inc + @echo "[$(TARGET_ARCH_ABI)] Generate : build-version.inc <= CHANGES" +$(LOCAL_PATH)/source/software_version.cpp: $(1)/build-version.inc +endef +$(eval $(call gen_spvtools_build_version_inc,$(SPVTOOLS_OUT_PATH))) + +define gen_spvtools_generators_inc +$(call generate-file-dir,$(1)/dummy_filename) +$(1)/generators.inc: \ + $(LOCAL_PATH)/utils/generate_registry_tables.py \ + $(SPVHEADERS_LOCAL_PATH)/include/spirv/spir-v.xml + @$(HOST_PYTHON) $(LOCAL_PATH)/utils/generate_registry_tables.py \ + --xml=$(SPVHEADERS_LOCAL_PATH)/include/spirv/spir-v.xml \ + --generator-output=$(1)/generators.inc + @echo "[$(TARGET_ARCH_ABI)] Generate : generators.inc <= spir-v.xml" +$(LOCAL_PATH)/source/opcode.cpp: $(1)/generators.inc +endef +$(eval $(call gen_spvtools_generators_inc,$(SPVTOOLS_OUT_PATH))) + +include $(CLEAR_VARS) +LOCAL_MODULE := SPIRV-Tools +LOCAL_C_INCLUDES := \ + $(LOCAL_PATH)/include \ + $(SPVHEADERS_LOCAL_PATH)/include \ + $(SPVTOOLS_OUT_PATH) +LOCAL_EXPORT_C_INCLUDES := \ + $(LOCAL_PATH)/include +LOCAL_CXXFLAGS:=-std=c++11 -fno-exceptions -fno-rtti -Werror +LOCAL_SRC_FILES:= $(SPVTOOLS_SRC_FILES) +include $(BUILD_STATIC_LIBRARY) + +include $(CLEAR_VARS) +LOCAL_MODULE := SPIRV-Tools-opt +LOCAL_C_INCLUDES := \ + $(LOCAL_PATH)/include \ + $(LOCAL_PATH)/source \ + $(SPVHEADERS_LOCAL_PATH)/include \ + $(SPVTOOLS_OUT_PATH) +LOCAL_CXXFLAGS:=-std=c++11 -fno-exceptions -fno-rtti -Werror +LOCAL_STATIC_LIBRARIES:=SPIRV-Tools +LOCAL_SRC_FILES:= $(SPVTOOLS_OPT_SRC_FILES) +include $(BUILD_STATIC_LIBRARY) diff --git a/BUILD.gn b/BUILD.gn new file mode 100644 index 000000000..69e3dc812 --- /dev/null +++ b/BUILD.gn @@ -0,0 +1,801 @@ +# Copyright 2018 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import("//build_overrides/spirv_tools.gni") + +import("//testing/test.gni") +import("//build_overrides/build.gni") + +spirv_headers = spirv_tools_spirv_headers_dir + +template("spvtools_core_tables") { + assert(defined(invoker.version), "Need version in $target_name generation.") + + action("spvtools_core_tables_" + target_name) { + script = "utils/generate_grammar_tables.py" + + version = invoker.version + + core_json_file = + "${spirv_headers}/include/spirv/$version/spirv.core.grammar.json" + core_insts_file = "${target_gen_dir}/core.insts-$version.inc" + operand_kinds_file = "${target_gen_dir}/operand.kinds-$version.inc" + extinst_file = "source/extinst.debuginfo.grammar.json" + + sources = [ + core_json_file, + ] + outputs = [ + core_insts_file, + operand_kinds_file, + ] + args = [ + "--spirv-core-grammar", + rebase_path(core_json_file, root_build_dir), + "--core-insts-output", + rebase_path(core_insts_file, root_build_dir), + "--extinst-debuginfo-grammar", + rebase_path(extinst_file, root_build_dir), + "--operand-kinds-output", + rebase_path(operand_kinds_file, root_build_dir), + ] + } +} + +template("spvtools_core_enums") { + assert(defined(invoker.version), "Need version in $target_name generation.") + + action("spvtools_core_enums_" + target_name) { + script = "utils/generate_grammar_tables.py" + + version = invoker.version + + core_json_file = + "${spirv_headers}/include/spirv/$version/spirv.core.grammar.json" + debug_insts_file = "source/extinst.debuginfo.grammar.json" + extension_enum_file = "${target_gen_dir}/extension_enum.inc" + extension_map_file = "${target_gen_dir}/enum_string_mapping.inc" + + args = [ + "--spirv-core-grammar", + rebase_path(core_json_file, root_build_dir), + "--extinst-debuginfo-grammar", + rebase_path(debug_insts_file, root_build_dir), + "--extension-enum-output", + rebase_path(extension_enum_file, root_build_dir), + "--enum-string-mapping-output", + rebase_path(extension_map_file, root_build_dir), + ] + inputs = [ + core_json_file, + ] + outputs = [ + extension_enum_file, + extension_map_file, + ] + } +} + +template("spvtools_glsl_tables") { + assert(defined(invoker.version), "Need version in $target_name generation.") + + action("spvtools_glsl_tables_" + target_name) { + script = "utils/generate_grammar_tables.py" + + version = invoker.version + + core_json_file = + "${spirv_headers}/include/spirv/$version/spirv.core.grammar.json" + glsl_json_file = "${spirv_headers}/include/spirv/${version}/extinst.glsl.std.450.grammar.json" + glsl_insts_file = "${target_gen_dir}/glsl.std.450.insts.inc" + debug_insts_file = "source/extinst.debuginfo.grammar.json" + + args = [ + "--spirv-core-grammar", + rebase_path(core_json_file, root_build_dir), + "--extinst-glsl-grammar", + rebase_path(glsl_json_file, root_build_dir), + "--glsl-insts-output", + rebase_path(glsl_insts_file, root_build_dir), + "--extinst-debuginfo-grammar", + rebase_path(debug_insts_file, root_build_dir), + ] + inputs = [ + core_json_file, + glsl_json_file, + ] + outputs = [ + glsl_insts_file, + ] + } +} + +template("spvtools_opencl_tables") { + assert(defined(invoker.version), "Need version in $target_name generation.") + + action("spvtools_opencl_tables_" + target_name) { + script = "utils/generate_grammar_tables.py" + + version = invoker.version + + core_json_file = + "${spirv_headers}/include/spirv/$version/spirv.core.grammar.json" + opengl_json_file = "${spirv_headers}/include/spirv/${version}/extinst.opencl.std.100.grammar.json" + opencl_insts_file = "${target_gen_dir}/opencl.std.insts.inc" + debug_insts_file = "source/extinst.debuginfo.grammar.json" + + args = [ + "--spirv-core-grammar", + rebase_path(core_json_file, root_build_dir), + "--extinst-opencl-grammar", + rebase_path(opengl_json_file, root_build_dir), + "--opencl-insts-output", + rebase_path(opencl_insts_file, root_build_dir), + "--extinst-debuginfo-grammar", + rebase_path(debug_insts_file, root_build_dir), + ] + inputs = [ + core_json_file, + opengl_json_file, + ] + outputs = [ + opencl_insts_file, + ] + } +} + +template("spvtools_language_header") { + assert(defined(invoker.name), "Need name in $target_name generation.") + + action("spvtools_language_header_" + target_name) { + script = "utils/generate_language_headers.py" + + name = invoker.name + extinst_output_base = "${target_gen_dir}/${name}" + debug_insts_file = "source/extinst.debuginfo.grammar.json" + + args = [ + "--extinst-name", + "${name}", + "--extinst-grammar", + rebase_path(debug_insts_file, root_build_dir), + "--extinst-output-base", + rebase_path(extinst_output_base, root_build_dir), + ] + inputs = [ + debug_insts_file, + ] + outputs = [ + "${extinst_output_base}.h", + ] + } +} + +template("spvtools_vendor_table") { + assert(defined(invoker.name), "Need name in $target_name generation.") + + action("spvtools_vendor_tables_" + target_name) { + script = "utils/generate_grammar_tables.py" + + name = invoker.name + extinst_vendor_grammar = "source/extinst.${name}.grammar.json" + extinst_file = "${target_gen_dir}/${name}.insts.inc" + + args = [ + "--extinst-vendor-grammar", + rebase_path(extinst_vendor_grammar, root_build_dir), + "--vendor-insts-output", + rebase_path(extinst_file, root_build_dir), + ] + inputs = [ + extinst_vendor_grammar, + ] + outputs = [ + extinst_file, + ] + } +} + +action("spvtools_generators_inc") { + script = "utils/generate_registry_tables.py" + + # TODO(dsinclair): Make work for chrome + xml_file = "${spirv_headers}/include/spirv/spir-v.xml" + inc_file = "${target_gen_dir}/generators.inc" + + sources = [ + xml_file, + ] + outputs = [ + inc_file, + ] + args = [ + "--xml", + rebase_path(xml_file, root_build_dir), + "--generator", + rebase_path(inc_file, root_build_dir), + ] +} + +action("spvtools_build_version") { + script = "utils/update_build_version.py" + + src_dir = "." + inc_file = "${target_gen_dir}/build-version.inc" + + outputs = [ + inc_file, + ] + args = [ + rebase_path(src_dir, root_build_dir), + rebase_path(inc_file, root_build_dir), + ] +} + +spvtools_core_tables("unified1") { + version = "unified1" +} +spvtools_core_enums("unified1") { + version = "unified1" +} +spvtools_glsl_tables("glsl1-0") { + version = "1.0" +} +spvtools_opencl_tables("opencl1-0") { + version = "1.0" +} +spvtools_language_header("unified1") { + name = "DebugInfo" +} + +spvtools_vendor_tables = [ + "spv-amd-shader-explicit-vertex-parameter", + "spv-amd-shader-trinary-minmax", + "spv-amd-gcn-shader", + "spv-amd-shader-ballot", + "debuginfo", +] + +foreach(table, spvtools_vendor_tables) { + spvtools_vendor_table(table) { + name = table + } +} + +config("spvtools_public_config") { + include_dirs = [ "include" ] +} + +config("spvtools_internal_config") { + include_dirs = [ + ".", + "$target_gen_dir", + "${spirv_headers}/include", + ] + + configs = [ ":spvtools_public_config" ] + + if (is_clang) { + cflags = [ "-Wno-implicit-fallthrough" ] + } +} + +source_set("spvtools_headers") { + sources = [ + "include/spirv-tools/libspirv.h", + "include/spirv-tools/libspirv.hpp", + "include/spirv-tools/linker.hpp", + "include/spirv-tools/optimizer.hpp", + "include/spirv-tools/instrument.hpp", + ] + + public_configs = [ ":spvtools_public_config" ] +} + +static_library("spvtools") { + deps = [ + ":spvtools_core_enums_unified1", + ":spvtools_core_tables_unified1", + ":spvtools_generators_inc", + ":spvtools_glsl_tables_glsl1-0", + ":spvtools_language_header_unified1", + ":spvtools_opencl_tables_opencl1-0", + ] + foreach(target_name, spvtools_vendor_tables) { + deps += [ ":spvtools_vendor_tables_$target_name" ] + } + + sources = [ + "source/assembly_grammar.cpp", + "source/assembly_grammar.h", + "source/binary.cpp", + "source/binary.h", + "source/diagnostic.cpp", + "source/diagnostic.h", + "source/disassemble.cpp", + "source/enum_set.h", + "source/enum_string_mapping.cpp", + "source/ext_inst.cpp", + "source/ext_inst.h", + "source/extensions.cpp", + "source/extensions.h", + "source/instruction.h", + "source/libspirv.cpp", + "source/macro.h", + "source/name_mapper.cpp", + "source/name_mapper.h", + "source/opcode.cpp", + "source/opcode.h", + "source/operand.cpp", + "source/operand.h", + "source/parsed_operand.cpp", + "source/parsed_operand.h", + "source/print.cpp", + "source/print.h", + "source/spirv_constant.h", + "source/spirv_definition.h", + "source/spirv_endian.cpp", + "source/spirv_endian.h", + "source/spirv_optimizer_options.cpp", + "source/spirv_optimizer_options.h", + "source/spirv_target_env.cpp", + "source/spirv_target_env.h", + "source/spirv_validator_options.cpp", + "source/spirv_validator_options.h", + "source/table.cpp", + "source/table.h", + "source/text.cpp", + "source/text.h", + "source/text_handler.cpp", + "source/text_handler.h", + "source/util/bit_vector.cpp", + "source/util/bit_vector.h", + "source/util/bitutils.h", + "source/util/hex_float.h", + "source/util/ilist.h", + "source/util/ilist_node.h", + "source/util/make_unique.h", + "source/util/parse_number.cpp", + "source/util/parse_number.h", + "source/util/small_vector.h", + "source/util/string_utils.cpp", + "source/util/string_utils.h", + "source/util/timer.cpp", + "source/util/timer.h", + ] + + public_deps = [ + ":spvtools_headers", + ] + + configs -= [ "//build/config/compiler:chromium_code" ] + configs += [ + "//build/config/compiler:no_chromium_code", + ":spvtools_internal_config", + ] +} + +static_library("spvtools_val") { + sources = [ + "source/val/basic_block.cpp", + "source/val/construct.cpp", + "source/val/function.cpp", + "source/val/instruction.cpp", + "source/val/validate.cpp", + "source/val/validate.h", + "source/val/validate_adjacency.cpp", + "source/val/validate_annotation.cpp", + "source/val/validate_arithmetics.cpp", + "source/val/validate_atomics.cpp", + "source/val/validate_barriers.cpp", + "source/val/validate_bitwise.cpp", + "source/val/validate_builtins.cpp", + "source/val/validate_capability.cpp", + "source/val/validate_cfg.cpp", + "source/val/validate_composites.cpp", + "source/val/validate_constants.cpp", + "source/val/validate_conversion.cpp", + "source/val/validate_datarules.cpp", + "source/val/validate_debug.cpp", + "source/val/validate_decorations.cpp", + "source/val/validate_derivatives.cpp", + "source/val/validate_execution_limitations.cpp", + "source/val/validate_extensions.cpp", + "source/val/validate_function.cpp", + "source/val/validate_id.cpp", + "source/val/validate_image.cpp", + "source/val/validate_instruction.cpp", + "source/val/validate_interfaces.cpp", + "source/val/validate_layout.cpp", + "source/val/validate_literals.cpp", + "source/val/validate_logicals.cpp", + "source/val/validate_memory.cpp", + "source/val/validate_memory_semantics.cpp", + "source/val/validate_mode_setting.cpp", + "source/val/validate_non_uniform.cpp", + "source/val/validate_primitives.cpp", + "source/val/validate_scopes.cpp", + "source/val/validate_type.cpp", + "source/val/validation_state.cpp", + ] + + deps = [ + ":spvtools", + ] + public_deps = [ + ":spvtools_headers", + ] + + configs -= [ "//build/config/compiler:chromium_code" ] + configs += [ + "//build/config/compiler:no_chromium_code", + ":spvtools_internal_config", + ] +} + +static_library("spvtools_opt") { + sources = [ + "source/opt/aggressive_dead_code_elim_pass.cpp", + "source/opt/aggressive_dead_code_elim_pass.h", + "source/opt/basic_block.cpp", + "source/opt/basic_block.h", + "source/opt/block_merge_pass.cpp", + "source/opt/block_merge_pass.h", + "source/opt/build_module.cpp", + "source/opt/build_module.h", + "source/opt/ccp_pass.cpp", + "source/opt/ccp_pass.h", + "source/opt/cfg.cpp", + "source/opt/cfg.h", + "source/opt/cfg_cleanup_pass.cpp", + "source/opt/cfg_cleanup_pass.h", + "source/opt/code_sink.cpp", + "source/opt/code_sink.h", + "source/opt/combine_access_chains.cpp", + "source/opt/combine_access_chains.h", + "source/opt/common_uniform_elim_pass.cpp", + "source/opt/common_uniform_elim_pass.h", + "source/opt/compact_ids_pass.cpp", + "source/opt/compact_ids_pass.h", + "source/opt/composite.cpp", + "source/opt/composite.h", + "source/opt/const_folding_rules.cpp", + "source/opt/const_folding_rules.h", + "source/opt/constants.cpp", + "source/opt/constants.h", + "source/opt/copy_prop_arrays.cpp", + "source/opt/copy_prop_arrays.h", + "source/opt/dead_branch_elim_pass.cpp", + "source/opt/dead_branch_elim_pass.h", + "source/opt/dead_insert_elim_pass.cpp", + "source/opt/dead_insert_elim_pass.h", + "source/opt/dead_variable_elimination.cpp", + "source/opt/dead_variable_elimination.h", + "source/opt/decoration_manager.cpp", + "source/opt/decoration_manager.h", + "source/opt/def_use_manager.cpp", + "source/opt/def_use_manager.h", + "source/opt/dominator_analysis.cpp", + "source/opt/dominator_analysis.h", + "source/opt/dominator_tree.cpp", + "source/opt/dominator_tree.h", + "source/opt/eliminate_dead_constant_pass.cpp", + "source/opt/eliminate_dead_constant_pass.h", + "source/opt/eliminate_dead_functions_pass.cpp", + "source/opt/eliminate_dead_functions_pass.h", + "source/opt/feature_manager.cpp", + "source/opt/feature_manager.h", + "source/opt/flatten_decoration_pass.cpp", + "source/opt/flatten_decoration_pass.h", + "source/opt/fold.cpp", + "source/opt/fold.h", + "source/opt/fold_spec_constant_op_and_composite_pass.cpp", + "source/opt/fold_spec_constant_op_and_composite_pass.h", + "source/opt/folding_rules.cpp", + "source/opt/folding_rules.h", + "source/opt/freeze_spec_constant_value_pass.cpp", + "source/opt/freeze_spec_constant_value_pass.h", + "source/opt/function.cpp", + "source/opt/function.h", + "source/opt/if_conversion.cpp", + "source/opt/if_conversion.h", + "source/opt/inline_exhaustive_pass.cpp", + "source/opt/inline_exhaustive_pass.h", + "source/opt/inline_opaque_pass.cpp", + "source/opt/inline_opaque_pass.h", + "source/opt/inline_pass.cpp", + "source/opt/inline_pass.h", + "source/opt/inst_bindless_check_pass.cpp", + "source/opt/inst_bindless_check_pass.h", + "source/opt/instruction.cpp", + "source/opt/instruction.h", + "source/opt/instruction_list.cpp", + "source/opt/instruction_list.h", + "source/opt/instrument_pass.cpp", + "source/opt/instrument_pass.h", + "source/opt/ir_builder.h", + "source/opt/ir_context.cpp", + "source/opt/ir_context.h", + "source/opt/ir_loader.cpp", + "source/opt/ir_loader.h", + "source/opt/iterator.h", + "source/opt/licm_pass.cpp", + "source/opt/licm_pass.h", + "source/opt/local_access_chain_convert_pass.cpp", + "source/opt/local_access_chain_convert_pass.h", + "source/opt/local_redundancy_elimination.cpp", + "source/opt/local_redundancy_elimination.h", + "source/opt/local_single_block_elim_pass.cpp", + "source/opt/local_single_block_elim_pass.h", + "source/opt/local_single_store_elim_pass.cpp", + "source/opt/local_single_store_elim_pass.h", + "source/opt/local_ssa_elim_pass.cpp", + "source/opt/local_ssa_elim_pass.h", + "source/opt/log.h", + "source/opt/loop_dependence.cpp", + "source/opt/loop_dependence.h", + "source/opt/loop_dependence_helpers.cpp", + "source/opt/loop_descriptor.cpp", + "source/opt/loop_descriptor.h", + "source/opt/loop_fission.cpp", + "source/opt/loop_fission.h", + "source/opt/loop_fusion.cpp", + "source/opt/loop_fusion.h", + "source/opt/loop_fusion_pass.cpp", + "source/opt/loop_fusion_pass.h", + "source/opt/loop_peeling.cpp", + "source/opt/loop_peeling.h", + "source/opt/loop_unroller.cpp", + "source/opt/loop_unroller.h", + "source/opt/loop_unswitch_pass.cpp", + "source/opt/loop_unswitch_pass.h", + "source/opt/loop_utils.cpp", + "source/opt/loop_utils.h", + "source/opt/mem_pass.cpp", + "source/opt/mem_pass.h", + "source/opt/merge_return_pass.cpp", + "source/opt/merge_return_pass.h", + "source/opt/module.cpp", + "source/opt/module.h", + "source/opt/null_pass.h", + "source/opt/optimizer.cpp", + "source/opt/pass.cpp", + "source/opt/pass.h", + "source/opt/pass_manager.cpp", + "source/opt/pass_manager.h", + "source/opt/passes.h", + "source/opt/private_to_local_pass.cpp", + "source/opt/private_to_local_pass.h", + "source/opt/process_lines_pass.cpp", + "source/opt/process_lines_pass.h", + "source/opt/propagator.cpp", + "source/opt/propagator.h", + "source/opt/reduce_load_size.cpp", + "source/opt/reduce_load_size.h", + "source/opt/redundancy_elimination.cpp", + "source/opt/redundancy_elimination.h", + "source/opt/reflect.h", + "source/opt/register_pressure.cpp", + "source/opt/register_pressure.h", + "source/opt/remove_duplicates_pass.cpp", + "source/opt/remove_duplicates_pass.h", + "source/opt/replace_invalid_opc.cpp", + "source/opt/replace_invalid_opc.h", + "source/opt/scalar_analysis.cpp", + "source/opt/scalar_analysis.h", + "source/opt/scalar_analysis_nodes.h", + "source/opt/scalar_analysis_simplification.cpp", + "source/opt/scalar_replacement_pass.cpp", + "source/opt/scalar_replacement_pass.h", + "source/opt/set_spec_constant_default_value_pass.cpp", + "source/opt/set_spec_constant_default_value_pass.h", + "source/opt/simplification_pass.cpp", + "source/opt/simplification_pass.h", + "source/opt/ssa_rewrite_pass.cpp", + "source/opt/ssa_rewrite_pass.h", + "source/opt/strength_reduction_pass.cpp", + "source/opt/strength_reduction_pass.h", + "source/opt/strip_debug_info_pass.cpp", + "source/opt/strip_debug_info_pass.h", + "source/opt/strip_reflect_info_pass.cpp", + "source/opt/strip_reflect_info_pass.h", + "source/opt/struct_cfg_analysis.cpp", + "source/opt/struct_cfg_analysis.h", + "source/opt/tree_iterator.h", + "source/opt/type_manager.cpp", + "source/opt/type_manager.h", + "source/opt/types.cpp", + "source/opt/types.h", + "source/opt/unify_const_pass.cpp", + "source/opt/unify_const_pass.h", + "source/opt/upgrade_memory_model.cpp", + "source/opt/upgrade_memory_model.h", + "source/opt/value_number_table.cpp", + "source/opt/value_number_table.h", + "source/opt/vector_dce.cpp", + "source/opt/vector_dce.h", + "source/opt/workaround1209.cpp", + "source/opt/workaround1209.h", + ] + + deps = [ + ":spvtools", + ] + public_deps = [ + ":spvtools_headers", + ] + + configs -= [ "//build/config/compiler:chromium_code" ] + configs += [ + "//build/config/compiler:no_chromium_code", + ":spvtools_internal_config", + ] +} + +group("SPIRV-Tools") { + deps = [ + ":spvtools", + ":spvtools_opt", + ":spvtools_val", + ] +} + +if (!build_with_chromium) { + googletest_dir = spirv_tools_googletest_dir + + config("gtest_config") { + include_dirs = [ + "${googletest_dir}/googletest", + "${googletest_dir}/googletest/include", + ] + } + + static_library("gtest") { + testonly = true + sources = [ + "${googletest_dir}/googletest/src/gtest-all.cc", + ] + public_configs = [ ":gtest_config" ] + } + + config("gmock_config") { + include_dirs = [ + "${googletest_dir}/googlemock", + "${googletest_dir}/googlemock/include", + "${googletest_dir}/googletest/include", + ] + if (is_clang) { + # TODO: Can remove this if/when the issue is fixed. + # https://github.com/google/googletest/issues/533 + cflags = [ "-Wno-inconsistent-missing-override" ] + } + } + + static_library("gmock") { + testonly = true + sources = [ + "${googletest_dir}/googlemock/src/gmock-all.cc", + ] + public_configs = [ ":gmock_config" ] + } +} + +test("spvtools_test") { + sources = [ + "test/assembly_context_test.cpp", + "test/assembly_format_test.cpp", + "test/binary_destroy_test.cpp", + "test/binary_endianness_test.cpp", + "test/binary_header_get_test.cpp", + "test/binary_parse_test.cpp", + "test/binary_strnlen_s_test.cpp", + "test/binary_to_text.literal_test.cpp", + "test/binary_to_text_test.cpp", + "test/comment_test.cpp", + "test/enum_set_test.cpp", + "test/enum_string_mapping_test.cpp", + "test/ext_inst.debuginfo_test.cpp", + "test/ext_inst.glsl_test.cpp", + "test/ext_inst.opencl_test.cpp", + "test/fix_word_test.cpp", + "test/generator_magic_number_test.cpp", + "test/hex_float_test.cpp", + "test/immediate_int_test.cpp", + "test/libspirv_macros_test.cpp", + "test/name_mapper_test.cpp", + "test/named_id_test.cpp", + "test/opcode_make_test.cpp", + "test/opcode_require_capabilities_test.cpp", + "test/opcode_split_test.cpp", + "test/opcode_table_get_test.cpp", + "test/operand_capabilities_test.cpp", + "test/operand_pattern_test.cpp", + "test/operand_test.cpp", + "test/target_env_test.cpp", + "test/test_fixture.h", + "test/text_advance_test.cpp", + "test/text_destroy_test.cpp", + "test/text_literal_test.cpp", + "test/text_start_new_inst_test.cpp", + "test/text_to_binary.annotation_test.cpp", + "test/text_to_binary.barrier_test.cpp", + "test/text_to_binary.constant_test.cpp", + "test/text_to_binary.control_flow_test.cpp", + "test/text_to_binary.debug_test.cpp", + "test/text_to_binary.device_side_enqueue_test.cpp", + "test/text_to_binary.extension_test.cpp", + "test/text_to_binary.function_test.cpp", + "test/text_to_binary.group_test.cpp", + "test/text_to_binary.image_test.cpp", + "test/text_to_binary.literal_test.cpp", + "test/text_to_binary.memory_test.cpp", + "test/text_to_binary.misc_test.cpp", + "test/text_to_binary.mode_setting_test.cpp", + "test/text_to_binary.pipe_storage_test.cpp", + "test/text_to_binary.reserved_sampling_test.cpp", + "test/text_to_binary.subgroup_dispatch_test.cpp", + "test/text_to_binary.type_declaration_test.cpp", + "test/text_to_binary_test.cpp", + "test/text_word_get_test.cpp", + "test/unit_spirv.cpp", + "test/unit_spirv.h", + ] + + deps = [ + ":spvtools", + ":spvtools_language_header_unified1", + ":spvtools_val", + ] + + if (build_with_chromium) { + deps += [ + "//testing/gmock", + "//testing/gtest", + "//testing/gtest:gtest_main", + ] + } else { + deps += [ + ":gmock", + ":gtest", + ] + sources += [ "${googletest_dir}/googletest/src/gtest_main.cc" ] + } + + if (is_clang) { + cflags_cc = [ "-Wno-self-assign" ] + } + + configs += [ ":spvtools_internal_config" ] +} + +if (spirv_tools_standalone) { + group("fuzzers") { + testonly = true + deps = [ + "test/fuzzers", + ] + } +} + +executable("spirv-as") { + sources = [ + "source/software_version.cpp", + "tools/as/as.cpp", + ] + deps = [ + ":spvtools", + ":spvtools_build_version", + ] + configs += [ ":spvtools_internal_config" ] +} diff --git a/CHANGES b/CHANGES new file mode 100644 index 000000000..d4c6dffe2 --- /dev/null +++ b/CHANGES @@ -0,0 +1,741 @@ +Revision history for SPIRV-Tools + +v2019.2-dev 2019-01-07 + - Start v2019.2-dev + +v2019.1 2019-01-07 + - General: + - Created a new tool called spirv-reduce. + - Add cmake option to turn off SPIRV_TIMER_ENABLED (#2103) + - New optimization pass to update the memory model from GLSL450 to VulkanKHR. + - Recognize OpTypeAccelerationStructureNV as a type instruction and ray tracing storage classes. + - Fix GCC8 build. + - Add --target-env flag to spirv-opt. + - Add --webgpu-mode flag to run optimizations for webgpu. + - The output disassembled line number stead of byte offset in validation errors. (#2091) + - Optimizer + - Added the instrumentation passes for bindless validation. + - Added passes to help preserve OpLine information (#2027) + - Add basic support for EXT_fragment_invocation_density (#2100) + - Fix invalid OpPhi generated by merge-return. (#2172) + - Constant and type manager have been turned into analysies. (#2251) + Fixes: + - #2018: Don't inline functions with a return in a structured CFG contstruct. + - #2047: Fix bug in folding when volatile stores are present. + - #2053: Fix check for when folding floating pointer values is allowed. + - #2130: Don't inline recursive functions. + - #2202: Handle multiple edges between two basic blocks in SSA-rewriter. + - #2205: Don't unswitch a latch condition during loop unswitch. + - #2245: Don't fold branch in loop unswitch. Run dead branch elimination to fold them. + - #2204: Fix eliminate common uniform to place OpPhi instructions correctly. + - #2247: Fix type mismatches caused by scalar replacement. + - #2248: Fix missing OpPhi after merge return. + - #2211: After merge return, fix invalid continue target. + - #2210: Fix loop invariant code motion to not place code between merge instruction and branch. + - #2258: Handle CompositeInsert with no indices in VDCE. + - #2261: Have replace load size handle extact with no index. + - Validator + - Changed the naming convention of outputing ids with names in diagnostic messages. + - Added validation rules for UniformConstant variables in Vulkan. + - #1949: Validate uniform variable type in Vulkan + - Ensure for OpVariable that result type and storage class operand agree (#2052) + - Validator: Support VK_EXT_scalar_block_layout + - Added Vulkan memory model semantics validation + - Added validation checkes spefic to WebGPU environment. + - Add support for VK_EXT_Transform_feedback capabilities (#2088) + - Add validation for OpArrayLength. (#2117) + - Ensure that function parameter's type is not void (#2118) + - Validate pointer variables (#2111) + - Add check for QueueFamilyKHMR memory scope (#2144) + - Validate PushConstants annotation and type (#2140) + - Allow Float16/Int8 for Vulkan 1.0 (#2153) + - Check binding annotations in resource variables (#2151, #2167) + - Validate OpForwardPointer (#2156) + - Validate operation for OpSpecConstantOp (#2260) + Fixes: + - #2049: Allow InstanceId for NV ray tracing + - Reduce + - Initial commit wit a few passes to reduce test cases. + - Validation is run after each reduction step. + Fixes: + + +v2018.6 2018-11-07 + - General: + - Added support for the Nvidia Turing and ray tracing extensions. + - Make C++11 the CXX standard in CMakeLists.txt. + - Enabled a parallel build for MSVC. + - Enable pre-compiled headers for MSVC. + - Added a code of conduct. + - EFFCEE and RE2 are now required when build the tests. + - Optimizer + - Unrolling loops marked for unrolling in the legalization passes. + - Improved the compile time of loop unrolling. + - Changee merge-return to create a dummy loop around the function. + - Small improvement to merge-blocks to allow it to merge more often. + - Enforce an upper bound for the ids, and add option to set it. + - #1966: Report error if there are unreachable block before running merge return + Fixes: + - #1917: Allow 0 (meaning unlimited) as a parameter to --scalar-replacement + - #1915: Improve handling of group decorations. + - #1942: Fix incorrect uses of the constant manager. Avoids type mismatches in generated code. + - #1997: Fix dead branch elimination when there is a loop in folded selection. + - #1991: Fixes legality check in if-conversion. + - #1987: Add nullptr check to array copy propagation. + - #1984: Better handling of OpUnreachable in ADCE. + - #1983: Run merge return on reachable functions only. + - #1956: Handled atomic operations in ADCE. + - #1963: Fold integer divisions by 0 to 0. + - #2019: Handle MemberDecorateStringGOOGLE in ADCE and strip reflect. + - Validator + - Added validation for OpGroupNonUniformBallotBitCount. + - Added validation for the Vulkan memory model. + - Added support for VK_KHR_shader_atddomic_int64. + - Added validation for execution modes. + - Added validation for runtime array layouts. + - Added validation for 8-bit storage. + - Added validation of OpPhi instructions with pointer result type. + - Added checks for the Vulkan memory model. + - Validate MakeTexelAvailableKHR and MakeTexelVisibleKHR + - Allow atomic function pointer for OpenCL. + - FPRounding mode checks were implemented. + - Added validation for the id bound with an option to set the max id bound. + Fixes: + - #1882: Improve the validation of decorations to reduce memory usage. + - #1891: Fix an potential infinite loop in dead-branch-elimination. + - #1405: Validate the storage class of boolean objects. + - #1880: Identify arrays of type void as invalid. + - #487: Validate OpImageTexelPointer. + - #1922: Validate OpPhi instructions are at the start of a block correctly. + - #1923: Validate function scope variable are at the start of the entry block. + +v2018.5 2018-09-07 + - General: + - Support SPV_KHR_vulkan_memory_model + - Update Dim capabilities, to match SPIR-V 1.3 Rev 4 + - Automated build bots no run tests for the VS2013 case + - Support Chromium GN build + - Use Kokoro bots: + - Disable Travis-CI bots + - Disable AppVeyor VisualStudio Release builds. Keep VS 2017 Debug build + - Don't check export symbols on OSX (Darwin): some installations don't have 'objdump' + - Reorganize source files and namespaces + - Fixes for ClangTidy, and whitespace (passes 'git cl presumit --all -uf') + - Fix unused param compile warnings/errors when Effcee not present + - Avoid including time headers when timer functionality is disabled + - Avoid too-stringent warnings flags for Clang on Windows + - Internal refactoring + - Add hooks for automated fuzzing + - Add testing of command line executables + - #1688: Use binary mode on stdin; fixes "spirv-dis . versioning, with "-dev" indicating + work in progress. The intent is to more easly report + and summarize functionality when SPIRV-Tools is incorporated + in downstream projects. + + - Summary of functionality (See the README.md for more): + - Supports SPIR-V 1.1 Rev 1 + - Supports SPIR-V 1.0 Rev 5 + - Supports GLSL std450 extended instructions 1.0 Rev 3 + - Supports OpenCL extended instructions 1.0 Rev 2 + - Assembler, disassembler are complete + - Supports floating point widths of 16, 32, 64 bits + - Supports integer widths up to 64 bits + - Validator is incomplete + - Checks capability requirements in most cases + - Checks module layout constraints + - Checks ID use-definition ordering constraints, + ignoring control flow + - Checks some control flow graph rules + - Optimizer is introduced, with few available transforms. + - Supported on Linux, OSX, Android, Windows + + - Fixes bugs: + - #143: OpenCL pow and pown arguments diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 000000000..a5ecb90a1 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,294 @@ +# Copyright (c) 2015-2016 The Khronos Group Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 2.8.12) +if (POLICY CMP0048) + cmake_policy(SET CMP0048 NEW) +endif() +if (POLICY CMP0054) + # Avoid dereferencing variables or interpret keywords that have been + # quoted or bracketed. + # https://cmake.org/cmake/help/v3.1/policy/CMP0054.html + cmake_policy(SET CMP0054 NEW) +endif() +set_property(GLOBAL PROPERTY USE_FOLDERS ON) + +project(spirv-tools) +enable_testing() +set(SPIRV_TOOLS "SPIRV-Tools") + +include(GNUInstallDirs) +include(cmake/setup_build.cmake) + +set(CMAKE_POSITION_INDEPENDENT_CODE ON) +set(CMAKE_CXX_STANDARD 11) + +option(SPIRV_ALLOW_TIMERS "Allow timers via clock_gettime on supported platforms" ON) + +if("${CMAKE_SYSTEM_NAME}" STREQUAL "Linux") + add_definitions(-DSPIRV_LINUX) + set(SPIRV_TIMER_ENABLED ${SPIRV_ALLOW_TIMERS}) +elseif("${CMAKE_SYSTEM_NAME}" STREQUAL "Windows") + add_definitions(-DSPIRV_WINDOWS) +elseif("${CMAKE_SYSTEM_NAME}" STREQUAL "CYGWIN") + add_definitions(-DSPIRV_WINDOWS) +elseif("${CMAKE_SYSTEM_NAME}" STREQUAL "Darwin") + add_definitions(-DSPIRV_MAC) +elseif("${CMAKE_SYSTEM_NAME}" STREQUAL "Android") + add_definitions(-DSPIRV_ANDROID) + set(SPIRV_TIMER_ENABLED ${SPIRV_ALLOW_TIMERS}) +elseif("${CMAKE_SYSTEM_NAME}" STREQUAL "FreeBSD") + add_definitions(-DSPIRV_FREEBSD) +else() + message(FATAL_ERROR "Your platform '${CMAKE_SYSTEM_NAME}' is not supported!") +endif() + +if (${SPIRV_TIMER_ENABLED}) + add_definitions(-DSPIRV_TIMER_ENABLED) +endif() + +if ("${CMAKE_BUILD_TYPE}" STREQUAL "") + message(STATUS "No build type selected, default to Debug") + set(CMAKE_BUILD_TYPE "Debug") +endif() + +option(SKIP_SPIRV_TOOLS_INSTALL "Skip installation" ${SKIP_SPIRV_TOOLS_INSTALL}) +if(NOT ${SKIP_SPIRV_TOOLS_INSTALL}) + set(ENABLE_SPIRV_TOOLS_INSTALL ON) +endif() + +option(SPIRV_BUILD_COMPRESSION "Build SPIR-V compressing codec" OFF) + +option(SPIRV_WERROR "Enable error on warning" ON) +if(("${CMAKE_CXX_COMPILER_ID}" MATCHES "GNU") OR (("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") AND (NOT CMAKE_CXX_SIMULATE_ID STREQUAL "MSVC"))) + set(COMPILER_IS_LIKE_GNU TRUE) +endif() +if(${COMPILER_IS_LIKE_GNU}) + set(SPIRV_WARNINGS -Wall -Wextra -Wnon-virtual-dtor -Wno-missing-field-initializers) + + if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") + set(SPIRV_WARNINGS ${SPIRV_WARNINGS} -Wno-self-assign) + endif() + + option(SPIRV_WARN_EVERYTHING "Enable -Weverything" ${SPIRV_WARN_EVERYTHING}) + if(${SPIRV_WARN_EVERYTHING}) + if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") + set(SPIRV_WARNINGS ${SPIRV_WARNINGS} + -Weverything -Wno-c++98-compat -Wno-c++98-compat-pedantic -Wno-padded) + elseif("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + set(SPIRV_WARNINGS ${SPIRV_WARNINGS} -Wpedantic -pedantic-errors) + else() + message(STATUS "Unknown compiler ${CMAKE_CXX_COMPILER_ID}, " + "so SPIRV_WARN_EVERYTHING has no effect") + endif() + endif() + + if(${SPIRV_WERROR}) + set(SPIRV_WARNINGS ${SPIRV_WARNINGS} -Werror) + endif() +elseif(MSVC) + set(SPIRV_WARNINGS -D_CRT_SECURE_NO_WARNINGS -D_SCL_SECURE_NO_WARNINGS /wd4800) + + if(${SPIRV_WERROR}) + set(SPIRV_WARNINGS ${SPIRV_WARNINGS} /WX) + endif() +endif() + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/) + +option(SPIRV_COLOR_TERMINAL "Enable color terminal output" ON) +if(${SPIRV_COLOR_TERMINAL}) + add_definitions(-DSPIRV_COLOR_TERMINAL) +endif() + +option(SPIRV_LOG_DEBUG "Enable excessive debug output" OFF) +if(${SPIRV_LOG_DEBUG}) + add_definitions(-DSPIRV_LOG_DEBUG) +endif() + +if (DEFINED SPIRV_TOOLS_EXTRA_DEFINITIONS) + add_definitions(${SPIRV_TOOLS_EXTRA_DEFINITIONS}) +endif() + +function(spvtools_default_compile_options TARGET) + target_compile_options(${TARGET} PRIVATE ${SPIRV_WARNINGS}) + + if (${COMPILER_IS_LIKE_GNU}) + target_compile_options(${TARGET} PRIVATE + -std=c++11 -fno-exceptions -fno-rtti) + target_compile_options(${TARGET} PRIVATE + -Wall -Wextra -Wno-long-long -Wshadow -Wundef -Wconversion + -Wno-sign-conversion) + # For good call stacks in profiles, keep the frame pointers. + if(NOT "${SPIRV_PERF}" STREQUAL "") + target_compile_options(${TARGET} PRIVATE -fno-omit-frame-pointer) + endif() + if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") + set(SPIRV_USE_SANITIZER "" CACHE STRING + "Use the clang sanitizer [address|memory|thread|...]") + if(NOT "${SPIRV_USE_SANITIZER}" STREQUAL "") + target_compile_options(${TARGET} PRIVATE + -fsanitize=${SPIRV_USE_SANITIZER}) + endif() + target_compile_options(${TARGET} PRIVATE + -ftemplate-depth=1024) + else() + target_compile_options(${TARGET} PRIVATE + -Wno-missing-field-initializers) + endif() + endif() + + if (MSVC) + # Specify /EHs for exception handling. This makes using SPIRV-Tools as + # dependencies in other projects easier. + target_compile_options(${TARGET} PRIVATE /EHs) + endif() + + # For MinGW cross compile, statically link to the C++ runtime. + # But it still depends on MSVCRT.dll. + if (${CMAKE_SYSTEM_NAME} MATCHES "Windows") + if (${CMAKE_CXX_COMPILER_ID} MATCHES "GNU") + set_target_properties(${TARGET} PROPERTIES + LINK_FLAGS -static -static-libgcc -static-libstdc++) + endif() + endif() +endfunction() + +if(NOT COMMAND find_host_package) + macro(find_host_package) + find_package(${ARGN}) + endmacro() +endif() +if(NOT COMMAND find_host_program) + macro(find_host_program) + find_program(${ARGN}) + endmacro() +endif() + +find_host_package(PythonInterp) + +# Check for symbol exports on Linux. +# At the moment, this check will fail on the OSX build machines for the Android NDK. +# It appears they don't have objdump. +if("${CMAKE_SYSTEM_NAME}" STREQUAL "Linux") + macro(spvtools_check_symbol_exports TARGET) + if (NOT "${SPIRV_SKIP_TESTS}") + add_test(NAME spirv-tools-symbol-exports-${TARGET} + COMMAND ${PYTHON_EXECUTABLE} + ${spirv-tools_SOURCE_DIR}/utils/check_symbol_exports.py "$") + endif() + endmacro() +else() + macro(spvtools_check_symbol_exports TARGET) + if (NOT "${SPIRV_SKIP_TESTS}") + message("Skipping symbol exports test for ${TARGET}") + endif() + endmacro() +endif() + +# Defaults to OFF if the user didn't set it. +option(SPIRV_SKIP_EXECUTABLES + "Skip building the executable and tests along with the library" + ${SPIRV_SKIP_EXECUTABLES}) +option(SPIRV_SKIP_TESTS + "Skip building tests along with the library" ${SPIRV_SKIP_TESTS}) +if ("${SPIRV_SKIP_EXECUTABLES}") + set(SPIRV_SKIP_TESTS ON) +endif() + +# Defaults to ON. The checks can be time consuming. +# Turn off if they take too long. +option(SPIRV_CHECK_CONTEXT "In a debug build, check if the IR context is in a valid state." ON) +if (${SPIRV_CHECK_CONTEXT}) + add_definitions(-DSPIRV_CHECK_CONTEXT) +endif() + +# Precompiled header macro. Parameters are source file list and filename for pch cpp file. +macro(spvtools_pch SRCS PCHPREFIX) + if(MSVC AND CMAKE_GENERATOR MATCHES "^Visual Studio") + set(PCH_NAME "$(IntDir)\\${PCHPREFIX}.pch") + # make source files use/depend on PCH_NAME + set_source_files_properties(${${SRCS}} PROPERTIES COMPILE_FLAGS "/Yu${PCHPREFIX}.h /FI${PCHPREFIX}.h /Fp${PCH_NAME} /Zm300" OBJECT_DEPENDS "${PCH_NAME}") + # make PCHPREFIX.cpp file compile and generate PCH_NAME + set_source_files_properties("${PCHPREFIX}.cpp" PROPERTIES COMPILE_FLAGS "/Yc${PCHPREFIX}.h /Fp${PCH_NAME} /Zm300" OBJECT_OUTPUTS "${PCH_NAME}") + list(APPEND ${SRCS} "${PCHPREFIX}.cpp") + endif() +endmacro(spvtools_pch) + +add_subdirectory(external) + +add_subdirectory(source) +add_subdirectory(tools) + +add_subdirectory(test) +add_subdirectory(examples) + +if(ENABLE_SPIRV_TOOLS_INSTALL) + install( + FILES + ${CMAKE_CURRENT_SOURCE_DIR}/include/spirv-tools/libspirv.h + ${CMAKE_CURRENT_SOURCE_DIR}/include/spirv-tools/libspirv.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/include/spirv-tools/optimizer.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/include/spirv-tools/linker.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/include/spirv-tools/instrument.hpp + DESTINATION + ${CMAKE_INSTALL_INCLUDEDIR}/spirv-tools/) +endif(ENABLE_SPIRV_TOOLS_INSTALL) + +if (NOT "${SPIRV_SKIP_TESTS}") + add_test(NAME spirv-tools-copyrights + COMMAND ${PYTHON_EXECUTABLE} utils/check_copyright.py + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) +endif() + +set(SPIRV_LIBRARIES "-lSPIRV-Tools -lSPIRV-Tools-link -lSPIRV-Tools-opt") +set(SPIRV_SHARED_LIBRARIES "-lSPIRV-Tools-shared") +if(SPIRV_BUILD_COMPRESSION) + set(SPIRV_LIBRARIES "${SPIRV_LIBRARIES} -lSPIRV-Tools-comp") +endif(SPIRV_BUILD_COMPRESSION) + +# Build pkg-config file +# Use a first-class target so it's regenerated when relevant files are updated. +add_custom_target(spirv-tools-pkg-config ALL + COMMAND ${CMAKE_COMMAND} + -DCHANGES_FILE=${CMAKE_CURRENT_SOURCE_DIR}/CHANGES + -DTEMPLATE_FILE=${CMAKE_CURRENT_SOURCE_DIR}/cmake/SPIRV-Tools.pc.in + -DOUT_FILE=${CMAKE_CURRENT_BINARY_DIR}/SPIRV-Tools.pc + -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX} + -DCMAKE_INSTALL_LIBDIR=${CMAKE_INSTALL_LIBDIR} + -DCMAKE_INSTALL_INCLUDEDIR=${CMAKE_INSTALL_INCLUDEDIR} + -DSPIRV_LIBRARIES=${SPIRV_LIBRARIES} + -P ${CMAKE_CURRENT_SOURCE_DIR}/cmake/write_pkg_config.cmake + DEPENDS "CHANGES" "cmake/SPIRV-Tools.pc.in" "cmake/write_pkg_config.cmake") +add_custom_target(spirv-tools-shared-pkg-config ALL + COMMAND ${CMAKE_COMMAND} + -DCHANGES_FILE=${CMAKE_CURRENT_SOURCE_DIR}/CHANGES + -DTEMPLATE_FILE=${CMAKE_CURRENT_SOURCE_DIR}/cmake/SPIRV-Tools-shared.pc.in + -DOUT_FILE=${CMAKE_CURRENT_BINARY_DIR}/SPIRV-Tools-shared.pc + -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX} + -DCMAKE_INSTALL_LIBDIR=${CMAKE_INSTALL_LIBDIR} + -DCMAKE_INSTALL_INCLUDEDIR=${CMAKE_INSTALL_INCLUDEDIR} + -DSPIRV_SHARED_LIBRARIES=${SPIRV_SHARED_LIBRARIES} + -P ${CMAKE_CURRENT_SOURCE_DIR}/cmake/write_pkg_config.cmake + DEPENDS "CHANGES" "cmake/SPIRV-Tools-shared.pc.in" "cmake/write_pkg_config.cmake") + +# Install pkg-config file +if (ENABLE_SPIRV_TOOLS_INSTALL) + install( + FILES + ${CMAKE_CURRENT_BINARY_DIR}/SPIRV-Tools.pc + ${CMAKE_CURRENT_BINARY_DIR}/SPIRV-Tools-shared.pc + DESTINATION + ${CMAKE_INSTALL_LIBDIR}/pkgconfig) +endif() diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 000000000..a11610bd3 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1 @@ +A reminder that this issue tracker is managed by the Khronos Group. Interactions here should follow the Khronos Code of Conduct (https://www.khronos.org/developers/code-of-conduct), which prohibits aggressive or derogatory language. Please keep the discussion friendly and civil. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..93a5610ee --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,192 @@ +# Contributing to SPIR-V Tools + +## For users: Reporting bugs and requesting features + +We organize known future work in GitHub projects. See [Tracking SPIRV-Tools work +with GitHub +projects](https://github.com/KhronosGroup/SPIRV-Tools/blob/master/projects.md) +for more. + +To report a new bug or request a new feature, please file a GitHub issue. Please +ensure the bug has not already been reported by searching +[issues](https://github.com/KhronosGroup/SPIRV-Tools/issues) and +[projects](https://github.com/KhronosGroup/SPIRV-Tools/projects). If the bug has +not already been reported open a new one +[here](https://github.com/KhronosGroup/SPIRV-Tools/issues/new). + +When opening a new issue for a bug, make sure you provide the following: + +* A clear and descriptive title. + * We want a title that will make it easy for people to remember what the + issue is about. Simply using "Segfault in spirv-opt" is not helpful + because there could be (but hopefully aren't) multiple bugs with + segmentation faults with different causes. +* A test case that exposes the bug, with the steps and commands to reproduce + it. + * The easier it is for a developer to reproduce the problem, the quicker a + fix can be found and verified. It will also make it easier for someone + to possibly realize the bug is related to another issue. + +For feature requests, we use +[issues](https://github.com/KhronosGroup/SPIRV-Tools/issues) as well. Please +create a new issue, as with bugs. In the issue provide + +* A description of the problem that needs to be solved. +* Examples that demonstrate the problem. + +## For developers: Contributing a patch + +Before we can use your code, you must sign the [Khronos Open Source Contributor +License Agreement](https://cla-assistant.io/KhronosGroup/SPIRV-Tools) (CLA), +which you can do online. The CLA is necessary mainly because you own the +copyright to your changes, even after your contribution becomes part of our +codebase, so we need your permission to use and distribute your code. We also +need to be sure of various other things -- for instance that you'll tell us if +you know that your code infringes on other people's patents. You don't have to +sign the CLA until after you've submitted your code for review and a member has +approved it, but you must do it before we can put your code into our codebase. + +See +[README.md](https://github.com/KhronosGroup/SPIRV-Tools/blob/master/README.md) +for instruction on how to get, build, and test the source. Once you have made +your changes: + +* Ensure the code follows the [Google C++ Style + Guide](https://google.github.io/styleguide/cppguide.html). Running + `clang-format -style=file -i [modified-files]` can help. +* Create a pull request (PR) with your patch. +* Make sure the PR description clearly identified the problem, explains the + solution, and references the issue if applicable. +* If your patch completely fixes bug 1234, the commit message should say + `Fixes https://github.com/KhronosGroup/SPIRV-Tools/issues/1234` + When you do this, the issue will be closed automatically when the commit + goes into master. Also, this helps us update the [CHANGES](CHANGES) file. +* Watch the continuous builds to make sure they pass. +* Request a code review. + +The reviewer can either approve your PR or request changes. If changes are +requested: + +* Please add new commits to your branch, instead of amending your commit. + Adding new commits makes it easier for the reviewer to see what has changed + since the last review. +* Once you are ready for another round of reviews, add a comment at the + bottom, such as "Ready for review" or "Please take a look" (or "PTAL"). This + explicit handoff is useful when responding with multiple small commits. + +After the PR has been reviewed it is the job of the reviewer to merge the PR. +Instructions for this are given below. + +## For maintainers: Reviewing a PR + +The formal code reviews are done on GitHub. Reviewers are to look for all of the +usual things: + +* Coding style follows the [Google C++ Style + Guide](https://google.github.io/styleguide/cppguide.html) +* Identify potential functional problems. +* Identify code duplication. +* Ensure the unit tests have enough coverage. +* Ensure continuous integration (CI) bots run on the PR. If not run (in the + case of PRs by external contributors), add the "kokoro:run" label to the + pull request which will trigger running all CI jobs. + +When looking for functional problems, there are some common problems reviewers +should pay particular attention to: + +* Does the code work for both Shader (Vulkan and OpenGL) and Kernel (OpenCL) + scenarios? The respective SPIR-V dialects are slightly different. +* Changes are made to a container while iterating through it. You have to be + careful that iterators are not invalidated or that elements are not skipped. +* C++11 and VS2013. We generally assume that we have a C++11 compliant + compiler. However, on Windows, we still support Visual Studio 2013, which is + not fully C++11 compliant. See + [here](https://msdn.microsoft.com/en-us/library/hh567368.aspx). In + particular, note that it does not provide default move-constructors or + move-assignments for classes. In general, r-value references do not work the + way you might assume they do. +* For SPIR-V transforms: The module is changed, but the analyses are not + updated. For example, a new instruction is added, but the def-use manager is + not updated. Later on, it is possible that the def-use manager will be used, + and give wrong results. + +## For maintainers: Merging a PR + +We intend to maintain a linear history on the GitHub master branch, and the +build and its tests should pass at each commit in that history. A linear +always-working history is easier to understand and to bisect in case we want to +find which commit introduced a bug. + +### Initial merge setup + +The following steps should be done exactly once (when you are about to merge a +PR for the first time): + +* It is assumed that upstream points to + [git@github.com](mailto:git@github.com):KhronosGroup/SPIRV-Tools.git or + https://github.com/KhronosGroup/SPIRV-Tools.git. + +* Find out the local name for the main github repo in your git configuration. + For example, in this configuration, it is labeled `upstream`. + + ``` + git remote -v + [ ... ] + upstream https://github.com/KhronosGroup/SPIRV-Tools.git (fetch) + upstream https://github.com/KhronosGroup/SPIRV-Tools.git (push) + ``` + +* Make sure that the `upstream` remote is set to fetch from the `refs/pull` + namespace: + + ``` + git config --get-all remote.upstream.fetch + +refs/heads/*:refs/remotes/upstream/* + +refs/pull/*/head:refs/remotes/upstream/pr/* + ``` + +* If the line `+refs/pull/*/head:refs/remotes/upstream/pr/*` is not present in + your configuration, you can add it with the command: + + ``` + git config --local --add remote.upstream.fetch '+refs/pull/*/head:refs/remotes/upstream/pr/*' + ``` + +### Merge workflow + +The following steps should be done for every PR that you intend to merge: + +* Make sure your local copy of the master branch is up to date: + + ``` + git checkout master + git pull + ``` + +* Fetch all pull requests refs: + + ``` + git fetch upstream + ``` + +* Checkout the particular pull request you are going to review: + + ``` + git checkout pr/1048 + ``` + +* Rebase the PR on top of the master branch. If there are conflicts, send it + back to the author and ask them to rebase. During the interactive rebase be + sure to squash all of the commits down to a single commit. + + ``` + git rebase -i master + ``` + +* **Build and test the PR.** + +* If all of the tests pass, push the commit `git push upstream HEAD:master` + +* Close the PR and add a comment saying it was push using the commit that you + just pushed. See https://github.com/KhronosGroup/SPIRV-Tools/pull/935 as an + example. diff --git a/DEPS b/DEPS new file mode 100644 index 000000000..5668c6602 --- /dev/null +++ b/DEPS @@ -0,0 +1,174 @@ +use_relative_paths = True + +vars = { + 'chromium_git': 'https://chromium.googlesource.com', + 'github': 'https://github.com', + + 'build_revision': '037f38ae0fe5e11b4f7c33b750fd7a1e9634a606', + 'buildtools_revision': 'ab7b6a7b350dd15804c87c20ce78982811fdd76f', + 'clang_revision': 'abe5e4f9dc0f1df848c7a0efa05256253e77a7b7', + 'effcee_revision': '04b624799f5a9dbaf3fa1dbed2ba9dce2fc8dcf2', + 'googletest_revision': '98a0d007d7092b72eea0e501bb9ad17908a1a036', + 'testing_revision': '340252637e2e7c72c0901dcbeeacfff419e19b59', + 're2_revision': '6cf8ccd82dbaab2668e9b13596c68183c9ecd13f', + 'spirv_headers_revision': '79b6681aadcb53c27d1052e5f8a0e82a981dbf2f', +} + +deps = { + "build": + Var('chromium_git') + "/chromium/src/build.git@" + Var('build_revision'), + + 'buildtools': + Var('chromium_git') + '/chromium/buildtools.git@' + + Var('buildtools_revision'), + + 'external/spirv-headers': + Var('github') + '/KhronosGroup/SPIRV-Headers.git@' + + Var('spirv_headers_revision'), + + 'external/googletest': + Var('github') + '/google/googletest.git@' + Var('googletest_revision'), + + 'external/effcee': + Var('github') + '/google/effcee.git@' + Var('effcee_revision'), + + 'external/re2': + Var('github') + '/google/re2.git@' + Var('re2_revision'), + + 'testing': + Var('chromium_git') + '/chromium/src/testing@' + + Var('testing_revision'), + + 'tools/clang': + Var('chromium_git') + '/chromium/src/tools/clang@' + Var('clang_revision') +} + +recursedeps = [ + # buildtools provides clang_format, libc++, and libc++api + 'buildtools', +] + +hooks = [ + { + 'name': 'gn_win', + 'action': [ 'download_from_google_storage', + '--no_resume', + '--platform=win32', + '--no_auth', + '--bucket', 'chromium-gn', + '-s', 'SPIRV-Tools/buildtools/win/gn.exe.sha1', + ], + }, + { + 'name': 'gn_mac', + 'pattern': '.', + 'action': [ 'download_from_google_storage', + '--no_resume', + '--platform=darwin', + '--no_auth', + '--bucket', 'chromium-gn', + '-s', 'SPIRV-Tools/buildtools/mac/gn.sha1', + ], + }, + { + 'name': 'gn_linux64', + 'pattern': '.', + 'action': [ 'download_from_google_storage', + '--no_resume', + '--platform=linux*', + '--no_auth', + '--bucket', 'chromium-gn', + '-s', 'SPIRV-Tools/buildtools/linux64/gn.sha1', + ], + }, + # Pull clang-format binaries using checked-in hashes. + { + 'name': 'clang_format_win', + 'pattern': '.', + 'action': [ 'download_from_google_storage', + '--no_resume', + '--platform=win32', + '--no_auth', + '--bucket', 'chromium-clang-format', + '-s', 'SPIRV-Tools/buildtools/win/clang-format.exe.sha1', + ], + }, + { + 'name': 'clang_format_mac', + 'pattern': '.', + 'action': [ 'download_from_google_storage', + '--no_resume', + '--platform=darwin', + '--no_auth', + '--bucket', 'chromium-clang-format', + '-s', 'SPIRV-Tools/buildtools/mac/clang-format.sha1', + ], + }, + { + 'name': 'clang_format_linux', + 'pattern': '.', + 'action': [ 'download_from_google_storage', + '--no_resume', + '--platform=linux*', + '--no_auth', + '--bucket', 'chromium-clang-format', + '-s', 'SPIRV-Tools/buildtools/linux64/clang-format.sha1', + ], + }, + { + # Pull clang + 'name': 'clang', + 'pattern': '.', + 'action': ['python', + 'SPIRV-Tools/tools/clang/scripts/update.py' + ], + }, + { + 'name': 'sysroot_arm', + 'pattern': '.', + 'condition': 'checkout_linux and checkout_arm', + 'action': ['python', 'SPIRV-Tools/build/linux/sysroot_scripts/install-sysroot.py', + '--arch=arm'], + }, + { + 'name': 'sysroot_arm64', + 'pattern': '.', + 'condition': 'checkout_linux and checkout_arm64', + 'action': ['python', 'SPIRV-Tools/build/linux/sysroot_scripts/install-sysroot.py', + '--arch=arm64'], + }, + { + 'name': 'sysroot_x86', + 'pattern': '.', + 'condition': 'checkout_linux and (checkout_x86 or checkout_x64)', + 'action': ['python', 'SPIRV-Tools/build/linux/sysroot_scripts/install-sysroot.py', + '--arch=x86'], + }, + { + 'name': 'sysroot_mips', + 'pattern': '.', + 'condition': 'checkout_linux and checkout_mips', + 'action': ['python', 'SPIRV-Tools/build/linux/sysroot_scripts/install-sysroot.py', + '--arch=mips'], + }, + { + 'name': 'sysroot_x64', + 'pattern': '.', + 'condition': 'checkout_linux and checkout_x64', + 'action': ['python', 'SPIRV-Tools/build/linux/sysroot_scripts/install-sysroot.py', + '--arch=x64'], + }, + { + # Update the Windows toolchain if necessary. + 'name': 'win_toolchain', + 'pattern': '.', + 'condition': 'checkout_win', + 'action': ['python', 'SPIRV-Tools/build/vs_toolchain.py', 'update', '--force'], + }, + { + # Update the Mac toolchain if necessary. + 'name': 'mac_toolchain', + 'pattern': '.', + 'action': ['python', 'SPIRV-Tools/build/mac_toolchain.py'], + }, +] diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..d64569567 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/PRESUBMIT.py b/PRESUBMIT.py new file mode 100644 index 000000000..dd3117f22 --- /dev/null +++ b/PRESUBMIT.py @@ -0,0 +1,40 @@ +# Copyright (c) 2018 The Khronos Group Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Presubmit script for SPIRV-Tools. + +See http://dev.chromium.org/developers/how-tos/depottools/presubmit-scripts +for more details about the presubmit API built into depot_tools. +""" + +LINT_FILTERS = [ + "-build/storage_class", + "-readability/casting", + "-readability/fn_size", + "-readability/todo", + "-runtime/explicit", + "-runtime/int", + "-runtime/printf", + "-runtime/references", + "-runtime/string", +] + + +def CheckChangeOnUpload(input_api, output_api): + results = [] + results += input_api.canned_checks.CheckPatchFormatted(input_api, output_api) + results += input_api.canned_checks.CheckChangeLintsClean( + input_api, output_api, None, LINT_FILTERS) + + return results diff --git a/README.md b/README.md new file mode 100644 index 000000000..534664f77 --- /dev/null +++ b/README.md @@ -0,0 +1,535 @@ +# SPIR-V Tools + +[![Build status](https://ci.appveyor.com/api/projects/status/gpue87cesrx3pi0d/branch/master?svg=true)](https://ci.appveyor.com/project/Khronoswebmaster/spirv-tools/branch/master) +Linux![Linux Build Status](https://storage.googleapis.com/spirv-tools/badges/build_status_linux_release.svg) +MacOS![MacOS Build Status](https://storage.googleapis.com/spirv-tools/badges/build_status_macos_release.svg) +Windows![Windows Build Status](https://storage.googleapis.com/spirv-tools/badges/build_status_windows_release.svg) + +## Overview + +The SPIR-V Tools project provides an API and commands for processing SPIR-V +modules. + +The project includes an assembler, binary module parser, disassembler, +validator, and optimizer for SPIR-V. Except for the optimizer, all are based +on a common static library. The library contains all of the implementation +details, and is used in the standalone tools whilst also enabling integration +into other code bases directly. The optimizer implementation resides in its +own library, which depends on the core library. + +The interfaces have stabilized: +We don't anticipate making a breaking change for existing features. + +SPIR-V is defined by the Khronos Group Inc. +See the [SPIR-V Registry][spirv-registry] for the SPIR-V specification, +headers, and XML registry. + +## Versioning SPIRV-Tools + +See [`CHANGES`](CHANGES) for a high level summary of recent changes, by version. + +SPIRV-Tools project version numbers are of the form `v`*year*`.`*index* and with +an optional `-dev` suffix to indicate work in progress. For exampe, the +following versions are ordered from oldest to newest: + +* `v2016.0` +* `v2016.1-dev` +* `v2016.1` +* `v2016.2-dev` +* `v2016.2` + +Use the `--version` option on each command line tool to see the software +version. An API call reports the software version as a C-style string. + +## Supported features + +### Assembler, binary parser, and disassembler + +* Support for SPIR-V 1.0, 1.1, 1.2, and 1.3 + * Based on SPIR-V syntax described by JSON grammar files in the + [SPIRV-Headers](spirv-headers) repository. +* Support for extended instruction sets: + * GLSL std450 version 1.0 Rev 3 + * OpenCL version 1.0 Rev 2 +* Assembler only does basic syntax checking. No cross validation of + IDs or types is performed, except to check literal arguments to + `OpConstant`, `OpSpecConstant`, and `OpSwitch`. + +See [`syntax.md`](syntax.md) for the assembly language syntax. + +### Validator + +The validator checks validation rules described by the SPIR-V specification. + +Khronos recommends that tools that create or transform SPIR-V modules use the +validator to ensure their outputs are valid, and that tools that consume SPIR-V +modules optionally use the validator to protect themselves from bad inputs. +This is especially encouraged for debug and development scenarios. + +The validator has one-sided error: it will only return an error when it has +implemented a rule check and the module violates that rule. + +The validator is incomplete. +See the [CHANGES](CHANGES) file for reports on completed work, and +the [Validator +sub-project](https://github.com/KhronosGroup/SPIRV-Tools/projects/1) for planned +and in-progress work. + +*Note*: The validator checks some Universal Limits, from section 2.17 of the SPIR-V spec. +The validator will fail on a module that exceeds those minimum upper bound limits. +It is [future work](https://github.com/KhronosGroup/SPIRV-Tools/projects/1#card-1052403) +to parameterize the validator to allow larger +limits accepted by a more than minimally capable SPIR-V consumer. + + +### Optimizer + +*Note:* The optimizer is still under development. + +Currently supported optimizations: +* General + * Strip debug info +* Specialization Constants + * Set spec constant default value + * Freeze spec constant + * Fold `OpSpecConstantOp` and `OpSpecConstantComposite` + * Unify constants + * Eliminate dead constant +* Code Reduction + * Inline all function calls exhaustively + * Convert local access chains to inserts/extracts + * Eliminate local load/store in single block + * Eliminate local load/store with single store + * Eliminate local load/store with multiple stores + * Eliminate local extract from insert + * Eliminate dead instructions (aggressive) + * Eliminate dead branches + * Merge single successor / single predecessor block pairs + * Eliminate common uniform loads + * Remove duplicates: Capabilities, extended instruction imports, types, and + decorations. + +For the latest list with detailed documentation, please refer to +[`include/spirv-tools/optimizer.hpp`](include/spirv-tools/optimizer.hpp). + +For suggestions on using the code reduction options, please refer to this [white paper](https://www.lunarg.com/shader-compiler-technologies/white-paper-spirv-opt/). + + +### Linker + +*Note:* The linker is still under development. + +Current features: +* Combine multiple SPIR-V binary modules together. +* Combine into a library (exports are retained) or an executable (no symbols + are exported). + +See the [CHANGES](CHANGES) file for reports on completed work, and the [General +sub-project](https://github.com/KhronosGroup/SPIRV-Tools/projects/2) for +planned and in-progress work. + +### Extras + +* [Utility filters](#utility-filters) +* Build target `spirv-tools-vimsyntax` generates file `spvasm.vim`. + Copy that file into your `$HOME/.vim/syntax` directory to get SPIR-V assembly syntax + highlighting in Vim. This build target is not built by default. + +## Contributing + +The SPIR-V Tools project is maintained by members of the The Khronos Group Inc., +and is hosted at https://github.com/KhronosGroup/SPIRV-Tools. + +Consider joining the `public_spirv_tools_dev@khronos.org` mailing list, via +[https://www.khronos.org/spir/spirv-tools-mailing-list/](https://www.khronos.org/spir/spirv-tools-mailing-list/). +The mailing list is used to discuss development plans for the SPIRV-Tools as an open source project. +Once discussion is resolved, +specific work is tracked via issues and sometimes in one of the +[projects][spirv-tools-projects]. + +(To provide feedback on the SPIR-V _specification_, file an issue on the +[SPIRV-Headers][spirv-headers] GitHub repository.) + +See [`projects.md`](projects.md) to see how we use the +[GitHub Project +feature](https://help.github.com/articles/tracking-the-progress-of-your-work-with-projects/) +to organize planned and in-progress work. + +Contributions via merge request are welcome. Changes should: +* Be provided under the [Apache 2.0](#license). +* You'll be prompted with a one-time "click-through" + [Khronos Open Source Contributor License Agreement][spirv-tools-cla] + (CLA) dialog as part of submitting your pull request or + other contribution to GitHub. +* Include tests to cover updated functionality. +* C++ code should follow the [Google C++ Style Guide][cpp-style-guide]. +* Code should be formatted with `clang-format`. + [kokoro/check-format/build.sh](kokoro/check-format/build.sh) + shows how to download it. Note that we currently use + `clang-format version 5.0.0` for SPIRV-Tools. Settings are defined by + the included [.clang-format](.clang-format) file. + +We intend to maintain a linear history on the GitHub `master` branch. + +### Source code organization + +* `example`: demo code of using SPIRV-Tools APIs +* `external/googletest`: Intended location for the + [googletest][googletest] sources, not provided +* `external/effcee`: Location of [Effcee][effcee] sources, if the `effcee` library + is not already configured by an enclosing project. +* `external/re2`: Location of [RE2][re2] sources, if the `re2` library is not already + configured by an enclosing project. + (The Effcee project already requires RE2.) +* `include/`: API clients should add this directory to the include search path +* `external/spirv-headers`: Intended location for + [SPIR-V headers][spirv-headers], not provided +* `include/spirv-tools/libspirv.h`: C API public interface +* `source/`: API implementation +* `test/`: Tests, using the [googletest][googletest] framework +* `tools/`: Command line executables + +Example of getting sources, assuming SPIRV-Tools is configured as a standalone project: + + git clone https://github.com/KhronosGroup/SPIRV-Tools.git spirv-tools + git clone https://github.com/KhronosGroup/SPIRV-Headers.git spirv-tools/external/spirv-headers + git clone https://github.com/google/googletest.git spirv-tools/external/googletest + git clone https://github.com/google/effcee.git spirv-tools/external/effcee + git clone https://github.com/google/re2.git spirv-tools/external/re2 + +### Tests + +The project contains a number of tests, used to drive development +and ensure correctness. The tests are written using the +[googletest][googletest] framework. The `googletest` +source is not provided with this project. There are two ways to enable +tests: +* If SPIR-V Tools is configured as part of an enclosing project, then the + enclosing project should configure `googletest` before configuring SPIR-V Tools. +* If SPIR-V Tools is configured as a standalone project, then download the + `googletest` source into the `/external/googletest` directory before + configuring and building the project. + +*Note*: You must use a version of googletest that includes +[a fix][googletest-pull-612] for [googletest issue 610][googletest-issue-610]. +The fix is included on the googletest master branch any time after 2015-11-10. +In particular, googletest must be newer than version 1.7.0. + +### Dependency on Effcee + +Some tests depend on the [Effcee][effcee] library for stateful matching. +Effcee itself depends on [RE2][re2]. + +* If SPIRV-Tools is configured as part of a larger project that already uses + Effcee, then that project should include Effcee before SPIRV-Tools. +* Otherwise, SPIRV-Tools expects Effcee sources to appear in `external/effcee` + and RE2 sources to appear in `external/re2`. + + +## Build + +Instead of building manually, you can also download the binaries for your +platform directly from the [master-tot release][master-tot-release] on GitHub. +Those binaries are automatically uploaded by the buildbots after successful +testing and they always reflect the current top of the tree of the master +branch. + +The project uses [CMake][cmake] to generate platform-specific build +configurations. Assume that `` is the root directory of the checked +out code: + +```sh +cd +git clone https://github.com/KhronosGroup/SPIRV-Headers.git external/spirv-headers +git clone https://github.com/google/effcee.git external/effcee +git clone https://github.com/google/re2.git external/re2 +git clone https://github.com/google/googletest.git external/googletest # optional + +mkdir build && cd build +cmake [-G ] +``` + +Once the build files have been generated, build using your preferred +development environment. + +### CMake options + +The following CMake options are supported: + +* `SPIRV_COLOR_TERMINAL={ON|OFF}`, default `ON` - Enables color console output. +* `SPIRV_SKIP_TESTS={ON|OFF}`, default `OFF`- Build only the library and + the command line tools. This will prevent the tests from being built. +* `SPIRV_SKIP_EXECUTABLES={ON|OFF}`, default `OFF`- Build only the library, not + the command line tools and tests. +* `SPIRV_BUILD_COMPRESSION={ON|OFF}`, default `OFF`- Build SPIR-V compressing + codec. +* `SPIRV_USE_SANITIZER=`, default is no sanitizing - On UNIX + platforms with an appropriate version of `clang` this option enables the use + of the sanitizers documented [here][clang-sanitizers]. + This should only be used with a debug build. +* `SPIRV_WARN_EVERYTHING={ON|OFF}`, default `OFF` - On UNIX platforms enable + more strict warnings. The code might not compile with this option enabled. + For Clang, enables `-Weverything`. For GCC, enables `-Wpedantic`. + See [`CMakeLists.txt`](CMakeLists.txt) for details. +* `SPIRV_WERROR={ON|OFF}`, default `ON` - Forces a compilation error on any + warnings encountered by enabling the compiler-specific compiler front-end + option. No compiler front-end options are enabled when this option is OFF. + +Additionally, you can pass additional C preprocessor definitions to SPIRV-Tools +via setting `SPIRV_TOOLS_EXTRA_DEFINITIONS`. For example, by setting it to +`/D_ITERATOR_DEBUG_LEVEL=0` on Windows, you can disable checked iterators and +iterator debugging. + +### Android + +SPIR-V Tools supports building static libraries `libSPIRV-Tools.a` and +`libSPIRV-Tools-opt.a` for Android: + +``` +cd + +export ANDROID_NDK=/path/to/your/ndk + +mkdir build && cd build +mkdir libs +mkdir app + +$ANDROID_NDK/ndk-build -C ../android_test \ + NDK_PROJECT_PATH=. \ + NDK_LIBS_OUT=`pwd`/libs \ + NDK_APP_OUT=`pwd`/app +``` + +## Library + +### Usage + +The internals of the library use C++11 features, and are exposed via both a C +and C++ API. + +In order to use the library from an application, the include path should point +to `/include`, which will enable the application to include the +header `/include/spirv-tools/libspirv.h{|pp}` then linking against +the static library in `/source/libSPIRV-Tools.a` or +`/source/SPIRV-Tools.lib`. +For optimization, the header file is +`/include/spirv-tools/optimizer.hpp`, and the static library is +`/source/libSPIRV-Tools-opt.a` or +`/source/SPIRV-Tools-opt.lib`. + +* `SPIRV-Tools` CMake target: Creates the static library: + * `/source/libSPIRV-Tools.a` on Linux and OS X. + * `/source/libSPIRV-Tools.lib` on Windows. +* `SPIRV-Tools-opt` CMake target: Creates the static library: + * `/source/libSPIRV-Tools-opt.a` on Linux and OS X. + * `/source/libSPIRV-Tools-opt.lib` on Windows. + +#### Entry points + +The interfaces are still under development, and are expected to change. + +There are five main entry points into the library in the C interface: + +* `spvTextToBinary`: An assembler, translating text to a binary SPIR-V module. +* `spvBinaryToText`: A disassembler, translating a binary SPIR-V module to + text. +* `spvBinaryParse`: The entry point to a binary parser API. It issues callbacks + for the header and each parsed instruction. The disassembler is implemented + as a client of `spvBinaryParse`. +* `spvValidate` implements the validator functionality. *Incomplete* +* `spvValidateBinary` implements the validator functionality. *Incomplete* + +The C++ interface is comprised of three classes, `SpirvTools`, `Optimizer` and +`Linker`, all in the `spvtools` namespace. +* `SpirvTools` provides `Assemble`, `Disassemble`, and `Validate` methods. +* `Optimizer` provides methods for registering and running optimization passes. +* `Linker` provides methods for combining together multiple binaries. + +## Command line tools + +Command line tools, which wrap the above library functions, are provided to +assemble or disassemble shader files. It's a convention to name SPIR-V +assembly and binary files with suffix `.spvasm` and `.spv`, respectively. + +### Assembler tool + +The assembler reads the assembly language text, and emits the binary form. + +The standalone assembler is the exectuable called `spirv-as`, and is located in +`/tools/spirv-as`. The functionality of the assembler is implemented +by the `spvTextToBinary` library function. + +* `spirv-as` - the standalone assembler + * `/tools/as` + +Use option `-h` to print help. + +### Disassembler tool + +The disassembler reads the binary form, and emits assembly language text. + +The standalone disassembler is the executable called `spirv-dis`, and is located in +`/tools/spirv-dis`. The functionality of the disassembler is implemented +by the `spvBinaryToText` library function. + +* `spirv-dis` - the standalone disassembler + * `/tools/dis` + +Use option `-h` to print help. + +The output includes syntax colouring when printing to the standard output stream, +on Linux, Windows, and OS X. + +### Linker tool + +The linker combines multiple SPIR-V binary modules together, resulting in a single +binary module as output. + +This is a work in progress. +The linker does not support OpenCL program linking options related to math +flags. (See section 5.6.5.2 in OpenCL 1.2) + +* `spirv-link` - the standalone linker + * `/tools/link` + +### Optimizer tool + +The optimizer processes a SPIR-V binary module, applying transformations +in the specified order. + +This is a work in progress, with initially only few available transformations. + +* `spirv-opt` - the standalone optimizer + * `/tools/opt` + +### Validator tool + +*Warning:* This functionality is under development, and is incomplete. + +The standalone validator is the executable called `spirv-val`, and is located in +`/tools/spirv-val`. The functionality of the validator is implemented +by the `spvValidate` library function. + +The validator operates on the binary form. + +* `spirv-val` - the standalone validator + * `/tools/val` + +### Control flow dumper tool + +The control flow dumper prints the control flow graph for a SPIR-V module as a +[GraphViz](http://www.graphviz.org/) graph. + +This is experimental. + +* `spirv-cfg` - the control flow graph dumper + * `/tools/cfg` + +### Utility filters + +* `spirv-lesspipe.sh` - Automatically disassembles `.spv` binary files for the + `less` program, on compatible systems. For example, set the `LESSOPEN` + environment variable as follows, assuming both `spirv-lesspipe.sh` and + `spirv-dis` are on your executable search path: + ``` + export LESSOPEN='| spirv-lesspipe.sh "%s"' + ``` + Then you page through a disassembled module as follows: + ``` + less foo.spv + ``` + * The `spirv-lesspipe.sh` script will pass through any extra arguments to + `spirv-dis`. So, for example, you can turn off colours and friendly ID + naming as follows: + ``` + export LESSOPEN='| spirv-lesspipe.sh "%s" --no-color --raw-id' + ``` + +* [vim-spirv](https://github.com/kbenzie/vim-spirv) - A vim plugin which + supports automatic disassembly of `.spv` files using the `:edit` command and + assembly using the `:write` command. The plugin also provides additional + features which include; syntax highlighting; highlighting of all ID's matching + the ID under the cursor; and highlighting errors where the `Instruction` + operand of `OpExtInst` is used without an appropriate `OpExtInstImport`. + +* `50spirv-tools.el` - Automatically disassembles '.spv' binary files when + loaded into the emacs text editor, and re-assembles them when saved, + provided any modifications to the file are valid. This functionality + must be explicitly requested by defining the symbol + SPIRV_TOOLS_INSTALL_EMACS_HELPERS as follows: + ``` + cmake -DSPIRV_TOOLS_INSTALL_EMACS_HELPERS=true ... + ``` + + In addition, this helper is only installed if the directory /etc/emacs/site-start.d + exists, which is typically true if emacs is installed on the system. + + Note that symbol IDs are not currently preserved through a load/edit/save operation. + This may change if the ability is added to spirv-as. + + +### Tests + +Tests are only built when googletest is found. Use `ctest` to run all the +tests. + +## Future Work + + +_See the [projects pages](https://github.com/KhronosGroup/SPIRV-Tools/projects) +for more information._ + +### Assembler and disassembler + +* The disassembler could emit helpful annotations in comments. For example: + * Use variable name information from debug instructions to annotate + key operations on variables. + * Show control flow information by annotating `OpLabel` instructions with + that basic block's predecessors. +* Error messages could be improved. + +### Validator + +This is a work in progress. + +### Linker + +* The linker could accept math transformations such as allowing MADs, or other + math flags passed at linking-time in OpenCL. +* Linkage attributes can not be applied through a group. +* Check decorations of linked functions attributes. +* Remove dead instructions, such as OpName targeting imported symbols. + +## Licence + +Full license terms are in [LICENSE](LICENSE) +``` +Copyright (c) 2015-2016 The Khronos Group Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +``` + +[spirv-tools-cla]: https://cla-assistant.io/KhronosGroup/SPIRV-Tools +[spirv-tools-projects]: https://github.com/KhronosGroup/SPIRV-Tools/projects +[spirv-tools-mailing-list]: https://www.khronos.org/spir/spirv-tools-mailing-list +[spirv-registry]: https://www.khronos.org/registry/spir-v/ +[spirv-headers]: https://github.com/KhronosGroup/SPIRV-Headers +[googletest]: https://github.com/google/googletest +[googletest-pull-612]: https://github.com/google/googletest/pull/612 +[googletest-issue-610]: https://github.com/google/googletest/issues/610 +[effcee]: https://github.com/google/effcee +[re2]: https://github.com/google/re2 +[CMake]: https://cmake.org/ +[cpp-style-guide]: https://google.github.io/styleguide/cppguide.html +[clang-sanitizers]: http://clang.llvm.org/docs/UsersManual.html#controlling-code-generation +[master-tot-release]: https://github.com/KhronosGroup/SPIRV-Tools/releases/tag/master-tot diff --git a/android_test/Android.mk b/android_test/Android.mk new file mode 100644 index 000000000..dbaf93ba9 --- /dev/null +++ b/android_test/Android.mk @@ -0,0 +1,12 @@ +LOCAL_PATH:= $(call my-dir) + +include $(CLEAR_VARS) +LOCAL_CPP_EXTENSION := .cc .cpp .cxx +LOCAL_SRC_FILES:=test.cpp +LOCAL_MODULE:=spirvtools_test +LOCAL_LDLIBS:=-landroid +LOCAL_CXXFLAGS:=-std=c++11 -fno-exceptions -fno-rtti -Werror +LOCAL_STATIC_LIBRARIES=SPIRV-Tools SPIRV-Tools-opt +include $(BUILD_SHARED_LIBRARY) + +include $(LOCAL_PATH)/../Android.mk diff --git a/android_test/jni/Application.mk b/android_test/jni/Application.mk new file mode 100644 index 000000000..d7ccd349a --- /dev/null +++ b/android_test/jni/Application.mk @@ -0,0 +1,5 @@ +APP_ABI := all +APP_BUILD_SCRIPT := Android.mk +APP_STL := gnustl_static +APP_PLATFORM := android-9 +NDK_TOOLCHAIN_VERSION := 4.9 diff --git a/android_test/test.cpp b/android_test/test.cpp new file mode 100644 index 000000000..e6a57c127 --- /dev/null +++ b/android_test/test.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "spirv-tools/libspirv.hpp" +#include "spirv-tools/optimizer.hpp" + +void android_main(struct android_app* /*state*/) { + spvtools::SpirvTools tools(SPV_ENV_UNIVERSAL_1_2); + spvtools::Optimizer optimizer(SPV_ENV_UNIVERSAL_1_2); +} diff --git a/build_overrides/build.gni b/build_overrides/build.gni new file mode 100644 index 000000000..833fcd349 --- /dev/null +++ b/build_overrides/build.gni @@ -0,0 +1,46 @@ +# Copyright 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Variable that can be used to support multiple build scenarios, like having +# Chromium specific targets in a client project's GN file etc. +build_with_chromium = false + +# Don't use Chromium's third_party/binutils. +linux_use_bundled_binutils_override = false + +declare_args() { + # Android 32-bit non-component, non-clang builds cannot have symbol_level=2 + # due to 4GiB file size limit, see https://crbug.com/648948. + # Set this flag to true to skip the assertion. + ignore_elf32_limitations = false + + # Use the system install of Xcode for tools like ibtool, libtool, etc. + # This does not affect the compiler. When this variable is false, targets will + # instead use a hermetic install of Xcode. [The hermetic install can be + # obtained with gclient sync after setting the environment variable + # FORCE_MAC_TOOLCHAIN]. + use_system_xcode = "" +} + +if (use_system_xcode == "") { + if (target_os == "mac") { + _result = exec_script("//build/mac/should_use_hermetic_xcode.py", + [ target_os ], + "value") + use_system_xcode = _result == 0 + } + if (target_os == "ios") { + use_system_xcode = true + } +} diff --git a/build_overrides/gtest.gni b/build_overrides/gtest.gni new file mode 100644 index 000000000..c8b1bae4c --- /dev/null +++ b/build_overrides/gtest.gni @@ -0,0 +1,25 @@ +# Copyright 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Exclude support for registering main function in multi-process tests. +gtest_include_multiprocess = false + +# Exclude support for platform-specific operations across unit tests. +gtest_include_platform_test = false + +# Exclude support for testing Objective C code on OS X and iOS. +gtest_include_objc_support = false + +# Exclude support for flushing coverage files on iOS. +gtest_include_ios_coverage = false diff --git a/build_overrides/spirv_tools.gni b/build_overrides/spirv_tools.gni new file mode 100644 index 000000000..24aa033d7 --- /dev/null +++ b/build_overrides/spirv_tools.gni @@ -0,0 +1,25 @@ +# Copyright 2018 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# These are variables that are overridable by projects that include +# SPIRV-Tools. The values in this file are the defaults for when we are +# building from SPIRV-Tools' repository. + +# Whether we are building from SPIRV-Tools' repository. +# MUST be set to false in other projects. +spirv_tools_standalone = true + +# The path to SPIRV-Tools' dependencies +spirv_tools_googletest_dir = "//external/googletest" +spirv_tools_spirv_headers_dir = "//external/spirv-headers" diff --git a/cmake/SPIRV-Tools-shared.pc.in b/cmake/SPIRV-Tools-shared.pc.in new file mode 100644 index 000000000..0dcaa2764 --- /dev/null +++ b/cmake/SPIRV-Tools-shared.pc.in @@ -0,0 +1,12 @@ +prefix=@CMAKE_INSTALL_PREFIX@ +exec_prefix=${prefix} +libdir=${prefix}/@CMAKE_INSTALL_LIBDIR@ +includedir=${prefix}/@CMAKE_INSTALL_INCLUDEDIR@ + +Name: SPIRV-Tools +Description: Tools for SPIR-V +Version: @CURRENT_VERSION@ +URL: https://github.com/KhronosGroup/SPIRV-Tools + +Libs: -L${libdir} @SPIRV_SHARED_LIBRARIES@ +Cflags: -I${includedir} diff --git a/cmake/SPIRV-Tools.pc.in b/cmake/SPIRV-Tools.pc.in new file mode 100644 index 000000000..2984dc57f --- /dev/null +++ b/cmake/SPIRV-Tools.pc.in @@ -0,0 +1,12 @@ +prefix=@CMAKE_INSTALL_PREFIX@ +exec_prefix=${prefix} +libdir=${prefix}/@CMAKE_INSTALL_LIBDIR@ +includedir=${prefix}/@CMAKE_INSTALL_INCLUDEDIR@ + +Name: SPIRV-Tools +Description: Tools for SPIR-V +Version: @CURRENT_VERSION@ +URL: https://github.com/KhronosGroup/SPIRV-Tools + +Libs: -L${libdir} @SPIRV_LIBRARIES@ +Cflags: -I${includedir} diff --git a/cmake/setup_build.cmake b/cmake/setup_build.cmake new file mode 100644 index 000000000..6ba4c53d7 --- /dev/null +++ b/cmake/setup_build.cmake @@ -0,0 +1,20 @@ +# Find nosetests; see spirv_add_nosetests() for opting in to nosetests in a +# specific directory. +find_program(NOSETESTS_EXE NAMES nosetests PATHS $ENV{PYTHON_PACKAGE_PATH}) +if (NOT NOSETESTS_EXE) + message(STATUS "SPIRV-Tools: nosetests was not found - python support code will not be tested") +else() + message(STATUS "SPIRV-Tools: nosetests found - python support code will be tested") +endif() + +# Run nosetests on file ${PREFIX}_nosetest.py. Nosetests will look for classes +# and functions whose names start with "nosetest". The test name will be +# ${PREFIX}_nosetests. +function(spirv_add_nosetests PREFIX) + if(NOT "${SPIRV_SKIP_TESTS}" AND NOSETESTS_EXE) + add_test( + NAME ${PREFIX}_nosetests + COMMAND ${NOSETESTS_EXE} -m "^[Nn]ose[Tt]est" -v + ${CMAKE_CURRENT_SOURCE_DIR}/${PREFIX}_nosetest.py) + endif() +endfunction() diff --git a/cmake/write_pkg_config.cmake b/cmake/write_pkg_config.cmake new file mode 100644 index 000000000..d367ce3e4 --- /dev/null +++ b/cmake/write_pkg_config.cmake @@ -0,0 +1,31 @@ +# Copyright (c) 2017 Pierre Moreau +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# First, retrieve the current version from CHANGES +file(STRINGS ${CHANGES_FILE} CHANGES_CONTENT) +string( +REGEX + MATCH "v[0-9]+(.[0-9]+)?(-dev)? [0-9]+-[0-9]+-[0-9]+" + FIRST_VERSION_LINE + ${CHANGES_CONTENT}) +string( +REGEX + REPLACE "^v([^ ]+) .+$" "\\1" + CURRENT_VERSION + "${FIRST_VERSION_LINE}") +# If this is a development version, replace "-dev" by ".0" as pkg-config nor +# CMake support "-dev" in the version. +# If it's not a "-dev" version then ensure it ends with ".1" +string(REGEX REPLACE "-dev.1" ".0" CURRENT_VERSION "${CURRENT_VERSION}.1") +configure_file(${TEMPLATE_FILE} ${OUT_FILE} @ONLY) diff --git a/codereview.settings b/codereview.settings new file mode 100644 index 000000000..ef84cf857 --- /dev/null +++ b/codereview.settings @@ -0,0 +1,2 @@ +# This file is used by git cl to get repository specific information. +CODE_REVIEW_SERVER: github.com diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt new file mode 100644 index 000000000..fd627cbda --- /dev/null +++ b/examples/CMakeLists.txt @@ -0,0 +1,35 @@ +# Copyright (c) 2016 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Add a SPIR-V Tools example. Signature: +# add_spvtools_example( +# TARGET target_name +# SRCS src_file1.cpp src_file2.cpp +# LIBS lib_target1 lib_target2 +# ) +function(add_spvtools_example) + if (NOT ${SPIRV_SKIP_EXECUTABLES}) + set(one_value_args TARGET) + set(multi_value_args SRCS LIBS) + cmake_parse_arguments( + ARG "" "${one_value_args}" "${multi_value_args}" ${ARGN}) + + add_executable(${ARG_TARGET} ${ARG_SRCS}) + spvtools_default_compile_options(${ARG_TARGET}) + target_link_libraries(${ARG_TARGET} PRIVATE ${ARG_LIBS}) + set_property(TARGET ${ARG_TARGET} PROPERTY FOLDER "SPIRV-Tools examples") + endif() +endfunction() + +add_subdirectory(cpp-interface) diff --git a/examples/cpp-interface/CMakeLists.txt b/examples/cpp-interface/CMakeLists.txt new file mode 100644 index 000000000..d050b0759 --- /dev/null +++ b/examples/cpp-interface/CMakeLists.txt @@ -0,0 +1,19 @@ +# Copyright (c) 2016 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_spvtools_example( + TARGET spirv-tools-cpp-example + SRCS main.cpp + LIBS SPIRV-Tools-opt +) \ No newline at end of file diff --git a/examples/cpp-interface/main.cpp b/examples/cpp-interface/main.cpp new file mode 100644 index 000000000..c5354b8bd --- /dev/null +++ b/examples/cpp-interface/main.cpp @@ -0,0 +1,64 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This program demonstrates basic SPIR-V module processing using +// SPIRV-Tools C++ API: +// * Assembling +// * Validating +// * Optimizing +// * Disassembling + +#include +#include +#include + +#include "spirv-tools/libspirv.hpp" +#include "spirv-tools/optimizer.hpp" + +int main() { + const std::string source = + " OpCapability Shader " + " OpMemoryModel Logical GLSL450 " + " OpSource GLSL 450 " + " OpDecorate %spec SpecId 1 " + " %int = OpTypeInt 32 1 " + " %spec = OpSpecConstant %int 0 " + "%const = OpConstant %int 42"; + + spvtools::SpirvTools core(SPV_ENV_VULKAN_1_0); + spvtools::Optimizer opt(SPV_ENV_VULKAN_1_0); + + auto print_msg_to_stderr = [](spv_message_level_t, const char*, + const spv_position_t&, const char* m) { + std::cerr << "error: " << m << std::endl; + }; + core.SetMessageConsumer(print_msg_to_stderr); + opt.SetMessageConsumer(print_msg_to_stderr); + + std::vector spirv; + if (!core.Assemble(source, &spirv)) return 1; + if (!core.Validate(spirv)) return 1; + + opt.RegisterPass(spvtools::CreateSetSpecConstantDefaultValuePass({{1, "42"}})) + .RegisterPass(spvtools::CreateFreezeSpecConstantValuePass()) + .RegisterPass(spvtools::CreateUnifyConstantPass()) + .RegisterPass(spvtools::CreateStripDebugInfoPass()); + if (!opt.Run(spirv.data(), spirv.size(), &spirv)) return 1; + + std::string disassembly; + if (!core.Disassemble(spirv, &disassembly)) return 1; + std::cout << disassembly << "\n"; + + return 0; +} diff --git a/external/CMakeLists.txt b/external/CMakeLists.txt new file mode 100644 index 000000000..d1251c248 --- /dev/null +++ b/external/CMakeLists.txt @@ -0,0 +1,104 @@ +# Copyright (c) 2015-2016 The Khronos Group Inc. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if (DEFINED SPIRV-Headers_SOURCE_DIR) + # This allows flexible position of the SPIRV-Headers repo. + set(SPIRV_HEADER_DIR ${SPIRV-Headers_SOURCE_DIR}) +else() + if (IS_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/SPIRV-Headers) + set(SPIRV_HEADER_DIR ${CMAKE_CURRENT_SOURCE_DIR}/SPIRV-Headers) + else() + set(SPIRV_HEADER_DIR ${CMAKE_CURRENT_SOURCE_DIR}/spirv-headers) + endif() +endif() + +if (IS_DIRECTORY ${SPIRV_HEADER_DIR}) + set(SPIRV_HEADER_INCLUDE_DIR ${SPIRV_HEADER_DIR}/include PARENT_SCOPE) +else() + message(FATAL_ERROR + "SPIRV-Headers was not found - please checkout a copy under external/.") +endif() + +if (NOT ${SPIRV_SKIP_TESTS}) + # Find gmock if we can. If it's not already configured, then try finding + # it in external/googletest. + if (TARGET gmock) + message(STATUS "Google Mock already configured") + else() + set(GMOCK_DIR ${CMAKE_CURRENT_SOURCE_DIR}/googletest/googlemock) + if(EXISTS ${GMOCK_DIR}) + if(MSVC) + # Our tests use ::testing::Combine. Work around a compiler + # detection problem in googletest, where that template is + # accidentally disabled for VS 2017. + # See https://github.com/google/googletest/issues/1352 + add_definitions(-DGTEST_HAS_COMBINE=1) + endif() + if(WIN32) + option(gtest_force_shared_crt + "Use shared (DLL) run-time lib even when Google Test is built as static lib." + ON) + endif() + add_subdirectory(${GMOCK_DIR} EXCLUDE_FROM_ALL) + endif() + endif() + if (TARGET gmock) + set(GTEST_TARGETS + gtest + gtest_main + gmock + gmock_main + ) + foreach(target ${GTEST_TARGETS}) + set_property(TARGET ${target} PROPERTY FOLDER GoogleTest) + endforeach() + endif() + + # Find Effcee and RE2, for testing. + + # First find RE2, since Effcee depends on it. + # If already configured, then use that. Otherwise, prefer to find it under 're2' + # in this directory. + if (NOT TARGET re2) + # If we are configuring RE2, then turn off its testing. It takes a long time and + # does not add much value for us. If an enclosing project configured RE2, then it + # has already chosen whether to enable RE2 testing. + set(RE2_BUILD_TESTING OFF CACHE STRING "Run RE2 Tests") + if (NOT RE2_SOURCE_DIR) + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/re2) + set(RE2_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/re2" CACHE STRING "RE2 source dir" ) + endif() + endif() + endif() + + if (NOT TARGET effcee) + # Expect to find effcee in this directory. + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/effcee) + # If we're configuring RE2 (via Effcee), then turn off RE2 testing. + if (NOT TARGET re2) + set(RE2_BUILD_TESTING OFF) + endif() + if (MSVC) + # SPIRV-Tools uses the shared CRT with MSVC. Tell Effcee to do the same. + set(EFFCEE_ENABLE_SHARED_CRT ON) + endif() + add_subdirectory(effcee) + set_property(TARGET effcee PROPERTY FOLDER Effcee) + # Turn off warnings for effcee and re2 + set_property(TARGET effcee APPEND PROPERTY COMPILE_OPTIONS -w) + set_property(TARGET re2 APPEND PROPERTY COMPILE_OPTIONS -w) + endif() + endif() +endif() diff --git a/include/spirv-tools/instrument.hpp b/include/spirv-tools/instrument.hpp new file mode 100644 index 000000000..f8068099c --- /dev/null +++ b/include/spirv-tools/instrument.hpp @@ -0,0 +1,135 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// Copyright (c) 2018 Valve Corporation +// Copyright (c) 2018 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef INCLUDE_SPIRV_TOOLS_INSTRUMENT_HPP_ +#define INCLUDE_SPIRV_TOOLS_INSTRUMENT_HPP_ + +// Shader Instrumentation Interface +// +// This file provides an external interface for applications that wish to +// communicate with shaders instrumented by passes created by: +// +// CreateInstBindlessCheckPass +// +// More detailed documentation of this routine can be found in optimizer.hpp + +namespace spvtools { + +// Stream Output Buffer Offsets +// +// The following values provide 32-bit word offsets into the output buffer +// generated by InstrumentPass::GenDebugStreamWrite. This method is utilized +// by InstBindlessCheckPass. +// +// The first word of the debug output buffer contains the next available word +// in the data stream to be written. Shaders will atomically read and update +// this value so as not to overwrite each others records. This value must be +// initialized to zero +static const int kDebugOutputSizeOffset = 0; + +// The second word of the output buffer is the start of the stream of records +// written by the instrumented shaders. Each record represents a validation +// error. The format of the records is documented below. +static const int kDebugOutputDataOffset = 1; + +// Common Stream Record Offsets +// +// The following are offsets to fields which are common to all records written +// to the output stream. +// +// Each record first contains the size of the record in 32-bit words, including +// the size word. +static const int kInstCommonOutSize = 0; + +// This is the shader id passed by the layer when the instrumentation pass is +// created. +static const int kInstCommonOutShaderId = 1; + +// This is the ordinal position of the instruction within the SPIR-V shader +// which generated the validation error. +static const int kInstCommonOutInstructionIdx = 2; + +// This is the stage which generated the validation error. This word is used +// to determine the contents of the next two words in the record. +// 0:Vert, 1:TessCtrl, 2:TessEval, 3:Geom, 4:Frag, 5:Compute +static const int kInstCommonOutStageIdx = 3; +static const int kInstCommonOutCnt = 4; + +// Stage-specific Stream Record Offsets +// +// Each stage will contain different values in the next two words of the record +// used to identify which instantiation of the shader generated the validation +// error. +// +// Vertex Shader Output Record Offsets +static const int kInstVertOutVertexIndex = kInstCommonOutCnt; +static const int kInstVertOutInstanceIndex = kInstCommonOutCnt + 1; + +// Frag Shader Output Record Offsets +static const int kInstFragOutFragCoordX = kInstCommonOutCnt; +static const int kInstFragOutFragCoordY = kInstCommonOutCnt + 1; + +// Compute Shader Output Record Offsets +static const int kInstCompOutGlobalInvocationId = kInstCommonOutCnt; +static const int kInstCompOutUnused = kInstCommonOutCnt + 1; + +// Tessellation Shader Output Record Offsets +static const int kInstTessOutInvocationId = kInstCommonOutCnt; +static const int kInstTessOutUnused = kInstCommonOutCnt + 1; + +// Geometry Shader Output Record Offsets +static const int kInstGeomOutPrimitiveId = kInstCommonOutCnt; +static const int kInstGeomOutInvocationId = kInstCommonOutCnt + 1; + +// Size of Common and Stage-specific Members +static const int kInstStageOutCnt = kInstCommonOutCnt + 2; + +// Validation Error Code +// +// This identifies the validation error. It also helps to identify +// how many words follow in the record and their meaning. +static const int kInstValidationOutError = kInstStageOutCnt; + +// Validation-specific Output Record Offsets +// +// Each different validation will generate a potentially different +// number of words at the end of the record giving more specifics +// about the validation error. +// +// A bindless bounds error will output the index and the bound. +static const int kInstBindlessOutDescIndex = kInstStageOutCnt + 1; +static const int kInstBindlessOutDescBound = kInstStageOutCnt + 2; +static const int kInstBindlessOutCnt = kInstStageOutCnt + 3; + +// Maximum Output Record Member Count +static const int kInstMaxOutCnt = kInstStageOutCnt + 3; + +// Validation Error Codes +// +// These are the possible validation error codes. +static const int kInstErrorBindlessBounds = 0; + +// Debug Buffer Bindings +// +// These are the bindings for the different buffers which are +// read or written by the instrumentation passes. +// +// This is the output buffer written by InstBindlessCheckPass. +static const int kDebugOutputBindingStream = 0; + +} // namespace spvtools + +#endif // INCLUDE_SPIRV_TOOLS_INSTRUMENT_HPP_ diff --git a/include/spirv-tools/libspirv.h b/include/spirv-tools/libspirv.h new file mode 100644 index 000000000..ff7eb6b6c --- /dev/null +++ b/include/spirv-tools/libspirv.h @@ -0,0 +1,680 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef INCLUDE_SPIRV_TOOLS_LIBSPIRV_H_ +#define INCLUDE_SPIRV_TOOLS_LIBSPIRV_H_ + +#ifdef __cplusplus +extern "C" { +#else +#include +#endif + +#include +#include + +#if defined(SPIRV_TOOLS_SHAREDLIB) +#if defined(_WIN32) +#if defined(SPIRV_TOOLS_IMPLEMENTATION) +#define SPIRV_TOOLS_EXPORT __declspec(dllexport) +#else +#define SPIRV_TOOLS_EXPORT __declspec(dllimport) +#endif +#else +#if defined(SPIRV_TOOLS_IMPLEMENTATION) +#define SPIRV_TOOLS_EXPORT __attribute__((visibility("default"))) +#else +#define SPIRV_TOOLS_EXPORT +#endif +#endif +#else +#define SPIRV_TOOLS_EXPORT +#endif + +// Helpers + +#define SPV_BIT(shift) (1 << (shift)) + +#define SPV_FORCE_16_BIT_ENUM(name) _##name = 0x7fff +#define SPV_FORCE_32_BIT_ENUM(name) _##name = 0x7fffffff + +// Enumerations + +typedef enum spv_result_t { + SPV_SUCCESS = 0, + SPV_UNSUPPORTED = 1, + SPV_END_OF_STREAM = 2, + SPV_WARNING = 3, + SPV_FAILED_MATCH = 4, + SPV_REQUESTED_TERMINATION = 5, // Success, but signals early termination. + SPV_ERROR_INTERNAL = -1, + SPV_ERROR_OUT_OF_MEMORY = -2, + SPV_ERROR_INVALID_POINTER = -3, + SPV_ERROR_INVALID_BINARY = -4, + SPV_ERROR_INVALID_TEXT = -5, + SPV_ERROR_INVALID_TABLE = -6, + SPV_ERROR_INVALID_VALUE = -7, + SPV_ERROR_INVALID_DIAGNOSTIC = -8, + SPV_ERROR_INVALID_LOOKUP = -9, + SPV_ERROR_INVALID_ID = -10, + SPV_ERROR_INVALID_CFG = -11, + SPV_ERROR_INVALID_LAYOUT = -12, + SPV_ERROR_INVALID_CAPABILITY = -13, + SPV_ERROR_INVALID_DATA = -14, // Indicates data rules validation failure. + SPV_ERROR_MISSING_EXTENSION = -15, + SPV_ERROR_WRONG_VERSION = -16, // Indicates wrong SPIR-V version + SPV_FORCE_32_BIT_ENUM(spv_result_t) +} spv_result_t; + +// Severity levels of messages communicated to the consumer. +typedef enum spv_message_level_t { + SPV_MSG_FATAL, // Unrecoverable error due to environment. + // Will exit the program immediately. E.g., + // out of memory. + SPV_MSG_INTERNAL_ERROR, // Unrecoverable error due to SPIRV-Tools + // internals. + // Will exit the program immediately. E.g., + // unimplemented feature. + SPV_MSG_ERROR, // Normal error due to user input. + SPV_MSG_WARNING, // Warning information. + SPV_MSG_INFO, // General information. + SPV_MSG_DEBUG, // Debug information. +} spv_message_level_t; + +typedef enum spv_endianness_t { + SPV_ENDIANNESS_LITTLE, + SPV_ENDIANNESS_BIG, + SPV_FORCE_32_BIT_ENUM(spv_endianness_t) +} spv_endianness_t; + +// The kinds of operands that an instruction may have. +// +// Some operand types are "concrete". The binary parser uses a concrete +// operand type to describe an operand of a parsed instruction. +// +// The assembler uses all operand types. In addition to determining what +// kind of value an operand may be, non-concrete operand types capture the +// fact that an operand might be optional (may be absent, or present exactly +// once), or might occur zero or more times. +// +// Sometimes we also need to be able to express the fact that an operand +// is a member of an optional tuple of values. In that case the first member +// would be optional, and the subsequent members would be required. +typedef enum spv_operand_type_t { + // A sentinel value. + SPV_OPERAND_TYPE_NONE = 0, + + // Set 1: Operands that are IDs. + SPV_OPERAND_TYPE_ID, + SPV_OPERAND_TYPE_TYPE_ID, + SPV_OPERAND_TYPE_RESULT_ID, + SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID, // SPIR-V Sec 3.25 + SPV_OPERAND_TYPE_SCOPE_ID, // SPIR-V Sec 3.27 + + // Set 2: Operands that are literal numbers. + SPV_OPERAND_TYPE_LITERAL_INTEGER, // Always unsigned 32-bits. + // The Instruction argument to OpExtInst. It's an unsigned 32-bit literal + // number indicating which instruction to use from an extended instruction + // set. + SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, + // The Opcode argument to OpSpecConstantOp. It determines the operation + // to be performed on constant operands to compute a specialization constant + // result. + SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, + // A literal number whose format and size are determined by a previous operand + // in the same instruction. It's a signed integer, an unsigned integer, or a + // floating point number. It also has a specified bit width. The width + // may be larger than 32, which would require such a typed literal value to + // occupy multiple SPIR-V words. + SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, + + // Set 3: The literal string operand type. + SPV_OPERAND_TYPE_LITERAL_STRING, + + // Set 4: Operands that are a single word enumerated value. + SPV_OPERAND_TYPE_SOURCE_LANGUAGE, // SPIR-V Sec 3.2 + SPV_OPERAND_TYPE_EXECUTION_MODEL, // SPIR-V Sec 3.3 + SPV_OPERAND_TYPE_ADDRESSING_MODEL, // SPIR-V Sec 3.4 + SPV_OPERAND_TYPE_MEMORY_MODEL, // SPIR-V Sec 3.5 + SPV_OPERAND_TYPE_EXECUTION_MODE, // SPIR-V Sec 3.6 + SPV_OPERAND_TYPE_STORAGE_CLASS, // SPIR-V Sec 3.7 + SPV_OPERAND_TYPE_DIMENSIONALITY, // SPIR-V Sec 3.8 + SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE, // SPIR-V Sec 3.9 + SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE, // SPIR-V Sec 3.10 + SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT, // SPIR-V Sec 3.11 + SPV_OPERAND_TYPE_IMAGE_CHANNEL_ORDER, // SPIR-V Sec 3.12 + SPV_OPERAND_TYPE_IMAGE_CHANNEL_DATA_TYPE, // SPIR-V Sec 3.13 + SPV_OPERAND_TYPE_FP_ROUNDING_MODE, // SPIR-V Sec 3.16 + SPV_OPERAND_TYPE_LINKAGE_TYPE, // SPIR-V Sec 3.17 + SPV_OPERAND_TYPE_ACCESS_QUALIFIER, // SPIR-V Sec 3.18 + SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE, // SPIR-V Sec 3.19 + SPV_OPERAND_TYPE_DECORATION, // SPIR-V Sec 3.20 + SPV_OPERAND_TYPE_BUILT_IN, // SPIR-V Sec 3.21 + SPV_OPERAND_TYPE_GROUP_OPERATION, // SPIR-V Sec 3.28 + SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS, // SPIR-V Sec 3.29 + SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO, // SPIR-V Sec 3.30 + SPV_OPERAND_TYPE_CAPABILITY, // SPIR-V Sec 3.31 + + // Set 5: Operands that are a single word bitmask. + // Sometimes a set bit indicates the instruction requires still more operands. + SPV_OPERAND_TYPE_IMAGE, // SPIR-V Sec 3.14 + SPV_OPERAND_TYPE_FP_FAST_MATH_MODE, // SPIR-V Sec 3.15 + SPV_OPERAND_TYPE_SELECTION_CONTROL, // SPIR-V Sec 3.22 + SPV_OPERAND_TYPE_LOOP_CONTROL, // SPIR-V Sec 3.23 + SPV_OPERAND_TYPE_FUNCTION_CONTROL, // SPIR-V Sec 3.24 + SPV_OPERAND_TYPE_MEMORY_ACCESS, // SPIR-V Sec 3.26 + +// The remaining operand types are only used internally by the assembler. +// There are two categories: +// Optional : expands to 0 or 1 operand, like ? in regular expressions. +// Variable : expands to 0, 1 or many operands or pairs of operands. +// This is similar to * in regular expressions. + +// Macros for defining bounds on optional and variable operand types. +// Any variable operand type is also optional. +#define FIRST_OPTIONAL(ENUM) ENUM, SPV_OPERAND_TYPE_FIRST_OPTIONAL_TYPE = ENUM +#define FIRST_VARIABLE(ENUM) ENUM, SPV_OPERAND_TYPE_FIRST_VARIABLE_TYPE = ENUM +#define LAST_VARIABLE(ENUM) \ + ENUM, SPV_OPERAND_TYPE_LAST_VARIABLE_TYPE = ENUM, \ + SPV_OPERAND_TYPE_LAST_OPTIONAL_TYPE = ENUM + + // An optional operand represents zero or one logical operands. + // In an instruction definition, this may only appear at the end of the + // operand types. + FIRST_OPTIONAL(SPV_OPERAND_TYPE_OPTIONAL_ID), + // An optional image operand type. + SPV_OPERAND_TYPE_OPTIONAL_IMAGE, + // An optional memory access type. + SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS, + // An optional literal integer. + SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER, + // An optional literal number, which may be either integer or floating point. + SPV_OPERAND_TYPE_OPTIONAL_LITERAL_NUMBER, + // Like SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, but optional, and integral. + SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER, + // An optional literal string. + SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING, + // An optional access qualifier + SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER, + // An optional context-independent value, or CIV. CIVs are tokens that we can + // assemble regardless of where they occur -- literals, IDs, immediate + // integers, etc. + SPV_OPERAND_TYPE_OPTIONAL_CIV, + + // A variable operand represents zero or more logical operands. + // In an instruction definition, this may only appear at the end of the + // operand types. + FIRST_VARIABLE(SPV_OPERAND_TYPE_VARIABLE_ID), + SPV_OPERAND_TYPE_VARIABLE_LITERAL_INTEGER, + // A sequence of zero or more pairs of (typed literal integer, Id). + // Expands to zero or more: + // (SPV_OPERAND_TYPE_TYPED_LITERAL_INTEGER, SPV_OPERAND_TYPE_ID) + // where the literal number must always be an integer of some sort. + SPV_OPERAND_TYPE_VARIABLE_LITERAL_INTEGER_ID, + // A sequence of zero or more pairs of (Id, Literal integer) + LAST_VARIABLE(SPV_OPERAND_TYPE_VARIABLE_ID_LITERAL_INTEGER), + + // The following are concrete enum types. + SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS, // DebugInfo Sec 3.2. A mask. + SPV_OPERAND_TYPE_DEBUG_BASE_TYPE_ATTRIBUTE_ENCODING, // DebugInfo Sec 3.3 + SPV_OPERAND_TYPE_DEBUG_COMPOSITE_TYPE, // DebugInfo Sec 3.4 + SPV_OPERAND_TYPE_DEBUG_TYPE_QUALIFIER, // DebugInfo Sec 3.5 + SPV_OPERAND_TYPE_DEBUG_OPERATION, // DebugInfo Sec 3.6 + + // This is a sentinel value, and does not represent an operand type. + // It should come last. + SPV_OPERAND_TYPE_NUM_OPERAND_TYPES, + + SPV_FORCE_32_BIT_ENUM(spv_operand_type_t) +} spv_operand_type_t; + +typedef enum spv_ext_inst_type_t { + SPV_EXT_INST_TYPE_NONE = 0, + SPV_EXT_INST_TYPE_GLSL_STD_450, + SPV_EXT_INST_TYPE_OPENCL_STD, + SPV_EXT_INST_TYPE_SPV_AMD_SHADER_EXPLICIT_VERTEX_PARAMETER, + SPV_EXT_INST_TYPE_SPV_AMD_SHADER_TRINARY_MINMAX, + SPV_EXT_INST_TYPE_SPV_AMD_GCN_SHADER, + SPV_EXT_INST_TYPE_SPV_AMD_SHADER_BALLOT, + SPV_EXT_INST_TYPE_DEBUGINFO, + + SPV_FORCE_32_BIT_ENUM(spv_ext_inst_type_t) +} spv_ext_inst_type_t; + +// This determines at a high level the kind of a binary-encoded literal +// number, but not the bit width. +// In principle, these could probably be folded into new entries in +// spv_operand_type_t. But then we'd have some special case differences +// between the assembler and disassembler. +typedef enum spv_number_kind_t { + SPV_NUMBER_NONE = 0, // The default for value initialization. + SPV_NUMBER_UNSIGNED_INT, + SPV_NUMBER_SIGNED_INT, + SPV_NUMBER_FLOATING, +} spv_number_kind_t; + +typedef enum spv_text_to_binary_options_t { + SPV_TEXT_TO_BINARY_OPTION_NONE = SPV_BIT(0), + // Numeric IDs in the binary will have the same values as in the source. + // Non-numeric IDs are allocated by filling in the gaps, starting with 1 + // and going up. + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS = SPV_BIT(1), + SPV_FORCE_32_BIT_ENUM(spv_text_to_binary_options_t) +} spv_text_to_binary_options_t; + +typedef enum spv_binary_to_text_options_t { + SPV_BINARY_TO_TEXT_OPTION_NONE = SPV_BIT(0), + SPV_BINARY_TO_TEXT_OPTION_PRINT = SPV_BIT(1), + SPV_BINARY_TO_TEXT_OPTION_COLOR = SPV_BIT(2), + SPV_BINARY_TO_TEXT_OPTION_INDENT = SPV_BIT(3), + SPV_BINARY_TO_TEXT_OPTION_SHOW_BYTE_OFFSET = SPV_BIT(4), + // Do not output the module header as leading comments in the assembly. + SPV_BINARY_TO_TEXT_OPTION_NO_HEADER = SPV_BIT(5), + // Use friendly names where possible. The heuristic may expand over + // time, but will use common names for scalar types, and debug names from + // OpName instructions. + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES = SPV_BIT(6), + SPV_FORCE_32_BIT_ENUM(spv_binary_to_text_options_t) +} spv_binary_to_text_options_t; + +// Constants + +// The default id bound is to the minimum value for the id limit +// in the spir-v specification under the section "Universal Limits". +const uint32_t kDefaultMaxIdBound = 0x3FFFFF; + +// Structures + +// Information about an operand parsed from a binary SPIR-V module. +// Note that the values are not included. You still need access to the binary +// to extract the values. +typedef struct spv_parsed_operand_t { + // Location of the operand, in words from the start of the instruction. + uint16_t offset; + // Number of words occupied by this operand. + uint16_t num_words; + // The "concrete" operand type. See the definition of spv_operand_type_t + // for details. + spv_operand_type_t type; + // If type is a literal number type, then number_kind says whether it's + // a signed integer, an unsigned integer, or a floating point number. + spv_number_kind_t number_kind; + // The number of bits for a literal number type. + uint32_t number_bit_width; +} spv_parsed_operand_t; + +// An instruction parsed from a binary SPIR-V module. +typedef struct spv_parsed_instruction_t { + // An array of words for this instruction, in native endianness. + const uint32_t* words; + // The number of words in this instruction. + uint16_t num_words; + uint16_t opcode; + // The extended instruction type, if opcode is OpExtInst. Otherwise + // this is the "none" value. + spv_ext_inst_type_t ext_inst_type; + // The type id, or 0 if this instruction doesn't have one. + uint32_t type_id; + // The result id, or 0 if this instruction doesn't have one. + uint32_t result_id; + // The array of parsed operands. + const spv_parsed_operand_t* operands; + uint16_t num_operands; +} spv_parsed_instruction_t; + +typedef struct spv_const_binary_t { + const uint32_t* code; + const size_t wordCount; +} spv_const_binary_t; + +typedef struct spv_binary_t { + uint32_t* code; + size_t wordCount; +} spv_binary_t; + +typedef struct spv_text_t { + const char* str; + size_t length; +} spv_text_t; + +typedef struct spv_position_t { + size_t line; + size_t column; + size_t index; +} spv_position_t; + +typedef struct spv_diagnostic_t { + spv_position_t position; + char* error; + bool isTextSource; +} spv_diagnostic_t; + +// Opaque struct containing the context used to operate on a SPIR-V module. +// Its object is used by various translation API functions. +typedef struct spv_context_t spv_context_t; + +typedef struct spv_validator_options_t spv_validator_options_t; + +typedef struct spv_optimizer_options_t spv_optimizer_options_t; + +typedef struct spv_reducer_options_t spv_reducer_options_t; + +// Type Definitions + +typedef spv_const_binary_t* spv_const_binary; +typedef spv_binary_t* spv_binary; +typedef spv_text_t* spv_text; +typedef spv_position_t* spv_position; +typedef spv_diagnostic_t* spv_diagnostic; +typedef const spv_context_t* spv_const_context; +typedef spv_context_t* spv_context; +typedef spv_validator_options_t* spv_validator_options; +typedef const spv_validator_options_t* spv_const_validator_options; +typedef spv_optimizer_options_t* spv_optimizer_options; +typedef const spv_optimizer_options_t* spv_const_optimizer_options; +typedef spv_reducer_options_t* spv_reducer_options; +typedef const spv_reducer_options_t* spv_const_reducer_options; + +// Platform API + +// Returns the SPIRV-Tools software version as a null-terminated string. +// The contents of the underlying storage is valid for the remainder of +// the process. +SPIRV_TOOLS_EXPORT const char* spvSoftwareVersionString(void); +// Returns a null-terminated string containing the name of the project, +// the software version string, and commit details. +// The contents of the underlying storage is valid for the remainder of +// the process. +SPIRV_TOOLS_EXPORT const char* spvSoftwareVersionDetailsString(void); + +// Certain target environments impose additional restrictions on SPIR-V, so it's +// often necessary to specify which one applies. SPV_ENV_UNIVERSAL means +// environment-agnostic SPIR-V. +typedef enum { + SPV_ENV_UNIVERSAL_1_0, // SPIR-V 1.0 latest revision, no other restrictions. + SPV_ENV_VULKAN_1_0, // Vulkan 1.0 latest revision. + SPV_ENV_UNIVERSAL_1_1, // SPIR-V 1.1 latest revision, no other restrictions. + SPV_ENV_OPENCL_2_1, // OpenCL Full Profile 2.1 latest revision. + SPV_ENV_OPENCL_2_2, // OpenCL Full Profile 2.2 latest revision. + SPV_ENV_OPENGL_4_0, // OpenGL 4.0 plus GL_ARB_gl_spirv, latest revisions. + SPV_ENV_OPENGL_4_1, // OpenGL 4.1 plus GL_ARB_gl_spirv, latest revisions. + SPV_ENV_OPENGL_4_2, // OpenGL 4.2 plus GL_ARB_gl_spirv, latest revisions. + SPV_ENV_OPENGL_4_3, // OpenGL 4.3 plus GL_ARB_gl_spirv, latest revisions. + // There is no variant for OpenGL 4.4. + SPV_ENV_OPENGL_4_5, // OpenGL 4.5 plus GL_ARB_gl_spirv, latest revisions. + SPV_ENV_UNIVERSAL_1_2, // SPIR-V 1.2, latest revision, no other restrictions. + SPV_ENV_OPENCL_1_2, // OpenCL Full Profile 1.2 plus cl_khr_il_program, + // latest revision. + SPV_ENV_OPENCL_EMBEDDED_1_2, // OpenCL Embedded Profile 1.2 plus + // cl_khr_il_program, latest revision. + SPV_ENV_OPENCL_2_0, // OpenCL Full Profile 2.0 plus cl_khr_il_program, + // latest revision. + SPV_ENV_OPENCL_EMBEDDED_2_0, // OpenCL Embedded Profile 2.0 plus + // cl_khr_il_program, latest revision. + SPV_ENV_OPENCL_EMBEDDED_2_1, // OpenCL Embedded Profile 2.1 latest revision. + SPV_ENV_OPENCL_EMBEDDED_2_2, // OpenCL Embedded Profile 2.2 latest revision. + SPV_ENV_UNIVERSAL_1_3, // SPIR-V 1.3 latest revision, no other restrictions. + SPV_ENV_VULKAN_1_1, // Vulkan 1.1 latest revision. + SPV_ENV_WEBGPU_0, // Work in progress WebGPU 1.0. +} spv_target_env; + +// SPIR-V Validator can be parameterized with the following Universal Limits. +typedef enum { + spv_validator_limit_max_struct_members, + spv_validator_limit_max_struct_depth, + spv_validator_limit_max_local_variables, + spv_validator_limit_max_global_variables, + spv_validator_limit_max_switch_branches, + spv_validator_limit_max_function_args, + spv_validator_limit_max_control_flow_nesting_depth, + spv_validator_limit_max_access_chain_indexes, + spv_validator_limit_max_id_bound, +} spv_validator_limit; + +// Returns a string describing the given SPIR-V target environment. +SPIRV_TOOLS_EXPORT const char* spvTargetEnvDescription(spv_target_env env); + +// Creates a context object. Returns null if env is invalid. +SPIRV_TOOLS_EXPORT spv_context spvContextCreate(spv_target_env env); + +// Destroys the given context object. +SPIRV_TOOLS_EXPORT void spvContextDestroy(spv_context context); + +// Creates a Validator options object with default options. Returns a valid +// options object. The object remains valid until it is passed into +// spvValidatorOptionsDestroy. +SPIRV_TOOLS_EXPORT spv_validator_options spvValidatorOptionsCreate(void); + +// Destroys the given Validator options object. +SPIRV_TOOLS_EXPORT void spvValidatorOptionsDestroy( + spv_validator_options options); + +// Records the maximum Universal Limit that is considered valid in the given +// Validator options object. argument must be a valid options object. +SPIRV_TOOLS_EXPORT void spvValidatorOptionsSetUniversalLimit( + spv_validator_options options, spv_validator_limit limit_type, + uint32_t limit); + +// Record whether or not the validator should relax the rules on types for +// stores to structs. When relaxed, it will allow a type mismatch as long as +// the types are structs with the same layout. Two structs have the same layout +// if +// +// 1) the members of the structs are either the same type or are structs with +// same layout, and +// +// 2) the decorations that affect the memory layout are identical for both +// types. Other decorations are not relevant. +SPIRV_TOOLS_EXPORT void spvValidatorOptionsSetRelaxStoreStruct( + spv_validator_options options, bool val); + +// Records whether or not the validator should relax the rules on pointer usage +// in logical addressing mode. +// +// When relaxed, it will allow the following usage cases of pointers: +// 1) OpVariable allocating an object whose type is a pointer type +// 2) OpReturnValue returning a pointer value +SPIRV_TOOLS_EXPORT void spvValidatorOptionsSetRelaxLogicalPointer( + spv_validator_options options, bool val); + +// Records whether the validator should use "relaxed" block layout rules. +// Relaxed layout rules are described by Vulkan extension +// VK_KHR_relaxed_block_layout, and they affect uniform blocks, storage blocks, +// and push constants. +// +// This is enabled by default when targeting Vulkan 1.1 or later. +// Relaxed layout is more permissive than the default rules in Vulkan 1.0. +SPIRV_TOOLS_EXPORT void spvValidatorOptionsSetRelaxBlockLayout( + spv_validator_options options, bool val); + +// Records whether the validator should use "scalar" block layout rules. +// Scalar layout rules are more permissive than relaxed block layout. +// +// See Vulkan extnesion VK_EXT_scalar_block_layout. The scalar alignment is +// defined as follows: +// - scalar alignment of a scalar is the scalar size +// - scalar alignment of a vector is the scalar alignment of its component +// - scalar alignment of a matrix is the scalar alignment of its component +// - scalar alignment of an array is the scalar alignment of its element +// - scalar alignment of a struct is the max scalar alignment among its +// members +// +// For a struct in Uniform, StorageClass, or PushConstant: +// - a member Offset must be a multiple of the member's scalar alignment +// - ArrayStride or MatrixStride must be a multiple of the array or matrix +// scalar alignment +SPIRV_TOOLS_EXPORT void spvValidatorOptionsSetScalarBlockLayout( + spv_validator_options options, bool val); + +// Records whether or not the validator should skip validating standard +// uniform/storage block layout. +SPIRV_TOOLS_EXPORT void spvValidatorOptionsSetSkipBlockLayout( + spv_validator_options options, bool val); + +// Creates an optimizer options object with default options. Returns a valid +// options object. The object remains valid until it is passed into +// |spvOptimizerOptionsDestroy|. +SPIRV_TOOLS_EXPORT spv_optimizer_options spvOptimizerOptionsCreate(void); + +// Destroys the given optimizer options object. +SPIRV_TOOLS_EXPORT void spvOptimizerOptionsDestroy( + spv_optimizer_options options); + +// Records whether or not the optimizer should run the validator before +// optimizing. If |val| is true, the validator will be run. +SPIRV_TOOLS_EXPORT void spvOptimizerOptionsSetRunValidator( + spv_optimizer_options options, bool val); + +// Records the validator options that should be passed to the validator if it is +// run. +SPIRV_TOOLS_EXPORT void spvOptimizerOptionsSetValidatorOptions( + spv_optimizer_options options, spv_validator_options val); + +// Records the maximum possible value for the id bound. +SPIRV_TOOLS_EXPORT void spvOptimizerOptionsSetMaxIdBound( + spv_optimizer_options options, uint32_t val); + +// Creates a reducer options object with default options. Returns a valid +// options object. The object remains valid until it is passed into +// |spvReducerOptionsDestroy|. +SPIRV_TOOLS_EXPORT spv_reducer_options spvReducerOptionsCreate(); + +// Destroys the given reducer options object. +SPIRV_TOOLS_EXPORT void spvReducerOptionsDestroy(spv_reducer_options options); + +// Records the maximum number of reduction steps that should run before the +// reducer gives up. +SPIRV_TOOLS_EXPORT void spvReducerOptionsSetStepLimit( + spv_reducer_options options, uint32_t step_limit); + +// Sets seed for random number generation. +SPIRV_TOOLS_EXPORT void spvReducerOptionsSetSeed(spv_reducer_options options, + uint32_t seed); + +// Encodes the given SPIR-V assembly text to its binary representation. The +// length parameter specifies the number of bytes for text. Encoded binary will +// be stored into *binary. Any error will be written into *diagnostic if +// diagnostic is non-null, otherwise the context's message consumer will be +// used. The generated binary is independent of the context and may outlive it. +SPIRV_TOOLS_EXPORT spv_result_t spvTextToBinary(const spv_const_context context, + const char* text, + const size_t length, + spv_binary* binary, + spv_diagnostic* diagnostic); + +// Encodes the given SPIR-V assembly text to its binary representation. Same as +// spvTextToBinary but with options. The options parameter is a bit field of +// spv_text_to_binary_options_t. +SPIRV_TOOLS_EXPORT spv_result_t spvTextToBinaryWithOptions( + const spv_const_context context, const char* text, const size_t length, + const uint32_t options, spv_binary* binary, spv_diagnostic* diagnostic); + +// Frees an allocated text stream. This is a no-op if the text parameter +// is a null pointer. +SPIRV_TOOLS_EXPORT void spvTextDestroy(spv_text text); + +// Decodes the given SPIR-V binary representation to its assembly text. The +// word_count parameter specifies the number of words for binary. The options +// parameter is a bit field of spv_binary_to_text_options_t. Decoded text will +// be stored into *text. Any error will be written into *diagnostic if +// diagnostic is non-null, otherwise the context's message consumer will be +// used. +SPIRV_TOOLS_EXPORT spv_result_t spvBinaryToText(const spv_const_context context, + const uint32_t* binary, + const size_t word_count, + const uint32_t options, + spv_text* text, + spv_diagnostic* diagnostic); + +// Frees a binary stream from memory. This is a no-op if binary is a null +// pointer. +SPIRV_TOOLS_EXPORT void spvBinaryDestroy(spv_binary binary); + +// Validates a SPIR-V binary for correctness. Any errors will be written into +// *diagnostic if diagnostic is non-null, otherwise the context's message +// consumer will be used. +SPIRV_TOOLS_EXPORT spv_result_t spvValidate(const spv_const_context context, + const spv_const_binary binary, + spv_diagnostic* diagnostic); + +// Validates a SPIR-V binary for correctness. Uses the provided Validator +// options. Any errors will be written into *diagnostic if diagnostic is +// non-null, otherwise the context's message consumer will be used. +SPIRV_TOOLS_EXPORT spv_result_t spvValidateWithOptions( + const spv_const_context context, const spv_const_validator_options options, + const spv_const_binary binary, spv_diagnostic* diagnostic); + +// Validates a raw SPIR-V binary for correctness. Any errors will be written +// into *diagnostic if diagnostic is non-null, otherwise the context's message +// consumer will be used. +SPIRV_TOOLS_EXPORT spv_result_t +spvValidateBinary(const spv_const_context context, const uint32_t* words, + const size_t num_words, spv_diagnostic* diagnostic); + +// Creates a diagnostic object. The position parameter specifies the location in +// the text/binary stream. The message parameter, copied into the diagnostic +// object, contains the error message to display. +SPIRV_TOOLS_EXPORT spv_diagnostic +spvDiagnosticCreate(const spv_position position, const char* message); + +// Destroys a diagnostic object. This is a no-op if diagnostic is a null +// pointer. +SPIRV_TOOLS_EXPORT void spvDiagnosticDestroy(spv_diagnostic diagnostic); + +// Prints the diagnostic to stderr. +SPIRV_TOOLS_EXPORT spv_result_t +spvDiagnosticPrint(const spv_diagnostic diagnostic); + +// The binary parser interface. + +// A pointer to a function that accepts a parsed SPIR-V header. +// The integer arguments are the 32-bit words from the header, as specified +// in SPIR-V 1.0 Section 2.3 Table 1. +// The function should return SPV_SUCCESS if parsing should continue. +typedef spv_result_t (*spv_parsed_header_fn_t)( + void* user_data, spv_endianness_t endian, uint32_t magic, uint32_t version, + uint32_t generator, uint32_t id_bound, uint32_t reserved); + +// A pointer to a function that accepts a parsed SPIR-V instruction. +// The parsed_instruction value is transient: it may be overwritten +// or released immediately after the function has returned. That also +// applies to the words array member of the parsed instruction. The +// function should return SPV_SUCCESS if and only if parsing should +// continue. +typedef spv_result_t (*spv_parsed_instruction_fn_t)( + void* user_data, const spv_parsed_instruction_t* parsed_instruction); + +// Parses a SPIR-V binary, specified as counted sequence of 32-bit words. +// Parsing feedback is provided via two callbacks provided as function +// pointers. Each callback function pointer can be a null pointer, in +// which case it is never called. Otherwise, in a valid parse the +// parsed-header callback is called once, and then the parsed-instruction +// callback once for each instruction in the stream. The user_data parameter +// is supplied as context to the callbacks. Returns SPV_SUCCESS on successful +// parse where the callbacks always return SPV_SUCCESS. For an invalid parse, +// returns a status code other than SPV_SUCCESS, and if diagnostic is non-null +// also emits a diagnostic. If diagnostic is null the context's message consumer +// will be used to emit any errors. If a callback returns anything other than +// SPV_SUCCESS, then that status code is returned, no further callbacks are +// issued, and no additional diagnostics are emitted. +SPIRV_TOOLS_EXPORT spv_result_t spvBinaryParse( + const spv_const_context context, void* user_data, const uint32_t* words, + const size_t num_words, spv_parsed_header_fn_t parse_header, + spv_parsed_instruction_fn_t parse_instruction, spv_diagnostic* diagnostic); + +#ifdef __cplusplus +} +#endif + +#endif // INCLUDE_SPIRV_TOOLS_LIBSPIRV_H_ diff --git a/include/spirv-tools/libspirv.hpp b/include/spirv-tools/libspirv.hpp new file mode 100644 index 000000000..9cb5afe3f --- /dev/null +++ b/include/spirv-tools/libspirv.hpp @@ -0,0 +1,245 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef INCLUDE_SPIRV_TOOLS_LIBSPIRV_HPP_ +#define INCLUDE_SPIRV_TOOLS_LIBSPIRV_HPP_ + +#include +#include +#include +#include + +#include "spirv-tools/libspirv.h" + +namespace spvtools { + +// Message consumer. The C strings for source and message are only alive for the +// specific invocation. +using MessageConsumer = std::function; + +// C++ RAII wrapper around the C context object spv_context. +class Context { + public: + // Constructs a context targeting the given environment |env|. + // + // The constructed instance will have an empty message consumer, which just + // ignores all messages from the library. Use SetMessageConsumer() to supply + // one if messages are of concern. + explicit Context(spv_target_env env); + + // Enables move constructor/assignment operations. + Context(Context&& other); + Context& operator=(Context&& other); + + // Disables copy constructor/assignment operations. + Context(const Context&) = delete; + Context& operator=(const Context&) = delete; + + // Destructs this instance. + ~Context(); + + // Sets the message consumer to the given |consumer|. The |consumer| will be + // invoked once for each message communicated from the library. + void SetMessageConsumer(MessageConsumer consumer); + + // Returns the underlying spv_context. + spv_context& CContext(); + const spv_context& CContext() const; + + private: + spv_context context_; +}; + +// A RAII wrapper around a validator options object. +class ValidatorOptions { + public: + ValidatorOptions() : options_(spvValidatorOptionsCreate()) {} + ~ValidatorOptions() { spvValidatorOptionsDestroy(options_); } + // Allow implicit conversion to the underlying object. + operator spv_validator_options() const { return options_; } + + // Sets a limit. + void SetUniversalLimit(spv_validator_limit limit_type, uint32_t limit) { + spvValidatorOptionsSetUniversalLimit(options_, limit_type, limit); + } + + void SetRelaxStructStore(bool val) { + spvValidatorOptionsSetRelaxStoreStruct(options_, val); + } + + // Enables VK_KHR_relaxed_block_layout when validating standard + // uniform/storage buffer/push-constant layout. If true, disables + // scalar block layout rules. + void SetRelaxBlockLayout(bool val) { + spvValidatorOptionsSetRelaxBlockLayout(options_, val); + } + + // Enables VK_EXT_scalar_block_layout when validating standard + // uniform/storage buffer/push-constant layout. If true, disables + // relaxed block layout rules. + void SetScalarBlockLayout(bool val) { + spvValidatorOptionsSetScalarBlockLayout(options_, val); + } + + // Skips validating standard uniform/storage buffer/push-constant layout. + void SetSkipBlockLayout(bool val) { + spvValidatorOptionsSetSkipBlockLayout(options_, val); + } + + // Records whether or not the validator should relax the rules on pointer + // usage in logical addressing mode. + // + // When relaxed, it will allow the following usage cases of pointers: + // 1) OpVariable allocating an object whose type is a pointer type + // 2) OpReturnValue returning a pointer value + void SetRelaxLogicalPointer(bool val) { + spvValidatorOptionsSetRelaxLogicalPointer(options_, val); + } + + private: + spv_validator_options options_; +}; + +// A C++ wrapper around an optimization options object. +class OptimizerOptions { + public: + OptimizerOptions() : options_(spvOptimizerOptionsCreate()) {} + ~OptimizerOptions() { spvOptimizerOptionsDestroy(options_); } + + // Allow implicit conversion to the underlying object. + operator spv_optimizer_options() const { return options_; } + + // Records whether or not the optimizer should run the validator before + // optimizing. If |run| is true, the validator will be run. + void set_run_validator(bool run) { + spvOptimizerOptionsSetRunValidator(options_, run); + } + + // Records the validator options that should be passed to the validator if it + // is run. + void set_validator_options(const ValidatorOptions& val_options) { + spvOptimizerOptionsSetValidatorOptions(options_, val_options); + } + + // Records the maximum possible value for the id bound. + void set_max_id_bound(uint32_t new_bound) { + spvOptimizerOptionsSetMaxIdBound(options_, new_bound); + } + + private: + spv_optimizer_options options_; +}; + +// A C++ wrapper around a reducer options object. +class ReducerOptions { + public: + ReducerOptions() : options_(spvReducerOptionsCreate()) {} + ~ReducerOptions() { spvReducerOptionsDestroy(options_); } + + // Allow implicit conversion to the underlying object. + operator spv_reducer_options() const { return options_; } + + // Records the maximum number of reduction steps that should + // run before the reducer gives up. + void set_step_limit(uint32_t step_limit) { + spvReducerOptionsSetStepLimit(options_, step_limit); + } + + // Sets a seed to be used for random number generation. + void set_seed(uint32_t seed) { spvReducerOptionsSetSeed(options_, seed); } + + private: + spv_reducer_options options_; +}; + +// C++ interface for SPIRV-Tools functionalities. It wraps the context +// (including target environment and the corresponding SPIR-V grammar) and +// provides methods for assembling, disassembling, and validating. +// +// Instances of this class provide basic thread-safety guarantee. +class SpirvTools { + public: + enum { + // Default assembling option used by assemble(): + kDefaultAssembleOption = SPV_TEXT_TO_BINARY_OPTION_NONE, + + // Default disassembling option used by Disassemble(): + // * Avoid prefix comments from decoding the SPIR-V module header, and + // * Use friendly names for variables. + kDefaultDisassembleOption = SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES + }; + + // Constructs an instance targeting the given environment |env|. + // + // The constructed instance will have an empty message consumer, which just + // ignores all messages from the library. Use SetMessageConsumer() to supply + // one if messages are of concern. + explicit SpirvTools(spv_target_env env); + + // Disables copy/move constructor/assignment operations. + SpirvTools(const SpirvTools&) = delete; + SpirvTools(SpirvTools&&) = delete; + SpirvTools& operator=(const SpirvTools&) = delete; + SpirvTools& operator=(SpirvTools&&) = delete; + + // Destructs this instance. + ~SpirvTools(); + + // Sets the message consumer to the given |consumer|. The |consumer| will be + // invoked once for each message communicated from the library. + void SetMessageConsumer(MessageConsumer consumer); + + // Assembles the given assembly |text| and writes the result to |binary|. + // Returns true on successful assembling. |binary| will be kept untouched if + // assembling is unsuccessful. + bool Assemble(const std::string& text, std::vector* binary, + uint32_t options = kDefaultAssembleOption) const; + // |text_size| specifies the number of bytes in |text|. A terminating null + // character is not required to present in |text| as long as |text| is valid. + bool Assemble(const char* text, size_t text_size, + std::vector* binary, + uint32_t options = kDefaultAssembleOption) const; + + // Disassembles the given SPIR-V |binary| with the given |options| and writes + // the assembly to |text|. Returns ture on successful disassembling. |text| + // will be kept untouched if diassembling is unsuccessful. + bool Disassemble(const std::vector& binary, std::string* text, + uint32_t options = kDefaultDisassembleOption) const; + // |binary_size| specifies the number of words in |binary|. + bool Disassemble(const uint32_t* binary, size_t binary_size, + std::string* text, + uint32_t options = kDefaultDisassembleOption) const; + + // Validates the given SPIR-V |binary|. Returns true if no issues are found. + // Otherwise, returns false and communicates issues via the message consumer + // registered. + bool Validate(const std::vector& binary) const; + // |binary_size| specifies the number of words in |binary|. + bool Validate(const uint32_t* binary, size_t binary_size) const; + // Like the previous overload, but takes an options object. + bool Validate(const uint32_t* binary, size_t binary_size, + spv_validator_options options) const; + + private: + struct Impl; // Opaque struct for holding the data fields used by this class. + std::unique_ptr impl_; // Unique pointer to implementation data. +}; + +} // namespace spvtools + +#endif // INCLUDE_SPIRV_TOOLS_LIBSPIRV_HPP_ diff --git a/include/spirv-tools/linker.hpp b/include/spirv-tools/linker.hpp new file mode 100644 index 000000000..d2f3e72ca --- /dev/null +++ b/include/spirv-tools/linker.hpp @@ -0,0 +1,97 @@ +// Copyright (c) 2017 Pierre Moreau +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef INCLUDE_SPIRV_TOOLS_LINKER_HPP_ +#define INCLUDE_SPIRV_TOOLS_LINKER_HPP_ + +#include + +#include +#include + +#include "libspirv.hpp" + +namespace spvtools { + +class LinkerOptions { + public: + LinkerOptions() + : create_library_(false), + verify_ids_(false), + allow_partial_linkage_(false) {} + + // Returns whether a library or an executable should be produced by the + // linking phase. + // + // All exported symbols are kept when creating a library, whereas they will + // be removed when creating an executable. + // The returned value will be true if creating a library, and false if + // creating an executable. + bool GetCreateLibrary() const { return create_library_; } + + // Sets whether a library or an executable should be produced. + void SetCreateLibrary(bool create_library) { + create_library_ = create_library; + } + + // Returns whether to verify the uniqueness of the unique ids in the merged + // context. + bool GetVerifyIds() const { return verify_ids_; } + + // Sets whether to verify the uniqueness of the unique ids in the merged + // context. + void SetVerifyIds(bool verify_ids) { verify_ids_ = verify_ids; } + + // Returns whether to allow for imported symbols to have no corresponding + // exported symbols + bool GetAllowPartialLinkage() const { return allow_partial_linkage_; } + + // Sets whether to allow for imported symbols to have no corresponding + // exported symbols + void SetAllowPartialLinkage(bool allow_partial_linkage) { + allow_partial_linkage_ = allow_partial_linkage; + } + + private: + bool create_library_; + bool verify_ids_; + bool allow_partial_linkage_; +}; + +// Links one or more SPIR-V modules into a new SPIR-V module. That is, combine +// several SPIR-V modules into one, resolving link dependencies between them. +// +// At least one binary has to be provided in |binaries|. Those binaries do not +// have to be valid, but they should be at least parseable. +// The functions can fail due to the following: +// * The given context was not initialised using `spvContextCreate()`; +// * No input modules were given; +// * One or more of those modules were not parseable; +// * The input modules used different addressing or memory models; +// * The ID or global variable number limit were exceeded; +// * Some entry points were defined multiple times; +// * Some imported symbols did not have an exported counterpart; +// * Possibly other reasons. +spv_result_t Link(const Context& context, + const std::vector>& binaries, + std::vector* linked_binary, + const LinkerOptions& options = LinkerOptions()); +spv_result_t Link(const Context& context, const uint32_t* const* binaries, + const size_t* binary_sizes, size_t num_binaries, + std::vector* linked_binary, + const LinkerOptions& options = LinkerOptions()); + +} // namespace spvtools + +#endif // INCLUDE_SPIRV_TOOLS_LINKER_HPP_ diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp new file mode 100644 index 000000000..b59329db4 --- /dev/null +++ b/include/spirv-tools/optimizer.hpp @@ -0,0 +1,731 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef INCLUDE_SPIRV_TOOLS_OPTIMIZER_HPP_ +#define INCLUDE_SPIRV_TOOLS_OPTIMIZER_HPP_ + +#include +#include +#include +#include +#include + +#include "libspirv.hpp" + +namespace spvtools { + +namespace opt { +class Pass; +} + +// C++ interface for SPIR-V optimization functionalities. It wraps the context +// (including target environment and the corresponding SPIR-V grammar) and +// provides methods for registering optimization passes and optimizing. +// +// Instances of this class provides basic thread-safety guarantee. +class Optimizer { + public: + // The token for an optimization pass. It is returned via one of the + // Create*Pass() standalone functions at the end of this header file and + // consumed by the RegisterPass() method. Tokens are one-time objects that + // only support move; copying is not allowed. + struct PassToken { + struct Impl; // Opaque struct for holding inernal data. + + PassToken(std::unique_ptr); + + // Tokens for built-in passes should be created using Create*Pass functions + // below; for out-of-tree passes, use this constructor instead. + // Note that this API isn't guaranteed to be stable and may change without + // preserving source or binary compatibility in the future. + PassToken(std::unique_ptr&& pass); + + // Tokens can only be moved. Copying is disabled. + PassToken(const PassToken&) = delete; + PassToken(PassToken&&); + PassToken& operator=(const PassToken&) = delete; + PassToken& operator=(PassToken&&); + + ~PassToken(); + + std::unique_ptr impl_; // Unique pointer to internal data. + }; + + // Constructs an instance with the given target |env|, which is used to decode + // the binaries to be optimized later. + // + // The constructed instance will have an empty message consumer, which just + // ignores all messages from the library. Use SetMessageConsumer() to supply + // one if messages are of concern. + explicit Optimizer(spv_target_env env); + + // Disables copy/move constructor/assignment operations. + Optimizer(const Optimizer&) = delete; + Optimizer(Optimizer&&) = delete; + Optimizer& operator=(const Optimizer&) = delete; + Optimizer& operator=(Optimizer&&) = delete; + + // Destructs this instance. + ~Optimizer(); + + // Sets the message consumer to the given |consumer|. The |consumer| will be + // invoked once for each message communicated from the library. + void SetMessageConsumer(MessageConsumer consumer); + + // Returns a reference to the registered message consumer. + const MessageConsumer& consumer() const; + + // Registers the given |pass| to this optimizer. Passes will be run in the + // exact order of registration. The token passed in will be consumed by this + // method. + Optimizer& RegisterPass(PassToken&& pass); + + // Registers passes that attempt to improve performance of generated code. + // This sequence of passes is subject to constant review and will change + // from time to time. + Optimizer& RegisterPerformancePasses(); + + // Registers passes that attempt to improve the size of generated code. + // This sequence of passes is subject to constant review and will change + // from time to time. + Optimizer& RegisterSizePasses(); + + // Registers passes that have been prescribed for WebGPU environments. + // This sequence of passes is subject to constant review and will change + // from time to time. + Optimizer& RegisterWebGPUPasses(); + + // Registers passes that attempt to legalize the generated code. + // + // Note: this recipe is specially designed for legalizing SPIR-V. It should be + // used by compilers after translating HLSL source code literally. It should + // *not* be used by general workloads for performance or size improvement. + // + // This sequence of passes is subject to constant review and will change + // from time to time. + Optimizer& RegisterLegalizationPasses(); + + // Register passes specified in the list of |flags|. Each flag must be a + // string of a form accepted by Optimizer::FlagHasValidForm(). + // + // If the list of flags contains an invalid entry, it returns false and an + // error message is emitted to the MessageConsumer object (use + // Optimizer::SetMessageConsumer to define a message consumer, if needed). + // + // If all the passes are registered successfully, it returns true. + bool RegisterPassesFromFlags(const std::vector& flags); + + // Registers the optimization pass associated with |flag|. This only accepts + // |flag| values of the form "--pass_name[=pass_args]". If no such pass + // exists, it returns false. Otherwise, the pass is registered and it returns + // true. + // + // The following flags have special meaning: + // + // -O: Registers all performance optimization passes + // (Optimizer::RegisterPerformancePasses) + // + // -Os: Registers all size optimization passes + // (Optimizer::RegisterSizePasses). + // + // --legalize-hlsl: Registers all passes that legalize SPIR-V generated by an + // HLSL front-end. + bool RegisterPassFromFlag(const std::string& flag); + + // Validates that |flag| has a valid format. Strings accepted: + // + // --pass_name[=pass_args] + // -O + // -Os + // + // If |flag| takes one of the forms above, it returns true. Otherwise, it + // returns false. + bool FlagHasValidForm(const std::string& flag) const; + + // Allows changing, after creation time, the target environment to be + // optimized for. Should be called before calling Run(). + void SetTargetEnv(const spv_target_env env); + + // Optimizes the given SPIR-V module |original_binary| and writes the + // optimized binary into |optimized_binary|. + // Returns true on successful optimization, whether or not the module is + // modified. Returns false if |original_binary| fails to validate or if errors + // occur when processing |original_binary| using any of the registered passes. + // In that case, no further passes are executed and the contents in + // |optimized_binary| may be invalid. + // + // It's allowed to alias |original_binary| to the start of |optimized_binary|. + bool Run(const uint32_t* original_binary, size_t original_binary_size, + std::vector* optimized_binary) const; + + // DEPRECATED: Same as above, except passes |options| to the validator when + // trying to validate the binary. If |skip_validation| is true, then the + // caller is guaranteeing that |original_binary| is valid, and the validator + // will not be run. The |max_id_bound| is the limit on the max id in the + // module. + bool Run(const uint32_t* original_binary, const size_t original_binary_size, + std::vector* optimized_binary, + const ValidatorOptions& options, bool skip_validation) const; + + // Same as above, except it takes an options object. See the documentation + // for |OptimizerOptions| to see which options can be set. + bool Run(const uint32_t* original_binary, const size_t original_binary_size, + std::vector* optimized_binary, + const spv_optimizer_options opt_options) const; + + // Returns a vector of strings with all the pass names added to this + // optimizer's pass manager. These strings are valid until the associated + // pass manager is destroyed. + std::vector GetPassNames() const; + + // Sets the option to print the disassembly before each pass and after the + // last pass. If |out| is null, then no output is generated. Otherwise, + // output is sent to the |out| output stream. + Optimizer& SetPrintAll(std::ostream* out); + + // Sets the option to print the resource utilization of each pass. If |out| + // is null, then no output is generated. Otherwise, output is sent to the + // |out| output stream. + Optimizer& SetTimeReport(std::ostream* out); + + private: + struct Impl; // Opaque struct for holding internal data. + std::unique_ptr impl_; // Unique pointer to internal data. +}; + +// Creates a null pass. +// A null pass does nothing to the SPIR-V module to be optimized. +Optimizer::PassToken CreateNullPass(); + +// Creates a strip-debug-info pass. +// A strip-debug-info pass removes all debug instructions (as documented in +// Section 3.32.2 of the SPIR-V spec) of the SPIR-V module to be optimized. +Optimizer::PassToken CreateStripDebugInfoPass(); + +// Creates a strip-reflect-info pass. +// A strip-reflect-info pass removes all reflections instructions. +// For now, this is limited to removing decorations defined in +// SPV_GOOGLE_hlsl_functionality1. The coverage may expand in +// the future. +Optimizer::PassToken CreateStripReflectInfoPass(); + +// Creates an eliminate-dead-functions pass. +// An eliminate-dead-functions pass will remove all functions that are not in +// the call trees rooted at entry points and exported functions. These +// functions are not needed because they will never be called. +Optimizer::PassToken CreateEliminateDeadFunctionsPass(); + +// Creates a set-spec-constant-default-value pass from a mapping from spec-ids +// to the default values in the form of string. +// A set-spec-constant-default-value pass sets the default values for the +// spec constants that have SpecId decorations (i.e., those defined by +// OpSpecConstant{|True|False} instructions). +Optimizer::PassToken CreateSetSpecConstantDefaultValuePass( + const std::unordered_map& id_value_map); + +// Creates a set-spec-constant-default-value pass from a mapping from spec-ids +// to the default values in the form of bit pattern. +// A set-spec-constant-default-value pass sets the default values for the +// spec constants that have SpecId decorations (i.e., those defined by +// OpSpecConstant{|True|False} instructions). +Optimizer::PassToken CreateSetSpecConstantDefaultValuePass( + const std::unordered_map>& id_value_map); + +// Creates a flatten-decoration pass. +// A flatten-decoration pass replaces grouped decorations with equivalent +// ungrouped decorations. That is, it replaces each OpDecorationGroup +// instruction and associated OpGroupDecorate and OpGroupMemberDecorate +// instructions with equivalent OpDecorate and OpMemberDecorate instructions. +// The pass does not attempt to preserve debug information for instructions +// it removes. +Optimizer::PassToken CreateFlattenDecorationPass(); + +// Creates a freeze-spec-constant-value pass. +// A freeze-spec-constant pass specializes the value of spec constants to +// their default values. This pass only processes the spec constants that have +// SpecId decorations (defined by OpSpecConstant, OpSpecConstantTrue, or +// OpSpecConstantFalse instructions) and replaces them with their normal +// counterparts (OpConstant, OpConstantTrue, or OpConstantFalse). The +// corresponding SpecId annotation instructions will also be removed. This +// pass does not fold the newly added normal constants and does not process +// other spec constants defined by OpSpecConstantComposite or +// OpSpecConstantOp. +Optimizer::PassToken CreateFreezeSpecConstantValuePass(); + +// Creates a fold-spec-constant-op-and-composite pass. +// A fold-spec-constant-op-and-composite pass folds spec constants defined by +// OpSpecConstantOp or OpSpecConstantComposite instruction, to normal Constants +// defined by OpConstantTrue, OpConstantFalse, OpConstant, OpConstantNull, or +// OpConstantComposite instructions. Note that spec constants defined with +// OpSpecConstant, OpSpecConstantTrue, or OpSpecConstantFalse instructions are +// not handled, as these instructions indicate their value are not determined +// and can be changed in future. A spec constant is foldable if all of its +// value(s) can be determined from the module. E.g., an integer spec constant +// defined with OpSpecConstantOp instruction can be folded if its value won't +// change later. This pass will replace the original OpSpecContantOp instruction +// with an OpConstant instruction. When folding composite spec constants, +// new instructions may be inserted to define the components of the composite +// constant first, then the original spec constants will be replaced by +// OpConstantComposite instructions. +// +// There are some operations not supported yet: +// OpSConvert, OpFConvert, OpQuantizeToF16 and +// all the operations under Kernel capability. +// TODO(qining): Add support for the operations listed above. +Optimizer::PassToken CreateFoldSpecConstantOpAndCompositePass(); + +// Creates a unify-constant pass. +// A unify-constant pass de-duplicates the constants. Constants with the exact +// same value and identical form will be unified and only one constant will +// be kept for each unique pair of type and value. +// There are several cases not handled by this pass: +// 1) Constants defined by OpConstantNull instructions (null constants) and +// constants defined by OpConstantFalse, OpConstant or OpConstantComposite +// with value 0 (zero-valued normal constants) are not considered equivalent. +// So null constants won't be used to replace zero-valued normal constants, +// vice versa. +// 2) Whenever there are decorations to the constant's result id id, the +// constant won't be handled, which means, it won't be used to replace any +// other constants, neither can other constants replace it. +// 3) NaN in float point format with different bit patterns are not unified. +Optimizer::PassToken CreateUnifyConstantPass(); + +// Creates a eliminate-dead-constant pass. +// A eliminate-dead-constant pass removes dead constants, including normal +// contants defined by OpConstant, OpConstantComposite, OpConstantTrue, or +// OpConstantFalse and spec constants defined by OpSpecConstant, +// OpSpecConstantComposite, OpSpecConstantTrue, OpSpecConstantFalse or +// OpSpecConstantOp. +Optimizer::PassToken CreateEliminateDeadConstantPass(); + +// Creates a strength-reduction pass. +// A strength-reduction pass will look for opportunities to replace an +// instruction with an equivalent and less expensive one. For example, +// multiplying by a power of 2 can be replaced by a bit shift. +Optimizer::PassToken CreateStrengthReductionPass(); + +// Creates a block merge pass. +// This pass searches for blocks with a single Branch to a block with no +// other predecessors and merges the blocks into a single block. Continue +// blocks and Merge blocks are not candidates for the second block. +// +// The pass is most useful after Dead Branch Elimination, which can leave +// such sequences of blocks. Merging them makes subsequent passes more +// effective, such as single block local store-load elimination. +// +// While this pass reduces the number of occurrences of this sequence, at +// this time it does not guarantee all such sequences are eliminated. +// +// Presence of phi instructions can inhibit this optimization. Handling +// these is left for future improvements. +Optimizer::PassToken CreateBlockMergePass(); + +// Creates an exhaustive inline pass. +// An exhaustive inline pass attempts to exhaustively inline all function +// calls in all functions in an entry point call tree. The intent is to enable, +// albeit through brute force, analysis and optimization across function +// calls by subsequent optimization passes. As the inlining is exhaustive, +// there is no attempt to optimize for size or runtime performance. Functions +// that are not in the call tree of an entry point are not changed. +Optimizer::PassToken CreateInlineExhaustivePass(); + +// Creates an opaque inline pass. +// An opaque inline pass inlines all function calls in all functions in all +// entry point call trees where the called function contains an opaque type +// in either its parameter types or return type. An opaque type is currently +// defined as Image, Sampler or SampledImage. The intent is to enable, albeit +// through brute force, analysis and optimization across these function calls +// by subsequent passes in order to remove the storing of opaque types which is +// not legal in Vulkan. Functions that are not in the call tree of an entry +// point are not changed. +Optimizer::PassToken CreateInlineOpaquePass(); + +// Creates a single-block local variable load/store elimination pass. +// For every entry point function, do single block memory optimization of +// function variables referenced only with non-access-chain loads and stores. +// For each targeted variable load, if previous store to that variable in the +// block, replace the load's result id with the value id of the store. +// If previous load within the block, replace the current load's result id +// with the previous load's result id. In either case, delete the current +// load. Finally, check if any remaining stores are useless, and delete store +// and variable if possible. +// +// The presence of access chain references and function calls can inhibit +// the above optimization. +// +// Only modules with relaxed logical addressing (see opt/instruction.h) are +// currently processed. +// +// This pass is most effective if preceeded by Inlining and +// LocalAccessChainConvert. This pass will reduce the work needed to be done +// by LocalSingleStoreElim and LocalMultiStoreElim. +// +// Only functions in the call tree of an entry point are processed. +Optimizer::PassToken CreateLocalSingleBlockLoadStoreElimPass(); + +// Create dead branch elimination pass. +// For each entry point function, this pass will look for SelectionMerge +// BranchConditionals with constant condition and convert to a Branch to +// the indicated label. It will delete resulting dead blocks. +// +// For all phi functions in merge block, replace all uses with the id +// corresponding to the living predecessor. +// +// Note that some branches and blocks may be left to avoid creating invalid +// control flow. Improving this is left to future work. +// +// This pass is most effective when preceeded by passes which eliminate +// local loads and stores, effectively propagating constant values where +// possible. +Optimizer::PassToken CreateDeadBranchElimPass(); + +// Creates an SSA local variable load/store elimination pass. +// For every entry point function, eliminate all loads and stores of function +// scope variables only referenced with non-access-chain loads and stores. +// Eliminate the variables as well. +// +// The presence of access chain references and function calls can inhibit +// the above optimization. +// +// Only shader modules with relaxed logical addressing (see opt/instruction.h) +// are currently processed. Currently modules with any extensions enabled are +// not processed. This is left for future work. +// +// This pass is most effective if preceeded by Inlining and +// LocalAccessChainConvert. LocalSingleStoreElim and LocalSingleBlockElim +// will reduce the work that this pass has to do. +Optimizer::PassToken CreateLocalMultiStoreElimPass(); + +// Creates a local access chain conversion pass. +// A local access chain conversion pass identifies all function scope +// variables which are accessed only with loads, stores and access chains +// with constant indices. It then converts all loads and stores of such +// variables into equivalent sequences of loads, stores, extracts and inserts. +// +// This pass only processes entry point functions. It currently only converts +// non-nested, non-ptr access chains. It does not process modules with +// non-32-bit integer types present. Optional memory access options on loads +// and stores are ignored as we are only processing function scope variables. +// +// This pass unifies access to these variables to a single mode and simplifies +// subsequent analysis and elimination of these variables along with their +// loads and stores allowing values to propagate to their points of use where +// possible. +Optimizer::PassToken CreateLocalAccessChainConvertPass(); + +// Creates a local single store elimination pass. +// For each entry point function, this pass eliminates loads and stores for +// function scope variable that are stored to only once, where possible. Only +// whole variable loads and stores are eliminated; access-chain references are +// not optimized. Replace all loads of such variables with the value that is +// stored and eliminate any resulting dead code. +// +// Currently, the presence of access chains and function calls can inhibit this +// pass, however the Inlining and LocalAccessChainConvert passes can make it +// more effective. In additional, many non-load/store memory operations are +// not supported and will prohibit optimization of a function. Support of +// these operations are future work. +// +// Only shader modules with relaxed logical addressing (see opt/instruction.h) +// are currently processed. +// +// This pass will reduce the work needed to be done by LocalSingleBlockElim +// and LocalMultiStoreElim and can improve the effectiveness of other passes +// such as DeadBranchElimination which depend on values for their analysis. +Optimizer::PassToken CreateLocalSingleStoreElimPass(); + +// Creates an insert/extract elimination pass. +// This pass processes each entry point function in the module, searching for +// extracts on a sequence of inserts. It further searches the sequence for an +// insert with indices identical to the extract. If such an insert can be +// found before hitting a conflicting insert, the extract's result id is +// replaced with the id of the values from the insert. +// +// Besides removing extracts this pass enables subsequent dead code elimination +// passes to delete the inserts. This pass performs best after access chains are +// converted to inserts and extracts and local loads and stores are eliminated. +Optimizer::PassToken CreateInsertExtractElimPass(); + +// Creates a dead insert elimination pass. +// This pass processes each entry point function in the module, searching for +// unreferenced inserts into composite types. These are most often unused +// stores to vector components. They are unused because they are never +// referenced, or because there is another insert to the same component between +// the insert and the reference. After removing the inserts, dead code +// elimination is attempted on the inserted values. +// +// This pass performs best after access chains are converted to inserts and +// extracts and local loads and stores are eliminated. While executing this +// pass can be advantageous on its own, it is also advantageous to execute +// this pass after CreateInsertExtractPass() as it will remove any unused +// inserts created by that pass. +Optimizer::PassToken CreateDeadInsertElimPass(); + +// Creates a pass to consolidate uniform references. +// For each entry point function in the module, first change all constant index +// access chain loads into equivalent composite extracts. Then consolidate +// identical uniform loads into one uniform load. Finally, consolidate +// identical uniform extracts into one uniform extract. This may require +// moving a load or extract to a point which dominates all uses. +// +// This pass requires a module to have structured control flow ie shader +// capability. It also requires logical addressing ie Addresses capability +// is not enabled. It also currently does not support any extensions. +// +// This pass currently only optimizes loads with a single index. +Optimizer::PassToken CreateCommonUniformElimPass(); + +// Create aggressive dead code elimination pass +// This pass eliminates unused code from the module. In addition, +// it detects and eliminates code which may have spurious uses but which do +// not contribute to the output of the function. The most common cause of +// such code sequences is summations in loops whose result is no longer used +// due to dead code elimination. This optimization has additional compile +// time cost over standard dead code elimination. +// +// This pass only processes entry point functions. It also only processes +// shaders with relaxed logical addressing (see opt/instruction.h). It +// currently will not process functions with function calls. Unreachable +// functions are deleted. +// +// This pass will be made more effective by first running passes that remove +// dead control flow and inlines function calls. +// +// This pass can be especially useful after running Local Access Chain +// Conversion, which tends to cause cycles of dead code to be left after +// Store/Load elimination passes are completed. These cycles cannot be +// eliminated with standard dead code elimination. +Optimizer::PassToken CreateAggressiveDCEPass(); + +// Create line propagation pass +// This pass propagates line information based on the rules for OpLine and +// OpNoline and clones an appropriate line instruction into every instruction +// which does not already have debug line instructions. +// +// This pass is intended to maximize preservation of source line information +// through passes which delete, move and clone instructions. Ideally it should +// be run before any such pass. It is a bookend pass with EliminateDeadLines +// which can be used to remove redundant line instructions at the end of a +// run of such passes and reduce final output file size. +Optimizer::PassToken CreatePropagateLineInfoPass(); + +// Create dead line elimination pass +// This pass eliminates redundant line instructions based on the rules for +// OpLine and OpNoline. Its main purpose is to reduce the size of the file +// need to store the SPIR-V without losing line information. +// +// This is a bookend pass with PropagateLines which attaches line instructions +// to every instruction to preserve line information during passes which +// delete, move and clone instructions. DeadLineElim should be run after +// PropagateLines and all such subsequent passes. Normally it would be one +// of the last passes to be run. +Optimizer::PassToken CreateRedundantLineInfoElimPass(); + +// Creates a compact ids pass. +// The pass remaps result ids to a compact and gapless range starting from %1. +Optimizer::PassToken CreateCompactIdsPass(); + +// Creates a remove duplicate pass. +// This pass removes various duplicates: +// * duplicate capabilities; +// * duplicate extended instruction imports; +// * duplicate types; +// * duplicate decorations. +Optimizer::PassToken CreateRemoveDuplicatesPass(); + +// Creates a CFG cleanup pass. +// This pass removes cruft from the control flow graph of functions that are +// reachable from entry points and exported functions. It currently includes the +// following functionality: +// +// - Removal of unreachable basic blocks. +Optimizer::PassToken CreateCFGCleanupPass(); + +// Create dead variable elimination pass. +// This pass will delete module scope variables, along with their decorations, +// that are not referenced. +Optimizer::PassToken CreateDeadVariableEliminationPass(); + +// create merge return pass. +// changes functions that have multiple return statements so they have a single +// return statement. +// +// for structured control flow it is assumed that the only unreachable blocks in +// the function are trivial merge and continue blocks. +// +// a trivial merge block contains the label and an opunreachable instructions, +// nothing else. a trivial continue block contain a label and an opbranch to +// the header, nothing else. +// +// these conditions are guaranteed to be met after running dead-branch +// elimination. +Optimizer::PassToken CreateMergeReturnPass(); + +// Create value numbering pass. +// This pass will look for instructions in the same basic block that compute the +// same value, and remove the redundant ones. +Optimizer::PassToken CreateLocalRedundancyEliminationPass(); + +// Create LICM pass. +// This pass will look for invariant instructions inside loops and hoist them to +// the loops preheader. +Optimizer::PassToken CreateLoopInvariantCodeMotionPass(); + +// Creates a loop fission pass. +// This pass will split all top level loops whose register pressure exceedes the +// given |threshold|. +Optimizer::PassToken CreateLoopFissionPass(size_t threshold); + +// Creates a loop fusion pass. +// This pass will look for adjacent loops that are compatible and legal to be +// fused. The fuse all such loops as long as the register usage for the fused +// loop stays under the threshold defined by |max_registers_per_loop|. +Optimizer::PassToken CreateLoopFusionPass(size_t max_registers_per_loop); + +// Creates a loop peeling pass. +// This pass will look for conditions inside a loop that are true or false only +// for the N first or last iteration. For loop with such condition, those N +// iterations of the loop will be executed outside of the main loop. +// To limit code size explosion, the loop peeling can only happen if the code +// size growth for each loop is under |code_growth_threshold|. +Optimizer::PassToken CreateLoopPeelingPass(); + +// Creates a loop unswitch pass. +// This pass will look for loop independent branch conditions and move the +// condition out of the loop and version the loop based on the taken branch. +// Works best after LICM and local multi store elimination pass. +Optimizer::PassToken CreateLoopUnswitchPass(); + +// Create global value numbering pass. +// This pass will look for instructions where the same value is computed on all +// paths leading to the instruction. Those instructions are deleted. +Optimizer::PassToken CreateRedundancyEliminationPass(); + +// Create scalar replacement pass. +// This pass replaces composite function scope variables with variables for each +// element if those elements are accessed individually. The parameter is a +// limit on the number of members in the composite variable that the pass will +// consider replacing. +Optimizer::PassToken CreateScalarReplacementPass(uint32_t size_limit = 100); + +// Create a private to local pass. +// This pass looks for variables delcared in the private storage class that are +// used in only one function. Those variables are moved to the function storage +// class in the function that they are used. +Optimizer::PassToken CreatePrivateToLocalPass(); + +// Creates a conditional constant propagation (CCP) pass. +// This pass implements the SSA-CCP algorithm in +// +// Constant propagation with conditional branches, +// Wegman and Zadeck, ACM TOPLAS 13(2):181-210. +// +// Constant values in expressions and conditional jumps are folded and +// simplified. This may reduce code size by removing never executed jump targets +// and computations with constant operands. +Optimizer::PassToken CreateCCPPass(); + +// Creates a workaround driver bugs pass. This pass attempts to work around +// a known driver bug (issue #1209) by identifying the bad code sequences and +// rewriting them. +// +// Current workaround: Avoid OpUnreachable instructions in loops. +Optimizer::PassToken CreateWorkaround1209Pass(); + +// Creates a pass that converts if-then-else like assignments into OpSelect. +Optimizer::PassToken CreateIfConversionPass(); + +// Creates a pass that will replace instructions that are not valid for the +// current shader stage by constants. Has no effect on non-shader modules. +Optimizer::PassToken CreateReplaceInvalidOpcodePass(); + +// Creates a pass that simplifies instructions using the instruction folder. +Optimizer::PassToken CreateSimplificationPass(); + +// Create loop unroller pass. +// Creates a pass to unroll loops which have the "Unroll" loop control +// mask set. The loops must meet a specific criteria in order to be unrolled +// safely this criteria is checked before doing the unroll by the +// LoopUtils::CanPerformUnroll method. Any loop that does not meet the criteria +// won't be unrolled. See CanPerformUnroll LoopUtils.h for more information. +Optimizer::PassToken CreateLoopUnrollPass(bool fully_unroll, int factor = 0); + +// Create the SSA rewrite pass. +// This pass converts load/store operations on function local variables into +// operations on SSA IDs. This allows SSA optimizers to act on these variables. +// Only variables that are local to the function and of supported types are +// processed (see IsSSATargetVar for details). +Optimizer::PassToken CreateSSARewritePass(); + +// Create copy propagate arrays pass. +// This pass looks to copy propagate memory references for arrays. It looks +// for specific code patterns to recognize array copies. +Optimizer::PassToken CreateCopyPropagateArraysPass(); + +// Create a vector dce pass. +// This pass looks for components of vectors that are unused, and removes them +// from the vector. Note this would still leave around lots of dead code that +// a pass of ADCE will be able to remove. +Optimizer::PassToken CreateVectorDCEPass(); + +// Create a pass to reduce the size of loads. +// This pass looks for loads of structures where only a few of its members are +// used. It replaces the loads feeding an OpExtract with an OpAccessChain and +// a load of the specific elements. +Optimizer::PassToken CreateReduceLoadSizePass(); + +// Create a pass to combine chained access chains. +// This pass looks for access chains fed by other access chains and combines +// them into a single instruction where possible. +Optimizer::PassToken CreateCombineAccessChainsPass(); + +// Create a pass to instrument bindless descriptor checking +// This pass instruments all bindless references to check that descriptor +// array indices are inbounds. If the reference is invalid, a record is +// written to the debug output buffer (if space allows) and a null value is +// returned. This pass is designed to support bindless validation in the Vulkan +// validation layers. +// +// Dead code elimination should be run after this pass as the original, +// potentially invalid code is not removed and could cause undefined behavior, +// including crashes. It may also be beneficial to run Simplification +// (ie Constant Propagation), DeadBranchElim and BlockMerge after this pass to +// optimize instrument code involving the testing of compile-time constants. +// It is also generally recommended that this pass (and all +// instrumentation passes) be run after any legalization and optimization +// passes. This will give better analysis for the instrumentation and avoid +// potentially de-optimizing the instrument code, for example, inlining +// the debug record output function throughout the module. +// +// The instrumentation will read and write buffers in debug +// descriptor set |desc_set|. It will write |shader_id| in each output record +// to identify the shader module which generated the record. +// +// TODO(greg-lunarg): Add support for vk_ext_descriptor_indexing. +Optimizer::PassToken CreateInstBindlessCheckPass(uint32_t desc_set, + uint32_t shader_id); + +// Create a pass to upgrade to the VulkanKHR memory model. +// This pass upgrades the Logical GLSL450 memory model to Logical VulkanKHR. +// Additionally, it modifies memory, image, atomic and barrier operations to +// conform to that model's requirements. +Optimizer::PassToken CreateUpgradeMemoryModelPass(); + +// Create a pass to do code sinking. Code sinking is a transformation +// where an instruction is moved into a more deeply nested construct. +Optimizer::PassToken CreateCodeSinkingPass(); + +} // namespace spvtools + +#endif // INCLUDE_SPIRV_TOOLS_OPTIMIZER_HPP_ diff --git a/kokoro/android/build.sh b/kokoro/android/build.sh new file mode 100644 index 000000000..e31744fd1 --- /dev/null +++ b/kokoro/android/build.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Android Build Script. + +# Fail on any error. +set -e +# Display commands being run. +set -x + +BUILD_ROOT=$PWD +SRC=$PWD/github/SPIRV-Tools +TARGET_ARCH="armeabi-v7a with NEON" +export ANDROID_NDK=/opt/android-ndk-r15c + +# Get NINJA. +wget -q https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip +unzip -q ninja-linux.zip +export PATH="$PWD:$PATH" +git clone --depth=1 https://github.com/taka-no-me/android-cmake.git android-cmake +export TOOLCHAIN_PATH=$PWD/android-cmake/android.toolchain.cmake + + +cd $SRC +git clone --depth=1 https://github.com/KhronosGroup/SPIRV-Headers external/spirv-headers +git clone --depth=1 https://github.com/google/googletest external/googletest +git clone --depth=1 https://github.com/google/effcee external/effcee +git clone --depth=1 https://github.com/google/re2 external/re2 + +mkdir build && cd $SRC/build + +# Invoke the build. +BUILD_SHA=${KOKORO_GITHUB_COMMIT:-$KOKORO_GITHUB_PULL_REQUEST_COMMIT} +echo $(date): Starting build... +cmake -DCMAKE_BUILD_TYPE=Release -DANDROID_NATIVE_API_LEVEL=android-14 -DANDROID_ABI="armeabi-v7a with NEON" -DSPIRV_BUILD_COMPRESSION=ON -DSPIRV_SKIP_TESTS=ON -DCMAKE_TOOLCHAIN_FILE=$TOOLCHAIN_PATH -GNinja -DANDROID_NDK=$ANDROID_NDK .. + +echo $(date): Build everything... +ninja +echo $(date): Build completed. diff --git a/kokoro/android/continuous.cfg b/kokoro/android/continuous.cfg new file mode 100644 index 000000000..3bdb17a57 --- /dev/null +++ b/kokoro/android/continuous.cfg @@ -0,0 +1,17 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous build configuration. +# +build_file: "SPIRV-Tools/kokoro/android/build.sh" diff --git a/kokoro/android/presubmit.cfg b/kokoro/android/presubmit.cfg new file mode 100644 index 000000000..21589ccc1 --- /dev/null +++ b/kokoro/android/presubmit.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/android/build.sh" diff --git a/kokoro/check-format/build.sh b/kokoro/check-format/build.sh new file mode 100644 index 000000000..2a8d50fb5 --- /dev/null +++ b/kokoro/check-format/build.sh @@ -0,0 +1,41 @@ +#!/bin/bash +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Android Build Script. + +# Fail on any error. +set -e +# Display commands being run. +set -x + +BUILD_ROOT=$PWD +SRC=$PWD/github/SPIRV-Tools + +# Get clang-format-5.0.0. +# Once kokoro upgrades the Ubuntu VMs, we can use 'apt-get install clang-format' +curl -L http://releases.llvm.org/5.0.0/clang+llvm-5.0.0-linux-x86_64-ubuntu14.04.tar.xz -o clang-llvm.tar.xz +tar xf clang-llvm.tar.xz +export PATH=$PWD/clang+llvm-5.0.0-linux-x86_64-ubuntu14.04/bin:$PATH + +cd $SRC +git clone --depth=1 https://github.com/KhronosGroup/SPIRV-Headers external/spirv-headers +git clone --depth=1 https://github.com/google/googletest external/googletest +git clone --depth=1 https://github.com/google/effcee external/effcee +git clone --depth=1 https://github.com/google/re2 external/re2 +curl -L http://llvm.org/svn/llvm-project/cfe/trunk/tools/clang-format/clang-format-diff.py -o utils/clang-format-diff.py; + +echo $(date): Check formatting... +./utils/check_code_format.sh; +echo $(date): check completed. diff --git a/kokoro/check-format/presubmit_check_format.cfg b/kokoro/check-format/presubmit_check_format.cfg new file mode 100644 index 000000000..1993289d6 --- /dev/null +++ b/kokoro/check-format/presubmit_check_format.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/check-format/build.sh" diff --git a/kokoro/img/linux.png b/kokoro/img/linux.png new file mode 100644 index 0000000000000000000000000000000000000000..ff066d979e1393d1f8d57198025c9a79aa4999c0 GIT binary patch literal 17369 zcmeI4c~}$I7Qlmm3L;1?h$0xGA{9a=8zgB6B4JZD0R5e(Z&;u2?WAuKVNWU0t1Q{-p~I-xp*t>b67% zhkb`MWM*Z3>oD~Y3$bn=rLXUqmg?tWq^}Q|=74JYf*9X?%^+vs38a+Eo5B-OyznLJ%JVyeE1pcoj^6aI2_2+UfJjUwNLN`*d}1`FXSU zMdlXkYj?KqjIOuOv(AfNzh~X1ja>Tz3YnZfbGQGzsON=|+iO1UnY8lD)hlPE&5{;e zD*tfLmOZb-KHhV!%)ET>7hN6?GM`V`8W!c~=208o?PFqT_0xm~tDmsF;mAGhNrq3Cc|clK~a>kQtL^{*B5_d@TMoT zyDraTwkmmLq)Gbtb06A`6I)oAN14; zHyv!+?{GQnk9`r|~X>YbpaDVjFB)|C8!Q)MSo0p{>@5(Owly7;&x$Cf#qN}#Rt|D-M z!u)D!?d_cS5`PzI>yiWUFHI}9zD%xu%&w-yy-}&yt1@iG?)DkW8~h%G-Z?jA`>Vu^ zi(Zdr`|O{x=v8w1xqy;|$xm9YhaO*X6TAjhO?$Z=K0YQj^=abkghxRr4U<4a_1RyQ#go& zcvuM%Q{v+!3Qmd}N#mD;+^d@@B%-E^a;Y21LmiM95fDuDl*(Zu)7ga#QK?j-E8CgM zWYK9%Cn61?GARI?LIubGl|uzMG=SK9k$A?)9akl_!i;QA67teu}mqJN{DK|pg@|abR&_} zf%>oBdBw~614$IU>=25S6i`N?Is=rUjD%3XjVw_frx}_MqQG%*JSzQe`kD1Xf5BDi4!N<9OOtHgFIi=YV=u z5gh_RNGwsG4j*wTXrqCQ5BP)mpc3XGwytCV`C)`nSsZ}Hp)nnir3nB!PP9EyJ_4jd zu_$$r51LCD6xIDg8N%R555(NEJT@N~4hIGM~LW%0V z;UR%w1(H#PPNtE3Yj*!kFlKOTpCV4oL-Oh@6bm^_7;uGXR5n>ar9)%}8>W(3Y>-N( zi5RYc0DxI8Od6M>-t&mjz(~}dG_8YU=*|FZgzB{u5Ww-1D3l;F0YAQGJYr-#kc)r; z0zw*Dz!12S8B7L?%yw~gC5z~6fWZQ2bXOO;W%Qo{uGLd@fdSA}1b}3w z0Qtj!nE)9CX#kl9i>NNHBAP3U#_V&gWB!)wAh{T+`(T_-pFWy@Cd3eK-U{U_R_8Vq z1zw$vP&i?5>p*Tb6Aa!|wgKq9QKGo_RRZLKCR!uhUfr8?+k>J zX8*;B|3WGhVA5Yp1QS`%ARuCqnE;IJYykw3*)W?;772v_8-%Gewky<^=s%wbZG{Gy z9BO}s{%{GaA-(CPYrkvr{f99$Xu3i{iX zl&gNY;viKRsR*E^`XNX7TWtfYM%zC^T}|T>HT~HvI?II(AkW~))oZJ(N88T-H@3QZ zwC#ZDT!~o8L(XK)P*Ke~R_bR!?c+%(^11@^T$oHca=f9N(H6nt6gVz|FGij%6zUU! zflN$egH{Y_IR9OSI*MLHZL)M;EHuZNwiS78LEe}s{jW@Kopk-@(BFA8Fr&ZIE6Nuh z1O)_J#BpKsL1}SZC?MD(jtiR)N{iz{0l^k=T-bb2S{xS&2)2mh!sdh0;nXA zAlM>~3!4v0i{nB8!4`2`*nCi092W`*wus}x=7ZAWxKKc_MI0A4ACwlyg#v;t;<&K+ zptLwH6cB6?$A!%YrNwcffMAO_E^Iz1EshHX1Y5*$Ve>(0aa<@M*dmS#n-5Bh<3a(! z7I9qId{9~(7YYcrh>6R1;7iD`1o`H1GV&E;=J?2Jw zgTq3ns1=dyo}~7!t+jVbUzA?opF%m;^@{c6MHjE>S^8@h@@?Z%b`8|oJTWx(mXpdv zzU9z8O3&nJcV=ZCiy_Y{ea5w=*%*9X*;Z`-iBmSc_z%0L6uT26yRElJ85ORC^~hFZ zZAA`(0H#l25F5*dERHzY0b=kd0x?Rh zXA1Lgn#}w8Qsr*TV#x7GLg@_q`(4ooice2Jm$0!m*QVz?p!u$+y=BqJ%PWR0-xPRd zU3$Yz-W|)@iK{POe5hY|?ewrMw_GL76JKTI?q)SP&b(+!SP)0;_E#>mUvSh~Phxb> zr4cCDa>D+l$tuXL_TiB`thHmp9Rp|0tPHvUQ!aW=n#MZ2l%;5C+i$h>?ERg~L)JW{ zo5qYj`Zy8+O5siKiySL9jp(wO6Mp`bS%cewvJKCFUz6uyz0oXI+VCjp zlk#6p*4oX>88OGCb<(M-+N=J1y-f;M1ACcG_0IcMK@JXX6DN_b{9r*B7L2$Y$p}3u zZ{|ZH&OhjF5A`xa!p=M))t7Y8K6@QM!6xePj@aF2y>jPwyl8yc^ujI2arJ>3BfIAY zhHhz}LAEzWb!CpWFT9|N*!InduUJ)nxlFL``>#n061d+HZZ_B0=_D!l=T_fxkc6@A;wf(k4ZU=N426bmPoB#Mx^m*4vYdg8!>*M`T zK3zEWRBQdPZGpgoqN^si?J7^yAikpgyY9+O@XzBT&2T+TxZLx4Kra#m7z(3>O5a-B16V`2acG z`#a~{`JQuc=AT)gmh$X4??7(|g2t(ml^M)unElt&o%tKuYr4*Sc$t#3Ef6$jqW#wm z+R@+-L36@LO%|P{o{wrNLllNnc|?@WU}Dq|B%5n9VcJ51=H(Fuq*2cQ@k9%sN8)n+ zLaAD)HYE@`GP%@DWR|9Aw55gGSe!ps;VrYFOaKExV?3KdZ?vE`Io}Z%Wo-K}%;z~I zbfKIdXHUq>Qm63}C^Ny6Mnwy>2!ilpVxti09I;3m$rA|?DJ+bI5urecpokC^33>e& zU*XMIGBci!W+;>T(=q?a`8t|5p)hQ)Ny^Bx$HnrfVp`7U+Y`C2{`)eRT#1a9es&B+*oK*4BuWSm=Y(rr zKBi)`-mx@X3ln<6Kp1HYn3$=FTJ6d`?0OA`VN+Y^#1aNU zKhj}6Et=&f0?r^TRIyo0B$hC@5Zd>|g7#ArgnbYzS`$oxn& z$Q)ES2Ng+YGp#8U0zsV0Au-exPUbHk6hjoPVOm-mi;9I{3_v+F2D5{3jK+Rd6i^Gq z!L?{UWj0{6f;3!H21yV7B3nZfa7y%|0<_q$0Tr8Fn;uu6LazIq8LHEZ(D=BTUeZTaF z#(!gI!2Pk%sb-Qn>o7f_v9JEOrWhv7W2ur!3)7*?fzD<6*KwL<5=DbYA5G{<$Ejq( z%of7F=j8lHo6dENxlSp&k7MtmnAU!tDYW)2LEtiY(CEOjT%P@I_Qsm;Lm09t@exg94$5mP*A^89a1&s7D%UBlKBHlG(u)dtXYJ`*QdUofy_J zdn89-qTkV3D&Q;KQAy`0^T5D7tH7>@mBB@JZH3?27`UU~IWQQGO9IkhwK-g@h(G~{ z3#7qnbGTR$fdUQ}NQ2epaIqo+1spDr2CL2CVnqZBI9wnNR-41ciU<^NxIh}LHiwH9 z5h&nrfizfc4i_sTP{82=X|UQHE>=XKfWrmSV6{11tcXAXhYO^^YIC?)5rF~@7f6HE z=5Vnh0tFl{kOr&G;bKJu3OHOK4OW}O#fk_NaJWDktTu;>6%i=laDg;fZ4MVJB2d8L z0%@??94=Nwpn$^#(qOeYT&##d0f!5u!D@53SP_8&4i`v+)#h-qA_4^*E|3PR{X|^e z1F!!PM&_MAEAw)nX^Bq`^I{-Ro1CGBpwgKTw0aE${rr&m?1G?TAq0J&3qfct1Wlo; zmwu20K~KG_QpRa)9WT}7==~P>pIf!m^Jt{cJoS}*eq(uOCnIe)7KK%W`Dp@^zb>0M z@snI=0>5d-jK6r!?rpBhs@NbM-TCLSAyrTyS>e&yXi#=$r0XKmjdfK`xGw|`}Bo2q{uJc)ZN`YsX^j&{p4q~`**$N4F%21XiwF)(=9v# zlIrgq@5}yc=QAZg>uS@c&fJ<4gfxA=*UO{onpdjkurTQr)%Fw0=TrT3ubw$KX^l!1 zaPZrNrKc5N%TLkb_mTQ%rxdkc%c(AUM-{O6lRceh{F)ofCw5%m_aIw#k1{u{*(NIT zTX^fIfHBtfH>)nT;_Y7fZdd$+YE!ge%cdeDa-lRt+%!D<*nJvS}!iEyE%dS z;Sbj@HB~3&NuD}v-eBG_>Yz&7GJnonONw@Mn|oF%Lq43aHtkS+b%;J9;c{f-l<-e| z7yZe-{&-!+iSi{r`*)7*X#6SSwF`}hb3NX4Z`z^`!i;oJoDQ2ZyUhK^z8I-`{_F42 z>*HUm7dN??Hgr8NIaacM=eCzO8(T3g#9kgwkJEii2Q_cY<=&qSF! zB^B=DpVPmD9IitFw>5r*zth;y5X$wpD zSNhzSVO_a<;}-+uXl2@zn z{o+OX*H{1G)*P{Cb7gn&x~`co^aaG%9%#Naxg{_3zV74N&fBS(XTF+-mqo|ae*n1P_WlzcCk?1)VQm7OIZGGQ{ zJAPlDu3sCvFP-?OcNP5ky@Nd|@ftzlRb}3P;9^bJimMgEjiuh@9hJu(7S&IhR`lTd za<3W6a;w_v0r|9rdp#&`_Hr|4KS%uX&aIQg_)G7ri=sydV}a?R6A4Kw^7Qf4DL39W i*G&F$T*|439#GI-UY#j*jF0_YU{zv@@=$zk`Tqdo!tUBpZR00s`Rd>nw1 zHJ$(vRD~xb8WNS$q=+^>07kXRSb!;ANA?B)S%^soBWaj{m5im}nqcSPqYMNB`RZBp;|r05(bDk2n0c_KuG{36!UpPe-;mfgj`UF3`a`A>{Nq6C*^XDMq_}HAE4Eza3P6A!UcI;9*;xz;1Fhw0XA_o zgpbWhmmdX2AbMP9z_l8d#V?$!%`gPB*_J@vN9VZGb=`q9L?=6vBG&}#xKIGd?aK&7 zx@~kBdbRb|P=t%Av2;vhAV@o?PdiI{=n*U|lN^dqU!ynRs-DSt#*;>JzkDwcM6vFq+%Nt{@aPz|ZBy zzOS7h6p^a5`gGVJ$J5~yjH}b6$hbX?c0IfEN*bzFYxU%9VRA^u?W@_&%9?A^NDX0t zH3$}|kdqAoIF3p&F@%Zu0ue_5ffz>s!XQV?OU5_|0`pM>NfyJP#0pWVv^VeNTA@WU zEbE1IZ7VIP79rjD+!iJjsUTDgbCP+1Vh#ia!W^+cq~eG%P!cG@uw*_c?h0h*tXH6D zJx*3WSZ&wGG8U8!CPFbtm@JZT_=t!ciwG5QV3AP8353Y8B_k3T69#sL=)+B~P!Tvm z=BwE*kI8Jd*0y-;?PpuN6l&aBqja#Iz$_~;nBBDmyNg#9-z&57vAWF!^Vq+qC=+cmbXhX19xd%l|ZuP&lK_Lu*QxwEatXDee@ zwe4L=WtPVaDOrWdiU3=yAGwM3YU{}wTmM8$HC1O=`v>ws@*p2V$)nTMUQIQfNpEVe zrkeIxzNx_ta$aXNO0%7nzG_56!L{KF149z5rdmBb)o_%_ff*K9|*{C zS#79_J`HE~XK1JBG_*~X-GjBYg4$Zi$9D1;1GoDpL$Bi6y$k)37d<2TBfV05F+nIm z=tTw>Js*me!9@W=FEY63`B1bBE(#EOk-u>mg+)M2yTUYIt`Iml;{b5XQU#KL!v z9sIP`cuW$|dgMtl{%Q(jY(2K>y8QOf^4mngmb%!RZ;Wz16FXMnxBzy%f7jz8@!1QX zNb1b5%cj;{REH0Nj(I*XGWmt3|YP!z-LyiT=Oci`(O_8Js`JBc{XT`+fa*VAh!6BfxE4?$QA* z&6CTAmRGLV?esq3`9^6&?P;&~7M=E-GREw_B(_Q``SI=R7xMi^KlA zh9`$gPtErpAQ;x-arx4f`{5f!1or5z-}Y17y&(_Qn)kdk_+qfz!7vG{t)oOyuk?5O z9dwafVe%l5`%Gbl|b~tcH&b=_S{ewl5{9Tr>$3@Z{BCZWxvPTk&I8D3*9&HHj?rnmD;k%LO!|E%G|j;Oyr z`qWY2`HtZ5YyOD$BF7~Sx#ho>?1H-5Z$f$H@4 zP}+s@_2>3YIS5{uZ9E=PUdJx3fWC?_3&!YXUK#S$+7Yqp>ggvr(Wc|#H!^-0#Z}B< zDc#o*`jK-`^@f9wujd@&-!B_9uw_o3V3X66N14@C*-jf8d_K6>^pW}M+4a23cVx~3 zN2q{2tk~>7_s-a>Bb*AhE)E@0QRn{iJFh(c_FXcqe2?E97r*8j?<0SxVSHWS2#tex zhv>p4ufnivPE9Ej5;F7d)cz^z>7$byujb!MV>>VKeym!)wRC>Lz48-o(_Qw3Oyn+I zw=DeQ<1Oo}L$)=xR!^Rz9TDjW3ilT@`ft79UucXEM#%3G<9F1gFY;2MWb11>xK z@V*4j-@cYdrG$6zH?^je9^j-G!`-vLac9eTMuck&B~Ih!9eooE}@URgV1?x0`ImoGFHtxa`1`R(}S1(MZ81F;8( z)}>#2F=Nti-`r3R-KqcZ$GO43WvR~}UG?6uFKgmvd!JRbpJ>C1n{J&e&pwaGrFbVE zXhMj|t@kRTb&0`$Zwh>(+cCXySKeJOub6A`*(c*iHo1 literal 0 HcmV?d00001 diff --git a/kokoro/linux-clang-debug/build.sh b/kokoro/linux-clang-debug/build.sh new file mode 100644 index 000000000..11b2968a6 --- /dev/null +++ b/kokoro/linux-clang-debug/build.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Linux Build Script. + +# Fail on any error. +set -e +# Display commands being run. +set -x + +SCRIPT_DIR=`dirname "$BASH_SOURCE"` +source $SCRIPT_DIR/../scripts/linux/build.sh DEBUG clang diff --git a/kokoro/linux-clang-debug/continuous.cfg b/kokoro/linux-clang-debug/continuous.cfg new file mode 100644 index 000000000..e92f059ed --- /dev/null +++ b/kokoro/linux-clang-debug/continuous.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous build configuration. +build_file: "SPIRV-Tools/kokoro/linux-clang-debug/build.sh" diff --git a/kokoro/linux-clang-debug/presubmit.cfg b/kokoro/linux-clang-debug/presubmit.cfg new file mode 100644 index 000000000..5011b445e --- /dev/null +++ b/kokoro/linux-clang-debug/presubmit.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/linux-clang-debug/build.sh" diff --git a/kokoro/linux-clang-release/build.sh b/kokoro/linux-clang-release/build.sh new file mode 100644 index 000000000..476433171 --- /dev/null +++ b/kokoro/linux-clang-release/build.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Linux Build Script. + +# Fail on any error. +set -e +# Display commands being run. +set -x + +SCRIPT_DIR=`dirname "$BASH_SOURCE"` +source $SCRIPT_DIR/../scripts/linux/build.sh RELEASE clang diff --git a/kokoro/linux-clang-release/continuous.cfg b/kokoro/linux-clang-release/continuous.cfg new file mode 100644 index 000000000..687434acc --- /dev/null +++ b/kokoro/linux-clang-release/continuous.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous build configuration. +build_file: "SPIRV-Tools/kokoro/linux-clang-release/build.sh" diff --git a/kokoro/linux-clang-release/presubmit.cfg b/kokoro/linux-clang-release/presubmit.cfg new file mode 100644 index 000000000..b7b9b5594 --- /dev/null +++ b/kokoro/linux-clang-release/presubmit.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/linux-clang-release/build.sh" diff --git a/kokoro/linux-gcc-debug/build.sh b/kokoro/linux-gcc-debug/build.sh new file mode 100644 index 000000000..3ef1e251b --- /dev/null +++ b/kokoro/linux-gcc-debug/build.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Linux Build Script. + +# Fail on any error. +set -e +# Display commands being run. +set -x + +SCRIPT_DIR=`dirname "$BASH_SOURCE"` +source $SCRIPT_DIR/../scripts/linux/build.sh DEBUG gcc diff --git a/kokoro/linux-gcc-debug/continuous.cfg b/kokoro/linux-gcc-debug/continuous.cfg new file mode 100644 index 000000000..4f8418d84 --- /dev/null +++ b/kokoro/linux-gcc-debug/continuous.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous build configuration. +build_file: "SPIRV-Tools/kokoro/linux-gcc-debug/build.sh" diff --git a/kokoro/linux-gcc-debug/presubmit.cfg b/kokoro/linux-gcc-debug/presubmit.cfg new file mode 100644 index 000000000..2d9fe5c99 --- /dev/null +++ b/kokoro/linux-gcc-debug/presubmit.cfg @@ -0,0 +1,17 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/linux-gcc-debug/build.sh" + diff --git a/kokoro/linux-gcc-release/build.sh b/kokoro/linux-gcc-release/build.sh new file mode 100644 index 000000000..3e97d8d3b --- /dev/null +++ b/kokoro/linux-gcc-release/build.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Linux Build Script. + +# Fail on any error. +set -e +# Display commands being run. +set -x + +SCRIPT_DIR=`dirname "$BASH_SOURCE"` +source $SCRIPT_DIR/../scripts/linux/build.sh RELEASE gcc diff --git a/kokoro/linux-gcc-release/continuous.cfg b/kokoro/linux-gcc-release/continuous.cfg new file mode 100644 index 000000000..41a0024e7 --- /dev/null +++ b/kokoro/linux-gcc-release/continuous.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous build configuration. +build_file: "SPIRV-Tools/kokoro/linux-gcc-release/build.sh" diff --git a/kokoro/linux-gcc-release/presubmit.cfg b/kokoro/linux-gcc-release/presubmit.cfg new file mode 100644 index 000000000..c249a5ab0 --- /dev/null +++ b/kokoro/linux-gcc-release/presubmit.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/linux-gcc-release/build.sh" diff --git a/kokoro/macos-clang-debug/build.sh b/kokoro/macos-clang-debug/build.sh new file mode 100644 index 000000000..8d9a062f6 --- /dev/null +++ b/kokoro/macos-clang-debug/build.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MacOS Build Script. + +# Fail on any error. +set -e +# Display commands being run. +set -x + +SCRIPT_DIR=`dirname "$BASH_SOURCE"` +source $SCRIPT_DIR/../scripts/macos/build.sh Debug + diff --git a/kokoro/macos-clang-debug/continuous.cfg b/kokoro/macos-clang-debug/continuous.cfg new file mode 100644 index 000000000..84aaa5c25 --- /dev/null +++ b/kokoro/macos-clang-debug/continuous.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous build configuration. +build_file: "SPIRV-Tools/kokoro/macos-clang-debug/build.sh" diff --git a/kokoro/macos-clang-debug/presubmit.cfg b/kokoro/macos-clang-debug/presubmit.cfg new file mode 100644 index 000000000..1d2f60da9 --- /dev/null +++ b/kokoro/macos-clang-debug/presubmit.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/macos-clang-debug/build.sh" diff --git a/kokoro/macos-clang-release/build.sh b/kokoro/macos-clang-release/build.sh new file mode 100644 index 000000000..ccc8b16aa --- /dev/null +++ b/kokoro/macos-clang-release/build.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MacOS Build Script. + +# Fail on any error. +set -e +# Display commands being run. +set -x + +SCRIPT_DIR=`dirname "$BASH_SOURCE"` +source $SCRIPT_DIR/../scripts/macos/build.sh RelWithDebInfo + diff --git a/kokoro/macos-clang-release/continuous.cfg b/kokoro/macos-clang-release/continuous.cfg new file mode 100644 index 000000000..a8e23a71a --- /dev/null +++ b/kokoro/macos-clang-release/continuous.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous build configuration. +build_file: "SPIRV-Tools/kokoro/macos-clang-release/build.sh" diff --git a/kokoro/macos-clang-release/presubmit.cfg b/kokoro/macos-clang-release/presubmit.cfg new file mode 100644 index 000000000..dbaa266cc --- /dev/null +++ b/kokoro/macos-clang-release/presubmit.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/macos-clang-release/build.sh" diff --git a/kokoro/ndk-build/build.sh b/kokoro/ndk-build/build.sh new file mode 100644 index 000000000..d51f071ea --- /dev/null +++ b/kokoro/ndk-build/build.sh @@ -0,0 +1,56 @@ +#!/bin/bash +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Linux Build Script. + +# Fail on any error. +set -e +# Display commands being run. +set -x + +BUILD_ROOT=$PWD +SRC=$PWD/github/SPIRV-Tools + +# Get NINJA. +wget -q https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip +unzip -q ninja-linux.zip +export PATH="$PWD:$PATH" + +# NDK Path +export ANDROID_NDK=/opt/android-ndk-r15c + +# Get the dependencies. +cd $SRC +git clone --depth=1 https://github.com/KhronosGroup/SPIRV-Headers external/spirv-headers +git clone --depth=1 https://github.com/google/googletest external/googletest +git clone --depth=1 https://github.com/google/effcee external/effcee +git clone --depth=1 https://github.com/google/re2 external/re2 + +mkdir build && cd $SRC/build +mkdir libs +mkdir app + +# Invoke the build. +BUILD_SHA=${KOKORO_GITHUB_COMMIT:-$KOKORO_GITHUB_PULL_REQUEST_COMMIT} +echo $(date): Starting ndk-build ... +$ANDROID_NDK/ndk-build \ + -C $SRC/android_test \ + NDK_PROJECT_PATH=. \ + NDK_LIBS_OUT=./libs \ + NDK_APP_OUT=./app \ + -j8 + +echo $(date): ndk-build completed. + diff --git a/kokoro/ndk-build/continuous.cfg b/kokoro/ndk-build/continuous.cfg new file mode 100644 index 000000000..b908a4814 --- /dev/null +++ b/kokoro/ndk-build/continuous.cfg @@ -0,0 +1,17 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous build configuration. +# +build_file: "SPIRV-Tools/kokoro/ndk-build/build.sh" diff --git a/kokoro/ndk-build/presubmit.cfg b/kokoro/ndk-build/presubmit.cfg new file mode 100644 index 000000000..3c1be4bf7 --- /dev/null +++ b/kokoro/ndk-build/presubmit.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/ndk-build/build.sh" diff --git a/kokoro/scripts/linux/build.sh b/kokoro/scripts/linux/build.sh new file mode 100644 index 000000000..d457539d4 --- /dev/null +++ b/kokoro/scripts/linux/build.sh @@ -0,0 +1,100 @@ +#!/bin/bash +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Linux Build Script. + +# Fail on any error. +set -e +# Display commands being run. +set -x + +BUILD_ROOT=$PWD +SRC=$PWD/github/SPIRV-Tools +CONFIG=$1 +COMPILER=$2 + +SKIP_TESTS="False" +BUILD_TYPE="Debug" + +CMAKE_C_CXX_COMPILER="" +if [ $COMPILER = "clang" ] +then + sudo ln -s /usr/bin/clang-3.8 /usr/bin/clang + sudo ln -s /usr/bin/clang++-3.8 /usr/bin/clang++ + CMAKE_C_CXX_COMPILER="-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++" +fi + +# Possible configurations are: +# ASAN, COVERAGE, RELEASE, DEBUG, DEBUG_EXCEPTION, RELEASE_MINGW + +if [ $CONFIG = "RELEASE" ] || [ $CONFIG = "RELEASE_MINGW" ] +then + BUILD_TYPE="RelWithDebInfo" +fi + +ADDITIONAL_CMAKE_FLAGS="" +if [ $CONFIG = "ASAN" ] +then + ADDITIONAL_CMAKE_FLAGS="-DCMAKE_CXX_FLAGS=-fsanitize=address -DCMAKE_C_FLAGS=-fsanitize=address" + export ASAN_SYMBOLIZER_PATH=/usr/bin/llvm-symbolizer-3.4 +elif [ $CONFIG = "COVERAGE" ] +then + ADDITIONAL_CMAKE_FLAGS="-DENABLE_CODE_COVERAGE=ON" + SKIP_TESTS="True" +elif [ $CONFIG = "DEBUG_EXCEPTION" ] +then + ADDITIONAL_CMAKE_FLAGS="-DDISABLE_EXCEPTIONS=ON -DDISABLE_RTTI=ON" +elif [ $CONFIG = "RELEASE_MINGW" ] +then + ADDITIONAL_CMAKE_FLAGS="-Dgtest_disable_pthreads=ON -DCMAKE_TOOLCHAIN_FILE=$SRC/cmake/linux-mingw-toolchain.cmake" + SKIP_TESTS="True" +fi + +# Get NINJA. +wget -q https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip +unzip -q ninja-linux.zip +export PATH="$PWD:$PATH" + +cd $SRC +git clone --depth=1 https://github.com/KhronosGroup/SPIRV-Headers external/spirv-headers +git clone --depth=1 https://github.com/google/googletest external/googletest +git clone --depth=1 https://github.com/google/effcee external/effcee +git clone --depth=1 https://github.com/google/re2 external/re2 + +mkdir build && cd $SRC/build + +# Invoke the build. +BUILD_SHA=${KOKORO_GITHUB_COMMIT:-$KOKORO_GITHUB_PULL_REQUEST_COMMIT} +echo $(date): Starting build... +cmake -GNinja -DCMAKE_BUILD_TYPE=$BUILD_TYPE -DCMAKE_INSTALL_PREFIX=install -DRE2_BUILD_TESTING=OFF $ADDITIONAL_CMAKE_FLAGS $CMAKE_C_CXX_COMPILER .. + +echo $(date): Build everything... +ninja +echo $(date): Build completed. + +if [ $CONFIG = "COVERAGE" ] +then + echo $(date): Check coverage... + ninja report-coverage + echo $(date): Check coverage completed. +fi + +echo $(date): Starting ctest... +if [ $SKIP_TESTS = "False" ] +then + ctest -j4 --output-on-failure --timeout 300 +fi +echo $(date): ctest completed. + diff --git a/kokoro/scripts/macos/build.sh b/kokoro/scripts/macos/build.sh new file mode 100644 index 000000000..a7f0453fe --- /dev/null +++ b/kokoro/scripts/macos/build.sh @@ -0,0 +1,53 @@ +#!/bin/bash +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MacOS Build Script. + +# Fail on any error. +set -e +# Display commands being run. +set -x + +BUILD_ROOT=$PWD +SRC=$PWD/github/SPIRV-Tools +BUILD_TYPE=$1 + +# Get NINJA. +wget -q https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-mac.zip +unzip -q ninja-mac.zip +chmod +x ninja +export PATH="$PWD:$PATH" + +cd $SRC +git clone --depth=1 https://github.com/KhronosGroup/SPIRV-Headers external/spirv-headers +git clone --depth=1 https://github.com/google/googletest external/googletest +git clone --depth=1 https://github.com/google/effcee external/effcee +git clone --depth=1 https://github.com/google/re2 external/re2 + +mkdir build && cd $SRC/build + +# Invoke the build. +BUILD_SHA=${KOKORO_GITHUB_COMMIT:-$KOKORO_GITHUB_PULL_REQUEST_COMMIT} +echo $(date): Starting build... +cmake -GNinja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=$BUILD_TYPE .. + +echo $(date): Build everything... +ninja +echo $(date): Build completed. + +echo $(date): Starting ctest... +ctest -j4 --output-on-failure --timeout 300 +echo $(date): ctest completed. + diff --git a/kokoro/scripts/windows/build.bat b/kokoro/scripts/windows/build.bat new file mode 100644 index 000000000..1985419f0 --- /dev/null +++ b/kokoro/scripts/windows/build.bat @@ -0,0 +1,95 @@ +:: Copyright (c) 2018 Google LLC. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: +:: Windows Build Script. + +@echo on + +set BUILD_ROOT=%cd% +set SRC=%cd%\github\SPIRV-Tools +set BUILD_TYPE=%1 +set VS_VERSION=%2 + +:: Force usage of python 2.7 rather than 3.6 +set PATH=C:\python27;%PATH% + +cd %SRC% +git clone --depth=1 https://github.com/KhronosGroup/SPIRV-Headers external/spirv-headers +git clone --depth=1 https://github.com/google/googletest external/googletest +git clone --depth=1 https://github.com/google/effcee external/effcee +git clone --depth=1 https://github.com/google/re2 external/re2 + +:: ######################################### +:: set up msvc build env +:: ######################################### +if %VS_VERSION% == 2017 ( + call "C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Auxiliary\Build\vcvarsall.bat" x64 + echo "Using VS 2017..." +) else if %VS_VERSION% == 2015 ( + call "C:\Program Files (x86)\Microsoft Visual Studio 14.0\VC\vcvarsall.bat" x64 + echo "Using VS 2015..." +) else if %VS_VERSION% == 2013 ( + call "C:\Program Files (x86)\Microsoft Visual Studio 12.0\VC\vcvarsall.bat" x64 + echo "Using VS 2013..." +) + +cd %SRC% +mkdir build +cd build + +:: ######################################### +:: Start building. +:: ######################################### +echo "Starting build... %DATE% %TIME%" +if "%KOKORO_GITHUB_COMMIT%." == "." ( + set BUILD_SHA=%KOKORO_GITHUB_PULL_REQUEST_COMMIT% +) else ( + set BUILD_SHA=%KOKORO_GITHUB_COMMIT% +) + +set CMAKE_FLAGS=-GNinja -DSPIRV_BUILD_COMPRESSION=ON -DCMAKE_BUILD_TYPE=%BUILD_TYPE% -DCMAKE_INSTALL_PREFIX=install -DRE2_BUILD_TESTING=OFF -DCMAKE_C_COMPILER=cl.exe -DCMAKE_CXX_COMPILER=cl.exe + +:: Skip building tests for VS2013 +if %VS_VERSION% == 2013 ( + set CMAKE_FLAGS=%CMAKE_FLAGS% -DSPIRV_SKIP_TESTS=ON +) + +cmake %CMAKE_FLAGS% .. + +if %ERRORLEVEL% NEQ 0 exit /b %ERRORLEVEL% + +echo "Build everything... %DATE% %TIME%" +ninja +if %ERRORLEVEL% NEQ 0 exit /b %ERRORLEVEL% +echo "Build Completed %DATE% %TIME%" + +:: This lets us use !ERRORLEVEL! inside an IF ... () and get the actual error at that point. +setlocal ENABLEDELAYEDEXPANSION + +:: ################################################ +:: Run the tests (We no longer run tests on VS2013) +:: ################################################ +echo "Running Tests... %DATE% %TIME%" +if %VS_VERSION% NEQ 2013 ( + ctest -C %BUILD_TYPE% --output-on-failure --timeout 300 + if !ERRORLEVEL! NEQ 0 exit /b !ERRORLEVEL! +) +echo "Tests Completed %DATE% %TIME%" + +:: Clean up some directories. +rm -rf %SRC%\build +rm -rf %SRC%\external + +exit /b 0 + diff --git a/kokoro/shaderc-smoketest/build.sh b/kokoro/shaderc-smoketest/build.sh new file mode 100644 index 000000000..638ca8c61 --- /dev/null +++ b/kokoro/shaderc-smoketest/build.sh @@ -0,0 +1,71 @@ +#!/bin/bash +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Fail on any error. +set -e +# Display commands being run. +set -x + +BUILD_ROOT=$PWD +GITHUB_DIR=$BUILD_ROOT/github + +SKIP_TESTS="False" +BUILD_TYPE="Release" + +# Get NINJA. +wget -q https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip +unzip -q ninja-linux.zip +export PATH="$PWD:$PATH" + +# Get shaderc. +cd $GITHUB_DIR +git clone https://github.com/google/shaderc.git +SHADERC_DIR=$GITHUB_DIR/shaderc +cd $SHADERC_DIR/third_party + +# Get shaderc dependencies. Link the appropriate SPIRV-Tools. +git clone https://github.com/google/googletest.git +git clone https://github.com/google/glslang.git +ln -s $GITHUB_DIR/SPIRV-Tools spirv-tools +git clone https://github.com/KhronosGroup/SPIRV-Headers.git spirv-headers +git clone https://github.com/google/re2 +git clone https://github.com/google/effcee + +cd $SHADERC_DIR +mkdir build +cd $SHADERC_DIR/build + +# Invoke the build. +BUILD_SHA=${KOKORO_GITHUB_COMMIT:-$KOKORO_GITHUB_PULL_REQUEST_COMMIT} +echo $(date): Starting build... +cmake -GNinja -DRE2_BUILD_TESTING=OFF -DCMAKE_BUILD_TYPE=$BUILD_TYPE .. + +echo $(date): Build glslang... +ninja glslangValidator + +echo $(date): Build everything... +ninja +echo $(date): Build completed. + +echo $(date): Check Shaderc for copyright notices... +ninja check-copyright + +echo $(date): Starting ctest... +if [ $SKIP_TESTS = "False" ] +then + ctest --output-on-failure -j4 +fi +echo $(date): ctest completed. + diff --git a/kokoro/shaderc-smoketest/continuous.cfg b/kokoro/shaderc-smoketest/continuous.cfg new file mode 100644 index 000000000..ee151ae1a --- /dev/null +++ b/kokoro/shaderc-smoketest/continuous.cfg @@ -0,0 +1,17 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous build configuration. +build_file: "SPIRV-Tools/kokoro/shaderc-smoketest/build.sh" + diff --git a/kokoro/shaderc-smoketest/presubmit.cfg b/kokoro/shaderc-smoketest/presubmit.cfg new file mode 100644 index 000000000..4f2ed21b7 --- /dev/null +++ b/kokoro/shaderc-smoketest/presubmit.cfg @@ -0,0 +1,17 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/shaderc-smoketest/build.sh" + diff --git a/kokoro/windows-msvc-2013-release/build.bat b/kokoro/windows-msvc-2013-release/build.bat new file mode 100644 index 000000000..e77172afc --- /dev/null +++ b/kokoro/windows-msvc-2013-release/build.bat @@ -0,0 +1,24 @@ +:: Copyright (c) 2018 Google LLC. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: +:: Windows Build Script. + +@echo on + +:: Find out the directory of the common build script. +set SCRIPT_DIR=%~dp0 + +:: Call with correct parameter +call %SCRIPT_DIR%\..\scripts\windows\build.bat RelWithDebInfo 2013 + diff --git a/kokoro/windows-msvc-2013-release/continuous.cfg b/kokoro/windows-msvc-2013-release/continuous.cfg new file mode 100644 index 000000000..5dfcba63b --- /dev/null +++ b/kokoro/windows-msvc-2013-release/continuous.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous build configuration. +build_file: "SPIRV-Tools/kokoro/windows-msvc-2013-release/build.bat" diff --git a/kokoro/windows-msvc-2013-release/presubmit.cfg b/kokoro/windows-msvc-2013-release/presubmit.cfg new file mode 100644 index 000000000..7d3b23822 --- /dev/null +++ b/kokoro/windows-msvc-2013-release/presubmit.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/windows-msvc-2013-release/build.bat" diff --git a/kokoro/windows-msvc-2015-release/build.bat b/kokoro/windows-msvc-2015-release/build.bat new file mode 100644 index 000000000..c0e4bd317 --- /dev/null +++ b/kokoro/windows-msvc-2015-release/build.bat @@ -0,0 +1,24 @@ +:: Copyright (c) 2018 Google LLC. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: +:: Windows Build Script. + +@echo on + +:: Find out the directory of the common build script. +set SCRIPT_DIR=%~dp0 + +:: Call with correct parameter +call %SCRIPT_DIR%\..\scripts\windows\build.bat RelWithDebInfo 2015 + diff --git a/kokoro/windows-msvc-2015-release/continuous.cfg b/kokoro/windows-msvc-2015-release/continuous.cfg new file mode 100644 index 000000000..3e47e5268 --- /dev/null +++ b/kokoro/windows-msvc-2015-release/continuous.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous build configuration. +build_file: "SPIRV-Tools/kokoro/windows-msvc-2015-release/build.bat" diff --git a/kokoro/windows-msvc-2015-release/presubmit.cfg b/kokoro/windows-msvc-2015-release/presubmit.cfg new file mode 100644 index 000000000..85a162593 --- /dev/null +++ b/kokoro/windows-msvc-2015-release/presubmit.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/windows-msvc-2015-release/build.bat" diff --git a/kokoro/windows-msvc-2017-debug/build.bat b/kokoro/windows-msvc-2017-debug/build.bat new file mode 100644 index 000000000..25783a9e5 --- /dev/null +++ b/kokoro/windows-msvc-2017-debug/build.bat @@ -0,0 +1,23 @@ +:: Copyright (c) 2018 Google LLC. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: +:: Windows Build Script. + +@echo on + +:: Find out the directory of the common build script. +set SCRIPT_DIR=%~dp0 + +:: Call with correct parameter +call %SCRIPT_DIR%\..\scripts\windows\build.bat Debug 2017 diff --git a/kokoro/windows-msvc-2017-debug/continuous.cfg b/kokoro/windows-msvc-2017-debug/continuous.cfg new file mode 100644 index 000000000..b842c30f1 --- /dev/null +++ b/kokoro/windows-msvc-2017-debug/continuous.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous build configuration. +build_file: "SPIRV-Tools/kokoro/windows-msvc-2017-debug/build.bat" diff --git a/kokoro/windows-msvc-2017-debug/presubmit.cfg b/kokoro/windows-msvc-2017-debug/presubmit.cfg new file mode 100644 index 000000000..a7a553aee --- /dev/null +++ b/kokoro/windows-msvc-2017-debug/presubmit.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/windows-msvc-2017-debug/build.bat" diff --git a/kokoro/windows-msvc-2017-release/build.bat b/kokoro/windows-msvc-2017-release/build.bat new file mode 100644 index 000000000..899fcbcfb --- /dev/null +++ b/kokoro/windows-msvc-2017-release/build.bat @@ -0,0 +1,24 @@ +:: Copyright (c) 2018 Google LLC. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: +:: Windows Build Script. + +@echo on + +:: Find out the directory of the common build script. +set SCRIPT_DIR=%~dp0 + +:: Call with correct parameter +call %SCRIPT_DIR%\..\scripts\windows\build.bat RelWithDebInfo 2017 + diff --git a/kokoro/windows-msvc-2017-release/continuous.cfg b/kokoro/windows-msvc-2017-release/continuous.cfg new file mode 100644 index 000000000..7b8c2ff2b --- /dev/null +++ b/kokoro/windows-msvc-2017-release/continuous.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous build configuration. +build_file: "SPIRV-Tools/kokoro/windows-msvc-2017-release/build.bat" diff --git a/kokoro/windows-msvc-2017-release/presubmit.cfg b/kokoro/windows-msvc-2017-release/presubmit.cfg new file mode 100644 index 000000000..5efd42927 --- /dev/null +++ b/kokoro/windows-msvc-2017-release/presubmit.cfg @@ -0,0 +1,16 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Presubmit build configuration. +build_file: "SPIRV-Tools/kokoro/windows-msvc-2017-release/build.bat" diff --git a/projects.md b/projects.md new file mode 100644 index 000000000..8f7f0bcd9 --- /dev/null +++ b/projects.md @@ -0,0 +1,82 @@ +# Tracking SPIRV-Tools work with GitHub projects + +We are experimenting with using the [GitHub Project +feature](https://help.github.com/articles/tracking-the-progress-of-your-work-with-projects/) +to track progress toward large goals. + +For more on GitHub Projects in general, see: +* [Introductory blog post](https://github.com/blog/2256-a-whole-new-github-universe-announcing-new-tools-forums-and-features) +* [Introductory video](https://www.youtube.com/watch?v=C6MGKHkNtxU) + +The current SPIRV-Tools project list can be found at +[https://github.com/KhronosGroup/SPIRV-Tools/projects](https://github.com/KhronosGroup/SPIRV-Tools/projects) + +## How we use a Project + +A GitHub Project is a set of work with an overall purpose, and +consists of a collection of *Cards*. +Each card is either a *Note* or a regular GitHub *Issue.* +A Note can be converted to an Issue. + +In our projects, a card represents work, i.e. a change that can +be applied to the repository. +The work could be a feature, a bug to be fixed, documentation to be +updated, etc. + +A project and its cards are used as a [Kanban +board](https://en.wikipedia.org/wiki/Kanban_board), where cards progress +through a workflow starting with ideas through to implementation and completion. + +In our usage, a *project manager* is someone who organizes the work. +They manage the creation and movement of cards +through the project workflow: +* They create cards to capture ideas, or to decompose large ideas into smaller + ones. +* They determine if the work for a card has been completed. +* Normally they are the person (or persons) who can approve and merge a pull + request into the `master` branch. + +Our projects organize cards into the following columns: +* `Ideas`: Work which could be done, captured either as Cards or Notes. + * A card in this column could be marked as a [PLACEHOLDER](#placeholders). +* `Ready to start`: Issues which represent work we'd like to do, and which + are not blocked by other work. + * The issue should be narrow enough that it can usually be addressed by a + single pull request. + * We want these to be Issues (not Notes) so that someone can claim the work + by updating the Issue with their intent to do the work. + Once an Issue is claimed, the project manager moves the corresponding card + from `Ready to start` to `In progress`. +* `In progress`: Issues which were in `Ready to start` but which have been + claimed by someone. +* `Done`: Issues which have been resolved, by completing their work. + * The changes have been applied to the repository, typically by being pushed + into the `master` branch. + * Other kinds of work could update repository settings, for example. +* `Rejected ideas`: Work which has been considered, but which we don't want + implemented. + * We keep rejected ideas so they are not proposed again. This serves + as a form of institutional memory. + * We should record why an idea is rejected. For this reason, a rejected + idea is likely to be an Issue which has been closed. + +## Prioritization + +We are considering prioritizing cards in the `Ideas` and `Ready to start` +columns so that things that should be considered first float up to the top. + +Experience will tell us if we stick to that rule, and if it proves helpful. + +## Placeholders + +A *placeholder* is a Note or Issue that represents a possibly large amount +of work that can be broadly defined but which may not have been broken down +into small implementable pieces of work. + +Use a placeholder to capture a big idea, but without doing the upfront work +to consider all the details of how it should be implemented. +Over time, break off pieces of the placeholder into implementable Issues. +Move those Issues into the `Ready to start` column when they become unblocked. + +We delete the placeholder when all its work has been decomposed into +implementable cards. diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt new file mode 100644 index 000000000..03efa9196 --- /dev/null +++ b/source/CMakeLists.txt @@ -0,0 +1,374 @@ +# Copyright (c) 2015-2016 The Khronos Group Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set(GRAMMAR_PROCESSING_SCRIPT "${spirv-tools_SOURCE_DIR}/utils/generate_grammar_tables.py") +set(VIMSYNTAX_PROCESSING_SCRIPT "${spirv-tools_SOURCE_DIR}/utils/generate_vim_syntax.py") +set(XML_REGISTRY_PROCESSING_SCRIPT "${spirv-tools_SOURCE_DIR}/utils/generate_registry_tables.py") +set(LANG_HEADER_PROCESSING_SCRIPT "${spirv-tools_SOURCE_DIR}/utils/generate_language_headers.py") + +# For now, assume the DebugInfo grammar file is in the current directory. +# It might migrate to SPIRV-Headers. +set(DEBUGINFO_GRAMMAR_JSON_FILE "${CMAKE_CURRENT_SOURCE_DIR}/extinst.debuginfo.grammar.json") + +# macro() definitions are used in the following because we need to append .inc +# file paths into some global lists (*_CPP_DEPENDS). And those global lists are +# later used by set_source_files_properties() calls. +# function() definitions are not suitable because they create new scopes. +macro(spvtools_core_tables CONFIG_VERSION) + set(GRAMMAR_JSON_FILE "${SPIRV_HEADER_INCLUDE_DIR}/spirv/${CONFIG_VERSION}/spirv.core.grammar.json") + set(GRAMMAR_INSTS_INC_FILE "${spirv-tools_BINARY_DIR}/core.insts-${CONFIG_VERSION}.inc") + set(GRAMMAR_KINDS_INC_FILE "${spirv-tools_BINARY_DIR}/operand.kinds-${CONFIG_VERSION}.inc") + add_custom_command(OUTPUT ${GRAMMAR_INSTS_INC_FILE} ${GRAMMAR_KINDS_INC_FILE} + COMMAND ${PYTHON_EXECUTABLE} ${GRAMMAR_PROCESSING_SCRIPT} + --spirv-core-grammar=${GRAMMAR_JSON_FILE} + --extinst-debuginfo-grammar=${DEBUGINFO_GRAMMAR_JSON_FILE} + --core-insts-output=${GRAMMAR_INSTS_INC_FILE} + --operand-kinds-output=${GRAMMAR_KINDS_INC_FILE} + DEPENDS ${GRAMMAR_PROCESSING_SCRIPT} ${GRAMMAR_JSON_FILE} ${DEBUGINFO_GRAMMAR_JSON_FILE} + COMMENT "Generate info tables for SPIR-V v${CONFIG_VERSION} core instructions and operands.") + list(APPEND OPCODE_CPP_DEPENDS ${GRAMMAR_INSTS_INC_FILE}) + list(APPEND OPERAND_CPP_DEPENDS ${GRAMMAR_KINDS_INC_FILE}) +endmacro(spvtools_core_tables) + +macro(spvtools_enum_string_mapping CONFIG_VERSION) + set(GRAMMAR_JSON_FILE "${SPIRV_HEADER_INCLUDE_DIR}/spirv/${CONFIG_VERSION}/spirv.core.grammar.json") + set(GRAMMAR_EXTENSION_ENUM_INC_FILE "${spirv-tools_BINARY_DIR}/extension_enum.inc") + set(GRAMMAR_ENUM_STRING_MAPPING_INC_FILE "${spirv-tools_BINARY_DIR}/enum_string_mapping.inc") + add_custom_command(OUTPUT ${GRAMMAR_EXTENSION_ENUM_INC_FILE} + ${GRAMMAR_ENUM_STRING_MAPPING_INC_FILE} + COMMAND ${PYTHON_EXECUTABLE} ${GRAMMAR_PROCESSING_SCRIPT} + --spirv-core-grammar=${GRAMMAR_JSON_FILE} + --extinst-debuginfo-grammar=${DEBUGINFO_GRAMMAR_JSON_FILE} + --extension-enum-output=${GRAMMAR_EXTENSION_ENUM_INC_FILE} + --enum-string-mapping-output=${GRAMMAR_ENUM_STRING_MAPPING_INC_FILE} + DEPENDS ${GRAMMAR_PROCESSING_SCRIPT} ${GRAMMAR_JSON_FILE} ${DEBUGINFO_GRAMMAR_JSON_FILE} + COMMENT "Generate enum-string mapping for SPIR-V v${CONFIG_VERSION}.") + list(APPEND EXTENSION_H_DEPENDS ${GRAMMAR_EXTENSION_ENUM_INC_FILE}) + list(APPEND ENUM_STRING_MAPPING_CPP_DEPENDS ${GRAMMAR_ENUM_STRING_MAPPING_INC_FILE}) +endmacro(spvtools_enum_string_mapping) + +macro(spvtools_vimsyntax CONFIG_VERSION CLVERSION) + set(GRAMMAR_JSON_FILE "${SPIRV_HEADER_INCLUDE_DIR}/spirv/${CONFIG_VERSION}/spirv.core.grammar.json") + set(GLSL_GRAMMAR_JSON_FILE "${SPIRV_HEADER_INCLUDE_DIR}/spirv/${CONFIG_VERSION}/extinst.glsl.std.450.grammar.json") + set(OPENCL_GRAMMAR_JSON_FILE "${SPIRV_HEADER_INCLUDE_DIR}/spirv/${CONFIG_VERSION}/extinst.opencl.std.100.grammar.json") + set(VIMSYNTAX_FILE "${spirv-tools_BINARY_DIR}/spvasm.vim") + add_custom_command(OUTPUT ${VIMSYNTAX_FILE} + COMMAND ${PYTHON_EXECUTABLE} ${VIMSYNTAX_PROCESSING_SCRIPT} + --spirv-core-grammar=${GRAMMAR_JSON_FILE} + --extinst-debuginfo-grammar=${DEBUGINFO_GRAMMAR_JSON_FILE} + --extinst-glsl-grammar=${GLSL_GRAMMAR_JSON_FILE} + --extinst-opencl-grammar=${OPENCL_GRAMMAR_JSON_FILE} + >${VIMSYNTAX_FILE} + DEPENDS ${VIMSYNTAX_PROCESSING_SCRIPT} ${GRAMMAR_JSON_FILE} + ${GLSL_GRAMMAR_JSON_FILE} ${OPENCL_GRAMMAR_JSON_FILE} ${DEBUGINFO_GRAMMAR_JSON_FILE} + COMMENT "Generate spvasm.vim: Vim syntax file for SPIR-V assembly.") +endmacro(spvtools_vimsyntax) + +macro(spvtools_glsl_tables CONFIG_VERSION) + set(CORE_GRAMMAR_JSON_FILE "${SPIRV_HEADER_INCLUDE_DIR}/spirv/${CONFIG_VERSION}/spirv.core.grammar.json") + set(GLSL_GRAMMAR_JSON_FILE "${SPIRV_HEADER_INCLUDE_DIR}/spirv/${CONFIG_VERSION}/extinst.glsl.std.450.grammar.json") + set(GRAMMAR_INC_FILE "${spirv-tools_BINARY_DIR}/glsl.std.450.insts.inc") + add_custom_command(OUTPUT ${GRAMMAR_INC_FILE} + COMMAND ${PYTHON_EXECUTABLE} ${GRAMMAR_PROCESSING_SCRIPT} + --extinst-glsl-grammar=${GLSL_GRAMMAR_JSON_FILE} + --glsl-insts-output=${GRAMMAR_INC_FILE} + DEPENDS ${GRAMMAR_PROCESSING_SCRIPT} ${CORE_GRAMMAR_JSON_FILE} ${GLSL_GRAMMAR_JSON_FILE} + COMMENT "Generate info tables for GLSL extended instructions and operands v${CONFIG_VERSION}.") + list(APPEND EXTINST_CPP_DEPENDS ${GRAMMAR_INC_FILE}) +endmacro(spvtools_glsl_tables) + +macro(spvtools_opencl_tables CONFIG_VERSION) + set(CORE_GRAMMAR_JSON_FILE "${SPIRV_HEADER_INCLUDE_DIR}/spirv/${CONFIG_VERSION}/spirv.core.grammar.json") + set(OPENCL_GRAMMAR_JSON_FILE "${SPIRV_HEADER_INCLUDE_DIR}/spirv/${CONFIG_VERSION}/extinst.opencl.std.100.grammar.json") + set(GRAMMAR_INC_FILE "${spirv-tools_BINARY_DIR}/opencl.std.insts.inc") + add_custom_command(OUTPUT ${GRAMMAR_INC_FILE} + COMMAND ${PYTHON_EXECUTABLE} ${GRAMMAR_PROCESSING_SCRIPT} + --extinst-opencl-grammar=${OPENCL_GRAMMAR_JSON_FILE} + --opencl-insts-output=${GRAMMAR_INC_FILE} + DEPENDS ${GRAMMAR_PROCESSING_SCRIPT} ${CORE_GRAMMAR_JSON_FILE} ${OPENCL_GRAMMAR_JSON_FILE} + COMMENT "Generate info tables for OpenCL extended instructions and operands v${CONFIG_VERSION}.") + list(APPEND EXTINST_CPP_DEPENDS ${GRAMMAR_INC_FILE}) +endmacro(spvtools_opencl_tables) + +macro(spvtools_vendor_tables VENDOR_TABLE) + set(INSTS_FILE "${spirv-tools_BINARY_DIR}/${VENDOR_TABLE}.insts.inc") + set(GRAMMAR_FILE "${spirv-tools_SOURCE_DIR}/source/extinst.${VENDOR_TABLE}.grammar.json") + add_custom_command(OUTPUT ${INSTS_FILE} + COMMAND ${PYTHON_EXECUTABLE} ${GRAMMAR_PROCESSING_SCRIPT} + --extinst-vendor-grammar=${GRAMMAR_FILE} + --vendor-insts-output=${INSTS_FILE} + DEPENDS ${GRAMMAR_PROCESSING_SCRIPT} ${GRAMMAR_FILE} + COMMENT "Generate extended instruction tables for ${VENDOR_TABLE}.") + add_custom_target(spirv-tools-${VENDOR_TABLE} DEPENDS ${INSTS_FILE}) + set_property(TARGET spirv-tools-${VENDOR_TABLE} PROPERTY FOLDER "SPIRV-Tools build") + list(APPEND EXTINST_CPP_DEPENDS spirv-tools-${VENDOR_TABLE}) +endmacro(spvtools_vendor_tables) + +macro(spvtools_extinst_lang_headers NAME GRAMMAR_FILE) + set(OUTBASE ${spirv-tools_BINARY_DIR}/${NAME}) + set(OUT_H ${OUTBASE}.h) + add_custom_command(OUTPUT ${OUT_H} + COMMAND ${PYTHON_EXECUTABLE} ${LANG_HEADER_PROCESSING_SCRIPT} + --extinst-name=${NAME} + --extinst-grammar=${GRAMMAR_FILE} + --extinst-output-base=${OUTBASE} + DEPENDS ${LANG_HEADER_PROCESSING_SCRIPT} ${GRAMMAR_FILE} + COMMENT "Generate language specific header for ${NAME}.") + add_custom_target(spirv-tools-header-${NAME} DEPENDS ${OUT_H}) + set_property(TARGET spirv-tools-header-${NAME} PROPERTY FOLDER "SPIRV-Tools build") + list(APPEND EXTINST_CPP_DEPENDS spirv-tools-header-${NAME}) +endmacro(spvtools_extinst_lang_headers) + +spvtools_core_tables("unified1") +spvtools_enum_string_mapping("unified1") +spvtools_opencl_tables("unified1") +spvtools_glsl_tables("unified1") +spvtools_vendor_tables("spv-amd-shader-explicit-vertex-parameter") +spvtools_vendor_tables("spv-amd-shader-trinary-minmax") +spvtools_vendor_tables("spv-amd-gcn-shader") +spvtools_vendor_tables("spv-amd-shader-ballot") +spvtools_vendor_tables("debuginfo") +spvtools_extinst_lang_headers("DebugInfo" ${DEBUGINFO_GRAMMAR_JSON_FILE}) + +spvtools_vimsyntax("unified1" "1.0") +add_custom_target(spirv-tools-vimsyntax DEPENDS ${VIMSYNTAX_FILE}) +set_property(TARGET spirv-tools-vimsyntax PROPERTY FOLDER "SPIRV-Tools utilities") + +# Extract the list of known generators from the SPIR-V XML registry file. +set(GENERATOR_INC_FILE ${spirv-tools_BINARY_DIR}/generators.inc) +set(SPIRV_XML_REGISTRY_FILE ${SPIRV_HEADER_INCLUDE_DIR}/spirv/spir-v.xml) +add_custom_command(OUTPUT ${GENERATOR_INC_FILE} + COMMAND ${PYTHON_EXECUTABLE} ${XML_REGISTRY_PROCESSING_SCRIPT} + --xml=${SPIRV_XML_REGISTRY_FILE} + --generator-output=${GENERATOR_INC_FILE} + DEPENDS ${XML_REGISTRY_PROCESSING_SCRIPT} ${SPIRV_XML_REGISTRY_FILE} + COMMENT "Generate tables based on the SPIR-V XML registry.") +list(APPEND OPCODE_CPP_DEPENDS ${GENERATOR_INC_FILE}) + +# The following .cpp files include the above generated .inc files. +# Add those .inc files as their dependencies. +# +# We need to wrap the .inc files with a custom target to avoid problems when +# multiple targets depend on the same custom command. +add_custom_target(core_tables + DEPENDS ${OPCODE_CPP_DEPENDS} ${OPERAND_CPP_DEPENDS}) +add_custom_target(enum_string_mapping + DEPENDS ${EXTENSION_H_DEPENDS} ${ENUM_STRING_MAPPING_CPP_DEPENDS}) +add_custom_target(extinst_tables + DEPENDS ${EXTINST_CPP_DEPENDS}) + +set_source_files_properties( + ${CMAKE_CURRENT_SOURCE_DIR}/extensions.h + PROPERTIES HEADER_FILE_ONLY TRUE) + +set(SPIRV_TOOLS_BUILD_VERSION_INC + ${spirv-tools_BINARY_DIR}/build-version.inc) +set(SPIRV_TOOLS_BUILD_VERSION_INC_GENERATOR + ${spirv-tools_SOURCE_DIR}/utils/update_build_version.py) +set(SPIRV_TOOLS_CHANGES_FILE + ${spirv-tools_SOURCE_DIR}/CHANGES) +add_custom_command(OUTPUT ${SPIRV_TOOLS_BUILD_VERSION_INC} + COMMAND ${PYTHON_EXECUTABLE} + ${SPIRV_TOOLS_BUILD_VERSION_INC_GENERATOR} + ${spirv-tools_SOURCE_DIR} ${SPIRV_TOOLS_BUILD_VERSION_INC} + DEPENDS ${SPIRV_TOOLS_BUILD_VERSION_INC_GENERATOR} + ${SPIRV_TOOLS_CHANGES_FILE} + COMMENT "Update build-version.inc in the SPIRV-Tools build directory (if necessary).") +# Convenience target for standalone generation of the build-version.inc file. +# This is not required for any dependence chain. +add_custom_target(spirv-tools-build-version + DEPENDS ${SPIRV_TOOLS_BUILD_VERSION_INC}) +set_property(TARGET spirv-tools-build-version PROPERTY FOLDER "SPIRV-Tools build") + +list(APPEND PCH_DEPENDS ${ENUM_STRING_MAPPING_CPP_DEPENDS} ${OPCODE_CPP_DEPENDS} ${OPERAND_CPP_DEPENDS} ${EXTENSION_H_DEPENDS} ${EXTINST_CPP_DEPENDS} ${SPIRV_TOOLS_BUILD_VERSION_INC}) +set_source_files_properties( + ${CMAKE_CURRENT_SOURCE_DIR}/pch_source.cpp + PROPERTIES OBJECT_DEPENDS "${PCH_DEPENDS}") + +add_subdirectory(comp) +add_subdirectory(opt) +add_subdirectory(reduce) +add_subdirectory(link) + +set(SPIRV_SOURCES + ${spirv-tools_SOURCE_DIR}/include/spirv-tools/libspirv.h + + ${CMAKE_CURRENT_SOURCE_DIR}/util/bitutils.h + ${CMAKE_CURRENT_SOURCE_DIR}/util/bit_vector.h + ${CMAKE_CURRENT_SOURCE_DIR}/util/hex_float.h + ${CMAKE_CURRENT_SOURCE_DIR}/util/make_unique.h + ${CMAKE_CURRENT_SOURCE_DIR}/util/parse_number.h + ${CMAKE_CURRENT_SOURCE_DIR}/util/small_vector.h + ${CMAKE_CURRENT_SOURCE_DIR}/util/string_utils.h + ${CMAKE_CURRENT_SOURCE_DIR}/util/timer.h + ${CMAKE_CURRENT_SOURCE_DIR}/assembly_grammar.h + ${CMAKE_CURRENT_SOURCE_DIR}/binary.h + ${CMAKE_CURRENT_SOURCE_DIR}/cfa.h + ${CMAKE_CURRENT_SOURCE_DIR}/diagnostic.h + ${CMAKE_CURRENT_SOURCE_DIR}/disassemble.h + ${CMAKE_CURRENT_SOURCE_DIR}/enum_set.h + ${CMAKE_CURRENT_SOURCE_DIR}/enum_string_mapping.h + ${CMAKE_CURRENT_SOURCE_DIR}/ext_inst.h + ${CMAKE_CURRENT_SOURCE_DIR}/extensions.h + ${CMAKE_CURRENT_SOURCE_DIR}/id_descriptor.h + ${CMAKE_CURRENT_SOURCE_DIR}/instruction.h + ${CMAKE_CURRENT_SOURCE_DIR}/latest_version_glsl_std_450_header.h + ${CMAKE_CURRENT_SOURCE_DIR}/latest_version_opencl_std_header.h + ${CMAKE_CURRENT_SOURCE_DIR}/latest_version_spirv_header.h + ${CMAKE_CURRENT_SOURCE_DIR}/macro.h + ${CMAKE_CURRENT_SOURCE_DIR}/name_mapper.h + ${CMAKE_CURRENT_SOURCE_DIR}/opcode.h + ${CMAKE_CURRENT_SOURCE_DIR}/operand.h + ${CMAKE_CURRENT_SOURCE_DIR}/parsed_operand.h + ${CMAKE_CURRENT_SOURCE_DIR}/print.h + ${CMAKE_CURRENT_SOURCE_DIR}/spirv_constant.h + ${CMAKE_CURRENT_SOURCE_DIR}/spirv_definition.h + ${CMAKE_CURRENT_SOURCE_DIR}/spirv_endian.h + ${CMAKE_CURRENT_SOURCE_DIR}/spirv_optimizer_options.h + ${CMAKE_CURRENT_SOURCE_DIR}/spirv_reducer_options.h + ${CMAKE_CURRENT_SOURCE_DIR}/spirv_target_env.h + ${CMAKE_CURRENT_SOURCE_DIR}/spirv_validator_options.h + ${CMAKE_CURRENT_SOURCE_DIR}/table.h + ${CMAKE_CURRENT_SOURCE_DIR}/text.h + ${CMAKE_CURRENT_SOURCE_DIR}/text_handler.h + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate.h + + ${CMAKE_CURRENT_SOURCE_DIR}/util/bit_vector.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/util/parse_number.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/util/string_utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/assembly_grammar.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/diagnostic.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/disassemble.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/enum_string_mapping.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ext_inst.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extensions.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/id_descriptor.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libspirv.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/name_mapper.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/opcode.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/operand.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/parsed_operand.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/print.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/software_version.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/spirv_endian.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/spirv_optimizer_options.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/spirv_reducer_options.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/spirv_target_env.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/spirv_validator_options.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/table.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/text.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/text_handler.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_adjacency.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_annotation.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_arithmetics.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_atomics.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_barriers.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_bitwise.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_builtins.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_capability.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_cfg.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_composites.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_constants.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_conversion.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_datarules.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_debug.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_decorations.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_derivatives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_extensions.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_execution_limitations.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_function.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_id.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_image.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_interfaces.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_instruction.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_layout.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_literals.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_logicals.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_memory.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_memory_semantics.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_mode_setting.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_non_uniform.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_scopes.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_type.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/decoration.h + ${CMAKE_CURRENT_SOURCE_DIR}/val/basic_block.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/construct.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/function.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/instruction.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validation_state.cpp) + +if (${SPIRV_TIMER_ENABLED}) + set(SPIRV_SOURCES + ${SPIRV_SOURCES} + ${CMAKE_CURRENT_SOURCE_DIR}/util/timer.cpp) +endif() + +# The software_version.cpp file includes build-version.inc. +# Rebuild the software_version.cpp object file if it is older than +# build-version.inc or whenever build-version.inc itself is out of +# date. In the latter case, rebuild build-version.inc first. +# CMake is not smart enough to detect this dependency automatically. +# Without this, the dependency detection system for #included files +# does not kick in on a clean build for the following reason: The +# build will fail early because it doesn't know how to build the +# missing source file build-version.inc. That occurs before the +# preprocessor is run on software_version.cpp to detect the +# #include dependency. +set_source_files_properties( + ${CMAKE_CURRENT_SOURCE_DIR}/software_version.cpp + PROPERTIES OBJECT_DEPENDS "${SPIRV_TOOLS_BUILD_VERSION_INC}") + +spvtools_pch(SPIRV_SOURCES pch_source) + +add_library(${SPIRV_TOOLS} ${SPIRV_SOURCES}) +spvtools_default_compile_options(${SPIRV_TOOLS}) +target_include_directories(${SPIRV_TOOLS} + PUBLIC ${spirv-tools_SOURCE_DIR}/include + PRIVATE ${spirv-tools_BINARY_DIR} + PRIVATE ${SPIRV_HEADER_INCLUDE_DIR} + ) +set_property(TARGET ${SPIRV_TOOLS} PROPERTY FOLDER "SPIRV-Tools libraries") +spvtools_check_symbol_exports(${SPIRV_TOOLS}) +add_dependencies( ${SPIRV_TOOLS} core_tables enum_string_mapping extinst_tables ) + +add_library(${SPIRV_TOOLS}-shared SHARED ${SPIRV_SOURCES}) +spvtools_default_compile_options(${SPIRV_TOOLS}-shared) +target_include_directories(${SPIRV_TOOLS}-shared + PUBLIC ${spirv-tools_SOURCE_DIR}/include + PRIVATE ${spirv-tools_BINARY_DIR} + PRIVATE ${SPIRV_HEADER_INCLUDE_DIR} + ) +set_target_properties(${SPIRV_TOOLS}-shared PROPERTIES CXX_VISIBILITY_PRESET hidden) +set_property(TARGET ${SPIRV_TOOLS}-shared PROPERTY FOLDER "SPIRV-Tools libraries") +spvtools_check_symbol_exports(${SPIRV_TOOLS}-shared) +target_compile_definitions(${SPIRV_TOOLS}-shared + PRIVATE SPIRV_TOOLS_IMPLEMENTATION + PUBLIC SPIRV_TOOLS_SHAREDLIB +) +add_dependencies( ${SPIRV_TOOLS}-shared core_tables enum_string_mapping extinst_tables ) + +if(ENABLE_SPIRV_TOOLS_INSTALL) + install(TARGETS ${SPIRV_TOOLS} ${SPIRV_TOOLS}-shared + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}) +endif(ENABLE_SPIRV_TOOLS_INSTALL) + +if(MSVC) + # Enable parallel builds across four cores for this lib + add_definitions(/MP4) +endif() diff --git a/source/assembly_grammar.cpp b/source/assembly_grammar.cpp new file mode 100644 index 000000000..4d98e3dab --- /dev/null +++ b/source/assembly_grammar.cpp @@ -0,0 +1,263 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/assembly_grammar.h" + +#include +#include +#include + +#include "source/ext_inst.h" +#include "source/opcode.h" +#include "source/operand.h" +#include "source/table.h" + +namespace spvtools { +namespace { + +/// @brief Parses a mask expression string for the given operand type. +/// +/// A mask expression is a sequence of one or more terms separated by '|', +/// where each term a named enum value for the given type. No whitespace +/// is permitted. +/// +/// On success, the value is written to pValue. +/// +/// @param[in] operandTable operand lookup table +/// @param[in] type of the operand +/// @param[in] textValue word of text to be parsed +/// @param[out] pValue where the resulting value is written +/// +/// @return result code +spv_result_t spvTextParseMaskOperand(spv_target_env env, + const spv_operand_table operandTable, + const spv_operand_type_t type, + const char* textValue, uint32_t* pValue) { + if (textValue == nullptr) return SPV_ERROR_INVALID_TEXT; + size_t text_length = strlen(textValue); + if (text_length == 0) return SPV_ERROR_INVALID_TEXT; + const char* text_end = textValue + text_length; + + // We only support mask expressions in ASCII, so the separator value is a + // char. + const char separator = '|'; + + // Accumulate the result by interpreting one word at a time, scanning + // from left to right. + uint32_t value = 0; + const char* begin = textValue; // The left end of the current word. + const char* end = nullptr; // One character past the end of the current word. + do { + end = std::find(begin, text_end, separator); + + spv_operand_desc entry = nullptr; + if (spvOperandTableNameLookup(env, operandTable, type, begin, end - begin, + &entry)) { + return SPV_ERROR_INVALID_TEXT; + } + value |= entry->value; + + // Advance to the next word by skipping over the separator. + begin = end + 1; + } while (end != text_end); + + *pValue = value; + return SPV_SUCCESS; +} + +// Associates an opcode with its name. +struct SpecConstantOpcodeEntry { + SpvOp opcode; + const char* name; +}; + +// All the opcodes allowed as the operation for OpSpecConstantOp. +// The name does not have the usual "Op" prefix. For example opcode SpvOpIAdd +// is associated with the name "IAdd". +// +// clang-format off +#define CASE(NAME) { SpvOp##NAME, #NAME } +const SpecConstantOpcodeEntry kOpSpecConstantOpcodes[] = { + // Conversion + CASE(SConvert), + CASE(FConvert), + CASE(ConvertFToS), + CASE(ConvertSToF), + CASE(ConvertFToU), + CASE(ConvertUToF), + CASE(UConvert), + CASE(ConvertPtrToU), + CASE(ConvertUToPtr), + CASE(GenericCastToPtr), + CASE(PtrCastToGeneric), + CASE(Bitcast), + CASE(QuantizeToF16), + // Arithmetic + CASE(SNegate), + CASE(Not), + CASE(IAdd), + CASE(ISub), + CASE(IMul), + CASE(UDiv), + CASE(SDiv), + CASE(UMod), + CASE(SRem), + CASE(SMod), + CASE(ShiftRightLogical), + CASE(ShiftRightArithmetic), + CASE(ShiftLeftLogical), + CASE(BitwiseOr), + CASE(BitwiseAnd), + CASE(BitwiseXor), + CASE(FNegate), + CASE(FAdd), + CASE(FSub), + CASE(FMul), + CASE(FDiv), + CASE(FRem), + CASE(FMod), + // Composite + CASE(VectorShuffle), + CASE(CompositeExtract), + CASE(CompositeInsert), + // Logical + CASE(LogicalOr), + CASE(LogicalAnd), + CASE(LogicalNot), + CASE(LogicalEqual), + CASE(LogicalNotEqual), + CASE(Select), + // Comparison + CASE(IEqual), + CASE(INotEqual), + CASE(ULessThan), + CASE(SLessThan), + CASE(UGreaterThan), + CASE(SGreaterThan), + CASE(ULessThanEqual), + CASE(SLessThanEqual), + CASE(UGreaterThanEqual), + CASE(SGreaterThanEqual), + // Memory + CASE(AccessChain), + CASE(InBoundsAccessChain), + CASE(PtrAccessChain), + CASE(InBoundsPtrAccessChain), +}; + +// The 59 is determined by counting the opcodes listed in the spec. +static_assert(59 == sizeof(kOpSpecConstantOpcodes)/sizeof(kOpSpecConstantOpcodes[0]), + "OpSpecConstantOp opcode table is incomplete"); +#undef CASE +// clang-format on + +const size_t kNumOpSpecConstantOpcodes = + sizeof(kOpSpecConstantOpcodes) / sizeof(kOpSpecConstantOpcodes[0]); + +} // namespace + +bool AssemblyGrammar::isValid() const { + return operandTable_ && opcodeTable_ && extInstTable_; +} + +CapabilitySet AssemblyGrammar::filterCapsAgainstTargetEnv( + const SpvCapability* cap_array, uint32_t count) const { + CapabilitySet cap_set; + for (uint32_t i = 0; i < count; ++i) { + spv_operand_desc cap_desc = {}; + if (SPV_SUCCESS == lookupOperand(SPV_OPERAND_TYPE_CAPABILITY, + static_cast(cap_array[i]), + &cap_desc)) { + // spvOperandTableValueLookup() filters capabilities internally + // according to the current target environment by itself. So we + // should be safe to add this capability if the lookup succeeds. + cap_set.Add(cap_array[i]); + } + } + return cap_set; +} + +spv_result_t AssemblyGrammar::lookupOpcode(const char* name, + spv_opcode_desc* desc) const { + return spvOpcodeTableNameLookup(target_env_, opcodeTable_, name, desc); +} + +spv_result_t AssemblyGrammar::lookupOpcode(SpvOp opcode, + spv_opcode_desc* desc) const { + return spvOpcodeTableValueLookup(target_env_, opcodeTable_, opcode, desc); +} + +spv_result_t AssemblyGrammar::lookupOperand(spv_operand_type_t type, + const char* name, size_t name_len, + spv_operand_desc* desc) const { + return spvOperandTableNameLookup(target_env_, operandTable_, type, name, + name_len, desc); +} + +spv_result_t AssemblyGrammar::lookupOperand(spv_operand_type_t type, + uint32_t operand, + spv_operand_desc* desc) const { + return spvOperandTableValueLookup(target_env_, operandTable_, type, operand, + desc); +} + +spv_result_t AssemblyGrammar::lookupSpecConstantOpcode(const char* name, + SpvOp* opcode) const { + const auto* last = kOpSpecConstantOpcodes + kNumOpSpecConstantOpcodes; + const auto* found = + std::find_if(kOpSpecConstantOpcodes, last, + [name](const SpecConstantOpcodeEntry& entry) { + return 0 == strcmp(name, entry.name); + }); + if (found == last) return SPV_ERROR_INVALID_LOOKUP; + *opcode = found->opcode; + return SPV_SUCCESS; +} + +spv_result_t AssemblyGrammar::lookupSpecConstantOpcode(SpvOp opcode) const { + const auto* last = kOpSpecConstantOpcodes + kNumOpSpecConstantOpcodes; + const auto* found = + std::find_if(kOpSpecConstantOpcodes, last, + [opcode](const SpecConstantOpcodeEntry& entry) { + return opcode == entry.opcode; + }); + if (found == last) return SPV_ERROR_INVALID_LOOKUP; + return SPV_SUCCESS; +} + +spv_result_t AssemblyGrammar::parseMaskOperand(const spv_operand_type_t type, + const char* textValue, + uint32_t* pValue) const { + return spvTextParseMaskOperand(target_env_, operandTable_, type, textValue, + pValue); +} +spv_result_t AssemblyGrammar::lookupExtInst(spv_ext_inst_type_t type, + const char* textValue, + spv_ext_inst_desc* extInst) const { + return spvExtInstTableNameLookup(extInstTable_, type, textValue, extInst); +} + +spv_result_t AssemblyGrammar::lookupExtInst(spv_ext_inst_type_t type, + uint32_t firstWord, + spv_ext_inst_desc* extInst) const { + return spvExtInstTableValueLookup(extInstTable_, type, firstWord, extInst); +} + +void AssemblyGrammar::pushOperandTypesForMask( + const spv_operand_type_t type, const uint32_t mask, + spv_operand_pattern_t* pattern) const { + spvPushOperandTypesForMask(target_env_, operandTable_, type, mask, pattern); +} + +} // namespace spvtools diff --git a/source/assembly_grammar.h b/source/assembly_grammar.h new file mode 100644 index 000000000..17c2bd3ba --- /dev/null +++ b/source/assembly_grammar.h @@ -0,0 +1,138 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_ASSEMBLY_GRAMMAR_H_ +#define SOURCE_ASSEMBLY_GRAMMAR_H_ + +#include "source/enum_set.h" +#include "source/latest_version_spirv_header.h" +#include "source/operand.h" +#include "source/table.h" +#include "spirv-tools/libspirv.h" + +namespace spvtools { + +// Encapsulates the grammar to use for SPIR-V assembly. +// Contains methods to query for valid instructions and operands. +class AssemblyGrammar { + public: + explicit AssemblyGrammar(const spv_const_context context) + : target_env_(context->target_env), + operandTable_(context->operand_table), + opcodeTable_(context->opcode_table), + extInstTable_(context->ext_inst_table) {} + + // Returns true if the internal tables have been initialized with valid data. + bool isValid() const; + + // Returns the SPIR-V target environment. + spv_target_env target_env() const { return target_env_; } + + // Removes capabilities not available in the current target environment and + // returns the rest. + CapabilitySet filterCapsAgainstTargetEnv(const SpvCapability* cap_array, + uint32_t count) const; + + // Fills in the desc parameter with the information about the opcode + // of the given name. Returns SPV_SUCCESS if the opcode was found, and + // SPV_ERROR_INVALID_LOOKUP if the opcode does not exist. + spv_result_t lookupOpcode(const char* name, spv_opcode_desc* desc) const; + + // Fills in the desc parameter with the information about the opcode + // of the valid. Returns SPV_SUCCESS if the opcode was found, and + // SPV_ERROR_INVALID_LOOKUP if the opcode does not exist. + spv_result_t lookupOpcode(SpvOp opcode, spv_opcode_desc* desc) const; + + // Fills in the desc parameter with the information about the given + // operand. Returns SPV_SUCCESS if the operand was found, and + // SPV_ERROR_INVALID_LOOKUP otherwise. + spv_result_t lookupOperand(spv_operand_type_t type, const char* name, + size_t name_len, spv_operand_desc* desc) const; + + // Fills in the desc parameter with the information about the given + // operand. Returns SPV_SUCCESS if the operand was found, and + // SPV_ERROR_INVALID_LOOKUP otherwise. + spv_result_t lookupOperand(spv_operand_type_t type, uint32_t operand, + spv_operand_desc* desc) const; + + // Finds operand entry in the grammar table and returns its name. + // Returns "Unknown" if not found. + const char* lookupOperandName(spv_operand_type_t type, + uint32_t operand) const { + spv_operand_desc desc = nullptr; + if (lookupOperand(type, operand, &desc) != SPV_SUCCESS || !desc) { + return "Unknown"; + } + return desc->name; + } + + // Finds the opcode for the given OpSpecConstantOp opcode name. The name + // should not have the "Op" prefix. For example, "IAdd" corresponds to + // the integer add opcode for OpSpecConstantOp. On success, returns + // SPV_SUCCESS and sends the discovered operation code through the opcode + // parameter. On failure, returns SPV_ERROR_INVALID_LOOKUP. + spv_result_t lookupSpecConstantOpcode(const char* name, SpvOp* opcode) const; + + // Returns SPV_SUCCESS if the given opcode is valid as the opcode operand + // to OpSpecConstantOp. + spv_result_t lookupSpecConstantOpcode(SpvOp opcode) const; + + // Parses a mask expression string for the given operand type. + // + // A mask expression is a sequence of one or more terms separated by '|', + // where each term is a named enum value for a given type. No whitespace + // is permitted. + // + // On success, the value is written to pValue, and SPV_SUCCESS is returned. + // The operand type is defined by the type parameter, and the text to be + // parsed is defined by the textValue parameter. + spv_result_t parseMaskOperand(const spv_operand_type_t type, + const char* textValue, uint32_t* pValue) const; + + // Writes the extended operand with the given type and text to the *extInst + // parameter. + // Returns SPV_SUCCESS if the value could be found. + spv_result_t lookupExtInst(spv_ext_inst_type_t type, const char* textValue, + spv_ext_inst_desc* extInst) const; + + // Writes the extended operand with the given type and first encoded word + // to the *extInst parameter. + // Returns SPV_SUCCESS if the value could be found. + spv_result_t lookupExtInst(spv_ext_inst_type_t type, uint32_t firstWord, + spv_ext_inst_desc* extInst) const; + + // Inserts the operands expected after the given typed mask onto the end + // of the given pattern. + // + // Each set bit in the mask represents zero or more operand types that + // should be appended onto the pattern. Operands for a less significant + // bit must always match before operands for a more significant bit, so + // the operands for a less significant bit must appear closer to the end + // of the pattern stack. + // + // If a set bit is unknown, then we assume it has no operands. + void pushOperandTypesForMask(const spv_operand_type_t type, + const uint32_t mask, + spv_operand_pattern_t* pattern) const; + + private: + const spv_target_env target_env_; + const spv_operand_table operandTable_; + const spv_opcode_table opcodeTable_; + const spv_ext_inst_table extInstTable_; +}; + +} // namespace spvtools + +#endif // SOURCE_ASSEMBLY_GRAMMAR_H_ diff --git a/source/binary.cpp b/source/binary.cpp new file mode 100644 index 000000000..636dac8c8 --- /dev/null +++ b/source/binary.cpp @@ -0,0 +1,795 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/binary.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/assembly_grammar.h" +#include "source/diagnostic.h" +#include "source/ext_inst.h" +#include "source/latest_version_spirv_header.h" +#include "source/opcode.h" +#include "source/operand.h" +#include "source/spirv_constant.h" +#include "source/spirv_endian.h" + +spv_result_t spvBinaryHeaderGet(const spv_const_binary binary, + const spv_endianness_t endian, + spv_header_t* pHeader) { + if (!binary->code) return SPV_ERROR_INVALID_BINARY; + if (binary->wordCount < SPV_INDEX_INSTRUCTION) + return SPV_ERROR_INVALID_BINARY; + if (!pHeader) return SPV_ERROR_INVALID_POINTER; + + // TODO: Validation checking? + pHeader->magic = spvFixWord(binary->code[SPV_INDEX_MAGIC_NUMBER], endian); + pHeader->version = spvFixWord(binary->code[SPV_INDEX_VERSION_NUMBER], endian); + pHeader->generator = + spvFixWord(binary->code[SPV_INDEX_GENERATOR_NUMBER], endian); + pHeader->bound = spvFixWord(binary->code[SPV_INDEX_BOUND], endian); + pHeader->schema = spvFixWord(binary->code[SPV_INDEX_SCHEMA], endian); + pHeader->instructions = &binary->code[SPV_INDEX_INSTRUCTION]; + + return SPV_SUCCESS; +} + +namespace { + +// A SPIR-V binary parser. A parser instance communicates detailed parse +// results via callbacks. +class Parser { + public: + // The user_data value is provided to the callbacks as context. + Parser(const spv_const_context context, void* user_data, + spv_parsed_header_fn_t parsed_header_fn, + spv_parsed_instruction_fn_t parsed_instruction_fn) + : grammar_(context), + consumer_(context->consumer), + user_data_(user_data), + parsed_header_fn_(parsed_header_fn), + parsed_instruction_fn_(parsed_instruction_fn) {} + + // Parses the specified binary SPIR-V module, issuing callbacks on a parsed + // header and for each parsed instruction. Returns SPV_SUCCESS on success. + // Otherwise returns an error code and issues a diagnostic. + spv_result_t parse(const uint32_t* words, size_t num_words, + spv_diagnostic* diagnostic); + + private: + // All remaining methods work on the current module parse state. + + // Like the parse method, but works on the current module parse state. + spv_result_t parseModule(); + + // Parses an instruction at the current position of the binary. Assumes + // the header has been parsed, the endian has been set, and the word index is + // still in range. Advances the parsing position past the instruction, and + // updates other parsing state for the current module. + // On success, returns SPV_SUCCESS and issues the parsed-instruction callback. + // On failure, returns an error code and issues a diagnostic. + spv_result_t parseInstruction(); + + // Parses an instruction operand with the given type, for an instruction + // starting at inst_offset words into the SPIR-V binary. + // If the SPIR-V binary is the same endianness as the host, then the + // endian_converted_inst_words parameter is ignored. Otherwise, this method + // appends the words for this operand, converted to host native endianness, + // to the end of endian_converted_inst_words. This method also updates the + // expected_operands parameter, and the scalar members of the inst parameter. + // On success, returns SPV_SUCCESS, advances past the operand, and pushes a + // new entry on to the operands vector. Otherwise returns an error code and + // issues a diagnostic. + spv_result_t parseOperand(size_t inst_offset, spv_parsed_instruction_t* inst, + const spv_operand_type_t type, + std::vector* endian_converted_inst_words, + std::vector* operands, + spv_operand_pattern_t* expected_operands); + + // Records the numeric type for an operand according to the type information + // associated with the given non-zero type Id. This can fail if the type Id + // is not a type Id, or if the type Id does not reference a scalar numeric + // type. On success, return SPV_SUCCESS and populates the num_words, + // number_kind, and number_bit_width fields of parsed_operand. + spv_result_t setNumericTypeInfoForType(spv_parsed_operand_t* parsed_operand, + uint32_t type_id); + + // Records the number type for an instruction at the given offset, if that + // instruction generates a type. For types that aren't scalar numbers, + // record something with number kind SPV_NUMBER_NONE. + void recordNumberType(size_t inst_offset, + const spv_parsed_instruction_t* inst); + + // Returns a diagnostic stream object initialized with current position in + // the input stream, and for the given error code. Any data written to the + // returned object will be propagated to the current parse's diagnostic + // object. + spvtools::DiagnosticStream diagnostic(spv_result_t error) { + return spvtools::DiagnosticStream({0, 0, _.instruction_count}, consumer_, + "", error); + } + + // Returns a diagnostic stream object with the default parse error code. + spvtools::DiagnosticStream diagnostic() { + // The default failure for parsing is invalid binary. + return diagnostic(SPV_ERROR_INVALID_BINARY); + } + + // Issues a diagnostic describing an exhaustion of input condition when + // trying to decode an instruction operand, and returns + // SPV_ERROR_INVALID_BINARY. + spv_result_t exhaustedInputDiagnostic(size_t inst_offset, SpvOp opcode, + spv_operand_type_t type) { + return diagnostic() << "End of input reached while decoding Op" + << spvOpcodeString(opcode) << " starting at word " + << inst_offset + << ((_.word_index < _.num_words) ? ": truncated " + : ": missing ") + << spvOperandTypeStr(type) << " operand at word offset " + << _.word_index - inst_offset << "."; + } + + // Returns the endian-corrected word at the current position. + uint32_t peek() const { return peekAt(_.word_index); } + + // Returns the endian-corrected word at the given position. + uint32_t peekAt(size_t index) const { + assert(index < _.num_words); + return spvFixWord(_.words[index], _.endian); + } + + // Data members + + const spvtools::AssemblyGrammar grammar_; // SPIR-V syntax utility. + const spvtools::MessageConsumer& consumer_; // Message consumer callback. + void* const user_data_; // Context for the callbacks + const spv_parsed_header_fn_t parsed_header_fn_; // Parsed header callback + const spv_parsed_instruction_fn_t + parsed_instruction_fn_; // Parsed instruction callback + + // Describes the format of a typed literal number. + struct NumberType { + spv_number_kind_t type; + uint32_t bit_width; + }; + + // The state used to parse a single SPIR-V binary module. + struct State { + State(const uint32_t* words_arg, size_t num_words_arg, + spv_diagnostic* diagnostic_arg) + : words(words_arg), + num_words(num_words_arg), + diagnostic(diagnostic_arg), + word_index(0), + instruction_count(0), + endian(), + requires_endian_conversion(false) { + // Temporary storage for parser state within a single instruction. + // Most instructions require fewer than 25 words or operands. + operands.reserve(25); + endian_converted_words.reserve(25); + expected_operands.reserve(25); + } + State() : State(0, 0, nullptr) {} + const uint32_t* words; // Words in the binary SPIR-V module. + size_t num_words; // Number of words in the module. + spv_diagnostic* diagnostic; // Where diagnostics go. + size_t word_index; // The current position in words. + size_t instruction_count; // The count of processed instructions + spv_endianness_t endian; // The endianness of the binary. + // Is the SPIR-V binary in a different endiannes from the host native + // endianness? + bool requires_endian_conversion; + + // Maps a result ID to its type ID. By convention: + // - a result ID that is a type definition maps to itself. + // - a result ID without a type maps to 0. (E.g. for OpLabel) + std::unordered_map id_to_type_id; + // Maps a type ID to its number type description. + std::unordered_map type_id_to_number_type_info; + // Maps an ExtInstImport id to the extended instruction type. + std::unordered_map + import_id_to_ext_inst_type; + + // Used by parseOperand + std::vector operands; + std::vector endian_converted_words; + spv_operand_pattern_t expected_operands; + } _; +}; + +spv_result_t Parser::parse(const uint32_t* words, size_t num_words, + spv_diagnostic* diagnostic_arg) { + _ = State(words, num_words, diagnostic_arg); + + const spv_result_t result = parseModule(); + + // Clear the module state. The tables might be big. + _ = State(); + + return result; +} + +spv_result_t Parser::parseModule() { + if (!_.words) return diagnostic() << "Missing module."; + + if (_.num_words < SPV_INDEX_INSTRUCTION) + return diagnostic() << "Module has incomplete header: only " << _.num_words + << " words instead of " << SPV_INDEX_INSTRUCTION; + + // Check the magic number and detect the module's endianness. + spv_const_binary_t binary{_.words, _.num_words}; + if (spvBinaryEndianness(&binary, &_.endian)) { + return diagnostic() << "Invalid SPIR-V magic number '" << std::hex + << _.words[0] << "'."; + } + _.requires_endian_conversion = !spvIsHostEndian(_.endian); + + // Process the header. + spv_header_t header; + if (spvBinaryHeaderGet(&binary, _.endian, &header)) { + // It turns out there is no way to trigger this error since the only + // failure cases are already handled above, with better messages. + return diagnostic(SPV_ERROR_INTERNAL) + << "Internal error: unhandled header parse failure"; + } + if (parsed_header_fn_) { + if (auto error = parsed_header_fn_(user_data_, _.endian, header.magic, + header.version, header.generator, + header.bound, header.schema)) { + return error; + } + } + + // Process the instructions. + _.word_index = SPV_INDEX_INSTRUCTION; + while (_.word_index < _.num_words) + if (auto error = parseInstruction()) return error; + + // Running off the end should already have been reported earlier. + assert(_.word_index == _.num_words); + + return SPV_SUCCESS; +} + +spv_result_t Parser::parseInstruction() { + _.instruction_count++; + + // The zero values for all members except for opcode are the + // correct initial values. + spv_parsed_instruction_t inst = {}; + + const uint32_t first_word = peek(); + + // If the module's endianness is different from the host native endianness, + // then converted_words contains the the endian-translated words in the + // instruction. + _.endian_converted_words.clear(); + _.endian_converted_words.push_back(first_word); + + // After a successful parse of the instruction, the inst.operands member + // will point to this vector's storage. + _.operands.clear(); + + assert(_.word_index < _.num_words); + // Decompose and check the first word. + uint16_t inst_word_count = 0; + spvOpcodeSplit(first_word, &inst_word_count, &inst.opcode); + if (inst_word_count < 1) { + return diagnostic() << "Invalid instruction word count: " + << inst_word_count; + } + spv_opcode_desc opcode_desc; + if (grammar_.lookupOpcode(static_cast(inst.opcode), &opcode_desc)) + return diagnostic() << "Invalid opcode: " << inst.opcode; + + // Advance past the opcode word. But remember the of the start + // of the instruction. + const size_t inst_offset = _.word_index; + _.word_index++; + + // Maintains the ordered list of expected operand types. + // For many instructions we only need the {numTypes, operandTypes} + // entries in opcode_desc. However, sometimes we need to modify + // the list as we parse the operands. This occurs when an operand + // has its own logical operands (such as the LocalSize operand for + // ExecutionMode), or for extended instructions that may have their + // own operands depending on the selected extended instruction. + _.expected_operands.clear(); + for (auto i = 0; i < opcode_desc->numTypes; i++) + _.expected_operands.push_back( + opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]); + + while (_.word_index < inst_offset + inst_word_count) { + const uint16_t inst_word_index = uint16_t(_.word_index - inst_offset); + if (_.expected_operands.empty()) { + return diagnostic() << "Invalid instruction Op" << opcode_desc->name + << " starting at word " << inst_offset + << ": expected no more operands after " + << inst_word_index + << " words, but stated word count is " + << inst_word_count << "."; + } + + spv_operand_type_t type = + spvTakeFirstMatchableOperand(&_.expected_operands); + + if (auto error = + parseOperand(inst_offset, &inst, type, &_.endian_converted_words, + &_.operands, &_.expected_operands)) { + return error; + } + } + + if (!_.expected_operands.empty() && + !spvOperandIsOptional(_.expected_operands.back())) { + return diagnostic() << "End of input reached while decoding Op" + << opcode_desc->name << " starting at word " + << inst_offset << ": expected more operands after " + << inst_word_count << " words."; + } + + if ((inst_offset + inst_word_count) != _.word_index) { + return diagnostic() << "Invalid word count: Op" << opcode_desc->name + << " starting at word " << inst_offset + << " says it has " << inst_word_count + << " words, but found " << _.word_index - inst_offset + << " words instead."; + } + + // Check the computed length of the endian-converted words vector against + // the declared number of words in the instruction. If endian conversion + // is required, then they should match. If no endian conversion was + // performed, then the vector only contains the initial opcode/word-count + // word. + assert(!_.requires_endian_conversion || + (inst_word_count == _.endian_converted_words.size())); + assert(_.requires_endian_conversion || + (_.endian_converted_words.size() == 1)); + + recordNumberType(inst_offset, &inst); + + if (_.requires_endian_conversion) { + // We must wait until here to set this pointer, because the vector might + // have been be resized while we accumulated its elements. + inst.words = _.endian_converted_words.data(); + } else { + // If no conversion is required, then just point to the underlying binary. + // This saves time and space. + inst.words = _.words + inst_offset; + } + inst.num_words = inst_word_count; + + // We must wait until here to set this pointer, because the vector might + // have been be resized while we accumulated its elements. + inst.operands = _.operands.data(); + inst.num_operands = uint16_t(_.operands.size()); + + // Issue the callback. The callee should know that all the storage in inst + // is transient, and will disappear immediately afterward. + if (parsed_instruction_fn_) { + if (auto error = parsed_instruction_fn_(user_data_, &inst)) return error; + } + + return SPV_SUCCESS; +} + +spv_result_t Parser::parseOperand(size_t inst_offset, + spv_parsed_instruction_t* inst, + const spv_operand_type_t type, + std::vector* words, + std::vector* operands, + spv_operand_pattern_t* expected_operands) { + const SpvOp opcode = static_cast(inst->opcode); + // We'll fill in this result as we go along. + spv_parsed_operand_t parsed_operand; + parsed_operand.offset = uint16_t(_.word_index - inst_offset); + // Most operands occupy one word. This might be be adjusted later. + parsed_operand.num_words = 1; + // The type argument is the one used by the grammar to parse the instruction. + // But it can exposes internal parser details such as whether an operand is + // optional or actually represents a variable-length sequence of operands. + // The resulting type should be adjusted to avoid those internal details. + // In most cases, the resulting operand type is the same as the grammar type. + parsed_operand.type = type; + + // Assume non-numeric values. This will be updated for literal numbers. + parsed_operand.number_kind = SPV_NUMBER_NONE; + parsed_operand.number_bit_width = 0; + + if (_.word_index >= _.num_words) + return exhaustedInputDiagnostic(inst_offset, opcode, type); + + const uint32_t word = peek(); + + // Do the words in this operand have to be converted to native endianness? + // True for all but literal strings. + bool convert_operand_endianness = true; + + switch (type) { + case SPV_OPERAND_TYPE_TYPE_ID: + if (!word) + return diagnostic(SPV_ERROR_INVALID_ID) << "Error: Type Id is 0"; + inst->type_id = word; + break; + + case SPV_OPERAND_TYPE_RESULT_ID: + if (!word) + return diagnostic(SPV_ERROR_INVALID_ID) << "Error: Result Id is 0"; + inst->result_id = word; + // Save the result ID to type ID mapping. + // In the grammar, type ID always appears before result ID. + if (_.id_to_type_id.find(inst->result_id) != _.id_to_type_id.end()) + return diagnostic(SPV_ERROR_INVALID_ID) + << "Id " << inst->result_id << " is defined more than once"; + // Record it. + // A regular value maps to its type. Some instructions (e.g. OpLabel) + // have no type Id, and will map to 0. The result Id for a + // type-generating instruction (e.g. OpTypeInt) maps to itself. + _.id_to_type_id[inst->result_id] = + spvOpcodeGeneratesType(opcode) ? inst->result_id : inst->type_id; + break; + + case SPV_OPERAND_TYPE_ID: + case SPV_OPERAND_TYPE_OPTIONAL_ID: + if (!word) return diagnostic(SPV_ERROR_INVALID_ID) << "Id is 0"; + parsed_operand.type = SPV_OPERAND_TYPE_ID; + + if (opcode == SpvOpExtInst && parsed_operand.offset == 3) { + // The current word is the extended instruction set Id. + // Set the extended instruction set type for the current instruction. + auto ext_inst_type_iter = _.import_id_to_ext_inst_type.find(word); + if (ext_inst_type_iter == _.import_id_to_ext_inst_type.end()) { + return diagnostic(SPV_ERROR_INVALID_ID) + << "OpExtInst set Id " << word + << " does not reference an OpExtInstImport result Id"; + } + inst->ext_inst_type = ext_inst_type_iter->second; + } + break; + + case SPV_OPERAND_TYPE_SCOPE_ID: + case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: + // Check for trivially invalid values. The operand descriptions already + // have the word "ID" in them. + if (!word) return diagnostic() << spvOperandTypeStr(type) << " is 0"; + break; + + case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: { + assert(SpvOpExtInst == opcode); + assert(inst->ext_inst_type != SPV_EXT_INST_TYPE_NONE); + spv_ext_inst_desc ext_inst; + if (grammar_.lookupExtInst(inst->ext_inst_type, word, &ext_inst)) + return diagnostic() << "Invalid extended instruction number: " << word; + spvPushOperandTypes(ext_inst->operandTypes, expected_operands); + } break; + + case SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER: { + assert(SpvOpSpecConstantOp == opcode); + if (grammar_.lookupSpecConstantOpcode(SpvOp(word))) { + return diagnostic() + << "Invalid " << spvOperandTypeStr(type) << ": " << word; + } + spv_opcode_desc opcode_entry = nullptr; + if (grammar_.lookupOpcode(SpvOp(word), &opcode_entry)) { + return diagnostic(SPV_ERROR_INTERNAL) + << "OpSpecConstant opcode table out of sync"; + } + // OpSpecConstant opcodes must have a type and result. We've already + // processed them, so skip them when preparing to parse the other + // operants for the opcode. + assert(opcode_entry->hasType); + assert(opcode_entry->hasResult); + assert(opcode_entry->numTypes >= 2); + spvPushOperandTypes(opcode_entry->operandTypes + 2, expected_operands); + } break; + + case SPV_OPERAND_TYPE_LITERAL_INTEGER: + case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: + // These are regular single-word literal integer operands. + // Post-parsing validation should check the range of the parsed value. + parsed_operand.type = SPV_OPERAND_TYPE_LITERAL_INTEGER; + // It turns out they are always unsigned integers! + parsed_operand.number_kind = SPV_NUMBER_UNSIGNED_INT; + parsed_operand.number_bit_width = 32; + break; + + case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: + case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER: + parsed_operand.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER; + if (opcode == SpvOpSwitch) { + // The literal operands have the same type as the value + // referenced by the selector Id. + const uint32_t selector_id = peekAt(inst_offset + 1); + const auto type_id_iter = _.id_to_type_id.find(selector_id); + if (type_id_iter == _.id_to_type_id.end() || + type_id_iter->second == 0) { + return diagnostic() << "Invalid OpSwitch: selector id " << selector_id + << " has no type"; + } + uint32_t type_id = type_id_iter->second; + + if (selector_id == type_id) { + // Recall that by convention, a result ID that is a type definition + // maps to itself. + return diagnostic() << "Invalid OpSwitch: selector id " << selector_id + << " is a type, not a value"; + } + if (auto error = setNumericTypeInfoForType(&parsed_operand, type_id)) + return error; + if (parsed_operand.number_kind != SPV_NUMBER_UNSIGNED_INT && + parsed_operand.number_kind != SPV_NUMBER_SIGNED_INT) { + return diagnostic() << "Invalid OpSwitch: selector id " << selector_id + << " is not a scalar integer"; + } + } else { + assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant); + // The literal number type is determined by the type Id for the + // constant. + assert(inst->type_id); + if (auto error = + setNumericTypeInfoForType(&parsed_operand, inst->type_id)) + return error; + } + break; + + case SPV_OPERAND_TYPE_LITERAL_STRING: + case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: { + convert_operand_endianness = false; + const char* string = + reinterpret_cast(_.words + _.word_index); + // Compute the length of the string, but make sure we don't run off the + // end of the input. + const size_t remaining_input_bytes = + sizeof(uint32_t) * (_.num_words - _.word_index); + const size_t string_num_content_bytes = + spv_strnlen_s(string, remaining_input_bytes); + // If there was no terminating null byte, then that's an end-of-input + // error. + if (string_num_content_bytes == remaining_input_bytes) + return exhaustedInputDiagnostic(inst_offset, opcode, type); + // Account for null in the word length, so add 1 for null, then add 3 to + // make sure we round up. The following is equivalent to: + // (string_num_content_bytes + 1 + 3) / 4 + const size_t string_num_words = string_num_content_bytes / 4 + 1; + // Make sure we can record the word count without overflow. + // + // This error can't currently be triggered because of validity + // checks elsewhere. + if (string_num_words > std::numeric_limits::max()) { + return diagnostic() << "Literal string is longer than " + << std::numeric_limits::max() + << " words: " << string_num_words << " words long"; + } + parsed_operand.num_words = uint16_t(string_num_words); + parsed_operand.type = SPV_OPERAND_TYPE_LITERAL_STRING; + + if (SpvOpExtInstImport == opcode) { + // Record the extended instruction type for the ID for this import. + // There is only one string literal argument to OpExtInstImport, + // so it's sufficient to guard this just on the opcode. + const spv_ext_inst_type_t ext_inst_type = + spvExtInstImportTypeGet(string); + if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) { + return diagnostic() + << "Invalid extended instruction import '" << string << "'"; + } + // We must have parsed a valid result ID. It's a condition + // of the grammar, and we only accept non-zero result Ids. + assert(inst->result_id); + _.import_id_to_ext_inst_type[inst->result_id] = ext_inst_type; + } + } break; + + case SPV_OPERAND_TYPE_CAPABILITY: + case SPV_OPERAND_TYPE_SOURCE_LANGUAGE: + case SPV_OPERAND_TYPE_EXECUTION_MODEL: + case SPV_OPERAND_TYPE_ADDRESSING_MODEL: + case SPV_OPERAND_TYPE_MEMORY_MODEL: + case SPV_OPERAND_TYPE_EXECUTION_MODE: + case SPV_OPERAND_TYPE_STORAGE_CLASS: + case SPV_OPERAND_TYPE_DIMENSIONALITY: + case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE: + case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE: + case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT: + case SPV_OPERAND_TYPE_FP_ROUNDING_MODE: + case SPV_OPERAND_TYPE_LINKAGE_TYPE: + case SPV_OPERAND_TYPE_ACCESS_QUALIFIER: + case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER: + case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE: + case SPV_OPERAND_TYPE_DECORATION: + case SPV_OPERAND_TYPE_BUILT_IN: + case SPV_OPERAND_TYPE_GROUP_OPERATION: + case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS: + case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: + case SPV_OPERAND_TYPE_DEBUG_BASE_TYPE_ATTRIBUTE_ENCODING: + case SPV_OPERAND_TYPE_DEBUG_COMPOSITE_TYPE: + case SPV_OPERAND_TYPE_DEBUG_TYPE_QUALIFIER: + case SPV_OPERAND_TYPE_DEBUG_OPERATION: { + // A single word that is a plain enum value. + + // Map an optional operand type to its corresponding concrete type. + if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER) + parsed_operand.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER; + + spv_operand_desc entry; + if (grammar_.lookupOperand(type, word, &entry)) { + return diagnostic() + << "Invalid " << spvOperandTypeStr(parsed_operand.type) + << " operand: " << word; + } + // Prepare to accept operands to this operand, if needed. + spvPushOperandTypes(entry->operandTypes, expected_operands); + } break; + + case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE: + case SPV_OPERAND_TYPE_FUNCTION_CONTROL: + case SPV_OPERAND_TYPE_LOOP_CONTROL: + case SPV_OPERAND_TYPE_IMAGE: + case SPV_OPERAND_TYPE_OPTIONAL_IMAGE: + case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS: + case SPV_OPERAND_TYPE_SELECTION_CONTROL: + case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS: { + // This operand is a mask. + + // Map an optional operand type to its corresponding concrete type. + if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE) + parsed_operand.type = SPV_OPERAND_TYPE_IMAGE; + else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS) + parsed_operand.type = SPV_OPERAND_TYPE_MEMORY_ACCESS; + + // Check validity of set mask bits. Also prepare for operands for those + // masks if they have any. To get operand order correct, scan from + // MSB to LSB since we can only prepend operands to a pattern. + // The only case in the grammar where you have more than one mask bit + // having an operand is for image operands. See SPIR-V 3.14 Image + // Operands. + uint32_t remaining_word = word; + for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) { + if (remaining_word & mask) { + spv_operand_desc entry; + if (grammar_.lookupOperand(type, mask, &entry)) { + return diagnostic() + << "Invalid " << spvOperandTypeStr(parsed_operand.type) + << " operand: " << word << " has invalid mask component " + << mask; + } + remaining_word ^= mask; + spvPushOperandTypes(entry->operandTypes, expected_operands); + } + } + if (word == 0) { + // An all-zeroes mask *might* also be valid. + spv_operand_desc entry; + if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) { + // Prepare for its operands, if any. + spvPushOperandTypes(entry->operandTypes, expected_operands); + } + } + } break; + default: + return diagnostic() << "Internal error: Unhandled operand type: " << type; + } + + assert(spvOperandIsConcrete(parsed_operand.type)); + + operands->push_back(parsed_operand); + + const size_t index_after_operand = _.word_index + parsed_operand.num_words; + + // Avoid buffer overrun for the cases where the operand has more than one + // word, and where it isn't a string. (Those other cases have already been + // handled earlier.) For example, this error can occur for a multi-word + // argument to OpConstant, or a multi-word case literal operand for OpSwitch. + if (_.num_words < index_after_operand) + return exhaustedInputDiagnostic(inst_offset, opcode, type); + + if (_.requires_endian_conversion) { + // Copy instruction words. Translate to native endianness as needed. + if (convert_operand_endianness) { + const spv_endianness_t endianness = _.endian; + std::transform(_.words + _.word_index, _.words + index_after_operand, + std::back_inserter(*words), + [endianness](const uint32_t raw_word) { + return spvFixWord(raw_word, endianness); + }); + } else { + words->insert(words->end(), _.words + _.word_index, + _.words + index_after_operand); + } + } + + // Advance past the operand. + _.word_index = index_after_operand; + + return SPV_SUCCESS; +} + +spv_result_t Parser::setNumericTypeInfoForType( + spv_parsed_operand_t* parsed_operand, uint32_t type_id) { + assert(type_id != 0); + auto type_info_iter = _.type_id_to_number_type_info.find(type_id); + if (type_info_iter == _.type_id_to_number_type_info.end()) { + return diagnostic() << "Type Id " << type_id << " is not a type"; + } + const NumberType& info = type_info_iter->second; + if (info.type == SPV_NUMBER_NONE) { + // This is a valid type, but for something other than a scalar number. + return diagnostic() << "Type Id " << type_id + << " is not a scalar numeric type"; + } + + parsed_operand->number_kind = info.type; + parsed_operand->number_bit_width = info.bit_width; + // Round up the word count. + parsed_operand->num_words = static_cast((info.bit_width + 31) / 32); + return SPV_SUCCESS; +} + +void Parser::recordNumberType(size_t inst_offset, + const spv_parsed_instruction_t* inst) { + const SpvOp opcode = static_cast(inst->opcode); + if (spvOpcodeGeneratesType(opcode)) { + NumberType info = {SPV_NUMBER_NONE, 0}; + if (SpvOpTypeInt == opcode) { + const bool is_signed = peekAt(inst_offset + 3) != 0; + info.type = is_signed ? SPV_NUMBER_SIGNED_INT : SPV_NUMBER_UNSIGNED_INT; + info.bit_width = peekAt(inst_offset + 2); + } else if (SpvOpTypeFloat == opcode) { + info.type = SPV_NUMBER_FLOATING; + info.bit_width = peekAt(inst_offset + 2); + } + // The *result* Id of a type generating instruction is the type Id. + _.type_id_to_number_type_info[inst->result_id] = info; + } +} + +} // anonymous namespace + +spv_result_t spvBinaryParse(const spv_const_context context, void* user_data, + const uint32_t* code, const size_t num_words, + spv_parsed_header_fn_t parsed_header, + spv_parsed_instruction_fn_t parsed_instruction, + spv_diagnostic* diagnostic) { + spv_context_t hijack_context = *context; + if (diagnostic) { + *diagnostic = nullptr; + spvtools::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic); + } + Parser parser(&hijack_context, user_data, parsed_header, parsed_instruction); + return parser.parse(code, num_words, diagnostic); +} + +// TODO(dneto): This probably belongs in text.cpp since that's the only place +// that a spv_binary_t value is created. +void spvBinaryDestroy(spv_binary binary) { + if (!binary) return; + delete[] binary->code; + delete binary; +} + +size_t spv_strnlen_s(const char* str, size_t strsz) { + if (!str) return 0; + for (size_t i = 0; i < strsz; i++) { + if (!str[i]) return i; + } + return strsz; +} diff --git a/source/binary.h b/source/binary.h new file mode 100644 index 000000000..66d24c7e4 --- /dev/null +++ b/source/binary.h @@ -0,0 +1,36 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_BINARY_H_ +#define SOURCE_BINARY_H_ + +#include "source/spirv_definition.h" +#include "spirv-tools/libspirv.h" + +// Functions + +// Grabs the header from the SPIR-V module given in the binary parameter. The +// endian parameter specifies the endianness of the binary module. On success, +// returns SPV_SUCCESS and writes the parsed header into *header. +spv_result_t spvBinaryHeaderGet(const spv_const_binary binary, + const spv_endianness_t endian, + spv_header_t* header); + +// Returns the number of non-null characters in str before the first null +// character, or strsz if there is no null character. Examines at most the +// first strsz characters in str. Returns 0 if str is nullptr. This is a +// replacement for C11's strnlen_s which might not exist in all environments. +size_t spv_strnlen_s(const char* str, size_t strsz); + +#endif // SOURCE_BINARY_H_ diff --git a/source/cfa.h b/source/cfa.h new file mode 100644 index 000000000..97ef398d6 --- /dev/null +++ b/source/cfa.h @@ -0,0 +1,347 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_CFA_H_ +#define SOURCE_CFA_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace spvtools { + +// Control Flow Analysis of control flow graphs of basic block nodes |BB|. +template +class CFA { + using bb_ptr = BB*; + using cbb_ptr = const BB*; + using bb_iter = typename std::vector::const_iterator; + using get_blocks_func = std::function*(const BB*)>; + + struct block_info { + cbb_ptr block; ///< pointer to the block + bb_iter iter; ///< Iterator to the current child node being processed + }; + + /// Returns true if a block with @p id is found in the @p work_list vector + /// + /// @param[in] work_list Set of blocks visited in the the depth first + /// traversal + /// of the CFG + /// @param[in] id The ID of the block being checked + /// + /// @return true if the edge work_list.back().block->id() => id is a back-edge + static bool FindInWorkList(const std::vector& work_list, + uint32_t id); + + public: + /// @brief Depth first traversal starting from the \p entry BasicBlock + /// + /// This function performs a depth first traversal from the \p entry + /// BasicBlock and calls the pre/postorder functions when it needs to process + /// the node in pre order, post order. It also calls the backedge function + /// when a back edge is encountered. + /// + /// @param[in] entry The root BasicBlock of a CFG + /// @param[in] successor_func A function which will return a pointer to the + /// successor nodes + /// @param[in] preorder A function that will be called for every block in a + /// CFG following preorder traversal semantics + /// @param[in] postorder A function that will be called for every block in a + /// CFG following postorder traversal semantics + /// @param[in] backedge A function that will be called when a backedge is + /// encountered during a traversal + /// NOTE: The @p successor_func and predecessor_func each return a pointer to + /// a + /// collection such that iterators to that collection remain valid for the + /// lifetime of the algorithm. + static void DepthFirstTraversal( + const BB* entry, get_blocks_func successor_func, + std::function preorder, + std::function postorder, + std::function backedge); + + /// @brief Calculates dominator edges for a set of blocks + /// + /// Computes dominators using the algorithm of Cooper, Harvey, and Kennedy + /// "A Simple, Fast Dominance Algorithm", 2001. + /// + /// The algorithm assumes there is a unique root node (a node without + /// predecessors), and it is therefore at the end of the postorder vector. + /// + /// This function calculates the dominator edges for a set of blocks in the + /// CFG. + /// Uses the dominator algorithm by Cooper et al. + /// + /// @param[in] postorder A vector of blocks in post order traversal + /// order + /// in a CFG + /// @param[in] predecessor_func Function used to get the predecessor nodes of + /// a + /// block + /// + /// @return the dominator tree of the graph, as a vector of pairs of nodes. + /// The first node in the pair is a node in the graph. The second node in the + /// pair is its immediate dominator in the sense of Cooper et.al., where a + /// block + /// without predecessors (such as the root node) is its own immediate + /// dominator. + static std::vector> CalculateDominators( + const std::vector& postorder, get_blocks_func predecessor_func); + + // Computes a minimal set of root nodes required to traverse, in the forward + // direction, the CFG represented by the given vector of blocks, and successor + // and predecessor functions. When considering adding two nodes, each having + // predecessors, favour using the one that appears earlier on the input blocks + // list. + static std::vector TraversalRoots(const std::vector& blocks, + get_blocks_func succ_func, + get_blocks_func pred_func); + + static void ComputeAugmentedCFG( + std::vector& ordered_blocks, BB* pseudo_entry_block, + BB* pseudo_exit_block, + std::unordered_map>* augmented_successors_map, + std::unordered_map>* + augmented_predecessors_map, + get_blocks_func succ_func, get_blocks_func pred_func); +}; + +template +bool CFA::FindInWorkList(const std::vector& work_list, + uint32_t id) { + for (const auto b : work_list) { + if (b.block->id() == id) return true; + } + return false; +} + +template +void CFA::DepthFirstTraversal( + const BB* entry, get_blocks_func successor_func, + std::function preorder, + std::function postorder, + std::function backedge) { + std::unordered_set processed; + + /// NOTE: work_list is the sequence of nodes from the root node to the node + /// being processed in the traversal + std::vector work_list; + work_list.reserve(10); + + work_list.push_back({entry, std::begin(*successor_func(entry))}); + preorder(entry); + processed.insert(entry->id()); + + while (!work_list.empty()) { + block_info& top = work_list.back(); + if (top.iter == end(*successor_func(top.block))) { + postorder(top.block); + work_list.pop_back(); + } else { + BB* child = *top.iter; + top.iter++; + if (FindInWorkList(work_list, child->id())) { + backedge(top.block, child); + } + if (processed.count(child->id()) == 0) { + preorder(child); + work_list.emplace_back( + block_info{child, std::begin(*successor_func(child))}); + processed.insert(child->id()); + } + } + } +} + +template +std::vector> CFA::CalculateDominators( + const std::vector& postorder, get_blocks_func predecessor_func) { + struct block_detail { + size_t dominator; ///< The index of blocks's dominator in post order array + size_t postorder_index; ///< The index of the block in the post order array + }; + const size_t undefined_dom = postorder.size(); + + std::unordered_map idoms; + for (size_t i = 0; i < postorder.size(); i++) { + idoms[postorder[i]] = {undefined_dom, i}; + } + idoms[postorder.back()].dominator = idoms[postorder.back()].postorder_index; + + bool changed = true; + while (changed) { + changed = false; + for (auto b = postorder.rbegin() + 1; b != postorder.rend(); ++b) { + const std::vector& predecessors = *predecessor_func(*b); + // Find the first processed/reachable predecessor that is reachable + // in the forward traversal. + auto res = std::find_if(std::begin(predecessors), std::end(predecessors), + [&idoms, undefined_dom](BB* pred) { + return idoms.count(pred) && + idoms[pred].dominator != undefined_dom; + }); + if (res == end(predecessors)) continue; + const BB* idom = *res; + size_t idom_idx = idoms[idom].postorder_index; + + // all other predecessors + for (const auto* p : predecessors) { + if (idom == p) continue; + // Only consider nodes reachable in the forward traversal. + // Otherwise the intersection doesn't make sense and will never + // terminate. + if (!idoms.count(p)) continue; + if (idoms[p].dominator != undefined_dom) { + size_t finger1 = idoms[p].postorder_index; + size_t finger2 = idom_idx; + while (finger1 != finger2) { + while (finger1 < finger2) { + finger1 = idoms[postorder[finger1]].dominator; + } + while (finger2 < finger1) { + finger2 = idoms[postorder[finger2]].dominator; + } + } + idom_idx = finger1; + } + } + if (idoms[*b].dominator != idom_idx) { + idoms[*b].dominator = idom_idx; + changed = true; + } + } + } + + std::vector> out; + for (auto idom : idoms) { + // NOTE: performing a const cast for convenient usage with + // UpdateImmediateDominators + out.push_back({const_cast(std::get<0>(idom)), + const_cast(postorder[std::get<1>(idom).dominator])}); + } + + // Sort by postorder index to generate a deterministic ordering of edges. + std::sort( + out.begin(), out.end(), + [&idoms](const std::pair& lhs, + const std::pair& rhs) { + assert(lhs.first); + assert(lhs.second); + assert(rhs.first); + assert(rhs.second); + auto lhs_indices = std::make_pair(idoms[lhs.first].postorder_index, + idoms[lhs.second].postorder_index); + auto rhs_indices = std::make_pair(idoms[rhs.first].postorder_index, + idoms[rhs.second].postorder_index); + return lhs_indices < rhs_indices; + }); + return out; +} + +template +std::vector CFA::TraversalRoots(const std::vector& blocks, + get_blocks_func succ_func, + get_blocks_func pred_func) { + // The set of nodes which have been visited from any of the roots so far. + std::unordered_set visited; + + auto mark_visited = [&visited](const BB* b) { visited.insert(b); }; + auto ignore_block = [](const BB*) {}; + auto ignore_blocks = [](const BB*, const BB*) {}; + + auto traverse_from_root = [&mark_visited, &succ_func, &ignore_block, + &ignore_blocks](const BB* entry) { + DepthFirstTraversal(entry, succ_func, mark_visited, ignore_block, + ignore_blocks); + }; + + std::vector result; + + // First collect nodes without predecessors. + for (auto block : blocks) { + if (pred_func(block)->empty()) { + assert(visited.count(block) == 0 && "Malformed graph!"); + result.push_back(block); + traverse_from_root(block); + } + } + + // Now collect other stranded nodes. These must be in unreachable cycles. + for (auto block : blocks) { + if (visited.count(block) == 0) { + result.push_back(block); + traverse_from_root(block); + } + } + + return result; +} + +template +void CFA::ComputeAugmentedCFG( + std::vector& ordered_blocks, BB* pseudo_entry_block, + BB* pseudo_exit_block, + std::unordered_map>* augmented_successors_map, + std::unordered_map>* augmented_predecessors_map, + get_blocks_func succ_func, get_blocks_func pred_func) { + // Compute the successors of the pseudo-entry block, and + // the predecessors of the pseudo exit block. + auto sources = TraversalRoots(ordered_blocks, succ_func, pred_func); + + // For the predecessor traversals, reverse the order of blocks. This + // will affect the post-dominance calculation as follows: + // - Suppose you have blocks A and B, with A appearing before B in + // the list of blocks. + // - Also, A branches only to B, and B branches only to A. + // - We want to compute A as dominating B, and B as post-dominating B. + // By using reversed blocks for predecessor traversal roots discovery, + // we'll add an edge from B to the pseudo-exit node, rather than from A. + // All this is needed to correctly process the dominance/post-dominance + // constraint when A is a loop header that points to itself as its + // own continue target, and B is the latch block for the loop. + std::vector reversed_blocks(ordered_blocks.rbegin(), + ordered_blocks.rend()); + auto sinks = TraversalRoots(reversed_blocks, pred_func, succ_func); + + // Wire up the pseudo entry block. + (*augmented_successors_map)[pseudo_entry_block] = sources; + for (auto block : sources) { + auto& augmented_preds = (*augmented_predecessors_map)[block]; + const auto preds = pred_func(block); + augmented_preds.reserve(1 + preds->size()); + augmented_preds.push_back(pseudo_entry_block); + augmented_preds.insert(augmented_preds.end(), preds->begin(), preds->end()); + } + + // Wire up the pseudo exit block. + (*augmented_predecessors_map)[pseudo_exit_block] = sinks; + for (auto block : sinks) { + auto& augmented_succ = (*augmented_successors_map)[block]; + const auto succ = succ_func(block); + augmented_succ.reserve(1 + succ->size()); + augmented_succ.push_back(pseudo_exit_block); + augmented_succ.insert(augmented_succ.end(), succ->begin(), succ->end()); + } +} + +} // namespace spvtools + +#endif // SOURCE_CFA_H_ diff --git a/source/comp/CMakeLists.txt b/source/comp/CMakeLists.txt new file mode 100644 index 000000000..f65f9f670 --- /dev/null +++ b/source/comp/CMakeLists.txt @@ -0,0 +1,52 @@ +# Copyright (c) 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(SPIRV_BUILD_COMPRESSION) + add_library(SPIRV-Tools-comp + bit_stream.cpp + bit_stream.h + huffman_codec.h + markv_codec.cpp + markv_codec.h + markv.cpp + markv.h + markv_decoder.cpp + markv_decoder.h + markv_encoder.cpp + markv_encoder.h + markv_logger.h + move_to_front.h + move_to_front.cpp) + + spvtools_default_compile_options(SPIRV-Tools-comp) + target_include_directories(SPIRV-Tools-comp + PUBLIC ${spirv-tools_SOURCE_DIR}/include + PUBLIC ${SPIRV_HEADER_INCLUDE_DIR} + PRIVATE ${spirv-tools_BINARY_DIR} + ) + + target_link_libraries(SPIRV-Tools-comp + PUBLIC ${SPIRV_TOOLS}) + + set_property(TARGET SPIRV-Tools-comp PROPERTY FOLDER "SPIRV-Tools libraries") + spvtools_check_symbol_exports(SPIRV-Tools-comp) + + if(ENABLE_SPIRV_TOOLS_INSTALL) + install(TARGETS SPIRV-Tools-comp + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}) + endif(ENABLE_SPIRV_TOOLS_INSTALL) + +endif(SPIRV_BUILD_COMPRESSION) diff --git a/source/comp/bit_stream.cpp b/source/comp/bit_stream.cpp new file mode 100644 index 000000000..a5769e03e --- /dev/null +++ b/source/comp/bit_stream.cpp @@ -0,0 +1,348 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "source/comp/bit_stream.h" + +namespace spvtools { +namespace comp { +namespace { + +// Returns if the system is little-endian. Unfortunately only works during +// runtime. +bool IsLittleEndian() { + // This constant value allows the detection of the host machine's endianness. + // Accessing it as an array of bytes is valid due to C++11 section 3.10 + // paragraph 10. + static const uint16_t kFF00 = 0xff00; + return reinterpret_cast(&kFF00)[0] == 0; +} + +// Copies bytes from the given buffer to a uint64_t buffer. +// Motivation: casting uint64_t* to uint8_t* is ok. Casting in the other +// direction is only advisable if uint8_t* is aligned to 64-bit word boundary. +std::vector ToBuffer64(const void* buffer, size_t num_bytes) { + std::vector out; + out.resize((num_bytes + 7) / 8, 0); + memcpy(out.data(), buffer, num_bytes); + return out; +} + +// Copies uint8_t buffer to a uint64_t buffer. +std::vector ToBuffer64(const std::vector& in) { + return ToBuffer64(in.data(), in.size()); +} + +// Returns uint64_t containing the same bits as |val|. +// Type size must be less than 8 bytes. +template +uint64_t ToU64(T val) { + static_assert(sizeof(T) <= 8, "Type size too big"); + uint64_t val64 = 0; + std::memcpy(&val64, &val, sizeof(T)); + return val64; +} + +// Returns value of type T containing the same bits as |val64|. +// Type size must be less than 8 bytes. Upper (unused) bits of |val64| must be +// zero (irrelevant, but is checked with assertion). +template +T FromU64(uint64_t val64) { + assert(sizeof(T) == 8 || (val64 >> (sizeof(T) * 8)) == 0); + static_assert(sizeof(T) <= 8, "Type size too big"); + T val = 0; + std::memcpy(&val, &val64, sizeof(T)); + return val; +} + +// Writes bits from |val| to |writer| in chunks of size |chunk_length|. +// Signal bit is used to signal if the reader should expect another chunk: +// 0 - no more chunks to follow +// 1 - more chunks to follow +// If number of written bits reaches |max_payload| last chunk is truncated. +void WriteVariableWidthInternal(BitWriterInterface* writer, uint64_t val, + size_t chunk_length, size_t max_payload) { + assert(chunk_length > 0); + assert(chunk_length < max_payload); + assert(max_payload == 64 || (val >> max_payload) == 0); + + if (val == 0) { + // Split in two writes for more readable logging. + writer->WriteBits(0, chunk_length); + writer->WriteBits(0, 1); + return; + } + + size_t payload_written = 0; + + while (val) { + if (payload_written + chunk_length >= max_payload) { + // This has to be the last chunk. + // There is no need for the signal bit and the chunk can be truncated. + const size_t left_to_write = max_payload - payload_written; + assert((val >> left_to_write) == 0); + writer->WriteBits(val, left_to_write); + break; + } + + writer->WriteBits(val, chunk_length); + payload_written += chunk_length; + val = val >> chunk_length; + + // Write a single bit to signal if there is more to come. + writer->WriteBits(val ? 1 : 0, 1); + } +} + +// Reads data written with WriteVariableWidthInternal. |chunk_length| and +// |max_payload| should be identical to those used to write the data. +// Returns false if the stream ends prematurely. +bool ReadVariableWidthInternal(BitReaderInterface* reader, uint64_t* val, + size_t chunk_length, size_t max_payload) { + assert(chunk_length > 0); + assert(chunk_length <= max_payload); + size_t payload_read = 0; + + while (payload_read + chunk_length < max_payload) { + uint64_t bits = 0; + if (reader->ReadBits(&bits, chunk_length) != chunk_length) return false; + + *val |= bits << payload_read; + payload_read += chunk_length; + + uint64_t more_to_come = 0; + if (reader->ReadBits(&more_to_come, 1) != 1) return false; + + if (!more_to_come) { + return true; + } + } + + // Need to read the last chunk which may be truncated. No signal bit follows. + uint64_t bits = 0; + const size_t left_to_read = max_payload - payload_read; + if (reader->ReadBits(&bits, left_to_read) != left_to_read) return false; + + *val |= bits << payload_read; + return true; +} + +// Calls WriteVariableWidthInternal with the right max_payload argument. +template +void WriteVariableWidthUnsigned(BitWriterInterface* writer, T val, + size_t chunk_length) { + static_assert(std::is_unsigned::value, "Type must be unsigned"); + static_assert(std::is_integral::value, "Type must be integral"); + WriteVariableWidthInternal(writer, val, chunk_length, sizeof(T) * 8); +} + +// Calls ReadVariableWidthInternal with the right max_payload argument. +template +bool ReadVariableWidthUnsigned(BitReaderInterface* reader, T* val, + size_t chunk_length) { + static_assert(std::is_unsigned::value, "Type must be unsigned"); + static_assert(std::is_integral::value, "Type must be integral"); + uint64_t val64 = 0; + if (!ReadVariableWidthInternal(reader, &val64, chunk_length, sizeof(T) * 8)) + return false; + *val = static_cast(val64); + assert(*val == val64); + return true; +} + +// Encodes signed |val| to an unsigned value and calls +// WriteVariableWidthInternal with the right max_payload argument. +template +void WriteVariableWidthSigned(BitWriterInterface* writer, T val, + size_t chunk_length, size_t zigzag_exponent) { + static_assert(std::is_signed::value, "Type must be signed"); + static_assert(std::is_integral::value, "Type must be integral"); + WriteVariableWidthInternal(writer, EncodeZigZag(val, zigzag_exponent), + chunk_length, sizeof(T) * 8); +} + +// Calls ReadVariableWidthInternal with the right max_payload argument +// and decodes the value. +template +bool ReadVariableWidthSigned(BitReaderInterface* reader, T* val, + size_t chunk_length, size_t zigzag_exponent) { + static_assert(std::is_signed::value, "Type must be signed"); + static_assert(std::is_integral::value, "Type must be integral"); + uint64_t encoded = 0; + if (!ReadVariableWidthInternal(reader, &encoded, chunk_length, sizeof(T) * 8)) + return false; + + const int64_t decoded = DecodeZigZag(encoded, zigzag_exponent); + + *val = static_cast(decoded); + assert(*val == decoded); + return true; +} + +} // namespace + +void BitWriterInterface::WriteVariableWidthU64(uint64_t val, + size_t chunk_length) { + WriteVariableWidthUnsigned(this, val, chunk_length); +} + +void BitWriterInterface::WriteVariableWidthU32(uint32_t val, + size_t chunk_length) { + WriteVariableWidthUnsigned(this, val, chunk_length); +} + +void BitWriterInterface::WriteVariableWidthU16(uint16_t val, + size_t chunk_length) { + WriteVariableWidthUnsigned(this, val, chunk_length); +} + +void BitWriterInterface::WriteVariableWidthS64(int64_t val, size_t chunk_length, + size_t zigzag_exponent) { + WriteVariableWidthSigned(this, val, chunk_length, zigzag_exponent); +} + +BitWriterWord64::BitWriterWord64(size_t reserve_bits) : end_(0) { + buffer_.reserve(NumBitsToNumWords<64>(reserve_bits)); +} + +void BitWriterWord64::WriteBits(uint64_t bits, size_t num_bits) { + // Check that |bits| and |num_bits| are valid and consistent. + assert(num_bits <= 64); + const bool is_little_endian = IsLittleEndian(); + assert(is_little_endian && "Big-endian architecture support not implemented"); + if (!is_little_endian) return; + + if (num_bits == 0) return; + + bits = GetLowerBits(bits, num_bits); + + EmitSequence(bits, num_bits); + + // Offset from the start of the current word. + const size_t offset = end_ % 64; + + if (offset == 0) { + // If no offset, simply add |bits| as a new word to the buffer_. + buffer_.push_back(bits); + } else { + // Shift bits and add them to the current word after offset. + const uint64_t first_word = bits << offset; + buffer_.back() |= first_word; + + // If we don't overflow to the next word, there is nothing more to do. + + if (offset + num_bits > 64) { + // We overflow to the next word. + const uint64_t second_word = bits >> (64 - offset); + // Add remaining bits as a new word to buffer_. + buffer_.push_back(second_word); + } + } + + // Move end_ into position for next write. + end_ += num_bits; + assert(buffer_.size() * 64 >= end_); +} + +bool BitReaderInterface::ReadVariableWidthU64(uint64_t* val, + size_t chunk_length) { + return ReadVariableWidthUnsigned(this, val, chunk_length); +} + +bool BitReaderInterface::ReadVariableWidthU32(uint32_t* val, + size_t chunk_length) { + return ReadVariableWidthUnsigned(this, val, chunk_length); +} + +bool BitReaderInterface::ReadVariableWidthU16(uint16_t* val, + size_t chunk_length) { + return ReadVariableWidthUnsigned(this, val, chunk_length); +} + +bool BitReaderInterface::ReadVariableWidthS64(int64_t* val, size_t chunk_length, + size_t zigzag_exponent) { + return ReadVariableWidthSigned(this, val, chunk_length, zigzag_exponent); +} + +BitReaderWord64::BitReaderWord64(std::vector&& buffer) + : buffer_(std::move(buffer)), pos_(0) {} + +BitReaderWord64::BitReaderWord64(const std::vector& buffer) + : buffer_(ToBuffer64(buffer)), pos_(0) {} + +BitReaderWord64::BitReaderWord64(const void* buffer, size_t num_bytes) + : buffer_(ToBuffer64(buffer, num_bytes)), pos_(0) {} + +size_t BitReaderWord64::ReadBits(uint64_t* bits, size_t num_bits) { + assert(num_bits <= 64); + const bool is_little_endian = IsLittleEndian(); + assert(is_little_endian && "Big-endian architecture support not implemented"); + if (!is_little_endian) return 0; + + if (ReachedEnd()) return 0; + + // Index of the current word. + const size_t index = pos_ / 64; + + // Bit position in the current word where we start reading. + const size_t offset = pos_ % 64; + + // Read all bits from the current word (it might be too much, but + // excessive bits will be removed later). + *bits = buffer_[index] >> offset; + + const size_t num_read_from_first_word = std::min(64 - offset, num_bits); + pos_ += num_read_from_first_word; + + if (pos_ >= buffer_.size() * 64) { + // Reached end of buffer_. + EmitSequence(*bits, num_read_from_first_word); + return num_read_from_first_word; + } + + if (offset + num_bits > 64) { + // Requested |num_bits| overflows to next word. + // Write all bits from the beginning of next word to *bits after offset. + *bits |= buffer_[index + 1] << (64 - offset); + pos_ += offset + num_bits - 64; + } + + // We likely have written more bits than requested. Clear excessive bits. + *bits = GetLowerBits(*bits, num_bits); + EmitSequence(*bits, num_bits); + return num_bits; +} + +bool BitReaderWord64::ReachedEnd() const { return pos_ >= buffer_.size() * 64; } + +bool BitReaderWord64::OnlyZeroesLeft() const { + if (ReachedEnd()) return true; + + const size_t index = pos_ / 64; + if (index < buffer_.size() - 1) return false; + + assert(index == buffer_.size() - 1); + + const size_t offset = pos_ % 64; + const uint64_t remaining_bits = buffer_[index] >> offset; + return !remaining_bits; +} + +} // namespace comp +} // namespace spvtools diff --git a/source/comp/bit_stream.h b/source/comp/bit_stream.h new file mode 100644 index 000000000..5f82344d6 --- /dev/null +++ b/source/comp/bit_stream.h @@ -0,0 +1,280 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Contains utils for reading, writing and debug printing bit streams. + +#ifndef SOURCE_COMP_BIT_STREAM_H_ +#define SOURCE_COMP_BIT_STREAM_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace spvtools { +namespace comp { + +// Terminology: +// Bits - usually used for a uint64 word, first bit is the lowest. +// Stream - std::string of '0' and '1', read left-to-right, +// i.e. first bit is at the front and not at the end as in +// std::bitset::to_string(). +// Bitset - std::bitset corresponding to uint64 bits and to reverse(stream). + +// Converts number of bits to a respective number of chunks of size N. +// For example NumBitsToNumWords<8> returns how many bytes are needed to store +// |num_bits|. +template +inline size_t NumBitsToNumWords(size_t num_bits) { + return (num_bits + (N - 1)) / N; +} + +// Returns value of the same type as |in|, where all but the first |num_bits| +// are set to zero. +template +inline T GetLowerBits(T in, size_t num_bits) { + return sizeof(T) * 8 == num_bits ? in : in & T((T(1) << num_bits) - T(1)); +} + +// Encodes signed integer as unsigned. This is a generalized version of +// EncodeZigZag, designed to favor small positive numbers. +// Values are transformed in blocks of 2^|block_exponent|. +// If |block_exponent| is zero, then this degenerates into normal EncodeZigZag. +// Example when |block_exponent| is 1 (return value is the index): +// 0, 1, -1, -2, 2, 3, -3, -4, 4, 5, -5, -6, 6, 7, -7, -8 +// Example when |block_exponent| is 2: +// 0, 1, 2, 3, -1, -2, -3, -4, 4, 5, 6, 7, -5, -6, -7, -8 +inline uint64_t EncodeZigZag(int64_t val, size_t block_exponent) { + assert(block_exponent < 64); + const uint64_t uval = static_cast(val >= 0 ? val : -val - 1); + const uint64_t block_num = + ((uval >> block_exponent) << 1) + (val >= 0 ? 0 : 1); + const uint64_t pos = GetLowerBits(uval, block_exponent); + return (block_num << block_exponent) + pos; +} + +// Decodes signed integer encoded with EncodeZigZag. |block_exponent| must be +// the same. +inline int64_t DecodeZigZag(uint64_t val, size_t block_exponent) { + assert(block_exponent < 64); + const uint64_t block_num = val >> block_exponent; + const uint64_t pos = GetLowerBits(val, block_exponent); + if (block_num & 1) { + // Negative. + return -1LL - ((block_num >> 1) << block_exponent) - pos; + } else { + // Positive. + return ((block_num >> 1) << block_exponent) + pos; + } +} + +// Converts first |num_bits| stored in uint64 to a left-to-right stream of bits. +inline std::string BitsToStream(uint64_t bits, size_t num_bits = 64) { + std::bitset<64> bitset(bits); + std::string str = bitset.to_string().substr(64 - num_bits); + std::reverse(str.begin(), str.end()); + return str; +} + +// Base class for writing sequences of bits. +class BitWriterInterface { + public: + BitWriterInterface() = default; + virtual ~BitWriterInterface() = default; + + // Writes lower |num_bits| in |bits| to the stream. + // |num_bits| must be no greater than 64. + virtual void WriteBits(uint64_t bits, size_t num_bits) = 0; + + // Writes bits from value of type |T| to the stream. No encoding is done. + // Always writes 8 * sizeof(T) bits. + template + void WriteUnencoded(T val) { + static_assert(sizeof(T) <= 64, "Type size too large"); + uint64_t bits = 0; + memcpy(&bits, &val, sizeof(T)); + WriteBits(bits, sizeof(T) * 8); + } + + // Writes |val| in chunks of size |chunk_length| followed by a signal bit: + // 0 - no more chunks to follow + // 1 - more chunks to follow + // for example 255 is encoded into 1111 1 1111 0 for chunk length 4. + // The last chunk can be truncated and signal bit omitted, if the entire + // payload (for example 16 bit for uint16_t has already been written). + void WriteVariableWidthU64(uint64_t val, size_t chunk_length); + void WriteVariableWidthU32(uint32_t val, size_t chunk_length); + void WriteVariableWidthU16(uint16_t val, size_t chunk_length); + void WriteVariableWidthS64(int64_t val, size_t chunk_length, + size_t zigzag_exponent); + + // Returns number of bits written. + virtual size_t GetNumBits() const = 0; + + // Provides direct access to the buffer data if implemented. + virtual const uint8_t* GetData() const { return nullptr; } + + // Returns buffer size in bytes. + size_t GetDataSizeBytes() const { return NumBitsToNumWords<8>(GetNumBits()); } + + // Generates and returns byte array containing written bits. + virtual std::vector GetDataCopy() const = 0; + + BitWriterInterface(const BitWriterInterface&) = delete; + BitWriterInterface& operator=(const BitWriterInterface&) = delete; +}; + +// This class is an implementation of BitWriterInterface, using +// std::vector to store written bits. +class BitWriterWord64 : public BitWriterInterface { + public: + explicit BitWriterWord64(size_t reserve_bits = 64); + + void WriteBits(uint64_t bits, size_t num_bits) override; + + size_t GetNumBits() const override { return end_; } + + const uint8_t* GetData() const override { + return reinterpret_cast(buffer_.data()); + } + + std::vector GetDataCopy() const override { + return std::vector(GetData(), GetData() + GetDataSizeBytes()); + } + + // Sets callback to emit bit sequences after every write. + void SetCallback(std::function callback) { + callback_ = callback; + } + + protected: + // Sends string generated from arguments to callback_ if defined. + void EmitSequence(uint64_t bits, size_t num_bits) const { + if (callback_) callback_(BitsToStream(bits, num_bits)); + } + + private: + std::vector buffer_; + // Total number of bits written so far. Named 'end' as analogy to std::end(). + size_t end_; + + // If not null, the writer will use the callback to emit the written bit + // sequence as a string of '0' and '1'. + std::function callback_; +}; + +// Base class for reading sequences of bits. +class BitReaderInterface { + public: + BitReaderInterface() {} + virtual ~BitReaderInterface() {} + + // Reads |num_bits| from the stream, stores them in |bits|. + // Returns number of read bits. |num_bits| must be no greater than 64. + virtual size_t ReadBits(uint64_t* bits, size_t num_bits) = 0; + + // Reads 8 * sizeof(T) bits and stores them in |val|. + template + bool ReadUnencoded(T* val) { + static_assert(sizeof(T) <= 64, "Type size too large"); + uint64_t bits = 0; + const size_t num_read = ReadBits(&bits, sizeof(T) * 8); + if (num_read != sizeof(T) * 8) return false; + memcpy(val, &bits, sizeof(T)); + return true; + } + + // Returns number of bits already read. + virtual size_t GetNumReadBits() const = 0; + + // These two functions define 'hard' and 'soft' EOF. + // + // Returns true if the end of the buffer was reached. + virtual bool ReachedEnd() const = 0; + // Returns true if we reached the end of the buffer or are nearing it and only + // zero bits are left to read. Implementations of this function are allowed to + // commit a "false negative" error if the end of the buffer was not reached, + // i.e. it can return false even if indeed only zeroes are left. + // It is assumed that the consumer expects that + // the buffer stream ends with padding zeroes, and would accept this as a + // 'soft' EOF. Implementations of this class do not necessarily need to + // implement this, default behavior can simply delegate to ReachedEnd(). + virtual bool OnlyZeroesLeft() const { return ReachedEnd(); } + + // Reads value encoded with WriteVariableWidthXXX (see BitWriterInterface). + // Reader and writer must use the same |chunk_length| and variable type. + // Returns true on success, false if the bit stream ends prematurely. + bool ReadVariableWidthU64(uint64_t* val, size_t chunk_length); + bool ReadVariableWidthU32(uint32_t* val, size_t chunk_length); + bool ReadVariableWidthU16(uint16_t* val, size_t chunk_length); + bool ReadVariableWidthS64(int64_t* val, size_t chunk_length, + size_t zigzag_exponent); + + BitReaderInterface(const BitReaderInterface&) = delete; + BitReaderInterface& operator=(const BitReaderInterface&) = delete; +}; + +// This class is an implementation of BitReaderInterface which accepts both +// uint8_t and uint64_t buffers as input. uint64_t buffers are consumed and +// owned. uint8_t buffers are copied. +class BitReaderWord64 : public BitReaderInterface { + public: + // Consumes and owns the buffer. + explicit BitReaderWord64(std::vector&& buffer); + + // Copies the buffer and casts it to uint64. + // Consuming the original buffer and casting it to uint64 is difficult, + // as it would potentially cause data misalignment and poor performance. + explicit BitReaderWord64(const std::vector& buffer); + BitReaderWord64(const void* buffer, size_t num_bytes); + + size_t ReadBits(uint64_t* bits, size_t num_bits) override; + + size_t GetNumReadBits() const override { return pos_; } + + bool ReachedEnd() const override; + bool OnlyZeroesLeft() const override; + + BitReaderWord64() = delete; + + // Sets callback to emit bit sequences after every read. + void SetCallback(std::function callback) { + callback_ = callback; + } + + protected: + // Sends string generated from arguments to callback_ if defined. + void EmitSequence(uint64_t bits, size_t num_bits) const { + if (callback_) callback_(BitsToStream(bits, num_bits)); + } + + private: + const std::vector buffer_; + size_t pos_; + + // If not null, the reader will use the callback to emit the read bit + // sequence as a string of '0' and '1'. + std::function callback_; +}; + +} // namespace comp +} // namespace spvtools + +#endif // SOURCE_COMP_BIT_STREAM_H_ diff --git a/source/comp/huffman_codec.h b/source/comp/huffman_codec.h new file mode 100644 index 000000000..166021614 --- /dev/null +++ b/source/comp/huffman_codec.h @@ -0,0 +1,389 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Contains utils for reading, writing and debug printing bit streams. + +#ifndef SOURCE_COMP_HUFFMAN_CODEC_H_ +#define SOURCE_COMP_HUFFMAN_CODEC_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace spvtools { +namespace comp { + +// Used to generate and apply a Huffman coding scheme. +// |Val| is the type of variable being encoded (for example a string or a +// literal). +template +class HuffmanCodec { + public: + // Huffman tree node. + struct Node { + Node() {} + + // Creates Node from serialization leaving weight and id undefined. + Node(const Val& in_value, uint32_t in_left, uint32_t in_right) + : value(in_value), left(in_left), right(in_right) {} + + Val value = Val(); + uint32_t weight = 0; + // Ids are issued sequentially starting from 1. Ids are used as an ordering + // tie-breaker, to make sure that the ordering (and resulting coding scheme) + // are consistent accross multiple platforms. + uint32_t id = 0; + // Handles of children. + uint32_t left = 0; + uint32_t right = 0; + }; + + // Creates Huffman codec from a histogramm. + // Histogramm counts must not be zero. + explicit HuffmanCodec(const std::map& hist) { + if (hist.empty()) return; + + // Heuristic estimate. + nodes_.reserve(3 * hist.size()); + + // Create NIL. + CreateNode(); + + // The queue is sorted in ascending order by weight (or by node id if + // weights are equal). + std::vector queue_vector; + queue_vector.reserve(hist.size()); + std::priority_queue, + std::function> + queue(std::bind(&HuffmanCodec::LeftIsBigger, this, + std::placeholders::_1, std::placeholders::_2), + std::move(queue_vector)); + + // Put all leaves in the queue. + for (const auto& pair : hist) { + const uint32_t node = CreateNode(); + MutableValueOf(node) = pair.first; + MutableWeightOf(node) = pair.second; + assert(WeightOf(node)); + queue.push(node); + } + + // Form the tree by combining two subtrees with the least weight, + // and pushing the root of the new tree in the queue. + while (true) { + // We push a node at the end of each iteration, so the queue is never + // supposed to be empty at this point, unless there are no leaves, but + // that case was already handled. + assert(!queue.empty()); + const uint32_t right = queue.top(); + queue.pop(); + + // If the queue is empty at this point, then the last node is + // the root of the complete Huffman tree. + if (queue.empty()) { + root_ = right; + break; + } + + const uint32_t left = queue.top(); + queue.pop(); + + // Combine left and right into a new tree and push it into the queue. + const uint32_t parent = CreateNode(); + MutableWeightOf(parent) = WeightOf(right) + WeightOf(left); + MutableLeftOf(parent) = left; + MutableRightOf(parent) = right; + queue.push(parent); + } + + // Traverse the tree and form encoding table. + CreateEncodingTable(); + } + + // Creates Huffman codec from saved tree structure. + // |nodes| is the list of nodes of the tree, nodes[0] being NIL. + // |root_handle| is the index of the root node. + HuffmanCodec(uint32_t root_handle, std::vector&& nodes) { + nodes_ = std::move(nodes); + assert(!nodes_.empty()); + assert(root_handle > 0 && root_handle < nodes_.size()); + assert(!LeftOf(0) && !RightOf(0)); + + root_ = root_handle; + + // Traverse the tree and form encoding table. + CreateEncodingTable(); + } + + // Serializes the codec in the following text format: + // (, { + // {0, 0, 0}, + // {val1, left1, right1}, + // {val2, left2, right2}, + // ... + // }) + std::string SerializeToText(int indent_num_whitespaces) const { + const bool value_is_text = std::is_same::value; + + const std::string indent1 = std::string(indent_num_whitespaces, ' '); + const std::string indent2 = std::string(indent_num_whitespaces + 2, ' '); + + std::stringstream code; + code << "(" << root_ << ", {\n"; + + for (const Node& node : nodes_) { + code << indent2 << "{"; + if (value_is_text) code << "\""; + code << node.value; + if (value_is_text) code << "\""; + code << ", " << node.left << ", " << node.right << "},\n"; + } + + code << indent1 << "})"; + + return code.str(); + } + + // Prints the Huffman tree in the following format: + // w------w------'x' + // w------'y' + // Where w stands for the weight of the node. + // Right tree branches appear above left branches. Taking the right path + // adds 1 to the code, taking the left adds 0. + void PrintTree(std::ostream& out) const { PrintTreeInternal(out, root_, 0); } + + // Traverses the tree and prints the Huffman table: value, code + // and optionally node weight for every leaf. + void PrintTable(std::ostream& out, bool print_weights = true) { + std::queue> queue; + queue.emplace(root_, ""); + + while (!queue.empty()) { + const uint32_t node = queue.front().first; + const std::string code = queue.front().second; + queue.pop(); + if (!RightOf(node) && !LeftOf(node)) { + out << ValueOf(node); + if (print_weights) out << " " << WeightOf(node); + out << " " << code << std::endl; + } else { + if (LeftOf(node)) queue.emplace(LeftOf(node), code + "0"); + + if (RightOf(node)) queue.emplace(RightOf(node), code + "1"); + } + } + } + + // Returns the Huffman table. The table was built at at construction time, + // this function just returns a const reference. + const std::unordered_map>& GetEncodingTable() + const { + return encoding_table_; + } + + // Encodes |val| and stores its Huffman code in the lower |num_bits| of + // |bits|. Returns false of |val| is not in the Huffman table. + bool Encode(const Val& val, uint64_t* bits, size_t* num_bits) const { + auto it = encoding_table_.find(val); + if (it == encoding_table_.end()) return false; + *bits = it->second.first; + *num_bits = it->second.second; + return true; + } + + // Reads bits one-by-one using callback |read_bit| until a match is found. + // Matching value is stored in |val|. Returns false if |read_bit| terminates + // before a code was mathced. + // |read_bit| has type bool func(bool* bit). When called, the next bit is + // stored in |bit|. |read_bit| returns false if the stream terminates + // prematurely. + bool DecodeFromStream(const std::function& read_bit, + Val* val) const { + uint32_t node = root_; + while (true) { + assert(node); + + if (!RightOf(node) && !LeftOf(node)) { + *val = ValueOf(node); + return true; + } + + bool go_right; + if (!read_bit(&go_right)) return false; + + if (go_right) + node = RightOf(node); + else + node = LeftOf(node); + } + + assert(0); + return false; + } + + private: + // Returns value of the node referenced by |handle|. + Val ValueOf(uint32_t node) const { return nodes_.at(node).value; } + + // Returns left child of |node|. + uint32_t LeftOf(uint32_t node) const { return nodes_.at(node).left; } + + // Returns right child of |node|. + uint32_t RightOf(uint32_t node) const { return nodes_.at(node).right; } + + // Returns weight of |node|. + uint32_t WeightOf(uint32_t node) const { return nodes_.at(node).weight; } + + // Returns id of |node|. + uint32_t IdOf(uint32_t node) const { return nodes_.at(node).id; } + + // Returns mutable reference to value of |node|. + Val& MutableValueOf(uint32_t node) { + assert(node); + return nodes_.at(node).value; + } + + // Returns mutable reference to handle of left child of |node|. + uint32_t& MutableLeftOf(uint32_t node) { + assert(node); + return nodes_.at(node).left; + } + + // Returns mutable reference to handle of right child of |node|. + uint32_t& MutableRightOf(uint32_t node) { + assert(node); + return nodes_.at(node).right; + } + + // Returns mutable reference to weight of |node|. + uint32_t& MutableWeightOf(uint32_t node) { return nodes_.at(node).weight; } + + // Returns mutable reference to id of |node|. + uint32_t& MutableIdOf(uint32_t node) { return nodes_.at(node).id; } + + // Returns true if |left| has bigger weight than |right|. Node ids are + // used as tie-breaker. + bool LeftIsBigger(uint32_t left, uint32_t right) const { + if (WeightOf(left) == WeightOf(right)) { + assert(IdOf(left) != IdOf(right)); + return IdOf(left) > IdOf(right); + } + return WeightOf(left) > WeightOf(right); + } + + // Prints subtree (helper function used by PrintTree). + void PrintTreeInternal(std::ostream& out, uint32_t node, size_t depth) const { + if (!node) return; + + const size_t kTextFieldWidth = 7; + + if (!RightOf(node) && !LeftOf(node)) { + out << ValueOf(node) << std::endl; + } else { + if (RightOf(node)) { + std::stringstream label; + label << std::setfill('-') << std::left << std::setw(kTextFieldWidth) + << WeightOf(RightOf(node)); + out << label.str(); + PrintTreeInternal(out, RightOf(node), depth + 1); + } + + if (LeftOf(node)) { + out << std::string(depth * kTextFieldWidth, ' '); + std::stringstream label; + label << std::setfill('-') << std::left << std::setw(kTextFieldWidth) + << WeightOf(LeftOf(node)); + out << label.str(); + PrintTreeInternal(out, LeftOf(node), depth + 1); + } + } + } + + // Traverses the Huffman tree and saves paths to the leaves as bit + // sequences to encoding_table_. + void CreateEncodingTable() { + struct Context { + Context(uint32_t in_node, uint64_t in_bits, size_t in_depth) + : node(in_node), bits(in_bits), depth(in_depth) {} + uint32_t node; + // Huffman tree depth cannot exceed 64 as histogramm counts are expected + // to be positive and limited by numeric_limits::max(). + // For practical applications tree depth would be much smaller than 64. + uint64_t bits; + size_t depth; + }; + + std::queue queue; + queue.emplace(root_, 0, 0); + + while (!queue.empty()) { + const Context& context = queue.front(); + const uint32_t node = context.node; + const uint64_t bits = context.bits; + const size_t depth = context.depth; + queue.pop(); + + if (!RightOf(node) && !LeftOf(node)) { + auto insertion_result = encoding_table_.emplace( + ValueOf(node), std::pair(bits, depth)); + assert(insertion_result.second); + (void)insertion_result; + } else { + if (LeftOf(node)) queue.emplace(LeftOf(node), bits, depth + 1); + + if (RightOf(node)) + queue.emplace(RightOf(node), bits | (1ULL << depth), depth + 1); + } + } + } + + // Creates new Huffman tree node and stores it in the deleter array. + uint32_t CreateNode() { + const uint32_t handle = static_cast(nodes_.size()); + nodes_.emplace_back(Node()); + nodes_.back().id = next_node_id_++; + return handle; + } + + // Huffman tree root handle. + uint32_t root_ = 0; + + // Huffman tree deleter. + std::vector nodes_; + + // Encoding table value -> {bits, num_bits}. + // Huffman codes are expected to never exceed 64 bit length (this is in fact + // impossible if frequencies are stored as uint32_t). + std::unordered_map> encoding_table_; + + // Next node id issued by CreateNode(); + uint32_t next_node_id_ = 1; +}; + +} // namespace comp +} // namespace spvtools + +#endif // SOURCE_COMP_HUFFMAN_CODEC_H_ diff --git a/source/comp/markv.cpp b/source/comp/markv.cpp new file mode 100644 index 000000000..736bc51ba --- /dev/null +++ b/source/comp/markv.cpp @@ -0,0 +1,112 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/comp/markv.h" + +#include "source/comp/markv_decoder.h" +#include "source/comp/markv_encoder.h" + +namespace spvtools { +namespace comp { +namespace { + +spv_result_t EncodeHeader(void* user_data, spv_endianness_t endian, + uint32_t magic, uint32_t version, uint32_t generator, + uint32_t id_bound, uint32_t schema) { + MarkvEncoder* encoder = reinterpret_cast(user_data); + return encoder->EncodeHeader(endian, magic, version, generator, id_bound, + schema); +} + +spv_result_t EncodeInstruction(void* user_data, + const spv_parsed_instruction_t* inst) { + MarkvEncoder* encoder = reinterpret_cast(user_data); + return encoder->EncodeInstruction(*inst); +} + +} // namespace + +spv_result_t SpirvToMarkv( + spv_const_context context, const std::vector& spirv, + const MarkvCodecOptions& options, const MarkvModel& markv_model, + MessageConsumer message_consumer, MarkvLogConsumer log_consumer, + MarkvDebugConsumer debug_consumer, std::vector* markv) { + spv_context_t hijack_context = *context; + SetContextMessageConsumer(&hijack_context, message_consumer); + + spv_validator_options validator_options = + MarkvDecoder::GetValidatorOptions(options); + if (validator_options) { + spv_const_binary_t spirv_binary = {spirv.data(), spirv.size()}; + const spv_result_t result = spvValidateWithOptions( + &hijack_context, validator_options, &spirv_binary, nullptr); + if (result != SPV_SUCCESS) return result; + } + + MarkvEncoder encoder(&hijack_context, options, &markv_model); + + spv_position_t position = {}; + if (log_consumer || debug_consumer) { + encoder.CreateLogger(log_consumer, debug_consumer); + + spv_text text = nullptr; + if (spvBinaryToText(&hijack_context, spirv.data(), spirv.size(), + SPV_BINARY_TO_TEXT_OPTION_NO_HEADER, &text, + nullptr) != SPV_SUCCESS) { + return DiagnosticStream(position, hijack_context.consumer, "", + SPV_ERROR_INVALID_BINARY) + << "Failed to disassemble SPIR-V binary."; + } + assert(text); + encoder.SetDisassembly(std::string(text->str, text->length)); + spvTextDestroy(text); + } + + if (spvBinaryParse(&hijack_context, &encoder, spirv.data(), spirv.size(), + EncodeHeader, EncodeInstruction, nullptr) != SPV_SUCCESS) { + return DiagnosticStream(position, hijack_context.consumer, "", + SPV_ERROR_INVALID_BINARY) + << "Unable to encode to MARK-V."; + } + + *markv = encoder.GetMarkvBinary(); + return SPV_SUCCESS; +} + +spv_result_t MarkvToSpirv( + spv_const_context context, const std::vector& markv, + const MarkvCodecOptions& options, const MarkvModel& markv_model, + MessageConsumer message_consumer, MarkvLogConsumer log_consumer, + MarkvDebugConsumer debug_consumer, std::vector* spirv) { + spv_position_t position = {}; + spv_context_t hijack_context = *context; + SetContextMessageConsumer(&hijack_context, message_consumer); + + MarkvDecoder decoder(&hijack_context, markv, options, &markv_model); + + if (log_consumer || debug_consumer) + decoder.CreateLogger(log_consumer, debug_consumer); + + if (decoder.DecodeModule(spirv) != SPV_SUCCESS) { + return DiagnosticStream(position, hijack_context.consumer, "", + SPV_ERROR_INVALID_BINARY) + << "Unable to decode MARK-V."; + } + + assert(!spirv->empty()); + return SPV_SUCCESS; +} + +} // namespace comp +} // namespace spvtools diff --git a/source/comp/markv.h b/source/comp/markv.h new file mode 100644 index 000000000..587086f91 --- /dev/null +++ b/source/comp/markv.h @@ -0,0 +1,74 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// MARK-V is a compression format for SPIR-V binaries. It strips away +// non-essential information (such as result ids which can be regenerated) and +// uses various bit reduction techiniques to reduce the size of the binary and +// make it more similar to other compressed SPIR-V files to further improve +// compression of the dataset. + +#ifndef SOURCE_COMP_MARKV_H_ +#define SOURCE_COMP_MARKV_H_ + +#include "spirv-tools/libspirv.hpp" + +namespace spvtools { +namespace comp { + +class MarkvModel; + +struct MarkvCodecOptions { + bool validate_spirv_binary = false; +}; + +// Debug callback. Called once per instruction. +// |words| is instruction SPIR-V words. +// |bits| is a textual representation of the MARK-V bit sequence used to encode +// the instruction (char '0' for 0, char '1' for 1). +// |comment| contains all logs generated while processing the instruction. +using MarkvDebugConsumer = + std::function& words, + const std::string& bits, const std::string& comment)>; + +// Logging callback. Called often (if decoder reads a single bit, the log +// consumer will receive 1 character string with that bit). +// This callback is more suitable for continous output than MarkvDebugConsumer, +// for example if the codec crashes it would allow to pinpoint on which operand +// or bit the crash happened. +// |snippet| could be any atomic fragment of text logged by the codec. It can +// contain a paragraph of text with newlines, or can be just one character. +using MarkvLogConsumer = std::function; + +// Encodes the given SPIR-V binary to MARK-V binary. +// |log_consumer| is optional (pass MarkvLogConsumer() to disable). +// |debug_consumer| is optional (pass MarkvDebugConsumer() to disable). +spv_result_t SpirvToMarkv( + spv_const_context context, const std::vector& spirv, + const MarkvCodecOptions& options, const MarkvModel& markv_model, + MessageConsumer message_consumer, MarkvLogConsumer log_consumer, + MarkvDebugConsumer debug_consumer, std::vector* markv); + +// Decodes a SPIR-V binary from the given MARK-V binary. +// |log_consumer| is optional (pass MarkvLogConsumer() to disable). +// |debug_consumer| is optional (pass MarkvDebugConsumer() to disable). +spv_result_t MarkvToSpirv( + spv_const_context context, const std::vector& markv, + const MarkvCodecOptions& options, const MarkvModel& markv_model, + MessageConsumer message_consumer, MarkvLogConsumer log_consumer, + MarkvDebugConsumer debug_consumer, std::vector* spirv); + +} // namespace comp +} // namespace spvtools + +#endif // SOURCE_COMP_MARKV_H_ diff --git a/source/comp/markv_codec.cpp b/source/comp/markv_codec.cpp new file mode 100644 index 000000000..ae3ce79f2 --- /dev/null +++ b/source/comp/markv_codec.cpp @@ -0,0 +1,793 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// MARK-V is a compression format for SPIR-V binaries. It strips away +// non-essential information (such as result IDs which can be regenerated) and +// uses various bit reduction techniques to reduce the size of the binary. + +#include "source/comp/markv_codec.h" + +#include "source/comp/markv_logger.h" +#include "source/latest_version_glsl_std_450_header.h" +#include "source/latest_version_opencl_std_header.h" +#include "source/opcode.h" +#include "source/util/make_unique.h" + +namespace spvtools { +namespace comp { +namespace { + +// Custom hash function used to produce short descriptors. +uint32_t ShortHashU32Array(const std::vector& words) { + // The hash function is a sum of hashes of each word seeded by word index. + // Knuth's multiplicative hash is used to hash the words. + const uint32_t kKnuthMulHash = 2654435761; + uint32_t val = 0; + for (uint32_t i = 0; i < words.size(); ++i) { + val += (words[i] + i + 123) * kKnuthMulHash; + } + return 1 + val % ((1 << MarkvCodec::kShortDescriptorNumBits) - 1); +} + +// Returns a set of mtf rank codecs based on a plausible hand-coded +// distribution. +std::map>> +GetMtfHuffmanCodecs() { + std::map>> codecs; + + std::unique_ptr> codec; + + codec = MakeUnique>(std::map({ + {0, 5}, + {1, 40}, + {2, 10}, + {3, 5}, + {4, 5}, + {5, 5}, + {6, 3}, + {7, 3}, + {8, 3}, + {9, 3}, + {MarkvCodec::kMtfRankEncodedByValueSignal, 10}, + })); + codecs.emplace(kMtfAll, std::move(codec)); + + codec = MakeUnique>(std::map({ + {1, 50}, + {2, 20}, + {3, 5}, + {4, 5}, + {5, 2}, + {6, 1}, + {7, 1}, + {8, 1}, + {9, 1}, + {MarkvCodec::kMtfRankEncodedByValueSignal, 10}, + })); + codecs.emplace(kMtfGenericNonZeroRank, std::move(codec)); + + return codecs; +} + +} // namespace + +const uint32_t MarkvCodec::kMarkvMagicNumber = 0x07230303; + +const uint32_t MarkvCodec::kMtfSmallestRankEncodedByValue = 10; + +const uint32_t MarkvCodec::kMtfRankEncodedByValueSignal = + std::numeric_limits::max(); + +const uint32_t MarkvCodec::kShortDescriptorNumBits = 8; + +const size_t MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte = 8; + +MarkvCodec::MarkvCodec(spv_const_context context, + spv_validator_options validator_options, + const MarkvModel* model) + : validator_options_(validator_options), + grammar_(context), + model_(model), + short_id_descriptors_(ShortHashU32Array), + mtf_huffman_codecs_(GetMtfHuffmanCodecs()), + context_(context) {} + +MarkvCodec::~MarkvCodec() { spvValidatorOptionsDestroy(validator_options_); } + +MarkvCodec::MarkvHeader::MarkvHeader() + : magic_number(MarkvCodec::kMarkvMagicNumber), + markv_version(MarkvCodec::GetMarkvVersion()) {} + +// Defines and returns current MARK-V version. +// static +uint32_t MarkvCodec::GetMarkvVersion() { + const uint32_t kVersionMajor = 1; + const uint32_t kVersionMinor = 4; + return kVersionMinor | (kVersionMajor << 16); +} + +size_t MarkvCodec::GetNumBitsToNextByte(size_t bit_pos) const { + return (8 - (bit_pos % 8)) % 8; +} + +// Returns true if the opcode has a fixed number of operands. May return a +// false negative. +bool MarkvCodec::OpcodeHasFixedNumberOfOperands(SpvOp opcode) const { + switch (opcode) { + // TODO(atgoo@github.com) This is not a complete list. + case SpvOpNop: + case SpvOpName: + case SpvOpUndef: + case SpvOpSizeOf: + case SpvOpLine: + case SpvOpNoLine: + case SpvOpDecorationGroup: + case SpvOpExtension: + case SpvOpExtInstImport: + case SpvOpMemoryModel: + case SpvOpCapability: + case SpvOpTypeVoid: + case SpvOpTypeBool: + case SpvOpTypeInt: + case SpvOpTypeFloat: + case SpvOpTypeVector: + case SpvOpTypeMatrix: + case SpvOpTypeSampler: + case SpvOpTypeSampledImage: + case SpvOpTypeArray: + case SpvOpTypePointer: + case SpvOpConstantTrue: + case SpvOpConstantFalse: + case SpvOpLabel: + case SpvOpBranch: + case SpvOpFunction: + case SpvOpFunctionParameter: + case SpvOpFunctionEnd: + case SpvOpBitcast: + case SpvOpCopyObject: + case SpvOpTranspose: + case SpvOpSNegate: + case SpvOpFNegate: + case SpvOpIAdd: + case SpvOpFAdd: + case SpvOpISub: + case SpvOpFSub: + case SpvOpIMul: + case SpvOpFMul: + case SpvOpUDiv: + case SpvOpSDiv: + case SpvOpFDiv: + case SpvOpUMod: + case SpvOpSRem: + case SpvOpSMod: + case SpvOpFRem: + case SpvOpFMod: + case SpvOpVectorTimesScalar: + case SpvOpMatrixTimesScalar: + case SpvOpVectorTimesMatrix: + case SpvOpMatrixTimesVector: + case SpvOpMatrixTimesMatrix: + case SpvOpOuterProduct: + case SpvOpDot: + return true; + default: + break; + } + return false; +} + +void MarkvCodec::ProcessCurInstruction() { + instructions_.emplace_back(new val::Instruction(&inst_)); + + const SpvOp opcode = SpvOp(inst_.opcode); + + if (inst_.result_id) { + id_to_def_instruction_.emplace(inst_.result_id, instructions_.back().get()); + + // Collect ids local to the current function. + if (cur_function_id_) { + ids_local_to_cur_function_.push_back(inst_.result_id); + } + + // Starting new function. + if (opcode == SpvOpFunction) { + cur_function_id_ = inst_.result_id; + cur_function_return_type_ = inst_.type_id; + if (model_->id_fallback_strategy() == + MarkvModel::IdFallbackStrategy::kRuleBased) { + multi_mtf_.Insert(GetMtfFunctionWithReturnType(inst_.type_id), + inst_.result_id); + } + + // Store function parameter types in a queue, so that we know which types + // to expect in the following OpFunctionParameter instructions. + const val::Instruction* def_inst = FindDef(inst_.words[4]); + assert(def_inst); + assert(def_inst->opcode() == SpvOpTypeFunction); + for (uint32_t i = 3; i < def_inst->words().size(); ++i) { + remaining_function_parameter_types_.push_back(def_inst->word(i)); + } + } + } + + // Remove local ids from MTFs if function end. + if (opcode == SpvOpFunctionEnd) { + cur_function_id_ = 0; + for (uint32_t id : ids_local_to_cur_function_) multi_mtf_.RemoveFromAll(id); + ids_local_to_cur_function_.clear(); + assert(remaining_function_parameter_types_.empty()); + } + + if (!inst_.result_id) return; + + { + // Save the result ID to type ID mapping. + // In the grammar, type ID always appears before result ID. + // A regular value maps to its type. Some instructions (e.g. OpLabel) + // have no type Id, and will map to 0. The result Id for a + // type-generating instruction (e.g. OpTypeInt) maps to itself. + auto insertion_result = id_to_type_id_.emplace( + inst_.result_id, spvOpcodeGeneratesType(SpvOp(inst_.opcode)) + ? inst_.result_id + : inst_.type_id); + (void)insertion_result; + assert(insertion_result.second); + } + + // Add result_id to MTFs. + if (model_->id_fallback_strategy() == + MarkvModel::IdFallbackStrategy::kRuleBased) { + switch (opcode) { + case SpvOpTypeFloat: + case SpvOpTypeInt: + case SpvOpTypeBool: + case SpvOpTypeVector: + case SpvOpTypePointer: + case SpvOpExtInstImport: + case SpvOpTypeSampledImage: + case SpvOpTypeImage: + case SpvOpTypeSampler: + multi_mtf_.Insert(GetMtfIdGeneratedByOpcode(opcode), inst_.result_id); + break; + default: + break; + } + + if (spvOpcodeIsComposite(opcode)) { + multi_mtf_.Insert(kMtfTypeComposite, inst_.result_id); + } + + if (opcode == SpvOpLabel) { + multi_mtf_.InsertOrPromote(kMtfLabel, inst_.result_id); + } + + if (opcode == SpvOpTypeInt) { + multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id); + multi_mtf_.Insert(kMtfTypeIntScalarOrVector, inst_.result_id); + } + + if (opcode == SpvOpTypeFloat) { + multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id); + multi_mtf_.Insert(kMtfTypeFloatScalarOrVector, inst_.result_id); + } + + if (opcode == SpvOpTypeBool) { + multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id); + multi_mtf_.Insert(kMtfTypeBoolScalarOrVector, inst_.result_id); + } + + if (opcode == SpvOpTypeVector) { + const uint32_t component_type_id = inst_.words[2]; + const uint32_t size = inst_.words[3]; + if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeFloat), + component_type_id)) { + multi_mtf_.Insert(kMtfTypeFloatScalarOrVector, inst_.result_id); + } else if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeInt), + component_type_id)) { + multi_mtf_.Insert(kMtfTypeIntScalarOrVector, inst_.result_id); + } else if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeBool), + component_type_id)) { + multi_mtf_.Insert(kMtfTypeBoolScalarOrVector, inst_.result_id); + } + multi_mtf_.Insert(GetMtfTypeVectorOfSize(size), inst_.result_id); + } + + if (inst_.opcode == SpvOpTypeFunction) { + const uint32_t return_type = inst_.words[2]; + multi_mtf_.Insert(kMtfTypeReturnedByFunction, return_type); + multi_mtf_.Insert(GetMtfFunctionTypeWithReturnType(return_type), + inst_.result_id); + } + + if (inst_.type_id) { + const val::Instruction* type_inst = FindDef(inst_.type_id); + assert(type_inst); + + multi_mtf_.Insert(kMtfObject, inst_.result_id); + + multi_mtf_.Insert(GetMtfIdOfType(inst_.type_id), inst_.result_id); + + if (multi_mtf_.HasValue(kMtfTypeFloatScalarOrVector, inst_.type_id)) { + multi_mtf_.Insert(kMtfFloatScalarOrVector, inst_.result_id); + } + + if (multi_mtf_.HasValue(kMtfTypeIntScalarOrVector, inst_.type_id)) + multi_mtf_.Insert(kMtfIntScalarOrVector, inst_.result_id); + + if (multi_mtf_.HasValue(kMtfTypeBoolScalarOrVector, inst_.type_id)) + multi_mtf_.Insert(kMtfBoolScalarOrVector, inst_.result_id); + + if (multi_mtf_.HasValue(kMtfTypeComposite, inst_.type_id)) + multi_mtf_.Insert(kMtfComposite, inst_.result_id); + + switch (type_inst->opcode()) { + case SpvOpTypeInt: + case SpvOpTypeBool: + case SpvOpTypePointer: + case SpvOpTypeVector: + case SpvOpTypeImage: + case SpvOpTypeSampledImage: + case SpvOpTypeSampler: + multi_mtf_.Insert( + GetMtfIdWithTypeGeneratedByOpcode(type_inst->opcode()), + inst_.result_id); + break; + default: + break; + } + + if (type_inst->opcode() == SpvOpTypeVector) { + const uint32_t component_type = type_inst->word(2); + multi_mtf_.Insert(GetMtfVectorOfComponentType(component_type), + inst_.result_id); + } + + if (type_inst->opcode() == SpvOpTypePointer) { + assert(type_inst->operands().size() > 2); + assert(type_inst->words().size() > type_inst->operands()[2].offset); + const uint32_t data_type = + type_inst->word(type_inst->operands()[2].offset); + multi_mtf_.Insert(GetMtfPointerToType(data_type), inst_.result_id); + + if (multi_mtf_.HasValue(kMtfTypeComposite, data_type)) + multi_mtf_.Insert(kMtfTypePointerToComposite, inst_.result_id); + } + } + + if (spvOpcodeGeneratesType(opcode)) { + if (opcode != SpvOpTypeFunction) { + multi_mtf_.Insert(kMtfTypeNonFunction, inst_.result_id); + } + } + } + + if (model_->AnyDescriptorHasCodingScheme()) { + const uint32_t long_descriptor = + long_id_descriptors_.ProcessInstruction(inst_); + if (model_->DescriptorHasCodingScheme(long_descriptor)) + multi_mtf_.Insert(GetMtfLongIdDescriptor(long_descriptor), + inst_.result_id); + } + + if (model_->id_fallback_strategy() == + MarkvModel::IdFallbackStrategy::kShortDescriptor) { + const uint32_t short_descriptor = + short_id_descriptors_.ProcessInstruction(inst_); + multi_mtf_.Insert(GetMtfShortIdDescriptor(short_descriptor), + inst_.result_id); + } +} + +uint64_t MarkvCodec::GetRuleBasedMtf() { + // This function is only called for id operands (but not result ids). + assert(spvIsIdType(operand_.type) || + operand_.type == SPV_OPERAND_TYPE_OPTIONAL_ID); + assert(operand_.type != SPV_OPERAND_TYPE_RESULT_ID); + + const SpvOp opcode = static_cast(inst_.opcode); + + // All operand slots which expect label id. + if ((inst_.opcode == SpvOpLoopMerge && operand_index_ <= 1) || + (inst_.opcode == SpvOpSelectionMerge && operand_index_ == 0) || + (inst_.opcode == SpvOpBranch && operand_index_ == 0) || + (inst_.opcode == SpvOpBranchConditional && + (operand_index_ == 1 || operand_index_ == 2)) || + (inst_.opcode == SpvOpPhi && operand_index_ >= 3 && + operand_index_ % 2 == 1) || + (inst_.opcode == SpvOpSwitch && operand_index_ > 0)) { + return kMtfLabel; + } + + switch (opcode) { + case SpvOpFAdd: + case SpvOpFSub: + case SpvOpFMul: + case SpvOpFDiv: + case SpvOpFRem: + case SpvOpFMod: + case SpvOpFNegate: { + if (operand_index_ == 0) return kMtfTypeFloatScalarOrVector; + return GetMtfIdOfType(inst_.type_id); + } + + case SpvOpISub: + case SpvOpIAdd: + case SpvOpIMul: + case SpvOpSDiv: + case SpvOpUDiv: + case SpvOpSMod: + case SpvOpUMod: + case SpvOpSRem: + case SpvOpSNegate: { + if (operand_index_ == 0) return kMtfTypeIntScalarOrVector; + + return kMtfIntScalarOrVector; + } + + // TODO(atgoo@github.com) Add OpConvertFToU and other opcodes. + + case SpvOpFOrdEqual: + case SpvOpFUnordEqual: + case SpvOpFOrdNotEqual: + case SpvOpFUnordNotEqual: + case SpvOpFOrdLessThan: + case SpvOpFUnordLessThan: + case SpvOpFOrdGreaterThan: + case SpvOpFUnordGreaterThan: + case SpvOpFOrdLessThanEqual: + case SpvOpFUnordLessThanEqual: + case SpvOpFOrdGreaterThanEqual: + case SpvOpFUnordGreaterThanEqual: { + if (operand_index_ == 0) return kMtfTypeBoolScalarOrVector; + if (operand_index_ == 2) return kMtfFloatScalarOrVector; + if (operand_index_ == 3) { + const uint32_t first_operand_id = GetInstWords()[3]; + const uint32_t first_operand_type = id_to_type_id_.at(first_operand_id); + return GetMtfIdOfType(first_operand_type); + } + break; + } + + case SpvOpVectorShuffle: { + if (operand_index_ == 0) { + assert(inst_.num_operands > 4); + return GetMtfTypeVectorOfSize(inst_.num_operands - 4); + } + + assert(inst_.type_id); + if (operand_index_ == 2 || operand_index_ == 3) + return GetMtfVectorOfComponentType( + GetVectorComponentType(inst_.type_id)); + break; + } + + case SpvOpVectorTimesScalar: { + if (operand_index_ == 0) { + // TODO(atgoo@github.com) Could be narrowed to vector of floats. + return GetMtfIdGeneratedByOpcode(SpvOpTypeVector); + } + + assert(inst_.type_id); + if (operand_index_ == 2) return GetMtfIdOfType(inst_.type_id); + if (operand_index_ == 3) + return GetMtfIdOfType(GetVectorComponentType(inst_.type_id)); + break; + } + + case SpvOpDot: { + if (operand_index_ == 0) return GetMtfIdGeneratedByOpcode(SpvOpTypeFloat); + + assert(inst_.type_id); + if (operand_index_ == 2) + return GetMtfVectorOfComponentType(inst_.type_id); + if (operand_index_ == 3) { + const uint32_t vector_id = GetInstWords()[3]; + const uint32_t vector_type = id_to_type_id_.at(vector_id); + return GetMtfIdOfType(vector_type); + } + break; + } + + case SpvOpTypeVector: { + if (operand_index_ == 1) { + return kMtfTypeScalar; + } + break; + } + + case SpvOpTypeMatrix: { + if (operand_index_ == 1) { + return GetMtfIdGeneratedByOpcode(SpvOpTypeVector); + } + break; + } + + case SpvOpTypePointer: { + if (operand_index_ == 2) { + return kMtfTypeNonFunction; + } + break; + } + + case SpvOpTypeStruct: { + if (operand_index_ >= 1) { + return kMtfTypeNonFunction; + } + break; + } + + case SpvOpTypeFunction: { + if (operand_index_ == 1) { + return kMtfTypeNonFunction; + } + + if (operand_index_ >= 2) { + return kMtfTypeNonFunction; + } + break; + } + + case SpvOpLoad: { + if (operand_index_ == 0) return kMtfTypeNonFunction; + + if (operand_index_ == 2) { + assert(inst_.type_id); + return GetMtfPointerToType(inst_.type_id); + } + break; + } + + case SpvOpStore: { + if (operand_index_ == 0) + return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypePointer); + if (operand_index_ == 1) { + const uint32_t pointer_id = GetInstWords()[1]; + const uint32_t pointer_type = id_to_type_id_.at(pointer_id); + const val::Instruction* pointer_inst = FindDef(pointer_type); + assert(pointer_inst); + assert(pointer_inst->opcode() == SpvOpTypePointer); + const uint32_t data_type = + pointer_inst->word(pointer_inst->operands()[2].offset); + return GetMtfIdOfType(data_type); + } + break; + } + + case SpvOpVariable: { + if (operand_index_ == 0) + return GetMtfIdGeneratedByOpcode(SpvOpTypePointer); + break; + } + + case SpvOpAccessChain: { + if (operand_index_ == 0) + return GetMtfIdGeneratedByOpcode(SpvOpTypePointer); + if (operand_index_ == 2) return kMtfTypePointerToComposite; + if (operand_index_ >= 3) + return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeInt); + break; + } + + case SpvOpCompositeConstruct: { + if (operand_index_ == 0) return kMtfTypeComposite; + if (operand_index_ >= 2) { + const uint32_t composite_type = GetInstWords()[1]; + if (multi_mtf_.HasValue(kMtfTypeFloatScalarOrVector, composite_type)) + return kMtfFloatScalarOrVector; + if (multi_mtf_.HasValue(kMtfTypeIntScalarOrVector, composite_type)) + return kMtfIntScalarOrVector; + if (multi_mtf_.HasValue(kMtfTypeBoolScalarOrVector, composite_type)) + return kMtfBoolScalarOrVector; + } + break; + } + + case SpvOpCompositeExtract: { + if (operand_index_ == 2) return kMtfComposite; + break; + } + + case SpvOpConstantComposite: { + if (operand_index_ == 0) return kMtfTypeComposite; + if (operand_index_ >= 2) { + const val::Instruction* composite_type_inst = FindDef(inst_.type_id); + assert(composite_type_inst); + if (composite_type_inst->opcode() == SpvOpTypeVector) { + return GetMtfIdOfType(composite_type_inst->word(2)); + } + } + break; + } + + case SpvOpExtInst: { + if (operand_index_ == 2) + return GetMtfIdGeneratedByOpcode(SpvOpExtInstImport); + if (operand_index_ >= 4) { + const uint32_t return_type = GetInstWords()[1]; + const uint32_t ext_inst_type = inst_.ext_inst_type; + const uint32_t ext_inst_index = GetInstWords()[4]; + // TODO(atgoo@github.com) The list of extended instructions is + // incomplete. Only common instructions and low-hanging fruits listed. + if (ext_inst_type == SPV_EXT_INST_TYPE_GLSL_STD_450) { + switch (ext_inst_index) { + case GLSLstd450FAbs: + case GLSLstd450FClamp: + case GLSLstd450FMax: + case GLSLstd450FMin: + case GLSLstd450FMix: + case GLSLstd450Step: + case GLSLstd450SmoothStep: + case GLSLstd450Fma: + case GLSLstd450Pow: + case GLSLstd450Exp: + case GLSLstd450Exp2: + case GLSLstd450Log: + case GLSLstd450Log2: + case GLSLstd450Sqrt: + case GLSLstd450InverseSqrt: + case GLSLstd450Fract: + case GLSLstd450Floor: + case GLSLstd450Ceil: + case GLSLstd450Radians: + case GLSLstd450Degrees: + case GLSLstd450Sin: + case GLSLstd450Cos: + case GLSLstd450Tan: + case GLSLstd450Sinh: + case GLSLstd450Cosh: + case GLSLstd450Tanh: + case GLSLstd450Asin: + case GLSLstd450Acos: + case GLSLstd450Atan: + case GLSLstd450Atan2: + case GLSLstd450Asinh: + case GLSLstd450Acosh: + case GLSLstd450Atanh: + case GLSLstd450MatrixInverse: + case GLSLstd450Cross: + case GLSLstd450Normalize: + case GLSLstd450Reflect: + case GLSLstd450FaceForward: + return GetMtfIdOfType(return_type); + case GLSLstd450Length: + case GLSLstd450Distance: + case GLSLstd450Refract: + return kMtfFloatScalarOrVector; + default: + break; + } + } else if (ext_inst_type == SPV_EXT_INST_TYPE_OPENCL_STD) { + switch (ext_inst_index) { + case OpenCLLIB::Fabs: + case OpenCLLIB::FClamp: + case OpenCLLIB::Fmax: + case OpenCLLIB::Fmin: + case OpenCLLIB::Step: + case OpenCLLIB::Smoothstep: + case OpenCLLIB::Fma: + case OpenCLLIB::Pow: + case OpenCLLIB::Exp: + case OpenCLLIB::Exp2: + case OpenCLLIB::Log: + case OpenCLLIB::Log2: + case OpenCLLIB::Sqrt: + case OpenCLLIB::Rsqrt: + case OpenCLLIB::Fract: + case OpenCLLIB::Floor: + case OpenCLLIB::Ceil: + case OpenCLLIB::Radians: + case OpenCLLIB::Degrees: + case OpenCLLIB::Sin: + case OpenCLLIB::Cos: + case OpenCLLIB::Tan: + case OpenCLLIB::Sinh: + case OpenCLLIB::Cosh: + case OpenCLLIB::Tanh: + case OpenCLLIB::Asin: + case OpenCLLIB::Acos: + case OpenCLLIB::Atan: + case OpenCLLIB::Atan2: + case OpenCLLIB::Asinh: + case OpenCLLIB::Acosh: + case OpenCLLIB::Atanh: + case OpenCLLIB::Cross: + case OpenCLLIB::Normalize: + return GetMtfIdOfType(return_type); + case OpenCLLIB::Length: + case OpenCLLIB::Distance: + return kMtfFloatScalarOrVector; + default: + break; + } + } + } + break; + } + + case SpvOpFunction: { + if (operand_index_ == 0) return kMtfTypeReturnedByFunction; + + if (operand_index_ == 3) { + const uint32_t return_type = GetInstWords()[1]; + return GetMtfFunctionTypeWithReturnType(return_type); + } + break; + } + + case SpvOpFunctionCall: { + if (operand_index_ == 0) return kMtfTypeReturnedByFunction; + + if (operand_index_ == 2) { + const uint32_t return_type = GetInstWords()[1]; + return GetMtfFunctionWithReturnType(return_type); + } + + if (operand_index_ >= 3) { + const uint32_t function_id = GetInstWords()[3]; + const val::Instruction* function_inst = FindDef(function_id); + if (!function_inst) return kMtfObject; + + assert(function_inst->opcode() == SpvOpFunction); + + const uint32_t function_type_id = function_inst->word(4); + const val::Instruction* function_type_inst = FindDef(function_type_id); + assert(function_type_inst); + assert(function_type_inst->opcode() == SpvOpTypeFunction); + + const uint32_t argument_type = function_type_inst->word(operand_index_); + return GetMtfIdOfType(argument_type); + } + break; + } + + case SpvOpReturnValue: { + if (operand_index_ == 0) return GetMtfIdOfType(cur_function_return_type_); + break; + } + + case SpvOpBranchConditional: { + if (operand_index_ == 0) + return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeBool); + break; + } + + case SpvOpSampledImage: { + if (operand_index_ == 0) + return GetMtfIdGeneratedByOpcode(SpvOpTypeSampledImage); + if (operand_index_ == 2) + return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeImage); + if (operand_index_ == 3) + return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeSampler); + break; + } + + case SpvOpImageSampleImplicitLod: { + if (operand_index_ == 0) + return GetMtfIdGeneratedByOpcode(SpvOpTypeVector); + if (operand_index_ == 2) + return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeSampledImage); + if (operand_index_ == 3) + return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeVector); + break; + } + + default: + break; + } + + return kMtfNone; +} + +} // namespace comp +} // namespace spvtools diff --git a/source/comp/markv_codec.h b/source/comp/markv_codec.h new file mode 100644 index 000000000..f313d6178 --- /dev/null +++ b/source/comp/markv_codec.h @@ -0,0 +1,337 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_COMP_MARKV_CODEC_H_ +#define SOURCE_COMP_MARKV_CODEC_H_ + +#include +#include +#include +#include + +#include "source/assembly_grammar.h" +#include "source/comp/huffman_codec.h" +#include "source/comp/markv_model.h" +#include "source/comp/move_to_front.h" +#include "source/diagnostic.h" +#include "source/id_descriptor.h" + +#include "source/val/instruction.h" + +// Base class for MARK-V encoder and decoder. Contains common functionality +// such as: +// - Validator connection and validation state. +// - SPIR-V grammar and helper functions. + +namespace spvtools { +namespace comp { + +class MarkvLogger; + +// Handles for move-to-front sequences. Enums which end with "Begin" define +// handle spaces which start at that value and span 16 or 32 bit wide. +enum : uint64_t { + kMtfNone = 0, + // All ids. + kMtfAll, + // All forward declared ids. + kMtfForwardDeclared, + // All type ids except for generated by OpTypeFunction. + kMtfTypeNonFunction, + // All labels. + kMtfLabel, + // All ids created by instructions which had type_id. + kMtfObject, + // All types generated by OpTypeFloat, OpTypeInt, OpTypeBool. + kMtfTypeScalar, + // All composite types. + kMtfTypeComposite, + // Boolean type or any vector type of it. + kMtfTypeBoolScalarOrVector, + // All float types or any vector floats type. + kMtfTypeFloatScalarOrVector, + // All int types or any vector int type. + kMtfTypeIntScalarOrVector, + // All types declared as return types in OpTypeFunction. + kMtfTypeReturnedByFunction, + // All composite objects. + kMtfComposite, + // All bool objects or vectors of bools. + kMtfBoolScalarOrVector, + // All float objects or vectors of float. + kMtfFloatScalarOrVector, + // All int objects or vectors of int. + kMtfIntScalarOrVector, + // All pointer types which point to composited. + kMtfTypePointerToComposite, + // Used by EncodeMtfRankHuffman. + kMtfGenericNonZeroRank, + // Handle space for ids of specific type. + kMtfIdOfTypeBegin = 0x10000, + // Handle space for ids generated by specific opcode. + kMtfIdGeneratedByOpcode = 0x20000, + // Handle space for ids of objects with type generated by specific opcode. + kMtfIdWithTypeGeneratedByOpcodeBegin = 0x30000, + // All vectors of specific component type. + kMtfVectorOfComponentTypeBegin = 0x40000, + // All vector types of specific size. + kMtfTypeVectorOfSizeBegin = 0x50000, + // All pointer types to specific type. + kMtfPointerToTypeBegin = 0x60000, + // All function types which return specific type. + kMtfFunctionTypeWithReturnTypeBegin = 0x70000, + // All function objects which return specific type. + kMtfFunctionWithReturnTypeBegin = 0x80000, + // Short id descriptor space (max 16-bit). + kMtfShortIdDescriptorSpaceBegin = 0x90000, + // Long id descriptor space (32-bit). + kMtfLongIdDescriptorSpaceBegin = 0x100000000, +}; + +class MarkvCodec { + public: + static const uint32_t kMarkvMagicNumber; + + // Mtf ranks smaller than this are encoded with Huffman coding. + static const uint32_t kMtfSmallestRankEncodedByValue; + + // Signals that the mtf rank is too large to be encoded with Huffman. + static const uint32_t kMtfRankEncodedByValueSignal; + + static const uint32_t kShortDescriptorNumBits; + + static const size_t kByteBreakAfterInstIfLessThanUntilNextByte; + + static uint32_t GetMarkvVersion(); + + virtual ~MarkvCodec(); + + protected: + struct MarkvHeader { + MarkvHeader(); + + uint32_t magic_number; + uint32_t markv_version; + // Magic number to identify or verify MarkvModel used for encoding. + uint32_t markv_model = 0; + uint32_t markv_length_in_bits = 0; + uint32_t spirv_version = 0; + uint32_t spirv_generator = 0; + }; + + // |model| is owned by the caller, must be not null and valid during the + // lifetime of the codec. + MarkvCodec(spv_const_context context, spv_validator_options validator_options, + const MarkvModel* model); + + // Returns instruction which created |id| or nullptr if such instruction was + // not registered. + const val::Instruction* FindDef(uint32_t id) const { + const auto it = id_to_def_instruction_.find(id); + if (it == id_to_def_instruction_.end()) return nullptr; + return it->second; + } + + size_t GetNumBitsToNextByte(size_t bit_pos) const; + bool OpcodeHasFixedNumberOfOperands(SpvOp opcode) const; + + // Returns type id of vector type component. + uint32_t GetVectorComponentType(uint32_t vector_type_id) const { + const val::Instruction* type_inst = FindDef(vector_type_id); + assert(type_inst); + assert(type_inst->opcode() == SpvOpTypeVector); + + const uint32_t component_type = + type_inst->word(type_inst->operands()[1].offset); + return component_type; + } + + // Returns mtf handle for ids of given type. + uint64_t GetMtfIdOfType(uint32_t type_id) const { + return kMtfIdOfTypeBegin + type_id; + } + + // Returns mtf handle for ids generated by given opcode. + uint64_t GetMtfIdGeneratedByOpcode(SpvOp opcode) const { + return kMtfIdGeneratedByOpcode + opcode; + } + + // Returns mtf handle for ids of type generated by given opcode. + uint64_t GetMtfIdWithTypeGeneratedByOpcode(SpvOp opcode) const { + return kMtfIdWithTypeGeneratedByOpcodeBegin + opcode; + } + + // Returns mtf handle for vectors of specific component type. + uint64_t GetMtfVectorOfComponentType(uint32_t type_id) const { + return kMtfVectorOfComponentTypeBegin + type_id; + } + + // Returns mtf handle for vector type of specific size. + uint64_t GetMtfTypeVectorOfSize(uint32_t size) const { + return kMtfTypeVectorOfSizeBegin + size; + } + + // Returns mtf handle for pointers to specific size. + uint64_t GetMtfPointerToType(uint32_t type_id) const { + return kMtfPointerToTypeBegin + type_id; + } + + // Returns mtf handle for function types with given return type. + uint64_t GetMtfFunctionTypeWithReturnType(uint32_t type_id) const { + return kMtfFunctionTypeWithReturnTypeBegin + type_id; + } + + // Returns mtf handle for functions with given return type. + uint64_t GetMtfFunctionWithReturnType(uint32_t type_id) const { + return kMtfFunctionWithReturnTypeBegin + type_id; + } + + // Returns mtf handle for the given long id descriptor. + uint64_t GetMtfLongIdDescriptor(uint32_t descriptor) const { + return kMtfLongIdDescriptorSpaceBegin + descriptor; + } + + // Returns mtf handle for the given short id descriptor. + uint64_t GetMtfShortIdDescriptor(uint32_t descriptor) const { + return kMtfShortIdDescriptorSpaceBegin + descriptor; + } + + // Process data from the current instruction. This would update MTFs and + // other data containers. + void ProcessCurInstruction(); + + // Returns move-to-front handle to be used for the current operand slot. + // Mtf handle is chosen based on a set of rules defined by SPIR-V grammar. + uint64_t GetRuleBasedMtf(); + + // Returns words of the current instruction. Decoder has a different + // implementation and the array is valid only until the previously decoded + // word. + virtual const uint32_t* GetInstWords() const { return inst_.words; } + + // Returns the opcode of the previous instruction. + SpvOp GetPrevOpcode() const { + if (instructions_.empty()) return SpvOpNop; + + return instructions_.back()->opcode(); + } + + // Returns diagnostic stream, position index is set to instruction number. + DiagnosticStream Diag(spv_result_t error_code) const { + return DiagnosticStream({0, 0, instructions_.size()}, context_->consumer, + "", error_code); + } + + // Returns current id bound. + uint32_t GetIdBound() const { return id_bound_; } + + // Sets current id bound, expected to be no lower than the previous one. + void SetIdBound(uint32_t id_bound) { + assert(id_bound >= id_bound_); + id_bound_ = id_bound; + } + + // Returns Huffman codec for ranks of the mtf with given |handle|. + // Different mtfs can use different rank distributions. + // May return nullptr if the codec doesn't exist. + const HuffmanCodec* GetMtfHuffmanCodec(uint64_t handle) const { + const auto it = mtf_huffman_codecs_.find(handle); + if (it == mtf_huffman_codecs_.end()) return nullptr; + return it->second.get(); + } + + // Promotes id in all move-to-front sequences if ids can be shared by multiple + // sequences. + void PromoteIfNeeded(uint32_t id) { + if (!model_->AnyDescriptorHasCodingScheme() && + model_->id_fallback_strategy() == + MarkvModel::IdFallbackStrategy::kShortDescriptor) { + // Move-to-front sequences do not share ids. Nothing to do. + return; + } + multi_mtf_.Promote(id); + } + + spv_validator_options validator_options_ = nullptr; + const AssemblyGrammar grammar_; + MarkvHeader header_; + + // MARK-V model, not owned. + const MarkvModel* model_ = nullptr; + + // Current instruction, current operand and current operand index. + spv_parsed_instruction_t inst_; + spv_parsed_operand_t operand_; + uint32_t operand_index_; + + // Maps a result ID to its type ID. By convention: + // - a result ID that is a type definition maps to itself. + // - a result ID without a type maps to 0. (E.g. for OpLabel) + std::unordered_map id_to_type_id_; + + // Container for all move-to-front sequences. + MultiMoveToFront multi_mtf_; + + // Id of the current function or zero if outside of function. + uint32_t cur_function_id_ = 0; + + // Return type of the current function. + uint32_t cur_function_return_type_ = 0; + + // Remaining function parameter types. This container is filled on OpFunction, + // and drained on OpFunctionParameter. + std::list remaining_function_parameter_types_; + + // List of ids local to the current function. + std::vector ids_local_to_cur_function_; + + // List of instructions in the order they are given in the module. + std::vector> instructions_; + + // Container/computer for long (32-bit) id descriptors. + IdDescriptorCollection long_id_descriptors_; + + // Container/computer for short id descriptors. + // Short descriptors are stored in uint32_t, but their actual bit width is + // defined with kShortDescriptorNumBits. + // It doesn't seem logical to have a different computer for short id + // descriptors, since one could actually map/truncate long descriptors. + // But as short descriptors have collisions, the efficiency of + // compression depends on the collision pattern, and short descriptors + // produced by function ShortHashU32Array have been empirically proven to + // produce better results. + IdDescriptorCollection short_id_descriptors_; + + // Huffman codecs for move-to-front ranks. The map key is mtf handle. Doesn't + // need to contain a different codec for every handle as most use one and the + // same. + std::map>> + mtf_huffman_codecs_; + + // If not nullptr, codec will log comments on the compression process. + std::unique_ptr logger_; + + spv_const_context context_ = nullptr; + + private: + // Maps result id to the instruction which defined it. + std::unordered_map id_to_def_instruction_; + + uint32_t id_bound_ = 1; +}; + +} // namespace comp +} // namespace spvtools + +#endif // SOURCE_COMP_MARKV_CODEC_H_ diff --git a/source/comp/markv_decoder.cpp b/source/comp/markv_decoder.cpp new file mode 100644 index 000000000..22115831d --- /dev/null +++ b/source/comp/markv_decoder.cpp @@ -0,0 +1,925 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/comp/markv_decoder.h" + +#include +#include +#include + +#include "source/ext_inst.h" +#include "source/opcode.h" +#include "spirv-tools/libspirv.hpp" + +namespace spvtools { +namespace comp { + +spv_result_t MarkvDecoder::DecodeNonIdWord(uint32_t* word) { + auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_); + + if (codec) { + uint64_t decoded_value = 0; + if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) + return Diag(SPV_ERROR_INVALID_BINARY) + << "Failed to decode non-id word with Huffman"; + + if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) { + // The word decoded successfully. + *word = uint32_t(decoded_value); + assert(*word == decoded_value); + return SPV_SUCCESS; + } + + // Received kMarkvNoneOfTheAbove signal, use fallback decoding. + } + + const size_t chunk_length = + model_->GetOperandVariableWidthChunkLength(operand_.type); + if (chunk_length) { + if (!reader_.ReadVariableWidthU32(word, chunk_length)) + return Diag(SPV_ERROR_INVALID_BINARY) + << "Failed to decode non-id word with varint"; + } else { + if (!reader_.ReadUnencoded(word)) + return Diag(SPV_ERROR_INVALID_BINARY) + << "Failed to read unencoded non-id word"; + } + return SPV_SUCCESS; +} + +spv_result_t MarkvDecoder::DecodeOpcodeAndNumberOfOperands( + uint32_t* opcode, uint32_t* num_operands) { + // First try to use the Markov chain codec. + auto* codec = + model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode()); + if (codec) { + uint64_t decoded_value = 0; + if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) + return Diag(SPV_ERROR_INTERNAL) + << "Failed to decode opcode_and_num_operands, previous opcode is " + << spvOpcodeString(GetPrevOpcode()); + + if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) { + // The word was successfully decoded. + *opcode = uint32_t(decoded_value & 0xFFFF); + *num_operands = uint32_t(decoded_value >> 16); + return SPV_SUCCESS; + } + + // Received kMarkvNoneOfTheAbove signal, use fallback decoding. + } + + // Fallback to base-rate codec. + codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop); + assert(codec); + uint64_t decoded_value = 0; + if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) + return Diag(SPV_ERROR_INTERNAL) + << "Failed to decode opcode_and_num_operands with global codec"; + + if (decoded_value == MarkvModel::GetMarkvNoneOfTheAbove()) { + // Received kMarkvNoneOfTheAbove signal, fallback further. + return SPV_UNSUPPORTED; + } + + *opcode = uint32_t(decoded_value & 0xFFFF); + *num_operands = uint32_t(decoded_value >> 16); + return SPV_SUCCESS; +} + +spv_result_t MarkvDecoder::DecodeMtfRankHuffman(uint64_t mtf, + uint32_t fallback_method, + uint32_t* rank) { + const auto* codec = GetMtfHuffmanCodec(mtf); + if (!codec) { + assert(fallback_method != kMtfNone); + codec = GetMtfHuffmanCodec(fallback_method); + } + + if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to decode MTF rank"; + + uint32_t decoded_value = 0; + if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) + return Diag(SPV_ERROR_INTERNAL) << "Failed to decode MTF rank with Huffman"; + + if (decoded_value == kMtfRankEncodedByValueSignal) { + // Decode by value. + if (!reader_.ReadVariableWidthU32(rank, model_->mtf_rank_chunk_length())) + return Diag(SPV_ERROR_INTERNAL) + << "Failed to decode MTF rank with varint"; + *rank += MarkvCodec::kMtfSmallestRankEncodedByValue; + } else { + // Decode using Huffman coding. + assert(decoded_value < MarkvCodec::kMtfSmallestRankEncodedByValue); + *rank = decoded_value; + } + return SPV_SUCCESS; +} + +spv_result_t MarkvDecoder::DecodeIdWithDescriptor(uint32_t* id) { + auto* codec = + model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_); + + uint64_t mtf = kMtfNone; + if (codec) { + uint64_t decoded_value = 0; + if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) + return Diag(SPV_ERROR_INTERNAL) + << "Failed to decode descriptor with Huffman"; + + if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) { + const uint32_t long_descriptor = uint32_t(decoded_value); + mtf = GetMtfLongIdDescriptor(long_descriptor); + } + } + + if (mtf == kMtfNone) { + if (model_->id_fallback_strategy() != + MarkvModel::IdFallbackStrategy::kShortDescriptor) { + return SPV_UNSUPPORTED; + } + + uint64_t decoded_value = 0; + if (!reader_.ReadBits(&decoded_value, MarkvCodec::kShortDescriptorNumBits)) + return Diag(SPV_ERROR_INTERNAL) << "Failed to read short descriptor"; + const uint32_t short_descriptor = uint32_t(decoded_value); + if (short_descriptor == 0) { + // Forward declared id. + return SPV_UNSUPPORTED; + } + mtf = GetMtfShortIdDescriptor(short_descriptor); + } + + return DecodeExistingId(mtf, id); +} + +spv_result_t MarkvDecoder::DecodeExistingId(uint64_t mtf, uint32_t* id) { + assert(multi_mtf_.GetSize(mtf) > 0); + *id = 0; + + uint32_t rank = 0; + + if (multi_mtf_.GetSize(mtf) == 1) { + rank = 1; + } else { + const spv_result_t result = + DecodeMtfRankHuffman(mtf, kMtfGenericNonZeroRank, &rank); + if (result != SPV_SUCCESS) return result; + } + + assert(rank); + if (!multi_mtf_.ValueFromRank(mtf, rank, id)) + return Diag(SPV_ERROR_INTERNAL) << "MTF rank is out of bounds"; + + return SPV_SUCCESS; +} + +spv_result_t MarkvDecoder::DecodeRefId(uint32_t* id) { + { + const spv_result_t result = DecodeIdWithDescriptor(id); + if (result != SPV_UNSUPPORTED) return result; + } + + const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction( + SpvOp(inst_.opcode))(operand_index_); + uint32_t rank = 0; + *id = 0; + + if (model_->id_fallback_strategy() == + MarkvModel::IdFallbackStrategy::kRuleBased) { + uint64_t mtf = GetRuleBasedMtf(); + if (mtf != kMtfNone && !can_forward_declare) { + return DecodeExistingId(mtf, id); + } + + if (mtf == kMtfNone) mtf = kMtfAll; + { + const spv_result_t result = DecodeMtfRankHuffman(mtf, kMtfAll, &rank); + if (result != SPV_SUCCESS) return result; + } + + if (rank == 0) { + // This is the first occurrence of a forward declared id. + *id = GetIdBound(); + SetIdBound(*id + 1); + multi_mtf_.Insert(kMtfAll, *id); + multi_mtf_.Insert(kMtfForwardDeclared, *id); + if (mtf != kMtfAll) multi_mtf_.Insert(mtf, *id); + } else { + if (!multi_mtf_.ValueFromRank(mtf, rank, id)) + return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds"; + } + } else { + assert(can_forward_declare); + + if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length())) + return Diag(SPV_ERROR_INTERNAL) + << "Failed to decode MTF rank with varint"; + + if (rank == 0) { + // This is the first occurrence of a forward declared id. + *id = GetIdBound(); + SetIdBound(*id + 1); + multi_mtf_.Insert(kMtfForwardDeclared, *id); + } else { + if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank, id)) + return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds"; + } + } + assert(*id); + return SPV_SUCCESS; +} + +spv_result_t MarkvDecoder::DecodeTypeId() { + if (inst_.opcode == SpvOpFunctionParameter) { + assert(!remaining_function_parameter_types_.empty()); + inst_.type_id = remaining_function_parameter_types_.front(); + remaining_function_parameter_types_.pop_front(); + return SPV_SUCCESS; + } + + { + const spv_result_t result = DecodeIdWithDescriptor(&inst_.type_id); + if (result != SPV_UNSUPPORTED) return result; + } + + assert(model_->id_fallback_strategy() == + MarkvModel::IdFallbackStrategy::kRuleBased); + + uint64_t mtf = GetRuleBasedMtf(); + assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))( + operand_index_)); + + if (mtf == kMtfNone) { + mtf = kMtfTypeNonFunction; + // Function types should have been handled by GetRuleBasedMtf. + assert(inst_.opcode != SpvOpFunction); + } + + return DecodeExistingId(mtf, &inst_.type_id); +} + +spv_result_t MarkvDecoder::DecodeResultId() { + uint32_t rank = 0; + + const uint64_t num_still_forward_declared = + multi_mtf_.GetSize(kMtfForwardDeclared); + + if (num_still_forward_declared) { + // Some ids were forward declared. Check if this id is one of them. + uint64_t id_was_forward_declared; + if (!reader_.ReadBits(&id_was_forward_declared, 1)) + return Diag(SPV_ERROR_INVALID_BINARY) + << "Failed to read id_was_forward_declared flag"; + + if (id_was_forward_declared) { + if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length())) + return Diag(SPV_ERROR_INVALID_BINARY) + << "Failed to read MTF rank of forward declared id"; + + if (rank) { + // The id was forward declared, recover it from kMtfForwardDeclared. + if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank, + &inst_.result_id)) + return Diag(SPV_ERROR_INTERNAL) + << "Forward declared MTF rank is out of bounds"; + + // We can now remove the id from kMtfForwardDeclared. + if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id)) + return Diag(SPV_ERROR_INTERNAL) + << "Failed to remove id from kMtfForwardDeclared"; + } + } + } + + if (inst_.result_id == 0) { + // The id was not forward declared, issue a new id. + inst_.result_id = GetIdBound(); + SetIdBound(inst_.result_id + 1); + } + + if (model_->id_fallback_strategy() == + MarkvModel::IdFallbackStrategy::kRuleBased) { + if (!rank) { + multi_mtf_.Insert(kMtfAll, inst_.result_id); + } + } + + return SPV_SUCCESS; +} + +spv_result_t MarkvDecoder::DecodeLiteralNumber( + const spv_parsed_operand_t& operand) { + if (operand.number_bit_width <= 32) { + uint32_t word = 0; + const spv_result_t result = DecodeNonIdWord(&word); + if (result != SPV_SUCCESS) return result; + inst_words_.push_back(word); + } else { + assert(operand.number_bit_width <= 64); + uint64_t word = 0; + if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { + if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length())) + return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal U64"; + } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { + int64_t val = 0; + if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(), + model_->s64_block_exponent())) + return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal S64"; + std::memcpy(&word, &val, 8); + } else if (operand.number_kind == SPV_NUMBER_FLOATING) { + if (!reader_.ReadUnencoded(&word)) + return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal F64"; + } else { + return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length"; + } + inst_words_.push_back(static_cast(word)); + inst_words_.push_back(static_cast(word >> 32)); + } + return SPV_SUCCESS; +} + +bool MarkvDecoder::ReadToByteBreak(size_t byte_break_if_less_than) { + const size_t num_bits_to_next_byte = + GetNumBitsToNextByte(reader_.GetNumReadBits()); + if (num_bits_to_next_byte == 0 || + num_bits_to_next_byte > byte_break_if_less_than) + return true; + + uint64_t bits = 0; + if (!reader_.ReadBits(&bits, num_bits_to_next_byte)) return false; + + assert(bits == 0); + if (bits != 0) return false; + + return true; +} + +spv_result_t MarkvDecoder::DecodeModule(std::vector* spirv_binary) { + const bool header_read_success = + reader_.ReadUnencoded(&header_.magic_number) && + reader_.ReadUnencoded(&header_.markv_version) && + reader_.ReadUnencoded(&header_.markv_model) && + reader_.ReadUnencoded(&header_.markv_length_in_bits) && + reader_.ReadUnencoded(&header_.spirv_version) && + reader_.ReadUnencoded(&header_.spirv_generator); + + if (!header_read_success) + return Diag(SPV_ERROR_INVALID_BINARY) << "Unable to read MARK-V header"; + + if (header_.markv_length_in_bits == 0) + return Diag(SPV_ERROR_INVALID_BINARY) + << "Header markv_length_in_bits field is zero"; + + if (header_.magic_number != MarkvCodec::kMarkvMagicNumber) + return Diag(SPV_ERROR_INVALID_BINARY) + << "MARK-V binary has incorrect magic number"; + + // TODO(atgoo@github.com): Print version strings. + if (header_.markv_version != MarkvCodec::GetMarkvVersion()) + return Diag(SPV_ERROR_INVALID_BINARY) + << "MARK-V binary and the codec have different versions"; + + const uint32_t model_type = header_.markv_model >> 16; + const uint32_t model_version = header_.markv_model & 0xFFFF; + if (model_type != model_->model_type()) + return Diag(SPV_ERROR_INVALID_BINARY) + << "MARK-V binary and the codec use different MARK-V models"; + + if (model_version != model_->model_version()) + return Diag(SPV_ERROR_INVALID_BINARY) + << "MARK-V binary and the codec use different versions if the same " + << "MARK-V model"; + + spirv_.reserve(header_.markv_length_in_bits / 2); // Heuristic. + spirv_.resize(5, 0); + spirv_[0] = SpvMagicNumber; + spirv_[1] = header_.spirv_version; + spirv_[2] = header_.spirv_generator; + + if (logger_) { + reader_.SetCallback( + [this](const std::string& str) { logger_->AppendBitSequence(str); }); + } + + while (reader_.GetNumReadBits() < header_.markv_length_in_bits) { + inst_ = {}; + const spv_result_t decode_result = DecodeInstruction(); + if (decode_result != SPV_SUCCESS) return decode_result; + } + + if (validator_options_) { + spv_const_binary_t validation_binary = {spirv_.data(), spirv_.size()}; + const spv_result_t result = spvValidateWithOptions( + context_, validator_options_, &validation_binary, nullptr); + if (result != SPV_SUCCESS) return result; + } + + // Validate the decode binary + if (reader_.GetNumReadBits() != header_.markv_length_in_bits || + !reader_.OnlyZeroesLeft()) { + return Diag(SPV_ERROR_INVALID_BINARY) + << "MARK-V binary has wrong stated bit length " + << reader_.GetNumReadBits() << " " << header_.markv_length_in_bits; + } + + // Decoding of the module is finished, validation state should have correct + // id bound. + spirv_[3] = GetIdBound(); + + *spirv_binary = std::move(spirv_); + return SPV_SUCCESS; +} + +// TODO(atgoo@github.com): The implementation borrows heavily from +// Parser::parseOperand. +// Consider coupling them together in some way once MARK-V codec is more mature. +// For now it's better to keep the code independent for experimentation +// purposes. +spv_result_t MarkvDecoder::DecodeOperand( + size_t operand_offset, const spv_operand_type_t type, + spv_operand_pattern_t* expected_operands) { + const SpvOp opcode = static_cast(inst_.opcode); + + memset(&operand_, 0, sizeof(operand_)); + + assert((operand_offset >> 16) == 0); + operand_.offset = static_cast(operand_offset); + operand_.type = type; + + // Set default values, may be updated later. + operand_.number_kind = SPV_NUMBER_NONE; + operand_.number_bit_width = 0; + + const size_t first_word_index = inst_words_.size(); + + switch (type) { + case SPV_OPERAND_TYPE_RESULT_ID: { + const spv_result_t result = DecodeResultId(); + if (result != SPV_SUCCESS) return result; + + inst_words_.push_back(inst_.result_id); + SetIdBound(std::max(GetIdBound(), inst_.result_id + 1)); + PromoteIfNeeded(inst_.result_id); + break; + } + + case SPV_OPERAND_TYPE_TYPE_ID: { + const spv_result_t result = DecodeTypeId(); + if (result != SPV_SUCCESS) return result; + + inst_words_.push_back(inst_.type_id); + SetIdBound(std::max(GetIdBound(), inst_.type_id + 1)); + PromoteIfNeeded(inst_.type_id); + break; + } + + case SPV_OPERAND_TYPE_ID: + case SPV_OPERAND_TYPE_OPTIONAL_ID: + case SPV_OPERAND_TYPE_SCOPE_ID: + case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: { + uint32_t id = 0; + const spv_result_t result = DecodeRefId(&id); + if (result != SPV_SUCCESS) return result; + + if (id == 0) return Diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0"; + + if (type == SPV_OPERAND_TYPE_ID || type == SPV_OPERAND_TYPE_OPTIONAL_ID) { + operand_.type = SPV_OPERAND_TYPE_ID; + + if (opcode == SpvOpExtInst && operand_.offset == 3) { + // The current word is the extended instruction set id. + // Set the extended instruction set type for the current + // instruction. + auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id); + if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) { + return Diag(SPV_ERROR_INVALID_ID) + << "OpExtInst set id " << id + << " does not reference an OpExtInstImport result Id"; + } + inst_.ext_inst_type = ext_inst_type_iter->second; + } + } + + inst_words_.push_back(id); + SetIdBound(std::max(GetIdBound(), id + 1)); + PromoteIfNeeded(id); + break; + } + + case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: { + uint32_t word = 0; + const spv_result_t result = DecodeNonIdWord(&word); + if (result != SPV_SUCCESS) return result; + + inst_words_.push_back(word); + + assert(SpvOpExtInst == opcode); + assert(inst_.ext_inst_type != SPV_EXT_INST_TYPE_NONE); + spv_ext_inst_desc ext_inst; + if (grammar_.lookupExtInst(inst_.ext_inst_type, word, &ext_inst)) + return Diag(SPV_ERROR_INVALID_BINARY) + << "Invalid extended instruction number: " << word; + spvPushOperandTypes(ext_inst->operandTypes, expected_operands); + break; + } + + case SPV_OPERAND_TYPE_LITERAL_INTEGER: + case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: { + // These are regular single-word literal integer operands. + // Post-parsing validation should check the range of the parsed value. + operand_.type = SPV_OPERAND_TYPE_LITERAL_INTEGER; + // It turns out they are always unsigned integers! + operand_.number_kind = SPV_NUMBER_UNSIGNED_INT; + operand_.number_bit_width = 32; + + uint32_t word = 0; + const spv_result_t result = DecodeNonIdWord(&word); + if (result != SPV_SUCCESS) return result; + + inst_words_.push_back(word); + break; + } + + case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: + case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER: { + operand_.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER; + if (opcode == SpvOpSwitch) { + // The literal operands have the same type as the value + // referenced by the selector Id. + const uint32_t selector_id = inst_words_.at(1); + const auto type_id_iter = id_to_type_id_.find(selector_id); + if (type_id_iter == id_to_type_id_.end() || type_id_iter->second == 0) { + return Diag(SPV_ERROR_INVALID_BINARY) + << "Invalid OpSwitch: selector id " << selector_id + << " has no type"; + } + uint32_t type_id = type_id_iter->second; + + if (selector_id == type_id) { + // Recall that by convention, a result ID that is a type definition + // maps to itself. + return Diag(SPV_ERROR_INVALID_BINARY) + << "Invalid OpSwitch: selector id " << selector_id + << " is a type, not a value"; + } + if (auto error = SetNumericTypeInfoForType(&operand_, type_id)) + return error; + if (operand_.number_kind != SPV_NUMBER_UNSIGNED_INT && + operand_.number_kind != SPV_NUMBER_SIGNED_INT) { + return Diag(SPV_ERROR_INVALID_BINARY) + << "Invalid OpSwitch: selector id " << selector_id + << " is not a scalar integer"; + } + } else { + assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant); + // The literal number type is determined by the type Id for the + // constant. + assert(inst_.type_id); + if (auto error = SetNumericTypeInfoForType(&operand_, inst_.type_id)) + return error; + } + + if (auto error = DecodeLiteralNumber(operand_)) return error; + + break; + } + + case SPV_OPERAND_TYPE_LITERAL_STRING: + case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: { + operand_.type = SPV_OPERAND_TYPE_LITERAL_STRING; + std::vector str; + auto* codec = model_->GetLiteralStringHuffmanCodec(inst_.opcode); + + if (codec) { + std::string decoded_string; + const bool huffman_result = + codec->DecodeFromStream(GetReadBitCallback(), &decoded_string); + assert(huffman_result); + if (!huffman_result) + return Diag(SPV_ERROR_INVALID_BINARY) + << "Failed to read literal string"; + + if (decoded_string != "kMarkvNoneOfTheAbove") { + std::copy(decoded_string.begin(), decoded_string.end(), + std::back_inserter(str)); + str.push_back('\0'); + } + } + + // The loop is expected to terminate once we encounter '\0' or exhaust + // the bit stream. + if (str.empty()) { + while (true) { + char ch = 0; + if (!reader_.ReadUnencoded(&ch)) + return Diag(SPV_ERROR_INVALID_BINARY) + << "Failed to read literal string"; + + str.push_back(ch); + + if (ch == '\0') break; + } + } + + while (str.size() % 4 != 0) str.push_back('\0'); + + inst_words_.resize(inst_words_.size() + str.size() / 4); + std::memcpy(&inst_words_[first_word_index], str.data(), str.size()); + + if (SpvOpExtInstImport == opcode) { + // Record the extended instruction type for the ID for this import. + // There is only one string literal argument to OpExtInstImport, + // so it's sufficient to guard this just on the opcode. + const spv_ext_inst_type_t ext_inst_type = + spvExtInstImportTypeGet(str.data()); + if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) { + return Diag(SPV_ERROR_INVALID_BINARY) + << "Invalid extended instruction import '" << str.data() + << "'"; + } + // We must have parsed a valid result ID. It's a condition + // of the grammar, and we only accept non-zero result Ids. + assert(inst_.result_id); + const bool inserted = + import_id_to_ext_inst_type_.emplace(inst_.result_id, ext_inst_type) + .second; + (void)inserted; + assert(inserted); + } + break; + } + + case SPV_OPERAND_TYPE_CAPABILITY: + case SPV_OPERAND_TYPE_SOURCE_LANGUAGE: + case SPV_OPERAND_TYPE_EXECUTION_MODEL: + case SPV_OPERAND_TYPE_ADDRESSING_MODEL: + case SPV_OPERAND_TYPE_MEMORY_MODEL: + case SPV_OPERAND_TYPE_EXECUTION_MODE: + case SPV_OPERAND_TYPE_STORAGE_CLASS: + case SPV_OPERAND_TYPE_DIMENSIONALITY: + case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE: + case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE: + case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT: + case SPV_OPERAND_TYPE_FP_ROUNDING_MODE: + case SPV_OPERAND_TYPE_LINKAGE_TYPE: + case SPV_OPERAND_TYPE_ACCESS_QUALIFIER: + case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER: + case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE: + case SPV_OPERAND_TYPE_DECORATION: + case SPV_OPERAND_TYPE_BUILT_IN: + case SPV_OPERAND_TYPE_GROUP_OPERATION: + case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS: + case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: { + // A single word that is a plain enum value. + uint32_t word = 0; + const spv_result_t result = DecodeNonIdWord(&word); + if (result != SPV_SUCCESS) return result; + + inst_words_.push_back(word); + + // Map an optional operand type to its corresponding concrete type. + if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER) + operand_.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER; + + spv_operand_desc entry; + if (grammar_.lookupOperand(type, word, &entry)) { + return Diag(SPV_ERROR_INVALID_BINARY) + << "Invalid " << spvOperandTypeStr(operand_.type) + << " operand: " << word; + } + + // Prepare to accept operands to this operand, if needed. + spvPushOperandTypes(entry->operandTypes, expected_operands); + break; + } + + case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE: + case SPV_OPERAND_TYPE_FUNCTION_CONTROL: + case SPV_OPERAND_TYPE_LOOP_CONTROL: + case SPV_OPERAND_TYPE_IMAGE: + case SPV_OPERAND_TYPE_OPTIONAL_IMAGE: + case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS: + case SPV_OPERAND_TYPE_SELECTION_CONTROL: { + // This operand is a mask. + uint32_t word = 0; + const spv_result_t result = DecodeNonIdWord(&word); + if (result != SPV_SUCCESS) return result; + + inst_words_.push_back(word); + + // Map an optional operand type to its corresponding concrete type. + if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE) + operand_.type = SPV_OPERAND_TYPE_IMAGE; + else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS) + operand_.type = SPV_OPERAND_TYPE_MEMORY_ACCESS; + + // Check validity of set mask bits. Also prepare for operands for those + // masks if they have any. To get operand order correct, scan from + // MSB to LSB since we can only prepend operands to a pattern. + // The only case in the grammar where you have more than one mask bit + // having an operand is for image operands. See SPIR-V 3.14 Image + // Operands. + uint32_t remaining_word = word; + for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) { + if (remaining_word & mask) { + spv_operand_desc entry; + if (grammar_.lookupOperand(type, mask, &entry)) { + return Diag(SPV_ERROR_INVALID_BINARY) + << "Invalid " << spvOperandTypeStr(operand_.type) + << " operand: " << word << " has invalid mask component " + << mask; + } + remaining_word ^= mask; + spvPushOperandTypes(entry->operandTypes, expected_operands); + } + } + if (word == 0) { + // An all-zeroes mask *might* also be valid. + spv_operand_desc entry; + if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) { + // Prepare for its operands, if any. + spvPushOperandTypes(entry->operandTypes, expected_operands); + } + } + break; + } + default: + return Diag(SPV_ERROR_INVALID_BINARY) + << "Internal error: Unhandled operand type: " << type; + } + + operand_.num_words = uint16_t(inst_words_.size() - first_word_index); + + assert(spvOperandIsConcrete(operand_.type)); + + parsed_operands_.push_back(operand_); + + return SPV_SUCCESS; +} + +spv_result_t MarkvDecoder::DecodeInstruction() { + parsed_operands_.clear(); + inst_words_.clear(); + + // Opcode/num_words placeholder, the word will be filled in later. + inst_words_.push_back(0); + + bool num_operands_still_unknown = true; + { + uint32_t opcode = 0; + uint32_t num_operands = 0; + + const spv_result_t opcode_decoding_result = + DecodeOpcodeAndNumberOfOperands(&opcode, &num_operands); + if (opcode_decoding_result < 0) return opcode_decoding_result; + + if (opcode_decoding_result == SPV_SUCCESS) { + inst_.num_operands = static_cast(num_operands); + num_operands_still_unknown = false; + } else { + if (!reader_.ReadVariableWidthU32(&opcode, + model_->opcode_chunk_length())) { + return Diag(SPV_ERROR_INVALID_BINARY) + << "Failed to read opcode of instruction"; + } + } + + inst_.opcode = static_cast(opcode); + } + + const SpvOp opcode = static_cast(inst_.opcode); + + spv_opcode_desc opcode_desc; + if (grammar_.lookupOpcode(opcode, &opcode_desc) != SPV_SUCCESS) { + return Diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode"; + } + + spv_operand_pattern_t expected_operands; + expected_operands.reserve(opcode_desc->numTypes); + for (auto i = 0; i < opcode_desc->numTypes; i++) { + expected_operands.push_back( + opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]); + } + + if (num_operands_still_unknown) { + if (!OpcodeHasFixedNumberOfOperands(opcode)) { + if (!reader_.ReadVariableWidthU16(&inst_.num_operands, + model_->num_operands_chunk_length())) + return Diag(SPV_ERROR_INVALID_BINARY) + << "Failed to read num_operands of instruction"; + } else { + inst_.num_operands = static_cast(expected_operands.size()); + } + } + + for (operand_index_ = 0; + operand_index_ < static_cast(inst_.num_operands); + ++operand_index_) { + assert(!expected_operands.empty()); + const spv_operand_type_t type = + spvTakeFirstMatchableOperand(&expected_operands); + + const size_t operand_offset = inst_words_.size(); + + const spv_result_t decode_result = + DecodeOperand(operand_offset, type, &expected_operands); + + if (decode_result != SPV_SUCCESS) return decode_result; + } + + assert(inst_.num_operands == parsed_operands_.size()); + + // Only valid while inst_words_ and parsed_operands_ remain unchanged (until + // next DecodeInstruction call). + inst_.words = inst_words_.data(); + inst_.operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data(); + inst_.num_words = static_cast(inst_words_.size()); + inst_words_[0] = spvOpcodeMake(inst_.num_words, SpvOp(inst_.opcode)); + + std::copy(inst_words_.begin(), inst_words_.end(), std::back_inserter(spirv_)); + + assert(inst_.num_words == + std::accumulate( + parsed_operands_.begin(), parsed_operands_.end(), 1, + [](int num_words, const spv_parsed_operand_t& operand) { + return num_words += operand.num_words; + }) && + "num_words in instruction doesn't correspond to the sum of num_words" + "in the operands"); + + RecordNumberType(); + ProcessCurInstruction(); + + if (!ReadToByteBreak(MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte)) + return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read to byte break"; + + if (logger_) { + logger_->NewLine(); + std::stringstream ss; + ss << spvOpcodeString(opcode) << " "; + for (size_t index = 1; index < inst_words_.size(); ++index) + ss << inst_words_[index] << " "; + logger_->AppendText(ss.str()); + logger_->NewLine(); + logger_->NewLine(); + if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION; + } + + return SPV_SUCCESS; +} + +spv_result_t MarkvDecoder::SetNumericTypeInfoForType( + spv_parsed_operand_t* parsed_operand, uint32_t type_id) { + assert(type_id != 0); + auto type_info_iter = type_id_to_number_type_info_.find(type_id); + if (type_info_iter == type_id_to_number_type_info_.end()) { + return Diag(SPV_ERROR_INVALID_BINARY) + << "Type Id " << type_id << " is not a type"; + } + + const NumberType& info = type_info_iter->second; + if (info.type == SPV_NUMBER_NONE) { + // This is a valid type, but for something other than a scalar number. + return Diag(SPV_ERROR_INVALID_BINARY) + << "Type Id " << type_id << " is not a scalar numeric type"; + } + + parsed_operand->number_kind = info.type; + parsed_operand->number_bit_width = info.bit_width; + // Round up the word count. + parsed_operand->num_words = static_cast((info.bit_width + 31) / 32); + return SPV_SUCCESS; +} + +void MarkvDecoder::RecordNumberType() { + const SpvOp opcode = static_cast(inst_.opcode); + if (spvOpcodeGeneratesType(opcode)) { + NumberType info = {SPV_NUMBER_NONE, 0}; + if (SpvOpTypeInt == opcode) { + info.bit_width = inst_.words[inst_.operands[1].offset]; + info.type = inst_.words[inst_.operands[2].offset] + ? SPV_NUMBER_SIGNED_INT + : SPV_NUMBER_UNSIGNED_INT; + } else if (SpvOpTypeFloat == opcode) { + info.bit_width = inst_.words[inst_.operands[1].offset]; + info.type = SPV_NUMBER_FLOATING; + } + // The *result* Id of a type generating instruction is the type Id. + type_id_to_number_type_info_[inst_.result_id] = info; + } +} + +} // namespace comp +} // namespace spvtools diff --git a/source/comp/markv_decoder.h b/source/comp/markv_decoder.h new file mode 100644 index 000000000..4d8402b44 --- /dev/null +++ b/source/comp/markv_decoder.h @@ -0,0 +1,175 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/comp/bit_stream.h" +#include "source/comp/markv.h" +#include "source/comp/markv_codec.h" +#include "source/comp/markv_logger.h" +#include "source/util/make_unique.h" + +#ifndef SOURCE_COMP_MARKV_DECODER_H_ +#define SOURCE_COMP_MARKV_DECODER_H_ + +namespace spvtools { +namespace comp { + +class MarkvLogger; + +// Decodes MARK-V buffers written by MarkvEncoder. +class MarkvDecoder : public MarkvCodec { + public: + // |model| is owned by the caller, must be not null and valid during the + // lifetime of MarkvEncoder. + MarkvDecoder(spv_const_context context, const std::vector& markv, + const MarkvCodecOptions& options, const MarkvModel* model) + : MarkvCodec(context, GetValidatorOptions(options), model), + options_(options), + reader_(markv) { + SetIdBound(1); + parsed_operands_.reserve(25); + inst_words_.reserve(25); + } + ~MarkvDecoder() = default; + + // Creates an internal logger which writes comments on the decoding process. + void CreateLogger(MarkvLogConsumer log_consumer, + MarkvDebugConsumer debug_consumer) { + logger_ = MakeUnique(log_consumer, debug_consumer); + } + + // Decodes SPIR-V from MARK-V and stores the words in |spirv_binary|. + // Can be called only once. Fails if data of wrong format or ends prematurely, + // of if validation fails. + spv_result_t DecodeModule(std::vector* spirv_binary); + + // Creates and returns validator options. Returned value owned by the caller. + static spv_validator_options GetValidatorOptions( + const MarkvCodecOptions& options) { + return options.validate_spirv_binary ? spvValidatorOptionsCreate() + : nullptr; + } + + private: + // Describes the format of a typed literal number. + struct NumberType { + spv_number_kind_t type; + uint32_t bit_width; + }; + + // Reads a single bit from reader_. The read bit is stored in |bit|. + // Returns false iff reader_ fails. + bool ReadBit(bool* bit) { + uint64_t bits = 0; + const bool result = reader_.ReadBits(&bits, 1); + if (result) *bit = bits ? true : false; + return result; + }; + + // Returns ReadBit bound to the class object. + std::function GetReadBitCallback() { + return std::bind(&MarkvDecoder::ReadBit, this, std::placeholders::_1); + } + + // Reads a single non-id word from bit stream. operand_.type determines if + // the word needs to be decoded and how. + spv_result_t DecodeNonIdWord(uint32_t* word); + + // Reads and decodes both opcode and num_operands as a single code. + // Returns SPV_UNSUPPORTED iff no suitable codec was found. + spv_result_t DecodeOpcodeAndNumberOfOperands(uint32_t* opcode, + uint32_t* num_operands); + + // Reads mtf rank from bit stream. |mtf| is used to determine the codec + // scheme. |fallback_method| is used if no codec defined for |mtf|. + spv_result_t DecodeMtfRankHuffman(uint64_t mtf, uint32_t fallback_method, + uint32_t* rank); + + // Reads id using coding based on mtf associated with the id descriptor. + // Returns SPV_UNSUPPORTED iff fallback method needs to be used. + spv_result_t DecodeIdWithDescriptor(uint32_t* id); + + // Reads id using coding based on the given |mtf|, which is expected to + // contain the needed |id|. + spv_result_t DecodeExistingId(uint64_t mtf, uint32_t* id); + + // Reads type id of the current instruction if can't be inferred. + spv_result_t DecodeTypeId(); + + // Reads result id of the current instruction if can't be inferred. + spv_result_t DecodeResultId(); + + // Reads id which is neither type nor result id. + spv_result_t DecodeRefId(uint32_t* id); + + // Reads and discards bits until the beginning of the next byte if the + // number of bits until the next byte is less than |byte_break_if_less_than|. + bool ReadToByteBreak(size_t byte_break_if_less_than); + + // Returns instruction words decoded up to this point. + const uint32_t* GetInstWords() const override { return inst_words_.data(); } + + // Reads a literal number as it is described in |operand| from the bit stream, + // decodes and writes it to spirv_. + spv_result_t DecodeLiteralNumber(const spv_parsed_operand_t& operand); + + // Reads instruction from bit stream, decodes and validates it. + // Decoded instruction is valid until the next call of DecodeInstruction(). + spv_result_t DecodeInstruction(); + + // Read operand from the stream decodes and validates it. + spv_result_t DecodeOperand(size_t operand_offset, + const spv_operand_type_t type, + spv_operand_pattern_t* expected_operands); + + // Records the numeric type for an operand according to the type information + // associated with the given non-zero type Id. This can fail if the type Id + // is not a type Id, or if the type Id does not reference a scalar numeric + // type. On success, return SPV_SUCCESS and populates the num_words, + // number_kind, and number_bit_width fields of parsed_operand. + spv_result_t SetNumericTypeInfoForType(spv_parsed_operand_t* parsed_operand, + uint32_t type_id); + + // Records the number type for the current instruction, if it generates a + // type. For types that aren't scalar numbers, record something with number + // kind SPV_NUMBER_NONE. + void RecordNumberType(); + + MarkvCodecOptions options_; + + // Temporary sink where decoded SPIR-V words are written. Once it contains the + // entire module, the container is moved and returned. + std::vector spirv_; + + // Bit stream containing encoded data. + BitReaderWord64 reader_; + + // Temporary storage for operands of the currently parsed instruction. + // Valid until next DecodeInstruction call. + std::vector parsed_operands_; + + // Temporary storage for current instruction words. + // Valid until next DecodeInstruction call. + std::vector inst_words_; + + // Maps a type ID to its number type description. + std::unordered_map type_id_to_number_type_info_; + + // Maps an ExtInstImport id to the extended instruction type. + std::unordered_map import_id_to_ext_inst_type_; +}; + +} // namespace comp +} // namespace spvtools + +#endif // SOURCE_COMP_MARKV_DECODER_H_ diff --git a/source/comp/markv_encoder.cpp b/source/comp/markv_encoder.cpp new file mode 100644 index 000000000..1abd58646 --- /dev/null +++ b/source/comp/markv_encoder.cpp @@ -0,0 +1,486 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/comp/markv_encoder.h" + +#include "source/binary.h" +#include "source/opcode.h" +#include "spirv-tools/libspirv.hpp" + +namespace spvtools { +namespace comp { +namespace { + +const size_t kCommentNumWhitespaces = 2; + +} // namespace + +spv_result_t MarkvEncoder::EncodeNonIdWord(uint32_t word) { + auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_); + + if (codec) { + uint64_t bits = 0; + size_t num_bits = 0; + if (codec->Encode(word, &bits, &num_bits)) { + // Encoding successful. + writer_.WriteBits(bits, num_bits); + return SPV_SUCCESS; + } else { + // Encoding failed, write kMarkvNoneOfTheAbove flag. + if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits, + &num_bits)) + return Diag(SPV_ERROR_INTERNAL) + << "Non-id word Huffman table for " + << spvOpcodeString(SpvOp(inst_.opcode)) << " operand index " + << operand_index_ << " is missing kMarkvNoneOfTheAbove"; + writer_.WriteBits(bits, num_bits); + } + } + + // Fallback encoding. + const size_t chunk_length = + model_->GetOperandVariableWidthChunkLength(operand_.type); + if (chunk_length) { + writer_.WriteVariableWidthU32(word, chunk_length); + } else { + writer_.WriteUnencoded(word); + } + return SPV_SUCCESS; +} + +spv_result_t MarkvEncoder::EncodeOpcodeAndNumOperands(uint32_t opcode, + uint32_t num_operands) { + uint64_t bits = 0; + size_t num_bits = 0; + + const uint32_t word = opcode | (num_operands << 16); + + // First try to use the Markov chain codec. + auto* codec = + model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode()); + if (codec) { + if (codec->Encode(word, &bits, &num_bits)) { + // The word was successfully encoded into bits/num_bits. + writer_.WriteBits(bits, num_bits); + return SPV_SUCCESS; + } else { + // The word is not in the Huffman table. Write kMarkvNoneOfTheAbove + // and use fallback encoding. + if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits, + &num_bits)) + return Diag(SPV_ERROR_INTERNAL) + << "opcode_and_num_operands Huffman table for " + << spvOpcodeString(GetPrevOpcode()) + << "is missing kMarkvNoneOfTheAbove"; + writer_.WriteBits(bits, num_bits); + } + } + + // Fallback to base-rate codec. + codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop); + assert(codec); + if (codec->Encode(word, &bits, &num_bits)) { + // The word was successfully encoded into bits/num_bits. + writer_.WriteBits(bits, num_bits); + return SPV_SUCCESS; + } else { + // The word is not in the Huffman table. Write kMarkvNoneOfTheAbove + // and return false. + if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits, &num_bits)) + return Diag(SPV_ERROR_INTERNAL) + << "Global opcode_and_num_operands Huffman table is missing " + << "kMarkvNoneOfTheAbove"; + writer_.WriteBits(bits, num_bits); + return SPV_UNSUPPORTED; + } +} + +spv_result_t MarkvEncoder::EncodeMtfRankHuffman(uint32_t rank, uint64_t mtf, + uint64_t fallback_method) { + const auto* codec = GetMtfHuffmanCodec(mtf); + if (!codec) { + assert(fallback_method != kMtfNone); + codec = GetMtfHuffmanCodec(fallback_method); + } + + if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to encode MTF rank"; + + uint64_t bits = 0; + size_t num_bits = 0; + if (rank < MarkvCodec::kMtfSmallestRankEncodedByValue) { + // Encode using Huffman coding. + if (!codec->Encode(rank, &bits, &num_bits)) + return Diag(SPV_ERROR_INTERNAL) + << "Failed to encode MTF rank with Huffman"; + + writer_.WriteBits(bits, num_bits); + } else { + // Encode by value. + if (!codec->Encode(MarkvCodec::kMtfRankEncodedByValueSignal, &bits, + &num_bits)) + return Diag(SPV_ERROR_INTERNAL) + << "Failed to encode kMtfRankEncodedByValueSignal"; + + writer_.WriteBits(bits, num_bits); + writer_.WriteVariableWidthU32( + rank - MarkvCodec::kMtfSmallestRankEncodedByValue, + model_->mtf_rank_chunk_length()); + } + return SPV_SUCCESS; +} + +spv_result_t MarkvEncoder::EncodeIdWithDescriptor(uint32_t id) { + // Get the descriptor for id. + const uint32_t long_descriptor = long_id_descriptors_.GetDescriptor(id); + auto* codec = + model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_); + uint64_t bits = 0; + size_t num_bits = 0; + uint64_t mtf = kMtfNone; + if (long_descriptor && codec && + codec->Encode(long_descriptor, &bits, &num_bits)) { + // If the descriptor exists and is in the table, write the descriptor and + // proceed to encoding the rank. + writer_.WriteBits(bits, num_bits); + mtf = GetMtfLongIdDescriptor(long_descriptor); + } else { + if (codec) { + // The descriptor doesn't exist or we have no coding for it. Write + // kMarkvNoneOfTheAbove and go to fallback method. + if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits, + &num_bits)) + return Diag(SPV_ERROR_INTERNAL) + << "Descriptor Huffman table for " + << spvOpcodeString(SpvOp(inst_.opcode)) << " operand index " + << operand_index_ << " is missing kMarkvNoneOfTheAbove"; + + writer_.WriteBits(bits, num_bits); + } + + if (model_->id_fallback_strategy() != + MarkvModel::IdFallbackStrategy::kShortDescriptor) { + return SPV_UNSUPPORTED; + } + + const uint32_t short_descriptor = short_id_descriptors_.GetDescriptor(id); + writer_.WriteBits(short_descriptor, MarkvCodec::kShortDescriptorNumBits); + + if (short_descriptor == 0) { + // Forward declared id. + return SPV_UNSUPPORTED; + } + + mtf = GetMtfShortIdDescriptor(short_descriptor); + } + + // Descriptor has been encoded. Now encode the rank of the id in the + // associated mtf sequence. + return EncodeExistingId(mtf, id); +} + +spv_result_t MarkvEncoder::EncodeExistingId(uint64_t mtf, uint32_t id) { + assert(multi_mtf_.GetSize(mtf) > 0); + if (multi_mtf_.GetSize(mtf) == 1) { + // If the sequence has only one element no need to write rank, the decoder + // would make the same decision. + return SPV_SUCCESS; + } + + uint32_t rank = 0; + if (!multi_mtf_.RankFromValue(mtf, id, &rank)) + return Diag(SPV_ERROR_INTERNAL) << "Id is not in the MTF sequence"; + + return EncodeMtfRankHuffman(rank, mtf, kMtfGenericNonZeroRank); +} + +spv_result_t MarkvEncoder::EncodeRefId(uint32_t id) { + { + // Try to encode using id descriptor mtfs. + const spv_result_t result = EncodeIdWithDescriptor(id); + if (result != SPV_UNSUPPORTED) return result; + // If can't be done continue with other methods. + } + + const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction( + SpvOp(inst_.opcode))(operand_index_); + uint32_t rank = 0; + + if (model_->id_fallback_strategy() == + MarkvModel::IdFallbackStrategy::kRuleBased) { + // Encode using rule-based mtf. + uint64_t mtf = GetRuleBasedMtf(); + + if (mtf != kMtfNone && !can_forward_declare) { + assert(multi_mtf_.HasValue(kMtfAll, id)); + return EncodeExistingId(mtf, id); + } + + if (mtf == kMtfNone) mtf = kMtfAll; + + if (!multi_mtf_.RankFromValue(mtf, id, &rank)) { + // This is the first occurrence of a forward declared id. + multi_mtf_.Insert(kMtfAll, id); + multi_mtf_.Insert(kMtfForwardDeclared, id); + if (mtf != kMtfAll) multi_mtf_.Insert(mtf, id); + rank = 0; + } + + return EncodeMtfRankHuffman(rank, mtf, kMtfAll); + } else { + assert(can_forward_declare); + + if (!multi_mtf_.RankFromValue(kMtfForwardDeclared, id, &rank)) { + // This is the first occurrence of a forward declared id. + multi_mtf_.Insert(kMtfForwardDeclared, id); + rank = 0; + } + + writer_.WriteVariableWidthU32(rank, model_->mtf_rank_chunk_length()); + return SPV_SUCCESS; + } +} + +spv_result_t MarkvEncoder::EncodeTypeId() { + if (inst_.opcode == SpvOpFunctionParameter) { + assert(!remaining_function_parameter_types_.empty()); + assert(inst_.type_id == remaining_function_parameter_types_.front()); + remaining_function_parameter_types_.pop_front(); + return SPV_SUCCESS; + } + + { + // Try to encode using id descriptor mtfs. + const spv_result_t result = EncodeIdWithDescriptor(inst_.type_id); + if (result != SPV_UNSUPPORTED) return result; + // If can't be done continue with other methods. + } + + assert(model_->id_fallback_strategy() == + MarkvModel::IdFallbackStrategy::kRuleBased); + + uint64_t mtf = GetRuleBasedMtf(); + assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))( + operand_index_)); + + if (mtf == kMtfNone) { + mtf = kMtfTypeNonFunction; + // Function types should have been handled by GetRuleBasedMtf. + assert(inst_.opcode != SpvOpFunction); + } + + return EncodeExistingId(mtf, inst_.type_id); +} + +spv_result_t MarkvEncoder::EncodeResultId() { + uint32_t rank = 0; + + const uint64_t num_still_forward_declared = + multi_mtf_.GetSize(kMtfForwardDeclared); + + if (num_still_forward_declared) { + // We write the rank only if kMtfForwardDeclared is not empty. If it is + // empty the decoder knows that there are no forward declared ids to expect. + if (multi_mtf_.RankFromValue(kMtfForwardDeclared, inst_.result_id, &rank)) { + // This is a definition of a forward declared id. We can remove the id + // from kMtfForwardDeclared. + if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id)) + return Diag(SPV_ERROR_INTERNAL) + << "Failed to remove id from kMtfForwardDeclared"; + writer_.WriteBits(1, 1); + writer_.WriteVariableWidthU32(rank, model_->mtf_rank_chunk_length()); + } else { + rank = 0; + writer_.WriteBits(0, 1); + } + } + + if (model_->id_fallback_strategy() == + MarkvModel::IdFallbackStrategy::kRuleBased) { + if (!rank) { + multi_mtf_.Insert(kMtfAll, inst_.result_id); + } + } + + return SPV_SUCCESS; +} + +spv_result_t MarkvEncoder::EncodeLiteralNumber( + const spv_parsed_operand_t& operand) { + if (operand.number_bit_width <= 32) { + const uint32_t word = inst_.words[operand.offset]; + return EncodeNonIdWord(word); + } else { + assert(operand.number_bit_width <= 64); + const uint64_t word = uint64_t(inst_.words[operand.offset]) | + (uint64_t(inst_.words[operand.offset + 1]) << 32); + if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { + writer_.WriteVariableWidthU64(word, model_->u64_chunk_length()); + } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { + int64_t val = 0; + std::memcpy(&val, &word, 8); + writer_.WriteVariableWidthS64(val, model_->s64_chunk_length(), + model_->s64_block_exponent()); + } else if (operand.number_kind == SPV_NUMBER_FLOATING) { + writer_.WriteUnencoded(word); + } else { + return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length"; + } + } + return SPV_SUCCESS; +} + +void MarkvEncoder::AddByteBreak(size_t byte_break_if_less_than) { + const size_t num_bits_to_next_byte = + GetNumBitsToNextByte(writer_.GetNumBits()); + if (num_bits_to_next_byte == 0 || + num_bits_to_next_byte > byte_break_if_less_than) + return; + + if (logger_) { + logger_->AppendWhitespaces(kCommentNumWhitespaces); + logger_->AppendText(""); + } + + writer_.WriteBits(0, num_bits_to_next_byte); +} + +spv_result_t MarkvEncoder::EncodeInstruction( + const spv_parsed_instruction_t& inst) { + SpvOp opcode = SpvOp(inst.opcode); + inst_ = inst; + + LogDisassemblyInstruction(); + + const spv_result_t opcode_encodig_result = + EncodeOpcodeAndNumOperands(opcode, inst.num_operands); + if (opcode_encodig_result < 0) return opcode_encodig_result; + + if (opcode_encodig_result != SPV_SUCCESS) { + // Fallback encoding for opcode and num_operands. + writer_.WriteVariableWidthU32(opcode, model_->opcode_chunk_length()); + + if (!OpcodeHasFixedNumberOfOperands(opcode)) { + // If the opcode has a variable number of operands, encode the number of + // operands with the instruction. + + if (logger_) logger_->AppendWhitespaces(kCommentNumWhitespaces); + + writer_.WriteVariableWidthU16(inst.num_operands, + model_->num_operands_chunk_length()); + } + } + + // Write operands. + const uint32_t num_operands = inst_.num_operands; + for (operand_index_ = 0; operand_index_ < num_operands; ++operand_index_) { + operand_ = inst_.operands[operand_index_]; + + if (logger_) { + logger_->AppendWhitespaces(kCommentNumWhitespaces); + logger_->AppendText("<"); + logger_->AppendText(spvOperandTypeStr(operand_.type)); + logger_->AppendText(">"); + } + + switch (operand_.type) { + case SPV_OPERAND_TYPE_RESULT_ID: + case SPV_OPERAND_TYPE_TYPE_ID: + case SPV_OPERAND_TYPE_ID: + case SPV_OPERAND_TYPE_OPTIONAL_ID: + case SPV_OPERAND_TYPE_SCOPE_ID: + case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: { + const uint32_t id = inst_.words[operand_.offset]; + if (operand_.type == SPV_OPERAND_TYPE_TYPE_ID) { + const spv_result_t result = EncodeTypeId(); + if (result != SPV_SUCCESS) return result; + } else if (operand_.type == SPV_OPERAND_TYPE_RESULT_ID) { + const spv_result_t result = EncodeResultId(); + if (result != SPV_SUCCESS) return result; + } else { + const spv_result_t result = EncodeRefId(id); + if (result != SPV_SUCCESS) return result; + } + + PromoteIfNeeded(id); + break; + } + + case SPV_OPERAND_TYPE_LITERAL_INTEGER: { + const spv_result_t result = + EncodeNonIdWord(inst_.words[operand_.offset]); + if (result != SPV_SUCCESS) return result; + break; + } + + case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: { + const spv_result_t result = EncodeLiteralNumber(operand_); + if (result != SPV_SUCCESS) return result; + break; + } + + case SPV_OPERAND_TYPE_LITERAL_STRING: { + const char* src = + reinterpret_cast(&inst_.words[operand_.offset]); + + auto* codec = model_->GetLiteralStringHuffmanCodec(opcode); + if (codec) { + uint64_t bits = 0; + size_t num_bits = 0; + const std::string str = src; + if (codec->Encode(str, &bits, &num_bits)) { + writer_.WriteBits(bits, num_bits); + break; + } else { + bool result = + codec->Encode("kMarkvNoneOfTheAbove", &bits, &num_bits); + (void)result; + assert(result); + writer_.WriteBits(bits, num_bits); + } + } + + const size_t length = spv_strnlen_s(src, operand_.num_words * 4); + if (length == operand_.num_words * 4) + return Diag(SPV_ERROR_INVALID_BINARY) + << "Failed to find terminal character of literal string"; + for (size_t i = 0; i < length + 1; ++i) writer_.WriteUnencoded(src[i]); + break; + } + + default: { + for (int i = 0; i < operand_.num_words; ++i) { + const uint32_t word = inst_.words[operand_.offset + i]; + const spv_result_t result = EncodeNonIdWord(word); + if (result != SPV_SUCCESS) return result; + } + break; + } + } + } + + AddByteBreak(MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte); + + if (logger_) { + logger_->NewLine(); + logger_->NewLine(); + if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION; + } + + ProcessCurInstruction(); + + return SPV_SUCCESS; +} + +} // namespace comp +} // namespace spvtools diff --git a/source/comp/markv_encoder.h b/source/comp/markv_encoder.h new file mode 100644 index 000000000..21843123f --- /dev/null +++ b/source/comp/markv_encoder.h @@ -0,0 +1,167 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/comp/bit_stream.h" +#include "source/comp/markv.h" +#include "source/comp/markv_codec.h" +#include "source/comp/markv_logger.h" +#include "source/util/make_unique.h" + +#ifndef SOURCE_COMP_MARKV_ENCODER_H_ +#define SOURCE_COMP_MARKV_ENCODER_H_ + +#include + +namespace spvtools { +namespace comp { + +// SPIR-V to MARK-V encoder. Exposes functions EncodeHeader and +// EncodeInstruction which can be used as callback by spvBinaryParse. +// Encoded binary is written to an internally maintained bitstream. +// After the last instruction is encoded, the resulting MARK-V binary can be +// acquired by calling GetMarkvBinary(). +// +// The encoder uses SPIR-V validator to keep internal state, therefore +// SPIR-V binary needs to be able to pass validator checks. +// CreateCommentsLogger() can be used to enable the encoder to write comments +// on how encoding was done, which can later be accessed with GetComments(). +class MarkvEncoder : public MarkvCodec { + public: + // |model| is owned by the caller, must be not null and valid during the + // lifetime of MarkvEncoder. + MarkvEncoder(spv_const_context context, const MarkvCodecOptions& options, + const MarkvModel* model) + : MarkvCodec(context, GetValidatorOptions(options), model), + options_(options) {} + ~MarkvEncoder() override = default; + + // Writes data from SPIR-V header to MARK-V header. + spv_result_t EncodeHeader(spv_endianness_t /* endian */, uint32_t /* magic */, + uint32_t version, uint32_t generator, + uint32_t id_bound, uint32_t /* schema */) { + SetIdBound(id_bound); + header_.spirv_version = version; + header_.spirv_generator = generator; + return SPV_SUCCESS; + } + + // Creates an internal logger which writes comments on the encoding process. + void CreateLogger(MarkvLogConsumer log_consumer, + MarkvDebugConsumer debug_consumer) { + logger_ = MakeUnique(log_consumer, debug_consumer); + writer_.SetCallback( + [this](const std::string& str) { logger_->AppendBitSequence(str); }); + } + + // Encodes SPIR-V instruction to MARK-V and writes to bit stream. + // Operation can fail if the instruction fails to pass the validator or if + // the encoder stubmles on something unexpected. + spv_result_t EncodeInstruction(const spv_parsed_instruction_t& inst); + + // Concatenates MARK-V header and the bit stream with encoded instructions + // into a single buffer and returns it as spv_markv_binary. The returned + // value is owned by the caller and needs to be destroyed with + // spvMarkvBinaryDestroy(). + std::vector GetMarkvBinary() { + header_.markv_length_in_bits = + static_cast(sizeof(header_) * 8 + writer_.GetNumBits()); + header_.markv_model = + (model_->model_type() << 16) | model_->model_version(); + + const size_t num_bytes = sizeof(header_) + writer_.GetDataSizeBytes(); + std::vector markv(num_bytes); + + assert(writer_.GetData()); + std::memcpy(markv.data(), &header_, sizeof(header_)); + std::memcpy(markv.data() + sizeof(header_), writer_.GetData(), + writer_.GetDataSizeBytes()); + return markv; + } + + // Optionally adds disassembly to the comments. + // Disassembly should contain all instructions in the module separated by + // \n, and no header. + void SetDisassembly(std::string&& disassembly) { + disassembly_ = MakeUnique(std::move(disassembly)); + } + + // Extracts the next instruction line from the disassembly and logs it. + void LogDisassemblyInstruction() { + if (logger_ && disassembly_) { + std::string line; + std::getline(*disassembly_, line, '\n'); + logger_->AppendTextNewLine(line); + } + } + + private: + // Creates and returns validator options. Returned value owned by the caller. + static spv_validator_options GetValidatorOptions( + const MarkvCodecOptions& options) { + return options.validate_spirv_binary ? spvValidatorOptionsCreate() + : nullptr; + } + + // Writes a single word to bit stream. operand_.type determines if the word is + // encoded and how. + spv_result_t EncodeNonIdWord(uint32_t word); + + // Writes both opcode and num_operands as a single code. + // Returns SPV_UNSUPPORTED iff no suitable codec was found. + spv_result_t EncodeOpcodeAndNumOperands(uint32_t opcode, + uint32_t num_operands); + + // Writes mtf rank to bit stream. |mtf| is used to determine the codec + // scheme. |fallback_method| is used if no codec defined for |mtf|. + spv_result_t EncodeMtfRankHuffman(uint32_t rank, uint64_t mtf, + uint64_t fallback_method); + + // Writes id using coding based on mtf associated with the id descriptor. + // Returns SPV_UNSUPPORTED iff fallback method needs to be used. + spv_result_t EncodeIdWithDescriptor(uint32_t id); + + // Writes id using coding based on the given |mtf|, which is expected to + // contain the given |id|. + spv_result_t EncodeExistingId(uint64_t mtf, uint32_t id); + + // Writes type id of the current instruction if can't be inferred. + spv_result_t EncodeTypeId(); + + // Writes result id of the current instruction if can't be inferred. + spv_result_t EncodeResultId(); + + // Writes ids which are neither type nor result ids. + spv_result_t EncodeRefId(uint32_t id); + + // Writes bits to the stream until the beginning of the next byte if the + // number of bits until the next byte is less than |byte_break_if_less_than|. + void AddByteBreak(size_t byte_break_if_less_than); + + // Encodes a literal number operand and writes it to the bit stream. + spv_result_t EncodeLiteralNumber(const spv_parsed_operand_t& operand); + + MarkvCodecOptions options_; + + // Bit stream where encoded instructions are written. + BitWriterWord64 writer_; + + // If not nullptr, disassembled instruction lines will be written to comments. + // Format: \n separated instruction lines, no header. + std::unique_ptr disassembly_; +}; + +} // namespace comp +} // namespace spvtools + +#endif // SOURCE_COMP_MARKV_ENCODER_H_ diff --git a/source/comp/markv_logger.h b/source/comp/markv_logger.h new file mode 100644 index 000000000..c07fe97b7 --- /dev/null +++ b/source/comp/markv_logger.h @@ -0,0 +1,93 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_COMP_MARKV_LOGGER_H_ +#define SOURCE_COMP_MARKV_LOGGER_H_ + +#include "source/comp/markv.h" + +namespace spvtools { +namespace comp { + +class MarkvLogger { + public: + MarkvLogger(MarkvLogConsumer log_consumer, MarkvDebugConsumer debug_consumer) + : log_consumer_(log_consumer), debug_consumer_(debug_consumer) {} + + void AppendText(const std::string& str) { + Append(str); + use_delimiter_ = false; + } + + void AppendTextNewLine(const std::string& str) { + Append(str); + Append("\n"); + use_delimiter_ = false; + } + + void AppendBitSequence(const std::string& str) { + if (debug_consumer_) instruction_bits_ << str; + if (use_delimiter_) Append("-"); + Append(str); + use_delimiter_ = true; + } + + void AppendWhitespaces(size_t num) { + Append(std::string(num, ' ')); + use_delimiter_ = false; + } + + void NewLine() { + Append("\n"); + use_delimiter_ = false; + } + + bool DebugInstruction(const spv_parsed_instruction_t& inst) { + bool result = true; + if (debug_consumer_) { + result = debug_consumer_( + std::vector(inst.words, inst.words + inst.num_words), + instruction_bits_.str(), instruction_comment_.str()); + instruction_bits_.str(std::string()); + instruction_comment_.str(std::string()); + } + return result; + } + + private: + MarkvLogger(const MarkvLogger&) = delete; + MarkvLogger(MarkvLogger&&) = delete; + MarkvLogger& operator=(const MarkvLogger&) = delete; + MarkvLogger& operator=(MarkvLogger&&) = delete; + + void Append(const std::string& str) { + if (log_consumer_) log_consumer_(str); + if (debug_consumer_) instruction_comment_ << str; + } + + MarkvLogConsumer log_consumer_; + MarkvDebugConsumer debug_consumer_; + + std::stringstream instruction_bits_; + std::stringstream instruction_comment_; + + // If true a delimiter will be appended before the next bit sequence. + // Used to generate outputs like: 1100-0 1110-1-1100-1-1111-0 110-0. + bool use_delimiter_ = false; +}; + +} // namespace comp +} // namespace spvtools + +#endif // SOURCE_COMP_MARKV_LOGGER_H_ diff --git a/source/comp/markv_model.h b/source/comp/markv_model.h new file mode 100644 index 000000000..d03df02df --- /dev/null +++ b/source/comp/markv_model.h @@ -0,0 +1,232 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_COMP_MARKV_MODEL_H_ +#define SOURCE_COMP_MARKV_MODEL_H_ + +#include + +#include "source/comp/huffman_codec.h" +#include "source/latest_version_spirv_header.h" +#include "spirv-tools/libspirv.hpp" + +namespace spvtools { +namespace comp { + +// Base class for MARK-V models. +// The class contains encoding/decoding model with various constants and +// codecs used by the compression algorithm. +class MarkvModel { + public: + MarkvModel() + : operand_chunk_lengths_( + static_cast(SPV_OPERAND_TYPE_NUM_OPERAND_TYPES), 0) { + // Set default values. + operand_chunk_lengths_[SPV_OPERAND_TYPE_TYPE_ID] = 4; + operand_chunk_lengths_[SPV_OPERAND_TYPE_RESULT_ID] = 8; + operand_chunk_lengths_[SPV_OPERAND_TYPE_ID] = 8; + operand_chunk_lengths_[SPV_OPERAND_TYPE_SCOPE_ID] = 8; + operand_chunk_lengths_[SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID] = 8; + operand_chunk_lengths_[SPV_OPERAND_TYPE_LITERAL_INTEGER] = 6; + operand_chunk_lengths_[SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER] = 6; + operand_chunk_lengths_[SPV_OPERAND_TYPE_CAPABILITY] = 6; + operand_chunk_lengths_[SPV_OPERAND_TYPE_SOURCE_LANGUAGE] = 3; + operand_chunk_lengths_[SPV_OPERAND_TYPE_EXECUTION_MODEL] = 3; + operand_chunk_lengths_[SPV_OPERAND_TYPE_ADDRESSING_MODEL] = 2; + operand_chunk_lengths_[SPV_OPERAND_TYPE_MEMORY_MODEL] = 2; + operand_chunk_lengths_[SPV_OPERAND_TYPE_EXECUTION_MODE] = 6; + operand_chunk_lengths_[SPV_OPERAND_TYPE_STORAGE_CLASS] = 4; + operand_chunk_lengths_[SPV_OPERAND_TYPE_DIMENSIONALITY] = 3; + operand_chunk_lengths_[SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE] = 3; + operand_chunk_lengths_[SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE] = 2; + operand_chunk_lengths_[SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT] = 6; + operand_chunk_lengths_[SPV_OPERAND_TYPE_FP_ROUNDING_MODE] = 2; + operand_chunk_lengths_[SPV_OPERAND_TYPE_LINKAGE_TYPE] = 2; + operand_chunk_lengths_[SPV_OPERAND_TYPE_ACCESS_QUALIFIER] = 2; + operand_chunk_lengths_[SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER] = 2; + operand_chunk_lengths_[SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE] = 3; + operand_chunk_lengths_[SPV_OPERAND_TYPE_DECORATION] = 6; + operand_chunk_lengths_[SPV_OPERAND_TYPE_BUILT_IN] = 6; + operand_chunk_lengths_[SPV_OPERAND_TYPE_GROUP_OPERATION] = 2; + operand_chunk_lengths_[SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS] = 2; + operand_chunk_lengths_[SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO] = 2; + operand_chunk_lengths_[SPV_OPERAND_TYPE_FP_FAST_MATH_MODE] = 4; + operand_chunk_lengths_[SPV_OPERAND_TYPE_FUNCTION_CONTROL] = 4; + operand_chunk_lengths_[SPV_OPERAND_TYPE_LOOP_CONTROL] = 4; + operand_chunk_lengths_[SPV_OPERAND_TYPE_IMAGE] = 4; + operand_chunk_lengths_[SPV_OPERAND_TYPE_OPTIONAL_IMAGE] = 4; + operand_chunk_lengths_[SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS] = 4; + operand_chunk_lengths_[SPV_OPERAND_TYPE_SELECTION_CONTROL] = 4; + operand_chunk_lengths_[SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER] = 6; + operand_chunk_lengths_[SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER] = 6; + } + + uint32_t model_type() const { return model_type_; } + uint32_t model_version() const { return model_version_; } + + uint32_t opcode_chunk_length() const { return opcode_chunk_length_; } + uint32_t num_operands_chunk_length() const { + return num_operands_chunk_length_; + } + uint32_t mtf_rank_chunk_length() const { return mtf_rank_chunk_length_; } + + uint32_t u64_chunk_length() const { return u64_chunk_length_; } + uint32_t s64_chunk_length() const { return s64_chunk_length_; } + uint32_t s64_block_exponent() const { return s64_block_exponent_; } + + enum class IdFallbackStrategy { + kRuleBased = 0, + kShortDescriptor, + }; + + IdFallbackStrategy id_fallback_strategy() const { + return id_fallback_strategy_; + } + + // Returns a codec for common opcode_and_num_operands words for the given + // previous opcode. May return nullptr if the codec doesn't exist. + const HuffmanCodec* GetOpcodeAndNumOperandsMarkovHuffmanCodec( + uint32_t prev_opcode) const { + if (prev_opcode == SpvOpNop) + return opcode_and_num_operands_huffman_codec_.get(); + + const auto it = + opcode_and_num_operands_markov_huffman_codecs_.find(prev_opcode); + if (it == opcode_and_num_operands_markov_huffman_codecs_.end()) + return nullptr; + return it->second.get(); + } + + // Returns a codec for common non-id words used for given operand slot. + // Operand slot is defined by the opcode and the operand index. + // May return nullptr if the codec doesn't exist. + const HuffmanCodec* GetNonIdWordHuffmanCodec( + uint32_t opcode, uint32_t operand_index) const { + const auto it = non_id_word_huffman_codecs_.find( + std::pair(opcode, operand_index)); + if (it == non_id_word_huffman_codecs_.end()) return nullptr; + return it->second.get(); + } + + // Returns a codec for common id descriptos used for given operand slot. + // Operand slot is defined by the opcode and the operand index. + // May return nullptr if the codec doesn't exist. + const HuffmanCodec* GetIdDescriptorHuffmanCodec( + uint32_t opcode, uint32_t operand_index) const { + const auto it = id_descriptor_huffman_codecs_.find( + std::pair(opcode, operand_index)); + if (it == id_descriptor_huffman_codecs_.end()) return nullptr; + return it->second.get(); + } + + // Returns a codec for common strings used by the given opcode. + // Operand slot is defined by the opcode and the operand index. + // May return nullptr if the codec doesn't exist. + const HuffmanCodec* GetLiteralStringHuffmanCodec( + uint32_t opcode) const { + const auto it = literal_string_huffman_codecs_.find(opcode); + if (it == literal_string_huffman_codecs_.end()) return nullptr; + return it->second.get(); + } + + // Checks if |descriptor| has a coding scheme in any of + // id_descriptor_huffman_codecs_. + bool DescriptorHasCodingScheme(uint32_t descriptor) const { + return descriptors_with_coding_scheme_.count(descriptor); + } + + // Checks if any descriptor has a coding scheme. + bool AnyDescriptorHasCodingScheme() const { + return !descriptors_with_coding_scheme_.empty(); + } + + // Returns chunk length used for variable length encoding of spirv operand + // words. + uint32_t GetOperandVariableWidthChunkLength(spv_operand_type_t type) const { + return operand_chunk_lengths_.at(static_cast(type)); + } + + // Sets model type. + void SetModelType(uint32_t in_model_type) { model_type_ = in_model_type; } + + // Sets model version. + void SetModelVersion(uint32_t in_model_version) { + model_version_ = in_model_version; + } + + // Returns value used by Huffman codecs as a signal that a value is not in the + // coding table. + static uint64_t GetMarkvNoneOfTheAbove() { + // Magic number. + return 1111111111111111111; + } + + MarkvModel(const MarkvModel&) = delete; + const MarkvModel& operator=(const MarkvModel&) = delete; + + protected: + // Huffman codec for base-rate of opcode_and_num_operands. + std::unique_ptr> + opcode_and_num_operands_huffman_codec_; + + // Huffman codecs for opcode_and_num_operands. The map key is previous opcode. + std::map>> + opcode_and_num_operands_markov_huffman_codecs_; + + // Huffman codecs for non-id single-word operand values. + // The map key is pair . + std::map, + std::unique_ptr>> + non_id_word_huffman_codecs_; + + // Huffman codecs for id descriptors. The map key is pair + // . + std::map, + std::unique_ptr>> + id_descriptor_huffman_codecs_; + + // Set of all descriptors which have a coding scheme in any of + // id_descriptor_huffman_codecs_. + std::unordered_set descriptors_with_coding_scheme_; + + // Huffman codecs for literal strings. The map key is the opcode of the + // current instruction. This assumes, that there is no more than one literal + // string operand per instruction, but would still work even if this is not + // the case. Names and debug information strings are not collected. + std::map>> + literal_string_huffman_codecs_; + + // Chunk lengths used for variable width encoding of operands (index is + // spv_operand_type of the operand). + std::vector operand_chunk_lengths_; + + uint32_t opcode_chunk_length_ = 7; + uint32_t num_operands_chunk_length_ = 3; + uint32_t mtf_rank_chunk_length_ = 5; + + uint32_t u64_chunk_length_ = 8; + uint32_t s64_chunk_length_ = 8; + uint32_t s64_block_exponent_ = 10; + + IdFallbackStrategy id_fallback_strategy_ = + IdFallbackStrategy::kShortDescriptor; + + uint32_t model_type_ = 0; + uint32_t model_version_ = 0; +}; + +} // namespace comp +} // namespace spvtools + +#endif // SOURCE_COMP_MARKV_MODEL_H_ diff --git a/source/comp/move_to_front.cpp b/source/comp/move_to_front.cpp new file mode 100644 index 000000000..9d35a3f5b --- /dev/null +++ b/source/comp/move_to_front.cpp @@ -0,0 +1,456 @@ +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/comp/move_to_front.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace spvtools { +namespace comp { + +bool MoveToFront::Insert(uint32_t value) { + auto it = value_to_node_.find(value); + if (it != value_to_node_.end() && IsInTree(it->second)) return false; + + const uint32_t old_size = GetSize(); + (void)old_size; + + InsertNode(CreateNode(next_timestamp_++, value)); + + last_accessed_value_ = value; + last_accessed_value_valid_ = true; + + assert(value_to_node_.count(value)); + assert(old_size + 1 == GetSize()); + return true; +} + +bool MoveToFront::Remove(uint32_t value) { + auto it = value_to_node_.find(value); + if (it == value_to_node_.end()) return false; + + if (!IsInTree(it->second)) return false; + + if (last_accessed_value_ == value) last_accessed_value_valid_ = false; + + const uint32_t orphan = RemoveNode(it->second); + (void)orphan; + // The node of |value| is still alive but it's orphaned now. Can still be + // reused later. + assert(!IsInTree(orphan)); + assert(ValueOf(orphan) == value); + return true; +} + +bool MoveToFront::RankFromValue(uint32_t value, uint32_t* rank) { + if (last_accessed_value_valid_ && last_accessed_value_ == value) { + *rank = 1; + return true; + } + + const uint32_t old_size = GetSize(); + if (old_size == 1) { + if (ValueOf(root_) == value) { + *rank = 1; + return true; + } else { + return false; + } + } + + const auto it = value_to_node_.find(value); + if (it == value_to_node_.end()) { + return false; + } + + uint32_t target = it->second; + + if (!IsInTree(target)) { + return false; + } + + uint32_t node = target; + *rank = 1 + SizeOf(LeftOf(node)); + while (node) { + if (IsRightChild(node)) *rank += 1 + SizeOf(LeftOf(ParentOf(node))); + node = ParentOf(node); + } + + // Don't update timestamp if the node has rank 1. + if (*rank != 1) { + // Update timestamp and reposition the node. + target = RemoveNode(target); + assert(ValueOf(target) == value); + assert(old_size == GetSize() + 1); + MutableTimestampOf(target) = next_timestamp_++; + InsertNode(target); + assert(old_size == GetSize()); + } + + last_accessed_value_ = value; + last_accessed_value_valid_ = true; + return true; +} + +bool MoveToFront::HasValue(uint32_t value) const { + const auto it = value_to_node_.find(value); + if (it == value_to_node_.end()) { + return false; + } + + return IsInTree(it->second); +} + +bool MoveToFront::Promote(uint32_t value) { + if (last_accessed_value_valid_ && last_accessed_value_ == value) { + return true; + } + + const uint32_t old_size = GetSize(); + if (old_size == 1) return ValueOf(root_) == value; + + const auto it = value_to_node_.find(value); + if (it == value_to_node_.end()) { + return false; + } + + uint32_t target = it->second; + + if (!IsInTree(target)) { + return false; + } + + // Update timestamp and reposition the node. + target = RemoveNode(target); + assert(ValueOf(target) == value); + assert(old_size == GetSize() + 1); + + MutableTimestampOf(target) = next_timestamp_++; + InsertNode(target); + assert(old_size == GetSize()); + + last_accessed_value_ = value; + last_accessed_value_valid_ = true; + return true; +} + +bool MoveToFront::ValueFromRank(uint32_t rank, uint32_t* value) { + if (last_accessed_value_valid_ && rank == 1) { + *value = last_accessed_value_; + return true; + } + + const uint32_t old_size = GetSize(); + if (rank <= 0 || rank > old_size) { + return false; + } + + if (old_size == 1) { + *value = ValueOf(root_); + return true; + } + + const bool update_timestamp = (rank != 1); + + uint32_t node = root_; + while (node) { + const uint32_t left_subtree_num_nodes = SizeOf(LeftOf(node)); + if (rank == left_subtree_num_nodes + 1) { + // This is the node we are looking for. + // Don't update timestamp if the node has rank 1. + if (update_timestamp) { + node = RemoveNode(node); + assert(old_size == GetSize() + 1); + MutableTimestampOf(node) = next_timestamp_++; + InsertNode(node); + assert(old_size == GetSize()); + } + *value = ValueOf(node); + last_accessed_value_ = *value; + last_accessed_value_valid_ = true; + return true; + } + + if (rank < left_subtree_num_nodes + 1) { + // Descend into the left subtree. The rank is still valid. + node = LeftOf(node); + } else { + // Descend into the right subtree. We leave behind the left subtree and + // the current node, adjust the |rank| accordingly. + rank -= left_subtree_num_nodes + 1; + node = RightOf(node); + } + } + + assert(0); + return false; +} + +uint32_t MoveToFront::CreateNode(uint32_t timestamp, uint32_t value) { + uint32_t handle = static_cast(nodes_.size()); + const auto result = value_to_node_.emplace(value, handle); + if (result.second) { + // Create new node. + nodes_.emplace_back(Node()); + Node& node = nodes_.back(); + node.timestamp = timestamp; + node.value = value; + node.size = 1; + // Non-NIL nodes start with height 1 because their NIL children are + // leaves. + node.height = 1; + } else { + // Reuse old node. + handle = result.first->second; + assert(!IsInTree(handle)); + assert(ValueOf(handle) == value); + assert(SizeOf(handle) == 1); + assert(HeightOf(handle) == 1); + MutableTimestampOf(handle) = timestamp; + } + + return handle; +} + +void MoveToFront::InsertNode(uint32_t node) { + assert(!IsInTree(node)); + assert(SizeOf(node) == 1); + assert(HeightOf(node) == 1); + assert(TimestampOf(node)); + + if (!root_) { + root_ = node; + return; + } + + uint32_t iter = root_; + uint32_t parent = 0; + + // Will determine if |node| will become the right or left child after + // insertion (but before balancing). + bool right_child = true; + + // Find the node which will become |node|'s parent after insertion + // (but before balancing). + while (iter) { + parent = iter; + assert(TimestampOf(iter) != TimestampOf(node)); + right_child = TimestampOf(iter) > TimestampOf(node); + iter = right_child ? RightOf(iter) : LeftOf(iter); + } + + assert(parent); + + // Connect node and parent. + MutableParentOf(node) = parent; + if (right_child) + MutableRightOf(parent) = node; + else + MutableLeftOf(parent) = node; + + // Insertion is finished. Start the balancing process. + bool needs_rebalancing = true; + parent = ParentOf(node); + + while (parent) { + UpdateNode(parent); + + if (needs_rebalancing) { + const int parent_balance = BalanceOf(parent); + + if (RightOf(parent) == node) { + // Added node to the right subtree. + if (parent_balance > 1) { + // Parent is right heavy, rotate left. + if (BalanceOf(node) < 0) RotateRight(node); + parent = RotateLeft(parent); + } else if (parent_balance == 0 || parent_balance == -1) { + // Parent is balanced or left heavy, no need to balance further. + needs_rebalancing = false; + } + } else { + // Added node to the left subtree. + if (parent_balance < -1) { + // Parent is left heavy, rotate right. + if (BalanceOf(node) > 0) RotateLeft(node); + parent = RotateRight(parent); + } else if (parent_balance == 0 || parent_balance == 1) { + // Parent is balanced or right heavy, no need to balance further. + needs_rebalancing = false; + } + } + } + + assert(BalanceOf(parent) >= -1 && (BalanceOf(parent) <= 1)); + + node = parent; + parent = ParentOf(parent); + } +} + +uint32_t MoveToFront::RemoveNode(uint32_t node) { + if (LeftOf(node) && RightOf(node)) { + // If |node| has two children, then use another node as scapegoat and swap + // their contents. We pick the scapegoat on the side of the tree which has + // more nodes. + const uint32_t scapegoat = SizeOf(LeftOf(node)) >= SizeOf(RightOf(node)) + ? RightestDescendantOf(LeftOf(node)) + : LeftestDescendantOf(RightOf(node)); + assert(scapegoat); + std::swap(MutableValueOf(node), MutableValueOf(scapegoat)); + std::swap(MutableTimestampOf(node), MutableTimestampOf(scapegoat)); + value_to_node_[ValueOf(node)] = node; + value_to_node_[ValueOf(scapegoat)] = scapegoat; + node = scapegoat; + } + + // |node| may have only one child at this point. + assert(!RightOf(node) || !LeftOf(node)); + + uint32_t parent = ParentOf(node); + uint32_t child = RightOf(node) ? RightOf(node) : LeftOf(node); + + // Orphan |node| and reconnect parent and child. + if (child) MutableParentOf(child) = parent; + + if (parent) { + if (LeftOf(parent) == node) + MutableLeftOf(parent) = child; + else + MutableRightOf(parent) = child; + } + + MutableParentOf(node) = 0; + MutableLeftOf(node) = 0; + MutableRightOf(node) = 0; + UpdateNode(node); + const uint32_t orphan = node; + + if (root_ == node) root_ = child; + + // Removal is finished. Start the balancing process. + bool needs_rebalancing = true; + node = child; + + while (parent) { + UpdateNode(parent); + + if (needs_rebalancing) { + const int parent_balance = BalanceOf(parent); + + if (parent_balance == 1 || parent_balance == -1) { + // The height of the subtree was not changed. + needs_rebalancing = false; + } else { + if (RightOf(parent) == node) { + // Removed node from the right subtree. + if (parent_balance < -1) { + // Parent is left heavy, rotate right. + const uint32_t sibling = LeftOf(parent); + if (BalanceOf(sibling) > 0) RotateLeft(sibling); + parent = RotateRight(parent); + } + } else { + // Removed node from the left subtree. + if (parent_balance > 1) { + // Parent is right heavy, rotate left. + const uint32_t sibling = RightOf(parent); + if (BalanceOf(sibling) < 0) RotateRight(sibling); + parent = RotateLeft(parent); + } + } + } + } + + assert(BalanceOf(parent) >= -1 && (BalanceOf(parent) <= 1)); + + node = parent; + parent = ParentOf(parent); + } + + return orphan; +} + +uint32_t MoveToFront::RotateLeft(const uint32_t node) { + const uint32_t pivot = RightOf(node); + assert(pivot); + + // LeftOf(pivot) gets attached to node in place of pivot. + MutableRightOf(node) = LeftOf(pivot); + if (RightOf(node)) MutableParentOf(RightOf(node)) = node; + + // Pivot gets attached to ParentOf(node) in place of node. + MutableParentOf(pivot) = ParentOf(node); + if (!ParentOf(node)) + root_ = pivot; + else if (IsLeftChild(node)) + MutableLeftOf(ParentOf(node)) = pivot; + else + MutableRightOf(ParentOf(node)) = pivot; + + // Node is child of pivot. + MutableLeftOf(pivot) = node; + MutableParentOf(node) = pivot; + + // Update both node and pivot. Pivot is the new parent of node, so node should + // be updated first. + UpdateNode(node); + UpdateNode(pivot); + + return pivot; +} + +uint32_t MoveToFront::RotateRight(const uint32_t node) { + const uint32_t pivot = LeftOf(node); + assert(pivot); + + // RightOf(pivot) gets attached to node in place of pivot. + MutableLeftOf(node) = RightOf(pivot); + if (LeftOf(node)) MutableParentOf(LeftOf(node)) = node; + + // Pivot gets attached to ParentOf(node) in place of node. + MutableParentOf(pivot) = ParentOf(node); + if (!ParentOf(node)) + root_ = pivot; + else if (IsLeftChild(node)) + MutableLeftOf(ParentOf(node)) = pivot; + else + MutableRightOf(ParentOf(node)) = pivot; + + // Node is child of pivot. + MutableRightOf(pivot) = node; + MutableParentOf(node) = pivot; + + // Update both node and pivot. Pivot is the new parent of node, so node should + // be updated first. + UpdateNode(node); + UpdateNode(pivot); + + return pivot; +} + +void MoveToFront::UpdateNode(uint32_t node) { + MutableSizeOf(node) = 1 + SizeOf(LeftOf(node)) + SizeOf(RightOf(node)); + MutableHeightOf(node) = + 1 + std::max(HeightOf(LeftOf(node)), HeightOf(RightOf(node))); +} + +} // namespace comp +} // namespace spvtools diff --git a/source/comp/move_to_front.h b/source/comp/move_to_front.h new file mode 100644 index 000000000..8752194ec --- /dev/null +++ b/source/comp/move_to_front.h @@ -0,0 +1,384 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_COMP_MOVE_TO_FRONT_H_ +#define SOURCE_COMP_MOVE_TO_FRONT_H_ + +#include +#include +#include +#include +#include +#include + +namespace spvtools { +namespace comp { + +// Log(n) move-to-front implementation. Implements the following functions: +// Insert - pushes value to the front of the mtf sequence +// (only unique values allowed). +// Remove - remove value from the sequence. +// ValueFromRank - access value by its 1-indexed rank in the sequence. +// RankFromValue - get the rank of the given value in the sequence. +// Accessing a value with ValueFromRank or RankFromValue moves the value to the +// front of the sequence (rank of 1). +// +// The implementation is based on an AVL-based order statistic tree. The tree +// is ordered by timestamps issued when values are inserted or accessed (recent +// values go to the left side of the tree, old values are gradually rotated to +// the right side). +// +// Terminology +// rank: 1-indexed rank showing how recently the value was inserted or accessed. +// node: handle used internally to access node data. +// size: size of the subtree of a node (including the node). +// height: distance from a node to the farthest leaf. +class MoveToFront { + public: + explicit MoveToFront(size_t reserve_capacity = 4) { + nodes_.reserve(reserve_capacity); + + // Create NIL node. + nodes_.emplace_back(Node()); + } + + virtual ~MoveToFront() = default; + + // Inserts value in the move-to-front sequence. Does nothing if the value is + // already in the sequence. Returns true if insertion was successful. + // The inserted value is placed at the front of the sequence (rank 1). + bool Insert(uint32_t value); + + // Removes value from move-to-front sequence. Returns false iff the value + // was not found. + bool Remove(uint32_t value); + + // Computes 1-indexed rank of value in the move-to-front sequence and moves + // the value to the front. Example: + // Before the call: 4 8 2 1 7 + // RankFromValue(8) returns 2 + // After the call: 8 4 2 1 7 + // Returns true iff the value was found in the sequence. + bool RankFromValue(uint32_t value, uint32_t* rank); + + // Returns value corresponding to a 1-indexed rank in the move-to-front + // sequence and moves the value to the front. Example: + // Before the call: 4 8 2 1 7 + // ValueFromRank(2) returns 8 + // After the call: 8 4 2 1 7 + // Returns true iff the rank is within bounds [1, GetSize()]. + bool ValueFromRank(uint32_t rank, uint32_t* value); + + // Moves the value to the front of the sequence. + // Returns false iff value is not in the sequence. + bool Promote(uint32_t value); + + // Returns true iff the move-to-front sequence contains the value. + bool HasValue(uint32_t value) const; + + // Returns the number of elements in the move-to-front sequence. + uint32_t GetSize() const { return SizeOf(root_); } + + protected: + // Internal tree data structure uses handles instead of pointers. Leaves and + // root parent reference a singleton under handle 0. Although dereferencing + // a null pointer is not possible, inappropriate access to handle 0 would + // cause an assertion. Handles are not garbage collected if value was + // deprecated + // with DeprecateValue(). But handles are recycled when a node is + // repositioned. + + // Internal tree data structure node. + struct Node { + // Timestamp from a logical clock which updates every time the element is + // accessed through ValueFromRank or RankFromValue. + uint32_t timestamp = 0; + // The size of the node's subtree, including the node. + // SizeOf(LeftOf(node)) + SizeOf(RightOf(node)) + 1. + uint32_t size = 0; + // Handles to connected nodes. + uint32_t left = 0; + uint32_t right = 0; + uint32_t parent = 0; + // Distance to the farthest leaf. + // Leaves have height 0, real nodes at least 1. + uint32_t height = 0; + // Stored value. + uint32_t value = 0; + }; + + // Creates node and sets correct values. Non-NIL nodes should be created only + // through this function. If the node with this value has been created + // previously + // and since orphaned, reuses the old node instead of creating a new one. + uint32_t CreateNode(uint32_t timestamp, uint32_t value); + + // Node accessor methods. Naming is designed to be similar to natural + // language as these functions tend to be used in sequences, for example: + // ParentOf(LeftestDescendentOf(RightOf(node))) + + // Returns value of the node referenced by |handle|. + uint32_t ValueOf(uint32_t node) const { return nodes_.at(node).value; } + + // Returns left child of |node|. + uint32_t LeftOf(uint32_t node) const { return nodes_.at(node).left; } + + // Returns right child of |node|. + uint32_t RightOf(uint32_t node) const { return nodes_.at(node).right; } + + // Returns parent of |node|. + uint32_t ParentOf(uint32_t node) const { return nodes_.at(node).parent; } + + // Returns timestamp of |node|. + uint32_t TimestampOf(uint32_t node) const { + assert(node); + return nodes_.at(node).timestamp; + } + + // Returns size of |node|. + uint32_t SizeOf(uint32_t node) const { return nodes_.at(node).size; } + + // Returns height of |node|. + uint32_t HeightOf(uint32_t node) const { return nodes_.at(node).height; } + + // Returns mutable reference to value of |node|. + uint32_t& MutableValueOf(uint32_t node) { + assert(node); + return nodes_.at(node).value; + } + + // Returns mutable reference to handle of left child of |node|. + uint32_t& MutableLeftOf(uint32_t node) { + assert(node); + return nodes_.at(node).left; + } + + // Returns mutable reference to handle of right child of |node|. + uint32_t& MutableRightOf(uint32_t node) { + assert(node); + return nodes_.at(node).right; + } + + // Returns mutable reference to handle of parent of |node|. + uint32_t& MutableParentOf(uint32_t node) { + assert(node); + return nodes_.at(node).parent; + } + + // Returns mutable reference to timestamp of |node|. + uint32_t& MutableTimestampOf(uint32_t node) { + assert(node); + return nodes_.at(node).timestamp; + } + + // Returns mutable reference to size of |node|. + uint32_t& MutableSizeOf(uint32_t node) { + assert(node); + return nodes_.at(node).size; + } + + // Returns mutable reference to height of |node|. + uint32_t& MutableHeightOf(uint32_t node) { + assert(node); + return nodes_.at(node).height; + } + + // Returns true iff |node| is left child of its parent. + bool IsLeftChild(uint32_t node) const { + assert(node); + return LeftOf(ParentOf(node)) == node; + } + + // Returns true iff |node| is right child of its parent. + bool IsRightChild(uint32_t node) const { + assert(node); + return RightOf(ParentOf(node)) == node; + } + + // Returns true iff |node| has no relatives. + bool IsOrphan(uint32_t node) const { + assert(node); + return !ParentOf(node) && !LeftOf(node) && !RightOf(node); + } + + // Returns true iff |node| is in the tree. + bool IsInTree(uint32_t node) const { + assert(node); + return node == root_ || !IsOrphan(node); + } + + // Returns the height difference between right and left subtrees. + int BalanceOf(uint32_t node) const { + return int(HeightOf(RightOf(node))) - int(HeightOf(LeftOf(node))); + } + + // Updates size and height of the node, assuming that the children have + // correct values. + void UpdateNode(uint32_t node); + + // Returns the most LeftOf(LeftOf(... descendent which is not leaf. + uint32_t LeftestDescendantOf(uint32_t node) const { + uint32_t parent = 0; + while (node) { + parent = node; + node = LeftOf(node); + } + return parent; + } + + // Returns the most RightOf(RightOf(... descendent which is not leaf. + uint32_t RightestDescendantOf(uint32_t node) const { + uint32_t parent = 0; + while (node) { + parent = node; + node = RightOf(node); + } + return parent; + } + + // Inserts node in the tree. The node must be an orphan. + void InsertNode(uint32_t node); + + // Removes node from the tree. May change value_to_node_ if removal uses a + // scapegoat. Returns the removed (orphaned) handle for recycling. The + // returned handle may not be equal to |node| if scapegoat was used. + uint32_t RemoveNode(uint32_t node); + + // Rotates |node| left, reassigns all connections and returns the node + // which takes place of the |node|. + uint32_t RotateLeft(const uint32_t node); + + // Rotates |node| right, reassigns all connections and returns the node + // which takes place of the |node|. + uint32_t RotateRight(const uint32_t node); + + // Root node handle. The tree is empty if root_ is 0. + uint32_t root_ = 0; + + // Incremented counters for next timestamp and value. + uint32_t next_timestamp_ = 1; + + // Holds all tree nodes. Indices of this vector are node handles. + std::vector nodes_; + + // Maps ids to node handles. + std::unordered_map value_to_node_; + + // Cache for the last accessed value in the sequence. + uint32_t last_accessed_value_ = 0; + bool last_accessed_value_valid_ = false; +}; + +class MultiMoveToFront { + public: + // Inserts |value| to sequence with handle |mtf|. + // Returns false if |mtf| already has |value|. + bool Insert(uint64_t mtf, uint32_t value) { + if (GetMtf(mtf).Insert(value)) { + val_to_mtfs_[value].insert(mtf); + return true; + } + return false; + } + + // Removes |value| from sequence with handle |mtf|. + // Returns false if |mtf| doesn't have |value|. + bool Remove(uint64_t mtf, uint32_t value) { + if (GetMtf(mtf).Remove(value)) { + val_to_mtfs_[value].erase(mtf); + return true; + } + assert(val_to_mtfs_[value].count(mtf) == 0); + return false; + } + + // Removes |value| from all sequences which have it. + void RemoveFromAll(uint32_t value) { + auto it = val_to_mtfs_.find(value); + if (it == val_to_mtfs_.end()) return; + + auto& mtfs_containing_value = it->second; + for (uint64_t mtf : mtfs_containing_value) { + GetMtf(mtf).Remove(value); + } + + val_to_mtfs_.erase(value); + } + + // Computes rank of |value| in sequence |mtf|. + // Returns false if |mtf| doesn't have |value|. + bool RankFromValue(uint64_t mtf, uint32_t value, uint32_t* rank) { + return GetMtf(mtf).RankFromValue(value, rank); + } + + // Finds |value| with |rank| in sequence |mtf|. + // Returns false if |rank| is out of bounds. + bool ValueFromRank(uint64_t mtf, uint32_t rank, uint32_t* value) { + return GetMtf(mtf).ValueFromRank(rank, value); + } + + // Returns size of |mtf| sequence. + uint32_t GetSize(uint64_t mtf) { return GetMtf(mtf).GetSize(); } + + // Promotes |value| in all sequences which have it. + void Promote(uint32_t value) { + const auto it = val_to_mtfs_.find(value); + if (it == val_to_mtfs_.end()) return; + + const auto& mtfs_containing_value = it->second; + for (uint64_t mtf : mtfs_containing_value) { + GetMtf(mtf).Promote(value); + } + } + + // Inserts |value| in sequence |mtf| or promotes if it's already there. + void InsertOrPromote(uint64_t mtf, uint32_t value) { + if (!Insert(mtf, value)) { + GetMtf(mtf).Promote(value); + } + } + + // Returns if |mtf| sequence has |value|. + bool HasValue(uint64_t mtf, uint32_t value) { + return GetMtf(mtf).HasValue(value); + } + + private: + // Returns actual MoveToFront object corresponding to |handle|. + // As multiple operations are often performed consecutively for the same + // sequence, the last returned value is cached. + MoveToFront& GetMtf(uint64_t handle) { + if (!cached_mtf_ || cached_handle_ != handle) { + cached_handle_ = handle; + cached_mtf_ = &mtfs_[handle]; + } + + return *cached_mtf_; + } + + // Container holding MoveToFront objects. Map key is sequence handle. + std::map mtfs_; + + // Container mapping value to sequences which contain that value. + std::unordered_map> val_to_mtfs_; + + // Cache for the last accessed sequence. + uint64_t cached_handle_ = 0; + MoveToFront* cached_mtf_ = nullptr; +}; + +} // namespace comp +} // namespace spvtools + +#endif // SOURCE_COMP_MOVE_TO_FRONT_H_ diff --git a/source/diagnostic.cpp b/source/diagnostic.cpp new file mode 100644 index 000000000..edc27c8fd --- /dev/null +++ b/source/diagnostic.cpp @@ -0,0 +1,193 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/diagnostic.h" + +#include +#include +#include +#include +#include + +#include "source/table.h" + +// Diagnostic API + +spv_diagnostic spvDiagnosticCreate(const spv_position position, + const char* message) { + spv_diagnostic diagnostic = new spv_diagnostic_t; + if (!diagnostic) return nullptr; + size_t length = strlen(message) + 1; + diagnostic->error = new char[length]; + if (!diagnostic->error) { + delete diagnostic; + return nullptr; + } + diagnostic->position = *position; + diagnostic->isTextSource = false; + memset(diagnostic->error, 0, length); + strncpy(diagnostic->error, message, length); + return diagnostic; +} + +void spvDiagnosticDestroy(spv_diagnostic diagnostic) { + if (!diagnostic) return; + delete[] diagnostic->error; + delete diagnostic; +} + +spv_result_t spvDiagnosticPrint(const spv_diagnostic diagnostic) { + if (!diagnostic) return SPV_ERROR_INVALID_DIAGNOSTIC; + + if (diagnostic->isTextSource) { + // NOTE: This is a text position + // NOTE: add 1 to the line as editors start at line 1, we are counting new + // line characters to start at line 0 + std::cerr << "error: " << diagnostic->position.line + 1 << ": " + << diagnostic->position.column + 1 << ": " << diagnostic->error + << "\n"; + return SPV_SUCCESS; + } + + // NOTE: Assume this is a binary position + std::cerr << "error: "; + if (diagnostic->position.index > 0) + std::cerr << diagnostic->position.index << ": "; + std::cerr << diagnostic->error << "\n"; + return SPV_SUCCESS; +} + +namespace spvtools { + +DiagnosticStream::DiagnosticStream(DiagnosticStream&& other) + : stream_(), + position_(other.position_), + consumer_(other.consumer_), + disassembled_instruction_(std::move(other.disassembled_instruction_)), + error_(other.error_) { + // Prevent the other object from emitting output during destruction. + other.error_ = SPV_FAILED_MATCH; + // Some platforms are missing support for std::ostringstream functionality, + // including: move constructor, swap method. Either would have been a + // better choice than copying the string. + stream_ << other.stream_.str(); +} + +DiagnosticStream::~DiagnosticStream() { + if (error_ != SPV_FAILED_MATCH && consumer_ != nullptr) { + auto level = SPV_MSG_ERROR; + switch (error_) { + case SPV_SUCCESS: + case SPV_REQUESTED_TERMINATION: // Essentially success. + level = SPV_MSG_INFO; + break; + case SPV_WARNING: + level = SPV_MSG_WARNING; + break; + case SPV_UNSUPPORTED: + case SPV_ERROR_INTERNAL: + case SPV_ERROR_INVALID_TABLE: + level = SPV_MSG_INTERNAL_ERROR; + break; + case SPV_ERROR_OUT_OF_MEMORY: + level = SPV_MSG_FATAL; + break; + default: + break; + } + if (disassembled_instruction_.size() > 0) + stream_ << std::endl << " " << disassembled_instruction_ << std::endl; + + consumer_(level, "input", position_, stream_.str().c_str()); + } +} + +void UseDiagnosticAsMessageConsumer(spv_context context, + spv_diagnostic* diagnostic) { + assert(diagnostic && *diagnostic == nullptr); + + auto create_diagnostic = [diagnostic](spv_message_level_t, const char*, + const spv_position_t& position, + const char* message) { + auto p = position; + spvDiagnosticDestroy(*diagnostic); // Avoid memory leak. + *diagnostic = spvDiagnosticCreate(&p, message); + }; + SetContextMessageConsumer(context, std::move(create_diagnostic)); +} + +std::string spvResultToString(spv_result_t res) { + std::string out; + switch (res) { + case SPV_SUCCESS: + out = "SPV_SUCCESS"; + break; + case SPV_UNSUPPORTED: + out = "SPV_UNSUPPORTED"; + break; + case SPV_END_OF_STREAM: + out = "SPV_END_OF_STREAM"; + break; + case SPV_WARNING: + out = "SPV_WARNING"; + break; + case SPV_FAILED_MATCH: + out = "SPV_FAILED_MATCH"; + break; + case SPV_REQUESTED_TERMINATION: + out = "SPV_REQUESTED_TERMINATION"; + break; + case SPV_ERROR_INTERNAL: + out = "SPV_ERROR_INTERNAL"; + break; + case SPV_ERROR_OUT_OF_MEMORY: + out = "SPV_ERROR_OUT_OF_MEMORY"; + break; + case SPV_ERROR_INVALID_POINTER: + out = "SPV_ERROR_INVALID_POINTER"; + break; + case SPV_ERROR_INVALID_BINARY: + out = "SPV_ERROR_INVALID_BINARY"; + break; + case SPV_ERROR_INVALID_TEXT: + out = "SPV_ERROR_INVALID_TEXT"; + break; + case SPV_ERROR_INVALID_TABLE: + out = "SPV_ERROR_INVALID_TABLE"; + break; + case SPV_ERROR_INVALID_VALUE: + out = "SPV_ERROR_INVALID_VALUE"; + break; + case SPV_ERROR_INVALID_DIAGNOSTIC: + out = "SPV_ERROR_INVALID_DIAGNOSTIC"; + break; + case SPV_ERROR_INVALID_LOOKUP: + out = "SPV_ERROR_INVALID_LOOKUP"; + break; + case SPV_ERROR_INVALID_ID: + out = "SPV_ERROR_INVALID_ID"; + break; + case SPV_ERROR_INVALID_CFG: + out = "SPV_ERROR_INVALID_CFG"; + break; + case SPV_ERROR_INVALID_LAYOUT: + out = "SPV_ERROR_INVALID_LAYOUT"; + break; + default: + out = "Unknown Error"; + } + return out; +} + +} // namespace spvtools diff --git a/source/diagnostic.h b/source/diagnostic.h new file mode 100644 index 000000000..22df96143 --- /dev/null +++ b/source/diagnostic.h @@ -0,0 +1,79 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_DIAGNOSTIC_H_ +#define SOURCE_DIAGNOSTIC_H_ + +#include +#include + +#include "spirv-tools/libspirv.hpp" + +namespace spvtools { + +// A DiagnosticStream remembers the current position of the input and an error +// code, and captures diagnostic messages via the left-shift operator. +// If the error code is not SPV_FAILED_MATCH, then captured messages are +// emitted during the destructor. +class DiagnosticStream { + public: + DiagnosticStream(spv_position_t position, const MessageConsumer& consumer, + const std::string& disassembled_instruction, + spv_result_t error) + : position_(position), + consumer_(consumer), + disassembled_instruction_(disassembled_instruction), + error_(error) {} + + // Creates a DiagnosticStream from an expiring DiagnosticStream. + // The new object takes the contents of the other, and prevents the + // other from emitting anything during destruction. + DiagnosticStream(DiagnosticStream&& other); + + // Destroys a DiagnosticStream. + // If its status code is something other than SPV_FAILED_MATCH + // then emit the accumulated message to the consumer. + ~DiagnosticStream(); + + // Adds the given value to the diagnostic message to be written. + template + DiagnosticStream& operator<<(const T& val) { + stream_ << val; + return *this; + } + + // Conversion operator to spv_result, returning the error code. + operator spv_result_t() { return error_; } + + private: + std::ostringstream stream_; + spv_position_t position_; + MessageConsumer consumer_; // Message consumer callback. + std::string disassembled_instruction_; + spv_result_t error_; +}; + +// Changes the MessageConsumer in |context| to one that updates |diagnostic| +// with the last message received. +// +// This function expects that |diagnostic| is not nullptr and its content is a +// nullptr. +void UseDiagnosticAsMessageConsumer(spv_context context, + spv_diagnostic* diagnostic); + +std::string spvResultToString(spv_result_t res); + +} // namespace spvtools + +#endif // SOURCE_DIAGNOSTIC_H_ diff --git a/source/disassemble.cpp b/source/disassemble.cpp new file mode 100644 index 000000000..c116f5072 --- /dev/null +++ b/source/disassemble.cpp @@ -0,0 +1,478 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains a disassembler: It converts a SPIR-V binary +// to text. + +#include +#include +#include +#include +#include +#include +#include + +#include "source/assembly_grammar.h" +#include "source/binary.h" +#include "source/diagnostic.h" +#include "source/disassemble.h" +#include "source/ext_inst.h" +#include "source/name_mapper.h" +#include "source/opcode.h" +#include "source/parsed_operand.h" +#include "source/print.h" +#include "source/spirv_constant.h" +#include "source/spirv_endian.h" +#include "source/util/hex_float.h" +#include "source/util/make_unique.h" +#include "spirv-tools/libspirv.h" + +namespace { + +// A Disassembler instance converts a SPIR-V binary to its assembly +// representation. +class Disassembler { + public: + Disassembler(const spvtools::AssemblyGrammar& grammar, uint32_t options, + spvtools::NameMapper name_mapper) + : grammar_(grammar), + print_(spvIsInBitfield(SPV_BINARY_TO_TEXT_OPTION_PRINT, options)), + color_(spvIsInBitfield(SPV_BINARY_TO_TEXT_OPTION_COLOR, options)), + indent_(spvIsInBitfield(SPV_BINARY_TO_TEXT_OPTION_INDENT, options) + ? kStandardIndent + : 0), + text_(), + out_(print_ ? out_stream() : out_stream(text_)), + stream_(out_.get()), + header_(!spvIsInBitfield(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER, options)), + show_byte_offset_(spvIsInBitfield( + SPV_BINARY_TO_TEXT_OPTION_SHOW_BYTE_OFFSET, options)), + byte_offset_(0), + name_mapper_(std::move(name_mapper)) {} + + // Emits the assembly header for the module, and sets up internal state + // so subsequent callbacks can handle the cases where the entire module + // is either big-endian or little-endian. + spv_result_t HandleHeader(spv_endianness_t endian, uint32_t version, + uint32_t generator, uint32_t id_bound, + uint32_t schema); + // Emits the assembly text for the given instruction. + spv_result_t HandleInstruction(const spv_parsed_instruction_t& inst); + + // If not printing, populates text_result with the accumulated text. + // Returns SPV_SUCCESS on success. + spv_result_t SaveTextResult(spv_text* text_result) const; + + private: + enum { kStandardIndent = 15 }; + + using out_stream = spvtools::out_stream; + + // Emits an operand for the given instruction, where the instruction + // is at offset words from the start of the binary. + void EmitOperand(const spv_parsed_instruction_t& inst, + const uint16_t operand_index); + + // Emits a mask expression for the given mask word of the specified type. + void EmitMaskOperand(const spv_operand_type_t type, const uint32_t word); + + // Resets the output color, if color is turned on. + void ResetColor() { + if (color_) out_.get() << spvtools::clr::reset{print_}; + } + // Sets the output to grey, if color is turned on. + void SetGrey() { + if (color_) out_.get() << spvtools::clr::grey{print_}; + } + // Sets the output to blue, if color is turned on. + void SetBlue() { + if (color_) out_.get() << spvtools::clr::blue{print_}; + } + // Sets the output to yellow, if color is turned on. + void SetYellow() { + if (color_) out_.get() << spvtools::clr::yellow{print_}; + } + // Sets the output to red, if color is turned on. + void SetRed() { + if (color_) out_.get() << spvtools::clr::red{print_}; + } + // Sets the output to green, if color is turned on. + void SetGreen() { + if (color_) out_.get() << spvtools::clr::green{print_}; + } + + const spvtools::AssemblyGrammar& grammar_; + const bool print_; // Should we also print to the standard output stream? + const bool color_; // Should we print in colour? + const int indent_; // How much to indent. 0 means don't indent + spv_endianness_t endian_; // The detected endianness of the binary. + std::stringstream text_; // Captures the text, if not printing. + out_stream out_; // The Output stream. Either to text_ or standard output. + std::ostream& stream_; // The output std::stream. + const bool header_; // Should we output header as the leading comment? + const bool show_byte_offset_; // Should we print byte offset, in hex? + size_t byte_offset_; // The number of bytes processed so far. + spvtools::NameMapper name_mapper_; +}; + +spv_result_t Disassembler::HandleHeader(spv_endianness_t endian, + uint32_t version, uint32_t generator, + uint32_t id_bound, uint32_t schema) { + endian_ = endian; + + if (header_) { + SetGrey(); + const char* generator_tool = + spvGeneratorStr(SPV_GENERATOR_TOOL_PART(generator)); + stream_ << "; SPIR-V\n" + << "; Version: " << SPV_SPIRV_VERSION_MAJOR_PART(version) << "." + << SPV_SPIRV_VERSION_MINOR_PART(version) << "\n" + << "; Generator: " << generator_tool; + // For unknown tools, print the numeric tool value. + if (0 == strcmp("Unknown", generator_tool)) { + stream_ << "(" << SPV_GENERATOR_TOOL_PART(generator) << ")"; + } + // Print the miscellaneous part of the generator word on the same + // line as the tool name. + stream_ << "; " << SPV_GENERATOR_MISC_PART(generator) << "\n" + << "; Bound: " << id_bound << "\n" + << "; Schema: " << schema << "\n"; + ResetColor(); + } + + byte_offset_ = SPV_INDEX_INSTRUCTION * sizeof(uint32_t); + + return SPV_SUCCESS; +} + +spv_result_t Disassembler::HandleInstruction( + const spv_parsed_instruction_t& inst) { + if (inst.result_id) { + SetBlue(); + const std::string id_name = name_mapper_(inst.result_id); + if (indent_) + stream_ << std::setw(std::max(0, indent_ - 3 - int(id_name.size()))); + stream_ << "%" << id_name; + ResetColor(); + stream_ << " = "; + } else { + stream_ << std::string(indent_, ' '); + } + + stream_ << "Op" << spvOpcodeString(static_cast(inst.opcode)); + + for (uint16_t i = 0; i < inst.num_operands; i++) { + const spv_operand_type_t type = inst.operands[i].type; + assert(type != SPV_OPERAND_TYPE_NONE); + if (type == SPV_OPERAND_TYPE_RESULT_ID) continue; + stream_ << " "; + EmitOperand(inst, i); + } + + if (show_byte_offset_) { + SetGrey(); + auto saved_flags = stream_.flags(); + auto saved_fill = stream_.fill(); + stream_ << " ; 0x" << std::setw(8) << std::hex << std::setfill('0') + << byte_offset_; + stream_.flags(saved_flags); + stream_.fill(saved_fill); + ResetColor(); + } + + byte_offset_ += inst.num_words * sizeof(uint32_t); + + stream_ << "\n"; + return SPV_SUCCESS; +} + +void Disassembler::EmitOperand(const spv_parsed_instruction_t& inst, + const uint16_t operand_index) { + assert(operand_index < inst.num_operands); + const spv_parsed_operand_t& operand = inst.operands[operand_index]; + const uint32_t word = inst.words[operand.offset]; + switch (operand.type) { + case SPV_OPERAND_TYPE_RESULT_ID: + assert(false && " is not supposed to be handled here"); + SetBlue(); + stream_ << "%" << name_mapper_(word); + break; + case SPV_OPERAND_TYPE_ID: + case SPV_OPERAND_TYPE_TYPE_ID: + case SPV_OPERAND_TYPE_SCOPE_ID: + case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: + SetYellow(); + stream_ << "%" << name_mapper_(word); + break; + case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: { + spv_ext_inst_desc ext_inst; + if (grammar_.lookupExtInst(inst.ext_inst_type, word, &ext_inst)) + assert(false && "should have caught this earlier"); + SetRed(); + stream_ << ext_inst->name; + } break; + case SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER: { + spv_opcode_desc opcode_desc; + if (grammar_.lookupOpcode(SpvOp(word), &opcode_desc)) + assert(false && "should have caught this earlier"); + SetRed(); + stream_ << opcode_desc->name; + } break; + case SPV_OPERAND_TYPE_LITERAL_INTEGER: + case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: { + SetRed(); + spvtools::EmitNumericLiteral(&stream_, inst, operand); + ResetColor(); + } break; + case SPV_OPERAND_TYPE_LITERAL_STRING: { + stream_ << "\""; + SetGreen(); + // Strings are always little-endian, and null-terminated. + // Write out the characters, escaping as needed, and without copying + // the entire string. + auto c_str = reinterpret_cast(inst.words + operand.offset); + for (auto p = c_str; *p; ++p) { + if (*p == '"' || *p == '\\') stream_ << '\\'; + stream_ << *p; + } + ResetColor(); + stream_ << '"'; + } break; + case SPV_OPERAND_TYPE_CAPABILITY: + case SPV_OPERAND_TYPE_SOURCE_LANGUAGE: + case SPV_OPERAND_TYPE_EXECUTION_MODEL: + case SPV_OPERAND_TYPE_ADDRESSING_MODEL: + case SPV_OPERAND_TYPE_MEMORY_MODEL: + case SPV_OPERAND_TYPE_EXECUTION_MODE: + case SPV_OPERAND_TYPE_STORAGE_CLASS: + case SPV_OPERAND_TYPE_DIMENSIONALITY: + case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE: + case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE: + case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT: + case SPV_OPERAND_TYPE_FP_ROUNDING_MODE: + case SPV_OPERAND_TYPE_LINKAGE_TYPE: + case SPV_OPERAND_TYPE_ACCESS_QUALIFIER: + case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE: + case SPV_OPERAND_TYPE_DECORATION: + case SPV_OPERAND_TYPE_BUILT_IN: + case SPV_OPERAND_TYPE_GROUP_OPERATION: + case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS: + case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: + case SPV_OPERAND_TYPE_DEBUG_BASE_TYPE_ATTRIBUTE_ENCODING: + case SPV_OPERAND_TYPE_DEBUG_COMPOSITE_TYPE: + case SPV_OPERAND_TYPE_DEBUG_TYPE_QUALIFIER: + case SPV_OPERAND_TYPE_DEBUG_OPERATION: { + spv_operand_desc entry; + if (grammar_.lookupOperand(operand.type, word, &entry)) + assert(false && "should have caught this earlier"); + stream_ << entry->name; + } break; + case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE: + case SPV_OPERAND_TYPE_FUNCTION_CONTROL: + case SPV_OPERAND_TYPE_LOOP_CONTROL: + case SPV_OPERAND_TYPE_IMAGE: + case SPV_OPERAND_TYPE_MEMORY_ACCESS: + case SPV_OPERAND_TYPE_SELECTION_CONTROL: + case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS: + EmitMaskOperand(operand.type, word); + break; + default: + assert(false && "unhandled or invalid case"); + } + ResetColor(); +} + +void Disassembler::EmitMaskOperand(const spv_operand_type_t type, + const uint32_t word) { + // Scan the mask from least significant bit to most significant bit. For each + // set bit, emit the name of that bit. Separate multiple names with '|'. + uint32_t remaining_word = word; + uint32_t mask; + int num_emitted = 0; + for (mask = 1; remaining_word; mask <<= 1) { + if (remaining_word & mask) { + remaining_word ^= mask; + spv_operand_desc entry; + if (grammar_.lookupOperand(type, mask, &entry)) + assert(false && "should have caught this earlier"); + if (num_emitted) stream_ << "|"; + stream_ << entry->name; + num_emitted++; + } + } + if (!num_emitted) { + // An operand value of 0 was provided, so represent it by the name + // of the 0 value. In many cases, that's "None". + spv_operand_desc entry; + if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) + stream_ << entry->name; + } +} + +spv_result_t Disassembler::SaveTextResult(spv_text* text_result) const { + if (!print_) { + size_t length = text_.str().size(); + char* str = new char[length + 1]; + if (!str) return SPV_ERROR_OUT_OF_MEMORY; + strncpy(str, text_.str().c_str(), length + 1); + spv_text text = new spv_text_t(); + if (!text) { + delete[] str; + return SPV_ERROR_OUT_OF_MEMORY; + } + text->str = str; + text->length = length; + *text_result = text; + } + return SPV_SUCCESS; +} + +spv_result_t DisassembleHeader(void* user_data, spv_endianness_t endian, + uint32_t /* magic */, uint32_t version, + uint32_t generator, uint32_t id_bound, + uint32_t schema) { + assert(user_data); + auto disassembler = static_cast(user_data); + return disassembler->HandleHeader(endian, version, generator, id_bound, + schema); +} + +spv_result_t DisassembleInstruction( + void* user_data, const spv_parsed_instruction_t* parsed_instruction) { + assert(user_data); + auto disassembler = static_cast(user_data); + return disassembler->HandleInstruction(*parsed_instruction); +} + +// Simple wrapper class to provide extra data necessary for targeted +// instruction disassembly. +class WrappedDisassembler { + public: + WrappedDisassembler(Disassembler* dis, const uint32_t* binary, size_t wc) + : disassembler_(dis), inst_binary_(binary), word_count_(wc) {} + + Disassembler* disassembler() { return disassembler_; } + const uint32_t* inst_binary() const { return inst_binary_; } + size_t word_count() const { return word_count_; } + + private: + Disassembler* disassembler_; + const uint32_t* inst_binary_; + const size_t word_count_; +}; + +spv_result_t DisassembleTargetHeader(void* user_data, spv_endianness_t endian, + uint32_t /* magic */, uint32_t version, + uint32_t generator, uint32_t id_bound, + uint32_t schema) { + assert(user_data); + auto wrapped = static_cast(user_data); + return wrapped->disassembler()->HandleHeader(endian, version, generator, + id_bound, schema); +} + +spv_result_t DisassembleTargetInstruction( + void* user_data, const spv_parsed_instruction_t* parsed_instruction) { + assert(user_data); + auto wrapped = static_cast(user_data); + // Check if this is the instruction we want to disassemble. + if (wrapped->word_count() == parsed_instruction->num_words && + std::equal(wrapped->inst_binary(), + wrapped->inst_binary() + wrapped->word_count(), + parsed_instruction->words)) { + // Found the target instruction. Disassemble it and signal that we should + // stop searching so we don't output the same instruction again. + if (auto error = + wrapped->disassembler()->HandleInstruction(*parsed_instruction)) + return error; + return SPV_REQUESTED_TERMINATION; + } + return SPV_SUCCESS; +} + +} // namespace + +spv_result_t spvBinaryToText(const spv_const_context context, + const uint32_t* code, const size_t wordCount, + const uint32_t options, spv_text* pText, + spv_diagnostic* pDiagnostic) { + spv_context_t hijack_context = *context; + if (pDiagnostic) { + *pDiagnostic = nullptr; + spvtools::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); + } + + const spvtools::AssemblyGrammar grammar(&hijack_context); + if (!grammar.isValid()) return SPV_ERROR_INVALID_TABLE; + + // Generate friendly names for Ids if requested. + std::unique_ptr friendly_mapper; + spvtools::NameMapper name_mapper = spvtools::GetTrivialNameMapper(); + if (options & SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES) { + friendly_mapper = spvtools::MakeUnique( + &hijack_context, code, wordCount); + name_mapper = friendly_mapper->GetNameMapper(); + } + + // Now disassemble! + Disassembler disassembler(grammar, options, name_mapper); + if (auto error = spvBinaryParse(&hijack_context, &disassembler, code, + wordCount, DisassembleHeader, + DisassembleInstruction, pDiagnostic)) { + return error; + } + + return disassembler.SaveTextResult(pText); +} + +std::string spvtools::spvInstructionBinaryToText(const spv_target_env env, + const uint32_t* instCode, + const size_t instWordCount, + const uint32_t* code, + const size_t wordCount, + const uint32_t options) { + spv_context context = spvContextCreate(env); + const spvtools::AssemblyGrammar grammar(context); + if (!grammar.isValid()) { + spvContextDestroy(context); + return ""; + } + + // Generate friendly names for Ids if requested. + std::unique_ptr friendly_mapper; + spvtools::NameMapper name_mapper = spvtools::GetTrivialNameMapper(); + if (options & SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES) { + friendly_mapper = spvtools::MakeUnique( + context, code, wordCount); + name_mapper = friendly_mapper->GetNameMapper(); + } + + // Now disassemble! + Disassembler disassembler(grammar, options, name_mapper); + WrappedDisassembler wrapped(&disassembler, instCode, instWordCount); + spvBinaryParse(context, &wrapped, code, wordCount, DisassembleTargetHeader, + DisassembleTargetInstruction, nullptr); + + spv_text text = nullptr; + std::string output; + if (disassembler.SaveTextResult(&text) == SPV_SUCCESS) { + output.assign(text->str, text->str + text->length); + // Drop trailing newline characters. + while (!output.empty() && output.back() == '\n') output.pop_back(); + } + spvTextDestroy(text); + spvContextDestroy(context); + + return output; +} diff --git a/source/disassemble.h b/source/disassemble.h new file mode 100644 index 000000000..ac3574272 --- /dev/null +++ b/source/disassemble.h @@ -0,0 +1,38 @@ +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_DISASSEMBLE_H_ +#define SOURCE_DISASSEMBLE_H_ + +#include + +#include "spirv-tools/libspirv.h" + +namespace spvtools { + +// Decodes the given SPIR-V instruction binary representation to its assembly +// text. The context is inferred from the provided module binary. The options +// parameter is a bit field of spv_binary_to_text_options_t. Decoded text will +// be stored into *text. Any error will be written into *diagnostic if +// diagnostic is non-null. +std::string spvInstructionBinaryToText(const spv_target_env env, + const uint32_t* inst_binary, + const size_t inst_word_count, + const uint32_t* binary, + const size_t word_count, + const uint32_t options); + +} // namespace spvtools + +#endif // SOURCE_DISASSEMBLE_H_ diff --git a/source/enum_set.h b/source/enum_set.h new file mode 100644 index 000000000..e4ef297cd --- /dev/null +++ b/source/enum_set.h @@ -0,0 +1,173 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_ENUM_SET_H_ +#define SOURCE_ENUM_SET_H_ + +#include +#include +#include +#include +#include + +#include "source/latest_version_spirv_header.h" +#include "source/util/make_unique.h" + +namespace spvtools { + +// A set of values of a 32-bit enum type. +// It is fast and compact for the common case, where enum values +// are at most 63. But it can represent enums with larger values, +// as may appear in extensions. +template +class EnumSet { + private: + // The ForEach method will call the functor on enum values in + // enum value order (lowest to highest). To make that easier, use + // an ordered set for the overflow values. + using OverflowSetType = std::set; + + public: + // Construct an empty set. + EnumSet() {} + // Construct an set with just the given enum value. + explicit EnumSet(EnumType c) { Add(c); } + // Construct an set from an initializer list of enum values. + EnumSet(std::initializer_list cs) { + for (auto c : cs) Add(c); + } + EnumSet(uint32_t count, const EnumType* ptr) { + for (uint32_t i = 0; i < count; ++i) Add(ptr[i]); + } + // Copy constructor. + EnumSet(const EnumSet& other) { *this = other; } + // Move constructor. The moved-from set is emptied. + EnumSet(EnumSet&& other) { + mask_ = other.mask_; + overflow_ = std::move(other.overflow_); + other.mask_ = 0; + other.overflow_.reset(nullptr); + } + // Assignment operator. + EnumSet& operator=(const EnumSet& other) { + if (&other != this) { + mask_ = other.mask_; + overflow_.reset(other.overflow_ ? new OverflowSetType(*other.overflow_) + : nullptr); + } + return *this; + } + + // Adds the given enum value to the set. This has no effect if the + // enum value is already in the set. + void Add(EnumType c) { AddWord(ToWord(c)); } + + // Returns true if this enum value is in the set. + bool Contains(EnumType c) const { return ContainsWord(ToWord(c)); } + + // Applies f to each enum in the set, in order from smallest enum + // value to largest. + void ForEach(std::function f) const { + for (uint32_t i = 0; i < 64; ++i) { + if (mask_ & AsMask(i)) f(static_cast(i)); + } + if (overflow_) { + for (uint32_t c : *overflow_) f(static_cast(c)); + } + } + + // Returns true if the set is empty. + bool IsEmpty() const { + if (mask_) return false; + if (overflow_ && !overflow_->empty()) return false; + return true; + } + + // Returns true if the set contains ANY of the elements of |in_set|, + // or if |in_set| is empty. + bool HasAnyOf(const EnumSet& in_set) const { + if (in_set.IsEmpty()) return true; + + if (mask_ & in_set.mask_) return true; + + if (!overflow_ || !in_set.overflow_) return false; + + for (uint32_t item : *in_set.overflow_) { + if (overflow_->find(item) != overflow_->end()) return true; + } + + return false; + } + + private: + // Adds the given enum value (as a 32-bit word) to the set. This has no + // effect if the enum value is already in the set. + void AddWord(uint32_t word) { + if (auto new_bits = AsMask(word)) { + mask_ |= new_bits; + } else { + Overflow().insert(word); + } + } + + // Returns true if the enum represented as a 32-bit word is in the set. + bool ContainsWord(uint32_t word) const { + // We shouldn't call Overflow() since this is a const method. + if (auto bits = AsMask(word)) { + return (mask_ & bits) != 0; + } else if (auto overflow = overflow_.get()) { + return overflow->find(word) != overflow->end(); + } + // The word is large, but the set doesn't have large members, so + // it doesn't have an overflow set. + return false; + } + + // Returns the enum value as a uint32_t. + uint32_t ToWord(EnumType value) const { + static_assert(sizeof(EnumType) <= sizeof(uint32_t), + "EnumType must statically castable to uint32_t"); + return static_cast(value); + } + + // Determines whether the given enum value can be represented + // as a bit in a uint64_t mask. If so, then returns that mask bit. + // Otherwise, returns 0. + uint64_t AsMask(uint32_t word) const { + if (word > 63) return 0; + return uint64_t(1) << word; + } + + // Ensures that overflow_set_ references a set. A new empty set is + // allocated if one doesn't exist yet. Returns overflow_set_. + OverflowSetType& Overflow() { + if (overflow_.get() == nullptr) { + overflow_ = MakeUnique(); + } + return *overflow_; + } + + // Enums with values up to 63 are stored as bits in this mask. + uint64_t mask_ = 0; + // Enums with values larger than 63 are stored in this set. + // This set should normally be empty or very small. + std::unique_ptr overflow_ = {}; +}; + +// A set of SpvCapability, optimized for small capability values. +using CapabilitySet = EnumSet; + +} // namespace spvtools + +#endif // SOURCE_ENUM_SET_H_ diff --git a/source/enum_string_mapping.cpp b/source/enum_string_mapping.cpp new file mode 100644 index 000000000..32361a08d --- /dev/null +++ b/source/enum_string_mapping.cpp @@ -0,0 +1,29 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/enum_string_mapping.h" + +#include +#include +#include +#include +#include + +#include "source/extensions.h" + +namespace spvtools { + +#include "enum_string_mapping.inc" + +} // namespace spvtools diff --git a/source/enum_string_mapping.h b/source/enum_string_mapping.h new file mode 100644 index 000000000..af8f56b82 --- /dev/null +++ b/source/enum_string_mapping.h @@ -0,0 +1,36 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_ENUM_STRING_MAPPING_H_ +#define SOURCE_ENUM_STRING_MAPPING_H_ + +#include + +#include "source/extensions.h" +#include "source/latest_version_spirv_header.h" + +namespace spvtools { + +// Finds Extension enum corresponding to |str|. Returns false if not found. +bool GetExtensionFromString(const char* str, Extension* extension); + +// Returns text string corresponding to |extension|. +const char* ExtensionToString(Extension extension); + +// Returns text string corresponding to |capability|. +const char* CapabilityToString(SpvCapability capability); + +} // namespace spvtools + +#endif // SOURCE_ENUM_STRING_MAPPING_H_ diff --git a/source/ext_inst.cpp b/source/ext_inst.cpp new file mode 100644 index 000000000..08c775eb3 --- /dev/null +++ b/source/ext_inst.cpp @@ -0,0 +1,160 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/ext_inst.h" + +#include + +// DebugInfo extended instruction set. +// See https://www.khronos.org/registry/spir-v/specs/1.0/DebugInfo.html +// TODO(dneto): DebugInfo.h should probably move to SPIRV-Headers. +#include "DebugInfo.h" + +#include "source/latest_version_glsl_std_450_header.h" +#include "source/latest_version_opencl_std_header.h" +#include "source/macro.h" +#include "source/spirv_definition.h" + +#include "debuginfo.insts.inc" +#include "glsl.std.450.insts.inc" +#include "opencl.std.insts.inc" + +#include "spv-amd-gcn-shader.insts.inc" +#include "spv-amd-shader-ballot.insts.inc" +#include "spv-amd-shader-explicit-vertex-parameter.insts.inc" +#include "spv-amd-shader-trinary-minmax.insts.inc" + +static const spv_ext_inst_group_t kGroups_1_0[] = { + {SPV_EXT_INST_TYPE_GLSL_STD_450, ARRAY_SIZE(glsl_entries), glsl_entries}, + {SPV_EXT_INST_TYPE_OPENCL_STD, ARRAY_SIZE(opencl_entries), opencl_entries}, + {SPV_EXT_INST_TYPE_SPV_AMD_SHADER_EXPLICIT_VERTEX_PARAMETER, + ARRAY_SIZE(spv_amd_shader_explicit_vertex_parameter_entries), + spv_amd_shader_explicit_vertex_parameter_entries}, + {SPV_EXT_INST_TYPE_SPV_AMD_SHADER_TRINARY_MINMAX, + ARRAY_SIZE(spv_amd_shader_trinary_minmax_entries), + spv_amd_shader_trinary_minmax_entries}, + {SPV_EXT_INST_TYPE_SPV_AMD_GCN_SHADER, + ARRAY_SIZE(spv_amd_gcn_shader_entries), spv_amd_gcn_shader_entries}, + {SPV_EXT_INST_TYPE_SPV_AMD_SHADER_BALLOT, + ARRAY_SIZE(spv_amd_shader_ballot_entries), spv_amd_shader_ballot_entries}, + {SPV_EXT_INST_TYPE_DEBUGINFO, ARRAY_SIZE(debuginfo_entries), + debuginfo_entries}, +}; + +static const spv_ext_inst_table_t kTable_1_0 = {ARRAY_SIZE(kGroups_1_0), + kGroups_1_0}; + +spv_result_t spvExtInstTableGet(spv_ext_inst_table* pExtInstTable, + spv_target_env env) { + if (!pExtInstTable) return SPV_ERROR_INVALID_POINTER; + + switch (env) { + // The extended instruction sets are all version 1.0 so far. + case SPV_ENV_UNIVERSAL_1_0: + case SPV_ENV_VULKAN_1_0: + case SPV_ENV_UNIVERSAL_1_1: + case SPV_ENV_UNIVERSAL_1_2: + case SPV_ENV_OPENCL_1_2: + case SPV_ENV_OPENCL_EMBEDDED_1_2: + case SPV_ENV_OPENCL_2_0: + case SPV_ENV_OPENCL_EMBEDDED_2_0: + case SPV_ENV_OPENCL_2_1: + case SPV_ENV_OPENCL_EMBEDDED_2_1: + case SPV_ENV_OPENCL_2_2: + case SPV_ENV_OPENCL_EMBEDDED_2_2: + case SPV_ENV_OPENGL_4_0: + case SPV_ENV_OPENGL_4_1: + case SPV_ENV_OPENGL_4_2: + case SPV_ENV_OPENGL_4_3: + case SPV_ENV_OPENGL_4_5: + case SPV_ENV_UNIVERSAL_1_3: + case SPV_ENV_VULKAN_1_1: + case SPV_ENV_WEBGPU_0: + *pExtInstTable = &kTable_1_0; + return SPV_SUCCESS; + default: + return SPV_ERROR_INVALID_TABLE; + } +} + +spv_ext_inst_type_t spvExtInstImportTypeGet(const char* name) { + // The names are specified by the respective extension instruction + // specifications. + if (!strcmp("GLSL.std.450", name)) { + return SPV_EXT_INST_TYPE_GLSL_STD_450; + } + if (!strcmp("OpenCL.std", name)) { + return SPV_EXT_INST_TYPE_OPENCL_STD; + } + if (!strcmp("SPV_AMD_shader_explicit_vertex_parameter", name)) { + return SPV_EXT_INST_TYPE_SPV_AMD_SHADER_EXPLICIT_VERTEX_PARAMETER; + } + if (!strcmp("SPV_AMD_shader_trinary_minmax", name)) { + return SPV_EXT_INST_TYPE_SPV_AMD_SHADER_TRINARY_MINMAX; + } + if (!strcmp("SPV_AMD_gcn_shader", name)) { + return SPV_EXT_INST_TYPE_SPV_AMD_GCN_SHADER; + } + if (!strcmp("SPV_AMD_shader_ballot", name)) { + return SPV_EXT_INST_TYPE_SPV_AMD_SHADER_BALLOT; + } + if (!strcmp("DebugInfo", name)) { + return SPV_EXT_INST_TYPE_DEBUGINFO; + } + return SPV_EXT_INST_TYPE_NONE; +} + +spv_result_t spvExtInstTableNameLookup(const spv_ext_inst_table table, + const spv_ext_inst_type_t type, + const char* name, + spv_ext_inst_desc* pEntry) { + if (!table) return SPV_ERROR_INVALID_TABLE; + if (!pEntry) return SPV_ERROR_INVALID_POINTER; + + for (uint32_t groupIndex = 0; groupIndex < table->count; groupIndex++) { + const auto& group = table->groups[groupIndex]; + if (type != group.type) continue; + for (uint32_t index = 0; index < group.count; index++) { + const auto& entry = group.entries[index]; + if (!strcmp(name, entry.name)) { + *pEntry = &entry; + return SPV_SUCCESS; + } + } + } + + return SPV_ERROR_INVALID_LOOKUP; +} + +spv_result_t spvExtInstTableValueLookup(const spv_ext_inst_table table, + const spv_ext_inst_type_t type, + const uint32_t value, + spv_ext_inst_desc* pEntry) { + if (!table) return SPV_ERROR_INVALID_TABLE; + if (!pEntry) return SPV_ERROR_INVALID_POINTER; + + for (uint32_t groupIndex = 0; groupIndex < table->count; groupIndex++) { + const auto& group = table->groups[groupIndex]; + if (type != group.type) continue; + for (uint32_t index = 0; index < group.count; index++) { + const auto& entry = group.entries[index]; + if (value == entry.ext_inst) { + *pEntry = &entry; + return SPV_SUCCESS; + } + } + } + + return SPV_ERROR_INVALID_LOOKUP; +} diff --git a/source/ext_inst.h b/source/ext_inst.h new file mode 100644 index 000000000..a821cc2bc --- /dev/null +++ b/source/ext_inst.h @@ -0,0 +1,40 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_EXT_INST_H_ +#define SOURCE_EXT_INST_H_ + +#include "source/table.h" +#include "spirv-tools/libspirv.h" + +// Gets the type of the extended instruction set with the specified name. +spv_ext_inst_type_t spvExtInstImportTypeGet(const char* name); + +// Finds the named extented instruction of the given type in the given extended +// instruction table. On success, returns SPV_SUCCESS and writes a handle of +// the instruction entry into *entry. +spv_result_t spvExtInstTableNameLookup(const spv_ext_inst_table table, + const spv_ext_inst_type_t type, + const char* name, + spv_ext_inst_desc* entry); + +// Finds the extented instruction of the given type in the given extended +// instruction table by value. On success, returns SPV_SUCCESS and writes a +// handle of the instruction entry into *entry. +spv_result_t spvExtInstTableValueLookup(const spv_ext_inst_table table, + const spv_ext_inst_type_t type, + const uint32_t value, + spv_ext_inst_desc* pEntry); + +#endif // SOURCE_EXT_INST_H_ diff --git a/source/extensions.cpp b/source/extensions.cpp new file mode 100644 index 000000000..a94db273e --- /dev/null +++ b/source/extensions.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/extensions.h" + +#include +#include +#include + +#include "source/enum_string_mapping.h" + +namespace spvtools { + +std::string GetExtensionString(const spv_parsed_instruction_t* inst) { + if (inst->opcode != SpvOpExtension) return "ERROR_not_op_extension"; + + assert(inst->num_operands == 1); + + const auto& operand = inst->operands[0]; + assert(operand.type == SPV_OPERAND_TYPE_LITERAL_STRING); + assert(inst->num_words > operand.offset); + + return reinterpret_cast(inst->words + operand.offset); +} + +std::string ExtensionSetToString(const ExtensionSet& extensions) { + std::stringstream ss; + extensions.ForEach( + [&ss](Extension ext) { ss << ExtensionToString(ext) << " "; }); + return ss.str(); +} + +} // namespace spvtools diff --git a/source/extensions.h b/source/extensions.h new file mode 100644 index 000000000..8023444c3 --- /dev/null +++ b/source/extensions.h @@ -0,0 +1,40 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_EXTENSIONS_H_ +#define SOURCE_EXTENSIONS_H_ + +#include + +#include "source/enum_set.h" +#include "spirv-tools/libspirv.h" + +namespace spvtools { + +// The known SPIR-V extensions. +enum Extension { +#include "extension_enum.inc" +}; + +using ExtensionSet = EnumSet; + +// Returns literal string operand of OpExtension instruction. +std::string GetExtensionString(const spv_parsed_instruction_t* inst); + +// Returns text string listing |extensions| separated by whitespace. +std::string ExtensionSetToString(const ExtensionSet& extensions); + +} // namespace spvtools + +#endif // SOURCE_EXTENSIONS_H_ diff --git a/source/extinst.debuginfo.grammar.json b/source/extinst.debuginfo.grammar.json new file mode 100644 index 000000000..9212f6f48 --- /dev/null +++ b/source/extinst.debuginfo.grammar.json @@ -0,0 +1,568 @@ +{ + "copyright" : [ + "Copyright (c) 2017 The Khronos Group Inc.", + "", + "Permission is hereby granted, free of charge, to any person obtaining a copy", + "of this software and/or associated documentation files (the \"Materials\"),", + "to deal in the Materials without restriction, including without limitation", + "the rights to use, copy, modify, merge, publish, distribute, sublicense,", + "and/or sell copies of the Materials, and to permit persons to whom the", + "Materials are furnished to do so, subject to the following conditions:", + "", + "The above copyright notice and this permission notice shall be included in", + "all copies or substantial portions of the Materials.", + "", + "MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS KHRONOS", + "STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS SPECIFICATIONS AND", + "HEADER INFORMATION ARE LOCATED AT https://www.khronos.org/registry/ ", + "", + "THE MATERIALS ARE PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS", + "OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,", + "FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL", + "THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER", + "LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING", + "FROM,OUT OF OR IN CONNECTION WITH THE MATERIALS OR THE USE OR OTHER DEALINGS", + "IN THE MATERIALS." + ], + "version" : 100, + "revision" : 1, + "instructions" : [ + { + "opname" : "DebugInfoNone", + "opcode" : 0 + }, + { + "opname" : "DebugCompilationUnit", + "opcode" : 1, + "operands" : [ + { "kind" : "IdRef", "name" : "'Source'" }, + { "kind" : "LiteralInteger", "name" : "'Version'" }, + { "kind" : "LiteralInteger", "name" : "'DWARF Version'" } + ] + }, + { + "opname" : "DebugTypeBasic", + "opcode" : 2, + "operands" : [ + { "kind" : "IdRef", "name" : "'Name'" }, + { "kind" : "IdRef", "name" : "'Size'" }, + { "kind" : "DebugBaseTypeAttributeEncoding", "name" : "'Encoding'" } + ] + }, + { + "opname" : "DebugTypePointer", + "opcode" : 3, + "operands" : [ + { "kind" : "IdRef", "name" : "'Base Type'" }, + { "kind" : "StorageClass", "name" : "'Storage Class'" }, + { "kind" : "DebugInfoFlags", "name" : "'Literal Flags'" } + ] + }, + { + "opname" : "DebugTypeQualifier", + "opcode" : 4, + "operands" : [ + { "kind" : "IdRef", "name" : "'Base Type'" }, + { "kind" : "DebugTypeQualifier", "name" : "'Type Qualifier'" } + ] + }, + { + "opname" : "DebugTypeArray", + "opcode" : 5, + "operands" : [ + { "kind" : "IdRef", "name" : "'Base Type'" }, + { "kind" : "IdRef", "name" : "'Component Counts'", "quantifier" : "*" } + ] + }, + { + "opname" : "DebugTypeVector", + "opcode" : 6, + "operands" : [ + { "kind" : "IdRef", "name" : "'Base Type'" }, + { "kind" : "LiteralInteger", "name" : "'Component Count'" } + ] + }, + { + "opname" : "DebugTypedef", + "opcode" : 7, + "operands" : [ + { "kind" : "IdRef", "name" : "'Name'" }, + { "kind" : "IdRef", "name" : "'Base Type'" }, + { "kind" : "IdRef", "name" : "'Source'" }, + { "kind" : "LiteralInteger", "name" : "'Line'" }, + { "kind" : "LiteralInteger", "name" : "'Column'" }, + { "kind" : "IdRef", "name" : "'Parent'" } + ] + }, + { + "opname" : "DebugTypeFunction", + "opcode" : 8, + "operands" : [ + { "kind" : "IdRef", "name" : "'Return Type'" }, + { "kind" : "IdRef", "name" : "'Paramter Types'", "quantifier" : "*" } + ] + }, + { + "opname" : "DebugTypeEnum", + "opcode" : 9, + "operands" : [ + { "kind" : "IdRef", "name" : "'Name'" }, + { "kind" : "IdRef", "name" : "'Underlying Type'" }, + { "kind" : "IdRef", "name" : "'Source'" }, + { "kind" : "LiteralInteger", "name" : "'Line'" }, + { "kind" : "LiteralInteger", "name" : "'Column'" }, + { "kind" : "IdRef", "name" : "'Parent'" }, + { "kind" : "IdRef", "name" : "'Size'" }, + { "kind" : "DebugInfoFlags", "name" : "'Flags'" }, + { "kind" : "PairIdRefIdRef", "name" : "'Value, Name, Value, Name, ...'", "quantifier" : "*" } + ] + }, + { + "opname" : "DebugTypeComposite", + "opcode" : 10, + "operands" : [ + { "kind" : "IdRef", "name" : "'Name'" }, + { "kind" : "DebugCompositeType", "name" : "'Tag'" }, + { "kind" : "IdRef", "name" : "'Source'" }, + { "kind" : "LiteralInteger", "name" : "'Line'" }, + { "kind" : "LiteralInteger", "name" : "'Column'" }, + { "kind" : "IdRef", "name" : "'Parent'" }, + { "kind" : "IdRef", "name" : "'Size'" }, + { "kind" : "DebugInfoFlags", "name" : "'Flags'" }, + { "kind" : "IdRef", "name" : "'Members'", "quantifier" : "*" } + ] + }, + { + "opname" : "DebugTypeMember", + "opcode" : 11, + "operands" : [ + { "kind" : "IdRef", "name" : "'Name'" }, + { "kind" : "IdRef", "name" : "'Type'" }, + { "kind" : "IdRef", "name" : "'Source'" }, + { "kind" : "LiteralInteger", "name" : "'Line'" }, + { "kind" : "LiteralInteger", "name" : "'Column'" }, + { "kind" : "IdRef", "name" : "'Parent'" }, + { "kind" : "IdRef", "name" : "'Offset'" }, + { "kind" : "IdRef", "name" : "'Size'" }, + { "kind" : "DebugInfoFlags", "name" : "'Flags'" }, + { "kind" : "IdRef", "name" : "'Value'", "quantifier" : "?" } + ] + }, + { + "opname" : "DebugTypeInheritance", + "opcode" : 12, + "operands" : [ + { "kind" : "IdRef", "name" : "'Child'" }, + { "kind" : "IdRef", "name" : "'Parent'" }, + { "kind" : "IdRef", "name" : "'Offset'" }, + { "kind" : "IdRef", "name" : "'Size'" }, + { "kind" : "DebugInfoFlags", "name" : "'Flags'" } + ] + }, + { + "opname" : "DebugTypePtrToMember", + "opcode" : 13, + "operands" : [ + { "kind" : "IdRef", "name" : "'Member Type'" }, + { "kind" : "IdRef", "name" : "'Parent'" } + ] + }, + { + "opname" : "DebugTypeTemplate", + "opcode" : 14, + "operands" : [ + { "kind" : "IdRef", "name" : "'Target'" }, + { "kind" : "IdRef", "name" : "'Parameters'", "quantifier" : "*" } + ] + }, + { + "opname" : "DebugTypeTemplateParameter", + "opcode" : 15, + "operands" : [ + { "kind" : "IdRef", "name" : "'Name'" }, + { "kind" : "IdRef", "name" : "'Actual Type'" }, + { "kind" : "IdRef", "name" : "'Value'" }, + { "kind" : "IdRef", "name" : "'Source'" }, + { "kind" : "LiteralInteger", "name" : "'Line'" }, + { "kind" : "LiteralInteger", "name" : "'Column'" } + ] + }, + { + "opname" : "DebugTypeTemplateTemplateParameter", + "opcode" : 16, + "operands" : [ + { "kind" : "IdRef", "name" : "'Name'" }, + { "kind" : "IdRef", "name" : "'Template Name'" }, + { "kind" : "IdRef", "name" : "'Source'" }, + { "kind" : "LiteralInteger", "name" : "'Line'" }, + { "kind" : "LiteralInteger", "name" : "'Column'" } + ] + }, + { + "opname" : "DebugTypeTemplateParameterPack", + "opcode" : 17, + "operands" : [ + { "kind" : "IdRef", "name" : "'Name'" }, + { "kind" : "IdRef", "name" : "'Source'" }, + { "kind" : "LiteralInteger", "name" : "'Line'" }, + { "kind" : "LiteralInteger", "name" : "'Column'" }, + { "kind" : "IdRef", "name" : "'Template Parameters'", "quantifier" : "*" } + ] + }, + { + "opname" : "DebugGlobalVariable", + "opcode" : 18, + "operands" : [ + { "kind" : "IdRef", "name" : "'Name'" }, + { "kind" : "IdRef", "name" : "'Type'" }, + { "kind" : "IdRef", "name" : "'Source'" }, + { "kind" : "LiteralInteger", "name" : "'Line'" }, + { "kind" : "LiteralInteger", "name" : "'Column'" }, + { "kind" : "IdRef", "name" : "'Parent'" }, + { "kind" : "IdRef", "name" : "'Linkage Name'" }, + { "kind" : "IdRef", "name" : "'Variable'" }, + { "kind" : "DebugInfoFlags", "name" : "'Flags'" }, + { "kind" : "IdRef", "name" : "'Static Member Declaration'", "quantifier" : "?" } + ] + }, + { + "opname" : "DebugFunctionDeclaration", + "opcode" : 19, + "operands" : [ + { "kind" : "IdRef", "name" : "'Name'" }, + { "kind" : "IdRef", "name" : "'Type'" }, + { "kind" : "IdRef", "name" : "'Source'" }, + { "kind" : "LiteralInteger", "name" : "'Line'" }, + { "kind" : "LiteralInteger", "name" : "'Column'" }, + { "kind" : "IdRef", "name" : "'Parent'" }, + { "kind" : "IdRef", "name" : "'Linkage Name'" }, + { "kind" : "DebugInfoFlags", "name" : "'Flags'" } + ] + }, + { + "opname" : "DebugFunction", + "opcode" : 20, + "operands" : [ + { "kind" : "IdRef", "name" : "'Name'" }, + { "kind" : "IdRef", "name" : "'Type'" }, + { "kind" : "IdRef", "name" : "'Source'" }, + { "kind" : "LiteralInteger", "name" : "'Line'" }, + { "kind" : "LiteralInteger", "name" : "'Column'" }, + { "kind" : "IdRef", "name" : "'Parent'" }, + { "kind" : "IdRef", "name" : "'Linkage Name'" }, + { "kind" : "DebugInfoFlags", "name" : "'Flags'" }, + { "kind" : "LiteralInteger", "name" : "'Scope Line'" }, + { "kind" : "IdRef", "name" : "'Function'" }, + { "kind" : "IdRef", "name" : "'Declaration'", "quantifier" : "?" } + ] + }, + { + "opname" : "DebugLexicalBlock", + "opcode" : 21, + "operands" : [ + { "kind" : "IdRef", "name" : "'Source'" }, + { "kind" : "LiteralInteger", "name" : "'Line'" }, + { "kind" : "LiteralInteger", "name" : "'Column'" }, + { "kind" : "IdRef", "name" : "'Parent'" }, + { "kind" : "IdRef", "name" : "'Name'", "quantifier" : "?" } + ] + }, + { + "opname" : "DebugLexicalBlockDiscriminator", + "opcode" : 22, + "operands" : [ + { "kind" : "IdRef", "name" : "'Scope'" }, + { "kind" : "LiteralInteger", "name" : "'Discriminator'" }, + { "kind" : "IdRef", "name" : "'Parent'" } + ] + }, + { + "opname" : "DebugScope", + "opcode" : 23, + "operands" : [ + { "kind" : "IdRef", "name" : "'Scope'" }, + { "kind" : "IdRef", "name" : "'Inlined At'", "quantifier" : "?" } + ] + }, + { + "opname" : "DebugNoScope", + "opcode" : 24 + }, + { + "opname" : "DebugInlinedAt", + "opcode" : 25, + "operands" : [ + { "kind" : "LiteralInteger", "name" : "'Line'" }, + { "kind" : "IdRef", "name" : "'Scope'" }, + { "kind" : "IdRef", "name" : "'Inlined'", "quantifier" : "?" } + ] + }, + { + "opname" : "DebugLocalVariable", + "opcode" : 26, + "operands" : [ + { "kind" : "IdRef", "name" : "'Name'" }, + { "kind" : "IdRef", "name" : "'Type'" }, + { "kind" : "IdRef", "name" : "'Source'" }, + { "kind" : "LiteralInteger", "name" : "'Line'" }, + { "kind" : "LiteralInteger", "name" : "'Column'" }, + { "kind" : "IdRef", "name" : "'Parent'" }, + { "kind" : "LiteralInteger", "name" : "'Arg Number'", "quantifier" : "?" } + ] + }, + { + "opname" : "DebugInlinedVariable", + "opcode" : 27, + "operands" : [ + { "kind" : "IdRef", "name" : "'Variable'" }, + { "kind" : "IdRef", "name" : "'Inlined'" } + ] + }, + { + "opname" : "DebugDeclare", + "opcode" : 28, + "operands" : [ + { "kind" : "IdRef", "name" : "'Local Variable'" }, + { "kind" : "IdRef", "name" : "'Variable'" }, + { "kind" : "IdRef", "name" : "'Expression'" } + ] + }, + { + "opname" : "DebugValue", + "opcode" : 29, + "operands" : [ + { "kind" : "IdRef", "name" : "'Value'" }, + { "kind" : "IdRef", "name" : "'Expression'" }, + { "kind" : "IdRef", "name" : "'Indexes'", "quantifier" : "*" } + ] + }, + { + "opname" : "DebugOperation", + "opcode" : 30, + "operands" : [ + { "kind" : "DebugOperation", "name" : "'OpCode'" }, + { "kind" : "LiteralInteger", "name" : "'Operands ...'", "quantifier" : "*" } + ] + }, + { + "opname" : "DebugExpression", + "opcode" : 31, + "operands" : [ + { "kind" : "IdRef", "name" : "'Operands ...'", "quantifier" : "*" } + ] + }, + { + "opname" : "DebugMacroDef", + "opcode" : 32, + "operands" : [ + { "kind" : "IdRef", "name" : "'Source'" }, + { "kind" : "LiteralInteger", "name" : "'Line'" }, + { "kind" : "IdRef", "name" : "'Name'" }, + { "kind" : "IdRef", "name" : "'Value'", "quantifier" : "?" } + ] + }, + { + "opname" : "DebugMacroUndef", + "opcode" : 33, + "operands" : [ + { "kind" : "IdRef", "name" : "'Source'" }, + { "kind" : "LiteralInteger", "name" : "'Line'" }, + { "kind" : "IdRef", "name" : "'Macro'" } + ] + } + ], + "operand_kinds" : [ + { + "category" : "BitEnum", + "kind" : "DebugInfoFlags", + "enumerants" : [ + { + "enumerant" : "FlagIsProtected", + "value" : "0x01" + }, + { + "enumerant" : "FlagIsPrivate", + "value" : "0x02" + }, + { + "enumerant" : "FlagIsPublic", + "value" : "0x03" + }, + { + "enumerant" : "FlagIsLocal", + "value" : "0x04" + }, + { + "enumerant" : "FlagIsDefinition", + "value" : "0x08" + }, + { + "enumerant" : "FlagFwdDecl", + "value" : "0x10" + }, + { + "enumerant" : "FlagArtificial", + "value" : "0x20" + }, + { + "enumerant" : "FlagExplicit", + "value" : "0x40" + }, + { + "enumerant" : "FlagPrototyped", + "value" : "0x80" + }, + { + "enumerant" : "FlagObjectPointer", + "value" : "0x100" + }, + { + "enumerant" : "FlagStaticMember", + "value" : "0x200" + }, + { + "enumerant" : "FlagIndirectVariable", + "value" : "0x400" + }, + { + "enumerant" : "FlagLValueReference", + "value" : "0x800" + }, + { + "enumerant" : "FlagRValueReference", + "value" : "0x1000" + }, + { + "enumerant" : "FlagIsOptimized", + "value" : "0x2000" + } + ] + }, + { + "category" : "ValueEnum", + "kind" : "DebugBaseTypeAttributeEncoding", + "enumerants" : [ + { + "enumerant" : "Unspecified", + "value" : "0" + }, + { + "enumerant" : "Address", + "value" : "1" + }, + { + "enumerant" : "Boolean", + "value" : "2" + }, + { + "enumerant" : "Float", + "value" : "4" + }, + { + "enumerant" : "Signed", + "value" : "5" + }, + { + "enumerant" : "SignedChar", + "value" : "6" + }, + { + "enumerant" : "Unsigned", + "value" : "7" + }, + { + "enumerant" : "UnsignedChar", + "value" : "8" + } + ] + }, + { + "category" : "ValueEnum", + "kind" : "DebugCompositeType", + "enumerants" : [ + { + "enumerant" : "Class", + "value" : "0" + }, + { + "enumerant" : "Structure", + "value" : "1" + }, + { + "enumerant" : "Union", + "value" : "2" + } + ] + }, + { + "category" : "ValueEnum", + "kind" : "DebugTypeQualifier", + "enumerants" : [ + { + "enumerant" : "ConstType", + "value" : "0" + }, + { + "enumerant" : "VolatileType", + "value" : "1" + }, + { + "enumerant" : "RestrictType", + "value" : "2" + } + ] + }, + { + "category" : "ValueEnum", + "kind" : "DebugOperation", + "enumerants" : [ + { + "enumerant" : "Deref", + "value" : "0" + }, + { + "enumerant" : "Plus", + "value" : "1" + }, + { + "enumerant" : "Minus", + "value" : "2" + }, + { + "enumerant" : "PlusUconst", + "value" : "3", + "parameters" : [ + { "kind" : "LiteralInteger" } + ] + }, + { + "enumerant" : "BitPiece", + "value" : "4", + "parameters" : [ + { "kind" : "LiteralInteger" }, + { "kind" : "LiteralInteger" } + ] + }, + { + "enumerant" : "Swap", + "value" : "5" + }, + { + "enumerant" : "Xderef", + "value" : "6" + }, + { + "enumerant" : "StackValue", + "value" : "7" + }, + { + "enumerant" : "Constu", + "value" : "8", + "parameters" : [ + { "kind" : "LiteralInteger" } + ] + } + ] + } + ] +} diff --git a/source/extinst.spv-amd-gcn-shader.grammar.json b/source/extinst.spv-amd-gcn-shader.grammar.json new file mode 100644 index 000000000..e18251bba --- /dev/null +++ b/source/extinst.spv-amd-gcn-shader.grammar.json @@ -0,0 +1,26 @@ +{ + "revision" : 2, + "instructions" : [ + { + "opname" : "CubeFaceIndexAMD", + "opcode" : 1, + "operands" : [ + { "kind" : "IdRef", "name" : "'P'" } + ], + "extensions" : [ "SPV_AMD_gcn_shader" ] + }, + { + "opname" : "CubeFaceCoordAMD", + "opcode" : 2, + "operands" : [ + { "kind" : "IdRef", "name" : "'P'" } + ], + "extensions" : [ "SPV_AMD_gcn_shader" ] + }, + { + "opname" : "TimeAMD", + "opcode" : 3, + "extensions" : [ "SPV_AMD_gcn_shader" ] + } + ] +} diff --git a/source/extinst.spv-amd-shader-ballot.grammar.json b/source/extinst.spv-amd-shader-ballot.grammar.json new file mode 100644 index 000000000..62a470eeb --- /dev/null +++ b/source/extinst.spv-amd-shader-ballot.grammar.json @@ -0,0 +1,41 @@ +{ + "revision" : 5, + "instructions" : [ + { + "opname" : "SwizzleInvocationsAMD", + "opcode" : 1, + "operands" : [ + { "kind" : "IdRef", "name" : "'data'" }, + { "kind" : "IdRef", "name" : "'offset'" } + ], + "extensions" : [ "SPV_AMD_shader_ballot" ] + }, + { + "opname" : "SwizzleInvocationsMaskedAMD", + "opcode" : 2, + "operands" : [ + { "kind" : "IdRef", "name" : "'data'" }, + { "kind" : "IdRef", "name" : "'mask'" } + ], + "extensions" : [ "SPV_AMD_shader_ballot" ] + }, + { + "opname" : "WriteInvocationAMD", + "opcode" : 3, + "operands" : [ + { "kind" : "IdRef", "name" : "'inputValue'" }, + { "kind" : "IdRef", "name" : "'writeValue'" }, + { "kind" : "IdRef", "name" : "'invocationIndex'" } + ], + "extensions" : [ "SPV_AMD_shader_ballot" ] + }, + { + "opname" : "MbcntAMD", + "opcode" : 4, + "operands" : [ + { "kind" : "IdRef", "name" : "'mask'" } + ], + "extensions" : [ "SPV_AMD_shader_ballot" ] + } + ] +} diff --git a/source/extinst.spv-amd-shader-explicit-vertex-parameter.grammar.json b/source/extinst.spv-amd-shader-explicit-vertex-parameter.grammar.json new file mode 100644 index 000000000..e156b1b6f --- /dev/null +++ b/source/extinst.spv-amd-shader-explicit-vertex-parameter.grammar.json @@ -0,0 +1,14 @@ +{ + "revision" : 4, + "instructions" : [ + { + "opname" : "InterpolateAtVertexAMD", + "opcode" : 1, + "operands" : [ + { "kind" : "IdRef", "name" : "'interpolant'" }, + { "kind" : "IdRef", "name" : "'vertexIdx'" } + ], + "extensions" : [ "SPV_AMD_shader_explicit_vertex_parameter" ] + } + ] +} diff --git a/source/extinst.spv-amd-shader-trinary-minmax.grammar.json b/source/extinst.spv-amd-shader-trinary-minmax.grammar.json new file mode 100644 index 000000000..c681976fe --- /dev/null +++ b/source/extinst.spv-amd-shader-trinary-minmax.grammar.json @@ -0,0 +1,95 @@ +{ + "revision" : 4, + "instructions" : [ + { + "opname" : "FMin3AMD", + "opcode" : 1, + "operands" : [ + { "kind" : "IdRef", "name" : "'x'" }, + { "kind" : "IdRef", "name" : "'y'" }, + { "kind" : "IdRef", "name" : "'z'" } + ], + "extensions" : [ "SPV_AMD_shader_trinary_minmax" ] + }, + { + "opname" : "UMin3AMD", + "opcode" : 2, + "operands" : [ + { "kind" : "IdRef", "name" : "'x'" }, + { "kind" : "IdRef", "name" : "'y'" }, + { "kind" : "IdRef", "name" : "'z'" } + ], + "extensions" : [ "SPV_AMD_shader_trinary_minmax" ] + }, + { + "opname" : "SMin3AMD", + "opcode" : 3, + "operands" : [ + { "kind" : "IdRef", "name" : "'x'" }, + { "kind" : "IdRef", "name" : "'y'" }, + { "kind" : "IdRef", "name" : "'z'" } + ], + "extensions" : [ "SPV_AMD_shader_trinary_minmax" ] + }, + { + "opname" : "FMax3AMD", + "opcode" : 4, + "operands" : [ + { "kind" : "IdRef", "name" : "'x'" }, + { "kind" : "IdRef", "name" : "'y'" }, + { "kind" : "IdRef", "name" : "'z'" } + ], + "extensions" : [ "SPV_AMD_shader_trinary_minmax" ] + }, + { + "opname" : "UMax3AMD", + "opcode" : 5, + "operands" : [ + { "kind" : "IdRef", "name" : "'x'" }, + { "kind" : "IdRef", "name" : "'y'" }, + { "kind" : "IdRef", "name" : "'z'" } + ], + "extensions" : [ "SPV_AMD_shader_trinary_minmax" ] + }, + { + "opname" : "SMax3AMD", + "opcode" : 6, + "operands" : [ + { "kind" : "IdRef", "name" : "'x'" }, + { "kind" : "IdRef", "name" : "'y'" }, + { "kind" : "IdRef", "name" : "'z'" } + ], + "extensions" : [ "SPV_AMD_shader_trinary_minmax" ] + }, + { + "opname" : "FMid3AMD", + "opcode" : 7, + "operands" : [ + { "kind" : "IdRef", "name" : "'x'" }, + { "kind" : "IdRef", "name" : "'y'" }, + { "kind" : "IdRef", "name" : "'z'" } + ], + "extensions" : [ "SPV_AMD_shader_trinary_minmax" ] + }, + { + "opname" : "UMid3AMD", + "opcode" : 8, + "operands" : [ + { "kind" : "IdRef", "name" : "'x'" }, + { "kind" : "IdRef", "name" : "'y'" }, + { "kind" : "IdRef", "name" : "'z'" } + ], + "extensions" : [ "SPV_AMD_shader_trinary_minmax" ] + }, + { + "opname" : "SMid3AMD", + "opcode" : 9, + "operands" : [ + { "kind" : "IdRef", "name" : "'x'" }, + { "kind" : "IdRef", "name" : "'y'" }, + { "kind" : "IdRef", "name" : "'z'" } + ], + "extensions" : [ "SPV_AMD_shader_trinary_minmax" ] + } + ] +} diff --git a/source/id_descriptor.cpp b/source/id_descriptor.cpp new file mode 100644 index 000000000..d44ed672c --- /dev/null +++ b/source/id_descriptor.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/id_descriptor.h" + +#include +#include + +#include "source/opcode.h" +#include "source/operand.h" + +namespace spvtools { +namespace { + +// Hashes an array of words. Order of words is important. +uint32_t HashU32Array(const std::vector& words) { + // The hash function is a sum of hashes of each word seeded by word index. + // Knuth's multiplicative hash is used to hash the words. + const uint32_t kKnuthMulHash = 2654435761; + uint32_t val = 0; + for (uint32_t i = 0; i < words.size(); ++i) { + val += (words[i] + i + 123) * kKnuthMulHash; + } + return val; +} + +} // namespace + +uint32_t IdDescriptorCollection::ProcessInstruction( + const spv_parsed_instruction_t& inst) { + if (!inst.result_id) return 0; + + assert(words_.empty()); + words_.push_back(inst.words[0]); + + for (size_t operand_index = 0; operand_index < inst.num_operands; + ++operand_index) { + const auto& operand = inst.operands[operand_index]; + if (spvIsIdType(operand.type)) { + const uint32_t id = inst.words[operand.offset]; + const auto it = id_to_descriptor_.find(id); + // Forward declared ids are not hashed. + if (it != id_to_descriptor_.end()) { + words_.push_back(it->second); + } + } else { + for (size_t operand_word_index = 0; + operand_word_index < operand.num_words; ++operand_word_index) { + words_.push_back(inst.words[operand.offset + operand_word_index]); + } + } + } + + uint32_t descriptor = + custom_hash_func_ ? custom_hash_func_(words_) : HashU32Array(words_); + if (descriptor == 0) descriptor = 1; + assert(descriptor); + + words_.clear(); + + const auto result = id_to_descriptor_.emplace(inst.result_id, descriptor); + assert(result.second); + (void)result; + return descriptor; +} + +} // namespace spvtools diff --git a/source/id_descriptor.h b/source/id_descriptor.h new file mode 100644 index 000000000..add23343a --- /dev/null +++ b/source/id_descriptor.h @@ -0,0 +1,63 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_ID_DESCRIPTOR_H_ +#define SOURCE_ID_DESCRIPTOR_H_ + +#include +#include + +#include "spirv-tools/libspirv.hpp" + +namespace spvtools { + +using CustomHashFunc = std::function&)>; + +// Computes and stores id descriptors. +// +// Descriptors are computed as hash of all words in the instruction where ids +// were substituted with previously computed descriptors. +class IdDescriptorCollection { + public: + explicit IdDescriptorCollection( + CustomHashFunc custom_hash_func = CustomHashFunc()) + : custom_hash_func_(custom_hash_func) { + words_.reserve(16); + } + + // Computes descriptor for the result id of the given instruction and + // registers it in id_to_descriptor_. Returns the computed descriptor. + // This function needs to be sequentially called for every instruction in the + // module. + uint32_t ProcessInstruction(const spv_parsed_instruction_t& inst); + + // Returns a previously computed descriptor id. + uint32_t GetDescriptor(uint32_t id) const { + const auto it = id_to_descriptor_.find(id); + if (it == id_to_descriptor_.end()) return 0; + return it->second; + } + + private: + std::unordered_map id_to_descriptor_; + + std::function&)> custom_hash_func_; + + // Scratch buffer used for hashing. Class member to optimize on allocation. + std::vector words_; +}; + +} // namespace spvtools + +#endif // SOURCE_ID_DESCRIPTOR_H_ diff --git a/source/instruction.h b/source/instruction.h new file mode 100644 index 000000000..9e7dccd03 --- /dev/null +++ b/source/instruction.h @@ -0,0 +1,49 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_INSTRUCTION_H_ +#define SOURCE_INSTRUCTION_H_ + +#include +#include + +#include "source/latest_version_spirv_header.h" +#include "spirv-tools/libspirv.h" + +// Describes an instruction. +struct spv_instruction_t { + // Normally, both opcode and extInstType contain valid data. + // However, when the assembler parses ! as the first word in + // an instruction and opcode and extInstType are invalid. + SpvOp opcode; + spv_ext_inst_type_t extInstType; + + // The Id of the result type, if this instruction has one. Zero otherwise. + uint32_t resultTypeId; + + // The instruction, as a sequence of 32-bit words. + // For a regular instruction the opcode and word count are combined + // in words[0], as described in the SPIR-V spec. + // Otherwise, the first token was !, and that number appears + // in words[0]. Subsequent elements are the result of parsing + // tokens in the alternate parsing mode as described in syntax.md. + std::vector words; +}; + +// Appends a word to an instruction, without checking for overflow. +inline void spvInstructionAddWord(spv_instruction_t* inst, uint32_t value) { + inst->words.push_back(value); +} + +#endif // SOURCE_INSTRUCTION_H_ diff --git a/source/latest_version_glsl_std_450_header.h b/source/latest_version_glsl_std_450_header.h new file mode 100644 index 000000000..bed1f2502 --- /dev/null +++ b/source/latest_version_glsl_std_450_header.h @@ -0,0 +1,20 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_LATEST_VERSION_GLSL_STD_450_HEADER_H_ +#define SOURCE_LATEST_VERSION_GLSL_STD_450_HEADER_H_ + +#include "spirv/unified1/GLSL.std.450.h" + +#endif // SOURCE_LATEST_VERSION_GLSL_STD_450_HEADER_H_ diff --git a/source/latest_version_opencl_std_header.h b/source/latest_version_opencl_std_header.h new file mode 100644 index 000000000..90ff9c033 --- /dev/null +++ b/source/latest_version_opencl_std_header.h @@ -0,0 +1,20 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_LATEST_VERSION_OPENCL_STD_HEADER_H_ +#define SOURCE_LATEST_VERSION_OPENCL_STD_HEADER_H_ + +#include "spirv/unified1/OpenCL.std.h" + +#endif // SOURCE_LATEST_VERSION_OPENCL_STD_HEADER_H_ diff --git a/source/latest_version_spirv_header.h b/source/latest_version_spirv_header.h new file mode 100644 index 000000000..e4f28e43e --- /dev/null +++ b/source/latest_version_spirv_header.h @@ -0,0 +1,20 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_LATEST_VERSION_SPIRV_HEADER_H_ +#define SOURCE_LATEST_VERSION_SPIRV_HEADER_H_ + +#include "spirv/unified1/spirv.h" + +#endif // SOURCE_LATEST_VERSION_SPIRV_HEADER_H_ diff --git a/source/libspirv.cpp b/source/libspirv.cpp new file mode 100644 index 000000000..b5fe89766 --- /dev/null +++ b/source/libspirv.cpp @@ -0,0 +1,131 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "spirv-tools/libspirv.hpp" + +#include + +#include +#include +#include + +#include "source/table.h" + +namespace spvtools { + +Context::Context(spv_target_env env) : context_(spvContextCreate(env)) {} + +Context::Context(Context&& other) : context_(other.context_) { + other.context_ = nullptr; +} + +Context& Context::operator=(Context&& other) { + spvContextDestroy(context_); + context_ = other.context_; + other.context_ = nullptr; + + return *this; +} + +Context::~Context() { spvContextDestroy(context_); } + +void Context::SetMessageConsumer(MessageConsumer consumer) { + SetContextMessageConsumer(context_, std::move(consumer)); +} + +spv_context& Context::CContext() { return context_; } + +const spv_context& Context::CContext() const { return context_; } + +// Structs for holding the data members for SpvTools. +struct SpirvTools::Impl { + explicit Impl(spv_target_env env) : context(spvContextCreate(env)) { + // The default consumer in spv_context_t is a null consumer, which provides + // equivalent functionality (from the user's perspective) as a real consumer + // does nothing. + } + ~Impl() { spvContextDestroy(context); } + + spv_context context; // C interface context object. +}; + +SpirvTools::SpirvTools(spv_target_env env) : impl_(new Impl(env)) {} + +SpirvTools::~SpirvTools() {} + +void SpirvTools::SetMessageConsumer(MessageConsumer consumer) { + SetContextMessageConsumer(impl_->context, std::move(consumer)); +} + +bool SpirvTools::Assemble(const std::string& text, + std::vector* binary, + uint32_t options) const { + return Assemble(text.data(), text.size(), binary, options); +} + +bool SpirvTools::Assemble(const char* text, const size_t text_size, + std::vector* binary, + uint32_t options) const { + spv_binary spvbinary = nullptr; + spv_result_t status = spvTextToBinaryWithOptions( + impl_->context, text, text_size, options, &spvbinary, nullptr); + if (status == SPV_SUCCESS) { + binary->assign(spvbinary->code, spvbinary->code + spvbinary->wordCount); + } + spvBinaryDestroy(spvbinary); + return status == SPV_SUCCESS; +} + +bool SpirvTools::Disassemble(const std::vector& binary, + std::string* text, uint32_t options) const { + return Disassemble(binary.data(), binary.size(), text, options); +} + +bool SpirvTools::Disassemble(const uint32_t* binary, const size_t binary_size, + std::string* text, uint32_t options) const { + spv_text spvtext = nullptr; + spv_result_t status = spvBinaryToText(impl_->context, binary, binary_size, + options, &spvtext, nullptr); + if (status == SPV_SUCCESS) { + text->assign(spvtext->str, spvtext->str + spvtext->length); + } + spvTextDestroy(spvtext); + return status == SPV_SUCCESS; +} + +bool SpirvTools::Validate(const std::vector& binary) const { + return Validate(binary.data(), binary.size()); +} + +bool SpirvTools::Validate(const uint32_t* binary, + const size_t binary_size) const { + return spvValidateBinary(impl_->context, binary, binary_size, nullptr) == + SPV_SUCCESS; +} + +bool SpirvTools::Validate(const uint32_t* binary, const size_t binary_size, + spv_validator_options options) const { + spv_const_binary_t the_binary{binary, binary_size}; + spv_diagnostic diagnostic = nullptr; + bool valid = spvValidateWithOptions(impl_->context, options, &the_binary, + &diagnostic) == SPV_SUCCESS; + if (!valid && impl_->context->consumer) { + impl_->context->consumer.operator()( + SPV_MSG_ERROR, nullptr, diagnostic->position, diagnostic->error); + } + spvDiagnosticDestroy(diagnostic); + return valid; +} + +} // namespace spvtools diff --git a/source/link/CMakeLists.txt b/source/link/CMakeLists.txt new file mode 100644 index 000000000..8ca4df39f --- /dev/null +++ b/source/link/CMakeLists.txt @@ -0,0 +1,36 @@ +# Copyright (c) 2017 Pierre Moreau + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_library(SPIRV-Tools-link + linker.cpp +) + +spvtools_default_compile_options(SPIRV-Tools-link) +target_include_directories(SPIRV-Tools-link + PUBLIC ${spirv-tools_SOURCE_DIR}/include + PUBLIC ${SPIRV_HEADER_INCLUDE_DIR} + PRIVATE ${spirv-tools_BINARY_DIR} +) +# We need the IR functionnalities from the optimizer +target_link_libraries(SPIRV-Tools-link + PUBLIC SPIRV-Tools-opt) + +set_property(TARGET SPIRV-Tools-link PROPERTY FOLDER "SPIRV-Tools libraries") +spvtools_check_symbol_exports(SPIRV-Tools-link) + +if(ENABLE_SPIRV_TOOLS_INSTALL) + install(TARGETS SPIRV-Tools-link + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}) +endif(ENABLE_SPIRV_TOOLS_INSTALL) diff --git a/source/link/linker.cpp b/source/link/linker.cpp new file mode 100644 index 000000000..f28b7595a --- /dev/null +++ b/source/link/linker.cpp @@ -0,0 +1,769 @@ +// Copyright (c) 2017 Pierre Moreau +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "spirv-tools/linker.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/assembly_grammar.h" +#include "source/diagnostic.h" +#include "source/opt/build_module.h" +#include "source/opt/compact_ids_pass.h" +#include "source/opt/decoration_manager.h" +#include "source/opt/ir_loader.h" +#include "source/opt/pass_manager.h" +#include "source/opt/remove_duplicates_pass.h" +#include "source/spirv_target_env.h" +#include "source/util/make_unique.h" +#include "spirv-tools/libspirv.hpp" + +namespace spvtools { +namespace { + +using opt::IRContext; +using opt::Instruction; +using opt::Module; +using opt::Operand; +using opt::PassManager; +using opt::RemoveDuplicatesPass; +using opt::analysis::DecorationManager; +using opt::analysis::DefUseManager; + +// Stores various information about an imported or exported symbol. +struct LinkageSymbolInfo { + SpvId id; // ID of the symbol + SpvId type_id; // ID of the type of the symbol + std::string name; // unique name defining the symbol and used for matching + // imports and exports together + std::vector parameter_ids; // ID of the parameters of the symbol, if + // it is a function +}; +struct LinkageEntry { + LinkageSymbolInfo imported_symbol; + LinkageSymbolInfo exported_symbol; + + LinkageEntry(const LinkageSymbolInfo& import_info, + const LinkageSymbolInfo& export_info) + : imported_symbol(import_info), exported_symbol(export_info) {} +}; +using LinkageTable = std::vector; + +// Shifts the IDs used in each binary of |modules| so that they occupy a +// disjoint range from the other binaries, and compute the new ID bound which +// is returned in |max_id_bound|. +// +// Both |modules| and |max_id_bound| should not be null, and |modules| should +// not be empty either. Furthermore |modules| should not contain any null +// pointers. +spv_result_t ShiftIdsInModules(const MessageConsumer& consumer, + std::vector* modules, + uint32_t* max_id_bound); + +// Generates the header for the linked module and returns it in |header|. +// +// |header| should not be null, |modules| should not be empty and pointers +// should be non-null. |max_id_bound| should be strictly greater than 0. +// +// TODO(pierremoreau): What to do when binaries use different versions of +// SPIR-V? For now, use the max of all versions found in +// the input modules. +spv_result_t GenerateHeader(const MessageConsumer& consumer, + const std::vector& modules, + uint32_t max_id_bound, opt::ModuleHeader* header); + +// Merge all the modules from |in_modules| into a single module owned by +// |linked_context|. +// +// |linked_context| should not be null. +spv_result_t MergeModules(const MessageConsumer& consumer, + const std::vector& in_modules, + const AssemblyGrammar& grammar, + IRContext* linked_context); + +// Compute all pairs of import and export and return it in |linkings_to_do|. +// +// |linkings_to_do should not be null. Built-in symbols will be ignored. +// +// TODO(pierremoreau): Linkage attributes applied by a group decoration are +// currently not handled. (You could have a group being +// applied to a single ID.) +// TODO(pierremoreau): What should be the proper behaviour with built-in +// symbols? +spv_result_t GetImportExportPairs(const MessageConsumer& consumer, + const opt::IRContext& linked_context, + const DefUseManager& def_use_manager, + const DecorationManager& decoration_manager, + bool allow_partial_linkage, + LinkageTable* linkings_to_do); + +// Checks that for each pair of import and export, the import and export have +// the same type as well as the same decorations. +// +// TODO(pierremoreau): Decorations on functions parameters are currently not +// checked. +spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer, + const LinkageTable& linkings_to_do, + opt::IRContext* context); + +// Remove linkage specific instructions, such as prototypes of imported +// functions, declarations of imported variables, import (and export if +// necessary) linkage attribtes. +// +// |linked_context| and |decoration_manager| should not be null, and the +// 'RemoveDuplicatePass' should be run first. +// +// TODO(pierremoreau): Linkage attributes applied by a group decoration are +// currently not handled. (You could have a group being +// applied to a single ID.) +// TODO(pierremoreau): Run a pass for removing dead instructions, for example +// OpName for prototypes of imported funcions. +spv_result_t RemoveLinkageSpecificInstructions( + const MessageConsumer& consumer, const LinkerOptions& options, + const LinkageTable& linkings_to_do, DecorationManager* decoration_manager, + opt::IRContext* linked_context); + +// Verify that the unique ids of each instruction in |linked_context| (i.e. the +// merged module) are truly unique. Does not check the validity of other ids +spv_result_t VerifyIds(const MessageConsumer& consumer, + opt::IRContext* linked_context); + +spv_result_t ShiftIdsInModules(const MessageConsumer& consumer, + std::vector* modules, + uint32_t* max_id_bound) { + spv_position_t position = {}; + + if (modules == nullptr) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA) + << "|modules| of ShiftIdsInModules should not be null."; + if (modules->empty()) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA) + << "|modules| of ShiftIdsInModules should not be empty."; + if (max_id_bound == nullptr) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA) + << "|max_id_bound| of ShiftIdsInModules should not be null."; + + uint32_t id_bound = modules->front()->IdBound() - 1u; + for (auto module_iter = modules->begin() + 1; module_iter != modules->end(); + ++module_iter) { + Module* module = *module_iter; + module->ForEachInst([&id_bound](Instruction* insn) { + insn->ForEachId([&id_bound](uint32_t* id) { *id += id_bound; }); + }); + id_bound += module->IdBound() - 1u; + if (id_bound > 0x3FFFFF) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_ID) + << "The limit of IDs, 4194303, was exceeded:" + << " " << id_bound << " is the current ID bound."; + + // Invalidate the DefUseManager + module->context()->InvalidateAnalyses(opt::IRContext::kAnalysisDefUse); + } + ++id_bound; + if (id_bound > 0x3FFFFF) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_ID) + << "The limit of IDs, 4194303, was exceeded:" + << " " << id_bound << " is the current ID bound."; + + *max_id_bound = id_bound; + + return SPV_SUCCESS; +} + +spv_result_t GenerateHeader(const MessageConsumer& consumer, + const std::vector& modules, + uint32_t max_id_bound, opt::ModuleHeader* header) { + spv_position_t position = {}; + + if (modules.empty()) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA) + << "|modules| of GenerateHeader should not be empty."; + if (max_id_bound == 0u) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA) + << "|max_id_bound| of GenerateHeader should not be null."; + + uint32_t version = 0u; + for (const auto& module : modules) + version = std::max(version, module->version()); + + header->magic_number = SpvMagicNumber; + header->version = version; + header->generator = 17u; + header->bound = max_id_bound; + header->reserved = 0u; + + return SPV_SUCCESS; +} + +spv_result_t MergeModules(const MessageConsumer& consumer, + const std::vector& input_modules, + const AssemblyGrammar& grammar, + IRContext* linked_context) { + spv_position_t position = {}; + + if (linked_context == nullptr) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA) + << "|linked_module| of MergeModules should not be null."; + Module* linked_module = linked_context->module(); + + if (input_modules.empty()) return SPV_SUCCESS; + + for (const auto& module : input_modules) + for (const auto& inst : module->capabilities()) + linked_module->AddCapability( + std::unique_ptr(inst.Clone(linked_context))); + + for (const auto& module : input_modules) + for (const auto& inst : module->extensions()) + linked_module->AddExtension( + std::unique_ptr(inst.Clone(linked_context))); + + for (const auto& module : input_modules) + for (const auto& inst : module->ext_inst_imports()) + linked_module->AddExtInstImport( + std::unique_ptr(inst.Clone(linked_context))); + + do { + const Instruction* memory_model_inst = input_modules[0]->GetMemoryModel(); + if (memory_model_inst == nullptr) break; + + uint32_t addressing_model = memory_model_inst->GetSingleWordOperand(0u); + uint32_t memory_model = memory_model_inst->GetSingleWordOperand(1u); + for (const auto& module : input_modules) { + memory_model_inst = module->GetMemoryModel(); + if (memory_model_inst == nullptr) continue; + + if (addressing_model != memory_model_inst->GetSingleWordOperand(0u)) { + spv_operand_desc initial_desc = nullptr, current_desc = nullptr; + grammar.lookupOperand(SPV_OPERAND_TYPE_ADDRESSING_MODEL, + addressing_model, &initial_desc); + grammar.lookupOperand(SPV_OPERAND_TYPE_ADDRESSING_MODEL, + memory_model_inst->GetSingleWordOperand(0u), + ¤t_desc); + return DiagnosticStream(position, consumer, "", SPV_ERROR_INTERNAL) + << "Conflicting addressing models: " << initial_desc->name + << " vs " << current_desc->name << "."; + } + if (memory_model != memory_model_inst->GetSingleWordOperand(1u)) { + spv_operand_desc initial_desc = nullptr, current_desc = nullptr; + grammar.lookupOperand(SPV_OPERAND_TYPE_MEMORY_MODEL, memory_model, + &initial_desc); + grammar.lookupOperand(SPV_OPERAND_TYPE_MEMORY_MODEL, + memory_model_inst->GetSingleWordOperand(1u), + ¤t_desc); + return DiagnosticStream(position, consumer, "", SPV_ERROR_INTERNAL) + << "Conflicting memory models: " << initial_desc->name << " vs " + << current_desc->name << "."; + } + } + + if (memory_model_inst != nullptr) + linked_module->SetMemoryModel(std::unique_ptr( + memory_model_inst->Clone(linked_context))); + } while (false); + + std::vector> entry_points; + for (const auto& module : input_modules) + for (const auto& inst : module->entry_points()) { + const uint32_t model = inst.GetSingleWordInOperand(0); + const char* const name = + reinterpret_cast(inst.GetInOperand(2).words.data()); + const auto i = std::find_if( + entry_points.begin(), entry_points.end(), + [model, name](const std::pair& v) { + return v.first == model && strcmp(name, v.second) == 0; + }); + if (i != entry_points.end()) { + spv_operand_desc desc = nullptr; + grammar.lookupOperand(SPV_OPERAND_TYPE_EXECUTION_MODEL, model, &desc); + return DiagnosticStream(position, consumer, "", SPV_ERROR_INTERNAL) + << "The entry point \"" << name << "\", with execution model " + << desc->name << ", was already defined."; + } + linked_module->AddEntryPoint( + std::unique_ptr(inst.Clone(linked_context))); + entry_points.emplace_back(model, name); + } + + for (const auto& module : input_modules) + for (const auto& inst : module->execution_modes()) + linked_module->AddExecutionMode( + std::unique_ptr(inst.Clone(linked_context))); + + for (const auto& module : input_modules) + for (const auto& inst : module->debugs1()) + linked_module->AddDebug1Inst( + std::unique_ptr(inst.Clone(linked_context))); + + for (const auto& module : input_modules) + for (const auto& inst : module->debugs2()) + linked_module->AddDebug2Inst( + std::unique_ptr(inst.Clone(linked_context))); + + for (const auto& module : input_modules) + for (const auto& inst : module->debugs3()) + linked_module->AddDebug3Inst( + std::unique_ptr(inst.Clone(linked_context))); + + // If the generated module uses SPIR-V 1.1 or higher, add an + // OpModuleProcessed instruction about the linking step. + if (linked_module->version() >= 0x10100) { + const std::string processed_string("Linked by SPIR-V Tools Linker"); + const auto num_chars = processed_string.size(); + // Compute num words, accommodate the terminating null character. + const auto num_words = (num_chars + 1 + 3) / 4; + std::vector processed_words(num_words, 0u); + std::memcpy(processed_words.data(), processed_string.data(), num_chars); + linked_module->AddDebug3Inst(std::unique_ptr( + new Instruction(linked_context, SpvOpModuleProcessed, 0u, 0u, + {{SPV_OPERAND_TYPE_LITERAL_STRING, processed_words}}))); + } + + for (const auto& module : input_modules) + for (const auto& inst : module->annotations()) + linked_module->AddAnnotationInst( + std::unique_ptr(inst.Clone(linked_context))); + + // TODO(pierremoreau): Since the modules have not been validate, should we + // expect SpvStorageClassFunction variables outside + // functions? + uint32_t num_global_values = 0u; + for (const auto& module : input_modules) { + for (const auto& inst : module->types_values()) { + linked_module->AddType( + std::unique_ptr(inst.Clone(linked_context))); + num_global_values += inst.opcode() == SpvOpVariable; + } + } + if (num_global_values > 0xFFFF) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INTERNAL) + << "The limit of global values, 65535, was exceeded;" + << " " << num_global_values << " global values were found."; + + // Process functions and their basic blocks + for (const auto& module : input_modules) { + for (const auto& func : *module) { + std::unique_ptr cloned_func(func.Clone(linked_context)); + linked_module->AddFunction(std::move(cloned_func)); + } + } + + return SPV_SUCCESS; +} + +spv_result_t GetImportExportPairs(const MessageConsumer& consumer, + const opt::IRContext& linked_context, + const DefUseManager& def_use_manager, + const DecorationManager& decoration_manager, + bool allow_partial_linkage, + LinkageTable* linkings_to_do) { + spv_position_t position = {}; + + if (linkings_to_do == nullptr) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA) + << "|linkings_to_do| of GetImportExportPairs should not be empty."; + + std::vector imports; + std::unordered_map> exports; + + // Figure out the imports and exports + for (const auto& decoration : linked_context.annotations()) { + if (decoration.opcode() != SpvOpDecorate || + decoration.GetSingleWordInOperand(1u) != SpvDecorationLinkageAttributes) + continue; + + const SpvId id = decoration.GetSingleWordInOperand(0u); + // Ignore if the targeted symbol is a built-in + bool is_built_in = false; + for (const auto& id_decoration : + decoration_manager.GetDecorationsFor(id, false)) { + if (id_decoration->GetSingleWordInOperand(1u) == SpvDecorationBuiltIn) { + is_built_in = true; + break; + } + } + if (is_built_in) { + continue; + } + + const uint32_t type = decoration.GetSingleWordInOperand(3u); + + LinkageSymbolInfo symbol_info; + symbol_info.name = + reinterpret_cast(decoration.GetInOperand(2u).words.data()); + symbol_info.id = id; + symbol_info.type_id = 0u; + + // Retrieve the type of the current symbol. This information will be used + // when checking that the imported and exported symbols have the same + // types. + const Instruction* def_inst = def_use_manager.GetDef(id); + if (def_inst == nullptr) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY) + << "ID " << id << " is never defined:\n"; + + if (def_inst->opcode() == SpvOpVariable) { + symbol_info.type_id = def_inst->type_id(); + } else if (def_inst->opcode() == SpvOpFunction) { + symbol_info.type_id = def_inst->GetSingleWordInOperand(1u); + + // range-based for loop calls begin()/end(), but never cbegin()/cend(), + // which will not work here. + for (auto func_iter = linked_context.module()->cbegin(); + func_iter != linked_context.module()->cend(); ++func_iter) { + if (func_iter->result_id() != id) continue; + func_iter->ForEachParam([&symbol_info](const Instruction* inst) { + symbol_info.parameter_ids.push_back(inst->result_id()); + }); + } + } else { + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY) + << "Only global variables and functions can be decorated using" + << " LinkageAttributes; " << id << " is neither of them.\n"; + } + + if (type == SpvLinkageTypeImport) + imports.push_back(symbol_info); + else if (type == SpvLinkageTypeExport) + exports[symbol_info.name].push_back(symbol_info); + } + + // Find the import/export pairs + for (const auto& import : imports) { + std::vector possible_exports; + const auto& exp = exports.find(import.name); + if (exp != exports.end()) possible_exports = exp->second; + if (possible_exports.empty() && !allow_partial_linkage) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY) + << "Unresolved external reference to \"" << import.name << "\"."; + else if (possible_exports.size() > 1u) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY) + << "Too many external references, " << possible_exports.size() + << ", were found for \"" << import.name << "\"."; + + if (!possible_exports.empty()) + linkings_to_do->emplace_back(import, possible_exports.front()); + } + + return SPV_SUCCESS; +} + +spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer, + const LinkageTable& linkings_to_do, + opt::IRContext* context) { + spv_position_t position = {}; + + // Ensure th import and export types are the same. + const DefUseManager& def_use_manager = *context->get_def_use_mgr(); + const DecorationManager& decoration_manager = *context->get_decoration_mgr(); + for (const auto& linking_entry : linkings_to_do) { + if (!RemoveDuplicatesPass::AreTypesEqual( + *def_use_manager.GetDef(linking_entry.imported_symbol.type_id), + *def_use_manager.GetDef(linking_entry.exported_symbol.type_id), + context)) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY) + << "Type mismatch on symbol \"" + << linking_entry.imported_symbol.name + << "\" between imported variable/function %" + << linking_entry.imported_symbol.id + << " and exported variable/function %" + << linking_entry.exported_symbol.id << "."; + } + + // Ensure the import and export decorations are similar + for (const auto& linking_entry : linkings_to_do) { + if (!decoration_manager.HaveTheSameDecorations( + linking_entry.imported_symbol.id, linking_entry.exported_symbol.id)) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY) + << "Decorations mismatch on symbol \"" + << linking_entry.imported_symbol.name + << "\" between imported variable/function %" + << linking_entry.imported_symbol.id + << " and exported variable/function %" + << linking_entry.exported_symbol.id << "."; + // TODO(pierremoreau): Decorations on function parameters should probably + // match, except for FuncParamAttr if I understand the + // spec correctly. + // TODO(pierremoreau): Decorations on the function return type should + // match, except for FuncParamAttr. + } + + return SPV_SUCCESS; +} + +spv_result_t RemoveLinkageSpecificInstructions( + const MessageConsumer& consumer, const LinkerOptions& options, + const LinkageTable& linkings_to_do, DecorationManager* decoration_manager, + opt::IRContext* linked_context) { + spv_position_t position = {}; + + if (decoration_manager == nullptr) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA) + << "|decoration_manager| of RemoveLinkageSpecificInstructions " + "should not be empty."; + if (linked_context == nullptr) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA) + << "|linked_module| of RemoveLinkageSpecificInstructions should not " + "be empty."; + + // TODO(pierremoreau): Remove FuncParamAttr decorations of imported + // functions' return type. + + // Remove FuncParamAttr decorations of imported functions' parameters. + // From the SPIR-V specification, Sec. 2.13: + // When resolving imported functions, the Function Control and all Function + // Parameter Attributes are taken from the function definition, and not + // from the function declaration. + for (const auto& linking_entry : linkings_to_do) { + for (const auto parameter_id : + linking_entry.imported_symbol.parameter_ids) { + decoration_manager->RemoveDecorationsFrom( + parameter_id, [](const Instruction& inst) { + return (inst.opcode() == SpvOpDecorate || + inst.opcode() == SpvOpMemberDecorate) && + inst.GetSingleWordInOperand(1u) == + SpvDecorationFuncParamAttr; + }); + } + } + + // Remove prototypes of imported functions + for (const auto& linking_entry : linkings_to_do) { + for (auto func_iter = linked_context->module()->begin(); + func_iter != linked_context->module()->end();) { + if (func_iter->result_id() == linking_entry.imported_symbol.id) + func_iter = func_iter.Erase(); + else + ++func_iter; + } + } + + // Remove declarations of imported variables + for (const auto& linking_entry : linkings_to_do) { + auto next = linked_context->types_values_begin(); + for (auto inst = next; inst != linked_context->types_values_end(); + inst = next) { + ++next; + if (inst->result_id() == linking_entry.imported_symbol.id) { + linked_context->KillInst(&*inst); + } + } + } + + // If partial linkage is allowed, we need an efficient way to check whether + // an imported ID had a corresponding export symbol. As uses of the imported + // symbol have already been replaced by the exported symbol, use the exported + // symbol ID. + // TODO(pierremoreau): This will not work if the decoration is applied + // through a group, but the linker does not support that + // either. + std::unordered_set imports; + if (options.GetAllowPartialLinkage()) { + imports.reserve(linkings_to_do.size()); + for (const auto& linking_entry : linkings_to_do) + imports.emplace(linking_entry.exported_symbol.id); + } + + // Remove import linkage attributes + auto next = linked_context->annotation_begin(); + for (auto inst = next; inst != linked_context->annotation_end(); + inst = next) { + ++next; + // If this is an import annotation: + // * if we do not allow partial linkage, remove all import annotations; + // * otherwise, remove the annotation only if there was a corresponding + // export. + if (inst->opcode() == SpvOpDecorate && + inst->GetSingleWordOperand(1u) == SpvDecorationLinkageAttributes && + inst->GetSingleWordOperand(3u) == SpvLinkageTypeImport && + (!options.GetAllowPartialLinkage() || + imports.find(inst->GetSingleWordOperand(0u)) != imports.end())) { + linked_context->KillInst(&*inst); + } + } + + // Remove export linkage attributes if making an executable + if (!options.GetCreateLibrary()) { + next = linked_context->annotation_begin(); + for (auto inst = next; inst != linked_context->annotation_end(); + inst = next) { + ++next; + if (inst->opcode() == SpvOpDecorate && + inst->GetSingleWordOperand(1u) == SpvDecorationLinkageAttributes && + inst->GetSingleWordOperand(3u) == SpvLinkageTypeExport) { + linked_context->KillInst(&*inst); + } + } + } + + // Remove Linkage capability if making an executable and partial linkage is + // not allowed + if (!options.GetCreateLibrary() && !options.GetAllowPartialLinkage()) { + for (auto& inst : linked_context->capabilities()) + if (inst.GetSingleWordInOperand(0u) == SpvCapabilityLinkage) { + linked_context->KillInst(&inst); + // The RemoveDuplicatesPass did remove duplicated capabilities, so we + // now there aren’t more SpvCapabilityLinkage further down. + break; + } + } + + return SPV_SUCCESS; +} + +spv_result_t VerifyIds(const MessageConsumer& consumer, + opt::IRContext* linked_context) { + std::unordered_set ids; + bool ok = true; + linked_context->module()->ForEachInst( + [&ids, &ok](const opt::Instruction* inst) { + ok &= ids.insert(inst->unique_id()).second; + }); + + if (!ok) { + consumer(SPV_MSG_INTERNAL_ERROR, "", {}, "Non-unique id in merged module"); + return SPV_ERROR_INVALID_ID; + } + + return SPV_SUCCESS; +} + +} // namespace + +spv_result_t Link(const Context& context, + const std::vector>& binaries, + std::vector* linked_binary, + const LinkerOptions& options) { + std::vector binary_ptrs; + binary_ptrs.reserve(binaries.size()); + std::vector binary_sizes; + binary_sizes.reserve(binaries.size()); + + for (const auto& binary : binaries) { + binary_ptrs.push_back(binary.data()); + binary_sizes.push_back(binary.size()); + } + + return Link(context, binary_ptrs.data(), binary_sizes.data(), binaries.size(), + linked_binary, options); +} + +spv_result_t Link(const Context& context, const uint32_t* const* binaries, + const size_t* binary_sizes, size_t num_binaries, + std::vector* linked_binary, + const LinkerOptions& options) { + spv_position_t position = {}; + const spv_context& c_context = context.CContext(); + const MessageConsumer& consumer = c_context->consumer; + + linked_binary->clear(); + if (num_binaries == 0u) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY) + << "No modules were given."; + + std::vector> ir_contexts; + std::vector modules; + modules.reserve(num_binaries); + for (size_t i = 0u; i < num_binaries; ++i) { + const uint32_t schema = binaries[i][4u]; + if (schema != 0u) { + position.index = 4u; + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY) + << "Schema is non-zero for module " << i << "."; + } + + std::unique_ptr ir_context = BuildModule( + c_context->target_env, consumer, binaries[i], binary_sizes[i]); + if (ir_context == nullptr) + return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY) + << "Failed to build a module out of " << ir_contexts.size() << "."; + modules.push_back(ir_context->module()); + ir_contexts.push_back(std::move(ir_context)); + } + + // Phase 1: Shift the IDs used in each binary so that they occupy a disjoint + // range from the other binaries, and compute the new ID bound. + uint32_t max_id_bound = 0u; + spv_result_t res = ShiftIdsInModules(consumer, &modules, &max_id_bound); + if (res != SPV_SUCCESS) return res; + + // Phase 2: Generate the header + opt::ModuleHeader header; + res = GenerateHeader(consumer, modules, max_id_bound, &header); + if (res != SPV_SUCCESS) return res; + IRContext linked_context(c_context->target_env, consumer); + linked_context.module()->SetHeader(header); + + // Phase 3: Merge all the binaries into a single one. + AssemblyGrammar grammar(c_context); + res = MergeModules(consumer, modules, grammar, &linked_context); + if (res != SPV_SUCCESS) return res; + + if (options.GetVerifyIds()) { + res = VerifyIds(consumer, &linked_context); + if (res != SPV_SUCCESS) return res; + } + + // Phase 4: Find the import/export pairs + LinkageTable linkings_to_do; + res = GetImportExportPairs(consumer, linked_context, + *linked_context.get_def_use_mgr(), + *linked_context.get_decoration_mgr(), + options.GetAllowPartialLinkage(), &linkings_to_do); + if (res != SPV_SUCCESS) return res; + + // Phase 5: Ensure the import and export have the same types and decorations. + res = + CheckImportExportCompatibility(consumer, linkings_to_do, &linked_context); + if (res != SPV_SUCCESS) return res; + + // Phase 6: Remove duplicates + PassManager manager; + manager.SetMessageConsumer(consumer); + manager.AddPass(); + opt::Pass::Status pass_res = manager.Run(&linked_context); + if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA; + + // Phase 7: Rematch import variables/functions to export variables/functions + for (const auto& linking_entry : linkings_to_do) + linked_context.ReplaceAllUsesWith(linking_entry.imported_symbol.id, + linking_entry.exported_symbol.id); + + // Phase 8: Remove linkage specific instructions, such as import/export + // attributes, linkage capability, etc. if applicable + res = RemoveLinkageSpecificInstructions(consumer, options, linkings_to_do, + linked_context.get_decoration_mgr(), + &linked_context); + if (res != SPV_SUCCESS) return res; + + // Phase 9: Compact the IDs used in the module + manager.AddPass(); + pass_res = manager.Run(&linked_context); + if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA; + + // Phase 10: Output the module + linked_context.module()->ToBinary(linked_binary, true); + + return SPV_SUCCESS; +} + +} // namespace spvtools diff --git a/source/macro.h b/source/macro.h new file mode 100644 index 000000000..7219ffed1 --- /dev/null +++ b/source/macro.h @@ -0,0 +1,25 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_MACRO_H_ +#define SOURCE_MACRO_H_ + +// Evaluates to the number of elements of array A. +// +// If we could use constexpr, then we could make this a template function. +// If the source arrays were std::array, then we could have used +// std::array::size. +#define ARRAY_SIZE(A) (static_cast(sizeof(A) / sizeof(A[0]))) + +#endif // SOURCE_MACRO_H_ diff --git a/source/name_mapper.cpp b/source/name_mapper.cpp new file mode 100644 index 000000000..43fdfb34b --- /dev/null +++ b/source/name_mapper.cpp @@ -0,0 +1,331 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/name_mapper.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "spirv-tools/libspirv.h" + +#include "source/latest_version_spirv_header.h" +#include "source/parsed_operand.h" + +namespace spvtools { +namespace { + +// Converts a uint32_t to its string decimal representation. +std::string to_string(uint32_t id) { + // Use stringstream, since some versions of Android compilers lack + // std::to_string. + std::stringstream os; + os << id; + return os.str(); +} + +} // anonymous namespace + +NameMapper GetTrivialNameMapper() { return to_string; } + +FriendlyNameMapper::FriendlyNameMapper(const spv_const_context context, + const uint32_t* code, + const size_t wordCount) + : grammar_(AssemblyGrammar(context)) { + spv_diagnostic diag = nullptr; + // We don't care if the parse fails. + spvBinaryParse(context, this, code, wordCount, nullptr, + ParseInstructionForwarder, &diag); + spvDiagnosticDestroy(diag); +} + +std::string FriendlyNameMapper::NameForId(uint32_t id) { + auto iter = name_for_id_.find(id); + if (iter == name_for_id_.end()) { + // It must have been an invalid module, so just return a trivial mapping. + // We don't care about uniqueness. + return to_string(id); + } else { + return iter->second; + } +} + +std::string FriendlyNameMapper::Sanitize(const std::string& suggested_name) { + if (suggested_name.empty()) return "_"; + // Otherwise, replace invalid characters by '_'. + std::string result; + std::string valid = + "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "_0123456789"; + std::transform(suggested_name.begin(), suggested_name.end(), + std::back_inserter(result), [&valid](const char c) { + return (std::string::npos == valid.find(c)) ? '_' : c; + }); + return result; +} + +void FriendlyNameMapper::SaveName(uint32_t id, + const std::string& suggested_name) { + if (name_for_id_.find(id) != name_for_id_.end()) return; + + const std::string sanitized_suggested_name = Sanitize(suggested_name); + std::string name = sanitized_suggested_name; + auto inserted = used_names_.insert(name); + if (!inserted.second) { + const std::string base_name = sanitized_suggested_name + "_"; + for (uint32_t index = 0; !inserted.second; ++index) { + name = base_name + to_string(index); + inserted = used_names_.insert(name); + } + } + name_for_id_[id] = name; +} + +void FriendlyNameMapper::SaveBuiltInName(uint32_t target_id, + uint32_t built_in) { +#define GLCASE(name) \ + case SpvBuiltIn##name: \ + SaveName(target_id, "gl_" #name); \ + return; +#define GLCASE2(name, suggested) \ + case SpvBuiltIn##name: \ + SaveName(target_id, "gl_" #suggested); \ + return; +#define CASE(name) \ + case SpvBuiltIn##name: \ + SaveName(target_id, #name); \ + return; + switch (built_in) { + GLCASE(Position) + GLCASE(PointSize) + GLCASE(ClipDistance) + GLCASE(CullDistance) + GLCASE2(VertexId, VertexID) + GLCASE2(InstanceId, InstanceID) + GLCASE2(PrimitiveId, PrimitiveID) + GLCASE2(InvocationId, InvocationID) + GLCASE(Layer) + GLCASE(ViewportIndex) + GLCASE(TessLevelOuter) + GLCASE(TessLevelInner) + GLCASE(TessCoord) + GLCASE(PatchVertices) + GLCASE(FragCoord) + GLCASE(PointCoord) + GLCASE(FrontFacing) + GLCASE2(SampleId, SampleID) + GLCASE(SamplePosition) + GLCASE(SampleMask) + GLCASE(FragDepth) + GLCASE(HelperInvocation) + GLCASE2(NumWorkgroups, NumWorkGroups) + GLCASE2(WorkgroupSize, WorkGroupSize) + GLCASE2(WorkgroupId, WorkGroupID) + GLCASE2(LocalInvocationId, LocalInvocationID) + GLCASE2(GlobalInvocationId, GlobalInvocationID) + GLCASE(LocalInvocationIndex) + CASE(WorkDim) + CASE(GlobalSize) + CASE(EnqueuedWorkgroupSize) + CASE(GlobalOffset) + CASE(GlobalLinearId) + CASE(SubgroupSize) + CASE(SubgroupMaxSize) + CASE(NumSubgroups) + CASE(NumEnqueuedSubgroups) + CASE(SubgroupId) + CASE(SubgroupLocalInvocationId) + GLCASE(VertexIndex) + GLCASE(InstanceIndex) + CASE(SubgroupEqMaskKHR) + CASE(SubgroupGeMaskKHR) + CASE(SubgroupGtMaskKHR) + CASE(SubgroupLeMaskKHR) + CASE(SubgroupLtMaskKHR) + default: + break; + } +#undef GLCASE +#undef GLCASE2 +#undef CASE +} + +spv_result_t FriendlyNameMapper::ParseInstruction( + const spv_parsed_instruction_t& inst) { + const auto result_id = inst.result_id; + switch (inst.opcode) { + case SpvOpName: + SaveName(inst.words[1], reinterpret_cast(inst.words + 2)); + break; + case SpvOpDecorate: + // Decorations come after OpName. So OpName will take precedence over + // decorations. + // + // In theory, we should also handle OpGroupDecorate. But that's unlikely + // to occur. + if (inst.words[2] == SpvDecorationBuiltIn) { + assert(inst.num_words > 3); + SaveBuiltInName(inst.words[1], inst.words[3]); + } + break; + case SpvOpTypeVoid: + SaveName(result_id, "void"); + break; + case SpvOpTypeBool: + SaveName(result_id, "bool"); + break; + case SpvOpTypeInt: { + std::string signedness; + std::string root; + const auto bit_width = inst.words[2]; + switch (bit_width) { + case 8: + root = "char"; + break; + case 16: + root = "short"; + break; + case 32: + root = "int"; + break; + case 64: + root = "long"; + break; + default: + root = to_string(bit_width); + signedness = "i"; + break; + } + if (0 == inst.words[3]) signedness = "u"; + SaveName(result_id, signedness + root); + } break; + case SpvOpTypeFloat: { + const auto bit_width = inst.words[2]; + switch (bit_width) { + case 16: + SaveName(result_id, "half"); + break; + case 32: + SaveName(result_id, "float"); + break; + case 64: + SaveName(result_id, "double"); + break; + default: + SaveName(result_id, std::string("fp") + to_string(bit_width)); + break; + } + } break; + case SpvOpTypeVector: + SaveName(result_id, std::string("v") + to_string(inst.words[3]) + + NameForId(inst.words[2])); + break; + case SpvOpTypeMatrix: + SaveName(result_id, std::string("mat") + to_string(inst.words[3]) + + NameForId(inst.words[2])); + break; + case SpvOpTypeArray: + SaveName(result_id, std::string("_arr_") + NameForId(inst.words[2]) + + "_" + NameForId(inst.words[3])); + break; + case SpvOpTypeRuntimeArray: + SaveName(result_id, + std::string("_runtimearr_") + NameForId(inst.words[2])); + break; + case SpvOpTypePointer: + SaveName(result_id, std::string("_ptr_") + + NameForEnumOperand(SPV_OPERAND_TYPE_STORAGE_CLASS, + inst.words[2]) + + "_" + NameForId(inst.words[3])); + break; + case SpvOpTypePipe: + SaveName(result_id, + std::string("Pipe") + + NameForEnumOperand(SPV_OPERAND_TYPE_ACCESS_QUALIFIER, + inst.words[2])); + break; + case SpvOpTypeEvent: + SaveName(result_id, "Event"); + break; + case SpvOpTypeDeviceEvent: + SaveName(result_id, "DeviceEvent"); + break; + case SpvOpTypeReserveId: + SaveName(result_id, "ReserveId"); + break; + case SpvOpTypeQueue: + SaveName(result_id, "Queue"); + break; + case SpvOpTypeOpaque: + SaveName(result_id, + std::string("Opaque_") + + Sanitize(reinterpret_cast(inst.words + 2))); + break; + case SpvOpTypePipeStorage: + SaveName(result_id, "PipeStorage"); + break; + case SpvOpTypeNamedBarrier: + SaveName(result_id, "NamedBarrier"); + break; + case SpvOpTypeStruct: + // Structs are mapped rather simplisitically. Just indicate that they + // are a struct and then give the raw Id number. + SaveName(result_id, std::string("_struct_") + to_string(result_id)); + break; + case SpvOpConstantTrue: + SaveName(result_id, "true"); + break; + case SpvOpConstantFalse: + SaveName(result_id, "false"); + break; + case SpvOpConstant: { + std::ostringstream value; + EmitNumericLiteral(&value, inst, inst.operands[2]); + auto value_str = value.str(); + // Use 'n' to signify negative. Other invalid characters will be mapped + // to underscore. + for (auto& c : value_str) + if (c == '-') c = 'n'; + SaveName(result_id, NameForId(inst.type_id) + "_" + value_str); + } break; + default: + // If this instruction otherwise defines an Id, then save a mapping for + // it. This is needed to ensure uniqueness in there is an OpName with + // string something like "1" that might collide with this result_id. + // We should only do this if a name hasn't already been registered by some + // previous forward reference. + if (result_id && name_for_id_.find(result_id) == name_for_id_.end()) + SaveName(result_id, to_string(result_id)); + break; + } + return SPV_SUCCESS; +} + +std::string FriendlyNameMapper::NameForEnumOperand(spv_operand_type_t type, + uint32_t word) { + spv_operand_desc desc = nullptr; + if (SPV_SUCCESS == grammar_.lookupOperand(type, word, &desc)) { + return desc->name; + } else { + // Invalid input. Just give something sane. + return std::string("StorageClass") + to_string(word); + } +} + +} // namespace spvtools diff --git a/source/name_mapper.h b/source/name_mapper.h new file mode 100644 index 000000000..6902141b1 --- /dev/null +++ b/source/name_mapper.h @@ -0,0 +1,122 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_NAME_MAPPER_H_ +#define SOURCE_NAME_MAPPER_H_ + +#include +#include +#include +#include + +#include "source/assembly_grammar.h" +#include "spirv-tools/libspirv.h" + +namespace spvtools { + +// A NameMapper maps SPIR-V Id values to names. Each name is valid to use in +// SPIR-V assembly. The mapping is one-to-one, i.e. no two Ids map to the same +// name. +using NameMapper = std::function; + +// Returns a NameMapper which always maps an Id to its decimal representation. +NameMapper GetTrivialNameMapper(); + +// A FriendlyNameMapper parses a module upon construction. If the parse is +// successful, then the NameForId method maps an Id to a friendly name +// while also satisfying the constraints on a NameMapper. +// +// The mapping is friendly in the following sense: +// - If an Id has a debug name (via OpName), then that will be used when +// possible. +// - Well known scalar types map to friendly names. For example, +// OpTypeVoid should be %void. Scalar types map to their names in OpenCL +// when +// there is a correspondence, and otherwise as follows: +// - unsigned integer type of n bits map to "u" followed by n +// - signed integer type of n bits map to "i" followed by n +// - floating point type of n bits map to "fp" followed by n +// - Vector type names map to "v" followed by the number of components, +// followed by the friendly name for the base type. +// - Matrix type names map to "mat" followed by the number of columns, +// followed by the friendly name for the base vector type. +// - Pointer types map to "_ptr_", then the name of the storage class, then the +// name for the pointee type. +// - Exotic types like event, pipe, opaque, queue, reserve-id map to their own +// human readable names. +// - A struct type maps to "_struct_" followed by the raw Id number. That's +// pretty simplistic, but workable. +// - A built-in variable maps to its GLSL variable name. +// - Numeric literals in OpConstant map to a human-friendly name. +class FriendlyNameMapper { + public: + // Construct a friendly name mapper, and determine friendly names for each + // defined Id in the specified module. The module is specified by the code + // wordCount, and should be parseable in the specified context. + FriendlyNameMapper(const spv_const_context context, const uint32_t* code, + const size_t wordCount); + + // Returns a NameMapper which maps ids to the friendly names parsed from the + // module provided to the constructor. + NameMapper GetNameMapper() { + return [this](uint32_t id) { return this->NameForId(id); }; + } + + // Returns the friendly name for the given id. If the module parsed during + // construction is valid, then the mapping satisfies the rules for a + // NameMapper. + std::string NameForId(uint32_t id); + + private: + // Transforms the given string so that it is acceptable as an Id name in + // assembly language. Two distinct inputs can map to the same output. + std::string Sanitize(const std::string& suggested_name); + + // Records a name for the given id. If this id already has a name, then + // this is a no-op. If the id doesn't have a name, use the given + // suggested_name if it hasn't already been taken, and otherwise generate + // a new (unused) name based on the suggested name. + void SaveName(uint32_t id, const std::string& suggested_name); + + // Records a built-in variable name for target_id. If target_id already + // has a name then this is a no-op. + void SaveBuiltInName(uint32_t target_id, uint32_t built_in); + + // Collects information from the given parsed instruction to populate + // name_for_id_. Returns SPV_SUCCESS; + spv_result_t ParseInstruction(const spv_parsed_instruction_t& inst); + + // Forwards a parsed-instruction callback from the binary parser into the + // FriendlyNameMapper hidden inside the user_data parameter. + static spv_result_t ParseInstructionForwarder( + void* user_data, const spv_parsed_instruction_t* parsed_instruction) { + return reinterpret_cast(user_data)->ParseInstruction( + *parsed_instruction); + } + + // Returns the friendly name for an enumerant. + std::string NameForEnumOperand(spv_operand_type_t type, uint32_t word); + + // Maps an id to its friendly name. This will have an entry for each Id + // defined in the module. + std::unordered_map name_for_id_; + // The set of names that have a mapping in name_for_id_; + std::unordered_set used_names_; + // The assembly grammar for the current context. + const AssemblyGrammar grammar_; +}; + +} // namespace spvtools + +#endif // SOURCE_NAME_MAPPER_H_ diff --git a/source/opcode.cpp b/source/opcode.cpp new file mode 100644 index 000000000..78c238686 --- /dev/null +++ b/source/opcode.cpp @@ -0,0 +1,605 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opcode.h" + +#include +#include + +#include +#include + +#include "source/instruction.h" +#include "source/macro.h" +#include "source/spirv_constant.h" +#include "source/spirv_endian.h" +#include "source/spirv_target_env.h" +#include "spirv-tools/libspirv.h" + +namespace { +struct OpcodeDescPtrLen { + const spv_opcode_desc_t* ptr; + uint32_t len; +}; + +#include "core.insts-unified1.inc" + +static const spv_opcode_table_t kOpcodeTable = {ARRAY_SIZE(kOpcodeTableEntries), + kOpcodeTableEntries}; + +// Represents a vendor tool entry in the SPIR-V XML Regsitry. +struct VendorTool { + uint32_t value; + const char* vendor; + const char* tool; // Might be empty string. + const char* vendor_tool; // Combiantion of vendor and tool. +}; + +const VendorTool vendor_tools[] = { +#include "generators.inc" +}; + +} // anonymous namespace + +// TODO(dneto): Move this to another file. It doesn't belong with opcode +// processing. +const char* spvGeneratorStr(uint32_t generator) { + auto where = std::find_if( + std::begin(vendor_tools), std::end(vendor_tools), + [generator](const VendorTool& vt) { return generator == vt.value; }); + if (where != std::end(vendor_tools)) return where->vendor_tool; + return "Unknown"; +} + +uint32_t spvOpcodeMake(uint16_t wordCount, SpvOp opcode) { + return ((uint32_t)opcode) | (((uint32_t)wordCount) << 16); +} + +void spvOpcodeSplit(const uint32_t word, uint16_t* pWordCount, + uint16_t* pOpcode) { + if (pWordCount) { + *pWordCount = (uint16_t)((0xffff0000 & word) >> 16); + } + if (pOpcode) { + *pOpcode = 0x0000ffff & word; + } +} + +spv_result_t spvOpcodeTableGet(spv_opcode_table* pInstTable, spv_target_env) { + if (!pInstTable) return SPV_ERROR_INVALID_POINTER; + + // Descriptions of each opcode. Each entry describes the format of the + // instruction that follows a particular opcode. + + *pInstTable = &kOpcodeTable; + return SPV_SUCCESS; +} + +spv_result_t spvOpcodeTableNameLookup(spv_target_env env, + const spv_opcode_table table, + const char* name, + spv_opcode_desc* pEntry) { + if (!name || !pEntry) return SPV_ERROR_INVALID_POINTER; + if (!table) return SPV_ERROR_INVALID_TABLE; + + // TODO: This lookup of the Opcode table is suboptimal! Binary sort would be + // preferable but the table requires sorting on the Opcode name, but it's + // static const initialized and matches the order of the spec. + const size_t nameLength = strlen(name); + for (uint64_t opcodeIndex = 0; opcodeIndex < table->count; ++opcodeIndex) { + const spv_opcode_desc_t& entry = table->entries[opcodeIndex]; + // We considers the current opcode as available as long as + // 1. The target environment satisfies the minimal requirement of the + // opcode; or + // 2. There is at least one extension enabling this opcode. + // + // Note that the second rule assumes the extension enabling this instruction + // is indeed requested in the SPIR-V code; checking that should be + // validator's work. + if ((spvVersionForTargetEnv(env) >= entry.minVersion || + entry.numExtensions > 0u || entry.numCapabilities > 0u) && + nameLength == strlen(entry.name) && + !strncmp(name, entry.name, nameLength)) { + // NOTE: Found out Opcode! + *pEntry = &entry; + return SPV_SUCCESS; + } + } + + return SPV_ERROR_INVALID_LOOKUP; +} + +spv_result_t spvOpcodeTableValueLookup(spv_target_env env, + const spv_opcode_table table, + const SpvOp opcode, + spv_opcode_desc* pEntry) { + if (!table) return SPV_ERROR_INVALID_TABLE; + if (!pEntry) return SPV_ERROR_INVALID_POINTER; + + const auto beg = table->entries; + const auto end = table->entries + table->count; + + spv_opcode_desc_t needle = {"", opcode, 0, nullptr, 0, {}, + false, false, 0, nullptr, ~0u}; + + auto comp = [](const spv_opcode_desc_t& lhs, const spv_opcode_desc_t& rhs) { + return lhs.opcode < rhs.opcode; + }; + + // We need to loop here because there can exist multiple symbols for the same + // opcode value, and they can be introduced in different target environments, + // which means they can have different minimal version requirements. + // Assumes the underlying table is already sorted ascendingly according to + // opcode value. + for (auto it = std::lower_bound(beg, end, needle, comp); + it != end && it->opcode == opcode; ++it) { + // We considers the current opcode as available as long as + // 1. The target environment satisfies the minimal requirement of the + // opcode; or + // 2. There is at least one extension enabling this opcode. + // + // Note that the second rule assumes the extension enabling this instruction + // is indeed requested in the SPIR-V code; checking that should be + // validator's work. + if (spvVersionForTargetEnv(env) >= it->minVersion || + it->numExtensions > 0u || it->numCapabilities > 0u) { + *pEntry = it; + return SPV_SUCCESS; + } + } + + return SPV_ERROR_INVALID_LOOKUP; +} + +void spvInstructionCopy(const uint32_t* words, const SpvOp opcode, + const uint16_t wordCount, const spv_endianness_t endian, + spv_instruction_t* pInst) { + pInst->opcode = opcode; + pInst->words.resize(wordCount); + for (uint16_t wordIndex = 0; wordIndex < wordCount; ++wordIndex) { + pInst->words[wordIndex] = spvFixWord(words[wordIndex], endian); + if (!wordIndex) { + uint16_t thisWordCount; + uint16_t thisOpcode; + spvOpcodeSplit(pInst->words[wordIndex], &thisWordCount, &thisOpcode); + assert(opcode == static_cast(thisOpcode) && + wordCount == thisWordCount && "Endianness failed!"); + } + } +} + +const char* spvOpcodeString(const SpvOp opcode) { + const auto beg = kOpcodeTableEntries; + const auto end = kOpcodeTableEntries + ARRAY_SIZE(kOpcodeTableEntries); + spv_opcode_desc_t needle = {"", opcode, 0, nullptr, 0, {}, + false, false, 0, nullptr, ~0u}; + auto comp = [](const spv_opcode_desc_t& lhs, const spv_opcode_desc_t& rhs) { + return lhs.opcode < rhs.opcode; + }; + auto it = std::lower_bound(beg, end, needle, comp); + if (it != end && it->opcode == opcode) { + return it->name; + } + + assert(0 && "Unreachable!"); + return "unknown"; +} + +int32_t spvOpcodeIsScalarType(const SpvOp opcode) { + switch (opcode) { + case SpvOpTypeInt: + case SpvOpTypeFloat: + case SpvOpTypeBool: + return true; + default: + return false; + } +} + +int32_t spvOpcodeIsSpecConstant(const SpvOp opcode) { + switch (opcode) { + case SpvOpSpecConstantTrue: + case SpvOpSpecConstantFalse: + case SpvOpSpecConstant: + case SpvOpSpecConstantComposite: + case SpvOpSpecConstantOp: + return true; + default: + return false; + } +} + +int32_t spvOpcodeIsConstant(const SpvOp opcode) { + switch (opcode) { + case SpvOpConstantTrue: + case SpvOpConstantFalse: + case SpvOpConstant: + case SpvOpConstantComposite: + case SpvOpConstantSampler: + case SpvOpConstantNull: + case SpvOpSpecConstantTrue: + case SpvOpSpecConstantFalse: + case SpvOpSpecConstant: + case SpvOpSpecConstantComposite: + case SpvOpSpecConstantOp: + return true; + default: + return false; + } +} + +bool spvOpcodeIsConstantOrUndef(const SpvOp opcode) { + return opcode == SpvOpUndef || spvOpcodeIsConstant(opcode); +} + +bool spvOpcodeIsScalarSpecConstant(const SpvOp opcode) { + switch (opcode) { + case SpvOpSpecConstantTrue: + case SpvOpSpecConstantFalse: + case SpvOpSpecConstant: + return true; + default: + return false; + } +} + +int32_t spvOpcodeIsComposite(const SpvOp opcode) { + switch (opcode) { + case SpvOpTypeVector: + case SpvOpTypeMatrix: + case SpvOpTypeArray: + case SpvOpTypeStruct: + return true; + default: + return false; + } +} + +bool spvOpcodeReturnsLogicalVariablePointer(const SpvOp opcode) { + switch (opcode) { + case SpvOpVariable: + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + case SpvOpFunctionParameter: + case SpvOpImageTexelPointer: + case SpvOpCopyObject: + case SpvOpSelect: + case SpvOpPhi: + case SpvOpFunctionCall: + case SpvOpPtrAccessChain: + case SpvOpLoad: + case SpvOpConstantNull: + return true; + default: + return false; + } +} + +int32_t spvOpcodeReturnsLogicalPointer(const SpvOp opcode) { + switch (opcode) { + case SpvOpVariable: + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + case SpvOpFunctionParameter: + case SpvOpImageTexelPointer: + case SpvOpCopyObject: + return true; + default: + return false; + } +} + +int32_t spvOpcodeGeneratesType(SpvOp op) { + switch (op) { + case SpvOpTypeVoid: + case SpvOpTypeBool: + case SpvOpTypeInt: + case SpvOpTypeFloat: + case SpvOpTypeVector: + case SpvOpTypeMatrix: + case SpvOpTypeImage: + case SpvOpTypeSampler: + case SpvOpTypeSampledImage: + case SpvOpTypeArray: + case SpvOpTypeRuntimeArray: + case SpvOpTypeStruct: + case SpvOpTypeOpaque: + case SpvOpTypePointer: + case SpvOpTypeFunction: + case SpvOpTypeEvent: + case SpvOpTypeDeviceEvent: + case SpvOpTypeReserveId: + case SpvOpTypeQueue: + case SpvOpTypePipe: + case SpvOpTypePipeStorage: + case SpvOpTypeNamedBarrier: + case SpvOpTypeAccelerationStructureNV: + return true; + default: + // In particular, OpTypeForwardPointer does not generate a type, + // but declares a storage class for a pointer type generated + // by a different instruction. + break; + } + return 0; +} + +bool spvOpcodeIsDecoration(const SpvOp opcode) { + switch (opcode) { + case SpvOpDecorate: + case SpvOpDecorateId: + case SpvOpMemberDecorate: + case SpvOpGroupDecorate: + case SpvOpGroupMemberDecorate: + case SpvOpDecorateStringGOOGLE: + case SpvOpMemberDecorateStringGOOGLE: + return true; + default: + break; + } + return false; +} + +bool spvOpcodeIsLoad(const SpvOp opcode) { + switch (opcode) { + case SpvOpLoad: + case SpvOpImageSampleExplicitLod: + case SpvOpImageSampleImplicitLod: + case SpvOpImageSampleDrefImplicitLod: + case SpvOpImageSampleDrefExplicitLod: + case SpvOpImageSampleProjImplicitLod: + case SpvOpImageSampleProjExplicitLod: + case SpvOpImageSampleProjDrefImplicitLod: + case SpvOpImageSampleProjDrefExplicitLod: + case SpvOpImageFetch: + case SpvOpImageGather: + case SpvOpImageDrefGather: + case SpvOpImageRead: + case SpvOpImageSparseSampleImplicitLod: + case SpvOpImageSparseSampleExplicitLod: + case SpvOpImageSparseSampleDrefExplicitLod: + case SpvOpImageSparseSampleDrefImplicitLod: + case SpvOpImageSparseFetch: + case SpvOpImageSparseGather: + case SpvOpImageSparseDrefGather: + case SpvOpImageSparseRead: + return true; + default: + return false; + } +} + +bool spvOpcodeIsBranch(SpvOp opcode) { + switch (opcode) { + case SpvOpBranch: + case SpvOpBranchConditional: + case SpvOpSwitch: + return true; + default: + return false; + } +} + +bool spvOpcodeIsAtomicWithLoad(const SpvOp opcode) { + switch (opcode) { + case SpvOpAtomicLoad: + case SpvOpAtomicExchange: + case SpvOpAtomicCompareExchange: + case SpvOpAtomicCompareExchangeWeak: + case SpvOpAtomicIIncrement: + case SpvOpAtomicIDecrement: + case SpvOpAtomicIAdd: + case SpvOpAtomicISub: + case SpvOpAtomicSMin: + case SpvOpAtomicUMin: + case SpvOpAtomicSMax: + case SpvOpAtomicUMax: + case SpvOpAtomicAnd: + case SpvOpAtomicOr: + case SpvOpAtomicXor: + case SpvOpAtomicFlagTestAndSet: + return true; + default: + return false; + } +} + +bool spvOpcodeIsAtomicOp(const SpvOp opcode) { + return (spvOpcodeIsAtomicWithLoad(opcode) || opcode == SpvOpAtomicStore || + opcode == SpvOpAtomicFlagClear); +} + +bool spvOpcodeIsReturn(SpvOp opcode) { + switch (opcode) { + case SpvOpReturn: + case SpvOpReturnValue: + return true; + default: + return false; + } +} + +bool spvOpcodeIsReturnOrAbort(SpvOp opcode) { + return spvOpcodeIsReturn(opcode) || opcode == SpvOpKill || + opcode == SpvOpUnreachable; +} + +bool spvOpcodeIsBlockTerminator(SpvOp opcode) { + return spvOpcodeIsBranch(opcode) || spvOpcodeIsReturnOrAbort(opcode); +} + +bool spvOpcodeIsBaseOpaqueType(SpvOp opcode) { + switch (opcode) { + case SpvOpTypeImage: + case SpvOpTypeSampler: + case SpvOpTypeSampledImage: + case SpvOpTypeOpaque: + case SpvOpTypeEvent: + case SpvOpTypeDeviceEvent: + case SpvOpTypeReserveId: + case SpvOpTypeQueue: + case SpvOpTypePipe: + case SpvOpTypeForwardPointer: + case SpvOpTypePipeStorage: + case SpvOpTypeNamedBarrier: + return true; + default: + return false; + } +} + +bool spvOpcodeIsNonUniformGroupOperation(SpvOp opcode) { + switch (opcode) { + case SpvOpGroupNonUniformElect: + case SpvOpGroupNonUniformAll: + case SpvOpGroupNonUniformAny: + case SpvOpGroupNonUniformAllEqual: + case SpvOpGroupNonUniformBroadcast: + case SpvOpGroupNonUniformBroadcastFirst: + case SpvOpGroupNonUniformBallot: + case SpvOpGroupNonUniformInverseBallot: + case SpvOpGroupNonUniformBallotBitExtract: + case SpvOpGroupNonUniformBallotBitCount: + case SpvOpGroupNonUniformBallotFindLSB: + case SpvOpGroupNonUniformBallotFindMSB: + case SpvOpGroupNonUniformShuffle: + case SpvOpGroupNonUniformShuffleXor: + case SpvOpGroupNonUniformShuffleUp: + case SpvOpGroupNonUniformShuffleDown: + case SpvOpGroupNonUniformIAdd: + case SpvOpGroupNonUniformFAdd: + case SpvOpGroupNonUniformIMul: + case SpvOpGroupNonUniformFMul: + case SpvOpGroupNonUniformSMin: + case SpvOpGroupNonUniformUMin: + case SpvOpGroupNonUniformFMin: + case SpvOpGroupNonUniformSMax: + case SpvOpGroupNonUniformUMax: + case SpvOpGroupNonUniformFMax: + case SpvOpGroupNonUniformBitwiseAnd: + case SpvOpGroupNonUniformBitwiseOr: + case SpvOpGroupNonUniformBitwiseXor: + case SpvOpGroupNonUniformLogicalAnd: + case SpvOpGroupNonUniformLogicalOr: + case SpvOpGroupNonUniformLogicalXor: + case SpvOpGroupNonUniformQuadBroadcast: + case SpvOpGroupNonUniformQuadSwap: + return true; + default: + return false; + } +} + +bool spvOpcodeIsScalarizable(SpvOp opcode) { + switch (opcode) { + case SpvOpPhi: + case SpvOpCopyObject: + case SpvOpConvertFToU: + case SpvOpConvertFToS: + case SpvOpConvertSToF: + case SpvOpConvertUToF: + case SpvOpUConvert: + case SpvOpSConvert: + case SpvOpFConvert: + case SpvOpQuantizeToF16: + case SpvOpVectorInsertDynamic: + case SpvOpSNegate: + case SpvOpFNegate: + case SpvOpIAdd: + case SpvOpFAdd: + case SpvOpISub: + case SpvOpFSub: + case SpvOpIMul: + case SpvOpFMul: + case SpvOpUDiv: + case SpvOpSDiv: + case SpvOpFDiv: + case SpvOpUMod: + case SpvOpSRem: + case SpvOpSMod: + case SpvOpFRem: + case SpvOpFMod: + case SpvOpVectorTimesScalar: + case SpvOpIAddCarry: + case SpvOpISubBorrow: + case SpvOpUMulExtended: + case SpvOpSMulExtended: + case SpvOpShiftRightLogical: + case SpvOpShiftRightArithmetic: + case SpvOpShiftLeftLogical: + case SpvOpBitwiseOr: + case SpvOpBitwiseAnd: + case SpvOpNot: + case SpvOpBitFieldInsert: + case SpvOpBitFieldSExtract: + case SpvOpBitFieldUExtract: + case SpvOpBitReverse: + case SpvOpBitCount: + case SpvOpIsNan: + case SpvOpIsInf: + case SpvOpIsFinite: + case SpvOpIsNormal: + case SpvOpSignBitSet: + case SpvOpLessOrGreater: + case SpvOpOrdered: + case SpvOpUnordered: + case SpvOpLogicalEqual: + case SpvOpLogicalNotEqual: + case SpvOpLogicalOr: + case SpvOpLogicalAnd: + case SpvOpLogicalNot: + case SpvOpSelect: + case SpvOpIEqual: + case SpvOpINotEqual: + case SpvOpUGreaterThan: + case SpvOpSGreaterThan: + case SpvOpUGreaterThanEqual: + case SpvOpSGreaterThanEqual: + case SpvOpULessThan: + case SpvOpSLessThan: + case SpvOpULessThanEqual: + case SpvOpSLessThanEqual: + case SpvOpFOrdEqual: + case SpvOpFUnordEqual: + case SpvOpFOrdNotEqual: + case SpvOpFUnordNotEqual: + case SpvOpFOrdLessThan: + case SpvOpFUnordLessThan: + case SpvOpFOrdGreaterThan: + case SpvOpFUnordGreaterThan: + case SpvOpFOrdLessThanEqual: + case SpvOpFUnordLessThanEqual: + case SpvOpFOrdGreaterThanEqual: + case SpvOpFUnordGreaterThanEqual: + return true; + default: + return false; + } +} + +bool spvOpcodeIsDebug(SpvOp opcode) { + switch (opcode) { + case SpvOpName: + case SpvOpMemberName: + case SpvOpSource: + case SpvOpSourceContinued: + case SpvOpSourceExtension: + case SpvOpString: + case SpvOpLine: + case SpvOpNoLine: + return true; + default: + return false; + } +} diff --git a/source/opcode.h b/source/opcode.h new file mode 100644 index 000000000..76f9a0e84 --- /dev/null +++ b/source/opcode.h @@ -0,0 +1,136 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPCODE_H_ +#define SOURCE_OPCODE_H_ + +#include "source/instruction.h" +#include "source/latest_version_spirv_header.h" +#include "source/table.h" +#include "spirv-tools/libspirv.h" + +// Returns the name of a registered SPIR-V generator as a null-terminated +// string. If the generator is not known, then returns the string "Unknown". +// The generator parameter should be most significant 16-bits of the generator +// word in the SPIR-V module header. +// +// See the registry at https://www.khronos.org/registry/spir-v/api/spir-v.xml. +const char* spvGeneratorStr(uint32_t generator); + +// Combines word_count and opcode enumerant in single word. +uint32_t spvOpcodeMake(uint16_t word_count, SpvOp opcode); + +// Splits word into into two constituent parts: word_count and opcode. +void spvOpcodeSplit(const uint32_t word, uint16_t* word_count, + uint16_t* opcode); + +// Finds the named opcode in the given opcode table. On success, returns +// SPV_SUCCESS and writes a handle of the table entry into *entry. +spv_result_t spvOpcodeTableNameLookup(spv_target_env, + const spv_opcode_table table, + const char* name, spv_opcode_desc* entry); + +// Finds the opcode by enumerant in the given opcode table. On success, returns +// SPV_SUCCESS and writes a handle of the table entry into *entry. +spv_result_t spvOpcodeTableValueLookup(spv_target_env, + const spv_opcode_table table, + const SpvOp opcode, + spv_opcode_desc* entry); + +// Copies an instruction's word and fixes the endianness to host native. The +// source instruction's stream/opcode/endianness is in the words/opcode/endian +// parameter. The word_count parameter specifies the number of words to copy. +// Writes copied instruction into *inst. +void spvInstructionCopy(const uint32_t* words, const SpvOp opcode, + const uint16_t word_count, + const spv_endianness_t endian, spv_instruction_t* inst); + +// Gets the name of an instruction, without the "Op" prefix. +const char* spvOpcodeString(const SpvOp opcode); + +// Determine if the given opcode is a scalar type. Returns zero if false, +// non-zero otherwise. +int32_t spvOpcodeIsScalarType(const SpvOp opcode); + +// Determines if the given opcode is a specialization constant. Returns zero if +// false, non-zero otherwise. +int32_t spvOpcodeIsSpecConstant(const SpvOp opcode); + +// Determines if the given opcode is a constant. Returns zero if false, non-zero +// otherwise. +int32_t spvOpcodeIsConstant(const SpvOp opcode); + +// Returns true if the given opcode is a constant or undef. +bool spvOpcodeIsConstantOrUndef(const SpvOp opcode); + +// Returns true if the given opcode is a scalar specialization constant. +bool spvOpcodeIsScalarSpecConstant(const SpvOp opcode); + +// Determines if the given opcode is a composite type. Returns zero if false, +// non-zero otherwise. +int32_t spvOpcodeIsComposite(const SpvOp opcode); + +// Determines if the given opcode results in a pointer when using the logical +// addressing model. Returns zero if false, non-zero otherwise. +int32_t spvOpcodeReturnsLogicalPointer(const SpvOp opcode); + +// Returns whether the given opcode could result in a pointer or a variable +// pointer when using the logical addressing model. +bool spvOpcodeReturnsLogicalVariablePointer(const SpvOp opcode); + +// Determines if the given opcode generates a type. Returns zero if false, +// non-zero otherwise. +int32_t spvOpcodeGeneratesType(SpvOp opcode); + +// Returns true if the opcode adds a decoration to an id. +bool spvOpcodeIsDecoration(const SpvOp opcode); + +// Returns true if the opcode is a load from memory into a result id. This +// function only considers core instructions. +bool spvOpcodeIsLoad(const SpvOp opcode); + +// Returns true if the opcode is an atomic operation that uses the original +// value. +bool spvOpcodeIsAtomicWithLoad(const SpvOp opcode); + +// Returns true if the opcode is an atomic operation. +bool spvOpcodeIsAtomicOp(const SpvOp opcode); + +// Returns true if the given opcode is a branch instruction. +bool spvOpcodeIsBranch(SpvOp opcode); + +// Returns true if the given opcode is a return instruction. +bool spvOpcodeIsReturn(SpvOp opcode); + +// Returns true if the given opcode is a return instruction or it aborts +// execution. +bool spvOpcodeIsReturnOrAbort(SpvOp opcode); + +// Returns true if the given opcode is a basic block terminator. +bool spvOpcodeIsBlockTerminator(SpvOp opcode); + +// Returns true if the given opcode always defines an opaque type. +bool spvOpcodeIsBaseOpaqueType(SpvOp opcode); + +// Returns true if the given opcode is a non-uniform group operation. +bool spvOpcodeIsNonUniformGroupOperation(SpvOp opcode); + +// Returns true if the opcode with vector inputs could be divided into a series +// of independent scalar operations that would give the same result. +bool spvOpcodeIsScalarizable(SpvOp opcode); + +// Returns true if the given opcode is a debug instruction. +bool spvOpcodeIsDebug(SpvOp opcode); + +#endif // SOURCE_OPCODE_H_ diff --git a/source/operand.cpp b/source/operand.cpp new file mode 100644 index 000000000..00dc53d1d --- /dev/null +++ b/source/operand.cpp @@ -0,0 +1,492 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/operand.h" + +#include +#include +#include + +#include "source/macro.h" +#include "source/spirv_constant.h" +#include "source/spirv_target_env.h" + +// For now, assume unified1 contains up to SPIR-V 1.3 and no later +// SPIR-V version. +// TODO(dneto): Make one set of tables, but with version tags on a +// per-item basis. https://github.com/KhronosGroup/SPIRV-Tools/issues/1195 + +#include "operand.kinds-unified1.inc" + +static const spv_operand_table_t kOperandTable = { + ARRAY_SIZE(pygen_variable_OperandInfoTable), + pygen_variable_OperandInfoTable}; + +spv_result_t spvOperandTableGet(spv_operand_table* pOperandTable, + spv_target_env) { + if (!pOperandTable) return SPV_ERROR_INVALID_POINTER; + + *pOperandTable = &kOperandTable; + return SPV_SUCCESS; +} + +spv_result_t spvOperandTableNameLookup(spv_target_env env, + const spv_operand_table table, + const spv_operand_type_t type, + const char* name, + const size_t nameLength, + spv_operand_desc* pEntry) { + if (!table) return SPV_ERROR_INVALID_TABLE; + if (!name || !pEntry) return SPV_ERROR_INVALID_POINTER; + + for (uint64_t typeIndex = 0; typeIndex < table->count; ++typeIndex) { + const auto& group = table->types[typeIndex]; + if (type != group.type) continue; + for (uint64_t index = 0; index < group.count; ++index) { + const auto& entry = group.entries[index]; + // We consider the current operand as available as long as + // 1. The target environment satisfies the minimal requirement of the + // operand; or + // 2. There is at least one extension enabling this operand; or + // 3. There is at least one capability enabling this operand. + // + // Note that the second rule assumes the extension enabling this operand + // is indeed requested in the SPIR-V code; checking that should be + // validator's work. + if ((spvVersionForTargetEnv(env) >= entry.minVersion || + entry.numExtensions > 0u || entry.numCapabilities > 0u) && + nameLength == strlen(entry.name) && + !strncmp(entry.name, name, nameLength)) { + *pEntry = &entry; + return SPV_SUCCESS; + } + } + } + + return SPV_ERROR_INVALID_LOOKUP; +} + +spv_result_t spvOperandTableValueLookup(spv_target_env env, + const spv_operand_table table, + const spv_operand_type_t type, + const uint32_t value, + spv_operand_desc* pEntry) { + if (!table) return SPV_ERROR_INVALID_TABLE; + if (!pEntry) return SPV_ERROR_INVALID_POINTER; + + spv_operand_desc_t needle = {"", value, 0, nullptr, 0, nullptr, {}, ~0u}; + + auto comp = [](const spv_operand_desc_t& lhs, const spv_operand_desc_t& rhs) { + return lhs.value < rhs.value; + }; + + for (uint64_t typeIndex = 0; typeIndex < table->count; ++typeIndex) { + const auto& group = table->types[typeIndex]; + if (type != group.type) continue; + + const auto beg = group.entries; + const auto end = group.entries + group.count; + + // We need to loop here because there can exist multiple symbols for the + // same operand value, and they can be introduced in different target + // environments, which means they can have different minimal version + // requirements. For example, SubgroupEqMaskKHR can exist in any SPIR-V + // version as long as the SPV_KHR_shader_ballot extension is there; but + // starting from SPIR-V 1.3, SubgroupEqMask, which has the same numeric + // value as SubgroupEqMaskKHR, is available in core SPIR-V without extension + // requirements. + // Assumes the underlying table is already sorted ascendingly according to + // opcode value. + for (auto it = std::lower_bound(beg, end, needle, comp); + it != end && it->value == value; ++it) { + // We consider the current operand as available as long as + // 1. The target environment satisfies the minimal requirement of the + // operand; or + // 2. There is at least one extension enabling this operand; or + // 3. There is at least one capability enabling this operand. + // + // Note that the second rule assumes the extension enabling this operand + // is indeed requested in the SPIR-V code; checking that should be + // validator's work. + if (spvVersionForTargetEnv(env) >= it->minVersion || + it->numExtensions > 0u || it->numCapabilities > 0u) { + *pEntry = it; + return SPV_SUCCESS; + } + } + } + + return SPV_ERROR_INVALID_LOOKUP; +} + +const char* spvOperandTypeStr(spv_operand_type_t type) { + switch (type) { + case SPV_OPERAND_TYPE_ID: + case SPV_OPERAND_TYPE_OPTIONAL_ID: + return "ID"; + case SPV_OPERAND_TYPE_TYPE_ID: + return "type ID"; + case SPV_OPERAND_TYPE_RESULT_ID: + return "result ID"; + case SPV_OPERAND_TYPE_LITERAL_INTEGER: + case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: + case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_NUMBER: + return "literal number"; + case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER: + return "possibly multi-word literal integer"; + case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: + return "possibly multi-word literal number"; + case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: + return "extension instruction number"; + case SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER: + return "OpSpecConstantOp opcode"; + case SPV_OPERAND_TYPE_LITERAL_STRING: + case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: + return "literal string"; + case SPV_OPERAND_TYPE_SOURCE_LANGUAGE: + return "source language"; + case SPV_OPERAND_TYPE_EXECUTION_MODEL: + return "execution model"; + case SPV_OPERAND_TYPE_ADDRESSING_MODEL: + return "addressing model"; + case SPV_OPERAND_TYPE_MEMORY_MODEL: + return "memory model"; + case SPV_OPERAND_TYPE_EXECUTION_MODE: + return "execution mode"; + case SPV_OPERAND_TYPE_STORAGE_CLASS: + return "storage class"; + case SPV_OPERAND_TYPE_DIMENSIONALITY: + return "dimensionality"; + case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE: + return "sampler addressing mode"; + case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE: + return "sampler filter mode"; + case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT: + return "image format"; + case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE: + return "floating-point fast math mode"; + case SPV_OPERAND_TYPE_FP_ROUNDING_MODE: + return "floating-point rounding mode"; + case SPV_OPERAND_TYPE_LINKAGE_TYPE: + return "linkage type"; + case SPV_OPERAND_TYPE_ACCESS_QUALIFIER: + case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER: + return "access qualifier"; + case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE: + return "function parameter attribute"; + case SPV_OPERAND_TYPE_DECORATION: + return "decoration"; + case SPV_OPERAND_TYPE_BUILT_IN: + return "built-in"; + case SPV_OPERAND_TYPE_SELECTION_CONTROL: + return "selection control"; + case SPV_OPERAND_TYPE_LOOP_CONTROL: + return "loop control"; + case SPV_OPERAND_TYPE_FUNCTION_CONTROL: + return "function control"; + case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: + return "memory semantics ID"; + case SPV_OPERAND_TYPE_MEMORY_ACCESS: + case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS: + return "memory access"; + case SPV_OPERAND_TYPE_SCOPE_ID: + return "scope ID"; + case SPV_OPERAND_TYPE_GROUP_OPERATION: + return "group operation"; + case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS: + return "kernel enqeue flags"; + case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: + return "kernel profiling info"; + case SPV_OPERAND_TYPE_CAPABILITY: + return "capability"; + case SPV_OPERAND_TYPE_IMAGE: + case SPV_OPERAND_TYPE_OPTIONAL_IMAGE: + return "image"; + case SPV_OPERAND_TYPE_OPTIONAL_CIV: + return "context-insensitive value"; + case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS: + return "debug info flags"; + case SPV_OPERAND_TYPE_DEBUG_BASE_TYPE_ATTRIBUTE_ENCODING: + return "debug base type encoding"; + case SPV_OPERAND_TYPE_DEBUG_COMPOSITE_TYPE: + return "debug composite type"; + case SPV_OPERAND_TYPE_DEBUG_TYPE_QUALIFIER: + return "debug type qualifier"; + case SPV_OPERAND_TYPE_DEBUG_OPERATION: + return "debug operation"; + + // The next values are for values returned from an instruction, not actually + // an operand. So the specific strings don't matter. But let's add them + // for completeness and ease of testing. + case SPV_OPERAND_TYPE_IMAGE_CHANNEL_ORDER: + return "image channel order"; + case SPV_OPERAND_TYPE_IMAGE_CHANNEL_DATA_TYPE: + return "image channel data type"; + + case SPV_OPERAND_TYPE_NONE: + return "NONE"; + default: + assert(0 && "Unhandled operand type!"); + break; + } + return "unknown"; +} + +void spvPushOperandTypes(const spv_operand_type_t* types, + spv_operand_pattern_t* pattern) { + const spv_operand_type_t* endTypes; + for (endTypes = types; *endTypes != SPV_OPERAND_TYPE_NONE; ++endTypes) { + } + + while (endTypes-- != types) { + pattern->push_back(*endTypes); + } +} + +void spvPushOperandTypesForMask(spv_target_env env, + const spv_operand_table operandTable, + const spv_operand_type_t type, + const uint32_t mask, + spv_operand_pattern_t* pattern) { + // Scan from highest bits to lowest bits because we will append in LIFO + // fashion, and we need the operands for lower order bits to be consumed first + for (uint32_t candidate_bit = (1u << 31u); candidate_bit; + candidate_bit >>= 1) { + if (candidate_bit & mask) { + spv_operand_desc entry = nullptr; + if (SPV_SUCCESS == spvOperandTableValueLookup(env, operandTable, type, + candidate_bit, &entry)) { + spvPushOperandTypes(entry->operandTypes, pattern); + } + } + } +} + +bool spvOperandIsConcrete(spv_operand_type_t type) { + if (spvIsIdType(type) || spvOperandIsConcreteMask(type)) { + return true; + } + switch (type) { + case SPV_OPERAND_TYPE_LITERAL_INTEGER: + case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: + case SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER: + case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: + case SPV_OPERAND_TYPE_LITERAL_STRING: + case SPV_OPERAND_TYPE_SOURCE_LANGUAGE: + case SPV_OPERAND_TYPE_EXECUTION_MODEL: + case SPV_OPERAND_TYPE_ADDRESSING_MODEL: + case SPV_OPERAND_TYPE_MEMORY_MODEL: + case SPV_OPERAND_TYPE_EXECUTION_MODE: + case SPV_OPERAND_TYPE_STORAGE_CLASS: + case SPV_OPERAND_TYPE_DIMENSIONALITY: + case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE: + case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE: + case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT: + case SPV_OPERAND_TYPE_IMAGE_CHANNEL_ORDER: + case SPV_OPERAND_TYPE_IMAGE_CHANNEL_DATA_TYPE: + case SPV_OPERAND_TYPE_FP_ROUNDING_MODE: + case SPV_OPERAND_TYPE_LINKAGE_TYPE: + case SPV_OPERAND_TYPE_ACCESS_QUALIFIER: + case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE: + case SPV_OPERAND_TYPE_DECORATION: + case SPV_OPERAND_TYPE_BUILT_IN: + case SPV_OPERAND_TYPE_GROUP_OPERATION: + case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS: + case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: + case SPV_OPERAND_TYPE_CAPABILITY: + case SPV_OPERAND_TYPE_DEBUG_BASE_TYPE_ATTRIBUTE_ENCODING: + case SPV_OPERAND_TYPE_DEBUG_COMPOSITE_TYPE: + case SPV_OPERAND_TYPE_DEBUG_TYPE_QUALIFIER: + case SPV_OPERAND_TYPE_DEBUG_OPERATION: + return true; + default: + break; + } + return false; +} + +bool spvOperandIsConcreteMask(spv_operand_type_t type) { + switch (type) { + case SPV_OPERAND_TYPE_IMAGE: + case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE: + case SPV_OPERAND_TYPE_SELECTION_CONTROL: + case SPV_OPERAND_TYPE_LOOP_CONTROL: + case SPV_OPERAND_TYPE_FUNCTION_CONTROL: + case SPV_OPERAND_TYPE_MEMORY_ACCESS: + case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS: + return true; + default: + break; + } + return false; +} + +bool spvOperandIsOptional(spv_operand_type_t type) { + return SPV_OPERAND_TYPE_FIRST_OPTIONAL_TYPE <= type && + type <= SPV_OPERAND_TYPE_LAST_OPTIONAL_TYPE; +} + +bool spvOperandIsVariable(spv_operand_type_t type) { + return SPV_OPERAND_TYPE_FIRST_VARIABLE_TYPE <= type && + type <= SPV_OPERAND_TYPE_LAST_VARIABLE_TYPE; +} + +bool spvExpandOperandSequenceOnce(spv_operand_type_t type, + spv_operand_pattern_t* pattern) { + switch (type) { + case SPV_OPERAND_TYPE_VARIABLE_ID: + pattern->push_back(type); + pattern->push_back(SPV_OPERAND_TYPE_OPTIONAL_ID); + return true; + case SPV_OPERAND_TYPE_VARIABLE_LITERAL_INTEGER: + pattern->push_back(type); + pattern->push_back(SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER); + return true; + case SPV_OPERAND_TYPE_VARIABLE_LITERAL_INTEGER_ID: + // Represents Zero or more (Literal number, Id) pairs, + // where the literal number must be a scalar integer. + pattern->push_back(type); + pattern->push_back(SPV_OPERAND_TYPE_ID); + pattern->push_back(SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER); + return true; + case SPV_OPERAND_TYPE_VARIABLE_ID_LITERAL_INTEGER: + // Represents Zero or more (Id, Literal number) pairs. + pattern->push_back(type); + pattern->push_back(SPV_OPERAND_TYPE_LITERAL_INTEGER); + pattern->push_back(SPV_OPERAND_TYPE_OPTIONAL_ID); + return true; + default: + break; + } + return false; +} + +spv_operand_type_t spvTakeFirstMatchableOperand( + spv_operand_pattern_t* pattern) { + assert(!pattern->empty()); + spv_operand_type_t result; + do { + result = pattern->back(); + pattern->pop_back(); + } while (spvExpandOperandSequenceOnce(result, pattern)); + return result; +} + +spv_operand_pattern_t spvAlternatePatternFollowingImmediate( + const spv_operand_pattern_t& pattern) { + auto it = + std::find(pattern.crbegin(), pattern.crend(), SPV_OPERAND_TYPE_RESULT_ID); + if (it != pattern.crend()) { + spv_operand_pattern_t alternatePattern(it - pattern.crbegin() + 2, + SPV_OPERAND_TYPE_OPTIONAL_CIV); + alternatePattern[1] = SPV_OPERAND_TYPE_RESULT_ID; + return alternatePattern; + } + + // No result-id found, so just expect CIVs. + return {SPV_OPERAND_TYPE_OPTIONAL_CIV}; +} + +bool spvIsIdType(spv_operand_type_t type) { + switch (type) { + case SPV_OPERAND_TYPE_ID: + case SPV_OPERAND_TYPE_TYPE_ID: + case SPV_OPERAND_TYPE_RESULT_ID: + case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: + case SPV_OPERAND_TYPE_SCOPE_ID: + return true; + default: + return false; + } +} + +bool spvIsInIdType(spv_operand_type_t type) { + if (!spvIsIdType(type)) { + // If it is not an ID it cannot be an input ID. + return false; + } + switch (type) { + // Blacklist non-input IDs. + case SPV_OPERAND_TYPE_TYPE_ID: + case SPV_OPERAND_TYPE_RESULT_ID: + return false; + default: + return true; + } +} + +std::function spvOperandCanBeForwardDeclaredFunction( + SpvOp opcode) { + std::function out; + switch (opcode) { + case SpvOpExecutionMode: + case SpvOpExecutionModeId: + case SpvOpEntryPoint: + case SpvOpName: + case SpvOpMemberName: + case SpvOpSelectionMerge: + case SpvOpDecorate: + case SpvOpMemberDecorate: + case SpvOpDecorateId: + case SpvOpDecorateStringGOOGLE: + case SpvOpMemberDecorateStringGOOGLE: + case SpvOpTypeStruct: + case SpvOpBranch: + case SpvOpLoopMerge: + out = [](unsigned) { return true; }; + break; + case SpvOpGroupDecorate: + case SpvOpGroupMemberDecorate: + case SpvOpBranchConditional: + case SpvOpSwitch: + out = [](unsigned index) { return index != 0; }; + break; + + case SpvOpFunctionCall: + // The Function parameter. + out = [](unsigned index) { return index == 2; }; + break; + + case SpvOpPhi: + out = [](unsigned index) { return index > 1; }; + break; + + case SpvOpEnqueueKernel: + // The Invoke parameter. + out = [](unsigned index) { return index == 8; }; + break; + + case SpvOpGetKernelNDrangeSubGroupCount: + case SpvOpGetKernelNDrangeMaxSubGroupSize: + // The Invoke parameter. + out = [](unsigned index) { return index == 3; }; + break; + + case SpvOpGetKernelWorkGroupSize: + case SpvOpGetKernelPreferredWorkGroupSizeMultiple: + // The Invoke parameter. + out = [](unsigned index) { return index == 2; }; + break; + case SpvOpTypeForwardPointer: + out = [](unsigned index) { return index == 0; }; + break; + case SpvOpTypeArray: + out = [](unsigned index) { return index == 1; }; + break; + default: + out = [](unsigned) { return false; }; + break; + } + return out; +} diff --git a/source/operand.h b/source/operand.h new file mode 100644 index 000000000..15a182583 --- /dev/null +++ b/source/operand.h @@ -0,0 +1,144 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPERAND_H_ +#define SOURCE_OPERAND_H_ + +#include +#include + +#include "source/table.h" +#include "spirv-tools/libspirv.h" + +// A sequence of operand types. +// +// A SPIR-V parser uses an operand pattern to describe what is expected +// next on the input. +// +// As we parse an instruction in text or binary form from left to right, +// we pop and push at the end of the pattern vector. Symbols later in the +// pattern vector are matched against the input before symbols earlier in the +// pattern vector are matched. + +// Using a vector in this way reduces memory traffic, which is good for +// performance. +using spv_operand_pattern_t = std::vector; + +// Finds the named operand in the table. The type parameter specifies the +// operand's group. A handle of the operand table entry for this operand will +// be written into *entry. +spv_result_t spvOperandTableNameLookup(spv_target_env, + const spv_operand_table table, + const spv_operand_type_t type, + const char* name, + const size_t name_length, + spv_operand_desc* entry); + +// Finds the operand with value in the table. The type parameter specifies the +// operand's group. A handle of the operand table entry for this operand will +// be written into *entry. +spv_result_t spvOperandTableValueLookup(spv_target_env, + const spv_operand_table table, + const spv_operand_type_t type, + const uint32_t value, + spv_operand_desc* entry); + +// Gets the name string of the non-variable operand type. +const char* spvOperandTypeStr(spv_operand_type_t type); + +// Returns true if the given type is concrete. +bool spvOperandIsConcrete(spv_operand_type_t type); + +// Returns true if the given type is concrete and also a mask. +bool spvOperandIsConcreteMask(spv_operand_type_t type); + +// Returns true if an operand of the given type is optional. +bool spvOperandIsOptional(spv_operand_type_t type); + +// Returns true if an operand type represents zero or more logical operands. +// +// Note that a single logical operand may still be a variable number of words. +// For example, a literal string may be many words, but is just one logical +// operand. +bool spvOperandIsVariable(spv_operand_type_t type); + +// Append a list of operand types to the end of the pattern vector. +// The types parameter specifies the source array of types, ending with +// SPV_OPERAND_TYPE_NONE. +void spvPushOperandTypes(const spv_operand_type_t* types, + spv_operand_pattern_t* pattern); + +// Appends the operands expected after the given typed mask onto the +// end of the given pattern. +// +// Each set bit in the mask represents zero or more operand types that should +// be appended onto the pattern. Operands for a less significant bit always +// appear after operands for a more significant bit. +// +// If a set bit is unknown, then we assume it has no operands. +void spvPushOperandTypesForMask(spv_target_env, + const spv_operand_table operand_table, + const spv_operand_type_t mask_type, + const uint32_t mask, + spv_operand_pattern_t* pattern); + +// Expands an operand type representing zero or more logical operands, +// exactly once. +// +// If the given type represents potentially several logical operands, +// then prepend the given pattern with the first expansion of the logical +// operands, followed by original type. Otherwise, don't modify the pattern. +// +// For example, the SPV_OPERAND_TYPE_VARIABLE_ID represents zero or more +// IDs. In that case we would prepend the pattern with SPV_OPERAND_TYPE_ID +// followed by SPV_OPERAND_TYPE_VARIABLE_ID again. +// +// This also applies to zero or more tuples of logical operands. In that case +// we prepend pattern with for the members of the tuple, followed by the +// original type argument. The pattern must encode the fact that if any part +// of the tuple is present, then all tuple members should be. So the first +// member of the tuple must be optional, and the remaining members +// non-optional. +// +// Returns true if we modified the pattern. +bool spvExpandOperandSequenceOnce(spv_operand_type_t type, + spv_operand_pattern_t* pattern); + +// Expands the first element in the pattern until it is a matchable operand +// type, then pops it off the front and returns it. The pattern must not be +// empty. +// +// A matchable operand type is anything other than a zero-or-more-items +// operand type. +spv_operand_type_t spvTakeFirstMatchableOperand(spv_operand_pattern_t* pattern); + +// Calculates the corresponding post-immediate alternate pattern, which allows +// a limited set of operand types. +spv_operand_pattern_t spvAlternatePatternFollowingImmediate( + const spv_operand_pattern_t& pattern); + +// Is the operand an ID? +bool spvIsIdType(spv_operand_type_t type); + +// Is the operand an input ID? +bool spvIsInIdType(spv_operand_type_t type); + +// Takes the opcode of an instruction and returns +// a function object that will return true if the index +// of the operand can be forward declared. This function will +// used in the SSA validation stage of the pipeline +std::function spvOperandCanBeForwardDeclaredFunction( + SpvOp opcode); + +#endif // SOURCE_OPERAND_H_ diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt new file mode 100644 index 000000000..0f3835cbb --- /dev/null +++ b/source/opt/CMakeLists.txt @@ -0,0 +1,225 @@ +# Copyright (c) 2016 Google Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set(SPIRV_TOOLS_OPT_SOURCES + aggressive_dead_code_elim_pass.h + basic_block.h + block_merge_pass.h + build_module.h + ccp_pass.h + cfg_cleanup_pass.h + cfg.h + code_sink.h + combine_access_chains.h + common_uniform_elim_pass.h + compact_ids_pass.h + composite.h + const_folding_rules.h + constants.h + copy_prop_arrays.h + dead_branch_elim_pass.h + dead_insert_elim_pass.h + dead_variable_elimination.h + decoration_manager.h + def_use_manager.h + dominator_analysis.h + dominator_tree.h + eliminate_dead_constant_pass.h + eliminate_dead_functions_pass.h + feature_manager.h + flatten_decoration_pass.h + fold.h + folding_rules.h + fold_spec_constant_op_and_composite_pass.h + freeze_spec_constant_value_pass.h + function.h + if_conversion.h + inline_exhaustive_pass.h + inline_opaque_pass.h + inline_pass.h + inst_bindless_check_pass.h + instruction.h + instruction_list.h + instrument_pass.h + ir_builder.h + ir_context.h + ir_loader.h + licm_pass.h + local_access_chain_convert_pass.h + local_redundancy_elimination.h + local_single_block_elim_pass.h + local_single_store_elim_pass.h + local_ssa_elim_pass.h + log.h + loop_dependence.h + loop_descriptor.h + loop_fission.h + loop_fusion.h + loop_fusion_pass.h + loop_peeling.h + loop_unroller.h + loop_utils.h + loop_unswitch_pass.h + mem_pass.h + merge_return_pass.h + module.h + null_pass.h + passes.h + pass.h + pass_manager.h + private_to_local_pass.h + process_lines_pass.h + propagator.h + reduce_load_size.h + redundancy_elimination.h + reflect.h + register_pressure.h + remove_duplicates_pass.h + replace_invalid_opc.h + scalar_analysis.h + scalar_analysis_nodes.h + scalar_replacement_pass.h + set_spec_constant_default_value_pass.h + simplification_pass.h + ssa_rewrite_pass.h + strength_reduction_pass.h + strip_debug_info_pass.h + strip_reflect_info_pass.h + struct_cfg_analysis.h + tree_iterator.h + type_manager.h + types.h + unify_const_pass.h + upgrade_memory_model.h + value_number_table.h + vector_dce.h + workaround1209.h + + aggressive_dead_code_elim_pass.cpp + basic_block.cpp + block_merge_pass.cpp + build_module.cpp + ccp_pass.cpp + cfg_cleanup_pass.cpp + cfg.cpp + code_sink.cpp + combine_access_chains.cpp + common_uniform_elim_pass.cpp + compact_ids_pass.cpp + composite.cpp + const_folding_rules.cpp + constants.cpp + copy_prop_arrays.cpp + dead_branch_elim_pass.cpp + dead_insert_elim_pass.cpp + dead_variable_elimination.cpp + decoration_manager.cpp + def_use_manager.cpp + dominator_analysis.cpp + dominator_tree.cpp + eliminate_dead_constant_pass.cpp + eliminate_dead_functions_pass.cpp + feature_manager.cpp + flatten_decoration_pass.cpp + fold.cpp + folding_rules.cpp + fold_spec_constant_op_and_composite_pass.cpp + freeze_spec_constant_value_pass.cpp + function.cpp + if_conversion.cpp + inline_exhaustive_pass.cpp + inline_opaque_pass.cpp + inline_pass.cpp + inst_bindless_check_pass.cpp + instruction.cpp + instruction_list.cpp + instrument_pass.cpp + ir_context.cpp + ir_loader.cpp + licm_pass.cpp + local_access_chain_convert_pass.cpp + local_redundancy_elimination.cpp + local_single_block_elim_pass.cpp + local_single_store_elim_pass.cpp + local_ssa_elim_pass.cpp + loop_dependence.cpp + loop_dependence_helpers.cpp + loop_descriptor.cpp + loop_fission.cpp + loop_fusion.cpp + loop_fusion_pass.cpp + loop_peeling.cpp + loop_utils.cpp + loop_unroller.cpp + loop_unswitch_pass.cpp + mem_pass.cpp + merge_return_pass.cpp + module.cpp + optimizer.cpp + pass.cpp + pass_manager.cpp + private_to_local_pass.cpp + process_lines_pass.cpp + propagator.cpp + reduce_load_size.cpp + redundancy_elimination.cpp + register_pressure.cpp + remove_duplicates_pass.cpp + replace_invalid_opc.cpp + scalar_analysis.cpp + scalar_analysis_simplification.cpp + scalar_replacement_pass.cpp + set_spec_constant_default_value_pass.cpp + simplification_pass.cpp + ssa_rewrite_pass.cpp + strength_reduction_pass.cpp + strip_debug_info_pass.cpp + strip_reflect_info_pass.cpp + struct_cfg_analysis.cpp + type_manager.cpp + types.cpp + unify_const_pass.cpp + upgrade_memory_model.cpp + value_number_table.cpp + vector_dce.cpp + workaround1209.cpp +) + +if(MSVC) + # Enable parallel builds across four cores for this lib + add_definitions(/MP4) +endif() + +spvtools_pch(SPIRV_TOOLS_OPT_SOURCES pch_source_opt) + +add_library(SPIRV-Tools-opt ${SPIRV_TOOLS_OPT_SOURCES}) + +spvtools_default_compile_options(SPIRV-Tools-opt) +target_include_directories(SPIRV-Tools-opt + PUBLIC ${spirv-tools_SOURCE_DIR}/include + PUBLIC ${SPIRV_HEADER_INCLUDE_DIR} + PRIVATE ${spirv-tools_BINARY_DIR} +) +# We need the assembling and disassembling functionalities in the main library. +target_link_libraries(SPIRV-Tools-opt + PUBLIC ${SPIRV_TOOLS}) + +set_property(TARGET SPIRV-Tools-opt PROPERTY FOLDER "SPIRV-Tools libraries") +spvtools_check_symbol_exports(SPIRV-Tools-opt) + +if(ENABLE_SPIRV_TOOLS_INSTALL) + install(TARGETS SPIRV-Tools-opt + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}) +endif(ENABLE_SPIRV_TOOLS_INSTALL) diff --git a/source/opt/aggressive_dead_code_elim_pass.cpp b/source/opt/aggressive_dead_code_elim_pass.cpp new file mode 100644 index 000000000..82d749905 --- /dev/null +++ b/source/opt/aggressive_dead_code_elim_pass.cpp @@ -0,0 +1,813 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/aggressive_dead_code_elim_pass.h" + +#include +#include + +#include "source/cfa.h" +#include "source/latest_version_glsl_std_450_header.h" +#include "source/opt/iterator.h" +#include "source/opt/reflect.h" + +namespace spvtools { +namespace opt { + +namespace { + +const uint32_t kTypePointerStorageClassInIdx = 0; +const uint32_t kEntryPointFunctionIdInIdx = 1; +const uint32_t kSelectionMergeMergeBlockIdInIdx = 0; +const uint32_t kLoopMergeMergeBlockIdInIdx = 0; +const uint32_t kLoopMergeContinueBlockIdInIdx = 1; +const uint32_t kCopyMemoryTargetAddrInIdx = 0; +const uint32_t kCopyMemorySourceAddrInIdx = 1; + +// Sorting functor to present annotation instructions in an easy-to-process +// order. The functor orders by opcode first and falls back on unique id +// ordering if both instructions have the same opcode. +// +// Desired priority: +// SpvOpGroupDecorate +// SpvOpGroupMemberDecorate +// SpvOpDecorate +// SpvOpMemberDecorate +// SpvOpDecorateId +// SpvOpDecorateStringGOOGLE +// SpvOpDecorationGroup +struct DecorationLess { + bool operator()(const Instruction* lhs, const Instruction* rhs) const { + assert(lhs && rhs); + SpvOp lhsOp = lhs->opcode(); + SpvOp rhsOp = rhs->opcode(); + if (lhsOp != rhsOp) { +#define PRIORITY_CASE(opcode) \ + if (lhsOp == opcode && rhsOp != opcode) return true; \ + if (rhsOp == opcode && lhsOp != opcode) return false; + // OpGroupDecorate and OpGroupMember decorate are highest priority to + // eliminate dead targets early and simplify subsequent checks. + PRIORITY_CASE(SpvOpGroupDecorate) + PRIORITY_CASE(SpvOpGroupMemberDecorate) + PRIORITY_CASE(SpvOpDecorate) + PRIORITY_CASE(SpvOpMemberDecorate) + PRIORITY_CASE(SpvOpDecorateId) + PRIORITY_CASE(SpvOpDecorateStringGOOGLE) + // OpDecorationGroup is lowest priority to ensure use/def chains remain + // usable for instructions that target this group. + PRIORITY_CASE(SpvOpDecorationGroup) +#undef PRIORITY_CASE + } + + // Fall back to maintain total ordering (compare unique ids). + return *lhs < *rhs; + } +}; + +} // namespace + +bool AggressiveDCEPass::IsVarOfStorage(uint32_t varId, uint32_t storageClass) { + if (varId == 0) return false; + const Instruction* varInst = get_def_use_mgr()->GetDef(varId); + const SpvOp op = varInst->opcode(); + if (op != SpvOpVariable) return false; + const uint32_t varTypeId = varInst->type_id(); + const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId); + if (varTypeInst->opcode() != SpvOpTypePointer) return false; + return varTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx) == + storageClass; +} + +bool AggressiveDCEPass::IsLocalVar(uint32_t varId) { + if (IsVarOfStorage(varId, SpvStorageClassFunction)) { + return true; + } + if (!private_like_local_) { + return false; + } + + return IsVarOfStorage(varId, SpvStorageClassPrivate) || + IsVarOfStorage(varId, SpvStorageClassWorkgroup); +} + +void AggressiveDCEPass::AddStores(uint32_t ptrId) { + get_def_use_mgr()->ForEachUser(ptrId, [this, ptrId](Instruction* user) { + switch (user->opcode()) { + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + case SpvOpCopyObject: + this->AddStores(user->result_id()); + break; + case SpvOpLoad: + break; + case SpvOpCopyMemory: + case SpvOpCopyMemorySized: + if (user->GetSingleWordInOperand(kCopyMemoryTargetAddrInIdx) == ptrId) { + AddToWorklist(user); + } + break; + // If default, assume it stores e.g. frexp, modf, function call + case SpvOpStore: + default: + AddToWorklist(user); + break; + } + }); +} + +bool AggressiveDCEPass::AllExtensionsSupported() const { + // If any extension not in whitelist, return false + for (auto& ei : get_module()->extensions()) { + const char* extName = + reinterpret_cast(&ei.GetInOperand(0).words[0]); + if (extensions_whitelist_.find(extName) == extensions_whitelist_.end()) + return false; + } + return true; +} + +bool AggressiveDCEPass::IsDead(Instruction* inst) { + if (IsLive(inst)) return false; + if ((inst->IsBranch() || inst->opcode() == SpvOpUnreachable) && + !IsStructuredHeader(context()->get_instr_block(inst), nullptr, nullptr, + nullptr)) + return false; + return true; +} + +bool AggressiveDCEPass::IsTargetDead(Instruction* inst) { + const uint32_t tId = inst->GetSingleWordInOperand(0); + Instruction* tInst = get_def_use_mgr()->GetDef(tId); + if (IsAnnotationInst(tInst->opcode())) { + // This must be a decoration group. We go through annotations in a specific + // order. So if this is not used by any group or group member decorates, it + // is dead. + assert(tInst->opcode() == SpvOpDecorationGroup); + bool dead = true; + get_def_use_mgr()->ForEachUser(tInst, [&dead](Instruction* user) { + if (user->opcode() == SpvOpGroupDecorate || + user->opcode() == SpvOpGroupMemberDecorate) + dead = false; + }); + return dead; + } + return IsDead(tInst); +} + +void AggressiveDCEPass::ProcessLoad(uint32_t varId) { + // Only process locals + if (!IsLocalVar(varId)) return; + // Return if already processed + if (live_local_vars_.find(varId) != live_local_vars_.end()) return; + // Mark all stores to varId as live + AddStores(varId); + // Cache varId as processed + live_local_vars_.insert(varId); +} + +bool AggressiveDCEPass::IsStructuredHeader(BasicBlock* bp, + Instruction** mergeInst, + Instruction** branchInst, + uint32_t* mergeBlockId) { + if (!bp) return false; + Instruction* mi = bp->GetMergeInst(); + if (mi == nullptr) return false; + Instruction* bri = &*bp->tail(); + if (branchInst != nullptr) *branchInst = bri; + if (mergeInst != nullptr) *mergeInst = mi; + if (mergeBlockId != nullptr) *mergeBlockId = mi->GetSingleWordInOperand(0); + return true; +} + +void AggressiveDCEPass::ComputeBlock2HeaderMaps( + std::list& structuredOrder) { + block2headerBranch_.clear(); + header2nextHeaderBranch_.clear(); + branch2merge_.clear(); + structured_order_index_.clear(); + std::stack currentHeaderBranch; + currentHeaderBranch.push(nullptr); + uint32_t currentMergeBlockId = 0; + uint32_t index = 0; + for (auto bi = structuredOrder.begin(); bi != structuredOrder.end(); + ++bi, ++index) { + structured_order_index_[*bi] = index; + // If this block is the merge block of the current control construct, + // we are leaving the current construct so we must update state + if ((*bi)->id() == currentMergeBlockId) { + currentHeaderBranch.pop(); + Instruction* chb = currentHeaderBranch.top(); + if (chb != nullptr) + currentMergeBlockId = branch2merge_[chb]->GetSingleWordInOperand(0); + } + Instruction* mergeInst; + Instruction* branchInst; + uint32_t mergeBlockId; + bool is_header = + IsStructuredHeader(*bi, &mergeInst, &branchInst, &mergeBlockId); + // Map header block to next enclosing header. + if (is_header) header2nextHeaderBranch_[*bi] = currentHeaderBranch.top(); + // If this is a loop header, update state first so the block will map to + // itself. + if (is_header && mergeInst->opcode() == SpvOpLoopMerge) { + currentHeaderBranch.push(branchInst); + branch2merge_[branchInst] = mergeInst; + currentMergeBlockId = mergeBlockId; + } + // Map the block to the current construct. + block2headerBranch_[*bi] = currentHeaderBranch.top(); + // If this is an if header, update state so following blocks map to the if. + if (is_header && mergeInst->opcode() == SpvOpSelectionMerge) { + currentHeaderBranch.push(branchInst); + branch2merge_[branchInst] = mergeInst; + currentMergeBlockId = mergeBlockId; + } + } +} + +void AggressiveDCEPass::AddBranch(uint32_t labelId, BasicBlock* bp) { + std::unique_ptr newBranch( + new Instruction(context(), SpvOpBranch, 0, 0, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {labelId}}})); + context()->AnalyzeDefUse(&*newBranch); + context()->set_instr_block(&*newBranch, bp); + bp->AddInstruction(std::move(newBranch)); +} + +void AggressiveDCEPass::AddBreaksAndContinuesToWorklist( + Instruction* mergeInst) { + assert(mergeInst->opcode() == SpvOpSelectionMerge || + mergeInst->opcode() == SpvOpLoopMerge); + + BasicBlock* header = context()->get_instr_block(mergeInst); + uint32_t headerIndex = structured_order_index_[header]; + const uint32_t mergeId = mergeInst->GetSingleWordInOperand(0); + BasicBlock* merge = context()->get_instr_block(mergeId); + uint32_t mergeIndex = structured_order_index_[merge]; + get_def_use_mgr()->ForEachUser( + mergeId, [headerIndex, mergeIndex, this](Instruction* user) { + if (!user->IsBranch()) return; + BasicBlock* block = context()->get_instr_block(user); + uint32_t index = structured_order_index_[block]; + if (headerIndex < index && index < mergeIndex) { + // This is a break from the loop. + AddToWorklist(user); + // Add branch's merge if there is one. + Instruction* userMerge = branch2merge_[user]; + if (userMerge != nullptr) AddToWorklist(userMerge); + } + }); + + if (mergeInst->opcode() != SpvOpLoopMerge) { + return; + } + + // For loops we need to find the continues as well. + const uint32_t contId = + mergeInst->GetSingleWordInOperand(kLoopMergeContinueBlockIdInIdx); + get_def_use_mgr()->ForEachUser(contId, [&contId, this](Instruction* user) { + SpvOp op = user->opcode(); + if (op == SpvOpBranchConditional || op == SpvOpSwitch) { + // A conditional branch or switch can only be a continue if it does not + // have a merge instruction or its merge block is not the continue block. + Instruction* hdrMerge = branch2merge_[user]; + if (hdrMerge != nullptr && hdrMerge->opcode() == SpvOpSelectionMerge) { + uint32_t hdrMergeId = + hdrMerge->GetSingleWordInOperand(kSelectionMergeMergeBlockIdInIdx); + if (hdrMergeId == contId) return; + // Need to mark merge instruction too + AddToWorklist(hdrMerge); + } + } else if (op == SpvOpBranch) { + // An unconditional branch can only be a continue if it is not + // branching to its own merge block. + BasicBlock* blk = context()->get_instr_block(user); + Instruction* hdrBranch = block2headerBranch_[blk]; + if (hdrBranch == nullptr) return; + Instruction* hdrMerge = branch2merge_[hdrBranch]; + if (hdrMerge->opcode() == SpvOpLoopMerge) return; + uint32_t hdrMergeId = + hdrMerge->GetSingleWordInOperand(kSelectionMergeMergeBlockIdInIdx); + if (contId == hdrMergeId) return; + } else { + return; + } + AddToWorklist(user); + }); +} + +bool AggressiveDCEPass::AggressiveDCE(Function* func) { + // Mark function parameters as live. + AddToWorklist(&func->DefInst()); + func->ForEachParam( + [this](const Instruction* param) { + AddToWorklist(const_cast(param)); + }, + false); + + // Compute map from block to controlling conditional branch + std::list structuredOrder; + cfg()->ComputeStructuredOrder(func, &*func->begin(), &structuredOrder); + ComputeBlock2HeaderMaps(structuredOrder); + bool modified = false; + // Add instructions with external side effects to worklist. Also add branches + // EXCEPT those immediately contained in an "if" selection construct or a loop + // or continue construct. + // TODO(greg-lunarg): Handle Frexp, Modf more optimally + call_in_func_ = false; + func_is_entry_point_ = false; + private_stores_.clear(); + // Stacks to keep track of when we are inside an if- or loop-construct. + // When immediately inside an if- or loop-construct, we do not initially + // mark branches live. All other branches must be marked live. + std::stack assume_branches_live; + std::stack currentMergeBlockId; + // Push sentinel values on stack for when outside of any control flow. + assume_branches_live.push(true); + currentMergeBlockId.push(0); + for (auto bi = structuredOrder.begin(); bi != structuredOrder.end(); ++bi) { + // If exiting if or loop, update stacks + if ((*bi)->id() == currentMergeBlockId.top()) { + assume_branches_live.pop(); + currentMergeBlockId.pop(); + } + for (auto ii = (*bi)->begin(); ii != (*bi)->end(); ++ii) { + SpvOp op = ii->opcode(); + switch (op) { + case SpvOpStore: { + uint32_t varId; + (void)GetPtr(&*ii, &varId); + // Mark stores as live if their variable is not function scope + // and is not private scope. Remember private stores for possible + // later inclusion. We cannot call IsLocalVar at this point because + // private_like_local_ has not been set yet. + if (IsVarOfStorage(varId, SpvStorageClassPrivate) || + IsVarOfStorage(varId, SpvStorageClassWorkgroup)) + private_stores_.push_back(&*ii); + else if (!IsVarOfStorage(varId, SpvStorageClassFunction)) + AddToWorklist(&*ii); + } break; + case SpvOpCopyMemory: + case SpvOpCopyMemorySized: { + uint32_t varId; + (void)GetPtr(ii->GetSingleWordInOperand(kCopyMemoryTargetAddrInIdx), + &varId); + if (IsVarOfStorage(varId, SpvStorageClassPrivate) || + IsVarOfStorage(varId, SpvStorageClassWorkgroup)) + private_stores_.push_back(&*ii); + else if (!IsVarOfStorage(varId, SpvStorageClassFunction)) + AddToWorklist(&*ii); + } break; + case SpvOpLoopMerge: { + assume_branches_live.push(false); + currentMergeBlockId.push( + ii->GetSingleWordInOperand(kLoopMergeMergeBlockIdInIdx)); + } break; + case SpvOpSelectionMerge: { + assume_branches_live.push(false); + currentMergeBlockId.push( + ii->GetSingleWordInOperand(kSelectionMergeMergeBlockIdInIdx)); + } break; + case SpvOpSwitch: + case SpvOpBranch: + case SpvOpBranchConditional: + case SpvOpUnreachable: { + if (assume_branches_live.top()) { + AddToWorklist(&*ii); + } + } break; + default: { + // Function calls, atomics, function params, function returns, etc. + // TODO(greg-lunarg): function calls live only if write to non-local + if (!ii->IsOpcodeSafeToDelete()) { + AddToWorklist(&*ii); + } + // Remember function calls + if (op == SpvOpFunctionCall) call_in_func_ = true; + } break; + } + } + } + // See if current function is an entry point + for (auto& ei : get_module()->entry_points()) { + if (ei.GetSingleWordInOperand(kEntryPointFunctionIdInIdx) == + func->result_id()) { + func_is_entry_point_ = true; + break; + } + } + // If the current function is an entry point and has no function calls, + // we can optimize private variables as locals + private_like_local_ = func_is_entry_point_ && !call_in_func_; + // If privates are not like local, add their stores to worklist + if (!private_like_local_) + for (auto& ps : private_stores_) AddToWorklist(ps); + // Perform closure on live instruction set. + while (!worklist_.empty()) { + Instruction* liveInst = worklist_.front(); + // Add all operand instructions if not already live + liveInst->ForEachInId([&liveInst, this](const uint32_t* iid) { + Instruction* inInst = get_def_use_mgr()->GetDef(*iid); + // Do not add label if an operand of a branch. This is not needed + // as part of live code discovery and can create false live code, + // for example, the branch to a header of a loop. + if (inInst->opcode() == SpvOpLabel && liveInst->IsBranch()) return; + AddToWorklist(inInst); + }); + if (liveInst->type_id() != 0) { + AddToWorklist(get_def_use_mgr()->GetDef(liveInst->type_id())); + } + // If in a structured if or loop construct, add the controlling + // conditional branch and its merge. + BasicBlock* blk = context()->get_instr_block(liveInst); + Instruction* branchInst = block2headerBranch_[blk]; + if (branchInst != nullptr) { + AddToWorklist(branchInst); + Instruction* mergeInst = branch2merge_[branchInst]; + AddToWorklist(mergeInst); + } + // If the block is a header, add the next outermost controlling + // conditional branch and its merge. + Instruction* nextBranchInst = header2nextHeaderBranch_[blk]; + if (nextBranchInst != nullptr) { + AddToWorklist(nextBranchInst); + Instruction* mergeInst = branch2merge_[nextBranchInst]; + AddToWorklist(mergeInst); + } + // If local load, add all variable's stores if variable not already live + if (liveInst->opcode() == SpvOpLoad || liveInst->IsAtomicWithLoad()) { + uint32_t varId; + (void)GetPtr(liveInst, &varId); + if (varId != 0) { + ProcessLoad(varId); + } + // Process memory copies like loads + } else if (liveInst->opcode() == SpvOpCopyMemory || + liveInst->opcode() == SpvOpCopyMemorySized) { + uint32_t varId; + (void)GetPtr(liveInst->GetSingleWordInOperand(kCopyMemorySourceAddrInIdx), + &varId); + if (varId != 0) { + ProcessLoad(varId); + } + // If merge, add other branches that are part of its control structure + } else if (liveInst->opcode() == SpvOpLoopMerge || + liveInst->opcode() == SpvOpSelectionMerge) { + AddBreaksAndContinuesToWorklist(liveInst); + // If function call, treat as if it loads from all pointer arguments + } else if (liveInst->opcode() == SpvOpFunctionCall) { + liveInst->ForEachInId([this](const uint32_t* iid) { + // Skip non-ptr args + if (!IsPtr(*iid)) return; + uint32_t varId; + (void)GetPtr(*iid, &varId); + ProcessLoad(varId); + }); + // If function parameter, treat as if it's result id is loaded from + } else if (liveInst->opcode() == SpvOpFunctionParameter) { + ProcessLoad(liveInst->result_id()); + // We treat an OpImageTexelPointer as a load of the pointer, and + // that value is manipulated to get the result. + } else if (liveInst->opcode() == SpvOpImageTexelPointer) { + uint32_t varId; + (void)GetPtr(liveInst, &varId); + if (varId != 0) { + ProcessLoad(varId); + } + } + worklist_.pop(); + } + + // Kill dead instructions and remember dead blocks + for (auto bi = structuredOrder.begin(); bi != structuredOrder.end();) { + uint32_t mergeBlockId = 0; + (*bi)->ForEachInst([this, &modified, &mergeBlockId](Instruction* inst) { + if (!IsDead(inst)) return; + if (inst->opcode() == SpvOpLabel) return; + // If dead instruction is selection merge, remember merge block + // for new branch at end of block + if (inst->opcode() == SpvOpSelectionMerge || + inst->opcode() == SpvOpLoopMerge) + mergeBlockId = inst->GetSingleWordInOperand(0); + to_kill_.push_back(inst); + modified = true; + }); + // If a structured if or loop was deleted, add a branch to its merge + // block, and traverse to the merge block and continue processing there. + // We know the block still exists because the label is not deleted. + if (mergeBlockId != 0) { + AddBranch(mergeBlockId, *bi); + for (++bi; (*bi)->id() != mergeBlockId; ++bi) { + } + } else { + ++bi; + } + } + + return modified; +} + +void AggressiveDCEPass::InitializeModuleScopeLiveInstructions() { + // Keep all execution modes. + for (auto& exec : get_module()->execution_modes()) { + AddToWorklist(&exec); + } + // Keep all entry points. + for (auto& entry : get_module()->entry_points()) { + AddToWorklist(&entry); + } + // Keep workgroup size. + for (auto& anno : get_module()->annotations()) { + if (anno.opcode() == SpvOpDecorate) { + if (anno.GetSingleWordInOperand(1u) == SpvDecorationBuiltIn && + anno.GetSingleWordInOperand(2u) == SpvBuiltInWorkgroupSize) { + AddToWorklist(&anno); + } + } + } +} + +Pass::Status AggressiveDCEPass::ProcessImpl() { + // Current functionality assumes shader capability + // TODO(greg-lunarg): Handle additional capabilities + if (!context()->get_feature_mgr()->HasCapability(SpvCapabilityShader)) + return Status::SuccessWithoutChange; + // Current functionality assumes relaxed logical addressing (see + // instruction.h) + // TODO(greg-lunarg): Handle non-logical addressing + if (context()->get_feature_mgr()->HasCapability(SpvCapabilityAddresses)) + return Status::SuccessWithoutChange; + // If any extensions in the module are not explicitly supported, + // return unmodified. + if (!AllExtensionsSupported()) return Status::SuccessWithoutChange; + + // If the decoration manager is kept live then the context will try to keep it + // up to date. ADCE deals with group decorations by changing the operands in + // |OpGroupDecorate| instruction directly without informing the decoration + // manager. This can put it in an invalid state which will cause an error + // when the context tries to update it. To avoid this problem invalidate + // the decoration manager upfront. + context()->InvalidateAnalyses(IRContext::Analysis::kAnalysisDecorations); + + // Eliminate Dead functions. + bool modified = EliminateDeadFunctions(); + + InitializeModuleScopeLiveInstructions(); + + // Process all entry point functions. + ProcessFunction pfn = [this](Function* fp) { return AggressiveDCE(fp); }; + modified |= context()->ProcessEntryPointCallTree(pfn); + + // Process module-level instructions. Now that all live instructions have + // been marked, it is safe to remove dead global values. + modified |= ProcessGlobalValues(); + + // Kill all dead instructions. + for (auto inst : to_kill_) { + context()->KillInst(inst); + } + + // Cleanup all CFG including all unreachable blocks. + ProcessFunction cleanup = [this](Function* f) { return CFGCleanup(f); }; + modified |= context()->ProcessEntryPointCallTree(cleanup); + + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +bool AggressiveDCEPass::EliminateDeadFunctions() { + // Identify live functions first. Those that are not live + // are dead. ADCE is disabled for non-shaders so we do not check for exported + // functions here. + std::unordered_set live_function_set; + ProcessFunction mark_live = [&live_function_set](Function* fp) { + live_function_set.insert(fp); + return false; + }; + context()->ProcessEntryPointCallTree(mark_live); + + bool modified = false; + for (auto funcIter = get_module()->begin(); + funcIter != get_module()->end();) { + if (live_function_set.count(&*funcIter) == 0) { + modified = true; + EliminateFunction(&*funcIter); + funcIter = funcIter.Erase(); + } else { + ++funcIter; + } + } + + return modified; +} + +void AggressiveDCEPass::EliminateFunction(Function* func) { + // Remove all of the instruction in the function body + func->ForEachInst([this](Instruction* inst) { context()->KillInst(inst); }, + true); +} + +bool AggressiveDCEPass::ProcessGlobalValues() { + // Remove debug and annotation statements referencing dead instructions. + // This must be done before killing the instructions, otherwise there are + // dead objects in the def/use database. + bool modified = false; + Instruction* instruction = &*get_module()->debug2_begin(); + while (instruction) { + if (instruction->opcode() != SpvOpName) { + instruction = instruction->NextNode(); + continue; + } + + if (IsTargetDead(instruction)) { + instruction = context()->KillInst(instruction); + modified = true; + } else { + instruction = instruction->NextNode(); + } + } + + // This code removes all unnecessary decorations safely (see #1174). It also + // does so in a more efficient manner than deleting them only as the targets + // are deleted. + std::vector annotations; + for (auto& inst : get_module()->annotations()) annotations.push_back(&inst); + std::sort(annotations.begin(), annotations.end(), DecorationLess()); + for (auto annotation : annotations) { + switch (annotation->opcode()) { + case SpvOpDecorate: + case SpvOpMemberDecorate: + case SpvOpDecorateStringGOOGLE: + case SpvOpMemberDecorateStringGOOGLE: + if (IsTargetDead(annotation)) { + context()->KillInst(annotation); + modified = true; + } + break; + case SpvOpDecorateId: + if (IsTargetDead(annotation)) { + context()->KillInst(annotation); + modified = true; + } else { + if (annotation->GetSingleWordInOperand(1) == + SpvDecorationHlslCounterBufferGOOGLE) { + // HlslCounterBuffer will reference an id other than the target. + // If that id is dead, then the decoration can be removed as well. + uint32_t counter_buffer_id = annotation->GetSingleWordInOperand(2); + Instruction* counter_buffer_inst = + get_def_use_mgr()->GetDef(counter_buffer_id); + if (IsDead(counter_buffer_inst)) { + context()->KillInst(annotation); + modified = true; + } + } + } + break; + case SpvOpGroupDecorate: { + // Go through the targets of this group decorate. Remove each dead + // target. If all targets are dead, remove this decoration. + bool dead = true; + bool removed_operand = false; + for (uint32_t i = 1; i < annotation->NumOperands();) { + Instruction* opInst = + get_def_use_mgr()->GetDef(annotation->GetSingleWordOperand(i)); + if (IsDead(opInst)) { + // Don't increment |i|. + annotation->RemoveOperand(i); + modified = true; + removed_operand = true; + } else { + i++; + dead = false; + } + } + if (dead) { + context()->KillInst(annotation); + modified = true; + } else if (removed_operand) { + context()->UpdateDefUse(annotation); + } + break; + } + case SpvOpGroupMemberDecorate: { + // Go through the targets of this group member decorate. Remove each + // dead target (and member index). If all targets are dead, remove this + // decoration. + bool dead = true; + bool removed_operand = false; + for (uint32_t i = 1; i < annotation->NumOperands();) { + Instruction* opInst = + get_def_use_mgr()->GetDef(annotation->GetSingleWordOperand(i)); + if (IsDead(opInst)) { + // Don't increment |i|. + annotation->RemoveOperand(i + 1); + annotation->RemoveOperand(i); + modified = true; + removed_operand = true; + } else { + i += 2; + dead = false; + } + } + if (dead) { + context()->KillInst(annotation); + modified = true; + } else if (removed_operand) { + context()->UpdateDefUse(annotation); + } + break; + } + case SpvOpDecorationGroup: + // By the time we hit decoration groups we've checked everything that + // can target them. So if they have no uses they must be dead. + if (get_def_use_mgr()->NumUsers(annotation) == 0) { + context()->KillInst(annotation); + modified = true; + } + break; + default: + assert(false); + break; + } + } + + // Since ADCE is disabled for non-shaders, we don't check for export linkage + // attributes here. + for (auto& val : get_module()->types_values()) { + if (IsDead(&val)) { + to_kill_.push_back(&val); + } + } + + return modified; +} + +AggressiveDCEPass::AggressiveDCEPass() = default; + +Pass::Status AggressiveDCEPass::Process() { + // Initialize extensions whitelist + InitExtensions(); + return ProcessImpl(); +} + +void AggressiveDCEPass::InitExtensions() { + extensions_whitelist_.clear(); + extensions_whitelist_.insert({ + "SPV_AMD_shader_explicit_vertex_parameter", + "SPV_AMD_shader_trinary_minmax", + "SPV_AMD_gcn_shader", + "SPV_KHR_shader_ballot", + "SPV_AMD_shader_ballot", + "SPV_AMD_gpu_shader_half_float", + "SPV_KHR_shader_draw_parameters", + "SPV_KHR_subgroup_vote", + "SPV_KHR_16bit_storage", + "SPV_KHR_device_group", + "SPV_KHR_multiview", + "SPV_NVX_multiview_per_view_attributes", + "SPV_NV_viewport_array2", + "SPV_NV_stereo_view_rendering", + "SPV_NV_sample_mask_override_coverage", + "SPV_NV_geometry_shader_passthrough", + "SPV_AMD_texture_gather_bias_lod", + "SPV_KHR_storage_buffer_storage_class", + // SPV_KHR_variable_pointers + // Currently do not support extended pointer expressions + "SPV_AMD_gpu_shader_int16", + "SPV_KHR_post_depth_coverage", + "SPV_KHR_shader_atomic_counter_ops", + "SPV_EXT_shader_stencil_export", + "SPV_EXT_shader_viewport_index_layer", + "SPV_AMD_shader_image_load_store_lod", + "SPV_AMD_shader_fragment_mask", + "SPV_EXT_fragment_fully_covered", + "SPV_AMD_gpu_shader_half_float_fetch", + "SPV_GOOGLE_decorate_string", + "SPV_GOOGLE_hlsl_functionality1", + "SPV_NV_shader_subgroup_partitioned", + "SPV_EXT_descriptor_indexing", + "SPV_NV_fragment_shader_barycentric", + "SPV_NV_compute_shader_derivatives", + "SPV_NV_shader_image_footprint", + "SPV_NV_shading_rate", + "SPV_NV_mesh_shader", + "SPV_NV_ray_tracing", + "SPV_EXT_fragment_invocation_density", + }); +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/aggressive_dead_code_elim_pass.h b/source/opt/aggressive_dead_code_elim_pass.h new file mode 100644 index 000000000..c043a96f4 --- /dev/null +++ b/source/opt/aggressive_dead_code_elim_pass.h @@ -0,0 +1,200 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_AGGRESSIVE_DEAD_CODE_ELIM_PASS_H_ +#define SOURCE_OPT_AGGRESSIVE_DEAD_CODE_ELIM_PASS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/opt/basic_block.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/mem_pass.h" +#include "source/opt/module.h" +#include "source/util/bit_vector.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class AggressiveDCEPass : public MemPass { + using cbb_ptr = const BasicBlock*; + + public: + using GetBlocksFunction = + std::function*(const BasicBlock*)>; + + AggressiveDCEPass(); + const char* name() const override { return "eliminate-dead-code-aggressive"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisConstants | IRContext::kAnalysisTypes; + } + + private: + // Return true if |varId| is a variable of |storageClass|. |varId| must either + // be 0 or the result of an instruction. + bool IsVarOfStorage(uint32_t varId, uint32_t storageClass); + + // Return true if |varId| is variable of function storage class or is + // private variable and privates can be optimized like locals (see + // privates_like_local_). + bool IsLocalVar(uint32_t varId); + + // Return true if |inst| is marked live. + bool IsLive(const Instruction* inst) const { + return live_insts_.Get(inst->unique_id()); + } + + // Returns true if |inst| is dead. + bool IsDead(Instruction* inst); + + // Adds entry points, execution modes and workgroup size decorations to the + // worklist for processing with the first function. + void InitializeModuleScopeLiveInstructions(); + + // Add |inst| to worklist_ and live_insts_. + void AddToWorklist(Instruction* inst) { + if (!live_insts_.Set(inst->unique_id())) { + worklist_.push(inst); + } + } + + // Add all store instruction which use |ptrId|, directly or indirectly, + // to the live instruction worklist. + void AddStores(uint32_t ptrId); + + // Initialize extensions whitelist + void InitExtensions(); + + // Return true if all extensions in this module are supported by this pass. + bool AllExtensionsSupported() const; + + // Returns true if the target of |inst| is dead. An instruction is dead if + // its result id is used in decoration or debug instructions only. |inst| is + // assumed to be OpName, OpMemberName or an annotation instruction. + bool IsTargetDead(Instruction* inst); + + // If |varId| is local, mark all stores of varId as live. + void ProcessLoad(uint32_t varId); + + // If |bp| is structured header block, returns true and sets |mergeInst| to + // the merge instruction, |branchInst| to the branch and |mergeBlockId| to the + // merge block if they are not nullptr. Any of |mergeInst|, |branchInst| or + // |mergeBlockId| may be a null pointer. Returns false if |bp| is a null + // pointer. + bool IsStructuredHeader(BasicBlock* bp, Instruction** mergeInst, + Instruction** branchInst, uint32_t* mergeBlockId); + + // Initialize block2headerBranch_, header2nextHeaderBranch_, and + // branch2merge_ using |structuredOrder| to order blocks. + void ComputeBlock2HeaderMaps(std::list& structuredOrder); + + // Add branch to |labelId| to end of block |bp|. + void AddBranch(uint32_t labelId, BasicBlock* bp); + + // Add all break and continue branches in the construct associated with + // |mergeInst| to worklist if not already live + void AddBreaksAndContinuesToWorklist(Instruction* mergeInst); + + // Eliminates dead debug2 and annotation instructions. Marks dead globals for + // removal (e.g. types, constants and variables). + bool ProcessGlobalValues(); + + // Erases functions that are unreachable from the entry points of the module. + bool EliminateDeadFunctions(); + + // Removes |func| from the module and deletes all its instructions. + void EliminateFunction(Function* func); + + // For function |func|, mark all Stores to non-function-scope variables + // and block terminating instructions as live. Recursively mark the values + // they use. When complete, mark any non-live instructions to be deleted. + // Returns true if the function has been modified. + // + // Note: This function does not delete useless control structures. All + // existing control structures will remain. This can leave not-insignificant + // sequences of ultimately useless code. + // TODO(): Remove useless control constructs. + bool AggressiveDCE(Function* func); + + Pass::Status ProcessImpl(); + + // True if current function has a call instruction contained in it + bool call_in_func_; + + // True if current function is an entry point + bool func_is_entry_point_; + + // True if current function is entry point and has no function calls. + bool private_like_local_; + + // Live Instruction Worklist. An instruction is added to this list + // if it might have a side effect, either directly or indirectly. + // If we don't know, then add it to this list. Instructions are + // removed from this list as the algorithm traces side effects, + // building up the live instructions set |live_insts_|. + std::queue worklist_; + + // Map from block to the branch instruction in the header of the most + // immediate controlling structured if or loop. A loop header block points + // to its own branch instruction. An if-selection block points to the branch + // of an enclosing construct's header, if one exists. + std::unordered_map block2headerBranch_; + + // Map from header block to the branch instruction in the header of the + // structured construct enclosing it. + // The liveness algorithm is designed to iteratively mark as live all + // structured constructs enclosing a live instruction. + std::unordered_map header2nextHeaderBranch_; + + // Maps basic block to their index in the structured order traversal. + std::unordered_map structured_order_index_; + + // Map from branch to its associated merge instruction, if any + std::unordered_map branch2merge_; + + // Store instructions to variables of private storage + std::vector private_stores_; + + // Live Instructions + utils::BitVector live_insts_; + + // Live Local Variables + std::unordered_set live_local_vars_; + + // List of instructions to delete. Deletion is delayed until debug and + // annotation instructions are processed. + std::vector to_kill_; + + // Extensions supported by this pass. + std::unordered_set extensions_whitelist_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_AGGRESSIVE_DEAD_CODE_ELIM_PASS_H_ diff --git a/source/opt/basic_block.cpp b/source/opt/basic_block.cpp new file mode 100644 index 000000000..aafee5143 --- /dev/null +++ b/source/opt/basic_block.cpp @@ -0,0 +1,266 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/basic_block.h" + +#include + +#include "source/opt/function.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/reflect.h" +#include "source/util/make_unique.h" + +namespace spvtools { +namespace opt { +namespace { + +const uint32_t kLoopMergeContinueBlockIdInIdx = 1; +const uint32_t kLoopMergeMergeBlockIdInIdx = 0; +const uint32_t kSelectionMergeMergeBlockIdInIdx = 0; + +} // namespace + +BasicBlock* BasicBlock::Clone(IRContext* context) const { + BasicBlock* clone = new BasicBlock( + std::unique_ptr(GetLabelInst()->Clone(context))); + for (const auto& inst : insts_) { + // Use the incoming context + clone->AddInstruction(std::unique_ptr(inst.Clone(context))); + } + + if (context->AreAnalysesValid( + IRContext::Analysis::kAnalysisInstrToBlockMapping)) { + for (auto& inst : *clone) { + context->set_instr_block(&inst, clone); + } + } + + return clone; +} + +const Instruction* BasicBlock::GetMergeInst() const { + const Instruction* result = nullptr; + // If it exists, the merge instruction immediately precedes the + // terminator. + auto iter = ctail(); + if (iter != cbegin()) { + --iter; + const auto opcode = iter->opcode(); + if (opcode == SpvOpLoopMerge || opcode == SpvOpSelectionMerge) { + result = &*iter; + } + } + return result; +} + +Instruction* BasicBlock::GetMergeInst() { + Instruction* result = nullptr; + // If it exists, the merge instruction immediately precedes the + // terminator. + auto iter = tail(); + if (iter != begin()) { + --iter; + const auto opcode = iter->opcode(); + if (opcode == SpvOpLoopMerge || opcode == SpvOpSelectionMerge) { + result = &*iter; + } + } + return result; +} + +const Instruction* BasicBlock::GetLoopMergeInst() const { + if (auto* merge = GetMergeInst()) { + if (merge->opcode() == SpvOpLoopMerge) { + return merge; + } + } + return nullptr; +} + +Instruction* BasicBlock::GetLoopMergeInst() { + if (auto* merge = GetMergeInst()) { + if (merge->opcode() == SpvOpLoopMerge) { + return merge; + } + } + return nullptr; +} + +void BasicBlock::KillAllInsts(bool killLabel) { + ForEachInst([killLabel](Instruction* ip) { + if (killLabel || ip->opcode() != SpvOpLabel) { + ip->context()->KillInst(ip); + } + }); +} + +void BasicBlock::ForEachSuccessorLabel( + const std::function& f) const { + const auto br = &insts_.back(); + switch (br->opcode()) { + case SpvOpBranch: { + f(br->GetOperand(0).words[0]); + } break; + case SpvOpBranchConditional: + case SpvOpSwitch: { + bool is_first = true; + br->ForEachInId([&is_first, &f](const uint32_t* idp) { + if (!is_first) f(*idp); + is_first = false; + }); + } break; + default: + break; + } +} + +void BasicBlock::ForEachSuccessorLabel( + const std::function& f) { + auto br = &insts_.back(); + switch (br->opcode()) { + case SpvOpBranch: { + uint32_t tmp_id = br->GetOperand(0).words[0]; + f(&tmp_id); + if (tmp_id != br->GetOperand(0).words[0]) br->SetOperand(0, {tmp_id}); + } break; + case SpvOpBranchConditional: + case SpvOpSwitch: { + bool is_first = true; + br->ForEachInId([&is_first, &f](uint32_t* idp) { + if (!is_first) f(idp); + is_first = false; + }); + } break; + default: + break; + } +} + +bool BasicBlock::IsSuccessor(const BasicBlock* block) const { + uint32_t succId = block->id(); + bool isSuccessor = false; + ForEachSuccessorLabel([&isSuccessor, succId](const uint32_t label) { + if (label == succId) isSuccessor = true; + }); + return isSuccessor; +} + +void BasicBlock::ForMergeAndContinueLabel( + const std::function& f) { + auto ii = insts_.end(); + --ii; + if (ii == insts_.begin()) return; + --ii; + if (ii->opcode() == SpvOpSelectionMerge || ii->opcode() == SpvOpLoopMerge) { + ii->ForEachInId([&f](const uint32_t* idp) { f(*idp); }); + } +} + +uint32_t BasicBlock::MergeBlockIdIfAny() const { + auto merge_ii = cend(); + --merge_ii; + uint32_t mbid = 0; + if (merge_ii != cbegin()) { + --merge_ii; + if (merge_ii->opcode() == SpvOpLoopMerge) { + mbid = merge_ii->GetSingleWordInOperand(kLoopMergeMergeBlockIdInIdx); + } else if (merge_ii->opcode() == SpvOpSelectionMerge) { + mbid = merge_ii->GetSingleWordInOperand(kSelectionMergeMergeBlockIdInIdx); + } + } + + return mbid; +} + +uint32_t BasicBlock::ContinueBlockIdIfAny() const { + auto merge_ii = cend(); + --merge_ii; + uint32_t cbid = 0; + if (merge_ii != cbegin()) { + --merge_ii; + if (merge_ii->opcode() == SpvOpLoopMerge) { + cbid = merge_ii->GetSingleWordInOperand(kLoopMergeContinueBlockIdInIdx); + } + } + return cbid; +} + +std::ostream& operator<<(std::ostream& str, const BasicBlock& block) { + str << block.PrettyPrint(); + return str; +} + +void BasicBlock::Dump() const { + std::cerr << "Basic block #" << id() << "\n" << *this << "\n "; +} + +std::string BasicBlock::PrettyPrint(uint32_t options) const { + std::ostringstream str; + ForEachInst([&str, options](const Instruction* inst) { + str << inst->PrettyPrint(options); + if (!IsTerminatorInst(inst->opcode())) { + str << std::endl; + } + }); + return str.str(); +} + +BasicBlock* BasicBlock::SplitBasicBlock(IRContext* context, uint32_t label_id, + iterator iter) { + assert(!insts_.empty()); + + std::unique_ptr new_block_temp = + MakeUnique(MakeUnique( + context, SpvOpLabel, 0, label_id, std::initializer_list{})); + BasicBlock* new_block = new_block_temp.get(); + function_->InsertBasicBlockAfter(std::move(new_block_temp), this); + + new_block->insts_.Splice(new_block->end(), &insts_, iter, end()); + new_block->SetParent(GetParent()); + + context->AnalyzeDefUse(new_block->GetLabelInst()); + + // Update the phi nodes in the successor blocks to reference the new block id. + const_cast(new_block)->ForEachSuccessorLabel( + [new_block, this, context](const uint32_t label) { + BasicBlock* target_bb = context->get_instr_block(label); + target_bb->ForEachPhiInst( + [this, new_block, context](Instruction* phi_inst) { + bool changed = false; + for (uint32_t i = 1; i < phi_inst->NumInOperands(); i += 2) { + if (phi_inst->GetSingleWordInOperand(i) == this->id()) { + changed = true; + phi_inst->SetInOperand(i, {new_block->id()}); + } + } + + if (changed) { + context->UpdateDefUse(phi_inst); + } + }); + }); + + if (context->AreAnalysesValid(IRContext::kAnalysisInstrToBlockMapping)) { + context->set_instr_block(new_block->GetLabelInst(), new_block); + new_block->ForEachInst([new_block, context](Instruction* inst) { + context->set_instr_block(inst, new_block); + }); + } + + return new_block; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/basic_block.h b/source/opt/basic_block.h new file mode 100644 index 000000000..ff3a41280 --- /dev/null +++ b/source/opt/basic_block.h @@ -0,0 +1,331 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file defines the language constructs for representing a SPIR-V +// module in memory. + +#ifndef SOURCE_OPT_BASIC_BLOCK_H_ +#define SOURCE_OPT_BASIC_BLOCK_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "source/opt/instruction.h" +#include "source/opt/instruction_list.h" +#include "source/opt/iterator.h" + +namespace spvtools { +namespace opt { + +class Function; +class IRContext; + +// A SPIR-V basic block. +class BasicBlock { + public: + using iterator = InstructionList::iterator; + using const_iterator = InstructionList::const_iterator; + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = + std::reverse_iterator; + + // Creates a basic block with the given starting |label|. + inline explicit BasicBlock(std::unique_ptr label); + + explicit BasicBlock(const BasicBlock& bb) = delete; + + // Creates a clone of the basic block in the given |context| + // + // The parent function will default to null and needs to be explicitly set by + // the user. + // + // If the inst-to-block map in |context| is valid, then the new instructions + // will be inserted into the map. + BasicBlock* Clone(IRContext*) const; + + // Sets the enclosing function for this basic block. + void SetParent(Function* function) { function_ = function; } + + // Return the enclosing function + inline Function* GetParent() const { return function_; } + + // Appends an instruction to this basic block. + inline void AddInstruction(std::unique_ptr i); + + // Appends all of block's instructions (except label) to this block + inline void AddInstructions(BasicBlock* bp); + + // The pointer to the label starting this basic block. + std::unique_ptr& GetLabel() { return label_; } + + // The label starting this basic block. + Instruction* GetLabelInst() { return label_.get(); } + const Instruction* GetLabelInst() const { return label_.get(); } + + // Returns the merge instruction in this basic block, if it exists. + // Otherwise return null. May be used whenever tail() can be used. + const Instruction* GetMergeInst() const; + Instruction* GetMergeInst(); + + // Returns the OpLoopMerge instruciton in this basic block, if it exists. + // Otherwise return null. May be used whenever tail() can be used. + const Instruction* GetLoopMergeInst() const; + Instruction* GetLoopMergeInst(); + + // Returns the id of the label at the top of this block + inline uint32_t id() const { return label_->result_id(); } + + iterator begin() { return insts_.begin(); } + iterator end() { return insts_.end(); } + const_iterator begin() const { return insts_.cbegin(); } + const_iterator end() const { return insts_.cend(); } + const_iterator cbegin() const { return insts_.cbegin(); } + const_iterator cend() const { return insts_.cend(); } + + reverse_iterator rbegin() { return reverse_iterator(end()); } + reverse_iterator rend() { return reverse_iterator(begin()); } + const_reverse_iterator rbegin() const { + return const_reverse_iterator(cend()); + } + const_reverse_iterator rend() const { + return const_reverse_iterator(cbegin()); + } + const_reverse_iterator crbegin() const { + return const_reverse_iterator(cend()); + } + const_reverse_iterator crend() const { + return const_reverse_iterator(cbegin()); + } + + // Returns an iterator pointing to the last instruction. This may only + // be used if this block has an instruction other than the OpLabel + // that defines it. + iterator tail() { + assert(!insts_.empty()); + return --end(); + } + + // Returns a const iterator, but othewrise similar to tail(). + const_iterator ctail() const { + assert(!insts_.empty()); + return --insts_.cend(); + } + + // Returns true if the basic block has at least one successor. + inline bool hasSuccessor() const { return ctail()->IsBranch(); } + + // Runs the given function |f| on each instruction in this basic block, and + // optionally on the debug line instructions that might precede them. + inline void ForEachInst(const std::function& f, + bool run_on_debug_line_insts = false); + inline void ForEachInst(const std::function& f, + bool run_on_debug_line_insts = false) const; + + // Runs the given function |f| on each instruction in this basic block, and + // optionally on the debug line instructions that might precede them. If |f| + // returns false, iteration is terminated and this function returns false. + inline bool WhileEachInst(const std::function& f, + bool run_on_debug_line_insts = false); + inline bool WhileEachInst(const std::function& f, + bool run_on_debug_line_insts = false) const; + + // Runs the given function |f| on each Phi instruction in this basic block, + // and optionally on the debug line instructions that might precede them. + inline void ForEachPhiInst(const std::function& f, + bool run_on_debug_line_insts = false); + + // Runs the given function |f| on each Phi instruction in this basic block, + // and optionally on the debug line instructions that might precede them. If + // |f| returns false, iteration is terminated and this function return false. + inline bool WhileEachPhiInst(const std::function& f, + bool run_on_debug_line_insts = false); + + // Runs the given function |f| on each label id of each successor block + void ForEachSuccessorLabel( + const std::function& f) const; + + // Runs the given function |f| on each label id of each successor block. + // Modifying the pointed value will change the branch taken by the basic + // block. It is the caller responsibility to update or invalidate the CFG. + void ForEachSuccessorLabel(const std::function& f); + + // Returns true if |block| is a direct successor of |this|. + bool IsSuccessor(const BasicBlock* block) const; + + // Runs the given function |f| on the merge and continue label, if any + void ForMergeAndContinueLabel(const std::function& f); + + // Returns true if this basic block has any Phi instructions. + bool HasPhiInstructions() { + return !WhileEachPhiInst([](Instruction*) { return false; }); + } + + // Return true if this block is a loop header block. + bool IsLoopHeader() const { return GetLoopMergeInst() != nullptr; } + + // Returns the ID of the merge block declared by a merge instruction in this + // block, if any. If none, returns zero. + uint32_t MergeBlockIdIfAny() const; + + // Returns the ID of the continue block declared by a merge instruction in + // this block, if any. If none, returns zero. + uint32_t ContinueBlockIdIfAny() const; + + // Returns the terminator instruction. Assumes the terminator exists. + Instruction* terminator() { return &*tail(); } + const Instruction* terminator() const { return &*ctail(); } + + // Returns true if this basic block exits this function and returns to its + // caller. + bool IsReturn() const { return ctail()->IsReturn(); } + + // Returns true if this basic block exits this function or aborts execution. + bool IsReturnOrAbort() const { return ctail()->IsReturnOrAbort(); } + + // Kill all instructions in this block. Whether or not to kill the label is + // indicated by |killLabel|. + void KillAllInsts(bool killLabel); + + // Splits this basic block into two. Returns a new basic block with label + // |labelId| containing the instructions from |iter| onwards. Instructions + // prior to |iter| remain in this basic block. The new block will be added + // to the function immediately after the original block. + BasicBlock* SplitBasicBlock(IRContext* context, uint32_t label_id, + iterator iter); + + // Pretty-prints this basic block into a std::string by printing every + // instruction in it. + // + // |options| are the disassembly options. SPV_BINARY_TO_TEXT_OPTION_NO_HEADER + // is always added to |options|. + std::string PrettyPrint(uint32_t options = 0u) const; + + // Dump this basic block on stderr. Useful when running interactive + // debuggers. + void Dump() const; + + private: + // The enclosing function. + Function* function_; + // The label starting this basic block. + std::unique_ptr label_; + // Instructions inside this basic block, but not the OpLabel. + InstructionList insts_; +}; + +// Pretty-prints |block| to |str|. Returns |str|. +std::ostream& operator<<(std::ostream& str, const BasicBlock& block); + +inline BasicBlock::BasicBlock(std::unique_ptr label) + : function_(nullptr), label_(std::move(label)) {} + +inline void BasicBlock::AddInstruction(std::unique_ptr i) { + insts_.push_back(std::move(i)); +} + +inline void BasicBlock::AddInstructions(BasicBlock* bp) { + auto bEnd = end(); + (void)bEnd.MoveBefore(&bp->insts_); +} + +inline bool BasicBlock::WhileEachInst( + const std::function& f, bool run_on_debug_line_insts) { + if (label_) { + if (!label_->WhileEachInst(f, run_on_debug_line_insts)) return false; + } + if (insts_.empty()) { + return true; + } + + Instruction* inst = &insts_.front(); + while (inst != nullptr) { + Instruction* next_instruction = inst->NextNode(); + if (!inst->WhileEachInst(f, run_on_debug_line_insts)) return false; + inst = next_instruction; + } + return true; +} + +inline bool BasicBlock::WhileEachInst( + const std::function& f, + bool run_on_debug_line_insts) const { + if (label_) { + if (!static_cast(label_.get()) + ->WhileEachInst(f, run_on_debug_line_insts)) + return false; + } + for (const auto& inst : insts_) { + if (!static_cast(&inst)->WhileEachInst( + f, run_on_debug_line_insts)) + return false; + } + return true; +} + +inline void BasicBlock::ForEachInst(const std::function& f, + bool run_on_debug_line_insts) { + WhileEachInst( + [&f](Instruction* inst) { + f(inst); + return true; + }, + run_on_debug_line_insts); +} + +inline void BasicBlock::ForEachInst( + const std::function& f, + bool run_on_debug_line_insts) const { + WhileEachInst( + [&f](const Instruction* inst) { + f(inst); + return true; + }, + run_on_debug_line_insts); +} + +inline bool BasicBlock::WhileEachPhiInst( + const std::function& f, bool run_on_debug_line_insts) { + if (insts_.empty()) { + return true; + } + + Instruction* inst = &insts_.front(); + while (inst != nullptr) { + Instruction* next_instruction = inst->NextNode(); + if (inst->opcode() != SpvOpPhi) break; + if (!inst->WhileEachInst(f, run_on_debug_line_insts)) return false; + inst = next_instruction; + } + return true; +} + +inline void BasicBlock::ForEachPhiInst( + const std::function& f, bool run_on_debug_line_insts) { + WhileEachPhiInst( + [&f](Instruction* inst) { + f(inst); + return true; + }, + run_on_debug_line_insts); +} + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_BASIC_BLOCK_H_ diff --git a/source/opt/block_merge_pass.cpp b/source/opt/block_merge_pass.cpp new file mode 100644 index 000000000..09deb217a --- /dev/null +++ b/source/opt/block_merge_pass.cpp @@ -0,0 +1,146 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/block_merge_pass.h" + +#include + +#include "source/opt/ir_context.h" +#include "source/opt/iterator.h" + +namespace spvtools { +namespace opt { + +bool BlockMergePass::MergeBlocks(Function* func) { + bool modified = false; + for (auto bi = func->begin(); bi != func->end();) { + // Find block with single successor which has no other predecessors. + auto ii = bi->end(); + --ii; + Instruction* br = &*ii; + if (br->opcode() != SpvOpBranch) { + ++bi; + continue; + } + + const uint32_t lab_id = br->GetSingleWordInOperand(0); + if (cfg()->preds(lab_id).size() != 1) { + ++bi; + continue; + } + + bool pred_is_merge = IsMerge(&*bi); + bool succ_is_merge = IsMerge(lab_id); + if (pred_is_merge && succ_is_merge) { + // Cannot merge two merges together. + ++bi; + continue; + } + + Instruction* merge_inst = bi->GetMergeInst(); + bool pred_is_header = IsHeader(&*bi); + if (pred_is_header && lab_id != merge_inst->GetSingleWordInOperand(0u)) { + bool succ_is_header = IsHeader(lab_id); + if (pred_is_header && succ_is_header) { + // Cannot merge two headers together when the successor is not the merge + // block of the predecessor. + ++bi; + continue; + } + + // If this is a header block and the successor is not its merge, we must + // be careful about which blocks we are willing to merge together. + // OpLoopMerge must be followed by a conditional or unconditional branch. + // The merge must be a loop merge because a selection merge cannot be + // followed by an unconditional branch. + BasicBlock* succ_block = context()->get_instr_block(lab_id); + SpvOp succ_term_op = succ_block->terminator()->opcode(); + assert(merge_inst->opcode() == SpvOpLoopMerge); + if (succ_term_op != SpvOpBranch && + succ_term_op != SpvOpBranchConditional) { + ++bi; + continue; + } + } + + // Merge blocks. + context()->KillInst(br); + auto sbi = bi; + for (; sbi != func->end(); ++sbi) + if (sbi->id() == lab_id) break; + // If bi is sbi's only predecessor, it dominates sbi and thus + // sbi must follow bi in func's ordering. + assert(sbi != func->end()); + + // Update the inst-to-block mapping for the instructions in sbi. + for (auto& inst : *sbi) { + context()->set_instr_block(&inst, &*bi); + } + + // Now actually move the instructions. + bi->AddInstructions(&*sbi); + + if (merge_inst) { + if (pred_is_header && lab_id == merge_inst->GetSingleWordInOperand(0u)) { + // Merging the header and merge blocks, so remove the structured control + // flow declaration. + context()->KillInst(merge_inst); + } else { + // Move the merge instruction to just before the terminator. + merge_inst->InsertBefore(bi->terminator()); + } + } + context()->ReplaceAllUsesWith(lab_id, bi->id()); + context()->KillInst(sbi->GetLabelInst()); + (void)sbi.Erase(); + // Reprocess block. + modified = true; + } + return modified; +} + +bool BlockMergePass::IsHeader(BasicBlock* block) { + return block->GetMergeInst() != nullptr; +} + +bool BlockMergePass::IsHeader(uint32_t id) { + return IsHeader(context()->get_instr_block(get_def_use_mgr()->GetDef(id))); +} + +bool BlockMergePass::IsMerge(uint32_t id) { + return !get_def_use_mgr()->WhileEachUse(id, [](Instruction* user, + uint32_t index) { + SpvOp op = user->opcode(); + if ((op == SpvOpLoopMerge || op == SpvOpSelectionMerge) && index == 0u) { + return false; + } + return true; + }); +} + +bool BlockMergePass::IsMerge(BasicBlock* block) { return IsMerge(block->id()); } + +Pass::Status BlockMergePass::Process() { + // Process all entry point functions. + ProcessFunction pfn = [this](Function* fp) { return MergeBlocks(fp); }; + bool modified = context()->ProcessEntryPointCallTree(pfn); + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +BlockMergePass::BlockMergePass() = default; + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/block_merge_pass.h b/source/opt/block_merge_pass.h new file mode 100644 index 000000000..3ae7a5c00 --- /dev/null +++ b/source/opt/block_merge_pass.h @@ -0,0 +1,70 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_BLOCK_MERGE_PASS_H_ +#define SOURCE_OPT_BLOCK_MERGE_PASS_H_ + +#include +#include +#include +#include +#include +#include + +#include "source/opt/basic_block.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class BlockMergePass : public Pass { + public: + BlockMergePass(); + const char* name() const override { return "merge-blocks"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators | + IRContext::kAnalysisNameMap | IRContext::kAnalysisConstants | + IRContext::kAnalysisTypes; + } + + private: + + // Search |func| for blocks which have a single Branch to a block + // with no other predecessors. Merge these blocks into a single block. + bool MergeBlocks(Function* func); + + // Returns true if |block| (or |id|) contains a merge instruction. + bool IsHeader(BasicBlock* block); + bool IsHeader(uint32_t id); + + // Returns true if |block| (or |id|) is the merge target of a merge + // instruction. + bool IsMerge(BasicBlock* block); + bool IsMerge(uint32_t id); +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_BLOCK_MERGE_PASS_H_ diff --git a/source/opt/build_module.cpp b/source/opt/build_module.cpp new file mode 100644 index 000000000..fc76a3c29 --- /dev/null +++ b/source/opt/build_module.cpp @@ -0,0 +1,79 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/build_module.h" + +#include +#include + +#include "source/opt/ir_context.h" +#include "source/opt/ir_loader.h" +#include "source/table.h" +#include "source/util/make_unique.h" + +namespace spvtools { +namespace { + +// Sets the module header for IrLoader. Meets the interface requirement of +// spvBinaryParse(). +spv_result_t SetSpvHeader(void* builder, spv_endianness_t, uint32_t magic, + uint32_t version, uint32_t generator, + uint32_t id_bound, uint32_t reserved) { + reinterpret_cast(builder)->SetModuleHeader( + magic, version, generator, id_bound, reserved); + return SPV_SUCCESS; +} + +// Processes a parsed instruction for IrLoader. Meets the interface requirement +// of spvBinaryParse(). +spv_result_t SetSpvInst(void* builder, const spv_parsed_instruction_t* inst) { + if (reinterpret_cast(builder)->AddInstruction(inst)) { + return SPV_SUCCESS; + } + return SPV_ERROR_INVALID_BINARY; +} + +} // namespace + +std::unique_ptr BuildModule(spv_target_env env, + MessageConsumer consumer, + const uint32_t* binary, + const size_t size) { + auto context = spvContextCreate(env); + SetContextMessageConsumer(context, consumer); + + auto irContext = MakeUnique(env, consumer); + opt::IrLoader loader(consumer, irContext->module()); + + spv_result_t status = spvBinaryParse(context, &loader, binary, size, + SetSpvHeader, SetSpvInst, nullptr); + loader.EndModule(); + + spvContextDestroy(context); + + return status == SPV_SUCCESS ? std::move(irContext) : nullptr; +} + +std::unique_ptr BuildModule(spv_target_env env, + MessageConsumer consumer, + const std::string& text, + uint32_t assemble_options) { + SpirvTools t(env); + t.SetMessageConsumer(consumer); + std::vector binary; + if (!t.Assemble(text, &binary, assemble_options)) return nullptr; + return BuildModule(env, consumer, binary.data(), binary.size()); +} + +} // namespace spvtools diff --git a/source/opt/build_module.h b/source/opt/build_module.h new file mode 100644 index 000000000..c9d1cf2e4 --- /dev/null +++ b/source/opt/build_module.h @@ -0,0 +1,46 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_BUILD_MODULE_H_ +#define SOURCE_OPT_BUILD_MODULE_H_ + +#include +#include + +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "spirv-tools/libspirv.hpp" + +namespace spvtools { + +// Builds an Module returns the owning IRContext from the given SPIR-V +// |binary|. |size| specifies number of words in |binary|. The |binary| will be +// decoded according to the given target |env|. Returns nullptr if errors occur +// and sends the errors to |consumer|. +std::unique_ptr BuildModule(spv_target_env env, + MessageConsumer consumer, + const uint32_t* binary, + size_t size); + +// Builds an Module and returns the owning IRContext from the given +// SPIR-V assembly |text|. The |text| will be encoded according to the given +// target |env|. Returns nullptr if errors occur and sends the errors to +// |consumer|. +std::unique_ptr BuildModule( + spv_target_env env, MessageConsumer consumer, const std::string& text, + uint32_t assemble_options = SpirvTools::kDefaultAssembleOption); + +} // namespace spvtools + +#endif // SOURCE_OPT_BUILD_MODULE_H_ diff --git a/source/opt/ccp_pass.cpp b/source/opt/ccp_pass.cpp new file mode 100644 index 000000000..835619530 --- /dev/null +++ b/source/opt/ccp_pass.cpp @@ -0,0 +1,328 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file implements conditional constant propagation as described in +// +// Constant propagation with conditional branches, +// Wegman and Zadeck, ACM TOPLAS 13(2):181-210. + +#include "source/opt/ccp_pass.h" + +#include +#include + +#include "source/opt/fold.h" +#include "source/opt/function.h" +#include "source/opt/module.h" +#include "source/opt/propagator.h" + +namespace spvtools { +namespace opt { + +namespace { + +// This SSA id is never defined nor referenced in the IR. It is a special ID +// which represents varying values. When an ID is found to have a varying +// value, its entry in the |values_| table maps to kVaryingSSAId. +const uint32_t kVaryingSSAId = std::numeric_limits::max(); + +} // namespace + +bool CCPPass::IsVaryingValue(uint32_t id) const { return id == kVaryingSSAId; } + +SSAPropagator::PropStatus CCPPass::MarkInstructionVarying(Instruction* instr) { + assert(instr->result_id() != 0 && + "Instructions with no result cannot be marked varying."); + values_[instr->result_id()] = kVaryingSSAId; + return SSAPropagator::kVarying; +} + +SSAPropagator::PropStatus CCPPass::VisitPhi(Instruction* phi) { + uint32_t meet_val_id = 0; + + // Implement the lattice meet operation. The result of this Phi instruction is + // interesting only if the meet operation over arguments coming through + // executable edges yields the same constant value. + for (uint32_t i = 2; i < phi->NumOperands(); i += 2) { + if (!propagator_->IsPhiArgExecutable(phi, i)) { + // Ignore arguments coming through non-executable edges. + continue; + } + uint32_t phi_arg_id = phi->GetSingleWordOperand(i); + auto it = values_.find(phi_arg_id); + if (it != values_.end()) { + // We found an argument with a constant value. Apply the meet operation + // with the previous arguments. + if (it->second == kVaryingSSAId) { + // The "constant" value is actually a placeholder for varying. Return + // varying for this phi. + return MarkInstructionVarying(phi); + } else if (meet_val_id == 0) { + // This is the first argument we find. Initialize the result to its + // constant value id. + meet_val_id = it->second; + } else if (it->second == meet_val_id) { + // The argument is the same constant value already computed. Continue + // looking. + continue; + } else { + // We either found a varying value, or another constant value different + // from the previous computed meet value. This Phi will never be + // constant. + return MarkInstructionVarying(phi); + } + } else { + // The incoming value has no recorded value and is therefore not + // interesting. A not interesting value joined with any other value is the + // other value. + continue; + } + } + + // If there are no incoming executable edges, the meet ID will still be 0. In + // that case, return not interesting to evaluate the Phi node again. + if (meet_val_id == 0) { + return SSAPropagator::kNotInteresting; + } + + // All the operands have the same constant value represented by |meet_val_id|. + // Set the Phi's result to that value and declare it interesting. + values_[phi->result_id()] = meet_val_id; + return SSAPropagator::kInteresting; +} + +SSAPropagator::PropStatus CCPPass::VisitAssignment(Instruction* instr) { + assert(instr->result_id() != 0 && + "Expecting an instruction that produces a result"); + + // If this is a copy operation, and the RHS is a known constant, assign its + // value to the LHS. + if (instr->opcode() == SpvOpCopyObject) { + uint32_t rhs_id = instr->GetSingleWordInOperand(0); + auto it = values_.find(rhs_id); + if (it != values_.end()) { + if (IsVaryingValue(it->second)) { + return MarkInstructionVarying(instr); + } else { + values_[instr->result_id()] = it->second; + return SSAPropagator::kInteresting; + } + } + return SSAPropagator::kNotInteresting; + } + + // Instructions with a RHS that cannot produce a constant are always varying. + if (!instr->IsFoldable()) { + return MarkInstructionVarying(instr); + } + + // See if the RHS of the assignment folds into a constant value. + auto map_func = [this](uint32_t id) { + auto it = values_.find(id); + if (it == values_.end() || IsVaryingValue(it->second)) { + return id; + } + return it->second; + }; + Instruction* folded_inst = + context()->get_instruction_folder().FoldInstructionToConstant(instr, + map_func); + if (folded_inst != nullptr) { + // We do not want to change the body of the function by adding new + // instructions. When folding we can only generate new constants. + assert(folded_inst->IsConstant() && "CCP is only interested in constant."); + values_[instr->result_id()] = folded_inst->result_id(); + return SSAPropagator::kInteresting; + } + + // Conservatively mark this instruction as varying if any input id is varying. + if (!instr->WhileEachInId([this](uint32_t* op_id) { + auto iter = values_.find(*op_id); + if (iter != values_.end() && IsVaryingValue(iter->second)) return false; + return true; + })) { + return MarkInstructionVarying(instr); + } + + // If not, see if there is a least one unknown operand to the instruction. If + // so, we might be able to fold it later. + if (!instr->WhileEachInId([this](uint32_t* op_id) { + auto it = values_.find(*op_id); + if (it == values_.end()) return false; + return true; + })) { + return SSAPropagator::kNotInteresting; + } + + // Otherwise, we will never be able to fold this instruction, so mark it + // varying. + return MarkInstructionVarying(instr); +} + +SSAPropagator::PropStatus CCPPass::VisitBranch(Instruction* instr, + BasicBlock** dest_bb) const { + assert(instr->IsBranch() && "Expected a branch instruction."); + + *dest_bb = nullptr; + uint32_t dest_label = 0; + if (instr->opcode() == SpvOpBranch) { + // An unconditional jump always goes to its unique destination. + dest_label = instr->GetSingleWordInOperand(0); + } else if (instr->opcode() == SpvOpBranchConditional) { + // For a conditional branch, determine whether the predicate selector has a + // known value in |values_|. If it does, set the destination block + // according to the selector's boolean value. + uint32_t pred_id = instr->GetSingleWordOperand(0); + auto it = values_.find(pred_id); + if (it == values_.end() || IsVaryingValue(it->second)) { + // The predicate has an unknown value, either branch could be taken. + return SSAPropagator::kVarying; + } + + // Get the constant value for the predicate selector from the value table. + // Use it to decide which branch will be taken. + uint32_t pred_val_id = it->second; + const analysis::Constant* c = const_mgr_->FindDeclaredConstant(pred_val_id); + assert(c && "Expected to find a constant declaration for a known value."); + // Undef values should have returned as varying above. + assert(c->AsBoolConstant() || c->AsNullConstant()); + if (c->AsNullConstant()) { + dest_label = instr->GetSingleWordOperand(2u); + } else { + const analysis::BoolConstant* val = c->AsBoolConstant(); + dest_label = val->value() ? instr->GetSingleWordOperand(1) + : instr->GetSingleWordOperand(2); + } + } else { + // For an OpSwitch, extract the value taken by the switch selector and check + // which of the target literals it matches. The branch associated with that + // literal is the taken branch. + assert(instr->opcode() == SpvOpSwitch); + if (instr->GetOperand(0).words.size() != 1) { + // If the selector is wider than 32-bits, return varying. TODO(dnovillo): + // Add support for wider constants. + return SSAPropagator::kVarying; + } + uint32_t select_id = instr->GetSingleWordOperand(0); + auto it = values_.find(select_id); + if (it == values_.end() || IsVaryingValue(it->second)) { + // The selector has an unknown value, any of the branches could be taken. + return SSAPropagator::kVarying; + } + + // Get the constant value for the selector from the value table. Use it to + // decide which branch will be taken. + uint32_t select_val_id = it->second; + const analysis::Constant* c = + const_mgr_->FindDeclaredConstant(select_val_id); + assert(c && "Expected to find a constant declaration for a known value."); + // TODO: support 64-bit integer switches. + uint32_t constant_cond = 0; + if (const analysis::IntConstant* val = c->AsIntConstant()) { + constant_cond = val->words()[0]; + } else { + // Undef values should have returned varying above. + assert(c->AsNullConstant()); + constant_cond = 0; + } + + // Start assuming that the selector will take the default value; + dest_label = instr->GetSingleWordOperand(1); + for (uint32_t i = 2; i < instr->NumOperands(); i += 2) { + if (constant_cond == instr->GetSingleWordOperand(i)) { + dest_label = instr->GetSingleWordOperand(i + 1); + break; + } + } + } + + assert(dest_label && "Destination label should be set at this point."); + *dest_bb = context()->cfg()->block(dest_label); + return SSAPropagator::kInteresting; +} + +SSAPropagator::PropStatus CCPPass::VisitInstruction(Instruction* instr, + BasicBlock** dest_bb) { + *dest_bb = nullptr; + if (instr->opcode() == SpvOpPhi) { + return VisitPhi(instr); + } else if (instr->IsBranch()) { + return VisitBranch(instr, dest_bb); + } else if (instr->result_id()) { + return VisitAssignment(instr); + } + return SSAPropagator::kVarying; +} + +bool CCPPass::ReplaceValues() { + bool retval = false; + for (const auto& it : values_) { + uint32_t id = it.first; + uint32_t cst_id = it.second; + if (!IsVaryingValue(cst_id) && id != cst_id) { + retval |= context()->ReplaceAllUsesWith(id, cst_id); + } + } + return retval; +} + +bool CCPPass::PropagateConstants(Function* fp) { + // Mark function parameters as varying. + fp->ForEachParam([this](const Instruction* inst) { + values_[inst->result_id()] = kVaryingSSAId; + }); + + const auto visit_fn = [this](Instruction* instr, BasicBlock** dest_bb) { + return VisitInstruction(instr, dest_bb); + }; + + propagator_ = + std::unique_ptr(new SSAPropagator(context(), visit_fn)); + + if (propagator_->Run(fp)) { + return ReplaceValues(); + } + + return false; +} + +void CCPPass::Initialize() { + const_mgr_ = context()->get_constant_mgr(); + + // Populate the constant table with values from constant declarations in the + // module. The values of each OpConstant declaration is the identity + // assignment (i.e., each constant is its own value). + for (const auto& inst : get_module()->types_values()) { + // Record compile time constant ids. Treat all other global values as + // varying. + if (inst.IsConstant()) { + values_[inst.result_id()] = inst.result_id(); + } else { + values_[inst.result_id()] = kVaryingSSAId; + } + } +} + +Pass::Status CCPPass::Process() { + Initialize(); + + // Process all entry point functions. + ProcessFunction pfn = [this](Function* fp) { return PropagateConstants(fp); }; + bool modified = context()->ProcessReachableCallTree(pfn); + return modified ? Pass::Status::SuccessWithChange + : Pass::Status::SuccessWithoutChange; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/ccp_pass.h b/source/opt/ccp_pass.h new file mode 100644 index 000000000..527f459c5 --- /dev/null +++ b/source/opt/ccp_pass.h @@ -0,0 +1,113 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_CCP_PASS_H_ +#define SOURCE_OPT_CCP_PASS_H_ + +#include +#include + +#include "source/opt/constants.h" +#include "source/opt/function.h" +#include "source/opt/ir_context.h" +#include "source/opt/mem_pass.h" +#include "source/opt/module.h" +#include "source/opt/propagator.h" + +namespace spvtools { +namespace opt { + +class CCPPass : public MemPass { + public: + CCPPass() = default; + + const char* name() const override { return "ccp"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators | + IRContext::kAnalysisCFG | IRContext::kAnalysisDominatorAnalysis | + IRContext::kAnalysisNameMap | IRContext::kAnalysisConstants | + IRContext::kAnalysisTypes; + } + + private: + // Initializes the pass. + void Initialize(); + + // Runs constant propagation on the given function |fp|. Returns true if any + // constants were propagated and the IR modified. + bool PropagateConstants(Function* fp); + + // Visits a single instruction |instr|. If the instruction is a conditional + // branch that always jumps to the same basic block, it sets the destination + // block in |dest_bb|. + SSAPropagator::PropStatus VisitInstruction(Instruction* instr, + BasicBlock** dest_bb); + + // Visits an OpPhi instruction |phi|. This applies the meet operator for the + // CCP lattice. Essentially, if all the operands in |phi| have the same + // constant value C, the result for |phi| gets assigned the value C. + SSAPropagator::PropStatus VisitPhi(Instruction* phi); + + // Visits an SSA assignment instruction |instr|. If the RHS of |instr| folds + // into a constant value C, then the LHS of |instr| is assigned the value C in + // |values_|. + SSAPropagator::PropStatus VisitAssignment(Instruction* instr); + + // Visits a branch instruction |instr|. If the branch is conditional + // (OpBranchConditional or OpSwitch), and the value of its selector is known, + // |dest_bb| will be set to the corresponding destination block. Unconditional + // branches always set |dest_bb| to the single destination block. + SSAPropagator::PropStatus VisitBranch(Instruction* instr, + BasicBlock** dest_bb) const; + + // Replaces all operands used in |fp| with the corresponding constant values + // in |values_|. Returns true if any operands were replaced, and false + // otherwise. + bool ReplaceValues(); + + // Marks |instr| as varying by registering a varying value for its result + // into the |values_| table. Returns SSAPropagator::kVarying. + SSAPropagator::PropStatus MarkInstructionVarying(Instruction* instr); + + // Returns true if |id| is the special SSA id that corresponds to a varying + // value. + bool IsVaryingValue(uint32_t id) const; + + // Constant manager for the parent IR context. Used to record new constants + // generated during propagation. + analysis::ConstantManager* const_mgr_; + + // Constant value table. Each entry in this map + // represents the compile-time constant value for |id| as declared by + // |const_decl_id|. Each |const_decl_id| in this table is an OpConstant + // declaration for the current module. + // + // Additionally, this table keeps track of SSA IDs with varying values. If an + // SSA ID is found to have a varying value, it will have an entry in this + // table that maps to the special SSA id kVaryingSSAId. These values are + // never replaced in the IR, they are used by CCP during propagation. + std::unordered_map values_; + + // Propagator engine used. + std::unique_ptr propagator_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_CCP_PASS_H_ diff --git a/source/opt/cfg.cpp b/source/opt/cfg.cpp new file mode 100644 index 000000000..7e1097e38 --- /dev/null +++ b/source/opt/cfg.cpp @@ -0,0 +1,317 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/cfg.h" + +#include +#include + +#include "source/cfa.h" +#include "source/opt/ir_builder.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" + +namespace spvtools { +namespace opt { +namespace { + +using cbb_ptr = const opt::BasicBlock*; + +// Universal Limit of ResultID + 1 +const int kMaxResultId = 0x400000; + +} // namespace + +CFG::CFG(Module* module) + : module_(module), + pseudo_entry_block_(std::unique_ptr( + new Instruction(module->context(), SpvOpLabel, 0, 0, {}))), + pseudo_exit_block_(std::unique_ptr(new Instruction( + module->context(), SpvOpLabel, 0, kMaxResultId, {}))) { + for (auto& fn : *module) { + for (auto& blk : fn) { + RegisterBlock(&blk); + } + } +} + +void CFG::AddEdges(BasicBlock* blk) { + uint32_t blk_id = blk->id(); + // Force the creation of an entry, not all basic block have predecessors + // (such as the entry blocks and some unreachables). + label2preds_[blk_id]; + const auto* const_blk = blk; + const_blk->ForEachSuccessorLabel( + [blk_id, this](const uint32_t succ_id) { AddEdge(blk_id, succ_id); }); +} + +void CFG::RemoveNonExistingEdges(uint32_t blk_id) { + std::vector updated_pred_list; + for (uint32_t id : preds(blk_id)) { + const BasicBlock* pred_blk = block(id); + bool has_branch = false; + pred_blk->ForEachSuccessorLabel([&has_branch, blk_id](uint32_t succ) { + if (succ == blk_id) { + has_branch = true; + } + }); + if (has_branch) updated_pred_list.push_back(id); + } + + label2preds_.at(blk_id) = std::move(updated_pred_list); +} + +void CFG::ComputeStructuredOrder(Function* func, BasicBlock* root, + std::list* order) { + assert(module_->context()->get_feature_mgr()->HasCapability( + SpvCapabilityShader) && + "This only works on structured control flow"); + + // Compute structured successors and do DFS. + ComputeStructuredSuccessors(func); + auto ignore_block = [](cbb_ptr) {}; + auto ignore_edge = [](cbb_ptr, cbb_ptr) {}; + auto get_structured_successors = [this](const BasicBlock* b) { + return &(block2structured_succs_[b]); + }; + + // TODO(greg-lunarg): Get rid of const_cast by making moving const + // out of the cfa.h prototypes and into the invoking code. + auto post_order = [&](cbb_ptr b) { + order->push_front(const_cast(b)); + }; + CFA::DepthFirstTraversal(root, get_structured_successors, + ignore_block, post_order, ignore_edge); +} + +void CFG::ForEachBlockInPostOrder(BasicBlock* bb, + const std::function& f) { + std::vector po; + std::unordered_set seen; + ComputePostOrderTraversal(bb, &po, &seen); + + for (BasicBlock* current_bb : po) { + if (!IsPseudoExitBlock(current_bb) && !IsPseudoEntryBlock(current_bb)) { + f(current_bb); + } + } +} + +void CFG::ForEachBlockInReversePostOrder( + BasicBlock* bb, const std::function& f) { + std::vector po; + std::unordered_set seen; + ComputePostOrderTraversal(bb, &po, &seen); + + for (auto current_bb = po.rbegin(); current_bb != po.rend(); ++current_bb) { + if (!IsPseudoExitBlock(*current_bb) && !IsPseudoEntryBlock(*current_bb)) { + f(*current_bb); + } + } +} + +void CFG::ComputeStructuredSuccessors(Function* func) { + block2structured_succs_.clear(); + for (auto& blk : *func) { + // If no predecessors in function, make successor to pseudo entry. + if (label2preds_[blk.id()].size() == 0) + block2structured_succs_[&pseudo_entry_block_].push_back(&blk); + + // If header, make merge block first successor and continue block second + // successor if there is one. + uint32_t mbid = blk.MergeBlockIdIfAny(); + if (mbid != 0) { + block2structured_succs_[&blk].push_back(block(mbid)); + uint32_t cbid = blk.ContinueBlockIdIfAny(); + if (cbid != 0) { + block2structured_succs_[&blk].push_back(block(cbid)); + } + } + + // Add true successors. + const auto& const_blk = blk; + const_blk.ForEachSuccessorLabel([&blk, this](const uint32_t sbid) { + block2structured_succs_[&blk].push_back(block(sbid)); + }); + } +} + +void CFG::ComputePostOrderTraversal(BasicBlock* bb, + std::vector* order, + std::unordered_set* seen) { + seen->insert(bb); + static_cast(bb)->ForEachSuccessorLabel( + [&order, &seen, this](const uint32_t sbid) { + BasicBlock* succ_bb = id2block_[sbid]; + if (!seen->count(succ_bb)) { + ComputePostOrderTraversal(succ_bb, order, seen); + } + }); + order->push_back(bb); +} + +BasicBlock* CFG::SplitLoopHeader(BasicBlock* bb) { + assert(bb->GetLoopMergeInst() && "Expecting bb to be the header of a loop."); + + Function* fn = bb->GetParent(); + IRContext* context = module_->context(); + + // Get the new header id up front. If we are out of ids, then we cannot split + // the loop. + uint32_t new_header_id = context->TakeNextId(); + if (new_header_id == 0) { + return nullptr; + } + + // Find the insertion point for the new bb. + Function::iterator header_it = std::find_if( + fn->begin(), fn->end(), + [bb](BasicBlock& block_in_func) { return &block_in_func == bb; }); + assert(header_it != fn->end()); + + const std::vector& pred = preds(bb->id()); + // Find the back edge + BasicBlock* latch_block = nullptr; + Function::iterator latch_block_iter = header_it; + while (++latch_block_iter != fn->end()) { + // If blocks are in the proper order, then the only branch that appears + // after the header is the latch. + if (std::find(pred.begin(), pred.end(), latch_block_iter->id()) != + pred.end()) { + break; + } + } + assert(latch_block_iter != fn->end() && "Could not find the latch."); + latch_block = &*latch_block_iter; + + RemoveSuccessorEdges(bb); + + // Create the new header bb basic bb. + // Leave the phi instructions behind. + auto iter = bb->begin(); + while (iter->opcode() == SpvOpPhi) { + ++iter; + } + + BasicBlock* new_header = bb->SplitBasicBlock(context, new_header_id, iter); + context->AnalyzeDefUse(new_header->GetLabelInst()); + + // Update cfg + RegisterBlock(new_header); + + // Update bb mappings. + context->set_instr_block(new_header->GetLabelInst(), new_header); + new_header->ForEachInst([new_header, context](Instruction* inst) { + context->set_instr_block(inst, new_header); + }); + + // Adjust the OpPhi instructions as needed. + bb->ForEachPhiInst([latch_block, bb, new_header, context](Instruction* phi) { + std::vector preheader_phi_ops; + std::vector header_phi_ops; + + // Identify where the original inputs to original OpPhi belong: header or + // preheader. + for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) { + uint32_t def_id = phi->GetSingleWordInOperand(i); + uint32_t branch_id = phi->GetSingleWordInOperand(i + 1); + if (branch_id == latch_block->id()) { + header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {def_id}}); + header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {branch_id}}); + } else { + preheader_phi_ops.push_back(def_id); + preheader_phi_ops.push_back(branch_id); + } + } + + // Create a phi instruction if and only if the preheader_phi_ops has more + // than one pair. + if (preheader_phi_ops.size() > 2) { + InstructionBuilder builder( + context, &*bb->begin(), + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + + Instruction* new_phi = builder.AddPhi(phi->type_id(), preheader_phi_ops); + + // Add the OpPhi to the header bb. + header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {new_phi->result_id()}}); + header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {bb->id()}}); + } else { + // An OpPhi with a single entry is just a copy. In this case use the same + // instruction in the new header. + header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {preheader_phi_ops[0]}}); + header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {bb->id()}}); + } + + phi->RemoveFromList(); + std::unique_ptr phi_owner(phi); + phi->SetInOperands(std::move(header_phi_ops)); + new_header->begin()->InsertBefore(std::move(phi_owner)); + context->set_instr_block(phi, new_header); + context->AnalyzeUses(phi); + }); + + // Add a branch to the new header. + InstructionBuilder branch_builder( + context, bb, + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + bb->AddInstruction( + MakeUnique(context, SpvOpBranch, 0, 0, + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {new_header->id()}}})); + context->AnalyzeUses(bb->terminator()); + context->set_instr_block(bb->terminator(), bb); + label2preds_[new_header->id()].push_back(bb->id()); + + // Update the latch to branch to the new header. + latch_block->ForEachSuccessorLabel([bb, new_header_id](uint32_t* id) { + if (*id == bb->id()) { + *id = new_header_id; + } + }); + Instruction* latch_branch = latch_block->terminator(); + context->AnalyzeUses(latch_branch); + label2preds_[new_header->id()].push_back(latch_block->id()); + + auto& block_preds = label2preds_[bb->id()]; + auto latch_pos = + std::find(block_preds.begin(), block_preds.end(), latch_block->id()); + assert(latch_pos != block_preds.end() && "The cfg was invalid."); + block_preds.erase(latch_pos); + + // Update the loop descriptors + if (context->AreAnalysesValid(IRContext::kAnalysisLoopAnalysis)) { + LoopDescriptor* loop_desc = context->GetLoopDescriptor(bb->GetParent()); + Loop* loop = (*loop_desc)[bb->id()]; + + loop->AddBasicBlock(new_header_id); + loop->SetHeaderBlock(new_header); + loop_desc->SetBasicBlockToLoop(new_header_id, loop); + + loop->RemoveBasicBlock(bb->id()); + loop->SetPreHeaderBlock(bb); + + Loop* parent_loop = loop->GetParent(); + if (parent_loop != nullptr) { + parent_loop->AddBasicBlock(bb->id()); + loop_desc->SetBasicBlockToLoop(bb->id(), parent_loop); + } else { + loop_desc->SetBasicBlockToLoop(bb->id(), nullptr); + } + } + return new_header; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/cfg.h b/source/opt/cfg.h new file mode 100644 index 000000000..5ff3aa03e --- /dev/null +++ b/source/opt/cfg.h @@ -0,0 +1,177 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_CFG_H_ +#define SOURCE_OPT_CFG_H_ + +#include +#include +#include +#include +#include + +#include "source/opt/basic_block.h" + +namespace spvtools { +namespace opt { + +class CFG { + public: + explicit CFG(Module* module); + + // Return the list of predecesors for basic block with label |blkid|. + // TODO(dnovillo): Move this to BasicBlock. + const std::vector& preds(uint32_t blk_id) const { + assert(label2preds_.count(blk_id)); + return label2preds_.at(blk_id); + } + + // Return a pointer to the basic block instance corresponding to the label + // |blk_id|. + BasicBlock* block(uint32_t blk_id) const { return id2block_.at(blk_id); } + + // Return the pseudo entry and exit blocks. + const BasicBlock* pseudo_entry_block() const { return &pseudo_entry_block_; } + BasicBlock* pseudo_entry_block() { return &pseudo_entry_block_; } + + const BasicBlock* pseudo_exit_block() const { return &pseudo_exit_block_; } + BasicBlock* pseudo_exit_block() { return &pseudo_exit_block_; } + + // Return true if |block_ptr| is the pseudo-entry block. + bool IsPseudoEntryBlock(BasicBlock* block_ptr) const { + return block_ptr == &pseudo_entry_block_; + } + + // Return true if |block_ptr| is the pseudo-exit block. + bool IsPseudoExitBlock(BasicBlock* block_ptr) const { + return block_ptr == &pseudo_exit_block_; + } + + // Compute structured block order into |order| for |func| starting at |root|. + // This order has the property that dominators come before all blocks they + // dominate and merge blocks come after all blocks that are in the control + // constructs of their header. + void ComputeStructuredOrder(Function* func, BasicBlock* root, + std::list* order); + + // Applies |f| to the basic block in post order starting with |bb|. + // Note that basic blocks that cannot be reached from |bb| node will not be + // processed. + void ForEachBlockInPostOrder(BasicBlock* bb, + const std::function& f); + + // Applies |f| to the basic block in reverse post order starting with |bb|. + // Note that basic blocks that cannot be reached from |bb| node will not be + // processed. + void ForEachBlockInReversePostOrder( + BasicBlock* bb, const std::function& f); + + // Registers |blk| as a basic block in the cfg, this also updates the + // predecessor lists of each successor of |blk|. |blk| must have a terminator + // instruction at the end of the block. + void RegisterBlock(BasicBlock* blk) { + assert(blk->begin() != blk->end() && + "Basic blocks must have a terminator before registering."); + assert(blk->tail()->IsBlockTerminator() && + "Basic blocks must have a terminator before registering."); + uint32_t blk_id = blk->id(); + id2block_[blk_id] = blk; + AddEdges(blk); + } + + // Removes from the CFG any mapping for the basic block id |blk_id|. + void ForgetBlock(const BasicBlock* blk) { + id2block_.erase(blk->id()); + label2preds_.erase(blk->id()); + RemoveSuccessorEdges(blk); + } + + void RemoveEdge(uint32_t pred_blk_id, uint32_t succ_blk_id) { + auto pred_it = label2preds_.find(succ_blk_id); + if (pred_it == label2preds_.end()) return; + auto& preds_list = pred_it->second; + auto it = std::find(preds_list.begin(), preds_list.end(), pred_blk_id); + if (it != preds_list.end()) preds_list.erase(it); + } + + // Registers |blk| to all of its successors. + void AddEdges(BasicBlock* blk); + + // Registers the basic block id |pred_blk_id| as being a predecessor of the + // basic block id |succ_blk_id|. + void AddEdge(uint32_t pred_blk_id, uint32_t succ_blk_id) { + label2preds_[succ_blk_id].push_back(pred_blk_id); + } + + // Removes any edges that no longer exist from the predecessor mapping for + // the basic block id |blk_id|. + void RemoveNonExistingEdges(uint32_t blk_id); + + // Remove all edges that leave |bb|. + void RemoveSuccessorEdges(const BasicBlock* bb) { + bb->ForEachSuccessorLabel( + [bb, this](uint32_t succ_id) { RemoveEdge(bb->id(), succ_id); }); + } + + // Divides |block| into two basic blocks. The first block will have the same + // id as |block| and will become a preheader for the loop. The other block + // is a new block that will be the new loop header. + // + // Returns a pointer to the new loop header. Returns |nullptr| if the new + // loop pointer could not be created. + BasicBlock* SplitLoopHeader(BasicBlock* bb); + + private: + // Compute structured successors for function |func|. A block's structured + // successors are the blocks it branches to together with its declared merge + // block and continue block if it has them. When order matters, the merge + // block and continue block always appear first. This assures correct depth + // first search in the presence of early returns and kills. If the successor + // vector contain duplicates of the merge or continue blocks, they are safely + // ignored by DFS. + void ComputeStructuredSuccessors(Function* func); + + // Computes the post-order traversal of the cfg starting at |bb| skipping + // nodes in |seen|. The order of the traversal is appended to |order|, and + // all nodes in the traversal are added to |seen|. + void ComputePostOrderTraversal(BasicBlock* bb, + std::vector* order, + std::unordered_set* seen); + + // Module for this CFG. + Module* module_; + + // Map from block to its structured successor blocks. See + // ComputeStructuredSuccessors() for definition. + std::unordered_map> + block2structured_succs_; + + // Extra block whose successors are all blocks with no predecessors + // in function. + BasicBlock pseudo_entry_block_; + + // Augmented CFG Exit Block. + BasicBlock pseudo_exit_block_; + + // Map from block's label id to its predecessor blocks ids + std::unordered_map> label2preds_; + + // Map from block's label id to block. + std::unordered_map id2block_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_CFG_H_ diff --git a/source/opt/cfg_cleanup_pass.cpp b/source/opt/cfg_cleanup_pass.cpp new file mode 100644 index 000000000..6d48637a4 --- /dev/null +++ b/source/opt/cfg_cleanup_pass.cpp @@ -0,0 +1,39 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file implements a pass to cleanup the CFG to remove superfluous +// constructs (e.g., unreachable basic blocks, empty control flow structures, +// etc) + +#include +#include + +#include "source/opt/cfg_cleanup_pass.h" + +#include "source/opt/function.h" +#include "source/opt/module.h" + +namespace spvtools { +namespace opt { + +Pass::Status CFGCleanupPass::Process() { + // Process all entry point functions. + ProcessFunction pfn = [this](Function* fp) { return CFGCleanup(fp); }; + bool modified = context()->ProcessReachableCallTree(pfn); + return modified ? Pass::Status::SuccessWithChange + : Pass::Status::SuccessWithoutChange; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/cfg_cleanup_pass.h b/source/opt/cfg_cleanup_pass.h new file mode 100644 index 000000000..509542890 --- /dev/null +++ b/source/opt/cfg_cleanup_pass.h @@ -0,0 +1,41 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_CFG_CLEANUP_PASS_H_ +#define SOURCE_OPT_CFG_CLEANUP_PASS_H_ + +#include "source/opt/function.h" +#include "source/opt/mem_pass.h" +#include "source/opt/module.h" + +namespace spvtools { +namespace opt { + +class CFGCleanupPass : public MemPass { + public: + CFGCleanupPass() = default; + + const char* name() const override { return "cfg-cleanup"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | IRContext::kAnalysisConstants | + IRContext::kAnalysisTypes; + } +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_CFG_CLEANUP_PASS_H_ diff --git a/source/opt/code_sink.cpp b/source/opt/code_sink.cpp new file mode 100644 index 000000000..e1e819017 --- /dev/null +++ b/source/opt/code_sink.cpp @@ -0,0 +1,316 @@ +// Copyright (c) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "code_sink.h" + +#include +#include + +#include "source/opt/instruction.h" +#include "source/opt/ir_builder.h" +#include "source/opt/ir_context.h" +#include "source/util/bit_vector.h" + +namespace spvtools { +namespace opt { + +Pass::Status CodeSinkingPass::Process() { + bool modified = false; + for (Function& function : *get_module()) { + cfg()->ForEachBlockInPostOrder(function.entry().get(), + [&modified, this](BasicBlock* bb) { + if (SinkInstructionsInBB(bb)) { + modified = true; + } + }); + } + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +bool CodeSinkingPass::SinkInstructionsInBB(BasicBlock* bb) { + bool modified = false; + for (auto inst = bb->rbegin(); inst != bb->rend(); ++inst) { + if (SinkInstruction(&*inst)) { + inst = bb->rbegin(); + modified = true; + } + } + return modified; +} + +bool CodeSinkingPass::SinkInstruction(Instruction* inst) { + if (inst->opcode() != SpvOpLoad && inst->opcode() != SpvOpAccessChain) { + return false; + } + + if (ReferencesMutableMemory(inst)) { + return false; + } + + if (BasicBlock* target_bb = FindNewBasicBlockFor(inst)) { + Instruction* pos = &*target_bb->begin(); + while (pos->opcode() == SpvOpPhi) { + pos = pos->NextNode(); + } + + inst->InsertBefore(pos); + context()->set_instr_block(inst, target_bb); + return true; + } + return false; +} + +BasicBlock* CodeSinkingPass::FindNewBasicBlockFor(Instruction* inst) { + assert(inst->result_id() != 0 && "Instruction should have a result."); + BasicBlock* original_bb = context()->get_instr_block(inst); + BasicBlock* bb = original_bb; + + std::unordered_set bbs_with_uses; + get_def_use_mgr()->ForEachUse( + inst, [&bbs_with_uses, this](Instruction* use, uint32_t idx) { + if (use->opcode() != SpvOpPhi) { + bbs_with_uses.insert(context()->get_instr_block(use)->id()); + } else { + bbs_with_uses.insert(use->GetSingleWordOperand(idx + 1)); + } + }); + + while (true) { + // If |inst| is used in |bb|, then |inst| cannot be moved any further. + if (bbs_with_uses.count(bb->id())) { + break; + } + + // If |bb| has one successor (succ_bb), and |bb| is the only predecessor + // of succ_bb, then |inst| can be moved to succ_bb. If succ_bb, has move + // then one predecessor, then moving |inst| into succ_bb could cause it to + // be executed more often, so the search has to stop. + if (bb->terminator()->opcode() == SpvOpBranch) { + uint32_t succ_bb_id = bb->terminator()->GetSingleWordInOperand(0); + if (cfg()->preds(succ_bb_id).size() == 1) { + bb = context()->get_instr_block(succ_bb_id); + continue; + } else { + break; + } + } + + // The remaining checks need to know the merge node. If there is no merge + // instruction or an OpLoopMerge, then it is a break or continue. We could + // figure it out, but not worth doing it now. + Instruction* merge_inst = bb->GetMergeInst(); + if (merge_inst == nullptr || merge_inst->opcode() != SpvOpSelectionMerge) { + break; + } + + // Check all of the successors of |bb| it see which lead to a use of |inst| + // before reaching the merge node. + bool used_in_multiple_blocks = false; + uint32_t bb_used_in = 0; + bb->ForEachSuccessorLabel([this, bb, &bb_used_in, &used_in_multiple_blocks, + &bbs_with_uses](uint32_t* succ_bb_id) { + if (IntersectsPath(*succ_bb_id, bb->MergeBlockIdIfAny(), bbs_with_uses)) { + if (bb_used_in == 0) { + bb_used_in = *succ_bb_id; + } else { + used_in_multiple_blocks = true; + } + } + }); + + // If more than one successor, which is not the merge block, uses |inst| + // then we have to leave |inst| in bb because there is none of the + // successors dominate all uses of |inst|. + if (used_in_multiple_blocks) { + break; + } + + if (bb_used_in == 0) { + // If |inst| is not used before reaching the merge node, then we can move + // |inst| to the merge node. + bb = context()->get_instr_block(bb->MergeBlockIdIfAny()); + } else { + // If the only successor that leads to a used of |inst| has more than 1 + // predecessor, then moving |inst| could cause it to be executed more + // often, so we cannot move it. + if (cfg()->preds(bb_used_in).size() != 1) { + break; + } + + // If |inst| is used after the merge block, then |bb_used_in| does not + // dominate all of the uses. So we cannot move |inst| any further. + if (IntersectsPath(bb->MergeBlockIdIfAny(), original_bb->id(), + bbs_with_uses)) { + break; + } + + // Otherwise, |bb_used_in| dominates all uses, so move |inst| into that + // block. + bb = context()->get_instr_block(bb_used_in); + } + continue; + } + return (bb != original_bb ? bb : nullptr); +} + +bool CodeSinkingPass::ReferencesMutableMemory(Instruction* inst) { + if (!inst->IsLoad()) { + return false; + } + + Instruction* base_ptr = inst->GetBaseAddress(); + if (base_ptr->opcode() != SpvOpVariable) { + return true; + } + + if (base_ptr->IsReadOnlyVariable()) { + return false; + } + + if (HasUniformMemorySync()) { + return true; + } + + if (base_ptr->GetSingleWordInOperand(0) != SpvStorageClassUniform) { + return true; + } + + return HasPossibleStore(base_ptr); +} + +bool CodeSinkingPass::HasUniformMemorySync() { + if (checked_for_uniform_sync_) { + return has_uniform_sync_; + } + + bool has_sync = false; + get_module()->ForEachInst([this, &has_sync](Instruction* inst) { + switch (inst->opcode()) { + case SpvOpMemoryBarrier: { + uint32_t mem_semantics_id = inst->GetSingleWordInOperand(1); + if (IsSyncOnUniform(mem_semantics_id)) { + has_sync = true; + } + break; + } + case SpvOpControlBarrier: + case SpvOpAtomicLoad: + case SpvOpAtomicStore: + case SpvOpAtomicExchange: + case SpvOpAtomicIIncrement: + case SpvOpAtomicIDecrement: + case SpvOpAtomicIAdd: + case SpvOpAtomicISub: + case SpvOpAtomicSMin: + case SpvOpAtomicUMin: + case SpvOpAtomicSMax: + case SpvOpAtomicUMax: + case SpvOpAtomicAnd: + case SpvOpAtomicOr: + case SpvOpAtomicXor: + case SpvOpAtomicFlagTestAndSet: + case SpvOpAtomicFlagClear: { + uint32_t mem_semantics_id = inst->GetSingleWordInOperand(2); + if (IsSyncOnUniform(mem_semantics_id)) { + has_sync = true; + } + break; + } + case SpvOpAtomicCompareExchange: + case SpvOpAtomicCompareExchangeWeak: + if (IsSyncOnUniform(inst->GetSingleWordInOperand(2)) || + IsSyncOnUniform(inst->GetSingleWordInOperand(3))) { + has_sync = true; + } + break; + default: + break; + } + }); + has_uniform_sync_ = has_sync; + return has_sync; +} + +bool CodeSinkingPass::IsSyncOnUniform(uint32_t mem_semantics_id) const { + const analysis::Constant* mem_semantics_const = + context()->get_constant_mgr()->FindDeclaredConstant(mem_semantics_id); + assert(mem_semantics_const != nullptr && + "Expecting memory semantics id to be a constant."); + assert(mem_semantics_const->AsIntConstant() && + "Memory semantics should be an integer."); + uint32_t mem_semantics_int = mem_semantics_const->GetU32(); + + // If it does not affect uniform memory, then it is does not apply to uniform + // memory. + if ((mem_semantics_int & SpvMemorySemanticsUniformMemoryMask) == 0) { + return false; + } + + // Check if there is an acquire or release. If so not, this it does not add + // any memory constraints. + return (mem_semantics_int & (SpvMemorySemanticsAcquireMask | + SpvMemorySemanticsAcquireReleaseMask | + SpvMemorySemanticsReleaseMask)) != 0; +} + +bool CodeSinkingPass::HasPossibleStore(Instruction* var_inst) { + assert(var_inst->opcode() == SpvOpVariable || + var_inst->opcode() == SpvOpAccessChain || + var_inst->opcode() == SpvOpPtrAccessChain); + + return get_def_use_mgr()->WhileEachUser(var_inst, [this](Instruction* use) { + switch (use->opcode()) { + case SpvOpStore: + return true; + case SpvOpAccessChain: + case SpvOpPtrAccessChain: + return HasPossibleStore(use); + default: + return false; + } + }); +} + +bool CodeSinkingPass::IntersectsPath(uint32_t start, uint32_t end, + const std::unordered_set& set) { + std::vector worklist; + worklist.push_back(start); + std::unordered_set already_done; + already_done.insert(start); + + while (!worklist.empty()) { + BasicBlock* bb = context()->get_instr_block(worklist.back()); + worklist.pop_back(); + + if (bb->id() == end) { + continue; + } + + if (set.count(bb->id())) { + return true; + } + + bb->ForEachSuccessorLabel([&already_done, &worklist](uint32_t* succ_bb_id) { + if (already_done.insert(*succ_bb_id).second) { + worklist.push_back(*succ_bb_id); + } + }); + } + return false; +} + +// namespace opt + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/code_sink.h b/source/opt/code_sink.h new file mode 100644 index 000000000..d24df030d --- /dev/null +++ b/source/opt/code_sink.h @@ -0,0 +1,107 @@ +// Copyright (c) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_CODE_SINK_H_ +#define SOURCE_OPT_CODE_SINK_H_ + +#include + +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// This pass does code sinking for OpAccessChain and OpLoad on variables in +// uniform storage or in read only memory. Code sinking is a transformation +// where an instruction is moved into a more deeply nested construct. +// +// The goal is to move these instructions as close as possible to their uses +// without having to execute them more often or to replicate the instruction. +// Moving the instruction in this way can lead to shorter live ranges, which can +// lead to less register pressure. It can also cause instructions to be +// executed less often because they could be moved into one path of a selection +// construct. +// +// This optimization can cause register pressure to rise if the operands of the +// instructions go dead after the instructions being moved. That is why we only +// move certain OpLoad and OpAccessChain instructions. They generally have +// constants, loop induction variables, and global pointers as operands. The +// operands are live for a longer time in most cases. +class CodeSinkingPass : public Pass { + public: + const char* name() const override { return "code-sink"; } + Status Process() override; + + // Return the mask of preserved Analyses. + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisCombinators | IRContext::kAnalysisCFG | + IRContext::kAnalysisDominatorAnalysis | + IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisNameMap | + IRContext::kAnalysisConstants | IRContext::kAnalysisTypes; + } + + private: + // Sinks the instructions in |bb| as much as possible. Returns true if + // something changes. + bool SinkInstructionsInBB(BasicBlock* bb); + + // Tries the sink |inst| as much as possible. Returns true if the instruction + // is moved. + bool SinkInstruction(Instruction* inst); + + // Returns the basic block in which to move |inst| to move is as close as + // possible to the uses of |inst| without increasing the number of times + // |inst| will be executed. Return |nullptr| if there is no need to move + // |inst|. + BasicBlock* FindNewBasicBlockFor(Instruction* inst); + + // Return true if |inst| reference memory and it is possible that the data in + // the memory changes at some point. + bool ReferencesMutableMemory(Instruction* inst); + + // Returns true if the module contains an instruction that has a memory + // semantics id as an operand, and the memory semantics enforces a + // synchronization of uniform memory. See section 3.25 of the SPIR-V + // specification. + bool HasUniformMemorySync(); + + // Returns true if there may be a store to the variable |var_inst|. + bool HasPossibleStore(Instruction* var_inst); + + // Returns true if one of the basic blocks in |set| exists on a path from the + // basic block |start| to |end|. + bool IntersectsPath(uint32_t start, uint32_t end, + const std::unordered_set& set); + + // Returns true if |mem_semantics_id| is the id of a constant that, when + // interpreted as a memory semantics mask enforces synchronization of uniform + // memory. See section 3.25 of the SPIR-V specification. + bool IsSyncOnUniform(uint32_t mem_semantics_id) const; + + // True if a check has for uniform storage has taken place. + bool checked_for_uniform_sync_; + + // Cache of whether or not the module has a memory sync on uniform storage. + // only valid if |check_for_uniform_sync_| is true. + bool has_uniform_sync_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_CODE_SINK_H_ diff --git a/source/opt/combine_access_chains.cpp b/source/opt/combine_access_chains.cpp new file mode 100644 index 000000000..facfc24b6 --- /dev/null +++ b/source/opt/combine_access_chains.cpp @@ -0,0 +1,290 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/combine_access_chains.h" + +#include + +#include "source/opt/constants.h" +#include "source/opt/ir_builder.h" +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace opt { + +Pass::Status CombineAccessChains::Process() { + bool modified = false; + + for (auto& function : *get_module()) { + modified |= ProcessFunction(function); + } + + return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); +} + +bool CombineAccessChains::ProcessFunction(Function& function) { + bool modified = false; + + cfg()->ForEachBlockInReversePostOrder( + function.entry().get(), [&modified, this](BasicBlock* block) { + block->ForEachInst([&modified, this](Instruction* inst) { + switch (inst->opcode()) { + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + case SpvOpPtrAccessChain: + case SpvOpInBoundsPtrAccessChain: + modified |= CombineAccessChain(inst); + break; + default: + break; + } + }); + }); + + return modified; +} + +uint32_t CombineAccessChains::GetConstantValue( + const analysis::Constant* constant_inst) { + if (constant_inst->type()->AsInteger()->width() <= 32) { + if (constant_inst->type()->AsInteger()->IsSigned()) { + return static_cast(constant_inst->GetS32()); + } else { + return constant_inst->GetU32(); + } + } else { + assert(false); + return 0u; + } +} + +uint32_t CombineAccessChains::GetArrayStride(const Instruction* inst) { + uint32_t array_stride = 0; + context()->get_decoration_mgr()->WhileEachDecoration( + inst->type_id(), SpvDecorationArrayStride, + [&array_stride](const Instruction& decoration) { + assert(decoration.opcode() != SpvOpDecorateId); + if (decoration.opcode() == SpvOpDecorate) { + array_stride = decoration.GetSingleWordInOperand(1); + } else { + array_stride = decoration.GetSingleWordInOperand(2); + } + return false; + }); + return array_stride; +} + +const analysis::Type* CombineAccessChains::GetIndexedType(Instruction* inst) { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + + Instruction* base_ptr = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); + const analysis::Type* type = type_mgr->GetType(base_ptr->type_id()); + assert(type->AsPointer()); + type = type->AsPointer()->pointee_type(); + std::vector element_indices; + uint32_t starting_index = 1; + if (IsPtrAccessChain(inst->opcode())) { + // Skip the first index of OpPtrAccessChain as it does not affect type + // resolution. + starting_index = 2; + } + for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) { + Instruction* index_inst = + def_use_mgr->GetDef(inst->GetSingleWordInOperand(i)); + const analysis::Constant* index_constant = + context()->get_constant_mgr()->GetConstantFromInst(index_inst); + if (index_constant) { + uint32_t index_value = GetConstantValue(index_constant); + element_indices.push_back(index_value); + } else { + // This index must not matter to resolve the type in valid SPIR-V. + element_indices.push_back(0); + } + } + type = type_mgr->GetMemberType(type, element_indices); + return type; +} + +bool CombineAccessChains::CombineIndices(Instruction* ptr_input, + Instruction* inst, + std::vector* new_operands) { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + analysis::ConstantManager* constant_mgr = context()->get_constant_mgr(); + + Instruction* last_index_inst = def_use_mgr->GetDef( + ptr_input->GetSingleWordInOperand(ptr_input->NumInOperands() - 1)); + const analysis::Constant* last_index_constant = + constant_mgr->GetConstantFromInst(last_index_inst); + + Instruction* element_inst = + def_use_mgr->GetDef(inst->GetSingleWordInOperand(1)); + const analysis::Constant* element_constant = + constant_mgr->GetConstantFromInst(element_inst); + + // Combine the last index of the AccessChain (|ptr_inst|) with the element + // operand of the PtrAccessChain (|inst|). + const bool combining_element_operands = + IsPtrAccessChain(inst->opcode()) && + IsPtrAccessChain(ptr_input->opcode()) && ptr_input->NumInOperands() == 2; + uint32_t new_value_id = 0; + const analysis::Type* type = GetIndexedType(ptr_input); + if (last_index_constant && element_constant) { + // Combine the constants. + uint32_t new_value = GetConstantValue(last_index_constant) + + GetConstantValue(element_constant); + const analysis::Constant* new_value_constant = + constant_mgr->GetConstant(last_index_constant->type(), {new_value}); + Instruction* new_value_inst = + constant_mgr->GetDefiningInstruction(new_value_constant); + new_value_id = new_value_inst->result_id(); + } else if (!type->AsStruct() || combining_element_operands) { + // Generate an addition of the two indices. + InstructionBuilder builder( + context(), inst, + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + Instruction* addition = builder.AddIAdd(last_index_inst->type_id(), + last_index_inst->result_id(), + element_inst->result_id()); + new_value_id = addition->result_id(); + } else { + // Indexing into structs must be constant, so bail out here. + return false; + } + new_operands->push_back({SPV_OPERAND_TYPE_ID, {new_value_id}}); + return true; +} + +bool CombineAccessChains::CreateNewInputOperands( + Instruction* ptr_input, Instruction* inst, + std::vector* new_operands) { + // Start by copying all the input operands of the feeder access chain. + for (uint32_t i = 0; i != ptr_input->NumInOperands() - 1; ++i) { + new_operands->push_back(ptr_input->GetInOperand(i)); + } + + // Deal with the last index of the feeder access chain. + if (IsPtrAccessChain(inst->opcode())) { + // The last index of the feeder should be combined with the element operand + // of |inst|. + if (!CombineIndices(ptr_input, inst, new_operands)) return false; + } else { + // The indices aren't being combined so now add the last index operand of + // |ptr_input|. + new_operands->push_back( + ptr_input->GetInOperand(ptr_input->NumInOperands() - 1)); + } + + // Copy the remaining index operands. + uint32_t starting_index = IsPtrAccessChain(inst->opcode()) ? 2 : 1; + for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) { + new_operands->push_back(inst->GetInOperand(i)); + } + + return true; +} + +bool CombineAccessChains::CombineAccessChain(Instruction* inst) { + assert((inst->opcode() == SpvOpPtrAccessChain || + inst->opcode() == SpvOpAccessChain || + inst->opcode() == SpvOpInBoundsAccessChain || + inst->opcode() == SpvOpInBoundsPtrAccessChain) && + "Wrong opcode. Expected an access chain."); + + Instruction* ptr_input = + context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0)); + if (ptr_input->opcode() != SpvOpAccessChain && + ptr_input->opcode() != SpvOpInBoundsAccessChain && + ptr_input->opcode() != SpvOpPtrAccessChain && + ptr_input->opcode() != SpvOpInBoundsPtrAccessChain) { + return false; + } + + if (Has64BitIndices(inst) || Has64BitIndices(ptr_input)) return false; + + // Handles the following cases: + // 1. |ptr_input| is an index-less access chain. Replace the pointer + // in |inst| with |ptr_input|'s pointer. + // 2. |inst| is a index-less access chain. Change |inst| to an + // OpCopyObject. + // 3. |inst| is not a pointer access chain. + // |inst|'s indices are appended to |ptr_input|'s indices. + // 4. |ptr_input| is not pointer access chain. + // |inst| is a pointer access chain. + // |inst|'s element operand is combined with the last index in + // |ptr_input| to form a new operand. + // 5. |ptr_input| is a pointer access chain. + // Like the above scenario, |inst|'s element operand is combined + // with |ptr_input|'s last index. This results is either a + // combined element operand or combined regular index. + + // TODO(alan-baker): Support this properly. Requires analyzing the + // size/alignment of the type and converting the stride into an element + // index. + uint32_t array_stride = GetArrayStride(ptr_input); + if (array_stride != 0) return false; + + if (ptr_input->NumInOperands() == 1) { + // The input is effectively a no-op. + inst->SetInOperand(0, {ptr_input->GetSingleWordInOperand(0)}); + context()->AnalyzeUses(inst); + } else if (inst->NumInOperands() == 1) { + // |inst| is a no-op, change it to a copy. Instruction simplification will + // clean it up. + inst->SetOpcode(SpvOpCopyObject); + } else { + std::vector new_operands; + if (!CreateNewInputOperands(ptr_input, inst, &new_operands)) return false; + + // Update the instruction. + inst->SetOpcode(UpdateOpcode(inst->opcode(), ptr_input->opcode())); + inst->SetInOperands(std::move(new_operands)); + context()->AnalyzeUses(inst); + } + return true; +} + +SpvOp CombineAccessChains::UpdateOpcode(SpvOp base_opcode, SpvOp input_opcode) { + auto IsInBounds = [](SpvOp opcode) { + return opcode == SpvOpInBoundsPtrAccessChain || + opcode == SpvOpInBoundsAccessChain; + }; + + if (input_opcode == SpvOpInBoundsPtrAccessChain) { + if (!IsInBounds(base_opcode)) return SpvOpPtrAccessChain; + } else if (input_opcode == SpvOpInBoundsAccessChain) { + if (!IsInBounds(base_opcode)) return SpvOpAccessChain; + } + + return input_opcode; +} + +bool CombineAccessChains::IsPtrAccessChain(SpvOp opcode) { + return opcode == SpvOpPtrAccessChain || opcode == SpvOpInBoundsPtrAccessChain; +} + +bool CombineAccessChains::Has64BitIndices(Instruction* inst) { + for (uint32_t i = 1; i < inst->NumInOperands(); ++i) { + Instruction* index_inst = + context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(i)); + const analysis::Type* index_type = + context()->get_type_mgr()->GetType(index_inst->type_id()); + if (!index_type->AsInteger() || index_type->AsInteger()->width() != 32) + return true; + } + return false; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/combine_access_chains.h b/source/opt/combine_access_chains.h new file mode 100644 index 000000000..531209ec1 --- /dev/null +++ b/source/opt/combine_access_chains.h @@ -0,0 +1,83 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_COMBINE_ACCESS_CHAINS_H_ +#define SOURCE_OPT_COMBINE_ACCESS_CHAINS_H_ + +#include + +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class CombineAccessChains : public Pass { + public: + const char* name() const override { return "combine-access-chains"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators | + IRContext::kAnalysisCFG | IRContext::kAnalysisDominatorAnalysis | + IRContext::kAnalysisNameMap | IRContext::kAnalysisConstants | + IRContext::kAnalysisTypes; + } + + private: + // Combine access chains in |function|. Blocks are processed in reverse + // post-order. Returns true if the function is modified. + bool ProcessFunction(Function& function); + + // Combines an access chain (normal, in bounds or pointer) |inst| if its base + // pointer is another access chain. Returns true if the access chain was + // modified. + bool CombineAccessChain(Instruction* inst); + + // Returns the value of |constant_inst| as a uint32_t. + uint32_t GetConstantValue(const analysis::Constant* constant_inst); + + // Returns the array stride of |inst|'s type. + uint32_t GetArrayStride(const Instruction* inst); + + // Returns the type by resolving the index operands |inst|. |inst| must be an + // access chain instruction. + const analysis::Type* GetIndexedType(Instruction* inst); + + // Populates |new_operands| with the operands for the combined access chain. + // Returns false if the access chains cannot be combined. + bool CreateNewInputOperands(Instruction* ptr_input, Instruction* inst, + std::vector* new_operands); + + // Combines the last index of |ptr_input| with the element operand of |inst|. + // Adds the combined operand to |new_operands|. + bool CombineIndices(Instruction* ptr_input, Instruction* inst, + std::vector* new_operands); + + // Returns the opcode to use for the combined access chain. + SpvOp UpdateOpcode(SpvOp base_opcode, SpvOp input_opcode); + + // Returns true if |opcode| is a pointer access chain. + bool IsPtrAccessChain(SpvOp opcode); + + // Returns true if |inst| (an access chain) has 64-bit indices. + bool Has64BitIndices(Instruction* inst); +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_COMBINE_ACCESS_CHAINS_H_ diff --git a/source/opt/common_uniform_elim_pass.cpp b/source/opt/common_uniform_elim_pass.cpp new file mode 100644 index 000000000..52d279fdb --- /dev/null +++ b/source/opt/common_uniform_elim_pass.cpp @@ -0,0 +1,592 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/common_uniform_elim_pass.h" +#include "source/cfa.h" +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace opt { + +namespace { + +const uint32_t kAccessChainPtrIdInIdx = 0; +const uint32_t kTypePointerStorageClassInIdx = 0; +const uint32_t kTypePointerTypeIdInIdx = 1; +const uint32_t kConstantValueInIdx = 0; +const uint32_t kExtractCompositeIdInIdx = 0; +const uint32_t kExtractIdx0InIdx = 1; +const uint32_t kStorePtrIdInIdx = 0; +const uint32_t kLoadPtrIdInIdx = 0; +const uint32_t kCopyObjectOperandInIdx = 0; +const uint32_t kTypeIntWidthInIdx = 0; + +} // anonymous namespace + +bool CommonUniformElimPass::IsNonPtrAccessChain(const SpvOp opcode) const { + return opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain; +} + +bool CommonUniformElimPass::IsSamplerOrImageType( + const Instruction* typeInst) const { + switch (typeInst->opcode()) { + case SpvOpTypeSampler: + case SpvOpTypeImage: + case SpvOpTypeSampledImage: + return true; + default: + break; + } + if (typeInst->opcode() != SpvOpTypeStruct) return false; + // Return true if any member is a sampler or image + return !typeInst->WhileEachInId([this](const uint32_t* tid) { + const Instruction* compTypeInst = get_def_use_mgr()->GetDef(*tid); + if (IsSamplerOrImageType(compTypeInst)) { + return false; + } + return true; + }); +} + +bool CommonUniformElimPass::IsSamplerOrImageVar(uint32_t varId) const { + const Instruction* varInst = get_def_use_mgr()->GetDef(varId); + assert(varInst->opcode() == SpvOpVariable); + const uint32_t varTypeId = varInst->type_id(); + const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId); + const uint32_t varPteTypeId = + varTypeInst->GetSingleWordInOperand(kTypePointerTypeIdInIdx); + Instruction* varPteTypeInst = get_def_use_mgr()->GetDef(varPteTypeId); + return IsSamplerOrImageType(varPteTypeInst); +} + +Instruction* CommonUniformElimPass::GetPtr(Instruction* ip, uint32_t* objId) { + const SpvOp op = ip->opcode(); + assert(op == SpvOpStore || op == SpvOpLoad); + *objId = ip->GetSingleWordInOperand(op == SpvOpStore ? kStorePtrIdInIdx + : kLoadPtrIdInIdx); + Instruction* ptrInst = get_def_use_mgr()->GetDef(*objId); + while (ptrInst->opcode() == SpvOpCopyObject) { + *objId = ptrInst->GetSingleWordInOperand(kCopyObjectOperandInIdx); + ptrInst = get_def_use_mgr()->GetDef(*objId); + } + Instruction* objInst = ptrInst; + while (objInst->opcode() != SpvOpVariable && + objInst->opcode() != SpvOpFunctionParameter) { + if (IsNonPtrAccessChain(objInst->opcode())) { + *objId = objInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx); + } else { + assert(objInst->opcode() == SpvOpCopyObject); + *objId = objInst->GetSingleWordInOperand(kCopyObjectOperandInIdx); + } + objInst = get_def_use_mgr()->GetDef(*objId); + } + return ptrInst; +} + +bool CommonUniformElimPass::IsVolatileStruct(uint32_t type_id) { + assert(get_def_use_mgr()->GetDef(type_id)->opcode() == SpvOpTypeStruct); + return !get_decoration_mgr()->WhileEachDecoration( + type_id, SpvDecorationVolatile, [](const Instruction&) { return false; }); +} + +bool CommonUniformElimPass::IsAccessChainToVolatileStructType( + const Instruction& AccessChainInst) { + assert(AccessChainInst.opcode() == SpvOpAccessChain); + + uint32_t ptr_id = AccessChainInst.GetSingleWordInOperand(0); + const Instruction* ptr_inst = get_def_use_mgr()->GetDef(ptr_id); + uint32_t pointee_type_id = GetPointeeTypeId(ptr_inst); + const uint32_t num_operands = AccessChainInst.NumOperands(); + + // walk the type tree: + for (uint32_t idx = 3; idx < num_operands; ++idx) { + Instruction* pointee_type = get_def_use_mgr()->GetDef(pointee_type_id); + + switch (pointee_type->opcode()) { + case SpvOpTypeMatrix: + case SpvOpTypeVector: + case SpvOpTypeArray: + case SpvOpTypeRuntimeArray: + pointee_type_id = pointee_type->GetSingleWordOperand(1); + break; + case SpvOpTypeStruct: + // check for volatile decorations: + if (IsVolatileStruct(pointee_type_id)) return true; + + if (idx < num_operands - 1) { + const uint32_t index_id = AccessChainInst.GetSingleWordOperand(idx); + const Instruction* index_inst = get_def_use_mgr()->GetDef(index_id); + uint32_t index_value = index_inst->GetSingleWordOperand( + 2); // TODO: replace with GetUintValueFromConstant() + pointee_type_id = pointee_type->GetSingleWordInOperand(index_value); + } + break; + default: + assert(false && "Unhandled pointee type."); + } + } + return false; +} + +bool CommonUniformElimPass::IsVolatileLoad(const Instruction& loadInst) { + assert(loadInst.opcode() == SpvOpLoad); + // Check if this Load instruction has Volatile Memory Access flag + if (loadInst.NumOperands() == 4) { + uint32_t memory_access_mask = loadInst.GetSingleWordOperand(3); + if (memory_access_mask & SpvMemoryAccessVolatileMask) return true; + } + // If we load a struct directly (result type is struct), + // check if the struct is decorated volatile + uint32_t type_id = loadInst.type_id(); + if (get_def_use_mgr()->GetDef(type_id)->opcode() == SpvOpTypeStruct) + return IsVolatileStruct(type_id); + else + return false; +} + +bool CommonUniformElimPass::IsUniformVar(uint32_t varId) { + const Instruction* varInst = + get_def_use_mgr()->id_to_defs().find(varId)->second; + if (varInst->opcode() != SpvOpVariable) return false; + const uint32_t varTypeId = varInst->type_id(); + const Instruction* varTypeInst = + get_def_use_mgr()->id_to_defs().find(varTypeId)->second; + return varTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx) == + SpvStorageClassUniform || + varTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx) == + SpvStorageClassUniformConstant; +} + +bool CommonUniformElimPass::HasUnsupportedDecorates(uint32_t id) const { + return !get_def_use_mgr()->WhileEachUser(id, [this](Instruction* user) { + if (IsNonTypeDecorate(user->opcode())) return false; + return true; + }); +} + +bool CommonUniformElimPass::HasOnlyNamesAndDecorates(uint32_t id) const { + return get_def_use_mgr()->WhileEachUser(id, [this](Instruction* user) { + SpvOp op = user->opcode(); + if (op != SpvOpName && !IsNonTypeDecorate(op)) return false; + return true; + }); +} + +void CommonUniformElimPass::DeleteIfUseless(Instruction* inst) { + const uint32_t resId = inst->result_id(); + assert(resId != 0); + if (HasOnlyNamesAndDecorates(resId)) { + context()->KillInst(inst); + } +} + +Instruction* CommonUniformElimPass::ReplaceAndDeleteLoad(Instruction* loadInst, + uint32_t replId, + Instruction* ptrInst) { + const uint32_t loadId = loadInst->result_id(); + context()->KillNamesAndDecorates(loadId); + (void)context()->ReplaceAllUsesWith(loadId, replId); + // remove load instruction + Instruction* next_instruction = context()->KillInst(loadInst); + // if access chain, see if it can be removed as well + if (IsNonPtrAccessChain(ptrInst->opcode())) DeleteIfUseless(ptrInst); + return next_instruction; +} + +void CommonUniformElimPass::GenACLoadRepl( + const Instruction* ptrInst, + std::vector>* newInsts, uint32_t* resultId) { + // Build and append Load + const uint32_t ldResultId = TakeNextId(); + const uint32_t varId = + ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx); + const Instruction* varInst = get_def_use_mgr()->GetDef(varId); + assert(varInst->opcode() == SpvOpVariable); + const uint32_t varPteTypeId = GetPointeeTypeId(varInst); + std::vector load_in_operands; + load_in_operands.push_back(Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, + std::initializer_list{varId})); + std::unique_ptr newLoad(new Instruction( + context(), SpvOpLoad, varPteTypeId, ldResultId, load_in_operands)); + get_def_use_mgr()->AnalyzeInstDefUse(&*newLoad); + newInsts->emplace_back(std::move(newLoad)); + + // Build and append Extract + const uint32_t extResultId = TakeNextId(); + const uint32_t ptrPteTypeId = GetPointeeTypeId(ptrInst); + std::vector ext_in_opnds; + ext_in_opnds.push_back(Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, + std::initializer_list{ldResultId})); + uint32_t iidIdx = 0; + ptrInst->ForEachInId([&iidIdx, &ext_in_opnds, this](const uint32_t* iid) { + if (iidIdx > 0) { + const Instruction* cInst = get_def_use_mgr()->GetDef(*iid); + uint32_t val = cInst->GetSingleWordInOperand(kConstantValueInIdx); + ext_in_opnds.push_back( + Operand(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, + std::initializer_list{val})); + } + ++iidIdx; + }); + std::unique_ptr newExt( + new Instruction(context(), SpvOpCompositeExtract, ptrPteTypeId, + extResultId, ext_in_opnds)); + get_def_use_mgr()->AnalyzeInstDefUse(&*newExt); + newInsts->emplace_back(std::move(newExt)); + *resultId = extResultId; +} + +bool CommonUniformElimPass::IsConstantIndexAccessChain(Instruction* acp) { + uint32_t inIdx = 0; + return acp->WhileEachInId([&inIdx, this](uint32_t* tid) { + if (inIdx > 0) { + Instruction* opInst = get_def_use_mgr()->GetDef(*tid); + if (opInst->opcode() != SpvOpConstant) return false; + } + ++inIdx; + return true; + }); +} + +bool CommonUniformElimPass::UniformAccessChainConvert(Function* func) { + bool modified = false; + for (auto bi = func->begin(); bi != func->end(); ++bi) { + for (Instruction* inst = &*bi->begin(); inst; inst = inst->NextNode()) { + if (inst->opcode() != SpvOpLoad) continue; + uint32_t varId; + Instruction* ptrInst = GetPtr(inst, &varId); + if (!IsNonPtrAccessChain(ptrInst->opcode())) continue; + // Do not convert nested access chains + if (ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx) != varId) + continue; + if (!IsUniformVar(varId)) continue; + if (!IsConstantIndexAccessChain(ptrInst)) continue; + if (HasUnsupportedDecorates(inst->result_id())) continue; + if (HasUnsupportedDecorates(ptrInst->result_id())) continue; + if (IsVolatileLoad(*inst)) continue; + if (IsAccessChainToVolatileStructType(*ptrInst)) continue; + std::vector> newInsts; + uint32_t replId; + GenACLoadRepl(ptrInst, &newInsts, &replId); + inst = ReplaceAndDeleteLoad(inst, replId, ptrInst); + assert(inst->opcode() != SpvOpPhi); + inst = inst->InsertBefore(std::move(newInsts)); + modified = true; + } + } + return modified; +} + +void CommonUniformElimPass::ComputeStructuredSuccessors(Function* func) { + block2structured_succs_.clear(); + for (auto& blk : *func) { + // If header, make merge block first successor. + uint32_t mbid = blk.MergeBlockIdIfAny(); + if (mbid != 0) { + block2structured_succs_[&blk].push_back(cfg()->block(mbid)); + uint32_t cbid = blk.ContinueBlockIdIfAny(); + if (cbid != 0) { + block2structured_succs_[&blk].push_back(cfg()->block(mbid)); + } + } + // add true successors + const auto& const_blk = blk; + const_blk.ForEachSuccessorLabel([&blk, this](const uint32_t sbid) { + block2structured_succs_[&blk].push_back(cfg()->block(sbid)); + }); + } +} + +void CommonUniformElimPass::ComputeStructuredOrder( + Function* func, std::list* order) { + // Compute structured successors and do DFS + ComputeStructuredSuccessors(func); + auto ignore_block = [](cbb_ptr) {}; + auto ignore_edge = [](cbb_ptr, cbb_ptr) {}; + auto get_structured_successors = [this](const BasicBlock* block) { + return &(block2structured_succs_[block]); + }; + // TODO(greg-lunarg): Get rid of const_cast by making moving const + // out of the cfa.h prototypes and into the invoking code. + auto post_order = [&](cbb_ptr b) { + order->push_front(const_cast(b)); + }; + + order->clear(); + CFA::DepthFirstTraversal(&*func->begin(), + get_structured_successors, ignore_block, + post_order, ignore_edge); +} + +bool CommonUniformElimPass::CommonUniformLoadElimination(Function* func) { + // Process all blocks in structured order. This is just one way (the + // simplest?) to keep track of the most recent block outside of control + // flow, used to copy common instructions, guaranteed to dominate all + // following load sites. + std::list structuredOrder; + ComputeStructuredOrder(func, &structuredOrder); + uniform2load_id_.clear(); + bool modified = false; + // Find insertion point in first block to copy non-dominating loads. + auto insertItr = func->begin()->begin(); + while (insertItr->opcode() == SpvOpVariable || + insertItr->opcode() == SpvOpNop) + ++insertItr; + // Update insertItr until it will not be removed. Without this code, + // ReplaceAndDeleteLoad() can set |insertItr| as a dangling pointer. + while (IsUniformLoadToBeRemoved(&*insertItr)) ++insertItr; + uint32_t mergeBlockId = 0; + for (auto bi = structuredOrder.begin(); bi != structuredOrder.end(); ++bi) { + BasicBlock* bp = *bi; + // Check if we are exiting outermost control construct. If so, remember + // new load insertion point. Trying to keep register pressure down. + if (mergeBlockId == bp->id()) { + mergeBlockId = 0; + insertItr = bp->begin(); + while (insertItr->opcode() == SpvOpPhi) { + ++insertItr; + } + + // Update insertItr until it will not be removed. Without this code, + // ReplaceAndDeleteLoad() can set |insertItr| as a dangling pointer. + while (IsUniformLoadToBeRemoved(&*insertItr)) ++insertItr; + } + for (Instruction* inst = &*bp->begin(); inst; inst = inst->NextNode()) { + if (inst->opcode() != SpvOpLoad) continue; + uint32_t varId; + Instruction* ptrInst = GetPtr(inst, &varId); + if (ptrInst->opcode() != SpvOpVariable) continue; + if (!IsUniformVar(varId)) continue; + if (IsSamplerOrImageVar(varId)) continue; + if (HasUnsupportedDecorates(inst->result_id())) continue; + if (IsVolatileLoad(*inst)) continue; + uint32_t replId; + const auto uItr = uniform2load_id_.find(varId); + if (uItr != uniform2load_id_.end()) { + replId = uItr->second; + } else { + if (mergeBlockId == 0) { + // Load is in dominating block; just remember it + uniform2load_id_[varId] = inst->result_id(); + continue; + } else { + // Copy load into most recent dominating block and remember it + replId = TakeNextId(); + std::unique_ptr newLoad(new Instruction( + context(), SpvOpLoad, inst->type_id(), replId, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {varId}}})); + get_def_use_mgr()->AnalyzeInstDefUse(&*newLoad); + insertItr = insertItr.InsertBefore(std::move(newLoad)); + ++insertItr; + uniform2load_id_[varId] = replId; + } + } + inst = ReplaceAndDeleteLoad(inst, replId, ptrInst); + modified = true; + } + // If we are outside of any control construct and entering one, remember + // the id of the merge block + if (mergeBlockId == 0) { + mergeBlockId = bp->MergeBlockIdIfAny(); + } + } + return modified; +} + +bool CommonUniformElimPass::CommonUniformLoadElimBlock(Function* func) { + bool modified = false; + for (auto& blk : *func) { + uniform2load_id_.clear(); + for (Instruction* inst = &*blk.begin(); inst; inst = inst->NextNode()) { + if (inst->opcode() != SpvOpLoad) continue; + uint32_t varId; + Instruction* ptrInst = GetPtr(inst, &varId); + if (ptrInst->opcode() != SpvOpVariable) continue; + if (!IsUniformVar(varId)) continue; + if (!IsSamplerOrImageVar(varId)) continue; + if (HasUnsupportedDecorates(inst->result_id())) continue; + if (IsVolatileLoad(*inst)) continue; + uint32_t replId; + const auto uItr = uniform2load_id_.find(varId); + if (uItr != uniform2load_id_.end()) { + replId = uItr->second; + } else { + uniform2load_id_[varId] = inst->result_id(); + continue; + } + inst = ReplaceAndDeleteLoad(inst, replId, ptrInst); + modified = true; + } + } + return modified; +} + +bool CommonUniformElimPass::CommonExtractElimination(Function* func) { + // Find all composite ids with duplicate extracts. + for (auto bi = func->begin(); bi != func->end(); ++bi) { + for (auto ii = bi->begin(); ii != bi->end(); ++ii) { + if (ii->opcode() != SpvOpCompositeExtract) continue; + // TODO(greg-lunarg): Support multiple indices + if (ii->NumInOperands() > 2) continue; + if (HasUnsupportedDecorates(ii->result_id())) continue; + uint32_t compId = ii->GetSingleWordInOperand(kExtractCompositeIdInIdx); + uint32_t idx = ii->GetSingleWordInOperand(kExtractIdx0InIdx); + comp2idx2inst_[compId][idx].push_back(&*ii); + } + } + // For all defs of ids with duplicate extracts, insert new extracts + // after def, and replace and delete old extracts + bool modified = false; + for (auto bi = func->begin(); bi != func->end(); ++bi) { + for (auto ii = bi->begin(); ii != bi->end(); ++ii) { + const auto cItr = comp2idx2inst_.find(ii->result_id()); + if (cItr == comp2idx2inst_.end()) continue; + for (auto idxItr : cItr->second) { + if (idxItr.second.size() < 2) continue; + uint32_t replId = TakeNextId(); + std::unique_ptr newExtract( + idxItr.second.front()->Clone(context())); + newExtract->SetResultId(replId); + get_def_use_mgr()->AnalyzeInstDefUse(&*newExtract); + ++ii; + ii = ii.InsertBefore(std::move(newExtract)); + for (auto instItr : idxItr.second) { + uint32_t resId = instItr->result_id(); + context()->KillNamesAndDecorates(resId); + (void)context()->ReplaceAllUsesWith(resId, replId); + context()->KillInst(instItr); + } + modified = true; + } + } + } + return modified; +} + +bool CommonUniformElimPass::EliminateCommonUniform(Function* func) { + bool modified = false; + modified |= UniformAccessChainConvert(func); + modified |= CommonUniformLoadElimination(func); + modified |= CommonExtractElimination(func); + + modified |= CommonUniformLoadElimBlock(func); + return modified; +} + +void CommonUniformElimPass::Initialize() { + // Clear collections. + comp2idx2inst_.clear(); + + // Initialize extension whitelist + InitExtensions(); +} + +bool CommonUniformElimPass::AllExtensionsSupported() const { + // If any extension not in whitelist, return false + for (auto& ei : get_module()->extensions()) { + const char* extName = + reinterpret_cast(&ei.GetInOperand(0).words[0]); + if (extensions_whitelist_.find(extName) == extensions_whitelist_.end()) + return false; + } + return true; +} + +Pass::Status CommonUniformElimPass::ProcessImpl() { + // Assumes all control flow structured. + // TODO(greg-lunarg): Do SSA rewrite for non-structured control flow + if (!context()->get_feature_mgr()->HasCapability(SpvCapabilityShader)) + return Status::SuccessWithoutChange; + // Assumes logical addressing only + // TODO(greg-lunarg): Add support for physical addressing + if (context()->get_feature_mgr()->HasCapability(SpvCapabilityAddresses)) + return Status::SuccessWithoutChange; + // Do not process if any disallowed extensions are enabled + if (!AllExtensionsSupported()) return Status::SuccessWithoutChange; + // Do not process if module contains OpGroupDecorate. Additional + // support required in KillNamesAndDecorates(). + // TODO(greg-lunarg): Add support for OpGroupDecorate + for (auto& ai : get_module()->annotations()) + if (ai.opcode() == SpvOpGroupDecorate) return Status::SuccessWithoutChange; + // If non-32-bit integer type in module, terminate processing + // TODO(): Handle non-32-bit integer constants in access chains + for (const Instruction& inst : get_module()->types_values()) + if (inst.opcode() == SpvOpTypeInt && + inst.GetSingleWordInOperand(kTypeIntWidthInIdx) != 32) + return Status::SuccessWithoutChange; + // Process entry point functions + ProcessFunction pfn = [this](Function* fp) { + return EliminateCommonUniform(fp); + }; + bool modified = context()->ProcessEntryPointCallTree(pfn); + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +CommonUniformElimPass::CommonUniformElimPass() = default; + +Pass::Status CommonUniformElimPass::Process() { + Initialize(); + return ProcessImpl(); +} + +void CommonUniformElimPass::InitExtensions() { + extensions_whitelist_.clear(); + extensions_whitelist_.insert({ + "SPV_AMD_shader_explicit_vertex_parameter", + "SPV_AMD_shader_trinary_minmax", + "SPV_AMD_gcn_shader", + "SPV_KHR_shader_ballot", + "SPV_AMD_shader_ballot", + "SPV_AMD_gpu_shader_half_float", + "SPV_KHR_shader_draw_parameters", + "SPV_KHR_subgroup_vote", + "SPV_KHR_16bit_storage", + "SPV_KHR_device_group", + "SPV_KHR_multiview", + "SPV_NVX_multiview_per_view_attributes", + "SPV_NV_viewport_array2", + "SPV_NV_stereo_view_rendering", + "SPV_NV_sample_mask_override_coverage", + "SPV_NV_geometry_shader_passthrough", + "SPV_AMD_texture_gather_bias_lod", + "SPV_KHR_storage_buffer_storage_class", + // SPV_KHR_variable_pointers + // Currently do not support extended pointer expressions + "SPV_AMD_gpu_shader_int16", + "SPV_KHR_post_depth_coverage", + "SPV_KHR_shader_atomic_counter_ops", + "SPV_EXT_shader_stencil_export", + "SPV_EXT_shader_viewport_index_layer", + "SPV_AMD_shader_image_load_store_lod", + "SPV_AMD_shader_fragment_mask", + "SPV_EXT_fragment_fully_covered", + "SPV_AMD_gpu_shader_half_float_fetch", + "SPV_GOOGLE_decorate_string", + "SPV_GOOGLE_hlsl_functionality1", + "SPV_NV_shader_subgroup_partitioned", + "SPV_EXT_descriptor_indexing", + "SPV_NV_fragment_shader_barycentric", + "SPV_NV_compute_shader_derivatives", + "SPV_NV_shader_image_footprint", + "SPV_NV_shading_rate", + "SPV_NV_mesh_shader", + "SPV_NV_ray_tracing", + "SPV_EXT_fragment_invocation_density", + }); +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/common_uniform_elim_pass.h b/source/opt/common_uniform_elim_pass.h new file mode 100644 index 000000000..e6ef69c5d --- /dev/null +++ b/source/opt/common_uniform_elim_pass.h @@ -0,0 +1,213 @@ +// Copyright (c) 2016 The Khronos Group Inc. +// Copyright (c) 2016 Valve Corporation +// Copyright (c) 2016 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_COMMON_UNIFORM_ELIM_PASS_H_ +#define SOURCE_OPT_COMMON_UNIFORM_ELIM_PASS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/opt/basic_block.h" +#include "source/opt/decoration_manager.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class CommonUniformElimPass : public Pass { + using cbb_ptr = const BasicBlock*; + + public: + using GetBlocksFunction = + std::function*(const BasicBlock*)>; + + CommonUniformElimPass(); + + const char* name() const override { return "eliminate-common-uniform"; } + Status Process() override; + + private: + // Returns true if |opcode| is a non-ptr access chain op + bool IsNonPtrAccessChain(const SpvOp opcode) const; + + // Returns true if |typeInst| is a sampler or image type or a struct + // containing one, recursively. + bool IsSamplerOrImageType(const Instruction* typeInst) const; + + // Returns true if |varId| is a variable containing a sampler or image. + bool IsSamplerOrImageVar(uint32_t varId) const; + + // Given a load or store pointed at by |ip|, return the top-most + // non-CopyObj in its pointer operand. Also return the base pointer + // in |objId|. + Instruction* GetPtr(Instruction* ip, uint32_t* objId); + + // Return true if variable is uniform + bool IsUniformVar(uint32_t varId); + + // Given the type id for a struct type, checks if the struct type + // or any struct member is volatile decorated + bool IsVolatileStruct(uint32_t type_id); + + // Given an OpAccessChain instruction, return true + // if the accessed variable belongs to a volatile + // decorated object or member of a struct type + bool IsAccessChainToVolatileStructType(const Instruction& AccessChainInst); + + // Given an OpLoad instruction, return true if + // OpLoad has a Volatile Memory Access flag or if + // the resulting type is a volatile decorated struct + bool IsVolatileLoad(const Instruction& loadInst); + + // Return true if any uses of |id| are decorate ops. + bool HasUnsupportedDecorates(uint32_t id) const; + + // Return true if all uses of |id| are only name or decorate ops. + bool HasOnlyNamesAndDecorates(uint32_t id) const; + + // Delete inst if it has no uses. Assumes inst has a resultId. + void DeleteIfUseless(Instruction* inst); + + // Replace all instances of load's id with replId and delete load + // and its access chain, if any + Instruction* ReplaceAndDeleteLoad(Instruction* loadInst, uint32_t replId, + Instruction* ptrInst); + + // For the (constant index) access chain ptrInst, create an + // equivalent load and extract + void GenACLoadRepl(const Instruction* ptrInst, + std::vector>* newInsts, + uint32_t* resultId); + + // Return true if all indices are constant + bool IsConstantIndexAccessChain(Instruction* acp); + + // Convert all uniform access chain loads into load/extract. + bool UniformAccessChainConvert(Function* func); + + // Compute structured successors for function |func|. + // A block's structured successors are the blocks it branches to + // together with its declared merge block if it has one. + // When order matters, the merge block always appears first. + // This assures correct depth first search in the presence of early + // returns and kills. If the successor vector contain duplicates + // if the merge block, they are safely ignored by DFS. + // + // TODO(dnovillo): This pass computes structured successors slightly different + // than the implementation in class Pass. Can this be re-factored? + void ComputeStructuredSuccessors(Function* func); + + // Compute structured block order for |func| into |structuredOrder|. This + // order has the property that dominators come before all blocks they + // dominate and merge blocks come after all blocks that are in the control + // constructs of their header. + // + // TODO(dnovillo): This pass computes structured order slightly different + // than the implementation in class Pass. Can this be re-factored? + void ComputeStructuredOrder(Function* func, std::list* order); + + // Eliminate loads of uniform variables which have previously been loaded. + // If first load is in control flow, move it to first block of function. + // Most effective if preceded by UniformAccessChainRemoval(). + bool CommonUniformLoadElimination(Function* func); + + // Eliminate loads of uniform sampler and image variables which have + // previously + // been loaded in the same block for types whose loads cannot cross blocks. + bool CommonUniformLoadElimBlock(Function* func); + + // Eliminate duplicated extracts of same id. Extract may be moved to same + // block as the id definition. This is primarily intended for extracts + // from uniform loads. Most effective if preceded by + // CommonUniformLoadElimination(). + bool CommonExtractElimination(Function* func); + + // For function |func|, first change all uniform constant index + // access chain loads into equivalent composite extracts. Then consolidate + // identical uniform loads into one uniform load. Finally, consolidate + // identical uniform extracts into one uniform extract. This may require + // moving a load or extract to a point which dominates all uses. + // Return true if func is modified. + // + // This pass requires the function to have structured control flow ie shader + // capability. It also requires logical addressing ie Addresses capability + // is not enabled. It also currently does not support any extensions. + // + // This function currently only optimizes loads with a single index. + bool EliminateCommonUniform(Function* func); + + // Initialize extensions whitelist + void InitExtensions(); + + // Return true if all extensions in this module are allowed by this pass. + bool AllExtensionsSupported() const; + + // Return true if |op| is a decorate for non-type instruction + inline bool IsNonTypeDecorate(uint32_t op) const { + return (op == SpvOpDecorate || op == SpvOpDecorateId); + } + + // Return true if |inst| is an instruction that loads uniform variable and + // can be replaced with other uniform load instruction. + bool IsUniformLoadToBeRemoved(Instruction* inst) { + if (inst->opcode() == SpvOpLoad) { + uint32_t varId; + Instruction* ptrInst = GetPtr(inst, &varId); + if (ptrInst->opcode() == SpvOpVariable && IsUniformVar(varId) && + !IsSamplerOrImageVar(varId) && + !HasUnsupportedDecorates(inst->result_id()) && !IsVolatileLoad(*inst)) + return true; + } + return false; + } + + void Initialize(); + Pass::Status ProcessImpl(); + + // Map from uniform variable id to its common load id + std::unordered_map uniform2load_id_; + + // Map of extract composite ids to map of indices to insts + // TODO(greg-lunarg): Consider std::vector. + std::unordered_map>> + comp2idx2inst_; + + // Extensions supported by this pass. + std::unordered_set extensions_whitelist_; + + // Map from block to its structured successor blocks. See + // ComputeStructuredSuccessors() for definition. + std::unordered_map> + block2structured_succs_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_COMMON_UNIFORM_ELIM_PASS_H_ diff --git a/source/opt/compact_ids_pass.cpp b/source/opt/compact_ids_pass.cpp new file mode 100644 index 000000000..68b940f1d --- /dev/null +++ b/source/opt/compact_ids_pass.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/compact_ids_pass.h" + +#include +#include + +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace opt { + +Pass::Status CompactIdsPass::Process() { + bool modified = false; + std::unordered_map result_id_mapping; + + context()->module()->ForEachInst( + [&result_id_mapping, &modified](Instruction* inst) { + auto operand = inst->begin(); + while (operand != inst->end()) { + const auto type = operand->type; + if (spvIsIdType(type)) { + assert(operand->words.size() == 1); + uint32_t& id = operand->words[0]; + auto it = result_id_mapping.find(id); + if (it == result_id_mapping.end()) { + const uint32_t new_id = + static_cast(result_id_mapping.size()) + 1; + const auto insertion_result = + result_id_mapping.emplace(id, new_id); + it = insertion_result.first; + assert(insertion_result.second); + } + if (id != it->second) { + modified = true; + id = it->second; + // Update data cached in the instruction object. + if (type == SPV_OPERAND_TYPE_RESULT_ID) { + inst->SetResultId(id); + } else if (type == SPV_OPERAND_TYPE_TYPE_ID) { + inst->SetResultType(id); + } + } + } + ++operand; + } + }, + true); + + if (modified) + context()->module()->SetIdBound( + static_cast(result_id_mapping.size() + 1)); + + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/compact_ids_pass.h b/source/opt/compact_ids_pass.h new file mode 100644 index 000000000..d97ae0fa4 --- /dev/null +++ b/source/opt/compact_ids_pass.h @@ -0,0 +1,42 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_COMPACT_IDS_PASS_H_ +#define SOURCE_OPT_COMPACT_IDS_PASS_H_ + +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class CompactIdsPass : public Pass { + public: + const char* name() const override { return "compact-ids"; } + Status Process() override; + + // Return the mask of preserved Analyses. + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisDominatorAnalysis | + IRContext::kAnalysisLoopAnalysis; + } +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_COMPACT_IDS_PASS_H_ diff --git a/source/opt/composite.cpp b/source/opt/composite.cpp new file mode 100644 index 000000000..2b4dca257 --- /dev/null +++ b/source/opt/composite.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// Copyright (c) 2018 Valve Corporation +// Copyright (c) 2018 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/composite.h" + +#include + +#include "source/opt/ir_context.h" +#include "source/opt/iterator.h" +#include "spirv/1.2/GLSL.std.450.h" + +namespace spvtools { +namespace opt { + +bool ExtInsMatch(const std::vector& extIndices, + const Instruction* insInst, const uint32_t extOffset) { + uint32_t numIndices = static_cast(extIndices.size()) - extOffset; + if (numIndices != insInst->NumInOperands() - 2) return false; + for (uint32_t i = 0; i < numIndices; ++i) + if (extIndices[i + extOffset] != insInst->GetSingleWordInOperand(i + 2)) + return false; + return true; +} + +bool ExtInsConflict(const std::vector& extIndices, + const Instruction* insInst, const uint32_t extOffset) { + if (extIndices.size() - extOffset == insInst->NumInOperands() - 2) + return false; + uint32_t extNumIndices = static_cast(extIndices.size()) - extOffset; + uint32_t insNumIndices = insInst->NumInOperands() - 2; + uint32_t numIndices = std::min(extNumIndices, insNumIndices); + for (uint32_t i = 0; i < numIndices; ++i) + if (extIndices[i + extOffset] != insInst->GetSingleWordInOperand(i + 2)) + return false; + return true; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/composite.h b/source/opt/composite.h new file mode 100644 index 000000000..3cc036e4d --- /dev/null +++ b/source/opt/composite.h @@ -0,0 +1,51 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// Copyright (c) 2018 Valve Corporation +// Copyright (c) 2018 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_COMPOSITE_H_ +#define SOURCE_OPT_COMPOSITE_H_ + +#include +#include +#include +#include +#include +#include + +#include "source/opt/basic_block.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" + +namespace spvtools { +namespace opt { + +// Return true if the extract indices in |extIndices| starting at |extOffset| +// match indices of insert |insInst|. +bool ExtInsMatch(const std::vector& extIndices, + const Instruction* insInst, const uint32_t extOffset); + +// Return true if indices in |extIndices| starting at |extOffset| and +// indices of insert |insInst| conflict, specifically, if the insert +// changes bits specified by the extract, but changes either more bits +// or less bits than the extract specifies, meaning the exact value being +// inserted cannot be used to replace the extract. +bool ExtInsConflict(const std::vector& extIndices, + const Instruction* insInst, const uint32_t extOffset); + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_COMPOSITE_H_ diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp new file mode 100644 index 000000000..f6013a3d7 --- /dev/null +++ b/source/opt/const_folding_rules.cpp @@ -0,0 +1,849 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/const_folding_rules.h" + +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace opt { +namespace { + +const uint32_t kExtractCompositeIdInIdx = 0; + +// Returns true if |type| is Float or a vector of Float. +bool HasFloatingPoint(const analysis::Type* type) { + if (type->AsFloat()) { + return true; + } else if (const analysis::Vector* vec_type = type->AsVector()) { + return vec_type->element_type()->AsFloat() != nullptr; + } + + return false; +} + +// Folds an OpcompositeExtract where input is a composite constant. +ConstantFoldingRule FoldExtractWithConstants() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) + -> const analysis::Constant* { + const analysis::Constant* c = constants[kExtractCompositeIdInIdx]; + if (c == nullptr) { + return nullptr; + } + + for (uint32_t i = 1; i < inst->NumInOperands(); ++i) { + uint32_t element_index = inst->GetSingleWordInOperand(i); + if (c->AsNullConstant()) { + // Return Null for the return type. + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + analysis::TypeManager* type_mgr = context->get_type_mgr(); + return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {}); + } + + auto cc = c->AsCompositeConstant(); + assert(cc != nullptr); + auto components = cc->GetComponents(); + c = components[element_index]; + } + return c; + }; +} + +ConstantFoldingRule FoldVectorShuffleWithConstants() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) + -> const analysis::Constant* { + assert(inst->opcode() == SpvOpVectorShuffle); + const analysis::Constant* c1 = constants[0]; + const analysis::Constant* c2 = constants[1]; + if (c1 == nullptr || c2 == nullptr) { + return nullptr; + } + + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + const analysis::Type* element_type = c1->type()->AsVector()->element_type(); + + std::vector c1_components; + if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) { + c1_components = vec_const->GetComponents(); + } else { + assert(c1->AsNullConstant()); + const analysis::Constant* element = + const_mgr->GetConstant(element_type, {}); + c1_components.resize(c1->type()->AsVector()->element_count(), element); + } + std::vector c2_components; + if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) { + c2_components = vec_const->GetComponents(); + } else { + assert(c2->AsNullConstant()); + const analysis::Constant* element = + const_mgr->GetConstant(element_type, {}); + c2_components.resize(c2->type()->AsVector()->element_count(), element); + } + + std::vector ids; + const uint32_t undef_literal_value = 0xffffffff; + for (uint32_t i = 2; i < inst->NumInOperands(); ++i) { + uint32_t index = inst->GetSingleWordInOperand(i); + if (index == undef_literal_value) { + // Don't fold shuffle with undef literal value. + return nullptr; + } else if (index < c1_components.size()) { + Instruction* member_inst = + const_mgr->GetDefiningInstruction(c1_components[index]); + ids.push_back(member_inst->result_id()); + } else { + Instruction* member_inst = const_mgr->GetDefiningInstruction( + c2_components[index - c1_components.size()]); + ids.push_back(member_inst->result_id()); + } + } + + analysis::TypeManager* type_mgr = context->get_type_mgr(); + return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids); + }; +} + +ConstantFoldingRule FoldVectorTimesScalar() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) + -> const analysis::Constant* { + assert(inst->opcode() == SpvOpVectorTimesScalar); + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + analysis::TypeManager* type_mgr = context->get_type_mgr(); + + if (!inst->IsFloatingPointFoldingAllowed()) { + if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) { + return nullptr; + } + } + + const analysis::Constant* c1 = constants[0]; + const analysis::Constant* c2 = constants[1]; + + if (c1 && c1->IsZero()) { + return c1; + } + + if (c2 && c2->IsZero()) { + // Get or create the NullConstant for this type. + std::vector ids; + return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids); + } + + if (c1 == nullptr || c2 == nullptr) { + return nullptr; + } + + // Check result type. + const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); + const analysis::Vector* vector_type = result_type->AsVector(); + assert(vector_type != nullptr); + const analysis::Type* element_type = vector_type->element_type(); + assert(element_type != nullptr); + const analysis::Float* float_type = element_type->AsFloat(); + assert(float_type != nullptr); + + // Check types of c1 and c2. + assert(c1->type()->AsVector() == vector_type); + assert(c1->type()->AsVector()->element_type() == element_type && + c2->type() == element_type); + + // Get a float vector that is the result of vector-times-scalar. + std::vector c1_components = + c1->GetVectorComponents(const_mgr); + std::vector ids; + if (float_type->width() == 32) { + float scalar = c2->GetFloat(); + for (uint32_t i = 0; i < c1_components.size(); ++i) { + utils::FloatProxy result(c1_components[i]->GetFloat() * scalar); + std::vector words = result.GetWords(); + const analysis::Constant* new_elem = + const_mgr->GetConstant(float_type, words); + ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id()); + } + return const_mgr->GetConstant(vector_type, ids); + } else if (float_type->width() == 64) { + double scalar = c2->GetDouble(); + for (uint32_t i = 0; i < c1_components.size(); ++i) { + utils::FloatProxy result(c1_components[i]->GetDouble() * + scalar); + std::vector words = result.GetWords(); + const analysis::Constant* new_elem = + const_mgr->GetConstant(float_type, words); + ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id()); + } + return const_mgr->GetConstant(vector_type, ids); + } + return nullptr; + }; +} + +ConstantFoldingRule FoldCompositeWithConstants() { + // Folds an OpCompositeConstruct where all of the inputs are constants to a + // constant. A new constant is created if necessary. + return [](IRContext* context, Instruction* inst, + const std::vector& constants) + -> const analysis::Constant* { + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + analysis::TypeManager* type_mgr = context->get_type_mgr(); + const analysis::Type* new_type = type_mgr->GetType(inst->type_id()); + Instruction* type_inst = + context->get_def_use_mgr()->GetDef(inst->type_id()); + + std::vector ids; + for (uint32_t i = 0; i < constants.size(); ++i) { + const analysis::Constant* element_const = constants[i]; + if (element_const == nullptr) { + return nullptr; + } + + uint32_t component_type_id = 0; + if (type_inst->opcode() == SpvOpTypeStruct) { + component_type_id = type_inst->GetSingleWordInOperand(i); + } else if (type_inst->opcode() == SpvOpTypeArray) { + component_type_id = type_inst->GetSingleWordInOperand(0); + } + + uint32_t element_id = + const_mgr->FindDeclaredConstant(element_const, component_type_id); + if (element_id == 0) { + return nullptr; + } + ids.push_back(element_id); + } + return const_mgr->GetConstant(new_type, ids); + }; +} + +// The interface for a function that returns the result of applying a scalar +// floating-point binary operation on |a| and |b|. The type of the return value +// will be |type|. The input constants must also be of type |type|. +using UnaryScalarFoldingRule = std::function; + +// The interface for a function that returns the result of applying a scalar +// floating-point binary operation on |a| and |b|. The type of the return value +// will be |type|. The input constants must also be of type |type|. +using BinaryScalarFoldingRule = std::function; + +// Returns a |ConstantFoldingRule| that folds unary floating point scalar ops +// using |scalar_rule| and unary float point vectors ops by applying +// |scalar_rule| to the elements of the vector. The |ConstantFoldingRule| +// that is returned assumes that |constants| contains 1 entry. If they are +// not |nullptr|, then their type is either |Float| or |Integer| or a |Vector| +// whose element type is |Float| or |Integer|. +ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) { + return [scalar_rule](IRContext* context, Instruction* inst, + const std::vector& constants) + -> const analysis::Constant* { + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + analysis::TypeManager* type_mgr = context->get_type_mgr(); + const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); + const analysis::Vector* vector_type = result_type->AsVector(); + + if (!inst->IsFloatingPointFoldingAllowed()) { + return nullptr; + } + + if (constants[0] == nullptr) { + return nullptr; + } + + if (vector_type != nullptr) { + std::vector a_components; + std::vector results_components; + + a_components = constants[0]->GetVectorComponents(const_mgr); + + // Fold each component of the vector. + for (uint32_t i = 0; i < a_components.size(); ++i) { + results_components.push_back(scalar_rule(vector_type->element_type(), + a_components[i], const_mgr)); + if (results_components[i] == nullptr) { + return nullptr; + } + } + + // Build the constant object and return it. + std::vector ids; + for (const analysis::Constant* member : results_components) { + ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id()); + } + return const_mgr->GetConstant(vector_type, ids); + } else { + return scalar_rule(result_type, constants[0], const_mgr); + } + }; +} + +// Returns a |ConstantFoldingRule| that folds floating point scalars using +// |scalar_rule| and vectors of floating point by applying |scalar_rule| to the +// elements of the vector. The |ConstantFoldingRule| that is returned assumes +// that |constants| contains 2 entries. If they are not |nullptr|, then their +// type is either |Float| or a |Vector| whose element type is |Float|. +ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) { + return [scalar_rule](IRContext* context, Instruction* inst, + const std::vector& constants) + -> const analysis::Constant* { + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + analysis::TypeManager* type_mgr = context->get_type_mgr(); + const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); + const analysis::Vector* vector_type = result_type->AsVector(); + + if (!inst->IsFloatingPointFoldingAllowed()) { + return nullptr; + } + + if (constants[0] == nullptr || constants[1] == nullptr) { + return nullptr; + } + + if (vector_type != nullptr) { + std::vector a_components; + std::vector b_components; + std::vector results_components; + + a_components = constants[0]->GetVectorComponents(const_mgr); + b_components = constants[1]->GetVectorComponents(const_mgr); + + // Fold each component of the vector. + for (uint32_t i = 0; i < a_components.size(); ++i) { + results_components.push_back(scalar_rule(vector_type->element_type(), + a_components[i], + b_components[i], const_mgr)); + if (results_components[i] == nullptr) { + return nullptr; + } + } + + // Build the constant object and return it. + std::vector ids; + for (const analysis::Constant* member : results_components) { + ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id()); + } + return const_mgr->GetConstant(vector_type, ids); + } else { + return scalar_rule(result_type, constants[0], constants[1], const_mgr); + } + }; +} + +// This macro defines a |UnaryScalarFoldingRule| that performs float to +// integer conversion. +// TODO(greg-lunarg): Support for 64-bit integer types. +UnaryScalarFoldingRule FoldFToIOp() { + return [](const analysis::Type* result_type, const analysis::Constant* a, + analysis::ConstantManager* const_mgr) -> const analysis::Constant* { + assert(result_type != nullptr && a != nullptr); + const analysis::Integer* integer_type = result_type->AsInteger(); + const analysis::Float* float_type = a->type()->AsFloat(); + assert(float_type != nullptr); + assert(integer_type != nullptr); + if (integer_type->width() != 32) return nullptr; + if (float_type->width() == 32) { + float fa = a->GetFloat(); + uint32_t result = integer_type->IsSigned() + ? static_cast(static_cast(fa)) + : static_cast(fa); + std::vector words = {result}; + return const_mgr->GetConstant(result_type, words); + } else if (float_type->width() == 64) { + double fa = a->GetDouble(); + uint32_t result = integer_type->IsSigned() + ? static_cast(static_cast(fa)) + : static_cast(fa); + std::vector words = {result}; + return const_mgr->GetConstant(result_type, words); + } + return nullptr; + }; +} + +// This function defines a |UnaryScalarFoldingRule| that performs integer to +// float conversion. +// TODO(greg-lunarg): Support for 64-bit integer types. +UnaryScalarFoldingRule FoldIToFOp() { + return [](const analysis::Type* result_type, const analysis::Constant* a, + analysis::ConstantManager* const_mgr) -> const analysis::Constant* { + assert(result_type != nullptr && a != nullptr); + const analysis::Integer* integer_type = a->type()->AsInteger(); + const analysis::Float* float_type = result_type->AsFloat(); + assert(float_type != nullptr); + assert(integer_type != nullptr); + if (integer_type->width() != 32) return nullptr; + uint32_t ua = a->GetU32(); + if (float_type->width() == 32) { + float result_val = integer_type->IsSigned() + ? static_cast(static_cast(ua)) + : static_cast(ua); + utils::FloatProxy result(result_val); + std::vector words = {result.data()}; + return const_mgr->GetConstant(result_type, words); + } else if (float_type->width() == 64) { + double result_val = integer_type->IsSigned() + ? static_cast(static_cast(ua)) + : static_cast(ua); + utils::FloatProxy result(result_val); + std::vector words = result.GetWords(); + return const_mgr->GetConstant(result_type, words); + } + return nullptr; + }; +} + +// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The +// operator |op| must work for both float and double, and use syntax "f1 op f2". +#define FOLD_FPARITH_OP(op) \ + [](const analysis::Type* result_type, const analysis::Constant* a, \ + const analysis::Constant* b, \ + analysis::ConstantManager* const_mgr_in_macro) \ + -> const analysis::Constant* { \ + assert(result_type != nullptr && a != nullptr && b != nullptr); \ + assert(result_type == a->type() && result_type == b->type()); \ + const analysis::Float* float_type_in_macro = result_type->AsFloat(); \ + assert(float_type_in_macro != nullptr); \ + if (float_type_in_macro->width() == 32) { \ + float fa = a->GetFloat(); \ + float fb = b->GetFloat(); \ + utils::FloatProxy result_in_macro(fa op fb); \ + std::vector words_in_macro = result_in_macro.GetWords(); \ + return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \ + } else if (float_type_in_macro->width() == 64) { \ + double fa = a->GetDouble(); \ + double fb = b->GetDouble(); \ + utils::FloatProxy result_in_macro(fa op fb); \ + std::vector words_in_macro = result_in_macro.GetWords(); \ + return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \ + } \ + return nullptr; \ + } + +// Define the folding rule for conversion between floating point and integer +ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); } +ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); } + +// Define the folding rules for subtraction, addition, multiplication, and +// division for floating point values. +ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); } +ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); } +ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); } +ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FOLD_FPARITH_OP(/)); } + +bool CompareFloatingPoint(bool op_result, bool op_unordered, + bool need_ordered) { + if (need_ordered) { + // operands are ordered and Operand 1 is |op| Operand 2 + return !op_unordered && op_result; + } else { + // operands are unordered or Operand 1 is |op| Operand 2 + return op_unordered || op_result; + } +} + +// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The +// operator |op| must work for both float and double, and use syntax "f1 op f2". +#define FOLD_FPCMP_OP(op, ord) \ + [](const analysis::Type* result_type, const analysis::Constant* a, \ + const analysis::Constant* b, \ + analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \ + assert(result_type != nullptr && a != nullptr && b != nullptr); \ + assert(result_type->AsBool()); \ + assert(a->type() == b->type()); \ + const analysis::Float* float_type = a->type()->AsFloat(); \ + assert(float_type != nullptr); \ + if (float_type->width() == 32) { \ + float fa = a->GetFloat(); \ + float fb = b->GetFloat(); \ + bool result = CompareFloatingPoint( \ + fa op fb, std::isnan(fa) || std::isnan(fb), ord); \ + std::vector words = {uint32_t(result)}; \ + return const_mgr->GetConstant(result_type, words); \ + } else if (float_type->width() == 64) { \ + double fa = a->GetDouble(); \ + double fb = b->GetDouble(); \ + bool result = CompareFloatingPoint( \ + fa op fb, std::isnan(fa) || std::isnan(fb), ord); \ + std::vector words = {uint32_t(result)}; \ + return const_mgr->GetConstant(result_type, words); \ + } \ + return nullptr; \ + } + +// Define the folding rules for ordered and unordered comparison for floating +// point values. +ConstantFoldingRule FoldFOrdEqual() { + return FoldFPBinaryOp(FOLD_FPCMP_OP(==, true)); +} +ConstantFoldingRule FoldFUnordEqual() { + return FoldFPBinaryOp(FOLD_FPCMP_OP(==, false)); +} +ConstantFoldingRule FoldFOrdNotEqual() { + return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, true)); +} +ConstantFoldingRule FoldFUnordNotEqual() { + return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, false)); +} +ConstantFoldingRule FoldFOrdLessThan() { + return FoldFPBinaryOp(FOLD_FPCMP_OP(<, true)); +} +ConstantFoldingRule FoldFUnordLessThan() { + return FoldFPBinaryOp(FOLD_FPCMP_OP(<, false)); +} +ConstantFoldingRule FoldFOrdGreaterThan() { + return FoldFPBinaryOp(FOLD_FPCMP_OP(>, true)); +} +ConstantFoldingRule FoldFUnordGreaterThan() { + return FoldFPBinaryOp(FOLD_FPCMP_OP(>, false)); +} +ConstantFoldingRule FoldFOrdLessThanEqual() { + return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, true)); +} +ConstantFoldingRule FoldFUnordLessThanEqual() { + return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, false)); +} +ConstantFoldingRule FoldFOrdGreaterThanEqual() { + return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, true)); +} +ConstantFoldingRule FoldFUnordGreaterThanEqual() { + return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false)); +} + +// Folds an OpDot where all of the inputs are constants to a +// constant. A new constant is created if necessary. +ConstantFoldingRule FoldOpDotWithConstants() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) + -> const analysis::Constant* { + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + analysis::TypeManager* type_mgr = context->get_type_mgr(); + const analysis::Type* new_type = type_mgr->GetType(inst->type_id()); + assert(new_type->AsFloat() && "OpDot should have a float return type."); + const analysis::Float* float_type = new_type->AsFloat(); + + if (!inst->IsFloatingPointFoldingAllowed()) { + return nullptr; + } + + // If one of the operands is 0, then the result is 0. + bool has_zero_operand = false; + + for (int i = 0; i < 2; ++i) { + if (constants[i]) { + if (constants[i]->AsNullConstant() || + constants[i]->AsVectorConstant()->IsZero()) { + has_zero_operand = true; + break; + } + } + } + + if (has_zero_operand) { + if (float_type->width() == 32) { + utils::FloatProxy result(0.0f); + std::vector words = result.GetWords(); + return const_mgr->GetConstant(float_type, words); + } + if (float_type->width() == 64) { + utils::FloatProxy result(0.0); + std::vector words = result.GetWords(); + return const_mgr->GetConstant(float_type, words); + } + return nullptr; + } + + if (constants[0] == nullptr || constants[1] == nullptr) { + return nullptr; + } + + std::vector a_components; + std::vector b_components; + + a_components = constants[0]->GetVectorComponents(const_mgr); + b_components = constants[1]->GetVectorComponents(const_mgr); + + utils::FloatProxy result(0.0); + std::vector words = result.GetWords(); + const analysis::Constant* result_const = + const_mgr->GetConstant(float_type, words); + for (uint32_t i = 0; i < a_components.size(); ++i) { + if (a_components[i] == nullptr || b_components[i] == nullptr) { + return nullptr; + } + + const analysis::Constant* component = FOLD_FPARITH_OP(*)( + new_type, a_components[i], b_components[i], const_mgr); + result_const = + FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr); + } + return result_const; + }; +} + +// This function defines a |UnaryScalarFoldingRule| that subtracts the constant +// from zero. +UnaryScalarFoldingRule FoldFNegateOp() { + return [](const analysis::Type* result_type, const analysis::Constant* a, + analysis::ConstantManager* const_mgr) -> const analysis::Constant* { + assert(result_type != nullptr && a != nullptr); + assert(result_type == a->type()); + const analysis::Float* float_type = result_type->AsFloat(); + assert(float_type != nullptr); + if (float_type->width() == 32) { + float fa = a->GetFloat(); + utils::FloatProxy result(-fa); + std::vector words = result.GetWords(); + return const_mgr->GetConstant(result_type, words); + } else if (float_type->width() == 64) { + double da = a->GetDouble(); + utils::FloatProxy result(-da); + std::vector words = result.GetWords(); + return const_mgr->GetConstant(result_type, words); + } + return nullptr; + }; +} + +ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(FoldFNegateOp()); } + +ConstantFoldingRule FoldFClampFeedingCompare(uint32_t cmp_opcode) { + return [cmp_opcode](IRContext* context, Instruction* inst, + const std::vector& constants) + -> const analysis::Constant* { + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + + if (!inst->IsFloatingPointFoldingAllowed()) { + return nullptr; + } + + uint32_t non_const_idx = (constants[0] ? 1 : 0); + uint32_t operand_id = inst->GetSingleWordInOperand(non_const_idx); + Instruction* operand_inst = def_use_mgr->GetDef(operand_id); + + analysis::TypeManager* type_mgr = context->get_type_mgr(); + const analysis::Type* operand_type = + type_mgr->GetType(operand_inst->type_id()); + + if (!operand_type->AsFloat()) { + return nullptr; + } + + if (operand_type->AsFloat()->width() != 32 && + operand_type->AsFloat()->width() != 64) { + return nullptr; + } + + if (operand_inst->opcode() != SpvOpExtInst) { + return nullptr; + } + + if (operand_inst->GetSingleWordInOperand(1) != GLSLstd450FClamp) { + return nullptr; + } + + if (constants[1] == nullptr && constants[0] == nullptr) { + return nullptr; + } + + uint32_t max_id = operand_inst->GetSingleWordInOperand(4); + const analysis::Constant* max_const = + const_mgr->FindDeclaredConstant(max_id); + + uint32_t min_id = operand_inst->GetSingleWordInOperand(3); + const analysis::Constant* min_const = + const_mgr->FindDeclaredConstant(min_id); + + bool found_result = false; + bool result = false; + + switch (cmp_opcode) { + case SpvOpFOrdLessThan: + case SpvOpFUnordLessThan: + case SpvOpFOrdGreaterThanEqual: + case SpvOpFUnordGreaterThanEqual: + if (constants[0]) { + if (min_const) { + if (constants[0]->GetValueAsDouble() < + min_const->GetValueAsDouble()) { + found_result = true; + result = (cmp_opcode == SpvOpFOrdLessThan || + cmp_opcode == SpvOpFUnordLessThan); + } + } + if (max_const) { + if (constants[0]->GetValueAsDouble() >= + max_const->GetValueAsDouble()) { + found_result = true; + result = !(cmp_opcode == SpvOpFOrdLessThan || + cmp_opcode == SpvOpFUnordLessThan); + } + } + } + + if (constants[1]) { + if (max_const) { + if (max_const->GetValueAsDouble() < + constants[1]->GetValueAsDouble()) { + found_result = true; + result = (cmp_opcode == SpvOpFOrdLessThan || + cmp_opcode == SpvOpFUnordLessThan); + } + } + + if (min_const) { + if (min_const->GetValueAsDouble() >= + constants[1]->GetValueAsDouble()) { + found_result = true; + result = !(cmp_opcode == SpvOpFOrdLessThan || + cmp_opcode == SpvOpFUnordLessThan); + } + } + } + break; + case SpvOpFOrdGreaterThan: + case SpvOpFUnordGreaterThan: + case SpvOpFOrdLessThanEqual: + case SpvOpFUnordLessThanEqual: + if (constants[0]) { + if (min_const) { + if (constants[0]->GetValueAsDouble() <= + min_const->GetValueAsDouble()) { + found_result = true; + result = (cmp_opcode == SpvOpFOrdLessThanEqual || + cmp_opcode == SpvOpFUnordLessThanEqual); + } + } + if (max_const) { + if (constants[0]->GetValueAsDouble() > + max_const->GetValueAsDouble()) { + found_result = true; + result = !(cmp_opcode == SpvOpFOrdLessThanEqual || + cmp_opcode == SpvOpFUnordLessThanEqual); + } + } + } + + if (constants[1]) { + if (max_const) { + if (max_const->GetValueAsDouble() <= + constants[1]->GetValueAsDouble()) { + found_result = true; + result = (cmp_opcode == SpvOpFOrdLessThanEqual || + cmp_opcode == SpvOpFUnordLessThanEqual); + } + } + + if (min_const) { + if (min_const->GetValueAsDouble() > + constants[1]->GetValueAsDouble()) { + found_result = true; + result = !(cmp_opcode == SpvOpFOrdLessThanEqual || + cmp_opcode == SpvOpFUnordLessThanEqual); + } + } + } + break; + default: + return nullptr; + } + + if (!found_result) { + return nullptr; + } + + const analysis::Type* bool_type = + context->get_type_mgr()->GetType(inst->type_id()); + const analysis::Constant* result_const = + const_mgr->GetConstant(bool_type, {static_cast(result)}); + assert(result_const); + return result_const; + }; +} + +} // namespace + +ConstantFoldingRules::ConstantFoldingRules() { + // Add all folding rules to the list for the opcodes to which they apply. + // Note that the order in which rules are added to the list matters. If a rule + // applies to the instruction, the rest of the rules will not be attempted. + // Take that into consideration. + + rules_[SpvOpCompositeConstruct].push_back(FoldCompositeWithConstants()); + + rules_[SpvOpCompositeExtract].push_back(FoldExtractWithConstants()); + + rules_[SpvOpConvertFToS].push_back(FoldFToI()); + rules_[SpvOpConvertFToU].push_back(FoldFToI()); + rules_[SpvOpConvertSToF].push_back(FoldIToF()); + rules_[SpvOpConvertUToF].push_back(FoldIToF()); + + rules_[SpvOpDot].push_back(FoldOpDotWithConstants()); + rules_[SpvOpFAdd].push_back(FoldFAdd()); + rules_[SpvOpFDiv].push_back(FoldFDiv()); + rules_[SpvOpFMul].push_back(FoldFMul()); + rules_[SpvOpFSub].push_back(FoldFSub()); + + rules_[SpvOpFOrdEqual].push_back(FoldFOrdEqual()); + + rules_[SpvOpFUnordEqual].push_back(FoldFUnordEqual()); + + rules_[SpvOpFOrdNotEqual].push_back(FoldFOrdNotEqual()); + + rules_[SpvOpFUnordNotEqual].push_back(FoldFUnordNotEqual()); + + rules_[SpvOpFOrdLessThan].push_back(FoldFOrdLessThan()); + rules_[SpvOpFOrdLessThan].push_back( + FoldFClampFeedingCompare(SpvOpFOrdLessThan)); + + rules_[SpvOpFUnordLessThan].push_back(FoldFUnordLessThan()); + rules_[SpvOpFUnordLessThan].push_back( + FoldFClampFeedingCompare(SpvOpFUnordLessThan)); + + rules_[SpvOpFOrdGreaterThan].push_back(FoldFOrdGreaterThan()); + rules_[SpvOpFOrdGreaterThan].push_back( + FoldFClampFeedingCompare(SpvOpFOrdGreaterThan)); + + rules_[SpvOpFUnordGreaterThan].push_back(FoldFUnordGreaterThan()); + rules_[SpvOpFUnordGreaterThan].push_back( + FoldFClampFeedingCompare(SpvOpFUnordGreaterThan)); + + rules_[SpvOpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual()); + rules_[SpvOpFOrdLessThanEqual].push_back( + FoldFClampFeedingCompare(SpvOpFOrdLessThanEqual)); + + rules_[SpvOpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual()); + rules_[SpvOpFUnordLessThanEqual].push_back( + FoldFClampFeedingCompare(SpvOpFUnordLessThanEqual)); + + rules_[SpvOpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual()); + rules_[SpvOpFOrdGreaterThanEqual].push_back( + FoldFClampFeedingCompare(SpvOpFOrdGreaterThanEqual)); + + rules_[SpvOpFUnordGreaterThanEqual].push_back(FoldFUnordGreaterThanEqual()); + rules_[SpvOpFUnordGreaterThanEqual].push_back( + FoldFClampFeedingCompare(SpvOpFUnordGreaterThanEqual)); + + rules_[SpvOpVectorShuffle].push_back(FoldVectorShuffleWithConstants()); + rules_[SpvOpVectorTimesScalar].push_back(FoldVectorTimesScalar()); + + rules_[SpvOpFNegate].push_back(FoldFNegate()); +} +} // namespace opt +} // namespace spvtools diff --git a/source/opt/const_folding_rules.h b/source/opt/const_folding_rules.h new file mode 100644 index 000000000..c1865792b --- /dev/null +++ b/source/opt/const_folding_rules.h @@ -0,0 +1,80 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_CONST_FOLDING_RULES_H_ +#define SOURCE_OPT_CONST_FOLDING_RULES_H_ + +#include +#include + +#include "source/opt/constants.h" + +namespace spvtools { +namespace opt { + +// Constant Folding Rules: +// +// The folding mechanism is built around the concept of a |ConstantFoldingRule|. +// A constant folding rule is a function that implements a method of simplifying +// an instruction to a constant. +// +// The inputs to a folding rule are: +// |inst| - the instruction to be simplified. +// |constants| - if an in-operands is an id of a constant, then the +// corresponding value in |constants| contains that +// constant value. Otherwise, the corresponding entry in +// |constants| is |nullptr|. +// +// A constant folding rule returns a pointer to an Constant if |inst| can be +// simplified using this rule. Otherwise, it returns |nullptr|. +// +// See const_folding_rules.cpp for examples on how to write a constant folding +// rule. +// +// Be sure to add new constant folding rules to the table of constant folding +// rules in the constructor for ConstantFoldingRules. The new rule should be +// added to the list for every opcode that it applies to. Note that earlier +// rules in the list are given priority. That is, if an earlier rule is able to +// fold an instruction, the later rules will not be attempted. + +using ConstantFoldingRule = std::function& constants)>; + +class ConstantFoldingRules { + public: + ConstantFoldingRules(); + + // Returns true if there is at least 1 folding rule for |opcode|. + bool HasFoldingRule(SpvOp opcode) const { return rules_.count(opcode); } + + // Returns an vector of constant folding rules for |opcode|. + const std::vector& GetRulesForOpcode( + SpvOp opcode) const { + auto it = rules_.find(opcode); + if (it != rules_.end()) { + return it->second; + } + return empty_vector_; + } + + private: + std::unordered_map> rules_; + std::vector empty_vector_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_CONST_FOLDING_RULES_H_ diff --git a/source/opt/constants.cpp b/source/opt/constants.cpp new file mode 100644 index 000000000..768364be8 --- /dev/null +++ b/source/opt/constants.cpp @@ -0,0 +1,374 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/constants.h" + +#include +#include + +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace opt { +namespace analysis { + +float Constant::GetFloat() const { + assert(type()->AsFloat() != nullptr && type()->AsFloat()->width() == 32); + + if (const FloatConstant* fc = AsFloatConstant()) { + return fc->GetFloatValue(); + } else { + assert(AsNullConstant() && "Must be a floating point constant."); + return 0.0f; + } +} + +double Constant::GetDouble() const { + assert(type()->AsFloat() != nullptr && type()->AsFloat()->width() == 64); + + if (const FloatConstant* fc = AsFloatConstant()) { + return fc->GetDoubleValue(); + } else { + assert(AsNullConstant() && "Must be a floating point constant."); + return 0.0; + } +} + +double Constant::GetValueAsDouble() const { + assert(type()->AsFloat() != nullptr); + if (type()->AsFloat()->width() == 32) { + return GetFloat(); + } else { + assert(type()->AsFloat()->width() == 64); + return GetDouble(); + } +} + +uint32_t Constant::GetU32() const { + assert(type()->AsInteger() != nullptr); + assert(type()->AsInteger()->width() == 32); + + if (const IntConstant* ic = AsIntConstant()) { + return ic->GetU32BitValue(); + } else { + assert(AsNullConstant() && "Must be an integer constant."); + return 0u; + } +} + +uint64_t Constant::GetU64() const { + assert(type()->AsInteger() != nullptr); + assert(type()->AsInteger()->width() == 64); + + if (const IntConstant* ic = AsIntConstant()) { + return ic->GetU64BitValue(); + } else { + assert(AsNullConstant() && "Must be an integer constant."); + return 0u; + } +} + +int32_t Constant::GetS32() const { + assert(type()->AsInteger() != nullptr); + assert(type()->AsInteger()->width() == 32); + + if (const IntConstant* ic = AsIntConstant()) { + return ic->GetS32BitValue(); + } else { + assert(AsNullConstant() && "Must be an integer constant."); + return 0; + } +} + +int64_t Constant::GetS64() const { + assert(type()->AsInteger() != nullptr); + assert(type()->AsInteger()->width() == 64); + + if (const IntConstant* ic = AsIntConstant()) { + return ic->GetS64BitValue(); + } else { + assert(AsNullConstant() && "Must be an integer constant."); + return 0; + } +} + +ConstantManager::ConstantManager(IRContext* ctx) : ctx_(ctx) { + // Populate the constant table with values from constant declarations in the + // module. The values of each OpConstant declaration is the identity + // assignment (i.e., each constant is its own value). + for (const auto& inst : ctx_->module()->GetConstants()) { + MapInst(inst); + } +} + +Type* ConstantManager::GetType(const Instruction* inst) const { + return context()->get_type_mgr()->GetType(inst->type_id()); +} + +std::vector ConstantManager::GetOperandConstants( + Instruction* inst) const { + std::vector constants; + for (uint32_t i = 0; i < inst->NumInOperands(); i++) { + const Operand* operand = &inst->GetInOperand(i); + if (operand->type != SPV_OPERAND_TYPE_ID) { + constants.push_back(nullptr); + } else { + uint32_t id = operand->words[0]; + const analysis::Constant* constant = FindDeclaredConstant(id); + constants.push_back(constant); + } + } + return constants; +} + +uint32_t ConstantManager::FindDeclaredConstant(const Constant* c, + uint32_t type_id) const { + c = FindConstant(c); + if (c == nullptr) { + return 0; + } + + for (auto range = const_val_to_id_.equal_range(c); + range.first != range.second; ++range.first) { + Instruction* const_def = + context()->get_def_use_mgr()->GetDef(range.first->second); + if (type_id == 0 || const_def->type_id() == type_id) { + return range.first->second; + } + } + return 0; +} + +std::vector ConstantManager::GetConstantsFromIds( + const std::vector& ids) const { + std::vector constants; + for (uint32_t id : ids) { + if (const Constant* c = FindDeclaredConstant(id)) { + constants.push_back(c); + } else { + return {}; + } + } + return constants; +} + +Instruction* ConstantManager::BuildInstructionAndAddToModule( + const Constant* new_const, Module::inst_iterator* pos, uint32_t type_id) { + // TODO(1841): Handle id overflow. + uint32_t new_id = context()->TakeNextId(); + auto new_inst = CreateInstruction(new_id, new_const, type_id); + if (!new_inst) { + return nullptr; + } + auto* new_inst_ptr = new_inst.get(); + *pos = pos->InsertBefore(std::move(new_inst)); + ++(*pos); + context()->get_def_use_mgr()->AnalyzeInstDefUse(new_inst_ptr); + MapConstantToInst(new_const, new_inst_ptr); + return new_inst_ptr; +} + +Instruction* ConstantManager::GetDefiningInstruction( + const Constant* c, uint32_t type_id, Module::inst_iterator* pos) { + assert(type_id == 0 || + context()->get_type_mgr()->GetType(type_id) == c->type()); + uint32_t decl_id = FindDeclaredConstant(c, type_id); + if (decl_id == 0) { + auto iter = context()->types_values_end(); + if (pos == nullptr) pos = &iter; + return BuildInstructionAndAddToModule(c, pos, type_id); + } else { + auto def = context()->get_def_use_mgr()->GetDef(decl_id); + assert(def != nullptr); + assert((type_id == 0 || def->type_id() == type_id) && + "This constant already has an instruction with a different type."); + return def; + } +} + +std::unique_ptr ConstantManager::CreateConstant( + const Type* type, const std::vector& literal_words_or_ids) const { + if (literal_words_or_ids.size() == 0) { + // Constant declared with OpConstantNull + return MakeUnique(type); + } else if (auto* bt = type->AsBool()) { + assert(literal_words_or_ids.size() == 1 && + "Bool constant should be declared with one operand"); + return MakeUnique(bt, literal_words_or_ids.front()); + } else if (auto* it = type->AsInteger()) { + return MakeUnique(it, literal_words_or_ids); + } else if (auto* ft = type->AsFloat()) { + return MakeUnique(ft, literal_words_or_ids); + } else if (auto* vt = type->AsVector()) { + auto components = GetConstantsFromIds(literal_words_or_ids); + if (components.empty()) return nullptr; + // All components of VectorConstant must be of type Bool, Integer or Float. + if (!std::all_of(components.begin(), components.end(), + [](const Constant* c) { + if (c->type()->AsBool() || c->type()->AsInteger() || + c->type()->AsFloat()) { + return true; + } else { + return false; + } + })) + return nullptr; + // All components of VectorConstant must be in the same type. + const auto* component_type = components.front()->type(); + if (!std::all_of(components.begin(), components.end(), + [&component_type](const Constant* c) { + if (c->type() == component_type) return true; + return false; + })) + return nullptr; + return MakeUnique(vt, components); + } else if (auto* mt = type->AsMatrix()) { + auto components = GetConstantsFromIds(literal_words_or_ids); + if (components.empty()) return nullptr; + return MakeUnique(mt, components); + } else if (auto* st = type->AsStruct()) { + auto components = GetConstantsFromIds(literal_words_or_ids); + if (components.empty()) return nullptr; + return MakeUnique(st, components); + } else if (auto* at = type->AsArray()) { + auto components = GetConstantsFromIds(literal_words_or_ids); + if (components.empty()) return nullptr; + return MakeUnique(at, components); + } else { + return nullptr; + } +} + +const Constant* ConstantManager::GetConstantFromInst(Instruction* inst) { + std::vector literal_words_or_ids; + + // Collect the constant defining literals or component ids. + for (uint32_t i = 0; i < inst->NumInOperands(); i++) { + literal_words_or_ids.insert(literal_words_or_ids.end(), + inst->GetInOperand(i).words.begin(), + inst->GetInOperand(i).words.end()); + } + + switch (inst->opcode()) { + // OpConstant{True|False} have the value embedded in the opcode. So they + // are not handled by the for-loop above. Here we add the value explicitly. + case SpvOp::SpvOpConstantTrue: + literal_words_or_ids.push_back(true); + break; + case SpvOp::SpvOpConstantFalse: + literal_words_or_ids.push_back(false); + break; + case SpvOp::SpvOpConstantNull: + case SpvOp::SpvOpConstant: + case SpvOp::SpvOpConstantComposite: + case SpvOp::SpvOpSpecConstantComposite: + break; + default: + return nullptr; + } + + return GetConstant(GetType(inst), literal_words_or_ids); +} + +std::unique_ptr ConstantManager::CreateInstruction( + uint32_t id, const Constant* c, uint32_t type_id) const { + uint32_t type = + (type_id == 0) ? context()->get_type_mgr()->GetId(c->type()) : type_id; + if (c->AsNullConstant()) { + return MakeUnique(context(), SpvOp::SpvOpConstantNull, type, + id, std::initializer_list{}); + } else if (const BoolConstant* bc = c->AsBoolConstant()) { + return MakeUnique( + context(), + bc->value() ? SpvOp::SpvOpConstantTrue : SpvOp::SpvOpConstantFalse, + type, id, std::initializer_list{}); + } else if (const IntConstant* ic = c->AsIntConstant()) { + return MakeUnique( + context(), SpvOp::SpvOpConstant, type, id, + std::initializer_list{ + Operand(spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, + ic->words())}); + } else if (const FloatConstant* fc = c->AsFloatConstant()) { + return MakeUnique( + context(), SpvOp::SpvOpConstant, type, id, + std::initializer_list{ + Operand(spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, + fc->words())}); + } else if (const CompositeConstant* cc = c->AsCompositeConstant()) { + return CreateCompositeInstruction(id, cc, type_id); + } else { + return nullptr; + } +} + +std::unique_ptr ConstantManager::CreateCompositeInstruction( + uint32_t result_id, const CompositeConstant* cc, uint32_t type_id) const { + std::vector operands; + Instruction* type_inst = context()->get_def_use_mgr()->GetDef(type_id); + uint32_t component_index = 0; + for (const Constant* component_const : cc->GetComponents()) { + uint32_t component_type_id = 0; + if (type_inst && type_inst->opcode() == SpvOpTypeStruct) { + component_type_id = type_inst->GetSingleWordInOperand(component_index); + } else if (type_inst && type_inst->opcode() == SpvOpTypeArray) { + component_type_id = type_inst->GetSingleWordInOperand(0); + } + uint32_t id = FindDeclaredConstant(component_const, component_type_id); + + if (id == 0) { + // Cannot get the id of the component constant, while all components + // should have been added to the module prior to the composite constant. + // Cannot create OpConstantComposite instruction in this case. + return nullptr; + } + operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID, + std::initializer_list{id}); + component_index++; + } + uint32_t type = + (type_id == 0) ? context()->get_type_mgr()->GetId(cc->type()) : type_id; + return MakeUnique(context(), SpvOp::SpvOpConstantComposite, type, + result_id, std::move(operands)); +} + +const Constant* ConstantManager::GetConstant( + const Type* type, const std::vector& literal_words_or_ids) { + auto cst = CreateConstant(type, literal_words_or_ids); + return cst ? RegisterConstant(std::move(cst)) : nullptr; +} + +std::vector Constant::GetVectorComponents( + analysis::ConstantManager* const_mgr) const { + std::vector components; + const analysis::VectorConstant* a = this->AsVectorConstant(); + const analysis::Vector* vector_type = this->type()->AsVector(); + assert(vector_type != nullptr); + if (a != nullptr) { + for (uint32_t i = 0; i < vector_type->element_count(); ++i) { + components.push_back(a->GetComponents()[i]); + } + } else { + const analysis::Type* element_type = vector_type->element_type(); + const analysis::Constant* element_null_const = + const_mgr->GetConstant(element_type, {}); + for (uint32_t i = 0; i < vector_type->element_count(); ++i) { + components.push_back(element_null_const); + } + } + return components; +} + +} // namespace analysis +} // namespace opt +} // namespace spvtools diff --git a/source/opt/constants.h b/source/opt/constants.h new file mode 100644 index 000000000..de2dfc3d0 --- /dev/null +++ b/source/opt/constants.h @@ -0,0 +1,696 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_CONSTANTS_H_ +#define SOURCE_OPT_CONSTANTS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "source/opt/module.h" +#include "source/opt/type_manager.h" +#include "source/opt/types.h" +#include "source/util/hex_float.h" +#include "source/util/make_unique.h" + +namespace spvtools { +namespace opt { + +class IRContext; + +namespace analysis { + +// Class hierarchy to represent the normal constants defined through +// OpConstantTrue, OpConstantFalse, OpConstant, OpConstantNull and +// OpConstantComposite instructions. +// TODO(qining): Add class for constants defined with OpConstantSampler. +class Constant; +class ScalarConstant; +class IntConstant; +class FloatConstant; +class BoolConstant; +class CompositeConstant; +class StructConstant; +class VectorConstant; +class MatrixConstant; +class ArrayConstant; +class NullConstant; +class ConstantManager; + +// Abstract class for a SPIR-V constant. It has a bunch of As methods, +// which is used as a way to probe the actual +class Constant { + public: + Constant() = delete; + virtual ~Constant() {} + + // Make a deep copy of this constant. + virtual std::unique_ptr Copy() const = 0; + + // reflections + virtual ScalarConstant* AsScalarConstant() { return nullptr; } + virtual IntConstant* AsIntConstant() { return nullptr; } + virtual FloatConstant* AsFloatConstant() { return nullptr; } + virtual BoolConstant* AsBoolConstant() { return nullptr; } + virtual CompositeConstant* AsCompositeConstant() { return nullptr; } + virtual StructConstant* AsStructConstant() { return nullptr; } + virtual VectorConstant* AsVectorConstant() { return nullptr; } + virtual MatrixConstant* AsMatrixConstant() { return nullptr; } + virtual ArrayConstant* AsArrayConstant() { return nullptr; } + virtual NullConstant* AsNullConstant() { return nullptr; } + + virtual const ScalarConstant* AsScalarConstant() const { return nullptr; } + virtual const IntConstant* AsIntConstant() const { return nullptr; } + virtual const FloatConstant* AsFloatConstant() const { return nullptr; } + virtual const BoolConstant* AsBoolConstant() const { return nullptr; } + virtual const CompositeConstant* AsCompositeConstant() const { + return nullptr; + } + virtual const StructConstant* AsStructConstant() const { return nullptr; } + virtual const VectorConstant* AsVectorConstant() const { return nullptr; } + virtual const MatrixConstant* AsMatrixConstant() const { return nullptr; } + virtual const ArrayConstant* AsArrayConstant() const { return nullptr; } + virtual const NullConstant* AsNullConstant() const { return nullptr; } + + // Returns the float representation of the constant. Must be a 32 bit + // Float type. + float GetFloat() const; + + // Returns the double representation of the constant. Must be a 64 bit + // Float type. + double GetDouble() const; + + // Returns the double representation of the constant. Must be a 32-bit or + // 64-bit Float type. + double GetValueAsDouble() const; + + // Returns uint32_t representation of the constant. Must be a 32 bit + // Integer type. + uint32_t GetU32() const; + + // Returns uint64_t representation of the constant. Must be a 64 bit + // Integer type. + uint64_t GetU64() const; + + // Returns int32_t representation of the constant. Must be a 32 bit + // Integer type. + int32_t GetS32() const; + + // Returns int64_t representation of the constant. Must be a 64 bit + // Integer type. + int64_t GetS64() const; + + // Returns true if the constant is a zero or a composite containing 0s. + virtual bool IsZero() const { return false; } + + const Type* type() const { return type_; } + + // Returns an std::vector containing the elements of |constant|. The type of + // |constant| must be |Vector|. + std::vector GetVectorComponents( + ConstantManager* const_mgr) const; + + protected: + Constant(const Type* ty) : type_(ty) {} + + // The type of this constant. + const Type* type_; +}; + +// Abstract class for scalar type constants. +class ScalarConstant : public Constant { + public: + ScalarConstant() = delete; + ScalarConstant* AsScalarConstant() override { return this; } + const ScalarConstant* AsScalarConstant() const override { return this; } + + // Returns a const reference of the value of this constant in 32-bit words. + virtual const std::vector& words() const { return words_; } + + // Returns true if the value is zero. + bool IsZero() const override { + bool is_zero = true; + for (uint32_t v : words()) { + if (v != 0) { + is_zero = false; + break; + } + } + return is_zero; + } + + protected: + ScalarConstant(const Type* ty, const std::vector& w) + : Constant(ty), words_(w) {} + ScalarConstant(const Type* ty, std::vector&& w) + : Constant(ty), words_(std::move(w)) {} + std::vector words_; +}; + +// Integer type constant. +class IntConstant : public ScalarConstant { + public: + IntConstant(const Integer* ty, const std::vector& w) + : ScalarConstant(ty, w) {} + IntConstant(const Integer* ty, std::vector&& w) + : ScalarConstant(ty, std::move(w)) {} + + IntConstant* AsIntConstant() override { return this; } + const IntConstant* AsIntConstant() const override { return this; } + + int32_t GetS32BitValue() const { + // Relies on signed values smaller than 32-bit being sign extended. See + // section 2.2.1 of the SPIR-V spec. + assert(words().size() == 1); + return words()[0]; + } + + uint32_t GetU32BitValue() const { + // Relies on unsigned values smaller than 32-bit being zero extended. See + // section 2.2.1 of the SPIR-V spec. + assert(words().size() == 1); + return words()[0]; + } + + int64_t GetS64BitValue() const { + // Relies on unsigned values smaller than 64-bit being sign extended. See + // section 2.2.1 of the SPIR-V spec. + assert(words().size() == 2); + return static_cast(words()[1]) << 32 | + static_cast(words()[0]); + } + + uint64_t GetU64BitValue() const { + // Relies on unsigned values smaller than 64-bit being zero extended. See + // section 2.2.1 of the SPIR-V spec. + assert(words().size() == 2); + return static_cast(words()[1]) << 32 | + static_cast(words()[0]); + } + + // Make a copy of this IntConstant instance. + std::unique_ptr CopyIntConstant() const { + return MakeUnique(type_->AsInteger(), words_); + } + std::unique_ptr Copy() const override { + return std::unique_ptr(CopyIntConstant().release()); + } +}; + +// Float type constant. +class FloatConstant : public ScalarConstant { + public: + FloatConstant(const Float* ty, const std::vector& w) + : ScalarConstant(ty, w) {} + FloatConstant(const Float* ty, std::vector&& w) + : ScalarConstant(ty, std::move(w)) {} + + FloatConstant* AsFloatConstant() override { return this; } + const FloatConstant* AsFloatConstant() const override { return this; } + + // Make a copy of this FloatConstant instance. + std::unique_ptr CopyFloatConstant() const { + return MakeUnique(type_->AsFloat(), words_); + } + std::unique_ptr Copy() const override { + return std::unique_ptr(CopyFloatConstant().release()); + } + + // Returns the float value of |this|. The type of |this| must be |Float| with + // width of 32. + float GetFloatValue() const { + assert(type()->AsFloat()->width() == 32 && + "Not a 32-bit floating point value."); + utils::FloatProxy a(words()[0]); + return a.getAsFloat(); + } + + // Returns the double value of |this|. The type of |this| must be |Float| + // with width of 64. + double GetDoubleValue() const { + assert(type()->AsFloat()->width() == 64 && + "Not a 32-bit floating point value."); + uint64_t combined_words = words()[1]; + combined_words = combined_words << 32; + combined_words |= words()[0]; + utils::FloatProxy a(combined_words); + return a.getAsFloat(); + } +}; + +// Bool type constant. +class BoolConstant : public ScalarConstant { + public: + BoolConstant(const Bool* ty, bool v) + : ScalarConstant(ty, {static_cast(v)}), value_(v) {} + + BoolConstant* AsBoolConstant() override { return this; } + const BoolConstant* AsBoolConstant() const override { return this; } + + // Make a copy of this BoolConstant instance. + std::unique_ptr CopyBoolConstant() const { + return MakeUnique(type_->AsBool(), value_); + } + std::unique_ptr Copy() const override { + return std::unique_ptr(CopyBoolConstant().release()); + } + + bool value() const { return value_; } + + private: + bool value_; +}; + +// Abstract class for composite constants. +class CompositeConstant : public Constant { + public: + CompositeConstant() = delete; + CompositeConstant* AsCompositeConstant() override { return this; } + const CompositeConstant* AsCompositeConstant() const override { return this; } + + // Returns a const reference of the components held in this composite + // constant. + virtual const std::vector& GetComponents() const { + return components_; + } + + bool IsZero() const override { + for (const Constant* c : GetComponents()) { + if (!c->IsZero()) { + return false; + } + } + return true; + } + + protected: + CompositeConstant(const Type* ty) : Constant(ty), components_() {} + CompositeConstant(const Type* ty, + const std::vector& components) + : Constant(ty), components_(components) {} + CompositeConstant(const Type* ty, std::vector&& components) + : Constant(ty), components_(std::move(components)) {} + std::vector components_; +}; + +// Struct type constant. +class StructConstant : public CompositeConstant { + public: + StructConstant(const Struct* ty) : CompositeConstant(ty) {} + StructConstant(const Struct* ty, + const std::vector& components) + : CompositeConstant(ty, components) {} + StructConstant(const Struct* ty, std::vector&& components) + : CompositeConstant(ty, std::move(components)) {} + + StructConstant* AsStructConstant() override { return this; } + const StructConstant* AsStructConstant() const override { return this; } + + // Make a copy of this StructConstant instance. + std::unique_ptr CopyStructConstant() const { + return MakeUnique(type_->AsStruct(), components_); + } + std::unique_ptr Copy() const override { + return std::unique_ptr(CopyStructConstant().release()); + } +}; + +// Vector type constant. +class VectorConstant : public CompositeConstant { + public: + VectorConstant(const Vector* ty) + : CompositeConstant(ty), component_type_(ty->element_type()) {} + VectorConstant(const Vector* ty, + const std::vector& components) + : CompositeConstant(ty, components), + component_type_(ty->element_type()) {} + VectorConstant(const Vector* ty, std::vector&& components) + : CompositeConstant(ty, std::move(components)), + component_type_(ty->element_type()) {} + + VectorConstant* AsVectorConstant() override { return this; } + const VectorConstant* AsVectorConstant() const override { return this; } + + // Make a copy of this VectorConstant instance. + std::unique_ptr CopyVectorConstant() const { + auto another = MakeUnique(type_->AsVector()); + another->components_.insert(another->components_.end(), components_.begin(), + components_.end()); + return another; + } + std::unique_ptr Copy() const override { + return std::unique_ptr(CopyVectorConstant().release()); + } + + const Type* component_type() const { return component_type_; } + + private: + const Type* component_type_; +}; + +// Matrix type constant. +class MatrixConstant : public CompositeConstant { + public: + MatrixConstant(const Matrix* ty) + : CompositeConstant(ty), component_type_(ty->element_type()) {} + MatrixConstant(const Matrix* ty, + const std::vector& components) + : CompositeConstant(ty, components), + component_type_(ty->element_type()) {} + MatrixConstant(const Vector* ty, std::vector&& components) + : CompositeConstant(ty, std::move(components)), + component_type_(ty->element_type()) {} + + MatrixConstant* AsMatrixConstant() override { return this; } + const MatrixConstant* AsMatrixConstant() const override { return this; } + + // Make a copy of this MatrixConstant instance. + std::unique_ptr CopyMatrixConstant() const { + auto another = MakeUnique(type_->AsMatrix()); + another->components_.insert(another->components_.end(), components_.begin(), + components_.end()); + return another; + } + std::unique_ptr Copy() const override { + return std::unique_ptr(CopyMatrixConstant().release()); + } + + const Type* component_type() { return component_type_; } + + private: + const Type* component_type_; +}; + +// Array type constant. +class ArrayConstant : public CompositeConstant { + public: + ArrayConstant(const Array* ty) : CompositeConstant(ty) {} + ArrayConstant(const Array* ty, const std::vector& components) + : CompositeConstant(ty, components) {} + ArrayConstant(const Array* ty, std::vector&& components) + : CompositeConstant(ty, std::move(components)) {} + + ArrayConstant* AsArrayConstant() override { return this; } + const ArrayConstant* AsArrayConstant() const override { return this; } + + // Make a copy of this ArrayConstant instance. + std::unique_ptr CopyArrayConstant() const { + return MakeUnique(type_->AsArray(), components_); + } + std::unique_ptr Copy() const override { + return std::unique_ptr(CopyArrayConstant().release()); + } +}; + +// Null type constant. +class NullConstant : public Constant { + public: + NullConstant(const Type* ty) : Constant(ty) {} + NullConstant* AsNullConstant() override { return this; } + const NullConstant* AsNullConstant() const override { return this; } + + // Make a copy of this NullConstant instance. + std::unique_ptr CopyNullConstant() const { + return MakeUnique(type_); + } + std::unique_ptr Copy() const override { + return std::unique_ptr(CopyNullConstant().release()); + } + bool IsZero() const override { return true; }; +}; + +// Hash function for Constant instances. Use the structure of the constant as +// the key. +struct ConstantHash { + void add_pointer(std::u32string* h, const void* p) const { + uint64_t ptr_val = reinterpret_cast(p); + h->push_back(static_cast(ptr_val >> 32)); + h->push_back(static_cast(ptr_val)); + } + + size_t operator()(const Constant* const_val) const { + std::u32string h; + add_pointer(&h, const_val->type()); + if (const auto scalar = const_val->AsScalarConstant()) { + for (const auto& w : scalar->words()) { + h.push_back(w); + } + } else if (const auto composite = const_val->AsCompositeConstant()) { + for (const auto& c : composite->GetComponents()) { + add_pointer(&h, c); + } + } else if (const_val->AsNullConstant()) { + h.push_back(0); + } else { + assert( + false && + "Tried to compute the hash value of an invalid Constant instance."); + } + + return std::hash()(h); + } +}; + +// Equality comparison structure for two constants. +struct ConstantEqual { + bool operator()(const Constant* c1, const Constant* c2) const { + if (c1->type() != c2->type()) { + return false; + } + + if (const auto& s1 = c1->AsScalarConstant()) { + const auto& s2 = c2->AsScalarConstant(); + return s2 && s1->words() == s2->words(); + } else if (const auto& composite1 = c1->AsCompositeConstant()) { + const auto& composite2 = c2->AsCompositeConstant(); + return composite2 && + composite1->GetComponents() == composite2->GetComponents(); + } else if (c1->AsNullConstant()) { + return c2->AsNullConstant() != nullptr; + } else { + assert(false && "Tried to compare two invalid Constant instances."); + } + return false; + } +}; + +// This class represents a pool of constants. +class ConstantManager { + public: + ConstantManager(IRContext* ctx); + + IRContext* context() const { return ctx_; } + + // Gets or creates a unique Constant instance of type |type| and a vector of + // constant defining words |words|. If a Constant instance existed already in + // the constant pool, it returns a pointer to it. Otherwise, it creates one + // using CreateConstant. If a new Constant instance cannot be created, it + // returns nullptr. + const Constant* GetConstant( + const Type* type, const std::vector& literal_words_or_ids); + + template + const Constant* GetConstant(const Type* type, const C& literal_words_or_ids) { + return GetConstant(type, std::vector(literal_words_or_ids.begin(), + literal_words_or_ids.end())); + } + + // Gets or creates a Constant instance to hold the constant value of the given + // instruction. It returns a pointer to a Constant instance or nullptr if it + // could not create the constant. + const Constant* GetConstantFromInst(Instruction* inst); + + // Gets or creates a constant defining instruction for the given Constant |c|. + // If |c| had already been defined, it returns a pointer to the existing + // declaration. Otherwise, it calls BuildInstructionAndAddToModule. If the + // optional |pos| is given, it will insert any newly created instructions at + // the given instruction iterator position. Otherwise, it inserts the new + // instruction at the end of the current module's types section. + // + // |type_id| is an optional argument for disambiguating equivalent types. If + // |type_id| is specified, it is used as the type of the constant when a new + // instruction is created. Otherwise the type of the constant is derived by + // getting an id from the type manager for |c|. + // + // When |type_id| is not zero, the type of |c| must be the type returned by + // type manager when given |type_id|. + Instruction* GetDefiningInstruction(const Constant* c, uint32_t type_id = 0, + Module::inst_iterator* pos = nullptr); + + // Creates a constant defining instruction for the given Constant instance + // and inserts the instruction at the position specified by the given + // instruction iterator. Returns a pointer to the created instruction if + // succeeded, otherwise returns a null pointer. The instruction iterator + // points to the same instruction before and after the insertion. This is the + // only method that actually manages id creation/assignment and instruction + // creation/insertion for a new Constant instance. + // + // |type_id| is an optional argument for disambiguating equivalent types. If + // |type_id| is specified, it is used as the type of the constant. Otherwise + // the type of the constant is derived by getting an id from the type manager + // for |c|. + Instruction* BuildInstructionAndAddToModule(const Constant* c, + Module::inst_iterator* pos, + uint32_t type_id = 0); + + // A helper function to get the result type of the given instruction. Returns + // nullptr if the instruction does not have a type id (type id is 0). + Type* GetType(const Instruction* inst) const; + + // A helper function to get the collected normal constant with the given id. + // Returns the pointer to the Constant instance in case it is found. + // Otherwise, it returns a null pointer. + const Constant* FindDeclaredConstant(uint32_t id) const { + auto iter = id_to_const_val_.find(id); + return (iter != id_to_const_val_.end()) ? iter->second : nullptr; + } + + // A helper function to get the id of a collected constant with the pointer + // to the Constant instance. Returns 0 in case the constant is not found. + uint32_t FindDeclaredConstant(const Constant* c, uint32_t type_id) const; + + // Returns the canonical constant that has the same structure and value as the + // given Constant |cst|. If none is found, it returns nullptr. + // + // TODO: Should be able to give a type id to disambiguate types with the same + // structure. + const Constant* FindConstant(const Constant* c) const { + auto it = const_pool_.find(c); + return (it != const_pool_.end()) ? *it : nullptr; + } + + // Registers a new constant |cst| in the constant pool. If the constant + // existed already, it returns a pointer to the previously existing Constant + // in the pool. Otherwise, it returns |cst|. + const Constant* RegisterConstant(std::unique_ptr cst) { + auto ret = const_pool_.insert(cst.get()); + if (ret.second) { + owned_constants_.emplace_back(std::move(cst)); + } + return *ret.first; + } + + // A helper function to get a vector of Constant instances with the specified + // ids. If it can not find the Constant instance for any one of the ids, + // it returns an empty vector. + std::vector GetConstantsFromIds( + const std::vector& ids) const; + + // Returns a vector of constants representing each in operand. If an operand + // is not constant its entry is nullptr. + std::vector GetOperandConstants(Instruction* inst) const; + + // Records a mapping between |inst| and the constant value generated by it. + // It returns true if a new Constant was successfully mapped, false if |inst| + // generates no constant values. + bool MapInst(Instruction* inst) { + if (auto cst = GetConstantFromInst(inst)) { + MapConstantToInst(cst, inst); + return true; + } + return false; + } + + void RemoveId(uint32_t id) { + auto it = id_to_const_val_.find(id); + if (it != id_to_const_val_.end()) { + const_val_to_id_.erase(it->second); + id_to_const_val_.erase(it); + } + } + + // Records a new mapping between |inst| and |const_value|. This updates the + // two mappings |id_to_const_val_| and |const_val_to_id_|. + void MapConstantToInst(const Constant* const_value, Instruction* inst) { + if (id_to_const_val_.insert({inst->result_id(), const_value}).second) { + const_val_to_id_.insert({const_value, inst->result_id()}); + } + } + + private: + // Creates a Constant instance with the given type and a vector of constant + // defining words. Returns a unique pointer to the created Constant instance + // if the Constant instance can be created successfully. To create scalar + // type constants, the vector should contain the constant value in 32 bit + // words and the given type must be of type Bool, Integer or Float. To create + // composite type constants, the vector should contain the component ids, and + // those component ids should have been recorded before as Normal Constants. + // And the given type must be of type Struct, Vector or Array. When creating + // VectorType Constant instance, the components must be scalars of the same + // type, either Bool, Integer or Float. If any of the rules above failed, the + // creation will fail and nullptr will be returned. If the vector is empty, + // a NullConstant instance will be created with the given type. + std::unique_ptr CreateConstant( + const Type* type, + const std::vector& literal_words_or_ids) const; + + // Creates an instruction with the given result id to declare a constant + // represented by the given Constant instance. Returns an unique pointer to + // the created instruction if the instruction can be created successfully. + // Otherwise, returns a null pointer. + // + // |type_id| is an optional argument for disambiguating equivalent types. If + // |type_id| is specified, it is used as the type of the constant. Otherwise + // the type of the constant is derived by getting an id from the type manager + // for |c|. + std::unique_ptr CreateInstruction(uint32_t result_id, + const Constant* c, + uint32_t type_id = 0) const; + + // Creates an OpConstantComposite instruction with the given result id and + // the CompositeConst instance which represents a composite constant. Returns + // an unique pointer to the created instruction if succeeded. Otherwise + // returns a null pointer. + // + // |type_id| is an optional argument for disambiguating equivalent types. If + // |type_id| is specified, it is used as the type of the constant. Otherwise + // the type of the constant is derived by getting an id from the type manager + // for |c|. + std::unique_ptr CreateCompositeInstruction( + uint32_t result_id, const CompositeConstant* cc, + uint32_t type_id = 0) const; + + // IR context that owns this constant manager. + IRContext* ctx_; + + // A mapping from the result ids of Normal Constants to their + // Constant instances. All Normal Constants in the module, either + // existing ones before optimization or the newly generated ones, should have + // their Constant instance stored and their result id registered in this map. + std::unordered_map id_to_const_val_; + + // A mapping from the Constant instance of Normal Constants to their + // result id in the module. This is a mirror map of |id_to_const_val_|. All + // Normal Constants that defining instructions in the module should have + // their Constant and their result id registered here. + std::multimap const_val_to_id_; + + // The constant pool. All created constants are registered here. + std::unordered_set const_pool_; + + // The constant that are owned by the constant manager. Every constant in + // |const_pool_| should be in |owned_constants_| as well. + std::vector> owned_constants_; +}; + +} // namespace analysis +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_CONSTANTS_H_ diff --git a/source/opt/copy_prop_arrays.cpp b/source/opt/copy_prop_arrays.cpp new file mode 100644 index 000000000..e508c05d1 --- /dev/null +++ b/source/opt/copy_prop_arrays.cpp @@ -0,0 +1,872 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/copy_prop_arrays.h" + +#include + +#include "source/opt/ir_builder.h" + +namespace spvtools { +namespace opt { +namespace { + +const uint32_t kLoadPointerInOperand = 0; +const uint32_t kStorePointerInOperand = 0; +const uint32_t kStoreObjectInOperand = 1; +const uint32_t kCompositeExtractObjectInOperand = 0; +const uint32_t kTypePointerStorageClassInIdx = 0; +const uint32_t kTypePointerPointeeInIdx = 1; + +} // namespace + +Pass::Status CopyPropagateArrays::Process() { + bool modified = false; + for (Function& function : *get_module()) { + BasicBlock* entry_bb = &*function.begin(); + + for (auto var_inst = entry_bb->begin(); var_inst->opcode() == SpvOpVariable; + ++var_inst) { + if (!IsPointerToArrayType(var_inst->type_id())) { + continue; + } + + // Find the only store to the entire memory location, if it exists. + Instruction* store_inst = FindStoreInstruction(&*var_inst); + + if (!store_inst) { + continue; + } + + std::unique_ptr source_object = + FindSourceObjectIfPossible(&*var_inst, store_inst); + + if (source_object != nullptr) { + if (CanUpdateUses(&*var_inst, source_object->GetPointerTypeId(this))) { + modified = true; + PropagateObject(&*var_inst, source_object.get(), store_inst); + } + } + } + } + return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); +} + +std::unique_ptr +CopyPropagateArrays::FindSourceObjectIfPossible(Instruction* var_inst, + Instruction* store_inst) { + assert(var_inst->opcode() == SpvOpVariable && "Expecting a variable."); + + // Check that the variable is a composite object where |store_inst| + // dominates all of its loads. + if (!store_inst) { + return nullptr; + } + + // Look at the loads to ensure they are dominated by the store. + if (!HasValidReferencesOnly(var_inst, store_inst)) { + return nullptr; + } + + // If so, look at the store to see if it is the copy of an object. + std::unique_ptr source = GetSourceObjectIfAny( + store_inst->GetSingleWordInOperand(kStoreObjectInOperand)); + + if (!source) { + return nullptr; + } + + // Ensure that |source| does not change between the point at which it is + // loaded, and the position in which |var_inst| is loaded. + // + // For now we will go with the easy to implement approach, and check that the + // entire variable (not just the specific component) is never written to. + + if (!HasNoStores(source->GetVariable())) { + return nullptr; + } + return source; +} + +Instruction* CopyPropagateArrays::FindStoreInstruction( + const Instruction* var_inst) const { + Instruction* store_inst = nullptr; + get_def_use_mgr()->WhileEachUser( + var_inst, [&store_inst, var_inst](Instruction* use) { + if (use->opcode() == SpvOpStore && + use->GetSingleWordInOperand(kStorePointerInOperand) == + var_inst->result_id()) { + if (store_inst == nullptr) { + store_inst = use; + } else { + store_inst = nullptr; + return false; + } + } + return true; + }); + return store_inst; +} + +void CopyPropagateArrays::PropagateObject(Instruction* var_inst, + MemoryObject* source, + Instruction* insertion_point) { + assert(var_inst->opcode() == SpvOpVariable && + "This function propagates variables."); + + Instruction* new_access_chain = BuildNewAccessChain(insertion_point, source); + context()->KillNamesAndDecorates(var_inst); + UpdateUses(var_inst, new_access_chain); +} + +Instruction* CopyPropagateArrays::BuildNewAccessChain( + Instruction* insertion_point, + CopyPropagateArrays::MemoryObject* source) const { + InstructionBuilder builder( + context(), insertion_point, + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + + if (source->AccessChain().size() == 0) { + return source->GetVariable(); + } + + return builder.AddAccessChain(source->GetPointerTypeId(this), + source->GetVariable()->result_id(), + source->AccessChain()); +} + +bool CopyPropagateArrays::HasNoStores(Instruction* ptr_inst) { + return get_def_use_mgr()->WhileEachUser(ptr_inst, [this](Instruction* use) { + if (use->opcode() == SpvOpLoad) { + return true; + } else if (use->opcode() == SpvOpAccessChain) { + return HasNoStores(use); + } else if (use->IsDecoration() || use->opcode() == SpvOpName) { + return true; + } else if (use->opcode() == SpvOpStore) { + return false; + } else if (use->opcode() == SpvOpImageTexelPointer) { + return true; + } + // Some other instruction. Be conservative. + return false; + }); +} + +bool CopyPropagateArrays::HasValidReferencesOnly(Instruction* ptr_inst, + Instruction* store_inst) { + BasicBlock* store_block = context()->get_instr_block(store_inst); + DominatorAnalysis* dominator_analysis = + context()->GetDominatorAnalysis(store_block->GetParent()); + + return get_def_use_mgr()->WhileEachUser( + ptr_inst, + [this, store_inst, dominator_analysis, ptr_inst](Instruction* use) { + if (use->opcode() == SpvOpLoad || + use->opcode() == SpvOpImageTexelPointer) { + // TODO: If there are many load in the same BB as |store_inst| the + // time to do the multiple traverses can add up. Consider collecting + // those loads and doing a single traversal. + return dominator_analysis->Dominates(store_inst, use); + } else if (use->opcode() == SpvOpAccessChain) { + return HasValidReferencesOnly(use, store_inst); + } else if (use->IsDecoration() || use->opcode() == SpvOpName) { + return true; + } else if (use->opcode() == SpvOpStore) { + // If we are storing to part of the object it is not an candidate. + return ptr_inst->opcode() == SpvOpVariable && + store_inst->GetSingleWordInOperand(kStorePointerInOperand) == + ptr_inst->result_id(); + } + // Some other instruction. Be conservative. + return false; + }); +} + +std::unique_ptr +CopyPropagateArrays::GetSourceObjectIfAny(uint32_t result) { + Instruction* result_inst = context()->get_def_use_mgr()->GetDef(result); + + switch (result_inst->opcode()) { + case SpvOpLoad: + return BuildMemoryObjectFromLoad(result_inst); + case SpvOpCompositeExtract: + return BuildMemoryObjectFromExtract(result_inst); + case SpvOpCompositeConstruct: + return BuildMemoryObjectFromCompositeConstruct(result_inst); + case SpvOpCopyObject: + return GetSourceObjectIfAny(result_inst->GetSingleWordInOperand(0)); + case SpvOpCompositeInsert: + return BuildMemoryObjectFromInsert(result_inst); + default: + return nullptr; + } +} + +std::unique_ptr +CopyPropagateArrays::BuildMemoryObjectFromLoad(Instruction* load_inst) { + std::vector components_in_reverse; + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + + Instruction* current_inst = def_use_mgr->GetDef( + load_inst->GetSingleWordInOperand(kLoadPointerInOperand)); + + // Build the access chain for the memory object by collecting the indices used + // in the OpAccessChain instructions. If we find a variable index, then + // return |nullptr| because we cannot know for sure which memory location is + // used. + // + // It is built in reverse order because the different |OpAccessChain| + // instructions are visited in reverse order from which they are applied. + while (current_inst->opcode() == SpvOpAccessChain) { + for (uint32_t i = current_inst->NumInOperands() - 1; i >= 1; --i) { + uint32_t element_index_id = current_inst->GetSingleWordInOperand(i); + components_in_reverse.push_back(element_index_id); + } + current_inst = def_use_mgr->GetDef(current_inst->GetSingleWordInOperand(0)); + } + + // If the address in the load is not constructed from an |OpVariable| + // instruction followed by a series of |OpAccessChain| instructions, then + // return |nullptr| because we cannot identify the owner or access chain + // exactly. + if (current_inst->opcode() != SpvOpVariable) { + return nullptr; + } + + // Build the memory object. Use |rbegin| and |rend| to put the access chain + // back in the correct order. + return std::unique_ptr( + new MemoryObject(current_inst, components_in_reverse.rbegin(), + components_in_reverse.rend())); +} + +std::unique_ptr +CopyPropagateArrays::BuildMemoryObjectFromExtract(Instruction* extract_inst) { + assert(extract_inst->opcode() == SpvOpCompositeExtract && + "Expecting an OpCompositeExtract instruction."); + analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); + + std::unique_ptr result = GetSourceObjectIfAny( + extract_inst->GetSingleWordInOperand(kCompositeExtractObjectInOperand)); + + if (result) { + analysis::Integer int_type(32, false); + const analysis::Type* uint32_type = + context()->get_type_mgr()->GetRegisteredType(&int_type); + + std::vector components; + // Convert the indices in the extract instruction to a series of ids that + // can be used by the |OpAccessChain| instruction. + for (uint32_t i = 1; i < extract_inst->NumInOperands(); ++i) { + uint32_t index = extract_inst->GetSingleWordInOperand(i); + const analysis::Constant* index_const = + const_mgr->GetConstant(uint32_type, {index}); + components.push_back( + const_mgr->GetDefiningInstruction(index_const)->result_id()); + } + result->GetMember(components); + return result; + } + return nullptr; +} + +std::unique_ptr +CopyPropagateArrays::BuildMemoryObjectFromCompositeConstruct( + Instruction* conststruct_inst) { + assert(conststruct_inst->opcode() == SpvOpCompositeConstruct && + "Expecting an OpCompositeConstruct instruction."); + + // If every operand in the instruction are part of the same memory object, and + // are being combined in the same order, then the result is the same as the + // parent. + + std::unique_ptr memory_object = + GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(0)); + + if (!memory_object) { + return nullptr; + } + + if (!memory_object->IsMember()) { + return nullptr; + } + + analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); + const analysis::Constant* last_access = + const_mgr->FindDeclaredConstant(memory_object->AccessChain().back()); + if (!last_access || + (!last_access->AsIntConstant() && !last_access->AsNullConstant())) { + return nullptr; + } + + if (last_access->GetU32() != 0) { + return nullptr; + } + + memory_object->GetParent(); + + if (memory_object->GetNumberOfMembers() != + conststruct_inst->NumInOperands()) { + return nullptr; + } + + for (uint32_t i = 1; i < conststruct_inst->NumInOperands(); ++i) { + std::unique_ptr member_object = + GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(i)); + + if (!member_object) { + return nullptr; + } + + if (!member_object->IsMember()) { + return nullptr; + } + + if (!memory_object->Contains(member_object.get())) { + return nullptr; + } + + last_access = + const_mgr->FindDeclaredConstant(member_object->AccessChain().back()); + if (!last_access || !last_access->AsIntConstant()) { + return nullptr; + } + + if (last_access->GetU32() != i) { + return nullptr; + } + } + return memory_object; +} + +std::unique_ptr +CopyPropagateArrays::BuildMemoryObjectFromInsert(Instruction* insert_inst) { + assert(insert_inst->opcode() == SpvOpCompositeInsert && + "Expecting an OpCompositeInsert instruction."); + + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); + const analysis::Type* result_type = type_mgr->GetType(insert_inst->type_id()); + + uint32_t number_of_elements = 0; + if (const analysis::Struct* struct_type = result_type->AsStruct()) { + number_of_elements = + static_cast(struct_type->element_types().size()); + } else if (const analysis::Array* array_type = result_type->AsArray()) { + const analysis::Constant* length_const = + const_mgr->FindDeclaredConstant(array_type->LengthId()); + assert(length_const->AsIntConstant()); + number_of_elements = length_const->AsIntConstant()->GetU32(); + } else if (const analysis::Vector* vector_type = result_type->AsVector()) { + number_of_elements = vector_type->element_count(); + } else if (const analysis::Matrix* matrix_type = result_type->AsMatrix()) { + number_of_elements = matrix_type->element_count(); + } + + if (number_of_elements == 0) { + return nullptr; + } + + if (insert_inst->NumInOperands() != 3) { + return nullptr; + } + + if (insert_inst->GetSingleWordInOperand(2) != number_of_elements - 1) { + return nullptr; + } + + std::unique_ptr memory_object = + GetSourceObjectIfAny(insert_inst->GetSingleWordInOperand(0)); + + if (!memory_object) { + return nullptr; + } + + if (!memory_object->IsMember()) { + return nullptr; + } + + const analysis::Constant* last_access = + const_mgr->FindDeclaredConstant(memory_object->AccessChain().back()); + if (!last_access || !last_access->AsIntConstant()) { + return nullptr; + } + + if (last_access->GetU32() != number_of_elements - 1) { + return nullptr; + } + + memory_object->GetParent(); + + Instruction* current_insert = + def_use_mgr->GetDef(insert_inst->GetSingleWordInOperand(1)); + for (uint32_t i = number_of_elements - 1; i > 0; --i) { + if (current_insert->opcode() != SpvOpCompositeInsert) { + return nullptr; + } + + if (current_insert->NumInOperands() != 3) { + return nullptr; + } + + if (current_insert->GetSingleWordInOperand(2) != i - 1) { + return nullptr; + } + + std::unique_ptr current_memory_object = + GetSourceObjectIfAny(current_insert->GetSingleWordInOperand(0)); + + if (!current_memory_object) { + return nullptr; + } + + if (!current_memory_object->IsMember()) { + return nullptr; + } + + if (memory_object->AccessChain().size() + 1 != + current_memory_object->AccessChain().size()) { + return nullptr; + } + + if (!memory_object->Contains(current_memory_object.get())) { + return nullptr; + } + + const analysis::Constant* current_last_access = + const_mgr->FindDeclaredConstant( + current_memory_object->AccessChain().back()); + if (!current_last_access || !current_last_access->AsIntConstant()) { + return nullptr; + } + + if (current_last_access->GetU32() != i - 1) { + return nullptr; + } + current_insert = + def_use_mgr->GetDef(current_insert->GetSingleWordInOperand(1)); + } + + return memory_object; +} + +bool CopyPropagateArrays::IsPointerToArrayType(uint32_t type_id) { + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + analysis::Pointer* pointer_type = type_mgr->GetType(type_id)->AsPointer(); + if (pointer_type) { + return pointer_type->pointee_type()->kind() == analysis::Type::kArray || + pointer_type->pointee_type()->kind() == analysis::Type::kImage; + } + return false; +} + +bool CopyPropagateArrays::CanUpdateUses(Instruction* original_ptr_inst, + uint32_t type_id) { + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + + analysis::Type* type = type_mgr->GetType(type_id); + if (type->AsRuntimeArray()) { + return false; + } + + if (!type->AsStruct() && !type->AsArray() && !type->AsPointer()) { + // If the type is not an aggregate, then the desired type must be the + // same as the current type. No work to do, and we can do that. + return true; + } + + return def_use_mgr->WhileEachUse(original_ptr_inst, [this, type_mgr, + const_mgr, + type](Instruction* use, + uint32_t) { + switch (use->opcode()) { + case SpvOpLoad: { + analysis::Pointer* pointer_type = type->AsPointer(); + uint32_t new_type_id = type_mgr->GetId(pointer_type->pointee_type()); + + if (new_type_id != use->type_id()) { + return CanUpdateUses(use, new_type_id); + } + return true; + } + case SpvOpAccessChain: { + analysis::Pointer* pointer_type = type->AsPointer(); + const analysis::Type* pointee_type = pointer_type->pointee_type(); + + std::vector access_chain; + for (uint32_t i = 1; i < use->NumInOperands(); ++i) { + const analysis::Constant* index_const = + const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i)); + if (index_const) { + access_chain.push_back(index_const->AsIntConstant()->GetU32()); + } else { + // Variable index means the type is a type where every element + // is the same type. Use element 0 to get the type. + access_chain.push_back(0); + } + } + + const analysis::Type* new_pointee_type = + type_mgr->GetMemberType(pointee_type, access_chain); + analysis::Pointer pointerTy(new_pointee_type, + pointer_type->storage_class()); + uint32_t new_pointer_type_id = + context()->get_type_mgr()->GetTypeInstruction(&pointerTy); + + if (new_pointer_type_id != use->type_id()) { + return CanUpdateUses(use, new_pointer_type_id); + } + return true; + } + case SpvOpCompositeExtract: { + std::vector access_chain; + for (uint32_t i = 1; i < use->NumInOperands(); ++i) { + access_chain.push_back(use->GetSingleWordInOperand(i)); + } + + const analysis::Type* new_type = + type_mgr->GetMemberType(type, access_chain); + uint32_t new_type_id = type_mgr->GetTypeInstruction(new_type); + + if (new_type_id != use->type_id()) { + return CanUpdateUses(use, new_type_id); + } + return true; + } + case SpvOpStore: + // If needed, we can create an element-by-element copy to change the + // type of the value being stored. This way we can always handled + // stores. + return true; + case SpvOpImageTexelPointer: + case SpvOpName: + return true; + default: + return use->IsDecoration(); + } + }); +} +void CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst, + Instruction* new_ptr_inst) { + // TODO (s-perron): Keep the def-use manager up to date. Not done now because + // it can cause problems for the |ForEachUse| traversals. Can be use by + // keeping a list of instructions that need updating, and then updating them + // in |PropagateObject|. + + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + + std::vector > uses; + def_use_mgr->ForEachUse(original_ptr_inst, + [&uses](Instruction* use, uint32_t index) { + uses.push_back({use, index}); + }); + + for (auto pair : uses) { + Instruction* use = pair.first; + uint32_t index = pair.second; + switch (use->opcode()) { + case SpvOpLoad: { + // Replace the actual use. + context()->ForgetUses(use); + use->SetOperand(index, {new_ptr_inst->result_id()}); + + // Update the type. + Instruction* pointer_type_inst = + def_use_mgr->GetDef(new_ptr_inst->type_id()); + uint32_t new_type_id = + pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx); + if (new_type_id != use->type_id()) { + use->SetResultType(new_type_id); + context()->AnalyzeUses(use); + UpdateUses(use, use); + } else { + context()->AnalyzeUses(use); + } + } break; + case SpvOpAccessChain: { + // Update the actual use. + context()->ForgetUses(use); + use->SetOperand(index, {new_ptr_inst->result_id()}); + + // Convert the ids on the OpAccessChain to indices that can be used to + // get the specific member. + std::vector access_chain; + for (uint32_t i = 1; i < use->NumInOperands(); ++i) { + const analysis::Constant* index_const = + const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i)); + if (index_const) { + access_chain.push_back(index_const->AsIntConstant()->GetU32()); + } else { + // Variable index means the type is an type where every element + // is the same type. Use element 0 to get the type. + access_chain.push_back(0); + } + } + + Instruction* pointer_type_inst = + get_def_use_mgr()->GetDef(new_ptr_inst->type_id()); + + uint32_t new_pointee_type_id = GetMemberTypeId( + pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx), + access_chain); + + SpvStorageClass storage_class = static_cast( + pointer_type_inst->GetSingleWordInOperand( + kTypePointerStorageClassInIdx)); + + uint32_t new_pointer_type_id = + type_mgr->FindPointerToType(new_pointee_type_id, storage_class); + + if (new_pointer_type_id != use->type_id()) { + use->SetResultType(new_pointer_type_id); + context()->AnalyzeUses(use); + UpdateUses(use, use); + } else { + context()->AnalyzeUses(use); + } + } break; + case SpvOpCompositeExtract: { + // Update the actual use. + context()->ForgetUses(use); + use->SetOperand(index, {new_ptr_inst->result_id()}); + + uint32_t new_type_id = new_ptr_inst->type_id(); + std::vector access_chain; + for (uint32_t i = 1; i < use->NumInOperands(); ++i) { + access_chain.push_back(use->GetSingleWordInOperand(i)); + } + + new_type_id = GetMemberTypeId(new_type_id, access_chain); + + if (new_type_id != use->type_id()) { + use->SetResultType(new_type_id); + context()->AnalyzeUses(use); + UpdateUses(use, use); + } else { + context()->AnalyzeUses(use); + } + } break; + case SpvOpStore: + // If the use is the pointer, then it is the single store to that + // variable. We do not want to replace it. Instead, it will become + // dead after all of the loads are removed, and ADCE will get rid of it. + // + // If the use is the object being stored, we will create a copy of the + // object turning it into the correct type. The copy is done by + // decomposing the object into the base type, which must be the same, + // and then rebuilding them. + if (index == 1) { + Instruction* target_pointer = def_use_mgr->GetDef( + use->GetSingleWordInOperand(kStorePointerInOperand)); + Instruction* pointer_type = + def_use_mgr->GetDef(target_pointer->type_id()); + uint32_t pointee_type_id = + pointer_type->GetSingleWordInOperand(kTypePointerPointeeInIdx); + uint32_t copy = GenerateCopy(original_ptr_inst, pointee_type_id, use); + + context()->ForgetUses(use); + use->SetInOperand(index, {copy}); + context()->AnalyzeUses(use); + } + break; + case SpvOpImageTexelPointer: + // We treat an OpImageTexelPointer as a load. The result type should + // always have the Image storage class, and should not need to be + // updated. + + // Replace the actual use. + context()->ForgetUses(use); + use->SetOperand(index, {new_ptr_inst->result_id()}); + context()->AnalyzeUses(use); + break; + default: + assert(false && "Don't know how to rewrite instruction"); + break; + } + } +} + +uint32_t CopyPropagateArrays::GenerateCopy(Instruction* object_inst, + uint32_t new_type_id, + Instruction* insertion_position) { + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); + + uint32_t original_type_id = object_inst->type_id(); + if (original_type_id == new_type_id) { + return object_inst->result_id(); + } + + InstructionBuilder ir_builder( + context(), insertion_position, + IRContext::kAnalysisInstrToBlockMapping | IRContext::kAnalysisDefUse); + + analysis::Type* original_type = type_mgr->GetType(original_type_id); + analysis::Type* new_type = type_mgr->GetType(new_type_id); + + if (const analysis::Array* original_array_type = original_type->AsArray()) { + uint32_t original_element_type_id = + type_mgr->GetId(original_array_type->element_type()); + + analysis::Array* new_array_type = new_type->AsArray(); + assert(new_array_type != nullptr && "Can't copy an array to a non-array."); + uint32_t new_element_type_id = + type_mgr->GetId(new_array_type->element_type()); + + std::vector element_ids; + const analysis::Constant* length_const = + const_mgr->FindDeclaredConstant(original_array_type->LengthId()); + assert(length_const->AsIntConstant()); + uint32_t array_length = length_const->AsIntConstant()->GetU32(); + for (uint32_t i = 0; i < array_length; i++) { + Instruction* extract = ir_builder.AddCompositeExtract( + original_element_type_id, object_inst->result_id(), {i}); + element_ids.push_back( + GenerateCopy(extract, new_element_type_id, insertion_position)); + } + + return ir_builder.AddCompositeConstruct(new_type_id, element_ids) + ->result_id(); + } else if (const analysis::Struct* original_struct_type = + original_type->AsStruct()) { + analysis::Struct* new_struct_type = new_type->AsStruct(); + + const std::vector& original_types = + original_struct_type->element_types(); + const std::vector& new_types = + new_struct_type->element_types(); + std::vector element_ids; + for (uint32_t i = 0; i < original_types.size(); i++) { + Instruction* extract = ir_builder.AddCompositeExtract( + type_mgr->GetId(original_types[i]), object_inst->result_id(), {i}); + element_ids.push_back(GenerateCopy(extract, type_mgr->GetId(new_types[i]), + insertion_position)); + } + return ir_builder.AddCompositeConstruct(new_type_id, element_ids) + ->result_id(); + } else { + // If we do not have an aggregate type, then we have a problem. Either we + // found multiple instances of the same type, or we are copying to an + // incompatible type. Either way the code is illegal. + assert(false && + "Don't know how to copy this type. Code is likely illegal."); + } + return 0; +} + +uint32_t CopyPropagateArrays::GetMemberTypeId( + uint32_t id, const std::vector& access_chain) const { + for (uint32_t element_index : access_chain) { + Instruction* type_inst = get_def_use_mgr()->GetDef(id); + switch (type_inst->opcode()) { + case SpvOpTypeArray: + case SpvOpTypeRuntimeArray: + case SpvOpTypeMatrix: + case SpvOpTypeVector: + id = type_inst->GetSingleWordInOperand(0); + break; + case SpvOpTypeStruct: + id = type_inst->GetSingleWordInOperand(element_index); + break; + default: + break; + } + assert(id != 0 && + "Tried to extract from an object where it cannot be done."); + } + return id; +} + +void CopyPropagateArrays::MemoryObject::GetMember( + const std::vector& access_chain) { + access_chain_.insert(access_chain_.end(), access_chain.begin(), + access_chain.end()); +} + +uint32_t CopyPropagateArrays::MemoryObject::GetNumberOfMembers() { + IRContext* context = variable_inst_->context(); + analysis::TypeManager* type_mgr = context->get_type_mgr(); + + const analysis::Type* type = type_mgr->GetType(variable_inst_->type_id()); + type = type->AsPointer()->pointee_type(); + + std::vector access_indices = GetAccessIds(); + type = type_mgr->GetMemberType(type, access_indices); + + if (const analysis::Struct* struct_type = type->AsStruct()) { + return static_cast(struct_type->element_types().size()); + } else if (const analysis::Array* array_type = type->AsArray()) { + const analysis::Constant* length_const = + context->get_constant_mgr()->FindDeclaredConstant( + array_type->LengthId()); + assert(length_const->AsIntConstant()); + return length_const->AsIntConstant()->GetU32(); + } else if (const analysis::Vector* vector_type = type->AsVector()) { + return vector_type->element_count(); + } else if (const analysis::Matrix* matrix_type = type->AsMatrix()) { + return matrix_type->element_count(); + } else { + return 0; + } +} + +template +CopyPropagateArrays::MemoryObject::MemoryObject(Instruction* var_inst, + iterator begin, iterator end) + : variable_inst_(var_inst), access_chain_(begin, end) {} + +std::vector CopyPropagateArrays::MemoryObject::GetAccessIds() const { + analysis::ConstantManager* const_mgr = + variable_inst_->context()->get_constant_mgr(); + + std::vector access_indices; + for (uint32_t id : AccessChain()) { + const analysis::Constant* element_index_const = + const_mgr->FindDeclaredConstant(id); + if (!element_index_const) { + access_indices.push_back(0); + } else { + assert(element_index_const->AsIntConstant()); + access_indices.push_back(element_index_const->AsIntConstant()->GetU32()); + } + } + return access_indices; +} + +bool CopyPropagateArrays::MemoryObject::Contains( + CopyPropagateArrays::MemoryObject* other) { + if (this->GetVariable() != other->GetVariable()) { + return false; + } + + if (AccessChain().size() > other->AccessChain().size()) { + return false; + } + + for (uint32_t i = 0; i < AccessChain().size(); i++) { + if (AccessChain()[i] != other->AccessChain()[i]) { + return false; + } + } + return true; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/copy_prop_arrays.h b/source/opt/copy_prop_arrays.h new file mode 100644 index 000000000..eb7cc68b7 --- /dev/null +++ b/source/opt/copy_prop_arrays.h @@ -0,0 +1,241 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_COPY_PROP_ARRAYS_H_ +#define SOURCE_OPT_COPY_PROP_ARRAYS_H_ + +#include +#include + +#include "source/opt/mem_pass.h" + +namespace spvtools { +namespace opt { + +// This pass implements a simple array copy propagation. It does not do a full +// array data flow. It looks for simple cases that meet the following +// conditions: +// +// 1) The source must never be stored to. +// 2) The target must be stored to exactly once. +// 3) The store to the target must be a store to the entire array, and be a +// copy of the entire source. +// 4) All loads of the target must be dominated by the store. +// +// The hard part is keeping all of the types correct. We do not want to +// have to do too large a search to update everything, which may not be +// possible, do we give up if we see any instruction that might be hard to +// update. + +class CopyPropagateArrays : public MemPass { + public: + const char* name() const override { return "copy-propagate-arrays"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | IRContext::kAnalysisCFG | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisDecorations | + IRContext::kAnalysisDominatorAnalysis | IRContext::kAnalysisNameMap | + IRContext::kAnalysisConstants | IRContext::kAnalysisTypes; + } + + private: + // The class used to identify a particular memory object. This memory object + // will be owned by a particular variable, meaning that the memory is part of + // that variable. It could be the entire variable or a member of the + // variable. + class MemoryObject { + public: + // Construction a memory object that is owned by |var_inst|. The iterator + // |begin| and |end| traverse a container of integers that identify which + // member of |var_inst| this memory object will represent. These integers + // are interpreted the same way they would be in an |OpAccessChain| + // instruction. + template + MemoryObject(Instruction* var_inst, iterator begin, iterator end); + + // Change |this| to now point to the member identified by |access_chain| + // (starting from the current member). The elements in |access_chain| are + // interpreted the same as the indices in the |OpAccessChain| + // instruction. + void GetMember(const std::vector& access_chain); + + // Change |this| to now represent the first enclosing object to which it + // belongs. (Remove the last element off the access_chain). It is invalid + // to call this function if |this| does not represent a member of its owner. + void GetParent() { + assert(IsMember()); + access_chain_.pop_back(); + } + + // Returns true if |this| represents a member of its owner, and not the + // entire variable. + bool IsMember() const { return !access_chain_.empty(); } + + // Returns the number of members in the object represented by |this|. If + // |this| does not represent a composite type, the return value will be 0. + uint32_t GetNumberOfMembers(); + + // Returns the owning variable that the memory object is contained in. + Instruction* GetVariable() const { return variable_inst_; } + + // Returns a vector of integers that can be used to access the specific + // member that |this| represents starting from the owning variable. These + // values are to be interpreted the same way the indices are in an + // |OpAccessChain| instruction. + const std::vector& AccessChain() const { return access_chain_; } + + // Returns the type id of the pointer type that can be used to point to this + // memory object. + uint32_t GetPointerTypeId(const CopyPropagateArrays* pass) const { + analysis::DefUseManager* def_use_mgr = + GetVariable()->context()->get_def_use_mgr(); + analysis::TypeManager* type_mgr = + GetVariable()->context()->get_type_mgr(); + + Instruction* var_pointer_inst = + def_use_mgr->GetDef(GetVariable()->type_id()); + + uint32_t member_type_id = pass->GetMemberTypeId( + var_pointer_inst->GetSingleWordInOperand(1), GetAccessIds()); + + uint32_t member_pointer_type_id = type_mgr->FindPointerToType( + member_type_id, static_cast( + var_pointer_inst->GetSingleWordInOperand(0))); + return member_pointer_type_id; + } + + // Returns the storage class of the memory object. + SpvStorageClass GetStorageClass() const { + analysis::TypeManager* type_mgr = + GetVariable()->context()->get_type_mgr(); + const analysis::Pointer* pointer_type = + type_mgr->GetType(GetVariable()->type_id())->AsPointer(); + return pointer_type->storage_class(); + } + + // Returns true if |other| represents memory that is contains inside of the + // memory represented by |this|. + bool Contains(MemoryObject* other); + + private: + // The variable that owns this memory object. + Instruction* variable_inst_; + + // The access chain to reach the particular member the memory object + // represents. It should be interpreted the same way the indices in an + // |OpAccessChain| are interpreted. + std::vector access_chain_; + std::vector GetAccessIds() const; + }; + + // Returns the memory object being stored to |var_inst| in the store + // instruction |store_inst|, if one exists, that can be used in place of + // |var_inst| in all of the loads of |var_inst|. This code is conservative + // and only identifies very simple cases. If no such memory object can be + // found, the return value is |nullptr|. + std::unique_ptr FindSourceObjectIfPossible( + Instruction* var_inst, Instruction* store_inst); + + // Replaces all loads of |var_inst| with a load from |source| instead. + // |insertion_pos| is a position where it is possible to construct the + // address of |source| and also dominates all of the loads of |var_inst|. + void PropagateObject(Instruction* var_inst, MemoryObject* source, + Instruction* insertion_pos); + + // Returns true if all of the references to |ptr_inst| can be rewritten and + // are dominated by |store_inst|. + bool HasValidReferencesOnly(Instruction* ptr_inst, Instruction* store_inst); + + // Returns a memory object that at one time was equivalent to the value in + // |result|. If no such memory object exists, the return value is |nullptr|. + std::unique_ptr GetSourceObjectIfAny(uint32_t result); + + // Returns the memory object that is loaded by |load_inst|. If a memory + // object cannot be identified, the return value is |nullptr|. The opcode of + // |load_inst| must be |OpLoad|. + std::unique_ptr BuildMemoryObjectFromLoad( + Instruction* load_inst); + + // Returns the memory object that at some point was equivalent to the result + // of |extract_inst|. If a memory object cannot be identified, the return + // value is |nullptr|. The opcode of |extract_inst| must be + // |OpCompositeExtract|. + std::unique_ptr BuildMemoryObjectFromExtract( + Instruction* extract_inst); + + // Returns the memory object that at some point was equivalent to the result + // of |construct_inst|. If a memory object cannot be identified, the return + // value is |nullptr|. The opcode of |constuct_inst| must be + // |OpCompositeConstruct|. + std::unique_ptr BuildMemoryObjectFromCompositeConstruct( + Instruction* conststruct_inst); + + // Returns the memory object that at some point was equivalent to the result + // of |insert_inst|. If a memory object cannot be identified, the return + // value is |nullptr\. The opcode of |insert_inst| must be + // |OpCompositeInsert|. This function looks for a series of + // |OpCompositeInsert| instructions that insert the elements one at a time in + // order from beginning to end. + std::unique_ptr BuildMemoryObjectFromInsert( + Instruction* insert_inst); + + // Return true if |type_id| is a pointer type whose pointee type is an array. + bool IsPointerToArrayType(uint32_t type_id); + + // Returns true of there are not stores using |ptr_inst| or something derived + // from it. + bool HasNoStores(Instruction* ptr_inst); + + // Creates an |OpAccessChain| instruction whose result is a pointer the memory + // represented by |source|. The new instruction will be placed before + // |insertion_point|. |insertion_point| must be part of a function. Returns + // the new instruction. + Instruction* BuildNewAccessChain(Instruction* insertion_point, + MemoryObject* source) const; + + // Rewrites all uses of |original_ptr| to use |new_pointer_inst| updating + // types of other instructions as needed. This function should not be called + // if |CanUpdateUses(original_ptr_inst, new_pointer_inst->type_id())| returns + // false. + void UpdateUses(Instruction* original_ptr_inst, + Instruction* new_pointer_inst); + + // Return true if |UpdateUses| is able to change all of the uses of + // |original_ptr_inst| to |type_id| and still have valid code. + bool CanUpdateUses(Instruction* original_ptr_inst, uint32_t type_id); + + // Returns the id whose value is the same as |object_to_copy| except its type + // is |new_type_id|. Any instructions need to generate this value will be + // inserted before |insertion_position|. + uint32_t GenerateCopy(Instruction* object_to_copy, uint32_t new_type_id, + Instruction* insertion_position); + + // Returns a store to |var_inst| that writes to the entire variable, and is + // the only store that does so. Note it does not look through OpAccessChain + // instruction, so partial stores are not considered. + Instruction* FindStoreInstruction(const Instruction* var_inst) const; + + // Return the type id of the member of the type |id| access using + // |access_chain|. The elements of |access_chain| are to be interpreted the + // same way the indexes are used in an |OpCompositeExtract| instruction. + uint32_t GetMemberTypeId(uint32_t id, + const std::vector& access_chain) const; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_COPY_PROP_ARRAYS_H_ diff --git a/source/opt/dead_branch_elim_pass.cpp b/source/opt/dead_branch_elim_pass.cpp new file mode 100644 index 000000000..98935361f --- /dev/null +++ b/source/opt/dead_branch_elim_pass.cpp @@ -0,0 +1,542 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/dead_branch_elim_pass.h" + +#include +#include +#include + +#include "source/cfa.h" +#include "source/opt/ir_context.h" +#include "source/opt/iterator.h" +#include "source/opt/struct_cfg_analysis.h" +#include "source/util/make_unique.h" + +namespace spvtools { +namespace opt { + +namespace { + +const uint32_t kBranchCondTrueLabIdInIdx = 1; +const uint32_t kBranchCondFalseLabIdInIdx = 2; + +} // anonymous namespace + +bool DeadBranchElimPass::GetConstCondition(uint32_t condId, bool* condVal) { + bool condIsConst; + Instruction* cInst = get_def_use_mgr()->GetDef(condId); + switch (cInst->opcode()) { + case SpvOpConstantFalse: { + *condVal = false; + condIsConst = true; + } break; + case SpvOpConstantTrue: { + *condVal = true; + condIsConst = true; + } break; + case SpvOpLogicalNot: { + bool negVal; + condIsConst = + GetConstCondition(cInst->GetSingleWordInOperand(0), &negVal); + if (condIsConst) *condVal = !negVal; + } break; + default: { condIsConst = false; } break; + } + return condIsConst; +} + +bool DeadBranchElimPass::GetConstInteger(uint32_t selId, uint32_t* selVal) { + Instruction* sInst = get_def_use_mgr()->GetDef(selId); + uint32_t typeId = sInst->type_id(); + Instruction* typeInst = get_def_use_mgr()->GetDef(typeId); + if (!typeInst || (typeInst->opcode() != SpvOpTypeInt)) return false; + // TODO(greg-lunarg): Support non-32 bit ints + if (typeInst->GetSingleWordInOperand(0) != 32) return false; + if (sInst->opcode() == SpvOpConstant) { + *selVal = sInst->GetSingleWordInOperand(0); + return true; + } else if (sInst->opcode() == SpvOpConstantNull) { + *selVal = 0; + return true; + } + return false; +} + +void DeadBranchElimPass::AddBranch(uint32_t labelId, BasicBlock* bp) { + assert(get_def_use_mgr()->GetDef(labelId) != nullptr); + std::unique_ptr newBranch( + new Instruction(context(), SpvOpBranch, 0, 0, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {labelId}}})); + context()->AnalyzeDefUse(&*newBranch); + context()->set_instr_block(&*newBranch, bp); + bp->AddInstruction(std::move(newBranch)); +} + +BasicBlock* DeadBranchElimPass::GetParentBlock(uint32_t id) { + return context()->get_instr_block(get_def_use_mgr()->GetDef(id)); +} + +bool DeadBranchElimPass::MarkLiveBlocks( + Function* func, std::unordered_set* live_blocks) { + StructuredCFGAnalysis* cfgAnalysis = context()->GetStructuredCFGAnalysis(); + + std::unordered_set continues; + std::vector stack; + stack.push_back(&*func->begin()); + bool modified = false; + while (!stack.empty()) { + BasicBlock* block = stack.back(); + stack.pop_back(); + + // Live blocks doubles as visited set. + if (!live_blocks->insert(block).second) continue; + + uint32_t cont_id = block->ContinueBlockIdIfAny(); + if (cont_id != 0) continues.insert(GetParentBlock(cont_id)); + + Instruction* terminator = block->terminator(); + uint32_t live_lab_id = 0; + // Check if the terminator has a single valid successor. + if (terminator->opcode() == SpvOpBranchConditional) { + bool condVal; + if (GetConstCondition(terminator->GetSingleWordInOperand(0u), &condVal)) { + live_lab_id = terminator->GetSingleWordInOperand( + condVal ? kBranchCondTrueLabIdInIdx : kBranchCondFalseLabIdInIdx); + } + } else if (terminator->opcode() == SpvOpSwitch) { + uint32_t sel_val; + if (GetConstInteger(terminator->GetSingleWordInOperand(0u), &sel_val)) { + // Search switch operands for selector value, set live_lab_id to + // corresponding label, use default if not found. + uint32_t icnt = 0; + uint32_t case_val; + terminator->WhileEachInOperand( + [&icnt, &case_val, &sel_val, &live_lab_id](const uint32_t* idp) { + if (icnt == 1) { + // Start with default label. + live_lab_id = *idp; + } else if (icnt > 1) { + if (icnt % 2 == 0) { + case_val = *idp; + } else { + if (case_val == sel_val) { + live_lab_id = *idp; + return false; + } + } + } + ++icnt; + return true; + }); + } + } + + // Don't simplify branches of continue blocks. A path from the continue to + // the header is required. + // TODO(alan-baker): They can be simplified iff there remains a path to the + // backedge. Structured control flow should guarantee one path hits the + // backedge, but I've removed the requirement for structured control flow + // from this pass. + bool simplify = live_lab_id != 0 && !continues.count(block); + + if (simplify) { + modified = true; + // Replace with unconditional branch. + // Remove the merge instruction if it is a selection merge. + AddBranch(live_lab_id, block); + context()->KillInst(terminator); + Instruction* mergeInst = block->GetMergeInst(); + if (mergeInst && mergeInst->opcode() == SpvOpSelectionMerge) { + Instruction* first_break = FindFirstExitFromSelectionMerge( + live_lab_id, mergeInst->GetSingleWordInOperand(0), + cfgAnalysis->LoopMergeBlock(live_lab_id), + cfgAnalysis->LoopContinueBlock(live_lab_id)); + if (first_break == nullptr) { + context()->KillInst(mergeInst); + } else { + mergeInst->RemoveFromList(); + first_break->InsertBefore(std::unique_ptr(mergeInst)); + context()->set_instr_block(mergeInst, + context()->get_instr_block(first_break)); + } + } + stack.push_back(GetParentBlock(live_lab_id)); + } else { + // All successors are live. + const auto* const_block = block; + const_block->ForEachSuccessorLabel([&stack, this](const uint32_t label) { + stack.push_back(GetParentBlock(label)); + }); + } + } + + return modified; +} + +void DeadBranchElimPass::MarkUnreachableStructuredTargets( + const std::unordered_set& live_blocks, + std::unordered_set* unreachable_merges, + std::unordered_map* unreachable_continues) { + for (auto block : live_blocks) { + if (auto merge_id = block->MergeBlockIdIfAny()) { + BasicBlock* merge_block = GetParentBlock(merge_id); + if (!live_blocks.count(merge_block)) { + unreachable_merges->insert(merge_block); + } + if (auto cont_id = block->ContinueBlockIdIfAny()) { + BasicBlock* cont_block = GetParentBlock(cont_id); + if (!live_blocks.count(cont_block)) { + (*unreachable_continues)[cont_block] = block; + } + } + } + } +} + +bool DeadBranchElimPass::FixPhiNodesInLiveBlocks( + Function* func, const std::unordered_set& live_blocks, + const std::unordered_map& unreachable_continues) { + bool modified = false; + for (auto& block : *func) { + if (live_blocks.count(&block)) { + for (auto iter = block.begin(); iter != block.end();) { + if (iter->opcode() != SpvOpPhi) { + break; + } + + bool changed = false; + bool backedge_added = false; + Instruction* inst = &*iter; + std::vector operands; + // Build a complete set of operands (not just input operands). Start + // with type and result id operands. + operands.push_back(inst->GetOperand(0u)); + operands.push_back(inst->GetOperand(1u)); + // Iterate through the incoming labels and determine which to keep + // and/or modify. If there in an unreachable continue block, there will + // be an edge from that block to the header. We need to keep it to + // maintain the structured control flow. If the header has more that 2 + // incoming edges, then the OpPhi must have an entry for that edge. + // However, if there is only one other incoming edge, the OpPhi can be + // eliminated. + for (uint32_t i = 1; i < inst->NumInOperands(); i += 2) { + BasicBlock* inc = GetParentBlock(inst->GetSingleWordInOperand(i)); + auto cont_iter = unreachable_continues.find(inc); + if (cont_iter != unreachable_continues.end() && + cont_iter->second == &block && inst->NumInOperands() > 4) { + if (get_def_use_mgr() + ->GetDef(inst->GetSingleWordInOperand(i - 1)) + ->opcode() == SpvOpUndef) { + // Already undef incoming value, no change necessary. + operands.push_back(inst->GetInOperand(i - 1)); + operands.push_back(inst->GetInOperand(i)); + backedge_added = true; + } else { + // Replace incoming value with undef if this phi exists in the + // loop header. Otherwise, this edge is not live since the + // unreachable continue block will be replaced with an + // unconditional branch to the header only. + operands.emplace_back( + SPV_OPERAND_TYPE_ID, + std::initializer_list{Type2Undef(inst->type_id())}); + operands.push_back(inst->GetInOperand(i)); + changed = true; + backedge_added = true; + } + } else if (live_blocks.count(inc) && inc->IsSuccessor(&block)) { + // Keep live incoming edge. + operands.push_back(inst->GetInOperand(i - 1)); + operands.push_back(inst->GetInOperand(i)); + } else { + // Remove incoming edge. + changed = true; + } + } + + if (changed) { + modified = true; + uint32_t continue_id = block.ContinueBlockIdIfAny(); + if (!backedge_added && continue_id != 0 && + unreachable_continues.count(GetParentBlock(continue_id)) && + operands.size() > 4) { + // Changed the backedge to branch from the continue block instead + // of a successor of the continue block. Add an entry to the phi to + // provide an undef for the continue block. Since the successor of + // the continue must also be unreachable (dominated by the continue + // block), any entry for the original backedge has been removed + // from the phi operands. + operands.emplace_back( + SPV_OPERAND_TYPE_ID, + std::initializer_list{Type2Undef(inst->type_id())}); + operands.emplace_back(SPV_OPERAND_TYPE_ID, + std::initializer_list{continue_id}); + } + + // Either replace the phi with a single value or rebuild the phi out + // of |operands|. + // + // We always have type and result id operands. So this phi has a + // single source if there are two more operands beyond those. + if (operands.size() == 4) { + // First input data operands is at index 2. + uint32_t replId = operands[2u].words[0]; + context()->ReplaceAllUsesWith(inst->result_id(), replId); + iter = context()->KillInst(&*inst); + } else { + // We've rewritten the operands, so first instruct the def/use + // manager to forget uses in the phi before we replace them. After + // replacing operands update the def/use manager by re-analyzing + // the used ids in this phi. + get_def_use_mgr()->EraseUseRecordsOfOperandIds(inst); + inst->ReplaceOperands(operands); + get_def_use_mgr()->AnalyzeInstUse(inst); + ++iter; + } + } else { + ++iter; + } + } + } + } + + return modified; +} + +bool DeadBranchElimPass::EraseDeadBlocks( + Function* func, const std::unordered_set& live_blocks, + const std::unordered_set& unreachable_merges, + const std::unordered_map& unreachable_continues) { + bool modified = false; + for (auto ebi = func->begin(); ebi != func->end();) { + if (unreachable_merges.count(&*ebi)) { + if (ebi->begin() != ebi->tail() || + ebi->terminator()->opcode() != SpvOpUnreachable) { + // Make unreachable, but leave the label. + KillAllInsts(&*ebi, false); + // Add unreachable terminator. + ebi->AddInstruction( + MakeUnique(context(), SpvOpUnreachable, 0, 0, + std::initializer_list{})); + context()->AnalyzeUses(ebi->terminator()); + context()->set_instr_block(ebi->terminator(), &*ebi); + modified = true; + } + ++ebi; + } else if (unreachable_continues.count(&*ebi)) { + uint32_t cont_id = unreachable_continues.find(&*ebi)->second->id(); + if (ebi->begin() != ebi->tail() || + ebi->terminator()->opcode() != SpvOpBranch || + ebi->terminator()->GetSingleWordInOperand(0u) != cont_id) { + // Make unreachable, but leave the label. + KillAllInsts(&*ebi, false); + // Add unconditional branch to header. + assert(unreachable_continues.count(&*ebi)); + ebi->AddInstruction(MakeUnique( + context(), SpvOpBranch, 0, 0, + std::initializer_list{{SPV_OPERAND_TYPE_ID, {cont_id}}})); + get_def_use_mgr()->AnalyzeInstUse(&*ebi->tail()); + context()->set_instr_block(&*ebi->tail(), &*ebi); + modified = true; + } + ++ebi; + } else if (!live_blocks.count(&*ebi)) { + // Kill this block. + KillAllInsts(&*ebi); + ebi = ebi.Erase(); + modified = true; + } else { + ++ebi; + } + } + + return modified; +} + +bool DeadBranchElimPass::EliminateDeadBranches(Function* func) { + bool modified = false; + std::unordered_set live_blocks; + modified |= MarkLiveBlocks(func, &live_blocks); + + std::unordered_set unreachable_merges; + std::unordered_map unreachable_continues; + MarkUnreachableStructuredTargets(live_blocks, &unreachable_merges, + &unreachable_continues); + modified |= FixPhiNodesInLiveBlocks(func, live_blocks, unreachable_continues); + modified |= EraseDeadBlocks(func, live_blocks, unreachable_merges, + unreachable_continues); + + return modified; +} + +void DeadBranchElimPass::FixBlockOrder() { + context()->BuildInvalidAnalyses(IRContext::kAnalysisCFG | + IRContext::kAnalysisDominatorAnalysis); + // Reorders blocks according to DFS of dominator tree. + ProcessFunction reorder_dominators = [this](Function* function) { + DominatorAnalysis* dominators = context()->GetDominatorAnalysis(function); + std::vector blocks; + for (auto iter = dominators->GetDomTree().begin(); + iter != dominators->GetDomTree().end(); ++iter) { + if (iter->id() != 0) { + blocks.push_back(iter->bb_); + } + } + for (uint32_t i = 1; i < blocks.size(); ++i) { + function->MoveBasicBlockToAfter(blocks[i]->id(), blocks[i - 1]); + } + return true; + }; + + // Reorders blocks according to structured order. + ProcessFunction reorder_structured = [this](Function* function) { + std::list order; + context()->cfg()->ComputeStructuredOrder(function, &*function->begin(), + &order); + std::vector blocks; + for (auto block : order) { + blocks.push_back(block); + } + for (uint32_t i = 1; i < blocks.size(); ++i) { + function->MoveBasicBlockToAfter(blocks[i]->id(), blocks[i - 1]); + } + return true; + }; + + // Structured order is more intuitive so use it where possible. + if (context()->get_feature_mgr()->HasCapability(SpvCapabilityShader)) { + context()->ProcessReachableCallTree(reorder_structured); + } else { + context()->ProcessReachableCallTree(reorder_dominators); + } +} + +Pass::Status DeadBranchElimPass::Process() { + // Do not process if module contains OpGroupDecorate. Additional + // support required in KillNamesAndDecorates(). + // TODO(greg-lunarg): Add support for OpGroupDecorate + for (auto& ai : get_module()->annotations()) + if (ai.opcode() == SpvOpGroupDecorate) return Status::SuccessWithoutChange; + // Process all entry point functions + ProcessFunction pfn = [this](Function* fp) { + return EliminateDeadBranches(fp); + }; + bool modified = context()->ProcessReachableCallTree(pfn); + if (modified) FixBlockOrder(); + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +Instruction* DeadBranchElimPass::FindFirstExitFromSelectionMerge( + uint32_t start_block_id, uint32_t merge_block_id, uint32_t loop_merge_id, + uint32_t loop_continue_id) { + // To find the "first" exit, we follow branches looking for a conditional + // branch that is not in a nested construct and is not the header of a new + // construct. We follow the control flow from |start_block_id| to find the + // first one. + while (start_block_id != merge_block_id && start_block_id != loop_merge_id && + start_block_id != loop_continue_id) { + BasicBlock* start_block = context()->get_instr_block(start_block_id); + Instruction* branch = start_block->terminator(); + uint32_t next_block_id = 0; + switch (branch->opcode()) { + case SpvOpBranchConditional: + next_block_id = start_block->MergeBlockIdIfAny(); + if (next_block_id == 0) { + // If a possible target is the |loop_merge_id| or |loop_continue_id|, + // which are not the current merge node, then we continue the search + // with the other target. + for (uint32_t i = 1; i < 3; i++) { + if (branch->GetSingleWordInOperand(i) == loop_merge_id && + loop_merge_id != merge_block_id) { + next_block_id = branch->GetSingleWordInOperand(3 - i); + break; + } + if (branch->GetSingleWordInOperand(i) == loop_continue_id && + loop_continue_id != merge_block_id) { + next_block_id = branch->GetSingleWordInOperand(3 - i); + break; + } + } + + if (next_block_id == 0) { + return branch; + } + } + break; + case SpvOpSwitch: + next_block_id = start_block->MergeBlockIdIfAny(); + if (next_block_id == 0) { + // A switch with no merge instructions can have at most 4 targets: + // a. |merge_block_id| + // b. |loop_merge_id| + // c. |loop_continue_id| + // d. 1 block inside the current region. + // + // This leads to a number of cases of what to do. + // + // 1. Does not jump to a block inside of the current construct. In + // this case, there is not conditional break, so we should return + // |nullptr|. + // + // 2. Jumps to |merge_block_id| and a block inside the current + // construct. In this case, this branch conditionally break to the + // end of the current construct, so return the current branch. + // + // 3. Otherwise, this branch may break, but not to the current merge + // block. So we continue with the block that is inside the loop. + + bool found_break = false; + for (uint32_t i = 1; i < branch->NumInOperands(); i += 2) { + uint32_t target = branch->GetSingleWordInOperand(i); + if (target == merge_block_id) { + found_break = true; + } else if (target != loop_merge_id && target != loop_continue_id) { + next_block_id = branch->GetSingleWordInOperand(i); + } + } + + if (next_block_id == 0) { + // Case 1. + return nullptr; + } + + if (found_break) { + // Case 2. + return branch; + } + + // The fall through is case 3. + } + break; + case SpvOpBranch: + // Need to check if this is the header of a loop nested in the + // selection construct. + next_block_id = start_block->MergeBlockIdIfAny(); + if (next_block_id == 0) { + next_block_id = branch->GetSingleWordInOperand(0); + } + break; + default: + return nullptr; + } + start_block_id = next_block_id; + } + return nullptr; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/dead_branch_elim_pass.h b/source/opt/dead_branch_elim_pass.h new file mode 100644 index 000000000..4a1b6b49b --- /dev/null +++ b/source/opt/dead_branch_elim_pass.h @@ -0,0 +1,156 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_DEAD_BRANCH_ELIM_PASS_H_ +#define SOURCE_OPT_DEAD_BRANCH_ELIM_PASS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "source/opt/basic_block.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/mem_pass.h" +#include "source/opt/module.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class DeadBranchElimPass : public MemPass { + using cbb_ptr = const BasicBlock*; + + public: + DeadBranchElimPass() = default; + + const char* name() const override { return "eliminate-dead-branches"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisConstants | IRContext::kAnalysisTypes; + } + + private: + // If |condId| is boolean constant, return conditional value in |condVal| and + // return true, otherwise return false. + bool GetConstCondition(uint32_t condId, bool* condVal); + + // If |valId| is a 32-bit integer constant, return value via |value| and + // return true, otherwise return false. + bool GetConstInteger(uint32_t valId, uint32_t* value); + + // Add branch to |labelId| to end of block |bp|. + void AddBranch(uint32_t labelId, BasicBlock* bp); + + // For function |func|, look for BranchConditionals with constant condition + // and convert to a Branch to the indicated label. Delete resulting dead + // blocks. Note some such branches and blocks may be left to avoid creating + // invalid control flow. + // TODO(greg-lunarg): Remove remaining constant conditional branches and dead + // blocks. + bool EliminateDeadBranches(Function* func); + + // Returns the basic block containing |id|. + // Note: this pass only requires correct instruction block mappings for the + // input. This pass does not preserve the block mapping, so it is not kept + // up-to-date during processing. + BasicBlock* GetParentBlock(uint32_t id); + + // Marks live blocks reachable from the entry of |func|. Simplifies constant + // branches and switches as it proceeds, to limit the number of live blocks. + // It is careful not to eliminate backedges even if they are dead, but the + // header is live. Likewise, unreachable merge blocks named in live merge + // instruction must be retained (though they may be clobbered). + bool MarkLiveBlocks(Function* func, + std::unordered_set* live_blocks); + + // Checks for unreachable merge and continue blocks with live headers; those + // blocks must be retained. Continues are tracked separately so that a live + // phi can be updated to take an undef value from any of its predecessors + // that are unreachable continues. + // + // |unreachable_continues| maps the id of an unreachable continue target to + // the header block that declares it. + void MarkUnreachableStructuredTargets( + const std::unordered_set& live_blocks, + std::unordered_set* unreachable_merges, + std::unordered_map* unreachable_continues); + + // Fix phis in reachable blocks so that only live (or unremovable) incoming + // edges are present. If the block now only has a single live incoming edge, + // remove the phi and replace its uses with its data input. If the single + // remaining incoming edge is from the phi itself, the the phi is in an + // unreachable single block loop. Either the block is dead and will be + // removed, or it's reachable from an unreachable continue target. In the + // latter case that continue target block will be collapsed into a block that + // only branches back to its header and we'll eliminate the block with the + // phi. + // + // |unreachable_continues| maps continue targets that cannot be reached to + // merge instruction that declares them. + bool FixPhiNodesInLiveBlocks( + Function* func, const std::unordered_set& live_blocks, + const std::unordered_map& + unreachable_continues); + + // Erases dead blocks. Any block captured in |unreachable_merges| or + // |unreachable_continues| is a dead block that is required to remain due to + // a live merge instruction in the corresponding header. These blocks will + // have their instructions clobbered and will become a label and terminator. + // Unreachable merge blocks are terminated by OpUnreachable, while + // unreachable continue blocks are terminated by an unconditional branch to + // the header. Otherwise, blocks are dead if not explicitly captured in + // |live_blocks| and are totally removed. + // + // |unreachable_continues| maps continue targets that cannot be reached to + // corresponding header block that declares them. + bool EraseDeadBlocks( + Function* func, const std::unordered_set& live_blocks, + const std::unordered_set& unreachable_merges, + const std::unordered_map& + unreachable_continues); + + // Reorders blocks in reachable functions so that they satisfy dominator + // block ordering rules. + void FixBlockOrder(); + + // Return the first branch instruction that is a conditional branch to + // |merge_block_id|. Returns |nullptr| if no such branch exists. If there are + // multiple such branches, the first one is the one that would be executed + // first when running the code. That is, the one that dominates all of the + // others. + // + // |start_block_id| must be a block whose innermost containing merge construct + // has |merge_block_id| as the merge block. + // + // |loop_merge_id| and |loop_continue_id| are the merge and continue block ids + // of the innermost loop containing |start_block_id|. + Instruction* FindFirstExitFromSelectionMerge(uint32_t start_block_id, + uint32_t merge_block_id, + uint32_t loop_merge_id, + uint32_t loop_continue_id); +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_DEAD_BRANCH_ELIM_PASS_H_ diff --git a/source/opt/dead_insert_elim_pass.cpp b/source/opt/dead_insert_elim_pass.cpp new file mode 100644 index 000000000..7d5634383 --- /dev/null +++ b/source/opt/dead_insert_elim_pass.cpp @@ -0,0 +1,263 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// Copyright (c) 2018 Valve Corporation +// Copyright (c) 2018 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/dead_insert_elim_pass.h" + +#include "source/opt/composite.h" +#include "source/opt/ir_context.h" +#include "source/opt/iterator.h" +#include "spirv/1.2/GLSL.std.450.h" + +namespace spvtools { +namespace opt { + +namespace { + +const uint32_t kTypeVectorCountInIdx = 1; +const uint32_t kTypeMatrixCountInIdx = 1; +const uint32_t kTypeArrayLengthIdInIdx = 1; +const uint32_t kTypeIntWidthInIdx = 0; +const uint32_t kConstantValueInIdx = 0; +const uint32_t kInsertObjectIdInIdx = 0; +const uint32_t kInsertCompositeIdInIdx = 1; + +} // anonymous namespace + +uint32_t DeadInsertElimPass::NumComponents(Instruction* typeInst) { + switch (typeInst->opcode()) { + case SpvOpTypeVector: { + return typeInst->GetSingleWordInOperand(kTypeVectorCountInIdx); + } break; + case SpvOpTypeMatrix: { + return typeInst->GetSingleWordInOperand(kTypeMatrixCountInIdx); + } break; + case SpvOpTypeArray: { + uint32_t lenId = + typeInst->GetSingleWordInOperand(kTypeArrayLengthIdInIdx); + Instruction* lenInst = get_def_use_mgr()->GetDef(lenId); + if (lenInst->opcode() != SpvOpConstant) return 0; + uint32_t lenTypeId = lenInst->type_id(); + Instruction* lenTypeInst = get_def_use_mgr()->GetDef(lenTypeId); + // TODO(greg-lunarg): Support non-32-bit array length + if (lenTypeInst->GetSingleWordInOperand(kTypeIntWidthInIdx) != 32) + return 0; + return lenInst->GetSingleWordInOperand(kConstantValueInIdx); + } break; + case SpvOpTypeStruct: { + return typeInst->NumInOperands(); + } break; + default: { return 0; } break; + } +} + +void DeadInsertElimPass::MarkInsertChain( + Instruction* insertChain, std::vector* pExtIndices, + uint32_t extOffset, std::unordered_set* visited_phis) { + // Not currently optimizing array inserts. + Instruction* typeInst = get_def_use_mgr()->GetDef(insertChain->type_id()); + if (typeInst->opcode() == SpvOpTypeArray) return; + // Insert chains are only composed of inserts and phis + if (insertChain->opcode() != SpvOpCompositeInsert && + insertChain->opcode() != SpvOpPhi) + return; + // If extract indices are empty, mark all subcomponents if type + // is constant length. + if (pExtIndices == nullptr) { + uint32_t cnum = NumComponents(typeInst); + if (cnum > 0) { + std::vector extIndices; + for (uint32_t i = 0; i < cnum; i++) { + extIndices.clear(); + extIndices.push_back(i); + std::unordered_set sub_visited_phis; + MarkInsertChain(insertChain, &extIndices, 0, &sub_visited_phis); + } + return; + } + } + Instruction* insInst = insertChain; + while (insInst->opcode() == SpvOpCompositeInsert) { + // If no extract indices, mark insert and inserted object (which might + // also be an insert chain) and continue up the chain though the input + // composite. + // + // Note: We mark inserted objects in this function (rather than in + // EliminateDeadInsertsOnePass) because in some cases, we can do it + // more accurately here. + if (pExtIndices == nullptr) { + liveInserts_.insert(insInst->result_id()); + uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx); + std::unordered_set obj_visited_phis; + MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0, + &obj_visited_phis); + // If extract indices match insert, we are done. Mark insert and + // inserted object. + } else if (ExtInsMatch(*pExtIndices, insInst, extOffset)) { + liveInserts_.insert(insInst->result_id()); + uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx); + std::unordered_set obj_visited_phis; + MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0, + &obj_visited_phis); + break; + // If non-matching intersection, mark insert + } else if (ExtInsConflict(*pExtIndices, insInst, extOffset)) { + liveInserts_.insert(insInst->result_id()); + // If more extract indices than insert, we are done. Use remaining + // extract indices to mark inserted object. + uint32_t numInsertIndices = insInst->NumInOperands() - 2; + if (pExtIndices->size() - extOffset > numInsertIndices) { + uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx); + std::unordered_set obj_visited_phis; + MarkInsertChain(get_def_use_mgr()->GetDef(objId), pExtIndices, + extOffset + numInsertIndices, &obj_visited_phis); + break; + // If fewer extract indices than insert, also mark inserted object and + // continue up chain. + } else { + uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx); + std::unordered_set obj_visited_phis; + MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0, + &obj_visited_phis); + } + } + // Get next insert in chain + const uint32_t compId = + insInst->GetSingleWordInOperand(kInsertCompositeIdInIdx); + insInst = get_def_use_mgr()->GetDef(compId); + } + // If insert chain ended with phi, do recursive call on each operand + if (insInst->opcode() != SpvOpPhi) return; + // Mark phi visited to prevent potential infinite loop. If phi is already + // visited, return to avoid infinite loop. + if (visited_phis->count(insInst->result_id()) != 0) return; + visited_phis->insert(insInst->result_id()); + + // Phis may have duplicate inputs values for different edges, prune incoming + // ids lists before recursing. + std::vector ids; + for (uint32_t i = 0; i < insInst->NumInOperands(); i += 2) { + ids.push_back(insInst->GetSingleWordInOperand(i)); + } + std::sort(ids.begin(), ids.end()); + auto new_end = std::unique(ids.begin(), ids.end()); + for (auto id_iter = ids.begin(); id_iter != new_end; ++id_iter) { + Instruction* pi = get_def_use_mgr()->GetDef(*id_iter); + MarkInsertChain(pi, pExtIndices, extOffset, visited_phis); + } +} + +bool DeadInsertElimPass::EliminateDeadInserts(Function* func) { + bool modified = false; + bool lastmodified = true; + // Each pass can delete dead instructions, thus potentially revealing + // new dead insertions ie insertions with no uses. + while (lastmodified) { + lastmodified = EliminateDeadInsertsOnePass(func); + modified |= lastmodified; + } + return modified; +} + +bool DeadInsertElimPass::EliminateDeadInsertsOnePass(Function* func) { + bool modified = false; + liveInserts_.clear(); + visitedPhis_.clear(); + // Mark all live inserts + for (auto bi = func->begin(); bi != func->end(); ++bi) { + for (auto ii = bi->begin(); ii != bi->end(); ++ii) { + // Only process Inserts and composite Phis + SpvOp op = ii->opcode(); + Instruction* typeInst = get_def_use_mgr()->GetDef(ii->type_id()); + if (op != SpvOpCompositeInsert && + (op != SpvOpPhi || !spvOpcodeIsComposite(typeInst->opcode()))) + continue; + // The marking algorithm can be expensive for large arrays and the + // efficacy of eliminating dead inserts into arrays is questionable. + // Skip optimizing array inserts for now. Just mark them live. + // TODO(greg-lunarg): Eliminate dead array inserts + if (op == SpvOpCompositeInsert) { + if (typeInst->opcode() == SpvOpTypeArray) { + liveInserts_.insert(ii->result_id()); + continue; + } + } + const uint32_t id = ii->result_id(); + get_def_use_mgr()->ForEachUser(id, [&ii, this](Instruction* user) { + switch (user->opcode()) { + case SpvOpCompositeInsert: + case SpvOpPhi: + // Use by insert or phi does not initiate marking + break; + case SpvOpCompositeExtract: { + // Capture extract indices + std::vector extIndices; + uint32_t icnt = 0; + user->ForEachInOperand([&icnt, &extIndices](const uint32_t* idp) { + if (icnt > 0) extIndices.push_back(*idp); + ++icnt; + }); + // Mark all inserts in chain that intersect with extract + std::unordered_set visited_phis; + MarkInsertChain(&*ii, &extIndices, 0, &visited_phis); + } break; + default: { + // Mark inserts in chain for all components + MarkInsertChain(&*ii, nullptr, 0, nullptr); + } break; + } + }); + } + } + // Find and disconnect dead inserts + std::vector dead_instructions; + for (auto bi = func->begin(); bi != func->end(); ++bi) { + for (auto ii = bi->begin(); ii != bi->end(); ++ii) { + if (ii->opcode() != SpvOpCompositeInsert) continue; + const uint32_t id = ii->result_id(); + if (liveInserts_.find(id) != liveInserts_.end()) continue; + const uint32_t replId = + ii->GetSingleWordInOperand(kInsertCompositeIdInIdx); + (void)context()->ReplaceAllUsesWith(id, replId); + dead_instructions.push_back(&*ii); + modified = true; + } + } + // DCE dead inserts + while (!dead_instructions.empty()) { + Instruction* inst = dead_instructions.back(); + dead_instructions.pop_back(); + DCEInst(inst, [&dead_instructions](Instruction* other_inst) { + auto i = std::find(dead_instructions.begin(), dead_instructions.end(), + other_inst); + if (i != dead_instructions.end()) { + dead_instructions.erase(i); + } + }); + } + return modified; +} + +Pass::Status DeadInsertElimPass::Process() { + // Process all entry point functions. + ProcessFunction pfn = [this](Function* fp) { + return EliminateDeadInserts(fp); + }; + bool modified = context()->ProcessEntryPointCallTree(pfn); + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/dead_insert_elim_pass.h b/source/opt/dead_insert_elim_pass.h new file mode 100644 index 000000000..01f12bb04 --- /dev/null +++ b/source/opt/dead_insert_elim_pass.h @@ -0,0 +1,90 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// Copyright (c) 2018 Valve Corporation +// Copyright (c) 2018 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_DEAD_INSERT_ELIM_PASS_H_ +#define SOURCE_OPT_DEAD_INSERT_ELIM_PASS_H_ + +#include +#include +#include +#include +#include +#include + +#include "source/opt/basic_block.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/mem_pass.h" +#include "source/opt/module.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class DeadInsertElimPass : public MemPass { + public: + DeadInsertElimPass() = default; + + const char* name() const override { return "eliminate-dead-inserts"; } + Status Process() override; + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators | + IRContext::kAnalysisCFG | IRContext::kAnalysisDominatorAnalysis | + IRContext::kAnalysisNameMap | IRContext::kAnalysisConstants | + IRContext::kAnalysisTypes; + } + + private: + // Return the number of subcomponents in the composite type |typeId|. + // Return 0 if not a composite type or number of components is not a + // 32-bit constant. + uint32_t NumComponents(Instruction* typeInst); + + // Mark all inserts in instruction chain ending at |insertChain| with + // indices that intersect with extract indices |extIndices| starting with + // index at |extOffset|. Chains are composed solely of Inserts and Phis. + // Mark all inserts in chain if |extIndices| is nullptr. + void MarkInsertChain(Instruction* insertChain, + std::vector* extIndices, uint32_t extOffset, + std::unordered_set* visited_phis); + + // Perform EliminateDeadInsertsOnePass(|func|) until no modification is + // made. Return true if modified. + bool EliminateDeadInserts(Function* func); + + // DCE all dead struct, matrix and vector inserts in |func|. An insert is + // dead if the value it inserts is never used. Replace any reference to the + // insert with its original composite. Return true if modified. Dead inserts + // in dependence cycles are not currently eliminated. Dead inserts into + // arrays are not currently eliminated. + bool EliminateDeadInsertsOnePass(Function* func); + + // Return true if all extensions in this module are allowed by this pass. + bool AllExtensionsSupported() const; + + // Live inserts + std::unordered_set liveInserts_; + + // Visited phis as insert chain is traversed; used to avoid infinite loop + std::unordered_map visitedPhis_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_DEAD_INSERT_ELIM_PASS_H_ diff --git a/source/opt/dead_variable_elimination.cpp b/source/opt/dead_variable_elimination.cpp new file mode 100644 index 000000000..283710684 --- /dev/null +++ b/source/opt/dead_variable_elimination.cpp @@ -0,0 +1,111 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/dead_variable_elimination.h" + +#include + +#include "source/opt/ir_context.h" +#include "source/opt/reflect.h" + +namespace spvtools { +namespace opt { + +// This optimization removes global variables that are not needed because they +// are definitely not accessed. +Pass::Status DeadVariableElimination::Process() { + // The algorithm will compute the reference count for every global variable. + // Anything with a reference count of 0 will then be deleted. For variables + // that might have references that are not explicit in this context, we use + // the value kMustKeep as the reference count. + std::vector ids_to_remove; + + // Get the reference count for all of the global OpVariable instructions. + for (auto& inst : context()->types_values()) { + if (inst.opcode() != SpvOp::SpvOpVariable) { + continue; + } + + size_t count = 0; + uint32_t result_id = inst.result_id(); + + // Check the linkage. If it is exported, it could be reference somewhere + // else, so we must keep the variable around. + get_decoration_mgr()->ForEachDecoration( + result_id, SpvDecorationLinkageAttributes, + [&count](const Instruction& linkage_instruction) { + uint32_t last_operand = linkage_instruction.NumOperands() - 1; + if (linkage_instruction.GetSingleWordOperand(last_operand) == + SpvLinkageTypeExport) { + count = kMustKeep; + } + }); + + if (count != kMustKeep) { + // If we don't have to keep the instruction for other reasons, then look + // at the uses and count the number of real references. + count = 0; + get_def_use_mgr()->ForEachUser(result_id, [&count](Instruction* user) { + if (!IsAnnotationInst(user->opcode()) && user->opcode() != SpvOpName) { + ++count; + } + }); + } + reference_count_[result_id] = count; + if (count == 0) { + ids_to_remove.push_back(result_id); + } + } + + // Remove all of the variables that have a reference count of 0. + bool modified = false; + if (!ids_to_remove.empty()) { + modified = true; + for (auto result_id : ids_to_remove) { + DeleteVariable(result_id); + } + } + return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); +} + +void DeadVariableElimination::DeleteVariable(uint32_t result_id) { + Instruction* inst = get_def_use_mgr()->GetDef(result_id); + assert(inst->opcode() == SpvOpVariable && + "Should not be trying to delete anything other than an OpVariable."); + + // Look for an initializer that references another variable. We need to know + // if that variable can be deleted after the reference is removed. + if (inst->NumOperands() == 4) { + Instruction* initializer = + get_def_use_mgr()->GetDef(inst->GetSingleWordOperand(3)); + + // TODO: Handle OpSpecConstantOP which might be defined in terms of other + // variables. Will probably require a unified dead code pass that does all + // instruction types. (Issue 906) + if (initializer->opcode() == SpvOpVariable) { + uint32_t initializer_id = initializer->result_id(); + size_t& count = reference_count_[initializer_id]; + if (count != kMustKeep) { + --count; + } + + if (count == 0) { + DeleteVariable(initializer_id); + } + } + } + context()->KillDef(result_id); +} +} // namespace opt +} // namespace spvtools diff --git a/source/opt/dead_variable_elimination.h b/source/opt/dead_variable_elimination.h new file mode 100644 index 000000000..5dde71ba7 --- /dev/null +++ b/source/opt/dead_variable_elimination.h @@ -0,0 +1,56 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_DEAD_VARIABLE_ELIMINATION_H_ +#define SOURCE_OPT_DEAD_VARIABLE_ELIMINATION_H_ + +#include +#include + +#include "source/opt/decoration_manager.h" +#include "source/opt/mem_pass.h" + +namespace spvtools { +namespace opt { + +class DeadVariableElimination : public MemPass { + public: + const char* name() const override { return "eliminate-dead-variables"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | IRContext::kAnalysisConstants | + IRContext::kAnalysisTypes; + } + + private: + // Deletes the OpVariable instruction who result id is |result_id|. + void DeleteVariable(uint32_t result_id); + + // Keeps track of the number of references of an id. Once that value is 0, it + // is safe to remove the corresponding instruction. + // + // Note that the special value kMustKeep is used to indicate that the + // instruction cannot be deleted for reasons other that is being explicitly + // referenced. + std::unordered_map reference_count_; + + // Special value used to indicate that an id cannot be safely deleted. + enum { kMustKeep = INT_MAX }; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_DEAD_VARIABLE_ELIMINATION_H_ diff --git a/source/opt/decoration_manager.cpp b/source/opt/decoration_manager.cpp new file mode 100644 index 000000000..a12326ba5 --- /dev/null +++ b/source/opt/decoration_manager.cpp @@ -0,0 +1,527 @@ +// Copyright (c) 2017 Pierre Moreau +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/decoration_manager.h" + +#include +#include +#include +#include +#include + +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace opt { +namespace analysis { + +void DecorationManager::RemoveDecorationsFrom( + uint32_t id, std::function pred) { + const auto ids_iter = id_to_decoration_insts_.find(id); + if (ids_iter == id_to_decoration_insts_.end()) { + return; + } + + TargetData& decorations_info = ids_iter->second; + auto context = module_->context(); + std::vector insts_to_kill; + const bool is_group = !decorations_info.decorate_insts.empty(); + + // Schedule all direct decorations for removal if instructed as such by + // |pred|. + for (Instruction* inst : decorations_info.direct_decorations) + if (pred(*inst)) insts_to_kill.push_back(inst); + + // For all groups being directly applied to |id|, remove |id| (and the + // literal if |inst| is an OpGroupMemberDecorate) from the instruction + // applying the group. + std::unordered_set indirect_decorations_to_remove; + for (Instruction* inst : decorations_info.indirect_decorations) { + assert(inst->opcode() == SpvOpGroupDecorate || + inst->opcode() == SpvOpGroupMemberDecorate); + + std::vector group_decorations_to_keep; + const uint32_t group_id = inst->GetSingleWordInOperand(0u); + const auto group_iter = id_to_decoration_insts_.find(group_id); + assert(group_iter != id_to_decoration_insts_.end() && + "Unknown decoration group"); + const auto& group_decorations = group_iter->second.direct_decorations; + for (Instruction* decoration : group_decorations) { + if (!pred(*decoration)) group_decorations_to_keep.push_back(decoration); + } + + // If all decorations should be kept, then we can keep |id| part of the + // group. However, if the group itself has no decorations, we should remove + // the id from the group. This is needed to make |KillNameAndDecorate| work + // correctly when a decoration group has no decorations. + if (group_decorations_to_keep.size() == group_decorations.size() && + group_decorations.size() != 0) { + continue; + } + + // Otherwise, remove |id| from the targets of |group_id| + const uint32_t stride = inst->opcode() == SpvOpGroupDecorate ? 1u : 2u; + bool was_modified = false; + for (uint32_t i = 1u; i < inst->NumInOperands();) { + if (inst->GetSingleWordInOperand(i) != id) { + i += stride; + continue; + } + + const uint32_t last_operand_index = inst->NumInOperands() - stride; + if (i < last_operand_index) + inst->GetInOperand(i) = inst->GetInOperand(last_operand_index); + // Remove the associated literal, if it exists. + if (stride == 2u) { + if (i < last_operand_index) + inst->GetInOperand(i + 1u) = + inst->GetInOperand(last_operand_index + 1u); + inst->RemoveInOperand(last_operand_index + 1u); + } + inst->RemoveInOperand(last_operand_index); + was_modified = true; + } + + // If the instruction has no targets left, remove the instruction + // altogether. + if (inst->NumInOperands() == 1u) { + indirect_decorations_to_remove.emplace(inst); + insts_to_kill.push_back(inst); + } else if (was_modified) { + context->ForgetUses(inst); + indirect_decorations_to_remove.emplace(inst); + context->AnalyzeUses(inst); + } + + // If only some of the decorations should be kept, clone them and apply + // them directly to |id|. + if (!group_decorations_to_keep.empty()) { + for (Instruction* decoration : group_decorations_to_keep) { + // simply clone decoration and change |group_id| to |id| + std::unique_ptr new_inst( + decoration->Clone(module_->context())); + new_inst->SetInOperand(0, {id}); + module_->AddAnnotationInst(std::move(new_inst)); + auto decoration_iter = --module_->annotation_end(); + context->AnalyzeUses(&*decoration_iter); + } + } + } + + auto& indirect_decorations = decorations_info.indirect_decorations; + indirect_decorations.erase( + std::remove_if( + indirect_decorations.begin(), indirect_decorations.end(), + [&indirect_decorations_to_remove](const Instruction* inst) { + return indirect_decorations_to_remove.count(inst); + }), + indirect_decorations.end()); + + for (Instruction* inst : insts_to_kill) context->KillInst(inst); + insts_to_kill.clear(); + + // Schedule all instructions applying the group for removal if this group no + // longer applies decorations, either directly or indirectly. + if (is_group && decorations_info.direct_decorations.empty() && + decorations_info.indirect_decorations.empty()) { + for (Instruction* inst : decorations_info.decorate_insts) + insts_to_kill.push_back(inst); + } + for (Instruction* inst : insts_to_kill) context->KillInst(inst); + + if (decorations_info.direct_decorations.empty() && + decorations_info.indirect_decorations.empty() && + decorations_info.decorate_insts.empty()) { + id_to_decoration_insts_.erase(ids_iter); + } +} + +std::vector DecorationManager::GetDecorationsFor( + uint32_t id, bool include_linkage) { + return InternalGetDecorationsFor(id, include_linkage); +} + +std::vector DecorationManager::GetDecorationsFor( + uint32_t id, bool include_linkage) const { + return const_cast(this) + ->InternalGetDecorationsFor(id, include_linkage); +} + +bool DecorationManager::HaveTheSameDecorations(uint32_t id1, + uint32_t id2) const { + using InstructionList = std::vector; + using DecorationSet = std::set; + + const InstructionList decorations_for1 = GetDecorationsFor(id1, false); + const InstructionList decorations_for2 = GetDecorationsFor(id2, false); + + // This function splits the decoration instructions into different sets, + // based on their opcode; only OpDecorate, OpDecorateId, + // OpDecorateStringGOOGLE, and OpMemberDecorate are considered, the other + // opcodes are ignored. + const auto fillDecorationSets = + [](const InstructionList& decoration_list, DecorationSet* decorate_set, + DecorationSet* decorate_id_set, DecorationSet* decorate_string_set, + DecorationSet* member_decorate_set) { + for (const Instruction* inst : decoration_list) { + std::u32string decoration_payload; + // Ignore the opcode and the target as we do not want them to be + // compared. + for (uint32_t i = 1u; i < inst->NumInOperands(); ++i) { + for (uint32_t word : inst->GetInOperand(i).words) { + decoration_payload.push_back(word); + } + } + + switch (inst->opcode()) { + case SpvOpDecorate: + decorate_set->emplace(std::move(decoration_payload)); + break; + case SpvOpMemberDecorate: + member_decorate_set->emplace(std::move(decoration_payload)); + break; + case SpvOpDecorateId: + decorate_id_set->emplace(std::move(decoration_payload)); + break; + case SpvOpDecorateStringGOOGLE: + decorate_string_set->emplace(std::move(decoration_payload)); + break; + default: + break; + } + } + }; + + DecorationSet decorate_set_for1; + DecorationSet decorate_id_set_for1; + DecorationSet decorate_string_set_for1; + DecorationSet member_decorate_set_for1; + fillDecorationSets(decorations_for1, &decorate_set_for1, + &decorate_id_set_for1, &decorate_string_set_for1, + &member_decorate_set_for1); + + DecorationSet decorate_set_for2; + DecorationSet decorate_id_set_for2; + DecorationSet decorate_string_set_for2; + DecorationSet member_decorate_set_for2; + fillDecorationSets(decorations_for2, &decorate_set_for2, + &decorate_id_set_for2, &decorate_string_set_for2, + &member_decorate_set_for2); + + const bool result = decorate_set_for1 == decorate_set_for2 && + decorate_id_set_for1 == decorate_id_set_for2 && + member_decorate_set_for1 == member_decorate_set_for2 && + // Compare string sets last in case the strings are long. + decorate_string_set_for1 == decorate_string_set_for2; + return result; +} + +// TODO(pierremoreau): If OpDecorateId is referencing an OpConstant, one could +// check that the constants are the same rather than just +// looking at the constant ID. +bool DecorationManager::AreDecorationsTheSame(const Instruction* inst1, + const Instruction* inst2, + bool ignore_target) const { + switch (inst1->opcode()) { + case SpvOpDecorate: + case SpvOpMemberDecorate: + case SpvOpDecorateId: + case SpvOpDecorateStringGOOGLE: + break; + default: + return false; + } + + if (inst1->opcode() != inst2->opcode() || + inst1->NumInOperands() != inst2->NumInOperands()) + return false; + + for (uint32_t i = ignore_target ? 1u : 0u; i < inst1->NumInOperands(); ++i) + if (inst1->GetInOperand(i) != inst2->GetInOperand(i)) return false; + + return true; +} + +void DecorationManager::AnalyzeDecorations() { + if (!module_) return; + + // For each group and instruction, collect all their decoration instructions. + for (Instruction& inst : module_->annotations()) { + AddDecoration(&inst); + } +} + +void DecorationManager::AddDecoration(Instruction* inst) { + switch (inst->opcode()) { + case SpvOpDecorate: + case SpvOpDecorateId: + case SpvOpDecorateStringGOOGLE: + case SpvOpMemberDecorate: { + const auto target_id = inst->GetSingleWordInOperand(0u); + id_to_decoration_insts_[target_id].direct_decorations.push_back(inst); + break; + } + case SpvOpGroupDecorate: + case SpvOpGroupMemberDecorate: { + const uint32_t start = inst->opcode() == SpvOpGroupDecorate ? 1u : 2u; + const uint32_t stride = start; + for (uint32_t i = start; i < inst->NumInOperands(); i += stride) { + const auto target_id = inst->GetSingleWordInOperand(i); + TargetData& target_data = id_to_decoration_insts_[target_id]; + target_data.indirect_decorations.push_back(inst); + } + const auto target_id = inst->GetSingleWordInOperand(0u); + id_to_decoration_insts_[target_id].decorate_insts.push_back(inst); + break; + } + default: + break; + } +} + +void DecorationManager::AddDecoration(SpvOp opcode, + std::vector opnds) { + IRContext* ctx = module_->context(); + std::unique_ptr newDecoOp( + new Instruction(ctx, opcode, 0, 0, opnds)); + ctx->AddAnnotationInst(std::move(newDecoOp)); +} + +void DecorationManager::AddDecoration(uint32_t inst_id, uint32_t decoration) { + AddDecoration( + SpvOpDecorate, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {inst_id}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {decoration}}}); +} + +void DecorationManager::AddDecorationVal(uint32_t inst_id, uint32_t decoration, + uint32_t decoration_value) { + AddDecoration( + SpvOpDecorate, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {inst_id}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {decoration}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, + {decoration_value}}}); +} + +void DecorationManager::AddMemberDecoration(uint32_t inst_id, uint32_t member, + uint32_t decoration, + uint32_t decoration_value) { + AddDecoration( + SpvOpMemberDecorate, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {inst_id}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {member}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {decoration}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, + {decoration_value}}}); +} + +template +std::vector DecorationManager::InternalGetDecorationsFor( + uint32_t id, bool include_linkage) { + std::vector decorations; + + const auto ids_iter = id_to_decoration_insts_.find(id); + // |id| has no decorations + if (ids_iter == id_to_decoration_insts_.end()) return decorations; + + const TargetData& target_data = ids_iter->second; + + const auto process_direct_decorations = + [include_linkage, + &decorations](const std::vector& direct_decorations) { + for (Instruction* inst : direct_decorations) { + const bool is_linkage = inst->opcode() == SpvOpDecorate && + inst->GetSingleWordInOperand(1u) == + SpvDecorationLinkageAttributes; + if (include_linkage || !is_linkage) decorations.push_back(inst); + } + }; + + // Process |id|'s decorations. + process_direct_decorations(ids_iter->second.direct_decorations); + + // Process the decorations of all groups applied to |id|. + for (const Instruction* inst : target_data.indirect_decorations) { + const uint32_t group_id = inst->GetSingleWordInOperand(0u); + const auto group_iter = id_to_decoration_insts_.find(group_id); + assert(group_iter != id_to_decoration_insts_.end() && "Unknown group ID"); + process_direct_decorations(group_iter->second.direct_decorations); + } + + return decorations; +} + +bool DecorationManager::WhileEachDecoration( + uint32_t id, uint32_t decoration, + std::function f) { + for (const Instruction* inst : GetDecorationsFor(id, true)) { + switch (inst->opcode()) { + case SpvOpMemberDecorate: + if (inst->GetSingleWordInOperand(2) == decoration) { + if (!f(*inst)) return false; + } + break; + case SpvOpDecorate: + case SpvOpDecorateId: + case SpvOpDecorateStringGOOGLE: + if (inst->GetSingleWordInOperand(1) == decoration) { + if (!f(*inst)) return false; + } + break; + default: + assert(false && "Unexpected decoration instruction"); + } + } + return true; +} + +void DecorationManager::ForEachDecoration( + uint32_t id, uint32_t decoration, + std::function f) { + WhileEachDecoration(id, decoration, [&f](const Instruction& inst) { + f(inst); + return true; + }); +} + +void DecorationManager::CloneDecorations(uint32_t from, uint32_t to) { + const auto decoration_list = id_to_decoration_insts_.find(from); + if (decoration_list == id_to_decoration_insts_.end()) return; + auto context = module_->context(); + for (Instruction* inst : decoration_list->second.direct_decorations) { + // simply clone decoration and change |target-id| to |to| + std::unique_ptr new_inst(inst->Clone(module_->context())); + new_inst->SetInOperand(0, {to}); + module_->AddAnnotationInst(std::move(new_inst)); + auto decoration_iter = --module_->annotation_end(); + context->AnalyzeUses(&*decoration_iter); + } + // We need to copy the list of instructions as ForgetUses and AnalyzeUses are + // going to modify it. + std::vector indirect_decorations = + decoration_list->second.indirect_decorations; + for (Instruction* inst : indirect_decorations) { + switch (inst->opcode()) { + case SpvOpGroupDecorate: + context->ForgetUses(inst); + // add |to| to list of decorated id's + inst->AddOperand( + Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, {to})); + context->AnalyzeUses(inst); + break; + case SpvOpGroupMemberDecorate: { + context->ForgetUses(inst); + // for each (id == from), add (to, literal) as operands + const uint32_t num_operands = inst->NumOperands(); + for (uint32_t i = 1; i < num_operands; i += 2) { + Operand op = inst->GetOperand(i); + if (op.words[0] == from) { // add new pair of operands: (to, literal) + inst->AddOperand( + Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, {to})); + op = inst->GetOperand(i + 1); + inst->AddOperand(std::move(op)); + } + } + context->AnalyzeUses(inst); + break; + } + default: + assert(false && "Unexpected decoration instruction"); + } + } +} + +void DecorationManager::CloneDecorations( + uint32_t from, uint32_t to, + const std::vector& decorations_to_copy) { + const auto decoration_list = id_to_decoration_insts_.find(from); + if (decoration_list == id_to_decoration_insts_.end()) return; + auto context = module_->context(); + for (Instruction* inst : decoration_list->second.direct_decorations) { + if (std::find(decorations_to_copy.begin(), decorations_to_copy.end(), + inst->GetSingleWordInOperand(1)) == + decorations_to_copy.end()) { + continue; + } + + // Clone decoration and change |target-id| to |to|. + std::unique_ptr new_inst(inst->Clone(module_->context())); + new_inst->SetInOperand(0, {to}); + module_->AddAnnotationInst(std::move(new_inst)); + auto decoration_iter = --module_->annotation_end(); + context->AnalyzeUses(&*decoration_iter); + } + + // We need to copy the list of instructions as ForgetUses and AnalyzeUses are + // going to modify it. + std::vector indirect_decorations = + decoration_list->second.indirect_decorations; + for (Instruction* inst : indirect_decorations) { + switch (inst->opcode()) { + case SpvOpGroupDecorate: + CloneDecorations(inst->GetSingleWordInOperand(0), to, + decorations_to_copy); + break; + case SpvOpGroupMemberDecorate: { + assert(false && "The source id is not suppose to be a type."); + break; + } + default: + assert(false && "Unexpected decoration instruction"); + } + } +} + +void DecorationManager::RemoveDecoration(Instruction* inst) { + const auto remove_from_container = [inst](std::vector& v) { + v.erase(std::remove(v.begin(), v.end(), inst), v.end()); + }; + + switch (inst->opcode()) { + case SpvOpDecorate: + case SpvOpDecorateId: + case SpvOpDecorateStringGOOGLE: + case SpvOpMemberDecorate: { + const auto target_id = inst->GetSingleWordInOperand(0u); + auto const iter = id_to_decoration_insts_.find(target_id); + if (iter == id_to_decoration_insts_.end()) return; + remove_from_container(iter->second.direct_decorations); + } break; + case SpvOpGroupDecorate: + case SpvOpGroupMemberDecorate: { + const uint32_t stride = inst->opcode() == SpvOpGroupDecorate ? 1u : 2u; + for (uint32_t i = 1u; i < inst->NumInOperands(); i += stride) { + const auto target_id = inst->GetSingleWordInOperand(i); + auto const iter = id_to_decoration_insts_.find(target_id); + if (iter == id_to_decoration_insts_.end()) continue; + remove_from_container(iter->second.indirect_decorations); + } + const auto group_id = inst->GetSingleWordInOperand(0u); + auto const iter = id_to_decoration_insts_.find(group_id); + if (iter == id_to_decoration_insts_.end()) return; + remove_from_container(iter->second.decorate_insts); + } break; + default: + break; + } +} + +bool operator==(const DecorationManager& lhs, const DecorationManager& rhs) { + return lhs.id_to_decoration_insts_ == rhs.id_to_decoration_insts_; +} + +} // namespace analysis +} // namespace opt +} // namespace spvtools diff --git a/source/opt/decoration_manager.h b/source/opt/decoration_manager.h new file mode 100644 index 000000000..a5fb4c861 --- /dev/null +++ b/source/opt/decoration_manager.h @@ -0,0 +1,192 @@ +// Copyright (c) 2017 Pierre Moreau +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_DECORATION_MANAGER_H_ +#define SOURCE_OPT_DECORATION_MANAGER_H_ + +#include +#include +#include +#include + +#include "source/opt/instruction.h" +#include "source/opt/module.h" + +namespace spvtools { +namespace opt { +namespace analysis { + +// A class for analyzing and managing decorations in an Module. +class DecorationManager { + public: + // Constructs a decoration manager from the given |module| + explicit DecorationManager(Module* module) : module_(module) { + AnalyzeDecorations(); + } + DecorationManager() = delete; + + // Changes all of the decorations (direct and through groups) where |pred| is + // true and that apply to |id| so that they no longer apply to |id|. + // + // If |id| is part of a group, it will be removed from the group if it + // does not use all of the group's decorations, or, if there are no + // decorations that apply to the group. + // + // If decoration groups become empty, the |OpGroupDecorate| and + // |OpGroupMemberDecorate| instructions will be killed. + // + // Decoration instructions that apply directly to |id| will be killed. + // + // If |id| is a decoration group and all of the group's decorations are + // removed, then the |OpGroupDecorate| and + // |OpGroupMemberDecorate| for the group will be killed, but not the defining + // |OpDecorationGroup| instruction. + void RemoveDecorationsFrom(uint32_t id, + std::function pred = + [](const Instruction&) { return true; }); + + // Removes all decorations from the result id of |inst|. + // + // NOTE: This is only meant to be called from ir_context, as only metadata + // will be removed, and no actual instruction. + void RemoveDecoration(Instruction* inst); + + // Returns a vector of all decorations affecting |id|. If a group is applied + // to |id|, the decorations of that group are returned rather than the group + // decoration instruction. If |include_linkage| is not set, linkage + // decorations won't be returned. + std::vector GetDecorationsFor(uint32_t id, + bool include_linkage); + std::vector GetDecorationsFor(uint32_t id, + bool include_linkage) const; + // Returns whether two IDs have the same decorations. Two SpvOpGroupDecorate + // instructions that apply the same decorations but to different IDs, still + // count as being the same. + bool HaveTheSameDecorations(uint32_t id1, uint32_t id2) const; + // Returns whether the two decorations instructions are the same and are + // applying the same decorations; unless |ignore_target| is false, the targets + // to which they are applied to does not matter, except for the member part. + // + // This is only valid for OpDecorate, OpMemberDecorate and OpDecorateId; it + // will return false for other opcodes. + bool AreDecorationsTheSame(const Instruction* inst1, const Instruction* inst2, + bool ignore_target) const; + + // |f| is run on each decoration instruction for |id| with decoration + // |decoration|. Processed are all decorations which target |id| either + // directly or indirectly by Decoration Groups. + void ForEachDecoration(uint32_t id, uint32_t decoration, + std::function f); + + // |f| is run on each decoration instruction for |id| with decoration + // |decoration|. Processes all decoration which target |id| either directly or + // indirectly through decoration groups. If |f| returns false, iteration is + // terminated and this function returns false. + bool WhileEachDecoration(uint32_t id, uint32_t decoration, + std::function f); + + // Clone all decorations from one id |from|. + // The cloned decorations are assigned to the given id |to| and are + // added to the module. The purpose is to decorate cloned instructions. + // This function does not check if the id |to| is already decorated. + void CloneDecorations(uint32_t from, uint32_t to); + + // Same as above, but only clone the decoration if the decoration operand is + // in |decorations_to_copy|. This function has the extra restriction that + // |from| and |to| must not be an object, not a type. + void CloneDecorations(uint32_t from, uint32_t to, + const std::vector& decorations_to_copy); + + // Informs the decoration manager of a new decoration that it needs to track. + void AddDecoration(Instruction* inst); + + // Add decoration with |opcode| and operands |opnds|. + void AddDecoration(SpvOp opcode, const std::vector opnds); + + // Add |decoration| of |inst_id| to module. + void AddDecoration(uint32_t inst_id, uint32_t decoration); + + // Add |decoration, decoration_value| of |inst_id| to module. + void AddDecorationVal(uint32_t inst_id, uint32_t decoration, + uint32_t decoration_value); + + // Add |decoration, decoration_value| of |inst_id, member| to module. + void AddMemberDecoration(uint32_t member, uint32_t inst_id, + uint32_t decoration, uint32_t decoration_value); + + friend bool operator==(const DecorationManager&, const DecorationManager&); + friend bool operator!=(const DecorationManager& lhs, + const DecorationManager& rhs) { + return !(lhs == rhs); + } + + private: + // Analyzes the defs and uses in the given |module| and populates data + // structures in this class. Does nothing if |module| is nullptr. + void AnalyzeDecorations(); + + template + std::vector InternalGetDecorationsFor(uint32_t id, bool include_linkage); + + // Tracks decoration information of an ID. + struct TargetData { + std::vector direct_decorations; // All decorate + // instructions applied + // to the tracked ID. + std::vector indirect_decorations; // All instructions + // applying a group to + // the tracked ID. + std::vector decorate_insts; // All decorate instructions + // applying the decorations + // of the tracked ID to + // targets. + // It is empty if the + // tracked ID is not a + // group. + }; + + friend bool operator==(const TargetData& lhs, const TargetData& rhs) { + if (!std::is_permutation(lhs.direct_decorations.begin(), + lhs.direct_decorations.end(), + rhs.direct_decorations.begin())) { + return false; + } + if (!std::is_permutation(lhs.indirect_decorations.begin(), + lhs.indirect_decorations.end(), + rhs.indirect_decorations.begin())) { + return false; + } + if (!std::is_permutation(lhs.decorate_insts.begin(), + lhs.decorate_insts.end(), + rhs.decorate_insts.begin())) { + return false; + } + return true; + } + + // Mapping from ids to the instructions applying a decoration to those ids. + // In other words, for each id you get all decoration instructions + // referencing that id, be it directly (SpvOpDecorate, SpvOpMemberDecorate + // and SpvOpDecorateId), or indirectly (SpvOpGroupDecorate, + // SpvOpMemberGroupDecorate). + std::unordered_map id_to_decoration_insts_; + // The enclosing module. + Module* module_; +}; + +} // namespace analysis +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_DECORATION_MANAGER_H_ diff --git a/source/opt/def_use_manager.cpp b/source/opt/def_use_manager.cpp new file mode 100644 index 000000000..0ec98cae1 --- /dev/null +++ b/source/opt/def_use_manager.cpp @@ -0,0 +1,299 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/def_use_manager.h" + +#include + +#include "source/opt/log.h" +#include "source/opt/reflect.h" + +namespace spvtools { +namespace opt { +namespace analysis { + +void DefUseManager::AnalyzeInstDef(Instruction* inst) { + const uint32_t def_id = inst->result_id(); + if (def_id != 0) { + auto iter = id_to_def_.find(def_id); + if (iter != id_to_def_.end()) { + // Clear the original instruction that defining the same result id of the + // new instruction. + ClearInst(iter->second); + } + id_to_def_[def_id] = inst; + } else { + ClearInst(inst); + } +} + +void DefUseManager::AnalyzeInstUse(Instruction* inst) { + // Create entry for the given instruction. Note that the instruction may + // not have any in-operands. In such cases, we still need a entry for those + // instructions so this manager knows it has seen the instruction later. + auto* used_ids = &inst_to_used_ids_[inst]; + if (used_ids->size()) { + EraseUseRecordsOfOperandIds(inst); + used_ids = &inst_to_used_ids_[inst]; + } + used_ids->clear(); // It might have existed before. + + for (uint32_t i = 0; i < inst->NumOperands(); ++i) { + switch (inst->GetOperand(i).type) { + // For any id type but result id type + case SPV_OPERAND_TYPE_ID: + case SPV_OPERAND_TYPE_TYPE_ID: + case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: + case SPV_OPERAND_TYPE_SCOPE_ID: { + uint32_t use_id = inst->GetSingleWordOperand(i); + Instruction* def = GetDef(use_id); + assert(def && "Definition is not registered."); + id_to_users_.insert(UserEntry(def, inst)); + used_ids->push_back(use_id); + } break; + default: + break; + } + } +} + +void DefUseManager::AnalyzeInstDefUse(Instruction* inst) { + AnalyzeInstDef(inst); + AnalyzeInstUse(inst); +} + +void DefUseManager::UpdateDefUse(Instruction* inst) { + const uint32_t def_id = inst->result_id(); + if (def_id != 0) { + auto iter = id_to_def_.find(def_id); + if (iter == id_to_def_.end()) { + AnalyzeInstDef(inst); + } + } + AnalyzeInstUse(inst); +} + +Instruction* DefUseManager::GetDef(uint32_t id) { + auto iter = id_to_def_.find(id); + if (iter == id_to_def_.end()) return nullptr; + return iter->second; +} + +const Instruction* DefUseManager::GetDef(uint32_t id) const { + const auto iter = id_to_def_.find(id); + if (iter == id_to_def_.end()) return nullptr; + return iter->second; +} + +DefUseManager::IdToUsersMap::const_iterator DefUseManager::UsersBegin( + const Instruction* def) const { + return id_to_users_.lower_bound( + UserEntry(const_cast(def), nullptr)); +} + +bool DefUseManager::UsersNotEnd(const IdToUsersMap::const_iterator& iter, + const IdToUsersMap::const_iterator& cached_end, + const Instruction* inst) const { + return (iter != cached_end && iter->first == inst); +} + +bool DefUseManager::UsersNotEnd(const IdToUsersMap::const_iterator& iter, + const Instruction* inst) const { + return UsersNotEnd(iter, id_to_users_.end(), inst); +} + +bool DefUseManager::WhileEachUser( + const Instruction* def, const std::function& f) const { + // Ensure that |def| has been registered. + assert(def && (!def->HasResultId() || def == GetDef(def->result_id())) && + "Definition is not registered."); + if (!def->HasResultId()) return true; + + auto end = id_to_users_.end(); + for (auto iter = UsersBegin(def); UsersNotEnd(iter, end, def); ++iter) { + if (!f(iter->second)) return false; + } + return true; +} + +bool DefUseManager::WhileEachUser( + uint32_t id, const std::function& f) const { + return WhileEachUser(GetDef(id), f); +} + +void DefUseManager::ForEachUser( + const Instruction* def, const std::function& f) const { + WhileEachUser(def, [&f](Instruction* user) { + f(user); + return true; + }); +} + +void DefUseManager::ForEachUser( + uint32_t id, const std::function& f) const { + ForEachUser(GetDef(id), f); +} + +bool DefUseManager::WhileEachUse( + const Instruction* def, + const std::function& f) const { + // Ensure that |def| has been registered. + assert(def && (!def->HasResultId() || def == GetDef(def->result_id())) && + "Definition is not registered."); + if (!def->HasResultId()) return true; + + auto end = id_to_users_.end(); + for (auto iter = UsersBegin(def); UsersNotEnd(iter, end, def); ++iter) { + Instruction* user = iter->second; + for (uint32_t idx = 0; idx != user->NumOperands(); ++idx) { + const Operand& op = user->GetOperand(idx); + if (op.type != SPV_OPERAND_TYPE_RESULT_ID && spvIsIdType(op.type)) { + if (def->result_id() == op.words[0]) { + if (!f(user, idx)) return false; + } + } + } + } + return true; +} + +bool DefUseManager::WhileEachUse( + uint32_t id, const std::function& f) const { + return WhileEachUse(GetDef(id), f); +} + +void DefUseManager::ForEachUse( + const Instruction* def, + const std::function& f) const { + WhileEachUse(def, [&f](Instruction* user, uint32_t index) { + f(user, index); + return true; + }); +} + +void DefUseManager::ForEachUse( + uint32_t id, const std::function& f) const { + ForEachUse(GetDef(id), f); +} + +uint32_t DefUseManager::NumUsers(const Instruction* def) const { + uint32_t count = 0; + ForEachUser(def, [&count](Instruction*) { ++count; }); + return count; +} + +uint32_t DefUseManager::NumUsers(uint32_t id) const { + return NumUsers(GetDef(id)); +} + +uint32_t DefUseManager::NumUses(const Instruction* def) const { + uint32_t count = 0; + ForEachUse(def, [&count](Instruction*, uint32_t) { ++count; }); + return count; +} + +uint32_t DefUseManager::NumUses(uint32_t id) const { + return NumUses(GetDef(id)); +} + +std::vector DefUseManager::GetAnnotations(uint32_t id) const { + std::vector annos; + const Instruction* def = GetDef(id); + if (!def) return annos; + + ForEachUser(def, [&annos](Instruction* user) { + if (IsAnnotationInst(user->opcode())) { + annos.push_back(user); + } + }); + return annos; +} + +void DefUseManager::AnalyzeDefUse(Module* module) { + if (!module) return; + // Analyze all the defs before any uses to catch forward references. + module->ForEachInst( + std::bind(&DefUseManager::AnalyzeInstDef, this, std::placeholders::_1)); + module->ForEachInst( + std::bind(&DefUseManager::AnalyzeInstUse, this, std::placeholders::_1)); +} + +void DefUseManager::ClearInst(Instruction* inst) { + auto iter = inst_to_used_ids_.find(inst); + if (iter != inst_to_used_ids_.end()) { + EraseUseRecordsOfOperandIds(inst); + if (inst->result_id() != 0) { + // Remove all uses of this inst. + auto users_begin = UsersBegin(inst); + auto end = id_to_users_.end(); + auto new_end = users_begin; + for (; UsersNotEnd(new_end, end, inst); ++new_end) { + } + id_to_users_.erase(users_begin, new_end); + id_to_def_.erase(inst->result_id()); + } + } +} + +void DefUseManager::EraseUseRecordsOfOperandIds(const Instruction* inst) { + // Go through all ids used by this instruction, remove this instruction's + // uses of them. + auto iter = inst_to_used_ids_.find(inst); + if (iter != inst_to_used_ids_.end()) { + for (auto use_id : iter->second) { + id_to_users_.erase( + UserEntry(GetDef(use_id), const_cast(inst))); + } + inst_to_used_ids_.erase(inst); + } +} + +bool operator==(const DefUseManager& lhs, const DefUseManager& rhs) { + if (lhs.id_to_def_ != rhs.id_to_def_) { + return false; + } + + if (lhs.id_to_users_ != rhs.id_to_users_) { + for (auto p : lhs.id_to_users_) { + if (rhs.id_to_users_.count(p) == 0) { + return false; + } + } + for (auto p : rhs.id_to_users_) { + if (lhs.id_to_users_.count(p) == 0) { + return false; + } + } + return false; + } + + if (lhs.inst_to_used_ids_ != rhs.inst_to_used_ids_) { + for (auto p : lhs.inst_to_used_ids_) { + if (rhs.inst_to_used_ids_.count(p.first) == 0) { + return false; + } + } + for (auto p : rhs.inst_to_used_ids_) { + if (lhs.inst_to_used_ids_.count(p.first) == 0) { + return false; + } + } + return false; + } + return true; +} + +} // namespace analysis +} // namespace opt +} // namespace spvtools diff --git a/source/opt/def_use_manager.h b/source/opt/def_use_manager.h new file mode 100644 index 000000000..0499e82b5 --- /dev/null +++ b/source/opt/def_use_manager.h @@ -0,0 +1,256 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_DEF_USE_MANAGER_H_ +#define SOURCE_OPT_DEF_USE_MANAGER_H_ + +#include +#include +#include +#include +#include + +#include "source/opt/instruction.h" +#include "source/opt/module.h" +#include "spirv-tools/libspirv.hpp" + +namespace spvtools { +namespace opt { +namespace analysis { + +// Class for representing a use of id. Note that: +// * Result type id is a use. +// * Ids referenced in OpSectionMerge & OpLoopMerge are considered as use. +// * Ids referenced in OpPhi's in operands are considered as use. +struct Use { + Instruction* inst; // Instruction using the id. + uint32_t operand_index; // logical operand index of the id use. This can be + // the index of result type id. +}; + +inline bool operator==(const Use& lhs, const Use& rhs) { + return lhs.inst == rhs.inst && lhs.operand_index == rhs.operand_index; +} + +inline bool operator!=(const Use& lhs, const Use& rhs) { return !(lhs == rhs); } + +inline bool operator<(const Use& lhs, const Use& rhs) { + if (lhs.inst < rhs.inst) return true; + if (lhs.inst > rhs.inst) return false; + return lhs.operand_index < rhs.operand_index; +} + +// Definition and user pair. +// +// The first element of the pair is the definition. +// The second element of the pair is the user. +// +// Definition should never be null. User can be null, however, such an entry +// should be used only for searching (e.g. all users of a particular definition) +// and never stored in a container. +using UserEntry = std::pair; + +// Orders UserEntry for use in associative containers (i.e. less than ordering). +// +// The definition of an UserEntry is treated as the major key and the users as +// the minor key so that all the users of a particular definition are +// consecutive in a container. +// +// A null user always compares less than a real user. This is done to provide +// easy values to search for the beginning of the users of a particular +// definition (i.e. using {def, nullptr}). +struct UserEntryLess { + bool operator()(const UserEntry& lhs, const UserEntry& rhs) const { + // If lhs.first and rhs.first are both null, fall through to checking the + // second entries. + if (!lhs.first && rhs.first) return true; + if (lhs.first && !rhs.first) return false; + + // If neither definition is null, then compare unique ids. + if (lhs.first && rhs.first) { + if (lhs.first->unique_id() < rhs.first->unique_id()) return true; + if (rhs.first->unique_id() < lhs.first->unique_id()) return false; + } + + // Return false on equality. + if (!lhs.second && !rhs.second) return false; + if (!lhs.second) return true; + if (!rhs.second) return false; + + // If neither user is null then compare unique ids. + return lhs.second->unique_id() < rhs.second->unique_id(); + } +}; + +// A class for analyzing and managing defs and uses in an Module. +class DefUseManager { + public: + using IdToDefMap = std::unordered_map; + using IdToUsersMap = std::set; + + // Constructs a def-use manager from the given |module|. All internal messages + // will be communicated to the outside via the given message |consumer|. This + // instance only keeps a reference to the |consumer|, so the |consumer| should + // outlive this instance. + DefUseManager(Module* module) { AnalyzeDefUse(module); } + + DefUseManager(const DefUseManager&) = delete; + DefUseManager(DefUseManager&&) = delete; + DefUseManager& operator=(const DefUseManager&) = delete; + DefUseManager& operator=(DefUseManager&&) = delete; + + // Analyzes the defs in the given |inst|. + void AnalyzeInstDef(Instruction* inst); + + // Analyzes the uses in the given |inst|. + // + // All operands of |inst| must be analyzed as defs. + void AnalyzeInstUse(Instruction* inst); + + // Analyzes the defs and uses in the given |inst|. + void AnalyzeInstDefUse(Instruction* inst); + + // Returns the def instruction for the given |id|. If there is no instruction + // defining |id|, returns nullptr. + Instruction* GetDef(uint32_t id); + const Instruction* GetDef(uint32_t id) const; + + // Runs the given function |f| on each unique user instruction of |def| (or + // |id|). + // + // If one instruction uses |def| in multiple operands, that instruction will + // only be visited once. + // + // |def| (or |id|) must be registered as a definition. + void ForEachUser(const Instruction* def, + const std::function& f) const; + void ForEachUser(uint32_t id, + const std::function& f) const; + + // Runs the given function |f| on each unique user instruction of |def| (or + // |id|). If |f| returns false, iteration is terminated and this function + // returns false. + // + // If one instruction uses |def| in multiple operands, that instruction will + // be only be visited once. + // + // |def| (or |id|) must be registered as a definition. + bool WhileEachUser(const Instruction* def, + const std::function& f) const; + bool WhileEachUser(uint32_t id, + const std::function& f) const; + + // Runs the given function |f| on each unique use of |def| (or + // |id|). + // + // If one instruction uses |def| in multiple operands, each operand will be + // visited separately. + // + // |def| (or |id|) must be registered as a definition. + void ForEachUse( + const Instruction* def, + const std::function& f) const; + void ForEachUse( + uint32_t id, + const std::function& f) const; + + // Runs the given function |f| on each unique use of |def| (or + // |id|). If |f| returns false, iteration is terminated and this function + // returns false. + // + // If one instruction uses |def| in multiple operands, each operand will be + // visited separately. + // + // |def| (or |id|) must be registered as a definition. + bool WhileEachUse( + const Instruction* def, + const std::function& f) const; + bool WhileEachUse( + uint32_t id, + const std::function& f) const; + + // Returns the number of users of |def| (or |id|). + uint32_t NumUsers(const Instruction* def) const; + uint32_t NumUsers(uint32_t id) const; + + // Returns the number of uses of |def| (or |id|). + uint32_t NumUses(const Instruction* def) const; + uint32_t NumUses(uint32_t id) const; + + // Returns the annotation instrunctions which are a direct use of the given + // |id|. This means when the decorations are applied through decoration + // group(s), this function will just return the OpGroupDecorate + // instrcution(s) which refer to the given id as an operand. The OpDecorate + // instructions which decorate the decoration group will not be returned. + std::vector GetAnnotations(uint32_t id) const; + + // Returns the map from ids to their def instructions. + const IdToDefMap& id_to_defs() const { return id_to_def_; } + // Returns the map from instructions to their users. + const IdToUsersMap& id_to_users() const { return id_to_users_; } + + // Clear the internal def-use record of the given instruction |inst|. This + // method will update the use information of the operand ids of |inst|. The + // record: |inst| uses an |id|, will be removed from the use records of |id|. + // If |inst| defines an result id, the use record of this result id will also + // be removed. Does nothing if |inst| was not analyzed before. + void ClearInst(Instruction* inst); + + // Erases the records that a given instruction uses its operand ids. + void EraseUseRecordsOfOperandIds(const Instruction* inst); + + friend bool operator==(const DefUseManager&, const DefUseManager&); + friend bool operator!=(const DefUseManager& lhs, const DefUseManager& rhs) { + return !(lhs == rhs); + } + + // If |inst| has not already been analysed, then analyses its defintion and + // uses. + void UpdateDefUse(Instruction* inst); + + private: + using InstToUsedIdsMap = + std::unordered_map>; + + // Returns the first location that {|def|, nullptr} could be inserted into the + // users map without violating ordering. + IdToUsersMap::const_iterator UsersBegin(const Instruction* def) const; + + // Returns true if |iter| has not reached the end of |def|'s users. + // + // In the first version |iter| is compared against the end of the map for + // validity before other checks. In the second version, |iter| is compared + // against |cached_end| for validity before other checks. This allows caching + // the map's end which is a performance improvement on some platforms. + bool UsersNotEnd(const IdToUsersMap::const_iterator& iter, + const Instruction* def) const; + bool UsersNotEnd(const IdToUsersMap::const_iterator& iter, + const IdToUsersMap::const_iterator& cached_end, + const Instruction* def) const; + + // Analyzes the defs and uses in the given |module| and populates data + // structures in this class. Does nothing if |module| is nullptr. + void AnalyzeDefUse(Module* module); + + IdToDefMap id_to_def_; // Mapping from ids to their definitions + IdToUsersMap id_to_users_; // Mapping from ids to their users + // Mapping from instructions to the ids used in the instruction. + InstToUsedIdsMap inst_to_used_ids_; +}; + +} // namespace analysis +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_DEF_USE_MANAGER_H_ diff --git a/source/opt/dominator_analysis.cpp b/source/opt/dominator_analysis.cpp new file mode 100644 index 000000000..aef43e69f --- /dev/null +++ b/source/opt/dominator_analysis.cpp @@ -0,0 +1,68 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/dominator_analysis.h" + +#include + +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace opt { + +BasicBlock* DominatorAnalysisBase::CommonDominator(BasicBlock* b1, + BasicBlock* b2) const { + if (!b1 || !b2) return nullptr; + + std::unordered_set seen; + BasicBlock* block = b1; + while (block && seen.insert(block).second) { + block = ImmediateDominator(block); + } + + block = b2; + while (block && !seen.count(block)) { + block = ImmediateDominator(block); + } + + return block; +} + +bool DominatorAnalysisBase::Dominates(Instruction* a, Instruction* b) const { + if (!a || !b) { + return false; + } + + if (a == b) { + return true; + } + + BasicBlock* bb_a = a->context()->get_instr_block(a); + BasicBlock* bb_b = b->context()->get_instr_block(b); + + if (bb_a != bb_b) { + return tree_.Dominates(bb_a, bb_b); + } + + Instruction* current_inst = a; + while ((current_inst = current_inst->NextNode())) { + if (current_inst == b) { + return true; + } + } + return false; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/dominator_analysis.h b/source/opt/dominator_analysis.h new file mode 100644 index 000000000..a94120a55 --- /dev/null +++ b/source/opt/dominator_analysis.h @@ -0,0 +1,138 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_DOMINATOR_ANALYSIS_H_ +#define SOURCE_OPT_DOMINATOR_ANALYSIS_H_ + +#include +#include + +#include "source/opt/dominator_tree.h" + +namespace spvtools { +namespace opt { + +// Interface to perform dominator or postdominator analysis on a given function. +class DominatorAnalysisBase { + public: + explicit DominatorAnalysisBase(bool is_post_dom) : tree_(is_post_dom) {} + + // Calculates the dominator (or postdominator) tree for given function |f|. + inline void InitializeTree(const CFG& cfg, const Function* f) { + tree_.InitializeTree(cfg, f); + } + + // Returns true if BasicBlock |a| dominates BasicBlock |b|. + inline bool Dominates(const BasicBlock* a, const BasicBlock* b) const { + if (!a || !b) return false; + return Dominates(a->id(), b->id()); + } + + // Returns true if BasicBlock |a| dominates BasicBlock |b|. Same as above only + // using the BasicBlock IDs. + inline bool Dominates(uint32_t a, uint32_t b) const { + return tree_.Dominates(a, b); + } + + // Returns true if instruction |a| dominates instruction |b|. + bool Dominates(Instruction* a, Instruction* b) const; + + // Returns true if BasicBlock |a| strictly dominates BasicBlock |b|. + inline bool StrictlyDominates(const BasicBlock* a, + const BasicBlock* b) const { + if (!a || !b) return false; + return StrictlyDominates(a->id(), b->id()); + } + + // Returns true if BasicBlock |a| strictly dominates BasicBlock |b|. Same as + // above only using the BasicBlock IDs. + inline bool StrictlyDominates(uint32_t a, uint32_t b) const { + return tree_.StrictlyDominates(a, b); + } + + // Returns the immediate dominator of |node| or returns nullptr if it is has + // no dominator. + inline BasicBlock* ImmediateDominator(const BasicBlock* node) const { + if (!node) return nullptr; + return tree_.ImmediateDominator(node); + } + + // Returns the immediate dominator of |node_id| or returns nullptr if it is + // has no dominator. Same as above but operates on IDs. + inline BasicBlock* ImmediateDominator(uint32_t node_id) const { + return tree_.ImmediateDominator(node_id); + } + + // Returns true if |node| is reachable from the entry. + inline bool IsReachable(const BasicBlock* node) const { + if (!node) return false; + return tree_.ReachableFromRoots(node->id()); + } + + // Returns true if |node_id| is reachable from the entry. + inline bool IsReachable(uint32_t node_id) const { + return tree_.ReachableFromRoots(node_id); + } + + // Dump the tree structure into the given |out| stream in the dot format. + inline void DumpAsDot(std::ostream& out) const { tree_.DumpTreeAsDot(out); } + + // Returns true if this is a postdomiator tree. + inline bool IsPostDominator() const { return tree_.IsPostDominator(); } + + // Returns the tree itself for manual operations, such as traversing the + // roots. + // For normal dominance relationships the methods above should be used. + inline DominatorTree& GetDomTree() { return tree_; } + inline const DominatorTree& GetDomTree() const { return tree_; } + + // Force the dominator tree to be removed + inline void ClearTree() { tree_.ClearTree(); } + + // Applies the std::function |func| to dominator tree nodes in dominator + // order. + void Visit(std::function func) { + tree_.Visit(func); + } + + // Applies the std::function |func| to dominator tree nodes in dominator + // order. + void Visit(std::function func) const { + tree_.Visit(func); + } + + // Returns the most immediate basic block that dominates both |b1| and |b2|. + // If there is no such basic block, nullptr is returned. + BasicBlock* CommonDominator(BasicBlock* b1, BasicBlock* b2) const; + + protected: + DominatorTree tree_; +}; + +// Derived class for normal dominator analysis. +class DominatorAnalysis : public DominatorAnalysisBase { + public: + DominatorAnalysis() : DominatorAnalysisBase(false) {} +}; + +// Derived class for postdominator analysis. +class PostDominatorAnalysis : public DominatorAnalysisBase { + public: + PostDominatorAnalysis() : DominatorAnalysisBase(true) {} +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_DOMINATOR_ANALYSIS_H_ diff --git a/source/opt/dominator_tree.cpp b/source/opt/dominator_tree.cpp new file mode 100644 index 000000000..c9346e1c5 --- /dev/null +++ b/source/opt/dominator_tree.cpp @@ -0,0 +1,394 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "source/cfa.h" +#include "source/opt/dominator_tree.h" +#include "source/opt/ir_context.h" + +// Calculates the dominator or postdominator tree for a given function. +// 1 - Compute the successors and predecessors for each BasicBlock. We add a +// dummy node for the start node or for postdominators the exit. This node will +// point to all entry or all exit nodes. +// 2 - Using the CFA::DepthFirstTraversal get a depth first postordered list of +// all BasicBlocks. Using the successors (or for postdominator, predecessors) +// calculated in step 1 to traverse the tree. +// 3 - Pass the list calculated in step 2 to the CFA::CalculateDominators using +// the predecessors list (or for postdominator, successors). This will give us a +// vector of BB pairs. Each BB and its immediate dominator. +// 4 - Using the list from 3 use those edges to build a tree of +// DominatorTreeNodes. Each node containing a link to the parent dominator and +// children which are dominated. +// 5 - Using the tree from 4, perform a depth first traversal to calculate the +// preorder and postorder index of each node. We use these indexes to compare +// nodes against each other for domination checks. + +namespace spvtools { +namespace opt { +namespace { + +// Wrapper around CFA::DepthFirstTraversal to provide an interface to perform +// depth first search on generic BasicBlock types. Will call post and pre order +// user defined functions during traversal +// +// BBType - BasicBlock type. Will either be BasicBlock or DominatorTreeNode +// SuccessorLambda - Lamdba matching the signature of 'const +// std::vector*(const BBType *A)'. Will return a vector of the nodes +// succeding BasicBlock A. +// PostLambda - Lamdba matching the signature of 'void (const BBType*)' will be +// called on each node traversed AFTER their children. +// PreLambda - Lamdba matching the signature of 'void (const BBType*)' will be +// called on each node traversed BEFORE their children. +template +static void DepthFirstSearch(const BBType* bb, SuccessorLambda successors, + PreLambda pre, PostLambda post) { + // Ignore backedge operation. + auto nop_backedge = [](const BBType*, const BBType*) {}; + CFA::DepthFirstTraversal(bb, successors, pre, post, nop_backedge); +} + +// Wrapper around CFA::DepthFirstTraversal to provide an interface to perform +// depth first search on generic BasicBlock types. This overload is for only +// performing user defined post order. +// +// BBType - BasicBlock type. Will either be BasicBlock or DominatorTreeNode +// SuccessorLambda - Lamdba matching the signature of 'const +// std::vector*(const BBType *A)'. Will return a vector of the nodes +// succeding BasicBlock A. +// PostLambda - Lamdba matching the signature of 'void (const BBType*)' will be +// called on each node traversed after their children. +template +static void DepthFirstSearchPostOrder(const BBType* bb, + SuccessorLambda successors, + PostLambda post) { + // Ignore preorder operation. + auto nop_preorder = [](const BBType*) {}; + DepthFirstSearch(bb, successors, nop_preorder, post); +} + +// Small type trait to get the function class type. +template +struct GetFunctionClass { + using FunctionType = Function; +}; + +// Helper class to compute predecessors and successors for each Basic Block in a +// function. Through GetPredFunctor and GetSuccessorFunctor it provides an +// interface to get the successor and predecessor lists for each basic +// block. This is required by the DepthFirstTraversal and ComputeDominator +// functions which take as parameter an std::function returning the successors +// and predecessors respectively. +// +// When computing the post-dominator tree, all edges are inverted. So successors +// returned by this class will be predecessors in the original CFG. +template +class BasicBlockSuccessorHelper { + // This should eventually become const BasicBlock. + using BasicBlock = BBType; + using Function = typename GetFunctionClass::FunctionType; + + using BasicBlockListTy = std::vector; + using BasicBlockMapTy = std::map; + + public: + // For compliance with the dominance tree computation, entry nodes are + // connected to a single dummy node. + BasicBlockSuccessorHelper(Function& func, const BasicBlock* dummy_start_node, + bool post); + + // CFA::CalculateDominators requires std::vector. + using GetBlocksFunction = + std::function*(const BasicBlock*)>; + + // Returns the list of predecessor functions. + GetBlocksFunction GetPredFunctor() { + return [this](const BasicBlock* bb) { + BasicBlockListTy* v = &this->predecessors_[bb]; + return v; + }; + } + + // Returns a vector of the list of successor nodes from a given node. + GetBlocksFunction GetSuccessorFunctor() { + return [this](const BasicBlock* bb) { + BasicBlockListTy* v = &this->successors_[bb]; + return v; + }; + } + + private: + bool invert_graph_; + BasicBlockMapTy successors_; + BasicBlockMapTy predecessors_; + + // Build the successors and predecessors map for each basic blocks |f|. + // If |invert_graph_| is true, all edges are reversed (successors becomes + // predecessors and vice versa). + // For convenience, the start of the graph is |dummy_start_node|. + // The dominator tree construction requires a unique entry node, which cannot + // be guaranteed for the postdominator graph. The |dummy_start_node| BB is + // here to gather all entry nodes. + void CreateSuccessorMap(Function& f, const BasicBlock* dummy_start_node); +}; + +template +BasicBlockSuccessorHelper::BasicBlockSuccessorHelper( + Function& func, const BasicBlock* dummy_start_node, bool invert) + : invert_graph_(invert) { + CreateSuccessorMap(func, dummy_start_node); +} + +template +void BasicBlockSuccessorHelper::CreateSuccessorMap( + Function& f, const BasicBlock* dummy_start_node) { + std::map id_to_BB_map; + auto GetSuccessorBasicBlock = [&f, &id_to_BB_map](uint32_t successor_id) { + BasicBlock*& Succ = id_to_BB_map[successor_id]; + if (!Succ) { + for (BasicBlock& BBIt : f) { + if (successor_id == BBIt.id()) { + Succ = &BBIt; + break; + } + } + } + return Succ; + }; + + if (invert_graph_) { + // For the post dominator tree, we see the inverted graph. + // successors_ in the inverted graph are the predecessors in the CFG. + // The tree construction requires 1 entry point, so we add a dummy node + // that is connected to all function exiting basic blocks. + // An exiting basic block is a block with an OpKill, OpUnreachable, + // OpReturn or OpReturnValue as terminator instruction. + for (BasicBlock& bb : f) { + if (bb.hasSuccessor()) { + BasicBlockListTy& pred_list = predecessors_[&bb]; + const auto& const_bb = bb; + const_bb.ForEachSuccessorLabel( + [this, &pred_list, &bb, + &GetSuccessorBasicBlock](const uint32_t successor_id) { + BasicBlock* succ = GetSuccessorBasicBlock(successor_id); + // Inverted graph: our successors in the CFG + // are our predecessors in the inverted graph. + this->successors_[succ].push_back(&bb); + pred_list.push_back(succ); + }); + } else { + successors_[dummy_start_node].push_back(&bb); + predecessors_[&bb].push_back(const_cast(dummy_start_node)); + } + } + } else { + successors_[dummy_start_node].push_back(f.entry().get()); + predecessors_[f.entry().get()].push_back( + const_cast(dummy_start_node)); + for (BasicBlock& bb : f) { + BasicBlockListTy& succ_list = successors_[&bb]; + + const auto& const_bb = bb; + const_bb.ForEachSuccessorLabel([&](const uint32_t successor_id) { + BasicBlock* succ = GetSuccessorBasicBlock(successor_id); + succ_list.push_back(succ); + predecessors_[succ].push_back(&bb); + }); + } + } +} + +} // namespace + +bool DominatorTree::StrictlyDominates(uint32_t a, uint32_t b) const { + if (a == b) return false; + return Dominates(a, b); +} + +bool DominatorTree::StrictlyDominates(const BasicBlock* a, + const BasicBlock* b) const { + return DominatorTree::StrictlyDominates(a->id(), b->id()); +} + +bool DominatorTree::StrictlyDominates(const DominatorTreeNode* a, + const DominatorTreeNode* b) const { + if (a == b) return false; + return Dominates(a, b); +} + +bool DominatorTree::Dominates(uint32_t a, uint32_t b) const { + // Check that both of the inputs are actual nodes. + const DominatorTreeNode* a_node = GetTreeNode(a); + const DominatorTreeNode* b_node = GetTreeNode(b); + if (!a_node || !b_node) return false; + + return Dominates(a_node, b_node); +} + +bool DominatorTree::Dominates(const DominatorTreeNode* a, + const DominatorTreeNode* b) const { + // Node A dominates node B if they are the same. + if (a == b) return true; + + return a->dfs_num_pre_ < b->dfs_num_pre_ && + a->dfs_num_post_ > b->dfs_num_post_; +} + +bool DominatorTree::Dominates(const BasicBlock* A, const BasicBlock* B) const { + return Dominates(A->id(), B->id()); +} + +BasicBlock* DominatorTree::ImmediateDominator(const BasicBlock* A) const { + return ImmediateDominator(A->id()); +} + +BasicBlock* DominatorTree::ImmediateDominator(uint32_t a) const { + // Check that A is a valid node in the tree. + auto a_itr = nodes_.find(a); + if (a_itr == nodes_.end()) return nullptr; + + const DominatorTreeNode* node = &a_itr->second; + + if (node->parent_ == nullptr) { + return nullptr; + } + + return node->parent_->bb_; +} + +DominatorTreeNode* DominatorTree::GetOrInsertNode(BasicBlock* bb) { + DominatorTreeNode* dtn = nullptr; + + std::map::iterator node_iter = + nodes_.find(bb->id()); + if (node_iter == nodes_.end()) { + dtn = &nodes_.emplace(std::make_pair(bb->id(), DominatorTreeNode{bb})) + .first->second; + } else { + dtn = &node_iter->second; + } + + return dtn; +} + +void DominatorTree::GetDominatorEdges( + const Function* f, const BasicBlock* dummy_start_node, + std::vector>* edges) { + // Each time the depth first traversal calls the postorder callback + // std::function we push that node into the postorder vector to create our + // postorder list. + std::vector postorder; + auto postorder_function = [&](const BasicBlock* b) { + postorder.push_back(b); + }; + + // CFA::CalculateDominators requires std::vector + // BB are derived from F, so we need to const cast it at some point + // no modification is made on F. + BasicBlockSuccessorHelper helper{ + *const_cast(f), dummy_start_node, postdominator_}; + + // The successor function tells DepthFirstTraversal how to move to successive + // nodes by providing an interface to get a list of successor nodes from any + // given node. + auto successor_functor = helper.GetSuccessorFunctor(); + + // The predecessor functor does the same as the successor functor + // but for all nodes preceding a given node. + auto predecessor_functor = helper.GetPredFunctor(); + + // If we're building a post dominator tree we traverse the tree in reverse + // using the predecessor function in place of the successor function and vice + // versa. + DepthFirstSearchPostOrder(dummy_start_node, successor_functor, + postorder_function); + *edges = CFA::CalculateDominators(postorder, predecessor_functor); +} + +void DominatorTree::InitializeTree(const CFG& cfg, const Function* f) { + ClearTree(); + + // Skip over empty functions. + if (f->cbegin() == f->cend()) { + return; + } + + const BasicBlock* dummy_start_node = + postdominator_ ? cfg.pseudo_exit_block() : cfg.pseudo_entry_block(); + + // Get the immediate dominator for each node. + std::vector> edges; + GetDominatorEdges(f, dummy_start_node, &edges); + + // Transform the vector into the tree structure which we can use to + // efficiently query dominance. + for (auto edge : edges) { + DominatorTreeNode* first = GetOrInsertNode(edge.first); + + if (edge.first == edge.second) { + if (std::find(roots_.begin(), roots_.end(), first) == roots_.end()) + roots_.push_back(first); + continue; + } + + DominatorTreeNode* second = GetOrInsertNode(edge.second); + + first->parent_ = second; + second->children_.push_back(first); + } + ResetDFNumbering(); +} + +void DominatorTree::ResetDFNumbering() { + int index = 0; + auto preFunc = [&index](const DominatorTreeNode* node) { + const_cast(node)->dfs_num_pre_ = ++index; + }; + + auto postFunc = [&index](const DominatorTreeNode* node) { + const_cast(node)->dfs_num_post_ = ++index; + }; + + auto getSucc = [](const DominatorTreeNode* node) { return &node->children_; }; + + for (auto root : roots_) DepthFirstSearch(root, getSucc, preFunc, postFunc); +} + +void DominatorTree::DumpTreeAsDot(std::ostream& out_stream) const { + out_stream << "digraph {\n"; + Visit([&out_stream](const DominatorTreeNode* node) { + // Print the node. + if (node->bb_) { + out_stream << node->bb_->id() << "[label=\"" << node->bb_->id() + << "\"];\n"; + } + + // Print the arrow from the parent to this node. Entry nodes will not have + // parents so draw them as children from the dummy node. + if (node->parent_) { + out_stream << node->parent_->bb_->id() << " -> " << node->bb_->id() + << ";\n"; + } + + // Return true to continue the traversal. + return true; + }); + out_stream << "}\n"; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/dominator_tree.h b/source/opt/dominator_tree.h new file mode 100644 index 000000000..0024bc508 --- /dev/null +++ b/source/opt/dominator_tree.h @@ -0,0 +1,305 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_DOMINATOR_TREE_H_ +#define SOURCE_OPT_DOMINATOR_TREE_H_ + +#include +#include +#include +#include +#include + +#include "source/opt/cfg.h" +#include "source/opt/tree_iterator.h" + +namespace spvtools { +namespace opt { +// This helper struct forms the nodes in the tree, with each node containing its +// children. It also contains two values, for the pre and post indexes in the +// tree which are used to compare two nodes. +struct DominatorTreeNode { + explicit DominatorTreeNode(BasicBlock* bb) + : bb_(bb), + parent_(nullptr), + children_({}), + dfs_num_pre_(-1), + dfs_num_post_(-1) {} + + using iterator = std::vector::iterator; + using const_iterator = std::vector::const_iterator; + + // depth first preorder iterator. + using df_iterator = TreeDFIterator; + using const_df_iterator = TreeDFIterator; + // depth first postorder iterator. + using post_iterator = PostOrderTreeDFIterator; + using const_post_iterator = PostOrderTreeDFIterator; + + iterator begin() { return children_.begin(); } + iterator end() { return children_.end(); } + const_iterator begin() const { return cbegin(); } + const_iterator end() const { return cend(); } + const_iterator cbegin() const { return children_.begin(); } + const_iterator cend() const { return children_.end(); } + + // Depth first preorder iterator using this node as root. + df_iterator df_begin() { return df_iterator(this); } + df_iterator df_end() { return df_iterator(); } + const_df_iterator df_begin() const { return df_cbegin(); } + const_df_iterator df_end() const { return df_cend(); } + const_df_iterator df_cbegin() const { return const_df_iterator(this); } + const_df_iterator df_cend() const { return const_df_iterator(); } + + // Depth first postorder iterator using this node as root. + post_iterator post_begin() { return post_iterator::begin(this); } + post_iterator post_end() { return post_iterator::end(nullptr); } + const_post_iterator post_begin() const { return post_cbegin(); } + const_post_iterator post_end() const { return post_cend(); } + const_post_iterator post_cbegin() const { + return const_post_iterator::begin(this); + } + const_post_iterator post_cend() const { + return const_post_iterator::end(nullptr); + } + + inline uint32_t id() const { return bb_->id(); } + + BasicBlock* bb_; + DominatorTreeNode* parent_; + std::vector children_; + + // These indexes are used to compare two given nodes. A node is a child or + // grandchild of another node if its preorder index is greater than the + // first nodes preorder index AND if its postorder index is less than the + // first nodes postorder index. + int dfs_num_pre_; + int dfs_num_post_; +}; + +// A class representing a tree of BasicBlocks in a given function, where each +// node is dominated by its parent. +class DominatorTree { + public: + // Map OpLabel ids to dominator tree nodes + using DominatorTreeNodeMap = std::map; + using iterator = TreeDFIterator; + using const_iterator = TreeDFIterator; + using post_iterator = PostOrderTreeDFIterator; + using const_post_iterator = PostOrderTreeDFIterator; + + // List of DominatorTreeNode to define the list of roots + using DominatorTreeNodeList = std::vector; + using roots_iterator = DominatorTreeNodeList::iterator; + using roots_const_iterator = DominatorTreeNodeList::const_iterator; + + DominatorTree() : postdominator_(false) {} + explicit DominatorTree(bool post) : postdominator_(post) {} + + // Depth first iterators. + // Traverse the dominator tree in a depth first pre-order. + // The pseudo-block is ignored. + iterator begin() { return ++iterator(GetRoot()); } + iterator end() { return iterator(); } + const_iterator begin() const { return cbegin(); } + const_iterator end() const { return cend(); } + const_iterator cbegin() const { return ++const_iterator(GetRoot()); } + const_iterator cend() const { return const_iterator(); } + + // Traverse the dominator tree in a depth first post-order. + // The pseudo-block is ignored. + post_iterator post_begin() { return post_iterator::begin(GetRoot()); } + post_iterator post_end() { return post_iterator::end(GetRoot()); } + const_post_iterator post_begin() const { return post_cbegin(); } + const_post_iterator post_end() const { return post_cend(); } + const_post_iterator post_cbegin() const { + return const_post_iterator::begin(GetRoot()); + } + const_post_iterator post_cend() const { + return const_post_iterator::end(GetRoot()); + } + + roots_iterator roots_begin() { return roots_.begin(); } + roots_iterator roots_end() { return roots_.end(); } + roots_const_iterator roots_begin() const { return roots_cbegin(); } + roots_const_iterator roots_end() const { return roots_cend(); } + roots_const_iterator roots_cbegin() const { return roots_.begin(); } + roots_const_iterator roots_cend() const { return roots_.end(); } + + // Get the unique root of the tree. + // It is guaranteed to work on a dominator tree. + // post-dominator might have a list. + DominatorTreeNode* GetRoot() { + assert(roots_.size() == 1); + return *roots_.begin(); + } + + const DominatorTreeNode* GetRoot() const { + assert(roots_.size() == 1); + return *roots_.begin(); + } + + const DominatorTreeNodeList& Roots() const { return roots_; } + + // Dumps the tree in the graphvis dot format into the |out_stream|. + void DumpTreeAsDot(std::ostream& out_stream) const; + + // Build the (post-)dominator tree for the given control flow graph + // |cfg| and the function |f|. |f| must exist in the |cfg|. Any + // existing data in the dominator tree will be overwritten + void InitializeTree(const CFG& cfg, const Function* f); + + // Check if the basic block |a| dominates the basic block |b|. + bool Dominates(const BasicBlock* a, const BasicBlock* b) const; + + // Check if the basic block id |a| dominates the basic block id |b|. + bool Dominates(uint32_t a, uint32_t b) const; + + // Check if the dominator tree node |a| dominates the dominator tree node |b|. + bool Dominates(const DominatorTreeNode* a, const DominatorTreeNode* b) const; + + // Check if the basic block |a| strictly dominates the basic block |b|. + bool StrictlyDominates(const BasicBlock* a, const BasicBlock* b) const; + + // Check if the basic block id |a| strictly dominates the basic block id |b|. + bool StrictlyDominates(uint32_t a, uint32_t b) const; + + // Check if the dominator tree node |a| strictly dominates the dominator tree + // node |b|. + bool StrictlyDominates(const DominatorTreeNode* a, + const DominatorTreeNode* b) const; + + // Returns the immediate dominator of basic block |a|. + BasicBlock* ImmediateDominator(const BasicBlock* A) const; + + // Returns the immediate dominator of basic block id |a|. + BasicBlock* ImmediateDominator(uint32_t a) const; + + // Returns true if the basic block |a| is reachable by this tree. A node would + // be unreachable if it cannot be reached by traversal from the start node or + // for a postdominator tree, cannot be reached from the exit nodes. + inline bool ReachableFromRoots(const BasicBlock* a) const { + if (!a) return false; + return ReachableFromRoots(a->id()); + } + + // Returns true if the basic block id |a| is reachable by this tree. + bool ReachableFromRoots(uint32_t a) const { + return GetTreeNode(a) != nullptr; + } + + // Returns true if this tree is a post dominator tree. + bool IsPostDominator() const { return postdominator_; } + + // Clean up the tree. + void ClearTree() { + nodes_.clear(); + roots_.clear(); + } + + // Applies the std::function |func| to all nodes in the dominator tree. + // Tree nodes are visited in a depth first pre-order. + bool Visit(std::function func) { + for (auto n : *this) { + if (!func(&n)) return false; + } + return true; + } + + // Applies the std::function |func| to all nodes in the dominator tree. + // Tree nodes are visited in a depth first pre-order. + bool Visit(std::function func) const { + for (auto n : *this) { + if (!func(&n)) return false; + } + return true; + } + + // Applies the std::function |func| to all nodes in the dominator tree from + // |node| downwards. The boolean return from |func| is used to determine + // whether or not the children should also be traversed. Tree nodes are + // visited in a depth first pre-order. + void VisitChildrenIf(std::function func, + iterator node) { + if (func(&*node)) { + for (auto n : *node) { + VisitChildrenIf(func, n->df_begin()); + } + } + } + + // Returns the DominatorTreeNode associated with the basic block |bb|. + // If the |bb| is unknown to the dominator tree, it returns null. + inline DominatorTreeNode* GetTreeNode(BasicBlock* bb) { + return GetTreeNode(bb->id()); + } + // Returns the DominatorTreeNode associated with the basic block |bb|. + // If the |bb| is unknown to the dominator tree, it returns null. + inline const DominatorTreeNode* GetTreeNode(BasicBlock* bb) const { + return GetTreeNode(bb->id()); + } + + // Returns the DominatorTreeNode associated with the basic block id |id|. + // If the id |id| is unknown to the dominator tree, it returns null. + inline DominatorTreeNode* GetTreeNode(uint32_t id) { + DominatorTreeNodeMap::iterator node_iter = nodes_.find(id); + if (node_iter == nodes_.end()) { + return nullptr; + } + return &node_iter->second; + } + // Returns the DominatorTreeNode associated with the basic block id |id|. + // If the id |id| is unknown to the dominator tree, it returns null. + inline const DominatorTreeNode* GetTreeNode(uint32_t id) const { + DominatorTreeNodeMap::const_iterator node_iter = nodes_.find(id); + if (node_iter == nodes_.end()) { + return nullptr; + } + return &node_iter->second; + } + + // Adds the basic block |bb| to the tree structure if it doesn't already + // exist. + DominatorTreeNode* GetOrInsertNode(BasicBlock* bb); + + // Recomputes the DF numbering of the tree. + void ResetDFNumbering(); + + private: + // Wrapper function which gets the list of pairs of each BasicBlocks to its + // immediately dominating BasicBlock and stores the result in the the edges + // parameter. + // + // The |edges| vector will contain the dominator tree as pairs of nodes. + // The first node in the pair is a node in the graph. The second node in the + // pair is its immediate dominator. + // The root of the tree has themself as immediate dominator. + void GetDominatorEdges( + const Function* f, const BasicBlock* dummy_start_node, + std::vector>* edges); + + // The roots of the tree. + std::vector roots_; + + // Pairs each basic block id to the tree node containing that basic block. + DominatorTreeNodeMap nodes_; + + // True if this is a post dominator tree. + bool postdominator_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_DOMINATOR_TREE_H_ diff --git a/source/opt/eliminate_dead_constant_pass.cpp b/source/opt/eliminate_dead_constant_pass.cpp new file mode 100644 index 000000000..d368bd145 --- /dev/null +++ b/source/opt/eliminate_dead_constant_pass.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/eliminate_dead_constant_pass.h" + +#include +#include +#include +#include + +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/log.h" +#include "source/opt/reflect.h" + +namespace spvtools { +namespace opt { + +Pass::Status EliminateDeadConstantPass::Process() { + std::unordered_set working_list; + // Traverse all the instructions to get the initial set of dead constants as + // working list and count number of real uses for constants. Uses in + // annotation instructions do not count. + std::unordered_map use_counts; + std::vector constants = context()->GetConstants(); + for (auto* c : constants) { + uint32_t const_id = c->result_id(); + size_t count = 0; + context()->get_def_use_mgr()->ForEachUse( + const_id, [&count](Instruction* user, uint32_t index) { + (void)index; + SpvOp op = user->opcode(); + if (!(IsAnnotationInst(op) || IsDebug1Inst(op) || IsDebug2Inst(op) || + IsDebug3Inst(op))) { + ++count; + } + }); + use_counts[c] = count; + if (!count) { + working_list.insert(c); + } + } + + // Start from the constants with 0 uses, back trace through the def-use chain + // to find all dead constants. + std::unordered_set dead_consts; + while (!working_list.empty()) { + Instruction* inst = *working_list.begin(); + // Back propagate if the instruction contains IDs in its operands. + switch (inst->opcode()) { + case SpvOp::SpvOpConstantComposite: + case SpvOp::SpvOpSpecConstantComposite: + case SpvOp::SpvOpSpecConstantOp: + for (uint32_t i = 0; i < inst->NumInOperands(); i++) { + // SpecConstantOp instruction contains 'opcode' as its operand. Need + // to exclude such operands when decreasing uses. + if (inst->GetInOperand(i).type != SPV_OPERAND_TYPE_ID) { + continue; + } + uint32_t operand_id = inst->GetSingleWordInOperand(i); + Instruction* def_inst = + context()->get_def_use_mgr()->GetDef(operand_id); + // If the use_count does not have any count for the def_inst, + // def_inst must not be a constant, and should be ignored here. + if (!use_counts.count(def_inst)) { + continue; + } + // The number of uses should never be less then 0, so it can not be + // less than 1 before it decreases. + SPIRV_ASSERT(consumer(), use_counts[def_inst] > 0); + --use_counts[def_inst]; + if (!use_counts[def_inst]) { + working_list.insert(def_inst); + } + } + break; + default: + break; + } + dead_consts.insert(inst); + working_list.erase(inst); + } + + // Turn all dead instructions and uses of them to nop + for (auto* dc : dead_consts) { + context()->KillDef(dc->result_id()); + } + return dead_consts.empty() ? Status::SuccessWithoutChange + : Status::SuccessWithChange; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/eliminate_dead_constant_pass.h b/source/opt/eliminate_dead_constant_pass.h new file mode 100644 index 000000000..01692dbf4 --- /dev/null +++ b/source/opt/eliminate_dead_constant_pass.h @@ -0,0 +1,35 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_ELIMINATE_DEAD_CONSTANT_PASS_H_ +#define SOURCE_OPT_ELIMINATE_DEAD_CONSTANT_PASS_H_ + +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class EliminateDeadConstantPass : public Pass { + public: + const char* name() const override { return "eliminate-dead-const"; } + Status Process() override; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_ELIMINATE_DEAD_CONSTANT_PASS_H_ diff --git a/source/opt/eliminate_dead_functions_pass.cpp b/source/opt/eliminate_dead_functions_pass.cpp new file mode 100644 index 000000000..f067be5f9 --- /dev/null +++ b/source/opt/eliminate_dead_functions_pass.cpp @@ -0,0 +1,56 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/eliminate_dead_functions_pass.h" + +#include + +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace opt { + +Pass::Status EliminateDeadFunctionsPass::Process() { + // Identify live functions first. Those that are not live + // are dead. + std::unordered_set live_function_set; + ProcessFunction mark_live = [&live_function_set](Function* fp) { + live_function_set.insert(fp); + return false; + }; + context()->ProcessReachableCallTree(mark_live); + + bool modified = false; + for (auto funcIter = get_module()->begin(); + funcIter != get_module()->end();) { + if (live_function_set.count(&*funcIter) == 0) { + modified = true; + EliminateFunction(&*funcIter); + funcIter = funcIter.Erase(); + } else { + ++funcIter; + } + } + + return modified ? Pass::Status::SuccessWithChange + : Pass::Status::SuccessWithoutChange; +} + +void EliminateDeadFunctionsPass::EliminateFunction(Function* func) { + // Remove all of the instruction in the function body + func->ForEachInst([this](Instruction* inst) { context()->KillInst(inst); }, + true); +} +} // namespace opt +} // namespace spvtools diff --git a/source/opt/eliminate_dead_functions_pass.h b/source/opt/eliminate_dead_functions_pass.h new file mode 100644 index 000000000..6ed5c42b0 --- /dev/null +++ b/source/opt/eliminate_dead_functions_pass.h @@ -0,0 +1,44 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_ELIMINATE_DEAD_FUNCTIONS_PASS_H_ +#define SOURCE_OPT_ELIMINATE_DEAD_FUNCTIONS_PASS_H_ + +#include "source/opt/def_use_manager.h" +#include "source/opt/function.h" +#include "source/opt/mem_pass.h" +#include "source/opt/module.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class EliminateDeadFunctionsPass : public MemPass { + public: + const char* name() const override { return "eliminate-dead-functions"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | IRContext::kAnalysisConstants | + IRContext::kAnalysisTypes; + } + + private: + void EliminateFunction(Function* func); +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_ELIMINATE_DEAD_FUNCTIONS_PASS_H_ diff --git a/source/opt/feature_manager.cpp b/source/opt/feature_manager.cpp new file mode 100644 index 000000000..b7fc16a50 --- /dev/null +++ b/source/opt/feature_manager.cpp @@ -0,0 +1,67 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/feature_manager.h" + +#include +#include +#include + +#include "source/enum_string_mapping.h" + +namespace spvtools { +namespace opt { + +void FeatureManager::Analyze(Module* module) { + AddExtensions(module); + AddCapabilities(module); + AddExtInstImportIds(module); +} + +void FeatureManager::AddExtensions(Module* module) { + for (auto ext : module->extensions()) { + const std::string name = + reinterpret_cast(ext.GetInOperand(0u).words.data()); + Extension extension; + if (GetExtensionFromString(name.c_str(), &extension)) { + extensions_.Add(extension); + } + } +} + +void FeatureManager::AddCapability(SpvCapability cap) { + if (capabilities_.Contains(cap)) return; + + capabilities_.Add(cap); + + spv_operand_desc desc = {}; + if (SPV_SUCCESS == + grammar_.lookupOperand(SPV_OPERAND_TYPE_CAPABILITY, cap, &desc)) { + CapabilitySet(desc->numCapabilities, desc->capabilities) + .ForEach([this](SpvCapability c) { AddCapability(c); }); + } +} + +void FeatureManager::AddCapabilities(Module* module) { + for (Instruction& inst : module->capabilities()) { + AddCapability(static_cast(inst.GetSingleWordInOperand(0))); + } +} + +void FeatureManager::AddExtInstImportIds(Module* module) { + extinst_importid_GLSLstd450_ = module->GetExtInstImportId("GLSL.std.450"); +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/feature_manager.h b/source/opt/feature_manager.h new file mode 100644 index 000000000..80b2cccf6 --- /dev/null +++ b/source/opt/feature_manager.h @@ -0,0 +1,78 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_FEATURE_MANAGER_H_ +#define SOURCE_OPT_FEATURE_MANAGER_H_ + +#include "source/assembly_grammar.h" +#include "source/extensions.h" +#include "source/opt/module.h" + +namespace spvtools { +namespace opt { + +// Tracks features enabled by a module. The IRContext has a FeatureManager. +class FeatureManager { + public: + explicit FeatureManager(const AssemblyGrammar& grammar) : grammar_(grammar) {} + + // Returns true if |ext| is an enabled extension in the module. + bool HasExtension(Extension ext) const { return extensions_.Contains(ext); } + + // Returns true if |cap| is an enabled capability in the module. + bool HasCapability(SpvCapability cap) const { + return capabilities_.Contains(cap); + } + + // Analyzes |module| and records enabled extensions and capabilities. + void Analyze(Module* module); + + CapabilitySet* GetCapabilities() { return &capabilities_; } + const CapabilitySet* GetCapabilities() const { return &capabilities_; } + + uint32_t GetExtInstImportId_GLSLstd450() const { + return extinst_importid_GLSLstd450_; + } + + private: + // Analyzes |module| and records enabled extensions. + void AddExtensions(Module* module); + + // Adds the given |capability| and all implied capabilities into the current + // FeatureManager. + void AddCapability(SpvCapability capability); + + // Analyzes |module| and records enabled capabilities. + void AddCapabilities(Module* module); + + // Analyzes |module| and records imported external instruction sets. + void AddExtInstImportIds(Module* module); + + // Auxiliary object for querying SPIR-V grammar facts. + const AssemblyGrammar& grammar_; + + // The enabled extensions. + ExtensionSet extensions_; + + // The enabled capabilities. + CapabilitySet capabilities_; + + // Common external instruction import ids, cached for performance. + uint32_t extinst_importid_GLSLstd450_ = 0; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_FEATURE_MANAGER_H_ diff --git a/source/opt/flatten_decoration_pass.cpp b/source/opt/flatten_decoration_pass.cpp new file mode 100644 index 000000000..f4de9116f --- /dev/null +++ b/source/opt/flatten_decoration_pass.cpp @@ -0,0 +1,165 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/flatten_decoration_pass.h" + +#include +#include +#include +#include +#include +#include + +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace opt { + +using Words = std::vector; +using OrderedUsesMap = std::unordered_map; + +Pass::Status FlattenDecorationPass::Process() { + bool modified = false; + + // The target Id of OpDecorationGroup instructions. + // We have to track this separately from its uses, in case it + // has no uses. + std::unordered_set group_ids; + // Maps a decoration group Id to its GroupDecorate targets, in order + // of appearance. + OrderedUsesMap normal_uses; + // Maps a decoration group Id to its GroupMemberDecorate targets and + // their indices, in of appearance. + OrderedUsesMap member_uses; + + auto annotations = context()->annotations(); + + // On the first pass, record each OpDecorationGroup with its ordered uses. + // Rely on unordered_map::operator[] to create its entries on first access. + for (const auto& inst : annotations) { + switch (inst.opcode()) { + case SpvOp::SpvOpDecorationGroup: + group_ids.insert(inst.result_id()); + break; + case SpvOp::SpvOpGroupDecorate: { + Words& words = normal_uses[inst.GetSingleWordInOperand(0)]; + for (uint32_t i = 1; i < inst.NumInOperandWords(); i++) { + words.push_back(inst.GetSingleWordInOperand(i)); + } + } break; + case SpvOp::SpvOpGroupMemberDecorate: { + Words& words = member_uses[inst.GetSingleWordInOperand(0)]; + for (uint32_t i = 1; i < inst.NumInOperandWords(); i++) { + words.push_back(inst.GetSingleWordInOperand(i)); + } + } break; + default: + break; + } + } + + // On the second pass, replace OpDecorationGroup and its uses with + // equivalent normal and struct member uses. + auto inst_iter = annotations.begin(); + // We have to re-evaluate the end pointer + while (inst_iter != context()->annotations().end()) { + // Should we replace this instruction? + bool replace = false; + switch (inst_iter->opcode()) { + case SpvOp::SpvOpDecorationGroup: + case SpvOp::SpvOpGroupDecorate: + case SpvOp::SpvOpGroupMemberDecorate: + replace = true; + break; + case SpvOp::SpvOpDecorate: { + // If this decoration targets a group, then replace it + // by sets of normal and member decorations. + const uint32_t group = inst_iter->GetSingleWordOperand(0); + const auto normal_uses_iter = normal_uses.find(group); + if (normal_uses_iter != normal_uses.end()) { + for (auto target : normal_uses[group]) { + std::unique_ptr new_inst(inst_iter->Clone(context())); + new_inst->SetInOperand(0, Words{target}); + inst_iter = inst_iter.InsertBefore(std::move(new_inst)); + ++inst_iter; + replace = true; + } + } + const auto member_uses_iter = member_uses.find(group); + if (member_uses_iter != member_uses.end()) { + const Words& member_id_pairs = (*member_uses_iter).second; + // The collection is a sequence of pairs. + assert((member_id_pairs.size() % 2) == 0); + for (size_t i = 0; i < member_id_pairs.size(); i += 2) { + // Make an OpMemberDecorate instruction for each (target, member) + // pair. + const uint32_t target = member_id_pairs[i]; + const uint32_t member = member_id_pairs[i + 1]; + std::vector operands; + operands.push_back(Operand(SPV_OPERAND_TYPE_ID, {target})); + operands.push_back( + Operand(SPV_OPERAND_TYPE_LITERAL_INTEGER, {member})); + auto decoration_operands_iter = inst_iter->begin(); + decoration_operands_iter++; // Skip the group target. + operands.insert(operands.end(), decoration_operands_iter, + inst_iter->end()); + std::unique_ptr new_inst(new Instruction( + context(), SpvOp::SpvOpMemberDecorate, 0, 0, operands)); + inst_iter = inst_iter.InsertBefore(std::move(new_inst)); + ++inst_iter; + replace = true; + } + } + // If this is an OpDecorate targeting the OpDecorationGroup itself, + // remove it even if that decoration group itself is not the target of + // any OpGroupDecorate or OpGroupMemberDecorate. + if (!replace && group_ids.count(group)) { + replace = true; + } + } break; + default: + break; + } + if (replace) { + inst_iter = inst_iter.Erase(); + modified = true; + } else { + // Handle the case of decorations unrelated to decoration groups. + ++inst_iter; + } + } + + // Remove OpName instructions which reference the removed group decorations. + // An OpDecorationGroup instruction might not have been used by an + // OpGroupDecorate or OpGroupMemberDecorate instruction. + if (!group_ids.empty()) { + for (auto debug_inst_iter = context()->debug2_begin(); + debug_inst_iter != context()->debug2_end();) { + if (debug_inst_iter->opcode() == SpvOp::SpvOpName) { + const uint32_t target = debug_inst_iter->GetSingleWordOperand(0); + if (group_ids.count(target)) { + debug_inst_iter = debug_inst_iter.Erase(); + modified = true; + } else { + ++debug_inst_iter; + } + } + } + } + + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/flatten_decoration_pass.h b/source/opt/flatten_decoration_pass.h new file mode 100644 index 000000000..6a34f5bb2 --- /dev/null +++ b/source/opt/flatten_decoration_pass.h @@ -0,0 +1,35 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_FLATTEN_DECORATION_PASS_H_ +#define SOURCE_OPT_FLATTEN_DECORATION_PASS_H_ + +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class FlattenDecorationPass : public Pass { + public: + const char* name() const override { return "flatten-decorations"; } + Status Process() override; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_FLATTEN_DECORATION_PASS_H_ diff --git a/source/opt/fold.cpp b/source/opt/fold.cpp new file mode 100644 index 000000000..944f43870 --- /dev/null +++ b/source/opt/fold.cpp @@ -0,0 +1,706 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/fold.h" + +#include +#include +#include + +#include "source/opt/const_folding_rules.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/folding_rules.h" +#include "source/opt/ir_builder.h" +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace opt { +namespace { + +#ifndef INT32_MIN +#define INT32_MIN (-2147483648) +#endif + +#ifndef INT32_MAX +#define INT32_MAX 2147483647 +#endif + +#ifndef UINT32_MAX +#define UINT32_MAX 0xffffffff /* 4294967295U */ +#endif + +} // namespace + +uint32_t InstructionFolder::UnaryOperate(SpvOp opcode, uint32_t operand) const { + switch (opcode) { + // Arthimetics + case SpvOp::SpvOpSNegate: { + int32_t s_operand = static_cast(operand); + if (s_operand == std::numeric_limits::min()) { + return s_operand; + } + return -s_operand; + } + case SpvOp::SpvOpNot: + return ~operand; + case SpvOp::SpvOpLogicalNot: + return !static_cast(operand); + default: + assert(false && + "Unsupported unary operation for OpSpecConstantOp instruction"); + return 0u; + } +} + +uint32_t InstructionFolder::BinaryOperate(SpvOp opcode, uint32_t a, + uint32_t b) const { + switch (opcode) { + // Arthimetics + case SpvOp::SpvOpIAdd: + return a + b; + case SpvOp::SpvOpISub: + return a - b; + case SpvOp::SpvOpIMul: + return a * b; + case SpvOp::SpvOpUDiv: + if (b != 0) { + return a / b; + } else { + // Dividing by 0 is undefined, so we will just pick 0. + return 0; + } + case SpvOp::SpvOpSDiv: + if (b != 0u) { + return (static_cast(a)) / (static_cast(b)); + } else { + // Dividing by 0 is undefined, so we will just pick 0. + return 0; + } + case SpvOp::SpvOpSRem: { + // The sign of non-zero result comes from the first operand: a. This is + // guaranteed by C++11 rules for integer division operator. The division + // result is rounded toward zero, so the result of '%' has the sign of + // the first operand. + if (b != 0u) { + return static_cast(a) % static_cast(b); + } else { + // Remainder when dividing with 0 is undefined, so we will just pick 0. + return 0; + } + } + case SpvOp::SpvOpSMod: { + // The sign of non-zero result comes from the second operand: b + if (b != 0u) { + int32_t rem = BinaryOperate(SpvOp::SpvOpSRem, a, b); + int32_t b_prim = static_cast(b); + return (rem + b_prim) % b_prim; + } else { + // Mod with 0 is undefined, so we will just pick 0. + return 0; + } + } + case SpvOp::SpvOpUMod: + if (b != 0u) { + return (a % b); + } else { + // Mod with 0 is undefined, so we will just pick 0. + return 0; + } + + // Shifting + case SpvOp::SpvOpShiftRightLogical: + if (b >= 32) { + // This is undefined behaviour when |b| > 32. Choose 0 for consistency. + // When |b| == 32, doing the shift in C++ in undefined, but the result + // will be 0, so just return that value. + return 0; + } + return a >> b; + case SpvOp::SpvOpShiftRightArithmetic: + if (b > 32) { + // This is undefined behaviour. Choose 0 for consistency. + return 0; + } + if (b == 32) { + // Doing the shift in C++ is undefined, but the result is defined in the + // spir-v spec. Find that value another way. + if (static_cast(a) >= 0) { + return 0; + } else { + return static_cast(-1); + } + } + return (static_cast(a)) >> b; + case SpvOp::SpvOpShiftLeftLogical: + if (b >= 32) { + // This is undefined behaviour when |b| > 32. Choose 0 for consistency. + // When |b| == 32, doing the shift in C++ in undefined, but the result + // will be 0, so just return that value. + return 0; + } + return a << b; + + // Bitwise operations + case SpvOp::SpvOpBitwiseOr: + return a | b; + case SpvOp::SpvOpBitwiseAnd: + return a & b; + case SpvOp::SpvOpBitwiseXor: + return a ^ b; + + // Logical + case SpvOp::SpvOpLogicalEqual: + return (static_cast(a)) == (static_cast(b)); + case SpvOp::SpvOpLogicalNotEqual: + return (static_cast(a)) != (static_cast(b)); + case SpvOp::SpvOpLogicalOr: + return (static_cast(a)) || (static_cast(b)); + case SpvOp::SpvOpLogicalAnd: + return (static_cast(a)) && (static_cast(b)); + + // Comparison + case SpvOp::SpvOpIEqual: + return a == b; + case SpvOp::SpvOpINotEqual: + return a != b; + case SpvOp::SpvOpULessThan: + return a < b; + case SpvOp::SpvOpSLessThan: + return (static_cast(a)) < (static_cast(b)); + case SpvOp::SpvOpUGreaterThan: + return a > b; + case SpvOp::SpvOpSGreaterThan: + return (static_cast(a)) > (static_cast(b)); + case SpvOp::SpvOpULessThanEqual: + return a <= b; + case SpvOp::SpvOpSLessThanEqual: + return (static_cast(a)) <= (static_cast(b)); + case SpvOp::SpvOpUGreaterThanEqual: + return a >= b; + case SpvOp::SpvOpSGreaterThanEqual: + return (static_cast(a)) >= (static_cast(b)); + default: + assert(false && + "Unsupported binary operation for OpSpecConstantOp instruction"); + return 0u; + } +} + +uint32_t InstructionFolder::TernaryOperate(SpvOp opcode, uint32_t a, uint32_t b, + uint32_t c) const { + switch (opcode) { + case SpvOp::SpvOpSelect: + return (static_cast(a)) ? b : c; + default: + assert(false && + "Unsupported ternary operation for OpSpecConstantOp instruction"); + return 0u; + } +} + +uint32_t InstructionFolder::OperateWords( + SpvOp opcode, const std::vector& operand_words) const { + switch (operand_words.size()) { + case 1: + return UnaryOperate(opcode, operand_words.front()); + case 2: + return BinaryOperate(opcode, operand_words.front(), operand_words.back()); + case 3: + return TernaryOperate(opcode, operand_words[0], operand_words[1], + operand_words[2]); + default: + assert(false && "Invalid number of operands"); + return 0; + } +} + +bool InstructionFolder::FoldInstructionInternal(Instruction* inst) const { + auto identity_map = [](uint32_t id) { return id; }; + Instruction* folded_inst = FoldInstructionToConstant(inst, identity_map); + if (folded_inst != nullptr) { + inst->SetOpcode(SpvOpCopyObject); + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {folded_inst->result_id()}}}); + return true; + } + + SpvOp opcode = inst->opcode(); + analysis::ConstantManager* const_manager = context_->get_constant_mgr(); + + std::vector constants = + const_manager->GetOperandConstants(inst); + + for (const FoldingRule& rule : GetFoldingRules().GetRulesForOpcode(opcode)) { + if (rule(context_, inst, constants)) { + return true; + } + } + return false; +} + +// Returns the result of performing an operation on scalar constant operands. +// This function extracts the operand values as 32 bit words and returns the +// result in 32 bit word. Scalar constants with longer than 32-bit width are +// not accepted in this function. +uint32_t InstructionFolder::FoldScalars( + SpvOp opcode, + const std::vector& operands) const { + assert(IsFoldableOpcode(opcode) && + "Unhandled instruction opcode in FoldScalars"); + std::vector operand_values_in_raw_words; + for (const auto& operand : operands) { + if (const analysis::ScalarConstant* scalar = operand->AsScalarConstant()) { + const auto& scalar_words = scalar->words(); + assert(scalar_words.size() == 1 && + "Scalar constants with longer than 32-bit width are not allowed " + "in FoldScalars()"); + operand_values_in_raw_words.push_back(scalar_words.front()); + } else if (operand->AsNullConstant()) { + operand_values_in_raw_words.push_back(0u); + } else { + assert(false && + "FoldScalars() only accepts ScalarConst or NullConst type of " + "constant"); + } + } + return OperateWords(opcode, operand_values_in_raw_words); +} + +bool InstructionFolder::FoldBinaryIntegerOpToConstant( + Instruction* inst, const std::function& id_map, + uint32_t* result) const { + SpvOp opcode = inst->opcode(); + analysis::ConstantManager* const_manger = context_->get_constant_mgr(); + + uint32_t ids[2]; + const analysis::IntConstant* constants[2]; + for (uint32_t i = 0; i < 2; i++) { + const Operand* operand = &inst->GetInOperand(i); + if (operand->type != SPV_OPERAND_TYPE_ID) { + return false; + } + ids[i] = id_map(operand->words[0]); + const analysis::Constant* constant = + const_manger->FindDeclaredConstant(ids[i]); + constants[i] = (constant != nullptr ? constant->AsIntConstant() : nullptr); + } + + switch (opcode) { + // Arthimetics + case SpvOp::SpvOpIMul: + for (uint32_t i = 0; i < 2; i++) { + if (constants[i] != nullptr && constants[i]->IsZero()) { + *result = 0; + return true; + } + } + break; + case SpvOp::SpvOpUDiv: + case SpvOp::SpvOpSDiv: + case SpvOp::SpvOpSRem: + case SpvOp::SpvOpSMod: + case SpvOp::SpvOpUMod: + // This changes undefined behaviour (ie divide by 0) into a 0. + for (uint32_t i = 0; i < 2; i++) { + if (constants[i] != nullptr && constants[i]->IsZero()) { + *result = 0; + return true; + } + } + break; + + // Shifting + case SpvOp::SpvOpShiftRightLogical: + case SpvOp::SpvOpShiftLeftLogical: + if (constants[1] != nullptr) { + // When shifting by a value larger than the size of the result, the + // result is undefined. We are setting the undefined behaviour to a + // result of 0. If the shift amount is the same as the size of the + // result, then the result is defined, and it 0. + uint32_t shift_amount = constants[1]->GetU32BitValue(); + if (shift_amount >= 32) { + *result = 0; + return true; + } + } + break; + + // Bitwise operations + case SpvOp::SpvOpBitwiseOr: + for (uint32_t i = 0; i < 2; i++) { + if (constants[i] != nullptr) { + // TODO: Change the mask against a value based on the bit width of the + // instruction result type. This way we can handle say 16-bit values + // as well. + uint32_t mask = constants[i]->GetU32BitValue(); + if (mask == 0xFFFFFFFF) { + *result = 0xFFFFFFFF; + return true; + } + } + } + break; + case SpvOp::SpvOpBitwiseAnd: + for (uint32_t i = 0; i < 2; i++) { + if (constants[i] != nullptr) { + if (constants[i]->IsZero()) { + *result = 0; + return true; + } + } + } + break; + + // Comparison + case SpvOp::SpvOpULessThan: + if (constants[0] != nullptr && + constants[0]->GetU32BitValue() == UINT32_MAX) { + *result = false; + return true; + } + if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) { + *result = false; + return true; + } + break; + case SpvOp::SpvOpSLessThan: + if (constants[0] != nullptr && + constants[0]->GetS32BitValue() == INT32_MAX) { + *result = false; + return true; + } + if (constants[1] != nullptr && + constants[1]->GetS32BitValue() == INT32_MIN) { + *result = false; + return true; + } + break; + case SpvOp::SpvOpUGreaterThan: + if (constants[0] != nullptr && constants[0]->IsZero()) { + *result = false; + return true; + } + if (constants[1] != nullptr && + constants[1]->GetU32BitValue() == UINT32_MAX) { + *result = false; + return true; + } + break; + case SpvOp::SpvOpSGreaterThan: + if (constants[0] != nullptr && + constants[0]->GetS32BitValue() == INT32_MIN) { + *result = false; + return true; + } + if (constants[1] != nullptr && + constants[1]->GetS32BitValue() == INT32_MAX) { + *result = false; + return true; + } + break; + case SpvOp::SpvOpULessThanEqual: + if (constants[0] != nullptr && constants[0]->IsZero()) { + *result = true; + return true; + } + if (constants[1] != nullptr && + constants[1]->GetU32BitValue() == UINT32_MAX) { + *result = true; + return true; + } + break; + case SpvOp::SpvOpSLessThanEqual: + if (constants[0] != nullptr && + constants[0]->GetS32BitValue() == INT32_MIN) { + *result = true; + return true; + } + if (constants[1] != nullptr && + constants[1]->GetS32BitValue() == INT32_MAX) { + *result = true; + return true; + } + break; + case SpvOp::SpvOpUGreaterThanEqual: + if (constants[0] != nullptr && + constants[0]->GetU32BitValue() == UINT32_MAX) { + *result = true; + return true; + } + if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) { + *result = true; + return true; + } + break; + case SpvOp::SpvOpSGreaterThanEqual: + if (constants[0] != nullptr && + constants[0]->GetS32BitValue() == INT32_MAX) { + *result = true; + return true; + } + if (constants[1] != nullptr && + constants[1]->GetS32BitValue() == INT32_MIN) { + *result = true; + return true; + } + break; + default: + break; + } + return false; +} + +bool InstructionFolder::FoldBinaryBooleanOpToConstant( + Instruction* inst, const std::function& id_map, + uint32_t* result) const { + SpvOp opcode = inst->opcode(); + analysis::ConstantManager* const_manger = context_->get_constant_mgr(); + + uint32_t ids[2]; + const analysis::BoolConstant* constants[2]; + for (uint32_t i = 0; i < 2; i++) { + const Operand* operand = &inst->GetInOperand(i); + if (operand->type != SPV_OPERAND_TYPE_ID) { + return false; + } + ids[i] = id_map(operand->words[0]); + const analysis::Constant* constant = + const_manger->FindDeclaredConstant(ids[i]); + constants[i] = (constant != nullptr ? constant->AsBoolConstant() : nullptr); + } + + switch (opcode) { + // Logical + case SpvOp::SpvOpLogicalOr: + for (uint32_t i = 0; i < 2; i++) { + if (constants[i] != nullptr) { + if (constants[i]->value()) { + *result = true; + return true; + } + } + } + break; + case SpvOp::SpvOpLogicalAnd: + for (uint32_t i = 0; i < 2; i++) { + if (constants[i] != nullptr) { + if (!constants[i]->value()) { + *result = false; + return true; + } + } + } + break; + + default: + break; + } + return false; +} + +bool InstructionFolder::FoldIntegerOpToConstant( + Instruction* inst, const std::function& id_map, + uint32_t* result) const { + assert(IsFoldableOpcode(inst->opcode()) && + "Unhandled instruction opcode in FoldScalars"); + switch (inst->NumInOperands()) { + case 2: + return FoldBinaryIntegerOpToConstant(inst, id_map, result) || + FoldBinaryBooleanOpToConstant(inst, id_map, result); + default: + return false; + } +} + +std::vector InstructionFolder::FoldVectors( + SpvOp opcode, uint32_t num_dims, + const std::vector& operands) const { + assert(IsFoldableOpcode(opcode) && + "Unhandled instruction opcode in FoldVectors"); + std::vector result; + for (uint32_t d = 0; d < num_dims; d++) { + std::vector operand_values_for_one_dimension; + for (const auto& operand : operands) { + if (const analysis::VectorConstant* vector_operand = + operand->AsVectorConstant()) { + // Extract the raw value of the scalar component constants + // in 32-bit words here. The reason of not using FoldScalars() here + // is that we do not create temporary null constants as components + // when the vector operand is a NullConstant because Constant creation + // may need extra checks for the validity and that is not manageed in + // here. + if (const analysis::ScalarConstant* scalar_component = + vector_operand->GetComponents().at(d)->AsScalarConstant()) { + const auto& scalar_words = scalar_component->words(); + assert( + scalar_words.size() == 1 && + "Vector components with longer than 32-bit width are not allowed " + "in FoldVectors()"); + operand_values_for_one_dimension.push_back(scalar_words.front()); + } else if (operand->AsNullConstant()) { + operand_values_for_one_dimension.push_back(0u); + } else { + assert(false && + "VectorConst should only has ScalarConst or NullConst as " + "components"); + } + } else if (operand->AsNullConstant()) { + operand_values_for_one_dimension.push_back(0u); + } else { + assert(false && + "FoldVectors() only accepts VectorConst or NullConst type of " + "constant"); + } + } + result.push_back(OperateWords(opcode, operand_values_for_one_dimension)); + } + return result; +} + +bool InstructionFolder::IsFoldableOpcode(SpvOp opcode) const { + // NOTE: Extend to more opcodes as new cases are handled in the folder + // functions. + switch (opcode) { + case SpvOp::SpvOpBitwiseAnd: + case SpvOp::SpvOpBitwiseOr: + case SpvOp::SpvOpBitwiseXor: + case SpvOp::SpvOpIAdd: + case SpvOp::SpvOpIEqual: + case SpvOp::SpvOpIMul: + case SpvOp::SpvOpINotEqual: + case SpvOp::SpvOpISub: + case SpvOp::SpvOpLogicalAnd: + case SpvOp::SpvOpLogicalEqual: + case SpvOp::SpvOpLogicalNot: + case SpvOp::SpvOpLogicalNotEqual: + case SpvOp::SpvOpLogicalOr: + case SpvOp::SpvOpNot: + case SpvOp::SpvOpSDiv: + case SpvOp::SpvOpSelect: + case SpvOp::SpvOpSGreaterThan: + case SpvOp::SpvOpSGreaterThanEqual: + case SpvOp::SpvOpShiftLeftLogical: + case SpvOp::SpvOpShiftRightArithmetic: + case SpvOp::SpvOpShiftRightLogical: + case SpvOp::SpvOpSLessThan: + case SpvOp::SpvOpSLessThanEqual: + case SpvOp::SpvOpSMod: + case SpvOp::SpvOpSNegate: + case SpvOp::SpvOpSRem: + case SpvOp::SpvOpUDiv: + case SpvOp::SpvOpUGreaterThan: + case SpvOp::SpvOpUGreaterThanEqual: + case SpvOp::SpvOpULessThan: + case SpvOp::SpvOpULessThanEqual: + case SpvOp::SpvOpUMod: + return true; + default: + return false; + } +} + +bool InstructionFolder::IsFoldableConstant( + const analysis::Constant* cst) const { + // Currently supported constants are 32-bit values or null constants. + if (const analysis::ScalarConstant* scalar = cst->AsScalarConstant()) + return scalar->words().size() == 1; + else + return cst->AsNullConstant() != nullptr; +} + +Instruction* InstructionFolder::FoldInstructionToConstant( + Instruction* inst, std::function id_map) const { + analysis::ConstantManager* const_mgr = context_->get_constant_mgr(); + + if (!inst->IsFoldableByFoldScalar() && + !GetConstantFoldingRules().HasFoldingRule(inst->opcode())) { + return nullptr; + } + // Collect the values of the constant parameters. + std::vector constants; + bool missing_constants = false; + inst->ForEachInId([&constants, &missing_constants, const_mgr, + &id_map](uint32_t* op_id) { + uint32_t id = id_map(*op_id); + const analysis::Constant* const_op = const_mgr->FindDeclaredConstant(id); + if (!const_op) { + constants.push_back(nullptr); + missing_constants = true; + } else { + constants.push_back(const_op); + } + }); + + if (GetConstantFoldingRules().HasFoldingRule(inst->opcode())) { + const analysis::Constant* folded_const = nullptr; + for (auto rule : + GetConstantFoldingRules().GetRulesForOpcode(inst->opcode())) { + folded_const = rule(context_, inst, constants); + if (folded_const != nullptr) { + Instruction* const_inst = + const_mgr->GetDefiningInstruction(folded_const, inst->type_id()); + assert(const_inst->type_id() == inst->type_id()); + // May be a new instruction that needs to be analysed. + context_->UpdateDefUse(const_inst); + return const_inst; + } + } + } + + uint32_t result_val = 0; + bool successful = false; + // If all parameters are constant, fold the instruction to a constant. + if (!missing_constants && inst->IsFoldableByFoldScalar()) { + result_val = FoldScalars(inst->opcode(), constants); + successful = true; + } + + if (!successful && inst->IsFoldableByFoldScalar()) { + successful = FoldIntegerOpToConstant(inst, id_map, &result_val); + } + + if (successful) { + const analysis::Constant* result_const = + const_mgr->GetConstant(const_mgr->GetType(inst), {result_val}); + Instruction* folded_inst = + const_mgr->GetDefiningInstruction(result_const, inst->type_id()); + return folded_inst; + } + return nullptr; +} + +bool InstructionFolder::IsFoldableType(Instruction* type_inst) const { + // Support 32-bit integers. + if (type_inst->opcode() == SpvOpTypeInt) { + return type_inst->GetSingleWordInOperand(0) == 32; + } + // Support booleans. + if (type_inst->opcode() == SpvOpTypeBool) { + return true; + } + // Nothing else yet. + return false; +} + +bool InstructionFolder::FoldInstruction(Instruction* inst) const { + bool modified = false; + Instruction* folded_inst(inst); + while (folded_inst->opcode() != SpvOpCopyObject && + FoldInstructionInternal(&*folded_inst)) { + modified = true; + } + return modified; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/fold.h b/source/opt/fold.h new file mode 100644 index 000000000..0dc7c0ebb --- /dev/null +++ b/source/opt/fold.h @@ -0,0 +1,171 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_FOLD_H_ +#define SOURCE_OPT_FOLD_H_ + +#include +#include + +#include "source/opt/const_folding_rules.h" +#include "source/opt/constants.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/folding_rules.h" + +namespace spvtools { +namespace opt { + +class InstructionFolder { + public: + explicit InstructionFolder(IRContext* context) : context_(context) {} + + // Returns the result of folding a scalar instruction with the given |opcode| + // and |operands|. Each entry in |operands| is a pointer to an + // analysis::Constant instance, which should've been created with the constant + // manager (See IRContext::get_constant_mgr). + // + // It is an error to call this function with an opcode that does not pass the + // IsFoldableOpcode test. If any error occurs during folding, the folder will + // fail with a call to assert. + uint32_t FoldScalars( + SpvOp opcode, + const std::vector& operands) const; + + // Returns the result of performing an operation with the given |opcode| over + // constant vectors with |num_dims| dimensions. Each entry in |operands| is a + // pointer to an analysis::Constant instance, which should've been created + // with the constant manager (See IRContext::get_constant_mgr). + // + // This function iterates through the given vector type constant operands and + // calculates the result for each element of the result vector to return. + // Vectors with longer than 32-bit scalar components are not accepted in this + // function. + // + // It is an error to call this function with an opcode that does not pass the + // IsFoldableOpcode test. If any error occurs during folding, the folder will + // fail with a call to assert. + std::vector FoldVectors( + SpvOp opcode, uint32_t num_dims, + const std::vector& operands) const; + + // Returns true if |opcode| represents an operation handled by FoldScalars or + // FoldVectors. + bool IsFoldableOpcode(SpvOp opcode) const; + + // Returns true if |cst| is supported by FoldScalars and FoldVectors. + bool IsFoldableConstant(const analysis::Constant* cst) const; + + // Returns true if |FoldInstructionToConstant| could fold an instruction whose + // result type is |type_inst|. + bool IsFoldableType(Instruction* type_inst) const; + + // Tries to fold |inst| to a single constant, when the input ids to |inst| + // have been substituted using |id_map|. Returns a pointer to the OpConstant* + // instruction if successful. If necessary, a new constant instruction is + // created and placed in the global values section. + // + // |id_map| is a function that takes one result id and returns another. It + // can be used for things like CCP where it is known that some ids contain a + // constant, but the instruction itself has not been updated yet. This can + // map those ids to the appropriate constants. + Instruction* FoldInstructionToConstant( + Instruction* inst, std::function id_map) const; + // Returns true if |inst| can be folded into a simpler instruction. + // If |inst| can be simplified, |inst| is overwritten with the simplified + // instruction reusing the same result id. + // + // If |inst| is simplified, it is possible that the resulting code in invalid + // because the instruction is in a bad location. Callers of this function + // have to handle the following cases: + // + // 1) An OpPhi becomes and OpCopyObject - If there are OpPhi instruction after + // |inst| in a basic block then this is invalid. The caller must fix this + // up. + bool FoldInstruction(Instruction* inst) const; + + // Return true if this opcode has a const folding rule associtated with it. + bool HasConstFoldingRule(SpvOp opcode) const { + return GetConstantFoldingRules().HasFoldingRule(opcode); + } + + private: + // Returns a reference to the ConstnatFoldingRules instance. + const ConstantFoldingRules& GetConstantFoldingRules() const { + return const_folding_rules; + } + + // Returns a reference to the FoldingRules instance. + const FoldingRules& GetFoldingRules() const { return folding_rules; } + + // Returns the single-word result from performing the given unary operation on + // the operand value which is passed in as a 32-bit word. + uint32_t UnaryOperate(SpvOp opcode, uint32_t operand) const; + + // Returns the single-word result from performing the given binary operation + // on the operand values which are passed in as two 32-bit word. + uint32_t BinaryOperate(SpvOp opcode, uint32_t a, uint32_t b) const; + + // Returns the single-word result from performing the given ternary operation + // on the operand values which are passed in as three 32-bit word. + uint32_t TernaryOperate(SpvOp opcode, uint32_t a, uint32_t b, + uint32_t c) const; + + // Returns the single-word result from performing the given operation on the + // operand words. This only works with 32-bit operations and uses boolean + // convention that 0u is false, and anything else is boolean true. + // TODO(qining): Support operands other than 32-bit wide. + uint32_t OperateWords(SpvOp opcode, + const std::vector& operand_words) const; + + bool FoldInstructionInternal(Instruction* inst) const; + + // Returns true if |inst| is a binary operation that takes two integers as + // parameters and folds to a constant that can be represented as an unsigned + // 32-bit value when the ids have been replaced by |id_map|. If |inst| can be + // folded, the resulting value is returned in |*result|. Valid result types + // for the instruction are any integer (signed or unsigned) with 32-bits or + // less, or a boolean value. + bool FoldBinaryIntegerOpToConstant( + Instruction* inst, const std::function& id_map, + uint32_t* result) const; + + // Returns true if |inst| is a binary operation on two boolean values, and + // folds + // to a constant boolean value when the ids have been replaced using |id_map|. + // If |inst| can be folded, the result value is returned in |*result|. + bool FoldBinaryBooleanOpToConstant( + Instruction* inst, const std::function& id_map, + uint32_t* result) const; + + // Returns true if |inst| can be folded to an constant when the ids have been + // substituted using id_map. If it can, the value is returned in |result|. If + // not, |result| is unchanged. It is assumed that not all operands are + // constant. Those cases are handled by |FoldScalar|. + bool FoldIntegerOpToConstant(Instruction* inst, + const std::function& id_map, + uint32_t* result) const; + + IRContext* context_; + + // Folding rules used by |FoldInstructionToConstant| and |FoldInstruction|. + ConstantFoldingRules const_folding_rules; + + // Folding rules used by |FoldInstruction|. + FoldingRules folding_rules; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_FOLD_H_ diff --git a/source/opt/fold_spec_constant_op_and_composite_pass.cpp b/source/opt/fold_spec_constant_op_and_composite_pass.cpp new file mode 100644 index 000000000..663d112d4 --- /dev/null +++ b/source/opt/fold_spec_constant_op_and_composite_pass.cpp @@ -0,0 +1,385 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/fold_spec_constant_op_and_composite_pass.h" + +#include +#include +#include + +#include "source/opt/constants.h" +#include "source/opt/fold.h" +#include "source/opt/ir_context.h" +#include "source/util/make_unique.h" + +namespace spvtools { +namespace opt { + +Pass::Status FoldSpecConstantOpAndCompositePass::Process() { + bool modified = false; + // Traverse through all the constant defining instructions. For Normal + // Constants whose values are determined and do not depend on OpUndef + // instructions, records their values in two internal maps: id_to_const_val_ + // and const_val_to_id_ so that we can use them to infer the value of Spec + // Constants later. + // For Spec Constants defined with OpSpecConstantComposite instructions, if + // all of their components are Normal Constants, they will be turned into + // Normal Constants too. For Spec Constants defined with OpSpecConstantOp + // instructions, we check if they only depends on Normal Constants and fold + // them when possible. The two maps for Normal Constants: id_to_const_val_ + // and const_val_to_id_ will be updated along the traversal so that the new + // Normal Constants generated from folding can be used to fold following Spec + // Constants. + // This algorithm depends on the SSA property of SPIR-V when + // defining constants. The dependent constants must be defined before the + // dependee constants. So a dependent Spec Constant must be defined and + // will be processed before its dependee Spec Constant. When we encounter + // the dependee Spec Constants, all its dependent constants must have been + // processed and all its dependent Spec Constants should have been folded if + // possible. + Module::inst_iterator next_inst = context()->types_values_begin(); + for (Module::inst_iterator inst_iter = next_inst; + // Need to re-evaluate the end iterator since we may modify the list of + // instructions in this section of the module as the process goes. + inst_iter != context()->types_values_end(); inst_iter = next_inst) { + ++next_inst; + Instruction* inst = &*inst_iter; + // Collect constant values of normal constants and process the + // OpSpecConstantOp and OpSpecConstantComposite instructions if possible. + // The constant values will be stored in analysis::Constant instances. + // OpConstantSampler instruction is not collected here because it cannot be + // used in OpSpecConstant{Composite|Op} instructions. + // TODO(qining): If the constant or its type has decoration, we may need + // to skip it. + if (context()->get_constant_mgr()->GetType(inst) && + !context()->get_constant_mgr()->GetType(inst)->decoration_empty()) + continue; + switch (SpvOp opcode = inst->opcode()) { + // Records the values of Normal Constants. + case SpvOp::SpvOpConstantTrue: + case SpvOp::SpvOpConstantFalse: + case SpvOp::SpvOpConstant: + case SpvOp::SpvOpConstantNull: + case SpvOp::SpvOpConstantComposite: + case SpvOp::SpvOpSpecConstantComposite: { + // A Constant instance will be created if the given instruction is a + // Normal Constant whose value(s) are fixed. Note that for a composite + // Spec Constant defined with OpSpecConstantComposite instruction, if + // all of its components are Normal Constants already, the Spec + // Constant will be turned in to a Normal Constant. In that case, a + // Constant instance should also be created successfully and recorded + // in the id_to_const_val_ and const_val_to_id_ mapps. + if (auto const_value = + context()->get_constant_mgr()->GetConstantFromInst(inst)) { + // Need to replace the OpSpecConstantComposite instruction with a + // corresponding OpConstantComposite instruction. + if (opcode == SpvOp::SpvOpSpecConstantComposite) { + inst->SetOpcode(SpvOp::SpvOpConstantComposite); + modified = true; + } + context()->get_constant_mgr()->MapConstantToInst(const_value, inst); + } + break; + } + // For a Spec Constants defined with OpSpecConstantOp instruction, check + // if it only depends on Normal Constants. If so, the Spec Constant will + // be folded. The original Spec Constant defining instruction will be + // replaced by Normal Constant defining instructions, and the new Normal + // Constants will be added to id_to_const_val_ and const_val_to_id_ so + // that we can use the new Normal Constants when folding following Spec + // Constants. + case SpvOp::SpvOpSpecConstantOp: + modified |= ProcessOpSpecConstantOp(&inst_iter); + break; + default: + break; + } + } + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +bool FoldSpecConstantOpAndCompositePass::ProcessOpSpecConstantOp( + Module::inst_iterator* pos) { + Instruction* inst = &**pos; + Instruction* folded_inst = nullptr; + assert(inst->GetInOperand(0).type == + SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER && + "The first in-operand of OpSpecContantOp instruction must be of " + "SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER type"); + + switch (static_cast(inst->GetSingleWordInOperand(0))) { + case SpvOp::SpvOpCompositeExtract: + folded_inst = DoCompositeExtract(pos); + break; + case SpvOp::SpvOpVectorShuffle: + folded_inst = DoVectorShuffle(pos); + break; + + case SpvOp::SpvOpCompositeInsert: + // Current Glslang does not generate code with OpSpecConstantOp + // CompositeInsert instruction, so this is not implmented so far. + // TODO(qining): Implement CompositeInsert case. + return false; + + default: + // Component-wise operations. + folded_inst = DoComponentWiseOperation(pos); + break; + } + if (!folded_inst) return false; + + // Replace the original constant with the new folded constant, kill the + // original constant. + uint32_t new_id = folded_inst->result_id(); + uint32_t old_id = inst->result_id(); + context()->ReplaceAllUsesWith(old_id, new_id); + context()->KillDef(old_id); + return true; +} + +uint32_t FoldSpecConstantOpAndCompositePass::GetTypeComponent( + uint32_t typeId, uint32_t element) const { + Instruction* type = context()->get_def_use_mgr()->GetDef(typeId); + uint32_t subtype = type->GetTypeComponent(element); + assert(subtype != 0); + + return subtype; +} + +Instruction* FoldSpecConstantOpAndCompositePass::DoCompositeExtract( + Module::inst_iterator* pos) { + Instruction* inst = &**pos; + assert(inst->NumInOperands() - 1 >= 2 && + "OpSpecConstantOp CompositeExtract requires at least two non-type " + "non-opcode operands."); + assert(inst->GetInOperand(1).type == SPV_OPERAND_TYPE_ID && + "The composite operand must have a SPV_OPERAND_TYPE_ID type"); + assert( + inst->GetInOperand(2).type == SPV_OPERAND_TYPE_LITERAL_INTEGER && + "The literal operand must have a SPV_OPERAND_TYPE_LITERAL_INTEGER type"); + + // Note that for OpSpecConstantOp, the second in-operand is the first id + // operand. The first in-operand is the spec opcode. + uint32_t source = inst->GetSingleWordInOperand(1); + uint32_t type = context()->get_def_use_mgr()->GetDef(source)->type_id(); + const analysis::Constant* first_operand_const = + context()->get_constant_mgr()->FindDeclaredConstant(source); + if (!first_operand_const) return nullptr; + + const analysis::Constant* current_const = first_operand_const; + for (uint32_t i = 2; i < inst->NumInOperands(); i++) { + uint32_t literal = inst->GetSingleWordInOperand(i); + type = GetTypeComponent(type, literal); + } + for (uint32_t i = 2; i < inst->NumInOperands(); i++) { + uint32_t literal = inst->GetSingleWordInOperand(i); + if (const analysis::CompositeConstant* composite_const = + current_const->AsCompositeConstant()) { + // Case 1: current constant is a non-null composite type constant. + assert(literal < composite_const->GetComponents().size() && + "Literal index out of bound of the composite constant"); + current_const = composite_const->GetComponents().at(literal); + } else if (current_const->AsNullConstant()) { + // Case 2: current constant is a constant created with OpConstantNull. + // Because components of a NullConstant are always NullConstants, we can + // return early with a NullConstant in the result type. + return context()->get_constant_mgr()->BuildInstructionAndAddToModule( + context()->get_constant_mgr()->GetConstant( + context()->get_constant_mgr()->GetType(inst), {}), + pos, type); + } else { + // Dereferencing a non-composite constant. Invalid case. + return nullptr; + } + } + return context()->get_constant_mgr()->BuildInstructionAndAddToModule( + current_const, pos); +} + +Instruction* FoldSpecConstantOpAndCompositePass::DoVectorShuffle( + Module::inst_iterator* pos) { + Instruction* inst = &**pos; + analysis::Vector* result_vec_type = + context()->get_constant_mgr()->GetType(inst)->AsVector(); + assert(inst->NumInOperands() - 1 > 2 && + "OpSpecConstantOp DoVectorShuffle instruction requires more than 2 " + "operands (2 vector ids and at least one literal operand"); + assert(result_vec_type && + "The result of VectorShuffle must be of type vector"); + + // A temporary null constants that can be used as the components of the result + // vector. This is needed when any one of the vector operands are null + // constant. + const analysis::Constant* null_component_constants = nullptr; + + // Get a concatenated vector of scalar constants. The vector should be built + // with the components from the first and the second operand of VectorShuffle. + std::vector concatenated_components; + // Note that for OpSpecConstantOp, the second in-operand is the first id + // operand. The first in-operand is the spec opcode. + for (uint32_t i : {1, 2}) { + assert(inst->GetInOperand(i).type == SPV_OPERAND_TYPE_ID && + "The vector operand must have a SPV_OPERAND_TYPE_ID type"); + uint32_t operand_id = inst->GetSingleWordInOperand(i); + auto operand_const = + context()->get_constant_mgr()->FindDeclaredConstant(operand_id); + if (!operand_const) return nullptr; + const analysis::Type* operand_type = operand_const->type(); + assert(operand_type->AsVector() && + "The first two operand of VectorShuffle must be of vector type"); + if (auto vec_const = operand_const->AsVectorConstant()) { + // case 1: current operand is a non-null vector constant. + concatenated_components.insert(concatenated_components.end(), + vec_const->GetComponents().begin(), + vec_const->GetComponents().end()); + } else if (operand_const->AsNullConstant()) { + // case 2: current operand is a null vector constant. Create a temporary + // null scalar constant as the component. + if (!null_component_constants) { + const analysis::Type* component_type = + operand_type->AsVector()->element_type(); + null_component_constants = + context()->get_constant_mgr()->GetConstant(component_type, {}); + } + // Append the null scalar consts to the concatenated components + // vector. + concatenated_components.insert(concatenated_components.end(), + operand_type->AsVector()->element_count(), + null_component_constants); + } else { + // no other valid cases + return nullptr; + } + } + // Create null component constants if there are any. The component constants + // must be added to the module before the dependee composite constants to + // satisfy SSA def-use dominance. + if (null_component_constants) { + context()->get_constant_mgr()->BuildInstructionAndAddToModule( + null_component_constants, pos); + } + // Create the new vector constant with the selected components. + std::vector selected_components; + for (uint32_t i = 3; i < inst->NumInOperands(); i++) { + assert(inst->GetInOperand(i).type == SPV_OPERAND_TYPE_LITERAL_INTEGER && + "The literal operand must of type SPV_OPERAND_TYPE_LITERAL_INTEGER"); + uint32_t literal = inst->GetSingleWordInOperand(i); + assert(literal < concatenated_components.size() && + "Literal index out of bound of the concatenated vector"); + selected_components.push_back(concatenated_components[literal]); + } + auto new_vec_const = MakeUnique( + result_vec_type, selected_components); + auto reg_vec_const = + context()->get_constant_mgr()->RegisterConstant(std::move(new_vec_const)); + return context()->get_constant_mgr()->BuildInstructionAndAddToModule( + reg_vec_const, pos); +} + +namespace { +// A helper function to check the type for component wise operations. Returns +// true if the type: +// 1) is bool type; +// 2) is 32-bit int type; +// 3) is vector of bool type; +// 4) is vector of 32-bit integer type. +// Otherwise returns false. +bool IsValidTypeForComponentWiseOperation(const analysis::Type* type) { + if (type->AsBool()) { + return true; + } else if (auto* it = type->AsInteger()) { + if (it->width() == 32) return true; + } else if (auto* vt = type->AsVector()) { + if (vt->element_type()->AsBool()) { + return true; + } else if (auto* vit = vt->element_type()->AsInteger()) { + if (vit->width() == 32) return true; + } + } + return false; +} +} // namespace + +Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation( + Module::inst_iterator* pos) { + const Instruction* inst = &**pos; + const analysis::Type* result_type = + context()->get_constant_mgr()->GetType(inst); + SpvOp spec_opcode = static_cast(inst->GetSingleWordInOperand(0)); + // Check and collect operands. + std::vector operands; + + if (!std::all_of( + inst->cbegin(), inst->cend(), [&operands, this](const Operand& o) { + // skip the operands that is not an id. + if (o.type != spv_operand_type_t::SPV_OPERAND_TYPE_ID) return true; + uint32_t id = o.words.front(); + if (auto c = + context()->get_constant_mgr()->FindDeclaredConstant(id)) { + if (IsValidTypeForComponentWiseOperation(c->type())) { + operands.push_back(c); + return true; + } + } + return false; + })) + return nullptr; + + if (result_type->AsInteger() || result_type->AsBool()) { + // Scalar operation + uint32_t result_val = + context()->get_instruction_folder().FoldScalars(spec_opcode, operands); + auto result_const = + context()->get_constant_mgr()->GetConstant(result_type, {result_val}); + return context()->get_constant_mgr()->BuildInstructionAndAddToModule( + result_const, pos); + } else if (result_type->AsVector()) { + // Vector operation + const analysis::Type* element_type = + result_type->AsVector()->element_type(); + uint32_t num_dims = result_type->AsVector()->element_count(); + std::vector result_vec = + context()->get_instruction_folder().FoldVectors(spec_opcode, num_dims, + operands); + std::vector result_vector_components; + for (uint32_t r : result_vec) { + if (auto rc = + context()->get_constant_mgr()->GetConstant(element_type, {r})) { + result_vector_components.push_back(rc); + if (!context()->get_constant_mgr()->BuildInstructionAndAddToModule( + rc, pos)) { + assert(false && + "Failed to build and insert constant declaring instruction " + "for the given vector component constant"); + } + } else { + assert(false && "Failed to create constants with 32-bit word"); + } + } + auto new_vec_const = MakeUnique( + result_type->AsVector(), result_vector_components); + auto reg_vec_const = context()->get_constant_mgr()->RegisterConstant( + std::move(new_vec_const)); + return context()->get_constant_mgr()->BuildInstructionAndAddToModule( + reg_vec_const, pos); + } else { + // Cannot process invalid component wise operation. The result of component + // wise operation must be of integer or bool scalar or vector of + // integer/bool type. + return nullptr; + } +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/fold_spec_constant_op_and_composite_pass.h b/source/opt/fold_spec_constant_op_and_composite_pass.h new file mode 100644 index 000000000..16271251f --- /dev/null +++ b/source/opt/fold_spec_constant_op_and_composite_pass.h @@ -0,0 +1,84 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_FOLD_SPEC_CONSTANT_OP_AND_COMPOSITE_PASS_H_ +#define SOURCE_OPT_FOLD_SPEC_CONSTANT_OP_AND_COMPOSITE_PASS_H_ + +#include +#include +#include + +#include "source/opt/constants.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" +#include "source/opt/type_manager.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class FoldSpecConstantOpAndCompositePass : public Pass { + public: + FoldSpecConstantOpAndCompositePass() = default; + + const char* name() const override { return "fold-spec-const-op-composite"; } + + // Iterates through the types-constants-globals section of the given module, + // finds the Spec Constants defined with OpSpecConstantOp and + // OpSpecConstantComposite instructions. If the result value of those spec + // constants can be folded, fold them to their corresponding normal constants. + Status Process() override; + + private: + // Processes the OpSpecConstantOp instruction pointed by the given + // instruction iterator, folds it to normal constants if possible. Returns + // true if the spec constant is folded to normal constants. New instructions + // will be inserted before the OpSpecConstantOp instruction pointed by the + // instruction iterator. The instruction iterator, which is passed by + // pointer, will still point to the original OpSpecConstantOp instruction. If + // folding is done successfully, the original OpSpecConstantOp instruction + // will be changed to Nop and new folded instruction will be inserted before + // it. + bool ProcessOpSpecConstantOp(Module::inst_iterator* pos); + + // Try to fold the OpSpecConstantOp CompositeExtract instruction pointed by + // the given instruction iterator to a normal constant defining instruction. + // Returns the pointer to the new constant defining instruction if succeeded. + // Otherwise returns nullptr. + Instruction* DoCompositeExtract(Module::inst_iterator* inst_iter_ptr); + + // Try to fold the OpSpecConstantOp VectorShuffle instruction pointed by the + // given instruction iterator to a normal constant defining instruction. + // Returns the pointer to the new constant defining instruction if succeeded. + // Otherwise return nullptr. + Instruction* DoVectorShuffle(Module::inst_iterator* inst_iter_ptr); + + // Try to fold the OpSpecConstantOp instruction + // pointed by the given instruction iterator to a normal constant defining + // instruction. Returns the pointer to the new constant defining instruction + // if succeeded, otherwise return nullptr. + Instruction* DoComponentWiseOperation(Module::inst_iterator* inst_iter_ptr); + + // Returns the |element|'th subtype of |type|. + // + // |type| must be a composite type. + uint32_t GetTypeComponent(uint32_t type, uint32_t element) const; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_FOLD_SPEC_CONSTANT_OP_AND_COMPOSITE_PASS_H_ diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp new file mode 100644 index 000000000..327431969 --- /dev/null +++ b/source/opt/folding_rules.cpp @@ -0,0 +1,2243 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/folding_rules.h" + +#include +#include +#include + +#include "source/latest_version_glsl_std_450_header.h" +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace opt { +namespace { + +const uint32_t kExtractCompositeIdInIdx = 0; +const uint32_t kInsertObjectIdInIdx = 0; +const uint32_t kInsertCompositeIdInIdx = 1; +const uint32_t kExtInstSetIdInIdx = 0; +const uint32_t kExtInstInstructionInIdx = 1; +const uint32_t kFMixXIdInIdx = 2; +const uint32_t kFMixYIdInIdx = 3; +const uint32_t kFMixAIdInIdx = 4; +const uint32_t kStoreObjectInIdx = 1; + +// Returns the element width of |type|. +uint32_t ElementWidth(const analysis::Type* type) { + if (const analysis::Vector* vec_type = type->AsVector()) { + return ElementWidth(vec_type->element_type()); + } else if (const analysis::Float* float_type = type->AsFloat()) { + return float_type->width(); + } else { + assert(type->AsInteger()); + return type->AsInteger()->width(); + } +} + +// Returns true if |type| is Float or a vector of Float. +bool HasFloatingPoint(const analysis::Type* type) { + if (type->AsFloat()) { + return true; + } else if (const analysis::Vector* vec_type = type->AsVector()) { + return vec_type->element_type()->AsFloat() != nullptr; + } + + return false; +} + +// Returns false if |val| is NaN, infinite or subnormal. +template +bool IsValidResult(T val) { + int classified = std::fpclassify(val); + switch (classified) { + case FP_NAN: + case FP_INFINITE: + case FP_SUBNORMAL: + return false; + default: + return true; + } +} + +const analysis::Constant* ConstInput( + const std::vector& constants) { + return constants[0] ? constants[0] : constants[1]; +} + +Instruction* NonConstInput(IRContext* context, const analysis::Constant* c, + Instruction* inst) { + uint32_t in_op = c ? 1u : 0u; + return context->get_def_use_mgr()->GetDef( + inst->GetSingleWordInOperand(in_op)); +} + +// Returns the negation of |c|. |c| must be a 32 or 64 bit floating point +// constant. +uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr, + const analysis::Constant* c) { + assert(c); + assert(c->type()->AsFloat()); + uint32_t width = c->type()->AsFloat()->width(); + assert(width == 32 || width == 64); + std::vector words; + if (width == 64) { + utils::FloatProxy result(c->GetDouble() * -1.0); + words = result.GetWords(); + } else { + utils::FloatProxy result(c->GetFloat() * -1.0f); + words = result.GetWords(); + } + + const analysis::Constant* negated_const = + const_mgr->GetConstant(c->type(), std::move(words)); + return const_mgr->GetDefiningInstruction(negated_const)->result_id(); +} + +std::vector ExtractInts(uint64_t val) { + std::vector words; + words.push_back(static_cast(val)); + words.push_back(static_cast(val >> 32)); + return words; +} + +// Negates the integer constant |c|. Returns the id of the defining instruction. +uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr, + const analysis::Constant* c) { + assert(c); + assert(c->type()->AsInteger()); + uint32_t width = c->type()->AsInteger()->width(); + assert(width == 32 || width == 64); + std::vector words; + if (width == 64) { + uint64_t uval = static_cast(0 - c->GetU64()); + words = ExtractInts(uval); + } else { + words.push_back(static_cast(0 - c->GetU32())); + } + + const analysis::Constant* negated_const = + const_mgr->GetConstant(c->type(), std::move(words)); + return const_mgr->GetDefiningInstruction(negated_const)->result_id(); +} + +// Negates the vector constant |c|. Returns the id of the defining instruction. +uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr, + const analysis::Constant* c) { + assert(const_mgr && c); + assert(c->type()->AsVector()); + if (c->AsNullConstant()) { + // 0.0 vs -0.0 shouldn't matter. + return const_mgr->GetDefiningInstruction(c)->result_id(); + } else { + const analysis::Type* component_type = + c->AsVectorConstant()->component_type(); + std::vector words; + for (auto& comp : c->AsVectorConstant()->GetComponents()) { + if (component_type->AsFloat()) { + words.push_back(NegateFloatingPointConstant(const_mgr, comp)); + } else { + assert(component_type->AsInteger()); + words.push_back(NegateIntegerConstant(const_mgr, comp)); + } + } + + const analysis::Constant* negated_const = + const_mgr->GetConstant(c->type(), std::move(words)); + return const_mgr->GetDefiningInstruction(negated_const)->result_id(); + } +} + +// Negates |c|. Returns the id of the defining instruction. +uint32_t NegateConstant(analysis::ConstantManager* const_mgr, + const analysis::Constant* c) { + if (c->type()->AsVector()) { + return NegateVectorConstant(const_mgr, c); + } else if (c->type()->AsFloat()) { + return NegateFloatingPointConstant(const_mgr, c); + } else { + assert(c->type()->AsInteger()); + return NegateIntegerConstant(const_mgr, c); + } +} + +// Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float. +// Returns 0 if the reciprocal is NaN, infinite or subnormal. +uint32_t Reciprocal(analysis::ConstantManager* const_mgr, + const analysis::Constant* c) { + assert(const_mgr && c); + assert(c->type()->AsFloat()); + + uint32_t width = c->type()->AsFloat()->width(); + assert(width == 32 || width == 64); + std::vector words; + if (width == 64) { + spvtools::utils::FloatProxy result(1.0 / c->GetDouble()); + if (!IsValidResult(result.getAsFloat())) return 0; + words = result.GetWords(); + } else { + spvtools::utils::FloatProxy result(1.0f / c->GetFloat()); + if (!IsValidResult(result.getAsFloat())) return 0; + words = result.GetWords(); + } + + const analysis::Constant* negated_const = + const_mgr->GetConstant(c->type(), std::move(words)); + return const_mgr->GetDefiningInstruction(negated_const)->result_id(); +} + +// Replaces fdiv where second operand is constant with fmul. +FoldingRule ReciprocalFDiv() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpFDiv); + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + const analysis::Type* type = + context->get_type_mgr()->GetType(inst->type_id()); + if (!inst->IsFloatingPointFoldingAllowed()) return false; + + uint32_t width = ElementWidth(type); + if (width != 32 && width != 64) return false; + + if (constants[1] != nullptr) { + uint32_t id = 0; + if (const analysis::VectorConstant* vector_const = + constants[1]->AsVectorConstant()) { + std::vector neg_ids; + for (auto& comp : vector_const->GetComponents()) { + id = Reciprocal(const_mgr, comp); + if (id == 0) return false; + neg_ids.push_back(id); + } + const analysis::Constant* negated_const = + const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids)); + id = const_mgr->GetDefiningInstruction(negated_const)->result_id(); + } else if (constants[1]->AsFloatConstant()) { + id = Reciprocal(const_mgr, constants[1]); + if (id == 0) return false; + } else { + // Don't fold a null constant. + return false; + } + inst->SetOpcode(SpvOpFMul); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}}, + {SPV_OPERAND_TYPE_ID, {id}}}); + return true; + } + + return false; + }; +} + +// Elides consecutive negate instructions. +FoldingRule MergeNegateArithmetic() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate); + (void)constants; + const analysis::Type* type = + context->get_type_mgr()->GetType(inst->type_id()); + if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed()) + return false; + + Instruction* op_inst = + context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u)); + if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed()) + return false; + + if (op_inst->opcode() == inst->opcode()) { + // Elide negates. + inst->SetOpcode(SpvOpCopyObject); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}}); + return true; + } + + return false; + }; +} + +// Merges negate into a mul or div operation if that operation contains a +// constant operand. +// Cases: +// -(x * 2) = x * -2 +// -(2 * x) = x * -2 +// -(x / 2) = x / -2 +// -(2 / x) = -2 / x +FoldingRule MergeNegateMulDivArithmetic() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate); + (void)constants; + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + const analysis::Type* type = + context->get_type_mgr()->GetType(inst->type_id()); + if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed()) + return false; + + Instruction* op_inst = + context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u)); + if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed()) + return false; + + uint32_t width = ElementWidth(type); + if (width != 32 && width != 64) return false; + + SpvOp opcode = op_inst->opcode(); + if (opcode == SpvOpFMul || opcode == SpvOpFDiv || opcode == SpvOpIMul || + opcode == SpvOpSDiv || opcode == SpvOpUDiv) { + std::vector op_constants = + const_mgr->GetOperandConstants(op_inst); + // Merge negate into mul or div if one operand is constant. + if (op_constants[0] || op_constants[1]) { + bool zero_is_variable = op_constants[0] == nullptr; + const analysis::Constant* c = ConstInput(op_constants); + uint32_t neg_id = NegateConstant(const_mgr, c); + uint32_t non_const_id = zero_is_variable + ? op_inst->GetSingleWordInOperand(0u) + : op_inst->GetSingleWordInOperand(1u); + // Change this instruction to a mul/div. + inst->SetOpcode(op_inst->opcode()); + if (opcode == SpvOpFDiv || opcode == SpvOpUDiv || opcode == SpvOpSDiv) { + uint32_t op0 = zero_is_variable ? non_const_id : neg_id; + uint32_t op1 = zero_is_variable ? neg_id : non_const_id; + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}}); + } else { + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}}, + {SPV_OPERAND_TYPE_ID, {neg_id}}}); + } + return true; + } + } + + return false; + }; +} + +// Merges negate into a add or sub operation if that operation contains a +// constant operand. +// Cases: +// -(x + 2) = -2 - x +// -(2 + x) = -2 - x +// -(x - 2) = 2 - x +// -(2 - x) = x - 2 +FoldingRule MergeNegateAddSubArithmetic() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate); + (void)constants; + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + const analysis::Type* type = + context->get_type_mgr()->GetType(inst->type_id()); + if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed()) + return false; + + Instruction* op_inst = + context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u)); + if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed()) + return false; + + uint32_t width = ElementWidth(type); + if (width != 32 && width != 64) return false; + + if (op_inst->opcode() == SpvOpFAdd || op_inst->opcode() == SpvOpFSub || + op_inst->opcode() == SpvOpIAdd || op_inst->opcode() == SpvOpISub) { + std::vector op_constants = + const_mgr->GetOperandConstants(op_inst); + if (op_constants[0] || op_constants[1]) { + bool zero_is_variable = op_constants[0] == nullptr; + bool is_add = (op_inst->opcode() == SpvOpFAdd) || + (op_inst->opcode() == SpvOpIAdd); + bool swap_operands = !is_add || zero_is_variable; + bool negate_const = is_add; + const analysis::Constant* c = ConstInput(op_constants); + uint32_t const_id = 0; + if (negate_const) { + const_id = NegateConstant(const_mgr, c); + } else { + const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u) + : op_inst->GetSingleWordInOperand(0u); + } + + // Swap operands if necessary and make the instruction a subtraction. + uint32_t op0 = + zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id; + uint32_t op1 = + zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u); + if (swap_operands) std::swap(op0, op1); + inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}}); + return true; + } + } + + return false; + }; +} + +// Returns true if |c| has a zero element. +bool HasZero(const analysis::Constant* c) { + if (c->AsNullConstant()) { + return true; + } + if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) { + for (auto& comp : vec_const->GetComponents()) + if (HasZero(comp)) return true; + } else { + assert(c->AsScalarConstant()); + return c->AsScalarConstant()->IsZero(); + } + + return false; +} + +// Performs |input1| |opcode| |input2| and returns the merged constant result +// id. Returns 0 if the result is not a valid value. The input types must be +// Float. +uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr, + SpvOp opcode, + const analysis::Constant* input1, + const analysis::Constant* input2) { + const analysis::Type* type = input1->type(); + assert(type->AsFloat()); + uint32_t width = type->AsFloat()->width(); + assert(width == 32 || width == 64); + std::vector words; +#define FOLD_OP(op) \ + if (width == 64) { \ + utils::FloatProxy val = \ + input1->GetDouble() op input2->GetDouble(); \ + double dval = val.getAsFloat(); \ + if (!IsValidResult(dval)) return 0; \ + words = val.GetWords(); \ + } else { \ + utils::FloatProxy val = input1->GetFloat() op input2->GetFloat(); \ + float fval = val.getAsFloat(); \ + if (!IsValidResult(fval)) return 0; \ + words = val.GetWords(); \ + } + switch (opcode) { + case SpvOpFMul: + FOLD_OP(*); + break; + case SpvOpFDiv: + if (HasZero(input2)) return 0; + FOLD_OP(/); + break; + case SpvOpFAdd: + FOLD_OP(+); + break; + case SpvOpFSub: + FOLD_OP(-); + break; + default: + assert(false && "Unexpected operation"); + break; + } +#undef FOLD_OP + const analysis::Constant* merged_const = const_mgr->GetConstant(type, words); + return const_mgr->GetDefiningInstruction(merged_const)->result_id(); +} + +// Performs |input1| |opcode| |input2| and returns the merged constant result +// id. Returns 0 if the result is not a valid value. The input types must be +// Integers. +uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr, + SpvOp opcode, const analysis::Constant* input1, + const analysis::Constant* input2) { + assert(input1->type()->AsInteger()); + const analysis::Integer* type = input1->type()->AsInteger(); + uint32_t width = type->AsInteger()->width(); + assert(width == 32 || width == 64); + std::vector words; +#define FOLD_OP(op) \ + if (width == 64) { \ + if (type->IsSigned()) { \ + int64_t val = input1->GetS64() op input2->GetS64(); \ + words = ExtractInts(static_cast(val)); \ + } else { \ + uint64_t val = input1->GetU64() op input2->GetU64(); \ + words = ExtractInts(val); \ + } \ + } else { \ + if (type->IsSigned()) { \ + int32_t val = input1->GetS32() op input2->GetS32(); \ + words.push_back(static_cast(val)); \ + } else { \ + uint32_t val = input1->GetU32() op input2->GetU32(); \ + words.push_back(val); \ + } \ + } + switch (opcode) { + case SpvOpIMul: + FOLD_OP(*); + break; + case SpvOpSDiv: + case SpvOpUDiv: + assert(false && "Should not merge integer division"); + break; + case SpvOpIAdd: + FOLD_OP(+); + break; + case SpvOpISub: + FOLD_OP(-); + break; + default: + assert(false && "Unexpected operation"); + break; + } +#undef FOLD_OP + const analysis::Constant* merged_const = const_mgr->GetConstant(type, words); + return const_mgr->GetDefiningInstruction(merged_const)->result_id(); +} + +// Performs |input1| |opcode| |input2| and returns the merged constant result +// id. Returns 0 if the result is not a valid value. The input types must be +// Integers, Floats or Vectors of such. +uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode, + const analysis::Constant* input1, + const analysis::Constant* input2) { + assert(input1 && input2); + assert(input1->type() == input2->type()); + const analysis::Type* type = input1->type(); + std::vector words; + if (const analysis::Vector* vector_type = type->AsVector()) { + const analysis::Type* ele_type = vector_type->element_type(); + for (uint32_t i = 0; i != vector_type->element_count(); ++i) { + uint32_t id = 0; + + const analysis::Constant* input1_comp = nullptr; + if (const analysis::VectorConstant* input1_vector = + input1->AsVectorConstant()) { + input1_comp = input1_vector->GetComponents()[i]; + } else { + assert(input1->AsNullConstant()); + input1_comp = const_mgr->GetConstant(ele_type, {}); + } + + const analysis::Constant* input2_comp = nullptr; + if (const analysis::VectorConstant* input2_vector = + input2->AsVectorConstant()) { + input2_comp = input2_vector->GetComponents()[i]; + } else { + assert(input2->AsNullConstant()); + input2_comp = const_mgr->GetConstant(ele_type, {}); + } + + if (ele_type->AsFloat()) { + id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp, + input2_comp); + } else { + assert(ele_type->AsInteger()); + id = PerformIntegerOperation(const_mgr, opcode, input1_comp, + input2_comp); + } + if (id == 0) return 0; + words.push_back(id); + } + const analysis::Constant* merged_const = + const_mgr->GetConstant(type, words); + return const_mgr->GetDefiningInstruction(merged_const)->result_id(); + } else if (type->AsFloat()) { + return PerformFloatingPointOperation(const_mgr, opcode, input1, input2); + } else { + assert(type->AsInteger()); + return PerformIntegerOperation(const_mgr, opcode, input1, input2); + } +} + +// Merges consecutive multiplies where each contains one constant operand. +// Cases: +// 2 * (x * 2) = x * 4 +// 2 * (2 * x) = x * 4 +// (x * 2) * 2 = x * 4 +// (2 * x) * 2 = x * 4 +FoldingRule MergeMulMulArithmetic() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul); + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + const analysis::Type* type = + context->get_type_mgr()->GetType(inst->type_id()); + if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed()) + return false; + + uint32_t width = ElementWidth(type); + if (width != 32 && width != 64) return false; + + // Determine the constant input and the variable input in |inst|. + const analysis::Constant* const_input1 = ConstInput(constants); + if (!const_input1) return false; + Instruction* other_inst = NonConstInput(context, constants[0], inst); + if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed()) + return false; + + if (other_inst->opcode() == inst->opcode()) { + std::vector other_constants = + const_mgr->GetOperandConstants(other_inst); + const analysis::Constant* const_input2 = ConstInput(other_constants); + if (!const_input2) return false; + + bool other_first_is_variable = other_constants[0] == nullptr; + uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(), + const_input1, const_input2); + if (merged_id == 0) return false; + + uint32_t non_const_id = other_first_is_variable + ? other_inst->GetSingleWordInOperand(0u) + : other_inst->GetSingleWordInOperand(1u); + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}}, + {SPV_OPERAND_TYPE_ID, {merged_id}}}); + return true; + } + + return false; + }; +} + +// Merges divides into subsequent multiplies if each instruction contains one +// constant operand. Does not support integer operations. +// Cases: +// 2 * (x / 2) = x * 1 +// 2 * (2 / x) = 4 / x +// (x / 2) * 2 = x * 1 +// (2 / x) * 2 = 4 / x +// (y / x) * x = y +// x * (y / x) = y +FoldingRule MergeMulDivArithmetic() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpFMul); + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + + const analysis::Type* type = + context->get_type_mgr()->GetType(inst->type_id()); + if (!inst->IsFloatingPointFoldingAllowed()) return false; + + uint32_t width = ElementWidth(type); + if (width != 32 && width != 64) return false; + + for (uint32_t i = 0; i < 2; i++) { + uint32_t op_id = inst->GetSingleWordInOperand(i); + Instruction* op_inst = def_use_mgr->GetDef(op_id); + if (op_inst->opcode() == SpvOpFDiv) { + if (op_inst->GetSingleWordInOperand(1) == + inst->GetSingleWordInOperand(1 - i)) { + inst->SetOpcode(SpvOpCopyObject); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}}); + return true; + } + } + } + + const analysis::Constant* const_input1 = ConstInput(constants); + if (!const_input1) return false; + Instruction* other_inst = NonConstInput(context, constants[0], inst); + if (!other_inst->IsFloatingPointFoldingAllowed()) return false; + + if (other_inst->opcode() == SpvOpFDiv) { + std::vector other_constants = + const_mgr->GetOperandConstants(other_inst); + const analysis::Constant* const_input2 = ConstInput(other_constants); + if (!const_input2 || HasZero(const_input2)) return false; + + bool other_first_is_variable = other_constants[0] == nullptr; + // If the variable value is the second operand of the divide, multiply + // the constants together. Otherwise divide the constants. + uint32_t merged_id = PerformOperation( + const_mgr, + other_first_is_variable ? other_inst->opcode() : inst->opcode(), + const_input1, const_input2); + if (merged_id == 0) return false; + + uint32_t non_const_id = other_first_is_variable + ? other_inst->GetSingleWordInOperand(0u) + : other_inst->GetSingleWordInOperand(1u); + + // If the variable value is on the second operand of the div, then this + // operation is a div. Otherwise it should be a multiply. + inst->SetOpcode(other_first_is_variable ? inst->opcode() + : other_inst->opcode()); + if (other_first_is_variable) { + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}}, + {SPV_OPERAND_TYPE_ID, {merged_id}}}); + } else { + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}}, + {SPV_OPERAND_TYPE_ID, {non_const_id}}}); + } + return true; + } + + return false; + }; +} + +// Merges multiply of constant and negation. +// Cases: +// (-x) * 2 = x * -2 +// 2 * (-x) = x * -2 +FoldingRule MergeMulNegateArithmetic() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul); + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + const analysis::Type* type = + context->get_type_mgr()->GetType(inst->type_id()); + bool uses_float = HasFloatingPoint(type); + if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; + + uint32_t width = ElementWidth(type); + if (width != 32 && width != 64) return false; + + const analysis::Constant* const_input1 = ConstInput(constants); + if (!const_input1) return false; + Instruction* other_inst = NonConstInput(context, constants[0], inst); + if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) + return false; + + if (other_inst->opcode() == SpvOpFNegate || + other_inst->opcode() == SpvOpSNegate) { + uint32_t neg_id = NegateConstant(const_mgr, const_input1); + + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}, + {SPV_OPERAND_TYPE_ID, {neg_id}}}); + return true; + } + + return false; + }; +} + +// Merges consecutive divides if each instruction contains one constant operand. +// Does not support integer division. +// Cases: +// 2 / (x / 2) = 4 / x +// 4 / (2 / x) = 2 * x +// (4 / x) / 2 = 2 / x +// (x / 2) / 2 = x / 4 +FoldingRule MergeDivDivArithmetic() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpFDiv); + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + const analysis::Type* type = + context->get_type_mgr()->GetType(inst->type_id()); + if (!inst->IsFloatingPointFoldingAllowed()) return false; + + uint32_t width = ElementWidth(type); + if (width != 32 && width != 64) return false; + + const analysis::Constant* const_input1 = ConstInput(constants); + if (!const_input1 || HasZero(const_input1)) return false; + Instruction* other_inst = NonConstInput(context, constants[0], inst); + if (!other_inst->IsFloatingPointFoldingAllowed()) return false; + + bool first_is_variable = constants[0] == nullptr; + if (other_inst->opcode() == inst->opcode()) { + std::vector other_constants = + const_mgr->GetOperandConstants(other_inst); + const analysis::Constant* const_input2 = ConstInput(other_constants); + if (!const_input2 || HasZero(const_input2)) return false; + + bool other_first_is_variable = other_constants[0] == nullptr; + + SpvOp merge_op = inst->opcode(); + if (other_first_is_variable) { + // Constants magnify. + merge_op = SpvOpFMul; + } + + // This is an x / (*) case. Swap the inputs. Doesn't harm multiply + // because it is commutative. + if (first_is_variable) std::swap(const_input1, const_input2); + uint32_t merged_id = + PerformOperation(const_mgr, merge_op, const_input1, const_input2); + if (merged_id == 0) return false; + + uint32_t non_const_id = other_first_is_variable + ? other_inst->GetSingleWordInOperand(0u) + : other_inst->GetSingleWordInOperand(1u); + + SpvOp op = inst->opcode(); + if (!first_is_variable && !other_first_is_variable) { + // Effectively div of 1/x, so change to multiply. + op = SpvOpFMul; + } + + uint32_t op1 = merged_id; + uint32_t op2 = non_const_id; + if (first_is_variable && other_first_is_variable) std::swap(op1, op2); + inst->SetOpcode(op); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}); + return true; + } + + return false; + }; +} + +// Fold multiplies succeeded by divides where each instruction contains a +// constant operand. Does not support integer divide. +// Cases: +// 4 / (x * 2) = 2 / x +// 4 / (2 * x) = 2 / x +// (x * 4) / 2 = x * 2 +// (4 * x) / 2 = x * 2 +// (x * y) / x = y +// (y * x) / x = y +FoldingRule MergeDivMulArithmetic() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpFDiv); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + + const analysis::Type* type = + context->get_type_mgr()->GetType(inst->type_id()); + if (!inst->IsFloatingPointFoldingAllowed()) return false; + + uint32_t width = ElementWidth(type); + if (width != 32 && width != 64) return false; + + uint32_t op_id = inst->GetSingleWordInOperand(0); + Instruction* op_inst = def_use_mgr->GetDef(op_id); + + if (op_inst->opcode() == SpvOpFMul) { + for (uint32_t i = 0; i < 2; i++) { + if (op_inst->GetSingleWordInOperand(i) == + inst->GetSingleWordInOperand(1)) { + inst->SetOpcode(SpvOpCopyObject); + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, + {op_inst->GetSingleWordInOperand(1 - i)}}}); + return true; + } + } + } + + const analysis::Constant* const_input1 = ConstInput(constants); + if (!const_input1 || HasZero(const_input1)) return false; + Instruction* other_inst = NonConstInput(context, constants[0], inst); + if (!other_inst->IsFloatingPointFoldingAllowed()) return false; + + bool first_is_variable = constants[0] == nullptr; + if (other_inst->opcode() == SpvOpFMul) { + std::vector other_constants = + const_mgr->GetOperandConstants(other_inst); + const analysis::Constant* const_input2 = ConstInput(other_constants); + if (!const_input2) return false; + + bool other_first_is_variable = other_constants[0] == nullptr; + + // This is an x / (*) case. Swap the inputs. + if (first_is_variable) std::swap(const_input1, const_input2); + uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(), + const_input1, const_input2); + if (merged_id == 0) return false; + + uint32_t non_const_id = other_first_is_variable + ? other_inst->GetSingleWordInOperand(0u) + : other_inst->GetSingleWordInOperand(1u); + + uint32_t op1 = merged_id; + uint32_t op2 = non_const_id; + if (first_is_variable) std::swap(op1, op2); + + // Convert to multiply + if (first_is_variable) inst->SetOpcode(other_inst->opcode()); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}); + return true; + } + + return false; + }; +} + +// Fold divides of a constant and a negation. +// Cases: +// (-x) / 2 = x / -2 +// 2 / (-x) = 2 / -x +FoldingRule MergeDivNegateArithmetic() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpFDiv || inst->opcode() == SpvOpSDiv || + inst->opcode() == SpvOpUDiv); + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + const analysis::Type* type = + context->get_type_mgr()->GetType(inst->type_id()); + bool uses_float = HasFloatingPoint(type); + if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; + + uint32_t width = ElementWidth(type); + if (width != 32 && width != 64) return false; + + const analysis::Constant* const_input1 = ConstInput(constants); + if (!const_input1) return false; + Instruction* other_inst = NonConstInput(context, constants[0], inst); + if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) + return false; + + bool first_is_variable = constants[0] == nullptr; + if (other_inst->opcode() == SpvOpFNegate || + other_inst->opcode() == SpvOpSNegate) { + uint32_t neg_id = NegateConstant(const_mgr, const_input1); + + if (first_is_variable) { + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}, + {SPV_OPERAND_TYPE_ID, {neg_id}}}); + } else { + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {neg_id}}, + {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}}); + } + return true; + } + + return false; + }; +} + +// Folds addition of a constant and a negation. +// Cases: +// (-x) + 2 = 2 - x +// 2 + (-x) = 2 - x +FoldingRule MergeAddNegateArithmetic() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd); + const analysis::Type* type = + context->get_type_mgr()->GetType(inst->type_id()); + bool uses_float = HasFloatingPoint(type); + if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; + + const analysis::Constant* const_input1 = ConstInput(constants); + if (!const_input1) return false; + Instruction* other_inst = NonConstInput(context, constants[0], inst); + if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) + return false; + + if (other_inst->opcode() == SpvOpSNegate || + other_inst->opcode() == SpvOpFNegate) { + inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub); + uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u) + : inst->GetSingleWordInOperand(1u); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {const_id}}, + {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}}); + return true; + } + return false; + }; +} + +// Folds subtraction of a constant and a negation. +// Cases: +// (-x) - 2 = -2 - x +// 2 - (-x) = x + 2 +FoldingRule MergeSubNegateArithmetic() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub); + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + const analysis::Type* type = + context->get_type_mgr()->GetType(inst->type_id()); + bool uses_float = HasFloatingPoint(type); + if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; + + uint32_t width = ElementWidth(type); + if (width != 32 && width != 64) return false; + + const analysis::Constant* const_input1 = ConstInput(constants); + if (!const_input1) return false; + Instruction* other_inst = NonConstInput(context, constants[0], inst); + if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) + return false; + + if (other_inst->opcode() == SpvOpSNegate || + other_inst->opcode() == SpvOpFNegate) { + uint32_t op1 = 0; + uint32_t op2 = 0; + SpvOp opcode = inst->opcode(); + if (constants[0] != nullptr) { + op1 = other_inst->GetSingleWordInOperand(0u); + op2 = inst->GetSingleWordInOperand(0u); + opcode = HasFloatingPoint(type) ? SpvOpFAdd : SpvOpIAdd; + } else { + op1 = NegateConstant(const_mgr, const_input1); + op2 = other_inst->GetSingleWordInOperand(0u); + } + + inst->SetOpcode(opcode); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}); + return true; + } + return false; + }; +} + +// Folds addition of an addition where each operation has a constant operand. +// Cases: +// (x + 2) + 2 = x + 4 +// (2 + x) + 2 = x + 4 +// 2 + (x + 2) = x + 4 +// 2 + (2 + x) = x + 4 +FoldingRule MergeAddAddArithmetic() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd); + const analysis::Type* type = + context->get_type_mgr()->GetType(inst->type_id()); + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + bool uses_float = HasFloatingPoint(type); + if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; + + uint32_t width = ElementWidth(type); + if (width != 32 && width != 64) return false; + + const analysis::Constant* const_input1 = ConstInput(constants); + if (!const_input1) return false; + Instruction* other_inst = NonConstInput(context, constants[0], inst); + if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) + return false; + + if (other_inst->opcode() == SpvOpFAdd || + other_inst->opcode() == SpvOpIAdd) { + std::vector other_constants = + const_mgr->GetOperandConstants(other_inst); + const analysis::Constant* const_input2 = ConstInput(other_constants); + if (!const_input2) return false; + + Instruction* non_const_input = + NonConstInput(context, other_constants[0], other_inst); + uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(), + const_input1, const_input2); + if (merged_id == 0) return false; + + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}}, + {SPV_OPERAND_TYPE_ID, {merged_id}}}); + return true; + } + return false; + }; +} + +// Folds addition of a subtraction where each operation has a constant operand. +// Cases: +// (x - 2) + 2 = x + 0 +// (2 - x) + 2 = 4 - x +// 2 + (x - 2) = x + 0 +// 2 + (2 - x) = 4 - x +FoldingRule MergeAddSubArithmetic() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd); + const analysis::Type* type = + context->get_type_mgr()->GetType(inst->type_id()); + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + bool uses_float = HasFloatingPoint(type); + if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; + + uint32_t width = ElementWidth(type); + if (width != 32 && width != 64) return false; + + const analysis::Constant* const_input1 = ConstInput(constants); + if (!const_input1) return false; + Instruction* other_inst = NonConstInput(context, constants[0], inst); + if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) + return false; + + if (other_inst->opcode() == SpvOpFSub || + other_inst->opcode() == SpvOpISub) { + std::vector other_constants = + const_mgr->GetOperandConstants(other_inst); + const analysis::Constant* const_input2 = ConstInput(other_constants); + if (!const_input2) return false; + + bool first_is_variable = other_constants[0] == nullptr; + SpvOp op = inst->opcode(); + uint32_t op1 = 0; + uint32_t op2 = 0; + if (first_is_variable) { + // Subtract constants. Non-constant operand is first. + op1 = other_inst->GetSingleWordInOperand(0u); + op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1, + const_input2); + } else { + // Add constants. Constant operand is first. Change the opcode. + op1 = PerformOperation(const_mgr, inst->opcode(), const_input1, + const_input2); + op2 = other_inst->GetSingleWordInOperand(1u); + op = other_inst->opcode(); + } + if (op1 == 0 || op2 == 0) return false; + + inst->SetOpcode(op); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}); + return true; + } + return false; + }; +} + +// Folds subtraction of an addition where each operand has a constant operand. +// Cases: +// (x + 2) - 2 = x + 0 +// (2 + x) - 2 = x + 0 +// 2 - (x + 2) = 0 - x +// 2 - (2 + x) = 0 - x +FoldingRule MergeSubAddArithmetic() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub); + const analysis::Type* type = + context->get_type_mgr()->GetType(inst->type_id()); + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + bool uses_float = HasFloatingPoint(type); + if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; + + uint32_t width = ElementWidth(type); + if (width != 32 && width != 64) return false; + + const analysis::Constant* const_input1 = ConstInput(constants); + if (!const_input1) return false; + Instruction* other_inst = NonConstInput(context, constants[0], inst); + if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) + return false; + + if (other_inst->opcode() == SpvOpFAdd || + other_inst->opcode() == SpvOpIAdd) { + std::vector other_constants = + const_mgr->GetOperandConstants(other_inst); + const analysis::Constant* const_input2 = ConstInput(other_constants); + if (!const_input2) return false; + + Instruction* non_const_input = + NonConstInput(context, other_constants[0], other_inst); + + // If the first operand of the sub is not a constant, swap the constants + // so the subtraction has the correct operands. + if (constants[0] == nullptr) std::swap(const_input1, const_input2); + // Subtract the constants. + uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(), + const_input1, const_input2); + SpvOp op = inst->opcode(); + uint32_t op1 = 0; + uint32_t op2 = 0; + if (constants[0] == nullptr) { + // Non-constant operand is first. Change the opcode. + op1 = non_const_input->result_id(); + op2 = merged_id; + op = other_inst->opcode(); + } else { + // Constant operand is first. + op1 = merged_id; + op2 = non_const_input->result_id(); + } + if (op1 == 0 || op2 == 0) return false; + + inst->SetOpcode(op); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}); + return true; + } + return false; + }; +} + +// Folds subtraction of a subtraction where each operand has a constant operand. +// Cases: +// (x - 2) - 2 = x - 4 +// (2 - x) - 2 = 0 - x +// 2 - (x - 2) = 4 - x +// 2 - (2 - x) = x + 0 +FoldingRule MergeSubSubArithmetic() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub); + const analysis::Type* type = + context->get_type_mgr()->GetType(inst->type_id()); + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + bool uses_float = HasFloatingPoint(type); + if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; + + uint32_t width = ElementWidth(type); + if (width != 32 && width != 64) return false; + + const analysis::Constant* const_input1 = ConstInput(constants); + if (!const_input1) return false; + Instruction* other_inst = NonConstInput(context, constants[0], inst); + if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) + return false; + + if (other_inst->opcode() == SpvOpFSub || + other_inst->opcode() == SpvOpISub) { + std::vector other_constants = + const_mgr->GetOperandConstants(other_inst); + const analysis::Constant* const_input2 = ConstInput(other_constants); + if (!const_input2) return false; + + Instruction* non_const_input = + NonConstInput(context, other_constants[0], other_inst); + + // Merge the constants. + uint32_t merged_id = 0; + SpvOp merge_op = inst->opcode(); + if (other_constants[0] == nullptr) { + merge_op = uses_float ? SpvOpFAdd : SpvOpIAdd; + } else if (constants[0] == nullptr) { + std::swap(const_input1, const_input2); + } + merged_id = + PerformOperation(const_mgr, merge_op, const_input1, const_input2); + if (merged_id == 0) return false; + + SpvOp op = inst->opcode(); + if (constants[0] != nullptr && other_constants[0] != nullptr) { + // Change the operation. + op = uses_float ? SpvOpFAdd : SpvOpIAdd; + } + + uint32_t op1 = 0; + uint32_t op2 = 0; + if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) { + op1 = merged_id; + op2 = non_const_input->result_id(); + } else { + op1 = non_const_input->result_id(); + op2 = merged_id; + } + + inst->SetOpcode(op); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}); + return true; + } + return false; + }; +} + +FoldingRule IntMultipleBy1() { + return [](IRContext*, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpIMul && "Wrong opcode. Should be OpIMul."); + for (uint32_t i = 0; i < 2; i++) { + if (constants[i] == nullptr) { + continue; + } + const analysis::IntConstant* int_constant = constants[i]->AsIntConstant(); + if (int_constant) { + uint32_t width = ElementWidth(int_constant->type()); + if (width != 32 && width != 64) return false; + bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u + : int_constant->GetU64BitValue() == 1ull; + if (is_one) { + inst->SetOpcode(SpvOpCopyObject); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}}); + return true; + } + } + } + return false; + }; +} + +FoldingRule CompositeConstructFeedingExtract() { + return [](IRContext* context, Instruction* inst, + const std::vector&) { + // If the input to an OpCompositeExtract is an OpCompositeConstruct, + // then we can simply use the appropriate element in the construction. + assert(inst->opcode() == SpvOpCompositeExtract && + "Wrong opcode. Should be OpCompositeExtract."); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + analysis::TypeManager* type_mgr = context->get_type_mgr(); + uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); + Instruction* cinst = def_use_mgr->GetDef(cid); + + if (cinst->opcode() != SpvOpCompositeConstruct) { + return false; + } + + std::vector operands; + analysis::Type* composite_type = type_mgr->GetType(cinst->type_id()); + if (composite_type->AsVector() == nullptr) { + // Get the element being extracted from the OpCompositeConstruct + // Since it is not a vector, it is simple to extract the single element. + uint32_t element_index = inst->GetSingleWordInOperand(1); + uint32_t element_id = cinst->GetSingleWordInOperand(element_index); + operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}}); + + // Add the remaining indices for extraction. + for (uint32_t i = 2; i < inst->NumInOperands(); ++i) { + operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, + {inst->GetSingleWordInOperand(i)}}); + } + + } else { + // With vectors we have to handle the case where it is concatenating + // vectors. + assert(inst->NumInOperands() == 2 && + "Expecting a vector of scalar values."); + + uint32_t element_index = inst->GetSingleWordInOperand(1); + for (uint32_t construct_index = 0; + construct_index < cinst->NumInOperands(); ++construct_index) { + uint32_t element_id = cinst->GetSingleWordInOperand(construct_index); + Instruction* element_def = def_use_mgr->GetDef(element_id); + analysis::Vector* element_type = + type_mgr->GetType(element_def->type_id())->AsVector(); + if (element_type) { + uint32_t vector_size = element_type->element_count(); + if (vector_size < element_index) { + // The element we want comes after this vector. + element_index -= vector_size; + } else { + // We want an element of this vector. + operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}}); + operands.push_back( + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {element_index}}); + break; + } + } else { + if (element_index == 0) { + // This is a scalar, and we this is the element we are extracting. + operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}}); + break; + } else { + // Skip over this scalar value. + --element_index; + } + } + } + } + + // If there were no extra indices, then we have the final object. No need + // to extract even more. + if (operands.size() == 1) { + inst->SetOpcode(SpvOpCopyObject); + } + + inst->SetInOperands(std::move(operands)); + return true; + }; +} + +FoldingRule CompositeExtractFeedingConstruct() { + // If the OpCompositeConstruct is simply putting back together elements that + // where extracted from the same souce, we can simlpy reuse the source. + // + // This is a common code pattern because of the way that scalar replacement + // works. + return [](IRContext* context, Instruction* inst, + const std::vector&) { + assert(inst->opcode() == SpvOpCompositeConstruct && + "Wrong opcode. Should be OpCompositeConstruct."); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + uint32_t original_id = 0; + + // Check each element to make sure they are: + // - extractions + // - extracting the same position they are inserting + // - all extract from the same id. + for (uint32_t i = 0; i < inst->NumInOperands(); ++i) { + uint32_t element_id = inst->GetSingleWordInOperand(i); + Instruction* element_inst = def_use_mgr->GetDef(element_id); + + if (element_inst->opcode() != SpvOpCompositeExtract) { + return false; + } + + if (element_inst->NumInOperands() != 2) { + return false; + } + + if (element_inst->GetSingleWordInOperand(1) != i) { + return false; + } + + if (i == 0) { + original_id = + element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); + } else if (original_id != element_inst->GetSingleWordInOperand( + kExtractCompositeIdInIdx)) { + return false; + } + } + + // The last check it to see that the object being extracted from is the + // correct type. + Instruction* original_inst = def_use_mgr->GetDef(original_id); + if (original_inst->type_id() != inst->type_id()) { + return false; + } + + // Simplify by using the original object. + inst->SetOpcode(SpvOpCopyObject); + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}}); + return true; + }; +} + +FoldingRule InsertFeedingExtract() { + return [](IRContext* context, Instruction* inst, + const std::vector&) { + assert(inst->opcode() == SpvOpCompositeExtract && + "Wrong opcode. Should be OpCompositeExtract."); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); + Instruction* cinst = def_use_mgr->GetDef(cid); + + if (cinst->opcode() != SpvOpCompositeInsert) { + return false; + } + + // Find the first position where the list of insert and extract indicies + // differ, if at all. + uint32_t i; + for (i = 1; i < inst->NumInOperands(); ++i) { + if (i + 1 >= cinst->NumInOperands()) { + break; + } + + if (inst->GetSingleWordInOperand(i) != + cinst->GetSingleWordInOperand(i + 1)) { + break; + } + } + + // We are extracting the element that was inserted. + if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) { + inst->SetOpcode(SpvOpCopyObject); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, + {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}}); + return true; + } + + // Extracting the value that was inserted along with values for the base + // composite. Cannot do anything. + if (i == inst->NumInOperands()) { + return false; + } + + // Extracting an element of the value that was inserted. Extract from + // that value directly. + if (i + 1 == cinst->NumInOperands()) { + std::vector operands; + operands.push_back( + {SPV_OPERAND_TYPE_ID, + {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}); + for (; i < inst->NumInOperands(); ++i) { + operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, + {inst->GetSingleWordInOperand(i)}}); + } + inst->SetInOperands(std::move(operands)); + return true; + } + + // Extracting a value that is disjoint from the element being inserted. + // Rewrite the extract to use the composite input to the insert. + std::vector operands; + operands.push_back( + {SPV_OPERAND_TYPE_ID, + {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}}); + for (i = 1; i < inst->NumInOperands(); ++i) { + operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, + {inst->GetSingleWordInOperand(i)}}); + } + inst->SetInOperands(std::move(operands)); + return true; + }; +} + +// When a VectorShuffle is feeding an Extract, we can extract from one of the +// operands of the VectorShuffle. We just need to adjust the index in the +// extract instruction. +FoldingRule VectorShuffleFeedingExtract() { + return [](IRContext* context, Instruction* inst, + const std::vector&) { + assert(inst->opcode() == SpvOpCompositeExtract && + "Wrong opcode. Should be OpCompositeExtract."); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + analysis::TypeManager* type_mgr = context->get_type_mgr(); + uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); + Instruction* cinst = def_use_mgr->GetDef(cid); + + if (cinst->opcode() != SpvOpVectorShuffle) { + return false; + } + + // Find the size of the first vector operand of the VectorShuffle + Instruction* first_input = + def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0)); + analysis::Type* first_input_type = + type_mgr->GetType(first_input->type_id()); + assert(first_input_type->AsVector() && + "Input to vector shuffle should be vectors."); + uint32_t first_input_size = first_input_type->AsVector()->element_count(); + + // Get index of the element the vector shuffle is placing in the position + // being extracted. + uint32_t new_index = + cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1)); + + // Extracting an undefined value so fold this extract into an undef. + const uint32_t undef_literal_value = 0xffffffff; + if (new_index == undef_literal_value) { + inst->SetOpcode(SpvOpUndef); + inst->SetInOperands({}); + return true; + } + + // Get the id of the of the vector the elemtent comes from, and update the + // index if needed. + uint32_t new_vector = 0; + if (new_index < first_input_size) { + new_vector = cinst->GetSingleWordInOperand(0); + } else { + new_vector = cinst->GetSingleWordInOperand(1); + new_index -= first_input_size; + } + + // Update the extract instruction. + inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector}); + inst->SetInOperand(1, {new_index}); + return true; + }; +} + +// When an FMix with is feeding an Extract that extracts an element whose +// corresponding |a| in the FMix is 0 or 1, we can extract from one of the +// operands of the FMix. +FoldingRule FMixFeedingExtract() { + return [](IRContext* context, Instruction* inst, + const std::vector&) { + assert(inst->opcode() == SpvOpCompositeExtract && + "Wrong opcode. Should be OpCompositeExtract."); + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + + uint32_t composite_id = + inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); + Instruction* composite_inst = def_use_mgr->GetDef(composite_id); + + if (composite_inst->opcode() != SpvOpExtInst) { + return false; + } + + uint32_t inst_set_id = + context->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); + + if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) != + inst_set_id || + composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) != + GLSLstd450FMix) { + return false; + } + + // Get the |a| for the FMix instruction. + uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx); + std::unique_ptr a(inst->Clone(context)); + a->SetInOperand(kExtractCompositeIdInIdx, {a_id}); + context->get_instruction_folder().FoldInstruction(a.get()); + + if (a->opcode() != SpvOpCopyObject) { + return false; + } + + const analysis::Constant* a_const = + const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0)); + + if (!a_const) { + return false; + } + + bool use_x = false; + + assert(a_const->type()->AsFloat()); + double element_value = a_const->GetValueAsDouble(); + if (element_value == 0.0) { + use_x = true; + } else if (element_value == 1.0) { + use_x = false; + } else { + return false; + } + + // Get the id of the of the vector the element comes from. + uint32_t new_vector = 0; + if (use_x) { + new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx); + } else { + new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx); + } + + // Update the extract instruction. + inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector}); + return true; + }; +} + +FoldingRule RedundantPhi() { + // An OpPhi instruction where all values are the same or the result of the phi + // itself, can be replaced by the value itself. + return [](IRContext*, Instruction* inst, + const std::vector&) { + assert(inst->opcode() == SpvOpPhi && "Wrong opcode. Should be OpPhi."); + + uint32_t incoming_value = 0; + + for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) { + uint32_t op_id = inst->GetSingleWordInOperand(i); + if (op_id == inst->result_id()) { + continue; + } + + if (incoming_value == 0) { + incoming_value = op_id; + } else if (op_id != incoming_value) { + // Found two possible value. Can't simplify. + return false; + } + } + + if (incoming_value == 0) { + // Code looks invalid. Don't do anything. + return false; + } + + // We have a single incoming value. Simplify using that value. + inst->SetOpcode(SpvOpCopyObject); + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}}); + return true; + }; +} + +FoldingRule RedundantSelect() { + // An OpSelect instruction where both values are the same or the condition is + // constant can be replaced by one of the values + return [](IRContext*, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpSelect && + "Wrong opcode. Should be OpSelect."); + assert(inst->NumInOperands() == 3); + assert(constants.size() == 3); + + uint32_t true_id = inst->GetSingleWordInOperand(1); + uint32_t false_id = inst->GetSingleWordInOperand(2); + + if (true_id == false_id) { + // Both results are the same, condition doesn't matter + inst->SetOpcode(SpvOpCopyObject); + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}}); + return true; + } else if (constants[0]) { + const analysis::Type* type = constants[0]->type(); + if (type->AsBool()) { + // Scalar constant value, select the corresponding value. + inst->SetOpcode(SpvOpCopyObject); + if (constants[0]->AsNullConstant() || + !constants[0]->AsBoolConstant()->value()) { + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}}); + } else { + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}}); + } + return true; + } else { + assert(type->AsVector()); + if (constants[0]->AsNullConstant()) { + // All values come from false id. + inst->SetOpcode(SpvOpCopyObject); + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}}); + return true; + } else { + // Convert to a vector shuffle. + std::vector ops; + ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}}); + ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}}); + const analysis::VectorConstant* vector_const = + constants[0]->AsVectorConstant(); + uint32_t size = + static_cast(vector_const->GetComponents().size()); + for (uint32_t i = 0; i != size; ++i) { + const analysis::Constant* component = + vector_const->GetComponents()[i]; + if (component->AsNullConstant() || + !component->AsBoolConstant()->value()) { + // Selecting from the false vector which is the second input + // vector to the shuffle. Offset the index by |size|. + ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}}); + } else { + // Selecting from true vector which is the first input vector to + // the shuffle. + ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}); + } + } + + inst->SetOpcode(SpvOpVectorShuffle); + inst->SetInOperands(std::move(ops)); + return true; + } + } + } + + return false; + }; +} + +enum class FloatConstantKind { Unknown, Zero, One }; + +FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) { + if (constant == nullptr) { + return FloatConstantKind::Unknown; + } + + assert(HasFloatingPoint(constant->type()) && "Unexpected constant type"); + + if (constant->AsNullConstant()) { + return FloatConstantKind::Zero; + } else if (const analysis::VectorConstant* vc = + constant->AsVectorConstant()) { + const std::vector& components = + vc->GetComponents(); + assert(!components.empty()); + + FloatConstantKind kind = getFloatConstantKind(components[0]); + + for (size_t i = 1; i < components.size(); ++i) { + if (getFloatConstantKind(components[i]) != kind) { + return FloatConstantKind::Unknown; + } + } + + return kind; + } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) { + if (fc->IsZero()) return FloatConstantKind::Zero; + + uint32_t width = fc->type()->AsFloat()->width(); + if (width != 32 && width != 64) return FloatConstantKind::Unknown; + + double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue(); + + if (value == 0.0) { + return FloatConstantKind::Zero; + } else if (value == 1.0) { + return FloatConstantKind::One; + } else { + return FloatConstantKind::Unknown; + } + } else { + return FloatConstantKind::Unknown; + } +} + +FoldingRule RedundantFAdd() { + return [](IRContext*, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpFAdd && "Wrong opcode. Should be OpFAdd."); + assert(constants.size() == 2); + + if (!inst->IsFloatingPointFoldingAllowed()) { + return false; + } + + FloatConstantKind kind0 = getFloatConstantKind(constants[0]); + FloatConstantKind kind1 = getFloatConstantKind(constants[1]); + + if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) { + inst->SetOpcode(SpvOpCopyObject); + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, + {inst->GetSingleWordInOperand( + kind0 == FloatConstantKind::Zero ? 1 : 0)}}}); + return true; + } + + return false; + }; +} + +FoldingRule RedundantFSub() { + return [](IRContext*, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpFSub && "Wrong opcode. Should be OpFSub."); + assert(constants.size() == 2); + + if (!inst->IsFloatingPointFoldingAllowed()) { + return false; + } + + FloatConstantKind kind0 = getFloatConstantKind(constants[0]); + FloatConstantKind kind1 = getFloatConstantKind(constants[1]); + + if (kind0 == FloatConstantKind::Zero) { + inst->SetOpcode(SpvOpFNegate); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}}); + return true; + } + + if (kind1 == FloatConstantKind::Zero) { + inst->SetOpcode(SpvOpCopyObject); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}}); + return true; + } + + return false; + }; +} + +FoldingRule RedundantFMul() { + return [](IRContext*, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpFMul && "Wrong opcode. Should be OpFMul."); + assert(constants.size() == 2); + + if (!inst->IsFloatingPointFoldingAllowed()) { + return false; + } + + FloatConstantKind kind0 = getFloatConstantKind(constants[0]); + FloatConstantKind kind1 = getFloatConstantKind(constants[1]); + + if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) { + inst->SetOpcode(SpvOpCopyObject); + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, + {inst->GetSingleWordInOperand( + kind0 == FloatConstantKind::Zero ? 0 : 1)}}}); + return true; + } + + if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) { + inst->SetOpcode(SpvOpCopyObject); + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, + {inst->GetSingleWordInOperand( + kind0 == FloatConstantKind::One ? 1 : 0)}}}); + return true; + } + + return false; + }; +} + +FoldingRule RedundantFDiv() { + return [](IRContext*, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpFDiv && "Wrong opcode. Should be OpFDiv."); + assert(constants.size() == 2); + + if (!inst->IsFloatingPointFoldingAllowed()) { + return false; + } + + FloatConstantKind kind0 = getFloatConstantKind(constants[0]); + FloatConstantKind kind1 = getFloatConstantKind(constants[1]); + + if (kind0 == FloatConstantKind::Zero) { + inst->SetOpcode(SpvOpCopyObject); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}}); + return true; + } + + if (kind1 == FloatConstantKind::One) { + inst->SetOpcode(SpvOpCopyObject); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}}); + return true; + } + + return false; + }; +} + +FoldingRule RedundantFMix() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpExtInst && + "Wrong opcode. Should be OpExtInst."); + + if (!inst->IsFloatingPointFoldingAllowed()) { + return false; + } + + uint32_t instSetId = + context->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); + + if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId && + inst->GetSingleWordInOperand(kExtInstInstructionInIdx) == + GLSLstd450FMix) { + assert(constants.size() == 5); + + FloatConstantKind kind4 = getFloatConstantKind(constants[4]); + + if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) { + inst->SetOpcode(SpvOpCopyObject); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, + {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero + ? kFMixXIdInIdx + : kFMixYIdInIdx)}}}); + return true; + } + } + + return false; + }; +} + +// This rule handles addition of zero for integers. +FoldingRule RedundantIAdd() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpIAdd && "Wrong opcode. Should be OpIAdd."); + + uint32_t operand = std::numeric_limits::max(); + const analysis::Type* operand_type = nullptr; + if (constants[0] && constants[0]->IsZero()) { + operand = inst->GetSingleWordInOperand(1); + operand_type = constants[0]->type(); + } else if (constants[1] && constants[1]->IsZero()) { + operand = inst->GetSingleWordInOperand(0); + operand_type = constants[1]->type(); + } + + if (operand != std::numeric_limits::max()) { + const analysis::Type* inst_type = + context->get_type_mgr()->GetType(inst->type_id()); + if (inst_type->IsSame(operand_type)) { + inst->SetOpcode(SpvOpCopyObject); + } else { + inst->SetOpcode(SpvOpBitcast); + } + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}}); + return true; + } + return false; + }; +} + +// This rule look for a dot with a constant vector containing a single 1 and +// the rest 0s. This is the same as doing an extract. +FoldingRule DotProductDoingExtract() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == SpvOpDot && "Wrong opcode. Should be OpDot."); + + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + + if (!inst->IsFloatingPointFoldingAllowed()) { + return false; + } + + for (int i = 0; i < 2; ++i) { + if (!constants[i]) { + continue; + } + + const analysis::Vector* vector_type = constants[i]->type()->AsVector(); + assert(vector_type && "Inputs to OpDot must be vectors."); + const analysis::Float* element_type = + vector_type->element_type()->AsFloat(); + assert(element_type && "Inputs to OpDot must be vectors of floats."); + uint32_t element_width = element_type->width(); + if (element_width != 32 && element_width != 64) { + return false; + } + + std::vector components; + components = constants[i]->GetVectorComponents(const_mgr); + + const uint32_t kNotFound = std::numeric_limits::max(); + + uint32_t component_with_one = kNotFound; + bool all_others_zero = true; + for (uint32_t j = 0; j < components.size(); ++j) { + const analysis::Constant* element = components[j]; + double value = + (element_width == 32 ? element->GetFloat() : element->GetDouble()); + if (value == 0.0) { + continue; + } else if (value == 1.0) { + if (component_with_one == kNotFound) { + component_with_one = j; + } else { + component_with_one = kNotFound; + break; + } + } else { + all_others_zero = false; + break; + } + } + + if (!all_others_zero || component_with_one == kNotFound) { + continue; + } + + std::vector operands; + operands.push_back( + {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}}); + operands.push_back( + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}}); + + inst->SetOpcode(SpvOpCompositeExtract); + inst->SetInOperands(std::move(operands)); + return true; + } + return false; + }; +} + +// If we are storing an undef, then we can remove the store. +// +// TODO: We can do something similar for OpImageWrite, but checking for volatile +// is complicated. Waiting to see if it is needed. +FoldingRule StoringUndef() { + return [](IRContext* context, Instruction* inst, + const std::vector&) { + assert(inst->opcode() == SpvOpStore && "Wrong opcode. Should be OpStore."); + + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + + // If this is a volatile store, the store cannot be removed. + if (inst->NumInOperands() == 3) { + if (inst->GetSingleWordInOperand(2) & SpvMemoryAccessVolatileMask) { + return false; + } + } + + uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx); + Instruction* object_inst = def_use_mgr->GetDef(object_id); + if (object_inst->opcode() == SpvOpUndef) { + inst->ToNop(); + return true; + } + return false; + }; +} + +FoldingRule VectorShuffleFeedingShuffle() { + return [](IRContext* context, Instruction* inst, + const std::vector&) { + assert(inst->opcode() == SpvOpVectorShuffle && + "Wrong opcode. Should be OpVectorShuffle."); + + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + analysis::TypeManager* type_mgr = context->get_type_mgr(); + + Instruction* feeding_shuffle_inst = + def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); + analysis::Vector* op0_type = + type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector(); + uint32_t op0_length = op0_type->element_count(); + + bool feeder_is_op0 = true; + if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) { + feeding_shuffle_inst = + def_use_mgr->GetDef(inst->GetSingleWordInOperand(1)); + feeder_is_op0 = false; + } + + if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) { + return false; + } + + Instruction* feeder2 = + def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0)); + analysis::Vector* feeder_op0_type = + type_mgr->GetType(feeder2->type_id())->AsVector(); + uint32_t feeder_op0_length = feeder_op0_type->element_count(); + + uint32_t new_feeder_id = 0; + std::vector new_operands; + new_operands.resize( + 2, {SPV_OPERAND_TYPE_ID, {0}}); // Place holders for vector operands. + const uint32_t undef_literal = 0xffffffff; + for (uint32_t op = 2; op < inst->NumInOperands(); ++op) { + uint32_t component_index = inst->GetSingleWordInOperand(op); + + // Do not interpret the undefined value literal as coming from operand 1. + if (component_index != undef_literal && + feeder_is_op0 == (component_index < op0_length)) { + // This component comes from the feeding_shuffle_inst. Update + // |component_index| to be the index into the operand of the feeder. + + // Adjust component_index to get the index into the operands of the + // feeding_shuffle_inst. + if (component_index >= op0_length) { + component_index -= op0_length; + } + component_index = + feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2); + + // Check if we are using a component from the first or second operand of + // the feeding instruction. + if (component_index < feeder_op0_length) { + if (new_feeder_id == 0) { + // First time through, save the id of the operand the element comes + // from. + new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0); + } else if (new_feeder_id != + feeding_shuffle_inst->GetSingleWordInOperand(0)) { + // We need both elements of the feeding_shuffle_inst, so we cannot + // fold. + return false; + } + } else { + if (new_feeder_id == 0) { + // First time through, save the id of the operand the element comes + // from. + new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1); + } else if (new_feeder_id != + feeding_shuffle_inst->GetSingleWordInOperand(1)) { + // We need both elements of the feeding_shuffle_inst, so we cannot + // fold. + return false; + } + component_index -= feeder_op0_length; + } + + if (!feeder_is_op0) { + component_index += op0_length; + } + } + new_operands.push_back( + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}}); + } + + if (new_feeder_id == 0) { + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + const analysis::Type* type = + type_mgr->GetType(feeding_shuffle_inst->type_id()); + const analysis::Constant* null_const = const_mgr->GetConstant(type, {}); + new_feeder_id = + const_mgr->GetDefiningInstruction(null_const, 0)->result_id(); + } + + if (feeder_is_op0) { + // If the size of the first vector operand changed then the indices + // referring to the second operand need to be adjusted. + Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id); + analysis::Type* new_feeder_type = + type_mgr->GetType(new_feeder_inst->type_id()); + uint32_t new_op0_size = new_feeder_type->AsVector()->element_count(); + int32_t adjustment = op0_length - new_op0_size; + + if (adjustment != 0) { + for (uint32_t i = 2; i < new_operands.size(); i++) { + if (inst->GetSingleWordInOperand(i) >= op0_length) { + new_operands[i].words[0] -= adjustment; + } + } + } + + new_operands[0].words[0] = new_feeder_id; + new_operands[1] = inst->GetInOperand(1); + } else { + new_operands[1].words[0] = new_feeder_id; + new_operands[0] = inst->GetInOperand(0); + } + + inst->SetInOperands(std::move(new_operands)); + return true; + }; +} + +} // namespace + +FoldingRules::FoldingRules() { + // Add all folding rules to the list for the opcodes to which they apply. + // Note that the order in which rules are added to the list matters. If a rule + // applies to the instruction, the rest of the rules will not be attempted. + // Take that into consideration. + rules_[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct()); + + rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract()); + rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract()); + rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract()); + rules_[SpvOpCompositeExtract].push_back(FMixFeedingExtract()); + + rules_[SpvOpDot].push_back(DotProductDoingExtract()); + + rules_[SpvOpExtInst].push_back(RedundantFMix()); + + rules_[SpvOpFAdd].push_back(RedundantFAdd()); + rules_[SpvOpFAdd].push_back(MergeAddNegateArithmetic()); + rules_[SpvOpFAdd].push_back(MergeAddAddArithmetic()); + rules_[SpvOpFAdd].push_back(MergeAddSubArithmetic()); + + rules_[SpvOpFDiv].push_back(RedundantFDiv()); + rules_[SpvOpFDiv].push_back(ReciprocalFDiv()); + rules_[SpvOpFDiv].push_back(MergeDivDivArithmetic()); + rules_[SpvOpFDiv].push_back(MergeDivMulArithmetic()); + rules_[SpvOpFDiv].push_back(MergeDivNegateArithmetic()); + + rules_[SpvOpFMul].push_back(RedundantFMul()); + rules_[SpvOpFMul].push_back(MergeMulMulArithmetic()); + rules_[SpvOpFMul].push_back(MergeMulDivArithmetic()); + rules_[SpvOpFMul].push_back(MergeMulNegateArithmetic()); + + rules_[SpvOpFNegate].push_back(MergeNegateArithmetic()); + rules_[SpvOpFNegate].push_back(MergeNegateAddSubArithmetic()); + rules_[SpvOpFNegate].push_back(MergeNegateMulDivArithmetic()); + + rules_[SpvOpFSub].push_back(RedundantFSub()); + rules_[SpvOpFSub].push_back(MergeSubNegateArithmetic()); + rules_[SpvOpFSub].push_back(MergeSubAddArithmetic()); + rules_[SpvOpFSub].push_back(MergeSubSubArithmetic()); + + rules_[SpvOpIAdd].push_back(RedundantIAdd()); + rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic()); + rules_[SpvOpIAdd].push_back(MergeAddAddArithmetic()); + rules_[SpvOpIAdd].push_back(MergeAddSubArithmetic()); + + rules_[SpvOpIMul].push_back(IntMultipleBy1()); + rules_[SpvOpIMul].push_back(MergeMulMulArithmetic()); + rules_[SpvOpIMul].push_back(MergeMulNegateArithmetic()); + + rules_[SpvOpISub].push_back(MergeSubNegateArithmetic()); + rules_[SpvOpISub].push_back(MergeSubAddArithmetic()); + rules_[SpvOpISub].push_back(MergeSubSubArithmetic()); + + rules_[SpvOpPhi].push_back(RedundantPhi()); + + rules_[SpvOpSDiv].push_back(MergeDivNegateArithmetic()); + + rules_[SpvOpSNegate].push_back(MergeNegateArithmetic()); + rules_[SpvOpSNegate].push_back(MergeNegateMulDivArithmetic()); + rules_[SpvOpSNegate].push_back(MergeNegateAddSubArithmetic()); + + rules_[SpvOpSelect].push_back(RedundantSelect()); + + rules_[SpvOpStore].push_back(StoringUndef()); + + rules_[SpvOpUDiv].push_back(MergeDivNegateArithmetic()); + + rules_[SpvOpVectorShuffle].push_back(VectorShuffleFeedingShuffle()); +} +} // namespace opt +} // namespace spvtools diff --git a/source/opt/folding_rules.h b/source/opt/folding_rules.h new file mode 100644 index 000000000..33fdbffe9 --- /dev/null +++ b/source/opt/folding_rules.h @@ -0,0 +1,79 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_FOLDING_RULES_H_ +#define SOURCE_OPT_FOLDING_RULES_H_ + +#include +#include +#include + +#include "source/opt/constants.h" + +namespace spvtools { +namespace opt { + +// Folding Rules: +// +// The folding mechanism is built around the concept of a |FoldingRule|. A +// folding rule is a function that implements a method of simplifying an +// instruction. +// +// The inputs to a folding rule are: +// |inst| - the instruction to be simplified. +// |constants| - if an in-operands is an id of a constant, then the +// corresponding value in |constants| contains that +// constant value. Otherwise, the corresponding entry in +// |constants| is |nullptr|. +// +// A folding rule returns true if |inst| can be simplified using this rule. If +// the instruction can be simplified, then |inst| is changed to the simplified +// instruction. Otherwise, |inst| remains the same. +// +// See folding_rules.cpp for examples on how to write a folding rule. It is +// important to note that if |inst| can be folded to the result of an +// instruction that feed it, then |inst| should be changed to an OpCopyObject +// that copies that id. +// +// Be sure to add new folding rules to the table of folding rules in the +// constructor for FoldingRules. The new rule should be added to the list for +// every opcode that it applies to. Note that earlier rules in the list are +// given priority. That is, if an earlier rule is able to fold an instruction, +// the later rules will not be attempted. + +using FoldingRule = std::function& constants)>; + +class FoldingRules { + public: + FoldingRules(); + + const std::vector& GetRulesForOpcode(SpvOp opcode) const { + auto it = rules_.find(opcode); + if (it != rules_.end()) { + return it->second; + } + return empty_vector_; + } + + private: + std::unordered_map> rules_; + std::vector empty_vector_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_FOLDING_RULES_H_ diff --git a/source/opt/freeze_spec_constant_value_pass.cpp b/source/opt/freeze_spec_constant_value_pass.cpp new file mode 100644 index 000000000..10e98fd8b --- /dev/null +++ b/source/opt/freeze_spec_constant_value_pass.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/freeze_spec_constant_value_pass.h" +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace opt { + +Pass::Status FreezeSpecConstantValuePass::Process() { + bool modified = false; + auto ctx = context(); + ctx->module()->ForEachInst([&modified, ctx](Instruction* inst) { + switch (inst->opcode()) { + case SpvOp::SpvOpSpecConstant: + inst->SetOpcode(SpvOp::SpvOpConstant); + modified = true; + break; + case SpvOp::SpvOpSpecConstantTrue: + inst->SetOpcode(SpvOp::SpvOpConstantTrue); + modified = true; + break; + case SpvOp::SpvOpSpecConstantFalse: + inst->SetOpcode(SpvOp::SpvOpConstantFalse); + modified = true; + break; + case SpvOp::SpvOpDecorate: + if (inst->GetSingleWordInOperand(1) == + SpvDecoration::SpvDecorationSpecId) { + ctx->KillInst(inst); + modified = true; + } + break; + default: + break; + } + }); + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/freeze_spec_constant_value_pass.h b/source/opt/freeze_spec_constant_value_pass.h new file mode 100644 index 000000000..0663adf40 --- /dev/null +++ b/source/opt/freeze_spec_constant_value_pass.h @@ -0,0 +1,35 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_FREEZE_SPEC_CONSTANT_VALUE_PASS_H_ +#define SOURCE_OPT_FREEZE_SPEC_CONSTANT_VALUE_PASS_H_ + +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class FreezeSpecConstantValuePass : public Pass { + public: + const char* name() const override { return "freeze-spec-const"; } + Status Process() override; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_FREEZE_SPEC_CONSTANT_VALUE_PASS_H_ diff --git a/source/opt/function.cpp b/source/opt/function.cpp new file mode 100644 index 000000000..9bd46e2a0 --- /dev/null +++ b/source/opt/function.cpp @@ -0,0 +1,183 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/function.h" +#include "function.h" +#include "ir_context.h" + +#include +#include +#include + +namespace spvtools { +namespace opt { + +Function* Function::Clone(IRContext* ctx) const { + Function* clone = + new Function(std::unique_ptr(DefInst().Clone(ctx))); + clone->params_.reserve(params_.size()); + ForEachParam( + [clone, ctx](const Instruction* inst) { + clone->AddParameter(std::unique_ptr(inst->Clone(ctx))); + }, + true); + + clone->blocks_.reserve(blocks_.size()); + for (const auto& b : blocks_) { + std::unique_ptr bb(b->Clone(ctx)); + bb->SetParent(clone); + clone->AddBasicBlock(std::move(bb)); + } + + clone->SetFunctionEnd(std::unique_ptr(EndInst()->Clone(ctx))); + return clone; +} + +void Function::ForEachInst(const std::function& f, + bool run_on_debug_line_insts) { + WhileEachInst( + [&f](Instruction* inst) { + f(inst); + return true; + }, + run_on_debug_line_insts); +} + +void Function::ForEachInst(const std::function& f, + bool run_on_debug_line_insts) const { + WhileEachInst( + [&f](const Instruction* inst) { + f(inst); + return true; + }, + run_on_debug_line_insts); +} + +bool Function::WhileEachInst(const std::function& f, + bool run_on_debug_line_insts) { + if (def_inst_) { + if (!def_inst_->WhileEachInst(f, run_on_debug_line_insts)) { + return false; + } + } + + for (auto& param : params_) { + if (!param->WhileEachInst(f, run_on_debug_line_insts)) { + return false; + } + } + + for (auto& bb : blocks_) { + if (!bb->WhileEachInst(f, run_on_debug_line_insts)) { + return false; + } + } + + if (end_inst_) return end_inst_->WhileEachInst(f, run_on_debug_line_insts); + + return true; +} + +bool Function::WhileEachInst(const std::function& f, + bool run_on_debug_line_insts) const { + if (def_inst_) { + if (!static_cast(def_inst_.get()) + ->WhileEachInst(f, run_on_debug_line_insts)) { + return false; + } + } + + for (const auto& param : params_) { + if (!static_cast(param.get()) + ->WhileEachInst(f, run_on_debug_line_insts)) { + return false; + } + } + + for (const auto& bb : blocks_) { + if (!static_cast(bb.get())->WhileEachInst( + f, run_on_debug_line_insts)) { + return false; + } + } + + if (end_inst_) + return static_cast(end_inst_.get()) + ->WhileEachInst(f, run_on_debug_line_insts); + + return true; +} + +void Function::ForEachParam(const std::function& f, + bool run_on_debug_line_insts) { + for (auto& param : params_) + static_cast(param.get()) + ->ForEachInst(f, run_on_debug_line_insts); +} + +void Function::ForEachParam(const std::function& f, + bool run_on_debug_line_insts) const { + for (const auto& param : params_) + static_cast(param.get()) + ->ForEachInst(f, run_on_debug_line_insts); +} + +BasicBlock* Function::InsertBasicBlockAfter( + std::unique_ptr&& new_block, BasicBlock* position) { + for (auto bb_iter = begin(); bb_iter != end(); ++bb_iter) { + if (&*bb_iter == position) { + new_block->SetParent(this); + ++bb_iter; + bb_iter = bb_iter.InsertBefore(std::move(new_block)); + return &*bb_iter; + } + } + assert(false && "Could not find insertion point."); + return nullptr; +} + +bool Function::IsRecursive() const { + IRContext* ctx = blocks_.front()->GetLabel()->context(); + IRContext::ProcessFunction mark_visited = [this](Function* fp) { + return fp == this; + }; + + // Process the call tree from all of the function called by |this|. If it get + // back to |this|, then we have a recursive function. + std::queue roots; + ctx->AddCalls(this, &roots); + return ctx->ProcessCallTreeFromRoots(mark_visited, &roots); +} + +std::ostream& operator<<(std::ostream& str, const Function& func) { + str << func.PrettyPrint(); + return str; +} + +void Function::Dump() const { + std::cerr << "Function #" << result_id() << "\n" << *this << "\n"; +} + +std::string Function::PrettyPrint(uint32_t options) const { + std::ostringstream str; + ForEachInst([&str, options](const Instruction* inst) { + str << inst->PrettyPrint(options); + if (inst->opcode() != SpvOpFunctionEnd) { + str << std::endl; + } + }); + return str.str(); +} +} // namespace opt +} // namespace spvtools diff --git a/source/opt/function.h b/source/opt/function.h new file mode 100644 index 000000000..c80b078cd --- /dev/null +++ b/source/opt/function.h @@ -0,0 +1,203 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_FUNCTION_H_ +#define SOURCE_OPT_FUNCTION_H_ + +#include +#include +#include +#include +#include +#include + +#include "source/opt/basic_block.h" +#include "source/opt/instruction.h" +#include "source/opt/iterator.h" + +namespace spvtools { +namespace opt { + +class CFG; +class IRContext; +class Module; + +// A SPIR-V function. +class Function { + public: + using iterator = UptrVectorIterator; + using const_iterator = UptrVectorIterator; + + // Creates a function instance declared by the given OpFunction instruction + // |def_inst|. + inline explicit Function(std::unique_ptr def_inst); + + explicit Function(const Function& f) = delete; + + // Creates a clone of the instruction in the given |context| + // + // The parent module will default to null and needs to be explicitly set by + // the user. + Function* Clone(IRContext*) const; + // The OpFunction instruction that begins the definition of this function. + Instruction& DefInst() { return *def_inst_; } + const Instruction& DefInst() const { return *def_inst_; } + + // Appends a parameter to this function. + inline void AddParameter(std::unique_ptr p); + // Appends a basic block to this function. + inline void AddBasicBlock(std::unique_ptr b); + // Appends a basic block to this function at the position |ip|. + inline void AddBasicBlock(std::unique_ptr b, iterator ip); + template + inline void AddBasicBlocks(T begin, T end, iterator ip); + + // Move basic block with |id| to the position after |ip|. Both have to be + // contained in this function. + inline void MoveBasicBlockToAfter(uint32_t id, BasicBlock* ip); + + // Delete all basic blocks that contain no instructions. + inline void RemoveEmptyBlocks(); + + // Saves the given function end instruction. + inline void SetFunctionEnd(std::unique_ptr end_inst); + + // Returns the given function end instruction. + inline Instruction* EndInst() { return end_inst_.get(); } + inline const Instruction* EndInst() const { return end_inst_.get(); } + + // Returns function's id + inline uint32_t result_id() const { return def_inst_->result_id(); } + + // Returns function's return type id + inline uint32_t type_id() const { return def_inst_->type_id(); } + + // Returns the entry basic block for this function. + const std::unique_ptr& entry() const { return blocks_.front(); } + + iterator begin() { return iterator(&blocks_, blocks_.begin()); } + iterator end() { return iterator(&blocks_, blocks_.end()); } + const_iterator begin() const { return cbegin(); } + const_iterator end() const { return cend(); } + const_iterator cbegin() const { + return const_iterator(&blocks_, blocks_.cbegin()); + } + const_iterator cend() const { + return const_iterator(&blocks_, blocks_.cend()); + } + + // Returns an iterator to the basic block |id|. + iterator FindBlock(uint32_t bb_id) { + return std::find_if(begin(), end(), [bb_id](const BasicBlock& it_bb) { + return bb_id == it_bb.id(); + }); + } + + // Runs the given function |f| on each instruction in this function, and + // optionally on debug line instructions that might precede them. + void ForEachInst(const std::function& f, + bool run_on_debug_line_insts = false); + void ForEachInst(const std::function& f, + bool run_on_debug_line_insts = false) const; + bool WhileEachInst(const std::function& f, + bool run_on_debug_line_insts = false); + bool WhileEachInst(const std::function& f, + bool run_on_debug_line_insts = false) const; + + // Runs the given function |f| on each parameter instruction in this function, + // and optionally on debug line instructions that might precede them. + void ForEachParam(const std::function& f, + bool run_on_debug_line_insts = false) const; + void ForEachParam(const std::function& f, + bool run_on_debug_line_insts = false); + + BasicBlock* InsertBasicBlockAfter(std::unique_ptr&& new_block, + BasicBlock* position); + + // Return true if the function calls itself either directly or indirectly. + bool IsRecursive() const; + + // Pretty-prints all the basic blocks in this function into a std::string. + // + // |options| are the disassembly options. SPV_BINARY_TO_TEXT_OPTION_NO_HEADER + // is always added to |options|. + std::string PrettyPrint(uint32_t options = 0u) const; + + // Dump this function on stderr. Useful when running interactive + // debuggers. + void Dump() const; + + private: + // The OpFunction instruction that begins the definition of this function. + std::unique_ptr def_inst_; + // All parameters to this function. + std::vector> params_; + // All basic blocks inside this function in specification order + std::vector> blocks_; + // The OpFunctionEnd instruction. + std::unique_ptr end_inst_; +}; + +// Pretty-prints |func| to |str|. Returns |str|. +std::ostream& operator<<(std::ostream& str, const Function& func); + +inline Function::Function(std::unique_ptr def_inst) + : def_inst_(std::move(def_inst)), end_inst_() {} + +inline void Function::AddParameter(std::unique_ptr p) { + params_.emplace_back(std::move(p)); +} + +inline void Function::AddBasicBlock(std::unique_ptr b) { + AddBasicBlock(std::move(b), end()); +} + +inline void Function::AddBasicBlock(std::unique_ptr b, + iterator ip) { + ip.InsertBefore(std::move(b)); +} + +template +inline void Function::AddBasicBlocks(T src_begin, T src_end, iterator ip) { + blocks_.insert(ip.Get(), std::make_move_iterator(src_begin), + std::make_move_iterator(src_end)); +} + +inline void Function::MoveBasicBlockToAfter(uint32_t id, BasicBlock* ip) { + auto block_to_move = std::move(*FindBlock(id).Get()); + + assert(block_to_move->GetParent() == ip->GetParent() && + "Both blocks have to be in the same function."); + + InsertBasicBlockAfter(std::move(block_to_move), ip); + blocks_.erase(std::find(std::begin(blocks_), std::end(blocks_), nullptr)); +} + +inline void Function::RemoveEmptyBlocks() { + auto first_empty = + std::remove_if(std::begin(blocks_), std::end(blocks_), + [](const std::unique_ptr& bb) -> bool { + return bb->GetLabelInst()->opcode() == SpvOpNop; + }); + blocks_.erase(first_empty, std::end(blocks_)); +} + +inline void Function::SetFunctionEnd(std::unique_ptr end_inst) { + end_inst_ = std::move(end_inst); +} + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_FUNCTION_H_ diff --git a/source/opt/if_conversion.cpp b/source/opt/if_conversion.cpp new file mode 100644 index 000000000..104182bc3 --- /dev/null +++ b/source/opt/if_conversion.cpp @@ -0,0 +1,285 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/if_conversion.h" + +#include +#include + +#include "source/opt/value_number_table.h" + +namespace spvtools { +namespace opt { + +Pass::Status IfConversion::Process() { + if (!context()->get_feature_mgr()->HasCapability(SpvCapabilityShader)) { + return Status::SuccessWithoutChange; + } + + const ValueNumberTable& vn_table = *context()->GetValueNumberTable(); + bool modified = false; + std::vector to_kill; + for (auto& func : *get_module()) { + DominatorAnalysis* dominators = context()->GetDominatorAnalysis(&func); + for (auto& block : func) { + // Check if it is possible for |block| to have phis that can be + // transformed. + BasicBlock* common = nullptr; + if (!CheckBlock(&block, dominators, &common)) continue; + + // Get an insertion point. + auto iter = block.begin(); + while (iter != block.end() && iter->opcode() == SpvOpPhi) { + ++iter; + } + + InstructionBuilder builder( + context(), &*iter, + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + block.ForEachPhiInst([this, &builder, &modified, &common, &to_kill, + dominators, &block, &vn_table](Instruction* phi) { + // This phi is not compatible, but subsequent phis might be. + if (!CheckType(phi->type_id())) return; + + // We cannot transform cases where the phi is used by another phi in the + // same block due to instruction ordering restrictions. + // TODO(alan-baker): If all inappropriate uses could also be + // transformed, we could still remove this phi. + if (!CheckPhiUsers(phi, &block)) return; + + // Identify the incoming values associated with the true and false + // branches. If |then_block| dominates |inc0| or if the true edge + // branches straight to this block and |common| is |inc0|, then |inc0| + // is on the true branch. Otherwise the |inc1| is on the true branch. + BasicBlock* inc0 = GetIncomingBlock(phi, 0u); + Instruction* branch = common->terminator(); + uint32_t condition = branch->GetSingleWordInOperand(0u); + BasicBlock* then_block = GetBlock(branch->GetSingleWordInOperand(1u)); + Instruction* true_value = nullptr; + Instruction* false_value = nullptr; + if ((then_block == &block && inc0 == common) || + dominators->Dominates(then_block, inc0)) { + true_value = GetIncomingValue(phi, 0u); + false_value = GetIncomingValue(phi, 1u); + } else { + true_value = GetIncomingValue(phi, 1u); + false_value = GetIncomingValue(phi, 0u); + } + + BasicBlock* true_def_block = context()->get_instr_block(true_value); + BasicBlock* false_def_block = context()->get_instr_block(false_value); + + uint32_t true_vn = vn_table.GetValueNumber(true_value); + uint32_t false_vn = vn_table.GetValueNumber(false_value); + if (true_vn != 0 && true_vn == false_vn) { + Instruction* inst_to_use = nullptr; + + // Try to pick an instruction that is not in a side node. If we can't + // pick either the true for false branch as long as they can be + // legally moved. + if (!true_def_block || + dominators->Dominates(true_def_block, &block)) { + inst_to_use = true_value; + } else if (!false_def_block || + dominators->Dominates(false_def_block, &block)) { + inst_to_use = false_value; + } else if (CanHoistInstruction(true_value, common, dominators)) { + inst_to_use = true_value; + } else if (CanHoistInstruction(false_value, common, dominators)) { + inst_to_use = false_value; + } + + if (inst_to_use != nullptr) { + modified = true; + HoistInstruction(inst_to_use, common, dominators); + context()->KillNamesAndDecorates(phi); + context()->ReplaceAllUsesWith(phi->result_id(), + inst_to_use->result_id()); + } + return; + } + + // If either incoming value is defined in a block that does not dominate + // this phi, then we cannot eliminate the phi with a select. + // TODO(alan-baker): Perform code motion where it makes sense to enable + // the transform in this case. + if (true_def_block && !dominators->Dominates(true_def_block, &block)) + return; + + if (false_def_block && !dominators->Dominates(false_def_block, &block)) + return; + + analysis::Type* data_ty = + context()->get_type_mgr()->GetType(true_value->type_id()); + if (analysis::Vector* vec_data_ty = data_ty->AsVector()) { + condition = SplatCondition(vec_data_ty, condition, &builder); + } + + Instruction* select = builder.AddSelect(phi->type_id(), condition, + true_value->result_id(), + false_value->result_id()); + context()->ReplaceAllUsesWith(phi->result_id(), select->result_id()); + to_kill.push_back(phi); + modified = true; + + return; + }); + } + } + + for (auto inst : to_kill) { + context()->KillInst(inst); + } + + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +bool IfConversion::CheckBlock(BasicBlock* block, DominatorAnalysis* dominators, + BasicBlock** common) { + const std::vector& preds = cfg()->preds(block->id()); + + // TODO(alan-baker): Extend to more than two predecessors + if (preds.size() != 2) return false; + + BasicBlock* inc0 = context()->get_instr_block(preds[0]); + if (dominators->Dominates(block, inc0)) return false; + + BasicBlock* inc1 = context()->get_instr_block(preds[1]); + if (dominators->Dominates(block, inc1)) return false; + + // All phis will have the same common dominator, so cache the result + // for this block. If there is no common dominator, then we cannot transform + // any phi in this basic block. + *common = dominators->CommonDominator(inc0, inc1); + if (!*common || cfg()->IsPseudoEntryBlock(*common)) return false; + Instruction* branch = (*common)->terminator(); + if (branch->opcode() != SpvOpBranchConditional) return false; + auto merge = (*common)->GetMergeInst(); + if (!merge || merge->opcode() != SpvOpSelectionMerge) return false; + if ((*common)->MergeBlockIdIfAny() != block->id()) return false; + + return true; +} + +bool IfConversion::CheckPhiUsers(Instruction* phi, BasicBlock* block) { + return get_def_use_mgr()->WhileEachUser(phi, [block, + this](Instruction* user) { + if (user->opcode() == SpvOpPhi && context()->get_instr_block(user) == block) + return false; + return true; + }); +} + +uint32_t IfConversion::SplatCondition(analysis::Vector* vec_data_ty, + uint32_t cond, + InstructionBuilder* builder) { + // If the data inputs to OpSelect are vectors, the condition for + // OpSelect must be a boolean vector with the same number of + // components. So splat the condition for the branch into a vector + // type. + analysis::Bool bool_ty; + analysis::Vector bool_vec_ty(&bool_ty, vec_data_ty->element_count()); + uint32_t bool_vec_id = + context()->get_type_mgr()->GetTypeInstruction(&bool_vec_ty); + std::vector ids(vec_data_ty->element_count(), cond); + return builder->AddCompositeConstruct(bool_vec_id, ids)->result_id(); +} + +bool IfConversion::CheckType(uint32_t id) { + Instruction* type = get_def_use_mgr()->GetDef(id); + SpvOp op = type->opcode(); + if (spvOpcodeIsScalarType(op) || op == SpvOpTypePointer || + op == SpvOpTypeVector) + return true; + return false; +} + +BasicBlock* IfConversion::GetBlock(uint32_t id) { + return context()->get_instr_block(get_def_use_mgr()->GetDef(id)); +} + +BasicBlock* IfConversion::GetIncomingBlock(Instruction* phi, + uint32_t predecessor) { + uint32_t in_index = 2 * predecessor + 1; + return GetBlock(phi->GetSingleWordInOperand(in_index)); +} + +Instruction* IfConversion::GetIncomingValue(Instruction* phi, + uint32_t predecessor) { + uint32_t in_index = 2 * predecessor; + return get_def_use_mgr()->GetDef(phi->GetSingleWordInOperand(in_index)); +} + +void IfConversion::HoistInstruction(Instruction* inst, BasicBlock* target_block, + DominatorAnalysis* dominators) { + BasicBlock* inst_block = context()->get_instr_block(inst); + if (!inst_block) { + // This is in the header, and dominates everything. + return; + } + + if (dominators->Dominates(inst_block, target_block)) { + // Already in position. No work to do. + return; + } + + assert(inst->IsOpcodeCodeMotionSafe() && + "Trying to move an instruction that is not safe to move."); + + // First hoist all instructions it depends on. + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + inst->ForEachInId( + [this, target_block, def_use_mgr, dominators](uint32_t* id) { + Instruction* operand_inst = def_use_mgr->GetDef(*id); + HoistInstruction(operand_inst, target_block, dominators); + }); + + Instruction* insertion_pos = target_block->terminator(); + if ((insertion_pos)->PreviousNode()->opcode() == SpvOpSelectionMerge) { + insertion_pos = insertion_pos->PreviousNode(); + } + inst->RemoveFromList(); + insertion_pos->InsertBefore(std::unique_ptr(inst)); + context()->set_instr_block(inst, target_block); +} + +bool IfConversion::CanHoistInstruction(Instruction* inst, + BasicBlock* target_block, + DominatorAnalysis* dominators) { + BasicBlock* inst_block = context()->get_instr_block(inst); + if (!inst_block) { + // This is in the header, and dominates everything. + return true; + } + + if (dominators->Dominates(inst_block, target_block)) { + // Already in position. No work to do. + return true; + } + + if (!inst->IsOpcodeCodeMotionSafe()) { + return false; + } + + // Check all instruction |inst| depends on. + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + return inst->WhileEachInId( + [this, target_block, def_use_mgr, dominators](uint32_t* id) { + Instruction* operand_inst = def_use_mgr->GetDef(*id); + return CanHoistInstruction(operand_inst, target_block, dominators); + }); +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/if_conversion.h b/source/opt/if_conversion.h new file mode 100644 index 000000000..db84e703b --- /dev/null +++ b/source/opt/if_conversion.h @@ -0,0 +1,89 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_IF_CONVERSION_H_ +#define SOURCE_OPT_IF_CONVERSION_H_ + +#include "source/opt/basic_block.h" +#include "source/opt/ir_builder.h" +#include "source/opt/pass.h" +#include "source/opt/types.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class IfConversion : public Pass { + public: + const char* name() const override { return "if-conversion"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | IRContext::kAnalysisDominatorAnalysis | + IRContext::kAnalysisInstrToBlockMapping | IRContext::kAnalysisCFG | + IRContext::kAnalysisNameMap | IRContext::kAnalysisConstants | + IRContext::kAnalysisTypes; + } + + private: + // Returns true if |id| is a valid type for use with OpSelect. OpSelect only + // allows scalars, vectors and pointers as valid inputs. + bool CheckType(uint32_t id); + + // Returns the basic block containing |id|. + BasicBlock* GetBlock(uint32_t id); + + // Returns the basic block for the |predecessor|'th index predecessor of + // |phi|. + BasicBlock* GetIncomingBlock(Instruction* phi, uint32_t predecessor); + + // Returns the instruction defining the |predecessor|'th index of |phi|. + Instruction* GetIncomingValue(Instruction* phi, uint32_t predecessor); + + // Returns the id of a OpCompositeConstruct boolean vector. The composite has + // the same number of elements as |vec_data_ty| and each member is |cond|. + // |where| indicates the location in |block| to insert the composite + // construct. If necessary, this function will also construct the necessary + // type instructions for the boolean vector. + uint32_t SplatCondition(analysis::Vector* vec_data_ty, uint32_t cond, + InstructionBuilder* builder); + + // Returns true if none of |phi|'s users are in |block|. + bool CheckPhiUsers(Instruction* phi, BasicBlock* block); + + // Returns |false| if |block| is not appropriate to transform. Only + // transforms blocks with two predecessors. Neither incoming block can be + // dominated by |block|. Both predecessors must share a common dominator that + // is terminated by a conditional branch. + bool CheckBlock(BasicBlock* block, DominatorAnalysis* dominators, + BasicBlock** common); + + // Moves |inst| to |target_block| if it does not already dominate the block. + // Any instructions that |inst| depends on are move if necessary. It is + // assumed that |inst| can be hoisted to |target_block| as defined by + // |CanHoistInstruction|. |dominators| is the dominator analysis for the + // function that contains |target_block|. + void HoistInstruction(Instruction* inst, BasicBlock* target_block, + DominatorAnalysis* dominators); + + // Returns true if it is legal to move |inst| and the instructions it depends + // on to |target_block| if they do not already dominate |target_block|. + bool CanHoistInstruction(Instruction* inst, BasicBlock* target_block, + DominatorAnalysis* dominators); +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_IF_CONVERSION_H_ diff --git a/source/opt/inline_exhaustive_pass.cpp b/source/opt/inline_exhaustive_pass.cpp new file mode 100644 index 000000000..24f4e7364 --- /dev/null +++ b/source/opt/inline_exhaustive_pass.cpp @@ -0,0 +1,85 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/inline_exhaustive_pass.h" + +#include + +namespace spvtools { +namespace opt { + +Pass::Status InlineExhaustivePass::InlineExhaustive(Function* func) { + bool modified = false; + // Using block iterators here because of block erasures and insertions. + for (auto bi = func->begin(); bi != func->end(); ++bi) { + for (auto ii = bi->begin(); ii != bi->end();) { + if (IsInlinableFunctionCall(&*ii)) { + // Inline call. + std::vector> newBlocks; + std::vector> newVars; + if (!GenInlineCode(&newBlocks, &newVars, ii, bi)) { + return Status::Failure; + } + // If call block is replaced with more than one block, point + // succeeding phis at new last block. + if (newBlocks.size() > 1) UpdateSucceedingPhis(newBlocks); + // Replace old calling block with new block(s). + + // We need to kill the name and decorations for the call, which + // will be deleted. Other instructions in the block will be moved to + // newBlocks. We don't need to do anything with those. + context()->KillNamesAndDecorates(&*ii); + + bi = bi.Erase(); + + for (auto& bb : newBlocks) { + bb->SetParent(func); + } + bi = bi.InsertBefore(&newBlocks); + // Insert new function variables. + if (newVars.size() > 0) + func->begin()->begin().InsertBefore(std::move(newVars)); + // Restart inlining at beginning of calling block. + ii = bi->begin(); + modified = true; + } else { + ++ii; + } + } + } + return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); +} + +Pass::Status InlineExhaustivePass::ProcessImpl() { + Status status = Status::SuccessWithoutChange; + // Attempt exhaustive inlining on each entry point function in module + ProcessFunction pfn = [&status, this](Function* fp) { + status = CombineStatus(status, InlineExhaustive(fp)); + return false; + }; + context()->ProcessEntryPointCallTree(pfn); + return status; +} + +InlineExhaustivePass::InlineExhaustivePass() = default; + +Pass::Status InlineExhaustivePass::Process() { + InitializeInline(); + return ProcessImpl(); +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/inline_exhaustive_pass.h b/source/opt/inline_exhaustive_pass.h new file mode 100644 index 000000000..c2e854731 --- /dev/null +++ b/source/opt/inline_exhaustive_pass.h @@ -0,0 +1,53 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_INLINE_EXHAUSTIVE_PASS_H_ +#define SOURCE_OPT_INLINE_EXHAUSTIVE_PASS_H_ + +#include +#include +#include +#include +#include + +#include "source/opt/def_use_manager.h" +#include "source/opt/inline_pass.h" +#include "source/opt/module.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class InlineExhaustivePass : public InlinePass { + public: + InlineExhaustivePass(); + Status Process() override; + + const char* name() const override { return "inline-entry-points-exhaustive"; } + + private: + // Exhaustively inline all function calls in func as well as in + // all code that is inlined into func. Returns the status. + Status InlineExhaustive(Function* func); + + void Initialize(); + Pass::Status ProcessImpl(); +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_INLINE_EXHAUSTIVE_PASS_H_ diff --git a/source/opt/inline_opaque_pass.cpp b/source/opt/inline_opaque_pass.cpp new file mode 100644 index 000000000..6ccaf9087 --- /dev/null +++ b/source/opt/inline_opaque_pass.cpp @@ -0,0 +1,120 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/inline_opaque_pass.h" + +#include + +namespace spvtools { +namespace opt { +namespace { + +const uint32_t kTypePointerTypeIdInIdx = 1; + +} // anonymous namespace + +bool InlineOpaquePass::IsOpaqueType(uint32_t typeId) { + const Instruction* typeInst = get_def_use_mgr()->GetDef(typeId); + switch (typeInst->opcode()) { + case SpvOpTypeSampler: + case SpvOpTypeImage: + case SpvOpTypeSampledImage: + return true; + case SpvOpTypePointer: + return IsOpaqueType( + typeInst->GetSingleWordInOperand(kTypePointerTypeIdInIdx)); + default: + break; + } + // TODO(greg-lunarg): Handle arrays containing opaque type + if (typeInst->opcode() != SpvOpTypeStruct) return false; + // Return true if any member is opaque + return !typeInst->WhileEachInId([this](const uint32_t* tid) { + if (IsOpaqueType(*tid)) return false; + return true; + }); +} + +bool InlineOpaquePass::HasOpaqueArgsOrReturn(const Instruction* callInst) { + // Check return type + if (IsOpaqueType(callInst->type_id())) return true; + // Check args + int icnt = 0; + return !callInst->WhileEachInId([&icnt, this](const uint32_t* iid) { + if (icnt > 0) { + const Instruction* argInst = get_def_use_mgr()->GetDef(*iid); + if (IsOpaqueType(argInst->type_id())) return false; + } + ++icnt; + return true; + }); +} + +Pass::Status InlineOpaquePass::InlineOpaque(Function* func) { + bool modified = false; + // Using block iterators here because of block erasures and insertions. + for (auto bi = func->begin(); bi != func->end(); ++bi) { + for (auto ii = bi->begin(); ii != bi->end();) { + if (IsInlinableFunctionCall(&*ii) && HasOpaqueArgsOrReturn(&*ii)) { + // Inline call. + std::vector> newBlocks; + std::vector> newVars; + if (!GenInlineCode(&newBlocks, &newVars, ii, bi)) { + return Status::Failure; + } + + // If call block is replaced with more than one block, point + // succeeding phis at new last block. + if (newBlocks.size() > 1) UpdateSucceedingPhis(newBlocks); + // Replace old calling block with new block(s). + bi = bi.Erase(); + bi = bi.InsertBefore(&newBlocks); + // Insert new function variables. + if (newVars.size() > 0) + func->begin()->begin().InsertBefore(std::move(newVars)); + // Restart inlining at beginning of calling block. + ii = bi->begin(); + modified = true; + } else { + ++ii; + } + } + } + return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); +} + +void InlineOpaquePass::Initialize() { InitializeInline(); } + +Pass::Status InlineOpaquePass::ProcessImpl() { + Status status = Status::SuccessWithoutChange; + // Do opaque inlining on each function in entry point call tree + ProcessFunction pfn = [&status, this](Function* fp) { + status = CombineStatus(status, InlineOpaque(fp)); + return false; + }; + context()->ProcessEntryPointCallTree(pfn); + return status; +} + +InlineOpaquePass::InlineOpaquePass() = default; + +Pass::Status InlineOpaquePass::Process() { + Initialize(); + return ProcessImpl(); +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/inline_opaque_pass.h b/source/opt/inline_opaque_pass.h new file mode 100644 index 000000000..1e3081d22 --- /dev/null +++ b/source/opt/inline_opaque_pass.h @@ -0,0 +1,60 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_INLINE_OPAQUE_PASS_H_ +#define SOURCE_OPT_INLINE_OPAQUE_PASS_H_ + +#include +#include +#include +#include +#include + +#include "source/opt/def_use_manager.h" +#include "source/opt/inline_pass.h" +#include "source/opt/module.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class InlineOpaquePass : public InlinePass { + public: + InlineOpaquePass(); + Status Process() override; + + const char* name() const override { return "inline-entry-points-opaque"; } + + private: + // Return true if |typeId| is or contains opaque type + bool IsOpaqueType(uint32_t typeId); + + // Return true if function call |callInst| has opaque argument or return type + bool HasOpaqueArgsOrReturn(const Instruction* callInst); + + // Inline all function calls in |func| that have opaque params or return + // type. Inline similarly all code that is inlined into func. Return true + // if func is modified. + Status InlineOpaque(Function* func); + + void Initialize(); + Pass::Status ProcessImpl(); +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_INLINE_OPAQUE_PASS_H_ diff --git a/source/opt/inline_pass.cpp b/source/opt/inline_pass.cpp new file mode 100644 index 000000000..f348bbe3e --- /dev/null +++ b/source/opt/inline_pass.cpp @@ -0,0 +1,750 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/inline_pass.h" + +#include +#include + +#include "source/cfa.h" +#include "source/util/make_unique.h" + +// Indices of operands in SPIR-V instructions + +static const int kSpvFunctionCallFunctionId = 2; +static const int kSpvFunctionCallArgumentId = 3; +static const int kSpvReturnValueId = 0; +static const int kSpvLoopMergeContinueTargetIdInIdx = 1; + +namespace spvtools { +namespace opt { + +uint32_t InlinePass::AddPointerToType(uint32_t type_id, + SpvStorageClass storage_class) { + uint32_t resultId = context()->TakeNextId(); + if (resultId == 0) { + return resultId; + } + + std::unique_ptr type_inst( + new Instruction(context(), SpvOpTypePointer, 0, resultId, + {{spv_operand_type_t::SPV_OPERAND_TYPE_STORAGE_CLASS, + {uint32_t(storage_class)}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {type_id}}})); + context()->AddType(std::move(type_inst)); + analysis::Type* pointeeTy; + std::unique_ptr pointerTy; + std::tie(pointeeTy, pointerTy) = + context()->get_type_mgr()->GetTypeAndPointerType(type_id, + SpvStorageClassFunction); + context()->get_type_mgr()->RegisterType(resultId, *pointerTy); + return resultId; +} + +void InlinePass::AddBranch(uint32_t label_id, + std::unique_ptr* block_ptr) { + std::unique_ptr newBranch( + new Instruction(context(), SpvOpBranch, 0, 0, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {label_id}}})); + (*block_ptr)->AddInstruction(std::move(newBranch)); +} + +void InlinePass::AddBranchCond(uint32_t cond_id, uint32_t true_id, + uint32_t false_id, + std::unique_ptr* block_ptr) { + std::unique_ptr newBranch( + new Instruction(context(), SpvOpBranchConditional, 0, 0, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {cond_id}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {true_id}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {false_id}}})); + (*block_ptr)->AddInstruction(std::move(newBranch)); +} + +void InlinePass::AddLoopMerge(uint32_t merge_id, uint32_t continue_id, + std::unique_ptr* block_ptr) { + std::unique_ptr newLoopMerge(new Instruction( + context(), SpvOpLoopMerge, 0, 0, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {merge_id}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {continue_id}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_LOOP_CONTROL, {0}}})); + (*block_ptr)->AddInstruction(std::move(newLoopMerge)); +} + +void InlinePass::AddStore(uint32_t ptr_id, uint32_t val_id, + std::unique_ptr* block_ptr) { + std::unique_ptr newStore( + new Instruction(context(), SpvOpStore, 0, 0, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ptr_id}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {val_id}}})); + (*block_ptr)->AddInstruction(std::move(newStore)); +} + +void InlinePass::AddLoad(uint32_t type_id, uint32_t resultId, uint32_t ptr_id, + std::unique_ptr* block_ptr) { + std::unique_ptr newLoad( + new Instruction(context(), SpvOpLoad, type_id, resultId, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ptr_id}}})); + (*block_ptr)->AddInstruction(std::move(newLoad)); +} + +std::unique_ptr InlinePass::NewLabel(uint32_t label_id) { + std::unique_ptr newLabel( + new Instruction(context(), SpvOpLabel, 0, label_id, {})); + return newLabel; +} + +uint32_t InlinePass::GetFalseId() { + if (false_id_ != 0) return false_id_; + false_id_ = get_module()->GetGlobalValue(SpvOpConstantFalse); + if (false_id_ != 0) return false_id_; + uint32_t boolId = get_module()->GetGlobalValue(SpvOpTypeBool); + if (boolId == 0) { + boolId = context()->TakeNextId(); + if (boolId == 0) { + return 0; + } + get_module()->AddGlobalValue(SpvOpTypeBool, boolId, 0); + } + false_id_ = context()->TakeNextId(); + if (false_id_ == 0) { + return 0; + } + get_module()->AddGlobalValue(SpvOpConstantFalse, false_id_, boolId); + return false_id_; +} + +void InlinePass::MapParams( + Function* calleeFn, BasicBlock::iterator call_inst_itr, + std::unordered_map* callee2caller) { + int param_idx = 0; + calleeFn->ForEachParam( + [&call_inst_itr, ¶m_idx, &callee2caller](const Instruction* cpi) { + const uint32_t pid = cpi->result_id(); + (*callee2caller)[pid] = call_inst_itr->GetSingleWordOperand( + kSpvFunctionCallArgumentId + param_idx); + ++param_idx; + }); +} + +bool InlinePass::CloneAndMapLocals( + Function* calleeFn, std::vector>* new_vars, + std::unordered_map* callee2caller) { + auto callee_block_itr = calleeFn->begin(); + auto callee_var_itr = callee_block_itr->begin(); + while (callee_var_itr->opcode() == SpvOp::SpvOpVariable) { + std::unique_ptr var_inst(callee_var_itr->Clone(context())); + uint32_t newId = context()->TakeNextId(); + if (newId == 0) { + return false; + } + get_decoration_mgr()->CloneDecorations(callee_var_itr->result_id(), newId); + var_inst->SetResultId(newId); + (*callee2caller)[callee_var_itr->result_id()] = newId; + new_vars->push_back(std::move(var_inst)); + ++callee_var_itr; + } + return true; +} + +uint32_t InlinePass::CreateReturnVar( + Function* calleeFn, std::vector>* new_vars) { + uint32_t returnVarId = 0; + const uint32_t calleeTypeId = calleeFn->type_id(); + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + assert(type_mgr->GetType(calleeTypeId)->AsVoid() == nullptr && + "Cannot create a return variable of type void."); + // Find or create ptr to callee return type. + uint32_t returnVarTypeId = + type_mgr->FindPointerToType(calleeTypeId, SpvStorageClassFunction); + + if (returnVarTypeId == 0) { + returnVarTypeId = AddPointerToType(calleeTypeId, SpvStorageClassFunction); + if (returnVarTypeId == 0) { + return 0; + } + } + + // Add return var to new function scope variables. + returnVarId = context()->TakeNextId(); + if (returnVarId == 0) { + return 0; + } + + std::unique_ptr var_inst( + new Instruction(context(), SpvOpVariable, returnVarTypeId, returnVarId, + {{spv_operand_type_t::SPV_OPERAND_TYPE_STORAGE_CLASS, + {SpvStorageClassFunction}}})); + new_vars->push_back(std::move(var_inst)); + get_decoration_mgr()->CloneDecorations(calleeFn->result_id(), returnVarId); + return returnVarId; +} + +bool InlinePass::IsSameBlockOp(const Instruction* inst) const { + return inst->opcode() == SpvOpSampledImage || inst->opcode() == SpvOpImage; +} + +bool InlinePass::CloneSameBlockOps( + std::unique_ptr* inst, + std::unordered_map* postCallSB, + std::unordered_map* preCallSB, + std::unique_ptr* block_ptr) { + return (*inst)->WhileEachInId([&postCallSB, &preCallSB, &block_ptr, + this](uint32_t* iid) { + const auto mapItr = (*postCallSB).find(*iid); + if (mapItr == (*postCallSB).end()) { + const auto mapItr2 = (*preCallSB).find(*iid); + if (mapItr2 != (*preCallSB).end()) { + // Clone pre-call same-block ops, map result id. + const Instruction* inInst = mapItr2->second; + std::unique_ptr sb_inst(inInst->Clone(context())); + if (!CloneSameBlockOps(&sb_inst, postCallSB, preCallSB, block_ptr)) { + return false; + } + + const uint32_t rid = sb_inst->result_id(); + const uint32_t nid = context()->TakeNextId(); + if (nid == 0) { + return false; + } + get_decoration_mgr()->CloneDecorations(rid, nid); + sb_inst->SetResultId(nid); + (*postCallSB)[rid] = nid; + *iid = nid; + (*block_ptr)->AddInstruction(std::move(sb_inst)); + } + } else { + // Reset same-block op operand. + *iid = mapItr->second; + } + return true; + }); +} + +bool InlinePass::GenInlineCode( + std::vector>* new_blocks, + std::vector>* new_vars, + BasicBlock::iterator call_inst_itr, + UptrVectorIterator call_block_itr) { + // Map from all ids in the callee to their equivalent id in the caller + // as callee instructions are copied into caller. + std::unordered_map callee2caller; + // Pre-call same-block insts + std::unordered_map preCallSB; + // Post-call same-block op ids + std::unordered_map postCallSB; + + // Invalidate the def-use chains. They are not kept up to date while + // inlining. However, certain calls try to keep them up-to-date if they are + // valid. These operations can fail. + context()->InvalidateAnalyses(IRContext::kAnalysisDefUse); + + Function* calleeFn = id2function_[call_inst_itr->GetSingleWordOperand( + kSpvFunctionCallFunctionId)]; + + // Check for multiple returns in the callee. + auto fi = early_return_funcs_.find(calleeFn->result_id()); + const bool earlyReturn = fi != early_return_funcs_.end(); + + // Map parameters to actual arguments. + MapParams(calleeFn, call_inst_itr, &callee2caller); + + // Define caller local variables for all callee variables and create map to + // them. + if (!CloneAndMapLocals(calleeFn, new_vars, &callee2caller)) { + return false; + } + + // Create return var if needed. + const uint32_t calleeTypeId = calleeFn->type_id(); + uint32_t returnVarId = 0; + analysis::Type* calleeType = context()->get_type_mgr()->GetType(calleeTypeId); + if (calleeType->AsVoid() == nullptr) { + returnVarId = CreateReturnVar(calleeFn, new_vars); + if (returnVarId == 0) { + return false; + } + } + + // Create set of callee result ids. Used to detect forward references + std::unordered_set callee_result_ids; + calleeFn->ForEachInst([&callee_result_ids](const Instruction* cpi) { + const uint32_t rid = cpi->result_id(); + if (rid != 0) callee_result_ids.insert(rid); + }); + + // If the caller is in a single-block loop, and the callee has multiple + // blocks, then the normal inlining logic will place the OpLoopMerge in + // the last of several blocks in the loop. Instead, it should be placed + // at the end of the first block. First determine if the caller is in a + // single block loop. We'll wait to move the OpLoopMerge until the end + // of the regular inlining logic, and only if necessary. + bool caller_is_single_block_loop = false; + bool caller_is_loop_header = false; + if (auto* loop_merge = call_block_itr->GetLoopMergeInst()) { + caller_is_loop_header = true; + caller_is_single_block_loop = + call_block_itr->id() == + loop_merge->GetSingleWordInOperand(kSpvLoopMergeContinueTargetIdInIdx); + } + + bool callee_begins_with_structured_header = + (*(calleeFn->begin())).GetMergeInst() != nullptr; + + // Clone and map callee code. Copy caller block code to beginning of + // first block and end of last block. + bool prevInstWasReturn = false; + uint32_t singleTripLoopHeaderId = 0; + uint32_t singleTripLoopContinueId = 0; + uint32_t returnLabelId = 0; + bool multiBlocks = false; + // new_blk_ptr is a new basic block in the caller. New instructions are + // written to it. It is created when we encounter the OpLabel + // of the first callee block. It is appended to new_blocks only when + // it is complete. + std::unique_ptr new_blk_ptr; + bool successful = calleeFn->WhileEachInst( + [&new_blocks, &callee2caller, &call_block_itr, &call_inst_itr, + &new_blk_ptr, &prevInstWasReturn, &returnLabelId, &returnVarId, + caller_is_loop_header, callee_begins_with_structured_header, + &calleeTypeId, &multiBlocks, &postCallSB, &preCallSB, earlyReturn, + &singleTripLoopHeaderId, &singleTripLoopContinueId, &callee_result_ids, + this](const Instruction* cpi) { + switch (cpi->opcode()) { + case SpvOpFunction: + case SpvOpFunctionParameter: + // Already processed + break; + case SpvOpVariable: + if (cpi->NumInOperands() == 2) { + assert(callee2caller.count(cpi->result_id()) && + "Expected the variable to have already been mapped."); + uint32_t new_var_id = callee2caller.at(cpi->result_id()); + + // The initializer must be a constant or global value. No mapped + // should be used. + uint32_t val_id = cpi->GetSingleWordInOperand(1); + AddStore(new_var_id, val_id, &new_blk_ptr); + } + break; + case SpvOpUnreachable: + case SpvOpKill: { + // Generate a return label so that we split the block with the + // function call. Copy the terminator into the new block. + if (returnLabelId == 0) { + returnLabelId = context()->TakeNextId(); + if (returnLabelId == 0) { + return false; + } + } + std::unique_ptr terminator( + new Instruction(context(), cpi->opcode(), 0, 0, {})); + new_blk_ptr->AddInstruction(std::move(terminator)); + break; + } + case SpvOpLabel: { + // If previous instruction was early return, insert branch + // instruction to return block. + if (prevInstWasReturn) { + if (returnLabelId == 0) { + returnLabelId = context()->TakeNextId(); + if (returnLabelId == 0) { + return false; + } + } + AddBranch(returnLabelId, &new_blk_ptr); + prevInstWasReturn = false; + } + // Finish current block (if it exists) and get label for next block. + uint32_t labelId; + bool firstBlock = false; + if (new_blk_ptr != nullptr) { + new_blocks->push_back(std::move(new_blk_ptr)); + // If result id is already mapped, use it, otherwise get a new + // one. + const uint32_t rid = cpi->result_id(); + const auto mapItr = callee2caller.find(rid); + labelId = (mapItr != callee2caller.end()) + ? mapItr->second + : context()->TakeNextId(); + if (labelId == 0) { + return false; + } + } else { + // First block needs to use label of original block + // but map callee label in case of phi reference. + labelId = call_block_itr->id(); + callee2caller[cpi->result_id()] = labelId; + firstBlock = true; + } + // Create first/next block. + new_blk_ptr = MakeUnique(NewLabel(labelId)); + if (firstBlock) { + // Copy contents of original caller block up to call instruction. + for (auto cii = call_block_itr->begin(); cii != call_inst_itr; + cii = call_block_itr->begin()) { + Instruction* inst = &*cii; + inst->RemoveFromList(); + std::unique_ptr cp_inst(inst); + // Remember same-block ops for possible regeneration. + if (IsSameBlockOp(&*cp_inst)) { + auto* sb_inst_ptr = cp_inst.get(); + preCallSB[cp_inst->result_id()] = sb_inst_ptr; + } + new_blk_ptr->AddInstruction(std::move(cp_inst)); + } + if (caller_is_loop_header && + callee_begins_with_structured_header) { + // We can't place both the caller's merge instruction and + // another merge instruction in the same block. So split the + // calling block. Insert an unconditional branch to a new guard + // block. Later, once we know the ID of the last block, we + // will move the caller's OpLoopMerge from the last generated + // block into the first block. We also wait to avoid + // invalidating various iterators. + const auto guard_block_id = context()->TakeNextId(); + if (guard_block_id == 0) { + return false; + } + AddBranch(guard_block_id, &new_blk_ptr); + new_blocks->push_back(std::move(new_blk_ptr)); + // Start the next block. + new_blk_ptr = MakeUnique(NewLabel(guard_block_id)); + // Reset the mapping of the callee's entry block to point to + // the guard block. Do this so we can fix up phis later on to + // satisfy dominance. + callee2caller[cpi->result_id()] = guard_block_id; + } + // If callee has early return, insert a header block for + // single-trip loop that will encompass callee code. Start + // postheader block. + // + // Note: Consider the following combination: + // - the caller is a single block loop + // - the callee does not begin with a structure header + // - the callee has multiple returns. + // We still need to split the caller block and insert a guard + // block. But we only need to do it once. We haven't done it yet, + // but the single-trip loop header will serve the same purpose. + if (earlyReturn) { + singleTripLoopHeaderId = context()->TakeNextId(); + if (singleTripLoopHeaderId == 0) { + return false; + } + AddBranch(singleTripLoopHeaderId, &new_blk_ptr); + new_blocks->push_back(std::move(new_blk_ptr)); + new_blk_ptr = + MakeUnique(NewLabel(singleTripLoopHeaderId)); + returnLabelId = context()->TakeNextId(); + singleTripLoopContinueId = context()->TakeNextId(); + if (returnLabelId == 0 || singleTripLoopContinueId == 0) { + return false; + } + AddLoopMerge(returnLabelId, singleTripLoopContinueId, + &new_blk_ptr); + uint32_t postHeaderId = context()->TakeNextId(); + if (postHeaderId == 0) { + return false; + } + AddBranch(postHeaderId, &new_blk_ptr); + new_blocks->push_back(std::move(new_blk_ptr)); + new_blk_ptr = MakeUnique(NewLabel(postHeaderId)); + multiBlocks = true; + // Reset the mapping of the callee's entry block to point to + // the post-header block. Do this so we can fix up phis later + // on to satisfy dominance. + callee2caller[cpi->result_id()] = postHeaderId; + } + } else { + multiBlocks = true; + } + } break; + case SpvOpReturnValue: { + // Store return value to return variable. + assert(returnVarId != 0); + uint32_t valId = cpi->GetInOperand(kSpvReturnValueId).words[0]; + const auto mapItr = callee2caller.find(valId); + if (mapItr != callee2caller.end()) { + valId = mapItr->second; + } + AddStore(returnVarId, valId, &new_blk_ptr); + + // Remember we saw a return; if followed by a label, will need to + // insert branch. + prevInstWasReturn = true; + } break; + case SpvOpReturn: { + // Remember we saw a return; if followed by a label, will need to + // insert branch. + prevInstWasReturn = true; + } break; + case SpvOpFunctionEnd: { + // If there was an early return, we generated a return label id + // for it. Now we have to generate the return block with that Id. + if (returnLabelId != 0) { + // If previous instruction was return, insert branch instruction + // to return block. + if (prevInstWasReturn) AddBranch(returnLabelId, &new_blk_ptr); + if (earlyReturn) { + // If we generated a loop header for the single-trip loop + // to accommodate early returns, insert the continue + // target block now, with a false branch back to the loop + // header. + new_blocks->push_back(std::move(new_blk_ptr)); + new_blk_ptr = + MakeUnique(NewLabel(singleTripLoopContinueId)); + uint32_t false_id = GetFalseId(); + if (false_id == 0) { + return false; + } + AddBranchCond(false_id, singleTripLoopHeaderId, returnLabelId, + &new_blk_ptr); + } + // Generate the return block. + new_blocks->push_back(std::move(new_blk_ptr)); + new_blk_ptr = MakeUnique(NewLabel(returnLabelId)); + multiBlocks = true; + } + // Load return value into result id of call, if it exists. + if (returnVarId != 0) { + const uint32_t resId = call_inst_itr->result_id(); + assert(resId != 0); + AddLoad(calleeTypeId, resId, returnVarId, &new_blk_ptr); + } + // Copy remaining instructions from caller block. + for (Instruction* inst = call_inst_itr->NextNode(); inst; + inst = call_inst_itr->NextNode()) { + inst->RemoveFromList(); + std::unique_ptr cp_inst(inst); + // If multiple blocks generated, regenerate any same-block + // instruction that has not been seen in this last block. + if (multiBlocks) { + if (!CloneSameBlockOps(&cp_inst, &postCallSB, &preCallSB, + &new_blk_ptr)) { + return false; + } + + // Remember same-block ops in this block. + if (IsSameBlockOp(&*cp_inst)) { + const uint32_t rid = cp_inst->result_id(); + postCallSB[rid] = rid; + } + } + new_blk_ptr->AddInstruction(std::move(cp_inst)); + } + // Finalize inline code. + new_blocks->push_back(std::move(new_blk_ptr)); + } break; + default: { + // Copy callee instruction and remap all input Ids. + std::unique_ptr cp_inst(cpi->Clone(context())); + bool succeeded = cp_inst->WhileEachInId( + [&callee2caller, &callee_result_ids, this](uint32_t* iid) { + const auto mapItr = callee2caller.find(*iid); + if (mapItr != callee2caller.end()) { + *iid = mapItr->second; + } else if (callee_result_ids.find(*iid) != + callee_result_ids.end()) { + // Forward reference. Allocate a new id, map it, + // use it and check for it when remapping result ids + const uint32_t nid = context()->TakeNextId(); + if (nid == 0) { + return false; + } + callee2caller[*iid] = nid; + *iid = nid; + } + return true; + }); + if (!succeeded) { + return false; + } + // If result id is non-zero, remap it. If already mapped, use mapped + // value, else use next id. + const uint32_t rid = cp_inst->result_id(); + if (rid != 0) { + const auto mapItr = callee2caller.find(rid); + uint32_t nid; + if (mapItr != callee2caller.end()) { + nid = mapItr->second; + } else { + nid = context()->TakeNextId(); + if (nid == 0) { + return false; + } + callee2caller[rid] = nid; + } + cp_inst->SetResultId(nid); + get_decoration_mgr()->CloneDecorations(rid, nid); + } + new_blk_ptr->AddInstruction(std::move(cp_inst)); + } break; + } + return true; + }); + + if (!successful) { + return false; + } + + if (caller_is_loop_header && (new_blocks->size() > 1)) { + // Move the OpLoopMerge from the last block back to the first, where + // it belongs. + auto& first = new_blocks->front(); + auto& last = new_blocks->back(); + assert(first != last); + + // Insert a modified copy of the loop merge into the first block. + auto loop_merge_itr = last->tail(); + --loop_merge_itr; + assert(loop_merge_itr->opcode() == SpvOpLoopMerge); + std::unique_ptr cp_inst(loop_merge_itr->Clone(context())); + if (caller_is_single_block_loop) { + // Also, update its continue target to point to the last block. + cp_inst->SetInOperand(kSpvLoopMergeContinueTargetIdInIdx, {last->id()}); + } + first->tail().InsertBefore(std::move(cp_inst)); + + // Remove the loop merge from the last block. + loop_merge_itr->RemoveFromList(); + delete &*loop_merge_itr; + } + + // Update block map given replacement blocks. + for (auto& blk : *new_blocks) { + id2block_[blk->id()] = &*blk; + } + return true; +} + +bool InlinePass::IsInlinableFunctionCall(const Instruction* inst) { + if (inst->opcode() != SpvOp::SpvOpFunctionCall) return false; + const uint32_t calleeFnId = + inst->GetSingleWordOperand(kSpvFunctionCallFunctionId); + const auto ci = inlinable_.find(calleeFnId); + return ci != inlinable_.cend(); +} + +void InlinePass::UpdateSucceedingPhis( + std::vector>& new_blocks) { + const auto firstBlk = new_blocks.begin(); + const auto lastBlk = new_blocks.end() - 1; + const uint32_t firstId = (*firstBlk)->id(); + const uint32_t lastId = (*lastBlk)->id(); + const BasicBlock& const_last_block = *lastBlk->get(); + const_last_block.ForEachSuccessorLabel( + [&firstId, &lastId, this](const uint32_t succ) { + BasicBlock* sbp = this->id2block_[succ]; + sbp->ForEachPhiInst([&firstId, &lastId](Instruction* phi) { + phi->ForEachInId([&firstId, &lastId](uint32_t* id) { + if (*id == firstId) *id = lastId; + }); + }); + }); +} + +bool InlinePass::HasNoReturnInStructuredConstruct(Function* func) { + // If control not structured, do not do loop/return analysis + // TODO: Analyze returns in non-structured control flow + if (!context()->get_feature_mgr()->HasCapability(SpvCapabilityShader)) + return false; + const auto structured_analysis = context()->GetStructuredCFGAnalysis(); + // Search for returns in structured construct. + bool return_in_construct = false; + for (auto& blk : *func) { + auto terminal_ii = blk.cend(); + --terminal_ii; + if (spvOpcodeIsReturn(terminal_ii->opcode()) && + structured_analysis->ContainingConstruct(blk.id()) != 0) { + return_in_construct = true; + break; + } + } + return !return_in_construct; +} + +bool InlinePass::HasNoReturnInLoop(Function* func) { + // If control not structured, do not do loop/return analysis + // TODO: Analyze returns in non-structured control flow + if (!context()->get_feature_mgr()->HasCapability(SpvCapabilityShader)) + return false; + const auto structured_analysis = context()->GetStructuredCFGAnalysis(); + // Search for returns in structured construct. + bool return_in_loop = false; + for (auto& blk : *func) { + auto terminal_ii = blk.cend(); + --terminal_ii; + if (spvOpcodeIsReturn(terminal_ii->opcode()) && + structured_analysis->ContainingLoop(blk.id()) != 0) { + return_in_loop = true; + break; + } + } + return !return_in_loop; +} + +void InlinePass::AnalyzeReturns(Function* func) { + if (HasNoReturnInLoop(func)) { + no_return_in_loop_.insert(func->result_id()); + if (!HasNoReturnInStructuredConstruct(func)) + early_return_funcs_.insert(func->result_id()); + } +} + +bool InlinePass::IsInlinableFunction(Function* func) { + // We can only inline a function if it has blocks. + if (func->cbegin() == func->cend()) return false; + // Do not inline functions with returns in loops. Currently early return + // functions are inlined by wrapping them in a one trip loop and implementing + // the returns as a branch to the loop's merge block. However, this can only + // done validly if the return was not in a loop in the original function. + // Also remember functions with multiple (early) returns. + AnalyzeReturns(func); + if (no_return_in_loop_.find(func->result_id()) == no_return_in_loop_.cend()) { + return false; + } + + if (func->IsRecursive()) { + return false; + } + + return true; +} + +void InlinePass::InitializeInline() { + false_id_ = 0; + + // clear collections + id2function_.clear(); + id2block_.clear(); + inlinable_.clear(); + no_return_in_loop_.clear(); + early_return_funcs_.clear(); + + for (auto& fn : *get_module()) { + // Initialize function and block maps. + id2function_[fn.result_id()] = &fn; + for (auto& blk : fn) { + id2block_[blk.id()] = &blk; + } + // Compute inlinability + if (IsInlinableFunction(&fn)) inlinable_.insert(fn.result_id()); + } +} + +InlinePass::InlinePass() {} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/inline_pass.h b/source/opt/inline_pass.h new file mode 100644 index 000000000..ecfe964f1 --- /dev/null +++ b/source/opt/inline_pass.h @@ -0,0 +1,172 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_INLINE_PASS_H_ +#define SOURCE_OPT_INLINE_PASS_H_ + +#include +#include +#include +#include +#include +#include + +#include "source/opt/decoration_manager.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class InlinePass : public Pass { + using cbb_ptr = const BasicBlock*; + + public: + virtual ~InlinePass() = default; + + protected: + InlinePass(); + + // Add pointer to type to module and return resultId. Returns 0 if the type + // could not be created. + uint32_t AddPointerToType(uint32_t type_id, SpvStorageClass storage_class); + + // Add unconditional branch to labelId to end of block block_ptr. + void AddBranch(uint32_t labelId, std::unique_ptr* block_ptr); + + // Add conditional branch to end of block |block_ptr|. + void AddBranchCond(uint32_t cond_id, uint32_t true_id, uint32_t false_id, + std::unique_ptr* block_ptr); + + // Add unconditional branch to labelId to end of block block_ptr. + void AddLoopMerge(uint32_t merge_id, uint32_t continue_id, + std::unique_ptr* block_ptr); + + // Add store of valId to ptrId to end of block block_ptr. + void AddStore(uint32_t ptrId, uint32_t valId, + std::unique_ptr* block_ptr); + + // Add load of ptrId into resultId to end of block block_ptr. + void AddLoad(uint32_t typeId, uint32_t resultId, uint32_t ptrId, + std::unique_ptr* block_ptr); + + // Return new label. + std::unique_ptr NewLabel(uint32_t label_id); + + // Returns the id for the boolean false value. Looks in the module first + // and creates it if not found. Remembers it for future calls. Returns 0 if + // the value could not be created. + uint32_t GetFalseId(); + + // Map callee params to caller args + void MapParams(Function* calleeFn, BasicBlock::iterator call_inst_itr, + std::unordered_map* callee2caller); + + // Clone and map callee locals. Return true if successful. + bool CloneAndMapLocals(Function* calleeFn, + std::vector>* new_vars, + std::unordered_map* callee2caller); + + // Create return variable for callee clone code. The return type of + // |calleeFn| must not be void. Returns the id of the return variable if + // created. Returns 0 if the return variable could not be created. + uint32_t CreateReturnVar(Function* calleeFn, + std::vector>* new_vars); + + // Return true if instruction must be in the same block that its result + // is used. + bool IsSameBlockOp(const Instruction* inst) const; + + // Clone operands which must be in same block as consumer instructions. + // Look in preCallSB for instructions that need cloning. Look in + // postCallSB for instructions already cloned. Add cloned instruction + // to postCallSB. + bool CloneSameBlockOps(std::unique_ptr* inst, + std::unordered_map* postCallSB, + std::unordered_map* preCallSB, + std::unique_ptr* block_ptr); + + // Return in new_blocks the result of inlining the call at call_inst_itr + // within its block at call_block_itr. The block at call_block_itr can + // just be replaced with the blocks in new_blocks. Any additional branches + // are avoided. Debug instructions are cloned along with their callee + // instructions. Early returns are replaced by a store to a local return + // variable and a branch to a (created) exit block where the local variable + // is returned. Formal parameters are trivially mapped to their actual + // parameters. Note that the first block in new_blocks retains the label + // of the original calling block. Also note that if an exit block is + // created, it is the last block of new_blocks. + // + // Also return in new_vars additional OpVariable instructions required by + // and to be inserted into the caller function after the block at + // call_block_itr is replaced with new_blocks. + // + // Returns true if successful. + bool GenInlineCode(std::vector>* new_blocks, + std::vector>* new_vars, + BasicBlock::iterator call_inst_itr, + UptrVectorIterator call_block_itr); + + // Return true if |inst| is a function call that can be inlined. + bool IsInlinableFunctionCall(const Instruction* inst); + + // Return true if |func| does not have a return that is + // nested in a structured if, switch or loop. + bool HasNoReturnInStructuredConstruct(Function* func); + + // Return true if |func| has no return in a loop. The current analysis + // requires structured control flow, so return false if control flow not + // structured ie. module is not a shader. + bool HasNoReturnInLoop(Function* func); + + // Find all functions with multiple returns and no returns in loops + void AnalyzeReturns(Function* func); + + // Return true if |func| is a function that can be inlined. + bool IsInlinableFunction(Function* func); + + // Update phis in succeeding blocks to point to new last block + void UpdateSucceedingPhis( + std::vector>& new_blocks); + + // Initialize state for optimization of |module| + void InitializeInline(); + + // Map from function's result id to function. + std::unordered_map id2function_; + + // Map from block's label id to block. TODO(dnovillo): This is superfluous wrt + // CFG. It has functionality not present in CFG. Consolidate. + std::unordered_map id2block_; + + // Set of ids of functions with early return. + std::set early_return_funcs_; + + // Set of ids of functions with no returns in loop + std::set no_return_in_loop_; + + // Set of ids of inlinable functions + std::set inlinable_; + + // result id for OpConstantFalse + uint32_t false_id_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_INLINE_PASS_H_ diff --git a/source/opt/inst_bindless_check_pass.cpp b/source/opt/inst_bindless_check_pass.cpp new file mode 100644 index 000000000..1901f763b --- /dev/null +++ b/source/opt/inst_bindless_check_pass.cpp @@ -0,0 +1,263 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// Copyright (c) 2018 Valve Corporation +// Copyright (c) 2018 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "inst_bindless_check_pass.h" + +namespace { + +// Input Operand Indices +static const int kSpvImageSampleImageIdInIdx = 0; +static const int kSpvSampledImageImageIdInIdx = 0; +static const int kSpvSampledImageSamplerIdInIdx = 1; +static const int kSpvImageSampledImageIdInIdx = 0; +static const int kSpvLoadPtrIdInIdx = 0; +static const int kSpvAccessChainBaseIdInIdx = 0; +static const int kSpvAccessChainIndex0IdInIdx = 1; +static const int kSpvTypePointerTypeIdInIdx = 1; +static const int kSpvTypeArrayLengthIdInIdx = 1; +static const int kSpvConstantValueInIdx = 0; + +} // anonymous namespace + +namespace spvtools { +namespace opt { + +void InstBindlessCheckPass::GenBindlessCheckCode( + BasicBlock::iterator ref_inst_itr, + UptrVectorIterator ref_block_itr, uint32_t instruction_idx, + uint32_t stage_idx, std::vector>* new_blocks) { + // Look for reference through bindless descriptor. If not, return. + std::unique_ptr new_blk_ptr; + uint32_t image_id; + switch (ref_inst_itr->opcode()) { + case SpvOp::SpvOpImageSampleImplicitLod: + case SpvOp::SpvOpImageSampleExplicitLod: + case SpvOp::SpvOpImageSampleDrefImplicitLod: + case SpvOp::SpvOpImageSampleDrefExplicitLod: + case SpvOp::SpvOpImageSampleProjImplicitLod: + case SpvOp::SpvOpImageSampleProjExplicitLod: + case SpvOp::SpvOpImageSampleProjDrefImplicitLod: + case SpvOp::SpvOpImageSampleProjDrefExplicitLod: + case SpvOp::SpvOpImageGather: + case SpvOp::SpvOpImageDrefGather: + case SpvOp::SpvOpImageQueryLod: + case SpvOp::SpvOpImageSparseSampleImplicitLod: + case SpvOp::SpvOpImageSparseSampleExplicitLod: + case SpvOp::SpvOpImageSparseSampleDrefImplicitLod: + case SpvOp::SpvOpImageSparseSampleDrefExplicitLod: + case SpvOp::SpvOpImageSparseSampleProjImplicitLod: + case SpvOp::SpvOpImageSparseSampleProjExplicitLod: + case SpvOp::SpvOpImageSparseSampleProjDrefImplicitLod: + case SpvOp::SpvOpImageSparseSampleProjDrefExplicitLod: + case SpvOp::SpvOpImageSparseGather: + case SpvOp::SpvOpImageSparseDrefGather: + case SpvOp::SpvOpImageFetch: + case SpvOp::SpvOpImageRead: + case SpvOp::SpvOpImageQueryFormat: + case SpvOp::SpvOpImageQueryOrder: + case SpvOp::SpvOpImageQuerySizeLod: + case SpvOp::SpvOpImageQuerySize: + case SpvOp::SpvOpImageQueryLevels: + case SpvOp::SpvOpImageQuerySamples: + case SpvOp::SpvOpImageSparseFetch: + case SpvOp::SpvOpImageSparseRead: + case SpvOp::SpvOpImageWrite: + image_id = + ref_inst_itr->GetSingleWordInOperand(kSpvImageSampleImageIdInIdx); + break; + default: + return; + } + Instruction* image_inst = get_def_use_mgr()->GetDef(image_id); + uint32_t load_id; + Instruction* load_inst; + if (image_inst->opcode() == SpvOp::SpvOpSampledImage) { + load_id = image_inst->GetSingleWordInOperand(kSpvSampledImageImageIdInIdx); + load_inst = get_def_use_mgr()->GetDef(load_id); + } else if (image_inst->opcode() == SpvOp::SpvOpImage) { + load_id = image_inst->GetSingleWordInOperand(kSpvImageSampledImageIdInIdx); + load_inst = get_def_use_mgr()->GetDef(load_id); + } else { + load_id = image_id; + load_inst = image_inst; + image_id = 0; + } + if (load_inst->opcode() != SpvOp::SpvOpLoad) { + // TODO(greg-lunarg): Handle additional possibilities + return; + } + uint32_t ptr_id = load_inst->GetSingleWordInOperand(kSpvLoadPtrIdInIdx); + Instruction* ptr_inst = get_def_use_mgr()->GetDef(ptr_id); + if (ptr_inst->opcode() != SpvOp::SpvOpAccessChain) return; + if (ptr_inst->NumInOperands() != 2) { + assert(false && "unexpected bindless index number"); + return; + } + uint32_t index_id = + ptr_inst->GetSingleWordInOperand(kSpvAccessChainIndex0IdInIdx); + ptr_id = ptr_inst->GetSingleWordInOperand(kSpvAccessChainBaseIdInIdx); + ptr_inst = get_def_use_mgr()->GetDef(ptr_id); + if (ptr_inst->opcode() != SpvOpVariable) { + assert(false && "unexpected bindless base"); + return; + } + uint32_t var_type_id = ptr_inst->type_id(); + Instruction* var_type_inst = get_def_use_mgr()->GetDef(var_type_id); + uint32_t ptr_type_id = + var_type_inst->GetSingleWordInOperand(kSpvTypePointerTypeIdInIdx); + Instruction* ptr_type_inst = get_def_use_mgr()->GetDef(ptr_type_id); + // TODO(greg-lunarg): Handle RuntimeArray. Will need to pull length + // out of debug input buffer. + if (ptr_type_inst->opcode() != SpvOpTypeArray) return; + // If index and bound both compile-time constants and index < bound, + // return without changing + uint32_t length_id = + ptr_type_inst->GetSingleWordInOperand(kSpvTypeArrayLengthIdInIdx); + Instruction* index_inst = get_def_use_mgr()->GetDef(index_id); + Instruction* length_inst = get_def_use_mgr()->GetDef(length_id); + if (index_inst->opcode() == SpvOpConstant && + length_inst->opcode() == SpvOpConstant && + index_inst->GetSingleWordInOperand(kSpvConstantValueInIdx) < + length_inst->GetSingleWordInOperand(kSpvConstantValueInIdx)) + return; + // Generate full runtime bounds test code with true branch + // being full reference and false branch being debug output and zero + // for the referenced value. + MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr); + InstructionBuilder builder( + context(), &*new_blk_ptr, + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + uint32_t error_id = builder.GetUintConstantId(kInstErrorBindlessBounds); + Instruction* ult_inst = + builder.AddBinaryOp(GetBoolId(), SpvOpULessThan, index_id, length_id); + uint32_t merge_blk_id = TakeNextId(); + uint32_t valid_blk_id = TakeNextId(); + uint32_t invalid_blk_id = TakeNextId(); + std::unique_ptr merge_label(NewLabel(merge_blk_id)); + std::unique_ptr valid_label(NewLabel(valid_blk_id)); + std::unique_ptr invalid_label(NewLabel(invalid_blk_id)); + (void)builder.AddConditionalBranch(ult_inst->result_id(), valid_blk_id, + invalid_blk_id, merge_blk_id, + SpvSelectionControlMaskNone); + // Close selection block and gen valid reference block + new_blocks->push_back(std::move(new_blk_ptr)); + new_blk_ptr.reset(new BasicBlock(std::move(valid_label))); + builder.SetInsertPoint(&*new_blk_ptr); + // Clone descriptor load + Instruction* new_load_inst = + builder.AddLoad(load_inst->type_id(), + load_inst->GetSingleWordInOperand(kSpvLoadPtrIdInIdx)); + uint32_t new_load_id = new_load_inst->result_id(); + get_decoration_mgr()->CloneDecorations(load_inst->result_id(), new_load_id); + uint32_t new_image_id = new_load_id; + // Clone Image/SampledImage with new load, if needed + if (image_id != 0) { + if (image_inst->opcode() == SpvOp::SpvOpSampledImage) { + Instruction* new_image_inst = builder.AddBinaryOp( + image_inst->type_id(), SpvOpSampledImage, new_load_id, + image_inst->GetSingleWordInOperand(kSpvSampledImageSamplerIdInIdx)); + new_image_id = new_image_inst->result_id(); + } else { + assert(image_inst->opcode() == SpvOp::SpvOpImage && "expecting OpImage"); + Instruction* new_image_inst = + builder.AddUnaryOp(image_inst->type_id(), SpvOpImage, new_load_id); + new_image_id = new_image_inst->result_id(); + } + get_decoration_mgr()->CloneDecorations(image_id, new_image_id); + } + // Clone original reference using new image code + std::unique_ptr new_ref_inst(ref_inst_itr->Clone(context())); + uint32_t ref_result_id = ref_inst_itr->result_id(); + uint32_t new_ref_id = 0; + if (ref_result_id != 0) { + new_ref_id = TakeNextId(); + new_ref_inst->SetResultId(new_ref_id); + } + new_ref_inst->SetInOperand(kSpvImageSampleImageIdInIdx, {new_image_id}); + // Register new reference and add to new block + builder.AddInstruction(std::move(new_ref_inst)); + if (new_ref_id != 0) + get_decoration_mgr()->CloneDecorations(ref_result_id, new_ref_id); + // Close valid block and gen invalid block + (void)builder.AddBranch(merge_blk_id); + new_blocks->push_back(std::move(new_blk_ptr)); + new_blk_ptr.reset(new BasicBlock(std::move(invalid_label))); + builder.SetInsertPoint(&*new_blk_ptr); + uint32_t u_index_id = GenUintCastCode(index_id, &builder); + GenDebugStreamWrite(instruction_idx, stage_idx, + {error_id, u_index_id, length_id}, &builder); + // Remember last invalid block id + uint32_t last_invalid_blk_id = new_blk_ptr->GetLabelInst()->result_id(); + // Gen zero for invalid reference + uint32_t ref_type_id = ref_inst_itr->type_id(); + // Close invalid block and gen merge block + (void)builder.AddBranch(merge_blk_id); + new_blocks->push_back(std::move(new_blk_ptr)); + new_blk_ptr.reset(new BasicBlock(std::move(merge_label))); + builder.SetInsertPoint(&*new_blk_ptr); + // Gen phi of new reference and zero, if necessary, and replace the + // result id of the original reference with that of the Phi. Kill original + // reference and move in remainder of original block. + if (new_ref_id != 0) { + Instruction* phi_inst = builder.AddPhi( + ref_type_id, {new_ref_id, valid_blk_id, builder.GetNullId(ref_type_id), + last_invalid_blk_id}); + context()->ReplaceAllUsesWith(ref_result_id, phi_inst->result_id()); + } + context()->KillInst(&*ref_inst_itr); + MovePostludeCode(ref_block_itr, &new_blk_ptr); + // Add remainder/merge block to new blocks + new_blocks->push_back(std::move(new_blk_ptr)); +} + +void InstBindlessCheckPass::InitializeInstBindlessCheck() { + // Initialize base class + InitializeInstrument(); + // Look for related extensions + ext_descriptor_indexing_defined_ = false; + for (auto& ei : get_module()->extensions()) { + const char* ext_name = + reinterpret_cast(&ei.GetInOperand(0).words[0]); + if (strcmp(ext_name, "SPV_EXT_descriptor_indexing") == 0) { + ext_descriptor_indexing_defined_ = true; + break; + } + } +} + +Pass::Status InstBindlessCheckPass::ProcessImpl() { + // Perform instrumentation on each entry point function in module + InstProcessFunction pfn = + [this](BasicBlock::iterator ref_inst_itr, + UptrVectorIterator ref_block_itr, + uint32_t instruction_idx, uint32_t stage_idx, + std::vector>* new_blocks) { + return GenBindlessCheckCode(ref_inst_itr, ref_block_itr, + instruction_idx, stage_idx, new_blocks); + }; + bool modified = InstProcessEntryPointCallTree(pfn); + // This pass does not update inst->blk info + context()->InvalidateAnalyses(IRContext::kAnalysisInstrToBlockMapping); + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +Pass::Status InstBindlessCheckPass::Process() { + InitializeInstBindlessCheck(); + return ProcessImpl(); +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/inst_bindless_check_pass.h b/source/opt/inst_bindless_check_pass.h new file mode 100644 index 000000000..3ab5ab7cf --- /dev/null +++ b/source/opt/inst_bindless_check_pass.h @@ -0,0 +1,93 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// Copyright (c) 2018 Valve Corporation +// Copyright (c) 2018 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIBSPIRV_OPT_INST_BINDLESS_CHECK_PASS_H_ +#define LIBSPIRV_OPT_INST_BINDLESS_CHECK_PASS_H_ + +#include "instrument_pass.h" + +namespace spvtools { +namespace opt { + +// This class/pass is designed to support the bindless (descriptor indexing) +// GPU-assisted validation layer of +// https://github.com/KhronosGroup/Vulkan-ValidationLayers. Its internal and +// external design may change as the layer evolves. +class InstBindlessCheckPass : public InstrumentPass { + public: + // For test harness only + InstBindlessCheckPass() : InstrumentPass(7, 23, kInstValidationIdBindless) {} + // For all other interfaces + InstBindlessCheckPass(uint32_t desc_set, uint32_t shader_id) + : InstrumentPass(desc_set, shader_id, kInstValidationIdBindless) {} + + ~InstBindlessCheckPass() override = default; + + // See optimizer.hpp for pass user documentation. + Status Process() override; + + const char* name() const override { return "inst-bindless-check-pass"; } + + private: + // Initialize state for instrumenting bindless checking + void InitializeInstBindlessCheck(); + + // This function does bindless checking instrumentation on a single + // instruction. It is designed to be passed to + // InstrumentPass::InstProcessEntryPointCallTree(), which applies the + // function to each instruction in a module and replaces the instruction + // if warranted. + // + // If |ref_inst_itr| is a bindless reference, return in |new_blocks| the + // result of instrumenting it with validation code within its block at + // |ref_block_itr|. Specifically, generate code to check that the index + // into the descriptor array is in-bounds. If the check passes, execute + // the remainder of the reference, otherwise write a record to the debug + // output buffer stream including |function_idx, instruction_idx, stage_idx| + // and replace the reference with the null value of the original type. The + // block at |ref_block_itr| can just be replaced with the blocks in + // |new_blocks|, which will contain at least two blocks. The last block will + // comprise all instructions following |ref_inst_itr|, + // preceded by a phi instruction. + // + // This instrumentation pass utilizes GenDebugStreamWrite() to write its + // error records. The validation-specific part of the error record will + // have the format: + // + // Validation Error Code (=kInstErrorBindlessBounds) + // Descriptor Index + // Descriptor Array Size + // + // The Descriptor Index is the index which has been determined to be + // out-of-bounds. + // + // The Descriptor Array Size is the size of the descriptor array which was + // indexed. + void GenBindlessCheckCode( + BasicBlock::iterator ref_inst_itr, + UptrVectorIterator ref_block_itr, uint32_t instruction_idx, + uint32_t stage_idx, std::vector>* new_blocks); + + Pass::Status ProcessImpl(); + + // True if VK_EXT_descriptor_indexing is defined + bool ext_descriptor_indexing_defined_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // LIBSPIRV_OPT_INST_BINDLESS_CHECK_PASS_H_ diff --git a/source/opt/instruction.cpp b/source/opt/instruction.cpp new file mode 100644 index 000000000..5f3c5a889 --- /dev/null +++ b/source/opt/instruction.cpp @@ -0,0 +1,753 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/instruction.h" + +#include + +#include "source/disassemble.h" +#include "source/opt/fold.h" +#include "source/opt/ir_context.h" +#include "source/opt/reflect.h" + +namespace spvtools { +namespace opt { + +namespace { +// Indices used to get particular operands out of instructions using InOperand. +const uint32_t kTypeImageDimIndex = 1; +const uint32_t kLoadBaseIndex = 0; +const uint32_t kVariableStorageClassIndex = 0; +const uint32_t kTypeImageSampledIndex = 5; +} // namespace + +Instruction::Instruction(IRContext* c) + : utils::IntrusiveNodeBase(), + context_(c), + opcode_(SpvOpNop), + has_type_id_(false), + has_result_id_(false), + unique_id_(c->TakeNextUniqueId()) {} + +Instruction::Instruction(IRContext* c, SpvOp op) + : utils::IntrusiveNodeBase(), + context_(c), + opcode_(op), + has_type_id_(false), + has_result_id_(false), + unique_id_(c->TakeNextUniqueId()) {} + +Instruction::Instruction(IRContext* c, const spv_parsed_instruction_t& inst, + std::vector&& dbg_line) + : context_(c), + opcode_(static_cast(inst.opcode)), + has_type_id_(inst.type_id != 0), + has_result_id_(inst.result_id != 0), + unique_id_(c->TakeNextUniqueId()), + dbg_line_insts_(std::move(dbg_line)) { + assert((!IsDebugLineInst(opcode_) || dbg_line.empty()) && + "Op(No)Line attaching to Op(No)Line found"); + for (uint32_t i = 0; i < inst.num_operands; ++i) { + const auto& current_payload = inst.operands[i]; + std::vector words( + inst.words + current_payload.offset, + inst.words + current_payload.offset + current_payload.num_words); + operands_.emplace_back(current_payload.type, std::move(words)); + } +} + +Instruction::Instruction(IRContext* c, SpvOp op, uint32_t ty_id, + uint32_t res_id, const OperandList& in_operands) + : utils::IntrusiveNodeBase(), + context_(c), + opcode_(op), + has_type_id_(ty_id != 0), + has_result_id_(res_id != 0), + unique_id_(c->TakeNextUniqueId()), + operands_() { + if (has_type_id_) { + operands_.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_TYPE_ID, + std::initializer_list{ty_id}); + } + if (has_result_id_) { + operands_.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_RESULT_ID, + std::initializer_list{res_id}); + } + operands_.insert(operands_.end(), in_operands.begin(), in_operands.end()); +} + +Instruction::Instruction(Instruction&& that) + : utils::IntrusiveNodeBase(), + opcode_(that.opcode_), + has_type_id_(that.has_type_id_), + has_result_id_(that.has_result_id_), + unique_id_(that.unique_id_), + operands_(std::move(that.operands_)), + dbg_line_insts_(std::move(that.dbg_line_insts_)) {} + +Instruction& Instruction::operator=(Instruction&& that) { + opcode_ = that.opcode_; + has_type_id_ = that.has_type_id_; + has_result_id_ = that.has_result_id_; + unique_id_ = that.unique_id_; + operands_ = std::move(that.operands_); + dbg_line_insts_ = std::move(that.dbg_line_insts_); + return *this; +} + +Instruction* Instruction::Clone(IRContext* c) const { + Instruction* clone = new Instruction(c); + clone->opcode_ = opcode_; + clone->has_type_id_ = has_type_id_; + clone->has_result_id_ = has_result_id_; + clone->unique_id_ = c->TakeNextUniqueId(); + clone->operands_ = operands_; + clone->dbg_line_insts_ = dbg_line_insts_; + return clone; +} + +uint32_t Instruction::GetSingleWordOperand(uint32_t index) const { + const auto& words = GetOperand(index).words; + assert(words.size() == 1 && "expected the operand only taking one word"); + return words.front(); +} + +uint32_t Instruction::NumInOperandWords() const { + uint32_t size = 0; + for (uint32_t i = TypeResultIdCount(); i < operands_.size(); ++i) + size += static_cast(operands_[i].words.size()); + return size; +} + +void Instruction::ToBinaryWithoutAttachedDebugInsts( + std::vector* binary) const { + const uint32_t num_words = 1 + NumOperandWords(); + binary->push_back((num_words << 16) | static_cast(opcode_)); + for (const auto& operand : operands_) + binary->insert(binary->end(), operand.words.begin(), operand.words.end()); +} + +void Instruction::ReplaceOperands(const OperandList& new_operands) { + operands_.clear(); + operands_.insert(operands_.begin(), new_operands.begin(), new_operands.end()); +} + +bool Instruction::IsReadOnlyLoad() const { + if (IsLoad()) { + Instruction* address_def = GetBaseAddress(); + if (!address_def || address_def->opcode() != SpvOpVariable) { + return false; + } + return address_def->IsReadOnlyVariable(); + } + return false; +} + +Instruction* Instruction::GetBaseAddress() const { + assert((IsLoad() || opcode() == SpvOpStore || opcode() == SpvOpAccessChain || + opcode() == SpvOpPtrAccessChain || + opcode() == SpvOpInBoundsAccessChain || opcode() == SpvOpCopyObject || + opcode() == SpvOpImageTexelPointer) && + "GetBaseAddress should only be called on instructions that take a " + "pointer or image."); + uint32_t base = GetSingleWordInOperand(kLoadBaseIndex); + Instruction* base_inst = context()->get_def_use_mgr()->GetDef(base); + bool done = false; + while (!done) { + switch (base_inst->opcode()) { + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + case SpvOpPtrAccessChain: + case SpvOpInBoundsPtrAccessChain: + case SpvOpImageTexelPointer: + case SpvOpCopyObject: + // All of these instructions have the base pointer use a base pointer + // in in-operand 0. + base = base_inst->GetSingleWordInOperand(0); + base_inst = context()->get_def_use_mgr()->GetDef(base); + break; + default: + done = true; + break; + } + } + + switch (opcode()) { + case SpvOpLoad: + case SpvOpStore: + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + case SpvOpPtrAccessChain: + case SpvOpImageTexelPointer: + case SpvOpCopyObject: + // A load or store through a pointer. + assert(base_inst->IsValidBasePointer() && + "We cannot have a base pointer come from this load"); + break; + default: + // A load or store of an image. + assert(base_inst->IsValidBaseImage() && "We are expecting an image."); + break; + } + return base_inst; +} + +bool Instruction::IsReadOnlyVariable() const { + if (context()->get_feature_mgr()->HasCapability(SpvCapabilityShader)) + return IsReadOnlyVariableShaders(); + else + return IsReadOnlyVariableKernel(); +} + +bool Instruction::IsVulkanStorageImage() const { + if (opcode() != SpvOpTypePointer) { + return false; + } + + uint32_t storage_class = GetSingleWordInOperand(kVariableStorageClassIndex); + if (storage_class != SpvStorageClassUniformConstant) { + return false; + } + + Instruction* base_type = + context()->get_def_use_mgr()->GetDef(GetSingleWordInOperand(1)); + if (base_type->opcode() != SpvOpTypeImage) { + return false; + } + + if (base_type->GetSingleWordInOperand(kTypeImageDimIndex) == SpvDimBuffer) { + return false; + } + + // Check if the image is sampled. If we do not know for sure that it is, + // then assume it is a storage image. + auto s = base_type->GetSingleWordInOperand(kTypeImageSampledIndex); + return s != 1; +} + +bool Instruction::IsVulkanSampledImage() const { + if (opcode() != SpvOpTypePointer) { + return false; + } + + uint32_t storage_class = GetSingleWordInOperand(kVariableStorageClassIndex); + if (storage_class != SpvStorageClassUniformConstant) { + return false; + } + + Instruction* base_type = + context()->get_def_use_mgr()->GetDef(GetSingleWordInOperand(1)); + if (base_type->opcode() != SpvOpTypeImage) { + return false; + } + + if (base_type->GetSingleWordInOperand(kTypeImageDimIndex) == SpvDimBuffer) { + return false; + } + + // Check if the image is sampled. If we know for sure that it is, + // then return true. + auto s = base_type->GetSingleWordInOperand(kTypeImageSampledIndex); + return s == 1; +} + +bool Instruction::IsVulkanStorageTexelBuffer() const { + if (opcode() != SpvOpTypePointer) { + return false; + } + + uint32_t storage_class = GetSingleWordInOperand(kVariableStorageClassIndex); + if (storage_class != SpvStorageClassUniformConstant) { + return false; + } + + Instruction* base_type = + context()->get_def_use_mgr()->GetDef(GetSingleWordInOperand(1)); + if (base_type->opcode() != SpvOpTypeImage) { + return false; + } + + if (base_type->GetSingleWordInOperand(kTypeImageDimIndex) != SpvDimBuffer) { + return false; + } + + // Check if the image is sampled. If we do not know for sure that it is, + // then assume it is a storage texel buffer. + return base_type->GetSingleWordInOperand(kTypeImageSampledIndex) != 1; +} + +bool Instruction::IsVulkanStorageBuffer() const { + // Is there a difference between a "Storage buffer" and a "dynamic storage + // buffer" in SPIR-V and do we care about the difference? + if (opcode() != SpvOpTypePointer) { + return false; + } + + Instruction* base_type = + context()->get_def_use_mgr()->GetDef(GetSingleWordInOperand(1)); + + if (base_type->opcode() != SpvOpTypeStruct) { + return false; + } + + uint32_t storage_class = GetSingleWordInOperand(kVariableStorageClassIndex); + if (storage_class == SpvStorageClassUniform) { + bool is_buffer_block = false; + context()->get_decoration_mgr()->ForEachDecoration( + base_type->result_id(), SpvDecorationBufferBlock, + [&is_buffer_block](const Instruction&) { is_buffer_block = true; }); + return is_buffer_block; + } else if (storage_class == SpvStorageClassStorageBuffer) { + bool is_block = false; + context()->get_decoration_mgr()->ForEachDecoration( + base_type->result_id(), SpvDecorationBlock, + [&is_block](const Instruction&) { is_block = true; }); + return is_block; + } + return false; +} + +bool Instruction::IsVulkanUniformBuffer() const { + if (opcode() != SpvOpTypePointer) { + return false; + } + + uint32_t storage_class = GetSingleWordInOperand(kVariableStorageClassIndex); + if (storage_class != SpvStorageClassUniform) { + return false; + } + + Instruction* base_type = + context()->get_def_use_mgr()->GetDef(GetSingleWordInOperand(1)); + if (base_type->opcode() != SpvOpTypeStruct) { + return false; + } + + bool is_block = false; + context()->get_decoration_mgr()->ForEachDecoration( + base_type->result_id(), SpvDecorationBlock, + [&is_block](const Instruction&) { is_block = true; }); + return is_block; +} + +bool Instruction::IsReadOnlyVariableShaders() const { + uint32_t storage_class = GetSingleWordInOperand(kVariableStorageClassIndex); + Instruction* type_def = context()->get_def_use_mgr()->GetDef(type_id()); + + switch (storage_class) { + case SpvStorageClassUniformConstant: + if (!type_def->IsVulkanStorageImage() && + !type_def->IsVulkanStorageTexelBuffer()) { + return true; + } + break; + case SpvStorageClassUniform: + if (!type_def->IsVulkanStorageBuffer()) { + return true; + } + break; + case SpvStorageClassPushConstant: + case SpvStorageClassInput: + return true; + default: + break; + } + + bool is_nonwritable = false; + context()->get_decoration_mgr()->ForEachDecoration( + result_id(), SpvDecorationNonWritable, + [&is_nonwritable](const Instruction&) { is_nonwritable = true; }); + return is_nonwritable; +} + +bool Instruction::IsReadOnlyVariableKernel() const { + uint32_t storage_class = GetSingleWordInOperand(kVariableStorageClassIndex); + return storage_class == SpvStorageClassUniformConstant; +} + +uint32_t Instruction::GetTypeComponent(uint32_t element) const { + uint32_t subtype = 0; + switch (opcode()) { + case SpvOpTypeStruct: + subtype = GetSingleWordInOperand(element); + break; + case SpvOpTypeArray: + case SpvOpTypeRuntimeArray: + case SpvOpTypeVector: + case SpvOpTypeMatrix: + // These types all have uniform subtypes. + subtype = GetSingleWordInOperand(0u); + break; + default: + break; + } + + return subtype; +} + +Instruction* Instruction::InsertBefore( + std::vector>&& list) { + Instruction* first_node = list.front().get(); + for (auto& i : list) { + i.release()->InsertBefore(this); + } + list.clear(); + return first_node; +} + +Instruction* Instruction::InsertBefore(std::unique_ptr&& i) { + i.get()->InsertBefore(this); + return i.release(); +} + +bool Instruction::IsValidBasePointer() const { + uint32_t tid = type_id(); + if (tid == 0) { + return false; + } + + Instruction* type = context()->get_def_use_mgr()->GetDef(tid); + if (type->opcode() != SpvOpTypePointer) { + return false; + } + + auto feature_mgr = context()->get_feature_mgr(); + if (feature_mgr->HasCapability(SpvCapabilityAddresses)) { + // TODO: The rules here could be more restrictive. + return true; + } + + if (opcode() == SpvOpVariable || opcode() == SpvOpFunctionParameter) { + return true; + } + + // With variable pointers, there are more valid base pointer objects. + // Variable pointers implicitly declares Variable pointers storage buffer. + SpvStorageClass storage_class = + static_cast(type->GetSingleWordInOperand(0)); + if ((feature_mgr->HasCapability(SpvCapabilityVariablePointersStorageBuffer) && + storage_class == SpvStorageClassStorageBuffer) || + (feature_mgr->HasCapability(SpvCapabilityVariablePointers) && + storage_class == SpvStorageClassWorkgroup)) { + switch (opcode()) { + case SpvOpPhi: + case SpvOpSelect: + case SpvOpFunctionCall: + case SpvOpConstantNull: + return true; + default: + break; + } + } + + uint32_t pointee_type_id = type->GetSingleWordInOperand(1); + Instruction* pointee_type_inst = + context()->get_def_use_mgr()->GetDef(pointee_type_id); + + if (pointee_type_inst->IsOpaqueType()) { + return true; + } + return false; +} + +bool Instruction::IsValidBaseImage() const { + uint32_t tid = type_id(); + if (tid == 0) { + return false; + } + + Instruction* type = context()->get_def_use_mgr()->GetDef(tid); + return (type->opcode() == SpvOpTypeImage || + type->opcode() == SpvOpTypeSampledImage); +} + +bool Instruction::IsOpaqueType() const { + if (opcode() == SpvOpTypeStruct) { + bool is_opaque = false; + ForEachInOperand([&is_opaque, this](const uint32_t* op_id) { + Instruction* type_inst = context()->get_def_use_mgr()->GetDef(*op_id); + is_opaque |= type_inst->IsOpaqueType(); + }); + return is_opaque; + } else if (opcode() == SpvOpTypeArray) { + uint32_t sub_type_id = GetSingleWordInOperand(0); + Instruction* sub_type_inst = + context()->get_def_use_mgr()->GetDef(sub_type_id); + return sub_type_inst->IsOpaqueType(); + } else { + return opcode() == SpvOpTypeRuntimeArray || + spvOpcodeIsBaseOpaqueType(opcode()); + } +} + +bool Instruction::IsFoldable() const { + return IsFoldableByFoldScalar() || + context()->get_instruction_folder().HasConstFoldingRule(opcode()); +} + +bool Instruction::IsFoldableByFoldScalar() const { + const InstructionFolder& folder = context()->get_instruction_folder(); + if (!folder.IsFoldableOpcode(opcode())) { + return false; + } + Instruction* type = context()->get_def_use_mgr()->GetDef(type_id()); + return folder.IsFoldableType(type); +} + +bool Instruction::IsFloatingPointFoldingAllowed() const { + // TODO: Add the rules for kernels. For now it will be pessimistic. + if (!context_->get_feature_mgr()->HasCapability(SpvCapabilityShader)) { + return false; + } + + bool is_nocontract = false; + context_->get_decoration_mgr()->WhileEachDecoration( + result_id(), SpvDecorationNoContraction, + [&is_nocontract](const Instruction&) { + is_nocontract = true; + return false; + }); + return !is_nocontract; +} + +std::string Instruction::PrettyPrint(uint32_t options) const { + // Convert the module to binary. + std::vector module_binary; + context()->module()->ToBinary(&module_binary, /* skip_nop = */ false); + + // Convert the instruction to binary. This is used to identify the correct + // stream of words to output from the module. + std::vector inst_binary; + ToBinaryWithoutAttachedDebugInsts(&inst_binary); + + // Do not generate a header. + return spvInstructionBinaryToText( + context()->grammar().target_env(), inst_binary.data(), inst_binary.size(), + module_binary.data(), module_binary.size(), + options | SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); +} + +std::ostream& operator<<(std::ostream& str, const Instruction& inst) { + str << inst.PrettyPrint(); + return str; +} + +void Instruction::Dump() const { + std::cerr << "Instruction #" << unique_id() << "\n" << *this << "\n"; +} + +bool Instruction::IsOpcodeCodeMotionSafe() const { + switch (opcode_) { + case SpvOpNop: + case SpvOpUndef: + case SpvOpLoad: + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + case SpvOpArrayLength: + case SpvOpVectorExtractDynamic: + case SpvOpVectorInsertDynamic: + case SpvOpVectorShuffle: + case SpvOpCompositeConstruct: + case SpvOpCompositeExtract: + case SpvOpCompositeInsert: + case SpvOpCopyObject: + case SpvOpTranspose: + case SpvOpConvertFToU: + case SpvOpConvertFToS: + case SpvOpConvertSToF: + case SpvOpConvertUToF: + case SpvOpUConvert: + case SpvOpSConvert: + case SpvOpFConvert: + case SpvOpQuantizeToF16: + case SpvOpBitcast: + case SpvOpSNegate: + case SpvOpFNegate: + case SpvOpIAdd: + case SpvOpFAdd: + case SpvOpISub: + case SpvOpFSub: + case SpvOpIMul: + case SpvOpFMul: + case SpvOpUDiv: + case SpvOpSDiv: + case SpvOpFDiv: + case SpvOpUMod: + case SpvOpSRem: + case SpvOpSMod: + case SpvOpFRem: + case SpvOpFMod: + case SpvOpVectorTimesScalar: + case SpvOpMatrixTimesScalar: + case SpvOpVectorTimesMatrix: + case SpvOpMatrixTimesVector: + case SpvOpMatrixTimesMatrix: + case SpvOpOuterProduct: + case SpvOpDot: + case SpvOpIAddCarry: + case SpvOpISubBorrow: + case SpvOpUMulExtended: + case SpvOpSMulExtended: + case SpvOpAny: + case SpvOpAll: + case SpvOpIsNan: + case SpvOpIsInf: + case SpvOpLogicalEqual: + case SpvOpLogicalNotEqual: + case SpvOpLogicalOr: + case SpvOpLogicalAnd: + case SpvOpLogicalNot: + case SpvOpSelect: + case SpvOpIEqual: + case SpvOpINotEqual: + case SpvOpUGreaterThan: + case SpvOpSGreaterThan: + case SpvOpUGreaterThanEqual: + case SpvOpSGreaterThanEqual: + case SpvOpULessThan: + case SpvOpSLessThan: + case SpvOpULessThanEqual: + case SpvOpSLessThanEqual: + case SpvOpFOrdEqual: + case SpvOpFUnordEqual: + case SpvOpFOrdNotEqual: + case SpvOpFUnordNotEqual: + case SpvOpFOrdLessThan: + case SpvOpFUnordLessThan: + case SpvOpFOrdGreaterThan: + case SpvOpFUnordGreaterThan: + case SpvOpFOrdLessThanEqual: + case SpvOpFUnordLessThanEqual: + case SpvOpFOrdGreaterThanEqual: + case SpvOpFUnordGreaterThanEqual: + case SpvOpShiftRightLogical: + case SpvOpShiftRightArithmetic: + case SpvOpShiftLeftLogical: + case SpvOpBitwiseOr: + case SpvOpBitwiseXor: + case SpvOpBitwiseAnd: + case SpvOpNot: + case SpvOpBitFieldInsert: + case SpvOpBitFieldSExtract: + case SpvOpBitFieldUExtract: + case SpvOpBitReverse: + case SpvOpBitCount: + case SpvOpSizeOf: + return true; + default: + return false; + } +} + +bool Instruction::IsScalarizable() const { + if (spvOpcodeIsScalarizable(opcode())) { + return true; + } + + const uint32_t kExtInstSetIdInIdx = 0; + const uint32_t kExtInstInstructionInIdx = 1; + + if (opcode() == SpvOpExtInst) { + uint32_t instSetId = + context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); + + if (GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId) { + switch (GetSingleWordInOperand(kExtInstInstructionInIdx)) { + case GLSLstd450Round: + case GLSLstd450RoundEven: + case GLSLstd450Trunc: + case GLSLstd450FAbs: + case GLSLstd450SAbs: + case GLSLstd450FSign: + case GLSLstd450SSign: + case GLSLstd450Floor: + case GLSLstd450Ceil: + case GLSLstd450Fract: + case GLSLstd450Radians: + case GLSLstd450Degrees: + case GLSLstd450Sin: + case GLSLstd450Cos: + case GLSLstd450Tan: + case GLSLstd450Asin: + case GLSLstd450Acos: + case GLSLstd450Atan: + case GLSLstd450Sinh: + case GLSLstd450Cosh: + case GLSLstd450Tanh: + case GLSLstd450Asinh: + case GLSLstd450Acosh: + case GLSLstd450Atanh: + case GLSLstd450Atan2: + case GLSLstd450Pow: + case GLSLstd450Exp: + case GLSLstd450Log: + case GLSLstd450Exp2: + case GLSLstd450Log2: + case GLSLstd450Sqrt: + case GLSLstd450InverseSqrt: + case GLSLstd450Modf: + case GLSLstd450FMin: + case GLSLstd450UMin: + case GLSLstd450SMin: + case GLSLstd450FMax: + case GLSLstd450UMax: + case GLSLstd450SMax: + case GLSLstd450FClamp: + case GLSLstd450UClamp: + case GLSLstd450SClamp: + case GLSLstd450FMix: + case GLSLstd450Step: + case GLSLstd450SmoothStep: + case GLSLstd450Fma: + case GLSLstd450Frexp: + case GLSLstd450Ldexp: + case GLSLstd450FindILsb: + case GLSLstd450FindSMsb: + case GLSLstd450FindUMsb: + case GLSLstd450NMin: + case GLSLstd450NMax: + case GLSLstd450NClamp: + return true; + default: + return false; + } + } + } + return false; +} + +bool Instruction::IsOpcodeSafeToDelete() const { + if (context()->IsCombinatorInstruction(this)) { + return true; + } + + switch (opcode()) { + case SpvOpDPdx: + case SpvOpDPdy: + case SpvOpFwidth: + case SpvOpDPdxFine: + case SpvOpDPdyFine: + case SpvOpFwidthFine: + case SpvOpDPdxCoarse: + case SpvOpDPdyCoarse: + case SpvOpFwidthCoarse: + case SpvOpImageQueryLod: + return true; + default: + return false; + } +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/instruction.h b/source/opt/instruction.h new file mode 100644 index 000000000..034da76f4 --- /dev/null +++ b/source/opt/instruction.h @@ -0,0 +1,733 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_INSTRUCTION_H_ +#define SOURCE_OPT_INSTRUCTION_H_ + +#include +#include +#include +#include +#include +#include + +#include "source/opcode.h" +#include "source/operand.h" +#include "source/util/ilist_node.h" +#include "source/util/small_vector.h" + +#include "source/latest_version_glsl_std_450_header.h" +#include "source/latest_version_spirv_header.h" +#include "source/opt/reflect.h" +#include "spirv-tools/libspirv.h" + +namespace spvtools { +namespace opt { + +class Function; +class IRContext; +class Module; +class InstructionList; + +// Relaxed logical addressing: +// +// In the logical addressing model, pointers cannot be stored or loaded. This +// is a useful assumption because it simplifies the aliasing significantly. +// However, for the purpose of legalizing code generated from HLSL, we will have +// to allow storing and loading of pointers to opaque objects and runtime +// arrays. This relaxation of the rule still implies that function and private +// scope variables do not have any aliasing, so we can treat them as before. +// This will be call the relaxed logical addressing model. +// +// This relaxation of the rule will be allowed by |GetBaseAddress|, but it will +// enforce that no other pointers are stored or loaded. + +// About operand: +// +// In the SPIR-V specification, the term "operand" is used to mean any single +// SPIR-V word following the leading wordcount-opcode word. Here, the term +// "operand" is used to mean a *logical* operand. A logical operand may consist +// of multiple SPIR-V words, which together make up the same component. For +// example, a logical operand of a 64-bit integer needs two words to express. +// +// Further, we categorize logical operands into *in* and *out* operands. +// In operands are operands actually serve as input to operations, while out +// operands are operands that represent ids generated from operations (result +// type id or result id). For example, for "OpIAdd %rtype %rid %inop1 %inop2", +// "%inop1" and "%inop2" are in operands, while "%rtype" and "%rid" are out +// operands. + +// A *logical* operand to a SPIR-V instruction. It can be the type id, result +// id, or other additional operands carried in an instruction. +struct Operand { + using OperandData = utils::SmallVector; + Operand(spv_operand_type_t t, OperandData&& w) + : type(t), words(std::move(w)) {} + + Operand(spv_operand_type_t t, const OperandData& w) : type(t), words(w) {} + + spv_operand_type_t type; // Type of this logical operand. + OperandData words; // Binary segments of this logical operand. + + friend bool operator==(const Operand& o1, const Operand& o2) { + return o1.type == o2.type && o1.words == o2.words; + } + + // TODO(antiagainst): create fields for literal number kind, width, etc. +}; + +inline bool operator!=(const Operand& o1, const Operand& o2) { + return !(o1 == o2); +} + +// A SPIR-V instruction. It contains the opcode and any additional logical +// operand, including the result id (if any) and result type id (if any). It +// may also contain line-related debug instruction (OpLine, OpNoLine) directly +// appearing before this instruction. Note that the result id of an instruction +// should never change after the instruction being built. If the result id +// needs to change, the user should create a new instruction instead. +class Instruction : public utils::IntrusiveNodeBase { + public: + using OperandList = std::vector; + using iterator = OperandList::iterator; + using const_iterator = OperandList::const_iterator; + + // Creates a default OpNop instruction. + // This exists solely for containers that can't do without. Should be removed. + Instruction() + : utils::IntrusiveNodeBase(), + context_(nullptr), + opcode_(SpvOpNop), + has_type_id_(false), + has_result_id_(false), + unique_id_(0) {} + + // Creates a default OpNop instruction. + Instruction(IRContext*); + // Creates an instruction with the given opcode |op| and no additional logical + // operands. + Instruction(IRContext*, SpvOp); + // Creates an instruction using the given spv_parsed_instruction_t |inst|. All + // the data inside |inst| will be copied and owned in this instance. And keep + // record of line-related debug instructions |dbg_line| ahead of this + // instruction, if any. + Instruction(IRContext* c, const spv_parsed_instruction_t& inst, + std::vector&& dbg_line = {}); + + // Creates an instruction with the given opcode |op|, type id: |ty_id|, + // result id: |res_id| and input operands: |in_operands|. + Instruction(IRContext* c, SpvOp op, uint32_t ty_id, uint32_t res_id, + const OperandList& in_operands); + + // TODO: I will want to remove these, but will first have to remove the use of + // std::vector. + Instruction(const Instruction&) = default; + Instruction& operator=(const Instruction&) = default; + + Instruction(Instruction&&); + Instruction& operator=(Instruction&&); + + virtual ~Instruction() = default; + + // Returns a newly allocated instruction that has the same operands, result, + // and type as |this|. The new instruction is not linked into any list. + // It is the responsibility of the caller to make sure that the storage is + // removed. It is the caller's responsibility to make sure that there is only + // one instruction for each result id. + Instruction* Clone(IRContext* c) const; + + IRContext* context() const { return context_; } + + SpvOp opcode() const { return opcode_; } + // Sets the opcode of this instruction to a specific opcode. Note this may + // invalidate the instruction. + // TODO(qining): Remove this function when instruction building and insertion + // is well implemented. + void SetOpcode(SpvOp op) { opcode_ = op; } + uint32_t type_id() const { + return has_type_id_ ? GetSingleWordOperand(0) : 0; + } + uint32_t result_id() const { + return has_result_id_ ? GetSingleWordOperand(has_type_id_ ? 1 : 0) : 0; + } + uint32_t unique_id() const { + assert(unique_id_ != 0); + return unique_id_; + } + // Returns the vector of line-related debug instructions attached to this + // instruction and the caller can directly modify them. + std::vector& dbg_line_insts() { return dbg_line_insts_; } + const std::vector& dbg_line_insts() const { + return dbg_line_insts_; + } + + // Same semantics as in the base class except the list the InstructionList + // containing |pos| will now assume ownership of |this|. + // inline void MoveBefore(Instruction* pos); + // inline void InsertAfter(Instruction* pos); + + // Begin and end iterators for operands. + iterator begin() { return operands_.begin(); } + iterator end() { return operands_.end(); } + const_iterator begin() const { return operands_.cbegin(); } + const_iterator end() const { return operands_.cend(); } + // Const begin and end iterators for operands. + const_iterator cbegin() const { return operands_.cbegin(); } + const_iterator cend() const { return operands_.cend(); } + + // Gets the number of logical operands. + uint32_t NumOperands() const { + return static_cast(operands_.size()); + } + // Gets the number of SPIR-V words occupied by all logical operands. + uint32_t NumOperandWords() const { + return NumInOperandWords() + TypeResultIdCount(); + } + // Gets the |index|-th logical operand. + inline Operand& GetOperand(uint32_t index); + inline const Operand& GetOperand(uint32_t index) const; + // Adds |operand| to the list of operands of this instruction. + // It is the responsibility of the caller to make sure + // that the instruction remains valid. + inline void AddOperand(Operand&& operand); + // Gets the |index|-th logical operand as a single SPIR-V word. This method is + // not expected to be used with logical operands consisting of multiple SPIR-V + // words. + uint32_t GetSingleWordOperand(uint32_t index) const; + // Sets the |index|-th in-operand's data to the given |data|. + inline void SetInOperand(uint32_t index, Operand::OperandData&& data); + // Sets the |index|-th operand's data to the given |data|. + // This is for in-operands modification only, but with |index| expressed in + // terms of operand index rather than in-operand index. + inline void SetOperand(uint32_t index, Operand::OperandData&& data); + // Replace all of the in operands with those in |new_operands|. + inline void SetInOperands(OperandList&& new_operands); + // Sets the result type id. + inline void SetResultType(uint32_t ty_id); + // Sets the result id + inline void SetResultId(uint32_t res_id); + inline bool HasResultId() const { return has_result_id_; } + // Remove the |index|-th operand + void RemoveOperand(uint32_t index) { + operands_.erase(operands_.begin() + index); + } + + // The following methods are similar to the above, but are for in operands. + uint32_t NumInOperands() const { + return static_cast(operands_.size() - TypeResultIdCount()); + } + uint32_t NumInOperandWords() const; + Operand& GetInOperand(uint32_t index) { + return GetOperand(index + TypeResultIdCount()); + } + const Operand& GetInOperand(uint32_t index) const { + return GetOperand(index + TypeResultIdCount()); + } + uint32_t GetSingleWordInOperand(uint32_t index) const { + return GetSingleWordOperand(index + TypeResultIdCount()); + } + void RemoveInOperand(uint32_t index) { + operands_.erase(operands_.begin() + index + TypeResultIdCount()); + } + + // Returns true if this instruction is OpNop. + inline bool IsNop() const; + // Turns this instruction to OpNop. This does not clear out all preceding + // line-related debug instructions. + inline void ToNop(); + + // Runs the given function |f| on this instruction and optionally on the + // preceding debug line instructions. The function will always be run + // if this is itself a debug line instruction. + inline void ForEachInst(const std::function& f, + bool run_on_debug_line_insts = false); + inline void ForEachInst(const std::function& f, + bool run_on_debug_line_insts = false) const; + + // Runs the given function |f| on this instruction and optionally on the + // preceding debug line instructions. The function will always be run + // if this is itself a debug line instruction. If |f| returns false, + // iteration is terminated and this function returns false. + inline bool WhileEachInst(const std::function& f, + bool run_on_debug_line_insts = false); + inline bool WhileEachInst(const std::function& f, + bool run_on_debug_line_insts = false) const; + + // Runs the given function |f| on all operand ids. + // + // |f| should not transform an ID into 0, as 0 is an invalid ID. + inline void ForEachId(const std::function& f); + inline void ForEachId(const std::function& f) const; + + // Runs the given function |f| on all "in" operand ids. + inline void ForEachInId(const std::function& f); + inline void ForEachInId(const std::function& f) const; + + // Runs the given function |f| on all "in" operand ids. If |f| returns false, + // iteration is terminated and this function returns false. + inline bool WhileEachInId(const std::function& f); + inline bool WhileEachInId( + const std::function& f) const; + + // Runs the given function |f| on all "in" operands. + inline void ForEachInOperand(const std::function& f); + inline void ForEachInOperand( + const std::function& f) const; + + // Runs the given function |f| on all "in" operands. If |f| returns false, + // iteration is terminated and this function return false. + inline bool WhileEachInOperand(const std::function& f); + inline bool WhileEachInOperand( + const std::function& f) const; + + // Returns true if any operands can be labels + inline bool HasLabels() const; + + // Pushes the binary segments for this instruction into the back of *|binary|. + void ToBinaryWithoutAttachedDebugInsts(std::vector* binary) const; + + // Replaces the operands to the instruction with |new_operands|. The caller + // is responsible for building a complete and valid list of operands for + // this instruction. + void ReplaceOperands(const OperandList& new_operands); + + // Returns true if the instruction annotates an id with a decoration. + inline bool IsDecoration() const; + + // Returns true if the instruction is known to be a load from read-only + // memory. + bool IsReadOnlyLoad() const; + + // Returns the instruction that gives the base address of an address + // calculation. The instruction must be a load, as defined by |IsLoad|, + // store, copy, or access chain instruction. In logical addressing mode, will + // return an OpVariable or OpFunctionParameter instruction. For relaxed + // logical addressing, it would also return a load of a pointer to an opaque + // object. For physical addressing mode, could return other types of + // instructions. + Instruction* GetBaseAddress() const; + + // Returns true if the instruction loads from memory or samples an image, and + // stores the result into an id. It considers only core instructions. + // Memory-to-memory instructions are not considered loads. + inline bool IsLoad() const; + + // Returns true if the instruction declares a variable that is read-only. + bool IsReadOnlyVariable() const; + + // The following functions check for the various descriptor types defined in + // the Vulkan specification section 13.1. + + // Returns true if the instruction defines a pointer type that points to a + // storage image. + bool IsVulkanStorageImage() const; + + // Returns true if the instruction defines a pointer type that points to a + // sampled image. + bool IsVulkanSampledImage() const; + + // Returns true if the instruction defines a pointer type that points to a + // storage texel buffer. + bool IsVulkanStorageTexelBuffer() const; + + // Returns true if the instruction defines a pointer type that points to a + // storage buffer. + bool IsVulkanStorageBuffer() const; + + // Returns true if the instruction defines a pointer type that points to a + // uniform buffer. + bool IsVulkanUniformBuffer() const; + + // Returns true if the instruction is an atom operation that uses original + // value. + inline bool IsAtomicWithLoad() const; + + // Returns true if the instruction is an atom operation. + inline bool IsAtomicOp() const; + + // Returns true if this instruction is a branch or switch instruction (either + // conditional or not). + bool IsBranch() const { return spvOpcodeIsBranch(opcode()); } + + // Returns true if this instruction causes the function to finish execution + // and return to its caller + bool IsReturn() const { return spvOpcodeIsReturn(opcode()); } + + // Returns true if this instruction exits this function or aborts execution. + bool IsReturnOrAbort() const { return spvOpcodeIsReturnOrAbort(opcode()); } + + // Returns the id for the |element|'th subtype. If the |this| is not a + // composite type, this function returns 0. + uint32_t GetTypeComponent(uint32_t element) const; + + // Returns true if this instruction is a basic block terminator. + bool IsBlockTerminator() const { + return spvOpcodeIsBlockTerminator(opcode()); + } + + // Returns true if |this| is an instruction that define an opaque type. Since + // runtime array have similar characteristics they are included as opaque + // types. + bool IsOpaqueType() const; + + // Returns true if |this| is an instruction which could be folded into a + // constant value. + bool IsFoldable() const; + + // Returns true if |this| is an instruction which could be folded into a + // constant value by |FoldScalar|. + bool IsFoldableByFoldScalar() const; + + // Returns true if we are allowed to fold or otherwise manipulate the + // instruction that defines |id| in the given context. This includes not + // handling NaN values. + bool IsFloatingPointFoldingAllowed() const; + + inline bool operator==(const Instruction&) const; + inline bool operator!=(const Instruction&) const; + inline bool operator<(const Instruction&) const; + + Instruction* InsertBefore(std::vector>&& list); + Instruction* InsertBefore(std::unique_ptr&& i); + using utils::IntrusiveNodeBase::InsertBefore; + + // Returns true if |this| is an instruction defining a constant, but not a + // Spec constant. + inline bool IsConstant() const; + + // Returns true if |this| is an instruction with an opcode safe to move + bool IsOpcodeCodeMotionSafe() const; + + // Pretty-prints |inst|. + // + // Provides the disassembly of a specific instruction. Utilizes |inst|'s + // context to provide the correct interpretation of types, constants, etc. + // + // |options| are the disassembly options. SPV_BINARY_TO_TEXT_OPTION_NO_HEADER + // is always added to |options|. + std::string PrettyPrint(uint32_t options = 0u) const; + + // Returns true if the result can be a vector and the result of each component + // depends on the corresponding component of any vector inputs. + bool IsScalarizable() const; + + // Return true if the only effect of this instructions is the result. + bool IsOpcodeSafeToDelete() const; + + // Returns true if it is valid to use the result of |inst| as the base + // pointer for a load or store. In this case, valid is defined by the relaxed + // logical addressing rules when using logical addressing. Normal validation + // rules for physical addressing. + bool IsValidBasePointer() const; + + // Dump this instruction on stderr. Useful when running interactive + // debuggers. + void Dump() const; + + private: + // Returns the total count of result type id and result id. + uint32_t TypeResultIdCount() const { + if (has_type_id_ && has_result_id_) return 2; + if (has_type_id_ || has_result_id_) return 1; + return 0; + } + + // Returns true if the instruction declares a variable that is read-only. The + // first version assumes the module is a shader module. The second assumes a + // kernel. + bool IsReadOnlyVariableShaders() const; + bool IsReadOnlyVariableKernel() const; + + // Returns true if the result of |inst| can be used as the base image for an + // instruction that samples a image, reads an image, or writes to an image. + bool IsValidBaseImage() const; + + IRContext* context_; // IR Context + SpvOp opcode_; // Opcode + bool has_type_id_; // True if the instruction has a type id + bool has_result_id_; // True if the instruction has a result id + uint32_t unique_id_; // Unique instruction id + // All logical operands, including result type id and result id. + OperandList operands_; + // Opline and OpNoLine instructions preceding this instruction. Note that for + // Instructions representing OpLine or OpNonLine itself, this field should be + // empty. + std::vector dbg_line_insts_; + + friend InstructionList; +}; + +// Pretty-prints |inst| to |str| and returns |str|. +// +// Provides the disassembly of a specific instruction. Utilizes |inst|'s context +// to provide the correct interpretation of types, constants, etc. +// +// Disassembly uses raw ids (not pretty printed names). +std::ostream& operator<<(std::ostream& str, const Instruction& inst); + +inline bool Instruction::operator==(const Instruction& other) const { + return unique_id() == other.unique_id(); +} + +inline bool Instruction::operator!=(const Instruction& other) const { + return !(*this == other); +} + +inline bool Instruction::operator<(const Instruction& other) const { + return unique_id() < other.unique_id(); +} + +inline Operand& Instruction::GetOperand(uint32_t index) { + assert(index < operands_.size() && "operand index out of bound"); + return operands_[index]; +} + +inline const Operand& Instruction::GetOperand(uint32_t index) const { + assert(index < operands_.size() && "operand index out of bound"); + return operands_[index]; +} + +inline void Instruction::AddOperand(Operand&& operand) { + operands_.push_back(std::move(operand)); +} + +inline void Instruction::SetInOperand(uint32_t index, + Operand::OperandData&& data) { + SetOperand(index + TypeResultIdCount(), std::move(data)); +} + +inline void Instruction::SetOperand(uint32_t index, + Operand::OperandData&& data) { + assert(index < operands_.size() && "operand index out of bound"); + assert(index >= TypeResultIdCount() && "operand is not a in-operand"); + operands_[index].words = std::move(data); +} + +inline void Instruction::SetInOperands(OperandList&& new_operands) { + // Remove the old in operands. + operands_.erase(operands_.begin() + TypeResultIdCount(), operands_.end()); + // Add the new in operands. + operands_.insert(operands_.end(), new_operands.begin(), new_operands.end()); +} + +inline void Instruction::SetResultId(uint32_t res_id) { + // TODO(dsinclair): Allow setting a result id if there wasn't one + // previously. Need to make room in the operands_ array to place the result, + // and update the has_result_id_ flag. + assert(has_result_id_); + + // TODO(dsinclair): Allow removing the result id. This needs to make sure, + // if there was a result id previously to remove it from the operands_ array + // and reset the has_result_id_ flag. + assert(res_id != 0); + + auto ridx = has_type_id_ ? 1 : 0; + operands_[ridx].words = {res_id}; +} + +inline void Instruction::SetResultType(uint32_t ty_id) { + // TODO(dsinclair): Allow setting a type id if there wasn't one + // previously. Need to make room in the operands_ array to place the result, + // and update the has_type_id_ flag. + assert(has_type_id_); + + // TODO(dsinclair): Allow removing the type id. This needs to make sure, + // if there was a type id previously to remove it from the operands_ array + // and reset the has_type_id_ flag. + assert(ty_id != 0); + + operands_.front().words = {ty_id}; +} + +inline bool Instruction::IsNop() const { + return opcode_ == SpvOpNop && !has_type_id_ && !has_result_id_ && + operands_.empty(); +} + +inline void Instruction::ToNop() { + opcode_ = SpvOpNop; + has_type_id_ = false; + has_result_id_ = false; + operands_.clear(); +} + +inline bool Instruction::WhileEachInst( + const std::function& f, bool run_on_debug_line_insts) { + if (run_on_debug_line_insts) { + for (auto& dbg_line : dbg_line_insts_) { + if (!f(&dbg_line)) return false; + } + } + return f(this); +} + +inline bool Instruction::WhileEachInst( + const std::function& f, + bool run_on_debug_line_insts) const { + if (run_on_debug_line_insts) { + for (auto& dbg_line : dbg_line_insts_) { + if (!f(&dbg_line)) return false; + } + } + return f(this); +} + +inline void Instruction::ForEachInst(const std::function& f, + bool run_on_debug_line_insts) { + WhileEachInst( + [&f](Instruction* inst) { + f(inst); + return true; + }, + run_on_debug_line_insts); +} + +inline void Instruction::ForEachInst( + const std::function& f, + bool run_on_debug_line_insts) const { + WhileEachInst( + [&f](const Instruction* inst) { + f(inst); + return true; + }, + run_on_debug_line_insts); +} + +inline void Instruction::ForEachId(const std::function& f) { + for (auto& opnd : operands_) + if (spvIsIdType(opnd.type)) f(&opnd.words[0]); +} + +inline void Instruction::ForEachId( + const std::function& f) const { + for (const auto& opnd : operands_) + if (spvIsIdType(opnd.type)) f(&opnd.words[0]); +} + +inline bool Instruction::WhileEachInId( + const std::function& f) { + for (auto& opnd : operands_) { + if (spvIsInIdType(opnd.type)) { + if (!f(&opnd.words[0])) return false; + } + } + return true; +} + +inline bool Instruction::WhileEachInId( + const std::function& f) const { + for (const auto& opnd : operands_) { + if (spvIsInIdType(opnd.type)) { + if (!f(&opnd.words[0])) return false; + } + } + return true; +} + +inline void Instruction::ForEachInId(const std::function& f) { + WhileEachInId([&f](uint32_t* id) { + f(id); + return true; + }); +} + +inline void Instruction::ForEachInId( + const std::function& f) const { + WhileEachInId([&f](const uint32_t* id) { + f(id); + return true; + }); +} + +inline bool Instruction::WhileEachInOperand( + const std::function& f) { + for (auto& opnd : operands_) { + switch (opnd.type) { + case SPV_OPERAND_TYPE_RESULT_ID: + case SPV_OPERAND_TYPE_TYPE_ID: + break; + default: + if (!f(&opnd.words[0])) return false; + break; + } + } + return true; +} + +inline bool Instruction::WhileEachInOperand( + const std::function& f) const { + for (const auto& opnd : operands_) { + switch (opnd.type) { + case SPV_OPERAND_TYPE_RESULT_ID: + case SPV_OPERAND_TYPE_TYPE_ID: + break; + default: + if (!f(&opnd.words[0])) return false; + break; + } + } + return true; +} + +inline void Instruction::ForEachInOperand( + const std::function& f) { + WhileEachInOperand([&f](uint32_t* op) { + f(op); + return true; + }); +} + +inline void Instruction::ForEachInOperand( + const std::function& f) const { + WhileEachInOperand([&f](const uint32_t* op) { + f(op); + return true; + }); +} + +inline bool Instruction::HasLabels() const { + switch (opcode_) { + case SpvOpSelectionMerge: + case SpvOpBranch: + case SpvOpLoopMerge: + case SpvOpBranchConditional: + case SpvOpSwitch: + case SpvOpPhi: + return true; + break; + default: + break; + } + return false; +} + +bool Instruction::IsDecoration() const { + return spvOpcodeIsDecoration(opcode()); +} + +bool Instruction::IsLoad() const { return spvOpcodeIsLoad(opcode()); } + +bool Instruction::IsAtomicWithLoad() const { + return spvOpcodeIsAtomicWithLoad(opcode()); +} + +bool Instruction::IsAtomicOp() const { return spvOpcodeIsAtomicOp(opcode()); } + +bool Instruction::IsConstant() const { + return IsCompileTimeConstantInst(opcode()); +} +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_INSTRUCTION_H_ diff --git a/source/opt/instruction_list.cpp b/source/opt/instruction_list.cpp new file mode 100644 index 000000000..385a136ec --- /dev/null +++ b/source/opt/instruction_list.cpp @@ -0,0 +1,36 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/instruction_list.h" + +namespace spvtools { +namespace opt { + +InstructionList::iterator InstructionList::iterator::InsertBefore( + std::vector>&& list) { + Instruction* first_node = list.front().get(); + for (auto& i : list) { + i.release()->InsertBefore(node_); + } + list.clear(); + return iterator(first_node); +} + +InstructionList::iterator InstructionList::iterator::InsertBefore( + std::unique_ptr&& i) { + i.get()->InsertBefore(node_); + return iterator(i.release()); +} +} // namespace opt +} // namespace spvtools diff --git a/source/opt/instruction_list.h b/source/opt/instruction_list.h new file mode 100644 index 000000000..ea1cc7c46 --- /dev/null +++ b/source/opt/instruction_list.h @@ -0,0 +1,130 @@ + +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_INSTRUCTION_LIST_H_ +#define SOURCE_OPT_INSTRUCTION_LIST_H_ + +#include +#include +#include +#include +#include + +#include "source/latest_version_spirv_header.h" +#include "source/operand.h" +#include "source/opt/instruction.h" +#include "source/util/ilist.h" +#include "spirv-tools/libspirv.h" + +namespace spvtools { +namespace opt { + +// This class is intended to be the container for Instructions. This container +// owns the instructions that are in it. When removing an Instruction from the +// list, the caller is assuming responsibility for deleting the storage. +// +// TODO: Because there are a number of other data structures that will want +// pointers to instruction, ownership should probably be moved to the module. +// Because of that I have not made the ownership passing in this class fully +// explicit. For example, RemoveFromList takes ownership from the list, but +// does not return an std::unique_ptr to signal that. When we fully decide on +// ownership, this will have to be fixed up one way or the other. +class InstructionList : public utils::IntrusiveList { + public: + InstructionList() = default; + InstructionList(InstructionList&& that) + : utils::IntrusiveList(std::move(that)) {} + InstructionList& operator=(InstructionList&& that) { + auto p = static_cast*>(this); + *p = std::move(that); + return *this; + } + + // Destroy this list and any instructions in the list. + inline virtual ~InstructionList(); + + class iterator : public utils::IntrusiveList::iterator { + public: + iterator(const utils::IntrusiveList::iterator& i) + : utils::IntrusiveList::iterator(i) {} + iterator(Instruction* i) : utils::IntrusiveList::iterator(i) {} + + // DEPRECATED: Please use MoveBefore with an InstructionList instead. + // + // Moves the nodes in |list| to the list that |this| points to. The + // positions of the nodes will be immediately before the element pointed to + // by the iterator. The return value will be an iterator pointing to the + // first of the newly inserted elements. Ownership of the elements in + // |list| is now passed on to |*this|. + iterator InsertBefore(std::vector>&& list); + + // The node |i| will be inserted immediately before |this|. The return value + // will be an iterator pointing to the newly inserted node. The owner of + // |*i| becomes |*this| + iterator InsertBefore(std::unique_ptr&& i); + + // Removes the node from the list, and deletes the storage. Returns a valid + // iterator to the next node. + iterator Erase() { + iterator_template next_node = *this; + ++next_node; + node_->RemoveFromList(); + delete node_; + return next_node; + } + }; + + iterator begin() { return utils::IntrusiveList::begin(); } + iterator end() { return utils::IntrusiveList::end(); } + const_iterator begin() const { + return utils::IntrusiveList::begin(); + } + const_iterator end() const { + return utils::IntrusiveList::end(); + } + + void push_back(std::unique_ptr&& inst) { + utils::IntrusiveList::push_back(inst.release()); + } + + // Same as in the base class, except it will delete the data as well. + inline void clear(); + + // Runs the given function |f| on the instructions in the list and optionally + // on the preceding debug line instructions. + inline void ForEachInst(const std::function& f, + bool run_on_debug_line_insts) { + auto next = begin(); + for (auto i = next; i != end(); i = next) { + ++next; + i->ForEachInst(f, run_on_debug_line_insts); + } + } +}; + +InstructionList::~InstructionList() { clear(); } + +void InstructionList::clear() { + while (!empty()) { + Instruction* inst = &front(); + inst->RemoveFromList(); + delete inst; + } +} + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_INSTRUCTION_LIST_H_ diff --git a/source/opt/instrument_pass.cpp b/source/opt/instrument_pass.cpp new file mode 100644 index 000000000..6935a43dc --- /dev/null +++ b/source/opt/instrument_pass.cpp @@ -0,0 +1,712 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// Copyright (c) 2018 Valve Corporation +// Copyright (c) 2018 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "instrument_pass.h" + +#include "source/cfa.h" + +namespace { + +// Common Parameter Positions +static const int kInstCommonParamInstIdx = 0; +static const int kInstCommonParamCnt = 1; + +// Indices of operands in SPIR-V instructions +static const int kEntryPointExecutionModelInIdx = 0; +static const int kEntryPointFunctionIdInIdx = 1; + +} // anonymous namespace + +namespace spvtools { +namespace opt { + +void InstrumentPass::MovePreludeCode( + BasicBlock::iterator ref_inst_itr, + UptrVectorIterator ref_block_itr, + std::unique_ptr* new_blk_ptr) { + same_block_pre_.clear(); + same_block_post_.clear(); + // Initialize new block. Reuse label from original block. + new_blk_ptr->reset(new BasicBlock(std::move(ref_block_itr->GetLabel()))); + // Move contents of original ref block up to ref instruction. + for (auto cii = ref_block_itr->begin(); cii != ref_inst_itr; + cii = ref_block_itr->begin()) { + Instruction* inst = &*cii; + inst->RemoveFromList(); + std::unique_ptr mv_ptr(inst); + // Remember same-block ops for possible regeneration. + if (IsSameBlockOp(&*mv_ptr)) { + auto* sb_inst_ptr = mv_ptr.get(); + same_block_pre_[mv_ptr->result_id()] = sb_inst_ptr; + } + (*new_blk_ptr)->AddInstruction(std::move(mv_ptr)); + } +} + +void InstrumentPass::MovePostludeCode( + UptrVectorIterator ref_block_itr, + std::unique_ptr* new_blk_ptr) { + // new_blk_ptr->reset(new BasicBlock(NewLabel(ref_block_itr->id()))); + // Move contents of original ref block. + for (auto cii = ref_block_itr->begin(); cii != ref_block_itr->end(); + cii = ref_block_itr->begin()) { + Instruction* inst = &*cii; + inst->RemoveFromList(); + std::unique_ptr mv_inst(inst); + // Regenerate any same-block instruction that has not been seen in the + // current block. + if (same_block_pre_.size() > 0) { + CloneSameBlockOps(&mv_inst, &same_block_post_, &same_block_pre_, + new_blk_ptr); + // Remember same-block ops in this block. + if (IsSameBlockOp(&*mv_inst)) { + const uint32_t rid = mv_inst->result_id(); + same_block_post_[rid] = rid; + } + } + (*new_blk_ptr)->AddInstruction(std::move(mv_inst)); + } +} + +std::unique_ptr InstrumentPass::NewLabel(uint32_t label_id) { + std::unique_ptr newLabel( + new Instruction(context(), SpvOpLabel, 0, label_id, {})); + get_def_use_mgr()->AnalyzeInstDefUse(&*newLabel); + return newLabel; +} + +uint32_t InstrumentPass::GenUintCastCode(uint32_t val_id, + InstructionBuilder* builder) { + // Cast value to 32-bit unsigned if necessary + if (get_def_use_mgr()->GetDef(val_id)->type_id() == GetUintId()) + return val_id; + return builder->AddUnaryOp(GetUintId(), SpvOpBitcast, val_id)->result_id(); +} + +void InstrumentPass::GenDebugOutputFieldCode(uint32_t base_offset_id, + uint32_t field_offset, + uint32_t field_value_id, + InstructionBuilder* builder) { + // Cast value to 32-bit unsigned if necessary + uint32_t val_id = GenUintCastCode(field_value_id, builder); + // Store value + Instruction* data_idx_inst = + builder->AddBinaryOp(GetUintId(), SpvOpIAdd, base_offset_id, + builder->GetUintConstantId(field_offset)); + uint32_t buf_id = GetOutputBufferId(); + uint32_t buf_uint_ptr_id = GetOutputBufferUintPtrId(); + Instruction* achain_inst = + builder->AddTernaryOp(buf_uint_ptr_id, SpvOpAccessChain, buf_id, + builder->GetUintConstantId(kDebugOutputDataOffset), + data_idx_inst->result_id()); + (void)builder->AddBinaryOp(0, SpvOpStore, achain_inst->result_id(), val_id); +} + +void InstrumentPass::GenCommonStreamWriteCode(uint32_t record_sz, + uint32_t inst_id, + uint32_t stage_idx, + uint32_t base_offset_id, + InstructionBuilder* builder) { + // Store record size + GenDebugOutputFieldCode(base_offset_id, kInstCommonOutSize, + builder->GetUintConstantId(record_sz), builder); + // Store Shader Id + GenDebugOutputFieldCode(base_offset_id, kInstCommonOutShaderId, + builder->GetUintConstantId(shader_id_), builder); + // Store Instruction Idx + GenDebugOutputFieldCode(base_offset_id, kInstCommonOutInstructionIdx, inst_id, + builder); + // Store Stage Idx + GenDebugOutputFieldCode(base_offset_id, kInstCommonOutStageIdx, + builder->GetUintConstantId(stage_idx), builder); +} + +void InstrumentPass::GenFragCoordEltDebugOutputCode( + uint32_t base_offset_id, uint32_t uint_frag_coord_id, uint32_t element, + InstructionBuilder* builder) { + Instruction* element_val_inst = builder->AddIdLiteralOp( + GetUintId(), SpvOpCompositeExtract, uint_frag_coord_id, element); + GenDebugOutputFieldCode(base_offset_id, kInstFragOutFragCoordX + element, + element_val_inst->result_id(), builder); +} + +void InstrumentPass::GenBuiltinOutputCode(uint32_t builtin_id, + uint32_t builtin_off, + uint32_t base_offset_id, + InstructionBuilder* builder) { + // Load and store builtin + Instruction* load_inst = + builder->AddUnaryOp(GetUintId(), SpvOpLoad, builtin_id); + GenDebugOutputFieldCode(base_offset_id, builtin_off, load_inst->result_id(), + builder); +} + +void InstrumentPass::GenUintNullOutputCode(uint32_t field_off, + uint32_t base_offset_id, + InstructionBuilder* builder) { + GenDebugOutputFieldCode(base_offset_id, field_off, + builder->GetNullId(GetUintId()), builder); +} + +void InstrumentPass::GenStageStreamWriteCode(uint32_t stage_idx, + uint32_t base_offset_id, + InstructionBuilder* builder) { + // TODO(greg-lunarg): Add support for all stages + switch (stage_idx) { + case SpvExecutionModelVertex: { + // Load and store VertexId and InstanceId + GenBuiltinOutputCode(context()->GetBuiltinVarId(SpvBuiltInVertexIndex), + kInstVertOutVertexIndex, base_offset_id, builder); + GenBuiltinOutputCode(context()->GetBuiltinVarId(SpvBuiltInInstanceIndex), + kInstVertOutInstanceIndex, base_offset_id, builder); + } break; + case SpvExecutionModelGLCompute: { + // Load and store GlobalInvocationId. Second word is unused; store zero. + GenBuiltinOutputCode( + context()->GetBuiltinVarId(SpvBuiltInGlobalInvocationId), + kInstCompOutGlobalInvocationId, base_offset_id, builder); + GenUintNullOutputCode(kInstCompOutUnused, base_offset_id, builder); + } break; + case SpvExecutionModelGeometry: { + // Load and store PrimitiveId and InvocationId. + GenBuiltinOutputCode(context()->GetBuiltinVarId(SpvBuiltInPrimitiveId), + kInstGeomOutPrimitiveId, base_offset_id, builder); + GenBuiltinOutputCode(context()->GetBuiltinVarId(SpvBuiltInInvocationId), + kInstGeomOutInvocationId, base_offset_id, builder); + } break; + case SpvExecutionModelTessellationControl: + case SpvExecutionModelTessellationEvaluation: { + // Load and store InvocationId. Second word is unused; store zero. + GenBuiltinOutputCode(context()->GetBuiltinVarId(SpvBuiltInInvocationId), + kInstTessOutInvocationId, base_offset_id, builder); + GenUintNullOutputCode(kInstTessOutUnused, base_offset_id, builder); + } break; + case SpvExecutionModelFragment: { + // Load FragCoord and convert to Uint + Instruction* frag_coord_inst = + builder->AddUnaryOp(GetVec4FloatId(), SpvOpLoad, + context()->GetBuiltinVarId(SpvBuiltInFragCoord)); + Instruction* uint_frag_coord_inst = builder->AddUnaryOp( + GetVec4UintId(), SpvOpBitcast, frag_coord_inst->result_id()); + for (uint32_t u = 0; u < 2u; ++u) + GenFragCoordEltDebugOutputCode( + base_offset_id, uint_frag_coord_inst->result_id(), u, builder); + } break; + default: { assert(false && "unsupported stage"); } break; + } +} + +void InstrumentPass::GenDebugStreamWrite( + uint32_t instruction_idx, uint32_t stage_idx, + const std::vector& validation_ids, InstructionBuilder* builder) { + // Call debug output function. Pass func_idx, instruction_idx and + // validation ids as args. + uint32_t val_id_cnt = static_cast(validation_ids.size()); + uint32_t output_func_id = GetStreamWriteFunctionId(stage_idx, val_id_cnt); + std::vector args = {output_func_id, + builder->GetUintConstantId(instruction_idx)}; + (void)args.insert(args.end(), validation_ids.begin(), validation_ids.end()); + (void)builder->AddNaryOp(GetVoidId(), SpvOpFunctionCall, args); +} + +bool InstrumentPass::IsSameBlockOp(const Instruction* inst) const { + return inst->opcode() == SpvOpSampledImage || inst->opcode() == SpvOpImage; +} + +void InstrumentPass::CloneSameBlockOps( + std::unique_ptr* inst, + std::unordered_map* same_blk_post, + std::unordered_map* same_blk_pre, + std::unique_ptr* block_ptr) { + (*inst)->ForEachInId( + [&same_blk_post, &same_blk_pre, &block_ptr, this](uint32_t* iid) { + const auto map_itr = (*same_blk_post).find(*iid); + if (map_itr == (*same_blk_post).end()) { + const auto map_itr2 = (*same_blk_pre).find(*iid); + if (map_itr2 != (*same_blk_pre).end()) { + // Clone pre-call same-block ops, map result id. + const Instruction* in_inst = map_itr2->second; + std::unique_ptr sb_inst(in_inst->Clone(context())); + CloneSameBlockOps(&sb_inst, same_blk_post, same_blk_pre, block_ptr); + const uint32_t rid = sb_inst->result_id(); + const uint32_t nid = this->TakeNextId(); + get_decoration_mgr()->CloneDecorations(rid, nid); + sb_inst->SetResultId(nid); + (*same_blk_post)[rid] = nid; + *iid = nid; + (*block_ptr)->AddInstruction(std::move(sb_inst)); + } + } else { + // Reset same-block op operand. + *iid = map_itr->second; + } + }); +} + +void InstrumentPass::UpdateSucceedingPhis( + std::vector>& new_blocks) { + const auto first_blk = new_blocks.begin(); + const auto last_blk = new_blocks.end() - 1; + const uint32_t first_id = (*first_blk)->id(); + const uint32_t last_id = (*last_blk)->id(); + const BasicBlock& const_last_block = *last_blk->get(); + const_last_block.ForEachSuccessorLabel( + [&first_id, &last_id, this](const uint32_t succ) { + BasicBlock* sbp = this->id2block_[succ]; + sbp->ForEachPhiInst([&first_id, &last_id, this](Instruction* phi) { + bool changed = false; + phi->ForEachInId([&first_id, &last_id, &changed](uint32_t* id) { + if (*id == first_id) { + *id = last_id; + changed = true; + } + }); + if (changed) get_def_use_mgr()->AnalyzeInstUse(phi); + }); + }); +} + +// Return id for output buffer uint ptr type +uint32_t InstrumentPass::GetOutputBufferUintPtrId() { + if (output_buffer_uint_ptr_id_ == 0) { + output_buffer_uint_ptr_id_ = context()->get_type_mgr()->FindPointerToType( + GetUintId(), SpvStorageClassStorageBuffer); + } + return output_buffer_uint_ptr_id_; +} + +uint32_t InstrumentPass::GetOutputBufferBinding() { + switch (validation_id_) { + case kInstValidationIdBindless: + return kDebugOutputBindingStream; + default: + assert(false && "unexpected validation id"); + } + return 0; +} + +// Return id for output buffer +uint32_t InstrumentPass::GetOutputBufferId() { + if (output_buffer_id_ == 0) { + // If not created yet, create one + analysis::DecorationManager* deco_mgr = get_decoration_mgr(); + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + analysis::Integer uint_ty(32, false); + analysis::Type* reg_uint_ty = type_mgr->GetRegisteredType(&uint_ty); + analysis::RuntimeArray uint_rarr_ty(reg_uint_ty); + analysis::Type* reg_uint_rarr_ty = + type_mgr->GetRegisteredType(&uint_rarr_ty); + uint32_t uint_arr_ty_id = type_mgr->GetTypeInstruction(reg_uint_rarr_ty); + deco_mgr->AddDecorationVal(uint_arr_ty_id, SpvDecorationArrayStride, 4u); + analysis::Struct obuf_ty({reg_uint_ty, reg_uint_rarr_ty}); + analysis::Type* reg_obuf_ty = type_mgr->GetRegisteredType(&obuf_ty); + uint32_t obufTyId = type_mgr->GetTypeInstruction(reg_obuf_ty); + deco_mgr->AddDecoration(obufTyId, SpvDecorationBlock); + deco_mgr->AddMemberDecoration(obufTyId, kDebugOutputSizeOffset, + SpvDecorationOffset, 0); + deco_mgr->AddMemberDecoration(obufTyId, kDebugOutputDataOffset, + SpvDecorationOffset, 4); + uint32_t obufTyPtrId_ = + type_mgr->FindPointerToType(obufTyId, SpvStorageClassStorageBuffer); + output_buffer_id_ = TakeNextId(); + std::unique_ptr newVarOp(new Instruction( + context(), SpvOpVariable, obufTyPtrId_, output_buffer_id_, + {{spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, + {SpvStorageClassStorageBuffer}}})); + context()->AddGlobalValue(std::move(newVarOp)); + deco_mgr->AddDecorationVal(output_buffer_id_, SpvDecorationDescriptorSet, + desc_set_); + deco_mgr->AddDecorationVal(output_buffer_id_, SpvDecorationBinding, + GetOutputBufferBinding()); + // Look for storage buffer extension. If none, create one. + if (!get_feature_mgr()->HasExtension( + kSPV_KHR_storage_buffer_storage_class)) { + const std::string ext_name("SPV_KHR_storage_buffer_storage_class"); + const auto num_chars = ext_name.size(); + // Compute num words, accommodate the terminating null character. + const auto num_words = (num_chars + 1 + 3) / 4; + std::vector ext_words(num_words, 0u); + std::memcpy(ext_words.data(), ext_name.data(), num_chars); + context()->AddExtension(std::unique_ptr( + new Instruction(context(), SpvOpExtension, 0u, 0u, + {{SPV_OPERAND_TYPE_LITERAL_STRING, ext_words}}))); + } + } + return output_buffer_id_; +} + +uint32_t InstrumentPass::GetVec4FloatId() { + if (v4float_id_ == 0) { + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + analysis::Float float_ty(32); + analysis::Type* reg_float_ty = type_mgr->GetRegisteredType(&float_ty); + analysis::Vector v4float_ty(reg_float_ty, 4); + analysis::Type* reg_v4float_ty = type_mgr->GetRegisteredType(&v4float_ty); + v4float_id_ = type_mgr->GetTypeInstruction(reg_v4float_ty); + } + return v4float_id_; +} + +uint32_t InstrumentPass::GetUintId() { + if (uint_id_ == 0) { + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + analysis::Integer uint_ty(32, false); + analysis::Type* reg_uint_ty = type_mgr->GetRegisteredType(&uint_ty); + uint_id_ = type_mgr->GetTypeInstruction(reg_uint_ty); + } + return uint_id_; +} + +uint32_t InstrumentPass::GetVec4UintId() { + if (v4uint_id_ == 0) { + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + analysis::Integer uint_ty(32, false); + analysis::Type* reg_uint_ty = type_mgr->GetRegisteredType(&uint_ty); + analysis::Vector v4uint_ty(reg_uint_ty, 4); + analysis::Type* reg_v4uint_ty = type_mgr->GetRegisteredType(&v4uint_ty); + v4uint_id_ = type_mgr->GetTypeInstruction(reg_v4uint_ty); + } + return v4uint_id_; +} + +uint32_t InstrumentPass::GetBoolId() { + if (bool_id_ == 0) { + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + analysis::Bool bool_ty; + analysis::Type* reg_bool_ty = type_mgr->GetRegisteredType(&bool_ty); + bool_id_ = type_mgr->GetTypeInstruction(reg_bool_ty); + } + return bool_id_; +} + +uint32_t InstrumentPass::GetVoidId() { + if (void_id_ == 0) { + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + analysis::Void void_ty; + analysis::Type* reg_void_ty = type_mgr->GetRegisteredType(&void_ty); + void_id_ = type_mgr->GetTypeInstruction(reg_void_ty); + } + return void_id_; +} + +uint32_t InstrumentPass::GetStreamWriteFunctionId(uint32_t stage_idx, + uint32_t val_spec_param_cnt) { + // Total param count is common params plus validation-specific + // params + uint32_t param_cnt = kInstCommonParamCnt + val_spec_param_cnt; + if (output_func_id_ == 0) { + // Create function + output_func_id_ = TakeNextId(); + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + std::vector param_types; + for (uint32_t c = 0; c < param_cnt; ++c) + param_types.push_back(type_mgr->GetType(GetUintId())); + analysis::Function func_ty(type_mgr->GetType(GetVoidId()), param_types); + analysis::Type* reg_func_ty = type_mgr->GetRegisteredType(&func_ty); + std::unique_ptr func_inst(new Instruction( + get_module()->context(), SpvOpFunction, GetVoidId(), output_func_id_, + {{spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, + {SpvFunctionControlMaskNone}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, + {type_mgr->GetTypeInstruction(reg_func_ty)}}})); + get_def_use_mgr()->AnalyzeInstDefUse(&*func_inst); + std::unique_ptr output_func = + MakeUnique(std::move(func_inst)); + // Add parameters + std::vector param_vec; + for (uint32_t c = 0; c < param_cnt; ++c) { + uint32_t pid = TakeNextId(); + param_vec.push_back(pid); + std::unique_ptr param_inst( + new Instruction(get_module()->context(), SpvOpFunctionParameter, + GetUintId(), pid, {})); + get_def_use_mgr()->AnalyzeInstDefUse(&*param_inst); + output_func->AddParameter(std::move(param_inst)); + } + // Create first block + uint32_t test_blk_id = TakeNextId(); + std::unique_ptr test_label(NewLabel(test_blk_id)); + std::unique_ptr new_blk_ptr = + MakeUnique(std::move(test_label)); + InstructionBuilder builder( + context(), &*new_blk_ptr, + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + // Gen test if debug output buffer size will not be exceeded. + uint32_t obuf_record_sz = kInstStageOutCnt + val_spec_param_cnt; + uint32_t buf_id = GetOutputBufferId(); + uint32_t buf_uint_ptr_id = GetOutputBufferUintPtrId(); + Instruction* obuf_curr_sz_ac_inst = + builder.AddBinaryOp(buf_uint_ptr_id, SpvOpAccessChain, buf_id, + builder.GetUintConstantId(kDebugOutputSizeOffset)); + // Fetch the current debug buffer written size atomically, adding the + // size of the record to be written. + uint32_t obuf_record_sz_id = builder.GetUintConstantId(obuf_record_sz); + uint32_t mask_none_id = builder.GetUintConstantId(SpvMemoryAccessMaskNone); + uint32_t scope_invok_id = builder.GetUintConstantId(SpvScopeInvocation); + Instruction* obuf_curr_sz_inst = builder.AddQuadOp( + GetUintId(), SpvOpAtomicIAdd, obuf_curr_sz_ac_inst->result_id(), + scope_invok_id, mask_none_id, obuf_record_sz_id); + uint32_t obuf_curr_sz_id = obuf_curr_sz_inst->result_id(); + // Compute new written size + Instruction* obuf_new_sz_inst = + builder.AddBinaryOp(GetUintId(), SpvOpIAdd, obuf_curr_sz_id, + builder.GetUintConstantId(obuf_record_sz)); + // Fetch the data bound + Instruction* obuf_bnd_inst = + builder.AddIdLiteralOp(GetUintId(), SpvOpArrayLength, + GetOutputBufferId(), kDebugOutputDataOffset); + // Test that new written size is less than or equal to debug output + // data bound + Instruction* obuf_safe_inst = builder.AddBinaryOp( + GetBoolId(), SpvOpULessThanEqual, obuf_new_sz_inst->result_id(), + obuf_bnd_inst->result_id()); + uint32_t merge_blk_id = TakeNextId(); + uint32_t write_blk_id = TakeNextId(); + std::unique_ptr merge_label(NewLabel(merge_blk_id)); + std::unique_ptr write_label(NewLabel(write_blk_id)); + (void)builder.AddConditionalBranch(obuf_safe_inst->result_id(), + write_blk_id, merge_blk_id, merge_blk_id, + SpvSelectionControlMaskNone); + // Close safety test block and gen write block + new_blk_ptr->SetParent(&*output_func); + output_func->AddBasicBlock(std::move(new_blk_ptr)); + new_blk_ptr = MakeUnique(std::move(write_label)); + builder.SetInsertPoint(&*new_blk_ptr); + // Generate common and stage-specific debug record members + GenCommonStreamWriteCode(obuf_record_sz, param_vec[kInstCommonParamInstIdx], + stage_idx, obuf_curr_sz_id, &builder); + GenStageStreamWriteCode(stage_idx, obuf_curr_sz_id, &builder); + // Gen writes of validation specific data + for (uint32_t i = 0; i < val_spec_param_cnt; ++i) { + GenDebugOutputFieldCode(obuf_curr_sz_id, kInstStageOutCnt + i, + param_vec[kInstCommonParamCnt + i], &builder); + } + // Close write block and gen merge block + (void)builder.AddBranch(merge_blk_id); + new_blk_ptr->SetParent(&*output_func); + output_func->AddBasicBlock(std::move(new_blk_ptr)); + new_blk_ptr = MakeUnique(std::move(merge_label)); + builder.SetInsertPoint(&*new_blk_ptr); + // Close merge block and function and add function to module + (void)builder.AddNullaryOp(0, SpvOpReturn); + new_blk_ptr->SetParent(&*output_func); + output_func->AddBasicBlock(std::move(new_blk_ptr)); + std::unique_ptr func_end_inst( + new Instruction(get_module()->context(), SpvOpFunctionEnd, 0, 0, {})); + get_def_use_mgr()->AnalyzeInstDefUse(&*func_end_inst); + output_func->SetFunctionEnd(std::move(func_end_inst)); + context()->AddFunction(std::move(output_func)); + output_func_param_cnt_ = param_cnt; + } + assert(param_cnt == output_func_param_cnt_ && "bad arg count"); + return output_func_id_; +} + +bool InstrumentPass::InstrumentFunction(Function* func, uint32_t stage_idx, + InstProcessFunction& pfn) { + bool modified = false; + // Compute function index + uint32_t function_idx = 0; + for (auto fii = get_module()->begin(); fii != get_module()->end(); ++fii) { + if (&*fii == func) break; + ++function_idx; + } + std::vector> new_blks; + // Start count after function instruction + uint32_t instruction_idx = funcIdx2offset_[function_idx] + 1; + // Using block iterators here because of block erasures and insertions. + for (auto bi = func->begin(); bi != func->end(); ++bi) { + // Count block's label + ++instruction_idx; + for (auto ii = bi->begin(); ii != bi->end(); ++instruction_idx) { + // Bump instruction count if debug instructions + instruction_idx += static_cast(ii->dbg_line_insts().size()); + // Generate instrumentation if warranted + pfn(ii, bi, instruction_idx, stage_idx, &new_blks); + if (new_blks.size() == 0) { + ++ii; + continue; + } + // If there are new blocks we know there will always be two or + // more, so update succeeding phis with label of new last block. + size_t newBlocksSize = new_blks.size(); + assert(newBlocksSize > 1); + UpdateSucceedingPhis(new_blks); + // Replace original block with new block(s) + bi = bi.Erase(); + for (auto& bb : new_blks) { + bb->SetParent(func); + } + bi = bi.InsertBefore(&new_blks); + // Reset block iterator to last new block + for (size_t i = 0; i < newBlocksSize - 1; i++) ++bi; + modified = true; + // Restart instrumenting at beginning of last new block, + // but skip over any new phi or copy instruction. + ii = bi->begin(); + if (ii->opcode() == SpvOpPhi || ii->opcode() == SpvOpCopyObject) ++ii; + new_blks.clear(); + } + } + return modified; +} + +bool InstrumentPass::InstProcessCallTreeFromRoots(InstProcessFunction& pfn, + std::queue* roots, + uint32_t stage_idx) { + bool modified = false; + std::unordered_set done; + // Process all functions from roots + while (!roots->empty()) { + const uint32_t fi = roots->front(); + roots->pop(); + if (done.insert(fi).second) { + Function* fn = id2function_.at(fi); + // Add calls first so we don't add new output function + context()->AddCalls(fn, roots); + modified = InstrumentFunction(fn, stage_idx, pfn) || modified; + } + } + return modified; +} + +bool InstrumentPass::InstProcessEntryPointCallTree(InstProcessFunction& pfn) { + // Make sure all entry points have the same execution model. Do not + // instrument if they do not. + // TODO(greg-lunarg): Handle mixed stages. Technically, a shader module + // can contain entry points with different execution models, although + // such modules will likely be rare as GLSL and HLSL are geared toward + // one model per module. In such cases we will need + // to clone any functions which are in the call trees of entrypoints + // with differing execution models. + uint32_t ecnt = 0; + uint32_t stage = SpvExecutionModelMax; + for (auto& e : get_module()->entry_points()) { + if (ecnt == 0) + stage = e.GetSingleWordInOperand(kEntryPointExecutionModelInIdx); + else if (e.GetSingleWordInOperand(kEntryPointExecutionModelInIdx) != stage) + return false; + ++ecnt; + } + // Only supporting vertex, fragment and compute shaders at the moment. + // TODO(greg-lunarg): Handle all stages. + if (stage != SpvExecutionModelVertex && stage != SpvExecutionModelFragment && + stage != SpvExecutionModelGeometry && + stage != SpvExecutionModelGLCompute && + stage != SpvExecutionModelTessellationControl && + stage != SpvExecutionModelTessellationEvaluation) + return false; + // Add together the roots of all entry points + std::queue roots; + for (auto& e : get_module()->entry_points()) { + roots.push(e.GetSingleWordInOperand(kEntryPointFunctionIdInIdx)); + } + bool modified = InstProcessCallTreeFromRoots(pfn, &roots, stage); + return modified; +} + +void InstrumentPass::InitializeInstrument() { + output_buffer_id_ = 0; + output_buffer_uint_ptr_id_ = 0; + output_func_id_ = 0; + output_func_param_cnt_ = 0; + v4float_id_ = 0; + uint_id_ = 0; + v4uint_id_ = 0; + bool_id_ = 0; + void_id_ = 0; + + // clear collections + id2function_.clear(); + id2block_.clear(); + + // Initialize function and block maps. + for (auto& fn : *get_module()) { + id2function_[fn.result_id()] = &fn; + for (auto& blk : fn) { + id2block_[blk.id()] = &blk; + } + } + + // Calculate instruction offset of first function + uint32_t pre_func_size = 0; + Module* module = get_module(); + for (auto& i : context()->capabilities()) { + (void)i; + ++pre_func_size; + } + for (auto& i : module->extensions()) { + (void)i; + ++pre_func_size; + } + for (auto& i : module->ext_inst_imports()) { + (void)i; + ++pre_func_size; + } + ++pre_func_size; // memory_model + for (auto& i : module->entry_points()) { + (void)i; + ++pre_func_size; + } + for (auto& i : module->execution_modes()) { + (void)i; + ++pre_func_size; + } + for (auto& i : module->debugs1()) { + (void)i; + ++pre_func_size; + } + for (auto& i : module->debugs2()) { + (void)i; + ++pre_func_size; + } + for (auto& i : module->debugs3()) { + (void)i; + ++pre_func_size; + } + for (auto& i : module->annotations()) { + (void)i; + ++pre_func_size; + } + for (auto& i : module->types_values()) { + pre_func_size += 1; + pre_func_size += static_cast(i.dbg_line_insts().size()); + } + funcIdx2offset_[0] = pre_func_size; + + // Set instruction offsets for all other functions. + uint32_t func_idx = 1; + auto prev_fn = get_module()->begin(); + auto curr_fn = prev_fn; + for (++curr_fn; curr_fn != get_module()->end(); ++curr_fn) { + // Count function and end instructions + uint32_t func_size = 2; + for (auto& blk : *prev_fn) { + // Count label + func_size += 1; + for (auto& inst : blk) { + func_size += 1; + func_size += static_cast(inst.dbg_line_insts().size()); + } + } + funcIdx2offset_[func_idx] = func_size; + ++prev_fn; + ++func_idx; + } +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/instrument_pass.h b/source/opt/instrument_pass.h new file mode 100644 index 000000000..cfa76b149 --- /dev/null +++ b/source/opt/instrument_pass.h @@ -0,0 +1,357 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// Copyright (c) 2018 Valve Corporation +// Copyright (c) 2018 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIBSPIRV_OPT_INSTRUMENT_PASS_H_ +#define LIBSPIRV_OPT_INSTRUMENT_PASS_H_ + +#include +#include +#include + +#include "source/opt/ir_builder.h" +#include "source/opt/pass.h" +#include "spirv-tools/instrument.hpp" + +// This is a base class to assist in the creation of passes which instrument +// shader modules. More specifically, passes which replace instructions with a +// larger and more capable set of instructions. Commonly, these new +// instructions will add testing of operands and execute different +// instructions depending on the outcome, including outputting of debug +// information into a buffer created especially for that purpose. +// +// This class contains helper functions to create an InstProcessFunction, +// which is the heart of any derived class implementing a specific +// instrumentation pass. It takes an instruction as an argument, decides +// if it should be instrumented, and generates code to replace it. This class +// also supplies function InstProcessEntryPointCallTree which applies the +// InstProcessFunction to every reachable instruction in a module and replaces +// the instruction with new instructions if generated. +// +// Chief among the helper functions are output code generation functions, +// used to generate code in the shader which writes data to output buffers +// associated with that validation. Currently one such function, +// GenDebugStreamWrite, exists. Other such functions may be added in the +// future. Each is accompanied by documentation describing the format of +// its output buffer. +// +// A validation pass may read or write multiple buffers. All such buffers +// are located in a single debug descriptor set whose index is passed at the +// creation of the instrumentation pass. The bindings of the buffers used by +// a validation pass are permanantly assigned and fixed and documented by +// the kDebugOutput* static consts. + +namespace spvtools { +namespace opt { + +// Validation Ids +// These are used to identify the general validation being done and map to +// its output buffers. +static const uint32_t kInstValidationIdBindless = 0; + +class InstrumentPass : public Pass { + using cbb_ptr = const BasicBlock*; + + public: + using InstProcessFunction = std::function, uint32_t, uint32_t, + std::vector>*)>; + + ~InstrumentPass() override = default; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators | + IRContext::kAnalysisNameMap | IRContext::kAnalysisBuiltinVarId | + IRContext::kAnalysisConstants | IRContext::kAnalysisTypes; + } + + protected: + // Create instrumentation pass which utilizes descriptor set |desc_set| + // for debug input and output buffers and writes |shader_id| into debug + // output records. + InstrumentPass(uint32_t desc_set, uint32_t shader_id, uint32_t validation_id) + : Pass(), + desc_set_(desc_set), + shader_id_(shader_id), + validation_id_(validation_id) {} + + // Initialize state for instrumentation of module by |validation_id|. + void InitializeInstrument(); + + // Call |pfn| on all instructions in all functions in the call tree of the + // entry points in |module|. If code is generated for an instruction, replace + // the instruction's block with the new blocks that are generated. Continue + // processing at the top of the last new block. + bool InstProcessEntryPointCallTree(InstProcessFunction& pfn); + + // Move all code in |ref_block_itr| preceding the instruction |ref_inst_itr| + // to be instrumented into block |new_blk_ptr|. + void MovePreludeCode(BasicBlock::iterator ref_inst_itr, + UptrVectorIterator ref_block_itr, + std::unique_ptr* new_blk_ptr); + + // Move all code in |ref_block_itr| succeeding the instruction |ref_inst_itr| + // to be instrumented into block |new_blk_ptr|. + void MovePostludeCode(UptrVectorIterator ref_block_itr, + std::unique_ptr* new_blk_ptr); + + // Generate instructions in |builder| which will atomically fetch and + // increment the size of the debug output buffer stream of the current + // validation and write a record to the end of the stream, if enough space + // in the buffer remains. The record will contain the index of the function + // and instruction within that function |func_idx, instruction_idx| which + // generated the record. It will also contain additional information to + // identify the instance of the shader, depending on the stage |stage_idx| + // of the shader. Finally, the record will contain validation-specific + // data contained in |validation_ids| which will identify the validation + // error as well as the values involved in the error. + // + // The output buffer binding written to by the code generated by the function + // is determined by the validation id specified when each specific + // instrumentation pass is created. + // + // The output buffer is a sequence of 32-bit values with the following + // format (where all elements are unsigned 32-bit unless otherwise noted): + // + // Size + // Record0 + // Record1 + // Record2 + // ... + // + // Size is the number of 32-bit values that have been written or + // attempted to be written to the output buffer, excluding the Size. It is + // initialized to 0. If the size of attempts to write the buffer exceeds + // the actual size of the buffer, it is possible that this field can exceed + // the actual size of the buffer. + // + // Each Record* is a variable-length sequence of 32-bit values with the + // following format defined using static const offsets in the .cpp file: + // + // Record Size + // Shader ID + // Instruction Index + // Stage + // Stage-specific Word 0 + // Stage-specific Word 1 + // Validation Error Code + // Validation-specific Word 0 + // Validation-specific Word 1 + // Validation-specific Word 2 + // ... + // + // Each record consists of three subsections: members common across all + // validation, members specific to the stage, and members specific to a + // validation. + // + // The Record Size is the number of 32-bit words in the record, including + // the Record Size word. + // + // Shader ID is a value that identifies which shader has generated the + // validation error. It is passed when the instrumentation pass is created. + // + // The Instruction Index is the position of the instruction within the + // SPIR-V file which is in error. + // + // The Stage is the pipeline stage which has generated the error as defined + // by the SpvExecutionModel_ enumeration. This is used to interpret the + // following Stage-specific words. + // + // The Stage-specific Words identify which invocation of the shader generated + // the error. Every stage will write two words, although in some cases the + // second word is unused and so zero is written. Vertex shaders will write + // the Vertex and Instance ID. Fragment shaders will write FragCoord.xy. + // Compute shaders will write the Global Invocation ID and zero (unused). + // Both tesselation shaders will write the Invocation Id and zero (unused). + // The geometry shader will write the Primitive ID and Invocation ID. + // + // The Validation Error Code specifies the exact error which has occurred. + // These are enumerated with the kInstError* static consts. This allows + // multiple validation layers to use the same, single output buffer. + // + // The Validation-specific Words are a validation-specific number of 32-bit + // words which give further information on the validation error that + // occurred. These are documented further in each file containing the + // validation-specific class which derives from this base class. + // + // Because the code that is generated checks against the size of the buffer + // before writing, the size of the debug out buffer can be used by the + // validation layer to control the number of error records that are written. + void GenDebugStreamWrite(uint32_t instruction_idx, uint32_t stage_idx, + const std::vector& validation_ids, + InstructionBuilder* builder); + + // Generate code to cast |value_id| to unsigned, if needed. Return + // an id to the unsigned equivalent. + uint32_t GenUintCastCode(uint32_t value_id, InstructionBuilder* builder); + + // Return new label. + std::unique_ptr NewLabel(uint32_t label_id); + + // Return id for 32-bit unsigned type + uint32_t GetUintId(); + + // Return id for 32-bit unsigned type + uint32_t GetBoolId(); + + // Return id for void type + uint32_t GetVoidId(); + + // Return id for output buffer uint type + uint32_t GetOutputBufferUintPtrId(); + + // Return binding for output buffer for current validation. + uint32_t GetOutputBufferBinding(); + + // Return id for debug output buffer + uint32_t GetOutputBufferId(); + + // Return id for v4float type + uint32_t GetVec4FloatId(); + + // Return id for v4uint type + uint32_t GetVec4UintId(); + + // Return id for output function. Define if it doesn't exist with + // |val_spec_arg_cnt| validation-specific uint32 arguments. + uint32_t GetStreamWriteFunctionId(uint32_t stage_idx, + uint32_t val_spec_param_cnt); + + // Apply instrumentation function |pfn| to every instruction in |func|. + // If code is generated for an instruction, replace the instruction's + // block with the new blocks that are generated. Continue processing at the + // top of the last new block. + bool InstrumentFunction(Function* func, uint32_t stage_idx, + InstProcessFunction& pfn); + + // Call |pfn| on all functions in the call tree of the function + // ids in |roots|. + bool InstProcessCallTreeFromRoots(InstProcessFunction& pfn, + std::queue* roots, + uint32_t stage_idx); + + // Gen code into |builder| to write |field_value_id| into debug output + // buffer at |base_offset_id| + |field_offset|. + void GenDebugOutputFieldCode(uint32_t base_offset_id, uint32_t field_offset, + uint32_t field_value_id, + InstructionBuilder* builder); + + // Generate instructions into |builder| which will write the members + // of the debug output record common for all stages and validations at + // |base_off|. + void GenCommonStreamWriteCode(uint32_t record_sz, uint32_t instruction_idx, + uint32_t stage_idx, uint32_t base_off, + InstructionBuilder* builder); + + // Generate instructions into |builder| which will write + // |uint_frag_coord_id| at |component| of the record at |base_offset_id| of + // the debug output buffer . + void GenFragCoordEltDebugOutputCode(uint32_t base_offset_id, + uint32_t uint_frag_coord_id, + uint32_t component, + InstructionBuilder* builder); + + // Generate instructions into |builder| which will load the uint |builtin_id| + // and write it into the debug output buffer at |base_off| + |builtin_off|. + void GenBuiltinOutputCode(uint32_t builtin_id, uint32_t builtin_off, + uint32_t base_off, InstructionBuilder* builder); + + // Generate instructions into |builder| which will write a uint null into + // the debug output buffer at |base_off| + |builtin_off|. + void GenUintNullOutputCode(uint32_t field_off, uint32_t base_off, + InstructionBuilder* builder); + + // Generate instructions into |builder| which will write the |stage_idx|- + // specific members of the debug output stream at |base_off|. + void GenStageStreamWriteCode(uint32_t stage_idx, uint32_t base_off, + InstructionBuilder* builder); + + // Return true if instruction must be in the same block that its result + // is used. + bool IsSameBlockOp(const Instruction* inst) const; + + // Clone operands which must be in same block as consumer instructions. + // Look in same_blk_pre for instructions that need cloning. Look in + // same_blk_post for instructions already cloned. Add cloned instruction + // to same_blk_post. + void CloneSameBlockOps( + std::unique_ptr* inst, + std::unordered_map* same_blk_post, + std::unordered_map* same_blk_pre, + std::unique_ptr* block_ptr); + + // Update phis in succeeding blocks to point to new last block + void UpdateSucceedingPhis( + std::vector>& new_blocks); + + // Debug descriptor set index + uint32_t desc_set_; + + // Shader module ID written into output record + uint32_t shader_id_; + + // Map from function id to function pointer. + std::unordered_map id2function_; + + // Map from block's label id to block. TODO(dnovillo): This is superfluous wrt + // CFG. It has functionality not present in CFG. Consolidate. + std::unordered_map id2block_; + + // Map from function's position index to the offset of its first instruction + std::unordered_map funcIdx2offset_; + + // result id for OpConstantFalse + uint32_t validation_id_; + + // id for output buffer variable + uint32_t output_buffer_id_; + + // type id for output buffer element + uint32_t output_buffer_uint_ptr_id_; + + // id for debug output function + uint32_t output_func_id_; + + // param count for output function + uint32_t output_func_param_cnt_; + + // id for v4float type + uint32_t v4float_id_; + + // id for v4float type + uint32_t v4uint_id_; + + // id for 32-bit unsigned type + uint32_t uint_id_; + + // id for bool type + uint32_t bool_id_; + + // id for void type + uint32_t void_id_; + + // Pre-instrumentation same-block insts + std::unordered_map same_block_pre_; + + // Post-instrumentation same-block op ids + std::unordered_map same_block_post_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // LIBSPIRV_OPT_INSTRUMENT_PASS_H_ diff --git a/source/opt/ir_builder.h b/source/opt/ir_builder.h new file mode 100644 index 000000000..da7405512 --- /dev/null +++ b/source/opt/ir_builder.h @@ -0,0 +1,542 @@ +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_IR_BUILDER_H_ +#define SOURCE_OPT_IR_BUILDER_H_ + +#include +#include +#include +#include + +#include "source/opt/basic_block.h" +#include "source/opt/constants.h" +#include "source/opt/instruction.h" +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace opt { + +// In SPIR-V, ids are encoded as uint16_t, this id is guaranteed to be always +// invalid. +const uint32_t kInvalidId = std::numeric_limits::max(); + +// Helper class to abstract instruction construction and insertion. +// The instruction builder can preserve the following analyses (specified via +// the constructors): +// - Def-use analysis +// - Instruction to block analysis +class InstructionBuilder { + public: + using InsertionPointTy = BasicBlock::iterator; + + // Creates an InstructionBuilder, all new instructions will be inserted before + // the instruction |insert_before|. + InstructionBuilder( + IRContext* context, Instruction* insert_before, + IRContext::Analysis preserved_analyses = IRContext::kAnalysisNone) + : InstructionBuilder(context, context->get_instr_block(insert_before), + InsertionPointTy(insert_before), + preserved_analyses) {} + + // Creates an InstructionBuilder, all new instructions will be inserted at the + // end of the basic block |parent_block|. + InstructionBuilder( + IRContext* context, BasicBlock* parent_block, + IRContext::Analysis preserved_analyses = IRContext::kAnalysisNone) + : InstructionBuilder(context, parent_block, parent_block->end(), + preserved_analyses) {} + + Instruction* AddNullaryOp(uint32_t type_id, SpvOp opcode) { + // TODO(1841): Handle id overflow. + std::unique_ptr newUnOp(new Instruction( + GetContext(), opcode, type_id, + opcode == SpvOpReturn ? 0 : GetContext()->TakeNextId(), {})); + return AddInstruction(std::move(newUnOp)); + } + + Instruction* AddUnaryOp(uint32_t type_id, SpvOp opcode, uint32_t operand1) { + // TODO(1841): Handle id overflow. + std::unique_ptr newUnOp(new Instruction( + GetContext(), opcode, type_id, GetContext()->TakeNextId(), + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {operand1}}})); + return AddInstruction(std::move(newUnOp)); + } + + Instruction* AddBinaryOp(uint32_t type_id, SpvOp opcode, uint32_t operand1, + uint32_t operand2) { + // TODO(1841): Handle id overflow. + std::unique_ptr newBinOp(new Instruction( + GetContext(), opcode, type_id, + opcode == SpvOpStore ? 0 : GetContext()->TakeNextId(), + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {operand1}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {operand2}}})); + return AddInstruction(std::move(newBinOp)); + } + + Instruction* AddTernaryOp(uint32_t type_id, SpvOp opcode, uint32_t operand1, + uint32_t operand2, uint32_t operand3) { + // TODO(1841): Handle id overflow. + std::unique_ptr newTernOp(new Instruction( + GetContext(), opcode, type_id, GetContext()->TakeNextId(), + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {operand1}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {operand2}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {operand3}}})); + return AddInstruction(std::move(newTernOp)); + } + + Instruction* AddQuadOp(uint32_t type_id, SpvOp opcode, uint32_t operand1, + uint32_t operand2, uint32_t operand3, + uint32_t operand4) { + // TODO(1841): Handle id overflow. + std::unique_ptr newQuadOp(new Instruction( + GetContext(), opcode, type_id, GetContext()->TakeNextId(), + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {operand1}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {operand2}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {operand3}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {operand4}}})); + return AddInstruction(std::move(newQuadOp)); + } + + Instruction* AddIdLiteralOp(uint32_t type_id, SpvOp opcode, uint32_t operand1, + uint32_t operand2) { + // TODO(1841): Handle id overflow. + std::unique_ptr newBinOp(new Instruction( + GetContext(), opcode, type_id, GetContext()->TakeNextId(), + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {operand1}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {operand2}}})); + return AddInstruction(std::move(newBinOp)); + } + + // Creates an N-ary instruction of |opcode|. + // |typid| must be the id of the instruction's type. + // |operands| must be a sequence of operand ids. + // Use |result| for the result id if non-zero. + Instruction* AddNaryOp(uint32_t type_id, SpvOp opcode, + const std::vector& operands, + uint32_t result = 0) { + std::vector ops; + for (size_t i = 0; i < operands.size(); i++) { + ops.push_back({SPV_OPERAND_TYPE_ID, {operands[i]}}); + } + // TODO(1841): Handle id overflow. + std::unique_ptr new_inst(new Instruction( + GetContext(), opcode, type_id, + result != 0 ? result : GetContext()->TakeNextId(), ops)); + return AddInstruction(std::move(new_inst)); + } + + // Creates a new selection merge instruction. + // The id |merge_id| is the merge basic block id. + Instruction* AddSelectionMerge( + uint32_t merge_id, + uint32_t selection_control = SpvSelectionControlMaskNone) { + std::unique_ptr new_branch_merge(new Instruction( + GetContext(), SpvOpSelectionMerge, 0, 0, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {merge_id}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_SELECTION_CONTROL, + {selection_control}}})); + return AddInstruction(std::move(new_branch_merge)); + } + + // Creates a new loop merge instruction. + // The id |merge_id| is the basic block id of the merge block. + // |continue_id| is the id of the continue block. + // |loop_control| are the loop control flags to be added to the instruction. + Instruction* AddLoopMerge(uint32_t merge_id, uint32_t continue_id, + uint32_t loop_control = SpvLoopControlMaskNone) { + std::unique_ptr new_branch_merge(new Instruction( + GetContext(), SpvOpLoopMerge, 0, 0, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {merge_id}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {continue_id}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_LOOP_CONTROL, {loop_control}}})); + return AddInstruction(std::move(new_branch_merge)); + } + + // Creates a new branch instruction to |label_id|. + // Note that the user must make sure the final basic block is + // well formed. + Instruction* AddBranch(uint32_t label_id) { + std::unique_ptr new_branch(new Instruction( + GetContext(), SpvOpBranch, 0, 0, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {label_id}}})); + return AddInstruction(std::move(new_branch)); + } + + // Creates a new conditional instruction and the associated selection merge + // instruction if requested. + // The id |cond_id| is the id of the condition instruction, must be of + // type bool. + // The id |true_id| is the id of the basic block to branch to if the condition + // is true. + // The id |false_id| is the id of the basic block to branch to if the + // condition is false. + // The id |merge_id| is the id of the merge basic block for the selection + // merge instruction. If |merge_id| equals kInvalidId then no selection merge + // instruction will be created. + // The value |selection_control| is the selection control flag for the + // selection merge instruction. + // Note that the user must make sure the final basic block is + // well formed. + Instruction* AddConditionalBranch( + uint32_t cond_id, uint32_t true_id, uint32_t false_id, + uint32_t merge_id = kInvalidId, + uint32_t selection_control = SpvSelectionControlMaskNone) { + if (merge_id != kInvalidId) { + AddSelectionMerge(merge_id, selection_control); + } + std::unique_ptr new_branch(new Instruction( + GetContext(), SpvOpBranchConditional, 0, 0, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {cond_id}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {true_id}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {false_id}}})); + return AddInstruction(std::move(new_branch)); + } + + // Creates a new switch instruction and the associated selection merge + // instruction if requested. + // The id |selector_id| is the id of the selector instruction, must be of + // type int. + // The id |default_id| is the id of the default basic block to branch to. + // The vector |targets| is the pair of literal/branch id. + // The id |merge_id| is the id of the merge basic block for the selection + // merge instruction. If |merge_id| equals kInvalidId then no selection merge + // instruction will be created. + // The value |selection_control| is the selection control flag for the + // selection merge instruction. + // Note that the user must make sure the final basic block is + // well formed. + Instruction* AddSwitch( + uint32_t selector_id, uint32_t default_id, + const std::vector>& targets, + uint32_t merge_id = kInvalidId, + uint32_t selection_control = SpvSelectionControlMaskNone) { + if (merge_id != kInvalidId) { + AddSelectionMerge(merge_id, selection_control); + } + std::vector operands; + operands.emplace_back( + Operand{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {selector_id}}); + operands.emplace_back( + Operand{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {default_id}}); + for (auto& target : targets) { + operands.emplace_back( + Operand{spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, + target.first}); + operands.emplace_back( + Operand{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {target.second}}); + } + std::unique_ptr new_switch( + new Instruction(GetContext(), SpvOpSwitch, 0, 0, operands)); + return AddInstruction(std::move(new_switch)); + } + + // Creates a phi instruction. + // The id |type| must be the id of the phi instruction's type. + // The vector |incomings| must be a sequence of pairs of . + Instruction* AddPhi(uint32_t type, const std::vector& incomings, + uint32_t result = 0) { + assert(incomings.size() % 2 == 0 && "A sequence of pairs is expected"); + return AddNaryOp(type, SpvOpPhi, incomings, result); + } + + // Creates an addition instruction. + // The id |type| must be the id of the instruction's type, must be the same as + // |op1| and |op2| types. + // The id |op1| is the left hand side of the operation. + // The id |op2| is the right hand side of the operation. + Instruction* AddIAdd(uint32_t type, uint32_t op1, uint32_t op2) { + // TODO(1841): Handle id overflow. + std::unique_ptr inst(new Instruction( + GetContext(), SpvOpIAdd, type, GetContext()->TakeNextId(), + {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}})); + return AddInstruction(std::move(inst)); + } + + // Creates a less than instruction for unsigned integer. + // The id |op1| is the left hand side of the operation. + // The id |op2| is the right hand side of the operation. + // It is assumed that |op1| and |op2| have the same underlying type. + Instruction* AddULessThan(uint32_t op1, uint32_t op2) { + analysis::Bool bool_type; + uint32_t type = GetContext()->get_type_mgr()->GetId(&bool_type); + // TODO(1841): Handle id overflow. + std::unique_ptr inst(new Instruction( + GetContext(), SpvOpULessThan, type, GetContext()->TakeNextId(), + {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}})); + return AddInstruction(std::move(inst)); + } + + // Creates a less than instruction for signed integer. + // The id |op1| is the left hand side of the operation. + // The id |op2| is the right hand side of the operation. + // It is assumed that |op1| and |op2| have the same underlying type. + Instruction* AddSLessThan(uint32_t op1, uint32_t op2) { + analysis::Bool bool_type; + uint32_t type = GetContext()->get_type_mgr()->GetId(&bool_type); + // TODO(1841): Handle id overflow. + std::unique_ptr inst(new Instruction( + GetContext(), SpvOpSLessThan, type, GetContext()->TakeNextId(), + {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}})); + return AddInstruction(std::move(inst)); + } + + // Creates an OpILessThan or OpULessThen instruction depending on the sign of + // |op1|. The id |op1| is the left hand side of the operation. The id |op2| is + // the right hand side of the operation. It is assumed that |op1| and |op2| + // have the same underlying type. + Instruction* AddLessThan(uint32_t op1, uint32_t op2) { + Instruction* op1_insn = context_->get_def_use_mgr()->GetDef(op1); + analysis::Type* type = + GetContext()->get_type_mgr()->GetType(op1_insn->type_id()); + analysis::Integer* int_type = type->AsInteger(); + assert(int_type && "Operand is not of int type"); + + if (int_type->IsSigned()) + return AddSLessThan(op1, op2); + else + return AddULessThan(op1, op2); + } + + // Creates a select instruction. + // |type| must match the types of |true_value| and |false_value|. It is up to + // the caller to ensure that |cond| is a correct type (bool or vector of + // bool) for |type|. + Instruction* AddSelect(uint32_t type, uint32_t cond, uint32_t true_value, + uint32_t false_value) { + // TODO(1841): Handle id overflow. + std::unique_ptr select(new Instruction( + GetContext(), SpvOpSelect, type, GetContext()->TakeNextId(), + std::initializer_list{{SPV_OPERAND_TYPE_ID, {cond}}, + {SPV_OPERAND_TYPE_ID, {true_value}}, + {SPV_OPERAND_TYPE_ID, {false_value}}})); + return AddInstruction(std::move(select)); + } + + // Adds a signed int32 constant to the binary. + // The |value| parameter is the constant value to be added. + Instruction* GetSintConstant(int32_t value) { + return GetIntConstant(value, true); + } + + // Create a composite construct. + // |type| should be a composite type and the number of elements it has should + // match the size od |ids|. + Instruction* AddCompositeConstruct(uint32_t type, + const std::vector& ids) { + std::vector ops; + for (auto id : ids) { + ops.emplace_back(SPV_OPERAND_TYPE_ID, + std::initializer_list{id}); + } + // TODO(1841): Handle id overflow. + std::unique_ptr construct( + new Instruction(GetContext(), SpvOpCompositeConstruct, type, + GetContext()->TakeNextId(), ops)); + return AddInstruction(std::move(construct)); + } + // Adds an unsigned int32 constant to the binary. + // The |value| parameter is the constant value to be added. + Instruction* GetUintConstant(uint32_t value) { + return GetIntConstant(value, false); + } + + uint32_t GetUintConstantId(uint32_t value) { + Instruction* uint_inst = GetUintConstant(value); + return uint_inst->result_id(); + } + + uint32_t GetNullId(uint32_t type_id) { + analysis::TypeManager* type_mgr = GetContext()->get_type_mgr(); + analysis::ConstantManager* const_mgr = GetContext()->get_constant_mgr(); + const analysis::Type* type = type_mgr->GetType(type_id); + const analysis::Constant* null_const = const_mgr->GetConstant(type, {}); + Instruction* null_inst = + const_mgr->GetDefiningInstruction(null_const, type_id); + return null_inst->result_id(); + } + + // Adds either a signed or unsigned 32 bit integer constant to the binary + // depedning on the |sign|. If |sign| is true then the value is added as a + // signed constant otherwise as an unsigned constant. If |sign| is false the + // value must not be a negative number. + template + Instruction* GetIntConstant(T value, bool sign) { + // Assert that we are not trying to store a negative number in an unsigned + // type. + if (!sign) + assert(value >= 0 && + "Trying to add a signed integer with an unsigned type!"); + + analysis::Integer int_type{32, sign}; + + // Get or create the integer type. This rebuilds the type and manages the + // memory for the rebuilt type. + uint32_t type_id = + GetContext()->get_type_mgr()->GetTypeInstruction(&int_type); + + // Get the memory managed type so that it is safe to be stored by + // GetConstant. + analysis::Type* rebuilt_type = + GetContext()->get_type_mgr()->GetType(type_id); + + // Even if the value is negative we need to pass the bit pattern as a + // uint32_t to GetConstant. + uint32_t word = value; + + // Create the constant value. + const analysis::Constant* constant = + GetContext()->get_constant_mgr()->GetConstant(rebuilt_type, {word}); + + // Create the OpConstant instruction using the type and the value. + return GetContext()->get_constant_mgr()->GetDefiningInstruction(constant); + } + + Instruction* AddCompositeExtract(uint32_t type, uint32_t id_of_composite, + const std::vector& index_list) { + std::vector operands; + operands.push_back({SPV_OPERAND_TYPE_ID, {id_of_composite}}); + + for (uint32_t index : index_list) { + operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}); + } + + // TODO(1841): Handle id overflow. + std::unique_ptr new_inst( + new Instruction(GetContext(), SpvOpCompositeExtract, type, + GetContext()->TakeNextId(), operands)); + return AddInstruction(std::move(new_inst)); + } + + // Creates an unreachable instruction. + Instruction* AddUnreachable() { + std::unique_ptr select( + new Instruction(GetContext(), SpvOpUnreachable, 0, 0, + std::initializer_list{})); + return AddInstruction(std::move(select)); + } + + Instruction* AddAccessChain(uint32_t type_id, uint32_t base_ptr_id, + std::vector ids) { + std::vector operands; + operands.push_back({SPV_OPERAND_TYPE_ID, {base_ptr_id}}); + + for (uint32_t index_id : ids) { + operands.push_back({SPV_OPERAND_TYPE_ID, {index_id}}); + } + + // TODO(1841): Handle id overflow. + std::unique_ptr new_inst( + new Instruction(GetContext(), SpvOpAccessChain, type_id, + GetContext()->TakeNextId(), operands)); + return AddInstruction(std::move(new_inst)); + } + + Instruction* AddLoad(uint32_t type_id, uint32_t base_ptr_id) { + std::vector operands; + operands.push_back({SPV_OPERAND_TYPE_ID, {base_ptr_id}}); + + // TODO(1841): Handle id overflow. + std::unique_ptr new_inst( + new Instruction(GetContext(), SpvOpLoad, type_id, + GetContext()->TakeNextId(), operands)); + return AddInstruction(std::move(new_inst)); + } + + Instruction* AddStore(uint32_t ptr_id, uint32_t obj_id) { + std::vector operands; + operands.push_back({SPV_OPERAND_TYPE_ID, {ptr_id}}); + operands.push_back({SPV_OPERAND_TYPE_ID, {obj_id}}); + + std::unique_ptr new_inst( + new Instruction(GetContext(), SpvOpStore, 0, 0, operands)); + return AddInstruction(std::move(new_inst)); + } + + // Inserts the new instruction before the insertion point. + Instruction* AddInstruction(std::unique_ptr&& insn) { + Instruction* insn_ptr = &*insert_before_.InsertBefore(std::move(insn)); + UpdateInstrToBlockMapping(insn_ptr); + UpdateDefUseMgr(insn_ptr); + return insn_ptr; + } + + // Returns the insertion point iterator. + InsertionPointTy GetInsertPoint() { return insert_before_; } + + // Change the insertion point to insert before the instruction + // |insert_before|. + void SetInsertPoint(Instruction* insert_before) { + parent_ = context_->get_instr_block(insert_before); + insert_before_ = InsertionPointTy(insert_before); + } + + // Change the insertion point to insert at the end of the basic block + // |parent_block|. + void SetInsertPoint(BasicBlock* parent_block) { + parent_ = parent_block; + insert_before_ = parent_block->end(); + } + + // Returns the context which instructions are constructed for. + IRContext* GetContext() const { return context_; } + + // Returns the set of preserved analyses. + inline IRContext::Analysis GetPreservedAnalysis() const { + return preserved_analyses_; + } + + private: + InstructionBuilder(IRContext* context, BasicBlock* parent, + InsertionPointTy insert_before, + IRContext::Analysis preserved_analyses) + : context_(context), + parent_(parent), + insert_before_(insert_before), + preserved_analyses_(preserved_analyses) { + assert(!(preserved_analyses_ & ~(IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping))); + } + + // Returns true if the users requested to update |analysis|. + inline bool IsAnalysisUpdateRequested(IRContext::Analysis analysis) const { + return preserved_analyses_ & analysis; + } + + // Updates the def/use manager if the user requested it. If he did not request + // an update, this function does nothing. + inline void UpdateDefUseMgr(Instruction* insn) { + if (IsAnalysisUpdateRequested(IRContext::kAnalysisDefUse)) + GetContext()->get_def_use_mgr()->AnalyzeInstDefUse(insn); + } + + // Updates the instruction to block analysis if the user requested it. If he + // did not request an update, this function does nothing. + inline void UpdateInstrToBlockMapping(Instruction* insn) { + if (IsAnalysisUpdateRequested(IRContext::kAnalysisInstrToBlockMapping) && + parent_) + GetContext()->set_instr_block(insn, parent_); + } + + IRContext* context_; + BasicBlock* parent_; + InsertionPointTy insert_before_; + const IRContext::Analysis preserved_analyses_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_IR_BUILDER_H_ diff --git a/source/opt/ir_context.cpp b/source/opt/ir_context.cpp new file mode 100644 index 000000000..a31349f45 --- /dev/null +++ b/source/opt/ir_context.cpp @@ -0,0 +1,850 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/ir_context.h" + +#include + +#include "source/latest_version_glsl_std_450_header.h" +#include "source/opt/log.h" +#include "source/opt/mem_pass.h" +#include "source/opt/reflect.h" + +namespace { + +static const int kSpvDecorateTargetIdInIdx = 0; +static const int kSpvDecorateDecorationInIdx = 1; +static const int kSpvDecorateBuiltinInIdx = 2; +static const int kEntryPointInterfaceInIdx = 3; +static const int kEntryPointFunctionIdInIdx = 1; + +} // anonymous namespace + +namespace spvtools { +namespace opt { + +void IRContext::BuildInvalidAnalyses(IRContext::Analysis set) { + if (set & kAnalysisDefUse) { + BuildDefUseManager(); + } + if (set & kAnalysisInstrToBlockMapping) { + BuildInstrToBlockMapping(); + } + if (set & kAnalysisDecorations) { + BuildDecorationManager(); + } + if (set & kAnalysisCFG) { + BuildCFG(); + } + if (set & kAnalysisDominatorAnalysis) { + ResetDominatorAnalysis(); + } + if (set & kAnalysisLoopAnalysis) { + ResetLoopAnalysis(); + } + if (set & kAnalysisBuiltinVarId) { + ResetBuiltinAnalysis(); + } + if (set & kAnalysisNameMap) { + BuildIdToNameMap(); + } + if (set & kAnalysisScalarEvolution) { + BuildScalarEvolutionAnalysis(); + } + if (set & kAnalysisRegisterPressure) { + BuildRegPressureAnalysis(); + } + if (set & kAnalysisValueNumberTable) { + BuildValueNumberTable(); + } + if (set & kAnalysisStructuredCFG) { + BuildStructuredCFGAnalysis(); + } + if (set & kAnalysisIdToFuncMapping) { + BuildIdToFuncMapping(); + } + if (set & kAnalysisConstants) { + BuildConstantManager(); + } + if (set & kAnalysisTypes) { + BuildTypeManager(); + } +} + +void IRContext::InvalidateAnalysesExceptFor( + IRContext::Analysis preserved_analyses) { + uint32_t analyses_to_invalidate = valid_analyses_ & (~preserved_analyses); + InvalidateAnalyses(static_cast(analyses_to_invalidate)); +} + +void IRContext::InvalidateAnalyses(IRContext::Analysis analyses_to_invalidate) { + if (analyses_to_invalidate & kAnalysisDefUse) { + def_use_mgr_.reset(nullptr); + } + if (analyses_to_invalidate & kAnalysisInstrToBlockMapping) { + instr_to_block_.clear(); + } + if (analyses_to_invalidate & kAnalysisDecorations) { + decoration_mgr_.reset(nullptr); + } + if (analyses_to_invalidate & kAnalysisCombinators) { + combinator_ops_.clear(); + } + if (analyses_to_invalidate & kAnalysisBuiltinVarId) { + builtin_var_id_map_.clear(); + } + if (analyses_to_invalidate & kAnalysisCFG) { + cfg_.reset(nullptr); + } + if (analyses_to_invalidate & kAnalysisDominatorAnalysis) { + dominator_trees_.clear(); + post_dominator_trees_.clear(); + } + if (analyses_to_invalidate & kAnalysisNameMap) { + id_to_name_.reset(nullptr); + } + if (analyses_to_invalidate & kAnalysisValueNumberTable) { + vn_table_.reset(nullptr); + } + if (analyses_to_invalidate & kAnalysisStructuredCFG) { + struct_cfg_analysis_.reset(nullptr); + } + if (analyses_to_invalidate & kAnalysisIdToFuncMapping) { + id_to_func_.clear(); + } + if (analyses_to_invalidate & kAnalysisConstants) { + constant_mgr_.reset(nullptr); + } + if (analyses_to_invalidate & kAnalysisTypes) { + // The ConstantManager contains Type pointers. If the TypeManager goes + // away, the ConstantManager has to go away. + constant_mgr_.reset(nullptr); + type_mgr_.reset(nullptr); + } + + valid_analyses_ = Analysis(valid_analyses_ & ~analyses_to_invalidate); +} + +Instruction* IRContext::KillInst(Instruction* inst) { + if (!inst) { + return nullptr; + } + + KillNamesAndDecorates(inst); + + if (AreAnalysesValid(kAnalysisDefUse)) { + get_def_use_mgr()->ClearInst(inst); + } + if (AreAnalysesValid(kAnalysisInstrToBlockMapping)) { + instr_to_block_.erase(inst); + } + if (AreAnalysesValid(kAnalysisDecorations)) { + if (inst->IsDecoration()) { + decoration_mgr_->RemoveDecoration(inst); + } + } + + if (type_mgr_ && IsTypeInst(inst->opcode())) { + type_mgr_->RemoveId(inst->result_id()); + } + + if (constant_mgr_ && IsConstantInst(inst->opcode())) { + constant_mgr_->RemoveId(inst->result_id()); + } + + RemoveFromIdToName(inst); + + Instruction* next_instruction = nullptr; + if (inst->IsInAList()) { + next_instruction = inst->NextNode(); + inst->RemoveFromList(); + delete inst; + } else { + // Needed for instructions that are not part of a list like OpLabels, + // OpFunction, OpFunctionEnd, etc.. + inst->ToNop(); + } + return next_instruction; +} + +bool IRContext::KillDef(uint32_t id) { + Instruction* def = get_def_use_mgr()->GetDef(id); + if (def != nullptr) { + KillInst(def); + return true; + } + return false; +} + +bool IRContext::ReplaceAllUsesWith(uint32_t before, uint32_t after) { + if (before == after) return false; + + // Ensure that |after| has been registered as def. + assert(get_def_use_mgr()->GetDef(after) && + "'after' is not a registered def."); + + std::vector> uses_to_update; + get_def_use_mgr()->ForEachUse( + before, [&uses_to_update](Instruction* user, uint32_t index) { + uses_to_update.emplace_back(user, index); + }); + + Instruction* prev = nullptr; + for (auto p : uses_to_update) { + Instruction* user = p.first; + uint32_t index = p.second; + if (prev == nullptr || prev != user) { + ForgetUses(user); + prev = user; + } + const uint32_t type_result_id_count = + (user->result_id() != 0) + (user->type_id() != 0); + + if (index < type_result_id_count) { + // Update the type_id. Note that result id is immutable so it should + // never be updated. + if (user->type_id() != 0 && index == 0) { + user->SetResultType(after); + } else if (user->type_id() == 0) { + SPIRV_ASSERT(consumer_, false, + "Result type id considered as use while the instruction " + "doesn't have a result type id."); + (void)consumer_; // Makes the compiler happy for release build. + } else { + SPIRV_ASSERT(consumer_, false, + "Trying setting the immutable result id."); + } + } else { + // Update an in-operand. + uint32_t in_operand_pos = index - type_result_id_count; + // Make the modification in the instruction. + user->SetInOperand(in_operand_pos, {after}); + } + AnalyzeUses(user); + } + + return true; +} + +bool IRContext::IsConsistent() { +#ifndef SPIRV_CHECK_CONTEXT + return true; +#endif + + if (AreAnalysesValid(kAnalysisDefUse)) { + analysis::DefUseManager new_def_use(module()); + if (*get_def_use_mgr() != new_def_use) { + return false; + } + } + + if (AreAnalysesValid(kAnalysisInstrToBlockMapping)) { + for (auto& func : *module()) { + for (auto& block : func) { + if (!block.WhileEachInst([this, &block](Instruction* inst) { + if (get_instr_block(inst) != &block) { + return false; + } + return true; + })) + return false; + } + } + } + + if (!CheckCFG()) { + return false; + } + + if (AreAnalysesValid(kAnalysisDecorations)) { + analysis::DecorationManager* dec_mgr = get_decoration_mgr(); + analysis::DecorationManager current(module()); + + if (*dec_mgr != current) { + return false; + } + } + return true; +} + +void IRContext::ForgetUses(Instruction* inst) { + if (AreAnalysesValid(kAnalysisDefUse)) { + get_def_use_mgr()->EraseUseRecordsOfOperandIds(inst); + } + if (AreAnalysesValid(kAnalysisDecorations)) { + if (inst->IsDecoration()) { + get_decoration_mgr()->RemoveDecoration(inst); + } + } + RemoveFromIdToName(inst); +} + +void IRContext::AnalyzeUses(Instruction* inst) { + if (AreAnalysesValid(kAnalysisDefUse)) { + get_def_use_mgr()->AnalyzeInstUse(inst); + } + if (AreAnalysesValid(kAnalysisDecorations)) { + if (inst->IsDecoration()) { + get_decoration_mgr()->AddDecoration(inst); + } + } + if (id_to_name_ && + (inst->opcode() == SpvOpName || inst->opcode() == SpvOpMemberName)) { + id_to_name_->insert({inst->GetSingleWordInOperand(0), inst}); + } +} + +void IRContext::KillNamesAndDecorates(uint32_t id) { + analysis::DecorationManager* dec_mgr = get_decoration_mgr(); + dec_mgr->RemoveDecorationsFrom(id); + + std::vector name_to_kill; + for (auto name : GetNames(id)) { + name_to_kill.push_back(name.second); + } + for (Instruction* name_inst : name_to_kill) { + KillInst(name_inst); + } +} + +void IRContext::KillNamesAndDecorates(Instruction* inst) { + const uint32_t rId = inst->result_id(); + if (rId == 0) return; + KillNamesAndDecorates(rId); +} + +void IRContext::AddCombinatorsForCapability(uint32_t capability) { + if (capability == SpvCapabilityShader) { + combinator_ops_[0].insert({SpvOpNop, + SpvOpUndef, + SpvOpConstant, + SpvOpConstantTrue, + SpvOpConstantFalse, + SpvOpConstantComposite, + SpvOpConstantSampler, + SpvOpConstantNull, + SpvOpTypeVoid, + SpvOpTypeBool, + SpvOpTypeInt, + SpvOpTypeFloat, + SpvOpTypeVector, + SpvOpTypeMatrix, + SpvOpTypeImage, + SpvOpTypeSampler, + SpvOpTypeSampledImage, + SpvOpTypeAccelerationStructureNV, + SpvOpTypeArray, + SpvOpTypeRuntimeArray, + SpvOpTypeStruct, + SpvOpTypeOpaque, + SpvOpTypePointer, + SpvOpTypeFunction, + SpvOpTypeEvent, + SpvOpTypeDeviceEvent, + SpvOpTypeReserveId, + SpvOpTypeQueue, + SpvOpTypePipe, + SpvOpTypeForwardPointer, + SpvOpVariable, + SpvOpImageTexelPointer, + SpvOpLoad, + SpvOpAccessChain, + SpvOpInBoundsAccessChain, + SpvOpArrayLength, + SpvOpVectorExtractDynamic, + SpvOpVectorInsertDynamic, + SpvOpVectorShuffle, + SpvOpCompositeConstruct, + SpvOpCompositeExtract, + SpvOpCompositeInsert, + SpvOpCopyObject, + SpvOpTranspose, + SpvOpSampledImage, + SpvOpImageSampleImplicitLod, + SpvOpImageSampleExplicitLod, + SpvOpImageSampleDrefImplicitLod, + SpvOpImageSampleDrefExplicitLod, + SpvOpImageSampleProjImplicitLod, + SpvOpImageSampleProjExplicitLod, + SpvOpImageSampleProjDrefImplicitLod, + SpvOpImageSampleProjDrefExplicitLod, + SpvOpImageFetch, + SpvOpImageGather, + SpvOpImageDrefGather, + SpvOpImageRead, + SpvOpImage, + SpvOpImageQueryFormat, + SpvOpImageQueryOrder, + SpvOpImageQuerySizeLod, + SpvOpImageQuerySize, + SpvOpImageQueryLevels, + SpvOpImageQuerySamples, + SpvOpConvertFToU, + SpvOpConvertFToS, + SpvOpConvertSToF, + SpvOpConvertUToF, + SpvOpUConvert, + SpvOpSConvert, + SpvOpFConvert, + SpvOpQuantizeToF16, + SpvOpBitcast, + SpvOpSNegate, + SpvOpFNegate, + SpvOpIAdd, + SpvOpFAdd, + SpvOpISub, + SpvOpFSub, + SpvOpIMul, + SpvOpFMul, + SpvOpUDiv, + SpvOpSDiv, + SpvOpFDiv, + SpvOpUMod, + SpvOpSRem, + SpvOpSMod, + SpvOpFRem, + SpvOpFMod, + SpvOpVectorTimesScalar, + SpvOpMatrixTimesScalar, + SpvOpVectorTimesMatrix, + SpvOpMatrixTimesVector, + SpvOpMatrixTimesMatrix, + SpvOpOuterProduct, + SpvOpDot, + SpvOpIAddCarry, + SpvOpISubBorrow, + SpvOpUMulExtended, + SpvOpSMulExtended, + SpvOpAny, + SpvOpAll, + SpvOpIsNan, + SpvOpIsInf, + SpvOpLogicalEqual, + SpvOpLogicalNotEqual, + SpvOpLogicalOr, + SpvOpLogicalAnd, + SpvOpLogicalNot, + SpvOpSelect, + SpvOpIEqual, + SpvOpINotEqual, + SpvOpUGreaterThan, + SpvOpSGreaterThan, + SpvOpUGreaterThanEqual, + SpvOpSGreaterThanEqual, + SpvOpULessThan, + SpvOpSLessThan, + SpvOpULessThanEqual, + SpvOpSLessThanEqual, + SpvOpFOrdEqual, + SpvOpFUnordEqual, + SpvOpFOrdNotEqual, + SpvOpFUnordNotEqual, + SpvOpFOrdLessThan, + SpvOpFUnordLessThan, + SpvOpFOrdGreaterThan, + SpvOpFUnordGreaterThan, + SpvOpFOrdLessThanEqual, + SpvOpFUnordLessThanEqual, + SpvOpFOrdGreaterThanEqual, + SpvOpFUnordGreaterThanEqual, + SpvOpShiftRightLogical, + SpvOpShiftRightArithmetic, + SpvOpShiftLeftLogical, + SpvOpBitwiseOr, + SpvOpBitwiseXor, + SpvOpBitwiseAnd, + SpvOpNot, + SpvOpBitFieldInsert, + SpvOpBitFieldSExtract, + SpvOpBitFieldUExtract, + SpvOpBitReverse, + SpvOpBitCount, + SpvOpPhi, + SpvOpImageSparseSampleImplicitLod, + SpvOpImageSparseSampleExplicitLod, + SpvOpImageSparseSampleDrefImplicitLod, + SpvOpImageSparseSampleDrefExplicitLod, + SpvOpImageSparseSampleProjImplicitLod, + SpvOpImageSparseSampleProjExplicitLod, + SpvOpImageSparseSampleProjDrefImplicitLod, + SpvOpImageSparseSampleProjDrefExplicitLod, + SpvOpImageSparseFetch, + SpvOpImageSparseGather, + SpvOpImageSparseDrefGather, + SpvOpImageSparseTexelsResident, + SpvOpImageSparseRead, + SpvOpSizeOf}); + } +} + +void IRContext::AddCombinatorsForExtension(Instruction* extension) { + assert(extension->opcode() == SpvOpExtInstImport && + "Expecting an import of an extension's instruction set."); + const char* extension_name = + reinterpret_cast(&extension->GetInOperand(0).words[0]); + if (!strcmp(extension_name, "GLSL.std.450")) { + combinator_ops_[extension->result_id()] = {GLSLstd450Round, + GLSLstd450RoundEven, + GLSLstd450Trunc, + GLSLstd450FAbs, + GLSLstd450SAbs, + GLSLstd450FSign, + GLSLstd450SSign, + GLSLstd450Floor, + GLSLstd450Ceil, + GLSLstd450Fract, + GLSLstd450Radians, + GLSLstd450Degrees, + GLSLstd450Sin, + GLSLstd450Cos, + GLSLstd450Tan, + GLSLstd450Asin, + GLSLstd450Acos, + GLSLstd450Atan, + GLSLstd450Sinh, + GLSLstd450Cosh, + GLSLstd450Tanh, + GLSLstd450Asinh, + GLSLstd450Acosh, + GLSLstd450Atanh, + GLSLstd450Atan2, + GLSLstd450Pow, + GLSLstd450Exp, + GLSLstd450Log, + GLSLstd450Exp2, + GLSLstd450Log2, + GLSLstd450Sqrt, + GLSLstd450InverseSqrt, + GLSLstd450Determinant, + GLSLstd450MatrixInverse, + GLSLstd450ModfStruct, + GLSLstd450FMin, + GLSLstd450UMin, + GLSLstd450SMin, + GLSLstd450FMax, + GLSLstd450UMax, + GLSLstd450SMax, + GLSLstd450FClamp, + GLSLstd450UClamp, + GLSLstd450SClamp, + GLSLstd450FMix, + GLSLstd450IMix, + GLSLstd450Step, + GLSLstd450SmoothStep, + GLSLstd450Fma, + GLSLstd450FrexpStruct, + GLSLstd450Ldexp, + GLSLstd450PackSnorm4x8, + GLSLstd450PackUnorm4x8, + GLSLstd450PackSnorm2x16, + GLSLstd450PackUnorm2x16, + GLSLstd450PackHalf2x16, + GLSLstd450PackDouble2x32, + GLSLstd450UnpackSnorm2x16, + GLSLstd450UnpackUnorm2x16, + GLSLstd450UnpackHalf2x16, + GLSLstd450UnpackSnorm4x8, + GLSLstd450UnpackUnorm4x8, + GLSLstd450UnpackDouble2x32, + GLSLstd450Length, + GLSLstd450Distance, + GLSLstd450Cross, + GLSLstd450Normalize, + GLSLstd450FaceForward, + GLSLstd450Reflect, + GLSLstd450Refract, + GLSLstd450FindILsb, + GLSLstd450FindSMsb, + GLSLstd450FindUMsb, + GLSLstd450InterpolateAtCentroid, + GLSLstd450InterpolateAtSample, + GLSLstd450InterpolateAtOffset, + GLSLstd450NMin, + GLSLstd450NMax, + GLSLstd450NClamp}; + } else { + // Map the result id to the empty set. + combinator_ops_[extension->result_id()]; + } +} + +void IRContext::InitializeCombinators() { + get_feature_mgr()->GetCapabilities()->ForEach( + [this](SpvCapability cap) { AddCombinatorsForCapability(cap); }); + + for (auto& extension : module()->ext_inst_imports()) { + AddCombinatorsForExtension(&extension); + } + + valid_analyses_ |= kAnalysisCombinators; +} + +void IRContext::RemoveFromIdToName(const Instruction* inst) { + if (id_to_name_ && + (inst->opcode() == SpvOpName || inst->opcode() == SpvOpMemberName)) { + auto range = id_to_name_->equal_range(inst->GetSingleWordInOperand(0)); + for (auto it = range.first; it != range.second; ++it) { + if (it->second == inst) { + id_to_name_->erase(it); + break; + } + } + } +} + +LoopDescriptor* IRContext::GetLoopDescriptor(const Function* f) { + if (!AreAnalysesValid(kAnalysisLoopAnalysis)) { + ResetLoopAnalysis(); + } + + std::unordered_map::iterator it = + loop_descriptors_.find(f); + if (it == loop_descriptors_.end()) { + return &loop_descriptors_ + .emplace(std::make_pair(f, LoopDescriptor(this, f))) + .first->second; + } + + return &it->second; +} + +uint32_t IRContext::FindBuiltinVar(uint32_t builtin) { + for (auto& a : module_->annotations()) { + if (a.opcode() != SpvOpDecorate) continue; + if (a.GetSingleWordInOperand(kSpvDecorateDecorationInIdx) != + SpvDecorationBuiltIn) + continue; + if (a.GetSingleWordInOperand(kSpvDecorateBuiltinInIdx) != builtin) continue; + uint32_t target_id = a.GetSingleWordInOperand(kSpvDecorateTargetIdInIdx); + Instruction* b_var = get_def_use_mgr()->GetDef(target_id); + if (b_var->opcode() != SpvOpVariable) continue; + return target_id; + } + return 0; +} + +void IRContext::AddVarToEntryPoints(uint32_t var_id) { + uint32_t ocnt = 0; + for (auto& e : module()->entry_points()) { + bool found = false; + e.ForEachInOperand([&ocnt, &found, &var_id](const uint32_t* idp) { + if (ocnt >= kEntryPointInterfaceInIdx) { + if (*idp == var_id) found = true; + } + ++ocnt; + }); + if (!found) { + e.AddOperand({SPV_OPERAND_TYPE_ID, {var_id}}); + get_def_use_mgr()->AnalyzeInstDefUse(&e); + } + } +} + +uint32_t IRContext::GetBuiltinVarId(uint32_t builtin) { + if (!AreAnalysesValid(kAnalysisBuiltinVarId)) ResetBuiltinAnalysis(); + // If cached, return it. + std::unordered_map::iterator it = + builtin_var_id_map_.find(builtin); + if (it != builtin_var_id_map_.end()) return it->second; + // Look for one in shader + uint32_t var_id = FindBuiltinVar(builtin); + if (var_id == 0) { + // If not found, create it + // TODO(greg-lunarg): Add support for all builtins + analysis::TypeManager* type_mgr = get_type_mgr(); + analysis::Type* reg_type; + switch (builtin) { + case SpvBuiltInFragCoord: { + analysis::Float float_ty(32); + analysis::Type* reg_float_ty = type_mgr->GetRegisteredType(&float_ty); + analysis::Vector v4float_ty(reg_float_ty, 4); + reg_type = type_mgr->GetRegisteredType(&v4float_ty); + break; + } + case SpvBuiltInVertexIndex: + case SpvBuiltInInstanceIndex: + case SpvBuiltInPrimitiveId: + case SpvBuiltInInvocationId: + case SpvBuiltInGlobalInvocationId: { + analysis::Integer uint_ty(32, false); + reg_type = type_mgr->GetRegisteredType(&uint_ty); + break; + } + default: { + assert(false && "unhandled builtin"); + return 0; + } + } + uint32_t type_id = type_mgr->GetTypeInstruction(reg_type); + uint32_t varTyPtrId = + type_mgr->FindPointerToType(type_id, SpvStorageClassInput); + // TODO(1841): Handle id overflow. + var_id = TakeNextId(); + std::unique_ptr newVarOp( + new Instruction(this, SpvOpVariable, varTyPtrId, var_id, + {{spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, + {SpvStorageClassInput}}})); + get_def_use_mgr()->AnalyzeInstDefUse(&*newVarOp); + module()->AddGlobalValue(std::move(newVarOp)); + get_decoration_mgr()->AddDecorationVal(var_id, SpvDecorationBuiltIn, + builtin); + AddVarToEntryPoints(var_id); + } + builtin_var_id_map_[builtin] = var_id; + return var_id; +} + +void IRContext::AddCalls(const Function* func, std::queue* todo) { + for (auto bi = func->begin(); bi != func->end(); ++bi) + for (auto ii = bi->begin(); ii != bi->end(); ++ii) + if (ii->opcode() == SpvOpFunctionCall) + todo->push(ii->GetSingleWordInOperand(0)); +} + +bool IRContext::ProcessEntryPointCallTree(ProcessFunction& pfn) { + // Collect all of the entry points as the roots. + std::queue roots; + for (auto& e : module()->entry_points()) { + roots.push(e.GetSingleWordInOperand(kEntryPointFunctionIdInIdx)); + } + return ProcessCallTreeFromRoots(pfn, &roots); +} + +bool IRContext::ProcessReachableCallTree(ProcessFunction& pfn) { + std::queue roots; + + // Add all entry points since they can be reached from outside the module. + for (auto& e : module()->entry_points()) + roots.push(e.GetSingleWordInOperand(kEntryPointFunctionIdInIdx)); + + // Add all exported functions since they can be reached from outside the + // module. + for (auto& a : annotations()) { + // TODO: Handle group decorations as well. Currently not generate by any + // front-end, but could be coming. + if (a.opcode() == SpvOp::SpvOpDecorate) { + if (a.GetSingleWordOperand(1) == + SpvDecoration::SpvDecorationLinkageAttributes) { + uint32_t lastOperand = a.NumOperands() - 1; + if (a.GetSingleWordOperand(lastOperand) == + SpvLinkageType::SpvLinkageTypeExport) { + uint32_t id = a.GetSingleWordOperand(0); + if (GetFunction(id)) { + roots.push(id); + } + } + } + } + } + + return ProcessCallTreeFromRoots(pfn, &roots); +} + +bool IRContext::ProcessCallTreeFromRoots(ProcessFunction& pfn, + std::queue* roots) { + // Process call tree + bool modified = false; + std::unordered_set done; + + while (!roots->empty()) { + const uint32_t fi = roots->front(); + roots->pop(); + if (done.insert(fi).second) { + Function* fn = GetFunction(fi); + modified = pfn(fn) || modified; + AddCalls(fn, roots); + } + } + return modified; +} + +// Gets the dominator analysis for function |f|. +DominatorAnalysis* IRContext::GetDominatorAnalysis(const Function* f) { + if (!AreAnalysesValid(kAnalysisDominatorAnalysis)) { + ResetDominatorAnalysis(); + } + + if (dominator_trees_.find(f) == dominator_trees_.end()) { + dominator_trees_[f].InitializeTree(*cfg(), f); + } + + return &dominator_trees_[f]; +} + +// Gets the postdominator analysis for function |f|. +PostDominatorAnalysis* IRContext::GetPostDominatorAnalysis(const Function* f) { + if (!AreAnalysesValid(kAnalysisDominatorAnalysis)) { + ResetDominatorAnalysis(); + } + + if (post_dominator_trees_.find(f) == post_dominator_trees_.end()) { + post_dominator_trees_[f].InitializeTree(*cfg(), f); + } + + return &post_dominator_trees_[f]; +} + +bool IRContext::CheckCFG() { + std::unordered_map> real_preds; + if (!AreAnalysesValid(kAnalysisCFG)) { + return true; + } + + for (Function& function : *module()) { + for (const auto& bb : function) { + bb.ForEachSuccessorLabel([&bb, &real_preds](const uint32_t lab_id) { + real_preds[lab_id].push_back(bb.id()); + }); + } + + for (auto& bb : function) { + std::vector preds = cfg()->preds(bb.id()); + std::vector real = real_preds[bb.id()]; + std::sort(preds.begin(), preds.end()); + std::sort(real.begin(), real.end()); + + bool same = true; + if (preds.size() != real.size()) { + same = false; + } + + for (size_t i = 0; i < real.size() && same; i++) { + if (preds[i] != real[i]) { + same = false; + } + } + + if (!same) { + std::cerr << "Predecessors for " << bb.id() << " are different:\n"; + + std::cerr << "Real:"; + for (uint32_t i : real) { + std::cerr << ' ' << i; + } + std::cerr << std::endl; + + std::cerr << "Recorded:"; + for (uint32_t i : preds) { + std::cerr << ' ' << i; + } + std::cerr << std::endl; + } + if (!same) return false; + } + } + + return true; +} +} // namespace opt +} // namespace spvtools diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h new file mode 100644 index 000000000..185b4944e --- /dev/null +++ b/source/opt/ir_context.h @@ -0,0 +1,992 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_IR_CONTEXT_H_ +#define SOURCE_OPT_IR_CONTEXT_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/assembly_grammar.h" +#include "source/opt/cfg.h" +#include "source/opt/constants.h" +#include "source/opt/decoration_manager.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/feature_manager.h" +#include "source/opt/fold.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/module.h" +#include "source/opt/register_pressure.h" +#include "source/opt/scalar_analysis.h" +#include "source/opt/struct_cfg_analysis.h" +#include "source/opt/type_manager.h" +#include "source/opt/value_number_table.h" +#include "source/util/make_unique.h" + +namespace spvtools { +namespace opt { + +class IRContext { + public: + // Available analyses. + // + // When adding a new analysis: + // + // 1. Enum values should be powers of 2. These are cast into uint32_t + // bitmasks, so we can have at most 31 analyses represented. + // + // 2. Make sure it gets invalidated or preserved by IRContext methods that add + // or remove IR elements (e.g., KillDef, KillInst, ReplaceAllUsesWith). + // + // 3. Add handling code in BuildInvalidAnalyses and InvalidateAnalyses + enum Analysis { + kAnalysisNone = 0 << 0, + kAnalysisBegin = 1 << 0, + kAnalysisDefUse = kAnalysisBegin, + kAnalysisInstrToBlockMapping = 1 << 1, + kAnalysisDecorations = 1 << 2, + kAnalysisCombinators = 1 << 3, + kAnalysisCFG = 1 << 4, + kAnalysisDominatorAnalysis = 1 << 5, + kAnalysisLoopAnalysis = 1 << 6, + kAnalysisNameMap = 1 << 7, + kAnalysisScalarEvolution = 1 << 8, + kAnalysisRegisterPressure = 1 << 9, + kAnalysisValueNumberTable = 1 << 10, + kAnalysisStructuredCFG = 1 << 11, + kAnalysisBuiltinVarId = 1 << 12, + kAnalysisIdToFuncMapping = 1 << 13, + kAnalysisConstants = 1 << 14, + kAnalysisTypes = 1 << 15, + kAnalysisEnd = 1 << 16 + }; + + using ProcessFunction = std::function; + + friend inline Analysis operator|(Analysis lhs, Analysis rhs); + friend inline Analysis& operator|=(Analysis& lhs, Analysis rhs); + friend inline Analysis operator<<(Analysis a, int shift); + friend inline Analysis& operator<<=(Analysis& a, int shift); + + // Creates an |IRContext| that contains an owned |Module| + IRContext(spv_target_env env, MessageConsumer c) + : syntax_context_(spvContextCreate(env)), + grammar_(syntax_context_), + unique_id_(0), + module_(new Module()), + consumer_(std::move(c)), + def_use_mgr_(nullptr), + valid_analyses_(kAnalysisNone), + constant_mgr_(nullptr), + type_mgr_(nullptr), + id_to_name_(nullptr), + max_id_bound_(kDefaultMaxIdBound) { + SetContextMessageConsumer(syntax_context_, consumer_); + module_->SetContext(this); + } + + IRContext(spv_target_env env, std::unique_ptr&& m, MessageConsumer c) + : syntax_context_(spvContextCreate(env)), + grammar_(syntax_context_), + unique_id_(0), + module_(std::move(m)), + consumer_(std::move(c)), + def_use_mgr_(nullptr), + valid_analyses_(kAnalysisNone), + type_mgr_(nullptr), + id_to_name_(nullptr), + max_id_bound_(kDefaultMaxIdBound) { + SetContextMessageConsumer(syntax_context_, consumer_); + module_->SetContext(this); + InitializeCombinators(); + } + + ~IRContext() { spvContextDestroy(syntax_context_); } + + Module* module() const { return module_.get(); } + + // Returns a vector of pointers to constant-creation instructions in this + // context. + inline std::vector GetConstants(); + inline std::vector GetConstants() const; + + // Iterators for annotation instructions contained in this context. + inline Module::inst_iterator annotation_begin(); + inline Module::inst_iterator annotation_end(); + inline IteratorRange annotations(); + inline IteratorRange annotations() const; + + // Iterators for capabilities instructions contained in this module. + inline Module::inst_iterator capability_begin(); + inline Module::inst_iterator capability_end(); + inline IteratorRange capabilities(); + inline IteratorRange capabilities() const; + + // Iterators for types, constants and global variables instructions. + inline Module::inst_iterator types_values_begin(); + inline Module::inst_iterator types_values_end(); + inline IteratorRange types_values(); + inline IteratorRange types_values() const; + + // Iterators for extension instructions contained in this module. + inline Module::inst_iterator ext_inst_import_begin(); + inline Module::inst_iterator ext_inst_import_end(); + inline IteratorRange ext_inst_imports(); + inline IteratorRange ext_inst_imports() const; + + // There are several kinds of debug instructions, according to where they can + // appear in the logical layout of a module: + // - Section 7a: OpString, OpSourceExtension, OpSource, OpSourceContinued + // - Section 7b: OpName, OpMemberName + // - Section 7c: OpModuleProcessed + // - Mostly anywhere: OpLine and OpNoLine + // + + // Iterators for debug 1 instructions (excluding OpLine & OpNoLine) contained + // in this module. These are for layout section 7a. + inline Module::inst_iterator debug1_begin(); + inline Module::inst_iterator debug1_end(); + inline IteratorRange debugs1(); + inline IteratorRange debugs1() const; + + // Iterators for debug 2 instructions (excluding OpLine & OpNoLine) contained + // in this module. These are for layout section 7b. + inline Module::inst_iterator debug2_begin(); + inline Module::inst_iterator debug2_end(); + inline IteratorRange debugs2(); + inline IteratorRange debugs2() const; + + // Iterators for debug 3 instructions (excluding OpLine & OpNoLine) contained + // in this module. These are for layout section 7c. + inline Module::inst_iterator debug3_begin(); + inline Module::inst_iterator debug3_end(); + inline IteratorRange debugs3(); + inline IteratorRange debugs3() const; + + // Clears all debug instructions (excluding OpLine & OpNoLine). + inline void debug_clear(); + + // Appends a capability instruction to this module. + inline void AddCapability(std::unique_ptr&& c); + // Appends an extension instruction to this module. + inline void AddExtension(std::unique_ptr&& e); + // Appends an extended instruction set instruction to this module. + inline void AddExtInstImport(std::unique_ptr&& e); + // Set the memory model for this module. + inline void SetMemoryModel(std::unique_ptr&& m); + // Appends an entry point instruction to this module. + inline void AddEntryPoint(std::unique_ptr&& e); + // Appends an execution mode instruction to this module. + inline void AddExecutionMode(std::unique_ptr&& e); + // Appends a debug 1 instruction (excluding OpLine & OpNoLine) to this module. + // "debug 1" instructions are the ones in layout section 7.a), see section + // 2.4 Logical Layout of a Module from the SPIR-V specification. + inline void AddDebug1Inst(std::unique_ptr&& d); + // Appends a debug 2 instruction (excluding OpLine & OpNoLine) to this module. + // "debug 2" instructions are the ones in layout section 7.b), see section + // 2.4 Logical Layout of a Module from the SPIR-V specification. + inline void AddDebug2Inst(std::unique_ptr&& d); + // Appends a debug 3 instruction (OpModuleProcessed) to this module. + // This is due to decision by the SPIR Working Group, pending publication. + inline void AddDebug3Inst(std::unique_ptr&& d); + // Appends an annotation instruction to this module. + inline void AddAnnotationInst(std::unique_ptr&& a); + // Appends a type-declaration instruction to this module. + inline void AddType(std::unique_ptr&& t); + // Appends a constant, global variable, or OpUndef instruction to this module. + inline void AddGlobalValue(std::unique_ptr&& v); + // Appends a function to this module. + inline void AddFunction(std::unique_ptr&& f); + + // Returns a pointer to a def-use manager. If the def-use manager is + // invalid, it is rebuilt first. + analysis::DefUseManager* get_def_use_mgr() { + if (!AreAnalysesValid(kAnalysisDefUse)) { + BuildDefUseManager(); + } + return def_use_mgr_.get(); + } + + // Returns a pointer to a value number table. If the liveness analysis is + // invalid, it is rebuilt first. + ValueNumberTable* GetValueNumberTable() { + if (!AreAnalysesValid(kAnalysisValueNumberTable)) { + BuildValueNumberTable(); + } + return vn_table_.get(); + } + + // Returns a pointer to a StructuredCFGAnalysis. If the analysis is invalid, + // it is rebuilt first. + StructuredCFGAnalysis* GetStructuredCFGAnalysis() { + if (!AreAnalysesValid(kAnalysisStructuredCFG)) { + BuildStructuredCFGAnalysis(); + } + return struct_cfg_analysis_.get(); + } + + // Returns a pointer to a liveness analysis. If the liveness analysis is + // invalid, it is rebuilt first. + LivenessAnalysis* GetLivenessAnalysis() { + if (!AreAnalysesValid(kAnalysisRegisterPressure)) { + BuildRegPressureAnalysis(); + } + return reg_pressure_.get(); + } + + // Returns the basic block for instruction |instr|. Re-builds the instruction + // block map, if needed. + BasicBlock* get_instr_block(Instruction* instr) { + if (!AreAnalysesValid(kAnalysisInstrToBlockMapping)) { + BuildInstrToBlockMapping(); + } + auto entry = instr_to_block_.find(instr); + return (entry != instr_to_block_.end()) ? entry->second : nullptr; + } + + // Returns the basic block for |id|. Re-builds the instruction block map, if + // needed. + // + // |id| must be a registered definition. + BasicBlock* get_instr_block(uint32_t id) { + Instruction* def = get_def_use_mgr()->GetDef(id); + return get_instr_block(def); + } + + // Sets the basic block for |inst|. Re-builds the mapping if it has become + // invalid. + void set_instr_block(Instruction* inst, BasicBlock* block) { + if (AreAnalysesValid(kAnalysisInstrToBlockMapping)) { + instr_to_block_[inst] = block; + } + } + + // Returns a pointer the decoration manager. If the decoration manger is + // invalid, it is rebuilt first. + analysis::DecorationManager* get_decoration_mgr() { + if (!AreAnalysesValid(kAnalysisDecorations)) { + BuildDecorationManager(); + } + return decoration_mgr_.get(); + } + + // Returns a pointer to the constant manager. If no constant manager has been + // created yet, it creates one. NOTE: Once created, the constant manager + // remains active and it is never re-built. + analysis::ConstantManager* get_constant_mgr() { + if (!AreAnalysesValid(kAnalysisConstants)) { + BuildConstantManager(); + } + return constant_mgr_.get(); + } + + // Returns a pointer to the type manager. If no type manager has been created + // yet, it creates one. NOTE: Once created, the type manager remains active it + // is never re-built. + analysis::TypeManager* get_type_mgr() { + if (!AreAnalysesValid(kAnalysisTypes)) { + BuildTypeManager(); + } + return type_mgr_.get(); + } + + // Returns a pointer to the scalar evolution analysis. If it is invalid it + // will be rebuilt first. + ScalarEvolutionAnalysis* GetScalarEvolutionAnalysis() { + if (!AreAnalysesValid(kAnalysisScalarEvolution)) { + BuildScalarEvolutionAnalysis(); + } + return scalar_evolution_analysis_.get(); + } + + // Build the map from the ids to the OpName and OpMemberName instruction + // associated with it. + inline void BuildIdToNameMap(); + + // Returns a range of instrucions that contain all of the OpName and + // OpMemberNames associated with the given id. + inline IteratorRange::iterator> + GetNames(uint32_t id); + + // Sets the message consumer to the given |consumer|. |consumer| which will be + // invoked every time there is a message to be communicated to the outside. + void SetMessageConsumer(MessageConsumer c) { consumer_ = std::move(c); } + + // Returns the reference to the message consumer for this pass. + const MessageConsumer& consumer() const { return consumer_; } + + // Rebuilds the analyses in |set| that are invalid. + void BuildInvalidAnalyses(Analysis set); + + // Invalidates all of the analyses except for those in |preserved_analyses|. + void InvalidateAnalysesExceptFor(Analysis preserved_analyses); + + // Invalidates the analyses marked in |analyses_to_invalidate|. + void InvalidateAnalyses(Analysis analyses_to_invalidate); + + // Deletes the instruction defining the given |id|. Returns true on + // success, false if the given |id| is not defined at all. This method also + // erases the name, decorations, and defintion of |id|. + // + // Pointers and iterators pointing to the deleted instructions become invalid. + // However other pointers and iterators are still valid. + bool KillDef(uint32_t id); + + // Deletes the given instruction |inst|. This method erases the + // information of the given instruction's uses of its operands. If |inst| + // defines a result id, its name and decorations will also be deleted. + // + // Pointer and iterator pointing to the deleted instructions become invalid. + // However other pointers and iterators are still valid. + // + // Note that if an instruction is not in an instruction list, the memory may + // not be safe to delete, so the instruction is turned into a OpNop instead. + // This can happen with OpLabel. + // + // Returns a pointer to the instruction after |inst| or |nullptr| if no such + // instruction exists. + Instruction* KillInst(Instruction* inst); + + // Returns true if all of the given analyses are valid. + bool AreAnalysesValid(Analysis set) { return (set & valid_analyses_) == set; } + + // Replaces all uses of |before| id with |after| id. Returns true if any + // replacement happens. This method does not kill the definition of the + // |before| id. If |after| is the same as |before|, does nothing and returns + // false. + // + // |before| and |after| must be registered definitions in the DefUseManager. + bool ReplaceAllUsesWith(uint32_t before, uint32_t after); + + // Returns true if all of the analyses that are suppose to be valid are + // actually valid. + bool IsConsistent(); + + // The IRContext will look at the def and uses of |inst| and update any valid + // analyses will be updated accordingly. + inline void AnalyzeDefUse(Instruction* inst); + + // Informs the IRContext that the uses of |inst| are going to change, and that + // is should forget everything it know about the current uses. Any valid + // analyses will be updated accordingly. + void ForgetUses(Instruction* inst); + + // The IRContext will look at the uses of |inst| and update any valid analyses + // will be updated accordingly. + void AnalyzeUses(Instruction* inst); + + // Kill all name and decorate ops targeting |id|. + void KillNamesAndDecorates(uint32_t id); + + // Kill all name and decorate ops targeting the result id of |inst|. + void KillNamesAndDecorates(Instruction* inst); + + // Returns the next unique id for use by an instruction. + inline uint32_t TakeNextUniqueId() { + assert(unique_id_ != std::numeric_limits::max()); + + // Skip zero. + return ++unique_id_; + } + + // Returns true if |inst| is a combinator in the current context. + // |combinator_ops_| is built if it has not been already. + inline bool IsCombinatorInstruction(const Instruction* inst) { + if (!AreAnalysesValid(kAnalysisCombinators)) { + InitializeCombinators(); + } + const uint32_t kExtInstSetIdInIndx = 0; + const uint32_t kExtInstInstructionInIndx = 1; + + if (inst->opcode() != SpvOpExtInst) { + return combinator_ops_[0].count(inst->opcode()) != 0; + } else { + uint32_t set = inst->GetSingleWordInOperand(kExtInstSetIdInIndx); + uint32_t op = inst->GetSingleWordInOperand(kExtInstInstructionInIndx); + return combinator_ops_[set].count(op) != 0; + } + } + + // Returns a pointer to the CFG for all the functions in |module_|. + CFG* cfg() { + if (!AreAnalysesValid(kAnalysisCFG)) { + BuildCFG(); + } + return cfg_.get(); + } + + // Gets the loop descriptor for function |f|. + LoopDescriptor* GetLoopDescriptor(const Function* f); + + // Gets the dominator analysis for function |f|. + DominatorAnalysis* GetDominatorAnalysis(const Function* f); + + // Gets the postdominator analysis for function |f|. + PostDominatorAnalysis* GetPostDominatorAnalysis(const Function* f); + + // Remove the dominator tree of |f| from the cache. + inline void RemoveDominatorAnalysis(const Function* f) { + dominator_trees_.erase(f); + } + + // Remove the postdominator tree of |f| from the cache. + inline void RemovePostDominatorAnalysis(const Function* f) { + post_dominator_trees_.erase(f); + } + + // Return the next available SSA id and increment it. Returns 0 if the + // maximum SSA id has been reached. + inline uint32_t TakeNextId() { return module()->TakeNextIdBound(); } + + FeatureManager* get_feature_mgr() { + if (!feature_mgr_.get()) { + AnalyzeFeatures(); + } + return feature_mgr_.get(); + } + + // Returns the grammar for this context. + const AssemblyGrammar& grammar() const { return grammar_; } + + // If |inst| has not yet been analysed by the def-use manager, then analyse + // its definitions and uses. + inline void UpdateDefUse(Instruction* inst); + + const InstructionFolder& get_instruction_folder() { + if (!inst_folder_) { + inst_folder_ = MakeUnique(this); + } + return *inst_folder_; + } + + uint32_t max_id_bound() const { return max_id_bound_; } + void set_max_id_bound(uint32_t new_bound) { max_id_bound_ = new_bound; } + + // Return id of variable only decorated with |builtin|, if in module. + // Create variable and return its id otherwise. If builtin not currently + // supported, return 0. + uint32_t GetBuiltinVarId(uint32_t builtin); + + // Returns the function whose id is |id|, if one exists. Returns |nullptr| + // otherwise. + Function* GetFunction(uint32_t id) { + if (!AreAnalysesValid(kAnalysisIdToFuncMapping)) { + BuildIdToFuncMapping(); + } + auto entry = id_to_func_.find(id); + return (entry != id_to_func_.end()) ? entry->second : nullptr; + } + + Function* GetFunction(Instruction* inst) { + if (inst->opcode() != SpvOpFunction) { + return nullptr; + } + return GetFunction(inst->result_id()); + } + + // Add to |todo| all ids of functions called in |func|. + void AddCalls(const Function* func, std::queue* todo); + + // Applies |pfn| to every function in the call trees that are rooted at the + // entry points. Returns true if any call |pfn| returns true. By convention + // |pfn| should return true if it modified the module. + bool ProcessEntryPointCallTree(ProcessFunction& pfn); + + // Applies |pfn| to every function in the call trees rooted at the entry + // points and exported functions. Returns true if any call |pfn| returns + // true. By convention |pfn| should return true if it modified the module. + bool ProcessReachableCallTree(ProcessFunction& pfn); + + // Applies |pfn| to every function in the call trees rooted at the elements of + // |roots|. Returns true if any call to |pfn| returns true. By convention + // |pfn| should return true if it modified the module. After returning + // |roots| will be empty. + bool ProcessCallTreeFromRoots(ProcessFunction& pfn, + std::queue* roots); + + private: + // Builds the def-use manager from scratch, even if it was already valid. + void BuildDefUseManager() { + def_use_mgr_ = MakeUnique(module()); + valid_analyses_ = valid_analyses_ | kAnalysisDefUse; + } + + // Builds the instruction-block map for the whole module. + void BuildInstrToBlockMapping() { + instr_to_block_.clear(); + for (auto& fn : *module_) { + for (auto& block : fn) { + block.ForEachInst([this, &block](Instruction* inst) { + instr_to_block_[inst] = █ + }); + } + } + valid_analyses_ = valid_analyses_ | kAnalysisInstrToBlockMapping; + } + + // Builds the instruction-function map for the whole module. + void BuildIdToFuncMapping() { + id_to_func_.clear(); + for (auto& fn : *module_) { + id_to_func_[fn.result_id()] = &fn; + } + valid_analyses_ = valid_analyses_ | kAnalysisIdToFuncMapping; + } + + void BuildDecorationManager() { + decoration_mgr_ = MakeUnique(module()); + valid_analyses_ = valid_analyses_ | kAnalysisDecorations; + } + + void BuildCFG() { + cfg_ = MakeUnique(module()); + valid_analyses_ = valid_analyses_ | kAnalysisCFG; + } + + void BuildScalarEvolutionAnalysis() { + scalar_evolution_analysis_ = MakeUnique(this); + valid_analyses_ = valid_analyses_ | kAnalysisScalarEvolution; + } + + // Builds the liveness analysis from scratch, even if it was already valid. + void BuildRegPressureAnalysis() { + reg_pressure_ = MakeUnique(this); + valid_analyses_ = valid_analyses_ | kAnalysisRegisterPressure; + } + + // Builds the value number table analysis from scratch, even if it was already + // valid. + void BuildValueNumberTable() { + vn_table_ = MakeUnique(this); + valid_analyses_ = valid_analyses_ | kAnalysisValueNumberTable; + } + + // Builds the structured CFG analysis from scratch, even if it was already + // valid. + void BuildStructuredCFGAnalysis() { + struct_cfg_analysis_ = MakeUnique(this); + valid_analyses_ = valid_analyses_ | kAnalysisStructuredCFG; + } + + // Builds the constant manager from scratch, even if it was already + // valid. + void BuildConstantManager() { + constant_mgr_ = MakeUnique(this); + valid_analyses_ = valid_analyses_ | kAnalysisConstants; + } + + // Builds the type manager from scratch, even if it was already + // valid. + void BuildTypeManager() { + type_mgr_ = MakeUnique(consumer(), this); + valid_analyses_ = valid_analyses_ | kAnalysisTypes; + } + + // Removes all computed dominator and post-dominator trees. This will force + // the context to rebuild the trees on demand. + void ResetDominatorAnalysis() { + // Clear the cache. + dominator_trees_.clear(); + post_dominator_trees_.clear(); + valid_analyses_ = valid_analyses_ | kAnalysisDominatorAnalysis; + } + + // Removes all computed loop descriptors. + void ResetLoopAnalysis() { + // Clear the cache. + loop_descriptors_.clear(); + valid_analyses_ = valid_analyses_ | kAnalysisLoopAnalysis; + } + + // Removes all computed loop descriptors. + void ResetBuiltinAnalysis() { + // Clear the cache. + builtin_var_id_map_.clear(); + valid_analyses_ = valid_analyses_ | kAnalysisBuiltinVarId; + } + + // Analyzes the features in the owned module. Builds the manager if required. + void AnalyzeFeatures() { + feature_mgr_ = MakeUnique(grammar_); + feature_mgr_->Analyze(module()); + } + + // Scans a module looking for it capabilities, and initializes combinator_ops_ + // accordingly. + void InitializeCombinators(); + + // Add the combinator opcode for the given capability to combinator_ops_. + void AddCombinatorsForCapability(uint32_t capability); + + // Add the combinator opcode for the given extension to combinator_ops_. + void AddCombinatorsForExtension(Instruction* extension); + + // Remove |inst| from |id_to_name_| if it is in map. + void RemoveFromIdToName(const Instruction* inst); + + // Returns true if it is suppose to be valid but it is incorrect. Returns + // true if the cfg is invalidated. + bool CheckCFG(); + + // Return id of variable only decorated with |builtin|, if in module. + // Return 0 otherwise. + uint32_t FindBuiltinVar(uint32_t builtin); + + // Add |var_id| to all entry points in module. + void AddVarToEntryPoints(uint32_t var_id); + + // The SPIR-V syntax context containing grammar tables for opcodes and + // operands. + spv_context syntax_context_; + + // Auxiliary object for querying SPIR-V grammar facts. + AssemblyGrammar grammar_; + + // An unique identifier for instructions in |module_|. Can be used to order + // instructions in a container. + // + // This member is initialized to 0, but always issues this value plus one. + // Therefore, 0 is not a valid unique id for an instruction. + uint32_t unique_id_; + + // The module being processed within this IR context. + std::unique_ptr module_; + + // A message consumer for diagnostics. + MessageConsumer consumer_; + + // The def-use manager for |module_|. + std::unique_ptr def_use_mgr_; + + // The instruction decoration manager for |module_|. + std::unique_ptr decoration_mgr_; + std::unique_ptr feature_mgr_; + + // A map from instructions to the basic block they belong to. This mapping is + // built on-demand when get_instr_block() is called. + // + // NOTE: Do not traverse this map. Ever. Use the function and basic block + // iterators to traverse instructions. + std::unordered_map instr_to_block_; + + // A map from ids to the function they define. This mapping is + // built on-demand when GetFunction() is called. + // + // NOTE: Do not traverse this map. Ever. Use the function and basic block + // iterators to traverse instructions. + std::unordered_map id_to_func_; + + // A bitset indicating which analyes are currently valid. + Analysis valid_analyses_; + + // Opcodes of shader capability core executable instructions + // without side-effect. + std::unordered_map> combinator_ops_; + + // Opcodes of shader capability core executable instructions + // without side-effect. + std::unordered_map builtin_var_id_map_; + + // The CFG for all the functions in |module_|. + std::unique_ptr cfg_; + + // Each function in the module will create its own dominator tree. We cache + // the result so it doesn't need to be rebuilt each time. + std::map dominator_trees_; + std::map post_dominator_trees_; + + // Cache of loop descriptors for each function. + std::unordered_map loop_descriptors_; + + // Constant manager for |module_|. + std::unique_ptr constant_mgr_; + + // Type manager for |module_|. + std::unique_ptr type_mgr_; + + // A map from an id to its corresponding OpName and OpMemberName instructions. + std::unique_ptr> id_to_name_; + + // The cache scalar evolution analysis node. + std::unique_ptr scalar_evolution_analysis_; + + // The liveness analysis |module_|. + std::unique_ptr reg_pressure_; + + std::unique_ptr vn_table_; + + std::unique_ptr inst_folder_; + + std::unique_ptr struct_cfg_analysis_; + + // The maximum legal value for the id bound. + uint32_t max_id_bound_; +}; + +inline IRContext::Analysis operator|(IRContext::Analysis lhs, + IRContext::Analysis rhs) { + return static_cast(static_cast(lhs) | + static_cast(rhs)); +} + +inline IRContext::Analysis& operator|=(IRContext::Analysis& lhs, + IRContext::Analysis rhs) { + lhs = static_cast(static_cast(lhs) | + static_cast(rhs)); + return lhs; +} + +inline IRContext::Analysis operator<<(IRContext::Analysis a, int shift) { + return static_cast(static_cast(a) << shift); +} + +inline IRContext::Analysis& operator<<=(IRContext::Analysis& a, int shift) { + a = static_cast(static_cast(a) << shift); + return a; +} + +std::vector IRContext::GetConstants() { + return module()->GetConstants(); +} + +std::vector IRContext::GetConstants() const { + return ((const Module*)module())->GetConstants(); +} + +Module::inst_iterator IRContext::annotation_begin() { + return module()->annotation_begin(); +} + +Module::inst_iterator IRContext::annotation_end() { + return module()->annotation_end(); +} + +IteratorRange IRContext::annotations() { + return module_->annotations(); +} + +IteratorRange IRContext::annotations() const { + return ((const Module*)module_.get())->annotations(); +} + +Module::inst_iterator IRContext::capability_begin() { + return module()->capability_begin(); +} + +Module::inst_iterator IRContext::capability_end() { + return module()->capability_end(); +} + +IteratorRange IRContext::capabilities() { + return module()->capabilities(); +} + +IteratorRange IRContext::capabilities() const { + return ((const Module*)module())->capabilities(); +} + +Module::inst_iterator IRContext::types_values_begin() { + return module()->types_values_begin(); +} + +Module::inst_iterator IRContext::types_values_end() { + return module()->types_values_end(); +} + +IteratorRange IRContext::types_values() { + return module()->types_values(); +} + +IteratorRange IRContext::types_values() const { + return ((const Module*)module_.get())->types_values(); +} + +Module::inst_iterator IRContext::ext_inst_import_begin() { + return module()->ext_inst_import_begin(); +} + +Module::inst_iterator IRContext::ext_inst_import_end() { + return module()->ext_inst_import_end(); +} + +IteratorRange IRContext::ext_inst_imports() { + return module()->ext_inst_imports(); +} + +IteratorRange IRContext::ext_inst_imports() const { + return ((const Module*)module_.get())->ext_inst_imports(); +} + +Module::inst_iterator IRContext::debug1_begin() { + return module()->debug1_begin(); +} + +Module::inst_iterator IRContext::debug1_end() { return module()->debug1_end(); } + +IteratorRange IRContext::debugs1() { + return module()->debugs1(); +} + +IteratorRange IRContext::debugs1() const { + return ((const Module*)module_.get())->debugs1(); +} + +Module::inst_iterator IRContext::debug2_begin() { + return module()->debug2_begin(); +} +Module::inst_iterator IRContext::debug2_end() { return module()->debug2_end(); } + +IteratorRange IRContext::debugs2() { + return module()->debugs2(); +} + +IteratorRange IRContext::debugs2() const { + return ((const Module*)module_.get())->debugs2(); +} + +Module::inst_iterator IRContext::debug3_begin() { + return module()->debug3_begin(); +} + +Module::inst_iterator IRContext::debug3_end() { return module()->debug3_end(); } + +IteratorRange IRContext::debugs3() { + return module()->debugs3(); +} + +IteratorRange IRContext::debugs3() const { + return ((const Module*)module_.get())->debugs3(); +} + +void IRContext::debug_clear() { module_->debug_clear(); } + +void IRContext::AddCapability(std::unique_ptr&& c) { + AddCombinatorsForCapability(c->GetSingleWordInOperand(0)); + module()->AddCapability(std::move(c)); +} + +void IRContext::AddExtension(std::unique_ptr&& e) { + if (AreAnalysesValid(kAnalysisDefUse)) { + get_def_use_mgr()->AnalyzeInstDefUse(e.get()); + } + module()->AddExtension(std::move(e)); +} + +void IRContext::AddExtInstImport(std::unique_ptr&& e) { + AddCombinatorsForExtension(e.get()); + module()->AddExtInstImport(std::move(e)); +} + +void IRContext::SetMemoryModel(std::unique_ptr&& m) { + module()->SetMemoryModel(std::move(m)); +} + +void IRContext::AddEntryPoint(std::unique_ptr&& e) { + module()->AddEntryPoint(std::move(e)); +} + +void IRContext::AddExecutionMode(std::unique_ptr&& e) { + module()->AddExecutionMode(std::move(e)); +} + +void IRContext::AddDebug1Inst(std::unique_ptr&& d) { + module()->AddDebug1Inst(std::move(d)); +} + +void IRContext::AddDebug2Inst(std::unique_ptr&& d) { + if (AreAnalysesValid(kAnalysisNameMap)) { + if (d->opcode() == SpvOpName || d->opcode() == SpvOpMemberName) { + id_to_name_->insert({d->result_id(), d.get()}); + } + } + module()->AddDebug2Inst(std::move(d)); +} + +void IRContext::AddDebug3Inst(std::unique_ptr&& d) { + module()->AddDebug3Inst(std::move(d)); +} + +void IRContext::AddAnnotationInst(std::unique_ptr&& a) { + if (AreAnalysesValid(kAnalysisDecorations)) { + get_decoration_mgr()->AddDecoration(a.get()); + } + if (AreAnalysesValid(kAnalysisDefUse)) { + get_def_use_mgr()->AnalyzeInstDefUse(a.get()); + } + module()->AddAnnotationInst(std::move(a)); +} + +void IRContext::AddType(std::unique_ptr&& t) { + module()->AddType(std::move(t)); + if (AreAnalysesValid(kAnalysisDefUse)) { + get_def_use_mgr()->AnalyzeInstDefUse(&*(--types_values_end())); + } +} + +void IRContext::AddGlobalValue(std::unique_ptr&& v) { + if (AreAnalysesValid(kAnalysisDefUse)) { + get_def_use_mgr()->AnalyzeInstDefUse(&*v); + } + module()->AddGlobalValue(std::move(v)); +} + +void IRContext::AddFunction(std::unique_ptr&& f) { + module()->AddFunction(std::move(f)); +} + +void IRContext::AnalyzeDefUse(Instruction* inst) { + if (AreAnalysesValid(kAnalysisDefUse)) { + get_def_use_mgr()->AnalyzeInstDefUse(inst); + } +} + +void IRContext::UpdateDefUse(Instruction* inst) { + if (AreAnalysesValid(kAnalysisDefUse)) { + get_def_use_mgr()->UpdateDefUse(inst); + } +} + +void IRContext::BuildIdToNameMap() { + id_to_name_ = MakeUnique>(); + for (Instruction& debug_inst : debugs2()) { + if (debug_inst.opcode() == SpvOpMemberName || + debug_inst.opcode() == SpvOpName) { + id_to_name_->insert({debug_inst.GetSingleWordInOperand(0), &debug_inst}); + } + } + valid_analyses_ = valid_analyses_ | kAnalysisNameMap; +} + +IteratorRange::iterator> +IRContext::GetNames(uint32_t id) { + if (!AreAnalysesValid(kAnalysisNameMap)) { + BuildIdToNameMap(); + } + auto result = id_to_name_->equal_range(id); + return make_range(std::move(result.first), std::move(result.second)); +} + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_IR_CONTEXT_H_ diff --git a/source/opt/ir_loader.cpp b/source/opt/ir_loader.cpp new file mode 100644 index 000000000..46e2bee42 --- /dev/null +++ b/source/opt/ir_loader.cpp @@ -0,0 +1,163 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/ir_loader.h" + +#include + +#include "source/opt/log.h" +#include "source/opt/reflect.h" +#include "source/util/make_unique.h" + +namespace spvtools { +namespace opt { + +IrLoader::IrLoader(const MessageConsumer& consumer, Module* m) + : consumer_(consumer), + module_(m), + source_(""), + inst_index_(0) {} + +bool IrLoader::AddInstruction(const spv_parsed_instruction_t* inst) { + ++inst_index_; + const auto opcode = static_cast(inst->opcode); + if (IsDebugLineInst(opcode)) { + dbg_line_info_.push_back(Instruction(module()->context(), *inst)); + return true; + } + + std::unique_ptr spv_inst( + new Instruction(module()->context(), *inst, std::move(dbg_line_info_))); + dbg_line_info_.clear(); + + const char* src = source_.c_str(); + spv_position_t loc = {inst_index_, 0, 0}; + + // Handle function and basic block boundaries first, then normal + // instructions. + if (opcode == SpvOpFunction) { + if (function_ != nullptr) { + Error(consumer_, src, loc, "function inside function"); + return false; + } + function_ = MakeUnique(std::move(spv_inst)); + } else if (opcode == SpvOpFunctionEnd) { + if (function_ == nullptr) { + Error(consumer_, src, loc, + "OpFunctionEnd without corresponding OpFunction"); + return false; + } + if (block_ != nullptr) { + Error(consumer_, src, loc, "OpFunctionEnd inside basic block"); + return false; + } + function_->SetFunctionEnd(std::move(spv_inst)); + module_->AddFunction(std::move(function_)); + function_ = nullptr; + } else if (opcode == SpvOpLabel) { + if (function_ == nullptr) { + Error(consumer_, src, loc, "OpLabel outside function"); + return false; + } + if (block_ != nullptr) { + Error(consumer_, src, loc, "OpLabel inside basic block"); + return false; + } + block_ = MakeUnique(std::move(spv_inst)); + } else if (IsTerminatorInst(opcode)) { + if (function_ == nullptr) { + Error(consumer_, src, loc, "terminator instruction outside function"); + return false; + } + if (block_ == nullptr) { + Error(consumer_, src, loc, "terminator instruction outside basic block"); + return false; + } + block_->AddInstruction(std::move(spv_inst)); + function_->AddBasicBlock(std::move(block_)); + block_ = nullptr; + } else { + if (function_ == nullptr) { // Outside function definition + SPIRV_ASSERT(consumer_, block_ == nullptr); + if (opcode == SpvOpCapability) { + module_->AddCapability(std::move(spv_inst)); + } else if (opcode == SpvOpExtension) { + module_->AddExtension(std::move(spv_inst)); + } else if (opcode == SpvOpExtInstImport) { + module_->AddExtInstImport(std::move(spv_inst)); + } else if (opcode == SpvOpMemoryModel) { + module_->SetMemoryModel(std::move(spv_inst)); + } else if (opcode == SpvOpEntryPoint) { + module_->AddEntryPoint(std::move(spv_inst)); + } else if (opcode == SpvOpExecutionMode) { + module_->AddExecutionMode(std::move(spv_inst)); + } else if (IsDebug1Inst(opcode)) { + module_->AddDebug1Inst(std::move(spv_inst)); + } else if (IsDebug2Inst(opcode)) { + module_->AddDebug2Inst(std::move(spv_inst)); + } else if (IsDebug3Inst(opcode)) { + module_->AddDebug3Inst(std::move(spv_inst)); + } else if (IsAnnotationInst(opcode)) { + module_->AddAnnotationInst(std::move(spv_inst)); + } else if (IsTypeInst(opcode)) { + module_->AddType(std::move(spv_inst)); + } else if (IsConstantInst(opcode) || opcode == SpvOpVariable || + opcode == SpvOpUndef) { + module_->AddGlobalValue(std::move(spv_inst)); + } else { + SPIRV_UNIMPLEMENTED(consumer_, + "unhandled inst type outside function definition"); + } + } else { + if (block_ == nullptr) { // Inside function but outside blocks + if (opcode != SpvOpFunctionParameter) { + Errorf(consumer_, src, loc, + "Non-OpFunctionParameter (opcode: %d) found inside " + "function but outside basic block", + opcode); + return false; + } + function_->AddParameter(std::move(spv_inst)); + } else { + block_->AddInstruction(std::move(spv_inst)); + } + } + } + return true; +} + +// Resolves internal references among the module, functions, basic blocks, etc. +// This function should be called after adding all instructions. +void IrLoader::EndModule() { + if (block_ && function_) { + // We're in the middle of a basic block, but the terminator is missing. + // Register the block anyway. This lets us write tests with less + // boilerplate. + function_->AddBasicBlock(std::move(block_)); + block_ = nullptr; + } + if (function_) { + // We're in the middle of a function, but the OpFunctionEnd is missing. + // Register the function anyway. This lets us write tests with less + // boilerplate. + module_->AddFunction(std::move(function_)); + function_ = nullptr; + } + for (auto& function : *module_) { + for (auto& bb : function) bb.SetParent(&function); + } +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/ir_loader.h b/source/opt/ir_loader.h new file mode 100644 index 000000000..940d7b0db --- /dev/null +++ b/source/opt/ir_loader.h @@ -0,0 +1,86 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_IR_LOADER_H_ +#define SOURCE_OPT_IR_LOADER_H_ + +#include +#include +#include + +#include "source/opt/basic_block.h" +#include "source/opt/instruction.h" +#include "source/opt/module.h" +#include "spirv-tools/libspirv.hpp" + +namespace spvtools { +namespace opt { + +// Loader class for constructing SPIR-V in-memory IR representation. Methods in +// this class are designed to work with the interface for spvBinaryParse() in +// libspirv.h so that we can leverage the syntax checks implemented behind it. +// +// The user is expected to call SetModuleHeader() to fill in the module's +// header, and then AddInstruction() for each decoded instruction, and finally +// EndModule() to finalize the module. The instructions processed in sequence +// by AddInstruction() should comprise a valid SPIR-V module. +class IrLoader { + public: + // Instantiates a builder to construct the given |module| gradually. + // All internal messages will be communicated to the outside via the given + // message |consumer|. This instance only keeps a reference to the |consumer|, + // so the |consumer| should outlive this instance. + IrLoader(const MessageConsumer& consumer, Module* m); + + // Sets the source name of the module. + void SetSource(const std::string& src) { source_ = src; } + + Module* module() const { return module_; } + + // Sets the fields in the module's header to the given parameters. + void SetModuleHeader(uint32_t magic, uint32_t version, uint32_t generator, + uint32_t bound, uint32_t reserved) { + module_->SetHeader({magic, version, generator, bound, reserved}); + } + // Adds an instruction to the module. Returns true if no error occurs. This + // method will properly capture and store the data provided in |inst| so that + // |inst| is no longer needed after returning. + bool AddInstruction(const spv_parsed_instruction_t* inst); + // Finalizes the module construction. This must be called after the module + // header has been set and all instructions have been added. This is + // forgiving in the case of a missing terminator instruction on a basic block, + // or a missing OpFunctionEnd. Resolves internal bookkeeping. + void EndModule(); + + private: + // Consumer for communicating messages to outside. + const MessageConsumer& consumer_; + // The module to be built. + Module* module_; + // The source name of the module. + std::string source_; + // The last used instruction index. + uint32_t inst_index_; + // The current Function under construction. + std::unique_ptr function_; + // The current BasicBlock under construction. + std::unique_ptr block_; + // Line related debug instructions accumulated thus far. + std::vector dbg_line_info_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_IR_LOADER_H_ diff --git a/source/opt/iterator.h b/source/opt/iterator.h new file mode 100644 index 000000000..444d457c5 --- /dev/null +++ b/source/opt/iterator.h @@ -0,0 +1,358 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_ITERATOR_H_ +#define SOURCE_OPT_ITERATOR_H_ + +#include // for ptrdiff_t +#include +#include +#include +#include +#include + +namespace spvtools { +namespace opt { + +// An ad hoc iterator class for std::vector>. The +// purpose of this iterator class is to provide transparent access to those +// std::unique_ptr managed elements in the vector, behaving like we are using +// std::vector<|ValueType|>. +template +class UptrVectorIterator + : public std::iterator::type> { + public: + using super = std::iterator< + std::random_access_iterator_tag, + typename std::conditional::type>; + + using pointer = typename super::pointer; + using reference = typename super::reference; + using difference_type = typename super::difference_type; + + // Type aliases. We need to apply constness properly if |IsConst| is true. + using Uptr = std::unique_ptr; + using UptrVector = typename std::conditional, + std::vector>::type; + using UnderlyingIterator = + typename std::conditional::type; + + // Creates a new iterator from the given |container| and its raw iterator + // |it|. + UptrVectorIterator(UptrVector* container, const UnderlyingIterator& it) + : container_(container), iterator_(it) {} + UptrVectorIterator(const UptrVectorIterator&) = default; + UptrVectorIterator& operator=(const UptrVectorIterator&) = default; + + inline UptrVectorIterator& operator++(); + inline UptrVectorIterator operator++(int); + inline UptrVectorIterator& operator--(); + inline UptrVectorIterator operator--(int); + + reference operator*() const { return **iterator_; } + pointer operator->() { return (*iterator_).get(); } + reference operator[](ptrdiff_t index) { return **(iterator_ + index); } + + inline bool operator==(const UptrVectorIterator& that) const; + inline bool operator!=(const UptrVectorIterator& that) const; + + inline ptrdiff_t operator-(const UptrVectorIterator& that) const; + inline bool operator<(const UptrVectorIterator& that) const; + + // Inserts the given |value| to the position pointed to by this iterator + // and returns an iterator to the newly iserted |value|. + // If the underlying vector changes capacity, all previous iterators will be + // invalidated. Otherwise, those previous iterators pointing to after the + // insertion point will be invalidated. + template + inline typename std::enable_if::type + InsertBefore(Uptr value); + + // Inserts the given |valueVector| to the position pointed to by this iterator + // and returns an iterator to the first newly inserted value. + // If the underlying vector changes capacity, all previous iterators will be + // invalidated. Otherwise, those previous iterators pointing to after the + // insertion point will be invalidated. + template + inline typename std::enable_if::type + InsertBefore(UptrVector* valueVector); + + // Erases the value at the position pointed to by this iterator + // and returns an iterator to the following value. + // If the underlying vector changes capacity, all previous iterators will be + // invalidated. Otherwise, those previous iterators pointing to after the + // erasure point will be invalidated. + template + inline typename std::enable_if::type + Erase(); + + // Returns the underlying iterator. + UnderlyingIterator Get() const { return iterator_; } + + // Returns a valid end iterator for the underlying container. + UptrVectorIterator End() const { + return UptrVectorIterator(container_, container_->end()); + } + + private: + UptrVector* container_; // The container we are manipulating. + UnderlyingIterator iterator_; // The raw iterator from the container. +}; + +// Handy class for a (begin, end) iterator pair. +template +class IteratorRange { + public: + IteratorRange(const IteratorType& b, const IteratorType& e) + : begin_(b), end_(e) {} + IteratorRange(IteratorType&& b, IteratorType&& e) + : begin_(std::move(b)), end_(std::move(e)) {} + + IteratorType begin() const { return begin_; } + IteratorType end() const { return end_; } + + bool empty() const { return begin_ == end_; } + size_t size() const { return end_ - begin_; } + + private: + IteratorType begin_; + IteratorType end_; +}; + +// Returns a (begin, end) iterator pair for the given iterators. +// The iterators must belong to the same container. +template +inline IteratorRange make_range(const IteratorType& begin, + const IteratorType& end) { + return {begin, end}; +} + +// Returns a (begin, end) iterator pair for the given iterators. +// The iterators must belong to the same container. +template +inline IteratorRange make_range(IteratorType&& begin, + IteratorType&& end) { + return {std::move(begin), std::move(end)}; +} + +// Returns a (begin, end) iterator pair for the given container. +template > +inline IteratorRange make_range( + std::vector>& container) { + return {IteratorType(&container, container.begin()), + IteratorType(&container, container.end())}; +} + +// Returns a const (begin, end) iterator pair for the given container. +template > +inline IteratorRange make_const_range( + const std::vector>& container) { + return {IteratorType(&container, container.cbegin()), + IteratorType(&container, container.cend())}; +} + +// Wrapping iterator class that only consider elements that satisfy the given +// predicate |Predicate|. When moving to the next element of the iterator, the +// FilterIterator will iterate over the range until it finds an element that +// satisfies |Predicate| or reaches the end of the iterator. +// +// Currently this iterator is always an input iterator. +template +class FilterIterator + : public std::iterator< + std::input_iterator_tag, typename SubIterator::value_type, + typename SubIterator::difference_type, typename SubIterator::pointer, + typename SubIterator::reference> { + public: + // Iterator interface. + using iterator_category = typename SubIterator::iterator_category; + using value_type = typename SubIterator::value_type; + using pointer = typename SubIterator::pointer; + using reference = typename SubIterator::reference; + using difference_type = typename SubIterator::difference_type; + + using Range = IteratorRange; + + FilterIterator(const IteratorRange& iteration_range, + Predicate predicate) + : cur_(iteration_range.begin()), + end_(iteration_range.end()), + predicate_(predicate) { + if (!IsPredicateSatisfied()) { + MoveToNextPosition(); + } + } + + FilterIterator(const SubIterator& end, Predicate predicate) + : FilterIterator({end, end}, predicate) {} + + inline FilterIterator& operator++() { + MoveToNextPosition(); + return *this; + } + inline FilterIterator operator++(int) { + FilterIterator old = *this; + MoveToNextPosition(); + return old; + } + + reference operator*() const { return *cur_; } + pointer operator->() { return &*cur_; } + + inline bool operator==(const FilterIterator& rhs) const { + return cur_ == rhs.cur_ && end_ == rhs.end_; + } + inline bool operator!=(const FilterIterator& rhs) const { + return !(*this == rhs); + } + + // Returns the underlying iterator. + SubIterator Get() const { return cur_; } + + // Returns the sentinel iterator. + FilterIterator GetEnd() const { return FilterIterator(end_, predicate_); } + + private: + // Returns true if the predicate is satisfied or the current iterator reached + // the end. + bool IsPredicateSatisfied() { return cur_ == end_ || predicate_(*cur_); } + + void MoveToNextPosition() { + if (cur_ == end_) return; + + do { + ++cur_; + } while (!IsPredicateSatisfied()); + } + + SubIterator cur_; + SubIterator end_; + Predicate predicate_; +}; + +template +FilterIterator MakeFilterIterator( + const IteratorRange& sub_iterator_range, Predicate predicate) { + return FilterIterator(sub_iterator_range, predicate); +} + +template +FilterIterator MakeFilterIterator( + const SubIterator& begin, const SubIterator& end, Predicate predicate) { + return MakeFilterIterator(make_range(begin, end), predicate); +} + +template +typename FilterIterator::Range MakeFilterIteratorRange( + const SubIterator& begin, const SubIterator& end, Predicate predicate) { + return typename FilterIterator::Range( + MakeFilterIterator(begin, end, predicate), + MakeFilterIterator(end, end, predicate)); +} + +template +inline UptrVectorIterator& UptrVectorIterator::operator++() { + ++iterator_; + return *this; +} + +template +inline UptrVectorIterator UptrVectorIterator::operator++(int) { + auto it = *this; + ++(*this); + return it; +} + +template +inline UptrVectorIterator& UptrVectorIterator::operator--() { + --iterator_; + return *this; +} + +template +inline UptrVectorIterator UptrVectorIterator::operator--(int) { + auto it = *this; + --(*this); + return it; +} + +template +inline bool UptrVectorIterator::operator==( + const UptrVectorIterator& that) const { + return container_ == that.container_ && iterator_ == that.iterator_; +} + +template +inline bool UptrVectorIterator::operator!=( + const UptrVectorIterator& that) const { + return !(*this == that); +} + +template +inline ptrdiff_t UptrVectorIterator::operator-( + const UptrVectorIterator& that) const { + assert(container_ == that.container_); + return iterator_ - that.iterator_; +} + +template +inline bool UptrVectorIterator::operator<( + const UptrVectorIterator& that) const { + assert(container_ == that.container_); + return iterator_ < that.iterator_; +} + +template +template +inline + typename std::enable_if>::type + UptrVectorIterator::InsertBefore(Uptr value) { + auto index = iterator_ - container_->begin(); + container_->insert(iterator_, std::move(value)); + return UptrVectorIterator(container_, container_->begin() + index); +} + +template +template +inline + typename std::enable_if>::type + UptrVectorIterator::InsertBefore(UptrVector* values) { + const auto pos = iterator_ - container_->begin(); + const auto origsz = container_->size(); + container_->resize(origsz + values->size()); + std::move_backward(container_->begin() + pos, container_->begin() + origsz, + container_->end()); + std::move(values->begin(), values->end(), container_->begin() + pos); + return UptrVectorIterator(container_, container_->begin() + pos); +} + +template +template +inline + typename std::enable_if>::type + UptrVectorIterator::Erase() { + auto index = iterator_ - container_->begin(); + (void)container_->erase(iterator_); + return UptrVectorIterator(container_, container_->begin() + index); +} + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_ITERATOR_H_ diff --git a/source/opt/licm_pass.cpp b/source/opt/licm_pass.cpp new file mode 100644 index 000000000..82851fd27 --- /dev/null +++ b/source/opt/licm_pass.cpp @@ -0,0 +1,140 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/licm_pass.h" + +#include +#include + +#include "source/opt/module.h" +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +Pass::Status LICMPass::Process() { return ProcessIRContext(); } + +Pass::Status LICMPass::ProcessIRContext() { + Status status = Status::SuccessWithoutChange; + Module* module = get_module(); + + // Process each function in the module + for (auto func = module->begin(); + func != module->end() && status != Status::Failure; ++func) { + status = CombineStatus(status, ProcessFunction(&*func)); + } + return status; +} + +Pass::Status LICMPass::ProcessFunction(Function* f) { + Status status = Status::SuccessWithoutChange; + LoopDescriptor* loop_descriptor = context()->GetLoopDescriptor(f); + + // Process each loop in the function + for (auto it = loop_descriptor->begin(); + it != loop_descriptor->end() && status != Status::Failure; ++it) { + Loop& loop = *it; + // Ignore nested loops, as we will process them in order in ProcessLoop + if (loop.IsNested()) { + continue; + } + status = CombineStatus(status, ProcessLoop(&loop, f)); + } + return status; +} + +Pass::Status LICMPass::ProcessLoop(Loop* loop, Function* f) { + Status status = Status::SuccessWithoutChange; + + // Process all nested loops first + for (auto nl = loop->begin(); nl != loop->end() && status != Status::Failure; + ++nl) { + Loop* nested_loop = *nl; + status = CombineStatus(status, ProcessLoop(nested_loop, f)); + } + + std::vector loop_bbs{}; + status = CombineStatus( + status, + AnalyseAndHoistFromBB(loop, f, loop->GetHeaderBlock(), &loop_bbs)); + + for (size_t i = 0; i < loop_bbs.size() && status != Status::Failure; ++i) { + BasicBlock* bb = loop_bbs[i]; + // do not delete the element + status = + CombineStatus(status, AnalyseAndHoistFromBB(loop, f, bb, &loop_bbs)); + } + + return status; +} + +Pass::Status LICMPass::AnalyseAndHoistFromBB( + Loop* loop, Function* f, BasicBlock* bb, + std::vector* loop_bbs) { + bool modified = false; + std::function hoist_inst = + [this, &loop, &modified](Instruction* inst) { + if (loop->ShouldHoistInstruction(this->context(), inst)) { + if (!HoistInstruction(loop, inst)) { + return false; + } + modified = true; + } + return true; + }; + + if (IsImmediatelyContainedInLoop(loop, f, bb)) { + if (!bb->WhileEachInst(hoist_inst, false)) { + return Status::Failure; + } + } + + DominatorAnalysis* dom_analysis = context()->GetDominatorAnalysis(f); + DominatorTree& dom_tree = dom_analysis->GetDomTree(); + + for (DominatorTreeNode* child_dom_tree_node : *dom_tree.GetTreeNode(bb)) { + if (loop->IsInsideLoop(child_dom_tree_node->bb_)) { + loop_bbs->push_back(child_dom_tree_node->bb_); + } + } + + return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); +} + +bool LICMPass::IsImmediatelyContainedInLoop(Loop* loop, Function* f, + BasicBlock* bb) { + LoopDescriptor* loop_descriptor = context()->GetLoopDescriptor(f); + return loop == (*loop_descriptor)[bb->id()]; +} + +bool LICMPass::HoistInstruction(Loop* loop, Instruction* inst) { + // TODO(1841): Handle failure to create pre-header. + BasicBlock* pre_header_bb = loop->GetOrCreatePreHeaderBlock(); + if (!pre_header_bb) { + return false; + } + Instruction* insertion_point = &*pre_header_bb->tail(); + Instruction* previous_node = insertion_point->PreviousNode(); + if (previous_node && (previous_node->opcode() == SpvOpLoopMerge || + previous_node->opcode() == SpvOpSelectionMerge)) { + insertion_point = previous_node; + } + + inst->InsertBefore(insertion_point); + context()->set_instr_block(inst, pre_header_bb); + return true; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/licm_pass.h b/source/opt/licm_pass.h new file mode 100644 index 000000000..597fe920a --- /dev/null +++ b/source/opt/licm_pass.h @@ -0,0 +1,72 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_LICM_PASS_H_ +#define SOURCE_OPT_LICM_PASS_H_ + +#include +#include + +#include "source/opt/basic_block.h" +#include "source/opt/instruction.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +class LICMPass : public Pass { + public: + LICMPass() {} + + const char* name() const override { return "loop-invariant-code-motion"; } + Status Process() override; + + private: + // Searches the IRContext for functions and processes each, moving invariants + // outside loops within the function where possible. + // Returns the status depending on whether or not there was a failure or + // change. + Pass::Status ProcessIRContext(); + + // Checks the function for loops, calling ProcessLoop on each one found. + // Returns the status depending on whether or not there was a failure or + // change. + Pass::Status ProcessFunction(Function* f); + + // Checks for invariants in the loop and attempts to move them to the loops + // preheader. Works from inner loop to outer when nested loops are found. + // Returns the status depending on whether or not there was a failure or + // change. + Pass::Status ProcessLoop(Loop* loop, Function* f); + + // Analyses each instruction in |bb|, hoisting invariants to |pre_header_bb|. + // Each child of |bb| wrt to |dom_tree| is pushed to |loop_bbs| + // Returns the status depending on whether or not there was a failure or + // change. + Pass::Status AnalyseAndHoistFromBB(Loop* loop, Function* f, BasicBlock* bb, + std::vector* loop_bbs); + + // Returns true if |bb| is immediately contained in |loop| + bool IsImmediatelyContainedInLoop(Loop* loop, Function* f, BasicBlock* bb); + + // Move the instruction to the preheader of |loop|. + // This method will update the instruction to block mapping for the context + bool HoistInstruction(Loop* loop, Instruction* inst); +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_LICM_PASS_H_ diff --git a/source/opt/local_access_chain_convert_pass.cpp b/source/opt/local_access_chain_convert_pass.cpp new file mode 100644 index 000000000..5b976a11e --- /dev/null +++ b/source/opt/local_access_chain_convert_pass.cpp @@ -0,0 +1,353 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/local_access_chain_convert_pass.h" + +#include "ir_builder.h" +#include "ir_context.h" +#include "iterator.h" + +namespace spvtools { +namespace opt { + +namespace { + +const uint32_t kStoreValIdInIdx = 1; +const uint32_t kAccessChainPtrIdInIdx = 0; +const uint32_t kConstantValueInIdx = 0; +const uint32_t kTypeIntWidthInIdx = 0; + +} // anonymous namespace + +void LocalAccessChainConvertPass::BuildAndAppendInst( + SpvOp opcode, uint32_t typeId, uint32_t resultId, + const std::vector& in_opnds, + std::vector>* newInsts) { + std::unique_ptr newInst( + new Instruction(context(), opcode, typeId, resultId, in_opnds)); + get_def_use_mgr()->AnalyzeInstDefUse(&*newInst); + newInsts->emplace_back(std::move(newInst)); +} + +uint32_t LocalAccessChainConvertPass::BuildAndAppendVarLoad( + const Instruction* ptrInst, uint32_t* varId, uint32_t* varPteTypeId, + std::vector>* newInsts) { + const uint32_t ldResultId = TakeNextId(); + *varId = ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx); + const Instruction* varInst = get_def_use_mgr()->GetDef(*varId); + assert(varInst->opcode() == SpvOpVariable); + *varPteTypeId = GetPointeeTypeId(varInst); + BuildAndAppendInst(SpvOpLoad, *varPteTypeId, ldResultId, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {*varId}}}, + newInsts); + return ldResultId; +} + +void LocalAccessChainConvertPass::AppendConstantOperands( + const Instruction* ptrInst, std::vector* in_opnds) { + uint32_t iidIdx = 0; + ptrInst->ForEachInId([&iidIdx, &in_opnds, this](const uint32_t* iid) { + if (iidIdx > 0) { + const Instruction* cInst = get_def_use_mgr()->GetDef(*iid); + uint32_t val = cInst->GetSingleWordInOperand(kConstantValueInIdx); + in_opnds->push_back( + {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {val}}); + } + ++iidIdx; + }); +} + +void LocalAccessChainConvertPass::ReplaceAccessChainLoad( + const Instruction* address_inst, Instruction* original_load) { + // Build and append load of variable in ptrInst + std::vector> new_inst; + uint32_t varId; + uint32_t varPteTypeId; + const uint32_t ldResultId = + BuildAndAppendVarLoad(address_inst, &varId, &varPteTypeId, &new_inst); + context()->get_decoration_mgr()->CloneDecorations( + original_load->result_id(), ldResultId, {SpvDecorationRelaxedPrecision}); + original_load->InsertBefore(std::move(new_inst)); + + // Rewrite |original_load| into an extract. + Instruction::OperandList new_operands; + + // copy the result id and the type id to the new operand list. + new_operands.emplace_back(original_load->GetOperand(0)); + new_operands.emplace_back(original_load->GetOperand(1)); + + new_operands.emplace_back( + Operand({spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}})); + AppendConstantOperands(address_inst, &new_operands); + original_load->SetOpcode(SpvOpCompositeExtract); + original_load->ReplaceOperands(new_operands); + context()->UpdateDefUse(original_load); +} + +void LocalAccessChainConvertPass::GenAccessChainStoreReplacement( + const Instruction* ptrInst, uint32_t valId, + std::vector>* newInsts) { + // Build and append load of variable in ptrInst + uint32_t varId; + uint32_t varPteTypeId; + const uint32_t ldResultId = + BuildAndAppendVarLoad(ptrInst, &varId, &varPteTypeId, newInsts); + context()->get_decoration_mgr()->CloneDecorations( + varId, ldResultId, {SpvDecorationRelaxedPrecision}); + + // Build and append Insert + const uint32_t insResultId = TakeNextId(); + std::vector ins_in_opnds = { + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}}; + AppendConstantOperands(ptrInst, &ins_in_opnds); + BuildAndAppendInst(SpvOpCompositeInsert, varPteTypeId, insResultId, + ins_in_opnds, newInsts); + + context()->get_decoration_mgr()->CloneDecorations( + varId, insResultId, {SpvDecorationRelaxedPrecision}); + + // Build and append Store + BuildAndAppendInst(SpvOpStore, 0, 0, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {varId}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {insResultId}}}, + newInsts); +} + +bool LocalAccessChainConvertPass::IsConstantIndexAccessChain( + const Instruction* acp) const { + uint32_t inIdx = 0; + return acp->WhileEachInId([&inIdx, this](const uint32_t* tid) { + if (inIdx > 0) { + Instruction* opInst = get_def_use_mgr()->GetDef(*tid); + if (opInst->opcode() != SpvOpConstant) return false; + } + ++inIdx; + return true; + }); +} + +bool LocalAccessChainConvertPass::HasOnlySupportedRefs(uint32_t ptrId) { + if (supported_ref_ptrs_.find(ptrId) != supported_ref_ptrs_.end()) return true; + if (get_def_use_mgr()->WhileEachUser(ptrId, [this](Instruction* user) { + SpvOp op = user->opcode(); + if (IsNonPtrAccessChain(op) || op == SpvOpCopyObject) { + if (!HasOnlySupportedRefs(user->result_id())) { + return false; + } + } else if (op != SpvOpStore && op != SpvOpLoad && op != SpvOpName && + !IsNonTypeDecorate(op)) { + return false; + } + return true; + })) { + supported_ref_ptrs_.insert(ptrId); + return true; + } + return false; +} + +void LocalAccessChainConvertPass::FindTargetVars(Function* func) { + for (auto bi = func->begin(); bi != func->end(); ++bi) { + for (auto ii = bi->begin(); ii != bi->end(); ++ii) { + switch (ii->opcode()) { + case SpvOpStore: + case SpvOpLoad: { + uint32_t varId; + Instruction* ptrInst = GetPtr(&*ii, &varId); + if (!IsTargetVar(varId)) break; + const SpvOp op = ptrInst->opcode(); + // Rule out variables with non-supported refs eg function calls + if (!HasOnlySupportedRefs(varId)) { + seen_non_target_vars_.insert(varId); + seen_target_vars_.erase(varId); + break; + } + // Rule out variables with nested access chains + // TODO(): Convert nested access chains + if (IsNonPtrAccessChain(op) && ptrInst->GetSingleWordInOperand( + kAccessChainPtrIdInIdx) != varId) { + seen_non_target_vars_.insert(varId); + seen_target_vars_.erase(varId); + break; + } + // Rule out variables accessed with non-constant indices + if (!IsConstantIndexAccessChain(ptrInst)) { + seen_non_target_vars_.insert(varId); + seen_target_vars_.erase(varId); + break; + } + } break; + default: + break; + } + } + } +} + +bool LocalAccessChainConvertPass::ConvertLocalAccessChains(Function* func) { + FindTargetVars(func); + // Replace access chains of all targeted variables with equivalent + // extract and insert sequences + bool modified = false; + for (auto bi = func->begin(); bi != func->end(); ++bi) { + std::vector dead_instructions; + for (auto ii = bi->begin(); ii != bi->end(); ++ii) { + switch (ii->opcode()) { + case SpvOpLoad: { + uint32_t varId; + Instruction* ptrInst = GetPtr(&*ii, &varId); + if (!IsNonPtrAccessChain(ptrInst->opcode())) break; + if (!IsTargetVar(varId)) break; + std::vector> newInsts; + ReplaceAccessChainLoad(ptrInst, &*ii); + modified = true; + } break; + case SpvOpStore: { + uint32_t varId; + Instruction* ptrInst = GetPtr(&*ii, &varId); + if (!IsNonPtrAccessChain(ptrInst->opcode())) break; + if (!IsTargetVar(varId)) break; + std::vector> newInsts; + uint32_t valId = ii->GetSingleWordInOperand(kStoreValIdInIdx); + GenAccessChainStoreReplacement(ptrInst, valId, &newInsts); + dead_instructions.push_back(&*ii); + ++ii; + ii = ii.InsertBefore(std::move(newInsts)); + ++ii; + ++ii; + modified = true; + } break; + default: + break; + } + } + + while (!dead_instructions.empty()) { + Instruction* inst = dead_instructions.back(); + dead_instructions.pop_back(); + DCEInst(inst, [&dead_instructions](Instruction* other_inst) { + auto i = std::find(dead_instructions.begin(), dead_instructions.end(), + other_inst); + if (i != dead_instructions.end()) { + dead_instructions.erase(i); + } + }); + } + } + return modified; +} + +void LocalAccessChainConvertPass::Initialize() { + // Initialize Target Variable Caches + seen_target_vars_.clear(); + seen_non_target_vars_.clear(); + + // Initialize collections + supported_ref_ptrs_.clear(); + + // Initialize extension whitelist + InitExtensions(); +} + +bool LocalAccessChainConvertPass::AllExtensionsSupported() const { + // If any extension not in whitelist, return false + for (auto& ei : get_module()->extensions()) { + const char* extName = + reinterpret_cast(&ei.GetInOperand(0).words[0]); + if (extensions_whitelist_.find(extName) == extensions_whitelist_.end()) + return false; + } + return true; +} + +Pass::Status LocalAccessChainConvertPass::ProcessImpl() { + // If non-32-bit integer type in module, terminate processing + // TODO(): Handle non-32-bit integer constants in access chains + for (const Instruction& inst : get_module()->types_values()) + if (inst.opcode() == SpvOpTypeInt && + inst.GetSingleWordInOperand(kTypeIntWidthInIdx) != 32) + return Status::SuccessWithoutChange; + // Do not process if module contains OpGroupDecorate. Additional + // support required in KillNamesAndDecorates(). + // TODO(greg-lunarg): Add support for OpGroupDecorate + for (auto& ai : get_module()->annotations()) + if (ai.opcode() == SpvOpGroupDecorate) return Status::SuccessWithoutChange; + // Do not process if any disallowed extensions are enabled + if (!AllExtensionsSupported()) return Status::SuccessWithoutChange; + // Process all entry point functions. + ProcessFunction pfn = [this](Function* fp) { + return ConvertLocalAccessChains(fp); + }; + bool modified = context()->ProcessEntryPointCallTree(pfn); + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +LocalAccessChainConvertPass::LocalAccessChainConvertPass() {} + +Pass::Status LocalAccessChainConvertPass::Process() { + Initialize(); + return ProcessImpl(); +} + +void LocalAccessChainConvertPass::InitExtensions() { + extensions_whitelist_.clear(); + extensions_whitelist_.insert({ + "SPV_AMD_shader_explicit_vertex_parameter", + "SPV_AMD_shader_trinary_minmax", + "SPV_AMD_gcn_shader", + "SPV_KHR_shader_ballot", + "SPV_AMD_shader_ballot", + "SPV_AMD_gpu_shader_half_float", + "SPV_KHR_shader_draw_parameters", + "SPV_KHR_subgroup_vote", + "SPV_KHR_16bit_storage", + "SPV_KHR_device_group", + "SPV_KHR_multiview", + "SPV_NVX_multiview_per_view_attributes", + "SPV_NV_viewport_array2", + "SPV_NV_stereo_view_rendering", + "SPV_NV_sample_mask_override_coverage", + "SPV_NV_geometry_shader_passthrough", + "SPV_AMD_texture_gather_bias_lod", + "SPV_KHR_storage_buffer_storage_class", + // SPV_KHR_variable_pointers + // Currently do not support extended pointer expressions + "SPV_AMD_gpu_shader_int16", + "SPV_KHR_post_depth_coverage", + "SPV_KHR_shader_atomic_counter_ops", + "SPV_EXT_shader_stencil_export", + "SPV_EXT_shader_viewport_index_layer", + "SPV_AMD_shader_image_load_store_lod", + "SPV_AMD_shader_fragment_mask", + "SPV_EXT_fragment_fully_covered", + "SPV_AMD_gpu_shader_half_float_fetch", + "SPV_GOOGLE_decorate_string", + "SPV_GOOGLE_hlsl_functionality1", + "SPV_NV_shader_subgroup_partitioned", + "SPV_EXT_descriptor_indexing", + "SPV_NV_fragment_shader_barycentric", + "SPV_NV_compute_shader_derivatives", + "SPV_NV_shader_image_footprint", + "SPV_NV_shading_rate", + "SPV_NV_mesh_shader", + "SPV_NV_ray_tracing", + "SPV_EXT_fragment_invocation_density", + }); +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/local_access_chain_convert_pass.h b/source/opt/local_access_chain_convert_pass.h new file mode 100644 index 000000000..1c7e3d53b --- /dev/null +++ b/source/opt/local_access_chain_convert_pass.h @@ -0,0 +1,131 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_LOCAL_ACCESS_CHAIN_CONVERT_PASS_H_ +#define SOURCE_OPT_LOCAL_ACCESS_CHAIN_CONVERT_PASS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/opt/basic_block.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/mem_pass.h" +#include "source/opt/module.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class LocalAccessChainConvertPass : public MemPass { + public: + LocalAccessChainConvertPass(); + + const char* name() const override { return "convert-local-access-chains"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | IRContext::kAnalysisConstants | + IRContext::kAnalysisTypes; + } + + using ProcessFunction = std::function; + + private: + // Return true if all refs through |ptrId| are only loads or stores and + // cache ptrId in supported_ref_ptrs_. TODO(dnovillo): This function is + // replicated in other passes and it's slightly different in every pass. Is it + // possible to make one common implementation? + bool HasOnlySupportedRefs(uint32_t ptrId); + + // Search |func| and cache function scope variables of target type that are + // not accessed with non-constant-index access chains. Also cache non-target + // variables. + void FindTargetVars(Function* func); + + // Build instruction from |opcode|, |typeId|, |resultId|, and |in_opnds|. + // Append to |newInsts|. + void BuildAndAppendInst(SpvOp opcode, uint32_t typeId, uint32_t resultId, + const std::vector& in_opnds, + std::vector>* newInsts); + + // Build load of variable in |ptrInst| and append to |newInsts|. + // Return var in |varId| and its pointee type in |varPteTypeId|. + uint32_t BuildAndAppendVarLoad( + const Instruction* ptrInst, uint32_t* varId, uint32_t* varPteTypeId, + std::vector>* newInsts); + + // Append literal integer operands to |in_opnds| corresponding to constant + // integer operands from access chain |ptrInst|. Assumes all indices in + // access chains are OpConstant. + void AppendConstantOperands(const Instruction* ptrInst, + std::vector* in_opnds); + + // Create a load/insert/store equivalent to a store of + // |valId| through (constant index) access chaing |ptrInst|. + // Append to |newInsts|. + void GenAccessChainStoreReplacement( + const Instruction* ptrInst, uint32_t valId, + std::vector>* newInsts); + + // For the (constant index) access chain |address_inst|, create an + // equivalent load and extract that replaces |original_load|. The result id + // of the extract will be the same as the original result id of + // |original_load|. + void ReplaceAccessChainLoad(const Instruction* address_inst, + Instruction* original_load); + + // Return true if all indices of access chain |acp| are OpConstant integers + bool IsConstantIndexAccessChain(const Instruction* acp) const; + + // Identify all function scope variables of target type which are + // accessed only with loads, stores and access chains with constant + // indices. Convert all loads and stores of such variables into equivalent + // loads, stores, extracts and inserts. This unifies access to these + // variables to a single mode and simplifies analysis and optimization. + // See IsTargetType() for targeted types. + // + // Nested access chains and pointer access chains are not currently + // converted. + bool ConvertLocalAccessChains(Function* func); + + // Initialize extensions whitelist + void InitExtensions(); + + // Return true if all extensions in this module are allowed by this pass. + bool AllExtensionsSupported() const; + + void Initialize(); + Pass::Status ProcessImpl(); + + // Variables with only supported references, ie. loads and stores using + // variable directly or through non-ptr access chains. + std::unordered_set supported_ref_ptrs_; + + // Extensions supported by this pass. + std::unordered_set extensions_whitelist_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_LOCAL_ACCESS_CHAIN_CONVERT_PASS_H_ diff --git a/source/opt/local_redundancy_elimination.cpp b/source/opt/local_redundancy_elimination.cpp new file mode 100644 index 000000000..9539e6556 --- /dev/null +++ b/source/opt/local_redundancy_elimination.cpp @@ -0,0 +1,67 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/local_redundancy_elimination.h" + +#include "source/opt/value_number_table.h" + +namespace spvtools { +namespace opt { + +Pass::Status LocalRedundancyEliminationPass::Process() { + bool modified = false; + ValueNumberTable vnTable(context()); + + for (auto& func : *get_module()) { + for (auto& bb : func) { + // Keeps track of all ids that contain a given value number. We keep + // track of multiple values because they could have the same value, but + // different decorations. + std::map value_to_ids; + if (EliminateRedundanciesInBB(&bb, vnTable, &value_to_ids)) + modified = true; + } + } + return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); +} + +bool LocalRedundancyEliminationPass::EliminateRedundanciesInBB( + BasicBlock* block, const ValueNumberTable& vnTable, + std::map* value_to_ids) { + bool modified = false; + + auto func = [this, &vnTable, &modified, value_to_ids](Instruction* inst) { + if (inst->result_id() == 0) { + return; + } + + uint32_t value = vnTable.GetValueNumber(inst); + + if (value == 0) { + return; + } + + auto candidate = value_to_ids->insert({value, inst->result_id()}); + if (!candidate.second) { + context()->KillNamesAndDecorates(inst); + context()->ReplaceAllUsesWith(inst->result_id(), candidate.first->second); + context()->KillInst(inst); + modified = true; + } + }; + block->ForEachInst(func); + return modified; +} +} // namespace opt +} // namespace spvtools diff --git a/source/opt/local_redundancy_elimination.h b/source/opt/local_redundancy_elimination.h new file mode 100644 index 000000000..770457a32 --- /dev/null +++ b/source/opt/local_redundancy_elimination.h @@ -0,0 +1,68 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_LOCAL_REDUNDANCY_ELIMINATION_H_ +#define SOURCE_OPT_LOCAL_REDUNDANCY_ELIMINATION_H_ + +#include + +#include "source/opt/ir_context.h" +#include "source/opt/pass.h" +#include "source/opt/value_number_table.h" + +namespace spvtools { +namespace opt { + +// This pass implements local redundancy elimination. Its goal is to reduce the +// number of times the same value is computed. It works on each basic block +// independently, ie local. For each instruction in a basic block, it gets the +// value number for the result id, |id|, of the instruction. If that value +// number has already been computed in the basic block, it tries to replace the +// uses of |id| by the id that already contains the same value. Then the +// current instruction is deleted. +class LocalRedundancyEliminationPass : public Pass { + public: + const char* name() const override { return "local-redundancy-elimination"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators | + IRContext::kAnalysisCFG | IRContext::kAnalysisDominatorAnalysis | + IRContext::kAnalysisNameMap | IRContext::kAnalysisConstants | + IRContext::kAnalysisTypes; + } + + protected: + // Deletes instructions in |block| whose value is in |value_to_ids| or is + // computed earlier in |block|. + // + // |vnTable| must have computed a value number for every result id defined + // in |bb|. + // + // |value_to_ids| is a map from value number to ids. If {vn, id} is in + // |value_to_ids| then vn is the value number of id, and the definition of id + // dominates |bb|. + // + // Returns true if the module is changed. + bool EliminateRedundanciesInBB(BasicBlock* block, + const ValueNumberTable& vnTable, + std::map* value_to_ids); +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_LOCAL_REDUNDANCY_ELIMINATION_H_ diff --git a/source/opt/local_single_block_elim_pass.cpp b/source/opt/local_single_block_elim_pass.cpp new file mode 100644 index 000000000..9330ab7ac --- /dev/null +++ b/source/opt/local_single_block_elim_pass.cpp @@ -0,0 +1,262 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/local_single_block_elim_pass.h" + +#include + +#include "source/opt/iterator.h" + +namespace spvtools { +namespace opt { +namespace { + +const uint32_t kStoreValIdInIdx = 1; + +} // anonymous namespace + +bool LocalSingleBlockLoadStoreElimPass::HasOnlySupportedRefs(uint32_t ptrId) { + if (supported_ref_ptrs_.find(ptrId) != supported_ref_ptrs_.end()) return true; + if (get_def_use_mgr()->WhileEachUser(ptrId, [this](Instruction* user) { + SpvOp op = user->opcode(); + if (IsNonPtrAccessChain(op) || op == SpvOpCopyObject) { + if (!HasOnlySupportedRefs(user->result_id())) { + return false; + } + } else if (op != SpvOpStore && op != SpvOpLoad && op != SpvOpName && + !IsNonTypeDecorate(op)) { + return false; + } + return true; + })) { + supported_ref_ptrs_.insert(ptrId); + return true; + } + return false; +} + +bool LocalSingleBlockLoadStoreElimPass::LocalSingleBlockLoadStoreElim( + Function* func) { + // Perform local store/load, load/load and store/store elimination + // on each block + bool modified = false; + std::vector instructions_to_kill; + std::unordered_set instructions_to_save; + for (auto bi = func->begin(); bi != func->end(); ++bi) { + var2store_.clear(); + var2load_.clear(); + auto next = bi->begin(); + for (auto ii = next; ii != bi->end(); ii = next) { + ++next; + switch (ii->opcode()) { + case SpvOpStore: { + // Verify store variable is target type + uint32_t varId; + Instruction* ptrInst = GetPtr(&*ii, &varId); + if (!IsTargetVar(varId)) continue; + if (!HasOnlySupportedRefs(varId)) continue; + // If a store to the whole variable, remember it for succeeding + // loads and stores. Otherwise forget any previous store to that + // variable. + if (ptrInst->opcode() == SpvOpVariable) { + // If a previous store to same variable, mark the store + // for deletion if not still used. + auto prev_store = var2store_.find(varId); + if (prev_store != var2store_.end() && + instructions_to_save.count(prev_store->second) == 0) { + instructions_to_kill.push_back(prev_store->second); + modified = true; + } + + bool kill_store = false; + auto li = var2load_.find(varId); + if (li != var2load_.end()) { + if (ii->GetSingleWordInOperand(kStoreValIdInIdx) == + li->second->result_id()) { + // We are storing the same value that already exists in the + // memory location. The store does nothing. + kill_store = true; + } + } + + if (!kill_store) { + var2store_[varId] = &*ii; + var2load_.erase(varId); + } else { + instructions_to_kill.push_back(&*ii); + modified = true; + } + } else { + assert(IsNonPtrAccessChain(ptrInst->opcode())); + var2store_.erase(varId); + var2load_.erase(varId); + } + } break; + case SpvOpLoad: { + // Verify store variable is target type + uint32_t varId; + Instruction* ptrInst = GetPtr(&*ii, &varId); + if (!IsTargetVar(varId)) continue; + if (!HasOnlySupportedRefs(varId)) continue; + uint32_t replId = 0; + if (ptrInst->opcode() == SpvOpVariable) { + // If a load from a variable, look for a previous store or + // load from that variable and use its value. + auto si = var2store_.find(varId); + if (si != var2store_.end()) { + replId = si->second->GetSingleWordInOperand(kStoreValIdInIdx); + } else { + auto li = var2load_.find(varId); + if (li != var2load_.end()) { + replId = li->second->result_id(); + } + } + } else { + // If a partial load of a previously seen store, remember + // not to delete the store. + auto si = var2store_.find(varId); + if (si != var2store_.end()) instructions_to_save.insert(si->second); + } + if (replId != 0) { + // replace load's result id and delete load + context()->KillNamesAndDecorates(&*ii); + context()->ReplaceAllUsesWith(ii->result_id(), replId); + instructions_to_kill.push_back(&*ii); + modified = true; + } else { + if (ptrInst->opcode() == SpvOpVariable) + var2load_[varId] = &*ii; // register load + } + } break; + case SpvOpFunctionCall: { + // Conservatively assume all locals are redefined for now. + // TODO(): Handle more optimally + var2store_.clear(); + var2load_.clear(); + } break; + default: + break; + } + } + } + + for (Instruction* inst : instructions_to_kill) { + context()->KillInst(inst); + } + + return modified; +} + +void LocalSingleBlockLoadStoreElimPass::Initialize() { + // Initialize Target Type Caches + seen_target_vars_.clear(); + seen_non_target_vars_.clear(); + + // Clear collections + supported_ref_ptrs_.clear(); + + // Initialize extensions whitelist + InitExtensions(); +} + +bool LocalSingleBlockLoadStoreElimPass::AllExtensionsSupported() const { + // If any extension not in whitelist, return false + for (auto& ei : get_module()->extensions()) { + const char* extName = + reinterpret_cast(&ei.GetInOperand(0).words[0]); + if (extensions_whitelist_.find(extName) == extensions_whitelist_.end()) + return false; + } + return true; +} + +Pass::Status LocalSingleBlockLoadStoreElimPass::ProcessImpl() { + // Assumes relaxed logical addressing only (see instruction.h). + if (context()->get_feature_mgr()->HasCapability(SpvCapabilityAddresses)) + return Status::SuccessWithoutChange; + // Do not process if module contains OpGroupDecorate. Additional + // support required in KillNamesAndDecorates(). + // TODO(greg-lunarg): Add support for OpGroupDecorate + for (auto& ai : get_module()->annotations()) + if (ai.opcode() == SpvOpGroupDecorate) return Status::SuccessWithoutChange; + // If any extensions in the module are not explicitly supported, + // return unmodified. + if (!AllExtensionsSupported()) return Status::SuccessWithoutChange; + // Process all entry point functions + ProcessFunction pfn = [this](Function* fp) { + return LocalSingleBlockLoadStoreElim(fp); + }; + + bool modified = context()->ProcessEntryPointCallTree(pfn); + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +LocalSingleBlockLoadStoreElimPass::LocalSingleBlockLoadStoreElimPass() = + default; + +Pass::Status LocalSingleBlockLoadStoreElimPass::Process() { + Initialize(); + return ProcessImpl(); +} + +void LocalSingleBlockLoadStoreElimPass::InitExtensions() { + extensions_whitelist_.clear(); + extensions_whitelist_.insert({ + "SPV_AMD_shader_explicit_vertex_parameter", + "SPV_AMD_shader_trinary_minmax", + "SPV_AMD_gcn_shader", + "SPV_KHR_shader_ballot", + "SPV_AMD_shader_ballot", + "SPV_AMD_gpu_shader_half_float", + "SPV_KHR_shader_draw_parameters", + "SPV_KHR_subgroup_vote", + "SPV_KHR_16bit_storage", + "SPV_KHR_device_group", + "SPV_KHR_multiview", + "SPV_NVX_multiview_per_view_attributes", + "SPV_NV_viewport_array2", + "SPV_NV_stereo_view_rendering", + "SPV_NV_sample_mask_override_coverage", + "SPV_NV_geometry_shader_passthrough", + "SPV_AMD_texture_gather_bias_lod", + "SPV_KHR_storage_buffer_storage_class", + // SPV_KHR_variable_pointers + // Currently do not support extended pointer expressions + "SPV_AMD_gpu_shader_int16", + "SPV_KHR_post_depth_coverage", + "SPV_KHR_shader_atomic_counter_ops", + "SPV_EXT_shader_stencil_export", + "SPV_EXT_shader_viewport_index_layer", + "SPV_AMD_shader_image_load_store_lod", + "SPV_AMD_shader_fragment_mask", + "SPV_EXT_fragment_fully_covered", + "SPV_AMD_gpu_shader_half_float_fetch", + "SPV_GOOGLE_decorate_string", + "SPV_GOOGLE_hlsl_functionality1", + "SPV_NV_shader_subgroup_partitioned", + "SPV_EXT_descriptor_indexing", + "SPV_NV_fragment_shader_barycentric", + "SPV_NV_compute_shader_derivatives", + "SPV_NV_shader_image_footprint", + "SPV_NV_shading_rate", + "SPV_NV_mesh_shader", + "SPV_NV_ray_tracing", + "SPV_EXT_fragment_invocation_density", + }); +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/local_single_block_elim_pass.h b/source/opt/local_single_block_elim_pass.h new file mode 100644 index 000000000..0fe7732a8 --- /dev/null +++ b/source/opt/local_single_block_elim_pass.h @@ -0,0 +1,107 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_LOCAL_SINGLE_BLOCK_ELIM_PASS_H_ +#define SOURCE_OPT_LOCAL_SINGLE_BLOCK_ELIM_PASS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "source/opt/basic_block.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/mem_pass.h" +#include "source/opt/module.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class LocalSingleBlockLoadStoreElimPass : public MemPass { + public: + LocalSingleBlockLoadStoreElimPass(); + + const char* name() const override { return "eliminate-local-single-block"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisConstants | IRContext::kAnalysisTypes; + } + + private: + // Return true if all uses of |varId| are only through supported reference + // operations ie. loads and store. Also cache in supported_ref_ptrs_. + // TODO(dnovillo): This function is replicated in other passes and it's + // slightly different in every pass. Is it possible to make one common + // implementation? + bool HasOnlySupportedRefs(uint32_t varId); + + // On all entry point functions, within each basic block, eliminate + // loads and stores to function variables where possible. For + // loads, if previous load or store to same variable, replace + // load id with previous id and delete load. Finally, check if + // remaining stores are useless, and delete store and variable + // where possible. Assumes logical addressing. + bool LocalSingleBlockLoadStoreElim(Function* func); + + // Initialize extensions whitelist + void InitExtensions(); + + // Return true if all extensions in this module are supported by this pass. + bool AllExtensionsSupported() const; + + void Initialize(); + Pass::Status ProcessImpl(); + + // Map from function scope variable to a store of that variable in the + // current block whose value is currently valid. This map is cleared + // at the start of each block and incrementally updated as the block + // is scanned. The stores are candidates for elimination. The map is + // conservatively cleared when a function call is encountered. + std::unordered_map var2store_; + + // Map from function scope variable to a load of that variable in the + // current block whose value is currently valid. This map is cleared + // at the start of each block and incrementally updated as the block + // is scanned. The stores are candidates for elimination. The map is + // conservatively cleared when a function call is encountered. + std::unordered_map var2load_; + + // Set of variables whose most recent store in the current block cannot be + // deleted, for example, if there is a load of the variable which is + // dependent on the store and is not replaced and deleted by this pass, + // for example, a load through an access chain. A variable is removed + // from this set each time a new store of that variable is encountered. + std::unordered_set pinned_vars_; + + // Extensions supported by this pass. + std::unordered_set extensions_whitelist_; + + // Variables that are only referenced by supported operations for this + // pass ie. loads and stores. + std::unordered_set supported_ref_ptrs_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_LOCAL_SINGLE_BLOCK_ELIM_PASS_H_ diff --git a/source/opt/local_single_store_elim_pass.cpp b/source/opt/local_single_store_elim_pass.cpp new file mode 100644 index 000000000..6c09deced --- /dev/null +++ b/source/opt/local_single_store_elim_pass.cpp @@ -0,0 +1,248 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/local_single_store_elim_pass.h" + +#include "source/cfa.h" +#include "source/latest_version_glsl_std_450_header.h" +#include "source/opt/iterator.h" + +namespace spvtools { +namespace opt { + +namespace { + +const uint32_t kStoreValIdInIdx = 1; +const uint32_t kVariableInitIdInIdx = 1; + +} // anonymous namespace + +bool LocalSingleStoreElimPass::LocalSingleStoreElim(Function* func) { + bool modified = false; + + // Check all function scope variables in |func|. + BasicBlock* entry_block = &*func->begin(); + for (Instruction& inst : *entry_block) { + if (inst.opcode() != SpvOpVariable) { + break; + } + + modified |= ProcessVariable(&inst); + } + return modified; +} + +bool LocalSingleStoreElimPass::AllExtensionsSupported() const { + // If any extension not in whitelist, return false + for (auto& ei : get_module()->extensions()) { + const char* extName = + reinterpret_cast(&ei.GetInOperand(0).words[0]); + if (extensions_whitelist_.find(extName) == extensions_whitelist_.end()) + return false; + } + return true; +} + +Pass::Status LocalSingleStoreElimPass::ProcessImpl() { + // Assumes relaxed logical addressing only (see instruction.h) + if (context()->get_feature_mgr()->HasCapability(SpvCapabilityAddresses)) + return Status::SuccessWithoutChange; + + // Do not process if any disallowed extensions are enabled + if (!AllExtensionsSupported()) return Status::SuccessWithoutChange; + // Process all entry point functions + ProcessFunction pfn = [this](Function* fp) { + return LocalSingleStoreElim(fp); + }; + bool modified = context()->ProcessEntryPointCallTree(pfn); + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +LocalSingleStoreElimPass::LocalSingleStoreElimPass() = default; + +Pass::Status LocalSingleStoreElimPass::Process() { + InitExtensionWhiteList(); + return ProcessImpl(); +} + +void LocalSingleStoreElimPass::InitExtensionWhiteList() { + extensions_whitelist_.insert({ + "SPV_AMD_shader_explicit_vertex_parameter", + "SPV_AMD_shader_trinary_minmax", + "SPV_AMD_gcn_shader", + "SPV_KHR_shader_ballot", + "SPV_AMD_shader_ballot", + "SPV_AMD_gpu_shader_half_float", + "SPV_KHR_shader_draw_parameters", + "SPV_KHR_subgroup_vote", + "SPV_KHR_16bit_storage", + "SPV_KHR_device_group", + "SPV_KHR_multiview", + "SPV_NVX_multiview_per_view_attributes", + "SPV_NV_viewport_array2", + "SPV_NV_stereo_view_rendering", + "SPV_NV_sample_mask_override_coverage", + "SPV_NV_geometry_shader_passthrough", + "SPV_AMD_texture_gather_bias_lod", + "SPV_KHR_storage_buffer_storage_class", + // SPV_KHR_variable_pointers + // Currently do not support extended pointer expressions + "SPV_AMD_gpu_shader_int16", + "SPV_KHR_post_depth_coverage", + "SPV_KHR_shader_atomic_counter_ops", + "SPV_EXT_shader_stencil_export", + "SPV_EXT_shader_viewport_index_layer", + "SPV_AMD_shader_image_load_store_lod", + "SPV_AMD_shader_fragment_mask", + "SPV_EXT_fragment_fully_covered", + "SPV_AMD_gpu_shader_half_float_fetch", + "SPV_GOOGLE_decorate_string", + "SPV_GOOGLE_hlsl_functionality1", + "SPV_NV_shader_subgroup_partitioned", + "SPV_EXT_descriptor_indexing", + "SPV_NV_fragment_shader_barycentric", + "SPV_NV_compute_shader_derivatives", + "SPV_NV_shader_image_footprint", + "SPV_NV_shading_rate", + "SPV_NV_mesh_shader", + "SPV_NV_ray_tracing", + "SPV_EXT_fragment_invocation_density", + }); +} +bool LocalSingleStoreElimPass::ProcessVariable(Instruction* var_inst) { + std::vector users; + FindUses(var_inst, &users); + + Instruction* store_inst = FindSingleStoreAndCheckUses(var_inst, users); + + if (store_inst == nullptr) { + return false; + } + + return RewriteLoads(store_inst, users); +} + +Instruction* LocalSingleStoreElimPass::FindSingleStoreAndCheckUses( + Instruction* var_inst, const std::vector& users) const { + // Make sure there is exactly 1 store. + Instruction* store_inst = nullptr; + + // If |var_inst| has an initializer, then that will count as a store. + if (var_inst->NumInOperands() > 1) { + store_inst = var_inst; + } + + for (Instruction* user : users) { + switch (user->opcode()) { + case SpvOpStore: + // Since we are in the relaxed addressing mode, the use has to be the + // base address of the store, and not the value being store. Otherwise, + // we would have a pointer to a pointer to function scope memory, which + // is not allowed. + if (store_inst == nullptr) { + store_inst = user; + } else { + // More than 1 store. + return nullptr; + } + break; + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + if (FeedsAStore(user)) { + // Has a partial store. Cannot propagate that. + return nullptr; + } + break; + case SpvOpLoad: + case SpvOpImageTexelPointer: + case SpvOpName: + case SpvOpCopyObject: + break; + default: + if (!user->IsDecoration()) { + // Don't know if this instruction modifies the variable. + // Conservatively assume it is a store. + return nullptr; + } + break; + } + } + return store_inst; +} + +void LocalSingleStoreElimPass::FindUses( + const Instruction* var_inst, std::vector* users) const { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + def_use_mgr->ForEachUser(var_inst, [users, this](Instruction* user) { + users->push_back(user); + if (user->opcode() == SpvOpCopyObject) { + FindUses(user, users); + } + }); +} + +bool LocalSingleStoreElimPass::FeedsAStore(Instruction* inst) const { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + return !def_use_mgr->WhileEachUser(inst, [this](Instruction* user) { + switch (user->opcode()) { + case SpvOpStore: + return false; + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + case SpvOpCopyObject: + return !FeedsAStore(user); + case SpvOpLoad: + case SpvOpImageTexelPointer: + case SpvOpName: + return true; + default: + // Don't know if this instruction modifies the variable. + // Conservatively assume it is a store. + return user->IsDecoration(); + } + }); +} + +bool LocalSingleStoreElimPass::RewriteLoads( + Instruction* store_inst, const std::vector& uses) { + BasicBlock* store_block = context()->get_instr_block(store_inst); + DominatorAnalysis* dominator_analysis = + context()->GetDominatorAnalysis(store_block->GetParent()); + + uint32_t stored_id; + if (store_inst->opcode() == SpvOpStore) + stored_id = store_inst->GetSingleWordInOperand(kStoreValIdInIdx); + else + stored_id = store_inst->GetSingleWordInOperand(kVariableInitIdInIdx); + + std::vector uses_in_store_block; + bool modified = false; + for (Instruction* use : uses) { + if (use->opcode() == SpvOpLoad) { + if (dominator_analysis->Dominates(store_inst, use)) { + modified = true; + context()->KillNamesAndDecorates(use->result_id()); + context()->ReplaceAllUsesWith(use->result_id(), stored_id); + context()->KillInst(use); + } + } + } + + return modified; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/local_single_store_elim_pass.h b/source/opt/local_single_store_elim_pass.h new file mode 100644 index 000000000..4cf8bbb86 --- /dev/null +++ b/source/opt/local_single_store_elim_pass.h @@ -0,0 +1,103 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_LOCAL_SINGLE_STORE_ELIM_PASS_H_ +#define SOURCE_OPT_LOCAL_SINGLE_STORE_ELIM_PASS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/opt/basic_block.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/mem_pass.h" +#include "source/opt/module.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class LocalSingleStoreElimPass : public Pass { + using cbb_ptr = const BasicBlock*; + + public: + LocalSingleStoreElimPass(); + + const char* name() const override { return "eliminate-local-single-store"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisConstants | IRContext::kAnalysisTypes; + } + + private: + // Do "single-store" optimization of function variables defined only + // with a single non-access-chain store in |func|. Replace all their + // non-access-chain loads with the value that is stored and eliminate + // any resulting dead code. + bool LocalSingleStoreElim(Function* func); + + // Initialize extensions whitelist + void InitExtensionWhiteList(); + + // Return true if all extensions in this module are allowed by this pass. + bool AllExtensionsSupported() const; + + Pass::Status ProcessImpl(); + + // If there is a single store to |var_inst|, and it covers the entire + // variable, then replace all of the loads of the entire variable that are + // dominated by the store by the value that was stored. Returns true if the + // module was changed. + bool ProcessVariable(Instruction* var_inst); + + // Collects all of the uses of |var_inst| into |uses|. This looks through + // OpObjectCopy's that copy the address of the variable, and collects those + // uses as well. + void FindUses(const Instruction* var_inst, + std::vector* uses) const; + + // Returns a store to |var_inst| if + // - it is a store to the entire variable, + // - and there are no other instructions that may modify |var_inst|. + Instruction* FindSingleStoreAndCheckUses( + Instruction* var_inst, const std::vector& users) const; + + // Returns true if the address that results from |inst| may be used as a base + // address in a store instruction or may be used to compute the base address + // of a store instruction. + bool FeedsAStore(Instruction* inst) const; + + // Replaces all of the loads in |uses| by the value stored in |store_inst|. + // The load instructions are then killed. + bool RewriteLoads(Instruction* store_inst, + const std::vector& uses); + + // Extensions supported by this pass. + std::unordered_set extensions_whitelist_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_LOCAL_SINGLE_STORE_ELIM_PASS_H_ diff --git a/source/opt/local_ssa_elim_pass.cpp b/source/opt/local_ssa_elim_pass.cpp new file mode 100644 index 000000000..8209aa405 --- /dev/null +++ b/source/opt/local_ssa_elim_pass.cpp @@ -0,0 +1,112 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/local_ssa_elim_pass.h" + +#include "source/cfa.h" +#include "source/opt/iterator.h" +#include "source/opt/ssa_rewrite_pass.h" + +namespace spvtools { +namespace opt { + +bool LocalMultiStoreElimPass::AllExtensionsSupported() const { + // If any extension not in whitelist, return false + for (auto& ei : get_module()->extensions()) { + const char* extName = + reinterpret_cast(&ei.GetInOperand(0).words[0]); + if (extensions_whitelist_.find(extName) == extensions_whitelist_.end()) + return false; + } + return true; +} + +Pass::Status LocalMultiStoreElimPass::ProcessImpl() { + // Assumes relaxed logical addressing only (see instruction.h) + // TODO(greg-lunarg): Add support for physical addressing + if (context()->get_feature_mgr()->HasCapability(SpvCapabilityAddresses)) + return Status::SuccessWithoutChange; + // Do not process if module contains OpGroupDecorate. Additional + // support required in KillNamesAndDecorates(). + // TODO(greg-lunarg): Add support for OpGroupDecorate + for (auto& ai : get_module()->annotations()) + if (ai.opcode() == SpvOpGroupDecorate) return Status::SuccessWithoutChange; + // Do not process if any disallowed extensions are enabled + if (!AllExtensionsSupported()) return Status::SuccessWithoutChange; + // Process functions + ProcessFunction pfn = [this](Function* fp) { + return SSARewriter(this).RewriteFunctionIntoSSA(fp); + }; + bool modified = context()->ProcessEntryPointCallTree(pfn); + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +LocalMultiStoreElimPass::LocalMultiStoreElimPass() = default; + +Pass::Status LocalMultiStoreElimPass::Process() { + // Initialize extension whitelist + InitExtensions(); + return ProcessImpl(); +} + +void LocalMultiStoreElimPass::InitExtensions() { + extensions_whitelist_.clear(); + extensions_whitelist_.insert({ + "SPV_AMD_shader_explicit_vertex_parameter", + "SPV_AMD_shader_trinary_minmax", + "SPV_AMD_gcn_shader", + "SPV_KHR_shader_ballot", + "SPV_AMD_shader_ballot", + "SPV_AMD_gpu_shader_half_float", + "SPV_KHR_shader_draw_parameters", + "SPV_KHR_subgroup_vote", + "SPV_KHR_16bit_storage", + "SPV_KHR_device_group", + "SPV_KHR_multiview", + "SPV_NVX_multiview_per_view_attributes", + "SPV_NV_viewport_array2", + "SPV_NV_stereo_view_rendering", + "SPV_NV_sample_mask_override_coverage", + "SPV_NV_geometry_shader_passthrough", + "SPV_AMD_texture_gather_bias_lod", + "SPV_KHR_storage_buffer_storage_class", + // SPV_KHR_variable_pointers + // Currently do not support extended pointer expressions + "SPV_AMD_gpu_shader_int16", + "SPV_KHR_post_depth_coverage", + "SPV_KHR_shader_atomic_counter_ops", + "SPV_EXT_shader_stencil_export", + "SPV_EXT_shader_viewport_index_layer", + "SPV_AMD_shader_image_load_store_lod", + "SPV_AMD_shader_fragment_mask", + "SPV_EXT_fragment_fully_covered", + "SPV_AMD_gpu_shader_half_float_fetch", + "SPV_GOOGLE_decorate_string", + "SPV_GOOGLE_hlsl_functionality1", + "SPV_NV_shader_subgroup_partitioned", + "SPV_EXT_descriptor_indexing", + "SPV_NV_fragment_shader_barycentric", + "SPV_NV_compute_shader_derivatives", + "SPV_NV_shader_image_footprint", + "SPV_NV_shading_rate", + "SPV_NV_mesh_shader", + "SPV_NV_ray_tracing", + "SPV_EXT_fragment_invocation_density", + }); +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/local_ssa_elim_pass.h b/source/opt/local_ssa_elim_pass.h new file mode 100644 index 000000000..de80d5a86 --- /dev/null +++ b/source/opt/local_ssa_elim_pass.h @@ -0,0 +1,72 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_LOCAL_SSA_ELIM_PASS_H_ +#define SOURCE_OPT_LOCAL_SSA_ELIM_PASS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/opt/basic_block.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/mem_pass.h" +#include "source/opt/module.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class LocalMultiStoreElimPass : public MemPass { + using cbb_ptr = const BasicBlock*; + + public: + using GetBlocksFunction = + std::function*(const BasicBlock*)>; + + LocalMultiStoreElimPass(); + + const char* name() const override { return "eliminate-local-multi-store"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisConstants | IRContext::kAnalysisTypes; + } + + private: + // Initialize extensions whitelist + void InitExtensions(); + + // Return true if all extensions in this module are allowed by this pass. + bool AllExtensionsSupported() const; + + Pass::Status ProcessImpl(); + + // Extensions supported by this pass. + std::unordered_set extensions_whitelist_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_LOCAL_SSA_ELIM_PASS_H_ diff --git a/source/opt/log.h b/source/opt/log.h new file mode 100644 index 000000000..f87cbf381 --- /dev/null +++ b/source/opt/log.h @@ -0,0 +1,231 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_LOG_H_ +#define SOURCE_OPT_LOG_H_ + +#include +#include +#include +#include + +#include "spirv-tools/libspirv.hpp" + +// Asserts the given condition is true. Otherwise, sends a message to the +// consumer and exits the problem with failure code. Accepts the following +// formats: +// +// SPIRV_ASSERT(, ); +// SPIRV_ASSERT(, , ); +// SPIRV_ASSERT(, , +// , ); +// +// In the third format, the number of cannot exceed (5 - +// 2). If more arguments are wanted, grow PP_ARG_N and PP_NARGS in the below. +#if !defined(NDEBUG) +#define SPIRV_ASSERT(consumer, ...) SPIRV_ASSERT_IMPL(consumer, __VA_ARGS__) +#else +#define SPIRV_ASSERT(consumer, ...) +#endif + +// Logs a debug message to the consumer. Accepts the following formats: +// +// SPIRV_DEBUG(, ); +// SPIRV_DEBUG(, , ); +// +// In the second format, the number of cannot exceed (5 - +// 1). If more arguments are wanted, grow PP_ARG_N and PP_NARGS in the below. +#if !defined(NDEBUG) && defined(SPIRV_LOG_DEBUG) +#define SPIRV_DEBUG(consumer, ...) SPIRV_DEBUG_IMPL(consumer, __VA_ARGS__) +#else +#define SPIRV_DEBUG(consumer, ...) +#endif + +// Logs an error message to the consumer saying the given feature is +// unimplemented. +#define SPIRV_UNIMPLEMENTED(consumer, feature) \ + do { \ + spvtools::Log(consumer, SPV_MSG_INTERNAL_ERROR, __FILE__, \ + {__LINE__, 0, 0}, "unimplemented: " feature); \ + } while (0) + +// Logs an error message to the consumer saying the code location +// should be unreachable. +#define SPIRV_UNREACHABLE(consumer) \ + do { \ + spvtools::Log(consumer, SPV_MSG_INTERNAL_ERROR, __FILE__, \ + {__LINE__, 0, 0}, "unreachable"); \ + } while (0) + +// Helper macros for concatenating arguments. +#define SPIRV_CONCATENATE(a, b) SPIRV_CONCATENATE_(a, b) +#define SPIRV_CONCATENATE_(a, b) a##b + +// Helper macro to force expanding __VA_ARGS__ to satisfy MSVC compiler. +#define PP_EXPAND(x) x + +namespace spvtools { + +// Calls the given |consumer| by supplying the |message|. The |message| is from +// the given |source| and |location| and of the given severity |level|. +inline void Log(const MessageConsumer& consumer, spv_message_level_t level, + const char* source, const spv_position_t& position, + const char* message) { + if (consumer != nullptr) consumer(level, source, position, message); +} + +// Calls the given |consumer| by supplying the message composed according to the +// given |format|. The |message| is from the given |source| and |location| and +// of the given severity |level|. +template +void Logf(const MessageConsumer& consumer, spv_message_level_t level, + const char* source, const spv_position_t& position, + const char* format, Args&&... args) { +#if defined(_MSC_VER) && _MSC_VER < 1900 +// Sadly, snprintf() is not supported until Visual Studio 2015! +#define snprintf _snprintf +#endif + + enum { kInitBufferSize = 256 }; + + char message[kInitBufferSize]; + const int size = + snprintf(message, kInitBufferSize, format, std::forward(args)...); + + if (size >= 0 && size < kInitBufferSize) { + Log(consumer, level, source, position, message); + return; + } + + if (size >= 0) { + // The initial buffer is insufficient. Allocate a buffer of a larger size, + // and write to it instead. Force the size to be unsigned to avoid a + // warning in GCC 7.1. + std::vector longer_message(size + 1u); + snprintf(longer_message.data(), longer_message.size(), format, + std::forward(args)...); + Log(consumer, level, source, position, longer_message.data()); + return; + } + + Log(consumer, level, source, position, "cannot compose log message"); + +#if defined(_MSC_VER) && _MSC_VER < 1900 +#undef snprintf +#endif +} + +// Calls the given |consumer| by supplying the given error |message|. The +// |message| is from the given |source| and |location|. +inline void Error(const MessageConsumer& consumer, const char* source, + const spv_position_t& position, const char* message) { + Log(consumer, SPV_MSG_ERROR, source, position, message); +} + +// Calls the given |consumer| by supplying the error message composed according +// to the given |format|. The |message| is from the given |source| and +// |location|. +template +inline void Errorf(const MessageConsumer& consumer, const char* source, + const spv_position_t& position, const char* format, + Args&&... args) { + Logf(consumer, SPV_MSG_ERROR, source, position, format, + std::forward(args)...); +} + +} // namespace spvtools + +#define SPIRV_ASSERT_IMPL(consumer, ...) \ + PP_EXPAND(SPIRV_CONCATENATE(SPIRV_ASSERT_, PP_NARGS(__VA_ARGS__))( \ + consumer, __VA_ARGS__)) + +#define SPIRV_DEBUG_IMPL(consumer, ...) \ + PP_EXPAND(SPIRV_CONCATENATE(SPIRV_DEBUG_, PP_NARGS(__VA_ARGS__))( \ + consumer, __VA_ARGS__)) + +#define SPIRV_ASSERT_1(consumer, condition) \ + do { \ + if (!(condition)) { \ + spvtools::Log(consumer, SPV_MSG_INTERNAL_ERROR, __FILE__, \ + {__LINE__, 0, 0}, "assertion failed: " #condition); \ + std::exit(EXIT_FAILURE); \ + } \ + } while (0) + +#define SPIRV_ASSERT_2(consumer, condition, message) \ + do { \ + if (!(condition)) { \ + spvtools::Log(consumer, SPV_MSG_INTERNAL_ERROR, __FILE__, \ + {__LINE__, 0, 0}, "assertion failed: " message); \ + std::exit(EXIT_FAILURE); \ + } \ + } while (0) + +#define SPIRV_ASSERT_more(consumer, condition, format, ...) \ + do { \ + if (!(condition)) { \ + spvtools::Logf(consumer, SPV_MSG_INTERNAL_ERROR, __FILE__, \ + {__LINE__, 0, 0}, "assertion failed: " format, \ + __VA_ARGS__); \ + std::exit(EXIT_FAILURE); \ + } \ + } while (0) + +#define SPIRV_ASSERT_3(consumer, condition, format, ...) \ + SPIRV_ASSERT_more(consumer, condition, format, __VA_ARGS__) + +#define SPIRV_ASSERT_4(consumer, condition, format, ...) \ + SPIRV_ASSERT_more(consumer, condition, format, __VA_ARGS__) + +#define SPIRV_ASSERT_5(consumer, condition, format, ...) \ + SPIRV_ASSERT_more(consumer, condition, format, __VA_ARGS__) + +#define SPIRV_DEBUG_1(consumer, message) \ + do { \ + spvtools::Log(consumer, SPV_MSG_DEBUG, __FILE__, {__LINE__, 0, 0}, \ + message); \ + } while (0) + +#define SPIRV_DEBUG_more(consumer, format, ...) \ + do { \ + spvtools::Logf(consumer, SPV_MSG_DEBUG, __FILE__, {__LINE__, 0, 0}, \ + format, __VA_ARGS__); \ + } while (0) + +#define SPIRV_DEBUG_2(consumer, format, ...) \ + SPIRV_DEBUG_more(consumer, format, __VA_ARGS__) + +#define SPIRV_DEBUG_3(consumer, format, ...) \ + SPIRV_DEBUG_more(consumer, format, __VA_ARGS__) + +#define SPIRV_DEBUG_4(consumer, format, ...) \ + SPIRV_DEBUG_more(consumer, format, __VA_ARGS__) + +#define SPIRV_DEBUG_5(consumer, format, ...) \ + SPIRV_DEBUG_more(consumer, format, __VA_ARGS__) + +// Macros for counting the number of arguments passed in. +#define PP_NARGS(...) PP_EXPAND(PP_ARG_N(__VA_ARGS__, 5, 4, 3, 2, 1, 0)) +#define PP_ARG_N(_1, _2, _3, _4, _5, N, ...) N + +// Tests for making sure that PP_NARGS() behaves as expected. +static_assert(PP_NARGS(0) == 1, "PP_NARGS macro error"); +static_assert(PP_NARGS(0, 0) == 2, "PP_NARGS macro error"); +static_assert(PP_NARGS(0, 0, 0) == 3, "PP_NARGS macro error"); +static_assert(PP_NARGS(0, 0, 0, 0) == 4, "PP_NARGS macro error"); +static_assert(PP_NARGS(0, 0, 0, 0, 0) == 5, "PP_NARGS macro error"); +static_assert(PP_NARGS(1 + 1, 2, 3 / 3) == 3, "PP_NARGS macro error"); +static_assert(PP_NARGS((1, 1), 2, (3, 3)) == 3, "PP_NARGS macro error"); + +#endif // SOURCE_OPT_LOG_H_ diff --git a/source/opt/loop_dependence.cpp b/source/opt/loop_dependence.cpp new file mode 100644 index 000000000..d8de699bf --- /dev/null +++ b/source/opt/loop_dependence.cpp @@ -0,0 +1,1675 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/loop_dependence.h" + +#include +#include +#include +#include +#include +#include + +#include "source/opt/instruction.h" +#include "source/opt/scalar_analysis.h" +#include "source/opt/scalar_analysis_nodes.h" + +namespace spvtools { +namespace opt { + +using SubscriptPair = std::pair; + +namespace { + +// Calculate the greatest common divisor of a & b using Stein's algorithm. +// https://en.wikipedia.org/wiki/Binary_GCD_algorithm +int64_t GreatestCommonDivisor(int64_t a, int64_t b) { + // Simple cases + if (a == b) { + return a; + } else if (a == 0) { + return b; + } else if (b == 0) { + return a; + } + + // Both even + if (a % 2 == 0 && b % 2 == 0) { + return 2 * GreatestCommonDivisor(a / 2, b / 2); + } + + // Even a, odd b + if (a % 2 == 0 && b % 2 == 1) { + return GreatestCommonDivisor(a / 2, b); + } + + // Odd a, even b + if (a % 2 == 1 && b % 2 == 0) { + return GreatestCommonDivisor(a, b / 2); + } + + // Both odd, reduce the larger argument + if (a > b) { + return GreatestCommonDivisor((a - b) / 2, b); + } else { + return GreatestCommonDivisor((b - a) / 2, a); + } +} + +// Check if node is affine, ie in the form: a0*i0 + a1*i1 + ... an*in + c +// and contains only the following types of nodes: SERecurrentNode, SEAddNode +// and SEConstantNode +bool IsInCorrectFormForGCDTest(SENode* node) { + bool children_ok = true; + + if (auto add_node = node->AsSEAddNode()) { + for (auto child : add_node->GetChildren()) { + children_ok &= IsInCorrectFormForGCDTest(child); + } + } + + bool this_ok = node->AsSERecurrentNode() || node->AsSEAddNode() || + node->AsSEConstantNode(); + + return children_ok && this_ok; +} + +// If |node| is an SERecurrentNode then returns |node| or if |node| is an +// SEAddNode returns a vector of SERecurrentNode that are its children. +std::vector GetAllTopLevelRecurrences(SENode* node) { + auto nodes = std::vector{}; + if (auto recurrent_node = node->AsSERecurrentNode()) { + nodes.push_back(recurrent_node); + } + + if (auto add_node = node->AsSEAddNode()) { + for (auto child : add_node->GetChildren()) { + auto child_nodes = GetAllTopLevelRecurrences(child); + nodes.insert(nodes.end(), child_nodes.begin(), child_nodes.end()); + } + } + + return nodes; +} + +// If |node| is an SEConstantNode then returns |node| or if |node| is an +// SEAddNode returns a vector of SEConstantNode that are its children. +std::vector GetAllTopLevelConstants(SENode* node) { + auto nodes = std::vector{}; + if (auto recurrent_node = node->AsSEConstantNode()) { + nodes.push_back(recurrent_node); + } + + if (auto add_node = node->AsSEAddNode()) { + for (auto child : add_node->GetChildren()) { + auto child_nodes = GetAllTopLevelConstants(child); + nodes.insert(nodes.end(), child_nodes.begin(), child_nodes.end()); + } + } + + return nodes; +} + +bool AreOffsetsAndCoefficientsConstant( + const std::vector& nodes) { + for (auto node : nodes) { + if (!node->GetOffset()->AsSEConstantNode() || + !node->GetOffset()->AsSEConstantNode()) { + return false; + } + } + return true; +} + +// Fold all SEConstantNode that appear in |recurrences| and |constants| into a +// single integer value. +int64_t CalculateConstantTerm(const std::vector& recurrences, + const std::vector& constants) { + int64_t constant_term = 0; + for (auto recurrence : recurrences) { + constant_term += + recurrence->GetOffset()->AsSEConstantNode()->FoldToSingleValue(); + } + + for (auto constant : constants) { + constant_term += constant->FoldToSingleValue(); + } + + return constant_term; +} + +int64_t CalculateGCDFromCoefficients( + const std::vector& recurrences, int64_t running_gcd) { + for (SERecurrentNode* recurrence : recurrences) { + auto coefficient = recurrence->GetCoefficient()->AsSEConstantNode(); + + running_gcd = GreatestCommonDivisor( + running_gcd, std::abs(coefficient->FoldToSingleValue())); + } + + return running_gcd; +} + +// Compare 2 fractions while first normalizing them, e.g. 2/4 and 4/8 will both +// be simplified to 1/2 and then determined to be equal. +bool NormalizeAndCompareFractions(int64_t numerator_0, int64_t denominator_0, + int64_t numerator_1, int64_t denominator_1) { + auto gcd_0 = + GreatestCommonDivisor(std::abs(numerator_0), std::abs(denominator_0)); + auto gcd_1 = + GreatestCommonDivisor(std::abs(numerator_1), std::abs(denominator_1)); + + auto normalized_numerator_0 = numerator_0 / gcd_0; + auto normalized_denominator_0 = denominator_0 / gcd_0; + auto normalized_numerator_1 = numerator_1 / gcd_1; + auto normalized_denominator_1 = denominator_1 / gcd_1; + + return normalized_numerator_0 == normalized_numerator_1 && + normalized_denominator_0 == normalized_denominator_1; +} + +} // namespace + +bool LoopDependenceAnalysis::GetDependence(const Instruction* source, + const Instruction* destination, + DistanceVector* distance_vector) { + // Start off by finding and marking all the loops in |loops_| that are + // irrelevant to the dependence analysis. + MarkUnsusedDistanceEntriesAsIrrelevant(source, destination, distance_vector); + + Instruction* source_access_chain = GetOperandDefinition(source, 0); + Instruction* destination_access_chain = GetOperandDefinition(destination, 0); + + auto num_access_chains = + (source_access_chain->opcode() == SpvOpAccessChain) + + (destination_access_chain->opcode() == SpvOpAccessChain); + + // If neither is an access chain, then they are load/store to a variable. + if (num_access_chains == 0) { + if (source_access_chain != destination_access_chain) { + // Not the same location, report independence + return true; + } else { + // Accessing the same variable + for (auto& entry : distance_vector->GetEntries()) { + entry = DistanceEntry(); + } + return false; + } + } + + // If only one is an access chain, it could be accessing a part of a struct + if (num_access_chains == 1) { + auto source_is_chain = source_access_chain->opcode() == SpvOpAccessChain; + auto access_chain = + source_is_chain ? source_access_chain : destination_access_chain; + auto variable = + source_is_chain ? destination_access_chain : source_access_chain; + + auto location_in_chain = GetOperandDefinition(access_chain, 0); + + if (variable != location_in_chain) { + // Not the same location, report independence + return true; + } else { + // Accessing the same variable + for (auto& entry : distance_vector->GetEntries()) { + entry = DistanceEntry(); + } + return false; + } + } + + // If the access chains aren't collecting from the same structure there is no + // dependence. + Instruction* source_array = GetOperandDefinition(source_access_chain, 0); + Instruction* destination_array = + GetOperandDefinition(destination_access_chain, 0); + + // Nested access chains are not supported yet, bail out. + if (source_array->opcode() == SpvOpAccessChain || + destination_array->opcode() == SpvOpAccessChain) { + for (auto& entry : distance_vector->GetEntries()) { + entry = DistanceEntry(); + } + return false; + } + + if (source_array != destination_array) { + PrintDebug("Proved independence through different arrays."); + return true; + } + + // To handle multiple subscripts we must get every operand in the access + // chains past the first. + std::vector source_subscripts = GetSubscripts(source); + std::vector destination_subscripts = GetSubscripts(destination); + + auto sets_of_subscripts = + PartitionSubscripts(source_subscripts, destination_subscripts); + + auto first_coupled = std::partition( + std::begin(sets_of_subscripts), std::end(sets_of_subscripts), + [](const std::set>& set) { + return set.size() == 1; + }); + + // Go through each subscript testing for independence. + // If any subscript results in independence, we prove independence between the + // load and store. + // If we can't prove independence we store what information we can gather in + // a DistanceVector. + for (auto it = std::begin(sets_of_subscripts); it < first_coupled; ++it) { + auto source_subscript = std::get<0>(*(*it).begin()); + auto destination_subscript = std::get<1>(*(*it).begin()); + + SENode* source_node = scalar_evolution_.SimplifyExpression( + scalar_evolution_.AnalyzeInstruction(source_subscript)); + SENode* destination_node = scalar_evolution_.SimplifyExpression( + scalar_evolution_.AnalyzeInstruction(destination_subscript)); + + // Check the loops are in a form we support. + auto subscript_pair = std::make_pair(source_node, destination_node); + + const Loop* loop = GetLoopForSubscriptPair(subscript_pair); + if (loop) { + if (!IsSupportedLoop(loop)) { + PrintDebug( + "GetDependence found an unsupported loop form. Assuming <=> for " + "loop."); + DistanceEntry* distance_entry = + GetDistanceEntryForSubscriptPair(subscript_pair, distance_vector); + if (distance_entry) { + distance_entry->direction = DistanceEntry::Directions::ALL; + } + continue; + } + } + + // If either node is simplified to a CanNotCompute we can't perform any + // analysis so must assume <=> dependence and return. + if (source_node->GetType() == SENode::CanNotCompute || + destination_node->GetType() == SENode::CanNotCompute) { + // Record the <=> dependence if we can get a DistanceEntry + PrintDebug( + "GetDependence found source_node || destination_node as " + "CanNotCompute. Abandoning evaluation for this subscript."); + DistanceEntry* distance_entry = + GetDistanceEntryForSubscriptPair(subscript_pair, distance_vector); + if (distance_entry) { + distance_entry->direction = DistanceEntry::Directions::ALL; + } + continue; + } + + // We have no induction variables so can apply a ZIV test. + if (IsZIV(subscript_pair)) { + PrintDebug("Found a ZIV subscript pair"); + if (ZIVTest(subscript_pair)) { + PrintDebug("Proved independence with ZIVTest."); + return true; + } + } + + // We have only one induction variable so should attempt an SIV test. + if (IsSIV(subscript_pair)) { + PrintDebug("Found a SIV subscript pair."); + if (SIVTest(subscript_pair, distance_vector)) { + PrintDebug("Proved independence with SIVTest."); + return true; + } + } + + // We have multiple induction variables so should attempt an MIV test. + if (IsMIV(subscript_pair)) { + PrintDebug("Found a MIV subscript pair."); + if (GCDMIVTest(subscript_pair)) { + PrintDebug("Proved independence with the GCD test."); + auto current_loops = CollectLoops(source_node, destination_node); + + for (auto current_loop : current_loops) { + auto distance_entry = + GetDistanceEntryForLoop(current_loop, distance_vector); + distance_entry->direction = DistanceEntry::Directions::NONE; + } + return true; + } + } + } + + for (auto it = first_coupled; it < std::end(sets_of_subscripts); ++it) { + auto coupled_instructions = *it; + std::vector coupled_subscripts{}; + + for (const auto& elem : coupled_instructions) { + auto source_subscript = std::get<0>(elem); + auto destination_subscript = std::get<1>(elem); + + SENode* source_node = scalar_evolution_.SimplifyExpression( + scalar_evolution_.AnalyzeInstruction(source_subscript)); + SENode* destination_node = scalar_evolution_.SimplifyExpression( + scalar_evolution_.AnalyzeInstruction(destination_subscript)); + + coupled_subscripts.push_back({source_node, destination_node}); + } + + auto supported = true; + + for (const auto& subscript : coupled_subscripts) { + auto loops = CollectLoops(std::get<0>(subscript), std::get<1>(subscript)); + + auto is_subscript_supported = + std::all_of(std::begin(loops), std::end(loops), + [this](const Loop* l) { return IsSupportedLoop(l); }); + + supported = supported && is_subscript_supported; + } + + if (DeltaTest(coupled_subscripts, distance_vector)) { + return true; + } + } + + // We were unable to prove independence so must gather all of the direction + // information we found. + PrintDebug( + "Couldn't prove independence.\n" + "All possible direction information has been collected in the input " + "DistanceVector."); + + return false; +} + +bool LoopDependenceAnalysis::ZIVTest( + const std::pair& subscript_pair) { + auto source = std::get<0>(subscript_pair); + auto destination = std::get<1>(subscript_pair); + + PrintDebug("Performing ZIVTest"); + // If source == destination, dependence with direction = and distance 0. + if (source == destination) { + PrintDebug("ZIVTest found EQ dependence."); + return false; + } else { + PrintDebug("ZIVTest found independence."); + // Otherwise we prove independence. + return true; + } +} + +bool LoopDependenceAnalysis::SIVTest( + const std::pair& subscript_pair, + DistanceVector* distance_vector) { + DistanceEntry* distance_entry = + GetDistanceEntryForSubscriptPair(subscript_pair, distance_vector); + if (!distance_entry) { + PrintDebug( + "SIVTest could not find a DistanceEntry for subscript_pair. Exiting"); + } + + SENode* source_node = std::get<0>(subscript_pair); + SENode* destination_node = std::get<1>(subscript_pair); + + int64_t source_induction_count = CountInductionVariables(source_node); + int64_t destination_induction_count = + CountInductionVariables(destination_node); + + // If the source node has no induction variables we can apply a + // WeakZeroSrcTest. + if (source_induction_count == 0) { + PrintDebug("Found source has no induction variable."); + if (WeakZeroSourceSIVTest( + source_node, destination_node->AsSERecurrentNode(), + destination_node->AsSERecurrentNode()->GetCoefficient(), + distance_entry)) { + PrintDebug("Proved independence with WeakZeroSourceSIVTest."); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DIRECTION; + distance_entry->direction = DistanceEntry::Directions::NONE; + return true; + } + } + + // If the destination has no induction variables we can apply a + // WeakZeroDestTest. + if (destination_induction_count == 0) { + PrintDebug("Found destination has no induction variable."); + if (WeakZeroDestinationSIVTest( + source_node->AsSERecurrentNode(), destination_node, + source_node->AsSERecurrentNode()->GetCoefficient(), + distance_entry)) { + PrintDebug("Proved independence with WeakZeroDestinationSIVTest."); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DIRECTION; + distance_entry->direction = DistanceEntry::Directions::NONE; + return true; + } + } + + // We now need to collect the SERecurrentExpr nodes from source and + // destination. We do not handle cases where source or destination have + // multiple SERecurrentExpr nodes. + std::vector source_recurrent_nodes = + source_node->CollectRecurrentNodes(); + std::vector destination_recurrent_nodes = + destination_node->CollectRecurrentNodes(); + + if (source_recurrent_nodes.size() == 1 && + destination_recurrent_nodes.size() == 1) { + PrintDebug("Found source and destination have 1 induction variable."); + SERecurrentNode* source_recurrent_expr = *source_recurrent_nodes.begin(); + SERecurrentNode* destination_recurrent_expr = + *destination_recurrent_nodes.begin(); + + // If the coefficients are identical we can apply a StrongSIVTest. + if (source_recurrent_expr->GetCoefficient() == + destination_recurrent_expr->GetCoefficient()) { + PrintDebug("Found source and destination share coefficient."); + if (StrongSIVTest(source_node, destination_node, + source_recurrent_expr->GetCoefficient(), + distance_entry)) { + PrintDebug("Proved independence with StrongSIVTest"); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DIRECTION; + distance_entry->direction = DistanceEntry::Directions::NONE; + return true; + } + } + + // If the coefficients are of equal magnitude and opposite sign we can + // apply a WeakCrossingSIVTest. + if (source_recurrent_expr->GetCoefficient() == + scalar_evolution_.CreateNegation( + destination_recurrent_expr->GetCoefficient())) { + PrintDebug("Found source coefficient = -destination coefficient."); + if (WeakCrossingSIVTest(source_node, destination_node, + source_recurrent_expr->GetCoefficient(), + distance_entry)) { + PrintDebug("Proved independence with WeakCrossingSIVTest"); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DIRECTION; + distance_entry->direction = DistanceEntry::Directions::NONE; + return true; + } + } + } + + return false; +} + +bool LoopDependenceAnalysis::StrongSIVTest(SENode* source, SENode* destination, + SENode* coefficient, + DistanceEntry* distance_entry) { + PrintDebug("Performing StrongSIVTest."); + // If both source and destination are SERecurrentNodes we can perform tests + // based on distance. + // If either source or destination contain value unknown nodes or if one or + // both are not SERecurrentNodes we must attempt a symbolic test. + std::vector source_value_unknown_nodes = + source->CollectValueUnknownNodes(); + std::vector destination_value_unknown_nodes = + destination->CollectValueUnknownNodes(); + if (source_value_unknown_nodes.size() > 0 || + destination_value_unknown_nodes.size() > 0) { + PrintDebug( + "StrongSIVTest found symbolics. Will attempt SymbolicStrongSIVTest."); + return SymbolicStrongSIVTest(source, destination, coefficient, + distance_entry); + } + + if (!source->AsSERecurrentNode() || !destination->AsSERecurrentNode()) { + PrintDebug( + "StrongSIVTest could not simplify source and destination to " + "SERecurrentNodes so will exit."); + distance_entry->direction = DistanceEntry::Directions::ALL; + return false; + } + + // Build an SENode for distance. + std::pair subscript_pair = + std::make_pair(source, destination); + const Loop* subscript_loop = GetLoopForSubscriptPair(subscript_pair); + SENode* source_constant_term = + GetConstantTerm(subscript_loop, source->AsSERecurrentNode()); + SENode* destination_constant_term = + GetConstantTerm(subscript_loop, destination->AsSERecurrentNode()); + if (!source_constant_term || !destination_constant_term) { + PrintDebug( + "StrongSIVTest could not collect the constant terms of either source " + "or destination so will exit."); + return false; + } + SENode* constant_term_delta = + scalar_evolution_.SimplifyExpression(scalar_evolution_.CreateSubtraction( + destination_constant_term, source_constant_term)); + + // Scalar evolution doesn't perform division, so we must fold to constants and + // do it manually. + // We must check the offset delta and coefficient are constants. + int64_t distance = 0; + SEConstantNode* delta_constant = constant_term_delta->AsSEConstantNode(); + SEConstantNode* coefficient_constant = coefficient->AsSEConstantNode(); + if (delta_constant && coefficient_constant) { + int64_t delta_value = delta_constant->FoldToSingleValue(); + int64_t coefficient_value = coefficient_constant->FoldToSingleValue(); + PrintDebug( + "StrongSIVTest found delta value and coefficient value as constants " + "with values:\n" + "\tdelta value: " + + ToString(delta_value) + + "\n\tcoefficient value: " + ToString(coefficient_value) + "\n"); + // Check if the distance is not integral to try to prove independence. + if (delta_value % coefficient_value != 0) { + PrintDebug( + "StrongSIVTest proved independence through distance not being an " + "integer."); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DIRECTION; + distance_entry->direction = DistanceEntry::Directions::NONE; + return true; + } else { + distance = delta_value / coefficient_value; + PrintDebug("StrongSIV test found distance as " + ToString(distance)); + } + } else { + // If we can't fold delta and coefficient to single values we can't produce + // distance. + // As a result we can't perform the rest of the pass and must assume + // dependence in all directions. + PrintDebug("StrongSIVTest could not produce a distance. Must exit."); + distance_entry->distance = DistanceEntry::Directions::ALL; + return false; + } + + // Next we gather the upper and lower bounds as constants if possible. If + // distance > upper_bound - lower_bound we prove independence. + SENode* lower_bound = GetLowerBound(subscript_loop); + SENode* upper_bound = GetUpperBound(subscript_loop); + if (lower_bound && upper_bound) { + PrintDebug("StrongSIVTest found bounds."); + SENode* bounds = scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateSubtraction(upper_bound, lower_bound)); + + if (bounds->GetType() == SENode::SENodeType::Constant) { + int64_t bounds_value = bounds->AsSEConstantNode()->FoldToSingleValue(); + PrintDebug( + "StrongSIVTest found upper_bound - lower_bound as a constant with " + "value " + + ToString(bounds_value)); + + // If the absolute value of the distance is > upper bound - lower bound + // then we prove independence. + if (llabs(distance) > llabs(bounds_value)) { + PrintDebug( + "StrongSIVTest proved independence through distance escaping the " + "loop bounds."); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DISTANCE; + distance_entry->direction = DistanceEntry::Directions::NONE; + distance_entry->distance = distance; + return true; + } + } + } else { + PrintDebug("StrongSIVTest was unable to gather lower and upper bounds."); + } + + // Otherwise we can get a direction as follows + // { < if distance > 0 + // direction = { = if distance == 0 + // { > if distance < 0 + PrintDebug( + "StrongSIVTest could not prove independence. Gathering direction " + "information."); + if (distance > 0) { + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DISTANCE; + distance_entry->direction = DistanceEntry::Directions::LT; + distance_entry->distance = distance; + return false; + } + if (distance == 0) { + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DISTANCE; + distance_entry->direction = DistanceEntry::Directions::EQ; + distance_entry->distance = 0; + return false; + } + if (distance < 0) { + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DISTANCE; + distance_entry->direction = DistanceEntry::Directions::GT; + distance_entry->distance = distance; + return false; + } + + // We were unable to prove independence or discern any additional information + // Must assume <=> direction. + PrintDebug( + "StrongSIVTest was unable to determine any dependence information."); + distance_entry->direction = DistanceEntry::Directions::ALL; + return false; +} + +bool LoopDependenceAnalysis::SymbolicStrongSIVTest( + SENode* source, SENode* destination, SENode* coefficient, + DistanceEntry* distance_entry) { + PrintDebug("Performing SymbolicStrongSIVTest."); + SENode* source_destination_delta = scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateSubtraction(source, destination)); + // By cancelling out the induction variables by subtracting the source and + // destination we can produce an expression of symbolics and constants. This + // expression can be compared to the loop bounds to find if the offset is + // outwith the bounds. + std::pair subscript_pair = + std::make_pair(source, destination); + const Loop* subscript_loop = GetLoopForSubscriptPair(subscript_pair); + if (IsProvablyOutsideOfLoopBounds(subscript_loop, source_destination_delta, + coefficient)) { + PrintDebug( + "SymbolicStrongSIVTest proved independence through loop bounds."); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DIRECTION; + distance_entry->direction = DistanceEntry::Directions::NONE; + return true; + } + // We were unable to prove independence or discern any additional information. + // Must assume <=> direction. + PrintDebug( + "SymbolicStrongSIVTest was unable to determine any dependence " + "information."); + distance_entry->direction = DistanceEntry::Directions::ALL; + return false; +} + +bool LoopDependenceAnalysis::WeakZeroSourceSIVTest( + SENode* source, SERecurrentNode* destination, SENode* coefficient, + DistanceEntry* distance_entry) { + PrintDebug("Performing WeakZeroSourceSIVTest."); + std::pair subscript_pair = + std::make_pair(source, destination); + const Loop* subscript_loop = GetLoopForSubscriptPair(subscript_pair); + // Build an SENode for distance. + SENode* destination_constant_term = + GetConstantTerm(subscript_loop, destination); + SENode* delta = scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateSubtraction(source, destination_constant_term)); + + // Scalar evolution doesn't perform division, so we must fold to constants and + // do it manually. + int64_t distance = 0; + SEConstantNode* delta_constant = delta->AsSEConstantNode(); + SEConstantNode* coefficient_constant = coefficient->AsSEConstantNode(); + if (delta_constant && coefficient_constant) { + PrintDebug( + "WeakZeroSourceSIVTest folding delta and coefficient to constants."); + int64_t delta_value = delta_constant->FoldToSingleValue(); + int64_t coefficient_value = coefficient_constant->FoldToSingleValue(); + // Check if the distance is not integral. + if (delta_value % coefficient_value != 0) { + PrintDebug( + "WeakZeroSourceSIVTest proved independence through distance not " + "being an integer."); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DIRECTION; + distance_entry->direction = DistanceEntry::Directions::NONE; + return true; + } else { + distance = delta_value / coefficient_value; + PrintDebug( + "WeakZeroSourceSIVTest calculated distance with the following " + "values\n" + "\tdelta value: " + + ToString(delta_value) + + "\n\tcoefficient value: " + ToString(coefficient_value) + + "\n\tdistance: " + ToString(distance) + "\n"); + } + } else { + PrintDebug( + "WeakZeroSourceSIVTest was unable to fold delta and coefficient to " + "constants."); + } + + // If we can prove the distance is outside the bounds we prove independence. + SEConstantNode* lower_bound = + GetLowerBound(subscript_loop)->AsSEConstantNode(); + SEConstantNode* upper_bound = + GetUpperBound(subscript_loop)->AsSEConstantNode(); + if (lower_bound && upper_bound) { + PrintDebug("WeakZeroSourceSIVTest found bounds as SEConstantNodes."); + int64_t lower_bound_value = lower_bound->FoldToSingleValue(); + int64_t upper_bound_value = upper_bound->FoldToSingleValue(); + if (!IsWithinBounds(llabs(distance), lower_bound_value, + upper_bound_value)) { + PrintDebug( + "WeakZeroSourceSIVTest proved independence through distance escaping " + "the loop bounds."); + PrintDebug( + "Bound values were as follow\n" + "\tlower bound value: " + + ToString(lower_bound_value) + + "\n\tupper bound value: " + ToString(upper_bound_value) + + "\n\tdistance value: " + ToString(distance) + "\n"); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DISTANCE; + distance_entry->direction = DistanceEntry::Directions::NONE; + distance_entry->distance = distance; + return true; + } + } else { + PrintDebug( + "WeakZeroSourceSIVTest was unable to find lower and upper bound as " + "SEConstantNodes."); + } + + // Now we want to see if we can detect to peel the first or last iterations. + + // We get the FirstTripValue as GetFirstTripInductionNode() + + // GetConstantTerm(destination) + SENode* first_trip_SENode = + scalar_evolution_.SimplifyExpression(scalar_evolution_.CreateAddNode( + GetFirstTripInductionNode(subscript_loop), + GetConstantTerm(subscript_loop, destination))); + + // If source == FirstTripValue, peel_first. + if (first_trip_SENode) { + PrintDebug("WeakZeroSourceSIVTest built first_trip_SENode."); + if (first_trip_SENode->AsSEConstantNode()) { + PrintDebug( + "WeakZeroSourceSIVTest has found first_trip_SENode as an " + "SEConstantNode with value: " + + ToString(first_trip_SENode->AsSEConstantNode()->FoldToSingleValue()) + + "\n"); + } + if (source == first_trip_SENode) { + // We have found that peeling the first iteration will break dependency. + PrintDebug( + "WeakZeroSourceSIVTest has found peeling first iteration will break " + "dependency"); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::PEEL; + distance_entry->peel_first = true; + return false; + } + } else { + PrintDebug("WeakZeroSourceSIVTest was unable to build first_trip_SENode"); + } + + // We get the LastTripValue as GetFinalTripInductionNode(coefficient) + + // GetConstantTerm(destination) + SENode* final_trip_SENode = + scalar_evolution_.SimplifyExpression(scalar_evolution_.CreateAddNode( + GetFinalTripInductionNode(subscript_loop, coefficient), + GetConstantTerm(subscript_loop, destination))); + + // If source == LastTripValue, peel_last. + if (final_trip_SENode) { + PrintDebug("WeakZeroSourceSIVTest built final_trip_SENode."); + if (first_trip_SENode->AsSEConstantNode()) { + PrintDebug( + "WeakZeroSourceSIVTest has found final_trip_SENode as an " + "SEConstantNode with value: " + + ToString(final_trip_SENode->AsSEConstantNode()->FoldToSingleValue()) + + "\n"); + } + if (source == final_trip_SENode) { + // We have found that peeling the last iteration will break dependency. + PrintDebug( + "WeakZeroSourceSIVTest has found peeling final iteration will break " + "dependency"); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::PEEL; + distance_entry->peel_last = true; + return false; + } + } else { + PrintDebug("WeakZeroSourceSIVTest was unable to build final_trip_SENode"); + } + + // We were unable to prove independence or discern any additional information. + // Must assume <=> direction. + PrintDebug( + "WeakZeroSourceSIVTest was unable to determine any dependence " + "information."); + distance_entry->direction = DistanceEntry::Directions::ALL; + return false; +} + +bool LoopDependenceAnalysis::WeakZeroDestinationSIVTest( + SERecurrentNode* source, SENode* destination, SENode* coefficient, + DistanceEntry* distance_entry) { + PrintDebug("Performing WeakZeroDestinationSIVTest."); + // Build an SENode for distance. + std::pair subscript_pair = + std::make_pair(source, destination); + const Loop* subscript_loop = GetLoopForSubscriptPair(subscript_pair); + SENode* source_constant_term = GetConstantTerm(subscript_loop, source); + SENode* delta = scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateSubtraction(destination, source_constant_term)); + + // Scalar evolution doesn't perform division, so we must fold to constants and + // do it manually. + int64_t distance = 0; + SEConstantNode* delta_constant = delta->AsSEConstantNode(); + SEConstantNode* coefficient_constant = coefficient->AsSEConstantNode(); + if (delta_constant && coefficient_constant) { + PrintDebug( + "WeakZeroDestinationSIVTest folding delta and coefficient to " + "constants."); + int64_t delta_value = delta_constant->FoldToSingleValue(); + int64_t coefficient_value = coefficient_constant->FoldToSingleValue(); + // Check if the distance is not integral. + if (delta_value % coefficient_value != 0) { + PrintDebug( + "WeakZeroDestinationSIVTest proved independence through distance not " + "being an integer."); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DIRECTION; + distance_entry->direction = DistanceEntry::Directions::NONE; + return true; + } else { + distance = delta_value / coefficient_value; + PrintDebug( + "WeakZeroDestinationSIVTest calculated distance with the following " + "values\n" + "\tdelta value: " + + ToString(delta_value) + + "\n\tcoefficient value: " + ToString(coefficient_value) + + "\n\tdistance: " + ToString(distance) + "\n"); + } + } else { + PrintDebug( + "WeakZeroDestinationSIVTest was unable to fold delta and coefficient " + "to constants."); + } + + // If we can prove the distance is outside the bounds we prove independence. + SEConstantNode* lower_bound = + GetLowerBound(subscript_loop)->AsSEConstantNode(); + SEConstantNode* upper_bound = + GetUpperBound(subscript_loop)->AsSEConstantNode(); + if (lower_bound && upper_bound) { + PrintDebug("WeakZeroDestinationSIVTest found bounds as SEConstantNodes."); + int64_t lower_bound_value = lower_bound->FoldToSingleValue(); + int64_t upper_bound_value = upper_bound->FoldToSingleValue(); + if (!IsWithinBounds(llabs(distance), lower_bound_value, + upper_bound_value)) { + PrintDebug( + "WeakZeroDestinationSIVTest proved independence through distance " + "escaping the loop bounds."); + PrintDebug( + "Bound values were as follows\n" + "\tlower bound value: " + + ToString(lower_bound_value) + + "\n\tupper bound value: " + ToString(upper_bound_value) + + "\n\tdistance value: " + ToString(distance)); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DISTANCE; + distance_entry->direction = DistanceEntry::Directions::NONE; + distance_entry->distance = distance; + return true; + } + } else { + PrintDebug( + "WeakZeroDestinationSIVTest was unable to find lower and upper bound " + "as SEConstantNodes."); + } + + // Now we want to see if we can detect to peel the first or last iterations. + + // We get the FirstTripValue as GetFirstTripInductionNode() + + // GetConstantTerm(source) + SENode* first_trip_SENode = scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateAddNode(GetFirstTripInductionNode(subscript_loop), + GetConstantTerm(subscript_loop, source))); + + // If destination == FirstTripValue, peel_first. + if (first_trip_SENode) { + PrintDebug("WeakZeroDestinationSIVTest built first_trip_SENode."); + if (first_trip_SENode->AsSEConstantNode()) { + PrintDebug( + "WeakZeroDestinationSIVTest has found first_trip_SENode as an " + "SEConstantNode with value: " + + ToString(first_trip_SENode->AsSEConstantNode()->FoldToSingleValue()) + + "\n"); + } + if (destination == first_trip_SENode) { + // We have found that peeling the first iteration will break dependency. + PrintDebug( + "WeakZeroDestinationSIVTest has found peeling first iteration will " + "break dependency"); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::PEEL; + distance_entry->peel_first = true; + return false; + } + } else { + PrintDebug( + "WeakZeroDestinationSIVTest was unable to build first_trip_SENode"); + } + + // We get the LastTripValue as GetFinalTripInductionNode(coefficient) + + // GetConstantTerm(source) + SENode* final_trip_SENode = + scalar_evolution_.SimplifyExpression(scalar_evolution_.CreateAddNode( + GetFinalTripInductionNode(subscript_loop, coefficient), + GetConstantTerm(subscript_loop, source))); + + // If destination == LastTripValue, peel_last. + if (final_trip_SENode) { + PrintDebug("WeakZeroDestinationSIVTest built final_trip_SENode."); + if (final_trip_SENode->AsSEConstantNode()) { + PrintDebug( + "WeakZeroDestinationSIVTest has found final_trip_SENode as an " + "SEConstantNode with value: " + + ToString(final_trip_SENode->AsSEConstantNode()->FoldToSingleValue()) + + "\n"); + } + if (destination == final_trip_SENode) { + // We have found that peeling the last iteration will break dependency. + PrintDebug( + "WeakZeroDestinationSIVTest has found peeling final iteration will " + "break dependency"); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::PEEL; + distance_entry->peel_last = true; + return false; + } + } else { + PrintDebug( + "WeakZeroDestinationSIVTest was unable to build final_trip_SENode"); + } + + // We were unable to prove independence or discern any additional information. + // Must assume <=> direction. + PrintDebug( + "WeakZeroDestinationSIVTest was unable to determine any dependence " + "information."); + distance_entry->direction = DistanceEntry::Directions::ALL; + return false; +} + +bool LoopDependenceAnalysis::WeakCrossingSIVTest( + SENode* source, SENode* destination, SENode* coefficient, + DistanceEntry* distance_entry) { + PrintDebug("Performing WeakCrossingSIVTest."); + // We currently can't handle symbolic WeakCrossingSIVTests. If either source + // or destination are not SERecurrentNodes we must exit. + if (!source->AsSERecurrentNode() || !destination->AsSERecurrentNode()) { + PrintDebug( + "WeakCrossingSIVTest found source or destination != SERecurrentNode. " + "Exiting"); + distance_entry->direction = DistanceEntry::Directions::ALL; + return false; + } + + // Build an SENode for distance. + SENode* offset_delta = + scalar_evolution_.SimplifyExpression(scalar_evolution_.CreateSubtraction( + destination->AsSERecurrentNode()->GetOffset(), + source->AsSERecurrentNode()->GetOffset())); + + // Scalar evolution doesn't perform division, so we must fold to constants and + // do it manually. + int64_t distance = 0; + SEConstantNode* delta_constant = offset_delta->AsSEConstantNode(); + SEConstantNode* coefficient_constant = coefficient->AsSEConstantNode(); + if (delta_constant && coefficient_constant) { + PrintDebug( + "WeakCrossingSIVTest folding offset_delta and coefficient to " + "constants."); + int64_t delta_value = delta_constant->FoldToSingleValue(); + int64_t coefficient_value = coefficient_constant->FoldToSingleValue(); + // Check if the distance is not integral or if it has a non-integral part + // equal to 1/2. + if (delta_value % (2 * coefficient_value) != 0 && + static_cast(delta_value % (2 * coefficient_value)) / + static_cast(2 * coefficient_value) != + 0.5) { + PrintDebug( + "WeakCrossingSIVTest proved independence through distance escaping " + "the loop bounds."); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DIRECTION; + distance_entry->direction = DistanceEntry::Directions::NONE; + return true; + } else { + distance = delta_value / (2 * coefficient_value); + } + + if (distance == 0) { + PrintDebug("WeakCrossingSIVTest found EQ dependence."); + distance_entry->dependence_information = + DistanceEntry::DependenceInformation::DISTANCE; + distance_entry->direction = DistanceEntry::Directions::EQ; + distance_entry->distance = 0; + return false; + } + } else { + PrintDebug( + "WeakCrossingSIVTest was unable to fold offset_delta and coefficient " + "to constants."); + } + + // We were unable to prove independence or discern any additional information. + // Must assume <=> direction. + PrintDebug( + "WeakCrossingSIVTest was unable to determine any dependence " + "information."); + distance_entry->direction = DistanceEntry::Directions::ALL; + return false; +} + +// Perform the GCD test if both, the source and the destination nodes, are in +// the form a0*i0 + a1*i1 + ... an*in + c. +bool LoopDependenceAnalysis::GCDMIVTest( + const std::pair& subscript_pair) { + auto source = std::get<0>(subscript_pair); + auto destination = std::get<1>(subscript_pair); + + // Bail out if source/destination is in an unexpected form. + if (!IsInCorrectFormForGCDTest(source) || + !IsInCorrectFormForGCDTest(destination)) { + return false; + } + + auto source_recurrences = GetAllTopLevelRecurrences(source); + auto dest_recurrences = GetAllTopLevelRecurrences(destination); + + // Bail out if all offsets and coefficients aren't constant. + if (!AreOffsetsAndCoefficientsConstant(source_recurrences) || + !AreOffsetsAndCoefficientsConstant(dest_recurrences)) { + return false; + } + + // Calculate the GCD of all coefficients. + auto source_constants = GetAllTopLevelConstants(source); + int64_t source_constant = + CalculateConstantTerm(source_recurrences, source_constants); + + auto dest_constants = GetAllTopLevelConstants(destination); + int64_t destination_constant = + CalculateConstantTerm(dest_recurrences, dest_constants); + + int64_t delta = std::abs(source_constant - destination_constant); + + int64_t running_gcd = 0; + + running_gcd = CalculateGCDFromCoefficients(source_recurrences, running_gcd); + running_gcd = CalculateGCDFromCoefficients(dest_recurrences, running_gcd); + + return delta % running_gcd != 0; +} + +using PartitionedSubscripts = + std::vector>>; +PartitionedSubscripts LoopDependenceAnalysis::PartitionSubscripts( + const std::vector& source_subscripts, + const std::vector& destination_subscripts) { + PartitionedSubscripts partitions{}; + + auto num_subscripts = source_subscripts.size(); + + // Create initial partitions with one subscript pair per partition. + for (size_t i = 0; i < num_subscripts; ++i) { + partitions.push_back({{source_subscripts[i], destination_subscripts[i]}}); + } + + // Iterate over the loops to create all partitions + for (auto loop : loops_) { + int64_t k = -1; + + for (size_t j = 0; j < partitions.size(); ++j) { + auto& current_partition = partitions[j]; + + // Does |loop| appear in |current_partition| + auto it = std::find_if( + current_partition.begin(), current_partition.end(), + [loop, + this](const std::pair& elem) -> bool { + auto source_recurrences = + scalar_evolution_.AnalyzeInstruction(std::get<0>(elem)) + ->CollectRecurrentNodes(); + auto destination_recurrences = + scalar_evolution_.AnalyzeInstruction(std::get<1>(elem)) + ->CollectRecurrentNodes(); + + source_recurrences.insert(source_recurrences.end(), + destination_recurrences.begin(), + destination_recurrences.end()); + + auto loops_in_pair = CollectLoops(source_recurrences); + auto end_it = loops_in_pair.end(); + + return std::find(loops_in_pair.begin(), end_it, loop) != end_it; + }); + + auto has_loop = it != current_partition.end(); + + if (has_loop) { + if (k == -1) { + k = j; + } else { + // Add |partitions[j]| to |partitions[k]| and discard |partitions[j]| + partitions[static_cast(k)].insert(current_partition.begin(), + current_partition.end()); + current_partition.clear(); + } + } + } + } + + // Remove empty (discarded) partitions + partitions.erase( + std::remove_if( + partitions.begin(), partitions.end(), + [](const std::set>& partition) { + return partition.empty(); + }), + partitions.end()); + + return partitions; +} + +Constraint* LoopDependenceAnalysis::IntersectConstraints( + Constraint* constraint_0, Constraint* constraint_1, + const SENode* lower_bound, const SENode* upper_bound) { + if (constraint_0->AsDependenceNone()) { + return constraint_1; + } else if (constraint_1->AsDependenceNone()) { + return constraint_0; + } + + // Both constraints are distances. Either the same distance or independent. + if (constraint_0->AsDependenceDistance() && + constraint_1->AsDependenceDistance()) { + auto dist_0 = constraint_0->AsDependenceDistance(); + auto dist_1 = constraint_1->AsDependenceDistance(); + + if (*dist_0->GetDistance() == *dist_1->GetDistance()) { + return constraint_0; + } else { + return make_constraint(); + } + } + + // Both constraints are points. Either the same point or independent. + if (constraint_0->AsDependencePoint() && constraint_1->AsDependencePoint()) { + auto point_0 = constraint_0->AsDependencePoint(); + auto point_1 = constraint_1->AsDependencePoint(); + + if (*point_0->GetSource() == *point_1->GetSource() && + *point_0->GetDestination() == *point_1->GetDestination()) { + return constraint_0; + } else { + return make_constraint(); + } + } + + // Both constraints are lines/distances. + if ((constraint_0->AsDependenceDistance() || + constraint_0->AsDependenceLine()) && + (constraint_1->AsDependenceDistance() || + constraint_1->AsDependenceLine())) { + auto is_distance_0 = constraint_0->AsDependenceDistance() != nullptr; + auto is_distance_1 = constraint_1->AsDependenceDistance() != nullptr; + + auto a0 = is_distance_0 ? scalar_evolution_.CreateConstant(1) + : constraint_0->AsDependenceLine()->GetA(); + auto b0 = is_distance_0 ? scalar_evolution_.CreateConstant(-1) + : constraint_0->AsDependenceLine()->GetB(); + auto c0 = + is_distance_0 + ? scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateNegation( + constraint_0->AsDependenceDistance()->GetDistance())) + : constraint_0->AsDependenceLine()->GetC(); + + auto a1 = is_distance_1 ? scalar_evolution_.CreateConstant(1) + : constraint_1->AsDependenceLine()->GetA(); + auto b1 = is_distance_1 ? scalar_evolution_.CreateConstant(-1) + : constraint_1->AsDependenceLine()->GetB(); + auto c1 = + is_distance_1 + ? scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateNegation( + constraint_1->AsDependenceDistance()->GetDistance())) + : constraint_1->AsDependenceLine()->GetC(); + + if (a0->AsSEConstantNode() && b0->AsSEConstantNode() && + c0->AsSEConstantNode() && a1->AsSEConstantNode() && + b1->AsSEConstantNode() && c1->AsSEConstantNode()) { + auto constant_a0 = a0->AsSEConstantNode()->FoldToSingleValue(); + auto constant_b0 = b0->AsSEConstantNode()->FoldToSingleValue(); + auto constant_c0 = c0->AsSEConstantNode()->FoldToSingleValue(); + + auto constant_a1 = a1->AsSEConstantNode()->FoldToSingleValue(); + auto constant_b1 = b1->AsSEConstantNode()->FoldToSingleValue(); + auto constant_c1 = c1->AsSEConstantNode()->FoldToSingleValue(); + + // a & b can't both be zero, otherwise it wouldn't be line. + if (NormalizeAndCompareFractions(constant_a0, constant_b0, constant_a1, + constant_b1)) { + // Slopes are equal, either parallel lines or the same line. + + if (constant_b0 == 0 && constant_b1 == 0) { + if (NormalizeAndCompareFractions(constant_c0, constant_a0, + constant_c1, constant_a1)) { + return constraint_0; + } + + return make_constraint(); + } else if (NormalizeAndCompareFractions(constant_c0, constant_b0, + constant_c1, constant_b1)) { + // Same line. + return constraint_0; + } else { + // Parallel lines can't intersect, report independence. + return make_constraint(); + } + + } else { + // Lines are not parallel, therefore, they must intersect. + + // Calculate intersection. + if (upper_bound->AsSEConstantNode() && + lower_bound->AsSEConstantNode()) { + auto constant_lower_bound = + lower_bound->AsSEConstantNode()->FoldToSingleValue(); + auto constant_upper_bound = + upper_bound->AsSEConstantNode()->FoldToSingleValue(); + + auto up = constant_b1 * constant_c0 - constant_b0 * constant_c1; + // Both b or both a can't be 0, so down is never 0 + // otherwise would have entered the parallel line section. + auto down = constant_b1 * constant_a0 - constant_b0 * constant_a1; + + auto x_coord = up / down; + + int64_t y_coord = 0; + int64_t arg1 = 0; + int64_t const_b_to_use = 0; + + if (constant_b1 != 0) { + arg1 = constant_c1 - constant_a1 * x_coord; + y_coord = arg1 / constant_b1; + const_b_to_use = constant_b1; + } else if (constant_b0 != 0) { + arg1 = constant_c0 - constant_a0 * x_coord; + y_coord = arg1 / constant_b0; + const_b_to_use = constant_b0; + } + + if (up % down == 0 && + arg1 % const_b_to_use == 0 && // Coordinates are integers. + constant_lower_bound <= + x_coord && // x_coord is within loop bounds. + x_coord <= constant_upper_bound && + constant_lower_bound <= + y_coord && // y_coord is within loop bounds. + y_coord <= constant_upper_bound) { + // Lines intersect at integer coordinates. + return make_constraint( + scalar_evolution_.CreateConstant(x_coord), + scalar_evolution_.CreateConstant(y_coord), + constraint_0->GetLoop()); + + } else { + return make_constraint(); + } + + } else { + // Not constants, bail out. + return make_constraint(); + } + } + + } else { + // Not constants, bail out. + return make_constraint(); + } + } + + // One constraint is a line/distance and the other is a point. + if ((constraint_0->AsDependencePoint() && + (constraint_1->AsDependenceLine() || + constraint_1->AsDependenceDistance())) || + (constraint_1->AsDependencePoint() && + (constraint_0->AsDependenceLine() || + constraint_0->AsDependenceDistance()))) { + auto point_0 = constraint_0->AsDependencePoint() != nullptr; + + auto point = point_0 ? constraint_0->AsDependencePoint() + : constraint_1->AsDependencePoint(); + + auto line_or_distance = point_0 ? constraint_1 : constraint_0; + + auto is_distance = line_or_distance->AsDependenceDistance() != nullptr; + + auto a = is_distance ? scalar_evolution_.CreateConstant(1) + : line_or_distance->AsDependenceLine()->GetA(); + auto b = is_distance ? scalar_evolution_.CreateConstant(-1) + : line_or_distance->AsDependenceLine()->GetB(); + auto c = + is_distance + ? scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateNegation( + line_or_distance->AsDependenceDistance()->GetDistance())) + : line_or_distance->AsDependenceLine()->GetC(); + + auto x = point->GetSource(); + auto y = point->GetDestination(); + + if (a->AsSEConstantNode() && b->AsSEConstantNode() && + c->AsSEConstantNode() && x->AsSEConstantNode() && + y->AsSEConstantNode()) { + auto constant_a = a->AsSEConstantNode()->FoldToSingleValue(); + auto constant_b = b->AsSEConstantNode()->FoldToSingleValue(); + auto constant_c = c->AsSEConstantNode()->FoldToSingleValue(); + + auto constant_x = x->AsSEConstantNode()->FoldToSingleValue(); + auto constant_y = y->AsSEConstantNode()->FoldToSingleValue(); + + auto left_hand_side = constant_a * constant_x + constant_b * constant_y; + + if (left_hand_side == constant_c) { + // Point is on line, return point + return point_0 ? constraint_0 : constraint_1; + } else { + // Point not on line, report independence (empty constraint). + return make_constraint(); + } + + } else { + // Not constants, bail out. + return make_constraint(); + } + } + + return nullptr; +} + +// Propagate constraints function as described in section 5 of Practical +// Dependence Testing, Goff, Kennedy, Tseng, 1991. +SubscriptPair LoopDependenceAnalysis::PropagateConstraints( + const SubscriptPair& subscript_pair, + const std::vector& constraints) { + SENode* new_first = subscript_pair.first; + SENode* new_second = subscript_pair.second; + + for (auto& constraint : constraints) { + // In the paper this is a[k]. We're extracting the coefficient ('a') of a + // recurrent expression with respect to the loop 'k'. + SENode* coefficient_of_recurrent = + scalar_evolution_.GetCoefficientFromRecurrentTerm( + new_first, constraint->GetLoop()); + + // In the paper this is a'[k]. + SENode* coefficient_of_recurrent_prime = + scalar_evolution_.GetCoefficientFromRecurrentTerm( + new_second, constraint->GetLoop()); + + if (constraint->GetType() == Constraint::Distance) { + DependenceDistance* as_distance = constraint->AsDependenceDistance(); + + // In the paper this is a[k]*d + SENode* rhs = scalar_evolution_.CreateMultiplyNode( + coefficient_of_recurrent, as_distance->GetDistance()); + + // In the paper this is a[k] <- 0 + SENode* zeroed_coefficient = + scalar_evolution_.BuildGraphWithoutRecurrentTerm( + new_first, constraint->GetLoop()); + + // In the paper this is e <- e - a[k]*d. + new_first = scalar_evolution_.CreateSubtraction(zeroed_coefficient, rhs); + new_first = scalar_evolution_.SimplifyExpression(new_first); + + // In the paper this is a'[k] - a[k]. + SENode* new_child = scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateSubtraction(coefficient_of_recurrent_prime, + coefficient_of_recurrent)); + + // In the paper this is a'[k]'i[k]. + SERecurrentNode* prime_recurrent = + scalar_evolution_.GetRecurrentTerm(new_second, constraint->GetLoop()); + + if (!prime_recurrent) continue; + + // As we hash the nodes we need to create a new node when we update a + // child. + SENode* new_recurrent = scalar_evolution_.CreateRecurrentExpression( + constraint->GetLoop(), prime_recurrent->GetOffset(), new_child); + // In the paper this is a'[k] <- a'[k] - a[k]. + new_second = scalar_evolution_.UpdateChildNode( + new_second, prime_recurrent, new_recurrent); + } + } + + new_second = scalar_evolution_.SimplifyExpression(new_second); + return std::make_pair(new_first, new_second); +} + +bool LoopDependenceAnalysis::DeltaTest( + const std::vector& coupled_subscripts, + DistanceVector* dv_entry) { + std::vector constraints(loops_.size()); + + std::vector loop_appeared(loops_.size()); + + std::generate(std::begin(constraints), std::end(constraints), + [this]() { return make_constraint(); }); + + // Separate SIV and MIV subscripts + std::vector siv_subscripts{}; + std::vector miv_subscripts{}; + + for (const auto& subscript_pair : coupled_subscripts) { + if (IsSIV(subscript_pair)) { + siv_subscripts.push_back(subscript_pair); + } else { + miv_subscripts.push_back(subscript_pair); + } + } + + // Delta Test + while (!siv_subscripts.empty()) { + std::vector results(siv_subscripts.size()); + + std::vector current_distances( + siv_subscripts.size(), DistanceVector(loops_.size())); + + // Apply SIV test to all SIV subscripts, report independence if any of them + // is independent + std::transform( + std::begin(siv_subscripts), std::end(siv_subscripts), + std::begin(current_distances), std::begin(results), + [this](SubscriptPair& p, DistanceVector& d) { return SIVTest(p, &d); }); + + if (std::accumulate(std::begin(results), std::end(results), false, + std::logical_or{})) { + return true; + } + + // Derive new constraint vector. + std::vector> all_new_constrants{}; + + for (size_t i = 0; i < siv_subscripts.size(); ++i) { + auto loop = GetLoopForSubscriptPair(siv_subscripts[i]); + + auto loop_id = + std::distance(std::begin(loops_), + std::find(std::begin(loops_), std::end(loops_), loop)); + + loop_appeared[loop_id] = true; + auto distance_entry = current_distances[i].GetEntries()[loop_id]; + + if (distance_entry.dependence_information == + DistanceEntry::DependenceInformation::DISTANCE) { + // Construct a DependenceDistance. + auto node = scalar_evolution_.CreateConstant(distance_entry.distance); + + all_new_constrants.push_back( + {make_constraint(node, loop), loop_id}); + } else { + // Construct a DependenceLine. + const auto& subscript_pair = siv_subscripts[i]; + SENode* source_node = std::get<0>(subscript_pair); + SENode* destination_node = std::get<1>(subscript_pair); + + int64_t source_induction_count = CountInductionVariables(source_node); + int64_t destination_induction_count = + CountInductionVariables(destination_node); + + SENode* a = nullptr; + SENode* b = nullptr; + SENode* c = nullptr; + + if (destination_induction_count != 0) { + a = destination_node->AsSERecurrentNode()->GetCoefficient(); + c = scalar_evolution_.CreateNegation( + destination_node->AsSERecurrentNode()->GetOffset()); + } else { + a = scalar_evolution_.CreateConstant(0); + c = scalar_evolution_.CreateNegation(destination_node); + } + + if (source_induction_count != 0) { + b = scalar_evolution_.CreateNegation( + source_node->AsSERecurrentNode()->GetCoefficient()); + c = scalar_evolution_.CreateAddNode( + c, source_node->AsSERecurrentNode()->GetOffset()); + } else { + b = scalar_evolution_.CreateConstant(0); + c = scalar_evolution_.CreateAddNode(c, source_node); + } + + a = scalar_evolution_.SimplifyExpression(a); + b = scalar_evolution_.SimplifyExpression(b); + c = scalar_evolution_.SimplifyExpression(c); + + all_new_constrants.push_back( + {make_constraint(a, b, c, loop), loop_id}); + } + } + + // Calculate the intersection between the new and existing constraints. + std::vector intersection = constraints; + for (const auto& constraint_to_intersect : all_new_constrants) { + auto loop_id = std::get<1>(constraint_to_intersect); + auto loop = loops_[loop_id]; + intersection[loop_id] = IntersectConstraints( + intersection[loop_id], std::get<0>(constraint_to_intersect), + GetLowerBound(loop), GetUpperBound(loop)); + } + + // Report independence if an empty constraint (DependenceEmpty) is found. + auto first_empty = + std::find_if(std::begin(intersection), std::end(intersection), + [](Constraint* constraint) { + return constraint->AsDependenceEmpty() != nullptr; + }); + if (first_empty != std::end(intersection)) { + return true; + } + std::vector new_siv_subscripts{}; + std::vector new_miv_subscripts{}; + + auto equal = + std::equal(std::begin(constraints), std::end(constraints), + std::begin(intersection), + [](Constraint* a, Constraint* b) { return *a == *b; }); + + // If any constraints have changed, propagate them into the rest of the + // subscripts possibly creating new ZIV/SIV subscripts. + if (!equal) { + std::vector new_subscripts(miv_subscripts.size()); + + // Propagate constraints into MIV subscripts + std::transform(std::begin(miv_subscripts), std::end(miv_subscripts), + std::begin(new_subscripts), + [this, &intersection](SubscriptPair& subscript_pair) { + return PropagateConstraints(subscript_pair, + intersection); + }); + + // If a ZIV subscript is returned, apply test, otherwise, update untested + // subscripts. + for (auto& subscript : new_subscripts) { + if (IsZIV(subscript) && ZIVTest(subscript)) { + return true; + } else if (IsSIV(subscript)) { + new_siv_subscripts.push_back(subscript); + } else { + new_miv_subscripts.push_back(subscript); + } + } + } + + // Set new constraints and subscripts to test. + std::swap(siv_subscripts, new_siv_subscripts); + std::swap(miv_subscripts, new_miv_subscripts); + std::swap(constraints, intersection); + } + + // Create the dependence vector from the constraints. + for (size_t i = 0; i < loops_.size(); ++i) { + // Don't touch entries for loops that weren't tested. + if (loop_appeared[i]) { + auto current_constraint = constraints[i]; + auto& current_distance_entry = (*dv_entry).GetEntries()[i]; + + if (auto dependence_distance = + current_constraint->AsDependenceDistance()) { + if (auto constant_node = + dependence_distance->GetDistance()->AsSEConstantNode()) { + current_distance_entry.dependence_information = + DistanceEntry::DependenceInformation::DISTANCE; + + current_distance_entry.distance = constant_node->FoldToSingleValue(); + if (current_distance_entry.distance == 0) { + current_distance_entry.direction = DistanceEntry::Directions::EQ; + } else if (current_distance_entry.distance < 0) { + current_distance_entry.direction = DistanceEntry::Directions::GT; + } else { + current_distance_entry.direction = DistanceEntry::Directions::LT; + } + } + } else if (auto dependence_point = + current_constraint->AsDependencePoint()) { + auto source = dependence_point->GetSource(); + auto destination = dependence_point->GetDestination(); + + if (source->AsSEConstantNode() && destination->AsSEConstantNode()) { + current_distance_entry = DistanceEntry( + source->AsSEConstantNode()->FoldToSingleValue(), + destination->AsSEConstantNode()->FoldToSingleValue()); + } + } + } + } + + // Test any remaining MIV subscripts and report independence if found. + std::vector results(miv_subscripts.size()); + + std::transform(std::begin(miv_subscripts), std::end(miv_subscripts), + std::begin(results), + [this](const SubscriptPair& p) { return GCDMIVTest(p); }); + + return std::accumulate(std::begin(results), std::end(results), false, + std::logical_or{}); +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/loop_dependence.h b/source/opt/loop_dependence.h new file mode 100644 index 000000000..582c8d0ac --- /dev/null +++ b/source/opt/loop_dependence.h @@ -0,0 +1,558 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_LOOP_DEPENDENCE_H_ +#define SOURCE_OPT_LOOP_DEPENDENCE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/opt/instruction.h" +#include "source/opt/ir_context.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/scalar_analysis.h" + +namespace spvtools { +namespace opt { + +// Stores information about dependence between a load and a store wrt a single +// loop in a loop nest. +// DependenceInformation +// * UNKNOWN if no dependence information can be gathered or is gathered +// for it. +// * DIRECTION if a dependence direction could be found, but not a +// distance. +// * DISTANCE if a dependence distance could be found. +// * PEEL if peeling either the first or last iteration will break +// dependence between the given load and store. +// * IRRELEVANT if it has no effect on the dependence between the given +// load and store. +// +// If peel_first == true, the analysis has found that peeling the first +// iteration of this loop will break dependence. +// +// If peel_last == true, the analysis has found that peeling the last iteration +// of this loop will break dependence. +class DistanceEntry { + public: + enum DependenceInformation { + UNKNOWN = 0, + DIRECTION = 1, + DISTANCE = 2, + PEEL = 3, + IRRELEVANT = 4, + POINT = 5 + }; + enum Directions { + NONE = 0, + LT = 1, + EQ = 2, + LE = 3, + GT = 4, + NE = 5, + GE = 6, + ALL = 7 + }; + DependenceInformation dependence_information; + Directions direction; + int64_t distance; + bool peel_first; + bool peel_last; + int64_t point_x; + int64_t point_y; + + DistanceEntry() + : dependence_information(DependenceInformation::UNKNOWN), + direction(Directions::ALL), + distance(0), + peel_first(false), + peel_last(false), + point_x(0), + point_y(0) {} + + explicit DistanceEntry(Directions direction_) + : dependence_information(DependenceInformation::DIRECTION), + direction(direction_), + distance(0), + peel_first(false), + peel_last(false), + point_x(0), + point_y(0) {} + + DistanceEntry(Directions direction_, int64_t distance_) + : dependence_information(DependenceInformation::DISTANCE), + direction(direction_), + distance(distance_), + peel_first(false), + peel_last(false), + point_x(0), + point_y(0) {} + + DistanceEntry(int64_t x, int64_t y) + : dependence_information(DependenceInformation::POINT), + direction(Directions::ALL), + distance(0), + peel_first(false), + peel_last(false), + point_x(x), + point_y(y) {} + + bool operator==(const DistanceEntry& rhs) const { + return direction == rhs.direction && peel_first == rhs.peel_first && + peel_last == rhs.peel_last && distance == rhs.distance && + point_x == rhs.point_x && point_y == rhs.point_y; + } + + bool operator!=(const DistanceEntry& rhs) const { return !(*this == rhs); } +}; + +// Stores a vector of DistanceEntrys, one per loop in the analysis. +// A DistanceVector holds all of the information gathered in a dependence +// analysis wrt the loops stored in the LoopDependenceAnalysis performing the +// analysis. +class DistanceVector { + public: + explicit DistanceVector(size_t size) : entries(size, DistanceEntry{}) {} + + explicit DistanceVector(std::vector entries_) + : entries(entries_) {} + + DistanceEntry& GetEntry(size_t index) { return entries[index]; } + const DistanceEntry& GetEntry(size_t index) const { return entries[index]; } + + std::vector& GetEntries() { return entries; } + const std::vector& GetEntries() const { return entries; } + + bool operator==(const DistanceVector& rhs) const { + if (entries.size() != rhs.entries.size()) { + return false; + } + for (size_t i = 0; i < entries.size(); ++i) { + if (entries[i] != rhs.entries[i]) { + return false; + } + } + return true; + } + bool operator!=(const DistanceVector& rhs) const { return !(*this == rhs); } + + private: + std::vector entries; +}; + +class DependenceLine; +class DependenceDistance; +class DependencePoint; +class DependenceNone; +class DependenceEmpty; + +class Constraint { + public: + explicit Constraint(const Loop* loop) : loop_(loop) {} + enum ConstraintType { Line, Distance, Point, None, Empty }; + + virtual ConstraintType GetType() const = 0; + + virtual ~Constraint() {} + + // Get the loop this constraint belongs to. + const Loop* GetLoop() const { return loop_; } + + bool operator==(const Constraint& other) const; + + bool operator!=(const Constraint& other) const; + +#define DeclareCastMethod(target) \ + virtual target* As##target() { return nullptr; } \ + virtual const target* As##target() const { return nullptr; } + DeclareCastMethod(DependenceLine); + DeclareCastMethod(DependenceDistance); + DeclareCastMethod(DependencePoint); + DeclareCastMethod(DependenceNone); + DeclareCastMethod(DependenceEmpty); +#undef DeclareCastMethod + + protected: + const Loop* loop_; +}; + +class DependenceLine : public Constraint { + public: + DependenceLine(SENode* a, SENode* b, SENode* c, const Loop* loop) + : Constraint(loop), a_(a), b_(b), c_(c) {} + + ConstraintType GetType() const final { return Line; } + + DependenceLine* AsDependenceLine() final { return this; } + const DependenceLine* AsDependenceLine() const final { return this; } + + SENode* GetA() const { return a_; } + SENode* GetB() const { return b_; } + SENode* GetC() const { return c_; } + + private: + SENode* a_; + SENode* b_; + SENode* c_; +}; + +class DependenceDistance : public Constraint { + public: + DependenceDistance(SENode* distance, const Loop* loop) + : Constraint(loop), distance_(distance) {} + + ConstraintType GetType() const final { return Distance; } + + DependenceDistance* AsDependenceDistance() final { return this; } + const DependenceDistance* AsDependenceDistance() const final { return this; } + + SENode* GetDistance() const { return distance_; } + + private: + SENode* distance_; +}; + +class DependencePoint : public Constraint { + public: + DependencePoint(SENode* source, SENode* destination, const Loop* loop) + : Constraint(loop), source_(source), destination_(destination) {} + + ConstraintType GetType() const final { return Point; } + + DependencePoint* AsDependencePoint() final { return this; } + const DependencePoint* AsDependencePoint() const final { return this; } + + SENode* GetSource() const { return source_; } + SENode* GetDestination() const { return destination_; } + + private: + SENode* source_; + SENode* destination_; +}; + +class DependenceNone : public Constraint { + public: + DependenceNone() : Constraint(nullptr) {} + ConstraintType GetType() const final { return None; } + + DependenceNone* AsDependenceNone() final { return this; } + const DependenceNone* AsDependenceNone() const final { return this; } +}; + +class DependenceEmpty : public Constraint { + public: + DependenceEmpty() : Constraint(nullptr) {} + ConstraintType GetType() const final { return Empty; } + + DependenceEmpty* AsDependenceEmpty() final { return this; } + const DependenceEmpty* AsDependenceEmpty() const final { return this; } +}; + +// Provides dependence information between a store instruction and a load +// instruction inside the same loop in a loop nest. +// +// The analysis can only check dependence between stores and loads with regard +// to the loop nest it is created with. +// +// The analysis can output debugging information to a stream. The output +// describes the control flow of the analysis and what information it can deduce +// at each step. +// SetDebugStream and ClearDebugStream are provided for this functionality. +// +// The dependency algorithm is based on the 1990 Paper +// Practical Dependence Testing +// Gina Goff, Ken Kennedy, Chau-Wen Tseng +// +// The algorithm first identifies subscript pairs between the load and store. +// Each pair is tested until all have been tested or independence is found. +// The number of induction variables in a pair determines which test to perform +// on it; +// Zero Index Variable (ZIV) is used when no induction variables are present +// in the pair. +// Single Index Variable (SIV) is used when only one induction variable is +// present, but may occur multiple times in the pair. +// Multiple Index Variable (MIV) is used when more than one induction variable +// is present in the pair. +class LoopDependenceAnalysis { + public: + LoopDependenceAnalysis(IRContext* context, std::vector loops) + : context_(context), + loops_(loops), + scalar_evolution_(context), + debug_stream_(nullptr), + constraints_{} {} + + // Finds the dependence between |source| and |destination|. + // |source| should be an OpLoad. + // |destination| should be an OpStore. + // Any direction and distance information found will be stored in + // |distance_vector|. + // Returns true if independence is found, false otherwise. + bool GetDependence(const Instruction* source, const Instruction* destination, + DistanceVector* distance_vector); + + // Returns true if |subscript_pair| represents a Zero Index Variable pair + // (ZIV) + bool IsZIV(const std::pair& subscript_pair); + + // Returns true if |subscript_pair| represents a Single Index Variable + // (SIV) pair + bool IsSIV(const std::pair& subscript_pair); + + // Returns true if |subscript_pair| represents a Multiple Index Variable + // (MIV) pair + bool IsMIV(const std::pair& subscript_pair); + + // Finds the lower bound of |loop| as an SENode* and returns the result. + // The lower bound is the starting value of the loops induction variable + SENode* GetLowerBound(const Loop* loop); + + // Finds the upper bound of |loop| as an SENode* and returns the result. + // The upper bound is the last value before the loop exit condition is met. + SENode* GetUpperBound(const Loop* loop); + + // Returns true if |value| is between |bound_one| and |bound_two| (inclusive). + bool IsWithinBounds(int64_t value, int64_t bound_one, int64_t bound_two); + + // Finds the bounds of |loop| as upper_bound - lower_bound and returns the + // resulting SENode. + // If the operations can not be completed a nullptr is returned. + SENode* GetTripCount(const Loop* loop); + + // Returns the SENode* produced by building an SENode from the result of + // calling GetInductionInitValue on |loop|. + // If the operation can not be completed a nullptr is returned. + SENode* GetFirstTripInductionNode(const Loop* loop); + + // Returns the SENode* produced by building an SENode from the result of + // GetFirstTripInductionNode + (GetTripCount - 1) * induction_coefficient. + // If the operation can not be completed a nullptr is returned. + SENode* GetFinalTripInductionNode(const Loop* loop, + SENode* induction_coefficient); + + // Returns all the distinct loops that appear in |nodes|. + std::set CollectLoops( + const std::vector& nodes); + + // Returns all the distinct loops that appear in |source| and |destination|. + std::set CollectLoops(SENode* source, SENode* destination); + + // Returns true if |distance| is provably outside the loop bounds. + // |coefficient| must be an SENode representing the coefficient of the + // induction variable of |loop|. + // This method is able to handle some symbolic cases which IsWithinBounds + // can't handle. + bool IsProvablyOutsideOfLoopBounds(const Loop* loop, SENode* distance, + SENode* coefficient); + + // Sets the ostream for debug information for the analysis. + void SetDebugStream(std::ostream& debug_stream) { + debug_stream_ = &debug_stream; + } + + // Clears the stored ostream to stop debug information printing. + void ClearDebugStream() { debug_stream_ = nullptr; } + + // Returns the ScalarEvolutionAnalysis used by this analysis. + ScalarEvolutionAnalysis* GetScalarEvolution() { return &scalar_evolution_; } + + // Creates a new constraint of type |T| and returns the pointer to it. + template + Constraint* make_constraint(Args&&... args) { + constraints_.push_back( + std::unique_ptr(new T(std::forward(args)...))); + + return constraints_.back().get(); + } + + // Subscript partitioning as described in Figure 1 of 'Practical Dependence + // Testing' by Gina Goff, Ken Kennedy, and Chau-Wen Tseng from PLDI '91. + // Partitions the subscripts into independent subscripts and minimally coupled + // sets of subscripts. + // Returns the partitioning of subscript pairs. Sets of size 1 indicates an + // independent subscript-pair and others indicate coupled sets. + using PartitionedSubscripts = + std::vector>>; + PartitionedSubscripts PartitionSubscripts( + const std::vector& source_subscripts, + const std::vector& destination_subscripts); + + // Returns the Loop* matching the loop for |subscript_pair|. + // |subscript_pair| must be an SIV pair. + const Loop* GetLoopForSubscriptPair( + const std::pair& subscript_pair); + + // Returns the DistanceEntry matching the loop for |subscript_pair|. + // |subscript_pair| must be an SIV pair. + DistanceEntry* GetDistanceEntryForSubscriptPair( + const std::pair& subscript_pair, + DistanceVector* distance_vector); + + // Returns the DistanceEntry matching |loop|. + DistanceEntry* GetDistanceEntryForLoop(const Loop* loop, + DistanceVector* distance_vector); + + // Returns a vector of Instruction* which form the subscripts of the array + // access defined by the access chain |instruction|. + std::vector GetSubscripts(const Instruction* instruction); + + // Delta test as described in Figure 3 of 'Practical Dependence + // Testing' by Gina Goff, Ken Kennedy, and Chau-Wen Tseng from PLDI '91. + bool DeltaTest( + const std::vector>& coupled_subscripts, + DistanceVector* dv_entry); + + // Constraint propagation as described in Figure 5 of 'Practical Dependence + // Testing' by Gina Goff, Ken Kennedy, and Chau-Wen Tseng from PLDI '91. + std::pair PropagateConstraints( + const std::pair& subscript_pair, + const std::vector& constraints); + + // Constraint intersection as described in Figure 4 of 'Practical Dependence + // Testing' by Gina Goff, Ken Kennedy, and Chau-Wen Tseng from PLDI '91. + Constraint* IntersectConstraints(Constraint* constraint_0, + Constraint* constraint_1, + const SENode* lower_bound, + const SENode* upper_bound); + + // Returns true if each loop in |loops| is in a form supported by this + // analysis. + // A loop is supported if it has a single induction variable and that + // induction variable has a step of +1 or -1 per loop iteration. + bool CheckSupportedLoops(std::vector loops); + + // Returns true if |loop| is in a form supported by this analysis. + // A loop is supported if it has a single induction variable and that + // induction variable has a step of +1 or -1 per loop iteration. + bool IsSupportedLoop(const Loop* loop); + + private: + IRContext* context_; + + // The loop nest we are analysing the dependence of. + std::vector loops_; + + // The ScalarEvolutionAnalysis used by this analysis to store and perform much + // of its logic. + ScalarEvolutionAnalysis scalar_evolution_; + + // The ostream debug information for the analysis to print to. + std::ostream* debug_stream_; + + // Stores all the constraints created by the analysis. + std::list> constraints_; + + // Returns true if independence can be proven and false if it can't be proven. + bool ZIVTest(const std::pair& subscript_pair); + + // Analyzes the subscript pair to find an applicable SIV test. + // Returns true if independence can be proven and false if it can't be proven. + bool SIVTest(const std::pair& subscript_pair, + DistanceVector* distance_vector); + + // Takes the form a*i + c1, a*i + c2 + // When c1 and c2 are loop invariant and a is constant + // distance = (c1 - c2)/a + // < if distance > 0 + // direction = = if distance = 0 + // > if distance < 0 + // Returns true if independence is proven and false if it can't be proven. + bool StrongSIVTest(SENode* source, SENode* destination, SENode* coeff, + DistanceEntry* distance_entry); + + // Takes for form a*i + c1, a*i + c2 + // where c1 and c2 are loop invariant and a is constant. + // c1 and/or c2 contain one or more SEValueUnknown nodes. + bool SymbolicStrongSIVTest(SENode* source, SENode* destination, + SENode* coefficient, + DistanceEntry* distance_entry); + + // Takes the form a1*i + c1, a2*i + c2 + // where a1 = 0 + // distance = (c1 - c2) / a2 + // Returns true if independence is proven and false if it can't be proven. + bool WeakZeroSourceSIVTest(SENode* source, SERecurrentNode* destination, + SENode* coefficient, + DistanceEntry* distance_entry); + + // Takes the form a1*i + c1, a2*i + c2 + // where a2 = 0 + // distance = (c2 - c1) / a1 + // Returns true if independence is proven and false if it can't be proven. + bool WeakZeroDestinationSIVTest(SERecurrentNode* source, SENode* destination, + SENode* coefficient, + DistanceEntry* distance_entry); + + // Takes the form a1*i + c1, a2*i + c2 + // where a1 = -a2 + // distance = (c2 - c1) / 2*a1 + // Returns true if independence is proven and false if it can't be proven. + bool WeakCrossingSIVTest(SENode* source, SENode* destination, + SENode* coefficient, DistanceEntry* distance_entry); + + // Uses the def_use_mgr to get the instruction referenced by + // SingleWordInOperand(|id|) when called on |instruction|. + Instruction* GetOperandDefinition(const Instruction* instruction, int id); + + // Perform the GCD test if both, the source and the destination nodes, are in + // the form a0*i0 + a1*i1 + ... an*in + c. + bool GCDMIVTest(const std::pair& subscript_pair); + + // Finds the number of induction variables in |node|. + // Returns -1 on failure. + int64_t CountInductionVariables(SENode* node); + + // Finds the number of induction variables shared between |source| and + // |destination|. + // Returns -1 on failure. + int64_t CountInductionVariables(SENode* source, SENode* destination); + + // Takes the offset from the induction variable and subtracts the lower bound + // from it to get the constant term added to the induction. + // Returns the resuting constant term, or nullptr if it could not be produced. + SENode* GetConstantTerm(const Loop* loop, SERecurrentNode* induction); + + // Marks all the distance entries in |distance_vector| that were relate to + // loops in |loops_| but were not used in any subscripts as irrelevant to the + // to the dependence test. + void MarkUnsusedDistanceEntriesAsIrrelevant(const Instruction* source, + const Instruction* destination, + DistanceVector* distance_vector); + + // Converts |value| to a std::string and returns the result. + // This is required because Android does not compile std::to_string. + template + std::string ToString(valueT value) { + std::ostringstream string_stream; + string_stream << value; + return string_stream.str(); + } + + // Prints |debug_msg| and "\n" to the ostream pointed to by |debug_stream_|. + // Won't print anything if |debug_stream_| is nullptr. + void PrintDebug(std::string debug_msg); +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_LOOP_DEPENDENCE_H_ diff --git a/source/opt/loop_dependence_helpers.cpp b/source/opt/loop_dependence_helpers.cpp new file mode 100644 index 000000000..de27a0a72 --- /dev/null +++ b/source/opt/loop_dependence_helpers.cpp @@ -0,0 +1,541 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/loop_dependence.h" + +#include +#include +#include +#include +#include +#include + +#include "source/opt/basic_block.h" +#include "source/opt/instruction.h" +#include "source/opt/scalar_analysis.h" +#include "source/opt/scalar_analysis_nodes.h" + +namespace spvtools { +namespace opt { + +bool LoopDependenceAnalysis::IsZIV( + const std::pair& subscript_pair) { + return CountInductionVariables(subscript_pair.first, subscript_pair.second) == + 0; +} + +bool LoopDependenceAnalysis::IsSIV( + const std::pair& subscript_pair) { + return CountInductionVariables(subscript_pair.first, subscript_pair.second) == + 1; +} + +bool LoopDependenceAnalysis::IsMIV( + const std::pair& subscript_pair) { + return CountInductionVariables(subscript_pair.first, subscript_pair.second) > + 1; +} + +SENode* LoopDependenceAnalysis::GetLowerBound(const Loop* loop) { + Instruction* cond_inst = loop->GetConditionInst(); + if (!cond_inst) { + return nullptr; + } + Instruction* lower_inst = GetOperandDefinition(cond_inst, 0); + switch (cond_inst->opcode()) { + case SpvOpULessThan: + case SpvOpSLessThan: + case SpvOpULessThanEqual: + case SpvOpSLessThanEqual: + case SpvOpUGreaterThan: + case SpvOpSGreaterThan: + case SpvOpUGreaterThanEqual: + case SpvOpSGreaterThanEqual: { + // If we have a phi we are looking at the induction variable. We look + // through the phi to the initial value of the phi upon entering the loop. + if (lower_inst->opcode() == SpvOpPhi) { + lower_inst = GetOperandDefinition(lower_inst, 0); + // We don't handle looking through multiple phis. + if (lower_inst->opcode() == SpvOpPhi) { + return nullptr; + } + } + return scalar_evolution_.SimplifyExpression( + scalar_evolution_.AnalyzeInstruction(lower_inst)); + } + default: + return nullptr; + } +} + +SENode* LoopDependenceAnalysis::GetUpperBound(const Loop* loop) { + Instruction* cond_inst = loop->GetConditionInst(); + if (!cond_inst) { + return nullptr; + } + Instruction* upper_inst = GetOperandDefinition(cond_inst, 1); + switch (cond_inst->opcode()) { + case SpvOpULessThan: + case SpvOpSLessThan: { + // When we have a < condition we must subtract 1 from the analyzed upper + // instruction. + SENode* upper_bound = scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateSubtraction( + scalar_evolution_.AnalyzeInstruction(upper_inst), + scalar_evolution_.CreateConstant(1))); + return upper_bound; + } + case SpvOpUGreaterThan: + case SpvOpSGreaterThan: { + // When we have a > condition we must add 1 to the analyzed upper + // instruction. + SENode* upper_bound = + scalar_evolution_.SimplifyExpression(scalar_evolution_.CreateAddNode( + scalar_evolution_.AnalyzeInstruction(upper_inst), + scalar_evolution_.CreateConstant(1))); + return upper_bound; + } + case SpvOpULessThanEqual: + case SpvOpSLessThanEqual: + case SpvOpUGreaterThanEqual: + case SpvOpSGreaterThanEqual: { + // We don't need to modify the results of analyzing when we have <= or >=. + SENode* upper_bound = scalar_evolution_.SimplifyExpression( + scalar_evolution_.AnalyzeInstruction(upper_inst)); + return upper_bound; + } + default: + return nullptr; + } +} + +bool LoopDependenceAnalysis::IsWithinBounds(int64_t value, int64_t bound_one, + int64_t bound_two) { + if (bound_one < bound_two) { + // If |bound_one| is the lower bound. + return (value >= bound_one && value <= bound_two); + } else if (bound_one > bound_two) { + // If |bound_two| is the lower bound. + return (value >= bound_two && value <= bound_one); + } else { + // Both bounds have the same value. + return value == bound_one; + } +} + +bool LoopDependenceAnalysis::IsProvablyOutsideOfLoopBounds( + const Loop* loop, SENode* distance, SENode* coefficient) { + // We test to see if we can reduce the coefficient to an integral constant. + SEConstantNode* coefficient_constant = coefficient->AsSEConstantNode(); + if (!coefficient_constant) { + PrintDebug( + "IsProvablyOutsideOfLoopBounds could not reduce coefficient to a " + "SEConstantNode so must exit."); + return false; + } + + SENode* lower_bound = GetLowerBound(loop); + SENode* upper_bound = GetUpperBound(loop); + if (!lower_bound || !upper_bound) { + PrintDebug( + "IsProvablyOutsideOfLoopBounds could not get both the lower and upper " + "bounds so must exit."); + return false; + } + // If the coefficient is positive we calculate bounds as upper - lower + // If the coefficient is negative we calculate bounds as lower - upper + SENode* bounds = nullptr; + if (coefficient_constant->FoldToSingleValue() >= 0) { + PrintDebug( + "IsProvablyOutsideOfLoopBounds found coefficient >= 0.\n" + "Using bounds as upper - lower."); + bounds = scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateSubtraction(upper_bound, lower_bound)); + } else { + PrintDebug( + "IsProvablyOutsideOfLoopBounds found coefficient < 0.\n" + "Using bounds as lower - upper."); + bounds = scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateSubtraction(lower_bound, upper_bound)); + } + + // We can attempt to deal with symbolic cases by subtracting |distance| and + // the bound nodes. If we can subtract, simplify and produce a SEConstantNode + // we can produce some information. + SEConstantNode* distance_minus_bounds = + scalar_evolution_ + .SimplifyExpression( + scalar_evolution_.CreateSubtraction(distance, bounds)) + ->AsSEConstantNode(); + if (distance_minus_bounds) { + PrintDebug( + "IsProvablyOutsideOfLoopBounds found distance - bounds as a " + "SEConstantNode with value " + + ToString(distance_minus_bounds->FoldToSingleValue())); + // If distance - bounds > 0 we prove the distance is outwith the loop + // bounds. + if (distance_minus_bounds->FoldToSingleValue() > 0) { + PrintDebug( + "IsProvablyOutsideOfLoopBounds found distance escaped the loop " + "bounds."); + return true; + } + } + + return false; +} + +const Loop* LoopDependenceAnalysis::GetLoopForSubscriptPair( + const std::pair& subscript_pair) { + // Collect all the SERecurrentNodes. + std::vector source_nodes = + std::get<0>(subscript_pair)->CollectRecurrentNodes(); + std::vector destination_nodes = + std::get<1>(subscript_pair)->CollectRecurrentNodes(); + + // Collect all the loops stored by the SERecurrentNodes. + std::unordered_set loops{}; + for (auto source_nodes_it = source_nodes.begin(); + source_nodes_it != source_nodes.end(); ++source_nodes_it) { + loops.insert((*source_nodes_it)->GetLoop()); + } + for (auto destination_nodes_it = destination_nodes.begin(); + destination_nodes_it != destination_nodes.end(); + ++destination_nodes_it) { + loops.insert((*destination_nodes_it)->GetLoop()); + } + + // If we didn't find 1 loop |subscript_pair| is a subscript over multiple or 0 + // loops. We don't handle this so return nullptr. + if (loops.size() != 1) { + PrintDebug("GetLoopForSubscriptPair found loops.size() != 1."); + return nullptr; + } + return *loops.begin(); +} + +DistanceEntry* LoopDependenceAnalysis::GetDistanceEntryForLoop( + const Loop* loop, DistanceVector* distance_vector) { + if (!loop) { + return nullptr; + } + + DistanceEntry* distance_entry = nullptr; + for (size_t loop_index = 0; loop_index < loops_.size(); ++loop_index) { + if (loop == loops_[loop_index]) { + distance_entry = &(distance_vector->GetEntries()[loop_index]); + break; + } + } + + return distance_entry; +} + +DistanceEntry* LoopDependenceAnalysis::GetDistanceEntryForSubscriptPair( + const std::pair& subscript_pair, + DistanceVector* distance_vector) { + const Loop* loop = GetLoopForSubscriptPair(subscript_pair); + + return GetDistanceEntryForLoop(loop, distance_vector); +} + +SENode* LoopDependenceAnalysis::GetTripCount(const Loop* loop) { + BasicBlock* condition_block = loop->FindConditionBlock(); + if (!condition_block) { + return nullptr; + } + Instruction* induction_instr = loop->FindConditionVariable(condition_block); + if (!induction_instr) { + return nullptr; + } + Instruction* cond_instr = loop->GetConditionInst(); + if (!cond_instr) { + return nullptr; + } + + size_t iteration_count = 0; + + // We have to check the instruction type here. If the condition instruction + // isn't a supported type we can't calculate the trip count. + if (loop->IsSupportedCondition(cond_instr->opcode())) { + if (loop->FindNumberOfIterations(induction_instr, &*condition_block->tail(), + &iteration_count)) { + return scalar_evolution_.CreateConstant( + static_cast(iteration_count)); + } + } + + return nullptr; +} + +SENode* LoopDependenceAnalysis::GetFirstTripInductionNode(const Loop* loop) { + BasicBlock* condition_block = loop->FindConditionBlock(); + if (!condition_block) { + return nullptr; + } + Instruction* induction_instr = loop->FindConditionVariable(condition_block); + if (!induction_instr) { + return nullptr; + } + int64_t induction_initial_value = 0; + if (!loop->GetInductionInitValue(induction_instr, &induction_initial_value)) { + return nullptr; + } + + SENode* induction_init_SENode = scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateConstant(induction_initial_value)); + return induction_init_SENode; +} + +SENode* LoopDependenceAnalysis::GetFinalTripInductionNode( + const Loop* loop, SENode* induction_coefficient) { + SENode* first_trip_induction_node = GetFirstTripInductionNode(loop); + if (!first_trip_induction_node) { + return nullptr; + } + // Get trip_count as GetTripCount - 1 + // This is because the induction variable is not stepped on the first + // iteration of the loop + SENode* trip_count = + scalar_evolution_.SimplifyExpression(scalar_evolution_.CreateSubtraction( + GetTripCount(loop), scalar_evolution_.CreateConstant(1))); + // Return first_trip_induction_node + trip_count * induction_coefficient + return scalar_evolution_.SimplifyExpression(scalar_evolution_.CreateAddNode( + first_trip_induction_node, + scalar_evolution_.CreateMultiplyNode(trip_count, induction_coefficient))); +} + +std::set LoopDependenceAnalysis::CollectLoops( + const std::vector& recurrent_nodes) { + // We don't handle loops with more than one induction variable. Therefore we + // can identify the number of induction variables by collecting all of the + // loops the collected recurrent nodes belong to. + std::set loops{}; + for (auto recurrent_nodes_it = recurrent_nodes.begin(); + recurrent_nodes_it != recurrent_nodes.end(); ++recurrent_nodes_it) { + loops.insert((*recurrent_nodes_it)->GetLoop()); + } + + return loops; +} + +int64_t LoopDependenceAnalysis::CountInductionVariables(SENode* node) { + if (!node) { + return -1; + } + + std::vector recurrent_nodes = node->CollectRecurrentNodes(); + + // We don't handle loops with more than one induction variable. Therefore we + // can identify the number of induction variables by collecting all of the + // loops the collected recurrent nodes belong to. + std::set loops = CollectLoops(recurrent_nodes); + + return static_cast(loops.size()); +} + +std::set LoopDependenceAnalysis::CollectLoops( + SENode* source, SENode* destination) { + if (!source || !destination) { + return std::set{}; + } + + std::vector source_nodes = source->CollectRecurrentNodes(); + std::vector destination_nodes = + destination->CollectRecurrentNodes(); + + std::set loops = CollectLoops(source_nodes); + std::set destination_loops = CollectLoops(destination_nodes); + + loops.insert(std::begin(destination_loops), std::end(destination_loops)); + + return loops; +} + +int64_t LoopDependenceAnalysis::CountInductionVariables(SENode* source, + SENode* destination) { + if (!source || !destination) { + return -1; + } + + std::set loops = CollectLoops(source, destination); + + return static_cast(loops.size()); +} + +Instruction* LoopDependenceAnalysis::GetOperandDefinition( + const Instruction* instruction, int id) { + return context_->get_def_use_mgr()->GetDef( + instruction->GetSingleWordInOperand(id)); +} + +std::vector LoopDependenceAnalysis::GetSubscripts( + const Instruction* instruction) { + Instruction* access_chain = GetOperandDefinition(instruction, 0); + + std::vector subscripts; + + for (auto i = 1u; i < access_chain->NumInOperandWords(); ++i) { + subscripts.push_back(GetOperandDefinition(access_chain, i)); + } + + return subscripts; +} + +SENode* LoopDependenceAnalysis::GetConstantTerm(const Loop* loop, + SERecurrentNode* induction) { + SENode* offset = induction->GetOffset(); + SENode* lower_bound = GetLowerBound(loop); + if (!offset || !lower_bound) { + return nullptr; + } + SENode* constant_term = scalar_evolution_.SimplifyExpression( + scalar_evolution_.CreateSubtraction(offset, lower_bound)); + return constant_term; +} + +bool LoopDependenceAnalysis::CheckSupportedLoops( + std::vector loops) { + for (auto loop : loops) { + if (!IsSupportedLoop(loop)) { + return false; + } + } + return true; +} + +void LoopDependenceAnalysis::MarkUnsusedDistanceEntriesAsIrrelevant( + const Instruction* source, const Instruction* destination, + DistanceVector* distance_vector) { + std::vector source_subscripts = GetSubscripts(source); + std::vector destination_subscripts = GetSubscripts(destination); + + std::set used_loops{}; + + for (Instruction* source_inst : source_subscripts) { + SENode* source_node = scalar_evolution_.SimplifyExpression( + scalar_evolution_.AnalyzeInstruction(source_inst)); + std::vector recurrent_nodes = + source_node->CollectRecurrentNodes(); + for (SERecurrentNode* recurrent_node : recurrent_nodes) { + used_loops.insert(recurrent_node->GetLoop()); + } + } + + for (Instruction* destination_inst : destination_subscripts) { + SENode* destination_node = scalar_evolution_.SimplifyExpression( + scalar_evolution_.AnalyzeInstruction(destination_inst)); + std::vector recurrent_nodes = + destination_node->CollectRecurrentNodes(); + for (SERecurrentNode* recurrent_node : recurrent_nodes) { + used_loops.insert(recurrent_node->GetLoop()); + } + } + + for (size_t i = 0; i < loops_.size(); ++i) { + if (used_loops.find(loops_[i]) == used_loops.end()) { + distance_vector->GetEntries()[i].dependence_information = + DistanceEntry::DependenceInformation::IRRELEVANT; + } + } +} + +bool LoopDependenceAnalysis::IsSupportedLoop(const Loop* loop) { + std::vector inductions{}; + loop->GetInductionVariables(inductions); + if (inductions.size() != 1) { + return false; + } + Instruction* induction = inductions[0]; + SENode* induction_node = scalar_evolution_.SimplifyExpression( + scalar_evolution_.AnalyzeInstruction(induction)); + if (!induction_node->AsSERecurrentNode()) { + return false; + } + SENode* induction_step = + induction_node->AsSERecurrentNode()->GetCoefficient(); + if (!induction_step->AsSEConstantNode()) { + return false; + } + if (!(induction_step->AsSEConstantNode()->FoldToSingleValue() == 1 || + induction_step->AsSEConstantNode()->FoldToSingleValue() == -1)) { + return false; + } + return true; +} + +void LoopDependenceAnalysis::PrintDebug(std::string debug_msg) { + if (debug_stream_) { + (*debug_stream_) << debug_msg << "\n"; + } +} + +bool Constraint::operator==(const Constraint& other) const { + // A distance of |d| is equivalent to a line |x - y = -d| + if ((GetType() == ConstraintType::Distance && + other.GetType() == ConstraintType::Line) || + (GetType() == ConstraintType::Line && + other.GetType() == ConstraintType::Distance)) { + auto is_distance = AsDependenceLine() != nullptr; + + auto as_distance = + is_distance ? AsDependenceDistance() : other.AsDependenceDistance(); + auto distance = as_distance->GetDistance(); + + auto line = other.AsDependenceLine(); + + auto scalar_evolution = distance->GetParentAnalysis(); + + auto neg_distance = scalar_evolution->SimplifyExpression( + scalar_evolution->CreateNegation(distance)); + + return *scalar_evolution->CreateConstant(1) == *line->GetA() && + *scalar_evolution->CreateConstant(-1) == *line->GetB() && + *neg_distance == *line->GetC(); + } + + if (GetType() != other.GetType()) { + return false; + } + + if (AsDependenceDistance()) { + return *AsDependenceDistance()->GetDistance() == + *other.AsDependenceDistance()->GetDistance(); + } + + if (AsDependenceLine()) { + auto this_line = AsDependenceLine(); + auto other_line = other.AsDependenceLine(); + return *this_line->GetA() == *other_line->GetA() && + *this_line->GetB() == *other_line->GetB() && + *this_line->GetC() == *other_line->GetC(); + } + + if (AsDependencePoint()) { + auto this_point = AsDependencePoint(); + auto other_point = other.AsDependencePoint(); + + return *this_point->GetSource() == *other_point->GetSource() && + *this_point->GetDestination() == *other_point->GetDestination(); + } + + return true; +} + +bool Constraint::operator!=(const Constraint& other) const { + return !(*this == other); +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/loop_descriptor.cpp b/source/opt/loop_descriptor.cpp new file mode 100644 index 000000000..11f7e9cfa --- /dev/null +++ b/source/opt/loop_descriptor.cpp @@ -0,0 +1,1012 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/loop_descriptor.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "source/opt/cfg.h" +#include "source/opt/constants.h" +#include "source/opt/dominator_tree.h" +#include "source/opt/ir_builder.h" +#include "source/opt/ir_context.h" +#include "source/opt/iterator.h" +#include "source/opt/tree_iterator.h" +#include "source/util/make_unique.h" + +namespace spvtools { +namespace opt { + +// Takes in a phi instruction |induction| and the loop |header| and returns the +// step operation of the loop. +Instruction* Loop::GetInductionStepOperation( + const Instruction* induction) const { + // Induction must be a phi instruction. + assert(induction->opcode() == SpvOpPhi); + + Instruction* step = nullptr; + + analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr(); + + // Traverse the incoming operands of the phi instruction. + for (uint32_t operand_id = 1; operand_id < induction->NumInOperands(); + operand_id += 2) { + // Incoming edge. + BasicBlock* incoming_block = + context_->cfg()->block(induction->GetSingleWordInOperand(operand_id)); + + // Check if the block is dominated by header, and thus coming from within + // the loop. + if (IsInsideLoop(incoming_block)) { + step = def_use_manager->GetDef( + induction->GetSingleWordInOperand(operand_id - 1)); + break; + } + } + + if (!step || !IsSupportedStepOp(step->opcode())) { + return nullptr; + } + + // The induction variable which binds the loop must only be modified once. + uint32_t lhs = step->GetSingleWordInOperand(0); + uint32_t rhs = step->GetSingleWordInOperand(1); + + // One of the left hand side or right hand side of the step instruction must + // be the induction phi and the other must be an OpConstant. + if (lhs != induction->result_id() && rhs != induction->result_id()) { + return nullptr; + } + + if (def_use_manager->GetDef(lhs)->opcode() != SpvOp::SpvOpConstant && + def_use_manager->GetDef(rhs)->opcode() != SpvOp::SpvOpConstant) { + return nullptr; + } + + return step; +} + +// Returns true if the |step| operation is an induction variable step operation +// which is currently handled. +bool Loop::IsSupportedStepOp(SpvOp step) const { + switch (step) { + case SpvOp::SpvOpISub: + case SpvOp::SpvOpIAdd: + return true; + default: + return false; + } +} + +bool Loop::IsSupportedCondition(SpvOp condition) const { + switch (condition) { + // < + case SpvOp::SpvOpULessThan: + case SpvOp::SpvOpSLessThan: + // > + case SpvOp::SpvOpUGreaterThan: + case SpvOp::SpvOpSGreaterThan: + + // >= + case SpvOp::SpvOpSGreaterThanEqual: + case SpvOp::SpvOpUGreaterThanEqual: + // <= + case SpvOp::SpvOpSLessThanEqual: + case SpvOp::SpvOpULessThanEqual: + + return true; + default: + return false; + } +} + +int64_t Loop::GetResidualConditionValue(SpvOp condition, int64_t initial_value, + int64_t step_value, + size_t number_of_iterations, + size_t factor) { + int64_t remainder = + initial_value + (number_of_iterations % factor) * step_value; + + // We subtract or add one as the above formula calculates the remainder if the + // loop where just less than or greater than. Adding or subtracting one should + // give a functionally equivalent value. + switch (condition) { + case SpvOp::SpvOpSGreaterThanEqual: + case SpvOp::SpvOpUGreaterThanEqual: { + remainder -= 1; + break; + } + case SpvOp::SpvOpSLessThanEqual: + case SpvOp::SpvOpULessThanEqual: { + remainder += 1; + break; + } + + default: + break; + } + return remainder; +} + +Instruction* Loop::GetConditionInst() const { + BasicBlock* condition_block = FindConditionBlock(); + if (!condition_block) { + return nullptr; + } + Instruction* branch_conditional = &*condition_block->tail(); + if (!branch_conditional || + branch_conditional->opcode() != SpvOpBranchConditional) { + return nullptr; + } + Instruction* condition_inst = context_->get_def_use_mgr()->GetDef( + branch_conditional->GetSingleWordInOperand(0)); + if (IsSupportedCondition(condition_inst->opcode())) { + return condition_inst; + } + + return nullptr; +} + +// Extract the initial value from the |induction| OpPhi instruction and store it +// in |value|. If the function couldn't find the initial value of |induction| +// return false. +bool Loop::GetInductionInitValue(const Instruction* induction, + int64_t* value) const { + Instruction* constant_instruction = nullptr; + analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr(); + + for (uint32_t operand_id = 0; operand_id < induction->NumInOperands(); + operand_id += 2) { + BasicBlock* bb = context_->cfg()->block( + induction->GetSingleWordInOperand(operand_id + 1)); + + if (!IsInsideLoop(bb)) { + constant_instruction = def_use_manager->GetDef( + induction->GetSingleWordInOperand(operand_id)); + } + } + + if (!constant_instruction) return false; + + const analysis::Constant* constant = + context_->get_constant_mgr()->FindDeclaredConstant( + constant_instruction->result_id()); + if (!constant) return false; + + if (value) { + const analysis::Integer* type = + constant->AsIntConstant()->type()->AsInteger(); + + if (type->IsSigned()) { + *value = constant->AsIntConstant()->GetS32BitValue(); + } else { + *value = constant->AsIntConstant()->GetU32BitValue(); + } + } + + return true; +} + +Loop::Loop(IRContext* context, DominatorAnalysis* dom_analysis, + BasicBlock* header, BasicBlock* continue_target, + BasicBlock* merge_target) + : context_(context), + loop_header_(header), + loop_continue_(continue_target), + loop_merge_(merge_target), + loop_preheader_(nullptr), + parent_(nullptr), + loop_is_marked_for_removal_(false) { + assert(context); + assert(dom_analysis); + loop_preheader_ = FindLoopPreheader(dom_analysis); + loop_latch_ = FindLatchBlock(); +} + +BasicBlock* Loop::FindLoopPreheader(DominatorAnalysis* dom_analysis) { + CFG* cfg = context_->cfg(); + DominatorTree& dom_tree = dom_analysis->GetDomTree(); + DominatorTreeNode* header_node = dom_tree.GetTreeNode(loop_header_); + + // The loop predecessor. + BasicBlock* loop_pred = nullptr; + + auto header_pred = cfg->preds(loop_header_->id()); + for (uint32_t p_id : header_pred) { + DominatorTreeNode* node = dom_tree.GetTreeNode(p_id); + if (node && !dom_tree.Dominates(header_node, node)) { + // The predecessor is not part of the loop, so potential loop preheader. + if (loop_pred && node->bb_ != loop_pred) { + // If we saw 2 distinct predecessors that are outside the loop, we don't + // have a loop preheader. + return nullptr; + } + loop_pred = node->bb_; + } + } + // Safe guard against invalid code, SPIR-V spec forbids loop with the entry + // node as header. + assert(loop_pred && "The header node is the entry block ?"); + + // So we have a unique basic block that can enter this loop. + // If this loop is the unique successor of this block, then it is a loop + // preheader. + bool is_preheader = true; + uint32_t loop_header_id = loop_header_->id(); + const auto* const_loop_pred = loop_pred; + const_loop_pred->ForEachSuccessorLabel( + [&is_preheader, loop_header_id](const uint32_t id) { + if (id != loop_header_id) is_preheader = false; + }); + if (is_preheader) return loop_pred; + return nullptr; +} + +bool Loop::IsInsideLoop(Instruction* inst) const { + const BasicBlock* parent_block = context_->get_instr_block(inst); + if (!parent_block) return false; + return IsInsideLoop(parent_block); +} + +bool Loop::IsBasicBlockInLoopSlow(const BasicBlock* bb) { + assert(bb->GetParent() && "The basic block does not belong to a function"); + DominatorAnalysis* dom_analysis = + context_->GetDominatorAnalysis(bb->GetParent()); + if (dom_analysis->IsReachable(bb) && + !dom_analysis->Dominates(GetHeaderBlock(), bb)) + return false; + + return true; +} + +BasicBlock* Loop::GetOrCreatePreHeaderBlock() { + if (loop_preheader_) return loop_preheader_; + + CFG* cfg = context_->cfg(); + loop_header_ = cfg->SplitLoopHeader(loop_header_); + return loop_preheader_; +} + +void Loop::SetContinueBlock(BasicBlock* continue_block) { + assert(IsInsideLoop(continue_block)); + loop_continue_ = continue_block; +} + +void Loop::SetLatchBlock(BasicBlock* latch) { +#ifndef NDEBUG + assert(latch->GetParent() && "The basic block does not belong to a function"); + + const auto* const_latch = latch; + const_latch->ForEachSuccessorLabel([this](uint32_t id) { + assert((!IsInsideLoop(id) || id == GetHeaderBlock()->id()) && + "A predecessor of the continue block does not belong to the loop"); + }); +#endif // NDEBUG + assert(IsInsideLoop(latch) && "The continue block is not in the loop"); + + SetLatchBlockImpl(latch); +} + +void Loop::SetMergeBlock(BasicBlock* merge) { +#ifndef NDEBUG + assert(merge->GetParent() && "The basic block does not belong to a function"); +#endif // NDEBUG + assert(!IsInsideLoop(merge) && "The merge block is in the loop"); + + SetMergeBlockImpl(merge); + if (GetHeaderBlock()->GetLoopMergeInst()) { + UpdateLoopMergeInst(); + } +} + +void Loop::SetPreHeaderBlock(BasicBlock* preheader) { + if (preheader) { + assert(!IsInsideLoop(preheader) && "The preheader block is in the loop"); + assert(preheader->tail()->opcode() == SpvOpBranch && + "The preheader block does not unconditionally branch to the header " + "block"); + assert(preheader->tail()->GetSingleWordOperand(0) == + GetHeaderBlock()->id() && + "The preheader block does not unconditionally branch to the header " + "block"); + } + loop_preheader_ = preheader; +} + +BasicBlock* Loop::FindLatchBlock() { + CFG* cfg = context_->cfg(); + + DominatorAnalysis* dominator_analysis = + context_->GetDominatorAnalysis(loop_header_->GetParent()); + + // Look at the predecessors of the loop header to find a predecessor block + // which is dominated by the loop continue target. There should only be one + // block which meets this criteria and this is the latch block, as per the + // SPIR-V spec. + for (uint32_t block_id : cfg->preds(loop_header_->id())) { + if (dominator_analysis->Dominates(loop_continue_->id(), block_id)) { + return cfg->block(block_id); + } + } + + assert( + false && + "Every loop should have a latch block dominated by the continue target"); + return nullptr; +} + +void Loop::GetExitBlocks(std::unordered_set* exit_blocks) const { + CFG* cfg = context_->cfg(); + exit_blocks->clear(); + + for (uint32_t bb_id : GetBlocks()) { + const BasicBlock* bb = cfg->block(bb_id); + bb->ForEachSuccessorLabel([exit_blocks, this](uint32_t succ) { + if (!IsInsideLoop(succ)) { + exit_blocks->insert(succ); + } + }); + } +} + +void Loop::GetMergingBlocks( + std::unordered_set* merging_blocks) const { + assert(GetMergeBlock() && "This loop is not structured"); + CFG* cfg = context_->cfg(); + merging_blocks->clear(); + + std::stack to_visit; + to_visit.push(GetMergeBlock()); + while (!to_visit.empty()) { + const BasicBlock* bb = to_visit.top(); + to_visit.pop(); + merging_blocks->insert(bb->id()); + for (uint32_t pred_id : cfg->preds(bb->id())) { + if (!IsInsideLoop(pred_id) && !merging_blocks->count(pred_id)) { + to_visit.push(cfg->block(pred_id)); + } + } + } +} + +namespace { + +static inline bool IsBasicBlockSafeToClone(IRContext* context, BasicBlock* bb) { + for (Instruction& inst : *bb) { + if (!inst.IsBranch() && !context->IsCombinatorInstruction(&inst)) + return false; + } + + return true; +} + +} // namespace + +bool Loop::IsSafeToClone() const { + CFG& cfg = *context_->cfg(); + + for (uint32_t bb_id : GetBlocks()) { + BasicBlock* bb = cfg.block(bb_id); + assert(bb); + if (!IsBasicBlockSafeToClone(context_, bb)) return false; + } + + // Look at the merge construct. + if (GetHeaderBlock()->GetLoopMergeInst()) { + std::unordered_set blocks; + GetMergingBlocks(&blocks); + blocks.erase(GetMergeBlock()->id()); + for (uint32_t bb_id : blocks) { + BasicBlock* bb = cfg.block(bb_id); + assert(bb); + if (!IsBasicBlockSafeToClone(context_, bb)) return false; + } + } + + return true; +} + +bool Loop::IsLCSSA() const { + CFG* cfg = context_->cfg(); + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + + std::unordered_set exit_blocks; + GetExitBlocks(&exit_blocks); + + // Declare ir_context so we can capture context_ in the below lambda + IRContext* ir_context = context_; + + for (uint32_t bb_id : GetBlocks()) { + for (Instruction& insn : *cfg->block(bb_id)) { + // All uses must be either: + // - In the loop; + // - In an exit block and in a phi instruction. + if (!def_use_mgr->WhileEachUser( + &insn, + [&exit_blocks, ir_context, this](Instruction* use) -> bool { + BasicBlock* parent = ir_context->get_instr_block(use); + assert(parent && "Invalid analysis"); + if (IsInsideLoop(parent)) return true; + if (use->opcode() != SpvOpPhi) return false; + return exit_blocks.count(parent->id()); + })) + return false; + } + } + return true; +} + +bool Loop::ShouldHoistInstruction(IRContext* context, Instruction* inst) { + return AreAllOperandsOutsideLoop(context, inst) && + inst->IsOpcodeCodeMotionSafe(); +} + +bool Loop::AreAllOperandsOutsideLoop(IRContext* context, Instruction* inst) { + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + bool all_outside_loop = true; + + const std::function operand_outside_loop = + [this, &def_use_mgr, &all_outside_loop](uint32_t* id) { + if (this->IsInsideLoop(def_use_mgr->GetDef(*id))) { + all_outside_loop = false; + return; + } + }; + + inst->ForEachInId(operand_outside_loop); + return all_outside_loop; +} + +void Loop::ComputeLoopStructuredOrder( + std::vector* ordered_loop_blocks, bool include_pre_header, + bool include_merge) const { + CFG& cfg = *context_->cfg(); + + // Reserve the memory: all blocks in the loop + extra if needed. + ordered_loop_blocks->reserve(GetBlocks().size() + include_pre_header + + include_merge); + + if (include_pre_header && GetPreHeaderBlock()) + ordered_loop_blocks->push_back(loop_preheader_); + cfg.ForEachBlockInReversePostOrder( + loop_header_, [ordered_loop_blocks, this](BasicBlock* bb) { + if (IsInsideLoop(bb)) ordered_loop_blocks->push_back(bb); + }); + if (include_merge && GetMergeBlock()) + ordered_loop_blocks->push_back(loop_merge_); +} + +LoopDescriptor::LoopDescriptor(IRContext* context, const Function* f) + : loops_(), dummy_top_loop_(nullptr) { + PopulateList(context, f); +} + +LoopDescriptor::~LoopDescriptor() { ClearLoops(); } + +void LoopDescriptor::PopulateList(IRContext* context, const Function* f) { + DominatorAnalysis* dom_analysis = context->GetDominatorAnalysis(f); + + ClearLoops(); + + // Post-order traversal of the dominator tree to find all the OpLoopMerge + // instructions. + DominatorTree& dom_tree = dom_analysis->GetDomTree(); + for (DominatorTreeNode& node : + make_range(dom_tree.post_begin(), dom_tree.post_end())) { + Instruction* merge_inst = node.bb_->GetLoopMergeInst(); + if (merge_inst) { + bool all_backedge_unreachable = true; + for (uint32_t pid : context->cfg()->preds(node.bb_->id())) { + if (dom_analysis->IsReachable(pid) && + dom_analysis->Dominates(node.bb_->id(), pid)) { + all_backedge_unreachable = false; + break; + } + } + if (all_backedge_unreachable) + continue; // ignore this one, we actually never branch back. + + // The id of the merge basic block of this loop. + uint32_t merge_bb_id = merge_inst->GetSingleWordOperand(0); + + // The id of the continue basic block of this loop. + uint32_t continue_bb_id = merge_inst->GetSingleWordOperand(1); + + // The merge target of this loop. + BasicBlock* merge_bb = context->cfg()->block(merge_bb_id); + + // The continue target of this loop. + BasicBlock* continue_bb = context->cfg()->block(continue_bb_id); + + // The basic block containing the merge instruction. + BasicBlock* header_bb = context->get_instr_block(merge_inst); + + // Add the loop to the list of all the loops in the function. + Loop* current_loop = + new Loop(context, dom_analysis, header_bb, continue_bb, merge_bb); + loops_.push_back(current_loop); + + // We have a bottom-up construction, so if this loop has nested-loops, + // they are by construction at the tail of the loop list. + for (auto itr = loops_.rbegin() + 1; itr != loops_.rend(); ++itr) { + Loop* previous_loop = *itr; + + // If the loop already has a parent, then it has been processed. + if (previous_loop->HasParent()) continue; + + // If the current loop does not dominates the previous loop then it is + // not nested loop. + if (!dom_analysis->Dominates(header_bb, + previous_loop->GetHeaderBlock())) + continue; + // If the current loop merge dominates the previous loop then it is + // not nested loop. + if (dom_analysis->Dominates(merge_bb, previous_loop->GetHeaderBlock())) + continue; + + current_loop->AddNestedLoop(previous_loop); + } + DominatorTreeNode* dom_merge_node = dom_tree.GetTreeNode(merge_bb); + for (DominatorTreeNode& loop_node : + make_range(node.df_begin(), node.df_end())) { + // Check if we are in the loop. + if (dom_tree.Dominates(dom_merge_node, &loop_node)) continue; + current_loop->AddBasicBlock(loop_node.bb_); + basic_block_to_loop_.insert( + std::make_pair(loop_node.bb_->id(), current_loop)); + } + } + } + for (Loop* loop : loops_) { + if (!loop->HasParent()) dummy_top_loop_.nested_loops_.push_back(loop); + } +} + +std::vector LoopDescriptor::GetLoopsInBinaryLayoutOrder() { + std::vector ids{}; + + for (size_t i = 0; i < NumLoops(); ++i) { + ids.push_back(GetLoopByIndex(i).GetHeaderBlock()->id()); + } + + std::vector loops{}; + if (!ids.empty()) { + auto function = GetLoopByIndex(0).GetHeaderBlock()->GetParent(); + for (const auto& block : *function) { + auto block_id = block.id(); + + auto element = std::find(std::begin(ids), std::end(ids), block_id); + if (element != std::end(ids)) { + loops.push_back(&GetLoopByIndex(element - std::begin(ids))); + } + } + } + + return loops; +} + +BasicBlock* Loop::FindConditionBlock() const { + if (!loop_merge_) { + return nullptr; + } + BasicBlock* condition_block = nullptr; + + uint32_t in_loop_pred = 0; + for (uint32_t p : context_->cfg()->preds(loop_merge_->id())) { + if (IsInsideLoop(p)) { + if (in_loop_pred) { + // 2 in-loop predecessors. + return nullptr; + } + in_loop_pred = p; + } + } + if (!in_loop_pred) { + // Merge block is unreachable. + return nullptr; + } + + BasicBlock* bb = context_->cfg()->block(in_loop_pred); + + if (!bb) return nullptr; + + const Instruction& branch = *bb->ctail(); + + // Make sure the branch is a conditional branch. + if (branch.opcode() != SpvOpBranchConditional) return nullptr; + + // Make sure one of the two possible branches is to the merge block. + if (branch.GetSingleWordInOperand(1) == loop_merge_->id() || + branch.GetSingleWordInOperand(2) == loop_merge_->id()) { + condition_block = bb; + } + + return condition_block; +} + +bool Loop::FindNumberOfIterations(const Instruction* induction, + const Instruction* branch_inst, + size_t* iterations_out, + int64_t* step_value_out, + int64_t* init_value_out) const { + // From the branch instruction find the branch condition. + analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr(); + + // Condition instruction from the OpConditionalBranch. + Instruction* condition = + def_use_manager->GetDef(branch_inst->GetSingleWordOperand(0)); + + assert(IsSupportedCondition(condition->opcode())); + + // Get the constant manager from the ir context. + analysis::ConstantManager* const_manager = context_->get_constant_mgr(); + + // Find the constant value used by the condition variable. Exit out if it + // isn't a constant int. + const analysis::Constant* upper_bound = + const_manager->FindDeclaredConstant(condition->GetSingleWordOperand(3)); + if (!upper_bound) return false; + + // Must be integer because of the opcode on the condition. + int64_t condition_value = 0; + + const analysis::Integer* type = + upper_bound->AsIntConstant()->type()->AsInteger(); + + if (type->width() > 32) { + return false; + } + + if (type->IsSigned()) { + condition_value = upper_bound->AsIntConstant()->GetS32BitValue(); + } else { + condition_value = upper_bound->AsIntConstant()->GetU32BitValue(); + } + + // Find the instruction which is stepping through the loop. + Instruction* step_inst = GetInductionStepOperation(induction); + if (!step_inst) return false; + + // Find the constant value used by the condition variable. + const analysis::Constant* step_constant = + const_manager->FindDeclaredConstant(step_inst->GetSingleWordOperand(3)); + if (!step_constant) return false; + + // Must be integer because of the opcode on the condition. + int64_t step_value = 0; + + const analysis::Integer* step_type = + step_constant->AsIntConstant()->type()->AsInteger(); + + if (step_type->IsSigned()) { + step_value = step_constant->AsIntConstant()->GetS32BitValue(); + } else { + step_value = step_constant->AsIntConstant()->GetU32BitValue(); + } + + // If this is a subtraction step we should negate the step value. + if (step_inst->opcode() == SpvOp::SpvOpISub) { + step_value = -step_value; + } + + // Find the inital value of the loop and make sure it is a constant integer. + int64_t init_value = 0; + if (!GetInductionInitValue(induction, &init_value)) return false; + + // If iterations is non null then store the value in that. + int64_t num_itrs = GetIterations(condition->opcode(), condition_value, + init_value, step_value); + + // If the loop body will not be reached return false. + if (num_itrs <= 0) { + return false; + } + + if (iterations_out) { + assert(static_cast(num_itrs) <= std::numeric_limits::max()); + *iterations_out = static_cast(num_itrs); + } + + if (step_value_out) { + *step_value_out = step_value; + } + + if (init_value_out) { + *init_value_out = init_value; + } + + return true; +} + +// We retrieve the number of iterations using the following formula, diff / +// |step_value| where diff is calculated differently according to the +// |condition| and uses the |condition_value| and |init_value|. If diff / +// |step_value| is NOT cleanly divisable then we add one to the sum. +int64_t Loop::GetIterations(SpvOp condition, int64_t condition_value, + int64_t init_value, int64_t step_value) const { + int64_t diff = 0; + + switch (condition) { + case SpvOp::SpvOpSLessThan: + case SpvOp::SpvOpULessThan: { + // If the condition is not met to begin with the loop will never iterate. + if (!(init_value < condition_value)) return 0; + + diff = condition_value - init_value; + + // If the operation is a less then operation then the diff and step must + // have the same sign otherwise the induction will never cross the + // condition (either never true or always true). + if ((diff < 0 && step_value > 0) || (diff > 0 && step_value < 0)) { + return 0; + } + + break; + } + case SpvOp::SpvOpSGreaterThan: + case SpvOp::SpvOpUGreaterThan: { + // If the condition is not met to begin with the loop will never iterate. + if (!(init_value > condition_value)) return 0; + + diff = init_value - condition_value; + + // If the operation is a greater than operation then the diff and step + // must have opposite signs. Otherwise the condition will always be true + // or will never be true. + if ((diff < 0 && step_value < 0) || (diff > 0 && step_value > 0)) { + return 0; + } + + break; + } + + case SpvOp::SpvOpSGreaterThanEqual: + case SpvOp::SpvOpUGreaterThanEqual: { + // If the condition is not met to begin with the loop will never iterate. + if (!(init_value >= condition_value)) return 0; + + // We subract one to make it the same as SpvOpGreaterThan as it is + // functionally equivalent. + diff = init_value - (condition_value - 1); + + // If the operation is a greater than operation then the diff and step + // must have opposite signs. Otherwise the condition will always be true + // or will never be true. + if ((diff > 0 && step_value > 0) || (diff < 0 && step_value < 0)) { + return 0; + } + + break; + } + + case SpvOp::SpvOpSLessThanEqual: + case SpvOp::SpvOpULessThanEqual: { + // If the condition is not met to begin with the loop will never iterate. + if (!(init_value <= condition_value)) return 0; + + // We add one to make it the same as SpvOpLessThan as it is functionally + // equivalent. + diff = (condition_value + 1) - init_value; + + // If the operation is a less than operation then the diff and step must + // have the same sign otherwise the induction will never cross the + // condition (either never true or always true). + if ((diff < 0 && step_value > 0) || (diff > 0 && step_value < 0)) { + return 0; + } + + break; + } + + default: + assert(false && + "Could not retrieve number of iterations from the loop condition. " + "Condition is not supported."); + } + + // Take the abs of - step values. + step_value = llabs(step_value); + diff = llabs(diff); + int64_t result = diff / step_value; + + if (diff % step_value != 0) { + result += 1; + } + return result; +} + +// Returns the list of induction variables within the loop. +void Loop::GetInductionVariables( + std::vector& induction_variables) const { + for (Instruction& inst : *loop_header_) { + if (inst.opcode() == SpvOp::SpvOpPhi) { + induction_variables.push_back(&inst); + } + } +} + +Instruction* Loop::FindConditionVariable( + const BasicBlock* condition_block) const { + // Find the branch instruction. + const Instruction& branch_inst = *condition_block->ctail(); + + Instruction* induction = nullptr; + // Verify that the branch instruction is a conditional branch. + if (branch_inst.opcode() == SpvOp::SpvOpBranchConditional) { + // From the branch instruction find the branch condition. + analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr(); + + // Find the instruction representing the condition used in the conditional + // branch. + Instruction* condition = + def_use_manager->GetDef(branch_inst.GetSingleWordOperand(0)); + + // Ensure that the condition is a less than operation. + if (condition && IsSupportedCondition(condition->opcode())) { + // The left hand side operand of the operation. + Instruction* variable_inst = + def_use_manager->GetDef(condition->GetSingleWordOperand(2)); + + // Make sure the variable instruction used is a phi. + if (!variable_inst || variable_inst->opcode() != SpvOpPhi) return nullptr; + + // Make sure the phi instruction only has two incoming blocks. Each + // incoming block will be represented by two in operands in the phi + // instruction, the value and the block which that value came from. We + // assume the cannocalised phi will have two incoming values, one from the + // preheader and one from the continue block. + size_t max_supported_operands = 4; + if (variable_inst->NumInOperands() == max_supported_operands) { + // The operand index of the first incoming block label. + uint32_t operand_label_1 = 1; + + // The operand index of the second incoming block label. + uint32_t operand_label_2 = 3; + + // Make sure one of them is the preheader. + if (!IsInsideLoop( + variable_inst->GetSingleWordInOperand(operand_label_1)) && + !IsInsideLoop( + variable_inst->GetSingleWordInOperand(operand_label_2))) { + return nullptr; + } + + // And make sure that the other is the latch block. + if (variable_inst->GetSingleWordInOperand(operand_label_1) != + loop_latch_->id() && + variable_inst->GetSingleWordInOperand(operand_label_2) != + loop_latch_->id()) { + return nullptr; + } + } else { + return nullptr; + } + + if (!FindNumberOfIterations(variable_inst, &branch_inst, nullptr)) + return nullptr; + induction = variable_inst; + } + } + + return induction; +} + +bool LoopDescriptor::CreatePreHeaderBlocksIfMissing() { + auto modified = false; + + for (auto& loop : *this) { + if (!loop.GetPreHeaderBlock()) { + modified = true; + // TODO(1841): Handle failure to create pre-header. + loop.GetOrCreatePreHeaderBlock(); + } + } + + return modified; +} + +// Add and remove loops which have been marked for addition and removal to +// maintain the state of the loop descriptor class. +void LoopDescriptor::PostModificationCleanup() { + LoopContainerType loops_to_remove_; + for (Loop* loop : loops_) { + if (loop->IsMarkedForRemoval()) { + loops_to_remove_.push_back(loop); + if (loop->HasParent()) { + loop->GetParent()->RemoveChildLoop(loop); + } + } + } + + for (Loop* loop : loops_to_remove_) { + loops_.erase(std::find(loops_.begin(), loops_.end(), loop)); + delete loop; + } + + for (auto& pair : loops_to_add_) { + Loop* parent = pair.first; + std::unique_ptr loop = std::move(pair.second); + + if (parent) { + loop->SetParent(nullptr); + parent->AddNestedLoop(loop.get()); + + for (uint32_t block_id : loop->GetBlocks()) { + parent->AddBasicBlock(block_id); + } + } + + loops_.emplace_back(loop.release()); + } + + loops_to_add_.clear(); +} + +void LoopDescriptor::ClearLoops() { + for (Loop* loop : loops_) { + delete loop; + } + loops_.clear(); +} + +// Adds a new loop nest to the descriptor set. +Loop* LoopDescriptor::AddLoopNest(std::unique_ptr new_loop) { + Loop* loop = new_loop.release(); + if (!loop->HasParent()) dummy_top_loop_.nested_loops_.push_back(loop); + // Iterate from inner to outer most loop, adding basic block to loop mapping + // as we go. + for (Loop& current_loop : + make_range(iterator::begin(loop), iterator::end(nullptr))) { + loops_.push_back(¤t_loop); + for (uint32_t bb_id : current_loop.GetBlocks()) + basic_block_to_loop_.insert(std::make_pair(bb_id, ¤t_loop)); + } + + return loop; +} + +void LoopDescriptor::RemoveLoop(Loop* loop) { + Loop* parent = loop->GetParent() ? loop->GetParent() : &dummy_top_loop_; + parent->nested_loops_.erase(std::find(parent->nested_loops_.begin(), + parent->nested_loops_.end(), loop)); + std::for_each( + loop->nested_loops_.begin(), loop->nested_loops_.end(), + [loop](Loop* sub_loop) { sub_loop->SetParent(loop->GetParent()); }); + parent->nested_loops_.insert(parent->nested_loops_.end(), + loop->nested_loops_.begin(), + loop->nested_loops_.end()); + for (uint32_t bb_id : loop->GetBlocks()) { + Loop* l = FindLoopForBasicBlock(bb_id); + if (l == loop) { + SetBasicBlockToLoop(bb_id, l->GetParent()); + } else { + ForgetBasicBlock(bb_id); + } + } + + LoopContainerType::iterator it = + std::find(loops_.begin(), loops_.end(), loop); + assert(it != loops_.end()); + delete loop; + loops_.erase(it); +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/loop_descriptor.h b/source/opt/loop_descriptor.h new file mode 100644 index 000000000..6e2b82896 --- /dev/null +++ b/source/opt/loop_descriptor.h @@ -0,0 +1,574 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_LOOP_DESCRIPTOR_H_ +#define SOURCE_OPT_LOOP_DESCRIPTOR_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/opt/basic_block.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/module.h" +#include "source/opt/tree_iterator.h" + +namespace spvtools { +namespace opt { + +class IRContext; +class CFG; +class LoopDescriptor; + +// A class to represent and manipulate a loop in structured control flow. +class Loop { + // The type used to represent nested child loops. + using ChildrenList = std::vector; + + public: + using iterator = ChildrenList::iterator; + using const_iterator = ChildrenList::const_iterator; + using BasicBlockListTy = std::unordered_set; + + explicit Loop(IRContext* context) + : context_(context), + loop_header_(nullptr), + loop_continue_(nullptr), + loop_merge_(nullptr), + loop_preheader_(nullptr), + loop_latch_(nullptr), + parent_(nullptr), + loop_is_marked_for_removal_(false) {} + + Loop(IRContext* context, DominatorAnalysis* analysis, BasicBlock* header, + BasicBlock* continue_target, BasicBlock* merge_target); + + // Iterators over the immediate sub-loops. + inline iterator begin() { return nested_loops_.begin(); } + inline iterator end() { return nested_loops_.end(); } + inline const_iterator begin() const { return cbegin(); } + inline const_iterator end() const { return cend(); } + inline const_iterator cbegin() const { return nested_loops_.begin(); } + inline const_iterator cend() const { return nested_loops_.end(); } + + // Returns the header (first basic block of the loop). This block contains the + // OpLoopMerge instruction. + inline BasicBlock* GetHeaderBlock() { return loop_header_; } + inline const BasicBlock* GetHeaderBlock() const { return loop_header_; } + inline void SetHeaderBlock(BasicBlock* header) { loop_header_ = header; } + + // Updates the OpLoopMerge instruction to reflect the current state of the + // loop. + inline void UpdateLoopMergeInst() { + assert(GetHeaderBlock()->GetLoopMergeInst() && + "The loop is not structured"); + Instruction* merge_inst = GetHeaderBlock()->GetLoopMergeInst(); + merge_inst->SetInOperand(0, {GetMergeBlock()->id()}); + } + + // Returns the continue target basic block. This is the block designated as + // the continue target by the OpLoopMerge instruction. + inline BasicBlock* GetContinueBlock() { return loop_continue_; } + inline const BasicBlock* GetContinueBlock() const { return loop_continue_; } + + // Returns the latch basic block (basic block that holds the back-edge). + // These functions return nullptr if the loop is not structured (i.e. if it + // has more than one backedge). + inline BasicBlock* GetLatchBlock() { return loop_latch_; } + inline const BasicBlock* GetLatchBlock() const { return loop_latch_; } + + // Sets |latch| as the loop unique block branching back to the header. + // A latch block must have the following properties: + // - |latch| must be in the loop; + // - must be the only block branching back to the header block. + void SetLatchBlock(BasicBlock* latch); + + // Sets |continue_block| as the continue block of the loop. This should be the + // continue target of the OpLoopMerge and should dominate the latch block. + void SetContinueBlock(BasicBlock* continue_block); + + // Returns the basic block which marks the end of the loop. + // These functions return nullptr if the loop is not structured. + inline BasicBlock* GetMergeBlock() { return loop_merge_; } + inline const BasicBlock* GetMergeBlock() const { return loop_merge_; } + // Sets |merge| as the loop merge block. A merge block must have the following + // properties: + // - |merge| must not be in the loop; + // - all its predecessors must be in the loop. + // - it must not be already used as merge block. + // If the loop has an OpLoopMerge in its header, this instruction is also + // updated. + void SetMergeBlock(BasicBlock* merge); + + // Returns the loop pre-header, nullptr means that the loop predecessor does + // not qualify as a preheader. + // The preheader is the unique predecessor that: + // - Dominates the loop header; + // - Has only the loop header as successor. + inline BasicBlock* GetPreHeaderBlock() { return loop_preheader_; } + + // Returns the loop pre-header. + inline const BasicBlock* GetPreHeaderBlock() const { return loop_preheader_; } + // Sets |preheader| as the loop preheader block. A preheader block must have + // the following properties: + // - |merge| must not be in the loop; + // - have an unconditional branch to the loop header. + void SetPreHeaderBlock(BasicBlock* preheader); + + // Returns the loop pre-header, if there is no suitable preheader it will be + // created. Returns |nullptr| if it fails to create the preheader. + BasicBlock* GetOrCreatePreHeaderBlock(); + + // Returns true if this loop contains any nested loops. + inline bool HasNestedLoops() const { return nested_loops_.size() != 0; } + + // Clears and fills |exit_blocks| with all basic blocks that are not in the + // loop and has at least one predecessor in the loop. + void GetExitBlocks(std::unordered_set* exit_blocks) const; + + // Clears and fills |merging_blocks| with all basic blocks that are + // post-dominated by the merge block. The merge block must exist. + // The set |merging_blocks| will only contain the merge block if it is + // unreachable. + void GetMergingBlocks(std::unordered_set* merging_blocks) const; + + // Returns true if the loop is in a Loop Closed SSA form. + // In LCSSA form, all in-loop definitions are used in the loop or in phi + // instructions in the loop exit blocks. + bool IsLCSSA() const; + + // Returns the depth of this loop in the loop nest. + // The outer-most loop has a depth of 1. + inline size_t GetDepth() const { + size_t lvl = 1; + for (const Loop* loop = GetParent(); loop; loop = loop->GetParent()) lvl++; + return lvl; + } + + inline size_t NumImmediateChildren() const { return nested_loops_.size(); } + + inline bool HasChildren() const { return !nested_loops_.empty(); } + // Adds |nested| as a nested loop of this loop. Automatically register |this| + // as the parent of |nested|. + inline void AddNestedLoop(Loop* nested) { + assert(!nested->GetParent() && "The loop has another parent."); + nested_loops_.push_back(nested); + nested->SetParent(this); + } + + inline Loop* GetParent() { return parent_; } + inline const Loop* GetParent() const { return parent_; } + + inline bool HasParent() const { return parent_; } + + // Returns true if this loop is itself nested within another loop. + inline bool IsNested() const { return parent_ != nullptr; } + + // Returns the set of all basic blocks contained within the loop. Will be all + // BasicBlocks dominated by the header which are not also dominated by the + // loop merge block. + inline const BasicBlockListTy& GetBlocks() const { + return loop_basic_blocks_; + } + + // Returns true if the basic block |bb| is inside this loop. + inline bool IsInsideLoop(const BasicBlock* bb) const { + return IsInsideLoop(bb->id()); + } + + // Returns true if the basic block id |bb_id| is inside this loop. + inline bool IsInsideLoop(uint32_t bb_id) const { + return loop_basic_blocks_.count(bb_id); + } + + // Returns true if the instruction |inst| is inside this loop. + bool IsInsideLoop(Instruction* inst) const; + + // Adds the Basic Block |bb| to this loop and its parents. + void AddBasicBlock(const BasicBlock* bb) { AddBasicBlock(bb->id()); } + + // Adds the Basic Block with |id| to this loop and its parents. + void AddBasicBlock(uint32_t id) { + for (Loop* loop = this; loop != nullptr; loop = loop->parent_) { + loop->loop_basic_blocks_.insert(id); + } + } + + // Removes the Basic Block id |bb_id| from this loop and its parents. + // It the user responsibility to make sure the removed block is not a merge, + // header or continue block. + void RemoveBasicBlock(uint32_t bb_id) { + for (Loop* loop = this; loop != nullptr; loop = loop->parent_) { + loop->loop_basic_blocks_.erase(bb_id); + } + } + + // Removes all the basic blocks from the set of basic blocks within the loop. + // This does not affect any of the stored pointers to the header, preheader, + // merge, or continue blocks. + void ClearBlocks() { loop_basic_blocks_.clear(); } + + // Adds the Basic Block |bb| this loop and its parents. + void AddBasicBlockToLoop(const BasicBlock* bb) { + assert(IsBasicBlockInLoopSlow(bb) && + "Basic block does not belong to the loop"); + + AddBasicBlock(bb); + } + + // Returns the list of induction variables within the loop. + void GetInductionVariables(std::vector& inductions) const; + + // This function uses the |condition| to find the induction variable which is + // used by the loop condition within the loop. This only works if the loop is + // bound by a single condition and single induction variable. + Instruction* FindConditionVariable(const BasicBlock* condition) const; + + // Returns the number of iterations within a loop when given the |induction| + // variable and the loop |condition| check. It stores the found number of + // iterations in the output parameter |iterations| and optionally, the step + // value in |step_value| and the initial value of the induction variable in + // |init_value|. + bool FindNumberOfIterations(const Instruction* induction, + const Instruction* condition, size_t* iterations, + int64_t* step_amount = nullptr, + int64_t* init_value = nullptr) const; + + // Returns the value of the OpLoopMerge control operand as a bool. Loop + // control can be None(0), Unroll(1), or DontUnroll(2). This function returns + // true if it is set to Unroll. + inline bool HasUnrollLoopControl() const { + assert(loop_header_); + if (!loop_header_->GetLoopMergeInst()) return false; + + return loop_header_->GetLoopMergeInst()->GetSingleWordOperand(2) == 1; + } + + // Finds the conditional block with a branch to the merge and continue blocks + // within the loop body. + BasicBlock* FindConditionBlock() const; + + // Remove the child loop form this loop. + inline void RemoveChildLoop(Loop* loop) { + nested_loops_.erase( + std::find(nested_loops_.begin(), nested_loops_.end(), loop)); + loop->SetParent(nullptr); + } + + // Mark this loop to be removed later by a call to + // LoopDescriptor::PostModificationCleanup. + inline void MarkLoopForRemoval() { loop_is_marked_for_removal_ = true; } + + // Returns whether or not this loop has been marked for removal. + inline bool IsMarkedForRemoval() const { return loop_is_marked_for_removal_; } + + // Returns true if all nested loops have been marked for removal. + inline bool AreAllChildrenMarkedForRemoval() const { + for (const Loop* child : nested_loops_) { + if (!child->IsMarkedForRemoval()) { + return false; + } + } + return true; + } + + // Checks if the loop contains any instruction that will prevent it from being + // cloned. If the loop is structured, the merge construct is also considered. + bool IsSafeToClone() const; + + // Sets the parent loop of this loop, that is, a loop which contains this loop + // as a nested child loop. + inline void SetParent(Loop* parent) { parent_ = parent; } + + // Returns true is the instruction is invariant and safe to move wrt loop + bool ShouldHoistInstruction(IRContext* context, Instruction* inst); + + // Returns true if all operands of inst are in basic blocks not contained in + // loop + bool AreAllOperandsOutsideLoop(IRContext* context, Instruction* inst); + + // Extract the initial value from the |induction| variable and store it in + // |value|. If the function couldn't find the initial value of |induction| + // return false. + bool GetInductionInitValue(const Instruction* induction, + int64_t* value) const; + + // Takes in a phi instruction |induction| and the loop |header| and returns + // the step operation of the loop. + Instruction* GetInductionStepOperation(const Instruction* induction) const; + + // Returns true if we can deduce the number of loop iterations in the step + // operation |step|. IsSupportedCondition must also be true for the condition + // instruction. + bool IsSupportedStepOp(SpvOp step) const; + + // Returns true if we can deduce the number of loop iterations in the + // condition operation |condition|. IsSupportedStepOp must also be true for + // the step instruction. + bool IsSupportedCondition(SpvOp condition) const; + + // Creates the list of the loop's basic block in structured order and store + // the result in |ordered_loop_blocks|. If |include_pre_header| is true, the + // pre-header block will also be included at the beginning of the list if it + // exist. If |include_merge| is true, the merge block will also be included at + // the end of the list if it exist. + void ComputeLoopStructuredOrder(std::vector* ordered_loop_blocks, + bool include_pre_header = false, + bool include_merge = false) const; + + // Given the loop |condition|, |initial_value|, |step_value|, the trip count + // |number_of_iterations|, and the |unroll_factor| requested, get the new + // condition value for the residual loop. + static int64_t GetResidualConditionValue(SpvOp condition, + int64_t initial_value, + int64_t step_value, + size_t number_of_iterations, + size_t unroll_factor); + + // Returns the condition instruction for entry into the loop + // Returns nullptr if it can't be found. + Instruction* GetConditionInst() const; + + // Returns the context associated this loop. + IRContext* GetContext() const { return context_; } + + // Looks at all the blocks with a branch to the header block to find one + // which is also dominated by the loop continue block. This block is the latch + // block. The specification mandates that this block should exist, therefore + // this function will assert if it is not found. + BasicBlock* FindLatchBlock(); + + private: + IRContext* context_; + // The block which marks the start of the loop. + BasicBlock* loop_header_; + + // The block which begins the body of the loop. + BasicBlock* loop_continue_; + + // The block which marks the end of the loop. + BasicBlock* loop_merge_; + + // The block immediately before the loop header. + BasicBlock* loop_preheader_; + + // The block containing the backedge to the loop header. + BasicBlock* loop_latch_; + + // A parent of a loop is the loop which contains it as a nested child loop. + Loop* parent_; + + // Nested child loops of this loop. + ChildrenList nested_loops_; + + // A set of all the basic blocks which comprise the loop structure. Will be + // computed only when needed on demand. + BasicBlockListTy loop_basic_blocks_; + + // Check that |bb| is inside the loop using domination property. + // Note: this is for assertion purposes only, IsInsideLoop should be used + // instead. + bool IsBasicBlockInLoopSlow(const BasicBlock* bb); + + // Returns the loop preheader if it exists, returns nullptr otherwise. + BasicBlock* FindLoopPreheader(DominatorAnalysis* dom_analysis); + + // Sets |latch| as the loop unique latch block. No checks are performed + // here. + inline void SetLatchBlockImpl(BasicBlock* latch) { loop_latch_ = latch; } + // Sets |merge| as the loop merge block. No checks are performed here. + inline void SetMergeBlockImpl(BasicBlock* merge) { loop_merge_ = merge; } + + // Each differnt loop |condition| affects how we calculate the number of + // iterations using the |condition_value|, |init_value|, and |step_values| of + // the induction variable. This method will return the number of iterations in + // a loop with those values for a given |condition|. + int64_t GetIterations(SpvOp condition, int64_t condition_value, + int64_t init_value, int64_t step_value) const; + + // This is to allow for loops to be removed mid iteration without invalidating + // the iterators. + bool loop_is_marked_for_removal_; + + // This is only to allow LoopDescriptor::dummy_top_loop_ to add top level + // loops as child. + friend class LoopDescriptor; + friend class LoopUtils; +}; + +// Loop descriptions class for a given function. +// For a given function, the class builds loop nests information. +// The analysis expects a structured control flow. +class LoopDescriptor { + public: + // Iterator interface (depth first postorder traversal). + using iterator = PostOrderTreeDFIterator; + using const_iterator = PostOrderTreeDFIterator; + + using pre_iterator = TreeDFIterator; + using const_pre_iterator = TreeDFIterator; + + // Creates a loop object for all loops found in |f|. + LoopDescriptor(IRContext* context, const Function* f); + + // Disable copy constructor, to avoid double-free on destruction. + LoopDescriptor(const LoopDescriptor&) = delete; + // Move constructor. + LoopDescriptor(LoopDescriptor&& other) : dummy_top_loop_(nullptr) { + // We need to take ownership of the Loop objects in the other + // LoopDescriptor, to avoid double-free. + loops_ = std::move(other.loops_); + other.loops_.clear(); + basic_block_to_loop_ = std::move(other.basic_block_to_loop_); + other.basic_block_to_loop_.clear(); + dummy_top_loop_ = std::move(other.dummy_top_loop_); + } + + // Destructor + ~LoopDescriptor(); + + // Returns the number of loops found in the function. + inline size_t NumLoops() const { return loops_.size(); } + + // Returns the loop at a particular |index|. The |index| must be in bounds, + // check with NumLoops before calling. + inline Loop& GetLoopByIndex(size_t index) const { + assert(loops_.size() > index && + "Index out of range (larger than loop count)"); + return *loops_[index]; + } + + // Returns the loops in |this| in the order their headers appear in the + // binary. + std::vector GetLoopsInBinaryLayoutOrder(); + + // Returns the inner most loop that contains the basic block id |block_id|. + inline Loop* operator[](uint32_t block_id) const { + return FindLoopForBasicBlock(block_id); + } + + // Returns the inner most loop that contains the basic block |bb|. + inline Loop* operator[](const BasicBlock* bb) const { + return (*this)[bb->id()]; + } + + // Iterators for post order depth first traversal of the loops. + // Inner most loops will be visited first. + inline iterator begin() { return iterator::begin(&dummy_top_loop_); } + inline iterator end() { return iterator::end(&dummy_top_loop_); } + inline const_iterator begin() const { return cbegin(); } + inline const_iterator end() const { return cend(); } + inline const_iterator cbegin() const { + return const_iterator::begin(&dummy_top_loop_); + } + inline const_iterator cend() const { + return const_iterator::end(&dummy_top_loop_); + } + + // Iterators for pre-order depth first traversal of the loops. + // Inner most loops will be visited first. + inline pre_iterator pre_begin() { return ++pre_iterator(&dummy_top_loop_); } + inline pre_iterator pre_end() { return pre_iterator(); } + inline const_pre_iterator pre_begin() const { return pre_cbegin(); } + inline const_pre_iterator pre_end() const { return pre_cend(); } + inline const_pre_iterator pre_cbegin() const { + return ++const_pre_iterator(&dummy_top_loop_); + } + inline const_pre_iterator pre_cend() const { return const_pre_iterator(); } + + // Returns the inner most loop that contains the basic block |bb|. + inline void SetBasicBlockToLoop(uint32_t bb_id, Loop* loop) { + basic_block_to_loop_[bb_id] = loop; + } + + // Mark the loop |loop_to_add| as needing to be added when the user calls + // PostModificationCleanup. |parent| may be null. + inline void AddLoop(std::unique_ptr&& loop_to_add, Loop* parent) { + loops_to_add_.emplace_back(std::make_pair(parent, std::move(loop_to_add))); + } + + // Checks all loops in |this| and will create pre-headers for all loops + // that don't have one. Returns |true| if any blocks were created. + bool CreatePreHeaderBlocksIfMissing(); + + // Should be called to preserve the LoopAnalysis after loops have been marked + // for addition with AddLoop or MarkLoopForRemoval. + void PostModificationCleanup(); + + // Removes the basic block id |bb_id| from the block to loop mapping. + inline void ForgetBasicBlock(uint32_t bb_id) { + basic_block_to_loop_.erase(bb_id); + } + + // Adds the loop |new_loop| and all its nested loops to the descriptor set. + // The object takes ownership of all the loops. + Loop* AddLoopNest(std::unique_ptr new_loop); + + // Remove the loop |loop|. + void RemoveLoop(Loop* loop); + + void SetAsTopLoop(Loop* loop) { + assert(std::find(dummy_top_loop_.begin(), dummy_top_loop_.end(), loop) == + dummy_top_loop_.end() && + "already registered"); + dummy_top_loop_.nested_loops_.push_back(loop); + } + + Loop* GetDummyRootLoop() { return &dummy_top_loop_; } + const Loop* GetDummyRootLoop() const { return &dummy_top_loop_; } + + private: + // TODO(dneto): This should be a vector of unique_ptr. But VisualStudio 2013 + // is unable to compile it. + using LoopContainerType = std::vector; + + using LoopsToAddContainerType = + std::vector>>; + + // Creates loop descriptors for the function |f|. + void PopulateList(IRContext* context, const Function* f); + + // Returns the inner most loop that contains the basic block id |block_id|. + inline Loop* FindLoopForBasicBlock(uint32_t block_id) const { + std::unordered_map::const_iterator it = + basic_block_to_loop_.find(block_id); + return it != basic_block_to_loop_.end() ? it->second : nullptr; + } + + // Erase all the loop information. + void ClearLoops(); + + // A list of all the loops in the function. This variable owns the Loop + // objects. + LoopContainerType loops_; + + // Dummy root: this "loop" is only there to help iterators creation. + Loop dummy_top_loop_; + + std::unordered_map basic_block_to_loop_; + + // List of the loops marked for addition when PostModificationCleanup is + // called. + LoopsToAddContainerType loops_to_add_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_LOOP_DESCRIPTOR_H_ diff --git a/source/opt/loop_fission.cpp b/source/opt/loop_fission.cpp new file mode 100644 index 000000000..0678113c4 --- /dev/null +++ b/source/opt/loop_fission.cpp @@ -0,0 +1,513 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/loop_fission.h" + +#include + +#include "source/opt/register_pressure.h" + +// Implement loop fission with an optional parameter to split only +// if the register pressure in a given loop meets a certain criteria. This is +// controlled via the constructors of LoopFissionPass. +// +// 1 - Build a list of loops to be split, these are top level loops (loops +// without child loops themselves) which meet the register pressure criteria, as +// determined by the ShouldSplitLoop method of LoopFissionPass. +// +// 2 - For each loop in the list, group each instruction into a set of related +// instructions by traversing each instructions users and operands recursively. +// We stop if we encounter an instruction we have seen before or an instruction +// which we don't consider relevent (i.e OpLoopMerge). We then group these +// groups into two different sets, one for the first loop and one for the +// second. +// +// 3 - We then run CanPerformSplit to check that it would be legal to split a +// loop using those two sets. We check that we haven't altered the relative +// order load/stores appear in the binary and that we aren't breaking any +// dependency between load/stores by splitting them into two loops. We also +// check that none of the OpBranch instructions are dependent on a load as we +// leave control flow structure intact and move only instructions in the body so +// we want to avoid any loads with side affects or aliasing. +// +// 4 - We then split the loop by calling SplitLoop. This function clones the +// loop and attaches it to the preheader and connects the new loops merge block +// to the current loop header block. We then use the two sets built in step 2 to +// remove instructions from each loop. If an instruction appears in the first +// set it is removed from the second loop and vice versa. +// +// 5 - If the multiple split passes flag is set we check if each of the loops +// still meet the register pressure criteria. If they do then we add them to the +// list of loops to be split (created in step one) to allow for loops to be +// split multiple times. +// + +namespace spvtools { +namespace opt { + +class LoopFissionImpl { + public: + LoopFissionImpl(IRContext* context, Loop* loop) + : context_(context), loop_(loop), load_used_in_condition_(false) {} + + // Group each instruction in the loop into sets of instructions related by + // their usedef chains. An instruction which uses another will appear in the + // same set. Then merge those sets into just two sets. Returns false if there + // was one or less sets created. + bool GroupInstructionsByUseDef(); + + // Check if the sets built by GroupInstructionsByUseDef violate any data + // dependence rules. + bool CanPerformSplit(); + + // Split the loop and return a pointer to the new loop. + Loop* SplitLoop(); + + // Checks if |inst| is safe to move. We can only move instructions which don't + // have any side effects and OpLoads and OpStores. + bool MovableInstruction(const Instruction& inst) const; + + private: + // Traverse the def use chain of |inst| and add the users and uses of |inst| + // which are in the same loop to the |returned_set|. + void TraverseUseDef(Instruction* inst, std::set* returned_set, + bool ignore_phi_users = false, bool report_loads = false); + + // We group the instructions in the block into two different groups, the + // instructions to be kept in the original loop and the ones to be cloned into + // the new loop. As the cloned loop is attached to the preheader it will be + // the first loop and the second loop will be the original. + std::set cloned_loop_instructions_; + std::set original_loop_instructions_; + + // We need a set of all the instructions to be seen so we can break any + // recursion and also so we can ignore certain instructions by preemptively + // adding them to this set. + std::set seen_instructions_; + + // A map of instructions to their relative position in the function. + std::map instruction_order_; + + IRContext* context_; + + Loop* loop_; + + // This is set to true by TraverseUseDef when traversing the instructions + // related to the loop condition and any if conditions should any of those + // instructions be a load. + bool load_used_in_condition_; +}; + +bool LoopFissionImpl::MovableInstruction(const Instruction& inst) const { + return inst.opcode() == SpvOp::SpvOpLoad || + inst.opcode() == SpvOp::SpvOpStore || + inst.opcode() == SpvOp::SpvOpSelectionMerge || + inst.opcode() == SpvOp::SpvOpPhi || inst.IsOpcodeCodeMotionSafe(); +} + +void LoopFissionImpl::TraverseUseDef(Instruction* inst, + std::set* returned_set, + bool ignore_phi_users, bool report_loads) { + assert(returned_set && "Set to be returned cannot be null."); + + analysis::DefUseManager* def_use = context_->get_def_use_mgr(); + std::set& inst_set = *returned_set; + + // We create this functor to traverse the use def chain to build the + // grouping of related instructions. The lambda captures the std::function + // to allow it to recurse. + std::function traverser_functor; + traverser_functor = [this, def_use, &inst_set, &traverser_functor, + ignore_phi_users, report_loads](Instruction* user) { + // If we've seen the instruction before or it is not inside the loop end the + // traversal. + if (!user || seen_instructions_.count(user) != 0 || + !context_->get_instr_block(user) || + !loop_->IsInsideLoop(context_->get_instr_block(user))) { + return; + } + + // Don't include labels or loop merge instructions in the instruction sets. + // Including them would mean we group instructions related only by using the + // same labels (i.e phis). We already preempt the inclusion of + // OpSelectionMerge by adding related instructions to the seen_instructions_ + // set. + if (user->opcode() == SpvOp::SpvOpLoopMerge || + user->opcode() == SpvOp::SpvOpLabel) + return; + + // If the |report_loads| flag is set, set the class field + // load_used_in_condition_ to false. This is used to check that none of the + // condition checks in the loop rely on loads. + if (user->opcode() == SpvOp::SpvOpLoad && report_loads) { + load_used_in_condition_ = true; + } + + // Add the instruction to the set of instructions already seen, this breaks + // recursion and allows us to ignore certain instructions. + seen_instructions_.insert(user); + + inst_set.insert(user); + + // Wrapper functor to traverse the operands of each instruction. + auto traverse_operand = [&traverser_functor, def_use](const uint32_t* id) { + traverser_functor(def_use->GetDef(*id)); + }; + user->ForEachInOperand(traverse_operand); + + // For the first traversal we want to ignore the users of the phi. + if (ignore_phi_users && user->opcode() == SpvOp::SpvOpPhi) return; + + // Traverse each user with this lambda. + def_use->ForEachUser(user, traverser_functor); + + // Wrapper functor for the use traversal. + auto traverse_use = [&traverser_functor](Instruction* use, uint32_t) { + traverser_functor(use); + }; + def_use->ForEachUse(user, traverse_use); + + }; + + // We start the traversal of the use def graph by invoking the above + // lambda with the |inst| parameter. + traverser_functor(inst); +} + +bool LoopFissionImpl::GroupInstructionsByUseDef() { + std::vector> sets{}; + + // We want to ignore all the instructions stemming from the loop condition + // instruction. + BasicBlock* condition_block = loop_->FindConditionBlock(); + + if (!condition_block) return false; + Instruction* condition = &*condition_block->tail(); + + // We iterate over the blocks via iterating over all the blocks in the + // function, we do this so we are iterating in the same order which the blocks + // appear in the binary. + Function& function = *loop_->GetHeaderBlock()->GetParent(); + + // Create a temporary set to ignore certain groups of instructions within the + // loop. We don't want any instructions related to control flow to be removed + // from either loop only instructions within the control flow bodies. + std::set instructions_to_ignore{}; + TraverseUseDef(condition, &instructions_to_ignore, true, true); + + // Traverse control flow instructions to ensure they are added to the + // seen_instructions_ set and will be ignored when it it called with actual + // sets. + for (BasicBlock& block : function) { + if (!loop_->IsInsideLoop(block.id())) continue; + + for (Instruction& inst : block) { + // Ignore all instructions related to control flow. + if (inst.opcode() == SpvOp::SpvOpSelectionMerge || inst.IsBranch()) { + TraverseUseDef(&inst, &instructions_to_ignore, true, true); + } + } + } + + // Traverse the instructions and generate the sets, automatically ignoring any + // instructions in instructions_to_ignore. + for (BasicBlock& block : function) { + if (!loop_->IsInsideLoop(block.id()) || + loop_->GetHeaderBlock()->id() == block.id()) + continue; + + for (Instruction& inst : block) { + // Record the order that each load/store is seen. + if (inst.opcode() == SpvOp::SpvOpLoad || + inst.opcode() == SpvOp::SpvOpStore) { + instruction_order_[&inst] = instruction_order_.size(); + } + + // Ignore instructions already seen in a traversal. + if (seen_instructions_.count(&inst) != 0) { + continue; + } + + // Build the set. + std::set inst_set{}; + TraverseUseDef(&inst, &inst_set); + if (!inst_set.empty()) sets.push_back(std::move(inst_set)); + } + } + + // If we have one or zero sets return false to indicate that due to + // insufficient instructions we couldn't split the loop into two groups and + // thus the loop can't be split any further. + if (sets.size() < 2) { + return false; + } + + // Merge the loop sets into two different sets. In CanPerformSplit we will + // validate that we don't break the relative ordering of loads/stores by doing + // this. + for (size_t index = 0; index < sets.size() / 2; ++index) { + cloned_loop_instructions_.insert(sets[index].begin(), sets[index].end()); + } + for (size_t index = sets.size() / 2; index < sets.size(); ++index) { + original_loop_instructions_.insert(sets[index].begin(), sets[index].end()); + } + + return true; +} + +bool LoopFissionImpl::CanPerformSplit() { + // Return false if any of the condition instructions in the loop depend on a + // load. + if (load_used_in_condition_) { + return false; + } + + // Build a list of all parent loops of this loop. Loop dependence analysis + // needs this structure. + std::vector loops; + Loop* parent_loop = loop_; + while (parent_loop) { + loops.push_back(parent_loop); + parent_loop = parent_loop->GetParent(); + } + + LoopDependenceAnalysis analysis{context_, loops}; + + // A list of all the stores in the cloned loop. + std::vector set_one_stores{}; + + // A list of all the loads in the cloned loop. + std::vector set_one_loads{}; + + // Populate the above lists. + for (Instruction* inst : cloned_loop_instructions_) { + if (inst->opcode() == SpvOp::SpvOpStore) { + set_one_stores.push_back(inst); + } else if (inst->opcode() == SpvOp::SpvOpLoad) { + set_one_loads.push_back(inst); + } + + // If we find any instruction which we can't move (such as a barrier), + // return false. + if (!MovableInstruction(*inst)) return false; + } + + // We need to calculate the depth of the loop to create the loop dependency + // distance vectors. + const size_t loop_depth = loop_->GetDepth(); + + // Check the dependencies between loads in the cloned loop and stores in the + // original and vice versa. + for (Instruction* inst : original_loop_instructions_) { + // If we find any instruction which we can't move (such as a barrier), + // return false. + if (!MovableInstruction(*inst)) return false; + + // Look at the dependency between the loads in the original and stores in + // the cloned loops. + if (inst->opcode() == SpvOp::SpvOpLoad) { + for (Instruction* store : set_one_stores) { + DistanceVector vec{loop_depth}; + + // If the store actually should appear after the load, return false. + // This means the store has been placed in the wrong grouping. + if (instruction_order_[store] > instruction_order_[inst]) { + return false; + } + // If not independent check the distance vector. + if (!analysis.GetDependence(store, inst, &vec)) { + for (DistanceEntry& entry : vec.GetEntries()) { + // A distance greater than zero means that the store in the cloned + // loop has a dependency on the load in the original loop. + if (entry.distance > 0) return false; + } + } + } + } else if (inst->opcode() == SpvOp::SpvOpStore) { + for (Instruction* load : set_one_loads) { + DistanceVector vec{loop_depth}; + + // If the load actually should appear after the store, return false. + if (instruction_order_[load] > instruction_order_[inst]) { + return false; + } + + // If not independent check the distance vector. + if (!analysis.GetDependence(inst, load, &vec)) { + for (DistanceEntry& entry : vec.GetEntries()) { + // A distance less than zero means the load in the cloned loop is + // dependent on the store instruction in the original loop. + if (entry.distance < 0) return false; + } + } + } + } + } + return true; +} + +Loop* LoopFissionImpl::SplitLoop() { + // Clone the loop. + LoopUtils util{context_, loop_}; + LoopUtils::LoopCloningResult clone_results; + Loop* cloned_loop = util.CloneAndAttachLoopToHeader(&clone_results); + + // Update the OpLoopMerge in the cloned loop. + cloned_loop->UpdateLoopMergeInst(); + + // Add the loop_ to the module. + // TODO(1841): Handle failure to create pre-header. + Function::iterator it = + util.GetFunction()->FindBlock(loop_->GetOrCreatePreHeaderBlock()->id()); + util.GetFunction()->AddBasicBlocks(clone_results.cloned_bb_.begin(), + clone_results.cloned_bb_.end(), ++it); + loop_->SetPreHeaderBlock(cloned_loop->GetMergeBlock()); + + std::vector instructions_to_kill{}; + + // Kill all the instructions which should appear in the cloned loop but not in + // the original loop. + for (uint32_t id : loop_->GetBlocks()) { + BasicBlock* block = context_->cfg()->block(id); + + for (Instruction& inst : *block) { + // If the instruction appears in the cloned loop instruction group, kill + // it. + if (cloned_loop_instructions_.count(&inst) == 1 && + original_loop_instructions_.count(&inst) == 0) { + instructions_to_kill.push_back(&inst); + if (inst.opcode() == SpvOp::SpvOpPhi) { + context_->ReplaceAllUsesWith( + inst.result_id(), clone_results.value_map_[inst.result_id()]); + } + } + } + } + + // Kill all instructions which should appear in the original loop and not in + // the cloned loop. + for (uint32_t id : cloned_loop->GetBlocks()) { + BasicBlock* block = context_->cfg()->block(id); + for (Instruction& inst : *block) { + Instruction* old_inst = clone_results.ptr_map_[&inst]; + // If the instruction belongs to the original loop instruction group, kill + // it. + if (cloned_loop_instructions_.count(old_inst) == 0 && + original_loop_instructions_.count(old_inst) == 1) { + instructions_to_kill.push_back(&inst); + } + } + } + + for (Instruction* i : instructions_to_kill) { + context_->KillInst(i); + } + + return cloned_loop; +} + +LoopFissionPass::LoopFissionPass(const size_t register_threshold_to_split, + bool split_multiple_times) + : split_multiple_times_(split_multiple_times) { + // Split if the number of registers in the loop exceeds + // |register_threshold_to_split|. + split_criteria_ = + [register_threshold_to_split]( + const RegisterLiveness::RegionRegisterLiveness& liveness) { + return liveness.used_registers_ > register_threshold_to_split; + }; +} + +LoopFissionPass::LoopFissionPass() : split_multiple_times_(false) { + // Split by default. + split_criteria_ = [](const RegisterLiveness::RegionRegisterLiveness&) { + return true; + }; +} + +bool LoopFissionPass::ShouldSplitLoop(const Loop& loop, IRContext* c) { + LivenessAnalysis* analysis = c->GetLivenessAnalysis(); + + RegisterLiveness::RegionRegisterLiveness liveness{}; + + Function* function = loop.GetHeaderBlock()->GetParent(); + analysis->Get(function)->ComputeLoopRegisterPressure(loop, &liveness); + + return split_criteria_(liveness); +} + +Pass::Status LoopFissionPass::Process() { + bool changed = false; + + for (Function& f : *context()->module()) { + // We collect all the inner most loops in the function and run the loop + // splitting util on each. The reason we do this is to allow us to iterate + // over each, as creating new loops will invalidate the the loop iterator. + std::vector inner_most_loops{}; + LoopDescriptor& loop_descriptor = *context()->GetLoopDescriptor(&f); + for (Loop& loop : loop_descriptor) { + if (!loop.HasChildren() && ShouldSplitLoop(loop, context())) { + inner_most_loops.push_back(&loop); + } + } + + // List of new loops which meet the criteria to be split again. + std::vector new_loops_to_split{}; + + while (!inner_most_loops.empty()) { + for (Loop* loop : inner_most_loops) { + LoopFissionImpl impl{context(), loop}; + + // Group the instructions in the loop into two different sets of related + // instructions. If we can't group the instructions into the two sets + // then we can't split the loop any further. + if (!impl.GroupInstructionsByUseDef()) { + continue; + } + + if (impl.CanPerformSplit()) { + Loop* second_loop = impl.SplitLoop(); + changed = true; + context()->InvalidateAnalysesExceptFor( + IRContext::kAnalysisLoopAnalysis); + + // If the newly created loop meets the criteria to be split, split it + // again. + if (ShouldSplitLoop(*second_loop, context())) + new_loops_to_split.push_back(second_loop); + + // If the original loop (now split) still meets the criteria to be + // split, split it again. + if (ShouldSplitLoop(*loop, context())) + new_loops_to_split.push_back(loop); + } + } + + // If the split multiple times flag has been set add the new loops which + // meet the splitting criteria into the list of loops to be split on the + // next iteration. + if (split_multiple_times_) { + inner_most_loops = std::move(new_loops_to_split); + } else { + break; + } + } + } + + return changed ? Pass::Status::SuccessWithChange + : Pass::Status::SuccessWithoutChange; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/loop_fission.h b/source/opt/loop_fission.h new file mode 100644 index 000000000..e7a59c185 --- /dev/null +++ b/source/opt/loop_fission.h @@ -0,0 +1,78 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_LOOP_FISSION_H_ +#define SOURCE_OPT_LOOP_FISSION_H_ + +#include +#include +#include +#include +#include + +#include "source/opt/cfg.h" +#include "source/opt/loop_dependence.h" +#include "source/opt/loop_utils.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" +#include "source/opt/tree_iterator.h" + +namespace spvtools { +namespace opt { + +class LoopFissionPass : public Pass { + public: + // Fuction used to determine if a given loop should be split. Takes register + // pressure region for that loop as a parameter and returns true if the loop + // should be split. + using FissionCriteriaFunction = + std::function; + + // Pass built with this constructor will split all loops regardless of + // register pressure. Will not split loops more than once. + LoopFissionPass(); + + // Split the loop if the number of registers used in the loop exceeds + // |register_threshold_to_split|. |split_multiple_times| flag determines + // whether or not the pass should split loops after already splitting them + // once. + LoopFissionPass(size_t register_threshold_to_split, + bool split_multiple_times = true); + + // Split loops whose register pressure meets the criteria of |functor|. + LoopFissionPass(FissionCriteriaFunction functor, + bool split_multiple_times = true) + : split_criteria_(functor), split_multiple_times_(split_multiple_times) {} + + const char* name() const override { return "loop-fission"; } + + Pass::Status Process() override; + + // Checks if |loop| meets the register pressure criteria to be split. + bool ShouldSplitLoop(const Loop& loop, IRContext* context); + + private: + // Functor to run in ShouldSplitLoop to determine if the register pressure + // criteria is met for splitting the loop. + FissionCriteriaFunction split_criteria_; + + // Flag designating whether or not we should also split the result of + // previously split loops if they meet the register presure criteria. + bool split_multiple_times_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_LOOP_FISSION_H_ diff --git a/source/opt/loop_fusion.cpp b/source/opt/loop_fusion.cpp new file mode 100644 index 000000000..07d171a0a --- /dev/null +++ b/source/opt/loop_fusion.cpp @@ -0,0 +1,730 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/loop_fusion.h" + +#include +#include + +#include "source/opt/ir_context.h" +#include "source/opt/loop_dependence.h" +#include "source/opt/loop_descriptor.h" + +namespace spvtools { +namespace opt { + +namespace { + +// Append all the loops nested in |loop| to |loops|. +void CollectChildren(Loop* loop, std::vector* loops) { + for (auto child : *loop) { + loops->push_back(child); + if (child->NumImmediateChildren() != 0) { + CollectChildren(child, loops); + } + } +} + +// Return the set of locations accessed by |stores| and |loads|. +std::set GetLocationsAccessed( + const std::map>& stores, + const std::map>& loads) { + std::set locations{}; + + for (const auto& kv : stores) { + locations.insert(std::get<0>(kv)); + } + + for (const auto& kv : loads) { + locations.insert(std::get<0>(kv)); + } + + return locations; +} + +// Append all dependences from |sources| to |destinations| to |dependences|. +void GetDependences(std::vector* dependences, + LoopDependenceAnalysis* analysis, + const std::vector& sources, + const std::vector& destinations, + size_t num_entries) { + for (auto source : sources) { + for (auto destination : destinations) { + DistanceVector dist(num_entries); + if (!analysis->GetDependence(source, destination, &dist)) { + dependences->push_back(dist); + } + } + } +} + +// Apped all instructions in |block| to |instructions|. +void AddInstructionsInBlock(std::vector* instructions, + BasicBlock* block) { + for (auto& inst : *block) { + instructions->push_back(&inst); + } + + instructions->push_back(block->GetLabelInst()); +} + +} // namespace + +bool LoopFusion::UsedInContinueOrConditionBlock(Instruction* phi_instruction, + Loop* loop) { + auto condition_block = loop->FindConditionBlock()->id(); + auto continue_block = loop->GetContinueBlock()->id(); + auto not_used = context_->get_def_use_mgr()->WhileEachUser( + phi_instruction, + [this, condition_block, continue_block](Instruction* instruction) { + auto block_id = context_->get_instr_block(instruction)->id(); + return block_id != condition_block && block_id != continue_block; + }); + + return !not_used; +} + +void LoopFusion::RemoveIfNotUsedContinueOrConditionBlock( + std::vector* instructions, Loop* loop) { + instructions->erase( + std::remove_if(std::begin(*instructions), std::end(*instructions), + [this, loop](Instruction* instruction) { + return !UsedInContinueOrConditionBlock(instruction, + loop); + }), + std::end(*instructions)); +} + +bool LoopFusion::AreCompatible() { + // Check that the loops are in the same function. + if (loop_0_->GetHeaderBlock()->GetParent() != + loop_1_->GetHeaderBlock()->GetParent()) { + return false; + } + + // Check that both loops have pre-header blocks. + if (!loop_0_->GetPreHeaderBlock() || !loop_1_->GetPreHeaderBlock()) { + return false; + } + + // Check there are no breaks. + if (context_->cfg()->preds(loop_0_->GetMergeBlock()->id()).size() != 1 || + context_->cfg()->preds(loop_1_->GetMergeBlock()->id()).size() != 1) { + return false; + } + + // Check there are no continues. + if (context_->cfg()->preds(loop_0_->GetContinueBlock()->id()).size() != 1 || + context_->cfg()->preds(loop_1_->GetContinueBlock()->id()).size() != 1) { + return false; + } + + // |GetInductionVariables| returns all OpPhi in the header. Check that both + // loops have exactly one that is used in the continue and condition blocks. + std::vector inductions_0{}, inductions_1{}; + loop_0_->GetInductionVariables(inductions_0); + RemoveIfNotUsedContinueOrConditionBlock(&inductions_0, loop_0_); + + if (inductions_0.size() != 1) { + return false; + } + + induction_0_ = inductions_0.front(); + + loop_1_->GetInductionVariables(inductions_1); + RemoveIfNotUsedContinueOrConditionBlock(&inductions_1, loop_1_); + + if (inductions_1.size() != 1) { + return false; + } + + induction_1_ = inductions_1.front(); + + if (!CheckInit()) { + return false; + } + + if (!CheckCondition()) { + return false; + } + + if (!CheckStep()) { + return false; + } + + // Check adjacency, |loop_0_| should come just before |loop_1_|. + // There is always at least one block between loops, even if it's empty. + // We'll check at most 2 preceeding blocks. + + auto pre_header_1 = loop_1_->GetPreHeaderBlock(); + + std::vector block_to_check{}; + block_to_check.push_back(pre_header_1); + + if (loop_0_->GetMergeBlock() != loop_1_->GetPreHeaderBlock()) { + // Follow CFG for one more block. + auto preds = context_->cfg()->preds(pre_header_1->id()); + if (preds.size() == 1) { + auto block = &*containing_function_->FindBlock(preds.front()); + if (block == loop_0_->GetMergeBlock()) { + block_to_check.push_back(block); + } else { + return false; + } + } else { + return false; + } + } + + // Check that the separating blocks are either empty or only contains a store + // to a local variable that is never read (left behind by + // '--eliminate-local-multi-store'). Also allow OpPhi, since the loop could be + // in LCSSA form. + for (auto block : block_to_check) { + for (auto& inst : *block) { + if (inst.opcode() == SpvOpStore) { + // Get the definition of the target to check it's function scope so + // there are no observable side effects. + auto variable = + context_->get_def_use_mgr()->GetDef(inst.GetSingleWordInOperand(0)); + + if (variable->opcode() != SpvOpVariable || + variable->GetSingleWordInOperand(0) != SpvStorageClassFunction) { + return false; + } + + // Check the target is never loaded. + auto is_used = false; + context_->get_def_use_mgr()->ForEachUse( + inst.GetSingleWordInOperand(0), + [&is_used](Instruction* use_inst, uint32_t) { + if (use_inst->opcode() == SpvOpLoad) { + is_used = true; + } + }); + + if (is_used) { + return false; + } + } else if (inst.opcode() == SpvOpPhi) { + if (inst.NumInOperands() != 2) { + return false; + } + } else if (inst.opcode() != SpvOpBranch) { + return false; + } + } + } + + return true; +} // namespace opt + +bool LoopFusion::ContainsBarriersOrFunctionCalls(Loop* loop) { + for (const auto& block : loop->GetBlocks()) { + for (const auto& inst : *containing_function_->FindBlock(block)) { + auto opcode = inst.opcode(); + if (opcode == SpvOpFunctionCall || opcode == SpvOpControlBarrier || + opcode == SpvOpMemoryBarrier || opcode == SpvOpTypeNamedBarrier || + opcode == SpvOpNamedBarrierInitialize || + opcode == SpvOpMemoryNamedBarrier) { + return true; + } + } + } + + return false; +} + +bool LoopFusion::CheckInit() { + int64_t loop_0_init; + if (!loop_0_->GetInductionInitValue(induction_0_, &loop_0_init)) { + return false; + } + + int64_t loop_1_init; + if (!loop_1_->GetInductionInitValue(induction_1_, &loop_1_init)) { + return false; + } + + if (loop_0_init != loop_1_init) { + return false; + } + + return true; +} + +bool LoopFusion::CheckCondition() { + auto condition_0 = loop_0_->GetConditionInst(); + auto condition_1 = loop_1_->GetConditionInst(); + + if (!loop_0_->IsSupportedCondition(condition_0->opcode()) || + !loop_1_->IsSupportedCondition(condition_1->opcode())) { + return false; + } + + if (condition_0->opcode() != condition_1->opcode()) { + return false; + } + + for (uint32_t i = 0; i < condition_0->NumInOperandWords(); ++i) { + auto arg_0 = context_->get_def_use_mgr()->GetDef( + condition_0->GetSingleWordInOperand(i)); + auto arg_1 = context_->get_def_use_mgr()->GetDef( + condition_1->GetSingleWordInOperand(i)); + + if (arg_0 == induction_0_ && arg_1 == induction_1_) { + continue; + } + + if (arg_0 == induction_0_ && arg_1 != induction_1_) { + return false; + } + + if (arg_1 == induction_1_ && arg_0 != induction_0_) { + return false; + } + + if (arg_0 != arg_1) { + return false; + } + } + + return true; +} + +bool LoopFusion::CheckStep() { + auto scalar_analysis = context_->GetScalarEvolutionAnalysis(); + SENode* induction_node_0 = scalar_analysis->SimplifyExpression( + scalar_analysis->AnalyzeInstruction(induction_0_)); + if (!induction_node_0->AsSERecurrentNode()) { + return false; + } + + SENode* induction_step_0 = + induction_node_0->AsSERecurrentNode()->GetCoefficient(); + if (!induction_step_0->AsSEConstantNode()) { + return false; + } + + SENode* induction_node_1 = scalar_analysis->SimplifyExpression( + scalar_analysis->AnalyzeInstruction(induction_1_)); + if (!induction_node_1->AsSERecurrentNode()) { + return false; + } + + SENode* induction_step_1 = + induction_node_1->AsSERecurrentNode()->GetCoefficient(); + if (!induction_step_1->AsSEConstantNode()) { + return false; + } + + if (*induction_step_0 != *induction_step_1) { + return false; + } + + return true; +} + +std::map> LoopFusion::LocationToMemOps( + const std::vector& mem_ops) { + std::map> location_map{}; + + for (auto instruction : mem_ops) { + auto access_location = context_->get_def_use_mgr()->GetDef( + instruction->GetSingleWordInOperand(0)); + + while (access_location->opcode() == SpvOpAccessChain) { + access_location = context_->get_def_use_mgr()->GetDef( + access_location->GetSingleWordInOperand(0)); + } + + location_map[access_location].push_back(instruction); + } + + return location_map; +} + +std::pair, std::vector> +LoopFusion::GetLoadsAndStoresInLoop(Loop* loop) { + std::vector loads{}; + std::vector stores{}; + + for (auto block_id : loop->GetBlocks()) { + if (block_id == loop->GetContinueBlock()->id()) { + continue; + } + + for (auto& instruction : *containing_function_->FindBlock(block_id)) { + if (instruction.opcode() == SpvOpLoad) { + loads.push_back(&instruction); + } else if (instruction.opcode() == SpvOpStore) { + stores.push_back(&instruction); + } + } + } + + return std::make_pair(loads, stores); +} + +bool LoopFusion::IsUsedInLoop(Instruction* instruction, Loop* loop) { + auto not_used = context_->get_def_use_mgr()->WhileEachUser( + instruction, [this, loop](Instruction* user) { + auto block_id = context_->get_instr_block(user)->id(); + return !loop->IsInsideLoop(block_id); + }); + + return !not_used; +} + +bool LoopFusion::IsLegal() { + assert(AreCompatible() && "Fusion can't be legal, loops are not compatible."); + + // Bail out if there are function calls as they could have side-effects that + // cause dependencies or if there are any barriers. + if (ContainsBarriersOrFunctionCalls(loop_0_) || + ContainsBarriersOrFunctionCalls(loop_1_)) { + return false; + } + + std::vector phi_instructions{}; + loop_0_->GetInductionVariables(phi_instructions); + + // Check no OpPhi in |loop_0_| is used in |loop_1_|. + for (auto phi_instruction : phi_instructions) { + if (IsUsedInLoop(phi_instruction, loop_1_)) { + return false; + } + } + + // Check no LCSSA OpPhi in merge block of |loop_0_| is used in |loop_1_|. + auto phi_used = false; + loop_0_->GetMergeBlock()->ForEachPhiInst( + [this, &phi_used](Instruction* phi_instruction) { + phi_used |= IsUsedInLoop(phi_instruction, loop_1_); + }); + + if (phi_used) { + return false; + } + + // Grab loads & stores from both loops. + auto loads_stores_0 = GetLoadsAndStoresInLoop(loop_0_); + auto loads_stores_1 = GetLoadsAndStoresInLoop(loop_1_); + + // Build memory location to operation maps. + auto load_locs_0 = LocationToMemOps(std::get<0>(loads_stores_0)); + auto store_locs_0 = LocationToMemOps(std::get<1>(loads_stores_0)); + + auto load_locs_1 = LocationToMemOps(std::get<0>(loads_stores_1)); + auto store_locs_1 = LocationToMemOps(std::get<1>(loads_stores_1)); + + // Get the locations accessed in both loops. + auto locations_0 = GetLocationsAccessed(store_locs_0, load_locs_0); + auto locations_1 = GetLocationsAccessed(store_locs_1, load_locs_1); + + std::vector potential_clashes{}; + + std::set_intersection(std::begin(locations_0), std::end(locations_0), + std::begin(locations_1), std::end(locations_1), + std::back_inserter(potential_clashes)); + + // If the loops don't access the same variables, the fusion is legal. + if (potential_clashes.empty()) { + return true; + } + + // Find variables that have at least one store. + std::vector potential_clashes_with_stores{}; + for (auto location : potential_clashes) { + if (store_locs_0.find(location) != std::end(store_locs_0) || + store_locs_1.find(location) != std::end(store_locs_1)) { + potential_clashes_with_stores.push_back(location); + } + } + + // If there are only loads to the same variables, the fusion is legal. + if (potential_clashes_with_stores.empty()) { + return true; + } + + // Else if loads and at least one store (across loops) to the same variable + // there is a potential dependence and we need to check the dependence + // distance. + + // Find all the loops in this loop nest for the dependency analysis. + std::vector loops{}; + + // Find the parents. + for (auto current_loop = loop_0_; current_loop != nullptr; + current_loop = current_loop->GetParent()) { + loops.push_back(current_loop); + } + + auto this_loop_position = loops.size() - 1; + std::reverse(std::begin(loops), std::end(loops)); + + // Find the children. + CollectChildren(loop_0_, &loops); + CollectChildren(loop_1_, &loops); + + // Check that any dependes created are legal. That means the fused loops do + // not have any dependencies with dependence distance greater than 0 that did + // not exist in the original loops. + + LoopDependenceAnalysis analysis(context_, loops); + + analysis.GetScalarEvolution()->AddLoopsToPretendAreTheSame( + {loop_0_, loop_1_}); + + for (auto location : potential_clashes_with_stores) { + // Analyse dependences from |loop_0_| to |loop_1_|. + std::vector dependences; + // Read-After-Write. + GetDependences(&dependences, &analysis, store_locs_0[location], + load_locs_1[location], loops.size()); + // Write-After-Read. + GetDependences(&dependences, &analysis, load_locs_0[location], + store_locs_1[location], loops.size()); + // Write-After-Write. + GetDependences(&dependences, &analysis, store_locs_0[location], + store_locs_1[location], loops.size()); + + // Check that the induction variables either don't appear in the subscripts + // or the dependence distance is negative. + for (const auto& dependence : dependences) { + const auto& entry = dependence.GetEntries()[this_loop_position]; + if ((entry.dependence_information == + DistanceEntry::DependenceInformation::DISTANCE && + entry.distance < 1) || + (entry.dependence_information == + DistanceEntry::DependenceInformation::IRRELEVANT)) { + continue; + } else { + return false; + } + } + } + + return true; +} + +void ReplacePhiParentWith(Instruction* inst, uint32_t orig_block, + uint32_t new_block) { + if (inst->GetSingleWordInOperand(1) == orig_block) { + inst->SetInOperand(1, {new_block}); + } else { + inst->SetInOperand(3, {new_block}); + } +} + +void LoopFusion::Fuse() { + assert(AreCompatible() && "Can't fuse, loops aren't compatible"); + assert(IsLegal() && "Can't fuse, illegal"); + + // Save the pointers/ids, won't be found in the middle of doing modifications. + auto header_1 = loop_1_->GetHeaderBlock()->id(); + auto condition_1 = loop_1_->FindConditionBlock()->id(); + auto continue_1 = loop_1_->GetContinueBlock()->id(); + auto continue_0 = loop_0_->GetContinueBlock()->id(); + auto condition_block_of_0 = loop_0_->FindConditionBlock(); + + // Find the blocks whose branches need updating. + auto first_block_of_1 = &*(++containing_function_->FindBlock(condition_1)); + auto last_block_of_1 = &*(--containing_function_->FindBlock(continue_1)); + auto last_block_of_0 = &*(--containing_function_->FindBlock(continue_0)); + + // Update the branch for |last_block_of_loop_0| to go to |first_block_of_1|. + last_block_of_0->ForEachSuccessorLabel( + [first_block_of_1](uint32_t* succ) { *succ = first_block_of_1->id(); }); + + // Update the branch for the |last_block_of_loop_1| to go to the continue + // block of |loop_0_|. + last_block_of_1->ForEachSuccessorLabel( + [this](uint32_t* succ) { *succ = loop_0_->GetContinueBlock()->id(); }); + + // Update merge block id in the header of |loop_0_| to the merge block of + // |loop_1_|. + loop_0_->GetHeaderBlock()->ForEachInst([this](Instruction* inst) { + if (inst->opcode() == SpvOpLoopMerge) { + inst->SetInOperand(0, {loop_1_->GetMergeBlock()->id()}); + } + }); + + // Update condition branch target in |loop_0_| to the merge block of + // |loop_1_|. + condition_block_of_0->ForEachInst([this](Instruction* inst) { + if (inst->opcode() == SpvOpBranchConditional) { + auto loop_0_merge_block_id = loop_0_->GetMergeBlock()->id(); + + if (inst->GetSingleWordInOperand(1) == loop_0_merge_block_id) { + inst->SetInOperand(1, {loop_1_->GetMergeBlock()->id()}); + } else { + inst->SetInOperand(2, {loop_1_->GetMergeBlock()->id()}); + } + } + }); + + // Move OpPhi instructions not corresponding to the induction variable from + // the header of |loop_1_| to the header of |loop_0_|. + std::vector instructions_to_move{}; + for (auto& instruction : *loop_1_->GetHeaderBlock()) { + if (instruction.opcode() == SpvOpPhi && &instruction != induction_1_) { + instructions_to_move.push_back(&instruction); + } + } + + for (auto& it : instructions_to_move) { + it->RemoveFromList(); + it->InsertBefore(induction_0_); + } + + // Update the OpPhi parents to the correct blocks in |loop_0_|. + loop_0_->GetHeaderBlock()->ForEachPhiInst([this](Instruction* i) { + ReplacePhiParentWith(i, loop_1_->GetPreHeaderBlock()->id(), + loop_0_->GetPreHeaderBlock()->id()); + + ReplacePhiParentWith(i, loop_1_->GetContinueBlock()->id(), + loop_0_->GetContinueBlock()->id()); + }); + + // Update instruction to block mapping & DefUseManager. + for (auto& phi_instruction : instructions_to_move) { + context_->set_instr_block(phi_instruction, loop_0_->GetHeaderBlock()); + context_->get_def_use_mgr()->AnalyzeInstUse(phi_instruction); + } + + // Replace the uses of the induction variable of |loop_1_| with that the + // induction variable of |loop_0_|. + context_->ReplaceAllUsesWith(induction_1_->result_id(), + induction_0_->result_id()); + + // Replace LCSSA OpPhi in merge block of |loop_0_|. + loop_0_->GetMergeBlock()->ForEachPhiInst([this](Instruction* instruction) { + context_->ReplaceAllUsesWith(instruction->result_id(), + instruction->GetSingleWordInOperand(0)); + }); + + // Update LCSSA OpPhi in merge block of |loop_1_|. + loop_1_->GetMergeBlock()->ForEachPhiInst( + [condition_block_of_0](Instruction* instruction) { + instruction->SetInOperand(1, {condition_block_of_0->id()}); + }); + + // Move the continue block of |loop_0_| after the last block of |loop_1_|. + containing_function_->MoveBasicBlockToAfter(continue_0, last_block_of_1); + + // Gather all instructions to be killed from |loop_1_| (induction variable + // initialisation, header, condition and continue blocks). + std::vector instr_to_delete{}; + AddInstructionsInBlock(&instr_to_delete, loop_1_->GetPreHeaderBlock()); + AddInstructionsInBlock(&instr_to_delete, loop_1_->GetHeaderBlock()); + AddInstructionsInBlock(&instr_to_delete, loop_1_->FindConditionBlock()); + AddInstructionsInBlock(&instr_to_delete, loop_1_->GetContinueBlock()); + + // There was an additional empty block between the loops, kill that too. + if (loop_0_->GetMergeBlock() != loop_1_->GetPreHeaderBlock()) { + AddInstructionsInBlock(&instr_to_delete, loop_0_->GetMergeBlock()); + } + + // Update the CFG, so it wouldn't need invalidating. + auto cfg = context_->cfg(); + + cfg->ForgetBlock(loop_1_->GetPreHeaderBlock()); + cfg->ForgetBlock(loop_1_->GetHeaderBlock()); + cfg->ForgetBlock(loop_1_->FindConditionBlock()); + cfg->ForgetBlock(loop_1_->GetContinueBlock()); + + if (loop_0_->GetMergeBlock() != loop_1_->GetPreHeaderBlock()) { + cfg->ForgetBlock(loop_0_->GetMergeBlock()); + } + + cfg->RemoveEdge(last_block_of_0->id(), loop_0_->GetContinueBlock()->id()); + cfg->AddEdge(last_block_of_0->id(), first_block_of_1->id()); + + cfg->AddEdge(last_block_of_1->id(), loop_0_->GetContinueBlock()->id()); + + cfg->AddEdge(loop_0_->GetContinueBlock()->id(), + loop_1_->GetHeaderBlock()->id()); + + cfg->AddEdge(condition_block_of_0->id(), loop_1_->GetMergeBlock()->id()); + + // Update DefUseManager. + auto def_use_mgr = context_->get_def_use_mgr(); + + // Uses of labels that are in updated branches need analysing. + def_use_mgr->AnalyzeInstUse(last_block_of_0->terminator()); + def_use_mgr->AnalyzeInstUse(last_block_of_1->terminator()); + def_use_mgr->AnalyzeInstUse(loop_0_->GetHeaderBlock()->GetLoopMergeInst()); + def_use_mgr->AnalyzeInstUse(condition_block_of_0->terminator()); + + // Update the LoopDescriptor, so it wouldn't need invalidating. + auto ld = context_->GetLoopDescriptor(containing_function_); + + // Create a copy, so the iterator wouldn't be invalidated. + std::vector loops_to_add_remove{}; + for (auto child_loop : *loop_1_) { + loops_to_add_remove.push_back(child_loop); + } + + for (auto child_loop : loops_to_add_remove) { + loop_1_->RemoveChildLoop(child_loop); + loop_0_->AddNestedLoop(child_loop); + } + + auto loop_1_blocks = loop_1_->GetBlocks(); + + for (auto block : loop_1_blocks) { + loop_1_->RemoveBasicBlock(block); + if (block != header_1 && block != condition_1 && block != continue_1) { + loop_0_->AddBasicBlock(block); + if ((*ld)[block] == loop_1_) { + ld->SetBasicBlockToLoop(block, loop_0_); + } + } + + if ((*ld)[block] == loop_1_) { + ld->ForgetBasicBlock(block); + } + } + + loop_1_->RemoveBasicBlock(loop_1_->GetPreHeaderBlock()->id()); + ld->ForgetBasicBlock(loop_1_->GetPreHeaderBlock()->id()); + + if (loop_0_->GetMergeBlock() != loop_1_->GetPreHeaderBlock()) { + loop_0_->RemoveBasicBlock(loop_0_->GetMergeBlock()->id()); + ld->ForgetBasicBlock(loop_0_->GetMergeBlock()->id()); + } + + loop_0_->SetMergeBlock(loop_1_->GetMergeBlock()); + + loop_1_->ClearBlocks(); + + ld->RemoveLoop(loop_1_); + + // Kill unnessecary instructions and remove all empty blocks. + for (auto inst : instr_to_delete) { + context_->KillInst(inst); + } + + containing_function_->RemoveEmptyBlocks(); + + // Invalidate analyses. + context_->InvalidateAnalysesExceptFor( + IRContext::Analysis::kAnalysisInstrToBlockMapping | + IRContext::Analysis::kAnalysisLoopAnalysis | + IRContext::Analysis::kAnalysisDefUse | IRContext::Analysis::kAnalysisCFG); +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/loop_fusion.h b/source/opt/loop_fusion.h new file mode 100644 index 000000000..d61d6783c --- /dev/null +++ b/source/opt/loop_fusion.h @@ -0,0 +1,114 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_LOOP_FUSION_H_ +#define SOURCE_OPT_LOOP_FUSION_H_ + +#include +#include +#include +#include + +#include "source/opt/ir_context.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/loop_utils.h" +#include "source/opt/scalar_analysis.h" + +namespace spvtools { +namespace opt { + +class LoopFusion { + public: + LoopFusion(IRContext* context, Loop* loop_0, Loop* loop_1) + : context_(context), + loop_0_(loop_0), + loop_1_(loop_1), + containing_function_(loop_0->GetHeaderBlock()->GetParent()) {} + + // Checks if the |loop_0| and |loop_1| are compatible for fusion. + // That means: + // * they both have one induction variable + // * they have the same upper and lower bounds + // - same inital value + // - same condition + // * they have the same update step + // * they are adjacent, with |loop_0| appearing before |loop_1| + // * there are no break/continue in either of them + // * they both have pre-header blocks (required for ScalarEvolutionAnalysis + // and dependence checking). + bool AreCompatible(); + + // Checks if compatible |loop_0| and |loop_1| are legal to fuse. + // * fused loops do not have any dependencies with dependence distance greater + // than 0 that did not exist in the original loops. + // * there are no function calls in the loops (could have side-effects) + bool IsLegal(); + + // Perform the actual fusion of |loop_0_| and |loop_1_|. The loops have to be + // compatible and the fusion has to be legal. + void Fuse(); + + private: + // Check that the initial values are the same. + bool CheckInit(); + + // Check that the conditions are the same. + bool CheckCondition(); + + // Check that the steps are the same. + bool CheckStep(); + + // Returns |true| if |instruction| is used in the continue or condition block + // of |loop|. + bool UsedInContinueOrConditionBlock(Instruction* instruction, Loop* loop); + + // Remove entries in |instructions| that are not used in the continue or + // condition block of |loop|. + void RemoveIfNotUsedContinueOrConditionBlock( + std::vector* instructions, Loop* loop); + + // Returns |true| if |instruction| is used in |loop|. + bool IsUsedInLoop(Instruction* instruction, Loop* loop); + + // Returns |true| if |loop| has at least one barrier or function call. + bool ContainsBarriersOrFunctionCalls(Loop* loop); + + // Get all instructions in the |loop| (except in the latch block) that have + // the opcode |opcode|. + std::pair, std::vector> + GetLoadsAndStoresInLoop(Loop* loop); + + // Given a vector of memory operations (OpLoad/OpStore), constructs a map from + // variables to the loads/stores that those variables. + std::map> LocationToMemOps( + const std::vector& mem_ops); + + IRContext* context_; + + // The original loops to be fused. + Loop* loop_0_; + Loop* loop_1_; + + // The function that contains |loop_0_| and |loop_1_|. + Function* containing_function_ = nullptr; + + // The induction variables for |loop_0_| and |loop_1_|. + Instruction* induction_0_ = nullptr; + Instruction* induction_1_ = nullptr; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_LOOP_FUSION_H_ diff --git a/source/opt/loop_fusion_pass.cpp b/source/opt/loop_fusion_pass.cpp new file mode 100644 index 000000000..bd8444ae5 --- /dev/null +++ b/source/opt/loop_fusion_pass.cpp @@ -0,0 +1,69 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/loop_fusion_pass.h" + +#include "source/opt/ir_context.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/loop_fusion.h" +#include "source/opt/register_pressure.h" + +namespace spvtools { +namespace opt { + +Pass::Status LoopFusionPass::Process() { + bool modified = false; + Module* module = context()->module(); + + // Process each function in the module + for (Function& f : *module) { + modified |= ProcessFunction(&f); + } + + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +bool LoopFusionPass::ProcessFunction(Function* function) { + LoopDescriptor& ld = *context()->GetLoopDescriptor(function); + + // If a loop doesn't have a preheader needs then it needs to be created. Make + // sure to return Status::SuccessWithChange in that case. + auto modified = ld.CreatePreHeaderBlocksIfMissing(); + + // TODO(tremmelg): Could the only loop that |loop| could possibly be fused be + // picked out so don't have to check every loop + for (auto& loop_0 : ld) { + for (auto& loop_1 : ld) { + LoopFusion fusion(context(), &loop_0, &loop_1); + + if (fusion.AreCompatible() && fusion.IsLegal()) { + RegisterLiveness liveness(context(), function); + RegisterLiveness::RegionRegisterLiveness reg_pressure{}; + liveness.SimulateFusion(loop_0, loop_1, ®_pressure); + + if (reg_pressure.used_registers_ <= max_registers_per_loop_) { + fusion.Fuse(); + // Recurse, as the current iterators will have been invalidated. + ProcessFunction(function); + return true; + } + } + } + } + + return modified; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/loop_fusion_pass.h b/source/opt/loop_fusion_pass.h new file mode 100644 index 000000000..3a0be6000 --- /dev/null +++ b/source/opt/loop_fusion_pass.h @@ -0,0 +1,51 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_LOOP_FUSION_PASS_H_ +#define SOURCE_OPT_LOOP_FUSION_PASS_H_ + +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// Implements a loop fusion pass. +// This pass will look for adjacent loops that are compatible and legal to be +// fused. It will fuse all such loops as long as the register usage for the +// fused loop stays under the threshold defined by |max_registers_per_loop|. +class LoopFusionPass : public Pass { + public: + explicit LoopFusionPass(size_t max_registers_per_loop) + : Pass(), max_registers_per_loop_(max_registers_per_loop) {} + + const char* name() const override { return "loop-fusion"; } + + // Processes the given |module|. Returns Status::Failure if errors occur when + // processing. Returns the corresponding Status::Success if processing is + // succesful to indicate whether changes have been made to the modue. + Status Process() override; + + private: + // Fuse loops in |function| if compatible, legal and the fused loop won't use + // too many registers. + bool ProcessFunction(Function* function); + + // The maximum number of registers a fused loop is allowed to use. + size_t max_registers_per_loop_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_LOOP_FUSION_PASS_H_ diff --git a/source/opt/loop_peeling.cpp b/source/opt/loop_peeling.cpp new file mode 100644 index 000000000..b640542d3 --- /dev/null +++ b/source/opt/loop_peeling.cpp @@ -0,0 +1,1086 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include "source/opt/ir_builder.h" +#include "source/opt/ir_context.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/loop_peeling.h" +#include "source/opt/loop_utils.h" +#include "source/opt/scalar_analysis.h" +#include "source/opt/scalar_analysis_nodes.h" + +namespace spvtools { +namespace opt { +size_t LoopPeelingPass::code_grow_threshold_ = 1000; + +void LoopPeeling::DuplicateAndConnectLoop( + LoopUtils::LoopCloningResult* clone_results) { + CFG& cfg = *context_->cfg(); + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + + assert(CanPeelLoop() && "Cannot peel loop!"); + + std::vector ordered_loop_blocks; + // TODO(1841): Handle failure to create pre-header. + BasicBlock* pre_header = loop_->GetOrCreatePreHeaderBlock(); + + loop_->ComputeLoopStructuredOrder(&ordered_loop_blocks); + + cloned_loop_ = loop_utils_.CloneLoop(clone_results, ordered_loop_blocks); + + // Add the basic block to the function. + Function::iterator it = + loop_utils_.GetFunction()->FindBlock(pre_header->id()); + assert(it != loop_utils_.GetFunction()->end() && + "Pre-header not found in the function."); + loop_utils_.GetFunction()->AddBasicBlocks( + clone_results->cloned_bb_.begin(), clone_results->cloned_bb_.end(), ++it); + + // Make the |loop_|'s preheader the |cloned_loop_| one. + BasicBlock* cloned_header = cloned_loop_->GetHeaderBlock(); + pre_header->ForEachSuccessorLabel( + [cloned_header](uint32_t* succ) { *succ = cloned_header->id(); }); + + // Update cfg. + cfg.RemoveEdge(pre_header->id(), loop_->GetHeaderBlock()->id()); + cloned_loop_->SetPreHeaderBlock(pre_header); + loop_->SetPreHeaderBlock(nullptr); + + // When cloning the loop, we didn't cloned the merge block, so currently + // |cloned_loop_| shares the same block as |loop_|. + // We mutate all branches from |cloned_loop_| block to |loop_|'s merge into a + // branch to |loop_|'s header (so header will also be the merge of + // |cloned_loop_|). + uint32_t cloned_loop_exit = 0; + for (uint32_t pred_id : cfg.preds(loop_->GetMergeBlock()->id())) { + if (loop_->IsInsideLoop(pred_id)) continue; + BasicBlock* bb = cfg.block(pred_id); + assert(cloned_loop_exit == 0 && "The loop has multiple exits."); + cloned_loop_exit = bb->id(); + bb->ForEachSuccessorLabel([this](uint32_t* succ) { + if (*succ == loop_->GetMergeBlock()->id()) + *succ = loop_->GetHeaderBlock()->id(); + }); + } + + // Update cfg. + cfg.RemoveNonExistingEdges(loop_->GetMergeBlock()->id()); + cfg.AddEdge(cloned_loop_exit, loop_->GetHeaderBlock()->id()); + + // Patch the phi of the original loop header: + // - Set the loop entry branch to come from the cloned loop exit block; + // - Set the initial value of the phi using the corresponding cloned loop + // exit values. + // + // We patch the iterating value initializers of the original loop using the + // corresponding cloned loop exit values. Connects the cloned loop iterating + // values to the original loop. This make sure that the initial value of the + // second loop starts with the last value of the first loop. + // + // For example, loops like: + // + // int z = 0; + // for (int i = 0; i++ < M; i += cst1) { + // if (cond) + // z += cst2; + // } + // + // Will become: + // + // int z = 0; + // int i = 0; + // for (; i++ < M; i += cst1) { + // if (cond) + // z += cst2; + // } + // for (; i++ < M; i += cst1) { + // if (cond) + // z += cst2; + // } + loop_->GetHeaderBlock()->ForEachPhiInst([cloned_loop_exit, def_use_mgr, + clone_results, + this](Instruction* phi) { + for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) { + if (!loop_->IsInsideLoop(phi->GetSingleWordInOperand(i + 1))) { + phi->SetInOperand(i, + {clone_results->value_map_.at( + exit_value_.at(phi->result_id())->result_id())}); + phi->SetInOperand(i + 1, {cloned_loop_exit}); + def_use_mgr->AnalyzeInstUse(phi); + return; + } + } + }); + + // Force the creation of a new preheader for the original loop and set it as + // the merge block for the cloned loop. + // TODO(1841): Handle failure to create pre-header. + cloned_loop_->SetMergeBlock(loop_->GetOrCreatePreHeaderBlock()); +} + +void LoopPeeling::InsertCanonicalInductionVariable( + LoopUtils::LoopCloningResult* clone_results) { + if (original_loop_canonical_induction_variable_) { + canonical_induction_variable_ = + context_->get_def_use_mgr()->GetDef(clone_results->value_map_.at( + original_loop_canonical_induction_variable_->result_id())); + return; + } + + BasicBlock::iterator insert_point = GetClonedLoop()->GetLatchBlock()->tail(); + if (GetClonedLoop()->GetLatchBlock()->GetMergeInst()) { + --insert_point; + } + InstructionBuilder builder( + context_, &*insert_point, + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + Instruction* uint_1_cst = + builder.GetIntConstant(1, int_type_->IsSigned()); + // Create the increment. + // Note that we do "1 + 1" here, one of the operand should the phi + // value but we don't have it yet. The operand will be set latter. + Instruction* iv_inc = builder.AddIAdd( + uint_1_cst->type_id(), uint_1_cst->result_id(), uint_1_cst->result_id()); + + builder.SetInsertPoint(&*GetClonedLoop()->GetHeaderBlock()->begin()); + + canonical_induction_variable_ = builder.AddPhi( + uint_1_cst->type_id(), + {builder.GetIntConstant(0, int_type_->IsSigned())->result_id(), + GetClonedLoop()->GetPreHeaderBlock()->id(), iv_inc->result_id(), + GetClonedLoop()->GetLatchBlock()->id()}); + // Connect everything. + iv_inc->SetInOperand(0, {canonical_induction_variable_->result_id()}); + + // Update def/use manager. + context_->get_def_use_mgr()->AnalyzeInstUse(iv_inc); + + // If do-while form, use the incremented value. + if (do_while_form_) { + canonical_induction_variable_ = iv_inc; + } +} + +void LoopPeeling::GetIteratorUpdateOperations( + const Loop* loop, Instruction* iterator, + std::unordered_set* operations) { + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + operations->insert(iterator); + iterator->ForEachInId([def_use_mgr, loop, operations, this](uint32_t* id) { + Instruction* insn = def_use_mgr->GetDef(*id); + if (insn->opcode() == SpvOpLabel) { + return; + } + if (operations->count(insn)) { + return; + } + if (!loop->IsInsideLoop(insn)) { + return; + } + GetIteratorUpdateOperations(loop, insn, operations); + }); +} + +// Gather the set of blocks for all the path from |entry| to |root|. +static void GetBlocksInPath(uint32_t block, uint32_t entry, + std::unordered_set* blocks_in_path, + const CFG& cfg) { + for (uint32_t pid : cfg.preds(block)) { + if (blocks_in_path->insert(pid).second) { + if (pid != entry) { + GetBlocksInPath(pid, entry, blocks_in_path, cfg); + } + } + } +} + +bool LoopPeeling::IsConditionCheckSideEffectFree() const { + CFG& cfg = *context_->cfg(); + + // The "do-while" form does not cause issues, the algorithm takes into account + // the first iteration. + if (!do_while_form_) { + uint32_t condition_block_id = cfg.preds(loop_->GetMergeBlock()->id())[0]; + + std::unordered_set blocks_in_path; + + blocks_in_path.insert(condition_block_id); + GetBlocksInPath(condition_block_id, loop_->GetHeaderBlock()->id(), + &blocks_in_path, cfg); + + for (uint32_t bb_id : blocks_in_path) { + BasicBlock* bb = cfg.block(bb_id); + if (!bb->WhileEachInst([this](Instruction* insn) { + if (insn->IsBranch()) return true; + switch (insn->opcode()) { + case SpvOpLabel: + case SpvOpSelectionMerge: + case SpvOpLoopMerge: + return true; + default: + break; + } + return context_->IsCombinatorInstruction(insn); + })) { + return false; + } + } + } + + return true; +} + +void LoopPeeling::GetIteratingExitValues() { + CFG& cfg = *context_->cfg(); + + loop_->GetHeaderBlock()->ForEachPhiInst( + [this](Instruction* phi) { exit_value_[phi->result_id()] = nullptr; }); + + if (!loop_->GetMergeBlock()) { + return; + } + if (cfg.preds(loop_->GetMergeBlock()->id()).size() != 1) { + return; + } + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + + uint32_t condition_block_id = cfg.preds(loop_->GetMergeBlock()->id())[0]; + + auto& header_pred = cfg.preds(loop_->GetHeaderBlock()->id()); + do_while_form_ = std::find(header_pred.begin(), header_pred.end(), + condition_block_id) != header_pred.end(); + if (do_while_form_) { + loop_->GetHeaderBlock()->ForEachPhiInst( + [condition_block_id, def_use_mgr, this](Instruction* phi) { + std::unordered_set operations; + + for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) { + if (condition_block_id == phi->GetSingleWordInOperand(i + 1)) { + exit_value_[phi->result_id()] = + def_use_mgr->GetDef(phi->GetSingleWordInOperand(i)); + } + } + }); + } else { + DominatorTree* dom_tree = + &context_->GetDominatorAnalysis(loop_utils_.GetFunction()) + ->GetDomTree(); + BasicBlock* condition_block = cfg.block(condition_block_id); + + loop_->GetHeaderBlock()->ForEachPhiInst( + [dom_tree, condition_block, this](Instruction* phi) { + std::unordered_set operations; + + // Not the back-edge value, check if the phi instruction is the only + // possible candidate. + GetIteratorUpdateOperations(loop_, phi, &operations); + + for (Instruction* insn : operations) { + if (insn == phi) { + continue; + } + if (dom_tree->Dominates(context_->get_instr_block(insn), + condition_block)) { + return; + } + } + exit_value_[phi->result_id()] = phi; + }); + } +} + +void LoopPeeling::FixExitCondition( + const std::function& condition_builder) { + CFG& cfg = *context_->cfg(); + + uint32_t condition_block_id = 0; + for (uint32_t id : cfg.preds(GetClonedLoop()->GetMergeBlock()->id())) { + if (GetClonedLoop()->IsInsideLoop(id)) { + condition_block_id = id; + break; + } + } + assert(condition_block_id != 0 && "2nd loop in improperly connected"); + + BasicBlock* condition_block = cfg.block(condition_block_id); + Instruction* exit_condition = condition_block->terminator(); + assert(exit_condition->opcode() == SpvOpBranchConditional); + BasicBlock::iterator insert_point = condition_block->tail(); + if (condition_block->GetMergeInst()) { + --insert_point; + } + + exit_condition->SetInOperand(0, {condition_builder(&*insert_point)}); + + uint32_t to_continue_block_idx = + GetClonedLoop()->IsInsideLoop(exit_condition->GetSingleWordInOperand(1)) + ? 1 + : 2; + exit_condition->SetInOperand( + 1, {exit_condition->GetSingleWordInOperand(to_continue_block_idx)}); + exit_condition->SetInOperand(2, {GetClonedLoop()->GetMergeBlock()->id()}); + + // Update def/use manager. + context_->get_def_use_mgr()->AnalyzeInstUse(exit_condition); +} + +BasicBlock* LoopPeeling::CreateBlockBefore(BasicBlock* bb) { + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + CFG& cfg = *context_->cfg(); + assert(cfg.preds(bb->id()).size() == 1 && "More than one predecessor"); + + // TODO(1841): Handle id overflow. + std::unique_ptr new_bb = + MakeUnique(std::unique_ptr(new Instruction( + context_, SpvOpLabel, 0, context_->TakeNextId(), {}))); + new_bb->SetParent(loop_utils_.GetFunction()); + // Update the loop descriptor. + Loop* in_loop = (*loop_utils_.GetLoopDescriptor())[bb]; + if (in_loop) { + in_loop->AddBasicBlock(new_bb.get()); + loop_utils_.GetLoopDescriptor()->SetBasicBlockToLoop(new_bb->id(), in_loop); + } + + context_->set_instr_block(new_bb->GetLabelInst(), new_bb.get()); + def_use_mgr->AnalyzeInstDefUse(new_bb->GetLabelInst()); + + BasicBlock* bb_pred = cfg.block(cfg.preds(bb->id())[0]); + bb_pred->tail()->ForEachInId([bb, &new_bb](uint32_t* id) { + if (*id == bb->id()) { + *id = new_bb->id(); + } + }); + cfg.RemoveEdge(bb_pred->id(), bb->id()); + cfg.AddEdge(bb_pred->id(), new_bb->id()); + def_use_mgr->AnalyzeInstUse(&*bb_pred->tail()); + + // Update the incoming branch. + bb->ForEachPhiInst([&new_bb, def_use_mgr](Instruction* phi) { + phi->SetInOperand(1, {new_bb->id()}); + def_use_mgr->AnalyzeInstUse(phi); + }); + InstructionBuilder( + context_, new_bb.get(), + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping) + .AddBranch(bb->id()); + cfg.RegisterBlock(new_bb.get()); + + // Add the basic block to the function. + Function::iterator it = loop_utils_.GetFunction()->FindBlock(bb->id()); + assert(it != loop_utils_.GetFunction()->end() && + "Basic block not found in the function."); + BasicBlock* ret = new_bb.get(); + loop_utils_.GetFunction()->AddBasicBlock(std::move(new_bb), it); + return ret; +} + +BasicBlock* LoopPeeling::ProtectLoop(Loop* loop, Instruction* condition, + BasicBlock* if_merge) { + // TODO(1841): Handle failure to create pre-header. + BasicBlock* if_block = loop->GetOrCreatePreHeaderBlock(); + // Will no longer be a pre-header because of the if. + loop->SetPreHeaderBlock(nullptr); + // Kill the branch to the header. + context_->KillInst(&*if_block->tail()); + + InstructionBuilder builder( + context_, if_block, + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + builder.AddConditionalBranch(condition->result_id(), + loop->GetHeaderBlock()->id(), if_merge->id(), + if_merge->id()); + + return if_block; +} + +void LoopPeeling::PeelBefore(uint32_t peel_factor) { + assert(CanPeelLoop() && "Cannot peel loop"); + LoopUtils::LoopCloningResult clone_results; + + // Clone the loop and insert the cloned one before the loop. + DuplicateAndConnectLoop(&clone_results); + + // Add a canonical induction variable "canonical_induction_variable_". + InsertCanonicalInductionVariable(&clone_results); + + InstructionBuilder builder( + context_, &*cloned_loop_->GetPreHeaderBlock()->tail(), + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + Instruction* factor = + builder.GetIntConstant(peel_factor, int_type_->IsSigned()); + + Instruction* has_remaining_iteration = builder.AddLessThan( + factor->result_id(), loop_iteration_count_->result_id()); + Instruction* max_iteration = builder.AddSelect( + factor->type_id(), has_remaining_iteration->result_id(), + factor->result_id(), loop_iteration_count_->result_id()); + + // Change the exit condition of the cloned loop to be (exit when become + // false): + // "canonical_induction_variable_" < min("factor", "loop_iteration_count_") + FixExitCondition([max_iteration, this](Instruction* insert_before_point) { + return InstructionBuilder(context_, insert_before_point, + IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping) + .AddLessThan(canonical_induction_variable_->result_id(), + max_iteration->result_id()) + ->result_id(); + }); + + // "Protect" the second loop: the second loop can only be executed if + // |has_remaining_iteration| is true (i.e. factor < loop_iteration_count_). + BasicBlock* if_merge_block = loop_->GetMergeBlock(); + loop_->SetMergeBlock(CreateBlockBefore(loop_->GetMergeBlock())); + // Prevent the second loop from being executed if we already executed all the + // required iterations. + BasicBlock* if_block = + ProtectLoop(loop_, has_remaining_iteration, if_merge_block); + // Patch the phi of the merge block. + if_merge_block->ForEachPhiInst( + [&clone_results, if_block, this](Instruction* phi) { + // if_merge_block had previously only 1 predecessor. + uint32_t incoming_value = phi->GetSingleWordInOperand(0); + auto def_in_loop = clone_results.value_map_.find(incoming_value); + if (def_in_loop != clone_results.value_map_.end()) + incoming_value = def_in_loop->second; + phi->AddOperand( + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {incoming_value}}); + phi->AddOperand( + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {if_block->id()}}); + context_->get_def_use_mgr()->AnalyzeInstUse(phi); + }); + + context_->InvalidateAnalysesExceptFor( + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisCFG); +} + +void LoopPeeling::PeelAfter(uint32_t peel_factor) { + assert(CanPeelLoop() && "Cannot peel loop"); + LoopUtils::LoopCloningResult clone_results; + + // Clone the loop and insert the cloned one before the loop. + DuplicateAndConnectLoop(&clone_results); + + // Add a canonical induction variable "canonical_induction_variable_". + InsertCanonicalInductionVariable(&clone_results); + + InstructionBuilder builder( + context_, &*cloned_loop_->GetPreHeaderBlock()->tail(), + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + Instruction* factor = + builder.GetIntConstant(peel_factor, int_type_->IsSigned()); + + Instruction* has_remaining_iteration = builder.AddLessThan( + factor->result_id(), loop_iteration_count_->result_id()); + + // Change the exit condition of the cloned loop to be (exit when become + // false): + // "canonical_induction_variable_" + "factor" < "loop_iteration_count_" + FixExitCondition([factor, this](Instruction* insert_before_point) { + InstructionBuilder cond_builder( + context_, insert_before_point, + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + // Build the following check: canonical_induction_variable_ + factor < + // iteration_count + return cond_builder + .AddLessThan(cond_builder + .AddIAdd(canonical_induction_variable_->type_id(), + canonical_induction_variable_->result_id(), + factor->result_id()) + ->result_id(), + loop_iteration_count_->result_id()) + ->result_id(); + }); + + // "Protect" the first loop: the first loop can only be executed if + // factor < loop_iteration_count_. + + // The original loop's pre-header was the cloned loop merge block. + GetClonedLoop()->SetMergeBlock( + CreateBlockBefore(GetOriginalLoop()->GetPreHeaderBlock())); + // Use the second loop preheader as if merge block. + + // Prevent the first loop if only the peeled loop needs it. + BasicBlock* if_block = ProtectLoop(cloned_loop_, has_remaining_iteration, + GetOriginalLoop()->GetPreHeaderBlock()); + + // Patch the phi of the header block. + // We added an if to enclose the first loop and because the phi node are + // connected to the exit value of the first loop, the definition no longer + // dominate the preheader. + // We had to the preheader (our if merge block) the required phi instruction + // and patch the header phi. + GetOriginalLoop()->GetHeaderBlock()->ForEachPhiInst( + [&clone_results, if_block, this](Instruction* phi) { + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + + auto find_value_idx = [](Instruction* phi_inst, Loop* loop) { + uint32_t preheader_value_idx = + !loop->IsInsideLoop(phi_inst->GetSingleWordInOperand(1)) ? 0 : 2; + return preheader_value_idx; + }; + + Instruction* cloned_phi = + def_use_mgr->GetDef(clone_results.value_map_.at(phi->result_id())); + uint32_t cloned_preheader_value = cloned_phi->GetSingleWordInOperand( + find_value_idx(cloned_phi, GetClonedLoop())); + + Instruction* new_phi = + InstructionBuilder(context_, + &*GetOriginalLoop()->GetPreHeaderBlock()->tail(), + IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping) + .AddPhi(phi->type_id(), + {phi->GetSingleWordInOperand( + find_value_idx(phi, GetOriginalLoop())), + GetClonedLoop()->GetMergeBlock()->id(), + cloned_preheader_value, if_block->id()}); + + phi->SetInOperand(find_value_idx(phi, GetOriginalLoop()), + {new_phi->result_id()}); + def_use_mgr->AnalyzeInstUse(phi); + }); + + context_->InvalidateAnalysesExceptFor( + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisCFG); +} + +Pass::Status LoopPeelingPass::Process() { + bool modified = false; + Module* module = context()->module(); + + // Process each function in the module + for (Function& f : *module) { + modified |= ProcessFunction(&f); + } + + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +bool LoopPeelingPass::ProcessFunction(Function* f) { + bool modified = false; + LoopDescriptor& loop_descriptor = *context()->GetLoopDescriptor(f); + + std::vector to_process_loop; + to_process_loop.reserve(loop_descriptor.NumLoops()); + for (Loop& l : loop_descriptor) { + to_process_loop.push_back(&l); + } + + ScalarEvolutionAnalysis scev_analysis(context()); + + for (Loop* loop : to_process_loop) { + CodeMetrics loop_size; + loop_size.Analyze(*loop); + + auto try_peel = [&loop_size, &modified, this](Loop* loop_to_peel) -> Loop* { + if (!loop_to_peel->IsLCSSA()) { + LoopUtils(context(), loop_to_peel).MakeLoopClosedSSA(); + } + + bool peeled_loop; + Loop* still_peelable_loop; + std::tie(peeled_loop, still_peelable_loop) = + ProcessLoop(loop_to_peel, &loop_size); + + if (peeled_loop) { + modified = true; + } + + return still_peelable_loop; + }; + + Loop* still_peelable_loop = try_peel(loop); + // The pass is working out the maximum factor by which a loop can be peeled. + // If the loop can potentially be peeled again, then there is only one + // possible direction, so only one call is still needed. + if (still_peelable_loop) { + try_peel(loop); + } + } + + return modified; +} + +std::pair LoopPeelingPass::ProcessLoop(Loop* loop, + CodeMetrics* loop_size) { + ScalarEvolutionAnalysis* scev_analysis = + context()->GetScalarEvolutionAnalysis(); + // Default values for bailing out. + std::pair bail_out{false, nullptr}; + + BasicBlock* exit_block = loop->FindConditionBlock(); + if (!exit_block) { + return bail_out; + } + + Instruction* exiting_iv = loop->FindConditionVariable(exit_block); + if (!exiting_iv) { + return bail_out; + } + size_t iterations = 0; + if (!loop->FindNumberOfIterations(exiting_iv, &*exit_block->tail(), + &iterations)) { + return bail_out; + } + if (!iterations) { + return bail_out; + } + + Instruction* canonical_induction_variable = nullptr; + + loop->GetHeaderBlock()->WhileEachPhiInst([&canonical_induction_variable, + scev_analysis, + this](Instruction* insn) { + if (const SERecurrentNode* iv = + scev_analysis->AnalyzeInstruction(insn)->AsSERecurrentNode()) { + const SEConstantNode* offset = iv->GetOffset()->AsSEConstantNode(); + const SEConstantNode* coeff = iv->GetCoefficient()->AsSEConstantNode(); + if (offset && coeff && offset->FoldToSingleValue() == 0 && + coeff->FoldToSingleValue() == 1) { + if (context()->get_type_mgr()->GetType(insn->type_id())->AsInteger()) { + canonical_induction_variable = insn; + return false; + } + } + } + return true; + }); + + bool is_signed = canonical_induction_variable + ? context() + ->get_type_mgr() + ->GetType(canonical_induction_variable->type_id()) + ->AsInteger() + ->IsSigned() + : false; + + LoopPeeling peeler( + loop, + InstructionBuilder( + context(), loop->GetHeaderBlock(), + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping) + .GetIntConstant(static_cast(iterations), + is_signed), + canonical_induction_variable); + + if (!peeler.CanPeelLoop()) { + return bail_out; + } + + // For each basic block in the loop, check if it can be peeled. If it + // can, get the direction (before/after) and by which factor. + LoopPeelingInfo peel_info(loop, iterations, scev_analysis); + + uint32_t peel_before_factor = 0; + uint32_t peel_after_factor = 0; + + for (uint32_t block : loop->GetBlocks()) { + if (block == exit_block->id()) { + continue; + } + BasicBlock* bb = cfg()->block(block); + PeelDirection direction; + uint32_t factor; + std::tie(direction, factor) = peel_info.GetPeelingInfo(bb); + + if (direction == PeelDirection::kNone) { + continue; + } + if (direction == PeelDirection::kBefore) { + peel_before_factor = std::max(peel_before_factor, factor); + } else { + assert(direction == PeelDirection::kAfter); + peel_after_factor = std::max(peel_after_factor, factor); + } + } + PeelDirection direction = PeelDirection::kNone; + uint32_t factor = 0; + + // Find which direction we should peel. + if (peel_before_factor) { + factor = peel_before_factor; + direction = PeelDirection::kBefore; + } + if (peel_after_factor) { + if (peel_before_factor < peel_after_factor) { + // Favor a peel after here and give the peel before another shot later. + factor = peel_after_factor; + direction = PeelDirection::kAfter; + } + } + + // Do the peel if we can. + if (direction == PeelDirection::kNone) return bail_out; + + // This does not take into account branch elimination opportunities and + // the unrolling. It assumes the peeled loop will be unrolled as well. + if (factor * loop_size->roi_size_ > code_grow_threshold_) { + return bail_out; + } + loop_size->roi_size_ *= factor; + + // Find if a loop should be peeled again. + Loop* extra_opportunity = nullptr; + + if (direction == PeelDirection::kBefore) { + peeler.PeelBefore(factor); + if (stats_) { + stats_->peeled_loops_.emplace_back(loop, PeelDirection::kBefore, factor); + } + if (peel_after_factor) { + // We could have peeled after, give it another try. + extra_opportunity = peeler.GetOriginalLoop(); + } + } else { + peeler.PeelAfter(factor); + if (stats_) { + stats_->peeled_loops_.emplace_back(loop, PeelDirection::kAfter, factor); + } + if (peel_before_factor) { + // We could have peeled before, give it another try. + extra_opportunity = peeler.GetClonedLoop(); + } + } + + return {true, extra_opportunity}; +} + +uint32_t LoopPeelingPass::LoopPeelingInfo::GetFirstLoopInvariantOperand( + Instruction* condition) const { + for (uint32_t i = 0; i < condition->NumInOperands(); i++) { + BasicBlock* bb = + context_->get_instr_block(condition->GetSingleWordInOperand(i)); + if (bb && loop_->IsInsideLoop(bb)) { + return condition->GetSingleWordInOperand(i); + } + } + + return 0; +} + +uint32_t LoopPeelingPass::LoopPeelingInfo::GetFirstNonLoopInvariantOperand( + Instruction* condition) const { + for (uint32_t i = 0; i < condition->NumInOperands(); i++) { + BasicBlock* bb = + context_->get_instr_block(condition->GetSingleWordInOperand(i)); + if (!bb || !loop_->IsInsideLoop(bb)) { + return condition->GetSingleWordInOperand(i); + } + } + + return 0; +} + +static bool IsHandledCondition(SpvOp opcode) { + switch (opcode) { + case SpvOpIEqual: + case SpvOpINotEqual: + case SpvOpUGreaterThan: + case SpvOpSGreaterThan: + case SpvOpUGreaterThanEqual: + case SpvOpSGreaterThanEqual: + case SpvOpULessThan: + case SpvOpSLessThan: + case SpvOpULessThanEqual: + case SpvOpSLessThanEqual: + return true; + default: + return false; + } +} + +LoopPeelingPass::LoopPeelingInfo::Direction +LoopPeelingPass::LoopPeelingInfo::GetPeelingInfo(BasicBlock* bb) const { + if (bb->terminator()->opcode() != SpvOpBranchConditional) { + return GetNoneDirection(); + } + + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + + Instruction* condition = + def_use_mgr->GetDef(bb->terminator()->GetSingleWordInOperand(0)); + + if (!IsHandledCondition(condition->opcode())) { + return GetNoneDirection(); + } + + if (!GetFirstLoopInvariantOperand(condition)) { + // No loop invariant, it cannot be peeled by this pass. + return GetNoneDirection(); + } + if (!GetFirstNonLoopInvariantOperand(condition)) { + // Seems to be a job for the unswitch pass. + return GetNoneDirection(); + } + + // Left hand-side. + SExpression lhs = scev_analysis_->AnalyzeInstruction( + def_use_mgr->GetDef(condition->GetSingleWordInOperand(0))); + if (lhs->GetType() == SENode::CanNotCompute) { + // Can't make any conclusion. + return GetNoneDirection(); + } + + // Right hand-side. + SExpression rhs = scev_analysis_->AnalyzeInstruction( + def_use_mgr->GetDef(condition->GetSingleWordInOperand(1))); + if (rhs->GetType() == SENode::CanNotCompute) { + // Can't make any conclusion. + return GetNoneDirection(); + } + + // Only take into account recurrent expression over the current loop. + bool is_lhs_rec = !scev_analysis_->IsLoopInvariant(loop_, lhs); + bool is_rhs_rec = !scev_analysis_->IsLoopInvariant(loop_, rhs); + + if ((is_lhs_rec && is_rhs_rec) || (!is_lhs_rec && !is_rhs_rec)) { + return GetNoneDirection(); + } + + if (is_lhs_rec) { + if (!lhs->AsSERecurrentNode() || + lhs->AsSERecurrentNode()->GetLoop() != loop_) { + return GetNoneDirection(); + } + } + if (is_rhs_rec) { + if (!rhs->AsSERecurrentNode() || + rhs->AsSERecurrentNode()->GetLoop() != loop_) { + return GetNoneDirection(); + } + } + + // If the op code is ==, then we try a peel before or after. + // If opcode is not <, >, <= or >=, we bail out. + // + // For the remaining cases, we canonicalize the expression so that the + // constant expression is on the left hand side and the recurring expression + // is on the right hand side. If we swap hand side, then < becomes >, <= + // becomes >= etc. + // If the opcode is <=, then we add 1 to the right hand side and do the peel + // check on <. + // If the opcode is >=, then we add 1 to the left hand side and do the peel + // check on >. + + CmpOperator cmp_operator; + switch (condition->opcode()) { + default: + return GetNoneDirection(); + case SpvOpIEqual: + case SpvOpINotEqual: + return HandleEquality(lhs, rhs); + case SpvOpUGreaterThan: + case SpvOpSGreaterThan: { + cmp_operator = CmpOperator::kGT; + break; + } + case SpvOpULessThan: + case SpvOpSLessThan: { + cmp_operator = CmpOperator::kLT; + break; + } + // We add one to transform >= into > and <= into <. + case SpvOpUGreaterThanEqual: + case SpvOpSGreaterThanEqual: { + cmp_operator = CmpOperator::kGE; + break; + } + case SpvOpULessThanEqual: + case SpvOpSLessThanEqual: { + cmp_operator = CmpOperator::kLE; + break; + } + } + + // Force the left hand side to be the non recurring expression. + if (is_lhs_rec) { + std::swap(lhs, rhs); + switch (cmp_operator) { + case CmpOperator::kLT: { + cmp_operator = CmpOperator::kGT; + break; + } + case CmpOperator::kGT: { + cmp_operator = CmpOperator::kLT; + break; + } + case CmpOperator::kLE: { + cmp_operator = CmpOperator::kGE; + break; + } + case CmpOperator::kGE: { + cmp_operator = CmpOperator::kLE; + break; + } + } + } + return HandleInequality(cmp_operator, lhs, rhs->AsSERecurrentNode()); +} + +SExpression LoopPeelingPass::LoopPeelingInfo::GetValueAtFirstIteration( + SERecurrentNode* rec) const { + return rec->GetOffset(); +} + +SExpression LoopPeelingPass::LoopPeelingInfo::GetValueAtIteration( + SERecurrentNode* rec, int64_t iteration) const { + SExpression coeff = rec->GetCoefficient(); + SExpression offset = rec->GetOffset(); + + return (coeff * iteration) + offset; +} + +SExpression LoopPeelingPass::LoopPeelingInfo::GetValueAtLastIteration( + SERecurrentNode* rec) const { + return GetValueAtIteration(rec, loop_max_iterations_ - 1); +} + +bool LoopPeelingPass::LoopPeelingInfo::EvalOperator(CmpOperator cmp_op, + SExpression lhs, + SExpression rhs, + bool* result) const { + assert(scev_analysis_->IsLoopInvariant(loop_, lhs)); + assert(scev_analysis_->IsLoopInvariant(loop_, rhs)); + // We perform the test: 0 cmp_op rhs - lhs + // What is left is then to determine the sign of the expression. + switch (cmp_op) { + case CmpOperator::kLT: { + return scev_analysis_->IsAlwaysGreaterThanZero(rhs - lhs, result); + } + case CmpOperator::kGT: { + return scev_analysis_->IsAlwaysGreaterThanZero(lhs - rhs, result); + } + case CmpOperator::kLE: { + return scev_analysis_->IsAlwaysGreaterOrEqualToZero(rhs - lhs, result); + } + case CmpOperator::kGE: { + return scev_analysis_->IsAlwaysGreaterOrEqualToZero(lhs - rhs, result); + } + } + return false; +} + +LoopPeelingPass::LoopPeelingInfo::Direction +LoopPeelingPass::LoopPeelingInfo::HandleEquality(SExpression lhs, + SExpression rhs) const { + { + // Try peel before opportunity. + SExpression lhs_cst = lhs; + if (SERecurrentNode* rec_node = lhs->AsSERecurrentNode()) { + lhs_cst = rec_node->GetOffset(); + } + SExpression rhs_cst = rhs; + if (SERecurrentNode* rec_node = rhs->AsSERecurrentNode()) { + rhs_cst = rec_node->GetOffset(); + } + + if (lhs_cst == rhs_cst) { + return Direction{LoopPeelingPass::PeelDirection::kBefore, 1}; + } + } + + { + // Try peel after opportunity. + SExpression lhs_cst = lhs; + if (SERecurrentNode* rec_node = lhs->AsSERecurrentNode()) { + // rec_node(x) = a * x + b + // assign to lhs: a * (loop_max_iterations_ - 1) + b + lhs_cst = GetValueAtLastIteration(rec_node); + } + SExpression rhs_cst = rhs; + if (SERecurrentNode* rec_node = rhs->AsSERecurrentNode()) { + // rec_node(x) = a * x + b + // assign to lhs: a * (loop_max_iterations_ - 1) + b + rhs_cst = GetValueAtLastIteration(rec_node); + } + + if (lhs_cst == rhs_cst) { + return Direction{LoopPeelingPass::PeelDirection::kAfter, 1}; + } + } + + return GetNoneDirection(); +} + +LoopPeelingPass::LoopPeelingInfo::Direction +LoopPeelingPass::LoopPeelingInfo::HandleInequality(CmpOperator cmp_op, + SExpression lhs, + SERecurrentNode* rhs) const { + SExpression offset = rhs->GetOffset(); + SExpression coefficient = rhs->GetCoefficient(); + // Compute (cst - B) / A. + std::pair flip_iteration = (lhs - offset) / coefficient; + if (!flip_iteration.first->AsSEConstantNode()) { + return GetNoneDirection(); + } + // note: !!flip_iteration.second normalize to 0/1 (via bool cast). + int64_t iteration = + flip_iteration.first->AsSEConstantNode()->FoldToSingleValue() + + !!flip_iteration.second; + if (iteration <= 0 || + loop_max_iterations_ <= static_cast(iteration)) { + // Always true or false within the loop bounds. + return GetNoneDirection(); + } + // If this is a <= or >= operator and the iteration, make sure |iteration| is + // the one flipping the condition. + // If (cst - B) and A are not divisible, this equivalent to a < or > check, so + // we skip this test. + if (!flip_iteration.second && + (cmp_op == CmpOperator::kLE || cmp_op == CmpOperator::kGE)) { + bool first_iteration; + bool current_iteration; + if (!EvalOperator(cmp_op, lhs, offset, &first_iteration) || + !EvalOperator(cmp_op, lhs, GetValueAtIteration(rhs, iteration), + ¤t_iteration)) { + return GetNoneDirection(); + } + // If the condition did not flip the next will. + if (first_iteration == current_iteration) { + iteration++; + } + } + + uint32_t cast_iteration = 0; + // sanity check: can we fit |iteration| in a uint32_t ? + if (static_cast(iteration) < std::numeric_limits::max()) { + cast_iteration = static_cast(iteration); + } + + if (cast_iteration) { + // Peel before if we are closer to the start, after if closer to the end. + if (loop_max_iterations_ / 2 > cast_iteration) { + return Direction{LoopPeelingPass::PeelDirection::kBefore, cast_iteration}; + } else { + return Direction{ + LoopPeelingPass::PeelDirection::kAfter, + static_cast(loop_max_iterations_ - cast_iteration)}; + } + } + + return GetNoneDirection(); +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/loop_peeling.h b/source/opt/loop_peeling.h new file mode 100644 index 000000000..413f896f2 --- /dev/null +++ b/source/opt/loop_peeling.h @@ -0,0 +1,336 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_LOOP_PEELING_H_ +#define SOURCE_OPT_LOOP_PEELING_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/opt/ir_context.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/loop_utils.h" +#include "source/opt/pass.h" +#include "source/opt/scalar_analysis.h" + +namespace spvtools { +namespace opt { + +// Utility class to perform the peeling of a given loop. +// The loop peeling transformation make a certain amount of a loop iterations to +// be executed either before (peel before) or after (peel after) the transformed +// loop. +// +// For peeling cases the transformation does the following steps: +// - It clones the loop and inserts the cloned loop before the original loop; +// - It connects all iterating values of the cloned loop with the +// corresponding original loop values so that the second loop starts with +// the appropriate values. +// - It inserts a new induction variable "i" is inserted into the cloned that +// starts with the value 0 and increment by step of one. +// +// The last step is specific to each case: +// - Peel before: the transformation is to peel the "N" first iterations. +// The exit condition of the cloned loop is changed so that the loop +// exits when "i < N" becomes false. The original loop is then protected to +// only execute if there is any iteration left to do. +// - Peel after: the transformation is to peel the "N" last iterations, +// then the exit condition of the cloned loop is changed so that the loop +// exits when "i + N < max_iteration" becomes false, where "max_iteration" +// is the upper bound of the loop. The cloned loop is then protected to +// only execute if there is any iteration left to do no covered by the +// second. +// +// To be peelable: +// - The loop must be in LCSSA form; +// - The loop must not contain any breaks; +// - The loop must not have any ambiguous iterators updates (see +// "CanPeelLoop"). +// The method "CanPeelLoop" checks that those constrained are met. +class LoopPeeling { + public: + // LoopPeeling constructor. + // |loop| is the loop to peel. + // |loop_iteration_count| is the instruction holding the |loop| iteration + // count, must be invariant for |loop| and must be of an int 32 type (signed + // or unsigned). + // |canonical_induction_variable| is an induction variable that can be used to + // count the number of iterations, must be of the same type as + // |loop_iteration_count| and start at 0 and increase by step of one at each + // iteration. The value nullptr is interpreted as no suitable variable exists + // and one will be created. + LoopPeeling(Loop* loop, Instruction* loop_iteration_count, + Instruction* canonical_induction_variable = nullptr) + : context_(loop->GetContext()), + loop_utils_(loop->GetContext(), loop), + loop_(loop), + loop_iteration_count_(!loop->IsInsideLoop(loop_iteration_count) + ? loop_iteration_count + : nullptr), + int_type_(nullptr), + original_loop_canonical_induction_variable_( + canonical_induction_variable), + canonical_induction_variable_(nullptr) { + if (loop_iteration_count_) { + int_type_ = context_->get_type_mgr() + ->GetType(loop_iteration_count_->type_id()) + ->AsInteger(); + if (canonical_induction_variable_) { + assert(canonical_induction_variable_->type_id() == + loop_iteration_count_->type_id() && + "loop_iteration_count and canonical_induction_variable do not " + "have the same type"); + } + } + GetIteratingExitValues(); + } + + // Returns true if the loop can be peeled. + // To be peelable, all operation involved in the update of the loop iterators + // must not dominates the exit condition. This restriction is a work around to + // not miss compile code like: + // + // for (int i = 0; i + 1 < N; i++) {} + // for (int i = 0; ++i < N; i++) {} + // + // The increment will happen before the test on the exit condition leading to + // very look-a-like code. + // + // This restriction will not apply if a loop rotate is applied before (i.e. + // becomes a do-while loop). + bool CanPeelLoop() const { + CFG& cfg = *context_->cfg(); + + if (!loop_iteration_count_) { + return false; + } + if (!int_type_) { + return false; + } + if (int_type_->width() != 32) { + return false; + } + if (!loop_->IsLCSSA()) { + return false; + } + if (!loop_->GetMergeBlock()) { + return false; + } + if (cfg.preds(loop_->GetMergeBlock()->id()).size() != 1) { + return false; + } + if (!IsConditionCheckSideEffectFree()) { + return false; + } + + return !std::any_of(exit_value_.cbegin(), exit_value_.cend(), + [](std::pair it) { + return it.second == nullptr; + }); + } + + // Moves the execution of the |factor| first iterations of the loop into a + // dedicated loop. + void PeelBefore(uint32_t factor); + + // Moves the execution of the |factor| last iterations of the loop into a + // dedicated loop. + void PeelAfter(uint32_t factor); + + // Returns the cloned loop. + Loop* GetClonedLoop() { return cloned_loop_; } + // Returns the original loop. + Loop* GetOriginalLoop() { return loop_; } + + private: + IRContext* context_; + LoopUtils loop_utils_; + // The original loop. + Loop* loop_; + // The initial |loop_| upper bound. + Instruction* loop_iteration_count_; + // The int type to use for the canonical_induction_variable_. + analysis::Integer* int_type_; + // The cloned loop. + Loop* cloned_loop_; + // This is set to true when the exit and back-edge branch instruction is the + // same. + bool do_while_form_; + // The canonical induction variable from the original loop if it exists. + Instruction* original_loop_canonical_induction_variable_; + // The canonical induction variable of the cloned loop. The induction variable + // is initialized to 0 and incremented by step of 1. + Instruction* canonical_induction_variable_; + // Map between loop iterators and exit values. Loop iterators + std::unordered_map exit_value_; + + // Duplicate |loop_| and place the new loop before the cloned loop. Iterating + // values from the cloned loop are then connected to the original loop as + // initializer. + void DuplicateAndConnectLoop(LoopUtils::LoopCloningResult* clone_results); + + // Insert the canonical induction variable into the first loop as a simplified + // counter. + void InsertCanonicalInductionVariable( + LoopUtils::LoopCloningResult* clone_results); + + // Fixes the exit condition of the before loop. The function calls + // |condition_builder| to get the condition to use in the conditional branch + // of the loop exit. The loop will be exited if the condition evaluate to + // true. |condition_builder| takes an Instruction* that represent the + // insertion point. + void FixExitCondition( + const std::function& condition_builder); + + // Gathers all operations involved in the update of |iterator| into + // |operations|. + void GetIteratorUpdateOperations( + const Loop* loop, Instruction* iterator, + std::unordered_set* operations); + + // Gathers exiting iterator values. The function builds a map between each + // iterating value in the loop (a phi instruction in the loop header) and its + // SSA value when it exit the loop. If no exit value can be accurately found, + // it is map to nullptr (see comment on CanPeelLoop). + void GetIteratingExitValues(); + + // Returns true if a for-loop has no instruction with effects before the + // condition check. + bool IsConditionCheckSideEffectFree() const; + + // Creates a new basic block and insert it between |bb| and the predecessor of + // |bb|. + BasicBlock* CreateBlockBefore(BasicBlock* bb); + + // Inserts code to only execute |loop| only if the given |condition| is true. + // |if_merge| is a suitable basic block to be used by the if condition as + // merge block. + // The function returns the if block protecting the loop. + BasicBlock* ProtectLoop(Loop* loop, Instruction* condition, + BasicBlock* if_merge); +}; + +// Implements a loop peeling optimization. +// For each loop, the pass will try to peel it if there is conditions that +// are true for the "N" first or last iterations of the loop. +// To avoid code size explosion, too large loops will not be peeled. +class LoopPeelingPass : public Pass { + public: + // Describes the peeling direction. + enum class PeelDirection { + kNone, // Cannot peel + kBefore, // Can peel before + kAfter // Can peel last + }; + + // Holds some statistics about peeled function. + struct LoopPeelingStats { + std::vector> peeled_loops_; + }; + + LoopPeelingPass(LoopPeelingStats* stats = nullptr) : stats_(stats) {} + + // Sets the loop peeling growth threshold. If the code size increase is above + // |code_grow_threshold|, the loop will not be peeled. The code size is + // measured in terms of SPIR-V instructions. + static void SetLoopPeelingThreshold(size_t code_grow_threshold) { + code_grow_threshold_ = code_grow_threshold; + } + + // Returns the loop peeling code growth threshold. + static size_t GetLoopPeelingThreshold() { return code_grow_threshold_; } + + const char* name() const override { return "loop-peeling"; } + + // Processes the given |module|. Returns Status::Failure if errors occur when + // processing. Returns the corresponding Status::Success if processing is + // succesful to indicate whether changes have been made to the modue. + Pass::Status Process() override; + + private: + // Describes the peeling direction. + enum class CmpOperator { + kLT, // less than + kGT, // greater than + kLE, // less than or equal + kGE, // greater than or equal + }; + + class LoopPeelingInfo { + public: + using Direction = std::pair; + + LoopPeelingInfo(Loop* loop, size_t loop_max_iterations, + ScalarEvolutionAnalysis* scev_analysis) + : context_(loop->GetContext()), + loop_(loop), + scev_analysis_(scev_analysis), + loop_max_iterations_(loop_max_iterations) {} + + // Returns by how much and to which direction a loop should be peeled to + // make the conditional branch of the basic block |bb| an unconditional + // branch. If |bb|'s terminator is not a conditional branch or the condition + // is not workable then it returns PeelDirection::kNone and a 0 factor. + Direction GetPeelingInfo(BasicBlock* bb) const; + + private: + // Returns the id of the loop invariant operand of the conditional + // expression |condition|. It returns if no operand is invariant. + uint32_t GetFirstLoopInvariantOperand(Instruction* condition) const; + // Returns the id of the non loop invariant operand of the conditional + // expression |condition|. It returns if all operands are invariant. + uint32_t GetFirstNonLoopInvariantOperand(Instruction* condition) const; + + // Returns the value of |rec| at the first loop iteration. + SExpression GetValueAtFirstIteration(SERecurrentNode* rec) const; + // Returns the value of |rec| at the given |iteration|. + SExpression GetValueAtIteration(SERecurrentNode* rec, + int64_t iteration) const; + // Returns the value of |rec| at the last loop iteration. + SExpression GetValueAtLastIteration(SERecurrentNode* rec) const; + + bool EvalOperator(CmpOperator cmp_op, SExpression lhs, SExpression rhs, + bool* result) const; + + Direction HandleEquality(SExpression lhs, SExpression rhs) const; + Direction HandleInequality(CmpOperator cmp_op, SExpression lhs, + SERecurrentNode* rhs) const; + + static Direction GetNoneDirection() { + return Direction{LoopPeelingPass::PeelDirection::kNone, 0}; + } + IRContext* context_; + Loop* loop_; + ScalarEvolutionAnalysis* scev_analysis_; + size_t loop_max_iterations_; + }; + // Peel profitable loops in |f|. + bool ProcessFunction(Function* f); + // Peel |loop| if profitable. + std::pair ProcessLoop(Loop* loop, CodeMetrics* loop_size); + + static size_t code_grow_threshold_; + LoopPeelingStats* stats_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_LOOP_PEELING_H_ diff --git a/source/opt/loop_unroller.cpp b/source/opt/loop_unroller.cpp new file mode 100644 index 000000000..10fac0433 --- /dev/null +++ b/source/opt/loop_unroller.cpp @@ -0,0 +1,1092 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/loop_unroller.h" + +#include +#include +#include +#include +#include +#include + +#include "source/opt/ir_builder.h" +#include "source/opt/loop_utils.h" + +// Implements loop util unrolling functionality for fully and partially +// unrolling loops. Given a factor it will duplicate the loop that many times, +// appending each one to the end of the old loop and removing backedges, to +// create a new unrolled loop. +// +// 1 - User calls LoopUtils::FullyUnroll or LoopUtils::PartiallyUnroll with a +// loop they wish to unroll. LoopUtils::CanPerformUnroll is used to +// validate that a given loop can be unrolled. That method (along with the +// constructor of loop) checks that the IR is in the expected canonicalised +// format. +// +// 2 - The LoopUtils methods create a LoopUnrollerUtilsImpl object to actually +// perform the unrolling. This implements helper methods to copy the loop basic +// blocks and remap the ids of instructions used inside them. +// +// 3 - The core of LoopUnrollerUtilsImpl is the Unroll method, this method +// actually performs the loop duplication. It does this by creating a +// LoopUnrollState object and then copying the loop as given by the factor +// parameter. The LoopUnrollState object retains the state of the unroller +// between the loop body copies as each iteration needs information on the last +// to adjust the phi induction variable, adjust the OpLoopMerge instruction in +// the main loop header, and change the previous continue block to point to the +// new header and the new continue block to the main loop header. +// +// 4 - If the loop is to be fully unrolled then it is simply closed after step +// 3, with the OpLoopMerge being deleted, the backedge removed, and the +// condition blocks folded. +// +// 5 - If it is being partially unrolled: if the unrolling factor leaves the +// loop with an even number of bodies with respect to the number of loop +// iterations then step 3 is all that is needed. If it is uneven then we need to +// duplicate the loop completely and unroll the duplicated loop to cover the +// residual part and adjust the first loop to cover only the "even" part. For +// instance if you request an unroll factor of 3 on a loop with 10 iterations +// then copying the body three times would leave you with three bodies in the +// loop +// where the loop still iterates over each 4 times. So we make two loops one +// iterating once then a second loop of three iterating 3 times. + +namespace spvtools { +namespace opt { +namespace { + +// Loop control constant value for DontUnroll flag. +static const uint32_t kLoopControlDontUnrollIndex = 2; + +// Operand index of the loop control parameter of the OpLoopMerge. +static const uint32_t kLoopControlIndex = 2; + +// This utility class encapsulates some of the state we need to maintain between +// loop unrolls. Specifically it maintains key blocks and the induction variable +// in the current loop duplication step and the blocks from the previous one. +// This is because each step of the unroll needs to use data from both the +// preceding step and the original loop. +struct LoopUnrollState { + LoopUnrollState() + : previous_phi_(nullptr), + previous_latch_block_(nullptr), + previous_condition_block_(nullptr), + new_phi(nullptr), + new_continue_block(nullptr), + new_condition_block(nullptr), + new_header_block(nullptr) {} + + // Initialize from the loop descriptor class. + LoopUnrollState(Instruction* induction, BasicBlock* latch_block, + BasicBlock* condition, std::vector&& phis) + : previous_phi_(induction), + previous_latch_block_(latch_block), + previous_condition_block_(condition), + new_phi(nullptr), + new_continue_block(nullptr), + new_condition_block(nullptr), + new_header_block(nullptr) { + previous_phis_ = std::move(phis); + } + + // Swap the state so that the new nodes are now the previous nodes. + void NextIterationState() { + previous_phi_ = new_phi; + previous_latch_block_ = new_latch_block; + previous_condition_block_ = new_condition_block; + previous_phis_ = std::move(new_phis_); + + // Clear new nodes. + new_phi = nullptr; + new_continue_block = nullptr; + new_condition_block = nullptr; + new_header_block = nullptr; + new_latch_block = nullptr; + + // Clear new block/instruction maps. + new_blocks.clear(); + new_inst.clear(); + ids_to_new_inst.clear(); + } + + // The induction variable from the immediately preceding loop body. + Instruction* previous_phi_; + + // All the phi nodes from the previous loop iteration. + std::vector previous_phis_; + + std::vector new_phis_; + + // The previous latch block. The backedge will be removed from this and + // added to the new latch block. + BasicBlock* previous_latch_block_; + + // The previous condition block. This may be folded to flatten the loop. + BasicBlock* previous_condition_block_; + + // The new induction variable. + Instruction* new_phi; + + // The new continue block. + BasicBlock* new_continue_block; + + // The new condition block. + BasicBlock* new_condition_block; + + // The new header block. + BasicBlock* new_header_block; + + // The new latch block. + BasicBlock* new_latch_block; + + // A mapping of new block ids to the original blocks which they were copied + // from. + std::unordered_map new_blocks; + + // A mapping of the original instruction ids to the instruction ids to their + // copies. + std::unordered_map new_inst; + + std::unordered_map ids_to_new_inst; +}; + +// This class implements the actual unrolling. It uses a LoopUnrollState to +// maintain the state of the unrolling inbetween steps. +class LoopUnrollerUtilsImpl { + public: + using BasicBlockListTy = std::vector>; + + LoopUnrollerUtilsImpl(IRContext* c, Function* function) + : context_(c), + function_(*function), + loop_condition_block_(nullptr), + loop_induction_variable_(nullptr), + number_of_loop_iterations_(0), + loop_step_value_(0), + loop_init_value_(0) {} + + // Unroll the |loop| by given |factor| by copying the whole body |factor| + // times. The resulting basicblock structure will remain a loop. + void PartiallyUnroll(Loop*, size_t factor); + + // If partially unrolling the |loop| would leave the loop with too many bodies + // for its number of iterations then this method should be used. This method + // will duplicate the |loop| completely, making the duplicated loop the + // successor of the original's merge block. The original loop will have its + // condition changed to loop over the residual part and the duplicate will be + // partially unrolled. The resulting structure will be two loops. + void PartiallyUnrollResidualFactor(Loop* loop, size_t factor); + + // Fully unroll the |loop| by copying the full body by the total number of + // loop iterations, folding all conditions, and removing the backedge from the + // continue block to the header. + void FullyUnroll(Loop* loop); + + // Get the ID of the variable in the |phi| paired with |label|. + uint32_t GetPhiDefID(const Instruction* phi, uint32_t label) const; + + // Close the loop by removing the OpLoopMerge from the |loop| header block and + // making the backedge point to the merge block. + void CloseUnrolledLoop(Loop* loop); + + // Remove the OpConditionalBranch instruction inside |conditional_block| used + // to branch to either exit or continue the loop and replace it with an + // unconditional OpBranch to block |new_target|. + void FoldConditionBlock(BasicBlock* condtion_block, uint32_t new_target); + + // Add all blocks_to_add_ to function_ at the |insert_point|. + void AddBlocksToFunction(const BasicBlock* insert_point); + + // Duplicates the |old_loop|, cloning each body and remaping the ids without + // removing instructions or changing relative structure. Result will be stored + // in |new_loop|. + void DuplicateLoop(Loop* old_loop, Loop* new_loop); + + inline size_t GetLoopIterationCount() const { + return number_of_loop_iterations_; + } + + // Extracts the initial state information from the |loop|. + void Init(Loop* loop); + + // Replace the uses of each induction variable outside the loop with the final + // value of the induction variable before the loop exit. To reflect the proper + // state of a fully unrolled loop. + void ReplaceInductionUseWithFinalValue(Loop* loop); + + // Remove all the instructions in the invalidated_instructions_ vector. + void RemoveDeadInstructions(); + + // Replace any use of induction variables outwith the loop with the final + // value of the induction variable in the unrolled loop. + void ReplaceOutsideLoopUseWithFinalValue(Loop* loop); + + // Set the LoopControl operand of the OpLoopMerge instruction to be + // DontUnroll. + void MarkLoopControlAsDontUnroll(Loop* loop) const; + + private: + // Remap all the in |basic_block| to new IDs and keep the mapping of new ids + // to old + // ids. |loop| is used to identify special loop blocks (header, continue, + // ect). + void AssignNewResultIds(BasicBlock* basic_block); + + // Using the map built by AssignNewResultIds, replace the uses in |inst| + // by the id that the use maps to. + void RemapOperands(Instruction* inst); + + // Using the map built by AssignNewResultIds, for each instruction in + // |basic_block| use + // that map to substitute the IDs used by instructions (in the operands) with + // the new ids. + void RemapOperands(BasicBlock* basic_block); + + // Copy the whole body of the loop, all blocks dominated by the |loop| header + // and not dominated by the |loop| merge. The copied body will be linked to by + // the old |loop| continue block and the new body will link to the |loop| + // header via the new continue block. |eliminate_conditions| is used to decide + // whether or not to fold all the condition blocks other than the last one. + void CopyBody(Loop* loop, bool eliminate_conditions); + + // Copy a given |block_to_copy| in the |loop| and record the mapping of the + // old/new ids. |preserve_instructions| determines whether or not the method + // will modify (other than result_id) instructions which are copied. + void CopyBasicBlock(Loop* loop, const BasicBlock* block_to_copy, + bool preserve_instructions); + + // The actual implementation of the unroll step. Unrolls |loop| by given + // |factor| by copying the body by |factor| times. Also propagates the + // induction variable value throughout the copies. + void Unroll(Loop* loop, size_t factor); + + // Fills the loop_blocks_inorder_ field with the ordered list of basic blocks + // as computed by the method ComputeLoopOrderedBlocks. + void ComputeLoopOrderedBlocks(Loop* loop); + + // Adds the blocks_to_add_ to both the |loop| and to the parent of |loop| if + // the parent exists. + void AddBlocksToLoop(Loop* loop) const; + + // After the partially unroll step the phi instructions in the header block + // will be in an illegal format. This function makes the phis legal by making + // the edge from the latch block come from the new latch block and the value + // to be the actual value of the phi at that point. + void LinkLastPhisToStart(Loop* loop) const; + + // A pointer to the IRContext. Used to add/remove instructions and for usedef + // chains. + IRContext* context_; + + // A reference the function the loop is within. + Function& function_; + + // A list of basic blocks to be added to the loop at the end of an unroll + // step. + BasicBlockListTy blocks_to_add_; + + // List of instructions which are now dead and can be removed. + std::vector invalidated_instructions_; + + // Maintains the current state of the transform between calls to unroll. + LoopUnrollState state_; + + // An ordered list containing the loop basic blocks. + std::vector loop_blocks_inorder_; + + // The block containing the condition check which contains a conditional + // branch to the merge and continue block. + BasicBlock* loop_condition_block_; + + // The induction variable of the loop. + Instruction* loop_induction_variable_; + + // Phis used in the loop need to be remapped to use the actual result values + // and then be remapped at the end. + std::vector loop_phi_instructions_; + + // The number of loop iterations that the loop would preform pre-unroll. + size_t number_of_loop_iterations_; + + // The amount that the loop steps each iteration. + int64_t loop_step_value_; + + // The value the loop starts stepping from. + int64_t loop_init_value_; +}; + +/* + * Static helper functions. + */ + +// Retrieve the index of the OpPhi instruction |phi| which corresponds to the +// incoming |block| id. +static uint32_t GetPhiIndexFromLabel(const BasicBlock* block, + const Instruction* phi) { + for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) { + if (block->id() == phi->GetSingleWordInOperand(i)) { + return i; + } + } + assert(false && "Could not find operand in instruction."); + return 0; +} + +void LoopUnrollerUtilsImpl::Init(Loop* loop) { + loop_condition_block_ = loop->FindConditionBlock(); + + // When we reinit the second loop during PartiallyUnrollResidualFactor we need + // to use the cached value from the duplicate step as the dominator tree + // basded solution, loop->FindConditionBlock, requires all the nodes to be + // connected up with the correct branches. They won't be at this point. + if (!loop_condition_block_) { + loop_condition_block_ = state_.new_condition_block; + } + assert(loop_condition_block_); + + loop_induction_variable_ = loop->FindConditionVariable(loop_condition_block_); + assert(loop_induction_variable_); + + bool found = loop->FindNumberOfIterations( + loop_induction_variable_, &*loop_condition_block_->ctail(), + &number_of_loop_iterations_, &loop_step_value_, &loop_init_value_); + (void)found; // To silence unused variable warning on release builds. + assert(found); + + // Blocks are stored in an unordered set of ids in the loop class, we need to + // create the dominator ordered list. + ComputeLoopOrderedBlocks(loop); +} + +// This function is used to partially unroll the loop when the factor provided +// would normally lead to an illegal optimization. Instead of just unrolling the +// loop it creates two loops and unrolls one and adjusts the condition on the +// other. The end result being that the new loop pair iterates over the correct +// number of bodies. +void LoopUnrollerUtilsImpl::PartiallyUnrollResidualFactor(Loop* loop, + size_t factor) { + // TODO(1841): Handle id overflow. + std::unique_ptr new_label{new Instruction( + context_, SpvOp::SpvOpLabel, 0, context_->TakeNextId(), {})}; + std::unique_ptr new_exit_bb{new BasicBlock(std::move(new_label))}; + + // Save the id of the block before we move it. + uint32_t new_merge_id = new_exit_bb->id(); + + // Add the block the list of blocks to add, we want this merge block to be + // right at the start of the new blocks. + blocks_to_add_.push_back(std::move(new_exit_bb)); + BasicBlock* new_exit_bb_raw = blocks_to_add_[0].get(); + Instruction& original_conditional_branch = *loop_condition_block_->tail(); + // Duplicate the loop, providing access to the blocks of both loops. + // This is a naked new due to the VS2013 requirement of not having unique + // pointers in vectors, as it will be inserted into a vector with + // loop_descriptor.AddLoop. + std::unique_ptr new_loop = MakeUnique(*loop); + + // Clear the basic blocks of the new loop. + new_loop->ClearBlocks(); + + DuplicateLoop(loop, new_loop.get()); + + // Add the blocks to the function. + AddBlocksToFunction(loop->GetMergeBlock()); + blocks_to_add_.clear(); + + // Create a new merge block for the first loop. + InstructionBuilder builder{context_, new_exit_bb_raw}; + // Make the first loop branch to the second. + builder.AddBranch(new_loop->GetHeaderBlock()->id()); + + loop_condition_block_ = state_.new_condition_block; + loop_induction_variable_ = state_.new_phi; + // Unroll the new loop by the factor with the usual -1 to account for the + // existing block iteration. + Unroll(new_loop.get(), factor); + + LinkLastPhisToStart(new_loop.get()); + AddBlocksToLoop(new_loop.get()); + + // Add the new merge block to the back of the list of blocks to be added. It + // needs to be the last block added to maintain dominator order in the binary. + blocks_to_add_.push_back( + std::unique_ptr(new_loop->GetMergeBlock())); + + // Add the blocks to the function. + AddBlocksToFunction(loop->GetMergeBlock()); + + // Reset the usedef analysis. + context_->InvalidateAnalysesExceptFor( + IRContext::Analysis::kAnalysisLoopAnalysis); + analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr(); + + // The loop condition. + Instruction* condition_check = def_use_manager->GetDef( + original_conditional_branch.GetSingleWordOperand(0)); + + // This should have been checked by the LoopUtils::CanPerformUnroll function + // before entering this. + assert(loop->IsSupportedCondition(condition_check->opcode())); + + // We need to account for the initial body when calculating the remainder. + int64_t remainder = Loop::GetResidualConditionValue( + condition_check->opcode(), loop_init_value_, loop_step_value_, + number_of_loop_iterations_, factor); + + assert(remainder > std::numeric_limits::min() && + remainder < std::numeric_limits::max()); + + Instruction* new_constant = nullptr; + + // If the remainder is negative then we add a signed constant, otherwise just + // add an unsigned constant. + if (remainder < 0) { + new_constant = builder.GetSintConstant(static_cast(remainder)); + } else { + new_constant = builder.GetUintConstant(static_cast(remainder)); + } + + uint32_t constant_id = new_constant->result_id(); + + // Update the condition check. + condition_check->SetInOperand(1, {constant_id}); + + // Update the next phi node. The phi will have a constant value coming in from + // the preheader block. For the duplicated loop we need to update the constant + // to be the amount of iterations covered by the first loop and the incoming + // block to be the first loops new merge block. + std::vector new_inductions; + new_loop->GetInductionVariables(new_inductions); + + std::vector old_inductions; + loop->GetInductionVariables(old_inductions); + for (size_t index = 0; index < new_inductions.size(); ++index) { + Instruction* new_induction = new_inductions[index]; + Instruction* old_induction = old_inductions[index]; + // Get the index of the loop initalizer, the value coming in from the + // preheader. + uint32_t initalizer_index = + GetPhiIndexFromLabel(new_loop->GetPreHeaderBlock(), old_induction); + + // Replace the second loop initalizer with the phi from the first + new_induction->SetInOperand(initalizer_index - 1, + {old_induction->result_id()}); + new_induction->SetInOperand(initalizer_index, {new_merge_id}); + + // If the use of the first loop induction variable is outside of the loop + // then replace that use with the second loop induction variable. + uint32_t second_loop_induction = new_induction->result_id(); + auto replace_use_outside_of_loop = [loop, second_loop_induction]( + Instruction* user, + uint32_t operand_index) { + if (!loop->IsInsideLoop(user)) { + user->SetOperand(operand_index, {second_loop_induction}); + } + }; + + context_->get_def_use_mgr()->ForEachUse(old_induction, + replace_use_outside_of_loop); + } + + context_->InvalidateAnalysesExceptFor( + IRContext::Analysis::kAnalysisLoopAnalysis); + + context_->ReplaceAllUsesWith(loop->GetMergeBlock()->id(), new_merge_id); + + LoopDescriptor& loop_descriptor = *context_->GetLoopDescriptor(&function_); + + loop_descriptor.AddLoop(std::move(new_loop), loop->GetParent()); + + RemoveDeadInstructions(); +} + +// Mark this loop as DontUnroll as it will already be unrolled and it may not +// be safe to unroll a previously partially unrolled loop. +void LoopUnrollerUtilsImpl::MarkLoopControlAsDontUnroll(Loop* loop) const { + Instruction* loop_merge_inst = loop->GetHeaderBlock()->GetLoopMergeInst(); + assert(loop_merge_inst && + "Loop merge instruction could not be found after entering unroller " + "(should have exited before this)"); + loop_merge_inst->SetInOperand(kLoopControlIndex, + {kLoopControlDontUnrollIndex}); +} + +// Duplicate the |loop| body |factor| - 1 number of times while keeping the loop +// backedge intact. This will leave the loop with |factor| number of bodies +// after accounting for the initial body. +void LoopUnrollerUtilsImpl::Unroll(Loop* loop, size_t factor) { + // If we unroll a loop partially it will not be safe to unroll it further. + // This is due to the current method of calculating the number of loop + // iterations. + MarkLoopControlAsDontUnroll(loop); + + std::vector inductions; + loop->GetInductionVariables(inductions); + state_ = LoopUnrollState{loop_induction_variable_, loop->GetLatchBlock(), + loop_condition_block_, std::move(inductions)}; + for (size_t i = 0; i < factor - 1; ++i) { + CopyBody(loop, true); + } +} + +void LoopUnrollerUtilsImpl::RemoveDeadInstructions() { + // Remove the dead instructions. + for (Instruction* inst : invalidated_instructions_) { + context_->KillInst(inst); + } +} + +void LoopUnrollerUtilsImpl::ReplaceInductionUseWithFinalValue(Loop* loop) { + context_->InvalidateAnalysesExceptFor( + IRContext::Analysis::kAnalysisLoopAnalysis | + IRContext::Analysis::kAnalysisDefUse | + IRContext::Analysis::kAnalysisInstrToBlockMapping); + + std::vector inductions; + loop->GetInductionVariables(inductions); + + for (size_t index = 0; index < inductions.size(); ++index) { + uint32_t trip_step_id = GetPhiDefID(state_.previous_phis_[index], + state_.previous_latch_block_->id()); + context_->ReplaceAllUsesWith(inductions[index]->result_id(), trip_step_id); + invalidated_instructions_.push_back(inductions[index]); + } +} + +// Fully unroll the loop by partially unrolling it by the number of loop +// iterations minus one for the body already accounted for. +void LoopUnrollerUtilsImpl::FullyUnroll(Loop* loop) { + // We unroll the loop by number of iterations in the loop. + Unroll(loop, number_of_loop_iterations_); + + // The first condition block is preserved until now so it can be copied. + FoldConditionBlock(loop_condition_block_, 1); + + // Delete the OpLoopMerge and remove the backedge to the header. + CloseUnrolledLoop(loop); + + // Mark the loop for later deletion. This allows us to preserve the loop + // iterators but still disregard dead loops. + loop->MarkLoopForRemoval(); + + // If the loop has a parent add the new blocks to the parent. + if (loop->GetParent()) { + AddBlocksToLoop(loop->GetParent()); + } + + // Add the blocks to the function. + AddBlocksToFunction(loop->GetMergeBlock()); + + ReplaceInductionUseWithFinalValue(loop); + + RemoveDeadInstructions(); + // Invalidate all analyses. + context_->InvalidateAnalysesExceptFor( + IRContext::Analysis::kAnalysisLoopAnalysis | + IRContext::Analysis::kAnalysisDefUse); +} + +// Copy a given basic block, give it a new result_id, and store the new block +// and the id mapping in the state. |preserve_instructions| is used to determine +// whether or not this function should edit instructions other than the +// |result_id|. +void LoopUnrollerUtilsImpl::CopyBasicBlock(Loop* loop, const BasicBlock* itr, + bool preserve_instructions) { + // Clone the block exactly, including the IDs. + BasicBlock* basic_block = itr->Clone(context_); + basic_block->SetParent(itr->GetParent()); + + // Assign each result a new unique ID and keep a mapping of the old ids to + // the new ones. + AssignNewResultIds(basic_block); + + // If this is the continue block we are copying. + if (itr == loop->GetContinueBlock()) { + // Make the OpLoopMerge point to this block for the continue. + if (!preserve_instructions) { + Instruction* merge_inst = loop->GetHeaderBlock()->GetLoopMergeInst(); + merge_inst->SetInOperand(1, {basic_block->id()}); + context_->UpdateDefUse(merge_inst); + } + + state_.new_continue_block = basic_block; + } + + // If this is the header block we are copying. + if (itr == loop->GetHeaderBlock()) { + state_.new_header_block = basic_block; + + if (!preserve_instructions) { + // Remove the loop merge instruction if it exists. + Instruction* merge_inst = basic_block->GetLoopMergeInst(); + if (merge_inst) invalidated_instructions_.push_back(merge_inst); + } + } + + // If this is the latch block being copied, record it in the state. + if (itr == loop->GetLatchBlock()) state_.new_latch_block = basic_block; + + // If this is the condition block we are copying. + if (itr == loop_condition_block_) { + state_.new_condition_block = basic_block; + } + + // Add this block to the list of blocks to add to the function at the end of + // the unrolling process. + blocks_to_add_.push_back(std::unique_ptr(basic_block)); + + // Keep tracking the old block via a map. + state_.new_blocks[itr->id()] = basic_block; +} + +void LoopUnrollerUtilsImpl::CopyBody(Loop* loop, bool eliminate_conditions) { + // Copy each basic block in the loop, give them new ids, and save state + // information. + for (const BasicBlock* itr : loop_blocks_inorder_) { + CopyBasicBlock(loop, itr, false); + } + + // Set the previous latch block to point to the new header. + Instruction* latch_branch = state_.previous_latch_block_->terminator(); + latch_branch->SetInOperand(0, {state_.new_header_block->id()}); + context_->UpdateDefUse(latch_branch); + + // As the algorithm copies the original loop blocks exactly, the tail of the + // latch block on iterations after the first one will be a branch to the new + // header and not the actual loop header. The last continue block in the loop + // should always be a backedge to the global header. + Instruction* new_latch_branch = state_.new_latch_block->terminator(); + new_latch_branch->SetInOperand(0, {loop->GetHeaderBlock()->id()}); + context_->AnalyzeUses(new_latch_branch); + + std::vector inductions; + loop->GetInductionVariables(inductions); + for (size_t index = 0; index < inductions.size(); ++index) { + Instruction* master_copy = inductions[index]; + + assert(master_copy->result_id() != 0); + Instruction* induction_clone = + state_.ids_to_new_inst[state_.new_inst[master_copy->result_id()]]; + + state_.new_phis_.push_back(induction_clone); + assert(induction_clone->result_id() != 0); + + if (!state_.previous_phis_.empty()) { + state_.new_inst[master_copy->result_id()] = GetPhiDefID( + state_.previous_phis_[index], state_.previous_latch_block_->id()); + } else { + // Do not replace the first phi block ids. + state_.new_inst[master_copy->result_id()] = master_copy->result_id(); + } + } + + if (eliminate_conditions && + state_.new_condition_block != loop_condition_block_) { + FoldConditionBlock(state_.new_condition_block, 1); + } + + // Only reference to the header block is the backedge in the latch block, + // don't change this. + state_.new_inst[loop->GetHeaderBlock()->id()] = loop->GetHeaderBlock()->id(); + + for (auto& pair : state_.new_blocks) { + RemapOperands(pair.second); + } + + for (Instruction* dead_phi : state_.new_phis_) + invalidated_instructions_.push_back(dead_phi); + + // Swap the state so the new is now the previous. + state_.NextIterationState(); +} + +uint32_t LoopUnrollerUtilsImpl::GetPhiDefID(const Instruction* phi, + uint32_t label) const { + for (uint32_t operand = 3; operand < phi->NumOperands(); operand += 2) { + if (phi->GetSingleWordOperand(operand) == label) { + return phi->GetSingleWordOperand(operand - 1); + } + } + assert(false && "Could not find a phi index matching the provided label"); + return 0; +} + +void LoopUnrollerUtilsImpl::FoldConditionBlock(BasicBlock* condition_block, + uint32_t operand_label) { + // Remove the old conditional branch to the merge and continue blocks. + Instruction& old_branch = *condition_block->tail(); + uint32_t new_target = old_branch.GetSingleWordOperand(operand_label); + + context_->KillInst(&old_branch); + // Add the new unconditional branch to the merge block. + InstructionBuilder builder( + context_, condition_block, + IRContext::Analysis::kAnalysisDefUse | + IRContext::Analysis::kAnalysisInstrToBlockMapping); + builder.AddBranch(new_target); +} + +void LoopUnrollerUtilsImpl::CloseUnrolledLoop(Loop* loop) { + // Remove the OpLoopMerge instruction from the function. + Instruction* merge_inst = loop->GetHeaderBlock()->GetLoopMergeInst(); + invalidated_instructions_.push_back(merge_inst); + + // Remove the final backedge to the header and make it point instead to the + // merge block. + Instruction* latch_instruction = state_.previous_latch_block_->terminator(); + latch_instruction->SetInOperand(0, {loop->GetMergeBlock()->id()}); + context_->UpdateDefUse(latch_instruction); + + // Remove all induction variables as the phis will now be invalid. Replace all + // uses with the constant initializer value (all uses of phis will be in + // the first iteration with the subsequent phis already having been removed). + std::vector inductions; + loop->GetInductionVariables(inductions); + + // We can use the state instruction mechanism to replace all internal loop + // values within the first loop trip (as the subsequent ones will be updated + // by the copy function) with the value coming in from the preheader and then + // use context ReplaceAllUsesWith for the uses outside the loop with the final + // trip phi value. + state_.new_inst.clear(); + for (Instruction* induction : inductions) { + uint32_t initalizer_id = + GetPhiDefID(induction, loop->GetPreHeaderBlock()->id()); + + state_.new_inst[induction->result_id()] = initalizer_id; + } + + for (BasicBlock* block : loop_blocks_inorder_) { + RemapOperands(block); + } + + // Rewrite the last phis, since they may still reference the original phi. + for (Instruction* last_phi : state_.previous_phis_) { + RemapOperands(last_phi); + } +} + +// Uses the first loop to create a copy of the loop with new IDs. +void LoopUnrollerUtilsImpl::DuplicateLoop(Loop* old_loop, Loop* new_loop) { + std::vector new_block_order; + + // Copy every block in the old loop. + for (const BasicBlock* itr : loop_blocks_inorder_) { + CopyBasicBlock(old_loop, itr, true); + new_block_order.push_back(blocks_to_add_.back().get()); + } + + // Clone the merge block, give it a new id and record it in the state. + BasicBlock* new_merge = old_loop->GetMergeBlock()->Clone(context_); + new_merge->SetParent(old_loop->GetMergeBlock()->GetParent()); + AssignNewResultIds(new_merge); + state_.new_blocks[old_loop->GetMergeBlock()->id()] = new_merge; + + // Remap the operands of every instruction in the loop to point to the new + // copies. + for (auto& pair : state_.new_blocks) { + RemapOperands(pair.second); + } + + loop_blocks_inorder_ = std::move(new_block_order); + + AddBlocksToLoop(new_loop); + + new_loop->SetHeaderBlock(state_.new_header_block); + new_loop->SetContinueBlock(state_.new_continue_block); + new_loop->SetLatchBlock(state_.new_latch_block); + new_loop->SetMergeBlock(new_merge); +} + +// Whenever the utility copies a block it stores it in a tempory buffer, this +// function adds the buffer into the Function. The blocks will be inserted +// after the block |insert_point|. +void LoopUnrollerUtilsImpl::AddBlocksToFunction( + const BasicBlock* insert_point) { + for (auto basic_block_iterator = function_.begin(); + basic_block_iterator != function_.end(); ++basic_block_iterator) { + if (basic_block_iterator->id() == insert_point->id()) { + basic_block_iterator.InsertBefore(&blocks_to_add_); + return; + } + } + + assert( + false && + "Could not add basic blocks to function as insert point was not found."); +} + +// Assign all result_ids in |basic_block| instructions to new IDs and preserve +// the mapping of new ids to old ones. +void LoopUnrollerUtilsImpl::AssignNewResultIds(BasicBlock* basic_block) { + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + + // Label instructions aren't covered by normal traversal of the + // instructions. + // TODO(1841): Handle id overflow. + uint32_t new_label_id = context_->TakeNextId(); + + // Assign a new id to the label. + state_.new_inst[basic_block->GetLabelInst()->result_id()] = new_label_id; + basic_block->GetLabelInst()->SetResultId(new_label_id); + def_use_mgr->AnalyzeInstDefUse(basic_block->GetLabelInst()); + + for (Instruction& inst : *basic_block) { + uint32_t old_id = inst.result_id(); + + // Ignore stores etc. + if (old_id == 0) { + continue; + } + + // Give the instruction a new id. + // TODO(1841): Handle id overflow. + inst.SetResultId(context_->TakeNextId()); + def_use_mgr->AnalyzeInstDef(&inst); + + // Save the mapping of old_id -> new_id. + state_.new_inst[old_id] = inst.result_id(); + // Check if this instruction is the induction variable. + if (loop_induction_variable_->result_id() == old_id) { + // Save a pointer to the new copy of it. + state_.new_phi = &inst; + } + state_.ids_to_new_inst[inst.result_id()] = &inst; + } +} + +void LoopUnrollerUtilsImpl::RemapOperands(Instruction* inst) { + auto remap_operands_to_new_ids = [this](uint32_t* id) { + auto itr = state_.new_inst.find(*id); + + if (itr != state_.new_inst.end()) { + *id = itr->second; + } + }; + + inst->ForEachInId(remap_operands_to_new_ids); + context_->AnalyzeUses(inst); +} + +void LoopUnrollerUtilsImpl::RemapOperands(BasicBlock* basic_block) { + for (Instruction& inst : *basic_block) { + RemapOperands(&inst); + } +} + +// Generate the ordered list of basic blocks in the |loop| and cache it for +// later use. +void LoopUnrollerUtilsImpl::ComputeLoopOrderedBlocks(Loop* loop) { + loop_blocks_inorder_.clear(); + loop->ComputeLoopStructuredOrder(&loop_blocks_inorder_); +} + +// Adds the blocks_to_add_ to both the loop and to the parent. +void LoopUnrollerUtilsImpl::AddBlocksToLoop(Loop* loop) const { + // Add the blocks to this loop. + for (auto& block_itr : blocks_to_add_) { + loop->AddBasicBlock(block_itr.get()); + } + + // Add the blocks to the parent as well. + if (loop->GetParent()) AddBlocksToLoop(loop->GetParent()); +} + +void LoopUnrollerUtilsImpl::LinkLastPhisToStart(Loop* loop) const { + std::vector inductions; + loop->GetInductionVariables(inductions); + + for (size_t i = 0; i < inductions.size(); ++i) { + Instruction* last_phi_in_block = state_.previous_phis_[i]; + + uint32_t phi_index = + GetPhiIndexFromLabel(state_.previous_latch_block_, last_phi_in_block); + uint32_t phi_variable = + last_phi_in_block->GetSingleWordInOperand(phi_index - 1); + uint32_t phi_label = last_phi_in_block->GetSingleWordInOperand(phi_index); + + Instruction* phi = inductions[i]; + phi->SetInOperand(phi_index - 1, {phi_variable}); + phi->SetInOperand(phi_index, {phi_label}); + } +} + +// Duplicate the |loop| body |factor| number of times while keeping the loop +// backedge intact. +void LoopUnrollerUtilsImpl::PartiallyUnroll(Loop* loop, size_t factor) { + Unroll(loop, factor); + LinkLastPhisToStart(loop); + AddBlocksToLoop(loop); + AddBlocksToFunction(loop->GetMergeBlock()); + RemoveDeadInstructions(); +} + +/* + * End LoopUtilsImpl. + */ + +} // namespace + +/* + * + * Begin Utils. + * + * */ + +bool LoopUtils::CanPerformUnroll() { + // The loop is expected to be in structured order. + if (!loop_->GetHeaderBlock()->GetMergeInst()) { + return false; + } + + // Find check the loop has a condition we can find and evaluate. + const BasicBlock* condition = loop_->FindConditionBlock(); + if (!condition) return false; + + // Check that we can find and process the induction variable. + const Instruction* induction = loop_->FindConditionVariable(condition); + if (!induction || induction->opcode() != SpvOpPhi) return false; + + // Check that we can find the number of loop iterations. + if (!loop_->FindNumberOfIterations(induction, &*condition->ctail(), nullptr)) + return false; + + // Make sure the latch block is a unconditional branch to the header + // block. + const Instruction& branch = *loop_->GetLatchBlock()->ctail(); + bool branching_assumption = + branch.opcode() == SpvOpBranch && + branch.GetSingleWordInOperand(0) == loop_->GetHeaderBlock()->id(); + if (!branching_assumption) { + return false; + } + + std::vector inductions; + loop_->GetInductionVariables(inductions); + + // Ban breaks within the loop. + const std::vector& merge_block_preds = + context_->cfg()->preds(loop_->GetMergeBlock()->id()); + if (merge_block_preds.size() != 1) { + return false; + } + + // Ban continues within the loop. + const std::vector& continue_block_preds = + context_->cfg()->preds(loop_->GetContinueBlock()->id()); + if (continue_block_preds.size() != 1) { + return false; + } + + // Ban returns in the loop. + // Iterate over all the blocks within the loop and check that none of them + // exit the loop. + for (uint32_t label_id : loop_->GetBlocks()) { + const BasicBlock* block = context_->cfg()->block(label_id); + if (block->ctail()->opcode() == SpvOp::SpvOpKill || + block->ctail()->opcode() == SpvOp::SpvOpReturn || + block->ctail()->opcode() == SpvOp::SpvOpReturnValue) { + return false; + } + } + // Can only unroll inner loops. + if (!loop_->AreAllChildrenMarkedForRemoval()) { + return false; + } + + return true; +} + +bool LoopUtils::PartiallyUnroll(size_t factor) { + if (factor == 1 || !CanPerformUnroll()) return false; + + // Create the unroller utility. + LoopUnrollerUtilsImpl unroller{context_, + loop_->GetHeaderBlock()->GetParent()}; + unroller.Init(loop_); + + // If the unrolling factor is larger than or the same size as the loop just + // fully unroll the loop. + if (factor >= unroller.GetLoopIterationCount()) { + unroller.FullyUnroll(loop_); + return true; + } + + // If the loop unrolling factor is an residual number of iterations we need to + // let run the loop for the residual part then let it branch into the unrolled + // remaining part. We add one when calucating the remainder to take into + // account the one iteration already in the loop. + if (unroller.GetLoopIterationCount() % factor != 0) { + unroller.PartiallyUnrollResidualFactor(loop_, factor); + } else { + unroller.PartiallyUnroll(loop_, factor); + } + + return true; +} + +bool LoopUtils::FullyUnroll() { + if (!CanPerformUnroll()) return false; + + std::vector inductions; + loop_->GetInductionVariables(inductions); + + LoopUnrollerUtilsImpl unroller{context_, + loop_->GetHeaderBlock()->GetParent()}; + + unroller.Init(loop_); + unroller.FullyUnroll(loop_); + + return true; +} + +void LoopUtils::Finalize() { + // Clean up the loop descriptor to preserve the analysis. + + LoopDescriptor* LD = context_->GetLoopDescriptor(&function_); + LD->PostModificationCleanup(); +} + +/* + * + * Begin Pass. + * + */ + +Pass::Status LoopUnroller::Process() { + bool changed = false; + for (Function& f : *context()->module()) { + LoopDescriptor* LD = context()->GetLoopDescriptor(&f); + for (Loop& loop : *LD) { + LoopUtils loop_utils{context(), &loop}; + if (!loop.HasUnrollLoopControl() || !loop_utils.CanPerformUnroll()) { + continue; + } + + if (fully_unroll_) { + loop_utils.FullyUnroll(); + } else { + loop_utils.PartiallyUnroll(unroll_factor_); + } + changed = true; + } + LD->PostModificationCleanup(); + } + + return changed ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/loop_unroller.h b/source/opt/loop_unroller.h new file mode 100644 index 000000000..71e7cca31 --- /dev/null +++ b/source/opt/loop_unroller.h @@ -0,0 +1,49 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_LOOP_UNROLLER_H_ +#define SOURCE_OPT_LOOP_UNROLLER_H_ + +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +class LoopUnroller : public Pass { + public: + LoopUnroller() : Pass(), fully_unroll_(true), unroll_factor_(0) {} + LoopUnroller(bool fully_unroll, int unroll_factor) + : Pass(), fully_unroll_(fully_unroll), unroll_factor_(unroll_factor) {} + + const char* name() const override { return "loop-unroll"; } + + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators | + IRContext::kAnalysisNameMap | IRContext::kAnalysisConstants | + IRContext::kAnalysisTypes; + } + + private: + bool fully_unroll_; + int unroll_factor_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_LOOP_UNROLLER_H_ diff --git a/source/opt/loop_unswitch_pass.cpp b/source/opt/loop_unswitch_pass.cpp new file mode 100644 index 000000000..502fc6b68 --- /dev/null +++ b/source/opt/loop_unswitch_pass.cpp @@ -0,0 +1,620 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/loop_unswitch_pass.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/opt/basic_block.h" +#include "source/opt/dominator_tree.h" +#include "source/opt/fold.h" +#include "source/opt/function.h" +#include "source/opt/instruction.h" +#include "source/opt/ir_builder.h" +#include "source/opt/ir_context.h" +#include "source/opt/loop_descriptor.h" + +#include "source/opt/loop_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +static const uint32_t kTypePointerStorageClassInIdx = 0; + +} // anonymous namespace + +namespace { + +// This class handle the unswitch procedure for a given loop. +// The unswitch will not happen if: +// - The loop has any instruction that will prevent it; +// - The loop invariant condition is not uniform. +class LoopUnswitch { + public: + LoopUnswitch(IRContext* context, Function* function, Loop* loop, + LoopDescriptor* loop_desc) + : function_(function), + loop_(loop), + loop_desc_(*loop_desc), + context_(context), + switch_block_(nullptr) {} + + // Returns true if the loop can be unswitched. + // Can be unswitch if: + // - The loop has no instructions that prevents it (such as barrier); + // - The loop has one conditional branch or switch that do not depends on the + // loop; + // - The loop invariant condition is uniform; + bool CanUnswitchLoop() { + if (switch_block_) return true; + if (loop_->IsSafeToClone()) return false; + + CFG& cfg = *context_->cfg(); + + for (uint32_t bb_id : loop_->GetBlocks()) { + BasicBlock* bb = cfg.block(bb_id); + if (loop_->GetLatchBlock() == bb) { + continue; + } + + if (bb->terminator()->IsBranch() && + bb->terminator()->opcode() != SpvOpBranch) { + if (IsConditionNonConstantLoopInvariant(bb->terminator())) { + switch_block_ = bb; + break; + } + } + } + + return switch_block_; + } + + // Return the iterator to the basic block |bb|. + Function::iterator FindBasicBlockPosition(BasicBlock* bb_to_find) { + Function::iterator it = function_->FindBlock(bb_to_find->id()); + assert(it != function_->end() && "Basic Block not found"); + return it; + } + + // Creates a new basic block and insert it into the function |fn| at the + // position |ip|. This function preserves the def/use and instr to block + // managers. + BasicBlock* CreateBasicBlock(Function::iterator ip) { + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + + // TODO(1841): Handle id overflow. + BasicBlock* bb = &*ip.InsertBefore(std::unique_ptr( + new BasicBlock(std::unique_ptr(new Instruction( + context_, SpvOpLabel, 0, context_->TakeNextId(), {}))))); + bb->SetParent(function_); + def_use_mgr->AnalyzeInstDef(bb->GetLabelInst()); + context_->set_instr_block(bb->GetLabelInst(), bb); + + return bb; + } + + Instruction* GetValueForDefaultPathForSwitch(Instruction* switch_inst) { + assert(switch_inst->opcode() == SpvOpSwitch && + "The given instructoin must be an OpSwitch."); + + // Find a value that can be used to select the default path. + // If none are possible, then it will just use 0. The value does not matter + // because this path will never be taken becaues the new switch outside of + // the loop cannot select this path either. + std::vector existing_values; + for (uint32_t i = 2; i < switch_inst->NumInOperands(); i += 2) { + existing_values.push_back(switch_inst->GetSingleWordInOperand(i)); + } + std::sort(existing_values.begin(), existing_values.end()); + uint32_t value_for_default_path = 0; + if (existing_values.size() < std::numeric_limits::max()) { + for (value_for_default_path = 0; + value_for_default_path < existing_values.size(); + value_for_default_path++) { + if (existing_values[value_for_default_path] != value_for_default_path) { + break; + } + } + } + InstructionBuilder builder( + context_, static_cast(nullptr), + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + return builder.GetUintConstant(value_for_default_path); + } + + // Unswitches |loop_|. + void PerformUnswitch() { + assert(CanUnswitchLoop() && + "Cannot unswitch if there is not constant condition"); + assert(loop_->GetPreHeaderBlock() && "This loop has no pre-header block"); + assert(loop_->IsLCSSA() && "This loop is not in LCSSA form"); + + CFG& cfg = *context_->cfg(); + DominatorTree* dom_tree = + &context_->GetDominatorAnalysis(function_)->GetDomTree(); + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + LoopUtils loop_utils(context_, loop_); + + ////////////////////////////////////////////////////////////////////////////// + // Step 1: Create the if merge block for structured modules. + // To do so, the |loop_| merge block will become the if's one and we + // create a merge for the loop. This will limit the amount of duplicated + // code the structured control flow imposes. + // For non structured program, the new loop will be connected to + // the old loop's exit blocks. + ////////////////////////////////////////////////////////////////////////////// + + // Get the merge block if it exists. + BasicBlock* if_merge_block = loop_->GetMergeBlock(); + // The merge block is only created if the loop has a unique exit block. We + // have this guarantee for structured loops, for compute loop it will + // trivially help maintain both a structured-like form and LCSAA. + BasicBlock* loop_merge_block = + if_merge_block + ? CreateBasicBlock(FindBasicBlockPosition(if_merge_block)) + : nullptr; + if (loop_merge_block) { + // Add the instruction and update managers. + InstructionBuilder builder( + context_, loop_merge_block, + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + builder.AddBranch(if_merge_block->id()); + builder.SetInsertPoint(&*loop_merge_block->begin()); + cfg.RegisterBlock(loop_merge_block); + def_use_mgr->AnalyzeInstDef(loop_merge_block->GetLabelInst()); + // Update CFG. + if_merge_block->ForEachPhiInst( + [loop_merge_block, &builder, this](Instruction* phi) { + Instruction* cloned = phi->Clone(context_); + cloned->SetResultId(TakeNextId()); + builder.AddInstruction(std::unique_ptr(cloned)); + phi->SetInOperand(0, {cloned->result_id()}); + phi->SetInOperand(1, {loop_merge_block->id()}); + for (uint32_t j = phi->NumInOperands() - 1; j > 1; j--) + phi->RemoveInOperand(j); + }); + // Copy the predecessor list (will get invalidated otherwise). + std::vector preds = cfg.preds(if_merge_block->id()); + for (uint32_t pid : preds) { + if (pid == loop_merge_block->id()) continue; + BasicBlock* p_bb = cfg.block(pid); + p_bb->ForEachSuccessorLabel( + [if_merge_block, loop_merge_block](uint32_t* id) { + if (*id == if_merge_block->id()) *id = loop_merge_block->id(); + }); + cfg.AddEdge(pid, loop_merge_block->id()); + } + cfg.RemoveNonExistingEdges(if_merge_block->id()); + // Update loop descriptor. + if (Loop* ploop = loop_->GetParent()) { + ploop->AddBasicBlock(loop_merge_block); + loop_desc_.SetBasicBlockToLoop(loop_merge_block->id(), ploop); + } + // Update the dominator tree. + DominatorTreeNode* loop_merge_dtn = + dom_tree->GetOrInsertNode(loop_merge_block); + DominatorTreeNode* if_merge_block_dtn = + dom_tree->GetOrInsertNode(if_merge_block); + loop_merge_dtn->parent_ = if_merge_block_dtn->parent_; + loop_merge_dtn->children_.push_back(if_merge_block_dtn); + loop_merge_dtn->parent_->children_.push_back(loop_merge_dtn); + if_merge_block_dtn->parent_->children_.erase(std::find( + if_merge_block_dtn->parent_->children_.begin(), + if_merge_block_dtn->parent_->children_.end(), if_merge_block_dtn)); + + loop_->SetMergeBlock(loop_merge_block); + } + + //////////////////////////////////////////////////////////////////////////// + // Step 2: Build a new preheader for |loop_|, use the old one + // for the invariant branch. + //////////////////////////////////////////////////////////////////////////// + + BasicBlock* if_block = loop_->GetPreHeaderBlock(); + // If this preheader is the parent loop header, + // we need to create a dedicated block for the if. + BasicBlock* loop_pre_header = + CreateBasicBlock(++FindBasicBlockPosition(if_block)); + InstructionBuilder( + context_, loop_pre_header, + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping) + .AddBranch(loop_->GetHeaderBlock()->id()); + + if_block->tail()->SetInOperand(0, {loop_pre_header->id()}); + + // Update loop descriptor. + if (Loop* ploop = loop_desc_[if_block]) { + ploop->AddBasicBlock(loop_pre_header); + loop_desc_.SetBasicBlockToLoop(loop_pre_header->id(), ploop); + } + + // Update the CFG. + cfg.RegisterBlock(loop_pre_header); + def_use_mgr->AnalyzeInstDef(loop_pre_header->GetLabelInst()); + cfg.AddEdge(if_block->id(), loop_pre_header->id()); + cfg.RemoveNonExistingEdges(loop_->GetHeaderBlock()->id()); + + loop_->GetHeaderBlock()->ForEachPhiInst( + [loop_pre_header, if_block](Instruction* phi) { + phi->ForEachInId([loop_pre_header, if_block](uint32_t* id) { + if (*id == if_block->id()) { + *id = loop_pre_header->id(); + } + }); + }); + loop_->SetPreHeaderBlock(loop_pre_header); + + // Update the dominator tree. + DominatorTreeNode* loop_pre_header_dtn = + dom_tree->GetOrInsertNode(loop_pre_header); + DominatorTreeNode* if_block_dtn = dom_tree->GetTreeNode(if_block); + loop_pre_header_dtn->parent_ = if_block_dtn; + assert( + if_block_dtn->children_.size() == 1 && + "A loop preheader should only have the header block as a child in the " + "dominator tree"); + loop_pre_header_dtn->children_.push_back(if_block_dtn->children_[0]); + if_block_dtn->children_.clear(); + if_block_dtn->children_.push_back(loop_pre_header_dtn); + + // Make domination queries valid. + dom_tree->ResetDFNumbering(); + + // Compute an ordered list of basic block to clone: loop blocks + pre-header + // + merge block. + loop_->ComputeLoopStructuredOrder(&ordered_loop_blocks_, true, true); + + ///////////////////////////// + // Do the actual unswitch: // + // - Clone the loop // + // - Connect exits // + // - Specialize the loop // + ///////////////////////////// + + Instruction* iv_condition = &*switch_block_->tail(); + SpvOp iv_opcode = iv_condition->opcode(); + Instruction* condition = + def_use_mgr->GetDef(iv_condition->GetOperand(0).words[0]); + + analysis::ConstantManager* cst_mgr = context_->get_constant_mgr(); + const analysis::Type* cond_type = + context_->get_type_mgr()->GetType(condition->type_id()); + + // Build the list of value for which we need to clone and specialize the + // loop. + std::vector> constant_branch; + // Special case for the original loop + Instruction* original_loop_constant_value; + if (iv_opcode == SpvOpBranchConditional) { + constant_branch.emplace_back( + cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant(cond_type, {0})), + nullptr); + original_loop_constant_value = + cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant(cond_type, {1})); + } else { + // We are looking to take the default branch, so we can't provide a + // specific value. + original_loop_constant_value = + GetValueForDefaultPathForSwitch(iv_condition); + + for (uint32_t i = 2; i < iv_condition->NumInOperands(); i += 2) { + constant_branch.emplace_back( + cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant( + cond_type, iv_condition->GetInOperand(i).words)), + nullptr); + } + } + + // Get the loop landing pads. + std::unordered_set if_merging_blocks; + std::function is_from_original_loop; + if (loop_->GetHeaderBlock()->GetLoopMergeInst()) { + if_merging_blocks.insert(if_merge_block->id()); + is_from_original_loop = [this](uint32_t id) { + return loop_->IsInsideLoop(id) || loop_->GetMergeBlock()->id() == id; + }; + } else { + loop_->GetExitBlocks(&if_merging_blocks); + is_from_original_loop = [this](uint32_t id) { + return loop_->IsInsideLoop(id); + }; + } + + for (auto& specialisation_pair : constant_branch) { + Instruction* specialisation_value = specialisation_pair.first; + ////////////////////////////////////////////////////////// + // Step 3: Duplicate |loop_|. + ////////////////////////////////////////////////////////// + LoopUtils::LoopCloningResult clone_result; + + Loop* cloned_loop = + loop_utils.CloneLoop(&clone_result, ordered_loop_blocks_); + specialisation_pair.second = cloned_loop->GetPreHeaderBlock(); + + //////////////////////////////////// + // Step 4: Specialize the loop. // + //////////////////////////////////// + + { + SpecializeLoop(cloned_loop, condition, specialisation_value); + + /////////////////////////////////////////////////////////// + // Step 5: Connect convergent edges to the landing pads. // + /////////////////////////////////////////////////////////// + + for (uint32_t merge_bb_id : if_merging_blocks) { + BasicBlock* merge = context_->cfg()->block(merge_bb_id); + // We are in LCSSA so we only care about phi instructions. + merge->ForEachPhiInst( + [is_from_original_loop, &clone_result](Instruction* phi) { + uint32_t num_in_operands = phi->NumInOperands(); + for (uint32_t i = 0; i < num_in_operands; i += 2) { + uint32_t pred = phi->GetSingleWordInOperand(i + 1); + if (is_from_original_loop(pred)) { + pred = clone_result.value_map_.at(pred); + uint32_t incoming_value_id = phi->GetSingleWordInOperand(i); + // Not all the incoming values are coming from the loop. + ValueMapTy::iterator new_value = + clone_result.value_map_.find(incoming_value_id); + if (new_value != clone_result.value_map_.end()) { + incoming_value_id = new_value->second; + } + phi->AddOperand({SPV_OPERAND_TYPE_ID, {incoming_value_id}}); + phi->AddOperand({SPV_OPERAND_TYPE_ID, {pred}}); + } + } + }); + } + } + function_->AddBasicBlocks(clone_result.cloned_bb_.begin(), + clone_result.cloned_bb_.end(), + ++FindBasicBlockPosition(if_block)); + } + + // Specialize the existing loop. + SpecializeLoop(loop_, condition, original_loop_constant_value); + BasicBlock* original_loop_target = loop_->GetPreHeaderBlock(); + + ///////////////////////////////////// + // Finally: connect the new loops. // + ///////////////////////////////////// + + // Delete the old jump + context_->KillInst(&*if_block->tail()); + InstructionBuilder builder(context_, if_block); + if (iv_opcode == SpvOpBranchConditional) { + assert(constant_branch.size() == 1); + builder.AddConditionalBranch( + condition->result_id(), original_loop_target->id(), + constant_branch[0].second->id(), + if_merge_block ? if_merge_block->id() : kInvalidId); + } else { + std::vector> targets; + for (auto& t : constant_branch) { + targets.emplace_back(t.first->GetInOperand(0).words, t.second->id()); + } + + builder.AddSwitch(condition->result_id(), original_loop_target->id(), + targets, + if_merge_block ? if_merge_block->id() : kInvalidId); + } + + switch_block_ = nullptr; + ordered_loop_blocks_.clear(); + + context_->InvalidateAnalysesExceptFor( + IRContext::Analysis::kAnalysisLoopAnalysis); + } + + private: + using ValueMapTy = std::unordered_map; + using BlockMapTy = std::unordered_map; + + Function* function_; + Loop* loop_; + LoopDescriptor& loop_desc_; + IRContext* context_; + + BasicBlock* switch_block_; + // Map between instructions and if they are dynamically uniform. + std::unordered_map dynamically_uniform_; + // The loop basic blocks in structured order. + std::vector ordered_loop_blocks_; + + // Returns the next usable id for the context. + uint32_t TakeNextId() { + // TODO(1841): Handle id overflow. + return context_->TakeNextId(); + } + + // Simplifies |loop| assuming the instruction |to_version_insn| takes the + // value |cst_value|. |block_range| is an iterator range returning the loop + // basic blocks in a structured order (dominator first). + // The function will ignore basic blocks returned by |block_range| if they + // does not belong to the loop. + // The set |dead_blocks| will contain all the dead basic blocks. + // + // Requirements: + // - |loop| must be in the LCSSA form; + // - |cst_value| must be constant. + void SpecializeLoop(Loop* loop, Instruction* to_version_insn, + Instruction* cst_value) { + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + + std::function ignore_node; + ignore_node = [loop](uint32_t bb_id) { return !loop->IsInsideLoop(bb_id); }; + + std::vector> use_list; + def_use_mgr->ForEachUse(to_version_insn, + [&use_list, &ignore_node, this]( + Instruction* inst, uint32_t operand_index) { + BasicBlock* bb = context_->get_instr_block(inst); + + if (!bb || ignore_node(bb->id())) { + // Out of the loop, the specialization does not + // apply any more. + return; + } + use_list.emplace_back(inst, operand_index); + }); + + // First pass: inject the specialized value into the loop (and only the + // loop). + for (auto use : use_list) { + Instruction* inst = use.first; + uint32_t operand_index = use.second; + + // To also handle switch, cst_value can be nullptr: this case + // means that we are looking to branch to the default target of + // the switch. We don't actually know its value so we don't touch + // it if it not a switch. + assert(cst_value && "We do not have a value to use."); + inst->SetOperand(operand_index, {cst_value->result_id()}); + def_use_mgr->AnalyzeInstUse(inst); + } + } + + // Returns true if |var| is dynamically uniform. + // Note: this is currently approximated as uniform. + bool IsDynamicallyUniform(Instruction* var, const BasicBlock* entry, + const DominatorTree& post_dom_tree) { + assert(post_dom_tree.IsPostDominator()); + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + + auto it = dynamically_uniform_.find(var->result_id()); + + if (it != dynamically_uniform_.end()) return it->second; + + analysis::DecorationManager* dec_mgr = context_->get_decoration_mgr(); + + bool& is_uniform = dynamically_uniform_[var->result_id()]; + is_uniform = false; + + dec_mgr->WhileEachDecoration(var->result_id(), SpvDecorationUniform, + [&is_uniform](const Instruction&) { + is_uniform = true; + return false; + }); + if (is_uniform) { + return is_uniform; + } + + BasicBlock* parent = context_->get_instr_block(var); + if (!parent) { + return is_uniform = true; + } + + if (!post_dom_tree.Dominates(parent->id(), entry->id())) { + return is_uniform = false; + } + if (var->opcode() == SpvOpLoad) { + const uint32_t PtrTypeId = + def_use_mgr->GetDef(var->GetSingleWordInOperand(0))->type_id(); + const Instruction* PtrTypeInst = def_use_mgr->GetDef(PtrTypeId); + uint32_t storage_class = + PtrTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx); + if (storage_class != SpvStorageClassUniform && + storage_class != SpvStorageClassUniformConstant) { + return is_uniform = false; + } + } else { + if (!context_->IsCombinatorInstruction(var)) { + return is_uniform = false; + } + } + + return is_uniform = var->WhileEachInId([entry, &post_dom_tree, + this](const uint32_t* id) { + return IsDynamicallyUniform(context_->get_def_use_mgr()->GetDef(*id), + entry, post_dom_tree); + }); + } + + // Returns true if |insn| is not a constant, but is loop invariant and + // dynamically uniform. + bool IsConditionNonConstantLoopInvariant(Instruction* insn) { + assert(insn->IsBranch()); + assert(insn->opcode() != SpvOpBranch); + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + + Instruction* condition = def_use_mgr->GetDef(insn->GetOperand(0).words[0]); + if (condition->IsConstant()) { + return false; + } + + if (loop_->IsInsideLoop(condition)) { + return false; + } + + return IsDynamicallyUniform( + condition, function_->entry().get(), + context_->GetPostDominatorAnalysis(function_)->GetDomTree()); + } +}; + +} // namespace + +Pass::Status LoopUnswitchPass::Process() { + bool modified = false; + Module* module = context()->module(); + + // Process each function in the module + for (Function& f : *module) { + modified |= ProcessFunction(&f); + } + + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +bool LoopUnswitchPass::ProcessFunction(Function* f) { + bool modified = false; + std::unordered_set processed_loop; + + LoopDescriptor& loop_descriptor = *context()->GetLoopDescriptor(f); + + bool loop_changed = true; + while (loop_changed) { + loop_changed = false; + for (Loop& loop : + make_range(++TreeDFIterator(loop_descriptor.GetDummyRootLoop()), + TreeDFIterator())) { + if (processed_loop.count(&loop)) continue; + processed_loop.insert(&loop); + + LoopUnswitch unswitcher(context(), f, &loop, &loop_descriptor); + while (unswitcher.CanUnswitchLoop()) { + if (!loop.IsLCSSA()) { + LoopUtils(context(), &loop).MakeLoopClosedSSA(); + } + modified = true; + loop_changed = true; + unswitcher.PerformUnswitch(); + } + if (loop_changed) break; + } + } + + return modified; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/loop_unswitch_pass.h b/source/opt/loop_unswitch_pass.h new file mode 100644 index 000000000..3ecdd6116 --- /dev/null +++ b/source/opt/loop_unswitch_pass.h @@ -0,0 +1,43 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_LOOP_UNSWITCH_PASS_H_ +#define SOURCE_OPT_LOOP_UNSWITCH_PASS_H_ + +#include "source/opt/loop_descriptor.h" +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// Implements the loop unswitch optimization. +// The loop unswitch hoists invariant "if" statements if the conditions are +// constant within the loop and clones the loop for each branch. +class LoopUnswitchPass : public Pass { + public: + const char* name() const override { return "loop-unswitch"; } + + // Processes the given |module|. Returns Status::Failure if errors occur when + // processing. Returns the corresponding Status::Success if processing is + // succesful to indicate whether changes have been made to the modue. + Pass::Status Process() override; + + private: + bool ProcessFunction(Function* f); +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_LOOP_UNSWITCH_PASS_H_ diff --git a/source/opt/loop_utils.cpp b/source/opt/loop_utils.cpp new file mode 100644 index 000000000..8c6d355d6 --- /dev/null +++ b/source/opt/loop_utils.cpp @@ -0,0 +1,694 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include "source/cfa.h" +#include "source/opt/cfg.h" +#include "source/opt/ir_builder.h" +#include "source/opt/ir_context.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/loop_utils.h" + +namespace spvtools { +namespace opt { + +namespace { +// Return true if |bb| is dominated by at least one block in |exits| +static inline bool DominatesAnExit(BasicBlock* bb, + const std::unordered_set& exits, + const DominatorTree& dom_tree) { + for (BasicBlock* e_bb : exits) + if (dom_tree.Dominates(bb, e_bb)) return true; + return false; +} + +// Utility class to rewrite out-of-loop uses of an in-loop definition in terms +// of phi instructions to achieve a LCSSA form. +// For a given definition, the class user registers phi instructions using that +// definition in all loop exit blocks by which the definition escapes. +// Then, when rewriting a use of the definition, the rewriter walks the +// paths from the use the loop exits. At each step, it will insert a phi +// instruction to merge the incoming value according to exit blocks definition. +class LCSSARewriter { + public: + LCSSARewriter(IRContext* context, const DominatorTree& dom_tree, + const std::unordered_set& exit_bb, + BasicBlock* merge_block) + : context_(context), + cfg_(context_->cfg()), + dom_tree_(dom_tree), + exit_bb_(exit_bb), + merge_block_id_(merge_block ? merge_block->id() : 0) {} + + struct UseRewriter { + explicit UseRewriter(LCSSARewriter* base, const Instruction& def_insn) + : base_(base), def_insn_(def_insn) {} + // Rewrites the use of |def_insn_| by the instruction |user| at the index + // |operand_index| in terms of phi instruction. This recursively builds new + // phi instructions from |user| to the loop exit blocks' phis. The use of + // |def_insn_| in |user| is replaced by the relevant phi instruction at the + // end of the operation. + // It is assumed that |user| does not dominates any of the loop exit basic + // block. This operation does not update the def/use manager, instead it + // records what needs to be updated. The actual update is performed by + // UpdateManagers. + void RewriteUse(BasicBlock* bb, Instruction* user, uint32_t operand_index) { + assert( + (user->opcode() != SpvOpPhi || bb != GetParent(user)) && + "The root basic block must be the incoming edge if |user| is a phi " + "instruction"); + assert((user->opcode() == SpvOpPhi || bb == GetParent(user)) && + "The root basic block must be the instruction parent if |user| is " + "not " + "phi instruction"); + + Instruction* new_def = GetOrBuildIncoming(bb->id()); + + user->SetOperand(operand_index, {new_def->result_id()}); + rewritten_.insert(user); + } + + // In-place update of some managers (avoid full invalidation). + inline void UpdateManagers() { + analysis::DefUseManager* def_use_mgr = base_->context_->get_def_use_mgr(); + // Register all new definitions. + for (Instruction* insn : rewritten_) { + def_use_mgr->AnalyzeInstDef(insn); + } + // Register all new uses. + for (Instruction* insn : rewritten_) { + def_use_mgr->AnalyzeInstUse(insn); + } + } + + private: + // Return the basic block that |instr| belongs to. + BasicBlock* GetParent(Instruction* instr) { + return base_->context_->get_instr_block(instr); + } + + // Builds a phi instruction for the basic block |bb|. The function assumes + // that |defining_blocks| contains the list of basic block that define the + // usable value for each predecessor of |bb|. + inline Instruction* CreatePhiInstruction( + BasicBlock* bb, const std::vector& defining_blocks) { + std::vector incomings; + const std::vector& bb_preds = base_->cfg_->preds(bb->id()); + assert(bb_preds.size() == defining_blocks.size()); + for (size_t i = 0; i < bb_preds.size(); i++) { + incomings.push_back( + GetOrBuildIncoming(defining_blocks[i])->result_id()); + incomings.push_back(bb_preds[i]); + } + InstructionBuilder builder(base_->context_, &*bb->begin(), + IRContext::kAnalysisInstrToBlockMapping); + Instruction* incoming_phi = + builder.AddPhi(def_insn_.type_id(), incomings); + + rewritten_.insert(incoming_phi); + return incoming_phi; + } + + // Builds a phi instruction for the basic block |bb|, all incoming values + // will be |value|. + inline Instruction* CreatePhiInstruction(BasicBlock* bb, + const Instruction& value) { + std::vector incomings; + const std::vector& bb_preds = base_->cfg_->preds(bb->id()); + for (size_t i = 0; i < bb_preds.size(); i++) { + incomings.push_back(value.result_id()); + incomings.push_back(bb_preds[i]); + } + InstructionBuilder builder(base_->context_, &*bb->begin(), + IRContext::kAnalysisInstrToBlockMapping); + Instruction* incoming_phi = + builder.AddPhi(def_insn_.type_id(), incomings); + + rewritten_.insert(incoming_phi); + return incoming_phi; + } + + // Return the new def to use for the basic block |bb_id|. + // If |bb_id| does not have a suitable def to use then we: + // - return the common def used by all predecessors; + // - if there is no common def, then we build a new phi instr at the + // beginning of |bb_id| and return this new instruction. + Instruction* GetOrBuildIncoming(uint32_t bb_id) { + assert(base_->cfg_->block(bb_id) != nullptr && "Unknown basic block"); + + Instruction*& incoming_phi = bb_to_phi_[bb_id]; + if (incoming_phi) { + return incoming_phi; + } + + BasicBlock* bb = &*base_->cfg_->block(bb_id); + // If this is an exit basic block, look if there already is an eligible + // phi instruction. An eligible phi has |def_insn_| as all incoming + // values. + if (base_->exit_bb_.count(bb)) { + // Look if there is an eligible phi in this block. + if (!bb->WhileEachPhiInst([&incoming_phi, this](Instruction* phi) { + for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) { + if (phi->GetSingleWordInOperand(i) != def_insn_.result_id()) + return true; + } + incoming_phi = phi; + rewritten_.insert(incoming_phi); + return false; + })) { + return incoming_phi; + } + incoming_phi = CreatePhiInstruction(bb, def_insn_); + return incoming_phi; + } + + // Get the block that defines the value to use for each predecessor. + // If the vector has 1 value, then it means that this block does not need + // to build a phi instruction unless |bb_id| is the loop merge block. + const std::vector& defining_blocks = + base_->GetDefiningBlocks(bb_id); + + // Special case for structured loops: merge block might be different from + // the exit block set. To maintain structured properties it will ease + // transformations if the merge block also holds a phi instruction like + // the exit ones. + if (defining_blocks.size() > 1 || bb_id == base_->merge_block_id_) { + if (defining_blocks.size() > 1) { + incoming_phi = CreatePhiInstruction(bb, defining_blocks); + } else { + assert(bb_id == base_->merge_block_id_); + incoming_phi = + CreatePhiInstruction(bb, *GetOrBuildIncoming(defining_blocks[0])); + } + } else { + incoming_phi = GetOrBuildIncoming(defining_blocks[0]); + } + + return incoming_phi; + } + + LCSSARewriter* base_; + const Instruction& def_insn_; + std::unordered_map bb_to_phi_; + std::unordered_set rewritten_; + }; + + private: + // Return the new def to use for the basic block |bb_id|. + // If |bb_id| does not have a suitable def to use then we: + // - return the common def used by all predecessors; + // - if there is no common def, then we build a new phi instr at the + // beginning of |bb_id| and return this new instruction. + const std::vector& GetDefiningBlocks(uint32_t bb_id) { + assert(cfg_->block(bb_id) != nullptr && "Unknown basic block"); + std::vector& defining_blocks = bb_to_defining_blocks_[bb_id]; + + if (defining_blocks.size()) return defining_blocks; + + // Check if one of the loop exit basic block dominates |bb_id|. + for (const BasicBlock* e_bb : exit_bb_) { + if (dom_tree_.Dominates(e_bb->id(), bb_id)) { + defining_blocks.push_back(e_bb->id()); + return defining_blocks; + } + } + + // Process parents, they will returns their suitable blocks. + // If they are all the same, this means this basic block is dominated by a + // common block, so we won't need to build a phi instruction. + for (uint32_t pred_id : cfg_->preds(bb_id)) { + const std::vector& pred_blocks = GetDefiningBlocks(pred_id); + if (pred_blocks.size() == 1) + defining_blocks.push_back(pred_blocks[0]); + else + defining_blocks.push_back(pred_id); + } + assert(defining_blocks.size()); + if (std::all_of(defining_blocks.begin(), defining_blocks.end(), + [&defining_blocks](uint32_t id) { + return id == defining_blocks[0]; + })) { + // No need for a phi. + defining_blocks.resize(1); + } + + return defining_blocks; + } + + IRContext* context_; + CFG* cfg_; + const DominatorTree& dom_tree_; + const std::unordered_set& exit_bb_; + uint32_t merge_block_id_; + // This map represent the set of known paths. For each key, the vector + // represent the set of blocks holding the definition to be used to build the + // phi instruction. + // If the vector has 0 value, then the path is unknown yet, and must be built. + // If the vector has 1 value, then the value defined by that basic block + // should be used. + // If the vector has more than 1 value, then a phi node must be created, the + // basic block ordering is the same as the predecessor ordering. + std::unordered_map> bb_to_defining_blocks_; +}; + +// Make the set |blocks| closed SSA. The set is closed SSA if all the uses +// outside the set are phi instructions in exiting basic block set (hold by +// |lcssa_rewriter|). +inline void MakeSetClosedSSA(IRContext* context, Function* function, + const std::unordered_set& blocks, + const std::unordered_set& exit_bb, + LCSSARewriter* lcssa_rewriter) { + CFG& cfg = *context->cfg(); + DominatorTree& dom_tree = + context->GetDominatorAnalysis(function)->GetDomTree(); + analysis::DefUseManager* def_use_manager = context->get_def_use_mgr(); + + for (uint32_t bb_id : blocks) { + BasicBlock* bb = cfg.block(bb_id); + // If bb does not dominate an exit block, then it cannot have escaping defs. + if (!DominatesAnExit(bb, exit_bb, dom_tree)) continue; + for (Instruction& inst : *bb) { + LCSSARewriter::UseRewriter rewriter(lcssa_rewriter, inst); + def_use_manager->ForEachUse( + &inst, [&blocks, &rewriter, &exit_bb, context]( + Instruction* use, uint32_t operand_index) { + BasicBlock* use_parent = context->get_instr_block(use); + assert(use_parent); + if (blocks.count(use_parent->id())) return; + + if (use->opcode() == SpvOpPhi) { + // If the use is a Phi instruction and the incoming block is + // coming from the loop, then that's consistent with LCSSA form. + if (exit_bb.count(use_parent)) { + return; + } else { + // That's not an exit block, but the user is a phi instruction. + // Consider the incoming branch only. + use_parent = context->get_instr_block( + use->GetSingleWordOperand(operand_index + 1)); + } + } + // Rewrite the use. Note that this call does not invalidate the + // def/use manager. So this operation is safe. + rewriter.RewriteUse(use_parent, use, operand_index); + }); + rewriter.UpdateManagers(); + } + } +} + +} // namespace + +void LoopUtils::CreateLoopDedicatedExits() { + Function* function = loop_->GetHeaderBlock()->GetParent(); + LoopDescriptor& loop_desc = *context_->GetLoopDescriptor(function); + CFG& cfg = *context_->cfg(); + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + + const IRContext::Analysis PreservedAnalyses = + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping; + + // Gathers the set of basic block that are not in this loop and have at least + // one predecessor in the loop and one not in the loop. + std::unordered_set exit_bb_set; + loop_->GetExitBlocks(&exit_bb_set); + + std::unordered_set new_loop_exits; + bool made_change = false; + // For each block, we create a new one that gathers all branches from + // the loop and fall into the block. + for (uint32_t non_dedicate_id : exit_bb_set) { + BasicBlock* non_dedicate = cfg.block(non_dedicate_id); + const std::vector& bb_pred = cfg.preds(non_dedicate_id); + // Ignore the block if all the predecessors are in the loop. + if (std::all_of(bb_pred.begin(), bb_pred.end(), + [this](uint32_t id) { return loop_->IsInsideLoop(id); })) { + new_loop_exits.insert(non_dedicate); + continue; + } + + made_change = true; + Function::iterator insert_pt = function->begin(); + for (; insert_pt != function->end() && &*insert_pt != non_dedicate; + ++insert_pt) { + } + assert(insert_pt != function->end() && "Basic Block not found"); + + // Create the dedicate exit basic block. + // TODO(1841): Handle id overflow. + BasicBlock& exit = *insert_pt.InsertBefore(std::unique_ptr( + new BasicBlock(std::unique_ptr(new Instruction( + context_, SpvOpLabel, 0, context_->TakeNextId(), {}))))); + exit.SetParent(function); + + // Redirect in loop predecessors to |exit| block. + for (uint32_t exit_pred_id : bb_pred) { + if (loop_->IsInsideLoop(exit_pred_id)) { + BasicBlock* pred_block = cfg.block(exit_pred_id); + pred_block->ForEachSuccessorLabel([non_dedicate, &exit](uint32_t* id) { + if (*id == non_dedicate->id()) *id = exit.id(); + }); + // Update the CFG. + // |non_dedicate|'s predecessor list will be updated at the end of the + // loop. + cfg.RegisterBlock(pred_block); + } + } + + // Register the label to the def/use manager, requires for the phi patching. + def_use_mgr->AnalyzeInstDefUse(exit.GetLabelInst()); + context_->set_instr_block(exit.GetLabelInst(), &exit); + + InstructionBuilder builder(context_, &exit, PreservedAnalyses); + // Now jump from our dedicate basic block to the old exit. + // We also reset the insert point so all instructions are inserted before + // the branch. + builder.SetInsertPoint(builder.AddBranch(non_dedicate->id())); + non_dedicate->ForEachPhiInst( + [&builder, &exit, def_use_mgr, this](Instruction* phi) { + // New phi operands for this instruction. + std::vector new_phi_op; + // Phi operands for the dedicated exit block. + std::vector exit_phi_op; + for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) { + uint32_t def_id = phi->GetSingleWordInOperand(i); + uint32_t incoming_id = phi->GetSingleWordInOperand(i + 1); + if (loop_->IsInsideLoop(incoming_id)) { + exit_phi_op.push_back(def_id); + exit_phi_op.push_back(incoming_id); + } else { + new_phi_op.push_back(def_id); + new_phi_op.push_back(incoming_id); + } + } + + // Build the new phi instruction dedicated exit block. + Instruction* exit_phi = builder.AddPhi(phi->type_id(), exit_phi_op); + // Build the new incoming branch. + new_phi_op.push_back(exit_phi->result_id()); + new_phi_op.push_back(exit.id()); + // Rewrite operands. + uint32_t idx = 0; + for (; idx < new_phi_op.size(); idx++) + phi->SetInOperand(idx, {new_phi_op[idx]}); + // Remove extra operands, from last to first (more efficient). + for (uint32_t j = phi->NumInOperands() - 1; j >= idx; j--) + phi->RemoveInOperand(j); + // Update the def/use manager for this |phi|. + def_use_mgr->AnalyzeInstUse(phi); + }); + // Update the CFG. + cfg.RegisterBlock(&exit); + cfg.RemoveNonExistingEdges(non_dedicate->id()); + new_loop_exits.insert(&exit); + // If non_dedicate is in a loop, add the new dedicated exit in that loop. + if (Loop* parent_loop = loop_desc[non_dedicate]) + parent_loop->AddBasicBlock(&exit); + } + + if (new_loop_exits.size() == 1) { + loop_->SetMergeBlock(*new_loop_exits.begin()); + } + + if (made_change) { + context_->InvalidateAnalysesExceptFor( + PreservedAnalyses | IRContext::kAnalysisCFG | + IRContext::Analysis::kAnalysisLoopAnalysis); + } +} + +void LoopUtils::MakeLoopClosedSSA() { + CreateLoopDedicatedExits(); + + Function* function = loop_->GetHeaderBlock()->GetParent(); + CFG& cfg = *context_->cfg(); + DominatorTree& dom_tree = + context_->GetDominatorAnalysis(function)->GetDomTree(); + + std::unordered_set exit_bb; + { + std::unordered_set exit_bb_id; + loop_->GetExitBlocks(&exit_bb_id); + for (uint32_t bb_id : exit_bb_id) { + exit_bb.insert(cfg.block(bb_id)); + } + } + + LCSSARewriter lcssa_rewriter(context_, dom_tree, exit_bb, + loop_->GetMergeBlock()); + MakeSetClosedSSA(context_, function, loop_->GetBlocks(), exit_bb, + &lcssa_rewriter); + + // Make sure all defs post-dominated by the merge block have their last use no + // further than the merge block. + if (loop_->GetMergeBlock()) { + std::unordered_set merging_bb_id; + loop_->GetMergingBlocks(&merging_bb_id); + merging_bb_id.erase(loop_->GetMergeBlock()->id()); + // Reset the exit set, now only the merge block is the exit. + exit_bb.clear(); + exit_bb.insert(loop_->GetMergeBlock()); + // LCSSARewriter is reusable here only because it forces the creation of a + // phi instruction in the merge block. + MakeSetClosedSSA(context_, function, merging_bb_id, exit_bb, + &lcssa_rewriter); + } + + context_->InvalidateAnalysesExceptFor( + IRContext::Analysis::kAnalysisCFG | + IRContext::Analysis::kAnalysisDominatorAnalysis | + IRContext::Analysis::kAnalysisLoopAnalysis); +} + +Loop* LoopUtils::CloneLoop(LoopCloningResult* cloning_result) const { + // Compute the structured order of the loop basic blocks and store it in the + // vector ordered_loop_blocks. + std::vector ordered_loop_blocks; + loop_->ComputeLoopStructuredOrder(&ordered_loop_blocks); + + // Clone the loop. + return CloneLoop(cloning_result, ordered_loop_blocks); +} + +Loop* LoopUtils::CloneAndAttachLoopToHeader(LoopCloningResult* cloning_result) { + // Clone the loop. + Loop* new_loop = CloneLoop(cloning_result); + + // Create a new exit block/label for the new loop. + // TODO(1841): Handle id overflow. + std::unique_ptr new_label{new Instruction( + context_, SpvOp::SpvOpLabel, 0, context_->TakeNextId(), {})}; + std::unique_ptr new_exit_bb{new BasicBlock(std::move(new_label))}; + new_exit_bb->SetParent(loop_->GetMergeBlock()->GetParent()); + + // Create an unconditional branch to the header block. + InstructionBuilder builder{context_, new_exit_bb.get()}; + builder.AddBranch(loop_->GetHeaderBlock()->id()); + + // Save the ids of the new and old merge block. + const uint32_t old_merge_block = loop_->GetMergeBlock()->id(); + const uint32_t new_merge_block = new_exit_bb->id(); + + // Replace the uses of the old merge block in the new loop with the new merge + // block. + for (std::unique_ptr& basic_block : cloning_result->cloned_bb_) { + for (Instruction& inst : *basic_block) { + // For each operand in each instruction check if it is using the old merge + // block and change it to be the new merge block. + auto replace_merge_use = [old_merge_block, + new_merge_block](uint32_t* id) { + if (*id == old_merge_block) *id = new_merge_block; + }; + inst.ForEachInOperand(replace_merge_use); + } + } + + const uint32_t old_header = loop_->GetHeaderBlock()->id(); + const uint32_t new_header = new_loop->GetHeaderBlock()->id(); + analysis::DefUseManager* def_use = context_->get_def_use_mgr(); + + def_use->ForEachUse(old_header, + [new_header, this](Instruction* inst, uint32_t operand) { + if (!this->loop_->IsInsideLoop(inst)) + inst->SetOperand(operand, {new_header}); + }); + + // TODO(1841): Handle failure to create pre-header. + def_use->ForEachUse( + loop_->GetOrCreatePreHeaderBlock()->id(), + [new_merge_block, this](Instruction* inst, uint32_t operand) { + if (this->loop_->IsInsideLoop(inst)) + inst->SetOperand(operand, {new_merge_block}); + + }); + new_loop->SetMergeBlock(new_exit_bb.get()); + + new_loop->SetPreHeaderBlock(loop_->GetPreHeaderBlock()); + + // Add the new block into the cloned instructions. + cloning_result->cloned_bb_.push_back(std::move(new_exit_bb)); + + return new_loop; +} + +Loop* LoopUtils::CloneLoop( + LoopCloningResult* cloning_result, + const std::vector& ordered_loop_blocks) const { + analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); + + std::unique_ptr new_loop = MakeUnique(context_); + + CFG& cfg = *context_->cfg(); + + // Clone and place blocks in a SPIR-V compliant order (dominators first). + for (BasicBlock* old_bb : ordered_loop_blocks) { + // For each basic block in the loop, we clone it and register the mapping + // between old and new ids. + BasicBlock* new_bb = old_bb->Clone(context_); + new_bb->SetParent(&function_); + // TODO(1841): Handle id overflow. + new_bb->GetLabelInst()->SetResultId(context_->TakeNextId()); + def_use_mgr->AnalyzeInstDef(new_bb->GetLabelInst()); + context_->set_instr_block(new_bb->GetLabelInst(), new_bb); + cloning_result->cloned_bb_.emplace_back(new_bb); + + cloning_result->old_to_new_bb_[old_bb->id()] = new_bb; + cloning_result->new_to_old_bb_[new_bb->id()] = old_bb; + cloning_result->value_map_[old_bb->id()] = new_bb->id(); + + if (loop_->IsInsideLoop(old_bb)) new_loop->AddBasicBlock(new_bb); + + for (auto new_inst = new_bb->begin(), old_inst = old_bb->begin(); + new_inst != new_bb->end(); ++new_inst, ++old_inst) { + cloning_result->ptr_map_[&*new_inst] = &*old_inst; + if (new_inst->HasResultId()) { + // TODO(1841): Handle id overflow. + new_inst->SetResultId(context_->TakeNextId()); + cloning_result->value_map_[old_inst->result_id()] = + new_inst->result_id(); + + // Only look at the defs for now, uses are not updated yet. + def_use_mgr->AnalyzeInstDef(&*new_inst); + } + } + } + + // All instructions (including all labels) have been cloned, + // remap instruction operands id with the new ones. + for (std::unique_ptr& bb_ref : cloning_result->cloned_bb_) { + BasicBlock* bb = bb_ref.get(); + + for (Instruction& insn : *bb) { + insn.ForEachInId([cloning_result](uint32_t* old_id) { + // If the operand is defined in the loop, remap the id. + auto id_it = cloning_result->value_map_.find(*old_id); + if (id_it != cloning_result->value_map_.end()) { + *old_id = id_it->second; + } + }); + // Only look at what the instruction uses. All defs are register, so all + // should be fine now. + def_use_mgr->AnalyzeInstUse(&insn); + context_->set_instr_block(&insn, bb); + } + cfg.RegisterBlock(bb); + } + + PopulateLoopNest(new_loop.get(), *cloning_result); + + return new_loop.release(); +} + +void LoopUtils::PopulateLoopNest( + Loop* new_loop, const LoopCloningResult& cloning_result) const { + std::unordered_map loop_mapping; + loop_mapping[loop_] = new_loop; + + if (loop_->HasParent()) loop_->GetParent()->AddNestedLoop(new_loop); + PopulateLoopDesc(new_loop, loop_, cloning_result); + + for (Loop& sub_loop : + make_range(++TreeDFIterator(loop_), TreeDFIterator())) { + Loop* cloned = new Loop(context_); + if (Loop* parent = loop_mapping[sub_loop.GetParent()]) + parent->AddNestedLoop(cloned); + loop_mapping[&sub_loop] = cloned; + PopulateLoopDesc(cloned, &sub_loop, cloning_result); + } + + loop_desc_->AddLoopNest(std::unique_ptr(new_loop)); +} + +// Populates |new_loop| descriptor according to |old_loop|'s one. +void LoopUtils::PopulateLoopDesc( + Loop* new_loop, Loop* old_loop, + const LoopCloningResult& cloning_result) const { + for (uint32_t bb_id : old_loop->GetBlocks()) { + BasicBlock* bb = cloning_result.old_to_new_bb_.at(bb_id); + new_loop->AddBasicBlock(bb); + } + new_loop->SetHeaderBlock( + cloning_result.old_to_new_bb_.at(old_loop->GetHeaderBlock()->id())); + if (old_loop->GetLatchBlock()) + new_loop->SetLatchBlock( + cloning_result.old_to_new_bb_.at(old_loop->GetLatchBlock()->id())); + if (old_loop->GetContinueBlock()) + new_loop->SetContinueBlock( + cloning_result.old_to_new_bb_.at(old_loop->GetContinueBlock()->id())); + if (old_loop->GetMergeBlock()) { + auto it = + cloning_result.old_to_new_bb_.find(old_loop->GetMergeBlock()->id()); + BasicBlock* bb = it != cloning_result.old_to_new_bb_.end() + ? it->second + : old_loop->GetMergeBlock(); + new_loop->SetMergeBlock(bb); + } + if (old_loop->GetPreHeaderBlock()) { + auto it = + cloning_result.old_to_new_bb_.find(old_loop->GetPreHeaderBlock()->id()); + if (it != cloning_result.old_to_new_bb_.end()) { + new_loop->SetPreHeaderBlock(it->second); + } + } +} + +// Class to gather some metrics about a region of interest. +void CodeMetrics::Analyze(const Loop& loop) { + CFG& cfg = *loop.GetContext()->cfg(); + + roi_size_ = 0; + block_sizes_.clear(); + + for (uint32_t id : loop.GetBlocks()) { + const BasicBlock* bb = cfg.block(id); + size_t bb_size = 0; + bb->ForEachInst([&bb_size](const Instruction* insn) { + if (insn->opcode() == SpvOpLabel) return; + if (insn->IsNop()) return; + if (insn->opcode() == SpvOpPhi) return; + bb_size++; + }); + block_sizes_[bb->id()] = bb_size; + roi_size_ += bb_size; + } +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/loop_utils.h b/source/opt/loop_utils.h new file mode 100644 index 000000000..a4e61900b --- /dev/null +++ b/source/opt/loop_utils.h @@ -0,0 +1,182 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_LOOP_UTILS_H_ +#define SOURCE_OPT_LOOP_UTILS_H_ + +#include +#include +#include +#include + +#include "source/opt/ir_context.h" +#include "source/opt/loop_descriptor.h" + +namespace spvtools { + +namespace opt { + +// Class to gather some metrics about a Region Of Interest (ROI). +// So far it counts the number of instructions in a ROI (excluding debug +// and label instructions) per basic block and in total. +struct CodeMetrics { + void Analyze(const Loop& loop); + + // The number of instructions per basic block in the ROI. + std::unordered_map block_sizes_; + + // Number of instruction in the ROI. + size_t roi_size_; +}; + +// LoopUtils is used to encapsulte loop optimizations and from the passes which +// use them. Any pass which needs a loop optimization should do it through this +// or through a pass which is using this. +class LoopUtils { + public: + // Holds a auxiliary results of the loop cloning procedure. + struct LoopCloningResult { + using ValueMapTy = std::unordered_map; + using BlockMapTy = std::unordered_map; + using PtrMap = std::unordered_map; + + PtrMap ptr_map_; + + // Mapping between the original loop ids and the new one. + ValueMapTy value_map_; + // Mapping between original loop blocks to the cloned one. + BlockMapTy old_to_new_bb_; + // Mapping between the cloned loop blocks to original one. + BlockMapTy new_to_old_bb_; + // List of cloned basic block. + std::vector> cloned_bb_; + }; + + LoopUtils(IRContext* context, Loop* loop) + : context_(context), + loop_desc_( + context->GetLoopDescriptor(loop->GetHeaderBlock()->GetParent())), + loop_(loop), + function_(*loop_->GetHeaderBlock()->GetParent()) {} + + // The converts the current loop to loop closed SSA form. + // In the loop closed SSA, all loop exiting values go through a dedicated Phi + // instruction. For instance: + // + // for (...) { + // A1 = ... + // if (...) + // A2 = ... + // A = phi A1, A2 + // } + // ... = op A ... + // + // Becomes + // + // for (...) { + // A1 = ... + // if (...) + // A2 = ... + // A = phi A1, A2 + // } + // C = phi A + // ... = op C ... + // + // This makes some loop transformations (such as loop unswitch) simpler + // (removes the needs to take care of exiting variables). + void MakeLoopClosedSSA(); + + // Create dedicate exit basic block. This ensure all exit basic blocks has the + // loop as sole predecessors. + // By construction, structured control flow already has a dedicated exit + // block. + // Preserves: CFG, def/use and instruction to block mapping. + void CreateLoopDedicatedExits(); + + // Clone |loop_| and remap its instructions. Newly created blocks + // will be added to the |cloning_result.cloned_bb_| list, correctly ordered to + // be inserted into a function. + // It is assumed that |ordered_loop_blocks| is compatible with the result of + // |Loop::ComputeLoopStructuredOrder|. If the preheader and merge block are in + // the list they will also be cloned. If not, the resulting loop will share + // them with the original loop. + // The function preserves the def/use, cfg and instr to block analyses. + // The cloned loop nest will be added to the loop descriptor and will have + // ownership. + Loop* CloneLoop(LoopCloningResult* cloning_result, + const std::vector& ordered_loop_blocks) const; + // Clone |loop_| and remap its instructions, as above. Overload to compute + // loop block ordering within method rather than taking in as parameter. + Loop* CloneLoop(LoopCloningResult* cloning_result) const; + + // Clone the |loop_| and make the new loop branch to the second loop on exit. + Loop* CloneAndAttachLoopToHeader(LoopCloningResult* cloning_result); + + // Perfom a partial unroll of |loop| by given |factor|. This will copy the + // body of the loop |factor| times. So a |factor| of one would give a new loop + // with the original body plus one unrolled copy body. + bool PartiallyUnroll(size_t factor); + + // Fully unroll |loop|. + bool FullyUnroll(); + + // This function validates that |loop| meets the assumptions made by the + // implementation of the loop unroller. As the implementation accommodates + // more types of loops this function can reduce its checks. + // + // The conditions checked to ensure the loop can be unrolled are as follows: + // 1. That the loop is in structured order. + // 2. That the continue block is a branch to the header. + // 3. That the only phi used in the loop is the induction variable. + // TODO(stephen@codeplay.com): This is a temporary mesure, after the loop is + // converted into LCSAA form and has a single entry and exit we can rewrite + // the other phis. + // 4. That this is an inner most loop, or that loops contained within this + // loop have already been fully unrolled. + // 5. That each instruction in the loop is only used within the loop. + // (Related to the above phi condition). + bool CanPerformUnroll(); + + // Maintains the loop descriptor object after the unroll functions have been + // called, otherwise the analysis should be invalidated. + void Finalize(); + + // Returns the context associate to |loop_|. + IRContext* GetContext() { return context_; } + // Returns the loop descriptor owning |loop_|. + LoopDescriptor* GetLoopDescriptor() { return loop_desc_; } + // Returns the loop on which the object operates on. + Loop* GetLoop() const { return loop_; } + // Returns the function that |loop_| belong to. + Function* GetFunction() const { return &function_; } + + private: + IRContext* context_; + LoopDescriptor* loop_desc_; + Loop* loop_; + Function& function_; + + // Populates the loop nest of |new_loop| according to |loop_| nest. + void PopulateLoopNest(Loop* new_loop, + const LoopCloningResult& cloning_result) const; + + // Populates |new_loop| descriptor according to |old_loop|'s one. + void PopulateLoopDesc(Loop* new_loop, Loop* old_loop, + const LoopCloningResult& cloning_result) const; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_LOOP_UTILS_H_ diff --git a/source/opt/mem_pass.cpp b/source/opt/mem_pass.cpp new file mode 100644 index 000000000..cc0976767 --- /dev/null +++ b/source/opt/mem_pass.cpp @@ -0,0 +1,495 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/mem_pass.h" + +#include +#include +#include + +#include "source/cfa.h" +#include "source/opt/basic_block.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/ir_context.h" +#include "source/opt/iterator.h" + +namespace spvtools { +namespace opt { + +namespace { + +const uint32_t kCopyObjectOperandInIdx = 0; +const uint32_t kTypePointerStorageClassInIdx = 0; +const uint32_t kTypePointerTypeIdInIdx = 1; + +} // namespace + +bool MemPass::IsBaseTargetType(const Instruction* typeInst) const { + switch (typeInst->opcode()) { + case SpvOpTypeInt: + case SpvOpTypeFloat: + case SpvOpTypeBool: + case SpvOpTypeVector: + case SpvOpTypeMatrix: + case SpvOpTypeImage: + case SpvOpTypeSampler: + case SpvOpTypeSampledImage: + case SpvOpTypePointer: + return true; + default: + break; + } + return false; +} + +bool MemPass::IsTargetType(const Instruction* typeInst) const { + if (IsBaseTargetType(typeInst)) return true; + if (typeInst->opcode() == SpvOpTypeArray) { + if (!IsTargetType( + get_def_use_mgr()->GetDef(typeInst->GetSingleWordOperand(1)))) { + return false; + } + return true; + } + if (typeInst->opcode() != SpvOpTypeStruct) return false; + // All struct members must be math type + return typeInst->WhileEachInId([this](const uint32_t* tid) { + Instruction* compTypeInst = get_def_use_mgr()->GetDef(*tid); + if (!IsTargetType(compTypeInst)) return false; + return true; + }); +} + +bool MemPass::IsNonPtrAccessChain(const SpvOp opcode) const { + return opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain; +} + +bool MemPass::IsPtr(uint32_t ptrId) { + uint32_t varId = ptrId; + Instruction* ptrInst = get_def_use_mgr()->GetDef(varId); + while (ptrInst->opcode() == SpvOpCopyObject) { + varId = ptrInst->GetSingleWordInOperand(kCopyObjectOperandInIdx); + ptrInst = get_def_use_mgr()->GetDef(varId); + } + const SpvOp op = ptrInst->opcode(); + if (op == SpvOpVariable || IsNonPtrAccessChain(op)) return true; + if (op != SpvOpFunctionParameter) return false; + const uint32_t varTypeId = ptrInst->type_id(); + const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId); + return varTypeInst->opcode() == SpvOpTypePointer; +} + +Instruction* MemPass::GetPtr(uint32_t ptrId, uint32_t* varId) { + *varId = ptrId; + Instruction* ptrInst = get_def_use_mgr()->GetDef(*varId); + Instruction* varInst; + + if (ptrInst->opcode() != SpvOpVariable && + ptrInst->opcode() != SpvOpFunctionParameter) { + varInst = ptrInst->GetBaseAddress(); + } else { + varInst = ptrInst; + } + if (varInst->opcode() == SpvOpVariable) { + *varId = varInst->result_id(); + } else { + *varId = 0; + } + + while (ptrInst->opcode() == SpvOpCopyObject) { + uint32_t temp = ptrInst->GetSingleWordInOperand(0); + ptrInst = get_def_use_mgr()->GetDef(temp); + } + + return ptrInst; +} + +Instruction* MemPass::GetPtr(Instruction* ip, uint32_t* varId) { + assert(ip->opcode() == SpvOpStore || ip->opcode() == SpvOpLoad || + ip->opcode() == SpvOpImageTexelPointer || ip->IsAtomicWithLoad()); + + // All of these opcode place the pointer in position 0. + const uint32_t ptrId = ip->GetSingleWordInOperand(0); + return GetPtr(ptrId, varId); +} + +bool MemPass::HasOnlyNamesAndDecorates(uint32_t id) const { + return get_def_use_mgr()->WhileEachUser(id, [this](Instruction* user) { + SpvOp op = user->opcode(); + if (op != SpvOpName && !IsNonTypeDecorate(op)) { + return false; + } + return true; + }); +} + +void MemPass::KillAllInsts(BasicBlock* bp, bool killLabel) { + bp->KillAllInsts(killLabel); +} + +bool MemPass::HasLoads(uint32_t varId) const { + return !get_def_use_mgr()->WhileEachUser(varId, [this](Instruction* user) { + SpvOp op = user->opcode(); + // TODO(): The following is slightly conservative. Could be + // better handling of non-store/name. + if (IsNonPtrAccessChain(op) || op == SpvOpCopyObject) { + if (HasLoads(user->result_id())) { + return false; + } + } else if (op != SpvOpStore && op != SpvOpName && !IsNonTypeDecorate(op)) { + return false; + } + return true; + }); +} + +bool MemPass::IsLiveVar(uint32_t varId) const { + const Instruction* varInst = get_def_use_mgr()->GetDef(varId); + // assume live if not a variable eg. function parameter + if (varInst->opcode() != SpvOpVariable) return true; + // non-function scope vars are live + const uint32_t varTypeId = varInst->type_id(); + const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId); + if (varTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx) != + SpvStorageClassFunction) + return true; + // test if variable is loaded from + return HasLoads(varId); +} + +void MemPass::AddStores(uint32_t ptr_id, std::queue* insts) { + get_def_use_mgr()->ForEachUser(ptr_id, [this, insts](Instruction* user) { + SpvOp op = user->opcode(); + if (IsNonPtrAccessChain(op)) { + AddStores(user->result_id(), insts); + } else if (op == SpvOpStore) { + insts->push(user); + } + }); +} + +void MemPass::DCEInst(Instruction* inst, + const std::function& call_back) { + std::queue deadInsts; + deadInsts.push(inst); + while (!deadInsts.empty()) { + Instruction* di = deadInsts.front(); + // Don't delete labels + if (di->opcode() == SpvOpLabel) { + deadInsts.pop(); + continue; + } + // Remember operands + std::set ids; + di->ForEachInId([&ids](uint32_t* iid) { ids.insert(*iid); }); + uint32_t varId = 0; + // Remember variable if dead load + if (di->opcode() == SpvOpLoad) (void)GetPtr(di, &varId); + if (call_back) { + call_back(di); + } + context()->KillInst(di); + // For all operands with no remaining uses, add their instruction + // to the dead instruction queue. + for (auto id : ids) + if (HasOnlyNamesAndDecorates(id)) { + Instruction* odi = get_def_use_mgr()->GetDef(id); + if (context()->IsCombinatorInstruction(odi)) deadInsts.push(odi); + } + // if a load was deleted and it was the variable's + // last load, add all its stores to dead queue + if (varId != 0 && !IsLiveVar(varId)) AddStores(varId, &deadInsts); + deadInsts.pop(); + } +} + +MemPass::MemPass() {} + +bool MemPass::HasOnlySupportedRefs(uint32_t varId) { + return get_def_use_mgr()->WhileEachUser(varId, [this](Instruction* user) { + SpvOp op = user->opcode(); + if (op != SpvOpStore && op != SpvOpLoad && op != SpvOpName && + !IsNonTypeDecorate(op)) { + return false; + } + return true; + }); +} + +uint32_t MemPass::Type2Undef(uint32_t type_id) { + const auto uitr = type2undefs_.find(type_id); + if (uitr != type2undefs_.end()) return uitr->second; + const uint32_t undefId = TakeNextId(); + std::unique_ptr undef_inst( + new Instruction(context(), SpvOpUndef, type_id, undefId, {})); + get_def_use_mgr()->AnalyzeInstDefUse(&*undef_inst); + get_module()->AddGlobalValue(std::move(undef_inst)); + type2undefs_[type_id] = undefId; + return undefId; +} + +bool MemPass::IsTargetVar(uint32_t varId) { + if (varId == 0) { + return false; + } + + if (seen_non_target_vars_.find(varId) != seen_non_target_vars_.end()) + return false; + if (seen_target_vars_.find(varId) != seen_target_vars_.end()) return true; + const Instruction* varInst = get_def_use_mgr()->GetDef(varId); + if (varInst->opcode() != SpvOpVariable) return false; + const uint32_t varTypeId = varInst->type_id(); + const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId); + if (varTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx) != + SpvStorageClassFunction) { + seen_non_target_vars_.insert(varId); + return false; + } + const uint32_t varPteTypeId = + varTypeInst->GetSingleWordInOperand(kTypePointerTypeIdInIdx); + Instruction* varPteTypeInst = get_def_use_mgr()->GetDef(varPteTypeId); + if (!IsTargetType(varPteTypeInst)) { + seen_non_target_vars_.insert(varId); + return false; + } + seen_target_vars_.insert(varId); + return true; +} + +// Remove all |phi| operands coming from unreachable blocks (i.e., blocks not in +// |reachable_blocks|). There are two types of removal that this function can +// perform: +// +// 1- Any operand that comes directly from an unreachable block is completely +// removed. Since the block is unreachable, the edge between the unreachable +// block and the block holding |phi| has been removed. +// +// 2- Any operand that comes via a live block and was defined at an unreachable +// block gets its value replaced with an OpUndef value. Since the argument +// was generated in an unreachable block, it no longer exists, so it cannot +// be referenced. However, since the value does not reach |phi| directly +// from the unreachable block, the operand cannot be removed from |phi|. +// Therefore, we replace the argument value with OpUndef. +// +// For example, in the switch() below, assume that we want to remove the +// argument with value %11 coming from block %41. +// +// [ ... ] +// %41 = OpLabel <--- Unreachable block +// %11 = OpLoad %int %y +// [ ... ] +// OpSelectionMerge %16 None +// OpSwitch %12 %16 10 %13 13 %14 18 %15 +// %13 = OpLabel +// OpBranch %16 +// %14 = OpLabel +// OpStore %outparm %int_14 +// OpBranch %16 +// %15 = OpLabel +// OpStore %outparm %int_15 +// OpBranch %16 +// %16 = OpLabel +// %30 = OpPhi %int %11 %41 %int_42 %13 %11 %14 %11 %15 +// +// Since %41 is now an unreachable block, the first operand of |phi| needs to +// be removed completely. But the operands (%11 %14) and (%11 %15) cannot be +// removed because %14 and %15 are reachable blocks. Since %11 no longer exist, +// in those arguments, we replace all references to %11 with an OpUndef value. +// This results in |phi| looking like: +// +// %50 = OpUndef %int +// [ ... ] +// %30 = OpPhi %int %int_42 %13 %50 %14 %50 %15 +void MemPass::RemovePhiOperands( + Instruction* phi, const std::unordered_set& reachable_blocks) { + std::vector keep_operands; + uint32_t type_id = 0; + // The id of an undefined value we've generated. + uint32_t undef_id = 0; + + // Traverse all the operands in |phi|. Build the new operand vector by adding + // all the original operands from |phi| except the unwanted ones. + for (uint32_t i = 0; i < phi->NumOperands();) { + if (i < 2) { + // The first two arguments are always preserved. + keep_operands.push_back(phi->GetOperand(i)); + ++i; + continue; + } + + // The remaining Phi arguments come in pairs. Index 'i' contains the + // variable id, index 'i + 1' is the originating block id. + assert(i % 2 == 0 && i < phi->NumOperands() - 1 && + "malformed Phi arguments"); + + BasicBlock* in_block = cfg()->block(phi->GetSingleWordOperand(i + 1)); + if (reachable_blocks.find(in_block) == reachable_blocks.end()) { + // If the incoming block is unreachable, remove both operands as this + // means that the |phi| has lost an incoming edge. + i += 2; + continue; + } + + // In all other cases, the operand must be kept but may need to be changed. + uint32_t arg_id = phi->GetSingleWordOperand(i); + Instruction* arg_def_instr = get_def_use_mgr()->GetDef(arg_id); + BasicBlock* def_block = context()->get_instr_block(arg_def_instr); + if (def_block && + reachable_blocks.find(def_block) == reachable_blocks.end()) { + // If the current |phi| argument was defined in an unreachable block, it + // means that this |phi| argument is no longer defined. Replace it with + // |undef_id|. + if (!undef_id) { + type_id = arg_def_instr->type_id(); + undef_id = Type2Undef(type_id); + } + keep_operands.push_back( + Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, {undef_id})); + } else { + // Otherwise, the argument comes from a reachable block or from no block + // at all (meaning that it was defined in the global section of the + // program). In both cases, keep the argument intact. + keep_operands.push_back(phi->GetOperand(i)); + } + + keep_operands.push_back(phi->GetOperand(i + 1)); + + i += 2; + } + + context()->ForgetUses(phi); + phi->ReplaceOperands(keep_operands); + context()->AnalyzeUses(phi); +} + +void MemPass::RemoveBlock(Function::iterator* bi) { + auto& rm_block = **bi; + + // Remove instructions from the block. + rm_block.ForEachInst([&rm_block, this](Instruction* inst) { + // Note that we do not kill the block label instruction here. The label + // instruction is needed to identify the block, which is needed by the + // removal of phi operands. + if (inst != rm_block.GetLabelInst()) { + context()->KillInst(inst); + } + }); + + // Remove the label instruction last. + auto label = rm_block.GetLabelInst(); + context()->KillInst(label); + + *bi = bi->Erase(); +} + +bool MemPass::RemoveUnreachableBlocks(Function* func) { + bool modified = false; + + // Mark reachable all blocks reachable from the function's entry block. + std::unordered_set reachable_blocks; + std::unordered_set visited_blocks; + std::queue worklist; + reachable_blocks.insert(func->entry().get()); + + // Initially mark the function entry point as reachable. + worklist.push(func->entry().get()); + + auto mark_reachable = [&reachable_blocks, &visited_blocks, &worklist, + this](uint32_t label_id) { + auto successor = cfg()->block(label_id); + if (visited_blocks.count(successor) == 0) { + reachable_blocks.insert(successor); + worklist.push(successor); + visited_blocks.insert(successor); + } + }; + + // Transitively mark all blocks reachable from the entry as reachable. + while (!worklist.empty()) { + BasicBlock* block = worklist.front(); + worklist.pop(); + + // All the successors of a live block are also live. + static_cast(block)->ForEachSuccessorLabel( + mark_reachable); + + // All the Merge and ContinueTarget blocks of a live block are also live. + block->ForMergeAndContinueLabel(mark_reachable); + } + + // Update operands of Phi nodes that reference unreachable blocks. + for (auto& block : *func) { + // If the block is about to be removed, don't bother updating its + // Phi instructions. + if (reachable_blocks.count(&block) == 0) { + continue; + } + + // If the block is reachable and has Phi instructions, remove all + // operands from its Phi instructions that reference unreachable blocks. + // If the block has no Phi instructions, this is a no-op. + block.ForEachPhiInst([&reachable_blocks, this](Instruction* phi) { + RemovePhiOperands(phi, reachable_blocks); + }); + } + + // Erase unreachable blocks. + for (auto ebi = func->begin(); ebi != func->end();) { + if (reachable_blocks.count(&*ebi) == 0) { + RemoveBlock(&ebi); + modified = true; + } else { + ++ebi; + } + } + + return modified; +} + +bool MemPass::CFGCleanup(Function* func) { + bool modified = false; + modified |= RemoveUnreachableBlocks(func); + return modified; +} + +void MemPass::CollectTargetVars(Function* func) { + seen_target_vars_.clear(); + seen_non_target_vars_.clear(); + type2undefs_.clear(); + + // Collect target (and non-) variable sets. Remove variables with + // non-load/store refs from target variable set + for (auto& blk : *func) { + for (auto& inst : blk) { + switch (inst.opcode()) { + case SpvOpStore: + case SpvOpLoad: { + uint32_t varId; + (void)GetPtr(&inst, &varId); + if (!IsTargetVar(varId)) break; + if (HasOnlySupportedRefs(varId)) break; + seen_non_target_vars_.insert(varId); + seen_target_vars_.erase(varId); + } break; + default: + break; + } + } + } +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/mem_pass.h b/source/opt/mem_pass.h new file mode 100644 index 000000000..67ce26b13 --- /dev/null +++ b/source/opt/mem_pass.h @@ -0,0 +1,163 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_MEM_PASS_H_ +#define SOURCE_OPT_MEM_PASS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "source/opt/basic_block.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// A common base class for mem2reg-type passes. Provides common +// utility functions and supporting state. +class MemPass : public Pass { + public: + virtual ~MemPass() = default; + + // Returns an undef value for the given |var_id|'s type. + uint32_t GetUndefVal(uint32_t var_id) { + return Type2Undef(GetPointeeTypeId(get_def_use_mgr()->GetDef(var_id))); + } + + // Given a load or store |ip|, return the pointer instruction. + // Also return the base variable's id in |varId|. If no base variable is + // found, |varId| will be 0. + Instruction* GetPtr(Instruction* ip, uint32_t* varId); + + // Return true if |varId| is a previously identified target variable. + // Return false if |varId| is a previously identified non-target variable. + // + // Non-target variables are variable of function scope of a target type that + // are accessed with constant-index access chains. not accessed with + // non-constant-index access chains. Also cache non-target variables. + // + // If variable is not cached, return true if variable is a function scope + // variable of target type, false otherwise. Updates caches of target and + // non-target variables. + bool IsTargetVar(uint32_t varId); + + // Collect target SSA variables. This traverses all the loads and stores in + // function |func| looking for variables that can be replaced with SSA IDs. It + // populates the sets |seen_target_vars_| and |seen_non_target_vars_|. + void CollectTargetVars(Function* func); + + protected: + MemPass(); + + // Returns true if |typeInst| is a scalar type + // or a vector or matrix + bool IsBaseTargetType(const Instruction* typeInst) const; + + // Returns true if |typeInst| is a math type or a struct or array + // of a math type. + // TODO(): Add more complex types to convert + bool IsTargetType(const Instruction* typeInst) const; + + // Returns true if |opcode| is a non-ptr access chain op + bool IsNonPtrAccessChain(const SpvOp opcode) const; + + // Given the id |ptrId|, return true if the top-most non-CopyObj is + // a variable, a non-ptr access chain or a parameter of pointer type. + bool IsPtr(uint32_t ptrId); + + // Given the id of a pointer |ptrId|, return the top-most non-CopyObj. + // Also return the base variable's id in |varId|. If no base variable is + // found, |varId| will be 0. + Instruction* GetPtr(uint32_t ptrId, uint32_t* varId); + + // Return true if all uses of |id| are only name or decorate ops. + bool HasOnlyNamesAndDecorates(uint32_t id) const; + + // Kill all instructions in block |bp|. Whether or not to kill the label is + // indicated by |killLabel|. + void KillAllInsts(BasicBlock* bp, bool killLabel = true); + + // Return true if any instruction loads from |varId| + bool HasLoads(uint32_t varId) const; + + // Return true if |varId| is not a function variable or if it has + // a load + bool IsLiveVar(uint32_t varId) const; + + // Add stores using |ptr_id| to |insts| + void AddStores(uint32_t ptr_id, std::queue* insts); + + // Delete |inst| and iterate DCE on all its operands if they are now + // useless. If a load is deleted and its variable has no other loads, + // delete all its variable's stores. + void DCEInst(Instruction* inst, const std::function&); + + // Call all the cleanup helper functions on |func|. + bool CFGCleanup(Function* func); + + // Return true if |op| is supported decorate. + inline bool IsNonTypeDecorate(uint32_t op) const { + return (op == SpvOpDecorate || op == SpvOpDecorateId); + } + + // Return undef in function for type. Create and insert an undef after the + // first non-variable in the function if it doesn't already exist. Add + // undef to function undef map. + uint32_t Type2Undef(uint32_t type_id); + + // Cache of verified target vars + std::unordered_set seen_target_vars_; + + // Cache of verified non-target vars + std::unordered_set seen_non_target_vars_; + + private: + // Return true if all uses of |varId| are only through supported reference + // operations ie. loads and store. Also cache in supported_ref_vars_. + // TODO(dnovillo): This function is replicated in other passes and it's + // slightly different in every pass. Is it possible to make one common + // implementation? + bool HasOnlySupportedRefs(uint32_t varId); + + // Remove all the unreachable basic blocks in |func|. + bool RemoveUnreachableBlocks(Function* func); + + // Remove the block pointed by the iterator |*bi|. This also removes + // all the instructions in the pointed-to block. + void RemoveBlock(Function::iterator* bi); + + // Remove Phi operands in |phi| that are coming from blocks not in + // |reachable_blocks|. + void RemovePhiOperands( + Instruction* phi, + const std::unordered_set& reachable_blocks); + + // Map from type to undef + std::unordered_map type2undefs_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_MEM_PASS_H_ diff --git a/source/opt/merge_return_pass.cpp b/source/opt/merge_return_pass.cpp new file mode 100644 index 000000000..4f49c70f2 --- /dev/null +++ b/source/opt/merge_return_pass.cpp @@ -0,0 +1,816 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/merge_return_pass.h" + +#include +#include +#include + +#include "source/opt/instruction.h" +#include "source/opt/ir_builder.h" +#include "source/opt/ir_context.h" +#include "source/opt/reflect.h" +#include "source/util/bit_vector.h" +#include "source/util/make_unique.h" + +namespace spvtools { +namespace opt { + +Pass::Status MergeReturnPass::Process() { + bool is_shader = + context()->get_feature_mgr()->HasCapability(SpvCapabilityShader); + + bool failed = false; + ProcessFunction pfn = [&failed, is_shader, this](Function* function) { + std::vector return_blocks = CollectReturnBlocks(function); + if (return_blocks.size() <= 1) { + return false; + } + + function_ = function; + return_flag_ = nullptr; + return_value_ = nullptr; + final_return_block_ = nullptr; + + if (is_shader) { + if (!ProcessStructured(function, return_blocks)) { + failed = true; + } + } else { + MergeReturnBlocks(function, return_blocks); + } + return true; + }; + + bool modified = context()->ProcessReachableCallTree(pfn); + + if (failed) { + return Status::Failure; + } + + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +bool MergeReturnPass::ProcessStructured( + Function* function, const std::vector& return_blocks) { + if (HasNontrivialUnreachableBlocks(function)) { + if (consumer()) { + std::string message = + "Module contains unreachable blocks during merge return. Run dead " + "branch elimination before merge return."; + consumer()(SPV_MSG_ERROR, 0, {0, 0, 0}, message.c_str()); + } + return false; + } + + AddDummyLoopAroundFunction(); + + std::list order; + cfg()->ComputeStructuredOrder(function, &*function->begin(), &order); + + state_.clear(); + state_.emplace_back(nullptr, nullptr); + for (auto block : order) { + if (cfg()->IsPseudoEntryBlock(block) || cfg()->IsPseudoExitBlock(block) || + block == final_return_block_) { + continue; + } + + auto blockId = block->GetLabelInst()->result_id(); + if (blockId == CurrentState().CurrentMergeId()) { + // Pop the current state as we've hit the merge + state_.pop_back(); + } + + ProcessStructuredBlock(block); + + // Generate state for next block + if (Instruction* mergeInst = block->GetMergeInst()) { + Instruction* loopMergeInst = block->GetLoopMergeInst(); + if (!loopMergeInst) loopMergeInst = state_.back().LoopMergeInst(); + state_.emplace_back(loopMergeInst, mergeInst); + } + } + + state_.clear(); + state_.emplace_back(nullptr, nullptr); + std::unordered_set predicated; + for (auto block : order) { + if (cfg()->IsPseudoEntryBlock(block) || cfg()->IsPseudoExitBlock(block)) { + continue; + } + + auto blockId = block->id(); + if (blockId == CurrentState().CurrentMergeId()) { + // Pop the current state as we've hit the merge + state_.pop_back(); + } + + // Predicate successors of the original return blocks as necessary. + if (std::find(return_blocks.begin(), return_blocks.end(), block) != + return_blocks.end()) { + if (!PredicateBlocks(block, &predicated, &order)) { + return false; + } + } + + // Generate state for next block + if (Instruction* mergeInst = block->GetMergeInst()) { + Instruction* loopMergeInst = block->GetLoopMergeInst(); + if (!loopMergeInst) loopMergeInst = state_.back().LoopMergeInst(); + state_.emplace_back(loopMergeInst, mergeInst); + } + } + + // We have not kept the dominator tree up-to-date. + // Invalidate it at this point to make sure it will be rebuilt. + context()->RemoveDominatorAnalysis(function); + AddNewPhiNodes(); + return true; +} + +void MergeReturnPass::CreateReturnBlock() { + // Create a label for the new return block + std::unique_ptr return_label( + new Instruction(context(), SpvOpLabel, 0u, TakeNextId(), {})); + + // Create the new basic block + std::unique_ptr return_block( + new BasicBlock(std::move(return_label))); + function_->AddBasicBlock(std::move(return_block)); + final_return_block_ = &*(--function_->end()); + context()->AnalyzeDefUse(final_return_block_->GetLabelInst()); + context()->set_instr_block(final_return_block_->GetLabelInst(), + final_return_block_); + final_return_block_->SetParent(function_); +} + +void MergeReturnPass::CreateReturn(BasicBlock* block) { + AddReturnValue(); + + if (return_value_) { + // Load and return the final return value + uint32_t loadId = TakeNextId(); + block->AddInstruction(MakeUnique( + context(), SpvOpLoad, function_->type_id(), loadId, + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {return_value_->result_id()}}})); + Instruction* var_inst = block->terminator(); + context()->AnalyzeDefUse(var_inst); + context()->set_instr_block(var_inst, block); + context()->get_decoration_mgr()->CloneDecorations( + return_value_->result_id(), loadId, {SpvDecorationRelaxedPrecision}); + + block->AddInstruction(MakeUnique( + context(), SpvOpReturnValue, 0, 0, + std::initializer_list{{SPV_OPERAND_TYPE_ID, {loadId}}})); + context()->AnalyzeDefUse(block->terminator()); + context()->set_instr_block(block->terminator(), block); + } else { + block->AddInstruction(MakeUnique(context(), SpvOpReturn)); + context()->AnalyzeDefUse(block->terminator()); + context()->set_instr_block(block->terminator(), block); + } +} + +void MergeReturnPass::ProcessStructuredBlock(BasicBlock* block) { + SpvOp tail_opcode = block->tail()->opcode(); + if (tail_opcode == SpvOpReturn || tail_opcode == SpvOpReturnValue) { + if (!return_flag_) { + AddReturnFlag(); + } + } + + if (tail_opcode == SpvOpReturn || tail_opcode == SpvOpReturnValue || + tail_opcode == SpvOpUnreachable) { + assert(CurrentState().InLoop() && "Should be in the dummy loop."); + BranchToBlock(block, CurrentState().LoopMergeId()); + } +} + +void MergeReturnPass::BranchToBlock(BasicBlock* block, uint32_t target) { + if (block->tail()->opcode() == SpvOpReturn || + block->tail()->opcode() == SpvOpReturnValue) { + RecordReturned(block); + RecordReturnValue(block); + } + + BasicBlock* target_block = context()->get_instr_block(target); + if (target_block->GetLoopMergeInst()) { + cfg()->SplitLoopHeader(target_block); + } + UpdatePhiNodes(block, target_block); + + Instruction* return_inst = block->terminator(); + return_inst->SetOpcode(SpvOpBranch); + return_inst->ReplaceOperands({{SPV_OPERAND_TYPE_ID, {target}}}); + context()->get_def_use_mgr()->AnalyzeInstDefUse(return_inst); + cfg()->AddEdge(block->id(), target); +} + +void MergeReturnPass::UpdatePhiNodes(BasicBlock* new_source, + BasicBlock* target) { + target->ForEachPhiInst([this, new_source](Instruction* inst) { + uint32_t undefId = Type2Undef(inst->type_id()); + inst->AddOperand({SPV_OPERAND_TYPE_ID, {undefId}}); + inst->AddOperand({SPV_OPERAND_TYPE_ID, {new_source->id()}}); + context()->UpdateDefUse(inst); + }); + + const auto& target_pred = cfg()->preds(target->id()); + if (target_pred.size() == 1) { + MarkForNewPhiNodes(target, context()->get_instr_block(target_pred[0])); + } +} + +void MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block, + uint32_t predecessor, + Instruction& inst) { + DominatorAnalysis* dom_tree = + context()->GetDominatorAnalysis(merge_block->GetParent()); + BasicBlock* inst_bb = context()->get_instr_block(&inst); + + if (inst.result_id() != 0) { + std::vector users_to_update; + context()->get_def_use_mgr()->ForEachUser( + &inst, + [&users_to_update, &dom_tree, &inst, inst_bb, this](Instruction* user) { + BasicBlock* user_bb = nullptr; + if (user->opcode() != SpvOpPhi) { + user_bb = context()->get_instr_block(user); + } else { + // For OpPhi, the use should be considered to be in the predecessor. + for (uint32_t i = 0; i < user->NumInOperands(); i += 2) { + if (user->GetSingleWordInOperand(i) == inst.result_id()) { + uint32_t user_bb_id = user->GetSingleWordInOperand(i + 1); + user_bb = context()->get_instr_block(user_bb_id); + break; + } + } + } + + // If |user_bb| is nullptr, then |user| is not in the function. It is + // something like an OpName or decoration, which should not be + // replaced with the result of the OpPhi. + if (user_bb && !dom_tree->Dominates(inst_bb, user_bb)) { + users_to_update.push_back(user); + } + }); + + if (users_to_update.empty()) { + return; + } + + // There is at least one values that needs to be replaced. + // First create the OpPhi instruction. + InstructionBuilder builder(context(), &*merge_block->begin(), + IRContext::kAnalysisDefUse); + uint32_t undef_id = Type2Undef(inst.type_id()); + std::vector phi_operands; + + // Add the operands for the defining instructions. + phi_operands.push_back(inst.result_id()); + phi_operands.push_back(predecessor); + + // Add undef from all other blocks. + std::vector preds = cfg()->preds(merge_block->id()); + for (uint32_t pred_id : preds) { + if (pred_id != predecessor) { + phi_operands.push_back(undef_id); + phi_operands.push_back(pred_id); + } + } + + Instruction* new_phi = builder.AddPhi(inst.type_id(), phi_operands); + uint32_t result_of_phi = new_phi->result_id(); + + // Update all of the users to use the result of the new OpPhi. + for (Instruction* user : users_to_update) { + user->ForEachInId([&inst, result_of_phi](uint32_t* id) { + if (*id == inst.result_id()) { + *id = result_of_phi; + } + }); + context()->AnalyzeUses(user); + } + } +} + +bool MergeReturnPass::PredicateBlocks( + BasicBlock* return_block, std::unordered_set* predicated, + std::list* order) { + // The CFG is being modified as the function proceeds so avoid caching + // successors. + + if (predicated->count(return_block)) { + return true; + } + + BasicBlock* block = nullptr; + const BasicBlock* const_block = const_cast(return_block); + const_block->ForEachSuccessorLabel([this, &block](const uint32_t idx) { + BasicBlock* succ_block = context()->get_instr_block(idx); + assert(block == nullptr); + block = succ_block; + }); + assert(block && + "Return blocks should have returns already replaced by a single " + "unconditional branch."); + + auto state = state_.rbegin(); + std::unordered_set seen; + if (block->id() == state->CurrentMergeId()) { + state++; + } else if (block->id() == state->LoopMergeId()) { + while (state->LoopMergeId() == block->id()) { + state++; + } + } + + while (block != nullptr && block != final_return_block_) { + if (!predicated->insert(block).second) break; + // Skip structured subgraphs. + assert(state->InLoop() && "Should be in the dummy loop at the very least."); + Instruction* current_loop_merge_inst = state->LoopMergeInst(); + uint32_t merge_block_id = + current_loop_merge_inst->GetSingleWordInOperand(0); + while (state->LoopMergeId() == merge_block_id) { + state++; + } + if (!BreakFromConstruct(block, predicated, order, + current_loop_merge_inst)) { + return false; + } + block = context()->get_instr_block(merge_block_id); + } + return true; +} + +bool MergeReturnPass::BreakFromConstruct( + BasicBlock* block, std::unordered_set* predicated, + std::list* order, Instruction* loop_merge_inst) { + assert(loop_merge_inst->opcode() == SpvOpLoopMerge && + "loop_merge_inst must be a loop merge instruction."); + // Make sure the CFG is build here. If we don't then it becomes very hard + // to know which new blocks need to be updated. + context()->BuildInvalidAnalyses(IRContext::kAnalysisCFG); + + // When predicating, be aware of whether this block is a header block, a + // merge block or both. + // + // If this block is a merge block, ensure the appropriate header stays + // up-to-date with any changes (i.e. points to the pre-header). + // + // If this block is a header block, predicate the entire structured + // subgraph. This can act recursively. + + // If |block| is a loop header, then the back edge must jump to the original + // code, not the new header. + if (block->GetLoopMergeInst()) { + if (cfg()->SplitLoopHeader(block) == nullptr) { + return false; + } + } + + uint32_t merge_block_id = loop_merge_inst->GetSingleWordInOperand(0); + BasicBlock* merge_block = context()->get_instr_block(merge_block_id); + if (merge_block->GetLoopMergeInst()) { + cfg()->SplitLoopHeader(merge_block); + } + + // Leave the phi instructions behind. + auto iter = block->begin(); + while (iter->opcode() == SpvOpPhi) { + ++iter; + } + + // Forget about the edges leaving block. They will be removed. + cfg()->RemoveSuccessorEdges(block); + + BasicBlock* old_body = block->SplitBasicBlock(context(), TakeNextId(), iter); + predicated->insert(old_body); + + // If |block| was a continue target for a loop |old_body| is now the correct + // continue target. + if (loop_merge_inst->GetSingleWordInOperand(1) == block->id()) { + loop_merge_inst->SetInOperand(1, {old_body->id()}); + context()->UpdateDefUse(loop_merge_inst); + } + + // Update |order| so old_block will be traversed. + InsertAfterElement(block, old_body, order); + + // Within the new header we need the following: + // 1. Load of the return status flag + // 2. Branch to |merge_block| (true) or old body (false) + // 3. Update OpPhi instructions in |merge_block|. + // 4. Update the CFG. + // + // Sine we are branching to the merge block of the current construct, there is + // no need for an OpSelectionMerge. + + InstructionBuilder builder( + context(), block, + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + // 1. Load of the return status flag + analysis::Bool bool_type; + uint32_t bool_id = context()->get_type_mgr()->GetId(&bool_type); + assert(bool_id != 0); + uint32_t load_id = + builder.AddLoad(bool_id, return_flag_->result_id())->result_id(); + + // 2. Branch to |merge_block| (true) or |old_body| (false) + builder.AddConditionalBranch(load_id, merge_block->id(), old_body->id(), + old_body->id()); + + // 3. Update OpPhi instructions in |merge_block|. + BasicBlock* merge_original_pred = MarkedSinglePred(merge_block); + if (merge_original_pred == nullptr) { + UpdatePhiNodes(block, merge_block); + } else if (merge_original_pred == block) { + MarkForNewPhiNodes(merge_block, old_body); + } + + // 4. Update the CFG. We do this after updating the OpPhi instructions + // because |UpdatePhiNodes| assumes the edge from |block| has not been added + // to the CFG yet. + cfg()->AddEdges(block); + cfg()->RegisterBlock(old_body); + + assert(old_body->begin() != old_body->end()); + assert(block->begin() != block->end()); + return true; +} + +void MergeReturnPass::RecordReturned(BasicBlock* block) { + if (block->tail()->opcode() != SpvOpReturn && + block->tail()->opcode() != SpvOpReturnValue) + return; + + assert(return_flag_ && "Did not generate the return flag variable."); + + if (!constant_true_) { + analysis::Bool temp; + const analysis::Bool* bool_type = + context()->get_type_mgr()->GetRegisteredType(&temp)->AsBool(); + + analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); + const analysis::Constant* true_const = + const_mgr->GetConstant(bool_type, {true}); + constant_true_ = const_mgr->GetDefiningInstruction(true_const); + context()->UpdateDefUse(constant_true_); + } + + std::unique_ptr return_store(new Instruction( + context(), SpvOpStore, 0, 0, + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {return_flag_->result_id()}}, + {SPV_OPERAND_TYPE_ID, {constant_true_->result_id()}}})); + + Instruction* store_inst = + &*block->tail().InsertBefore(std::move(return_store)); + context()->set_instr_block(store_inst, block); + context()->AnalyzeDefUse(store_inst); +} + +void MergeReturnPass::RecordReturnValue(BasicBlock* block) { + auto terminator = *block->tail(); + if (terminator.opcode() != SpvOpReturnValue) { + return; + } + + assert(return_value_ && + "Did not generate the variable to hold the return value."); + + std::unique_ptr value_store(new Instruction( + context(), SpvOpStore, 0, 0, + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {return_value_->result_id()}}, + {SPV_OPERAND_TYPE_ID, {terminator.GetSingleWordInOperand(0u)}}})); + + Instruction* store_inst = + &*block->tail().InsertBefore(std::move(value_store)); + context()->set_instr_block(store_inst, block); + context()->AnalyzeDefUse(store_inst); +} + +void MergeReturnPass::AddReturnValue() { + if (return_value_) return; + + uint32_t return_type_id = function_->type_id(); + if (get_def_use_mgr()->GetDef(return_type_id)->opcode() == SpvOpTypeVoid) + return; + + uint32_t return_ptr_type = context()->get_type_mgr()->FindPointerToType( + return_type_id, SpvStorageClassFunction); + + uint32_t var_id = TakeNextId(); + std::unique_ptr returnValue(new Instruction( + context(), SpvOpVariable, return_ptr_type, var_id, + std::initializer_list{ + {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}})); + + auto insert_iter = function_->begin()->begin(); + insert_iter.InsertBefore(std::move(returnValue)); + BasicBlock* entry_block = &*function_->begin(); + return_value_ = &*entry_block->begin(); + context()->AnalyzeDefUse(return_value_); + context()->set_instr_block(return_value_, entry_block); + + context()->get_decoration_mgr()->CloneDecorations( + function_->result_id(), var_id, {SpvDecorationRelaxedPrecision}); +} + +void MergeReturnPass::AddReturnFlag() { + if (return_flag_) return; + + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); + + analysis::Bool temp; + uint32_t bool_id = type_mgr->GetTypeInstruction(&temp); + analysis::Bool* bool_type = type_mgr->GetType(bool_id)->AsBool(); + + const analysis::Constant* false_const = + const_mgr->GetConstant(bool_type, {false}); + uint32_t const_false_id = + const_mgr->GetDefiningInstruction(false_const)->result_id(); + + uint32_t bool_ptr_id = + type_mgr->FindPointerToType(bool_id, SpvStorageClassFunction); + + uint32_t var_id = TakeNextId(); + std::unique_ptr returnFlag(new Instruction( + context(), SpvOpVariable, bool_ptr_id, var_id, + std::initializer_list{ + {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}, + {SPV_OPERAND_TYPE_ID, {const_false_id}}})); + + auto insert_iter = function_->begin()->begin(); + + insert_iter.InsertBefore(std::move(returnFlag)); + BasicBlock* entry_block = &*function_->begin(); + return_flag_ = &*entry_block->begin(); + context()->AnalyzeDefUse(return_flag_); + context()->set_instr_block(return_flag_, entry_block); +} + +std::vector MergeReturnPass::CollectReturnBlocks( + Function* function) { + std::vector return_blocks; + for (auto& block : *function) { + Instruction& terminator = *block.tail(); + if (terminator.opcode() == SpvOpReturn || + terminator.opcode() == SpvOpReturnValue) { + return_blocks.push_back(&block); + } + } + return return_blocks; +} + +void MergeReturnPass::MergeReturnBlocks( + Function* function, const std::vector& return_blocks) { + if (return_blocks.size() <= 1) { + // No work to do. + return; + } + + CreateReturnBlock(); + uint32_t return_id = final_return_block_->id(); + auto ret_block_iter = --function->end(); + // Create the PHI for the merged block (if necessary). + // Create new return. + std::vector phi_ops; + for (auto block : return_blocks) { + if (block->tail()->opcode() == SpvOpReturnValue) { + phi_ops.push_back( + {SPV_OPERAND_TYPE_ID, {block->tail()->GetSingleWordInOperand(0u)}}); + phi_ops.push_back({SPV_OPERAND_TYPE_ID, {block->id()}}); + } + } + + if (!phi_ops.empty()) { + // Need a PHI node to select the correct return value. + uint32_t phi_result_id = TakeNextId(); + uint32_t phi_type_id = function->type_id(); + std::unique_ptr phi_inst(new Instruction( + context(), SpvOpPhi, phi_type_id, phi_result_id, phi_ops)); + ret_block_iter->AddInstruction(std::move(phi_inst)); + BasicBlock::iterator phiIter = ret_block_iter->tail(); + + std::unique_ptr return_inst( + new Instruction(context(), SpvOpReturnValue, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {phi_result_id}}})); + ret_block_iter->AddInstruction(std::move(return_inst)); + BasicBlock::iterator ret = ret_block_iter->tail(); + + // Register the phi def and mark instructions for use updates. + get_def_use_mgr()->AnalyzeInstDefUse(&*phiIter); + get_def_use_mgr()->AnalyzeInstDef(&*ret); + } else { + std::unique_ptr return_inst( + new Instruction(context(), SpvOpReturn)); + ret_block_iter->AddInstruction(std::move(return_inst)); + } + + // Replace returns with branches + for (auto block : return_blocks) { + context()->ForgetUses(block->terminator()); + block->tail()->SetOpcode(SpvOpBranch); + block->tail()->ReplaceOperands({{SPV_OPERAND_TYPE_ID, {return_id}}}); + get_def_use_mgr()->AnalyzeInstUse(block->terminator()); + get_def_use_mgr()->AnalyzeInstUse(block->GetLabelInst()); + } + + get_def_use_mgr()->AnalyzeInstDefUse(ret_block_iter->GetLabelInst()); +} + +void MergeReturnPass::AddNewPhiNodes() { + DominatorAnalysis* dom_tree = context()->GetDominatorAnalysis(function_); + std::list order; + cfg()->ComputeStructuredOrder(function_, &*function_->begin(), &order); + + for (BasicBlock* bb : order) { + BasicBlock* dominator = dom_tree->ImmediateDominator(bb); + if (dominator) { + AddNewPhiNodes(bb, new_merge_nodes_[bb], dominator->id()); + } + } +} + +void MergeReturnPass::AddNewPhiNodes(BasicBlock* bb, BasicBlock* pred, + uint32_t header_id) { + DominatorAnalysis* dom_tree = context()->GetDominatorAnalysis(function_); + // Insert as a stopping point. We do not have to add anything in the block + // or above because the header dominates |bb|. + + BasicBlock* current_bb = pred; + while (current_bb != nullptr && current_bb->id() != header_id) { + for (Instruction& inst : *current_bb) { + CreatePhiNodesForInst(bb, pred->id(), inst); + } + current_bb = dom_tree->ImmediateDominator(current_bb); + } +} + +void MergeReturnPass::MarkForNewPhiNodes(BasicBlock* block, + BasicBlock* single_original_pred) { + new_merge_nodes_[block] = single_original_pred; +} + +void MergeReturnPass::InsertAfterElement(BasicBlock* element, + BasicBlock* new_element, + std::list* list) { + auto pos = std::find(list->begin(), list->end(), element); + assert(pos != list->end()); + ++pos; + list->insert(pos, new_element); +} + +void MergeReturnPass::AddDummyLoopAroundFunction() { + CreateReturnBlock(); + CreateReturn(final_return_block_); + + if (context()->AreAnalysesValid(IRContext::kAnalysisCFG)) { + cfg()->RegisterBlock(final_return_block_); + } + + CreateDummyLoop(final_return_block_); +} + +BasicBlock* MergeReturnPass::CreateContinueTarget(uint32_t header_label_id) { + std::unique_ptr label( + new Instruction(context(), SpvOpLabel, 0u, TakeNextId(), {})); + + // Create the new basic block + std::unique_ptr block(new BasicBlock(std::move(label))); + + // Insert the new block just before the return block + auto pos = function_->end(); + assert(pos != function_->begin()); + pos--; + assert(pos != function_->begin()); + assert(&*pos == final_return_block_); + auto new_block = &*pos.InsertBefore(std::move(block)); + new_block->SetParent(function_); + + context()->AnalyzeDefUse(new_block->GetLabelInst()); + context()->set_instr_block(new_block->GetLabelInst(), new_block); + + InstructionBuilder builder( + context(), new_block, + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + + builder.AddBranch(header_label_id); + + if (context()->AreAnalysesValid(IRContext::kAnalysisCFG)) { + cfg()->RegisterBlock(new_block); + } + + return new_block; +} + +void MergeReturnPass::CreateDummyLoop(BasicBlock* merge_target) { + std::unique_ptr label( + new Instruction(context(), SpvOpLabel, 0u, TakeNextId(), {})); + + // Create the new basic block + std::unique_ptr block(new BasicBlock(std::move(label))); + + // Insert the new block before any code is run. We have to split the entry + // block to make sure the OpVariable instructions remain in the entry block. + BasicBlock* start_block = &*function_->begin(); + auto split_pos = start_block->begin(); + while (split_pos->opcode() == SpvOpVariable) { + ++split_pos; + } + + BasicBlock* old_block = + start_block->SplitBasicBlock(context(), TakeNextId(), split_pos); + + // The new block must be inserted after the entry block. We cannot make the + // entry block the header for the dummy loop because it is not valid to have a + // branch to the entry block, and the continue target must branch back to the + // loop header. + auto pos = function_->begin(); + pos++; + BasicBlock* header_block = &*pos.InsertBefore(std::move(block)); + context()->AnalyzeDefUse(header_block->GetLabelInst()); + header_block->SetParent(function_); + + // We have to create the continue block before OpLoopMerge instruction. + // Otherwise the def-use manager will compalain that there is a use without a + // definition. + uint32_t continue_target = CreateContinueTarget(header_block->id())->id(); + + // Add the code the the header block. + InstructionBuilder builder( + context(), header_block, + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + + builder.AddLoopMerge(merge_target->id(), continue_target); + builder.AddBranch(old_block->id()); + + // Fix up the entry block by adding a branch to the loop header. + InstructionBuilder builder2( + context(), start_block, + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + builder2.AddBranch(header_block->id()); + + if (context()->AreAnalysesValid(IRContext::kAnalysisCFG)) { + cfg()->RegisterBlock(old_block); + cfg()->RegisterBlock(header_block); + cfg()->AddEdges(start_block); + } +} + +bool MergeReturnPass::HasNontrivialUnreachableBlocks(Function* function) { + utils::BitVector reachable_blocks; + cfg()->ForEachBlockInPostOrder( + function->entry().get(), + [&reachable_blocks](BasicBlock* bb) { reachable_blocks.Set(bb->id()); }); + + for (auto& bb : *function) { + if (reachable_blocks.Get(bb.id())) { + continue; + } + + StructuredCFGAnalysis* struct_cfg_analysis = + context()->GetStructuredCFGAnalysis(); + if (struct_cfg_analysis->IsMergeBlock(bb.id())) { + // |bb| must be an empty block ending with OpUnreachable. + if (bb.begin()->opcode() != SpvOpUnreachable) { + return true; + } + } else if (struct_cfg_analysis->IsContinueBlock(bb.id())) { + // |bb| must be an empty block ending with a branch to the header. + Instruction* inst = &*bb.begin(); + if (inst->opcode() != SpvOpBranch) { + return true; + } + + if (inst->GetSingleWordInOperand(0) != + struct_cfg_analysis->ContainingLoop(bb.id())) { + return true; + } + } else { + return true; + } + } + return false; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/merge_return_pass.h b/source/opt/merge_return_pass.h new file mode 100644 index 000000000..d7e18f0c5 --- /dev/null +++ b/source/opt/merge_return_pass.h @@ -0,0 +1,339 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_MERGE_RETURN_PASS_H_ +#define SOURCE_OPT_MERGE_RETURN_PASS_H_ + +#include +#include +#include + +#include "source/opt/basic_block.h" +#include "source/opt/function.h" +#include "source/opt/mem_pass.h" + +namespace spvtools { +namespace opt { + +/******************************************************************************* + * + * Handling Structured Control Flow: + * + * Structured control flow guarantees that the CFG will converge at a given + * point (the merge block). Within structured control flow, all blocks must be + * post-dominated by the merge block, except return blocks and break blocks. + * A break block is a block that branches to the innermost loop's merge block. + * + * Beyond this, we further assume that all unreachable blocks have been + * cleaned up. This means that the only unreachable blocks are those necessary + * for valid structured control flow. + * + * Algorithm: + * + * If a return is encountered, it should record that: i) the function has + * "returned" and ii) the value of the return. The return should be replaced + * with a branch. If current block is not within structured control flow, this + * is the final return. This block should branch to the new return block (its + * direct successor). If the current block is within structured control flow, + * the branch destination should be the innermost loop's merge. This loop will + * always exist because a dummy loop is added around the entire function. + * If the merge block produces any live values it will need to be predicated. + * While the merge is nested in structured control flow, the predication path + *should branch to the merge block of the inner-most loop it is contained in. + *Once structured control flow has been exited, it will be at the merge of the + *dummy loop, with will simply return. + * + * In the final return block, the return value should be loaded and returned. + * Memory promotion passes should be able to promote the newly introduced + * variables ("has returned" and "return value"). + * + * Predicating the Final Merge: + * + * At each merge block predication needs to be introduced (optimization: only if + * that block produces value live beyond it). This needs to be done carefully. + * The merge block should be split into multiple blocks. + * + * 1 (loop header) + * / \ + * (ret) 2 3 (merge) + * + * || + * \/ + * + * 0 (dummy loop header) + * | + * 1 (loop header) + * / \ + * 2 | (merge) + * \ / + * 3' (merge) + * / \ + * | 3 (original code in 3) + * \ / + * (ret) 4 (dummy loop merge) + * + * In the above (simple) example, the return originally in |2| is passed through + * the merge. That merge is predicated such that the old body of the block is + * the else branch. The branch condition is based on the value of the "has + * returned" variable. + * + ******************************************************************************/ + +// Documented in optimizer.hpp +class MergeReturnPass : public MemPass { + public: + MergeReturnPass() + : function_(nullptr), + return_flag_(nullptr), + return_value_(nullptr), + constant_true_(nullptr), + final_return_block_(nullptr) {} + + const char* name() const override { return "merge-return"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisConstants | IRContext::kAnalysisTypes; + } + + private: + // This class is used to store the a loop merge instruction and a selection + // merge instruction. The intended use is that is represent the inner most + // contain selection construct and the inner most loop construct. + class StructuredControlState { + public: + StructuredControlState(Instruction* loop, Instruction* merge) + : loop_merge_(loop), current_merge_(merge) {} + + StructuredControlState(const StructuredControlState&) = default; + + bool InLoop() const { return loop_merge_; } + bool InStructuredFlow() const { return CurrentMergeId() != 0; } + + uint32_t CurrentMergeId() const { + return current_merge_ ? current_merge_->GetSingleWordInOperand(0u) : 0u; + } + + uint32_t CurrentMergeHeader() const { + return current_merge_ ? current_merge_->context() + ->get_instr_block(current_merge_) + ->id() + : 0; + } + + uint32_t LoopMergeId() const { + return loop_merge_ ? loop_merge_->GetSingleWordInOperand(0u) : 0u; + } + + uint32_t CurrentLoopHeader() const { + return loop_merge_ + ? loop_merge_->context()->get_instr_block(loop_merge_)->id() + : 0; + } + + Instruction* LoopMergeInst() const { return loop_merge_; } + + private: + Instruction* loop_merge_; + Instruction* current_merge_; + }; + + // Returns all BasicBlocks terminated by OpReturn or OpReturnValue in + // |function|. + std::vector CollectReturnBlocks(Function* function); + + // Creates a new basic block with a single return. If |function| returns a + // value, a phi node is created to select the correct value to return. + // Replaces old returns with an unconditional branch to the new block. + void MergeReturnBlocks(Function* function, + const std::vector& returnBlocks); + + // Merges the return instruction in |function| so that it has a single return + // statement. It is assumed that |function| has structured control flow, and + // that |return_blocks| is a list of all of the basic blocks in |function| + // that have a return. + bool ProcessStructured(Function* function, + const std::vector& return_blocks); + + // Changes an OpReturn* or OpUnreachable instruction at the end of |block| + // into a store to |return_flag_|, a store to |return_value_| (if necessary), + // and a branch to the appropriate merge block. + // + // Is is assumed that |AddReturnValue| have already been called to created the + // variable to store a return value if there is one. + // + // Note this will break the semantics. To fix this, PredicateBlock will have + // to be called on the merge block the branch targets. + void ProcessStructuredBlock(BasicBlock* block); + + // Creates a variable used to store whether or not the control flow has + // traversed a block that used to have a return. A pointer to the instruction + // declaring the variable is stored in |return_flag_|. + void AddReturnFlag(); + + // Creates the variable used to store the return value when passing through + // a block that use to contain an OpReturnValue. + void AddReturnValue(); + + // Adds a store that stores true to |return_flag_| immediately before the + // terminator of |block|. It is assumed that |AddReturnFlag| has already been + // called. + void RecordReturned(BasicBlock* block); + + // Adds an instruction that stores the value being returned in the + // OpReturnValue in |block|. The value is stored to |return_value_|, and the + // store is placed before the OpReturnValue. + // + // If |block| does not contain an OpReturnValue, then this function has no + // effect. If |block| contains an OpReturnValue, then |AddReturnValue| must + // have already been called to create the variable to store to. + void RecordReturnValue(BasicBlock* block); + + // Adds an unconditional branch in |block| that branches to |target|. It also + // adds stores to |return_flag_| and |return_value_| as needed. + // |AddReturnFlag| and |AddReturnValue| must have already been called. + void BranchToBlock(BasicBlock* block, uint32_t target); + + // For every basic block that is reachable from |return_block|, extra code is + // added to jump around any code that should not be executed because the + // original code would have already returned. This involves adding new + // selections constructs to jump around these instructions. + // + // If new blocks that are created will be added to |order|. This way a call + // can traverse these new block in structured order. + // + // Returns true if successful. + bool PredicateBlocks(BasicBlock* return_block, + std::unordered_set* pSet, + std::list* order); + + // Add a conditional branch at the start of |block| that either jumps to + // the merge block of |loop_merge_inst| or the original code in |block| + // depending on the value in |return_flag_|. The continue target in + // |loop_merge_inst| will be updated if needed. + // + // If new blocks that are created will be added to |order|. This way a call + // can traverse these new block in structured order. + // + // Returns true if successful. + bool BreakFromConstruct(BasicBlock* block, + std::unordered_set* predicated, + std::list* order, + Instruction* loop_merge_inst); + + // Add an |OpReturn| or |OpReturnValue| to the end of |block|. If an + // |OpReturnValue| is needed, the return value is loaded from |return_value_|. + void CreateReturn(BasicBlock* block); + + // Creates a block at the end of the function that will become the single + // return block at the end of the pass. + void CreateReturnBlock(); + + // Creates a Phi node in |merge_block| for the result of |inst| coming from + // |predecessor|. Any uses of the result of |inst| that are no longer + // dominated by |inst|, are replaced with the result of the new |OpPhi| + // instruction. + void CreatePhiNodesForInst(BasicBlock* merge_block, uint32_t predecessor, + Instruction& inst); + + // Traverse the nodes in |new_merge_nodes_|, and adds the OpPhi instructions + // that are needed to make the code correct. It is assumed that at this point + // there are no unreachable blocks in the control flow graph. + void AddNewPhiNodes(); + + // Creates any new phi nodes that are needed in |bb| now that |pred| is no + // longer the only block that preceedes |bb|. |header_id| is the id of the + // basic block for the loop or selection construct that merges at |bb|. + void AddNewPhiNodes(BasicBlock* bb, BasicBlock* pred, uint32_t header_id); + + // Saves |block| to a list of basic block that will require OpPhi nodes to be + // added by calling |AddNewPhiNodes|. It is assumed that |block| used to have + // a single predecessor, |single_original_pred|, but now has more. + void MarkForNewPhiNodes(BasicBlock* block, BasicBlock* single_original_pred); + + // Return the original single predcessor of |block| if it was flagged as + // having a single predecessor. |nullptr| is returned otherwise. + BasicBlock* MarkedSinglePred(BasicBlock* block) { + auto it = new_merge_nodes_.find(block); + if (it != new_merge_nodes_.end()) { + return it->second; + } else { + return nullptr; + } + } + + // Modifies existing OpPhi instruction in |target| block to account for the + // new edge from |new_source|. The value for that edge will be an Undef. If + // |target| only had a single predecessor, then it is marked as needing new + // phi nodes. See |MarkForNewPhiNodes|. + // + // The CFG must not include the edge from |new_source| to |target| yet. + void UpdatePhiNodes(BasicBlock* new_source, BasicBlock* target); + + StructuredControlState& CurrentState() { return state_.back(); } + + // Inserts |new_element| into |list| after the first occurrence of |element|. + // |element| must be in |list| at least once. + void InsertAfterElement(BasicBlock* element, BasicBlock* new_element, + std::list* list); + + // Creates a single iteration loop around all of the exectuable code of the + // current function and returns after the loop is done. Sets + // |final_return_block_|. + void AddDummyLoopAroundFunction(); + + // Creates a new basic block that branches to |header_label_id|. Returns the + // new basic block. The block will be the second last basic block in the + // function. + BasicBlock* CreateContinueTarget(uint32_t header_label_id); + + // Creates a loop around the executable code of the function with + // |merge_target| as the merge node. + void CreateDummyLoop(BasicBlock* merge_target); + + // A stack used to keep track of the innermost contain loop and selection + // constructs. + std::vector state_; + + // The current function being transformed. + Function* function_; + + // The |OpVariable| instruction defining a boolean variable used to keep track + // of whether or not the function is trying to return. + Instruction* return_flag_; + + // The |OpVariable| instruction defining a variabled to used to keep track of + // the value that was returned when passing through a block that use to + // contain an |OpReturnValue|. + Instruction* return_value_; + + // The instruction defining the boolean constant true. + Instruction* constant_true_; + + // The basic block that is suppose to become the contain the only return value + // after processing the current function. + BasicBlock* final_return_block_; + + // This map contains the set of nodes that use to have a single predcessor, + // but now have more. They will need new OpPhi nodes. For each of the nodes, + // it is mapped to it original single predcessor. It is assumed there are no + // values that will need a phi on the new edges. + std::unordered_map new_merge_nodes_; + bool HasNontrivialUnreachableBlocks(Function* function); +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_MERGE_RETURN_PASS_H_ diff --git a/source/opt/module.cpp b/source/opt/module.cpp new file mode 100644 index 000000000..04e4e9733 --- /dev/null +++ b/source/opt/module.cpp @@ -0,0 +1,186 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/module.h" + +#include +#include +#include + +#include "source/operand.h" +#include "source/opt/ir_context.h" +#include "source/opt/reflect.h" + +namespace spvtools { +namespace opt { + +uint32_t Module::TakeNextIdBound() { + if (context()) { + if (id_bound() >= context()->max_id_bound()) { + return 0; + } + } else if (id_bound() >= kDefaultMaxIdBound) { + return 0; + } + + return header_.bound++; +} + +std::vector Module::GetTypes() { + std::vector type_insts; + for (auto& inst : types_values_) { + if (IsTypeInst(inst.opcode())) type_insts.push_back(&inst); + } + return type_insts; +} + +std::vector Module::GetTypes() const { + std::vector type_insts; + for (auto& inst : types_values_) { + if (IsTypeInst(inst.opcode())) type_insts.push_back(&inst); + } + return type_insts; +} + +std::vector Module::GetConstants() { + std::vector const_insts; + for (auto& inst : types_values_) { + if (IsConstantInst(inst.opcode())) const_insts.push_back(&inst); + } + return const_insts; +} + +std::vector Module::GetConstants() const { + std::vector const_insts; + for (auto& inst : types_values_) { + if (IsConstantInst(inst.opcode())) const_insts.push_back(&inst); + } + return const_insts; +} + +uint32_t Module::GetGlobalValue(SpvOp opcode) const { + for (auto& inst : types_values_) { + if (inst.opcode() == opcode) return inst.result_id(); + } + return 0; +} + +void Module::AddGlobalValue(SpvOp opcode, uint32_t result_id, + uint32_t type_id) { + std::unique_ptr newGlobal( + new Instruction(context(), opcode, type_id, result_id, {})); + AddGlobalValue(std::move(newGlobal)); +} + +void Module::ForEachInst(const std::function& f, + bool run_on_debug_line_insts) { +#define DELEGATE(list) list.ForEachInst(f, run_on_debug_line_insts) + DELEGATE(capabilities_); + DELEGATE(extensions_); + DELEGATE(ext_inst_imports_); + if (memory_model_) memory_model_->ForEachInst(f, run_on_debug_line_insts); + DELEGATE(entry_points_); + DELEGATE(execution_modes_); + DELEGATE(debugs1_); + DELEGATE(debugs2_); + DELEGATE(debugs3_); + DELEGATE(annotations_); + DELEGATE(types_values_); + for (auto& i : functions_) i->ForEachInst(f, run_on_debug_line_insts); +#undef DELEGATE +} + +void Module::ForEachInst(const std::function& f, + bool run_on_debug_line_insts) const { +#define DELEGATE(i) i.ForEachInst(f, run_on_debug_line_insts) + for (auto& i : capabilities_) DELEGATE(i); + for (auto& i : extensions_) DELEGATE(i); + for (auto& i : ext_inst_imports_) DELEGATE(i); + if (memory_model_) + static_cast(memory_model_.get()) + ->ForEachInst(f, run_on_debug_line_insts); + for (auto& i : entry_points_) DELEGATE(i); + for (auto& i : execution_modes_) DELEGATE(i); + for (auto& i : debugs1_) DELEGATE(i); + for (auto& i : debugs2_) DELEGATE(i); + for (auto& i : debugs3_) DELEGATE(i); + for (auto& i : annotations_) DELEGATE(i); + for (auto& i : types_values_) DELEGATE(i); + for (auto& i : functions_) { + static_cast(i.get())->ForEachInst(f, + run_on_debug_line_insts); + } +#undef DELEGATE +} + +void Module::ToBinary(std::vector* binary, bool skip_nop) const { + binary->push_back(header_.magic_number); + binary->push_back(header_.version); + // TODO(antiagainst): should we change the generator number? + binary->push_back(header_.generator); + binary->push_back(header_.bound); + binary->push_back(header_.reserved); + + auto write_inst = [binary, skip_nop](const Instruction* i) { + if (!(skip_nop && i->IsNop())) i->ToBinaryWithoutAttachedDebugInsts(binary); + }; + ForEachInst(write_inst, true); +} + +uint32_t Module::ComputeIdBound() const { + uint32_t highest = 0; + + ForEachInst( + [&highest](const Instruction* inst) { + for (const auto& operand : *inst) { + if (spvIsIdType(operand.type)) { + highest = std::max(highest, operand.words[0]); + } + } + }, + true /* scan debug line insts as well */); + + return highest + 1; +} + +bool Module::HasExplicitCapability(uint32_t cap) { + for (auto& ci : capabilities_) { + uint32_t tcap = ci.GetSingleWordOperand(0); + if (tcap == cap) { + return true; + } + } + return false; +} + +uint32_t Module::GetExtInstImportId(const char* extstr) { + for (auto& ei : ext_inst_imports_) + if (!strcmp(extstr, + reinterpret_cast(&(ei.GetInOperand(0).words[0])))) + return ei.result_id(); + return 0; +} + +std::ostream& operator<<(std::ostream& str, const Module& module) { + module.ForEachInst([&str](const Instruction* inst) { + str << *inst; + if (inst->opcode() != SpvOpFunctionEnd) { + str << std::endl; + } + }); + return str; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/module.h b/source/opt/module.h new file mode 100644 index 000000000..ede0bbbf3 --- /dev/null +++ b/source/opt/module.h @@ -0,0 +1,484 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_MODULE_H_ +#define SOURCE_OPT_MODULE_H_ + +#include +#include +#include +#include + +#include "source/opt/function.h" +#include "source/opt/instruction.h" +#include "source/opt/iterator.h" + +namespace spvtools { +namespace opt { + +class IRContext; + +// A struct for containing the module header information. +struct ModuleHeader { + uint32_t magic_number; + uint32_t version; + uint32_t generator; + uint32_t bound; + uint32_t reserved; +}; + +// A SPIR-V module. It contains all the information for a SPIR-V module and +// serves as the backbone of optimization transformations. +class Module { + public: + using iterator = UptrVectorIterator; + using const_iterator = UptrVectorIterator; + using inst_iterator = InstructionList::iterator; + using const_inst_iterator = InstructionList::const_iterator; + + // Creates an empty module with zero'd header. + Module() : header_({}) {} + + // Sets the header to the given |header|. + void SetHeader(const ModuleHeader& header) { header_ = header; } + + // Sets the Id bound. The Id bound cannot be set to 0. + void SetIdBound(uint32_t bound) { + assert(bound != 0); + header_.bound = bound; + } + + // Returns the Id bound. + uint32_t IdBound() { return header_.bound; } + + // Returns the current Id bound and increases it to the next available value. + // If the id bound has already reached its maximum value, then 0 is returned. + // The maximum value for the id bound is obtained from the context. If there + // is none, then the minimum that limit can be according to the spir-v + // specification. + // TODO(1841): Update the uses to check for a 0 return value. + uint32_t TakeNextIdBound(); + + // Appends a capability instruction to this module. + inline void AddCapability(std::unique_ptr c); + + // Appends an extension instruction to this module. + inline void AddExtension(std::unique_ptr e); + + // Appends an extended instruction set instruction to this module. + inline void AddExtInstImport(std::unique_ptr e); + + // Set the memory model for this module. + inline void SetMemoryModel(std::unique_ptr m); + + // Appends an entry point instruction to this module. + inline void AddEntryPoint(std::unique_ptr e); + + // Appends an execution mode instruction to this module. + inline void AddExecutionMode(std::unique_ptr e); + + // Appends a debug 1 instruction (excluding OpLine & OpNoLine) to this module. + // "debug 1" instructions are the ones in layout section 7.a), see section + // 2.4 Logical Layout of a Module from the SPIR-V specification. + inline void AddDebug1Inst(std::unique_ptr d); + + // Appends a debug 2 instruction (excluding OpLine & OpNoLine) to this module. + // "debug 2" instructions are the ones in layout section 7.b), see section + // 2.4 Logical Layout of a Module from the SPIR-V specification. + inline void AddDebug2Inst(std::unique_ptr d); + + // Appends a debug 3 instruction (OpModuleProcessed) to this module. + // This is due to decision by the SPIR Working Group, pending publication. + inline void AddDebug3Inst(std::unique_ptr d); + + // Appends an annotation instruction to this module. + inline void AddAnnotationInst(std::unique_ptr a); + + // Appends a type-declaration instruction to this module. + inline void AddType(std::unique_ptr t); + + // Appends a constant, global variable, or OpUndef instruction to this module. + inline void AddGlobalValue(std::unique_ptr v); + + // Appends a function to this module. + inline void AddFunction(std::unique_ptr f); + + // Returns a vector of pointers to type-declaration instructions in this + // module. + std::vector GetTypes(); + std::vector GetTypes() const; + // Returns a vector of pointers to constant-creation instructions in this + // module. + std::vector GetConstants(); + std::vector GetConstants() const; + + // Return result id of global value with |opcode|, 0 if not present. + uint32_t GetGlobalValue(SpvOp opcode) const; + + // Add global value with |opcode|, |result_id| and |type_id| + void AddGlobalValue(SpvOp opcode, uint32_t result_id, uint32_t type_id); + + inline uint32_t id_bound() const { return header_.bound; } + + inline uint32_t version() const { return header_.version; } + + // Iterators for capabilities instructions contained in this module. + inline inst_iterator capability_begin(); + inline inst_iterator capability_end(); + inline IteratorRange capabilities(); + inline IteratorRange capabilities() const; + + // Iterators for ext_inst_imports instructions contained in this module. + inline inst_iterator ext_inst_import_begin(); + inline inst_iterator ext_inst_import_end(); + inline IteratorRange ext_inst_imports(); + inline IteratorRange ext_inst_imports() const; + + // Return the memory model instruction contained inthis module. + inline Instruction* GetMemoryModel() { return memory_model_.get(); } + inline const Instruction* GetMemoryModel() const { + return memory_model_.get(); + } + + // There are several kinds of debug instructions, according to where they can + // appear in the logical layout of a module: + // - Section 7a: OpString, OpSourceExtension, OpSource, OpSourceContinued + // - Section 7b: OpName, OpMemberName + // - Section 7c: OpModuleProcessed + // - Mostly anywhere: OpLine and OpNoLine + // + + // Iterators for debug 1 instructions (excluding OpLine & OpNoLine) contained + // in this module. These are for layout section 7a. + inline inst_iterator debug1_begin(); + inline inst_iterator debug1_end(); + inline IteratorRange debugs1(); + inline IteratorRange debugs1() const; + + // Iterators for debug 2 instructions (excluding OpLine & OpNoLine) contained + // in this module. These are for layout section 7b. + inline inst_iterator debug2_begin(); + inline inst_iterator debug2_end(); + inline IteratorRange debugs2(); + inline IteratorRange debugs2() const; + + // Iterators for debug 3 instructions (excluding OpLine & OpNoLine) contained + // in this module. These are for layout section 7c. + inline inst_iterator debug3_begin(); + inline inst_iterator debug3_end(); + inline IteratorRange debugs3(); + inline IteratorRange debugs3() const; + + // Iterators for entry point instructions contained in this module + inline IteratorRange entry_points(); + inline IteratorRange entry_points() const; + + // Iterators for execution_modes instructions contained in this module. + inline inst_iterator execution_mode_begin(); + inline inst_iterator execution_mode_end(); + inline IteratorRange execution_modes(); + inline IteratorRange execution_modes() const; + + // Clears all debug instructions (excluding OpLine & OpNoLine). + void debug_clear() { + debug1_clear(); + debug2_clear(); + debug3_clear(); + } + + // Clears all debug 1 instructions (excluding OpLine & OpNoLine). + void debug1_clear() { debugs1_.clear(); } + + // Clears all debug 2 instructions (excluding OpLine & OpNoLine). + void debug2_clear() { debugs2_.clear(); } + + // Clears all debug 3 instructions (excluding OpLine & OpNoLine). + void debug3_clear() { debugs3_.clear(); } + + // Iterators for annotation instructions contained in this module. + inline inst_iterator annotation_begin(); + inline inst_iterator annotation_end(); + IteratorRange annotations(); + IteratorRange annotations() const; + + // Iterators for extension instructions contained in this module. + inline inst_iterator extension_begin(); + inline inst_iterator extension_end(); + IteratorRange extensions(); + IteratorRange extensions() const; + + // Iterators for types, constants and global variables instructions. + inline inst_iterator types_values_begin(); + inline inst_iterator types_values_end(); + inline IteratorRange types_values(); + inline IteratorRange types_values() const; + + // Iterators for functions contained in this module. + iterator begin() { return iterator(&functions_, functions_.begin()); } + iterator end() { return iterator(&functions_, functions_.end()); } + const_iterator begin() const { return cbegin(); } + const_iterator end() const { return cend(); } + inline const_iterator cbegin() const; + inline const_iterator cend() const; + + // Invokes function |f| on all instructions in this module, and optionally on + // the debug line instructions that precede them. + void ForEachInst(const std::function& f, + bool run_on_debug_line_insts = false); + void ForEachInst(const std::function& f, + bool run_on_debug_line_insts = false) const; + + // Pushes the binary segments for this instruction into the back of *|binary|. + // If |skip_nop| is true and this is a OpNop, do nothing. + void ToBinary(std::vector* binary, bool skip_nop) const; + + // Returns 1 more than the maximum Id value mentioned in the module. + uint32_t ComputeIdBound() const; + + // Returns true if module has capability |cap| + bool HasExplicitCapability(uint32_t cap); + + // Returns id for OpExtInst instruction for extension |extstr|. + // Returns 0 if not found. + uint32_t GetExtInstImportId(const char* extstr); + + // Sets the associated context for this module + void SetContext(IRContext* c) { context_ = c; } + + // Gets the associated context for this module + IRContext* context() const { return context_; } + + private: + ModuleHeader header_; // Module header + + // The following fields respect the "Logical Layout of a Module" in + // Section 2.4 of the SPIR-V specification. + IRContext* context_; + InstructionList capabilities_; + InstructionList extensions_; + InstructionList ext_inst_imports_; + // A module only has one memory model instruction. + std::unique_ptr memory_model_; + InstructionList entry_points_; + InstructionList execution_modes_; + InstructionList debugs1_; + InstructionList debugs2_; + InstructionList debugs3_; + InstructionList annotations_; + // Type declarations, constants, and global variable declarations. + InstructionList types_values_; + std::vector> functions_; +}; + +// Pretty-prints |module| to |str|. Returns |str|. +std::ostream& operator<<(std::ostream& str, const Module& module); + +inline void Module::AddCapability(std::unique_ptr c) { + capabilities_.push_back(std::move(c)); +} + +inline void Module::AddExtension(std::unique_ptr e) { + extensions_.push_back(std::move(e)); +} + +inline void Module::AddExtInstImport(std::unique_ptr e) { + ext_inst_imports_.push_back(std::move(e)); +} + +inline void Module::SetMemoryModel(std::unique_ptr m) { + memory_model_ = std::move(m); +} + +inline void Module::AddEntryPoint(std::unique_ptr e) { + entry_points_.push_back(std::move(e)); +} + +inline void Module::AddExecutionMode(std::unique_ptr e) { + execution_modes_.push_back(std::move(e)); +} + +inline void Module::AddDebug1Inst(std::unique_ptr d) { + debugs1_.push_back(std::move(d)); +} + +inline void Module::AddDebug2Inst(std::unique_ptr d) { + debugs2_.push_back(std::move(d)); +} + +inline void Module::AddDebug3Inst(std::unique_ptr d) { + debugs3_.push_back(std::move(d)); +} + +inline void Module::AddAnnotationInst(std::unique_ptr a) { + annotations_.push_back(std::move(a)); +} + +inline void Module::AddType(std::unique_ptr t) { + types_values_.push_back(std::move(t)); +} + +inline void Module::AddGlobalValue(std::unique_ptr v) { + types_values_.push_back(std::move(v)); +} + +inline void Module::AddFunction(std::unique_ptr f) { + functions_.emplace_back(std::move(f)); +} + +inline Module::inst_iterator Module::capability_begin() { + return capabilities_.begin(); +} +inline Module::inst_iterator Module::capability_end() { + return capabilities_.end(); +} + +inline IteratorRange Module::capabilities() { + return make_range(capabilities_.begin(), capabilities_.end()); +} + +inline IteratorRange Module::capabilities() const { + return make_range(capabilities_.begin(), capabilities_.end()); +} + +inline Module::inst_iterator Module::ext_inst_import_begin() { + return ext_inst_imports_.begin(); +} +inline Module::inst_iterator Module::ext_inst_import_end() { + return ext_inst_imports_.end(); +} + +inline IteratorRange Module::ext_inst_imports() { + return make_range(ext_inst_imports_.begin(), ext_inst_imports_.end()); +} + +inline IteratorRange Module::ext_inst_imports() + const { + return make_range(ext_inst_imports_.begin(), ext_inst_imports_.end()); +} + +inline Module::inst_iterator Module::debug1_begin() { return debugs1_.begin(); } +inline Module::inst_iterator Module::debug1_end() { return debugs1_.end(); } + +inline IteratorRange Module::debugs1() { + return make_range(debugs1_.begin(), debugs1_.end()); +} + +inline IteratorRange Module::debugs1() const { + return make_range(debugs1_.begin(), debugs1_.end()); +} + +inline Module::inst_iterator Module::debug2_begin() { return debugs2_.begin(); } +inline Module::inst_iterator Module::debug2_end() { return debugs2_.end(); } + +inline IteratorRange Module::debugs2() { + return make_range(debugs2_.begin(), debugs2_.end()); +} + +inline IteratorRange Module::debugs2() const { + return make_range(debugs2_.begin(), debugs2_.end()); +} + +inline Module::inst_iterator Module::debug3_begin() { return debugs3_.begin(); } +inline Module::inst_iterator Module::debug3_end() { return debugs3_.end(); } + +inline IteratorRange Module::debugs3() { + return make_range(debugs3_.begin(), debugs3_.end()); +} + +inline IteratorRange Module::debugs3() const { + return make_range(debugs3_.begin(), debugs3_.end()); +} + +inline IteratorRange Module::entry_points() { + return make_range(entry_points_.begin(), entry_points_.end()); +} + +inline IteratorRange Module::entry_points() const { + return make_range(entry_points_.begin(), entry_points_.end()); +} + +inline Module::inst_iterator Module::execution_mode_begin() { + return execution_modes_.begin(); +} +inline Module::inst_iterator Module::execution_mode_end() { + return execution_modes_.end(); +} + +inline IteratorRange Module::execution_modes() { + return make_range(execution_modes_.begin(), execution_modes_.end()); +} + +inline IteratorRange Module::execution_modes() + const { + return make_range(execution_modes_.begin(), execution_modes_.end()); +} + +inline Module::inst_iterator Module::annotation_begin() { + return annotations_.begin(); +} +inline Module::inst_iterator Module::annotation_end() { + return annotations_.end(); +} + +inline IteratorRange Module::annotations() { + return make_range(annotations_.begin(), annotations_.end()); +} + +inline IteratorRange Module::annotations() const { + return make_range(annotations_.begin(), annotations_.end()); +} + +inline Module::inst_iterator Module::extension_begin() { + return extensions_.begin(); +} +inline Module::inst_iterator Module::extension_end() { + return extensions_.end(); +} + +inline IteratorRange Module::extensions() { + return make_range(extensions_.begin(), extensions_.end()); +} + +inline IteratorRange Module::extensions() const { + return make_range(extensions_.begin(), extensions_.end()); +} + +inline Module::inst_iterator Module::types_values_begin() { + return types_values_.begin(); +} + +inline Module::inst_iterator Module::types_values_end() { + return types_values_.end(); +} + +inline IteratorRange Module::types_values() { + return make_range(types_values_.begin(), types_values_.end()); +} + +inline IteratorRange Module::types_values() const { + return make_range(types_values_.begin(), types_values_.end()); +} + +inline Module::const_iterator Module::cbegin() const { + return const_iterator(&functions_, functions_.cbegin()); +} + +inline Module::const_iterator Module::cend() const { + return const_iterator(&functions_, functions_.cend()); +} + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_MODULE_H_ diff --git a/source/opt/null_pass.h b/source/opt/null_pass.h new file mode 100644 index 000000000..2b5974fb9 --- /dev/null +++ b/source/opt/null_pass.h @@ -0,0 +1,34 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_NULL_PASS_H_ +#define SOURCE_OPT_NULL_PASS_H_ + +#include "source/opt/module.h" +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class NullPass : public Pass { + public: + const char* name() const override { return "null"; } + Status Process() override { return Status::SuccessWithoutChange; } +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_NULL_PASS_H_ diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp new file mode 100644 index 000000000..30e80d718 --- /dev/null +++ b/source/opt/optimizer.cpp @@ -0,0 +1,802 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "spirv-tools/optimizer.hpp" + +#include +#include +#include +#include +#include + +#include +#include "code_sink.h" +#include "source/opt/build_module.h" +#include "source/opt/log.h" +#include "source/opt/pass_manager.h" +#include "source/opt/passes.h" +#include "source/util/make_unique.h" +#include "source/util/string_utils.h" + +namespace spvtools { + +struct Optimizer::PassToken::Impl { + Impl(std::unique_ptr p) : pass(std::move(p)) {} + + std::unique_ptr pass; // Internal implementation pass. +}; + +Optimizer::PassToken::PassToken( + std::unique_ptr impl) + : impl_(std::move(impl)) {} + +Optimizer::PassToken::PassToken(std::unique_ptr&& pass) + : impl_(MakeUnique(std::move(pass))) {} + +Optimizer::PassToken::PassToken(PassToken&& that) + : impl_(std::move(that.impl_)) {} + +Optimizer::PassToken& Optimizer::PassToken::operator=(PassToken&& that) { + impl_ = std::move(that.impl_); + return *this; +} + +Optimizer::PassToken::~PassToken() {} + +struct Optimizer::Impl { + explicit Impl(spv_target_env env) : target_env(env), pass_manager() {} + + spv_target_env target_env; // Target environment. + opt::PassManager pass_manager; // Internal implementation pass manager. +}; + +Optimizer::Optimizer(spv_target_env env) : impl_(new Impl(env)) {} + +Optimizer::~Optimizer() {} + +void Optimizer::SetMessageConsumer(MessageConsumer c) { + // All passes' message consumer needs to be updated. + for (uint32_t i = 0; i < impl_->pass_manager.NumPasses(); ++i) { + impl_->pass_manager.GetPass(i)->SetMessageConsumer(c); + } + impl_->pass_manager.SetMessageConsumer(std::move(c)); +} + +const MessageConsumer& Optimizer::consumer() const { + return impl_->pass_manager.consumer(); +} + +Optimizer& Optimizer::RegisterPass(PassToken&& p) { + // Change to use the pass manager's consumer. + p.impl_->pass->SetMessageConsumer(consumer()); + impl_->pass_manager.AddPass(std::move(p.impl_->pass)); + return *this; +} + +// The legalization passes take a spir-v shader generated by an HLSL front-end +// and turn it into a valid vulkan spir-v shader. There are two ways in which +// the code will be invalid at the start: +// +// 1) There will be opaque objects, like images, which will be passed around +// in intermediate objects. Valid spir-v will have to replace the use of +// the opaque object with an intermediate object that is the result of the +// load of the global opaque object. +// +// 2) There will be variables that contain pointers to structured or uniform +// buffers. It be legal, the variables must be eliminated, and the +// references to the structured buffers must use the result of OpVariable +// in the Uniform storage class. +// +// Optimization in this list must accept shaders with these relaxation of the +// rules. There is not guarantee that this list of optimizations is able to +// legalize all inputs, but it is on a best effort basis. +// +// The legalization problem is essentially a very general copy propagation +// problem. The optimization we use are all used to either do copy propagation +// or enable more copy propagation. +Optimizer& Optimizer::RegisterLegalizationPasses() { + return + // Remove unreachable block so that merge return works. + RegisterPass(CreateDeadBranchElimPass()) + // Merge the returns so we can inline. + .RegisterPass(CreateMergeReturnPass()) + // Make sure uses and definitions are in the same function. + .RegisterPass(CreateInlineExhaustivePass()) + // Make private variable function scope + .RegisterPass(CreateEliminateDeadFunctionsPass()) + .RegisterPass(CreatePrivateToLocalPass()) + // Propagate the value stored to the loads in very simple cases. + .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass()) + .RegisterPass(CreateLocalSingleStoreElimPass()) + .RegisterPass(CreateAggressiveDCEPass()) + // Split up aggregates so they are easier to deal with. + .RegisterPass(CreateScalarReplacementPass(0)) + // Remove loads and stores so everything is in intermediate values. + // Takes care of copy propagation of non-members. + .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass()) + .RegisterPass(CreateLocalSingleStoreElimPass()) + .RegisterPass(CreateAggressiveDCEPass()) + .RegisterPass(CreateLocalMultiStoreElimPass()) + .RegisterPass(CreateAggressiveDCEPass()) + // Propagate constants to get as many constant conditions on branches + // as possible. + .RegisterPass(CreateCCPPass()) + .RegisterPass(CreateLoopUnrollPass(true)) + .RegisterPass(CreateDeadBranchElimPass()) + // Copy propagate members. Cleans up code sequences generated by + // scalar replacement. Also important for removing OpPhi nodes. + .RegisterPass(CreateSimplificationPass()) + .RegisterPass(CreateAggressiveDCEPass()) + .RegisterPass(CreateCopyPropagateArraysPass()) + // May need loop unrolling here see + // https://github.com/Microsoft/DirectXShaderCompiler/pull/930 + // Get rid of unused code that contain traces of illegal code + // or unused references to unbound external objects + .RegisterPass(CreateVectorDCEPass()) + .RegisterPass(CreateDeadInsertElimPass()) + .RegisterPass(CreateReduceLoadSizePass()) + .RegisterPass(CreateAggressiveDCEPass()); +} + +Optimizer& Optimizer::RegisterPerformancePasses() { + return RegisterPass(CreateDeadBranchElimPass()) + .RegisterPass(CreateMergeReturnPass()) + .RegisterPass(CreateInlineExhaustivePass()) + .RegisterPass(CreateAggressiveDCEPass()) + .RegisterPass(CreatePrivateToLocalPass()) + .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass()) + .RegisterPass(CreateLocalSingleStoreElimPass()) + .RegisterPass(CreateAggressiveDCEPass()) + .RegisterPass(CreateScalarReplacementPass()) + .RegisterPass(CreateLocalAccessChainConvertPass()) + .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass()) + .RegisterPass(CreateLocalSingleStoreElimPass()) + .RegisterPass(CreateAggressiveDCEPass()) + .RegisterPass(CreateLocalMultiStoreElimPass()) + .RegisterPass(CreateAggressiveDCEPass()) + .RegisterPass(CreateCCPPass()) + .RegisterPass(CreateAggressiveDCEPass()) + .RegisterPass(CreateRedundancyEliminationPass()) + .RegisterPass(CreateCombineAccessChainsPass()) + .RegisterPass(CreateSimplificationPass()) + .RegisterPass(CreateVectorDCEPass()) + .RegisterPass(CreateDeadInsertElimPass()) + .RegisterPass(CreateDeadBranchElimPass()) + .RegisterPass(CreateSimplificationPass()) + .RegisterPass(CreateIfConversionPass()) + .RegisterPass(CreateCopyPropagateArraysPass()) + .RegisterPass(CreateReduceLoadSizePass()) + .RegisterPass(CreateAggressiveDCEPass()) + .RegisterPass(CreateBlockMergePass()) + .RegisterPass(CreateRedundancyEliminationPass()) + .RegisterPass(CreateDeadBranchElimPass()) + .RegisterPass(CreateBlockMergePass()) + .RegisterPass(CreateSimplificationPass()) + .RegisterPass(CreateCodeSinkingPass()); + // Currently exposing driver bugs resulting in crashes (#946) + // .RegisterPass(CreateCommonUniformElimPass()) +} + +Optimizer& Optimizer::RegisterSizePasses() { + return RegisterPass(CreateDeadBranchElimPass()) + .RegisterPass(CreateMergeReturnPass()) + .RegisterPass(CreateInlineExhaustivePass()) + .RegisterPass(CreateAggressiveDCEPass()) + .RegisterPass(CreatePrivateToLocalPass()) + .RegisterPass(CreateScalarReplacementPass()) + .RegisterPass(CreateLocalAccessChainConvertPass()) + .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass()) + .RegisterPass(CreateLocalSingleStoreElimPass()) + .RegisterPass(CreateAggressiveDCEPass()) + .RegisterPass(CreateSimplificationPass()) + .RegisterPass(CreateDeadInsertElimPass()) + .RegisterPass(CreateLocalMultiStoreElimPass()) + .RegisterPass(CreateAggressiveDCEPass()) + .RegisterPass(CreateCCPPass()) + .RegisterPass(CreateAggressiveDCEPass()) + .RegisterPass(CreateDeadBranchElimPass()) + .RegisterPass(CreateIfConversionPass()) + .RegisterPass(CreateAggressiveDCEPass()) + .RegisterPass(CreateBlockMergePass()) + .RegisterPass(CreateSimplificationPass()) + .RegisterPass(CreateDeadInsertElimPass()) + .RegisterPass(CreateRedundancyEliminationPass()) + .RegisterPass(CreateCFGCleanupPass()) + // Currently exposing driver bugs resulting in crashes (#946) + // .RegisterPass(CreateCommonUniformElimPass()) + .RegisterPass(CreateAggressiveDCEPass()); +} + +Optimizer& Optimizer::RegisterWebGPUPasses() { + return RegisterPass(CreateAggressiveDCEPass()) + .RegisterPass(CreateDeadBranchElimPass()); +} + +bool Optimizer::RegisterPassesFromFlags(const std::vector& flags) { + for (const auto& flag : flags) { + if (!RegisterPassFromFlag(flag)) { + return false; + } + } + + return true; +} + +bool Optimizer::FlagHasValidForm(const std::string& flag) const { + if (flag == "-O" || flag == "-Os") { + return true; + } else if (flag.size() > 2 && flag.substr(0, 2) == "--") { + return true; + } + + Errorf(consumer(), nullptr, {}, + "%s is not a valid flag. Flag passes should have the form " + "'--pass_name[=pass_args]'. Special flag names also accepted: -O " + "and -Os.", + flag.c_str()); + return false; +} + +bool Optimizer::RegisterPassFromFlag(const std::string& flag) { + if (!FlagHasValidForm(flag)) { + return false; + } + + // Split flags of the form --pass_name=pass_args. + auto p = utils::SplitFlagArgs(flag); + std::string pass_name = p.first; + std::string pass_args = p.second; + + // FIXME(dnovillo): This should be re-factored so that pass names can be + // automatically checked against Pass::name() and PassToken instances created + // via a template function. Additionally, class Pass should have a desc() + // method that describes the pass (so it can be used in --help). + // + // Both Pass::name() and Pass::desc() should be static class members so they + // can be invoked without creating a pass instance. + if (pass_name == "strip-debug") { + RegisterPass(CreateStripDebugInfoPass()); + } else if (pass_name == "strip-reflect") { + RegisterPass(CreateStripReflectInfoPass()); + } else if (pass_name == "set-spec-const-default-value") { + if (pass_args.size() > 0) { + auto spec_ids_vals = + opt::SetSpecConstantDefaultValuePass::ParseDefaultValuesString( + pass_args.c_str()); + if (!spec_ids_vals) { + Errorf(consumer(), nullptr, {}, + "Invalid argument for --set-spec-const-default-value: %s", + pass_args.c_str()); + return false; + } + RegisterPass( + CreateSetSpecConstantDefaultValuePass(std::move(*spec_ids_vals))); + } else { + Errorf(consumer(), nullptr, {}, + "Invalid spec constant value string '%s'. Expected a string of " + ": pairs.", + pass_args.c_str()); + return false; + } + } else if (pass_name == "if-conversion") { + RegisterPass(CreateIfConversionPass()); + } else if (pass_name == "freeze-spec-const") { + RegisterPass(CreateFreezeSpecConstantValuePass()); + } else if (pass_name == "inline-entry-points-exhaustive") { + RegisterPass(CreateInlineExhaustivePass()); + } else if (pass_name == "inline-entry-points-opaque") { + RegisterPass(CreateInlineOpaquePass()); + } else if (pass_name == "combine-access-chains") { + RegisterPass(CreateCombineAccessChainsPass()); + } else if (pass_name == "convert-local-access-chains") { + RegisterPass(CreateLocalAccessChainConvertPass()); + } else if (pass_name == "eliminate-dead-code-aggressive") { + RegisterPass(CreateAggressiveDCEPass()); + } else if (pass_name == "propagate-line-info") { + RegisterPass(CreatePropagateLineInfoPass()); + } else if (pass_name == "eliminate-redundant-line-info") { + RegisterPass(CreateRedundantLineInfoElimPass()); + } else if (pass_name == "eliminate-insert-extract") { + RegisterPass(CreateInsertExtractElimPass()); + } else if (pass_name == "eliminate-local-single-block") { + RegisterPass(CreateLocalSingleBlockLoadStoreElimPass()); + } else if (pass_name == "eliminate-local-single-store") { + RegisterPass(CreateLocalSingleStoreElimPass()); + } else if (pass_name == "merge-blocks") { + RegisterPass(CreateBlockMergePass()); + } else if (pass_name == "merge-return") { + RegisterPass(CreateMergeReturnPass()); + } else if (pass_name == "eliminate-dead-branches") { + RegisterPass(CreateDeadBranchElimPass()); + } else if (pass_name == "eliminate-dead-functions") { + RegisterPass(CreateEliminateDeadFunctionsPass()); + } else if (pass_name == "eliminate-local-multi-store") { + RegisterPass(CreateLocalMultiStoreElimPass()); + } else if (pass_name == "eliminate-common-uniform") { + RegisterPass(CreateCommonUniformElimPass()); + } else if (pass_name == "eliminate-dead-const") { + RegisterPass(CreateEliminateDeadConstantPass()); + } else if (pass_name == "eliminate-dead-inserts") { + RegisterPass(CreateDeadInsertElimPass()); + } else if (pass_name == "eliminate-dead-variables") { + RegisterPass(CreateDeadVariableEliminationPass()); + } else if (pass_name == "fold-spec-const-op-composite") { + RegisterPass(CreateFoldSpecConstantOpAndCompositePass()); + } else if (pass_name == "loop-unswitch") { + RegisterPass(CreateLoopUnswitchPass()); + } else if (pass_name == "scalar-replacement") { + if (pass_args.size() == 0) { + RegisterPass(CreateScalarReplacementPass()); + } else { + int limit = -1; + if (pass_args.find_first_not_of("0123456789") == std::string::npos) { + limit = atoi(pass_args.c_str()); + } + + if (limit >= 0) { + RegisterPass(CreateScalarReplacementPass(limit)); + } else { + Error(consumer(), nullptr, {}, + "--scalar-replacement must have no arguments or a non-negative " + "integer argument"); + return false; + } + } + } else if (pass_name == "strength-reduction") { + RegisterPass(CreateStrengthReductionPass()); + } else if (pass_name == "unify-const") { + RegisterPass(CreateUnifyConstantPass()); + } else if (pass_name == "flatten-decorations") { + RegisterPass(CreateFlattenDecorationPass()); + } else if (pass_name == "compact-ids") { + RegisterPass(CreateCompactIdsPass()); + } else if (pass_name == "cfg-cleanup") { + RegisterPass(CreateCFGCleanupPass()); + } else if (pass_name == "local-redundancy-elimination") { + RegisterPass(CreateLocalRedundancyEliminationPass()); + } else if (pass_name == "loop-invariant-code-motion") { + RegisterPass(CreateLoopInvariantCodeMotionPass()); + } else if (pass_name == "reduce-load-size") { + RegisterPass(CreateReduceLoadSizePass()); + } else if (pass_name == "redundancy-elimination") { + RegisterPass(CreateRedundancyEliminationPass()); + } else if (pass_name == "private-to-local") { + RegisterPass(CreatePrivateToLocalPass()); + } else if (pass_name == "remove-duplicates") { + RegisterPass(CreateRemoveDuplicatesPass()); + } else if (pass_name == "workaround-1209") { + RegisterPass(CreateWorkaround1209Pass()); + } else if (pass_name == "replace-invalid-opcode") { + RegisterPass(CreateReplaceInvalidOpcodePass()); + } else if (pass_name == "inst-bindless-check") { + RegisterPass(CreateInstBindlessCheckPass(7, 23)); + RegisterPass(CreateSimplificationPass()); + RegisterPass(CreateDeadBranchElimPass()); + RegisterPass(CreateBlockMergePass()); + RegisterPass(CreateAggressiveDCEPass()); + } else if (pass_name == "simplify-instructions") { + RegisterPass(CreateSimplificationPass()); + } else if (pass_name == "ssa-rewrite") { + RegisterPass(CreateSSARewritePass()); + } else if (pass_name == "copy-propagate-arrays") { + RegisterPass(CreateCopyPropagateArraysPass()); + } else if (pass_name == "loop-fission") { + int register_threshold_to_split = + (pass_args.size() > 0) ? atoi(pass_args.c_str()) : -1; + if (register_threshold_to_split > 0) { + RegisterPass(CreateLoopFissionPass( + static_cast(register_threshold_to_split))); + } else { + Error(consumer(), nullptr, {}, + "--loop-fission must have a positive integer argument"); + return false; + } + } else if (pass_name == "loop-fusion") { + int max_registers_per_loop = + (pass_args.size() > 0) ? atoi(pass_args.c_str()) : -1; + if (max_registers_per_loop > 0) { + RegisterPass( + CreateLoopFusionPass(static_cast(max_registers_per_loop))); + } else { + Error(consumer(), nullptr, {}, + "--loop-fusion must have a positive integer argument"); + return false; + } + } else if (pass_name == "loop-unroll") { + RegisterPass(CreateLoopUnrollPass(true)); + } else if (pass_name == "upgrade-memory-model") { + RegisterPass(CreateUpgradeMemoryModelPass()); + } else if (pass_name == "vector-dce") { + RegisterPass(CreateVectorDCEPass()); + } else if (pass_name == "loop-unroll-partial") { + int factor = (pass_args.size() > 0) ? atoi(pass_args.c_str()) : 0; + if (factor > 0) { + RegisterPass(CreateLoopUnrollPass(false, factor)); + } else { + Error(consumer(), nullptr, {}, + "--loop-unroll-partial must have a positive integer argument"); + return false; + } + } else if (pass_name == "loop-peeling") { + RegisterPass(CreateLoopPeelingPass()); + } else if (pass_name == "loop-peeling-threshold") { + int factor = (pass_args.size() > 0) ? atoi(pass_args.c_str()) : 0; + if (factor > 0) { + opt::LoopPeelingPass::SetLoopPeelingThreshold(factor); + } else { + Error(consumer(), nullptr, {}, + "--loop-peeling-threshold must have a positive integer argument"); + return false; + } + } else if (pass_name == "ccp") { + RegisterPass(CreateCCPPass()); + } else if (pass_name == "code-sink") { + RegisterPass(CreateCodeSinkingPass()); + } else if (pass_name == "O") { + RegisterPerformancePasses(); + } else if (pass_name == "Os") { + RegisterSizePasses(); + } else if (pass_name == "legalize-hlsl") { + RegisterLegalizationPasses(); + } else { + Errorf(consumer(), nullptr, {}, + "Unknown flag '--%s'. Use --help for a list of valid flags", + pass_name.c_str()); + return false; + } + + return true; +} + +void Optimizer::SetTargetEnv(const spv_target_env env) { + impl_->target_env = env; +} + +bool Optimizer::Run(const uint32_t* original_binary, + const size_t original_binary_size, + std::vector* optimized_binary) const { + return Run(original_binary, original_binary_size, optimized_binary, + OptimizerOptions()); +} + +bool Optimizer::Run(const uint32_t* original_binary, + const size_t original_binary_size, + std::vector* optimized_binary, + const ValidatorOptions& validator_options, + bool skip_validation) const { + OptimizerOptions opt_options; + opt_options.set_run_validator(!skip_validation); + opt_options.set_validator_options(validator_options); + return Run(original_binary, original_binary_size, optimized_binary, + opt_options); +} + +bool Optimizer::Run(const uint32_t* original_binary, + const size_t original_binary_size, + std::vector* optimized_binary, + const spv_optimizer_options opt_options) const { + spvtools::SpirvTools tools(impl_->target_env); + tools.SetMessageConsumer(impl_->pass_manager.consumer()); + if (opt_options->run_validator_ && + !tools.Validate(original_binary, original_binary_size, + &opt_options->val_options_)) { + return false; + } + + std::unique_ptr context = BuildModule( + impl_->target_env, consumer(), original_binary, original_binary_size); + if (context == nullptr) return false; + + context->set_max_id_bound(opt_options->max_id_bound_); + + auto status = impl_->pass_manager.Run(context.get()); + if (status == opt::Pass::Status::SuccessWithChange || + (status == opt::Pass::Status::SuccessWithoutChange && + (optimized_binary->data() != original_binary || + optimized_binary->size() != original_binary_size))) { + optimized_binary->clear(); + context->module()->ToBinary(optimized_binary, /* skip_nop = */ true); + } + + return status != opt::Pass::Status::Failure; +} + +Optimizer& Optimizer::SetPrintAll(std::ostream* out) { + impl_->pass_manager.SetPrintAll(out); + return *this; +} + +Optimizer& Optimizer::SetTimeReport(std::ostream* out) { + impl_->pass_manager.SetTimeReport(out); + return *this; +} + +Optimizer::PassToken CreateNullPass() { + return MakeUnique(MakeUnique()); +} + +Optimizer::PassToken CreateStripDebugInfoPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateStripReflectInfoPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateEliminateDeadFunctionsPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateSetSpecConstantDefaultValuePass( + const std::unordered_map& id_value_map) { + return MakeUnique( + MakeUnique(id_value_map)); +} + +Optimizer::PassToken CreateSetSpecConstantDefaultValuePass( + const std::unordered_map>& id_value_map) { + return MakeUnique( + MakeUnique(id_value_map)); +} + +Optimizer::PassToken CreateFlattenDecorationPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateFreezeSpecConstantValuePass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateFoldSpecConstantOpAndCompositePass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateUnifyConstantPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateEliminateDeadConstantPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateDeadVariableEliminationPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateStrengthReductionPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateBlockMergePass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateInlineExhaustivePass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateInlineOpaquePass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateLocalAccessChainConvertPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateLocalSingleBlockLoadStoreElimPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateLocalSingleStoreElimPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateInsertExtractElimPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateDeadInsertElimPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateDeadBranchElimPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateLocalMultiStoreElimPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateAggressiveDCEPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreatePropagateLineInfoPass() { + return MakeUnique( + MakeUnique(opt::kLinesPropagateLines)); +} + +Optimizer::PassToken CreateRedundantLineInfoElimPass() { + return MakeUnique( + MakeUnique(opt::kLinesEliminateDeadLines)); +} + +Optimizer::PassToken CreateCommonUniformElimPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateCompactIdsPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateMergeReturnPass() { + return MakeUnique( + MakeUnique()); +} + +std::vector Optimizer::GetPassNames() const { + std::vector v; + for (uint32_t i = 0; i < impl_->pass_manager.NumPasses(); i++) { + v.push_back(impl_->pass_manager.GetPass(i)->name()); + } + return v; +} + +Optimizer::PassToken CreateCFGCleanupPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateLocalRedundancyEliminationPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateLoopFissionPass(size_t threshold) { + return MakeUnique( + MakeUnique(threshold)); +} + +Optimizer::PassToken CreateLoopFusionPass(size_t max_registers_per_loop) { + return MakeUnique( + MakeUnique(max_registers_per_loop)); +} + +Optimizer::PassToken CreateLoopInvariantCodeMotionPass() { + return MakeUnique(MakeUnique()); +} + +Optimizer::PassToken CreateLoopPeelingPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateLoopUnswitchPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateRedundancyEliminationPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateRemoveDuplicatesPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateScalarReplacementPass(uint32_t size_limit) { + return MakeUnique( + MakeUnique(size_limit)); +} + +Optimizer::PassToken CreatePrivateToLocalPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateCCPPass() { + return MakeUnique(MakeUnique()); +} + +Optimizer::PassToken CreateWorkaround1209Pass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateIfConversionPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateReplaceInvalidOpcodePass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateSimplificationPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateLoopUnrollPass(bool fully_unroll, int factor) { + return MakeUnique( + MakeUnique(fully_unroll, factor)); +} + +Optimizer::PassToken CreateSSARewritePass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateCopyPropagateArraysPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateVectorDCEPass() { + return MakeUnique(MakeUnique()); +} + +Optimizer::PassToken CreateReduceLoadSizePass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateCombineAccessChainsPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateUpgradeMemoryModelPass() { + return MakeUnique( + MakeUnique()); +} + +Optimizer::PassToken CreateInstBindlessCheckPass(uint32_t desc_set, + uint32_t shader_id) { + return MakeUnique( + MakeUnique(desc_set, shader_id)); +} + +Optimizer::PassToken CreateCodeSinkingPass() { + return MakeUnique( + MakeUnique()); +} + +} // namespace spvtools diff --git a/source/opt/pass.cpp b/source/opt/pass.cpp new file mode 100644 index 000000000..edcd24516 --- /dev/null +++ b/source/opt/pass.cpp @@ -0,0 +1,56 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/pass.h" + +#include "source/opt/iterator.h" + +namespace spvtools { +namespace opt { + +namespace { + +const uint32_t kTypePointerTypeIdInIdx = 1; + +} // namespace + +Pass::Pass() : consumer_(nullptr), context_(nullptr), already_run_(false) {} + +Pass::Status Pass::Run(IRContext* ctx) { + if (already_run_) { + return Status::Failure; + } + already_run_ = true; + + context_ = ctx; + Pass::Status status = Process(); + context_ = nullptr; + + if (status == Status::SuccessWithChange) { + ctx->InvalidateAnalysesExceptFor(GetPreservedAnalyses()); + } + assert(ctx->IsConsistent()); + return status; +} + +uint32_t Pass::GetPointeeTypeId(const Instruction* ptrInst) const { + const uint32_t ptrTypeId = ptrInst->type_id(); + const Instruction* ptrTypeInst = get_def_use_mgr()->GetDef(ptrTypeId); + return ptrTypeInst->GetSingleWordInOperand(kTypePointerTypeIdInIdx); +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/pass.h b/source/opt/pass.h new file mode 100644 index 000000000..c95f5022b --- /dev/null +++ b/source/opt/pass.h @@ -0,0 +1,147 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_PASS_H_ +#define SOURCE_OPT_PASS_H_ + +#include +#include +#include +#include +#include + +#include "source/opt/basic_block.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "spirv-tools/libspirv.hpp" + +namespace spvtools { +namespace opt { + +// Abstract class of a pass. All passes should implement this abstract class +// and all analysis and transformation is done via the Process() method. +class Pass { + public: + // The status of processing a module using a pass. + // + // The numbers for the cases are assigned to make sure that Failure & anything + // is Failure, SuccessWithChange & any success is SuccessWithChange. + enum class Status { + Failure = 0x00, + SuccessWithChange = 0x10, + SuccessWithoutChange = 0x11, + }; + + using ProcessFunction = std::function; + + // Destructs the pass. + virtual ~Pass() = default; + + // Returns a descriptive name for this pass. + // + // NOTE: When deriving a new pass class, make sure you make the name + // compatible with the corresponding spirv-opt command-line flag. For example, + // if you add the flag --my-pass to spirv-opt, make this function return + // "my-pass" (no leading hyphens). + virtual const char* name() const = 0; + + // Sets the message consumer to the given |consumer|. |consumer| which will be + // invoked every time there is a message to be communicated to the outside. + void SetMessageConsumer(MessageConsumer c) { consumer_ = std::move(c); } + + // Returns the reference to the message consumer for this pass. + const MessageConsumer& consumer() const { return consumer_; } + + // Returns the def-use manager used for this pass. TODO(dnovillo): This should + // be handled by the pass manager. + analysis::DefUseManager* get_def_use_mgr() const { + return context()->get_def_use_mgr(); + } + + analysis::DecorationManager* get_decoration_mgr() const { + return context()->get_decoration_mgr(); + } + + FeatureManager* get_feature_mgr() const { + return context()->get_feature_mgr(); + } + + // Returns a pointer to the current module for this pass. + Module* get_module() const { return context_->module(); } + + // Sets the pointer to the current context for this pass. + void SetContextForTesting(IRContext* ctx) { context_ = ctx; } + + // Returns a pointer to the current context for this pass. + IRContext* context() const { return context_; } + + // Returns a pointer to the CFG for current module. + CFG* cfg() const { return context()->cfg(); } + + // Run the pass on the given |module|. Returns Status::Failure if errors occur + // when processing. Returns the corresponding Status::Success if processing is + // successful to indicate whether changes are made to the module. If there + // were any changes it will also invalidate the analyses in the IRContext + // that are not preserved. + // + // It is an error if |Run| is called twice with the same instance of the pass. + // If this happens the return value will be |Failure|. + Status Run(IRContext* ctx); + + // Returns the set of analyses that the pass is guaranteed to preserve. + virtual IRContext::Analysis GetPreservedAnalyses() { + return IRContext::kAnalysisNone; + } + + // Return type id for |ptrInst|'s pointee + uint32_t GetPointeeTypeId(const Instruction* ptrInst) const; + + protected: + // Constructs a new pass. + // + // The constructed instance will have an empty message consumer, which just + // ignores all messages from the library. Use SetMessageConsumer() to supply + // one if messages are of concern. + Pass(); + + // Processes the given |module|. Returns Status::Failure if errors occur when + // processing. Returns the corresponding Status::Success if processing is + // succesful to indicate whether changes are made to the module. + virtual Status Process() = 0; + + // Return the next available SSA id and increment it. + // TODO(1841): Handle id overflow. + uint32_t TakeNextId() { return context_->TakeNextId(); } + + private: + MessageConsumer consumer_; // Message consumer. + + // The context that this pass belongs to. + IRContext* context_; + + // An instance of a pass can only be run once because it is too hard to + // enforce proper resetting of internal state for each instance. This member + // is used to check that we do not run the same instance twice. + bool already_run_; +}; + +inline Pass::Status CombineStatus(Pass::Status a, Pass::Status b) { + return std::min(a, b); +} + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_PASS_H_ diff --git a/source/opt/pass_manager.cpp b/source/opt/pass_manager.cpp new file mode 100644 index 000000000..fa1e1d8a8 --- /dev/null +++ b/source/opt/pass_manager.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/pass_manager.h" + +#include +#include +#include + +#include "source/opt/ir_context.h" +#include "source/util/timer.h" +#include "spirv-tools/libspirv.hpp" + +namespace spvtools { + +namespace opt { + +Pass::Status PassManager::Run(IRContext* context) { + auto status = Pass::Status::SuccessWithoutChange; + + // If print_all_stream_ is not null, prints the disassembly of the module + // to that stream, with the given preamble and optionally the pass name. + auto print_disassembly = [&context, this](const char* preamble, Pass* pass) { + if (print_all_stream_) { + std::vector binary; + context->module()->ToBinary(&binary, false); + SpirvTools t(SPV_ENV_UNIVERSAL_1_2); + std::string disassembly; + t.Disassemble(binary, &disassembly, 0); + *print_all_stream_ << preamble << (pass ? pass->name() : "") << "\n" + << disassembly << std::endl; + } + }; + + SPIRV_TIMER_DESCRIPTION(time_report_stream_, /* measure_mem_usage = */ true); + for (auto& pass : passes_) { + print_disassembly("; IR before pass ", pass.get()); + SPIRV_TIMER_SCOPED(time_report_stream_, (pass ? pass->name() : ""), true); + const auto one_status = pass->Run(context); + if (one_status == Pass::Status::Failure) return one_status; + if (one_status == Pass::Status::SuccessWithChange) status = one_status; + + // Reset the pass to free any memory used by the pass. + pass.reset(nullptr); + } + print_disassembly("; IR after last pass", nullptr); + + // Set the Id bound in the header in case a pass forgot to do so. + // + // TODO(dnovillo): This should be unnecessary and automatically maintained by + // the IRContext. + if (status == Pass::Status::SuccessWithChange) { + context->module()->SetIdBound(context->module()->ComputeIdBound()); + } + passes_.clear(); + return status; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/pass_manager.h b/source/opt/pass_manager.h new file mode 100644 index 000000000..ed88aa17c --- /dev/null +++ b/source/opt/pass_manager.h @@ -0,0 +1,131 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_PASS_MANAGER_H_ +#define SOURCE_OPT_PASS_MANAGER_H_ + +#include +#include +#include +#include + +#include "source/opt/log.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" + +#include "source/opt/ir_context.h" +#include "spirv-tools/libspirv.hpp" + +namespace spvtools { +namespace opt { + +// The pass manager, responsible for tracking and running passes. +// Clients should first call AddPass() to add passes and then call Run() +// to run on a module. Passes are executed in the exact order of addition. +class PassManager { + public: + // Constructs a pass manager. + // + // The constructed instance will have an empty message consumer, which just + // ignores all messages from the library. Use SetMessageConsumer() to supply + // one if messages are of concern. + PassManager() + : consumer_(nullptr), + print_all_stream_(nullptr), + time_report_stream_(nullptr) {} + + // Sets the message consumer to the given |consumer|. + void SetMessageConsumer(MessageConsumer c) { consumer_ = std::move(c); } + + // Adds an externally constructed pass. + void AddPass(std::unique_ptr pass); + // Uses the argument |args| to construct a pass instance of type |T|, and adds + // the pass instance to this pass manger. The pass added will use this pass + // manager's message consumer. + template + void AddPass(Args&&... args); + + // Returns the number of passes added. + uint32_t NumPasses() const; + // Returns a pointer to the |index|th pass added. + inline Pass* GetPass(uint32_t index) const; + + // Returns the message consumer. + inline const MessageConsumer& consumer() const; + + // Runs all passes on the given |module|. Returns Status::Failure if errors + // occur when processing using one of the registered passes. All passes + // registered after the error-reporting pass will be skipped. Returns the + // corresponding Status::Success if processing is succesful to indicate + // whether changes are made to the module. + // + // After running all the passes, they are removed from the list. + Pass::Status Run(IRContext* context); + + // Sets the option to print the disassembly before each pass and after the + // last pass. Output is written to |out| if that is not null. No output + // is generated if |out| is null. + PassManager& SetPrintAll(std::ostream* out) { + print_all_stream_ = out; + return *this; + } + + // Sets the option to print the resource utilization of each pass. Output is + // written to |out| if that is not null. No output is generated if |out| is + // null. + PassManager& SetTimeReport(std::ostream* out) { + time_report_stream_ = out; + return *this; + } + + private: + // Consumer for messages. + MessageConsumer consumer_; + // A vector of passes. Order matters. + std::vector> passes_; + // The output stream to write disassembly to before each pass, and after + // the last pass. If this is null, no output is generated. + std::ostream* print_all_stream_; + // The output stream to write the resource utilization of each pass. If this + // is null, no output is generated. + std::ostream* time_report_stream_; +}; + +inline void PassManager::AddPass(std::unique_ptr pass) { + passes_.push_back(std::move(pass)); +} + +template +inline void PassManager::AddPass(Args&&... args) { + passes_.emplace_back(new T(std::forward(args)...)); + passes_.back()->SetMessageConsumer(consumer_); +} + +inline uint32_t PassManager::NumPasses() const { + return static_cast(passes_.size()); +} + +inline Pass* PassManager::GetPass(uint32_t index) const { + SPIRV_ASSERT(consumer_, index < passes_.size(), "index out of bound"); + return passes_[index].get(); +} + +inline const MessageConsumer& PassManager::consumer() const { + return consumer_; +} + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_PASS_MANAGER_H_ diff --git a/source/opt/passes.h b/source/opt/passes.h new file mode 100644 index 000000000..f7b675e28 --- /dev/null +++ b/source/opt/passes.h @@ -0,0 +1,72 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_PASSES_H_ +#define SOURCE_OPT_PASSES_H_ + +// A single header to include all passes. + +#include "source/opt/aggressive_dead_code_elim_pass.h" +#include "source/opt/block_merge_pass.h" +#include "source/opt/ccp_pass.h" +#include "source/opt/cfg_cleanup_pass.h" +#include "source/opt/code_sink.h" +#include "source/opt/combine_access_chains.h" +#include "source/opt/common_uniform_elim_pass.h" +#include "source/opt/compact_ids_pass.h" +#include "source/opt/copy_prop_arrays.h" +#include "source/opt/dead_branch_elim_pass.h" +#include "source/opt/dead_insert_elim_pass.h" +#include "source/opt/dead_variable_elimination.h" +#include "source/opt/eliminate_dead_constant_pass.h" +#include "source/opt/eliminate_dead_functions_pass.h" +#include "source/opt/flatten_decoration_pass.h" +#include "source/opt/fold_spec_constant_op_and_composite_pass.h" +#include "source/opt/freeze_spec_constant_value_pass.h" +#include "source/opt/if_conversion.h" +#include "source/opt/inline_exhaustive_pass.h" +#include "source/opt/inline_opaque_pass.h" +#include "source/opt/inst_bindless_check_pass.h" +#include "source/opt/licm_pass.h" +#include "source/opt/local_access_chain_convert_pass.h" +#include "source/opt/local_redundancy_elimination.h" +#include "source/opt/local_single_block_elim_pass.h" +#include "source/opt/local_single_store_elim_pass.h" +#include "source/opt/local_ssa_elim_pass.h" +#include "source/opt/loop_fission.h" +#include "source/opt/loop_fusion_pass.h" +#include "source/opt/loop_peeling.h" +#include "source/opt/loop_unroller.h" +#include "source/opt/loop_unswitch_pass.h" +#include "source/opt/merge_return_pass.h" +#include "source/opt/null_pass.h" +#include "source/opt/private_to_local_pass.h" +#include "source/opt/process_lines_pass.h" +#include "source/opt/reduce_load_size.h" +#include "source/opt/redundancy_elimination.h" +#include "source/opt/remove_duplicates_pass.h" +#include "source/opt/replace_invalid_opc.h" +#include "source/opt/scalar_replacement_pass.h" +#include "source/opt/set_spec_constant_default_value_pass.h" +#include "source/opt/simplification_pass.h" +#include "source/opt/ssa_rewrite_pass.h" +#include "source/opt/strength_reduction_pass.h" +#include "source/opt/strip_debug_info_pass.h" +#include "source/opt/strip_reflect_info_pass.h" +#include "source/opt/unify_const_pass.h" +#include "source/opt/upgrade_memory_model.h" +#include "source/opt/vector_dce.h" +#include "source/opt/workaround1209.h" + +#endif // SOURCE_OPT_PASSES_H_ diff --git a/source/opt/pch_source_opt.cpp b/source/opt/pch_source_opt.cpp new file mode 100644 index 000000000..f45448dc5 --- /dev/null +++ b/source/opt/pch_source_opt.cpp @@ -0,0 +1,15 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "pch_source_opt.h" diff --git a/source/opt/pch_source_opt.h b/source/opt/pch_source_opt.h new file mode 100644 index 000000000..73566510e --- /dev/null +++ b/source/opt/pch_source_opt.h @@ -0,0 +1,32 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "source/opt/basic_block.h" +#include "source/opt/decoration_manager.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/mem_pass.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" +#include "source/util/hex_float.h" +#include "source/util/make_unique.h" diff --git a/source/opt/private_to_local_pass.cpp b/source/opt/private_to_local_pass.cpp new file mode 100644 index 000000000..02909a72b --- /dev/null +++ b/source/opt/private_to_local_pass.cpp @@ -0,0 +1,186 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/private_to_local_pass.h" + +#include +#include +#include + +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace opt { +namespace { + +const uint32_t kVariableStorageClassInIdx = 0; +const uint32_t kSpvTypePointerTypeIdInIdx = 1; + +} // namespace + +Pass::Status PrivateToLocalPass::Process() { + bool modified = false; + + // Private variables require the shader capability. If this is not a shader, + // there is no work to do. + if (context()->get_feature_mgr()->HasCapability(SpvCapabilityAddresses)) + return Status::SuccessWithoutChange; + + std::vector> variables_to_move; + for (auto& inst : context()->types_values()) { + if (inst.opcode() != SpvOpVariable) { + continue; + } + + if (inst.GetSingleWordInOperand(kVariableStorageClassInIdx) != + SpvStorageClassPrivate) { + continue; + } + + Function* target_function = FindLocalFunction(inst); + if (target_function != nullptr) { + variables_to_move.push_back({&inst, target_function}); + } + } + + modified = !variables_to_move.empty(); + for (auto p : variables_to_move) { + MoveVariable(p.first, p.second); + } + + return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); +} + +Function* PrivateToLocalPass::FindLocalFunction(const Instruction& inst) const { + bool found_first_use = false; + Function* target_function = nullptr; + context()->get_def_use_mgr()->ForEachUser( + inst.result_id(), + [&target_function, &found_first_use, this](Instruction* use) { + BasicBlock* current_block = context()->get_instr_block(use); + if (current_block == nullptr) { + return; + } + + if (!IsValidUse(use)) { + found_first_use = true; + target_function = nullptr; + return; + } + Function* current_function = current_block->GetParent(); + if (!found_first_use) { + found_first_use = true; + target_function = current_function; + } else if (target_function != current_function) { + target_function = nullptr; + } + }); + return target_function; +} // namespace opt + +void PrivateToLocalPass::MoveVariable(Instruction* variable, + Function* function) { + // The variable needs to be removed from the global section, and placed in the + // header of the function. First step remove from the global list. + variable->RemoveFromList(); + std::unique_ptr var(variable); // Take ownership. + context()->ForgetUses(variable); + + // Update the storage class of the variable. + variable->SetInOperand(kVariableStorageClassInIdx, {SpvStorageClassFunction}); + + // Update the type as well. + uint32_t new_type_id = GetNewType(variable->type_id()); + variable->SetResultType(new_type_id); + + // Place the variable at the start of the first basic block. + context()->AnalyzeUses(variable); + context()->set_instr_block(variable, &*function->begin()); + function->begin()->begin()->InsertBefore(move(var)); + + // Update uses where the type may have changed. + UpdateUses(variable->result_id()); +} + +uint32_t PrivateToLocalPass::GetNewType(uint32_t old_type_id) { + auto type_mgr = context()->get_type_mgr(); + Instruction* old_type_inst = get_def_use_mgr()->GetDef(old_type_id); + uint32_t pointee_type_id = + old_type_inst->GetSingleWordInOperand(kSpvTypePointerTypeIdInIdx); + uint32_t new_type_id = + type_mgr->FindPointerToType(pointee_type_id, SpvStorageClassFunction); + context()->UpdateDefUse(context()->get_def_use_mgr()->GetDef(new_type_id)); + return new_type_id; +} + +bool PrivateToLocalPass::IsValidUse(const Instruction* inst) const { + // The cases in this switch have to match the cases in |UpdateUse|. + // If we don't know how to update it, it is not valid. + switch (inst->opcode()) { + case SpvOpLoad: + case SpvOpStore: + case SpvOpImageTexelPointer: // Treat like a load + return true; + case SpvOpAccessChain: + return context()->get_def_use_mgr()->WhileEachUser( + inst, [this](const Instruction* user) { + if (!IsValidUse(user)) return false; + return true; + }); + case SpvOpName: + return true; + default: + return spvOpcodeIsDecoration(inst->opcode()); + } +} + +void PrivateToLocalPass::UpdateUse(Instruction* inst) { + // The cases in this switch have to match the cases in |IsValidUse|. If we + // don't think it is valid, the optimization will not view the variable as a + // candidate, and therefore the use will not be updated. + switch (inst->opcode()) { + case SpvOpLoad: + case SpvOpStore: + case SpvOpImageTexelPointer: // Treat like a load + // The type is fine because it is the type pointed to, and that does not + // change. + break; + case SpvOpAccessChain: + context()->ForgetUses(inst); + inst->SetResultType(GetNewType(inst->type_id())); + context()->AnalyzeUses(inst); + + // Update uses where the type may have changed. + UpdateUses(inst->result_id()); + break; + case SpvOpName: + break; + default: + assert(spvOpcodeIsDecoration(inst->opcode()) && + "Do not know how to update the type for this instruction."); + break; + } +} +void PrivateToLocalPass::UpdateUses(uint32_t id) { + std::vector uses; + context()->get_def_use_mgr()->ForEachUser( + id, [&uses](Instruction* use) { uses.push_back(use); }); + + for (Instruction* use : uses) { + UpdateUse(use); + } +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/private_to_local_pass.h b/source/opt/private_to_local_pass.h new file mode 100644 index 000000000..467853030 --- /dev/null +++ b/source/opt/private_to_local_pass.h @@ -0,0 +1,73 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_PRIVATE_TO_LOCAL_PASS_H_ +#define SOURCE_OPT_PRIVATE_TO_LOCAL_PASS_H_ + +#include "source/opt/ir_context.h" +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// This pass implements total redundancy elimination. This is the same as +// local redundancy elimination except it looks across basic block boundaries. +// An instruction, inst, is totally redundant if there is another instruction +// that dominates inst, and also computes the same value. +class PrivateToLocalPass : public Pass { + public: + const char* name() const override { return "private-to-local"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators | + IRContext::kAnalysisCFG | IRContext::kAnalysisDominatorAnalysis | + IRContext::kAnalysisNameMap | IRContext::kAnalysisConstants | + IRContext::kAnalysisTypes; + } + + private: + // Moves |variable| from the private storage class to the function storage + // class of |function|. + void MoveVariable(Instruction* variable, Function* function); + + // |inst| is an instruction declaring a varible. If that variable is + // referenced in a single function and all of uses are valid as defined by + // |IsValidUse|, then that function is returned. Otherwise, the return + // value is |nullptr|. + Function* FindLocalFunction(const Instruction& inst) const; + + // Returns true is |inst| is a valid use of a pointer. In this case, a + // valid use is one where the transformation is able to rewrite the type to + // match a change in storage class of the original variable. + bool IsValidUse(const Instruction* inst) const; + + // Given the result id of a pointer type, |old_type_id|, this function + // returns the id of a the same pointer type except the storage class has + // been changed to function. If the type does not already exist, it will be + // created. + uint32_t GetNewType(uint32_t old_type_id); + + // Updates |inst|, and any instruction dependent on |inst|, to reflect the + // change of the base pointer now pointing to the function storage class. + void UpdateUse(Instruction* inst); + void UpdateUses(uint32_t id); +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_PRIVATE_TO_LOCAL_PASS_H_ diff --git a/source/opt/process_lines_pass.cpp b/source/opt/process_lines_pass.cpp new file mode 100644 index 000000000..0ae2f7583 --- /dev/null +++ b/source/opt/process_lines_pass.cpp @@ -0,0 +1,157 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// Copyright (c) 2018 Valve Corporation +// Copyright (c) 2018 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/process_lines_pass.h" + +#include +#include +#include + +namespace { + +// Input Operand Indices +static const int kSpvLineFileInIdx = 0; +static const int kSpvLineLineInIdx = 1; +static const int kSpvLineColInIdx = 2; + +} // anonymous namespace + +namespace spvtools { +namespace opt { + +Pass::Status ProcessLinesPass::Process() { + bool modified = ProcessLines(); + return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); +} + +bool ProcessLinesPass::ProcessLines() { + bool modified = false; + uint32_t file_id = 0; + uint32_t line = 0; + uint32_t col = 0; + // Process types, globals, constants + for (Instruction& inst : get_module()->types_values()) + modified |= line_process_func_(&inst, &file_id, &line, &col); + // Process functions + for (Function& function : *get_module()) { + modified |= line_process_func_(&function.DefInst(), &file_id, &line, &col); + function.ForEachParam( + [this, &modified, &file_id, &line, &col](Instruction* param) { + modified |= line_process_func_(param, &file_id, &line, &col); + }); + for (BasicBlock& block : function) { + modified |= + line_process_func_(block.GetLabelInst(), &file_id, &line, &col); + for (Instruction& inst : block) { + modified |= line_process_func_(&inst, &file_id, &line, &col); + // Don't process terminal instruction if preceeded by merge + if (inst.opcode() == SpvOpSelectionMerge || + inst.opcode() == SpvOpLoopMerge) + break; + } + // Nullify line info after each block. + file_id = 0; + } + modified |= line_process_func_(function.EndInst(), &file_id, &line, &col); + } + return modified; +} + +bool ProcessLinesPass::PropagateLine(Instruction* inst, uint32_t* file_id, + uint32_t* line, uint32_t* col) { + bool modified = false; + // only the last debug instruction needs to be considered + auto line_itr = inst->dbg_line_insts().rbegin(); + // if no line instructions, propagate previous info + if (line_itr == inst->dbg_line_insts().rend()) { + // if no current line info, add OpNoLine, else OpLine + if (*file_id == 0) + inst->dbg_line_insts().push_back(Instruction(context(), SpvOpNoLine)); + else + inst->dbg_line_insts().push_back(Instruction( + context(), SpvOpLine, 0, 0, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {*file_id}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {*line}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {*col}}})); + modified = true; + } else { + // else pre-existing line instruction, so update source line info + if (line_itr->opcode() == SpvOpNoLine) { + *file_id = 0; + } else { + assert(line_itr->opcode() == SpvOpLine && "unexpected debug inst"); + *file_id = line_itr->GetSingleWordInOperand(kSpvLineFileInIdx); + *line = line_itr->GetSingleWordInOperand(kSpvLineLineInIdx); + *col = line_itr->GetSingleWordInOperand(kSpvLineColInIdx); + } + } + return modified; +} + +bool ProcessLinesPass::EliminateDeadLines(Instruction* inst, uint32_t* file_id, + uint32_t* line, uint32_t* col) { + // If no debug line instructions, return without modifying lines + if (inst->dbg_line_insts().empty()) return false; + // Only the last debug instruction needs to be considered; delete all others + bool modified = inst->dbg_line_insts().size() > 1; + Instruction last_inst = inst->dbg_line_insts().back(); + inst->dbg_line_insts().clear(); + // If last line is OpNoLine + if (last_inst.opcode() == SpvOpNoLine) { + // If no propagated line info, throw away redundant OpNoLine + if (*file_id == 0) { + modified = true; + // Else replace OpNoLine and propagate no line info + } else { + inst->dbg_line_insts().push_back(last_inst); + *file_id = 0; + } + } else { + // Else last line is OpLine + assert(last_inst.opcode() == SpvOpLine && "unexpected debug inst"); + // If propagated info matches last line, throw away last line + if (*file_id == last_inst.GetSingleWordInOperand(kSpvLineFileInIdx) && + *line == last_inst.GetSingleWordInOperand(kSpvLineLineInIdx) && + *col == last_inst.GetSingleWordInOperand(kSpvLineColInIdx)) { + modified = true; + } else { + // Else replace last line and propagate line info + *file_id = last_inst.GetSingleWordInOperand(kSpvLineFileInIdx); + *line = last_inst.GetSingleWordInOperand(kSpvLineLineInIdx); + *col = last_inst.GetSingleWordInOperand(kSpvLineColInIdx); + inst->dbg_line_insts().push_back(last_inst); + } + } + return modified; +} + +ProcessLinesPass::ProcessLinesPass(uint32_t func_id) { + if (func_id == kLinesPropagateLines) { + line_process_func_ = [this](Instruction* inst, uint32_t* file_id, + uint32_t* line, uint32_t* col) { + return PropagateLine(inst, file_id, line, col); + }; + } else { + assert(func_id == kLinesEliminateDeadLines && "unknown Lines param"); + line_process_func_ = [this](Instruction* inst, uint32_t* file_id, + uint32_t* line, uint32_t* col) { + return EliminateDeadLines(inst, file_id, line, col); + }; + } +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/process_lines_pass.h b/source/opt/process_lines_pass.h new file mode 100644 index 000000000..c988bfd0d --- /dev/null +++ b/source/opt/process_lines_pass.h @@ -0,0 +1,87 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// Copyright (c) 2018 Valve Corporation +// Copyright (c) 2018 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_PROPAGATE_LINES_PASS_H_ +#define SOURCE_OPT_PROPAGATE_LINES_PASS_H_ + +#include "source/opt/function.h" +#include "source/opt/ir_context.h" +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +namespace { + +// Constructor Parameters +static const int kLinesPropagateLines = 0; +static const int kLinesEliminateDeadLines = 1; + +} // anonymous namespace + +// See optimizer.hpp for documentation. +class ProcessLinesPass : public Pass { + using LineProcessFunction = + std::function; + + public: + ProcessLinesPass(uint32_t func_id); + ~ProcessLinesPass() override = default; + + const char* name() const override { return "propagate-lines"; } + + // See optimizer.hpp for this pass' user documentation. + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators | + IRContext::kAnalysisCFG | IRContext::kAnalysisDominatorAnalysis | + IRContext::kAnalysisNameMap | IRContext::kAnalysisConstants | + IRContext::kAnalysisTypes; + } + + private: + // If |inst| has no debug line instruction, create one with + // |file_id, line, col|. If |inst| has debug line instructions, set + // |file_id, line, col| from the last. |file_id| equals 0 indicates no line + // info is available. Return true if |inst| modified. + bool PropagateLine(Instruction* inst, uint32_t* file_id, uint32_t* line, + uint32_t* col); + + // If last debug line instruction of |inst| matches |file_id, line, col|, + // delete all debug line instructions of |inst|. If they do not match, + // replace all debug line instructions of |inst| with new line instruction + // set from |file_id, line, col|. If |inst| has no debug line instructions, + // do not modify |inst|. |file_id| equals 0 indicates no line info is + // available. Return true if |inst| modified. + bool EliminateDeadLines(Instruction* inst, uint32_t* file_id, uint32_t* line, + uint32_t* col); + + // Apply lpfn() to all type, constant, global variable and function + // instructions in their physical order. + bool ProcessLines(); + + // A function that calls either PropagateLine or EliminateDeadLines. + // Initialized by the class constructor. + LineProcessFunction line_process_func_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_PROPAGATE_LINES_PASS_H_ diff --git a/source/opt/propagator.cpp b/source/opt/propagator.cpp new file mode 100644 index 000000000..6a1f1aafb --- /dev/null +++ b/source/opt/propagator.cpp @@ -0,0 +1,291 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/propagator.h" + +namespace spvtools { +namespace opt { + +void SSAPropagator::AddControlEdge(const Edge& edge) { + BasicBlock* dest_bb = edge.dest; + + // Refuse to add the exit block to the work list. + if (dest_bb == ctx_->cfg()->pseudo_exit_block()) { + return; + } + + // Try to mark the edge executable. If it was already in the set of + // executable edges, do nothing. + if (!MarkEdgeExecutable(edge)) { + return; + } + + // If the edge had not already been marked executable, add the destination + // basic block to the work list. + blocks_.push(dest_bb); +} + +void SSAPropagator::AddSSAEdges(Instruction* instr) { + // Ignore instructions that produce no result. + if (instr->result_id() == 0) { + return; + } + + get_def_use_mgr()->ForEachUser( + instr->result_id(), [this](Instruction* use_instr) { + // If the basic block for |use_instr| has not been simulated yet, do + // nothing. The instruction |use_instr| will be simulated next time the + // block is scheduled. + if (!BlockHasBeenSimulated(ctx_->get_instr_block(use_instr))) { + return; + } + + if (ShouldSimulateAgain(use_instr)) { + ssa_edge_uses_.push(use_instr); + } + }); +} + +bool SSAPropagator::IsPhiArgExecutable(Instruction* phi, uint32_t i) const { + BasicBlock* phi_bb = ctx_->get_instr_block(phi); + + uint32_t in_label_id = phi->GetSingleWordOperand(i + 1); + Instruction* in_label_instr = get_def_use_mgr()->GetDef(in_label_id); + BasicBlock* in_bb = ctx_->get_instr_block(in_label_instr); + + return IsEdgeExecutable(Edge(in_bb, phi_bb)); +} + +bool SSAPropagator::SetStatus(Instruction* inst, PropStatus status) { + bool has_old_status = false; + PropStatus old_status = kVarying; + if (HasStatus(inst)) { + has_old_status = true; + old_status = Status(inst); + } + + assert((!has_old_status || old_status <= status) && + "Invalid lattice transition"); + + bool status_changed = !has_old_status || (old_status != status); + if (status_changed) statuses_[inst] = status; + + return status_changed; +} + +bool SSAPropagator::Simulate(Instruction* instr) { + bool changed = false; + + // Don't bother visiting instructions that should not be simulated again. + if (!ShouldSimulateAgain(instr)) { + return changed; + } + + BasicBlock* dest_bb = nullptr; + PropStatus status = visit_fn_(instr, &dest_bb); + bool status_changed = SetStatus(instr, status); + + if (status == kVarying) { + // The statement produces a varying result, add it to the list of statements + // not to simulate anymore and add its SSA def-use edges for simulation. + DontSimulateAgain(instr); + if (status_changed) { + AddSSAEdges(instr); + } + + // If |instr| is a block terminator, add all the control edges out of its + // block. + if (instr->IsBlockTerminator()) { + BasicBlock* block = ctx_->get_instr_block(instr); + for (const auto& e : bb_succs_.at(block)) { + AddControlEdge(e); + } + } + return false; + } else if (status == kInteresting) { + // Add the SSA edges coming out of this instruction if the propagation + // status has changed. + if (status_changed) { + AddSSAEdges(instr); + } + + // If there are multiple outgoing control flow edges and we know which one + // will be taken, add the destination block to the CFG work list. + if (dest_bb) { + AddControlEdge(Edge(ctx_->get_instr_block(instr), dest_bb)); + } + changed = true; + } + + // At this point, we are dealing with instructions that are in status + // kInteresting or kNotInteresting. To decide whether this instruction should + // be simulated again, we examine its operands. If at least one operand O is + // defined at an instruction D that should be simulated again, then the output + // of D might affect |instr|, so we should simulate |instr| again. + bool has_operands_to_simulate = false; + if (instr->opcode() == SpvOpPhi) { + // For Phi instructions, an operand causes the Phi to be simulated again if + // the operand comes from an edge that has not yet been traversed or if its + // definition should be simulated again. + for (uint32_t i = 2; i < instr->NumOperands(); i += 2) { + // Phi arguments come in pairs. Index 'i' contains the + // variable id, index 'i + 1' is the originating block id. + assert(i % 2 == 0 && i < instr->NumOperands() - 1 && + "malformed Phi arguments"); + + uint32_t arg_id = instr->GetSingleWordOperand(i); + Instruction* arg_def_instr = get_def_use_mgr()->GetDef(arg_id); + if (!IsPhiArgExecutable(instr, i) || ShouldSimulateAgain(arg_def_instr)) { + has_operands_to_simulate = true; + break; + } + } + } else { + // For regular instructions, check if the defining instruction of each + // operand needs to be simulated again. If so, then this instruction should + // also be simulated again. + has_operands_to_simulate = + !instr->WhileEachInId([this](const uint32_t* use) { + Instruction* def_instr = get_def_use_mgr()->GetDef(*use); + if (ShouldSimulateAgain(def_instr)) { + return false; + } + return true; + }); + } + + if (!has_operands_to_simulate) { + DontSimulateAgain(instr); + } + + return changed; +} + +bool SSAPropagator::Simulate(BasicBlock* block) { + if (block == ctx_->cfg()->pseudo_exit_block()) { + return false; + } + + // Always simulate Phi instructions, even if we have simulated this block + // before. We do this because Phi instructions receive their inputs from + // incoming edges. When those edges are marked executable, the corresponding + // operand can be simulated. + bool changed = false; + block->ForEachPhiInst( + [&changed, this](Instruction* instr) { changed |= Simulate(instr); }); + + // If this is the first time this block is being simulated, simulate every + // statement in it. + if (!BlockHasBeenSimulated(block)) { + block->ForEachInst([this, &changed](Instruction* instr) { + if (instr->opcode() != SpvOpPhi) { + changed |= Simulate(instr); + } + }); + + MarkBlockSimulated(block); + + // If this block has exactly one successor, mark the edge to its successor + // as executable. + if (bb_succs_.at(block).size() == 1) { + AddControlEdge(bb_succs_.at(block).at(0)); + } + } + + return changed; +} + +void SSAPropagator::Initialize(Function* fn) { + // Compute predecessor and successor blocks for every block in |fn|'s CFG. + // TODO(dnovillo): Move this to CFG and always build them. Alternately, + // move it to IRContext and build CFG preds/succs on-demand. + bb_succs_[ctx_->cfg()->pseudo_entry_block()].push_back( + Edge(ctx_->cfg()->pseudo_entry_block(), fn->entry().get())); + + for (auto& block : *fn) { + const auto& const_block = block; + const_block.ForEachSuccessorLabel([this, &block](const uint32_t label_id) { + BasicBlock* succ_bb = + ctx_->get_instr_block(get_def_use_mgr()->GetDef(label_id)); + bb_succs_[&block].push_back(Edge(&block, succ_bb)); + bb_preds_[succ_bb].push_back(Edge(succ_bb, &block)); + }); + if (block.IsReturnOrAbort()) { + bb_succs_[&block].push_back( + Edge(&block, ctx_->cfg()->pseudo_exit_block())); + bb_preds_[ctx_->cfg()->pseudo_exit_block()].push_back( + Edge(ctx_->cfg()->pseudo_exit_block(), &block)); + } + } + + // Add the edges out of the entry block to seed the propagator. + const auto& entry_succs = bb_succs_[ctx_->cfg()->pseudo_entry_block()]; + for (const auto& e : entry_succs) { + AddControlEdge(e); + } +} + +bool SSAPropagator::Run(Function* fn) { + Initialize(fn); + + bool changed = false; + while (!blocks_.empty() || !ssa_edge_uses_.empty()) { + // Simulate all blocks first. Simulating blocks will add SSA edges to + // follow after all the blocks have been simulated. + if (!blocks_.empty()) { + auto block = blocks_.front(); + changed |= Simulate(block); + blocks_.pop(); + continue; + } + + // Simulate edges from the SSA queue. + if (!ssa_edge_uses_.empty()) { + Instruction* instr = ssa_edge_uses_.front(); + changed |= Simulate(instr); + ssa_edge_uses_.pop(); + } + } + +#ifndef NDEBUG + // Verify all visited values have settled. No value that has been simulated + // should end on not interesting. + fn->ForEachInst([this](Instruction* inst) { + assert( + (!HasStatus(inst) || Status(inst) != SSAPropagator::kNotInteresting) && + "Unsettled value"); + }); +#endif + + return changed; +} + +std::ostream& operator<<(std::ostream& str, + const SSAPropagator::PropStatus& status) { + switch (status) { + case SSAPropagator::kVarying: + str << "Varying"; + break; + case SSAPropagator::kInteresting: + str << "Interesting"; + break; + default: + str << "Not interesting"; + break; + } + return str; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/propagator.h b/source/opt/propagator.h new file mode 100644 index 000000000..ac7c0e7ea --- /dev/null +++ b/source/opt/propagator.h @@ -0,0 +1,317 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_PROPAGATOR_H_ +#define SOURCE_OPT_PROPAGATOR_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "source/opt/ir_context.h" +#include "source/opt/module.h" + +namespace spvtools { +namespace opt { + +// Represents a CFG control edge. +struct Edge { + Edge(BasicBlock* b1, BasicBlock* b2) : source(b1), dest(b2) { + assert(source && "CFG edges cannot have a null source block."); + assert(dest && "CFG edges cannot have a null destination block."); + } + BasicBlock* source; + BasicBlock* dest; + bool operator<(const Edge& o) const { + return std::make_pair(source->id(), dest->id()) < + std::make_pair(o.source->id(), o.dest->id()); + } +}; + +// This class implements a generic value propagation algorithm based on the +// conditional constant propagation algorithm proposed in +// +// Constant propagation with conditional branches, +// Wegman and Zadeck, ACM TOPLAS 13(2):181-210. +// +// A Propagation Engine for GCC +// Diego Novillo, GCC Summit 2005 +// http://ols.fedoraproject.org/GCC/Reprints-2005/novillo-Reprint.pdf +// +// The purpose of this implementation is to act as a common framework for any +// transformation that needs to propagate values from statements producing new +// values to statements using those values. Simulation proceeds as follows: +// +// 1- Initially, all edges of the CFG are marked not executable and the CFG +// worklist is seeded with all the statements in the entry basic block. +// +// 2- Every instruction I is simulated by calling a pass-provided function +// |visit_fn|. This function is responsible for three things: +// +// (a) Keep a value table of interesting values. This table maps SSA IDs to +// their values. For instance, when implementing constant propagation, +// given a store operation 'OpStore %f %int_3', |visit_fn| should assign +// the value 3 to the table slot for %f. +// +// In general, |visit_fn| will need to use the value table to replace its +// operands, fold the result and decide whether a new value needs to be +// stored in the table. |visit_fn| should only create a new mapping in +// the value table if all the operands in the instruction are known and +// present in the value table. +// +// (b) Return a status indicator to direct the propagator logic. Once the +// instruction is simulated, the propagator needs to know whether this +// instruction produced something interesting. This is indicated via +// |visit_fn|'s return value: +// +// SSAPropagator::kNotInteresting: Instruction I produces nothing of +// interest and does not affect any of the work lists. The +// propagator will visit the statement again if any of its operands +// produce an interesting value in the future. +// +// |visit_fn| should always return this value when it is not sure +// whether the instruction will produce an interesting value in the +// future or not. For instance, for constant propagation, an OpIAdd +// instruction may produce a constant if its two operands are +// constant, but the first time we visit the instruction, we still +// may not have its operands in the value table. +// +// SSAPropagator::kVarying: The value produced by I cannot be determined +// at compile time. Further simulation of I is not required. The +// propagator will not visit this instruction again. Additionally, +// the propagator will add all the instructions at the end of SSA +// def-use edges to be simulated again. +// +// If I is a basic block terminator, it will mark all outgoing edges +// as executable so they are traversed one more time. Eventually +// the kVarying attribute will be spread out to all the data and +// control dependents for I. +// +// It is important for propagation to use kVarying as a bottom value +// for the propagation lattice. It should never be possible for an +// instruction to return kVarying once and kInteresting on a second +// visit. Otherwise, propagation would not stabilize. +// +// SSAPropagator::kInteresting: Instruction I produces a value that can +// be computed at compile time. In this case, |visit_fn| should +// create a new mapping between I's result ID and the produced +// value. Much like the kNotInteresting case, the propagator will +// visit this instruction again if any of its operands changes. +// This is useful when the statement changes from one interesting +// state to another. +// +// (c) For conditional branches, |visit_fn| may decide which edge to take out +// of I's basic block. For example, if the operand for an OpSwitch is +// known to take a specific constant value, |visit_fn| should figure out +// the destination basic block and pass it back by setting the second +// argument to |visit_fn|. +// +// At the end of propagation, values in the value table are guaranteed to be +// stable and can be replaced in the IR. +// +// 3- The propagator keeps two work queues. Instructions are only added to +// these queues if they produce an interesting or varying value. None of this +// should be handled by |visit_fn|. The propagator keeps track of this +// automatically (see SSAPropagator::Simulate for implementation). +// +// CFG blocks: contains the queue of blocks to be simulated. +// Blocks are added to this queue if their incoming edges are +// executable. +// +// SSA Edges: An SSA edge is a def-use edge between a value-producing +// instruction and its use instruction. The SSA edges list +// contains the statements at the end of a def-use edge that need +// to be re-visited when an instruction produces a kVarying or +// kInteresting result. +// +// 4- Simulation terminates when all work queues are drained. +// +// +// EXAMPLE: Basic constant store propagator. +// +// Suppose we want to propagate all constant assignments of the form "OpStore +// %id %cst" where "%id" is some variable and "%cst" an OpConstant. The +// following code builds a table |values| where every id that was assigned a +// constant value is mapped to the constant value it was assigned. +// +// auto ctx = BuildModule(...); +// std::map values; +// const auto visit_fn = [&ctx, &values](Instruction* instr, +// BasicBlock** dest_bb) { +// if (instr->opcode() == SpvOpStore) { +// uint32_t rhs_id = instr->GetSingleWordOperand(1); +// Instruction* rhs_def = ctx->get_def_use_mgr()->GetDef(rhs_id); +// if (rhs_def->opcode() == SpvOpConstant) { +// uint32_t val = rhs_def->GetSingleWordOperand(2); +// values[rhs_id] = val; +// return SSAPropagator::kInteresting; +// } +// } +// return SSAPropagator::kVarying; +// }; +// SSAPropagator propagator(ctx.get(), &cfg, visit_fn); +// propagator.Run(&fn); +// +// Given the code: +// +// %int_4 = OpConstant %int 4 +// %int_3 = OpConstant %int 3 +// %int_1 = OpConstant %int 1 +// OpStore %x %int_4 +// OpStore %y %int_3 +// OpStore %z %int_1 +// +// After SSAPropagator::Run returns, the |values| map will contain the entries: +// values[%x] = 4, values[%y] = 3, and, values[%z] = 1. +class SSAPropagator { + public: + // Lattice values used for propagation. See class documentation for + // a description. + enum PropStatus { kNotInteresting, kInteresting, kVarying }; + + using VisitFunction = std::function; + + SSAPropagator(IRContext* context, const VisitFunction& visit_fn) + : ctx_(context), visit_fn_(visit_fn) {} + + // Runs the propagator on function |fn|. Returns true if changes were made to + // the function. Otherwise, it returns false. + bool Run(Function* fn); + + // Returns true if the |i|th argument for |phi| comes through a CFG edge that + // has been marked executable. |i| should be an index value accepted by + // Instruction::GetSingleWordOperand. + bool IsPhiArgExecutable(Instruction* phi, uint32_t i) const; + + // Returns true if |inst| has a recorded status. This will be true once |inst| + // has been simulated once. + bool HasStatus(Instruction* inst) const { return statuses_.count(inst); } + + // Returns the current propagation status of |inst|. Assumes + // |HasStatus(inst)| returns true. + PropStatus Status(Instruction* inst) const { + return statuses_.find(inst)->second; + } + + // Records the propagation status |status| for |inst|. Returns true if the + // status for |inst| has changed or set was set for the first time. + bool SetStatus(Instruction* inst, PropStatus status); + + private: + // Initialize processing. + void Initialize(Function* fn); + + // Simulate the execution |block| by calling |visit_fn_| on every instruction + // in it. + bool Simulate(BasicBlock* block); + + // Simulate the execution of |instr| by replacing all the known values in + // every operand and determining whether the result is interesting for + // propagation. This invokes the callback function |visit_fn_| to determine + // the value computed by |instr|. + bool Simulate(Instruction* instr); + + // Returns true if |instr| should be simulated again. + bool ShouldSimulateAgain(Instruction* instr) const { + return do_not_simulate_.find(instr) == do_not_simulate_.end(); + } + + // Add |instr| to the set of instructions not to simulate again. + void DontSimulateAgain(Instruction* instr) { do_not_simulate_.insert(instr); } + + // Returns true if |block| has been simulated already. + bool BlockHasBeenSimulated(BasicBlock* block) const { + return simulated_blocks_.find(block) != simulated_blocks_.end(); + } + + // Marks block |block| as simulated. + void MarkBlockSimulated(BasicBlock* block) { + simulated_blocks_.insert(block); + } + + // Marks |edge| as executable. Returns false if the edge was already marked + // as executable. + bool MarkEdgeExecutable(const Edge& edge) { + return executable_edges_.insert(edge).second; + } + + // Returns true if |edge| has been marked as executable. + bool IsEdgeExecutable(const Edge& edge) const { + return executable_edges_.find(edge) != executable_edges_.end(); + } + + // Returns a pointer to the def-use manager for |ctx_|. + analysis::DefUseManager* get_def_use_mgr() const { + return ctx_->get_def_use_mgr(); + } + + // If the CFG edge |e| has not been executed, this function adds |e|'s + // destination block to the work list. + void AddControlEdge(const Edge& e); + + // Adds all the instructions that use the result of |instr| to the SSA edges + // work list. If |instr| produces no result id, this does nothing. + void AddSSAEdges(Instruction* instr); + + // IR context to use. + IRContext* ctx_; + + // Function that visits instructions during simulation. The output of this + // function is used to determine if the simulated instruction produced a value + // interesting for propagation. The function is responsible for keeping + // track of interesting values by storing them in some user-provided map. + VisitFunction visit_fn_; + + // SSA def-use edges to traverse. Each entry is a destination statement for an + // SSA def-use edge as returned by |def_use_manager_|. + std::queue ssa_edge_uses_; + + // Blocks to simulate. + std::queue blocks_; + + // Blocks simulated during propagation. + std::unordered_set simulated_blocks_; + + // Set of instructions that should not be simulated again because they have + // been found to be in the kVarying state. + std::unordered_set do_not_simulate_; + + // Map between a basic block and its predecessor edges. + // TODO(dnovillo): Move this to CFG and always build them. Alternately, + // move it to IRContext and build CFG preds/succs on-demand. + std::unordered_map> bb_preds_; + + // Map between a basic block and its successor edges. + // TODO(dnovillo): Move this to CFG and always build them. Alternately, + // move it to IRContext and build CFG preds/succs on-demand. + std::unordered_map> bb_succs_; + + // Set of executable CFG edges. + std::set executable_edges_; + + // Tracks instruction propagation status. + std::unordered_map statuses_; +}; + +std::ostream& operator<<(std::ostream& str, + const SSAPropagator::PropStatus& status); + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_PROPAGATOR_H_ diff --git a/source/opt/reduce_load_size.cpp b/source/opt/reduce_load_size.cpp new file mode 100644 index 000000000..7b5a015c4 --- /dev/null +++ b/source/opt/reduce_load_size.cpp @@ -0,0 +1,182 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/reduce_load_size.h" + +#include +#include + +#include "source/opt/instruction.h" +#include "source/opt/ir_builder.h" +#include "source/opt/ir_context.h" +#include "source/util/bit_vector.h" + +namespace { + +const uint32_t kExtractCompositeIdInIdx = 0; +const uint32_t kVariableStorageClassInIdx = 0; +const uint32_t kLoadPointerInIdx = 0; +const double kThreshold = 0.9; + +} // namespace + +namespace spvtools { +namespace opt { + +Pass::Status ReduceLoadSize::Process() { + bool modified = false; + + for (auto& func : *get_module()) { + func.ForEachInst([&modified, this](Instruction* inst) { + if (inst->opcode() == SpvOpCompositeExtract) { + if (ShouldReplaceExtract(inst)) { + modified |= ReplaceExtract(inst); + } + } + }); + } + + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +bool ReduceLoadSize::ReplaceExtract(Instruction* inst) { + assert(inst->opcode() == SpvOpCompositeExtract && + "Wrong opcode. Should be OpCompositeExtract."); + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); + + uint32_t composite_id = + inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); + Instruction* composite_inst = def_use_mgr->GetDef(composite_id); + + if (composite_inst->opcode() != SpvOpLoad) { + return false; + } + + analysis::Type* composite_type = type_mgr->GetType(composite_inst->type_id()); + if (composite_type->kind() == analysis::Type::kVector || + composite_type->kind() == analysis::Type::kMatrix) { + return false; + } + + Instruction* var = composite_inst->GetBaseAddress(); + if (var == nullptr || var->opcode() != SpvOpVariable) { + return false; + } + + SpvStorageClass storage_class = static_cast( + var->GetSingleWordInOperand(kVariableStorageClassInIdx)); + switch (storage_class) { + case SpvStorageClassUniform: + case SpvStorageClassUniformConstant: + case SpvStorageClassInput: + break; + default: + return false; + } + + // Create a new access chain and load just after the old load. + // We cannot create the new access chain load in the position of the extract + // because the storage may have been written to in between. + InstructionBuilder ir_builder( + inst->context(), composite_inst, + IRContext::kAnalysisInstrToBlockMapping | IRContext::kAnalysisDefUse); + + uint32_t pointer_to_result_type_id = + type_mgr->FindPointerToType(inst->type_id(), storage_class); + assert(pointer_to_result_type_id != 0 && + "We did not find the pointer type that we need."); + + analysis::Integer int_type(32, false); + const analysis::Type* uint32_type = type_mgr->GetRegisteredType(&int_type); + std::vector ids; + for (uint32_t i = 1; i < inst->NumInOperands(); ++i) { + uint32_t index = inst->GetSingleWordInOperand(i); + const analysis::Constant* index_const = + const_mgr->GetConstant(uint32_type, {index}); + ids.push_back(const_mgr->GetDefiningInstruction(index_const)->result_id()); + } + + Instruction* new_access_chain = ir_builder.AddAccessChain( + pointer_to_result_type_id, + composite_inst->GetSingleWordInOperand(kLoadPointerInIdx), ids); + Instruction* new_laod = + ir_builder.AddLoad(inst->type_id(), new_access_chain->result_id()); + + context()->ReplaceAllUsesWith(inst->result_id(), new_laod->result_id()); + context()->KillInst(inst); + return true; +} + +bool ReduceLoadSize::ShouldReplaceExtract(Instruction* inst) { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + Instruction* op_inst = def_use_mgr->GetDef( + inst->GetSingleWordInOperand(kExtractCompositeIdInIdx)); + + if (op_inst->opcode() != SpvOpLoad) { + return false; + } + + auto cached_result = should_replace_cache_.find(op_inst->result_id()); + if (cached_result != should_replace_cache_.end()) { + return cached_result->second; + } + + bool all_elements_used = false; + std::set elements_used; + + all_elements_used = + !def_use_mgr->WhileEachUser(op_inst, [&elements_used](Instruction* use) { + if (use->opcode() != SpvOpCompositeExtract || + use->NumInOperands() == 1) { + return false; + } + elements_used.insert(use->GetSingleWordInOperand(1)); + return true; + }); + + bool should_replace = false; + if (all_elements_used) { + should_replace = false; + } else { + analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + analysis::Type* load_type = type_mgr->GetType(op_inst->type_id()); + uint32_t total_size = 1; + switch (load_type->kind()) { + case analysis::Type::kArray: { + const analysis::Constant* size_const = + const_mgr->FindDeclaredConstant(load_type->AsArray()->LengthId()); + assert(size_const->AsIntConstant()); + total_size = size_const->GetU32(); + } break; + case analysis::Type::kStruct: + total_size = static_cast( + load_type->AsStruct()->element_types().size()); + break; + default: + break; + } + double percent_used = static_cast(elements_used.size()) / + static_cast(total_size); + should_replace = (percent_used < kThreshold); + } + + should_replace_cache_[op_inst->result_id()] = should_replace; + return should_replace; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/reduce_load_size.h b/source/opt/reduce_load_size.h new file mode 100644 index 000000000..ccac49be6 --- /dev/null +++ b/source/opt/reduce_load_size.h @@ -0,0 +1,65 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_REDUCE_LOAD_SIZE_H_ +#define SOURCE_OPT_REDUCE_LOAD_SIZE_H_ + +#include + +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class ReduceLoadSize : public Pass { + public: + const char* name() const override { return "reduce-load-size"; } + Status Process() override; + + // Return the mask of preserved Analyses. + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisCombinators | IRContext::kAnalysisCFG | + IRContext::kAnalysisDominatorAnalysis | + IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisNameMap | + IRContext::kAnalysisConstants | IRContext::kAnalysisTypes; + } + + private: + // Replaces |inst|, which must be an OpCompositeExtract instruction, with + // an OpAccessChain and a load if possible. This happens only if it is a load + // feeding |inst|. Returns true if the substitution happened. The position + // of the new instructions will be in the same place as the load feeding the + // extract. + bool ReplaceExtract(Instruction* inst); + + // Returns true if the OpCompositeExtract instruction |inst| should be replace + // or not. This is determined by looking at the load that feeds |inst| if + // it is a load. |should_replace_cache_| is used to cache the results based + // on the load feeding |inst|. + bool ShouldReplaceExtract(Instruction* inst); + + // Maps the result id of an OpLoad instruction to the result of whether or + // not the OpCompositeExtract that use the id should be replaced. + std::unordered_map should_replace_cache_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_REDUCE_LOAD_SIZE_H_ diff --git a/source/opt/redundancy_elimination.cpp b/source/opt/redundancy_elimination.cpp new file mode 100644 index 000000000..362e54dc6 --- /dev/null +++ b/source/opt/redundancy_elimination.cpp @@ -0,0 +1,56 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/redundancy_elimination.h" + +#include "source/opt/value_number_table.h" + +namespace spvtools { +namespace opt { + +Pass::Status RedundancyEliminationPass::Process() { + bool modified = false; + ValueNumberTable vnTable(context()); + + for (auto& func : *get_module()) { + // Build the dominator tree for this function. It is how the code is + // traversed. + DominatorTree& dom_tree = + context()->GetDominatorAnalysis(&func)->GetDomTree(); + + // Keeps track of all ids that contain a given value number. We keep + // track of multiple values because they could have the same value, but + // different decorations. + std::map value_to_ids; + + if (EliminateRedundanciesFrom(dom_tree.GetRoot(), vnTable, value_to_ids)) { + modified = true; + } + } + return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); +} + +bool RedundancyEliminationPass::EliminateRedundanciesFrom( + DominatorTreeNode* bb, const ValueNumberTable& vnTable, + std::map value_to_ids) { + bool modified = EliminateRedundanciesInBB(bb->bb_, vnTable, &value_to_ids); + + for (auto dominated_bb : bb->children_) { + modified |= EliminateRedundanciesFrom(dominated_bb, vnTable, value_to_ids); + } + + return modified; +} +} // namespace opt +} // namespace spvtools diff --git a/source/opt/redundancy_elimination.h b/source/opt/redundancy_elimination.h new file mode 100644 index 000000000..91809b5d5 --- /dev/null +++ b/source/opt/redundancy_elimination.h @@ -0,0 +1,56 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_REDUNDANCY_ELIMINATION_H_ +#define SOURCE_OPT_REDUNDANCY_ELIMINATION_H_ + +#include + +#include "source/opt/ir_context.h" +#include "source/opt/local_redundancy_elimination.h" +#include "source/opt/pass.h" +#include "source/opt/value_number_table.h" + +namespace spvtools { +namespace opt { + +// This pass implements total redundancy elimination. This is the same as +// local redundancy elimination except it looks across basic block boundaries. +// An instruction, inst, is totally redundant if there is another instruction +// that dominates inst, and also computes the same value. +class RedundancyEliminationPass : public LocalRedundancyEliminationPass { + public: + const char* name() const override { return "redundancy-elimination"; } + Status Process() override; + + protected: + // Removes for all total redundancies in the function starting at |bb|. + // + // |vnTable| must have computed a value number for every result id defined + // in the function containing |bb|. + // + // |value_to_ids| is a map from value number to ids. If {vn, id} is in + // |value_to_ids| then vn is the value number of id, and the defintion of id + // dominates |bb|. + // + // Returns true if at least one instruction is deleted. + bool EliminateRedundanciesFrom(DominatorTreeNode* bb, + const ValueNumberTable& vnTable, + std::map value_to_ids); +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_REDUNDANCY_ELIMINATION_H_ diff --git a/source/opt/reflect.h b/source/opt/reflect.h new file mode 100644 index 000000000..79d90bda4 --- /dev/null +++ b/source/opt/reflect.h @@ -0,0 +1,66 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_REFLECT_H_ +#define SOURCE_OPT_REFLECT_H_ + +#include "source/latest_version_spirv_header.h" + +namespace spvtools { +namespace opt { + +// Note that as SPIR-V evolves over time, new opcodes may appear. So the +// following functions tend to be outdated and should be updated when SPIR-V +// version bumps. + +inline bool IsDebug1Inst(SpvOp opcode) { + return (opcode >= SpvOpSourceContinued && opcode <= SpvOpSourceExtension) || + opcode == SpvOpString; +} +inline bool IsDebug2Inst(SpvOp opcode) { + return opcode == SpvOpName || opcode == SpvOpMemberName; +} +inline bool IsDebug3Inst(SpvOp opcode) { + return opcode == SpvOpModuleProcessed; +} +inline bool IsDebugLineInst(SpvOp opcode) { + return opcode == SpvOpLine || opcode == SpvOpNoLine; +} +inline bool IsAnnotationInst(SpvOp opcode) { + return (opcode >= SpvOpDecorate && opcode <= SpvOpGroupMemberDecorate) || + opcode == SpvOpDecorateId || opcode == SpvOpDecorateStringGOOGLE || + opcode == SpvOpMemberDecorateStringGOOGLE; +} +inline bool IsTypeInst(SpvOp opcode) { + return (opcode >= SpvOpTypeVoid && opcode <= SpvOpTypeForwardPointer) || + opcode == SpvOpTypePipeStorage || opcode == SpvOpTypeNamedBarrier || + opcode == SpvOpTypeAccelerationStructureNV; +} +inline bool IsConstantInst(SpvOp opcode) { + return opcode >= SpvOpConstantTrue && opcode <= SpvOpSpecConstantOp; +} +inline bool IsCompileTimeConstantInst(SpvOp opcode) { + return opcode >= SpvOpConstantTrue && opcode <= SpvOpConstantNull; +} +inline bool IsSpecConstantInst(SpvOp opcode) { + return opcode >= SpvOpSpecConstantTrue && opcode <= SpvOpSpecConstantOp; +} +inline bool IsTerminatorInst(SpvOp opcode) { + return opcode >= SpvOpBranch && opcode <= SpvOpUnreachable; +} + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_REFLECT_H_ diff --git a/source/opt/register_pressure.cpp b/source/opt/register_pressure.cpp new file mode 100644 index 000000000..34dac1d7b --- /dev/null +++ b/source/opt/register_pressure.cpp @@ -0,0 +1,576 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/register_pressure.h" + +#include +#include + +#include "source/opt/cfg.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/dominator_tree.h" +#include "source/opt/function.h" +#include "source/opt/ir_context.h" +#include "source/opt/iterator.h" + +namespace spvtools { +namespace opt { + +namespace { +// Predicate for the FilterIterator to only consider instructions that are not +// phi instructions defined in the basic block |bb|. +class ExcludePhiDefinedInBlock { + public: + ExcludePhiDefinedInBlock(IRContext* context, const BasicBlock* bb) + : context_(context), bb_(bb) {} + + bool operator()(Instruction* insn) const { + return !(insn->opcode() == SpvOpPhi && + context_->get_instr_block(insn) == bb_); + } + + private: + IRContext* context_; + const BasicBlock* bb_; +}; + +// Returns true if |insn| generates a SSA register that is likely to require a +// physical register. +bool CreatesRegisterUsage(Instruction* insn) { + if (!insn->HasResultId()) return false; + if (insn->opcode() == SpvOpUndef) return false; + if (IsConstantInst(insn->opcode())) return false; + if (insn->opcode() == SpvOpLabel) return false; + return true; +} + +// Compute the register liveness for each basic block of a function. This also +// fill-up some information about the pick register usage and a break down of +// register usage. This implements: "A non-iterative data-flow algorithm for +// computing liveness sets in strict ssa programs" from Boissinot et al. +class ComputeRegisterLiveness { + public: + ComputeRegisterLiveness(RegisterLiveness* reg_pressure, Function* f) + : reg_pressure_(reg_pressure), + context_(reg_pressure->GetContext()), + function_(f), + cfg_(*reg_pressure->GetContext()->cfg()), + def_use_manager_(*reg_pressure->GetContext()->get_def_use_mgr()), + dom_tree_( + reg_pressure->GetContext()->GetDominatorAnalysis(f)->GetDomTree()), + loop_desc_(*reg_pressure->GetContext()->GetLoopDescriptor(f)) {} + + // Computes the register liveness for |function_| and then estimate the + // register usage. The liveness algorithm works in 2 steps: + // - First, compute the liveness for each basic blocks, but will ignore any + // back-edge; + // - Second, walk loop forest to propagate registers crossing back-edges + // (add iterative values into the liveness set). + void Compute() { + cfg_.ForEachBlockInPostOrder(&*function_->begin(), [this](BasicBlock* bb) { + ComputePartialLiveness(bb); + }); + DoLoopLivenessUnification(); + EvaluateRegisterRequirements(); + } + + private: + // Registers all SSA register used by successors of |bb| in their phi + // instructions. + void ComputePhiUses(const BasicBlock& bb, + RegisterLiveness::RegionRegisterLiveness::LiveSet* live) { + uint32_t bb_id = bb.id(); + bb.ForEachSuccessorLabel([live, bb_id, this](uint32_t sid) { + BasicBlock* succ_bb = cfg_.block(sid); + succ_bb->ForEachPhiInst([live, bb_id, this](const Instruction* phi) { + for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) { + if (phi->GetSingleWordInOperand(i + 1) == bb_id) { + Instruction* insn_op = + def_use_manager_.GetDef(phi->GetSingleWordInOperand(i)); + if (CreatesRegisterUsage(insn_op)) { + live->insert(insn_op); + break; + } + } + } + }); + }); + } + + // Computes register liveness for each basic blocks but ignores all + // back-edges. + void ComputePartialLiveness(BasicBlock* bb) { + assert(reg_pressure_->Get(bb) == nullptr && + "Basic block already processed"); + + RegisterLiveness::RegionRegisterLiveness* live_inout = + reg_pressure_->GetOrInsert(bb->id()); + ComputePhiUses(*bb, &live_inout->live_out_); + + const BasicBlock* cbb = bb; + cbb->ForEachSuccessorLabel([&live_inout, bb, this](uint32_t sid) { + // Skip back edges. + if (dom_tree_.Dominates(sid, bb->id())) { + return; + } + + BasicBlock* succ_bb = cfg_.block(sid); + RegisterLiveness::RegionRegisterLiveness* succ_live_inout = + reg_pressure_->Get(succ_bb); + assert(succ_live_inout && + "Successor liveness analysis was not performed"); + + ExcludePhiDefinedInBlock predicate(context_, succ_bb); + auto filter = + MakeFilterIteratorRange(succ_live_inout->live_in_.begin(), + succ_live_inout->live_in_.end(), predicate); + live_inout->live_out_.insert(filter.begin(), filter.end()); + }); + + live_inout->live_in_ = live_inout->live_out_; + for (Instruction& insn : make_range(bb->rbegin(), bb->rend())) { + if (insn.opcode() == SpvOpPhi) { + live_inout->live_in_.insert(&insn); + break; + } + live_inout->live_in_.erase(&insn); + insn.ForEachInId([live_inout, this](uint32_t* id) { + Instruction* insn_op = def_use_manager_.GetDef(*id); + if (CreatesRegisterUsage(insn_op)) { + live_inout->live_in_.insert(insn_op); + } + }); + } + } + + // Propagates the register liveness information of each loop iterators. + void DoLoopLivenessUnification() { + for (const Loop* loop : *loop_desc_.GetDummyRootLoop()) { + DoLoopLivenessUnification(*loop); + } + } + + // Propagates the register liveness information of loop iterators trough-out + // the loop body. + void DoLoopLivenessUnification(const Loop& loop) { + auto blocks_in_loop = MakeFilterIteratorRange( + loop.GetBlocks().begin(), loop.GetBlocks().end(), + [&loop, this](uint32_t bb_id) { + return bb_id != loop.GetHeaderBlock()->id() && + loop_desc_[bb_id] == &loop; + }); + + RegisterLiveness::RegionRegisterLiveness* header_live_inout = + reg_pressure_->Get(loop.GetHeaderBlock()); + assert(header_live_inout && + "Liveness analysis was not performed for the current block"); + + ExcludePhiDefinedInBlock predicate(context_, loop.GetHeaderBlock()); + auto live_loop = + MakeFilterIteratorRange(header_live_inout->live_in_.begin(), + header_live_inout->live_in_.end(), predicate); + + for (uint32_t bb_id : blocks_in_loop) { + BasicBlock* bb = cfg_.block(bb_id); + + RegisterLiveness::RegionRegisterLiveness* live_inout = + reg_pressure_->Get(bb); + live_inout->live_in_.insert(live_loop.begin(), live_loop.end()); + live_inout->live_out_.insert(live_loop.begin(), live_loop.end()); + } + + for (const Loop* inner_loop : loop) { + RegisterLiveness::RegionRegisterLiveness* live_inout = + reg_pressure_->Get(inner_loop->GetHeaderBlock()); + live_inout->live_in_.insert(live_loop.begin(), live_loop.end()); + live_inout->live_out_.insert(live_loop.begin(), live_loop.end()); + + DoLoopLivenessUnification(*inner_loop); + } + } + + // Get the number of required registers for this each basic block. + void EvaluateRegisterRequirements() { + for (BasicBlock& bb : *function_) { + RegisterLiveness::RegionRegisterLiveness* live_inout = + reg_pressure_->Get(bb.id()); + assert(live_inout != nullptr && "Basic block not processed"); + + size_t reg_count = live_inout->live_out_.size(); + for (Instruction* insn : live_inout->live_out_) { + live_inout->AddRegisterClass(insn); + } + live_inout->used_registers_ = reg_count; + + std::unordered_set die_in_block; + for (Instruction& insn : make_range(bb.rbegin(), bb.rend())) { + // If it is a phi instruction, the register pressure will not change + // anymore. + if (insn.opcode() == SpvOpPhi) { + break; + } + + insn.ForEachInId( + [live_inout, &die_in_block, ®_count, this](uint32_t* id) { + Instruction* op_insn = def_use_manager_.GetDef(*id); + if (!CreatesRegisterUsage(op_insn) || + live_inout->live_out_.count(op_insn)) { + // already taken into account. + return; + } + if (!die_in_block.count(*id)) { + live_inout->AddRegisterClass(def_use_manager_.GetDef(*id)); + reg_count++; + die_in_block.insert(*id); + } + }); + live_inout->used_registers_ = + std::max(live_inout->used_registers_, reg_count); + if (CreatesRegisterUsage(&insn)) { + reg_count--; + } + } + } + } + + RegisterLiveness* reg_pressure_; + IRContext* context_; + Function* function_; + CFG& cfg_; + analysis::DefUseManager& def_use_manager_; + DominatorTree& dom_tree_; + LoopDescriptor& loop_desc_; +}; +} // namespace + +// Get the number of required registers for each basic block. +void RegisterLiveness::RegionRegisterLiveness::AddRegisterClass( + Instruction* insn) { + assert(CreatesRegisterUsage(insn) && "Instruction does not use a register"); + analysis::Type* type = + insn->context()->get_type_mgr()->GetType(insn->type_id()); + + RegisterLiveness::RegisterClass reg_class{type, false}; + + insn->context()->get_decoration_mgr()->WhileEachDecoration( + insn->result_id(), SpvDecorationUniform, + [®_class](const Instruction&) { + reg_class.is_uniform_ = true; + return false; + }); + + AddRegisterClass(reg_class); +} + +void RegisterLiveness::Analyze(Function* f) { + block_pressure_.clear(); + ComputeRegisterLiveness(this, f).Compute(); +} + +void RegisterLiveness::ComputeLoopRegisterPressure( + const Loop& loop, RegionRegisterLiveness* loop_reg_pressure) const { + loop_reg_pressure->Clear(); + + const RegionRegisterLiveness* header_live_inout = Get(loop.GetHeaderBlock()); + loop_reg_pressure->live_in_ = header_live_inout->live_in_; + + std::unordered_set exit_blocks; + loop.GetExitBlocks(&exit_blocks); + + for (uint32_t bb_id : exit_blocks) { + const RegionRegisterLiveness* live_inout = Get(bb_id); + loop_reg_pressure->live_out_.insert(live_inout->live_in_.begin(), + live_inout->live_in_.end()); + } + + std::unordered_set seen_insn; + for (Instruction* insn : loop_reg_pressure->live_out_) { + loop_reg_pressure->AddRegisterClass(insn); + seen_insn.insert(insn->result_id()); + } + for (Instruction* insn : loop_reg_pressure->live_in_) { + if (!seen_insn.count(insn->result_id())) { + continue; + } + loop_reg_pressure->AddRegisterClass(insn); + seen_insn.insert(insn->result_id()); + } + + loop_reg_pressure->used_registers_ = 0; + + for (uint32_t bb_id : loop.GetBlocks()) { + BasicBlock* bb = context_->cfg()->block(bb_id); + + const RegionRegisterLiveness* live_inout = Get(bb_id); + assert(live_inout != nullptr && "Basic block not processed"); + loop_reg_pressure->used_registers_ = std::max( + loop_reg_pressure->used_registers_, live_inout->used_registers_); + + for (Instruction& insn : *bb) { + if (insn.opcode() == SpvOpPhi || !CreatesRegisterUsage(&insn) || + seen_insn.count(insn.result_id())) { + continue; + } + loop_reg_pressure->AddRegisterClass(&insn); + } + } +} + +void RegisterLiveness::SimulateFusion( + const Loop& l1, const Loop& l2, RegionRegisterLiveness* sim_result) const { + sim_result->Clear(); + + // Compute the live-in state: + // sim_result.live_in = l1.live_in U l2.live_in + // This assumes that |l1| does not generated register that is live-out for + // |l1|. + const RegionRegisterLiveness* l1_header_live_inout = Get(l1.GetHeaderBlock()); + sim_result->live_in_ = l1_header_live_inout->live_in_; + + const RegionRegisterLiveness* l2_header_live_inout = Get(l2.GetHeaderBlock()); + sim_result->live_in_.insert(l2_header_live_inout->live_in_.begin(), + l2_header_live_inout->live_in_.end()); + + // The live-out set of the fused loop is the l2 live-out set. + std::unordered_set exit_blocks; + l2.GetExitBlocks(&exit_blocks); + + for (uint32_t bb_id : exit_blocks) { + const RegionRegisterLiveness* live_inout = Get(bb_id); + sim_result->live_out_.insert(live_inout->live_in_.begin(), + live_inout->live_in_.end()); + } + + // Compute the register usage information. + std::unordered_set seen_insn; + for (Instruction* insn : sim_result->live_out_) { + sim_result->AddRegisterClass(insn); + seen_insn.insert(insn->result_id()); + } + for (Instruction* insn : sim_result->live_in_) { + if (!seen_insn.count(insn->result_id())) { + continue; + } + sim_result->AddRegisterClass(insn); + seen_insn.insert(insn->result_id()); + } + + sim_result->used_registers_ = 0; + + // The loop fusion is injecting the l1 before the l2, the latch of l1 will be + // connected to the header of l2. + // To compute the register usage, we inject the loop live-in (union of l1 and + // l2 live-in header blocks) into the the live in/out of each basic block of + // l1 to get the peak register usage. We then repeat the operation to for l2 + // basic blocks but in this case we inject the live-out of the latch of l1. + auto live_loop = MakeFilterIteratorRange( + sim_result->live_in_.begin(), sim_result->live_in_.end(), + [&l1, &l2](Instruction* insn) { + BasicBlock* bb = insn->context()->get_instr_block(insn); + return insn->HasResultId() && + !(insn->opcode() == SpvOpPhi && + (bb == l1.GetHeaderBlock() || bb == l2.GetHeaderBlock())); + }); + + for (uint32_t bb_id : l1.GetBlocks()) { + BasicBlock* bb = context_->cfg()->block(bb_id); + + const RegionRegisterLiveness* live_inout_info = Get(bb_id); + assert(live_inout_info != nullptr && "Basic block not processed"); + RegionRegisterLiveness::LiveSet live_out = live_inout_info->live_out_; + live_out.insert(live_loop.begin(), live_loop.end()); + sim_result->used_registers_ = + std::max(sim_result->used_registers_, + live_inout_info->used_registers_ + live_out.size() - + live_inout_info->live_out_.size()); + + for (Instruction& insn : *bb) { + if (insn.opcode() == SpvOpPhi || !CreatesRegisterUsage(&insn) || + seen_insn.count(insn.result_id())) { + continue; + } + sim_result->AddRegisterClass(&insn); + } + } + + const RegionRegisterLiveness* l1_latch_live_inout_info = + Get(l1.GetLatchBlock()->id()); + assert(l1_latch_live_inout_info != nullptr && "Basic block not processed"); + RegionRegisterLiveness::LiveSet l1_latch_live_out = + l1_latch_live_inout_info->live_out_; + l1_latch_live_out.insert(live_loop.begin(), live_loop.end()); + + auto live_loop_l2 = + make_range(l1_latch_live_out.begin(), l1_latch_live_out.end()); + + for (uint32_t bb_id : l2.GetBlocks()) { + BasicBlock* bb = context_->cfg()->block(bb_id); + + const RegionRegisterLiveness* live_inout_info = Get(bb_id); + assert(live_inout_info != nullptr && "Basic block not processed"); + RegionRegisterLiveness::LiveSet live_out = live_inout_info->live_out_; + live_out.insert(live_loop_l2.begin(), live_loop_l2.end()); + sim_result->used_registers_ = + std::max(sim_result->used_registers_, + live_inout_info->used_registers_ + live_out.size() - + live_inout_info->live_out_.size()); + + for (Instruction& insn : *bb) { + if (insn.opcode() == SpvOpPhi || !CreatesRegisterUsage(&insn) || + seen_insn.count(insn.result_id())) { + continue; + } + sim_result->AddRegisterClass(&insn); + } + } +} + +void RegisterLiveness::SimulateFission( + const Loop& loop, const std::unordered_set& moved_inst, + const std::unordered_set& copied_inst, + RegionRegisterLiveness* l1_sim_result, + RegionRegisterLiveness* l2_sim_result) const { + l1_sim_result->Clear(); + l2_sim_result->Clear(); + + // Filter predicates: consider instructions that only belong to the first and + // second loop. + auto belong_to_loop1 = [&moved_inst, &copied_inst, &loop](Instruction* insn) { + return moved_inst.count(insn) || copied_inst.count(insn) || + !loop.IsInsideLoop(insn); + }; + auto belong_to_loop2 = [&moved_inst](Instruction* insn) { + return !moved_inst.count(insn); + }; + + const RegionRegisterLiveness* header_live_inout = Get(loop.GetHeaderBlock()); + // l1 live-in + { + auto live_loop = MakeFilterIteratorRange( + header_live_inout->live_in_.begin(), header_live_inout->live_in_.end(), + belong_to_loop1); + l1_sim_result->live_in_.insert(live_loop.begin(), live_loop.end()); + } + // l2 live-in + { + auto live_loop = MakeFilterIteratorRange( + header_live_inout->live_in_.begin(), header_live_inout->live_in_.end(), + belong_to_loop2); + l2_sim_result->live_in_.insert(live_loop.begin(), live_loop.end()); + } + + std::unordered_set exit_blocks; + loop.GetExitBlocks(&exit_blocks); + + // l2 live-out. + for (uint32_t bb_id : exit_blocks) { + const RegionRegisterLiveness* live_inout = Get(bb_id); + l2_sim_result->live_out_.insert(live_inout->live_in_.begin(), + live_inout->live_in_.end()); + } + // l1 live-out. + { + auto live_out = MakeFilterIteratorRange(l2_sim_result->live_out_.begin(), + l2_sim_result->live_out_.end(), + belong_to_loop1); + l1_sim_result->live_out_.insert(live_out.begin(), live_out.end()); + } + { + auto live_out = + MakeFilterIteratorRange(l2_sim_result->live_in_.begin(), + l2_sim_result->live_in_.end(), belong_to_loop1); + l1_sim_result->live_out_.insert(live_out.begin(), live_out.end()); + } + // Lives out of l1 are live out of l2 so are live in of l2 as well. + l2_sim_result->live_in_.insert(l1_sim_result->live_out_.begin(), + l1_sim_result->live_out_.end()); + + for (Instruction* insn : l1_sim_result->live_in_) { + l1_sim_result->AddRegisterClass(insn); + } + for (Instruction* insn : l2_sim_result->live_in_) { + l2_sim_result->AddRegisterClass(insn); + } + + l1_sim_result->used_registers_ = 0; + l2_sim_result->used_registers_ = 0; + + for (uint32_t bb_id : loop.GetBlocks()) { + BasicBlock* bb = context_->cfg()->block(bb_id); + + const RegisterLiveness::RegionRegisterLiveness* live_inout = Get(bb_id); + assert(live_inout != nullptr && "Basic block not processed"); + auto l1_block_live_out = + MakeFilterIteratorRange(live_inout->live_out_.begin(), + live_inout->live_out_.end(), belong_to_loop1); + auto l2_block_live_out = + MakeFilterIteratorRange(live_inout->live_out_.begin(), + live_inout->live_out_.end(), belong_to_loop2); + + size_t l1_reg_count = + std::distance(l1_block_live_out.begin(), l1_block_live_out.end()); + size_t l2_reg_count = + std::distance(l2_block_live_out.begin(), l2_block_live_out.end()); + + std::unordered_set die_in_block; + for (Instruction& insn : make_range(bb->rbegin(), bb->rend())) { + if (insn.opcode() == SpvOpPhi) { + break; + } + + bool does_belong_to_loop1 = belong_to_loop1(&insn); + bool does_belong_to_loop2 = belong_to_loop2(&insn); + insn.ForEachInId([live_inout, &die_in_block, &l1_reg_count, &l2_reg_count, + does_belong_to_loop1, does_belong_to_loop2, + this](uint32_t* id) { + Instruction* op_insn = context_->get_def_use_mgr()->GetDef(*id); + if (!CreatesRegisterUsage(op_insn) || + live_inout->live_out_.count(op_insn)) { + // already taken into account. + return; + } + if (!die_in_block.count(*id)) { + if (does_belong_to_loop1) { + l1_reg_count++; + } + if (does_belong_to_loop2) { + l2_reg_count++; + } + die_in_block.insert(*id); + } + }); + l1_sim_result->used_registers_ = + std::max(l1_sim_result->used_registers_, l1_reg_count); + l2_sim_result->used_registers_ = + std::max(l2_sim_result->used_registers_, l2_reg_count); + if (CreatesRegisterUsage(&insn)) { + if (does_belong_to_loop1) { + if (!l1_sim_result->live_in_.count(&insn)) { + l1_sim_result->AddRegisterClass(&insn); + } + l1_reg_count--; + } + if (does_belong_to_loop2) { + if (!l2_sim_result->live_in_.count(&insn)) { + l2_sim_result->AddRegisterClass(&insn); + } + l2_reg_count--; + } + } + } + } +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/register_pressure.h b/source/opt/register_pressure.h new file mode 100644 index 000000000..cb3d2e270 --- /dev/null +++ b/source/opt/register_pressure.h @@ -0,0 +1,196 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_REGISTER_PRESSURE_H_ +#define SOURCE_OPT_REGISTER_PRESSURE_H_ + +#include +#include +#include +#include + +#include "source/opt/function.h" +#include "source/opt/types.h" + +namespace spvtools { +namespace opt { + +class IRContext; +class Loop; +class LoopDescriptor; + +// Handles the register pressure of a function for different regions (function, +// loop, basic block). It also contains some utilities to foresee the register +// pressure following code transformations. +class RegisterLiveness { + public: + // Classification of SSA registers. + struct RegisterClass { + analysis::Type* type_; + bool is_uniform_; + + bool operator==(const RegisterClass& rhs) const { + return std::tie(type_, is_uniform_) == + std::tie(rhs.type_, rhs.is_uniform_); + } + }; + + struct RegionRegisterLiveness { + using LiveSet = std::unordered_set; + using RegClassSetTy = std::vector>; + + // SSA register live when entering the basic block. + LiveSet live_in_; + // SSA register live when exiting the basic block. + LiveSet live_out_; + + // Maximum number of required registers. + size_t used_registers_; + // Break down of the number of required registers per class of register. + RegClassSetTy registers_classes_; + + void Clear() { + live_out_.clear(); + live_in_.clear(); + used_registers_ = 0; + registers_classes_.clear(); + } + + void AddRegisterClass(const RegisterClass& reg_class) { + auto it = std::find_if( + registers_classes_.begin(), registers_classes_.end(), + [®_class](const std::pair& class_count) { + return class_count.first == reg_class; + }); + if (it != registers_classes_.end()) { + it->second++; + } else { + registers_classes_.emplace_back(std::move(reg_class), + static_cast(1)); + } + } + + void AddRegisterClass(Instruction* insn); + }; + + RegisterLiveness(IRContext* context, Function* f) : context_(context) { + Analyze(f); + } + + // Returns liveness and register information for the basic block |bb|. If no + // entry exist for the basic block, the function returns null. + const RegionRegisterLiveness* Get(const BasicBlock* bb) const { + return Get(bb->id()); + } + + // Returns liveness and register information for the basic block id |bb_id|. + // If no entry exist for the basic block, the function returns null. + const RegionRegisterLiveness* Get(uint32_t bb_id) const { + RegionRegisterLivenessMap::const_iterator it = block_pressure_.find(bb_id); + if (it != block_pressure_.end()) { + return &it->second; + } + return nullptr; + } + + IRContext* GetContext() const { return context_; } + + // Returns liveness and register information for the basic block |bb|. If no + // entry exist for the basic block, the function returns null. + RegionRegisterLiveness* Get(const BasicBlock* bb) { return Get(bb->id()); } + + // Returns liveness and register information for the basic block id |bb_id|. + // If no entry exist for the basic block, the function returns null. + RegionRegisterLiveness* Get(uint32_t bb_id) { + RegionRegisterLivenessMap::iterator it = block_pressure_.find(bb_id); + if (it != block_pressure_.end()) { + return &it->second; + } + return nullptr; + } + + // Returns liveness and register information for the basic block id |bb_id| or + // create a new empty entry if no entry already existed. + RegionRegisterLiveness* GetOrInsert(uint32_t bb_id) { + return &block_pressure_[bb_id]; + } + + // Compute the register pressure for the |loop| and store the result into + // |reg_pressure|. The live-in set corresponds to the live-in set of the + // header block, the live-out set of the loop corresponds to the union of the + // live-in sets of each exit basic block. + void ComputeLoopRegisterPressure(const Loop& loop, + RegionRegisterLiveness* reg_pressure) const; + + // Estimate the register pressure for the |l1| and |l2| as if they were making + // one unique loop. The result is stored into |simulation_result|. + void SimulateFusion(const Loop& l1, const Loop& l2, + RegionRegisterLiveness* simulation_result) const; + + // Estimate the register pressure of |loop| after it has been fissioned + // according to |moved_instructions| and |copied_instructions|. The function + // assumes that the fission creates a new loop before |loop|, moves any + // instructions present inside |moved_instructions| and copies any + // instructions present inside |copied_instructions| into this new loop. + // The set |loop1_sim_result| store the simulation result of the loop with the + // moved instructions. The set |loop2_sim_result| store the simulation result + // of the loop with the removed instructions. + void SimulateFission( + const Loop& loop, + const std::unordered_set& moved_instructions, + const std::unordered_set& copied_instructions, + RegionRegisterLiveness* loop1_sim_result, + RegionRegisterLiveness* loop2_sim_result) const; + + private: + using RegionRegisterLivenessMap = + std::unordered_map; + + IRContext* context_; + RegionRegisterLivenessMap block_pressure_; + + void Analyze(Function* f); +}; + +// Handles the register pressure of a function for different regions (function, +// loop, basic block). It also contains some utilities to foresee the register +// pressure following code transformations. +class LivenessAnalysis { + using LivenessAnalysisMap = + std::unordered_map; + + public: + LivenessAnalysis(IRContext* context) : context_(context) {} + + // Computes the liveness analysis for the function |f| and cache the result. + // If the analysis was performed for this function, then the cached analysis + // is returned. + const RegisterLiveness* Get(Function* f) { + LivenessAnalysisMap::iterator it = analysis_cache_.find(f); + if (it != analysis_cache_.end()) { + return &it->second; + } + return &analysis_cache_.emplace(f, RegisterLiveness{context_, f}) + .first->second; + } + + private: + IRContext* context_; + LivenessAnalysisMap analysis_cache_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_REGISTER_PRESSURE_H_ diff --git a/source/opt/remove_duplicates_pass.cpp b/source/opt/remove_duplicates_pass.cpp new file mode 100644 index 000000000..a37e9df9e --- /dev/null +++ b/source/opt/remove_duplicates_pass.cpp @@ -0,0 +1,196 @@ +// Copyright (c) 2017 Pierre Moreau +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/remove_duplicates_pass.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "source/opcode.h" +#include "source/opt/decoration_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/reflect.h" + +namespace spvtools { +namespace opt { + +Pass::Status RemoveDuplicatesPass::Process() { + bool modified = RemoveDuplicateCapabilities(); + modified |= RemoveDuplicatesExtInstImports(); + modified |= RemoveDuplicateTypes(); + modified |= RemoveDuplicateDecorations(); + + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +bool RemoveDuplicatesPass::RemoveDuplicateCapabilities() const { + bool modified = false; + + if (context()->capabilities().empty()) { + return modified; + } + + std::unordered_set capabilities; + for (auto* i = &*context()->capability_begin(); i;) { + auto res = capabilities.insert(i->GetSingleWordOperand(0u)); + + if (res.second) { + // Never seen before, keep it. + i = i->NextNode(); + } else { + // It's a duplicate, remove it. + i = context()->KillInst(i); + modified = true; + } + } + + return modified; +} + +bool RemoveDuplicatesPass::RemoveDuplicatesExtInstImports() const { + bool modified = false; + + if (context()->ext_inst_imports().empty()) { + return modified; + } + + std::unordered_map ext_inst_imports; + for (auto* i = &*context()->ext_inst_import_begin(); i;) { + auto res = ext_inst_imports.emplace( + reinterpret_cast(i->GetInOperand(0u).words.data()), + i->result_id()); + if (res.second) { + // Never seen before, keep it. + i = i->NextNode(); + } else { + // It's a duplicate, remove it. + context()->ReplaceAllUsesWith(i->result_id(), res.first->second); + i = context()->KillInst(i); + modified = true; + } + } + + return modified; +} + +bool RemoveDuplicatesPass::RemoveDuplicateTypes() const { + bool modified = false; + + if (context()->types_values().empty()) { + return modified; + } + + std::vector visited_types; + std::vector to_delete; + for (auto* i = &*context()->types_values_begin(); i; i = i->NextNode()) { + // We only care about types. + if (!spvOpcodeGeneratesType((i->opcode())) && + i->opcode() != SpvOpTypeForwardPointer) { + continue; + } + + // Is the current type equal to one of the types we have aready visited? + SpvId id_to_keep = 0u; + // TODO(dneto0): Use a trie to avoid quadratic behaviour? Extract the + // ResultIdTrie from unify_const_pass.cpp for this. + for (auto j : visited_types) { + if (AreTypesEqual(*i, *j, context())) { + id_to_keep = j->result_id(); + break; + } + } + + if (id_to_keep == 0u) { + // This is a never seen before type, keep it around. + visited_types.emplace_back(i); + } else { + // The same type has already been seen before, remove this one. + context()->KillNamesAndDecorates(i->result_id()); + context()->ReplaceAllUsesWith(i->result_id(), id_to_keep); + modified = true; + to_delete.emplace_back(i); + } + } + + for (auto i : to_delete) { + context()->KillInst(i); + } + + return modified; +} + +// TODO(pierremoreau): Duplicate decoration groups should be removed. For +// example, in +// OpDecorate %1 Constant +// %1 = OpDecorationGroup +// OpDecorate %2 Constant +// %2 = OpDecorationGroup +// OpGroupDecorate %1 %3 +// OpGroupDecorate %2 %4 +// group %2 could be removed. +bool RemoveDuplicatesPass::RemoveDuplicateDecorations() const { + bool modified = false; + + std::vector visited_decorations; + + analysis::DecorationManager decoration_manager(context()->module()); + for (auto* i = &*context()->annotation_begin(); i;) { + // Is the current decoration equal to one of the decorations we have aready + // visited? + bool already_visited = false; + // TODO(dneto0): Use a trie to avoid quadratic behaviour? Extract the + // ResultIdTrie from unify_const_pass.cpp for this. + for (const Instruction* j : visited_decorations) { + if (decoration_manager.AreDecorationsTheSame(&*i, j, false)) { + already_visited = true; + break; + } + } + + if (!already_visited) { + // This is a never seen before decoration, keep it around. + visited_decorations.emplace_back(&*i); + i = i->NextNode(); + } else { + // The same decoration has already been seen before, remove this one. + modified = true; + i = context()->KillInst(i); + } + } + + return modified; +} + +bool RemoveDuplicatesPass::AreTypesEqual(const Instruction& inst1, + const Instruction& inst2, + IRContext* context) { + if (inst1.opcode() != inst2.opcode()) return false; + if (!IsTypeInst(inst1.opcode())) return false; + + const analysis::Type* type1 = + context->get_type_mgr()->GetType(inst1.result_id()); + const analysis::Type* type2 = + context->get_type_mgr()->GetType(inst2.result_id()); + if (type1 && type2 && *type1 == *type2) return true; + + return false; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/remove_duplicates_pass.h b/source/opt/remove_duplicates_pass.h new file mode 100644 index 000000000..8554a987d --- /dev/null +++ b/source/opt/remove_duplicates_pass.h @@ -0,0 +1,67 @@ +// Copyright (c) 2017 Pierre Moreau +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_REMOVE_DUPLICATES_PASS_H_ +#define SOURCE_OPT_REMOVE_DUPLICATES_PASS_H_ + +#include +#include + +#include "source/opt/decoration_manager.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +using IdDecorationsList = + std::unordered_map>; + +// See optimizer.hpp for documentation. +class RemoveDuplicatesPass : public Pass { + public: + const char* name() const override { return "remove-duplicates"; } + Status Process() override; + + // TODO(pierremoreau): Move this function somewhere else (e.g. pass.h or + // within the type manager) + // Returns whether two types are equal, and have the same decorations. + static bool AreTypesEqual(const Instruction& inst1, const Instruction& inst2, + IRContext* context); + + private: + // Remove duplicate capabilities from the module + // + // Returns true if the module was modified, false otherwise. + bool RemoveDuplicateCapabilities() const; + // Remove duplicate extended instruction imports from the module + // + // Returns true if the module was modified, false otherwise. + bool RemoveDuplicatesExtInstImports() const; + // Remove duplicate types from the module + // + // Returns true if the module was modified, false otherwise. + bool RemoveDuplicateTypes() const; + // Remove duplicate decorations from the module + // + // Returns true if the module was modified, false otherwise. + bool RemoveDuplicateDecorations() const; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_REMOVE_DUPLICATES_PASS_H_ diff --git a/source/opt/replace_invalid_opc.cpp b/source/opt/replace_invalid_opc.cpp new file mode 100644 index 000000000..4e0f24f46 --- /dev/null +++ b/source/opt/replace_invalid_opc.cpp @@ -0,0 +1,207 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/replace_invalid_opc.h" + +#include +#include + +namespace spvtools { +namespace opt { + +Pass::Status ReplaceInvalidOpcodePass::Process() { + bool modified = false; + + if (context()->get_feature_mgr()->HasCapability(SpvCapabilityLinkage)) { + return Status::SuccessWithoutChange; + } + + SpvExecutionModel execution_model = GetExecutionModel(); + if (execution_model == SpvExecutionModelKernel) { + // We do not handle kernels. + return Status::SuccessWithoutChange; + } + if (execution_model == SpvExecutionModelMax) { + // Mixed execution models for the entry points. This case is not currently + // handled. + return Status::SuccessWithoutChange; + } + + for (Function& func : *get_module()) { + modified |= RewriteFunction(&func, execution_model); + } + return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); +} + +SpvExecutionModel ReplaceInvalidOpcodePass::GetExecutionModel() { + SpvExecutionModel result = SpvExecutionModelMax; + bool first = true; + for (Instruction& entry_point : get_module()->entry_points()) { + if (first) { + result = + static_cast(entry_point.GetSingleWordInOperand(0)); + first = false; + } else { + SpvExecutionModel current_model = + static_cast(entry_point.GetSingleWordInOperand(0)); + if (current_model != result) { + result = SpvExecutionModelMax; + break; + } + } + } + return result; +} + +bool ReplaceInvalidOpcodePass::RewriteFunction(Function* function, + SpvExecutionModel model) { + bool modified = false; + Instruction* last_line_dbg_inst = nullptr; + function->ForEachInst( + [model, &modified, &last_line_dbg_inst, this](Instruction* inst) { + // Track the debug information so we can have a meaningful message. + if (inst->opcode() == SpvOpLabel || inst->opcode() == SpvOpNoLine) { + last_line_dbg_inst = nullptr; + return; + } else if (inst->opcode() == SpvOpLine) { + last_line_dbg_inst = inst; + return; + } + + bool replace = false; + if (model != SpvExecutionModelFragment && + IsFragmentShaderOnlyInstruction(inst)) { + replace = true; + } + + if (model != SpvExecutionModelTessellationControl && + model != SpvExecutionModelGLCompute) { + if (inst->opcode() == SpvOpControlBarrier) { + assert(model != SpvExecutionModelKernel && + "Expecting to be working on a shader module."); + replace = true; + } + } + + if (replace) { + modified = true; + if (last_line_dbg_inst == nullptr) { + ReplaceInstruction(inst, nullptr, 0, 0); + } else { + // Get the name of the source file. + Instruction* file_name = context()->get_def_use_mgr()->GetDef( + last_line_dbg_inst->GetSingleWordInOperand(0)); + const char* source = reinterpret_cast( + &file_name->GetInOperand(0).words[0]); + + // Get the line number and column number. + uint32_t line_number = + last_line_dbg_inst->GetSingleWordInOperand(1); + uint32_t col_number = last_line_dbg_inst->GetSingleWordInOperand(2); + + // Replace the instruction. + ReplaceInstruction(inst, source, line_number, col_number); + } + } + }, + /* run_on_debug_line_insts = */ true); + return modified; +} + +bool ReplaceInvalidOpcodePass::IsFragmentShaderOnlyInstruction( + Instruction* inst) { + switch (inst->opcode()) { + case SpvOpDPdx: + case SpvOpDPdy: + case SpvOpFwidth: + case SpvOpDPdxFine: + case SpvOpDPdyFine: + case SpvOpFwidthFine: + case SpvOpDPdxCoarse: + case SpvOpDPdyCoarse: + case SpvOpFwidthCoarse: + case SpvOpImageSampleImplicitLod: + case SpvOpImageSampleDrefImplicitLod: + case SpvOpImageSampleProjImplicitLod: + case SpvOpImageSampleProjDrefImplicitLod: + case SpvOpImageSparseSampleImplicitLod: + case SpvOpImageSparseSampleDrefImplicitLod: + case SpvOpImageQueryLod: + // TODO: Teach |ReplaceInstruction| to handle block terminators. Then + // uncomment the OpKill case. + // case SpvOpKill: + return true; + default: + return false; + } +} + +void ReplaceInvalidOpcodePass::ReplaceInstruction(Instruction* inst, + const char* source, + uint32_t line_number, + uint32_t column_number) { + if (inst->result_id() != 0) { + uint32_t const_id = GetSpecialConstant(inst->type_id()); + context()->KillNamesAndDecorates(inst); + context()->ReplaceAllUsesWith(inst->result_id(), const_id); + } + assert(!inst->IsBlockTerminator() && + "We cannot simply delete a block terminator. It must be replaced " + "with something."); + if (consumer()) { + std::string message = BuildWarningMessage(inst->opcode()); + consumer()(SPV_MSG_WARNING, source, {line_number, column_number, 0}, + message.c_str()); + } + context()->KillInst(inst); +} + +uint32_t ReplaceInvalidOpcodePass::GetSpecialConstant(uint32_t type_id) { + const analysis::Constant* special_const = nullptr; + analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + + Instruction* type = context()->get_def_use_mgr()->GetDef(type_id); + if (type->opcode() == SpvOpTypeVector) { + uint32_t component_const = + GetSpecialConstant(type->GetSingleWordInOperand(0)); + std::vector ids; + for (uint32_t i = 0; i < type->GetSingleWordInOperand(1); ++i) { + ids.push_back(component_const); + } + special_const = const_mgr->GetConstant(type_mgr->GetType(type_id), ids); + } else { + assert(type->opcode() == SpvOpTypeInt || type->opcode() == SpvOpTypeFloat); + std::vector literal_words; + for (uint32_t i = 0; i < type->GetSingleWordInOperand(0); i += 32) { + literal_words.push_back(0xDEADBEEF); + } + special_const = + const_mgr->GetConstant(type_mgr->GetType(type_id), literal_words); + } + assert(special_const != nullptr); + return const_mgr->GetDefiningInstruction(special_const)->result_id(); +} + +std::string ReplaceInvalidOpcodePass::BuildWarningMessage(SpvOp opcode) { + spv_opcode_desc opcode_info; + context()->grammar().lookupOpcode(opcode, &opcode_info); + std::string message = "Removing "; + message += opcode_info->name; + message += " instruction because of incompatible execution model."; + return message; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/replace_invalid_opc.h b/source/opt/replace_invalid_opc.h new file mode 100644 index 000000000..426bcac5e --- /dev/null +++ b/source/opt/replace_invalid_opc.h @@ -0,0 +1,67 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_REPLACE_INVALID_OPC_H_ +#define SOURCE_OPT_REPLACE_INVALID_OPC_H_ + +#include + +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// This pass will runs on shader modules only. It will replace the result of +// instructions that are valid for shader modules, but not the current shader +// stage, with a constant value. If the instruction does not have a return +// value, the instruction will simply be deleted. +class ReplaceInvalidOpcodePass : public Pass { + public: + const char* name() const override { return "replace-invalid-opcode"; } + Status Process() override; + + private: + // Returns the execution model that is used by every entry point in the + // module. If more than one execution model is used in the module, then the + // return value is SpvExecutionModelMax. + SpvExecutionModel GetExecutionModel(); + + // Replaces all instructions in |function| that are invalid with execution + // model |mode|, but valid for another shader model, with a special constant + // value. See |GetSpecialConstant|. + bool RewriteFunction(Function* function, SpvExecutionModel mode); + + // Returns true if |inst| is valid for fragment shaders only. + bool IsFragmentShaderOnlyInstruction(Instruction* inst); + + // Replaces all uses of the result of |inst|, if there is one, with the id of + // a special constant. Then |inst| is killed. |inst| cannot be a block + // terminator because the basic block will then become invalid. |inst| is no + // longer valid after calling this function. + void ReplaceInstruction(Instruction* inst, const char* source, + uint32_t line_number, uint32_t column_number); + + // Returns the id of a constant with type |type_id|. The type must be an + // integer, float, or vector. For scalar types, the hex representation of the + // constant will be the concatenation of 0xDEADBEEF with itself until the + // width of the type has been reached. For a vector, each element of the + // constant will be constructed the same way. + uint32_t GetSpecialConstant(uint32_t type_id); + std::string BuildWarningMessage(SpvOp opcode); +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_REPLACE_INVALID_OPC_H_ diff --git a/source/opt/scalar_analysis.cpp b/source/opt/scalar_analysis.cpp new file mode 100644 index 000000000..38555e649 --- /dev/null +++ b/source/opt/scalar_analysis.cpp @@ -0,0 +1,988 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/scalar_analysis.h" + +#include +#include +#include +#include + +#include "source/opt/ir_context.h" + +// Transforms a given scalar operation instruction into a DAG representation. +// +// 1. Take an instruction and traverse its operands until we reach a +// constant node or an instruction which we do not know how to compute the +// value, such as a load. +// +// 2. Create a new node for each instruction traversed and build the nodes for +// the in operands of that instruction as well. +// +// 3. Add the operand nodes as children of the first and hash the node. Use the +// hash to see if the node is already in the cache. We ensure the children are +// always in sorted order so that two nodes with the same children but inserted +// in a different order have the same hash and so that the overloaded operator== +// will return true. If the node is already in the cache return the cached +// version instead. +// +// 4. The created DAG can then be simplified by +// ScalarAnalysis::SimplifyExpression, implemented in +// scalar_analysis_simplification.cpp. See that file for further information on +// the simplification process. +// + +namespace spvtools { +namespace opt { + +uint32_t SENode::NumberOfNodes = 0; + +ScalarEvolutionAnalysis::ScalarEvolutionAnalysis(IRContext* context) + : context_(context), pretend_equal_{} { + // Create and cached the CantComputeNode. + cached_cant_compute_ = + GetCachedOrAdd(std::unique_ptr(new SECantCompute(this))); +} + +SENode* ScalarEvolutionAnalysis::CreateNegation(SENode* operand) { + // If operand is can't compute then the whole graph is can't compute. + if (operand->IsCantCompute()) return CreateCantComputeNode(); + + if (operand->GetType() == SENode::Constant) { + return CreateConstant(-operand->AsSEConstantNode()->FoldToSingleValue()); + } + std::unique_ptr negation_node{new SENegative(this)}; + negation_node->AddChild(operand); + return GetCachedOrAdd(std::move(negation_node)); +} + +SENode* ScalarEvolutionAnalysis::CreateConstant(int64_t integer) { + return GetCachedOrAdd( + std::unique_ptr(new SEConstantNode(this, integer))); +} + +SENode* ScalarEvolutionAnalysis::CreateRecurrentExpression( + const Loop* loop, SENode* offset, SENode* coefficient) { + assert(loop && "Recurrent add expressions must have a valid loop."); + + // If operands are can't compute then the whole graph is can't compute. + if (offset->IsCantCompute() || coefficient->IsCantCompute()) + return CreateCantComputeNode(); + + const Loop* loop_to_use = nullptr; + if (pretend_equal_[loop]) { + loop_to_use = pretend_equal_[loop]; + } else { + loop_to_use = loop; + } + + std::unique_ptr phi_node{ + new SERecurrentNode(this, loop_to_use)}; + phi_node->AddOffset(offset); + phi_node->AddCoefficient(coefficient); + + return GetCachedOrAdd(std::move(phi_node)); +} + +SENode* ScalarEvolutionAnalysis::AnalyzeMultiplyOp( + const Instruction* multiply) { + assert(multiply->opcode() == SpvOp::SpvOpIMul && + "Multiply node did not come from a multiply instruction"); + analysis::DefUseManager* def_use = context_->get_def_use_mgr(); + + SENode* op1 = + AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(0))); + SENode* op2 = + AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(1))); + + return CreateMultiplyNode(op1, op2); +} + +SENode* ScalarEvolutionAnalysis::CreateMultiplyNode(SENode* operand_1, + SENode* operand_2) { + // If operands are can't compute then the whole graph is can't compute. + if (operand_1->IsCantCompute() || operand_2->IsCantCompute()) + return CreateCantComputeNode(); + + if (operand_1->GetType() == SENode::Constant && + operand_2->GetType() == SENode::Constant) { + return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() * + operand_2->AsSEConstantNode()->FoldToSingleValue()); + } + + std::unique_ptr multiply_node{new SEMultiplyNode(this)}; + + multiply_node->AddChild(operand_1); + multiply_node->AddChild(operand_2); + + return GetCachedOrAdd(std::move(multiply_node)); +} + +SENode* ScalarEvolutionAnalysis::CreateSubtraction(SENode* operand_1, + SENode* operand_2) { + // Fold if both operands are constant. + if (operand_1->GetType() == SENode::Constant && + operand_2->GetType() == SENode::Constant) { + return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() - + operand_2->AsSEConstantNode()->FoldToSingleValue()); + } + + return CreateAddNode(operand_1, CreateNegation(operand_2)); +} + +SENode* ScalarEvolutionAnalysis::CreateAddNode(SENode* operand_1, + SENode* operand_2) { + // Fold if both operands are constant and the |simplify| flag is true. + if (operand_1->GetType() == SENode::Constant && + operand_2->GetType() == SENode::Constant) { + return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() + + operand_2->AsSEConstantNode()->FoldToSingleValue()); + } + + // If operands are can't compute then the whole graph is can't compute. + if (operand_1->IsCantCompute() || operand_2->IsCantCompute()) + return CreateCantComputeNode(); + + std::unique_ptr add_node{new SEAddNode(this)}; + + add_node->AddChild(operand_1); + add_node->AddChild(operand_2); + + return GetCachedOrAdd(std::move(add_node)); +} + +SENode* ScalarEvolutionAnalysis::AnalyzeInstruction(const Instruction* inst) { + auto itr = recurrent_node_map_.find(inst); + if (itr != recurrent_node_map_.end()) return itr->second; + + SENode* output = nullptr; + switch (inst->opcode()) { + case SpvOp::SpvOpPhi: { + output = AnalyzePhiInstruction(inst); + break; + } + case SpvOp::SpvOpConstant: + case SpvOp::SpvOpConstantNull: { + output = AnalyzeConstant(inst); + break; + } + case SpvOp::SpvOpISub: + case SpvOp::SpvOpIAdd: { + output = AnalyzeAddOp(inst); + break; + } + case SpvOp::SpvOpIMul: { + output = AnalyzeMultiplyOp(inst); + break; + } + default: { + output = CreateValueUnknownNode(inst); + break; + } + } + + return output; +} + +SENode* ScalarEvolutionAnalysis::AnalyzeConstant(const Instruction* inst) { + if (inst->opcode() == SpvOp::SpvOpConstantNull) return CreateConstant(0); + + assert(inst->opcode() == SpvOp::SpvOpConstant); + assert(inst->NumInOperands() == 1); + int64_t value = 0; + + // Look up the instruction in the constant manager. + const analysis::Constant* constant = + context_->get_constant_mgr()->FindDeclaredConstant(inst->result_id()); + + if (!constant) return CreateCantComputeNode(); + + const analysis::IntConstant* int_constant = constant->AsIntConstant(); + + // Exit out if it is a 64 bit integer. + if (!int_constant || int_constant->words().size() != 1) + return CreateCantComputeNode(); + + if (int_constant->type()->AsInteger()->IsSigned()) { + value = int_constant->GetS32BitValue(); + } else { + value = int_constant->GetU32BitValue(); + } + + return CreateConstant(value); +} + +// Handles both addition and subtraction. If the |sub| flag is set then the +// addition will be op1+(-op2) otherwise op1+op2. +SENode* ScalarEvolutionAnalysis::AnalyzeAddOp(const Instruction* inst) { + assert((inst->opcode() == SpvOp::SpvOpIAdd || + inst->opcode() == SpvOp::SpvOpISub) && + "Add node must be created from a OpIAdd or OpISub instruction"); + + analysis::DefUseManager* def_use = context_->get_def_use_mgr(); + + SENode* op1 = + AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(0))); + + SENode* op2 = + AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(1))); + + // To handle subtraction we wrap the second operand in a unary negation node. + if (inst->opcode() == SpvOp::SpvOpISub) { + op2 = CreateNegation(op2); + } + + return CreateAddNode(op1, op2); +} + +SENode* ScalarEvolutionAnalysis::AnalyzePhiInstruction(const Instruction* phi) { + // The phi should only have two incoming value pairs. + if (phi->NumInOperands() != 4) { + return CreateCantComputeNode(); + } + + analysis::DefUseManager* def_use = context_->get_def_use_mgr(); + + // Get the basic block this instruction belongs to. + BasicBlock* basic_block = + context_->get_instr_block(const_cast(phi)); + + // And then the function that the basic blocks belongs to. + Function* function = basic_block->GetParent(); + + // Use the function to get the loop descriptor. + LoopDescriptor* loop_descriptor = context_->GetLoopDescriptor(function); + + // We only handle phis in loops at the moment. + if (!loop_descriptor) return CreateCantComputeNode(); + + // Get the innermost loop which this block belongs to. + Loop* loop = (*loop_descriptor)[basic_block->id()]; + + // If the loop doesn't exist or doesn't have a preheader or latch block, exit + // out. + if (!loop || !loop->GetLatchBlock() || !loop->GetPreHeaderBlock() || + loop->GetHeaderBlock() != basic_block) + return recurrent_node_map_[phi] = CreateCantComputeNode(); + + const Loop* loop_to_use = nullptr; + if (pretend_equal_[loop]) { + loop_to_use = pretend_equal_[loop]; + } else { + loop_to_use = loop; + } + std::unique_ptr phi_node{ + new SERecurrentNode(this, loop_to_use)}; + + // We add the node to this map to allow it to be returned before the node is + // fully built. This is needed as the subsequent call to AnalyzeInstruction + // could lead back to this |phi| instruction so we return the pointer + // immediately in AnalyzeInstruction to break the recursion. + recurrent_node_map_[phi] = phi_node.get(); + + // Traverse the operands of the instruction an create new nodes for each one. + for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) { + uint32_t value_id = phi->GetSingleWordInOperand(i); + uint32_t incoming_label_id = phi->GetSingleWordInOperand(i + 1); + + Instruction* value_inst = def_use->GetDef(value_id); + SENode* value_node = AnalyzeInstruction(value_inst); + + // If any operand is CantCompute then the whole graph is CantCompute. + if (value_node->IsCantCompute()) + return recurrent_node_map_[phi] = CreateCantComputeNode(); + + // If the value is coming from the preheader block then the value is the + // initial value of the phi. + if (incoming_label_id == loop->GetPreHeaderBlock()->id()) { + phi_node->AddOffset(value_node); + } else if (incoming_label_id == loop->GetLatchBlock()->id()) { + // Assumed to be in the form of step + phi. + if (value_node->GetType() != SENode::Add) + return recurrent_node_map_[phi] = CreateCantComputeNode(); + + SENode* step_node = nullptr; + SENode* phi_operand = nullptr; + SENode* operand_1 = value_node->GetChild(0); + SENode* operand_2 = value_node->GetChild(1); + + // Find which node is the step term. + if (!operand_1->AsSERecurrentNode()) + step_node = operand_1; + else if (!operand_2->AsSERecurrentNode()) + step_node = operand_2; + + // Find which node is the recurrent expression. + if (operand_1->AsSERecurrentNode()) + phi_operand = operand_1; + else if (operand_2->AsSERecurrentNode()) + phi_operand = operand_2; + + // If it is not in the form step + phi exit out. + if (!(step_node && phi_operand)) + return recurrent_node_map_[phi] = CreateCantComputeNode(); + + // If the phi operand is not the same phi node exit out. + if (phi_operand != phi_node.get()) + return recurrent_node_map_[phi] = CreateCantComputeNode(); + + if (!IsLoopInvariant(loop, step_node)) + return recurrent_node_map_[phi] = CreateCantComputeNode(); + + phi_node->AddCoefficient(step_node); + } + } + + // Once the node is fully built we update the map with the version from the + // cache (if it has already been added to the cache). + return recurrent_node_map_[phi] = GetCachedOrAdd(std::move(phi_node)); +} + +SENode* ScalarEvolutionAnalysis::CreateValueUnknownNode( + const Instruction* inst) { + std::unique_ptr load_node{ + new SEValueUnknown(this, inst->result_id())}; + return GetCachedOrAdd(std::move(load_node)); +} + +SENode* ScalarEvolutionAnalysis::CreateCantComputeNode() { + return cached_cant_compute_; +} + +// Add the created node into the cache of nodes. If it already exists return it. +SENode* ScalarEvolutionAnalysis::GetCachedOrAdd( + std::unique_ptr prospective_node) { + auto itr = node_cache_.find(prospective_node); + if (itr != node_cache_.end()) { + return (*itr).get(); + } + + SENode* raw_ptr_to_node = prospective_node.get(); + node_cache_.insert(std::move(prospective_node)); + return raw_ptr_to_node; +} + +bool ScalarEvolutionAnalysis::IsLoopInvariant(const Loop* loop, + const SENode* node) const { + for (auto itr = node->graph_cbegin(); itr != node->graph_cend(); ++itr) { + if (const SERecurrentNode* rec = itr->AsSERecurrentNode()) { + const BasicBlock* header = rec->GetLoop()->GetHeaderBlock(); + + // If the loop which the recurrent expression belongs to is either |loop + // or a nested loop inside |loop| then we assume it is variant. + if (loop->IsInsideLoop(header)) { + return false; + } + } else if (const SEValueUnknown* unknown = itr->AsSEValueUnknown()) { + // If the instruction is inside the loop we conservatively assume it is + // loop variant. + if (loop->IsInsideLoop(unknown->ResultId())) return false; + } + } + + return true; +} + +SENode* ScalarEvolutionAnalysis::GetCoefficientFromRecurrentTerm( + SENode* node, const Loop* loop) { + // Traverse the DAG to find the recurrent expression belonging to |loop|. + for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) { + SERecurrentNode* rec = itr->AsSERecurrentNode(); + if (rec && rec->GetLoop() == loop) { + return rec->GetCoefficient(); + } + } + return CreateConstant(0); +} + +SENode* ScalarEvolutionAnalysis::UpdateChildNode(SENode* parent, + SENode* old_child, + SENode* new_child) { + // Only handles add. + if (parent->GetType() != SENode::Add) return parent; + + std::vector new_children; + for (SENode* child : *parent) { + if (child == old_child) { + new_children.push_back(new_child); + } else { + new_children.push_back(child); + } + } + + std::unique_ptr add_node{new SEAddNode(this)}; + for (SENode* child : new_children) { + add_node->AddChild(child); + } + + return SimplifyExpression(GetCachedOrAdd(std::move(add_node))); +} + +// Rebuild the |node| eliminating, if it exists, the recurrent term which +// belongs to the |loop|. +SENode* ScalarEvolutionAnalysis::BuildGraphWithoutRecurrentTerm( + SENode* node, const Loop* loop) { + // If the node is already a recurrent expression belonging to loop then just + // return the offset. + SERecurrentNode* recurrent = node->AsSERecurrentNode(); + if (recurrent) { + if (recurrent->GetLoop() == loop) { + return recurrent->GetOffset(); + } else { + return node; + } + } + + std::vector new_children; + // Otherwise find the recurrent node in the children of this node. + for (auto itr : *node) { + recurrent = itr->AsSERecurrentNode(); + if (recurrent && recurrent->GetLoop() == loop) { + new_children.push_back(recurrent->GetOffset()); + } else { + new_children.push_back(itr); + } + } + + std::unique_ptr add_node{new SEAddNode(this)}; + for (SENode* child : new_children) { + add_node->AddChild(child); + } + + return SimplifyExpression(GetCachedOrAdd(std::move(add_node))); +} + +// Return the recurrent term belonging to |loop| if it appears in the graph +// starting at |node| or null if it doesn't. +SERecurrentNode* ScalarEvolutionAnalysis::GetRecurrentTerm(SENode* node, + const Loop* loop) { + for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) { + SERecurrentNode* rec = itr->AsSERecurrentNode(); + if (rec && rec->GetLoop() == loop) { + return rec; + } + } + return nullptr; +} +std::string SENode::AsString() const { + switch (GetType()) { + case Constant: + return "Constant"; + case RecurrentAddExpr: + return "RecurrentAddExpr"; + case Add: + return "Add"; + case Negative: + return "Negative"; + case Multiply: + return "Multiply"; + case ValueUnknown: + return "Value Unknown"; + case CanNotCompute: + return "Can not compute"; + } + return "NULL"; +} + +bool SENode::operator==(const SENode& other) const { + if (GetType() != other.GetType()) return false; + + if (other.GetChildren().size() != children_.size()) return false; + + const SERecurrentNode* this_as_recurrent = AsSERecurrentNode(); + + // Check the children are the same, for SERecurrentNodes we need to check the + // offset and coefficient manually as the child vector is sorted by ids so the + // offset/coefficient information is lost. + if (!this_as_recurrent) { + for (size_t index = 0; index < children_.size(); ++index) { + if (other.GetChildren()[index] != children_[index]) return false; + } + } else { + const SERecurrentNode* other_as_recurrent = other.AsSERecurrentNode(); + + // We've already checked the types are the same, this should not fail if + // this->AsSERecurrentNode() succeeded. + assert(other_as_recurrent); + + if (this_as_recurrent->GetCoefficient() != + other_as_recurrent->GetCoefficient()) + return false; + + if (this_as_recurrent->GetOffset() != other_as_recurrent->GetOffset()) + return false; + + if (this_as_recurrent->GetLoop() != other_as_recurrent->GetLoop()) + return false; + } + + // If we're dealing with a value unknown node check both nodes were created by + // the same instruction. + if (GetType() == SENode::ValueUnknown) { + if (AsSEValueUnknown()->ResultId() != + other.AsSEValueUnknown()->ResultId()) { + return false; + } + } + + if (AsSEConstantNode()) { + if (AsSEConstantNode()->FoldToSingleValue() != + other.AsSEConstantNode()->FoldToSingleValue()) + return false; + } + + return true; +} + +bool SENode::operator!=(const SENode& other) const { return !(*this == other); } + +namespace { +// Helper functions to insert 32/64 bit values into the 32 bit hash string. This +// allows us to add pointers to the string by reinterpreting the pointers as +// uintptr_t. PushToString will deduce the type, call sizeof on it and use +// that size to call into the correct PushToStringImpl functor depending on +// whether it is 32 or 64 bit. + +template +struct PushToStringImpl; + +template +struct PushToStringImpl { + void operator()(T id, std::u32string* str) { + str->push_back(static_cast(id >> 32)); + str->push_back(static_cast(id)); + } +}; + +template +struct PushToStringImpl { + void operator()(T id, std::u32string* str) { + str->push_back(static_cast(id)); + } +}; + +template +static void PushToString(T id, std::u32string* str) { + PushToStringImpl{}(id, str); +} + +} // namespace + +// Implements the hashing of SENodes. +size_t SENodeHash::operator()(const SENode* node) const { + // Concatinate the terms into a string which we can hash. + std::u32string hash_string{}; + + // Hashing the type as a string is safer than hashing the enum as the enum is + // very likely to collide with constants. + for (char ch : node->AsString()) { + hash_string.push_back(static_cast(ch)); + } + + // We just ignore the literal value unless it is a constant. + if (node->GetType() == SENode::Constant) + PushToString(node->AsSEConstantNode()->FoldToSingleValue(), &hash_string); + + const SERecurrentNode* recurrent = node->AsSERecurrentNode(); + + // If we're dealing with a recurrent expression hash the loop as well so that + // nested inductions like i=0,i++ and j=0,j++ correspond to different nodes. + if (recurrent) { + PushToString(reinterpret_cast(recurrent->GetLoop()), + &hash_string); + + // Recurrent expressions can't be hashed using the normal method as the + // order of coefficient and offset matters to the hash. + PushToString(reinterpret_cast(recurrent->GetCoefficient()), + &hash_string); + PushToString(reinterpret_cast(recurrent->GetOffset()), + &hash_string); + + return std::hash{}(hash_string); + } + + // Hash the result id of the original instruction which created this node if + // it is a value unknown node. + if (node->GetType() == SENode::ValueUnknown) { + PushToString(node->AsSEValueUnknown()->ResultId(), &hash_string); + } + + // Hash the pointers of the child nodes, each SENode has a unique pointer + // associated with it. + const std::vector& children = node->GetChildren(); + for (const SENode* child : children) { + PushToString(reinterpret_cast(child), &hash_string); + } + + return std::hash{}(hash_string); +} + +// This overload is the actual overload used by the node_cache_ set. +size_t SENodeHash::operator()(const std::unique_ptr& node) const { + return this->operator()(node.get()); +} + +void SENode::DumpDot(std::ostream& out, bool recurse) const { + size_t unique_id = std::hash{}(this); + out << unique_id << " [label=\"" << AsString() << " "; + if (GetType() == SENode::Constant) { + out << "\nwith value: " << this->AsSEConstantNode()->FoldToSingleValue(); + } + out << "\"]\n"; + for (const SENode* child : children_) { + size_t child_unique_id = std::hash{}(child); + out << unique_id << " -> " << child_unique_id << " \n"; + if (recurse) child->DumpDot(out, true); + } +} + +namespace { +class IsGreaterThanZero { + public: + explicit IsGreaterThanZero(IRContext* context) : context_(context) {} + + // Determine if the value of |node| is always strictly greater than zero if + // |or_equal_zero| is false or greater or equal to zero if |or_equal_zero| is + // true. It returns true is the evaluation was able to conclude something, in + // which case the result is stored in |result|. + // The algorithm work by going through all the nodes and determine the + // sign of each of them. + bool Eval(const SENode* node, bool or_equal_zero, bool* result) { + *result = false; + switch (Visit(node)) { + case Signedness::kPositiveOrNegative: { + return false; + } + case Signedness::kStrictlyNegative: { + *result = false; + break; + } + case Signedness::kNegative: { + if (!or_equal_zero) { + return false; + } + *result = false; + break; + } + case Signedness::kStrictlyPositive: { + *result = true; + break; + } + case Signedness::kPositive: { + if (!or_equal_zero) { + return false; + } + *result = true; + break; + } + } + return true; + } + + private: + enum class Signedness { + kPositiveOrNegative, // Yield a value positive or negative. + kStrictlyNegative, // Yield a value strictly less than 0. + kNegative, // Yield a value less or equal to 0. + kStrictlyPositive, // Yield a value strictly greater than 0. + kPositive // Yield a value greater or equal to 0. + }; + + // Combine the signedness according to arithmetic rules of a given operator. + using Combiner = std::function; + + // Returns a functor to interpret the signedness of 2 expressions as if they + // were added. + Combiner GetAddCombiner() const { + return [](Signedness lhs, Signedness rhs) { + switch (lhs) { + case Signedness::kPositiveOrNegative: + break; + case Signedness::kStrictlyNegative: + if (rhs == Signedness::kStrictlyNegative || + rhs == Signedness::kNegative) + return lhs; + break; + case Signedness::kNegative: { + if (rhs == Signedness::kStrictlyNegative) + return Signedness::kStrictlyNegative; + if (rhs == Signedness::kNegative) return Signedness::kNegative; + break; + } + case Signedness::kStrictlyPositive: { + if (rhs == Signedness::kStrictlyPositive || + rhs == Signedness::kPositive) { + return Signedness::kStrictlyPositive; + } + break; + } + case Signedness::kPositive: { + if (rhs == Signedness::kStrictlyPositive) + return Signedness::kStrictlyPositive; + if (rhs == Signedness::kPositive) return Signedness::kPositive; + break; + } + } + return Signedness::kPositiveOrNegative; + }; + } + + // Returns a functor to interpret the signedness of 2 expressions as if they + // were multiplied. + Combiner GetMulCombiner() const { + return [](Signedness lhs, Signedness rhs) { + switch (lhs) { + case Signedness::kPositiveOrNegative: + break; + case Signedness::kStrictlyNegative: { + switch (rhs) { + case Signedness::kPositiveOrNegative: { + break; + } + case Signedness::kStrictlyNegative: { + return Signedness::kStrictlyPositive; + } + case Signedness::kNegative: { + return Signedness::kPositive; + } + case Signedness::kStrictlyPositive: { + return Signedness::kStrictlyNegative; + } + case Signedness::kPositive: { + return Signedness::kNegative; + } + } + break; + } + case Signedness::kNegative: { + switch (rhs) { + case Signedness::kPositiveOrNegative: { + break; + } + case Signedness::kStrictlyNegative: + case Signedness::kNegative: { + return Signedness::kPositive; + } + case Signedness::kStrictlyPositive: + case Signedness::kPositive: { + return Signedness::kNegative; + } + } + break; + } + case Signedness::kStrictlyPositive: { + return rhs; + } + case Signedness::kPositive: { + switch (rhs) { + case Signedness::kPositiveOrNegative: { + break; + } + case Signedness::kStrictlyNegative: + case Signedness::kNegative: { + return Signedness::kNegative; + } + case Signedness::kStrictlyPositive: + case Signedness::kPositive: { + return Signedness::kPositive; + } + } + break; + } + } + return Signedness::kPositiveOrNegative; + }; + } + + Signedness Visit(const SENode* node) { + switch (node->GetType()) { + case SENode::Constant: + return Visit(node->AsSEConstantNode()); + break; + case SENode::RecurrentAddExpr: + return Visit(node->AsSERecurrentNode()); + break; + case SENode::Negative: + return Visit(node->AsSENegative()); + break; + case SENode::CanNotCompute: + return Visit(node->AsSECantCompute()); + break; + case SENode::ValueUnknown: + return Visit(node->AsSEValueUnknown()); + break; + case SENode::Add: + return VisitExpr(node, GetAddCombiner()); + break; + case SENode::Multiply: + return VisitExpr(node, GetMulCombiner()); + break; + } + return Signedness::kPositiveOrNegative; + } + + // Returns the signedness of a constant |node|. + Signedness Visit(const SEConstantNode* node) { + if (0 == node->FoldToSingleValue()) return Signedness::kPositive; + if (0 < node->FoldToSingleValue()) return Signedness::kStrictlyPositive; + if (0 > node->FoldToSingleValue()) return Signedness::kStrictlyNegative; + return Signedness::kPositiveOrNegative; + } + + // Returns the signedness of an unknown |node| based on its type. + Signedness Visit(const SEValueUnknown* node) { + Instruction* insn = context_->get_def_use_mgr()->GetDef(node->ResultId()); + analysis::Type* type = context_->get_type_mgr()->GetType(insn->type_id()); + assert(type && "Can't retrieve a type for the instruction"); + analysis::Integer* int_type = type->AsInteger(); + assert(type && "Can't retrieve an integer type for the instruction"); + return int_type->IsSigned() ? Signedness::kPositiveOrNegative + : Signedness::kPositive; + } + + // Returns the signedness of a recurring expression. + Signedness Visit(const SERecurrentNode* node) { + Signedness coeff_sign = Visit(node->GetCoefficient()); + // SERecurrentNode represent an affine expression in the range [0, + // loop_bound], so the result cannot be strictly positive or negative. + switch (coeff_sign) { + default: + break; + case Signedness::kStrictlyNegative: + coeff_sign = Signedness::kNegative; + break; + case Signedness::kStrictlyPositive: + coeff_sign = Signedness::kPositive; + break; + } + return GetAddCombiner()(coeff_sign, Visit(node->GetOffset())); + } + + // Returns the signedness of a negation |node|. + Signedness Visit(const SENegative* node) { + switch (Visit(*node->begin())) { + case Signedness::kPositiveOrNegative: { + return Signedness::kPositiveOrNegative; + } + case Signedness::kStrictlyNegative: { + return Signedness::kStrictlyPositive; + } + case Signedness::kNegative: { + return Signedness::kPositive; + } + case Signedness::kStrictlyPositive: { + return Signedness::kStrictlyNegative; + } + case Signedness::kPositive: { + return Signedness::kNegative; + } + } + return Signedness::kPositiveOrNegative; + } + + Signedness Visit(const SECantCompute*) { + return Signedness::kPositiveOrNegative; + } + + // Returns the signedness of a binary expression by using the combiner + // |reduce|. + Signedness VisitExpr( + const SENode* node, + std::function reduce) { + Signedness result = Visit(*node->begin()); + for (const SENode* operand : make_range(++node->begin(), node->end())) { + if (result == Signedness::kPositiveOrNegative) { + return Signedness::kPositiveOrNegative; + } + result = reduce(result, Visit(operand)); + } + return result; + } + + IRContext* context_; +}; +} // namespace + +bool ScalarEvolutionAnalysis::IsAlwaysGreaterThanZero(SENode* node, + bool* is_gt_zero) const { + return IsGreaterThanZero(context_).Eval(node, false, is_gt_zero); +} + +bool ScalarEvolutionAnalysis::IsAlwaysGreaterOrEqualToZero( + SENode* node, bool* is_ge_zero) const { + return IsGreaterThanZero(context_).Eval(node, true, is_ge_zero); +} + +namespace { + +// Remove |node| from the |mul| chain (of the form A * ... * |node| * ... * Z), +// if |node| is not in the chain, returns the original chain. +static SENode* RemoveOneNodeFromMultiplyChain(SEMultiplyNode* mul, + const SENode* node) { + SENode* lhs = mul->GetChildren()[0]; + SENode* rhs = mul->GetChildren()[1]; + if (lhs == node) { + return rhs; + } + if (rhs == node) { + return lhs; + } + if (lhs->AsSEMultiplyNode()) { + SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), node); + if (res != lhs) + return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs); + } + if (rhs->AsSEMultiplyNode()) { + SENode* res = RemoveOneNodeFromMultiplyChain(rhs->AsSEMultiplyNode(), node); + if (res != rhs) + return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs); + } + + return mul; +} +} // namespace + +std::pair SExpression::operator/( + SExpression rhs_wrapper) const { + SENode* lhs = node_; + SENode* rhs = rhs_wrapper.node_; + // Check for division by 0. + if (rhs->AsSEConstantNode() && + !rhs->AsSEConstantNode()->FoldToSingleValue()) { + return {scev_->CreateCantComputeNode(), 0}; + } + + // Trivial case. + if (lhs->AsSEConstantNode() && rhs->AsSEConstantNode()) { + int64_t lhs_value = lhs->AsSEConstantNode()->FoldToSingleValue(); + int64_t rhs_value = rhs->AsSEConstantNode()->FoldToSingleValue(); + return {scev_->CreateConstant(lhs_value / rhs_value), + lhs_value % rhs_value}; + } + + // look for a "c U / U" pattern. + if (lhs->AsSEMultiplyNode()) { + assert(lhs->GetChildren().size() == 2 && + "More than 2 operand for a multiply node."); + SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), rhs); + if (res != lhs) { + return {res, 0}; + } + } + + return {scev_->CreateCantComputeNode(), 0}; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/scalar_analysis.h b/source/opt/scalar_analysis.h new file mode 100644 index 000000000..fb6d631f5 --- /dev/null +++ b/source/opt/scalar_analysis.h @@ -0,0 +1,314 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_SCALAR_ANALYSIS_H_ +#define SOURCE_OPT_SCALAR_ANALYSIS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "source/opt/basic_block.h" +#include "source/opt/instruction.h" +#include "source/opt/scalar_analysis_nodes.h" + +namespace spvtools { +namespace opt { + +class IRContext; +class Loop; + +// Manager for the Scalar Evolution analysis. Creates and maintains a DAG of +// scalar operations generated from analysing the use def graph from incoming +// instructions. Each node is hashed as it is added so like node (for instance, +// two induction variables i=0,i++ and j=0,j++) become the same node. After +// creating a DAG with AnalyzeInstruction it can the be simplified into a more +// usable form with SimplifyExpression. +class ScalarEvolutionAnalysis { + public: + explicit ScalarEvolutionAnalysis(IRContext* context); + + // Create a unary negative node on |operand|. + SENode* CreateNegation(SENode* operand); + + // Creates a subtraction between the two operands by adding |operand_1| to the + // negation of |operand_2|. + SENode* CreateSubtraction(SENode* operand_1, SENode* operand_2); + + // Create an addition node between two operands. The |simplify| when set will + // allow the function to return an SEConstant instead of an addition if the + // two input operands are also constant. + SENode* CreateAddNode(SENode* operand_1, SENode* operand_2); + + // Create a multiply node between two operands. + SENode* CreateMultiplyNode(SENode* operand_1, SENode* operand_2); + + // Create a node representing a constant integer. + SENode* CreateConstant(int64_t integer); + + // Create a value unknown node, such as a load. + SENode* CreateValueUnknownNode(const Instruction* inst); + + // Create a CantComputeNode. Used to exit out of analysis. + SENode* CreateCantComputeNode(); + + // Create a new recurrent node with |offset| and |coefficient|, with respect + // to |loop|. + SENode* CreateRecurrentExpression(const Loop* loop, SENode* offset, + SENode* coefficient); + + // Construct the DAG by traversing use def chain of |inst|. + SENode* AnalyzeInstruction(const Instruction* inst); + + // Simplify the |node| by grouping like terms or if contains a recurrent + // expression, rewrite the graph so the whole DAG (from |node| down) is in + // terms of that recurrent expression. + // + // For example. + // Induction variable i=0, i++ would produce Rec(0,1) so i+1 could be + // transformed into Rec(1,1). + // + // X+X*2+Y-Y+34-17 would be transformed into 3*X + 17, where X and Y are + // ValueUnknown nodes (such as a load instruction). + SENode* SimplifyExpression(SENode* node); + + // Add |prospective_node| into the cache and return a raw pointer to it. If + // |prospective_node| is already in the cache just return the raw pointer. + SENode* GetCachedOrAdd(std::unique_ptr prospective_node); + + // Checks that the graph starting from |node| is invariant to the |loop|. + bool IsLoopInvariant(const Loop* loop, const SENode* node) const; + + // Sets |is_gt_zero| to true if |node| represent a value always strictly + // greater than 0. The result of |is_gt_zero| is valid only if the function + // returns true. + bool IsAlwaysGreaterThanZero(SENode* node, bool* is_gt_zero) const; + + // Sets |is_ge_zero| to true if |node| represent a value greater or equals to + // 0. The result of |is_ge_zero| is valid only if the function returns true. + bool IsAlwaysGreaterOrEqualToZero(SENode* node, bool* is_ge_zero) const; + + // Find the recurrent term belonging to |loop| in the graph starting from + // |node| and return the coefficient of that recurrent term. Constant zero + // will be returned if no recurrent could be found. |node| should be in + // simplest form. + SENode* GetCoefficientFromRecurrentTerm(SENode* node, const Loop* loop); + + // Return a rebuilt graph starting from |node| with the recurrent expression + // belonging to |loop| being zeroed out. Returned node will be simplified. + SENode* BuildGraphWithoutRecurrentTerm(SENode* node, const Loop* loop); + + // Return the recurrent term belonging to |loop| if it appears in the graph + // starting at |node| or null if it doesn't. + SERecurrentNode* GetRecurrentTerm(SENode* node, const Loop* loop); + + SENode* UpdateChildNode(SENode* parent, SENode* child, SENode* new_child); + + // The loops in |loop_pair| will be considered the same when constructing + // SERecurrentNode objects. This enables analysing dependencies that will be + // created during loop fusion. + void AddLoopsToPretendAreTheSame( + const std::pair& loop_pair) { + pretend_equal_[std::get<1>(loop_pair)] = std::get<0>(loop_pair); + } + + private: + SENode* AnalyzeConstant(const Instruction* inst); + + // Handles both addition and subtraction. If the |instruction| is OpISub + // then the resulting node will be op1+(-op2) otherwise if it is OpIAdd then + // the result will be op1+op2. |instruction| must be OpIAdd or OpISub. + SENode* AnalyzeAddOp(const Instruction* instruction); + + SENode* AnalyzeMultiplyOp(const Instruction* multiply); + + SENode* AnalyzePhiInstruction(const Instruction* phi); + + IRContext* context_; + + // A map of instructions to SENodes. This is used to track recurrent + // expressions as they are added when analyzing instructions. Recurrent + // expressions come from phi nodes which by nature can include recursion so we + // check if nodes have already been built when analyzing instructions. + std::map recurrent_node_map_; + + // On creation we create and cache the CantCompute node so we not need to + // perform a needless create step. + SENode* cached_cant_compute_; + + // Helper functor to allow two unique_ptr to nodes to be compare. Only + // needed + // for the unordered_set implementation. + struct NodePointersEquality { + bool operator()(const std::unique_ptr& lhs, + const std::unique_ptr& rhs) const { + return *lhs == *rhs; + } + }; + + // Cache of nodes. All pointers to the nodes are references to the memory + // managed by they set. + std::unordered_set, SENodeHash, NodePointersEquality> + node_cache_; + + // Loops that should be considered the same for performing analysis for loop + // fusion. + std::map pretend_equal_; +}; + +// Wrapping class to manipulate SENode pointer using + - * / operators. +class SExpression { + public: + // Implicit on purpose ! + SExpression(SENode* node) + : node_(node->GetParentAnalysis()->SimplifyExpression(node)), + scev_(node->GetParentAnalysis()) {} + + inline operator SENode*() const { return node_; } + inline SENode* operator->() const { return node_; } + const SENode& operator*() const { return *node_; } + + inline ScalarEvolutionAnalysis* GetScalarEvolutionAnalysis() const { + return scev_; + } + + inline SExpression operator+(SENode* rhs) const; + template ::value, int>::type = 0> + inline SExpression operator+(T integer) const; + inline SExpression operator+(SExpression rhs) const; + + inline SExpression operator-() const; + inline SExpression operator-(SENode* rhs) const; + template ::value, int>::type = 0> + inline SExpression operator-(T integer) const; + inline SExpression operator-(SExpression rhs) const; + + inline SExpression operator*(SENode* rhs) const; + template ::value, int>::type = 0> + inline SExpression operator*(T integer) const; + inline SExpression operator*(SExpression rhs) const; + + template ::value, int>::type = 0> + inline std::pair operator/(T integer) const; + // Try to perform a division. Returns the pair . If it fails to simplify it, the function returns a + // CanNotCompute node. + std::pair operator/(SExpression rhs) const; + + private: + SENode* node_; + ScalarEvolutionAnalysis* scev_; +}; + +inline SExpression SExpression::operator+(SENode* rhs) const { + return scev_->CreateAddNode(node_, rhs); +} + +template ::value, int>::type> +inline SExpression SExpression::operator+(T integer) const { + return *this + scev_->CreateConstant(integer); +} + +inline SExpression SExpression::operator+(SExpression rhs) const { + return *this + rhs.node_; +} + +inline SExpression SExpression::operator-() const { + return scev_->CreateNegation(node_); +} + +inline SExpression SExpression::operator-(SENode* rhs) const { + return *this + scev_->CreateNegation(rhs); +} + +template ::value, int>::type> +inline SExpression SExpression::operator-(T integer) const { + return *this - scev_->CreateConstant(integer); +} + +inline SExpression SExpression::operator-(SExpression rhs) const { + return *this - rhs.node_; +} + +inline SExpression SExpression::operator*(SENode* rhs) const { + return scev_->CreateMultiplyNode(node_, rhs); +} + +template ::value, int>::type> +inline SExpression SExpression::operator*(T integer) const { + return *this * scev_->CreateConstant(integer); +} + +inline SExpression SExpression::operator*(SExpression rhs) const { + return *this * rhs.node_; +} + +template ::value, int>::type> +inline std::pair SExpression::operator/(T integer) const { + return *this / scev_->CreateConstant(integer); +} + +template ::value, int>::type> +inline SExpression operator+(T lhs, SExpression rhs) { + return rhs + lhs; +} +inline SExpression operator+(SENode* lhs, SExpression rhs) { return rhs + lhs; } + +template ::value, int>::type> +inline SExpression operator-(T lhs, SExpression rhs) { + // NOLINTNEXTLINE(whitespace/braces) + return SExpression{rhs.GetScalarEvolutionAnalysis()->CreateConstant(lhs)} - + rhs; +} +inline SExpression operator-(SENode* lhs, SExpression rhs) { + // NOLINTNEXTLINE(whitespace/braces) + return SExpression{lhs} - rhs; +} + +template ::value, int>::type> +inline SExpression operator*(T lhs, SExpression rhs) { + return rhs * lhs; +} +inline SExpression operator*(SENode* lhs, SExpression rhs) { return rhs * lhs; } + +template ::value, int>::type> +inline std::pair operator/(T lhs, SExpression rhs) { + // NOLINTNEXTLINE(whitespace/braces) + return SExpression{rhs.GetScalarEvolutionAnalysis()->CreateConstant(lhs)} / + rhs; +} +inline std::pair operator/(SENode* lhs, SExpression rhs) { + // NOLINTNEXTLINE(whitespace/braces) + return SExpression{lhs} / rhs; +} + +} // namespace opt +} // namespace spvtools +#endif // SOURCE_OPT_SCALAR_ANALYSIS_H_ diff --git a/source/opt/scalar_analysis_nodes.h b/source/opt/scalar_analysis_nodes.h new file mode 100644 index 000000000..450522ec3 --- /dev/null +++ b/source/opt/scalar_analysis_nodes.h @@ -0,0 +1,345 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASI, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_SCALAR_ANALYSIS_NODES_H_ +#define SOURCE_OPT_SCALAR_ANALYSIS_NODES_H_ + +#include +#include +#include +#include + +#include "source/opt/tree_iterator.h" + +namespace spvtools { +namespace opt { + +class Loop; +class ScalarEvolutionAnalysis; +class SEConstantNode; +class SERecurrentNode; +class SEAddNode; +class SEMultiplyNode; +class SENegative; +class SEValueUnknown; +class SECantCompute; + +// Abstract class representing a node in the scalar evolution DAG. Each node +// contains a vector of pointers to its children and each subclass of SENode +// implements GetType and an As method to allow casting. SENodes can be hashed +// using the SENodeHash functor. The vector of children is sorted when a node is +// added. This is important as it allows the hash of X+Y to be the same as Y+X. +class SENode { + public: + enum SENodeType { + Constant, + RecurrentAddExpr, + Add, + Multiply, + Negative, + ValueUnknown, + CanNotCompute + }; + + using ChildContainerType = std::vector; + + explicit SENode(ScalarEvolutionAnalysis* parent_analysis) + : parent_analysis_(parent_analysis), unique_id_(++NumberOfNodes) {} + + virtual SENodeType GetType() const = 0; + + virtual ~SENode() {} + + virtual inline void AddChild(SENode* child) { + // If this is a constant node, assert. + if (AsSEConstantNode()) { + assert(false && "Trying to add a child node to a constant!"); + } + + // Find the first point in the vector where |child| is greater than the node + // currently in the vector. + auto find_first_less_than = [child](const SENode* node) { + return child->unique_id_ <= node->unique_id_; + }; + + auto position = std::find_if_not(children_.begin(), children_.end(), + find_first_less_than); + // Children are sorted so the hashing and equality operator will be the same + // for a node with the same children. X+Y should be the same as Y+X. + children_.insert(position, child); + } + + // Get the type as an std::string. This is used to represent the node in the + // dot output and is used to hash the type as well. + std::string AsString() const; + + // Dump the SENode and its immediate children, if |recurse| is true then it + // will recurse through all children to print the DAG starting from this node + // as a root. + void DumpDot(std::ostream& out, bool recurse = false) const; + + // Checks if two nodes are the same by hashing them. + bool operator==(const SENode& other) const; + + // Checks if two nodes are not the same by comparing the hashes. + bool operator!=(const SENode& other) const; + + // Return the child node at |index|. + inline SENode* GetChild(size_t index) { return children_[index]; } + inline const SENode* GetChild(size_t index) const { return children_[index]; } + + // Iterator to iterate over the child nodes. + using iterator = ChildContainerType::iterator; + using const_iterator = ChildContainerType::const_iterator; + + // Iterate over immediate child nodes. + iterator begin() { return children_.begin(); } + iterator end() { return children_.end(); } + + // Constant overloads for iterating over immediate child nodes. + const_iterator begin() const { return children_.cbegin(); } + const_iterator end() const { return children_.cend(); } + const_iterator cbegin() { return children_.cbegin(); } + const_iterator cend() { return children_.cend(); } + + // Collect all the recurrent nodes in this SENode + std::vector CollectRecurrentNodes() { + std::vector recurrent_nodes{}; + + if (auto recurrent_node = AsSERecurrentNode()) { + recurrent_nodes.push_back(recurrent_node); + } + + for (auto child : GetChildren()) { + auto child_recurrent_nodes = child->CollectRecurrentNodes(); + recurrent_nodes.insert(recurrent_nodes.end(), + child_recurrent_nodes.begin(), + child_recurrent_nodes.end()); + } + + return recurrent_nodes; + } + + // Collect all the value unknown nodes in this SENode + std::vector CollectValueUnknownNodes() { + std::vector value_unknown_nodes{}; + + if (auto value_unknown_node = AsSEValueUnknown()) { + value_unknown_nodes.push_back(value_unknown_node); + } + + for (auto child : GetChildren()) { + auto child_value_unknown_nodes = child->CollectValueUnknownNodes(); + value_unknown_nodes.insert(value_unknown_nodes.end(), + child_value_unknown_nodes.begin(), + child_value_unknown_nodes.end()); + } + + return value_unknown_nodes; + } + + // Iterator to iterate over the entire DAG. Even though we are using the tree + // iterator it should still be safe to iterate over. However, nodes with + // multiple parents will be visited multiple times, unlike in a tree. + using dag_iterator = TreeDFIterator; + using const_dag_iterator = TreeDFIterator; + + // Iterate over all child nodes in the graph. + dag_iterator graph_begin() { return dag_iterator(this); } + dag_iterator graph_end() { return dag_iterator(); } + const_dag_iterator graph_begin() const { return graph_cbegin(); } + const_dag_iterator graph_end() const { return graph_cend(); } + const_dag_iterator graph_cbegin() const { return const_dag_iterator(this); } + const_dag_iterator graph_cend() const { return const_dag_iterator(); } + + // Return the vector of immediate children. + const ChildContainerType& GetChildren() const { return children_; } + ChildContainerType& GetChildren() { return children_; } + + // Return true if this node is a cant compute node. + bool IsCantCompute() const { return GetType() == CanNotCompute; } + +// Implements a casting method for each type. +#define DeclareCastMethod(target) \ + virtual target* As##target() { return nullptr; } \ + virtual const target* As##target() const { return nullptr; } + DeclareCastMethod(SEConstantNode); + DeclareCastMethod(SERecurrentNode); + DeclareCastMethod(SEAddNode); + DeclareCastMethod(SEMultiplyNode); + DeclareCastMethod(SENegative); + DeclareCastMethod(SEValueUnknown); + DeclareCastMethod(SECantCompute); +#undef DeclareCastMethod + + // Get the analysis which has this node in its cache. + inline ScalarEvolutionAnalysis* GetParentAnalysis() const { + return parent_analysis_; + } + + protected: + ChildContainerType children_; + + ScalarEvolutionAnalysis* parent_analysis_; + + // The unique id of this node, assigned on creation by incrementing the static + // node count. + uint32_t unique_id_; + + // The number of nodes created. + static uint32_t NumberOfNodes; +}; + +// Function object to handle the hashing of SENodes. Hashing algorithm hashes +// the type (as a string), the literal value of any constants, and the child +// pointers which are assumed to be unique. +struct SENodeHash { + size_t operator()(const std::unique_ptr& node) const; + size_t operator()(const SENode* node) const; +}; + +// A node representing a constant integer. +class SEConstantNode : public SENode { + public: + SEConstantNode(ScalarEvolutionAnalysis* parent_analysis, int64_t value) + : SENode(parent_analysis), literal_value_(value) {} + + SENodeType GetType() const final { return Constant; } + + int64_t FoldToSingleValue() const { return literal_value_; } + + SEConstantNode* AsSEConstantNode() override { return this; } + const SEConstantNode* AsSEConstantNode() const override { return this; } + + inline void AddChild(SENode*) final { + assert(false && "Attempting to add a child to a constant node!"); + } + + protected: + int64_t literal_value_; +}; + +// A node representing a recurrent expression in the code. A recurrent +// expression is an expression whose value can be expressed as a linear +// expression of the loop iterations. Such as an induction variable. The actual +// value of a recurrent expression is coefficent_ * iteration + offset_, hence +// an induction variable i=0, i++ becomes a recurrent expression with an offset +// of zero and a coefficient of one. +class SERecurrentNode : public SENode { + public: + SERecurrentNode(ScalarEvolutionAnalysis* parent_analysis, const Loop* loop) + : SENode(parent_analysis), loop_(loop) {} + + SENodeType GetType() const final { return RecurrentAddExpr; } + + inline void AddCoefficient(SENode* child) { + coefficient_ = child; + SENode::AddChild(child); + } + + inline void AddOffset(SENode* child) { + offset_ = child; + SENode::AddChild(child); + } + + inline const SENode* GetCoefficient() const { return coefficient_; } + inline SENode* GetCoefficient() { return coefficient_; } + + inline const SENode* GetOffset() const { return offset_; } + inline SENode* GetOffset() { return offset_; } + + // Return the loop which this recurrent expression is recurring within. + const Loop* GetLoop() const { return loop_; } + + SERecurrentNode* AsSERecurrentNode() override { return this; } + const SERecurrentNode* AsSERecurrentNode() const override { return this; } + + private: + SENode* coefficient_; + SENode* offset_; + const Loop* loop_; +}; + +// A node representing an addition operation between child nodes. +class SEAddNode : public SENode { + public: + explicit SEAddNode(ScalarEvolutionAnalysis* parent_analysis) + : SENode(parent_analysis) {} + + SENodeType GetType() const final { return Add; } + + SEAddNode* AsSEAddNode() override { return this; } + const SEAddNode* AsSEAddNode() const override { return this; } +}; + +// A node representing a multiply operation between child nodes. +class SEMultiplyNode : public SENode { + public: + explicit SEMultiplyNode(ScalarEvolutionAnalysis* parent_analysis) + : SENode(parent_analysis) {} + + SENodeType GetType() const final { return Multiply; } + + SEMultiplyNode* AsSEMultiplyNode() override { return this; } + const SEMultiplyNode* AsSEMultiplyNode() const override { return this; } +}; + +// A node representing a unary negative operation. +class SENegative : public SENode { + public: + explicit SENegative(ScalarEvolutionAnalysis* parent_analysis) + : SENode(parent_analysis) {} + + SENodeType GetType() const final { return Negative; } + + SENegative* AsSENegative() override { return this; } + const SENegative* AsSENegative() const override { return this; } +}; + +// A node representing a value which we do not know the value of, such as a load +// instruction. +class SEValueUnknown : public SENode { + public: + // SEValueUnknowns must come from an instruction |unique_id| is the unique id + // of that instruction. This is so we cancompare value unknowns and have a + // unique value unknown for each instruction. + SEValueUnknown(ScalarEvolutionAnalysis* parent_analysis, uint32_t result_id) + : SENode(parent_analysis), result_id_(result_id) {} + + SENodeType GetType() const final { return ValueUnknown; } + + SEValueUnknown* AsSEValueUnknown() override { return this; } + const SEValueUnknown* AsSEValueUnknown() const override { return this; } + + inline uint32_t ResultId() const { return result_id_; } + + private: + uint32_t result_id_; +}; + +// A node which we cannot reason about at all. +class SECantCompute : public SENode { + public: + explicit SECantCompute(ScalarEvolutionAnalysis* parent_analysis) + : SENode(parent_analysis) {} + + SENodeType GetType() const final { return CanNotCompute; } + + SECantCompute* AsSECantCompute() override { return this; } + const SECantCompute* AsSECantCompute() const override { return this; } +}; + +} // namespace opt +} // namespace spvtools +#endif // SOURCE_OPT_SCALAR_ANALYSIS_NODES_H_ diff --git a/source/opt/scalar_analysis_simplification.cpp b/source/opt/scalar_analysis_simplification.cpp new file mode 100644 index 000000000..52f2d6ad9 --- /dev/null +++ b/source/opt/scalar_analysis_simplification.cpp @@ -0,0 +1,539 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/scalar_analysis.h" + +#include +#include +#include +#include +#include +#include +#include + +// Simplifies scalar analysis DAGs. +// +// 1. Given a node passed to SimplifyExpression we first simplify the graph by +// calling SimplifyPolynomial. This groups like nodes following basic arithmetic +// rules, so multiple adds of the same load instruction could be grouped into a +// single multiply of that instruction. SimplifyPolynomial will traverse the DAG +// and build up an accumulator buffer for each class of instruction it finds. +// For example take the loop: +// for (i=0, i accumulators_; +}; + +// From a |multiply| build up the accumulator objects. +bool SENodeSimplifyImpl::AccumulatorsFromMultiply(SENode* multiply, + bool negation) { + if (multiply->GetChildren().size() != 2 || + multiply->GetType() != SENode::Multiply) + return false; + + SENode* operand_1 = multiply->GetChild(0); + SENode* operand_2 = multiply->GetChild(1); + + SENode* value_unknown = nullptr; + SENode* constant = nullptr; + + // Work out which operand is the unknown value. + if (operand_1->GetType() == SENode::ValueUnknown || + operand_1->GetType() == SENode::RecurrentAddExpr) + value_unknown = operand_1; + else if (operand_2->GetType() == SENode::ValueUnknown || + operand_2->GetType() == SENode::RecurrentAddExpr) + value_unknown = operand_2; + + // Work out which operand is the constant coefficient. + if (operand_1->GetType() == SENode::Constant) + constant = operand_1; + else if (operand_2->GetType() == SENode::Constant) + constant = operand_2; + + // If the expression is not a variable multiplied by a constant coefficient, + // exit out. + if (!(value_unknown && constant)) { + return false; + } + + int64_t sign = negation ? -1 : 1; + + auto iterator = accumulators_.find(value_unknown); + int64_t new_value = constant->AsSEConstantNode()->FoldToSingleValue() * sign; + // Add the result of the multiplication to the accumulators. + if (iterator != accumulators_.end()) { + (*iterator).second += new_value; + } else { + accumulators_.insert({value_unknown, new_value}); + } + + return true; +} + +SENode* SENodeSimplifyImpl::Simplify() { + // We only handle graphs with an addition, multiplication, or negation, at the + // root. + if (node_->GetType() != SENode::Add && node_->GetType() != SENode::Multiply && + node_->GetType() != SENode::Negative) + return node_; + + SENode* simplified_polynomial = SimplifyPolynomial(); + + SERecurrentNode* recurrent_expr = nullptr; + node_ = simplified_polynomial; + + // Fold recurrent expressions which are with respect to the same loop into a + // single recurrent expression. + simplified_polynomial = FoldRecurrentAddExpressions(simplified_polynomial); + + simplified_polynomial = + EliminateZeroCoefficientRecurrents(simplified_polynomial); + + // Traverse the immediate children of the new node to find the recurrent + // expression. If there is more than one there is nothing further we can do. + for (SENode* child : simplified_polynomial->GetChildren()) { + if (child->GetType() == SENode::RecurrentAddExpr) { + recurrent_expr = child->AsSERecurrentNode(); + } + } + + // We need to count the number of unique recurrent expressions in the DAG to + // ensure there is only one. + for (auto child_iterator = simplified_polynomial->graph_begin(); + child_iterator != simplified_polynomial->graph_end(); ++child_iterator) { + if (child_iterator->GetType() == SENode::RecurrentAddExpr && + recurrent_expr != child_iterator->AsSERecurrentNode()) { + return simplified_polynomial; + } + } + + if (recurrent_expr) { + return SimplifyRecurrentAddExpression(recurrent_expr); + } + + return simplified_polynomial; +} + +// Traverse the graph to build up the accumulator objects. +void SENodeSimplifyImpl::GatherAccumulatorsFromChildNodes(SENode* new_node, + SENode* child, + bool negation) { + int32_t sign = negation ? -1 : 1; + + if (child->GetType() == SENode::Constant) { + // Collect all the constants and add them together. + constant_accumulator_ += + child->AsSEConstantNode()->FoldToSingleValue() * sign; + + } else if (child->GetType() == SENode::ValueUnknown || + child->GetType() == SENode::RecurrentAddExpr) { + // To rebuild the graph of X+X+X*2 into 4*X we count the occurrences of X + // and create a new node of count*X after. X can either be a ValueUnknown or + // a RecurrentAddExpr. The count for each X is stored in the accumulators_ + // map. + + auto iterator = accumulators_.find(child); + // If we've encountered this term before add to the accumulator for it. + if (iterator == accumulators_.end()) + accumulators_.insert({child, sign}); + else + iterator->second += sign; + + } else if (child->GetType() == SENode::Multiply) { + if (!AccumulatorsFromMultiply(child, negation)) { + new_node->AddChild(child); + } + + } else if (child->GetType() == SENode::Add) { + for (SENode* next_child : *child) { + GatherAccumulatorsFromChildNodes(new_node, next_child, negation); + } + + } else if (child->GetType() == SENode::Negative) { + SENode* negated_node = child->GetChild(0); + GatherAccumulatorsFromChildNodes(new_node, negated_node, !negation); + } else { + // If we can't work out how to fold the expression just add it back into + // the graph. + new_node->AddChild(child); + } +} + +SERecurrentNode* SENodeSimplifyImpl::UpdateCoefficient( + SERecurrentNode* recurrent, int64_t coefficient_update) const { + std::unique_ptr new_recurrent_node{new SERecurrentNode( + recurrent->GetParentAnalysis(), recurrent->GetLoop())}; + + SENode* new_coefficient = analysis_.CreateMultiplyNode( + recurrent->GetCoefficient(), + analysis_.CreateConstant(coefficient_update)); + + // See if the node can be simplified. + SENode* simplified = analysis_.SimplifyExpression(new_coefficient); + if (simplified->GetType() != SENode::CanNotCompute) + new_coefficient = simplified; + + if (coefficient_update < 0) { + new_recurrent_node->AddOffset( + analysis_.CreateNegation(recurrent->GetOffset())); + } else { + new_recurrent_node->AddOffset(recurrent->GetOffset()); + } + + new_recurrent_node->AddCoefficient(new_coefficient); + + return analysis_.GetCachedOrAdd(std::move(new_recurrent_node)) + ->AsSERecurrentNode(); +} + +// Simplify all the terms in the polynomial function. +SENode* SENodeSimplifyImpl::SimplifyPolynomial() { + std::unique_ptr new_add{new SEAddNode(node_->GetParentAnalysis())}; + + // Traverse the graph and gather the accumulators from it. + GatherAccumulatorsFromChildNodes(new_add.get(), node_, false); + + // Fold all the constants into a single constant node. + if (constant_accumulator_ != 0) { + new_add->AddChild(analysis_.CreateConstant(constant_accumulator_)); + } + + for (auto& pair : accumulators_) { + SENode* term = pair.first; + int64_t count = pair.second; + + // We can eliminate the term completely. + if (count == 0) continue; + + if (count == 1) { + new_add->AddChild(term); + } else if (count == -1 && term->GetType() != SENode::RecurrentAddExpr) { + // If the count is -1 we can just add a negative version of that node, + // unless it is a recurrent expression as we would rather the negative + // goes on the recurrent expressions children. This makes it easier to + // work with in other places. + new_add->AddChild(analysis_.CreateNegation(term)); + } else { + // Output value unknown terms as count*term and output recurrent + // expression terms as rec(offset, coefficient + count) offset and + // coefficient are the same as in the original expression. + if (term->GetType() == SENode::ValueUnknown) { + SENode* count_as_constant = analysis_.CreateConstant(count); + new_add->AddChild( + analysis_.CreateMultiplyNode(count_as_constant, term)); + } else { + assert(term->GetType() == SENode::RecurrentAddExpr && + "We only handle value unknowns or recurrent expressions"); + + // Create a new recurrent expression by adding the count to the + // coefficient of the old one. + new_add->AddChild(UpdateCoefficient(term->AsSERecurrentNode(), count)); + } + } + } + + // If there is only one term in the addition left just return that term. + if (new_add->GetChildren().size() == 1) { + return new_add->GetChild(0); + } + + // If there are no terms left in the addition just return 0. + if (new_add->GetChildren().size() == 0) { + return analysis_.CreateConstant(0); + } + + return analysis_.GetCachedOrAdd(std::move(new_add)); +} + +SENode* SENodeSimplifyImpl::FoldRecurrentAddExpressions(SENode* root) { + std::unique_ptr new_node{new SEAddNode(&analysis_)}; + + // A mapping of loops to the list of recurrent expressions which are with + // respect to those loops. + std::map>> + loops_to_recurrent{}; + + bool has_multiple_same_loop_recurrent_terms = false; + + for (SENode* child : *root) { + bool negation = false; + + if (child->GetType() == SENode::Negative) { + child = child->GetChild(0); + negation = true; + } + + if (child->GetType() == SENode::RecurrentAddExpr) { + const Loop* loop = child->AsSERecurrentNode()->GetLoop(); + + SERecurrentNode* rec = child->AsSERecurrentNode(); + if (loops_to_recurrent.find(loop) == loops_to_recurrent.end()) { + loops_to_recurrent[loop] = {std::make_pair(rec, negation)}; + } else { + loops_to_recurrent[loop].push_back(std::make_pair(rec, negation)); + has_multiple_same_loop_recurrent_terms = true; + } + } else { + new_node->AddChild(child); + } + } + + if (!has_multiple_same_loop_recurrent_terms) return root; + + for (auto pair : loops_to_recurrent) { + std::vector>& recurrent_expressions = + pair.second; + const Loop* loop = pair.first; + + std::unique_ptr new_coefficient{new SEAddNode(&analysis_)}; + std::unique_ptr new_offset{new SEAddNode(&analysis_)}; + + for (auto node_pair : recurrent_expressions) { + SERecurrentNode* node = node_pair.first; + bool negative = node_pair.second; + + if (!negative) { + new_coefficient->AddChild(node->GetCoefficient()); + new_offset->AddChild(node->GetOffset()); + } else { + new_coefficient->AddChild( + analysis_.CreateNegation(node->GetCoefficient())); + new_offset->AddChild(analysis_.CreateNegation(node->GetOffset())); + } + } + + std::unique_ptr new_recurrent{ + new SERecurrentNode(&analysis_, loop)}; + + SENode* new_coefficient_simplified = + analysis_.SimplifyExpression(new_coefficient.get()); + + SENode* new_offset_simplified = + analysis_.SimplifyExpression(new_offset.get()); + + if (new_coefficient_simplified->GetType() == SENode::Constant && + new_coefficient_simplified->AsSEConstantNode()->FoldToSingleValue() == + 0) { + return new_offset_simplified; + } + + new_recurrent->AddCoefficient(new_coefficient_simplified); + new_recurrent->AddOffset(new_offset_simplified); + + new_node->AddChild(analysis_.GetCachedOrAdd(std::move(new_recurrent))); + } + + // If we only have one child in the add just return that. + if (new_node->GetChildren().size() == 1) { + return new_node->GetChild(0); + } + + return analysis_.GetCachedOrAdd(std::move(new_node)); +} + +SENode* SENodeSimplifyImpl::EliminateZeroCoefficientRecurrents(SENode* node) { + if (node->GetType() != SENode::Add) return node; + + bool has_change = false; + + std::vector new_children{}; + for (SENode* child : *node) { + if (child->GetType() == SENode::RecurrentAddExpr) { + SENode* coefficient = child->AsSERecurrentNode()->GetCoefficient(); + // If coefficient is zero then we can eliminate the recurrent expression + // entirely and just return the offset as the recurrent expression is + // representing the equation coefficient*iterations + offset. + if (coefficient->GetType() == SENode::Constant && + coefficient->AsSEConstantNode()->FoldToSingleValue() == 0) { + new_children.push_back(child->AsSERecurrentNode()->GetOffset()); + has_change = true; + } else { + new_children.push_back(child); + } + } else { + new_children.push_back(child); + } + } + + if (!has_change) return node; + + std::unique_ptr new_add{new SEAddNode(node_->GetParentAnalysis())}; + + for (SENode* child : new_children) { + new_add->AddChild(child); + } + + return analysis_.GetCachedOrAdd(std::move(new_add)); +} + +SENode* SENodeSimplifyImpl::SimplifyRecurrentAddExpression( + SERecurrentNode* recurrent_expr) { + const std::vector& children = node_->GetChildren(); + + std::unique_ptr recurrent_node{new SERecurrentNode( + recurrent_expr->GetParentAnalysis(), recurrent_expr->GetLoop())}; + + // Create and simplify the new offset node. + std::unique_ptr new_offset{ + new SEAddNode(recurrent_expr->GetParentAnalysis())}; + new_offset->AddChild(recurrent_expr->GetOffset()); + + for (SENode* child : children) { + if (child->GetType() != SENode::RecurrentAddExpr) { + new_offset->AddChild(child); + } + } + + // Simplify the new offset. + SENode* simplified_child = analysis_.SimplifyExpression(new_offset.get()); + + // If the child can be simplified, add the simplified form otherwise, add it + // via the usual caching mechanism. + if (simplified_child->GetType() != SENode::CanNotCompute) { + recurrent_node->AddOffset(simplified_child); + } else { + recurrent_expr->AddOffset(analysis_.GetCachedOrAdd(std::move(new_offset))); + } + + recurrent_node->AddCoefficient(recurrent_expr->GetCoefficient()); + + return analysis_.GetCachedOrAdd(std::move(recurrent_node)); +} + +/* + * Scalar Analysis simplification public methods. + */ + +SENode* ScalarEvolutionAnalysis::SimplifyExpression(SENode* node) { + SENodeSimplifyImpl impl{this, node}; + + return impl.Simplify(); +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/scalar_replacement_pass.cpp b/source/opt/scalar_replacement_pass.cpp new file mode 100644 index 000000000..2b2397d38 --- /dev/null +++ b/source/opt/scalar_replacement_pass.cpp @@ -0,0 +1,809 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/scalar_replacement_pass.h" + +#include +#include +#include +#include + +#include "source/enum_string_mapping.h" +#include "source/extensions.h" +#include "source/opt/reflect.h" +#include "source/opt/types.h" +#include "source/util/make_unique.h" + +namespace spvtools { +namespace opt { + +Pass::Status ScalarReplacementPass::Process() { + Status status = Status::SuccessWithoutChange; + for (auto& f : *get_module()) { + Status functionStatus = ProcessFunction(&f); + if (functionStatus == Status::Failure) + return functionStatus; + else if (functionStatus == Status::SuccessWithChange) + status = functionStatus; + } + + return status; +} + +Pass::Status ScalarReplacementPass::ProcessFunction(Function* function) { + std::queue worklist; + BasicBlock& entry = *function->begin(); + for (auto iter = entry.begin(); iter != entry.end(); ++iter) { + // Function storage class OpVariables must appear as the first instructions + // of the entry block. + if (iter->opcode() != SpvOpVariable) break; + + Instruction* varInst = &*iter; + if (CanReplaceVariable(varInst)) { + worklist.push(varInst); + } + } + + Status status = Status::SuccessWithoutChange; + while (!worklist.empty()) { + Instruction* varInst = worklist.front(); + worklist.pop(); + + if (!ReplaceVariable(varInst, &worklist)) + return Status::Failure; + else + status = Status::SuccessWithChange; + } + + return status; +} + +bool ScalarReplacementPass::ReplaceVariable( + Instruction* inst, std::queue* worklist) { + std::vector replacements; + CreateReplacementVariables(inst, &replacements); + + std::vector dead; + dead.push_back(inst); + if (!get_def_use_mgr()->WhileEachUser( + inst, [this, &replacements, &dead](Instruction* user) { + if (!IsAnnotationInst(user->opcode())) { + switch (user->opcode()) { + case SpvOpLoad: + ReplaceWholeLoad(user, replacements); + dead.push_back(user); + break; + case SpvOpStore: + ReplaceWholeStore(user, replacements); + dead.push_back(user); + break; + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + if (!ReplaceAccessChain(user, replacements)) return false; + dead.push_back(user); + break; + case SpvOpName: + case SpvOpMemberName: + break; + default: + assert(false && "Unexpected opcode"); + break; + } + } + return true; + })) + return false; + + // Clean up some dead code. + while (!dead.empty()) { + Instruction* toKill = dead.back(); + dead.pop_back(); + context()->KillInst(toKill); + } + + // Attempt to further scalarize. + for (auto var : replacements) { + if (var->opcode() == SpvOpVariable) { + if (get_def_use_mgr()->NumUsers(var) == 0) { + context()->KillInst(var); + } else if (CanReplaceVariable(var)) { + worklist->push(var); + } + } + } + + return true; +} + +void ScalarReplacementPass::ReplaceWholeLoad( + Instruction* load, const std::vector& replacements) { + // Replaces the load of the entire composite with a load from each replacement + // variable followed by a composite construction. + BasicBlock* block = context()->get_instr_block(load); + std::vector loads; + loads.reserve(replacements.size()); + BasicBlock::iterator where(load); + for (auto var : replacements) { + // Create a load of each replacement variable. + if (var->opcode() != SpvOpVariable) { + loads.push_back(var); + continue; + } + + Instruction* type = GetStorageType(var); + uint32_t loadId = TakeNextId(); + std::unique_ptr newLoad( + new Instruction(context(), SpvOpLoad, type->result_id(), loadId, + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {var->result_id()}}})); + // Copy memory access attributes which start at index 1. Index 0 is the + // pointer to load. + for (uint32_t i = 1; i < load->NumInOperands(); ++i) { + Operand copy(load->GetInOperand(i)); + newLoad->AddOperand(std::move(copy)); + } + where = where.InsertBefore(std::move(newLoad)); + get_def_use_mgr()->AnalyzeInstDefUse(&*where); + context()->set_instr_block(&*where, block); + loads.push_back(&*where); + } + + // Construct a new composite. + uint32_t compositeId = TakeNextId(); + where = load; + std::unique_ptr compositeConstruct(new Instruction( + context(), SpvOpCompositeConstruct, load->type_id(), compositeId, {})); + for (auto l : loads) { + Operand op(SPV_OPERAND_TYPE_ID, + std::initializer_list{l->result_id()}); + compositeConstruct->AddOperand(std::move(op)); + } + where = where.InsertBefore(std::move(compositeConstruct)); + get_def_use_mgr()->AnalyzeInstDefUse(&*where); + context()->set_instr_block(&*where, block); + context()->ReplaceAllUsesWith(load->result_id(), compositeId); +} + +void ScalarReplacementPass::ReplaceWholeStore( + Instruction* store, const std::vector& replacements) { + // Replaces a store to the whole composite with a series of extract and stores + // to each element. + uint32_t storeInput = store->GetSingleWordInOperand(1u); + BasicBlock* block = context()->get_instr_block(store); + BasicBlock::iterator where(store); + uint32_t elementIndex = 0; + for (auto var : replacements) { + // Create the extract. + if (var->opcode() != SpvOpVariable) { + elementIndex++; + continue; + } + + Instruction* type = GetStorageType(var); + uint32_t extractId = TakeNextId(); + std::unique_ptr extract(new Instruction( + context(), SpvOpCompositeExtract, type->result_id(), extractId, + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {storeInput}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {elementIndex++}}})); + auto iter = where.InsertBefore(std::move(extract)); + get_def_use_mgr()->AnalyzeInstDefUse(&*iter); + context()->set_instr_block(&*iter, block); + + // Create the store. + std::unique_ptr newStore( + new Instruction(context(), SpvOpStore, 0, 0, + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {var->result_id()}}, + {SPV_OPERAND_TYPE_ID, {extractId}}})); + // Copy memory access attributes which start at index 2. Index 0 is the + // pointer and index 1 is the data. + for (uint32_t i = 2; i < store->NumInOperands(); ++i) { + Operand copy(store->GetInOperand(i)); + newStore->AddOperand(std::move(copy)); + } + iter = where.InsertBefore(std::move(newStore)); + get_def_use_mgr()->AnalyzeInstDefUse(&*iter); + context()->set_instr_block(&*iter, block); + } +} + +bool ScalarReplacementPass::ReplaceAccessChain( + Instruction* chain, const std::vector& replacements) { + // Replaces the access chain with either another access chain (with one fewer + // indexes) or a direct use of the replacement variable. + uint32_t indexId = chain->GetSingleWordInOperand(1u); + const Instruction* index = get_def_use_mgr()->GetDef(indexId); + size_t indexValue = GetConstantInteger(index); + if (indexValue > replacements.size()) { + // Out of bounds access, this is illegal IR. + return false; + } else { + const Instruction* var = replacements[indexValue]; + if (chain->NumInOperands() > 2) { + // Replace input access chain with another access chain. + BasicBlock::iterator chainIter(chain); + uint32_t replacementId = TakeNextId(); + std::unique_ptr replacementChain(new Instruction( + context(), chain->opcode(), chain->type_id(), replacementId, + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {var->result_id()}}})); + // Add the remaining indexes. + for (uint32_t i = 2; i < chain->NumInOperands(); ++i) { + Operand copy(chain->GetInOperand(i)); + replacementChain->AddOperand(std::move(copy)); + } + auto iter = chainIter.InsertBefore(std::move(replacementChain)); + get_def_use_mgr()->AnalyzeInstDefUse(&*iter); + context()->set_instr_block(&*iter, context()->get_instr_block(chain)); + context()->ReplaceAllUsesWith(chain->result_id(), replacementId); + } else { + // Replace with a use of the variable. + context()->ReplaceAllUsesWith(chain->result_id(), var->result_id()); + } + } + + return true; +} + +void ScalarReplacementPass::CreateReplacementVariables( + Instruction* inst, std::vector* replacements) { + Instruction* type = GetStorageType(inst); + + std::unique_ptr> components_used = + GetUsedComponents(inst); + + uint32_t elem = 0; + switch (type->opcode()) { + case SpvOpTypeStruct: + type->ForEachInOperand( + [this, inst, &elem, replacements, &components_used](uint32_t* id) { + if (!components_used || components_used->count(elem)) { + CreateVariable(*id, inst, elem, replacements); + } else { + replacements->push_back(CreateNullConstant(*id)); + } + elem++; + }); + break; + case SpvOpTypeArray: + for (uint32_t i = 0; i != GetArrayLength(type); ++i) { + if (!components_used || components_used->count(i)) { + CreateVariable(type->GetSingleWordInOperand(0u), inst, i, + replacements); + } else { + replacements->push_back( + CreateNullConstant(type->GetSingleWordInOperand(0u))); + } + } + break; + + case SpvOpTypeMatrix: + case SpvOpTypeVector: + for (uint32_t i = 0; i != GetNumElements(type); ++i) { + CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements); + } + break; + + default: + assert(false && "Unexpected type."); + break; + } + + TransferAnnotations(inst, replacements); +} + +void ScalarReplacementPass::TransferAnnotations( + const Instruction* source, std::vector* replacements) { + // Only transfer invariant and restrict decorations on the variable. There are + // no type or member decorations that are necessary to transfer. + for (auto inst : + get_decoration_mgr()->GetDecorationsFor(source->result_id(), false)) { + assert(inst->opcode() == SpvOpDecorate); + uint32_t decoration = inst->GetSingleWordInOperand(1u); + if (decoration == SpvDecorationInvariant || + decoration == SpvDecorationRestrict) { + for (auto var : *replacements) { + std::unique_ptr annotation( + new Instruction(context(), SpvOpDecorate, 0, 0, + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {var->result_id()}}, + {SPV_OPERAND_TYPE_DECORATION, {decoration}}})); + for (uint32_t i = 2; i < inst->NumInOperands(); ++i) { + Operand copy(inst->GetInOperand(i)); + annotation->AddOperand(std::move(copy)); + } + context()->AddAnnotationInst(std::move(annotation)); + get_def_use_mgr()->AnalyzeInstUse(&*--context()->annotation_end()); + } + } + } +} + +void ScalarReplacementPass::CreateVariable( + uint32_t typeId, Instruction* varInst, uint32_t index, + std::vector* replacements) { + uint32_t ptrId = GetOrCreatePointerType(typeId); + uint32_t id = TakeNextId(); + std::unique_ptr variable(new Instruction( + context(), SpvOpVariable, ptrId, id, + std::initializer_list{ + {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}})); + + BasicBlock* block = context()->get_instr_block(varInst); + block->begin().InsertBefore(std::move(variable)); + Instruction* inst = &*block->begin(); + + // If varInst was initialized, make sure to initialize its replacement. + GetOrCreateInitialValue(varInst, index, inst); + get_def_use_mgr()->AnalyzeInstDefUse(inst); + context()->set_instr_block(inst, block); + + replacements->push_back(inst); +} + +uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) { + auto iter = pointee_to_pointer_.find(id); + if (iter != pointee_to_pointer_.end()) return iter->second; + + analysis::Type* pointeeTy; + std::unique_ptr pointerTy; + std::tie(pointeeTy, pointerTy) = + context()->get_type_mgr()->GetTypeAndPointerType(id, + SpvStorageClassFunction); + uint32_t ptrId = 0; + if (pointeeTy->IsUniqueType()) { + // Non-ambiguous type, just ask the type manager for an id. + ptrId = context()->get_type_mgr()->GetTypeInstruction(pointerTy.get()); + pointee_to_pointer_[id] = ptrId; + return ptrId; + } + + // Ambiguous type. We must perform a linear search to try and find the right + // type. + for (auto global : context()->types_values()) { + if (global.opcode() == SpvOpTypePointer && + global.GetSingleWordInOperand(0u) == SpvStorageClassFunction && + global.GetSingleWordInOperand(1u) == id) { + if (get_decoration_mgr()->GetDecorationsFor(id, false).empty()) { + // Only reuse a decoration-less pointer of the correct type. + ptrId = global.result_id(); + break; + } + } + } + + if (ptrId != 0) { + pointee_to_pointer_[id] = ptrId; + return ptrId; + } + + ptrId = TakeNextId(); + context()->AddType(MakeUnique( + context(), SpvOpTypePointer, 0, ptrId, + std::initializer_list{ + {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}, + {SPV_OPERAND_TYPE_ID, {id}}})); + Instruction* ptr = &*--context()->types_values_end(); + get_def_use_mgr()->AnalyzeInstDefUse(ptr); + pointee_to_pointer_[id] = ptrId; + // Register with the type manager if necessary. + context()->get_type_mgr()->RegisterType(ptrId, *pointerTy); + + return ptrId; +} + +void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source, + uint32_t index, + Instruction* newVar) { + assert(source->opcode() == SpvOpVariable); + if (source->NumInOperands() < 2) return; + + uint32_t initId = source->GetSingleWordInOperand(1u); + uint32_t storageId = GetStorageType(newVar)->result_id(); + Instruction* init = get_def_use_mgr()->GetDef(initId); + uint32_t newInitId = 0; + // TODO(dnovillo): Refactor this with constant propagation. + if (init->opcode() == SpvOpConstantNull) { + // Initialize to appropriate NULL. + auto iter = type_to_null_.find(storageId); + if (iter == type_to_null_.end()) { + newInitId = TakeNextId(); + type_to_null_[storageId] = newInitId; + context()->AddGlobalValue( + MakeUnique(context(), SpvOpConstantNull, storageId, + newInitId, std::initializer_list{})); + Instruction* newNull = &*--context()->types_values_end(); + get_def_use_mgr()->AnalyzeInstDefUse(newNull); + } else { + newInitId = iter->second; + } + } else if (IsSpecConstantInst(init->opcode())) { + // Create a new constant extract. + newInitId = TakeNextId(); + context()->AddGlobalValue(MakeUnique( + context(), SpvOpSpecConstantOp, storageId, newInitId, + std::initializer_list{ + {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, {SpvOpCompositeExtract}}, + {SPV_OPERAND_TYPE_ID, {init->result_id()}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}})); + Instruction* newSpecConst = &*--context()->types_values_end(); + get_def_use_mgr()->AnalyzeInstDefUse(newSpecConst); + } else if (init->opcode() == SpvOpConstantComposite) { + // Get the appropriate index constant. + newInitId = init->GetSingleWordInOperand(index); + Instruction* element = get_def_use_mgr()->GetDef(newInitId); + if (element->opcode() == SpvOpUndef) { + // Undef is not a valid initializer for a variable. + newInitId = 0; + } + } else { + assert(false); + } + + if (newInitId != 0) { + newVar->AddOperand({SPV_OPERAND_TYPE_ID, {newInitId}}); + } +} + +size_t ScalarReplacementPass::GetIntegerLiteral(const Operand& op) const { + assert(op.words.size() <= 2); + size_t len = 0; + for (uint32_t i = 0; i != op.words.size(); ++i) { + len |= (op.words[i] << (32 * i)); + } + return len; +} + +size_t ScalarReplacementPass::GetConstantInteger( + const Instruction* constant) const { + assert(get_def_use_mgr()->GetDef(constant->type_id())->opcode() == + SpvOpTypeInt); + assert(constant->opcode() == SpvOpConstant || + constant->opcode() == SpvOpConstantNull); + if (constant->opcode() == SpvOpConstantNull) { + return 0; + } + + const Operand& op = constant->GetInOperand(0u); + return GetIntegerLiteral(op); +} + +size_t ScalarReplacementPass::GetArrayLength( + const Instruction* arrayType) const { + assert(arrayType->opcode() == SpvOpTypeArray); + const Instruction* length = + get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u)); + return GetConstantInteger(length); +} + +size_t ScalarReplacementPass::GetNumElements(const Instruction* type) const { + assert(type->opcode() == SpvOpTypeVector || + type->opcode() == SpvOpTypeMatrix); + const Operand& op = type->GetInOperand(1u); + assert(op.words.size() <= 2); + size_t len = 0; + for (uint32_t i = 0; i != op.words.size(); ++i) { + len |= (op.words[i] << (32 * i)); + } + return len; +} + +bool ScalarReplacementPass::IsSpecConstant(uint32_t id) const { + const Instruction* inst = get_def_use_mgr()->GetDef(id); + assert(inst); + return spvOpcodeIsSpecConstant(inst->opcode()); +} + +Instruction* ScalarReplacementPass::GetStorageType( + const Instruction* inst) const { + assert(inst->opcode() == SpvOpVariable); + + uint32_t ptrTypeId = inst->type_id(); + uint32_t typeId = + get_def_use_mgr()->GetDef(ptrTypeId)->GetSingleWordInOperand(1u); + return get_def_use_mgr()->GetDef(typeId); +} + +bool ScalarReplacementPass::CanReplaceVariable( + const Instruction* varInst) const { + assert(varInst->opcode() == SpvOpVariable); + + // Can only replace function scope variables. + if (varInst->GetSingleWordInOperand(0u) != SpvStorageClassFunction) + return false; + + if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id()))) + return false; + + const Instruction* typeInst = GetStorageType(varInst); + return CheckType(typeInst) && CheckAnnotations(varInst) && CheckUses(varInst); +} + +bool ScalarReplacementPass::CheckType(const Instruction* typeInst) const { + if (!CheckTypeAnnotations(typeInst)) return false; + + switch (typeInst->opcode()) { + case SpvOpTypeStruct: + // Don't bother with empty structs or very large structs. + if (typeInst->NumInOperands() == 0 || + IsLargerThanSizeLimit(typeInst->NumInOperands())) + return false; + return true; + case SpvOpTypeArray: + if (IsSpecConstant(typeInst->GetSingleWordInOperand(1u))) { + return false; + } + if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) { + return false; + } + return true; + // TODO(alanbaker): Develop some heuristics for when this should be + // re-enabled. + //// Specifically including matrix and vector in an attempt to reduce the + //// number of vector registers required. + // case SpvOpTypeMatrix: + // case SpvOpTypeVector: + // if (IsLargerThanSizeLimit(GetNumElements(typeInst))) return false; + // return true; + + case SpvOpTypeRuntimeArray: + default: + return false; + } +} + +bool ScalarReplacementPass::CheckTypeAnnotations( + const Instruction* typeInst) const { + for (auto inst : + get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) { + uint32_t decoration; + if (inst->opcode() == SpvOpDecorate) { + decoration = inst->GetSingleWordInOperand(1u); + } else { + assert(inst->opcode() == SpvOpMemberDecorate); + decoration = inst->GetSingleWordInOperand(2u); + } + + switch (decoration) { + case SpvDecorationRowMajor: + case SpvDecorationColMajor: + case SpvDecorationArrayStride: + case SpvDecorationMatrixStride: + case SpvDecorationCPacked: + case SpvDecorationInvariant: + case SpvDecorationRestrict: + case SpvDecorationOffset: + case SpvDecorationAlignment: + case SpvDecorationAlignmentId: + case SpvDecorationMaxByteOffset: + break; + default: + return false; + } + } + + return true; +} + +bool ScalarReplacementPass::CheckAnnotations(const Instruction* varInst) const { + for (auto inst : + get_decoration_mgr()->GetDecorationsFor(varInst->result_id(), false)) { + assert(inst->opcode() == SpvOpDecorate); + uint32_t decoration = inst->GetSingleWordInOperand(1u); + switch (decoration) { + case SpvDecorationInvariant: + case SpvDecorationRestrict: + case SpvDecorationAlignment: + case SpvDecorationAlignmentId: + case SpvDecorationMaxByteOffset: + break; + default: + return false; + } + } + + return true; +} + +bool ScalarReplacementPass::CheckUses(const Instruction* inst) const { + VariableStats stats = {0, 0}; + bool ok = CheckUses(inst, &stats); + + // TODO(alanbaker/greg-lunarg): Add some meaningful heuristics about when + // SRoA is costly, such as when the structure has many (unaccessed?) + // members. + + return ok; +} + +bool ScalarReplacementPass::CheckUses(const Instruction* inst, + VariableStats* stats) const { + bool ok = true; + get_def_use_mgr()->ForEachUse( + inst, [this, stats, &ok](const Instruction* user, uint32_t index) { + // Annotations are check as a group separately. + if (!IsAnnotationInst(user->opcode())) { + switch (user->opcode()) { + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + if (index == 2u && user->NumInOperands() > 1) { + uint32_t id = user->GetSingleWordInOperand(1u); + const Instruction* opInst = get_def_use_mgr()->GetDef(id); + if (!IsCompileTimeConstantInst(opInst->opcode())) { + ok = false; + } else { + if (!CheckUsesRelaxed(user)) ok = false; + } + stats->num_partial_accesses++; + } else { + ok = false; + } + break; + case SpvOpLoad: + if (!CheckLoad(user, index)) ok = false; + stats->num_full_accesses++; + break; + case SpvOpStore: + if (!CheckStore(user, index)) ok = false; + stats->num_full_accesses++; + break; + case SpvOpName: + case SpvOpMemberName: + break; + default: + ok = false; + break; + } + } + }); + + return ok; +} + +bool ScalarReplacementPass::CheckUsesRelaxed(const Instruction* inst) const { + bool ok = true; + get_def_use_mgr()->ForEachUse( + inst, [this, &ok](const Instruction* user, uint32_t index) { + switch (user->opcode()) { + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + if (index != 2u) { + ok = false; + } else { + if (!CheckUsesRelaxed(user)) ok = false; + } + break; + case SpvOpLoad: + if (!CheckLoad(user, index)) ok = false; + break; + case SpvOpStore: + if (!CheckStore(user, index)) ok = false; + break; + default: + ok = false; + break; + } + }); + + return ok; +} + +bool ScalarReplacementPass::CheckLoad(const Instruction* inst, + uint32_t index) const { + if (index != 2u) return false; + if (inst->NumInOperands() >= 2 && + inst->GetSingleWordInOperand(1u) & SpvMemoryAccessVolatileMask) + return false; + return true; +} + +bool ScalarReplacementPass::CheckStore(const Instruction* inst, + uint32_t index) const { + if (index != 0u) return false; + if (inst->NumInOperands() >= 3 && + inst->GetSingleWordInOperand(2u) & SpvMemoryAccessVolatileMask) + return false; + return true; +} +bool ScalarReplacementPass::IsLargerThanSizeLimit(size_t length) const { + if (max_num_elements_ == 0) { + return false; + } + return length > max_num_elements_; +} + +std::unique_ptr> +ScalarReplacementPass::GetUsedComponents(Instruction* inst) { + std::unique_ptr> result( + new std::unordered_set()); + + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + + def_use_mgr->WhileEachUser(inst, [&result, def_use_mgr, + this](Instruction* use) { + switch (use->opcode()) { + case SpvOpLoad: { + // Look for extract from the load. + std::vector t; + if (def_use_mgr->WhileEachUser(use, [&t](Instruction* use2) { + if (use2->opcode() != SpvOpCompositeExtract) { + return false; + } + t.push_back(use2->GetSingleWordInOperand(1)); + return true; + })) { + result->insert(t.begin(), t.end()); + return true; + } else { + result.reset(nullptr); + return false; + } + } + case SpvOpName: + case SpvOpMemberName: + case SpvOpStore: + // No components are used. + return true; + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: { + // Add the first index it if is a constant. + // TODO: Could be improved by checking if the address is used in a load. + analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); + uint32_t index_id = use->GetSingleWordInOperand(1); + const analysis::Constant* index_const = + const_mgr->FindDeclaredConstant(index_id); + if (index_const) { + const analysis::Integer* index_type = + index_const->type()->AsInteger(); + assert(index_type); + if (index_type->width() == 32) { + result->insert(index_const->GetU32()); + return true; + } else if (index_type->width() == 64) { + result->insert(index_const->GetU64()); + return true; + } + result.reset(nullptr); + return false; + } else { + // Could be any element. Assuming all are used. + result.reset(nullptr); + return false; + } + } + default: + // We do not know what is happening. Have to assume the worst. + result.reset(nullptr); + return false; + } + }); + + return result; +} + +Instruction* ScalarReplacementPass::CreateNullConstant(uint32_t type_id) { + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); + + const analysis::Type* type = type_mgr->GetType(type_id); + const analysis::Constant* null_const = const_mgr->GetConstant(type, {}); + Instruction* null_inst = + const_mgr->GetDefiningInstruction(null_const, type_id); + context()->UpdateDefUse(null_inst); + return null_inst; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/scalar_replacement_pass.h b/source/opt/scalar_replacement_pass.h new file mode 100644 index 000000000..f5d761278 --- /dev/null +++ b/source/opt/scalar_replacement_pass.h @@ -0,0 +1,238 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_SCALAR_REPLACEMENT_PASS_H_ +#define SOURCE_OPT_SCALAR_REPLACEMENT_PASS_H_ + +#include +#include +#include +#include +#include +#include + +#include "source/opt/function.h" +#include "source/opt/pass.h" +#include "source/opt/type_manager.h" + +namespace spvtools { +namespace opt { + +// Documented in optimizer.hpp +class ScalarReplacementPass : public Pass { + private: + static const uint32_t kDefaultLimit = 100; + + public: + ScalarReplacementPass(uint32_t limit = kDefaultLimit) + : max_num_elements_(limit) { + name_[0] = '\0'; + strcat(name_, "scalar-replacement="); + sprintf(&name_[strlen(name_)], "%d", max_num_elements_); + } + + const char* name() const override { return name_; } + + // Attempts to scalarize all appropriate function scope variables. Returns + // SuccessWithChange if any change is made. + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators | + IRContext::kAnalysisCFG | IRContext::kAnalysisNameMap | + IRContext::kAnalysisConstants | IRContext::kAnalysisTypes; + } + + private: + // Small container for tracking statistics about variables. + // + // TODO(alanbaker): Develop some useful heuristics to tune this pass. + struct VariableStats { + uint32_t num_partial_accesses; + uint32_t num_full_accesses; + }; + + // Attempts to scalarize all appropriate function scope variables in + // |function|. Returns SuccessWithChange if any changes are mode. + Status ProcessFunction(Function* function); + + // Returns true if |varInst| can be scalarized. + // + // Examines the use chain of |varInst| to verify all uses are valid for + // scalarization. + bool CanReplaceVariable(const Instruction* varInst) const; + + // Returns true if |typeInst| is an acceptable type to scalarize. + // + // Allows all aggregate types except runtime arrays. Additionally, checks the + // that the number of elements that would be scalarized is within bounds. + bool CheckType(const Instruction* typeInst) const; + + // Returns true if all the decorations for |varInst| are acceptable for + // scalarization. + bool CheckAnnotations(const Instruction* varInst) const; + + // Returns true if all the decorations for |typeInst| are acceptable for + // scalarization. + bool CheckTypeAnnotations(const Instruction* typeInst) const; + + // Returns true if the uses of |inst| are acceptable for scalarization. + // + // Recursively checks all the uses of |inst|. For |inst| specifically, only + // allows SpvOpAccessChain, SpvOpInBoundsAccessChain, SpvOpLoad and + // SpvOpStore. Access chains must have the first index be a compile-time + // constant. Subsequent uses of access chains (including other access chains) + // are checked in a more relaxed manner. + bool CheckUses(const Instruction* inst) const; + + // Helper function for the above |CheckUses|. + // + // This version tracks some stats about the current OpVariable. These stats + // are used to drive heuristics about when to scalarize. + bool CheckUses(const Instruction* inst, VariableStats* stats) const; + + // Relaxed helper function for |CheckUses|. + bool CheckUsesRelaxed(const Instruction* inst) const; + + // Transfers appropriate decorations from |source| to |replacements|. + void TransferAnnotations(const Instruction* source, + std::vector* replacements); + + // Scalarizes |inst| and updates its uses. + // + // |inst| must be an OpVariable. It is replaced with an OpVariable for each + // for element of the composite type. Uses of |inst| are updated as + // appropriate. If the replacement variables are themselves scalarizable, they + // get added to |worklist| for further processing. If any replacement + // variable ends up with no uses it is erased. Returns false if any + // subsequent access chain is out of bounds. + bool ReplaceVariable(Instruction* inst, std::queue* worklist); + + // Returns the underlying storage type for |inst|. + // + // |inst| must be an OpVariable. Returns the type that is pointed to by + // |inst|. + Instruction* GetStorageType(const Instruction* inst) const; + + // Returns true if the load can be scalarized. + // + // |inst| must be an OpLoad. Returns true if |index| is the pointer operand of + // |inst| and the load is not from volatile memory. + bool CheckLoad(const Instruction* inst, uint32_t index) const; + + // Returns true if the store can be scalarized. + // + // |inst| must be an OpStore. Returns true if |index| is the pointer operand + // of |inst| and the store is not to volatile memory. + bool CheckStore(const Instruction* inst, uint32_t index) const; + + // Creates a variable of type |typeId| from the |index|'th element of + // |varInst|. The new variable is added to |replacements|. + void CreateVariable(uint32_t typeId, Instruction* varInst, uint32_t index, + std::vector* replacements); + + // Populates |replacements| with a new OpVariable for each element of |inst|. + // + // |inst| must be an OpVariable of a composite type. New variables are + // initialized the same as the corresponding index in |inst|. |replacements| + // will contain a variable for each element of the composite with matching + // indexes (i.e. the 0'th element of |inst| is the 0'th entry of + // |replacements|). + void CreateReplacementVariables(Instruction* inst, + std::vector* replacements); + + // Returns the value of an OpConstant of integer type. + // + // |constant| must use two or fewer words to generate the value. + size_t GetConstantInteger(const Instruction* constant) const; + + // Returns the integer literal for |op|. + size_t GetIntegerLiteral(const Operand& op) const; + + // Returns the array length for |arrayInst|. + size_t GetArrayLength(const Instruction* arrayInst) const; + + // Returns the number of elements in |type|. + // + // |type| must be a vector or matrix type. + size_t GetNumElements(const Instruction* type) const; + + // Returns true if |id| is a specialization constant. + // + // |id| must be registered definition. + bool IsSpecConstant(uint32_t id) const; + + // Returns an id for a pointer to |id|. + uint32_t GetOrCreatePointerType(uint32_t id); + + // Creates the initial value for the |index| element of |source| in |newVar|. + // + // If there is an initial value for |source| for element |index|, it is + // appended as an operand on |newVar|. If the initial value is OpUndef, no + // initial value is added to |newVar|. + void GetOrCreateInitialValue(Instruction* source, uint32_t index, + Instruction* newVar); + + // Replaces the load to the entire composite. + // + // Generates a load for each replacement variable and then creates a new + // composite by combining all of the loads. + // + // |load| must be a load. + void ReplaceWholeLoad(Instruction* load, + const std::vector& replacements); + + // Replaces the store to the entire composite. + // + // Generates a composite extract and store for each element in the scalarized + // variable from the original store data input. + void ReplaceWholeStore(Instruction* store, + const std::vector& replacements); + + // Replaces an access chain to the composite variable with either a direct use + // of the appropriate replacement variable or another access chain with the + // replacement variable as the base and one fewer indexes. Returns false if + // the chain has an out of bounds access. + bool ReplaceAccessChain(Instruction* chain, + const std::vector& replacements); + + // Returns a set containing the which components of the result of |inst| are + // potentially used. If the return value is |nullptr|, then every components + // is possibly used. + std::unique_ptr> GetUsedComponents( + Instruction* inst); + + // Returns an instruction defining a null constant with type |type_id|. If + // one already exists, it is returned. Otherwise a new one is created. + Instruction* CreateNullConstant(uint32_t type_id); + + // Maps storage type to a pointer type enclosing that type. + std::unordered_map pointee_to_pointer_; + + // Maps type id to OpConstantNull for that type. + std::unordered_map type_to_null_; + + // Limit on the number of members in an object that will be replaced. + // 0 means there is no limit. + uint32_t max_num_elements_; + bool IsLargerThanSizeLimit(size_t length) const; + char name_[55]; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_SCALAR_REPLACEMENT_PASS_H_ diff --git a/source/opt/set_spec_constant_default_value_pass.cpp b/source/opt/set_spec_constant_default_value_pass.cpp new file mode 100644 index 000000000..4c8d116f7 --- /dev/null +++ b/source/opt/set_spec_constant_default_value_pass.cpp @@ -0,0 +1,368 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/set_spec_constant_default_value_pass.h" + +#include +#include +#include +#include +#include + +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/type_manager.h" +#include "source/opt/types.h" +#include "source/util/make_unique.h" +#include "source/util/parse_number.h" +#include "spirv-tools/libspirv.h" + +namespace spvtools { +namespace opt { + +namespace { +using utils::EncodeNumberStatus; +using utils::NumberType; +using utils::ParseAndEncodeNumber; +using utils::ParseNumber; + +// Given a numeric value in a null-terminated c string and the expected type of +// the value, parses the string and encodes it in a vector of words. If the +// value is a scalar integer or floating point value, encodes the value in +// SPIR-V encoding format. If the value is 'false' or 'true', returns a vector +// with single word with value 0 or 1 respectively. Returns the vector +// containing the encoded value on success. Otherwise returns an empty vector. +std::vector ParseDefaultValueStr(const char* text, + const analysis::Type* type) { + std::vector result; + if (!strcmp(text, "true") && type->AsBool()) { + result.push_back(1u); + } else if (!strcmp(text, "false") && type->AsBool()) { + result.push_back(0u); + } else { + NumberType number_type = {32, SPV_NUMBER_UNSIGNED_INT}; + if (const auto* IT = type->AsInteger()) { + number_type.bitwidth = IT->width(); + number_type.kind = + IT->IsSigned() ? SPV_NUMBER_SIGNED_INT : SPV_NUMBER_UNSIGNED_INT; + } else if (const auto* FT = type->AsFloat()) { + number_type.bitwidth = FT->width(); + number_type.kind = SPV_NUMBER_FLOATING; + } else { + // Does not handle types other then boolean, integer or float. Returns + // empty vector. + result.clear(); + return result; + } + EncodeNumberStatus rc = ParseAndEncodeNumber( + text, number_type, [&result](uint32_t word) { result.push_back(word); }, + nullptr); + // Clear the result vector on failure. + if (rc != EncodeNumberStatus::kSuccess) { + result.clear(); + } + } + return result; +} + +// Given a bit pattern and a type, checks if the bit pattern is compatible +// with the type. If so, returns the bit pattern, otherwise returns an empty +// bit pattern. If the given bit pattern is empty, returns an empty bit +// pattern. If the given type represents a SPIR-V Boolean type, the bit pattern +// to be returned is determined with the following standard: +// If any words in the input bit pattern are non zero, returns a bit pattern +// with 0x1, which represents a 'true'. +// If all words in the bit pattern are zero, returns a bit pattern with 0x0, +// which represents a 'false'. +std::vector ParseDefaultValueBitPattern( + const std::vector& input_bit_pattern, + const analysis::Type* type) { + std::vector result; + if (type->AsBool()) { + if (std::any_of(input_bit_pattern.begin(), input_bit_pattern.end(), + [](uint32_t i) { return i != 0; })) { + result.push_back(1u); + } else { + result.push_back(0u); + } + return result; + } else if (const auto* IT = type->AsInteger()) { + if (IT->width() == input_bit_pattern.size() * sizeof(uint32_t) * 8) { + return std::vector(input_bit_pattern); + } + } else if (const auto* FT = type->AsFloat()) { + if (FT->width() == input_bit_pattern.size() * sizeof(uint32_t) * 8) { + return std::vector(input_bit_pattern); + } + } + result.clear(); + return result; +} + +// Returns true if the given instruction's result id could have a SpecId +// decoration. +bool CanHaveSpecIdDecoration(const Instruction& inst) { + switch (inst.opcode()) { + case SpvOp::SpvOpSpecConstant: + case SpvOp::SpvOpSpecConstantFalse: + case SpvOp::SpvOpSpecConstantTrue: + return true; + default: + return false; + } +} + +// Given a decoration group defining instruction that is decorated with SpecId +// decoration, finds the spec constant defining instruction which is the real +// target of the SpecId decoration. Returns the spec constant defining +// instruction if such an instruction is found, otherwise returns a nullptr. +Instruction* GetSpecIdTargetFromDecorationGroup( + const Instruction& decoration_group_defining_inst, + analysis::DefUseManager* def_use_mgr) { + // Find the OpGroupDecorate instruction which consumes the given decoration + // group. Note that the given decoration group has SpecId decoration, which + // is unique for different spec constants. So the decoration group cannot be + // consumed by different OpGroupDecorate instructions. Therefore we only need + // the first OpGroupDecoration instruction that uses the given decoration + // group. + Instruction* group_decorate_inst = nullptr; + if (def_use_mgr->WhileEachUser(&decoration_group_defining_inst, + [&group_decorate_inst](Instruction* user) { + if (user->opcode() == + SpvOp::SpvOpGroupDecorate) { + group_decorate_inst = user; + return false; + } + return true; + })) + return nullptr; + + // Scan through the target ids of the OpGroupDecorate instruction. There + // should be only one spec constant target consumes the SpecId decoration. + // If multiple target ids are presented in the OpGroupDecorate instruction, + // they must be the same one that defined by an eligible spec constant + // instruction. If the OpGroupDecorate instruction has different target ids + // or a target id is not defined by an eligible spec cosntant instruction, + // returns a nullptr. + Instruction* target_inst = nullptr; + for (uint32_t i = 1; i < group_decorate_inst->NumInOperands(); i++) { + // All the operands of a OpGroupDecorate instruction should be of type + // SPV_OPERAND_TYPE_ID. + uint32_t candidate_id = group_decorate_inst->GetSingleWordInOperand(i); + Instruction* candidate_inst = def_use_mgr->GetDef(candidate_id); + + if (!candidate_inst) { + continue; + } + + if (!target_inst) { + // If the spec constant target has not been found yet, check if the + // candidate instruction is the target. + if (CanHaveSpecIdDecoration(*candidate_inst)) { + target_inst = candidate_inst; + } else { + // Spec id decoration should not be applied on other instructions. + // TODO(qining): Emit an error message in the invalid case once the + // error handling is done. + return nullptr; + } + } else { + // If the spec constant target has been found, check if the candidate + // instruction is the same one as the target. The module is invalid if + // the candidate instruction is different with the found target. + // TODO(qining): Emit an error messaage in the invalid case once the + // error handling is done. + if (candidate_inst != target_inst) return nullptr; + } + } + return target_inst; +} +} // namespace + +Pass::Status SetSpecConstantDefaultValuePass::Process() { + // The operand index of decoration target in an OpDecorate instruction. + const uint32_t kTargetIdOperandIndex = 0; + // The operand index of the decoration literal in an OpDecorate instruction. + const uint32_t kDecorationOperandIndex = 1; + // The operand index of Spec id literal value in an OpDecorate SpecId + // instruction. + const uint32_t kSpecIdLiteralOperandIndex = 2; + // The number of operands in an OpDecorate SpecId instruction. + const uint32_t kOpDecorateSpecIdNumOperands = 3; + // The in-operand index of the default value in a OpSpecConstant instruction. + const uint32_t kOpSpecConstantLiteralInOperandIndex = 0; + + bool modified = false; + // Scan through all the annotation instructions to find 'OpDecorate SpecId' + // instructions. Then extract the decoration target of those instructions. + // The decoration targets should be spec constant defining instructions with + // opcode: OpSpecConstant{|True|False}. The spec id of those spec constants + // will be used to look up their new default values in the mapping from + // spec id to new default value strings. Once a new default value string + // is found for a spec id, the string will be parsed according to the target + // spec constant type. The parsed value will be used to replace the original + // default value of the target spec constant. + for (Instruction& inst : context()->annotations()) { + // Only process 'OpDecorate SpecId' instructions + if (inst.opcode() != SpvOp::SpvOpDecorate) continue; + if (inst.NumOperands() != kOpDecorateSpecIdNumOperands) continue; + if (inst.GetSingleWordInOperand(kDecorationOperandIndex) != + uint32_t(SpvDecoration::SpvDecorationSpecId)) { + continue; + } + + // 'inst' is an OpDecorate SpecId instruction. + uint32_t spec_id = inst.GetSingleWordOperand(kSpecIdLiteralOperandIndex); + uint32_t target_id = inst.GetSingleWordOperand(kTargetIdOperandIndex); + + // Find the spec constant defining instruction. Note that the + // target_id might be a decoration group id. + Instruction* spec_inst = nullptr; + if (Instruction* target_inst = get_def_use_mgr()->GetDef(target_id)) { + if (target_inst->opcode() == SpvOp::SpvOpDecorationGroup) { + spec_inst = + GetSpecIdTargetFromDecorationGroup(*target_inst, get_def_use_mgr()); + } else { + spec_inst = target_inst; + } + } else { + continue; + } + if (!spec_inst) continue; + + // Get the default value bit pattern for this spec id. + std::vector bit_pattern; + + if (spec_id_to_value_str_.size() != 0) { + // Search for the new string-form default value for this spec id. + auto iter = spec_id_to_value_str_.find(spec_id); + if (iter == spec_id_to_value_str_.end()) { + continue; + } + + // Gets the string of the default value and parses it to bit pattern + // with the type of the spec constant. + const std::string& default_value_str = iter->second; + bit_pattern = ParseDefaultValueStr( + default_value_str.c_str(), + context()->get_type_mgr()->GetType(spec_inst->type_id())); + + } else { + // Search for the new bit-pattern-form default value for this spec id. + auto iter = spec_id_to_value_bit_pattern_.find(spec_id); + if (iter == spec_id_to_value_bit_pattern_.end()) { + continue; + } + + // Gets the bit-pattern of the default value from the map directly. + bit_pattern = ParseDefaultValueBitPattern( + iter->second, + context()->get_type_mgr()->GetType(spec_inst->type_id())); + } + + if (bit_pattern.empty()) continue; + + // Update the operand bit patterns of the spec constant defining + // instruction. + switch (spec_inst->opcode()) { + case SpvOp::SpvOpSpecConstant: + // If the new value is the same with the original value, no + // need to do anything. Otherwise update the operand words. + if (spec_inst->GetInOperand(kOpSpecConstantLiteralInOperandIndex) + .words != bit_pattern) { + spec_inst->SetInOperand(kOpSpecConstantLiteralInOperandIndex, + std::move(bit_pattern)); + modified = true; + } + break; + case SpvOp::SpvOpSpecConstantTrue: + // If the new value is also 'true', no need to change anything. + // Otherwise, set the opcode to OpSpecConstantFalse; + if (!static_cast(bit_pattern.front())) { + spec_inst->SetOpcode(SpvOp::SpvOpSpecConstantFalse); + modified = true; + } + break; + case SpvOp::SpvOpSpecConstantFalse: + // If the new value is also 'false', no need to change anything. + // Otherwise, set the opcode to OpSpecConstantTrue; + if (static_cast(bit_pattern.front())) { + spec_inst->SetOpcode(SpvOp::SpvOpSpecConstantTrue); + modified = true; + } + break; + default: + break; + } + // No need to update the DefUse manager, as this pass does not change any + // ids. + } + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +// Returns true if the given char is ':', '\0' or considered as blank space +// (i.e.: '\n', '\r', '\v', '\t', '\f' and ' '). +bool IsSeparator(char ch) { + return std::strchr(":\0", ch) || std::isspace(ch) != 0; +} + +std::unique_ptr +SetSpecConstantDefaultValuePass::ParseDefaultValuesString(const char* str) { + if (!str) return nullptr; + + auto spec_id_to_value = MakeUnique(); + + // The parsing loop, break when points to the end. + while (*str) { + // Find the spec id. + while (std::isspace(*str)) str++; // skip leading spaces. + const char* entry_begin = str; + while (!IsSeparator(*str)) str++; + const char* entry_end = str; + std::string spec_id_str(entry_begin, entry_end - entry_begin); + uint32_t spec_id = 0; + if (!ParseNumber(spec_id_str.c_str(), &spec_id)) { + // The spec id is not a valid uint32 number. + return nullptr; + } + auto iter = spec_id_to_value->find(spec_id); + if (iter != spec_id_to_value->end()) { + // Same spec id has been defined before + return nullptr; + } + // Find the ':', spaces between the spec id and the ':' are not allowed. + if (*str++ != ':') { + // ':' not found + return nullptr; + } + // Find the value string + const char* val_begin = str; + while (!IsSeparator(*str)) str++; + const char* val_end = str; + if (val_end == val_begin) { + // Value string is empty. + return nullptr; + } + // Update the mapping with spec id and value string. + (*spec_id_to_value)[spec_id] = std::string(val_begin, val_end - val_begin); + + // Skip trailing spaces. + while (std::isspace(*str)) str++; + } + + return spec_id_to_value; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/set_spec_constant_default_value_pass.h b/source/opt/set_spec_constant_default_value_pass.h new file mode 100644 index 000000000..8bd1787cc --- /dev/null +++ b/source/opt/set_spec_constant_default_value_pass.h @@ -0,0 +1,114 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_SET_SPEC_CONSTANT_DEFAULT_VALUE_PASS_H_ +#define SOURCE_OPT_SET_SPEC_CONSTANT_DEFAULT_VALUE_PASS_H_ + +#include +#include +#include +#include +#include + +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class SetSpecConstantDefaultValuePass : public Pass { + public: + using SpecIdToValueStrMap = std::unordered_map; + using SpecIdToValueBitPatternMap = + std::unordered_map>; + using SpecIdToInstMap = std::unordered_map; + + // Constructs a pass instance with a map from spec ids to default values + // in the form of string. + explicit SetSpecConstantDefaultValuePass( + const SpecIdToValueStrMap& default_values) + : spec_id_to_value_str_(default_values), + spec_id_to_value_bit_pattern_() {} + explicit SetSpecConstantDefaultValuePass(SpecIdToValueStrMap&& default_values) + : spec_id_to_value_str_(std::move(default_values)), + spec_id_to_value_bit_pattern_() {} + + // Constructs a pass instance with a map from spec ids to default values in + // the form of bit pattern. + explicit SetSpecConstantDefaultValuePass( + const SpecIdToValueBitPatternMap& default_values) + : spec_id_to_value_str_(), + spec_id_to_value_bit_pattern_(default_values) {} + explicit SetSpecConstantDefaultValuePass( + SpecIdToValueBitPatternMap&& default_values) + : spec_id_to_value_str_(), + spec_id_to_value_bit_pattern_(std::move(default_values)) {} + + const char* name() const override { return "set-spec-const-default-value"; } + Status Process() override; + + // Parses the given null-terminated C string to get a mapping from Spec Id to + // default value strings. Returns a unique pointer of the mapping from spec + // ids to spec constant default value strings built from the given |str| on + // success. Returns a nullptr if the given string is not valid for building + // the mapping. + // A valid string for building the mapping should follow the rule below: + // + // ": : ..." + // Example: + // "200:0x11 201:3.14 202:1.4728" + // + // Entries are separated with blank spaces (i.e.:' ', '\n', '\r', '\t', + // '\f', '\v'). Each entry corresponds to a Spec Id and default value pair. + // Multiple spaces between, before or after entries are allowed. However, + // spaces are not allowed within spec id or the default value string because + // spaces are always considered as delimiter to separate entries. + // + // In each entry, the spec id and value string is separated by ':'. Missing + // ':' in any entry is invalid. And it is invalid to have blank spaces in + // between the spec id and ':' or the default value and ':'. + // + // : specifies the spec id value. + // The text must represent a valid uint32_t number. + // Hex format with '0x' prefix is allowed. + // Empty is not allowed. + // One spec id value can only be defined once, multiple default values + // defined for the same spec id is not allowed. Spec ids with same value + // but different formats (e.g. 0x100 and 256) are considered the same. + // + // : the default value string. + // Spaces before and after default value text is allowed. + // Spaces within the text is not allowed. + // Empty is not allowed. + static std::unique_ptr ParseDefaultValuesString( + const char* str); + + private: + // The mappings from spec ids to default values. Two maps are defined here, + // each to be used for one specific form of the default values. Only one of + // them will be populated in practice. + + // The mapping from spec ids to their string-form default values to be set. + const SpecIdToValueStrMap spec_id_to_value_str_; + // The mapping from spec ids to their bitpattern-form default values to be + // set. + const SpecIdToValueBitPatternMap spec_id_to_value_bit_pattern_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_SET_SPEC_CONSTANT_DEFAULT_VALUE_PASS_H_ diff --git a/source/opt/simplification_pass.cpp b/source/opt/simplification_pass.cpp new file mode 100644 index 000000000..5fbafbdd1 --- /dev/null +++ b/source/opt/simplification_pass.cpp @@ -0,0 +1,120 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/simplification_pass.h" + +#include +#include +#include + +#include "source/opt/fold.h" + +namespace spvtools { +namespace opt { + +Pass::Status SimplificationPass::Process() { + bool modified = false; + + for (Function& function : *get_module()) { + modified |= SimplifyFunction(&function); + } + return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); +} + +bool SimplificationPass::SimplifyFunction(Function* function) { + bool modified = false; + // Phase 1: Traverse all instructions in dominance order. + // The second phase will only be on the instructions whose inputs have changed + // after being processed during phase 1. Since OpPhi instructions are the + // only instructions whose inputs do not necessarily dominate the use, we keep + // track of the OpPhi instructions already seen, and add them to the work list + // for phase 2 when needed. + std::vector work_list; + std::unordered_set process_phis; + std::unordered_set inst_to_kill; + std::unordered_set in_work_list; + const InstructionFolder& folder = context()->get_instruction_folder(); + + cfg()->ForEachBlockInReversePostOrder( + function->entry().get(), + [&modified, &process_phis, &work_list, &in_work_list, &inst_to_kill, + folder, this](BasicBlock* bb) { + for (Instruction* inst = &*bb->begin(); inst; inst = inst->NextNode()) { + if (inst->opcode() == SpvOpPhi) { + process_phis.insert(inst); + } + + if (inst->opcode() == SpvOpCopyObject || + folder.FoldInstruction(inst)) { + modified = true; + context()->AnalyzeUses(inst); + get_def_use_mgr()->ForEachUser(inst, [&work_list, &process_phis, + &in_work_list]( + Instruction* use) { + if (process_phis.count(use) && in_work_list.insert(use).second) { + work_list.push_back(use); + } + }); + if (inst->opcode() == SpvOpCopyObject) { + context()->ReplaceAllUsesWith(inst->result_id(), + inst->GetSingleWordInOperand(0)); + inst_to_kill.insert(inst); + in_work_list.insert(inst); + } else if (inst->opcode() == SpvOpNop) { + inst_to_kill.insert(inst); + in_work_list.insert(inst); + } + } + } + }); + + // Phase 2: process the instructions in the work list until all of the work is + // done. This time we add all users to the work list because phase 1 + // has already finished. + for (size_t i = 0; i < work_list.size(); ++i) { + Instruction* inst = work_list[i]; + in_work_list.erase(inst); + if (inst->opcode() == SpvOpCopyObject || folder.FoldInstruction(inst)) { + modified = true; + context()->AnalyzeUses(inst); + get_def_use_mgr()->ForEachUser( + inst, [&work_list, &in_work_list](Instruction* use) { + if (!use->IsDecoration() && use->opcode() != SpvOpName && + in_work_list.insert(use).second) { + work_list.push_back(use); + } + }); + + if (inst->opcode() == SpvOpCopyObject) { + context()->ReplaceAllUsesWith(inst->result_id(), + inst->GetSingleWordInOperand(0)); + inst_to_kill.insert(inst); + in_work_list.insert(inst); + } else if (inst->opcode() == SpvOpNop) { + inst_to_kill.insert(inst); + in_work_list.insert(inst); + } + } + } + + // Phase 3: Kill instructions we know are no longer needed. + for (Instruction* inst : inst_to_kill) { + context()->KillInst(inst); + } + + return modified; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/simplification_pass.h b/source/opt/simplification_pass.h new file mode 100644 index 000000000..bcb88bfcd --- /dev/null +++ b/source/opt/simplification_pass.h @@ -0,0 +1,50 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_SIMPLIFICATION_PASS_H_ +#define SOURCE_OPT_SIMPLIFICATION_PASS_H_ + +#include "source/opt/function.h" +#include "source/opt/ir_context.h" +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class SimplificationPass : public Pass { + public: + const char* name() const override { return "simplify-instructions"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators | + IRContext::kAnalysisCFG | IRContext::kAnalysisDominatorAnalysis | + IRContext::kAnalysisNameMap | IRContext::kAnalysisConstants | + IRContext::kAnalysisTypes; + } + + private: + // Returns true if the module was changed. The simplifier is called on every + // instruction in |function| until nothing else in the function can be + // simplified. + bool SimplifyFunction(Function* function); +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_SIMPLIFICATION_PASS_H_ diff --git a/source/opt/ssa_rewrite_pass.cpp b/source/opt/ssa_rewrite_pass.cpp new file mode 100644 index 000000000..f2cb2da28 --- /dev/null +++ b/source/opt/ssa_rewrite_pass.cpp @@ -0,0 +1,595 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file implements the SSA rewriting algorithm proposed in +// +// Simple and Efficient Construction of Static Single Assignment Form. +// Braun M., Buchwald S., Hack S., Leißa R., Mallon C., Zwinkau A. (2013) +// In: Jhala R., De Bosschere K. (eds) +// Compiler Construction. CC 2013. +// Lecture Notes in Computer Science, vol 7791. +// Springer, Berlin, Heidelberg +// +// https://link.springer.com/chapter/10.1007/978-3-642-37051-9_6 +// +// In contrast to common eager algorithms based on dominance and dominance +// frontier information, this algorithm works backwards from load operations. +// +// When a target variable is loaded, it queries the variable's reaching +// definition. If the reaching definition is unknown at the current location, +// it searches backwards in the CFG, inserting Phi instructions at join points +// in the CFG along the way until it finds the desired store instruction. +// +// The algorithm avoids repeated lookups using memoization. +// +// For reducible CFGs, which are a superset of the structured CFGs in SPIRV, +// this algorithm is proven to produce minimal SSA. That is, it inserts the +// minimal number of Phi instructions required to ensure the SSA property, but +// some Phi instructions may be dead +// (https://en.wikipedia.org/wiki/Static_single_assignment_form). + +#include "source/opt/ssa_rewrite_pass.h" + +#include +#include + +#include "source/opcode.h" +#include "source/opt/cfg.h" +#include "source/opt/mem_pass.h" +#include "source/util/make_unique.h" + +// Debug logging (0: Off, 1-N: Verbosity level). Replace this with the +// implementation done for +// https://github.com/KhronosGroup/SPIRV-Tools/issues/1351 +// #define SSA_REWRITE_DEBUGGING_LEVEL 3 + +#ifdef SSA_REWRITE_DEBUGGING_LEVEL +#include +#else +#define SSA_REWRITE_DEBUGGING_LEVEL 0 +#endif + +namespace spvtools { +namespace opt { + +namespace { +const uint32_t kStoreValIdInIdx = 1; +const uint32_t kVariableInitIdInIdx = 1; +} // namespace + +std::string SSARewriter::PhiCandidate::PrettyPrint(const CFG* cfg) const { + std::ostringstream str; + str << "%" << result_id_ << " = Phi[%" << var_id_ << ", BB %" << bb_->id() + << "]("; + if (phi_args_.size() > 0) { + uint32_t arg_ix = 0; + for (uint32_t pred_label : cfg->preds(bb_->id())) { + uint32_t arg_id = phi_args_[arg_ix++]; + str << "[%" << arg_id << ", bb(%" << pred_label << ")] "; + } + } + str << ")"; + if (copy_of_ != 0) { + str << " [COPY OF " << copy_of_ << "]"; + } + str << ((is_complete_) ? " [COMPLETE]" : " [INCOMPLETE]"); + + return str.str(); +} + +SSARewriter::PhiCandidate& SSARewriter::CreatePhiCandidate(uint32_t var_id, + BasicBlock* bb) { + // TODO(1841): Handle id overflow. + uint32_t phi_result_id = pass_->context()->TakeNextId(); + auto result = phi_candidates_.emplace( + phi_result_id, PhiCandidate(var_id, phi_result_id, bb)); + PhiCandidate& phi_candidate = result.first->second; + return phi_candidate; +} + +void SSARewriter::ReplacePhiUsersWith(const PhiCandidate& phi_to_remove, + uint32_t repl_id) { + for (uint32_t user_id : phi_to_remove.users()) { + PhiCandidate* user_phi = GetPhiCandidate(user_id); + if (user_phi) { + // If the user is a Phi candidate, replace all arguments that refer to + // |phi_to_remove.result_id()| with |repl_id|. + for (uint32_t& arg : user_phi->phi_args()) { + if (arg == phi_to_remove.result_id()) { + arg = repl_id; + } + } + } else { + // For regular loads, traverse the |load_replacement_| table looking for + // instances of |phi_to_remove|. + for (auto& it : load_replacement_) { + if (it.second == phi_to_remove.result_id()) { + it.second = repl_id; + } + } + } + } +} + +uint32_t SSARewriter::TryRemoveTrivialPhi(PhiCandidate* phi_candidate) { + uint32_t same_id = 0; + for (uint32_t arg_id : phi_candidate->phi_args()) { + if (arg_id == same_id || arg_id == phi_candidate->result_id()) { + // This is a self-reference operand or a reference to the same value ID. + continue; + } + if (same_id != 0) { + // This Phi candidate merges at least two values. Therefore, it is not + // trivial. + assert(phi_candidate->copy_of() == 0 && + "Phi candidate transitioning from copy to non-copy."); + return phi_candidate->result_id(); + } + same_id = arg_id; + } + + // The previous logic has determined that this Phi candidate |phi_candidate| + // is trivial. It is essentially the copy operation phi_candidate->phi_result + // = Phi(same, same, same, ...). Since it is not necessary, we can re-route + // all the users of |phi_candidate->phi_result| to all its users, and remove + // |phi_candidate|. + + // Mark the Phi candidate as a trivial copy of |same_id|, so it won't be + // generated. + phi_candidate->MarkCopyOf(same_id); + + assert(same_id != 0 && "Completed Phis cannot have %0 in their arguments"); + + // Since |phi_candidate| always produces |same_id|, replace all the users of + // |phi_candidate| with |same_id|. + ReplacePhiUsersWith(*phi_candidate, same_id); + + return same_id; +} + +uint32_t SSARewriter::AddPhiOperands(PhiCandidate* phi_candidate) { + assert(phi_candidate->phi_args().size() == 0 && + "Phi candidate already has arguments"); + + bool found_0_arg = false; + for (uint32_t pred : pass_->cfg()->preds(phi_candidate->bb()->id())) { + BasicBlock* pred_bb = pass_->cfg()->block(pred); + + // If |pred_bb| is not sealed, use %0 to indicate that + // |phi_candidate| needs to be completed after the whole CFG has + // been processed. + // + // Note that we cannot call GetReachingDef() in these cases + // because this would generate an empty Phi candidate in + // |pred_bb|. When |pred_bb| is later processed, a new definition + // for |phi_candidate->var_id_| will be lost because + // |phi_candidate| will still be reached by the empty Phi. + // + // Consider: + // + // BB %23: + // %38 = Phi[%i](%int_0[%1], %39[%25]) + // + // ... + // + // BB %25: [Starts unsealed] + // %39 = Phi[%i]() + // %34 = ... + // OpStore %i %34 -> Currdef(%i) at %25 is %34 + // OpBranch %23 + // + // When we first create the Phi in %38, we add an operandless Phi in + // %39 to hold the unknown reaching def for %i. + // + // But then, when we go to complete %39 at the end. The reaching def + // for %i in %25's predecessor is %38 itself. So we miss the fact + // that %25 has a def for %i that should be used. + // + // By making the argument %0, we make |phi_candidate| incomplete, + // which will cause it to be completed after the whole CFG has + // been scanned. + uint32_t arg_id = IsBlockSealed(pred_bb) + ? GetReachingDef(phi_candidate->var_id(), pred_bb) + : 0; + phi_candidate->phi_args().push_back(arg_id); + + if (arg_id == 0) { + found_0_arg = true; + } else { + // If this argument is another Phi candidate, add |phi_candidate| to the + // list of users for the defining Phi. + PhiCandidate* defining_phi = GetPhiCandidate(arg_id); + if (defining_phi && defining_phi != phi_candidate) { + defining_phi->AddUser(phi_candidate->result_id()); + } + } + } + + // If we could not fill-in all the arguments of this Phi, mark it incomplete + // so it gets completed after the whole CFG has been processed. + if (found_0_arg) { + phi_candidate->MarkIncomplete(); + incomplete_phis_.push(phi_candidate); + return phi_candidate->result_id(); + } + + // Try to remove |phi_candidate|, if it's trivial. + uint32_t repl_id = TryRemoveTrivialPhi(phi_candidate); + if (repl_id == phi_candidate->result_id()) { + // |phi_candidate| is complete and not trivial. Add it to the + // list of Phi candidates to generate. + phi_candidate->MarkComplete(); + phis_to_generate_.push_back(phi_candidate); + } + + return repl_id; +} + +uint32_t SSARewriter::GetReachingDef(uint32_t var_id, BasicBlock* bb) { + // If |var_id| has a definition in |bb|, return it. + const auto& bb_it = defs_at_block_.find(bb); + if (bb_it != defs_at_block_.end()) { + const auto& current_defs = bb_it->second; + const auto& var_it = current_defs.find(var_id); + if (var_it != current_defs.end()) { + return var_it->second; + } + } + + // Otherwise, look up the value for |var_id| in |bb|'s predecessors. + uint32_t val_id = 0; + auto& predecessors = pass_->cfg()->preds(bb->id()); + if (predecessors.size() == 1) { + // If |bb| has exactly one predecessor, we look for |var_id|'s definition + // there. + val_id = GetReachingDef(var_id, pass_->cfg()->block(predecessors[0])); + } else if (predecessors.size() > 1) { + // If there is more than one predecessor, this is a join block which may + // require a Phi instruction. This will act as |var_id|'s current + // definition to break potential cycles. + PhiCandidate& phi_candidate = CreatePhiCandidate(var_id, bb); + WriteVariable(var_id, bb, phi_candidate.result_id()); + val_id = AddPhiOperands(&phi_candidate); + } + + // If we could not find a store for this variable in the path from the root + // of the CFG, the variable is not defined, so we use undef. + if (val_id == 0) { + val_id = pass_->GetUndefVal(var_id); + } + + WriteVariable(var_id, bb, val_id); + + return val_id; +} + +void SSARewriter::SealBlock(BasicBlock* bb) { + auto result = sealed_blocks_.insert(bb); + (void)result; + assert(result.second == true && + "Tried to seal the same basic block more than once."); +} + +void SSARewriter::ProcessStore(Instruction* inst, BasicBlock* bb) { + auto opcode = inst->opcode(); + assert((opcode == SpvOpStore || opcode == SpvOpVariable) && + "Expecting a store or a variable definition instruction."); + + uint32_t var_id = 0; + uint32_t val_id = 0; + if (opcode == SpvOpStore) { + (void)pass_->GetPtr(inst, &var_id); + val_id = inst->GetSingleWordInOperand(kStoreValIdInIdx); + } else if (inst->NumInOperands() >= 2) { + var_id = inst->result_id(); + val_id = inst->GetSingleWordInOperand(kVariableInitIdInIdx); + } + if (pass_->IsTargetVar(var_id)) { + WriteVariable(var_id, bb, val_id); + +#if SSA_REWRITE_DEBUGGING_LEVEL > 1 + std::cerr << "\tFound store '%" << var_id << " = %" << val_id << "': " + << inst->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES) + << "\n"; +#endif + } +} + +void SSARewriter::ProcessLoad(Instruction* inst, BasicBlock* bb) { + uint32_t var_id = 0; + (void)pass_->GetPtr(inst, &var_id); + if (pass_->IsTargetVar(var_id)) { + // Get the immediate reaching definition for |var_id|. + uint32_t val_id = GetReachingDef(var_id, bb); + + // Schedule a replacement for the result of this load instruction with + // |val_id|. After all the rewriting decisions are made, every use of + // this load will be replaced with |val_id|. + const uint32_t load_id = inst->result_id(); + assert(load_replacement_.count(load_id) == 0); + load_replacement_[load_id] = val_id; + PhiCandidate* defining_phi = GetPhiCandidate(val_id); + if (defining_phi) { + defining_phi->AddUser(load_id); + } + +#if SSA_REWRITE_DEBUGGING_LEVEL > 1 + std::cerr << "\tFound load: " + << inst->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES) + << " (replacement for %" << load_id << " is %" << val_id << ")\n"; +#endif + } +} + +void SSARewriter::PrintPhiCandidates() const { + std::cerr << "\nPhi candidates:\n"; + for (const auto& phi_it : phi_candidates_) { + std::cerr << "\tBB %" << phi_it.second.bb()->id() << ": " + << phi_it.second.PrettyPrint(pass_->cfg()) << "\n"; + } + std::cerr << "\n"; +} + +void SSARewriter::PrintReplacementTable() const { + std::cerr << "\nLoad replacement table\n"; + for (const auto& it : load_replacement_) { + std::cerr << "\t%" << it.first << " -> %" << it.second << "\n"; + } + std::cerr << "\n"; +} + +void SSARewriter::GenerateSSAReplacements(BasicBlock* bb) { +#if SSA_REWRITE_DEBUGGING_LEVEL > 1 + std::cerr << "Generating SSA replacements for block: " << bb->id() << "\n"; + std::cerr << bb->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES) + << "\n"; +#endif + + for (auto& inst : *bb) { + auto opcode = inst.opcode(); + if (opcode == SpvOpStore || opcode == SpvOpVariable) { + ProcessStore(&inst, bb); + } else if (inst.opcode() == SpvOpLoad) { + ProcessLoad(&inst, bb); + } + } + + // Seal |bb|. This means that all the stores in it have been scanned and it's + // ready to feed them into its successors. + SealBlock(bb); + +#if SSA_REWRITE_DEBUGGING_LEVEL > 1 + PrintPhiCandidates(); + PrintReplacementTable(); + std::cerr << "\n\n"; +#endif +} + +uint32_t SSARewriter::GetReplacement(std::pair repl) { + uint32_t val_id = repl.second; + auto it = load_replacement_.find(val_id); + while (it != load_replacement_.end()) { + val_id = it->second; + it = load_replacement_.find(val_id); + } + return val_id; +} + +uint32_t SSARewriter::GetPhiArgument(const PhiCandidate* phi_candidate, + uint32_t ix) { + assert(phi_candidate->IsReady() && + "Tried to get the final argument from an incomplete/trivial Phi"); + + uint32_t arg_id = phi_candidate->phi_args()[ix]; + while (arg_id != 0) { + PhiCandidate* phi_user = GetPhiCandidate(arg_id); + if (phi_user == nullptr || phi_user->IsReady()) { + // If the argument is not a Phi or it's a Phi candidate ready to be + // emitted, return it. + return arg_id; + } + arg_id = phi_user->copy_of(); + } + + assert(false && + "No Phi candidates in the copy-of chain are ready to be generated"); + + return 0; +} + +bool SSARewriter::ApplyReplacements() { + bool modified = false; + +#if SSA_REWRITE_DEBUGGING_LEVEL > 2 + std::cerr << "\n\nApplying replacement decisions to IR\n\n"; + PrintPhiCandidates(); + PrintReplacementTable(); + std::cerr << "\n\n"; +#endif + + // Add Phi instructions from completed Phi candidates. + std::vector generated_phis; + for (const PhiCandidate* phi_candidate : phis_to_generate_) { +#if SSA_REWRITE_DEBUGGING_LEVEL > 2 + std::cerr << "Phi candidate: " << phi_candidate->PrettyPrint(pass_->cfg()) + << "\n"; +#endif + + assert(phi_candidate->is_complete() && + "Tried to instantiate a Phi instruction from an incomplete Phi " + "candidate"); + + // Build the vector of operands for the new OpPhi instruction. + uint32_t type_id = pass_->GetPointeeTypeId( + pass_->get_def_use_mgr()->GetDef(phi_candidate->var_id())); + std::vector phi_operands; + uint32_t arg_ix = 0; + std::unordered_map already_seen; + for (uint32_t pred_label : pass_->cfg()->preds(phi_candidate->bb()->id())) { + uint32_t op_val_id = GetPhiArgument(phi_candidate, arg_ix++); + if (already_seen.count(pred_label) == 0) { + phi_operands.push_back( + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {op_val_id}}); + phi_operands.push_back( + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {pred_label}}); + already_seen[pred_label] = op_val_id; + } else { + // It is possible that there are two edges from the same parent block. + // Since the OpPhi can have only one entry for each parent, we have to + // make sure the two edges are consistent with each other. + assert(already_seen[pred_label] == op_val_id && + "Inconsistent value for duplicate edges."); + } + } + + // Generate a new OpPhi instruction and insert it in its basic + // block. + std::unique_ptr phi_inst( + new Instruction(pass_->context(), SpvOpPhi, type_id, + phi_candidate->result_id(), phi_operands)); + generated_phis.push_back(phi_inst.get()); + pass_->get_def_use_mgr()->AnalyzeInstDef(&*phi_inst); + pass_->context()->set_instr_block(&*phi_inst, phi_candidate->bb()); + auto insert_it = phi_candidate->bb()->begin(); + insert_it.InsertBefore(std::move(phi_inst)); + pass_->context()->get_decoration_mgr()->CloneDecorations( + phi_candidate->var_id(), phi_candidate->result_id(), + {SpvDecorationRelaxedPrecision}); + + modified = true; + } + + // Scan uses for all inserted Phi instructions. Do this separately from the + // registration of the Phi instruction itself to avoid trying to analyze uses + // of Phi instructions that have not been registered yet. + for (Instruction* phi_inst : generated_phis) { + pass_->get_def_use_mgr()->AnalyzeInstUse(&*phi_inst); + } + +#if SSA_REWRITE_DEBUGGING_LEVEL > 1 + std::cerr << "\n\nReplacing the result of load instructions with the " + "corresponding SSA id\n\n"; +#endif + + // Apply replacements from the load replacement table. + for (auto& repl : load_replacement_) { + uint32_t load_id = repl.first; + uint32_t val_id = GetReplacement(repl); + Instruction* load_inst = + pass_->context()->get_def_use_mgr()->GetDef(load_id); + +#if SSA_REWRITE_DEBUGGING_LEVEL > 2 + std::cerr << "\t" + << load_inst->PrettyPrint( + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES) + << " (%" << load_id << " -> %" << val_id << ")\n"; +#endif + + // Remove the load instruction and replace all the uses of this load's + // result with |val_id|. Kill any names or decorates using the load's + // result before replacing to prevent incorrect replacement in those + // instructions. + pass_->context()->KillNamesAndDecorates(load_id); + pass_->context()->ReplaceAllUsesWith(load_id, val_id); + pass_->context()->KillInst(load_inst); + modified = true; + } + + return modified; +} + +void SSARewriter::FinalizePhiCandidate(PhiCandidate* phi_candidate) { + assert(phi_candidate->phi_args().size() > 0 && + "Phi candidate should have arguments"); + + uint32_t ix = 0; + for (uint32_t pred : pass_->cfg()->preds(phi_candidate->bb()->id())) { + BasicBlock* pred_bb = pass_->cfg()->block(pred); + uint32_t& arg_id = phi_candidate->phi_args()[ix++]; + if (arg_id == 0) { + // If |pred_bb| is still not sealed, it means it's unreachable. In this + // case, we just use Undef as an argument. + arg_id = IsBlockSealed(pred_bb) + ? GetReachingDef(phi_candidate->var_id(), pred_bb) + : pass_->GetUndefVal(phi_candidate->var_id()); + } + } + + // This candidate is now completed. + phi_candidate->MarkComplete(); + + // If |phi_candidate| is not trivial, add it to the list of Phis to generate. + if (TryRemoveTrivialPhi(phi_candidate) == phi_candidate->result_id()) { + // If we could not remove |phi_candidate|, it means that it is complete + // and not trivial. Add it to the list of Phis to generate. + assert(!phi_candidate->copy_of() && "A completed Phi cannot be trivial."); + phis_to_generate_.push_back(phi_candidate); + } +} + +void SSARewriter::FinalizePhiCandidates() { +#if SSA_REWRITE_DEBUGGING_LEVEL > 1 + std::cerr << "Finalizing Phi candidates:\n\n"; + PrintPhiCandidates(); + std::cerr << "\n"; +#endif + + // Now, complete the collected candidates. + while (incomplete_phis_.size() > 0) { + PhiCandidate* phi_candidate = incomplete_phis_.front(); + incomplete_phis_.pop(); + FinalizePhiCandidate(phi_candidate); + } +} + +bool SSARewriter::RewriteFunctionIntoSSA(Function* fp) { +#if SSA_REWRITE_DEBUGGING_LEVEL > 0 + std::cerr << "Function before SSA rewrite:\n" + << fp->PrettyPrint(0) << "\n\n\n"; +#endif + + // Collect variables that can be converted into SSA IDs. + pass_->CollectTargetVars(fp); + + // Generate all the SSA replacements and Phi candidates. This will + // generate incomplete and trivial Phis. + pass_->cfg()->ForEachBlockInReversePostOrder( + fp->entry().get(), + [this](BasicBlock* bb) { GenerateSSAReplacements(bb); }); + + // Remove trivial Phis and add arguments to incomplete Phis. + FinalizePhiCandidates(); + + // Finally, apply all the replacements in the IR. + bool modified = ApplyReplacements(); + +#if SSA_REWRITE_DEBUGGING_LEVEL > 0 + std::cerr << "\n\n\nFunction after SSA rewrite:\n" + << fp->PrettyPrint(0) << "\n"; +#endif + + return modified; +} + +Pass::Status SSARewritePass::Process() { + bool modified = false; + for (auto& fn : *get_module()) { + modified |= SSARewriter(this).RewriteFunctionIntoSSA(&fn); + } + return modified ? Pass::Status::SuccessWithChange + : Pass::Status::SuccessWithoutChange; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/ssa_rewrite_pass.h b/source/opt/ssa_rewrite_pass.h new file mode 100644 index 000000000..c0373dc06 --- /dev/null +++ b/source/opt/ssa_rewrite_pass.h @@ -0,0 +1,304 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_SSA_REWRITE_PASS_H_ +#define SOURCE_OPT_SSA_REWRITE_PASS_H_ + +#include +#include +#include +#include +#include +#include + +#include "source/opt/basic_block.h" +#include "source/opt/ir_context.h" +#include "source/opt/mem_pass.h" + +namespace spvtools { +namespace opt { + +// Utility class for passes that need to rewrite a function into SSA. This +// converts load/store operations on function-local variables into SSA IDs, +// which allows them to be the target of optimizing transformations. +// +// Store and load operations to these variables are converted into +// operations on SSA IDs. Phi instructions are added when needed. See the +// SSA construction paper for algorithmic details +// (https://link.springer.com/chapter/10.1007/978-3-642-37051-9_6) +class SSARewriter { + public: + SSARewriter(MemPass* pass) + : pass_(pass), first_phi_id_(pass_->get_module()->IdBound()) {} + + // Rewrites SSA-target variables in function |fp| into SSA. This is the + // entry point for the SSA rewrite algorithm. SSA-target variables are + // locally defined variables that meet the criteria set by IsSSATargetVar. + // + // It returns true if function |fp| was modified. Otherwise, it returns + // false. + bool RewriteFunctionIntoSSA(Function* fp); + + private: + class PhiCandidate { + public: + explicit PhiCandidate(uint32_t var, uint32_t result, BasicBlock* block) + : var_id_(var), + result_id_(result), + bb_(block), + phi_args_(), + copy_of_(0), + is_complete_(false), + users_() {} + + uint32_t var_id() const { return var_id_; } + uint32_t result_id() const { return result_id_; } + BasicBlock* bb() const { return bb_; } + std::vector& phi_args() { return phi_args_; } + const std::vector& phi_args() const { return phi_args_; } + uint32_t copy_of() const { return copy_of_; } + bool is_complete() const { return is_complete_; } + std::vector& users() { return users_; } + const std::vector& users() const { return users_; } + + // Marks this phi candidate as a trivial copy of |orig_id|. + void MarkCopyOf(uint32_t orig_id) { copy_of_ = orig_id; } + + // Marks this phi candidate as incomplete. + void MarkIncomplete() { is_complete_ = false; } + + // Marks this phi candidate as complete. + void MarkComplete() { is_complete_ = true; } + + // Returns true if this Phi candidate is ready to be emitted. + bool IsReady() const { return is_complete() && copy_of() == 0; } + + // Pretty prints this Phi candidate into a string and returns it. |cfg| is + // needed to lookup basic block predecessors. + std::string PrettyPrint(const CFG* cfg) const; + + // Registers |operand_id| as a user of this Phi candidate. + void AddUser(uint32_t operand_id) { users_.push_back(operand_id); } + + private: + // Variable ID that this Phi is merging. + uint32_t var_id_; + + // SSA ID generated by this Phi (i.e., this is the result ID of the eventual + // Phi instruction). + uint32_t result_id_; + + // Basic block to hold this Phi. + BasicBlock* bb_; + + // Vector of operands for every predecessor block of |bb|. This vector is + // organized so that the Ith slot contains the argument coming from the Ith + // predecessor of |bb|. + std::vector phi_args_; + + // If this Phi is a trivial copy of another Phi, this is the ID of the + // original. If this is 0, it means that this is not a trivial Phi. + uint32_t copy_of_; + + // False, if this Phi candidate has no arguments or at least one argument is + // %0. + bool is_complete_; + + // List of all users for this Phi instruction. Each element is the result ID + // of the load instruction replaced by this Phi, or the result ID of a Phi + // candidate that has this Phi in its list of operands. + std::vector users_; + }; + + // Type used to keep track of store operations in each basic block. + typedef std::unordered_map> + BlockDefsMap; + + // Generates all the SSA rewriting decisions for basic block |bb|. This + // populates the Phi candidate table (|phi_candidate_|) and the load + // replacement table (|load_replacement_). + void GenerateSSAReplacements(BasicBlock* bb); + + // Seals block |bb|. Sealing a basic block means |bb| and all its + // predecessors of |bb| have been scanned for loads/stores. + void SealBlock(BasicBlock* bb); + + // Returns true if |bb| has been sealed. + bool IsBlockSealed(BasicBlock* bb) { return sealed_blocks_.count(bb) != 0; } + + // Returns the Phi candidate with result ID |id| if it exists in the table + // |phi_candidates_|. If no such Phi candidate exists, it returns nullptr. + PhiCandidate* GetPhiCandidate(uint32_t id) { + auto it = phi_candidates_.find(id); + return (it != phi_candidates_.end()) ? &it->second : nullptr; + } + + // Replaces all the users of Phi candidate |phi_cand| to be users of + // |repl_id|. + void ReplacePhiUsersWith(const PhiCandidate& phi_cand, uint32_t repl_id); + + // Returns the value ID that should replace the load ID in the given + // replacement pair |repl|. The replacement is a pair (|load_id|, |val_id|). + // If |val_id| is itself replaced by another value in the table, this function + // will look the replacement for |val_id| until it finds one that is not + // itself replaced. For instance, given: + // + // %34 = OpLoad %float %f1 + // OpStore %t %34 + // %36 = OpLoad %float %t + // + // Assume that %f1 is reached by a Phi candidate %42, the load + // replacement table will have the following entries: + // + // %34 -> %42 + // %36 -> %34 + // + // So, when looking for the replacement for %36, we should not use + // %34. Rather, we should use %42. To do this, the chain of + // replacements must be followed until we reach an element that has + // no replacement. + uint32_t GetReplacement(std::pair repl); + + // Returns the argument at index |ix| from |phi_candidate|. If argument |ix| + // comes from a trivial Phi, it follows the copy-of chain from that trivial + // Phi until it finds the original Phi candidate. + // + // This is only valid after all Phi candidates have been completed. It can + // only be called when generating the IR for these Phis. + uint32_t GetPhiArgument(const PhiCandidate* phi_candidate, uint32_t ix); + + // Applies all the SSA replacement decisions. This replaces loads/stores to + // SSA target variables with their corresponding SSA IDs, and inserts Phi + // instructions for them. + bool ApplyReplacements(); + + // Registers a definition for variable |var_id| in basic block |bb| with + // value |val_id|. + void WriteVariable(uint32_t var_id, BasicBlock* bb, uint32_t val_id) { + defs_at_block_[bb][var_id] = val_id; + } + + // Processes the store operation |inst| in basic block |bb|. This extracts + // the variable ID being stored into, determines whether the variable is an + // SSA-target variable, and, if it is, it stores its value in the + // |defs_at_block_| map. + void ProcessStore(Instruction* inst, BasicBlock* bb); + + // Processes the load operation |inst| in basic block |bb|. This extracts + // the variable ID being stored into, determines whether the variable is an + // SSA-target variable, and, if it is, it reads its reaching definition by + // calling |GetReachingDef|. + void ProcessLoad(Instruction* inst, BasicBlock* bb); + + // Reads the current definition for variable |var_id| in basic block |bb|. + // If |var_id| is not defined in block |bb| it walks up the predecessors of + // |bb|, creating new Phi candidates along the way, if needed. + // + // It returns the value for |var_id| from the RHS of the current reaching + // definition for |var_id|. + uint32_t GetReachingDef(uint32_t var_id, BasicBlock* bb); + + // Adds arguments to |phi_candidate| by getting the reaching definition of + // |phi_candidate|'s variable on each of the predecessors of its basic + // block. After populating the argument list, it determines whether all its + // arguments are the same. If so, it returns the ID of the argument that + // this Phi copies. + uint32_t AddPhiOperands(PhiCandidate* phi_candidate); + + // Creates a Phi candidate instruction for variable |var_id| in basic block + // |bb|. + // + // Since the rewriting algorithm may remove Phi candidates when it finds + // them to be trivial, we avoid the expense of creating actual Phi + // instructions by keeping a pool of Phi candidates (|phi_candidates_|) + // during rewriting. + // + // Once the candidate Phi is created, it returns its ID. + PhiCandidate& CreatePhiCandidate(uint32_t var_id, BasicBlock* bb); + + // Attempts to remove a trivial Phi candidate |phi_cand|. Trivial Phis are + // those that only reference themselves and one other value |val| any number + // of times. This will try to remove any other Phis that become trivial + // after |phi_cand| is removed. + // + // If |phi_cand| is trivial, it returns the SSA ID for the value that should + // replace it. Otherwise, it returns the SSA ID for |phi_cand|. + uint32_t TryRemoveTrivialPhi(PhiCandidate* phi_cand); + + // Finalizes |phi_candidate| by replacing every argument that is still %0 + // with its reaching definition. + void FinalizePhiCandidate(PhiCandidate* phi_candidate); + + // Finalizes processing of Phi candidates. Once the whole function has been + // scanned for loads and stores, the CFG will still have some incomplete and + // trivial Phis. This will add missing arguments and remove trivial Phi + // candidates. + void FinalizePhiCandidates(); + + // Prints the table of Phi candidates to std::cerr. + void PrintPhiCandidates() const; + + // Prints the load replacement table to std::cerr. + void PrintReplacementTable() const; + + // Map holding the value of every SSA-target variable at every basic block + // where the variable is stored. defs_at_block_[block][var_id] = val_id + // means that there is a store or Phi instruction for variable |var_id| at + // basic block |block| with value |val_id|. + BlockDefsMap defs_at_block_; + + // Map, indexed by Phi ID, holding all the Phi candidates created during SSA + // rewriting. |phi_candidates_[id]| returns the Phi candidate whose result + // is |id|. + std::unordered_map phi_candidates_; + + // Queue of incomplete Phi candidates. These are Phi candidates created at + // unsealed blocks. They need to be completed before they are instantiated + // in ApplyReplacements. + std::queue incomplete_phis_; + + // List of completed Phi candidates. These are the only candidates that + // will become real Phi instructions. + std::vector phis_to_generate_; + + // SSA replacement table. This maps variable IDs, resulting from a load + // operation, to the value IDs that will replace them after SSA rewriting. + // After all the rewriting decisions are made, a final scan through the IR + // is done to replace all uses of the original load ID with the value ID. + std::unordered_map load_replacement_; + + // Set of blocks that have been sealed already. + std::unordered_set sealed_blocks_; + + // Memory pass requesting the SSA rewriter. + MemPass* pass_; + + // ID of the first Phi created by the SSA rewriter. During rewriting, any + // ID bigger than this corresponds to a Phi candidate. + uint32_t first_phi_id_; +}; + +class SSARewritePass : public MemPass { + public: + SSARewritePass() = default; + + const char* name() const override { return "ssa-rewrite"; } + Status Process() override; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_SSA_REWRITE_PASS_H_ diff --git a/source/opt/strength_reduction_pass.cpp b/source/opt/strength_reduction_pass.cpp new file mode 100644 index 000000000..ab7c4eb8d --- /dev/null +++ b/source/opt/strength_reduction_pass.cpp @@ -0,0 +1,200 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/strength_reduction_pass.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/log.h" +#include "source/opt/reflect.h" + +namespace { +// Count the number of trailing zeros in the binary representation of +// |constVal|. +uint32_t CountTrailingZeros(uint32_t constVal) { + // Faster if we use the hardware count trailing zeros instruction. + // If not available, we could create a table. + uint32_t shiftAmount = 0; + while ((constVal & 1) == 0) { + ++shiftAmount; + constVal = (constVal >> 1); + } + return shiftAmount; +} + +// Return true if |val| is a power of 2. +bool IsPowerOf2(uint32_t val) { + // The idea is that the & will clear out the least + // significant 1 bit. If it is a power of 2, then + // there is exactly 1 bit set, and the value becomes 0. + if (val == 0) return false; + return ((val - 1) & val) == 0; +} + +} // namespace + +namespace spvtools { +namespace opt { + +Pass::Status StrengthReductionPass::Process() { + // Initialize the member variables on a per module basis. + bool modified = false; + int32_type_id_ = 0; + uint32_type_id_ = 0; + std::memset(constant_ids_, 0, sizeof(constant_ids_)); + + FindIntTypesAndConstants(); + modified = ScanFunctions(); + return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); +} + +bool StrengthReductionPass::ReplaceMultiplyByPowerOf2( + BasicBlock::iterator* inst) { + assert((*inst)->opcode() == SpvOp::SpvOpIMul && + "Only works for multiplication of integers."); + bool modified = false; + + // Currently only works on 32-bit integers. + if ((*inst)->type_id() != int32_type_id_ && + (*inst)->type_id() != uint32_type_id_) { + return modified; + } + + // Check the operands for a constant that is a power of 2. + for (int i = 0; i < 2; i++) { + uint32_t opId = (*inst)->GetSingleWordInOperand(i); + Instruction* opInst = get_def_use_mgr()->GetDef(opId); + if (opInst->opcode() == SpvOp::SpvOpConstant) { + // We found a constant operand. + uint32_t constVal = opInst->GetSingleWordOperand(2); + + if (IsPowerOf2(constVal)) { + modified = true; + uint32_t shiftAmount = CountTrailingZeros(constVal); + uint32_t shiftConstResultId = GetConstantId(shiftAmount); + + // Create the new instruction. + uint32_t newResultId = TakeNextId(); + std::vector newOperands; + newOperands.push_back((*inst)->GetInOperand(1 - i)); + Operand shiftOperand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, + {shiftConstResultId}); + newOperands.push_back(shiftOperand); + std::unique_ptr newInstruction( + new Instruction(context(), SpvOp::SpvOpShiftLeftLogical, + (*inst)->type_id(), newResultId, newOperands)); + + // Insert the new instruction and update the data structures. + (*inst) = (*inst).InsertBefore(std::move(newInstruction)); + get_def_use_mgr()->AnalyzeInstDefUse(&*(*inst)); + ++(*inst); + context()->ReplaceAllUsesWith((*inst)->result_id(), newResultId); + + // Remove the old instruction. + Instruction* inst_to_delete = &*(*inst); + --(*inst); + context()->KillInst(inst_to_delete); + + // We do not want to replace the instruction twice if both operands + // are constants that are a power of 2. So we break here. + break; + } + } + } + + return modified; +} + +void StrengthReductionPass::FindIntTypesAndConstants() { + analysis::Integer int32(32, true); + int32_type_id_ = context()->get_type_mgr()->GetId(&int32); + analysis::Integer uint32(32, false); + uint32_type_id_ = context()->get_type_mgr()->GetId(&uint32); + for (auto iter = get_module()->types_values_begin(); + iter != get_module()->types_values_end(); ++iter) { + switch (iter->opcode()) { + case SpvOp::SpvOpConstant: + if (iter->type_id() == uint32_type_id_) { + uint32_t value = iter->GetSingleWordOperand(2); + if (value <= 32) constant_ids_[value] = iter->result_id(); + } + break; + default: + break; + } + } +} + +uint32_t StrengthReductionPass::GetConstantId(uint32_t val) { + assert(val <= 32 && + "This function does not handle constants larger than 32."); + + if (constant_ids_[val] == 0) { + if (uint32_type_id_ == 0) { + analysis::Integer uint(32, false); + uint32_type_id_ = context()->get_type_mgr()->GetTypeInstruction(&uint); + } + + // Construct the constant. + uint32_t resultId = TakeNextId(); + Operand constant(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, + {val}); + std::unique_ptr newConstant( + new Instruction(context(), SpvOp::SpvOpConstant, uint32_type_id_, + resultId, {constant})); + get_module()->AddGlobalValue(std::move(newConstant)); + + // Notify the DefUseManager about this constant. + auto constantIter = --get_module()->types_values_end(); + get_def_use_mgr()->AnalyzeInstDef(&*constantIter); + + // Store the result id for next time. + constant_ids_[val] = resultId; + } + + return constant_ids_[val]; +} + +bool StrengthReductionPass::ScanFunctions() { + // I did not use |ForEachInst| in the module because the function that acts on + // the instruction gets a pointer to the instruction. We cannot use that to + // insert a new instruction. I want an iterator. + bool modified = false; + for (auto& func : *get_module()) { + for (auto& bb : func) { + for (auto inst = bb.begin(); inst != bb.end(); ++inst) { + switch (inst->opcode()) { + case SpvOp::SpvOpIMul: + if (ReplaceMultiplyByPowerOf2(&inst)) modified = true; + break; + default: + break; + } + } + } + } + return modified; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/strength_reduction_pass.h b/source/opt/strength_reduction_pass.h new file mode 100644 index 000000000..8dfeb307b --- /dev/null +++ b/source/opt/strength_reduction_pass.h @@ -0,0 +1,65 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_STRENGTH_REDUCTION_PASS_H_ +#define SOURCE_OPT_STRENGTH_REDUCTION_PASS_H_ + +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class StrengthReductionPass : public Pass { + public: + const char* name() const override { return "strength-reduction"; } + Status Process() override; + + private: + // Replaces multiple by power of 2 with an equivalent bit shift. + // Returns true if something changed. + bool ReplaceMultiplyByPowerOf2(BasicBlock::iterator*); + + // Scan the types and constants in the module looking for the the integer + // types that we are + // interested in. The shift operation needs a small unsigned integer. We + // need to find + // them or create them. We do not want duplicates. + void FindIntTypesAndConstants(); + + // Get the id for the given constant. If it does not exist, it will be + // created. The parameter must be between 0 and 32 inclusive. + uint32_t GetConstantId(uint32_t); + + // Replaces certain instructions in function bodies with presumably cheaper + // ones. Returns true if something changed. + bool ScanFunctions(); + + // Type ids for the types of interest, or 0 if they do not exist. + uint32_t int32_type_id_; + uint32_t uint32_type_id_; + + // constant_ids[i] is the id for unsigned integer constant i. + // We set the limit at 32 because a bit shift of a 32-bit integer does not + // need a value larger than 32. + uint32_t constant_ids_[33]; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_STRENGTH_REDUCTION_PASS_H_ diff --git a/source/opt/strip_debug_info_pass.cpp b/source/opt/strip_debug_info_pass.cpp new file mode 100644 index 000000000..5d9c5fec8 --- /dev/null +++ b/source/opt/strip_debug_info_pass.cpp @@ -0,0 +1,36 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/strip_debug_info_pass.h" +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace opt { + +Pass::Status StripDebugInfoPass::Process() { + bool modified = !context()->debugs1().empty() || + !context()->debugs2().empty() || + !context()->debugs3().empty(); + context()->debug_clear(); + + context()->module()->ForEachInst([&modified](Instruction* inst) { + modified |= !inst->dbg_line_insts().empty(); + inst->dbg_line_insts().clear(); + }); + + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/strip_debug_info_pass.h b/source/opt/strip_debug_info_pass.h new file mode 100644 index 000000000..47a2cd409 --- /dev/null +++ b/source/opt/strip_debug_info_pass.h @@ -0,0 +1,35 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_STRIP_DEBUG_INFO_PASS_H_ +#define SOURCE_OPT_STRIP_DEBUG_INFO_PASS_H_ + +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class StripDebugInfoPass : public Pass { + public: + const char* name() const override { return "strip-debug"; } + Status Process() override; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_STRIP_DEBUG_INFO_PASS_H_ diff --git a/source/opt/strip_reflect_info_pass.cpp b/source/opt/strip_reflect_info_pass.cpp new file mode 100644 index 000000000..984073f9d --- /dev/null +++ b/source/opt/strip_reflect_info_pass.cpp @@ -0,0 +1,82 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/strip_reflect_info_pass.h" + +#include +#include + +#include "source/opt/instruction.h" +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace opt { + +Pass::Status StripReflectInfoPass::Process() { + bool modified = false; + + std::vector to_remove; + + bool other_uses_for_decorate_string = false; + for (auto& inst : context()->module()->annotations()) { + switch (inst.opcode()) { + case SpvOpDecorateStringGOOGLE: + if (inst.GetSingleWordInOperand(1) == SpvDecorationHlslSemanticGOOGLE) { + to_remove.push_back(&inst); + } else { + other_uses_for_decorate_string = true; + } + break; + + case SpvOpMemberDecorateStringGOOGLE: + if (inst.GetSingleWordInOperand(2) == SpvDecorationHlslSemanticGOOGLE) { + to_remove.push_back(&inst); + } else { + other_uses_for_decorate_string = true; + } + break; + + case SpvOpDecorateId: + if (inst.GetSingleWordInOperand(1) == + SpvDecorationHlslCounterBufferGOOGLE) { + to_remove.push_back(&inst); + } + break; + + default: + break; + } + } + + for (auto& inst : context()->module()->extensions()) { + const char* ext_name = + reinterpret_cast(&inst.GetInOperand(0).words[0]); + if (0 == std::strcmp(ext_name, "SPV_GOOGLE_hlsl_functionality1")) { + to_remove.push_back(&inst); + } else if (!other_uses_for_decorate_string && + 0 == std::strcmp(ext_name, "SPV_GOOGLE_decorate_string")) { + to_remove.push_back(&inst); + } + } + + for (auto* inst : to_remove) { + modified = true; + context()->KillInst(inst); + } + + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/strip_reflect_info_pass.h b/source/opt/strip_reflect_info_pass.h new file mode 100644 index 000000000..4e1999ed3 --- /dev/null +++ b/source/opt/strip_reflect_info_pass.h @@ -0,0 +1,44 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_STRIP_REFLECT_INFO_PASS_H_ +#define SOURCE_OPT_STRIP_REFLECT_INFO_PASS_H_ + +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class StripReflectInfoPass : public Pass { + public: + const char* name() const override { return "strip-reflect"; } + Status Process() override; + + // Return the mask of preserved Analyses. + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisCombinators | IRContext::kAnalysisCFG | + IRContext::kAnalysisDominatorAnalysis | + IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisNameMap | + IRContext::kAnalysisConstants | IRContext::kAnalysisTypes; + } +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_STRIP_REFLECT_INFO_PASS_H_ diff --git a/source/opt/struct_cfg_analysis.cpp b/source/opt/struct_cfg_analysis.cpp new file mode 100644 index 000000000..d78ec560e --- /dev/null +++ b/source/opt/struct_cfg_analysis.cpp @@ -0,0 +1,128 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/struct_cfg_analysis.h" + +#include "source/opt/ir_context.h" + +namespace { +const uint32_t kMergeNodeIndex = 0; +const uint32_t kContinueNodeIndex = 1; +} + +namespace spvtools { +namespace opt { + +StructuredCFGAnalysis::StructuredCFGAnalysis(IRContext* ctx) : context_(ctx) { + // If this is not a shader, there are no merge instructions, and not + // structured CFG to analyze. + if (!context_->get_feature_mgr()->HasCapability(SpvCapabilityShader)) { + return; + } + + for (auto& func : *context_->module()) { + AddBlocksInFunction(&func); + } +} + +void StructuredCFGAnalysis::AddBlocksInFunction(Function* func) { + std::list order; + context_->cfg()->ComputeStructuredOrder(func, &*func->begin(), &order); + + struct TraversalInfo { + ConstructInfo cinfo; + uint32_t merge_node; + }; + + // Set up a stack to keep track of currently active constructs. + std::vector state; + state.emplace_back(); + state[0].cinfo.containing_construct = 0; + state[0].cinfo.containing_loop = 0; + state[0].merge_node = 0; + + for (BasicBlock* block : order) { + if (context_->cfg()->IsPseudoEntryBlock(block) || + context_->cfg()->IsPseudoExitBlock(block)) { + continue; + } + + if (block->id() == state.back().merge_node) { + state.pop_back(); + } + + bb_to_construct_.emplace(std::make_pair(block->id(), state.back().cinfo)); + + if (Instruction* merge_inst = block->GetMergeInst()) { + TraversalInfo new_state; + new_state.merge_node = + merge_inst->GetSingleWordInOperand(kMergeNodeIndex); + new_state.cinfo.containing_construct = block->id(); + + if (merge_inst->opcode() == SpvOpLoopMerge) { + new_state.cinfo.containing_loop = block->id(); + } else { + new_state.cinfo.containing_loop = state.back().cinfo.containing_loop; + } + + state.emplace_back(new_state); + merge_blocks_.Set(new_state.merge_node); + } + } +} + +uint32_t StructuredCFGAnalysis::MergeBlock(uint32_t bb_id) { + uint32_t header_id = ContainingConstruct(bb_id); + if (header_id == 0) { + return 0; + } + + BasicBlock* header = context_->cfg()->block(header_id); + Instruction* merge_inst = header->GetMergeInst(); + return merge_inst->GetSingleWordInOperand(kMergeNodeIndex); +} + +uint32_t StructuredCFGAnalysis::LoopMergeBlock(uint32_t bb_id) { + uint32_t header_id = ContainingLoop(bb_id); + if (header_id == 0) { + return 0; + } + + BasicBlock* header = context_->cfg()->block(header_id); + Instruction* merge_inst = header->GetMergeInst(); + return merge_inst->GetSingleWordInOperand(kMergeNodeIndex); +} + +uint32_t StructuredCFGAnalysis::LoopContinueBlock(uint32_t bb_id) { + uint32_t header_id = ContainingLoop(bb_id); + if (header_id == 0) { + return 0; + } + + BasicBlock* header = context_->cfg()->block(header_id); + Instruction* merge_inst = header->GetMergeInst(); + return merge_inst->GetSingleWordInOperand(kContinueNodeIndex); +} + +bool StructuredCFGAnalysis::IsContinueBlock(uint32_t bb_id) { + assert(bb_id != 0); + return LoopContinueBlock(bb_id) == bb_id; +} + +bool StructuredCFGAnalysis::IsMergeBlock(uint32_t bb_id) { + return merge_blocks_.Get(bb_id); +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/struct_cfg_analysis.h b/source/opt/struct_cfg_analysis.h new file mode 100644 index 000000000..ef0229d05 --- /dev/null +++ b/source/opt/struct_cfg_analysis.h @@ -0,0 +1,101 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_STRUCT_CFG_ANALYSIS_H_ +#define SOURCE_OPT_STRUCT_CFG_ANALYSIS_H_ + +#include + +#include "source/opt/function.h" +#include "source/util/bit_vector.h" + +namespace spvtools { +namespace opt { + +class IRContext; + +// An analysis that, for each basic block, finds the constructs in which it is +// contained, so we can easily get headers and merge nodes. +class StructuredCFGAnalysis { + public: + explicit StructuredCFGAnalysis(IRContext* ctx); + + // Returns the id of the header of the innermost merge construct + // that contains |bb_id|. Returns |0| if |bb_id| is not contained in any + // merge construct. + uint32_t ContainingConstruct(uint32_t bb_id) { + auto it = bb_to_construct_.find(bb_id); + if (it == bb_to_construct_.end()) { + return 0; + } + return it->second.containing_construct; + } + + // Returns the id of the merge block of the innermost merge construct + // that contains |bb_id|. Returns |0| if |bb_id| is not contained in any + // merge construct. + uint32_t MergeBlock(uint32_t bb_id); + + // Returns the id of the header of the innermost loop construct + // that contains |bb_id|. Return |0| if |bb_id| is not contained in any loop + // construct. + uint32_t ContainingLoop(uint32_t bb_id) { + auto it = bb_to_construct_.find(bb_id); + if (it == bb_to_construct_.end()) { + return 0; + } + return it->second.containing_loop; + } + + // Returns the id of the merge block of the innermost loop construct + // that contains |bb_id|. Return |0| if |bb_id| is not contained in any loop + // construct. + uint32_t LoopMergeBlock(uint32_t bb_id); + + // Returns the id of the continue block of the innermost loop construct + // that contains |bb_id|. Return |0| if |bb_id| is not contained in any loop + // construct. + uint32_t LoopContinueBlock(uint32_t bb_id); + + bool IsContinueBlock(uint32_t bb_id); + bool IsMergeBlock(uint32_t bb_id); + + private: + // Struct used to hold the information for a basic block. + // |containing_construct| is the header for the innermost containing + // construct, or 0 if no such construct exists. It could be a selection + // construct or a loop construct. |containing_loop| is the innermost + // containing loop construct, or 0 if the basic bloc is not in a loop. If the + // basic block is in a selection construct that is contained in a loop + // construct, then these two values will not be the same. + struct ConstructInfo { + uint32_t containing_construct; + uint32_t containing_loop; + }; + + // Populates |bb_to_construct_| with the innermost containing merge and loop + // constructs for each basic block in |func|. + void AddBlocksInFunction(Function* func); + + IRContext* context_; + + // A map from a basic block to the headers of its inner most containing + // constructs. + std::unordered_map bb_to_construct_; + utils::BitVector merge_blocks_; +}; + +} // namespace opt +} // namespace spvtools +#endif // SOURCE_OPT_STRUCT_CFG_ANALYSIS_H_ diff --git a/source/opt/tree_iterator.h b/source/opt/tree_iterator.h new file mode 100644 index 000000000..05f42bc5b --- /dev/null +++ b/source/opt/tree_iterator.h @@ -0,0 +1,246 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_TREE_ITERATOR_H_ +#define SOURCE_OPT_TREE_ITERATOR_H_ + +#include +#include +#include + +namespace spvtools { +namespace opt { + +// Helper class to iterate over a tree in a depth first order. +// The class assumes the data structure is a tree, tree node type implements a +// forward iterator. +// At each step, the iterator holds the pointer to the current node and state of +// the walk. +// The state is recorded by stacking the iteration position of the node +// children. To move to the next node, the iterator: +// - Looks at the top of the stack; +// - Sets the node behind the iterator as the current node; +// - Increments the iterator if it has more children to visit, pops otherwise; +// - If the current node has children, the children iterator is pushed into the +// stack. +template +class TreeDFIterator { + static_assert(!std::is_pointer::value && + !std::is_reference::value, + "NodeTy should be a class"); + // Type alias to keep track of the const qualifier. + using NodeIterator = + typename std::conditional::value, + typename NodeTy::const_iterator, + typename NodeTy::iterator>::type; + + // Type alias to keep track of the const qualifier. + using NodePtr = NodeTy*; + + public: + // Standard iterator interface. + using reference = NodeTy&; + using value_type = NodeTy; + + explicit inline TreeDFIterator(NodePtr top_node) : current_(top_node) { + if (current_ && current_->begin() != current_->end()) + parent_iterators_.emplace(make_pair(current_, current_->begin())); + } + + // end() iterator. + inline TreeDFIterator() : TreeDFIterator(nullptr) {} + + bool operator==(const TreeDFIterator& x) const { + return current_ == x.current_; + } + + bool operator!=(const TreeDFIterator& x) const { return !(*this == x); } + + reference operator*() const { return *current_; } + + NodePtr operator->() const { return current_; } + + TreeDFIterator& operator++() { + MoveToNextNode(); + return *this; + } + + TreeDFIterator operator++(int) { + TreeDFIterator tmp = *this; + ++*this; + return tmp; + } + + private: + // Moves the iterator to the next node in the tree. + // If we are at the end, do nothing, otherwise + // if our current node has children, use the children iterator and push the + // current node into the stack. + // If we reach the end of the local iterator, pop it. + inline void MoveToNextNode() { + if (!current_) return; + if (parent_iterators_.empty()) { + current_ = nullptr; + return; + } + std::pair& next_it = parent_iterators_.top(); + // Set the new node. + current_ = *next_it.second; + // Update the iterator for the next child. + ++next_it.second; + // If we finished with node, pop it. + if (next_it.first->end() == next_it.second) parent_iterators_.pop(); + // If our current node is not a leaf, store the iteration state for later. + if (current_->begin() != current_->end()) + parent_iterators_.emplace(make_pair(current_, current_->begin())); + } + + // The current node of the tree. + NodePtr current_; + // State of the tree walk: each pair contains the parent node (which has been + // already visited) and the iterator of the next children to visit. + // When all the children has been visited, we pop the entry, get the next + // child and push back the pair if the children iterator is not end(). + std::stack> parent_iterators_; +}; + +// Helper class to iterate over a tree in a depth first post-order. +// The class assumes the data structure is a tree, tree node type implements a +// forward iterator. +// At each step, the iterator holds the pointer to the current node and state of +// the walk. +// The state is recorded by stacking the iteration position of the node +// children. To move to the next node, the iterator: +// - Looks at the top of the stack; +// - If the children iterator has reach the end, then the node become the +// current one and we pop the stack; +// - Otherwise, we save the child and increment the iterator; +// - We walk the child sub-tree until we find a leaf, stacking all non-leaves +// states (pair of node pointer and child iterator) as we walk it. +template +class PostOrderTreeDFIterator { + static_assert(!std::is_pointer::value && + !std::is_reference::value, + "NodeTy should be a class"); + // Type alias to keep track of the const qualifier. + using NodeIterator = + typename std::conditional::value, + typename NodeTy::const_iterator, + typename NodeTy::iterator>::type; + + // Type alias to keep track of the const qualifier. + using NodePtr = NodeTy*; + + public: + // Standard iterator interface. + using reference = NodeTy&; + using value_type = NodeTy; + + static inline PostOrderTreeDFIterator begin(NodePtr top_node) { + return PostOrderTreeDFIterator(top_node); + } + + static inline PostOrderTreeDFIterator end(NodePtr sentinel_node) { + return PostOrderTreeDFIterator(sentinel_node, false); + } + + bool operator==(const PostOrderTreeDFIterator& x) const { + return current_ == x.current_; + } + + bool operator!=(const PostOrderTreeDFIterator& x) const { + return !(*this == x); + } + + reference operator*() const { return *current_; } + + NodePtr operator->() const { return current_; } + + PostOrderTreeDFIterator& operator++() { + MoveToNextNode(); + return *this; + } + + PostOrderTreeDFIterator operator++(int) { + PostOrderTreeDFIterator tmp = *this; + ++*this; + return tmp; + } + + private: + explicit inline PostOrderTreeDFIterator(NodePtr top_node) + : current_(top_node) { + if (current_) WalkToLeaf(); + } + + // Constructor for the "end()" iterator. + // |end_sentinel| is the value that acts as end value (can be null). The bool + // parameters is to distinguish from the start() Ctor. + inline PostOrderTreeDFIterator(NodePtr sentinel_node, bool) + : current_(sentinel_node) {} + + // Moves the iterator to the next node in the tree. + // If we are at the end, do nothing, otherwise + // if our current node has children, use the children iterator and push the + // current node into the stack. + // If we reach the end of the local iterator, pop it. + inline void MoveToNextNode() { + if (!current_) return; + if (parent_iterators_.empty()) { + current_ = nullptr; + return; + } + std::pair& next_it = parent_iterators_.top(); + // If we visited all children, the current node is the top of the stack. + if (next_it.second == next_it.first->end()) { + // Set the new node. + current_ = next_it.first; + parent_iterators_.pop(); + return; + } + // We have more children to visit, set the current node to the first child + // and dive to leaf. + current_ = *next_it.second; + // Update the iterator for the next child (avoid unneeded pop). + ++next_it.second; + WalkToLeaf(); + } + + // Moves the iterator to the next node in the tree. + // If we are at the end, do nothing, otherwise + // if our current node has children, use the children iterator and push the + // current node into the stack. + // If we reach the end of the local iterator, pop it. + inline void WalkToLeaf() { + while (current_->begin() != current_->end()) { + NodeIterator next = ++current_->begin(); + parent_iterators_.emplace(make_pair(current_, next)); + // Set the first child as the new node. + current_ = *current_->begin(); + } + } + + // The current node of the tree. + NodePtr current_; + // State of the tree walk: each pair contains the parent node and the iterator + // of the next children to visit. + // When all the children has been visited, we pop the first entry and the + // parent node become the current node. + std::stack> parent_iterators_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_TREE_ITERATOR_H_ diff --git a/source/opt/type_manager.cpp b/source/opt/type_manager.cpp new file mode 100644 index 000000000..001883cad --- /dev/null +++ b/source/opt/type_manager.cpp @@ -0,0 +1,935 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/type_manager.h" + +#include +#include +#include +#include + +#include "source/opt/ir_context.h" +#include "source/opt/log.h" +#include "source/opt/reflect.h" +#include "source/util/make_unique.h" + +namespace spvtools { +namespace opt { +namespace analysis { +namespace { + +const int kSpvTypePointerStorageClass = 1; +const int kSpvTypePointerTypeIdInIdx = 2; + +} // namespace + +TypeManager::TypeManager(const MessageConsumer& consumer, IRContext* c) + : consumer_(consumer), context_(c) { + AnalyzeTypes(*c->module()); +} + +Type* TypeManager::GetType(uint32_t id) const { + auto iter = id_to_type_.find(id); + if (iter != id_to_type_.end()) return (*iter).second; + iter = id_to_incomplete_type_.find(id); + if (iter != id_to_incomplete_type_.end()) return (*iter).second; + return nullptr; +} + +std::pair> TypeManager::GetTypeAndPointerType( + uint32_t id, SpvStorageClass sc) const { + Type* type = GetType(id); + if (type) { + return std::make_pair(type, MakeUnique(type, sc)); + } else { + return std::make_pair(type, std::unique_ptr()); + } +} + +uint32_t TypeManager::GetId(const Type* type) const { + auto iter = type_to_id_.find(type); + if (iter != type_to_id_.end()) { + return (*iter).second; + } + return 0; +} + +void TypeManager::AnalyzeTypes(const Module& module) { + // First pass through the types. Any types that reference a forward pointer + // (directly or indirectly) are incomplete, and are added to incomplete types. + for (const auto* inst : module.GetTypes()) { + RecordIfTypeDefinition(*inst); + } + + if (incomplete_types_.empty()) { + return; + } + + // Get the real pointer definition for all of the forward pointers. + for (auto& type : incomplete_types_) { + if (type.type()->kind() == Type::kForwardPointer) { + auto* t = GetType(type.id()); + assert(t); + auto* p = t->AsPointer(); + assert(p); + type.type()->AsForwardPointer()->SetTargetPointer(p); + } + } + + // Replaces the references to the forward pointers in the incomplete types. + for (auto& type : incomplete_types_) { + ReplaceForwardPointers(type.type()); + } + + // Delete the forward pointers now that they are not referenced anymore. + for (auto& type : incomplete_types_) { + if (type.type()->kind() == Type::kForwardPointer) { + type.ResetType(nullptr); + } + } + + // Compare the complete types looking for types that are the same. If there + // are two types that are the same, then replace one with the other. + // Continue until we reach a fixed point. + bool restart = true; + while (restart) { + restart = false; + for (auto it1 = incomplete_types_.begin(); it1 != incomplete_types_.end(); + ++it1) { + uint32_t id1 = it1->id(); + Type* type1 = it1->type(); + if (!type1) { + continue; + } + + for (auto it2 = it1 + 1; it2 != incomplete_types_.end(); ++it2) { + uint32_t id2 = it2->id(); + (void)(id2 + id1); + Type* type2 = it2->type(); + if (!type2) { + continue; + } + + if (type1->IsSame(type2)) { + ReplaceType(type1, type2); + it2->ResetType(nullptr); + id_to_incomplete_type_[it2->id()] = type1; + restart = true; + } + } + } + } + + // Add the remaining incomplete types to the type pool. + for (auto& type : incomplete_types_) { + if (type.type() && !type.type()->AsForwardPointer()) { + std::vector decorations = + context()->get_decoration_mgr()->GetDecorationsFor(type.id(), true); + for (auto dec : decorations) { + AttachDecoration(*dec, type.type()); + } + auto pair = type_pool_.insert(type.ReleaseType()); + id_to_type_[type.id()] = pair.first->get(); + type_to_id_[pair.first->get()] = type.id(); + id_to_incomplete_type_.erase(type.id()); + } + } + + // Add a mapping for any ids that whose original type was replaced by an + // equivalent type. + for (auto& type : id_to_incomplete_type_) { + id_to_type_[type.first] = type.second; + } + +#ifndef NDEBUG + // Check if the type pool contains two types that are the same. This + // is an indication that the hashing and comparision are wrong. It + // will cause a problem if the type pool gets resized and everything + // is rehashed. + for (auto& i : type_pool_) { + for (auto& j : type_pool_) { + Type* ti = i.get(); + Type* tj = j.get(); + assert((ti == tj || !ti->IsSame(tj)) && + "Type pool contains two types that are the same."); + } + } +#endif +} + +void TypeManager::RemoveId(uint32_t id) { + auto iter = id_to_type_.find(id); + if (iter == id_to_type_.end()) return; + + auto& type = iter->second; + if (!type->IsUniqueType(true)) { + auto tIter = type_to_id_.find(type); + if (tIter != type_to_id_.end() && tIter->second == id) { + // |type| currently maps to |id|. + // Search for an equivalent type to re-map. + bool found = false; + for (auto& pair : id_to_type_) { + if (pair.first != id && *pair.second == *type) { + // Equivalent ambiguous type, re-map type. + type_to_id_.erase(type); + type_to_id_[pair.second] = pair.first; + found = true; + break; + } + } + // No equivalent ambiguous type, remove mapping. + if (!found) type_to_id_.erase(tIter); + } + } else { + // Unique type, so just erase the entry. + type_to_id_.erase(type); + } + + // Erase the entry for |id|. + id_to_type_.erase(iter); +} + +uint32_t TypeManager::GetTypeInstruction(const Type* type) { + uint32_t id = GetId(type); + if (id != 0) return id; + + std::unique_ptr typeInst; + // TODO(1841): Handle id overflow. + id = context()->TakeNextId(); + RegisterType(id, *type); + switch (type->kind()) { +#define DefineParameterlessCase(kind) \ + case Type::k##kind: \ + typeInst = MakeUnique(context(), SpvOpType##kind, 0, id, \ + std::initializer_list{}); \ + break; + DefineParameterlessCase(Void); + DefineParameterlessCase(Bool); + DefineParameterlessCase(Sampler); + DefineParameterlessCase(Event); + DefineParameterlessCase(DeviceEvent); + DefineParameterlessCase(ReserveId); + DefineParameterlessCase(Queue); + DefineParameterlessCase(PipeStorage); + DefineParameterlessCase(NamedBarrier); + DefineParameterlessCase(AccelerationStructureNV); +#undef DefineParameterlessCase + case Type::kInteger: + typeInst = MakeUnique( + context(), SpvOpTypeInt, 0, id, + std::initializer_list{ + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {type->AsInteger()->width()}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, + {(type->AsInteger()->IsSigned() ? 1u : 0u)}}}); + break; + case Type::kFloat: + typeInst = MakeUnique( + context(), SpvOpTypeFloat, 0, id, + std::initializer_list{ + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {type->AsFloat()->width()}}}); + break; + case Type::kVector: { + uint32_t subtype = GetTypeInstruction(type->AsVector()->element_type()); + typeInst = + MakeUnique(context(), SpvOpTypeVector, 0, id, + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {subtype}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, + {type->AsVector()->element_count()}}}); + break; + } + case Type::kMatrix: { + uint32_t subtype = GetTypeInstruction(type->AsMatrix()->element_type()); + typeInst = + MakeUnique(context(), SpvOpTypeMatrix, 0, id, + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {subtype}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, + {type->AsMatrix()->element_count()}}}); + break; + } + case Type::kImage: { + const Image* image = type->AsImage(); + uint32_t subtype = GetTypeInstruction(image->sampled_type()); + typeInst = MakeUnique( + context(), SpvOpTypeImage, 0, id, + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {subtype}}, + {SPV_OPERAND_TYPE_DIMENSIONALITY, + {static_cast(image->dim())}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {image->depth()}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, + {(image->is_arrayed() ? 1u : 0u)}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, + {(image->is_multisampled() ? 1u : 0u)}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {image->sampled()}}, + {SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT, + {static_cast(image->format())}}, + {SPV_OPERAND_TYPE_ACCESS_QUALIFIER, + {static_cast(image->access_qualifier())}}}); + break; + } + case Type::kSampledImage: { + uint32_t subtype = + GetTypeInstruction(type->AsSampledImage()->image_type()); + typeInst = MakeUnique( + context(), SpvOpTypeSampledImage, 0, id, + std::initializer_list{{SPV_OPERAND_TYPE_ID, {subtype}}}); + break; + } + case Type::kArray: { + uint32_t subtype = GetTypeInstruction(type->AsArray()->element_type()); + typeInst = MakeUnique( + context(), SpvOpTypeArray, 0, id, + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {subtype}}, + {SPV_OPERAND_TYPE_ID, {type->AsArray()->LengthId()}}}); + break; + } + case Type::kRuntimeArray: { + uint32_t subtype = + GetTypeInstruction(type->AsRuntimeArray()->element_type()); + typeInst = MakeUnique( + context(), SpvOpTypeRuntimeArray, 0, id, + std::initializer_list{{SPV_OPERAND_TYPE_ID, {subtype}}}); + break; + } + case Type::kStruct: { + std::vector ops; + const Struct* structTy = type->AsStruct(); + for (auto ty : structTy->element_types()) { + ops.push_back(Operand(SPV_OPERAND_TYPE_ID, {GetTypeInstruction(ty)})); + } + typeInst = + MakeUnique(context(), SpvOpTypeStruct, 0, id, ops); + break; + } + case Type::kOpaque: { + const Opaque* opaque = type->AsOpaque(); + size_t size = opaque->name().size(); + // Convert to null-terminated packed UTF-8 string. + std::vector words(size / 4 + 1, 0); + char* dst = reinterpret_cast(words.data()); + strncpy(dst, opaque->name().c_str(), size); + typeInst = MakeUnique( + context(), SpvOpTypeOpaque, 0, id, + std::initializer_list{ + {SPV_OPERAND_TYPE_LITERAL_STRING, words}}); + break; + } + case Type::kPointer: { + const Pointer* pointer = type->AsPointer(); + uint32_t subtype = GetTypeInstruction(pointer->pointee_type()); + typeInst = MakeUnique( + context(), SpvOpTypePointer, 0, id, + std::initializer_list{ + {SPV_OPERAND_TYPE_STORAGE_CLASS, + {static_cast(pointer->storage_class())}}, + {SPV_OPERAND_TYPE_ID, {subtype}}}); + break; + } + case Type::kFunction: { + std::vector ops; + const Function* function = type->AsFunction(); + ops.push_back(Operand(SPV_OPERAND_TYPE_ID, + {GetTypeInstruction(function->return_type())})); + for (auto ty : function->param_types()) { + ops.push_back(Operand(SPV_OPERAND_TYPE_ID, {GetTypeInstruction(ty)})); + } + typeInst = + MakeUnique(context(), SpvOpTypeFunction, 0, id, ops); + break; + } + case Type::kPipe: + typeInst = MakeUnique( + context(), SpvOpTypePipe, 0, id, + std::initializer_list{ + {SPV_OPERAND_TYPE_ACCESS_QUALIFIER, + {static_cast(type->AsPipe()->access_qualifier())}}}); + break; + case Type::kForwardPointer: + typeInst = MakeUnique( + context(), SpvOpTypeForwardPointer, 0, 0, + std::initializer_list{ + {SPV_OPERAND_TYPE_ID, {type->AsForwardPointer()->target_id()}}, + {SPV_OPERAND_TYPE_STORAGE_CLASS, + {static_cast( + type->AsForwardPointer()->storage_class())}}}); + break; + default: + assert(false && "Unexpected type"); + break; + } + context()->AddType(std::move(typeInst)); + context()->AnalyzeDefUse(&*--context()->types_values_end()); + AttachDecorations(id, type); + return id; +} + +uint32_t TypeManager::FindPointerToType(uint32_t type_id, + SpvStorageClass storage_class) { + Type* pointeeTy = GetType(type_id); + Pointer pointerTy(pointeeTy, storage_class); + if (pointeeTy->IsUniqueType(true)) { + // Non-ambiguous type. Get the pointer type through the type manager. + return GetTypeInstruction(&pointerTy); + } + + // Ambiguous type, do a linear search. + Module::inst_iterator type_itr = context()->module()->types_values_begin(); + for (; type_itr != context()->module()->types_values_end(); ++type_itr) { + const Instruction* type_inst = &*type_itr; + if (type_inst->opcode() == SpvOpTypePointer && + type_inst->GetSingleWordOperand(kSpvTypePointerTypeIdInIdx) == + type_id && + type_inst->GetSingleWordOperand(kSpvTypePointerStorageClass) == + storage_class) + return type_inst->result_id(); + } + + // Must create the pointer type. + // TODO(1841): Handle id overflow. + uint32_t resultId = context()->TakeNextId(); + std::unique_ptr type_inst( + new Instruction(context(), SpvOpTypePointer, 0, resultId, + {{spv_operand_type_t::SPV_OPERAND_TYPE_STORAGE_CLASS, + {uint32_t(storage_class)}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {type_id}}})); + context()->AddType(std::move(type_inst)); + context()->get_type_mgr()->RegisterType(resultId, pointerTy); + return resultId; +} + +void TypeManager::AttachDecorations(uint32_t id, const Type* type) { + for (auto vec : type->decorations()) { + CreateDecoration(id, vec); + } + if (const Struct* structTy = type->AsStruct()) { + for (auto pair : structTy->element_decorations()) { + uint32_t element = pair.first; + for (auto vec : pair.second) { + CreateDecoration(id, vec, element); + } + } + } +} + +void TypeManager::CreateDecoration(uint32_t target, + const std::vector& decoration, + uint32_t element) { + std::vector ops; + ops.push_back(Operand(SPV_OPERAND_TYPE_ID, {target})); + if (element != 0) { + ops.push_back(Operand(SPV_OPERAND_TYPE_LITERAL_INTEGER, {element})); + } + ops.push_back(Operand(SPV_OPERAND_TYPE_DECORATION, {decoration[0]})); + for (size_t i = 1; i < decoration.size(); ++i) { + ops.push_back(Operand(SPV_OPERAND_TYPE_LITERAL_INTEGER, {decoration[i]})); + } + context()->AddAnnotationInst(MakeUnique( + context(), (element == 0 ? SpvOpDecorate : SpvOpMemberDecorate), 0, 0, + ops)); + Instruction* inst = &*--context()->annotation_end(); + context()->get_def_use_mgr()->AnalyzeInstUse(inst); +} + +Type* TypeManager::RebuildType(const Type& type) { + // The comparison and hash on the type pool will avoid inserting the rebuilt + // type if an equivalent type already exists. The rebuilt type will be deleted + // when it goes out of scope at the end of the function in that case. Repeated + // insertions of the same Type will, at most, keep one corresponding object in + // the type pool. + std::unique_ptr rebuilt_ty; + switch (type.kind()) { +#define DefineNoSubtypeCase(kind) \ + case Type::k##kind: \ + rebuilt_ty.reset(type.Clone().release()); \ + return type_pool_.insert(std::move(rebuilt_ty)).first->get(); + + DefineNoSubtypeCase(Void); + DefineNoSubtypeCase(Bool); + DefineNoSubtypeCase(Integer); + DefineNoSubtypeCase(Float); + DefineNoSubtypeCase(Sampler); + DefineNoSubtypeCase(Opaque); + DefineNoSubtypeCase(Event); + DefineNoSubtypeCase(DeviceEvent); + DefineNoSubtypeCase(ReserveId); + DefineNoSubtypeCase(Queue); + DefineNoSubtypeCase(Pipe); + DefineNoSubtypeCase(PipeStorage); + DefineNoSubtypeCase(NamedBarrier); + DefineNoSubtypeCase(AccelerationStructureNV); +#undef DefineNoSubtypeCase + case Type::kVector: { + const Vector* vec_ty = type.AsVector(); + const Type* ele_ty = vec_ty->element_type(); + rebuilt_ty = + MakeUnique(RebuildType(*ele_ty), vec_ty->element_count()); + break; + } + case Type::kMatrix: { + const Matrix* mat_ty = type.AsMatrix(); + const Type* ele_ty = mat_ty->element_type(); + rebuilt_ty = + MakeUnique(RebuildType(*ele_ty), mat_ty->element_count()); + break; + } + case Type::kImage: { + const Image* image_ty = type.AsImage(); + const Type* ele_ty = image_ty->sampled_type(); + rebuilt_ty = + MakeUnique(RebuildType(*ele_ty), image_ty->dim(), + image_ty->depth(), image_ty->is_arrayed(), + image_ty->is_multisampled(), image_ty->sampled(), + image_ty->format(), image_ty->access_qualifier()); + break; + } + case Type::kSampledImage: { + const SampledImage* image_ty = type.AsSampledImage(); + const Type* ele_ty = image_ty->image_type(); + rebuilt_ty = MakeUnique(RebuildType(*ele_ty)); + break; + } + case Type::kArray: { + const Array* array_ty = type.AsArray(); + const Type* ele_ty = array_ty->element_type(); + rebuilt_ty = + MakeUnique(RebuildType(*ele_ty), array_ty->LengthId()); + break; + } + case Type::kRuntimeArray: { + const RuntimeArray* array_ty = type.AsRuntimeArray(); + const Type* ele_ty = array_ty->element_type(); + rebuilt_ty = MakeUnique(RebuildType(*ele_ty)); + break; + } + case Type::kStruct: { + const Struct* struct_ty = type.AsStruct(); + std::vector subtypes; + subtypes.reserve(struct_ty->element_types().size()); + for (const auto* ele_ty : struct_ty->element_types()) { + subtypes.push_back(RebuildType(*ele_ty)); + } + rebuilt_ty = MakeUnique(subtypes); + Struct* rebuilt_struct = rebuilt_ty->AsStruct(); + for (auto pair : struct_ty->element_decorations()) { + uint32_t index = pair.first; + for (const auto& dec : pair.second) { + // Explicit copy intended. + std::vector copy(dec); + rebuilt_struct->AddMemberDecoration(index, std::move(copy)); + } + } + break; + } + case Type::kPointer: { + const Pointer* pointer_ty = type.AsPointer(); + const Type* ele_ty = pointer_ty->pointee_type(); + rebuilt_ty = MakeUnique(RebuildType(*ele_ty), + pointer_ty->storage_class()); + break; + } + case Type::kFunction: { + const Function* function_ty = type.AsFunction(); + const Type* ret_ty = function_ty->return_type(); + std::vector param_types; + param_types.reserve(function_ty->param_types().size()); + for (const auto* param_ty : function_ty->param_types()) { + param_types.push_back(RebuildType(*param_ty)); + } + rebuilt_ty = MakeUnique(RebuildType(*ret_ty), param_types); + break; + } + case Type::kForwardPointer: { + const ForwardPointer* forward_ptr_ty = type.AsForwardPointer(); + rebuilt_ty = MakeUnique(forward_ptr_ty->target_id(), + forward_ptr_ty->storage_class()); + const Pointer* target_ptr = forward_ptr_ty->target_pointer(); + if (target_ptr) { + rebuilt_ty->AsForwardPointer()->SetTargetPointer( + RebuildType(*target_ptr)->AsPointer()); + } + break; + } + default: + assert(false && "Unhandled type"); + return nullptr; + } + for (const auto& dec : type.decorations()) { + // Explicit copy intended. + std::vector copy(dec); + rebuilt_ty->AddDecoration(std::move(copy)); + } + + return type_pool_.insert(std::move(rebuilt_ty)).first->get(); +} + +void TypeManager::RegisterType(uint32_t id, const Type& type) { + // Rebuild |type| so it and all its constituent types are owned by the type + // pool. + Type* rebuilt = RebuildType(type); + assert(rebuilt->IsSame(&type)); + id_to_type_[id] = rebuilt; + if (GetId(rebuilt) == 0) { + type_to_id_[rebuilt] = id; + } +} + +Type* TypeManager::GetRegisteredType(const Type* type) { + uint32_t id = GetTypeInstruction(type); + return GetType(id); +} + +Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) { + if (!IsTypeInst(inst.opcode())) return nullptr; + + Type* type = nullptr; + switch (inst.opcode()) { + case SpvOpTypeVoid: + type = new Void(); + break; + case SpvOpTypeBool: + type = new Bool(); + break; + case SpvOpTypeInt: + type = new Integer(inst.GetSingleWordInOperand(0), + inst.GetSingleWordInOperand(1)); + break; + case SpvOpTypeFloat: + type = new Float(inst.GetSingleWordInOperand(0)); + break; + case SpvOpTypeVector: + type = new Vector(GetType(inst.GetSingleWordInOperand(0)), + inst.GetSingleWordInOperand(1)); + break; + case SpvOpTypeMatrix: + type = new Matrix(GetType(inst.GetSingleWordInOperand(0)), + inst.GetSingleWordInOperand(1)); + break; + case SpvOpTypeImage: { + const SpvAccessQualifier access = + inst.NumInOperands() < 8 + ? SpvAccessQualifierReadOnly + : static_cast(inst.GetSingleWordInOperand(7)); + type = new Image( + GetType(inst.GetSingleWordInOperand(0)), + static_cast(inst.GetSingleWordInOperand(1)), + inst.GetSingleWordInOperand(2), inst.GetSingleWordInOperand(3) == 1, + inst.GetSingleWordInOperand(4) == 1, inst.GetSingleWordInOperand(5), + static_cast(inst.GetSingleWordInOperand(6)), access); + } break; + case SpvOpTypeSampler: + type = new Sampler(); + break; + case SpvOpTypeSampledImage: + type = new SampledImage(GetType(inst.GetSingleWordInOperand(0))); + break; + case SpvOpTypeArray: + type = new Array(GetType(inst.GetSingleWordInOperand(0)), + inst.GetSingleWordInOperand(1)); + if (id_to_incomplete_type_.count(inst.GetSingleWordInOperand(0))) { + incomplete_types_.emplace_back(inst.result_id(), type); + id_to_incomplete_type_[inst.result_id()] = type; + return type; + } + break; + case SpvOpTypeRuntimeArray: + type = new RuntimeArray(GetType(inst.GetSingleWordInOperand(0))); + if (id_to_incomplete_type_.count(inst.GetSingleWordInOperand(0))) { + incomplete_types_.emplace_back(inst.result_id(), type); + id_to_incomplete_type_[inst.result_id()] = type; + return type; + } + break; + case SpvOpTypeStruct: { + std::vector element_types; + bool incomplete_type = false; + for (uint32_t i = 0; i < inst.NumInOperands(); ++i) { + uint32_t type_id = inst.GetSingleWordInOperand(i); + element_types.push_back(GetType(type_id)); + if (id_to_incomplete_type_.count(type_id)) { + incomplete_type = true; + } + } + type = new Struct(element_types); + + if (incomplete_type) { + incomplete_types_.emplace_back(inst.result_id(), type); + id_to_incomplete_type_[inst.result_id()] = type; + return type; + } + } break; + case SpvOpTypeOpaque: { + const uint32_t* data = inst.GetInOperand(0).words.data(); + type = new Opaque(reinterpret_cast(data)); + } break; + case SpvOpTypePointer: { + uint32_t pointee_type_id = inst.GetSingleWordInOperand(1); + type = new Pointer( + GetType(pointee_type_id), + static_cast(inst.GetSingleWordInOperand(0))); + + if (id_to_incomplete_type_.count(pointee_type_id)) { + incomplete_types_.emplace_back(inst.result_id(), type); + id_to_incomplete_type_[inst.result_id()] = type; + return type; + } + id_to_incomplete_type_.erase(inst.result_id()); + + } break; + case SpvOpTypeFunction: { + bool incomplete_type = false; + uint32_t return_type_id = inst.GetSingleWordInOperand(0); + if (id_to_incomplete_type_.count(return_type_id)) { + incomplete_type = true; + } + Type* return_type = GetType(return_type_id); + std::vector param_types; + for (uint32_t i = 1; i < inst.NumInOperands(); ++i) { + uint32_t param_type_id = inst.GetSingleWordInOperand(i); + param_types.push_back(GetType(param_type_id)); + if (id_to_incomplete_type_.count(param_type_id)) { + incomplete_type = true; + } + } + + type = new Function(return_type, param_types); + + if (incomplete_type) { + incomplete_types_.emplace_back(inst.result_id(), type); + id_to_incomplete_type_[inst.result_id()] = type; + return type; + } + } break; + case SpvOpTypeEvent: + type = new Event(); + break; + case SpvOpTypeDeviceEvent: + type = new DeviceEvent(); + break; + case SpvOpTypeReserveId: + type = new ReserveId(); + break; + case SpvOpTypeQueue: + type = new Queue(); + break; + case SpvOpTypePipe: + type = new Pipe( + static_cast(inst.GetSingleWordInOperand(0))); + break; + case SpvOpTypeForwardPointer: { + // Handling of forward pointers is different from the other types. + uint32_t target_id = inst.GetSingleWordInOperand(0); + type = new ForwardPointer(target_id, static_cast( + inst.GetSingleWordInOperand(1))); + incomplete_types_.emplace_back(target_id, type); + id_to_incomplete_type_[target_id] = type; + return type; + } + case SpvOpTypePipeStorage: + type = new PipeStorage(); + break; + case SpvOpTypeNamedBarrier: + type = new NamedBarrier(); + break; + case SpvOpTypeAccelerationStructureNV: + type = new AccelerationStructureNV(); + break; + default: + SPIRV_UNIMPLEMENTED(consumer_, "unhandled type"); + break; + } + + uint32_t id = inst.result_id(); + SPIRV_ASSERT(consumer_, id != 0, "instruction without result id found"); + SPIRV_ASSERT(consumer_, type != nullptr, + "type should not be nullptr at this point"); + std::vector decorations = + context()->get_decoration_mgr()->GetDecorationsFor(id, true); + for (auto dec : decorations) { + AttachDecoration(*dec, type); + } + std::unique_ptr unique(type); + auto pair = type_pool_.insert(std::move(unique)); + id_to_type_[id] = pair.first->get(); + type_to_id_[pair.first->get()] = id; + return type; +} + +void TypeManager::AttachDecoration(const Instruction& inst, Type* type) { + const SpvOp opcode = inst.opcode(); + if (!IsAnnotationInst(opcode)) return; + + switch (opcode) { + case SpvOpDecorate: { + const auto count = inst.NumOperands(); + std::vector data; + for (uint32_t i = 1; i < count; ++i) { + data.push_back(inst.GetSingleWordOperand(i)); + } + type->AddDecoration(std::move(data)); + } break; + case SpvOpMemberDecorate: { + const auto count = inst.NumOperands(); + const uint32_t index = inst.GetSingleWordOperand(1); + std::vector data; + for (uint32_t i = 2; i < count; ++i) { + data.push_back(inst.GetSingleWordOperand(i)); + } + if (Struct* st = type->AsStruct()) { + st->AddMemberDecoration(index, std::move(data)); + } else { + SPIRV_UNIMPLEMENTED(consumer_, "OpMemberDecorate non-struct type"); + } + } break; + default: + SPIRV_UNREACHABLE(consumer_); + break; + } +} + +const Type* TypeManager::GetMemberType( + const Type* parent_type, const std::vector& access_chain) { + for (uint32_t element_index : access_chain) { + if (const Struct* struct_type = parent_type->AsStruct()) { + parent_type = struct_type->element_types()[element_index]; + } else if (const Array* array_type = parent_type->AsArray()) { + parent_type = array_type->element_type(); + } else if (const RuntimeArray* runtime_array_type = + parent_type->AsRuntimeArray()) { + parent_type = runtime_array_type->element_type(); + } else if (const Vector* vector_type = parent_type->AsVector()) { + parent_type = vector_type->element_type(); + } else if (const Matrix* matrix_type = parent_type->AsMatrix()) { + parent_type = matrix_type->element_type(); + } else { + assert(false && "Trying to get a member of a type without members."); + } + } + return parent_type; +} + +void TypeManager::ReplaceForwardPointers(Type* type) { + switch (type->kind()) { + case Type::kArray: { + const ForwardPointer* element_type = + type->AsArray()->element_type()->AsForwardPointer(); + if (element_type) { + type->AsArray()->ReplaceElementType(element_type->target_pointer()); + } + } break; + case Type::kRuntimeArray: { + const ForwardPointer* element_type = + type->AsRuntimeArray()->element_type()->AsForwardPointer(); + if (element_type) { + type->AsRuntimeArray()->ReplaceElementType( + element_type->target_pointer()); + } + } break; + case Type::kStruct: { + auto& member_types = type->AsStruct()->element_types(); + for (auto& member_type : member_types) { + if (member_type->AsForwardPointer()) { + member_type = member_type->AsForwardPointer()->target_pointer(); + assert(member_type); + } + } + } break; + case Type::kPointer: { + const ForwardPointer* pointee_type = + type->AsPointer()->pointee_type()->AsForwardPointer(); + if (pointee_type) { + type->AsPointer()->SetPointeeType(pointee_type->target_pointer()); + } + } break; + case Type::kFunction: { + Function* func_type = type->AsFunction(); + const ForwardPointer* return_type = + func_type->return_type()->AsForwardPointer(); + if (return_type) { + func_type->SetReturnType(return_type->target_pointer()); + } + + auto& param_types = func_type->param_types(); + for (auto& param_type : param_types) { + if (param_type->AsForwardPointer()) { + param_type = param_type->AsForwardPointer()->target_pointer(); + } + } + } break; + default: + break; + } +} + +void TypeManager::ReplaceType(Type* new_type, Type* original_type) { + assert(original_type->kind() == new_type->kind() && + "Types must be the same for replacement.\n"); + for (auto& p : incomplete_types_) { + Type* type = p.type(); + if (!type) { + continue; + } + + switch (type->kind()) { + case Type::kArray: { + const Type* element_type = type->AsArray()->element_type(); + if (element_type == original_type) { + type->AsArray()->ReplaceElementType(new_type); + } + } break; + case Type::kRuntimeArray: { + const Type* element_type = type->AsRuntimeArray()->element_type(); + if (element_type == original_type) { + type->AsRuntimeArray()->ReplaceElementType(new_type); + } + } break; + case Type::kStruct: { + auto& member_types = type->AsStruct()->element_types(); + for (auto& member_type : member_types) { + if (member_type == original_type) { + member_type = new_type; + } + } + } break; + case Type::kPointer: { + const Type* pointee_type = type->AsPointer()->pointee_type(); + if (pointee_type == original_type) { + type->AsPointer()->SetPointeeType(new_type); + } + } break; + case Type::kFunction: { + Function* func_type = type->AsFunction(); + const Type* return_type = func_type->return_type(); + if (return_type == original_type) { + func_type->SetReturnType(new_type); + } + + auto& param_types = func_type->param_types(); + for (auto& param_type : param_types) { + if (param_type == original_type) { + param_type = new_type; + } + } + } break; + default: + break; + } + } +} + +} // namespace analysis +} // namespace opt +} // namespace spvtools diff --git a/source/opt/type_manager.h b/source/opt/type_manager.h new file mode 100644 index 000000000..c44969e84 --- /dev/null +++ b/source/opt/type_manager.h @@ -0,0 +1,218 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_TYPE_MANAGER_H_ +#define SOURCE_OPT_TYPE_MANAGER_H_ + +#include +#include +#include +#include +#include + +#include "source/opt/module.h" +#include "source/opt/types.h" +#include "spirv-tools/libspirv.hpp" + +namespace spvtools { +namespace opt { + +class IRContext; + +namespace analysis { + +// Hashing functor. +// +// All type pointers must be non-null. +struct HashTypePointer { + size_t operator()(const Type* type) const { + assert(type); + return type->HashValue(); + } +}; +struct HashTypeUniquePointer { + size_t operator()(const std::unique_ptr& type) const { + assert(type); + return type->HashValue(); + } +}; + +// Equality functor. +// +// Checks if two types pointers are the same type. +// +// All type pointers must be non-null. +struct CompareTypePointers { + bool operator()(const Type* lhs, const Type* rhs) const { + assert(lhs && rhs); + return lhs->IsSame(rhs); + } +}; +struct CompareTypeUniquePointers { + bool operator()(const std::unique_ptr& lhs, + const std::unique_ptr& rhs) const { + assert(lhs && rhs); + return lhs->IsSame(rhs.get()); + } +}; + +// A class for managing the SPIR-V type hierarchy. +class TypeManager { + public: + using IdToTypeMap = std::unordered_map; + + // Constructs a type manager from the given |module|. All internal messages + // will be communicated to the outside via the given message |consumer|. + // This instance only keeps a reference to the |consumer|, so the |consumer| + // should outlive this instance. + TypeManager(const MessageConsumer& consumer, IRContext* c); + + TypeManager(const TypeManager&) = delete; + TypeManager(TypeManager&&) = delete; + TypeManager& operator=(const TypeManager&) = delete; + TypeManager& operator=(TypeManager&&) = delete; + + // Returns the type for the given type |id|. Returns nullptr if the given |id| + // does not define a type. + Type* GetType(uint32_t id) const; + // Returns the id for the given |type|. Returns 0 if can not find the given + // |type|. + uint32_t GetId(const Type* type) const; + // Returns the number of types hold in this manager. + size_t NumTypes() const { return id_to_type_.size(); } + // Iterators for all types contained in this manager. + IdToTypeMap::const_iterator begin() const { return id_to_type_.cbegin(); } + IdToTypeMap::const_iterator end() const { return id_to_type_.cend(); } + + // Returns a pair of the type and pointer to the type in |sc|. + // + // |id| must be a registered type. + std::pair> GetTypeAndPointerType( + uint32_t id, SpvStorageClass sc) const; + + // Returns an id for a declaration representing |type|. + // + // If |type| is registered, then the registered id is returned. Otherwise, + // this function recursively adds type and annotation instructions as + // necessary to fully define |type|. + uint32_t GetTypeInstruction(const Type* type); + + // Find pointer to type and storage in module, return its resultId. If it is + // not found, a new type is created, and its id is returned. + uint32_t FindPointerToType(uint32_t type_id, SpvStorageClass storage_class); + + // Registers |id| to |type|. + // + // If GetId(|type|) already returns a non-zero id, that mapping will be + // unchanged. + void RegisterType(uint32_t id, const Type& type); + + Type* GetRegisteredType(const Type* type); + + // Removes knowledge of |id| from the manager. + // + // If |id| is an ambiguous type the multiple ids may be registered to |id|'s + // type (e.g. %struct1 and %struct1 might hash to the same type). In that + // case, calling GetId() with |id|'s type will return another suitable id + // defining that type. + void RemoveId(uint32_t id); + + // Returns the type of the member of |parent_type| that is identified by + // |access_chain|. The vector |access_chain| is a series of integers that are + // used to pick members as in the |OpCompositeExtract| instructions. If you + // want a member of an array, vector, or matrix that does not have a constant + // index, you can use 0 in that position. All elements have the same type. + const Type* GetMemberType(const Type* parent_type, + const std::vector& access_chain); + + private: + using TypeToIdMap = std::unordered_map; + using TypePool = + std::unordered_set, HashTypeUniquePointer, + CompareTypeUniquePointers>; + + class UnresolvedType { + public: + UnresolvedType(uint32_t i, Type* t) : id_(i), type_(t) {} + UnresolvedType(const UnresolvedType&) = delete; + UnresolvedType(UnresolvedType&& that) + : id_(that.id_), type_(std::move(that.type_)) {} + + uint32_t id() { return id_; } + Type* type() { return type_.get(); } + std::unique_ptr&& ReleaseType() { return std::move(type_); } + void ResetType(Type* t) { type_.reset(t); } + + private: + uint32_t id_; + std::unique_ptr type_; + }; + using IdToUnresolvedType = std::vector; + + // Analyzes the types and decorations on types in the given |module|. + void AnalyzeTypes(const Module& module); + + IRContext* context() { return context_; } + + // Attaches the decorations on |type| to |id|. + void AttachDecorations(uint32_t id, const Type* type); + + // Create the annotation instruction. + // + // If |element| is zero, an OpDecorate is created, other an OpMemberDecorate + // is created. The annotation is registered with the DefUseManager and the + // DecorationManager. + void CreateDecoration(uint32_t id, const std::vector& decoration, + uint32_t element = 0); + + // Creates and returns a type from the given SPIR-V |inst|. Returns nullptr if + // the given instruction is not for defining a type. + Type* RecordIfTypeDefinition(const Instruction& inst); + // Attaches the decoration encoded in |inst| to |type|. Does nothing if the + // given instruction is not a decoration instruction. Assumes the target is + // |type| (e.g. should be called in loop of |type|'s decorations). + void AttachDecoration(const Instruction& inst, Type* type); + + // Returns an equivalent pointer to |type| built in terms of pointers owned by + // |type_pool_|. For example, if |type| is a vec3 of bool, it will be rebuilt + // replacing the bool subtype with one owned by |type_pool_|. + Type* RebuildType(const Type& type); + + // Completes the incomplete type |type|, by replaces all references to + // ForwardPointer by the defining Pointer. + void ReplaceForwardPointers(Type* type); + + // Replaces all references to |original_type| in |incomplete_types_| by + // |new_type|. + void ReplaceType(Type* new_type, Type* original_type); + + const MessageConsumer& consumer_; // Message consumer. + IRContext* context_; + IdToTypeMap id_to_type_; // Mapping from ids to their type representations. + TypeToIdMap type_to_id_; // Mapping from types to their defining ids. + TypePool type_pool_; // Memory owner of type pointers. + IdToUnresolvedType incomplete_types_; // All incomplete types. Stored in an + // std::vector to make traversals + // deterministic. + + IdToTypeMap id_to_incomplete_type_; // Maps ids to their type representations + // for incomplete types. +}; + +} // namespace analysis +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_TYPE_MANAGER_H_ diff --git a/source/opt/types.cpp b/source/opt/types.cpp new file mode 100644 index 000000000..cfafc7dce --- /dev/null +++ b/source/opt/types.cpp @@ -0,0 +1,637 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "source/opt/types.h" +#include "source/util/make_unique.h" + +namespace spvtools { +namespace opt { +namespace analysis { + +using U32VecVec = std::vector>; + +namespace { + +// Returns true if the two vector of vectors are identical. +bool CompareTwoVectors(const U32VecVec a, const U32VecVec b) { + const auto size = a.size(); + if (size != b.size()) return false; + + if (size == 0) return true; + if (size == 1) return a.front() == b.front(); + + std::vector*> a_ptrs, b_ptrs; + a_ptrs.reserve(size); + a_ptrs.reserve(size); + for (uint32_t i = 0; i < size; ++i) { + a_ptrs.push_back(&a[i]); + b_ptrs.push_back(&b[i]); + } + + const auto cmp = [](const std::vector* m, + const std::vector* n) { + return m->front() < n->front(); + }; + + std::sort(a_ptrs.begin(), a_ptrs.end(), cmp); + std::sort(b_ptrs.begin(), b_ptrs.end(), cmp); + + for (uint32_t i = 0; i < size; ++i) { + if (*a_ptrs[i] != *b_ptrs[i]) return false; + } + return true; +} + +} // anonymous namespace + +std::string Type::GetDecorationStr() const { + std::ostringstream oss; + oss << "[["; + for (const auto& decoration : decorations_) { + oss << "("; + for (size_t i = 0; i < decoration.size(); ++i) { + oss << (i > 0 ? ", " : ""); + oss << decoration.at(i); + } + oss << ")"; + } + oss << "]]"; + return oss.str(); +} + +bool Type::HasSameDecorations(const Type* that) const { + return CompareTwoVectors(decorations_, that->decorations_); +} + +bool Type::IsUniqueType(bool allowVariablePointers) const { + switch (kind_) { + case kPointer: + return !allowVariablePointers; + case kStruct: + case kArray: + case kRuntimeArray: + return false; + default: + return true; + } +} + +std::unique_ptr Type::Clone() const { + std::unique_ptr type; + switch (kind_) { +#define DeclareKindCase(kind) \ + case k##kind: \ + type = MakeUnique(*this->As##kind()); \ + break; + DeclareKindCase(Void); + DeclareKindCase(Bool); + DeclareKindCase(Integer); + DeclareKindCase(Float); + DeclareKindCase(Vector); + DeclareKindCase(Matrix); + DeclareKindCase(Image); + DeclareKindCase(Sampler); + DeclareKindCase(SampledImage); + DeclareKindCase(Array); + DeclareKindCase(RuntimeArray); + DeclareKindCase(Struct); + DeclareKindCase(Opaque); + DeclareKindCase(Pointer); + DeclareKindCase(Function); + DeclareKindCase(Event); + DeclareKindCase(DeviceEvent); + DeclareKindCase(ReserveId); + DeclareKindCase(Queue); + DeclareKindCase(Pipe); + DeclareKindCase(ForwardPointer); + DeclareKindCase(PipeStorage); + DeclareKindCase(NamedBarrier); + DeclareKindCase(AccelerationStructureNV); +#undef DeclareKindCase + default: + assert(false && "Unhandled type"); + } + return type; +} + +std::unique_ptr Type::RemoveDecorations() const { + std::unique_ptr type(Clone()); + type->ClearDecorations(); + return type; +} + +bool Type::operator==(const Type& other) const { + if (kind_ != other.kind_) return false; + + switch (kind_) { +#define DeclareKindCase(kind) \ + case k##kind: \ + return As##kind()->IsSame(&other); + DeclareKindCase(Void); + DeclareKindCase(Bool); + DeclareKindCase(Integer); + DeclareKindCase(Float); + DeclareKindCase(Vector); + DeclareKindCase(Matrix); + DeclareKindCase(Image); + DeclareKindCase(Sampler); + DeclareKindCase(SampledImage); + DeclareKindCase(Array); + DeclareKindCase(RuntimeArray); + DeclareKindCase(Struct); + DeclareKindCase(Opaque); + DeclareKindCase(Pointer); + DeclareKindCase(Function); + DeclareKindCase(Event); + DeclareKindCase(DeviceEvent); + DeclareKindCase(ReserveId); + DeclareKindCase(Queue); + DeclareKindCase(Pipe); + DeclareKindCase(ForwardPointer); + DeclareKindCase(PipeStorage); + DeclareKindCase(NamedBarrier); + DeclareKindCase(AccelerationStructureNV); +#undef DeclareKindCase + default: + assert(false && "Unhandled type"); + return false; + } +} + +void Type::GetHashWords(std::vector* words, + std::unordered_set* seen) const { + if (!seen->insert(this).second) { + return; + } + + words->push_back(kind_); + for (const auto& d : decorations_) { + for (auto w : d) { + words->push_back(w); + } + } + + switch (kind_) { +#define DeclareKindCase(type) \ + case k##type: \ + As##type()->GetExtraHashWords(words, seen); \ + break; + DeclareKindCase(Void); + DeclareKindCase(Bool); + DeclareKindCase(Integer); + DeclareKindCase(Float); + DeclareKindCase(Vector); + DeclareKindCase(Matrix); + DeclareKindCase(Image); + DeclareKindCase(Sampler); + DeclareKindCase(SampledImage); + DeclareKindCase(Array); + DeclareKindCase(RuntimeArray); + DeclareKindCase(Struct); + DeclareKindCase(Opaque); + DeclareKindCase(Pointer); + DeclareKindCase(Function); + DeclareKindCase(Event); + DeclareKindCase(DeviceEvent); + DeclareKindCase(ReserveId); + DeclareKindCase(Queue); + DeclareKindCase(Pipe); + DeclareKindCase(ForwardPointer); + DeclareKindCase(PipeStorage); + DeclareKindCase(NamedBarrier); + DeclareKindCase(AccelerationStructureNV); +#undef DeclareKindCase + default: + assert(false && "Unhandled type"); + break; + } + + seen->erase(this); +} + +size_t Type::HashValue() const { + std::u32string h; + std::vector words; + GetHashWords(&words); + for (auto w : words) { + h.push_back(w); + } + + return std::hash()(h); +} + +bool Integer::IsSameImpl(const Type* that, IsSameCache*) const { + const Integer* it = that->AsInteger(); + return it && width_ == it->width_ && signed_ == it->signed_ && + HasSameDecorations(that); +} + +std::string Integer::str() const { + std::ostringstream oss; + oss << (signed_ ? "s" : "u") << "int" << width_; + return oss.str(); +} + +void Integer::GetExtraHashWords(std::vector* words, + std::unordered_set*) const { + words->push_back(width_); + words->push_back(signed_); +} + +bool Float::IsSameImpl(const Type* that, IsSameCache*) const { + const Float* ft = that->AsFloat(); + return ft && width_ == ft->width_ && HasSameDecorations(that); +} + +std::string Float::str() const { + std::ostringstream oss; + oss << "float" << width_; + return oss.str(); +} + +void Float::GetExtraHashWords(std::vector* words, + std::unordered_set*) const { + words->push_back(width_); +} + +Vector::Vector(Type* type, uint32_t count) + : Type(kVector), element_type_(type), count_(count) { + assert(type->AsBool() || type->AsInteger() || type->AsFloat()); +} + +bool Vector::IsSameImpl(const Type* that, IsSameCache* seen) const { + const Vector* vt = that->AsVector(); + if (!vt) return false; + return count_ == vt->count_ && + element_type_->IsSameImpl(vt->element_type_, seen) && + HasSameDecorations(that); +} + +std::string Vector::str() const { + std::ostringstream oss; + oss << "<" << element_type_->str() << ", " << count_ << ">"; + return oss.str(); +} + +void Vector::GetExtraHashWords(std::vector* words, + std::unordered_set* seen) const { + element_type_->GetHashWords(words, seen); + words->push_back(count_); +} + +Matrix::Matrix(Type* type, uint32_t count) + : Type(kMatrix), element_type_(type), count_(count) { + assert(type->AsVector()); +} + +bool Matrix::IsSameImpl(const Type* that, IsSameCache* seen) const { + const Matrix* mt = that->AsMatrix(); + if (!mt) return false; + return count_ == mt->count_ && + element_type_->IsSameImpl(mt->element_type_, seen) && + HasSameDecorations(that); +} + +std::string Matrix::str() const { + std::ostringstream oss; + oss << "<" << element_type_->str() << ", " << count_ << ">"; + return oss.str(); +} + +void Matrix::GetExtraHashWords(std::vector* words, + std::unordered_set* seen) const { + element_type_->GetHashWords(words, seen); + words->push_back(count_); +} + +Image::Image(Type* type, SpvDim dimen, uint32_t d, bool array, bool multisample, + uint32_t sampling, SpvImageFormat f, SpvAccessQualifier qualifier) + : Type(kImage), + sampled_type_(type), + dim_(dimen), + depth_(d), + arrayed_(array), + ms_(multisample), + sampled_(sampling), + format_(f), + access_qualifier_(qualifier) { + // TODO(antiagainst): check sampled_type +} + +bool Image::IsSameImpl(const Type* that, IsSameCache* seen) const { + const Image* it = that->AsImage(); + if (!it) return false; + return dim_ == it->dim_ && depth_ == it->depth_ && arrayed_ == it->arrayed_ && + ms_ == it->ms_ && sampled_ == it->sampled_ && format_ == it->format_ && + access_qualifier_ == it->access_qualifier_ && + sampled_type_->IsSameImpl(it->sampled_type_, seen) && + HasSameDecorations(that); +} + +std::string Image::str() const { + std::ostringstream oss; + oss << "image(" << sampled_type_->str() << ", " << dim_ << ", " << depth_ + << ", " << arrayed_ << ", " << ms_ << ", " << sampled_ << ", " << format_ + << ", " << access_qualifier_ << ")"; + return oss.str(); +} + +void Image::GetExtraHashWords(std::vector* words, + std::unordered_set* seen) const { + sampled_type_->GetHashWords(words, seen); + words->push_back(dim_); + words->push_back(depth_); + words->push_back(arrayed_); + words->push_back(ms_); + words->push_back(sampled_); + words->push_back(format_); + words->push_back(access_qualifier_); +} + +bool SampledImage::IsSameImpl(const Type* that, IsSameCache* seen) const { + const SampledImage* sit = that->AsSampledImage(); + if (!sit) return false; + return image_type_->IsSameImpl(sit->image_type_, seen) && + HasSameDecorations(that); +} + +std::string SampledImage::str() const { + std::ostringstream oss; + oss << "sampled_image(" << image_type_->str() << ")"; + return oss.str(); +} + +void SampledImage::GetExtraHashWords( + std::vector* words, std::unordered_set* seen) const { + image_type_->GetHashWords(words, seen); +} + +Array::Array(Type* type, uint32_t length_id) + : Type(kArray), element_type_(type), length_id_(length_id) { + assert(!type->AsVoid()); +} + +bool Array::IsSameImpl(const Type* that, IsSameCache* seen) const { + const Array* at = that->AsArray(); + if (!at) return false; + return length_id_ == at->length_id_ && + element_type_->IsSameImpl(at->element_type_, seen) && + HasSameDecorations(that); +} + +std::string Array::str() const { + std::ostringstream oss; + oss << "[" << element_type_->str() << ", id(" << length_id_ << ")]"; + return oss.str(); +} + +void Array::GetExtraHashWords(std::vector* words, + std::unordered_set* seen) const { + element_type_->GetHashWords(words, seen); + words->push_back(length_id_); +} + +void Array::ReplaceElementType(const Type* type) { element_type_ = type; } + +RuntimeArray::RuntimeArray(Type* type) + : Type(kRuntimeArray), element_type_(type) { + assert(!type->AsVoid()); +} + +bool RuntimeArray::IsSameImpl(const Type* that, IsSameCache* seen) const { + const RuntimeArray* rat = that->AsRuntimeArray(); + if (!rat) return false; + return element_type_->IsSameImpl(rat->element_type_, seen) && + HasSameDecorations(that); +} + +std::string RuntimeArray::str() const { + std::ostringstream oss; + oss << "[" << element_type_->str() << "]"; + return oss.str(); +} + +void RuntimeArray::GetExtraHashWords( + std::vector* words, std::unordered_set* seen) const { + element_type_->GetHashWords(words, seen); +} + +void RuntimeArray::ReplaceElementType(const Type* type) { + element_type_ = type; +} + +Struct::Struct(const std::vector& types) + : Type(kStruct), element_types_(types) { + for (const auto* t : types) { + (void)t; + assert(!t->AsVoid()); + } +} + +void Struct::AddMemberDecoration(uint32_t index, + std::vector&& decoration) { + if (index >= element_types_.size()) { + assert(0 && "index out of bound"); + return; + } + + element_decorations_[index].push_back(std::move(decoration)); +} + +bool Struct::IsSameImpl(const Type* that, IsSameCache* seen) const { + const Struct* st = that->AsStruct(); + if (!st) return false; + if (element_types_.size() != st->element_types_.size()) return false; + const auto size = element_decorations_.size(); + if (size != st->element_decorations_.size()) return false; + if (!HasSameDecorations(that)) return false; + + for (size_t i = 0; i < element_types_.size(); ++i) { + if (!element_types_[i]->IsSameImpl(st->element_types_[i], seen)) + return false; + } + for (const auto& p : element_decorations_) { + if (st->element_decorations_.count(p.first) == 0) return false; + if (!CompareTwoVectors(p.second, st->element_decorations_.at(p.first))) + return false; + } + return true; +} + +std::string Struct::str() const { + std::ostringstream oss; + oss << "{"; + const size_t count = element_types_.size(); + for (size_t i = 0; i < count; ++i) { + oss << element_types_[i]->str(); + if (i + 1 != count) oss << ", "; + } + oss << "}"; + return oss.str(); +} + +void Struct::GetExtraHashWords(std::vector* words, + std::unordered_set* seen) const { + for (auto* t : element_types_) { + t->GetHashWords(words, seen); + } + for (const auto& pair : element_decorations_) { + words->push_back(pair.first); + for (const auto& d : pair.second) { + for (auto w : d) { + words->push_back(w); + } + } + } +} + +bool Opaque::IsSameImpl(const Type* that, IsSameCache*) const { + const Opaque* ot = that->AsOpaque(); + if (!ot) return false; + return name_ == ot->name_ && HasSameDecorations(that); +} + +std::string Opaque::str() const { + std::ostringstream oss; + oss << "opaque('" << name_ << "')"; + return oss.str(); +} + +void Opaque::GetExtraHashWords(std::vector* words, + std::unordered_set*) const { + for (auto c : name_) { + words->push_back(static_cast(c)); + } +} + +Pointer::Pointer(const Type* type, SpvStorageClass sc) + : Type(kPointer), pointee_type_(type), storage_class_(sc) {} + +bool Pointer::IsSameImpl(const Type* that, IsSameCache* seen) const { + const Pointer* pt = that->AsPointer(); + if (!pt) return false; + if (storage_class_ != pt->storage_class_) return false; + auto p = seen->insert(std::make_pair(this, that->AsPointer())); + if (!p.second) { + return true; + } + bool same_pointee = pointee_type_->IsSameImpl(pt->pointee_type_, seen); + seen->erase(p.first); + if (!same_pointee) { + return false; + } + return HasSameDecorations(that); +} + +std::string Pointer::str() const { return pointee_type_->str() + "*"; } + +void Pointer::GetExtraHashWords(std::vector* words, + std::unordered_set* seen) const { + pointee_type_->GetHashWords(words, seen); + words->push_back(storage_class_); +} + +void Pointer::SetPointeeType(const Type* type) { pointee_type_ = type; } + +Function::Function(Type* ret_type, const std::vector& params) + : Type(kFunction), return_type_(ret_type), param_types_(params) {} + +Function::Function(Type* ret_type, std::vector& params) + : Type(kFunction), return_type_(ret_type), param_types_(params) {} + +bool Function::IsSameImpl(const Type* that, IsSameCache* seen) const { + const Function* ft = that->AsFunction(); + if (!ft) return false; + if (!return_type_->IsSameImpl(ft->return_type_, seen)) return false; + if (param_types_.size() != ft->param_types_.size()) return false; + for (size_t i = 0; i < param_types_.size(); ++i) { + if (!param_types_[i]->IsSameImpl(ft->param_types_[i], seen)) return false; + } + return HasSameDecorations(that); +} + +std::string Function::str() const { + std::ostringstream oss; + const size_t count = param_types_.size(); + oss << "("; + for (size_t i = 0; i < count; ++i) { + oss << param_types_[i]->str(); + if (i + 1 != count) oss << ", "; + } + oss << ") -> " << return_type_->str(); + return oss.str(); +} + +void Function::GetExtraHashWords(std::vector* words, + std::unordered_set* seen) const { + return_type_->GetHashWords(words, seen); + for (const auto* t : param_types_) { + t->GetHashWords(words, seen); + } +} + +void Function::SetReturnType(const Type* type) { return_type_ = type; } + +bool Pipe::IsSameImpl(const Type* that, IsSameCache*) const { + const Pipe* pt = that->AsPipe(); + if (!pt) return false; + return access_qualifier_ == pt->access_qualifier_ && HasSameDecorations(that); +} + +std::string Pipe::str() const { + std::ostringstream oss; + oss << "pipe(" << access_qualifier_ << ")"; + return oss.str(); +} + +void Pipe::GetExtraHashWords(std::vector* words, + std::unordered_set*) const { + words->push_back(access_qualifier_); +} + +bool ForwardPointer::IsSameImpl(const Type* that, IsSameCache*) const { + const ForwardPointer* fpt = that->AsForwardPointer(); + if (!fpt) return false; + return target_id_ == fpt->target_id_ && + storage_class_ == fpt->storage_class_ && HasSameDecorations(that); +} + +std::string ForwardPointer::str() const { + std::ostringstream oss; + oss << "forward_pointer("; + if (pointer_ != nullptr) { + oss << pointer_->str(); + } else { + oss << target_id_; + } + oss << ")"; + return oss.str(); +} + +void ForwardPointer::GetExtraHashWords( + std::vector* words, std::unordered_set* seen) const { + words->push_back(target_id_); + words->push_back(storage_class_); + if (pointer_) pointer_->GetHashWords(words, seen); +} + +} // namespace analysis +} // namespace opt +} // namespace spvtools diff --git a/source/opt/types.h b/source/opt/types.h new file mode 100644 index 000000000..fe0f39af3 --- /dev/null +++ b/source/opt/types.h @@ -0,0 +1,608 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file provides a class hierarchy for representing SPIR-V types. + +#ifndef SOURCE_OPT_TYPES_H_ +#define SOURCE_OPT_TYPES_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/latest_version_spirv_header.h" +#include "spirv-tools/libspirv.h" + +namespace spvtools { +namespace opt { +namespace analysis { + +class Void; +class Bool; +class Integer; +class Float; +class Vector; +class Matrix; +class Image; +class Sampler; +class SampledImage; +class Array; +class RuntimeArray; +class Struct; +class Opaque; +class Pointer; +class Function; +class Event; +class DeviceEvent; +class ReserveId; +class Queue; +class Pipe; +class ForwardPointer; +class PipeStorage; +class NamedBarrier; +class AccelerationStructureNV; + +// Abstract class for a SPIR-V type. It has a bunch of As() methods, +// which is used as a way to probe the actual . +class Type { + public: + typedef std::set> IsSameCache; + + // Available subtypes. + // + // When adding a new derived class of Type, please add an entry to the enum. + enum Kind { + kVoid, + kBool, + kInteger, + kFloat, + kVector, + kMatrix, + kImage, + kSampler, + kSampledImage, + kArray, + kRuntimeArray, + kStruct, + kOpaque, + kPointer, + kFunction, + kEvent, + kDeviceEvent, + kReserveId, + kQueue, + kPipe, + kForwardPointer, + kPipeStorage, + kNamedBarrier, + kAccelerationStructureNV, + }; + + Type(Kind k) : kind_(k) {} + + virtual ~Type() {} + + // Attaches a decoration directly on this type. + void AddDecoration(std::vector&& d) { + decorations_.push_back(std::move(d)); + } + // Returns the decorations on this type as a string. + std::string GetDecorationStr() const; + // Returns true if this type has exactly the same decorations as |that| type. + bool HasSameDecorations(const Type* that) const; + // Returns true if this type is exactly the same as |that| type, including + // decorations. + bool IsSame(const Type* that) const { + IsSameCache seen; + return IsSameImpl(that, &seen); + } + + // Returns true if this type is exactly the same as |that| type, including + // decorations. |seen| is the set of |Pointer*| pair that are currently being + // compared in a parent call to |IsSameImpl|. + virtual bool IsSameImpl(const Type* that, IsSameCache* seen) const = 0; + + // Returns a human-readable string to represent this type. + virtual std::string str() const = 0; + + Kind kind() const { return kind_; } + const std::vector>& decorations() const { + return decorations_; + } + + // Returns true if there is no decoration on this type. For struct types, + // returns true only when there is no decoration for both the struct type + // and the struct members. + virtual bool decoration_empty() const { return decorations_.empty(); } + + // Creates a clone of |this|. + std::unique_ptr Clone() const; + + // Returns a clone of |this| minus any decorations. + std::unique_ptr RemoveDecorations() const; + + // Returns true if this type must be unique. + // + // If variable pointers are allowed, then pointers are not required to be + // unique. + // TODO(alanbaker): Update this if variable pointers become a core feature. + bool IsUniqueType(bool allowVariablePointers = false) const; + +// A bunch of methods for casting this type to a given type. Returns this if the +// cast can be done, nullptr otherwise. +#define DeclareCastMethod(target) \ + virtual target* As##target() { return nullptr; } \ + virtual const target* As##target() const { return nullptr; } + DeclareCastMethod(Void); + DeclareCastMethod(Bool); + DeclareCastMethod(Integer); + DeclareCastMethod(Float); + DeclareCastMethod(Vector); + DeclareCastMethod(Matrix); + DeclareCastMethod(Image); + DeclareCastMethod(Sampler); + DeclareCastMethod(SampledImage); + DeclareCastMethod(Array); + DeclareCastMethod(RuntimeArray); + DeclareCastMethod(Struct); + DeclareCastMethod(Opaque); + DeclareCastMethod(Pointer); + DeclareCastMethod(Function); + DeclareCastMethod(Event); + DeclareCastMethod(DeviceEvent); + DeclareCastMethod(ReserveId); + DeclareCastMethod(Queue); + DeclareCastMethod(Pipe); + DeclareCastMethod(ForwardPointer); + DeclareCastMethod(PipeStorage); + DeclareCastMethod(NamedBarrier); + DeclareCastMethod(AccelerationStructureNV); +#undef DeclareCastMethod + + bool operator==(const Type& other) const; + + // Returns the hash value of this type. + size_t HashValue() const; + + // Adds the necessary words to compute a hash value of this type to |words|. + void GetHashWords(std::vector* words) const { + std::unordered_set seen; + GetHashWords(words, &seen); + } + + // Adds the necessary words to compute a hash value of this type to |words|. + void GetHashWords(std::vector* words, + std::unordered_set* seen) const; + + // Adds necessary extra words for a subtype to calculate a hash value into + // |words|. + virtual void GetExtraHashWords( + std::vector* words, + std::unordered_set* pSet) const = 0; + + protected: + // Decorations attached to this type. Each decoration is encoded as a vector + // of uint32_t numbers. The first uint32_t number is the decoration value, + // and the rest are the parameters to the decoration (if exists). + std::vector> decorations_; + + private: + // Removes decorations on this type. For struct types, also removes element + // decorations. + virtual void ClearDecorations() { decorations_.clear(); } + + Kind kind_; +}; + +class Integer : public Type { + public: + Integer(uint32_t w, bool is_signed) + : Type(kInteger), width_(w), signed_(is_signed) {} + Integer(const Integer&) = default; + + std::string str() const override; + + Integer* AsInteger() override { return this; } + const Integer* AsInteger() const override { return this; } + uint32_t width() const { return width_; } + bool IsSigned() const { return signed_; } + + void GetExtraHashWords(std::vector* words, + std::unordered_set* pSet) const override; + + private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; + + uint32_t width_; // bit width + bool signed_; // true if this integer is signed +}; + +class Float : public Type { + public: + Float(uint32_t w) : Type(kFloat), width_(w) {} + Float(const Float&) = default; + + std::string str() const override; + + Float* AsFloat() override { return this; } + const Float* AsFloat() const override { return this; } + uint32_t width() const { return width_; } + + void GetExtraHashWords(std::vector* words, + std::unordered_set* pSet) const override; + + private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; + + uint32_t width_; // bit width +}; + +class Vector : public Type { + public: + Vector(Type* element_type, uint32_t count); + Vector(const Vector&) = default; + + std::string str() const override; + const Type* element_type() const { return element_type_; } + uint32_t element_count() const { return count_; } + + Vector* AsVector() override { return this; } + const Vector* AsVector() const override { return this; } + + void GetExtraHashWords(std::vector* words, + std::unordered_set* pSet) const override; + + private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; + + const Type* element_type_; + uint32_t count_; +}; + +class Matrix : public Type { + public: + Matrix(Type* element_type, uint32_t count); + Matrix(const Matrix&) = default; + + std::string str() const override; + const Type* element_type() const { return element_type_; } + uint32_t element_count() const { return count_; } + + Matrix* AsMatrix() override { return this; } + const Matrix* AsMatrix() const override { return this; } + + void GetExtraHashWords(std::vector* words, + std::unordered_set* pSet) const override; + + private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; + + const Type* element_type_; + uint32_t count_; +}; + +class Image : public Type { + public: + Image(Type* type, SpvDim dimen, uint32_t d, bool array, bool multisample, + uint32_t sampling, SpvImageFormat f, + SpvAccessQualifier qualifier = SpvAccessQualifierReadOnly); + Image(const Image&) = default; + + std::string str() const override; + + Image* AsImage() override { return this; } + const Image* AsImage() const override { return this; } + + const Type* sampled_type() const { return sampled_type_; } + SpvDim dim() const { return dim_; } + uint32_t depth() const { return depth_; } + bool is_arrayed() const { return arrayed_; } + bool is_multisampled() const { return ms_; } + uint32_t sampled() const { return sampled_; } + SpvImageFormat format() const { return format_; } + SpvAccessQualifier access_qualifier() const { return access_qualifier_; } + + void GetExtraHashWords(std::vector* words, + std::unordered_set* pSet) const override; + + private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; + + Type* sampled_type_; + SpvDim dim_; + uint32_t depth_; + bool arrayed_; + bool ms_; + uint32_t sampled_; + SpvImageFormat format_; + SpvAccessQualifier access_qualifier_; +}; + +class SampledImage : public Type { + public: + SampledImage(Type* image) : Type(kSampledImage), image_type_(image) {} + SampledImage(const SampledImage&) = default; + + std::string str() const override; + + SampledImage* AsSampledImage() override { return this; } + const SampledImage* AsSampledImage() const override { return this; } + + const Type* image_type() const { return image_type_; } + + void GetExtraHashWords(std::vector* words, + std::unordered_set* pSet) const override; + + private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; + Type* image_type_; +}; + +class Array : public Type { + public: + Array(Type* element_type, uint32_t length_id); + Array(const Array&) = default; + + std::string str() const override; + const Type* element_type() const { return element_type_; } + uint32_t LengthId() const { return length_id_; } + + Array* AsArray() override { return this; } + const Array* AsArray() const override { return this; } + + void GetExtraHashWords(std::vector* words, + std::unordered_set* pSet) const override; + + void ReplaceElementType(const Type* element_type); + + private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; + + const Type* element_type_; + uint32_t length_id_; +}; + +class RuntimeArray : public Type { + public: + RuntimeArray(Type* element_type); + RuntimeArray(const RuntimeArray&) = default; + + std::string str() const override; + const Type* element_type() const { return element_type_; } + + RuntimeArray* AsRuntimeArray() override { return this; } + const RuntimeArray* AsRuntimeArray() const override { return this; } + + void GetExtraHashWords(std::vector* words, + std::unordered_set* pSet) const override; + + void ReplaceElementType(const Type* element_type); + + private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; + + const Type* element_type_; +}; + +class Struct : public Type { + public: + Struct(const std::vector& element_types); + Struct(const Struct&) = default; + + // Adds a decoration to the member at the given index. The first word is the + // decoration enum, and the remaining words, if any, are its operands. + void AddMemberDecoration(uint32_t index, std::vector&& decoration); + + std::string str() const override; + const std::vector& element_types() const { + return element_types_; + } + std::vector& element_types() { return element_types_; } + bool decoration_empty() const override { + return decorations_.empty() && element_decorations_.empty(); + } + + const std::map>>& + element_decorations() const { + return element_decorations_; + } + + Struct* AsStruct() override { return this; } + const Struct* AsStruct() const override { return this; } + + void GetExtraHashWords(std::vector* words, + std::unordered_set* pSet) const override; + + private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; + + void ClearDecorations() override { + decorations_.clear(); + element_decorations_.clear(); + } + + std::vector element_types_; + // We can attach decorations to struct members and that should not affect the + // underlying element type. So we need an extra data structure here to keep + // track of element type decorations. They must be stored in an ordered map + // because |GetExtraHashWords| will traverse the structure. It must have a + // fixed order in order to hash to the same value every time. + std::map>> element_decorations_; +}; + +class Opaque : public Type { + public: + Opaque(std::string n) : Type(kOpaque), name_(std::move(n)) {} + Opaque(const Opaque&) = default; + + std::string str() const override; + + Opaque* AsOpaque() override { return this; } + const Opaque* AsOpaque() const override { return this; } + + const std::string& name() const { return name_; } + + void GetExtraHashWords(std::vector* words, + std::unordered_set* pSet) const override; + + private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; + + std::string name_; +}; + +class Pointer : public Type { + public: + Pointer(const Type* pointee, SpvStorageClass sc); + Pointer(const Pointer&) = default; + + std::string str() const override; + const Type* pointee_type() const { return pointee_type_; } + SpvStorageClass storage_class() const { return storage_class_; } + + Pointer* AsPointer() override { return this; } + const Pointer* AsPointer() const override { return this; } + + void GetExtraHashWords(std::vector* words, + std::unordered_set* pSet) const override; + + void SetPointeeType(const Type* type); + + private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; + + const Type* pointee_type_; + SpvStorageClass storage_class_; +}; + +class Function : public Type { + public: + Function(Type* ret_type, const std::vector& params); + Function(Type* ret_type, std::vector& params); + Function(const Function&) = default; + + std::string str() const override; + + Function* AsFunction() override { return this; } + const Function* AsFunction() const override { return this; } + + const Type* return_type() const { return return_type_; } + const std::vector& param_types() const { return param_types_; } + std::vector& param_types() { return param_types_; } + + void GetExtraHashWords(std::vector* words, + std::unordered_set*) const override; + + void SetReturnType(const Type* type); + + private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; + + const Type* return_type_; + std::vector param_types_; +}; + +class Pipe : public Type { + public: + Pipe(SpvAccessQualifier qualifier) + : Type(kPipe), access_qualifier_(qualifier) {} + Pipe(const Pipe&) = default; + + std::string str() const override; + + Pipe* AsPipe() override { return this; } + const Pipe* AsPipe() const override { return this; } + + SpvAccessQualifier access_qualifier() const { return access_qualifier_; } + + void GetExtraHashWords(std::vector* words, + std::unordered_set* pSet) const override; + + private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; + + SpvAccessQualifier access_qualifier_; +}; + +class ForwardPointer : public Type { + public: + ForwardPointer(uint32_t id, SpvStorageClass sc) + : Type(kForwardPointer), + target_id_(id), + storage_class_(sc), + pointer_(nullptr) {} + ForwardPointer(const ForwardPointer&) = default; + + uint32_t target_id() const { return target_id_; } + void SetTargetPointer(const Pointer* pointer) { pointer_ = pointer; } + SpvStorageClass storage_class() const { return storage_class_; } + const Pointer* target_pointer() const { return pointer_; } + + std::string str() const override; + + ForwardPointer* AsForwardPointer() override { return this; } + const ForwardPointer* AsForwardPointer() const override { return this; } + + void GetExtraHashWords(std::vector* words, + std::unordered_set* pSet) const override; + + private: + bool IsSameImpl(const Type* that, IsSameCache*) const override; + + uint32_t target_id_; + SpvStorageClass storage_class_; + const Pointer* pointer_; +}; + +#define DefineParameterlessType(type, name) \ + class type : public Type { \ + public: \ + type() : Type(k##type) {} \ + type(const type&) = default; \ + \ + std::string str() const override { return #name; } \ + \ + type* As##type() override { return this; } \ + const type* As##type() const override { return this; } \ + \ + void GetExtraHashWords(std::vector*, \ + std::unordered_set*) const override {} \ + \ + private: \ + bool IsSameImpl(const Type* that, IsSameCache*) const override { \ + return that->As##type() && HasSameDecorations(that); \ + } \ + } +DefineParameterlessType(Void, void); +DefineParameterlessType(Bool, bool); +DefineParameterlessType(Sampler, sampler); +DefineParameterlessType(Event, event); +DefineParameterlessType(DeviceEvent, device_event); +DefineParameterlessType(ReserveId, reserve_id); +DefineParameterlessType(Queue, queue); +DefineParameterlessType(PipeStorage, pipe_storage); +DefineParameterlessType(NamedBarrier, named_barrier); +DefineParameterlessType(AccelerationStructureNV, accelerationStructureNV); +#undef DefineParameterlessType + +} // namespace analysis +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_TYPES_H_ diff --git a/source/opt/unify_const_pass.cpp b/source/opt/unify_const_pass.cpp new file mode 100644 index 000000000..227fd61da --- /dev/null +++ b/source/opt/unify_const_pass.cpp @@ -0,0 +1,177 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/unify_const_pass.h" + +#include +#include +#include +#include + +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/util/make_unique.h" + +namespace spvtools { +namespace opt { + +namespace { + +// The trie that stores a bunch of result ids and, for a given instruction, +// searches the result id that has been defined with the same opcode, type and +// operands. +class ResultIdTrie { + public: + ResultIdTrie() : root_(new Node) {} + + // For a given instruction, extracts its opcode, type id and operand words + // as an array of keys, looks up the trie to find a result id which is stored + // with the same opcode, type id and operand words. If none of such result id + // is found, creates a trie node with those keys, stores the instruction's + // result id and returns that result id. If an existing result id is found, + // returns the existing result id. + uint32_t LookupEquivalentResultFor(const Instruction& inst) { + auto keys = GetLookUpKeys(inst); + auto* node = root_.get(); + for (uint32_t key : keys) { + node = node->GetOrCreateTrieNodeFor(key); + } + if (node->result_id() == 0) { + node->SetResultId(inst.result_id()); + } + return node->result_id(); + } + + private: + // The trie node to store result ids. + class Node { + public: + using TrieNodeMap = std::unordered_map>; + + Node() : result_id_(0), next_() {} + uint32_t result_id() const { return result_id_; } + + // Sets the result id stored in this node. + void SetResultId(uint32_t id) { result_id_ = id; } + + // Searches for the child trie node with the given key. If the node is + // found, returns that node. Otherwise creates an empty child node with + // that key and returns that newly created node. + Node* GetOrCreateTrieNodeFor(uint32_t key) { + auto iter = next_.find(key); + if (iter == next_.end()) { + // insert a new node and return the node. + return next_.insert(std::make_pair(key, MakeUnique())) + .first->second.get(); + } + return iter->second.get(); + } + + private: + // The result id stored in this node. 0 means this node is empty. + uint32_t result_id_; + // The mapping from the keys to the child nodes of this node. + TrieNodeMap next_; + }; + + // Returns a vector of the opcode followed by the words in the raw SPIR-V + // instruction encoding but without the result id. + std::vector GetLookUpKeys(const Instruction& inst) { + std::vector keys; + // Need to use the opcode, otherwise there might be a conflict with the + // following case when 's binary value equals xx's id: + // OpSpecConstantOp tt yy zz + // OpSpecConstantComposite tt xx yy zz; + keys.push_back(static_cast(inst.opcode())); + for (const auto& operand : inst) { + if (operand.type == SPV_OPERAND_TYPE_RESULT_ID) continue; + keys.insert(keys.end(), operand.words.cbegin(), operand.words.cend()); + } + return keys; + } + + std::unique_ptr root_; // The root node of the trie. +}; +} // anonymous namespace + +Pass::Status UnifyConstantPass::Process() { + bool modified = false; + ResultIdTrie defined_constants; + + for (Instruction *next_instruction, + *inst = &*(context()->types_values_begin()); + inst; inst = next_instruction) { + next_instruction = inst->NextNode(); + + // Do not handle the instruction when there are decorations upon the result + // id. + if (get_def_use_mgr()->GetAnnotations(inst->result_id()).size() != 0) { + continue; + } + + // The overall algorithm is to store the result ids of all the eligible + // constants encountered so far in a trie. For a constant defining + // instruction under consideration, use its opcode, result type id and + // words in operands as an array of keys to lookup the trie. If a result id + // can be found for that array of keys, a constant with exactly the same + // value must has been defined before, the constant under processing + // should be replaced by the constant previously defined. If no such result + // id can be found for that array of keys, this must be the first time a + // constant with its value be defined, we then create a new trie node to + // store the result id with the keys. When replacing a duplicated constant + // with a previously defined constant, all the uses of the duplicated + // constant, which must be placed after the duplicated constant defining + // instruction, will be updated. This way, the descendants of the + // previously defined constant and the duplicated constant will both refer + // to the previously defined constant. So that the operand ids which are + // used in key arrays will be the ids of the unified constants, when + // processing is up to a descendant. This makes comparing the key array + // always valid for judging duplication. + switch (inst->opcode()) { + case SpvOp::SpvOpConstantTrue: + case SpvOp::SpvOpConstantFalse: + case SpvOp::SpvOpConstant: + case SpvOp::SpvOpConstantNull: + case SpvOp::SpvOpConstantSampler: + case SpvOp::SpvOpConstantComposite: + // Only spec constants defined with OpSpecConstantOp and + // OpSpecConstantComposite should be processed in this pass. Spec + // constants defined with OpSpecConstant{|True|False} are decorated with + // 'SpecId' decoration and all of them should be treated as unique. + // 'SpecId' is not applicable to SpecConstants defined with + // OpSpecConstant{Op|Composite}, their values are not necessary to be + // unique. When all the operands/compoents are the same between two + // OpSpecConstant{Op|Composite} results, their result values must be the + // same so are unifiable. + case SpvOp::SpvOpSpecConstantOp: + case SpvOp::SpvOpSpecConstantComposite: { + uint32_t id = defined_constants.LookupEquivalentResultFor(*inst); + if (id != inst->result_id()) { + // The constant is a duplicated one, use the cached constant to + // replace the uses of this duplicated one, then turn it to nop. + context()->ReplaceAllUsesWith(inst->result_id(), id); + context()->KillInst(inst); + modified = true; + } + break; + } + default: + break; + } + } + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/unify_const_pass.h b/source/opt/unify_const_pass.h new file mode 100644 index 000000000..f2b7897cc --- /dev/null +++ b/source/opt/unify_const_pass.h @@ -0,0 +1,35 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_UNIFY_CONST_PASS_H_ +#define SOURCE_OPT_UNIFY_CONST_PASS_H_ + +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class UnifyConstantPass : public Pass { + public: + const char* name() const override { return "unify-const"; } + Status Process() override; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_UNIFY_CONST_PASS_H_ diff --git a/source/opt/upgrade_memory_model.cpp b/source/opt/upgrade_memory_model.cpp new file mode 100644 index 000000000..d8836a408 --- /dev/null +++ b/source/opt/upgrade_memory_model.cpp @@ -0,0 +1,640 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "upgrade_memory_model.h" + +#include + +#include "source/opt/ir_builder.h" +#include "source/opt/ir_context.h" +#include "source/util/make_unique.h" + +namespace spvtools { +namespace opt { + +Pass::Status UpgradeMemoryModel::Process() { + // Only update Logical GLSL450 to Logical VulkanKHR. + Instruction* memory_model = get_module()->GetMemoryModel(); + if (memory_model->GetSingleWordInOperand(0u) != SpvAddressingModelLogical || + memory_model->GetSingleWordInOperand(1u) != SpvMemoryModelGLSL450) { + return Pass::Status::SuccessWithoutChange; + } + + UpgradeMemoryModelInstruction(); + UpgradeInstructions(); + CleanupDecorations(); + UpgradeBarriers(); + UpgradeMemoryScope(); + + return Pass::Status::SuccessWithChange; +} + +void UpgradeMemoryModel::UpgradeMemoryModelInstruction() { + // Overall changes necessary: + // 1. Add the OpExtension. + // 2. Add the OpCapability. + // 3. Modify the memory model. + Instruction* memory_model = get_module()->GetMemoryModel(); + get_module()->AddCapability(MakeUnique( + context(), SpvOpCapability, 0, 0, + std::initializer_list{ + {SPV_OPERAND_TYPE_CAPABILITY, {SpvCapabilityVulkanMemoryModelKHR}}})); + const std::string extension = "SPV_KHR_vulkan_memory_model"; + std::vector words(extension.size() / 4 + 1, 0); + char* dst = reinterpret_cast(words.data()); + strncpy(dst, extension.c_str(), extension.size()); + get_module()->AddExtension( + MakeUnique(context(), SpvOpExtension, 0, 0, + std::initializer_list{ + {SPV_OPERAND_TYPE_LITERAL_STRING, words}})); + memory_model->SetInOperand(1u, {SpvMemoryModelVulkanKHR}); +} + +void UpgradeMemoryModel::UpgradeInstructions() { + // Coherent and Volatile decorations are deprecated. Remove them and replace + // with flags on the memory/image operations. The decorations can occur on + // OpVariable, OpFunctionParameter (of pointer type) and OpStructType (member + // decoration). Trace from the decoration target(s) to the final memory/image + // instructions. Additionally, Workgroup storage class variables and function + // parameters are implicitly coherent in GLSL450. + + // Upgrade modf and frexp first since they generate new stores. + for (auto& func : *get_module()) { + func.ForEachInst([this](Instruction* inst) { + if (inst->opcode() == SpvOpExtInst) { + auto ext_inst = inst->GetSingleWordInOperand(1u); + if (ext_inst == GLSLstd450Modf || ext_inst == GLSLstd450Frexp) { + auto import = + get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u)); + if (reinterpret_cast(import->GetInOperand(0u).words.data()) == + std::string("GLSL.std.450")) { + UpgradeExtInst(inst); + } + } + } + }); + } + for (auto& func : *get_module()) { + func.ForEachInst([this](Instruction* inst) { + bool is_coherent = false; + bool is_volatile = false; + bool src_coherent = false; + bool src_volatile = false; + bool dst_coherent = false; + bool dst_volatile = false; + SpvScope scope = SpvScopeQueueFamilyKHR; + SpvScope src_scope = SpvScopeQueueFamilyKHR; + SpvScope dst_scope = SpvScopeQueueFamilyKHR; + switch (inst->opcode()) { + case SpvOpLoad: + case SpvOpStore: + std::tie(is_coherent, is_volatile, scope) = + GetInstructionAttributes(inst->GetSingleWordInOperand(0u)); + break; + case SpvOpImageRead: + case SpvOpImageSparseRead: + case SpvOpImageWrite: + std::tie(is_coherent, is_volatile, scope) = + GetInstructionAttributes(inst->GetSingleWordInOperand(0u)); + break; + case SpvOpCopyMemory: + case SpvOpCopyMemorySized: + std::tie(dst_coherent, dst_volatile, dst_scope) = + GetInstructionAttributes(inst->GetSingleWordInOperand(0u)); + std::tie(src_coherent, src_volatile, src_scope) = + GetInstructionAttributes(inst->GetSingleWordInOperand(1u)); + break; + default: + break; + } + + switch (inst->opcode()) { + case SpvOpLoad: + UpgradeFlags(inst, 1u, is_coherent, is_volatile, kVisibility, + kMemory); + break; + case SpvOpStore: + UpgradeFlags(inst, 2u, is_coherent, is_volatile, kAvailability, + kMemory); + break; + case SpvOpCopyMemory: + UpgradeFlags(inst, 2u, dst_coherent, dst_volatile, kAvailability, + kMemory); + UpgradeFlags(inst, 2u, src_coherent, src_volatile, kVisibility, + kMemory); + break; + case SpvOpCopyMemorySized: + UpgradeFlags(inst, 3u, dst_coherent, dst_volatile, kAvailability, + kMemory); + UpgradeFlags(inst, 3u, src_coherent, src_volatile, kVisibility, + kMemory); + break; + case SpvOpImageRead: + case SpvOpImageSparseRead: + UpgradeFlags(inst, 2u, is_coherent, is_volatile, kVisibility, kImage); + break; + case SpvOpImageWrite: + UpgradeFlags(inst, 3u, is_coherent, is_volatile, kAvailability, + kImage); + break; + default: + break; + } + + // |is_coherent| is never used for the same instructions as + // |src_coherent| and |dst_coherent|. + if (is_coherent) { + inst->AddOperand( + {SPV_OPERAND_TYPE_SCOPE_ID, {GetScopeConstant(scope)}}); + } + // According to SPV_KHR_vulkan_memory_model, if both available and + // visible flags are used the first scope operand is for availability + // (writes) and the second is for visibility (reads). + if (dst_coherent) { + inst->AddOperand( + {SPV_OPERAND_TYPE_SCOPE_ID, {GetScopeConstant(dst_scope)}}); + } + if (src_coherent) { + inst->AddOperand( + {SPV_OPERAND_TYPE_SCOPE_ID, {GetScopeConstant(src_scope)}}); + } + }); + } +} + +std::tuple UpgradeMemoryModel::GetInstructionAttributes( + uint32_t id) { + // |id| is a pointer used in a memory/image instruction. Need to determine if + // that pointer points to volatile or coherent memory. Workgroup storage + // class is implicitly coherent and cannot be decorated with volatile, so + // short circuit that case. + Instruction* inst = context()->get_def_use_mgr()->GetDef(id); + analysis::Type* type = context()->get_type_mgr()->GetType(inst->type_id()); + if (type->AsPointer() && + type->AsPointer()->storage_class() == SpvStorageClassWorkgroup) { + return std::make_tuple(true, false, SpvScopeWorkgroup); + } + + bool is_coherent = false; + bool is_volatile = false; + std::unordered_set visited; + std::tie(is_coherent, is_volatile) = + TraceInstruction(context()->get_def_use_mgr()->GetDef(id), + std::vector(), &visited); + + return std::make_tuple(is_coherent, is_volatile, SpvScopeQueueFamilyKHR); +} + +std::pair UpgradeMemoryModel::TraceInstruction( + Instruction* inst, std::vector indices, + std::unordered_set* visited) { + auto iter = cache_.find(std::make_pair(inst->result_id(), indices)); + if (iter != cache_.end()) { + return iter->second; + } + + if (!visited->insert(inst->result_id()).second) { + return std::make_pair(false, false); + } + + // Initialize the cache before |indices| is (potentially) modified. + auto& cached_result = cache_[std::make_pair(inst->result_id(), indices)]; + cached_result.first = false; + cached_result.second = false; + + bool is_coherent = false; + bool is_volatile = false; + switch (inst->opcode()) { + case SpvOpVariable: + case SpvOpFunctionParameter: + is_coherent |= HasDecoration(inst, 0, SpvDecorationCoherent); + is_volatile |= HasDecoration(inst, 0, SpvDecorationVolatile); + if (!is_coherent || !is_volatile) { + bool type_coherent = false; + bool type_volatile = false; + std::tie(type_coherent, type_volatile) = + CheckType(inst->type_id(), indices); + is_coherent |= type_coherent; + is_volatile |= type_volatile; + } + break; + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + // Store indices in reverse order. + for (uint32_t i = inst->NumInOperands() - 1; i > 0; --i) { + indices.push_back(inst->GetSingleWordInOperand(i)); + } + break; + case SpvOpPtrAccessChain: + // Store indices in reverse order. Skip the |Element| operand. + for (uint32_t i = inst->NumInOperands() - 1; i > 1; --i) { + indices.push_back(inst->GetSingleWordInOperand(i)); + } + break; + default: + break; + } + + // No point searching further. + if (is_coherent && is_volatile) { + cached_result.first = true; + cached_result.second = true; + return std::make_pair(true, true); + } + + // Variables and function parameters are sources. Continue searching until we + // reach them. + if (inst->opcode() != SpvOpVariable && + inst->opcode() != SpvOpFunctionParameter) { + inst->ForEachInId([this, &is_coherent, &is_volatile, &indices, + &visited](const uint32_t* id_ptr) { + Instruction* op_inst = context()->get_def_use_mgr()->GetDef(*id_ptr); + const analysis::Type* type = + context()->get_type_mgr()->GetType(op_inst->type_id()); + if (type && + (type->AsPointer() || type->AsImage() || type->AsSampledImage())) { + bool operand_coherent = false; + bool operand_volatile = false; + std::tie(operand_coherent, operand_volatile) = + TraceInstruction(op_inst, indices, visited); + is_coherent |= operand_coherent; + is_volatile |= operand_volatile; + } + }); + } + + cached_result.first = is_coherent; + cached_result.second = is_volatile; + return std::make_pair(is_coherent, is_volatile); +} + +std::pair UpgradeMemoryModel::CheckType( + uint32_t type_id, const std::vector& indices) { + bool is_coherent = false; + bool is_volatile = false; + Instruction* type_inst = context()->get_def_use_mgr()->GetDef(type_id); + assert(type_inst->opcode() == SpvOpTypePointer); + Instruction* element_inst = context()->get_def_use_mgr()->GetDef( + type_inst->GetSingleWordInOperand(1u)); + for (int i = (int)indices.size() - 1; i >= 0; --i) { + if (is_coherent && is_volatile) break; + + if (element_inst->opcode() == SpvOpTypePointer) { + element_inst = context()->get_def_use_mgr()->GetDef( + element_inst->GetSingleWordInOperand(1u)); + } else if (element_inst->opcode() == SpvOpTypeStruct) { + uint32_t index = indices.at(i); + Instruction* index_inst = context()->get_def_use_mgr()->GetDef(index); + assert(index_inst->opcode() == SpvOpConstant); + uint64_t value = GetIndexValue(index_inst); + is_coherent |= HasDecoration(element_inst, static_cast(value), + SpvDecorationCoherent); + is_volatile |= HasDecoration(element_inst, static_cast(value), + SpvDecorationVolatile); + element_inst = context()->get_def_use_mgr()->GetDef( + element_inst->GetSingleWordInOperand(static_cast(value))); + } else { + assert(spvOpcodeIsComposite(element_inst->opcode())); + element_inst = context()->get_def_use_mgr()->GetDef( + element_inst->GetSingleWordInOperand(1u)); + } + } + + if (!is_coherent || !is_volatile) { + bool remaining_coherent = false; + bool remaining_volatile = false; + std::tie(remaining_coherent, remaining_volatile) = + CheckAllTypes(element_inst); + is_coherent |= remaining_coherent; + is_volatile |= remaining_volatile; + } + + return std::make_pair(is_coherent, is_volatile); +} + +std::pair UpgradeMemoryModel::CheckAllTypes( + const Instruction* inst) { + std::unordered_set visited; + std::vector stack; + stack.push_back(inst); + + bool is_coherent = false; + bool is_volatile = false; + while (!stack.empty()) { + const Instruction* def = stack.back(); + stack.pop_back(); + + if (!visited.insert(def).second) continue; + + if (def->opcode() == SpvOpTypeStruct) { + // Any member decorated with coherent and/or volatile is enough to have + // the related operation be flagged as coherent and/or volatile. + is_coherent |= HasDecoration(def, std::numeric_limits::max(), + SpvDecorationCoherent); + is_volatile |= HasDecoration(def, std::numeric_limits::max(), + SpvDecorationVolatile); + if (is_coherent && is_volatile) + return std::make_pair(is_coherent, is_volatile); + + // Check the subtypes. + for (uint32_t i = 0; i < def->NumInOperands(); ++i) { + stack.push_back(context()->get_def_use_mgr()->GetDef( + def->GetSingleWordInOperand(i))); + } + } else if (spvOpcodeIsComposite(def->opcode())) { + stack.push_back(context()->get_def_use_mgr()->GetDef( + def->GetSingleWordInOperand(0u))); + } else if (def->opcode() == SpvOpTypePointer) { + stack.push_back(context()->get_def_use_mgr()->GetDef( + def->GetSingleWordInOperand(1u))); + } + } + + return std::make_pair(is_coherent, is_volatile); +} + +uint64_t UpgradeMemoryModel::GetIndexValue(Instruction* index_inst) { + const analysis::Constant* index_constant = + context()->get_constant_mgr()->GetConstantFromInst(index_inst); + assert(index_constant->AsIntConstant()); + if (index_constant->type()->AsInteger()->IsSigned()) { + if (index_constant->type()->AsInteger()->width() == 32) { + return index_constant->GetS32(); + } else { + return index_constant->GetS64(); + } + } else { + if (index_constant->type()->AsInteger()->width() == 32) { + return index_constant->GetU32(); + } else { + return index_constant->GetU64(); + } + } +} + +bool UpgradeMemoryModel::HasDecoration(const Instruction* inst, uint32_t value, + SpvDecoration decoration) { + // If the iteration was terminated early then an appropriate decoration was + // found. + return !context()->get_decoration_mgr()->WhileEachDecoration( + inst->result_id(), decoration, [value](const Instruction& i) { + if (i.opcode() == SpvOpDecorate || i.opcode() == SpvOpDecorateId) { + return false; + } else if (i.opcode() == SpvOpMemberDecorate) { + if (value == i.GetSingleWordInOperand(1u) || + value == std::numeric_limits::max()) + return false; + } + + return true; + }); +} + +void UpgradeMemoryModel::UpgradeFlags(Instruction* inst, uint32_t in_operand, + bool is_coherent, bool is_volatile, + OperationType operation_type, + InstructionType inst_type) { + if (!is_coherent && !is_volatile) return; + + uint32_t flags = 0; + if (inst->NumInOperands() > in_operand) { + flags |= inst->GetSingleWordInOperand(in_operand); + } + if (is_coherent) { + if (inst_type == kMemory) { + flags |= SpvMemoryAccessNonPrivatePointerKHRMask; + if (operation_type == kVisibility) { + flags |= SpvMemoryAccessMakePointerVisibleKHRMask; + } else { + flags |= SpvMemoryAccessMakePointerAvailableKHRMask; + } + } else { + flags |= SpvImageOperandsNonPrivateTexelKHRMask; + if (operation_type == kVisibility) { + flags |= SpvImageOperandsMakeTexelVisibleKHRMask; + } else { + flags |= SpvImageOperandsMakeTexelAvailableKHRMask; + } + } + } + + if (is_volatile) { + if (inst_type == kMemory) { + flags |= SpvMemoryAccessVolatileMask; + } else { + flags |= SpvImageOperandsVolatileTexelKHRMask; + } + } + + if (inst->NumInOperands() > in_operand) { + inst->SetInOperand(in_operand, {flags}); + } else if (inst_type == kMemory) { + inst->AddOperand({SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS, {flags}}); + } else { + inst->AddOperand({SPV_OPERAND_TYPE_OPTIONAL_IMAGE, {flags}}); + } +} + +uint32_t UpgradeMemoryModel::GetScopeConstant(SpvScope scope) { + analysis::Integer int_ty(32, false); + uint32_t int_id = context()->get_type_mgr()->GetTypeInstruction(&int_ty); + const analysis::Constant* constant = + context()->get_constant_mgr()->GetConstant( + context()->get_type_mgr()->GetType(int_id), + {static_cast(scope)}); + return context() + ->get_constant_mgr() + ->GetDefiningInstruction(constant) + ->result_id(); +} + +void UpgradeMemoryModel::CleanupDecorations() { + // All of the volatile and coherent decorations have been dealt with, so now + // we can just remove them. + get_module()->ForEachInst([this](Instruction* inst) { + if (inst->result_id() != 0) { + context()->get_decoration_mgr()->RemoveDecorationsFrom( + inst->result_id(), [](const Instruction& dec) { + switch (dec.opcode()) { + case SpvOpDecorate: + case SpvOpDecorateId: + if (dec.GetSingleWordInOperand(1u) == SpvDecorationCoherent || + dec.GetSingleWordInOperand(1u) == SpvDecorationVolatile) + return true; + break; + case SpvOpMemberDecorate: + if (dec.GetSingleWordInOperand(2u) == SpvDecorationCoherent || + dec.GetSingleWordInOperand(2u) == SpvDecorationVolatile) + return true; + break; + default: + break; + } + return false; + }); + } + }); +} + +void UpgradeMemoryModel::UpgradeBarriers() { + std::vector barriers; + // Collects all the control barriers in |function|. Returns true if the + // function operates on the Output storage class. + ProcessFunction CollectBarriers = [this, &barriers](Function* function) { + bool operates_on_output = false; + for (auto& block : *function) { + block.ForEachInst([this, &barriers, + &operates_on_output](Instruction* inst) { + if (inst->opcode() == SpvOpControlBarrier) { + barriers.push_back(inst); + } else if (!operates_on_output) { + // This instruction operates on output storage class if it is a + // pointer to output type or any input operand is a pointer to output + // type. + analysis::Type* type = + context()->get_type_mgr()->GetType(inst->type_id()); + if (type && type->AsPointer() && + type->AsPointer()->storage_class() == SpvStorageClassOutput) { + operates_on_output = true; + return; + } + inst->ForEachInId([this, &operates_on_output](uint32_t* id_ptr) { + Instruction* op_inst = + context()->get_def_use_mgr()->GetDef(*id_ptr); + analysis::Type* op_type = + context()->get_type_mgr()->GetType(op_inst->type_id()); + if (op_type && op_type->AsPointer() && + op_type->AsPointer()->storage_class() == SpvStorageClassOutput) + operates_on_output = true; + }); + } + }); + } + return operates_on_output; + }; + + std::queue roots; + for (auto& e : get_module()->entry_points()) + if (e.GetSingleWordInOperand(0u) == SpvExecutionModelTessellationControl) { + roots.push(e.GetSingleWordInOperand(1u)); + if (context()->ProcessCallTreeFromRoots(CollectBarriers, &roots)) { + for (auto barrier : barriers) { + // Add OutputMemoryKHR to the semantics of the barriers. + uint32_t semantics_id = barrier->GetSingleWordInOperand(2u); + Instruction* semantics_inst = + context()->get_def_use_mgr()->GetDef(semantics_id); + analysis::Type* semantics_type = + context()->get_type_mgr()->GetType(semantics_inst->type_id()); + uint64_t semantics_value = GetIndexValue(semantics_inst); + const analysis::Constant* constant = + context()->get_constant_mgr()->GetConstant( + semantics_type, {static_cast(semantics_value) | + SpvMemorySemanticsOutputMemoryKHRMask}); + barrier->SetInOperand(2u, {context() + ->get_constant_mgr() + ->GetDefiningInstruction(constant) + ->result_id()}); + } + } + barriers.clear(); + } +} + +void UpgradeMemoryModel::UpgradeMemoryScope() { + get_module()->ForEachInst([this](Instruction* inst) { + // Don't need to handle all the operations that take a scope. + // * Group operations can only be subgroup + // * Non-uniform can only be workgroup or subgroup + // * Named barriers are not supported by Vulkan + // * Workgroup ops (e.g. async_copy) have at most workgroup scope. + if (spvOpcodeIsAtomicOp(inst->opcode())) { + if (IsDeviceScope(inst->GetSingleWordInOperand(1))) { + inst->SetInOperand(1, {GetScopeConstant(SpvScopeQueueFamilyKHR)}); + } + } else if (inst->opcode() == SpvOpControlBarrier) { + if (IsDeviceScope(inst->GetSingleWordInOperand(1))) { + inst->SetInOperand(1, {GetScopeConstant(SpvScopeQueueFamilyKHR)}); + } + } else if (inst->opcode() == SpvOpMemoryBarrier) { + if (IsDeviceScope(inst->GetSingleWordInOperand(0))) { + inst->SetInOperand(0, {GetScopeConstant(SpvScopeQueueFamilyKHR)}); + } + } + }); +} + +bool UpgradeMemoryModel::IsDeviceScope(uint32_t scope_id) { + const analysis::Constant* constant = + context()->get_constant_mgr()->FindDeclaredConstant(scope_id); + assert(constant && "Memory scope must be a constant"); + + const analysis::Integer* type = constant->type()->AsInteger(); + assert(type); + assert(type->width() == 32 || type->width() == 64); + if (type->width() == 32) { + if (type->IsSigned()) + return static_cast(constant->GetS32()) == SpvScopeDevice; + else + return static_cast(constant->GetU32()) == SpvScopeDevice; + } else { + if (type->IsSigned()) + return static_cast(constant->GetS64()) == SpvScopeDevice; + else + return static_cast(constant->GetU64()) == SpvScopeDevice; + } + + assert(false); + return false; +} + +void UpgradeMemoryModel::UpgradeExtInst(Instruction* ext_inst) { + const bool is_modf = ext_inst->GetSingleWordInOperand(1u) == GLSLstd450Modf; + auto ptr_id = ext_inst->GetSingleWordInOperand(3u); + auto ptr_type_id = get_def_use_mgr()->GetDef(ptr_id)->type_id(); + auto pointee_type_id = + get_def_use_mgr()->GetDef(ptr_type_id)->GetSingleWordInOperand(1u); + auto element_type_id = ext_inst->type_id(); + std::vector element_types(2); + element_types[0] = context()->get_type_mgr()->GetType(element_type_id); + element_types[1] = context()->get_type_mgr()->GetType(pointee_type_id); + analysis::Struct struct_type(element_types); + uint32_t struct_id = + context()->get_type_mgr()->GetTypeInstruction(&struct_type); + // Change the operation + GLSLstd450 new_op = is_modf ? GLSLstd450ModfStruct : GLSLstd450FrexpStruct; + ext_inst->SetOperand(3u, {static_cast(new_op)}); + // Remove the pointer argument + ext_inst->RemoveOperand(5u); + // Set the type id to the new struct. + ext_inst->SetResultType(struct_id); + + // The result is now a struct of the original result. The zero'th element is + // old result and should replace the old result. The one'th element needs to + // be stored via a new instruction. + auto where = ext_inst->NextNode(); + InstructionBuilder builder( + context(), where, + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + auto extract_0 = + builder.AddCompositeExtract(element_type_id, ext_inst->result_id(), {0}); + context()->ReplaceAllUsesWith(ext_inst->result_id(), extract_0->result_id()); + // The extract's input was just changed to itself, so fix that. + extract_0->SetInOperand(0u, {ext_inst->result_id()}); + auto extract_1 = + builder.AddCompositeExtract(pointee_type_id, ext_inst->result_id(), {1}); + builder.AddStore(ptr_id, extract_1->result_id()); +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/upgrade_memory_model.h b/source/opt/upgrade_memory_model.h new file mode 100644 index 000000000..9adc33b68 --- /dev/null +++ b/source/opt/upgrade_memory_model.h @@ -0,0 +1,134 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIBSPIRV_OPT_UPGRADE_MEMORY_MODEL_H_ +#define LIBSPIRV_OPT_UPGRADE_MEMORY_MODEL_H_ + +#include "pass.h" + +#include +#include + +namespace spvtools { +namespace opt { + +// Hashing functor for the memoized result store. +struct CacheHash { + size_t operator()( + const std::pair>& item) const { + std::u32string to_hash; + to_hash.push_back(item.first); + for (auto i : item.second) to_hash.push_back(i); + return std::hash()(to_hash); + } +}; + +// Upgrades the memory model from Logical GLSL450 to Logical VulkanKHR. +// +// This pass remove deprecated decorations (Volatile and Coherent) and replaces +// them with new flags on individual instructions. It adds the Output storage +// class semantic to control barriers in tessellation control shaders that have +// an access to Output memory. +class UpgradeMemoryModel : public Pass { + public: + const char* name() const override { return "upgrade-memory-model"; } + Status Process() override; + + private: + // Used to indicate whether the operation performs an availability or + // visibility operation. + enum OperationType { kVisibility, kAvailability }; + + // Used to indicate whether the instruction is a memory or image instruction. + enum InstructionType { kMemory, kImage }; + + // Modifies the OpMemoryModel to use VulkanKHR. Adds the Vulkan memory model + // capability and extension. + void UpgradeMemoryModelInstruction(); + + // Upgrades memory, image and barrier instructions. + // Memory and image instructions convert coherent and volatile decorations + // into flags on the instruction. Barriers in tessellation shaders get the + // output storage semantic if appropriate. + void UpgradeInstructions(); + + // Returns whether |id| is coherent and/or volatile. + std::tuple GetInstructionAttributes(uint32_t id); + + // Traces |inst| to determine if it is coherent and/or volatile. + // |indices| tracks the access chain indices seen so far. + std::pair TraceInstruction(Instruction* inst, + std::vector indices, + std::unordered_set* visited); + + // Return true if |inst| is decorated with |decoration|. + // If |inst| is decorated by member decorations then either |value| must + // match the index or |value| must be a maximum allowable value. The max + // value allows any element to match. + bool HasDecoration(const Instruction* inst, uint32_t value, + SpvDecoration decoration); + + // Returns whether |type_id| indexed via |indices| is coherent and/or + // volatile. + std::pair CheckType(uint32_t type_id, + const std::vector& indices); + + // Returns whether any type/element under |inst| is coherent and/or volatile. + std::pair CheckAllTypes(const Instruction* inst); + + // Modifies the flags of |inst| to include the new flags for the Vulkan + // memory model. |operation_type| indicates whether flags should use + // MakeVisible or MakeAvailable variants. |inst_type| indicates whether the + // Pointer or Texel variants of flags should be used. + void UpgradeFlags(Instruction* inst, uint32_t in_operand, bool is_coherent, + bool is_volatile, OperationType operation_type, + InstructionType inst_type); + + // Returns the result id for a constant for |scope|. + uint32_t GetScopeConstant(SpvScope scope); + + // Returns the value of |index_inst|. |index_inst| must be an OpConstant of + // integer type.g + uint64_t GetIndexValue(Instruction* index_inst); + + // Removes coherent and volatile decorations. + void CleanupDecorations(); + + // For all tessellation control entry points, if there is an operation on + // Output storage class, then all barriers are modified to include the + // OutputMemoryKHR semantic. + void UpgradeBarriers(); + + // If the Vulkan memory model is specified, device scope actually means + // device scope. The memory scope must be modified to be QueueFamilyKHR + // scope. + void UpgradeMemoryScope(); + + // Returns true if |scope_id| is SpvScopeDevice. + bool IsDeviceScope(uint32_t scope_id); + + // Upgrades GLSL.std.450 modf and frexp. Both instructions are replaced with + // their struct versions. New extracts and a store are added in order to + // facilitate adding memory model flags. + void UpgradeExtInst(Instruction* modf); + + // Caches the result of TraceInstruction. For a given result id and set of + // indices, stores whether that combination is coherent and/or volatile. + std::unordered_map>, + std::pair, CacheHash> + cache_; +}; +} // namespace opt +} // namespace spvtools +#endif // LIBSPIRV_OPT_UPGRADE_MEMORY_MODEL_H_ diff --git a/source/opt/value_number_table.cpp b/source/opt/value_number_table.cpp new file mode 100644 index 000000000..1bac63fab --- /dev/null +++ b/source/opt/value_number_table.cpp @@ -0,0 +1,227 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/value_number_table.h" + +#include + +#include "source/opt/cfg.h" +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace opt { + +uint32_t ValueNumberTable::GetValueNumber(Instruction* inst) const { + assert(inst->result_id() != 0 && + "inst must have a result id to get a value number."); + + // Check if this instruction already has a value. + auto result_id_to_val = id_to_value_.find(inst->result_id()); + if (result_id_to_val != id_to_value_.end()) { + return result_id_to_val->second; + } + return 0; +} + +uint32_t ValueNumberTable::GetValueNumber(uint32_t id) const { + return GetValueNumber(context()->get_def_use_mgr()->GetDef(id)); +} + +uint32_t ValueNumberTable::AssignValueNumber(Instruction* inst) { + // If it already has a value return that. + uint32_t value = GetValueNumber(inst); + if (value != 0) { + return value; + } + + // If the instruction has other side effects, then it must + // have its own value number. + // OpSampledImage and OpImage must remain in the same basic block in which + // they are used, because of this we will assign each one it own value number. + if (!context()->IsCombinatorInstruction(inst)) { + value = TakeNextValueNumber(); + id_to_value_[inst->result_id()] = value; + return value; + } + + switch (inst->opcode()) { + case SpvOpSampledImage: + case SpvOpImage: + case SpvOpVariable: + value = TakeNextValueNumber(); + id_to_value_[inst->result_id()] = value; + return value; + default: + break; + } + + // If it is a load from memory that can be modified, we have to assume the + // memory has been modified, so we give it a new value number. + // + // Note that this test will also handle volatile loads because they are not + // read only. However, if this is ever relaxed because we analyze stores, we + // will have to add a new case for volatile loads. + if (inst->IsLoad() && !inst->IsReadOnlyLoad()) { + value = TakeNextValueNumber(); + id_to_value_[inst->result_id()] = value; + return value; + } + + // When we copy an object, the value numbers should be the same. + if (inst->opcode() == SpvOpCopyObject) { + value = GetValueNumber(inst->GetSingleWordInOperand(0)); + if (value != 0) { + id_to_value_[inst->result_id()] = value; + return value; + } + } + + // Phi nodes are a type of copy. If all of the inputs have the same value + // number, then we can assign the result of the phi the same value number. + if (inst->opcode() == SpvOpPhi) { + value = GetValueNumber(inst->GetSingleWordInOperand(0)); + if (value != 0) { + for (uint32_t op = 2; op < inst->NumInOperands(); op += 2) { + if (value != GetValueNumber(inst->GetSingleWordInOperand(op))) { + value = 0; + break; + } + } + if (value != 0) { + id_to_value_[inst->result_id()] = value; + return value; + } + } + } + + // Replace all of the operands by their value number. The sign bit will be + // set to distinguish between an id and a value number. + Instruction value_ins(context(), inst->opcode(), inst->type_id(), + inst->result_id(), {}); + for (uint32_t o = 0; o < inst->NumInOperands(); ++o) { + const Operand& op = inst->GetInOperand(o); + if (spvIsIdType(op.type)) { + uint32_t id_value = op.words[0]; + auto use_id_to_val = id_to_value_.find(id_value); + if (use_id_to_val != id_to_value_.end()) { + id_value = (1 << 31) | use_id_to_val->second; + } + value_ins.AddOperand(Operand(op.type, {id_value})); + } else { + value_ins.AddOperand(Operand(op.type, op.words)); + } + } + + // TODO: Implement a normal form for opcodes that commute like integer + // addition. This will let us know that a+b is the same value as b+a. + + // Otherwise, we check if this value has been computed before. + auto value_iterator = instruction_to_value_.find(value_ins); + if (value_iterator != instruction_to_value_.end()) { + value = id_to_value_[value_iterator->first.result_id()]; + id_to_value_[inst->result_id()] = value; + return value; + } + + // If not, assign it a new value number. + value = TakeNextValueNumber(); + id_to_value_[inst->result_id()] = value; + instruction_to_value_[value_ins] = value; + return value; +} + +void ValueNumberTable::BuildDominatorTreeValueNumberTable() { + // First value number the headers. + for (auto& inst : context()->annotations()) { + if (inst.result_id() != 0) { + AssignValueNumber(&inst); + } + } + + for (auto& inst : context()->capabilities()) { + if (inst.result_id() != 0) { + AssignValueNumber(&inst); + } + } + + for (auto& inst : context()->types_values()) { + if (inst.result_id() != 0) { + AssignValueNumber(&inst); + } + } + + for (auto& inst : context()->module()->ext_inst_imports()) { + if (inst.result_id() != 0) { + AssignValueNumber(&inst); + } + } + + for (Function& func : *context()->module()) { + // For best results we want to traverse the code in reverse post order. + // This happens naturally because of the forward referencing rules. + for (BasicBlock& block : func) { + for (Instruction& inst : block) { + if (inst.result_id() != 0) { + AssignValueNumber(&inst); + } + } + } + } +} + +bool ComputeSameValue::operator()(const Instruction& lhs, + const Instruction& rhs) const { + if (lhs.result_id() == 0 || rhs.result_id() == 0) { + return false; + } + + if (lhs.opcode() != rhs.opcode()) { + return false; + } + + if (lhs.type_id() != rhs.type_id()) { + return false; + } + + if (lhs.NumInOperands() != rhs.NumInOperands()) { + return false; + } + + for (uint32_t i = 0; i < lhs.NumInOperands(); ++i) { + if (lhs.GetInOperand(i) != rhs.GetInOperand(i)) { + return false; + } + } + + return lhs.context()->get_decoration_mgr()->HaveTheSameDecorations( + lhs.result_id(), rhs.result_id()); +} + +std::size_t ValueTableHash::operator()(const Instruction& inst) const { + // We hash the opcode and in-operands, not the result, because we want + // instructions that are the same except for the result to hash to the + // same value. + std::u32string h; + h.push_back(inst.opcode()); + h.push_back(inst.type_id()); + for (uint32_t i = 0; i < inst.NumInOperands(); ++i) { + const auto& opnd = inst.GetInOperand(i); + for (uint32_t word : opnd.words) { + h.push_back(word); + } + } + return std::hash()(h); +} +} // namespace opt +} // namespace spvtools diff --git a/source/opt/value_number_table.h b/source/opt/value_number_table.h new file mode 100644 index 000000000..39129ffa3 --- /dev/null +++ b/source/opt/value_number_table.h @@ -0,0 +1,91 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_VALUE_NUMBER_TABLE_H_ +#define SOURCE_OPT_VALUE_NUMBER_TABLE_H_ + +#include +#include + +#include "source/opt/instruction.h" + +namespace spvtools { +namespace opt { + +class IRContext; + +// Returns true if the two instructions compute the same value. Used by the +// value number table to compare two instructions. +class ComputeSameValue { + public: + bool operator()(const Instruction& lhs, const Instruction& rhs) const; +}; + +// The hash function used in the value number table. +class ValueTableHash { + public: + std::size_t operator()(const Instruction& inst) const; +}; + +// This class implements the value number analysis. It is using a hash-based +// approach to value numbering. It is essentially doing dominator-tree value +// numbering described in +// +// Preston Briggs, Keith D. Cooper, and L. Taylor Simpson. 1997. Value +// numbering. Softw. Pract. Exper. 27, 6 (June 1997), 701-724. +// https://www.cs.rice.edu/~keith/Promo/CRPC-TR94517.pdf.gz +// +// The main difference is that because we do not perform redundancy elimination +// as we build the value number table, we do not have to deal with cleaning up +// the scope. +class ValueNumberTable { + public: + ValueNumberTable(IRContext* ctx) : context_(ctx), next_value_number_(1) { + BuildDominatorTreeValueNumberTable(); + } + + // Returns the value number of the value computed by |inst|. |inst| must have + // a result id that will hold the computed value. If no value number has been + // assigned to the result id, then the return value is 0. + uint32_t GetValueNumber(Instruction* inst) const; + + // Returns the value number of the value contain in |id|. Returns 0 if it + // has not been assigned a value number. + uint32_t GetValueNumber(uint32_t id) const; + + IRContext* context() const { return context_; } + + private: + // Assigns a value number to every result id in the module. + void BuildDominatorTreeValueNumberTable(); + + // Returns the new value number. + uint32_t TakeNextValueNumber() { return next_value_number_++; } + + // Assigns a new value number to the result of |inst| if it does not already + // have one. Return the value number for |inst|. |inst| must have a result + // id. + uint32_t AssignValueNumber(Instruction* inst); + + std::unordered_map + instruction_to_value_; + std::unordered_map id_to_value_; + IRContext* context_; + uint32_t next_value_number_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_VALUE_NUMBER_TABLE_H_ diff --git a/source/opt/vector_dce.cpp b/source/opt/vector_dce.cpp new file mode 100644 index 000000000..92532e31a --- /dev/null +++ b/source/opt/vector_dce.cpp @@ -0,0 +1,397 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/vector_dce.h" + +#include + +namespace spvtools { +namespace opt { +namespace { + +const uint32_t kExtractCompositeIdInIdx = 0; +const uint32_t kInsertObjectIdInIdx = 0; +const uint32_t kInsertCompositeIdInIdx = 1; + +} // namespace + +Pass::Status VectorDCE::Process() { + bool modified = false; + for (Function& function : *get_module()) { + modified |= VectorDCEFunction(&function); + } + return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); +} + +bool VectorDCE::VectorDCEFunction(Function* function) { + LiveComponentMap live_components; + FindLiveComponents(function, &live_components); + return RewriteInstructions(function, live_components); +} + +void VectorDCE::FindLiveComponents(Function* function, + LiveComponentMap* live_components) { + std::vector work_list; + + // Prime the work list. We will assume that any instruction that does + // not result in a vector is live. + // + // Extending to structures and matrices is not as straight forward because of + // the nesting. We cannot simply us a bit vector to keep track of which + // components are live because of arbitrary nesting of structs. + function->ForEachInst( + [&work_list, this, live_components](Instruction* current_inst) { + if (!HasVectorOrScalarResult(current_inst) || + !context()->IsCombinatorInstruction(current_inst)) { + MarkUsesAsLive(current_inst, all_components_live_, live_components, + &work_list); + } + }); + + // Process the work list propagating liveness. + for (uint32_t i = 0; i < work_list.size(); i++) { + WorkListItem current_item = work_list[i]; + Instruction* current_inst = current_item.instruction; + + switch (current_inst->opcode()) { + case SpvOpCompositeExtract: + MarkExtractUseAsLive(current_inst, current_item.components, + live_components, &work_list); + break; + case SpvOpCompositeInsert: + MarkInsertUsesAsLive(current_item, live_components, &work_list); + break; + case SpvOpVectorShuffle: + MarkVectorShuffleUsesAsLive(current_item, live_components, &work_list); + break; + case SpvOpCompositeConstruct: + MarkCompositeContructUsesAsLive(current_item, live_components, + &work_list); + break; + default: + if (current_inst->IsScalarizable()) { + MarkUsesAsLive(current_inst, current_item.components, live_components, + &work_list); + } else { + MarkUsesAsLive(current_inst, all_components_live_, live_components, + &work_list); + } + break; + } + } +} + +void VectorDCE::MarkExtractUseAsLive(const Instruction* current_inst, + const utils::BitVector& live_elements, + LiveComponentMap* live_components, + std::vector* work_list) { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + uint32_t operand_id = + current_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); + Instruction* operand_inst = def_use_mgr->GetDef(operand_id); + + if (HasVectorOrScalarResult(operand_inst)) { + WorkListItem new_item; + new_item.instruction = operand_inst; + if (current_inst->NumInOperands() < 2) { + new_item.components = live_elements; + } else { + new_item.components.Set(current_inst->GetSingleWordInOperand(1)); + } + AddItemToWorkListIfNeeded(new_item, live_components, work_list); + } +} + +void VectorDCE::MarkInsertUsesAsLive( + const VectorDCE::WorkListItem& current_item, + LiveComponentMap* live_components, + std::vector* work_list) { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + + if (current_item.instruction->NumInOperands() > 2) { + uint32_t insert_position = + current_item.instruction->GetSingleWordInOperand(2); + + // Add the elements of the composite object that are used. + uint32_t operand_id = current_item.instruction->GetSingleWordInOperand( + kInsertCompositeIdInIdx); + Instruction* operand_inst = def_use_mgr->GetDef(operand_id); + + WorkListItem new_item; + new_item.instruction = operand_inst; + new_item.components = current_item.components; + new_item.components.Clear(insert_position); + + AddItemToWorkListIfNeeded(new_item, live_components, work_list); + + // Add the element being inserted if it is used. + if (current_item.components.Get(insert_position)) { + uint32_t obj_operand_id = + current_item.instruction->GetSingleWordInOperand( + kInsertObjectIdInIdx); + Instruction* obj_operand_inst = def_use_mgr->GetDef(obj_operand_id); + WorkListItem new_item_for_obj; + new_item_for_obj.instruction = obj_operand_inst; + new_item_for_obj.components.Set(0); + AddItemToWorkListIfNeeded(new_item_for_obj, live_components, work_list); + } + } else { + // If there are no indices, then this is a copy of the object being + // inserted. + uint32_t object_id = + current_item.instruction->GetSingleWordInOperand(kInsertObjectIdInIdx); + Instruction* object_inst = def_use_mgr->GetDef(object_id); + + WorkListItem new_item; + new_item.instruction = object_inst; + new_item.components = current_item.components; + AddItemToWorkListIfNeeded(new_item, live_components, work_list); + } +} + +void VectorDCE::MarkVectorShuffleUsesAsLive( + const WorkListItem& current_item, + VectorDCE::LiveComponentMap* live_components, + std::vector* work_list) { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + + WorkListItem first_operand; + first_operand.instruction = + def_use_mgr->GetDef(current_item.instruction->GetSingleWordInOperand(0)); + WorkListItem second_operand; + second_operand.instruction = + def_use_mgr->GetDef(current_item.instruction->GetSingleWordInOperand(1)); + + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + analysis::Vector* first_type = + type_mgr->GetType(first_operand.instruction->type_id())->AsVector(); + uint32_t size_of_first_operand = first_type->element_count(); + + for (uint32_t in_op = 2; in_op < current_item.instruction->NumInOperands(); + ++in_op) { + uint32_t index = current_item.instruction->GetSingleWordInOperand(in_op); + if (current_item.components.Get(in_op - 2)) { + if (index < size_of_first_operand) { + first_operand.components.Set(index); + } else { + second_operand.components.Set(index - size_of_first_operand); + } + } + } + + AddItemToWorkListIfNeeded(first_operand, live_components, work_list); + AddItemToWorkListIfNeeded(second_operand, live_components, work_list); +} + +void VectorDCE::MarkCompositeContructUsesAsLive( + VectorDCE::WorkListItem work_item, + VectorDCE::LiveComponentMap* live_components, + std::vector* work_list) { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + + uint32_t current_component = 0; + Instruction* current_inst = work_item.instruction; + uint32_t num_in_operands = current_inst->NumInOperands(); + for (uint32_t i = 0; i < num_in_operands; ++i) { + uint32_t id = current_inst->GetSingleWordInOperand(i); + Instruction* op_inst = def_use_mgr->GetDef(id); + + if (HasScalarResult(op_inst)) { + WorkListItem new_work_item; + new_work_item.instruction = op_inst; + if (work_item.components.Get(current_component)) { + new_work_item.components.Set(0); + } + AddItemToWorkListIfNeeded(new_work_item, live_components, work_list); + current_component++; + } else { + assert(HasVectorResult(op_inst)); + WorkListItem new_work_item; + new_work_item.instruction = op_inst; + uint32_t op_vector_size = + type_mgr->GetType(op_inst->type_id())->AsVector()->element_count(); + + for (uint32_t op_vector_idx = 0; op_vector_idx < op_vector_size; + op_vector_idx++, current_component++) { + if (work_item.components.Get(current_component)) { + new_work_item.components.Set(op_vector_idx); + } + } + AddItemToWorkListIfNeeded(new_work_item, live_components, work_list); + } + } +} + +void VectorDCE::MarkUsesAsLive( + Instruction* current_inst, const utils::BitVector& live_elements, + LiveComponentMap* live_components, + std::vector* work_list) { + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + + current_inst->ForEachInId([&work_list, &live_elements, this, live_components, + def_use_mgr](uint32_t* operand_id) { + Instruction* operand_inst = def_use_mgr->GetDef(*operand_id); + + if (HasVectorResult(operand_inst)) { + WorkListItem new_item; + new_item.instruction = operand_inst; + new_item.components = live_elements; + AddItemToWorkListIfNeeded(new_item, live_components, work_list); + } else if (HasScalarResult(operand_inst)) { + WorkListItem new_item; + new_item.instruction = operand_inst; + new_item.components.Set(0); + AddItemToWorkListIfNeeded(new_item, live_components, work_list); + } + }); +} + +bool VectorDCE::HasVectorOrScalarResult(const Instruction* inst) const { + return HasScalarResult(inst) || HasVectorResult(inst); +} + +bool VectorDCE::HasVectorResult(const Instruction* inst) const { + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + if (inst->type_id() == 0) { + return false; + } + + const analysis::Type* current_type = type_mgr->GetType(inst->type_id()); + switch (current_type->kind()) { + case analysis::Type::kVector: + return true; + default: + return false; + } +} + +bool VectorDCE::HasScalarResult(const Instruction* inst) const { + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + if (inst->type_id() == 0) { + return false; + } + + const analysis::Type* current_type = type_mgr->GetType(inst->type_id()); + switch (current_type->kind()) { + case analysis::Type::kBool: + case analysis::Type::kInteger: + case analysis::Type::kFloat: + return true; + default: + return false; + } +} + +bool VectorDCE::RewriteInstructions( + Function* function, const VectorDCE::LiveComponentMap& live_components) { + bool modified = false; + function->ForEachInst( + [&modified, this, live_components](Instruction* current_inst) { + if (!context()->IsCombinatorInstruction(current_inst)) { + return; + } + + auto live_component = live_components.find(current_inst->result_id()); + if (live_component == live_components.end()) { + // If this instruction is not in live_components then it does not + // produce a vector, or it is never referenced and ADCE will remove + // it. No point in trying to differentiate. + return; + } + + // If no element in the current instruction is used replace it with an + // OpUndef. + if (live_component->second.Empty()) { + modified = true; + uint32_t undef_id = this->Type2Undef(current_inst->type_id()); + context()->KillNamesAndDecorates(current_inst); + context()->ReplaceAllUsesWith(current_inst->result_id(), undef_id); + context()->KillInst(current_inst); + return; + } + + switch (current_inst->opcode()) { + case SpvOpCompositeInsert: + modified |= + RewriteInsertInstruction(current_inst, live_component->second); + break; + case SpvOpCompositeConstruct: + // TODO: The members that are not live can be replaced by an undef + // or constant. This will remove uses of those values, and possibly + // create opportunities for ADCE. + break; + default: + // Do nothing. + break; + } + }); + return modified; +} + +bool VectorDCE::RewriteInsertInstruction( + Instruction* current_inst, const utils::BitVector& live_components) { + // If the value being inserted is not live, then we can skip the insert. + + if (current_inst->NumInOperands() == 2) { + // If there are no indices, then this is the same as a copy. + context()->KillNamesAndDecorates(current_inst->result_id()); + uint32_t object_id = + current_inst->GetSingleWordInOperand(kInsertObjectIdInIdx); + context()->ReplaceAllUsesWith(current_inst->result_id(), object_id); + return true; + } + + uint32_t insert_index = current_inst->GetSingleWordInOperand(2); + if (!live_components.Get(insert_index)) { + context()->KillNamesAndDecorates(current_inst->result_id()); + uint32_t composite_id = + current_inst->GetSingleWordInOperand(kInsertCompositeIdInIdx); + context()->ReplaceAllUsesWith(current_inst->result_id(), composite_id); + return true; + } + + // If the values already in the composite are not used, then replace it with + // an undef. + utils::BitVector temp = live_components; + temp.Clear(insert_index); + if (temp.Empty()) { + context()->ForgetUses(current_inst); + uint32_t undef_id = Type2Undef(current_inst->type_id()); + current_inst->SetInOperand(kInsertCompositeIdInIdx, {undef_id}); + context()->AnalyzeUses(current_inst); + return true; + } + + return false; +} + +void VectorDCE::AddItemToWorkListIfNeeded( + WorkListItem work_item, VectorDCE::LiveComponentMap* live_components, + std::vector* work_list) { + Instruction* current_inst = work_item.instruction; + auto it = live_components->find(current_inst->result_id()); + if (it == live_components->end()) { + live_components->emplace( + std::make_pair(current_inst->result_id(), work_item.components)); + work_list->emplace_back(work_item); + } else { + if (it->second.Or(work_item.components)) { + work_list->emplace_back(work_item); + } + } +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/vector_dce.h b/source/opt/vector_dce.h new file mode 100644 index 000000000..4f039c53f --- /dev/null +++ b/source/opt/vector_dce.h @@ -0,0 +1,151 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_VECTOR_DCE_H_ +#define SOURCE_OPT_VECTOR_DCE_H_ + +#include +#include + +#include "source/opt/mem_pass.h" +#include "source/util/bit_vector.h" + +namespace spvtools { +namespace opt { + +class VectorDCE : public MemPass { + private: + using LiveComponentMap = std::unordered_map; + + // According to the SPEC the maximum size for a vector is 16. See the data + // rules in the universal validation rules (section 2.16.1). + enum { kMaxVectorSize = 16 }; + + struct WorkListItem { + WorkListItem() : instruction(nullptr), components(kMaxVectorSize) {} + + Instruction* instruction; + utils::BitVector components; + }; + + public: + VectorDCE() : all_components_live_(kMaxVectorSize) { + for (uint32_t i = 0; i < kMaxVectorSize; i++) { + all_components_live_.Set(i); + } + } + + const char* name() const override { return "vector-dce"; } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | IRContext::kAnalysisCFG | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisDecorations | + IRContext::kAnalysisDominatorAnalysis | IRContext::kAnalysisNameMap | + IRContext::kAnalysisConstants | IRContext::kAnalysisTypes; + } + + private: + // Runs the vector dce pass on |function|. Returns true if |function| was + // modified. + bool VectorDCEFunction(Function* function); + + // Identifies the live components of the vectors that are results of + // instructions in |function|. The results are stored in |live_components|. + void FindLiveComponents(Function* function, + LiveComponentMap* live_components); + + // Rewrites instructions in |function| that are dead or partially dead. If an + // instruction does not have an entry in |live_components|, then it is not + // changed. Returns true if |function| was modified. + bool RewriteInstructions(Function* function, + const LiveComponentMap& live_components); + + // Rewrites the OpCompositeInsert instruction |current_inst| to avoid + // unnecessary computes given that the only components of the result that are + // live are |live_components|. + // + // If the value being inserted is not live, then the result of |current_inst| + // is replaced by the composite input to |current_inst|. + // + // If the composite input to |current_inst| is not live, then it is replaced + // by and OpUndef in |current_inst|. + bool RewriteInsertInstruction(Instruction* current_inst, + const utils::BitVector& live_components); + + // Returns true if the result of |inst| is a vector or a scalar. + bool HasVectorOrScalarResult(const Instruction* inst) const; + + // Returns true if the result of |inst| is a scalar. + bool HasVectorResult(const Instruction* inst) const; + + // Returns true if the result of |inst| is a vector. + bool HasScalarResult(const Instruction* inst) const; + + // Adds |work_item| to |work_list| if it is not already live according to + // |live_components|. |live_components| is updated to indicate that + // |work_item| is now live. + void AddItemToWorkListIfNeeded(WorkListItem work_item, + LiveComponentMap* live_components, + std::vector* work_list); + + // Marks the components |live_elements| of the uses in |current_inst| as live + // according to |live_components|. If they were not live before, then they are + // added to |work_list|. + void MarkUsesAsLive(Instruction* current_inst, + const utils::BitVector& live_elements, + LiveComponentMap* live_components, + std::vector* work_list); + + // Marks the uses in the OpVectorShuffle instruction in |current_item| as live + // based on the live components in |current_item|. If anything becomes live + // they are added to |work_list| and |live_components| is updated + // accordingly. + void MarkVectorShuffleUsesAsLive(const WorkListItem& current_item, + VectorDCE::LiveComponentMap* live_components, + std::vector* work_list); + + // Marks the uses in the OpCompositeInsert instruction in |current_item| as + // live based on the live components in |current_item|. If anything becomes + // live they are added to |work_list| and |live_components| is updated + // accordingly. + void MarkInsertUsesAsLive(const WorkListItem& current_item, + LiveComponentMap* live_components, + std::vector* work_list); + + // Marks the uses in the OpCompositeExtract instruction |current_inst| as + // live. If anything becomes live they are added to |work_list| and + // |live_components| is updated accordingly. + void MarkExtractUseAsLive(const Instruction* current_inst, + const utils::BitVector& live_elements, + LiveComponentMap* live_components, + std::vector* work_list); + + // Marks the uses in the OpCompositeConstruct instruction |current_inst| as + // live. If anything becomes live they are added to |work_list| and + // |live_components| is updated accordingly. + void MarkCompositeContructUsesAsLive(WorkListItem work_item, + LiveComponentMap* live_components, + std::vector* work_list); + + // A BitVector that can always be used to say that all components of a vector + // are live. + utils::BitVector all_components_live_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_VECTOR_DCE_H_ diff --git a/source/opt/workaround1209.cpp b/source/opt/workaround1209.cpp new file mode 100644 index 000000000..d6e9d2cf7 --- /dev/null +++ b/source/opt/workaround1209.cpp @@ -0,0 +1,69 @@ +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/workaround1209.h" + +#include +#include +#include +#include + +namespace spvtools { +namespace opt { + +Pass::Status Workaround1209::Process() { + bool modified = false; + modified = RemoveOpUnreachableInLoops(); + return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); +} + +bool Workaround1209::RemoveOpUnreachableInLoops() { + bool modified = false; + for (auto& func : *get_module()) { + std::list structured_order; + cfg()->ComputeStructuredOrder(&func, &*func.begin(), &structured_order); + + // Keep track of the loop merges. The top of the stack will always be the + // loop merge for the loop that immediately contains the basic block being + // processed. + std::stack loop_merges; + for (BasicBlock* bb : structured_order) { + if (!loop_merges.empty() && bb->id() == loop_merges.top()) { + loop_merges.pop(); + } + + if (bb->tail()->opcode() == SpvOpUnreachable) { + if (!loop_merges.empty()) { + // We found an OpUnreachable inside a loop. + // Replace it with an unconditional branch to the loop merge. + context()->KillInst(&*bb->tail()); + std::unique_ptr new_branch( + new Instruction(context(), SpvOpBranch, 0, 0, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, + {loop_merges.top()}}})); + context()->AnalyzeDefUse(&*new_branch); + bb->AddInstruction(std::move(new_branch)); + modified = true; + } + } else { + if (bb->GetLoopMergeInst()) { + loop_merges.push(bb->MergeBlockIdIfAny()); + } + } + } + } + return modified; +} +} // namespace opt +} // namespace spvtools diff --git a/source/opt/workaround1209.h b/source/opt/workaround1209.h new file mode 100644 index 000000000..9a1f88d93 --- /dev/null +++ b/source/opt/workaround1209.h @@ -0,0 +1,41 @@ +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_WORKAROUND1209_H_ +#define SOURCE_OPT_WORKAROUND1209_H_ + +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class Workaround1209 : public Pass { + public: + const char* name() const override { return "workaround-1209"; } + Status Process() override; + + private: + // There is at least one driver where an OpUnreachable found in a loop is not + // handled correctly. Workaround that by changing the OpUnreachable into a + // branch to the loop merge. + // + // Returns true if the code changed. + bool RemoveOpUnreachableInLoops(); +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_WORKAROUND1209_H_ diff --git a/source/parsed_operand.cpp b/source/parsed_operand.cpp new file mode 100644 index 000000000..7ad369cdb --- /dev/null +++ b/source/parsed_operand.cpp @@ -0,0 +1,74 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains utility functions for spv_parsed_operand_t. + +#include "source/parsed_operand.h" + +#include +#include "source/util/hex_float.h" + +namespace spvtools { + +void EmitNumericLiteral(std::ostream* out, const spv_parsed_instruction_t& inst, + const spv_parsed_operand_t& operand) { + if (operand.type != SPV_OPERAND_TYPE_LITERAL_INTEGER && + operand.type != SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER) + return; + if (operand.num_words < 1) return; + // TODO(dneto): Support more than 64-bits at a time. + if (operand.num_words > 2) return; + + const uint32_t word = inst.words[operand.offset]; + if (operand.num_words == 1) { + switch (operand.number_kind) { + case SPV_NUMBER_SIGNED_INT: + *out << int32_t(word); + break; + case SPV_NUMBER_UNSIGNED_INT: + *out << word; + break; + case SPV_NUMBER_FLOATING: + if (operand.number_bit_width == 16) { + *out << spvtools::utils::FloatProxy( + uint16_t(word & 0xFFFF)); + } else { + // Assume 32-bit floats. + *out << spvtools::utils::FloatProxy(word); + } + break; + default: + break; + } + } else if (operand.num_words == 2) { + // Multi-word numbers are presented with lower order words first. + uint64_t bits = + uint64_t(word) | (uint64_t(inst.words[operand.offset + 1]) << 32); + switch (operand.number_kind) { + case SPV_NUMBER_SIGNED_INT: + *out << int64_t(bits); + break; + case SPV_NUMBER_UNSIGNED_INT: + *out << bits; + break; + case SPV_NUMBER_FLOATING: + // Assume only 64-bit floats. + *out << spvtools::utils::FloatProxy(bits); + break; + default: + break; + } + } +} +} // namespace spvtools diff --git a/source/parsed_operand.h b/source/parsed_operand.h new file mode 100644 index 000000000..bab861107 --- /dev/null +++ b/source/parsed_operand.h @@ -0,0 +1,33 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_PARSED_OPERAND_H_ +#define SOURCE_PARSED_OPERAND_H_ + +#include + +#include "spirv-tools/libspirv.h" + +namespace spvtools { + +// Emits the numeric literal representation of the given instruction operand +// to the stream. The operand must be of numeric type. If integral it may +// be up to 64 bits wide. If floating point, then it must be 16, 32, or 64 +// bits wide. +void EmitNumericLiteral(std::ostream* out, const spv_parsed_instruction_t& inst, + const spv_parsed_operand_t& operand); + +} // namespace spvtools + +#endif // SOURCE_PARSED_OPERAND_H_ diff --git a/source/pch_source.cpp b/source/pch_source.cpp new file mode 100644 index 000000000..032e29ec4 --- /dev/null +++ b/source/pch_source.cpp @@ -0,0 +1,15 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "pch_source.h" diff --git a/source/pch_source.h b/source/pch_source.h new file mode 100644 index 000000000..6695ba268 --- /dev/null +++ b/source/pch_source.h @@ -0,0 +1,15 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/validation_state.h" diff --git a/source/print.cpp b/source/print.cpp new file mode 100644 index 000000000..f75e2d457 --- /dev/null +++ b/source/print.cpp @@ -0,0 +1,125 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/print.h" + +#if defined(SPIRV_ANDROID) || defined(SPIRV_LINUX) || defined(SPIRV_MAC) || \ + defined(SPIRV_FREEBSD) +namespace spvtools { + +clr::reset::operator const char*() { return "\x1b[0m"; } + +clr::grey::operator const char*() { return "\x1b[1;30m"; } + +clr::red::operator const char*() { return "\x1b[31m"; } + +clr::green::operator const char*() { return "\x1b[32m"; } + +clr::yellow::operator const char*() { return "\x1b[33m"; } + +clr::blue::operator const char*() { return "\x1b[34m"; } + +} // namespace spvtools +#elif defined(SPIRV_WINDOWS) +#include + +namespace spvtools { + +static void SetConsoleForegroundColorPrimary(HANDLE hConsole, WORD color) { + // Get screen buffer information from console handle + CONSOLE_SCREEN_BUFFER_INFO bufInfo; + GetConsoleScreenBufferInfo(hConsole, &bufInfo); + + // Get background color + color = WORD(color | (bufInfo.wAttributes & 0xfff0)); + + // Set foreground color + SetConsoleTextAttribute(hConsole, color); +} + +static void SetConsoleForegroundColor(WORD color) { + SetConsoleForegroundColorPrimary(GetStdHandle(STD_OUTPUT_HANDLE), color); + SetConsoleForegroundColorPrimary(GetStdHandle(STD_ERROR_HANDLE), color); +} + +clr::reset::operator const char*() { + if (isPrint) { + SetConsoleForegroundColor(0xf); + return ""; + } + return "\x1b[0m"; +} + +clr::grey::operator const char*() { + if (isPrint) { + SetConsoleForegroundColor(FOREGROUND_INTENSITY); + return ""; + } + return "\x1b[1;30m"; +} + +clr::red::operator const char*() { + if (isPrint) { + SetConsoleForegroundColor(FOREGROUND_RED); + return ""; + } + return "\x1b[31m"; +} + +clr::green::operator const char*() { + if (isPrint) { + SetConsoleForegroundColor(FOREGROUND_GREEN); + return ""; + } + return "\x1b[32m"; +} + +clr::yellow::operator const char*() { + if (isPrint) { + SetConsoleForegroundColor(FOREGROUND_RED | FOREGROUND_GREEN); + return ""; + } + return "\x1b[33m"; +} + +clr::blue::operator const char*() { + // Blue all by itself is hard to see against a black background (the + // default on command shell), or a medium blue background (the default + // on PowerShell). So increase its intensity. + + if (isPrint) { + SetConsoleForegroundColor(FOREGROUND_BLUE | FOREGROUND_INTENSITY); + return ""; + } + return "\x1b[94m"; +} + +} // namespace spvtools +#else +namespace spvtools { + +clr::reset::operator const char*() { return ""; } + +clr::grey::operator const char*() { return ""; } + +clr::red::operator const char*() { return ""; } + +clr::green::operator const char*() { return ""; } + +clr::yellow::operator const char*() { return ""; } + +clr::blue::operator const char*() { return ""; } + +} // namespace spvtools +#endif diff --git a/source/print.h b/source/print.h new file mode 100644 index 000000000..f31ba38e7 --- /dev/null +++ b/source/print.h @@ -0,0 +1,75 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_PRINT_H_ +#define SOURCE_PRINT_H_ + +#include +#include + +namespace spvtools { + +// Wrapper for out stream selection. +class out_stream { + public: + out_stream() : pStream(nullptr) {} + explicit out_stream(std::stringstream& stream) : pStream(&stream) {} + + std::ostream& get() { + if (pStream) { + return *pStream; + } + return std::cout; + } + + private: + std::stringstream* pStream; +}; + +namespace clr { +// Resets console color. +struct reset { + operator const char*(); + bool isPrint; +}; +// Sets console color to grey. +struct grey { + operator const char*(); + bool isPrint; +}; +// Sets console color to red. +struct red { + operator const char*(); + bool isPrint; +}; +// Sets console color to green. +struct green { + operator const char*(); + bool isPrint; +}; +// Sets console color to yellow. +struct yellow { + operator const char*(); + bool isPrint; +}; +// Sets console color to blue. +struct blue { + operator const char*(); + bool isPrint; +}; +} // namespace clr + +} // namespace spvtools + +#endif // SOURCE_PRINT_H_ diff --git a/source/reduce/CMakeLists.txt b/source/reduce/CMakeLists.txt new file mode 100644 index 000000000..0a6bce99c --- /dev/null +++ b/source/reduce/CMakeLists.txt @@ -0,0 +1,74 @@ +# Copyright (c) 2018 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set(SPIRV_TOOLS_REDUCE_SOURCES + change_operand_reduction_opportunity.h + change_operand_to_undef_reduction_opportunity.h + operand_to_const_reduction_pass.h + operand_to_undef_reduction_pass.h + operand_to_dominating_id_reduction_pass.h + reducer.h + reduction_opportunity.h + reduction_pass.h + reduction_util.h + remove_instruction_reduction_opportunity.h + remove_opname_instruction_reduction_pass.h + remove_unreferenced_instruction_reduction_pass.h + structured_loop_to_selection_reduction_opportunity.h + structured_loop_to_selection_reduction_pass.h + + change_operand_reduction_opportunity.cpp + change_operand_to_undef_reduction_opportunity.cpp + operand_to_const_reduction_pass.cpp + operand_to_undef_reduction_pass.cpp + operand_to_dominating_id_reduction_pass.cpp + reducer.cpp + reduction_opportunity.cpp + reduction_pass.cpp + reduction_util.cpp + remove_instruction_reduction_opportunity.cpp + remove_unreferenced_instruction_reduction_pass.cpp + remove_opname_instruction_reduction_pass.cpp + structured_loop_to_selection_reduction_opportunity.cpp + structured_loop_to_selection_reduction_pass.cpp + ) + +if(MSVC) + # Enable parallel builds across four cores for this lib + add_definitions(/MP4) +endif() + +spvtools_pch(SPIRV_TOOLS_REDUCE_SOURCES pch_source_reduce) + +add_library(SPIRV-Tools-reduce ${SPIRV_TOOLS_REDUCE_SOURCES}) + +spvtools_default_compile_options(SPIRV-Tools-reduce) +target_include_directories(SPIRV-Tools-reduce + PUBLIC ${spirv-tools_SOURCE_DIR}/include + PUBLIC ${SPIRV_HEADER_INCLUDE_DIR} + PRIVATE ${spirv-tools_BINARY_DIR} +) +# The reducer reuses a lot of functionality from the SPIRV-Tools library. +target_link_libraries(SPIRV-Tools-reduce + PUBLIC ${SPIRV_TOOLS} + PUBLIC SPIRV-Tools-opt) + +set_property(TARGET SPIRV-Tools-reduce PROPERTY FOLDER "SPIRV-Tools libraries") +spvtools_check_symbol_exports(SPIRV-Tools-reduce) + +if(ENABLE_SPIRV_TOOLS_INSTALL) + install(TARGETS SPIRV-Tools-reduce + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}) +endif(ENABLE_SPIRV_TOOLS_INSTALL) diff --git a/source/reduce/change_operand_reduction_opportunity.cpp b/source/reduce/change_operand_reduction_opportunity.cpp new file mode 100644 index 000000000..5430d3e86 --- /dev/null +++ b/source/reduce/change_operand_reduction_opportunity.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "change_operand_reduction_opportunity.h" + +namespace spvtools { +namespace reduce { + +bool ChangeOperandReductionOpportunity::PreconditionHolds() { + // Check that the instruction still has the original operand. + return inst_->NumOperands() > operand_index_ && + inst_->GetOperand(operand_index_).words[0] == original_id_ && + inst_->GetOperand(operand_index_).type == original_type_; +} + +void ChangeOperandReductionOpportunity::Apply() { + inst_->SetOperand(operand_index_, {new_id_}); +} + +} // namespace reduce +} // namespace spvtools diff --git a/source/reduce/change_operand_reduction_opportunity.h b/source/reduce/change_operand_reduction_opportunity.h new file mode 100644 index 000000000..7e1fc8e3b --- /dev/null +++ b/source/reduce/change_operand_reduction_opportunity.h @@ -0,0 +1,56 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_REDUCE_CHANGE_OPERAND_REDUCTION_OPPORTUNITY_H_ +#define SOURCE_REDUCE_CHANGE_OPERAND_REDUCTION_OPPORTUNITY_H_ + +#include "reduction_opportunity.h" +#include "source/opt/instruction.h" +#include "spirv-tools/libspirv.h" + +namespace spvtools { +namespace reduce { + +using namespace opt; + +// An opportunity to replace an id operand of an instruction with some other id. +class ChangeOperandReductionOpportunity : public ReductionOpportunity { + public: + // Constructs the opportunity to replace operand |operand_index| of |inst| + // with |new_id|. + ChangeOperandReductionOpportunity(Instruction* inst, uint32_t operand_index, + uint32_t new_id) + : inst_(inst), + operand_index_(operand_index), + original_id_(inst->GetOperand(operand_index).words[0]), + original_type_(inst->GetOperand(operand_index).type), + new_id_(new_id) {} + + bool PreconditionHolds() override; + + protected: + void Apply() override; + + private: + Instruction* const inst_; + const uint32_t operand_index_; + const uint32_t original_id_; + const spv_operand_type_t original_type_; + const uint32_t new_id_; +}; + +} // namespace reduce +} // namespace spvtools + +#endif // SOURCE_REDUCE_CHANGE_OPERAND_REDUCTION_OPPORTUNITY_H_ diff --git a/source/reduce/change_operand_to_undef_reduction_opportunity.cpp b/source/reduce/change_operand_to_undef_reduction_opportunity.cpp new file mode 100644 index 000000000..8e33da661 --- /dev/null +++ b/source/reduce/change_operand_to_undef_reduction_opportunity.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/reduce/change_operand_to_undef_reduction_opportunity.h" + +#include "source/opt/ir_context.h" +#include "source/reduce/reduction_util.h" + +namespace spvtools { +namespace reduce { + +bool ChangeOperandToUndefReductionOpportunity::PreconditionHolds() { + // Check that the instruction still has the original operand. + return inst_->NumOperands() > operand_index_ && + inst_->GetOperand(operand_index_).words[0] == original_id_; +} + +void ChangeOperandToUndefReductionOpportunity::Apply() { + auto operand = inst_->GetOperand(operand_index_); + auto operand_id = operand.words[0]; + auto operand_id_def = context_->get_def_use_mgr()->GetDef(operand_id); + auto operand_type_id = operand_id_def->type_id(); + // The opportunity should not exist unless this holds. + assert(operand_type_id); + auto undef_id = FindOrCreateGlobalUndef(context_, operand_type_id); + inst_->SetOperand(operand_index_, {undef_id}); +} + +} // namespace reduce +} // namespace spvtools diff --git a/source/reduce/change_operand_to_undef_reduction_opportunity.h b/source/reduce/change_operand_to_undef_reduction_opportunity.h new file mode 100644 index 000000000..ffd3155b0 --- /dev/null +++ b/source/reduce/change_operand_to_undef_reduction_opportunity.h @@ -0,0 +1,53 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_REDUCE_CHANGE_OPERAND_TO_UNDEF_REDUCTION_OPPORTUNITY_H_ +#define SOURCE_REDUCE_CHANGE_OPERAND_TO_UNDEF_REDUCTION_OPPORTUNITY_H_ + +#include "source/opt/instruction.h" +#include "source/reduce/reduction_opportunity.h" +#include "spirv-tools/libspirv.h" + +namespace spvtools { +namespace reduce { + +// An opportunity to replace an id operand of an instruction with undef. +class ChangeOperandToUndefReductionOpportunity : public ReductionOpportunity { + public: + // Constructs the opportunity to replace operand |operand_index| of |inst| + // with undef. + ChangeOperandToUndefReductionOpportunity(opt::IRContext* context, + opt::Instruction* inst, + uint32_t operand_index) + : context_(context), + inst_(inst), + operand_index_(operand_index), + original_id_(inst->GetOperand(operand_index).words[0]) {} + + bool PreconditionHolds() override; + + protected: + void Apply() override; + + private: + opt::IRContext* context_; + opt::Instruction* const inst_; + const uint32_t operand_index_; + const uint32_t original_id_; +}; + +} // namespace reduce +} // namespace spvtools + +#endif // SOURCE_REDUCE_CHANGE_OPERAND_TO_UNDEF_REDUCTION_OPPORTUNITY_H_ diff --git a/source/reduce/operand_to_const_reduction_pass.cpp b/source/reduce/operand_to_const_reduction_pass.cpp new file mode 100644 index 000000000..4d04506e1 --- /dev/null +++ b/source/reduce/operand_to_const_reduction_pass.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/reduce/operand_to_const_reduction_pass.h" + +#include "source/opt/instruction.h" +#include "source/reduce/change_operand_reduction_opportunity.h" + +namespace spvtools { +namespace reduce { + +using namespace opt; + +std::vector> +OperandToConstReductionPass::GetAvailableOpportunities( + opt::IRContext* context) const { + std::vector> result; + assert(result.empty()); + + // We first loop over all constants. This means that all the reduction + // opportunities to replace an operand with a particular constant will be + // contiguous, and in particular it means that multiple, incompatible + // reduction opportunities that try to replace the same operand with distinct + // constants are likely to be discontiguous. This is good because the + // reducer works in the spirit of delta debugging and tries applying large + // contiguous blocks of opportunities early on, and we want to avoid having a + // large block of incompatible opportunities if possible. + for (const auto& constant : context->GetConstants()) { + for (auto& function : *context->module()) { + for (auto& block : function) { + for (auto& inst : block) { + // We iterate through the operands using an explicit index (rather + // than using a lambda) so that we use said index in the construction + // of a ChangeOperandReductionOpportunity + for (uint32_t index = 0; index < inst.NumOperands(); index++) { + const auto& operand = inst.GetOperand(index); + if (spvIsInIdType(operand.type)) { + const auto id = operand.words[0]; + auto def = context->get_def_use_mgr()->GetDef(id); + if (spvOpcodeIsConstant(def->opcode())) { + // The argument is already a constant. + continue; + } + if (def->opcode() == SpvOpFunction) { + // The argument refers to a function, e.g. the function called + // by OpFunctionCall; avoid replacing this with a constant of + // the function's return type. + continue; + } + auto type_id = def->type_id(); + if (type_id) { + if (constant->type_id() == type_id) { + result.push_back( + MakeUnique( + &inst, index, constant->result_id())); + } + } + } + } + } + } + } + } + return result; +} + +std::string OperandToConstReductionPass::GetName() const { + return "OperandToConstReductionPass"; +} + +} // namespace reduce +} // namespace spvtools diff --git a/source/reduce/operand_to_const_reduction_pass.h b/source/reduce/operand_to_const_reduction_pass.h new file mode 100644 index 000000000..4e7381e93 --- /dev/null +++ b/source/reduce/operand_to_const_reduction_pass.h @@ -0,0 +1,48 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_REDUCE_OPERAND_TO_CONST_REDUCTION_PASS_H_ +#define SOURCE_REDUCE_OPERAND_TO_CONST_REDUCTION_PASS_H_ + +#include "source/reduce/reduction_pass.h" + +namespace spvtools { +namespace reduce { + +// A reduction pass for replacing id operands of instructions with ids of +// constants. This reduces the extent to which ids of non-constants are used, +// paving the way for instructions that generate them to be eliminated by other +// passes. +class OperandToConstReductionPass : public ReductionPass { + public: + // Creates the reduction pass in the context of the given target environment + // |target_env| + explicit OperandToConstReductionPass(const spv_target_env target_env) + : ReductionPass(target_env) {} + + ~OperandToConstReductionPass() override = default; + + std::string GetName() const final; + + protected: + std::vector> GetAvailableOpportunities( + opt::IRContext* context) const final; + + private: +}; + +} // namespace reduce +} // namespace spvtools + +#endif // SOURCE_REDUCE_OPERAND_TO_CONST_REDUCTION_PASS_H_ diff --git a/source/reduce/operand_to_dominating_id_reduction_pass.cpp b/source/reduce/operand_to_dominating_id_reduction_pass.cpp new file mode 100644 index 000000000..9280a41dd --- /dev/null +++ b/source/reduce/operand_to_dominating_id_reduction_pass.cpp @@ -0,0 +1,114 @@ +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "operand_to_dominating_id_reduction_pass.h" +#include "change_operand_reduction_opportunity.h" +#include "source/opt/instruction.h" + +namespace spvtools { +namespace reduce { + +using namespace opt; + +std::vector> +OperandToDominatingIdReductionPass::GetAvailableOpportunities( + opt::IRContext* context) const { + std::vector> result; + + // Go through every instruction in every block, considering it as a potential + // dominator of other instructions. We choose this order for two reasons: + // + // (1) it is profitable for multiple opportunities to replace the same id x by + // different dominating ids y and z to be discontiguous, as they are + // incompatible. + // + // (2) We want to prioritise opportunities to replace an id with a more + // distant dominator. Intuitively, in a human-readable programming language + // if we have a complex expression e with many sub-expressions, we would like + // to prioritise replacing e with its smallest sub-expressions; generalising + // this idea to dominating ids this roughly corresponds to more distant + // dominators. + for (auto& function : *context->module()) { + for (auto dominating_block = function.begin(); + dominating_block != function.end(); ++dominating_block) { + for (auto& dominating_inst : *dominating_block) { + if (dominating_inst.HasResultId() && dominating_inst.type_id()) { + // Consider replacing any operand with matching type in a dominated + // instruction with the id generated by this instruction. + GetOpportunitiesForDominatingInst( + &result, &dominating_inst, dominating_block, &function, context); + } + } + } + } + return result; +} + +void OperandToDominatingIdReductionPass::GetOpportunitiesForDominatingInst( + std::vector>* opportunities, + opt::Instruction* candidate_dominator, + opt::Function::iterator candidate_dominator_block, opt::Function* function, + opt::IRContext* context) const { + assert(candidate_dominator->HasResultId()); + assert(candidate_dominator->type_id()); + auto dominator_analysis = context->GetDominatorAnalysis(function); + // SPIR-V requires a block to precede all blocks it dominates, so it suffices + // to search from the candidate dominator block onwards. + for (auto block = candidate_dominator_block; block != function->end(); + ++block) { + if (!dominator_analysis->Dominates(&*candidate_dominator_block, &*block)) { + // If the candidate dominator block doesn't dominate this block then there + // cannot be any of the desired reduction opportunities in this block. + continue; + } + for (auto& inst : *block) { + // We iterate through the operands using an explicit index (rather + // than using a lambda) so that we use said index in the construction + // of a ChangeOperandReductionOpportunity + for (uint32_t index = 0; index < inst.NumOperands(); index++) { + const auto& operand = inst.GetOperand(index); + if (spvIsInIdType(operand.type)) { + const auto id = operand.words[0]; + auto def = context->get_def_use_mgr()->GetDef(id); + assert(def); + if (!context->get_instr_block(def)) { + // The definition does not come from a block; e.g. it might be a + // constant. It is thus not relevant to this pass. + continue; + } + // Sanity check that we don't get here if the argument is a constant. + assert(!context->get_constant_mgr()->GetConstantFromInst(def)); + if (def->type_id() != candidate_dominator->type_id()) { + // The types need to match. + continue; + } + if (candidate_dominator != def && + dominator_analysis->Dominates(candidate_dominator, def)) { + // A hit: the candidate dominator strictly dominates the definition. + opportunities->push_back( + MakeUnique( + &inst, index, candidate_dominator->result_id())); + } + } + } + } + } +} + +std::string OperandToDominatingIdReductionPass::GetName() const { + return "OperandToDominatingIdReductionPass"; +} + +} // namespace reduce +} // namespace spvtools diff --git a/source/reduce/operand_to_dominating_id_reduction_pass.h b/source/reduce/operand_to_dominating_id_reduction_pass.h new file mode 100644 index 000000000..36bb20112 --- /dev/null +++ b/source/reduce/operand_to_dominating_id_reduction_pass.h @@ -0,0 +1,59 @@ +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_REDUCE_OPERAND_TO_DOMINATING_ID_REDUCTION_PASS_H_ +#define SOURCE_REDUCE_OPERAND_TO_DOMINATING_ID_REDUCTION_PASS_H_ + +#include "reduction_pass.h" + +namespace spvtools { +namespace reduce { + +// A reduction pass that aims to bring to SPIR-V (and generalize) the idea from +// human-readable languages of e.g. replacing an expression with one of its +// arguments, (x + y) -> x, or with a reference to an identifier that was +// assigned to higher up in the program. The generalization of this is to +// replace an id with a different id of the same type defined in some +// dominating instruction. +// +// If id x is defined and then used several times, changing each use of x to +// some dominating definition may eventually allow the statement defining x +// to be eliminated by another pass. +class OperandToDominatingIdReductionPass : public ReductionPass { + public: + // Creates the reduction pass in the context of the given target environment + // |target_env| + explicit OperandToDominatingIdReductionPass(const spv_target_env target_env) + : ReductionPass(target_env) {} + + ~OperandToDominatingIdReductionPass() override = default; + + std::string GetName() const final; + + protected: + std::vector> GetAvailableOpportunities( + opt::IRContext* context) const final; + + private: + void GetOpportunitiesForDominatingInst( + std::vector>* opportunities, + opt::Instruction* dominating_instruction, + opt::Function::iterator candidate_dominator_block, + opt::Function* function, opt::IRContext* context) const; +}; + +} // namespace reduce +} // namespace spvtools + +#endif // SOURCE_REDUCE_OPERAND_TO_DOMINATING_ID_REDUCTION_PASS_H_ diff --git a/source/reduce/operand_to_undef_reduction_pass.cpp b/source/reduce/operand_to_undef_reduction_pass.cpp new file mode 100644 index 000000000..e3d8a8ea3 --- /dev/null +++ b/source/reduce/operand_to_undef_reduction_pass.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/reduce/operand_to_undef_reduction_pass.h" + +#include "source/opt/instruction.h" +#include "source/reduce/change_operand_to_undef_reduction_opportunity.h" + +namespace spvtools { +namespace reduce { + +using namespace opt; + +std::vector> +OperandToUndefReductionPass::GetAvailableOpportunities( + IRContext* context) const { + std::vector> result; + + for (auto& function : *context->module()) { + for (auto& block : function) { + for (auto& inst : block) { + // Skip instructions that result in a pointer type. + auto type_id = inst.type_id(); + if (type_id) { + auto type_id_def = context->get_def_use_mgr()->GetDef(type_id); + if (type_id_def->opcode() == SpvOpTypePointer) { + continue; + } + } + + // We iterate through the operands using an explicit index (rather + // than using a lambda) so that we use said index in the construction + // of a ChangeOperandToUndefReductionOpportunity + for (uint32_t index = 0; index < inst.NumOperands(); index++) { + const auto& operand = inst.GetOperand(index); + + if (spvIsInIdType(operand.type)) { + const auto operand_id = operand.words[0]; + auto operand_id_def = + context->get_def_use_mgr()->GetDef(operand_id); + + // Skip constant and undef operands. + // We always want the reducer to make the module "smaller", which + // ensures termination. + // Therefore, we assume: id > undef id > constant id. + if (spvOpcodeIsConstantOrUndef(operand_id_def->opcode())) { + continue; + } + + // Don't replace function operands with undef. + if (operand_id_def->opcode() == SpvOpFunction) { + continue; + } + + // Only consider operands that have a type. + auto operand_type_id = operand_id_def->type_id(); + if (operand_type_id) { + auto operand_type_id_def = + context->get_def_use_mgr()->GetDef(operand_type_id); + + // Skip pointer operands. + if (operand_type_id_def->opcode() == SpvOpTypePointer) { + continue; + } + + result.push_back( + MakeUnique( + context, &inst, index)); + } + } + } + } + } + } + return result; +} + +std::string OperandToUndefReductionPass::GetName() const { + return "OperandToUndefReductionPass"; +} + +} // namespace reduce +} // namespace spvtools diff --git a/source/reduce/operand_to_undef_reduction_pass.h b/source/reduce/operand_to_undef_reduction_pass.h new file mode 100644 index 000000000..e4ec603fb --- /dev/null +++ b/source/reduce/operand_to_undef_reduction_pass.h @@ -0,0 +1,45 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_REDUCE_OPERAND_TO_UNDEF_REDUCTION_PASS_H_ +#define SOURCE_REDUCE_OPERAND_TO_UNDEF_REDUCTION_PASS_H_ + +#include "source/reduce/reduction_pass.h" + +namespace spvtools { +namespace reduce { + +// A reduction pass for replacing id operands of instructions with ids of undef. +class OperandToUndefReductionPass : public ReductionPass { + public: + // Creates the reduction pass in the context of the given target environment + // |target_env| + explicit OperandToUndefReductionPass(const spv_target_env target_env) + : ReductionPass(target_env) {} + + ~OperandToUndefReductionPass() override = default; + + std::string GetName() const final; + + protected: + std::vector> GetAvailableOpportunities( + opt::IRContext* context) const final; + + private: +}; + +} // namespace reduce +} // namespace spvtools + +#endif // SOURCE_REDUCE_OPERAND_TO_UNDEF_REDUCTION_PASS_H_ diff --git a/source/reduce/pch_source_reduce.cpp b/source/reduce/pch_source_reduce.cpp new file mode 100644 index 000000000..61e743645 --- /dev/null +++ b/source/reduce/pch_source_reduce.cpp @@ -0,0 +1,15 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "pch_source_reduce.h" diff --git a/source/reduce/pch_source_reduce.h b/source/reduce/pch_source_reduce.h new file mode 100644 index 000000000..823b55a8a --- /dev/null +++ b/source/reduce/pch_source_reduce.h @@ -0,0 +1,23 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "source/reduce/change_operand_reduction_opportunity.h" +#include "source/reduce/operand_to_const_reduction_pass.h" +#include "source/reduce/reduction_opportunity.h" +#include "source/reduce/reduction_pass.h" +#include "source/reduce/remove_instruction_reduction_opportunity.h" +#include "source/reduce/remove_unreferenced_instruction_reduction_pass.h" diff --git a/source/reduce/reducer.cpp b/source/reduce/reducer.cpp new file mode 100644 index 000000000..4f4429aab --- /dev/null +++ b/source/reduce/reducer.cpp @@ -0,0 +1,154 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "source/spirv_reducer_options.h" + +#include "reducer.h" +#include "reduction_pass.h" + +namespace spvtools { +namespace reduce { + +struct Reducer::Impl { + explicit Impl(spv_target_env env) : target_env(env) {} + + bool ReachedStepLimit(uint32_t current_step, + spv_const_reducer_options options); + + const spv_target_env target_env; // Target environment. + MessageConsumer consumer; // Message consumer. + InterestingnessFunction interestingness_function; + std::vector> passes; +}; + +Reducer::Reducer(spv_target_env env) : impl_(MakeUnique(env)) {} + +Reducer::~Reducer() = default; + +void Reducer::SetMessageConsumer(MessageConsumer c) { + for (auto& pass : impl_->passes) { + pass->SetMessageConsumer(c); + } + impl_->consumer = std::move(c); +} + +void Reducer::SetInterestingnessFunction( + Reducer::InterestingnessFunction interestingness_function) { + impl_->interestingness_function = std::move(interestingness_function); +} + +Reducer::ReductionResultStatus Reducer::Run( + std::vector&& binary_in, std::vector* binary_out, + spv_const_reducer_options options) const { + std::vector current_binary = binary_in; + + // Keeps track of how many reduction attempts have been tried. Reduction + // bails out if this reaches a given limit. + uint32_t reductions_applied = 0; + + // Initial state should be interesting. + if (!impl_->interestingness_function(current_binary, reductions_applied)) { + impl_->consumer(SPV_MSG_INFO, nullptr, {}, + "Initial state was not interesting; stopping."); + return Reducer::ReductionResultStatus::kInitialStateNotInteresting; + } + + // Determines whether, on completing one round of reduction passes, it is + // worthwhile trying a further round. + bool another_round_worthwhile = true; + + // Apply round after round of reduction passes until we hit the reduction + // step limit, or deem that another round is not going to be worthwhile. + while (!impl_->ReachedStepLimit(reductions_applied, options) && + another_round_worthwhile) { + // At the start of a round of reduction passes, assume another round will + // not be worthwhile unless we find evidence to the contrary. + another_round_worthwhile = false; + + // Iterate through the available passes + for (auto& pass : impl_->passes) { + // If this pass hasn't reached its minimum granularity then it's + // worth eventually doing another round of reductions, in order to + // try this pass at a finer granularity. + another_round_worthwhile |= !pass->ReachedMinimumGranularity(); + + // Keep applying this pass at its current granularity until it stops + // working or we hit the reduction step limit. + impl_->consumer(SPV_MSG_INFO, nullptr, {}, + ("Trying pass " + pass->GetName() + ".").c_str()); + do { + auto maybe_result = pass->TryApplyReduction(current_binary); + if (maybe_result.empty()) { + // This pass did not have any impact, so move on to the next pass. + impl_->consumer( + SPV_MSG_INFO, nullptr, {}, + ("Pass " + pass->GetName() + " did not make a reduction step.") + .c_str()); + break; + } + std::stringstream stringstream; + reductions_applied++; + stringstream << "Pass " << pass->GetName() << " made reduction step " + << reductions_applied << "."; + impl_->consumer(SPV_MSG_INFO, nullptr, {}, + (stringstream.str().c_str())); + if (!spvtools::SpirvTools(impl_->target_env).Validate(maybe_result)) { + // The reduction step went wrong and an invalid binary was produced. + // By design, this shouldn't happen; this is a safeguard to stop an + // invalid binary from being regarded as interesting. + impl_->consumer(SPV_MSG_INFO, nullptr, {}, + "Reduction step produced an invalid binary."); + } else if (impl_->interestingness_function(maybe_result, + reductions_applied)) { + // Success! The binary produced by this reduction step is + // interesting, so make it the binary of interest henceforth, and + // note that it's worth doing another round of reduction passes. + impl_->consumer(SPV_MSG_INFO, nullptr, {}, + "Reduction step succeeded."); + current_binary = std::move(maybe_result); + another_round_worthwhile = true; + } + // Bail out if the reduction step limit has been reached. + } while (!impl_->ReachedStepLimit(reductions_applied, options)); + } + } + + *binary_out = std::move(current_binary); + + // Report whether reduction completed, or bailed out early due to reaching + // the step limit. + if (impl_->ReachedStepLimit(reductions_applied, options)) { + impl_->consumer(SPV_MSG_INFO, nullptr, {}, + "Reached reduction step limit; stopping."); + return Reducer::ReductionResultStatus::kReachedStepLimit; + } + impl_->consumer(SPV_MSG_INFO, nullptr, {}, "No more to reduce; stopping."); + return Reducer::ReductionResultStatus::kComplete; +} + +void Reducer::AddReductionPass( + std::unique_ptr&& reduction_pass) { + impl_->passes.push_back(std::move(reduction_pass)); +} + +bool Reducer::Impl::ReachedStepLimit(uint32_t current_step, + spv_const_reducer_options options) { + return current_step >= options->step_limit; +} + +} // namespace reduce +} // namespace spvtools diff --git a/source/reduce/reducer.h b/source/reduce/reducer.h new file mode 100644 index 000000000..3a4c26c97 --- /dev/null +++ b/source/reduce/reducer.h @@ -0,0 +1,97 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_REDUCE_REDUCER_H_ +#define SOURCE_REDUCE_REDUCER_H_ + +#include +#include + +#include "spirv-tools/libspirv.hpp" + +#include "reduction_pass.h" + +namespace spvtools { +namespace reduce { + +// This class manages the process of applying a reduction -- parameterized by a +// number of reduction passes and an interestingness test, to a SPIR-V binary. +class Reducer { + public: + // Possible statuses that can result from running a reduction. + enum ReductionResultStatus { + kInitialStateNotInteresting, + kReachedStepLimit, + kComplete + }; + + // The type for a function that will take a binary and return true if and + // only if the binary is deemed interesting. (The function also takes an + // integer argument that will be incremented each time the function is + // called; this is for debugging purposes). + // + // The notion of "interesting" depends on what properties of the binary or + // tools that process the binary we are trying to maintain during reduction. + using InterestingnessFunction = + std::function&, uint32_t)>; + + // Constructs an instance with the given target |env|, which is used to + // decode the binary to be reduced later. + // + // The constructed instance will have an empty message consumer, which just + // ignores all messages from the library. Use SetMessageConsumer() to supply + // one if messages are of concern. + // + // The constructed instance also needs to have an interestingness function + // set and some reduction passes added to it in order to be useful. + explicit Reducer(spv_target_env env); + + // Disables copy/move constructor/assignment operations. + Reducer(const Reducer&) = delete; + Reducer(Reducer&&) = delete; + Reducer& operator=(const Reducer&) = delete; + Reducer& operator=(Reducer&&) = delete; + + // Destructs this instance. + ~Reducer(); + + // Sets the message consumer to the given |consumer|. The |consumer| will be + // invoked once for each message communicated from the library. + void SetMessageConsumer(MessageConsumer consumer); + + // Sets the function that will be used to decide whether a reduced binary + // turned out to be interesting. + void SetInterestingnessFunction( + InterestingnessFunction interestingness_function); + + // Adds a reduction pass to the sequence of passes that will be iterated + // over. + void AddReductionPass(std::unique_ptr&& reduction_pass); + + // Reduces the given SPIR-V module |binary_out|. + // The reduced binary ends up in |binary_out|. + // A status is returned. + ReductionResultStatus Run(std::vector&& binary_in, + std::vector* binary_out, + spv_const_reducer_options options) const; + + private: + struct Impl; // Opaque struct for holding internal data. + std::unique_ptr impl_; // Unique pointer to internal data. +}; + +} // namespace reduce +} // namespace spvtools + +#endif // SOURCE_REDUCE_REDUCER_H_ diff --git a/source/reduce/reduction_opportunity.cpp b/source/reduce/reduction_opportunity.cpp new file mode 100644 index 000000000..f562678dc --- /dev/null +++ b/source/reduce/reduction_opportunity.cpp @@ -0,0 +1,27 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reduction_opportunity.h" + +namespace spvtools { +namespace reduce { + +void ReductionOpportunity::TryToApply() { + if (PreconditionHolds()) { + Apply(); + } +} + +} // namespace reduce +} // namespace spvtools diff --git a/source/reduce/reduction_opportunity.h b/source/reduce/reduction_opportunity.h new file mode 100644 index 000000000..703a50a46 --- /dev/null +++ b/source/reduce/reduction_opportunity.h @@ -0,0 +1,47 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_REDUCE_REDUCTION_OPPORTUNITY_H_ +#define SOURCE_REDUCE_REDUCTION_OPPORTUNITY_H_ + +#include "spirv-tools/libspirv.hpp" + +namespace spvtools { +namespace reduce { + +// Abstract class: an opportunity to apply a reducing transformation. +class ReductionOpportunity { + public: + ReductionOpportunity() = default; + virtual ~ReductionOpportunity() = default; + + // Returns true if this opportunity has not been disabled by the application + // of another conflicting opportunity. + virtual bool PreconditionHolds() = 0; + + // Applies the opportunity, mutating the module from which the opportunity was + // created. It is a no-op if PreconditionHolds() returns false. + void TryToApply(); + + protected: + // Applies the opportunity, mutating the module from which the opportunity was + // created. + // Precondition: PreconditionHolds() must return true. + virtual void Apply() = 0; +}; + +} // namespace reduce +} // namespace spvtools + +#endif // SOURCE_REDUCE_REDUCTION_OPPORTUNITY_H_ diff --git a/source/reduce/reduction_pass.cpp b/source/reduce/reduction_pass.cpp new file mode 100644 index 000000000..befba8bc6 --- /dev/null +++ b/source/reduce/reduction_pass.cpp @@ -0,0 +1,86 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "reduction_pass.h" + +#include "source/opt/build_module.h" + +namespace spvtools { +namespace reduce { + +std::vector ReductionPass::TryApplyReduction( + const std::vector& binary) { + // We represent modules as binaries because (a) attempts at reduction need to + // end up in binary form to be passed on to SPIR-V-consuming tools, and (b) + // when we apply a reduction step we need to do it on a fresh version of the + // module as if the reduction step proves to be uninteresting we need to + // backtrack; re-parsing from binary provides a very clean way of cloning the + // module. + std::unique_ptr context = + BuildModule(target_env_, consumer_, binary.data(), binary.size()); + assert(context); + + std::vector> opportunities = + GetAvailableOpportunities(context.get()); + + if (!is_initialized_) { + is_initialized_ = true; + index_ = 0; + granularity_ = (uint32_t)opportunities.size(); + } + + if (opportunities.empty()) { + granularity_ = 1; + return std::vector(); + } + + assert(granularity_ > 0); + + if (index_ >= opportunities.size()) { + index_ = 0; + granularity_ = std::max((uint32_t)1, granularity_ / 2); + return std::vector(); + } + + for (uint32_t i = index_; + i < std::min(index_ + granularity_, (uint32_t)opportunities.size()); + ++i) { + opportunities[i]->TryToApply(); + } + + index_ += granularity_; + + std::vector result; + context->module()->ToBinary(&result, false); + return result; +} + +void ReductionPass::SetMessageConsumer(MessageConsumer consumer) { + consumer_ = std::move(consumer); +} + +bool ReductionPass::ReachedMinimumGranularity() const { + if (!is_initialized_) { + // Conceptually we can think that if the pass has not yet been initialized, + // it is operating at unbounded granularity. + return false; + } + assert(granularity_ != 0); + return granularity_ == 1; +} + +} // namespace reduce +} // namespace spvtools diff --git a/source/reduce/reduction_pass.h b/source/reduce/reduction_pass.h new file mode 100644 index 000000000..57e1c5f66 --- /dev/null +++ b/source/reduce/reduction_pass.h @@ -0,0 +1,73 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_REDUCE_REDUCTION_PASS_H_ +#define SOURCE_REDUCE_REDUCTION_PASS_H_ + +#include "spirv-tools/libspirv.hpp" + +#include "reduction_opportunity.h" +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace reduce { + +// Abstract class representing a reduction pass, which can be repeatedly +// invoked to find and apply particular reduction opportunities to a SPIR-V +// binary. In the spirit of delta debugging, a pass initially tries to apply +// large chunks of reduction opportunities, iterating through available +// opportunities at a given granularity. When an iteration over available +// opportunities completes, the granularity is reduced and iteration starts +// again, until the minimum granularity is reached. +class ReductionPass { + public: + // Constructs a reduction pass with a given target environment, |target_env|. + // Initially the pass is uninitialized. + explicit ReductionPass(const spv_target_env target_env) + : target_env_(target_env), is_initialized_(false) {} + + virtual ~ReductionPass() = default; + + // Applies the reduction pass to the given binary. + std::vector TryApplyReduction(const std::vector& binary); + + // Sets a consumer to which relevant messages will be directed. + void SetMessageConsumer(MessageConsumer consumer); + + // Returns true if the granularity with which reduction opportunities are + // applied has reached a minimum. + bool ReachedMinimumGranularity() const; + + // Returns the name of the reduction pass (useful for monitoring reduction + // progress). + virtual std::string GetName() const = 0; + + protected: + // Finds and returns the reduction opportunities relevant to this pass that + // could be applied to the given SPIR-V module. + virtual std::vector> + GetAvailableOpportunities(opt::IRContext* context) const = 0; + + private: + const spv_target_env target_env_; + MessageConsumer consumer_; + bool is_initialized_; + uint32_t index_; + uint32_t granularity_; +}; + +} // namespace reduce +} // namespace spvtools + +#endif // SOURCE_REDUCE_REDUCTION_PASS_H_ diff --git a/source/reduce/reduction_util.cpp b/source/reduce/reduction_util.cpp new file mode 100644 index 000000000..103d63f19 --- /dev/null +++ b/source/reduce/reduction_util.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/reduce/reduction_util.h" + +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace reduce { + +using namespace opt; + +uint32_t FindOrCreateGlobalUndef(IRContext* context, uint32_t type_id) { + for (auto& inst : context->module()->types_values()) { + if (inst.opcode() != SpvOpUndef) { + continue; + } + if (inst.type_id() == type_id) { + return inst.result_id(); + } + } + // TODO(2182): this is adapted from MemPass::Type2Undef. In due course it + // would be good to factor out this duplication. + const uint32_t undef_id = context->TakeNextId(); + std::unique_ptr undef_inst( + new Instruction(context, SpvOpUndef, type_id, undef_id, {})); + assert(undef_id == undef_inst->result_id()); + context->module()->AddGlobalValue(std::move(undef_inst)); + return undef_id; +} + +} // namespace reduce +} // namespace spvtools diff --git a/source/reduce/reduction_util.h b/source/reduce/reduction_util.h new file mode 100644 index 000000000..d8efc9705 --- /dev/null +++ b/source/reduce/reduction_util.h @@ -0,0 +1,33 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_REDUCE_REDUCTION_UTIL_H_ +#define SOURCE_REDUCE_REDUCTION_UTIL_H_ + +#include "spirv-tools/libspirv.hpp" + +#include "source/opt/ir_context.h" +#include "source/reduce/reduction_opportunity.h" + +namespace spvtools { +namespace reduce { + +// Returns an OpUndef id from the global value list that is of the given type, +// adding one if it does not exist. +uint32_t FindOrCreateGlobalUndef(opt::IRContext* context, uint32_t type_id); + +} // namespace reduce +} // namespace spvtools + +#endif // SOURCE_REDUCE_REDUCTION_UTIL_H_ diff --git a/source/reduce/remove_instruction_reduction_opportunity.cpp b/source/reduce/remove_instruction_reduction_opportunity.cpp new file mode 100644 index 000000000..7b7a74e48 --- /dev/null +++ b/source/reduce/remove_instruction_reduction_opportunity.cpp @@ -0,0 +1,29 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/ir_context.h" + +#include "remove_instruction_reduction_opportunity.h" + +namespace spvtools { +namespace reduce { + +bool RemoveInstructionReductionOpportunity::PreconditionHolds() { return true; } + +void RemoveInstructionReductionOpportunity::Apply() { + inst_->context()->KillInst(inst_); +} + +} // namespace reduce +} // namespace spvtools diff --git a/source/reduce/remove_instruction_reduction_opportunity.h b/source/reduce/remove_instruction_reduction_opportunity.h new file mode 100644 index 000000000..e9f442e3f --- /dev/null +++ b/source/reduce/remove_instruction_reduction_opportunity.h @@ -0,0 +1,46 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_REDUCE_REMOVE_INSTRUCTION_REDUCTION_OPPORTUNITY_H_ +#define SOURCE_REDUCE_REMOVE_INSTRUCTION_REDUCTION_OPPORTUNITY_H_ + +#include "reduction_opportunity.h" +#include "source/opt/instruction.h" + +namespace spvtools { +namespace reduce { + +using namespace opt; + +// An opportunity to remove an instruction from the SPIR-V module. +class RemoveInstructionReductionOpportunity : public ReductionOpportunity { + public: + // Constructs the opportunity to remove |inst|. + explicit RemoveInstructionReductionOpportunity(Instruction* inst) + : inst_(inst) {} + + // Always returns true, as this opportunity can always be applied. + bool PreconditionHolds() override; + + protected: + void Apply() override; + + private: + Instruction* inst_; +}; + +} // namespace reduce +} // namespace spvtools + +#endif // SOURCE_REDUCE_REMOVE_INSTRUCTION_REDUCTION_OPPORTUNITY_H_ diff --git a/source/reduce/remove_opname_instruction_reduction_pass.cpp b/source/reduce/remove_opname_instruction_reduction_pass.cpp new file mode 100644 index 000000000..bf99bc57f --- /dev/null +++ b/source/reduce/remove_opname_instruction_reduction_pass.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "remove_opname_instruction_reduction_pass.h" +#include "remove_instruction_reduction_opportunity.h" +#include "source/opcode.h" +#include "source/opt/instruction.h" + +namespace spvtools { +namespace reduce { + +using namespace opt; + +std::vector> +RemoveOpNameInstructionReductionPass::GetAvailableOpportunities( + opt::IRContext* context) const { + std::vector> result; + + for (auto& inst : context->module()->debugs2()) { + if (inst.opcode() == SpvOpName || inst.opcode() == SpvOpMemberName) { + result.push_back( + MakeUnique(&inst)); + } + } + return result; +} + +std::string RemoveOpNameInstructionReductionPass::GetName() const { + return "RemoveOpNameInstructionReductionPass"; +} + +} // namespace reduce +} // namespace spvtools diff --git a/source/reduce/remove_opname_instruction_reduction_pass.h b/source/reduce/remove_opname_instruction_reduction_pass.h new file mode 100644 index 000000000..2990a49cc --- /dev/null +++ b/source/reduce/remove_opname_instruction_reduction_pass.h @@ -0,0 +1,48 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_REDUCE_REMOVE_OPNAME_INSTRUCTION_REDUCTION_PASS_H_ +#define SOURCE_REDUCE_REMOVE_OPNAME_INSTRUCTION_REDUCTION_PASS_H_ + +#include "reduction_pass.h" + +namespace spvtools { +namespace reduce { + +// A reduction pass for removing OpName instructions. As well as making the +// module smaller, removing an OpName instruction may create opportunities +// for subsequently removing the instructions that create the ids to which the +// OpName applies. +class RemoveOpNameInstructionReductionPass : public ReductionPass { + public: + // Creates the reduction pass in the context of the given target environment + // |target_env| + explicit RemoveOpNameInstructionReductionPass(const spv_target_env target_env) + : ReductionPass(target_env) {} + + ~RemoveOpNameInstructionReductionPass() override = default; + + std::string GetName() const final; + + protected: + std::vector> GetAvailableOpportunities( + opt::IRContext* context) const final; + + private: +}; + +} // namespace reduce +} // namespace spvtools + +#endif // SOURCE_REDUCE_REMOVE_OpName_INSTRUCTION_REDUCTION_PASS_H_ diff --git a/source/reduce/remove_unreferenced_instruction_reduction_pass.cpp b/source/reduce/remove_unreferenced_instruction_reduction_pass.cpp new file mode 100644 index 000000000..bc4998d6a --- /dev/null +++ b/source/reduce/remove_unreferenced_instruction_reduction_pass.cpp @@ -0,0 +1,60 @@ +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "remove_unreferenced_instruction_reduction_pass.h" +#include "remove_instruction_reduction_opportunity.h" +#include "source/opcode.h" +#include "source/opt/instruction.h" + +namespace spvtools { +namespace reduce { + +using namespace opt; + +std::vector> +RemoveUnreferencedInstructionReductionPass::GetAvailableOpportunities( + opt::IRContext* context) const { + std::vector> result; + + for (auto& function : *context->module()) { + for (auto& block : function) { + for (auto& inst : block) { + if (context->get_def_use_mgr()->NumUses(&inst) > 0) { + continue; + } + if (spvOpcodeIsBlockTerminator(inst.opcode()) || + inst.opcode() == SpvOpSelectionMerge || + inst.opcode() == SpvOpLoopMerge) { + // In this reduction pass we do not want to affect static control + // flow. + continue; + } + // Given that we're in a block, we should only get here if the + // instruction is not directly related to control flow; i.e., it's + // some straightforward instruction with an unused result, like an + // arithmetic operation or function call. + result.push_back( + MakeUnique(&inst)); + } + } + } + return result; +} + +std::string RemoveUnreferencedInstructionReductionPass::GetName() const { + return "RemoveUnreferencedInstructionReductionPass"; +} + +} // namespace reduce +} // namespace spvtools diff --git a/source/reduce/remove_unreferenced_instruction_reduction_pass.h b/source/reduce/remove_unreferenced_instruction_reduction_pass.h new file mode 100644 index 000000000..44a0d55e9 --- /dev/null +++ b/source/reduce/remove_unreferenced_instruction_reduction_pass.h @@ -0,0 +1,50 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_REDUCE_REMOVE_UNREFERENCED_INSTRUCTION_REDUCTION_PASS_H_ +#define SOURCE_REDUCE_REMOVE_UNREFERENCED_INSTRUCTION_REDUCTION_PASS_H_ + +#include "reduction_pass.h" + +namespace spvtools { +namespace reduce { + +// A reduction pass for removing non-control-flow instructions in blocks in +// cases where the instruction's id is not referenced. As well as making the +// module smaller, removing an instruction that references particular ids may +// create opportunities for subsequently removing the instructions that +// generated those ids. +class RemoveUnreferencedInstructionReductionPass : public ReductionPass { + public: + // Creates the reduction pass in the context of the given target environment + // |target_env| + explicit RemoveUnreferencedInstructionReductionPass( + const spv_target_env target_env) + : ReductionPass(target_env) {} + + ~RemoveUnreferencedInstructionReductionPass() override = default; + + std::string GetName() const final; + + protected: + std::vector> GetAvailableOpportunities( + opt::IRContext* context) const final; + + private: +}; + +} // namespace reduce +} // namespace spvtools + +#endif // SOURCE_REDUCE_REMOVE_UNREFERENCED_INSTRUCTION_REDUCTION_PASS_H_ diff --git a/source/reduce/structured_loop_to_selection_reduction_opportunity.cpp b/source/reduce/structured_loop_to_selection_reduction_opportunity.cpp new file mode 100644 index 000000000..bf0e085ad --- /dev/null +++ b/source/reduce/structured_loop_to_selection_reduction_opportunity.cpp @@ -0,0 +1,360 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/reduce/structured_loop_to_selection_reduction_opportunity.h" + +#include "source/opt/aggressive_dead_code_elim_pass.h" +#include "source/opt/ir_context.h" +#include "source/reduce/reduction_util.h" + +namespace spvtools { +namespace reduce { + +namespace { +const uint32_t kMergeNodeIndex = 0; +const uint32_t kContinueNodeIndex = 1; +} // namespace + +bool StructuredLoopToSelectionReductionOpportunity::PreconditionHolds() { + // Is the loop header reachable? + return loop_construct_header_->GetLabel() + ->context() + ->GetDominatorAnalysis(enclosing_function_) + ->IsReachable(loop_construct_header_); +} + +void StructuredLoopToSelectionReductionOpportunity::Apply() { + // Force computation of dominator analysis, CFG and structured CFG analysis + // before we start to mess with edges in the function. + context_->GetDominatorAnalysis(enclosing_function_); + context_->cfg(); + context_->GetStructuredCFGAnalysis(); + + // (1) Redirect edges that point to the loop's continue target to their + // closest merge block. + RedirectToClosestMergeBlock( + loop_construct_header_->GetLoopMergeInst()->GetSingleWordOperand( + kContinueNodeIndex)); + + // (2) Redirect edges that point to the loop's merge block to their closest + // merge block (which might be that of an enclosing selection, for instance). + RedirectToClosestMergeBlock( + loop_construct_header_->GetLoopMergeInst()->GetSingleWordOperand( + kMergeNodeIndex)); + + // (3) Turn the loop construct header into a selection. + ChangeLoopToSelection(); + + // We have made control flow changes that do not preserve the analyses that + // were performed. + context_->InvalidateAnalysesExceptFor(IRContext::Analysis::kAnalysisNone); + + // (4) By changing CFG edges we may have created scenarios where ids are used + // without being dominated; we fix instances of this. + FixNonDominatedIdUses(); + + // Invalidate the analyses we just used. + context_->InvalidateAnalysesExceptFor(IRContext::Analysis::kAnalysisNone); +} + +void StructuredLoopToSelectionReductionOpportunity::RedirectToClosestMergeBlock( + uint32_t original_target_id) { + // Consider every predecessor of the node with respect to which edges should + // be redirected. + std::set already_seen; + for (auto pred : context_->cfg()->preds(original_target_id)) { + if (already_seen.find(pred) != already_seen.end()) { + // We have already handled this predecessor (this scenario can arise if + // there are multiple edges from a block b to original_target_id). + continue; + } + already_seen.insert(pred); + + if (!context_->GetDominatorAnalysis(enclosing_function_) + ->IsReachable(pred)) { + // We do not care about unreachable predecessors (and dominance + // information, and thus the notion of structured control flow, makes + // little sense for unreachable blocks). + continue; + } + // Find the merge block of the structured control construct that most + // tightly encloses the predecessor. + uint32_t new_merge_target; + // The structured CFG analysis deliberately does not regard a header as + // belonging to the structure that it heads. We want it to, so handle this + // case specially. + if (context_->cfg()->block(pred)->MergeBlockIdIfAny()) { + new_merge_target = context_->cfg()->block(pred)->MergeBlockIdIfAny(); + } else { + new_merge_target = context_->GetStructuredCFGAnalysis()->MergeBlock(pred); + } + assert(new_merge_target != pred); + + if (!new_merge_target) { + // If the loop being transformed is outermost, and the predecessor is + // part of that loop's continue construct, there will be no such + // enclosing control construct. In this case, the continue construct + // will become unreachable anyway, so it is fine not to redirect the + // edge. + continue; + } + + if (new_merge_target != original_target_id) { + // Redirect the edge if it doesn't already point to the desired block. + RedirectEdge(pred, original_target_id, new_merge_target); + } + } +} + +void StructuredLoopToSelectionReductionOpportunity::RedirectEdge( + uint32_t source_id, uint32_t original_target_id, uint32_t new_target_id) { + // Redirect edge source_id->original_target_id to edge + // source_id->new_target_id, where the blocks involved are all different. + assert(source_id != original_target_id); + assert(source_id != new_target_id); + assert(original_target_id != new_target_id); + + // original_target_id must either be the merge target or continue construct + // for the loop being operated on. + assert(original_target_id == + loop_construct_header_->GetMergeInst()->GetSingleWordOperand( + kMergeNodeIndex) || + original_target_id == + loop_construct_header_->GetMergeInst()->GetSingleWordOperand( + kContinueNodeIndex)); + + auto terminator = context_->cfg()->block(source_id)->terminator(); + + // Figure out which operands of the terminator need to be considered for + // redirection. + std::vector operand_indices; + if (terminator->opcode() == SpvOpBranch) { + operand_indices = {0}; + } else if (terminator->opcode() == SpvOpBranchConditional) { + operand_indices = {1, 2}; + } else { + assert(terminator->opcode() == SpvOpSwitch); + for (uint32_t label_index = 1; label_index < terminator->NumOperands(); + label_index += 2) { + operand_indices.push_back(label_index); + } + } + + // Redirect the relevant operands, asserting that at least one redirection is + // made. + bool redirected = false; + for (auto operand_index : operand_indices) { + if (terminator->GetSingleWordOperand(operand_index) == original_target_id) { + terminator->SetOperand(operand_index, {new_target_id}); + redirected = true; + } + } + (void)(redirected); + assert(redirected); + + // The old and new targets may have phi instructions; these will need to + // respect the change in edges. + AdaptPhiInstructionsForRemovedEdge( + source_id, context_->cfg()->block(original_target_id)); + AdaptPhiInstructionsForAddedEdge(source_id, + context_->cfg()->block(new_target_id)); +} + +void StructuredLoopToSelectionReductionOpportunity:: + AdaptPhiInstructionsForRemovedEdge(uint32_t from_id, BasicBlock* to_block) { + to_block->ForEachPhiInst([&from_id](Instruction* phi_inst) { + Instruction::OperandList new_in_operands; + // Go through the OpPhi's input operands in (variable, parent) pairs. + for (uint32_t index = 0; index < phi_inst->NumInOperands(); index += 2) { + // Keep all pairs where the parent is not the block from which the edge + // is being removed. + if (phi_inst->GetInOperand(index + 1).words[0] != from_id) { + new_in_operands.push_back(phi_inst->GetInOperand(index)); + new_in_operands.push_back(phi_inst->GetInOperand(index + 1)); + } + } + phi_inst->SetInOperands(std::move(new_in_operands)); + }); +} + +void StructuredLoopToSelectionReductionOpportunity:: + AdaptPhiInstructionsForAddedEdge(uint32_t from_id, BasicBlock* to_block) { + to_block->ForEachPhiInst([this, &from_id](Instruction* phi_inst) { + // Add to the phi operand an (undef, from_id) pair to reflect the added + // edge. + auto undef_id = FindOrCreateGlobalUndef(context_, phi_inst->type_id()); + phi_inst->AddOperand(Operand(SPV_OPERAND_TYPE_ID, {undef_id})); + phi_inst->AddOperand(Operand(SPV_OPERAND_TYPE_ID, {from_id})); + }); +} + +void StructuredLoopToSelectionReductionOpportunity::ChangeLoopToSelection() { + // Change the merge instruction from OpLoopMerge to OpSelectionMerge, with + // the same merge block. + auto loop_merge_inst = loop_construct_header_->GetLoopMergeInst(); + auto const loop_merge_block_id = + loop_merge_inst->GetSingleWordOperand(kMergeNodeIndex); + loop_merge_inst->SetOpcode(SpvOpSelectionMerge); + loop_merge_inst->ReplaceOperands( + {{loop_merge_inst->GetOperand(kMergeNodeIndex).type, + {loop_merge_block_id}}, + {SPV_OPERAND_TYPE_SELECTION_CONTROL, {SpvSelectionControlMaskNone}}}); + + // The loop header either finishes with OpBranch or OpBranchConditional. + // The latter is fine for a selection. In the former case we need to turn + // it into OpBranchConditional. We use "true" as the condition, and make + // the "else" branch be the merge block. + auto terminator = loop_construct_header_->terminator(); + if (terminator->opcode() == SpvOpBranch) { + analysis::Bool temp; + const analysis::Bool* bool_type = + context_->get_type_mgr()->GetRegisteredType(&temp)->AsBool(); + auto const_mgr = context_->get_constant_mgr(); + auto true_const = const_mgr->GetConstant(bool_type, {true}); + auto true_const_result_id = + const_mgr->GetDefiningInstruction(true_const)->result_id(); + auto original_branch_id = terminator->GetSingleWordOperand(0); + terminator->SetOpcode(SpvOpBranchConditional); + terminator->ReplaceOperands({{SPV_OPERAND_TYPE_ID, {true_const_result_id}}, + {SPV_OPERAND_TYPE_ID, {original_branch_id}}, + {SPV_OPERAND_TYPE_ID, {loop_merge_block_id}}}); + if (original_branch_id != loop_merge_block_id) { + AdaptPhiInstructionsForAddedEdge( + loop_construct_header_->id(), + context_->cfg()->block(loop_merge_block_id)); + } + } +} + +void StructuredLoopToSelectionReductionOpportunity::FixNonDominatedIdUses() { + // Consider each instruction in the function. + for (auto& block : *enclosing_function_) { + for (auto& def : block) { + if (def.opcode() == SpvOpVariable) { + // Variables are defined at the start of the function, and can be + // accessed by all blocks, even by unreachable blocks that have no + // dominators, so we do not need to worry about them. + continue; + } + context_->get_def_use_mgr()->ForEachUse(&def, [this, &block, &def]( + Instruction* use, + uint32_t index) { + // If a use is not appropriately dominated by its definition, + // replace the use with an OpUndef, unless the definition is an + // access chain, in which case replace it with some (possibly fresh) + // variable (as we cannot load from / store to OpUndef). + if (!DefinitionSufficientlyDominatesUse(&def, use, index, block)) { + if (def.opcode() == SpvOpAccessChain) { + auto pointer_type = + context_->get_type_mgr()->GetType(def.type_id())->AsPointer(); + switch (pointer_type->storage_class()) { + case SpvStorageClassFunction: + use->SetOperand( + index, {FindOrCreateFunctionVariable( + context_->get_type_mgr()->GetId(pointer_type))}); + break; + default: + // TODO(2183) Need to think carefully about whether it makes + // sense to add new variables for all storage classes; it's fine + // for Private but might not be OK for input/output storage + // classes for example. + use->SetOperand( + index, {FindOrCreateGlobalVariable( + context_->get_type_mgr()->GetId(pointer_type))}); + break; + } + } else { + use->SetOperand(index, + {FindOrCreateGlobalUndef(context_, def.type_id())}); + } + } + }); + } + } +} + +bool StructuredLoopToSelectionReductionOpportunity:: + DefinitionSufficientlyDominatesUse(Instruction* def, Instruction* use, + uint32_t use_index, + BasicBlock& def_block) { + if (use->opcode() == SpvOpPhi) { + // A use in a phi doesn't need to be dominated by its definition, but the + // associated parent block does need to be dominated by the definition. + return context_->GetDominatorAnalysis(enclosing_function_) + ->Dominates(def_block.id(), use->GetSingleWordOperand(use_index + 1)); + } + // In non-phi cases, a use needs to be dominated by its definition. + return context_->GetDominatorAnalysis(enclosing_function_) + ->Dominates(def, use); +} + +uint32_t +StructuredLoopToSelectionReductionOpportunity::FindOrCreateGlobalVariable( + uint32_t pointer_type_id) { + for (auto& inst : context_->module()->types_values()) { + if (inst.opcode() != SpvOpVariable) { + continue; + } + if (inst.type_id() == pointer_type_id) { + return inst.result_id(); + } + } + const uint32_t variable_id = context_->TakeNextId(); + std::unique_ptr variable_inst( + new Instruction(context_, SpvOpVariable, pointer_type_id, variable_id, + {{SPV_OPERAND_TYPE_STORAGE_CLASS, + {(uint32_t)context_->get_type_mgr() + ->GetType(pointer_type_id) + ->AsPointer() + ->storage_class()}}})); + context_->module()->AddGlobalValue(std::move(variable_inst)); + return variable_id; +} + +uint32_t +StructuredLoopToSelectionReductionOpportunity::FindOrCreateFunctionVariable( + uint32_t pointer_type_id) { + // The pointer type of a function variable must have Function storage class. + assert(context_->get_type_mgr() + ->GetType(pointer_type_id) + ->AsPointer() + ->storage_class() == SpvStorageClassFunction); + + // Go through the instructions in the function's first block until we find a + // suitable variable, or go past all the variables. + BasicBlock::iterator iter = enclosing_function_->begin()->begin(); + for (;; ++iter) { + // We will either find a suitable variable, or find a non-variable + // instruction; we won't exhaust all instructions. + assert(iter != enclosing_function_->begin()->end()); + if (iter->opcode() != SpvOpVariable) { + // If we see a non-variable, we have gone through all the variables. + break; + } + if (iter->type_id() == pointer_type_id) { + return iter->result_id(); + } + } + // At this point, iter refers to the first non-function instruction of the + // function's entry block. + const uint32_t variable_id = context_->TakeNextId(); + std::unique_ptr variable_inst(new Instruction( + context_, SpvOpVariable, pointer_type_id, variable_id, + {{SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}})); + iter->InsertBefore(std::move(variable_inst)); + return variable_id; +} + +} // namespace reduce +} // namespace spvtools diff --git a/source/reduce/structured_loop_to_selection_reduction_opportunity.h b/source/reduce/structured_loop_to_selection_reduction_opportunity.h new file mode 100644 index 000000000..71b0f0eb4 --- /dev/null +++ b/source/reduce/structured_loop_to_selection_reduction_opportunity.h @@ -0,0 +1,116 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_REDUCE_CUT_LOOP_REDUCTION_OPPORTUNITY_H_ +#define SOURCE_REDUCE_CUT_LOOP_REDUCTION_OPPORTUNITY_H_ + +#include "source/opt/def_use_manager.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/function.h" +#include "source/reduce/reduction_opportunity.h" + +namespace spvtools { +namespace reduce { + +using namespace opt; + +// An opportunity to replace a structured loop with a selection. +class StructuredLoopToSelectionReductionOpportunity + : public ReductionOpportunity { + public: + // Constructs an opportunity from a loop header block and the function that + // encloses it. + explicit StructuredLoopToSelectionReductionOpportunity( + IRContext* context, BasicBlock* loop_construct_header, + Function* enclosing_function) + : context_(context), + loop_construct_header_(loop_construct_header), + enclosing_function_(enclosing_function) {} + + // Returns true if the loop header is reachable. A structured loop might + // become unreachable as a result of turning another structured loop into + // a selection. + bool PreconditionHolds() override; + + protected: + void Apply() override; + + private: + // Parameter |original_target_id| is the id of the loop's merge block or + // continue target. This method considers each edge of the form + // b->original_target_id and transforms it into an edge of the form b->c, + // where c is the merge block of the structured control flow construct that + // most tightly contains b. + void RedirectToClosestMergeBlock(uint32_t original_target_id); + + // |source_id|, |original_target_id| and |new_target_id| are required to all + // be distinct, with a CFG edge existing from |source_id| to + // |original_target_id|, and |original_target_id| being either the merge block + // or continue target for the loop being operated on. + // The method removes this edge and adds an edge from + // |source_id| to |new_target_id|. It takes care of fixing up any OpPhi + // instructions associated with |original_target_id| and |new_target_id|. + void RedirectEdge(uint32_t source_id, uint32_t original_target_id, + uint32_t new_target_id); + + // Removes any components of |to_block|'s phi instructions relating to + // |from_id|. + void AdaptPhiInstructionsForRemovedEdge(uint32_t from_id, + BasicBlock* to_block); + + // Adds components to |to_block|'s phi instructions to account for a new + // incoming edge from |from_id|. + void AdaptPhiInstructionsForAddedEdge(uint32_t from_id, BasicBlock* to_block); + + // Turns the OpLoopMerge for the loop into OpSelectionMerge, and adapts the + // following branch instruction accordingly. + void ChangeLoopToSelection(); + + // Fixes any scenarios where, due to CFG changes, ids have uses not dominated + // by their definitions, by changing such uses to uses of OpUndef or of dummy + // variables. + void FixNonDominatedIdUses(); + + // Returns true if and only if at least one of the following holds: + // 1) |def| dominates |use| + // 2) |def| is an OpVariable + // 3) |use| is part of an OpPhi, with associated incoming block b, and |def| + // dominates b. + bool DefinitionSufficientlyDominatesUse(Instruction* def, Instruction* use, + uint32_t use_index, + BasicBlock& def_block); + + // Checks whether the global value list has an OpVariable of the given pointer + // type, adding one if not, and returns the id of such an OpVariable. + // + // TODO(2184): This will likely be used by other reduction passes, so should + // be factored out in due course. + uint32_t FindOrCreateGlobalVariable(uint32_t pointer_type_id); + + // Checks whether the enclosing function has an OpVariable of the given + // pointer type, adding one if not, and returns the id of such an OpVariable. + // + // TODO(2184): This will likely be used by other reduction passes, so should + // be factored out in due course. + uint32_t FindOrCreateFunctionVariable(uint32_t pointer_type_id); + + IRContext* context_; + BasicBlock* loop_construct_header_; + Function* enclosing_function_; +}; + +} // namespace reduce +} // namespace spvtools + +#endif // SOURCE_REDUCE_CUT_LOOP_REDUCTION_OPPORTUNITY_H_ diff --git a/source/reduce/structured_loop_to_selection_reduction_pass.cpp b/source/reduce/structured_loop_to_selection_reduction_pass.cpp new file mode 100644 index 000000000..768a2e8e1 --- /dev/null +++ b/source/reduce/structured_loop_to_selection_reduction_pass.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "structured_loop_to_selection_reduction_pass.h" +#include "structured_loop_to_selection_reduction_opportunity.h" + +namespace spvtools { +namespace reduce { + +using namespace opt; + +namespace { +const uint32_t kMergeNodeIndex = 0; +const uint32_t kContinueNodeIndex = 1; +} // namespace + +std::vector> +StructuredLoopToSelectionReductionPass::GetAvailableOpportunities( + opt::IRContext* context) const { + std::vector> result; + + std::set merge_block_ids; + for (auto& function : *context->module()) { + for (auto& block : function) { + auto merge_inst = block.GetMergeInst(); + if (merge_inst) { + merge_block_ids.insert( + merge_inst->GetSingleWordOperand(kMergeNodeIndex)); + } + } + } + + // Consider each loop construct header in the module. + for (auto& function : *context->module()) { + for (auto& block : function) { + auto loop_merge_inst = block.GetLoopMergeInst(); + if (!loop_merge_inst) { + // This is not a loop construct header. + continue; + } + + // Check whether the loop construct's continue target is the merge block + // of some structured control flow construct. If it is, we cautiously do + // not consider applying a transformation. + if (merge_block_ids.find(loop_merge_inst->GetSingleWordOperand( + kContinueNodeIndex)) != merge_block_ids.end()) { + continue; + } + + // Check whether the loop construct header dominates its merge block. + // If not, the merge block must be unreachable in the control flow graph + // so we cautiously do not consider applying a transformation. + auto merge_block_id = + loop_merge_inst->GetSingleWordInOperand(kMergeNodeIndex); + if (!context->GetDominatorAnalysis(&function)->Dominates( + block.id(), merge_block_id)) { + continue; + } + + // Check whether the loop construct merge block postdominates the loop + // construct header. If not (e.g. because the loop contains OpReturn, + // OpKill or OpUnreachable), we cautiously do not consider applying + // a transformation. + if (!context->GetPostDominatorAnalysis(&function)->Dominates( + merge_block_id, block.id())) { + continue; + } + + // We can turn this structured loop into a selection, so add the + // opportunity to do so. + result.push_back( + MakeUnique( + context, &block, &function)); + } + } + return result; +} + +std::string StructuredLoopToSelectionReductionPass::GetName() const { + return "StructuredLoopToSelectionReductionPass"; +} + +} // namespace reduce +} // namespace spvtools diff --git a/source/reduce/structured_loop_to_selection_reduction_pass.h b/source/reduce/structured_loop_to_selection_reduction_pass.h new file mode 100644 index 000000000..a1f88bc54 --- /dev/null +++ b/source/reduce/structured_loop_to_selection_reduction_pass.h @@ -0,0 +1,61 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_REDUCE_CUT_LOOP_REDUCTION_PASS_H_ +#define SOURCE_REDUCE_CUT_LOOP_REDUCTION_PASS_H_ + +#include "reduction_pass.h" + +namespace spvtools { +namespace reduce { + +// Turns structured loops into selections, generalizing from a human-writable +// language the idea of turning a loop: +// +// while (c) { +// body; +// } +// +// into: +// +// if (c) { +// body; +// } +// +// The pass results in continue constructs of transformed loops becoming +// unreachable; another pass for eliminating blocks may end up being able to +// remove them. +class StructuredLoopToSelectionReductionPass : public ReductionPass { + public: + // Creates the reduction pass in the context of the given target environment + // |target_env| + explicit StructuredLoopToSelectionReductionPass( + const spv_target_env target_env) + : ReductionPass(target_env) {} + + ~StructuredLoopToSelectionReductionPass() override = default; + + std::string GetName() const final; + + protected: + std::vector> GetAvailableOpportunities( + opt::IRContext* context) const final; + + private: +}; + +} // namespace reduce +} // namespace spvtools + +#endif // SOURCE_REDUCE_CUT_LOOP_REDUCTION_PASS_H_ diff --git a/source/software_version.cpp b/source/software_version.cpp new file mode 100644 index 000000000..b258ebe90 --- /dev/null +++ b/source/software_version.cpp @@ -0,0 +1,27 @@ +// Copyright (c) 2015-2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "spirv-tools/libspirv.h" + +namespace { + +const char* kBuildVersions[] = { +#include "build-version.inc" +}; + +} // anonymous namespace + +const char* spvSoftwareVersionString(void) { return kBuildVersions[0]; } + +const char* spvSoftwareVersionDetailsString(void) { return kBuildVersions[1]; } diff --git a/source/spirv_constant.h b/source/spirv_constant.h new file mode 100644 index 000000000..39771ccb2 --- /dev/null +++ b/source/spirv_constant.h @@ -0,0 +1,100 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_SPIRV_CONSTANT_H_ +#define SOURCE_SPIRV_CONSTANT_H_ + +#include "source/latest_version_spirv_header.h" +#include "spirv-tools/libspirv.h" + +// Version number macros. + +// Evaluates to a well-formed version header word, given valid +// SPIR-V version major and minor version numbers. +#define SPV_SPIRV_VERSION_WORD(MAJOR, MINOR) \ + ((uint32_t(uint8_t(MAJOR)) << 16) | (uint32_t(uint8_t(MINOR)) << 8)) +// Returns the major version extracted from a version header word. +#define SPV_SPIRV_VERSION_MAJOR_PART(WORD) ((uint32_t(WORD) >> 16) & 0xff) +// Returns the minor version extracted from a version header word. +#define SPV_SPIRV_VERSION_MINOR_PART(WORD) ((uint32_t(WORD) >> 8) & 0xff) + +// Header indices + +#define SPV_INDEX_MAGIC_NUMBER 0u +#define SPV_INDEX_VERSION_NUMBER 1u +#define SPV_INDEX_GENERATOR_NUMBER 2u +#define SPV_INDEX_BOUND 3u +#define SPV_INDEX_SCHEMA 4u +#define SPV_INDEX_INSTRUCTION 5u + +// Universal limits + +// SPIR-V 1.0 limits +#define SPV_LIMIT_INSTRUCTION_WORD_COUNT_MAX 0xffff +#define SPV_LIMIT_LITERAL_STRING_UTF8_CHARS_MAX 0xffff + +// A single Unicode character in UTF-8 encoding can take +// up 4 bytes. +#define SPV_LIMIT_LITERAL_STRING_BYTES_MAX \ + (SPV_LIMIT_LITERAL_STRING_UTF8_CHARS_MAX * 4) + +// NOTE: These are set to the minimum maximum values +// TODO(dneto): Check these. + +// libspirv limits. +#define SPV_LIMIT_RESULT_ID_BOUND 0x00400000 +#define SPV_LIMIT_CONTROL_FLOW_NEST_DEPTH 0x00000400 +#define SPV_LIMIT_GLOBAL_VARIABLES_MAX 0x00010000 +#define SPV_LIMIT_LOCAL_VARIABLES_MAX 0x00080000 +// TODO: Decorations per target ID max, depends on decoration table size +#define SPV_LIMIT_EXECUTION_MODE_PER_ENTRY_POINT_MAX 0x00000100 +#define SPV_LIMIT_INDICIES_MAX_ACCESS_CHAIN_COMPOSITE_MAX 0x00000100 +#define SPV_LIMIT_FUNCTION_PARAMETERS_PER_FUNCTION_DECL 0x00000100 +#define SPV_LIMIT_FUNCTION_CALL_ARGUMENTS_MAX 0x00000100 +#define SPV_LIMIT_EXT_FUNCTION_CALL_ARGUMENTS_MAX 0x00000100 +#define SPV_LIMIT_SWITCH_LITERAL_LABEL_PAIRS_MAX 0x00004000 +#define SPV_LIMIT_STRUCT_MEMBERS_MAX 0x0000400 +#define SPV_LIMIT_STRUCT_NESTING_DEPTH_MAX 0x00000100 + +// Enumerations + +// Values mapping to registered tools. See the registry at +// https://www.khronos.org/registry/spir-v/api/spir-v.xml +// These values occupy the higher order 16 bits of the generator magic word. +typedef enum spv_generator_t { + // TODO(dneto) Values 0 through 5 were registered only as vendor. + SPV_GENERATOR_KHRONOS = 0, + SPV_GENERATOR_LUNARG = 1, + SPV_GENERATOR_VALVE = 2, + SPV_GENERATOR_CODEPLAY = 3, + SPV_GENERATOR_NVIDIA = 4, + SPV_GENERATOR_ARM = 5, + // These are vendor and tool. + SPV_GENERATOR_KHRONOS_LLVM_TRANSLATOR = 6, + SPV_GENERATOR_KHRONOS_ASSEMBLER = 7, + SPV_GENERATOR_KHRONOS_GLSLANG = 8, + SPV_GENERATOR_NUM_ENTRIES, + SPV_FORCE_16_BIT_ENUM(spv_generator_t) +} spv_generator_t; + +// Evaluates to a well-formed generator magic word from a tool value and +// miscellaneous 16-bit value. +#define SPV_GENERATOR_WORD(TOOL, MISC) \ + ((uint32_t(uint16_t(TOOL)) << 16) | uint16_t(MISC)) +// Returns the tool component of the generator word. +#define SPV_GENERATOR_TOOL_PART(WORD) (uint32_t(WORD) >> 16) +// Returns the misc part of the generator word. +#define SPV_GENERATOR_MISC_PART(WORD) (uint32_t(WORD) & 0xFFFF) + +#endif // SOURCE_SPIRV_CONSTANT_H_ diff --git a/source/spirv_definition.h b/source/spirv_definition.h new file mode 100644 index 000000000..63a4ef0db --- /dev/null +++ b/source/spirv_definition.h @@ -0,0 +1,33 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_SPIRV_DEFINITION_H_ +#define SOURCE_SPIRV_DEFINITION_H_ + +#include + +#include "source/latest_version_spirv_header.h" + +#define spvIsInBitfield(value, bitfield) ((value) == ((value)&bitfield)) + +typedef struct spv_header_t { + uint32_t magic; + uint32_t version; + uint32_t generator; + uint32_t bound; + uint32_t schema; // NOTE: Reserved + const uint32_t* instructions; // NOTE: Unfixed pointer to instruciton stream +} spv_header_t; + +#endif // SOURCE_SPIRV_DEFINITION_H_ diff --git a/source/spirv_endian.cpp b/source/spirv_endian.cpp new file mode 100644 index 000000000..1d7709178 --- /dev/null +++ b/source/spirv_endian.cpp @@ -0,0 +1,77 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/spirv_endian.h" + +#include + +enum { + I32_ENDIAN_LITTLE = 0x03020100ul, + I32_ENDIAN_BIG = 0x00010203ul, +}; + +// This constant value allows the detection of the host machine's endianness. +// Accessing it through the "value" member is valid due to C++11 section 3.10 +// paragraph 10. +static const union { + unsigned char bytes[4]; + uint32_t value; +} o32_host_order = {{0, 1, 2, 3}}; + +#define I32_ENDIAN_HOST (o32_host_order.value) + +uint32_t spvFixWord(const uint32_t word, const spv_endianness_t endian) { + if ((SPV_ENDIANNESS_LITTLE == endian && I32_ENDIAN_HOST == I32_ENDIAN_BIG) || + (SPV_ENDIANNESS_BIG == endian && I32_ENDIAN_HOST == I32_ENDIAN_LITTLE)) { + return (word & 0x000000ff) << 24 | (word & 0x0000ff00) << 8 | + (word & 0x00ff0000) >> 8 | (word & 0xff000000) >> 24; + } + + return word; +} + +uint64_t spvFixDoubleWord(const uint32_t low, const uint32_t high, + const spv_endianness_t endian) { + return (uint64_t(spvFixWord(high, endian)) << 32) | spvFixWord(low, endian); +} + +spv_result_t spvBinaryEndianness(spv_const_binary binary, + spv_endianness_t* pEndian) { + if (!binary->code || !binary->wordCount) return SPV_ERROR_INVALID_BINARY; + if (!pEndian) return SPV_ERROR_INVALID_POINTER; + + uint8_t bytes[4]; + memcpy(bytes, binary->code, sizeof(uint32_t)); + + if (0x03 == bytes[0] && 0x02 == bytes[1] && 0x23 == bytes[2] && + 0x07 == bytes[3]) { + *pEndian = SPV_ENDIANNESS_LITTLE; + return SPV_SUCCESS; + } + + if (0x07 == bytes[0] && 0x23 == bytes[1] && 0x02 == bytes[2] && + 0x03 == bytes[3]) { + *pEndian = SPV_ENDIANNESS_BIG; + return SPV_SUCCESS; + } + + return SPV_ERROR_INVALID_BINARY; +} + +bool spvIsHostEndian(spv_endianness_t endian) { + return ((SPV_ENDIANNESS_LITTLE == endian) && + (I32_ENDIAN_LITTLE == I32_ENDIAN_HOST)) || + ((SPV_ENDIANNESS_BIG == endian) && + (I32_ENDIAN_BIG == I32_ENDIAN_HOST)); +} diff --git a/source/spirv_endian.h b/source/spirv_endian.h new file mode 100644 index 000000000..c2540bec9 --- /dev/null +++ b/source/spirv_endian.h @@ -0,0 +1,37 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_SPIRV_ENDIAN_H_ +#define SOURCE_SPIRV_ENDIAN_H_ + +#include "spirv-tools/libspirv.h" + +// Converts a word in the specified endianness to the host native endianness. +uint32_t spvFixWord(const uint32_t word, const spv_endianness_t endianness); + +// Converts a pair of words in the specified endianness to the host native +// endianness. +uint64_t spvFixDoubleWord(const uint32_t low, const uint32_t high, + const spv_endianness_t endianness); + +// Gets the endianness of the SPIR-V module given in the binary parameter. +// Returns SPV_ENDIANNESS_UNKNOWN if the SPIR-V magic number is invalid, +// otherwise writes the determined endianness into *endian. +spv_result_t spvBinaryEndianness(const spv_const_binary binary, + spv_endianness_t* endian); + +// Returns true if the given endianness matches the host's native endiannes. +bool spvIsHostEndian(spv_endianness_t endian); + +#endif // SOURCE_SPIRV_ENDIAN_H_ diff --git a/source/spirv_optimizer_options.cpp b/source/spirv_optimizer_options.cpp new file mode 100644 index 000000000..30db4e2de --- /dev/null +++ b/source/spirv_optimizer_options.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "source/spirv_optimizer_options.h" + +SPIRV_TOOLS_EXPORT spv_optimizer_options spvOptimizerOptionsCreate(void) { + return new spv_optimizer_options_t(); +} + +SPIRV_TOOLS_EXPORT void spvOptimizerOptionsDestroy( + spv_optimizer_options options) { + delete options; +} + +SPIRV_TOOLS_EXPORT void spvOptimizerOptionsSetRunValidator( + spv_optimizer_options options, bool val) { + options->run_validator_ = val; +} + +SPIRV_TOOLS_EXPORT void spvOptimizerOptionsSetValidatorOptions( + spv_optimizer_options options, spv_validator_options val) { + options->val_options_ = *val; +} +SPIRV_TOOLS_EXPORT void spvOptimizerOptionsSetMaxIdBound( + spv_optimizer_options options, uint32_t val) { + options->max_id_bound_ = val; +} diff --git a/source/spirv_optimizer_options.h b/source/spirv_optimizer_options.h new file mode 100644 index 000000000..1eb4d3f1b --- /dev/null +++ b/source/spirv_optimizer_options.h @@ -0,0 +1,40 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_SPIRV_OPTIMIZER_OPTIONS_H_ +#define SOURCE_SPIRV_OPTIMIZER_OPTIONS_H_ + +#include "source/spirv_validator_options.h" +#include "spirv-tools/libspirv.h" + +// Manages command line options passed to the SPIR-V Validator. New struct +// members may be added for any new option. +struct spv_optimizer_options_t { + spv_optimizer_options_t() + : run_validator_(true), + val_options_(), + max_id_bound_(kDefaultMaxIdBound) {} + + // When true the validator will be run before optimizations are run. + bool run_validator_; + + // Options to pass to the validator if it is run. + spv_validator_options_t val_options_; + + // The maximum value the id bound for a module can have. The Spir-V spec says + // this value must be at least 0x3FFFFF, but implementations can allow for a + // higher value. + uint32_t max_id_bound_; +}; +#endif // SOURCE_SPIRV_OPTIMIZER_OPTIONS_H_ diff --git a/source/spirv_reducer_options.cpp b/source/spirv_reducer_options.cpp new file mode 100644 index 000000000..110ea3e09 --- /dev/null +++ b/source/spirv_reducer_options.cpp @@ -0,0 +1,31 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "source/spirv_reducer_options.h" + +SPIRV_TOOLS_EXPORT spv_reducer_options spvReducerOptionsCreate() { + return new spv_reducer_options_t(); +} + +SPIRV_TOOLS_EXPORT void spvReducerOptionsDestroy(spv_reducer_options options) { + delete options; +} + +SPIRV_TOOLS_EXPORT void spvReducerOptionsSetStepLimit( + spv_reducer_options options, uint32_t step_limit) { + options->step_limit = step_limit; +} diff --git a/source/spirv_reducer_options.h b/source/spirv_reducer_options.h new file mode 100644 index 000000000..d48303ca6 --- /dev/null +++ b/source/spirv_reducer_options.h @@ -0,0 +1,35 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_SPIRV_REDUCER_OPTIONS_H_ +#define SOURCE_SPIRV_REDUCER_OPTIONS_H_ + +#include "spirv-tools/libspirv.h" + +#include +#include + +// The default maximum number of steps for the reducer to run before giving up. +const uint32_t kDefaultStepLimit = 250; + +// Manages command line options passed to the SPIR-V Reducer. New struct +// members may be added for any new option. +struct spv_reducer_options_t { + spv_reducer_options_t() : step_limit(kDefaultStepLimit) {} + + // The number of steps the reducer will run for before giving up. + uint32_t step_limit; +}; + +#endif // SOURCE_SPIRV_REDUCER_OPTIONS_H_ diff --git a/source/spirv_target_env.cpp b/source/spirv_target_env.cpp new file mode 100644 index 000000000..a3aa0af03 --- /dev/null +++ b/source/spirv_target_env.cpp @@ -0,0 +1,250 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/spirv_target_env.h" + +#include + +#include "source/spirv_constant.h" +#include "spirv-tools/libspirv.h" + +const char* spvTargetEnvDescription(spv_target_env env) { + switch (env) { + case SPV_ENV_UNIVERSAL_1_0: + return "SPIR-V 1.0"; + case SPV_ENV_VULKAN_1_0: + return "SPIR-V 1.0 (under Vulkan 1.0 semantics)"; + case SPV_ENV_UNIVERSAL_1_1: + return "SPIR-V 1.1"; + case SPV_ENV_OPENCL_1_2: + return "SPIR-V 1.0 (under OpenCL 1.2 Full Profile semantics)"; + case SPV_ENV_OPENCL_EMBEDDED_1_2: + return "SPIR-V 1.0 (under OpenCL 1.2 Embedded Profile semantics)"; + case SPV_ENV_OPENCL_2_0: + return "SPIR-V 1.0 (under OpenCL 2.0 Full Profile semantics)"; + case SPV_ENV_OPENCL_EMBEDDED_2_0: + return "SPIR-V 1.0 (under OpenCL 2.0 Embedded Profile semantics)"; + case SPV_ENV_OPENCL_2_1: + return "SPIR-V 1.0 (under OpenCL 2.1 Full Profile semantics)"; + case SPV_ENV_OPENCL_EMBEDDED_2_1: + return "SPIR-V 1.0 (under OpenCL 2.1 Embedded Profile semantics)"; + case SPV_ENV_OPENCL_2_2: + return "SPIR-V 1.2 (under OpenCL 2.2 Full Profile semantics)"; + case SPV_ENV_OPENCL_EMBEDDED_2_2: + return "SPIR-V 1.2 (under OpenCL 2.2 Embedded Profile semantics)"; + case SPV_ENV_OPENGL_4_0: + return "SPIR-V 1.0 (under OpenGL 4.0 semantics)"; + case SPV_ENV_OPENGL_4_1: + return "SPIR-V 1.0 (under OpenGL 4.1 semantics)"; + case SPV_ENV_OPENGL_4_2: + return "SPIR-V 1.0 (under OpenGL 4.2 semantics)"; + case SPV_ENV_OPENGL_4_3: + return "SPIR-V 1.0 (under OpenGL 4.3 semantics)"; + case SPV_ENV_OPENGL_4_5: + return "SPIR-V 1.0 (under OpenGL 4.5 semantics)"; + case SPV_ENV_UNIVERSAL_1_2: + return "SPIR-V 1.2"; + case SPV_ENV_UNIVERSAL_1_3: + return "SPIR-V 1.3"; + case SPV_ENV_VULKAN_1_1: + return "SPIR-V 1.3 (under Vulkan 1.1 semantics)"; + case SPV_ENV_WEBGPU_0: + return "SPIR-V 1.3 (under WIP WebGPU semantics)"; + } + return ""; +} + +uint32_t spvVersionForTargetEnv(spv_target_env env) { + switch (env) { + case SPV_ENV_UNIVERSAL_1_0: + case SPV_ENV_VULKAN_1_0: + case SPV_ENV_OPENCL_1_2: + case SPV_ENV_OPENCL_EMBEDDED_1_2: + case SPV_ENV_OPENCL_2_0: + case SPV_ENV_OPENCL_EMBEDDED_2_0: + case SPV_ENV_OPENCL_2_1: + case SPV_ENV_OPENCL_EMBEDDED_2_1: + case SPV_ENV_OPENGL_4_0: + case SPV_ENV_OPENGL_4_1: + case SPV_ENV_OPENGL_4_2: + case SPV_ENV_OPENGL_4_3: + case SPV_ENV_OPENGL_4_5: + return SPV_SPIRV_VERSION_WORD(1, 0); + case SPV_ENV_UNIVERSAL_1_1: + return SPV_SPIRV_VERSION_WORD(1, 1); + case SPV_ENV_UNIVERSAL_1_2: + case SPV_ENV_OPENCL_2_2: + case SPV_ENV_OPENCL_EMBEDDED_2_2: + return SPV_SPIRV_VERSION_WORD(1, 2); + case SPV_ENV_UNIVERSAL_1_3: + case SPV_ENV_VULKAN_1_1: + case SPV_ENV_WEBGPU_0: + return SPV_SPIRV_VERSION_WORD(1, 3); + } + return SPV_SPIRV_VERSION_WORD(0, 0); +} + +bool spvParseTargetEnv(const char* s, spv_target_env* env) { + auto match = [s](const char* b) { + return s && (0 == strncmp(s, b, strlen(b))); + }; + if (match("vulkan1.0")) { + if (env) *env = SPV_ENV_VULKAN_1_0; + return true; + } else if (match("vulkan1.1")) { + if (env) *env = SPV_ENV_VULKAN_1_1; + return true; + } else if (match("spv1.0")) { + if (env) *env = SPV_ENV_UNIVERSAL_1_0; + return true; + } else if (match("spv1.1")) { + if (env) *env = SPV_ENV_UNIVERSAL_1_1; + return true; + } else if (match("spv1.2")) { + if (env) *env = SPV_ENV_UNIVERSAL_1_2; + return true; + } else if (match("spv1.3")) { + if (env) *env = SPV_ENV_UNIVERSAL_1_3; + return true; + } else if (match("opencl1.2embedded")) { + if (env) *env = SPV_ENV_OPENCL_EMBEDDED_1_2; + return true; + } else if (match("opencl1.2")) { + if (env) *env = SPV_ENV_OPENCL_1_2; + return true; + } else if (match("opencl2.0embedded")) { + if (env) *env = SPV_ENV_OPENCL_EMBEDDED_2_0; + return true; + } else if (match("opencl2.0")) { + if (env) *env = SPV_ENV_OPENCL_2_0; + return true; + } else if (match("opencl2.1embedded")) { + if (env) *env = SPV_ENV_OPENCL_EMBEDDED_2_1; + return true; + } else if (match("opencl2.1")) { + if (env) *env = SPV_ENV_OPENCL_2_1; + return true; + } else if (match("opencl2.2embedded")) { + if (env) *env = SPV_ENV_OPENCL_EMBEDDED_2_2; + return true; + } else if (match("opencl2.2")) { + if (env) *env = SPV_ENV_OPENCL_2_2; + return true; + } else if (match("opengl4.0")) { + if (env) *env = SPV_ENV_OPENGL_4_0; + return true; + } else if (match("opengl4.1")) { + if (env) *env = SPV_ENV_OPENGL_4_1; + return true; + } else if (match("opengl4.2")) { + if (env) *env = SPV_ENV_OPENGL_4_2; + return true; + } else if (match("opengl4.3")) { + if (env) *env = SPV_ENV_OPENGL_4_3; + return true; + } else if (match("opengl4.5")) { + if (env) *env = SPV_ENV_OPENGL_4_5; + return true; + } else if (match("webgpu0")) { + if (env) *env = SPV_ENV_WEBGPU_0; + return true; + } else { + if (env) *env = SPV_ENV_UNIVERSAL_1_0; + return false; + } +} + +bool spvIsVulkanEnv(spv_target_env env) { + switch (env) { + case SPV_ENV_UNIVERSAL_1_0: + case SPV_ENV_OPENCL_1_2: + case SPV_ENV_OPENCL_EMBEDDED_1_2: + case SPV_ENV_OPENCL_2_0: + case SPV_ENV_OPENCL_EMBEDDED_2_0: + case SPV_ENV_OPENCL_2_1: + case SPV_ENV_OPENCL_EMBEDDED_2_1: + case SPV_ENV_OPENGL_4_0: + case SPV_ENV_OPENGL_4_1: + case SPV_ENV_OPENGL_4_2: + case SPV_ENV_OPENGL_4_3: + case SPV_ENV_OPENGL_4_5: + case SPV_ENV_UNIVERSAL_1_1: + case SPV_ENV_UNIVERSAL_1_2: + case SPV_ENV_OPENCL_2_2: + case SPV_ENV_OPENCL_EMBEDDED_2_2: + case SPV_ENV_UNIVERSAL_1_3: + case SPV_ENV_WEBGPU_0: + return false; + case SPV_ENV_VULKAN_1_0: + case SPV_ENV_VULKAN_1_1: + return true; + } + return false; +} + +bool spvIsOpenCLEnv(spv_target_env env) { + switch (env) { + case SPV_ENV_UNIVERSAL_1_0: + case SPV_ENV_VULKAN_1_0: + case SPV_ENV_UNIVERSAL_1_1: + case SPV_ENV_OPENGL_4_0: + case SPV_ENV_OPENGL_4_1: + case SPV_ENV_OPENGL_4_2: + case SPV_ENV_OPENGL_4_3: + case SPV_ENV_OPENGL_4_5: + case SPV_ENV_UNIVERSAL_1_2: + case SPV_ENV_UNIVERSAL_1_3: + case SPV_ENV_VULKAN_1_1: + case SPV_ENV_WEBGPU_0: + return false; + case SPV_ENV_OPENCL_1_2: + case SPV_ENV_OPENCL_EMBEDDED_1_2: + case SPV_ENV_OPENCL_2_0: + case SPV_ENV_OPENCL_EMBEDDED_2_0: + case SPV_ENV_OPENCL_EMBEDDED_2_1: + case SPV_ENV_OPENCL_EMBEDDED_2_2: + case SPV_ENV_OPENCL_2_1: + case SPV_ENV_OPENCL_2_2: + return true; + } + return false; +} + +bool spvIsWebGPUEnv(spv_target_env env) { + switch (env) { + case SPV_ENV_UNIVERSAL_1_0: + case SPV_ENV_VULKAN_1_0: + case SPV_ENV_UNIVERSAL_1_1: + case SPV_ENV_OPENGL_4_0: + case SPV_ENV_OPENGL_4_1: + case SPV_ENV_OPENGL_4_2: + case SPV_ENV_OPENGL_4_3: + case SPV_ENV_OPENGL_4_5: + case SPV_ENV_UNIVERSAL_1_2: + case SPV_ENV_UNIVERSAL_1_3: + case SPV_ENV_VULKAN_1_1: + case SPV_ENV_OPENCL_1_2: + case SPV_ENV_OPENCL_EMBEDDED_1_2: + case SPV_ENV_OPENCL_2_0: + case SPV_ENV_OPENCL_EMBEDDED_2_0: + case SPV_ENV_OPENCL_EMBEDDED_2_1: + case SPV_ENV_OPENCL_EMBEDDED_2_2: + case SPV_ENV_OPENCL_2_1: + case SPV_ENV_OPENCL_2_2: + return false; + case SPV_ENV_WEBGPU_0: + return true; + } + return false; +} diff --git a/source/spirv_target_env.h b/source/spirv_target_env.h new file mode 100644 index 000000000..d1bd83512 --- /dev/null +++ b/source/spirv_target_env.h @@ -0,0 +1,36 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_SPIRV_TARGET_ENV_H_ +#define SOURCE_SPIRV_TARGET_ENV_H_ + +#include "spirv-tools/libspirv.h" + +// Parses s into *env and returns true if successful. If unparsable, returns +// false and sets *env to SPV_ENV_UNIVERSAL_1_0. +bool spvParseTargetEnv(const char* s, spv_target_env* env); + +// Returns true if |env| is a VULKAN environment, false otherwise. +bool spvIsVulkanEnv(spv_target_env env); + +// Returns true if |env| is an OPENCL environment, false otherwise. +bool spvIsOpenCLEnv(spv_target_env env); + +// Returns true if |env| is an WEBGPU environment, false otherwise. +bool spvIsWebGPUEnv(spv_target_env env); + +// Returns the version number for the given SPIR-V target environment. +uint32_t spvVersionForTargetEnv(spv_target_env env); + +#endif // SOURCE_SPIRV_TARGET_ENV_H_ diff --git a/source/spirv_validator_options.cpp b/source/spirv_validator_options.cpp new file mode 100644 index 000000000..2e9cf2638 --- /dev/null +++ b/source/spirv_validator_options.cpp @@ -0,0 +1,106 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "source/spirv_validator_options.h" + +bool spvParseUniversalLimitsOptions(const char* s, spv_validator_limit* type) { + auto match = [s](const char* b) { + return s && (0 == strncmp(s, b, strlen(b))); + }; + if (match("--max-struct-members")) { + *type = spv_validator_limit_max_struct_members; + } else if (match("--max-struct_depth")) { + *type = spv_validator_limit_max_struct_depth; + } else if (match("--max-local-variables")) { + *type = spv_validator_limit_max_local_variables; + } else if (match("--max-global-variables")) { + *type = spv_validator_limit_max_global_variables; + } else if (match("--max-switch-branches")) { + *type = spv_validator_limit_max_global_variables; + } else if (match("--max-function-args")) { + *type = spv_validator_limit_max_function_args; + } else if (match("--max-control-flow-nesting-depth")) { + *type = spv_validator_limit_max_control_flow_nesting_depth; + } else if (match("--max-access-chain-indexes")) { + *type = spv_validator_limit_max_access_chain_indexes; + } else if (match("--max-id-bound")) { + *type = spv_validator_limit_max_id_bound; + } else { + // The command line option for this validator limit has not been added. + // Therefore we return false. + return false; + } + + return true; +} + +spv_validator_options spvValidatorOptionsCreate(void) { + return new spv_validator_options_t; +} + +void spvValidatorOptionsDestroy(spv_validator_options options) { + delete options; +} + +void spvValidatorOptionsSetUniversalLimit(spv_validator_options options, + spv_validator_limit limit_type, + uint32_t limit) { + assert(options && "Validator options object may not be Null"); + switch (limit_type) { +#define LIMIT(TYPE, FIELD) \ + case TYPE: \ + options->universal_limits_.FIELD = limit; \ + break; + LIMIT(spv_validator_limit_max_struct_members, max_struct_members) + LIMIT(spv_validator_limit_max_struct_depth, max_struct_depth) + LIMIT(spv_validator_limit_max_local_variables, max_local_variables) + LIMIT(spv_validator_limit_max_global_variables, max_global_variables) + LIMIT(spv_validator_limit_max_switch_branches, max_switch_branches) + LIMIT(spv_validator_limit_max_function_args, max_function_args) + LIMIT(spv_validator_limit_max_control_flow_nesting_depth, + max_control_flow_nesting_depth) + LIMIT(spv_validator_limit_max_access_chain_indexes, + max_access_chain_indexes) + LIMIT(spv_validator_limit_max_id_bound, max_id_bound) +#undef LIMIT + } +} + +void spvValidatorOptionsSetRelaxStoreStruct(spv_validator_options options, + bool val) { + options->relax_struct_store = val; +} + +void spvValidatorOptionsSetRelaxLogicalPointer(spv_validator_options options, + bool val) { + options->relax_logical_pointer = val; +} + +void spvValidatorOptionsSetRelaxBlockLayout(spv_validator_options options, + bool val) { + options->relax_block_layout = val; +} + +void spvValidatorOptionsSetScalarBlockLayout(spv_validator_options options, + bool val) { + options->scalar_block_layout = val; +} + +void spvValidatorOptionsSetSkipBlockLayout(spv_validator_options options, + bool val) { + options->skip_block_layout = val; +} diff --git a/source/spirv_validator_options.h b/source/spirv_validator_options.h new file mode 100644 index 000000000..f426ebfef --- /dev/null +++ b/source/spirv_validator_options.h @@ -0,0 +1,57 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_SPIRV_VALIDATOR_OPTIONS_H_ +#define SOURCE_SPIRV_VALIDATOR_OPTIONS_H_ + +#include "spirv-tools/libspirv.h" + +// Return true if the command line option for the validator limit is valid (Also +// returns the Enum for option in this case). Returns false otherwise. +bool spvParseUniversalLimitsOptions(const char* s, spv_validator_limit* limit); + +// Default initialization of this structure is to the default Universal Limits +// described in the SPIR-V Spec. +struct validator_universal_limits_t { + uint32_t max_struct_members{16383}; + uint32_t max_struct_depth{255}; + uint32_t max_local_variables{524287}; + uint32_t max_global_variables{65535}; + uint32_t max_switch_branches{16383}; + uint32_t max_function_args{255}; + uint32_t max_control_flow_nesting_depth{1023}; + uint32_t max_access_chain_indexes{255}; + uint32_t max_id_bound{0x3FFFFF}; +}; + +// Manages command line options passed to the SPIR-V Validator. New struct +// members may be added for any new option. +struct spv_validator_options_t { + spv_validator_options_t() + : universal_limits_(), + relax_struct_store(false), + relax_logical_pointer(false), + relax_block_layout(false), + scalar_block_layout(false), + skip_block_layout(false) {} + + validator_universal_limits_t universal_limits_; + bool relax_struct_store; + bool relax_logical_pointer; + bool relax_block_layout; + bool scalar_block_layout; + bool skip_block_layout; +}; + +#endif // SOURCE_SPIRV_VALIDATOR_OPTIONS_H_ diff --git a/source/table.cpp b/source/table.cpp new file mode 100644 index 000000000..b10d776da --- /dev/null +++ b/source/table.cpp @@ -0,0 +1,63 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/table.h" + +#include + +spv_context spvContextCreate(spv_target_env env) { + switch (env) { + case SPV_ENV_UNIVERSAL_1_0: + case SPV_ENV_VULKAN_1_0: + case SPV_ENV_UNIVERSAL_1_1: + case SPV_ENV_OPENCL_1_2: + case SPV_ENV_OPENCL_EMBEDDED_1_2: + case SPV_ENV_OPENCL_2_0: + case SPV_ENV_OPENCL_EMBEDDED_2_0: + case SPV_ENV_OPENCL_2_1: + case SPV_ENV_OPENCL_EMBEDDED_2_1: + case SPV_ENV_OPENCL_2_2: + case SPV_ENV_OPENCL_EMBEDDED_2_2: + case SPV_ENV_OPENGL_4_0: + case SPV_ENV_OPENGL_4_1: + case SPV_ENV_OPENGL_4_2: + case SPV_ENV_OPENGL_4_3: + case SPV_ENV_OPENGL_4_5: + case SPV_ENV_UNIVERSAL_1_2: + case SPV_ENV_UNIVERSAL_1_3: + case SPV_ENV_VULKAN_1_1: + case SPV_ENV_WEBGPU_0: + break; + default: + return nullptr; + } + + spv_opcode_table opcode_table; + spv_operand_table operand_table; + spv_ext_inst_table ext_inst_table; + + spvOpcodeTableGet(&opcode_table, env); + spvOperandTableGet(&operand_table, env); + spvExtInstTableGet(&ext_inst_table, env); + + return new spv_context_t{env, opcode_table, operand_table, ext_inst_table, + nullptr /* a null default consumer */}; +} + +void spvContextDestroy(spv_context context) { delete context; } + +void spvtools::SetContextMessageConsumer(spv_context context, + spvtools::MessageConsumer consumer) { + context->consumer = std::move(consumer); +} diff --git a/source/table.h b/source/table.h new file mode 100644 index 000000000..64d73dbb9 --- /dev/null +++ b/source/table.h @@ -0,0 +1,132 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_TABLE_H_ +#define SOURCE_TABLE_H_ + +#include "source/latest_version_spirv_header.h" + +#include "source/extensions.h" +#include "spirv-tools/libspirv.hpp" + +typedef struct spv_opcode_desc_t { + const char* name; + const SpvOp opcode; + const uint32_t numCapabilities; + const SpvCapability* capabilities; + // operandTypes[0..numTypes-1] describe logical operands for the instruction. + // The operand types include result id and result-type id, followed by + // the types of arguments. + const uint16_t numTypes; + spv_operand_type_t operandTypes[16]; // TODO: Smaller/larger? + const bool hasResult; // Does the instruction have a result ID operand? + const bool hasType; // Does the instruction have a type ID operand? + // A set of extensions that enable this feature. If empty then this operand + // value is in core and its availability is subject to minVersion. The + // assembler, binary parser, and disassembler ignore this rule, so you can + // freely process invalid modules. + const uint32_t numExtensions; + const spvtools::Extension* extensions; + // Minimal core SPIR-V version required for this feature, if without + // extensions. ~0u means reserved for future use. ~0u and non-empty extension + // lists means only available in extensions. + const uint32_t minVersion; +} spv_opcode_desc_t; + +typedef struct spv_operand_desc_t { + const char* name; + const uint32_t value; + const uint32_t numCapabilities; + const SpvCapability* capabilities; + // A set of extensions that enable this feature. If empty then this operand + // value is in core and its availability is subject to minVersion. The + // assembler, binary parser, and disassembler ignore this rule, so you can + // freely process invalid modules. + const uint32_t numExtensions; + const spvtools::Extension* extensions; + const spv_operand_type_t operandTypes[16]; // TODO: Smaller/larger? + // Minimal core SPIR-V version required for this feature, if without + // extensions. ~0u means reserved for future use. ~0u and non-empty extension + // lists means only available in extensions. + const uint32_t minVersion; +} spv_operand_desc_t; + +typedef struct spv_operand_desc_group_t { + const spv_operand_type_t type; + const uint32_t count; + const spv_operand_desc_t* entries; +} spv_operand_desc_group_t; + +typedef struct spv_ext_inst_desc_t { + const char* name; + const uint32_t ext_inst; + const uint32_t numCapabilities; + const SpvCapability* capabilities; + const spv_operand_type_t operandTypes[16]; // TODO: Smaller/larger? +} spv_ext_inst_desc_t; + +typedef struct spv_ext_inst_group_t { + const spv_ext_inst_type_t type; + const uint32_t count; + const spv_ext_inst_desc_t* entries; +} spv_ext_inst_group_t; + +typedef struct spv_opcode_table_t { + const uint32_t count; + const spv_opcode_desc_t* entries; +} spv_opcode_table_t; + +typedef struct spv_operand_table_t { + const uint32_t count; + const spv_operand_desc_group_t* types; +} spv_operand_table_t; + +typedef struct spv_ext_inst_table_t { + const uint32_t count; + const spv_ext_inst_group_t* groups; +} spv_ext_inst_table_t; + +typedef const spv_opcode_desc_t* spv_opcode_desc; +typedef const spv_operand_desc_t* spv_operand_desc; +typedef const spv_ext_inst_desc_t* spv_ext_inst_desc; + +typedef const spv_opcode_table_t* spv_opcode_table; +typedef const spv_operand_table_t* spv_operand_table; +typedef const spv_ext_inst_table_t* spv_ext_inst_table; + +struct spv_context_t { + const spv_target_env target_env; + const spv_opcode_table opcode_table; + const spv_operand_table operand_table; + const spv_ext_inst_table ext_inst_table; + spvtools::MessageConsumer consumer; +}; + +namespace spvtools { + +// Sets the message consumer to |consumer| in the given |context|. The original +// message consumer will be overwritten. +void SetContextMessageConsumer(spv_context context, MessageConsumer consumer); +} // namespace spvtools + +// Populates *table with entries for env. +spv_result_t spvOpcodeTableGet(spv_opcode_table* table, spv_target_env env); + +// Populates *table with entries for env. +spv_result_t spvOperandTableGet(spv_operand_table* table, spv_target_env env); + +// Populates *table with entries for env. +spv_result_t spvExtInstTableGet(spv_ext_inst_table* table, spv_target_env env); + +#endif // SOURCE_TABLE_H_ diff --git a/source/text.cpp b/source/text.cpp new file mode 100644 index 000000000..adaf79652 --- /dev/null +++ b/source/text.cpp @@ -0,0 +1,811 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/text.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/assembly_grammar.h" +#include "source/binary.h" +#include "source/diagnostic.h" +#include "source/ext_inst.h" +#include "source/instruction.h" +#include "source/opcode.h" +#include "source/operand.h" +#include "source/spirv_constant.h" +#include "source/spirv_target_env.h" +#include "source/table.h" +#include "source/text_handler.h" +#include "source/util/bitutils.h" +#include "source/util/parse_number.h" +#include "spirv-tools/libspirv.h" + +bool spvIsValidIDCharacter(const char value) { + return value == '_' || 0 != ::isalnum(value); +} + +// Returns true if the given string represents a valid ID name. +bool spvIsValidID(const char* textValue) { + const char* c = textValue; + for (; *c != '\0'; ++c) { + if (!spvIsValidIDCharacter(*c)) { + return false; + } + } + // If the string was empty, then the ID also is not valid. + return c != textValue; +} + +// Text API + +spv_result_t spvTextToLiteral(const char* textValue, spv_literal_t* pLiteral) { + bool isSigned = false; + int numPeriods = 0; + bool isString = false; + + const size_t len = strlen(textValue); + if (len == 0) return SPV_FAILED_MATCH; + + for (uint64_t index = 0; index < len; ++index) { + switch (textValue[index]) { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + break; + case '.': + numPeriods++; + break; + case '-': + if (index == 0) { + isSigned = true; + } else { + isString = true; + } + break; + default: + isString = true; + index = len; // break out of the loop too. + break; + } + } + + pLiteral->type = spv_literal_type_t(99); + + if (isString || numPeriods > 1 || (isSigned && len == 1)) { + if (len < 2 || textValue[0] != '"' || textValue[len - 1] != '"') + return SPV_FAILED_MATCH; + bool escaping = false; + for (const char* val = textValue + 1; val != textValue + len - 1; ++val) { + if ((*val == '\\') && (!escaping)) { + escaping = true; + } else { + // Have to save space for the null-terminator + if (pLiteral->str.size() >= SPV_LIMIT_LITERAL_STRING_BYTES_MAX) + return SPV_ERROR_OUT_OF_MEMORY; + pLiteral->str.push_back(*val); + escaping = false; + } + } + + pLiteral->type = SPV_LITERAL_TYPE_STRING; + } else if (numPeriods == 1) { + double d = std::strtod(textValue, nullptr); + float f = (float)d; + if (d == (double)f) { + pLiteral->type = SPV_LITERAL_TYPE_FLOAT_32; + pLiteral->value.f = f; + } else { + pLiteral->type = SPV_LITERAL_TYPE_FLOAT_64; + pLiteral->value.d = d; + } + } else if (isSigned) { + int64_t i64 = strtoll(textValue, nullptr, 10); + int32_t i32 = (int32_t)i64; + if (i64 == (int64_t)i32) { + pLiteral->type = SPV_LITERAL_TYPE_INT_32; + pLiteral->value.i32 = i32; + } else { + pLiteral->type = SPV_LITERAL_TYPE_INT_64; + pLiteral->value.i64 = i64; + } + } else { + uint64_t u64 = strtoull(textValue, nullptr, 10); + uint32_t u32 = (uint32_t)u64; + if (u64 == (uint64_t)u32) { + pLiteral->type = SPV_LITERAL_TYPE_UINT_32; + pLiteral->value.u32 = u32; + } else { + pLiteral->type = SPV_LITERAL_TYPE_UINT_64; + pLiteral->value.u64 = u64; + } + } + + return SPV_SUCCESS; +} + +namespace { + +/// Parses an immediate integer from text, guarding against overflow. If +/// successful, adds the parsed value to pInst, advances the context past it, +/// and returns SPV_SUCCESS. Otherwise, leaves pInst alone, emits diagnostics, +/// and returns SPV_ERROR_INVALID_TEXT. +spv_result_t encodeImmediate(spvtools::AssemblyContext* context, + const char* text, spv_instruction_t* pInst) { + assert(*text == '!'); + uint32_t parse_result; + if (!spvtools::utils::ParseNumber(text + 1, &parse_result)) { + return context->diagnostic(SPV_ERROR_INVALID_TEXT) + << "Invalid immediate integer: !" << text + 1; + } + context->binaryEncodeU32(parse_result, pInst); + context->seekForward(static_cast(strlen(text))); + return SPV_SUCCESS; +} + +} // anonymous namespace + +/// @brief Translate an Opcode operand to binary form +/// +/// @param[in] grammar the grammar to use for compilation +/// @param[in, out] context the dynamic compilation info +/// @param[in] type of the operand +/// @param[in] textValue word of text to be parsed +/// @param[out] pInst return binary Opcode +/// @param[in,out] pExpectedOperands the operand types expected +/// +/// @return result code +spv_result_t spvTextEncodeOperand(const spvtools::AssemblyGrammar& grammar, + spvtools::AssemblyContext* context, + const spv_operand_type_t type, + const char* textValue, + spv_instruction_t* pInst, + spv_operand_pattern_t* pExpectedOperands) { + // NOTE: Handle immediate int in the stream + if ('!' == textValue[0]) { + if (auto error = encodeImmediate(context, textValue, pInst)) { + return error; + } + *pExpectedOperands = + spvAlternatePatternFollowingImmediate(*pExpectedOperands); + return SPV_SUCCESS; + } + + // Optional literal operands can fail to parse. In that case use + // SPV_FAILED_MATCH to avoid emitting a diagostic. Use the following + // for those situations. + spv_result_t error_code_for_literals = + spvOperandIsOptional(type) ? SPV_FAILED_MATCH : SPV_ERROR_INVALID_TEXT; + + switch (type) { + case SPV_OPERAND_TYPE_ID: + case SPV_OPERAND_TYPE_TYPE_ID: + case SPV_OPERAND_TYPE_RESULT_ID: + case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: + case SPV_OPERAND_TYPE_SCOPE_ID: + case SPV_OPERAND_TYPE_OPTIONAL_ID: { + if ('%' == textValue[0]) { + textValue++; + } else { + return context->diagnostic() << "Expected id to start with %."; + } + if (!spvIsValidID(textValue)) { + return context->diagnostic() << "Invalid ID " << textValue; + } + const uint32_t id = context->spvNamedIdAssignOrGet(textValue); + if (type == SPV_OPERAND_TYPE_TYPE_ID) pInst->resultTypeId = id; + spvInstructionAddWord(pInst, id); + + // Set the extended instruction type. + // The import set id is the 3rd operand of OpExtInst. + if (pInst->opcode == SpvOpExtInst && pInst->words.size() == 4) { + auto ext_inst_type = context->getExtInstTypeForId(pInst->words[3]); + if (ext_inst_type == SPV_EXT_INST_TYPE_NONE) { + return context->diagnostic() + << "Invalid extended instruction import Id " + << pInst->words[2]; + } + pInst->extInstType = ext_inst_type; + } + } break; + + case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: { + // The assembler accepts the symbolic name for an extended instruction, + // and emits its corresponding number. + spv_ext_inst_desc extInst; + if (grammar.lookupExtInst(pInst->extInstType, textValue, &extInst)) { + return context->diagnostic() + << "Invalid extended instruction name '" << textValue << "'."; + } + spvInstructionAddWord(pInst, extInst->ext_inst); + + // Prepare to parse the operands for the extended instructions. + spvPushOperandTypes(extInst->operandTypes, pExpectedOperands); + } break; + + case SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER: { + // The assembler accepts the symbolic name for the opcode, but without + // the "Op" prefix. For example, "IAdd" is accepted. The number + // of the opcode is emitted. + SpvOp opcode; + if (grammar.lookupSpecConstantOpcode(textValue, &opcode)) { + return context->diagnostic() << "Invalid " << spvOperandTypeStr(type) + << " '" << textValue << "'."; + } + spv_opcode_desc opcodeEntry = nullptr; + if (grammar.lookupOpcode(opcode, &opcodeEntry)) { + return context->diagnostic(SPV_ERROR_INTERNAL) + << "OpSpecConstant opcode table out of sync"; + } + spvInstructionAddWord(pInst, uint32_t(opcodeEntry->opcode)); + + // Prepare to parse the operands for the opcode. Except skip the + // type Id and result Id, since they've already been processed. + assert(opcodeEntry->hasType); + assert(opcodeEntry->hasResult); + assert(opcodeEntry->numTypes >= 2); + spvPushOperandTypes(opcodeEntry->operandTypes + 2, pExpectedOperands); + } break; + + case SPV_OPERAND_TYPE_LITERAL_INTEGER: + case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: { + // The current operand is an *unsigned* 32-bit integer. + // That's just how the grammar works. + spvtools::IdType expected_type = { + 32, false, spvtools::IdTypeClass::kScalarIntegerType}; + if (auto error = context->binaryEncodeNumericLiteral( + textValue, error_code_for_literals, expected_type, pInst)) { + return error; + } + } break; + + case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_NUMBER: + // This is a context-independent literal number which can be a 32-bit + // number of floating point value. + if (auto error = context->binaryEncodeNumericLiteral( + textValue, error_code_for_literals, spvtools::kUnknownType, + pInst)) { + return error; + } + break; + + case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER: + case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: { + spvtools::IdType expected_type = spvtools::kUnknownType; + // The encoding for OpConstant, OpSpecConstant and OpSwitch all + // depend on either their own result-id or the result-id of + // one of their parameters. + if (SpvOpConstant == pInst->opcode || + SpvOpSpecConstant == pInst->opcode) { + // The type of the literal is determined by the type Id of the + // instruction. + expected_type = + context->getTypeOfTypeGeneratingValue(pInst->resultTypeId); + if (!spvtools::isScalarFloating(expected_type) && + !spvtools::isScalarIntegral(expected_type)) { + spv_opcode_desc d; + const char* opcode_name = "opcode"; + if (SPV_SUCCESS == grammar.lookupOpcode(pInst->opcode, &d)) { + opcode_name = d->name; + } + return context->diagnostic() + << "Type for " << opcode_name + << " must be a scalar floating point or integer type"; + } + } else if (pInst->opcode == SpvOpSwitch) { + // The type of the literal is the same as the type of the selector. + expected_type = context->getTypeOfValueInstruction(pInst->words[1]); + if (!spvtools::isScalarIntegral(expected_type)) { + return context->diagnostic() + << "The selector operand for OpSwitch must be the result" + " of an instruction that generates an integer scalar"; + } + } + if (auto error = context->binaryEncodeNumericLiteral( + textValue, error_code_for_literals, expected_type, pInst)) { + return error; + } + } break; + + case SPV_OPERAND_TYPE_LITERAL_STRING: + case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: { + spv_literal_t literal = {}; + spv_result_t error = spvTextToLiteral(textValue, &literal); + if (error != SPV_SUCCESS) { + if (error == SPV_ERROR_OUT_OF_MEMORY) return error; + return context->diagnostic(error_code_for_literals) + << "Invalid literal string '" << textValue << "'."; + } + if (literal.type != SPV_LITERAL_TYPE_STRING) { + return context->diagnostic() + << "Expected literal string, found literal number '" << textValue + << "'."; + } + + // NOTE: Special case for extended instruction library import + if (SpvOpExtInstImport == pInst->opcode) { + const spv_ext_inst_type_t ext_inst_type = + spvExtInstImportTypeGet(literal.str.c_str()); + if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) { + return context->diagnostic() + << "Invalid extended instruction import '" << literal.str + << "'"; + } + if ((error = context->recordIdAsExtInstImport(pInst->words[1], + ext_inst_type))) + return error; + } + + if (context->binaryEncodeString(literal.str.c_str(), pInst)) + return SPV_ERROR_INVALID_TEXT; + } break; + + // Masks. + case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE: + case SPV_OPERAND_TYPE_FUNCTION_CONTROL: + case SPV_OPERAND_TYPE_LOOP_CONTROL: + case SPV_OPERAND_TYPE_IMAGE: + case SPV_OPERAND_TYPE_OPTIONAL_IMAGE: + case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS: + case SPV_OPERAND_TYPE_SELECTION_CONTROL: + case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS: { + uint32_t value; + if (grammar.parseMaskOperand(type, textValue, &value)) { + return context->diagnostic() << "Invalid " << spvOperandTypeStr(type) + << " operand '" << textValue << "'."; + } + if (auto error = context->binaryEncodeU32(value, pInst)) return error; + // Prepare to parse the operands for this logical operand. + grammar.pushOperandTypesForMask(type, value, pExpectedOperands); + } break; + case SPV_OPERAND_TYPE_OPTIONAL_CIV: { + auto error = spvTextEncodeOperand( + grammar, context, SPV_OPERAND_TYPE_OPTIONAL_LITERAL_NUMBER, textValue, + pInst, pExpectedOperands); + if (error == SPV_FAILED_MATCH) { + // It's not a literal number -- is it a literal string? + error = spvTextEncodeOperand(grammar, context, + SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING, + textValue, pInst, pExpectedOperands); + } + if (error == SPV_FAILED_MATCH) { + // It's not a literal -- is it an ID? + error = + spvTextEncodeOperand(grammar, context, SPV_OPERAND_TYPE_OPTIONAL_ID, + textValue, pInst, pExpectedOperands); + } + if (error) { + return context->diagnostic(error) + << "Invalid word following !: " << textValue; + } + if (pExpectedOperands->empty()) { + pExpectedOperands->push_back(SPV_OPERAND_TYPE_OPTIONAL_CIV); + } + } break; + default: { + // NOTE: All non literal operands are handled here using the operand + // table. + spv_operand_desc entry; + if (grammar.lookupOperand(type, textValue, strlen(textValue), &entry)) { + return context->diagnostic() << "Invalid " << spvOperandTypeStr(type) + << " '" << textValue << "'."; + } + if (context->binaryEncodeU32(entry->value, pInst)) { + return context->diagnostic() << "Invalid " << spvOperandTypeStr(type) + << " '" << textValue << "'."; + } + + // Prepare to parse the operands for this logical operand. + spvPushOperandTypes(entry->operandTypes, pExpectedOperands); + } break; + } + return SPV_SUCCESS; +} + +namespace { + +/// Encodes an instruction started by ! at the given position in text. +/// +/// Puts the encoded words into *pInst. If successful, moves position past the +/// instruction and returns SPV_SUCCESS. Otherwise, returns an error code and +/// leaves position pointing to the error in text. +spv_result_t encodeInstructionStartingWithImmediate( + const spvtools::AssemblyGrammar& grammar, + spvtools::AssemblyContext* context, spv_instruction_t* pInst) { + std::string firstWord; + spv_position_t nextPosition = {}; + auto error = context->getWord(&firstWord, &nextPosition); + if (error) return context->diagnostic(error) << "Internal Error"; + + if ((error = encodeImmediate(context, firstWord.c_str(), pInst))) { + return error; + } + while (context->advance() != SPV_END_OF_STREAM) { + // A beginning of a new instruction means we're done. + if (context->isStartOfNewInst()) return SPV_SUCCESS; + + // Otherwise, there must be an operand that's either a literal, an ID, or + // an immediate. + std::string operandValue; + if ((error = context->getWord(&operandValue, &nextPosition))) + return context->diagnostic(error) << "Internal Error"; + + if (operandValue == "=") + return context->diagnostic() << firstWord << " not allowed before =."; + + // Needed to pass to spvTextEncodeOpcode(), but it shouldn't ever be + // expanded. + spv_operand_pattern_t dummyExpectedOperands; + error = spvTextEncodeOperand( + grammar, context, SPV_OPERAND_TYPE_OPTIONAL_CIV, operandValue.c_str(), + pInst, &dummyExpectedOperands); + if (error) return error; + context->setPosition(nextPosition); + } + return SPV_SUCCESS; +} + +/// @brief Translate single Opcode and operands to binary form +/// +/// @param[in] grammar the grammar to use for compilation +/// @param[in, out] context the dynamic compilation info +/// @param[in] text stream to translate +/// @param[out] pInst returned binary Opcode +/// @param[in,out] pPosition in the text stream +/// +/// @return result code +spv_result_t spvTextEncodeOpcode(const spvtools::AssemblyGrammar& grammar, + spvtools::AssemblyContext* context, + spv_instruction_t* pInst) { + // Check for ! first. + if ('!' == context->peek()) { + return encodeInstructionStartingWithImmediate(grammar, context, pInst); + } + + std::string firstWord; + spv_position_t nextPosition = {}; + spv_result_t error = context->getWord(&firstWord, &nextPosition); + if (error) return context->diagnostic() << "Internal Error"; + + std::string opcodeName; + std::string result_id; + spv_position_t result_id_position = {}; + if (context->startsWithOp()) { + opcodeName = firstWord; + } else { + result_id = firstWord; + if ('%' != result_id.front()) { + return context->diagnostic() + << "Expected or at the beginning " + "of an instruction, found '" + << result_id << "'."; + } + result_id_position = context->position(); + + // The '=' sign. + context->setPosition(nextPosition); + if (context->advance()) + return context->diagnostic() << "Expected '=', found end of stream."; + std::string equal_sign; + error = context->getWord(&equal_sign, &nextPosition); + if ("=" != equal_sign) + return context->diagnostic() << "'=' expected after result id."; + + // The after the '=' sign. + context->setPosition(nextPosition); + if (context->advance()) + return context->diagnostic() << "Expected opcode, found end of stream."; + error = context->getWord(&opcodeName, &nextPosition); + if (error) return context->diagnostic(error) << "Internal Error"; + if (!context->startsWithOp()) { + return context->diagnostic() + << "Invalid Opcode prefix '" << opcodeName << "'."; + } + } + + // NOTE: The table contains Opcode names without the "Op" prefix. + const char* pInstName = opcodeName.data() + 2; + + spv_opcode_desc opcodeEntry; + error = grammar.lookupOpcode(pInstName, &opcodeEntry); + if (error) { + return context->diagnostic(error) + << "Invalid Opcode name '" << opcodeName << "'"; + } + if (opcodeEntry->hasResult && result_id.empty()) { + return context->diagnostic() + << "Expected at the beginning of an instruction, found '" + << firstWord << "'."; + } + pInst->opcode = opcodeEntry->opcode; + context->setPosition(nextPosition); + // Reserve the first word for the instruction. + spvInstructionAddWord(pInst, 0); + + // Maintains the ordered list of expected operand types. + // For many instructions we only need the {numTypes, operandTypes} + // entries in opcodeEntry. However, sometimes we need to modify + // the list as we parse the operands. This occurs when an operand + // has its own logical operands (such as the LocalSize operand for + // ExecutionMode), or for extended instructions that may have their + // own operands depending on the selected extended instruction. + spv_operand_pattern_t expectedOperands; + expectedOperands.reserve(opcodeEntry->numTypes); + for (auto i = 0; i < opcodeEntry->numTypes; i++) + expectedOperands.push_back( + opcodeEntry->operandTypes[opcodeEntry->numTypes - i - 1]); + + while (!expectedOperands.empty()) { + const spv_operand_type_t type = expectedOperands.back(); + expectedOperands.pop_back(); + + // Expand optional tuples lazily. + if (spvExpandOperandSequenceOnce(type, &expectedOperands)) continue; + + if (type == SPV_OPERAND_TYPE_RESULT_ID && !result_id.empty()) { + // Handle the for value generating instructions. + // We've already consumed it from the text stream. Here + // we inject its words into the instruction. + spv_position_t temp_pos = context->position(); + error = spvTextEncodeOperand(grammar, context, SPV_OPERAND_TYPE_RESULT_ID, + result_id.c_str(), pInst, nullptr); + result_id_position = context->position(); + // Because we are injecting we have to reset the position afterwards. + context->setPosition(temp_pos); + if (error) return error; + } else { + // Find the next word. + error = context->advance(); + if (error == SPV_END_OF_STREAM) { + if (spvOperandIsOptional(type)) { + // This would have been the last potential operand for the + // instruction, + // and we didn't find one. We're finished parsing this instruction. + break; + } else { + return context->diagnostic() + << "Expected operand, found end of stream."; + } + } + assert(error == SPV_SUCCESS && "Somebody added another way to fail"); + + if (context->isStartOfNewInst()) { + if (spvOperandIsOptional(type)) { + break; + } else { + return context->diagnostic() + << "Expected operand, found next instruction instead."; + } + } + + std::string operandValue; + error = context->getWord(&operandValue, &nextPosition); + if (error) return context->diagnostic(error) << "Internal Error"; + + error = spvTextEncodeOperand(grammar, context, type, operandValue.c_str(), + pInst, &expectedOperands); + + if (error == SPV_FAILED_MATCH && spvOperandIsOptional(type)) + return SPV_SUCCESS; + + if (error) return error; + + context->setPosition(nextPosition); + } + } + + if (spvOpcodeGeneratesType(pInst->opcode)) { + if (context->recordTypeDefinition(pInst) != SPV_SUCCESS) { + return SPV_ERROR_INVALID_TEXT; + } + } else if (opcodeEntry->hasType) { + // SPIR-V dictates that if an instruction has both a return value and a + // type ID then the type id is first, and the return value is second. + assert(opcodeEntry->hasResult && + "Unknown opcode: has a type but no result."); + context->recordTypeIdForValue(pInst->words[2], pInst->words[1]); + } + + if (pInst->words.size() > SPV_LIMIT_INSTRUCTION_WORD_COUNT_MAX) { + return context->diagnostic() + << "Instruction too long: " << pInst->words.size() + << " words, but the limit is " + << SPV_LIMIT_INSTRUCTION_WORD_COUNT_MAX; + } + + pInst->words[0] = + spvOpcodeMake(uint16_t(pInst->words.size()), opcodeEntry->opcode); + + return SPV_SUCCESS; +} + +enum { kAssemblerVersion = 0 }; + +// Populates a binary stream's |header|. The target environment is specified via +// |env| and Id bound is via |bound|. +spv_result_t SetHeader(spv_target_env env, const uint32_t bound, + uint32_t* header) { + if (!header) return SPV_ERROR_INVALID_BINARY; + + header[SPV_INDEX_MAGIC_NUMBER] = SpvMagicNumber; + header[SPV_INDEX_VERSION_NUMBER] = spvVersionForTargetEnv(env); + header[SPV_INDEX_GENERATOR_NUMBER] = + SPV_GENERATOR_WORD(SPV_GENERATOR_KHRONOS_ASSEMBLER, kAssemblerVersion); + header[SPV_INDEX_BOUND] = bound; + header[SPV_INDEX_SCHEMA] = 0; // NOTE: Reserved + + return SPV_SUCCESS; +} + +// Collects all numeric ids in the module source into |numeric_ids|. +// This function is essentially a dry-run of spvTextToBinary. +spv_result_t GetNumericIds(const spvtools::AssemblyGrammar& grammar, + const spvtools::MessageConsumer& consumer, + const spv_text text, + std::set* numeric_ids) { + spvtools::AssemblyContext context(text, consumer); + + if (!text->str) return context.diagnostic() << "Missing assembly text."; + + if (!grammar.isValid()) { + return SPV_ERROR_INVALID_TABLE; + } + + // Skip past whitespace and comments. + context.advance(); + + while (context.hasText()) { + spv_instruction_t inst; + + if (spvTextEncodeOpcode(grammar, &context, &inst)) { + return SPV_ERROR_INVALID_TEXT; + } + + if (context.advance()) break; + } + + *numeric_ids = context.GetNumericIds(); + return SPV_SUCCESS; +} + +// Translates a given assembly language module into binary form. +// If a diagnostic is generated, it is not yet marked as being +// for a text-based input. +spv_result_t spvTextToBinaryInternal(const spvtools::AssemblyGrammar& grammar, + const spvtools::MessageConsumer& consumer, + const spv_text text, + const uint32_t options, + spv_binary* pBinary) { + // The ids in this set will have the same values both in source and binary. + // All other ids will be generated by filling in the gaps. + std::set ids_to_preserve; + + if (options & SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS) { + // Collect all numeric ids from the source into ids_to_preserve. + const spv_result_t result = + GetNumericIds(grammar, consumer, text, &ids_to_preserve); + if (result != SPV_SUCCESS) return result; + } + + spvtools::AssemblyContext context(text, consumer, std::move(ids_to_preserve)); + + if (!text->str) return context.diagnostic() << "Missing assembly text."; + + if (!grammar.isValid()) { + return SPV_ERROR_INVALID_TABLE; + } + if (!pBinary) return SPV_ERROR_INVALID_POINTER; + + std::vector instructions; + + // Skip past whitespace and comments. + context.advance(); + + while (context.hasText()) { + instructions.push_back({}); + spv_instruction_t& inst = instructions.back(); + + if (spvTextEncodeOpcode(grammar, &context, &inst)) { + return SPV_ERROR_INVALID_TEXT; + } + + if (context.advance()) break; + } + + size_t totalSize = SPV_INDEX_INSTRUCTION; + for (auto& inst : instructions) { + totalSize += inst.words.size(); + } + + uint32_t* data = new uint32_t[totalSize]; + if (!data) return SPV_ERROR_OUT_OF_MEMORY; + uint64_t currentIndex = SPV_INDEX_INSTRUCTION; + for (auto& inst : instructions) { + memcpy(data + currentIndex, inst.words.data(), + sizeof(uint32_t) * inst.words.size()); + currentIndex += inst.words.size(); + } + + if (auto error = SetHeader(grammar.target_env(), context.getBound(), data)) + return error; + + spv_binary binary = new spv_binary_t(); + if (!binary) { + delete[] data; + return SPV_ERROR_OUT_OF_MEMORY; + } + binary->code = data; + binary->wordCount = totalSize; + + *pBinary = binary; + + return SPV_SUCCESS; +} + +} // anonymous namespace + +spv_result_t spvTextToBinary(const spv_const_context context, + const char* input_text, + const size_t input_text_size, spv_binary* pBinary, + spv_diagnostic* pDiagnostic) { + return spvTextToBinaryWithOptions(context, input_text, input_text_size, + SPV_BINARY_TO_TEXT_OPTION_NONE, pBinary, + pDiagnostic); +} + +spv_result_t spvTextToBinaryWithOptions(const spv_const_context context, + const char* input_text, + const size_t input_text_size, + const uint32_t options, + spv_binary* pBinary, + spv_diagnostic* pDiagnostic) { + spv_context_t hijack_context = *context; + if (pDiagnostic) { + *pDiagnostic = nullptr; + spvtools::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); + } + + spv_text_t text = {input_text, input_text_size}; + spvtools::AssemblyGrammar grammar(&hijack_context); + + spv_result_t result = spvTextToBinaryInternal( + grammar, hijack_context.consumer, &text, options, pBinary); + if (pDiagnostic && *pDiagnostic) (*pDiagnostic)->isTextSource = true; + + return result; +} + +void spvTextDestroy(spv_text text) { + if (!text) return; + delete[] text->str; + delete text; +} diff --git a/source/text.h b/source/text.h new file mode 100644 index 000000000..fa34ee16b --- /dev/null +++ b/source/text.h @@ -0,0 +1,53 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_TEXT_H_ +#define SOURCE_TEXT_H_ + +#include + +#include "source/operand.h" +#include "source/spirv_constant.h" +#include "spirv-tools/libspirv.h" + +typedef enum spv_literal_type_t { + SPV_LITERAL_TYPE_INT_32, + SPV_LITERAL_TYPE_INT_64, + SPV_LITERAL_TYPE_UINT_32, + SPV_LITERAL_TYPE_UINT_64, + SPV_LITERAL_TYPE_FLOAT_32, + SPV_LITERAL_TYPE_FLOAT_64, + SPV_LITERAL_TYPE_STRING, + SPV_FORCE_32_BIT_ENUM(spv_literal_type_t) +} spv_literal_type_t; + +typedef struct spv_literal_t { + spv_literal_type_t type; + union value_t { + int32_t i32; + int64_t i64; + uint32_t u32; + uint64_t u64; + float f; + double d; + } value; + std::string str; // Special field for literal string. +} spv_literal_t; + +// Converts the given text string to a number/string literal and writes the +// result to *literal. String literals must be surrounded by double-quotes ("), +// which are then stripped. +spv_result_t spvTextToLiteral(const char* text, spv_literal_t* literal); + +#endif // SOURCE_TEXT_H_ diff --git a/source/text_handler.cpp b/source/text_handler.cpp new file mode 100644 index 000000000..c31f34a6b --- /dev/null +++ b/source/text_handler.cpp @@ -0,0 +1,397 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/text_handler.h" + +#include +#include +#include +#include +#include + +#include "source/assembly_grammar.h" +#include "source/binary.h" +#include "source/ext_inst.h" +#include "source/instruction.h" +#include "source/opcode.h" +#include "source/text.h" +#include "source/util/bitutils.h" +#include "source/util/hex_float.h" +#include "source/util/parse_number.h" + +namespace spvtools { +namespace { + +// Advances |text| to the start of the next line and writes the new position to +// |position|. +spv_result_t advanceLine(spv_text text, spv_position position) { + while (true) { + if (position->index >= text->length) return SPV_END_OF_STREAM; + switch (text->str[position->index]) { + case '\0': + return SPV_END_OF_STREAM; + case '\n': + position->column = 0; + position->line++; + position->index++; + return SPV_SUCCESS; + default: + position->column++; + position->index++; + break; + } + } +} + +// Advances |text| to first non white space character and writes the new +// position to |position|. +// If a null terminator is found during the text advance, SPV_END_OF_STREAM is +// returned, SPV_SUCCESS otherwise. No error checking is performed on the +// parameters, its the users responsibility to ensure these are non null. +spv_result_t advance(spv_text text, spv_position position) { + // NOTE: Consume white space, otherwise don't advance. + if (position->index >= text->length) return SPV_END_OF_STREAM; + switch (text->str[position->index]) { + case '\0': + return SPV_END_OF_STREAM; + case ';': + if (spv_result_t error = advanceLine(text, position)) return error; + return advance(text, position); + case ' ': + case '\t': + case '\r': + position->column++; + position->index++; + return advance(text, position); + case '\n': + position->column = 0; + position->line++; + position->index++; + return advance(text, position); + default: + break; + } + return SPV_SUCCESS; +} + +// Fetches the next word from the given text stream starting from the given +// *position. On success, writes the decoded word into *word and updates +// *position to the location past the returned word. +// +// A word ends at the next comment or whitespace. However, double-quoted +// strings remain intact, and a backslash always escapes the next character. +spv_result_t getWord(spv_text text, spv_position position, std::string* word) { + if (!text->str || !text->length) return SPV_ERROR_INVALID_TEXT; + if (!position) return SPV_ERROR_INVALID_POINTER; + + const size_t start_index = position->index; + + bool quoting = false; + bool escaping = false; + + // NOTE: Assumes first character is not white space! + while (true) { + if (position->index >= text->length) { + word->assign(text->str + start_index, text->str + position->index); + return SPV_SUCCESS; + } + const char ch = text->str[position->index]; + if (ch == '\\') { + escaping = !escaping; + } else { + switch (ch) { + case '"': + if (!escaping) quoting = !quoting; + break; + case ' ': + case ';': + case '\t': + case '\n': + case '\r': + if (escaping || quoting) break; + // Fall through. + case '\0': { // NOTE: End of word found! + word->assign(text->str + start_index, text->str + position->index); + return SPV_SUCCESS; + } + default: + break; + } + escaping = false; + } + + position->column++; + position->index++; + } +} + +// Returns true if the characters in the text as position represent +// the start of an Opcode. +bool startsWithOp(spv_text text, spv_position position) { + if (text->length < position->index + 3) return false; + char ch0 = text->str[position->index]; + char ch1 = text->str[position->index + 1]; + char ch2 = text->str[position->index + 2]; + return ('O' == ch0 && 'p' == ch1 && ('A' <= ch2 && ch2 <= 'Z')); +} + +} // namespace + +const IdType kUnknownType = {0, false, IdTypeClass::kBottom}; + +// TODO(dneto): Reorder AssemblyContext definitions to match declaration order. + +// This represents all of the data that is only valid for the duration of +// a single compilation. +uint32_t AssemblyContext::spvNamedIdAssignOrGet(const char* textValue) { + if (!ids_to_preserve_.empty()) { + uint32_t id = 0; + if (spvtools::utils::ParseNumber(textValue, &id)) { + if (ids_to_preserve_.find(id) != ids_to_preserve_.end()) { + bound_ = std::max(bound_, id + 1); + return id; + } + } + } + + const auto it = named_ids_.find(textValue); + if (it == named_ids_.end()) { + uint32_t id = next_id_++; + if (!ids_to_preserve_.empty()) { + while (ids_to_preserve_.find(id) != ids_to_preserve_.end()) { + id = next_id_++; + } + } + + named_ids_.emplace(textValue, id); + bound_ = std::max(bound_, id + 1); + return id; + } + + return it->second; +} + +uint32_t AssemblyContext::getBound() const { return bound_; } + +spv_result_t AssemblyContext::advance() { + return spvtools::advance(text_, ¤t_position_); +} + +spv_result_t AssemblyContext::getWord(std::string* word, + spv_position next_position) { + *next_position = current_position_; + return spvtools::getWord(text_, next_position, word); +} + +bool AssemblyContext::startsWithOp() { + return spvtools::startsWithOp(text_, ¤t_position_); +} + +bool AssemblyContext::isStartOfNewInst() { + spv_position_t pos = current_position_; + if (spvtools::advance(text_, &pos)) return false; + if (spvtools::startsWithOp(text_, &pos)) return true; + + std::string word; + pos = current_position_; + if (spvtools::getWord(text_, &pos, &word)) return false; + if ('%' != word.front()) return false; + + if (spvtools::advance(text_, &pos)) return false; + if (spvtools::getWord(text_, &pos, &word)) return false; + if ("=" != word) return false; + + if (spvtools::advance(text_, &pos)) return false; + if (spvtools::startsWithOp(text_, &pos)) return true; + return false; +} + +char AssemblyContext::peek() const { + return text_->str[current_position_.index]; +} + +bool AssemblyContext::hasText() const { + return text_->length > current_position_.index; +} + +void AssemblyContext::seekForward(uint32_t size) { + current_position_.index += size; + current_position_.column += size; +} + +spv_result_t AssemblyContext::binaryEncodeU32(const uint32_t value, + spv_instruction_t* pInst) { + pInst->words.insert(pInst->words.end(), value); + return SPV_SUCCESS; +} + +spv_result_t AssemblyContext::binaryEncodeNumericLiteral( + const char* val, spv_result_t error_code, const IdType& type, + spv_instruction_t* pInst) { + using spvtools::utils::EncodeNumberStatus; + // Populate the NumberType from the IdType for parsing. + spvtools::utils::NumberType number_type; + switch (type.type_class) { + case IdTypeClass::kOtherType: + return diagnostic(SPV_ERROR_INTERNAL) + << "Unexpected numeric literal type"; + case IdTypeClass::kScalarIntegerType: + if (type.isSigned) { + number_type = {type.bitwidth, SPV_NUMBER_SIGNED_INT}; + } else { + number_type = {type.bitwidth, SPV_NUMBER_UNSIGNED_INT}; + } + break; + case IdTypeClass::kScalarFloatType: + number_type = {type.bitwidth, SPV_NUMBER_FLOATING}; + break; + case IdTypeClass::kBottom: + // kBottom means the type is unknown and we need to infer the type before + // parsing the number. The rule is: If there is a decimal point, treat + // the value as a floating point value, otherwise a integer value, then + // if the first char of the integer text is '-', treat the integer as a + // signed integer, otherwise an unsigned integer. + uint32_t bitwidth = static_cast(assumedBitWidth(type)); + if (strchr(val, '.')) { + number_type = {bitwidth, SPV_NUMBER_FLOATING}; + } else if (type.isSigned || val[0] == '-') { + number_type = {bitwidth, SPV_NUMBER_SIGNED_INT}; + } else { + number_type = {bitwidth, SPV_NUMBER_UNSIGNED_INT}; + } + break; + } + + std::string error_msg; + EncodeNumberStatus parse_status = ParseAndEncodeNumber( + val, number_type, + [this, pInst](uint32_t d) { this->binaryEncodeU32(d, pInst); }, + &error_msg); + switch (parse_status) { + case EncodeNumberStatus::kSuccess: + return SPV_SUCCESS; + case EncodeNumberStatus::kInvalidText: + return diagnostic(error_code) << error_msg; + case EncodeNumberStatus::kUnsupported: + return diagnostic(SPV_ERROR_INTERNAL) << error_msg; + case EncodeNumberStatus::kInvalidUsage: + return diagnostic(SPV_ERROR_INVALID_TEXT) << error_msg; + } + // This line is not reachable, only added to satisfy the compiler. + return diagnostic(SPV_ERROR_INTERNAL) + << "Unexpected result code from ParseAndEncodeNumber()"; +} + +spv_result_t AssemblyContext::binaryEncodeString(const char* value, + spv_instruction_t* pInst) { + const size_t length = strlen(value); + const size_t wordCount = (length / 4) + 1; + const size_t oldWordCount = pInst->words.size(); + const size_t newWordCount = oldWordCount + wordCount; + + // TODO(dneto): We can just defer this check until later. + if (newWordCount > SPV_LIMIT_INSTRUCTION_WORD_COUNT_MAX) { + return diagnostic() << "Instruction too long: more than " + << SPV_LIMIT_INSTRUCTION_WORD_COUNT_MAX << " words."; + } + + pInst->words.resize(newWordCount); + + // Make sure all the bytes in the last word are 0, in case we only + // write a partial word at the end. + pInst->words.back() = 0; + + char* dest = (char*)&pInst->words[oldWordCount]; + strncpy(dest, value, length + 1); + + return SPV_SUCCESS; +} + +spv_result_t AssemblyContext::recordTypeDefinition( + const spv_instruction_t* pInst) { + uint32_t value = pInst->words[1]; + if (types_.find(value) != types_.end()) { + return diagnostic() << "Value " << value + << " has already been used to generate a type"; + } + + if (pInst->opcode == SpvOpTypeInt) { + if (pInst->words.size() != 4) + return diagnostic() << "Invalid OpTypeInt instruction"; + types_[value] = {pInst->words[2], pInst->words[3] != 0, + IdTypeClass::kScalarIntegerType}; + } else if (pInst->opcode == SpvOpTypeFloat) { + if (pInst->words.size() != 3) + return diagnostic() << "Invalid OpTypeFloat instruction"; + types_[value] = {pInst->words[2], false, IdTypeClass::kScalarFloatType}; + } else { + types_[value] = {0, false, IdTypeClass::kOtherType}; + } + return SPV_SUCCESS; +} + +IdType AssemblyContext::getTypeOfTypeGeneratingValue(uint32_t value) const { + auto type = types_.find(value); + if (type == types_.end()) { + return kUnknownType; + } + return std::get<1>(*type); +} + +IdType AssemblyContext::getTypeOfValueInstruction(uint32_t value) const { + auto type_value = value_types_.find(value); + if (type_value == value_types_.end()) { + return {0, false, IdTypeClass::kBottom}; + } + return getTypeOfTypeGeneratingValue(std::get<1>(*type_value)); +} + +spv_result_t AssemblyContext::recordTypeIdForValue(uint32_t value, + uint32_t type) { + bool successfully_inserted = false; + std::tie(std::ignore, successfully_inserted) = + value_types_.insert(std::make_pair(value, type)); + if (!successfully_inserted) + return diagnostic() << "Value is being defined a second time"; + return SPV_SUCCESS; +} + +spv_result_t AssemblyContext::recordIdAsExtInstImport( + uint32_t id, spv_ext_inst_type_t type) { + bool successfully_inserted = false; + std::tie(std::ignore, successfully_inserted) = + import_id_to_ext_inst_type_.insert(std::make_pair(id, type)); + if (!successfully_inserted) + return diagnostic() << "Import Id is being defined a second time"; + return SPV_SUCCESS; +} + +spv_ext_inst_type_t AssemblyContext::getExtInstTypeForId(uint32_t id) const { + auto type = import_id_to_ext_inst_type_.find(id); + if (type == import_id_to_ext_inst_type_.end()) { + return SPV_EXT_INST_TYPE_NONE; + } + return std::get<1>(*type); +} + +std::set AssemblyContext::GetNumericIds() const { + std::set ids; + for (const auto& kv : named_ids_) { + uint32_t id; + if (spvtools::utils::ParseNumber(kv.first.c_str(), &id)) ids.insert(id); + } + return ids; +} + +} // namespace spvtools diff --git a/source/text_handler.h b/source/text_handler.h new file mode 100644 index 000000000..19972e951 --- /dev/null +++ b/source/text_handler.h @@ -0,0 +1,264 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_TEXT_HANDLER_H_ +#define SOURCE_TEXT_HANDLER_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "source/diagnostic.h" +#include "source/instruction.h" +#include "source/text.h" +#include "spirv-tools/libspirv.h" + +namespace spvtools { + +// Structures + +// This is a lattice for tracking types. +enum class IdTypeClass { + kBottom = 0, // We have no information yet. + kScalarIntegerType, + kScalarFloatType, + kOtherType +}; + +// Contains ID type information that needs to be tracked across all Ids. +// Bitwidth is only valid when type_class is kScalarIntegerType or +// kScalarFloatType. +struct IdType { + uint32_t bitwidth; // Safe to assume that we will not have > 2^32 bits. + bool isSigned; // This is only significant if type_class is integral. + IdTypeClass type_class; +}; + +// Default equality operator for IdType. Tests if all members are the same. +inline bool operator==(const IdType& first, const IdType& second) { + return (first.bitwidth == second.bitwidth) && + (first.isSigned == second.isSigned) && + (first.type_class == second.type_class); +} + +// Tests whether any member of the IdTypes do not match. +inline bool operator!=(const IdType& first, const IdType& second) { + return !(first == second); +} + +// A value representing an unknown type. +extern const IdType kUnknownType; + +// Returns true if the type is a scalar integer type. +inline bool isScalarIntegral(const IdType& type) { + return type.type_class == IdTypeClass::kScalarIntegerType; +} + +// Returns true if the type is a scalar floating point type. +inline bool isScalarFloating(const IdType& type) { + return type.type_class == IdTypeClass::kScalarFloatType; +} + +// Returns the number of bits in the type. +// This is only valid for bottom, scalar integer, and scalar floating +// classes. For bottom, assume 32 bits. +inline int assumedBitWidth(const IdType& type) { + switch (type.type_class) { + case IdTypeClass::kBottom: + return 32; + case IdTypeClass::kScalarIntegerType: + case IdTypeClass::kScalarFloatType: + return type.bitwidth; + default: + break; + } + // We don't care about this case. + return 0; +} + +// A templated class with a static member function Clamp, where Clamp +// sets a referenced value of type T to 0 if T is an unsigned +// integer type, and returns true if it modified the referenced +// value. +template +class ClampToZeroIfUnsignedType { + public: + // The default specialization does not clamp the value. + static bool Clamp(T*) { return false; } +}; + +// The specialization of ClampToZeroIfUnsignedType for unsigned integer +// types. +template +class ClampToZeroIfUnsignedType< + T, typename std::enable_if::value>::type> { + public: + static bool Clamp(T* value_pointer) { + if (*value_pointer) { + *value_pointer = 0; + return true; + } + return false; + } +}; + +// Encapsulates the data used during the assembly of a SPIR-V module. +class AssemblyContext { + public: + AssemblyContext(spv_text text, const MessageConsumer& consumer, + std::set&& ids_to_preserve = std::set()) + : current_position_({}), + consumer_(consumer), + text_(text), + bound_(1), + next_id_(1), + ids_to_preserve_(std::move(ids_to_preserve)) {} + + // Assigns a new integer value to the given text ID, or returns the previously + // assigned integer value if the ID has been seen before. + uint32_t spvNamedIdAssignOrGet(const char* textValue); + + // Returns the largest largest numeric ID that has been assigned. + uint32_t getBound() const; + + // Advances position to point to the next word in the input stream. + // Returns SPV_SUCCESS on success. + spv_result_t advance(); + + // Sets word to the next word in the input text. Fills next_position with + // the next location past the end of the word. + spv_result_t getWord(std::string* word, spv_position next_position); + + // Returns true if the next word in the input is the start of a new Opcode. + bool startsWithOp(); + + // Returns true if the next word in the input is the start of a new + // instruction. + bool isStartOfNewInst(); + + // Returns a diagnostic object initialized with current position in the input + // stream, and for the given error code. Any data written to this object will + // show up in pDiagnsotic on destruction. + DiagnosticStream diagnostic(spv_result_t error) { + return DiagnosticStream(current_position_, consumer_, "", error); + } + + // Returns a diagnostic object with the default assembly error code. + DiagnosticStream diagnostic() { + // The default failure for assembly is invalid text. + return diagnostic(SPV_ERROR_INVALID_TEXT); + } + + // Returns then next character in the input stream. + char peek() const; + + // Returns true if there is more text in the input stream. + bool hasText() const; + + // Seeks the input stream forward by 'size' characters. + void seekForward(uint32_t size); + + // Sets the current position in the input stream to the given position. + void setPosition(const spv_position_t& newPosition) { + current_position_ = newPosition; + } + + // Returns the current position in the input stream. + const spv_position_t& position() const { return current_position_; } + + // Appends the given 32-bit value to the given instruction. + // Returns SPV_SUCCESS if the value could be correctly inserted in the + // instruction. + spv_result_t binaryEncodeU32(const uint32_t value, spv_instruction_t* pInst); + + // Appends the given string to the given instruction. + // Returns SPV_SUCCESS if the value could be correctly inserted in the + // instruction. + spv_result_t binaryEncodeString(const char* value, spv_instruction_t* pInst); + + // Appends the given numeric literal to the given instruction. + // Validates and respects the bitwidth supplied in the IdType argument. + // If the type is of class kBottom the value will be encoded as a + // 32-bit integer. + // Returns SPV_SUCCESS if the value could be correctly added to the + // instruction. Returns the given error code on failure, and emits + // a diagnostic if that error code is not SPV_FAILED_MATCH. + spv_result_t binaryEncodeNumericLiteral(const char* numeric_literal, + spv_result_t error_code, + const IdType& type, + spv_instruction_t* pInst); + + // Returns the IdType associated with this type-generating value. + // If the type has not been previously recorded with recordTypeDefinition, + // kUnknownType will be returned. + IdType getTypeOfTypeGeneratingValue(uint32_t value) const; + + // Returns the IdType that represents the return value of this Value + // generating instruction. + // If the value has not been recorded with recordTypeIdForValue, or the type + // could not be determined kUnknownType will be returned. + IdType getTypeOfValueInstruction(uint32_t value) const; + + // Tracks the type-defining instruction. The result of the tracking can + // later be queried using getValueType. + // pInst is expected to be completely filled in by the time this instruction + // is called. + // Returns SPV_SUCCESS on success, or SPV_ERROR_INVALID_VALUE on error. + spv_result_t recordTypeDefinition(const spv_instruction_t* pInst); + + // Tracks the relationship between the value and its type. + spv_result_t recordTypeIdForValue(uint32_t value, uint32_t type); + + // Records the given Id as being the import of the given extended instruction + // type. + spv_result_t recordIdAsExtInstImport(uint32_t id, spv_ext_inst_type_t type); + + // Returns the extended instruction type corresponding to the import with + // the given Id, if it exists. Returns SPV_EXT_INST_TYPE_NONE if the + // id is not the id for an extended instruction type. + spv_ext_inst_type_t getExtInstTypeForId(uint32_t id) const; + + // Returns a set consisting of each ID generated by spvNamedIdAssignOrGet from + // a numeric ID text representation. For example, generated from "%12" but not + // from "%foo". + std::set GetNumericIds() const; + + private: + // Maps ID names to their corresponding numerical ids. + using spv_named_id_table = std::unordered_map; + // Maps type-defining IDs to their IdType. + using spv_id_to_type_map = std::unordered_map; + // Maps Ids to the id of their type. + using spv_id_to_type_id = std::unordered_map; + + spv_named_id_table named_ids_; + spv_id_to_type_map types_; + spv_id_to_type_id value_types_; + // Maps an extended instruction import Id to the extended instruction type. + std::unordered_map import_id_to_ext_inst_type_; + spv_position_t current_position_; + MessageConsumer consumer_; + spv_text text_; + uint32_t bound_; + uint32_t next_id_; + std::set ids_to_preserve_; +}; + +} // namespace spvtools + +#endif // SOURCE_TEXT_HANDLER_H_ diff --git a/source/util/bit_vector.cpp b/source/util/bit_vector.cpp new file mode 100644 index 000000000..47e275bf4 --- /dev/null +++ b/source/util/bit_vector.cpp @@ -0,0 +1,82 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/util/bit_vector.h" + +#include +#include + +namespace spvtools { +namespace utils { + +void BitVector::ReportDensity(std::ostream& out) { + uint32_t count = 0; + + for (BitContainer e : bits_) { + while (e != 0) { + if ((e & 1) != 0) { + ++count; + } + e = e >> 1; + } + } + + out << "count=" << count + << ", total size (bytes)=" << bits_.size() * sizeof(BitContainer) + << ", bytes per element=" + << (double)(bits_.size() * sizeof(BitContainer)) / (double)(count); +} + +bool BitVector::Or(const BitVector& other) { + auto this_it = this->bits_.begin(); + auto other_it = other.bits_.begin(); + bool modified = false; + + while (this_it != this->bits_.end() && other_it != other.bits_.end()) { + auto temp = *this_it | *other_it; + if (temp != *this_it) { + modified = true; + *this_it = temp; + } + ++this_it; + ++other_it; + } + + if (other_it != other.bits_.end()) { + modified = true; + this->bits_.insert(this->bits_.end(), other_it, other.bits_.end()); + } + + return modified; +} + +std::ostream& operator<<(std::ostream& out, const BitVector& bv) { + out << "{"; + for (uint32_t i = 0; i < bv.bits_.size(); ++i) { + BitVector::BitContainer b = bv.bits_[i]; + uint32_t j = 0; + while (b != 0) { + if (b & 1) { + out << ' ' << i * BitVector::kBitContainerSize + j; + } + ++j; + b = b >> 1; + } + } + out << "}"; + return out; +} + +} // namespace utils +} // namespace spvtools diff --git a/source/util/bit_vector.h b/source/util/bit_vector.h new file mode 100644 index 000000000..3e189cb10 --- /dev/null +++ b/source/util/bit_vector.h @@ -0,0 +1,119 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_UTIL_BIT_VECTOR_H_ +#define SOURCE_UTIL_BIT_VECTOR_H_ + +#include +#include +#include + +namespace spvtools { +namespace utils { + +// Implements a bit vector class. +// +// All bits default to zero, and the upper bound is 2^32-1. +class BitVector { + private: + using BitContainer = uint64_t; + enum { kBitContainerSize = 64 }; + enum { kInitialNumBits = 1024 }; + + public: + // Creates a bit vector contianing 0s. + BitVector(uint32_t reserved_size = kInitialNumBits) + : bits_((reserved_size - 1) / kBitContainerSize + 1, 0) {} + + // Sets the |i|th bit to 1. Returns the |i|th bit before it was set. + bool Set(uint32_t i) { + uint32_t element_index = i / kBitContainerSize; + uint32_t bit_in_element = i % kBitContainerSize; + + if (element_index >= bits_.size()) { + bits_.resize(element_index + 1, 0); + } + + BitContainer original = bits_[element_index]; + BitContainer ith_bit = static_cast(1) << bit_in_element; + + if ((original & ith_bit) != 0) { + return true; + } else { + bits_[element_index] = original | ith_bit; + return false; + } + } + + // Sets the |i|th bit to 0. Return the |i|th bit before it was cleared. + bool Clear(uint32_t i) { + uint32_t element_index = i / kBitContainerSize; + uint32_t bit_in_element = i % kBitContainerSize; + + if (element_index >= bits_.size()) { + return false; + } + + BitContainer original = bits_[element_index]; + BitContainer ith_bit = static_cast(1) << bit_in_element; + + if ((original & ith_bit) == 0) { + return false; + } else { + bits_[element_index] = original & (~ith_bit); + return true; + } + } + + // Returns the |i|th bit. + bool Get(uint32_t i) const { + uint32_t element_index = i / kBitContainerSize; + uint32_t bit_in_element = i % kBitContainerSize; + + if (element_index >= bits_.size()) { + return false; + } + + return (bits_[element_index] & + (static_cast(1) << bit_in_element)) != 0; + } + + // Returns true if every bit is 0. + bool Empty() const { + for (BitContainer b : bits_) { + if (b != 0) { + return false; + } + } + return true; + } + + // Print a report on the densicy of the bit vector, number of 1 bits, number + // of bytes, and average bytes for 1 bit, to |out|. + void ReportDensity(std::ostream& out); + + friend std::ostream& operator<<(std::ostream&, const BitVector&); + + // Performs a bitwise-or operation on |this| and |that|, storing the result in + // |this|. Return true if |this| changed. + bool Or(const BitVector& that); + + private: + std::vector bits_; +}; + +} // namespace utils +} // namespace spvtools + +#endif // SOURCE_UTIL_BIT_VECTOR_H_ diff --git a/source/util/bitutils.h b/source/util/bitutils.h new file mode 100644 index 000000000..17d61df90 --- /dev/null +++ b/source/util/bitutils.h @@ -0,0 +1,96 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_UTIL_BITUTILS_H_ +#define SOURCE_UTIL_BITUTILS_H_ + +#include +#include + +namespace spvtools { +namespace utils { + +// Performs a bitwise copy of source to the destination type Dest. +template +Dest BitwiseCast(Src source) { + Dest dest; + static_assert(sizeof(source) == sizeof(dest), + "BitwiseCast: Source and destination must have the same size"); + std::memcpy(&dest, &source, sizeof(dest)); + return dest; +} + +// SetBits returns an integer of type with bits set +// for position through , counting from the least +// significant bit. In particular when Num == 0, no positions are set to 1. +// A static assert will be triggered if First + Num > sizeof(T) * 8, that is, +// a bit that will not fit in the underlying type is set. +template +struct SetBits { + static_assert(First < sizeof(T) * 8, + "Tried to set a bit that is shifted too far."); + const static T get = (T(1) << First) | SetBits::get; +}; + +template +struct SetBits { + const static T get = T(0); +}; + +// This is all compile-time so we can put our tests right here. +static_assert(SetBits::get == uint32_t(0x00000000), + "SetBits failed"); +static_assert(SetBits::get == uint32_t(0x00000001), + "SetBits failed"); +static_assert(SetBits::get == uint32_t(0x80000000), + "SetBits failed"); +static_assert(SetBits::get == uint32_t(0x00000006), + "SetBits failed"); +static_assert(SetBits::get == uint32_t(0xc0000000), + "SetBits failed"); +static_assert(SetBits::get == uint32_t(0x7FFFFFFF), + "SetBits failed"); +static_assert(SetBits::get == uint32_t(0xFFFFFFFF), + "SetBits failed"); +static_assert(SetBits::get == uint32_t(0xFFFF0000), + "SetBits failed"); + +static_assert(SetBits::get == uint64_t(0x0000000000000001LL), + "SetBits failed"); +static_assert(SetBits::get == uint64_t(0x8000000000000000LL), + "SetBits failed"); +static_assert(SetBits::get == uint64_t(0xc000000000000000LL), + "SetBits failed"); +static_assert(SetBits::get == uint64_t(0x0000000080000000LL), + "SetBits failed"); +static_assert(SetBits::get == uint64_t(0x00000000FFFF0000LL), + "SetBits failed"); + +// Returns number of '1' bits in a word. +template +size_t CountSetBits(T word) { + static_assert(std::is_integral::value, + "CountSetBits requires integer type"); + size_t count = 0; + while (word) { + word &= word - 1; + ++count; + } + return count; +} + +} // namespace utils +} // namespace spvtools + +#endif // SOURCE_UTIL_BITUTILS_H_ diff --git a/source/util/hex_float.h b/source/util/hex_float.h new file mode 100644 index 000000000..cfc40fa68 --- /dev/null +++ b/source/util/hex_float.h @@ -0,0 +1,1150 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_UTIL_HEX_FLOAT_H_ +#define SOURCE_UTIL_HEX_FLOAT_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/util/bitutils.h" + +#ifndef __GNUC__ +#define GCC_VERSION 0 +#else +#define GCC_VERSION \ + (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) +#endif + +namespace spvtools { +namespace utils { + +class Float16 { + public: + Float16(uint16_t v) : val(v) {} + Float16() = default; + static bool isNan(const Float16& val) { + return ((val.val & 0x7C00) == 0x7C00) && ((val.val & 0x3FF) != 0); + } + // Returns true if the given value is any kind of infinity. + static bool isInfinity(const Float16& val) { + return ((val.val & 0x7C00) == 0x7C00) && ((val.val & 0x3FF) == 0); + } + Float16(const Float16& other) { val = other.val; } + uint16_t get_value() const { return val; } + + // Returns the maximum normal value. + static Float16 max() { return Float16(0x7bff); } + // Returns the lowest normal value. + static Float16 lowest() { return Float16(0xfbff); } + + private: + uint16_t val; +}; + +// To specialize this type, you must override uint_type to define +// an unsigned integer that can fit your floating point type. +// You must also add a isNan function that returns true if +// a value is Nan. +template +struct FloatProxyTraits { + using uint_type = void; +}; + +template <> +struct FloatProxyTraits { + using uint_type = uint32_t; + static bool isNan(float f) { return std::isnan(f); } + // Returns true if the given value is any kind of infinity. + static bool isInfinity(float f) { return std::isinf(f); } + // Returns the maximum normal value. + static float max() { return std::numeric_limits::max(); } + // Returns the lowest normal value. + static float lowest() { return std::numeric_limits::lowest(); } + // Returns the value as the native floating point format. + static float getAsFloat(const uint_type& t) { return BitwiseCast(t); } + // Returns the bits from the given floating pointer number. + static uint_type getBitsFromFloat(const float& t) { + return BitwiseCast(t); + } + // Returns the bitwidth. + static uint32_t width() { return 32u; } +}; + +template <> +struct FloatProxyTraits { + using uint_type = uint64_t; + static bool isNan(double f) { return std::isnan(f); } + // Returns true if the given value is any kind of infinity. + static bool isInfinity(double f) { return std::isinf(f); } + // Returns the maximum normal value. + static double max() { return std::numeric_limits::max(); } + // Returns the lowest normal value. + static double lowest() { return std::numeric_limits::lowest(); } + // Returns the value as the native floating point format. + static double getAsFloat(const uint_type& t) { + return BitwiseCast(t); + } + // Returns the bits from the given floating pointer number. + static uint_type getBitsFromFloat(const double& t) { + return BitwiseCast(t); + } + // Returns the bitwidth. + static uint32_t width() { return 64u; } +}; + +template <> +struct FloatProxyTraits { + using uint_type = uint16_t; + static bool isNan(Float16 f) { return Float16::isNan(f); } + // Returns true if the given value is any kind of infinity. + static bool isInfinity(Float16 f) { return Float16::isInfinity(f); } + // Returns the maximum normal value. + static Float16 max() { return Float16::max(); } + // Returns the lowest normal value. + static Float16 lowest() { return Float16::lowest(); } + // Returns the value as the native floating point format. + static Float16 getAsFloat(const uint_type& t) { return Float16(t); } + // Returns the bits from the given floating pointer number. + static uint_type getBitsFromFloat(const Float16& t) { return t.get_value(); } + // Returns the bitwidth. + static uint32_t width() { return 16u; } +}; + +// Since copying a floating point number (especially if it is NaN) +// does not guarantee that bits are preserved, this class lets us +// store the type and use it as a float when necessary. +template +class FloatProxy { + public: + using uint_type = typename FloatProxyTraits::uint_type; + + // Since this is to act similar to the normal floats, + // do not initialize the data by default. + FloatProxy() = default; + + // Intentionally non-explicit. This is a proxy type so + // implicit conversions allow us to use it more transparently. + FloatProxy(T val) { data_ = FloatProxyTraits::getBitsFromFloat(val); } + + // Intentionally non-explicit. This is a proxy type so + // implicit conversions allow us to use it more transparently. + FloatProxy(uint_type val) { data_ = val; } + + // This is helpful to have and is guaranteed not to stomp bits. + FloatProxy operator-() const { + return static_cast(data_ ^ + (uint_type(0x1) << (sizeof(T) * 8 - 1))); + } + + // Returns the data as a floating point value. + T getAsFloat() const { return FloatProxyTraits::getAsFloat(data_); } + + // Returns the raw data. + uint_type data() const { return data_; } + + // Returns a vector of words suitable for use in an Operand. + std::vector GetWords() const { + std::vector words; + if (FloatProxyTraits::width() == 64) { + FloatProxyTraits::uint_type d = data(); + words.push_back(static_cast(d)); + words.push_back(static_cast(d >> 32)); + } else { + words.push_back(static_cast(data())); + } + return words; + } + + // Returns true if the value represents any type of NaN. + bool isNan() { return FloatProxyTraits::isNan(getAsFloat()); } + // Returns true if the value represents any type of infinity. + bool isInfinity() { return FloatProxyTraits::isInfinity(getAsFloat()); } + + // Returns the maximum normal value. + static FloatProxy max() { + return FloatProxy(FloatProxyTraits::max()); + } + // Returns the lowest normal value. + static FloatProxy lowest() { + return FloatProxy(FloatProxyTraits::lowest()); + } + + private: + uint_type data_; +}; + +template +bool operator==(const FloatProxy& first, const FloatProxy& second) { + return first.data() == second.data(); +} + +// Reads a FloatProxy value as a normal float from a stream. +template +std::istream& operator>>(std::istream& is, FloatProxy& value) { + T float_val; + is >> float_val; + value = FloatProxy(float_val); + return is; +} + +// This is an example traits. It is not meant to be used in practice, but will +// be the default for any non-specialized type. +template +struct HexFloatTraits { + // Integer type that can store this hex-float. + using uint_type = void; + // Signed integer type that can store this hex-float. + using int_type = void; + // The numerical type that this HexFloat represents. + using underlying_type = void; + // The type needed to construct the underlying type. + using native_type = void; + // The number of bits that are actually relevant in the uint_type. + // This allows us to deal with, for example, 24-bit values in a 32-bit + // integer. + static const uint32_t num_used_bits = 0; + // Number of bits that represent the exponent. + static const uint32_t num_exponent_bits = 0; + // Number of bits that represent the fractional part. + static const uint32_t num_fraction_bits = 0; + // The bias of the exponent. (How much we need to subtract from the stored + // value to get the correct value.) + static const uint32_t exponent_bias = 0; +}; + +// Traits for IEEE float. +// 1 sign bit, 8 exponent bits, 23 fractional bits. +template <> +struct HexFloatTraits> { + using uint_type = uint32_t; + using int_type = int32_t; + using underlying_type = FloatProxy; + using native_type = float; + static const uint_type num_used_bits = 32; + static const uint_type num_exponent_bits = 8; + static const uint_type num_fraction_bits = 23; + static const uint_type exponent_bias = 127; +}; + +// Traits for IEEE double. +// 1 sign bit, 11 exponent bits, 52 fractional bits. +template <> +struct HexFloatTraits> { + using uint_type = uint64_t; + using int_type = int64_t; + using underlying_type = FloatProxy; + using native_type = double; + static const uint_type num_used_bits = 64; + static const uint_type num_exponent_bits = 11; + static const uint_type num_fraction_bits = 52; + static const uint_type exponent_bias = 1023; +}; + +// Traits for IEEE half. +// 1 sign bit, 5 exponent bits, 10 fractional bits. +template <> +struct HexFloatTraits> { + using uint_type = uint16_t; + using int_type = int16_t; + using underlying_type = uint16_t; + using native_type = uint16_t; + static const uint_type num_used_bits = 16; + static const uint_type num_exponent_bits = 5; + static const uint_type num_fraction_bits = 10; + static const uint_type exponent_bias = 15; +}; + +enum class round_direction { + kToZero, + kToNearestEven, + kToPositiveInfinity, + kToNegativeInfinity, + max = kToNegativeInfinity +}; + +// Template class that houses a floating pointer number. +// It exposes a number of constants based on the provided traits to +// assist in interpreting the bits of the value. +template > +class HexFloat { + public: + using uint_type = typename Traits::uint_type; + using int_type = typename Traits::int_type; + using underlying_type = typename Traits::underlying_type; + using native_type = typename Traits::native_type; + + explicit HexFloat(T f) : value_(f) {} + + T value() const { return value_; } + void set_value(T f) { value_ = f; } + + // These are all written like this because it is convenient to have + // compile-time constants for all of these values. + + // Pass-through values to save typing. + static const uint32_t num_used_bits = Traits::num_used_bits; + static const uint32_t exponent_bias = Traits::exponent_bias; + static const uint32_t num_exponent_bits = Traits::num_exponent_bits; + static const uint32_t num_fraction_bits = Traits::num_fraction_bits; + + // Number of bits to shift left to set the highest relevant bit. + static const uint32_t top_bit_left_shift = num_used_bits - 1; + // How many nibbles (hex characters) the fractional part takes up. + static const uint32_t fraction_nibbles = (num_fraction_bits + 3) / 4; + // If the fractional part does not fit evenly into a hex character (4-bits) + // then we have to left-shift to get rid of leading 0s. This is the amount + // we have to shift (might be 0). + static const uint32_t num_overflow_bits = + fraction_nibbles * 4 - num_fraction_bits; + + // The representation of the fraction, not the actual bits. This + // includes the leading bit that is usually implicit. + static const uint_type fraction_represent_mask = + SetBits::get; + + // The topmost bit in the nibble-aligned fraction. + static const uint_type fraction_top_bit = + uint_type(1) << (num_fraction_bits + num_overflow_bits - 1); + + // The least significant bit in the exponent, which is also the bit + // immediately to the left of the significand. + static const uint_type first_exponent_bit = uint_type(1) + << (num_fraction_bits); + + // The mask for the encoded fraction. It does not include the + // implicit bit. + static const uint_type fraction_encode_mask = + SetBits::get; + + // The bit that is used as a sign. + static const uint_type sign_mask = uint_type(1) << top_bit_left_shift; + + // The bits that represent the exponent. + static const uint_type exponent_mask = + SetBits::get; + + // How far left the exponent is shifted. + static const uint32_t exponent_left_shift = num_fraction_bits; + + // How far from the right edge the fraction is shifted. + static const uint32_t fraction_right_shift = + static_cast(sizeof(uint_type) * 8) - num_fraction_bits; + + // The maximum representable unbiased exponent. + static const int_type max_exponent = + (exponent_mask >> num_fraction_bits) - exponent_bias; + // The minimum representable exponent for normalized numbers. + static const int_type min_exponent = -static_cast(exponent_bias); + + // Returns the bits associated with the value. + uint_type getBits() const { return value_.data(); } + + // Returns the bits associated with the value, without the leading sign bit. + uint_type getUnsignedBits() const { + return static_cast(value_.data() & ~sign_mask); + } + + // Returns the bits associated with the exponent, shifted to start at the + // lsb of the type. + const uint_type getExponentBits() const { + return static_cast((getBits() & exponent_mask) >> + num_fraction_bits); + } + + // Returns the exponent in unbiased form. This is the exponent in the + // human-friendly form. + const int_type getUnbiasedExponent() const { + return static_cast(getExponentBits() - exponent_bias); + } + + // Returns just the significand bits from the value. + const uint_type getSignificandBits() const { + return getBits() & fraction_encode_mask; + } + + // If the number was normalized, returns the unbiased exponent. + // If the number was denormal, normalize the exponent first. + const int_type getUnbiasedNormalizedExponent() const { + if ((getBits() & ~sign_mask) == 0) { // special case if everything is 0 + return 0; + } + int_type exp = getUnbiasedExponent(); + if (exp == min_exponent) { // We are in denorm land. + uint_type significand_bits = getSignificandBits(); + while ((significand_bits & (first_exponent_bit >> 1)) == 0) { + significand_bits = static_cast(significand_bits << 1); + exp = static_cast(exp - 1); + } + significand_bits &= fraction_encode_mask; + } + return exp; + } + + // Returns the signficand after it has been normalized. + const uint_type getNormalizedSignificand() const { + int_type unbiased_exponent = getUnbiasedNormalizedExponent(); + uint_type significand = getSignificandBits(); + for (int_type i = unbiased_exponent; i <= min_exponent; ++i) { + significand = static_cast(significand << 1); + } + significand &= fraction_encode_mask; + return significand; + } + + // Returns true if this number represents a negative value. + bool isNegative() const { return (getBits() & sign_mask) != 0; } + + // Sets this HexFloat from the individual components. + // Note this assumes EVERY significand is normalized, and has an implicit + // leading one. This means that the only way that this method will set 0, + // is if you set a number so denormalized that it underflows. + // Do not use this method with raw bits extracted from a subnormal number, + // since subnormals do not have an implicit leading 1 in the significand. + // The significand is also expected to be in the + // lowest-most num_fraction_bits of the uint_type. + // The exponent is expected to be unbiased, meaning an exponent of + // 0 actually means 0. + // If underflow_round_up is set, then on underflow, if a number is non-0 + // and would underflow, we round up to the smallest denorm. + void setFromSignUnbiasedExponentAndNormalizedSignificand( + bool negative, int_type exponent, uint_type significand, + bool round_denorm_up) { + bool significand_is_zero = significand == 0; + + if (exponent <= min_exponent) { + // If this was denormalized, then we have to shift the bit on, meaning + // the significand is not zero. + significand_is_zero = false; + significand |= first_exponent_bit; + significand = static_cast(significand >> 1); + } + + while (exponent < min_exponent) { + significand = static_cast(significand >> 1); + ++exponent; + } + + if (exponent == min_exponent) { + if (significand == 0 && !significand_is_zero && round_denorm_up) { + significand = static_cast(0x1); + } + } + + uint_type new_value = 0; + if (negative) { + new_value = static_cast(new_value | sign_mask); + } + exponent = static_cast(exponent + exponent_bias); + assert(exponent >= 0); + + // put it all together + exponent = static_cast((exponent << exponent_left_shift) & + exponent_mask); + significand = static_cast(significand & fraction_encode_mask); + new_value = static_cast(new_value | (exponent | significand)); + value_ = T(new_value); + } + + // Increments the significand of this number by the given amount. + // If this would spill the significand into the implicit bit, + // carry is set to true and the significand is shifted to fit into + // the correct location, otherwise carry is set to false. + // All significands and to_increment are assumed to be within the bounds + // for a valid significand. + static uint_type incrementSignificand(uint_type significand, + uint_type to_increment, bool* carry) { + significand = static_cast(significand + to_increment); + *carry = false; + if (significand & first_exponent_bit) { + *carry = true; + // The implicit 1-bit will have carried, so we should zero-out the + // top bit and shift back. + significand = static_cast(significand & ~first_exponent_bit); + significand = static_cast(significand >> 1); + } + return significand; + } + +#if GCC_VERSION == 40801 + // These exist because MSVC throws warnings on negative right-shifts + // even if they are not going to be executed. Eg: + // constant_number < 0? 0: constant_number + // These convert the negative left-shifts into right shifts. + template + struct negatable_left_shift { + static uint_type val(uint_type val) { + if (N > 0) { + return static_cast(val << N); + } else { + return static_cast(val >> N); + } + } + }; + + template + struct negatable_right_shift { + static uint_type val(uint_type val) { + if (N > 0) { + return static_cast(val >> N); + } else { + return static_cast(val << N); + } + } + }; + +#else + // These exist because MSVC throws warnings on negative right-shifts + // even if they are not going to be executed. Eg: + // constant_number < 0? 0: constant_number + // These convert the negative left-shifts into right shifts. + template + struct negatable_left_shift { + static uint_type val(uint_type val) { + return static_cast(val >> -N); + } + }; + + template + struct negatable_left_shift= 0>::type> { + static uint_type val(uint_type val) { + return static_cast(val << N); + } + }; + + template + struct negatable_right_shift { + static uint_type val(uint_type val) { + return static_cast(val << -N); + } + }; + + template + struct negatable_right_shift= 0>::type> { + static uint_type val(uint_type val) { + return static_cast(val >> N); + } + }; +#endif + + // Returns the significand, rounded to fit in a significand in + // other_T. This is shifted so that the most significant + // bit of the rounded number lines up with the most significant bit + // of the returned significand. + template + typename other_T::uint_type getRoundedNormalizedSignificand( + round_direction dir, bool* carry_bit) { + using other_uint_type = typename other_T::uint_type; + static const int_type num_throwaway_bits = + static_cast(num_fraction_bits) - + static_cast(other_T::num_fraction_bits); + + static const uint_type last_significant_bit = + (num_throwaway_bits < 0) + ? 0 + : negatable_left_shift::val(1u); + static const uint_type first_rounded_bit = + (num_throwaway_bits < 1) + ? 0 + : negatable_left_shift::val(1u); + + static const uint_type throwaway_mask_bits = + num_throwaway_bits > 0 ? num_throwaway_bits : 0; + static const uint_type throwaway_mask = + SetBits::get; + + *carry_bit = false; + other_uint_type out_val = 0; + uint_type significand = getNormalizedSignificand(); + // If we are up-casting, then we just have to shift to the right location. + if (num_throwaway_bits <= 0) { + out_val = static_cast(significand); + uint_type shift_amount = static_cast(-num_throwaway_bits); + out_val = static_cast(out_val << shift_amount); + return out_val; + } + + // If every non-representable bit is 0, then we don't have any casting to + // do. + if ((significand & throwaway_mask) == 0) { + return static_cast( + negatable_right_shift::val(significand)); + } + + bool round_away_from_zero = false; + // We actually have to narrow the significand here, so we have to follow the + // rounding rules. + switch (dir) { + case round_direction::kToZero: + break; + case round_direction::kToPositiveInfinity: + round_away_from_zero = !isNegative(); + break; + case round_direction::kToNegativeInfinity: + round_away_from_zero = isNegative(); + break; + case round_direction::kToNearestEven: + // Have to round down, round bit is 0 + if ((first_rounded_bit & significand) == 0) { + break; + } + if (((significand & throwaway_mask) & ~first_rounded_bit) != 0) { + // If any subsequent bit of the rounded portion is non-0 then we round + // up. + round_away_from_zero = true; + break; + } + // We are exactly half-way between 2 numbers, pick even. + if ((significand & last_significant_bit) != 0) { + // 1 for our last bit, round up. + round_away_from_zero = true; + break; + } + break; + } + + if (round_away_from_zero) { + return static_cast( + negatable_right_shift::val(incrementSignificand( + significand, last_significant_bit, carry_bit))); + } else { + return static_cast( + negatable_right_shift::val(significand)); + } + } + + // Casts this value to another HexFloat. If the cast is widening, + // then round_dir is ignored. If the cast is narrowing, then + // the result is rounded in the direction specified. + // This number will retain Nan and Inf values. + // It will also saturate to Inf if the number overflows, and + // underflow to (0 or min depending on rounding) if the number underflows. + template + void castTo(other_T& other, round_direction round_dir) { + other = other_T(static_cast(0)); + bool negate = isNegative(); + if (getUnsignedBits() == 0) { + if (negate) { + other.set_value(-other.value()); + } + return; + } + uint_type significand = getSignificandBits(); + bool carried = false; + typename other_T::uint_type rounded_significand = + getRoundedNormalizedSignificand(round_dir, &carried); + + int_type exponent = getUnbiasedExponent(); + if (exponent == min_exponent) { + // If we are denormal, normalize the exponent, so that we can encode + // easily. + exponent = static_cast(exponent + 1); + for (uint_type check_bit = first_exponent_bit >> 1; check_bit != 0; + check_bit = static_cast(check_bit >> 1)) { + exponent = static_cast(exponent - 1); + if (check_bit & significand) break; + } + } + + bool is_nan = + (getBits() & exponent_mask) == exponent_mask && significand != 0; + bool is_inf = + !is_nan && + ((exponent + carried) > static_cast(other_T::exponent_bias) || + (significand == 0 && (getBits() & exponent_mask) == exponent_mask)); + + // If we are Nan or Inf we should pass that through. + if (is_inf) { + other.set_value(typename other_T::underlying_type( + static_cast( + (negate ? other_T::sign_mask : 0) | other_T::exponent_mask))); + return; + } + if (is_nan) { + typename other_T::uint_type shifted_significand; + shifted_significand = static_cast( + negatable_left_shift< + static_cast(other_T::num_fraction_bits) - + static_cast(num_fraction_bits)>::val(significand)); + + // We are some sort of Nan. We try to keep the bit-pattern of the Nan + // as close as possible. If we had to shift off bits so we are 0, then we + // just set the last bit. + other.set_value(typename other_T::underlying_type( + static_cast( + (negate ? other_T::sign_mask : 0) | other_T::exponent_mask | + (shifted_significand == 0 ? 0x1 : shifted_significand)))); + return; + } + + bool round_underflow_up = + isNegative() ? round_dir == round_direction::kToNegativeInfinity + : round_dir == round_direction::kToPositiveInfinity; + using other_int_type = typename other_T::int_type; + // setFromSignUnbiasedExponentAndNormalizedSignificand will + // zero out any underflowing value (but retain the sign). + other.setFromSignUnbiasedExponentAndNormalizedSignificand( + negate, static_cast(exponent), rounded_significand, + round_underflow_up); + return; + } + + private: + T value_; + + static_assert(num_used_bits == + Traits::num_exponent_bits + Traits::num_fraction_bits + 1, + "The number of bits do not fit"); + static_assert(sizeof(T) == sizeof(uint_type), "The type sizes do not match"); +}; + +// Returns 4 bits represented by the hex character. +inline uint8_t get_nibble_from_character(int character) { + const char* dec = "0123456789"; + const char* lower = "abcdef"; + const char* upper = "ABCDEF"; + const char* p = nullptr; + if ((p = strchr(dec, character))) { + return static_cast(p - dec); + } else if ((p = strchr(lower, character))) { + return static_cast(p - lower + 0xa); + } else if ((p = strchr(upper, character))) { + return static_cast(p - upper + 0xa); + } + + assert(false && "This was called with a non-hex character"); + return 0; +} + +// Outputs the given HexFloat to the stream. +template +std::ostream& operator<<(std::ostream& os, const HexFloat& value) { + using HF = HexFloat; + using uint_type = typename HF::uint_type; + using int_type = typename HF::int_type; + + static_assert(HF::num_used_bits != 0, + "num_used_bits must be non-zero for a valid float"); + static_assert(HF::num_exponent_bits != 0, + "num_exponent_bits must be non-zero for a valid float"); + static_assert(HF::num_fraction_bits != 0, + "num_fractin_bits must be non-zero for a valid float"); + + const uint_type bits = value.value().data(); + const char* const sign = (bits & HF::sign_mask) ? "-" : ""; + const uint_type exponent = static_cast( + (bits & HF::exponent_mask) >> HF::num_fraction_bits); + + uint_type fraction = static_cast((bits & HF::fraction_encode_mask) + << HF::num_overflow_bits); + + const bool is_zero = exponent == 0 && fraction == 0; + const bool is_denorm = exponent == 0 && !is_zero; + + // exponent contains the biased exponent we have to convert it back into + // the normal range. + int_type int_exponent = static_cast(exponent - HF::exponent_bias); + // If the number is all zeros, then we actually have to NOT shift the + // exponent. + int_exponent = is_zero ? 0 : int_exponent; + + // If we are denorm, then start shifting, and decreasing the exponent until + // our leading bit is 1. + + if (is_denorm) { + while ((fraction & HF::fraction_top_bit) == 0) { + fraction = static_cast(fraction << 1); + int_exponent = static_cast(int_exponent - 1); + } + // Since this is denormalized, we have to consume the leading 1 since it + // will end up being implicit. + fraction = static_cast(fraction << 1); // eat the leading 1 + fraction &= HF::fraction_represent_mask; + } + + uint_type fraction_nibbles = HF::fraction_nibbles; + // We do not have to display any trailing 0s, since this represents the + // fractional part. + while (fraction_nibbles > 0 && (fraction & 0xF) == 0) { + // Shift off any trailing values; + fraction = static_cast(fraction >> 4); + --fraction_nibbles; + } + + const auto saved_flags = os.flags(); + const auto saved_fill = os.fill(); + + os << sign << "0x" << (is_zero ? '0' : '1'); + if (fraction_nibbles) { + // Make sure to keep the leading 0s in place, since this is the fractional + // part. + os << "." << std::setw(static_cast(fraction_nibbles)) + << std::setfill('0') << std::hex << fraction; + } + os << "p" << std::dec << (int_exponent >= 0 ? "+" : "") << int_exponent; + + os.flags(saved_flags); + os.fill(saved_fill); + + return os; +} + +// Returns true if negate_value is true and the next character on the +// input stream is a plus or minus sign. In that case we also set the fail bit +// on the stream and set the value to the zero value for its type. +template +inline bool RejectParseDueToLeadingSign(std::istream& is, bool negate_value, + HexFloat& value) { + if (negate_value) { + auto next_char = is.peek(); + if (next_char == '-' || next_char == '+') { + // Fail the parse. Emulate standard behaviour by setting the value to + // the zero value, and set the fail bit on the stream. + value = HexFloat(typename HexFloat::uint_type{0}); + is.setstate(std::ios_base::failbit); + return true; + } + } + return false; +} + +// Parses a floating point number from the given stream and stores it into the +// value parameter. +// If negate_value is true then the number may not have a leading minus or +// plus, and if it successfully parses, then the number is negated before +// being stored into the value parameter. +// If the value cannot be correctly parsed or overflows the target floating +// point type, then set the fail bit on the stream. +// TODO(dneto): Promise C++11 standard behavior in how the value is set in +// the error case, but only after all target platforms implement it correctly. +// In particular, the Microsoft C++ runtime appears to be out of spec. +template +inline std::istream& ParseNormalFloat(std::istream& is, bool negate_value, + HexFloat& value) { + if (RejectParseDueToLeadingSign(is, negate_value, value)) { + return is; + } + T val; + is >> val; + if (negate_value) { + val = -val; + } + value.set_value(val); + // In the failure case, map -0.0 to 0.0. + if (is.fail() && value.getUnsignedBits() == 0u) { + value = HexFloat(typename HexFloat::uint_type{0}); + } + if (val.isInfinity()) { + // Fail the parse. Emulate standard behaviour by setting the value to + // the closest normal value, and set the fail bit on the stream. + value.set_value((value.isNegative() | negate_value) ? T::lowest() + : T::max()); + is.setstate(std::ios_base::failbit); + } + return is; +} + +// Specialization of ParseNormalFloat for FloatProxy values. +// This will parse the float as it were a 32-bit floating point number, +// and then round it down to fit into a Float16 value. +// The number is rounded towards zero. +// If negate_value is true then the number may not have a leading minus or +// plus, and if it successfully parses, then the number is negated before +// being stored into the value parameter. +// If the value cannot be correctly parsed or overflows the target floating +// point type, then set the fail bit on the stream. +// TODO(dneto): Promise C++11 standard behavior in how the value is set in +// the error case, but only after all target platforms implement it correctly. +// In particular, the Microsoft C++ runtime appears to be out of spec. +template <> +inline std::istream& +ParseNormalFloat, HexFloatTraits>>( + std::istream& is, bool negate_value, + HexFloat, HexFloatTraits>>& value) { + // First parse as a 32-bit float. + HexFloat> float_val(0.0f); + ParseNormalFloat(is, negate_value, float_val); + + // Then convert to 16-bit float, saturating at infinities, and + // rounding toward zero. + float_val.castTo(value, round_direction::kToZero); + + // Overflow on 16-bit behaves the same as for 32- and 64-bit: set the + // fail bit and set the lowest or highest value. + if (Float16::isInfinity(value.value().getAsFloat())) { + value.set_value(value.isNegative() ? Float16::lowest() : Float16::max()); + is.setstate(std::ios_base::failbit); + } + return is; +} + +// Reads a HexFloat from the given stream. +// If the float is not encoded as a hex-float then it will be parsed +// as a regular float. +// This may fail if your stream does not support at least one unget. +// Nan values can be encoded with "0x1.p+exponent_bias". +// This would normally overflow a float and round to +// infinity but this special pattern is the exact representation for a NaN, +// and therefore is actually encoded as the correct NaN. To encode inf, +// either 0x0p+exponent_bias can be specified or any exponent greater than +// exponent_bias. +// Examples using IEEE 32-bit float encoding. +// 0x1.0p+128 (+inf) +// -0x1.0p-128 (-inf) +// +// 0x1.1p+128 (+Nan) +// -0x1.1p+128 (-Nan) +// +// 0x1p+129 (+inf) +// -0x1p+129 (-inf) +template +std::istream& operator>>(std::istream& is, HexFloat& value) { + using HF = HexFloat; + using uint_type = typename HF::uint_type; + using int_type = typename HF::int_type; + + value.set_value(static_cast(0.f)); + + if (is.flags() & std::ios::skipws) { + // If the user wants to skip whitespace , then we should obey that. + while (std::isspace(is.peek())) { + is.get(); + } + } + + auto next_char = is.peek(); + bool negate_value = false; + + if (next_char != '-' && next_char != '0') { + return ParseNormalFloat(is, negate_value, value); + } + + if (next_char == '-') { + negate_value = true; + is.get(); + next_char = is.peek(); + } + + if (next_char == '0') { + is.get(); // We may have to unget this. + auto maybe_hex_start = is.peek(); + if (maybe_hex_start != 'x' && maybe_hex_start != 'X') { + is.unget(); + return ParseNormalFloat(is, negate_value, value); + } else { + is.get(); // Throw away the 'x'; + } + } else { + return ParseNormalFloat(is, negate_value, value); + } + + // This "looks" like a hex-float so treat it as one. + bool seen_p = false; + bool seen_dot = false; + uint_type fraction_index = 0; + + uint_type fraction = 0; + int_type exponent = HF::exponent_bias; + + // Strip off leading zeros so we don't have to special-case them later. + while ((next_char = is.peek()) == '0') { + is.get(); + } + + bool is_denorm = + true; // Assume denorm "representation" until we hear otherwise. + // NB: This does not mean the value is actually denorm, + // it just means that it was written 0. + bool bits_written = false; // Stays false until we write a bit. + while (!seen_p && !seen_dot) { + // Handle characters that are left of the fractional part. + if (next_char == '.') { + seen_dot = true; + } else if (next_char == 'p') { + seen_p = true; + } else if (::isxdigit(next_char)) { + // We know this is not denormalized since we have stripped all leading + // zeroes and we are not a ".". + is_denorm = false; + int number = get_nibble_from_character(next_char); + for (int i = 0; i < 4; ++i, number <<= 1) { + uint_type write_bit = (number & 0x8) ? 0x1 : 0x0; + if (bits_written) { + // If we are here the bits represented belong in the fractional + // part of the float, and we have to adjust the exponent accordingly. + fraction = static_cast( + fraction | + static_cast( + write_bit << (HF::top_bit_left_shift - fraction_index++))); + exponent = static_cast(exponent + 1); + } + bits_written |= write_bit != 0; + } + } else { + // We have not found our exponent yet, so we have to fail. + is.setstate(std::ios::failbit); + return is; + } + is.get(); + next_char = is.peek(); + } + bits_written = false; + while (seen_dot && !seen_p) { + // Handle only fractional parts now. + if (next_char == 'p') { + seen_p = true; + } else if (::isxdigit(next_char)) { + int number = get_nibble_from_character(next_char); + for (int i = 0; i < 4; ++i, number <<= 1) { + uint_type write_bit = (number & 0x8) ? 0x01 : 0x00; + bits_written |= write_bit != 0; + if (is_denorm && !bits_written) { + // Handle modifying the exponent here this way we can handle + // an arbitrary number of hex values without overflowing our + // integer. + exponent = static_cast(exponent - 1); + } else { + fraction = static_cast( + fraction | + static_cast( + write_bit << (HF::top_bit_left_shift - fraction_index++))); + } + } + } else { + // We still have not found our 'p' exponent yet, so this is not a valid + // hex-float. + is.setstate(std::ios::failbit); + return is; + } + is.get(); + next_char = is.peek(); + } + + bool seen_sign = false; + int8_t exponent_sign = 1; + int_type written_exponent = 0; + while (true) { + if ((next_char == '-' || next_char == '+')) { + if (seen_sign) { + is.setstate(std::ios::failbit); + return is; + } + seen_sign = true; + exponent_sign = (next_char == '-') ? -1 : 1; + } else if (::isdigit(next_char)) { + // Hex-floats express their exponent as decimal. + written_exponent = static_cast(written_exponent * 10); + written_exponent = + static_cast(written_exponent + (next_char - '0')); + } else { + break; + } + is.get(); + next_char = is.peek(); + } + + written_exponent = static_cast(written_exponent * exponent_sign); + exponent = static_cast(exponent + written_exponent); + + bool is_zero = is_denorm && (fraction == 0); + if (is_denorm && !is_zero) { + fraction = static_cast(fraction << 1); + exponent = static_cast(exponent - 1); + } else if (is_zero) { + exponent = 0; + } + + if (exponent <= 0 && !is_zero) { + fraction = static_cast(fraction >> 1); + fraction |= static_cast(1) << HF::top_bit_left_shift; + } + + fraction = (fraction >> HF::fraction_right_shift) & HF::fraction_encode_mask; + + const int_type max_exponent = + SetBits::get; + + // Handle actual denorm numbers + while (exponent < 0 && !is_zero) { + fraction = static_cast(fraction >> 1); + exponent = static_cast(exponent + 1); + + fraction &= HF::fraction_encode_mask; + if (fraction == 0) { + // We have underflowed our fraction. We should clamp to zero. + is_zero = true; + exponent = 0; + } + } + + // We have overflowed so we should be inf/-inf. + if (exponent > max_exponent) { + exponent = max_exponent; + fraction = 0; + } + + uint_type output_bits = static_cast( + static_cast(negate_value ? 1 : 0) << HF::top_bit_left_shift); + output_bits |= fraction; + + uint_type shifted_exponent = static_cast( + static_cast(exponent << HF::exponent_left_shift) & + HF::exponent_mask); + output_bits |= shifted_exponent; + + T output_float(output_bits); + value.set_value(output_float); + + return is; +} + +// Writes a FloatProxy value to a stream. +// Zero and normal numbers are printed in the usual notation, but with +// enough digits to fully reproduce the value. Other values (subnormal, +// NaN, and infinity) are printed as a hex float. +template +std::ostream& operator<<(std::ostream& os, const FloatProxy& value) { + auto float_val = value.getAsFloat(); + switch (std::fpclassify(float_val)) { + case FP_ZERO: + case FP_NORMAL: { + auto saved_precision = os.precision(); + os.precision(std::numeric_limits::max_digits10); + os << float_val; + os.precision(saved_precision); + } break; + default: + os << HexFloat>(value); + break; + } + return os; +} + +template <> +inline std::ostream& operator<<(std::ostream& os, + const FloatProxy& value) { + os << HexFloat>(value); + return os; +} + +} // namespace utils +} // namespace spvtools + +#endif // SOURCE_UTIL_HEX_FLOAT_H_ diff --git a/source/util/ilist.h b/source/util/ilist.h new file mode 100644 index 000000000..9837b09b3 --- /dev/null +++ b/source/util/ilist.h @@ -0,0 +1,365 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_UTIL_ILIST_H_ +#define SOURCE_UTIL_ILIST_H_ + +#include +#include +#include +#include + +#include "source/util/ilist_node.h" + +namespace spvtools { +namespace utils { + +// An IntrusiveList is a generic implementation of a doubly-linked list. The +// intended convention for using this container is: +// +// class Node : public IntrusiveNodeBase { +// // Note that "Node", the class being defined is the template. +// // Must have a default constructor accessible to List. +// // Add whatever data is needed in the node +// }; +// +// using List = IntrusiveList; +// +// You can also inherit from IntrusiveList instead of a typedef if you want to +// add more functionality. +// +// The condition on the template for IntrusiveNodeBase is there to add some type +// checking to the container. The compiler will still allow inserting elements +// of type IntrusiveNodeBase, but that would be an error. This assumption +// allows NextNode and PreviousNode to return pointers to Node, and casting will +// not be required by the user. + +template +class IntrusiveList { + public: + static_assert( + std::is_base_of, NodeType>::value, + "The type from the node must be derived from IntrusiveNodeBase, with " + "itself in the template."); + + // Creates an empty list. + inline IntrusiveList(); + + // Moves the contents of the given list to the list being constructed. + IntrusiveList(IntrusiveList&&); + + // Destorys the list. Note that the elements of the list will not be deleted, + // but they will be removed from the list. + virtual ~IntrusiveList(); + + // Moves all of the elements in the list on the RHS to the list on the LHS. + IntrusiveList& operator=(IntrusiveList&&); + + // Basetype for iterators so an IntrusiveList can be traversed like STL + // containers. + template + class iterator_template { + public: + iterator_template(const iterator_template& i) : node_(i.node_) {} + + iterator_template& operator++() { + node_ = node_->next_node_; + return *this; + } + + iterator_template& operator--() { + node_ = node_->previous_node_; + return *this; + } + + iterator_template& operator=(const iterator_template& i) { + node_ = i.node_; + return *this; + } + + T& operator*() const { return *node_; } + T* operator->() const { return node_; } + + friend inline bool operator==(const iterator_template& lhs, + const iterator_template& rhs) { + return lhs.node_ == rhs.node_; + } + friend inline bool operator!=(const iterator_template& lhs, + const iterator_template& rhs) { + return !(lhs == rhs); + } + + // Moves the nodes in |list| to the list that |this| points to. The + // positions of the nodes will be immediately before the element pointed to + // by the iterator. The return value will be an iterator pointing to the + // first of the newly inserted elements. + iterator_template MoveBefore(IntrusiveList* list) { + if (list->empty()) return *this; + + NodeType* first_node = list->sentinel_.next_node_; + NodeType* last_node = list->sentinel_.previous_node_; + + this->node_->previous_node_->next_node_ = first_node; + first_node->previous_node_ = this->node_->previous_node_; + + last_node->next_node_ = this->node_; + this->node_->previous_node_ = last_node; + + list->sentinel_.next_node_ = &list->sentinel_; + list->sentinel_.previous_node_ = &list->sentinel_; + + return iterator(first_node); + } + + // Define standard iterator types needs so this class can be + // used with . + using iterator_category = std::bidirectional_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = T; + using pointer = T*; + using const_pointer = const T*; + using reference = T&; + using const_reference = const T&; + using size_type = size_t; + + protected: + iterator_template() = delete; + inline iterator_template(T* node) { node_ = node; } + T* node_; + + friend IntrusiveList; + }; + + using iterator = iterator_template; + using const_iterator = iterator_template; + + // Various types of iterators for the start (begin) and one past the end (end) + // of the list. + // + // Decrementing |end()| iterator will give and iterator pointing to the last + // element in the list, if one exists. + // + // Incrementing |end()| iterator will give |begin()|. + // + // Decrementing |begin()| will give |end()|. + // + // TODO: Not marking these functions as noexcept because Visual Studio 2013 + // does not support it. When we no longer care about that compiler, we should + // mark these as noexcept. + iterator begin(); + iterator end(); + const_iterator begin() const; + const_iterator end() const; + const_iterator cbegin() const; + const_iterator cend() const; + + // Appends |node| to the end of the list. If |node| is already in a list, it + // will be removed from that list first. + void push_back(NodeType* node); + + // Returns true if the list is empty. + bool empty() const; + + // Makes the current list empty. + inline void clear(); + + // Returns references to the first or last element in the list. It is an + // error to call these functions on an empty list. + NodeType& front(); + NodeType& back(); + const NodeType& front() const; + const NodeType& back() const; + + // Transfers [|first|, |last|) from |other| into the list at |where|. + // + // If |other| is |this|, no change is made. + void Splice(iterator where, IntrusiveList* other, iterator first, + iterator last); + + protected: + // Doing a deep copy of the list does not make sense if the list does not own + // the data. It is not clear who will own the newly created data. Making + // copies illegal for that reason. + IntrusiveList(const IntrusiveList&) = delete; + IntrusiveList& operator=(const IntrusiveList&) = delete; + + // This function will assert if it finds the list containing |node| is not in + // a valid state. + static void Check(NodeType* node); + + // A special node used to represent both the start and end of the list, + // without being part of the list. + NodeType sentinel_; +}; + +// Implementation of IntrusiveList + +template +inline IntrusiveList::IntrusiveList() : sentinel_() { + sentinel_.next_node_ = &sentinel_; + sentinel_.previous_node_ = &sentinel_; + sentinel_.is_sentinel_ = true; +} + +template +IntrusiveList::IntrusiveList(IntrusiveList&& list) : sentinel_() { + sentinel_.next_node_ = &sentinel_; + sentinel_.previous_node_ = &sentinel_; + sentinel_.is_sentinel_ = true; + list.sentinel_.ReplaceWith(&sentinel_); +} + +template +IntrusiveList::~IntrusiveList() { + clear(); +} + +template +IntrusiveList& IntrusiveList::operator=( + IntrusiveList&& list) { + list.sentinel_.ReplaceWith(&sentinel_); + return *this; +} + +template +inline typename IntrusiveList::iterator +IntrusiveList::begin() { + return iterator(sentinel_.next_node_); +} + +template +inline typename IntrusiveList::iterator +IntrusiveList::end() { + return iterator(&sentinel_); +} + +template +inline typename IntrusiveList::const_iterator +IntrusiveList::begin() const { + return const_iterator(sentinel_.next_node_); +} + +template +inline typename IntrusiveList::const_iterator +IntrusiveList::end() const { + return const_iterator(&sentinel_); +} + +template +inline typename IntrusiveList::const_iterator +IntrusiveList::cbegin() const { + return const_iterator(sentinel_.next_node_); +} + +template +inline typename IntrusiveList::const_iterator +IntrusiveList::cend() const { + return const_iterator(&sentinel_); +} + +template +void IntrusiveList::push_back(NodeType* node) { + node->InsertBefore(&sentinel_); +} + +template +bool IntrusiveList::empty() const { + return sentinel_.NextNode() == nullptr; +} + +template +void IntrusiveList::clear() { + while (!empty()) { + front().RemoveFromList(); + } +} + +template +NodeType& IntrusiveList::front() { + NodeType* node = sentinel_.NextNode(); + assert(node != nullptr && "Can't get the front of an empty list."); + return *node; +} + +template +NodeType& IntrusiveList::back() { + NodeType* node = sentinel_.PreviousNode(); + assert(node != nullptr && "Can't get the back of an empty list."); + return *node; +} + +template +const NodeType& IntrusiveList::front() const { + NodeType* node = sentinel_.NextNode(); + assert(node != nullptr && "Can't get the front of an empty list."); + return *node; +} + +template +const NodeType& IntrusiveList::back() const { + NodeType* node = sentinel_.PreviousNode(); + assert(node != nullptr && "Can't get the back of an empty list."); + return *node; +} + +template +void IntrusiveList::Splice(iterator where, + IntrusiveList* other, + iterator first, iterator last) { + if (first == last) return; + if (other == this) return; + + NodeType* first_prev = first.node_->previous_node_; + NodeType* where_next = where.node_->next_node_; + + // Attach first. + where.node_->next_node_ = first.node_; + first.node_->previous_node_ = where.node_; + + // Attach last. + where_next->previous_node_ = last.node_->previous_node_; + last.node_->previous_node_->next_node_ = where_next; + + // Fixup other. + first_prev->next_node_ = last.node_; + last.node_->previous_node_ = first_prev; +} + +template +void IntrusiveList::Check(NodeType* start) { + int sentinel_count = 0; + NodeType* p = start; + do { + assert(p != nullptr); + assert(p->next_node_->previous_node_ == p); + assert(p->previous_node_->next_node_ == p); + if (p->is_sentinel_) sentinel_count++; + p = p->next_node_; + } while (p != start); + assert(sentinel_count == 1 && "List should have exactly 1 sentinel node."); + + p = start; + do { + assert(p != nullptr); + assert(p->previous_node_->next_node_ == p); + assert(p->next_node_->previous_node_ == p); + if (p->is_sentinel_) sentinel_count++; + p = p->previous_node_; + } while (p != start); +} + +} // namespace utils +} // namespace spvtools + +#endif // SOURCE_UTIL_ILIST_H_ diff --git a/source/util/ilist_node.h b/source/util/ilist_node.h new file mode 100644 index 000000000..0579534b8 --- /dev/null +++ b/source/util/ilist_node.h @@ -0,0 +1,265 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_UTIL_ILIST_NODE_H_ +#define SOURCE_UTIL_ILIST_NODE_H_ + +#include + +namespace spvtools { +namespace utils { + +template +class IntrusiveList; + +// IntrusiveNodeBase is the base class for nodes in an IntrusiveList. +// See the comments in ilist.h on how to use the class. + +template +class IntrusiveNodeBase { + public: + // Creates a new node that is not in a list. + inline IntrusiveNodeBase(); + inline IntrusiveNodeBase(const IntrusiveNodeBase&); + inline IntrusiveNodeBase& operator=(const IntrusiveNodeBase&); + inline IntrusiveNodeBase(IntrusiveNodeBase&& that); + + // Destroys a node. It is an error to destroy a node that is part of a + // list, unless it is a sentinel. + virtual ~IntrusiveNodeBase(); + + IntrusiveNodeBase& operator=(IntrusiveNodeBase&& that); + + // Returns true if |this| is in a list. + inline bool IsInAList() const; + + // Returns the node that comes after the given node in the list, if one + // exists. If the given node is not in a list or is at the end of the list, + // the return value is nullptr. + inline NodeType* NextNode() const; + + // Returns the node that comes before the given node in the list, if one + // exists. If the given node is not in a list or is at the start of the + // list, the return value is nullptr. + inline NodeType* PreviousNode() const; + + // Inserts the given node immediately before |pos| in the list. + // If the given node is already in a list, it will first be removed + // from that list. + // + // It is assumed that the given node is of type NodeType. It is an error if + // |pos| is not already in a list. + inline void InsertBefore(NodeType* pos); + + // Inserts the given node immediately after |pos| in the list. + // If the given node is already in a list, it will first be removed + // from that list. + // + // It is assumed that the given node is of type NodeType. It is an error if + // |pos| is not already in a list, or if |pos| is equal to |this|. + inline void InsertAfter(NodeType* pos); + + // Removes the given node from the list. It is assumed that the node is + // in a list. Note that this does not free any storage related to the node, + // it becomes the caller's responsibility to free the storage. + inline void RemoveFromList(); + + protected: + // Replaces |this| with |target|. |this| is a sentinel if and only if + // |target| is also a sentinel. + // + // If neither node is a sentinel, |target| takes + // the place of |this|. It is assumed that |target| is not in a list. + // + // If both are sentinels, then it will cause all of the + // nodes in the list containing |this| to be moved to the list containing + // |target|. In this case, it is assumed that |target| is an empty list. + // + // No storage will be deleted. + void ReplaceWith(NodeType* target); + + // Returns true if |this| is the sentinel node of an empty list. + bool IsEmptyList(); + + // The pointers to the next and previous nodes in the list. + // If the current node is not part of a list, then |next_node_| and + // |previous_node_| are equal to |nullptr|. + NodeType* next_node_; + NodeType* previous_node_; + + // Only true for the sentinel node stored in the list itself. + bool is_sentinel_; + + friend IntrusiveList; +}; + +// Implementation of IntrusiveNodeBase + +template +inline IntrusiveNodeBase::IntrusiveNodeBase() + : next_node_(nullptr), previous_node_(nullptr), is_sentinel_(false) {} + +template +inline IntrusiveNodeBase::IntrusiveNodeBase( + const IntrusiveNodeBase&) { + next_node_ = nullptr; + previous_node_ = nullptr; + is_sentinel_ = false; +} + +template +inline IntrusiveNodeBase& IntrusiveNodeBase::operator=( + const IntrusiveNodeBase&) { + assert(!is_sentinel_); + if (IsInAList()) { + RemoveFromList(); + } + return *this; +} + +template +inline IntrusiveNodeBase::IntrusiveNodeBase(IntrusiveNodeBase&& that) + : next_node_(nullptr), + previous_node_(nullptr), + is_sentinel_(that.is_sentinel_) { + if (is_sentinel_) { + next_node_ = this; + previous_node_ = this; + } + that.ReplaceWith(this); +} + +template +IntrusiveNodeBase::~IntrusiveNodeBase() { + assert(is_sentinel_ || !IsInAList()); +} + +template +IntrusiveNodeBase& IntrusiveNodeBase::operator=( + IntrusiveNodeBase&& that) { + that.ReplaceWith(this); + return *this; +} + +template +inline bool IntrusiveNodeBase::IsInAList() const { + return next_node_ != nullptr; +} + +template +inline NodeType* IntrusiveNodeBase::NextNode() const { + if (!next_node_->is_sentinel_) return next_node_; + return nullptr; +} + +template +inline NodeType* IntrusiveNodeBase::PreviousNode() const { + if (!previous_node_->is_sentinel_) return previous_node_; + return nullptr; +} + +template +inline void IntrusiveNodeBase::InsertBefore(NodeType* pos) { + assert(!this->is_sentinel_ && "Sentinel nodes cannot be moved around."); + assert(pos->IsInAList() && "Pos should already be in a list."); + if (this->IsInAList()) this->RemoveFromList(); + + this->next_node_ = pos; + this->previous_node_ = pos->previous_node_; + pos->previous_node_ = static_cast(this); + this->previous_node_->next_node_ = static_cast(this); +} + +template +inline void IntrusiveNodeBase::InsertAfter(NodeType* pos) { + assert(!this->is_sentinel_ && "Sentinel nodes cannot be moved around."); + assert(pos->IsInAList() && "Pos should already be in a list."); + assert(this != pos && "Can't insert a node after itself."); + + if (this->IsInAList()) { + this->RemoveFromList(); + } + + this->previous_node_ = pos; + this->next_node_ = pos->next_node_; + pos->next_node_ = static_cast(this); + this->next_node_->previous_node_ = static_cast(this); +} + +template +inline void IntrusiveNodeBase::RemoveFromList() { + assert(!this->is_sentinel_ && "Sentinel nodes cannot be moved around."); + assert(this->IsInAList() && + "Cannot remove a node from a list if it is not in a list."); + + this->next_node_->previous_node_ = this->previous_node_; + this->previous_node_->next_node_ = this->next_node_; + this->next_node_ = nullptr; + this->previous_node_ = nullptr; +} + +template +void IntrusiveNodeBase::ReplaceWith(NodeType* target) { + if (this->is_sentinel_) { + assert(target->IsEmptyList() && + "If target is not an empty list, the nodes in that list would not " + "be linked to a sentinel."); + } else { + assert(IsInAList() && "The node being replaced must be in a list."); + assert(!target->is_sentinel_ && + "Cannot turn a sentinel node into one that is not."); + } + + if (!this->IsEmptyList()) { + // Link target into the same position that |this| was in. + target->next_node_ = this->next_node_; + target->previous_node_ = this->previous_node_; + target->next_node_->previous_node_ = target; + target->previous_node_->next_node_ = target; + + // Reset |this| to itself default value. + if (!this->is_sentinel_) { + // Reset |this| so that it is not in a list. + this->next_node_ = nullptr; + this->previous_node_ = nullptr; + } else { + // Set |this| so that it is the head of an empty list. + // We cannot treat sentinel nodes like others because it is invalid for + // a sentinel node to not be in a list. + this->next_node_ = static_cast(this); + this->previous_node_ = static_cast(this); + } + } else { + // If |this| points to itself, it must be a sentinel node with an empty + // list. Reset |this| so that it is the head of an empty list. We want + // |target| to be the same. The asserts above guarantee that. + } +} + +template +bool IntrusiveNodeBase::IsEmptyList() { + if (next_node_ == this) { + assert(is_sentinel_ && + "None sentinel nodes should never point to themselves."); + assert(previous_node_ == this && + "Inconsistency with the previous and next nodes."); + return true; + } + return false; +} + +} // namespace utils +} // namespace spvtools + +#endif // SOURCE_UTIL_ILIST_NODE_H_ diff --git a/source/util/make_unique.h b/source/util/make_unique.h new file mode 100644 index 000000000..ad7976c34 --- /dev/null +++ b/source/util/make_unique.h @@ -0,0 +1,30 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_UTIL_MAKE_UNIQUE_H_ +#define SOURCE_UTIL_MAKE_UNIQUE_H_ + +#include +#include + +namespace spvtools { + +template +std::unique_ptr MakeUnique(Args&&... args) { + return std::unique_ptr(new T(std::forward(args)...)); +} + +} // namespace spvtools + +#endif // SOURCE_UTIL_MAKE_UNIQUE_H_ diff --git a/source/util/parse_number.cpp b/source/util/parse_number.cpp new file mode 100644 index 000000000..c3351c236 --- /dev/null +++ b/source/util/parse_number.cpp @@ -0,0 +1,217 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/util/parse_number.h" + +#include +#include +#include +#include +#include +#include + +#include "source/util/hex_float.h" +#include "source/util/make_unique.h" + +namespace spvtools { +namespace utils { +namespace { + +// A helper class that temporarily stores error messages and dump the messages +// to a string which given as as pointer when it is destructed. If the given +// pointer is a nullptr, this class does not store error message. +class ErrorMsgStream { + public: + explicit ErrorMsgStream(std::string* error_msg_sink) + : error_msg_sink_(error_msg_sink) { + if (error_msg_sink_) stream_ = MakeUnique(); + } + ~ErrorMsgStream() { + if (error_msg_sink_ && stream_) *error_msg_sink_ = stream_->str(); + } + template + ErrorMsgStream& operator<<(T val) { + if (stream_) *stream_ << val; + return *this; + } + + private: + std::unique_ptr stream_; + // The destination string to which this class dump the error message when + // destructor is called. + std::string* error_msg_sink_; +}; +} // namespace + +EncodeNumberStatus ParseAndEncodeIntegerNumber( + const char* text, const NumberType& type, + std::function emit, std::string* error_msg) { + if (!text) { + ErrorMsgStream(error_msg) << "The given text is a nullptr"; + return EncodeNumberStatus::kInvalidText; + } + + if (!IsIntegral(type)) { + ErrorMsgStream(error_msg) << "The expected type is not a integer type"; + return EncodeNumberStatus::kInvalidUsage; + } + + const uint32_t bit_width = AssumedBitWidth(type); + + if (bit_width > 64) { + ErrorMsgStream(error_msg) + << "Unsupported " << bit_width << "-bit integer literals"; + return EncodeNumberStatus::kUnsupported; + } + + // Either we are expecting anything or integer. + bool is_negative = text[0] == '-'; + bool can_be_signed = IsSigned(type); + + if (is_negative && !can_be_signed) { + ErrorMsgStream(error_msg) + << "Cannot put a negative number in an unsigned literal"; + return EncodeNumberStatus::kInvalidUsage; + } + + const bool is_hex = text[0] == '0' && (text[1] == 'x' || text[1] == 'X'); + + uint64_t decoded_bits; + if (is_negative) { + int64_t decoded_signed = 0; + + if (!ParseNumber(text, &decoded_signed)) { + ErrorMsgStream(error_msg) << "Invalid signed integer literal: " << text; + return EncodeNumberStatus::kInvalidText; + } + + if (!CheckRangeAndIfHexThenSignExtend(decoded_signed, type, is_hex, + &decoded_signed)) { + ErrorMsgStream(error_msg) + << "Integer " << (is_hex ? std::hex : std::dec) << std::showbase + << decoded_signed << " does not fit in a " << std::dec << bit_width + << "-bit " << (IsSigned(type) ? "signed" : "unsigned") << " integer"; + return EncodeNumberStatus::kInvalidText; + } + decoded_bits = decoded_signed; + } else { + // There's no leading minus sign, so parse it as an unsigned integer. + if (!ParseNumber(text, &decoded_bits)) { + ErrorMsgStream(error_msg) << "Invalid unsigned integer literal: " << text; + return EncodeNumberStatus::kInvalidText; + } + if (!CheckRangeAndIfHexThenSignExtend(decoded_bits, type, is_hex, + &decoded_bits)) { + ErrorMsgStream(error_msg) + << "Integer " << (is_hex ? std::hex : std::dec) << std::showbase + << decoded_bits << " does not fit in a " << std::dec << bit_width + << "-bit " << (IsSigned(type) ? "signed" : "unsigned") << " integer"; + return EncodeNumberStatus::kInvalidText; + } + } + if (bit_width > 32) { + uint32_t low = uint32_t(0x00000000ffffffff & decoded_bits); + uint32_t high = uint32_t((0xffffffff00000000 & decoded_bits) >> 32); + emit(low); + emit(high); + } else { + emit(uint32_t(decoded_bits)); + } + return EncodeNumberStatus::kSuccess; +} + +EncodeNumberStatus ParseAndEncodeFloatingPointNumber( + const char* text, const NumberType& type, + std::function emit, std::string* error_msg) { + if (!text) { + ErrorMsgStream(error_msg) << "The given text is a nullptr"; + return EncodeNumberStatus::kInvalidText; + } + + if (!IsFloating(type)) { + ErrorMsgStream(error_msg) << "The expected type is not a float type"; + return EncodeNumberStatus::kInvalidUsage; + } + + const auto bit_width = AssumedBitWidth(type); + switch (bit_width) { + case 16: { + HexFloat> hVal(0); + if (!ParseNumber(text, &hVal)) { + ErrorMsgStream(error_msg) << "Invalid 16-bit float literal: " << text; + return EncodeNumberStatus::kInvalidText; + } + // getAsFloat will return the Float16 value, and get_value + // will return a uint16_t representing the bits of the float. + // The encoding is therefore correct from the perspective of the SPIR-V + // spec since the top 16 bits will be 0. + emit(static_cast(hVal.value().getAsFloat().get_value())); + return EncodeNumberStatus::kSuccess; + } break; + case 32: { + HexFloat> fVal(0.0f); + if (!ParseNumber(text, &fVal)) { + ErrorMsgStream(error_msg) << "Invalid 32-bit float literal: " << text; + return EncodeNumberStatus::kInvalidText; + } + emit(BitwiseCast(fVal)); + return EncodeNumberStatus::kSuccess; + } break; + case 64: { + HexFloat> dVal(0.0); + if (!ParseNumber(text, &dVal)) { + ErrorMsgStream(error_msg) << "Invalid 64-bit float literal: " << text; + return EncodeNumberStatus::kInvalidText; + } + uint64_t decoded_val = BitwiseCast(dVal); + uint32_t low = uint32_t(0x00000000ffffffff & decoded_val); + uint32_t high = uint32_t((0xffffffff00000000 & decoded_val) >> 32); + emit(low); + emit(high); + return EncodeNumberStatus::kSuccess; + } break; + default: + break; + } + ErrorMsgStream(error_msg) + << "Unsupported " << bit_width << "-bit float literals"; + return EncodeNumberStatus::kUnsupported; +} + +EncodeNumberStatus ParseAndEncodeNumber(const char* text, + const NumberType& type, + std::function emit, + std::string* error_msg) { + if (!text) { + ErrorMsgStream(error_msg) << "The given text is a nullptr"; + return EncodeNumberStatus::kInvalidText; + } + + if (IsUnknown(type)) { + ErrorMsgStream(error_msg) + << "The expected type is not a integer or float type"; + return EncodeNumberStatus::kInvalidUsage; + } + + // If we explicitly expect a floating-point number, we should handle that + // first. + if (IsFloating(type)) { + return ParseAndEncodeFloatingPointNumber(text, type, emit, error_msg); + } + + return ParseAndEncodeIntegerNumber(text, type, emit, error_msg); +} + +} // namespace utils +} // namespace spvtools diff --git a/source/util/parse_number.h b/source/util/parse_number.h new file mode 100644 index 000000000..729aac54b --- /dev/null +++ b/source/util/parse_number.h @@ -0,0 +1,252 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_UTIL_PARSE_NUMBER_H_ +#define SOURCE_UTIL_PARSE_NUMBER_H_ + +#include +#include +#include + +#include "source/util/hex_float.h" +#include "spirv-tools/libspirv.h" + +namespace spvtools { +namespace utils { + +// A struct to hold the expected type information for the number in text to be +// parsed. +struct NumberType { + uint32_t bitwidth; + // SPV_NUMBER_NONE means the type is unknown and is invalid to be used with + // ParseAndEncode{|Integer|Floating}Number(). + spv_number_kind_t kind; +}; + +// Returns true if the type is a scalar integer type. +inline bool IsIntegral(const NumberType& type) { + return type.kind == SPV_NUMBER_UNSIGNED_INT || + type.kind == SPV_NUMBER_SIGNED_INT; +} + +// Returns true if the type is a scalar floating point type. +inline bool IsFloating(const NumberType& type) { + return type.kind == SPV_NUMBER_FLOATING; +} + +// Returns true if the type is a signed value. +inline bool IsSigned(const NumberType& type) { + return type.kind == SPV_NUMBER_FLOATING || type.kind == SPV_NUMBER_SIGNED_INT; +} + +// Returns true if the type is unknown. +inline bool IsUnknown(const NumberType& type) { + return type.kind == SPV_NUMBER_NONE; +} + +// Returns the number of bits in the type. This is only valid for integer and +// floating types. +inline int AssumedBitWidth(const NumberType& type) { + switch (type.kind) { + case SPV_NUMBER_SIGNED_INT: + case SPV_NUMBER_UNSIGNED_INT: + case SPV_NUMBER_FLOATING: + return type.bitwidth; + default: + break; + } + // We don't care about this case. + return 0; +} + +// A templated class with a static member function Clamp, where Clamp sets a +// referenced value of type T to 0 if T is an unsigned integer type, and +// returns true if it modified the referenced value. +template +class ClampToZeroIfUnsignedType { + public: + // The default specialization does not clamp the value. + static bool Clamp(T*) { return false; } +}; + +// The specialization of ClampToZeroIfUnsignedType for unsigned integer types. +template +class ClampToZeroIfUnsignedType< + T, typename std::enable_if::value>::type> { + public: + static bool Clamp(T* value_pointer) { + if (*value_pointer) { + *value_pointer = 0; + return true; + } + return false; + } +}; + +// Returns true if the given value fits within the target scalar integral type. +// The target type may have an unusual bit width. If the value was originally +// specified as a hexadecimal number, then the overflow bits should be zero. +// If it was hex and the target type is signed, then return the sign-extended +// value through the updated_value_for_hex pointer argument. On failure, +// returns false. +template +bool CheckRangeAndIfHexThenSignExtend(T value, const NumberType& type, + bool is_hex, T* updated_value_for_hex) { + // The encoded result has three regions of bits that are of interest, from + // least to most significant: + // - magnitude bits, where the magnitude of the number would be stored if + // we were using a signed-magnitude representation. + // - an optional sign bit + // - overflow bits, up to bit 63 of a 64-bit number + // For example: + // Type Overflow Sign Magnitude + // --------------- -------- ---- --------- + // unsigned 8 bit 8-63 n/a 0-7 + // signed 8 bit 8-63 7 0-6 + // unsigned 16 bit 16-63 n/a 0-15 + // signed 16 bit 16-63 15 0-14 + + // We'll use masks to define the three regions. + // At first we'll assume the number is unsigned. + const uint32_t bit_width = AssumedBitWidth(type); + uint64_t magnitude_mask = + (bit_width == 64) ? -1 : ((uint64_t(1) << bit_width) - 1); + uint64_t sign_mask = 0; + uint64_t overflow_mask = ~magnitude_mask; + + if (value < 0 || IsSigned(type)) { + // Accommodate the sign bit. + magnitude_mask >>= 1; + sign_mask = magnitude_mask + 1; + } + + bool failed = false; + if (value < 0) { + // The top bits must all be 1 for a negative signed value. + failed = ((value & overflow_mask) != overflow_mask) || + ((value & sign_mask) != sign_mask); + } else { + if (is_hex) { + // Hex values are a bit special. They decode as unsigned values, but may + // represent a negative number. In this case, the overflow bits should + // be zero. + failed = (value & overflow_mask) != 0; + } else { + const uint64_t value_as_u64 = static_cast(value); + // Check overflow in the ordinary case. + failed = (value_as_u64 & magnitude_mask) != value_as_u64; + } + } + + if (failed) { + return false; + } + + // Sign extend hex the number. + if (is_hex && (value & sign_mask)) + *updated_value_for_hex = (value | overflow_mask); + + return true; +} + +// Parses a numeric value of a given type from the given text. The number +// should take up the entire string, and should be within bounds for the target +// type. On success, returns true and populates the object referenced by +// value_pointer. On failure, returns false. +template +bool ParseNumber(const char* text, T* value_pointer) { + // C++11 doesn't define std::istringstream(int8_t&), so calling this method + // with a single-byte type leads to implementation-defined behaviour. + // Similarly for uint8_t. + static_assert(sizeof(T) > 1, + "Single-byte types are not supported in this parse method"); + + if (!text) return false; + std::istringstream text_stream(text); + // Allow both decimal and hex input for integers. + // It also allows octal input, but we don't care about that case. + text_stream >> std::setbase(0); + text_stream >> *value_pointer; + + // We should have read something. + bool ok = (text[0] != 0) && !text_stream.bad(); + // It should have been all the text. + ok = ok && text_stream.eof(); + // It should have been in range. + ok = ok && !text_stream.fail(); + + // Work around a bug in the GNU C++11 library. It will happily parse + // "-1" for uint16_t as 65535. + if (ok && text[0] == '-') + ok = !ClampToZeroIfUnsignedType::Clamp(value_pointer); + + return ok; +} + +// Enum to indicate the parsing and encoding status. +enum class EncodeNumberStatus { + kSuccess = 0, + // Unsupported bit width etc. + kUnsupported, + // Expected type (NumberType) is not a scalar int or float, or putting a + // negative number in an unsigned literal. + kInvalidUsage, + // Number value does not fit the bit width of the expected type etc. + kInvalidText, +}; + +// Parses an integer value of a given |type| from the given |text| and encodes +// the number by the given |emit| function. On success, returns +// EncodeNumberStatus::kSuccess and the parsed number will be consumed by the +// given |emit| function word by word (least significant word first). On +// failure, this function returns the error code of the encoding status and +// |emit| function will not be called. If the string pointer |error_msg| is not +// a nullptr, it will be overwritten with error messages in case of failure. In +// case of success, |error_msg| will not be touched. Integers up to 64 bits are +// supported. +EncodeNumberStatus ParseAndEncodeIntegerNumber( + const char* text, const NumberType& type, + std::function emit, std::string* error_msg); + +// Parses a floating point value of a given |type| from the given |text| and +// encodes the number by the given |emit| funciton. On success, returns +// EncodeNumberStatus::kSuccess and the parsed number will be consumed by the +// given |emit| function word by word (least significant word first). On +// failure, this function returns the error code of the encoding status and +// |emit| function will not be called. If the string pointer |error_msg| is not +// a nullptr, it will be overwritten with error messages in case of failure. In +// case of success, |error_msg| will not be touched. Only 16, 32 and 64 bit +// floating point numbers are supported. +EncodeNumberStatus ParseAndEncodeFloatingPointNumber( + const char* text, const NumberType& type, + std::function emit, std::string* error_msg); + +// Parses an integer or floating point number of a given |type| from the given +// |text| and encodes the number by the given |emit| function. On success, +// returns EncodeNumberStatus::kSuccess and the parsed number will be consumed +// by the given |emit| function word by word (least significant word first). On +// failure, this function returns the error code of the encoding status and +// |emit| function will not be called. If the string pointer |error_msg| is not +// a nullptr, it will be overwritten with error messages in case of failure. In +// case of success, |error_msg| will not be touched. Integers up to 64 bits +// and 16/32/64 bit floating point values are supported. +EncodeNumberStatus ParseAndEncodeNumber(const char* text, + const NumberType& type, + std::function emit, + std::string* error_msg); + +} // namespace utils +} // namespace spvtools + +#endif // SOURCE_UTIL_PARSE_NUMBER_H_ diff --git a/source/util/small_vector.h b/source/util/small_vector.h new file mode 100644 index 000000000..f2c1147be --- /dev/null +++ b/source/util/small_vector.h @@ -0,0 +1,466 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_UTIL_SMALL_VECTOR_H_ +#define SOURCE_UTIL_SMALL_VECTOR_H_ + +#include +#include +#include +#include +#include + +#include "source/util/make_unique.h" + +namespace spvtools { +namespace utils { + +// The |SmallVector| class is intended to be a drop-in replacement for +// |std::vector|. The difference is in the implementation. A |SmallVector| is +// optimized for when the number of elements in the vector are small. Small is +// defined by the template parameter |small_size|. +// +// Note that |SmallVector| is not always faster than an |std::vector|, so you +// should experiment with different values for |small_size| and compare to +// using and |std::vector|. +// +// TODO: I have implemented the public member functions from |std::vector| that +// I needed. If others are needed they should be implemented. Do not implement +// public member functions that are not defined by std::vector. +template +class SmallVector { + public: + using iterator = T*; + using const_iterator = const T*; + + SmallVector() + : size_(0), + small_data_(reinterpret_cast(buffer)), + large_data_(nullptr) {} + + SmallVector(const SmallVector& that) : SmallVector() { *this = that; } + + SmallVector(SmallVector&& that) : SmallVector() { *this = std::move(that); } + + SmallVector(const std::vector& vec) : SmallVector() { + if (vec.size() > small_size) { + large_data_ = MakeUnique>(vec); + } else { + size_ = vec.size(); + for (uint32_t i = 0; i < size_; i++) { + new (small_data_ + i) T(vec[i]); + } + } + } + + SmallVector(std::vector&& vec) : SmallVector() { + if (vec.size() > small_size) { + large_data_ = MakeUnique>(std::move(vec)); + } else { + size_ = vec.size(); + for (uint32_t i = 0; i < size_; i++) { + new (small_data_ + i) T(std::move(vec[i])); + } + } + vec.clear(); + } + + SmallVector(std::initializer_list init_list) : SmallVector() { + if (init_list.size() < small_size) { + for (auto it = init_list.begin(); it != init_list.end(); ++it) { + new (small_data_ + (size_++)) T(std::move(*it)); + } + } else { + large_data_ = MakeUnique>(std::move(init_list)); + } + } + + SmallVector(size_t s, const T& v) : SmallVector() { resize(s, v); } + + virtual ~SmallVector() { + for (T* p = small_data_; p < small_data_ + size_; ++p) { + p->~T(); + } + } + + SmallVector& operator=(const SmallVector& that) { + assert(small_data_); + if (that.large_data_) { + if (large_data_) { + *large_data_ = *that.large_data_; + } else { + large_data_ = MakeUnique>(*that.large_data_); + } + } else { + large_data_.reset(nullptr); + size_t i = 0; + // Do a copy for any element in |this| that is already constructed. + for (; i < size_ && i < that.size_; ++i) { + small_data_[i] = that.small_data_[i]; + } + + if (i >= that.size_) { + // If the size of |this| becomes smaller after the assignment, then + // destroy any extra elements. + for (; i < size_; ++i) { + small_data_[i].~T(); + } + } else { + // If the size of |this| becomes larger after the assignement, copy + // construct the new elements that are needed. + for (; i < that.size_; ++i) { + new (small_data_ + i) T(that.small_data_[i]); + } + } + size_ = that.size_; + } + return *this; + } + + SmallVector& operator=(SmallVector&& that) { + if (that.large_data_) { + large_data_.reset(that.large_data_.release()); + } else { + large_data_.reset(nullptr); + size_t i = 0; + // Do a move for any element in |this| that is already constructed. + for (; i < size_ && i < that.size_; ++i) { + small_data_[i] = std::move(that.small_data_[i]); + } + + if (i >= that.size_) { + // If the size of |this| becomes smaller after the assignment, then + // destroy any extra elements. + for (; i < size_; ++i) { + small_data_[i].~T(); + } + } else { + // If the size of |this| becomes larger after the assignement, move + // construct the new elements that are needed. + for (; i < that.size_; ++i) { + new (small_data_ + i) T(std::move(that.small_data_[i])); + } + } + size_ = that.size_; + } + + // Reset |that| because all of the data has been moved to |this|. + that.DestructSmallData(); + return *this; + } + + template + friend bool operator==(const SmallVector& lhs, const OtherVector& rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + + auto rit = rhs.begin(); + for (auto lit = lhs.begin(); lit != lhs.end(); ++lit, ++rit) { + if (*lit != *rit) { + return false; + } + } + return true; + } + + friend bool operator==(const std::vector& lhs, const SmallVector& rhs) { + return rhs == lhs; + } + + friend bool operator!=(const SmallVector& lhs, const std::vector& rhs) { + return !(lhs == rhs); + } + + friend bool operator!=(const std::vector& lhs, const SmallVector& rhs) { + return rhs != lhs; + } + + T& operator[](size_t i) { + if (!large_data_) { + return small_data_[i]; + } else { + return (*large_data_)[i]; + } + } + + const T& operator[](size_t i) const { + if (!large_data_) { + return small_data_[i]; + } else { + return (*large_data_)[i]; + } + } + + size_t size() const { + if (!large_data_) { + return size_; + } else { + return large_data_->size(); + } + } + + iterator begin() { + if (large_data_) { + return large_data_->data(); + } else { + return small_data_; + } + } + + const_iterator begin() const { + if (large_data_) { + return large_data_->data(); + } else { + return small_data_; + } + } + + const_iterator cbegin() const { return begin(); } + + iterator end() { + if (large_data_) { + return large_data_->data() + large_data_->size(); + } else { + return small_data_ + size_; + } + } + + const_iterator end() const { + if (large_data_) { + return large_data_->data() + large_data_->size(); + } else { + return small_data_ + size_; + } + } + + const_iterator cend() const { return end(); } + + T* data() { return begin(); } + + const T* data() const { return cbegin(); } + + T& front() { return (*this)[0]; } + + const T& front() const { return (*this)[0]; } + + iterator erase(const_iterator pos) { return erase(pos, pos + 1); } + + iterator erase(const_iterator first, const_iterator last) { + if (large_data_) { + size_t start_index = first - large_data_->data(); + size_t end_index = last - large_data_->data(); + auto r = large_data_->erase(large_data_->begin() + start_index, + large_data_->begin() + end_index); + return large_data_->data() + (r - large_data_->begin()); + } + + // Since C++11, std::vector has |const_iterator| for the parameters, so I + // follow that. However, I need iterators to modify the current container, + // which is not const. This is why I cast away the const. + iterator f = const_cast(first); + iterator l = const_cast(last); + iterator e = end(); + + size_t num_of_del_elements = last - first; + iterator ret = f; + if (first == last) { + return ret; + } + + // Move |last| and any elements after it their earlier position. + while (l != e) { + *f = std::move(*l); + ++f; + ++l; + } + + // Destroy the elements that were supposed to be deleted. + while (f != l) { + f->~T(); + ++f; + } + + // Update the size. + size_ -= num_of_del_elements; + return ret; + } + + void push_back(const T& value) { + if (!large_data_ && size_ == small_size) { + MoveToLargeData(); + } + + if (large_data_) { + large_data_->push_back(value); + return; + } + + new (small_data_ + size_) T(value); + ++size_; + } + + void push_back(T&& value) { + if (!large_data_ && size_ == small_size) { + MoveToLargeData(); + } + + if (large_data_) { + large_data_->push_back(std::move(value)); + return; + } + + new (small_data_ + size_) T(std::move(value)); + ++size_; + } + + template + iterator insert(iterator pos, InputIt first, InputIt last) { + size_t element_idx = (pos - begin()); + size_t num_of_new_elements = std::distance(first, last); + size_t new_size = size_ + num_of_new_elements; + if (!large_data_ && new_size > small_size) { + MoveToLargeData(); + } + + if (large_data_) { + typename std::vector::iterator new_pos = + large_data_->begin() + element_idx; + large_data_->insert(new_pos, first, last); + return begin() + element_idx; + } + + // Move |pos| and all of the elements after it over |num_of_new_elements| + // places. We start at the end and work backwards, to make sure we do not + // overwrite data that we have not moved yet. + for (iterator i = begin() + new_size - 1, j = end() - 1; j >= pos; + --i, --j) { + if (i >= begin() + size_) { + new (i) T(std::move(*j)); + } else { + *i = std::move(*j); + } + } + + // Copy the new elements into position. + iterator p = pos; + for (; first != last; ++p, ++first) { + if (p >= small_data_ + size_) { + new (p) T(*first); + } else { + *p = *first; + } + } + + // Upate the size. + size_ += num_of_new_elements; + return pos; + } + + bool empty() const { + if (large_data_) { + return large_data_->empty(); + } + return size_ == 0; + } + + void clear() { + if (large_data_) { + large_data_->clear(); + } else { + DestructSmallData(); + } + } + + template + void emplace_back(Args&&... args) { + if (!large_data_ && size_ == small_size) { + MoveToLargeData(); + } + + if (large_data_) { + large_data_->emplace_back(std::forward(args)...); + } else { + new (small_data_ + size_) T(std::forward(args)...); + ++size_; + } + } + + void resize(size_t new_size, const T& v) { + if (!large_data_ && new_size > small_size) { + MoveToLargeData(); + } + + if (large_data_) { + large_data_->resize(new_size, v); + return; + } + + // If |new_size| < |size_|, then destroy the extra elements. + for (size_t i = new_size; i < size_; ++i) { + small_data_[i].~T(); + } + + // If |new_size| > |size_|, the copy construct the new elements. + for (size_t i = size_; i < new_size; ++i) { + new (small_data_ + i) T(v); + } + + // Update the size. + size_ = new_size; + } + + private: + // Moves all of the element from |small_data_| into a new std::vector that can + // be access through |large_data|. + void MoveToLargeData() { + assert(!large_data_); + large_data_ = MakeUnique>(); + for (size_t i = 0; i < size_; ++i) { + large_data_->emplace_back(std::move(small_data_[i])); + } + DestructSmallData(); + } + + // Destroys all of the elements in |small_data_| that have been constructed. + void DestructSmallData() { + for (size_t i = 0; i < size_; ++i) { + small_data_[i].~T(); + } + size_ = 0; + } + + // The number of elements in |small_data_| that have been constructed. + size_t size_; + + // The pointed used to access the array of elements when the number of + // elements is small. + T* small_data_; + + // The actual data used to store the array elements. It must never be used + // directly, but must only be accesed through |small_data_|. + typename std::aligned_storage::value>::type + buffer[small_size]; + + // A pointer to a vector that is used to store the elements of the vector when + // this size exceeds |small_size|. If |large_data_| is nullptr, then the data + // is stored in |small_data_|. Otherwise, the data is stored in + // |large_data_|. + std::unique_ptr> large_data_; +}; // namespace utils + +} // namespace utils +} // namespace spvtools + +#endif // SOURCE_UTIL_SMALL_VECTOR_H_ diff --git a/source/util/string_utils.cpp b/source/util/string_utils.cpp new file mode 100644 index 000000000..b56c353af --- /dev/null +++ b/source/util/string_utils.cpp @@ -0,0 +1,58 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "source/util/string_utils.h" + +namespace spvtools { +namespace utils { + +std::string CardinalToOrdinal(size_t cardinal) { + const size_t mod10 = cardinal % 10; + const size_t mod100 = cardinal % 100; + std::string suffix; + if (mod10 == 1 && mod100 != 11) + suffix = "st"; + else if (mod10 == 2 && mod100 != 12) + suffix = "nd"; + else if (mod10 == 3 && mod100 != 13) + suffix = "rd"; + else + suffix = "th"; + + return ToString(cardinal) + suffix; +} + +std::pair SplitFlagArgs(const std::string& flag) { + if (flag.size() < 2) return make_pair(flag, std::string()); + + // Detect the last dash before the pass name. Since we have to + // handle single dash options (-O and -Os), count up to two dashes. + size_t dash_ix = 0; + if (flag[0] == '-' && flag[1] == '-') + dash_ix = 2; + else if (flag[0] == '-') + dash_ix = 1; + + size_t ix = flag.find('='); + return (ix != std::string::npos) + ? make_pair(flag.substr(dash_ix, ix - 2), flag.substr(ix + 1)) + : make_pair(flag.substr(dash_ix), std::string()); +} + +} // namespace utils +} // namespace spvtools diff --git a/source/util/string_utils.h b/source/util/string_utils.h new file mode 100644 index 000000000..f1cd179c9 --- /dev/null +++ b/source/util/string_utils.h @@ -0,0 +1,48 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_UTIL_STRING_UTILS_H_ +#define SOURCE_UTIL_STRING_UTILS_H_ + +#include +#include + +#include "source/util/string_utils.h" + +namespace spvtools { +namespace utils { + +// Converts arithmetic value |val| to its default string representation. +template +std::string ToString(T val) { + static_assert( + std::is_arithmetic::value, + "spvtools::utils::ToString is restricted to only arithmetic values"); + std::stringstream os; + os << val; + return os.str(); +} + +// Converts cardinal number to ordinal number string. +std::string CardinalToOrdinal(size_t cardinal); + +// Splits the string |flag|, of the form '--pass_name[=pass_args]' into two +// strings "pass_name" and "pass_args". If |flag| has no arguments, the second +// string will be empty. +std::pair SplitFlagArgs(const std::string& flag); + +} // namespace utils +} // namespace spvtools + +#endif // SOURCE_UTIL_STRING_UTILS_H_ diff --git a/source/util/timer.cpp b/source/util/timer.cpp new file mode 100644 index 000000000..c8b8d5b61 --- /dev/null +++ b/source/util/timer.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(SPIRV_TIMER_ENABLED) + +#include "source/util/timer.h" + +#include +#include +#include +#include +#include + +namespace spvtools { +namespace utils { + +void PrintTimerDescription(std::ostream* out, bool measure_mem_usage) { + if (out) { + *out << std::setw(30) << "PASS name" << std::setw(12) << "CPU time" + << std::setw(12) << "WALL time" << std::setw(12) << "USR time" + << std::setw(12) << "SYS time"; + if (measure_mem_usage) { + *out << std::setw(12) << "RSS delta" << std::setw(16) << "PGFault delta"; + } + *out << std::endl; + } +} + +// Do not change the order of invoking system calls. We want to make CPU/Wall +// time correct as much as possible. Calling functions to get CPU/Wall time must +// closely surround the target code of measuring. +void Timer::Start() { + if (report_stream_) { + if (getrusage(RUSAGE_SELF, &usage_before_) == -1) + usage_status_ |= kGetrusageFailed; + if (clock_gettime(CLOCK_MONOTONIC, &wall_before_) == -1) + usage_status_ |= kClockGettimeWalltimeFailed; + if (clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &cpu_before_) == -1) + usage_status_ |= kClockGettimeCPUtimeFailed; + } +} + +// The order of invoking system calls is important with the same reason as +// Timer::Start(). +void Timer::Stop() { + if (report_stream_ && usage_status_ == kSucceeded) { + if (clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &cpu_after_) == -1) + usage_status_ |= kClockGettimeCPUtimeFailed; + if (clock_gettime(CLOCK_MONOTONIC, &wall_after_) == -1) + usage_status_ |= kClockGettimeWalltimeFailed; + if (getrusage(RUSAGE_SELF, &usage_after_) == -1) + usage_status_ = kGetrusageFailed; + } +} + +void Timer::Report(const char* tag) { + if (!report_stream_) return; + + report_stream_->precision(2); + *report_stream_ << std::fixed << std::setw(30) << tag; + + if (usage_status_ & kClockGettimeCPUtimeFailed) + *report_stream_ << std::setw(12) << "Failed"; + else + *report_stream_ << std::setw(12) << CPUTime(); + + if (usage_status_ & kClockGettimeWalltimeFailed) + *report_stream_ << std::setw(12) << "Failed"; + else + *report_stream_ << std::setw(12) << WallTime(); + + if (usage_status_ & kGetrusageFailed) { + *report_stream_ << std::setw(12) << "Failed" << std::setw(12) << "Failed"; + if (measure_mem_usage_) { + *report_stream_ << std::setw(12) << "Failed" << std::setw(12) << "Failed"; + } + } else { + *report_stream_ << std::setw(12) << UserTime() << std::setw(12) + << SystemTime(); + if (measure_mem_usage_) { + *report_stream_ << std::fixed << std::setw(12) << RSS() << std::setw(16) + << PageFault(); + } + } + *report_stream_ << std::endl; +} + +} // namespace utils +} // namespace spvtools + +#endif // defined(SPIRV_TIMER_ENABLED) diff --git a/source/util/timer.h b/source/util/timer.h new file mode 100644 index 000000000..fc4b747b9 --- /dev/null +++ b/source/util/timer.h @@ -0,0 +1,392 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Contains utils for getting resource utilization + +#ifndef SOURCE_UTIL_TIMER_H_ +#define SOURCE_UTIL_TIMER_H_ + +#if defined(SPIRV_TIMER_ENABLED) + +#include +#include +#include + +// A macro to call spvtools::utils::PrintTimerDescription(std::ostream*, bool). +// The first argument must be given as std::ostream*. If it is NULL, the +// function does nothing. Otherwise, it prints resource types measured by Timer +// class. The second is optional and if it is true, the function also prints +// resource type fields related to memory. Otherwise, it does not print memory +// related fields. Its default is false. In usual, this must be placed before +// calling Timer::Report() to inform what those fields printed by +// Timer::Report() indicate (or spvtools::utils::PrintTimerDescription() must be +// used instead). +#define SPIRV_TIMER_DESCRIPTION(...) \ + spvtools::utils::PrintTimerDescription(__VA_ARGS__) + +// Creates an object of ScopedTimer to measure the resource utilization for the +// scope surrounding it as the following example: +// +// { // <-- beginning of this scope +// +// /* ... code out of interest ... */ +// +// SPIRV_TIMER_SCOPED(std::cout, tag); +// +// /* ... lines of code that we want to know its resource usage ... */ +// +// } // <-- end of this scope. The destructor of ScopedTimer prints tag and +// the resource utilization to std::cout. +#define SPIRV_TIMER_SCOPED(...) \ + spvtools::utils::ScopedTimer timer##__LINE__( \ + __VA_ARGS__) + +namespace spvtools { +namespace utils { + +// Prints the description of resource types measured by Timer class. If |out| is +// NULL, it does nothing. Otherwise, it prints resource types. The second is +// optional and if it is true, the function also prints resource type fields +// related to memory. Its default is false. In usual, this must be placed before +// calling Timer::Report() to inform what those fields printed by +// Timer::Report() indicate. +void PrintTimerDescription(std::ostream*, bool = false); + +// Status of Timer. kGetrusageFailed means it failed in calling getrusage(). +// kClockGettimeWalltimeFailed means it failed in getting wall time when calling +// clock_gettime(). kClockGettimeCPUtimeFailed means it failed in getting CPU +// time when calling clock_gettime(). +enum UsageStatus { + kSucceeded = 0, + kGetrusageFailed = 1 << 0, + kClockGettimeWalltimeFailed = 1 << 1, + kClockGettimeCPUtimeFailed = 1 << 2, +}; + +// Timer measures the resource utilization for a range of code. The resource +// utilization consists of CPU time (i.e., process time), WALL time (elapsed +// time), USR time, SYS time, RSS delta, and the delta of the number of page +// faults. RSS delta and the delta of the number of page faults are measured +// only when |measure_mem_usage| given to the constructor is true. This class +// should be used as the following example: +// +// spvtools::utils::Timer timer(std::cout); +// timer.Start(); // <-- set |usage_before_|, |wall_before_|, +// and |cpu_before_| +// +// /* ... lines of code that we want to know its resource usage ... */ +// +// timer.Stop(); // <-- set |cpu_after_|, |wall_after_|, and +// |usage_after_| +// timer.Report(tag); // <-- print tag and the resource utilization to +// std::cout. +class Timer { + public: + Timer(std::ostream* out, bool measure_mem_usage = false) + : report_stream_(out), + usage_status_(kSucceeded), + measure_mem_usage_(measure_mem_usage) {} + + // Sets |usage_before_|, |wall_before_|, and |cpu_before_| as results of + // getrusage(), clock_gettime() for the wall time, and clock_gettime() for the + // CPU time respectively. Note that this method erases all previous state of + // |usage_before_|, |wall_before_|, |cpu_before_|. + virtual void Start(); + + // Sets |cpu_after_|, |wall_after_|, and |usage_after_| as results of + // clock_gettime() for the wall time, and clock_gettime() for the CPU time, + // getrusage() respectively. Note that this method erases all previous state + // of |cpu_after_|, |wall_after_|, |usage_after_|. + virtual void Stop(); + + // If |report_stream_| is NULL, it does nothing. Otherwise, it prints the + // resource utilization (i.e., CPU/WALL/USR/SYS time, RSS delta) between the + // time of calling Timer::Start() and the time of calling Timer::Stop(). If we + // cannot get a resource usage because of failures, it prints "Failed" instead + // for the resource. + void Report(const char* tag); + + // Returns the measured CPU Time (i.e., process time) for a range of code + // execution. If kClockGettimeCPUtimeFailed is set by the failure of calling + // clock_gettime(), it returns -1. + virtual double CPUTime() { + if (usage_status_ & kClockGettimeCPUtimeFailed) return -1; + return TimeDifference(cpu_before_, cpu_after_); + } + + // Returns the measured Wall Time (i.e., elapsed time) for a range of code + // execution. If kClockGettimeWalltimeFailed is set by the failure of + // calling clock_gettime(), it returns -1. + virtual double WallTime() { + if (usage_status_ & kClockGettimeWalltimeFailed) return -1; + return TimeDifference(wall_before_, wall_after_); + } + + // Returns the measured USR Time for a range of code execution. If + // kGetrusageFailed is set because of the failure of calling getrusage(), it + // returns -1. + virtual double UserTime() { + if (usage_status_ & kGetrusageFailed) return -1; + return TimeDifference(usage_before_.ru_utime, usage_after_.ru_utime); + } + + // Returns the measured SYS Time for a range of code execution. If + // kGetrusageFailed is set because of the failure of calling getrusage(), it + // returns -1. + virtual double SystemTime() { + if (usage_status_ & kGetrusageFailed) return -1; + return TimeDifference(usage_before_.ru_stime, usage_after_.ru_stime); + } + + // Returns the measured RSS delta for a range of code execution. If + // kGetrusageFailed is set because of the failure of calling getrusage(), it + // returns -1. + virtual long RSS() const { + if (usage_status_ & kGetrusageFailed) return -1; + return usage_after_.ru_maxrss - usage_before_.ru_maxrss; + } + + // Returns the measured the delta of the number of page faults for a range of + // code execution. If kGetrusageFailed is set because of the failure of + // calling getrusage(), it returns -1. + virtual long PageFault() const { + if (usage_status_ & kGetrusageFailed) return -1; + return (usage_after_.ru_minflt - usage_before_.ru_minflt) + + (usage_after_.ru_majflt - usage_before_.ru_majflt); + } + + virtual ~Timer() {} + + private: + // Returns the time gap between |from| and |to| in seconds. + static double TimeDifference(const timeval& from, const timeval& to) { + assert((to.tv_sec > from.tv_sec) || + (to.tv_sec == from.tv_sec && to.tv_usec >= from.tv_usec)); + return static_cast(to.tv_sec - from.tv_sec) + + static_cast(to.tv_usec - from.tv_usec) * .000001; + } + + // Returns the time gap between |from| and |to| in seconds. + static double TimeDifference(const timespec& from, const timespec& to) { + assert((to.tv_sec > from.tv_sec) || + (to.tv_sec == from.tv_sec && to.tv_nsec >= from.tv_nsec)); + return static_cast(to.tv_sec - from.tv_sec) + + static_cast(to.tv_nsec - from.tv_nsec) * .000000001; + } + + // Output stream to print out the resource utilization. If it is NULL, + // Report() does nothing. + std::ostream* report_stream_; + + // Status to stop measurement if a system call returns an error. + unsigned usage_status_; + + // Variable to save the result of clock_gettime(CLOCK_PROCESS_CPUTIME_ID) when + // Timer::Start() is called. It is used as the base status of CPU time. + timespec cpu_before_; + + // Variable to save the result of clock_gettime(CLOCK_MONOTONIC) when + // Timer::Start() is called. It is used as the base status of WALL time. + timespec wall_before_; + + // Variable to save the result of getrusage() when Timer::Start() is called. + // It is used as the base status of USR time, SYS time, and RSS. + rusage usage_before_; + + // Variable to save the result of clock_gettime(CLOCK_PROCESS_CPUTIME_ID) when + // Timer::Stop() is called. It is used as the last status of CPU time. The + // resouce usage is measured by subtracting |cpu_before_| from it. + timespec cpu_after_; + + // Variable to save the result of clock_gettime(CLOCK_MONOTONIC) when + // Timer::Stop() is called. It is used as the last status of WALL time. The + // resouce usage is measured by subtracting |wall_before_| from it. + timespec wall_after_; + + // Variable to save the result of getrusage() when Timer::Stop() is called. It + // is used as the last status of USR time, SYS time, and RSS. Those resouce + // usages are measured by subtracting |usage_before_| from it. + rusage usage_after_; + + // If true, Timer reports the memory usage information too. Otherwise, Timer + // reports only USR time, WALL time, SYS time. + bool measure_mem_usage_; +}; + +// The purpose of ScopedTimer is to measure the resource utilization for a +// scope. Simply creating a local variable of ScopedTimer will call +// Timer::Start() and it calls Timer::Stop() and Timer::Report() at the end of +// the scope by its destructor. When we use this class, we must choose the +// proper Timer class (for class TimerType template) in advance. This class +// should be used as the following example: +// +// { // <-- beginning of this scope +// +// /* ... code out of interest ... */ +// +// spvtools::utils::ScopedTimer +// scopedtimer(std::cout, tag); +// +// /* ... lines of code that we want to know its resource usage ... */ +// +// } // <-- end of this scope. The destructor of ScopedTimer prints tag and +// the resource utilization to std::cout. +// +// The template is used to choose a Timer class. Currently, +// only options for the Timer class are Timer and MockTimer in the unit test. +template +class ScopedTimer { + public: + ScopedTimer(std::ostream* out, const char* tag, + bool measure_mem_usage = false) + : timer(new TimerType(out, measure_mem_usage)), tag_(tag) { + timer->Start(); + } + + // At the end of the scope surrounding the instance of this class, this + // destructor saves the last status of resource usage and reports it. + virtual ~ScopedTimer() { + timer->Stop(); + timer->Report(tag_); + delete timer; + } + + private: + // Actual timer that measures the resource utilization. It must be an instance + // of Timer class if there is no special reason to use other class. + TimerType* timer; + + // A tag that will be printed in front of the trace reported by Timer class. + const char* tag_; +}; + +// CumulativeTimer is the same as Timer class, but it supports a cumulative +// measurement as the following example: +// +// CumulativeTimer *ctimer = new CumulativeTimer(std::cout); +// ctimer->Start(); +// +// /* ... lines of code that we want to know its resource usage ... */ +// +// ctimer->Stop(); +// +// /* ... code out of interest ... */ +// +// ctimer->Start(); +// +// /* ... lines of code that we want to know its resource usage ... */ +// +// ctimer->Stop(); +// ctimer->Report(tag); +// delete ctimer; +// +class CumulativeTimer : public Timer { + public: + CumulativeTimer(std::ostream* out, bool measure_mem_usage = false) + : Timer(out, measure_mem_usage), + cpu_time_(0), + wall_time_(0), + usr_time_(0), + sys_time_(0), + rss_(0), + pgfaults_(0) {} + + // If we cannot get a resource usage because of failures, it sets -1 for the + // resource usage. + void Stop() override { + Timer::Stop(); + + if (cpu_time_ >= 0 && Timer::CPUTime() >= 0) + cpu_time_ += Timer::CPUTime(); + else + cpu_time_ = -1; + + if (wall_time_ >= 0 && Timer::WallTime() >= 0) + wall_time_ += Timer::WallTime(); + else + wall_time_ = -1; + + if (usr_time_ >= 0 && Timer::UserTime() >= 0) + usr_time_ += Timer::UserTime(); + else + usr_time_ = -1; + + if (sys_time_ >= 0 && Timer::SystemTime() >= 0) + sys_time_ += Timer::SystemTime(); + else + sys_time_ = -1; + + if (rss_ >= 0 && Timer::RSS() >= 0) + rss_ += Timer::RSS(); + else + rss_ = -1; + + if (pgfaults_ >= 0 && Timer::PageFault() >= 0) + pgfaults_ += Timer::PageFault(); + else + pgfaults_ = -1; + } + + // Returns the cumulative CPU Time (i.e., process time) for a range of code + // execution. + double CPUTime() override { return cpu_time_; } + + // Returns the cumulative Wall Time (i.e., elapsed time) for a range of code + // execution. + double WallTime() override { return wall_time_; } + + // Returns the cumulative USR Time for a range of code execution. + double UserTime() override { return usr_time_; } + + // Returns the cumulative SYS Time for a range of code execution. + double SystemTime() override { return sys_time_; } + + // Returns the cumulative RSS delta for a range of code execution. + long RSS() const override { return rss_; } + + // Returns the cumulative delta of number of page faults for a range of code + // execution. + long PageFault() const override { return pgfaults_; } + + private: + // Variable to save the cumulative CPU time (i.e., process time). + double cpu_time_; + + // Variable to save the cumulative wall time (i.e., elapsed time). + double wall_time_; + + // Variable to save the cumulative user time. + double usr_time_; + + // Variable to save the cumulative system time. + double sys_time_; + + // Variable to save the cumulative RSS delta. + long rss_; + + // Variable to save the cumulative delta of the number of page faults. + long pgfaults_; +}; + +} // namespace utils +} // namespace spvtools + +#else // defined(SPIRV_TIMER_ENABLED) + +#define SPIRV_TIMER_DESCRIPTION(...) +#define SPIRV_TIMER_SCOPED(...) + +#endif // defined(SPIRV_TIMER_ENABLED) + +#endif // SOURCE_UTIL_TIMER_H_ diff --git a/source/val/basic_block.cpp b/source/val/basic_block.cpp new file mode 100644 index 000000000..a53103c8a --- /dev/null +++ b/source/val/basic_block.cpp @@ -0,0 +1,149 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/basic_block.h" + +#include +#include +#include + +namespace spvtools { +namespace val { + +BasicBlock::BasicBlock(uint32_t label_id) + : id_(label_id), + immediate_dominator_(nullptr), + immediate_post_dominator_(nullptr), + predecessors_(), + successors_(), + type_(0), + reachable_(false), + label_(nullptr), + terminator_(nullptr) {} + +void BasicBlock::SetImmediateDominator(BasicBlock* dom_block) { + immediate_dominator_ = dom_block; +} + +void BasicBlock::SetImmediatePostDominator(BasicBlock* pdom_block) { + immediate_post_dominator_ = pdom_block; +} + +const BasicBlock* BasicBlock::immediate_dominator() const { + return immediate_dominator_; +} + +const BasicBlock* BasicBlock::immediate_post_dominator() const { + return immediate_post_dominator_; +} + +BasicBlock* BasicBlock::immediate_dominator() { return immediate_dominator_; } +BasicBlock* BasicBlock::immediate_post_dominator() { + return immediate_post_dominator_; +} + +void BasicBlock::RegisterSuccessors( + const std::vector& next_blocks) { + for (auto& block : next_blocks) { + block->predecessors_.push_back(this); + successors_.push_back(block); + if (block->reachable_ == false) block->set_reachable(reachable_); + } +} + +void BasicBlock::RegisterBranchInstruction(SpvOp branch_instruction) { + if (branch_instruction == SpvOpUnreachable) reachable_ = false; + return; +} + +bool BasicBlock::dominates(const BasicBlock& other) const { + return (this == &other) || + !(other.dom_end() == + std::find(other.dom_begin(), other.dom_end(), this)); +} + +bool BasicBlock::postdominates(const BasicBlock& other) const { + return (this == &other) || + !(other.pdom_end() == + std::find(other.pdom_begin(), other.pdom_end(), this)); +} + +BasicBlock::DominatorIterator::DominatorIterator() : current_(nullptr) {} + +BasicBlock::DominatorIterator::DominatorIterator( + const BasicBlock* block, + std::function dominator_func) + : current_(block), dom_func_(dominator_func) {} + +BasicBlock::DominatorIterator& BasicBlock::DominatorIterator::operator++() { + if (current_ == dom_func_(current_)) { + current_ = nullptr; + } else { + current_ = dom_func_(current_); + } + return *this; +} + +const BasicBlock::DominatorIterator BasicBlock::dom_begin() const { + return DominatorIterator( + this, [](const BasicBlock* b) { return b->immediate_dominator(); }); +} + +BasicBlock::DominatorIterator BasicBlock::dom_begin() { + return DominatorIterator( + this, [](const BasicBlock* b) { return b->immediate_dominator(); }); +} + +const BasicBlock::DominatorIterator BasicBlock::dom_end() const { + return DominatorIterator(); +} + +BasicBlock::DominatorIterator BasicBlock::dom_end() { + return DominatorIterator(); +} + +const BasicBlock::DominatorIterator BasicBlock::pdom_begin() const { + return DominatorIterator( + this, [](const BasicBlock* b) { return b->immediate_post_dominator(); }); +} + +BasicBlock::DominatorIterator BasicBlock::pdom_begin() { + return DominatorIterator( + this, [](const BasicBlock* b) { return b->immediate_post_dominator(); }); +} + +const BasicBlock::DominatorIterator BasicBlock::pdom_end() const { + return DominatorIterator(); +} + +BasicBlock::DominatorIterator BasicBlock::pdom_end() { + return DominatorIterator(); +} + +bool operator==(const BasicBlock::DominatorIterator& lhs, + const BasicBlock::DominatorIterator& rhs) { + return lhs.current_ == rhs.current_; +} + +bool operator!=(const BasicBlock::DominatorIterator& lhs, + const BasicBlock::DominatorIterator& rhs) { + return !(lhs == rhs); +} + +const BasicBlock*& BasicBlock::DominatorIterator::operator*() { + return current_; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/basic_block.h b/source/val/basic_block.h new file mode 100644 index 000000000..efbd243b6 --- /dev/null +++ b/source/val/basic_block.h @@ -0,0 +1,247 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_VAL_BASIC_BLOCK_H_ +#define SOURCE_VAL_BASIC_BLOCK_H_ + +#include +#include +#include +#include +#include + +#include "source/latest_version_spirv_header.h" + +namespace spvtools { +namespace val { + +enum BlockType : uint32_t { + kBlockTypeUndefined, + kBlockTypeHeader, + kBlockTypeLoop, + kBlockTypeMerge, + kBlockTypeBreak, + kBlockTypeContinue, + kBlockTypeReturn, + kBlockTypeCOUNT ///< Total number of block types. (must be the last element) +}; + +class Instruction; + +// This class represents a basic block in a SPIR-V module +class BasicBlock { + public: + /// Constructor for a BasicBlock + /// + /// @param[in] id The ID of the basic block + explicit BasicBlock(uint32_t id); + + /// Returns the id of the BasicBlock + uint32_t id() const { return id_; } + + /// Returns the predecessors of the BasicBlock + const std::vector* predecessors() const { + return &predecessors_; + } + + /// Returns the predecessors of the BasicBlock + std::vector* predecessors() { return &predecessors_; } + + /// Returns the successors of the BasicBlock + const std::vector* successors() const { return &successors_; } + + /// Returns the successors of the BasicBlock + std::vector* successors() { return &successors_; } + + /// Returns true if the block is reachable in the CFG + bool reachable() const { return reachable_; } + + /// Returns true if BasicBlock is of the given type + bool is_type(BlockType type) const { + if (type == kBlockTypeUndefined) return type_.none(); + return type_.test(type); + } + + /// Sets the reachability of the basic block in the CFG + void set_reachable(bool reachability) { reachable_ = reachability; } + + /// Sets the type of the BasicBlock + void set_type(BlockType type) { + if (type == kBlockTypeUndefined) + type_.reset(); + else + type_.set(type); + } + + /// Sets the immedate dominator of this basic block + /// + /// @param[in] dom_block The dominator block + void SetImmediateDominator(BasicBlock* dom_block); + + /// Sets the immedate post dominator of this basic block + /// + /// @param[in] pdom_block The post dominator block + void SetImmediatePostDominator(BasicBlock* pdom_block); + + /// Returns the immedate dominator of this basic block + BasicBlock* immediate_dominator(); + + /// Returns the immedate dominator of this basic block + const BasicBlock* immediate_dominator() const; + + /// Returns the immedate post dominator of this basic block + BasicBlock* immediate_post_dominator(); + + /// Returns the immedate post dominator of this basic block + const BasicBlock* immediate_post_dominator() const; + + /// Ends the block without a successor + void RegisterBranchInstruction(SpvOp branch_instruction); + + /// Returns the label instruction for the block, or nullptr if not set. + const Instruction* label() const { return label_; } + + //// Registers the label instruction for the block. + void set_label(const Instruction* t) { label_ = t; } + + /// Registers the terminator instruction for the block. + void set_terminator(const Instruction* t) { terminator_ = t; } + + /// Returns the terminator instruction for the block. + const Instruction* terminator() const { return terminator_; } + + /// Adds @p next BasicBlocks as successors of this BasicBlock + void RegisterSuccessors( + const std::vector& next = std::vector()); + + /// Returns true if the id of the BasicBlock matches + bool operator==(const BasicBlock& other) const { return other.id_ == id_; } + + /// Returns true if the id of the BasicBlock matches + bool operator==(const uint32_t& other_id) const { return other_id == id_; } + + /// Returns true if this block dominates the other block. + /// Assumes dominators have been computed. + bool dominates(const BasicBlock& other) const; + + /// Returns true if this block postdominates the other block. + /// Assumes dominators have been computed. + bool postdominates(const BasicBlock& other) const; + + /// @brief A BasicBlock dominator iterator class + /// + /// This iterator will iterate over the (post)dominators of the block + class DominatorIterator + : public std::iterator { + public: + /// @brief Constructs the end of dominator iterator + /// + /// This will create an iterator which will represent the element + /// before the root node of the dominator tree + DominatorIterator(); + + /// @brief Constructs an iterator for the given block which points to + /// @p block + /// + /// @param block The block which is referenced by the iterator + /// @param dominator_func This function will be called to get the immediate + /// (post)dominator of the current block + DominatorIterator( + const BasicBlock* block, + std::function dominator_func); + + /// @brief Advances the iterator + DominatorIterator& operator++(); + + /// @brief Returns the current element + const BasicBlock*& operator*(); + + friend bool operator==(const DominatorIterator& lhs, + const DominatorIterator& rhs); + + private: + const BasicBlock* current_; + std::function dom_func_; + }; + + /// Returns a dominator iterator which points to the current block + const DominatorIterator dom_begin() const; + + /// Returns a dominator iterator which points to the current block + DominatorIterator dom_begin(); + + /// Returns a dominator iterator which points to one element past the first + /// block + const DominatorIterator dom_end() const; + + /// Returns a dominator iterator which points to one element past the first + /// block + DominatorIterator dom_end(); + + /// Returns a post dominator iterator which points to the current block + const DominatorIterator pdom_begin() const; + /// Returns a post dominator iterator which points to the current block + DominatorIterator pdom_begin(); + + /// Returns a post dominator iterator which points to one element past the + /// last block + const DominatorIterator pdom_end() const; + + /// Returns a post dominator iterator which points to one element past the + /// last block + DominatorIterator pdom_end(); + + private: + /// Id of the BasicBlock + const uint32_t id_; + + /// Pointer to the immediate dominator of the BasicBlock + BasicBlock* immediate_dominator_; + + /// Pointer to the immediate dominator of the BasicBlock + BasicBlock* immediate_post_dominator_; + + /// The set of predecessors of the BasicBlock + std::vector predecessors_; + + /// The set of successors of the BasicBlock + std::vector successors_; + + /// The type of the block + std::bitset type_; + + /// True if the block is reachable in the CFG + bool reachable_; + + /// label of this block, if any. + const Instruction* label_; + + /// Terminator of this block. + const Instruction* terminator_; +}; + +/// @brief Returns true if the iterators point to the same element or if both +/// iterators point to the @p dom_end block +bool operator==(const BasicBlock::DominatorIterator& lhs, + const BasicBlock::DominatorIterator& rhs); + +/// @brief Returns true if the iterators point to different elements and they +/// do not both point to the @p dom_end block +bool operator!=(const BasicBlock::DominatorIterator& lhs, + const BasicBlock::DominatorIterator& rhs); + +} // namespace val +} // namespace spvtools + +#endif // SOURCE_VAL_BASIC_BLOCK_H_ diff --git a/source/val/construct.cpp b/source/val/construct.cpp new file mode 100644 index 000000000..7b0cb2dcc --- /dev/null +++ b/source/val/construct.cpp @@ -0,0 +1,131 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/construct.h" + +#include +#include +#include + +#include "source/val/function.h" + +namespace spvtools { +namespace val { + +Construct::Construct(ConstructType construct_type, BasicBlock* entry, + BasicBlock* exit, std::vector constructs) + : type_(construct_type), + corresponding_constructs_(constructs), + entry_block_(entry), + exit_block_(exit) {} + +ConstructType Construct::type() const { return type_; } + +const std::vector& Construct::corresponding_constructs() const { + return corresponding_constructs_; +} +std::vector& Construct::corresponding_constructs() { + return corresponding_constructs_; +} + +bool ValidateConstructSize(ConstructType type, size_t size) { + switch (type) { + case ConstructType::kSelection: + return size == 0; + case ConstructType::kContinue: + return size == 1; + case ConstructType::kLoop: + return size == 1; + case ConstructType::kCase: + return size >= 1; + default: + assert(1 == 0 && "Type not defined"); + } + return false; +} + +void Construct::set_corresponding_constructs( + std::vector constructs) { + assert(ValidateConstructSize(type_, constructs.size())); + corresponding_constructs_ = constructs; +} + +const BasicBlock* Construct::entry_block() const { return entry_block_; } +BasicBlock* Construct::entry_block() { return entry_block_; } + +const BasicBlock* Construct::exit_block() const { return exit_block_; } +BasicBlock* Construct::exit_block() { return exit_block_; } + +void Construct::set_exit(BasicBlock* block) { exit_block_ = block; } + +Construct::ConstructBlockSet Construct::blocks(Function* function) const { + auto header = entry_block(); + auto merge = exit_block(); + assert(header); + assert(merge); + int header_depth = function->GetBlockDepth(const_cast(header)); + ConstructBlockSet construct_blocks; + std::unordered_set corresponding_headers; + for (auto& other : corresponding_constructs()) { + corresponding_headers.insert(other->entry_block()); + } + std::vector stack; + stack.push_back(const_cast(header)); + while (!stack.empty()) { + BasicBlock* block = stack.back(); + stack.pop_back(); + + if (merge == block && ExitBlockIsMergeBlock()) { + // Merge block is not part of the construct. + continue; + } + + if (corresponding_headers.count(block)) { + // Entered a corresponding construct. + continue; + } + + int block_depth = function->GetBlockDepth(block); + if (block_depth < header_depth) { + // Broke to outer construct. + continue; + } + + // In a loop, the continue target is at a depth of the loop construct + 1. + // A selection construct nested directly within the loop construct is also + // at the same depth. It is valid, however, to branch directly to the + // continue target from within the selection construct. + if (block_depth == header_depth && type() == ConstructType::kSelection && + block->is_type(kBlockTypeContinue)) { + // Continued to outer construct. + continue; + } + + if (!construct_blocks.insert(block).second) continue; + + if (merge != block) { + for (auto succ : *block->successors()) { + // All blocks in the construct must be dominated by the header. + if (header->dominates(*succ)) { + stack.push_back(succ); + } + } + } + } + + return construct_blocks; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/construct.h b/source/val/construct.h new file mode 100644 index 000000000..c7e7a780d --- /dev/null +++ b/source/val/construct.h @@ -0,0 +1,151 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_VAL_CONSTRUCT_H_ +#define SOURCE_VAL_CONSTRUCT_H_ + +#include +#include +#include + +#include "source/val/basic_block.h" + +namespace spvtools { +namespace val { + +/// Functor for ordering BasicBlocks. BasicBlock pointers must not be null. +struct less_than_id { + bool operator()(const BasicBlock* lhs, const BasicBlock* rhs) const { + return lhs->id() < rhs->id(); + } +}; + +enum class ConstructType : int { + kNone = 0, + /// The set of blocks dominated by a selection header, minus the set of blocks + /// dominated by the header's merge block + kSelection, + /// The set of blocks dominated by an OpLoopMerge's Continue Target and post + /// dominated by the corresponding back + kContinue, + /// The set of blocks dominated by a loop header, minus the set of blocks + /// dominated by the loop's merge block, minus the loop's corresponding + /// continue construct + kLoop, + /// The set of blocks dominated by an OpSwitch's Target or Default, minus the + /// set of blocks dominated by the OpSwitch's merge block (this construct is + /// only defined for those OpSwitch Target or Default that are not equal to + /// the OpSwitch's corresponding merge block) + kCase +}; + +class Function; + +/// @brief This class tracks the CFG constructs as defined in the SPIR-V spec +class Construct { + public: + Construct(ConstructType type, BasicBlock* dominator, + BasicBlock* exit = nullptr, + std::vector constructs = std::vector()); + + /// Returns the type of the construct + ConstructType type() const; + + const std::vector& corresponding_constructs() const; + std::vector& corresponding_constructs(); + void set_corresponding_constructs(std::vector constructs); + + /// Returns the dominator block of the construct. + /// + /// This is usually the header block or the first block of the construct. + const BasicBlock* entry_block() const; + + /// Returns the dominator block of the construct. + /// + /// This is usually the header block or the first block of the construct. + BasicBlock* entry_block(); + + /// Returns the exit block of the construct. + /// + /// For a continue construct it is the backedge block of the corresponding + /// loop construct. For the case construct it is the block that branches to + /// the OpSwitch merge block or other case blocks. Otherwise it is the merge + /// block of the corresponding header block + const BasicBlock* exit_block() const; + + /// Returns the exit block of the construct. + /// + /// For a continue construct it is the backedge block of the corresponding + /// loop construct. For the case construct it is the block that branches to + /// the OpSwitch merge block or other case blocks. Otherwise it is the merge + /// block of the corresponding header block + BasicBlock* exit_block(); + + /// Sets the exit block for this construct. This is useful for continue + /// constructs which do not know the back-edge block during construction + void set_exit(BasicBlock* exit_block); + + // Returns whether the exit block of this construct is the merge block + // for an OpLoopMerge or OpSelectionMerge + bool ExitBlockIsMergeBlock() const { + return type_ == ConstructType::kLoop || type_ == ConstructType::kSelection; + } + + using ConstructBlockSet = std::set; + + // Returns the basic blocks in this construct. This function should not + // be called before the exit block is set and dominators have been + // calculated. + ConstructBlockSet blocks(Function* function) const; + + private: + /// The type of the construct + ConstructType type_; + + /// These are the constructs that are related to this construct. These + /// constructs can be the continue construct, for the corresponding loop + /// construct, the case construct that are part of the same OpSwitch + /// instruction + /// + /// Here is a table that describes what constructs are included in + /// @p corresponding_constructs_ + /// | this construct | corresponding construct | + /// |----------------|----------------------------------| + /// | loop | continue | + /// | continue | loop | + /// | case | other cases in the same OpSwitch | + /// + /// kContinue and kLoop constructs will always have corresponding + /// constructs even if they are represented by the same block + std::vector corresponding_constructs_; + + /// @brief Dominator block for the construct + /// + /// The dominator block for the construct. Depending on the construct this may + /// be a selection header, a continue target of a loop, a loop header or a + /// Target or Default block of a switch + BasicBlock* entry_block_; + + /// @brief Exiting block for the construct + /// + /// The exit block for the construct. This can be a merge block for the loop + /// and selection constructs, a back-edge block for a continue construct, or + /// the branching block for the case construct + BasicBlock* exit_block_; +}; + +} // namespace val +} // namespace spvtools + +#endif // SOURCE_VAL_CONSTRUCT_H_ diff --git a/source/val/decoration.h b/source/val/decoration.h new file mode 100644 index 000000000..ed3320f87 --- /dev/null +++ b/source/val/decoration.h @@ -0,0 +1,89 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_VAL_DECORATION_H_ +#define SOURCE_VAL_DECORATION_H_ + +#include +#include +#include + +#include "source/latest_version_spirv_header.h" + +namespace spvtools { +namespace val { + +// An object of this class represents a specific decoration including its +// parameters (if any). Decorations are used by OpDecorate and OpMemberDecorate, +// and they describe certain properties that can be assigned to one or several +// s. +// +// A Decoration object contains the decoration type (an enum), associated +// literal parameters, and struct member index. If the decoration does not apply +// to a struct member, then the index is kInvalidIndex. A Decoration object does +// not store the target Id, i.e. the Id to which it applies. It is +// possible for the same decoration to be applied to several s (and they +// might be assigned using separate SPIR-V instructions, possibly using an +// assignment through GroupDecorate). +// +// Example 1: Decoration for an object with no parameters: +// OpDecorate %obj Flat +// dec_type_ = SpvDecorationFlat +// params_ = empty vector +// struct_member_index_ = kInvalidMember +// +// Example 2: Decoration for an object with two parameters: +// OpDecorate %obj LinkageAttributes "link" Import +// dec_type_ = SpvDecorationLinkageAttributes +// params_ = vector { link, Import } +// struct_member_index_ = kInvalidMember +// +// Example 3: Decoration for a member of a structure with one parameter: +// OpMemberDecorate %struct 2 Offset 2 +// dec_type_ = SpvDecorationOffset +// params_ = vector { 2 } +// struct_member_index_ = 2 +// +class Decoration { + public: + enum { kInvalidMember = -1 }; + Decoration(SpvDecoration t, + const std::vector& parameters = std::vector(), + uint32_t member_index = kInvalidMember) + : dec_type_(t), params_(parameters), struct_member_index_(member_index) {} + + void set_struct_member_index(uint32_t index) { struct_member_index_ = index; } + int struct_member_index() const { return struct_member_index_; } + SpvDecoration dec_type() const { return dec_type_; } + std::vector& params() { return params_; } + const std::vector& params() const { return params_; } + + inline bool operator==(const Decoration& rhs) const { + return (dec_type_ == rhs.dec_type_ && params_ == rhs.params_ && + struct_member_index_ == rhs.struct_member_index_); + } + + private: + SpvDecoration dec_type_; + std::vector params_; + + // If the decoration applies to a member of a structure type, then the index + // of the member is stored here. Otherwise, this is kInvalidIndex. + int struct_member_index_; +}; + +} // namespace val +} // namespace spvtools + +#endif // SOURCE_VAL_DECORATION_H_ diff --git a/source/val/function.cpp b/source/val/function.cpp new file mode 100644 index 000000000..f638fb5b4 --- /dev/null +++ b/source/val/function.cpp @@ -0,0 +1,387 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/function.h" + +#include + +#include +#include +#include +#include +#include + +#include "source/cfa.h" +#include "source/val/basic_block.h" +#include "source/val/construct.h" +#include "source/val/validate.h" + +namespace spvtools { +namespace val { + +// Universal Limit of ResultID + 1 +static const uint32_t kInvalidId = 0x400000; + +Function::Function(uint32_t function_id, uint32_t result_type_id, + SpvFunctionControlMask function_control, + uint32_t function_type_id) + : id_(function_id), + function_type_id_(function_type_id), + result_type_id_(result_type_id), + function_control_(function_control), + declaration_type_(FunctionDecl::kFunctionDeclUnknown), + end_has_been_registered_(false), + blocks_(), + current_block_(nullptr), + pseudo_entry_block_(0), + pseudo_exit_block_(kInvalidId), + cfg_constructs_(), + variable_ids_(), + parameter_ids_() {} + +bool Function::IsFirstBlock(uint32_t block_id) const { + return !ordered_blocks_.empty() && *first_block() == block_id; +} + +spv_result_t Function::RegisterFunctionParameter(uint32_t parameter_id, + uint32_t type_id) { + assert(current_block_ == nullptr && + "RegisterFunctionParameter can only be called when parsing the binary " + "ouside of a block"); + // TODO(umar): Validate function parameter type order and count + // TODO(umar): Use these variables to validate parameter type + (void)parameter_id; + (void)type_id; + return SPV_SUCCESS; +} + +spv_result_t Function::RegisterLoopMerge(uint32_t merge_id, + uint32_t continue_id) { + RegisterBlock(merge_id, false); + RegisterBlock(continue_id, false); + BasicBlock& merge_block = blocks_.at(merge_id); + BasicBlock& continue_target_block = blocks_.at(continue_id); + assert(current_block_ && + "RegisterLoopMerge must be called when called within a block"); + + current_block_->set_type(kBlockTypeLoop); + merge_block.set_type(kBlockTypeMerge); + continue_target_block.set_type(kBlockTypeContinue); + Construct& loop_construct = + AddConstruct({ConstructType::kLoop, current_block_, &merge_block}); + Construct& continue_construct = + AddConstruct({ConstructType::kContinue, &continue_target_block}); + + continue_construct.set_corresponding_constructs({&loop_construct}); + loop_construct.set_corresponding_constructs({&continue_construct}); + merge_block_header_[&merge_block] = current_block_; + + return SPV_SUCCESS; +} + +spv_result_t Function::RegisterSelectionMerge(uint32_t merge_id) { + RegisterBlock(merge_id, false); + BasicBlock& merge_block = blocks_.at(merge_id); + current_block_->set_type(kBlockTypeHeader); + merge_block.set_type(kBlockTypeMerge); + merge_block_header_[&merge_block] = current_block_; + + AddConstruct({ConstructType::kSelection, current_block(), &merge_block}); + + return SPV_SUCCESS; +} + +spv_result_t Function::RegisterSetFunctionDeclType(FunctionDecl type) { + assert(declaration_type_ == FunctionDecl::kFunctionDeclUnknown); + declaration_type_ = type; + return SPV_SUCCESS; +} + +spv_result_t Function::RegisterBlock(uint32_t block_id, bool is_definition) { + assert( + declaration_type_ == FunctionDecl::kFunctionDeclDefinition && + "RegisterBlocks can only be called after declaration_type_ is defined"); + + std::unordered_map::iterator inserted_block; + bool success = false; + tie(inserted_block, success) = + blocks_.insert({block_id, BasicBlock(block_id)}); + if (is_definition) { // new block definition + assert(current_block_ == nullptr && + "Register Block can only be called when parsing a binary outside of " + "a BasicBlock"); + + undefined_blocks_.erase(block_id); + current_block_ = &inserted_block->second; + ordered_blocks_.push_back(current_block_); + if (IsFirstBlock(block_id)) current_block_->set_reachable(true); + } else if (success) { // Block doesn't exsist but this is not a definition + undefined_blocks_.insert(block_id); + } + + return SPV_SUCCESS; +} + +void Function::RegisterBlockEnd(std::vector next_list, + SpvOp branch_instruction) { + assert( + current_block_ && + "RegisterBlockEnd can only be called when parsing a binary in a block"); + std::vector next_blocks; + next_blocks.reserve(next_list.size()); + + std::unordered_map::iterator inserted_block; + bool success; + for (uint32_t successor_id : next_list) { + tie(inserted_block, success) = + blocks_.insert({successor_id, BasicBlock(successor_id)}); + if (success) { + undefined_blocks_.insert(successor_id); + } + next_blocks.push_back(&inserted_block->second); + } + + if (current_block_->is_type(kBlockTypeLoop)) { + // For each loop header, record the set of its successors, and include + // its continue target if the continue target is not the loop header + // itself. + std::vector& next_blocks_plus_continue_target = + loop_header_successors_plus_continue_target_map_[current_block_]; + next_blocks_plus_continue_target = next_blocks; + auto continue_target = + FindConstructForEntryBlock(current_block_, ConstructType::kLoop) + .corresponding_constructs() + .back() + ->entry_block(); + if (continue_target != current_block_) { + next_blocks_plus_continue_target.push_back(continue_target); + } + } + + current_block_->RegisterBranchInstruction(branch_instruction); + current_block_->RegisterSuccessors(next_blocks); + current_block_ = nullptr; + return; +} + +void Function::RegisterFunctionEnd() { + if (!end_has_been_registered_) { + end_has_been_registered_ = true; + + ComputeAugmentedCFG(); + } +} + +size_t Function::block_count() const { return blocks_.size(); } + +size_t Function::undefined_block_count() const { + return undefined_blocks_.size(); +} + +const std::vector& Function::ordered_blocks() const { + return ordered_blocks_; +} +std::vector& Function::ordered_blocks() { return ordered_blocks_; } + +const BasicBlock* Function::current_block() const { return current_block_; } +BasicBlock* Function::current_block() { return current_block_; } + +const std::list& Function::constructs() const { + return cfg_constructs_; +} +std::list& Function::constructs() { return cfg_constructs_; } + +const BasicBlock* Function::first_block() const { + if (ordered_blocks_.empty()) return nullptr; + return ordered_blocks_[0]; +} +BasicBlock* Function::first_block() { + if (ordered_blocks_.empty()) return nullptr; + return ordered_blocks_[0]; +} + +bool Function::IsBlockType(uint32_t merge_block_id, BlockType type) const { + bool ret = false; + const BasicBlock* block; + std::tie(block, std::ignore) = GetBlock(merge_block_id); + if (block) { + ret = block->is_type(type); + } + return ret; +} + +std::pair Function::GetBlock(uint32_t block_id) const { + const auto b = blocks_.find(block_id); + if (b != end(blocks_)) { + const BasicBlock* block = &(b->second); + bool defined = + undefined_blocks_.find(block->id()) == std::end(undefined_blocks_); + return std::make_pair(block, defined); + } else { + return std::make_pair(nullptr, false); + } +} + +std::pair Function::GetBlock(uint32_t block_id) { + const BasicBlock* out; + bool defined; + std::tie(out, defined) = + const_cast(this)->GetBlock(block_id); + return std::make_pair(const_cast(out), defined); +} + +Function::GetBlocksFunction Function::AugmentedCFGSuccessorsFunction() const { + return [this](const BasicBlock* block) { + auto where = augmented_successors_map_.find(block); + return where == augmented_successors_map_.end() ? block->successors() + : &(*where).second; + }; +} + +Function::GetBlocksFunction +Function::AugmentedCFGSuccessorsFunctionIncludingHeaderToContinueEdge() const { + return [this](const BasicBlock* block) { + auto where = loop_header_successors_plus_continue_target_map_.find(block); + return where == loop_header_successors_plus_continue_target_map_.end() + ? AugmentedCFGSuccessorsFunction()(block) + : &(*where).second; + }; +} + +Function::GetBlocksFunction Function::AugmentedCFGPredecessorsFunction() const { + return [this](const BasicBlock* block) { + auto where = augmented_predecessors_map_.find(block); + return where == augmented_predecessors_map_.end() ? block->predecessors() + : &(*where).second; + }; +} + +void Function::ComputeAugmentedCFG() { + // Compute the successors of the pseudo-entry block, and + // the predecessors of the pseudo exit block. + auto succ_func = [](const BasicBlock* b) { return b->successors(); }; + auto pred_func = [](const BasicBlock* b) { return b->predecessors(); }; + CFA::ComputeAugmentedCFG( + ordered_blocks_, &pseudo_entry_block_, &pseudo_exit_block_, + &augmented_successors_map_, &augmented_predecessors_map_, succ_func, + pred_func); +} + +Construct& Function::AddConstruct(const Construct& new_construct) { + cfg_constructs_.push_back(new_construct); + auto& result = cfg_constructs_.back(); + entry_block_to_construct_[std::make_pair(new_construct.entry_block(), + new_construct.type())] = &result; + return result; +} + +Construct& Function::FindConstructForEntryBlock(const BasicBlock* entry_block, + ConstructType type) { + auto where = + entry_block_to_construct_.find(std::make_pair(entry_block, type)); + assert(where != entry_block_to_construct_.end()); + auto construct_ptr = (*where).second; + assert(construct_ptr); + return *construct_ptr; +} + +int Function::GetBlockDepth(BasicBlock* bb) { + // Guard against nullptr. + if (!bb) { + return 0; + } + // Only calculate the depth if it's not already calculated. + // This function uses memoization to avoid duplicate CFG depth calculations. + if (block_depth_.find(bb) != block_depth_.end()) { + return block_depth_[bb]; + } + + BasicBlock* bb_dom = bb->immediate_dominator(); + if (!bb_dom || bb == bb_dom) { + // This block has no dominator, so it's at depth 0. + block_depth_[bb] = 0; + } else if (bb->is_type(kBlockTypeMerge)) { + // If this is a merge block, its depth is equal to the block before + // branching. + BasicBlock* header = merge_block_header_[bb]; + assert(header); + block_depth_[bb] = GetBlockDepth(header); + } else if (bb->is_type(kBlockTypeContinue)) { + // The depth of the continue block entry point is 1 + loop header depth. + Construct* continue_construct = + entry_block_to_construct_[std::make_pair(bb, ConstructType::kContinue)]; + assert(continue_construct); + // Continue construct has only 1 corresponding construct (loop header). + Construct* loop_construct = + continue_construct->corresponding_constructs()[0]; + assert(loop_construct); + BasicBlock* loop_header = loop_construct->entry_block(); + // The continue target may be the loop itself (while 1). + // In such cases, the depth of the continue block is: 1 + depth of the + // loop's dominator block. + if (loop_header == bb) { + block_depth_[bb] = 1 + GetBlockDepth(bb_dom); + } else { + block_depth_[bb] = 1 + GetBlockDepth(loop_header); + } + } else if (bb_dom->is_type(kBlockTypeHeader) || + bb_dom->is_type(kBlockTypeLoop)) { + // The dominator of the given block is a header block. So, the nesting + // depth of this block is: 1 + nesting depth of the header. + block_depth_[bb] = 1 + GetBlockDepth(bb_dom); + } else { + block_depth_[bb] = GetBlockDepth(bb_dom); + } + return block_depth_[bb]; +} + +void Function::RegisterExecutionModelLimitation(SpvExecutionModel model, + const std::string& message) { + execution_model_limitations_.push_back( + [model, message](SpvExecutionModel in_model, std::string* out_message) { + if (model != in_model) { + if (out_message) { + *out_message = message; + } + return false; + } + return true; + }); +} + +bool Function::IsCompatibleWithExecutionModel(SpvExecutionModel model, + std::string* reason) const { + bool return_value = true; + std::stringstream ss_reason; + + for (const auto& is_compatible : execution_model_limitations_) { + std::string message; + if (!is_compatible(model, &message)) { + if (!reason) return false; + return_value = false; + if (!message.empty()) { + ss_reason << message << "\n"; + } + } + } + + if (!return_value && reason) { + *reason = ss_reason.str(); + } + + return return_value; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/function.h b/source/val/function.h new file mode 100644 index 000000000..a052bbda0 --- /dev/null +++ b/source/val/function.h @@ -0,0 +1,360 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_VAL_FUNCTION_H_ +#define SOURCE_VAL_FUNCTION_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/latest_version_spirv_header.h" +#include "source/val/basic_block.h" +#include "source/val/construct.h" +#include "spirv-tools/libspirv.h" + +namespace spvtools { +namespace val { + +struct bb_constr_type_pair_hash { + std::size_t operator()( + const std::pair& p) const { + auto h1 = std::hash{}(p.first); + auto h2 = std::hash::type>{}( + static_cast::type>(p.second)); + return (h1 ^ h2); + } +}; + +enum class FunctionDecl { + kFunctionDeclUnknown, /// < Unknown function declaration + kFunctionDeclDeclaration, /// < Function declaration + kFunctionDeclDefinition /// < Function definition +}; + +/// This class manages all function declaration and definitions in a module. It +/// handles the state and id information while parsing a function in the SPIR-V +/// binary. +class Function { + public: + Function(uint32_t id, uint32_t result_type_id, + SpvFunctionControlMask function_control, uint32_t function_type_id); + + /// Registers a function parameter in the current function + /// @return Returns SPV_SUCCESS if the call was successful + spv_result_t RegisterFunctionParameter(uint32_t id, uint32_t type_id); + + /// Sets the declaration type of the current function + /// @return Returns SPV_SUCCESS if the call was successful + spv_result_t RegisterSetFunctionDeclType(FunctionDecl type); + + /// Registers a block in the current function. Subsequent block instructions + /// will target this block + /// @param id The ID of the label of the block + /// @return Returns SPV_SUCCESS if the call was successful + spv_result_t RegisterBlock(uint32_t id, bool is_definition = true); + + /// Registers a variable in the current block + /// + /// @param[in] type_id The type ID of the varaible + /// @param[in] id The ID of the varaible + /// @param[in] storage The storage of the variable + /// @param[in] init_id The initializer ID of the variable + /// + /// @return Returns SPV_SUCCESS if the call was successful + spv_result_t RegisterBlockVariable(uint32_t type_id, uint32_t id, + SpvStorageClass storage, uint32_t init_id); + + /// Registers a loop merge construct in the function + /// + /// @param[in] merge_id The merge block ID of the loop + /// @param[in] continue_id The continue block ID of the loop + /// + /// @return Returns SPV_SUCCESS if the call was successful + spv_result_t RegisterLoopMerge(uint32_t merge_id, uint32_t continue_id); + + /// Registers a selection merge construct in the function + /// @return Returns SPV_SUCCESS if the call was successful + spv_result_t RegisterSelectionMerge(uint32_t merge_id); + + /// Registers the end of the block + /// + /// @param[in] successors_list A list of ids to the block's successors + /// @param[in] branch_instruction the branch instruction that ended the block + void RegisterBlockEnd(std::vector successors_list, + SpvOp branch_instruction); + + /// Registers the end of the function. This is idempotent. + void RegisterFunctionEnd(); + + /// Returns true if the \p id block is the first block of this function + bool IsFirstBlock(uint32_t id) const; + + /// Returns true if the \p merge_block_id is a BlockType of \p type + bool IsBlockType(uint32_t merge_block_id, BlockType type) const; + + /// Returns a pair consisting of the BasicBlock with \p id and a bool + /// which is true if the block has been defined, and false if it is + /// declared but not defined. This function will return nullptr if the + /// \p id was not declared and not defined at the current point in the binary + std::pair GetBlock(uint32_t id) const; + std::pair GetBlock(uint32_t id); + + /// Returns the first block of the current function + const BasicBlock* first_block() const; + + /// Returns the first block of the current function + BasicBlock* first_block(); + + /// Returns a vector of all the blocks in the function + const std::vector& ordered_blocks() const; + + /// Returns a vector of all the blocks in the function + std::vector& ordered_blocks(); + + /// Returns a list of all the cfg constructs in the function + const std::list& constructs() const; + + /// Returns a list of all the cfg constructs in the function + std::list& constructs(); + + /// Returns the number of blocks in the current function being parsed + size_t block_count() const; + + /// Returns the id of the function + uint32_t id() const { return id_; } + + /// Returns return type id of the function + uint32_t GetResultTypeId() const { return result_type_id_; } + + /// Returns the number of blocks in the current function being parsed + size_t undefined_block_count() const; + const std::unordered_set& undefined_blocks() const { + return undefined_blocks_; + } + + /// Returns the block that is currently being parsed in the binary + BasicBlock* current_block(); + + /// Returns the block that is currently being parsed in the binary + const BasicBlock* current_block() const; + + // For dominance calculations, we want to analyze all the + // blocks in the function, even in degenerate control flow cases + // including unreachable blocks. We therefore make an "augmented CFG" + // which is the same as the ordinary CFG but adds: + // - A pseudo-entry node. + // - A pseudo-exit node. + // - A minimal set of edges so that a forward traversal from the + // pseudo-entry node will visit all nodes. + // - A minimal set of edges so that a backward traversal from the + // pseudo-exit node will visit all nodes. + // In particular, the pseudo-entry node is the unique source of the + // augmented CFG, and the psueo-exit node is the unique sink of the + // augmented CFG. + + /// Returns the pseudo exit block + BasicBlock* pseudo_entry_block() { return &pseudo_entry_block_; } + + /// Returns the pseudo exit block + const BasicBlock* pseudo_entry_block() const { return &pseudo_entry_block_; } + + /// Returns the pseudo exit block + BasicBlock* pseudo_exit_block() { return &pseudo_exit_block_; } + + /// Returns the pseudo exit block + const BasicBlock* pseudo_exit_block() const { return &pseudo_exit_block_; } + + using GetBlocksFunction = + std::function*(const BasicBlock*)>; + /// Returns the block successors function for the augmented CFG. + GetBlocksFunction AugmentedCFGSuccessorsFunction() const; + /// Like AugmentedCFGSuccessorsFunction, but also includes a forward edge from + /// a loop header block to its continue target, if they are different blocks. + GetBlocksFunction + AugmentedCFGSuccessorsFunctionIncludingHeaderToContinueEdge() const; + /// Returns the block predecessors function for the augmented CFG. + GetBlocksFunction AugmentedCFGPredecessorsFunction() const; + + /// Returns the control flow nesting depth of the given basic block. + /// This function only works when you have structured control flow. + /// This function should only be called after the control flow constructs have + /// been identified and dominators have been computed. + int GetBlockDepth(BasicBlock* bb); + + /// Prints a GraphViz digraph of the CFG of the current funciton + void PrintDotGraph() const; + + /// Prints a directed graph of the CFG of the current funciton + void PrintBlocks() const; + + /// Registers execution model limitation such as "Feature X is only available + /// with Execution Model Y". + void RegisterExecutionModelLimitation(SpvExecutionModel model, + const std::string& message); + + /// Registers execution model limitation with an |is_compatible| functor. + void RegisterExecutionModelLimitation( + std::function is_compatible) { + execution_model_limitations_.push_back(is_compatible); + } + + /// Returns true if the given execution model passes the limitations stored in + /// execution_model_limitations_. Returns false otherwise and fills optional + /// |reason| parameter. + bool IsCompatibleWithExecutionModel(SpvExecutionModel model, + std::string* reason = nullptr) const; + + // Inserts id to the set of functions called from this function. + void AddFunctionCallTarget(uint32_t call_target_id) { + function_call_targets_.insert(call_target_id); + } + + // Returns a set with ids of all functions called from this function. + const std::set function_call_targets() const { + return function_call_targets_; + } + + private: + // Computes the representation of the augmented CFG. + // Populates augmented_successors_map_ and augmented_predecessors_map_. + void ComputeAugmentedCFG(); + + // Adds a copy of the given Construct, and tracks it by its entry block. + // Returns a reference to the stored construct. + Construct& AddConstruct(const Construct& new_construct); + + // Returns a reference to the construct corresponding to the given entry + // block. + Construct& FindConstructForEntryBlock(const BasicBlock* entry_block, + ConstructType t); + + /// The result id of the OpLabel that defined this block + uint32_t id_; + + /// The type of the function + uint32_t function_type_id_; + + /// The type of the return value + uint32_t result_type_id_; + + /// The control fo the funciton + SpvFunctionControlMask function_control_; + + /// The type of declaration of each function + FunctionDecl declaration_type_; + + // Have we finished parsing this function? + bool end_has_been_registered_; + + /// The blocks in the function mapped by block ID + std::unordered_map blocks_; + + /// A list of blocks in the order they appeared in the binary + std::vector ordered_blocks_; + + /// Blocks which are forward referenced by blocks but not defined + std::unordered_set undefined_blocks_; + + /// The block that is currently being parsed + BasicBlock* current_block_; + + /// A pseudo entry node used in dominance analysis. + /// After the function end has been registered, the successor list of the + /// pseudo entry node is the minimal set of nodes such that all nodes in the + /// CFG can be reached by following successor lists. That is, the successors + /// will be: + /// - Any basic block without predecessors. This includes the entry + /// block to the function. + /// - A single node from each otherwise unreachable cycle in the CFG, if + /// such cycles exist. + /// The pseudo entry node does not appear in the predecessor or successor + /// list of any ordinary block. + /// It has no predecessors. + /// It has Id 0. + BasicBlock pseudo_entry_block_; + + /// A pseudo exit block used in dominance analysis. + /// After the function end has been registered, the predecessor list of the + /// pseudo exit node is the minimal set of nodes such that all nodes in the + /// CFG can be reached by following predecessor lists. That is, the + /// predecessors will be: + /// - Any basic block without successors. This includes any basic block + /// ending with an OpReturn, OpReturnValue or similar instructions. + /// - A single node from each otherwise unreachable cycle in the CFG, if + /// such cycles exist. + /// The pseudo exit node does not appear in the predecessor or successor + /// list of any ordinary block. + /// It has no successors. + BasicBlock pseudo_exit_block_; + + // Maps a block to its successors in the augmented CFG, if that set is + // different from its successors in the ordinary CFG. + std::unordered_map> + augmented_successors_map_; + // Maps a block to its predecessors in the augmented CFG, if that set is + // different from its predecessors in the ordinary CFG. + std::unordered_map> + augmented_predecessors_map_; + + // Maps a structured loop header to its CFG successors and also its + // continue target if that continue target is not the loop header + // itself. This might have duplicates. + std::unordered_map> + loop_header_successors_plus_continue_target_map_; + + /// The constructs that are available in this function + std::list cfg_constructs_; + + /// The variable IDs of the functions + std::vector variable_ids_; + + /// The function parameter ids of the functions + std::vector parameter_ids_; + + /// Maps a construct's entry block to the construct(s). + /// Since a basic block may be the entry block of different types of + /// constructs, the type of the construct should also be specified in order to + /// get the unique construct. + std::unordered_map, Construct*, + bb_constr_type_pair_hash> + entry_block_to_construct_; + + /// This map provides the header block for a given merge block. + std::unordered_map merge_block_header_; + + /// Stores the control flow nesting depth of a given basic block + std::unordered_map block_depth_; + + /// Stores execution model limitations imposed by instructions used within the + /// function. The functor stored in the list return true if execution model + /// is compatible, false otherwise. If the functor returns false, it can also + /// optionally fill the string parameter with the reason for incompatibility. + std::list> + execution_model_limitations_; + + /// Stores ids of all functions called from this function. + std::set function_call_targets_; +}; + +} // namespace val +} // namespace spvtools + +#endif // SOURCE_VAL_FUNCTION_H_ diff --git a/source/val/instruction.cpp b/source/val/instruction.cpp new file mode 100644 index 000000000..b9155898a --- /dev/null +++ b/source/val/instruction.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/instruction.h" + +#include + +namespace spvtools { +namespace val { + +Instruction::Instruction(const spv_parsed_instruction_t* inst) + : words_(inst->words, inst->words + inst->num_words), + operands_(inst->operands, inst->operands + inst->num_operands), + inst_({words_.data(), inst->num_words, inst->opcode, inst->ext_inst_type, + inst->type_id, inst->result_id, operands_.data(), + inst->num_operands}) {} + +void Instruction::RegisterUse(const Instruction* inst, uint32_t index) { + uses_.push_back(std::make_pair(inst, index)); +} + +bool operator<(const Instruction& lhs, const Instruction& rhs) { + return lhs.id() < rhs.id(); +} +bool operator<(const Instruction& lhs, uint32_t rhs) { return lhs.id() < rhs; } +bool operator==(const Instruction& lhs, const Instruction& rhs) { + return lhs.id() == rhs.id(); +} +bool operator==(const Instruction& lhs, uint32_t rhs) { + return lhs.id() == rhs; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/instruction.h b/source/val/instruction.h new file mode 100644 index 000000000..1fa855fca --- /dev/null +++ b/source/val/instruction.h @@ -0,0 +1,140 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_VAL_INSTRUCTION_H_ +#define SOURCE_VAL_INSTRUCTION_H_ + +#include +#include +#include +#include +#include + +#include "source/table.h" +#include "spirv-tools/libspirv.h" + +namespace spvtools { +namespace val { + +class BasicBlock; +class Function; + +/// Wraps the spv_parsed_instruction struct along with use and definition of the +/// instruction's result id +class Instruction { + public: + explicit Instruction(const spv_parsed_instruction_t* inst); + + /// Registers the use of the Instruction in instruction \p inst at \p index + void RegisterUse(const Instruction* inst, uint32_t index); + + uint32_t id() const { return inst_.result_id; } + uint32_t type_id() const { return inst_.type_id; } + SpvOp opcode() const { return static_cast(inst_.opcode); } + + /// Returns the Function where the instruction was defined. nullptr if it was + /// defined outside of a Function + const Function* function() const { return function_; } + void set_function(Function* func) { function_ = func; } + + /// Returns the BasicBlock where the instruction was defined. nullptr if it + /// was defined outside of a BasicBlock + const BasicBlock* block() const { return block_; } + void set_block(BasicBlock* b) { block_ = b; } + + /// Returns a vector of pairs of all references to this instruction's result + /// id. The first element is the instruction in which this result id was + /// referenced and the second is the index of the word in that instruction + /// where this result id appeared + const std::vector>& uses() const { + return uses_; + } + + /// The word used to define the Instruction + uint32_t word(size_t index) const { return words_[index]; } + + /// The words used to define the Instruction + const std::vector& words() const { return words_; } + + /// Returns the operand at |idx|. + const spv_parsed_operand_t& operand(size_t idx) const { + return operands_[idx]; + } + + /// The operands of the Instruction + const std::vector& operands() const { + return operands_; + } + + /// Provides direct access to the stored C instruction object. + const spv_parsed_instruction_t& c_inst() const { return inst_; } + + /// Provides direct access to instructions spv_ext_inst_type_t object. + const spv_ext_inst_type_t& ext_inst_type() const { + return inst_.ext_inst_type; + } + + // Casts the words belonging to the operand under |index| to |T| and returns. + template + T GetOperandAs(size_t index) const { + const spv_parsed_operand_t& o = operands_.at(index); + assert(o.num_words * 4 >= sizeof(T)); + assert(o.offset + o.num_words <= inst_.num_words); + return *reinterpret_cast(&words_[o.offset]); + } + + size_t LineNum() const { return line_num_; } + void SetLineNum(size_t pos) { line_num_ = pos; } + + private: + const std::vector words_; + const std::vector operands_; + spv_parsed_instruction_t inst_; + size_t line_num_ = 0; + + /// The function in which this instruction was declared + Function* function_ = nullptr; + + /// The basic block in which this instruction was declared + BasicBlock* block_ = nullptr; + + /// This is a vector of pairs of all references to this instruction's result + /// id. The first element is the instruction in which this result id was + /// referenced and the second is the index of the word in the referencing + /// instruction where this instruction appeared + std::vector> uses_; +}; + +bool operator<(const Instruction& lhs, const Instruction& rhs); +bool operator<(const Instruction& lhs, uint32_t rhs); +bool operator==(const Instruction& lhs, const Instruction& rhs); +bool operator==(const Instruction& lhs, uint32_t rhs); + +} // namespace val +} // namespace spvtools + +// custom specialization of std::hash for Instruction +namespace std { +template <> +struct hash { + typedef spvtools::val::Instruction argument_type; + typedef std::size_t result_type; + result_type operator()(const argument_type& inst) const { + return hash()(inst.id()); + } +}; + +} // namespace std + +#endif // SOURCE_VAL_INSTRUCTION_H_ diff --git a/source/val/validate.cpp b/source/val/validate.cpp new file mode 100644 index 000000000..9797d31ad --- /dev/null +++ b/source/val/validate.cpp @@ -0,0 +1,531 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/validate.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "source/binary.h" +#include "source/diagnostic.h" +#include "source/enum_string_mapping.h" +#include "source/extensions.h" +#include "source/instruction.h" +#include "source/opcode.h" +#include "source/operand.h" +#include "source/spirv_constant.h" +#include "source/spirv_endian.h" +#include "source/spirv_target_env.h" +#include "source/spirv_validator_options.h" +#include "source/val/construct.h" +#include "source/val/function.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" +#include "spirv-tools/libspirv.h" + +namespace { +// TODO(issue 1950): The validator only returns a single message anyway, so no +// point in generating more than 1 warning. +static uint32_t kDefaultMaxNumOfWarnings = 1; +} // namespace + +namespace spvtools { +namespace val { +namespace { + +// TODO(umar): Validate header +// TODO(umar): The binary parser validates the magic word, and the length of the +// header, but nothing else. +spv_result_t setHeader(void* user_data, spv_endianness_t, uint32_t, + uint32_t version, uint32_t generator, uint32_t id_bound, + uint32_t) { + // Record the ID bound so that the validator can ensure no ID is out of bound. + ValidationState_t& _ = *(reinterpret_cast(user_data)); + _.setIdBound(id_bound); + _.setGenerator(generator); + _.setVersion(version); + + return SPV_SUCCESS; +} + +// Parses OpExtension instruction and registers extension. +void RegisterExtension(ValidationState_t& _, + const spv_parsed_instruction_t* inst) { + const std::string extension_str = spvtools::GetExtensionString(inst); + Extension extension; + if (!GetExtensionFromString(extension_str.c_str(), &extension)) { + // The error will be logged in the ProcessInstruction pass. + return; + } + + _.RegisterExtension(extension); +} + +// Parses the beginning of the module searching for OpExtension instructions. +// Registers extensions if recognized. Returns SPV_REQUESTED_TERMINATION +// once an instruction which is not SpvOpCapability and SpvOpExtension is +// encountered. According to the SPIR-V spec extensions are declared after +// capabilities and before everything else. +spv_result_t ProcessExtensions(void* user_data, + const spv_parsed_instruction_t* inst) { + const SpvOp opcode = static_cast(inst->opcode); + if (opcode == SpvOpCapability) return SPV_SUCCESS; + + if (opcode == SpvOpExtension) { + ValidationState_t& _ = *(reinterpret_cast(user_data)); + RegisterExtension(_, inst); + return SPV_SUCCESS; + } + + // OpExtension block is finished, requesting termination. + return SPV_REQUESTED_TERMINATION; +} + +spv_result_t ProcessInstruction(void* user_data, + const spv_parsed_instruction_t* inst) { + ValidationState_t& _ = *(reinterpret_cast(user_data)); + + auto* instruction = _.AddOrderedInstruction(inst); + _.RegisterDebugInstruction(instruction); + + return SPV_SUCCESS; +} + +void printDot(const ValidationState_t& _, const BasicBlock& other) { + std::string block_string; + if (other.successors()->empty()) { + block_string += "end "; + } else { + for (auto block : *other.successors()) { + block_string += _.getIdName(block->id()) + " "; + } + } + printf("%10s -> {%s\b}\n", _.getIdName(other.id()).c_str(), + block_string.c_str()); +} + +void PrintBlocks(ValidationState_t& _, Function func) { + assert(func.first_block()); + + printf("%10s -> %s\n", _.getIdName(func.id()).c_str(), + _.getIdName(func.first_block()->id()).c_str()); + for (const auto& block : func.ordered_blocks()) { + printDot(_, *block); + } +} + +#ifdef __clang__ +#define UNUSED(func) [[gnu::unused]] func +#elif defined(__GNUC__) +#define UNUSED(func) \ + func __attribute__((unused)); \ + func +#elif defined(_MSC_VER) +#define UNUSED(func) func +#endif + +UNUSED(void PrintDotGraph(ValidationState_t& _, Function func)) { + if (func.first_block()) { + std::string func_name(_.getIdName(func.id())); + printf("digraph %s {\n", func_name.c_str()); + PrintBlocks(_, func); + printf("}\n"); + } +} + +spv_result_t ValidateForwardDecls(ValidationState_t& _) { + if (_.unresolved_forward_id_count() == 0) return SPV_SUCCESS; + + std::stringstream ss; + std::vector ids = _.UnresolvedForwardIds(); + + std::transform( + std::begin(ids), std::end(ids), + std::ostream_iterator(ss, " "), + bind(&ValidationState_t::getIdName, std::ref(_), std::placeholders::_1)); + + auto id_str = ss.str(); + return _.diag(SPV_ERROR_INVALID_ID, nullptr) + << "The following forward referenced IDs have not been defined:\n" + << id_str.substr(0, id_str.size() - 1); +} + +std::vector CalculateNamesForEntryPoint(ValidationState_t& _, + const uint32_t id) { + auto id_descriptions = _.entry_point_descriptions(id); + auto id_names = std::vector(); + id_names.reserve((id_descriptions.size())); + + for (auto description : id_descriptions) id_names.push_back(description.name); + + return id_names; +} + +spv_result_t ValidateEntryPointNameUnique(ValidationState_t& _, + const uint32_t id) { + auto id_names = CalculateNamesForEntryPoint(_, id); + const auto names = + std::unordered_set(id_names.begin(), id_names.end()); + + if (id_names.size() != names.size()) { + std::sort(id_names.begin(), id_names.end()); + for (size_t i = 0; i < id_names.size() - 1; i++) { + if (id_names[i] == id_names[i + 1]) { + return _.diag(SPV_ERROR_INVALID_BINARY, _.FindDef(id)) + << "Entry point name \"" << id_names[i] + << "\" is not unique, which is not allow in WebGPU env."; + } + } + } + + for (const auto other_id : _.entry_points()) { + if (other_id == id) continue; + const auto other_id_names = CalculateNamesForEntryPoint(_, other_id); + for (const auto other_id_name : other_id_names) { + if (names.find(other_id_name) != names.end()) { + return _.diag(SPV_ERROR_INVALID_BINARY, _.FindDef(id)) + << "Entry point name \"" << other_id_name + << "\" is not unique, which is not allow in WebGPU env."; + } + } + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateEntryPointNamesUnique(ValidationState_t& _) { + for (const auto id : _.entry_points()) { + auto result = ValidateEntryPointNameUnique(_, id); + if (result != SPV_SUCCESS) return result; + } + return SPV_SUCCESS; +} + +// Entry point validation. Based on 2.16.1 (Universal Validation Rules) of the +// SPIRV spec: +// * There is at least one OpEntryPoint instruction, unless the Linkage +// capability is being used. +// * No function can be targeted by both an OpEntryPoint instruction and an +// OpFunctionCall instruction. +// +// Additionally enforces that entry points for Vulkan and WebGPU should not have +// recursion. And that entry names should be unique for WebGPU. +spv_result_t ValidateEntryPoints(ValidationState_t& _) { + _.ComputeFunctionToEntryPointMapping(); + _.ComputeRecursiveEntryPoints(); + + if (_.entry_points().empty() && !_.HasCapability(SpvCapabilityLinkage)) { + return _.diag(SPV_ERROR_INVALID_BINARY, nullptr) + << "No OpEntryPoint instruction was found. This is only allowed if " + "the Linkage capability is being used."; + } + + for (const auto& entry_point : _.entry_points()) { + if (_.IsFunctionCallTarget(entry_point)) { + return _.diag(SPV_ERROR_INVALID_BINARY, _.FindDef(entry_point)) + << "A function (" << entry_point + << ") may not be targeted by both an OpEntryPoint instruction and " + "an OpFunctionCall instruction."; + } + + // For Vulkan and WebGPU, the static function-call graph for an entry point + // must not contain cycles. + if (spvIsWebGPUEnv(_.context()->target_env) || + spvIsVulkanEnv(_.context()->target_env)) { + if (_.recursive_entry_points().find(entry_point) != + _.recursive_entry_points().end()) { + return _.diag(SPV_ERROR_INVALID_BINARY, _.FindDef(entry_point)) + << "Entry points may not have a call graph with cycles."; + } + } + + // For WebGPU all entry point names must be unique. + if (spvIsWebGPUEnv(_.context()->target_env)) { + const auto result = ValidateEntryPointNamesUnique(_); + if (result != SPV_SUCCESS) return result; + } + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateBinaryUsingContextAndValidationState( + const spv_context_t& context, const uint32_t* words, const size_t num_words, + spv_diagnostic* pDiagnostic, ValidationState_t* vstate) { + auto binary = std::unique_ptr( + new spv_const_binary_t{words, num_words}); + + spv_endianness_t endian; + spv_position_t position = {}; + if (spvBinaryEndianness(binary.get(), &endian)) { + return DiagnosticStream(position, context.consumer, "", + SPV_ERROR_INVALID_BINARY) + << "Invalid SPIR-V magic number."; + } + + spv_header_t header; + if (spvBinaryHeaderGet(binary.get(), endian, &header)) { + return DiagnosticStream(position, context.consumer, "", + SPV_ERROR_INVALID_BINARY) + << "Invalid SPIR-V header."; + } + + if (header.version > spvVersionForTargetEnv(context.target_env)) { + return DiagnosticStream(position, context.consumer, "", + SPV_ERROR_WRONG_VERSION) + << "Invalid SPIR-V binary version " + << SPV_SPIRV_VERSION_MAJOR_PART(header.version) << "." + << SPV_SPIRV_VERSION_MINOR_PART(header.version) + << " for target environment " + << spvTargetEnvDescription(context.target_env) << "."; + } + + if (header.bound > vstate->options()->universal_limits_.max_id_bound) { + return DiagnosticStream(position, context.consumer, "", + SPV_ERROR_INVALID_BINARY) + << "Invalid SPIR-V. The id bound is larger than the max id bound " + << vstate->options()->universal_limits_.max_id_bound << "."; + } + + // Look for OpExtension instructions and register extensions. + // This parse should not produce any error messages. Hijack the context and + // replace the message consumer so that we do not pollute any state in input + // consumer. + spv_context_t hijacked_context = context; + hijacked_context.consumer = [](spv_message_level_t, const char*, + const spv_position_t&, const char*) {}; + spvBinaryParse(&hijacked_context, vstate, words, num_words, + /* parsed_header = */ nullptr, ProcessExtensions, + /* diagnostic = */ nullptr); + + // Parse the module and perform inline validation checks. These checks do + // not require the the knowledge of the whole module. + if (auto error = spvBinaryParse(&context, vstate, words, num_words, setHeader, + ProcessInstruction, pDiagnostic)) { + return error; + } + + for (auto& instruction : vstate->ordered_instructions()) { + { + // In order to do this work outside of Process Instruction we need to be + // able to, briefly, de-const the instruction. + Instruction* inst = const_cast(&instruction); + + if (inst->opcode() == SpvOpEntryPoint) { + const auto entry_point = inst->GetOperandAs(1); + const auto execution_model = inst->GetOperandAs(0); + const char* str = reinterpret_cast( + inst->words().data() + inst->operand(2).offset); + + ValidationState_t::EntryPointDescription desc; + desc.name = str; + + std::vector interfaces; + for (size_t j = 3; j < inst->operands().size(); ++j) + desc.interfaces.push_back(inst->word(inst->operand(j).offset)); + + vstate->RegisterEntryPoint(entry_point, execution_model, + std::move(desc)); + } + if (inst->opcode() == SpvOpFunctionCall) { + if (!vstate->in_function_body()) { + return vstate->diag(SPV_ERROR_INVALID_LAYOUT, &instruction) + << "A FunctionCall must happen within a function body."; + } + + const auto called_id = inst->GetOperandAs(2); + if (spvIsWebGPUEnv(context.target_env) && + !vstate->IsFunctionCallDefined(called_id)) { + return vstate->diag(SPV_ERROR_INVALID_LAYOUT, &instruction) + << "For WebGPU, functions need to be defined before being " + "called."; + } + + vstate->AddFunctionCallTarget(called_id); + } + + if (vstate->in_function_body()) { + inst->set_function(&(vstate->current_function())); + inst->set_block(vstate->current_function().current_block()); + + if (vstate->in_block() && spvOpcodeIsBlockTerminator(inst->opcode())) { + vstate->current_function().current_block()->set_terminator(inst); + } + } + + if (auto error = IdPass(*vstate, inst)) return error; + } + + if (auto error = CapabilityPass(*vstate, &instruction)) return error; + if (auto error = DataRulesPass(*vstate, &instruction)) return error; + if (auto error = ModuleLayoutPass(*vstate, &instruction)) return error; + if (auto error = CfgPass(*vstate, &instruction)) return error; + if (auto error = InstructionPass(*vstate, &instruction)) return error; + + // Now that all of the checks are done, update the state. + { + Instruction* inst = const_cast(&instruction); + vstate->RegisterInstruction(inst); + } + if (auto error = UpdateIdUse(*vstate, &instruction)) return error; + } + + if (!vstate->has_memory_model_specified()) + return vstate->diag(SPV_ERROR_INVALID_LAYOUT, nullptr) + << "Missing required OpMemoryModel instruction."; + + if (vstate->in_function_body()) + return vstate->diag(SPV_ERROR_INVALID_LAYOUT, nullptr) + << "Missing OpFunctionEnd at end of module."; + + // Catch undefined forward references before performing further checks. + if (auto error = ValidateForwardDecls(*vstate)) return error; + + // Validate individual opcodes. + for (size_t i = 0; i < vstate->ordered_instructions().size(); ++i) { + auto& instruction = vstate->ordered_instructions()[i]; + + // Keep these passes in the order they appear in the SPIR-V specification + // sections to maintain test consistency. + // Miscellaneous + if (auto error = DebugPass(*vstate, &instruction)) return error; + if (auto error = AnnotationPass(*vstate, &instruction)) return error; + if (auto error = ExtensionPass(*vstate, &instruction)) return error; + if (auto error = ModeSettingPass(*vstate, &instruction)) return error; + if (auto error = TypePass(*vstate, &instruction)) return error; + if (auto error = ConstantPass(*vstate, &instruction)) return error; + if (auto error = MemoryPass(*vstate, &instruction)) return error; + if (auto error = FunctionPass(*vstate, &instruction)) return error; + if (auto error = ImagePass(*vstate, &instruction)) return error; + if (auto error = ConversionPass(*vstate, &instruction)) return error; + if (auto error = CompositesPass(*vstate, &instruction)) return error; + if (auto error = ArithmeticsPass(*vstate, &instruction)) return error; + if (auto error = BitwisePass(*vstate, &instruction)) return error; + if (auto error = LogicalsPass(*vstate, &instruction)) return error; + if (auto error = ControlFlowPass(*vstate, &instruction)) return error; + if (auto error = DerivativesPass(*vstate, &instruction)) return error; + if (auto error = AtomicsPass(*vstate, &instruction)) return error; + if (auto error = PrimitivesPass(*vstate, &instruction)) return error; + if (auto error = BarriersPass(*vstate, &instruction)) return error; + // Group + // Device-Side Enqueue + // Pipe + if (auto error = NonUniformPass(*vstate, &instruction)) return error; + + if (auto error = LiteralsPass(*vstate, &instruction)) return error; + } + + // Validate the preconditions involving adjacent instructions. e.g. SpvOpPhi + // must only be preceeded by SpvOpLabel, SpvOpPhi, or SpvOpLine. + if (auto error = ValidateAdjacency(*vstate)) return error; + + if (auto error = ValidateEntryPoints(*vstate)) return error; + // CFG checks are performed after the binary has been parsed + // and the CFGPass has collected information about the control flow + if (auto error = PerformCfgChecks(*vstate)) return error; + if (auto error = CheckIdDefinitionDominateUse(*vstate)) return error; + if (auto error = ValidateDecorations(*vstate)) return error; + if (auto error = ValidateInterfaces(*vstate)) return error; + // TODO(dsinclair): Restructure ValidateBuiltins so we can move into the + // for() above as it loops over all ordered_instructions internally. + if (auto error = ValidateBuiltIns(*vstate)) return error; + // These checks must be performed after individual opcode checks because + // those checks register the limitation checked here. + for (const auto inst : vstate->ordered_instructions()) { + if (auto error = ValidateExecutionLimitations(*vstate, &inst)) return error; + } + + return SPV_SUCCESS; +} + +} // namespace + +spv_result_t ValidateBinaryAndKeepValidationState( + const spv_const_context context, spv_const_validator_options options, + const uint32_t* words, const size_t num_words, spv_diagnostic* pDiagnostic, + std::unique_ptr* vstate) { + spv_context_t hijack_context = *context; + if (pDiagnostic) { + *pDiagnostic = nullptr; + UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); + } + + vstate->reset(new ValidationState_t(&hijack_context, options, words, + num_words, kDefaultMaxNumOfWarnings)); + + return ValidateBinaryUsingContextAndValidationState( + hijack_context, words, num_words, pDiagnostic, vstate->get()); +} + +} // namespace val +} // namespace spvtools + +spv_result_t spvValidate(const spv_const_context context, + const spv_const_binary binary, + spv_diagnostic* pDiagnostic) { + return spvValidateBinary(context, binary->code, binary->wordCount, + pDiagnostic); +} + +spv_result_t spvValidateBinary(const spv_const_context context, + const uint32_t* words, const size_t num_words, + spv_diagnostic* pDiagnostic) { + spv_context_t hijack_context = *context; + if (pDiagnostic) { + *pDiagnostic = nullptr; + spvtools::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); + } + + // This interface is used for default command line options. + spv_validator_options default_options = spvValidatorOptionsCreate(); + + // Create the ValidationState using the context and default options. + spvtools::val::ValidationState_t vstate(&hijack_context, default_options, + words, num_words, + kDefaultMaxNumOfWarnings); + + spv_result_t result = + spvtools::val::ValidateBinaryUsingContextAndValidationState( + hijack_context, words, num_words, pDiagnostic, &vstate); + + spvValidatorOptionsDestroy(default_options); + return result; +} + +spv_result_t spvValidateWithOptions(const spv_const_context context, + spv_const_validator_options options, + const spv_const_binary binary, + spv_diagnostic* pDiagnostic) { + spv_context_t hijack_context = *context; + if (pDiagnostic) { + *pDiagnostic = nullptr; + spvtools::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); + } + + // Create the ValidationState using the context. + spvtools::val::ValidationState_t vstate(&hijack_context, options, + binary->code, binary->wordCount, + kDefaultMaxNumOfWarnings); + + return spvtools::val::ValidateBinaryUsingContextAndValidationState( + hijack_context, binary->code, binary->wordCount, pDiagnostic, &vstate); +} diff --git a/source/val/validate.h b/source/val/validate.h new file mode 100644 index 000000000..fe357a2f8 --- /dev/null +++ b/source/val/validate.h @@ -0,0 +1,234 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_VAL_VALIDATE_H_ +#define SOURCE_VAL_VALIDATE_H_ + +#include +#include +#include +#include + +#include "source/instruction.h" +#include "source/table.h" +#include "spirv-tools/libspirv.h" + +namespace spvtools { +namespace val { + +class ValidationState_t; +class BasicBlock; +class Instruction; + +/// A function that returns a vector of BasicBlocks given a BasicBlock. Used to +/// get the successor and predecessor nodes of a CFG block +using get_blocks_func = + std::function*(const BasicBlock*)>; + +/// @brief Performs the Control Flow Graph checks +/// +/// @param[in] _ the validation state of the module +/// +/// @return SPV_SUCCESS if no errors are found. SPV_ERROR_INVALID_CFG otherwise +spv_result_t PerformCfgChecks(ValidationState_t& _); + +/// @brief Updates the use vectors of all instructions that can be referenced +/// +/// This function will update the vector which define where an instruction was +/// referenced in the binary. +/// +/// @param[in] _ the validation state of the module +/// +/// @return SPV_SUCCESS if no errors are found. +spv_result_t UpdateIdUse(ValidationState_t& _, const Instruction* inst); + +/// @brief This function checks all ID definitions dominate their use in the +/// CFG. +/// +/// This function will iterate over all ID definitions that are defined in the +/// functions of a module and make sure that the definitions appear in a +/// block that dominates their use. +/// +/// @param[in] _ the validation state of the module +/// +/// @return SPV_SUCCESS if no errors are found. SPV_ERROR_INVALID_ID otherwise +spv_result_t CheckIdDefinitionDominateUse(ValidationState_t& _); + +/// @brief This function checks for preconditions involving the adjacent +/// instructions. +/// +/// This function will iterate over all instructions and check for any required +/// predecessor and/or successor instructions. e.g. SpvOpPhi must only be +/// preceeded by SpvOpLabel, SpvOpPhi, or SpvOpLine. +/// +/// @param[in] _ the validation state of the module +/// +/// @return SPV_SUCCESS if no errors are found. SPV_ERROR_INVALID_DATA otherwise +spv_result_t ValidateAdjacency(ValidationState_t& _); + +/// @brief Validates static uses of input and output variables +/// +/// Checks that any entry point that uses a input or output variable lists that +/// variable in its interface. +/// +/// @param[in] _ the validation state of the module +/// +/// @return SPV_SUCCESS if no errors are found. +spv_result_t ValidateInterfaces(ValidationState_t& _); + +/// @brief Validates memory instructions +/// +/// @param[in] _ the validation state of the module +/// @return SPV_SUCCESS if no errors are found. +spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst); + +/// @brief Updates the immediate dominator for each of the block edges +/// +/// Updates the immediate dominator of the blocks for each of the edges +/// provided by the @p dom_edges parameter +/// +/// @param[in,out] dom_edges The edges of the dominator tree +/// @param[in] set_func This function will be called to updated the Immediate +/// dominator +void UpdateImmediateDominators( + const std::vector>& dom_edges, + std::function set_func); + +/// @brief Prints all of the dominators of a BasicBlock +/// +/// @param[in] block The dominators of this block will be printed +void printDominatorList(BasicBlock& block); + +/// Performs logical layout validation as described in section 2.4 of the SPIR-V +/// spec. +spv_result_t ModuleLayoutPass(ValidationState_t& _, const Instruction* inst); + +/// Performs Control Flow Graph validation and construction. +spv_result_t CfgPass(ValidationState_t& _, const Instruction* inst); + +/// Validates Control Flow Graph instructions. +spv_result_t ControlFlowPass(ValidationState_t& _, const Instruction* inst); + +/// Performs Id and SSA validation of a module +spv_result_t IdPass(ValidationState_t& _, Instruction* inst); + +/// Performs validation of the Data Rules subsection of 2.16.1 Universal +/// Validation Rules. +/// TODO(ehsann): add more comments here as more validation code is added. +spv_result_t DataRulesPass(ValidationState_t& _, const Instruction* inst); + +/// Performs instruction validation. +spv_result_t InstructionPass(ValidationState_t& _, const Instruction* inst); + +/// Performs decoration validation. Assumes each decoration on a group +/// has been propagated down to the group members. +spv_result_t ValidateDecorations(ValidationState_t& _); + +/// Performs validation of built-in variables. +spv_result_t ValidateBuiltIns(ValidationState_t& _); + +/// Validates type instructions. +spv_result_t TypePass(ValidationState_t& _, const Instruction* inst); + +/// Validates constant instructions. +spv_result_t ConstantPass(ValidationState_t& _, const Instruction* inst); + +/// Validates correctness of arithmetic instructions. +spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst); + +/// Validates correctness of composite instructions. +spv_result_t CompositesPass(ValidationState_t& _, const Instruction* inst); + +/// Validates correctness of conversion instructions. +spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst); + +/// Validates correctness of derivative instructions. +spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst); + +/// Validates correctness of logical instructions. +spv_result_t LogicalsPass(ValidationState_t& _, const Instruction* inst); + +/// Validates correctness of bitwise instructions. +spv_result_t BitwisePass(ValidationState_t& _, const Instruction* inst); + +/// Validates correctness of image instructions. +spv_result_t ImagePass(ValidationState_t& _, const Instruction* inst); + +/// Validates correctness of atomic instructions. +spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst); + +/// Validates correctness of barrier instructions. +spv_result_t BarriersPass(ValidationState_t& _, const Instruction* inst); + +/// Validates correctness of literal numbers. +spv_result_t LiteralsPass(ValidationState_t& _, const Instruction* inst); + +/// Validates correctness of extension instructions. +spv_result_t ExtensionPass(ValidationState_t& _, const Instruction* inst); + +/// Validates correctness of annotation instructions. +spv_result_t AnnotationPass(ValidationState_t& _, const Instruction* inst); + +/// Validates correctness of non-uniform group instructions. +spv_result_t NonUniformPass(ValidationState_t& _, const Instruction* inst); + +/// Validates correctness of debug instructions. +spv_result_t DebugPass(ValidationState_t& _, const Instruction* inst); + +// Validates that capability declarations use operands allowed in the current +// context. +spv_result_t CapabilityPass(ValidationState_t& _, const Instruction* inst); + +/// Validates correctness of primitive instructions. +spv_result_t PrimitivesPass(ValidationState_t& _, const Instruction* inst); + +/// Validates correctness of mode setting instructions. +spv_result_t ModeSettingPass(ValidationState_t& _, const Instruction* inst); + +/// Validates correctness of function instructions. +spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst); + +/// Validates execution limitations. +/// +/// Verifies execution models are allowed for all functionality they contain. +spv_result_t ValidateExecutionLimitations(ValidationState_t& _, + const Instruction* inst); + +/// @brief Validate the ID's within a SPIR-V binary +/// +/// @param[in] pInstructions array of instructions +/// @param[in] count number of elements in instruction array +/// @param[in] bound the binary header +/// @param[in,out] position current word in the binary +/// @param[in] consumer message consumer callback +/// +/// @return result code +spv_result_t spvValidateIDs(const spv_instruction_t* pInstructions, + const uint64_t count, const uint32_t bound, + spv_position position, + const MessageConsumer& consumer); + +// Performs validation for the SPIRV-V module binary. +// The main difference between this API and spvValidateBinary is that the +// "Validation State" is not destroyed upon function return; it lives on and is +// pointed to by the vstate unique_ptr. +spv_result_t ValidateBinaryAndKeepValidationState( + const spv_const_context context, spv_const_validator_options options, + const uint32_t* words, const size_t num_words, spv_diagnostic* pDiagnostic, + std::unique_ptr* vstate); + +} // namespace val +} // namespace spvtools + +#endif // SOURCE_VAL_VALIDATE_H_ diff --git a/source/val/validate_adjacency.cpp b/source/val/validate_adjacency.cpp new file mode 100644 index 000000000..108e36106 --- /dev/null +++ b/source/val/validate_adjacency.cpp @@ -0,0 +1,118 @@ +// Copyright (c) 2018 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validates correctness of the intra-block preconditions of SPIR-V +// instructions. + +#include "source/val/validate.h" + +#include + +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { + +enum { + // Status right after meeting OpFunction. + IN_NEW_FUNCTION, + // Status right after meeting the entry block. + IN_ENTRY_BLOCK, + // Status right after meeting non-entry blocks. + PHI_VALID, + // Status right after meeting non-OpVariable instructions in the entry block + // or non-OpPhi instructions in non-entry blocks, except OpLine. + PHI_AND_VAR_INVALID, +}; + +spv_result_t ValidateAdjacency(ValidationState_t& _) { + const auto& instructions = _.ordered_instructions(); + int adjacency_status = PHI_AND_VAR_INVALID; + + for (size_t i = 0; i < instructions.size(); ++i) { + const auto& inst = instructions[i]; + switch (inst.opcode()) { + case SpvOpFunction: + case SpvOpFunctionParameter: + adjacency_status = IN_NEW_FUNCTION; + break; + case SpvOpLabel: + adjacency_status = + adjacency_status == IN_NEW_FUNCTION ? IN_ENTRY_BLOCK : PHI_VALID; + break; + case SpvOpPhi: + if (adjacency_status != PHI_VALID) { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "OpPhi must appear within a non-entry block before all " + << "non-OpPhi instructions " + << "(except for OpLine, which can be mixed with OpPhi)."; + } + break; + case SpvOpLine: + case SpvOpNoLine: + break; + case SpvOpLoopMerge: + adjacency_status = PHI_AND_VAR_INVALID; + if (i != (instructions.size() - 1)) { + switch (instructions[i + 1].opcode()) { + case SpvOpBranch: + case SpvOpBranchConditional: + break; + default: + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "OpLoopMerge must immediately precede either an " + << "OpBranch or OpBranchConditional instruction. " + << "OpLoopMerge must be the second-to-last instruction in " + << "its block."; + } + } + break; + case SpvOpSelectionMerge: + adjacency_status = PHI_AND_VAR_INVALID; + if (i != (instructions.size() - 1)) { + switch (instructions[i + 1].opcode()) { + case SpvOpBranchConditional: + case SpvOpSwitch: + break; + default: + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "OpSelectionMerge must immediately precede either an " + << "OpBranchConditional or OpSwitch instruction. " + << "OpSelectionMerge must be the second-to-last " + << "instruction in its block."; + } + } + break; + case SpvOpVariable: + if (inst.GetOperandAs(2) == SpvStorageClassFunction && + adjacency_status != IN_ENTRY_BLOCK) { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "All OpVariable instructions in a function must be the " + "first instructions in the first block."; + } + break; + default: + adjacency_status = PHI_AND_VAR_INVALID; + break; + } + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_annotation.cpp b/source/val/validate_annotation.cpp new file mode 100644 index 000000000..621ace2a6 --- /dev/null +++ b/source/val/validate_annotation.cpp @@ -0,0 +1,263 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/validate.h" + +#include "source/opcode.h" +#include "source/spirv_target_env.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +spv_result_t ValidateDecorate(ValidationState_t& _, const Instruction* inst) { + const auto decoration = inst->GetOperandAs(1); + if (decoration == SpvDecorationSpecId) { + const auto target_id = inst->GetOperandAs(0); + const auto target = _.FindDef(target_id); + if (!target || !spvOpcodeIsScalarSpecConstant(target->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpDecorate SpecId decoration target '" + << _.getIdName(decoration) + << "' is not a scalar specialization constant."; + } + } + // TODO: Add validations for all decorations. + return SPV_SUCCESS; +} + +spv_result_t ValidateMemberDecorate(ValidationState_t& _, + const Instruction* inst) { + const auto struct_type_id = inst->GetOperandAs(0); + const auto struct_type = _.FindDef(struct_type_id); + if (!struct_type || SpvOpTypeStruct != struct_type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpMemberDecorate Structure type '" + << _.getIdName(struct_type_id) << "' is not a struct type."; + } + const auto member = inst->GetOperandAs(1); + const auto member_count = + static_cast(struct_type->words().size() - 2); + if (member_count <= member) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Index " << member + << " provided in OpMemberDecorate for struct " + << _.getIdName(struct_type_id) + << " is out of bounds. The structure has " << member_count + << " members. Largest valid index is " << member_count - 1 << "."; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateDecorationGroup(ValidationState_t& _, + const Instruction* inst) { + if (spvIsWebGPUEnv(_.context()->target_env)) { + return _.diag(SPV_ERROR_INVALID_BINARY, inst) + << "OpDecorationGroup is not allowed in the WebGPU execution " + << "environment."; + } + + const auto decoration_group_id = inst->GetOperandAs(0); + const auto decoration_group = _.FindDef(decoration_group_id); + for (auto pair : decoration_group->uses()) { + auto use = pair.first; + if (use->opcode() != SpvOpDecorate && use->opcode() != SpvOpGroupDecorate && + use->opcode() != SpvOpGroupMemberDecorate && + use->opcode() != SpvOpName) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Result id of OpDecorationGroup can only " + << "be targeted by OpName, OpGroupDecorate, " + << "OpDecorate, and OpGroupMemberDecorate"; + } + } + return SPV_SUCCESS; +} + +spv_result_t ValidateGroupDecorate(ValidationState_t& _, + const Instruction* inst) { + if (spvIsWebGPUEnv(_.context()->target_env)) { + return _.diag(SPV_ERROR_INVALID_BINARY, inst) + << "OpGroupDecorate is not allowed in the WebGPU execution " + << "environment."; + } + + const auto decoration_group_id = inst->GetOperandAs(0); + auto decoration_group = _.FindDef(decoration_group_id); + if (!decoration_group || SpvOpDecorationGroup != decoration_group->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpGroupDecorate Decoration group '" + << _.getIdName(decoration_group_id) + << "' is not a decoration group."; + } + for (unsigned i = 1; i < inst->operands().size(); ++i) { + auto target_id = inst->GetOperandAs(i); + auto target = _.FindDef(target_id); + if (!target || target->opcode() == SpvOpDecorationGroup) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpGroupDecorate may not target OpDecorationGroup '" + << _.getIdName(target_id) << "'"; + } + } + return SPV_SUCCESS; +} + +spv_result_t ValidateGroupMemberDecorate(ValidationState_t& _, + const Instruction* inst) { + if (spvIsWebGPUEnv(_.context()->target_env)) { + return _.diag(SPV_ERROR_INVALID_BINARY, inst) + << "OpGroupMemberDecorate is not allowed in the WebGPU execution " + << "environment."; + } + + const auto decoration_group_id = inst->GetOperandAs(0); + const auto decoration_group = _.FindDef(decoration_group_id); + if (!decoration_group || SpvOpDecorationGroup != decoration_group->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpGroupMemberDecorate Decoration group '" + << _.getIdName(decoration_group_id) + << "' is not a decoration group."; + } + // Grammar checks ensures that the number of arguments to this instruction + // is an odd number: 1 decoration group + (id,literal) pairs. + for (size_t i = 1; i + 1 < inst->operands().size(); i += 2) { + const uint32_t struct_id = inst->GetOperandAs(i); + const uint32_t index = inst->GetOperandAs(i + 1); + auto struct_instr = _.FindDef(struct_id); + if (!struct_instr || SpvOpTypeStruct != struct_instr->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpGroupMemberDecorate Structure type '" + << _.getIdName(struct_id) << "' is not a struct type."; + } + const uint32_t num_struct_members = + static_cast(struct_instr->words().size() - 2); + if (index >= num_struct_members) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Index " << index + << " provided in OpGroupMemberDecorate for struct " + << _.getIdName(struct_id) + << " is out of bounds. The structure has " << num_struct_members + << " members. Largest valid index is " << num_struct_members - 1 + << "."; + } + } + return SPV_SUCCESS; +} + +// Registers necessary decoration(s) for the appropriate IDs based on the +// instruction. +spv_result_t RegisterDecorations(ValidationState_t& _, + const Instruction* inst) { + switch (inst->opcode()) { + case SpvOpDecorate: { + const uint32_t target_id = inst->word(1); + const SpvDecoration dec_type = static_cast(inst->word(2)); + std::vector dec_params; + if (inst->words().size() > 3) { + dec_params.insert(dec_params.end(), inst->words().begin() + 3, + inst->words().end()); + } + _.RegisterDecorationForId(target_id, Decoration(dec_type, dec_params)); + break; + } + case SpvOpMemberDecorate: { + const uint32_t struct_id = inst->word(1); + const uint32_t index = inst->word(2); + const SpvDecoration dec_type = static_cast(inst->word(3)); + std::vector dec_params; + if (inst->words().size() > 4) { + dec_params.insert(dec_params.end(), inst->words().begin() + 4, + inst->words().end()); + } + _.RegisterDecorationForId(struct_id, + Decoration(dec_type, dec_params, index)); + break; + } + case SpvOpDecorationGroup: { + // We don't need to do anything right now. Assigning decorations to groups + // will be taken care of via OpGroupDecorate. + break; + } + case SpvOpGroupDecorate: { + // Word 1 is the group . All subsequent words are target s that + // are going to be decorated with the decorations. + const uint32_t decoration_group_id = inst->word(1); + std::vector& group_decorations = + _.id_decorations(decoration_group_id); + for (size_t i = 2; i < inst->words().size(); ++i) { + const uint32_t target_id = inst->word(i); + _.RegisterDecorationsForId(target_id, group_decorations.begin(), + group_decorations.end()); + } + break; + } + case SpvOpGroupMemberDecorate: { + // Word 1 is the Decoration Group followed by (struct,literal) + // pairs. All decorations of the group should be applied to all the struct + // members that are specified in the instructions. + const uint32_t decoration_group_id = inst->word(1); + std::vector& group_decorations = + _.id_decorations(decoration_group_id); + // Grammar checks ensures that the number of arguments to this instruction + // is an odd number: 1 decoration group + (id,literal) pairs. + for (size_t i = 2; i + 1 < inst->words().size(); i = i + 2) { + const uint32_t struct_id = inst->word(i); + const uint32_t index = inst->word(i + 1); + // ID validation phase ensures this is in fact a struct instruction and + // that the index is not out of bound. + _.RegisterDecorationsForStructMember(struct_id, index, + group_decorations.begin(), + group_decorations.end()); + } + break; + } + default: + break; + } + return SPV_SUCCESS; +} + +} // namespace + +spv_result_t AnnotationPass(ValidationState_t& _, const Instruction* inst) { + switch (inst->opcode()) { + case SpvOpDecorate: + if (auto error = ValidateDecorate(_, inst)) return error; + break; + case SpvOpMemberDecorate: + if (auto error = ValidateMemberDecorate(_, inst)) return error; + break; + case SpvOpDecorationGroup: + if (auto error = ValidateDecorationGroup(_, inst)) return error; + break; + case SpvOpGroupDecorate: + if (auto error = ValidateGroupDecorate(_, inst)) return error; + break; + case SpvOpGroupMemberDecorate: + if (auto error = ValidateGroupMemberDecorate(_, inst)) return error; + break; + default: + break; + } + + // In order to validate decoration rules, we need to know all the decorations + // that are applied to any given . + RegisterDecorations(_, inst); + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_arithmetics.cpp b/source/val/validate_arithmetics.cpp new file mode 100644 index 000000000..2314e7dfc --- /dev/null +++ b/source/val/validate_arithmetics.cpp @@ -0,0 +1,453 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Performs validation of arithmetic instructions. + +#include "source/val/validate.h" + +#include + +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { + +// Validates correctness of arithmetic instructions. +spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + const uint32_t result_type = inst->type_id(); + + switch (opcode) { + case SpvOpFAdd: + case SpvOpFSub: + case SpvOpFMul: + case SpvOpFDiv: + case SpvOpFRem: + case SpvOpFMod: + case SpvOpFNegate: { + if (!_.IsFloatScalarType(result_type) && + !_.IsFloatVectorType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected floating scalar or vector type as Result Type: " + << spvOpcodeString(opcode); + + for (size_t operand_index = 2; operand_index < inst->operands().size(); + ++operand_index) { + if (_.GetOperandTypeId(inst, operand_index) != result_type) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected arithmetic operands to be of Result Type: " + << spvOpcodeString(opcode) << " operand index " + << operand_index; + } + break; + } + + case SpvOpUDiv: + case SpvOpUMod: { + if (!_.IsUnsignedIntScalarType(result_type) && + !_.IsUnsignedIntVectorType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected unsigned int scalar or vector type as Result Type: " + << spvOpcodeString(opcode); + + for (size_t operand_index = 2; operand_index < inst->operands().size(); + ++operand_index) { + if (_.GetOperandTypeId(inst, operand_index) != result_type) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected arithmetic operands to be of Result Type: " + << spvOpcodeString(opcode) << " operand index " + << operand_index; + } + break; + } + + case SpvOpISub: + case SpvOpIAdd: + case SpvOpIMul: + case SpvOpSDiv: + case SpvOpSMod: + case SpvOpSRem: + case SpvOpSNegate: { + if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected int scalar or vector type as Result Type: " + << spvOpcodeString(opcode); + + const uint32_t dimension = _.GetDimension(result_type); + const uint32_t bit_width = _.GetBitWidth(result_type); + + for (size_t operand_index = 2; operand_index < inst->operands().size(); + ++operand_index) { + const uint32_t type_id = _.GetOperandTypeId(inst, operand_index); + if (!type_id || + (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id))) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected int scalar or vector type as operand: " + << spvOpcodeString(opcode) << " operand index " + << operand_index; + + if (_.GetDimension(type_id) != dimension) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected arithmetic operands to have the same dimension " + << "as Result Type: " << spvOpcodeString(opcode) + << " operand index " << operand_index; + + if (_.GetBitWidth(type_id) != bit_width) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected arithmetic operands to have the same bit width " + << "as Result Type: " << spvOpcodeString(opcode) + << " operand index " << operand_index; + } + break; + } + + case SpvOpDot: { + if (!_.IsFloatScalarType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected float scalar type as Result Type: " + << spvOpcodeString(opcode); + + uint32_t first_vector_num_components = 0; + + for (size_t operand_index = 2; operand_index < inst->operands().size(); + ++operand_index) { + const uint32_t type_id = _.GetOperandTypeId(inst, operand_index); + + if (!type_id || !_.IsFloatVectorType(type_id)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected float vector as operand: " + << spvOpcodeString(opcode) << " operand index " + << operand_index; + + const uint32_t component_type = _.GetComponentType(type_id); + if (component_type != result_type) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected component type to be equal to Result Type: " + << spvOpcodeString(opcode) << " operand index " + << operand_index; + + const uint32_t num_components = _.GetDimension(type_id); + if (operand_index == 2) { + first_vector_num_components = num_components; + } else if (num_components != first_vector_num_components) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected operands to have the same number of componenets: " + << spvOpcodeString(opcode); + } + } + break; + } + + case SpvOpVectorTimesScalar: { + if (!_.IsFloatVectorType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected float vector type as Result Type: " + << spvOpcodeString(opcode); + + const uint32_t vector_type_id = _.GetOperandTypeId(inst, 2); + if (result_type != vector_type_id) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected vector operand type to be equal to Result Type: " + << spvOpcodeString(opcode); + + const uint32_t component_type = _.GetComponentType(vector_type_id); + + const uint32_t scalar_type_id = _.GetOperandTypeId(inst, 3); + if (component_type != scalar_type_id) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected scalar operand type to be equal to the component " + << "type of the vector operand: " << spvOpcodeString(opcode); + + break; + } + + case SpvOpMatrixTimesScalar: { + if (!_.IsFloatMatrixType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected float matrix type as Result Type: " + << spvOpcodeString(opcode); + + const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 2); + if (result_type != matrix_type_id) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected matrix operand type to be equal to Result Type: " + << spvOpcodeString(opcode); + + const uint32_t component_type = _.GetComponentType(matrix_type_id); + + const uint32_t scalar_type_id = _.GetOperandTypeId(inst, 3); + if (component_type != scalar_type_id) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected scalar operand type to be equal to the component " + << "type of the matrix operand: " << spvOpcodeString(opcode); + + break; + } + + case SpvOpVectorTimesMatrix: { + const uint32_t vector_type_id = _.GetOperandTypeId(inst, 2); + const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 3); + + if (!_.IsFloatVectorType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected float vector type as Result Type: " + << spvOpcodeString(opcode); + + const uint32_t res_component_type = _.GetComponentType(result_type); + + if (!vector_type_id || !_.IsFloatVectorType(vector_type_id)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected float vector type as left operand: " + << spvOpcodeString(opcode); + + if (res_component_type != _.GetComponentType(vector_type_id)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected component types of Result Type and vector to be " + << "equal: " << spvOpcodeString(opcode); + + uint32_t matrix_num_rows = 0; + uint32_t matrix_num_cols = 0; + uint32_t matrix_col_type = 0; + uint32_t matrix_component_type = 0; + if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows, + &matrix_num_cols, &matrix_col_type, + &matrix_component_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected float matrix type as right operand: " + << spvOpcodeString(opcode); + + if (res_component_type != matrix_component_type) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected component types of Result Type and matrix to be " + << "equal: " << spvOpcodeString(opcode); + + if (matrix_num_cols != _.GetDimension(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected number of columns of the matrix to be equal to " + << "Result Type vector size: " << spvOpcodeString(opcode); + + if (matrix_num_rows != _.GetDimension(vector_type_id)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected number of rows of the matrix to be equal to the " + << "vector operand size: " << spvOpcodeString(opcode); + + break; + } + + case SpvOpMatrixTimesVector: { + const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 2); + const uint32_t vector_type_id = _.GetOperandTypeId(inst, 3); + + if (!_.IsFloatVectorType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected float vector type as Result Type: " + << spvOpcodeString(opcode); + + uint32_t matrix_num_rows = 0; + uint32_t matrix_num_cols = 0; + uint32_t matrix_col_type = 0; + uint32_t matrix_component_type = 0; + if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows, + &matrix_num_cols, &matrix_col_type, + &matrix_component_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected float matrix type as left operand: " + << spvOpcodeString(opcode); + + if (result_type != matrix_col_type) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected column type of the matrix to be equal to Result " + "Type: " + << spvOpcodeString(opcode); + + if (!vector_type_id || !_.IsFloatVectorType(vector_type_id)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected float vector type as right operand: " + << spvOpcodeString(opcode); + + if (matrix_component_type != _.GetComponentType(vector_type_id)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected component types of the operands to be equal: " + << spvOpcodeString(opcode); + + if (matrix_num_cols != _.GetDimension(vector_type_id)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected number of columns of the matrix to be equal to the " + << "vector size: " << spvOpcodeString(opcode); + + break; + } + + case SpvOpMatrixTimesMatrix: { + const uint32_t left_type_id = _.GetOperandTypeId(inst, 2); + const uint32_t right_type_id = _.GetOperandTypeId(inst, 3); + + uint32_t res_num_rows = 0; + uint32_t res_num_cols = 0; + uint32_t res_col_type = 0; + uint32_t res_component_type = 0; + if (!_.GetMatrixTypeInfo(result_type, &res_num_rows, &res_num_cols, + &res_col_type, &res_component_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected float matrix type as Result Type: " + << spvOpcodeString(opcode); + + uint32_t left_num_rows = 0; + uint32_t left_num_cols = 0; + uint32_t left_col_type = 0; + uint32_t left_component_type = 0; + if (!_.GetMatrixTypeInfo(left_type_id, &left_num_rows, &left_num_cols, + &left_col_type, &left_component_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected float matrix type as left operand: " + << spvOpcodeString(opcode); + + uint32_t right_num_rows = 0; + uint32_t right_num_cols = 0; + uint32_t right_col_type = 0; + uint32_t right_component_type = 0; + if (!_.GetMatrixTypeInfo(right_type_id, &right_num_rows, &right_num_cols, + &right_col_type, &right_component_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected float matrix type as right operand: " + << spvOpcodeString(opcode); + + if (!_.IsFloatScalarType(res_component_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected float matrix type as Result Type: " + << spvOpcodeString(opcode); + + if (res_col_type != left_col_type) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected column types of Result Type and left matrix to be " + << "equal: " << spvOpcodeString(opcode); + + if (res_component_type != right_component_type) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected component types of Result Type and right matrix to " + "be " + << "equal: " << spvOpcodeString(opcode); + + if (res_num_cols != right_num_cols) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected number of columns of Result Type and right matrix " + "to " + << "be equal: " << spvOpcodeString(opcode); + + if (left_num_cols != right_num_rows) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected number of columns of left matrix and number of " + "rows " + << "of right matrix to be equal: " << spvOpcodeString(opcode); + + assert(left_num_rows == res_num_rows); + break; + } + + case SpvOpOuterProduct: { + const uint32_t left_type_id = _.GetOperandTypeId(inst, 2); + const uint32_t right_type_id = _.GetOperandTypeId(inst, 3); + + uint32_t res_num_rows = 0; + uint32_t res_num_cols = 0; + uint32_t res_col_type = 0; + uint32_t res_component_type = 0; + if (!_.GetMatrixTypeInfo(result_type, &res_num_rows, &res_num_cols, + &res_col_type, &res_component_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected float matrix type as Result Type: " + << spvOpcodeString(opcode); + + if (left_type_id != res_col_type) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected column type of Result Type to be equal to the type " + << "of the left operand: " << spvOpcodeString(opcode); + + if (!right_type_id || !_.IsFloatVectorType(right_type_id)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected float vector type as right operand: " + << spvOpcodeString(opcode); + + if (res_component_type != _.GetComponentType(right_type_id)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected component types of the operands to be equal: " + << spvOpcodeString(opcode); + + if (res_num_cols != _.GetDimension(right_type_id)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected number of columns of the matrix to be equal to the " + << "vector size of the right operand: " + << spvOpcodeString(opcode); + + break; + } + + case SpvOpIAddCarry: + case SpvOpISubBorrow: + case SpvOpUMulExtended: + case SpvOpSMulExtended: { + std::vector result_types; + if (!_.GetStructMemberTypes(result_type, &result_types)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected a struct as Result Type: " + << spvOpcodeString(opcode); + + if (result_types.size() != 2) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type struct to have two members: " + << spvOpcodeString(opcode); + + if (opcode == SpvOpSMulExtended) { + if (!_.IsIntScalarType(result_types[0]) && + !_.IsIntVectorType(result_types[0])) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type struct member types to be integer " + "scalar " + << "or vector: " << spvOpcodeString(opcode); + } else { + if (!_.IsUnsignedIntScalarType(result_types[0]) && + !_.IsUnsignedIntVectorType(result_types[0])) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type struct member types to be unsigned " + << "integer scalar or vector: " << spvOpcodeString(opcode); + } + + if (result_types[0] != result_types[1]) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type struct member types to be identical: " + << spvOpcodeString(opcode); + + const uint32_t left_type_id = _.GetOperandTypeId(inst, 2); + const uint32_t right_type_id = _.GetOperandTypeId(inst, 3); + + if (left_type_id != result_types[0] || right_type_id != result_types[0]) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected both operands to be of Result Type member type: " + << spvOpcodeString(opcode); + + break; + } + + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_atomics.cpp b/source/val/validate_atomics.cpp new file mode 100644 index 000000000..38c7053b9 --- /dev/null +++ b/source/val/validate_atomics.cpp @@ -0,0 +1,229 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validates correctness of atomic SPIR-V instructions. + +#include "source/val/validate.h" + +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/spirv_target_env.h" +#include "source/util/bitutils.h" +#include "source/val/instruction.h" +#include "source/val/validate_memory_semantics.h" +#include "source/val/validate_scopes.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { + +// Validates correctness of atomic instructions. +spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + const uint32_t result_type = inst->type_id(); + + switch (opcode) { + case SpvOpAtomicLoad: + case SpvOpAtomicStore: + case SpvOpAtomicExchange: + case SpvOpAtomicCompareExchange: + case SpvOpAtomicCompareExchangeWeak: + case SpvOpAtomicIIncrement: + case SpvOpAtomicIDecrement: + case SpvOpAtomicIAdd: + case SpvOpAtomicISub: + case SpvOpAtomicSMin: + case SpvOpAtomicUMin: + case SpvOpAtomicSMax: + case SpvOpAtomicUMax: + case SpvOpAtomicAnd: + case SpvOpAtomicOr: + case SpvOpAtomicXor: + case SpvOpAtomicFlagTestAndSet: + case SpvOpAtomicFlagClear: { + if (_.HasCapability(SpvCapabilityKernel) && + (opcode == SpvOpAtomicLoad || opcode == SpvOpAtomicExchange || + opcode == SpvOpAtomicCompareExchange)) { + if (!_.IsFloatScalarType(result_type) && + !_.IsIntScalarType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": expected Result Type to be int or float scalar type"; + } + } else if (opcode == SpvOpAtomicFlagTestAndSet) { + if (!_.IsBoolScalarType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": expected Result Type to be bool scalar type"; + } + } else if (opcode == SpvOpAtomicFlagClear || opcode == SpvOpAtomicStore) { + assert(result_type == 0); + } else { + if (!_.IsIntScalarType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": expected Result Type to be int scalar type"; + } + if (spvIsVulkanEnv(_.context()->target_env) && + _.GetBitWidth(result_type) != 32) { + switch (opcode) { + case SpvOpAtomicSMin: + case SpvOpAtomicUMin: + case SpvOpAtomicSMax: + case SpvOpAtomicUMax: + case SpvOpAtomicAnd: + case SpvOpAtomicOr: + case SpvOpAtomicXor: + case SpvOpAtomicIAdd: + case SpvOpAtomicLoad: + case SpvOpAtomicStore: + case SpvOpAtomicExchange: + case SpvOpAtomicCompareExchange: { + if (_.GetBitWidth(result_type) == 64 && + !_.HasCapability(SpvCapabilityInt64Atomics)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": 64-bit atomics require the Int64Atomics " + "capability"; + } break; + default: + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": according to the Vulkan spec atomic Result Type " + "needs " + "to be a 32-bit int scalar type"; + } + } + } + + uint32_t operand_index = + opcode == SpvOpAtomicFlagClear || opcode == SpvOpAtomicStore ? 0 : 2; + const uint32_t pointer_type = _.GetOperandTypeId(inst, operand_index++); + + uint32_t data_type = 0; + uint32_t storage_class = 0; + if (!_.GetPointerTypeInfo(pointer_type, &data_type, &storage_class)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": expected Pointer to be of type OpTypePointer"; + } + + switch (storage_class) { + case SpvStorageClassUniform: + case SpvStorageClassWorkgroup: + case SpvStorageClassCrossWorkgroup: + case SpvStorageClassGeneric: + case SpvStorageClassAtomicCounter: + case SpvStorageClassImage: + case SpvStorageClassStorageBuffer: + case SpvStorageClassPhysicalStorageBufferEXT: + break; + default: + if (spvIsOpenCLEnv(_.context()->target_env)) { + if (storage_class != SpvStorageClassFunction) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": expected Pointer Storage Class to be Uniform, " + "Workgroup, CrossWorkgroup, Generic, AtomicCounter, " + "Image, StorageBuffer or Function"; + } + } else { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": expected Pointer Storage Class to be Uniform, " + "Workgroup, CrossWorkgroup, Generic, AtomicCounter, " + "Image or StorageBuffer"; + } + } + + if (opcode == SpvOpAtomicFlagTestAndSet || + opcode == SpvOpAtomicFlagClear) { + if (!_.IsIntScalarType(data_type) || _.GetBitWidth(data_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": expected Pointer to point to a value of 32-bit int type"; + } + } else if (opcode == SpvOpAtomicStore) { + if (!_.IsFloatScalarType(data_type) && !_.IsIntScalarType(data_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": expected Pointer to be a pointer to int or float " + << "scalar type"; + } + } else { + if (data_type != result_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": expected Pointer to point to a value of type Result " + "Type"; + } + } + + auto memory_scope = inst->GetOperandAs(operand_index++); + if (auto error = ValidateMemoryScope(_, inst, memory_scope)) { + return error; + } + + if (auto error = ValidateMemorySemantics(_, inst, operand_index++)) + return error; + + if (opcode == SpvOpAtomicCompareExchange || + opcode == SpvOpAtomicCompareExchangeWeak) { + if (auto error = ValidateMemorySemantics(_, inst, operand_index++)) + return error; + } + + if (opcode == SpvOpAtomicStore) { + const uint32_t value_type = _.GetOperandTypeId(inst, 3); + if (value_type != data_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": expected Value type and the type pointed to by " + "Pointer to be the same"; + } + } else if (opcode != SpvOpAtomicLoad && opcode != SpvOpAtomicIIncrement && + opcode != SpvOpAtomicIDecrement && + opcode != SpvOpAtomicFlagTestAndSet && + opcode != SpvOpAtomicFlagClear) { + const uint32_t value_type = _.GetOperandTypeId(inst, operand_index++); + if (value_type != result_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": expected Value to be of type Result Type"; + } + } + + if (opcode == SpvOpAtomicCompareExchange || + opcode == SpvOpAtomicCompareExchangeWeak) { + const uint32_t comparator_type = + _.GetOperandTypeId(inst, operand_index++); + if (comparator_type != result_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": expected Comparator to be of type Result Type"; + } + } + + break; + } + + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_barriers.cpp b/source/val/validate_barriers.cpp new file mode 100644 index 000000000..4fbe9c90a --- /dev/null +++ b/source/val/validate_barriers.cpp @@ -0,0 +1,138 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validates correctness of barrier SPIR-V instructions. + +#include "source/val/validate.h" + +#include + +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/spirv_constant.h" +#include "source/spirv_target_env.h" +#include "source/util/bitutils.h" +#include "source/val/instruction.h" +#include "source/val/validate_memory_semantics.h" +#include "source/val/validate_scopes.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { + +// Validates correctness of barrier instructions. +spv_result_t BarriersPass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + const uint32_t result_type = inst->type_id(); + + switch (opcode) { + case SpvOpControlBarrier: { + if (spvVersionForTargetEnv(_.context()->target_env) < + SPV_SPIRV_VERSION_WORD(1, 3)) { + _.function(inst->function()->id()) + ->RegisterExecutionModelLimitation( + [](SpvExecutionModel model, std::string* message) { + if (model != SpvExecutionModelTessellationControl && + model != SpvExecutionModelGLCompute && + model != SpvExecutionModelKernel && + model != SpvExecutionModelTaskNV && + model != SpvExecutionModelMeshNV) { + if (message) { + *message = + "OpControlBarrier requires one of the following " + "Execution " + "Models: TessellationControl, GLCompute or Kernel"; + } + return false; + } + return true; + }); + } + + const uint32_t execution_scope = inst->word(1); + const uint32_t memory_scope = inst->word(2); + + if (auto error = ValidateExecutionScope(_, inst, execution_scope)) { + return error; + } + + if (auto error = ValidateMemoryScope(_, inst, memory_scope)) { + return error; + } + + if (auto error = ValidateMemorySemantics(_, inst, 2)) { + return error; + } + break; + } + + case SpvOpMemoryBarrier: { + const uint32_t memory_scope = inst->word(1); + + if (auto error = ValidateMemoryScope(_, inst, memory_scope)) { + return error; + } + + if (auto error = ValidateMemorySemantics(_, inst, 1)) { + return error; + } + break; + } + + case SpvOpNamedBarrierInitialize: { + if (_.GetIdOpcode(result_type) != SpvOpTypeNamedBarrier) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": expected Result Type to be OpTypeNamedBarrier"; + } + + const uint32_t subgroup_count_type = _.GetOperandTypeId(inst, 2); + if (!_.IsIntScalarType(subgroup_count_type) || + _.GetBitWidth(subgroup_count_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": expected Subgroup Count to be a 32-bit int"; + } + break; + } + + case SpvOpMemoryNamedBarrier: { + const uint32_t named_barrier_type = _.GetOperandTypeId(inst, 0); + if (_.GetIdOpcode(named_barrier_type) != SpvOpTypeNamedBarrier) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": expected Named Barrier to be of type OpTypeNamedBarrier"; + } + + const uint32_t memory_scope = inst->word(2); + + if (auto error = ValidateMemoryScope(_, inst, memory_scope)) { + return error; + } + + if (auto error = ValidateMemorySemantics(_, inst, 2)) { + return error; + } + break; + } + + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_bitwise.cpp b/source/val/validate_bitwise.cpp new file mode 100644 index 000000000..d46b3fcab --- /dev/null +++ b/source/val/validate_bitwise.cpp @@ -0,0 +1,219 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validates correctness of bitwise instructions. + +#include "source/val/validate.h" + +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { + +// Validates correctness of bitwise instructions. +spv_result_t BitwisePass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + const uint32_t result_type = inst->type_id(); + + switch (opcode) { + case SpvOpShiftRightLogical: + case SpvOpShiftRightArithmetic: + case SpvOpShiftLeftLogical: { + if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected int scalar or vector type as Result Type: " + << spvOpcodeString(opcode); + + const uint32_t result_dimension = _.GetDimension(result_type); + const uint32_t base_type = _.GetOperandTypeId(inst, 2); + const uint32_t shift_type = _.GetOperandTypeId(inst, 3); + + if (!base_type || + (!_.IsIntScalarType(base_type) && !_.IsIntVectorType(base_type))) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Base to be int scalar or vector: " + << spvOpcodeString(opcode); + + if (_.GetDimension(base_type) != result_dimension) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Base to have the same dimension " + << "as Result Type: " << spvOpcodeString(opcode); + + if (_.GetBitWidth(base_type) != _.GetBitWidth(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Base to have the same bit width " + << "as Result Type: " << spvOpcodeString(opcode); + + if (!shift_type || + (!_.IsIntScalarType(shift_type) && !_.IsIntVectorType(shift_type))) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Shift to be int scalar or vector: " + << spvOpcodeString(opcode); + + if (_.GetDimension(shift_type) != result_dimension) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Shift to have the same dimension " + << "as Result Type: " << spvOpcodeString(opcode); + break; + } + + case SpvOpBitwiseOr: + case SpvOpBitwiseXor: + case SpvOpBitwiseAnd: + case SpvOpNot: { + if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected int scalar or vector type as Result Type: " + << spvOpcodeString(opcode); + + const uint32_t result_dimension = _.GetDimension(result_type); + const uint32_t result_bit_width = _.GetBitWidth(result_type); + + for (size_t operand_index = 2; operand_index < inst->operands().size(); + ++operand_index) { + const uint32_t type_id = _.GetOperandTypeId(inst, operand_index); + if (!type_id || + (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id))) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected int scalar or vector as operand: " + << spvOpcodeString(opcode) << " operand index " + << operand_index; + + if (_.GetDimension(type_id) != result_dimension) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected operands to have the same dimension " + << "as Result Type: " << spvOpcodeString(opcode) + << " operand index " << operand_index; + + if (_.GetBitWidth(type_id) != result_bit_width) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected operands to have the same bit width " + << "as Result Type: " << spvOpcodeString(opcode) + << " operand index " << operand_index; + } + break; + } + + case SpvOpBitFieldInsert: { + if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected int scalar or vector type as Result Type: " + << spvOpcodeString(opcode); + + const uint32_t base_type = _.GetOperandTypeId(inst, 2); + const uint32_t insert_type = _.GetOperandTypeId(inst, 3); + const uint32_t offset_type = _.GetOperandTypeId(inst, 4); + const uint32_t count_type = _.GetOperandTypeId(inst, 5); + + if (base_type != result_type) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Base Type to be equal to Result Type: " + << spvOpcodeString(opcode); + + if (insert_type != result_type) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Insert Type to be equal to Result Type: " + << spvOpcodeString(opcode); + + if (!offset_type || !_.IsIntScalarType(offset_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Offset Type to be int scalar: " + << spvOpcodeString(opcode); + + if (!count_type || !_.IsIntScalarType(count_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Count Type to be int scalar: " + << spvOpcodeString(opcode); + break; + } + + case SpvOpBitFieldSExtract: + case SpvOpBitFieldUExtract: { + if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected int scalar or vector type as Result Type: " + << spvOpcodeString(opcode); + + const uint32_t base_type = _.GetOperandTypeId(inst, 2); + const uint32_t offset_type = _.GetOperandTypeId(inst, 3); + const uint32_t count_type = _.GetOperandTypeId(inst, 4); + + if (base_type != result_type) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Base Type to be equal to Result Type: " + << spvOpcodeString(opcode); + + if (!offset_type || !_.IsIntScalarType(offset_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Offset Type to be int scalar: " + << spvOpcodeString(opcode); + + if (!count_type || !_.IsIntScalarType(count_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Count Type to be int scalar: " + << spvOpcodeString(opcode); + break; + } + + case SpvOpBitReverse: { + if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected int scalar or vector type as Result Type: " + << spvOpcodeString(opcode); + + const uint32_t base_type = _.GetOperandTypeId(inst, 2); + + if (base_type != result_type) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Base Type to be equal to Result Type: " + << spvOpcodeString(opcode); + break; + } + + case SpvOpBitCount: { + if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected int scalar or vector type as Result Type: " + << spvOpcodeString(opcode); + + const uint32_t base_type = _.GetOperandTypeId(inst, 2); + if (!base_type || + (!_.IsIntScalarType(base_type) && !_.IsIntVectorType(base_type))) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Base Type to be int scalar or vector: " + << spvOpcodeString(opcode); + + const uint32_t base_dimension = _.GetDimension(base_type); + const uint32_t result_dimension = _.GetDimension(result_type); + + if (base_dimension != result_dimension) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Base dimension to be equal to Result Type " + "dimension: " + << spvOpcodeString(opcode); + break; + } + + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_builtins.cpp b/source/val/validate_builtins.cpp new file mode 100644 index 000000000..aaba32490 --- /dev/null +++ b/source/val/validate_builtins.cpp @@ -0,0 +1,2675 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validates correctness of built-in variables. + +#include "source/val/validate.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/spirv_target_env.h" +#include "source/util/bitutils.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +// Returns a short textual description of the id defined by the given +// instruction. +std::string GetIdDesc(const Instruction& inst) { + std::ostringstream ss; + ss << "ID <" << inst.id() << "> (Op" << spvOpcodeString(inst.opcode()) << ")"; + return ss.str(); +} + +// Gets underlying data type which is +// - member type if instruction is OpTypeStruct +// (member index is taken from decoration). +// - data type if id creates a pointer. +// - type of the constant if instruction is OpConst or OpSpecConst. +// +// Fails in any other case. The function is based on built-ins allowed by +// the Vulkan spec. +// TODO: If non-Vulkan validation rules are added then it might need +// to be refactored. +spv_result_t GetUnderlyingType(ValidationState_t& _, + const Decoration& decoration, + const Instruction& inst, + uint32_t* underlying_type) { + if (decoration.struct_member_index() != Decoration::kInvalidMember) { + assert(inst.opcode() == SpvOpTypeStruct); + *underlying_type = inst.word(decoration.struct_member_index() + 2); + return SPV_SUCCESS; + } + + assert(inst.opcode() != SpvOpTypeStruct); + + if (spvOpcodeIsConstant(inst.opcode())) { + *underlying_type = inst.type_id(); + return SPV_SUCCESS; + } + + uint32_t storage_class = 0; + if (!_.GetPointerTypeInfo(inst.type_id(), underlying_type, &storage_class)) { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << GetIdDesc(inst) + << " is decorated with BuiltIn. BuiltIn decoration should only be " + "applied to struct types, variables and constants."; + } + return SPV_SUCCESS; +} + +// Returns Storage Class used by the instruction if applicable. +// Returns SpvStorageClassMax if not. +SpvStorageClass GetStorageClass(const Instruction& inst) { + switch (inst.opcode()) { + case SpvOpTypePointer: + case SpvOpTypeForwardPointer: { + return SpvStorageClass(inst.word(2)); + } + case SpvOpVariable: { + return SpvStorageClass(inst.word(3)); + } + case SpvOpGenericCastToPtrExplicit: { + return SpvStorageClass(inst.word(4)); + } + default: { break; } + } + return SpvStorageClassMax; +} + +// Helper class managing validation of built-ins. +// TODO: Generic functionality of this class can be moved into +// ValidationState_t to be made available to other users. +class BuiltInsValidator { + public: + BuiltInsValidator(ValidationState_t& vstate) : _(vstate) {} + + // Run validation. + spv_result_t Run(); + + private: + // Goes through all decorations in the module, if decoration is BuiltIn + // calls ValidateSingleBuiltInAtDefinition(). + spv_result_t ValidateBuiltInsAtDefinition(); + + // Validates the instruction defining an id with built-in decoration. + // Can be called multiple times for the same id, if multiple built-ins are + // specified. Seeds id_to_at_reference_checks_ with decorated ids if needed. + spv_result_t ValidateSingleBuiltInAtDefinition(const Decoration& decoration, + const Instruction& inst); + + // The following section contains functions which are called when id defined + // by |inst| is decorated with BuiltIn |decoration|. + // Most functions are specific to a single built-in and have naming scheme: + // ValidateXYZAtDefinition. Some functions are common to multiple kinds of + // BuiltIn. + spv_result_t ValidateClipOrCullDistanceAtDefinition( + const Decoration& decoration, const Instruction& inst); + spv_result_t ValidateFragCoordAtDefinition(const Decoration& decoration, + const Instruction& inst); + spv_result_t ValidateFragDepthAtDefinition(const Decoration& decoration, + const Instruction& inst); + spv_result_t ValidateFrontFacingAtDefinition(const Decoration& decoration, + const Instruction& inst); + spv_result_t ValidateHelperInvocationAtDefinition( + const Decoration& decoration, const Instruction& inst); + spv_result_t ValidateInvocationIdAtDefinition(const Decoration& decoration, + const Instruction& inst); + spv_result_t ValidateInstanceIndexAtDefinition(const Decoration& decoration, + const Instruction& inst); + spv_result_t ValidateLayerOrViewportIndexAtDefinition( + const Decoration& decoration, const Instruction& inst); + spv_result_t ValidatePatchVerticesAtDefinition(const Decoration& decoration, + const Instruction& inst); + spv_result_t ValidatePointCoordAtDefinition(const Decoration& decoration, + const Instruction& inst); + spv_result_t ValidatePointSizeAtDefinition(const Decoration& decoration, + const Instruction& inst); + spv_result_t ValidatePositionAtDefinition(const Decoration& decoration, + const Instruction& inst); + spv_result_t ValidatePrimitiveIdAtDefinition(const Decoration& decoration, + const Instruction& inst); + spv_result_t ValidateSampleIdAtDefinition(const Decoration& decoration, + const Instruction& inst); + spv_result_t ValidateSampleMaskAtDefinition(const Decoration& decoration, + const Instruction& inst); + spv_result_t ValidateSamplePositionAtDefinition(const Decoration& decoration, + const Instruction& inst); + spv_result_t ValidateTessCoordAtDefinition(const Decoration& decoration, + const Instruction& inst); + spv_result_t ValidateTessLevelOuterAtDefinition(const Decoration& decoration, + const Instruction& inst); + spv_result_t ValidateTessLevelInnerAtDefinition(const Decoration& decoration, + const Instruction& inst); + spv_result_t ValidateVertexIndexAtDefinition(const Decoration& decoration, + const Instruction& inst); + spv_result_t ValidateVertexIdOrInstanceIdAtDefinition( + const Decoration& decoration, const Instruction& inst); + spv_result_t ValidateWorkgroupSizeAtDefinition(const Decoration& decoration, + const Instruction& inst); + // Used for GlobalInvocationId, LocalInvocationId, NumWorkgroups, WorkgroupId. + spv_result_t ValidateComputeShaderI32Vec3InputAtDefinition( + const Decoration& decoration, const Instruction& inst); + + // The following section contains functions which are called when id defined + // by |referenced_inst| is + // 1. referenced by |referenced_from_inst| + // 2. dependent on |built_in_inst| which is decorated with BuiltIn + // |decoration|. Most functions are specific to a single built-in and have + // naming scheme: ValidateXYZAtReference. Some functions are common to + // multiple kinds of BuiltIn. + spv_result_t ValidateFragCoordAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst); + + spv_result_t ValidateFragDepthAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst); + + spv_result_t ValidateFrontFacingAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst); + + spv_result_t ValidateHelperInvocationAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst); + + spv_result_t ValidateInvocationIdAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst); + + spv_result_t ValidateInstanceIdAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst); + + spv_result_t ValidateInstanceIndexAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst); + + spv_result_t ValidatePatchVerticesAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst); + + spv_result_t ValidatePointCoordAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst); + + spv_result_t ValidatePointSizeAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst); + + spv_result_t ValidatePositionAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst); + + spv_result_t ValidatePrimitiveIdAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst); + + spv_result_t ValidateSampleIdAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst); + + spv_result_t ValidateSampleMaskAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst); + + spv_result_t ValidateSamplePositionAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst); + + spv_result_t ValidateTessCoordAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst); + + spv_result_t ValidateTessLevelAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst); + + spv_result_t ValidateVertexIndexAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst); + + spv_result_t ValidateLayerOrViewportIndexAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst); + + spv_result_t ValidateWorkgroupSizeAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst); + + spv_result_t ValidateClipOrCullDistanceAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst); + + // Used for GlobalInvocationId, LocalInvocationId, NumWorkgroups, WorkgroupId. + spv_result_t ValidateComputeShaderI32Vec3InputAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst); + + // Validates that |built_in_inst| is not (even indirectly) referenced from + // within a function which can be called with |execution_model|. + // + // |comment| - text explaining why the restriction was imposed. + // |decoration| - BuiltIn decoration which causes the restriction. + // |referenced_inst| - instruction which is dependent on |built_in_inst| and + // defines the id which was referenced. + // |referenced_from_inst| - instruction which references id defined by + // |referenced_inst| from within a function. + spv_result_t ValidateNotCalledWithExecutionModel( + const char* comment, SpvExecutionModel execution_model, + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst); + + // The following section contains functions which check that the decorated + // variable has the type specified in the function name. |diag| would be + // called with a corresponding error message, if validation is not successful. + spv_result_t ValidateBool( + const Decoration& decoration, const Instruction& inst, + const std::function& diag); + spv_result_t ValidateI32( + const Decoration& decoration, const Instruction& inst, + const std::function& diag); + spv_result_t ValidateI32Vec( + const Decoration& decoration, const Instruction& inst, + uint32_t num_components, + const std::function& diag); + spv_result_t ValidateI32Arr( + const Decoration& decoration, const Instruction& inst, + const std::function& diag); + spv_result_t ValidateF32( + const Decoration& decoration, const Instruction& inst, + const std::function& diag); + spv_result_t ValidateOptionalArrayedF32( + const Decoration& decoration, const Instruction& inst, + const std::function& diag); + spv_result_t ValidateF32Helper( + const Decoration& decoration, const Instruction& inst, + const std::function& diag, + uint32_t underlying_type); + spv_result_t ValidateF32Vec( + const Decoration& decoration, const Instruction& inst, + uint32_t num_components, + const std::function& diag); + spv_result_t ValidateOptionalArrayedF32Vec( + const Decoration& decoration, const Instruction& inst, + uint32_t num_components, + const std::function& diag); + spv_result_t ValidateF32VecHelper( + const Decoration& decoration, const Instruction& inst, + uint32_t num_components, + const std::function& diag, + uint32_t underlying_type); + // If |num_components| is zero, the number of components is not checked. + spv_result_t ValidateF32Arr( + const Decoration& decoration, const Instruction& inst, + uint32_t num_components, + const std::function& diag); + spv_result_t ValidateOptionalArrayedF32Arr( + const Decoration& decoration, const Instruction& inst, + uint32_t num_components, + const std::function& diag); + spv_result_t ValidateF32ArrHelper( + const Decoration& decoration, const Instruction& inst, + uint32_t num_components, + const std::function& diag, + uint32_t underlying_type); + + // Generates strings like "Member #0 of struct ID <2>". + std::string GetDefinitionDesc(const Decoration& decoration, + const Instruction& inst) const; + + // Generates strings like "ID <51> (OpTypePointer) is referencing ID <2> + // (OpTypeStruct) which is decorated with BuiltIn Position". + std::string GetReferenceDesc( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst, + SpvExecutionModel execution_model = SpvExecutionModelMax) const; + + // Generates strings like "ID <51> (OpTypePointer) uses storage class + // UniformConstant". + std::string GetStorageClassDesc(const Instruction& inst) const; + + // Updates inner working of the class. Is called sequentially for every + // instruction. + void Update(const Instruction& inst); + + ValidationState_t& _; + + // Mapping id -> list of rules which validate instruction referencing the + // id. Rules can create new rules and add them to this container. + // Using std::map, and not std::unordered_map to avoid iterator invalidation + // during rehashing. + std::map>> + id_to_at_reference_checks_; + + // Id of the function we are currently inside. 0 if not inside a function. + uint32_t function_id_ = 0; + + // Entry points which can (indirectly) call the current function. + // The pointer either points to a vector inside to function_to_entry_points_ + // or to no_entry_points_. The pointer is guaranteed to never be null. + const std::vector no_entry_points; + const std::vector* entry_points_ = &no_entry_points; + + // Execution models with which the current function can be called. + std::set execution_models_; +}; + +void BuiltInsValidator::Update(const Instruction& inst) { + const SpvOp opcode = inst.opcode(); + if (opcode == SpvOpFunction) { + // Entering a function. + assert(function_id_ == 0); + function_id_ = inst.id(); + execution_models_.clear(); + entry_points_ = &_.FunctionEntryPoints(function_id_); + // Collect execution models from all entry points from which the current + // function can be called. + for (const uint32_t entry_point : *entry_points_) { + if (const auto* models = _.GetExecutionModels(entry_point)) { + execution_models_.insert(models->begin(), models->end()); + } + } + } + + if (opcode == SpvOpFunctionEnd) { + // Exiting a function. + assert(function_id_ != 0); + function_id_ = 0; + entry_points_ = &no_entry_points; + execution_models_.clear(); + } +} + +std::string BuiltInsValidator::GetDefinitionDesc( + const Decoration& decoration, const Instruction& inst) const { + std::ostringstream ss; + if (decoration.struct_member_index() != Decoration::kInvalidMember) { + assert(inst.opcode() == SpvOpTypeStruct); + ss << "Member #" << decoration.struct_member_index(); + ss << " of struct ID <" << inst.id() << ">"; + } else { + ss << GetIdDesc(inst); + } + return ss.str(); +} + +std::string BuiltInsValidator::GetReferenceDesc( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, const Instruction& referenced_from_inst, + SpvExecutionModel execution_model) const { + std::ostringstream ss; + ss << GetIdDesc(referenced_from_inst) << " is referencing " + << GetIdDesc(referenced_inst); + if (built_in_inst.id() != referenced_inst.id()) { + ss << " which is dependent on " << GetIdDesc(built_in_inst); + } + + ss << " which is decorated with BuiltIn "; + ss << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, + decoration.params()[0]); + if (function_id_) { + ss << " in function <" << function_id_ << ">"; + if (execution_model != SpvExecutionModelMax) { + ss << " called with execution model "; + ss << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_EXECUTION_MODEL, + execution_model); + } + } + ss << "."; + return ss.str(); +} + +std::string BuiltInsValidator::GetStorageClassDesc( + const Instruction& inst) const { + std::ostringstream ss; + ss << GetIdDesc(inst) << " uses storage class "; + ss << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_STORAGE_CLASS, + GetStorageClass(inst)); + ss << "."; + return ss.str(); +} + +spv_result_t BuiltInsValidator::ValidateBool( + const Decoration& decoration, const Instruction& inst, + const std::function& diag) { + uint32_t underlying_type = 0; + if (spv_result_t error = + GetUnderlyingType(_, decoration, inst, &underlying_type)) { + return error; + } + + if (!_.IsBoolScalarType(underlying_type)) { + return diag(GetDefinitionDesc(decoration, inst) + " is not a bool scalar."); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateI32( + const Decoration& decoration, const Instruction& inst, + const std::function& diag) { + uint32_t underlying_type = 0; + if (spv_result_t error = + GetUnderlyingType(_, decoration, inst, &underlying_type)) { + return error; + } + + if (!_.IsIntScalarType(underlying_type)) { + return diag(GetDefinitionDesc(decoration, inst) + " is not an int scalar."); + } + + const uint32_t bit_width = _.GetBitWidth(underlying_type); + if (bit_width != 32) { + std::ostringstream ss; + ss << GetDefinitionDesc(decoration, inst) << " has bit width " << bit_width + << "."; + return diag(ss.str()); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateOptionalArrayedF32( + const Decoration& decoration, const Instruction& inst, + const std::function& diag) { + uint32_t underlying_type = 0; + if (spv_result_t error = + GetUnderlyingType(_, decoration, inst, &underlying_type)) { + return error; + } + + // Strip the array, if present. + if (_.GetIdOpcode(underlying_type) == SpvOpTypeArray) { + underlying_type = _.FindDef(underlying_type)->word(2u); + } + + return ValidateF32Helper(decoration, inst, diag, underlying_type); +} + +spv_result_t BuiltInsValidator::ValidateF32( + const Decoration& decoration, const Instruction& inst, + const std::function& diag) { + uint32_t underlying_type = 0; + if (spv_result_t error = + GetUnderlyingType(_, decoration, inst, &underlying_type)) { + return error; + } + + return ValidateF32Helper(decoration, inst, diag, underlying_type); +} + +spv_result_t BuiltInsValidator::ValidateF32Helper( + const Decoration& decoration, const Instruction& inst, + const std::function& diag, + uint32_t underlying_type) { + if (!_.IsFloatScalarType(underlying_type)) { + return diag(GetDefinitionDesc(decoration, inst) + + " is not a float scalar."); + } + + const uint32_t bit_width = _.GetBitWidth(underlying_type); + if (bit_width != 32) { + std::ostringstream ss; + ss << GetDefinitionDesc(decoration, inst) << " has bit width " << bit_width + << "."; + return diag(ss.str()); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateI32Vec( + const Decoration& decoration, const Instruction& inst, + uint32_t num_components, + const std::function& diag) { + uint32_t underlying_type = 0; + if (spv_result_t error = + GetUnderlyingType(_, decoration, inst, &underlying_type)) { + return error; + } + + if (!_.IsIntVectorType(underlying_type)) { + return diag(GetDefinitionDesc(decoration, inst) + " is not an int vector."); + } + + const uint32_t actual_num_components = _.GetDimension(underlying_type); + if (_.GetDimension(underlying_type) != num_components) { + std::ostringstream ss; + ss << GetDefinitionDesc(decoration, inst) << " has " + << actual_num_components << " components."; + return diag(ss.str()); + } + + const uint32_t bit_width = _.GetBitWidth(underlying_type); + if (bit_width != 32) { + std::ostringstream ss; + ss << GetDefinitionDesc(decoration, inst) + << " has components with bit width " << bit_width << "."; + return diag(ss.str()); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateOptionalArrayedF32Vec( + const Decoration& decoration, const Instruction& inst, + uint32_t num_components, + const std::function& diag) { + uint32_t underlying_type = 0; + if (spv_result_t error = + GetUnderlyingType(_, decoration, inst, &underlying_type)) { + return error; + } + + // Strip the array, if present. + if (_.GetIdOpcode(underlying_type) == SpvOpTypeArray) { + underlying_type = _.FindDef(underlying_type)->word(2u); + } + + return ValidateF32VecHelper(decoration, inst, num_components, diag, + underlying_type); +} + +spv_result_t BuiltInsValidator::ValidateF32Vec( + const Decoration& decoration, const Instruction& inst, + uint32_t num_components, + const std::function& diag) { + uint32_t underlying_type = 0; + if (spv_result_t error = + GetUnderlyingType(_, decoration, inst, &underlying_type)) { + return error; + } + + return ValidateF32VecHelper(decoration, inst, num_components, diag, + underlying_type); +} + +spv_result_t BuiltInsValidator::ValidateF32VecHelper( + const Decoration& decoration, const Instruction& inst, + uint32_t num_components, + const std::function& diag, + uint32_t underlying_type) { + if (!_.IsFloatVectorType(underlying_type)) { + return diag(GetDefinitionDesc(decoration, inst) + + " is not a float vector."); + } + + const uint32_t actual_num_components = _.GetDimension(underlying_type); + if (_.GetDimension(underlying_type) != num_components) { + std::ostringstream ss; + ss << GetDefinitionDesc(decoration, inst) << " has " + << actual_num_components << " components."; + return diag(ss.str()); + } + + const uint32_t bit_width = _.GetBitWidth(underlying_type); + if (bit_width != 32) { + std::ostringstream ss; + ss << GetDefinitionDesc(decoration, inst) + << " has components with bit width " << bit_width << "."; + return diag(ss.str()); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateI32Arr( + const Decoration& decoration, const Instruction& inst, + const std::function& diag) { + uint32_t underlying_type = 0; + if (spv_result_t error = + GetUnderlyingType(_, decoration, inst, &underlying_type)) { + return error; + } + + const Instruction* const type_inst = _.FindDef(underlying_type); + if (type_inst->opcode() != SpvOpTypeArray) { + return diag(GetDefinitionDesc(decoration, inst) + " is not an array."); + } + + const uint32_t component_type = type_inst->word(2); + if (!_.IsIntScalarType(component_type)) { + return diag(GetDefinitionDesc(decoration, inst) + + " components are not int scalar."); + } + + const uint32_t bit_width = _.GetBitWidth(component_type); + if (bit_width != 32) { + std::ostringstream ss; + ss << GetDefinitionDesc(decoration, inst) + << " has components with bit width " << bit_width << "."; + return diag(ss.str()); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateF32Arr( + const Decoration& decoration, const Instruction& inst, + uint32_t num_components, + const std::function& diag) { + uint32_t underlying_type = 0; + if (spv_result_t error = + GetUnderlyingType(_, decoration, inst, &underlying_type)) { + return error; + } + + return ValidateF32ArrHelper(decoration, inst, num_components, diag, + underlying_type); +} + +spv_result_t BuiltInsValidator::ValidateOptionalArrayedF32Arr( + const Decoration& decoration, const Instruction& inst, + uint32_t num_components, + const std::function& diag) { + uint32_t underlying_type = 0; + if (spv_result_t error = + GetUnderlyingType(_, decoration, inst, &underlying_type)) { + return error; + } + + // Strip an extra layer of arraying if present. + if (_.GetIdOpcode(underlying_type) == SpvOpTypeArray) { + uint32_t subtype = _.FindDef(underlying_type)->word(2u); + if (_.GetIdOpcode(subtype) == SpvOpTypeArray) { + underlying_type = subtype; + } + } + + return ValidateF32ArrHelper(decoration, inst, num_components, diag, + underlying_type); +} + +spv_result_t BuiltInsValidator::ValidateF32ArrHelper( + const Decoration& decoration, const Instruction& inst, + uint32_t num_components, + const std::function& diag, + uint32_t underlying_type) { + const Instruction* const type_inst = _.FindDef(underlying_type); + if (type_inst->opcode() != SpvOpTypeArray) { + return diag(GetDefinitionDesc(decoration, inst) + " is not an array."); + } + + const uint32_t component_type = type_inst->word(2); + if (!_.IsFloatScalarType(component_type)) { + return diag(GetDefinitionDesc(decoration, inst) + + " components are not float scalar."); + } + + const uint32_t bit_width = _.GetBitWidth(component_type); + if (bit_width != 32) { + std::ostringstream ss; + ss << GetDefinitionDesc(decoration, inst) + << " has components with bit width " << bit_width << "."; + return diag(ss.str()); + } + + if (num_components != 0) { + uint64_t actual_num_components = 0; + if (!_.GetConstantValUint64(type_inst->word(3), &actual_num_components)) { + assert(0 && "Array type definition is corrupt"); + } + if (actual_num_components != num_components) { + std::ostringstream ss; + ss << GetDefinitionDesc(decoration, inst) << " has " + << actual_num_components << " components."; + return diag(ss.str()); + } + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateNotCalledWithExecutionModel( + const char* comment, SpvExecutionModel execution_model, + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst) { + if (function_id_) { + if (execution_models_.count(execution_model)) { + const char* execution_model_str = _.grammar().lookupOperandName( + SPV_OPERAND_TYPE_EXECUTION_MODEL, execution_model); + const char* built_in_str = _.grammar().lookupOperandName( + SPV_OPERAND_TYPE_BUILT_IN, decoration.params()[0]); + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << comment << " " << GetIdDesc(referenced_inst) << " depends on " + << GetIdDesc(built_in_inst) << " which is decorated with BuiltIn " + << built_in_str << "." + << " Id <" << referenced_inst.id() << "> is later referenced by " + << GetIdDesc(referenced_from_inst) << " in function <" + << function_id_ << "> which is called with execution model " + << execution_model_str << "."; + } + } else { + // Propagate this rule to all dependant ids in the global scope. + id_to_at_reference_checks_[referenced_from_inst.id()].push_back( + std::bind(&BuiltInsValidator::ValidateNotCalledWithExecutionModel, this, + comment, execution_model, decoration, built_in_inst, + referenced_from_inst, std::placeholders::_1)); + } + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateClipOrCullDistanceAtDefinition( + const Decoration& decoration, const Instruction& inst) { + // Seed at reference checks with this built-in. + return ValidateClipOrCullDistanceAtReference(decoration, inst, inst, inst); +} + +spv_result_t BuiltInsValidator::ValidateClipOrCullDistanceAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); + if (storage_class != SpvStorageClassMax && + storage_class != SpvStorageClassInput && + storage_class != SpvStorageClassOutput) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn " + << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, + decoration.params()[0]) + << " to be only used for variables with Input or Output storage " + "class. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst) + << " " << GetStorageClassDesc(referenced_from_inst); + } + + if (storage_class == SpvStorageClassInput) { + assert(function_id_ == 0); + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidateNotCalledWithExecutionModel, this, + "Vulkan spec doesn't allow BuiltIn ClipDistance/CullDistance to be " + "used for variables with Input storage class if execution model is " + "Vertex.", + SpvExecutionModelVertex, decoration, built_in_inst, + referenced_from_inst, std::placeholders::_1)); + } + + if (storage_class == SpvStorageClassOutput) { + assert(function_id_ == 0); + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidateNotCalledWithExecutionModel, this, + "Vulkan spec doesn't allow BuiltIn ClipDistance/CullDistance to be " + "used for variables with Output storage class if execution model is " + "Fragment.", + SpvExecutionModelFragment, decoration, built_in_inst, + referenced_from_inst, std::placeholders::_1)); + } + + for (const SpvExecutionModel execution_model : execution_models_) { + switch (execution_model) { + case SpvExecutionModelFragment: + case SpvExecutionModelVertex: { + if (spv_result_t error = ValidateF32Arr( + decoration, built_in_inst, /* Any number of components */ 0, + [this, &decoration, &referenced_from_inst]( + const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "According to the Vulkan spec BuiltIn " + << _.grammar().lookupOperandName( + SPV_OPERAND_TYPE_BUILT_IN, + decoration.params()[0]) + << " variable needs to be a 32-bit float array. " + << message; + })) { + return error; + } + break; + } + case SpvExecutionModelTessellationControl: + case SpvExecutionModelTessellationEvaluation: + case SpvExecutionModelGeometry: + case SpvExecutionModelMeshNV: { + if (decoration.struct_member_index() != Decoration::kInvalidMember) { + // The outer level of array is applied on the variable. + if (spv_result_t error = ValidateF32Arr( + decoration, built_in_inst, /* Any number of components */ 0, + [this, &decoration, &referenced_from_inst]( + const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, + &referenced_from_inst) + << "According to the Vulkan spec BuiltIn " + << _.grammar().lookupOperandName( + SPV_OPERAND_TYPE_BUILT_IN, + decoration.params()[0]) + << " variable needs to be a 32-bit float array. " + << message; + })) { + return error; + } + } else { + if (spv_result_t error = ValidateOptionalArrayedF32Arr( + decoration, built_in_inst, /* Any number of components */ 0, + [this, &decoration, &referenced_from_inst]( + const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, + &referenced_from_inst) + << "According to the Vulkan spec BuiltIn " + << _.grammar().lookupOperandName( + SPV_OPERAND_TYPE_BUILT_IN, + decoration.params()[0]) + << " variable needs to be a 32-bit float array. " + << message; + })) { + return error; + } + } + break; + } + + default: { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn " + << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, + decoration.params()[0]) + << " to be used only with Fragment, Vertex, " + "TessellationControl, TessellationEvaluation or Geometry " + "execution models. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst, execution_model); + } + } + } + } + + if (function_id_ == 0) { + // Propagate this rule to all dependant ids in the global scope. + id_to_at_reference_checks_[referenced_from_inst.id()].push_back( + std::bind(&BuiltInsValidator::ValidateClipOrCullDistanceAtReference, + this, decoration, built_in_inst, referenced_from_inst, + std::placeholders::_1)); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateFragCoordAtDefinition( + const Decoration& decoration, const Instruction& inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + if (spv_result_t error = ValidateF32Vec( + decoration, inst, 4, + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "According to the Vulkan spec BuiltIn FragCoord " + "variable needs to be a 4-component 32-bit float " + "vector. " + << message; + })) { + return error; + } + } + + // Seed at reference checks with this built-in. + return ValidateFragCoordAtReference(decoration, inst, inst, inst); +} + +spv_result_t BuiltInsValidator::ValidateFragCoordAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); + if (storage_class != SpvStorageClassMax && + storage_class != SpvStorageClassInput) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn FragCoord to be only used for " + "variables with Input storage class. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst) + << " " << GetStorageClassDesc(referenced_from_inst); + } + + for (const SpvExecutionModel execution_model : execution_models_) { + if (execution_model != SpvExecutionModelFragment) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn FragCoord to be used only with " + "Fragment execution model. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst, execution_model); + } + } + } + + if (function_id_ == 0) { + // Propagate this rule to all dependant ids in the global scope. + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidateFragCoordAtReference, this, decoration, + built_in_inst, referenced_from_inst, std::placeholders::_1)); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateFragDepthAtDefinition( + const Decoration& decoration, const Instruction& inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + if (spv_result_t error = ValidateF32( + decoration, inst, + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "According to the Vulkan spec BuiltIn FragDepth " + "variable needs to be a 32-bit float scalar. " + << message; + })) { + return error; + } + } + + // Seed at reference checks with this built-in. + return ValidateFragDepthAtReference(decoration, inst, inst, inst); +} + +spv_result_t BuiltInsValidator::ValidateFragDepthAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); + if (storage_class != SpvStorageClassMax && + storage_class != SpvStorageClassOutput) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn FragDepth to be only used for " + "variables with Output storage class. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst) + << " " << GetStorageClassDesc(referenced_from_inst); + } + + for (const SpvExecutionModel execution_model : execution_models_) { + if (execution_model != SpvExecutionModelFragment) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn FragDepth to be used only with " + "Fragment execution model. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst, execution_model); + } + } + + for (const uint32_t entry_point : *entry_points_) { + // Every entry point from which this function is called needs to have + // Execution Mode DepthReplacing. + const auto* modes = _.GetExecutionModes(entry_point); + if (!modes || !modes->count(SpvExecutionModeDepthReplacing)) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec requires DepthReplacing execution mode to be " + "declared when using BuiltIn FragDepth. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst); + } + } + } + + if (function_id_ == 0) { + // Propagate this rule to all dependant ids in the global scope. + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidateFragDepthAtReference, this, decoration, + built_in_inst, referenced_from_inst, std::placeholders::_1)); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateFrontFacingAtDefinition( + const Decoration& decoration, const Instruction& inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + if (spv_result_t error = ValidateBool( + decoration, inst, + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "According to the Vulkan spec BuiltIn FrontFacing " + "variable needs to be a bool scalar. " + << message; + })) { + return error; + } + } + + // Seed at reference checks with this built-in. + return ValidateFrontFacingAtReference(decoration, inst, inst, inst); +} + +spv_result_t BuiltInsValidator::ValidateFrontFacingAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); + if (storage_class != SpvStorageClassMax && + storage_class != SpvStorageClassInput) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn FrontFacing to be only used for " + "variables with Input storage class. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst) + << " " << GetStorageClassDesc(referenced_from_inst); + } + + for (const SpvExecutionModel execution_model : execution_models_) { + if (execution_model != SpvExecutionModelFragment) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn FrontFacing to be used only with " + "Fragment execution model. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst, execution_model); + } + } + } + + if (function_id_ == 0) { + // Propagate this rule to all dependant ids in the global scope. + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidateFrontFacingAtReference, this, decoration, + built_in_inst, referenced_from_inst, std::placeholders::_1)); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateHelperInvocationAtDefinition( + const Decoration& decoration, const Instruction& inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + if (spv_result_t error = ValidateBool( + decoration, inst, + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "According to the Vulkan spec BuiltIn HelperInvocation " + "variable needs to be a bool scalar. " + << message; + })) { + return error; + } + } + + // Seed at reference checks with this built-in. + return ValidateHelperInvocationAtReference(decoration, inst, inst, inst); +} + +spv_result_t BuiltInsValidator::ValidateHelperInvocationAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); + if (storage_class != SpvStorageClassMax && + storage_class != SpvStorageClassInput) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn HelperInvocation to be only used " + "for variables with Input storage class. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst) + << " " << GetStorageClassDesc(referenced_from_inst); + } + + for (const SpvExecutionModel execution_model : execution_models_) { + if (execution_model != SpvExecutionModelFragment) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn HelperInvocation to be used only " + "with Fragment execution model. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst, execution_model); + } + } + } + + if (function_id_ == 0) { + // Propagate this rule to all dependant ids in the global scope. + id_to_at_reference_checks_[referenced_from_inst.id()].push_back( + std::bind(&BuiltInsValidator::ValidateHelperInvocationAtReference, this, + decoration, built_in_inst, referenced_from_inst, + std::placeholders::_1)); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateInvocationIdAtDefinition( + const Decoration& decoration, const Instruction& inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + if (spv_result_t error = ValidateI32( + decoration, inst, + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "According to the Vulkan spec BuiltIn InvocationId " + "variable needs to be a 32-bit int scalar. " + << message; + })) { + return error; + } + } + + // Seed at reference checks with this built-in. + return ValidateInvocationIdAtReference(decoration, inst, inst, inst); +} + +spv_result_t BuiltInsValidator::ValidateInvocationIdAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); + if (storage_class != SpvStorageClassMax && + storage_class != SpvStorageClassInput) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn InvocationId to be only used for " + "variables with Input storage class. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst) + << " " << GetStorageClassDesc(referenced_from_inst); + } + + for (const SpvExecutionModel execution_model : execution_models_) { + if (execution_model != SpvExecutionModelTessellationControl && + execution_model != SpvExecutionModelGeometry) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn InvocationId to be used only " + "with TessellationControl or Geometry execution models. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst, execution_model); + } + } + } + + if (function_id_ == 0) { + // Propagate this rule to all dependant ids in the global scope. + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidateInvocationIdAtReference, this, decoration, + built_in_inst, referenced_from_inst, std::placeholders::_1)); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateInstanceIndexAtDefinition( + const Decoration& decoration, const Instruction& inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + if (spv_result_t error = ValidateI32( + decoration, inst, + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "According to the Vulkan spec BuiltIn InstanceIndex " + "variable needs to be a 32-bit int scalar. " + << message; + })) { + return error; + } + } + + // Seed at reference checks with this built-in. + return ValidateInstanceIndexAtReference(decoration, inst, inst, inst); +} + +spv_result_t BuiltInsValidator::ValidateInstanceIndexAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); + if (storage_class != SpvStorageClassMax && + storage_class != SpvStorageClassInput) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn InstanceIndex to be only used for " + "variables with Input storage class. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst) + << " " << GetStorageClassDesc(referenced_from_inst); + } + + for (const SpvExecutionModel execution_model : execution_models_) { + if (execution_model != SpvExecutionModelVertex) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn InstanceIndex to be used only " + "with Vertex execution model. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst, execution_model); + } + } + } + + if (function_id_ == 0) { + // Propagate this rule to all dependant ids in the global scope. + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidateInstanceIndexAtReference, this, decoration, + built_in_inst, referenced_from_inst, std::placeholders::_1)); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidatePatchVerticesAtDefinition( + const Decoration& decoration, const Instruction& inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + if (spv_result_t error = ValidateI32( + decoration, inst, + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "According to the Vulkan spec BuiltIn PatchVertices " + "variable needs to be a 32-bit int scalar. " + << message; + })) { + return error; + } + } + + // Seed at reference checks with this built-in. + return ValidatePatchVerticesAtReference(decoration, inst, inst, inst); +} + +spv_result_t BuiltInsValidator::ValidatePatchVerticesAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); + if (storage_class != SpvStorageClassMax && + storage_class != SpvStorageClassInput) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn PatchVertices to be only used for " + "variables with Input storage class. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst) + << " " << GetStorageClassDesc(referenced_from_inst); + } + + for (const SpvExecutionModel execution_model : execution_models_) { + if (execution_model != SpvExecutionModelTessellationControl && + execution_model != SpvExecutionModelTessellationEvaluation) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn PatchVertices to be used only " + "with TessellationControl or TessellationEvaluation " + "execution models. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst, execution_model); + } + } + } + + if (function_id_ == 0) { + // Propagate this rule to all dependant ids in the global scope. + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidatePatchVerticesAtReference, this, decoration, + built_in_inst, referenced_from_inst, std::placeholders::_1)); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidatePointCoordAtDefinition( + const Decoration& decoration, const Instruction& inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + if (spv_result_t error = ValidateF32Vec( + decoration, inst, 2, + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "According to the Vulkan spec BuiltIn PointCoord " + "variable needs to be a 2-component 32-bit float " + "vector. " + << message; + })) { + return error; + } + } + + // Seed at reference checks with this built-in. + return ValidatePointCoordAtReference(decoration, inst, inst, inst); +} + +spv_result_t BuiltInsValidator::ValidatePointCoordAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); + if (storage_class != SpvStorageClassMax && + storage_class != SpvStorageClassInput) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn PointCoord to be only used for " + "variables with Input storage class. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst) + << " " << GetStorageClassDesc(referenced_from_inst); + } + + for (const SpvExecutionModel execution_model : execution_models_) { + if (execution_model != SpvExecutionModelFragment) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn PointCoord to be used only with " + "Fragment execution model. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst, execution_model); + } + } + } + + if (function_id_ == 0) { + // Propagate this rule to all dependant ids in the global scope. + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidatePointCoordAtReference, this, decoration, + built_in_inst, referenced_from_inst, std::placeholders::_1)); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidatePointSizeAtDefinition( + const Decoration& decoration, const Instruction& inst) { + // Seed at reference checks with this built-in. + return ValidatePointSizeAtReference(decoration, inst, inst, inst); +} + +spv_result_t BuiltInsValidator::ValidatePointSizeAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); + if (storage_class != SpvStorageClassMax && + storage_class != SpvStorageClassInput && + storage_class != SpvStorageClassOutput) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn PointSize to be only used for " + "variables with Input or Output storage class. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst) + << " " << GetStorageClassDesc(referenced_from_inst); + } + + if (storage_class == SpvStorageClassInput) { + assert(function_id_ == 0); + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidateNotCalledWithExecutionModel, this, + "Vulkan spec doesn't allow BuiltIn PointSize to be used for " + "variables with Input storage class if execution model is Vertex.", + SpvExecutionModelVertex, decoration, built_in_inst, + referenced_from_inst, std::placeholders::_1)); + } + + for (const SpvExecutionModel execution_model : execution_models_) { + switch (execution_model) { + case SpvExecutionModelVertex: { + if (spv_result_t error = ValidateF32( + decoration, built_in_inst, + [this, &referenced_from_inst]( + const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "According to the Vulkan spec BuiltIn PointSize " + "variable needs to be a 32-bit float scalar. " + << message; + })) { + return error; + } + break; + } + case SpvExecutionModelTessellationControl: + case SpvExecutionModelTessellationEvaluation: + case SpvExecutionModelGeometry: + case SpvExecutionModelMeshNV: { + // PointSize can be a per-vertex variable for tessellation control, + // tessellation evaluation and geometry shader stages. In such cases + // variables will have an array of 32-bit floats. + if (decoration.struct_member_index() != Decoration::kInvalidMember) { + // The array is on the variable, so this must be a 32-bit float. + if (spv_result_t error = ValidateF32( + decoration, built_in_inst, + [this, &referenced_from_inst]( + const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, + &referenced_from_inst) + << "According to the Vulkan spec BuiltIn " + "PointSize variable needs to be a 32-bit " + "float scalar. " + << message; + })) { + return error; + } + } else { + if (spv_result_t error = ValidateOptionalArrayedF32( + decoration, built_in_inst, + [this, &referenced_from_inst]( + const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, + &referenced_from_inst) + << "According to the Vulkan spec BuiltIn " + "PointSize variable needs to be a 32-bit " + "float scalar. " + << message; + })) { + return error; + } + } + break; + } + + default: { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn PointSize to be used only with " + "Vertex, TessellationControl, TessellationEvaluation or " + "Geometry execution models. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst, execution_model); + } + } + } + } + + if (function_id_ == 0) { + // Propagate this rule to all dependant ids in the global scope. + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidatePointSizeAtReference, this, decoration, + built_in_inst, referenced_from_inst, std::placeholders::_1)); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidatePositionAtDefinition( + const Decoration& decoration, const Instruction& inst) { + // Seed at reference checks with this built-in. + return ValidatePositionAtReference(decoration, inst, inst, inst); +} + +spv_result_t BuiltInsValidator::ValidatePositionAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); + if (storage_class != SpvStorageClassMax && + storage_class != SpvStorageClassInput && + storage_class != SpvStorageClassOutput) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn Position to be only used for " + "variables with Input or Output storage class. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst) + << " " << GetStorageClassDesc(referenced_from_inst); + } + + if (storage_class == SpvStorageClassInput) { + assert(function_id_ == 0); + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidateNotCalledWithExecutionModel, this, + "Vulkan spec doesn't allow BuiltIn Position to be used for variables " + "with Input storage class if execution model is Vertex.", + SpvExecutionModelVertex, decoration, built_in_inst, + referenced_from_inst, std::placeholders::_1)); + } + + for (const SpvExecutionModel execution_model : execution_models_) { + switch (execution_model) { + case SpvExecutionModelVertex: { + if (spv_result_t error = ValidateF32Vec( + decoration, built_in_inst, 4, + [this, &referenced_from_inst]( + const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "According to the Vulkan spec BuiltIn Position " + "variable needs to be a 4-component 32-bit float " + "vector. " + << message; + })) { + return error; + } + break; + } + case SpvExecutionModelGeometry: + case SpvExecutionModelTessellationControl: + case SpvExecutionModelTessellationEvaluation: + case SpvExecutionModelMeshNV: { + // Position can be a per-vertex variable for tessellation control, + // tessellation evaluation, geometry and mesh shader stages. In such + // cases variables will have an array of 4-component 32-bit float + // vectors. + if (decoration.struct_member_index() != Decoration::kInvalidMember) { + // The array is on the variable, so this must be a 4-component + // 32-bit float vector. + if (spv_result_t error = ValidateF32Vec( + decoration, built_in_inst, 4, + [this, &referenced_from_inst]( + const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, + &referenced_from_inst) + << "According to the Vulkan spec BuiltIn Position " + "variable needs to be a 4-component 32-bit " + "float vector. " + << message; + })) { + return error; + } + } else { + if (spv_result_t error = ValidateOptionalArrayedF32Vec( + decoration, built_in_inst, 4, + [this, &referenced_from_inst]( + const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, + &referenced_from_inst) + << "According to the Vulkan spec BuiltIn Position " + "variable needs to be a 4-component 32-bit " + "float vector. " + << message; + })) { + return error; + } + } + break; + } + + default: { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn Position to be used only " + "with Vertex, TessellationControl, TessellationEvaluation" + " or Geometry execution models. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst, execution_model); + } + } + } + } + + if (function_id_ == 0) { + // Propagate this rule to all dependant ids in the global scope. + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidatePositionAtReference, this, decoration, + built_in_inst, referenced_from_inst, std::placeholders::_1)); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidatePrimitiveIdAtDefinition( + const Decoration& decoration, const Instruction& inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + if (spv_result_t error = ValidateI32( + decoration, inst, + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "According to the Vulkan spec BuiltIn PrimitiveId " + "variable needs to be a 32-bit int scalar. " + << message; + })) { + return error; + } + } + + // Seed at reference checks with this built-in. + return ValidatePrimitiveIdAtReference(decoration, inst, inst, inst); +} + +spv_result_t BuiltInsValidator::ValidatePrimitiveIdAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); + if (storage_class != SpvStorageClassMax && + storage_class != SpvStorageClassInput && + storage_class != SpvStorageClassOutput) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn PrimitiveId to be only used for " + "variables with Input or Output storage class. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst) + << " " << GetStorageClassDesc(referenced_from_inst); + } + + if (storage_class == SpvStorageClassOutput) { + assert(function_id_ == 0); + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidateNotCalledWithExecutionModel, this, + "Vulkan spec doesn't allow BuiltIn PrimitiveId to be used for " + "variables with Output storage class if execution model is " + "TessellationControl.", + SpvExecutionModelTessellationControl, decoration, built_in_inst, + referenced_from_inst, std::placeholders::_1)); + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidateNotCalledWithExecutionModel, this, + "Vulkan spec doesn't allow BuiltIn PrimitiveId to be used for " + "variables with Output storage class if execution model is " + "TessellationEvaluation.", + SpvExecutionModelTessellationEvaluation, decoration, built_in_inst, + referenced_from_inst, std::placeholders::_1)); + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidateNotCalledWithExecutionModel, this, + "Vulkan spec doesn't allow BuiltIn PrimitiveId to be used for " + "variables with Output storage class if execution model is " + "Fragment.", + SpvExecutionModelFragment, decoration, built_in_inst, + referenced_from_inst, std::placeholders::_1)); + } + + for (const SpvExecutionModel execution_model : execution_models_) { + switch (execution_model) { + case SpvExecutionModelFragment: + case SpvExecutionModelTessellationControl: + case SpvExecutionModelTessellationEvaluation: + case SpvExecutionModelGeometry: + case SpvExecutionModelMeshNV: + case SpvExecutionModelRayGenerationNV: + case SpvExecutionModelIntersectionNV: + case SpvExecutionModelAnyHitNV: + case SpvExecutionModelClosestHitNV: + case SpvExecutionModelMissNV: + case SpvExecutionModelCallableNV: { + // Ok. + break; + } + + default: { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn PrimitiveId to be used only " + "with Fragment, TessellationControl, " + "TessellationEvaluation or Geometry execution models. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst, execution_model); + } + } + } + } + + if (function_id_ == 0) { + // Propagate this rule to all dependant ids in the global scope. + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidatePrimitiveIdAtReference, this, decoration, + built_in_inst, referenced_from_inst, std::placeholders::_1)); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateSampleIdAtDefinition( + const Decoration& decoration, const Instruction& inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + if (spv_result_t error = ValidateI32( + decoration, inst, + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "According to the Vulkan spec BuiltIn SampleId " + "variable needs to be a 32-bit int scalar. " + << message; + })) { + return error; + } + } + + // Seed at reference checks with this built-in. + return ValidateSampleIdAtReference(decoration, inst, inst, inst); +} + +spv_result_t BuiltInsValidator::ValidateSampleIdAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); + if (storage_class != SpvStorageClassMax && + storage_class != SpvStorageClassInput) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn SampleId to be only used for " + "variables with Input storage class. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst) + << " " << GetStorageClassDesc(referenced_from_inst); + } + + for (const SpvExecutionModel execution_model : execution_models_) { + if (execution_model != SpvExecutionModelFragment) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn SampleId to be used only with " + "Fragment execution model. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst, execution_model); + } + } + } + + if (function_id_ == 0) { + // Propagate this rule to all dependant ids in the global scope. + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidateSampleIdAtReference, this, decoration, + built_in_inst, referenced_from_inst, std::placeholders::_1)); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateSampleMaskAtDefinition( + const Decoration& decoration, const Instruction& inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + if (spv_result_t error = ValidateI32Arr( + decoration, inst, + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "According to the Vulkan spec BuiltIn SampleMask " + "variable needs to be a 32-bit int array. " + << message; + })) { + return error; + } + } + + // Seed at reference checks with this built-in. + return ValidateSampleMaskAtReference(decoration, inst, inst, inst); +} + +spv_result_t BuiltInsValidator::ValidateSampleMaskAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); + if (storage_class != SpvStorageClassMax && + storage_class != SpvStorageClassInput && + storage_class != SpvStorageClassOutput) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn SampleMask to be only used for " + "variables with Input or Output storage class. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst) + << " " << GetStorageClassDesc(referenced_from_inst); + } + + for (const SpvExecutionModel execution_model : execution_models_) { + if (execution_model != SpvExecutionModelFragment) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn SampleMask to be used only " + "with " + "Fragment execution model. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst, execution_model); + } + } + } + + if (function_id_ == 0) { + // Propagate this rule to all dependant ids in the global scope. + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidateSampleMaskAtReference, this, decoration, + built_in_inst, referenced_from_inst, std::placeholders::_1)); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateSamplePositionAtDefinition( + const Decoration& decoration, const Instruction& inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + if (spv_result_t error = ValidateF32Vec( + decoration, inst, 2, + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "According to the Vulkan spec BuiltIn SamplePosition " + "variable needs to be a 2-component 32-bit float " + "vector. " + << message; + })) { + return error; + } + } + + // Seed at reference checks with this built-in. + return ValidateSamplePositionAtReference(decoration, inst, inst, inst); +} + +spv_result_t BuiltInsValidator::ValidateSamplePositionAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); + if (storage_class != SpvStorageClassMax && + storage_class != SpvStorageClassInput) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn SamplePosition to be only used " + "for " + "variables with Input storage class. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst) + << " " << GetStorageClassDesc(referenced_from_inst); + } + + for (const SpvExecutionModel execution_model : execution_models_) { + if (execution_model != SpvExecutionModelFragment) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn SamplePosition to be used only " + "with " + "Fragment execution model. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst, execution_model); + } + } + } + + if (function_id_ == 0) { + // Propagate this rule to all dependant ids in the global scope. + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidateSamplePositionAtReference, this, decoration, + built_in_inst, referenced_from_inst, std::placeholders::_1)); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateTessCoordAtDefinition( + const Decoration& decoration, const Instruction& inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + if (spv_result_t error = ValidateF32Vec( + decoration, inst, 3, + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "According to the Vulkan spec BuiltIn TessCoord " + "variable needs to be a 3-component 32-bit float " + "vector. " + << message; + })) { + return error; + } + } + + // Seed at reference checks with this built-in. + return ValidateTessCoordAtReference(decoration, inst, inst, inst); +} + +spv_result_t BuiltInsValidator::ValidateTessCoordAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); + if (storage_class != SpvStorageClassMax && + storage_class != SpvStorageClassInput) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn TessCoord to be only used for " + "variables with Input storage class. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst) + << " " << GetStorageClassDesc(referenced_from_inst); + } + + for (const SpvExecutionModel execution_model : execution_models_) { + if (execution_model != SpvExecutionModelTessellationEvaluation) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn TessCoord to be used only with " + "TessellationEvaluation execution model. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst, execution_model); + } + } + } + + if (function_id_ == 0) { + // Propagate this rule to all dependant ids in the global scope. + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidateTessCoordAtReference, this, decoration, + built_in_inst, referenced_from_inst, std::placeholders::_1)); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateTessLevelOuterAtDefinition( + const Decoration& decoration, const Instruction& inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + if (spv_result_t error = ValidateF32Arr( + decoration, inst, 4, + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "According to the Vulkan spec BuiltIn TessLevelOuter " + "variable needs to be a 4-component 32-bit float " + "array. " + << message; + })) { + return error; + } + } + + // Seed at reference checks with this built-in. + return ValidateTessLevelAtReference(decoration, inst, inst, inst); +} + +spv_result_t BuiltInsValidator::ValidateTessLevelInnerAtDefinition( + const Decoration& decoration, const Instruction& inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + if (spv_result_t error = ValidateF32Arr( + decoration, inst, 2, + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "According to the Vulkan spec BuiltIn TessLevelOuter " + "variable needs to be a 2-component 32-bit float " + "array. " + << message; + })) { + return error; + } + } + + // Seed at reference checks with this built-in. + return ValidateTessLevelAtReference(decoration, inst, inst, inst); +} + +spv_result_t BuiltInsValidator::ValidateTessLevelAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); + if (storage_class != SpvStorageClassMax && + storage_class != SpvStorageClassInput && + storage_class != SpvStorageClassOutput) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn " + << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, + decoration.params()[0]) + << " to be only used for variables with Input or Output storage " + "class. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst) + << " " << GetStorageClassDesc(referenced_from_inst); + } + + if (storage_class == SpvStorageClassInput) { + assert(function_id_ == 0); + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidateNotCalledWithExecutionModel, this, + "Vulkan spec doesn't allow TessLevelOuter/TessLevelInner to be " + "used " + "for variables with Input storage class if execution model is " + "TessellationControl.", + SpvExecutionModelTessellationControl, decoration, built_in_inst, + referenced_from_inst, std::placeholders::_1)); + } + + if (storage_class == SpvStorageClassOutput) { + assert(function_id_ == 0); + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidateNotCalledWithExecutionModel, this, + "Vulkan spec doesn't allow TessLevelOuter/TessLevelInner to be " + "used " + "for variables with Output storage class if execution model is " + "TessellationEvaluation.", + SpvExecutionModelTessellationEvaluation, decoration, built_in_inst, + referenced_from_inst, std::placeholders::_1)); + } + + for (const SpvExecutionModel execution_model : execution_models_) { + switch (execution_model) { + case SpvExecutionModelTessellationControl: + case SpvExecutionModelTessellationEvaluation: { + // Ok. + break; + } + + default: { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn " + << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, + decoration.params()[0]) + << " to be used only with TessellationControl or " + "TessellationEvaluation execution models. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst, execution_model); + } + } + } + } + + if (function_id_ == 0) { + // Propagate this rule to all dependant ids in the global scope. + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidateTessLevelAtReference, this, decoration, + built_in_inst, referenced_from_inst, std::placeholders::_1)); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateVertexIndexAtDefinition( + const Decoration& decoration, const Instruction& inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + if (spv_result_t error = ValidateI32( + decoration, inst, + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "According to the Vulkan spec BuiltIn VertexIndex " + "variable needs to be a 32-bit int scalar. " + << message; + })) { + return error; + } + } + + // Seed at reference checks with this built-in. + return ValidateVertexIndexAtReference(decoration, inst, inst, inst); +} + +spv_result_t BuiltInsValidator::ValidateVertexIdOrInstanceIdAtDefinition( + const Decoration& decoration, const Instruction& inst) { + const SpvBuiltIn label = SpvBuiltIn(decoration.params()[0]); + bool allow_instance_id = _.HasCapability(SpvCapabilityRayTracingNV) && + label == SpvBuiltInInstanceId; + if (spvIsVulkanEnv(_.context()->target_env) && !allow_instance_id) { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "Vulkan spec doesn't allow BuiltIn VertexId/InstanceId " + "to be used."; + } + + if (label == SpvBuiltInInstanceId) { + return ValidateInstanceIdAtReference(decoration, inst, inst, inst); + } + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateInstanceIdAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + for (const SpvExecutionModel execution_model : execution_models_) { + switch (execution_model) { + case SpvExecutionModelIntersectionNV: + case SpvExecutionModelClosestHitNV: + case SpvExecutionModelAnyHitNV: + // Do nothing, valid stages + break; + default: + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn InstanceId to be used " + "only with IntersectionNV, ClosestHitNV and AnyHitNV " + "execution models. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst); + break; + } + } + } + + if (function_id_ == 0) { + // Propagate this rule to all dependant ids in the global scope. + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidateInstanceIdAtReference, this, decoration, + built_in_inst, referenced_from_inst, std::placeholders::_1)); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateVertexIndexAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); + if (storage_class != SpvStorageClassMax && + storage_class != SpvStorageClassInput) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn VertexIndex to be only used for " + "variables with Input storage class. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst) + << " " << GetStorageClassDesc(referenced_from_inst); + } + + for (const SpvExecutionModel execution_model : execution_models_) { + if (execution_model != SpvExecutionModelVertex) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn VertexIndex to be used only " + "with " + "Vertex execution model. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst, execution_model); + } + } + } + + if (function_id_ == 0) { + // Propagate this rule to all dependant ids in the global scope. + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidateVertexIndexAtReference, this, decoration, + built_in_inst, referenced_from_inst, std::placeholders::_1)); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateLayerOrViewportIndexAtDefinition( + const Decoration& decoration, const Instruction& inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + if (spv_result_t error = ValidateI32( + decoration, inst, + [this, &decoration, + &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "According to the Vulkan spec BuiltIn " + << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, + decoration.params()[0]) + << "variable needs to be a 32-bit int scalar. " << message; + })) { + return error; + } + } + + // Seed at reference checks with this built-in. + return ValidateLayerOrViewportIndexAtReference(decoration, inst, inst, inst); +} + +spv_result_t BuiltInsValidator::ValidateLayerOrViewportIndexAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); + if (storage_class != SpvStorageClassMax && + storage_class != SpvStorageClassInput && + storage_class != SpvStorageClassOutput) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn " + << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, + decoration.params()[0]) + << " to be only used for variables with Input or Output storage " + "class. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst) + << " " << GetStorageClassDesc(referenced_from_inst); + } + + if (storage_class == SpvStorageClassInput) { + assert(function_id_ == 0); + for (const auto em : + {SpvExecutionModelVertex, SpvExecutionModelTessellationEvaluation, + SpvExecutionModelGeometry}) { + id_to_at_reference_checks_[referenced_from_inst.id()].push_back( + std::bind(&BuiltInsValidator::ValidateNotCalledWithExecutionModel, + this, + "Vulkan spec doesn't allow BuiltIn Layer and " + "ViewportIndex to be " + "used for variables with Input storage class if " + "execution model is Vertex, TessellationEvaluation, or " + "Geometry.", + em, decoration, built_in_inst, referenced_from_inst, + std::placeholders::_1)); + } + } + + if (storage_class == SpvStorageClassOutput) { + assert(function_id_ == 0); + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidateNotCalledWithExecutionModel, this, + "Vulkan spec doesn't allow BuiltIn Layer and " + "ViewportIndex to be " + "used for variables with Output storage class if " + "execution model is " + "Fragment.", + SpvExecutionModelFragment, decoration, built_in_inst, + referenced_from_inst, std::placeholders::_1)); + } + + for (const SpvExecutionModel execution_model : execution_models_) { + switch (execution_model) { + case SpvExecutionModelGeometry: + case SpvExecutionModelFragment: + case SpvExecutionModelMeshNV: { + // Ok. + break; + case SpvExecutionModelVertex: + case SpvExecutionModelTessellationEvaluation: + if (!_.HasCapability(SpvCapabilityShaderViewportIndexLayerEXT)) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Using BuiltIn " + << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, + decoration.params()[0]) + << " in Vertex or Tessellation execution model requires " + "the ShaderViewportIndexLayerEXT capability."; + } + break; + } + default: { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn " + << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, + decoration.params()[0]) + << " to be used only with Vertex, TessellationEvaluation, " + "Geometry, or Fragment execution models. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst, execution_model); + } + } + } + } + + if (function_id_ == 0) { + // Propagate this rule to all dependant ids in the global scope. + id_to_at_reference_checks_[referenced_from_inst.id()].push_back( + std::bind(&BuiltInsValidator::ValidateLayerOrViewportIndexAtReference, + this, decoration, built_in_inst, referenced_from_inst, + std::placeholders::_1)); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateComputeShaderI32Vec3InputAtDefinition( + const Decoration& decoration, const Instruction& inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + if (spv_result_t error = ValidateI32Vec( + decoration, inst, 3, + [this, &decoration, + &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "According to the Vulkan spec BuiltIn " + << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, + decoration.params()[0]) + << " variable needs to be a 3-component 32-bit int " + "vector. " + << message; + })) { + return error; + } + } + + // Seed at reference checks with this built-in. + return ValidateComputeShaderI32Vec3InputAtReference(decoration, inst, inst, + inst); +} + +spv_result_t BuiltInsValidator::ValidateComputeShaderI32Vec3InputAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + const SpvStorageClass storage_class = GetStorageClass(referenced_from_inst); + if (storage_class != SpvStorageClassMax && + storage_class != SpvStorageClassInput) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn " + << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, + decoration.params()[0]) + << " to be only used for variables with Input storage class. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst) + << " " << GetStorageClassDesc(referenced_from_inst); + } + + for (const SpvExecutionModel execution_model : execution_models_) { + if (execution_model != SpvExecutionModelGLCompute && + execution_model != SpvExecutionModelTaskNV && + execution_model != SpvExecutionModelMeshNV) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn " + << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, + decoration.params()[0]) + << " to be used only with GLCompute execution model. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst, execution_model); + } + } + } + + if (function_id_ == 0) { + // Propagate this rule to all dependant ids in the global scope. + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidateComputeShaderI32Vec3InputAtReference, this, + decoration, built_in_inst, referenced_from_inst, + std::placeholders::_1)); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateWorkgroupSizeAtDefinition( + const Decoration& decoration, const Instruction& inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + if (!spvOpcodeIsConstant(inst.opcode())) { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "Vulkan spec requires BuiltIn WorkgroupSize to be a " + "constant. " + << GetIdDesc(inst) << " is not a constant."; + } + + if (spv_result_t error = ValidateI32Vec( + decoration, inst, 3, + [this, &inst](const std::string& message) -> spv_result_t { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << "According to the Vulkan spec BuiltIn WorkgroupSize " + "variable " + "needs to be a 3-component 32-bit int vector. " + << message; + })) { + return error; + } + } + + // Seed at reference checks with this built-in. + return ValidateWorkgroupSizeAtReference(decoration, inst, inst, inst); +} + +spv_result_t BuiltInsValidator::ValidateWorkgroupSizeAtReference( + const Decoration& decoration, const Instruction& built_in_inst, + const Instruction& referenced_inst, + const Instruction& referenced_from_inst) { + if (spvIsVulkanEnv(_.context()->target_env)) { + for (const SpvExecutionModel execution_model : execution_models_) { + if (execution_model != SpvExecutionModelGLCompute) { + return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst) + << "Vulkan spec allows BuiltIn " + << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN, + decoration.params()[0]) + << " to be used only with GLCompute execution model. " + << GetReferenceDesc(decoration, built_in_inst, referenced_inst, + referenced_from_inst, execution_model); + } + } + } + + if (function_id_ == 0) { + // Propagate this rule to all dependant ids in the global scope. + id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind( + &BuiltInsValidator::ValidateWorkgroupSizeAtReference, this, decoration, + built_in_inst, referenced_from_inst, std::placeholders::_1)); + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateSingleBuiltInAtDefinition( + const Decoration& decoration, const Instruction& inst) { + const SpvBuiltIn label = SpvBuiltIn(decoration.params()[0]); + // If you are adding a new BuiltIn enum, please register it here. + // If the newly added enum has validation rules associated with it + // consider leaving a TODO and/or creating an issue. + switch (label) { + case SpvBuiltInClipDistance: + case SpvBuiltInCullDistance: { + return ValidateClipOrCullDistanceAtDefinition(decoration, inst); + } + case SpvBuiltInFragCoord: { + return ValidateFragCoordAtDefinition(decoration, inst); + } + case SpvBuiltInFragDepth: { + return ValidateFragDepthAtDefinition(decoration, inst); + } + case SpvBuiltInFrontFacing: { + return ValidateFrontFacingAtDefinition(decoration, inst); + } + case SpvBuiltInGlobalInvocationId: + case SpvBuiltInLocalInvocationId: + case SpvBuiltInNumWorkgroups: + case SpvBuiltInWorkgroupId: { + return ValidateComputeShaderI32Vec3InputAtDefinition(decoration, inst); + } + case SpvBuiltInHelperInvocation: { + return ValidateHelperInvocationAtDefinition(decoration, inst); + } + case SpvBuiltInInvocationId: { + return ValidateInvocationIdAtDefinition(decoration, inst); + } + case SpvBuiltInInstanceIndex: { + return ValidateInstanceIndexAtDefinition(decoration, inst); + } + case SpvBuiltInLayer: + case SpvBuiltInViewportIndex: { + return ValidateLayerOrViewportIndexAtDefinition(decoration, inst); + } + case SpvBuiltInPatchVertices: { + return ValidatePatchVerticesAtDefinition(decoration, inst); + } + case SpvBuiltInPointCoord: { + return ValidatePointCoordAtDefinition(decoration, inst); + } + case SpvBuiltInPointSize: { + return ValidatePointSizeAtDefinition(decoration, inst); + } + case SpvBuiltInPosition: { + return ValidatePositionAtDefinition(decoration, inst); + } + case SpvBuiltInPrimitiveId: { + return ValidatePrimitiveIdAtDefinition(decoration, inst); + } + case SpvBuiltInSampleId: { + return ValidateSampleIdAtDefinition(decoration, inst); + } + case SpvBuiltInSampleMask: { + return ValidateSampleMaskAtDefinition(decoration, inst); + } + case SpvBuiltInSamplePosition: { + return ValidateSamplePositionAtDefinition(decoration, inst); + } + case SpvBuiltInTessCoord: { + return ValidateTessCoordAtDefinition(decoration, inst); + } + case SpvBuiltInTessLevelOuter: { + return ValidateTessLevelOuterAtDefinition(decoration, inst); + } + case SpvBuiltInTessLevelInner: { + return ValidateTessLevelInnerAtDefinition(decoration, inst); + } + case SpvBuiltInVertexIndex: { + return ValidateVertexIndexAtDefinition(decoration, inst); + } + case SpvBuiltInWorkgroupSize: { + return ValidateWorkgroupSizeAtDefinition(decoration, inst); + } + case SpvBuiltInVertexId: + case SpvBuiltInInstanceId: { + return ValidateVertexIdOrInstanceIdAtDefinition(decoration, inst); + } + case SpvBuiltInLocalInvocationIndex: + case SpvBuiltInWorkDim: + case SpvBuiltInGlobalSize: + case SpvBuiltInEnqueuedWorkgroupSize: + case SpvBuiltInGlobalOffset: + case SpvBuiltInGlobalLinearId: + case SpvBuiltInSubgroupSize: + case SpvBuiltInSubgroupMaxSize: + case SpvBuiltInNumSubgroups: + case SpvBuiltInNumEnqueuedSubgroups: + case SpvBuiltInSubgroupId: + case SpvBuiltInSubgroupLocalInvocationId: + case SpvBuiltInSubgroupEqMaskKHR: + case SpvBuiltInSubgroupGeMaskKHR: + case SpvBuiltInSubgroupGtMaskKHR: + case SpvBuiltInSubgroupLeMaskKHR: + case SpvBuiltInSubgroupLtMaskKHR: + case SpvBuiltInBaseVertex: + case SpvBuiltInBaseInstance: + case SpvBuiltInDrawIndex: + case SpvBuiltInDeviceIndex: + case SpvBuiltInViewIndex: + case SpvBuiltInBaryCoordNoPerspAMD: + case SpvBuiltInBaryCoordNoPerspCentroidAMD: + case SpvBuiltInBaryCoordNoPerspSampleAMD: + case SpvBuiltInBaryCoordSmoothAMD: + case SpvBuiltInBaryCoordSmoothCentroidAMD: + case SpvBuiltInBaryCoordSmoothSampleAMD: + case SpvBuiltInBaryCoordPullModelAMD: + case SpvBuiltInFragStencilRefEXT: + case SpvBuiltInViewportMaskNV: + case SpvBuiltInSecondaryPositionNV: + case SpvBuiltInSecondaryViewportMaskNV: + case SpvBuiltInPositionPerViewNV: + case SpvBuiltInViewportMaskPerViewNV: + case SpvBuiltInFullyCoveredEXT: + case SpvBuiltInMax: + case SpvBuiltInTaskCountNV: + case SpvBuiltInPrimitiveCountNV: + case SpvBuiltInPrimitiveIndicesNV: + case SpvBuiltInClipDistancePerViewNV: + case SpvBuiltInCullDistancePerViewNV: + case SpvBuiltInLayerPerViewNV: + case SpvBuiltInMeshViewCountNV: + case SpvBuiltInMeshViewIndicesNV: + case SpvBuiltInBaryCoordNV: + case SpvBuiltInBaryCoordNoPerspNV: + case SpvBuiltInFragmentSizeNV: // alias SpvBuiltInFragSizeEXT + case SpvBuiltInInvocationsPerPixelNV: // alias + // SpvBuiltInFragInvocationCountEXT + case SpvBuiltInLaunchIdNV: + case SpvBuiltInLaunchSizeNV: + case SpvBuiltInWorldRayOriginNV: + case SpvBuiltInWorldRayDirectionNV: + case SpvBuiltInObjectRayOriginNV: + case SpvBuiltInObjectRayDirectionNV: + case SpvBuiltInRayTminNV: + case SpvBuiltInRayTmaxNV: + case SpvBuiltInInstanceCustomIndexNV: + case SpvBuiltInObjectToWorldNV: + case SpvBuiltInWorldToObjectNV: + case SpvBuiltInHitTNV: + case SpvBuiltInHitKindNV: + case SpvBuiltInIncomingRayFlagsNV: { + // No validation rules (for the moment). + break; + } + } + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::ValidateBuiltInsAtDefinition() { + for (const auto& kv : _.id_decorations()) { + const uint32_t id = kv.first; + const auto& decorations = kv.second; + if (decorations.empty()) { + continue; + } + + const Instruction* inst = _.FindDef(id); + assert(inst); + + for (const auto& decoration : kv.second) { + if (decoration.dec_type() != SpvDecorationBuiltIn) { + continue; + } + + if (spv_result_t error = + ValidateSingleBuiltInAtDefinition(decoration, *inst)) { + return error; + } + } + } + + return SPV_SUCCESS; +} + +spv_result_t BuiltInsValidator::Run() { + // First pass: validate all built-ins at definition and seed + // id_to_at_reference_checks_ with built-ins. + if (auto error = ValidateBuiltInsAtDefinition()) { + return error; + } + + if (id_to_at_reference_checks_.empty()) { + // No validation tasks were seeded. Nothing else to do. + return SPV_SUCCESS; + } + + // Second pass: validate every id reference in the module using + // rules in id_to_at_reference_checks_. + for (const Instruction& inst : _.ordered_instructions()) { + Update(inst); + + std::set already_checked; + + for (const auto& operand : inst.operands()) { + if (!spvIsIdType(operand.type)) { + // Not id. + continue; + } + + const uint32_t id = inst.word(operand.offset); + if (id == inst.id()) { + // No need to check result id. + continue; + } + + if (!already_checked.insert(id).second) { + // The instruction has already referenced this id. + continue; + } + + // Instruction references the id. Run all checks associated with the id + // on the instruction. id_to_at_reference_checks_ can be modified in the + // process, iterators are safe because it's a tree-based map. + const auto it = id_to_at_reference_checks_.find(id); + if (it != id_to_at_reference_checks_.end()) { + for (const auto& check : it->second) { + if (spv_result_t error = check(inst)) { + return error; + } + } + } + } + } + + return SPV_SUCCESS; +} + +} // namespace + +// Validates correctness of built-in variables. +spv_result_t ValidateBuiltIns(ValidationState_t& _) { + if (!spvIsVulkanEnv(_.context()->target_env)) { + // Early return. All currently implemented rules are based on Vulkan spec. + // + // TODO: If you are adding validation rules for environments other than + // Vulkan (or general rules which are not environment independent), then you + // need to modify or remove this condition. Consider also adding early + // returns into BuiltIn-specific rules, so that the system doesn't spawn new + // rules which don't do anything. + return SPV_SUCCESS; + } + + BuiltInsValidator validator(_); + return validator.Run(); +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_capability.cpp b/source/val/validate_capability.cpp new file mode 100644 index 000000000..ad6cb2654 --- /dev/null +++ b/source/val/validate_capability.cpp @@ -0,0 +1,335 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validates OpCapability instruction. + +#include "source/val/validate.h" + +#include +#include +#include + +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +bool IsSupportGuaranteedVulkan_1_0(uint32_t capability) { + switch (capability) { + case SpvCapabilityMatrix: + case SpvCapabilityShader: + case SpvCapabilityInputAttachment: + case SpvCapabilitySampled1D: + case SpvCapabilityImage1D: + case SpvCapabilitySampledBuffer: + case SpvCapabilityImageBuffer: + case SpvCapabilityImageQuery: + case SpvCapabilityDerivativeControl: + return true; + } + return false; +} + +bool IsSupportGuaranteedVulkan_1_1(uint32_t capability) { + if (IsSupportGuaranteedVulkan_1_0(capability)) return true; + switch (capability) { + case SpvCapabilityDeviceGroup: + case SpvCapabilityMultiView: + return true; + } + return false; +} + +bool IsSupportOptionalVulkan_1_0(uint32_t capability) { + switch (capability) { + case SpvCapabilityGeometry: + case SpvCapabilityTessellation: + case SpvCapabilityFloat64: + case SpvCapabilityInt64: + case SpvCapabilityInt16: + case SpvCapabilityTessellationPointSize: + case SpvCapabilityGeometryPointSize: + case SpvCapabilityImageGatherExtended: + case SpvCapabilityStorageImageMultisample: + case SpvCapabilityUniformBufferArrayDynamicIndexing: + case SpvCapabilitySampledImageArrayDynamicIndexing: + case SpvCapabilityStorageBufferArrayDynamicIndexing: + case SpvCapabilityStorageImageArrayDynamicIndexing: + case SpvCapabilityClipDistance: + case SpvCapabilityCullDistance: + case SpvCapabilityImageCubeArray: + case SpvCapabilitySampleRateShading: + case SpvCapabilitySparseResidency: + case SpvCapabilityMinLod: + case SpvCapabilitySampledCubeArray: + case SpvCapabilityImageMSArray: + case SpvCapabilityStorageImageExtendedFormats: + case SpvCapabilityInterpolationFunction: + case SpvCapabilityStorageImageReadWithoutFormat: + case SpvCapabilityStorageImageWriteWithoutFormat: + case SpvCapabilityMultiViewport: + case SpvCapabilityInt64Atomics: + case SpvCapabilityTransformFeedback: + case SpvCapabilityGeometryStreams: + case SpvCapabilityFloat16: + case SpvCapabilityInt8: + return true; + } + return false; +} + +bool IsSupportOptionalVulkan_1_1(uint32_t capability) { + if (IsSupportOptionalVulkan_1_0(capability)) return true; + + switch (capability) { + case SpvCapabilityGroupNonUniform: + case SpvCapabilityGroupNonUniformVote: + case SpvCapabilityGroupNonUniformArithmetic: + case SpvCapabilityGroupNonUniformBallot: + case SpvCapabilityGroupNonUniformShuffle: + case SpvCapabilityGroupNonUniformShuffleRelative: + case SpvCapabilityGroupNonUniformClustered: + case SpvCapabilityGroupNonUniformQuad: + case SpvCapabilityDrawParameters: + // Alias SpvCapabilityStorageBuffer16BitAccess. + case SpvCapabilityStorageUniformBufferBlock16: + // Alias SpvCapabilityUniformAndStorageBuffer16BitAccess. + case SpvCapabilityStorageUniform16: + case SpvCapabilityStoragePushConstant16: + case SpvCapabilityStorageInputOutput16: + case SpvCapabilityDeviceGroup: + case SpvCapabilityMultiView: + case SpvCapabilityVariablePointersStorageBuffer: + case SpvCapabilityVariablePointers: + return true; + } + return false; +} + +bool IsSupportGuaranteedOpenCL_1_2(uint32_t capability, bool embedded_profile) { + switch (capability) { + case SpvCapabilityAddresses: + case SpvCapabilityFloat16Buffer: + case SpvCapabilityGroups: + case SpvCapabilityInt16: + case SpvCapabilityInt8: + case SpvCapabilityKernel: + case SpvCapabilityLinkage: + case SpvCapabilityVector16: + return true; + case SpvCapabilityInt64: + return !embedded_profile; + case SpvCapabilityPipes: + return embedded_profile; + } + return false; +} + +bool IsSupportGuaranteedOpenCL_2_0(uint32_t capability, bool embedded_profile) { + if (IsSupportGuaranteedOpenCL_1_2(capability, embedded_profile)) return true; + + switch (capability) { + case SpvCapabilityDeviceEnqueue: + case SpvCapabilityGenericPointer: + case SpvCapabilityPipes: + return true; + } + return false; +} + +bool IsSupportGuaranteedOpenCL_2_2(uint32_t capability, bool embedded_profile) { + if (IsSupportGuaranteedOpenCL_2_0(capability, embedded_profile)) return true; + + switch (capability) { + case SpvCapabilitySubgroupDispatch: + case SpvCapabilityPipeStorage: + return true; + } + return false; +} + +bool IsSupportOptionalOpenCL_1_2(uint32_t capability) { + switch (capability) { + case SpvCapabilityImageBasic: + case SpvCapabilityFloat64: + return true; + } + return false; +} + +// Checks if |capability| was enabled by extension. +bool IsEnabledByExtension(ValidationState_t& _, uint32_t capability) { + spv_operand_desc operand_desc = nullptr; + _.grammar().lookupOperand(SPV_OPERAND_TYPE_CAPABILITY, capability, + &operand_desc); + + // operand_desc is expected to be not null, otherwise validator would have + // failed at an earlier stage. This 'assert' is 'just in case'. + assert(operand_desc); + + ExtensionSet operand_exts(operand_desc->numExtensions, + operand_desc->extensions); + if (operand_exts.IsEmpty()) return false; + + return _.HasAnyOfExtensions(operand_exts); +} + +bool IsEnabledByCapabilityOpenCL_1_2(ValidationState_t& _, + uint32_t capability) { + if (_.HasCapability(SpvCapabilityImageBasic)) { + switch (capability) { + case SpvCapabilityLiteralSampler: + case SpvCapabilitySampled1D: + case SpvCapabilityImage1D: + case SpvCapabilitySampledBuffer: + case SpvCapabilityImageBuffer: + return true; + } + return false; + } + return false; +} + +bool IsEnabledByCapabilityOpenCL_2_0(ValidationState_t& _, + uint32_t capability) { + if (_.HasCapability(SpvCapabilityImageBasic)) { + switch (capability) { + case SpvCapabilityImageReadWrite: + case SpvCapabilityLiteralSampler: + case SpvCapabilitySampled1D: + case SpvCapabilityImage1D: + case SpvCapabilitySampledBuffer: + case SpvCapabilityImageBuffer: + return true; + } + return false; + } + return false; +} + +bool IsSupportGuaranteedWebGPU(uint32_t capability) { + switch (capability) { + case SpvCapabilityMatrix: + case SpvCapabilityShader: + case SpvCapabilitySampled1D: + case SpvCapabilityImage1D: + case SpvCapabilityDerivativeControl: + case SpvCapabilityImageQuery: + return true; + } + return false; +} + +} // namespace + +// Validates that capability declarations use operands allowed in the current +// context. +spv_result_t CapabilityPass(ValidationState_t& _, const Instruction* inst) { + if (inst->opcode() != SpvOpCapability) return SPV_SUCCESS; + + assert(inst->operands().size() == 1); + + const spv_parsed_operand_t& operand = inst->operand(0); + + assert(operand.num_words == 1); + assert(operand.offset < inst->words().size()); + + const uint32_t capability = inst->word(operand.offset); + const auto capability_str = [&_, capability]() { + spv_operand_desc desc = nullptr; + if (_.grammar().lookupOperand(SPV_OPERAND_TYPE_CAPABILITY, capability, + &desc) != SPV_SUCCESS || + !desc) { + return std::string("Unknown"); + } + return std::string(desc->name); + }; + + const auto env = _.context()->target_env; + const bool opencl_embedded = env == SPV_ENV_OPENCL_EMBEDDED_1_2 || + env == SPV_ENV_OPENCL_EMBEDDED_2_0 || + env == SPV_ENV_OPENCL_EMBEDDED_2_1 || + env == SPV_ENV_OPENCL_EMBEDDED_2_2; + const std::string opencl_profile = opencl_embedded ? "Embedded" : "Full"; + if (env == SPV_ENV_VULKAN_1_0) { + if (!IsSupportGuaranteedVulkan_1_0(capability) && + !IsSupportOptionalVulkan_1_0(capability) && + !IsEnabledByExtension(_, capability)) { + return _.diag(SPV_ERROR_INVALID_CAPABILITY, inst) + << "Capability " << capability_str() + << " is not allowed by Vulkan 1.0 specification" + << " (or requires extension)"; + } + } else if (env == SPV_ENV_VULKAN_1_1) { + if (!IsSupportGuaranteedVulkan_1_1(capability) && + !IsSupportOptionalVulkan_1_1(capability) && + !IsEnabledByExtension(_, capability)) { + return _.diag(SPV_ERROR_INVALID_CAPABILITY, inst) + << "Capability " << capability_str() + << " is not allowed by Vulkan 1.1 specification" + << " (or requires extension)"; + } + } else if (env == SPV_ENV_OPENCL_1_2 || env == SPV_ENV_OPENCL_EMBEDDED_1_2) { + if (!IsSupportGuaranteedOpenCL_1_2(capability, opencl_embedded) && + !IsSupportOptionalOpenCL_1_2(capability) && + !IsEnabledByExtension(_, capability) && + !IsEnabledByCapabilityOpenCL_1_2(_, capability)) { + return _.diag(SPV_ERROR_INVALID_CAPABILITY, inst) + << "Capability " << capability_str() + << " is not allowed by OpenCL 1.2 " << opencl_profile + << " Profile specification" + << " (or requires extension or capability)"; + } + } else if (env == SPV_ENV_OPENCL_2_0 || env == SPV_ENV_OPENCL_EMBEDDED_2_0 || + env == SPV_ENV_OPENCL_2_1 || env == SPV_ENV_OPENCL_EMBEDDED_2_1) { + if (!IsSupportGuaranteedOpenCL_2_0(capability, opencl_embedded) && + !IsSupportOptionalOpenCL_1_2(capability) && + !IsEnabledByExtension(_, capability) && + !IsEnabledByCapabilityOpenCL_2_0(_, capability)) { + return _.diag(SPV_ERROR_INVALID_CAPABILITY, inst) + << "Capability " << capability_str() + << " is not allowed by OpenCL 2.0/2.1 " << opencl_profile + << " Profile specification" + << " (or requires extension or capability)"; + } + } else if (env == SPV_ENV_OPENCL_2_2 || env == SPV_ENV_OPENCL_EMBEDDED_2_2) { + if (!IsSupportGuaranteedOpenCL_2_2(capability, opencl_embedded) && + !IsSupportOptionalOpenCL_1_2(capability) && + !IsEnabledByExtension(_, capability) && + !IsEnabledByCapabilityOpenCL_2_0(_, capability)) { + return _.diag(SPV_ERROR_INVALID_CAPABILITY, inst) + << "Capability " << capability_str() + << " is not allowed by OpenCL 2.2 " << opencl_profile + << " Profile specification" + << " (or requires extension or capability)"; + } + } else if (env == SPV_ENV_WEBGPU_0) { + if (!IsSupportGuaranteedWebGPU(capability) && + !IsEnabledByExtension(_, capability)) { + return _.diag(SPV_ERROR_INVALID_CAPABILITY, inst) + << "Capability " << capability_str() + << " is not allowed by WebGPU specification" + << " (or requires extension)"; + } + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_cfg.cpp b/source/val/validate_cfg.cpp new file mode 100644 index 000000000..8fe30a881 --- /dev/null +++ b/source/val/validate_cfg.cpp @@ -0,0 +1,781 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/validate.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/cfa.h" +#include "source/opcode.h" +#include "source/spirv_validator_options.h" +#include "source/val/basic_block.h" +#include "source/val/construct.h" +#include "source/val/function.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +spv_result_t ValidatePhi(ValidationState_t& _, const Instruction* inst) { + auto block = inst->block(); + size_t num_in_ops = inst->words().size() - 3; + if (num_in_ops % 2 != 0) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpPhi does not have an equal number of incoming values and " + "basic blocks."; + } + + const Instruction* type_inst = _.FindDef(inst->type_id()); + assert(type_inst); + + const SpvOp type_opcode = type_inst->opcode(); + if (type_opcode == SpvOpTypePointer && + _.addressing_model() == SpvAddressingModelLogical) { + if (!_.features().variable_pointers && + !_.features().variable_pointers_storage_buffer) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Using pointers with OpPhi requires capability " + << "VariablePointers or VariablePointersStorageBuffer"; + } + } + + // Create a uniqued vector of predecessor ids for comparison against + // incoming values. OpBranchConditional %cond %label %label produces two + // predecessors in the CFG. + std::vector pred_ids; + std::transform(block->predecessors()->begin(), block->predecessors()->end(), + std::back_inserter(pred_ids), + [](const BasicBlock* b) { return b->id(); }); + std::sort(pred_ids.begin(), pred_ids.end()); + pred_ids.erase(std::unique(pred_ids.begin(), pred_ids.end()), pred_ids.end()); + + size_t num_edges = num_in_ops / 2; + if (num_edges != pred_ids.size()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpPhi's number of incoming blocks (" << num_edges + << ") does not match block's predecessor count (" + << block->predecessors()->size() << ")."; + } + + for (size_t i = 3; i < inst->words().size(); ++i) { + auto inc_id = inst->word(i); + if (i % 2 == 1) { + // Incoming value type must match the phi result type. + auto inc_type_id = _.GetTypeId(inc_id); + if (inst->type_id() != inc_type_id) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpPhi's result type " << _.getIdName(inst->type_id()) + << " does not match incoming value " << _.getIdName(inc_id) + << " type " << _.getIdName(inc_type_id) << "."; + } + } else { + if (_.GetIdOpcode(inc_id) != SpvOpLabel) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpPhi's incoming basic block " << _.getIdName(inc_id) + << " is not an OpLabel."; + } + + // Incoming basic block must be an immediate predecessor of the phi's + // block. + if (!std::binary_search(pred_ids.begin(), pred_ids.end(), inc_id)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpPhi's incoming basic block " << _.getIdName(inc_id) + << " is not a predecessor of " << _.getIdName(block->id()) + << "."; + } + } + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateBranchConditional(ValidationState_t& _, + const Instruction* inst) { + // num_operands is either 3 or 5 --- if 5, the last two need to be literal + // integers + const auto num_operands = inst->operands().size(); + if (num_operands != 3 && num_operands != 5) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpBranchConditional requires either 3 or 5 parameters"; + } + + // grab the condition operand and check that it is a bool + const auto cond_id = inst->GetOperandAs(0); + const auto cond_op = _.FindDef(cond_id); + if (!cond_op || !cond_op->type_id() || + !_.IsBoolScalarType(cond_op->type_id())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) << "Condition operand for " + "OpBranchConditional must be " + "of boolean type"; + } + + // target operands must be OpLabel + // note that we don't need to check that the target labels are in the same + // function, + // PerformCfgChecks already checks for that + const auto true_id = inst->GetOperandAs(1); + const auto true_target = _.FindDef(true_id); + if (!true_target || SpvOpLabel != true_target->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "The 'True Label' operand for OpBranchConditional must be the " + "ID of an OpLabel instruction"; + } + + const auto false_id = inst->GetOperandAs(2); + const auto false_target = _.FindDef(false_id); + if (!false_target || SpvOpLabel != false_target->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "The 'False Label' operand for OpBranchConditional must be the " + "ID of an OpLabel instruction"; + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateReturnValue(ValidationState_t& _, + const Instruction* inst) { + const auto value_id = inst->GetOperandAs(0); + const auto value = _.FindDef(value_id); + if (!value || !value->type_id()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpReturnValue Value '" << _.getIdName(value_id) + << "' does not represent a value."; + } + auto value_type = _.FindDef(value->type_id()); + if (!value_type || SpvOpTypeVoid == value_type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpReturnValue value's type '" + << _.getIdName(value->type_id()) << "' is missing or void."; + } + + const bool uses_variable_pointer = + _.features().variable_pointers || + _.features().variable_pointers_storage_buffer; + + if (_.addressing_model() == SpvAddressingModelLogical && + SpvOpTypePointer == value_type->opcode() && !uses_variable_pointer && + !_.options()->relax_logical_pointer) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpReturnValue value's type '" + << _.getIdName(value->type_id()) + << "' is a pointer, which is invalid in the Logical addressing " + "model."; + } + + const auto function = inst->function(); + const auto return_type = _.FindDef(function->GetResultTypeId()); + if (!return_type || return_type->id() != value_type->id()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpReturnValue Value '" << _.getIdName(value_id) + << "'s type does not match OpFunction's return type."; + } + + return SPV_SUCCESS; +} + +} // namespace + +void printDominatorList(const BasicBlock& b) { + std::cout << b.id() << " is dominated by: "; + const BasicBlock* bb = &b; + while (bb->immediate_dominator() != bb) { + bb = bb->immediate_dominator(); + std::cout << bb->id() << " "; + } +} + +#define CFG_ASSERT(ASSERT_FUNC, TARGET) \ + if (spv_result_t rcode = ASSERT_FUNC(_, TARGET)) return rcode + +spv_result_t FirstBlockAssert(ValidationState_t& _, uint32_t target) { + if (_.current_function().IsFirstBlock(target)) { + return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(_.current_function().id())) + << "First block " << _.getIdName(target) << " of function " + << _.getIdName(_.current_function().id()) << " is targeted by block " + << _.getIdName(_.current_function().current_block()->id()); + } + return SPV_SUCCESS; +} + +spv_result_t MergeBlockAssert(ValidationState_t& _, uint32_t merge_block) { + if (_.current_function().IsBlockType(merge_block, kBlockTypeMerge)) { + return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(_.current_function().id())) + << "Block " << _.getIdName(merge_block) + << " is already a merge block for another header"; + } + return SPV_SUCCESS; +} + +/// Update the continue construct's exit blocks once the backedge blocks are +/// identified in the CFG. +void UpdateContinueConstructExitBlocks( + Function& function, + const std::vector>& back_edges) { + auto& constructs = function.constructs(); + // TODO(umar): Think of a faster way to do this + for (auto& edge : back_edges) { + uint32_t back_edge_block_id; + uint32_t loop_header_block_id; + std::tie(back_edge_block_id, loop_header_block_id) = edge; + auto is_this_header = [=](Construct& c) { + return c.type() == ConstructType::kLoop && + c.entry_block()->id() == loop_header_block_id; + }; + + for (auto construct : constructs) { + if (is_this_header(construct)) { + Construct* continue_construct = + construct.corresponding_constructs().back(); + assert(continue_construct->type() == ConstructType::kContinue); + + BasicBlock* back_edge_block; + std::tie(back_edge_block, std::ignore) = + function.GetBlock(back_edge_block_id); + continue_construct->set_exit(back_edge_block); + } + } + } +} + +std::tuple ConstructNames( + ConstructType type) { + std::string construct_name, header_name, exit_name; + + switch (type) { + case ConstructType::kSelection: + construct_name = "selection"; + header_name = "selection header"; + exit_name = "merge block"; + break; + case ConstructType::kLoop: + construct_name = "loop"; + header_name = "loop header"; + exit_name = "merge block"; + break; + case ConstructType::kContinue: + construct_name = "continue"; + header_name = "continue target"; + exit_name = "back-edge block"; + break; + case ConstructType::kCase: + construct_name = "case"; + header_name = "case entry block"; + exit_name = "case exit block"; + break; + default: + assert(1 == 0 && "Not defined type"); + } + + return std::make_tuple(construct_name, header_name, exit_name); +} + +/// Constructs an error message for construct validation errors +std::string ConstructErrorString(const Construct& construct, + const std::string& header_string, + const std::string& exit_string, + const std::string& dominate_text) { + std::string construct_name, header_name, exit_name; + std::tie(construct_name, header_name, exit_name) = + ConstructNames(construct.type()); + + // TODO(umar): Add header block for continue constructs to error message + return "The " + construct_name + " construct with the " + header_name + " " + + header_string + " " + dominate_text + " the " + exit_name + " " + + exit_string; +} + +// Finds the fall through case construct of |target_block| and records it in +// |case_fall_through|. Returns SPV_ERROR_INVALID_CFG if the case construct +// headed by |target_block| branches to multiple case constructs. +spv_result_t FindCaseFallThrough( + ValidationState_t& _, BasicBlock* target_block, uint32_t* case_fall_through, + const BasicBlock* merge, const std::unordered_set& case_targets, + Function* function) { + std::vector stack; + stack.push_back(target_block); + std::unordered_set visited; + bool target_reachable = target_block->reachable(); + int target_depth = function->GetBlockDepth(target_block); + while (!stack.empty()) { + auto block = stack.back(); + stack.pop_back(); + + if (block == merge) continue; + + if (!visited.insert(block).second) continue; + + if (target_reachable && block->reachable() && + target_block->dominates(*block)) { + // Still in the case construct. + for (auto successor : *block->successors()) { + stack.push_back(successor); + } + } else { + // Exiting the case construct to non-merge block. + if (!case_targets.count(block->id())) { + int depth = function->GetBlockDepth(block); + if ((depth < target_depth) || + (depth == target_depth && block->is_type(kBlockTypeContinue))) { + continue; + } + + return _.diag(SPV_ERROR_INVALID_CFG, target_block->label()) + << "Case construct that targets " + << _.getIdName(target_block->id()) + << " has invalid branch to block " << _.getIdName(block->id()) + << " (not another case construct, corresponding merge, outer " + "loop merge or outer loop continue)"; + } + + if (*case_fall_through == 0u) { + if (target_block != block) { + *case_fall_through = block->id(); + } + } else if (*case_fall_through != block->id()) { + // Case construct has at most one branch to another case construct. + return _.diag(SPV_ERROR_INVALID_CFG, target_block->label()) + << "Case construct that targets " + << _.getIdName(target_block->id()) + << " has branches to multiple other case construct targets " + << _.getIdName(*case_fall_through) << " and " + << _.getIdName(block->id()); + } + } + } + + return SPV_SUCCESS; +} + +spv_result_t StructuredSwitchChecks(ValidationState_t& _, Function* function, + const Instruction* switch_inst, + const BasicBlock* header, + const BasicBlock* merge) { + std::unordered_set case_targets; + for (uint32_t i = 1; i < switch_inst->operands().size(); i += 2) { + uint32_t target = switch_inst->GetOperandAs(i); + if (target != merge->id()) case_targets.insert(target); + } + // Tracks how many times each case construct is targeted by another case + // construct. + std::map num_fall_through_targeted; + uint32_t default_case_fall_through = 0u; + uint32_t default_target = switch_inst->GetOperandAs(1u); + std::unordered_set seen; + for (uint32_t i = 1; i < switch_inst->operands().size(); i += 2) { + uint32_t target = switch_inst->GetOperandAs(i); + if (target == merge->id()) continue; + + if (!seen.insert(target).second) continue; + + const auto target_block = function->GetBlock(target).first; + // OpSwitch must dominate all its case constructs. + if (header->reachable() && target_block->reachable() && + !header->dominates(*target_block)) { + return _.diag(SPV_ERROR_INVALID_CFG, header->label()) + << "Selection header " << _.getIdName(header->id()) + << " does not dominate its case construct " << _.getIdName(target); + } + + uint32_t case_fall_through = 0u; + if (auto error = FindCaseFallThrough(_, target_block, &case_fall_through, + merge, case_targets, function)) { + return error; + } + + // Track how many time the fall through case has been targeted. + if (case_fall_through != 0u) { + auto where = num_fall_through_targeted.lower_bound(case_fall_through); + if (where == num_fall_through_targeted.end() || + where->first != case_fall_through) { + num_fall_through_targeted.insert(where, + std::make_pair(case_fall_through, 1)); + } else { + where->second++; + } + } + + if (case_fall_through == default_target) { + case_fall_through = default_case_fall_through; + } + if (case_fall_through != 0u) { + bool is_default = i == 1; + if (is_default) { + default_case_fall_through = case_fall_through; + } else { + // Allow code like: + // case x: + // case y: + // ... + // case z: + // + // Where x and y target the same block and fall through to z. + uint32_t j = i; + while ((j + 2 < switch_inst->operands().size()) && + target == switch_inst->GetOperandAs(j + 2)) { + j += 2; + } + // If Target T1 branches to Target T2, or if Target T1 branches to the + // Default target and the Default target branches to Target T2, then T1 + // must immediately precede T2 in the list of OpSwitch Target operands. + if ((switch_inst->operands().size() < j + 2) || + (case_fall_through != switch_inst->GetOperandAs(j + 2))) { + return _.diag(SPV_ERROR_INVALID_CFG, switch_inst) + << "Case construct that targets " << _.getIdName(target) + << " has branches to the case construct that targets " + << _.getIdName(case_fall_through) + << ", but does not immediately precede it in the " + "OpSwitch's target list"; + } + } + } + } + + // Each case construct must be branched to by at most one other case + // construct. + for (const auto& pair : num_fall_through_targeted) { + if (pair.second > 1) { + return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(pair.first)) + << "Multiple case constructs have branches to the case construct " + "that targets " + << _.getIdName(pair.first); + } + } + + return SPV_SUCCESS; +} + +spv_result_t StructuredControlFlowChecks( + ValidationState_t& _, Function* function, + const std::vector>& back_edges) { + /// Check all backedges target only loop headers and have exactly one + /// back-edge branching to it + + // Map a loop header to blocks with back-edges to the loop header. + std::map> loop_latch_blocks; + for (auto back_edge : back_edges) { + uint32_t back_edge_block; + uint32_t header_block; + std::tie(back_edge_block, header_block) = back_edge; + if (!function->IsBlockType(header_block, kBlockTypeLoop)) { + return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(back_edge_block)) + << "Back-edges (" << _.getIdName(back_edge_block) << " -> " + << _.getIdName(header_block) + << ") can only be formed between a block and a loop header."; + } + loop_latch_blocks[header_block].insert(back_edge_block); + } + + // Check the loop headers have exactly one back-edge branching to it + for (BasicBlock* loop_header : function->ordered_blocks()) { + if (!loop_header->reachable()) continue; + if (!loop_header->is_type(kBlockTypeLoop)) continue; + auto loop_header_id = loop_header->id(); + auto num_latch_blocks = loop_latch_blocks[loop_header_id].size(); + if (num_latch_blocks != 1) { + return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(loop_header_id)) + << "Loop header " << _.getIdName(loop_header_id) + << " is targeted by " << num_latch_blocks + << " back-edge blocks but the standard requires exactly one"; + } + } + + // Check construct rules + for (const Construct& construct : function->constructs()) { + auto header = construct.entry_block(); + auto merge = construct.exit_block(); + + if (header->reachable() && !merge) { + std::string construct_name, header_name, exit_name; + std::tie(construct_name, header_name, exit_name) = + ConstructNames(construct.type()); + return _.diag(SPV_ERROR_INTERNAL, _.FindDef(header->id())) + << "Construct " + construct_name + " with " + header_name + " " + + _.getIdName(header->id()) + " does not have a " + + exit_name + ". This may be a bug in the validator."; + } + + // If the exit block is reachable then it's dominated by the + // header. + if (merge && merge->reachable()) { + if (!header->dominates(*merge)) { + return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(merge->id())) + << ConstructErrorString(construct, _.getIdName(header->id()), + _.getIdName(merge->id()), + "does not dominate"); + } + // If it's really a merge block for a selection or loop, then it must be + // *strictly* dominated by the header. + if (construct.ExitBlockIsMergeBlock() && (header == merge)) { + return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(merge->id())) + << ConstructErrorString(construct, _.getIdName(header->id()), + _.getIdName(merge->id()), + "does not strictly dominate"); + } + } + // Check post-dominance for continue constructs. But dominance and + // post-dominance only make sense when the construct is reachable. + if (header->reachable() && construct.type() == ConstructType::kContinue) { + if (!merge->postdominates(*header)) { + return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(merge->id())) + << ConstructErrorString(construct, _.getIdName(header->id()), + _.getIdName(merge->id()), + "is not post dominated by"); + } + } + + // Check that for all non-header blocks, all predecessors are within this + // construct. + Construct::ConstructBlockSet construct_blocks = construct.blocks(function); + for (auto block : construct_blocks) { + if (block == header) continue; + for (auto pred : *block->predecessors()) { + if (pred->reachable() && !construct_blocks.count(pred)) { + std::string construct_name, header_name, exit_name; + std::tie(construct_name, header_name, exit_name) = + ConstructNames(construct.type()); + return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(pred->id())) + << "block " << pred->id() << " branches to the " + << construct_name << " construct, but not to the " + << header_name << " " << header->id(); + } + } + } + + // Checks rules for case constructs. + if (construct.type() == ConstructType::kSelection && + header->terminator()->opcode() == SpvOpSwitch) { + const auto terminator = header->terminator(); + if (auto error = + StructuredSwitchChecks(_, function, terminator, header, merge)) { + return error; + } + } + } + return SPV_SUCCESS; +} + +spv_result_t PerformCfgChecks(ValidationState_t& _) { + for (auto& function : _.functions()) { + // Check all referenced blocks are defined within a function + if (function.undefined_block_count() != 0) { + std::string undef_blocks("{"); + bool first = true; + for (auto undefined_block : function.undefined_blocks()) { + undef_blocks += _.getIdName(undefined_block); + if (!first) { + undef_blocks += " "; + } + first = false; + } + return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(function.id())) + << "Block(s) " << undef_blocks << "}" + << " are referenced but not defined in function " + << _.getIdName(function.id()); + } + + // Set each block's immediate dominator and immediate postdominator, + // and find all back-edges. + // + // We want to analyze all the blocks in the function, even in degenerate + // control flow cases including unreachable blocks. So use the augmented + // CFG to ensure we cover all the blocks. + std::vector postorder; + std::vector postdom_postorder; + std::vector> back_edges; + auto ignore_block = [](const BasicBlock*) {}; + auto ignore_edge = [](const BasicBlock*, const BasicBlock*) {}; + if (!function.ordered_blocks().empty()) { + /// calculate dominators + CFA::DepthFirstTraversal( + function.first_block(), function.AugmentedCFGSuccessorsFunction(), + ignore_block, [&](const BasicBlock* b) { postorder.push_back(b); }, + ignore_edge); + auto edges = CFA::CalculateDominators( + postorder, function.AugmentedCFGPredecessorsFunction()); + for (auto edge : edges) { + edge.first->SetImmediateDominator(edge.second); + } + + /// calculate post dominators + CFA::DepthFirstTraversal( + function.pseudo_exit_block(), + function.AugmentedCFGPredecessorsFunction(), ignore_block, + [&](const BasicBlock* b) { postdom_postorder.push_back(b); }, + ignore_edge); + auto postdom_edges = CFA::CalculateDominators( + postdom_postorder, function.AugmentedCFGSuccessorsFunction()); + for (auto edge : postdom_edges) { + edge.first->SetImmediatePostDominator(edge.second); + } + /// calculate back edges. + CFA::DepthFirstTraversal( + function.pseudo_entry_block(), + function + .AugmentedCFGSuccessorsFunctionIncludingHeaderToContinueEdge(), + ignore_block, ignore_block, + [&](const BasicBlock* from, const BasicBlock* to) { + back_edges.emplace_back(from->id(), to->id()); + }); + } + UpdateContinueConstructExitBlocks(function, back_edges); + + auto& blocks = function.ordered_blocks(); + if (!blocks.empty()) { + // Check if the order of blocks in the binary appear before the blocks + // they dominate + for (auto block = begin(blocks) + 1; block != end(blocks); ++block) { + if (auto idom = (*block)->immediate_dominator()) { + if (idom != function.pseudo_entry_block() && + block == std::find(begin(blocks), block, idom)) { + return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(idom->id())) + << "Block " << _.getIdName((*block)->id()) + << " appears in the binary before its dominator " + << _.getIdName(idom->id()); + } + } + } + // If we have structed control flow, check that no block has a control + // flow nesting depth larger than the limit. + if (_.HasCapability(SpvCapabilityShader)) { + const int control_flow_nesting_depth_limit = + _.options()->universal_limits_.max_control_flow_nesting_depth; + for (auto block = begin(blocks); block != end(blocks); ++block) { + if (function.GetBlockDepth(*block) > + control_flow_nesting_depth_limit) { + return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef((*block)->id())) + << "Maximum Control Flow nesting depth exceeded."; + } + } + } + } + + /// Structured control flow checks are only required for shader capabilities + if (_.HasCapability(SpvCapabilityShader)) { + if (auto error = StructuredControlFlowChecks(_, &function, back_edges)) + return error; + } + } + return SPV_SUCCESS; +} + +spv_result_t CfgPass(ValidationState_t& _, const Instruction* inst) { + SpvOp opcode = inst->opcode(); + switch (opcode) { + case SpvOpLabel: + if (auto error = _.current_function().RegisterBlock(inst->id())) + return error; + + // TODO(github:1661) This should be done in the + // ValidationState::RegisterInstruction method but because of the order of + // passes the OpLabel ends up not being part of the basic block it starts. + _.current_function().current_block()->set_label(inst); + break; + case SpvOpLoopMerge: { + uint32_t merge_block = inst->GetOperandAs(0); + uint32_t continue_block = inst->GetOperandAs(1); + CFG_ASSERT(MergeBlockAssert, merge_block); + + if (auto error = _.current_function().RegisterLoopMerge(merge_block, + continue_block)) + return error; + } break; + case SpvOpSelectionMerge: { + uint32_t merge_block = inst->GetOperandAs(0); + CFG_ASSERT(MergeBlockAssert, merge_block); + + if (auto error = _.current_function().RegisterSelectionMerge(merge_block)) + return error; + } break; + case SpvOpBranch: { + uint32_t target = inst->GetOperandAs(0); + CFG_ASSERT(FirstBlockAssert, target); + + _.current_function().RegisterBlockEnd({target}, opcode); + } break; + case SpvOpBranchConditional: { + uint32_t tlabel = inst->GetOperandAs(1); + uint32_t flabel = inst->GetOperandAs(2); + CFG_ASSERT(FirstBlockAssert, tlabel); + CFG_ASSERT(FirstBlockAssert, flabel); + + _.current_function().RegisterBlockEnd({tlabel, flabel}, opcode); + } break; + + case SpvOpSwitch: { + std::vector cases; + for (size_t i = 1; i < inst->operands().size(); i += 2) { + uint32_t target = inst->GetOperandAs(i); + CFG_ASSERT(FirstBlockAssert, target); + cases.push_back(target); + } + _.current_function().RegisterBlockEnd({cases}, opcode); + } break; + case SpvOpReturn: { + const uint32_t return_type = _.current_function().GetResultTypeId(); + const Instruction* return_type_inst = _.FindDef(return_type); + assert(return_type_inst); + if (return_type_inst->opcode() != SpvOpTypeVoid) + return _.diag(SPV_ERROR_INVALID_CFG, inst) + << "OpReturn can only be called from a function with void " + << "return type."; + } + // Fallthrough. + case SpvOpKill: + case SpvOpReturnValue: + case SpvOpUnreachable: + _.current_function().RegisterBlockEnd(std::vector(), opcode); + if (opcode == SpvOpKill) { + _.current_function().RegisterExecutionModelLimitation( + SpvExecutionModelFragment, + "OpKill requires Fragment execution model"); + } + break; + default: + break; + } + return SPV_SUCCESS; +} + +spv_result_t ControlFlowPass(ValidationState_t& _, const Instruction* inst) { + switch (inst->opcode()) { + case SpvOpPhi: + if (auto error = ValidatePhi(_, inst)) return error; + break; + case SpvOpBranchConditional: + if (auto error = ValidateBranchConditional(_, inst)) return error; + break; + case SpvOpReturnValue: + if (auto error = ValidateReturnValue(_, inst)) return error; + break; + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_composites.cpp b/source/val/validate_composites.cpp new file mode 100644 index 000000000..ccc558773 --- /dev/null +++ b/source/val/validate_composites.cpp @@ -0,0 +1,521 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validates correctness of composite SPIR-V instructions. + +#include "source/val/validate.h" + +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/spirv_target_env.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +// Returns the type of the value accessed by OpCompositeExtract or +// OpCompositeInsert instruction. The function traverses the hierarchy of +// nested data structures (structs, arrays, vectors, matrices) as directed by +// the sequence of indices in the instruction. May return error if traversal +// fails (encountered non-composite, out of bounds, nesting too deep). +// Returns the type of Composite operand if the instruction has no indices. +spv_result_t GetExtractInsertValueType(ValidationState_t& _, + const Instruction* inst, + uint32_t* member_type) { + const SpvOp opcode = inst->opcode(); + assert(opcode == SpvOpCompositeExtract || opcode == SpvOpCompositeInsert); + uint32_t word_index = opcode == SpvOpCompositeExtract ? 4 : 5; + const uint32_t num_words = static_cast(inst->words().size()); + const uint32_t composite_id_index = word_index - 1; + + const uint32_t num_indices = num_words - word_index; + const uint32_t kCompositeExtractInsertMaxNumIndices = 255; + if (num_indices > kCompositeExtractInsertMaxNumIndices) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "The number of indexes in Op" << spvOpcodeString(opcode) + << " may not exceed " << kCompositeExtractInsertMaxNumIndices + << ". Found " << num_indices << " indexes."; + } + + *member_type = _.GetTypeId(inst->word(composite_id_index)); + if (*member_type == 0) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Composite to be an object of composite type"; + } + + for (; word_index < num_words; ++word_index) { + const uint32_t component_index = inst->word(word_index); + const Instruction* const type_inst = _.FindDef(*member_type); + assert(type_inst); + switch (type_inst->opcode()) { + case SpvOpTypeVector: { + *member_type = type_inst->word(2); + const uint32_t vector_size = type_inst->word(3); + if (component_index >= vector_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Vector access is out of bounds, vector size is " + << vector_size << ", but access index is " << component_index; + } + break; + } + case SpvOpTypeMatrix: { + *member_type = type_inst->word(2); + const uint32_t num_cols = type_inst->word(3); + if (component_index >= num_cols) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Matrix access is out of bounds, matrix has " << num_cols + << " columns, but access index is " << component_index; + } + break; + } + case SpvOpTypeArray: { + uint64_t array_size = 0; + auto size = _.FindDef(type_inst->word(3)); + *member_type = type_inst->word(2); + if (spvOpcodeIsSpecConstant(size->opcode())) { + // Cannot verify against the size of this array. + break; + } + + if (!_.GetConstantValUint64(type_inst->word(3), &array_size)) { + assert(0 && "Array type definition is corrupt"); + } + if (component_index >= array_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Array access is out of bounds, array size is " + << array_size << ", but access index is " << component_index; + } + break; + } + case SpvOpTypeRuntimeArray: { + *member_type = type_inst->word(2); + // Array size is unknown. + break; + } + case SpvOpTypeStruct: { + const size_t num_struct_members = type_inst->words().size() - 2; + if (component_index >= num_struct_members) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Index is out of bounds, can not find index " + << component_index << " in the structure '" + << type_inst->id() << "'. This structure has " + << num_struct_members << " members. Largest valid index is " + << num_struct_members - 1 << "."; + } + *member_type = type_inst->word(component_index + 2); + break; + } + default: + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Reached non-composite type while indexes still remain to " + "be traversed."; + } + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateVectorExtractDynamic(ValidationState_t& _, + const Instruction* inst) { + const uint32_t result_type = inst->type_id(); + const SpvOp result_opcode = _.GetIdOpcode(result_type); + if (!spvOpcodeIsScalarType(result_opcode)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be a scalar type"; + } + + const uint32_t vector_type = _.GetOperandTypeId(inst, 2); + const SpvOp vector_opcode = _.GetIdOpcode(vector_type); + if (vector_opcode != SpvOpTypeVector) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Vector type to be OpTypeVector"; + } + + if (_.GetComponentType(vector_type) != result_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Vector component type to be equal to Result Type"; + } + + const auto index = _.FindDef(inst->GetOperandAs(3)); + if (!index || index->type_id() == 0 || !_.IsIntScalarType(index->type_id())) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Index to be int scalar"; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateVectorInsertDyanmic(ValidationState_t& _, + const Instruction* inst) { + const uint32_t result_type = inst->type_id(); + const SpvOp result_opcode = _.GetIdOpcode(result_type); + if (result_opcode != SpvOpTypeVector) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be OpTypeVector"; + } + + const uint32_t vector_type = _.GetOperandTypeId(inst, 2); + if (vector_type != result_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Vector type to be equal to Result Type"; + } + + const uint32_t component_type = _.GetOperandTypeId(inst, 3); + if (_.GetComponentType(result_type) != component_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Component type to be equal to Result Type " + << "component type"; + } + + const uint32_t index_type = _.GetOperandTypeId(inst, 4); + if (!_.IsIntScalarType(index_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Index to be int scalar"; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateCompositeConstruct(ValidationState_t& _, + const Instruction* inst) { + const uint32_t num_operands = static_cast(inst->operands().size()); + const uint32_t result_type = inst->type_id(); + const SpvOp result_opcode = _.GetIdOpcode(result_type); + switch (result_opcode) { + case SpvOpTypeVector: { + const uint32_t num_result_components = _.GetDimension(result_type); + const uint32_t result_component_type = _.GetComponentType(result_type); + uint32_t given_component_count = 0; + + if (num_operands <= 3) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected number of constituents to be at least 2"; + } + + for (uint32_t operand_index = 2; operand_index < num_operands; + ++operand_index) { + const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index); + if (operand_type == result_component_type) { + ++given_component_count; + } else { + if (_.GetIdOpcode(operand_type) != SpvOpTypeVector || + _.GetComponentType(operand_type) != result_component_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Constituents to be scalars or vectors of" + << " the same type as Result Type components"; + } + + given_component_count += _.GetDimension(operand_type); + } + } + + if (num_result_components != given_component_count) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected total number of given components to be equal " + << "to the size of Result Type vector"; + } + + break; + } + case SpvOpTypeMatrix: { + uint32_t result_num_rows = 0; + uint32_t result_num_cols = 0; + uint32_t result_col_type = 0; + uint32_t result_component_type = 0; + if (!_.GetMatrixTypeInfo(result_type, &result_num_rows, &result_num_cols, + &result_col_type, &result_component_type)) { + assert(0); + } + + if (result_num_cols + 2 != num_operands) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected total number of Constituents to be equal " + << "to the number of columns of Result Type matrix"; + } + + for (uint32_t operand_index = 2; operand_index < num_operands; + ++operand_index) { + const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index); + if (operand_type != result_col_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Constituent type to be equal to the column " + << "type Result Type matrix"; + } + } + + break; + } + case SpvOpTypeArray: { + const Instruction* const array_inst = _.FindDef(result_type); + assert(array_inst); + assert(array_inst->opcode() == SpvOpTypeArray); + + auto size = _.FindDef(array_inst->word(3)); + if (spvOpcodeIsSpecConstant(size->opcode())) { + // Cannot verify against the size of this array. + break; + } + + uint64_t array_size = 0; + if (!_.GetConstantValUint64(array_inst->word(3), &array_size)) { + assert(0 && "Array type definition is corrupt"); + } + + if (array_size + 2 != num_operands) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected total number of Constituents to be equal " + << "to the number of elements of Result Type array"; + } + + const uint32_t result_component_type = array_inst->word(2); + for (uint32_t operand_index = 2; operand_index < num_operands; + ++operand_index) { + const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index); + if (operand_type != result_component_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Constituent type to be equal to the column " + << "type Result Type array"; + } + } + + break; + } + case SpvOpTypeStruct: { + const Instruction* const struct_inst = _.FindDef(result_type); + assert(struct_inst); + assert(struct_inst->opcode() == SpvOpTypeStruct); + + if (struct_inst->operands().size() + 1 != num_operands) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected total number of Constituents to be equal " + << "to the number of members of Result Type struct"; + } + + for (uint32_t operand_index = 2; operand_index < num_operands; + ++operand_index) { + const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index); + const uint32_t member_type = struct_inst->word(operand_index); + if (operand_type != member_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Constituent type to be equal to the " + << "corresponding member type of Result Type struct"; + } + } + + break; + } + default: { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be a composite type"; + } + } + return SPV_SUCCESS; +} + +spv_result_t ValidateCompositeExtract(ValidationState_t& _, + const Instruction* inst) { + uint32_t member_type = 0; + if (spv_result_t error = GetExtractInsertValueType(_, inst, &member_type)) { + return error; + } + + const uint32_t result_type = inst->type_id(); + if (result_type != member_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Result type (Op" << spvOpcodeString(_.GetIdOpcode(result_type)) + << ") does not match the type that results from indexing into " + "the composite (Op" + << spvOpcodeString(_.GetIdOpcode(member_type)) << ")."; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateCompositeInsert(ValidationState_t& _, + const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + const uint32_t object_type = _.GetOperandTypeId(inst, 2); + const uint32_t composite_type = _.GetOperandTypeId(inst, 3); + const uint32_t result_type = inst->type_id(); + if (result_type != composite_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "The Result Type must be the same as Composite type in Op" + << spvOpcodeString(opcode) << " yielding Result Id " << result_type + << "."; + } + + uint32_t member_type = 0; + if (spv_result_t error = GetExtractInsertValueType(_, inst, &member_type)) { + return error; + } + + if (object_type != member_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "The Object type (Op" + << spvOpcodeString(_.GetIdOpcode(object_type)) + << ") does not match the type that results from indexing into the " + "Composite (Op" + << spvOpcodeString(_.GetIdOpcode(member_type)) << ")."; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateCopyObject(ValidationState_t& _, const Instruction* inst) { + const uint32_t result_type = inst->type_id(); + const uint32_t operand_type = _.GetOperandTypeId(inst, 2); + if (operand_type != result_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type and Operand type to be the same"; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateTranspose(ValidationState_t& _, const Instruction* inst) { + uint32_t result_num_rows = 0; + uint32_t result_num_cols = 0; + uint32_t result_col_type = 0; + uint32_t result_component_type = 0; + const uint32_t result_type = inst->type_id(); + if (!_.GetMatrixTypeInfo(result_type, &result_num_rows, &result_num_cols, + &result_col_type, &result_component_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be a matrix type"; + } + + const uint32_t matrix_type = _.GetOperandTypeId(inst, 2); + uint32_t matrix_num_rows = 0; + uint32_t matrix_num_cols = 0; + uint32_t matrix_col_type = 0; + uint32_t matrix_component_type = 0; + if (!_.GetMatrixTypeInfo(matrix_type, &matrix_num_rows, &matrix_num_cols, + &matrix_col_type, &matrix_component_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Matrix to be of type OpTypeMatrix"; + } + + if (result_component_type != matrix_component_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected component types of Matrix and Result Type to be " + << "identical"; + } + + if (result_num_rows != matrix_num_cols || + result_num_cols != matrix_num_rows) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected number of columns and the column size of Matrix " + << "to be the reverse of those of Result Type"; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateVectorShuffle(ValidationState_t& _, + const Instruction* inst) { + auto resultType = _.FindDef(inst->type_id()); + if (!resultType || resultType->opcode() != SpvOpTypeVector) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "The Result Type of OpVectorShuffle must be" + << " OpTypeVector. Found Op" + << spvOpcodeString(static_cast(resultType->opcode())) << "."; + } + + // The number of components in Result Type must be the same as the number of + // Component operands. + auto componentCount = inst->operands().size() - 4; + auto resultVectorDimension = resultType->GetOperandAs(2); + if (componentCount != resultVectorDimension) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpVectorShuffle component literals count does not match " + "Result Type '" + << _.getIdName(resultType->id()) << "'s vector component count."; + } + + // Vector 1 and Vector 2 must both have vector types, with the same Component + // Type as Result Type. + auto vector1Object = _.FindDef(inst->GetOperandAs(2)); + auto vector1Type = _.FindDef(vector1Object->type_id()); + auto vector2Object = _.FindDef(inst->GetOperandAs(3)); + auto vector2Type = _.FindDef(vector2Object->type_id()); + if (!vector1Type || vector1Type->opcode() != SpvOpTypeVector) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "The type of Vector 1 must be OpTypeVector."; + } + if (!vector2Type || vector2Type->opcode() != SpvOpTypeVector) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "The type of Vector 2 must be OpTypeVector."; + } + + auto resultComponentType = resultType->GetOperandAs(1); + if (vector1Type->GetOperandAs(1) != resultComponentType) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "The Component Type of Vector 1 must be the same as ResultType."; + } + if (vector2Type->GetOperandAs(1) != resultComponentType) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "The Component Type of Vector 2 must be the same as ResultType."; + } + + // All Component literals must either be FFFFFFFF or in [0, N - 1]. + // For WebGPU specifically, Component literals cannot be FFFFFFFF. + auto vector1ComponentCount = vector1Type->GetOperandAs(2); + auto vector2ComponentCount = vector2Type->GetOperandAs(2); + auto N = vector1ComponentCount + vector2ComponentCount; + auto firstLiteralIndex = 4; + const auto is_webgpu_env = spvIsWebGPUEnv(_.context()->target_env); + for (size_t i = firstLiteralIndex; i < inst->operands().size(); ++i) { + auto literal = inst->GetOperandAs(i); + if (literal != 0xFFFFFFFF && literal >= N) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Component index " << literal << " is out of bounds for " + << "combined (Vector1 + Vector2) size of " << N << "."; + } + + if (is_webgpu_env && literal == 0xFFFFFFFF) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Component literal at operand " << i - firstLiteralIndex + << " cannot be 0xFFFFFFFF in WebGPU execution environment."; + } + } + + return SPV_SUCCESS; +} + +} // anonymous namespace + +// Validates correctness of composite instructions. +spv_result_t CompositesPass(ValidationState_t& _, const Instruction* inst) { + switch (inst->opcode()) { + case SpvOpVectorExtractDynamic: + return ValidateVectorExtractDynamic(_, inst); + case SpvOpVectorInsertDynamic: + return ValidateVectorInsertDyanmic(_, inst); + case SpvOpVectorShuffle: + return ValidateVectorShuffle(_, inst); + case SpvOpCompositeConstruct: + return ValidateCompositeConstruct(_, inst); + case SpvOpCompositeExtract: + return ValidateCompositeExtract(_, inst); + case SpvOpCompositeInsert: + return ValidateCompositeInsert(_, inst); + case SpvOpCopyObject: + return ValidateCopyObject(_, inst); + case SpvOpTranspose: + return ValidateTranspose(_, inst); + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_constants.cpp b/source/val/validate_constants.cpp new file mode 100644 index 000000000..e2f20f672 --- /dev/null +++ b/source/val/validate_constants.cpp @@ -0,0 +1,407 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/validate.h" + +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +spv_result_t ValidateConstantBool(ValidationState_t& _, + const Instruction* inst) { + auto type = _.FindDef(inst->type_id()); + if (!type || type->opcode() != SpvOpTypeBool) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Op" << spvOpcodeString(inst->opcode()) << " Result Type '" + << _.getIdName(inst->type_id()) << "' is not a boolean type."; + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateConstantComposite(ValidationState_t& _, + const Instruction* inst) { + std::string opcode_name = std::string("Op") + spvOpcodeString(inst->opcode()); + + const auto result_type = _.FindDef(inst->type_id()); + if (!result_type || !spvOpcodeIsComposite(result_type->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Result Type '" + << _.getIdName(inst->type_id()) << "' is not a composite type."; + } + + const auto constituent_count = inst->words().size() - 3; + switch (result_type->opcode()) { + case SpvOpTypeVector: { + const auto component_count = result_type->GetOperandAs(2); + if (component_count != constituent_count) { + // TODO: Output ID's on diagnostic + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name + << " Constituent count does not match " + "Result Type '" + << _.getIdName(result_type->id()) + << "'s vector component count."; + } + const auto component_type = + _.FindDef(result_type->GetOperandAs(1)); + if (!component_type) { + return _.diag(SPV_ERROR_INVALID_ID, result_type) + << "Component type is not defined."; + } + for (size_t constituent_index = 2; + constituent_index < inst->operands().size(); constituent_index++) { + const auto constituent_id = + inst->GetOperandAs(constituent_index); + const auto constituent = _.FindDef(constituent_id); + if (!constituent || + !spvOpcodeIsConstantOrUndef(constituent->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Constituent '" + << _.getIdName(constituent_id) + << "' is not a constant or undef."; + } + const auto constituent_result_type = _.FindDef(constituent->type_id()); + if (!constituent_result_type || + component_type->opcode() != constituent_result_type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Constituent '" + << _.getIdName(constituent_id) + << "'s type does not match Result Type '" + << _.getIdName(result_type->id()) << "'s vector element type."; + } + } + } break; + case SpvOpTypeMatrix: { + const auto column_count = result_type->GetOperandAs(2); + if (column_count != constituent_count) { + // TODO: Output ID's on diagnostic + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name + << " Constituent count does not match " + "Result Type '" + << _.getIdName(result_type->id()) << "'s matrix column count."; + } + + const auto column_type = _.FindDef(result_type->words()[2]); + if (!column_type) { + return _.diag(SPV_ERROR_INVALID_ID, result_type) + << "Column type is not defined."; + } + const auto component_count = column_type->GetOperandAs(2); + const auto component_type = + _.FindDef(column_type->GetOperandAs(1)); + if (!component_type) { + return _.diag(SPV_ERROR_INVALID_ID, column_type) + << "Component type is not defined."; + } + + for (size_t constituent_index = 2; + constituent_index < inst->operands().size(); constituent_index++) { + const auto constituent_id = + inst->GetOperandAs(constituent_index); + const auto constituent = _.FindDef(constituent_id); + if (!constituent || + !(SpvOpConstantComposite == constituent->opcode() || + SpvOpSpecConstantComposite == constituent->opcode() || + SpvOpUndef == constituent->opcode())) { + // The message says "... or undef" because the spec does not say + // undef is a constant. + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Constituent '" + << _.getIdName(constituent_id) + << "' is not a constant composite or undef."; + } + const auto vector = _.FindDef(constituent->type_id()); + if (!vector) { + return _.diag(SPV_ERROR_INVALID_ID, constituent) + << "Result type is not defined."; + } + if (column_type->opcode() != vector->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Constituent '" + << _.getIdName(constituent_id) + << "' type does not match Result Type '" + << _.getIdName(result_type->id()) << "'s matrix column type."; + } + const auto vector_component_type = + _.FindDef(vector->GetOperandAs(1)); + if (component_type->id() != vector_component_type->id()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Constituent '" + << _.getIdName(constituent_id) + << "' component type does not match Result Type '" + << _.getIdName(result_type->id()) + << "'s matrix column component type."; + } + if (component_count != vector->words()[3]) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Constituent '" + << _.getIdName(constituent_id) + << "' vector component count does not match Result Type '" + << _.getIdName(result_type->id()) + << "'s vector component count."; + } + } + } break; + case SpvOpTypeArray: { + auto element_type = _.FindDef(result_type->GetOperandAs(1)); + if (!element_type) { + return _.diag(SPV_ERROR_INVALID_ID, result_type) + << "Element type is not defined."; + } + const auto length = _.FindDef(result_type->GetOperandAs(2)); + if (!length) { + return _.diag(SPV_ERROR_INVALID_ID, result_type) + << "Length is not defined."; + } + bool is_int32; + bool is_const; + uint32_t value; + std::tie(is_int32, is_const, value) = _.EvalInt32IfConst(length->id()); + if (is_int32 && is_const && value != constituent_count) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name + << " Constituent count does not match " + "Result Type '" + << _.getIdName(result_type->id()) << "'s array length."; + } + for (size_t constituent_index = 2; + constituent_index < inst->operands().size(); constituent_index++) { + const auto constituent_id = + inst->GetOperandAs(constituent_index); + const auto constituent = _.FindDef(constituent_id); + if (!constituent || + !spvOpcodeIsConstantOrUndef(constituent->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Constituent '" + << _.getIdName(constituent_id) + << "' is not a constant or undef."; + } + const auto constituent_type = _.FindDef(constituent->type_id()); + if (!constituent_type) { + return _.diag(SPV_ERROR_INVALID_ID, constituent) + << "Result type is not defined."; + } + if (element_type->id() != constituent_type->id()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Constituent '" + << _.getIdName(constituent_id) + << "'s type does not match Result Type '" + << _.getIdName(result_type->id()) << "'s array element type."; + } + } + } break; + case SpvOpTypeStruct: { + const auto member_count = result_type->words().size() - 2; + if (member_count != constituent_count) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Constituent '" + << _.getIdName(inst->type_id()) + << "' count does not match Result Type '" + << _.getIdName(result_type->id()) << "'s struct member count."; + } + for (uint32_t constituent_index = 2, member_index = 1; + constituent_index < inst->operands().size(); + constituent_index++, member_index++) { + const auto constituent_id = + inst->GetOperandAs(constituent_index); + const auto constituent = _.FindDef(constituent_id); + if (!constituent || + !spvOpcodeIsConstantOrUndef(constituent->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Constituent '" + << _.getIdName(constituent_id) + << "' is not a constant or undef."; + } + const auto constituent_type = _.FindDef(constituent->type_id()); + if (!constituent_type) { + return _.diag(SPV_ERROR_INVALID_ID, constituent) + << "Result type is not defined."; + } + + const auto member_type_id = + result_type->GetOperandAs(member_index); + const auto member_type = _.FindDef(member_type_id); + if (!member_type || member_type->id() != constituent_type->id()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << opcode_name << " Constituent '" + << _.getIdName(constituent_id) + << "' type does not match the Result Type '" + << _.getIdName(result_type->id()) << "'s member type."; + } + } + } break; + default: + break; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateConstantSampler(ValidationState_t& _, + const Instruction* inst) { + const auto result_type = _.FindDef(inst->type_id()); + if (!result_type || result_type->opcode() != SpvOpTypeSampler) { + return _.diag(SPV_ERROR_INVALID_ID, result_type) + << "OpConstantSampler Result Type '" + << _.getIdName(inst->type_id()) << "' is not a sampler type."; + } + + return SPV_SUCCESS; +} + +// True if instruction defines a type that can have a null value, as defined by +// the SPIR-V spec. Tracks composite-type components through module to check +// nullability transitively. +bool IsTypeNullable(const std::vector& instruction, + const ValidationState_t& _) { + uint16_t opcode; + uint16_t word_count; + spvOpcodeSplit(instruction[0], &word_count, &opcode); + switch (static_cast(opcode)) { + case SpvOpTypeBool: + case SpvOpTypeInt: + case SpvOpTypeFloat: + case SpvOpTypePointer: + case SpvOpTypeEvent: + case SpvOpTypeDeviceEvent: + case SpvOpTypeReserveId: + case SpvOpTypeQueue: + return true; + case SpvOpTypeArray: + case SpvOpTypeMatrix: + case SpvOpTypeVector: { + auto base_type = _.FindDef(instruction[2]); + return base_type && IsTypeNullable(base_type->words(), _); + } + case SpvOpTypeStruct: { + for (size_t elementIndex = 2; elementIndex < instruction.size(); + ++elementIndex) { + auto element = _.FindDef(instruction[elementIndex]); + if (!element || !IsTypeNullable(element->words(), _)) return false; + } + return true; + } + default: + return false; + } +} + +spv_result_t ValidateConstantNull(ValidationState_t& _, + const Instruction* inst) { + const auto result_type = _.FindDef(inst->type_id()); + if (!result_type || !IsTypeNullable(result_type->words(), _)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpConstantNull Result Type '" + << _.getIdName(inst->type_id()) << "' cannot have a null value."; + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateSpecConstantOp(ValidationState_t& _, + const Instruction* inst) { + const auto op = inst->GetOperandAs(2); + + // The binary parser already ensures that the op is valid for *some* + // environment. Here we check restrictions. + switch(op) { + case SpvOpQuantizeToF16: + if (!_.HasCapability(SpvCapabilityShader)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Specialization constant operation " << spvOpcodeString(op) + << " requires Shader capability"; + } + break; + + case SpvOpUConvert: + if (!_.features().uconvert_spec_constant_op && + !_.HasCapability(SpvCapabilityKernel)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "UConvert requires Kernel capability or extension " + "SPV_AMD_gpu_shader_int16"; + } + break; + + case SpvOpConvertFToS: + case SpvOpConvertSToF: + case SpvOpConvertFToU: + case SpvOpConvertUToF: + case SpvOpConvertPtrToU: + case SpvOpConvertUToPtr: + case SpvOpGenericCastToPtr: + case SpvOpPtrCastToGeneric: + case SpvOpBitcast: + case SpvOpFNegate: + case SpvOpFAdd: + case SpvOpFSub: + case SpvOpFMul: + case SpvOpFDiv: + case SpvOpFRem: + case SpvOpFMod: + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + case SpvOpPtrAccessChain: + case SpvOpInBoundsPtrAccessChain: + if (!_.HasCapability(SpvCapabilityKernel)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Specialization constant operation " << spvOpcodeString(op) + << " requires Kernel capability"; + } + break; + + default: + break; + } + + // TODO(dneto): Validate result type and arguments to the various operations. + return SPV_SUCCESS; +} + +} // namespace + +spv_result_t ConstantPass(ValidationState_t& _, const Instruction* inst) { + switch (inst->opcode()) { + case SpvOpConstantTrue: + case SpvOpConstantFalse: + case SpvOpSpecConstantTrue: + case SpvOpSpecConstantFalse: + if (auto error = ValidateConstantBool(_, inst)) return error; + break; + case SpvOpConstantComposite: + case SpvOpSpecConstantComposite: + if (auto error = ValidateConstantComposite(_, inst)) return error; + break; + case SpvOpConstantSampler: + if (auto error = ValidateConstantSampler(_, inst)) return error; + break; + case SpvOpConstantNull: + if (auto error = ValidateConstantNull(_, inst)) return error; + break; + case SpvOpSpecConstantOp: + if (auto error = ValidateSpecConstantOp(_, inst)) return error; + break; + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_conversion.cpp b/source/val/validate_conversion.cpp new file mode 100644 index 000000000..73da58255 --- /dev/null +++ b/source/val/validate_conversion.cpp @@ -0,0 +1,462 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validates correctness of conversion instructions. + +#include "source/val/validate.h" + +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { + +// Validates correctness of conversion instructions. +spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + const uint32_t result_type = inst->type_id(); + + switch (opcode) { + case SpvOpConvertFToU: { + if (!_.IsUnsignedIntScalarType(result_type) && + !_.IsUnsignedIntVectorType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected unsigned int scalar or vector type as Result Type: " + << spvOpcodeString(opcode); + + const uint32_t input_type = _.GetOperandTypeId(inst, 2); + if (!input_type || (!_.IsFloatScalarType(input_type) && + !_.IsFloatVectorType(input_type))) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to be float scalar or vector: " + << spvOpcodeString(opcode); + + if (_.GetDimension(result_type) != _.GetDimension(input_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to have the same dimension as Result Type: " + << spvOpcodeString(opcode); + + if (!_.features().use_int8_type && (8 == _.GetBitWidth(result_type))) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Invalid cast to 8-bit integer from a floating-point: " + << spvOpcodeString(opcode); + + break; + } + + case SpvOpConvertFToS: { + if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected int scalar or vector type as Result Type: " + << spvOpcodeString(opcode); + + const uint32_t input_type = _.GetOperandTypeId(inst, 2); + if (!input_type || (!_.IsFloatScalarType(input_type) && + !_.IsFloatVectorType(input_type))) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to be float scalar or vector: " + << spvOpcodeString(opcode); + + if (_.GetDimension(result_type) != _.GetDimension(input_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to have the same dimension as Result Type: " + << spvOpcodeString(opcode); + + if (!_.features().use_int8_type && (8 == _.GetBitWidth(result_type))) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Invalid cast to 8-bit integer from a floating-point: " + << spvOpcodeString(opcode); + + break; + } + + case SpvOpConvertSToF: + case SpvOpConvertUToF: { + if (!_.IsFloatScalarType(result_type) && + !_.IsFloatVectorType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected float scalar or vector type as Result Type: " + << spvOpcodeString(opcode); + + const uint32_t input_type = _.GetOperandTypeId(inst, 2); + if (!input_type || + (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type))) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to be int scalar or vector: " + << spvOpcodeString(opcode); + + if (_.GetDimension(result_type) != _.GetDimension(input_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to have the same dimension as Result Type: " + << spvOpcodeString(opcode); + + if (!_.features().use_int8_type && (8 == _.GetBitWidth(input_type))) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Invalid cast to floating-point from an 8-bit integer: " + << spvOpcodeString(opcode); + + break; + } + + case SpvOpUConvert: { + if (!_.IsUnsignedIntScalarType(result_type) && + !_.IsUnsignedIntVectorType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected unsigned int scalar or vector type as Result Type: " + << spvOpcodeString(opcode); + + const uint32_t input_type = _.GetOperandTypeId(inst, 2); + if (!input_type || + (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type))) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to be int scalar or vector: " + << spvOpcodeString(opcode); + + if (_.GetDimension(result_type) != _.GetDimension(input_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to have the same dimension as Result Type: " + << spvOpcodeString(opcode); + + if (_.GetBitWidth(result_type) == _.GetBitWidth(input_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to have different bit width from Result " + "Type: " + << spvOpcodeString(opcode); + break; + } + + case SpvOpSConvert: { + if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected int scalar or vector type as Result Type: " + << spvOpcodeString(opcode); + + const uint32_t input_type = _.GetOperandTypeId(inst, 2); + if (!input_type || + (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type))) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to be int scalar or vector: " + << spvOpcodeString(opcode); + + if (_.GetDimension(result_type) != _.GetDimension(input_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to have the same dimension as Result Type: " + << spvOpcodeString(opcode); + + if (_.GetBitWidth(result_type) == _.GetBitWidth(input_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to have different bit width from Result " + "Type: " + << spvOpcodeString(opcode); + break; + } + + case SpvOpFConvert: { + if (!_.IsFloatScalarType(result_type) && + !_.IsFloatVectorType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected float scalar or vector type as Result Type: " + << spvOpcodeString(opcode); + + const uint32_t input_type = _.GetOperandTypeId(inst, 2); + if (!input_type || (!_.IsFloatScalarType(input_type) && + !_.IsFloatVectorType(input_type))) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to be float scalar or vector: " + << spvOpcodeString(opcode); + + if (_.GetDimension(result_type) != _.GetDimension(input_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to have the same dimension as Result Type: " + << spvOpcodeString(opcode); + + if (_.GetBitWidth(result_type) == _.GetBitWidth(input_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to have different bit width from Result " + "Type: " + << spvOpcodeString(opcode); + break; + } + + case SpvOpQuantizeToF16: { + if ((!_.IsFloatScalarType(result_type) && + !_.IsFloatVectorType(result_type)) || + _.GetBitWidth(result_type) != 32) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected 32-bit float scalar or vector type as Result Type: " + << spvOpcodeString(opcode); + + const uint32_t input_type = _.GetOperandTypeId(inst, 2); + if (input_type != result_type) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input type to be equal to Result Type: " + << spvOpcodeString(opcode); + break; + } + + case SpvOpConvertPtrToU: { + if (!_.IsUnsignedIntScalarType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected unsigned int scalar type as Result Type: " + << spvOpcodeString(opcode); + + const uint32_t input_type = _.GetOperandTypeId(inst, 2); + if (!_.IsPointerType(input_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to be a pointer: " << spvOpcodeString(opcode); + + if (_.addressing_model() == SpvAddressingModelLogical) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Logical addressing not supported: " + << spvOpcodeString(opcode); + + if (_.addressing_model() == + SpvAddressingModelPhysicalStorageBuffer64EXT) { + uint32_t input_storage_class = 0; + uint32_t input_data_type = 0; + _.GetPointerTypeInfo(input_type, &input_data_type, + &input_storage_class); + if (input_storage_class != SpvStorageClassPhysicalStorageBufferEXT) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Pointer storage class must be PhysicalStorageBufferEXT: " + << spvOpcodeString(opcode); + } + break; + } + + case SpvOpSatConvertSToU: + case SpvOpSatConvertUToS: { + if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected int scalar or vector type as Result Type: " + << spvOpcodeString(opcode); + + const uint32_t input_type = _.GetOperandTypeId(inst, 2); + if (!input_type || + (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type))) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected int scalar or vector as input: " + << spvOpcodeString(opcode); + + if (_.GetDimension(result_type) != _.GetDimension(input_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to have the same dimension as Result Type: " + << spvOpcodeString(opcode); + break; + } + + case SpvOpConvertUToPtr: { + if (!_.IsPointerType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be a pointer: " + << spvOpcodeString(opcode); + + const uint32_t input_type = _.GetOperandTypeId(inst, 2); + if (!input_type || !_.IsIntScalarType(input_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected int scalar as input: " << spvOpcodeString(opcode); + + if (_.addressing_model() == SpvAddressingModelLogical) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Logical addressing not supported: " + << spvOpcodeString(opcode); + + if (_.addressing_model() == + SpvAddressingModelPhysicalStorageBuffer64EXT) { + uint32_t result_storage_class = 0; + uint32_t result_data_type = 0; + _.GetPointerTypeInfo(result_type, &result_data_type, + &result_storage_class); + if (result_storage_class != SpvStorageClassPhysicalStorageBufferEXT) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Pointer storage class must be PhysicalStorageBufferEXT: " + << spvOpcodeString(opcode); + } + break; + } + + case SpvOpPtrCastToGeneric: { + uint32_t result_storage_class = 0; + uint32_t result_data_type = 0; + if (!_.GetPointerTypeInfo(result_type, &result_data_type, + &result_storage_class)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be a pointer: " + << spvOpcodeString(opcode); + + if (result_storage_class != SpvStorageClassGeneric) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to have storage class Generic: " + << spvOpcodeString(opcode); + + const uint32_t input_type = _.GetOperandTypeId(inst, 2); + uint32_t input_storage_class = 0; + uint32_t input_data_type = 0; + if (!_.GetPointerTypeInfo(input_type, &input_data_type, + &input_storage_class)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to be a pointer: " << spvOpcodeString(opcode); + + if (input_storage_class != SpvStorageClassWorkgroup && + input_storage_class != SpvStorageClassCrossWorkgroup && + input_storage_class != SpvStorageClassFunction) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to have storage class Workgroup, " + << "CrossWorkgroup or Function: " << spvOpcodeString(opcode); + + if (result_data_type != input_data_type) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input and Result Type to point to the same type: " + << spvOpcodeString(opcode); + break; + } + + case SpvOpGenericCastToPtr: { + uint32_t result_storage_class = 0; + uint32_t result_data_type = 0; + if (!_.GetPointerTypeInfo(result_type, &result_data_type, + &result_storage_class)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be a pointer: " + << spvOpcodeString(opcode); + + if (result_storage_class != SpvStorageClassWorkgroup && + result_storage_class != SpvStorageClassCrossWorkgroup && + result_storage_class != SpvStorageClassFunction) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to have storage class Workgroup, " + << "CrossWorkgroup or Function: " << spvOpcodeString(opcode); + + const uint32_t input_type = _.GetOperandTypeId(inst, 2); + uint32_t input_storage_class = 0; + uint32_t input_data_type = 0; + if (!_.GetPointerTypeInfo(input_type, &input_data_type, + &input_storage_class)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to be a pointer: " << spvOpcodeString(opcode); + + if (input_storage_class != SpvStorageClassGeneric) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to have storage class Generic: " + << spvOpcodeString(opcode); + + if (result_data_type != input_data_type) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input and Result Type to point to the same type: " + << spvOpcodeString(opcode); + break; + } + + case SpvOpGenericCastToPtrExplicit: { + uint32_t result_storage_class = 0; + uint32_t result_data_type = 0; + if (!_.GetPointerTypeInfo(result_type, &result_data_type, + &result_storage_class)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be a pointer: " + << spvOpcodeString(opcode); + + const uint32_t target_storage_class = inst->word(4); + if (result_storage_class != target_storage_class) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be of target storage class: " + << spvOpcodeString(opcode); + + const uint32_t input_type = _.GetOperandTypeId(inst, 2); + uint32_t input_storage_class = 0; + uint32_t input_data_type = 0; + if (!_.GetPointerTypeInfo(input_type, &input_data_type, + &input_storage_class)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to be a pointer: " << spvOpcodeString(opcode); + + if (input_storage_class != SpvStorageClassGeneric) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to have storage class Generic: " + << spvOpcodeString(opcode); + + if (result_data_type != input_data_type) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input and Result Type to point to the same type: " + << spvOpcodeString(opcode); + + if (target_storage_class != SpvStorageClassWorkgroup && + target_storage_class != SpvStorageClassCrossWorkgroup && + target_storage_class != SpvStorageClassFunction) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected target storage class to be Workgroup, " + << "CrossWorkgroup or Function: " << spvOpcodeString(opcode); + break; + } + + case SpvOpBitcast: { + const uint32_t input_type = _.GetOperandTypeId(inst, 2); + if (!input_type) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to have a type: " << spvOpcodeString(opcode); + + const bool result_is_pointer = _.IsPointerType(result_type); + const bool result_is_int_scalar = _.IsIntScalarType(result_type); + const bool input_is_pointer = _.IsPointerType(input_type); + const bool input_is_int_scalar = _.IsIntScalarType(input_type); + + if (!result_is_pointer && !result_is_int_scalar && + !_.IsIntVectorType(result_type) && + !_.IsFloatScalarType(result_type) && + !_.IsFloatVectorType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be a pointer or int or float vector " + << "or scalar type: " << spvOpcodeString(opcode); + + if (!input_is_pointer && !input_is_int_scalar && + !_.IsIntVectorType(input_type) && !_.IsFloatScalarType(input_type) && + !_.IsFloatVectorType(input_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to be a pointer or int or float vector " + << "or scalar: " << spvOpcodeString(opcode); + + if (result_is_pointer && !input_is_pointer && !input_is_int_scalar) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to be a pointer or int scalar if Result Type " + << "is pointer: " << spvOpcodeString(opcode); + + if (input_is_pointer && !result_is_pointer && !result_is_int_scalar) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Pointer can only be converted to another pointer or int " + << "scalar: " << spvOpcodeString(opcode); + + if (!result_is_pointer && !input_is_pointer) { + const uint32_t result_size = + _.GetBitWidth(result_type) * _.GetDimension(result_type); + const uint32_t input_size = + _.GetBitWidth(input_type) * _.GetDimension(input_type); + if (result_size != input_size) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected input to have the same total bit width as " + << "Result Type: " << spvOpcodeString(opcode); + } + break; + } + + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_datarules.cpp b/source/val/validate_datarules.cpp new file mode 100644 index 000000000..129b6bbf9 --- /dev/null +++ b/source/val/validate_datarules.cpp @@ -0,0 +1,267 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Ensures Data Rules are followed according to the specifications. + +#include "source/val/validate.h" + +#include +#include +#include + +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/operand.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +// Validates that the number of components in the vector is valid. +// Vector types can only be parameterized as having 2, 3, or 4 components. +// If the Vector16 capability is added, 8 and 16 components are also allowed. +spv_result_t ValidateVecNumComponents(ValidationState_t& _, + const Instruction* inst) { + // Operand 2 specifies the number of components in the vector. + auto num_components = inst->GetOperandAs(2); + if (num_components == 2 || num_components == 3 || num_components == 4) { + return SPV_SUCCESS; + } + if (num_components == 8 || num_components == 16) { + if (_.HasCapability(SpvCapabilityVector16)) { + return SPV_SUCCESS; + } + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Having " << num_components << " components for " + << spvOpcodeString(inst->opcode()) + << " requires the Vector16 capability"; + } + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Illegal number of components (" << num_components << ") for " + << spvOpcodeString(inst->opcode()); +} + +// Validates that the number of bits specifed for a float type is valid. +// Scalar floating-point types can be parameterized only with 32-bits. +// Float16 capability allows using a 16-bit OpTypeFloat. +// Float16Buffer capability allows creation of a 16-bit OpTypeFloat. +// Float64 capability allows using a 64-bit OpTypeFloat. +spv_result_t ValidateFloatSize(ValidationState_t& _, const Instruction* inst) { + // Operand 1 is the number of bits for this float + auto num_bits = inst->GetOperandAs(1); + if (num_bits == 32) { + return SPV_SUCCESS; + } + if (num_bits == 16) { + if (_.features().declare_float16_type) { + return SPV_SUCCESS; + } + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Using a 16-bit floating point " + << "type requires the Float16 or Float16Buffer capability," + " or an extension that explicitly enables 16-bit floating point."; + } + if (num_bits == 64) { + if (_.HasCapability(SpvCapabilityFloat64)) { + return SPV_SUCCESS; + } + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Using a 64-bit floating point " + << "type requires the Float64 capability."; + } + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Invalid number of bits (" << num_bits << ") used for OpTypeFloat."; +} + +// Validates that the number of bits specified for an Int type is valid. +// Scalar integer types can be parameterized only with 32-bits. +// Int8, Int16, and Int64 capabilities allow using 8-bit, 16-bit, and 64-bit +// integers, respectively. +spv_result_t ValidateIntSize(ValidationState_t& _, const Instruction* inst) { + // Operand 1 is the number of bits for this integer. + auto num_bits = inst->GetOperandAs(1); + if (num_bits == 32) { + return SPV_SUCCESS; + } + if (num_bits == 8) { + if (_.features().declare_int8_type) { + return SPV_SUCCESS; + } + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Using an 8-bit integer type requires the Int8 capability," + " or an extension that explicitly enables 8-bit integers."; + } + if (num_bits == 16) { + if (_.features().declare_int16_type) { + return SPV_SUCCESS; + } + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Using a 16-bit integer type requires the Int16 capability," + " or an extension that explicitly enables 16-bit integers."; + } + if (num_bits == 64) { + if (_.HasCapability(SpvCapabilityInt64)) { + return SPV_SUCCESS; + } + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Using a 64-bit integer type requires the Int64 capability."; + } + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Invalid number of bits (" << num_bits << ") used for OpTypeInt."; +} + +// Validates that the matrix is parameterized with floating-point types. +spv_result_t ValidateMatrixColumnType(ValidationState_t& _, + const Instruction* inst) { + // Find the component type of matrix columns (must be vector). + // Operand 1 is the of the type specified for matrix columns. + auto type_id = inst->GetOperandAs(1); + auto col_type_instr = _.FindDef(type_id); + if (col_type_instr->opcode() != SpvOpTypeVector) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Columns in a matrix must be of type vector."; + } + + // Trace back once more to find out the type of components in the vector. + // Operand 1 is the of the type of data in the vector. + auto comp_type_id = + col_type_instr->words()[col_type_instr->operands()[1].offset]; + auto comp_type_instruction = _.FindDef(comp_type_id); + if (comp_type_instruction->opcode() != SpvOpTypeFloat) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Matrix types can only be " + "parameterized with " + "floating-point types."; + } + return SPV_SUCCESS; +} + +// Validates that the matrix has 2,3, or 4 columns. +spv_result_t ValidateMatrixNumCols(ValidationState_t& _, + const Instruction* inst) { + // Operand 2 is the number of columns in the matrix. + auto num_cols = inst->GetOperandAs(2); + if (num_cols != 2 && num_cols != 3 && num_cols != 4) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Matrix types can only be " + "parameterized as having " + "only 2, 3, or 4 columns."; + } + return SPV_SUCCESS; +} + +// Validates that OpSpecConstant specializes to either int or float type. +spv_result_t ValidateSpecConstNumerical(ValidationState_t& _, + const Instruction* inst) { + // Operand 0 is the of the type that we're specializing to. + auto type_id = inst->GetOperandAs(0); + auto type_instruction = _.FindDef(type_id); + auto type_opcode = type_instruction->opcode(); + if (type_opcode != SpvOpTypeInt && type_opcode != SpvOpTypeFloat) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Specialization constant " + "must be an integer or " + "floating-point number."; + } + return SPV_SUCCESS; +} + +// Validates that OpSpecConstantTrue and OpSpecConstantFalse specialize to bool. +spv_result_t ValidateSpecConstBoolean(ValidationState_t& _, + const Instruction* inst) { + // Find out the type that we're specializing to. + auto type_instruction = _.FindDef(inst->type_id()); + if (type_instruction->opcode() != SpvOpTypeBool) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Specialization constant must be a boolean type."; + } + return SPV_SUCCESS; +} + +// Records the of the forward pointer to be used for validation. +spv_result_t ValidateForwardPointer(ValidationState_t& _, + const Instruction* inst) { + // Record the (which is operand 0) to ensure it's used properly. + // OpTypeStruct can only include undefined pointers that are + // previously declared as a ForwardPointer + return (_.RegisterForwardPointer(inst->GetOperandAs(0))); +} + +// Validates that any undefined component of the struct is a forward pointer. +// It is valid to declare a forward pointer, and use its as one of the +// components of a struct. +spv_result_t ValidateStruct(ValidationState_t& _, const Instruction* inst) { + // Struct components are operands 1, 2, etc. + for (unsigned i = 1; i < inst->operands().size(); i++) { + auto type_id = inst->GetOperandAs(i); + auto type_instruction = _.FindDef(type_id); + if (type_instruction == nullptr && !_.IsForwardPointer(type_id)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Forward reference operands in an OpTypeStruct must first be " + "declared using OpTypeForwardPointer."; + } + } + return SPV_SUCCESS; +} + +} // namespace + +// Validates that Data Rules are followed according to the specifications. +// (Data Rules subsection of 2.16.1 Universal Validation Rules) +spv_result_t DataRulesPass(ValidationState_t& _, const Instruction* inst) { + switch (inst->opcode()) { + case SpvOpTypeVector: { + if (auto error = ValidateVecNumComponents(_, inst)) return error; + break; + } + case SpvOpTypeFloat: { + if (auto error = ValidateFloatSize(_, inst)) return error; + break; + } + case SpvOpTypeInt: { + if (auto error = ValidateIntSize(_, inst)) return error; + break; + } + case SpvOpTypeMatrix: { + if (auto error = ValidateMatrixColumnType(_, inst)) return error; + if (auto error = ValidateMatrixNumCols(_, inst)) return error; + break; + } + // TODO(ehsan): Add OpSpecConstantComposite validation code. + // TODO(ehsan): Add OpSpecConstantOp validation code (if any). + case SpvOpSpecConstant: { + if (auto error = ValidateSpecConstNumerical(_, inst)) return error; + break; + } + case SpvOpSpecConstantFalse: + case SpvOpSpecConstantTrue: { + if (auto error = ValidateSpecConstBoolean(_, inst)) return error; + break; + } + case SpvOpTypeForwardPointer: { + if (auto error = ValidateForwardPointer(_, inst)) return error; + break; + } + case SpvOpTypeStruct: { + if (auto error = ValidateStruct(_, inst)) return error; + break; + } + // TODO(ehsan): add more data rules validation here. + default: { break; } + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_debug.cpp b/source/val/validate_debug.cpp new file mode 100644 index 000000000..b49890f26 --- /dev/null +++ b/source/val/validate_debug.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/validate.h" + +#include "source/opcode.h" +#include "source/spirv_target_env.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +spv_result_t ValidateMemberName(ValidationState_t& _, const Instruction* inst) { + const auto type_id = inst->GetOperandAs(0); + const auto type = _.FindDef(type_id); + if (!type || SpvOpTypeStruct != type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpMemberName Type '" << _.getIdName(type_id) + << "' is not a struct type."; + } + const auto member_id = inst->GetOperandAs(1); + const auto member_count = (uint32_t)(type->words().size() - 2); + if (member_count <= member_id) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpMemberName Member '" << _.getIdName(member_id) + << "' index is larger than Type '" << _.getIdName(type->id()) + << "'s member count."; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateLine(ValidationState_t& _, const Instruction* inst) { + const auto file_id = inst->GetOperandAs(0); + const auto file = _.FindDef(file_id); + if (!file || SpvOpString != file->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpLine Target '" << _.getIdName(file_id) + << "' is not an OpString."; + } + return SPV_SUCCESS; +} + +} // namespace + +spv_result_t DebugPass(ValidationState_t& _, const Instruction* inst) { + if (spvIsWebGPUEnv(_.context()->target_env) && + spvOpcodeIsDebug(inst->opcode())) { + return _.diag(SPV_ERROR_INVALID_BINARY, inst) + << "Debugging instructions are not allowed in the WebGPU execution " + << "environment."; + } + + switch (inst->opcode()) { + case SpvOpMemberName: + if (auto error = ValidateMemberName(_, inst)) return error; + break; + case SpvOpLine: + if (auto error = ValidateLine(_, inst)) return error; + break; + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_decorations.cpp b/source/val/validate_decorations.cpp new file mode 100644 index 000000000..7f150aafa --- /dev/null +++ b/source/val/validate_decorations.cpp @@ -0,0 +1,1276 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/validate.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/spirv_target_env.h" +#include "source/spirv_validator_options.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +// Distinguish between row and column major matrix layouts. +enum MatrixLayout { kRowMajor, kColumnMajor }; + +// A functor for hashing a pair of integers. +struct PairHash { + std::size_t operator()(const std::pair pair) const { + const uint32_t a = pair.first; + const uint32_t b = pair.second; + const uint32_t rotated_b = (b >> 2) | ((b & 3) << 30); + return a ^ rotated_b; + } +}; + +// A functor for hashing decoration types. +struct SpvDecorationHash { + std::size_t operator()(SpvDecoration dec) const { + return static_cast(dec); + } +}; + +// Struct member layout attributes that are inherited through arrays. +struct LayoutConstraints { + explicit LayoutConstraints( + MatrixLayout the_majorness = MatrixLayout::kColumnMajor, + uint32_t stride = 0) + : majorness(the_majorness), matrix_stride(stride) {} + MatrixLayout majorness; + uint32_t matrix_stride; +}; + +// A type for mapping (struct id, member id) to layout constraints. +using MemberConstraints = std::unordered_map, + LayoutConstraints, PairHash>; + +// Returns the array stride of the given array type. +uint32_t GetArrayStride(uint32_t array_id, ValidationState_t& vstate) { + for (auto& decoration : vstate.id_decorations(array_id)) { + if (SpvDecorationArrayStride == decoration.dec_type()) { + return decoration.params()[0]; + } + } + return 0; +} + +// Returns true if the given variable has a BuiltIn decoration. +bool isBuiltInVar(uint32_t var_id, ValidationState_t& vstate) { + const auto& decorations = vstate.id_decorations(var_id); + return std::any_of( + decorations.begin(), decorations.end(), + [](const Decoration& d) { return SpvDecorationBuiltIn == d.dec_type(); }); +} + +// Returns true if the given structure type has any members with BuiltIn +// decoration. +bool isBuiltInStruct(uint32_t struct_id, ValidationState_t& vstate) { + const auto& decorations = vstate.id_decorations(struct_id); + return std::any_of( + decorations.begin(), decorations.end(), [](const Decoration& d) { + return SpvDecorationBuiltIn == d.dec_type() && + Decoration::kInvalidMember != d.struct_member_index(); + }); +} + +// Returns true if the given ID has the Import LinkageAttributes decoration. +bool hasImportLinkageAttribute(uint32_t id, ValidationState_t& vstate) { + const auto& decorations = vstate.id_decorations(id); + return std::any_of(decorations.begin(), decorations.end(), + [](const Decoration& d) { + return SpvDecorationLinkageAttributes == d.dec_type() && + d.params().size() >= 2u && + d.params().back() == SpvLinkageTypeImport; + }); +} + +// Returns a vector of all members of a structure. +std::vector getStructMembers(uint32_t struct_id, + ValidationState_t& vstate) { + const auto inst = vstate.FindDef(struct_id); + return std::vector(inst->words().begin() + 2, inst->words().end()); +} + +// Returns a vector of all members of a structure that have specific type. +std::vector getStructMembers(uint32_t struct_id, SpvOp type, + ValidationState_t& vstate) { + std::vector members; + for (auto id : getStructMembers(struct_id, vstate)) { + if (type == vstate.FindDef(id)->opcode()) { + members.push_back(id); + } + } + return members; +} + +// Returns whether the given structure is missing Offset decoration for any +// member. Handles also nested structures. +bool isMissingOffsetInStruct(uint32_t struct_id, ValidationState_t& vstate) { + std::vector hasOffset(getStructMembers(struct_id, vstate).size(), + false); + // Check offsets of member decorations + for (auto& decoration : vstate.id_decorations(struct_id)) { + if (SpvDecorationOffset == decoration.dec_type() && + Decoration::kInvalidMember != decoration.struct_member_index()) { + hasOffset[decoration.struct_member_index()] = true; + } + } + // Check also nested structures + bool nestedStructsMissingOffset = false; + for (auto id : getStructMembers(struct_id, SpvOpTypeStruct, vstate)) { + if (isMissingOffsetInStruct(id, vstate)) { + nestedStructsMissingOffset = true; + break; + } + } + return nestedStructsMissingOffset || + !std::all_of(hasOffset.begin(), hasOffset.end(), + [](const bool b) { return b; }); +} + +// Rounds x up to the next alignment. Assumes alignment is a power of two. +uint32_t align(uint32_t x, uint32_t alignment) { + return (x + alignment - 1) & ~(alignment - 1); +} + +// Returns base alignment of struct member. If |roundUp| is true, also +// ensure that structs and arrays are aligned at least to a multiple of 16 +// bytes. +uint32_t getBaseAlignment(uint32_t member_id, bool roundUp, + const LayoutConstraints& inherited, + MemberConstraints& constraints, + ValidationState_t& vstate) { + const auto inst = vstate.FindDef(member_id); + const auto& words = inst->words(); + // Minimal alignment is byte-aligned. + uint32_t baseAlignment = 1; + switch (inst->opcode()) { + case SpvOpTypeInt: + case SpvOpTypeFloat: + baseAlignment = words[2] / 8; + break; + case SpvOpTypeVector: { + const auto componentId = words[2]; + const auto numComponents = words[3]; + const auto componentAlignment = getBaseAlignment( + componentId, roundUp, inherited, constraints, vstate); + baseAlignment = + componentAlignment * (numComponents == 3 ? 4 : numComponents); + break; + } + case SpvOpTypeMatrix: { + const auto column_type = words[2]; + if (inherited.majorness == kColumnMajor) { + baseAlignment = getBaseAlignment(column_type, roundUp, inherited, + constraints, vstate); + } else { + // A row-major matrix of C columns has a base alignment equal to the + // base alignment of a vector of C matrix components. + const auto num_columns = words[3]; + const auto component_inst = vstate.FindDef(column_type); + const auto component_id = component_inst->words()[2]; + const auto componentAlignment = getBaseAlignment( + component_id, roundUp, inherited, constraints, vstate); + baseAlignment = + componentAlignment * (num_columns == 3 ? 4 : num_columns); + } + } break; + case SpvOpTypeArray: + case SpvOpTypeRuntimeArray: + baseAlignment = + getBaseAlignment(words[2], roundUp, inherited, constraints, vstate); + if (roundUp) baseAlignment = align(baseAlignment, 16u); + break; + case SpvOpTypeStruct: { + const auto members = getStructMembers(member_id, vstate); + for (uint32_t memberIdx = 0, numMembers = uint32_t(members.size()); + memberIdx < numMembers; ++memberIdx) { + const auto id = members[memberIdx]; + const auto& constraint = + constraints[std::make_pair(member_id, memberIdx)]; + baseAlignment = std::max( + baseAlignment, + getBaseAlignment(id, roundUp, constraint, constraints, vstate)); + } + if (roundUp) baseAlignment = align(baseAlignment, 16u); + break; + } + case SpvOpTypePointer: + baseAlignment = vstate.pointer_size_and_alignment(); + break; + default: + assert(0); + break; + } + + return baseAlignment; +} + +// Returns scalar alignment of a type. +uint32_t getScalarAlignment(uint32_t type_id, ValidationState_t& vstate) { + const auto inst = vstate.FindDef(type_id); + const auto& words = inst->words(); + switch (inst->opcode()) { + case SpvOpTypeInt: + case SpvOpTypeFloat: + return words[2] / 8; + case SpvOpTypeVector: + case SpvOpTypeMatrix: + case SpvOpTypeArray: + case SpvOpTypeRuntimeArray: { + const auto compositeMemberTypeId = words[2]; + return getScalarAlignment(compositeMemberTypeId, vstate); + } + case SpvOpTypeStruct: { + const auto members = getStructMembers(type_id, vstate); + uint32_t max_member_alignment = 1; + for (uint32_t memberIdx = 0, numMembers = uint32_t(members.size()); + memberIdx < numMembers; ++memberIdx) { + const auto id = members[memberIdx]; + uint32_t member_alignment = getScalarAlignment(id, vstate); + if (member_alignment > max_member_alignment) { + max_member_alignment = member_alignment; + } + } + return max_member_alignment; + } break; + case SpvOpTypePointer: + return vstate.pointer_size_and_alignment(); + default: + assert(0); + break; + } + + return 1; +} + +// Returns size of a struct member. Doesn't include padding at the end of struct +// or array. Assumes that in the struct case, all members have offsets. +uint32_t getSize(uint32_t member_id, const LayoutConstraints& inherited, + MemberConstraints& constraints, ValidationState_t& vstate) { + const auto inst = vstate.FindDef(member_id); + const auto& words = inst->words(); + switch (inst->opcode()) { + case SpvOpTypeInt: + case SpvOpTypeFloat: + return words[2] / 8; + case SpvOpTypeVector: { + const auto componentId = words[2]; + const auto numComponents = words[3]; + const auto componentSize = + getSize(componentId, inherited, constraints, vstate); + const auto size = componentSize * numComponents; + return size; + } + case SpvOpTypeArray: { + const auto sizeInst = vstate.FindDef(words[3]); + if (spvOpcodeIsSpecConstant(sizeInst->opcode())) return 0; + assert(SpvOpConstant == sizeInst->opcode()); + const uint32_t num_elem = sizeInst->words()[3]; + const uint32_t elem_type = words[2]; + const uint32_t elem_size = + getSize(elem_type, inherited, constraints, vstate); + // Account for gaps due to alignments in the first N-1 elements, + // then add the size of the last element. + const auto size = + (num_elem - 1) * GetArrayStride(member_id, vstate) + elem_size; + return size; + } + case SpvOpTypeRuntimeArray: + return 0; + case SpvOpTypeMatrix: { + const auto num_columns = words[3]; + if (inherited.majorness == kColumnMajor) { + return num_columns * inherited.matrix_stride; + } else { + // Row major case. + const auto column_type = words[2]; + const auto component_inst = vstate.FindDef(column_type); + const auto num_rows = component_inst->words()[3]; + const auto scalar_elem_type = component_inst->words()[2]; + const uint32_t scalar_elem_size = + getSize(scalar_elem_type, inherited, constraints, vstate); + return (num_rows - 1) * inherited.matrix_stride + + num_columns * scalar_elem_size; + } + } + case SpvOpTypeStruct: { + const auto& members = getStructMembers(member_id, vstate); + if (members.empty()) return 0; + const auto lastIdx = uint32_t(members.size() - 1); + const auto& lastMember = members.back(); + uint32_t offset = 0xffffffff; + // Find the offset of the last element and add the size. + for (auto& decoration : vstate.id_decorations(member_id)) { + if (SpvDecorationOffset == decoration.dec_type() && + decoration.struct_member_index() == (int)lastIdx) { + offset = decoration.params()[0]; + } + } + // This check depends on the fact that all members have offsets. This + // has been checked earlier in the flow. + assert(offset != 0xffffffff); + const auto& constraint = constraints[std::make_pair(lastMember, lastIdx)]; + return offset + getSize(lastMember, constraint, constraints, vstate); + } + case SpvOpTypePointer: + return vstate.pointer_size_and_alignment(); + default: + assert(0); + return 0; + } +} + +// A member is defined to improperly straddle if either of the following are +// true: +// - It is a vector with total size less than or equal to 16 bytes, and has +// Offset decorations placing its first byte at F and its last byte at L, where +// floor(F / 16) != floor(L / 16). +// - It is a vector with total size greater than 16 bytes and has its Offset +// decorations placing its first byte at a non-integer multiple of 16. +bool hasImproperStraddle(uint32_t id, uint32_t offset, + const LayoutConstraints& inherited, + MemberConstraints& constraints, + ValidationState_t& vstate) { + const auto size = getSize(id, inherited, constraints, vstate); + const auto F = offset; + const auto L = offset + size - 1; + if (size <= 16) { + if ((F >> 4) != (L >> 4)) return true; + } else { + if (F % 16 != 0) return true; + } + return false; +} + +// Returns true if |offset| satsifies an alignment to |alignment|. In the case +// of |alignment| of zero, the |offset| must also be zero. +bool IsAlignedTo(uint32_t offset, uint32_t alignment) { + if (alignment == 0) return offset == 0; + return 0 == (offset % alignment); +} + +// Returns SPV_SUCCESS if the given struct satisfies standard layout rules for +// Block or BufferBlocks in Vulkan. Otherwise emits a diagnostic and returns +// something other than SPV_SUCCESS. Matrices inherit the specified column +// or row major-ness. +spv_result_t checkLayout(uint32_t struct_id, const char* storage_class_str, + const char* decoration_str, bool blockRules, + MemberConstraints& constraints, + ValidationState_t& vstate) { + if (vstate.options()->skip_block_layout) return SPV_SUCCESS; + + // Relaxed layout and scalar layout can both be in effect at the same time. + // For example, relaxed layout is implied by Vulkan 1.1. But scalar layout + // is more permissive than relaxed layout. + const bool relaxed_block_layout = vstate.IsRelaxedBlockLayout(); + const bool scalar_block_layout = vstate.options()->scalar_block_layout; + + auto fail = [&vstate, struct_id, storage_class_str, decoration_str, + blockRules, relaxed_block_layout, + scalar_block_layout](uint32_t member_idx) -> DiagnosticStream { + DiagnosticStream ds = + std::move(vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(struct_id)) + << "Structure id " << struct_id << " decorated as " + << decoration_str << " for variable in " << storage_class_str + << " storage class must follow " + << (scalar_block_layout + ? "scalar " + : (relaxed_block_layout ? "relaxed " : "standard ")) + << (blockRules ? "uniform buffer" : "storage buffer") + << " layout rules: member " << member_idx << " "); + return ds; + }; + + const auto& members = getStructMembers(struct_id, vstate); + + // To check for member overlaps, we want to traverse the members in + // offset order. + struct MemberOffsetPair { + uint32_t member; + uint32_t offset; + }; + std::vector member_offsets; + member_offsets.reserve(members.size()); + for (uint32_t memberIdx = 0, numMembers = uint32_t(members.size()); + memberIdx < numMembers; memberIdx++) { + uint32_t offset = 0xffffffff; + for (auto& decoration : vstate.id_decorations(struct_id)) { + if (decoration.struct_member_index() == (int)memberIdx) { + switch (decoration.dec_type()) { + case SpvDecorationOffset: + offset = decoration.params()[0]; + break; + default: + break; + } + } + } + member_offsets.push_back(MemberOffsetPair{memberIdx, offset}); + } + std::stable_sort( + member_offsets.begin(), member_offsets.end(), + [](const MemberOffsetPair& lhs, const MemberOffsetPair& rhs) { + return lhs.offset < rhs.offset; + }); + + // Now scan from lowest offest to highest offset. + uint32_t nextValidOffset = 0; + for (size_t ordered_member_idx = 0; + ordered_member_idx < member_offsets.size(); ordered_member_idx++) { + const auto& member_offset = member_offsets[ordered_member_idx]; + const auto memberIdx = member_offset.member; + const auto offset = member_offset.offset; + auto id = members[member_offset.member]; + const LayoutConstraints& constraint = + constraints[std::make_pair(struct_id, uint32_t(memberIdx))]; + // Scalar layout takes precedence because it's more permissive, and implying + // an alignment that divides evenly into the alignment that would otherwise + // be used. + const auto alignment = + scalar_block_layout + ? getScalarAlignment(id, vstate) + : getBaseAlignment(id, blockRules, constraint, constraints, vstate); + const auto inst = vstate.FindDef(id); + const auto opcode = inst->opcode(); + const auto size = getSize(id, constraint, constraints, vstate); + // Check offset. + if (offset == 0xffffffff) + return fail(memberIdx) << "is missing an Offset decoration"; + if (!scalar_block_layout && relaxed_block_layout && + opcode == SpvOpTypeVector) { + // In relaxed block layout, the vector offset must be aligned to the + // vector's scalar element type. + const auto componentId = inst->words()[2]; + const auto scalar_alignment = getScalarAlignment(componentId, vstate); + if (!IsAlignedTo(offset, scalar_alignment)) { + return fail(memberIdx) + << "at offset " << offset + << " is not aligned to scalar element size " << scalar_alignment; + } + } else { + // Without relaxed block layout, the offset must be divisible by the + // alignment requirement. + if (!IsAlignedTo(offset, alignment)) { + return fail(memberIdx) + << "at offset " << offset << " is not aligned to " << alignment; + } + } + if (offset < nextValidOffset) + return fail(memberIdx) << "at offset " << offset + << " overlaps previous member ending at offset " + << nextValidOffset - 1; + if (!scalar_block_layout && relaxed_block_layout) { + // Check improper straddle of vectors. + if (SpvOpTypeVector == opcode && + hasImproperStraddle(id, offset, constraint, constraints, vstate)) + return fail(memberIdx) + << "is an improperly straddling vector at offset " << offset; + } + // Check struct members recursively. + spv_result_t recursive_status = SPV_SUCCESS; + if (SpvOpTypeStruct == opcode && + SPV_SUCCESS != (recursive_status = + checkLayout(id, storage_class_str, decoration_str, + blockRules, constraints, vstate))) + return recursive_status; + // Check matrix stride. + if (SpvOpTypeMatrix == opcode) { + for (auto& decoration : vstate.id_decorations(id)) { + if (SpvDecorationMatrixStride == decoration.dec_type() && + !IsAlignedTo(decoration.params()[0], alignment)) + return fail(memberIdx) + << "is a matrix with stride " << decoration.params()[0] + << " not satisfying alignment to " << alignment; + } + } + // Check arrays and runtime arrays. + if (SpvOpTypeArray == opcode || SpvOpTypeRuntimeArray == opcode) { + const auto typeId = inst->word(2); + const auto arrayInst = vstate.FindDef(typeId); + if (SpvOpTypeStruct == arrayInst->opcode() && + SPV_SUCCESS != (recursive_status = checkLayout( + typeId, storage_class_str, decoration_str, + blockRules, constraints, vstate))) + return recursive_status; + // Check array stride. + for (auto& decoration : vstate.id_decorations(id)) { + if (SpvDecorationArrayStride == decoration.dec_type() && + !IsAlignedTo(decoration.params()[0], alignment)) + return fail(memberIdx) + << "is an array with stride " << decoration.params()[0] + << " not satisfying alignment to " << alignment; + } + } + nextValidOffset = offset + size; + if (!scalar_block_layout && blockRules && + (SpvOpTypeArray == opcode || SpvOpTypeStruct == opcode)) { + // Uniform block rules don't permit anything in the padding of a struct + // or array. + nextValidOffset = align(nextValidOffset, alignment); + } + } + return SPV_SUCCESS; +} + +// Returns true if variable or structure id has given decoration. Handles also +// nested structures. +bool hasDecoration(uint32_t id, SpvDecoration decoration, + ValidationState_t& vstate) { + for (auto& dec : vstate.id_decorations(id)) { + if (decoration == dec.dec_type()) return true; + } + if (SpvOpTypeStruct != vstate.FindDef(id)->opcode()) { + return false; + } + for (auto member_id : getStructMembers(id, SpvOpTypeStruct, vstate)) { + if (hasDecoration(member_id, decoration, vstate)) { + return true; + } + } + return false; +} + +// Returns true if all ids of given type have a specified decoration. +bool checkForRequiredDecoration(uint32_t struct_id, SpvDecoration decoration, + SpvOp type, ValidationState_t& vstate) { + const auto& members = getStructMembers(struct_id, vstate); + for (size_t memberIdx = 0; memberIdx < members.size(); memberIdx++) { + const auto id = members[memberIdx]; + if (type != vstate.FindDef(id)->opcode()) continue; + bool found = false; + for (auto& dec : vstate.id_decorations(id)) { + if (decoration == dec.dec_type()) found = true; + } + for (auto& dec : vstate.id_decorations(struct_id)) { + if (decoration == dec.dec_type() && + (int)memberIdx == dec.struct_member_index()) { + found = true; + } + } + if (!found) { + return false; + } + } + for (auto id : getStructMembers(struct_id, SpvOpTypeStruct, vstate)) { + if (!checkForRequiredDecoration(id, decoration, type, vstate)) { + return false; + } + } + return true; +} + +spv_result_t CheckLinkageAttrOfFunctions(ValidationState_t& vstate) { + for (const auto& function : vstate.functions()) { + if (function.block_count() == 0u) { + // A function declaration (an OpFunction with no basic blocks), must have + // a Linkage Attributes Decoration with the Import Linkage Type. + if (!hasImportLinkageAttribute(function.id(), vstate)) { + return vstate.diag(SPV_ERROR_INVALID_BINARY, + vstate.FindDef(function.id())) + << "Function declaration (id " << function.id() + << ") must have a LinkageAttributes decoration with the Import " + "Linkage type."; + } + } else { + if (hasImportLinkageAttribute(function.id(), vstate)) { + return vstate.diag(SPV_ERROR_INVALID_BINARY, + vstate.FindDef(function.id())) + << "Function definition (id " << function.id() + << ") may not be decorated with Import Linkage type."; + } + } + } + return SPV_SUCCESS; +} + +// Checks whether an imported variable is initialized by this module. +spv_result_t CheckImportedVariableInitialization(ValidationState_t& vstate) { + // According the SPIR-V Spec 2.16.1, it is illegal to initialize an imported + // variable. This means that a module-scope OpVariable with initialization + // value cannot be marked with the Import Linkage Type (import type id = 1). + for (auto global_var_id : vstate.global_vars()) { + // Initializer is an optional argument for OpVariable. If initializer + // is present, the instruction will have 5 words. + auto variable_instr = vstate.FindDef(global_var_id); + if (variable_instr->words().size() == 5u && + hasImportLinkageAttribute(global_var_id, vstate)) { + return vstate.diag(SPV_ERROR_INVALID_ID, variable_instr) + << "A module-scope OpVariable with initialization value " + "cannot be marked with the Import Linkage Type."; + } + } + return SPV_SUCCESS; +} + +// Checks whether a builtin variable is valid. +spv_result_t CheckBuiltInVariable(uint32_t var_id, ValidationState_t& vstate) { + const auto& decorations = vstate.id_decorations(var_id); + for (const auto& d : decorations) { + if (spvIsVulkanEnv(vstate.context()->target_env)) { + if (d.dec_type() == SpvDecorationLocation || + d.dec_type() == SpvDecorationComponent) { + return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(var_id)) + << "A BuiltIn variable (id " << var_id + << ") cannot have any Location or Component decorations"; + } + } + } + return SPV_SUCCESS; +} + +// Checks whether proper decorations have been appied to the entry points. +spv_result_t CheckDecorationsOfEntryPoints(ValidationState_t& vstate) { + for (uint32_t entry_point : vstate.entry_points()) { + const auto& descs = vstate.entry_point_descriptions(entry_point); + int num_builtin_inputs = 0; + int num_builtin_outputs = 0; + for (const auto& desc : descs) { + for (auto interface : desc.interfaces) { + Instruction* var_instr = vstate.FindDef(interface); + if (!var_instr || SpvOpVariable != var_instr->opcode()) { + return vstate.diag(SPV_ERROR_INVALID_ID, var_instr) + << "Interfaces passed to OpEntryPoint must be of type " + "OpTypeVariable. Found Op" + << spvOpcodeString(var_instr->opcode()) << "."; + } + const SpvStorageClass storage_class = + var_instr->GetOperandAs(2); + if (storage_class != SpvStorageClassInput && + storage_class != SpvStorageClassOutput) { + return vstate.diag(SPV_ERROR_INVALID_ID, var_instr) + << "OpEntryPoint interfaces must be OpVariables with " + "Storage Class of Input(1) or Output(3). Found Storage " + "Class " + << storage_class << " for Entry Point id " << entry_point + << "."; + } + + const uint32_t ptr_id = var_instr->word(1); + Instruction* ptr_instr = vstate.FindDef(ptr_id); + // It is guaranteed (by validator ID checks) that ptr_instr is + // OpTypePointer. Word 3 of this instruction is the type being pointed + // to. + const uint32_t type_id = ptr_instr->word(3); + Instruction* type_instr = vstate.FindDef(type_id); + if (type_instr && SpvOpTypeStruct == type_instr->opcode() && + isBuiltInStruct(type_id, vstate)) { + if (storage_class == SpvStorageClassInput) ++num_builtin_inputs; + if (storage_class == SpvStorageClassOutput) ++num_builtin_outputs; + if (num_builtin_inputs > 1 || num_builtin_outputs > 1) break; + if (auto error = CheckBuiltInVariable(interface, vstate)) + return error; + } else if (isBuiltInVar(interface, vstate)) { + if (auto error = CheckBuiltInVariable(interface, vstate)) + return error; + } + } + if (num_builtin_inputs > 1 || num_builtin_outputs > 1) { + return vstate.diag(SPV_ERROR_INVALID_BINARY, + vstate.FindDef(entry_point)) + << "There must be at most one object per Storage Class that can " + "contain a structure type containing members decorated with " + "BuiltIn, consumed per entry-point. Entry Point id " + << entry_point << " does not meet this requirement."; + } + // The LinkageAttributes Decoration cannot be applied to functions + // targeted by an OpEntryPoint instruction + for (auto& decoration : vstate.id_decorations(entry_point)) { + if (SpvDecorationLinkageAttributes == decoration.dec_type()) { + const char* linkage_name = + reinterpret_cast(&decoration.params()[0]); + return vstate.diag(SPV_ERROR_INVALID_BINARY, + vstate.FindDef(entry_point)) + << "The LinkageAttributes Decoration (Linkage name: " + << linkage_name << ") cannot be applied to function id " + << entry_point + << " because it is targeted by an OpEntryPoint instruction."; + } + } + } + } + return SPV_SUCCESS; +} + +// Load |constraints| with all the member constraints for structs contained +// within the given array type. +void ComputeMemberConstraintsForArray(MemberConstraints* constraints, + uint32_t array_id, + const LayoutConstraints& inherited, + ValidationState_t& vstate); + +// Load |constraints| with all the member constraints for the given struct, +// and all its contained structs. +void ComputeMemberConstraintsForStruct(MemberConstraints* constraints, + uint32_t struct_id, + const LayoutConstraints& inherited, + ValidationState_t& vstate) { + assert(constraints); + const auto& members = getStructMembers(struct_id, vstate); + for (uint32_t memberIdx = 0, numMembers = uint32_t(members.size()); + memberIdx < numMembers; memberIdx++) { + LayoutConstraints& constraint = + (*constraints)[std::make_pair(struct_id, memberIdx)]; + constraint = inherited; + for (auto& decoration : vstate.id_decorations(struct_id)) { + if (decoration.struct_member_index() == (int)memberIdx) { + switch (decoration.dec_type()) { + case SpvDecorationRowMajor: + constraint.majorness = kRowMajor; + break; + case SpvDecorationColMajor: + constraint.majorness = kColumnMajor; + break; + case SpvDecorationMatrixStride: + constraint.matrix_stride = decoration.params()[0]; + break; + default: + break; + } + } + } + + // Now recurse + auto member_type_id = members[memberIdx]; + const auto member_type_inst = vstate.FindDef(member_type_id); + const auto opcode = member_type_inst->opcode(); + switch (opcode) { + case SpvOpTypeArray: + case SpvOpTypeRuntimeArray: + ComputeMemberConstraintsForArray(constraints, member_type_id, inherited, + vstate); + break; + case SpvOpTypeStruct: + ComputeMemberConstraintsForStruct(constraints, member_type_id, + inherited, vstate); + break; + default: + break; + } + } +} + +void ComputeMemberConstraintsForArray(MemberConstraints* constraints, + uint32_t array_id, + const LayoutConstraints& inherited, + ValidationState_t& vstate) { + assert(constraints); + auto elem_type_id = vstate.FindDef(array_id)->words()[2]; + const auto elem_type_inst = vstate.FindDef(elem_type_id); + const auto opcode = elem_type_inst->opcode(); + switch (opcode) { + case SpvOpTypeArray: + case SpvOpTypeRuntimeArray: + ComputeMemberConstraintsForArray(constraints, elem_type_id, inherited, + vstate); + break; + case SpvOpTypeStruct: + ComputeMemberConstraintsForStruct(constraints, elem_type_id, inherited, + vstate); + break; + default: + break; + } +} + +spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) { + // Set of entry points that are known to use a push constant. + std::unordered_set uses_push_constant; + for (const auto& inst : vstate.ordered_instructions()) { + const auto& words = inst.words(); + if (SpvOpVariable == inst.opcode()) { + const auto var_id = inst.id(); + // For storage class / decoration combinations, see Vulkan 14.5.4 "Offset + // and Stride Assignment". + const auto storageClass = words[3]; + const bool uniform = storageClass == SpvStorageClassUniform; + const bool uniform_constant = + storageClass == SpvStorageClassUniformConstant; + const bool push_constant = storageClass == SpvStorageClassPushConstant; + const bool storage_buffer = storageClass == SpvStorageClassStorageBuffer; + + if (spvIsVulkanEnv(vstate.context()->target_env)) { + // Vulkan 14.5.1: There must be no more than one PushConstant block + // per entry point. + if (push_constant) { + auto entry_points = vstate.EntryPointReferences(var_id); + for (auto ep_id : entry_points) { + const bool already_used = !uses_push_constant.insert(ep_id).second; + if (already_used) { + return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(var_id)) + << "Entry point id '" << ep_id + << "' uses more than one PushConstant interface.\n" + << "From Vulkan spec, section 14.5.1:\n" + << "There must be no more than one push constant block " + << "statically used per shader entry point."; + } + } + } + // Vulkan 14.5.2: Check DescriptorSet and Binding decoration for + // UniformConstant which cannot be a struct. + if (uniform_constant) { + auto entry_points = vstate.EntryPointReferences(var_id); + if (!entry_points.empty() && + !hasDecoration(var_id, SpvDecorationDescriptorSet, vstate)) { + return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(var_id)) + << "UniformConstant id '" << var_id + << "' is missing DescriptorSet decoration.\n" + << "From Vulkan spec, section 14.5.2:\n" + << "These variables must have DescriptorSet and Binding " + "decorations specified"; + } + if (!entry_points.empty() && + !hasDecoration(var_id, SpvDecorationBinding, vstate)) { + return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(var_id)) + << "UniformConstant id '" << var_id + << "' is missing Binding decoration.\n" + << "From Vulkan spec, section 14.5.2:\n" + << "These variables must have DescriptorSet and Binding " + "decorations specified"; + } + } + } + + const bool phys_storage_buffer = + storageClass == SpvStorageClassPhysicalStorageBufferEXT; + if (uniform || push_constant || storage_buffer || phys_storage_buffer) { + const auto ptrInst = vstate.FindDef(words[1]); + assert(SpvOpTypePointer == ptrInst->opcode()); + const auto id = ptrInst->words()[3]; + if (SpvOpTypeStruct != vstate.FindDef(id)->opcode()) continue; + MemberConstraints constraints; + ComputeMemberConstraintsForStruct(&constraints, id, LayoutConstraints(), + vstate); + // Prepare for messages + const char* sc_str = + uniform ? "Uniform" + : (push_constant ? "PushConstant" : "StorageBuffer"); + + if (spvIsVulkanEnv(vstate.context()->target_env)) { + // Vulkan 14.5.1: Check Block decoration for PushConstant variables. + if (push_constant && !hasDecoration(id, SpvDecorationBlock, vstate)) { + return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(id)) + << "PushConstant id '" << id + << "' is missing Block decoration.\n" + << "From Vulkan spec, section 14.5.1:\n" + << "Such variables must be identified with a Block " + "decoration"; + } + // Vulkan 14.5.2: Check DescriptorSet and Binding decoration for + // Uniform and StorageBuffer variables. + if (uniform || storage_buffer) { + auto entry_points = vstate.EntryPointReferences(var_id); + if (!entry_points.empty() && + !hasDecoration(var_id, SpvDecorationDescriptorSet, vstate)) { + return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(var_id)) + << sc_str << " id '" << var_id + << "' is missing DescriptorSet decoration.\n" + << "From Vulkan spec, section 14.5.2:\n" + << "These variables must have DescriptorSet and Binding " + "decorations specified"; + } + if (!entry_points.empty() && + !hasDecoration(var_id, SpvDecorationBinding, vstate)) { + return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(var_id)) + << sc_str << " id '" << var_id + << "' is missing Binding decoration.\n" + << "From Vulkan spec, section 14.5.2:\n" + << "These variables must have DescriptorSet and Binding " + "decorations specified"; + } + } + } + + for (const auto& dec : vstate.id_decorations(id)) { + const bool blockDeco = SpvDecorationBlock == dec.dec_type(); + const bool bufferDeco = SpvDecorationBufferBlock == dec.dec_type(); + const bool blockRules = uniform && blockDeco; + const bool bufferRules = + (uniform && bufferDeco) || (push_constant && blockDeco) || + ((storage_buffer || phys_storage_buffer) && blockDeco); + if (blockRules || bufferRules) { + const char* deco_str = blockDeco ? "Block" : "BufferBlock"; + spv_result_t recursive_status = SPV_SUCCESS; + if (isMissingOffsetInStruct(id, vstate)) { + return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(id)) + << "Structure id " << id << " decorated as " << deco_str + << " must be explicitly laid out with Offset " + "decorations."; + } else if (hasDecoration(id, SpvDecorationGLSLShared, vstate)) { + return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(id)) + << "Structure id " << id << " decorated as " << deco_str + << " must not use GLSLShared decoration."; + } else if (hasDecoration(id, SpvDecorationGLSLPacked, vstate)) { + return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(id)) + << "Structure id " << id << " decorated as " << deco_str + << " must not use GLSLPacked decoration."; + } else if (!checkForRequiredDecoration(id, SpvDecorationArrayStride, + SpvOpTypeArray, vstate)) { + return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(id)) + << "Structure id " << id << " decorated as " << deco_str + << " must be explicitly laid out with ArrayStride " + "decorations."; + } else if (!checkForRequiredDecoration(id, + SpvDecorationMatrixStride, + SpvOpTypeMatrix, vstate)) { + return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(id)) + << "Structure id " << id << " decorated as " << deco_str + << " must be explicitly laid out with MatrixStride " + "decorations."; + } else if (blockRules && + (SPV_SUCCESS != (recursive_status = checkLayout( + id, sc_str, deco_str, true, + constraints, vstate)))) { + return recursive_status; + } else if (bufferRules && + (SPV_SUCCESS != (recursive_status = checkLayout( + id, sc_str, deco_str, false, + constraints, vstate)))) { + return recursive_status; + } + } + } + } + } + } + return SPV_SUCCESS; +} + +spv_result_t CheckDecorationsCompatibility(ValidationState_t& vstate) { + using AtMostOnceSet = std::unordered_set; + using MutuallyExclusiveSets = + std::vector>; + using PerIDKey = std::tuple; + using PerMemberKey = std::tuple; + using DecorationNameTable = + std::unordered_map; + + static const auto* const at_most_once_per_id = new AtMostOnceSet{ + SpvDecorationArrayStride, + }; + static const auto* const at_most_once_per_member = new AtMostOnceSet{ + SpvDecorationOffset, + SpvDecorationMatrixStride, + SpvDecorationRowMajor, + SpvDecorationColMajor, + }; + static const auto* const mutually_exclusive_per_id = + new MutuallyExclusiveSets{ + {SpvDecorationBlock, SpvDecorationBufferBlock}, + }; + static const auto* const mutually_exclusive_per_member = + new MutuallyExclusiveSets{ + {SpvDecorationRowMajor, SpvDecorationColMajor}, + }; + // For printing the decoration name. + static const auto* const decoration_name = new DecorationNameTable{ + {SpvDecorationArrayStride, "ArrayStride"}, + {SpvDecorationOffset, "Offset"}, + {SpvDecorationMatrixStride, "MatrixStride"}, + {SpvDecorationRowMajor, "RowMajor"}, + {SpvDecorationColMajor, "ColMajor"}, + {SpvDecorationBlock, "Block"}, + {SpvDecorationBufferBlock, "BufferBlock"}, + }; + + std::set seen_per_id; + std::set seen_per_member; + + for (const auto& inst : vstate.ordered_instructions()) { + const auto& words = inst.words(); + if (SpvOpDecorate == inst.opcode()) { + const auto id = words[1]; + const auto dec_type = static_cast(words[2]); + const auto k = PerIDKey(dec_type, id); + const auto already_used = !seen_per_id.insert(k).second; + if (already_used && + at_most_once_per_id->find(dec_type) != at_most_once_per_id->end()) { + return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(id)) + << "ID '" << id << "' decorated with " + << decoration_name->at(dec_type) + << " multiple times is not allowed."; + } + // Verify certain mutually exclusive decorations are not both applied on + // an ID. + for (const auto& s : *mutually_exclusive_per_id) { + if (s.find(dec_type) == s.end()) continue; + for (auto excl_dec_type : s) { + if (excl_dec_type == dec_type) continue; + const auto excl_k = PerIDKey(excl_dec_type, id); + if (seen_per_id.find(excl_k) != seen_per_id.end()) { + return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(id)) + << "ID '" << id << "' decorated with both " + << decoration_name->at(dec_type) << " and " + << decoration_name->at(excl_dec_type) << " is not allowed."; + } + } + } + } else if (SpvOpMemberDecorate == inst.opcode()) { + const auto id = words[1]; + const auto member_id = words[2]; + const auto dec_type = static_cast(words[3]); + const auto k = PerMemberKey(dec_type, id, member_id); + const auto already_used = !seen_per_member.insert(k).second; + if (already_used && at_most_once_per_member->find(dec_type) != + at_most_once_per_member->end()) { + return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(id)) + << "ID '" << id << "', member '" << member_id + << "' decorated with " << decoration_name->at(dec_type) + << " multiple times is not allowed."; + } + // Verify certain mutually exclusive decorations are not both applied on + // a (ID, member) tuple. + for (const auto& s : *mutually_exclusive_per_member) { + if (s.find(dec_type) == s.end()) continue; + for (auto excl_dec_type : s) { + if (excl_dec_type == dec_type) continue; + const auto excl_k = PerMemberKey(excl_dec_type, id, member_id); + if (seen_per_member.find(excl_k) != seen_per_member.end()) { + return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(id)) + << "ID '" << id << "', member '" << member_id + << "' decorated with both " << decoration_name->at(dec_type) + << " and " << decoration_name->at(excl_dec_type) + << " is not allowed."; + } + } + } + } + } + return SPV_SUCCESS; +} + +spv_result_t CheckVulkanMemoryModelDeprecatedDecorations( + ValidationState_t& vstate) { + if (vstate.memory_model() != SpvMemoryModelVulkanKHR) return SPV_SUCCESS; + + std::string msg; + std::ostringstream str(msg); + for (const auto& def : vstate.all_definitions()) { + const auto inst = def.second; + const auto id = inst->id(); + for (const auto& dec : vstate.id_decorations(id)) { + const auto member = dec.struct_member_index(); + if (dec.dec_type() == SpvDecorationCoherent || + dec.dec_type() == SpvDecorationVolatile) { + str << (dec.dec_type() == SpvDecorationCoherent ? "Coherent" + : "Volatile"); + str << " decoration targeting " << vstate.getIdName(id); + if (member != Decoration::kInvalidMember) { + str << " (member index " << member << ")"; + } + str << " is banned when using the Vulkan memory model."; + return vstate.diag(SPV_ERROR_INVALID_ID, inst) << str.str(); + } + } + } + return SPV_SUCCESS; +} + +// Returns SPV_SUCCESS if validation rules are satisfied for FPRoundingMode +// decorations. Otherwise emits a diagnostic and returns something other than +// SPV_SUCCESS. +spv_result_t CheckFPRoundingModeForShaders(ValidationState_t& vstate, + const Instruction& inst) { + // Validates width-only conversion instruction for floating-point object + // i.e., OpFConvert + if (inst.opcode() != SpvOpFConvert) { + return vstate.diag(SPV_ERROR_INVALID_ID, &inst) + << "FPRoundingMode decoration can be applied only to a " + "width-only conversion instruction for floating-point " + "object."; + } + + // Validates Object operand of an OpStore + for (const auto& use : inst.uses()) { + const auto store = use.first; + if (store->opcode() != SpvOpStore) { + return vstate.diag(SPV_ERROR_INVALID_ID, &inst) + << "FPRoundingMode decoration can be applied only to the " + "Object operand of an OpStore."; + } + + if (use.second != 2) { + return vstate.diag(SPV_ERROR_INVALID_ID, &inst) + << "FPRoundingMode decoration can be applied only to the " + "Object operand of an OpStore."; + } + + const auto ptr_inst = vstate.FindDef(store->GetOperandAs(0)); + const auto ptr_type = vstate.FindDef(ptr_inst->GetOperandAs(0)); + + const auto half_float_id = ptr_type->GetOperandAs(2); + if (!vstate.IsFloatScalarOrVectorType(half_float_id) || + vstate.GetBitWidth(half_float_id) != 16) { + return vstate.diag(SPV_ERROR_INVALID_ID, &inst) + << "FPRoundingMode decoration can be applied only to the " + "Object operand of an OpStore storing through a pointer " + "to " + "a 16-bit floating-point scalar or vector object."; + } + + // Validates storage class of the pointer to the OpStore + const auto storage = ptr_type->GetOperandAs(1); + if (storage != SpvStorageClassStorageBuffer && + storage != SpvStorageClassUniform && + storage != SpvStorageClassPushConstant && + storage != SpvStorageClassInput && storage != SpvStorageClassOutput && + storage != SpvStorageClassPhysicalStorageBufferEXT) { + return vstate.diag(SPV_ERROR_INVALID_ID, &inst) + << "FPRoundingMode decoration can be applied only to the " + "Object operand of an OpStore in the StorageBuffer, " + "PhysicalStorageBufferEXT, Uniform, PushConstant, Input, or " + "Output Storage Classes."; + } + } + return SPV_SUCCESS; +} + +// Returns SPV_SUCCESS if validation rules are satisfied for Uniform +// decorations. Otherwise emits a diagnostic and returns something other than +// SPV_SUCCESS. Assumes each decoration on a group has been propagated down to +// the group members. +spv_result_t CheckUniformDecoration(ValidationState_t& vstate, + const Instruction& inst, + const Decoration&) { + // Uniform must decorate an "object" + // - has a result ID + // - is an instantiation of a non-void type. So it has a type ID, and that + // type is not void. + + // We already know the result ID is non-zero. + + if (inst.type_id() == 0) { + return vstate.diag(SPV_ERROR_INVALID_ID, &inst) + << "Uniform decoration applied to a non-object"; + } + if (Instruction* type_inst = vstate.FindDef(inst.type_id())) { + if (type_inst->opcode() == SpvOpTypeVoid) { + return vstate.diag(SPV_ERROR_INVALID_ID, &inst) + << "Uniform decoration applied to a value with void type"; + } + } else { + // We might never get here because this would have been rejected earlier in + // the flow. + return vstate.diag(SPV_ERROR_INVALID_ID, &inst) + << "Uniform decoration applied to an object with invalid type"; + } + return SPV_SUCCESS; +} + +// Returns SPV_SUCCESS if validation rules are satisfied for NoSignedWrap or +// NoUnsignedWrap decorations. Otherwise emits a diagnostic and returns +// something other than SPV_SUCCESS. Assumes each decoration on a group has been +// propagated down to the group members. +spv_result_t CheckIntegerWrapDecoration(ValidationState_t& vstate, + const Instruction& inst, + const Decoration& decoration) { + switch (inst.opcode()) { + case SpvOpIAdd: + case SpvOpISub: + case SpvOpIMul: + case SpvOpShiftLeftLogical: + case SpvOpSNegate: + return SPV_SUCCESS; + case SpvOpExtInst: + // TODO(dneto): Only certain extended instructions allow these + // decorations. For now allow anything. + return SPV_SUCCESS; + default: + break; + } + + return vstate.diag(SPV_ERROR_INVALID_ID, &inst) + << (decoration.dec_type() == SpvDecorationNoSignedWrap + ? "NoSignedWrap" + : "NoUnsignedWrap") + << " decoration may not be applied to " + << spvOpcodeString(inst.opcode()); +} + +#define PASS_OR_BAIL_AT_LINE(X, LINE) \ + { \ + spv_result_t e##LINE = (X); \ + if (e##LINE != SPV_SUCCESS) return e##LINE; \ + } +#define PASS_OR_BAIL(X) PASS_OR_BAIL_AT_LINE(X, __LINE__) + +// Check rules for decorations where we start from the decoration rather +// than the decorated object. Assumes each decoration on a group have been +// propagated down to the group members. +spv_result_t CheckDecorationsFromDecoration(ValidationState_t& vstate) { + // Some rules are only checked for shaders. + const bool is_shader = vstate.HasCapability(SpvCapabilityShader); + + for (const auto& kv : vstate.id_decorations()) { + const uint32_t id = kv.first; + const auto& decorations = kv.second; + if (decorations.empty()) continue; + + const Instruction* inst = vstate.FindDef(id); + assert(inst); + + // We assume the decorations applied to a decoration group have already + // been propagated down to the group members. + if (inst->opcode() == SpvOpDecorationGroup) continue; + + // Validates FPRoundingMode decoration + for (const auto& decoration : decorations) { + switch (decoration.dec_type()) { + case SpvDecorationFPRoundingMode: + if (is_shader) + PASS_OR_BAIL(CheckFPRoundingModeForShaders(vstate, *inst)); + break; + case SpvDecorationUniform: + PASS_OR_BAIL(CheckUniformDecoration(vstate, *inst, decoration)); + break; + case SpvDecorationNoSignedWrap: + case SpvDecorationNoUnsignedWrap: + PASS_OR_BAIL(CheckIntegerWrapDecoration(vstate, *inst, decoration)); + break; + default: + break; + } + } + } + return SPV_SUCCESS; +} + +} // namespace + +spv_result_t ValidateDecorations(ValidationState_t& vstate) { + if (auto error = CheckImportedVariableInitialization(vstate)) return error; + if (auto error = CheckDecorationsOfEntryPoints(vstate)) return error; + if (auto error = CheckDecorationsOfBuffers(vstate)) return error; + if (auto error = CheckDecorationsCompatibility(vstate)) return error; + if (auto error = CheckLinkageAttrOfFunctions(vstate)) return error; + if (auto error = CheckVulkanMemoryModelDeprecatedDecorations(vstate)) + return error; + if (auto error = CheckDecorationsFromDecoration(vstate)) return error; + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_derivatives.cpp b/source/val/validate_derivatives.cpp new file mode 100644 index 000000000..718970769 --- /dev/null +++ b/source/val/validate_derivatives.cpp @@ -0,0 +1,98 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validates correctness of derivative SPIR-V instructions. + +#include "source/val/validate.h" + +#include + +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { + +// Validates correctness of derivative instructions. +spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + const uint32_t result_type = inst->type_id(); + + switch (opcode) { + case SpvOpDPdx: + case SpvOpDPdy: + case SpvOpFwidth: + case SpvOpDPdxFine: + case SpvOpDPdyFine: + case SpvOpFwidthFine: + case SpvOpDPdxCoarse: + case SpvOpDPdyCoarse: + case SpvOpFwidthCoarse: { + if (!_.IsFloatScalarOrVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be float scalar or vector type: " + << spvOpcodeString(opcode); + } + + const uint32_t p_type = _.GetOperandTypeId(inst, 2); + if (p_type != result_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected P type and Result Type to be the same: " + << spvOpcodeString(opcode); + } + + const spvtools::Extension compute_shader_derivatives_extension = + kSPV_NV_compute_shader_derivatives; + ExtensionSet exts(1, &compute_shader_derivatives_extension); + + if (_.HasAnyOfExtensions(exts)) { + _.function(inst->function()->id()) + ->RegisterExecutionModelLimitation([opcode](SpvExecutionModel model, + std::string* message) { + if (model != SpvExecutionModelFragment && + model != SpvExecutionModelGLCompute) { + if (message) { + *message = + std::string( + "Derivative instructions require Fragment execution " + "model: ") + + spvOpcodeString(opcode); + } + return false; + } + return true; + }); + } else { + _.function(inst->function()->id()) + ->RegisterExecutionModelLimitation( + SpvExecutionModelFragment, + std::string( + "Derivative instructions require Fragment execution " + "model: ") + + spvOpcodeString(opcode)); + } + break; + } + + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_execution_limitations.cpp b/source/val/validate_execution_limitations.cpp new file mode 100644 index 000000000..d44930770 --- /dev/null +++ b/source/val/validate_execution_limitations.cpp @@ -0,0 +1,61 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/validate.h" + +#include "source/val/function.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { + +spv_result_t ValidateExecutionLimitations(ValidationState_t& _, + const Instruction* inst) { + if (inst->opcode() != SpvOpFunction) { + return SPV_SUCCESS; + } + + const auto func = _.function(inst->id()); + if (!func) { + return _.diag(SPV_ERROR_INTERNAL, inst) + << "Internal error: missing function id " << inst->id() << "."; + } + + for (uint32_t entry_id : _.FunctionEntryPoints(inst->id())) { + const auto* models = _.GetExecutionModels(entry_id); + if (models) { + if (models->empty()) { + return _.diag(SPV_ERROR_INTERNAL, inst) + << "Internal error: empty execution models for function id " + << entry_id << "."; + } + for (const auto model : *models) { + std::string reason; + if (!func->IsCompatibleWithExecutionModel(model, &reason)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpEntryPoint Entry Point '" << _.getIdName(entry_id) + << "'s callgraph contains function " + << _.getIdName(inst->id()) + << ", which cannot be used with the current execution " + "model:\n" + << reason; + } + } + } + } + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_extensions.cpp b/source/val/validate_extensions.cpp new file mode 100644 index 000000000..f264c8e76 --- /dev/null +++ b/source/val/validate_extensions.cpp @@ -0,0 +1,2029 @@ +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validates correctness of extension SPIR-V instructions. + +#include "source/val/validate.h" + +#include +#include +#include + +#include "source/diagnostic.h" +#include "source/enum_string_mapping.h" +#include "source/extensions.h" +#include "source/latest_version_glsl_std_450_header.h" +#include "source/latest_version_opencl_std_header.h" +#include "source/opcode.h" +#include "source/spirv_target_env.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +uint32_t GetSizeTBitWidth(const ValidationState_t& _) { + if (_.addressing_model() == SpvAddressingModelPhysical32) return 32; + + if (_.addressing_model() == SpvAddressingModelPhysical64) return 64; + + return 0; +} + +} // anonymous namespace + +spv_result_t ValidateExtension(ValidationState_t& _, const Instruction* inst) { + if (spvIsWebGPUEnv(_.context()->target_env)) { + std::string extension = GetExtensionString(&(inst->c_inst())); + + if (extension != ExtensionToString(kSPV_KHR_vulkan_memory_model)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "For WebGPU, the only valid parameter to OpExtension is " + << "\"" << ExtensionToString(kSPV_KHR_vulkan_memory_model) + << "\"."; + } + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateExtInstImport(ValidationState_t& _, + const Instruction* inst) { + if (spvIsWebGPUEnv(_.context()->target_env)) { + const auto name_id = 1; + const std::string name(reinterpret_cast( + inst->words().data() + inst->operands()[name_id].offset)); + if (name != "GLSL.std.450") { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "For WebGPU, the only valid parameter to OpExtInstImport is " + "\"GLSL.std.450\"."; + } + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateExtInst(ValidationState_t& _, const Instruction* inst) { + const uint32_t result_type = inst->type_id(); + const uint32_t num_operands = static_cast(inst->operands().size()); + + const uint32_t ext_inst_set = inst->word(3); + const uint32_t ext_inst_index = inst->word(4); + const spv_ext_inst_type_t ext_inst_type = + spv_ext_inst_type_t(inst->ext_inst_type()); + + auto ext_inst_name = [&_, ext_inst_set, ext_inst_type, ext_inst_index]() { + spv_ext_inst_desc desc = nullptr; + if (_.grammar().lookupExtInst(ext_inst_type, ext_inst_index, &desc) != + SPV_SUCCESS || + !desc) { + return std::string("Unknown ExtInst"); + } + + auto* import_inst = _.FindDef(ext_inst_set); + assert(import_inst); + + std::ostringstream ss; + ss << reinterpret_cast(import_inst->words().data() + 2); + ss << " "; + ss << desc->name; + + return ss.str(); + }; + + if (ext_inst_type == SPV_EXT_INST_TYPE_GLSL_STD_450) { + const GLSLstd450 ext_inst_key = GLSLstd450(ext_inst_index); + switch (ext_inst_key) { + case GLSLstd450Round: + case GLSLstd450RoundEven: + case GLSLstd450FAbs: + case GLSLstd450Trunc: + case GLSLstd450FSign: + case GLSLstd450Floor: + case GLSLstd450Ceil: + case GLSLstd450Fract: + case GLSLstd450Sqrt: + case GLSLstd450InverseSqrt: + case GLSLstd450FMin: + case GLSLstd450FMax: + case GLSLstd450FClamp: + case GLSLstd450FMix: + case GLSLstd450Step: + case GLSLstd450SmoothStep: + case GLSLstd450Fma: + case GLSLstd450Normalize: + case GLSLstd450FaceForward: + case GLSLstd450Reflect: + case GLSLstd450NMin: + case GLSLstd450NMax: + case GLSLstd450NClamp: { + if (!_.IsFloatScalarOrVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a float scalar or vector type"; + } + + for (uint32_t operand_index = 4; operand_index < num_operands; + ++operand_index) { + const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index); + if (result_type != operand_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected types of all operands to be equal to Result " + "Type"; + } + } + break; + } + + case GLSLstd450SAbs: + case GLSLstd450SSign: + case GLSLstd450UMin: + case GLSLstd450SMin: + case GLSLstd450UMax: + case GLSLstd450SMax: + case GLSLstd450UClamp: + case GLSLstd450SClamp: + case GLSLstd450FindILsb: + case GLSLstd450FindUMsb: + case GLSLstd450FindSMsb: { + if (!_.IsIntScalarOrVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be an int scalar or vector type"; + } + + const uint32_t result_type_bit_width = _.GetBitWidth(result_type); + const uint32_t result_type_dimension = _.GetDimension(result_type); + + for (uint32_t operand_index = 4; operand_index < num_operands; + ++operand_index) { + const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index); + if (!_.IsIntScalarOrVectorType(operand_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected all operands to be int scalars or vectors"; + } + + if (result_type_dimension != _.GetDimension(operand_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected all operands to have the same dimension as " + << "Result Type"; + } + + if (result_type_bit_width != _.GetBitWidth(operand_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected all operands to have the same bit width as " + << "Result Type"; + } + + if (ext_inst_key == GLSLstd450FindUMsb || + ext_inst_key == GLSLstd450FindSMsb) { + if (result_type_bit_width != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "this instruction is currently limited to 32-bit width " + << "components"; + } + } + } + break; + } + + case GLSLstd450Radians: + case GLSLstd450Degrees: + case GLSLstd450Sin: + case GLSLstd450Cos: + case GLSLstd450Tan: + case GLSLstd450Asin: + case GLSLstd450Acos: + case GLSLstd450Atan: + case GLSLstd450Sinh: + case GLSLstd450Cosh: + case GLSLstd450Tanh: + case GLSLstd450Asinh: + case GLSLstd450Acosh: + case GLSLstd450Atanh: + case GLSLstd450Exp: + case GLSLstd450Exp2: + case GLSLstd450Log: + case GLSLstd450Log2: + case GLSLstd450Atan2: + case GLSLstd450Pow: { + if (!_.IsFloatScalarOrVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a 16 or 32-bit scalar or " + "vector float type"; + } + + const uint32_t result_type_bit_width = _.GetBitWidth(result_type); + if (result_type_bit_width != 16 && result_type_bit_width != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a 16 or 32-bit scalar or " + "vector float type"; + } + + for (uint32_t operand_index = 4; operand_index < num_operands; + ++operand_index) { + const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index); + if (result_type != operand_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected types of all operands to be equal to Result " + "Type"; + } + } + break; + } + + case GLSLstd450Determinant: { + const uint32_t x_type = _.GetOperandTypeId(inst, 4); + uint32_t num_rows = 0; + uint32_t num_cols = 0; + uint32_t col_type = 0; + uint32_t component_type = 0; + if (!_.GetMatrixTypeInfo(x_type, &num_rows, &num_cols, &col_type, + &component_type) || + num_rows != num_cols) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand X to be a square matrix"; + } + + if (result_type != component_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand X component type to be equal to " + << "Result Type"; + } + break; + } + + case GLSLstd450MatrixInverse: { + uint32_t num_rows = 0; + uint32_t num_cols = 0; + uint32_t col_type = 0; + uint32_t component_type = 0; + if (!_.GetMatrixTypeInfo(result_type, &num_rows, &num_cols, &col_type, + &component_type) || + num_rows != num_cols) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a square matrix"; + } + + const uint32_t x_type = _.GetOperandTypeId(inst, 4); + if (result_type != x_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand X type to be equal to Result Type"; + } + break; + } + + case GLSLstd450Modf: { + if (!_.IsFloatScalarOrVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a scalar or vector float type"; + } + + const uint32_t x_type = _.GetOperandTypeId(inst, 4); + const uint32_t i_type = _.GetOperandTypeId(inst, 5); + + if (x_type != result_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand X type to be equal to Result Type"; + } + + uint32_t i_storage_class = 0; + uint32_t i_data_type = 0; + if (!_.GetPointerTypeInfo(i_type, &i_data_type, &i_storage_class)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand I to be a pointer"; + } + + if (i_data_type != result_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand I data type to be equal to Result Type"; + } + + break; + } + + case GLSLstd450ModfStruct: { + std::vector result_types; + if (!_.GetStructMemberTypes(result_type, &result_types) || + result_types.size() != 2 || + !_.IsFloatScalarOrVectorType(result_types[0]) || + result_types[1] != result_types[0]) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a struct with two identical " + << "scalar or vector float type members"; + } + + const uint32_t x_type = _.GetOperandTypeId(inst, 4); + if (x_type != result_types[0]) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand X type to be equal to members of " + << "Result Type struct"; + } + break; + } + + case GLSLstd450Frexp: { + if (!_.IsFloatScalarOrVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a scalar or vector float type"; + } + + const uint32_t x_type = _.GetOperandTypeId(inst, 4); + const uint32_t exp_type = _.GetOperandTypeId(inst, 5); + + if (x_type != result_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand X type to be equal to Result Type"; + } + + uint32_t exp_storage_class = 0; + uint32_t exp_data_type = 0; + if (!_.GetPointerTypeInfo(exp_type, &exp_data_type, + &exp_storage_class)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand Exp to be a pointer"; + } + + if (!_.IsIntScalarOrVectorType(exp_data_type) || + (!_.HasExtension(kSPV_AMD_gpu_shader_int16) && + _.GetBitWidth(exp_data_type) != 32) || + (_.HasExtension(kSPV_AMD_gpu_shader_int16) && + _.GetBitWidth(exp_data_type) != 16 && + _.GetBitWidth(exp_data_type) != 32)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand Exp data type to be a " + << (_.HasExtension(kSPV_AMD_gpu_shader_int16) + ? "16-bit or 32-bit " + : "32-bit ") + << "int scalar or vector type"; + } + + if (_.GetDimension(result_type) != _.GetDimension(exp_data_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand Exp data type to have the same component " + << "number as Result Type"; + } + + break; + } + + case GLSLstd450Ldexp: { + if (!_.IsFloatScalarOrVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a scalar or vector float type"; + } + + const uint32_t x_type = _.GetOperandTypeId(inst, 4); + const uint32_t exp_type = _.GetOperandTypeId(inst, 5); + + if (x_type != result_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand X type to be equal to Result Type"; + } + + if (!_.IsIntScalarOrVectorType(exp_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand Exp to be a 32-bit int scalar " + << "or vector type"; + } + + if (_.GetDimension(result_type) != _.GetDimension(exp_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand Exp to have the same component " + << "number as Result Type"; + } + + break; + } + + case GLSLstd450FrexpStruct: { + std::vector result_types; + if (!_.GetStructMemberTypes(result_type, &result_types) || + result_types.size() != 2 || + !_.IsFloatScalarOrVectorType(result_types[0]) || + !_.IsIntScalarOrVectorType(result_types[1]) || + (!_.HasExtension(kSPV_AMD_gpu_shader_int16) && + _.GetBitWidth(result_types[1]) != 32) || + (_.HasExtension(kSPV_AMD_gpu_shader_int16) && + _.GetBitWidth(result_types[1]) != 16 && + _.GetBitWidth(result_types[1]) != 32) || + _.GetDimension(result_types[0]) != + _.GetDimension(result_types[1])) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a struct with two members, " + << "first member a float scalar or vector, second member a " + << (_.HasExtension(kSPV_AMD_gpu_shader_int16) + ? "16-bit or 32-bit " + : "32-bit ") + << "int scalar or vector with the same number of " + << "components as the first member"; + } + + const uint32_t x_type = _.GetOperandTypeId(inst, 4); + if (x_type != result_types[0]) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand X type to be equal to the first member " + << "of Result Type struct"; + } + break; + } + + case GLSLstd450PackSnorm4x8: + case GLSLstd450PackUnorm4x8: { + if (!_.IsIntScalarType(result_type) || + _.GetBitWidth(result_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be 32-bit int scalar type"; + } + + const uint32_t v_type = _.GetOperandTypeId(inst, 4); + if (!_.IsFloatVectorType(v_type) || _.GetDimension(v_type) != 4 || + _.GetBitWidth(v_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand V to be a 32-bit float vector of size 4"; + } + break; + } + + case GLSLstd450PackSnorm2x16: + case GLSLstd450PackUnorm2x16: + case GLSLstd450PackHalf2x16: { + if (!_.IsIntScalarType(result_type) || + _.GetBitWidth(result_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be 32-bit int scalar type"; + } + + const uint32_t v_type = _.GetOperandTypeId(inst, 4); + if (!_.IsFloatVectorType(v_type) || _.GetDimension(v_type) != 2 || + _.GetBitWidth(v_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand V to be a 32-bit float vector of size 2"; + } + break; + } + + case GLSLstd450PackDouble2x32: { + if (!_.IsFloatScalarType(result_type) || + _.GetBitWidth(result_type) != 64) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be 64-bit float scalar type"; + } + + const uint32_t v_type = _.GetOperandTypeId(inst, 4); + if (!_.IsIntVectorType(v_type) || _.GetDimension(v_type) != 2 || + _.GetBitWidth(v_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand V to be a 32-bit int vector of size 2"; + } + break; + } + + case GLSLstd450UnpackSnorm4x8: + case GLSLstd450UnpackUnorm4x8: { + if (!_.IsFloatVectorType(result_type) || + _.GetDimension(result_type) != 4 || + _.GetBitWidth(result_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a 32-bit float vector of size " + "4"; + } + + const uint32_t v_type = _.GetOperandTypeId(inst, 4); + if (!_.IsIntScalarType(v_type) || _.GetBitWidth(v_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P to be a 32-bit int scalar"; + } + break; + } + + case GLSLstd450UnpackSnorm2x16: + case GLSLstd450UnpackUnorm2x16: + case GLSLstd450UnpackHalf2x16: { + if (!_.IsFloatVectorType(result_type) || + _.GetDimension(result_type) != 2 || + _.GetBitWidth(result_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a 32-bit float vector of size " + "2"; + } + + const uint32_t v_type = _.GetOperandTypeId(inst, 4); + if (!_.IsIntScalarType(v_type) || _.GetBitWidth(v_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P to be a 32-bit int scalar"; + } + break; + } + + case GLSLstd450UnpackDouble2x32: { + if (!_.IsIntVectorType(result_type) || + _.GetDimension(result_type) != 2 || + _.GetBitWidth(result_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a 32-bit int vector of size " + "2"; + } + + const uint32_t v_type = _.GetOperandTypeId(inst, 4); + if (!_.IsFloatScalarType(v_type) || _.GetBitWidth(v_type) != 64) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand V to be a 64-bit float scalar"; + } + break; + } + + case GLSLstd450Length: { + if (!_.IsFloatScalarType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a float scalar type"; + } + + const uint32_t x_type = _.GetOperandTypeId(inst, 4); + if (!_.IsFloatScalarOrVectorType(x_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand X to be of float scalar or vector type"; + } + + if (result_type != _.GetComponentType(x_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand X component type to be equal to Result " + "Type"; + } + break; + } + + case GLSLstd450Distance: { + if (!_.IsFloatScalarType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a float scalar type"; + } + + const uint32_t p0_type = _.GetOperandTypeId(inst, 4); + if (!_.IsFloatScalarOrVectorType(p0_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P0 to be of float scalar or vector type"; + } + + if (result_type != _.GetComponentType(p0_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P0 component type to be equal to " + << "Result Type"; + } + + const uint32_t p1_type = _.GetOperandTypeId(inst, 5); + if (!_.IsFloatScalarOrVectorType(p1_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P1 to be of float scalar or vector type"; + } + + if (result_type != _.GetComponentType(p1_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P1 component type to be equal to " + << "Result Type"; + } + + if (_.GetDimension(p0_type) != _.GetDimension(p1_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operands P0 and P1 to have the same number of " + << "components"; + } + break; + } + + case GLSLstd450Cross: { + if (!_.IsFloatVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a float vector type"; + } + + if (_.GetDimension(result_type) != 3) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to have 3 components"; + } + + const uint32_t x_type = _.GetOperandTypeId(inst, 4); + const uint32_t y_type = _.GetOperandTypeId(inst, 5); + + if (x_type != result_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand X type to be equal to Result Type"; + } + + if (y_type != result_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand Y type to be equal to Result Type"; + } + break; + } + + case GLSLstd450Refract: { + if (!_.IsFloatScalarOrVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a float scalar or vector type"; + } + + const uint32_t i_type = _.GetOperandTypeId(inst, 4); + const uint32_t n_type = _.GetOperandTypeId(inst, 5); + const uint32_t eta_type = _.GetOperandTypeId(inst, 6); + + if (result_type != i_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand I to be of type equal to Result Type"; + } + + if (result_type != n_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand N to be of type equal to Result Type"; + } + + if (!_.IsFloatScalarType(eta_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand Eta to be a float scalar"; + } + break; + } + + case GLSLstd450InterpolateAtCentroid: + case GLSLstd450InterpolateAtSample: + case GLSLstd450InterpolateAtOffset: { + if (!_.HasCapability(SpvCapabilityInterpolationFunction)) { + return _.diag(SPV_ERROR_INVALID_CAPABILITY, inst) + << ext_inst_name() + << " requires capability InterpolationFunction"; + } + + if (!_.IsFloatScalarOrVectorType(result_type) || + _.GetBitWidth(result_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a 32-bit float scalar " + << "or vector type"; + } + + const uint32_t interpolant_type = _.GetOperandTypeId(inst, 4); + uint32_t interpolant_storage_class = 0; + uint32_t interpolant_data_type = 0; + if (!_.GetPointerTypeInfo(interpolant_type, &interpolant_data_type, + &interpolant_storage_class)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Interpolant to be a pointer"; + } + + if (result_type != interpolant_data_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Interpolant data type to be equal to Result Type"; + } + + if (interpolant_storage_class != SpvStorageClassInput) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Interpolant storage class to be Input"; + } + + if (ext_inst_key == GLSLstd450InterpolateAtSample) { + const uint32_t sample_type = _.GetOperandTypeId(inst, 5); + if (!_.IsIntScalarType(sample_type) || + _.GetBitWidth(sample_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Sample to be 32-bit integer"; + } + } + + if (ext_inst_key == GLSLstd450InterpolateAtOffset) { + const uint32_t offset_type = _.GetOperandTypeId(inst, 5); + if (!_.IsFloatVectorType(offset_type) || + _.GetDimension(offset_type) != 2 || + _.GetBitWidth(offset_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Offset to be a vector of 2 32-bit floats"; + } + } + + _.function(inst->function()->id()) + ->RegisterExecutionModelLimitation( + SpvExecutionModelFragment, + ext_inst_name() + + std::string(" requires Fragment execution model")); + break; + } + + case GLSLstd450IMix: { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Extended instruction GLSLstd450IMix is not supported"; + } + + case GLSLstd450Bad: { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Encountered extended instruction GLSLstd450Bad"; + } + + case GLSLstd450Count: { + assert(0); + break; + } + } + } else if (ext_inst_type == SPV_EXT_INST_TYPE_OPENCL_STD) { + const OpenCLLIB::Entrypoints ext_inst_key = + OpenCLLIB::Entrypoints(ext_inst_index); + switch (ext_inst_key) { + case OpenCLLIB::Acos: + case OpenCLLIB::Acosh: + case OpenCLLIB::Acospi: + case OpenCLLIB::Asin: + case OpenCLLIB::Asinh: + case OpenCLLIB::Asinpi: + case OpenCLLIB::Atan: + case OpenCLLIB::Atan2: + case OpenCLLIB::Atanh: + case OpenCLLIB::Atanpi: + case OpenCLLIB::Atan2pi: + case OpenCLLIB::Cbrt: + case OpenCLLIB::Ceil: + case OpenCLLIB::Copysign: + case OpenCLLIB::Cos: + case OpenCLLIB::Cosh: + case OpenCLLIB::Cospi: + case OpenCLLIB::Erfc: + case OpenCLLIB::Erf: + case OpenCLLIB::Exp: + case OpenCLLIB::Exp2: + case OpenCLLIB::Exp10: + case OpenCLLIB::Expm1: + case OpenCLLIB::Fabs: + case OpenCLLIB::Fdim: + case OpenCLLIB::Floor: + case OpenCLLIB::Fma: + case OpenCLLIB::Fmax: + case OpenCLLIB::Fmin: + case OpenCLLIB::Fmod: + case OpenCLLIB::Hypot: + case OpenCLLIB::Lgamma: + case OpenCLLIB::Log: + case OpenCLLIB::Log2: + case OpenCLLIB::Log10: + case OpenCLLIB::Log1p: + case OpenCLLIB::Logb: + case OpenCLLIB::Mad: + case OpenCLLIB::Maxmag: + case OpenCLLIB::Minmag: + case OpenCLLIB::Nextafter: + case OpenCLLIB::Pow: + case OpenCLLIB::Powr: + case OpenCLLIB::Remainder: + case OpenCLLIB::Rint: + case OpenCLLIB::Round: + case OpenCLLIB::Rsqrt: + case OpenCLLIB::Sin: + case OpenCLLIB::Sinh: + case OpenCLLIB::Sinpi: + case OpenCLLIB::Sqrt: + case OpenCLLIB::Tan: + case OpenCLLIB::Tanh: + case OpenCLLIB::Tanpi: + case OpenCLLIB::Tgamma: + case OpenCLLIB::Trunc: + case OpenCLLIB::Half_cos: + case OpenCLLIB::Half_divide: + case OpenCLLIB::Half_exp: + case OpenCLLIB::Half_exp2: + case OpenCLLIB::Half_exp10: + case OpenCLLIB::Half_log: + case OpenCLLIB::Half_log2: + case OpenCLLIB::Half_log10: + case OpenCLLIB::Half_powr: + case OpenCLLIB::Half_recip: + case OpenCLLIB::Half_rsqrt: + case OpenCLLIB::Half_sin: + case OpenCLLIB::Half_sqrt: + case OpenCLLIB::Half_tan: + case OpenCLLIB::Native_cos: + case OpenCLLIB::Native_divide: + case OpenCLLIB::Native_exp: + case OpenCLLIB::Native_exp2: + case OpenCLLIB::Native_exp10: + case OpenCLLIB::Native_log: + case OpenCLLIB::Native_log2: + case OpenCLLIB::Native_log10: + case OpenCLLIB::Native_powr: + case OpenCLLIB::Native_recip: + case OpenCLLIB::Native_rsqrt: + case OpenCLLIB::Native_sin: + case OpenCLLIB::Native_sqrt: + case OpenCLLIB::Native_tan: + case OpenCLLIB::FClamp: + case OpenCLLIB::Degrees: + case OpenCLLIB::FMax_common: + case OpenCLLIB::FMin_common: + case OpenCLLIB::Mix: + case OpenCLLIB::Radians: + case OpenCLLIB::Step: + case OpenCLLIB::Smoothstep: + case OpenCLLIB::Sign: { + if (!_.IsFloatScalarOrVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a float scalar or vector type"; + } + + const uint32_t num_components = _.GetDimension(result_type); + if (num_components > 4 && num_components != 8 && num_components != 16) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a scalar or a vector with 2, " + "3, 4, 8 or 16 components"; + } + + for (uint32_t operand_index = 4; operand_index < num_operands; + ++operand_index) { + const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index); + if (result_type != operand_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected types of all operands to be equal to Result " + "Type"; + } + } + break; + } + + case OpenCLLIB::Fract: + case OpenCLLIB::Modf: + case OpenCLLIB::Sincos: + case OpenCLLIB::Remquo: { + if (!_.IsFloatScalarOrVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a float scalar or vector type"; + } + + const uint32_t num_components = _.GetDimension(result_type); + if (num_components > 4 && num_components != 8 && num_components != 16) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a scalar or a vector with 2, " + "3, 4, 8 or 16 components"; + } + + uint32_t operand_index = 4; + const uint32_t x_type = _.GetOperandTypeId(inst, operand_index++); + if (result_type != x_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected type of operand X to be equal to Result Type"; + } + + if (ext_inst_key == OpenCLLIB::Remquo) { + const uint32_t y_type = _.GetOperandTypeId(inst, operand_index++); + if (result_type != y_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected type of operand Y to be equal to Result Type"; + } + } + + const uint32_t p_type = _.GetOperandTypeId(inst, operand_index++); + uint32_t p_storage_class = 0; + uint32_t p_data_type = 0; + if (!_.GetPointerTypeInfo(p_type, &p_data_type, &p_storage_class)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected the last operand to be a pointer"; + } + + if (p_storage_class != SpvStorageClassGeneric && + p_storage_class != SpvStorageClassCrossWorkgroup && + p_storage_class != SpvStorageClassWorkgroup && + p_storage_class != SpvStorageClassFunction) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected storage class of the pointer to be Generic, " + "CrossWorkgroup, Workgroup or Function"; + } + + if (result_type != p_data_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected data type of the pointer to be equal to Result " + "Type"; + } + break; + } + + case OpenCLLIB::Frexp: + case OpenCLLIB::Lgamma_r: { + if (!_.IsFloatScalarOrVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a float scalar or vector type"; + } + + const uint32_t num_components = _.GetDimension(result_type); + if (num_components > 4 && num_components != 8 && num_components != 16) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a scalar or a vector with 2, " + "3, 4, 8 or 16 components"; + } + + const uint32_t x_type = _.GetOperandTypeId(inst, 4); + if (result_type != x_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected type of operand X to be equal to Result Type"; + } + + const uint32_t p_type = _.GetOperandTypeId(inst, 5); + uint32_t p_storage_class = 0; + uint32_t p_data_type = 0; + if (!_.GetPointerTypeInfo(p_type, &p_data_type, &p_storage_class)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected the last operand to be a pointer"; + } + + if (p_storage_class != SpvStorageClassGeneric && + p_storage_class != SpvStorageClassCrossWorkgroup && + p_storage_class != SpvStorageClassWorkgroup && + p_storage_class != SpvStorageClassFunction) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected storage class of the pointer to be Generic, " + "CrossWorkgroup, Workgroup or Function"; + } + + if (!_.IsIntScalarOrVectorType(p_data_type) || + _.GetBitWidth(p_data_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected data type of the pointer to be a 32-bit int " + "scalar or vector type"; + } + + if (_.GetDimension(p_data_type) != num_components) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected data type of the pointer to have the same number " + "of components as Result Type"; + } + break; + } + + case OpenCLLIB::Ilogb: { + if (!_.IsIntScalarOrVectorType(result_type) || + _.GetBitWidth(result_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a 32-bit int scalar or vector " + "type"; + } + + const uint32_t num_components = _.GetDimension(result_type); + if (num_components > 4 && num_components != 8 && num_components != 16) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a scalar or a vector with 2, " + "3, 4, 8 or 16 components"; + } + + const uint32_t x_type = _.GetOperandTypeId(inst, 4); + if (!_.IsFloatScalarOrVectorType(x_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand X to be a float scalar or vector"; + } + + if (_.GetDimension(x_type) != num_components) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand X to have the same number of components " + "as Result Type"; + } + break; + } + + case OpenCLLIB::Ldexp: + case OpenCLLIB::Pown: + case OpenCLLIB::Rootn: { + if (!_.IsFloatScalarOrVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a float scalar or vector type"; + } + + const uint32_t num_components = _.GetDimension(result_type); + if (num_components > 4 && num_components != 8 && num_components != 16) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a scalar or a vector with 2, " + "3, 4, 8 or 16 components"; + } + + const uint32_t x_type = _.GetOperandTypeId(inst, 4); + if (result_type != x_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected type of operand X to be equal to Result Type"; + } + + const uint32_t exp_type = _.GetOperandTypeId(inst, 5); + if (!_.IsIntScalarOrVectorType(exp_type) || + _.GetBitWidth(exp_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected the exponent to be a 32-bit int scalar or vector"; + } + + if (_.GetDimension(exp_type) != num_components) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected the exponent to have the same number of " + "components as Result Type"; + } + break; + } + + case OpenCLLIB::Nan: { + if (!_.IsFloatScalarOrVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a float scalar or vector type"; + } + + const uint32_t num_components = _.GetDimension(result_type); + if (num_components > 4 && num_components != 8 && num_components != 16) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a scalar or a vector with 2, " + "3, 4, 8 or 16 components"; + } + + const uint32_t nancode_type = _.GetOperandTypeId(inst, 4); + if (!_.IsIntScalarOrVectorType(nancode_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Nancode to be an int scalar or vector type"; + } + + if (_.GetDimension(nancode_type) != num_components) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Nancode to have the same number of components as " + "Result Type"; + } + + if (_.GetBitWidth(result_type) != _.GetBitWidth(nancode_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Nancode to have the same bit width as Result " + "Type"; + } + break; + } + + case OpenCLLIB::SAbs: + case OpenCLLIB::SAbs_diff: + case OpenCLLIB::SAdd_sat: + case OpenCLLIB::UAdd_sat: + case OpenCLLIB::SHadd: + case OpenCLLIB::UHadd: + case OpenCLLIB::SRhadd: + case OpenCLLIB::URhadd: + case OpenCLLIB::SClamp: + case OpenCLLIB::UClamp: + case OpenCLLIB::Clz: + case OpenCLLIB::Ctz: + case OpenCLLIB::SMad_hi: + case OpenCLLIB::UMad_sat: + case OpenCLLIB::SMad_sat: + case OpenCLLIB::SMax: + case OpenCLLIB::UMax: + case OpenCLLIB::SMin: + case OpenCLLIB::UMin: + case OpenCLLIB::SMul_hi: + case OpenCLLIB::Rotate: + case OpenCLLIB::SSub_sat: + case OpenCLLIB::USub_sat: + case OpenCLLIB::Popcount: + case OpenCLLIB::UAbs: + case OpenCLLIB::UAbs_diff: + case OpenCLLIB::UMul_hi: + case OpenCLLIB::UMad_hi: { + if (!_.IsIntScalarOrVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be an int scalar or vector type"; + } + + const uint32_t num_components = _.GetDimension(result_type); + if (num_components > 4 && num_components != 8 && num_components != 16) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a scalar or a vector with 2, " + "3, 4, 8 or 16 components"; + } + + for (uint32_t operand_index = 4; operand_index < num_operands; + ++operand_index) { + const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index); + if (result_type != operand_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected types of all operands to be equal to Result " + "Type"; + } + } + break; + } + + case OpenCLLIB::U_Upsample: + case OpenCLLIB::S_Upsample: { + if (!_.IsIntScalarOrVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be an int scalar or vector " + "type"; + } + + const uint32_t result_num_components = _.GetDimension(result_type); + if (result_num_components > 4 && result_num_components != 8 && + result_num_components != 16) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a scalar or a vector with 2, " + "3, 4, 8 or 16 components"; + } + + const uint32_t result_bit_width = _.GetBitWidth(result_type); + if (result_bit_width != 16 && result_bit_width != 32 && + result_bit_width != 64) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected bit width of Result Type components to be 16, 32 " + "or 64"; + } + + const uint32_t hi_type = _.GetOperandTypeId(inst, 4); + const uint32_t lo_type = _.GetOperandTypeId(inst, 5); + + if (hi_type != lo_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Hi and Lo operands to have the same type"; + } + + if (result_num_components != _.GetDimension(hi_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Hi and Lo operands to have the same number of " + "components as Result Type"; + } + + if (result_bit_width != 2 * _.GetBitWidth(hi_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected bit width of components of Hi and Lo operands to " + "be half of the bit width of components of Result Type"; + } + break; + } + + case OpenCLLIB::SMad24: + case OpenCLLIB::UMad24: + case OpenCLLIB::SMul24: + case OpenCLLIB::UMul24: { + if (!_.IsIntScalarOrVectorType(result_type) || + _.GetBitWidth(result_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a 32-bit int scalar or vector " + "type"; + } + + const uint32_t num_components = _.GetDimension(result_type); + if (num_components > 4 && num_components != 8 && num_components != 16) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a scalar or a vector with 2, " + "3, 4, 8 or 16 components"; + } + + for (uint32_t operand_index = 4; operand_index < num_operands; + ++operand_index) { + const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index); + if (result_type != operand_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected types of all operands to be equal to Result " + "Type"; + } + } + break; + } + + case OpenCLLIB::Cross: { + if (!_.IsFloatVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a float vector type"; + } + + const uint32_t num_components = _.GetDimension(result_type); + if (num_components != 3 && num_components != 4) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to have 3 or 4 components"; + } + + const uint32_t x_type = _.GetOperandTypeId(inst, 4); + const uint32_t y_type = _.GetOperandTypeId(inst, 5); + + if (x_type != result_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand X type to be equal to Result Type"; + } + + if (y_type != result_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand Y type to be equal to Result Type"; + } + break; + } + + case OpenCLLIB::Distance: + case OpenCLLIB::Fast_distance: { + if (!_.IsFloatScalarType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a float scalar type"; + } + + const uint32_t p0_type = _.GetOperandTypeId(inst, 4); + if (!_.IsFloatScalarOrVectorType(p0_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P0 to be of float scalar or vector type"; + } + + const uint32_t num_components = _.GetDimension(p0_type); + if (num_components > 4) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P0 to have no more than 4 components"; + } + + if (result_type != _.GetComponentType(p0_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P0 component type to be equal to " + << "Result Type"; + } + + const uint32_t p1_type = _.GetOperandTypeId(inst, 5); + if (p0_type != p1_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operands P0 and P1 to be of the same type"; + } + break; + } + + case OpenCLLIB::Length: + case OpenCLLIB::Fast_length: { + if (!_.IsFloatScalarType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a float scalar type"; + } + + const uint32_t p_type = _.GetOperandTypeId(inst, 4); + if (!_.IsFloatScalarOrVectorType(p_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P to be a float scalar or vector"; + } + + const uint32_t num_components = _.GetDimension(p_type); + if (num_components > 4) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P to have no more than 4 components"; + } + + if (result_type != _.GetComponentType(p_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P component type to be equal to Result " + "Type"; + } + break; + } + + case OpenCLLIB::Normalize: + case OpenCLLIB::Fast_normalize: { + if (!_.IsFloatScalarOrVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a float scalar or vector type"; + } + + const uint32_t num_components = _.GetDimension(result_type); + if (num_components > 4) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to have no more than 4 components"; + } + + const uint32_t p_type = _.GetOperandTypeId(inst, 4); + if (p_type != result_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P type to be equal to Result Type"; + } + break; + } + + case OpenCLLIB::Bitselect: { + if (!_.IsFloatScalarOrVectorType(result_type) && + !_.IsIntScalarOrVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be an int or float scalar or " + "vector type"; + } + + const uint32_t num_components = _.GetDimension(result_type); + if (num_components > 4 && num_components != 8 && num_components != 16) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a scalar or a vector with 2, " + "3, 4, 8 or 16 components"; + } + + for (uint32_t operand_index = 4; operand_index < num_operands; + ++operand_index) { + const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index); + if (result_type != operand_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected types of all operands to be equal to Result " + "Type"; + } + } + break; + } + + case OpenCLLIB::Select: { + if (!_.IsFloatScalarOrVectorType(result_type) && + !_.IsIntScalarOrVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be an int or float scalar or " + "vector type"; + } + + const uint32_t num_components = _.GetDimension(result_type); + if (num_components > 4 && num_components != 8 && num_components != 16) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a scalar or a vector with 2, " + "3, 4, 8 or 16 components"; + } + + const uint32_t a_type = _.GetOperandTypeId(inst, 4); + const uint32_t b_type = _.GetOperandTypeId(inst, 5); + const uint32_t c_type = _.GetOperandTypeId(inst, 6); + + if (result_type != a_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand A type to be equal to Result Type"; + } + + if (result_type != b_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand B type to be equal to Result Type"; + } + + if (!_.IsIntScalarOrVectorType(c_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand C to be an int scalar or vector"; + } + + if (num_components != _.GetDimension(c_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand C to have the same number of components " + "as Result Type"; + } + + if (_.GetBitWidth(result_type) != _.GetBitWidth(c_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand C to have the same bit width as Result " + "Type"; + } + break; + } + + case OpenCLLIB::Vloadn: { + if (!_.IsFloatVectorType(result_type) && + !_.IsIntVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be an int or float vector type"; + } + + const uint32_t num_components = _.GetDimension(result_type); + if (num_components > 4 && num_components != 8 && num_components != 16) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to have 2, 3, 4, 8 or 16 components"; + } + + const uint32_t offset_type = _.GetOperandTypeId(inst, 4); + const uint32_t p_type = _.GetOperandTypeId(inst, 5); + + const uint32_t size_t_bit_width = GetSizeTBitWidth(_); + if (!size_t_bit_width) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() + << " can only be used with physical addressing models"; + } + + if (!_.IsIntScalarType(offset_type) || + _.GetBitWidth(offset_type) != size_t_bit_width) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand Offset to be of type size_t (" + << size_t_bit_width + << "-bit integer for the addressing model used in the module)"; + } + + uint32_t p_storage_class = 0; + uint32_t p_data_type = 0; + if (!_.GetPointerTypeInfo(p_type, &p_data_type, &p_storage_class)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P to be a pointer"; + } + + if (p_storage_class != SpvStorageClassUniformConstant && + p_storage_class != SpvStorageClassGeneric) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P storage class to be UniformConstant or " + "Generic"; + } + + if (_.GetComponentType(result_type) != p_data_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P data type to be equal to component " + "type of Result Type"; + } + + const uint32_t n_value = inst->word(7); + if (num_components != n_value) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected literal N to be equal to the number of " + "components of Result Type"; + } + break; + } + + case OpenCLLIB::Vstoren: { + if (_.GetIdOpcode(result_type) != SpvOpTypeVoid) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": expected Result Type to be void"; + } + + const uint32_t data_type = _.GetOperandTypeId(inst, 4); + const uint32_t offset_type = _.GetOperandTypeId(inst, 5); + const uint32_t p_type = _.GetOperandTypeId(inst, 6); + + if (!_.IsFloatVectorType(data_type) && !_.IsIntVectorType(data_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Data to be an int or float vector"; + } + + const uint32_t num_components = _.GetDimension(data_type); + if (num_components > 4 && num_components != 8 && num_components != 16) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Data to have 2, 3, 4, 8 or 16 components"; + } + + const uint32_t size_t_bit_width = GetSizeTBitWidth(_); + if (!size_t_bit_width) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() + << " can only be used with physical addressing models"; + } + + if (!_.IsIntScalarType(offset_type) || + _.GetBitWidth(offset_type) != size_t_bit_width) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand Offset to be of type size_t (" + << size_t_bit_width + << "-bit integer for the addressing model used in the module)"; + } + + uint32_t p_storage_class = 0; + uint32_t p_data_type = 0; + if (!_.GetPointerTypeInfo(p_type, &p_data_type, &p_storage_class)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P to be a pointer"; + } + + if (p_storage_class != SpvStorageClassGeneric) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P storage class to be Generic"; + } + + if (_.GetComponentType(data_type) != p_data_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P data type to be equal to the type of " + "operand Data components"; + } + break; + } + + case OpenCLLIB::Vload_half: { + if (!_.IsFloatScalarType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a float scalar type"; + } + + const uint32_t offset_type = _.GetOperandTypeId(inst, 4); + const uint32_t p_type = _.GetOperandTypeId(inst, 5); + + const uint32_t size_t_bit_width = GetSizeTBitWidth(_); + if (!size_t_bit_width) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() + << " can only be used with physical addressing models"; + } + + if (!_.IsIntScalarType(offset_type) || + _.GetBitWidth(offset_type) != size_t_bit_width) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand Offset to be of type size_t (" + << size_t_bit_width + << "-bit integer for the addressing model used in the module)"; + } + + uint32_t p_storage_class = 0; + uint32_t p_data_type = 0; + if (!_.GetPointerTypeInfo(p_type, &p_data_type, &p_storage_class)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P to be a pointer"; + } + + if (p_storage_class != SpvStorageClassUniformConstant && + p_storage_class != SpvStorageClassGeneric && + p_storage_class != SpvStorageClassCrossWorkgroup && + p_storage_class != SpvStorageClassWorkgroup && + p_storage_class != SpvStorageClassFunction) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P storage class to be UniformConstant, " + "Generic, CrossWorkgroup, Workgroup or Function"; + } + + if (!_.IsFloatScalarType(p_data_type) || + _.GetBitWidth(p_data_type) != 16) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P data type to be 16-bit float scalar"; + } + break; + } + + case OpenCLLIB::Vload_halfn: + case OpenCLLIB::Vloada_halfn: { + if (!_.IsFloatVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a float vector type"; + } + + const uint32_t num_components = _.GetDimension(result_type); + if (num_components > 4 && num_components != 8 && num_components != 16) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to have 2, 3, 4, 8 or 16 components"; + } + + const uint32_t offset_type = _.GetOperandTypeId(inst, 4); + const uint32_t p_type = _.GetOperandTypeId(inst, 5); + + const uint32_t size_t_bit_width = GetSizeTBitWidth(_); + if (!size_t_bit_width) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() + << " can only be used with physical addressing models"; + } + + if (!_.IsIntScalarType(offset_type) || + _.GetBitWidth(offset_type) != size_t_bit_width) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand Offset to be of type size_t (" + << size_t_bit_width + << "-bit integer for the addressing model used in the module)"; + } + + uint32_t p_storage_class = 0; + uint32_t p_data_type = 0; + if (!_.GetPointerTypeInfo(p_type, &p_data_type, &p_storage_class)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P to be a pointer"; + } + + if (p_storage_class != SpvStorageClassUniformConstant && + p_storage_class != SpvStorageClassGeneric && + p_storage_class != SpvStorageClassCrossWorkgroup && + p_storage_class != SpvStorageClassWorkgroup && + p_storage_class != SpvStorageClassFunction) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P storage class to be UniformConstant, " + "Generic, CrossWorkgroup, Workgroup or Function"; + } + + if (!_.IsFloatScalarType(p_data_type) || + _.GetBitWidth(p_data_type) != 16) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P data type to be 16-bit float scalar"; + } + + const uint32_t n_value = inst->word(7); + if (num_components != n_value) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected literal N to be equal to the number of " + "components of Result Type"; + } + break; + } + + case OpenCLLIB::Vstore_half: + case OpenCLLIB::Vstore_half_r: + case OpenCLLIB::Vstore_halfn: + case OpenCLLIB::Vstore_halfn_r: + case OpenCLLIB::Vstorea_halfn: + case OpenCLLIB::Vstorea_halfn_r: { + if (_.GetIdOpcode(result_type) != SpvOpTypeVoid) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": expected Result Type to be void"; + } + + const uint32_t data_type = _.GetOperandTypeId(inst, 4); + const uint32_t offset_type = _.GetOperandTypeId(inst, 5); + const uint32_t p_type = _.GetOperandTypeId(inst, 6); + const uint32_t data_type_bit_width = _.GetBitWidth(data_type); + + if (ext_inst_key == OpenCLLIB::Vstore_half || + ext_inst_key == OpenCLLIB::Vstore_half_r) { + if (!_.IsFloatScalarType(data_type) || + (data_type_bit_width != 32 && data_type_bit_width != 64)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Data to be a 32 or 64-bit float scalar"; + } + } else { + if (!_.IsFloatVectorType(data_type) || + (data_type_bit_width != 32 && data_type_bit_width != 64)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Data to be a 32 or 64-bit float vector"; + } + + const uint32_t num_components = _.GetDimension(data_type); + if (num_components > 4 && num_components != 8 && + num_components != 16) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Data to have 2, 3, 4, 8 or 16 components"; + } + } + + const uint32_t size_t_bit_width = GetSizeTBitWidth(_); + if (!size_t_bit_width) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() + << " can only be used with physical addressing models"; + } + + if (!_.IsIntScalarType(offset_type) || + _.GetBitWidth(offset_type) != size_t_bit_width) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand Offset to be of type size_t (" + << size_t_bit_width + << "-bit integer for the addressing model used in the module)"; + } + + uint32_t p_storage_class = 0; + uint32_t p_data_type = 0; + if (!_.GetPointerTypeInfo(p_type, &p_data_type, &p_storage_class)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P to be a pointer"; + } + + if (p_storage_class != SpvStorageClassGeneric && + p_storage_class != SpvStorageClassCrossWorkgroup && + p_storage_class != SpvStorageClassWorkgroup && + p_storage_class != SpvStorageClassFunction) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P storage class to be Generic, " + "CrossWorkgroup, Workgroup or Function"; + } + + if (!_.IsFloatScalarType(p_data_type) || + _.GetBitWidth(p_data_type) != 16) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand P data type to be 16-bit float scalar"; + } + + // Rounding mode enum is checked by assembler. + break; + } + + case OpenCLLIB::Shuffle: + case OpenCLLIB::Shuffle2: { + if (!_.IsFloatVectorType(result_type) && + !_.IsIntVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be an int or float vector type"; + } + + const uint32_t result_num_components = _.GetDimension(result_type); + if (result_num_components != 2 && result_num_components != 4 && + result_num_components != 8 && result_num_components != 16) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to have 2, 4, 8 or 16 components"; + } + + uint32_t operand_index = 4; + const uint32_t x_type = _.GetOperandTypeId(inst, operand_index++); + + if (ext_inst_key == OpenCLLIB::Shuffle2) { + const uint32_t y_type = _.GetOperandTypeId(inst, operand_index++); + if (x_type != y_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operands X and Y to be of the same type"; + } + } + + const uint32_t shuffle_mask_type = + _.GetOperandTypeId(inst, operand_index++); + + if (!_.IsFloatVectorType(x_type) && !_.IsIntVectorType(x_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand X to be an int or float vector"; + } + + const uint32_t x_num_components = _.GetDimension(x_type); + if (x_num_components != 2 && x_num_components != 4 && + x_num_components != 8 && x_num_components != 16) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand X to have 2, 4, 8 or 16 components"; + } + + const uint32_t result_component_type = _.GetComponentType(result_type); + + if (result_component_type != _.GetComponentType(x_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand X and Result Type to have equal " + "component types"; + } + + if (!_.IsIntVectorType(shuffle_mask_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand Shuffle Mask to be an int vector"; + } + + if (result_num_components != _.GetDimension(shuffle_mask_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand Shuffle Mask to have the same number of " + "components as Result Type"; + } + + if (_.GetBitWidth(result_component_type) != + _.GetBitWidth(shuffle_mask_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand Shuffle Mask components to have the same " + "bit width as Result Type components"; + } + break; + } + + case OpenCLLIB::Printf: { + if (!_.IsIntScalarType(result_type) || + _.GetBitWidth(result_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a 32-bit int type"; + } + + const uint32_t format_type = _.GetOperandTypeId(inst, 4); + uint32_t format_storage_class = 0; + uint32_t format_data_type = 0; + if (!_.GetPointerTypeInfo(format_type, &format_data_type, + &format_storage_class)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand Format to be a pointer"; + } + + if (format_storage_class != SpvStorageClassUniformConstant) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Format storage class to be UniformConstant"; + } + + if (!_.IsIntScalarType(format_data_type) || + _.GetBitWidth(format_data_type) != 8) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Format data type to be 8-bit int"; + } + break; + } + + case OpenCLLIB::Prefetch: { + if (_.GetIdOpcode(result_type) != SpvOpTypeVoid) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": expected Result Type to be void"; + } + + const uint32_t p_type = _.GetOperandTypeId(inst, 4); + const uint32_t num_elements_type = _.GetOperandTypeId(inst, 5); + + uint32_t p_storage_class = 0; + uint32_t p_data_type = 0; + if (!_.GetPointerTypeInfo(p_type, &p_data_type, &p_storage_class)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand Ptr to be a pointer"; + } + + if (p_storage_class != SpvStorageClassCrossWorkgroup) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand Ptr storage class to be CrossWorkgroup"; + } + + if (!_.IsFloatScalarOrVectorType(p_data_type) && + !_.IsIntScalarOrVectorType(p_data_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Ptr data type to be int or float scalar or " + "vector"; + } + + const uint32_t num_components = _.GetDimension(p_data_type); + if (num_components > 4 && num_components != 8 && num_components != 16) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected Result Type to be a scalar or a vector with 2, " + "3, 4, 8 or 16 components"; + } + + const uint32_t size_t_bit_width = GetSizeTBitWidth(_); + if (!size_t_bit_width) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() + << " can only be used with physical addressing models"; + } + + if (!_.IsIntScalarType(num_elements_type) || + _.GetBitWidth(num_elements_type) != size_t_bit_width) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << ext_inst_name() << ": " + << "expected operand Num Elements to be of type size_t (" + << size_t_bit_width + << "-bit integer for the addressing model used in the module)"; + } + break; + } + } + } + + return SPV_SUCCESS; +} + +spv_result_t ExtensionPass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + if (opcode == SpvOpExtension) return ValidateExtension(_, inst); + if (opcode == SpvOpExtInstImport) return ValidateExtInstImport(_, inst); + if (opcode == SpvOpExtInst) return ValidateExtInst(_, inst); + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_function.cpp b/source/val/validate_function.cpp new file mode 100644 index 000000000..de41b27c3 --- /dev/null +++ b/source/val/validate_function.cpp @@ -0,0 +1,275 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/validate.h" + +#include + +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) { + const auto function_type_id = inst->GetOperandAs(3); + const auto function_type = _.FindDef(function_type_id); + if (!function_type || SpvOpTypeFunction != function_type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpFunction Function Type '" << _.getIdName(function_type_id) + << "' is not a function type."; + } + + const auto return_id = function_type->GetOperandAs(1); + if (return_id != inst->type_id()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpFunction Result Type '" << _.getIdName(inst->type_id()) + << "' does not match the Function Type's return type '" + << _.getIdName(return_id) << "'."; + } + + for (auto& pair : inst->uses()) { + const auto* use = pair.first; + const std::vector acceptable = { + SpvOpFunctionCall, + SpvOpEntryPoint, + SpvOpEnqueueKernel, + SpvOpGetKernelNDrangeSubGroupCount, + SpvOpGetKernelNDrangeMaxSubGroupSize, + SpvOpGetKernelWorkGroupSize, + SpvOpGetKernelPreferredWorkGroupSizeMultiple, + SpvOpGetKernelLocalSizeForSubgroupCount, + SpvOpGetKernelMaxNumSubgroups}; + if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) == + acceptable.end()) { + return _.diag(SPV_ERROR_INVALID_ID, use) + << "Invalid use of function result id " << _.getIdName(inst->id()) + << "."; + } + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateFunctionParameter(ValidationState_t& _, + const Instruction* inst) { + // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place. + size_t param_index = 0; + size_t inst_num = inst->LineNum() - 1; + if (inst_num == 0) { + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) + << "Function parameter cannot be the first instruction."; + } + + auto func_inst = &_.ordered_instructions()[inst_num]; + while (--inst_num) { + func_inst = &_.ordered_instructions()[inst_num]; + if (func_inst->opcode() == SpvOpFunction) { + break; + } else if (func_inst->opcode() == SpvOpFunctionParameter) { + ++param_index; + } + } + + if (func_inst->opcode() != SpvOpFunction) { + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) + << "Function parameter must be preceded by a function."; + } + + const auto function_type_id = func_inst->GetOperandAs(3); + const auto function_type = _.FindDef(function_type_id); + if (!function_type) { + return _.diag(SPV_ERROR_INVALID_ID, func_inst) + << "Missing function type definition."; + } + if (param_index >= function_type->words().size() - 3) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Too many OpFunctionParameters for " << func_inst->id() + << ": expected " << function_type->words().size() - 3 + << " based on the function's type"; + } + + const auto param_type = + _.FindDef(function_type->GetOperandAs(param_index + 2)); + if (!param_type || inst->type_id() != param_type->id()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpFunctionParameter Result Type '" + << _.getIdName(inst->type_id()) + << "' does not match the OpTypeFunction parameter " + "type of the same index."; + } + + // Validate that PhysicalStorageBufferEXT have one of Restrict, Aliased, + // RestrictPointerEXT, or AliasedPointerEXT. + auto param_nonarray_type_id = param_type->id(); + while (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypeArray) { + param_nonarray_type_id = + _.FindDef(param_nonarray_type_id)->GetOperandAs(1u); + } + if (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypePointer) { + auto param_nonarray_type = _.FindDef(param_nonarray_type_id); + if (param_nonarray_type->GetOperandAs(1u) == + SpvStorageClassPhysicalStorageBufferEXT) { + // check for Aliased or Restrict + const auto& decorations = _.id_decorations(inst->id()); + + bool foundAliased = std::any_of( + decorations.begin(), decorations.end(), [](const Decoration& d) { + return SpvDecorationAliased == d.dec_type(); + }); + + bool foundRestrict = std::any_of( + decorations.begin(), decorations.end(), [](const Decoration& d) { + return SpvDecorationRestrict == d.dec_type(); + }); + + if (!foundAliased && !foundRestrict) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpFunctionParameter " << inst->id() + << ": expected Aliased or Restrict for PhysicalStorageBufferEXT " + "pointer."; + } + if (foundAliased && foundRestrict) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpFunctionParameter " << inst->id() + << ": can't specify both Aliased and Restrict for " + "PhysicalStorageBufferEXT pointer."; + } + } else { + const auto pointee_type_id = + param_nonarray_type->GetOperandAs(2); + const auto pointee_type = _.FindDef(pointee_type_id); + if (SpvOpTypePointer == pointee_type->opcode() && + pointee_type->GetOperandAs(1u) == + SpvStorageClassPhysicalStorageBufferEXT) { + // check for AliasedPointerEXT/RestrictPointerEXT + const auto& decorations = _.id_decorations(inst->id()); + + bool foundAliased = std::any_of( + decorations.begin(), decorations.end(), [](const Decoration& d) { + return SpvDecorationAliasedPointerEXT == d.dec_type(); + }); + + bool foundRestrict = std::any_of( + decorations.begin(), decorations.end(), [](const Decoration& d) { + return SpvDecorationRestrictPointerEXT == d.dec_type(); + }); + + if (!foundAliased && !foundRestrict) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpFunctionParameter " << inst->id() + << ": expected AliasedPointerEXT or RestrictPointerEXT for " + "PhysicalStorageBufferEXT pointer."; + } + if (foundAliased && foundRestrict) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpFunctionParameter " << inst->id() + << ": can't specify both AliasedPointerEXT and " + "RestrictPointerEXT for PhysicalStorageBufferEXT pointer."; + } + } + } + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateFunctionCall(ValidationState_t& _, + const Instruction* inst) { + const auto function_id = inst->GetOperandAs(2); + const auto function = _.FindDef(function_id); + if (!function || SpvOpFunction != function->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpFunctionCall Function '" << _.getIdName(function_id) + << "' is not a function."; + } + + auto return_type = _.FindDef(function->type_id()); + if (!return_type || return_type->id() != inst->type_id()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpFunctionCall Result Type '" + << _.getIdName(inst->type_id()) + << "'s type does not match Function '" + << _.getIdName(return_type->id()) << "'s return type."; + } + + const auto function_type_id = function->GetOperandAs(3); + const auto function_type = _.FindDef(function_type_id); + if (!function_type || function_type->opcode() != SpvOpTypeFunction) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Missing function type definition."; + } + + const auto function_call_arg_count = inst->words().size() - 4; + const auto function_param_count = function_type->words().size() - 3; + if (function_param_count != function_call_arg_count) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpFunctionCall Function 's parameter count does not match " + "the argument count."; + } + + for (size_t argument_index = 3, param_index = 2; + argument_index < inst->operands().size(); + argument_index++, param_index++) { + const auto argument_id = inst->GetOperandAs(argument_index); + const auto argument = _.FindDef(argument_id); + if (!argument) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Missing argument " << argument_index - 3 << " definition."; + } + + const auto argument_type = _.FindDef(argument->type_id()); + if (!argument_type) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Missing argument " << argument_index - 3 + << " type definition."; + } + + const auto parameter_type_id = + function_type->GetOperandAs(param_index); + const auto parameter_type = _.FindDef(parameter_type_id); + if (!parameter_type || argument_type->id() != parameter_type->id()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpFunctionCall Argument '" << _.getIdName(argument_id) + << "'s type does not match Function '" + << _.getIdName(parameter_type_id) << "'s parameter type."; + } + } + return SPV_SUCCESS; +} + +} // namespace + +spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) { + switch (inst->opcode()) { + case SpvOpFunction: + if (auto error = ValidateFunction(_, inst)) return error; + break; + case SpvOpFunctionParameter: + if (auto error = ValidateFunctionParameter(_, inst)) return error; + break; + case SpvOpFunctionCall: + if (auto error = ValidateFunctionCall(_, inst)) return error; + break; + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_id.cpp b/source/val/validate_id.cpp new file mode 100644 index 000000000..21a04113d --- /dev/null +++ b/source/val/validate_id.cpp @@ -0,0 +1,222 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/validate.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/diagnostic.h" +#include "source/instruction.h" +#include "source/opcode.h" +#include "source/operand.h" +#include "source/spirv_validator_options.h" +#include "source/val/function.h" +#include "source/val/validation_state.h" +#include "spirv-tools/libspirv.h" + +namespace spvtools { +namespace val { + +spv_result_t UpdateIdUse(ValidationState_t& _, const Instruction* inst) { + for (auto& operand : inst->operands()) { + const spv_operand_type_t& type = operand.type; + const uint32_t operand_id = inst->word(operand.offset); + if (spvIsIdType(type) && type != SPV_OPERAND_TYPE_RESULT_ID) { + if (auto def = _.FindDef(operand_id)) + def->RegisterUse(inst, operand.offset); + } + } + + return SPV_SUCCESS; +} + +/// This function checks all ID definitions dominate their use in the CFG. +/// +/// This function will iterate over all ID definitions that are defined in the +/// functions of a module and make sure that the definitions appear in a +/// block that dominates their use. +/// +/// NOTE: This function does NOT check module scoped functions which are +/// checked during the initial binary parse in the IdPass below +spv_result_t CheckIdDefinitionDominateUse(ValidationState_t& _) { + std::vector phi_instructions; + std::unordered_set phi_ids; + for (const auto& inst : _.ordered_instructions()) { + if (inst.id() == 0) continue; + if (const Function* func = inst.function()) { + if (const BasicBlock* block = inst.block()) { + // If the Id is defined within a block then make sure all references to + // that Id appear in a blocks that are dominated by the defining block + for (auto& use_index_pair : inst.uses()) { + const Instruction* use = use_index_pair.first; + if (const BasicBlock* use_block = use->block()) { + if (use_block->reachable() == false) continue; + if (use->opcode() == SpvOpPhi) { + if (phi_ids.insert(use->id()).second) { + phi_instructions.push_back(use); + } + } else if (!block->dominates(*use->block())) { + return _.diag(SPV_ERROR_INVALID_ID, use_block->label()) + << "ID " << _.getIdName(inst.id()) << " defined in block " + << _.getIdName(block->id()) + << " does not dominate its use in block " + << _.getIdName(use_block->id()); + } + } + } + } else { + // If the Ids defined within a function but not in a block(i.e. function + // parameters, block ids), then make sure all references to that Id + // appear within the same function + for (auto use : inst.uses()) { + const Instruction* user = use.first; + if (user->function() && user->function() != func) { + return _.diag(SPV_ERROR_INVALID_ID, _.FindDef(func->id())) + << "ID " << _.getIdName(inst.id()) << " used in function " + << _.getIdName(user->function()->id()) + << " is used outside of it's defining function " + << _.getIdName(func->id()); + } + } + } + } + // NOTE: Ids defined outside of functions must appear before they are used + // This check is being performed in the IdPass function + } + + // Check all OpPhi parent blocks are dominated by the variable's defining + // blocks + for (const Instruction* phi : phi_instructions) { + if (phi->block()->reachable() == false) continue; + for (size_t i = 3; i < phi->operands().size(); i += 2) { + const Instruction* variable = _.FindDef(phi->word(i)); + const BasicBlock* parent = + phi->function()->GetBlock(phi->word(i + 1)).first; + if (variable->block() && parent->reachable() && + !variable->block()->dominates(*parent)) { + return _.diag(SPV_ERROR_INVALID_ID, phi) + << "In OpPhi instruction " << _.getIdName(phi->id()) << ", ID " + << _.getIdName(variable->id()) + << " definition does not dominate its parent " + << _.getIdName(parent->id()); + } + } + } + + return SPV_SUCCESS; +} + +// Performs SSA validation on the IDs of an instruction. The +// can_have_forward_declared_ids functor should return true if the +// instruction operand's ID can be forward referenced. +spv_result_t IdPass(ValidationState_t& _, Instruction* inst) { + auto can_have_forward_declared_ids = + spvOperandCanBeForwardDeclaredFunction(inst->opcode()); + + // Keep track of a result id defined by this instruction. 0 means it + // does not define an id. + uint32_t result_id = 0; + + for (unsigned i = 0; i < inst->operands().size(); i++) { + const spv_parsed_operand_t& operand = inst->operand(i); + const spv_operand_type_t& type = operand.type; + // We only care about Id operands, which are a single word. + const uint32_t operand_word = inst->word(operand.offset); + + auto ret = SPV_ERROR_INTERNAL; + switch (type) { + case SPV_OPERAND_TYPE_RESULT_ID: + // NOTE: Multiple Id definitions are being checked by the binary parser. + // + // Defer undefined-forward-reference removal until after we've analyzed + // the remaining operands to this instruction. Deferral only matters + // for OpPhi since it's the only case where it defines its own forward + // reference. Other instructions that can have forward references + // either don't define a value or the forward reference is to a function + // Id (and hence defined outside of a function body). + result_id = operand_word; + // NOTE: The result Id is added (in RegisterInstruction) *after* all of + // the other Ids have been checked to avoid premature use in the same + // instruction. + ret = SPV_SUCCESS; + break; + case SPV_OPERAND_TYPE_ID: + case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: + case SPV_OPERAND_TYPE_SCOPE_ID: + if (const auto def = _.FindDef(operand_word)) { + const auto opcode = inst->opcode(); + if (spvOpcodeGeneratesType(def->opcode()) && + !spvOpcodeGeneratesType(opcode) && !spvOpcodeIsDebug(opcode) && + !spvOpcodeIsDecoration(opcode) && opcode != SpvOpFunction) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Operand " << _.getIdName(operand_word) + << " cannot be a type"; + } else if (def->type_id() == 0 && !spvOpcodeGeneratesType(opcode) && + !spvOpcodeIsDebug(opcode) && + !spvOpcodeIsDecoration(opcode) && + !spvOpcodeIsBranch(opcode) && opcode != SpvOpPhi && + opcode != SpvOpExtInst && opcode != SpvOpExtInstImport && + opcode != SpvOpSelectionMerge && + opcode != SpvOpLoopMerge && opcode != SpvOpFunction) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Operand " << _.getIdName(operand_word) + << " requires a type"; + } else { + ret = SPV_SUCCESS; + } + } else if (can_have_forward_declared_ids(i)) { + ret = _.ForwardDeclareId(operand_word); + } else { + ret = _.diag(SPV_ERROR_INVALID_ID, inst) + << "ID " << _.getIdName(operand_word) + << " has not been defined"; + } + break; + case SPV_OPERAND_TYPE_TYPE_ID: + if (_.IsDefinedId(operand_word)) { + auto* def = _.FindDef(operand_word); + if (!spvOpcodeGeneratesType(def->opcode())) { + ret = _.diag(SPV_ERROR_INVALID_ID, inst) + << "ID " << _.getIdName(operand_word) << " is not a type id"; + } else { + ret = SPV_SUCCESS; + } + } else { + ret = _.diag(SPV_ERROR_INVALID_ID, inst) + << "ID " << _.getIdName(operand_word) + << " has not been defined"; + } + break; + default: + ret = SPV_SUCCESS; + break; + } + if (SPV_SUCCESS != ret) return ret; + } + if (result_id) _.RemoveIfForwardDeclared(result_id); + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_image.cpp b/source/val/validate_image.cpp new file mode 100644 index 000000000..8a357ff34 --- /dev/null +++ b/source/val/validate_image.cpp @@ -0,0 +1,1789 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validates correctness of image instructions. + +#include "source/val/validate.h" + +#include + +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/spirv_target_env.h" +#include "source/util/bitutils.h" +#include "source/val/instruction.h" +#include "source/val/validate_scopes.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +// Performs compile time check that all SpvImageOperandsXXX cases are handled in +// this module. If SpvImageOperandsXXX list changes, this function will fail the +// build. +// For all other purposes this is a dummy function. +bool CheckAllImageOperandsHandled() { + SpvImageOperandsMask enum_val = SpvImageOperandsBiasMask; + + // Some improvised code to prevent the compiler from considering enum_val + // constant and optimizing the switch away. + uint32_t stack_var = 0; + if (reinterpret_cast(&stack_var) % 256) + enum_val = SpvImageOperandsLodMask; + + switch (enum_val) { + // Please update the validation rules in this module if you are changing + // the list of image operands, and add new enum values to this switch. + case SpvImageOperandsMaskNone: + return false; + case SpvImageOperandsBiasMask: + case SpvImageOperandsLodMask: + case SpvImageOperandsGradMask: + case SpvImageOperandsConstOffsetMask: + case SpvImageOperandsOffsetMask: + case SpvImageOperandsConstOffsetsMask: + case SpvImageOperandsSampleMask: + case SpvImageOperandsMinLodMask: + + // TODO(dneto): Support image operands related to the Vulkan memory model. + // https://gitlab.khronos.org/spirv/spirv-tools/issues/32 + case SpvImageOperandsMakeTexelAvailableKHRMask: + case SpvImageOperandsMakeTexelVisibleKHRMask: + case SpvImageOperandsNonPrivateTexelKHRMask: + case SpvImageOperandsVolatileTexelKHRMask: + return true; + } + return false; +} + +// Used by GetImageTypeInfo. See OpTypeImage spec for more information. +struct ImageTypeInfo { + uint32_t sampled_type = 0; + SpvDim dim = SpvDimMax; + uint32_t depth = 0; + uint32_t arrayed = 0; + uint32_t multisampled = 0; + uint32_t sampled = 0; + SpvImageFormat format = SpvImageFormatMax; + SpvAccessQualifier access_qualifier = SpvAccessQualifierMax; +}; + +// Provides information on image type. |id| should be object of either +// OpTypeImage or OpTypeSampledImage type. Returns false in case of failure +// (not a valid id, failed to parse the instruction, etc). +bool GetImageTypeInfo(const ValidationState_t& _, uint32_t id, + ImageTypeInfo* info) { + if (!id || !info) return false; + + const Instruction* inst = _.FindDef(id); + assert(inst); + + if (inst->opcode() == SpvOpTypeSampledImage) { + inst = _.FindDef(inst->word(2)); + assert(inst); + } + + if (inst->opcode() != SpvOpTypeImage) return false; + + const size_t num_words = inst->words().size(); + if (num_words != 9 && num_words != 10) return false; + + info->sampled_type = inst->word(2); + info->dim = static_cast(inst->word(3)); + info->depth = inst->word(4); + info->arrayed = inst->word(5); + info->multisampled = inst->word(6); + info->sampled = inst->word(7); + info->format = static_cast(inst->word(8)); + info->access_qualifier = num_words < 10 + ? SpvAccessQualifierMax + : static_cast(inst->word(9)); + return true; +} + +bool IsImplicitLod(SpvOp opcode) { + switch (opcode) { + case SpvOpImageSampleImplicitLod: + case SpvOpImageSampleDrefImplicitLod: + case SpvOpImageSampleProjImplicitLod: + case SpvOpImageSampleProjDrefImplicitLod: + case SpvOpImageSparseSampleImplicitLod: + case SpvOpImageSparseSampleDrefImplicitLod: + case SpvOpImageSparseSampleProjImplicitLod: + case SpvOpImageSparseSampleProjDrefImplicitLod: + return true; + default: + break; + } + return false; +} + +bool IsExplicitLod(SpvOp opcode) { + switch (opcode) { + case SpvOpImageSampleExplicitLod: + case SpvOpImageSampleDrefExplicitLod: + case SpvOpImageSampleProjExplicitLod: + case SpvOpImageSampleProjDrefExplicitLod: + case SpvOpImageSparseSampleExplicitLod: + case SpvOpImageSparseSampleDrefExplicitLod: + case SpvOpImageSparseSampleProjExplicitLod: + case SpvOpImageSparseSampleProjDrefExplicitLod: + return true; + default: + break; + } + return false; +} + +// Returns true if the opcode is a Image instruction which applies +// homogenous projection to the coordinates. +bool IsProj(SpvOp opcode) { + switch (opcode) { + case SpvOpImageSampleProjImplicitLod: + case SpvOpImageSampleProjDrefImplicitLod: + case SpvOpImageSparseSampleProjImplicitLod: + case SpvOpImageSparseSampleProjDrefImplicitLod: + case SpvOpImageSampleProjExplicitLod: + case SpvOpImageSampleProjDrefExplicitLod: + case SpvOpImageSparseSampleProjExplicitLod: + case SpvOpImageSparseSampleProjDrefExplicitLod: + return true; + default: + break; + } + return false; +} + +// Returns the number of components in a coordinate used to access a texel in +// a single plane of an image with the given parameters. +uint32_t GetPlaneCoordSize(const ImageTypeInfo& info) { + uint32_t plane_size = 0; + // If this switch breaks your build, please add new values below. + switch (info.dim) { + case SpvDim1D: + case SpvDimBuffer: + plane_size = 1; + break; + case SpvDim2D: + case SpvDimRect: + case SpvDimSubpassData: + plane_size = 2; + break; + case SpvDim3D: + case SpvDimCube: + // For Cube direction vector is used instead of UV. + plane_size = 3; + break; + case SpvDimMax: + assert(0); + break; + } + + return plane_size; +} + +// Returns minimal number of coordinates based on image dim, arrayed and whether +// the instruction uses projection coordinates. +uint32_t GetMinCoordSize(SpvOp opcode, const ImageTypeInfo& info) { + if (info.dim == SpvDimCube && + (opcode == SpvOpImageRead || opcode == SpvOpImageWrite || + opcode == SpvOpImageSparseRead)) { + // These opcodes use UV for Cube, not direction vector. + return 3; + } + + return GetPlaneCoordSize(info) + info.arrayed + (IsProj(opcode) ? 1 : 0); +} + +// Checks ImageOperand bitfield and respective operands. +spv_result_t ValidateImageOperands(ValidationState_t& _, + const Instruction* inst, + const ImageTypeInfo& info, uint32_t mask, + uint32_t word_index) { + static const bool kAllImageOperandsHandled = CheckAllImageOperandsHandled(); + (void)kAllImageOperandsHandled; + + const SpvOp opcode = inst->opcode(); + const size_t num_words = inst->words().size(); + + // NonPrivate and Volatile take no operand words. + const uint32_t mask_bits_having_operands = + mask & ~uint32_t(SpvImageOperandsNonPrivateTexelKHRMask | + SpvImageOperandsVolatileTexelKHRMask); + size_t expected_num_image_operand_words = + spvtools::utils::CountSetBits(mask_bits_having_operands); + if (mask & SpvImageOperandsGradMask) { + // Grad uses two words. + ++expected_num_image_operand_words; + } + + if (expected_num_image_operand_words != num_words - word_index) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Number of image operand ids doesn't correspond to the bit mask"; + } + + if (spvtools::utils::CountSetBits( + mask & (SpvImageOperandsOffsetMask | SpvImageOperandsConstOffsetMask | + SpvImageOperandsConstOffsetsMask)) > 1) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operands Offset, ConstOffset, ConstOffsets cannot be used " + << "together"; + } + + const bool is_implicit_lod = IsImplicitLod(opcode); + const bool is_explicit_lod = IsExplicitLod(opcode); + + // The checks should be done in the order of definition of OperandImage. + + if (mask & SpvImageOperandsBiasMask) { + if (!is_implicit_lod) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand Bias can only be used with ImplicitLod opcodes"; + } + + const uint32_t type_id = _.GetTypeId(inst->word(word_index++)); + if (!_.IsFloatScalarType(type_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand Bias to be float scalar"; + } + + if (info.dim != SpvDim1D && info.dim != SpvDim2D && info.dim != SpvDim3D && + info.dim != SpvDimCube) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand Bias requires 'Dim' parameter to be 1D, 2D, 3D " + "or Cube"; + } + + if (info.multisampled != 0) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand Bias requires 'MS' parameter to be 0"; + } + } + + if (mask & SpvImageOperandsLodMask) { + if (!is_explicit_lod && opcode != SpvOpImageFetch && + opcode != SpvOpImageSparseFetch) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand Lod can only be used with ExplicitLod opcodes " + << "and OpImageFetch"; + } + + if (mask & SpvImageOperandsGradMask) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand bits Lod and Grad cannot be set at the same " + "time"; + } + + const uint32_t type_id = _.GetTypeId(inst->word(word_index++)); + if (is_explicit_lod) { + if (!_.IsFloatScalarType(type_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand Lod to be float scalar when used " + << "with ExplicitLod"; + } + } else { + if (!_.IsIntScalarType(type_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand Lod to be int scalar when used with " + << "OpImageFetch"; + } + } + + if (info.dim != SpvDim1D && info.dim != SpvDim2D && info.dim != SpvDim3D && + info.dim != SpvDimCube) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand Lod requires 'Dim' parameter to be 1D, 2D, 3D " + "or Cube"; + } + + if (info.multisampled != 0) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand Lod requires 'MS' parameter to be 0"; + } + } + + if (mask & SpvImageOperandsGradMask) { + if (!is_explicit_lod) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand Grad can only be used with ExplicitLod opcodes"; + } + + const uint32_t dx_type_id = _.GetTypeId(inst->word(word_index++)); + const uint32_t dy_type_id = _.GetTypeId(inst->word(word_index++)); + if (!_.IsFloatScalarOrVectorType(dx_type_id) || + !_.IsFloatScalarOrVectorType(dy_type_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected both Image Operand Grad ids to be float scalars or " + << "vectors"; + } + + const uint32_t plane_size = GetPlaneCoordSize(info); + const uint32_t dx_size = _.GetDimension(dx_type_id); + const uint32_t dy_size = _.GetDimension(dy_type_id); + if (plane_size != dx_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand Grad dx to have " << plane_size + << " components, but given " << dx_size; + } + + if (plane_size != dy_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand Grad dy to have " << plane_size + << " components, but given " << dy_size; + } + + if (info.multisampled != 0) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand Grad requires 'MS' parameter to be 0"; + } + } + + if (mask & SpvImageOperandsConstOffsetMask) { + if (info.dim == SpvDimCube) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand ConstOffset cannot be used with Cube Image " + "'Dim'"; + } + + const uint32_t id = inst->word(word_index++); + const uint32_t type_id = _.GetTypeId(id); + if (!_.IsIntScalarOrVectorType(type_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand ConstOffset to be int scalar or " + << "vector"; + } + + if (!spvOpcodeIsConstant(_.GetIdOpcode(id))) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand ConstOffset to be a const object"; + } + + const uint32_t plane_size = GetPlaneCoordSize(info); + const uint32_t offset_size = _.GetDimension(type_id); + if (plane_size != offset_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand ConstOffset to have " << plane_size + << " components, but given " << offset_size; + } + } + + if (mask & SpvImageOperandsOffsetMask) { + if (info.dim == SpvDimCube) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand Offset cannot be used with Cube Image 'Dim'"; + } + + const uint32_t id = inst->word(word_index++); + const uint32_t type_id = _.GetTypeId(id); + if (!_.IsIntScalarOrVectorType(type_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand Offset to be int scalar or " + << "vector"; + } + + const uint32_t plane_size = GetPlaneCoordSize(info); + const uint32_t offset_size = _.GetDimension(type_id); + if (plane_size != offset_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand Offset to have " << plane_size + << " components, but given " << offset_size; + } + } + + if (mask & SpvImageOperandsConstOffsetsMask) { + if (opcode != SpvOpImageGather && opcode != SpvOpImageDrefGather && + opcode != SpvOpImageSparseGather && + opcode != SpvOpImageSparseDrefGather) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand ConstOffsets can only be used with " + "OpImageGather and OpImageDrefGather"; + } + + if (info.dim == SpvDimCube) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand ConstOffsets cannot be used with Cube Image " + "'Dim'"; + } + + const uint32_t id = inst->word(word_index++); + const uint32_t type_id = _.GetTypeId(id); + const Instruction* type_inst = _.FindDef(type_id); + assert(type_inst); + + if (type_inst->opcode() != SpvOpTypeArray) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand ConstOffsets to be an array of size 4"; + } + + uint64_t array_size = 0; + if (!_.GetConstantValUint64(type_inst->word(3), &array_size)) { + assert(0 && "Array type definition is corrupt"); + } + + if (array_size != 4) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand ConstOffsets to be an array of size 4"; + } + + const uint32_t component_type = type_inst->word(2); + if (!_.IsIntVectorType(component_type) || + _.GetDimension(component_type) != 2) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand ConstOffsets array componenets to be " + "int vectors of size 2"; + } + + if (!spvOpcodeIsConstant(_.GetIdOpcode(id))) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand ConstOffsets to be a const object"; + } + } + + if (mask & SpvImageOperandsSampleMask) { + if (opcode != SpvOpImageFetch && opcode != SpvOpImageRead && + opcode != SpvOpImageWrite && opcode != SpvOpImageSparseFetch && + opcode != SpvOpImageSparseRead) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand Sample can only be used with OpImageFetch, " + << "OpImageRead, OpImageWrite, OpImageSparseFetch and " + << "OpImageSparseRead"; + } + + if (info.multisampled == 0) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand Sample requires non-zero 'MS' parameter"; + } + + const uint32_t type_id = _.GetTypeId(inst->word(word_index++)); + if (!_.IsIntScalarType(type_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand Sample to be int scalar"; + } + } + + if (mask & SpvImageOperandsMinLodMask) { + if (!is_implicit_lod && !(mask & SpvImageOperandsGradMask)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand MinLod can only be used with ImplicitLod " + << "opcodes or together with Image Operand Grad"; + } + + const uint32_t type_id = _.GetTypeId(inst->word(word_index++)); + if (!_.IsFloatScalarType(type_id)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image Operand MinLod to be float scalar"; + } + + if (info.dim != SpvDim1D && info.dim != SpvDim2D && info.dim != SpvDim3D && + info.dim != SpvDimCube) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand MinLod requires 'Dim' parameter to be 1D, 2D, " + "3D or Cube"; + } + + if (info.multisampled != 0) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand MinLod requires 'MS' parameter to be 0"; + } + } + + if (mask & SpvImageOperandsMakeTexelAvailableKHRMask) { + // Checked elsewhere: capability and memory model are correct. + if (opcode != SpvOpImageWrite) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand MakeTexelAvailableKHR can only be used with Op" + << spvOpcodeString(SpvOpImageWrite) << ": Op" + << spvOpcodeString(opcode); + } + + if (!(mask & SpvImageOperandsNonPrivateTexelKHRMask)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand MakeTexelAvailableKHR requires " + "NonPrivateTexelKHR is also specified: Op" + << spvOpcodeString(opcode); + } + + const auto available_scope = inst->word(word_index++); + if (auto error = ValidateMemoryScope(_, inst, available_scope)) + return error; + } + + if (mask & SpvImageOperandsMakeTexelVisibleKHRMask) { + // Checked elsewhere: capability and memory model are correct. + if (opcode != SpvOpImageRead && opcode != SpvOpImageSparseRead) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand MakeTexelVisibleKHR can only be used with Op" + << spvOpcodeString(SpvOpImageRead) << " or Op" + << spvOpcodeString(SpvOpImageSparseRead) << ": Op" + << spvOpcodeString(opcode); + } + + if (!(mask & SpvImageOperandsNonPrivateTexelKHRMask)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Operand MakeTexelVisibleKHR requires NonPrivateTexelKHR " + "is also specified: Op" + << spvOpcodeString(opcode); + } + + const auto visible_scope = inst->word(word_index++); + if (auto error = ValidateMemoryScope(_, inst, visible_scope)) return error; + } + + return SPV_SUCCESS; +} + +// Checks some of the validation rules which are common to multiple opcodes. +spv_result_t ValidateImageCommon(ValidationState_t& _, const Instruction* inst, + const ImageTypeInfo& info) { + const SpvOp opcode = inst->opcode(); + if (IsProj(opcode)) { + if (info.dim != SpvDim1D && info.dim != SpvDim2D && info.dim != SpvDim3D && + info.dim != SpvDimRect) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Dim' parameter to be 1D, 2D, 3D or Rect"; + } + + if (info.multisampled != 0) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Image 'MS' parameter to be 0"; + } + + if (info.arrayed != 0) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Image 'arrayed' parameter to be 0"; + } + } + + if (opcode == SpvOpImageRead || opcode == SpvOpImageSparseRead || + opcode == SpvOpImageWrite) { + if (info.sampled == 0) { + } else if (info.sampled == 2) { + if (info.dim == SpvDim1D && !_.HasCapability(SpvCapabilityImage1D)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Capability Image1D is required to access storage image"; + } else if (info.dim == SpvDimRect && + !_.HasCapability(SpvCapabilityImageRect)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Capability ImageRect is required to access storage image"; + } else if (info.dim == SpvDimBuffer && + !_.HasCapability(SpvCapabilityImageBuffer)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Capability ImageBuffer is required to access storage image"; + } else if (info.dim == SpvDimCube && info.arrayed == 1 && + !_.HasCapability(SpvCapabilityImageCubeArray)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Capability ImageCubeArray is required to access " + << "storage image"; + } + + if (info.multisampled == 1 && + !_.HasCapability(SpvCapabilityImageMSArray)) { +#if 0 + // TODO(atgoo@github.com) The description of this rule in the spec + // is unclear and Glslang doesn't declare ImageMSArray. Need to clarify + // and reenable. + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Capability ImageMSArray is required to access storage " + << "image"; +#endif + } + } else { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Sampled' parameter to be 0 or 2"; + } + } + + return SPV_SUCCESS; +} + +// Returns true if opcode is *ImageSparse*, false otherwise. +bool IsSparse(SpvOp opcode) { + switch (opcode) { + case SpvOpImageSparseSampleImplicitLod: + case SpvOpImageSparseSampleExplicitLod: + case SpvOpImageSparseSampleDrefImplicitLod: + case SpvOpImageSparseSampleDrefExplicitLod: + case SpvOpImageSparseSampleProjImplicitLod: + case SpvOpImageSparseSampleProjExplicitLod: + case SpvOpImageSparseSampleProjDrefImplicitLod: + case SpvOpImageSparseSampleProjDrefExplicitLod: + case SpvOpImageSparseFetch: + case SpvOpImageSparseGather: + case SpvOpImageSparseDrefGather: + case SpvOpImageSparseTexelsResident: + case SpvOpImageSparseRead: { + return true; + } + + default: { return false; } + } + + return false; +} + +// Checks sparse image opcode result type and returns the second struct member. +// Returns inst.type_id for non-sparse image opcodes. +// Not valid for sparse image opcodes which do not return a struct. +spv_result_t GetActualResultType(ValidationState_t& _, const Instruction* inst, + uint32_t* actual_result_type) { + const SpvOp opcode = inst->opcode(); + + if (IsSparse(opcode)) { + const Instruction* const type_inst = _.FindDef(inst->type_id()); + assert(type_inst); + + if (!type_inst || type_inst->opcode() != SpvOpTypeStruct) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be OpTypeStruct"; + } + + if (type_inst->words().size() != 4 || + !_.IsIntScalarType(type_inst->word(2))) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be a struct containing an int " + "scalar and a texel"; + } + + *actual_result_type = type_inst->word(3); + } else { + *actual_result_type = inst->type_id(); + } + + return SPV_SUCCESS; +} + +// Returns a string describing actual result type of an opcode. +// Not valid for sparse image opcodes which do not return a struct. +const char* GetActualResultTypeStr(SpvOp opcode) { + if (IsSparse(opcode)) return "Result Type's second member"; + return "Result Type"; +} + +spv_result_t ValidateTypeImage(ValidationState_t& _, const Instruction* inst) { + assert(inst->type_id() == 0); + + ImageTypeInfo info; + if (!GetImageTypeInfo(_, inst->word(1), &info)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Corrupt image type definition"; + } + + if (spvIsVulkanEnv(_.context()->target_env)) { + if ((!_.IsFloatScalarType(info.sampled_type) && + !_.IsIntScalarType(info.sampled_type)) || + 32 != _.GetBitWidth(info.sampled_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Sampled Type to be a 32-bit int or float " + "scalar type for Vulkan environment"; + } + } else { + const SpvOp sampled_type_opcode = _.GetIdOpcode(info.sampled_type); + if (sampled_type_opcode != SpvOpTypeVoid && + sampled_type_opcode != SpvOpTypeInt && + sampled_type_opcode != SpvOpTypeFloat) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Sampled Type to be either void or" + << " numerical scalar type"; + } + } + + // Dim is checked elsewhere. + + if (info.depth > 2) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Invalid Depth " << info.depth << " (must be 0, 1 or 2)"; + } + + if (info.arrayed > 1) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Invalid Arrayed " << info.arrayed << " (must be 0 or 1)"; + } + + if (info.multisampled > 1) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Invalid MS " << info.multisampled << " (must be 0 or 1)"; + } + + if (info.sampled > 2) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Invalid Sampled " << info.sampled << " (must be 0, 1 or 2)"; + } + + if (info.dim == SpvDimSubpassData) { + if (info.sampled != 2) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Dim SubpassData requires Sampled to be 2"; + } + + if (info.format != SpvImageFormatUnknown) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Dim SubpassData requires format Unknown"; + } + } + + // Format and Access Qualifier are checked elsewhere. + + return SPV_SUCCESS; +} + +spv_result_t ValidateTypeSampledImage(ValidationState_t& _, + const Instruction* inst) { + const uint32_t image_type = inst->word(2); + if (_.GetIdOpcode(image_type) != SpvOpTypeImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image to be of type OpTypeImage"; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateSampledImage(ValidationState_t& _, + const Instruction* inst) { + if (_.GetIdOpcode(inst->type_id()) != SpvOpTypeSampledImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be OpTypeSampledImage."; + } + + const uint32_t image_type = _.GetOperandTypeId(inst, 2); + if (_.GetIdOpcode(image_type) != SpvOpTypeImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image to be of type OpTypeImage."; + } + + ImageTypeInfo info; + if (!GetImageTypeInfo(_, image_type, &info)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Corrupt image type definition"; + } + + // TODO(atgoo@github.com) Check compatibility of result type and received + // image. + + if (spvIsVulkanEnv(_.context()->target_env)) { + if (info.sampled != 1) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Sampled' parameter to be 1 " + << "for Vulkan environment."; + } + } else { + if (info.sampled != 0 && info.sampled != 1) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Sampled' parameter to be 0 or 1"; + } + } + + if (info.dim == SpvDimSubpassData) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Dim' parameter to be not SubpassData."; + } + + if (_.GetIdOpcode(_.GetOperandTypeId(inst, 3)) != SpvOpTypeSampler) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Sampler to be of type OpTypeSampler"; + } + + // We need to validate 2 things: + // * All OpSampledImage instructions must be in the same block in which their + // Result are consumed. + // * Result from OpSampledImage instructions must not appear as operands + // to OpPhi instructions or OpSelect instructions, or any instructions other + // than the image lookup and query instructions specified to take an operand + // whose type is OpTypeSampledImage. + std::vector consumers = _.getSampledImageConsumers(inst->id()); + if (!consumers.empty()) { + for (auto consumer_id : consumers) { + const auto consumer_instr = _.FindDef(consumer_id); + const auto consumer_opcode = consumer_instr->opcode(); + if (consumer_instr->block() != inst->block()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "All OpSampledImage instructions must be in the same block " + "in " + "which their Result are consumed. OpSampledImage Result " + "Type '" + << _.getIdName(inst->id()) + << "' has a consumer in a different basic " + "block. The consumer instruction is '" + << _.getIdName(consumer_id) << "'."; + } + // TODO: The following check is incomplete. We should also check that the + // Sampled Image is not used by instructions that should not take + // SampledImage as an argument. We could find the list of valid + // instructions by scanning for "Sampled Image" in the operand description + // field in the grammar file. + if (consumer_opcode == SpvOpPhi || consumer_opcode == SpvOpSelect) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Result from OpSampledImage instruction must not appear " + "as " + "operands of Op" + << spvOpcodeString(static_cast(consumer_opcode)) << "." + << " Found result '" << _.getIdName(inst->id()) + << "' as an operand of '" << _.getIdName(consumer_id) + << "'."; + } + } + } + return SPV_SUCCESS; +} + +spv_result_t ValidateImageTexelPointer(ValidationState_t& _, + const Instruction* inst) { + const auto result_type = _.FindDef(inst->type_id()); + if (result_type->opcode() != SpvOpTypePointer) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be OpTypePointer"; + } + + const auto storage_class = result_type->GetOperandAs(1); + if (storage_class != SpvStorageClassImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be OpTypePointer whose Storage Class " + "operand is Image"; + } + + const auto ptr_type = result_type->GetOperandAs(2); + const auto ptr_opcode = _.GetIdOpcode(ptr_type); + if (ptr_opcode != SpvOpTypeInt && ptr_opcode != SpvOpTypeFloat && + ptr_opcode != SpvOpTypeVoid) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be OpTypePointer whose Type operand " + "must be a scalar numerical type or OpTypeVoid"; + } + + const auto image_ptr = _.FindDef(_.GetOperandTypeId(inst, 2)); + if (!image_ptr || image_ptr->opcode() != SpvOpTypePointer) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image to be OpTypePointer"; + } + + const auto image_type = image_ptr->GetOperandAs(2); + if (_.GetIdOpcode(image_type) != SpvOpTypeImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image to be OpTypePointer with Type OpTypeImage"; + } + + ImageTypeInfo info; + if (!GetImageTypeInfo(_, image_type, &info)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Corrupt image type definition"; + } + + if (info.sampled_type != ptr_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Sampled Type' to be the same as the Type " + "pointed to by Result Type"; + } + + if (info.dim == SpvDimSubpassData) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Dim SubpassData cannot be used with OpImageTexelPointer"; + } + + const uint32_t coord_type = _.GetOperandTypeId(inst, 3); + if (!coord_type || !_.IsIntScalarOrVectorType(coord_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to be integer scalar or vector"; + } + + uint32_t expected_coord_size = 0; + if (info.arrayed == 0) { + expected_coord_size = GetPlaneCoordSize(info); + } else if (info.arrayed == 1) { + switch (info.dim) { + case SpvDim1D: + expected_coord_size = 2; + break; + case SpvDimCube: + case SpvDim2D: + expected_coord_size = 3; + break; + default: + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Dim' must be one of 1D, 2D, or Cube when " + "Arrayed is 1"; + break; + } + } + + const uint32_t actual_coord_size = _.GetDimension(coord_type); + if (expected_coord_size != actual_coord_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to have " << expected_coord_size + << " components, but given " << actual_coord_size; + } + + const uint32_t sample_type = _.GetOperandTypeId(inst, 4); + if (!sample_type || !_.IsIntScalarType(sample_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Sample to be integer scalar"; + } + + if (info.multisampled == 0) { + uint64_t ms = 0; + if (!_.GetConstantValUint64(inst->GetOperandAs(4), &ms) || + ms != 0) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Sample for Image with MS 0 to be a valid for " + "the value 0"; + } + } + return SPV_SUCCESS; +} + +spv_result_t ValidateImageLod(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + uint32_t actual_result_type = 0; + if (spv_result_t error = GetActualResultType(_, inst, &actual_result_type)) { + return error; + } + + if (!_.IsIntVectorType(actual_result_type) && + !_.IsFloatVectorType(actual_result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected " << GetActualResultTypeStr(opcode) + << " to be int or float vector type"; + } + + if (_.GetDimension(actual_result_type) != 4) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected " << GetActualResultTypeStr(opcode) + << " to have 4 components"; + } + + const uint32_t image_type = _.GetOperandTypeId(inst, 2); + if (_.GetIdOpcode(image_type) != SpvOpTypeSampledImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Sampled Image to be of type OpTypeSampledImage"; + } + + ImageTypeInfo info; + if (!GetImageTypeInfo(_, image_type, &info)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Corrupt image type definition"; + } + + if (spv_result_t result = ValidateImageCommon(_, inst, info)) return result; + + if (_.GetIdOpcode(info.sampled_type) != SpvOpTypeVoid) { + const uint32_t texel_component_type = + _.GetComponentType(actual_result_type); + if (texel_component_type != info.sampled_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Sampled Type' to be the same as " + << GetActualResultTypeStr(opcode) << " components"; + } + } + + const uint32_t coord_type = _.GetOperandTypeId(inst, 3); + if ((opcode == SpvOpImageSampleExplicitLod || + opcode == SpvOpImageSparseSampleExplicitLod) && + _.HasCapability(SpvCapabilityKernel)) { + if (!_.IsFloatScalarOrVectorType(coord_type) && + !_.IsIntScalarOrVectorType(coord_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to be int or float scalar or vector"; + } + } else { + if (!_.IsFloatScalarOrVectorType(coord_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to be float scalar or vector"; + } + } + + const uint32_t min_coord_size = GetMinCoordSize(opcode, info); + const uint32_t actual_coord_size = _.GetDimension(coord_type); + if (min_coord_size > actual_coord_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to have at least " << min_coord_size + << " components, but given only " << actual_coord_size; + } + + if (inst->words().size() <= 5) { + assert(IsImplicitLod(opcode)); + return SPV_SUCCESS; + } + + const uint32_t mask = inst->word(5); + if (spv_result_t result = + ValidateImageOperands(_, inst, info, mask, /* word_index = */ 6)) + return result; + + return SPV_SUCCESS; +} + +spv_result_t ValidateImageDrefLod(ValidationState_t& _, + const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + uint32_t actual_result_type = 0; + if (spv_result_t error = GetActualResultType(_, inst, &actual_result_type)) { + return error; + } + + if (!_.IsIntScalarType(actual_result_type) && + !_.IsFloatScalarType(actual_result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected " << GetActualResultTypeStr(opcode) + << " to be int or float scalar type"; + } + + const uint32_t image_type = _.GetOperandTypeId(inst, 2); + if (_.GetIdOpcode(image_type) != SpvOpTypeSampledImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Sampled Image to be of type OpTypeSampledImage"; + } + + ImageTypeInfo info; + if (!GetImageTypeInfo(_, image_type, &info)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Corrupt image type definition"; + } + + if (spv_result_t result = ValidateImageCommon(_, inst, info)) return result; + + if (actual_result_type != info.sampled_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Sampled Type' to be the same as " + << GetActualResultTypeStr(opcode); + } + + const uint32_t coord_type = _.GetOperandTypeId(inst, 3); + if (!_.IsFloatScalarOrVectorType(coord_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to be float scalar or vector"; + } + + const uint32_t min_coord_size = GetMinCoordSize(opcode, info); + const uint32_t actual_coord_size = _.GetDimension(coord_type); + if (min_coord_size > actual_coord_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to have at least " << min_coord_size + << " components, but given only " << actual_coord_size; + } + + const uint32_t dref_type = _.GetOperandTypeId(inst, 4); + if (!_.IsFloatScalarType(dref_type) || _.GetBitWidth(dref_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Dref to be of 32-bit float type"; + } + + if (inst->words().size() <= 6) { + assert(IsImplicitLod(opcode)); + return SPV_SUCCESS; + } + + const uint32_t mask = inst->word(6); + if (spv_result_t result = + ValidateImageOperands(_, inst, info, mask, /* word_index = */ 7)) + return result; + + return SPV_SUCCESS; +} + +spv_result_t ValidateImageFetch(ValidationState_t& _, const Instruction* inst) { + uint32_t actual_result_type = 0; + if (spv_result_t error = GetActualResultType(_, inst, &actual_result_type)) { + return error; + } + + const SpvOp opcode = inst->opcode(); + if (!_.IsIntVectorType(actual_result_type) && + !_.IsFloatVectorType(actual_result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected " << GetActualResultTypeStr(opcode) + << " to be int or float vector type"; + } + + if (_.GetDimension(actual_result_type) != 4) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected " << GetActualResultTypeStr(opcode) + << " to have 4 components"; + } + + const uint32_t image_type = _.GetOperandTypeId(inst, 2); + if (_.GetIdOpcode(image_type) != SpvOpTypeImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image to be of type OpTypeImage"; + } + + ImageTypeInfo info; + if (!GetImageTypeInfo(_, image_type, &info)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Corrupt image type definition"; + } + + if (_.GetIdOpcode(info.sampled_type) != SpvOpTypeVoid) { + const uint32_t result_component_type = + _.GetComponentType(actual_result_type); + if (result_component_type != info.sampled_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Sampled Type' to be the same as " + << GetActualResultTypeStr(opcode) << " components"; + } + } + + if (info.dim == SpvDimCube) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Image 'Dim' cannot be Cube"; + } + + if (info.sampled != 1) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Sampled' parameter to be 1"; + } + + const uint32_t coord_type = _.GetOperandTypeId(inst, 3); + if (!_.IsIntScalarOrVectorType(coord_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to be int scalar or vector"; + } + + const uint32_t min_coord_size = GetMinCoordSize(opcode, info); + const uint32_t actual_coord_size = _.GetDimension(coord_type); + if (min_coord_size > actual_coord_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to have at least " << min_coord_size + << " components, but given only " << actual_coord_size; + } + + if (inst->words().size() <= 5) return SPV_SUCCESS; + + const uint32_t mask = inst->word(5); + if (spv_result_t result = + ValidateImageOperands(_, inst, info, mask, /* word_index = */ 6)) + return result; + + return SPV_SUCCESS; +} + +spv_result_t ValidateImageGather(ValidationState_t& _, + const Instruction* inst) { + uint32_t actual_result_type = 0; + if (spv_result_t error = GetActualResultType(_, inst, &actual_result_type)) + return error; + + const SpvOp opcode = inst->opcode(); + if (!_.IsIntVectorType(actual_result_type) && + !_.IsFloatVectorType(actual_result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected " << GetActualResultTypeStr(opcode) + << " to be int or float vector type"; + } + + if (_.GetDimension(actual_result_type) != 4) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected " << GetActualResultTypeStr(opcode) + << " to have 4 components"; + } + + const uint32_t image_type = _.GetOperandTypeId(inst, 2); + if (_.GetIdOpcode(image_type) != SpvOpTypeSampledImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Sampled Image to be of type OpTypeSampledImage"; + } + + ImageTypeInfo info; + if (!GetImageTypeInfo(_, image_type, &info)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Corrupt image type definition"; + } + + if (opcode == SpvOpImageDrefGather || opcode == SpvOpImageSparseDrefGather || + _.GetIdOpcode(info.sampled_type) != SpvOpTypeVoid) { + const uint32_t result_component_type = + _.GetComponentType(actual_result_type); + if (result_component_type != info.sampled_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Sampled Type' to be the same as " + << GetActualResultTypeStr(opcode) << " components"; + } + } + + if (info.dim != SpvDim2D && info.dim != SpvDimCube && + info.dim != SpvDimRect) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Dim' cannot be Cube"; + } + + const uint32_t coord_type = _.GetOperandTypeId(inst, 3); + if (!_.IsFloatScalarOrVectorType(coord_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to be float scalar or vector"; + } + + const uint32_t min_coord_size = GetMinCoordSize(opcode, info); + const uint32_t actual_coord_size = _.GetDimension(coord_type); + if (min_coord_size > actual_coord_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to have at least " << min_coord_size + << " components, but given only " << actual_coord_size; + } + + if (opcode == SpvOpImageGather || opcode == SpvOpImageSparseGather) { + const uint32_t component_index_type = _.GetOperandTypeId(inst, 4); + if (!_.IsIntScalarType(component_index_type) || + _.GetBitWidth(component_index_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Component to be 32-bit int scalar"; + } + } else { + assert(opcode == SpvOpImageDrefGather || + opcode == SpvOpImageSparseDrefGather); + const uint32_t dref_type = _.GetOperandTypeId(inst, 4); + if (!_.IsFloatScalarType(dref_type) || _.GetBitWidth(dref_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Dref to be of 32-bit float type"; + } + } + + if (inst->words().size() <= 6) return SPV_SUCCESS; + + const uint32_t mask = inst->word(6); + if (spv_result_t result = + ValidateImageOperands(_, inst, info, mask, /* word_index = */ 7)) + return result; + + return SPV_SUCCESS; +} + +spv_result_t ValidateImageRead(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + uint32_t actual_result_type = 0; + if (spv_result_t error = GetActualResultType(_, inst, &actual_result_type)) { + return error; + } + + if (!_.IsIntScalarOrVectorType(actual_result_type) && + !_.IsFloatScalarOrVectorType(actual_result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected " << GetActualResultTypeStr(opcode) + << " to be int or float scalar or vector type"; + } + +#if 0 + // TODO(atgoo@github.com) Disabled until the spec is clarified. + if (_.GetDimension(actual_result_type) != 4) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected " << GetActualResultTypeStr(opcode) + << " to have 4 components"; + } +#endif + + const uint32_t image_type = _.GetOperandTypeId(inst, 2); + if (_.GetIdOpcode(image_type) != SpvOpTypeImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image to be of type OpTypeImage"; + } + + ImageTypeInfo info; + if (!GetImageTypeInfo(_, image_type, &info)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Corrupt image type definition"; + } + + if (info.dim == SpvDimSubpassData) { + if (opcode == SpvOpImageSparseRead) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image Dim SubpassData cannot be used with ImageSparseRead"; + } + + _.function(inst->function()->id()) + ->RegisterExecutionModelLimitation( + SpvExecutionModelFragment, + std::string("Dim SubpassData requires Fragment execution model: ") + + spvOpcodeString(opcode)); + } + + if (_.GetIdOpcode(info.sampled_type) != SpvOpTypeVoid) { + const uint32_t result_component_type = + _.GetComponentType(actual_result_type); + if (result_component_type != info.sampled_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Sampled Type' to be the same as " + << GetActualResultTypeStr(opcode) << " components"; + } + } + + if (spv_result_t result = ValidateImageCommon(_, inst, info)) return result; + + const uint32_t coord_type = _.GetOperandTypeId(inst, 3); + if (!_.IsIntScalarOrVectorType(coord_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to be int scalar or vector"; + } + + const uint32_t min_coord_size = GetMinCoordSize(opcode, info); + const uint32_t actual_coord_size = _.GetDimension(coord_type); + if (min_coord_size > actual_coord_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to have at least " << min_coord_size + << " components, but given only " << actual_coord_size; + } + + if (info.format == SpvImageFormatUnknown && info.dim != SpvDimSubpassData && + !_.HasCapability(SpvCapabilityStorageImageReadWithoutFormat)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Capability StorageImageReadWithoutFormat is required to " + << "read storage image"; + } + + if (inst->words().size() <= 5) return SPV_SUCCESS; + + const uint32_t mask = inst->word(5); + if (spv_result_t result = + ValidateImageOperands(_, inst, info, mask, /* word_index = */ 6)) + return result; + + return SPV_SUCCESS; +} + +spv_result_t ValidateImageWrite(ValidationState_t& _, const Instruction* inst) { + const uint32_t image_type = _.GetOperandTypeId(inst, 0); + if (_.GetIdOpcode(image_type) != SpvOpTypeImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image to be of type OpTypeImage"; + } + + ImageTypeInfo info; + if (!GetImageTypeInfo(_, image_type, &info)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Corrupt image type definition"; + } + + if (info.dim == SpvDimSubpassData) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image 'Dim' cannot be SubpassData"; + } + + if (spv_result_t result = ValidateImageCommon(_, inst, info)) return result; + + const uint32_t coord_type = _.GetOperandTypeId(inst, 1); + if (!_.IsIntScalarOrVectorType(coord_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to be int scalar or vector"; + } + + const uint32_t min_coord_size = GetMinCoordSize(inst->opcode(), info); + const uint32_t actual_coord_size = _.GetDimension(coord_type); + if (min_coord_size > actual_coord_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to have at least " << min_coord_size + << " components, but given only " << actual_coord_size; + } + + // TODO(atgoo@github.com) The spec doesn't explicitely say what the type + // of texel should be. + const uint32_t texel_type = _.GetOperandTypeId(inst, 2); + if (!_.IsIntScalarOrVectorType(texel_type) && + !_.IsFloatScalarOrVectorType(texel_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Texel to be int or float vector or scalar"; + } + +#if 0 + // TODO: See above. + if (_.GetDimension(texel_type) != 4) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Texel to have 4 components"; + } +#endif + + if (_.GetIdOpcode(info.sampled_type) != SpvOpTypeVoid) { + const uint32_t texel_component_type = _.GetComponentType(texel_type); + if (texel_component_type != info.sampled_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image 'Sampled Type' to be the same as Texel " + << "components"; + } + } + + if (info.format == SpvImageFormatUnknown && info.dim != SpvDimSubpassData && + !_.HasCapability(SpvCapabilityStorageImageWriteWithoutFormat)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Capability StorageImageWriteWithoutFormat is required to " + "write " + << "to storage image"; + } + + if (inst->words().size() <= 4) return SPV_SUCCESS; + + const uint32_t mask = inst->word(4); + if (spv_result_t result = + ValidateImageOperands(_, inst, info, mask, /* word_index = */ 5)) + return result; + + return SPV_SUCCESS; +} + +spv_result_t ValidateImage(ValidationState_t& _, const Instruction* inst) { + const uint32_t result_type = inst->type_id(); + if (_.GetIdOpcode(result_type) != SpvOpTypeImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be OpTypeImage"; + } + + const uint32_t sampled_image_type = _.GetOperandTypeId(inst, 2); + const Instruction* sampled_image_type_inst = _.FindDef(sampled_image_type); + assert(sampled_image_type_inst); + + if (sampled_image_type_inst->opcode() != SpvOpTypeSampledImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Sample Image to be of type OpTypeSampleImage"; + } + + if (sampled_image_type_inst->word(2) != result_type) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Sample Image image type to be equal to Result Type"; + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateImageQuerySizeLod(ValidationState_t& _, + const Instruction* inst) { + const uint32_t result_type = inst->type_id(); + if (!_.IsIntScalarOrVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be int scalar or vector type"; + } + + const uint32_t image_type = _.GetOperandTypeId(inst, 2); + if (_.GetIdOpcode(image_type) != SpvOpTypeImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image to be of type OpTypeImage"; + } + + ImageTypeInfo info; + if (!GetImageTypeInfo(_, image_type, &info)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Corrupt image type definition"; + } + + uint32_t expected_num_components = info.arrayed; + switch (info.dim) { + case SpvDim1D: + expected_num_components += 1; + break; + case SpvDim2D: + case SpvDimCube: + expected_num_components += 2; + break; + case SpvDim3D: + expected_num_components += 3; + break; + default: + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image 'Dim' must be 1D, 2D, 3D or Cube"; + } + + if (info.multisampled != 0) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Image 'MS' must be 0"; + } + + uint32_t result_num_components = _.GetDimension(result_type); + if (result_num_components != expected_num_components) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Result Type has " << result_num_components << " components, " + << "but " << expected_num_components << " expected"; + } + + const uint32_t lod_type = _.GetOperandTypeId(inst, 3); + if (!_.IsIntScalarType(lod_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Level of Detail to be int scalar"; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateImageQuerySize(ValidationState_t& _, + const Instruction* inst) { + const uint32_t result_type = inst->type_id(); + if (!_.IsIntScalarOrVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be int scalar or vector type"; + } + + const uint32_t image_type = _.GetOperandTypeId(inst, 2); + if (_.GetIdOpcode(image_type) != SpvOpTypeImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image to be of type OpTypeImage"; + } + + ImageTypeInfo info; + if (!GetImageTypeInfo(_, image_type, &info)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Corrupt image type definition"; + } + + uint32_t expected_num_components = info.arrayed; + switch (info.dim) { + case SpvDim1D: + case SpvDimBuffer: + expected_num_components += 1; + break; + case SpvDim2D: + case SpvDimCube: + case SpvDimRect: + expected_num_components += 2; + break; + case SpvDim3D: + expected_num_components += 3; + break; + default: + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image 'Dim' must be 1D, Buffer, 2D, Cube, 3D or Rect"; + } + + if (info.dim == SpvDim1D || info.dim == SpvDim2D || info.dim == SpvDim3D || + info.dim == SpvDimCube) { + if (info.multisampled != 1 && info.sampled != 0 && info.sampled != 2) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image must have either 'MS'=1 or 'Sampled'=0 or 'Sampled'=2"; + } + } + + uint32_t result_num_components = _.GetDimension(result_type); + if (result_num_components != expected_num_components) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Result Type has " << result_num_components << " components, " + << "but " << expected_num_components << " expected"; + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateImageQueryFormatOrOrder(ValidationState_t& _, + const Instruction* inst) { + if (!_.IsIntScalarType(inst->type_id())) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be int scalar type"; + } + + if (_.GetIdOpcode(_.GetOperandTypeId(inst, 2)) != SpvOpTypeImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected operand to be of type OpTypeImage"; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateImageQueryLod(ValidationState_t& _, + const Instruction* inst) { + _.function(inst->function()->id()) + ->RegisterExecutionModelLimitation( + SpvExecutionModelFragment, + "OpImageQueryLod requires Fragment execution model"); + + const uint32_t result_type = inst->type_id(); + if (!_.IsFloatVectorType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be float vector type"; + } + + if (_.GetDimension(result_type) != 2) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to have 2 components"; + } + + const uint32_t image_type = _.GetOperandTypeId(inst, 2); + if (_.GetIdOpcode(image_type) != SpvOpTypeSampledImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image operand to be of type OpTypeSampledImage"; + } + + ImageTypeInfo info; + if (!GetImageTypeInfo(_, image_type, &info)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Corrupt image type definition"; + } + + if (info.dim != SpvDim1D && info.dim != SpvDim2D && info.dim != SpvDim3D && + info.dim != SpvDimCube) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image 'Dim' must be 1D, 2D, 3D or Cube"; + } + + const uint32_t coord_type = _.GetOperandTypeId(inst, 3); + if (_.HasCapability(SpvCapabilityKernel)) { + if (!_.IsFloatScalarOrVectorType(coord_type) && + !_.IsIntScalarOrVectorType(coord_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to be int or float scalar or vector"; + } + } else { + if (!_.IsFloatScalarOrVectorType(coord_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to be float scalar or vector"; + } + } + + const uint32_t min_coord_size = GetPlaneCoordSize(info); + const uint32_t actual_coord_size = _.GetDimension(coord_type); + if (min_coord_size > actual_coord_size) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Coordinate to have at least " << min_coord_size + << " components, but given only " << actual_coord_size; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateImageSparseLod(ValidationState_t& _, + const Instruction* inst) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Instruction reserved for future use, use of this instruction " + << "is invalid"; +} + +spv_result_t ValidateImageQueryLevelsOrSamples(ValidationState_t& _, + const Instruction* inst) { + if (!_.IsIntScalarType(inst->type_id())) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be int scalar type"; + } + + const uint32_t image_type = _.GetOperandTypeId(inst, 2); + if (_.GetIdOpcode(image_type) != SpvOpTypeImage) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Image to be of type OpTypeImage"; + } + + ImageTypeInfo info; + if (!GetImageTypeInfo(_, image_type, &info)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Corrupt image type definition"; + } + + const SpvOp opcode = inst->opcode(); + if (opcode == SpvOpImageQueryLevels) { + if (info.dim != SpvDim1D && info.dim != SpvDim2D && info.dim != SpvDim3D && + info.dim != SpvDimCube) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Image 'Dim' must be 1D, 2D, 3D or Cube"; + } + } else { + assert(opcode == SpvOpImageQuerySamples); + if (info.dim != SpvDim2D) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Image 'Dim' must be 2D"; + } + + if (info.multisampled != 1) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Image 'MS' must be 1"; + } + } + return SPV_SUCCESS; +} + +spv_result_t ValidateImageSparseTexelsResident(ValidationState_t& _, + const Instruction* inst) { + if (!_.IsBoolScalarType(inst->type_id())) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be bool scalar type"; + } + + const uint32_t resident_code_type = _.GetOperandTypeId(inst, 2); + if (!_.IsIntScalarType(resident_code_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Resident Code to be int scalar"; + } + + return SPV_SUCCESS; +} + +} // namespace + +// Validates correctness of image instructions. +spv_result_t ImagePass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + if (IsImplicitLod(opcode)) { + _.function(inst->function()->id()) + ->RegisterExecutionModelLimitation( + SpvExecutionModelFragment, + "ImplicitLod instructions require Fragment execution model"); + } + + switch (opcode) { + case SpvOpTypeImage: + return ValidateTypeImage(_, inst); + case SpvOpTypeSampledImage: + return ValidateTypeSampledImage(_, inst); + case SpvOpSampledImage: + return ValidateSampledImage(_, inst); + case SpvOpImageTexelPointer: + return ValidateImageTexelPointer(_, inst); + + case SpvOpImageSampleImplicitLod: + case SpvOpImageSampleExplicitLod: + case SpvOpImageSampleProjImplicitLod: + case SpvOpImageSampleProjExplicitLod: + case SpvOpImageSparseSampleImplicitLod: + case SpvOpImageSparseSampleExplicitLod: + return ValidateImageLod(_, inst); + + case SpvOpImageSampleDrefImplicitLod: + case SpvOpImageSampleDrefExplicitLod: + case SpvOpImageSampleProjDrefImplicitLod: + case SpvOpImageSampleProjDrefExplicitLod: + case SpvOpImageSparseSampleDrefImplicitLod: + case SpvOpImageSparseSampleDrefExplicitLod: + return ValidateImageDrefLod(_, inst); + + case SpvOpImageFetch: + case SpvOpImageSparseFetch: + return ValidateImageFetch(_, inst); + + case SpvOpImageGather: + case SpvOpImageDrefGather: + case SpvOpImageSparseGather: + case SpvOpImageSparseDrefGather: + return ValidateImageGather(_, inst); + + case SpvOpImageRead: + case SpvOpImageSparseRead: + return ValidateImageRead(_, inst); + + case SpvOpImageWrite: + return ValidateImageWrite(_, inst); + + case SpvOpImage: + return ValidateImage(_, inst); + + case SpvOpImageQueryFormat: + case SpvOpImageQueryOrder: + return ValidateImageQueryFormatOrOrder(_, inst); + + case SpvOpImageQuerySizeLod: + return ValidateImageQuerySizeLod(_, inst); + case SpvOpImageQuerySize: + return ValidateImageQuerySize(_, inst); + case SpvOpImageQueryLod: + return ValidateImageQueryLod(_, inst); + + case SpvOpImageQueryLevels: + case SpvOpImageQuerySamples: + return ValidateImageQueryLevelsOrSamples(_, inst); + + case SpvOpImageSparseSampleProjImplicitLod: + case SpvOpImageSparseSampleProjExplicitLod: + case SpvOpImageSparseSampleProjDrefImplicitLod: + case SpvOpImageSparseSampleProjDrefExplicitLod: + return ValidateImageSparseLod(_, inst); + + case SpvOpImageSparseTexelsResident: + return ValidateImageSparseTexelsResident(_, inst); + + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_instruction.cpp b/source/val/validate_instruction.cpp new file mode 100644 index 000000000..17949d22a --- /dev/null +++ b/source/val/validate_instruction.cpp @@ -0,0 +1,535 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Performs validation on instructions that appear inside of a SPIR-V block. + +#include "source/val/validate.h" + +#include +#include +#include +#include +#include + +#include "source/binary.h" +#include "source/diagnostic.h" +#include "source/enum_set.h" +#include "source/enum_string_mapping.h" +#include "source/extensions.h" +#include "source/opcode.h" +#include "source/operand.h" +#include "source/spirv_constant.h" +#include "source/spirv_definition.h" +#include "source/spirv_target_env.h" +#include "source/spirv_validator_options.h" +#include "source/util/string_utils.h" +#include "source/val/function.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +std::string ToString(const CapabilitySet& capabilities, + const AssemblyGrammar& grammar) { + std::stringstream ss; + capabilities.ForEach([&grammar, &ss](SpvCapability cap) { + spv_operand_desc desc; + if (SPV_SUCCESS == + grammar.lookupOperand(SPV_OPERAND_TYPE_CAPABILITY, cap, &desc)) + ss << desc->name << " "; + else + ss << cap << " "; + }); + return ss.str(); +} + +// Returns capabilities that enable an opcode. An empty result is interpreted +// as no prohibition of use of the opcode. If the result is non-empty, then +// the opcode may only be used if at least one of the capabilities is specified +// by the module. +CapabilitySet EnablingCapabilitiesForOp(const ValidationState_t& state, + SpvOp opcode) { + // Exceptions for SPV_AMD_shader_ballot + switch (opcode) { + // Normally these would require Group capability + case SpvOpGroupIAddNonUniformAMD: + case SpvOpGroupFAddNonUniformAMD: + case SpvOpGroupFMinNonUniformAMD: + case SpvOpGroupUMinNonUniformAMD: + case SpvOpGroupSMinNonUniformAMD: + case SpvOpGroupFMaxNonUniformAMD: + case SpvOpGroupUMaxNonUniformAMD: + case SpvOpGroupSMaxNonUniformAMD: + if (state.HasExtension(kSPV_AMD_shader_ballot)) return CapabilitySet(); + break; + default: + break; + } + // Look it up in the grammar + spv_opcode_desc opcode_desc = {}; + if (SPV_SUCCESS == state.grammar().lookupOpcode(opcode, &opcode_desc)) { + return state.grammar().filterCapsAgainstTargetEnv( + opcode_desc->capabilities, opcode_desc->numCapabilities); + } + return CapabilitySet(); +} + +// Returns SPV_SUCCESS if the given operand is enabled by capabilities declared +// in the module. Otherwise issues an error message and returns +// SPV_ERROR_INVALID_CAPABILITY. +spv_result_t CheckRequiredCapabilities(ValidationState_t& state, + const Instruction* inst, + size_t which_operand, + spv_operand_type_t type, + uint32_t operand) { + // Mere mention of PointSize, ClipDistance, or CullDistance in a Builtin + // decoration does not require the associated capability. The use of such + // a variable value should trigger the capability requirement, but that's + // not implemented yet. This rule is independent of target environment. + // See https://github.com/KhronosGroup/SPIRV-Tools/issues/365 + if (type == SPV_OPERAND_TYPE_BUILT_IN) { + switch (operand) { + case SpvBuiltInPointSize: + case SpvBuiltInClipDistance: + case SpvBuiltInCullDistance: + return SPV_SUCCESS; + default: + break; + } + } else if (type == SPV_OPERAND_TYPE_FP_ROUNDING_MODE) { + // Allow all FP rounding modes if requested + if (state.features().free_fp_rounding_mode) { + return SPV_SUCCESS; + } + } else if (type == SPV_OPERAND_TYPE_GROUP_OPERATION && + state.features().group_ops_reduce_and_scans && + (operand <= uint32_t(SpvGroupOperationExclusiveScan))) { + // Allow certain group operations if requested. + return SPV_SUCCESS; + } + + CapabilitySet enabling_capabilities; + spv_operand_desc operand_desc = nullptr; + const auto lookup_result = + state.grammar().lookupOperand(type, operand, &operand_desc); + if (lookup_result == SPV_SUCCESS) { + // Allow FPRoundingMode decoration if requested. + if (type == SPV_OPERAND_TYPE_DECORATION && + operand_desc->value == SpvDecorationFPRoundingMode) { + if (state.features().free_fp_rounding_mode) return SPV_SUCCESS; + + // Vulkan API requires more capabilities on rounding mode. + if (spvIsVulkanEnv(state.context()->target_env)) { + enabling_capabilities.Add(SpvCapabilityStorageUniformBufferBlock16); + enabling_capabilities.Add(SpvCapabilityStorageUniform16); + enabling_capabilities.Add(SpvCapabilityStoragePushConstant16); + enabling_capabilities.Add(SpvCapabilityStorageInputOutput16); + } + } else { + enabling_capabilities = state.grammar().filterCapsAgainstTargetEnv( + operand_desc->capabilities, operand_desc->numCapabilities); + } + + if (!state.HasAnyOfCapabilities(enabling_capabilities)) { + return state.diag(SPV_ERROR_INVALID_CAPABILITY, inst) + << "Operand " << which_operand << " of " + << spvOpcodeString(inst->opcode()) + << " requires one of these capabilities: " + << ToString(enabling_capabilities, state.grammar()); + } + } + + return SPV_SUCCESS; +} + +// Returns operand's required extensions. +ExtensionSet RequiredExtensions(const ValidationState_t& state, + spv_operand_type_t type, uint32_t operand) { + spv_operand_desc operand_desc; + if (state.grammar().lookupOperand(type, operand, &operand_desc) == + SPV_SUCCESS) { + assert(operand_desc); + // If this operand is incorporated into core SPIR-V before or in the current + // target environment, we don't require extensions anymore. + if (spvVersionForTargetEnv(state.grammar().target_env()) >= + operand_desc->minVersion) + return {}; + return {operand_desc->numExtensions, operand_desc->extensions}; + } + + return {}; +} + +// Returns SPV_ERROR_INVALID_BINARY and emits a diagnostic if the instruction +// is explicitly reserved in the SPIR-V core spec. Otherwise return +// SPV_SUCCESS. +spv_result_t ReservedCheck(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + switch (opcode) { + // These instructions are enabled by a capability, but should never + // be used anyway. + case SpvOpImageSparseSampleProjImplicitLod: + case SpvOpImageSparseSampleProjExplicitLod: + case SpvOpImageSparseSampleProjDrefImplicitLod: + case SpvOpImageSparseSampleProjDrefExplicitLod: { + spv_opcode_desc inst_desc; + _.grammar().lookupOpcode(opcode, &inst_desc); + return _.diag(SPV_ERROR_INVALID_BINARY, inst) + << "Invalid Opcode name 'Op" << inst_desc->name << "'"; + } + default: + break; + } + return SPV_SUCCESS; +} + +// Returns SPV_ERROR_INVALID_BINARY and emits a diagnostic if the instruction +// is invalid because of an execution environment constraint. +spv_result_t EnvironmentCheck(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + switch (opcode) { + case SpvOpUndef: + if (_.features().bans_op_undef) { + return _.diag(SPV_ERROR_INVALID_BINARY, inst) + << "OpUndef is disallowed"; + } + break; + default: + break; + } + return SPV_SUCCESS; +} + +// Returns SPV_ERROR_INVALID_CAPABILITY and emits a diagnostic if the +// instruction is invalid because the required capability isn't declared +// in the module. +spv_result_t CapabilityCheck(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + CapabilitySet opcode_caps = EnablingCapabilitiesForOp(_, opcode); + if (!_.HasAnyOfCapabilities(opcode_caps)) { + return _.diag(SPV_ERROR_INVALID_CAPABILITY, inst) + << "Opcode " << spvOpcodeString(opcode) + << " requires one of these capabilities: " + << ToString(opcode_caps, _.grammar()); + } + for (size_t i = 0; i < inst->operands().size(); ++i) { + const auto& operand = inst->operand(i); + const auto word = inst->word(operand.offset); + if (spvOperandIsConcreteMask(operand.type)) { + // Check for required capabilities for each bit position of the mask. + for (uint32_t mask_bit = 0x80000000; mask_bit; mask_bit >>= 1) { + if (word & mask_bit) { + spv_result_t status = + CheckRequiredCapabilities(_, inst, i + 1, operand.type, mask_bit); + if (status != SPV_SUCCESS) return status; + } + } + } else if (spvIsIdType(operand.type)) { + // TODO(dneto): Check the value referenced by this Id, if we can compute + // it. For now, just punt, to fix issue 248: + // https://github.com/KhronosGroup/SPIRV-Tools/issues/248 + } else { + // Check the operand word as a whole. + spv_result_t status = + CheckRequiredCapabilities(_, inst, i + 1, operand.type, word); + if (status != SPV_SUCCESS) return status; + } + } + return SPV_SUCCESS; +} + +// Checks that all extensions required by the given instruction's operands were +// declared in the module. +spv_result_t ExtensionCheck(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + for (size_t operand_index = 0; operand_index < inst->operands().size(); + ++operand_index) { + const auto& operand = inst->operand(operand_index); + const uint32_t word = inst->word(operand.offset); + const ExtensionSet required_extensions = + RequiredExtensions(_, operand.type, word); + if (!_.HasAnyOfExtensions(required_extensions)) { + return _.diag(SPV_ERROR_MISSING_EXTENSION, inst) + << spvtools::utils::CardinalToOrdinal(operand_index + 1) + << " operand of " << spvOpcodeString(opcode) << ": operand " + << word << " requires one of these extensions: " + << ExtensionSetToString(required_extensions); + } + } + return SPV_SUCCESS; +} + +// Checks that the instruction can be used in this target environment's base +// version. Assumes that CapabilityCheck has checked direct capability +// dependencies for the opcode. +spv_result_t VersionCheck(ValidationState_t& _, const Instruction* inst) { + const auto opcode = inst->opcode(); + spv_opcode_desc inst_desc; + const spv_result_t r = _.grammar().lookupOpcode(opcode, &inst_desc); + assert(r == SPV_SUCCESS); + (void)r; + + const auto min_version = inst_desc->minVersion; + + if (inst_desc->numCapabilities > 0u) { + // We already checked that the direct capability dependency has been + // satisfied. We don't need to check any further. + return SPV_SUCCESS; + } + + ExtensionSet exts(inst_desc->numExtensions, inst_desc->extensions); + if (exts.IsEmpty()) { + // If no extensions can enable this instruction, then emit error messages + // only concerning core SPIR-V versions if errors happen. + if (min_version == ~0u) { + return _.diag(SPV_ERROR_WRONG_VERSION, inst) + << spvOpcodeString(opcode) << " is reserved for future use."; + } + + if (spvVersionForTargetEnv(_.grammar().target_env()) < min_version) { + return _.diag(SPV_ERROR_WRONG_VERSION, inst) + << spvOpcodeString(opcode) << " requires " + << spvTargetEnvDescription( + static_cast(min_version)) + << " at minimum."; + } + // Otherwise, we only error out when no enabling extensions are registered. + } else if (!_.HasAnyOfExtensions(exts)) { + if (min_version == ~0u) { + return _.diag(SPV_ERROR_MISSING_EXTENSION, inst) + << spvOpcodeString(opcode) + << " requires one of the following extensions: " + << ExtensionSetToString(exts); + } + + if (static_cast(_.grammar().target_env()) < min_version) { + return _.diag(SPV_ERROR_WRONG_VERSION, inst) + << spvOpcodeString(opcode) << " requires " + << spvTargetEnvDescription( + static_cast(min_version)) + << " at minimum or one of the following extensions: " + << ExtensionSetToString(exts); + } + } + + return SPV_SUCCESS; +} + +// Checks that the Resuld is within the valid bound. +spv_result_t LimitCheckIdBound(ValidationState_t& _, const Instruction* inst) { + if (inst->id() >= _.getIdBound()) { + return _.diag(SPV_ERROR_INVALID_BINARY, inst) + << "Result '" << inst->id() + << "' must be less than the ID bound '" << _.getIdBound() << "'."; + } + return SPV_SUCCESS; +} + +// Checks that the number of OpTypeStruct members is within the limit. +spv_result_t LimitCheckStruct(ValidationState_t& _, const Instruction* inst) { + if (SpvOpTypeStruct != inst->opcode()) { + return SPV_SUCCESS; + } + + // Number of members is the number of operands of the instruction minus 1. + // One operand is the result ID. + const uint16_t limit = + static_cast(_.options()->universal_limits_.max_struct_members); + if (inst->operands().size() - 1 > limit) { + return _.diag(SPV_ERROR_INVALID_BINARY, inst) + << "Number of OpTypeStruct members (" << inst->operands().size() - 1 + << ") has exceeded the limit (" << limit << ")."; + } + + // Section 2.17 of SPIRV Spec specifies that the "Structure Nesting Depth" + // must be less than or equal to 255. + // This is interpreted as structures including other structures as members. + // The code does not follow pointers or look into arrays to see if we reach a + // structure downstream. + // The nesting depth of a struct is 1+(largest depth of any member). + // Scalars are at depth 0. + uint32_t max_member_depth = 0; + // Struct members start at word 2 of OpTypeStruct instruction. + for (size_t word_i = 2; word_i < inst->words().size(); ++word_i) { + auto member = inst->word(word_i); + auto memberTypeInstr = _.FindDef(member); + if (memberTypeInstr && SpvOpTypeStruct == memberTypeInstr->opcode()) { + max_member_depth = std::max( + max_member_depth, _.struct_nesting_depth(memberTypeInstr->id())); + } + } + + const uint32_t depth_limit = _.options()->universal_limits_.max_struct_depth; + const uint32_t cur_depth = 1 + max_member_depth; + _.set_struct_nesting_depth(inst->id(), cur_depth); + if (cur_depth > depth_limit) { + return _.diag(SPV_ERROR_INVALID_BINARY, inst) + << "Structure Nesting Depth may not be larger than " << depth_limit + << ". Found " << cur_depth << "."; + } + return SPV_SUCCESS; +} + +// Checks that the number of (literal, label) pairs in OpSwitch is within the +// limit. +spv_result_t LimitCheckSwitch(ValidationState_t& _, const Instruction* inst) { + if (SpvOpSwitch == inst->opcode()) { + // The instruction syntax is as follows: + // OpSwitch literal label literal label ... + // literal,label pairs come after the first 2 operands. + // It is guaranteed at this point that num_operands is an even numner. + size_t num_pairs = (inst->operands().size() - 2) / 2; + const unsigned int num_pairs_limit = + _.options()->universal_limits_.max_switch_branches; + if (num_pairs > num_pairs_limit) { + return _.diag(SPV_ERROR_INVALID_BINARY, inst) + << "Number of (literal, label) pairs in OpSwitch (" << num_pairs + << ") exceeds the limit (" << num_pairs_limit << ")."; + } + } + return SPV_SUCCESS; +} + +// Ensure the number of variables of the given class does not exceed the limit. +spv_result_t LimitCheckNumVars(ValidationState_t& _, const uint32_t var_id, + const SpvStorageClass storage_class) { + if (SpvStorageClassFunction == storage_class) { + _.registerLocalVariable(var_id); + const uint32_t num_local_vars_limit = + _.options()->universal_limits_.max_local_variables; + if (_.num_local_vars() > num_local_vars_limit) { + return _.diag(SPV_ERROR_INVALID_BINARY, nullptr) + << "Number of local variables ('Function' Storage Class) " + "exceeded the valid limit (" + << num_local_vars_limit << ")."; + } + } else { + _.registerGlobalVariable(var_id); + const uint32_t num_global_vars_limit = + _.options()->universal_limits_.max_global_variables; + if (_.num_global_vars() > num_global_vars_limit) { + return _.diag(SPV_ERROR_INVALID_BINARY, nullptr) + << "Number of Global Variables (Storage Class other than " + "'Function') exceeded the valid limit (" + << num_global_vars_limit << ")."; + } + } + return SPV_SUCCESS; +} + +// Parses OpExtension instruction and logs warnings if unsuccessful. +spv_result_t CheckIfKnownExtension(ValidationState_t& _, + const Instruction* inst) { + const std::string extension_str = GetExtensionString(&(inst->c_inst())); + Extension extension; + if (!GetExtensionFromString(extension_str.c_str(), &extension)) { + return _.diag(SPV_WARNING, inst) + << "Found unrecognized extension " << extension_str; + } + return SPV_SUCCESS; +} + +} // namespace + +spv_result_t InstructionPass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + if (opcode == SpvOpExtension) { + CheckIfKnownExtension(_, inst); + } else if (opcode == SpvOpCapability) { + _.RegisterCapability(inst->GetOperandAs(0)); + } else if (opcode == SpvOpMemoryModel) { + if (_.has_memory_model_specified()) { + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) + << "OpMemoryModel should only be provided once."; + } + _.set_addressing_model(inst->GetOperandAs(0)); + _.set_memory_model(inst->GetOperandAs(1)); + + if (_.memory_model() != SpvMemoryModelVulkanKHR && + _.HasCapability(SpvCapabilityVulkanMemoryModelKHR)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "VulkanMemoryModelKHR capability must only be specified if the " + "VulkanKHR memory model is used."; + } + + if (spvIsWebGPUEnv(_.context()->target_env)) { + if (_.addressing_model() != SpvAddressingModelLogical) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Addressing model must be Logical for WebGPU environment."; + } + if (_.memory_model() != SpvMemoryModelVulkanKHR) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Memory model must be VulkanKHR for WebGPU environment."; + } + } + } else if (opcode == SpvOpExecutionMode) { + const uint32_t entry_point = inst->word(1); + _.RegisterExecutionModeForEntryPoint(entry_point, + SpvExecutionMode(inst->word(2))); + } else if (opcode == SpvOpVariable) { + const auto storage_class = inst->GetOperandAs(2); + if (auto error = LimitCheckNumVars(_, inst->id(), storage_class)) { + return error; + } + if (storage_class == SpvStorageClassGeneric) + return _.diag(SPV_ERROR_INVALID_BINARY, inst) + << "OpVariable storage class cannot be Generic"; + if (_.current_layout_section() == kLayoutFunctionDefinitions) { + if (storage_class != SpvStorageClassFunction) { + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) + << "Variables must have a function[7] storage class inside" + " of a function"; + } + if (_.current_function().IsFirstBlock( + _.current_function().current_block()->id()) == false) { + return _.diag(SPV_ERROR_INVALID_CFG, inst) + << "Variables can only be defined " + "in the first block of a " + "function"; + } + } else { + if (storage_class == SpvStorageClassFunction) { + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) + << "Variables can not have a function[7] storage class " + "outside of a function"; + } + } + } + + // SPIR-V Spec 2.16.3: Validation Rules for Kernel Capabilities: The + // Signedness in OpTypeInt must always be 0. + if (SpvOpTypeInt == inst->opcode() && _.HasCapability(SpvCapabilityKernel) && + inst->GetOperandAs(2) != 0u) { + return _.diag(SPV_ERROR_INVALID_BINARY, inst) + << "The Signedness in OpTypeInt " + "must always be 0 when Kernel " + "capability is used."; + } + + if (auto error = ExtensionCheck(_, inst)) return error; + if (auto error = ReservedCheck(_, inst)) return error; + if (auto error = EnvironmentCheck(_, inst)) return error; + if (auto error = CapabilityCheck(_, inst)) return error; + if (auto error = LimitCheckIdBound(_, inst)) return error; + if (auto error = LimitCheckStruct(_, inst)) return error; + if (auto error = LimitCheckSwitch(_, inst)) return error; + if (auto error = VersionCheck(_, inst)) return error; + + // All instruction checks have passed. + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_interfaces.cpp b/source/val/validate_interfaces.cpp new file mode 100644 index 000000000..fffc6da1a --- /dev/null +++ b/source/val/validate_interfaces.cpp @@ -0,0 +1,114 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/validate.h" + +#include +#include + +#include "source/diagnostic.h" +#include "source/val/function.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +// Returns true if \c inst is an input or output variable. +bool is_interface_variable(const Instruction* inst) { + return inst->opcode() == SpvOpVariable && + (inst->word(3u) == SpvStorageClassInput || + inst->word(3u) == SpvStorageClassOutput); +} + +// Checks that \c var is listed as an interface in all the entry points that use +// it. +spv_result_t check_interface_variable(ValidationState_t& _, + const Instruction* var) { + std::vector functions; + std::vector uses; + for (auto use : var->uses()) { + uses.push_back(use.first); + } + for (uint32_t i = 0; i < uses.size(); ++i) { + const auto user = uses[i]; + if (const Function* func = user->function()) { + functions.push_back(func); + } else { + // In the rare case that the variable is used by another instruction in + // the global scope, continue searching for an instruction used in a + // function. + for (auto use : user->uses()) { + uses.push_back(use.first); + } + } + } + + std::sort(functions.begin(), functions.end(), + [](const Function* lhs, const Function* rhs) { + return lhs->id() < rhs->id(); + }); + functions.erase(std::unique(functions.begin(), functions.end()), + functions.end()); + + std::vector entry_points; + for (const auto func : functions) { + for (auto id : _.FunctionEntryPoints(func->id())) { + entry_points.push_back(id); + } + } + + std::sort(entry_points.begin(), entry_points.end()); + entry_points.erase(std::unique(entry_points.begin(), entry_points.end()), + entry_points.end()); + + for (auto id : entry_points) { + for (const auto& desc : _.entry_point_descriptions(id)) { + bool found = false; + for (auto interface : desc.interfaces) { + if (var->id() == interface) { + found = true; + break; + } + } + if (!found) { + return _.diag(SPV_ERROR_INVALID_ID, var) + << (var->word(3u) == SpvStorageClassInput ? "Input" : "Output") + << " variable id <" << var->id() << "> is used by entry point '" + << desc.name << "' id <" << id + << ">, but is not listed as an interface"; + } + } + } + + return SPV_SUCCESS; +} + +} // namespace + +spv_result_t ValidateInterfaces(ValidationState_t& _) { + for (auto& inst : _.ordered_instructions()) { + if (is_interface_variable(&inst)) { + if (auto error = check_interface_variable(_, &inst)) { + return error; + } + } + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_layout.cpp b/source/val/validate_layout.cpp new file mode 100644 index 000000000..53c28355f --- /dev/null +++ b/source/val/validate_layout.cpp @@ -0,0 +1,202 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Source code for logical layout validation as described in section 2.4 + +#include "source/val/validate.h" + +#include + +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/operand.h" +#include "source/val/function.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +// Module scoped instructions are processed by determining if the opcode +// is part of the current layout section. If it is not then the next sections is +// checked. +spv_result_t ModuleScopedInstructions(ValidationState_t& _, + const Instruction* inst, SpvOp opcode) { + while (_.IsOpcodeInCurrentLayoutSection(opcode) == false) { + _.ProgressToNextLayoutSectionOrder(); + + switch (_.current_layout_section()) { + case kLayoutMemoryModel: + if (opcode != SpvOpMemoryModel) { + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) + << spvOpcodeString(opcode) + << " cannot appear before the memory model instruction"; + } + break; + case kLayoutFunctionDeclarations: + // All module sections have been processed. Recursively call + // ModuleLayoutPass to process the next section of the module + return ModuleLayoutPass(_, inst); + default: + break; + } + } + return SPV_SUCCESS; +} + +// Function declaration validation is performed by making sure that the +// FunctionParameter and FunctionEnd instructions only appear inside of +// functions. It also ensures that the Function instruction does not appear +// inside of another function. This stage ends when the first label is +// encountered inside of a function. +spv_result_t FunctionScopedInstructions(ValidationState_t& _, + const Instruction* inst, SpvOp opcode) { + if (_.IsOpcodeInCurrentLayoutSection(opcode)) { + switch (opcode) { + case SpvOpFunction: { + if (_.in_function_body()) { + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) + << "Cannot declare a function in a function body"; + } + auto control_mask = inst->GetOperandAs(2); + if (auto error = + _.RegisterFunction(inst->id(), inst->type_id(), control_mask, + inst->GetOperandAs(3))) + return error; + if (_.current_layout_section() == kLayoutFunctionDefinitions) { + if (auto error = _.current_function().RegisterSetFunctionDeclType( + FunctionDecl::kFunctionDeclDefinition)) + return error; + } + } break; + + case SpvOpFunctionParameter: + if (_.in_function_body() == false) { + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) + << "Function parameter instructions must be in a " + "function body"; + } + if (_.current_function().block_count() != 0) { + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) + << "Function parameters must only appear immediately after " + "the function definition"; + } + if (auto error = _.current_function().RegisterFunctionParameter( + inst->id(), inst->type_id())) + return error; + break; + + case SpvOpFunctionEnd: + if (_.in_function_body() == false) { + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) + << "Function end instructions must be in a function body"; + } + if (_.in_block()) { + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) + << "Function end cannot be called in blocks"; + } + if (_.current_function().block_count() == 0 && + _.current_layout_section() == kLayoutFunctionDefinitions) { + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) + << "Function declarations must appear before " + "function definitions."; + } + if (_.current_layout_section() == kLayoutFunctionDeclarations) { + if (auto error = _.current_function().RegisterSetFunctionDeclType( + FunctionDecl::kFunctionDeclDeclaration)) + return error; + } + if (auto error = _.RegisterFunctionEnd()) return error; + break; + + case SpvOpLine: + case SpvOpNoLine: + break; + case SpvOpLabel: + // If the label is encountered then the current function is a + // definition so set the function to a declaration and update the + // module section + if (_.in_function_body() == false) { + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) + << "Label instructions must be in a function body"; + } + if (_.in_block()) { + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) + << "A block must end with a branch instruction."; + } + if (_.current_layout_section() == kLayoutFunctionDeclarations) { + _.ProgressToNextLayoutSectionOrder(); + if (auto error = _.current_function().RegisterSetFunctionDeclType( + FunctionDecl::kFunctionDeclDefinition)) + return error; + } + break; + + default: + if (_.current_layout_section() == kLayoutFunctionDeclarations && + _.in_function_body()) { + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) + << "A function must begin with a label"; + } else { + if (_.in_block() == false) { + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) + << spvOpcodeString(opcode) << " must appear in a block"; + } + } + break; + } + } else { + return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) + << spvOpcodeString(opcode) + << " cannot appear in a function declaration"; + } + return SPV_SUCCESS; +} + +} // namespace + +// TODO(umar): Check linkage capabilities for function declarations +// TODO(umar): Better error messages +// NOTE: This function does not handle CFG related validation +// Performs logical layout validation. See Section 2.4 +spv_result_t ModuleLayoutPass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + + switch (_.current_layout_section()) { + case kLayoutCapabilities: + case kLayoutExtensions: + case kLayoutExtInstImport: + case kLayoutMemoryModel: + case kLayoutEntryPoint: + case kLayoutExecutionMode: + case kLayoutDebug1: + case kLayoutDebug2: + case kLayoutDebug3: + case kLayoutAnnotations: + case kLayoutTypes: + if (auto error = ModuleScopedInstructions(_, inst, opcode)) return error; + break; + case kLayoutFunctionDeclarations: + case kLayoutFunctionDefinitions: + if (auto error = FunctionScopedInstructions(_, inst, opcode)) { + return error; + } + break; + } + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_literals.cpp b/source/val/validate_literals.cpp new file mode 100644 index 000000000..53aae0767 --- /dev/null +++ b/source/val/validate_literals.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validates literal numbers. + +#include "source/val/validate.h" + +#include + +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +// Returns true if the operand holds a literal number +bool IsLiteralNumber(const spv_parsed_operand_t& operand) { + switch (operand.number_kind) { + case SPV_NUMBER_SIGNED_INT: + case SPV_NUMBER_UNSIGNED_INT: + case SPV_NUMBER_FLOATING: + return true; + default: + return false; + } +} + +// Verifies that the upper bits of the given upper |word| with given +// lower |width| are zero- or sign-extended when |signed_int| is true +bool VerifyUpperBits(uint32_t word, uint32_t width, bool signed_int) { + assert(width < 32); + assert(0 < width); + const uint32_t upper_mask = 0xFFFFFFFFu << width; + const uint32_t upper_bits = word & upper_mask; + + bool result = false; + if (signed_int) { + const uint32_t sign_bit = word & (1u << (width - 1)); + if (sign_bit) { + result = upper_bits == upper_mask; + } else { + result = upper_bits == 0; + } + } else { + result = upper_bits == 0; + } + return result; +} + +} // namespace + +// Validates that literal numbers are represented according to the spec +spv_result_t LiteralsPass(ValidationState_t& _, const Instruction* inst) { + // For every operand that is a literal number + for (size_t i = 0; i < inst->operands().size(); i++) { + const spv_parsed_operand_t& operand = inst->operand(i); + if (!IsLiteralNumber(operand)) continue; + + // The upper bits are always in the last word (little-endian) + int last_index = operand.offset + operand.num_words - 1; + const uint32_t upper_word = inst->word(last_index); + + // TODO(jcaraban): is the |word size| defined in some header? + const uint32_t word_size = 32; + uint32_t bit_width = operand.number_bit_width; + + // Bit widths that are a multiple of the word size have no upper bits + const auto remaining_value_bits = bit_width % word_size; + if (remaining_value_bits == 0) continue; + + const bool signedness = operand.number_kind == SPV_NUMBER_SIGNED_INT; + + if (!VerifyUpperBits(upper_word, remaining_value_bits, signedness)) { + return _.diag(SPV_ERROR_INVALID_VALUE, inst) + << "The high-order bits of a literal number in instruction " + << inst->id() << " must be 0 for a floating-point type, " + << "or 0 for an integer type with Signedness of 0, " + << "or sign extended when Signedness is 1"; + } + } + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_logicals.cpp b/source/val/validate_logicals.cpp new file mode 100644 index 000000000..a25460b51 --- /dev/null +++ b/source/val/validate_logicals.cpp @@ -0,0 +1,265 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validates correctness of logical SPIR-V instructions. + +#include "source/val/validate.h" + +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { + +// Validates correctness of logical instructions. +spv_result_t LogicalsPass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + const uint32_t result_type = inst->type_id(); + + switch (opcode) { + case SpvOpAny: + case SpvOpAll: { + if (!_.IsBoolScalarType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected bool scalar type as Result Type: " + << spvOpcodeString(opcode); + + const uint32_t vector_type = _.GetOperandTypeId(inst, 2); + if (!vector_type || !_.IsBoolVectorType(vector_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected operand to be vector bool: " + << spvOpcodeString(opcode); + + break; + } + + case SpvOpIsNan: + case SpvOpIsInf: + case SpvOpIsFinite: + case SpvOpIsNormal: + case SpvOpSignBitSet: { + if (!_.IsBoolScalarType(result_type) && !_.IsBoolVectorType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected bool scalar or vector type as Result Type: " + << spvOpcodeString(opcode); + + const uint32_t operand_type = _.GetOperandTypeId(inst, 2); + if (!operand_type || (!_.IsFloatScalarType(operand_type) && + !_.IsFloatVectorType(operand_type))) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected operand to be scalar or vector float: " + << spvOpcodeString(opcode); + + if (_.GetDimension(result_type) != _.GetDimension(operand_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected vector sizes of Result Type and the operand to be " + "equal: " + << spvOpcodeString(opcode); + + break; + } + + case SpvOpFOrdEqual: + case SpvOpFUnordEqual: + case SpvOpFOrdNotEqual: + case SpvOpFUnordNotEqual: + case SpvOpFOrdLessThan: + case SpvOpFUnordLessThan: + case SpvOpFOrdGreaterThan: + case SpvOpFUnordGreaterThan: + case SpvOpFOrdLessThanEqual: + case SpvOpFUnordLessThanEqual: + case SpvOpFOrdGreaterThanEqual: + case SpvOpFUnordGreaterThanEqual: + case SpvOpLessOrGreater: + case SpvOpOrdered: + case SpvOpUnordered: { + if (!_.IsBoolScalarType(result_type) && !_.IsBoolVectorType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected bool scalar or vector type as Result Type: " + << spvOpcodeString(opcode); + + const uint32_t left_operand_type = _.GetOperandTypeId(inst, 2); + if (!left_operand_type || (!_.IsFloatScalarType(left_operand_type) && + !_.IsFloatVectorType(left_operand_type))) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected operands to be scalar or vector float: " + << spvOpcodeString(opcode); + + if (_.GetDimension(result_type) != _.GetDimension(left_operand_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected vector sizes of Result Type and the operands to be " + "equal: " + << spvOpcodeString(opcode); + + if (left_operand_type != _.GetOperandTypeId(inst, 3)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected left and right operands to have the same type: " + << spvOpcodeString(opcode); + + break; + } + + case SpvOpLogicalEqual: + case SpvOpLogicalNotEqual: + case SpvOpLogicalOr: + case SpvOpLogicalAnd: { + if (!_.IsBoolScalarType(result_type) && !_.IsBoolVectorType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected bool scalar or vector type as Result Type: " + << spvOpcodeString(opcode); + + if (result_type != _.GetOperandTypeId(inst, 2) || + result_type != _.GetOperandTypeId(inst, 3)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected both operands to be of Result Type: " + << spvOpcodeString(opcode); + + break; + } + + case SpvOpLogicalNot: { + if (!_.IsBoolScalarType(result_type) && !_.IsBoolVectorType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected bool scalar or vector type as Result Type: " + << spvOpcodeString(opcode); + + if (result_type != _.GetOperandTypeId(inst, 2)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected operand to be of Result Type: " + << spvOpcodeString(opcode); + + break; + } + + case SpvOpSelect: { + uint32_t dimension = 1; + { + const Instruction* type_inst = _.FindDef(result_type); + assert(type_inst); + + const SpvOp type_opcode = type_inst->opcode(); + switch (type_opcode) { + case SpvOpTypePointer: { + if (_.addressing_model() == SpvAddressingModelLogical && + !_.features().variable_pointers && + !_.features().variable_pointers_storage_buffer) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Using pointers with OpSelect requires capability " + << "VariablePointers or VariablePointersStorageBuffer"; + break; + } + + case SpvOpTypeVector: { + dimension = type_inst->word(3); + break; + } + + case SpvOpTypeBool: + case SpvOpTypeInt: + case SpvOpTypeFloat: { + break; + } + + default: { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected scalar or vector type as Result Type: " + << spvOpcodeString(opcode); + } + } + } + + const uint32_t condition_type = _.GetOperandTypeId(inst, 2); + const uint32_t left_type = _.GetOperandTypeId(inst, 3); + const uint32_t right_type = _.GetOperandTypeId(inst, 4); + + if (!condition_type || (!_.IsBoolScalarType(condition_type) && + !_.IsBoolVectorType(condition_type))) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected bool scalar or vector type as condition: " + << spvOpcodeString(opcode); + + if (_.GetDimension(condition_type) != dimension) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected vector sizes of Result Type and the condition to be" + << " equal: " << spvOpcodeString(opcode); + + if (result_type != left_type || result_type != right_type) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected both objects to be of Result Type: " + << spvOpcodeString(opcode); + + break; + } + + case SpvOpIEqual: + case SpvOpINotEqual: + case SpvOpUGreaterThan: + case SpvOpUGreaterThanEqual: + case SpvOpULessThan: + case SpvOpULessThanEqual: + case SpvOpSGreaterThan: + case SpvOpSGreaterThanEqual: + case SpvOpSLessThan: + case SpvOpSLessThanEqual: { + if (!_.IsBoolScalarType(result_type) && !_.IsBoolVectorType(result_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected bool scalar or vector type as Result Type: " + << spvOpcodeString(opcode); + + const uint32_t left_type = _.GetOperandTypeId(inst, 2); + const uint32_t right_type = _.GetOperandTypeId(inst, 3); + + if (!left_type || + (!_.IsIntScalarType(left_type) && !_.IsIntVectorType(left_type))) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected operands to be scalar or vector int: " + << spvOpcodeString(opcode); + + if (_.GetDimension(result_type) != _.GetDimension(left_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected vector sizes of Result Type and the operands to be" + << " equal: " << spvOpcodeString(opcode); + + if (!right_type || + (!_.IsIntScalarType(right_type) && !_.IsIntVectorType(right_type))) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected operands to be scalar or vector int: " + << spvOpcodeString(opcode); + + if (_.GetDimension(result_type) != _.GetDimension(right_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected vector sizes of Result Type and the operands to be" + << " equal: " << spvOpcodeString(opcode); + + if (_.GetBitWidth(left_type) != _.GetBitWidth(right_type)) + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected both operands to have the same component bit " + "width: " + << spvOpcodeString(opcode); + + break; + } + + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_memory.cpp b/source/val/validate_memory.cpp new file mode 100644 index 000000000..3b104be8d --- /dev/null +++ b/source/val/validate_memory.cpp @@ -0,0 +1,1186 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/validate.h" + +#include +#include +#include + +#include "source/opcode.h" +#include "source/spirv_target_env.h" +#include "source/val/instruction.h" +#include "source/val/validate_scopes.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +bool AreLayoutCompatibleStructs(ValidationState_t&, const Instruction*, + const Instruction*); +bool HaveLayoutCompatibleMembers(ValidationState_t&, const Instruction*, + const Instruction*); +bool HaveSameLayoutDecorations(ValidationState_t&, const Instruction*, + const Instruction*); +bool HasConflictingMemberOffsets(const std::vector&, + const std::vector&); + +bool IsAllowedTypeOrArrayOfSame(ValidationState_t& _, const Instruction* type, + std::initializer_list allowed) { + if (std::find(allowed.begin(), allowed.end(), type->opcode()) != + allowed.end()) { + return true; + } + if (type->opcode() == SpvOpTypeArray || + type->opcode() == SpvOpTypeRuntimeArray) { + auto elem_type = _.FindDef(type->word(2)); + return std::find(allowed.begin(), allowed.end(), elem_type->opcode()) != + allowed.end(); + } + return false; +} + +// Returns true if the two instructions represent structs that, as far as the +// validator can tell, have the exact same data layout. +bool AreLayoutCompatibleStructs(ValidationState_t& _, const Instruction* type1, + const Instruction* type2) { + if (type1->opcode() != SpvOpTypeStruct) { + return false; + } + if (type2->opcode() != SpvOpTypeStruct) { + return false; + } + + if (!HaveLayoutCompatibleMembers(_, type1, type2)) return false; + + return HaveSameLayoutDecorations(_, type1, type2); +} + +// Returns true if the operands to the OpTypeStruct instruction defining the +// types are the same or are layout compatible types. |type1| and |type2| must +// be OpTypeStruct instructions. +bool HaveLayoutCompatibleMembers(ValidationState_t& _, const Instruction* type1, + const Instruction* type2) { + assert(type1->opcode() == SpvOpTypeStruct && + "type1 must be an OpTypeStruct instruction."); + assert(type2->opcode() == SpvOpTypeStruct && + "type2 must be an OpTypeStruct instruction."); + const auto& type1_operands = type1->operands(); + const auto& type2_operands = type2->operands(); + if (type1_operands.size() != type2_operands.size()) { + return false; + } + + for (size_t operand = 2; operand < type1_operands.size(); ++operand) { + if (type1->word(operand) != type2->word(operand)) { + auto def1 = _.FindDef(type1->word(operand)); + auto def2 = _.FindDef(type2->word(operand)); + if (!AreLayoutCompatibleStructs(_, def1, def2)) { + return false; + } + } + } + return true; +} + +// Returns true if all decorations that affect the data layout of the struct +// (like Offset), are the same for the two types. |type1| and |type2| must be +// OpTypeStruct instructions. +bool HaveSameLayoutDecorations(ValidationState_t& _, const Instruction* type1, + const Instruction* type2) { + assert(type1->opcode() == SpvOpTypeStruct && + "type1 must be an OpTypeStruct instruction."); + assert(type2->opcode() == SpvOpTypeStruct && + "type2 must be an OpTypeStruct instruction."); + const std::vector& type1_decorations = + _.id_decorations(type1->id()); + const std::vector& type2_decorations = + _.id_decorations(type2->id()); + + // TODO: Will have to add other check for arrays an matricies if we want to + // handle them. + if (HasConflictingMemberOffsets(type1_decorations, type2_decorations)) { + return false; + } + + return true; +} + +bool HasConflictingMemberOffsets( + const std::vector& type1_decorations, + const std::vector& type2_decorations) { + { + // We are interested in conflicting decoration. If a decoration is in one + // list but not the other, then we will assume the code is correct. We are + // looking for things we know to be wrong. + // + // We do not have to traverse type2_decoration because, after traversing + // type1_decorations, anything new will not be found in + // type1_decoration. Therefore, it cannot lead to a conflict. + for (const Decoration& decoration : type1_decorations) { + switch (decoration.dec_type()) { + case SpvDecorationOffset: { + // Since these affect the layout of the struct, they must be present + // in both structs. + auto compare = [&decoration](const Decoration& rhs) { + if (rhs.dec_type() != SpvDecorationOffset) return false; + return decoration.struct_member_index() == + rhs.struct_member_index(); + }; + auto i = std::find_if(type2_decorations.begin(), + type2_decorations.end(), compare); + if (i != type2_decorations.end() && + decoration.params().front() != i->params().front()) { + return true; + } + } break; + default: + // This decoration does not affect the layout of the structure, so + // just moving on. + break; + } + } + } + return false; +} + +// If |skip_builtin| is true, returns true if |storage| contains bool within +// it and no storage that contains the bool is builtin. +// If |skip_builtin| is false, returns true if |storage| contains bool within +// it. +bool ContainsInvalidBool(ValidationState_t& _, const Instruction* storage, + bool skip_builtin) { + if (skip_builtin) { + for (const Decoration& decoration : _.id_decorations(storage->id())) { + if (decoration.dec_type() == SpvDecorationBuiltIn) return false; + } + } + + const size_t elem_type_index = 1; + uint32_t elem_type_id; + Instruction* elem_type; + + switch (storage->opcode()) { + case SpvOpTypeBool: + return true; + case SpvOpTypeVector: + case SpvOpTypeMatrix: + case SpvOpTypeArray: + case SpvOpTypeRuntimeArray: + elem_type_id = storage->GetOperandAs(elem_type_index); + elem_type = _.FindDef(elem_type_id); + return ContainsInvalidBool(_, elem_type, skip_builtin); + case SpvOpTypeStruct: + for (size_t member_type_index = 1; + member_type_index < storage->operands().size(); + ++member_type_index) { + auto member_type_id = + storage->GetOperandAs(member_type_index); + auto member_type = _.FindDef(member_type_id); + if (ContainsInvalidBool(_, member_type, skip_builtin)) return true; + } + default: + break; + } + return false; +} + +std::pair GetStorageClass( + ValidationState_t& _, const Instruction* inst) { + SpvStorageClass dst_sc = SpvStorageClassMax; + SpvStorageClass src_sc = SpvStorageClassMax; + switch (inst->opcode()) { + case SpvOpLoad: { + auto load_pointer = _.FindDef(inst->GetOperandAs(2)); + auto load_pointer_type = _.FindDef(load_pointer->type_id()); + dst_sc = load_pointer_type->GetOperandAs(1); + break; + } + case SpvOpStore: { + auto store_pointer = _.FindDef(inst->GetOperandAs(0)); + auto store_pointer_type = _.FindDef(store_pointer->type_id()); + dst_sc = store_pointer_type->GetOperandAs(1); + break; + } + case SpvOpCopyMemory: + case SpvOpCopyMemorySized: { + auto dst = _.FindDef(inst->GetOperandAs(0)); + auto dst_type = _.FindDef(dst->type_id()); + dst_sc = dst_type->GetOperandAs(1); + auto src = _.FindDef(inst->GetOperandAs(1)); + auto src_type = _.FindDef(src->type_id()); + src_sc = src_type->GetOperandAs(1); + break; + } + default: + break; + } + + return std::make_pair(dst_sc, src_sc); +} + +// This function is only called for OpLoad, OpStore, OpCopyMemory and +// OpCopyMemorySized. +uint32_t GetMakeAvailableScope(const Instruction* inst, uint32_t mask) { + uint32_t offset = 1; + if (mask & SpvMemoryAccessAlignedMask) ++offset; + + uint32_t scope_id = 0; + switch (inst->opcode()) { + case SpvOpLoad: + case SpvOpCopyMemorySized: + return inst->GetOperandAs(3 + offset); + case SpvOpStore: + case SpvOpCopyMemory: + return inst->GetOperandAs(2 + offset); + default: + assert(false && "unexpected opcode"); + break; + } + + return scope_id; +} + +// This function is only called for OpLoad, OpStore, OpCopyMemory and +// OpCopyMemorySized. +uint32_t GetMakeVisibleScope(const Instruction* inst, uint32_t mask) { + uint32_t offset = 1; + if (mask & SpvMemoryAccessAlignedMask) ++offset; + if (mask & SpvMemoryAccessMakePointerAvailableKHRMask) ++offset; + + uint32_t scope_id = 0; + switch (inst->opcode()) { + case SpvOpLoad: + case SpvOpCopyMemorySized: + return inst->GetOperandAs(3 + offset); + case SpvOpStore: + case SpvOpCopyMemory: + return inst->GetOperandAs(2 + offset); + default: + assert(false && "unexpected opcode"); + break; + } + + return scope_id; +} + +bool DoesStructContainRTA(const ValidationState_t& _, const Instruction* inst) { + for (size_t member_index = 1; member_index < inst->operands().size(); + ++member_index) { + const auto member_id = inst->GetOperandAs(member_index); + const auto member_type = _.FindDef(member_id); + if (member_type->opcode() == SpvOpTypeRuntimeArray) return true; + } + return false; +} + +spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst, + uint32_t index) { + SpvStorageClass dst_sc, src_sc; + std::tie(dst_sc, src_sc) = GetStorageClass(_, inst); + if (inst->operands().size() <= index) { + if (src_sc == SpvStorageClassPhysicalStorageBufferEXT || + dst_sc == SpvStorageClassPhysicalStorageBufferEXT) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Memory accesses with PhysicalStorageBufferEXT must use " + "Aligned."; + } + return SPV_SUCCESS; + } + + uint32_t mask = inst->GetOperandAs(index); + if (mask & SpvMemoryAccessMakePointerAvailableKHRMask) { + if (inst->opcode() == SpvOpLoad) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "MakePointerAvailableKHR cannot be used with OpLoad."; + } + + if (!(mask & SpvMemoryAccessNonPrivatePointerKHRMask)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "NonPrivatePointerKHR must be specified if " + "MakePointerAvailableKHR is specified."; + } + + // Check the associated scope for MakeAvailableKHR. + const auto available_scope = GetMakeAvailableScope(inst, mask); + if (auto error = ValidateMemoryScope(_, inst, available_scope)) + return error; + } + + if (mask & SpvMemoryAccessMakePointerVisibleKHRMask) { + if (inst->opcode() == SpvOpStore) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "MakePointerVisibleKHR cannot be used with OpStore."; + } + + if (!(mask & SpvMemoryAccessNonPrivatePointerKHRMask)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "NonPrivatePointerKHR must be specified if " + << "MakePointerVisibleKHR is specified."; + } + + // Check the associated scope for MakeVisibleKHR. + const auto visible_scope = GetMakeVisibleScope(inst, mask); + if (auto error = ValidateMemoryScope(_, inst, visible_scope)) return error; + } + + if (mask & SpvMemoryAccessNonPrivatePointerKHRMask) { + if (dst_sc != SpvStorageClassUniform && + dst_sc != SpvStorageClassWorkgroup && + dst_sc != SpvStorageClassCrossWorkgroup && + dst_sc != SpvStorageClassGeneric && dst_sc != SpvStorageClassImage && + dst_sc != SpvStorageClassStorageBuffer && + dst_sc != SpvStorageClassPhysicalStorageBufferEXT) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "NonPrivatePointerKHR requires a pointer in Uniform, " + << "Workgroup, CrossWorkgroup, Generic, Image or StorageBuffer " + << "storage classes."; + } + if (src_sc != SpvStorageClassMax && src_sc != SpvStorageClassUniform && + src_sc != SpvStorageClassWorkgroup && + src_sc != SpvStorageClassCrossWorkgroup && + src_sc != SpvStorageClassGeneric && src_sc != SpvStorageClassImage && + src_sc != SpvStorageClassStorageBuffer && + src_sc != SpvStorageClassPhysicalStorageBufferEXT) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "NonPrivatePointerKHR requires a pointer in Uniform, " + << "Workgroup, CrossWorkgroup, Generic, Image or StorageBuffer " + << "storage classes."; + } + } + + if (!(mask & SpvMemoryAccessAlignedMask)) { + if (src_sc == SpvStorageClassPhysicalStorageBufferEXT || + dst_sc == SpvStorageClassPhysicalStorageBufferEXT) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Memory accesses with PhysicalStorageBufferEXT must use " + "Aligned."; + } + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) { + auto result_type = _.FindDef(inst->type_id()); + if (!result_type || result_type->opcode() != SpvOpTypePointer) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpVariable Result Type '" << _.getIdName(inst->type_id()) + << "' is not a pointer type."; + } + + const auto initializer_index = 3; + const auto storage_class_index = 2; + if (initializer_index < inst->operands().size()) { + const auto initializer_id = inst->GetOperandAs(initializer_index); + const auto initializer = _.FindDef(initializer_id); + const auto is_module_scope_var = + initializer && (initializer->opcode() == SpvOpVariable) && + (initializer->GetOperandAs(storage_class_index) != + SpvStorageClassFunction); + const auto is_constant = + initializer && spvOpcodeIsConstant(initializer->opcode()); + if (!initializer || !(is_constant || is_module_scope_var)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpVariable Initializer '" << _.getIdName(initializer_id) + << "' is not a constant or module-scope variable."; + } + } + + const auto storage_class = + inst->GetOperandAs(storage_class_index); + if (storage_class != SpvStorageClassWorkgroup && + storage_class != SpvStorageClassCrossWorkgroup && + storage_class != SpvStorageClassPrivate && + storage_class != SpvStorageClassFunction && + storage_class != SpvStorageClassRayPayloadNV && + storage_class != SpvStorageClassIncomingRayPayloadNV && + storage_class != SpvStorageClassHitAttributeNV && + storage_class != SpvStorageClassCallableDataNV && + storage_class != SpvStorageClassIncomingCallableDataNV) { + const auto storage_index = 2; + const auto storage_id = result_type->GetOperandAs(storage_index); + const auto storage = _.FindDef(storage_id); + bool storage_input_or_output = storage_class == SpvStorageClassInput || + storage_class == SpvStorageClassOutput; + bool builtin = false; + if (storage_input_or_output) { + for (const Decoration& decoration : _.id_decorations(inst->id())) { + if (decoration.dec_type() == SpvDecorationBuiltIn) { + builtin = true; + break; + } + } + } + if (!(storage_input_or_output && builtin) && + ContainsInvalidBool(_, storage, storage_input_or_output)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "If OpTypeBool is stored in conjunction with OpVariable, it " + << "can only be used with non-externally visible shader Storage " + << "Classes: Workgroup, CrossWorkgroup, Private, and Function"; + } + } + + // SPIR-V 3.32.8: Check that pointer type and variable type have the same + // storage class. + const auto result_storage_class_index = 1; + const auto result_storage_class = + result_type->GetOperandAs(result_storage_class_index); + if (storage_class != result_storage_class) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "From SPIR-V spec, section 3.32.8 on OpVariable:\n" + << "Its Storage Class operand must be the same as the Storage Class " + << "operand of the result type."; + } + + // Variable pointer related restrictions. + const auto pointee = _.FindDef(result_type->word(3)); + if (_.addressing_model() == SpvAddressingModelLogical && + !_.options()->relax_logical_pointer) { + // VariablePointersStorageBuffer is implied by VariablePointers. + if (pointee->opcode() == SpvOpTypePointer) { + if (!_.HasCapability(SpvCapabilityVariablePointersStorageBuffer)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "In Logical addressing, variables may not allocate a pointer " + << "type"; + } else if (storage_class != SpvStorageClassFunction && + storage_class != SpvStorageClassPrivate) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "In Logical addressing with variable pointers, variables " + << "that allocate pointers must be in Function or Private " + << "storage classes"; + } + } + } + + // Vulkan 14.5.1: Check type of PushConstant variables. + // Vulkan 14.5.2: Check type of UniformConstant and Uniform variables. + if (spvIsVulkanEnv(_.context()->target_env)) { + if (storage_class == SpvStorageClassPushConstant) { + if (!IsAllowedTypeOrArrayOfSame(_, pointee, {SpvOpTypeStruct})) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "PushConstant OpVariable '" << _.getIdName(inst->id()) + << "' has illegal type.\n" + << "From Vulkan spec, section 14.5.1:\n" + << "Such variables must be typed as OpTypeStruct, " + << "or an array of this type"; + } + } + + if (storage_class == SpvStorageClassUniformConstant) { + if (!IsAllowedTypeOrArrayOfSame( + _, pointee, + {SpvOpTypeImage, SpvOpTypeSampler, SpvOpTypeSampledImage, + SpvOpTypeAccelerationStructureNV})) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "UniformConstant OpVariable '" << _.getIdName(inst->id()) + << "' has illegal type.\n" + << "From Vulkan spec, section 14.5.2:\n" + << "Variables identified with the UniformConstant storage class " + << "are used only as handles to refer to opaque resources. Such " + << "variables must be typed as OpTypeImage, OpTypeSampler, " + << "OpTypeSampledImage, OpTypeAccelerationStructureNV, " + << "or an array of one of these types."; + } + } + + if (storage_class == SpvStorageClassUniform) { + if (!IsAllowedTypeOrArrayOfSame(_, pointee, {SpvOpTypeStruct})) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Uniform OpVariable '" << _.getIdName(inst->id()) + << "' has illegal type.\n" + << "From Vulkan spec, section 14.5.2:\n" + << "Variables identified with the Uniform storage class are " + << "used to access transparent buffer backed resources. Such " + << "variables must be typed as OpTypeStruct, or an array of " + << "this type"; + } + } + } + + // WebGPU & Vulkan Appendix A: Check that if contains initializer, then + // storage class is Output, Private, or Function. + if (inst->operands().size() > 3 && storage_class != SpvStorageClassOutput && + storage_class != SpvStorageClassPrivate && + storage_class != SpvStorageClassFunction) { + if (spvIsVulkanEnv(_.context()->target_env)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpVariable, '" << _.getIdName(inst->id()) + << "', has a disallowed initializer & storage class " + << "combination.\n" + << "From Vulkan spec, Appendix A:\n" + << "Variable declarations that include initializers must have " + << "one of the following storage classes: Output, Private, or " + << "Function"; + } + + if (spvIsWebGPUEnv(_.context()->target_env)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpVariable, '" << _.getIdName(inst->id()) + << "', has a disallowed initializer & storage class " + << "combination.\n" + << "From WebGPU execution environment spec:\n" + << "Variable declarations that include initializers must have " + << "one of the following storage classes: Output, Private, or " + << "Function"; + } + } + + // WebGPU: All variables with storage class Output, Private, or Function MUST + // have an initializer. + if (spvIsWebGPUEnv(_.context()->target_env) && inst->operands().size() <= 3 && + (storage_class == SpvStorageClassOutput || + storage_class == SpvStorageClassPrivate || + storage_class == SpvStorageClassFunction)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpVariable, '" << _.getIdName(inst->id()) + << "', must have an initializer.\n" + << "From WebGPU execution environment spec:\n" + << "All variables in the following storage classes must have an " + << "initializer: Output, Private, or Function"; + } + + if (storage_class == SpvStorageClassPhysicalStorageBufferEXT) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "PhysicalStorageBufferEXT must not be used with OpVariable."; + } + + auto pointee_base = pointee; + while (pointee_base->opcode() == SpvOpTypeArray) { + pointee_base = _.FindDef(pointee_base->GetOperandAs(1u)); + } + if (pointee_base->opcode() == SpvOpTypePointer) { + if (pointee_base->GetOperandAs(1u) == + SpvStorageClassPhysicalStorageBufferEXT) { + // check for AliasedPointerEXT/RestrictPointerEXT + bool foundAliased = + _.HasDecoration(inst->id(), SpvDecorationAliasedPointerEXT); + bool foundRestrict = + _.HasDecoration(inst->id(), SpvDecorationRestrictPointerEXT); + if (!foundAliased && !foundRestrict) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpVariable " << inst->id() + << ": expected AliasedPointerEXT or RestrictPointerEXT for " + << "PhysicalStorageBufferEXT pointer."; + } + if (foundAliased && foundRestrict) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpVariable " << inst->id() + << ": can't specify both AliasedPointerEXT and " + << "RestrictPointerEXT for PhysicalStorageBufferEXT pointer."; + } + } + } + + // Vulkan specific validation rules for OpTypeRuntimeArray + if (spvIsVulkanEnv(_.context()->target_env)) { + const auto type_index = 2; + const auto value_id = result_type->GetOperandAs(type_index); + auto value_type = _.FindDef(value_id); + // OpTypeRuntimeArray should only ever be in a container like OpTypeStruct, + // so should never appear as a bare variable. + // Unless the module has the RuntimeDescriptorArrayEXT capability. + if (value_type && value_type->opcode() == SpvOpTypeRuntimeArray) { + if (!_.HasCapability(SpvCapabilityRuntimeDescriptorArrayEXT)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpVariable, '" << _.getIdName(inst->id()) + << "', is attempting to create memory for an illegal type, " + << "OpTypeRuntimeArray.\nFor Vulkan OpTypeRuntimeArray can only " + << "appear as the final member of an OpTypeStruct, thus cannot " + << "be instantiated via OpVariable"; + } else { + // A bare variable OpTypeRuntimeArray is allowed in this context, but + // still need to check the storage class. + if (storage_class != SpvStorageClassStorageBuffer && + storage_class != SpvStorageClassUniform && + storage_class != SpvStorageClassUniformConstant) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "For Vulkan with RuntimeDescriptorArrayEXT, a variable " + << "containing OpTypeRuntimeArray must have storage class of " + << "StorageBuffer, Uniform, or UniformConstant."; + } + } + } + + // If an OpStruct has an OpTypeRuntimeArray somewhere within it, then it + // must either have the storage class StorageBuffer and be decorated + // with Block, or it must be in the Uniform storage class and be decorated + // as BufferBlock. + if (value_type && value_type->opcode() == SpvOpTypeStruct) { + if (DoesStructContainRTA(_, value_type)) { + if (storage_class == SpvStorageClassStorageBuffer) { + if (!_.HasDecoration(value_id, SpvDecorationBlock)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "For Vulkan, an OpTypeStruct variable containing an " + << "OpTypeRuntimeArray must be decorated with Block if it " + << "has storage class StorageBuffer."; + } + } else if (storage_class == SpvStorageClassUniform) { + if (!_.HasDecoration(value_id, SpvDecorationBufferBlock)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "For Vulkan, an OpTypeStruct variable containing an " + << "OpTypeRuntimeArray must be decorated with BufferBlock " + << "if it has storage class Uniform."; + } + } else { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "For Vulkan, OpTypeStruct variables containing " + << "OpTypeRuntimeArray must have storage class of " + << "StorageBuffer or Uniform."; + } + } + } + } + + // WebGPU specific validation rules for OpTypeRuntimeArray + if (spvIsWebGPUEnv(_.context()->target_env)) { + const auto type_index = 2; + const auto value_id = result_type->GetOperandAs(type_index); + auto value_type = _.FindDef(value_id); + // OpTypeRuntimeArray should only ever be in an OpTypeStruct, + // so should never appear as a bare variable. + if (value_type && value_type->opcode() == SpvOpTypeRuntimeArray) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpVariable, '" << _.getIdName(inst->id()) + << "', is attempting to create memory for an illegal type, " + << "OpTypeRuntimeArray.\nFor WebGPU OpTypeRuntimeArray can only " + << "appear as the final member of an OpTypeStruct, thus cannot " + << "be instantiated via OpVariable"; + } + + // If an OpStruct has an OpTypeRuntimeArray somewhere within it, then it + // must have the storage class StorageBuffer and be decorated + // with Block. + if (value_type && value_type->opcode() == SpvOpTypeStruct) { + if (DoesStructContainRTA(_, value_type)) { + if (storage_class == SpvStorageClassStorageBuffer) { + if (!_.HasDecoration(value_id, SpvDecorationBlock)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "For WebGPU, an OpTypeStruct variable containing an " + << "OpTypeRuntimeArray must be decorated with Block if it " + << "has storage class StorageBuffer."; + } + } else { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "For WebGPU, OpTypeStruct variables containing " + << "OpTypeRuntimeArray must have storage class of " + << "StorageBuffer"; + } + } + } + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateLoad(ValidationState_t& _, const Instruction* inst) { + const auto result_type = _.FindDef(inst->type_id()); + if (!result_type) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpLoad Result Type '" << _.getIdName(inst->type_id()) + << "' is not defined."; + } + + const bool uses_variable_pointers = + _.features().variable_pointers || + _.features().variable_pointers_storage_buffer; + const auto pointer_index = 2; + const auto pointer_id = inst->GetOperandAs(pointer_index); + const auto pointer = _.FindDef(pointer_id); + if (!pointer || + ((_.addressing_model() == SpvAddressingModelLogical) && + ((!uses_variable_pointers && + !spvOpcodeReturnsLogicalPointer(pointer->opcode())) || + (uses_variable_pointers && + !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpLoad Pointer '" << _.getIdName(pointer_id) + << "' is not a logical pointer."; + } + + const auto pointer_type = _.FindDef(pointer->type_id()); + if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpLoad type for pointer '" << _.getIdName(pointer_id) + << "' is not a pointer type."; + } + + const auto pointee_type = _.FindDef(pointer_type->GetOperandAs(2)); + if (!pointee_type || result_type->id() != pointee_type->id()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpLoad Result Type '" << _.getIdName(inst->type_id()) + << "' does not match Pointer '" << _.getIdName(pointer->id()) + << "'s type."; + } + + if (auto error = CheckMemoryAccess(_, inst, 3)) return error; + + return SPV_SUCCESS; +} + +spv_result_t ValidateStore(ValidationState_t& _, const Instruction* inst) { + const bool uses_variable_pointer = + _.features().variable_pointers || + _.features().variable_pointers_storage_buffer; + const auto pointer_index = 0; + const auto pointer_id = inst->GetOperandAs(pointer_index); + const auto pointer = _.FindDef(pointer_id); + if (!pointer || + (_.addressing_model() == SpvAddressingModelLogical && + ((!uses_variable_pointer && + !spvOpcodeReturnsLogicalPointer(pointer->opcode())) || + (uses_variable_pointer && + !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpStore Pointer '" << _.getIdName(pointer_id) + << "' is not a logical pointer."; + } + const auto pointer_type = _.FindDef(pointer->type_id()); + if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpStore type for pointer '" << _.getIdName(pointer_id) + << "' is not a pointer type."; + } + const auto type_id = pointer_type->GetOperandAs(2); + const auto type = _.FindDef(type_id); + if (!type || SpvOpTypeVoid == type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpStore Pointer '" << _.getIdName(pointer_id) + << "'s type is void."; + } + + // validate storage class + { + uint32_t data_type; + uint32_t storage_class; + if (!_.GetPointerTypeInfo(pointer_type->id(), &data_type, &storage_class)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpStore Pointer '" << _.getIdName(pointer_id) + << "' is not pointer type"; + } + + if (storage_class == SpvStorageClassUniformConstant || + storage_class == SpvStorageClassInput || + storage_class == SpvStorageClassPushConstant) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpStore Pointer '" << _.getIdName(pointer_id) + << "' storage class is read-only"; + } + } + + const auto object_index = 1; + const auto object_id = inst->GetOperandAs(object_index); + const auto object = _.FindDef(object_id); + if (!object || !object->type_id()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpStore Object '" << _.getIdName(object_id) + << "' is not an object."; + } + const auto object_type = _.FindDef(object->type_id()); + if (!object_type || SpvOpTypeVoid == object_type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpStore Object '" << _.getIdName(object_id) + << "'s type is void."; + } + + if (type->id() != object_type->id()) { + if (!_.options()->relax_struct_store || type->opcode() != SpvOpTypeStruct || + object_type->opcode() != SpvOpTypeStruct) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpStore Pointer '" << _.getIdName(pointer_id) + << "'s type does not match Object '" + << _.getIdName(object->id()) << "'s type."; + } + + // TODO: Check for layout compatible matricies and arrays as well. + if (!AreLayoutCompatibleStructs(_, type, object_type)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpStore Pointer '" << _.getIdName(pointer_id) + << "'s layout does not match Object '" + << _.getIdName(object->id()) << "'s layout."; + } + } + + if (auto error = CheckMemoryAccess(_, inst, 2)) return error; + + return SPV_SUCCESS; +} + +spv_result_t ValidateCopyMemory(ValidationState_t& _, const Instruction* inst) { + const auto target_index = 0; + const auto target_id = inst->GetOperandAs(target_index); + const auto target = _.FindDef(target_id); + if (!target) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Target operand '" << _.getIdName(target_id) + << "' is not defined."; + } + + const auto source_index = 1; + const auto source_id = inst->GetOperandAs(source_index); + const auto source = _.FindDef(source_id); + if (!source) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Source operand '" << _.getIdName(source_id) + << "' is not defined."; + } + + const auto target_pointer_type = _.FindDef(target->type_id()); + if (!target_pointer_type || + target_pointer_type->opcode() != SpvOpTypePointer) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Target operand '" << _.getIdName(target_id) + << "' is not a pointer."; + } + + const auto source_pointer_type = _.FindDef(source->type_id()); + if (!source_pointer_type || + source_pointer_type->opcode() != SpvOpTypePointer) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Source operand '" << _.getIdName(source_id) + << "' is not a pointer."; + } + + if (inst->opcode() == SpvOpCopyMemory) { + const auto target_type = + _.FindDef(target_pointer_type->GetOperandAs(2)); + if (!target_type || target_type->opcode() == SpvOpTypeVoid) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Target operand '" << _.getIdName(target_id) + << "' cannot be a void pointer."; + } + + const auto source_type = + _.FindDef(source_pointer_type->GetOperandAs(2)); + if (!source_type || source_type->opcode() == SpvOpTypeVoid) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Source operand '" << _.getIdName(source_id) + << "' cannot be a void pointer."; + } + + if (target_type->id() != source_type->id()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Target '" << _.getIdName(source_id) + << "'s type does not match Source '" + << _.getIdName(source_type->id()) << "'s type."; + } + + if (auto error = CheckMemoryAccess(_, inst, 2)) return error; + } else { + const auto size_id = inst->GetOperandAs(2); + const auto size = _.FindDef(size_id); + if (!size) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Size operand '" << _.getIdName(size_id) + << "' is not defined."; + } + + const auto size_type = _.FindDef(size->type_id()); + if (!_.IsIntScalarType(size_type->id())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Size operand '" << _.getIdName(size_id) + << "' must be a scalar integer type."; + } + + bool is_zero = true; + switch (size->opcode()) { + case SpvOpConstantNull: + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Size operand '" << _.getIdName(size_id) + << "' cannot be a constant zero."; + case SpvOpConstant: + if (size_type->word(3) == 1 && + size->word(size->words().size() - 1) & 0x80000000) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Size operand '" << _.getIdName(size_id) + << "' cannot have the sign bit set to 1."; + } + for (size_t i = 3; is_zero && i < size->words().size(); ++i) { + is_zero &= (size->word(i) == 0); + } + if (is_zero) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Size operand '" << _.getIdName(size_id) + << "' cannot be a constant zero."; + } + break; + default: + // Cannot infer any other opcodes. + break; + } + + if (auto error = CheckMemoryAccess(_, inst, 3)) return error; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateAccessChain(ValidationState_t& _, + const Instruction* inst) { + std::string instr_name = + "Op" + std::string(spvOpcodeString(static_cast(inst->opcode()))); + + // The result type must be OpTypePointer. + auto result_type = _.FindDef(inst->type_id()); + if (SpvOpTypePointer != result_type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "The Result Type of " << instr_name << " '" + << _.getIdName(inst->id()) << "' must be OpTypePointer. Found Op" + << spvOpcodeString(static_cast(result_type->opcode())) << "."; + } + + // Result type is a pointer. Find out what it's pointing to. + // This will be used to make sure the indexing results in the same type. + // OpTypePointer word 3 is the type being pointed to. + const auto result_type_pointee = _.FindDef(result_type->word(3)); + + // Base must be a pointer, pointing to the base of a composite object. + const auto base_index = 2; + const auto base_id = inst->GetOperandAs(base_index); + const auto base = _.FindDef(base_id); + const auto base_type = _.FindDef(base->type_id()); + if (!base_type || SpvOpTypePointer != base_type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "The Base '" << _.getIdName(base_id) << "' in " << instr_name + << " instruction must be a pointer."; + } + + // The result pointer storage class and base pointer storage class must match. + // Word 2 of OpTypePointer is the Storage Class. + auto result_type_storage_class = result_type->word(2); + auto base_type_storage_class = base_type->word(2); + if (result_type_storage_class != base_type_storage_class) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "The result pointer storage class and base " + "pointer storage class in " + << instr_name << " do not match."; + } + + // The type pointed to by OpTypePointer (word 3) must be a composite type. + auto type_pointee = _.FindDef(base_type->word(3)); + + // Check Universal Limit (SPIR-V Spec. Section 2.17). + // The number of indexes passed to OpAccessChain may not exceed 255 + // The instruction includes 4 words + N words (for N indexes) + size_t num_indexes = inst->words().size() - 4; + if (inst->opcode() == SpvOpPtrAccessChain || + inst->opcode() == SpvOpInBoundsPtrAccessChain) { + // In pointer access chains, the element operand is required, but not + // counted as an index. + --num_indexes; + } + const size_t num_indexes_limit = + _.options()->universal_limits_.max_access_chain_indexes; + if (num_indexes > num_indexes_limit) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "The number of indexes in " << instr_name << " may not exceed " + << num_indexes_limit << ". Found " << num_indexes << " indexes."; + } + // Indexes walk the type hierarchy to the desired depth, potentially down to + // scalar granularity. The first index in Indexes will select the top-level + // member/element/component/element of the base composite. All composite + // constituents use zero-based numbering, as described by their OpType... + // instruction. The second index will apply similarly to that result, and so + // on. Once any non-composite type is reached, there must be no remaining + // (unused) indexes. + auto starting_index = 4; + if (inst->opcode() == SpvOpPtrAccessChain || + inst->opcode() == SpvOpInBoundsPtrAccessChain) { + ++starting_index; + } + for (size_t i = starting_index; i < inst->words().size(); ++i) { + const uint32_t cur_word = inst->words()[i]; + // Earlier ID checks ensure that cur_word definition exists. + auto cur_word_instr = _.FindDef(cur_word); + // The index must be a scalar integer type (See OpAccessChain in the Spec.) + auto index_type = _.FindDef(cur_word_instr->type_id()); + if (!index_type || SpvOpTypeInt != index_type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Indexes passed to " << instr_name + << " must be of type integer."; + } + switch (type_pointee->opcode()) { + case SpvOpTypeMatrix: + case SpvOpTypeVector: + case SpvOpTypeArray: + case SpvOpTypeRuntimeArray: { + // In OpTypeMatrix, OpTypeVector, OpTypeArray, and OpTypeRuntimeArray, + // word 2 is the Element Type. + type_pointee = _.FindDef(type_pointee->word(2)); + break; + } + case SpvOpTypeStruct: { + // In case of structures, there is an additional constraint on the + // index: the index must be an OpConstant. + if (SpvOpConstant != cur_word_instr->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr) + << "The passed to " << instr_name + << " to index into a " + "structure must be an OpConstant."; + } + // Get the index value from the OpConstant (word 3 of OpConstant). + // OpConstant could be a signed integer. But it's okay to treat it as + // unsigned because a negative constant int would never be seen as + // correct as a struct offset, since structs can't have more than 2 + // billion members. + const uint32_t cur_index = cur_word_instr->word(3); + // The index points to the struct member we want, therefore, the index + // should be less than the number of struct members. + const uint32_t num_struct_members = + static_cast(type_pointee->words().size() - 2); + if (cur_index >= num_struct_members) { + return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr) + << "Index is out of bounds: " << instr_name + << " can not find index " << cur_index + << " into the structure '" + << _.getIdName(type_pointee->id()) << "'. This structure has " + << num_struct_members << " members. Largest valid index is " + << num_struct_members - 1 << "."; + } + // Struct members IDs start at word 2 of OpTypeStruct. + auto structMemberId = type_pointee->word(cur_index + 2); + type_pointee = _.FindDef(structMemberId); + break; + } + default: { + // Give an error. reached non-composite type while indexes still remain. + return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr) + << instr_name + << " reached non-composite type while indexes " + "still remain to be traversed."; + } + } + } + // At this point, we have fully walked down from the base using the indeces. + // The type being pointed to should be the same as the result type. + if (type_pointee->id() != result_type_pointee->id()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << instr_name << " result type (Op" + << spvOpcodeString(static_cast(result_type_pointee->opcode())) + << ") does not match the type that results from indexing into the " + "base " + " (Op" + << spvOpcodeString(static_cast(type_pointee->opcode())) + << ")."; + } + + return SPV_SUCCESS; +} + +spv_result_t ValidatePtrAccessChain(ValidationState_t& _, + const Instruction* inst) { + if (_.addressing_model() == SpvAddressingModelLogical) { + if (!_.features().variable_pointers && + !_.features().variable_pointers_storage_buffer) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Generating variable pointers requires capability " + << "VariablePointers or VariablePointersStorageBuffer"; + } + } + return ValidateAccessChain(_, inst); +} + +spv_result_t ValidateArrayLength(ValidationState_t& state, + const Instruction* inst) { + std::string instr_name = + "Op" + std::string(spvOpcodeString(static_cast(inst->opcode()))); + + // Result type must be a 32-bit unsigned int. + auto result_type = state.FindDef(inst->type_id()); + if (result_type->opcode() != SpvOpTypeInt || + result_type->GetOperandAs(1) != 32 || + result_type->GetOperandAs(2) != 0) { + return state.diag(SPV_ERROR_INVALID_ID, inst) + << "The Result Type of " << instr_name << " '" + << state.getIdName(inst->id()) + << "' must be OpTypeInt with width 32 and signedness 0."; + } + + // The structure that is passed in must be an pointer to a structure, whose + // last element is a runtime array. + auto pointer = state.FindDef(inst->GetOperandAs(2)); + auto pointer_type = state.FindDef(pointer->type_id()); + if (pointer_type->opcode() != SpvOpTypePointer) { + return state.diag(SPV_ERROR_INVALID_ID, inst) + << "The Struture's type in " << instr_name << " '" + << state.getIdName(inst->id()) + << "' must be a pointer to an OpTypeStruct."; + } + + auto structure_type = state.FindDef(pointer_type->GetOperandAs(2)); + if (structure_type->opcode() != SpvOpTypeStruct) { + return state.diag(SPV_ERROR_INVALID_ID, inst) + << "The Struture's type in " << instr_name << " '" + << state.getIdName(inst->id()) + << "' must be a pointer to an OpTypeStruct."; + } + + auto num_of_members = structure_type->operands().size() - 1; + auto last_member = + state.FindDef(structure_type->GetOperandAs(num_of_members)); + if (last_member->opcode() != SpvOpTypeRuntimeArray) { + return state.diag(SPV_ERROR_INVALID_ID, inst) + << "The Struture's last member in " << instr_name << " '" + << state.getIdName(inst->id()) << "' must be an OpTypeRuntimeArray."; + } + + // The array member must the the index of the last element (the run time + // array). + if (inst->GetOperandAs(3) != num_of_members - 1) { + return state.diag(SPV_ERROR_INVALID_ID, inst) + << "The array member in " << instr_name << " '" + << state.getIdName(inst->id()) + << "' must be an the last member of the struct."; + } + return SPV_SUCCESS; +} + +} // namespace + +spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) { + switch (inst->opcode()) { + case SpvOpVariable: + if (auto error = ValidateVariable(_, inst)) return error; + break; + case SpvOpLoad: + if (auto error = ValidateLoad(_, inst)) return error; + break; + case SpvOpStore: + if (auto error = ValidateStore(_, inst)) return error; + break; + case SpvOpCopyMemory: + case SpvOpCopyMemorySized: + if (auto error = ValidateCopyMemory(_, inst)) return error; + break; + case SpvOpPtrAccessChain: + if (auto error = ValidatePtrAccessChain(_, inst)) return error; + break; + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + case SpvOpInBoundsPtrAccessChain: + if (auto error = ValidateAccessChain(_, inst)) return error; + break; + case SpvOpArrayLength: + if (auto error = ValidateArrayLength(_, inst)) return error; + break; + case SpvOpImageTexelPointer: + case SpvOpGenericPtrMemSemantics: + default: + break; + } + + return SPV_SUCCESS; +} +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_memory_semantics.cpp b/source/val/validate_memory_semantics.cpp new file mode 100644 index 000000000..f90b5e4cd --- /dev/null +++ b/source/val/validate_memory_semantics.cpp @@ -0,0 +1,241 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/validate_memory_semantics.h" + +#include "source/diagnostic.h" +#include "source/spirv_target_env.h" +#include "source/util/bitutils.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { + +spv_result_t ValidateMemorySemantics(ValidationState_t& _, + const Instruction* inst, + uint32_t operand_index) { + const SpvOp opcode = inst->opcode(); + const auto id = inst->GetOperandAs(operand_index); + bool is_int32 = false, is_const_int32 = false; + uint32_t value = 0; + std::tie(is_int32, is_const_int32, value) = _.EvalInt32IfConst(id); + + if (!is_int32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": expected Memory Semantics to be a 32-bit int"; + } + + if (!is_const_int32) { + if (_.HasCapability(SpvCapabilityShader)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Memory Semantics ids must be OpConstant when Shader " + "capability is present"; + } + return SPV_SUCCESS; + } + + if (spvIsWebGPUEnv(_.context()->target_env)) { + uint32_t valid_bits = SpvMemorySemanticsAcquireMask | + SpvMemorySemanticsReleaseMask | + SpvMemorySemanticsAcquireReleaseMask | + SpvMemorySemanticsUniformMemoryMask | + SpvMemorySemanticsWorkgroupMemoryMask | + SpvMemorySemanticsImageMemoryMask | + SpvMemorySemanticsOutputMemoryKHRMask | + SpvMemorySemanticsMakeAvailableKHRMask | + SpvMemorySemanticsMakeVisibleKHRMask; + if (value & ~valid_bits) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "WebGPU spec disallows any bit masks in Memory Semantics that " + "are not Acquire, Release, AcquireRelease, UniformMemory, " + "WorkgroupMemory, ImageMemory, OutputMemoryKHR, " + "MakeAvailableKHR, or MakeVisibleKHR"; + } + } + + const size_t num_memory_order_set_bits = spvtools::utils::CountSetBits( + value & (SpvMemorySemanticsAcquireMask | SpvMemorySemanticsReleaseMask | + SpvMemorySemanticsAcquireReleaseMask | + SpvMemorySemanticsSequentiallyConsistentMask)); + + if (num_memory_order_set_bits > 1) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": Memory Semantics can have at most one of the following " + "bits " + "set: Acquire, Release, AcquireRelease or " + "SequentiallyConsistent"; + } + + if (_.memory_model() == SpvMemoryModelVulkanKHR && + value & SpvMemorySemanticsSequentiallyConsistentMask) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "SequentiallyConsistent memory " + "semantics cannot be used with " + "the VulkanKHR memory model."; + } + + if (value & SpvMemorySemanticsMakeAvailableKHRMask && + !_.HasCapability(SpvCapabilityVulkanMemoryModelKHR)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": Memory Semantics MakeAvailableKHR requires capability " + << "VulkanMemoryModelKHR"; + } + + if (value & SpvMemorySemanticsMakeVisibleKHRMask && + !_.HasCapability(SpvCapabilityVulkanMemoryModelKHR)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": Memory Semantics MakeVisibleKHR requires capability " + << "VulkanMemoryModelKHR"; + } + + if (value & SpvMemorySemanticsOutputMemoryKHRMask && + !_.HasCapability(SpvCapabilityVulkanMemoryModelKHR)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": Memory Semantics OutputMemoryKHR requires capability " + << "VulkanMemoryModelKHR"; + } + + if (value & SpvMemorySemanticsUniformMemoryMask && + !_.HasCapability(SpvCapabilityShader)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": Memory Semantics UniformMemory requires capability Shader"; + } + + // Checking for SpvCapabilityAtomicStorage is intentionally not done here. See + // https://github.com/KhronosGroup/glslang/issues/1618 for the reasoning why. + + if (value & (SpvMemorySemanticsMakeAvailableKHRMask | + SpvMemorySemanticsMakeVisibleKHRMask)) { + const bool includes_storage_class = + value & (SpvMemorySemanticsUniformMemoryMask | + SpvMemorySemanticsSubgroupMemoryMask | + SpvMemorySemanticsWorkgroupMemoryMask | + SpvMemorySemanticsCrossWorkgroupMemoryMask | + SpvMemorySemanticsAtomicCounterMemoryMask | + SpvMemorySemanticsImageMemoryMask | + SpvMemorySemanticsOutputMemoryKHRMask); + + if (!includes_storage_class) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": expected Memory Semantics to include a storage class"; + } + } + + if (value & SpvMemorySemanticsMakeVisibleKHRMask && + !(value & (SpvMemorySemanticsAcquireMask | + SpvMemorySemanticsAcquireReleaseMask))) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": MakeVisibleKHR Memory Semantics also requires either Acquire " + "or AcquireRelease Memory Semantics"; + } + + if (value & SpvMemorySemanticsMakeAvailableKHRMask && + !(value & (SpvMemorySemanticsReleaseMask | + SpvMemorySemanticsAcquireReleaseMask))) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": MakeAvailableKHR Memory Semantics also requires either " + "Release or AcquireRelease Memory Semantics"; + } + + if (spvIsVulkanEnv(_.context()->target_env)) { + const bool includes_storage_class = + value & (SpvMemorySemanticsUniformMemoryMask | + SpvMemorySemanticsWorkgroupMemoryMask | + SpvMemorySemanticsImageMemoryMask | + SpvMemorySemanticsOutputMemoryKHRMask); + + if (opcode == SpvOpMemoryBarrier && !num_memory_order_set_bits) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": Vulkan specification requires Memory Semantics to have " + "one " + "of the following bits set: Acquire, Release, " + "AcquireRelease " + "or SequentiallyConsistent"; + } + + if (opcode == SpvOpMemoryBarrier && !includes_storage_class) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": expected Memory Semantics to include a Vulkan-supported " + "storage class"; + } + +#if 0 + // TODO(atgoo@github.com): this check fails Vulkan CTS, reenable once fixed. + if (opcode == SpvOpControlBarrier && value && !includes_storage_class) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": expected Memory Semantics to include a Vulkan-supported " + "storage class if Memory Semantics is not None"; + } +#endif + } + + if (opcode == SpvOpAtomicFlagClear && + (value & SpvMemorySemanticsAcquireMask || + value & SpvMemorySemanticsAcquireReleaseMask)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Memory Semantics Acquire and AcquireRelease cannot be used " + "with " + << spvOpcodeString(opcode); + } + + if (opcode == SpvOpAtomicCompareExchange && operand_index == 5 && + (value & SpvMemorySemanticsReleaseMask || + value & SpvMemorySemanticsAcquireReleaseMask)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": Memory Semantics Release and AcquireRelease cannot be " + "used " + "for operand Unequal"; + } + + if (spvIsVulkanEnv(_.context()->target_env)) { + if (opcode == SpvOpAtomicLoad && + (value & SpvMemorySemanticsReleaseMask || + value & SpvMemorySemanticsAcquireReleaseMask || + value & SpvMemorySemanticsSequentiallyConsistentMask)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Vulkan spec disallows OpAtomicLoad with Memory Semantics " + "Release, AcquireRelease and SequentiallyConsistent"; + } + + if (opcode == SpvOpAtomicStore && + (value & SpvMemorySemanticsAcquireMask || + value & SpvMemorySemanticsAcquireReleaseMask || + value & SpvMemorySemanticsSequentiallyConsistentMask)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Vulkan spec disallows OpAtomicStore with Memory Semantics " + "Acquire, AcquireRelease and SequentiallyConsistent"; + } + } + + // TODO(atgoo@github.com) Add checks for OpenCL and OpenGL environments. + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_memory_semantics.h b/source/val/validate_memory_semantics.h new file mode 100644 index 000000000..72a3e1004 --- /dev/null +++ b/source/val/validate_memory_semantics.h @@ -0,0 +1,28 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validates correctness of memory semantics for SPIR-V instructions. + +#include "source/opcode.h" +#include "source/val/validate.h" + +namespace spvtools { +namespace val { + +spv_result_t ValidateMemorySemantics(ValidationState_t& _, + const Instruction* inst, + uint32_t operand_index); + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_mode_setting.cpp b/source/val/validate_mode_setting.cpp new file mode 100644 index 000000000..c1bfc2740 --- /dev/null +++ b/source/val/validate_mode_setting.cpp @@ -0,0 +1,435 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +#include "source/val/validate.h" + +#include + +#include "source/opcode.h" +#include "source/spirv_target_env.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +spv_result_t ValidateEntryPoint(ValidationState_t& _, const Instruction* inst) { + const auto entry_point_id = inst->GetOperandAs(1); + auto entry_point = _.FindDef(entry_point_id); + if (!entry_point || SpvOpFunction != entry_point->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpEntryPoint Entry Point '" << _.getIdName(entry_point_id) + << "' is not a function."; + } + // don't check kernel function signatures + const SpvExecutionModel execution_model = + inst->GetOperandAs(0); + if (execution_model != SpvExecutionModelKernel) { + // TODO: Check the entry point signature is void main(void), may be subject + // to change + const auto entry_point_type_id = entry_point->GetOperandAs(3); + const auto entry_point_type = _.FindDef(entry_point_type_id); + if (!entry_point_type || 3 != entry_point_type->words().size()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpEntryPoint Entry Point '" << _.getIdName(entry_point_id) + << "'s function parameter count is not zero."; + } + } + + auto return_type = _.FindDef(entry_point->type_id()); + if (!return_type || SpvOpTypeVoid != return_type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpEntryPoint Entry Point '" << _.getIdName(entry_point_id) + << "'s function return type is not void."; + } + + const auto* execution_modes = _.GetExecutionModes(entry_point_id); + if (_.HasCapability(SpvCapabilityShader)) { + switch (execution_model) { + case SpvExecutionModelFragment: + if (execution_modes && + execution_modes->count(SpvExecutionModeOriginUpperLeft) && + execution_modes->count(SpvExecutionModeOriginLowerLeft)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Fragment execution model entry points can only specify " + "one of OriginUpperLeft or OriginLowerLeft execution " + "modes."; + } + if (!execution_modes || + (!execution_modes->count(SpvExecutionModeOriginUpperLeft) && + !execution_modes->count(SpvExecutionModeOriginLowerLeft))) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Fragment execution model entry points require either an " + "OriginUpperLeft or OriginLowerLeft execution mode."; + } + if (execution_modes && + 1 < std::count_if(execution_modes->begin(), execution_modes->end(), + [](const SpvExecutionMode& mode) { + switch (mode) { + case SpvExecutionModeDepthGreater: + case SpvExecutionModeDepthLess: + case SpvExecutionModeDepthUnchanged: + return true; + default: + return false; + } + })) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Fragment execution model entry points can specify at most " + "one of DepthGreater, DepthLess or DepthUnchanged " + "execution modes."; + } + break; + case SpvExecutionModelTessellationControl: + case SpvExecutionModelTessellationEvaluation: + if (execution_modes && + 1 < std::count_if(execution_modes->begin(), execution_modes->end(), + [](const SpvExecutionMode& mode) { + switch (mode) { + case SpvExecutionModeSpacingEqual: + case SpvExecutionModeSpacingFractionalEven: + case SpvExecutionModeSpacingFractionalOdd: + return true; + default: + return false; + } + })) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Tessellation execution model entry points can specify at " + "most one of SpacingEqual, SpacingFractionalOdd or " + "SpacingFractionalEven execution modes."; + } + if (execution_modes && + 1 < std::count_if(execution_modes->begin(), execution_modes->end(), + [](const SpvExecutionMode& mode) { + switch (mode) { + case SpvExecutionModeTriangles: + case SpvExecutionModeQuads: + case SpvExecutionModeIsolines: + return true; + default: + return false; + } + })) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Tessellation execution model entry points can specify at " + "most one of Triangles, Quads or Isolines execution modes."; + } + if (execution_modes && + 1 < std::count_if(execution_modes->begin(), execution_modes->end(), + [](const SpvExecutionMode& mode) { + switch (mode) { + case SpvExecutionModeVertexOrderCw: + case SpvExecutionModeVertexOrderCcw: + return true; + default: + return false; + } + })) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Tessellation execution model entry points can specify at " + "most one of VertexOrderCw or VertexOrderCcw execution " + "modes."; + } + break; + case SpvExecutionModelGeometry: + if (!execution_modes || + 1 != std::count_if(execution_modes->begin(), execution_modes->end(), + [](const SpvExecutionMode& mode) { + switch (mode) { + case SpvExecutionModeInputPoints: + case SpvExecutionModeInputLines: + case SpvExecutionModeInputLinesAdjacency: + case SpvExecutionModeTriangles: + case SpvExecutionModeInputTrianglesAdjacency: + return true; + default: + return false; + } + })) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Geometry execution model entry points must specify " + "exactly one of InputPoints, InputLines, " + "InputLinesAdjacency, Triangles or InputTrianglesAdjacency " + "execution modes."; + } + if (!execution_modes || + 1 != std::count_if(execution_modes->begin(), execution_modes->end(), + [](const SpvExecutionMode& mode) { + switch (mode) { + case SpvExecutionModeOutputPoints: + case SpvExecutionModeOutputLineStrip: + case SpvExecutionModeOutputTriangleStrip: + return true; + default: + return false; + } + })) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Geometry execution model entry points must specify " + "exactly one of OutputPoints, OutputLineStrip or " + "OutputTriangleStrip execution modes."; + } + break; + default: + break; + } + } + + if (spvIsVulkanEnv(_.context()->target_env)) { + switch (execution_model) { + case SpvExecutionModelGLCompute: + if (!execution_modes || + !execution_modes->count(SpvExecutionModeLocalSize)) { + bool ok = false; + for (auto& i : _.ordered_instructions()) { + if (i.opcode() == SpvOpDecorate) { + if (i.operands().size() > 2) { + if (i.GetOperandAs(1) == SpvDecorationBuiltIn && + i.GetOperandAs(2) == SpvBuiltInWorkgroupSize) { + ok = true; + break; + } + } + } + } + if (!ok) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "In the Vulkan environment, GLCompute execution model " + "entry points require either the LocalSize execution " + "mode or an object decorated with WorkgroupSize must be " + "specified."; + } + } + break; + default: + break; + } + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateExecutionMode(ValidationState_t& _, + const Instruction* inst) { + const auto entry_point_id = inst->GetOperandAs(0); + const auto found = std::find(_.entry_points().cbegin(), + _.entry_points().cend(), entry_point_id); + if (found == _.entry_points().cend()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpExecutionMode Entry Point '" + << _.getIdName(entry_point_id) + << "' is not the Entry Point " + "operand of an OpEntryPoint."; + } + + const auto mode = inst->GetOperandAs(1); + const auto* models = _.GetExecutionModels(entry_point_id); + switch (mode) { + case SpvExecutionModeInvocations: + case SpvExecutionModeInputPoints: + case SpvExecutionModeInputLines: + case SpvExecutionModeInputLinesAdjacency: + case SpvExecutionModeInputTrianglesAdjacency: + case SpvExecutionModeOutputLineStrip: + case SpvExecutionModeOutputTriangleStrip: + if (!std::all_of(models->begin(), models->end(), + [](const SpvExecutionModel& model) { + return model == SpvExecutionModelGeometry; + })) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Execution mode can only be used with the Geometry execution " + "model."; + } + break; + case SpvExecutionModeOutputPoints: + if (!std::all_of(models->begin(), models->end(), + [&_](const SpvExecutionModel& model) { + switch (model) { + case SpvExecutionModelGeometry: + return true; + case SpvExecutionModelMeshNV: + return _.HasCapability(SpvCapabilityMeshShadingNV); + default: + return false; + } + })) { + if (_.HasCapability(SpvCapabilityMeshShadingNV)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Execution mode can only be used with the Geometry or " + "MeshNV execution model."; + } else { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Execution mode can only be used with the Geometry " + "execution " + "model."; + } + } + break; + case SpvExecutionModeSpacingEqual: + case SpvExecutionModeSpacingFractionalEven: + case SpvExecutionModeSpacingFractionalOdd: + case SpvExecutionModeVertexOrderCw: + case SpvExecutionModeVertexOrderCcw: + case SpvExecutionModePointMode: + case SpvExecutionModeQuads: + case SpvExecutionModeIsolines: + if (!std::all_of( + models->begin(), models->end(), + [](const SpvExecutionModel& model) { + return (model == SpvExecutionModelTessellationControl) || + (model == SpvExecutionModelTessellationEvaluation); + })) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Execution mode can only be used with a tessellation " + "execution model."; + } + break; + case SpvExecutionModeTriangles: + if (!std::all_of(models->begin(), models->end(), + [](const SpvExecutionModel& model) { + switch (model) { + case SpvExecutionModelGeometry: + case SpvExecutionModelTessellationControl: + case SpvExecutionModelTessellationEvaluation: + return true; + default: + return false; + } + })) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Execution mode can only be used with a Geometry or " + "tessellation execution model."; + } + break; + case SpvExecutionModeOutputVertices: + if (!std::all_of(models->begin(), models->end(), + [&_](const SpvExecutionModel& model) { + switch (model) { + case SpvExecutionModelGeometry: + case SpvExecutionModelTessellationControl: + case SpvExecutionModelTessellationEvaluation: + return true; + case SpvExecutionModelMeshNV: + return _.HasCapability(SpvCapabilityMeshShadingNV); + default: + return false; + } + })) { + if (_.HasCapability(SpvCapabilityMeshShadingNV)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Execution mode can only be used with a Geometry, " + "tessellation or MeshNV execution model."; + } else { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Execution mode can only be used with a Geometry or " + "tessellation execution model."; + } + } + break; + case SpvExecutionModePixelCenterInteger: + case SpvExecutionModeOriginUpperLeft: + case SpvExecutionModeOriginLowerLeft: + case SpvExecutionModeEarlyFragmentTests: + case SpvExecutionModeDepthReplacing: + case SpvExecutionModeDepthLess: + case SpvExecutionModeDepthUnchanged: + if (!std::all_of(models->begin(), models->end(), + [](const SpvExecutionModel& model) { + return model == SpvExecutionModelFragment; + })) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Execution mode can only be used with the Fragment execution " + "model."; + } + break; + case SpvExecutionModeLocalSizeHint: + case SpvExecutionModeVecTypeHint: + case SpvExecutionModeContractionOff: + case SpvExecutionModeLocalSizeHintId: + if (!std::all_of(models->begin(), models->end(), + [](const SpvExecutionModel& model) { + return model == SpvExecutionModelKernel; + })) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Execution mode can only be used with the Kernel execution " + "model."; + } + break; + case SpvExecutionModeLocalSize: + case SpvExecutionModeLocalSizeId: + if (!std::all_of(models->begin(), models->end(), + [&_](const SpvExecutionModel& model) { + switch (model) { + case SpvExecutionModelKernel: + case SpvExecutionModelGLCompute: + return true; + case SpvExecutionModelTaskNV: + case SpvExecutionModelMeshNV: + return _.HasCapability(SpvCapabilityMeshShadingNV); + default: + return false; + } + })) { + if (_.HasCapability(SpvCapabilityMeshShadingNV)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Execution mode can only be used with a Kernel, GLCompute, " + "MeshNV, or TaskNV execution model."; + } else { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Execution mode can only be used with a Kernel or " + "GLCompute " + "execution model."; + } + } + default: + break; + } + + if (spvIsVulkanEnv(_.context()->target_env)) { + if (mode == SpvExecutionModeOriginLowerLeft) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "In the Vulkan environment, the OriginLowerLeft execution mode " + "must not be used."; + } + if (mode == SpvExecutionModePixelCenterInteger) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "In the Vulkan environment, the PixelCenterInteger execution " + "mode must not be used."; + } + } + + return SPV_SUCCESS; +} + +} // namespace + +spv_result_t ModeSettingPass(ValidationState_t& _, const Instruction* inst) { + switch (inst->opcode()) { + case SpvOpEntryPoint: + if (auto error = ValidateEntryPoint(_, inst)) return error; + break; + case SpvOpExecutionMode: + case SpvOpExecutionModeId: + if (auto error = ValidateExecutionMode(_, inst)) return error; + break; + default: + break; + } + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_non_uniform.cpp b/source/val/validate_non_uniform.cpp new file mode 100644 index 000000000..8dcf9743f --- /dev/null +++ b/source/val/validate_non_uniform.cpp @@ -0,0 +1,77 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validates correctness of barrier SPIR-V instructions. + +#include "source/val/validate.h" + +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/spirv_constant.h" +#include "source/spirv_target_env.h" +#include "source/util/bitutils.h" +#include "source/val/instruction.h" +#include "source/val/validate_scopes.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +spv_result_t ValidateGroupNonUniformBallotBitCount(ValidationState_t& _, + const Instruction* inst) { + // Scope is already checked by ValidateExecutionScope() above. + + const uint32_t result_type = inst->type_id(); + if (!_.IsUnsignedIntScalarType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Result Type to be an unsigned integer type scalar."; + } + + const auto value = inst->GetOperandAs(4); + const auto value_type = _.FindDef(value)->type_id(); + if (!_.IsUnsignedIntVectorType(value_type) || + _.GetDimension(value_type) != 4) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Value to be a " + "vector of four components " + "of integer type scalar"; + } + return SPV_SUCCESS; +} + +} // namespace + +// Validates correctness of non-uniform group instructions. +spv_result_t NonUniformPass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + + if (spvOpcodeIsNonUniformGroupOperation(opcode)) { + const uint32_t execution_scope = inst->word(3); + if (auto error = ValidateExecutionScope(_, inst, execution_scope)) { + return error; + } + } + + switch (opcode) { + case SpvOpGroupNonUniformBallotBitCount: + return ValidateGroupNonUniformBallotBitCount(_, inst); + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_primitives.cpp b/source/val/validate_primitives.cpp new file mode 100644 index 000000000..7d11f2e7a --- /dev/null +++ b/source/val/validate_primitives.cpp @@ -0,0 +1,75 @@ +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validates correctness of primitive SPIR-V instructions. + +#include "source/val/validate.h" + +#include + +#include "source/diagnostic.h" +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { + +// Validates correctness of primitive instructions. +spv_result_t PrimitivesPass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + + switch (opcode) { + case SpvOpEmitVertex: + case SpvOpEndPrimitive: + case SpvOpEmitStreamVertex: + case SpvOpEndStreamPrimitive: + _.function(inst->function()->id()) + ->RegisterExecutionModelLimitation( + SpvExecutionModelGeometry, + std::string(spvOpcodeString(opcode)) + + " instructions require Geometry execution model"); + break; + default: + break; + } + + switch (opcode) { + case SpvOpEmitStreamVertex: + case SpvOpEndStreamPrimitive: { + const uint32_t stream_id = inst->word(1); + const uint32_t stream_type = _.GetTypeId(stream_id); + if (!_.IsIntScalarType(stream_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": expected Stream to be int scalar"; + } + + const SpvOp stream_opcode = _.GetIdOpcode(stream_id); + if (!spvOpcodeIsConstant(stream_opcode)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": expected Stream to be constant instruction"; + } + } + + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_scopes.cpp b/source/val/validate_scopes.cpp new file mode 100644 index 000000000..b6401310d --- /dev/null +++ b/source/val/validate_scopes.cpp @@ -0,0 +1,204 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/validate_scopes.h" + +#include "source/diagnostic.h" +#include "source/spirv_target_env.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { + +spv_result_t ValidateExecutionScope(ValidationState_t& _, + const Instruction* inst, uint32_t scope) { + SpvOp opcode = inst->opcode(); + bool is_int32 = false, is_const_int32 = false; + uint32_t value = 0; + std::tie(is_int32, is_const_int32, value) = _.EvalInt32IfConst(scope); + + if (!is_int32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": expected Execution Scope to be a 32-bit int"; + } + + if (!is_const_int32) { + if (_.HasCapability(SpvCapabilityShader)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Scope ids must be OpConstant when Shader capability is " + << "present"; + } + return SPV_SUCCESS; + } + + // Vulkan specific rules + if (spvIsVulkanEnv(_.context()->target_env)) { + // Vulkan 1.1 specific rules + if (_.context()->target_env != SPV_ENV_VULKAN_1_0) { + // Scope for Non Uniform Group Operations must be limited to Subgroup + if (spvOpcodeIsNonUniformGroupOperation(opcode) && + value != SpvScopeSubgroup) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": in Vulkan environment Execution scope is limited to " + << "Subgroup"; + } + } + + // If OpControlBarrier is used in fragment, vertex, tessellation evaluation, + // or geometry stages, the execution Scope must be Subgroup. + if (opcode == SpvOpControlBarrier && value != SpvScopeSubgroup) { + _.function(inst->function()->id()) + ->RegisterExecutionModelLimitation([](SpvExecutionModel model, + std::string* message) { + if (model == SpvExecutionModelFragment || + model == SpvExecutionModelVertex || + model == SpvExecutionModelGeometry || + model == SpvExecutionModelTessellationEvaluation) { + if (message) { + *message = + "in Vulkan evironment, OpControlBarrier execution scope " + "must be Subgroup for Fragment, Vertex, Geometry and " + "TessellationEvaluation execution models"; + } + return false; + } + return true; + }); + } + + // Vulkan generic rules + // Scope for execution must be limited to Workgroup or Subgroup + if (value != SpvScopeWorkgroup && value != SpvScopeSubgroup) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": in Vulkan environment Execution Scope is limited to " + << "Workgroup and Subgroup"; + } + } + + // WebGPU Specific rules + if (spvIsWebGPUEnv(_.context()->target_env)) { + // Scope for execution must be limited to Workgroup or Subgroup + if (value != SpvScopeWorkgroup && value != SpvScopeSubgroup) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": in WebGPU environment Execution Scope is limited to " + << "Workgroup and Subgroup"; + } + } + + // TODO(atgoo@github.com) Add checks for OpenCL and OpenGL environments. + + // General SPIRV rules + // Scope for execution must be limited to Workgroup or Subgroup for + // non-uniform operations + if (spvOpcodeIsNonUniformGroupOperation(opcode) && + value != SpvScopeSubgroup && value != SpvScopeWorkgroup) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": Execution scope is limited to Subgroup or Workgroup"; + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateMemoryScope(ValidationState_t& _, const Instruction* inst, + uint32_t scope) { + const SpvOp opcode = inst->opcode(); + bool is_int32 = false, is_const_int32 = false; + uint32_t value = 0; + std::tie(is_int32, is_const_int32, value) = _.EvalInt32IfConst(scope); + + if (!is_int32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": expected Memory Scope to be a 32-bit int"; + } + + if (!is_const_int32) { + if (_.HasCapability(SpvCapabilityShader)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Scope ids must be OpConstant when Shader capability is " + << "present"; + } + return SPV_SUCCESS; + } + + if (value == SpvScopeQueueFamilyKHR) { + if (_.HasCapability(SpvCapabilityVulkanMemoryModelKHR)) { + return SPV_SUCCESS; + } else { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": Memory Scope QueueFamilyKHR requires capability " + << "VulkanMemoryModelKHR"; + } + } + + if (value == SpvScopeDevice && + _.HasCapability(SpvCapabilityVulkanMemoryModelKHR) && + !_.HasCapability(SpvCapabilityVulkanMemoryModelDeviceScopeKHR)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Use of device scope with VulkanKHR memory model requires the " + << "VulkanMemoryModelDeviceScopeKHR capability"; + } + + // Vulkan Specific rules + if (spvIsVulkanEnv(_.context()->target_env)) { + if (value == SpvScopeCrossDevice) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": in Vulkan environment, Memory Scope cannot be CrossDevice"; + } + // Vulkan 1.0 specifc rules + if (_.context()->target_env == SPV_ENV_VULKAN_1_0 && + value != SpvScopeDevice && value != SpvScopeWorkgroup && + value != SpvScopeInvocation) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": in Vulkan 1.0 environment Memory Scope is limited to " + << "Device, Workgroup and Invocation"; + } + // Vulkan 1.1 specifc rules + if (_.context()->target_env == SPV_ENV_VULKAN_1_1 && + value != SpvScopeDevice && value != SpvScopeWorkgroup && + value != SpvScopeSubgroup && value != SpvScopeInvocation) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": in Vulkan 1.1 environment Memory Scope is limited to " + << "Device, Workgroup and Invocation"; + } + } + + // WebGPU specific rules + if (spvIsWebGPUEnv(_.context()->target_env)) { + if (value != SpvScopeWorkgroup && value != SpvScopeSubgroup && + value != SpvScopeQueueFamilyKHR) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << spvOpcodeString(opcode) + << ": in WebGPU environment Memory Scope is limited to " + << "Workgroup, Subgroup and QueuFamilyKHR"; + } + } + + // TODO(atgoo@github.com) Add checks for OpenCL and OpenGL environments. + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_scopes.h b/source/val/validate_scopes.h new file mode 100644 index 000000000..311ca7ff2 --- /dev/null +++ b/source/val/validate_scopes.h @@ -0,0 +1,30 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validates correctness of scopes for SPIR-V instructions. + +#include "source/opcode.h" +#include "source/val/validate.h" + +namespace spvtools { +namespace val { + +spv_result_t ValidateExecutionScope(ValidationState_t& _, + const Instruction* inst, uint32_t scope); + +spv_result_t ValidateMemoryScope(ValidationState_t& _, const Instruction* inst, + uint32_t scope); + +} // namespace val +} // namespace spvtools diff --git a/source/val/validate_type.cpp b/source/val/validate_type.cpp new file mode 100644 index 000000000..94ea66034 --- /dev/null +++ b/source/val/validate_type.cpp @@ -0,0 +1,393 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Ensures type declarations are unique unless allowed by the specification. + +#include "source/val/validate.h" + +#include "source/opcode.h" +#include "source/spirv_target_env.h" +#include "source/val/instruction.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +// True if the integer constant is > 0. |const_words| are words of the +// constant-defining instruction (either OpConstant or +// OpSpecConstant). typeWords are the words of the constant's-type-defining +// OpTypeInt. +bool AboveZero(const std::vector& const_words, + const std::vector& type_words) { + const uint32_t width = type_words[2]; + const bool is_signed = type_words[3] > 0; + const uint32_t lo_word = const_words[3]; + if (width > 32) { + // The spec currently doesn't allow integers wider than 64 bits. + const uint32_t hi_word = const_words[4]; // Must exist, per spec. + if (is_signed && (hi_word >> 31)) return false; + return (lo_word | hi_word) > 0; + } else { + if (is_signed && (lo_word >> 31)) return false; + return lo_word > 0; + } +} + +// Validates that type declarations are unique, unless multiple declarations +// of the same data type are allowed by the specification. +// (see section 2.8 Types and Variables) +// Doesn't do anything if SPV_VAL_ignore_type_decl_unique was declared in the +// module. +spv_result_t ValidateUniqueness(ValidationState_t& _, const Instruction* inst) { + if (_.HasExtension(Extension::kSPV_VALIDATOR_ignore_type_decl_unique)) + return SPV_SUCCESS; + + const auto opcode = inst->opcode(); + if (opcode != SpvOpTypeArray && opcode != SpvOpTypeRuntimeArray && + opcode != SpvOpTypeStruct && opcode != SpvOpTypePointer && + !_.RegisterUniqueTypeDeclaration(inst)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Duplicate non-aggregate type declarations are not allowed. " + "Opcode: " + << spvOpcodeString(opcode) << " id: " << inst->id(); + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateTypeVector(ValidationState_t& _, const Instruction* inst) { + const auto component_index = 1; + const auto component_id = inst->GetOperandAs(component_index); + const auto component_type = _.FindDef(component_id); + if (!component_type || !spvOpcodeIsScalarType(component_type->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeVector Component Type '" << _.getIdName(component_id) + << "' is not a scalar type."; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateTypeMatrix(ValidationState_t& _, const Instruction* inst) { + const auto column_type_index = 1; + const auto column_type_id = inst->GetOperandAs(column_type_index); + const auto column_type = _.FindDef(column_type_id); + if (!column_type || SpvOpTypeVector != column_type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeMatrix Column Type '" << _.getIdName(column_type_id) + << "' is not a vector."; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateTypeArray(ValidationState_t& _, const Instruction* inst) { + const auto element_type_index = 1; + const auto element_type_id = inst->GetOperandAs(element_type_index); + const auto element_type = _.FindDef(element_type_id); + if (!element_type || !spvOpcodeGeneratesType(element_type->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeArray Element Type '" << _.getIdName(element_type_id) + << "' is not a type."; + } + + if (element_type->opcode() == SpvOpTypeVoid) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeArray Element Type '" << _.getIdName(element_type_id) + << "' is a void type."; + } + + if ((spvIsVulkanEnv(_.context()->target_env) || + spvIsWebGPUEnv(_.context()->target_env)) && + element_type->opcode() == SpvOpTypeRuntimeArray) { + const char* env_text = + spvIsVulkanEnv(_.context()->target_env) ? "Vulkan" : "WebGPU"; + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeArray Element Type '" << _.getIdName(element_type_id) + << "' is not valid in " << env_text << " environment."; + } + + const auto length_index = 2; + const auto length_id = inst->GetOperandAs(length_index); + const auto length = _.FindDef(length_id); + if (!length || !spvOpcodeIsConstant(length->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeArray Length '" << _.getIdName(length_id) + << "' is not a scalar constant type."; + } + + // NOTE: Check the initialiser value of the constant + const auto const_inst = length->words(); + const auto const_result_type_index = 1; + const auto const_result_type = _.FindDef(const_inst[const_result_type_index]); + if (!const_result_type || SpvOpTypeInt != const_result_type->opcode()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeArray Length '" << _.getIdName(length_id) + << "' is not a constant integer type."; + } + + switch (length->opcode()) { + case SpvOpSpecConstant: + case SpvOpConstant: + if (AboveZero(length->words(), const_result_type->words())) break; + // Else fall through! + case SpvOpConstantNull: { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeArray Length '" << _.getIdName(length_id) + << "' default value must be at least 1."; + } + case SpvOpSpecConstantOp: + // Assume it's OK, rather than try to evaluate the operation. + break; + default: + assert(0 && "bug in spvOpcodeIsConstant() or result type isn't int"); + } + return SPV_SUCCESS; +} + +spv_result_t ValidateTypeRuntimeArray(ValidationState_t& _, + const Instruction* inst) { + const auto element_type_index = 1; + const auto element_id = inst->GetOperandAs(element_type_index); + const auto element_type = _.FindDef(element_id); + if (!element_type || !spvOpcodeGeneratesType(element_type->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeRuntimeArray Element Type '" + << _.getIdName(element_id) << "' is not a type."; + } + + if (element_type->opcode() == SpvOpTypeVoid) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeRuntimeArray Element Type '" + << _.getIdName(element_id) << "' is a void type."; + } + + if ((spvIsVulkanEnv(_.context()->target_env) || + spvIsWebGPUEnv(_.context()->target_env)) && + element_type->opcode() == SpvOpTypeRuntimeArray) { + const char* env_text = + spvIsVulkanEnv(_.context()->target_env) ? "Vulkan" : "WebGPU"; + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeRuntimeArray Element Type '" + << _.getIdName(element_id) << "' is not valid in " << env_text + << " environment."; + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateTypeStruct(ValidationState_t& _, const Instruction* inst) { + const uint32_t struct_id = inst->GetOperandAs(0); + for (size_t member_type_index = 1; + member_type_index < inst->operands().size(); ++member_type_index) { + auto member_type_id = inst->GetOperandAs(member_type_index); + auto member_type = _.FindDef(member_type_id); + if (!member_type || !spvOpcodeGeneratesType(member_type->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeStruct Member Type '" << _.getIdName(member_type_id) + << "' is not a type."; + } + if (member_type->opcode() == SpvOpTypeVoid) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Structures cannot contain a void type."; + } + if (SpvOpTypeStruct == member_type->opcode() && + _.IsStructTypeWithBuiltInMember(member_type_id)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Structure " << _.getIdName(member_type_id) + << " contains members with BuiltIn decoration. Therefore this " + << "structure may not be contained as a member of another " + << "structure " + << "type. Structure " << _.getIdName(struct_id) + << " contains structure " << _.getIdName(member_type_id) + << "."; + } + if (_.IsForwardPointer(member_type_id)) { + // If we're dealing with a forward pointer: + // Find out the type that the pointer is pointing to (must be struct) + // word 3 is the of the type being pointed to. + auto type_pointing_to = _.FindDef(member_type->words()[3]); + if (type_pointing_to && type_pointing_to->opcode() != SpvOpTypeStruct) { + // Forward declared operands of a struct may only point to a struct. + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "A forward reference operand in an OpTypeStruct must be an " + "OpTypePointer that points to an OpTypeStruct. " + "Found OpTypePointer that points to Op" + << spvOpcodeString( + static_cast(type_pointing_to->opcode())) + << "."; + } + } + + if ((spvIsVulkanEnv(_.context()->target_env) || + spvIsWebGPUEnv(_.context()->target_env)) && + member_type->opcode() == SpvOpTypeRuntimeArray) { + const bool is_last_member = + member_type_index == inst->operands().size() - 1; + if (!is_last_member) { + const char* env_text = + spvIsVulkanEnv(_.context()->target_env) ? "Vulkan" : "WebGPU"; + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "In " << env_text << ", OpTypeRuntimeArray must only be used " + << "for the last member of an OpTypeStruct"; + } + } + } + + std::unordered_set built_in_members; + for (auto decoration : _.id_decorations(struct_id)) { + if (decoration.dec_type() == SpvDecorationBuiltIn && + decoration.struct_member_index() != Decoration::kInvalidMember) { + built_in_members.insert(decoration.struct_member_index()); + } + } + int num_struct_members = static_cast(inst->operands().size() - 1); + int num_builtin_members = static_cast(built_in_members.size()); + if (num_builtin_members > 0 && num_builtin_members != num_struct_members) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "When BuiltIn decoration is applied to a structure-type member, " + << "all members of that structure type must also be decorated with " + << "BuiltIn (No allowed mixing of built-in variables and " + << "non-built-in variables within a single structure). Structure id " + << struct_id << " does not meet this requirement."; + } + if (num_builtin_members > 0) { + _.RegisterStructTypeWithBuiltInMember(struct_id); + } + return SPV_SUCCESS; +} + +spv_result_t ValidateTypePointer(ValidationState_t& _, + const Instruction* inst) { + const auto type_id = inst->GetOperandAs(2); + const auto type = _.FindDef(type_id); + if (!type || !spvOpcodeGeneratesType(type->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypePointer Type '" << _.getIdName(type_id) + << "' is not a type."; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateTypeFunction(ValidationState_t& _, + const Instruction* inst) { + const auto return_type_id = inst->GetOperandAs(1); + const auto return_type = _.FindDef(return_type_id); + if (!return_type || !spvOpcodeGeneratesType(return_type->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeFunction Return Type '" << _.getIdName(return_type_id) + << "' is not a type."; + } + size_t num_args = 0; + for (size_t param_type_index = 2; param_type_index < inst->operands().size(); + ++param_type_index, ++num_args) { + const auto param_id = inst->GetOperandAs(param_type_index); + const auto param_type = _.FindDef(param_id); + if (!param_type || !spvOpcodeGeneratesType(param_type->opcode())) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeFunction Parameter Type '" << _.getIdName(param_id) + << "' is not a type."; + } + + if (param_type->opcode() == SpvOpTypeVoid) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeFunction Parameter Type '" << _.getIdName(param_id) + << "' cannot be OpTypeVoid."; + } + } + const uint32_t num_function_args_limit = + _.options()->universal_limits_.max_function_args; + if (num_args > num_function_args_limit) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeFunction may not take more than " + << num_function_args_limit << " arguments. OpTypeFunction '" + << _.getIdName(inst->GetOperandAs(0)) << "' has " + << num_args << " arguments."; + } + + // The only valid uses of OpTypeFunction are in an OpFunction instruction. + for (auto& pair : inst->uses()) { + const auto* use = pair.first; + if (use->opcode() != SpvOpFunction) { + return _.diag(SPV_ERROR_INVALID_ID, use) + << "Invalid use of function type result id " + << _.getIdName(inst->id()) << "."; + } + } + + return SPV_SUCCESS; +} + +spv_result_t ValidateTypeForwardPointer(ValidationState_t& _, + const Instruction* inst) { + const auto pointer_type_id = inst->GetOperandAs(0); + const auto pointer_type_inst = _.FindDef(pointer_type_id); + if (pointer_type_inst->opcode() != SpvOpTypePointer) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Pointer type in OpTypeForwardPointer is not a pointer type."; + } + + if (inst->GetOperandAs(1) != + pointer_type_inst->GetOperandAs(1)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Storage class in OpTypeForwardPointer does not match the " + << "pointer definition."; + } + + return SPV_SUCCESS; +} + +} // namespace + +spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) { + if (!spvOpcodeGeneratesType(inst->opcode()) && + inst->opcode() != SpvOpTypeForwardPointer) { + return SPV_SUCCESS; + } + + if (auto error = ValidateUniqueness(_, inst)) return error; + + switch (inst->opcode()) { + case SpvOpTypeVector: + if (auto error = ValidateTypeVector(_, inst)) return error; + break; + case SpvOpTypeMatrix: + if (auto error = ValidateTypeMatrix(_, inst)) return error; + break; + case SpvOpTypeArray: + if (auto error = ValidateTypeArray(_, inst)) return error; + break; + case SpvOpTypeRuntimeArray: + if (auto error = ValidateTypeRuntimeArray(_, inst)) return error; + break; + case SpvOpTypeStruct: + if (auto error = ValidateTypeStruct(_, inst)) return error; + break; + case SpvOpTypePointer: + if (auto error = ValidateTypePointer(_, inst)) return error; + break; + case SpvOpTypeFunction: + if (auto error = ValidateTypeFunction(_, inst)) return error; + break; + case SpvOpTypeForwardPointer: + if (auto error = ValidateTypeForwardPointer(_, inst)) return error; + break; + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp new file mode 100644 index 000000000..4a4293546 --- /dev/null +++ b/source/val/validation_state.cpp @@ -0,0 +1,1025 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/val/validation_state.h" + +#include +#include +#include + +#include "source/opcode.h" +#include "source/spirv_target_env.h" +#include "source/val/basic_block.h" +#include "source/val/construct.h" +#include "source/val/function.h" +#include "spirv-tools/libspirv.h" + +namespace spvtools { +namespace val { +namespace { + +bool IsInstructionInLayoutSection(ModuleLayoutSection layout, SpvOp op) { + // See Section 2.4 + bool out = false; + // clang-format off + switch (layout) { + case kLayoutCapabilities: out = op == SpvOpCapability; break; + case kLayoutExtensions: out = op == SpvOpExtension; break; + case kLayoutExtInstImport: out = op == SpvOpExtInstImport; break; + case kLayoutMemoryModel: out = op == SpvOpMemoryModel; break; + case kLayoutEntryPoint: out = op == SpvOpEntryPoint; break; + case kLayoutExecutionMode: + out = op == SpvOpExecutionMode || op == SpvOpExecutionModeId; + break; + case kLayoutDebug1: + switch (op) { + case SpvOpSourceContinued: + case SpvOpSource: + case SpvOpSourceExtension: + case SpvOpString: + out = true; + break; + default: break; + } + break; + case kLayoutDebug2: + switch (op) { + case SpvOpName: + case SpvOpMemberName: + out = true; + break; + default: break; + } + break; + case kLayoutDebug3: + // Only OpModuleProcessed is allowed here. + out = (op == SpvOpModuleProcessed); + break; + case kLayoutAnnotations: + switch (op) { + case SpvOpDecorate: + case SpvOpMemberDecorate: + case SpvOpGroupDecorate: + case SpvOpGroupMemberDecorate: + case SpvOpDecorationGroup: + case SpvOpDecorateId: + case SpvOpDecorateStringGOOGLE: + case SpvOpMemberDecorateStringGOOGLE: + out = true; + break; + default: break; + } + break; + case kLayoutTypes: + if (spvOpcodeGeneratesType(op) || spvOpcodeIsConstant(op)) { + out = true; + break; + } + switch (op) { + case SpvOpTypeForwardPointer: + case SpvOpVariable: + case SpvOpLine: + case SpvOpNoLine: + case SpvOpUndef: + out = true; + break; + default: break; + } + break; + case kLayoutFunctionDeclarations: + case kLayoutFunctionDefinitions: + // NOTE: These instructions should NOT be in these layout sections + if (spvOpcodeGeneratesType(op) || spvOpcodeIsConstant(op)) { + out = false; + break; + } + switch (op) { + case SpvOpCapability: + case SpvOpExtension: + case SpvOpExtInstImport: + case SpvOpMemoryModel: + case SpvOpEntryPoint: + case SpvOpExecutionMode: + case SpvOpExecutionModeId: + case SpvOpSourceContinued: + case SpvOpSource: + case SpvOpSourceExtension: + case SpvOpString: + case SpvOpName: + case SpvOpMemberName: + case SpvOpModuleProcessed: + case SpvOpDecorate: + case SpvOpMemberDecorate: + case SpvOpGroupDecorate: + case SpvOpGroupMemberDecorate: + case SpvOpDecorationGroup: + case SpvOpTypeForwardPointer: + out = false; + break; + default: + out = true; + break; + } + } + // clang-format on + return out; +} + +// Counts the number of instructions and functions in the file. +spv_result_t CountInstructions(void* user_data, + const spv_parsed_instruction_t* inst) { + ValidationState_t& _ = *(reinterpret_cast(user_data)); + if (inst->opcode == SpvOpFunction) _.increment_total_functions(); + _.increment_total_instructions(); + + return SPV_SUCCESS; +} + +} // namespace + +ValidationState_t::ValidationState_t(const spv_const_context ctx, + const spv_const_validator_options opt, + const uint32_t* words, + const size_t num_words, + const uint32_t max_warnings) + : context_(ctx), + options_(opt), + words_(words), + num_words_(num_words), + unresolved_forward_ids_{}, + operand_names_{}, + current_layout_section_(kLayoutCapabilities), + module_functions_(), + module_capabilities_(), + module_extensions_(), + ordered_instructions_(), + all_definitions_(), + global_vars_(), + local_vars_(), + struct_nesting_depth_(), + grammar_(ctx), + addressing_model_(SpvAddressingModelMax), + memory_model_(SpvMemoryModelMax), + pointer_size_and_alignment_(0), + in_function_(false), + num_of_warnings_(0), + max_num_of_warnings_(max_warnings) { + assert(opt && "Validator options may not be Null."); + + const auto env = context_->target_env; + + if (spvIsVulkanEnv(env)) { + // Vulkan 1.1 includes VK_KHR_relaxed_block_layout in core. + if (env != SPV_ENV_VULKAN_1_0) { + features_.env_relaxed_block_layout = true; + } + } + + switch (env) { + case SPV_ENV_WEBGPU_0: + features_.bans_op_undef = true; + break; + default: + break; + } + + // Only attempt to count if we have words, otherwise let the other validation + // fail and generate an error. + if (num_words > 0) { + // Count the number of instructions in the binary. + // This parse should not produce any error messages. Hijack the context and + // replace the message consumer so that we do not pollute any state in input + // consumer. + spv_context_t hijacked_context = *ctx; + hijacked_context.consumer = [](spv_message_level_t, const char*, + const spv_position_t&, const char*) {}; + spvBinaryParse(&hijacked_context, this, words, num_words, + /* parsed_header = */ nullptr, CountInstructions, + /* diagnostic = */ nullptr); + preallocateStorage(); + } + + friendly_mapper_ = spvtools::MakeUnique( + context_, words_, num_words_); + name_mapper_ = friendly_mapper_->GetNameMapper(); +} + +void ValidationState_t::preallocateStorage() { + ordered_instructions_.reserve(total_instructions_); + module_functions_.reserve(total_functions_); +} + +spv_result_t ValidationState_t::ForwardDeclareId(uint32_t id) { + unresolved_forward_ids_.insert(id); + return SPV_SUCCESS; +} + +spv_result_t ValidationState_t::RemoveIfForwardDeclared(uint32_t id) { + unresolved_forward_ids_.erase(id); + return SPV_SUCCESS; +} + +spv_result_t ValidationState_t::RegisterForwardPointer(uint32_t id) { + forward_pointer_ids_.insert(id); + return SPV_SUCCESS; +} + +bool ValidationState_t::IsForwardPointer(uint32_t id) const { + return (forward_pointer_ids_.find(id) != forward_pointer_ids_.end()); +} + +void ValidationState_t::AssignNameToId(uint32_t id, std::string name) { + operand_names_[id] = name; +} + +std::string ValidationState_t::getIdName(uint32_t id) const { + const std::string id_name = name_mapper_(id); + + std::stringstream out; + out << id << "[%" << id_name << "]"; + return out.str(); +} + +size_t ValidationState_t::unresolved_forward_id_count() const { + return unresolved_forward_ids_.size(); +} + +std::vector ValidationState_t::UnresolvedForwardIds() const { + std::vector out(std::begin(unresolved_forward_ids_), + std::end(unresolved_forward_ids_)); + return out; +} + +bool ValidationState_t::IsDefinedId(uint32_t id) const { + return all_definitions_.find(id) != std::end(all_definitions_); +} + +const Instruction* ValidationState_t::FindDef(uint32_t id) const { + auto it = all_definitions_.find(id); + if (it == all_definitions_.end()) return nullptr; + return it->second; +} + +Instruction* ValidationState_t::FindDef(uint32_t id) { + auto it = all_definitions_.find(id); + if (it == all_definitions_.end()) return nullptr; + return it->second; +} + +ModuleLayoutSection ValidationState_t::current_layout_section() const { + return current_layout_section_; +} + +void ValidationState_t::ProgressToNextLayoutSectionOrder() { + // Guard against going past the last element(kLayoutFunctionDefinitions) + if (current_layout_section_ <= kLayoutFunctionDefinitions) { + current_layout_section_ = + static_cast(current_layout_section_ + 1); + } +} + +bool ValidationState_t::IsOpcodeInCurrentLayoutSection(SpvOp op) { + return IsInstructionInLayoutSection(current_layout_section_, op); +} + +DiagnosticStream ValidationState_t::diag(spv_result_t error_code, + const Instruction* inst) { + if (error_code == SPV_WARNING) { + if (num_of_warnings_ == max_num_of_warnings_) { + DiagnosticStream({0, 0, 0}, context_->consumer, "", error_code) + << "Other warnings have been suppressed.\n"; + } + if (num_of_warnings_ >= max_num_of_warnings_) { + return DiagnosticStream({0, 0, 0}, nullptr, "", error_code); + } + ++num_of_warnings_; + } + + std::string disassembly; + if (inst) disassembly = Disassemble(*inst); + + return DiagnosticStream({0, 0, inst ? inst->LineNum() : 0}, + context_->consumer, disassembly, error_code); +} + +std::vector& ValidationState_t::functions() { + return module_functions_; +} + +Function& ValidationState_t::current_function() { + assert(in_function_body()); + return module_functions_.back(); +} + +const Function& ValidationState_t::current_function() const { + assert(in_function_body()); + return module_functions_.back(); +} + +const Function* ValidationState_t::function(uint32_t id) const { + const auto it = id_to_function_.find(id); + if (it == id_to_function_.end()) return nullptr; + return it->second; +} + +Function* ValidationState_t::function(uint32_t id) { + auto it = id_to_function_.find(id); + if (it == id_to_function_.end()) return nullptr; + return it->second; +} + +bool ValidationState_t::in_function_body() const { return in_function_; } + +bool ValidationState_t::in_block() const { + return module_functions_.empty() == false && + module_functions_.back().current_block() != nullptr; +} + +void ValidationState_t::RegisterCapability(SpvCapability cap) { + // Avoid redundant work. Otherwise the recursion could induce work + // quadrdatic in the capability dependency depth. (Ok, not much, but + // it's something.) + if (module_capabilities_.Contains(cap)) return; + + module_capabilities_.Add(cap); + spv_operand_desc desc; + if (SPV_SUCCESS == + grammar_.lookupOperand(SPV_OPERAND_TYPE_CAPABILITY, cap, &desc)) { + CapabilitySet(desc->numCapabilities, desc->capabilities) + .ForEach([this](SpvCapability c) { RegisterCapability(c); }); + } + + switch (cap) { + case SpvCapabilityKernel: + features_.group_ops_reduce_and_scans = true; + break; + case SpvCapabilityInt8: + features_.use_int8_type = true; + features_.declare_int8_type = true; + break; + case SpvCapabilityStorageBuffer8BitAccess: + case SpvCapabilityUniformAndStorageBuffer8BitAccess: + case SpvCapabilityStoragePushConstant8: + features_.declare_int8_type = true; + break; + case SpvCapabilityInt16: + features_.declare_int16_type = true; + break; + case SpvCapabilityFloat16: + case SpvCapabilityFloat16Buffer: + features_.declare_float16_type = true; + break; + case SpvCapabilityStorageUniformBufferBlock16: + case SpvCapabilityStorageUniform16: + case SpvCapabilityStoragePushConstant16: + case SpvCapabilityStorageInputOutput16: + features_.declare_int16_type = true; + features_.declare_float16_type = true; + features_.free_fp_rounding_mode = true; + break; + case SpvCapabilityVariablePointers: + features_.variable_pointers = true; + features_.variable_pointers_storage_buffer = true; + break; + case SpvCapabilityVariablePointersStorageBuffer: + features_.variable_pointers_storage_buffer = true; + break; + default: + break; + } +} + +void ValidationState_t::RegisterExtension(Extension ext) { + if (module_extensions_.Contains(ext)) return; + + module_extensions_.Add(ext); + + switch (ext) { + case kSPV_AMD_gpu_shader_half_float: + case kSPV_AMD_gpu_shader_half_float_fetch: + // SPV_AMD_gpu_shader_half_float enables float16 type. + // https://github.com/KhronosGroup/SPIRV-Tools/issues/1375 + features_.declare_float16_type = true; + break; + case kSPV_AMD_gpu_shader_int16: + // This is not yet in the extension, but it's recommended for it. + // See https://github.com/KhronosGroup/glslang/issues/848 + features_.uconvert_spec_constant_op = true; + break; + case kSPV_AMD_shader_ballot: + // The grammar doesn't encode the fact that SPV_AMD_shader_ballot + // enables the use of group operations Reduce, InclusiveScan, + // and ExclusiveScan. Enable it manually. + // https://github.com/KhronosGroup/SPIRV-Tools/issues/991 + features_.group_ops_reduce_and_scans = true; + break; + default: + break; + } +} + +bool ValidationState_t::HasAnyOfCapabilities( + const CapabilitySet& capabilities) const { + return module_capabilities_.HasAnyOf(capabilities); +} + +bool ValidationState_t::HasAnyOfExtensions( + const ExtensionSet& extensions) const { + return module_extensions_.HasAnyOf(extensions); +} + +void ValidationState_t::set_addressing_model(SpvAddressingModel am) { + addressing_model_ = am; + switch (am) { + case SpvAddressingModelPhysical32: + pointer_size_and_alignment_ = 4; + break; + default: + // fall through + case SpvAddressingModelPhysical64: + case SpvAddressingModelPhysicalStorageBuffer64EXT: + pointer_size_and_alignment_ = 8; + break; + } +} + +SpvAddressingModel ValidationState_t::addressing_model() const { + return addressing_model_; +} + +void ValidationState_t::set_memory_model(SpvMemoryModel mm) { + memory_model_ = mm; +} + +SpvMemoryModel ValidationState_t::memory_model() const { return memory_model_; } + +spv_result_t ValidationState_t::RegisterFunction( + uint32_t id, uint32_t ret_type_id, SpvFunctionControlMask function_control, + uint32_t function_type_id) { + assert(in_function_body() == false && + "RegisterFunction can only be called when parsing the binary outside " + "of another function"); + in_function_ = true; + module_functions_.emplace_back(id, ret_type_id, function_control, + function_type_id); + id_to_function_.emplace(id, ¤t_function()); + + // TODO(umar): validate function type and type_id + + return SPV_SUCCESS; +} + +spv_result_t ValidationState_t::RegisterFunctionEnd() { + assert(in_function_body() == true && + "RegisterFunctionEnd can only be called when parsing the binary " + "inside of another function"); + assert(in_block() == false && + "RegisterFunctionParameter can only be called when parsing the binary " + "ouside of a block"); + current_function().RegisterFunctionEnd(); + in_function_ = false; + return SPV_SUCCESS; +} + +Instruction* ValidationState_t::AddOrderedInstruction( + const spv_parsed_instruction_t* inst) { + ordered_instructions_.emplace_back(inst); + ordered_instructions_.back().SetLineNum(ordered_instructions_.size()); + return &ordered_instructions_.back(); +} + +// Improves diagnostic messages by collecting names of IDs +void ValidationState_t::RegisterDebugInstruction(const Instruction* inst) { + switch (inst->opcode()) { + case SpvOpName: { + const auto target = inst->GetOperandAs(0); + const auto* str = reinterpret_cast(inst->words().data() + + inst->operand(1).offset); + AssignNameToId(target, str); + break; + } + case SpvOpMemberName: { + const auto target = inst->GetOperandAs(0); + const auto* str = reinterpret_cast(inst->words().data() + + inst->operand(2).offset); + AssignNameToId(target, str); + break; + } + case SpvOpSourceContinued: + case SpvOpSource: + case SpvOpSourceExtension: + case SpvOpString: + case SpvOpLine: + case SpvOpNoLine: + default: + break; + } +} + +void ValidationState_t::RegisterInstruction(Instruction* inst) { + if (inst->id()) all_definitions_.insert(std::make_pair(inst->id(), inst)); + + // If the instruction is using an OpTypeSampledImage as an operand, it should + // be recorded. The validator will ensure that all usages of an + // OpTypeSampledImage and its definition are in the same basic block. + for (uint16_t i = 0; i < inst->operands().size(); ++i) { + const spv_parsed_operand_t& operand = inst->operand(i); + if (SPV_OPERAND_TYPE_ID == operand.type) { + const uint32_t operand_word = inst->word(operand.offset); + Instruction* operand_inst = FindDef(operand_word); + if (operand_inst && SpvOpSampledImage == operand_inst->opcode()) { + RegisterSampledImageConsumer(operand_word, inst->id()); + } + } + } +} + +std::vector ValidationState_t::getSampledImageConsumers( + uint32_t sampled_image_id) const { + std::vector result; + auto iter = sampled_image_consumers_.find(sampled_image_id); + if (iter != sampled_image_consumers_.end()) { + result = iter->second; + } + return result; +} + +void ValidationState_t::RegisterSampledImageConsumer(uint32_t sampled_image_id, + uint32_t consumer_id) { + sampled_image_consumers_[sampled_image_id].push_back(consumer_id); +} + +uint32_t ValidationState_t::getIdBound() const { return id_bound_; } + +void ValidationState_t::setIdBound(const uint32_t bound) { id_bound_ = bound; } + +bool ValidationState_t::RegisterUniqueTypeDeclaration(const Instruction* inst) { + std::vector key; + key.push_back(static_cast(inst->opcode())); + for (size_t index = 0; index < inst->operands().size(); ++index) { + const spv_parsed_operand_t& operand = inst->operand(index); + + if (operand.type == SPV_OPERAND_TYPE_RESULT_ID) continue; + + const int words_begin = operand.offset; + const int words_end = words_begin + operand.num_words; + assert(words_end <= static_cast(inst->words().size())); + + key.insert(key.end(), inst->words().begin() + words_begin, + inst->words().begin() + words_end); + } + + return unique_type_declarations_.insert(std::move(key)).second; +} + +uint32_t ValidationState_t::GetTypeId(uint32_t id) const { + const Instruction* inst = FindDef(id); + return inst ? inst->type_id() : 0; +} + +SpvOp ValidationState_t::GetIdOpcode(uint32_t id) const { + const Instruction* inst = FindDef(id); + return inst ? inst->opcode() : SpvOpNop; +} + +uint32_t ValidationState_t::GetComponentType(uint32_t id) const { + const Instruction* inst = FindDef(id); + assert(inst); + + switch (inst->opcode()) { + case SpvOpTypeFloat: + case SpvOpTypeInt: + case SpvOpTypeBool: + return id; + + case SpvOpTypeVector: + return inst->word(2); + + case SpvOpTypeMatrix: + return GetComponentType(inst->word(2)); + + default: + break; + } + + if (inst->type_id()) return GetComponentType(inst->type_id()); + + assert(0); + return 0; +} + +uint32_t ValidationState_t::GetDimension(uint32_t id) const { + const Instruction* inst = FindDef(id); + assert(inst); + + switch (inst->opcode()) { + case SpvOpTypeFloat: + case SpvOpTypeInt: + case SpvOpTypeBool: + return 1; + + case SpvOpTypeVector: + case SpvOpTypeMatrix: + return inst->word(3); + + default: + break; + } + + if (inst->type_id()) return GetDimension(inst->type_id()); + + assert(0); + return 0; +} + +uint32_t ValidationState_t::GetBitWidth(uint32_t id) const { + const uint32_t component_type_id = GetComponentType(id); + const Instruction* inst = FindDef(component_type_id); + assert(inst); + + if (inst->opcode() == SpvOpTypeFloat || inst->opcode() == SpvOpTypeInt) + return inst->word(2); + + if (inst->opcode() == SpvOpTypeBool) return 1; + + assert(0); + return 0; +} + +bool ValidationState_t::IsFloatScalarType(uint32_t id) const { + const Instruction* inst = FindDef(id); + assert(inst); + return inst->opcode() == SpvOpTypeFloat; +} + +bool ValidationState_t::IsFloatVectorType(uint32_t id) const { + const Instruction* inst = FindDef(id); + assert(inst); + + if (inst->opcode() == SpvOpTypeVector) { + return IsFloatScalarType(GetComponentType(id)); + } + + return false; +} + +bool ValidationState_t::IsFloatScalarOrVectorType(uint32_t id) const { + const Instruction* inst = FindDef(id); + assert(inst); + + if (inst->opcode() == SpvOpTypeFloat) { + return true; + } + + if (inst->opcode() == SpvOpTypeVector) { + return IsFloatScalarType(GetComponentType(id)); + } + + return false; +} + +bool ValidationState_t::IsIntScalarType(uint32_t id) const { + const Instruction* inst = FindDef(id); + assert(inst); + return inst->opcode() == SpvOpTypeInt; +} + +bool ValidationState_t::IsIntVectorType(uint32_t id) const { + const Instruction* inst = FindDef(id); + assert(inst); + + if (inst->opcode() == SpvOpTypeVector) { + return IsIntScalarType(GetComponentType(id)); + } + + return false; +} + +bool ValidationState_t::IsIntScalarOrVectorType(uint32_t id) const { + const Instruction* inst = FindDef(id); + assert(inst); + + if (inst->opcode() == SpvOpTypeInt) { + return true; + } + + if (inst->opcode() == SpvOpTypeVector) { + return IsIntScalarType(GetComponentType(id)); + } + + return false; +} + +bool ValidationState_t::IsUnsignedIntScalarType(uint32_t id) const { + const Instruction* inst = FindDef(id); + assert(inst); + return inst->opcode() == SpvOpTypeInt && inst->word(3) == 0; +} + +bool ValidationState_t::IsUnsignedIntVectorType(uint32_t id) const { + const Instruction* inst = FindDef(id); + assert(inst); + + if (inst->opcode() == SpvOpTypeVector) { + return IsUnsignedIntScalarType(GetComponentType(id)); + } + + return false; +} + +bool ValidationState_t::IsSignedIntScalarType(uint32_t id) const { + const Instruction* inst = FindDef(id); + assert(inst); + return inst->opcode() == SpvOpTypeInt && inst->word(3) == 1; +} + +bool ValidationState_t::IsSignedIntVectorType(uint32_t id) const { + const Instruction* inst = FindDef(id); + assert(inst); + + if (inst->opcode() == SpvOpTypeVector) { + return IsSignedIntScalarType(GetComponentType(id)); + } + + return false; +} + +bool ValidationState_t::IsBoolScalarType(uint32_t id) const { + const Instruction* inst = FindDef(id); + assert(inst); + return inst->opcode() == SpvOpTypeBool; +} + +bool ValidationState_t::IsBoolVectorType(uint32_t id) const { + const Instruction* inst = FindDef(id); + assert(inst); + + if (inst->opcode() == SpvOpTypeVector) { + return IsBoolScalarType(GetComponentType(id)); + } + + return false; +} + +bool ValidationState_t::IsBoolScalarOrVectorType(uint32_t id) const { + const Instruction* inst = FindDef(id); + assert(inst); + + if (inst->opcode() == SpvOpTypeBool) { + return true; + } + + if (inst->opcode() == SpvOpTypeVector) { + return IsBoolScalarType(GetComponentType(id)); + } + + return false; +} + +bool ValidationState_t::IsFloatMatrixType(uint32_t id) const { + const Instruction* inst = FindDef(id); + assert(inst); + + if (inst->opcode() == SpvOpTypeMatrix) { + return IsFloatScalarType(GetComponentType(id)); + } + + return false; +} + +bool ValidationState_t::GetMatrixTypeInfo(uint32_t id, uint32_t* num_rows, + uint32_t* num_cols, + uint32_t* column_type, + uint32_t* component_type) const { + if (!id) return false; + + const Instruction* mat_inst = FindDef(id); + assert(mat_inst); + if (mat_inst->opcode() != SpvOpTypeMatrix) return false; + + const uint32_t vec_type = mat_inst->word(2); + const Instruction* vec_inst = FindDef(vec_type); + assert(vec_inst); + + if (vec_inst->opcode() != SpvOpTypeVector) { + assert(0); + return false; + } + + *num_cols = mat_inst->word(3); + *num_rows = vec_inst->word(3); + *column_type = mat_inst->word(2); + *component_type = vec_inst->word(2); + + return true; +} + +bool ValidationState_t::GetStructMemberTypes( + uint32_t struct_type_id, std::vector* member_types) const { + member_types->clear(); + if (!struct_type_id) return false; + + const Instruction* inst = FindDef(struct_type_id); + assert(inst); + if (inst->opcode() != SpvOpTypeStruct) return false; + + *member_types = + std::vector(inst->words().cbegin() + 2, inst->words().cend()); + + if (member_types->empty()) return false; + + return true; +} + +bool ValidationState_t::IsPointerType(uint32_t id) const { + const Instruction* inst = FindDef(id); + assert(inst); + return inst->opcode() == SpvOpTypePointer; +} + +bool ValidationState_t::GetPointerTypeInfo(uint32_t id, uint32_t* data_type, + uint32_t* storage_class) const { + if (!id) return false; + + const Instruction* inst = FindDef(id); + assert(inst); + if (inst->opcode() != SpvOpTypePointer) return false; + + *storage_class = inst->word(2); + *data_type = inst->word(3); + return true; +} + +uint32_t ValidationState_t::GetOperandTypeId(const Instruction* inst, + size_t operand_index) const { + return GetTypeId(inst->GetOperandAs(operand_index)); +} + +bool ValidationState_t::GetConstantValUint64(uint32_t id, uint64_t* val) const { + const Instruction* inst = FindDef(id); + if (!inst) { + assert(0 && "Instruction not found"); + return false; + } + + if (inst->opcode() != SpvOpConstant && inst->opcode() != SpvOpSpecConstant) + return false; + + if (!IsIntScalarType(inst->type_id())) return false; + + if (inst->words().size() == 4) { + *val = inst->word(3); + } else { + assert(inst->words().size() == 5); + *val = inst->word(3); + *val |= uint64_t(inst->word(4)) << 32; + } + return true; +} + +std::tuple ValidationState_t::EvalInt32IfConst( + uint32_t id) { + const Instruction* const inst = FindDef(id); + assert(inst); + const uint32_t type = inst->type_id(); + + if (type == 0 || !IsIntScalarType(type) || GetBitWidth(type) != 32) { + return std::make_tuple(false, false, 0); + } + + // Spec constant values cannot be evaluated so don't consider constant for + // the purpose of this method. + if (!spvOpcodeIsConstant(inst->opcode()) || + spvOpcodeIsSpecConstant(inst->opcode())) { + return std::make_tuple(true, false, 0); + } + + if (inst->opcode() == SpvOpConstantNull) { + return std::make_tuple(true, true, 0); + } + + assert(inst->words().size() == 4); + return std::make_tuple(true, true, inst->word(3)); +} + +void ValidationState_t::ComputeFunctionToEntryPointMapping() { + for (const uint32_t entry_point : entry_points()) { + std::stack call_stack; + std::set visited; + call_stack.push(entry_point); + while (!call_stack.empty()) { + const uint32_t called_func_id = call_stack.top(); + call_stack.pop(); + if (!visited.insert(called_func_id).second) continue; + + function_to_entry_points_[called_func_id].push_back(entry_point); + + const Function* called_func = function(called_func_id); + if (called_func) { + // Other checks should error out on this invalid SPIR-V. + for (const uint32_t new_call : called_func->function_call_targets()) { + call_stack.push(new_call); + } + } + } + } +} + +void ValidationState_t::ComputeRecursiveEntryPoints() { + for (const Function func : functions()) { + std::stack call_stack; + std::set visited; + + for (const uint32_t new_call : func.function_call_targets()) { + call_stack.push(new_call); + } + + while (!call_stack.empty()) { + const uint32_t called_func_id = call_stack.top(); + call_stack.pop(); + + if (!visited.insert(called_func_id).second) continue; + + if (called_func_id == func.id()) { + for (const uint32_t entry_point : + function_to_entry_points_[called_func_id]) + recursive_entry_points_.insert(entry_point); + break; + } + + const Function* called_func = function(called_func_id); + if (called_func) { + // Other checks should error out on this invalid SPIR-V. + for (const uint32_t new_call : called_func->function_call_targets()) { + call_stack.push(new_call); + } + } + } + } +} + +const std::vector& ValidationState_t::FunctionEntryPoints( + uint32_t func) const { + auto iter = function_to_entry_points_.find(func); + if (iter == function_to_entry_points_.end()) { + return empty_ids_; + } else { + return iter->second; + } +} + +std::set ValidationState_t::EntryPointReferences(uint32_t id) const { + std::set referenced_entry_points; + const auto inst = FindDef(id); + if (!inst) return referenced_entry_points; + + std::vector stack; + stack.push_back(inst); + while (!stack.empty()) { + const auto current_inst = stack.back(); + stack.pop_back(); + + if (const auto func = current_inst->function()) { + // Instruction lives in a function, we can stop searching. + const auto function_entry_points = FunctionEntryPoints(func->id()); + referenced_entry_points.insert(function_entry_points.begin(), + function_entry_points.end()); + } else { + // Instruction is in the global scope, keep searching its uses. + for (auto pair : current_inst->uses()) { + const auto next_inst = pair.first; + stack.push_back(next_inst); + } + } + } + + return referenced_entry_points; +} + +std::string ValidationState_t::Disassemble(const Instruction& inst) const { + const spv_parsed_instruction_t& c_inst(inst.c_inst()); + return Disassemble(c_inst.words, c_inst.num_words); +} + +std::string ValidationState_t::Disassemble(const uint32_t* words, + uint16_t num_words) const { + uint32_t disassembly_options = SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES; + + return spvInstructionBinaryToText(context()->target_env, words, num_words, + words_, num_words_, disassembly_options); +} + +} // namespace val +} // namespace spvtools diff --git a/source/val/validation_state.h b/source/val/validation_state.h new file mode 100644 index 000000000..2c94906b9 --- /dev/null +++ b/source/val/validation_state.h @@ -0,0 +1,716 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_VAL_VALIDATION_STATE_H_ +#define SOURCE_VAL_VALIDATION_STATE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/assembly_grammar.h" +#include "source/diagnostic.h" +#include "source/disassemble.h" +#include "source/enum_set.h" +#include "source/latest_version_spirv_header.h" +#include "source/name_mapper.h" +#include "source/spirv_definition.h" +#include "source/spirv_validator_options.h" +#include "source/val/decoration.h" +#include "source/val/function.h" +#include "source/val/instruction.h" +#include "spirv-tools/libspirv.h" + +namespace spvtools { +namespace val { + +/// This enum represents the sections of a SPIRV module. See section 2.4 +/// of the SPIRV spec for additional details of the order. The enumerant values +/// are in the same order as the vector returned by GetModuleOrder +enum ModuleLayoutSection { + kLayoutCapabilities, /// < Section 2.4 #1 + kLayoutExtensions, /// < Section 2.4 #2 + kLayoutExtInstImport, /// < Section 2.4 #3 + kLayoutMemoryModel, /// < Section 2.4 #4 + kLayoutEntryPoint, /// < Section 2.4 #5 + kLayoutExecutionMode, /// < Section 2.4 #6 + kLayoutDebug1, /// < Section 2.4 #7 > 1 + kLayoutDebug2, /// < Section 2.4 #7 > 2 + kLayoutDebug3, /// < Section 2.4 #7 > 3 + kLayoutAnnotations, /// < Section 2.4 #8 + kLayoutTypes, /// < Section 2.4 #9 + kLayoutFunctionDeclarations, /// < Section 2.4 #10 + kLayoutFunctionDefinitions /// < Section 2.4 #11 +}; + +/// This class manages the state of the SPIR-V validation as it is being parsed. +class ValidationState_t { + public: + // Features that can optionally be turned on by a capability or environment. + struct Feature { + bool declare_int16_type = false; // Allow OpTypeInt with 16 bit width? + bool declare_float16_type = false; // Allow OpTypeFloat with 16 bit width? + bool free_fp_rounding_mode = false; // Allow the FPRoundingMode decoration + // and its vaules to be used without + // requiring any capability + + // Allow functionalities enabled by VariablePointers capability. + bool variable_pointers = false; + // Allow functionalities enabled by VariablePointersStorageBuffer + // capability. + bool variable_pointers_storage_buffer = false; + + // Permit group oerations Reduce, InclusiveScan, ExclusiveScan + bool group_ops_reduce_and_scans = false; + + // Disallows the use of OpUndef + bool bans_op_undef = false; + + // Allow OpTypeInt with 8 bit width? + bool declare_int8_type = false; + + // Target environment uses relaxed block layout. + // This is true for Vulkan 1.1 or later. + bool env_relaxed_block_layout = false; + + // Allow an OpTypeInt with 8 bit width to be used in more than just int + // conversion opcodes + bool use_int8_type = false; + + // Use scalar block layout. See VK_EXT_scalar_block_layout: + // Defines scalar alignment: + // - scalar alignment equals the scalar size in bytes + // - array alignment is same as its element alignment + // - array alignment is max alignment of any of its members + // - vector alignment is same as component alignment + // - matrix alignment is same as component alignment + // For struct in Uniform, StorageBuffer, PushConstant: + // - Offset of a member is multiple of scalar alignment of that member + // - ArrayStride and MatrixStride are multiples of scalar alignment + // Members need not be listed in offset order + bool scalar_block_layout = false; + + // Permit UConvert as an OpSpecConstantOp operation. + // The Kernel capability already enables it, separately from this flag. + bool uconvert_spec_constant_op = false; + }; + + ValidationState_t(const spv_const_context context, + const spv_const_validator_options opt, + const uint32_t* words, const size_t num_words, + const uint32_t max_warnings); + + /// Returns the context + spv_const_context context() const { return context_; } + + /// Returns the command line options + spv_const_validator_options options() const { return options_; } + + /// Sets the ID of the generator for this module. + void setGenerator(uint32_t gen) { generator_ = gen; } + + /// Returns the ID of the generator for this module. + uint32_t generator() const { return generator_; } + + /// Sets the SPIR-V version of this module. + void setVersion(uint32_t ver) { version_ = ver; } + + /// Gets the SPIR-V version of this module. + uint32_t version() const { return version_; } + + /// Forward declares the id in the module + spv_result_t ForwardDeclareId(uint32_t id); + + /// Removes a forward declared ID if it has been defined + spv_result_t RemoveIfForwardDeclared(uint32_t id); + + /// Registers an ID as a forward pointer + spv_result_t RegisterForwardPointer(uint32_t id); + + /// Returns whether or not an ID is a forward pointer + bool IsForwardPointer(uint32_t id) const; + + /// Assigns a name to an ID + void AssignNameToId(uint32_t id, std::string name); + + /// Returns a string representation of the ID in the format [Name] where + /// the is the numeric valid of the id and the Name is a name assigned by + /// the OpName instruction + std::string getIdName(uint32_t id) const; + + /// Accessor function for ID bound. + uint32_t getIdBound() const; + + /// Mutator function for ID bound. + void setIdBound(uint32_t bound); + + /// Returns the number of ID which have been forward referenced but not + /// defined + size_t unresolved_forward_id_count() const; + + /// Returns a vector of unresolved forward ids. + std::vector UnresolvedForwardIds() const; + + /// Returns true if the id has been defined + bool IsDefinedId(uint32_t id) const; + + /// Increments the total number of instructions in the file. + void increment_total_instructions() { total_instructions_++; } + + /// Increments the total number of functions in the file. + void increment_total_functions() { total_functions_++; } + + /// Allocates internal storage. Note, calling this will invalidate any + /// pointers to |ordered_instructions_| or |module_functions_| and, hence, + /// should only be called at the beginning of validation. + void preallocateStorage(); + + /// Returns the current layout section which is being processed + ModuleLayoutSection current_layout_section() const; + + /// Increments the module_layout_order_section_ + void ProgressToNextLayoutSectionOrder(); + + /// Determines if the op instruction is part of the current section + bool IsOpcodeInCurrentLayoutSection(SpvOp op); + + DiagnosticStream diag(spv_result_t error_code, const Instruction* inst); + + /// Returns the function states + std::vector& functions(); + + /// Returns the function states + Function& current_function(); + const Function& current_function() const; + + /// Returns function state with the given id, or nullptr if no such function. + const Function* function(uint32_t id) const; + Function* function(uint32_t id); + + /// Returns true if the called after a function instruction but before the + /// function end instruction + bool in_function_body() const; + + /// Returns true if called after a label instruction but before a branch + /// instruction + bool in_block() const; + + struct EntryPointDescription { + std::string name; + std::vector interfaces; + }; + + /// Registers |id| as an entry point with |execution_model| and |interfaces|. + void RegisterEntryPoint(const uint32_t id, SpvExecutionModel execution_model, + EntryPointDescription&& desc) { + entry_points_.push_back(id); + entry_point_to_execution_models_[id].insert(execution_model); + entry_point_descriptions_[id].emplace_back(desc); + } + + /// Returns a list of entry point function ids + const std::vector& entry_points() const { return entry_points_; } + + /// Returns the set of entry points that root call graphs that contain + /// recursion. + const std::set& recursive_entry_points() const { + return recursive_entry_points_; + } + + /// Registers execution mode for the given entry point. + void RegisterExecutionModeForEntryPoint(uint32_t entry_point, + SpvExecutionMode execution_mode) { + entry_point_to_execution_modes_[entry_point].insert(execution_mode); + } + + /// Returns the interface descriptions of a given entry point. + const std::vector& entry_point_descriptions( + uint32_t entry_point) { + return entry_point_descriptions_.at(entry_point); + } + + /// Returns Execution Models for the given Entry Point. + /// Returns nullptr if none found (would trigger assertion). + const std::set* GetExecutionModels( + uint32_t entry_point) const { + const auto it = entry_point_to_execution_models_.find(entry_point); + if (it == entry_point_to_execution_models_.end()) { + assert(0); + return nullptr; + } + return &it->second; + } + + /// Returns Execution Modes for the given Entry Point. + /// Returns nullptr if none found. + const std::set* GetExecutionModes( + uint32_t entry_point) const { + const auto it = entry_point_to_execution_modes_.find(entry_point); + if (it == entry_point_to_execution_modes_.end()) { + return nullptr; + } + return &it->second; + } + + /// Traverses call tree and computes function_to_entry_points_. + /// Note: called after fully parsing the binary. + void ComputeFunctionToEntryPointMapping(); + + /// Traverse call tree and computes recursive_entry_points_. + /// Note: called after fully parsing the binary and calling + /// ComputeFunctionToEntryPointMapping. + void ComputeRecursiveEntryPoints(); + + /// Returns all the entry points that can call |func|. + const std::vector& FunctionEntryPoints(uint32_t func) const; + + /// Returns all the entry points that statically use |id|. + /// + /// Note: requires ComputeFunctionToEntryPointMapping to have been called. + std::set EntryPointReferences(uint32_t id) const; + + /// Inserts an to the set of functions that are target of OpFunctionCall. + void AddFunctionCallTarget(const uint32_t id) { + function_call_targets_.insert(id); + current_function().AddFunctionCallTarget(id); + } + + /// Returns whether or not a function is the target of OpFunctionCall. + bool IsFunctionCallTarget(const uint32_t id) { + return (function_call_targets_.find(id) != function_call_targets_.end()); + } + + bool IsFunctionCallDefined(const uint32_t id) { + return (id_to_function_.find(id) != id_to_function_.end()); + } + /// Registers the capability and its dependent capabilities + void RegisterCapability(SpvCapability cap); + + /// Registers the extension. + void RegisterExtension(Extension ext); + + /// Registers the function in the module. Subsequent instructions will be + /// called against this function + spv_result_t RegisterFunction(uint32_t id, uint32_t ret_type_id, + SpvFunctionControlMask function_control, + uint32_t function_type_id); + + /// Register a function end instruction + spv_result_t RegisterFunctionEnd(); + + /// Returns true if the capability is enabled in the module. + bool HasCapability(SpvCapability cap) const { + return module_capabilities_.Contains(cap); + } + + /// Returns true if the extension is enabled in the module. + bool HasExtension(Extension ext) const { + return module_extensions_.Contains(ext); + } + + /// Returns true if any of the capabilities is enabled, or if |capabilities| + /// is an empty set. + bool HasAnyOfCapabilities(const CapabilitySet& capabilities) const; + + /// Returns true if any of the extensions is enabled, or if |extensions| + /// is an empty set. + bool HasAnyOfExtensions(const ExtensionSet& extensions) const; + + /// Sets the addressing model of this module (logical/physical). + void set_addressing_model(SpvAddressingModel am); + + /// Returns true if the OpMemoryModel was found. + bool has_memory_model_specified() const { + return addressing_model_ != SpvAddressingModelMax && + memory_model_ != SpvMemoryModelMax; + } + + /// Returns the addressing model of this module, or Logical if uninitialized. + SpvAddressingModel addressing_model() const; + + /// Returns the addressing model of this module, or Logical if uninitialized. + uint32_t pointer_size_and_alignment() const { + return pointer_size_and_alignment_; + } + + /// Sets the memory model of this module. + void set_memory_model(SpvMemoryModel mm); + + /// Returns the memory model of this module, or Simple if uninitialized. + SpvMemoryModel memory_model() const; + + const AssemblyGrammar& grammar() const { return grammar_; } + + /// Inserts the instruction into the list of ordered instructions in the file. + Instruction* AddOrderedInstruction(const spv_parsed_instruction_t* inst); + + /// Registers the instruction. This will add the instruction to the list of + /// definitions and register sampled image consumers. + void RegisterInstruction(Instruction* inst); + + /// Registers the debug instruction information. + void RegisterDebugInstruction(const Instruction* inst); + + /// Registers the decoration for the given + void RegisterDecorationForId(uint32_t id, const Decoration& dec) { + id_decorations_[id].push_back(dec); + } + + /// Registers the list of decorations for the given + template + void RegisterDecorationsForId(uint32_t id, InputIt begin, InputIt end) { + std::vector& cur_decs = id_decorations_[id]; + cur_decs.insert(cur_decs.end(), begin, end); + } + + /// Registers the list of decorations for the given member of the given + /// structure. + template + void RegisterDecorationsForStructMember(uint32_t struct_id, + uint32_t member_index, InputIt begin, + InputIt end) { + RegisterDecorationsForId(struct_id, begin, end); + for (auto& decoration : id_decorations_[struct_id]) { + decoration.set_struct_member_index(member_index); + } + } + + /// Returns all the decorations for the given . If no decorations exist + /// for the , it registers an empty vector for it in the map and + /// returns the empty vector. + std::vector& id_decorations(uint32_t id) { + return id_decorations_[id]; + } + + // Returns const pointer to the internal decoration container. + const std::map>& id_decorations() const { + return id_decorations_; + } + + /// Returns true if the given id has the given decoration , + /// otherwise returns false. + bool HasDecoration(uint32_t id, SpvDecoration dec) { + const auto& decorations = id_decorations_.find(id); + if (decorations == id_decorations_.end()) return false; + + return std::any_of( + decorations->second.begin(), decorations->second.end(), + [dec](const Decoration& d) { return dec == d.dec_type(); }); + } + + /// Finds id's def, if it exists. If found, returns the definition otherwise + /// nullptr + const Instruction* FindDef(uint32_t id) const; + + /// Finds id's def, if it exists. If found, returns the definition otherwise + /// nullptr + Instruction* FindDef(uint32_t id); + + /// Returns the instructions in the order they appear in the binary + const std::vector& ordered_instructions() const { + return ordered_instructions_; + } + + /// Returns a map of instructions mapped by their result id + const std::unordered_map& all_definitions() const { + return all_definitions_; + } + + /// Returns a vector containing the Ids of instructions that consume the given + /// SampledImage id. + std::vector getSampledImageConsumers(uint32_t id) const; + + /// Records cons_id as a consumer of sampled_image_id. + void RegisterSampledImageConsumer(uint32_t sampled_image_id, + uint32_t cons_id); + + /// Returns the set of Global Variables. + std::unordered_set& global_vars() { return global_vars_; } + + /// Returns the set of Local Variables. + std::unordered_set& local_vars() { return local_vars_; } + + /// Returns the number of Global Variables. + size_t num_global_vars() { return global_vars_.size(); } + + /// Returns the number of Local Variables. + size_t num_local_vars() { return local_vars_.size(); } + + /// Inserts a new to the set of Global Variables. + void registerGlobalVariable(const uint32_t id) { global_vars_.insert(id); } + + /// Inserts a new to the set of Local Variables. + void registerLocalVariable(const uint32_t id) { local_vars_.insert(id); } + + // Returns true if using relaxed block layout, equivalent to + // VK_KHR_relaxed_block_layout. + bool IsRelaxedBlockLayout() const { + return features_.env_relaxed_block_layout || options()->relax_block_layout; + } + + /// Sets the struct nesting depth for a given struct ID + void set_struct_nesting_depth(uint32_t id, uint32_t depth) { + struct_nesting_depth_[id] = depth; + } + + /// Returns the nesting depth of a given structure ID + uint32_t struct_nesting_depth(uint32_t id) { + return struct_nesting_depth_[id]; + } + + /// Records that the structure type has a member decorated with a built-in. + void RegisterStructTypeWithBuiltInMember(uint32_t id) { + builtin_structs_.insert(id); + } + + /// Returns true if the struct type with the given Id has a BuiltIn member. + bool IsStructTypeWithBuiltInMember(uint32_t id) const { + return (builtin_structs_.find(id) != builtin_structs_.end()); + } + + // Returns the state of optional features. + const Feature& features() const { return features_; } + + /// Adds the instruction data to unique_type_declarations_. + /// Returns false if an identical type declaration already exists. + bool RegisterUniqueTypeDeclaration(const Instruction* inst); + + // Returns type_id of the scalar component of |id|. + // |id| can be either + // - scalar, vector or matrix type + // - object of either scalar, vector or matrix type + uint32_t GetComponentType(uint32_t id) const; + + // Returns + // - 1 for scalar types or objects + // - vector size for vector types or objects + // - num columns for matrix types or objects + // Should not be called with any other arguments (will return zero and invoke + // assertion). + uint32_t GetDimension(uint32_t id) const; + + // Returns bit width of scalar or component. + // |id| can be + // - scalar, vector or matrix type + // - object of either scalar, vector or matrix type + // Will invoke assertion and return 0 if |id| is none of the above. + uint32_t GetBitWidth(uint32_t id) const; + + // Provides detailed information on matrix type. + // Returns false iff |id| is not matrix type. + bool GetMatrixTypeInfo(uint32_t id, uint32_t* num_rows, uint32_t* num_cols, + uint32_t* column_type, uint32_t* component_type) const; + + // Collects struct member types into |member_types|. + // Returns false iff not struct type or has no members. + // Deletes prior contents of |member_types|. + bool GetStructMemberTypes(uint32_t struct_type_id, + std::vector* member_types) const; + + // Returns true iff |id| is a type corresponding to the name of the function. + // Only works for types not for objects. + bool IsFloatScalarType(uint32_t id) const; + bool IsFloatVectorType(uint32_t id) const; + bool IsFloatScalarOrVectorType(uint32_t id) const; + bool IsFloatMatrixType(uint32_t id) const; + bool IsIntScalarType(uint32_t id) const; + bool IsIntVectorType(uint32_t id) const; + bool IsIntScalarOrVectorType(uint32_t id) const; + bool IsUnsignedIntScalarType(uint32_t id) const; + bool IsUnsignedIntVectorType(uint32_t id) const; + bool IsSignedIntScalarType(uint32_t id) const; + bool IsSignedIntVectorType(uint32_t id) const; + bool IsBoolScalarType(uint32_t id) const; + bool IsBoolVectorType(uint32_t id) const; + bool IsBoolScalarOrVectorType(uint32_t id) const; + bool IsPointerType(uint32_t id) const; + + // Gets value from OpConstant and OpSpecConstant as uint64. + // Returns false on failure (no instruction, wrong instruction, not int). + bool GetConstantValUint64(uint32_t id, uint64_t* val) const; + + // Returns type_id if id has type or zero otherwise. + uint32_t GetTypeId(uint32_t id) const; + + // Returns opcode of the instruction which issued the id or OpNop if the + // instruction is not registered. + SpvOp GetIdOpcode(uint32_t id) const; + + // Returns type_id for given id operand if it has a type or zero otherwise. + // |operand_index| is expected to be pointing towards an operand which is an + // id. + uint32_t GetOperandTypeId(const Instruction* inst, + size_t operand_index) const; + + // Provides information on pointer type. Returns false iff not pointer type. + bool GetPointerTypeInfo(uint32_t id, uint32_t* data_type, + uint32_t* storage_class) const; + + // Tries to evaluate a 32-bit signed or unsigned scalar integer constant. + // Returns tuple . + // OpSpecConstant* return |is_const_int32| as false since their values cannot + // be relied upon during validation. + std::tuple EvalInt32IfConst(uint32_t id); + + // Returns the disassembly string for the given instruction. + std::string Disassemble(const Instruction& inst) const; + + // Returns the disassembly string for the given instruction. + std::string Disassemble(const uint32_t* words, uint16_t num_words) const; + + private: + ValidationState_t(const ValidationState_t&); + + const spv_const_context context_; + + /// Stores the Validator command line options. Must be a valid options object. + const spv_const_validator_options options_; + + /// The SPIR-V binary module we're validating. + const uint32_t* words_; + const size_t num_words_; + + /// The generator of the SPIR-V. + uint32_t generator_ = 0; + + /// The version of the SPIR-V. + uint32_t version_ = 0; + + /// The total number of instructions in the binary. + size_t total_instructions_ = 0; + /// The total number of functions in the binary. + size_t total_functions_ = 0; + + /// IDs which have been forward declared but have not been defined + std::unordered_set unresolved_forward_ids_; + + /// IDs that have been declared as forward pointers. + std::unordered_set forward_pointer_ids_; + + /// Stores a vector of instructions that use the result of a given + /// OpSampledImage instruction. + std::unordered_map> sampled_image_consumers_; + + /// A map of operand IDs and their names defined by the OpName instruction + std::unordered_map operand_names_; + + /// The section of the code being processed + ModuleLayoutSection current_layout_section_; + + /// A list of functions in the module. + /// Pointers to objects in this container are guaranteed to be stable and + /// valid until the end of lifetime of the validation state. + std::vector module_functions_; + + /// Capabilities declared in the module + CapabilitySet module_capabilities_; + + /// Extensions declared in the module + ExtensionSet module_extensions_; + + /// List of all instructions in the order they appear in the binary + std::vector ordered_instructions_; + + /// Instructions that can be referenced by Ids + std::unordered_map all_definitions_; + + /// IDs that are entry points, ie, arguments to OpEntryPoint. + std::vector entry_points_; + + /// Maps an entry point id to its desciptions. + std::unordered_map> + entry_point_descriptions_; + + /// IDs that are entry points, ie, arguments to OpEntryPoint, and root a call + /// graph that recurses. + std::set recursive_entry_points_; + + /// Functions IDs that are target of OpFunctionCall. + std::unordered_set function_call_targets_; + + /// ID Bound from the Header + uint32_t id_bound_; + + /// Set of Global Variable IDs (Storage Class other than 'Function') + std::unordered_set global_vars_; + + /// Set of Local Variable IDs ('Function' Storage Class) + std::unordered_set local_vars_; + + /// Set of struct types that have members with a BuiltIn decoration. + std::unordered_set builtin_structs_; + + /// Structure Nesting Depth + std::unordered_map struct_nesting_depth_; + + /// Stores the list of decorations for a given + std::map> id_decorations_; + + /// Stores type declarations which need to be unique (i.e. non-aggregates), + /// in the form [opcode, operand words], result_id is not stored. + /// Using ordered set to avoid the need for a vector hash function. + /// The size of this container is expected not to exceed double-digits. + std::set> unique_type_declarations_; + + AssemblyGrammar grammar_; + + SpvAddressingModel addressing_model_; + SpvMemoryModel memory_model_; + // pointer size derived from addressing model. Assumes all storage classes + // have the same pointer size (for physical pointer types). + uint32_t pointer_size_and_alignment_; + + /// NOTE: See correspoding getter functions + bool in_function_; + + /// The state of optional features. These are determined by capabilities + /// declared by the module and the environment. + Feature features_; + + /// Maps function ids to function stat objects. + std::unordered_map id_to_function_; + + /// Mapping entry point -> execution models. It is presumed that the same + /// function could theoretically be used as 'main' by multiple OpEntryPoint + /// instructions. + std::unordered_map> + entry_point_to_execution_models_; + + /// Mapping entry point -> execution modes. + std::unordered_map> + entry_point_to_execution_modes_; + + /// Mapping function -> array of entry points inside this + /// module which can (indirectly) call the function. + std::unordered_map> function_to_entry_points_; + const std::vector empty_ids_; + + /// Maps ids to friendly names. + std::unique_ptr friendly_mapper_; + spvtools::NameMapper name_mapper_; + + /// Variables used to reduce the number of diagnostic messages. + uint32_t num_of_warnings_; + uint32_t max_num_of_warnings_; +}; + +} // namespace val +} // namespace spvtools + +#endif // SOURCE_VAL_VALIDATION_STATE_H_ diff --git a/syntax.md b/syntax.md new file mode 100644 index 000000000..be3a5d5ad --- /dev/null +++ b/syntax.md @@ -0,0 +1,238 @@ +# SPIR-V Assembly language syntax + +## Overview + +The assembly attempts to adhere to the binary form from Section 3 of the SPIR-V +spec as closely as possible, with one exception aiming at improving the text's +readability. The `` generated by an instruction is moved to the +beginning of that instruction and followed by an `=` sign. This allows us to +distinguish between variable definitions and uses and locate value definitions +more easily. + +Here is an example: + +``` + OpCapability Shader + OpMemoryModel Logical Simple + OpEntryPoint GLCompute %3 "main" + OpExecutionMode %3 LocalSize 64 64 1 +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpFunction %1 None %2 +%4 = OpLabel + OpReturn + OpFunctionEnd +``` + +A module is a sequence of instructions, separated by whitespace. +An instruction is an opcode name followed by operands, separated by +whitespace. Typically each instruction is presented on its own line, +but the assembler does not enforce this rule. + +The opcode names and expected operands are described in Section 3 of +the SPIR-V specification. An operand is one of: +* a literal integer: A decimal integer, or a hexadecimal integer. + A hexadecimal integer is indicated by a leading `0x` or `0X`. A hex + integer supplied for a signed integer value will be sign-extended. + For example, `0xffff` supplied as the literal for an `OpConstant` + on a signed 16-bit integer type will be interpreted as the value `-1`. +* a literal floating point number, in decimal or hexadecimal form. + See [below](#floats). +* a literal string. + * A literal string is everything following a double-quote `"` until the + following un-escaped double-quote. This includes special characters such + as newlines. + * A backslash `\` may be used to escape characters in the string. The `\` + may be used to escape a double-quote or a `\` but is simply ignored when + preceding any other character. +* a named enumerated value, specific to that operand position. For example, + the `OpMemoryModel` takes a named Addressing Model operand (e.g. `Logical` or + `Physical32`), and a named Memory Model operand (e.g. `Simple` or `OpenCL`). + Named enumerated values are only meaningful in specific positions, and will + otherwise generate an error. +* a mask expression, consisting of one or more mask enum names separated + by `|`. For example, the expression `NotNaN|NotInf|NSZ` denotes the mask + which is the combination of the `NotNaN`, `NotInf`, and `NSZ` flags. +* an injected immediate integer: `!`. See [below](#immediate). +* an ID, e.g. `%foo`. See [below](#id). +* the name of an extended instruction. For example, `sqrt` in an extended + instruction such as `%f = OpExtInst %f32 %OpenCLImport sqrt %arg` +* the name of an opcode for OpSpecConstantOp, but where the `Op` prefix + is removed. For example, the following indicates the use of an integer + addition in a specialization constant computation: + `%sum = OpSpecConstantOp %i32 IAdd %a %b` + +## ID Definitions & Usage + + +An ID _definition_ pertains to the `` of an instruction, and ID +_usage_ is a use of an ID as an input to an instruction. + +An ID in the assembly language begins with `%` and must be followed by a name +consisting of one or more letters, numbers or underscore characters. + +For every ID in the assembly program, the assembler generates a unique number +called the ID's internal number. Then each ID reference translates into its +internal number in the SPIR-V output. Internal numbers are unique within the +compilation unit: no two IDs in the same unit will share internal numbers. + +The disassembler generates IDs where the name is always a decimal number +greater than 0. + +So the example can be rewritten using more user-friendly names, as follows: +``` + OpCapability Shader + OpMemoryModel Logical Simple + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 64 64 1 + %void = OpTypeVoid +%fnMain = OpTypeFunction %void + %main = OpFunction %void None %fnMain +%lbMain = OpLabel + OpReturn + OpFunctionEnd +``` + +## Floating point literals + + +The assembler and disassembler support floating point literals in both +decimal and hexadecimal form. + +The syntax for a floating point literal is the same as floating point +constants in the C programming language, except: +* An optional leading minus (`-`) is part of the literal. +* An optional type specifier suffix is not allowed. +Infinity and NaN values are expressed in hexadecimal float literals +by using the maximum representable exponent for the bit width. + +For example, in 32-bit floating point, 8 bits are used for the exponent, and the +exponent bias is 127. So the maximum representable unbiased exponent is 128. +Therefore, we represent the infinities and some NaNs as follows: + +``` +%float32 = OpTypeFloat 32 +%inf = OpConstant %float32 0x1p+128 +%neginf = OpConstant %float32 -0x1p+128 +%aNaN = OpConstant %float32 0x1.8p+128 +%moreNaN = OpConstant %float32 -0x1.0002p+128 +``` +The assembler preserves all the bits of a NaN value. For example, the encoding +of `%aNaN` in the previous example is the same as the word with bits +`0x7fc00000`, and `%moreNaN` is encoded as `0xff800100`. + +The disassembler prints infinite, NaN, and subnormal values in hexadecimal form. +Zero and normal values are printed in decimal form with enough digits +to preserve all significand bits. + +## Arbitrary Integers + + +When writing tests it can be useful to emit an invalid 32 bit word into the +binary stream at arbitrary positions within the assembly. To specify an +arbitrary word into the stream the prefix `!` is used, this takes the form +`!`. Here is an example. + +``` +OpCapability !0x0000FF00 +``` + +Any token in a valid assembly program may be replaced by `!` -- even +tokens that dictate how the rest of the instruction is parsed. Consider, for +example, the following assembly program: + +``` +%4 = OpConstant %1 123 456 789 OpExecutionMode %2 LocalSize 11 22 33 +OpExecutionMode %3 InputLines +``` + +The tokens `OpConstant`, `LocalSize`, and `InputLines` may be replaced by random +`!` values, and the assembler will still assemble an output binary with +three instructions. It will not necessarily be valid SPIR-V, but it will +faithfully reflect the input text. + +You may wonder how the assembler recognizes the instruction structure (including +instruction boundaries) in the text with certain crucial tokens replaced by +arbitrary integers. If, say, `OpConstant` becomes a `!` whose value +differs from the binary representation of `OpConstant` (remember that this +feature is intended for fine-grain control in SPIR-V testing), the assembler +generally has no idea what that value stands for. So how does it know there is +exactly one `` and three number literals following in that instruction, +before the next one begins? And if `LocalSize` is replaced by an arbitrary +`!`, how does it know to take the next three tokens (instead of zero or +one, both of which are possible in the absence of certainty that `LocalSize` +provided)? The answer is a simple rule governing the parsing of instructions +with `!` in them: + +When a token in the assembly program is a `!`, that integer value is +emitted into the binary output, and parsing proceeds differently than before: +each subsequent token not recognized as an OpCode or a is emitted +into the binary output without any checking; when a recognizable OpCode or a + is eventually encountered, it begins a new instruction and parsing +returns to normal. (If a subsequent OpCode is never found, then this alternate +parsing mode handles all the remaining tokens in the program.) + +The assembler processes the tokens encountered in alternate parsing mode as +follows: + +* If the token is a number literal, since context may be lost, the number + is interpreted as a 32-bit value and output as a single word. In order to + specify multiple-word literals in alternate-parsing mode, further uses of + `!` tokens may be required. + All formats supported by `strtoul()` are accepted. +* If the token is a string literal, it outputs a sequence of words representing + the string as defined in the SPIR-V specification for Literal String. +* If the token is an ID, it outputs the ID's internal number. +* If the token is another `!`, it outputs that integer. +* Any other token causes the assembler to quit with an error. + +Note that this has some interesting consequences, including: + +* When an OpCode is replaced by `!`, the integer value should encode + the instruction's word count, as specified in the physical-layout section of + the SPIR-V specification. + +* Consecutive instructions may have their OpCode replaced by `!` and + still produce valid SPIR-V. For example, `!262187 %1 %2 "abc" !327739 %1 %3 6 + %2` will successfully assemble into SPIR-V declaring a constant and a + PrivateGlobal variable. + +* Enums (such as `DontInline` or `SubgroupMemory`, for instance) are not handled + by the alternate parsing mode. They must be replaced by `!` for + successful assembly. + +* The `` on the left-hand side of an assignment cannot be a + `!`. The `` can be still be manually controlled if desired + by expressing the entire instruction as `!` tokens for its opcode and + operands. + +* The `=` sign cannot be processed by the alternate parsing mode if the OpCode + following it is a `!`. + +* When replacing a named ID with `!`, it is possible to generate + unintentionally valid SPIR-V. If the integer provided happens to equal a + number generated for an existing named ID, it will result in a reference to + that named ID being output. This may be valid SPIR-V, contrary to the + presumed intention of the writer. + +## Notes + +* Some enumerants cannot be used by name, because the target instruction +in which they are meaningful take an ID reference instead of a literal value. +For example: + * Named enumerated value `CmdExecTime` from section 3.30 Kernel + Profiling Info is used in constructing a mask value supplied as + an ID for `OpCaptureEventProfilingInfo`. But no other instruction + has enough context to bring the enumerant names from section 3.30 + into scope. + * Similarly, the names in section 3.29 Kernel Enqueue Flags are used to + construct a value supplied as an ID to the Flags argument of + OpEnqueueKernel. + * Similarly for the names in section 3.25 Memory Semantics. + * Similarly for the names in section 3.27 Scope. +* Some enumerants cannot be used by name, because they only name values +returned by an instruction: + * Enumerants from 3.12 Image Channel Order name possible values returned + by the `OpImageQueryOrder` instruction. + * Enumerants from 3.13 Image Channel Data Type name possible values + returned by the `OpImageQueryFormat` instruction. diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt new file mode 100644 index 000000000..9226ea7f7 --- /dev/null +++ b/test/CMakeLists.txt @@ -0,0 +1,214 @@ +# Copyright (c) 2015-2016 The Khronos Group Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Add a SPIR-V Tools unit test. Signature: +# add_spvtools_unittest( +# TARGET target_name +# SRCS src_file.h src_file.cpp +# LIBS lib1 lib2 +# ) + +if (NOT "${SPIRV_SKIP_TESTS}") + if (TARGET gmock_main) + message(STATUS "Found Google Mock, building tests.") + else() + message(STATUS "Did not find googletest, tests will not be built. " + "To enable tests place googletest in '/external/googletest'.") + endif() +endif() + +function(add_spvtools_unittest) + if (NOT "${SPIRV_SKIP_TESTS}" AND TARGET gmock_main) + set(one_value_args TARGET PCH_FILE) + set(multi_value_args SRCS LIBS ENVIRONMENT) + cmake_parse_arguments( + ARG "" "${one_value_args}" "${multi_value_args}" ${ARGN}) + set(target test_${ARG_TARGET}) + set(SRC_COPY ${ARG_SRCS}) + if (DEFINED ARG_PCH_FILE) + spvtools_pch(SRC_COPY ${ARG_PCH_FILE}) + endif() + add_executable(${target} ${SRC_COPY}) + spvtools_default_compile_options(${target}) + if(${COMPILER_IS_LIKE_GNU}) + target_compile_options(${target} PRIVATE -Wno-undef) + # Effcee and RE2 headers exhibit shadowing. + target_compile_options(${target} PRIVATE -Wno-shadow) + endif() + if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC") + # Disable C4503 "decorated name length exceeded" warning, + # triggered by some heavily templated types. + # We don't care much about that in test code. + # Important to do since we have warnings-as-errors. + target_compile_options(${target} PRIVATE /wd4503) + # Googletest accidentally turns off support for ::testing::Combine + # in VS 2017. See https://github.com/google/googletest/issues/1352 + # Forcibly turn it on again. + target_compile_options(${target} PRIVATE /DGTEST_HAS_COMBINE=1) + endif() + target_include_directories(${target} PRIVATE + ${SPIRV_HEADER_INCLUDE_DIR} + ${spirv-tools_SOURCE_DIR} + ${spirv-tools_SOURCE_DIR}/include + ${spirv-tools_SOURCE_DIR}/test + ${spirv-tools_BINARY_DIR} + ${gtest_SOURCE_DIR}/include + ${gmock_SOURCE_DIR}/include + ) + if (TARGET effcee) + # If using Effcee for testing, then add its include directory. + target_include_directories(${target} PRIVATE ${effcee_SOURCE_DIR}) + endif() + target_link_libraries(${target} PRIVATE ${ARG_LIBS}) + if (TARGET effcee) + target_link_libraries(${target} PRIVATE effcee) + endif() + target_link_libraries(${target} PRIVATE gmock_main) + add_test(NAME spirv-tools-${target} COMMAND ${target}) + if (DEFINED ARG_ENVIRONMENT) + set_tests_properties(spirv-tools-${target} PROPERTIES ENVIRONMENT ${ARG_ENVIRONMENT}) + endif() + set_property(TARGET ${target} PROPERTY FOLDER "SPIRV-Tools tests") + endif() +endfunction() + +set(TEST_SOURCES + test_fixture.h + unit_spirv.h + + assembly_context_test.cpp + assembly_format_test.cpp + binary_destroy_test.cpp + binary_endianness_test.cpp + binary_header_get_test.cpp + binary_parse_test.cpp + binary_strnlen_s_test.cpp + binary_to_text_test.cpp + binary_to_text.literal_test.cpp + comment_test.cpp + diagnostic_test.cpp + enum_string_mapping_test.cpp + enum_set_test.cpp + ext_inst.debuginfo_test.cpp + ext_inst.glsl_test.cpp + ext_inst.opencl_test.cpp + fix_word_test.cpp + generator_magic_number_test.cpp + hex_float_test.cpp + immediate_int_test.cpp + libspirv_macros_test.cpp + log_test.cpp + named_id_test.cpp + name_mapper_test.cpp + opcode_make_test.cpp + opcode_require_capabilities_test.cpp + opcode_split_test.cpp + opcode_table_get_test.cpp + operand_capabilities_test.cpp + operand_test.cpp + operand_pattern_test.cpp + parse_number_test.cpp + preserve_numeric_ids_test.cpp + software_version_test.cpp + string_utils_test.cpp + target_env_test.cpp + text_advance_test.cpp + text_destroy_test.cpp + text_literal_test.cpp + text_start_new_inst_test.cpp + text_to_binary.annotation_test.cpp + text_to_binary.barrier_test.cpp + text_to_binary.constant_test.cpp + text_to_binary.control_flow_test.cpp + text_to_binary_test.cpp + text_to_binary.debug_test.cpp + text_to_binary.device_side_enqueue_test.cpp + text_to_binary.extension_test.cpp + text_to_binary.function_test.cpp + text_to_binary.group_test.cpp + text_to_binary.image_test.cpp + text_to_binary.literal_test.cpp + text_to_binary.memory_test.cpp + text_to_binary.misc_test.cpp + text_to_binary.mode_setting_test.cpp + text_to_binary.pipe_storage_test.cpp + text_to_binary.type_declaration_test.cpp + text_to_binary.subgroup_dispatch_test.cpp + text_to_binary.reserved_sampling_test.cpp + text_word_get_test.cpp + + unit_spirv.cpp +) + +spvtools_pch(TEST_SOURCES pch_test) + +add_spvtools_unittest( + TARGET spirv_unit_tests + SRCS ${TEST_SOURCES} + LIBS ${SPIRV_TOOLS}) + +add_spvtools_unittest( + TARGET c_interface + SRCS c_interface_test.cpp + LIBS ${SPIRV_TOOLS}) + +add_spvtools_unittest( + TARGET c_interface_shared + SRCS c_interface_test.cpp + LIBS ${SPIRV_TOOLS}-shared + ENVIRONMENT PATH=$) + +add_spvtools_unittest( + TARGET cpp_interface + SRCS cpp_interface_test.cpp + LIBS SPIRV-Tools-opt) + +if (${SPIRV_TIMER_ENABLED}) +add_spvtools_unittest( + TARGET timer + SRCS timer_test.cpp + LIBS ${SPIRV_TOOLS}) +endif() + + +add_spvtools_unittest( + TARGET bit_stream + SRCS bit_stream.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/bit_stream.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/bit_stream.h + LIBS ${SPIRV_TOOLS}) + +add_spvtools_unittest( + TARGET huffman_codec + SRCS huffman_codec.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/bit_stream.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/bit_stream.h + ${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/huffman_codec.h + LIBS ${SPIRV_TOOLS}) + +add_spvtools_unittest( + TARGET move_to_front + SRCS move_to_front_test.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/move_to_front.h + ${CMAKE_CURRENT_SOURCE_DIR}/../source/comp/move_to_front.cpp + LIBS ${SPIRV_TOOLS}) + +add_subdirectory(comp) +add_subdirectory(link) +add_subdirectory(opt) +add_subdirectory(reduce) +add_subdirectory(stats) +add_subdirectory(tools) +add_subdirectory(util) +add_subdirectory(val) diff --git a/test/assembly_context_test.cpp b/test/assembly_context_test.cpp new file mode 100644 index 000000000..b6d60b95d --- /dev/null +++ b/test/assembly_context_test.cpp @@ -0,0 +1,77 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gmock/gmock.h" +#include "source/instruction.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using spvtest::AutoText; +using spvtest::Concatenate; +using ::testing::Eq; + +struct EncodeStringCase { + std::string str; + std::vector initial_contents; +}; + +using EncodeStringTest = ::testing::TestWithParam; + +TEST_P(EncodeStringTest, Sample) { + AssemblyContext context(AutoText(""), nullptr); + spv_instruction_t inst; + inst.words = GetParam().initial_contents; + ASSERT_EQ(SPV_SUCCESS, + context.binaryEncodeString(GetParam().str.c_str(), &inst)); + // We already trust MakeVector + EXPECT_THAT(inst.words, + Eq(Concatenate({GetParam().initial_contents, + spvtest::MakeVector(GetParam().str)}))); +} + +// clang-format off +INSTANTIATE_TEST_CASE_P( + BinaryEncodeString, EncodeStringTest, + ::testing::ValuesIn(std::vector{ + // Use cases that exercise at least one to two words, + // and both empty and non-empty initial contents. + {"", {}}, + {"", {1,2,3}}, + {"a", {}}, + {"a", {4}}, + {"ab", {4}}, + {"abc", {}}, + {"abc", {18}}, + {"abcd", {}}, + {"abcd", {22}}, + {"abcde", {4}}, + {"abcdef", {}}, + {"abcdef", {99,42}}, + {"abcdefg", {}}, + {"abcdefg", {101}}, + {"abcdefgh", {}}, + {"abcdefgh", {102, 103, 104}}, + // A very long string, encoded after an initial word. + // SPIR-V limits strings to 65535 characters. + {std::string(65535, 'a'), {1}}, + }),); +// clang-format on + +} // namespace +} // namespace spvtools diff --git a/test/assembly_format_test.cpp b/test/assembly_format_test.cpp new file mode 100644 index 000000000..59e500b81 --- /dev/null +++ b/test/assembly_format_test.cpp @@ -0,0 +1,37 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/test_fixture.h" + +namespace svptools { +namespace { + +using spvtest::ScopedContext; +using spvtest::TextToBinaryTest; + +TEST_F(TextToBinaryTest, NotPlacingResultIDAtTheBeginning) { + SetText("OpTypeMatrix %1 %2 1000"); + EXPECT_EQ(SPV_ERROR_INVALID_TEXT, + spvTextToBinary(ScopedContext().context, text.str, text.length, + &binary, &diagnostic)); + ASSERT_NE(nullptr, diagnostic); + EXPECT_STREQ( + "Expected at the beginning of an instruction, found " + "'OpTypeMatrix'.", + diagnostic->error); + EXPECT_EQ(0u, diagnostic->position.line); +} + +} // namespace +} // namespace svptools diff --git a/test/binary_destroy_test.cpp b/test/binary_destroy_test.cpp new file mode 100644 index 000000000..e3870c9f0 --- /dev/null +++ b/test/binary_destroy_test.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/unit_spirv.h" + +#include "test/test_fixture.h" + +namespace spvtools { +namespace { + +using spvtest::ScopedContext; + +TEST(BinaryDestroy, Null) { + // There is no state or return value to check. Just check + // for the ability to call the API without abnormal termination. + spvBinaryDestroy(nullptr); +} + +using BinaryDestroySomething = spvtest::TextToBinaryTest; + +// Checks safety of destroying a validly constructed binary. +TEST_F(BinaryDestroySomething, Default) { + // Use a binary object constructed by the API instead of rolling our own. + SetText("OpSource OpenCL_C 120"); + spv_binary my_binary = nullptr; + ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(ScopedContext().context, text.str, + text.length, &my_binary, &diagnostic)); + ASSERT_NE(nullptr, my_binary); + spvBinaryDestroy(my_binary); +} + +} // namespace +} // namespace spvtools diff --git a/test/binary_endianness_test.cpp b/test/binary_endianness_test.cpp new file mode 100644 index 000000000..3cd405d52 --- /dev/null +++ b/test/binary_endianness_test.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +TEST(BinaryEndianness, InvalidCode) { + uint32_t invalidMagicNumber[] = {0}; + spv_const_binary_t binary = {invalidMagicNumber, 1}; + spv_endianness_t endian; + ASSERT_EQ(SPV_ERROR_INVALID_BINARY, spvBinaryEndianness(&binary, &endian)); +} + +TEST(BinaryEndianness, Little) { + uint32_t magicNumber; + if (I32_ENDIAN_HOST == I32_ENDIAN_LITTLE) { + magicNumber = 0x07230203; + } else { + magicNumber = 0x03022307; + } + spv_const_binary_t binary = {&magicNumber, 1}; + spv_endianness_t endian; + ASSERT_EQ(SPV_SUCCESS, spvBinaryEndianness(&binary, &endian)); + ASSERT_EQ(SPV_ENDIANNESS_LITTLE, endian); +} + +TEST(BinaryEndianness, Big) { + uint32_t magicNumber; + if (I32_ENDIAN_HOST == I32_ENDIAN_BIG) { + magicNumber = 0x07230203; + } else { + magicNumber = 0x03022307; + } + spv_const_binary_t binary = {&magicNumber, 1}; + spv_endianness_t endian; + ASSERT_EQ(SPV_SUCCESS, spvBinaryEndianness(&binary, &endian)); + ASSERT_EQ(SPV_ENDIANNESS_BIG, endian); +} + +} // namespace +} // namespace spvtools diff --git a/test/binary_header_get_test.cpp b/test/binary_header_get_test.cpp new file mode 100644 index 000000000..e771f1a39 --- /dev/null +++ b/test/binary_header_get_test.cpp @@ -0,0 +1,84 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/spirv_constant.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +class BinaryHeaderGet : public ::testing::Test { + public: + BinaryHeaderGet() { memset(code, 0, sizeof(code)); } + + virtual void SetUp() { + code[0] = SpvMagicNumber; + code[1] = SpvVersion; + code[2] = SPV_GENERATOR_CODEPLAY; + code[3] = 1; // NOTE: Bound + code[4] = 0; // NOTE: Schema; reserved + code[5] = 0; // NOTE: Instructions + + binary.code = code; + binary.wordCount = 6; + } + spv_const_binary_t get_const_binary() { + return spv_const_binary_t{binary.code, binary.wordCount}; + } + virtual void TearDown() {} + + uint32_t code[6]; + spv_binary_t binary; +}; + +TEST_F(BinaryHeaderGet, Default) { + spv_endianness_t endian; + spv_const_binary_t const_bin = get_const_binary(); + ASSERT_EQ(SPV_SUCCESS, spvBinaryEndianness(&const_bin, &endian)); + + spv_header_t header; + ASSERT_EQ(SPV_SUCCESS, spvBinaryHeaderGet(&const_bin, endian, &header)); + + ASSERT_EQ(static_cast(SpvMagicNumber), header.magic); + ASSERT_EQ(0x00010300u, header.version); + ASSERT_EQ(static_cast(SPV_GENERATOR_CODEPLAY), header.generator); + ASSERT_EQ(1u, header.bound); + ASSERT_EQ(0u, header.schema); + ASSERT_EQ(&code[5], header.instructions); +} + +TEST_F(BinaryHeaderGet, InvalidCode) { + spv_const_binary_t my_binary = {nullptr, 0}; + spv_header_t header; + ASSERT_EQ(SPV_ERROR_INVALID_BINARY, + spvBinaryHeaderGet(&my_binary, SPV_ENDIANNESS_LITTLE, &header)); +} + +TEST_F(BinaryHeaderGet, InvalidPointerHeader) { + spv_const_binary_t const_bin = get_const_binary(); + ASSERT_EQ(SPV_ERROR_INVALID_POINTER, + spvBinaryHeaderGet(&const_bin, SPV_ENDIANNESS_LITTLE, nullptr)); +} + +TEST_F(BinaryHeaderGet, TruncatedHeader) { + for (uint8_t i = 1; i < SPV_INDEX_INSTRUCTION; i++) { + binary.wordCount = i; + spv_const_binary_t const_bin = get_const_binary(); + ASSERT_EQ(SPV_ERROR_INVALID_BINARY, + spvBinaryHeaderGet(&const_bin, SPV_ENDIANNESS_LITTLE, nullptr)); + } +} + +} // namespace +} // namespace spvtools diff --git a/test/binary_parse_test.cpp b/test/binary_parse_test.cpp new file mode 100644 index 000000000..749e12fd6 --- /dev/null +++ b/test/binary_parse_test.cpp @@ -0,0 +1,892 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/latest_version_opencl_std_header.h" +#include "source/table.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +// Returns true if two spv_parsed_operand_t values are equal. +// To use this operator, this definition must appear in the same namespace +// as spv_parsed_operand_t. +static bool operator==(const spv_parsed_operand_t& a, + const spv_parsed_operand_t& b) { + return a.offset == b.offset && a.num_words == b.num_words && + a.type == b.type && a.number_kind == b.number_kind && + a.number_bit_width == b.number_bit_width; +} + +namespace spvtools { +namespace { + +using ::spvtest::Concatenate; +using ::spvtest::MakeInstruction; +using ::spvtest::MakeVector; +using ::spvtest::ScopedContext; +using ::testing::_; +using ::testing::AnyOf; +using ::testing::Eq; +using ::testing::InSequence; +using ::testing::Return; + +// An easily-constructible and comparable object for the contents of an +// spv_parsed_instruction_t. Unlike spv_parsed_instruction_t, owns the memory +// of its components. +struct ParsedInstruction { + explicit ParsedInstruction(const spv_parsed_instruction_t& inst) + : words(inst.words, inst.words + inst.num_words), + opcode(static_cast(inst.opcode)), + ext_inst_type(inst.ext_inst_type), + type_id(inst.type_id), + result_id(inst.result_id), + operands(inst.operands, inst.operands + inst.num_operands) {} + + std::vector words; + SpvOp opcode; + spv_ext_inst_type_t ext_inst_type; + uint32_t type_id; + uint32_t result_id; + std::vector operands; + + bool operator==(const ParsedInstruction& b) const { + return words == b.words && opcode == b.opcode && + ext_inst_type == b.ext_inst_type && type_id == b.type_id && + result_id == b.result_id && operands == b.operands; + } +}; + +// Prints a ParsedInstruction object to the given output stream, and returns +// the stream. +std::ostream& operator<<(std::ostream& os, const ParsedInstruction& inst) { + os << "\nParsedInstruction( {"; + spvtest::PrintTo(spvtest::WordVector(inst.words), &os); + os << "}, opcode: " << int(inst.opcode) + << " ext_inst_type: " << int(inst.ext_inst_type) + << " type_id: " << inst.type_id << " result_id: " << inst.result_id; + for (const auto& operand : inst.operands) { + os << " { offset: " << operand.offset << " num_words: " << operand.num_words + << " type: " << int(operand.type) + << " number_kind: " << int(operand.number_kind) + << " number_bit_width: " << int(operand.number_bit_width) << "}"; + } + os << ")"; + return os; +} + +// Sanity check for the equality operator on ParsedInstruction. +TEST(ParsedInstruction, ZeroInitializedAreEqual) { + spv_parsed_instruction_t pi = {}; + ParsedInstruction a(pi); + ParsedInstruction b(pi); + EXPECT_THAT(a, ::testing::TypedEq(b)); +} + +// Googlemock class receiving Header/Instruction calls from spvBinaryParse(). +class MockParseClient { + public: + MOCK_METHOD6(Header, spv_result_t(spv_endianness_t endian, uint32_t magic, + uint32_t version, uint32_t generator, + uint32_t id_bound, uint32_t reserved)); + MOCK_METHOD1(Instruction, spv_result_t(const ParsedInstruction&)); +}; + +// Casts user_data as MockParseClient and invokes its Header(). +spv_result_t invoke_header(void* user_data, spv_endianness_t endian, + uint32_t magic, uint32_t version, uint32_t generator, + uint32_t id_bound, uint32_t reserved) { + return static_cast(user_data)->Header( + endian, magic, version, generator, id_bound, reserved); +} + +// Casts user_data as MockParseClient and invokes its Instruction(). +spv_result_t invoke_instruction( + void* user_data, const spv_parsed_instruction_t* parsed_instruction) { + return static_cast(user_data)->Instruction( + ParsedInstruction(*parsed_instruction)); +} + +// The SPIR-V module header words for the Khronos Assembler generator, +// for a module with an ID bound of 1. +const uint32_t kHeaderForBound1[] = { + SpvMagicNumber, SpvVersion, + SPV_GENERATOR_WORD(SPV_GENERATOR_KHRONOS_ASSEMBLER, 0), 1 /*bound*/, + 0 /*schema*/}; + +// Returns the expected SPIR-V module header words for the Khronos +// Assembler generator, and with a given Id bound. +std::vector ExpectedHeaderForBound(uint32_t bound) { + return {SpvMagicNumber, 0x10000, + SPV_GENERATOR_WORD(SPV_GENERATOR_KHRONOS_ASSEMBLER, 0), bound, 0}; +} + +// Returns a parsed operand for a non-number value at the given word offset +// within an instruction. +spv_parsed_operand_t MakeSimpleOperand(uint16_t offset, + spv_operand_type_t type) { + return {offset, 1, type, SPV_NUMBER_NONE, 0}; +} + +// Returns a parsed operand for a literal unsigned integer value at the given +// word offset within an instruction. +spv_parsed_operand_t MakeLiteralNumberOperand(uint16_t offset) { + return {offset, 1, SPV_OPERAND_TYPE_LITERAL_INTEGER, SPV_NUMBER_UNSIGNED_INT, + 32}; +} + +// Returns a parsed operand for a literal string value at the given +// word offset within an instruction. +spv_parsed_operand_t MakeLiteralStringOperand(uint16_t offset, + uint16_t length) { + return {offset, length, SPV_OPERAND_TYPE_LITERAL_STRING, SPV_NUMBER_NONE, 0}; +} + +// Returns a ParsedInstruction for an OpTypeVoid instruction that would +// generate the given result Id. +ParsedInstruction MakeParsedVoidTypeInstruction(uint32_t result_id) { + const auto void_inst = MakeInstruction(SpvOpTypeVoid, {result_id}); + const auto void_operands = std::vector{ + MakeSimpleOperand(1, SPV_OPERAND_TYPE_RESULT_ID)}; + const spv_parsed_instruction_t parsed_void_inst = { + void_inst.data(), + static_cast(void_inst.size()), + SpvOpTypeVoid, + SPV_EXT_INST_TYPE_NONE, + 0, // type id + result_id, + void_operands.data(), + static_cast(void_operands.size())}; + return ParsedInstruction(parsed_void_inst); +} + +// Returns a ParsedInstruction for an OpTypeInt instruction that generates +// the given result Id for a 32-bit signed integer scalar type. +ParsedInstruction MakeParsedInt32TypeInstruction(uint32_t result_id) { + const auto i32_inst = MakeInstruction(SpvOpTypeInt, {result_id, 32, 1}); + const auto i32_operands = std::vector{ + MakeSimpleOperand(1, SPV_OPERAND_TYPE_RESULT_ID), + MakeLiteralNumberOperand(2), MakeLiteralNumberOperand(3)}; + spv_parsed_instruction_t parsed_i32_inst = { + i32_inst.data(), + static_cast(i32_inst.size()), + SpvOpTypeInt, + SPV_EXT_INST_TYPE_NONE, + 0, // type id + result_id, + i32_operands.data(), + static_cast(i32_operands.size())}; + return ParsedInstruction(parsed_i32_inst); +} + +class BinaryParseTest : public spvtest::TextToBinaryTestBase<::testing::Test> { + protected: + ~BinaryParseTest() { spvDiagnosticDestroy(diagnostic_); } + + void Parse(const SpirvVector& words, spv_result_t expected_result, + bool flip_words = false) { + SpirvVector flipped_words(words); + SCOPED_TRACE(flip_words ? "Flipped Endianness" : "Normal Endianness"); + if (flip_words) { + std::transform(flipped_words.begin(), flipped_words.end(), + flipped_words.begin(), [](const uint32_t raw_word) { + return spvFixWord(raw_word, + I32_ENDIAN_HOST == I32_ENDIAN_BIG + ? SPV_ENDIANNESS_LITTLE + : SPV_ENDIANNESS_BIG); + }); + } + EXPECT_EQ(expected_result, + spvBinaryParse(ScopedContext().context, &client_, + flipped_words.data(), flipped_words.size(), + invoke_header, invoke_instruction, &diagnostic_)); + } + + spv_diagnostic diagnostic_ = nullptr; + MockParseClient client_; +}; + +// Adds an EXPECT_CALL to client_->Header() with appropriate parameters, +// including bound. Returns the EXPECT_CALL result. +#define EXPECT_HEADER(bound) \ + EXPECT_CALL( \ + client_, \ + Header(AnyOf(SPV_ENDIANNESS_LITTLE, SPV_ENDIANNESS_BIG), SpvMagicNumber, \ + 0x10000, SPV_GENERATOR_WORD(SPV_GENERATOR_KHRONOS_ASSEMBLER, 0), \ + bound, 0 /*reserved*/)) + +static const bool kSwapEndians[] = {false, true}; + +TEST_F(BinaryParseTest, EmptyModuleHasValidHeaderAndNoInstructionCallbacks) { + for (bool endian_swap : kSwapEndians) { + const auto words = CompileSuccessfully(""); + EXPECT_HEADER(1).WillOnce(Return(SPV_SUCCESS)); + EXPECT_CALL(client_, Instruction(_)).Times(0); // No instruction callback. + Parse(words, SPV_SUCCESS, endian_swap); + EXPECT_EQ(nullptr, diagnostic_); + } +} + +TEST_F(BinaryParseTest, NullDiagnosticsIsOkForGoodParse) { + const auto words = CompileSuccessfully(""); + EXPECT_HEADER(1).WillOnce(Return(SPV_SUCCESS)); + EXPECT_CALL(client_, Instruction(_)).Times(0); // No instruction callback. + EXPECT_EQ( + SPV_SUCCESS, + spvBinaryParse(ScopedContext().context, &client_, words.data(), + words.size(), invoke_header, invoke_instruction, nullptr)); +} + +TEST_F(BinaryParseTest, NullDiagnosticsIsOkForBadParse) { + auto words = CompileSuccessfully(""); + words.push_back(0xffffffff); // Certainly invalid instruction header. + EXPECT_HEADER(1).WillOnce(Return(SPV_SUCCESS)); + EXPECT_CALL(client_, Instruction(_)).Times(0); // No instruction callback. + EXPECT_EQ( + SPV_ERROR_INVALID_BINARY, + spvBinaryParse(ScopedContext().context, &client_, words.data(), + words.size(), invoke_header, invoke_instruction, nullptr)); +} + +// Make sure that we don't blow up when both the consumer and the diagnostic are +// null. +TEST_F(BinaryParseTest, NullConsumerNullDiagnosticsForBadParse) { + auto words = CompileSuccessfully(""); + + auto ctx = spvtools::Context(SPV_ENV_UNIVERSAL_1_1); + ctx.SetMessageConsumer(nullptr); + + words.push_back(0xffffffff); // Certainly invalid instruction header. + EXPECT_HEADER(1).WillOnce(Return(SPV_SUCCESS)); + EXPECT_CALL(client_, Instruction(_)).Times(0); // No instruction callback. + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + spvBinaryParse(ctx.CContext(), &client_, words.data(), words.size(), + invoke_header, invoke_instruction, nullptr)); +} + +TEST_F(BinaryParseTest, SpecifyConsumerNullDiagnosticsForGoodParse) { + const auto words = CompileSuccessfully(""); + + auto ctx = spvtools::Context(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + ctx.SetMessageConsumer([&invocation](spv_message_level_t, const char*, + const spv_position_t&, + const char*) { ++invocation; }); + + EXPECT_HEADER(1).WillOnce(Return(SPV_SUCCESS)); + EXPECT_CALL(client_, Instruction(_)).Times(0); // No instruction callback. + EXPECT_EQ(SPV_SUCCESS, + spvBinaryParse(ctx.CContext(), &client_, words.data(), words.size(), + invoke_header, invoke_instruction, nullptr)); + EXPECT_EQ(0, invocation); +} + +TEST_F(BinaryParseTest, SpecifyConsumerNullDiagnosticsForBadParse) { + auto words = CompileSuccessfully(""); + + auto ctx = spvtools::Context(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + ctx.SetMessageConsumer( + [&invocation](spv_message_level_t level, const char* source, + const spv_position_t& position, const char* message) { + ++invocation; + EXPECT_EQ(SPV_MSG_ERROR, level); + EXPECT_STREQ("input", source); + EXPECT_EQ(0u, position.line); + EXPECT_EQ(0u, position.column); + EXPECT_EQ(1u, position.index); + EXPECT_STREQ("Invalid opcode: 65535", message); + }); + + words.push_back(0xffffffff); // Certainly invalid instruction header. + EXPECT_HEADER(1).WillOnce(Return(SPV_SUCCESS)); + EXPECT_CALL(client_, Instruction(_)).Times(0); // No instruction callback. + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + spvBinaryParse(ctx.CContext(), &client_, words.data(), words.size(), + invoke_header, invoke_instruction, nullptr)); + EXPECT_EQ(1, invocation); +} + +TEST_F(BinaryParseTest, SpecifyConsumerSpecifyDiagnosticsForGoodParse) { + const auto words = CompileSuccessfully(""); + + auto ctx = spvtools::Context(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + ctx.SetMessageConsumer([&invocation](spv_message_level_t, const char*, + const spv_position_t&, + const char*) { ++invocation; }); + + EXPECT_HEADER(1).WillOnce(Return(SPV_SUCCESS)); + EXPECT_CALL(client_, Instruction(_)).Times(0); // No instruction callback. + EXPECT_EQ(SPV_SUCCESS, + spvBinaryParse(ctx.CContext(), &client_, words.data(), words.size(), + invoke_header, invoke_instruction, &diagnostic_)); + EXPECT_EQ(0, invocation); + EXPECT_EQ(nullptr, diagnostic_); +} + +TEST_F(BinaryParseTest, SpecifyConsumerSpecifyDiagnosticsForBadParse) { + auto words = CompileSuccessfully(""); + + auto ctx = spvtools::Context(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + ctx.SetMessageConsumer([&invocation](spv_message_level_t, const char*, + const spv_position_t&, + const char*) { ++invocation; }); + + words.push_back(0xffffffff); // Certainly invalid instruction header. + EXPECT_HEADER(1).WillOnce(Return(SPV_SUCCESS)); + EXPECT_CALL(client_, Instruction(_)).Times(0); // No instruction callback. + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + spvBinaryParse(ctx.CContext(), &client_, words.data(), words.size(), + invoke_header, invoke_instruction, &diagnostic_)); + EXPECT_EQ(0, invocation); + EXPECT_STREQ("Invalid opcode: 65535", diagnostic_->error); +} + +TEST_F(BinaryParseTest, + ModuleWithSingleInstructionHasValidHeaderAndInstructionCallback) { + for (bool endian_swap : kSwapEndians) { + const auto words = CompileSuccessfully("%1 = OpTypeVoid"); + InSequence calls_expected_in_specific_order; + EXPECT_HEADER(2).WillOnce(Return(SPV_SUCCESS)); + EXPECT_CALL(client_, Instruction(MakeParsedVoidTypeInstruction(1))) + .WillOnce(Return(SPV_SUCCESS)); + Parse(words, SPV_SUCCESS, endian_swap); + EXPECT_EQ(nullptr, diagnostic_); + } +} + +TEST_F(BinaryParseTest, NullHeaderCallbackIsIgnored) { + const auto words = CompileSuccessfully("%1 = OpTypeVoid"); + EXPECT_CALL(client_, Header(_, _, _, _, _, _)) + .Times(0); // No header callback. + EXPECT_CALL(client_, Instruction(MakeParsedVoidTypeInstruction(1))) + .WillOnce(Return(SPV_SUCCESS)); + EXPECT_EQ(SPV_SUCCESS, spvBinaryParse(ScopedContext().context, &client_, + words.data(), words.size(), nullptr, + invoke_instruction, &diagnostic_)); + EXPECT_EQ(nullptr, diagnostic_); +} + +TEST_F(BinaryParseTest, NullInstructionCallbackIsIgnored) { + const auto words = CompileSuccessfully("%1 = OpTypeVoid"); + EXPECT_HEADER((2)).WillOnce(Return(SPV_SUCCESS)); + EXPECT_CALL(client_, Instruction(_)).Times(0); // No instruction callback. + EXPECT_EQ(SPV_SUCCESS, + spvBinaryParse(ScopedContext().context, &client_, words.data(), + words.size(), invoke_header, nullptr, &diagnostic_)); + EXPECT_EQ(nullptr, diagnostic_); +} + +// Check the result of multiple instruction callbacks. +// +// This test exercises non-default values for the following members of the +// spv_parsed_instruction_t struct: words, num_words, opcode, result_id, +// operands, num_operands. +TEST_F(BinaryParseTest, TwoScalarTypesGenerateTwoInstructionCallbacks) { + for (bool endian_swap : kSwapEndians) { + const auto words = CompileSuccessfully( + "%1 = OpTypeVoid " + "%2 = OpTypeInt 32 1"); + InSequence calls_expected_in_specific_order; + EXPECT_HEADER(3).WillOnce(Return(SPV_SUCCESS)); + EXPECT_CALL(client_, Instruction(MakeParsedVoidTypeInstruction(1))) + .WillOnce(Return(SPV_SUCCESS)); + EXPECT_CALL(client_, Instruction(MakeParsedInt32TypeInstruction(2))) + .WillOnce(Return(SPV_SUCCESS)); + Parse(words, SPV_SUCCESS, endian_swap); + EXPECT_EQ(nullptr, diagnostic_); + } +} + +TEST_F(BinaryParseTest, EarlyReturnWithZeroPassingCallbacks) { + for (bool endian_swap : kSwapEndians) { + const auto words = CompileSuccessfully( + "%1 = OpTypeVoid " + "%2 = OpTypeInt 32 1"); + InSequence calls_expected_in_specific_order; + EXPECT_HEADER(3).WillOnce(Return(SPV_ERROR_INVALID_BINARY)); + // Early exit means no calls to Instruction(). + EXPECT_CALL(client_, Instruction(_)).Times(0); + Parse(words, SPV_ERROR_INVALID_BINARY, endian_swap); + // On error, the binary parser doesn't generate its own diagnostics. + EXPECT_EQ(nullptr, diagnostic_); + } +} + +TEST_F(BinaryParseTest, + EarlyReturnWithZeroPassingCallbacksAndSpecifiedResultCode) { + for (bool endian_swap : kSwapEndians) { + const auto words = CompileSuccessfully( + "%1 = OpTypeVoid " + "%2 = OpTypeInt 32 1"); + InSequence calls_expected_in_specific_order; + EXPECT_HEADER(3).WillOnce(Return(SPV_REQUESTED_TERMINATION)); + // Early exit means no calls to Instruction(). + EXPECT_CALL(client_, Instruction(_)).Times(0); + Parse(words, SPV_REQUESTED_TERMINATION, endian_swap); + // On early termination, the binary parser doesn't generate its own + // diagnostics. + EXPECT_EQ(nullptr, diagnostic_); + } +} + +TEST_F(BinaryParseTest, EarlyReturnWithOnePassingCallback) { + for (bool endian_swap : kSwapEndians) { + const auto words = CompileSuccessfully( + "%1 = OpTypeVoid " + "%2 = OpTypeInt 32 1 " + "%3 = OpTypeFloat 32"); + InSequence calls_expected_in_specific_order; + EXPECT_HEADER(4).WillOnce(Return(SPV_SUCCESS)); + EXPECT_CALL(client_, Instruction(MakeParsedVoidTypeInstruction(1))) + .WillOnce(Return(SPV_REQUESTED_TERMINATION)); + Parse(words, SPV_REQUESTED_TERMINATION, endian_swap); + // On early termination, the binary parser doesn't generate its own + // diagnostics. + EXPECT_EQ(nullptr, diagnostic_); + } +} + +TEST_F(BinaryParseTest, EarlyReturnWithTwoPassingCallbacks) { + for (bool endian_swap : kSwapEndians) { + const auto words = CompileSuccessfully( + "%1 = OpTypeVoid " + "%2 = OpTypeInt 32 1 " + "%3 = OpTypeFloat 32"); + InSequence calls_expected_in_specific_order; + EXPECT_HEADER(4).WillOnce(Return(SPV_SUCCESS)); + EXPECT_CALL(client_, Instruction(MakeParsedVoidTypeInstruction(1))) + .WillOnce(Return(SPV_SUCCESS)); + EXPECT_CALL(client_, Instruction(MakeParsedInt32TypeInstruction(2))) + .WillOnce(Return(SPV_REQUESTED_TERMINATION)); + Parse(words, SPV_REQUESTED_TERMINATION, endian_swap); + // On early termination, the binary parser doesn't generate its own + // diagnostics. + EXPECT_EQ(nullptr, diagnostic_); + } +} + +TEST_F(BinaryParseTest, InstructionWithStringOperand) { + const std::string str = + "the future is already here, it's just not evenly distributed"; + const auto str_words = MakeVector(str); + const auto instruction = MakeInstruction(SpvOpName, {99}, str_words); + const auto words = Concatenate({ExpectedHeaderForBound(100), instruction}); + InSequence calls_expected_in_specific_order; + EXPECT_HEADER(100).WillOnce(Return(SPV_SUCCESS)); + const auto operands = std::vector{ + MakeSimpleOperand(1, SPV_OPERAND_TYPE_ID), + MakeLiteralStringOperand(2, static_cast(str_words.size()))}; + EXPECT_CALL(client_, + Instruction(ParsedInstruction(spv_parsed_instruction_t{ + instruction.data(), static_cast(instruction.size()), + SpvOpName, SPV_EXT_INST_TYPE_NONE, 0 /*type id*/, + 0 /* No result id for OpName*/, operands.data(), + static_cast(operands.size())}))) + .WillOnce(Return(SPV_SUCCESS)); + // Since we are actually checking the output, don't test the + // endian-swapped version. + Parse(words, SPV_SUCCESS, false); + EXPECT_EQ(nullptr, diagnostic_); +} + +// Checks for non-zero values for the result_id and ext_inst_type members +// spv_parsed_instruction_t. +TEST_F(BinaryParseTest, ExtendedInstruction) { + const auto words = CompileSuccessfully( + "%extcl = OpExtInstImport \"OpenCL.std\" " + "%result = OpExtInst %float %extcl sqrt %x"); + EXPECT_HEADER(5).WillOnce(Return(SPV_SUCCESS)); + EXPECT_CALL(client_, Instruction(_)).WillOnce(Return(SPV_SUCCESS)); + // We're only interested in the second call to Instruction(): + const auto operands = std::vector{ + MakeSimpleOperand(1, SPV_OPERAND_TYPE_TYPE_ID), + MakeSimpleOperand(2, SPV_OPERAND_TYPE_RESULT_ID), + MakeSimpleOperand(3, + SPV_OPERAND_TYPE_ID), // Extended instruction set Id + MakeSimpleOperand(4, SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER), + MakeSimpleOperand(5, SPV_OPERAND_TYPE_ID), // Id of the argument + }; + const auto instruction = MakeInstruction( + SpvOpExtInst, + {2, 3, 1, static_cast(OpenCLLIB::Entrypoints::Sqrt), 4}); + EXPECT_CALL(client_, + Instruction(ParsedInstruction(spv_parsed_instruction_t{ + instruction.data(), static_cast(instruction.size()), + SpvOpExtInst, SPV_EXT_INST_TYPE_OPENCL_STD, 2 /*type id*/, + 3 /*result id*/, operands.data(), + static_cast(operands.size())}))) + .WillOnce(Return(SPV_SUCCESS)); + // Since we are actually checking the output, don't test the + // endian-swapped version. + Parse(words, SPV_SUCCESS, false); + EXPECT_EQ(nullptr, diagnostic_); +} + +// A binary parser diagnostic test case where we provide the words array +// pointer and word count explicitly. +struct WordsAndCountDiagnosticCase { + const uint32_t* words; + size_t num_words; + std::string expected_diagnostic; +}; + +using BinaryParseWordsAndCountDiagnosticTest = spvtest::TextToBinaryTestBase< + ::testing::TestWithParam>; + +TEST_P(BinaryParseWordsAndCountDiagnosticTest, WordAndCountCases) { + EXPECT_EQ( + SPV_ERROR_INVALID_BINARY, + spvBinaryParse(ScopedContext().context, nullptr, GetParam().words, + GetParam().num_words, nullptr, nullptr, &diagnostic)); + ASSERT_NE(nullptr, diagnostic); + EXPECT_THAT(diagnostic->error, Eq(GetParam().expected_diagnostic)); +} + +INSTANTIATE_TEST_CASE_P( + BinaryParseDiagnostic, BinaryParseWordsAndCountDiagnosticTest, + ::testing::ValuesIn(std::vector{ + {nullptr, 0, "Missing module."}, + {kHeaderForBound1, 0, + "Module has incomplete header: only 0 words instead of 5"}, + {kHeaderForBound1, 1, + "Module has incomplete header: only 1 words instead of 5"}, + {kHeaderForBound1, 2, + "Module has incomplete header: only 2 words instead of 5"}, + {kHeaderForBound1, 3, + "Module has incomplete header: only 3 words instead of 5"}, + {kHeaderForBound1, 4, + "Module has incomplete header: only 4 words instead of 5"}, + }), ); + +// A binary parser diagnostic test case where a vector of words is +// provided. We'll use this to express cases that can't be created +// via the assembler. Either we want to make a malformed instruction, +// or an invalid case the assembler would reject. +struct WordVectorDiagnosticCase { + std::vector words; + std::string expected_diagnostic; +}; + +using BinaryParseWordVectorDiagnosticTest = spvtest::TextToBinaryTestBase< + ::testing::TestWithParam>; + +TEST_P(BinaryParseWordVectorDiagnosticTest, WordVectorCases) { + const auto& words = GetParam().words; + EXPECT_THAT(spvBinaryParse(ScopedContext().context, nullptr, words.data(), + words.size(), nullptr, nullptr, &diagnostic), + AnyOf(SPV_ERROR_INVALID_BINARY, SPV_ERROR_INVALID_ID)); + ASSERT_NE(nullptr, diagnostic); + EXPECT_THAT(diagnostic->error, Eq(GetParam().expected_diagnostic)); +} + +INSTANTIATE_TEST_CASE_P( + BinaryParseDiagnostic, BinaryParseWordVectorDiagnosticTest, + ::testing::ValuesIn(std::vector{ + {Concatenate({ExpectedHeaderForBound(1), {spvOpcodeMake(0, SpvOpNop)}}), + "Invalid instruction word count: 0"}, + {Concatenate( + {ExpectedHeaderForBound(1), + {spvOpcodeMake(1, static_cast( + std::numeric_limits::max()))}}), + "Invalid opcode: 65535"}, + {Concatenate({ExpectedHeaderForBound(1), + MakeInstruction(SpvOpNop, {42})}), + "Invalid instruction OpNop starting at word 5: expected " + "no more operands after 1 words, but stated word count is 2."}, + // Supply several more unexpectd words. + {Concatenate({ExpectedHeaderForBound(1), + MakeInstruction(SpvOpNop, {42, 43, 44, 45, 46, 47})}), + "Invalid instruction OpNop starting at word 5: expected " + "no more operands after 1 words, but stated word count is 7."}, + {Concatenate({ExpectedHeaderForBound(1), + MakeInstruction(SpvOpTypeVoid, {1, 2})}), + "Invalid instruction OpTypeVoid starting at word 5: expected " + "no more operands after 2 words, but stated word count is 3."}, + {Concatenate({ExpectedHeaderForBound(1), + MakeInstruction(SpvOpTypeVoid, {1, 2, 5, 9, 10})}), + "Invalid instruction OpTypeVoid starting at word 5: expected " + "no more operands after 2 words, but stated word count is 6."}, + {Concatenate({ExpectedHeaderForBound(1), + MakeInstruction(SpvOpTypeInt, {1, 32, 1, 9})}), + "Invalid instruction OpTypeInt starting at word 5: expected " + "no more operands after 4 words, but stated word count is 5."}, + {Concatenate({ExpectedHeaderForBound(1), + MakeInstruction(SpvOpTypeInt, {1})}), + "End of input reached while decoding OpTypeInt starting at word 5:" + " expected more operands after 2 words."}, + + // Check several cases for running off the end of input. + + // Detect a missing single word operand. + {Concatenate({ExpectedHeaderForBound(1), + {spvOpcodeMake(2, SpvOpTypeStruct)}}), + "End of input reached while decoding OpTypeStruct starting at word" + " 5: missing result ID operand at word offset 1."}, + // Detect this a missing a multi-word operand to OpConstant. + // We also lie and say the OpConstant instruction has 5 words when + // it only has 3. Corresponds to something like this: + // %1 = OpTypeInt 64 0 + // %2 = OpConstant %1 + {Concatenate({ExpectedHeaderForBound(3), + {MakeInstruction(SpvOpTypeInt, {1, 64, 0})}, + {spvOpcodeMake(5, SpvOpConstant), 1, 2}}), + "End of input reached while decoding OpConstant starting at word" + " 9: missing possibly multi-word literal number operand at word " + "offset 3."}, + // Detect when we provide only one word from the 64-bit literal, + // and again lie about the number of words in the instruction. + {Concatenate({ExpectedHeaderForBound(3), + {MakeInstruction(SpvOpTypeInt, {1, 64, 0})}, + {spvOpcodeMake(5, SpvOpConstant), 1, 2, 42}}), + "End of input reached while decoding OpConstant starting at word" + " 9: truncated possibly multi-word literal number operand at word " + "offset 3."}, + // Detect when a required string operand is missing. + // Also, lie about the length of the instruction. + {Concatenate({ExpectedHeaderForBound(3), + {spvOpcodeMake(3, SpvOpString), 1}}), + "End of input reached while decoding OpString starting at word" + " 5: missing literal string operand at word offset 2."}, + // Detect when a required string operand is truncated: it's missing + // a null terminator. Catching the error avoids a buffer overrun. + {Concatenate({ExpectedHeaderForBound(3), + {spvOpcodeMake(4, SpvOpString), 1, 0x41414141, + 0x41414141}}), + "End of input reached while decoding OpString starting at word" + " 5: truncated literal string operand at word offset 2."}, + // Detect when an optional string operand is truncated: it's missing + // a null terminator. Catching the error avoids a buffer overrun. + // (It is valid for an optional string operand to be absent.) + {Concatenate({ExpectedHeaderForBound(3), + {spvOpcodeMake(6, SpvOpSource), + static_cast(SpvSourceLanguageOpenCL_C), 210, + 1 /* file id */, + /*start of string*/ 0x41414141, 0x41414141}}), + "End of input reached while decoding OpSource starting at word" + " 5: truncated literal string operand at word offset 4."}, + + // (End of input exhaustion test cases.) + + // In this case the instruction word count is too small, where + // it would truncate a multi-word operand to OpConstant. + {Concatenate({ExpectedHeaderForBound(3), + {MakeInstruction(SpvOpTypeInt, {1, 64, 0})}, + {spvOpcodeMake(4, SpvOpConstant), 1, 2, 44, 44}}), + "Invalid word count: OpConstant starting at word 9 says it has 4" + " words, but found 5 words instead."}, + // Word count is to small, where it would truncate a literal string. + {Concatenate({ExpectedHeaderForBound(2), + {spvOpcodeMake(3, SpvOpString), 1, 0x41414141, 0}}), + "Invalid word count: OpString starting at word 5 says it has 3" + " words, but found 4 words instead."}, + // Word count is too large. The string terminates before the last + // word. + {Concatenate({ExpectedHeaderForBound(2), + {spvOpcodeMake(4, SpvOpString), 1 /* result id */}, + MakeVector("abc"), + {0 /* this word does not belong*/}}), + "Invalid instruction OpString starting at word 5: expected no more" + " operands after 3 words, but stated word count is 4."}, + // Word count is too large. There are too many words after the string + // literal. A linkage attribute decoration is the only case in SPIR-V + // where a string operand is followed by another operand. + {Concatenate({ExpectedHeaderForBound(2), + {spvOpcodeMake(6, SpvOpDecorate), 1 /* target id */, + static_cast(SpvDecorationLinkageAttributes)}, + MakeVector("abc"), + {static_cast(SpvLinkageTypeImport), + 0 /* does not belong */}}), + "Invalid instruction OpDecorate starting at word 5: expected no more" + " operands after 5 words, but stated word count is 6."}, + // Like the previous case, but with 5 extra words. + {Concatenate({ExpectedHeaderForBound(2), + {spvOpcodeMake(10, SpvOpDecorate), 1 /* target id */, + static_cast(SpvDecorationLinkageAttributes)}, + MakeVector("abc"), + {static_cast(SpvLinkageTypeImport), + /* don't belong */ 0, 1, 2, 3, 4}}), + "Invalid instruction OpDecorate starting at word 5: expected no more" + " operands after 5 words, but stated word count is 10."}, + // Like the previous two cases, but with OpMemberDecorate. + {Concatenate({ExpectedHeaderForBound(2), + {spvOpcodeMake(7, SpvOpMemberDecorate), 1 /* target id */, + 42 /* member index */, + static_cast(SpvDecorationLinkageAttributes)}, + MakeVector("abc"), + {static_cast(SpvLinkageTypeImport), + 0 /* does not belong */}}), + "Invalid instruction OpMemberDecorate starting at word 5: expected no" + " more operands after 6 words, but stated word count is 7."}, + {Concatenate({ExpectedHeaderForBound(2), + {spvOpcodeMake(11, SpvOpMemberDecorate), + 1 /* target id */, 42 /* member index */, + static_cast(SpvDecorationLinkageAttributes)}, + MakeVector("abc"), + {static_cast(SpvLinkageTypeImport), + /* don't belong */ 0, 1, 2, 3, 4}}), + "Invalid instruction OpMemberDecorate starting at word 5: expected no" + " more operands after 6 words, but stated word count is 11."}, + // Word count is too large. There should be no more words + // after the RelaxedPrecision decoration. + {Concatenate({ExpectedHeaderForBound(2), + {spvOpcodeMake(4, SpvOpDecorate), 1 /* target id */, + static_cast(SpvDecorationRelaxedPrecision), + 0 /* does not belong */}}), + "Invalid instruction OpDecorate starting at word 5: expected no" + " more operands after 3 words, but stated word count is 4."}, + // Word count is too large. There should be only one word after + // the SpecId decoration enum word. + {Concatenate({ExpectedHeaderForBound(2), + {spvOpcodeMake(5, SpvOpDecorate), 1 /* target id */, + static_cast(SpvDecorationSpecId), + 42 /* the spec id */, 0 /* does not belong */}}), + "Invalid instruction OpDecorate starting at word 5: expected no" + " more operands after 4 words, but stated word count is 5."}, + {Concatenate({ExpectedHeaderForBound(2), + {spvOpcodeMake(2, SpvOpTypeVoid), 0}}), + "Error: Result Id is 0"}, + {Concatenate({ + ExpectedHeaderForBound(2), + {spvOpcodeMake(2, SpvOpTypeVoid), 1}, + {spvOpcodeMake(2, SpvOpTypeBool), 1}, + }), + "Id 1 is defined more than once"}, + {Concatenate({ExpectedHeaderForBound(3), + MakeInstruction(SpvOpExtInst, {2, 3, 100, 4, 5})}), + "OpExtInst set Id 100 does not reference an OpExtInstImport result " + "Id"}, + {Concatenate({ExpectedHeaderForBound(101), + MakeInstruction(SpvOpExtInstImport, {100}, + MakeVector("OpenCL.std")), + // OpenCL cos is #14 + MakeInstruction(SpvOpExtInst, {2, 3, 100, 14, 5, 999})}), + "Invalid instruction OpExtInst starting at word 10: expected no " + "more operands after 6 words, but stated word count is 7."}, + // In this case, the OpSwitch selector refers to an invalid ID. + {Concatenate({ExpectedHeaderForBound(3), + MakeInstruction(SpvOpSwitch, {1, 2, 42, 3})}), + "Invalid OpSwitch: selector id 1 has no type"}, + // In this case, the OpSwitch selector refers to an ID that has + // no type. + {Concatenate({ExpectedHeaderForBound(3), + MakeInstruction(SpvOpLabel, {1}), + MakeInstruction(SpvOpSwitch, {1, 2, 42, 3})}), + "Invalid OpSwitch: selector id 1 has no type"}, + {Concatenate({ExpectedHeaderForBound(3), + MakeInstruction(SpvOpTypeInt, {1, 32, 0}), + MakeInstruction(SpvOpSwitch, {1, 3, 42, 3})}), + "Invalid OpSwitch: selector id 1 is a type, not a value"}, + {Concatenate({ExpectedHeaderForBound(3), + MakeInstruction(SpvOpTypeFloat, {1, 32}), + MakeInstruction(SpvOpConstant, {1, 2, 0x78f00000}), + MakeInstruction(SpvOpSwitch, {2, 3, 42, 3})}), + "Invalid OpSwitch: selector id 2 is not a scalar integer"}, + {Concatenate({ExpectedHeaderForBound(3), + MakeInstruction(SpvOpExtInstImport, {1}, + MakeVector("invalid-import"))}), + "Invalid extended instruction import 'invalid-import'"}, + {Concatenate({ + ExpectedHeaderForBound(3), + MakeInstruction(SpvOpTypeInt, {1, 32, 0}), + MakeInstruction(SpvOpConstant, {2, 2, 42}), + }), + "Type Id 2 is not a type"}, + {Concatenate({ + ExpectedHeaderForBound(3), + MakeInstruction(SpvOpTypeBool, {1}), + MakeInstruction(SpvOpConstant, {1, 2, 42}), + }), + "Type Id 1 is not a scalar numeric type"}, + }), ); + +// A binary parser diagnostic case generated from an assembly text input. +struct AssemblyDiagnosticCase { + std::string assembly; + std::string expected_diagnostic; +}; + +using BinaryParseAssemblyDiagnosticTest = spvtest::TextToBinaryTestBase< + ::testing::TestWithParam>; + +TEST_P(BinaryParseAssemblyDiagnosticTest, AssemblyCases) { + auto words = CompileSuccessfully(GetParam().assembly); + EXPECT_THAT(spvBinaryParse(ScopedContext().context, nullptr, words.data(), + words.size(), nullptr, nullptr, &diagnostic), + AnyOf(SPV_ERROR_INVALID_BINARY, SPV_ERROR_INVALID_ID)); + ASSERT_NE(nullptr, diagnostic); + EXPECT_THAT(diagnostic->error, Eq(GetParam().expected_diagnostic)); +} + +INSTANTIATE_TEST_CASE_P( + BinaryParseDiagnostic, BinaryParseAssemblyDiagnosticTest, + ::testing::ValuesIn(std::vector{ + {"%1 = OpConstant !0 42", "Error: Type Id is 0"}, + // A required id is 0. + {"OpName !0 \"foo\"", "Id is 0"}, + // An optional id is 0, in this case the optional + // initializer. + {"%2 = OpVariable %1 CrossWorkgroup !0", "Id is 0"}, + {"OpControlBarrier !0 %1 %2", "scope ID is 0"}, + {"OpControlBarrier %1 !0 %2", "scope ID is 0"}, + {"OpControlBarrier %1 %2 !0", "memory semantics ID is 0"}, + {"%import = OpExtInstImport \"GLSL.std.450\" " + "%result = OpExtInst %type %import !999999 %x", + "Invalid extended instruction number: 999999"}, + {"%2 = OpSpecConstantOp %1 !1000 %2", + "Invalid OpSpecConstantOp opcode: 1000"}, + {"OpCapability !9999", "Invalid capability operand: 9999"}, + {"OpSource !9999 100", "Invalid source language operand: 9999"}, + {"OpEntryPoint !9999", "Invalid execution model operand: 9999"}, + {"OpMemoryModel !9999", "Invalid addressing model operand: 9999"}, + {"OpMemoryModel Logical !9999", "Invalid memory model operand: 9999"}, + {"OpExecutionMode %1 !9999", "Invalid execution mode operand: 9999"}, + {"OpTypeForwardPointer %1 !9999", + "Invalid storage class operand: 9999"}, + {"%2 = OpTypeImage %1 !9999", "Invalid dimensionality operand: 9999"}, + {"%2 = OpTypeImage %1 1D 0 0 0 0 !9999", + "Invalid image format operand: 9999"}, + {"OpDecorate %1 FPRoundingMode !9999", + "Invalid floating-point rounding mode operand: 9999"}, + {"OpDecorate %1 LinkageAttributes \"C\" !9999", + "Invalid linkage type operand: 9999"}, + {"%1 = OpTypePipe !9999", "Invalid access qualifier operand: 9999"}, + {"OpDecorate %1 FuncParamAttr !9999", + "Invalid function parameter attribute operand: 9999"}, + {"OpDecorate %1 !9999", "Invalid decoration operand: 9999"}, + {"OpDecorate %1 BuiltIn !9999", "Invalid built-in operand: 9999"}, + {"%2 = OpGroupIAdd %1 %3 !9999", + "Invalid group operation operand: 9999"}, + {"OpDecorate %1 FPFastMathMode !63", + "Invalid floating-point fast math mode operand: 63 has invalid mask " + "component 32"}, + {"%2 = OpFunction %2 !31", + "Invalid function control operand: 31 has invalid mask component 16"}, + {"OpLoopMerge %1 %2 !1027", + "Invalid loop control operand: 1027 has invalid mask component 1024"}, + {"%2 = OpImageFetch %1 %image %coord !32770", + "Invalid image operand: 32770 has invalid mask component 32768"}, + {"OpSelectionMerge %1 !7", + "Invalid selection control operand: 7 has invalid mask component 4"}, + }), ); + +} // namespace +} // namespace spvtools diff --git a/test/binary_strnlen_s_test.cpp b/test/binary_strnlen_s_test.cpp new file mode 100644 index 000000000..5f43bde67 --- /dev/null +++ b/test/binary_strnlen_s_test.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +TEST(Strnlen, Samples) { + EXPECT_EQ(0u, spv_strnlen_s(nullptr, 0)); + EXPECT_EQ(0u, spv_strnlen_s(nullptr, 5)); + EXPECT_EQ(0u, spv_strnlen_s("abc", 0)); + EXPECT_EQ(1u, spv_strnlen_s("abc", 1)); + EXPECT_EQ(3u, spv_strnlen_s("abc", 3)); + EXPECT_EQ(3u, spv_strnlen_s("abc\0", 5)); + EXPECT_EQ(0u, spv_strnlen_s("\0", 5)); + EXPECT_EQ(1u, spv_strnlen_s("a\0c", 5)); +} + +} // namespace +} // namespace spvtools diff --git a/test/binary_to_text.literal_test.cpp b/test/binary_to_text.literal_test.cpp new file mode 100644 index 000000000..bcfb0f016 --- /dev/null +++ b/test/binary_to_text.literal_test.cpp @@ -0,0 +1,76 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using ::testing::Eq; +using RoundTripLiteralsTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>; + +TEST_P(RoundTripLiteralsTest, Sample) { + EXPECT_THAT(EncodeAndDecodeSuccessfully(GetParam()), Eq(GetParam())); +} + +// clang-format off +INSTANTIATE_TEST_CASE_P( + StringLiterals, RoundTripLiteralsTest, + ::testing::ValuesIn(std::vector{ + "OpName %1 \"\"\n", // empty + "OpName %1 \"foo\"\n", // normal + "OpName %1 \"foo bar\"\n", // string with spaces + "OpName %1 \"foo\tbar\"\n", // string with tab + "OpName %1 \"\tfoo\"\n", // starts with tab + "OpName %1 \" foo\"\n", // starts with space + "OpName %1 \"foo \"\n", // ends with space + "OpName %1 \"foo\t\"\n", // ends with tab + "OpName %1 \"foo\nbar\"\n", // contains newline + "OpName %1 \"\nfoo\nbar\"\n", // starts with newline + "OpName %1 \"\n\n\nfoo\nbar\"\n", // multiple newlines + "OpName %1 \"\\\"foo\nbar\\\"\"\n", // escaped quote + "OpName %1 \"\\\\foo\nbar\\\\\"\n", // escaped backslash + "OpName %1 \"\xE4\xBA\xB2\"\n", // UTF-8 + }),); +// clang-format on + +using RoundTripSpecialCaseLiteralsTest = spvtest::TextToBinaryTestBase< + ::testing::TestWithParam>>; + +// Test case where the generated disassembly is not the same as the +// assembly passed in. +TEST_P(RoundTripSpecialCaseLiteralsTest, Sample) { + EXPECT_THAT(EncodeAndDecodeSuccessfully(std::get<0>(GetParam())), + Eq(std::get<1>(GetParam()))); +} + +// clang-format off +INSTANTIATE_TEST_CASE_P( + StringLiterals, RoundTripSpecialCaseLiteralsTest, + ::testing::ValuesIn(std::vector>{ + {"OpName %1 \"\\foo\"\n", "OpName %1 \"foo\"\n"}, // Escape f + {"OpName %1 \"\\\nfoo\"\n", "OpName %1 \"\nfoo\"\n"}, // Escape newline + {"OpName %1 \"\\\xE4\xBA\xB2\"\n", "OpName %1 \"\xE4\xBA\xB2\"\n"}, // Escape utf-8 + }),); +// clang-format on + +} // namespace +} // namespace spvtools diff --git a/test/binary_to_text_test.cpp b/test/binary_to_text_test.cpp new file mode 100644 index 000000000..00ac86bbd --- /dev/null +++ b/test/binary_to_text_test.cpp @@ -0,0 +1,561 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/spirv_constant.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using spvtest::AutoText; +using spvtest::ScopedContext; +using spvtest::TextToBinaryTest; +using ::testing::Combine; +using ::testing::Eq; +using ::testing::HasSubstr; + +class BinaryToText : public ::testing::Test { + public: + BinaryToText() + : context(spvContextCreate(SPV_ENV_UNIVERSAL_1_0)), binary(nullptr) {} + ~BinaryToText() { + spvBinaryDestroy(binary); + spvContextDestroy(context); + } + + virtual void SetUp() { + const char* textStr = R"( + OpSource OpenCL_C 12 + OpMemoryModel Physical64 OpenCL + OpSourceExtension "PlaceholderExtensionName" + OpEntryPoint Kernel %1 "foo" + OpExecutionMode %1 LocalSizeHint 1 1 1 + %2 = OpTypeVoid + %3 = OpTypeBool + %4 = OpTypeInt 8 0 + %5 = OpTypeInt 8 1 + %6 = OpTypeInt 16 0 + %7 = OpTypeInt 16 1 + %8 = OpTypeInt 32 0 + %9 = OpTypeInt 32 1 +%10 = OpTypeInt 64 0 +%11 = OpTypeInt 64 1 +%12 = OpTypeFloat 16 +%13 = OpTypeFloat 32 +%14 = OpTypeFloat 64 +%15 = OpTypeVector %4 2 +)"; + spv_text_t text = {textStr, strlen(textStr)}; + spv_diagnostic diagnostic = nullptr; + spv_result_t error = + spvTextToBinary(context, text.str, text.length, &binary, &diagnostic); + spvDiagnosticPrint(diagnostic); + spvDiagnosticDestroy(diagnostic); + ASSERT_EQ(SPV_SUCCESS, error); + } + + virtual void TearDown() { + spvBinaryDestroy(binary); + binary = nullptr; + } + + // Compiles the given assembly text, and saves it into 'binary'. + void CompileSuccessfully(std::string text) { + spvBinaryDestroy(binary); + binary = nullptr; + spv_diagnostic diagnostic = nullptr; + EXPECT_EQ(SPV_SUCCESS, spvTextToBinary(context, text.c_str(), text.size(), + &binary, &diagnostic)); + } + + spv_context context; + spv_binary binary; +}; + +TEST_F(BinaryToText, Default) { + spv_text text = nullptr; + spv_diagnostic diagnostic = nullptr; + ASSERT_EQ( + SPV_SUCCESS, + spvBinaryToText(context, binary->code, binary->wordCount, + SPV_BINARY_TO_TEXT_OPTION_NONE, &text, &diagnostic)); + printf("%s", text->str); + spvTextDestroy(text); +} + +TEST_F(BinaryToText, MissingModule) { + spv_text text; + spv_diagnostic diagnostic = nullptr; + EXPECT_EQ( + SPV_ERROR_INVALID_BINARY, + spvBinaryToText(context, nullptr, 42, SPV_BINARY_TO_TEXT_OPTION_NONE, + &text, &diagnostic)); + EXPECT_THAT(diagnostic->error, Eq(std::string("Missing module."))); + if (diagnostic) { + spvDiagnosticPrint(diagnostic); + spvDiagnosticDestroy(diagnostic); + } +} + +TEST_F(BinaryToText, TruncatedModule) { + // Make a valid module with zero instructions. + CompileSuccessfully(""); + EXPECT_EQ(SPV_INDEX_INSTRUCTION, binary->wordCount); + + for (size_t length = 0; length < SPV_INDEX_INSTRUCTION; length++) { + spv_text text = nullptr; + spv_diagnostic diagnostic = nullptr; + EXPECT_EQ( + SPV_ERROR_INVALID_BINARY, + spvBinaryToText(context, binary->code, length, + SPV_BINARY_TO_TEXT_OPTION_NONE, &text, &diagnostic)); + ASSERT_NE(nullptr, diagnostic); + std::stringstream expected; + expected << "Module has incomplete header: only " << length + << " words instead of " << SPV_INDEX_INSTRUCTION; + EXPECT_THAT(diagnostic->error, Eq(expected.str())); + spvDiagnosticDestroy(diagnostic); + } +} + +TEST_F(BinaryToText, InvalidMagicNumber) { + CompileSuccessfully(""); + std::vector damaged_binary(binary->code, + binary->code + binary->wordCount); + damaged_binary[SPV_INDEX_MAGIC_NUMBER] ^= 123; + + spv_diagnostic diagnostic = nullptr; + spv_text text; + EXPECT_EQ( + SPV_ERROR_INVALID_BINARY, + spvBinaryToText(context, damaged_binary.data(), damaged_binary.size(), + SPV_BINARY_TO_TEXT_OPTION_NONE, &text, &diagnostic)); + ASSERT_NE(nullptr, diagnostic); + std::stringstream expected; + expected << "Invalid SPIR-V magic number '" << std::hex + << damaged_binary[SPV_INDEX_MAGIC_NUMBER] << "'."; + EXPECT_THAT(diagnostic->error, Eq(expected.str())); + spvDiagnosticDestroy(diagnostic); +} + +struct FailedDecodeCase { + std::string source_text; + std::vector appended_instruction; + std::string expected_error_message; +}; + +using BinaryToTextFail = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>; + +TEST_P(BinaryToTextFail, EncodeSuccessfullyDecodeFailed) { + EXPECT_THAT(EncodeSuccessfullyDecodeFailed(GetParam().source_text, + GetParam().appended_instruction), + Eq(GetParam().expected_error_message)); +} + +INSTANTIATE_TEST_CASE_P( + InvalidIds, BinaryToTextFail, + ::testing::ValuesIn(std::vector{ + {"", spvtest::MakeInstruction(SpvOpTypeVoid, {0}), + "Error: Result Id is 0"}, + {"", spvtest::MakeInstruction(SpvOpConstant, {0, 1, 42}), + "Error: Type Id is 0"}, + {"%1 = OpTypeVoid", spvtest::MakeInstruction(SpvOpTypeVoid, {1}), + "Id 1 is defined more than once"}, + {"%1 = OpTypeVoid\n" + "%2 = OpNot %1 %foo", + spvtest::MakeInstruction(SpvOpNot, {1, 2, 3}), + "Id 2 is defined more than once"}, + {"%1 = OpTypeVoid\n" + "%2 = OpNot %1 %foo", + spvtest::MakeInstruction(SpvOpNot, {1, 1, 3}), + "Id 1 is defined more than once"}, + // The following are the two failure cases for + // Parser::setNumericTypeInfoForType. + {"", spvtest::MakeInstruction(SpvOpConstant, {500, 1, 42}), + "Type Id 500 is not a type"}, + {"%1 = OpTypeInt 32 0\n" + "%2 = OpTypeVector %1 4", + spvtest::MakeInstruction(SpvOpConstant, {2, 3, 999}), + "Type Id 2 is not a scalar numeric type"}, + }), ); + +INSTANTIATE_TEST_CASE_P( + InvalidIdsCheckedDuringLiteralCaseParsing, BinaryToTextFail, + ::testing::ValuesIn(std::vector{ + {"", spvtest::MakeInstruction(SpvOpSwitch, {1, 2, 3, 4}), + "Invalid OpSwitch: selector id 1 has no type"}, + {"%1 = OpTypeVoid\n", + spvtest::MakeInstruction(SpvOpSwitch, {1, 2, 3, 4}), + "Invalid OpSwitch: selector id 1 is a type, not a value"}, + {"%1 = OpConstantTrue !500", + spvtest::MakeInstruction(SpvOpSwitch, {1, 2, 3, 4}), + "Type Id 500 is not a type"}, + {"%1 = OpTypeFloat 32\n%2 = OpConstant %1 1.5", + spvtest::MakeInstruction(SpvOpSwitch, {2, 3, 4, 5}), + "Invalid OpSwitch: selector id 2 is not a scalar integer"}, + }), ); + +TEST_F(TextToBinaryTest, OneInstruction) { + const std::string input = "OpSource OpenCL_C 12\n"; + EXPECT_EQ(input, EncodeAndDecodeSuccessfully(input)); +} + +// Exercise the case where an operand itself has operands. +// This could detect problems in updating the expected-set-of-operands +// list. +TEST_F(TextToBinaryTest, OperandWithOperands) { + const std::string input = R"(OpEntryPoint Kernel %1 "foo" +OpExecutionMode %1 LocalSizeHint 100 200 300 +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%1 = OpFunction %1 None %3 +)"; + EXPECT_EQ(input, EncodeAndDecodeSuccessfully(input)); +} + +using RoundTripInstructionsTest = spvtest::TextToBinaryTestBase< + ::testing::TestWithParam>>; + +TEST_P(RoundTripInstructionsTest, Sample) { + EXPECT_THAT(EncodeAndDecodeSuccessfully(std::get<1>(GetParam()), + SPV_BINARY_TO_TEXT_OPTION_NONE, + std::get<0>(GetParam())), + Eq(std::get<1>(GetParam()))); +} + +// clang-format off +INSTANTIATE_TEST_CASE_P( + NumericLiterals, RoundTripInstructionsTest, + // This test is independent of environment, so just test the one. + Combine(::testing::Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, + SPV_ENV_UNIVERSAL_1_2, SPV_ENV_UNIVERSAL_1_3), + ::testing::ValuesIn(std::vector{ + "%1 = OpTypeInt 12 0\n%2 = OpConstant %1 1867\n", + "%1 = OpTypeInt 12 1\n%2 = OpConstant %1 1867\n", + "%1 = OpTypeInt 12 1\n%2 = OpConstant %1 -1867\n", + "%1 = OpTypeInt 32 0\n%2 = OpConstant %1 1867\n", + "%1 = OpTypeInt 32 1\n%2 = OpConstant %1 1867\n", + "%1 = OpTypeInt 32 1\n%2 = OpConstant %1 -1867\n", + "%1 = OpTypeInt 64 0\n%2 = OpConstant %1 18446744073709551615\n", + "%1 = OpTypeInt 64 1\n%2 = OpConstant %1 9223372036854775807\n", + "%1 = OpTypeInt 64 1\n%2 = OpConstant %1 -9223372036854775808\n", + // 16-bit floats print as hex floats. + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.ff4p+16\n", + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.d2cp-10\n", + // 32-bit floats + "%1 = OpTypeFloat 32\n%2 = OpConstant %1 -3.125\n", + "%1 = OpTypeFloat 32\n%2 = OpConstant %1 0x1.8p+128\n", // NaN + "%1 = OpTypeFloat 32\n%2 = OpConstant %1 -0x1.0002p+128\n", // NaN + "%1 = OpTypeFloat 32\n%2 = OpConstant %1 0x1p+128\n", // Inf + "%1 = OpTypeFloat 32\n%2 = OpConstant %1 -0x1p+128\n", // -Inf + // 64-bit floats + "%1 = OpTypeFloat 64\n%2 = OpConstant %1 -3.125\n", + "%1 = OpTypeFloat 64\n%2 = OpConstant %1 0x1.ffffffffffffap-1023\n", // small normal + "%1 = OpTypeFloat 64\n%2 = OpConstant %1 -0x1.ffffffffffffap-1023\n", + "%1 = OpTypeFloat 64\n%2 = OpConstant %1 0x1.8p+1024\n", // NaN + "%1 = OpTypeFloat 64\n%2 = OpConstant %1 -0x1.0002p+1024\n", // NaN + "%1 = OpTypeFloat 64\n%2 = OpConstant %1 0x1p+1024\n", // Inf + "%1 = OpTypeFloat 64\n%2 = OpConstant %1 -0x1p+1024\n", // -Inf + })), ); +// clang-format on + +INSTANTIATE_TEST_CASE_P( + MemoryAccessMasks, RoundTripInstructionsTest, + Combine(::testing::Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, + SPV_ENV_UNIVERSAL_1_2, SPV_ENV_UNIVERSAL_1_3), + ::testing::ValuesIn(std::vector{ + "OpStore %1 %2\n", // 3 words long. + "OpStore %1 %2 None\n", // 4 words long, explicit final 0. + "OpStore %1 %2 Volatile\n", + "OpStore %1 %2 Aligned 8\n", + "OpStore %1 %2 Nontemporal\n", + // Combinations show the names from LSB to MSB + "OpStore %1 %2 Volatile|Aligned 16\n", + "OpStore %1 %2 Volatile|Nontemporal\n", + "OpStore %1 %2 Volatile|Aligned|Nontemporal 32\n", + })), ); + +INSTANTIATE_TEST_CASE_P( + FPFastMathModeMasks, RoundTripInstructionsTest, + Combine( + ::testing::Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, + SPV_ENV_UNIVERSAL_1_2, SPV_ENV_UNIVERSAL_1_3), + ::testing::ValuesIn(std::vector{ + "OpDecorate %1 FPFastMathMode None\n", + "OpDecorate %1 FPFastMathMode NotNaN\n", + "OpDecorate %1 FPFastMathMode NotInf\n", + "OpDecorate %1 FPFastMathMode NSZ\n", + "OpDecorate %1 FPFastMathMode AllowRecip\n", + "OpDecorate %1 FPFastMathMode Fast\n", + // Combinations show the names from LSB to MSB + "OpDecorate %1 FPFastMathMode NotNaN|NotInf\n", + "OpDecorate %1 FPFastMathMode NSZ|AllowRecip\n", + "OpDecorate %1 FPFastMathMode NotNaN|NotInf|NSZ|AllowRecip|Fast\n", + })), ); + +INSTANTIATE_TEST_CASE_P( + LoopControlMasks, RoundTripInstructionsTest, + Combine(::testing::Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, + SPV_ENV_UNIVERSAL_1_3, SPV_ENV_UNIVERSAL_1_2), + ::testing::ValuesIn(std::vector{ + "OpLoopMerge %1 %2 None\n", + "OpLoopMerge %1 %2 Unroll\n", + "OpLoopMerge %1 %2 DontUnroll\n", + "OpLoopMerge %1 %2 Unroll|DontUnroll\n", + })), ); + +INSTANTIATE_TEST_CASE_P(LoopControlMasksV11, RoundTripInstructionsTest, + Combine(::testing::Values(SPV_ENV_UNIVERSAL_1_1, + SPV_ENV_UNIVERSAL_1_2, + SPV_ENV_UNIVERSAL_1_3), + ::testing::ValuesIn(std::vector{ + "OpLoopMerge %1 %2 DependencyInfinite\n", + "OpLoopMerge %1 %2 DependencyLength 8\n", + })), ); + +INSTANTIATE_TEST_CASE_P( + SelectionControlMasks, RoundTripInstructionsTest, + Combine(::testing::Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, + SPV_ENV_UNIVERSAL_1_3, SPV_ENV_UNIVERSAL_1_2), + ::testing::ValuesIn(std::vector{ + "OpSelectionMerge %1 None\n", + "OpSelectionMerge %1 Flatten\n", + "OpSelectionMerge %1 DontFlatten\n", + "OpSelectionMerge %1 Flatten|DontFlatten\n", + })), ); + +INSTANTIATE_TEST_CASE_P( + FunctionControlMasks, RoundTripInstructionsTest, + Combine(::testing::Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, + SPV_ENV_UNIVERSAL_1_2, SPV_ENV_UNIVERSAL_1_3), + ::testing::ValuesIn(std::vector{ + "%2 = OpFunction %1 None %3\n", + "%2 = OpFunction %1 Inline %3\n", + "%2 = OpFunction %1 DontInline %3\n", + "%2 = OpFunction %1 Pure %3\n", + "%2 = OpFunction %1 Const %3\n", + "%2 = OpFunction %1 Inline|Pure|Const %3\n", + "%2 = OpFunction %1 DontInline|Const %3\n", + })), ); + +INSTANTIATE_TEST_CASE_P( + ImageMasks, RoundTripInstructionsTest, + Combine(::testing::Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, + SPV_ENV_UNIVERSAL_1_2, SPV_ENV_UNIVERSAL_1_3), + ::testing::ValuesIn(std::vector{ + "%2 = OpImageFetch %1 %3 %4\n", + "%2 = OpImageFetch %1 %3 %4 None\n", + "%2 = OpImageFetch %1 %3 %4 Bias %5\n", + "%2 = OpImageFetch %1 %3 %4 Lod %5\n", + "%2 = OpImageFetch %1 %3 %4 Grad %5 %6\n", + "%2 = OpImageFetch %1 %3 %4 ConstOffset %5\n", + "%2 = OpImageFetch %1 %3 %4 Offset %5\n", + "%2 = OpImageFetch %1 %3 %4 ConstOffsets %5\n", + "%2 = OpImageFetch %1 %3 %4 Sample %5\n", + "%2 = OpImageFetch %1 %3 %4 MinLod %5\n", + "%2 = OpImageFetch %1 %3 %4 Bias|Lod|Grad %5 %6 %7 %8\n", + "%2 = OpImageFetch %1 %3 %4 ConstOffset|Offset|ConstOffsets" + " %5 %6 %7\n", + "%2 = OpImageFetch %1 %3 %4 Sample|MinLod %5 %6\n", + "%2 = OpImageFetch %1 %3 %4" + " Bias|Lod|Grad|ConstOffset|Offset|ConstOffsets|Sample|MinLod" + " %5 %6 %7 %8 %9 %10 %11 %12 %13\n"})), ); + +INSTANTIATE_TEST_CASE_P( + NewInstructionsInSPIRV1_2, RoundTripInstructionsTest, + Combine(::testing::Values(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_UNIVERSAL_1_3), + ::testing::ValuesIn(std::vector{ + "OpExecutionModeId %1 SubgroupsPerWorkgroupId %2\n", + "OpExecutionModeId %1 LocalSizeId %2 %3 %4\n", + "OpExecutionModeId %1 LocalSizeHintId %2\n", + "OpDecorateId %1 AlignmentId %2\n", + "OpDecorateId %1 MaxByteOffsetId %2\n", + })), ); + +using MaskSorting = TextToBinaryTest; + +TEST_F(MaskSorting, MasksAreSortedFromLSBToMSB) { + EXPECT_THAT(EncodeAndDecodeSuccessfully( + "OpStore %1 %2 Nontemporal|Aligned|Volatile 32"), + Eq("OpStore %1 %2 Volatile|Aligned|Nontemporal 32\n")); + EXPECT_THAT( + EncodeAndDecodeSuccessfully( + "OpDecorate %1 FPFastMathMode NotInf|Fast|AllowRecip|NotNaN|NSZ"), + Eq("OpDecorate %1 FPFastMathMode NotNaN|NotInf|NSZ|AllowRecip|Fast\n")); + EXPECT_THAT( + EncodeAndDecodeSuccessfully("OpLoopMerge %1 %2 DontUnroll|Unroll"), + Eq("OpLoopMerge %1 %2 Unroll|DontUnroll\n")); + EXPECT_THAT( + EncodeAndDecodeSuccessfully("OpSelectionMerge %1 DontFlatten|Flatten"), + Eq("OpSelectionMerge %1 Flatten|DontFlatten\n")); + EXPECT_THAT(EncodeAndDecodeSuccessfully( + "%2 = OpFunction %1 DontInline|Const|Pure|Inline %3"), + Eq("%2 = OpFunction %1 Inline|DontInline|Pure|Const %3\n")); + EXPECT_THAT(EncodeAndDecodeSuccessfully( + "%2 = OpImageFetch %1 %3 %4" + " MinLod|Sample|Offset|Lod|Grad|ConstOffsets|ConstOffset|Bias" + " %5 %6 %7 %8 %9 %10 %11 %12 %13\n"), + Eq("%2 = OpImageFetch %1 %3 %4" + " Bias|Lod|Grad|ConstOffset|Offset|ConstOffsets|Sample|MinLod" + " %5 %6 %7 %8 %9 %10 %11 %12 %13\n")); +} + +using OperandTypeTest = TextToBinaryTest; + +TEST_F(OperandTypeTest, OptionalTypedLiteralNumber) { + const std::string input = + "%1 = OpTypeInt 32 0\n" + "%2 = OpConstant %1 42\n" + "OpSwitch %2 %3 100 %4\n"; + EXPECT_EQ(input, EncodeAndDecodeSuccessfully(input)); +} + +using IndentTest = spvtest::TextToBinaryTest; + +TEST_F(IndentTest, Sample) { + const std::string input = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +%1 = OpTypeInt 32 0 +%2 = OpTypeStruct %1 %3 %4 %5 %6 %7 %8 %9 %10 ; force IDs into double digits +%11 = OpConstant %1 42 +OpStore %2 %3 Aligned|Volatile 4 ; bogus, but not indented +)"; + const std::string expected = + R"( OpCapability Shader + OpMemoryModel Logical GLSL450 + %1 = OpTypeInt 32 0 + %2 = OpTypeStruct %1 %3 %4 %5 %6 %7 %8 %9 %10 + %11 = OpConstant %1 42 + OpStore %2 %3 Volatile|Aligned 4 +)"; + EXPECT_THAT( + EncodeAndDecodeSuccessfully(input, SPV_BINARY_TO_TEXT_OPTION_INDENT), + expected); +} + +using FriendlyNameDisassemblyTest = spvtest::TextToBinaryTest; + +TEST_F(FriendlyNameDisassemblyTest, Sample) { + const std::string input = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +%1 = OpTypeInt 32 0 +%2 = OpTypeStruct %1 %3 %4 %5 %6 %7 %8 %9 %10 ; force IDs into double digits +%11 = OpConstant %1 42 +)"; + const std::string expected = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +%uint = OpTypeInt 32 0 +%_struct_2 = OpTypeStruct %uint %3 %4 %5 %6 %7 %8 %9 %10 +%uint_42 = OpConstant %uint 42 +)"; + EXPECT_THAT(EncodeAndDecodeSuccessfully( + input, SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES), + expected); +} + +TEST_F(TextToBinaryTest, ShowByteOffsetsWhenRequested) { + const std::string input = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +%1 = OpTypeInt 32 0 +%2 = OpTypeVoid +)"; + const std::string expected = + R"(OpCapability Shader ; 0x00000014 +OpMemoryModel Logical GLSL450 ; 0x0000001c +%1 = OpTypeInt 32 0 ; 0x00000028 +%2 = OpTypeVoid ; 0x00000038 +)"; + EXPECT_THAT(EncodeAndDecodeSuccessfully( + input, SPV_BINARY_TO_TEXT_OPTION_SHOW_BYTE_OFFSET), + expected); +} + +// Test version string. +TEST_F(TextToBinaryTest, VersionString) { + auto words = CompileSuccessfully(""); + spv_text decoded_text = nullptr; + EXPECT_THAT(spvBinaryToText(ScopedContext().context, words.data(), + words.size(), SPV_BINARY_TO_TEXT_OPTION_NONE, + &decoded_text, &diagnostic), + Eq(SPV_SUCCESS)); + EXPECT_EQ(nullptr, diagnostic); + + EXPECT_THAT(decoded_text->str, HasSubstr("Version: 1.0\n")) + << EncodeAndDecodeSuccessfully(""); + spvTextDestroy(decoded_text); +} + +// Test generator string. + +// A test case for the generator string. This allows us to +// test both of the 16-bit components of the generator word. +struct GeneratorStringCase { + uint16_t generator; + uint16_t misc; + std::string expected; +}; + +using GeneratorStringTest = spvtest::TextToBinaryTestBase< + ::testing::TestWithParam>; + +TEST_P(GeneratorStringTest, Sample) { + auto words = CompileSuccessfully(""); + EXPECT_EQ(2u, SPV_INDEX_GENERATOR_NUMBER); + words[SPV_INDEX_GENERATOR_NUMBER] = + SPV_GENERATOR_WORD(GetParam().generator, GetParam().misc); + + spv_text decoded_text = nullptr; + EXPECT_THAT(spvBinaryToText(ScopedContext().context, words.data(), + words.size(), SPV_BINARY_TO_TEXT_OPTION_NONE, + &decoded_text, &diagnostic), + Eq(SPV_SUCCESS)); + EXPECT_THAT(diagnostic, Eq(nullptr)); + EXPECT_THAT(std::string(decoded_text->str), HasSubstr(GetParam().expected)); + spvTextDestroy(decoded_text); +} + +INSTANTIATE_TEST_CASE_P(GeneratorStrings, GeneratorStringTest, + ::testing::ValuesIn(std::vector{ + {SPV_GENERATOR_KHRONOS, 12, "Khronos; 12"}, + {SPV_GENERATOR_LUNARG, 99, "LunarG; 99"}, + {SPV_GENERATOR_VALVE, 1, "Valve; 1"}, + {SPV_GENERATOR_CODEPLAY, 65535, "Codeplay; 65535"}, + {SPV_GENERATOR_NVIDIA, 19, "NVIDIA; 19"}, + {SPV_GENERATOR_ARM, 1000, "ARM; 1000"}, + {SPV_GENERATOR_KHRONOS_LLVM_TRANSLATOR, 38, + "Khronos LLVM/SPIR-V Translator; 38"}, + {SPV_GENERATOR_KHRONOS_ASSEMBLER, 2, + "Khronos SPIR-V Tools Assembler; 2"}, + {SPV_GENERATOR_KHRONOS_GLSLANG, 1, + "Khronos Glslang Reference Front End; 1"}, + {1000, 18, "Unknown(1000); 18"}, + {65535, 32767, "Unknown(65535); 32767"}, + }), ); + +// TODO(dneto): Test new instructions and enums in SPIR-V 1.3 + +} // namespace +} // namespace spvtools diff --git a/test/bit_stream.cpp b/test/bit_stream.cpp new file mode 100644 index 000000000..f02faf3c6 --- /dev/null +++ b/test/bit_stream.cpp @@ -0,0 +1,1025 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/comp/bit_stream.h" + +namespace spvtools { +namespace comp { +namespace { + +// Converts |buffer| to a stream of '0' and '1'. +template +std::string BufferToStream(const std::vector& buffer) { + std::stringstream ss; + for (auto it = buffer.begin(); it != buffer.end(); ++it) { + std::string str = std::bitset(*it).to_string(); + // Strings generated by std::bitset::to_string are read right to left. + // Reversing to left to right. + std::reverse(str.begin(), str.end()); + ss << str; + } + return ss.str(); +} + +// Converts a left-to-right input string of '0' and '1' to a buffer of |T| +// words. +template +std::vector StreamToBuffer(std::string str) { + // The input string is left-to-right, the input argument of std::bitset needs + // to right-to-left. Instead of reversing tokens, reverse the entire string + // and iterate tokens from end to begin. + std::reverse(str.begin(), str.end()); + const int word_size = static_cast(sizeof(T) * 8); + const int str_length = static_cast(str.length()); + std::vector buffer; + buffer.reserve(NumBitsToNumWords(str.length())); + for (int index = str_length - word_size; index >= 0; index -= word_size) { + buffer.push_back(static_cast( + std::bitset(str, index, word_size).to_ullong())); + } + const size_t suffix_length = str.length() % word_size; + if (suffix_length != 0) { + buffer.push_back(static_cast( + std::bitset(str, 0, suffix_length).to_ullong())); + } + return buffer; +} + +// Adds '0' chars at the end of the string until the size is a multiple of N. +template +std::string PadToWord(std::string&& str) { + const size_t tail_length = str.size() % N; + if (tail_length != 0) str += std::string(N - tail_length, '0'); + return std::move(str); +} + +// Adds '0' chars at the end of the string until the size is a multiple of N. +template +std::string PadToWord(const std::string& str) { + return PadToWord(std::string(str)); +} + +// Converts a left-to-right stream of bits to std::bitset. +template +std::bitset StreamToBitset(std::string str) { + std::reverse(str.begin(), str.end()); + return std::bitset(str); +} + +// Converts a left-to-right stream of bits to uint64. +uint64_t StreamToBits(std::string str) { + std::reverse(str.begin(), str.end()); + return std::bitset<64>(str).to_ullong(); +} + +// A simple and inefficient implementatition of BitWriterInterface, +// using std::stringstream. Intended for tests only. +class BitWriterStringStream : public BitWriterInterface { + public: + void WriteBits(uint64_t bits, size_t num_bits) override { + assert(num_bits <= 64); + ss_ << BitsToStream(bits, num_bits); + } + + size_t GetNumBits() const override { return ss_.str().size(); } + + std::vector GetDataCopy() const override { + return StreamToBuffer(ss_.str()); + } + + std::string GetStreamRaw() const { return ss_.str(); } + + private: + std::stringstream ss_; +}; + +// A simple and inefficient implementatition of BitReaderInterface. +// Intended for tests only. +class BitReaderFromString : public BitReaderInterface { + public: + explicit BitReaderFromString(std::string&& str) + : str_(std::move(str)), pos_(0) {} + + explicit BitReaderFromString(const std::vector& buffer) + : str_(BufferToStream(buffer)), pos_(0) {} + + explicit BitReaderFromString(const std::vector& buffer) + : str_(PadToWord<64>(BufferToStream(buffer))), pos_(0) {} + + size_t ReadBits(uint64_t* bits, size_t num_bits) override { + if (ReachedEnd()) return 0; + std::string sub = str_.substr(pos_, num_bits); + *bits = StreamToBits(sub); + pos_ += sub.length(); + return sub.length(); + } + + size_t GetNumReadBits() const override { return pos_; } + + bool ReachedEnd() const override { return pos_ >= str_.length(); } + + private: + std::string str_; + size_t pos_; +}; + +TEST(NumBitsToNumWords, Word8) { + EXPECT_EQ(0u, NumBitsToNumWords<8>(0)); + EXPECT_EQ(1u, NumBitsToNumWords<8>(1)); + EXPECT_EQ(1u, NumBitsToNumWords<8>(7)); + EXPECT_EQ(1u, NumBitsToNumWords<8>(8)); + EXPECT_EQ(2u, NumBitsToNumWords<8>(9)); + EXPECT_EQ(2u, NumBitsToNumWords<8>(16)); + EXPECT_EQ(3u, NumBitsToNumWords<8>(17)); + EXPECT_EQ(3u, NumBitsToNumWords<8>(23)); + EXPECT_EQ(3u, NumBitsToNumWords<8>(24)); + EXPECT_EQ(4u, NumBitsToNumWords<8>(25)); +} + +TEST(NumBitsToNumWords, Word64) { + EXPECT_EQ(0u, NumBitsToNumWords<64>(0)); + EXPECT_EQ(1u, NumBitsToNumWords<64>(1)); + EXPECT_EQ(1u, NumBitsToNumWords<64>(64)); + EXPECT_EQ(2u, NumBitsToNumWords<64>(65)); + EXPECT_EQ(2u, NumBitsToNumWords<64>(128)); + EXPECT_EQ(3u, NumBitsToNumWords<64>(129)); +} + +TEST(ZigZagCoding, Encode0) { + EXPECT_EQ(0u, EncodeZigZag(0, 0)); + EXPECT_EQ(1u, EncodeZigZag(-1, 0)); + EXPECT_EQ(2u, EncodeZigZag(1, 0)); + EXPECT_EQ(3u, EncodeZigZag(-2, 0)); + EXPECT_EQ(std::numeric_limits::max() - 1, + EncodeZigZag(std::numeric_limits::max(), 0)); + EXPECT_EQ(std::numeric_limits::max(), + EncodeZigZag(std::numeric_limits::min(), 0)); +} + +TEST(ZigZagCoding, Decode0) { + EXPECT_EQ(0, DecodeZigZag(0, 0)); + EXPECT_EQ(-1, DecodeZigZag(1, 0)); + EXPECT_EQ(1, DecodeZigZag(2, 0)); + EXPECT_EQ(-2, DecodeZigZag(3, 0)); + EXPECT_EQ(std::numeric_limits::min(), + DecodeZigZag(std::numeric_limits::max(), 0)); + EXPECT_EQ(std::numeric_limits::max(), + DecodeZigZag(std::numeric_limits::max() - 1, 0)); +} + +TEST(ZigZagCoding, Encode1) { + EXPECT_EQ(0u, EncodeZigZag(0, 1)); + EXPECT_EQ(1u, EncodeZigZag(1, 1)); + EXPECT_EQ(2u, EncodeZigZag(-1, 1)); + EXPECT_EQ(3u, EncodeZigZag(-2, 1)); + EXPECT_EQ(4u, EncodeZigZag(2, 1)); + EXPECT_EQ(5u, EncodeZigZag(3, 1)); + EXPECT_EQ(6u, EncodeZigZag(-3, 1)); + EXPECT_EQ(7u, EncodeZigZag(-4, 1)); + EXPECT_EQ(std::numeric_limits::max() - 2, + EncodeZigZag(std::numeric_limits::max(), 1)); + EXPECT_EQ(std::numeric_limits::max() - 1, + EncodeZigZag(std::numeric_limits::min() + 1, 1)); + EXPECT_EQ(std::numeric_limits::max(), + EncodeZigZag(std::numeric_limits::min(), 1)); +} + +TEST(ZigZagCoding, Decode1) { + EXPECT_EQ(0, DecodeZigZag(0, 1)); + EXPECT_EQ(1, DecodeZigZag(1, 1)); + EXPECT_EQ(-1, DecodeZigZag(2, 1)); + EXPECT_EQ(-2, DecodeZigZag(3, 1)); + EXPECT_EQ(2, DecodeZigZag(4, 1)); + EXPECT_EQ(3, DecodeZigZag(5, 1)); + EXPECT_EQ(-3, DecodeZigZag(6, 1)); + EXPECT_EQ(-4, DecodeZigZag(7, 1)); + EXPECT_EQ(std::numeric_limits::min(), + DecodeZigZag(std::numeric_limits::max(), 1)); + EXPECT_EQ(std::numeric_limits::min() + 1, + DecodeZigZag(std::numeric_limits::max() - 1, 1)); + EXPECT_EQ(std::numeric_limits::max(), + DecodeZigZag(std::numeric_limits::max() - 2, 1)); +} + +TEST(ZigZagCoding, Encode2) { + EXPECT_EQ(0u, EncodeZigZag(0, 2)); + EXPECT_EQ(1u, EncodeZigZag(1, 2)); + EXPECT_EQ(2u, EncodeZigZag(2, 2)); + EXPECT_EQ(3u, EncodeZigZag(3, 2)); + EXPECT_EQ(4u, EncodeZigZag(-1, 2)); + EXPECT_EQ(5u, EncodeZigZag(-2, 2)); + EXPECT_EQ(6u, EncodeZigZag(-3, 2)); + EXPECT_EQ(7u, EncodeZigZag(-4, 2)); + EXPECT_EQ(8u, EncodeZigZag(4, 2)); + EXPECT_EQ(9u, EncodeZigZag(5, 2)); + EXPECT_EQ(10u, EncodeZigZag(6, 2)); + EXPECT_EQ(11u, EncodeZigZag(7, 2)); + EXPECT_EQ(12u, EncodeZigZag(-5, 2)); + EXPECT_EQ(13u, EncodeZigZag(-6, 2)); + EXPECT_EQ(14u, EncodeZigZag(-7, 2)); + EXPECT_EQ(15u, EncodeZigZag(-8, 2)); + EXPECT_EQ(std::numeric_limits::max() - 4, + EncodeZigZag(std::numeric_limits::max(), 2)); + EXPECT_EQ(std::numeric_limits::max() - 3, + EncodeZigZag(std::numeric_limits::min() + 3, 2)); + EXPECT_EQ(std::numeric_limits::max() - 2, + EncodeZigZag(std::numeric_limits::min() + 2, 2)); + EXPECT_EQ(std::numeric_limits::max() - 1, + EncodeZigZag(std::numeric_limits::min() + 1, 2)); + EXPECT_EQ(std::numeric_limits::max(), + EncodeZigZag(std::numeric_limits::min(), 2)); +} + +TEST(ZigZagCoding, Decode2) { + EXPECT_EQ(0, DecodeZigZag(0, 2)); + EXPECT_EQ(1, DecodeZigZag(1, 2)); + EXPECT_EQ(2, DecodeZigZag(2, 2)); + EXPECT_EQ(3, DecodeZigZag(3, 2)); + EXPECT_EQ(-1, DecodeZigZag(4, 2)); + EXPECT_EQ(-2, DecodeZigZag(5, 2)); + EXPECT_EQ(-3, DecodeZigZag(6, 2)); + EXPECT_EQ(-4, DecodeZigZag(7, 2)); + EXPECT_EQ(4, DecodeZigZag(8, 2)); + EXPECT_EQ(5, DecodeZigZag(9, 2)); + EXPECT_EQ(6, DecodeZigZag(10, 2)); + EXPECT_EQ(7, DecodeZigZag(11, 2)); + EXPECT_EQ(-5, DecodeZigZag(12, 2)); + EXPECT_EQ(-6, DecodeZigZag(13, 2)); + EXPECT_EQ(-7, DecodeZigZag(14, 2)); + EXPECT_EQ(-8, DecodeZigZag(15, 2)); + EXPECT_EQ(std::numeric_limits::min(), + DecodeZigZag(std::numeric_limits::max(), 2)); + EXPECT_EQ(std::numeric_limits::min() + 1, + DecodeZigZag(std::numeric_limits::max() - 1, 2)); + EXPECT_EQ(std::numeric_limits::min() + 2, + DecodeZigZag(std::numeric_limits::max() - 2, 2)); + EXPECT_EQ(std::numeric_limits::min() + 3, + DecodeZigZag(std::numeric_limits::max() - 3, 2)); + EXPECT_EQ(std::numeric_limits::max(), + DecodeZigZag(std::numeric_limits::max() - 4, 2)); +} + +TEST(ZigZagCoding, Encode63) { + EXPECT_EQ(0u, EncodeZigZag(0, 63)); + + for (int64_t i = 0; i < 0xFFFFFFFF; i += 1234567) { + const int64_t positive_val = GetLowerBits(i * i * i + i * i, 63) | 1UL; + ASSERT_EQ(static_cast(positive_val), + EncodeZigZag(positive_val, 63)); + ASSERT_EQ((1ULL << 63) - 1 + positive_val, EncodeZigZag(-positive_val, 63)); + } + + EXPECT_EQ((1ULL << 63) - 1, + EncodeZigZag(std::numeric_limits::max(), 63)); + EXPECT_EQ(std::numeric_limits::max() - 1, + EncodeZigZag(std::numeric_limits::min() + 1, 63)); + EXPECT_EQ(std::numeric_limits::max(), + EncodeZigZag(std::numeric_limits::min(), 63)); +} + +TEST(BufToStream, UInt8_Empty) { + const std::string expected_bits = ""; + std::vector buffer = StreamToBuffer(expected_bits); + EXPECT_TRUE(buffer.empty()); + const std::string result_bits = BufferToStream(buffer); + EXPECT_EQ(expected_bits, result_bits); +} + +TEST(BufToStream, UInt8_OneWord) { + const std::string expected_bits = "00101100"; + std::vector buffer = StreamToBuffer(expected_bits); + EXPECT_EQ(std::vector({static_cast( + StreamToBitset<8>(expected_bits).to_ulong())}), + buffer); + const std::string result_bits = BufferToStream(buffer); + EXPECT_EQ(expected_bits, result_bits); +} + +TEST(BufToStream, UInt8_MultipleWords) { + const std::string expected_bits = + "00100010" + "01101010" + "01111101" + "00100010"; + std::vector buffer = StreamToBuffer(expected_bits); + EXPECT_EQ(std::vector({ + static_cast(StreamToBitset<8>("00100010").to_ulong()), + static_cast(StreamToBitset<8>("01101010").to_ulong()), + static_cast(StreamToBitset<8>("01111101").to_ulong()), + static_cast(StreamToBitset<8>("00100010").to_ulong()), + }), + buffer); + const std::string result_bits = BufferToStream(buffer); + EXPECT_EQ(expected_bits, result_bits); +} + +TEST(BufToStream, UInt64_Empty) { + const std::string expected_bits = ""; + std::vector buffer = StreamToBuffer(expected_bits); + EXPECT_TRUE(buffer.empty()); + const std::string result_bits = BufferToStream(buffer); + EXPECT_EQ(expected_bits, result_bits); +} + +TEST(BufToStream, UInt64_OneWord) { + const std::string expected_bits = + "0001000111101110011001101010101000100010110011000100010010001000"; + std::vector buffer = StreamToBuffer(expected_bits); + ASSERT_EQ(1u, buffer.size()); + EXPECT_EQ(0x1122334455667788u, buffer[0]); + const std::string result_bits = BufferToStream(buffer); + EXPECT_EQ(expected_bits, result_bits); +} + +TEST(BufToStream, UInt64_Unaligned) { + const std::string expected_bits = + "0010001001101010011111010010001001001010000111110010010010010101" + "0010001001101010011111111111111111111111"; + std::vector buffer = StreamToBuffer(expected_bits); + EXPECT_EQ(std::vector({ + StreamToBits(expected_bits.substr(0, 64)), + StreamToBits(expected_bits.substr(64, 64)), + }), + buffer); + const std::string result_bits = BufferToStream(buffer); + EXPECT_EQ(PadToWord<64>(expected_bits), result_bits); +} + +TEST(BufToStream, UInt64_MultipleWords) { + const std::string expected_bits = + "0010001001101010011111010010001001001010000111110010010010010101" + "0010001001101010011111111111111111111111000111110010010010010111" + "0000000000000000000000000000000000000000000000000010010011111111"; + std::vector buffer = StreamToBuffer(expected_bits); + EXPECT_EQ(std::vector({ + StreamToBits(expected_bits.substr(0, 64)), + StreamToBits(expected_bits.substr(64, 64)), + StreamToBits(expected_bits.substr(128, 64)), + }), + buffer); + const std::string result_bits = BufferToStream(buffer); + EXPECT_EQ(expected_bits, result_bits); +} + +TEST(PadToWord, Test) { + EXPECT_EQ("10100000", PadToWord<8>("101")); + EXPECT_EQ( + "10100000" + "00000000", + PadToWord<16>("101")); + EXPECT_EQ( + "10100000" + "00000000" + "00000000" + "00000000", + PadToWord<32>("101")); + EXPECT_EQ( + "10100000" + "00000000" + "00000000" + "00000000" + "00000000" + "00000000" + "00000000" + "00000000", + PadToWord<64>("101")); +} + +TEST(BitWriterStringStream, Empty) { + BitWriterStringStream writer; + EXPECT_EQ(0u, writer.GetNumBits()); + EXPECT_EQ(0u, writer.GetDataSizeBytes()); + EXPECT_EQ("", writer.GetStreamRaw()); +} + +TEST(BitWriterStringStream, WriteBits) { + BitWriterStringStream writer; + const uint64_t bits1 = 0x1 | 0x2 | 0x10; + writer.WriteBits(bits1, 5); + EXPECT_EQ(5u, writer.GetNumBits()); + EXPECT_EQ(1u, writer.GetDataSizeBytes()); + EXPECT_EQ("11001", writer.GetStreamRaw()); +} + +TEST(BitWriterStringStream, WriteUnencodedU8) { + BitWriterStringStream writer; + const uint8_t bits = 127; + writer.WriteUnencoded(bits); + EXPECT_EQ(8u, writer.GetNumBits()); + EXPECT_EQ("11111110", writer.GetStreamRaw()); +} + +TEST(BitWriterStringStream, WriteUnencodedS64) { + BitWriterStringStream writer; + const int64_t bits = std::numeric_limits::min() + 7; + writer.WriteUnencoded(bits); + EXPECT_EQ(64u, writer.GetNumBits()); + EXPECT_EQ("1110000000000000000000000000000000000000000000000000000000000001", + writer.GetStreamRaw()); +} + +TEST(BitWriterStringStream, WriteMultiple) { + BitWriterStringStream writer; + + std::string expected_result; + + const uint64_t b2_val = 0x4 | 0x2 | 0x40; + const std::string bits2 = BitsToStream(b2_val, 8); + writer.WriteBits(b2_val, 8); + + const uint64_t val = 0x1 | 0x2 | 0x10; + const std::string bits3 = BitsToStream(val, 8); + writer.WriteBits(val, 8); + + const std::string expected = bits2 + bits3; + + EXPECT_EQ(expected.length(), writer.GetNumBits()); + EXPECT_EQ(2u, writer.GetDataSizeBytes()); + EXPECT_EQ(expected, writer.GetStreamRaw()); + + EXPECT_EQ(PadToWord<8>(expected), BufferToStream(writer.GetDataCopy())); +} + +TEST(BitWriterWord64, Empty) { + BitWriterWord64 writer; + EXPECT_EQ(0u, writer.GetNumBits()); + EXPECT_EQ(0u, writer.GetDataSizeBytes()); +} + +TEST(BitWriterWord64, WriteBits) { + BitWriterWord64 writer; + const uint64_t bits1 = 0x1 | 0x2 | 0x10; + writer.WriteBits(bits1, 5); + writer.WriteBits(bits1, 5); + writer.WriteBits(bits1, 5); + EXPECT_EQ(15u, writer.GetNumBits()); + EXPECT_EQ(2u, writer.GetDataSizeBytes()); +} + +TEST(BitWriterWord64, WriteZeroBits) { + BitWriterWord64 writer; + writer.WriteBits(0, 0); + writer.WriteBits(1, 0); + EXPECT_EQ(0u, writer.GetNumBits()); + writer.WriteBits(1, 1); + writer.WriteBits(0, 0); + writer.WriteBits(0, 63); + EXPECT_EQ(64u, writer.GetNumBits()); + writer.WriteBits(0, 0); + writer.WriteBits(7, 3); + writer.WriteBits(0, 0); +} + +TEST(BitWriterWord64, ComparisonTestWriteLotsOfBits) { + BitWriterStringStream writer1; + BitWriterWord64 writer2(16384); + + for (uint64_t i = 0; i < 65000; i += 25) { + writer1.WriteBits(i, 16); + writer2.WriteBits(i, 16); + ASSERT_EQ(writer1.GetNumBits(), writer2.GetNumBits()); + } +} + +TEST(GetLowerBits, Test) { + EXPECT_EQ(0u, GetLowerBits(255, 0)); + EXPECT_EQ(1u, GetLowerBits(255, 1)); + EXPECT_EQ(3u, GetLowerBits(255, 2)); + EXPECT_EQ(7u, GetLowerBits(255, 3)); + EXPECT_EQ(15u, GetLowerBits(255, 4)); + EXPECT_EQ(31u, GetLowerBits(255, 5)); + EXPECT_EQ(63u, GetLowerBits(255, 6)); + EXPECT_EQ(127u, GetLowerBits(255, 7)); + EXPECT_EQ(255u, GetLowerBits(255, 8)); + EXPECT_EQ(0xFFu, GetLowerBits(0xFFFFFFFF, 8)); + EXPECT_EQ(0xFFFFu, GetLowerBits(0xFFFFFFFF, 16)); + EXPECT_EQ(0xFFFFFFu, GetLowerBits(0xFFFFFFFF, 24)); + EXPECT_EQ(0xFFFFFFu, GetLowerBits(0xFFFFFFFFFFFF, 24)); + EXPECT_EQ(0xFFFFFFFFFFFFFFFFu, + GetLowerBits(0xFFFFFFFFFFFFFFFFu, 64)); + EXPECT_EQ(StreamToBits("1010001110"), + GetLowerBits(StreamToBits("1010001110111101111111"), 10)); +} + +TEST(BitReaderFromString, FromU8) { + std::vector buffer = { + 0xAA, + 0xBB, + 0xCC, + 0xDD, + }; + + const std::string total_stream = + "01010101" + "11011101" + "00110011" + "10111011"; + + BitReaderFromString reader(buffer); + + uint64_t bits = 0; + EXPECT_EQ(2u, reader.ReadBits(&bits, 2)); + EXPECT_EQ(PadToWord<64>("01"), BitsToStream(bits)); + EXPECT_EQ(20u, reader.ReadBits(&bits, 20)); + EXPECT_EQ(PadToWord<64>("01010111011101001100"), BitsToStream(bits)); + EXPECT_EQ(20u, reader.ReadBits(&bits, 20)); + EXPECT_EQ(PadToWord<64>("11101110110000000000"), BitsToStream(bits)); + EXPECT_EQ(22u, reader.ReadBits(&bits, 30)); + EXPECT_EQ(PadToWord<64>("0000000000000000000000"), BitsToStream(bits)); + EXPECT_TRUE(reader.ReachedEnd()); +} + +TEST(BitReaderFromString, FromU64) { + std::vector buffer = { + 0xAAAAAAAAAAAAAAAA, + 0xBBBBBBBBBBBBBBBB, + 0xCCCCCCCCCCCCCCCC, + 0xDDDDDDDDDDDDDDDD, + }; + + const std::string total_stream = + "0101010101010101010101010101010101010101010101010101010101010101" + "1101110111011101110111011101110111011101110111011101110111011101" + "0011001100110011001100110011001100110011001100110011001100110011" + "1011101110111011101110111011101110111011101110111011101110111011"; + + BitReaderFromString reader(buffer); + + uint64_t bits = 0; + size_t pos = 0; + size_t to_read = 5; + while (reader.ReadBits(&bits, to_read) > 0) { + EXPECT_EQ(BitsToStream(bits), + PadToWord<64>(total_stream.substr(pos, to_read))); + pos += to_read; + to_read = (to_read + 35) % 64 + 1; + } + EXPECT_TRUE(reader.ReachedEnd()); +} + +TEST(BitReaderWord64, ReadBitsSingleByte) { + BitReaderWord64 reader(std::vector({uint8_t(0xF0)})); + EXPECT_FALSE(reader.ReachedEnd()); + + uint64_t bits = 0; + EXPECT_EQ(1u, reader.ReadBits(&bits, 1)); + EXPECT_EQ(0u, bits); + EXPECT_EQ(2u, reader.ReadBits(&bits, 2)); + EXPECT_EQ(0u, bits); + EXPECT_EQ(2u, reader.ReadBits(&bits, 2)); + EXPECT_EQ(2u, bits); + EXPECT_EQ(2u, reader.ReadBits(&bits, 2)); + EXPECT_EQ(3u, bits); + EXPECT_FALSE(reader.OnlyZeroesLeft()); + EXPECT_FALSE(reader.ReachedEnd()); + EXPECT_EQ(2u, reader.ReadBits(&bits, 2)); + EXPECT_EQ(1u, bits); + EXPECT_TRUE(reader.OnlyZeroesLeft()); + EXPECT_FALSE(reader.ReachedEnd()); + EXPECT_EQ(55u, reader.ReadBits(&bits, 64)); + EXPECT_EQ(0u, bits); + EXPECT_TRUE(reader.ReachedEnd()); +} + +TEST(BitReaderWord64, ReadBitsTwoWords) { + std::vector buffer = {0x0000000000000001, 0x0000000000FFFFFF}; + + BitReaderWord64 reader(std::move(buffer)); + + uint64_t bits = 0; + EXPECT_EQ(1u, reader.ReadBits(&bits, 1)); + EXPECT_EQ(1u, bits); + EXPECT_EQ(62u, reader.ReadBits(&bits, 62)); + EXPECT_EQ(0u, bits); + EXPECT_EQ(2u, reader.ReadBits(&bits, 2)); + EXPECT_EQ(2u, bits); + EXPECT_EQ(3u, reader.ReadBits(&bits, 3)); + EXPECT_EQ(7u, bits); + EXPECT_FALSE(reader.OnlyZeroesLeft()); + EXPECT_EQ(32u, reader.ReadBits(&bits, 32)); + EXPECT_EQ(0xFFFFFu, bits); + EXPECT_TRUE(reader.OnlyZeroesLeft()); + EXPECT_FALSE(reader.ReachedEnd()); + EXPECT_EQ(28u, reader.ReadBits(&bits, 32)); + EXPECT_EQ(0u, bits); + EXPECT_TRUE(reader.ReachedEnd()); +} + +TEST(BitReaderFromString, ReadUnencodedU8) { + BitReaderFromString reader("11111110"); + uint8_t val = 0; + ASSERT_TRUE(reader.ReadUnencoded(&val)); + EXPECT_EQ(8u, reader.GetNumReadBits()); + EXPECT_EQ(127, val); +} + +TEST(BitReaderFromString, ReadUnencodedU16Fail) { + BitReaderFromString reader("11111110"); + uint16_t val = 0; + ASSERT_FALSE(reader.ReadUnencoded(&val)); +} + +TEST(BitReaderFromString, ReadUnencodedS64) { + BitReaderFromString reader( + "1110000000000000000000000000000000000000000000000000000000000001"); + int64_t val = 0; + ASSERT_TRUE(reader.ReadUnencoded(&val)); + EXPECT_EQ(64u, reader.GetNumReadBits()); + EXPECT_EQ(std::numeric_limits::min() + 7, val); +} + +TEST(BitReaderWord64, FromU8) { + std::vector buffer = { + 0xAA, + 0xBB, + 0xCC, + 0xDD, + }; + + BitReaderWord64 reader(std::move(buffer)); + + uint64_t bits = 0; + EXPECT_EQ(2u, reader.ReadBits(&bits, 2)); + EXPECT_EQ(PadToWord<64>("01"), BitsToStream(bits)); + EXPECT_EQ(20u, reader.ReadBits(&bits, 20)); + EXPECT_EQ(PadToWord<64>("01010111011101001100"), BitsToStream(bits)); + EXPECT_EQ(20u, reader.ReadBits(&bits, 20)); + EXPECT_EQ(PadToWord<64>("11101110110000000000"), BitsToStream(bits)); + EXPECT_EQ(22u, reader.ReadBits(&bits, 30)); + EXPECT_EQ(PadToWord<64>("0000000000000000000000"), BitsToStream(bits)); + EXPECT_TRUE(reader.ReachedEnd()); +} + +TEST(BitReaderWord64, FromU64) { + std::vector buffer = { + 0xAAAAAAAAAAAAAAAA, + 0xBBBBBBBBBBBBBBBB, + 0xCCCCCCCCCCCCCCCC, + 0xDDDDDDDDDDDDDDDD, + }; + + const std::string total_stream = + "0101010101010101010101010101010101010101010101010101010101010101" + "1101110111011101110111011101110111011101110111011101110111011101" + "0011001100110011001100110011001100110011001100110011001100110011" + "1011101110111011101110111011101110111011101110111011101110111011"; + + BitReaderWord64 reader(std::move(buffer)); + + uint64_t bits = 0; + size_t pos = 0; + size_t to_read = 5; + while (reader.ReadBits(&bits, to_read) > 0) { + EXPECT_EQ(BitsToStream(bits), + PadToWord<64>(total_stream.substr(pos, to_read))); + pos += to_read; + to_read = (to_read + 35) % 64 + 1; + } + EXPECT_TRUE(reader.ReachedEnd()); +} + +TEST(BitReaderWord64, ComparisonLotsOfU8) { + std::vector buffer; + for (uint32_t i = 0; i < 10003; ++i) { + buffer.push_back(static_cast(i % 255)); + } + + BitReaderFromString reader1(buffer); + BitReaderWord64 reader2(std::move(buffer)); + + uint64_t bits1 = 0, bits2 = 0; + size_t to_read = 5; + while (reader1.ReadBits(&bits1, to_read) > 0) { + reader2.ReadBits(&bits2, to_read); + EXPECT_EQ(bits1, bits2); + to_read = (to_read + 35) % 64 + 1; + } + + EXPECT_EQ(0u, reader2.ReadBits(&bits2, 1)); +} + +TEST(BitReaderWord64, ComparisonLotsOfU64) { + std::vector buffer; + for (uint64_t i = 0; i < 1000; ++i) { + buffer.push_back(i); + } + + BitReaderFromString reader1(buffer); + BitReaderWord64 reader2(std::move(buffer)); + + uint64_t bits1 = 0, bits2 = 0; + size_t to_read = 5; + while (reader1.ReadBits(&bits1, to_read) > 0) { + reader2.ReadBits(&bits2, to_read); + EXPECT_EQ(bits1, bits2); + to_read = (to_read + 35) % 64 + 1; + } + + EXPECT_EQ(0u, reader2.ReadBits(&bits2, 1)); +} + +TEST(ReadWriteWord64, ReadWriteLotsOfBits) { + BitWriterWord64 writer(16384); + for (uint64_t i = 0; i < 65000; i += 25) { + const uint64_t num_bits = i % 64 + 1; + const uint64_t bits = i >> (64 - num_bits); + writer.WriteBits(bits, size_t(num_bits)); + } + + BitReaderWord64 reader(writer.GetDataCopy()); + for (uint64_t i = 0; i < 65000; i += 25) { + const uint64_t num_bits = i % 64 + 1; + const uint64_t expected_bits = i >> (64 - num_bits); + uint64_t bits = 0; + reader.ReadBits(&bits, size_t(num_bits)); + EXPECT_EQ(expected_bits, bits); + } + + EXPECT_TRUE(reader.OnlyZeroesLeft()); +} + +TEST(VariableWidthWrite, Write0U) { + BitWriterStringStream writer; + writer.WriteVariableWidthU64(0, 2); + EXPECT_EQ("000", writer.GetStreamRaw()); + writer.WriteVariableWidthU32(0, 2); + EXPECT_EQ( + "000" + "000", + writer.GetStreamRaw()); + writer.WriteVariableWidthU16(0, 2); + EXPECT_EQ( + "000" + "000" + "000", + writer.GetStreamRaw()); +} + +TEST(VariableWidthWrite, WriteSmallUnsigned) { + BitWriterStringStream writer; + writer.WriteVariableWidthU64(1, 2); + EXPECT_EQ("100", writer.GetStreamRaw()); + writer.WriteVariableWidthU32(2, 2); + EXPECT_EQ( + "100" + "010", + writer.GetStreamRaw()); + writer.WriteVariableWidthU16(3, 2); + EXPECT_EQ( + "100" + "010" + "110", + writer.GetStreamRaw()); +} + +TEST(VariableWidthWrite, WriteSmallSigned) { + BitWriterStringStream writer; + writer.WriteVariableWidthS64(1, 2, 0); + EXPECT_EQ("010", writer.GetStreamRaw()); + writer.WriteVariableWidthS64(-1, 2, 0); + EXPECT_EQ( + "010" + "100", + writer.GetStreamRaw()); +} + +TEST(VariableWidthWrite, U64Val127ChunkLength7) { + BitWriterStringStream writer; + writer.WriteVariableWidthU64(127, 7); + EXPECT_EQ( + "1111111" + "0", + writer.GetStreamRaw()); +} + +TEST(VariableWidthWrite, U32Val255ChunkLength7) { + BitWriterStringStream writer; + writer.WriteVariableWidthU32(255, 7); + EXPECT_EQ( + "1111111" + "1" + "1000000" + "0", + writer.GetStreamRaw()); +} + +TEST(VariableWidthWrite, U16Val2ChunkLength4) { + BitWriterStringStream writer; + writer.WriteVariableWidthU16(2, 4); + EXPECT_EQ( + "0100" + "0", + writer.GetStreamRaw()); +} + +TEST(VariableWidthWrite, U64ValAAAAChunkLength2) { + BitWriterStringStream writer; + writer.WriteVariableWidthU64(0xAAAA, 2); + EXPECT_EQ( + "01" + "1" + "01" + "1" + "01" + "1" + "01" + "1" + "01" + "1" + "01" + "1" + "01" + "1" + "01" + "0", + writer.GetStreamRaw()); +} + +TEST(VariableWidthRead, U64Val127ChunkLength7) { + BitReaderFromString reader( + "1111111" + "0"); + uint64_t val = 0; + ASSERT_TRUE(reader.ReadVariableWidthU64(&val, 7)); + EXPECT_EQ(127u, val); +} + +TEST(VariableWidthRead, U32Val255ChunkLength7) { + BitReaderFromString reader( + "1111111" + "1" + "1000000" + "0"); + uint32_t val = 0; + ASSERT_TRUE(reader.ReadVariableWidthU32(&val, 7)); + EXPECT_EQ(255u, val); +} + +TEST(VariableWidthRead, U16Val2ChunkLength4) { + BitReaderFromString reader( + "0100" + "0"); + uint16_t val = 0; + ASSERT_TRUE(reader.ReadVariableWidthU16(&val, 4)); + EXPECT_EQ(2u, val); +} + +TEST(VariableWidthRead, U64ValAAAAChunkLength2) { + BitReaderFromString reader( + "01" + "1" + "01" + "1" + "01" + "1" + "01" + "1" + "01" + "1" + "01" + "1" + "01" + "1" + "01" + "0"); + uint64_t val = 0; + ASSERT_TRUE(reader.ReadVariableWidthU64(&val, 2)); + EXPECT_EQ(0xAAAAu, val); +} + +TEST(VariableWidthRead, FailTooShort) { + BitReaderFromString reader("00000001100000"); + uint64_t val = 0; + ASSERT_FALSE(reader.ReadVariableWidthU64(&val, 7)); +} + +TEST(VariableWidthWriteRead, SingleWriteReadU64) { + for (uint64_t i = 0; i < 1000000; i += 1234) { + const uint64_t val = i * i * i; + const size_t chunk_length = size_t(i % 16 + 1); + + BitWriterWord64 writer; + writer.WriteVariableWidthU64(val, chunk_length); + + BitReaderWord64 reader(writer.GetDataCopy()); + uint64_t read_val = 0; + ASSERT_TRUE(reader.ReadVariableWidthU64(&read_val, chunk_length)); + + ASSERT_EQ(val, read_val) << "Chunk length " << chunk_length; + } +} + +TEST(VariableWidthWriteRead, SingleWriteReadS64) { + for (int64_t i = 0; i < 1000000; i += 4321) { + const int64_t val = i * i * (i % 2 ? -i : i); + const size_t chunk_length = size_t(i % 16 + 1); + const size_t zigzag_exponent = size_t(i % 13); + + BitWriterWord64 writer; + writer.WriteVariableWidthS64(val, chunk_length, zigzag_exponent); + + BitReaderWord64 reader(writer.GetDataCopy()); + int64_t read_val = 0; + ASSERT_TRUE( + reader.ReadVariableWidthS64(&read_val, chunk_length, zigzag_exponent)); + + ASSERT_EQ(val, read_val) << "Chunk length " << chunk_length; + } +} + +TEST(VariableWidthWriteRead, SingleWriteReadU32) { + for (uint32_t i = 0; i < 100000; i += 123) { + const uint32_t val = i * i; + const size_t chunk_length = i % 16 + 1; + + BitWriterWord64 writer; + writer.WriteVariableWidthU32(val, chunk_length); + + BitReaderWord64 reader(writer.GetDataCopy()); + uint32_t read_val = 0; + ASSERT_TRUE(reader.ReadVariableWidthU32(&read_val, chunk_length)); + + ASSERT_EQ(val, read_val) << "Chunk length " << chunk_length; + } +} + +TEST(VariableWidthWriteRead, SingleWriteReadU16) { + for (int i = 0; i < 65536; i += 123) { + const uint16_t val = static_cast(i); + const size_t chunk_length = val % 10 + 1; + + BitWriterWord64 writer; + writer.WriteVariableWidthU16(val, chunk_length); + + BitReaderWord64 reader(writer.GetDataCopy()); + uint16_t read_val = 0; + ASSERT_TRUE(reader.ReadVariableWidthU16(&read_val, chunk_length)); + + ASSERT_EQ(val, read_val) << "Chunk length " << chunk_length; + } +} + +TEST(VariableWidthWriteRead, SmallNumbersChunkLength4) { + const std::vector expected_values = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + + BitWriterWord64 writer; + for (uint64_t val : expected_values) { + writer.WriteVariableWidthU64(val, 4); + } + + EXPECT_EQ(50u, writer.GetNumBits()); + + std::vector actual_values; + BitReaderWord64 reader(writer.GetDataCopy()); + while (!reader.OnlyZeroesLeft()) { + uint64_t val = 0; + ASSERT_TRUE(reader.ReadVariableWidthU64(&val, 4)); + actual_values.push_back(val); + } + + EXPECT_EQ(expected_values, actual_values); +} + +TEST(VariableWidthWriteRead, VariedNumbersChunkLength8) { + const std::vector expected_values = {1000, 0, 255, 4294967296}; + const size_t kExpectedNumBits = 9 * (2 + 1 + 1 + 5); + + BitWriterWord64 writer; + for (uint64_t val : expected_values) { + writer.WriteVariableWidthU64(val, 8); + } + + EXPECT_EQ(kExpectedNumBits, writer.GetNumBits()); + + std::vector actual_values; + BitReaderWord64 reader(writer.GetDataCopy()); + while (!reader.OnlyZeroesLeft()) { + uint64_t val = 0; + ASSERT_TRUE(reader.ReadVariableWidthU64(&val, 8)); + actual_values.push_back(val); + } + + EXPECT_EQ(expected_values, actual_values); +} + +} // namespace +} // namespace comp +} // namespace spvtools diff --git a/test/c_interface_test.cpp b/test/c_interface_test.cpp new file mode 100644 index 000000000..c644fb9ba --- /dev/null +++ b/test/c_interface_test.cpp @@ -0,0 +1,299 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "source/table.h" +#include "spirv-tools/libspirv.h" + +namespace spvtools { +namespace { + +// TODO(antiagainst): Use public C API for setting the consumer once exists. +#ifndef SPIRV_TOOLS_SHAREDLIB +void SetContextMessageConsumer(spv_context context, MessageConsumer consumer) { + spvtools::SetContextMessageConsumer(context, consumer); +} +#else +void SetContextMessageConsumer(spv_context, MessageConsumer) {} +#endif + +// The default consumer is a null std::function. +TEST(CInterface, DefaultConsumerNullDiagnosticForValidInput) { + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + const char input_text[] = + "OpCapability Shader\n" + "OpCapability Linkage\n" + "OpMemoryModel Logical GLSL450"; + + spv_binary binary = nullptr; + EXPECT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text, + sizeof(input_text), &binary, nullptr)); + + { + // Sadly the compiler don't allow me to feed binary directly to + // spvValidate(). + spv_const_binary_t b{binary->code, binary->wordCount}; + EXPECT_EQ(SPV_SUCCESS, spvValidate(context, &b, nullptr)); + } + + spv_text text = nullptr; + EXPECT_EQ(SPV_SUCCESS, spvBinaryToText(context, binary->code, + binary->wordCount, 0, &text, nullptr)); + + spvTextDestroy(text); + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +// The default consumer is a null std::function. +TEST(CInterface, DefaultConsumerNullDiagnosticForInvalidAssembling) { + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + const char input_text[] = "%1 = OpName"; + + spv_binary binary = nullptr; + EXPECT_EQ(SPV_ERROR_INVALID_TEXT, + spvTextToBinary(context, input_text, sizeof(input_text), &binary, + nullptr)); + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +// The default consumer is a null std::function. +TEST(CInterface, DefaultConsumerNullDiagnosticForInvalidDiassembling) { + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + const char input_text[] = "OpNop"; + + spv_binary binary = nullptr; + ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text, + sizeof(input_text), &binary, nullptr)); + // Change OpNop to an invalid (wordcount|opcode) word. + binary->code[binary->wordCount - 1] = 0xffffffff; + + spv_text text = nullptr; + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + spvBinaryToText(context, binary->code, binary->wordCount, 0, &text, + nullptr)); + + spvTextDestroy(text); + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +// The default consumer is a null std::function. +TEST(CInterface, DefaultConsumerNullDiagnosticForInvalidValidating) { + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + const char input_text[] = "OpNop"; + + spv_binary binary = nullptr; + ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text, + sizeof(input_text), &binary, nullptr)); + + spv_const_binary_t b{binary->code, binary->wordCount}; + EXPECT_EQ(SPV_ERROR_INVALID_LAYOUT, spvValidate(context, &b, nullptr)); + + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +TEST(CInterface, SpecifyConsumerNullDiagnosticForAssembling) { + const char input_text[] = "%1 = OpName\n"; + + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + SetContextMessageConsumer( + context, + [&invocation](spv_message_level_t level, const char* source, + const spv_position_t& position, const char* message) { + ++invocation; + EXPECT_EQ(SPV_MSG_ERROR, level); + // The error happens at scanning the begining of second line. + EXPECT_STREQ("input", source); + EXPECT_EQ(1u, position.line); + EXPECT_EQ(0u, position.column); + EXPECT_EQ(12u, position.index); + EXPECT_STREQ("Expected operand, found end of stream.", message); + }); + + spv_binary binary = nullptr; + EXPECT_EQ(SPV_ERROR_INVALID_TEXT, + spvTextToBinary(context, input_text, sizeof(input_text), &binary, + nullptr)); +#ifndef SPIRV_TOOLS_SHAREDLIB + EXPECT_EQ(1, invocation); +#endif + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +TEST(CInterface, SpecifyConsumerNullDiagnosticForDisassembling) { + const char input_text[] = "OpNop"; + + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + SetContextMessageConsumer( + context, + [&invocation](spv_message_level_t level, const char* source, + const spv_position_t& position, const char* message) { + ++invocation; + EXPECT_EQ(SPV_MSG_ERROR, level); + EXPECT_STREQ("input", source); + EXPECT_EQ(0u, position.line); + EXPECT_EQ(0u, position.column); + EXPECT_EQ(1u, position.index); + EXPECT_STREQ("Invalid opcode: 65535", message); + }); + + spv_binary binary = nullptr; + ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text, + sizeof(input_text), &binary, nullptr)); + // Change OpNop to an invalid (wordcount|opcode) word. + binary->code[binary->wordCount - 1] = 0xffffffff; + + spv_text text = nullptr; + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + spvBinaryToText(context, binary->code, binary->wordCount, 0, &text, + nullptr)); +#ifndef SPIRV_TOOLS_SHAREDLIB + EXPECT_EQ(1, invocation); +#endif + + spvTextDestroy(text); + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +TEST(CInterface, SpecifyConsumerNullDiagnosticForValidating) { + const char input_text[] = "OpNop"; + + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + SetContextMessageConsumer( + context, + [&invocation](spv_message_level_t level, const char* source, + const spv_position_t& position, const char* message) { + ++invocation; + EXPECT_EQ(SPV_MSG_ERROR, level); + EXPECT_STREQ("input", source); + EXPECT_EQ(0u, position.line); + EXPECT_EQ(0u, position.column); + // TODO(antiagainst): what validation reports is not a word offset here. + // It is inconsistent with diassembler. Should be fixed. + EXPECT_EQ(1u, position.index); + EXPECT_STREQ( + "Nop cannot appear before the memory model instruction\n" + " OpNop\n", + message); + }); + + spv_binary binary = nullptr; + ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text, + sizeof(input_text), &binary, nullptr)); + + spv_const_binary_t b{binary->code, binary->wordCount}; + EXPECT_EQ(SPV_ERROR_INVALID_LAYOUT, spvValidate(context, &b, nullptr)); +#ifndef SPIRV_TOOLS_SHAREDLIB + EXPECT_EQ(1, invocation); +#endif + + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +// When having both a consumer and an diagnostic object, the diagnostic object +// should take priority. +TEST(CInterface, SpecifyConsumerSpecifyDiagnosticForAssembling) { + const char input_text[] = "%1 = OpName"; + + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + SetContextMessageConsumer( + context, + [&invocation](spv_message_level_t, const char*, const spv_position_t&, + const char*) { ++invocation; }); + + spv_binary binary = nullptr; + spv_diagnostic diagnostic = nullptr; + EXPECT_EQ(SPV_ERROR_INVALID_TEXT, + spvTextToBinary(context, input_text, sizeof(input_text), &binary, + &diagnostic)); + EXPECT_EQ(0, invocation); // Consumer should not be invoked at all. + EXPECT_STREQ("Expected operand, found end of stream.", diagnostic->error); + + spvDiagnosticDestroy(diagnostic); + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +TEST(CInterface, SpecifyConsumerSpecifyDiagnosticForDisassembling) { + const char input_text[] = "OpNop"; + + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + SetContextMessageConsumer( + context, + [&invocation](spv_message_level_t, const char*, const spv_position_t&, + const char*) { ++invocation; }); + + spv_binary binary = nullptr; + ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text, + sizeof(input_text), &binary, nullptr)); + // Change OpNop to an invalid (wordcount|opcode) word. + binary->code[binary->wordCount - 1] = 0xffffffff; + + spv_diagnostic diagnostic = nullptr; + spv_text text = nullptr; + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + spvBinaryToText(context, binary->code, binary->wordCount, 0, &text, + &diagnostic)); + + EXPECT_EQ(0, invocation); // Consumer should not be invoked at all. + EXPECT_STREQ("Invalid opcode: 65535", diagnostic->error); + + spvTextDestroy(text); + spvDiagnosticDestroy(diagnostic); + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +TEST(CInterface, SpecifyConsumerSpecifyDiagnosticForValidating) { + const char input_text[] = "OpNop"; + + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + SetContextMessageConsumer( + context, + [&invocation](spv_message_level_t, const char*, const spv_position_t&, + const char*) { ++invocation; }); + + spv_binary binary = nullptr; + ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text, + sizeof(input_text), &binary, nullptr)); + + spv_diagnostic diagnostic = nullptr; + spv_const_binary_t b{binary->code, binary->wordCount}; + EXPECT_EQ(SPV_ERROR_INVALID_LAYOUT, spvValidate(context, &b, &diagnostic)); + + EXPECT_EQ(0, invocation); // Consumer should not be invoked at all. + EXPECT_STREQ( + "Nop cannot appear before the memory model instruction\n" + " OpNop\n", + diagnostic->error); + + spvDiagnosticDestroy(diagnostic); + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +} // namespace +} // namespace spvtools diff --git a/test/comment_test.cpp b/test/comment_test.cpp new file mode 100644 index 000000000..f46b72ac5 --- /dev/null +++ b/test/comment_test.cpp @@ -0,0 +1,50 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using spvtest::Concatenate; +using spvtest::MakeInstruction; +using spvtest::MakeVector; +using spvtest::TextToBinaryTest; +using testing::Eq; + +TEST_F(TextToBinaryTest, Whitespace) { + std::string input = R"( +; I'm a proud comment at the beginning of the file +; I hide: OpCapability Shader + OpMemoryModel Logical Simple ; comment after instruction +;;;;;;;; many ;'s + %glsl450 = OpExtInstImport "GLSL.std.450" + ; comment indented +)"; + + EXPECT_THAT( + CompiledInstructions(input), + Eq(Concatenate({MakeInstruction(SpvOpMemoryModel, + {uint32_t(SpvAddressingModelLogical), + uint32_t(SpvMemoryModelSimple)}), + MakeInstruction(SpvOpExtInstImport, {1}, + MakeVector("GLSL.std.450"))}))); +} + +} // namespace +} // namespace spvtools diff --git a/test/comp/CMakeLists.txt b/test/comp/CMakeLists.txt new file mode 100644 index 000000000..c947fde0c --- /dev/null +++ b/test/comp/CMakeLists.txt @@ -0,0 +1,29 @@ +# Copyright (c) 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set(VAL_TEST_COMMON_SRCS + ${CMAKE_CURRENT_SOURCE_DIR}/../test_fixture.h + ${CMAKE_CURRENT_SOURCE_DIR}/../unit_spirv.h +) + +if(SPIRV_BUILD_COMPRESSION) + add_spvtools_unittest(TARGET markv_codec + SRCS + markv_codec_test.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../../tools/comp/markv_model_factory.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../../tools/comp/markv_model_shader.cpp + ${VAL_TEST_COMMON_SRCS} + LIBS SPIRV-Tools-comp ${SPIRV_TOOLS} + ) +endif(SPIRV_BUILD_COMPRESSION) diff --git a/test/comp/markv_codec_test.cpp b/test/comp/markv_codec_test.cpp new file mode 100644 index 000000000..283fcd3ca --- /dev/null +++ b/test/comp/markv_codec_test.cpp @@ -0,0 +1,829 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests for unique type declaration rules validator. + +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/comp/markv.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +#include "tools/comp/markv_model_factory.h" + +namespace spvtools { +namespace comp { +namespace { + +using spvtest::ScopedContext; +using MarkvTest = ::testing::TestWithParam; + +void DiagnosticsMessageHandler(spv_message_level_t level, const char*, + const spv_position_t& position, + const char* message) { + switch (level) { + case SPV_MSG_FATAL: + case SPV_MSG_INTERNAL_ERROR: + case SPV_MSG_ERROR: + std::cerr << "error: " << position.index << ": " << message << std::endl; + break; + case SPV_MSG_WARNING: + std::cout << "warning: " << position.index << ": " << message + << std::endl; + break; + case SPV_MSG_INFO: + std::cout << "info: " << position.index << ": " << message << std::endl; + break; + default: + break; + } +} + +// Compiles |code| to SPIR-V |words|. +void Compile(const std::string& code, std::vector* words, + uint32_t options = SPV_TEXT_TO_BINARY_OPTION_NONE, + spv_target_env env = SPV_ENV_UNIVERSAL_1_2) { + spvtools::Context ctx(env); + ctx.SetMessageConsumer(DiagnosticsMessageHandler); + + spv_binary spirv_binary; + ASSERT_EQ(SPV_SUCCESS, spvTextToBinaryWithOptions( + ctx.CContext(), code.c_str(), code.size(), options, + &spirv_binary, nullptr)); + + *words = std::vector(spirv_binary->code, + spirv_binary->code + spirv_binary->wordCount); + + spvBinaryDestroy(spirv_binary); +} + +// Disassembles SPIR-V |words| to |out_text|. +void Disassemble(const std::vector& words, std::string* out_text, + spv_target_env env = SPV_ENV_UNIVERSAL_1_2) { + spvtools::Context ctx(env); + ctx.SetMessageConsumer(DiagnosticsMessageHandler); + + spv_text text = nullptr; + ASSERT_EQ(SPV_SUCCESS, spvBinaryToText(ctx.CContext(), words.data(), + words.size(), 0, &text, nullptr)); + assert(text); + + *out_text = std::string(text->str, text->length); + spvTextDestroy(text); +} + +// Encodes/decodes |original|, assembles/dissasembles |original|, then compares +// the results of the two operations. +void TestEncodeDecode(MarkvModelType model_type, + const std::string& original_text) { + spvtools::Context ctx(SPV_ENV_UNIVERSAL_1_2); + std::unique_ptr model = CreateMarkvModel(model_type); + MarkvCodecOptions options; + + std::vector expected_binary; + Compile(original_text, &expected_binary); + ASSERT_FALSE(expected_binary.empty()); + + std::string expected_text; + Disassemble(expected_binary, &expected_text); + ASSERT_FALSE(expected_text.empty()); + + std::vector binary_to_encode; + Compile(original_text, &binary_to_encode, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_FALSE(binary_to_encode.empty()); + + std::stringstream encoder_comments; + const auto output_to_string_stream = + [&encoder_comments](const std::string& str) { encoder_comments << str; }; + + std::vector markv; + ASSERT_EQ(SPV_SUCCESS, + SpirvToMarkv(ctx.CContext(), binary_to_encode, options, *model, + DiagnosticsMessageHandler, output_to_string_stream, + MarkvDebugConsumer(), &markv)); + ASSERT_FALSE(markv.empty()); + + std::vector decoded_binary; + ASSERT_EQ(SPV_SUCCESS, + MarkvToSpirv(ctx.CContext(), markv, options, *model, + DiagnosticsMessageHandler, MarkvLogConsumer(), + MarkvDebugConsumer(), &decoded_binary)); + ASSERT_FALSE(decoded_binary.empty()); + + EXPECT_EQ(expected_binary, decoded_binary) << encoder_comments.str(); + + std::string decoded_text; + Disassemble(decoded_binary, &decoded_text); + ASSERT_FALSE(decoded_text.empty()); + + EXPECT_EQ(expected_text, decoded_text) << encoder_comments.str(); +} + +void TestEncodeDecodeShaderMainBody(MarkvModelType model_type, + const std::string& body) { + const std::string prefix = + R"( +OpCapability Shader +OpCapability Int64 +OpCapability Float64 +%ext_inst = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f32 = OpTypeFloat 32 +%u32 = OpTypeInt 32 0 +%s32 = OpTypeInt 32 1 +%f64 = OpTypeFloat 64 +%u64 = OpTypeInt 64 0 +%s64 = OpTypeInt 64 1 +%boolvec2 = OpTypeVector %bool 2 +%s32vec2 = OpTypeVector %s32 2 +%u32vec2 = OpTypeVector %u32 2 +%f32vec2 = OpTypeVector %f32 2 +%f64vec2 = OpTypeVector %f64 2 +%boolvec3 = OpTypeVector %bool 3 +%u32vec3 = OpTypeVector %u32 3 +%s32vec3 = OpTypeVector %s32 3 +%f32vec3 = OpTypeVector %f32 3 +%f64vec3 = OpTypeVector %f64 3 +%boolvec4 = OpTypeVector %bool 4 +%u32vec4 = OpTypeVector %u32 4 +%s32vec4 = OpTypeVector %s32 4 +%f32vec4 = OpTypeVector %f32 4 +%f64vec4 = OpTypeVector %f64 4 + +%f32_0 = OpConstant %f32 0 +%f32_1 = OpConstant %f32 1 +%f32_2 = OpConstant %f32 2 +%f32_3 = OpConstant %f32 3 +%f32_4 = OpConstant %f32 4 +%f32_pi = OpConstant %f32 3.14159 + +%s32_0 = OpConstant %s32 0 +%s32_1 = OpConstant %s32 1 +%s32_2 = OpConstant %s32 2 +%s32_3 = OpConstant %s32 3 +%s32_4 = OpConstant %s32 4 +%s32_m1 = OpConstant %s32 -1 + +%u32_0 = OpConstant %u32 0 +%u32_1 = OpConstant %u32 1 +%u32_2 = OpConstant %u32 2 +%u32_3 = OpConstant %u32 3 +%u32_4 = OpConstant %u32 4 + +%u32vec2_01 = OpConstantComposite %u32vec2 %u32_0 %u32_1 +%u32vec2_12 = OpConstantComposite %u32vec2 %u32_1 %u32_2 +%u32vec3_012 = OpConstantComposite %u32vec3 %u32_0 %u32_1 %u32_2 +%u32vec3_123 = OpConstantComposite %u32vec3 %u32_1 %u32_2 %u32_3 +%u32vec4_0123 = OpConstantComposite %u32vec4 %u32_0 %u32_1 %u32_2 %u32_3 +%u32vec4_1234 = OpConstantComposite %u32vec4 %u32_1 %u32_2 %u32_3 %u32_4 + +%s32vec2_01 = OpConstantComposite %s32vec2 %s32_0 %s32_1 +%s32vec2_12 = OpConstantComposite %s32vec2 %s32_1 %s32_2 +%s32vec3_012 = OpConstantComposite %s32vec3 %s32_0 %s32_1 %s32_2 +%s32vec3_123 = OpConstantComposite %s32vec3 %s32_1 %s32_2 %s32_3 +%s32vec4_0123 = OpConstantComposite %s32vec4 %s32_0 %s32_1 %s32_2 %s32_3 +%s32vec4_1234 = OpConstantComposite %s32vec4 %s32_1 %s32_2 %s32_3 %s32_4 + +%f32vec2_01 = OpConstantComposite %f32vec2 %f32_0 %f32_1 +%f32vec2_12 = OpConstantComposite %f32vec2 %f32_1 %f32_2 +%f32vec3_012 = OpConstantComposite %f32vec3 %f32_0 %f32_1 %f32_2 +%f32vec3_123 = OpConstantComposite %f32vec3 %f32_1 %f32_2 %f32_3 +%f32vec4_0123 = OpConstantComposite %f32vec4 %f32_0 %f32_1 %f32_2 %f32_3 +%f32vec4_1234 = OpConstantComposite %f32vec4 %f32_1 %f32_2 %f32_3 %f32_4 + +%main = OpFunction %void None %func +%main_entry = OpLabel)"; + + const std::string suffix = + R"( +OpReturn +OpFunctionEnd)"; + + TestEncodeDecode(model_type, prefix + body + suffix); +} + +TEST_P(MarkvTest, U32Literal) { + TestEncodeDecode(GetParam(), R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%u32 = OpTypeInt 32 0 +%100 = OpConstant %u32 0 +%200 = OpConstant %u32 1 +%300 = OpConstant %u32 4294967295 +)"); +} + +TEST_P(MarkvTest, S32Literal) { + TestEncodeDecode(GetParam(), R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%s32 = OpTypeInt 32 1 +%100 = OpConstant %s32 0 +%200 = OpConstant %s32 1 +%300 = OpConstant %s32 -1 +%400 = OpConstant %s32 2147483647 +%500 = OpConstant %s32 -2147483648 +)"); +} + +TEST_P(MarkvTest, U64Literal) { + TestEncodeDecode(GetParam(), R"( +OpCapability Shader +OpCapability Linkage +OpCapability Int64 +OpMemoryModel Logical GLSL450 +%u64 = OpTypeInt 64 0 +%100 = OpConstant %u64 0 +%200 = OpConstant %u64 1 +%300 = OpConstant %u64 18446744073709551615 +)"); +} + +TEST_P(MarkvTest, S64Literal) { + TestEncodeDecode(GetParam(), R"( +OpCapability Shader +OpCapability Linkage +OpCapability Int64 +OpMemoryModel Logical GLSL450 +%s64 = OpTypeInt 64 1 +%100 = OpConstant %s64 0 +%200 = OpConstant %s64 1 +%300 = OpConstant %s64 -1 +%400 = OpConstant %s64 9223372036854775807 +%500 = OpConstant %s64 -9223372036854775808 +)"); +} + +TEST_P(MarkvTest, U16Literal) { + TestEncodeDecode(GetParam(), R"( +OpCapability Shader +OpCapability Linkage +OpCapability Int16 +OpMemoryModel Logical GLSL450 +%u16 = OpTypeInt 16 0 +%100 = OpConstant %u16 0 +%200 = OpConstant %u16 1 +%300 = OpConstant %u16 65535 +)"); +} + +TEST_P(MarkvTest, S16Literal) { + TestEncodeDecode(GetParam(), R"( +OpCapability Shader +OpCapability Linkage +OpCapability Int16 +OpMemoryModel Logical GLSL450 +%s16 = OpTypeInt 16 1 +%100 = OpConstant %s16 0 +%200 = OpConstant %s16 1 +%300 = OpConstant %s16 -1 +%400 = OpConstant %s16 32767 +%500 = OpConstant %s16 -32768 +)"); +} + +TEST_P(MarkvTest, F32Literal) { + TestEncodeDecode(GetParam(), R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%f32 = OpTypeFloat 32 +%100 = OpConstant %f32 0 +%200 = OpConstant %f32 1 +%300 = OpConstant %f32 0.1 +%400 = OpConstant %f32 -0.1 +)"); +} + +TEST_P(MarkvTest, F64Literal) { + TestEncodeDecode(GetParam(), R"( +OpCapability Shader +OpCapability Linkage +OpCapability Float64 +OpMemoryModel Logical GLSL450 +%f64 = OpTypeFloat 64 +%100 = OpConstant %f64 0 +%200 = OpConstant %f64 1 +%300 = OpConstant %f64 0.1 +%400 = OpConstant %f64 -0.1 +)"); +} + +TEST_P(MarkvTest, F16Literal) { + TestEncodeDecode(GetParam(), R"( +OpCapability Shader +OpCapability Linkage +OpCapability Float16 +OpMemoryModel Logical GLSL450 +%f16 = OpTypeFloat 16 +%100 = OpConstant %f16 0 +%200 = OpConstant %f16 1 +%300 = OpConstant %f16 0.1 +%400 = OpConstant %f16 -0.1 +)"); +} + +TEST_P(MarkvTest, StringLiteral) { + TestEncodeDecode(GetParam(), R"( +OpCapability Shader +OpCapability Linkage +OpExtension "SPV_KHR_16bit_storage" +OpExtension "xxx" +OpExtension "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" +OpExtension "" +OpMemoryModel Logical GLSL450 +)"); +} + +TEST_P(MarkvTest, WithFunction) { + TestEncodeDecode(GetParam(), R"( +OpCapability Addresses +OpCapability Kernel +OpCapability GenericPointer +OpCapability Linkage +OpExtension "SPV_KHR_16bit_storage" +OpMemoryModel Physical32 OpenCL +%f32 = OpTypeFloat 32 +%u32 = OpTypeInt 32 0 +%void = OpTypeVoid +%void_func = OpTypeFunction %void +%100 = OpConstant %u32 1 +%200 = OpConstant %u32 2 +%main = OpFunction %void None %void_func +%entry_main = OpLabel +%300 = OpIAdd %u32 %100 %200 +OpReturn +OpFunctionEnd +)"); +} + +TEST_P(MarkvTest, WithMultipleFunctions) { + TestEncodeDecode(GetParam(), R"( +OpCapability Addresses +OpCapability Kernel +OpCapability GenericPointer +OpCapability Linkage +OpMemoryModel Physical32 OpenCL +%f32 = OpTypeFloat 32 +%one = OpConstant %f32 1 +%void = OpTypeVoid +%void_func = OpTypeFunction %void +%f32_func = OpTypeFunction %f32 %f32 +%sqr_plus_one = OpFunction %f32 None %f32_func +%x = OpFunctionParameter %f32 +%100 = OpLabel +%x2 = OpFMul %f32 %x %x +%x2p1 = OpFunctionCall %f32 %plus_one %x2 +OpReturnValue %x2p1 +OpFunctionEnd +%plus_one = OpFunction %f32 None %f32_func +%y = OpFunctionParameter %f32 +%200 = OpLabel +%yp1 = OpFAdd %f32 %y %one +OpReturnValue %yp1 +OpFunctionEnd +%main = OpFunction %void None %void_func +%entry_main = OpLabel +%1p1 = OpFunctionCall %f32 %sqr_plus_one %one +OpReturn +OpFunctionEnd +)"); +} + +TEST_P(MarkvTest, ForwardDeclaredId) { + TestEncodeDecode(GetParam(), R"( +OpCapability Addresses +OpCapability Kernel +OpCapability GenericPointer +OpCapability Linkage +OpMemoryModel Physical32 OpenCL +OpEntryPoint Kernel %1 "simple_kernel" +%2 = OpTypeInt 32 0 +%3 = OpTypeVector %2 2 +%4 = OpConstant %2 2 +%5 = OpTypeArray %2 %4 +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%1 = OpFunction %6 None %7 +%8 = OpLabel +OpReturn +OpFunctionEnd +)"); +} + +TEST_P(MarkvTest, WithSwitch) { + TestEncodeDecode(GetParam(), R"( +OpCapability Addresses +OpCapability Kernel +OpCapability GenericPointer +OpCapability Linkage +OpCapability Int64 +OpMemoryModel Physical32 OpenCL +%u64 = OpTypeInt 64 0 +%void = OpTypeVoid +%void_func = OpTypeFunction %void +%val = OpConstant %u64 1 +%main = OpFunction %void None %void_func +%entry_main = OpLabel +OpSwitch %val %default 1 %case1 1000000000000 %case2 +%case1 = OpLabel +OpNop +OpBranch %after_switch +%case2 = OpLabel +OpNop +OpBranch %after_switch +%default = OpLabel +OpNop +OpBranch %after_switch +%after_switch = OpLabel +OpReturn +OpFunctionEnd +)"); +} + +TEST_P(MarkvTest, WithLoop) { + TestEncodeDecode(GetParam(), R"( +OpCapability Addresses +OpCapability Kernel +OpCapability GenericPointer +OpCapability Linkage +OpMemoryModel Physical32 OpenCL +%void = OpTypeVoid +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%entry_main = OpLabel +OpLoopMerge %merge %continue DontUnroll|DependencyLength 10 +OpBranch %begin_loop +%begin_loop = OpLabel +OpNop +OpBranch %continue +%continue = OpLabel +OpNop +OpBranch %begin_loop +%merge = OpLabel +OpReturn +OpFunctionEnd +)"); +} + +TEST_P(MarkvTest, WithDecorate) { + TestEncodeDecode(GetParam(), R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 ArrayStride 4 +OpDecorate %1 Uniform +%2 = OpTypeFloat 32 +%1 = OpTypeRuntimeArray %2 +)"); +} + +TEST_P(MarkvTest, WithExtInst) { + TestEncodeDecode(GetParam(), R"( +OpCapability Addresses +OpCapability Kernel +OpCapability GenericPointer +OpCapability Linkage +%opencl = OpExtInstImport "OpenCL.std" +OpMemoryModel Physical32 OpenCL +%f32 = OpTypeFloat 32 +%void = OpTypeVoid +%void_func = OpTypeFunction %void +%100 = OpConstant %f32 1.1 +%main = OpFunction %void None %void_func +%entry_main = OpLabel +%200 = OpExtInst %f32 %opencl cos %100 +OpReturn +OpFunctionEnd +)"); +} + +TEST_P(MarkvTest, F32Mul) { + TestEncodeDecodeShaderMainBody(GetParam(), R"( +%val1 = OpFMul %f32 %f32_0 %f32_1 +%val2 = OpFMul %f32 %f32_2 %f32_0 +%val3 = OpFMul %f32 %f32_pi %f32_2 +%val4 = OpFMul %f32 %f32_1 %f32_1 +)"); +} + +TEST_P(MarkvTest, U32Mul) { + TestEncodeDecodeShaderMainBody(GetParam(), R"( +%val1 = OpIMul %u32 %u32_0 %u32_1 +%val2 = OpIMul %u32 %u32_2 %u32_0 +%val3 = OpIMul %u32 %u32_3 %u32_2 +%val4 = OpIMul %u32 %u32_1 %u32_1 +)"); +} + +TEST_P(MarkvTest, S32Mul) { + TestEncodeDecodeShaderMainBody(GetParam(), R"( +%val1 = OpIMul %s32 %s32_0 %s32_1 +%val2 = OpIMul %s32 %s32_2 %s32_0 +%val3 = OpIMul %s32 %s32_m1 %s32_2 +%val4 = OpIMul %s32 %s32_1 %s32_1 +)"); +} + +TEST_P(MarkvTest, F32Add) { + TestEncodeDecodeShaderMainBody(GetParam(), R"( +%val1 = OpFAdd %f32 %f32_0 %f32_1 +%val2 = OpFAdd %f32 %f32_2 %f32_0 +%val3 = OpFAdd %f32 %f32_pi %f32_2 +%val4 = OpFAdd %f32 %f32_1 %f32_1 +)"); +} + +TEST_P(MarkvTest, U32Add) { + TestEncodeDecodeShaderMainBody(GetParam(), R"( +%val1 = OpIAdd %u32 %u32_0 %u32_1 +%val2 = OpIAdd %u32 %u32_2 %u32_0 +%val3 = OpIAdd %u32 %u32_3 %u32_2 +%val4 = OpIAdd %u32 %u32_1 %u32_1 +)"); +} + +TEST_P(MarkvTest, S32Add) { + TestEncodeDecodeShaderMainBody(GetParam(), R"( +%val1 = OpIAdd %s32 %s32_0 %s32_1 +%val2 = OpIAdd %s32 %s32_2 %s32_0 +%val3 = OpIAdd %s32 %s32_m1 %s32_2 +%val4 = OpIAdd %s32 %s32_1 %s32_1 +)"); +} + +TEST_P(MarkvTest, F32Dot) { + TestEncodeDecodeShaderMainBody(GetParam(), R"( +%dot2_1 = OpDot %f32 %f32vec2_01 %f32vec2_12 +%dot2_2 = OpDot %f32 %f32vec2_01 %f32vec2_01 +%dot2_3 = OpDot %f32 %f32vec2_12 %f32vec2_12 +%dot3_1 = OpDot %f32 %f32vec3_012 %f32vec3_123 +%dot3_2 = OpDot %f32 %f32vec3_012 %f32vec3_012 +%dot3_3 = OpDot %f32 %f32vec3_123 %f32vec3_123 +%dot4_1 = OpDot %f32 %f32vec4_0123 %f32vec4_1234 +%dot4_2 = OpDot %f32 %f32vec4_0123 %f32vec4_0123 +%dot4_3 = OpDot %f32 %f32vec4_1234 %f32vec4_1234 +)"); +} + +TEST_P(MarkvTest, F32VectorCompositeConstruct) { + TestEncodeDecodeShaderMainBody(GetParam(), R"( +%cc1 = OpCompositeConstruct %f32vec4 %f32vec2_01 %f32vec2_12 +%cc2 = OpCompositeConstruct %f32vec3 %f32vec2_01 %f32_2 +%cc3 = OpCompositeConstruct %f32vec2 %f32_1 %f32_2 +%cc4 = OpCompositeConstruct %f32vec4 %f32_1 %f32_2 %cc3 +)"); +} + +TEST_P(MarkvTest, U32VectorCompositeConstruct) { + TestEncodeDecodeShaderMainBody(GetParam(), R"( +%cc1 = OpCompositeConstruct %u32vec4 %u32vec2_01 %u32vec2_12 +%cc2 = OpCompositeConstruct %u32vec3 %u32vec2_01 %u32_2 +%cc3 = OpCompositeConstruct %u32vec2 %u32_1 %u32_2 +%cc4 = OpCompositeConstruct %u32vec4 %u32_1 %u32_2 %cc3 +)"); +} + +TEST_P(MarkvTest, S32VectorCompositeConstruct) { + TestEncodeDecodeShaderMainBody(GetParam(), R"( +%cc1 = OpCompositeConstruct %u32vec4 %u32vec2_01 %u32vec2_12 +%cc2 = OpCompositeConstruct %u32vec3 %u32vec2_01 %u32_2 +%cc3 = OpCompositeConstruct %u32vec2 %u32_1 %u32_2 +%cc4 = OpCompositeConstruct %u32vec4 %u32_1 %u32_2 %cc3 +)"); +} + +TEST_P(MarkvTest, F32VectorCompositeExtract) { + TestEncodeDecodeShaderMainBody(GetParam(), R"( +%f32vec4_3210 = OpCompositeConstruct %f32vec4 %f32_3 %f32_2 %f32_1 %f32_0 +%f32vec3_013 = OpCompositeExtract %f32vec3 %f32vec4_0123 0 1 3 +)"); +} + +TEST_P(MarkvTest, F32VectorComparison) { + TestEncodeDecodeShaderMainBody(GetParam(), R"( +%f32vec4_3210 = OpCompositeConstruct %f32vec4 %f32_3 %f32_2 %f32_1 %f32_0 +%c1 = OpFOrdEqual %boolvec4 %f32vec4_0123 %f32vec4_3210 +%c2 = OpFUnordEqual %boolvec4 %f32vec4_0123 %f32vec4_3210 +%c3 = OpFOrdNotEqual %boolvec4 %f32vec4_0123 %f32vec4_3210 +%c4 = OpFUnordNotEqual %boolvec4 %f32vec4_0123 %f32vec4_3210 +%c5 = OpFOrdLessThan %boolvec4 %f32vec4_0123 %f32vec4_3210 +%c6 = OpFUnordLessThan %boolvec4 %f32vec4_0123 %f32vec4_3210 +%c7 = OpFOrdGreaterThan %boolvec4 %f32vec4_0123 %f32vec4_3210 +%c8 = OpFUnordGreaterThan %boolvec4 %f32vec4_0123 %f32vec4_3210 +%c9 = OpFOrdLessThanEqual %boolvec4 %f32vec4_0123 %f32vec4_3210 +%c10 = OpFUnordLessThanEqual %boolvec4 %f32vec4_0123 %f32vec4_3210 +%c11 = OpFOrdGreaterThanEqual %boolvec4 %f32vec4_0123 %f32vec4_3210 +%c12 = OpFUnordGreaterThanEqual %boolvec4 %f32vec4_0123 %f32vec4_3210 +)"); +} + +TEST_P(MarkvTest, VectorShuffle) { + TestEncodeDecodeShaderMainBody(GetParam(), R"( +%f32vec4_3210 = OpCompositeConstruct %f32vec4 %f32_3 %f32_2 %f32_1 %f32_0 +%sh1 = OpVectorShuffle %f32vec2 %f32vec4_0123 %f32vec4_3210 3 6 +%sh2 = OpVectorShuffle %f32vec3 %f32vec2_01 %f32vec4_3210 0 3 4 +)"); +} + +TEST_P(MarkvTest, VectorTimesScalar) { + TestEncodeDecodeShaderMainBody(GetParam(), R"( +%f32vec4_3210 = OpCompositeConstruct %f32vec4 %f32_3 %f32_2 %f32_1 %f32_0 +%res1 = OpVectorTimesScalar %f32vec4 %f32vec4_0123 %f32_2 +%res2 = OpVectorTimesScalar %f32vec4 %f32vec4_3210 %f32_2 +)"); +} + +TEST_P(MarkvTest, SpirvSpecSample) { + TestEncodeDecode(GetParam(), R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %31 %33 %42 %57 + OpExecutionMode %4 OriginLowerLeft + +; Debug information + OpSource GLSL 450 + OpName %4 "main" + OpName %9 "scale" + OpName %17 "S" + OpMemberName %17 0 "b" + OpMemberName %17 1 "v" + OpMemberName %17 2 "i" + OpName %18 "blockName" + OpMemberName %18 0 "s" + OpMemberName %18 1 "cond" + OpName %20 "" + OpName %31 "color" + OpName %33 "color1" + OpName %42 "color2" + OpName %48 "i" + OpName %57 "multiplier" + +; Annotations (non-debug) + OpDecorate %15 ArrayStride 16 + OpMemberDecorate %17 0 Offset 0 + OpMemberDecorate %17 1 Offset 16 + OpMemberDecorate %17 2 Offset 96 + OpMemberDecorate %18 0 Offset 0 + OpMemberDecorate %18 1 Offset 112 + OpDecorate %18 Block + OpDecorate %20 DescriptorSet 0 + OpDecorate %42 NoPerspective + +; All types, variables, and constants + %2 = OpTypeVoid + %3 = OpTypeFunction %2 ; void () + %6 = OpTypeFloat 32 ; 32-bit float + %7 = OpTypeVector %6 4 ; vec4 + %8 = OpTypePointer Function %7 ; function-local vec4* + %10 = OpConstant %6 1 + %11 = OpConstant %6 2 + %12 = OpConstantComposite %7 %10 %10 %11 %10 ; vec4(1.0, 1.0, 2.0, 1.0) + %13 = OpTypeInt 32 0 ; 32-bit int, sign-less + %14 = OpConstant %13 5 + %15 = OpTypeArray %7 %14 + %16 = OpTypeInt 32 1 + %17 = OpTypeStruct %13 %15 %16 + %18 = OpTypeStruct %17 %13 + %19 = OpTypePointer Uniform %18 + %20 = OpVariable %19 Uniform + %21 = OpConstant %16 1 + %22 = OpTypePointer Uniform %13 + %25 = OpTypeBool + %26 = OpConstant %13 0 + %30 = OpTypePointer Output %7 + %31 = OpVariable %30 Output + %32 = OpTypePointer Input %7 + %33 = OpVariable %32 Input + %35 = OpConstant %16 0 + %36 = OpConstant %16 2 + %37 = OpTypePointer Uniform %7 + %42 = OpVariable %32 Input + %47 = OpTypePointer Function %16 + %55 = OpConstant %16 4 + %57 = OpVariable %32 Input + +; All functions + %4 = OpFunction %2 None %3 ; main() + %5 = OpLabel + %9 = OpVariable %8 Function + %48 = OpVariable %47 Function + OpStore %9 %12 + %23 = OpAccessChain %22 %20 %21 ; location of cond + %24 = OpLoad %13 %23 ; load 32-bit int from cond + %27 = OpINotEqual %25 %24 %26 ; convert to bool + OpSelectionMerge %29 None ; structured if + OpBranchConditional %27 %28 %41 ; if cond + %28 = OpLabel ; then + %34 = OpLoad %7 %33 + %38 = OpAccessChain %37 %20 %35 %21 %36 ; s.v[2] + %39 = OpLoad %7 %38 + %40 = OpFAdd %7 %34 %39 + OpStore %31 %40 + OpBranch %29 + %41 = OpLabel ; else + %43 = OpLoad %7 %42 + %44 = OpExtInst %7 %1 Sqrt %43 ; extended instruction sqrt + %45 = OpLoad %7 %9 + %46 = OpFMul %7 %44 %45 + OpStore %31 %46 + OpBranch %29 + %29 = OpLabel ; endif + OpStore %48 %35 + OpBranch %49 + %49 = OpLabel + OpLoopMerge %51 %52 None ; structured loop + OpBranch %53 + %53 = OpLabel + %54 = OpLoad %16 %48 + %56 = OpSLessThan %25 %54 %55 ; i < 4 ? + OpBranchConditional %56 %50 %51 ; body or break + %50 = OpLabel ; body + %58 = OpLoad %7 %57 + %59 = OpLoad %7 %31 + %60 = OpFMul %7 %59 %58 + OpStore %31 %60 + OpBranch %52 + %52 = OpLabel ; continue target + %61 = OpLoad %16 %48 + %62 = OpIAdd %16 %61 %21 ; ++i + OpStore %48 %62 + OpBranch %49 ; loop back + %51 = OpLabel ; loop merge point + OpReturn + OpFunctionEnd +)"); +} + +TEST_P(MarkvTest, SampleFromDeadBranchEliminationTest) { + TestEncodeDecode(GetParam(), R"( +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%5 = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%float_0 = OpConstant %float 0 +%12 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +%float_1 = OpConstant %float 1 +%14 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%_ptr_Input_v4float = OpTypePointer Input %v4float +%main = OpFunction %void None %5 +%17 = OpLabel +OpSelectionMerge %18 None +OpBranchConditional %true %19 %20 +%19 = OpLabel +OpBranch %18 +%20 = OpLabel +OpBranch %18 +%18 = OpLabel +%21 = OpPhi %v4float %12 %19 %14 %20 +OpStore %gl_FragColor %21 +OpReturn +OpFunctionEnd +)"); +} + +INSTANTIATE_TEST_CASE_P(AllMarkvModels, MarkvTest, + ::testing::ValuesIn(std::vector{ + kMarkvModelShaderLite, + kMarkvModelShaderMid, + kMarkvModelShaderMax, + }), ); + +} // namespace +} // namespace comp +} // namespace spvtools diff --git a/test/cpp_interface_test.cpp b/test/cpp_interface_test.cpp new file mode 100644 index 000000000..538d40fd4 --- /dev/null +++ b/test/cpp_interface_test.cpp @@ -0,0 +1,328 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "spirv-tools/optimizer.hpp" +#include "spirv/1.1/spirv.h" + +namespace spvtools { +namespace { + +using ::testing::ContainerEq; +using ::testing::HasSubstr; + +// Return a string that contains the minimum instructions needed to form +// a valid module. Other instructions can be appended to this string. +std::string Header() { + return R"(OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +)"; +} + +// When we assemble with a target environment of SPIR-V 1.1, we expect +// the following in the module header version word. +const uint32_t kExpectedSpvVersion = 0x10100; + +TEST(CppInterface, SuccessfulRoundTrip) { + const std::string input_text = "%2 = OpSizeOf %1 %3\n"; + SpirvTools t(SPV_ENV_UNIVERSAL_1_1); + + std::vector binary; + EXPECT_TRUE(t.Assemble(input_text, &binary)); + EXPECT_TRUE(binary.size() > 5u); + EXPECT_EQ(SpvMagicNumber, binary[0]); + EXPECT_EQ(kExpectedSpvVersion, binary[1]); + + // This cannot pass validation since %1 is not defined. + t.SetMessageConsumer([](spv_message_level_t level, const char* source, + const spv_position_t& position, const char* message) { + EXPECT_EQ(SPV_MSG_ERROR, level); + EXPECT_STREQ("input", source); + EXPECT_EQ(0u, position.line); + EXPECT_EQ(0u, position.column); + EXPECT_EQ(1u, position.index); + EXPECT_STREQ("ID 1[%1] has not been defined\n %2 = OpSizeOf %1 %3\n", + message); + }); + EXPECT_FALSE(t.Validate(binary)); + + std::string output_text; + EXPECT_TRUE(t.Disassemble(binary, &output_text)); + EXPECT_EQ(input_text, output_text); +} + +TEST(CppInterface, AssembleEmptyModule) { + std::vector binary(10, 42); + SpirvTools t(SPV_ENV_UNIVERSAL_1_1); + EXPECT_TRUE(t.Assemble("", &binary)); + // We only have the header. + EXPECT_EQ(5u, binary.size()); + EXPECT_EQ(SpvMagicNumber, binary[0]); + EXPECT_EQ(kExpectedSpvVersion, binary[1]); +} + +TEST(CppInterface, AssembleOverloads) { + const std::string input_text = "%2 = OpSizeOf %1 %3\n"; + SpirvTools t(SPV_ENV_UNIVERSAL_1_1); + { + std::vector binary; + EXPECT_TRUE(t.Assemble(input_text, &binary)); + EXPECT_TRUE(binary.size() > 5u); + EXPECT_EQ(SpvMagicNumber, binary[0]); + EXPECT_EQ(kExpectedSpvVersion, binary[1]); + } + { + std::vector binary; + EXPECT_TRUE(t.Assemble(input_text.data(), input_text.size(), &binary)); + EXPECT_TRUE(binary.size() > 5u); + EXPECT_EQ(SpvMagicNumber, binary[0]); + EXPECT_EQ(kExpectedSpvVersion, binary[1]); + } + { // Ignore the last newline. + std::vector binary; + EXPECT_TRUE(t.Assemble(input_text.data(), input_text.size() - 1, &binary)); + EXPECT_TRUE(binary.size() > 5u); + EXPECT_EQ(SpvMagicNumber, binary[0]); + EXPECT_EQ(kExpectedSpvVersion, binary[1]); + } +} + +TEST(CppInterface, DisassembleEmptyModule) { + std::string text(10, 'x'); + SpirvTools t(SPV_ENV_UNIVERSAL_1_1); + int invocation_count = 0; + t.SetMessageConsumer( + [&invocation_count](spv_message_level_t level, const char* source, + const spv_position_t& position, const char* message) { + ++invocation_count; + EXPECT_EQ(SPV_MSG_ERROR, level); + EXPECT_STREQ("input", source); + EXPECT_EQ(0u, position.line); + EXPECT_EQ(0u, position.column); + EXPECT_EQ(0u, position.index); + EXPECT_STREQ("Missing module.", message); + }); + EXPECT_FALSE(t.Disassemble({}, &text)); + EXPECT_EQ("xxxxxxxxxx", text); // The original string is unmodified. + EXPECT_EQ(1, invocation_count); +} + +TEST(CppInterface, DisassembleOverloads) { + const std::string input_text = "%2 = OpSizeOf %1 %3\n"; + SpirvTools t(SPV_ENV_UNIVERSAL_1_1); + + std::vector binary; + EXPECT_TRUE(t.Assemble(input_text, &binary)); + + { + std::string output_text; + EXPECT_TRUE(t.Disassemble(binary, &output_text)); + EXPECT_EQ(input_text, output_text); + } + { + std::string output_text; + EXPECT_TRUE(t.Disassemble(binary.data(), binary.size(), &output_text)); + EXPECT_EQ(input_text, output_text); + } +} + +TEST(CppInterface, SuccessfulValidation) { + SpirvTools t(SPV_ENV_UNIVERSAL_1_1); + int invocation_count = 0; + t.SetMessageConsumer([&invocation_count](spv_message_level_t, const char*, + const spv_position_t&, const char*) { + ++invocation_count; + }); + + std::vector binary; + EXPECT_TRUE(t.Assemble(Header(), &binary)); + EXPECT_TRUE(t.Validate(binary)); + EXPECT_EQ(0, invocation_count); +} + +TEST(CppInterface, ValidateOverloads) { + SpirvTools t(SPV_ENV_UNIVERSAL_1_1); + std::vector binary; + EXPECT_TRUE(t.Assemble(Header(), &binary)); + + { EXPECT_TRUE(t.Validate(binary)); } + { EXPECT_TRUE(t.Validate(binary.data(), binary.size())); } +} + +TEST(CppInterface, ValidateEmptyModule) { + SpirvTools t(SPV_ENV_UNIVERSAL_1_1); + int invocation_count = 0; + t.SetMessageConsumer( + [&invocation_count](spv_message_level_t level, const char* source, + const spv_position_t& position, const char* message) { + ++invocation_count; + EXPECT_EQ(SPV_MSG_ERROR, level); + EXPECT_STREQ("input", source); + EXPECT_EQ(0u, position.line); + EXPECT_EQ(0u, position.column); + EXPECT_EQ(0u, position.index); + EXPECT_STREQ("Invalid SPIR-V magic number.", message); + }); + EXPECT_FALSE(t.Validate({})); + EXPECT_EQ(1, invocation_count); +} + +// Returns the assembly for a SPIR-V module with a struct declaration +// with the given number of members. +std::string MakeModuleHavingStruct(int num_members) { + std::stringstream os; + os << Header(); + os << R"(%1 = OpTypeInt 32 0 + %2 = OpTypeStruct)"; + for (int i = 0; i < num_members; i++) os << " %1"; + return os.str(); +} + +TEST(CppInterface, ValidateWithOptionsPass) { + SpirvTools t(SPV_ENV_UNIVERSAL_1_1); + std::vector binary; + EXPECT_TRUE(t.Assemble(MakeModuleHavingStruct(10), &binary)); + const ValidatorOptions opts; + + EXPECT_TRUE(t.Validate(binary.data(), binary.size(), opts)); +} + +TEST(CppInterface, ValidateWithOptionsFail) { + SpirvTools t(SPV_ENV_UNIVERSAL_1_1); + std::vector binary; + EXPECT_TRUE(t.Assemble(MakeModuleHavingStruct(10), &binary)); + ValidatorOptions opts; + opts.SetUniversalLimit(spv_validator_limit_max_struct_members, 9); + std::stringstream os; + t.SetMessageConsumer([&os](spv_message_level_t, const char*, + const spv_position_t&, + const char* message) { os << message; }); + + EXPECT_FALSE(t.Validate(binary.data(), binary.size(), opts)); + EXPECT_THAT( + os.str(), + HasSubstr( + "Number of OpTypeStruct members (10) has exceeded the limit (9)")); +} + +// Checks that after running the given optimizer |opt| on the given |original| +// source code, we can get the given |optimized| source code. +void CheckOptimization(const std::string& original, + const std::string& optimized, const Optimizer& opt) { + SpirvTools t(SPV_ENV_UNIVERSAL_1_1); + std::vector original_binary; + ASSERT_TRUE(t.Assemble(original, &original_binary)); + + std::vector optimized_binary; + EXPECT_TRUE(opt.Run(original_binary.data(), original_binary.size(), + &optimized_binary)); + + std::string optimized_text; + EXPECT_TRUE(t.Disassemble(optimized_binary, &optimized_text)); + EXPECT_EQ(optimized, optimized_text); +} + +TEST(CppInterface, OptimizeEmptyModule) { + SpirvTools t(SPV_ENV_UNIVERSAL_1_1); + std::vector binary; + EXPECT_TRUE(t.Assemble("", &binary)); + + Optimizer o(SPV_ENV_UNIVERSAL_1_1); + o.RegisterPass(CreateStripDebugInfoPass()); + + // Fails to validate. + EXPECT_FALSE(o.Run(binary.data(), binary.size(), &binary)); +} + +TEST(CppInterface, OptimizeModifiedModule) { + Optimizer o(SPV_ENV_UNIVERSAL_1_1); + o.RegisterPass(CreateStripDebugInfoPass()); + CheckOptimization(Header() + "OpSource GLSL 450", Header(), o); +} + +TEST(CppInterface, OptimizeMulitplePasses) { + std::string original_text = Header() + + "OpSource GLSL 450 " + "OpDecorate %true SpecId 1 " + "%bool = OpTypeBool " + "%true = OpSpecConstantTrue %bool"; + + Optimizer o(SPV_ENV_UNIVERSAL_1_1); + o.RegisterPass(CreateStripDebugInfoPass()) + .RegisterPass(CreateFreezeSpecConstantValuePass()); + + std::string expected_text = Header() + + "%bool = OpTypeBool\n" + "%true = OpConstantTrue %bool\n"; + + CheckOptimization(original_text, expected_text, o); +} + +TEST(CppInterface, OptimizeDoNothingWithPassToken) { + CreateFreezeSpecConstantValuePass(); + auto token = CreateUnifyConstantPass(); +} + +TEST(CppInterface, OptimizeReassignPassToken) { + auto token = CreateNullPass(); + token = CreateStripDebugInfoPass(); + + CheckOptimization( + Header() + "OpSource GLSL 450", Header(), + Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token))); +} + +TEST(CppInterface, OptimizeMoveConstructPassToken) { + auto token1 = CreateStripDebugInfoPass(); + Optimizer::PassToken token2(std::move(token1)); + + CheckOptimization( + Header() + "OpSource GLSL 450", Header(), + Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token2))); +} + +TEST(CppInterface, OptimizeMoveAssignPassToken) { + auto token1 = CreateStripDebugInfoPass(); + auto token2 = CreateNullPass(); + token2 = std::move(token1); + + CheckOptimization( + Header() + "OpSource GLSL 450", Header(), + Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token2))); +} + +TEST(CppInterface, OptimizeSameAddressForOriginalOptimizedBinary) { + SpirvTools t(SPV_ENV_UNIVERSAL_1_1); + std::vector binary; + ASSERT_TRUE(t.Assemble(Header() + "OpSource GLSL 450", &binary)); + + EXPECT_TRUE(Optimizer(SPV_ENV_UNIVERSAL_1_1) + .RegisterPass(CreateStripDebugInfoPass()) + .Run(binary.data(), binary.size(), &binary)); + + std::string optimized_text; + EXPECT_TRUE(t.Disassemble(binary, &optimized_text)); + EXPECT_EQ(Header(), optimized_text); +} + +// TODO(antiagainst): tests for SetMessageConsumer(). + +} // namespace +} // namespace spvtools diff --git a/test/diagnostic_test.cpp b/test/diagnostic_test.cpp new file mode 100644 index 000000000..f86bae113 --- /dev/null +++ b/test/diagnostic_test.cpp @@ -0,0 +1,150 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using ::testing::Eq; + +// Returns a newly created diagnostic value. +spv_diagnostic MakeValidDiagnostic() { + spv_position_t position = {}; + spv_diagnostic diagnostic = spvDiagnosticCreate(&position, ""); + EXPECT_NE(nullptr, diagnostic); + return diagnostic; +} + +TEST(Diagnostic, DestroyNull) { spvDiagnosticDestroy(nullptr); } + +TEST(Diagnostic, DestroyValidDiagnostic) { + spv_diagnostic diagnostic = MakeValidDiagnostic(); + spvDiagnosticDestroy(diagnostic); + // We aren't allowed to use the diagnostic pointer anymore. + // So we can't test its behaviour. +} + +TEST(Diagnostic, DestroyValidDiagnosticAfterReassignment) { + spv_diagnostic diagnostic = MakeValidDiagnostic(); + spv_diagnostic second_diagnostic = MakeValidDiagnostic(); + EXPECT_TRUE(diagnostic != second_diagnostic); + spvDiagnosticDestroy(diagnostic); + diagnostic = second_diagnostic; + spvDiagnosticDestroy(diagnostic); +} + +TEST(Diagnostic, PrintDefault) { + char message[] = "Test Diagnostic!"; + spv_diagnostic_t diagnostic = {{2, 3, 5}, message}; + // TODO: Redirect stderr + ASSERT_EQ(SPV_SUCCESS, spvDiagnosticPrint(&diagnostic)); + // TODO: Validate the output of spvDiagnosticPrint() + // TODO: Remove the redirection of stderr +} + +TEST(Diagnostic, PrintInvalidDiagnostic) { + ASSERT_EQ(SPV_ERROR_INVALID_DIAGNOSTIC, spvDiagnosticPrint(nullptr)); +} + +// TODO(dneto): We should be able to redirect the diagnostic printing. +// Once we do that, we can test diagnostic corner cases. + +TEST(DiagnosticStream, ConversionToResultType) { + // Check after the DiagnosticStream object is destroyed. + spv_result_t value; + { value = DiagnosticStream({}, nullptr, "", SPV_ERROR_INVALID_TEXT); } + EXPECT_EQ(SPV_ERROR_INVALID_TEXT, value); + + // Check implicit conversion via plain assignment. + value = DiagnosticStream({}, nullptr, "", SPV_SUCCESS); + EXPECT_EQ(SPV_SUCCESS, value); + + // Check conversion via constructor. + EXPECT_EQ(SPV_FAILED_MATCH, + spv_result_t(DiagnosticStream({}, nullptr, "", SPV_FAILED_MATCH))); +} + +TEST( + DiagnosticStream, + MoveConstructorPreservesPreviousMessagesAndPreventsOutputFromExpiringValue) { + std::ostringstream messages; + int message_count = 0; + auto consumer = [&messages, &message_count](spv_message_level_t, const char*, + const spv_position_t&, + const char* msg) { + message_count++; + messages << msg; + }; + + // Enclose the DiagnosticStream variables in a scope to force destruction. + { + DiagnosticStream ds0({}, consumer, "", SPV_ERROR_INVALID_BINARY); + ds0 << "First"; + DiagnosticStream ds1(std::move(ds0)); + ds1 << "Second"; + } + EXPECT_THAT(message_count, Eq(1)); + EXPECT_THAT(messages.str(), Eq("FirstSecond")); +} + +TEST(DiagnosticStream, MoveConstructorCanBeDirectlyShiftedTo) { + std::ostringstream messages; + int message_count = 0; + auto consumer = [&messages, &message_count](spv_message_level_t, const char*, + const spv_position_t&, + const char* msg) { + message_count++; + messages << msg; + }; + + // Enclose the DiagnosticStream variables in a scope to force destruction. + { + DiagnosticStream ds0({}, consumer, "", SPV_ERROR_INVALID_BINARY); + ds0 << "First"; + std::move(ds0) << "Second"; + } + EXPECT_THAT(message_count, Eq(1)); + EXPECT_THAT(messages.str(), Eq("FirstSecond")); +} + +TEST(DiagnosticStream, DiagnosticFromLambdaReturnCanStillBeUsed) { + std::ostringstream messages; + int message_count = 0; + auto consumer = [&messages, &message_count](spv_message_level_t, const char*, + const spv_position_t&, + const char* msg) { + message_count++; + messages << msg; + }; + + { + auto emitter = [&consumer]() -> DiagnosticStream { + DiagnosticStream ds0({}, consumer, "", SPV_ERROR_INVALID_BINARY); + ds0 << "First"; + return ds0; + }; + emitter() << "Second"; + } + EXPECT_THAT(message_count, Eq(1)); + EXPECT_THAT(messages.str(), Eq("FirstSecond")); +} + +} // namespace +} // namespace spvtools diff --git a/test/enum_set_test.cpp b/test/enum_set_test.cpp new file mode 100644 index 000000000..ddacd4214 --- /dev/null +++ b/test/enum_set_test.cpp @@ -0,0 +1,290 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/enum_set.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using spvtest::ElementsIn; +using ::testing::Eq; +using ::testing::ValuesIn; + +TEST(EnumSet, IsEmpty1) { + EnumSet set; + EXPECT_TRUE(set.IsEmpty()); + set.Add(0); + EXPECT_FALSE(set.IsEmpty()); +} + +TEST(EnumSet, IsEmpty2) { + EnumSet set; + EXPECT_TRUE(set.IsEmpty()); + set.Add(150); + EXPECT_FALSE(set.IsEmpty()); +} + +TEST(EnumSet, IsEmpty3) { + EnumSet set(4); + EXPECT_FALSE(set.IsEmpty()); +} + +TEST(EnumSet, IsEmpty4) { + EnumSet set(300); + EXPECT_FALSE(set.IsEmpty()); +} + +TEST(EnumSetHasAnyOf, EmptySetEmptyQuery) { + const EnumSet set; + const EnumSet empty; + EXPECT_TRUE(set.HasAnyOf(empty)); + EXPECT_TRUE(EnumSet().HasAnyOf(EnumSet())); +} + +TEST(EnumSetHasAnyOf, MaskSetEmptyQuery) { + EnumSet set; + const EnumSet empty; + set.Add(5); + set.Add(8); + EXPECT_TRUE(set.HasAnyOf(empty)); +} + +TEST(EnumSetHasAnyOf, OverflowSetEmptyQuery) { + EnumSet set; + const EnumSet empty; + set.Add(200); + set.Add(300); + EXPECT_TRUE(set.HasAnyOf(empty)); +} + +TEST(EnumSetHasAnyOf, EmptyQuery) { + EnumSet set; + const EnumSet empty; + set.Add(5); + set.Add(8); + set.Add(200); + set.Add(300); + EXPECT_TRUE(set.HasAnyOf(empty)); +} + +TEST(EnumSetHasAnyOf, EmptyQueryAlwaysTrue) { + EnumSet set; + const EnumSet empty; + EXPECT_TRUE(set.HasAnyOf(empty)); + set.Add(5); + EXPECT_TRUE(set.HasAnyOf(empty)); + + EXPECT_TRUE(EnumSet(100).HasAnyOf(EnumSet())); +} + +TEST(EnumSetHasAnyOf, ReflexiveMask) { + EnumSet set(3); + set.Add(24); + set.Add(30); + EXPECT_TRUE(set.HasAnyOf(set)); +} + +TEST(EnumSetHasAnyOf, ReflexiveOverflow) { + EnumSet set(200); + set.Add(300); + set.Add(400); + EXPECT_TRUE(set.HasAnyOf(set)); +} + +TEST(EnumSetHasAnyOf, Reflexive) { + EnumSet set(3); + set.Add(24); + set.Add(300); + set.Add(400); + EXPECT_TRUE(set.HasAnyOf(set)); +} + +TEST(EnumSetHasAnyOf, EmptySetHasNone) { + EnumSet set; + EnumSet items; + for (uint32_t i = 0; i < 200; ++i) { + items.Add(i); + EXPECT_FALSE(set.HasAnyOf(items)); + EXPECT_FALSE(set.HasAnyOf(EnumSet(i))); + } +} + +TEST(EnumSetHasAnyOf, MaskSetMaskQuery) { + EnumSet set(0); + EnumSet items(1); + EXPECT_FALSE(set.HasAnyOf(items)); + set.Add(2); + items.Add(3); + EXPECT_FALSE(set.HasAnyOf(items)); + set.Add(3); + EXPECT_TRUE(set.HasAnyOf(items)); + set.Add(4); + EXPECT_TRUE(set.HasAnyOf(items)); +} + +TEST(EnumSetHasAnyOf, OverflowSetOverflowQuery) { + EnumSet set(100); + EnumSet items(200); + EXPECT_FALSE(set.HasAnyOf(items)); + set.Add(300); + items.Add(400); + EXPECT_FALSE(set.HasAnyOf(items)); + set.Add(200); + EXPECT_TRUE(set.HasAnyOf(items)); + set.Add(500); + EXPECT_TRUE(set.HasAnyOf(items)); +} + +TEST(EnumSetHasAnyOf, GeneralCase) { + EnumSet set(0); + EnumSet items(100); + EXPECT_FALSE(set.HasAnyOf(items)); + set.Add(300); + items.Add(4); + EXPECT_FALSE(set.HasAnyOf(items)); + set.Add(5); + items.Add(500); + EXPECT_FALSE(set.HasAnyOf(items)); + set.Add(500); + EXPECT_TRUE(set.HasAnyOf(items)); + EXPECT_FALSE(set.HasAnyOf(EnumSet(20))); + EXPECT_FALSE(set.HasAnyOf(EnumSet(600))); + EXPECT_TRUE(set.HasAnyOf(EnumSet(5))); + EXPECT_TRUE(set.HasAnyOf(EnumSet(300))); + EXPECT_TRUE(set.HasAnyOf(EnumSet(0))); +} + +TEST(EnumSet, DefaultIsEmpty) { + EnumSet set; + for (uint32_t i = 0; i < 1000; ++i) { + EXPECT_FALSE(set.Contains(i)); + } +} + +TEST(CapabilitySet, ConstructSingleMemberMatrix) { + CapabilitySet s(SpvCapabilityMatrix); + EXPECT_TRUE(s.Contains(SpvCapabilityMatrix)); + EXPECT_FALSE(s.Contains(SpvCapabilityShader)); + EXPECT_FALSE(s.Contains(static_cast(1000))); +} + +TEST(CapabilitySet, ConstructSingleMemberMaxInMask) { + CapabilitySet s(static_cast(63)); + EXPECT_FALSE(s.Contains(SpvCapabilityMatrix)); + EXPECT_FALSE(s.Contains(SpvCapabilityShader)); + EXPECT_TRUE(s.Contains(static_cast(63))); + EXPECT_FALSE(s.Contains(static_cast(64))); + EXPECT_FALSE(s.Contains(static_cast(1000))); +} + +TEST(CapabilitySet, ConstructSingleMemberMinOverflow) { + // Check the first one that forces overflow beyond the mask. + CapabilitySet s(static_cast(64)); + EXPECT_FALSE(s.Contains(SpvCapabilityMatrix)); + EXPECT_FALSE(s.Contains(SpvCapabilityShader)); + EXPECT_FALSE(s.Contains(static_cast(63))); + EXPECT_TRUE(s.Contains(static_cast(64))); + EXPECT_FALSE(s.Contains(static_cast(1000))); +} + +TEST(CapabilitySet, ConstructSingleMemberMaxOverflow) { + // Check the max 32-bit signed int. + CapabilitySet s(static_cast(0x7fffffffu)); + EXPECT_FALSE(s.Contains(SpvCapabilityMatrix)); + EXPECT_FALSE(s.Contains(SpvCapabilityShader)); + EXPECT_FALSE(s.Contains(static_cast(1000))); + EXPECT_TRUE(s.Contains(static_cast(0x7fffffffu))); +} + +TEST(CapabilitySet, AddEnum) { + CapabilitySet s(SpvCapabilityShader); + s.Add(SpvCapabilityKernel); + s.Add(static_cast(42)); + EXPECT_FALSE(s.Contains(SpvCapabilityMatrix)); + EXPECT_TRUE(s.Contains(SpvCapabilityShader)); + EXPECT_TRUE(s.Contains(SpvCapabilityKernel)); + EXPECT_TRUE(s.Contains(static_cast(42))); +} + +TEST(CapabilitySet, InitializerListEmpty) { + CapabilitySet s{}; + for (uint32_t i = 0; i < 1000; i++) { + EXPECT_FALSE(s.Contains(static_cast(i))); + } +} + +struct ForEachCase { + CapabilitySet capabilities; + std::vector expected; +}; + +using CapabilitySetForEachTest = ::testing::TestWithParam; + +TEST_P(CapabilitySetForEachTest, CallsAsExpected) { + EXPECT_THAT(ElementsIn(GetParam().capabilities), Eq(GetParam().expected)); +} + +TEST_P(CapabilitySetForEachTest, CopyConstructor) { + CapabilitySet copy(GetParam().capabilities); + EXPECT_THAT(ElementsIn(copy), Eq(GetParam().expected)); +} + +TEST_P(CapabilitySetForEachTest, MoveConstructor) { + // We need a writable copy to move from. + CapabilitySet copy(GetParam().capabilities); + CapabilitySet moved(std::move(copy)); + EXPECT_THAT(ElementsIn(moved), Eq(GetParam().expected)); + + // The moved-from set is empty. + EXPECT_THAT(ElementsIn(copy), Eq(std::vector{})); +} + +TEST_P(CapabilitySetForEachTest, OperatorEquals) { + CapabilitySet assigned = GetParam().capabilities; + EXPECT_THAT(ElementsIn(assigned), Eq(GetParam().expected)); +} + +TEST_P(CapabilitySetForEachTest, OperatorEqualsSelfAssign) { + CapabilitySet assigned{GetParam().capabilities}; + assigned = assigned; + EXPECT_THAT(ElementsIn(assigned), Eq(GetParam().expected)); +} + +INSTANTIATE_TEST_CASE_P(Samples, CapabilitySetForEachTest, + ValuesIn(std::vector{ + {{}, {}}, + {{SpvCapabilityMatrix}, {SpvCapabilityMatrix}}, + {{SpvCapabilityKernel, SpvCapabilityShader}, + {SpvCapabilityShader, SpvCapabilityKernel}}, + {{static_cast(999)}, + {static_cast(999)}}, + {{static_cast(0x7fffffff)}, + {static_cast(0x7fffffff)}}, + // Mixture and out of order + {{static_cast(0x7fffffff), + static_cast(100), + SpvCapabilityShader, SpvCapabilityMatrix}, + {SpvCapabilityMatrix, SpvCapabilityShader, + static_cast(100), + static_cast(0x7fffffff)}}, + }), ); + +} // namespace +} // namespace spvtools diff --git a/test/enum_string_mapping_test.cpp b/test/enum_string_mapping_test.cpp new file mode 100644 index 000000000..b525d6014 --- /dev/null +++ b/test/enum_string_mapping_test.cpp @@ -0,0 +1,195 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests for OpExtension validator rules. + +#include +#include +#include + +#include "gtest/gtest.h" +#include "source/enum_string_mapping.h" +#include "source/extensions.h" + +namespace spvtools { +namespace { + +using ::testing::Values; +using ::testing::ValuesIn; + +using ExtensionTest = + ::testing::TestWithParam>; +using UnknownExtensionTest = ::testing::TestWithParam; +using CapabilityTest = + ::testing::TestWithParam>; + +TEST_P(ExtensionTest, TestExtensionFromString) { + const std::pair& param = GetParam(); + const Extension extension = param.first; + const std::string extension_str = param.second; + Extension result_extension; + ASSERT_TRUE(GetExtensionFromString(extension_str.c_str(), &result_extension)); + EXPECT_EQ(extension, result_extension); +} + +TEST_P(ExtensionTest, TestExtensionToString) { + const std::pair& param = GetParam(); + const Extension extension = param.first; + const std::string extension_str = param.second; + const std::string result_str = ExtensionToString(extension); + EXPECT_EQ(extension_str, result_str); +} + +TEST_P(UnknownExtensionTest, TestExtensionFromStringFails) { + Extension result_extension; + ASSERT_FALSE(GetExtensionFromString(GetParam().c_str(), &result_extension)); +} + +TEST_P(CapabilityTest, TestCapabilityToString) { + const std::pair& param = GetParam(); + const SpvCapability capability = param.first; + const std::string capability_str = param.second; + const std::string result_str = CapabilityToString(capability); + EXPECT_EQ(capability_str, result_str); +} + +INSTANTIATE_TEST_CASE_P( + AllExtensions, ExtensionTest, + ValuesIn(std::vector>({ + {Extension::kSPV_KHR_16bit_storage, "SPV_KHR_16bit_storage"}, + {Extension::kSPV_KHR_device_group, "SPV_KHR_device_group"}, + {Extension::kSPV_KHR_multiview, "SPV_KHR_multiview"}, + {Extension::kSPV_KHR_shader_ballot, "SPV_KHR_shader_ballot"}, + {Extension::kSPV_KHR_shader_draw_parameters, + "SPV_KHR_shader_draw_parameters"}, + {Extension::kSPV_KHR_subgroup_vote, "SPV_KHR_subgroup_vote"}, + {Extension::kSPV_NVX_multiview_per_view_attributes, + "SPV_NVX_multiview_per_view_attributes"}, + {Extension::kSPV_NV_geometry_shader_passthrough, + "SPV_NV_geometry_shader_passthrough"}, + {Extension::kSPV_NV_sample_mask_override_coverage, + "SPV_NV_sample_mask_override_coverage"}, + {Extension::kSPV_NV_stereo_view_rendering, + "SPV_NV_stereo_view_rendering"}, + {Extension::kSPV_NV_viewport_array2, "SPV_NV_viewport_array2"}, + {Extension::kSPV_GOOGLE_decorate_string, "SPV_GOOGLE_decorate_string"}, + {Extension::kSPV_GOOGLE_hlsl_functionality1, + "SPV_GOOGLE_hlsl_functionality1"}, + {Extension::kSPV_KHR_8bit_storage, "SPV_KHR_8bit_storage"}, + }))); + +INSTANTIATE_TEST_CASE_P(UnknownExtensions, UnknownExtensionTest, + Values("", "SPV_KHR_", "SPV_KHR_device_group_ERROR", + /*alphabetically before all extensions*/ "A", + /*alphabetically after all extensions*/ "Z", + "SPV_ERROR_random_string_hfsdklhlktherh")); + +INSTANTIATE_TEST_CASE_P( + AllCapabilities, CapabilityTest, + ValuesIn(std::vector>( + {{SpvCapabilityMatrix, "Matrix"}, + {SpvCapabilityShader, "Shader"}, + {SpvCapabilityGeometry, "Geometry"}, + {SpvCapabilityTessellation, "Tessellation"}, + {SpvCapabilityAddresses, "Addresses"}, + {SpvCapabilityLinkage, "Linkage"}, + {SpvCapabilityKernel, "Kernel"}, + {SpvCapabilityVector16, "Vector16"}, + {SpvCapabilityFloat16Buffer, "Float16Buffer"}, + {SpvCapabilityFloat16, "Float16"}, + {SpvCapabilityFloat64, "Float64"}, + {SpvCapabilityInt64, "Int64"}, + {SpvCapabilityInt64Atomics, "Int64Atomics"}, + {SpvCapabilityImageBasic, "ImageBasic"}, + {SpvCapabilityImageReadWrite, "ImageReadWrite"}, + {SpvCapabilityImageMipmap, "ImageMipmap"}, + {SpvCapabilityPipes, "Pipes"}, + {SpvCapabilityGroups, "Groups"}, + {SpvCapabilityDeviceEnqueue, "DeviceEnqueue"}, + {SpvCapabilityLiteralSampler, "LiteralSampler"}, + {SpvCapabilityAtomicStorage, "AtomicStorage"}, + {SpvCapabilityInt16, "Int16"}, + {SpvCapabilityTessellationPointSize, "TessellationPointSize"}, + {SpvCapabilityGeometryPointSize, "GeometryPointSize"}, + {SpvCapabilityImageGatherExtended, "ImageGatherExtended"}, + {SpvCapabilityStorageImageMultisample, "StorageImageMultisample"}, + {SpvCapabilityUniformBufferArrayDynamicIndexing, + "UniformBufferArrayDynamicIndexing"}, + {SpvCapabilitySampledImageArrayDynamicIndexing, + "SampledImageArrayDynamicIndexing"}, + {SpvCapabilityStorageBufferArrayDynamicIndexing, + "StorageBufferArrayDynamicIndexing"}, + {SpvCapabilityStorageImageArrayDynamicIndexing, + "StorageImageArrayDynamicIndexing"}, + {SpvCapabilityClipDistance, "ClipDistance"}, + {SpvCapabilityCullDistance, "CullDistance"}, + {SpvCapabilityImageCubeArray, "ImageCubeArray"}, + {SpvCapabilitySampleRateShading, "SampleRateShading"}, + {SpvCapabilityImageRect, "ImageRect"}, + {SpvCapabilitySampledRect, "SampledRect"}, + {SpvCapabilityGenericPointer, "GenericPointer"}, + {SpvCapabilityInt8, "Int8"}, + {SpvCapabilityInputAttachment, "InputAttachment"}, + {SpvCapabilitySparseResidency, "SparseResidency"}, + {SpvCapabilityMinLod, "MinLod"}, + {SpvCapabilitySampled1D, "Sampled1D"}, + {SpvCapabilityImage1D, "Image1D"}, + {SpvCapabilitySampledCubeArray, "SampledCubeArray"}, + {SpvCapabilitySampledBuffer, "SampledBuffer"}, + {SpvCapabilityImageBuffer, "ImageBuffer"}, + {SpvCapabilityImageMSArray, "ImageMSArray"}, + {SpvCapabilityStorageImageExtendedFormats, + "StorageImageExtendedFormats"}, + {SpvCapabilityImageQuery, "ImageQuery"}, + {SpvCapabilityDerivativeControl, "DerivativeControl"}, + {SpvCapabilityInterpolationFunction, "InterpolationFunction"}, + {SpvCapabilityTransformFeedback, "TransformFeedback"}, + {SpvCapabilityGeometryStreams, "GeometryStreams"}, + {SpvCapabilityStorageImageReadWithoutFormat, + "StorageImageReadWithoutFormat"}, + {SpvCapabilityStorageImageWriteWithoutFormat, + "StorageImageWriteWithoutFormat"}, + {SpvCapabilityMultiViewport, "MultiViewport"}, + {SpvCapabilitySubgroupDispatch, "SubgroupDispatch"}, + {SpvCapabilityNamedBarrier, "NamedBarrier"}, + {SpvCapabilityPipeStorage, "PipeStorage"}, + {SpvCapabilitySubgroupBallotKHR, "SubgroupBallotKHR"}, + {SpvCapabilityDrawParameters, "DrawParameters"}, + {SpvCapabilitySubgroupVoteKHR, "SubgroupVoteKHR"}, + {SpvCapabilityStorageBuffer16BitAccess, "StorageBuffer16BitAccess"}, + {SpvCapabilityStorageUniformBufferBlock16, + "StorageBuffer16BitAccess"}, // Preferred name + {SpvCapabilityUniformAndStorageBuffer16BitAccess, + "UniformAndStorageBuffer16BitAccess"}, + {SpvCapabilityStorageUniform16, + "UniformAndStorageBuffer16BitAccess"}, // Preferred name + {SpvCapabilityStoragePushConstant16, "StoragePushConstant16"}, + {SpvCapabilityStorageInputOutput16, "StorageInputOutput16"}, + {SpvCapabilityDeviceGroup, "DeviceGroup"}, + {SpvCapabilityMultiView, "MultiView"}, + {SpvCapabilitySampleMaskOverrideCoverageNV, + "SampleMaskOverrideCoverageNV"}, + {SpvCapabilityGeometryShaderPassthroughNV, + "GeometryShaderPassthroughNV"}, + // The next two are different names for the same token. + {SpvCapabilityShaderViewportIndexLayerNV, + "ShaderViewportIndexLayerEXT"}, + {SpvCapabilityShaderViewportIndexLayerEXT, + "ShaderViewportIndexLayerEXT"}, + {SpvCapabilityShaderViewportMaskNV, "ShaderViewportMaskNV"}, + {SpvCapabilityShaderStereoViewNV, "ShaderStereoViewNV"}, + {SpvCapabilityPerViewAttributesNV, "PerViewAttributesNV"}})), ); + +} // namespace +} // namespace spvtools diff --git a/test/ext_inst.debuginfo_test.cpp b/test/ext_inst.debuginfo_test.cpp new file mode 100644 index 000000000..15fa8f765 --- /dev/null +++ b/test/ext_inst.debuginfo_test.cpp @@ -0,0 +1,812 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "DebugInfo.h" +#include "gmock/gmock.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +// This file tests the correctness of encoding and decoding of instructions +// involving the DebugInfo extended instruction set. +// Semantic correctness should be the responsibility of validator. +// +// See https://www.khronos.org/registry/spir-v/specs/1.0/DebugInfo.html + +namespace spvtools { +namespace { + +using spvtest::Concatenate; +using spvtest::MakeInstruction; +using spvtest::MakeVector; +using testing::Eq; + +struct InstructionCase { + uint32_t opcode; + std::string name; + std::string operands; + std::vector expected_operands; +}; + +using ExtInstDebugInfoRoundTripTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>; +using ExtInstDebugInfoRoundTripTestExplicit = spvtest::TextToBinaryTest; + +TEST_P(ExtInstDebugInfoRoundTripTest, ParameterizedExtInst) { + const std::string input = + "%1 = OpExtInstImport \"DebugInfo\"\n" + "%3 = OpExtInst %2 %1 " + + GetParam().name + GetParam().operands + "\n"; + // First make sure it assembles correctly. + EXPECT_THAT( + CompiledInstructions(input), + Eq(Concatenate( + {MakeInstruction(SpvOpExtInstImport, {1}, MakeVector("DebugInfo")), + MakeInstruction(SpvOpExtInst, {2, 3, 1, GetParam().opcode}, + GetParam().expected_operands)}))) + << input; + // Now check the round trip through the disassembler. + EXPECT_THAT(EncodeAndDecodeSuccessfully(input), input) << input; +} + +#define CASE_0(Enum) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, "", {} \ + } + +#define CASE_ILL(Enum, L0, L1) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, " %4 " #L0 " " #L1, { \ + 4, L0, L1 \ + } \ + } + +#define CASE_IL(Enum, L0) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, " %4 " #L0, { 4, L0 } \ + } + +#define CASE_I(Enum) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, " %4", { 4 } \ + } + +#define CASE_II(Enum) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, " %4 %5", { 4, 5 } \ + } + +#define CASE_III(Enum) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, " %4 %5 %6", { 4, 5, 6 } \ + } + +#define CASE_IIII(Enum) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, " %4 %5 %6 %7", { \ + 4, 5, 6, 7 \ + } \ + } + +#define CASE_IIIII(Enum) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, " %4 %5 %6 %7 %8", { \ + 4, 5, 6, 7, 8 \ + } \ + } + +#define CASE_IIIIII(Enum) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, " %4 %5 %6 %7 %8 %9", { \ + 4, 5, 6, 7, 8, 9 \ + } \ + } + +#define CASE_IIIIIII(Enum) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, " %4 %5 %6 %7 %8 %9 %10", { \ + 4, 5, 6, 7, 8, 9, 10 \ + } \ + } + +#define CASE_IIILLI(Enum, L0, L1) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, \ + " %4 %5 %6 " #L0 " " #L1 " %7", { \ + 4, 5, 6, L0, L1, 7 \ + } \ + } + +#define CASE_IIILLIL(Enum, L0, L1, L2) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, \ + " %4 %5 %6 " #L0 " " #L1 " %7 " #L2, { \ + 4, 5, 6, L0, L1, 7, L2 \ + } \ + } + +#define CASE_IE(Enum, E0) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, " %4 " #E0, { \ + 4, uint32_t(DebugInfo##E0) \ + } \ + } + +#define CASE_IIE(Enum, E0) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, " %4 %5 " #E0, { \ + 4, 5, uint32_t(DebugInfo##E0) \ + } \ + } + +#define CASE_ISF(Enum, S0, Fstr, Fnum) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, " %4 " #S0 " " Fstr, { \ + 4, uint32_t(SpvStorageClass##S0), Fnum \ + } \ + } + +#define CASE_LII(Enum, L0) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, " " #L0 " %4 %5", { \ + L0, 4, 5 \ + } \ + } + +#define CASE_ILI(Enum, L0) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, " %4 " #L0 " %5", { \ + 4, L0, 5 \ + } \ + } + +#define CASE_ILII(Enum, L0) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, " %4 " #L0 " %5 %6", { \ + 4, L0, 5, 6 \ + } \ + } + +#define CASE_ILLII(Enum, L0, L1) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, \ + " %4 " #L0 " " #L1 " %5 %6", { \ + 4, L0, L1, 5, 6 \ + } \ + } + +#define CASE_IIILLIIF(Enum, L0, L1, Fstr, Fnum) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, \ + " %4 %5 %6 " #L0 " " #L1 " %7 %8 " Fstr, { \ + 4, 5, 6, L0, L1, 7, 8, Fnum \ + } \ + } + +#define CASE_IIILLIIFII(Enum, L0, L1, Fstr, Fnum) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, \ + " %4 %5 %6 " #L0 " " #L1 " %7 %8 " Fstr " %9 %10", { \ + 4, 5, 6, L0, L1, 7, 8, Fnum, 9, 10 \ + } \ + } + +#define CASE_IIILLIIFIIII(Enum, L0, L1, Fstr, Fnum) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, \ + " %4 %5 %6 " #L0 " " #L1 " %7 %8 " Fstr " %9 %10 %11 %12", { \ + 4, 5, 6, L0, L1, 7, 8, Fnum, 9, 10, 11, 12 \ + } \ + } + +#define CASE_IIILLIIFIIIIII(Enum, L0, L1, Fstr, Fnum) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, \ + " %4 %5 %6 " #L0 " " #L1 " %7 %8 " Fstr " %9 %10 %11 %12 %13 %14", { \ + 4, 5, 6, L0, L1, 7, 8, Fnum, 9, 10, 11, 12, 13, 14 \ + } \ + } + +#define CASE_IEILLIIF(Enum, E0, L0, L1, Fstr, Fnum) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, \ + " %4 " #E0 " %5 " #L0 " " #L1 " %6 %7 " Fstr, { \ + 4, uint32_t(DebugInfo##E0), 5, L0, L1, 6, 7, Fnum \ + } \ + } + +#define CASE_IEILLIIFI(Enum, E0, L0, L1, Fstr, Fnum) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, \ + " %4 " #E0 " %5 " #L0 " " #L1 " %6 %7 " Fstr " %8", { \ + 4, uint32_t(DebugInfo##E0), 5, L0, L1, 6, 7, Fnum, 8 \ + } \ + } + +#define CASE_IEILLIIFII(Enum, E0, L0, L1, Fstr, Fnum) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, \ + " %4 " #E0 " %5 " #L0 " " #L1 " %6 %7 " Fstr " %8 %9", { \ + 4, uint32_t(DebugInfo##E0), 5, L0, L1, 6, 7, Fnum, 8, 9 \ + } \ + } + +#define CASE_IEILLIIFIII(Enum, E0, L0, L1, Fstr, Fnum) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, \ + " %4 " #E0 " %5 " #L0 " " #L1 " %6 %7 " Fstr " %8 %9 %10", { \ + 4, uint32_t(DebugInfo##E0), 5, L0, L1, 6, 7, Fnum, 8, 9, 10 \ + } \ + } + +#define CASE_IEILLIIFIIII(Enum, E0, L0, L1, Fstr, Fnum) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, \ + " %4 " #E0 " %5 " #L0 " " #L1 " %6 %7 " Fstr " %8 %9 %10 %11", { \ + 4, uint32_t(DebugInfo##E0), 5, L0, L1, 6, 7, Fnum, 8, 9, 10, 11 \ + } \ + } + +#define CASE_IIILLIIIF(Enum, L0, L1, Fstr, Fnum) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, \ + " %4 %5 %6 " #L0 " " #L1 " %7 %8 %9 " Fstr, { \ + 4, 5, 6, L0, L1, 7, 8, 9, Fnum \ + } \ + } + +#define CASE_IIILLIIIFI(Enum, L0, L1, Fstr, Fnum) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, \ + " %4 %5 %6 " #L0 " " #L1 " %7 %8 %9 " Fstr " %10", { \ + 4, 5, 6, L0, L1, 7, 8, 9, Fnum, 10 \ + } \ + } + +#define CASE_IIIIF(Enum, Fstr, Fnum) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, " %4 %5 %6 %7 " Fstr, { \ + 4, 5, 6, 7, Fnum \ + } \ + } + +#define CASE_IIILL(Enum, L0, L1) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, " %4 %5 %6 " #L0 " " #L1, { \ + 4, 5, 6, L0, L1 \ + } \ + } + +#define CASE_IIIILL(Enum, L0, L1) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, \ + " %4 %5 %6 %7 " #L0 " " #L1, { \ + 4, 5, 6, 7, L0, L1 \ + } \ + } + +#define CASE_IILLI(Enum, L0, L1) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, \ + " %4 %5 " #L0 " " #L1 " %6", { \ + 4, 5, L0, L1, 6 \ + } \ + } + +#define CASE_IILLII(Enum, L0, L1) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, \ + " %4 %5 " #L0 " " #L1 " %6 %7", { \ + 4, 5, L0, L1, 6, 7 \ + } \ + } + +#define CASE_IILLIII(Enum, L0, L1) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, \ + " %4 %5 " #L0 " " #L1 " %6 %7 %8", { \ + 4, 5, L0, L1, 6, 7, 8 \ + } \ + } + +#define CASE_IILLIIII(Enum, L0, L1) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, \ + " %4 %5 " #L0 " " #L1 " %6 %7 %8 %9", { \ + 4, 5, L0, L1, 6, 7, 8, 9 \ + } \ + } + +#define CASE_IIILLIIFLI(Enum, L0, L1, Fstr, Fnum, L2) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, \ + " %4 %5 %6 " #L0 " " #L1 " %7 %8 " Fstr " " #L2 " %9", { \ + 4, 5, 6, L0, L1, 7, 8, Fnum, L2, 9 \ + } \ + } + +#define CASE_IIILLIIFLII(Enum, L0, L1, Fstr, Fnum, L2) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, \ + " %4 %5 %6 " #L0 " " #L1 " %7 %8 " Fstr " " #L2 " %9 %10", { \ + 4, 5, 6, L0, L1, 7, 8, Fnum, L2, 9, 10 \ + } \ + } + +#define CASE_E(Enum, E0) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, " " #E0, { \ + uint32_t(DebugInfo##E0) \ + } \ + } + +#define CASE_EL(Enum, E0, L0) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, " " #E0 " " #L0, { \ + uint32_t(DebugInfo##E0), L0 \ + } \ + } + +#define CASE_ELL(Enum, E0, L0, L1) \ + { \ + uint32_t(DebugInfoDebug##Enum), "Debug" #Enum, " " #E0 " " #L0 " " #L1, { \ + uint32_t(DebugInfo##E0), L0, L1 \ + } \ + } + +// DebugInfo 4.1 Absent Debugging Information +INSTANTIATE_TEST_CASE_P(DebugInfoDebugInfoNone, ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_0(InfoNone), // enum value 0 + })), ); + +// DebugInfo 4.2 Compilation Unit +INSTANTIATE_TEST_CASE_P(DebugInfoDebugCompilationUnit, + ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_ILL(CompilationUnit, 100, 42), + })), ); + +// DebugInfo 4.3 Type instructions +INSTANTIATE_TEST_CASE_P(DebugInfoDebugTypeBasic, ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_IIE(TypeBasic, Unspecified), + CASE_IIE(TypeBasic, Address), + CASE_IIE(TypeBasic, Boolean), + CASE_IIE(TypeBasic, Float), + CASE_IIE(TypeBasic, Signed), + CASE_IIE(TypeBasic, SignedChar), + CASE_IIE(TypeBasic, Unsigned), + CASE_IIE(TypeBasic, UnsignedChar), + })), ); + +// The FlagIsPublic is value is (1 << 0) | (1 << 2) which is the same +// as the bitwise-OR of FlagIsProtected and FlagIsPrivate. +// The disassembler will emit the compound expression instead. +// There is no simple fix for this. This enum is not really a mask +// for the bottom two bits. +TEST_F(ExtInstDebugInfoRoundTripTestExplicit, FlagIsPublic) { + const std::string prefix = + "%1 = OpExtInstImport \"DebugInfo\"\n" + "%3 = OpExtInst %2 %1 DebugTypePointer %4 Private "; + const std::string input = prefix + "FlagIsPublic\n"; + const std::string expected = prefix + "FlagIsProtected|FlagIsPrivate\n"; + // First make sure it assembles correctly. + EXPECT_THAT( + CompiledInstructions(input), + Eq(Concatenate( + {MakeInstruction(SpvOpExtInstImport, {1}, MakeVector("DebugInfo")), + MakeInstruction(SpvOpExtInst, {2, 3, 1, DebugInfoDebugTypePointer, 4, + uint32_t(SpvStorageClassPrivate), + DebugInfoFlagIsPublic})}))) + << input; + // Now check the round trip through the disassembler. + EXPECT_THAT(EncodeAndDecodeSuccessfully(input), Eq(expected)) << input; +} + +INSTANTIATE_TEST_CASE_P( + DebugInfoDebugTypePointer, ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + + //// Use each flag independently. + CASE_ISF(TypePointer, Private, "FlagIsProtected", + uint32_t(DebugInfoFlagIsProtected)), + CASE_ISF(TypePointer, Private, "FlagIsPrivate", + uint32_t(DebugInfoFlagIsPrivate)), + + // FlagIsPublic is tested above. + + CASE_ISF(TypePointer, Private, "FlagIsLocal", + uint32_t(DebugInfoFlagIsLocal)), + CASE_ISF(TypePointer, Private, "FlagIsDefinition", + uint32_t(DebugInfoFlagIsDefinition)), + CASE_ISF(TypePointer, Private, "FlagFwdDecl", + uint32_t(DebugInfoFlagFwdDecl)), + CASE_ISF(TypePointer, Private, "FlagArtificial", + uint32_t(DebugInfoFlagArtificial)), + CASE_ISF(TypePointer, Private, "FlagExplicit", + uint32_t(DebugInfoFlagExplicit)), + CASE_ISF(TypePointer, Private, "FlagPrototyped", + uint32_t(DebugInfoFlagPrototyped)), + CASE_ISF(TypePointer, Private, "FlagObjectPointer", + uint32_t(DebugInfoFlagObjectPointer)), + CASE_ISF(TypePointer, Private, "FlagStaticMember", + uint32_t(DebugInfoFlagStaticMember)), + CASE_ISF(TypePointer, Private, "FlagIndirectVariable", + uint32_t(DebugInfoFlagIndirectVariable)), + CASE_ISF(TypePointer, Private, "FlagLValueReference", + uint32_t(DebugInfoFlagLValueReference)), + CASE_ISF(TypePointer, Private, "FlagIsOptimized", + uint32_t(DebugInfoFlagIsOptimized)), + + //// Use flags in combination, and try different storage classes. + CASE_ISF(TypePointer, Function, "FlagIsProtected|FlagIsPrivate", + uint32_t(DebugInfoFlagIsProtected) | + uint32_t(DebugInfoFlagIsPrivate)), + CASE_ISF( + TypePointer, Workgroup, + "FlagIsPrivate|FlagFwdDecl|FlagIndirectVariable|FlagIsOptimized", + uint32_t(DebugInfoFlagIsPrivate) | uint32_t(DebugInfoFlagFwdDecl) | + uint32_t(DebugInfoFlagIndirectVariable) | + uint32_t(DebugInfoFlagIsOptimized)), + + })), ); + +INSTANTIATE_TEST_CASE_P(DebugInfoDebugTypeQualifier, + ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_IE(TypeQualifier, ConstType), + CASE_IE(TypeQualifier, VolatileType), + CASE_IE(TypeQualifier, RestrictType), + })), ); + +INSTANTIATE_TEST_CASE_P(DebugInfoDebugTypeArray, ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_II(TypeArray), + CASE_III(TypeArray), + CASE_IIII(TypeArray), + CASE_IIIII(TypeArray), + })), ); + +INSTANTIATE_TEST_CASE_P(DebugInfoDebugTypeVector, ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_IL(TypeVector, 2), + CASE_IL(TypeVector, 3), + CASE_IL(TypeVector, 4), + CASE_IL(TypeVector, 16), + })), ); + +INSTANTIATE_TEST_CASE_P(DebugInfoDebugTypedef, ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_IIILLI(Typedef, 12, 13), + CASE_IIILLI(Typedef, 14, 99), + })), ); + +INSTANTIATE_TEST_CASE_P(DebugInfoDebugTypeFunction, + ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_I(TypeFunction), + CASE_II(TypeFunction), + CASE_III(TypeFunction), + CASE_IIII(TypeFunction), + CASE_IIIII(TypeFunction), + })), ); + +INSTANTIATE_TEST_CASE_P( + DebugInfoDebugTypeEnum, ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_IIILLIIFII( + TypeEnum, 12, 13, + "FlagIsPrivate|FlagFwdDecl|FlagIndirectVariable|FlagIsOptimized", + uint32_t(DebugInfoFlagIsPrivate) | uint32_t(DebugInfoFlagFwdDecl) | + uint32_t(DebugInfoFlagIndirectVariable) | + uint32_t(DebugInfoFlagIsOptimized)), + CASE_IIILLIIFIIII(TypeEnum, 17, 18, "FlagStaticMember", + uint32_t(DebugInfoFlagStaticMember)), + CASE_IIILLIIFIIIIII(TypeEnum, 99, 1, "FlagStaticMember", + uint32_t(DebugInfoFlagStaticMember)), + })), ); + +INSTANTIATE_TEST_CASE_P( + DebugInfoDebugTypeComposite, ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_IEILLIIF( + TypeComposite, Class, 12, 13, + "FlagIsPrivate|FlagFwdDecl|FlagIndirectVariable|FlagIsOptimized", + uint32_t(DebugInfoFlagIsPrivate) | uint32_t(DebugInfoFlagFwdDecl) | + uint32_t(DebugInfoFlagIndirectVariable) | + uint32_t(DebugInfoFlagIsOptimized)), + // Cover all tag values: Class, Structure, Union + CASE_IEILLIIF(TypeComposite, Class, 12, 13, "FlagIsPrivate", + uint32_t(DebugInfoFlagIsPrivate)), + CASE_IEILLIIF(TypeComposite, Structure, 12, 13, "FlagIsPrivate", + uint32_t(DebugInfoFlagIsPrivate)), + CASE_IEILLIIF(TypeComposite, Union, 12, 13, "FlagIsPrivate", + uint32_t(DebugInfoFlagIsPrivate)), + // Now add members + CASE_IEILLIIFI(TypeComposite, Class, 9, 10, "FlagIsPrivate", + uint32_t(DebugInfoFlagIsPrivate)), + CASE_IEILLIIFII(TypeComposite, Class, 9, 10, "FlagIsPrivate", + uint32_t(DebugInfoFlagIsPrivate)), + CASE_IEILLIIFIII(TypeComposite, Class, 9, 10, "FlagIsPrivate", + uint32_t(DebugInfoFlagIsPrivate)), + CASE_IEILLIIFIIII(TypeComposite, Class, 9, 10, "FlagIsPrivate", + uint32_t(DebugInfoFlagIsPrivate)), + })), ); + +INSTANTIATE_TEST_CASE_P(DebugInfoDebugTypeMember, ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_IIILLIIIF(TypeMember, 12, 13, "FlagIsPrivate", + uint32_t(DebugInfoFlagIsPrivate)), + CASE_IIILLIIIF(TypeMember, 99, 100, + "FlagIsPrivate|FlagFwdDecl", + uint32_t(DebugInfoFlagIsPrivate) | + uint32_t(DebugInfoFlagFwdDecl)), + // Add the optional Id argument. + CASE_IIILLIIIFI(TypeMember, 12, 13, "FlagIsPrivate", + uint32_t(DebugInfoFlagIsPrivate)), + })), ); + +INSTANTIATE_TEST_CASE_P( + DebugInfoDebugTypeInheritance, ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_IIIIF(TypeInheritance, "FlagIsPrivate", + uint32_t(DebugInfoFlagIsPrivate)), + CASE_IIIIF(TypeInheritance, "FlagIsPrivate|FlagFwdDecl", + uint32_t(DebugInfoFlagIsPrivate) | + uint32_t(DebugInfoFlagFwdDecl)), + })), ); + +INSTANTIATE_TEST_CASE_P(DebugInfoDebugTypePtrToMember, + ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_II(TypePtrToMember), + })), ); + +// DebugInfo 4.4 Templates + +INSTANTIATE_TEST_CASE_P(DebugInfoDebugTypeTemplate, + ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_II(TypeTemplate), + CASE_III(TypeTemplate), + CASE_IIII(TypeTemplate), + CASE_IIIII(TypeTemplate), + })), ); + +INSTANTIATE_TEST_CASE_P(DebugInfoDebugTypeTemplateParameter, + ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_IIIILL(TypeTemplateParameter, 1, 2), + CASE_IIIILL(TypeTemplateParameter, 99, 102), + CASE_IIIILL(TypeTemplateParameter, 10, 7), + })), ); + +INSTANTIATE_TEST_CASE_P(DebugInfoDebugTypeTemplateTemplateParameter, + ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_IIILL(TypeTemplateTemplateParameter, 1, 2), + CASE_IIILL(TypeTemplateTemplateParameter, 99, 102), + CASE_IIILL(TypeTemplateTemplateParameter, 10, 7), + })), ); + +INSTANTIATE_TEST_CASE_P(DebugInfoDebugTypeTemplateParameterPack, + ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_IILLI(TypeTemplateParameterPack, 1, 2), + CASE_IILLII(TypeTemplateParameterPack, 99, 102), + CASE_IILLIII(TypeTemplateParameterPack, 10, 7), + CASE_IILLIIII(TypeTemplateParameterPack, 10, 7), + })), ); + +// DebugInfo 4.5 Global Variables + +INSTANTIATE_TEST_CASE_P( + DebugInfoDebugGlobalVariable, ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_IIILLIIIF(GlobalVariable, 1, 2, "FlagIsOptimized", + uint32_t(DebugInfoFlagIsOptimized)), + CASE_IIILLIIIF(GlobalVariable, 42, 43, "FlagIsOptimized", + uint32_t(DebugInfoFlagIsOptimized)), + CASE_IIILLIIIFI(GlobalVariable, 1, 2, "FlagIsOptimized", + uint32_t(DebugInfoFlagIsOptimized)), + CASE_IIILLIIIFI(GlobalVariable, 42, 43, "FlagIsOptimized", + uint32_t(DebugInfoFlagIsOptimized)), + })), ); + +// DebugInfo 4.6 Functions + +INSTANTIATE_TEST_CASE_P( + DebugInfoDebugFunctionDeclaration, ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_IIILLIIF(FunctionDeclaration, 1, 2, "FlagIsOptimized", + uint32_t(DebugInfoFlagIsOptimized)), + CASE_IIILLIIF(FunctionDeclaration, 42, 43, "FlagFwdDecl", + uint32_t(DebugInfoFlagFwdDecl)), + })), ); + +INSTANTIATE_TEST_CASE_P( + DebugInfoDebugFunction, ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_IIILLIIFLI(Function, 1, 2, "FlagIsOptimized", + uint32_t(DebugInfoFlagIsOptimized), 3), + CASE_IIILLIIFLI(Function, 42, 43, "FlagFwdDecl", + uint32_t(DebugInfoFlagFwdDecl), 44), + // Add the optional declaration Id. + CASE_IIILLIIFLII(Function, 1, 2, "FlagIsOptimized", + uint32_t(DebugInfoFlagIsOptimized), 3), + CASE_IIILLIIFLII(Function, 42, 43, "FlagFwdDecl", + uint32_t(DebugInfoFlagFwdDecl), 44), + })), ); + +// DebugInfo 4.7 Local Information + +INSTANTIATE_TEST_CASE_P(DebugInfoDebugLexicalBlock, + ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_ILLII(LexicalBlock, 1, 2), + CASE_ILLII(LexicalBlock, 42, 43), + })), ); + +INSTANTIATE_TEST_CASE_P(DebugInfoDebugLexicalBlockDiscriminator, + ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_ILI(LexicalBlockDiscriminator, 1), + CASE_ILI(LexicalBlockDiscriminator, 42), + })), ); + +INSTANTIATE_TEST_CASE_P(DebugInfoDebugScope, ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_I(Scope), + CASE_II(Scope), + })), ); + +INSTANTIATE_TEST_CASE_P(DebugInfoDebugNoScope, ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_0(NoScope), + })), ); + +INSTANTIATE_TEST_CASE_P(DebugInfoDebugInlinedAt, ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_LII(InlinedAt, 1), + CASE_LII(InlinedAt, 42), + })), ); + +// DebugInfo 4.8 Local Variables + +INSTANTIATE_TEST_CASE_P(DebugInfoDebugLocalVariable, + ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_IIILLI(LocalVariable, 1, 2), + CASE_IIILLI(LocalVariable, 42, 43), + CASE_IIILLIL(LocalVariable, 1, 2, 3), + CASE_IIILLIL(LocalVariable, 42, 43, 44), + })), ); + +INSTANTIATE_TEST_CASE_P(DebugInfoDebugInlinedVariable, + ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_II(InlinedVariable), + })), ); + +INSTANTIATE_TEST_CASE_P(DebugInfoDebugDebugDeclare, + ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_III(Declare), + })), ); + +INSTANTIATE_TEST_CASE_P( + DebugInfoDebugDebugValue, ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_III(Value), + CASE_IIII(Value), + CASE_IIIII(Value), + CASE_IIIIII(Value), + // Test up to 4 id parameters. We can always try more. + CASE_IIIIIII(Value), + })), ); + +INSTANTIATE_TEST_CASE_P(DebugInfoDebugDebugOperation, + ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_E(Operation, Deref), + CASE_E(Operation, Plus), + CASE_E(Operation, Minus), + CASE_EL(Operation, PlusUconst, 1), + CASE_EL(Operation, PlusUconst, 42), + CASE_ELL(Operation, BitPiece, 1, 2), + CASE_ELL(Operation, BitPiece, 4, 5), + CASE_E(Operation, Swap), + CASE_E(Operation, Xderef), + CASE_E(Operation, StackValue), + CASE_EL(Operation, Constu, 1), + CASE_EL(Operation, Constu, 42), + })), ); + +INSTANTIATE_TEST_CASE_P(DebugInfoDebugDebugExpression, + ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_0(Expression), + CASE_I(Expression), + CASE_II(Expression), + CASE_III(Expression), + CASE_IIII(Expression), + CASE_IIIII(Expression), + CASE_IIIIII(Expression), + CASE_IIIIIII(Expression), + })), ); + +// DebugInfo 4.9 Macros + +INSTANTIATE_TEST_CASE_P(DebugInfoDebugMacroDef, ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_ILI(MacroDef, 1), + CASE_ILI(MacroDef, 42), + CASE_ILII(MacroDef, 1), + CASE_ILII(MacroDef, 42), + })), ); + +INSTANTIATE_TEST_CASE_P(DebugInfoDebugMacroUndef, ExtInstDebugInfoRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE_ILI(MacroUndef, 1), + CASE_ILI(MacroUndef, 42), + })), ); + +#undef CASE_0 +#undef CASE_ILL +#undef CASE_IL +#undef CASE_I +#undef CASE_II +#undef CASE_III +#undef CASE_IIII +#undef CASE_IIIII +#undef CASE_IIIIII +#undef CASE_IIIIIII +#undef CASE_IIILLI +#undef CASE_IIILLIL +#undef CASE_IE +#undef CASE_IIE +#undef CASE_ISF +#undef CASE_LII +#undef CASE_ILI +#undef CASE_ILII +#undef CASE_ILLII +#undef CASE_IIILLIIF +#undef CASE_IIILLIIFII +#undef CASE_IIILLIIFIIII +#undef CASE_IIILLIIFIIIIII +#undef CASE_IEILLIIF +#undef CASE_IEILLIIFI +#undef CASE_IEILLIIFII +#undef CASE_IEILLIIFIII +#undef CASE_IEILLIIFIIII +#undef CASE_IIILLIIIF +#undef CASE_IIILLIIIFI +#undef CASE_IIIIF +#undef CASE_IIILL +#undef CASE_IIIILL +#undef CASE_IILLI +#undef CASE_IILLII +#undef CASE_IILLIII +#undef CASE_IILLIIII +#undef CASE_IIILLIIFLI +#undef CASE_IIILLIIFLII +#undef CASE_E +#undef CASE_EL +#undef CASE_ELL + +} // namespace +} // namespace spvtools diff --git a/test/ext_inst.glsl_test.cpp b/test/ext_inst.glsl_test.cpp new file mode 100644 index 000000000..f52337c61 --- /dev/null +++ b/test/ext_inst.glsl_test.cpp @@ -0,0 +1,204 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "source/latest_version_glsl_std_450_header.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +/// Context for an extended instruction. +/// +/// Information about a GLSL extended instruction (including its opname, return +/// type, etc.) and related instructions used to generate the return type and +/// constant as the operands. Used in generating extended instruction tests. +struct ExtInstContext { + const char* extInstOpName; + const char* extInstOperandVars; + /// The following fields are used to check the SPIR-V binary representation + /// of this instruction. + uint32_t extInstOpcode; ///< Opcode value for this extended instruction. + uint32_t extInstLength; ///< Wordcount of this extended instruction. + std::vector extInstOperandIds; ///< Ids for operands. +}; + +using ExtInstGLSLstd450RoundTripTest = ::testing::TestWithParam; + +TEST_P(ExtInstGLSLstd450RoundTripTest, ParameterizedExtInst) { + spv_context context = spvContextCreate(SPV_ENV_UNIVERSAL_1_0); + const std::string spirv = R"( +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical Simple +OpEntryPoint Vertex %2 "main" +%3 = OpTypeVoid +%4 = OpTypeFunction %3 +%2 = OpFunction %3 None %5 +%6 = OpLabel +%8 = OpExtInst %7 %1 )" + std::string(GetParam().extInstOpName) + + " " + GetParam().extInstOperandVars + R"( +OpReturn +OpFunctionEnd +)"; + const std::string spirv_header = + R"(; SPIR-V +; Version: 1.0 +; Generator: Khronos SPIR-V Tools Assembler; 0 +; Bound: 9 +; Schema: 0)"; + spv_binary binary = nullptr; + spv_diagnostic diagnostic; + spv_result_t error = spvTextToBinary(context, spirv.c_str(), spirv.size(), + &binary, &diagnostic); + if (error) { + spvDiagnosticPrint(diagnostic); + spvDiagnosticDestroy(diagnostic); + ASSERT_EQ(SPV_SUCCESS, error) + << "Source was: " << std::endl + << spirv << std::endl + << "Test case for : " << GetParam().extInstOpName << std::endl; + } + + // Check we do have the extended instruction's corresponding binary code in + // the generated SPIR-V binary. + std::vector expected_contains( + {12 /*OpExtInst*/ | GetParam().extInstLength << 16, 7 /*return type*/, + 8 /*result id*/, 1 /*glsl450 import*/, GetParam().extInstOpcode}); + for (uint32_t operand : GetParam().extInstOperandIds) { + expected_contains.push_back(operand); + } + EXPECT_NE(binary->code + binary->wordCount, + std::search(binary->code, binary->code + binary->wordCount, + expected_contains.begin(), expected_contains.end())) + << "Cannot find\n" + << spvtest::WordVector(expected_contains).str() << "in\n" + << spvtest::WordVector(*binary).str(); + + // Check round trip gives the same text. + spv_text output_text = nullptr; + error = spvBinaryToText(context, binary->code, binary->wordCount, + SPV_BINARY_TO_TEXT_OPTION_NONE, &output_text, + &diagnostic); + + if (error) { + spvDiagnosticPrint(diagnostic); + spvDiagnosticDestroy(diagnostic); + ASSERT_EQ(SPV_SUCCESS, error); + } + EXPECT_EQ(spirv_header + spirv, output_text->str); + spvTextDestroy(output_text); + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +INSTANTIATE_TEST_CASE_P( + ExtInstParameters, ExtInstGLSLstd450RoundTripTest, + ::testing::ValuesIn(std::vector({ + // We are only testing the correctness of encoding and decoding here. + // Semantic correctness should be the responsibility of validator. So + // some of the instructions below have incorrect operand and/or return + // types, e.g, Modf, ModfStruct, etc. + {"Round", "%5", 1, 6, {5}}, + {"RoundEven", "%5", 2, 6, {5}}, + {"Trunc", "%5", 3, 6, {5}}, + {"FAbs", "%5", 4, 6, {5}}, + {"SAbs", "%5", 5, 6, {5}}, + {"FSign", "%5", 6, 6, {5}}, + {"SSign", "%5", 7, 6, {5}}, + {"Floor", "%5", 8, 6, {5}}, + {"Ceil", "%5", 9, 6, {5}}, + {"Fract", "%5", 10, 6, {5}}, + {"Radians", "%5", 11, 6, {5}}, + {"Degrees", "%5", 12, 6, {5}}, + {"Sin", "%5", 13, 6, {5}}, + {"Cos", "%5", 14, 6, {5}}, + {"Tan", "%5", 15, 6, {5}}, + {"Asin", "%5", 16, 6, {5}}, + {"Acos", "%5", 17, 6, {5}}, + {"Atan", "%5", 18, 6, {5}}, + {"Sinh", "%5", 19, 6, {5}}, + {"Cosh", "%5", 20, 6, {5}}, + {"Tanh", "%5", 21, 6, {5}}, + {"Asinh", "%5", 22, 6, {5}}, + {"Acosh", "%5", 23, 6, {5}}, + {"Atanh", "%5", 24, 6, {5}}, + {"Atan2", "%5 %5", 25, 7, {5, 5}}, + {"Pow", "%5 %5", 26, 7, {5, 5}}, + {"Exp", "%5", 27, 6, {5}}, + {"Log", "%5", 28, 6, {5}}, + {"Exp2", "%5", 29, 6, {5}}, + {"Log2", "%5", 30, 6, {5}}, + {"Sqrt", "%5", 31, 6, {5}}, + {"InverseSqrt", "%5", 32, 6, {5}}, + {"Determinant", "%5", 33, 6, {5}}, + {"MatrixInverse", "%5", 34, 6, {5}}, + {"Modf", "%5 %5", 35, 7, {5, 5}}, + {"ModfStruct", "%5", 36, 6, {5}}, + {"FMin", "%5 %5", 37, 7, {5, 5}}, + {"UMin", "%5 %5", 38, 7, {5, 5}}, + {"SMin", "%5 %5", 39, 7, {5, 5}}, + {"FMax", "%5 %5", 40, 7, {5, 5}}, + {"UMax", "%5 %5", 41, 7, {5, 5}}, + {"SMax", "%5 %5", 42, 7, {5, 5}}, + {"FClamp", "%5 %5 %5", 43, 8, {5, 5, 5}}, + {"UClamp", "%5 %5 %5", 44, 8, {5, 5, 5}}, + {"SClamp", "%5 %5 %5", 45, 8, {5, 5, 5}}, + {"FMix", "%5 %5 %5", 46, 8, {5, 5, 5}}, + {"IMix", "%5 %5 %5", 47, 8, {5, 5, 5}}, // Bug 15452. Reserved. + {"Step", "%5 %5", 48, 7, {5, 5}}, + {"SmoothStep", "%5 %5 %5", 49, 8, {5, 5, 5}}, + {"Fma", "%5 %5 %5", 50, 8, {5, 5, 5}}, + {"Frexp", "%5 %5", 51, 7, {5, 5}}, + {"FrexpStruct", "%5", 52, 6, {5}}, + {"Ldexp", "%5 %5", 53, 7, {5, 5}}, + {"PackSnorm4x8", "%5", 54, 6, {5}}, + {"PackUnorm4x8", "%5", 55, 6, {5}}, + {"PackSnorm2x16", "%5", 56, 6, {5}}, + {"PackUnorm2x16", "%5", 57, 6, {5}}, + {"PackHalf2x16", "%5", 58, 6, {5}}, + {"PackDouble2x32", "%5", 59, 6, {5}}, + {"UnpackSnorm2x16", "%5", 60, 6, {5}}, + {"UnpackUnorm2x16", "%5", 61, 6, {5}}, + {"UnpackHalf2x16", "%5", 62, 6, {5}}, + {"UnpackSnorm4x8", "%5", 63, 6, {5}}, + {"UnpackUnorm4x8", "%5", 64, 6, {5}}, + {"UnpackDouble2x32", "%5", 65, 6, {5}}, + {"Length", "%5", 66, 6, {5}}, + {"Distance", "%5 %5", 67, 7, {5, 5}}, + {"Cross", "%5 %5", 68, 7, {5, 5}}, + {"Normalize", "%5", 69, 6, {5}}, + // clang-format off + {"FaceForward", "%5 %5 %5", 70, 8, {5, 5, 5}}, + // clang-format on + {"Reflect", "%5 %5", 71, 7, {5, 5}}, + {"Refract", "%5 %5 %5", 72, 8, {5, 5, 5}}, + {"FindILsb", "%5", 73, 6, {5}}, + {"FindSMsb", "%5", 74, 6, {5}}, + {"FindUMsb", "%5", 75, 6, {5}}, + {"InterpolateAtCentroid", "%5", 76, 6, {5}}, + // clang-format off + {"InterpolateAtSample", "%5 %5", 77, 7, {5, 5}}, + {"InterpolateAtOffset", "%5 %5", 78, 7, {5, 5}}, + // clang-format on + {"NMin", "%5 %5", 79, 7, {5, 5}}, + {"NMax", "%5 %5", 80, 7, {5, 5}}, + {"NClamp", "%5 %5 %5", 81, 8, {5, 5, 5}}, + })), ); + +} // namespace +} // namespace spvtools diff --git a/test/ext_inst.opencl_test.cpp b/test/ext_inst.opencl_test.cpp new file mode 100644 index 000000000..06bc5e848 --- /dev/null +++ b/test/ext_inst.opencl_test.cpp @@ -0,0 +1,373 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gmock/gmock.h" +#include "source/latest_version_opencl_std_header.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using spvtest::Concatenate; +using spvtest::MakeInstruction; +using spvtest::MakeVector; +using spvtest::TextToBinaryTest; +using testing::Eq; + +struct InstructionCase { + uint32_t opcode; + std::string name; + std::string operands; + std::vector expected_operands; +}; + +using ExtInstOpenCLStdRoundTripTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>; + +TEST_P(ExtInstOpenCLStdRoundTripTest, ParameterizedExtInst) { + // This example should not validate. + const std::string input = + "%1 = OpExtInstImport \"OpenCL.std\"\n" + "%3 = OpExtInst %2 %1 " + + GetParam().name + " " + GetParam().operands + "\n"; + // First make sure it assembles correctly. + EXPECT_THAT( + CompiledInstructions(input), + Eq(Concatenate( + {MakeInstruction(SpvOpExtInstImport, {1}, MakeVector("OpenCL.std")), + MakeInstruction(SpvOpExtInst, {2, 3, 1, GetParam().opcode}, + GetParam().expected_operands)}))) + << input; + // Now check the round trip through the disassembler. + EXPECT_THAT(EncodeAndDecodeSuccessfully(input), input) << input; +} + +#define CASE1(Enum, Name) \ + { \ + uint32_t(OpenCLLIB::Entrypoints::Enum), #Name, "%4", { 4 } \ + } +#define CASE2(Enum, Name) \ + { \ + uint32_t(OpenCLLIB::Entrypoints::Enum), #Name, "%4 %5", { 4, 5 } \ + } +#define CASE3(Enum, Name) \ + { \ + uint32_t(OpenCLLIB::Entrypoints::Enum), #Name, "%4 %5 %6", { 4, 5, 6 } \ + } +#define CASE4(Enum, Name) \ + { \ + uint32_t(OpenCLLIB::Entrypoints::Enum), #Name, "%4 %5 %6 %7", { \ + 4, 5, 6, 7 \ + } \ + } +#define CASE2Lit(Enum, Name, LiteralNumber) \ + { \ + uint32_t(OpenCLLIB::Entrypoints::Enum), #Name, "%4 %5 " #LiteralNumber, { \ + 4, 5, LiteralNumber \ + } \ + } +#define CASE3Round(Enum, Name, Mode) \ + { \ + uint32_t(OpenCLLIB::Entrypoints::Enum), #Name, "%4 %5 %6 " #Mode, { \ + 4, 5, 6, uint32_t(SpvFPRoundingMode##Mode) \ + } \ + } + +// clang-format off +// OpenCL.std: 2.1 Math extended instructions +INSTANTIATE_TEST_CASE_P( + OpenCLMath, ExtInstOpenCLStdRoundTripTest, + ::testing::ValuesIn(std::vector({ + // We are only testing the correctness of encoding and decoding here. + // Semantic correctness should be the responsibility of validator. + CASE1(Acos, acos), // enum value 0 + CASE1(Acosh, acosh), + CASE1(Acospi, acospi), + CASE1(Asin, asin), + CASE1(Asinh, asinh), + CASE1(Asinh, asinh), + CASE1(Asinpi, asinpi), + CASE1(Atan, atan), + CASE2(Atan2, atan2), + CASE1(Atanh, atanh), + CASE1(Atanpi, atanpi), + CASE2(Atan2pi, atan2pi), + CASE1(Cbrt, cbrt), + CASE1(Ceil, ceil), + CASE1(Ceil, ceil), + CASE2(Copysign, copysign), + CASE1(Cos, cos), + CASE1(Cosh, cosh), + CASE1(Cospi, cospi), + CASE1(Erfc, erfc), + CASE1(Erf, erf), + CASE1(Exp, exp), + CASE1(Exp2, exp2), + CASE1(Exp10, exp10), + CASE1(Expm1, expm1), + CASE1(Fabs, fabs), + CASE2(Fdim, fdim), + CASE1(Floor, floor), + CASE3(Fma, fma), + CASE2(Fmax, fmax), + CASE2(Fmin, fmin), + CASE2(Fmod, fmod), + CASE2(Fract, fract), + CASE2(Frexp, frexp), + CASE2(Hypot, hypot), + CASE1(Ilogb, ilogb), + CASE2(Ldexp, ldexp), + CASE1(Lgamma, lgamma), + CASE2(Lgamma_r, lgamma_r), + CASE1(Log, log), + CASE1(Log2, log2), + CASE1(Log10, log10), + CASE1(Log1p, log1p), + CASE3(Mad, mad), + CASE2(Maxmag, maxmag), + CASE2(Minmag, minmag), + CASE2(Modf, modf), + CASE1(Nan, nan), + CASE2(Nextafter, nextafter), + CASE2(Pow, pow), + CASE2(Pown, pown), + CASE2(Powr, powr), + CASE2(Remainder, remainder), + CASE3(Remquo, remquo), + CASE1(Rint, rint), + CASE2(Rootn, rootn), + CASE1(Round, round), + CASE1(Rsqrt, rsqrt), + CASE1(Sin, sin), + CASE2(Sincos, sincos), + CASE1(Sinh, sinh), + CASE1(Sinpi, sinpi), + CASE1(Sqrt, sqrt), + CASE1(Tan, tan), + CASE1(Tanh, tanh), + CASE1(Tanpi, tanpi), + CASE1(Tgamma, tgamma), + CASE1(Trunc, trunc), + CASE1(Half_cos, half_cos), + CASE2(Half_divide, half_divide), + CASE1(Half_exp, half_exp), + CASE1(Half_exp2, half_exp2), + CASE1(Half_exp10, half_exp10), + CASE1(Half_log, half_log), + CASE1(Half_log2, half_log2), + CASE1(Half_log10, half_log10), + CASE2(Half_powr, half_powr), + CASE1(Half_recip, half_recip), + CASE1(Half_rsqrt, half_rsqrt), + CASE1(Half_sin, half_sin), + CASE1(Half_sqrt, half_sqrt), + CASE1(Half_tan, half_tan), + CASE1(Native_cos, native_cos), + CASE2(Native_divide, native_divide), + CASE1(Native_exp, native_exp), + CASE1(Native_exp2, native_exp2), + CASE1(Native_exp10, native_exp10), + CASE1(Native_log, native_log), + CASE1(Native_log10, native_log10), + CASE2(Native_powr, native_powr), + CASE1(Native_recip, native_recip), + CASE1(Native_rsqrt, native_rsqrt), + CASE1(Native_sin, native_sin), + CASE1(Native_sqrt, native_sqrt), + CASE1(Native_tan, native_tan), // enum value 94 + })),); + +// OpenCL.std: 2.1 Integer instructions +INSTANTIATE_TEST_CASE_P( + OpenCLInteger, ExtInstOpenCLStdRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE1(SAbs, s_abs), // enum value 141 + CASE2(SAbs_diff, s_abs_diff), + CASE2(SAdd_sat, s_add_sat), + CASE2(UAdd_sat, u_add_sat), + CASE2(SHadd, s_hadd), + CASE2(UHadd, u_hadd), + CASE2(SRhadd, s_rhadd), + CASE2(SRhadd, s_rhadd), + CASE3(SClamp, s_clamp), + CASE3(UClamp, u_clamp), + CASE1(Clz, clz), + CASE1(Ctz, ctz), + CASE3(SMad_hi, s_mad_hi), + CASE3(UMad_sat, u_mad_sat), + CASE3(SMad_sat, s_mad_sat), + CASE2(SMax, s_max), + CASE2(UMax, u_max), + CASE2(SMin, s_min), + CASE2(UMin, u_min), + CASE2(SMul_hi, s_mul_hi), + CASE2(Rotate, rotate), + CASE2(SSub_sat, s_sub_sat), + CASE2(USub_sat, u_sub_sat), + CASE2(U_Upsample, u_upsample), + CASE2(S_Upsample, s_upsample), + CASE1(Popcount, popcount), + CASE3(SMad24, s_mad24), + CASE3(UMad24, u_mad24), + CASE2(SMul24, s_mul24), + CASE2(UMul24, u_mul24), // enum value 170 + CASE1(UAbs, u_abs), // enum value 201 + CASE2(UAbs_diff, u_abs_diff), + CASE2(UMul_hi, u_mul_hi), + CASE3(UMad_hi, u_mad_hi), // enum value 204 + })),); + +// OpenCL.std: 2.3 Common instrucitons +INSTANTIATE_TEST_CASE_P( + OpenCLCommon, ExtInstOpenCLStdRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE3(FClamp, fclamp), // enum value 95 + CASE1(Degrees, degrees), + CASE2(FMax_common, fmax_common), + CASE2(FMin_common, fmin_common), + CASE3(Mix, mix), + CASE1(Radians, radians), + CASE2(Step, step), + CASE3(Smoothstep, smoothstep), + CASE1(Sign, sign), // enum value 103 + })),); + +// OpenCL.std: 2.4 Geometric instructions +INSTANTIATE_TEST_CASE_P( + OpenCLGeometric, ExtInstOpenCLStdRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE2(Cross, cross), // enum value 104 + CASE2(Distance, distance), + CASE1(Length, length), + CASE1(Normalize, normalize), + CASE2(Fast_distance, fast_distance), + CASE1(Fast_length, fast_length), + CASE1(Fast_normalize, fast_normalize), // enum value 110 + })),); + +// OpenCL.std: 2.5 Relational instructions +INSTANTIATE_TEST_CASE_P( + OpenCLRelational, ExtInstOpenCLStdRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE3(Bitselect, bitselect), // enum value 186 + CASE3(Select, select), // enum value 187 + })),); + +// OpenCL.std: 2.6 Vector data load and store instructions +INSTANTIATE_TEST_CASE_P( + OpenCLVectorLoadStore, ExtInstOpenCLStdRoundTripTest, + ::testing::ValuesIn(std::vector({ + // The last argument to Vloadn must be one of 2, 3, 4, 8, 16. + CASE2Lit(Vloadn, vloadn, 2), + CASE2Lit(Vloadn, vloadn, 3), + CASE2Lit(Vloadn, vloadn, 4), + CASE2Lit(Vloadn, vloadn, 8), + CASE2Lit(Vloadn, vloadn, 16), + CASE3(Vstoren, vstoren), + CASE2(Vload_half, vload_half), + CASE2Lit(Vload_halfn, vload_halfn, 2), + CASE2Lit(Vload_halfn, vload_halfn, 3), + CASE2Lit(Vload_halfn, vload_halfn, 4), + CASE2Lit(Vload_halfn, vload_halfn, 8), + CASE2Lit(Vload_halfn, vload_halfn, 16), + CASE3(Vstore_half, vstore_half), + // Try all the rounding modes. + CASE3Round(Vstore_half_r, vstore_half_r, RTE), + CASE3Round(Vstore_half_r, vstore_half_r, RTZ), + CASE3Round(Vstore_half_r, vstore_half_r, RTP), + CASE3Round(Vstore_half_r, vstore_half_r, RTN), + CASE3(Vstore_halfn, vstore_halfn), + CASE3Round(Vstore_halfn_r, vstore_halfn_r, RTE), + CASE3Round(Vstore_halfn_r, vstore_halfn_r, RTZ), + CASE3Round(Vstore_halfn_r, vstore_halfn_r, RTP), + CASE3Round(Vstore_halfn_r, vstore_halfn_r, RTN), + CASE2Lit(Vloada_halfn, vloada_halfn, 2), + CASE2Lit(Vloada_halfn, vloada_halfn, 3), + CASE2Lit(Vloada_halfn, vloada_halfn, 4), + CASE2Lit(Vloada_halfn, vloada_halfn, 8), + CASE2Lit(Vloada_halfn, vloada_halfn, 16), + CASE3(Vstorea_halfn, vstorea_halfn), + CASE3Round(Vstorea_halfn_r, vstorea_halfn_r, RTE), + CASE3Round(Vstorea_halfn_r, vstorea_halfn_r, RTZ), + CASE3Round(Vstorea_halfn_r, vstorea_halfn_r, RTP), + CASE3Round(Vstorea_halfn_r, vstorea_halfn_r, RTN), + })),); + +// OpenCL.std: 2.7 Miscellaneous vector instructions +INSTANTIATE_TEST_CASE_P( + OpenCLMiscellaneousVector, ExtInstOpenCLStdRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE2(Shuffle, shuffle), + CASE3(Shuffle2, shuffle2), + })),); + +// OpenCL.std: 2.8 Miscellaneous instructions + +#define PREFIX uint32_t(OpenCLLIB::Entrypoints::Printf), "printf" +INSTANTIATE_TEST_CASE_P( + OpenCLMiscPrintf, ExtInstOpenCLStdRoundTripTest, + ::testing::ValuesIn(std::vector({ + // Printf is interesting because it takes a variable number of arguments. + // Start with zero optional arguments. + {PREFIX, "%4", {4}}, + {PREFIX, "%4 %5", {4, 5}}, + {PREFIX, "%4 %5 %6", {4, 5, 6}}, + {PREFIX, "%4 %5 %6 %7", {4, 5, 6, 7}}, + {PREFIX, "%4 %5 %6 %7 %8", {4, 5, 6, 7, 8}}, + {PREFIX, "%4 %5 %6 %7 %8 %9", {4, 5, 6, 7, 8, 9}}, + {PREFIX, "%4 %5 %6 %7 %8 %9 %10", {4, 5, 6, 7, 8, 9, 10}}, + {PREFIX, "%4 %5 %6 %7 %8 %9 %10 %11", {4, 5, 6, 7, 8, 9, 10, 11}}, + {PREFIX, "%4 %5 %6 %7 %8 %9 %10 %11 %12", + {4, 5, 6, 7, 8, 9, 10, 11, 12}}, + {PREFIX, "%4 %5 %6 %7 %8 %9 %10 %11 %12 %13", + {4, 5, 6, 7, 8, 9, 10, 11, 12, 13}}, + {PREFIX, "%4 %5 %6 %7 %8 %9 %10 %11 %12 %13 %14", + {4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}}, + })),); +#undef PREFIX + +INSTANTIATE_TEST_CASE_P( + OpenCLMiscPrefetch, ExtInstOpenCLStdRoundTripTest, + ::testing::ValuesIn(std::vector({ + CASE2(Prefetch, prefetch), + })),); + +// OpenCL.std: 2.9.1 Image encoding +// No new instructions defined in this section. + +// OpenCL.std: 2.9.2 Sampler encoding +// No new instructions defined in this section. + +// OpenCL.std: 2.9.3 Image read +// No new instructions defined in this section. +// Use core instruction OpImageSampleExplicitLod instead. + +// OpenCL.std: 2.9.4 Image write +// No new instructions defined in this section. + +// clang-format on + +#undef CASE1 +#undef CASE2 +#undef CASE3 +#undef CASE4 +#undef CASE2Lit +#undef CASE3Round + +} // namespace +} // namespace spvtools diff --git a/test/fix_word_test.cpp b/test/fix_word_test.cpp new file mode 100644 index 000000000..b8c3a33d5 --- /dev/null +++ b/test/fix_word_test.cpp @@ -0,0 +1,64 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +TEST(FixWord, Default) { + spv_endianness_t endian; + if (I32_ENDIAN_HOST == I32_ENDIAN_LITTLE) { + endian = SPV_ENDIANNESS_LITTLE; + } else { + endian = SPV_ENDIANNESS_BIG; + } + uint32_t word = 0x53780921; + ASSERT_EQ(word, spvFixWord(word, endian)); +} + +TEST(FixWord, Reorder) { + spv_endianness_t endian; + if (I32_ENDIAN_HOST == I32_ENDIAN_LITTLE) { + endian = SPV_ENDIANNESS_BIG; + } else { + endian = SPV_ENDIANNESS_LITTLE; + } + uint32_t word = 0x53780921; + uint32_t result = 0x21097853; + ASSERT_EQ(result, spvFixWord(word, endian)); +} + +TEST(FixDoubleWord, Default) { + spv_endianness_t endian = + (I32_ENDIAN_HOST == I32_ENDIAN_LITTLE ? SPV_ENDIANNESS_LITTLE + : SPV_ENDIANNESS_BIG); + uint32_t low = 0x53780921; + uint32_t high = 0xdeadbeef; + uint64_t result = 0xdeadbeef53780921; + ASSERT_EQ(result, spvFixDoubleWord(low, high, endian)); +} + +TEST(FixDoubleWord, Reorder) { + spv_endianness_t endian = + (I32_ENDIAN_HOST == I32_ENDIAN_LITTLE ? SPV_ENDIANNESS_BIG + : SPV_ENDIANNESS_LITTLE); + uint32_t low = 0x53780921; + uint32_t high = 0xdeadbeef; + uint64_t result = 0xefbeadde21097853; + ASSERT_EQ(result, spvFixDoubleWord(low, high, endian)); +} + +} // namespace +} // namespace spvtools diff --git a/test/fuzzers/BUILD.gn b/test/fuzzers/BUILD.gn new file mode 100644 index 000000000..69659c661 --- /dev/null +++ b/test/fuzzers/BUILD.gn @@ -0,0 +1,137 @@ +# Copyright 2018 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import("//testing/libfuzzer/fuzzer_test.gni") +import("//testing/test.gni") + +config("fuzzer_config") { + configs = [ "../..:spvtools_internal_config" ] +} + +group("fuzzers") { + testonly = true + deps = [] + + if (!build_with_chromium || use_fuzzing_engine) { + deps += [ ":fuzzers_bin" ] + } +} + +if (!build_with_chromium || use_fuzzing_engine) { + group("fuzzers_bin") { + testonly = true + + deps = [ + ":spvtools_binary_parser_fuzzer", + ":spvtools_opt_legalization_fuzzer", + ":spvtools_opt_performance_fuzzer", + ":spvtools_opt_size_fuzzer", + ":spvtools_val_fuzzer", + ] + } +} + +template("spvtools_fuzzer") { + source_set(target_name) { + testonly = true + sources = invoker.sources + deps = [ + "../..:spvtools", + "../..:spvtools_opt", + "../..:spvtools_val", + ] + if (defined(invoker.deps)) { + deps += invoker.deps + } + + configs -= [ "//build/config/compiler:chromium_code" ] + configs += [ + "//build/config/compiler:no_chromium_code", + ":fuzzer_config", + ] + } +} + +spvtools_fuzzer("spvtools_binary_parser_fuzzer_src") { + sources = [ + "spvtools_binary_parser_fuzzer.cpp", + ] +} + +spvtools_fuzzer("spvtools_opt_performance_fuzzer_src") { + sources = [ + "spvtools_opt_performance_fuzzer.cpp", + ] +} + +spvtools_fuzzer("spvtools_opt_legalization_fuzzer_src") { + sources = [ + "spvtools_opt_legalization_fuzzer.cpp", + ] +} + +spvtools_fuzzer("spvtools_opt_size_fuzzer_src") { + sources = [ + "spvtools_opt_size_fuzzer.cpp", + ] +} + +spvtools_fuzzer("spvtools_val_fuzzer_src") { + sources = [ + "spvtools_val_fuzzer.cpp", + ] +} + +if (!build_with_chromium || use_fuzzing_engine) { + fuzzer_test("spvtools_binary_parser_fuzzer") { + sources = [] + deps = [ + ":spvtools_binary_parser_fuzzer_src", + ] + # Intentionally doesn't use the seed corpus, because it consumes + # part of the input as not part of the file. + } + + fuzzer_test("spvtools_opt_performance_fuzzer") { + sources = [] + deps = [ + ":spvtools_opt_performance_fuzzer_src", + ] + seed_corpus = "corpora/spv" + } + + fuzzer_test("spvtools_opt_legalization_fuzzer") { + sources = [] + deps = [ + ":spvtools_opt_legalization_fuzzer_src", + ] + seed_corpus = "corpora/spv" + } + + fuzzer_test("spvtools_opt_size_fuzzer") { + sources = [] + deps = [ + ":spvtools_opt_size_fuzzer_src", + ] + seed_corpus = "corpora/spv" + } + + fuzzer_test("spvtools_val_fuzzer") { + sources = [] + deps = [ + ":spvtools_val_fuzzer_src", + ] + seed_corpus = "corpora/spv" + } +} diff --git a/test/fuzzers/corpora/spv/simple.spv b/test/fuzzers/corpora/spv/simple.spv new file mode 100644 index 0000000000000000000000000000000000000000..f972a56fd96468e5973c1d7943ff28be523effe3 GIT binary patch literal 728 zcmYk4+e-pb5XL9h)vl(QT0Mx;W3Z21hysa_pioqzw^*eIp>}1{p#Qv=qVIP)3mq6| zzHjcc28H9Y*_xSE>`1@7niV_~IE;7K%2t#{hU15Lo|tXRu1TsEI9`Qh1r|y_6-+v` zec97#^pY&IPnnLR-g`ES@;afi|JMbtD`q&pOBSo8m6QeM&C=I2lMcmhN-yFpev3Zh zESkGmc;=cXj=UM%4@QEDFS}XO)zUe}kI$3dsDC?gzM|D#4&BW9>Q9Zfga>BN9M^Rp zt!G`gtvS0XyCkk7X(^stfZ3}pAM?X?#B6vQ4!kAJO>%!#HkSu?<=^u*c@5!)bVowY ziN{UZJnFTi$we-Osn0&dgQ?Fx^f=XJ*o8THa9?8hhQ!j7J5%dQ7`-E*-cXnwDxSMm zHnHFfcX9qVe|_Qc(--br)0|7_PfcdyvFn-Q=*u@@7XCv$#j)3c=h-)(=g{*%ip8ex EFITED5dZ)H literal 0 HcmV?d00001 diff --git a/test/fuzzers/spvtools_binary_parser_fuzzer.cpp b/test/fuzzers/spvtools_binary_parser_fuzzer.cpp new file mode 100644 index 000000000..76ba4d9e9 --- /dev/null +++ b/test/fuzzers/spvtools_binary_parser_fuzzer.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "spirv-tools/libspirv.hpp" + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + if (size < sizeof(spv_target_env) + 1) return 0; + + const spv_context context = + spvContextCreate(*reinterpret_cast(data)); + if (context == nullptr) return 0; + + data += sizeof(spv_target_env); + size -= sizeof(spv_target_env); + + std::vector input; + input.resize(size >> 2); + + size_t count = 0; + for (size_t i = 0; (i + 3) < size; i += 4) { + input[count++] = data[i] | (data[i + 1] << 8) | (data[i + 2] << 16) | + (data[i + 3]) << 24; + } + + spvBinaryParse(context, nullptr, input.data(), input.size(), nullptr, nullptr, + nullptr); + + spvContextDestroy(context); + return 0; +} diff --git a/test/fuzzers/spvtools_opt_legalization_fuzzer.cpp b/test/fuzzers/spvtools_opt_legalization_fuzzer.cpp new file mode 100644 index 000000000..b45a98c37 --- /dev/null +++ b/test/fuzzers/spvtools_opt_legalization_fuzzer.cpp @@ -0,0 +1,38 @@ +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "spirv-tools/optimizer.hpp" + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + spvtools::Optimizer optimizer(SPV_ENV_UNIVERSAL_1_3); + optimizer.SetMessageConsumer([](spv_message_level_t, const char*, + const spv_position_t&, const char*) {}); + + std::vector input; + input.resize(size >> 2); + + size_t count = 0; + for (size_t i = 0; (i + 3) < size; i += 4) { + input[count++] = data[i] | (data[i + 1] << 8) | (data[i + 2] << 16) | + (data[i + 3]) << 24; + } + + optimizer.RegisterLegalizationPasses(); + optimizer.Run(input.data(), input.size(), &input); + + return 0; +} diff --git a/test/fuzzers/spvtools_opt_performance_fuzzer.cpp b/test/fuzzers/spvtools_opt_performance_fuzzer.cpp new file mode 100644 index 000000000..6c3bd6aba --- /dev/null +++ b/test/fuzzers/spvtools_opt_performance_fuzzer.cpp @@ -0,0 +1,38 @@ +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "spirv-tools/optimizer.hpp" + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + spvtools::Optimizer optimizer(SPV_ENV_UNIVERSAL_1_3); + optimizer.SetMessageConsumer([](spv_message_level_t, const char*, + const spv_position_t&, const char*) {}); + + std::vector input; + input.resize(size >> 2); + + size_t count = 0; + for (size_t i = 0; (i + 3) < size; i += 4) { + input[count++] = data[i] | (data[i + 1] << 8) | (data[i + 2] << 16) | + (data[i + 3]) << 24; + } + + optimizer.RegisterPerformancePasses(); + optimizer.Run(input.data(), input.size(), &input); + + return 0; +} diff --git a/test/fuzzers/spvtools_opt_size_fuzzer.cpp b/test/fuzzers/spvtools_opt_size_fuzzer.cpp new file mode 100644 index 000000000..68c797477 --- /dev/null +++ b/test/fuzzers/spvtools_opt_size_fuzzer.cpp @@ -0,0 +1,38 @@ +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "spirv-tools/optimizer.hpp" + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + spvtools::Optimizer optimizer(SPV_ENV_UNIVERSAL_1_3); + optimizer.SetMessageConsumer([](spv_message_level_t, const char*, + const spv_position_t&, const char*) {}); + + std::vector input; + input.resize(size >> 2); + + size_t count = 0; + for (size_t i = 0; (i + 3) < size; i += 4) { + input[count++] = data[i] | (data[i + 1] << 8) | (data[i + 2] << 16) | + (data[i + 3]) << 24; + } + + optimizer.RegisterSizePasses(); + optimizer.Run(input.data(), input.size(), &input); + + return 0; +} diff --git a/test/fuzzers/spvtools_val_fuzzer.cpp b/test/fuzzers/spvtools_val_fuzzer.cpp new file mode 100644 index 000000000..5dc4303b4 --- /dev/null +++ b/test/fuzzers/spvtools_val_fuzzer.cpp @@ -0,0 +1,36 @@ +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "spirv-tools/libspirv.hpp" + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + spvtools::SpirvTools tools(SPV_ENV_UNIVERSAL_1_3); + tools.SetMessageConsumer([](spv_message_level_t, const char*, + const spv_position_t&, const char*) {}); + + std::vector input; + input.resize(size >> 2); + + size_t count = 0; + for (size_t i = 0; (i + 3) < size; i += 4) { + input[count++] = data[i] | (data[i + 1] << 8) | (data[i + 2] << 16) | + (data[i + 3]) << 24; + } + + tools.Validate(input); + return 0; +} diff --git a/test/generator_magic_number_test.cpp b/test/generator_magic_number_test.cpp new file mode 100644 index 000000000..bc5fdf57a --- /dev/null +++ b/test/generator_magic_number_test.cpp @@ -0,0 +1,62 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opcode.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using ::spvtest::EnumCase; +using ::testing::Eq; +using GeneratorMagicNumberTest = + ::testing::TestWithParam>; + +TEST_P(GeneratorMagicNumberTest, Single) { + EXPECT_THAT(std::string(spvGeneratorStr(GetParam().value())), + GetParam().name()); +} + +INSTANTIATE_TEST_CASE_P( + Registered, GeneratorMagicNumberTest, + ::testing::ValuesIn(std::vector>{ + {SPV_GENERATOR_KHRONOS, "Khronos"}, + {SPV_GENERATOR_LUNARG, "LunarG"}, + {SPV_GENERATOR_VALVE, "Valve"}, + {SPV_GENERATOR_CODEPLAY, "Codeplay"}, + {SPV_GENERATOR_NVIDIA, "NVIDIA"}, + {SPV_GENERATOR_ARM, "ARM"}, + {SPV_GENERATOR_KHRONOS_LLVM_TRANSLATOR, + "Khronos LLVM/SPIR-V Translator"}, + {SPV_GENERATOR_KHRONOS_ASSEMBLER, "Khronos SPIR-V Tools Assembler"}, + {SPV_GENERATOR_KHRONOS_GLSLANG, "Khronos Glslang Reference Front End"}, + }), ); + +INSTANTIATE_TEST_CASE_P( + Unregistered, GeneratorMagicNumberTest, + ::testing::ValuesIn(std::vector>{ + // We read registered entries from the SPIR-V XML Registry file + // which can change over time. + {spv_generator_t(1000), "Unknown"}, + {spv_generator_t(9999), "Unknown"}, + }), ); + +} // namespace +} // namespace spvtools diff --git a/test/hex_float_test.cpp b/test/hex_float_test.cpp new file mode 100644 index 000000000..87450609f --- /dev/null +++ b/test/hex_float_test.cpp @@ -0,0 +1,1331 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/util/hex_float.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace utils { +namespace { + +using ::testing::Eq; + +// In this file "encode" means converting a number into a string, +// and "decode" means converting a string into a number. + +using HexFloatTest = + ::testing::TestWithParam, std::string>>; +using DecodeHexFloatTest = + ::testing::TestWithParam>>; +using HexDoubleTest = + ::testing::TestWithParam, std::string>>; +using DecodeHexDoubleTest = + ::testing::TestWithParam>>; +using RoundTripFloatTest = ::testing::TestWithParam; +using RoundTripDoubleTest = ::testing::TestWithParam; + +// Hex-encodes a float value. +template +std::string EncodeViaHexFloat(const T& value) { + std::stringstream ss; + ss << HexFloat(value); + return ss.str(); +} + +// The following two tests can't be DRY because they take different parameter +// types. + +TEST_P(HexFloatTest, EncodeCorrectly) { + EXPECT_THAT(EncodeViaHexFloat(GetParam().first), Eq(GetParam().second)); +} + +TEST_P(HexDoubleTest, EncodeCorrectly) { + EXPECT_THAT(EncodeViaHexFloat(GetParam().first), Eq(GetParam().second)); +} + +// Decodes a hex-float string. +template +FloatProxy Decode(const std::string& str) { + HexFloat> decoded(0.f); + EXPECT_TRUE((std::stringstream(str) >> decoded).eof()); + return decoded.value(); +} + +TEST_P(HexFloatTest, DecodeCorrectly) { + EXPECT_THAT(Decode(GetParam().second), Eq(GetParam().first)); +} + +TEST_P(HexDoubleTest, DecodeCorrectly) { + EXPECT_THAT(Decode(GetParam().second), Eq(GetParam().first)); +} + +INSTANTIATE_TEST_CASE_P( + Float32Tests, HexFloatTest, + ::testing::ValuesIn(std::vector, std::string>>({ + {0.f, "0x0p+0"}, + {1.f, "0x1p+0"}, + {2.f, "0x1p+1"}, + {3.f, "0x1.8p+1"}, + {0.5f, "0x1p-1"}, + {0.25f, "0x1p-2"}, + {0.75f, "0x1.8p-1"}, + {-0.f, "-0x0p+0"}, + {-1.f, "-0x1p+0"}, + {-0.5f, "-0x1p-1"}, + {-0.25f, "-0x1p-2"}, + {-0.75f, "-0x1.8p-1"}, + + // Larger numbers + {512.f, "0x1p+9"}, + {-512.f, "-0x1p+9"}, + {1024.f, "0x1p+10"}, + {-1024.f, "-0x1p+10"}, + {1024.f + 8.f, "0x1.02p+10"}, + {-1024.f - 8.f, "-0x1.02p+10"}, + + // Small numbers + {1.0f / 512.f, "0x1p-9"}, + {1.0f / -512.f, "-0x1p-9"}, + {1.0f / 1024.f, "0x1p-10"}, + {1.0f / -1024.f, "-0x1p-10"}, + {1.0f / 1024.f + 1.0f / 8.f, "0x1.02p-3"}, + {1.0f / -1024.f - 1.0f / 8.f, "-0x1.02p-3"}, + + // lowest non-denorm + {float(ldexp(1.0f, -126)), "0x1p-126"}, + {float(ldexp(-1.0f, -126)), "-0x1p-126"}, + + // Denormalized values + {float(ldexp(1.0f, -127)), "0x1p-127"}, + {float(ldexp(1.0f, -127) / 2.0f), "0x1p-128"}, + {float(ldexp(1.0f, -127) / 4.0f), "0x1p-129"}, + {float(ldexp(1.0f, -127) / 8.0f), "0x1p-130"}, + {float(ldexp(-1.0f, -127)), "-0x1p-127"}, + {float(ldexp(-1.0f, -127) / 2.0f), "-0x1p-128"}, + {float(ldexp(-1.0f, -127) / 4.0f), "-0x1p-129"}, + {float(ldexp(-1.0f, -127) / 8.0f), "-0x1p-130"}, + + {float(ldexp(1.0, -127) + (ldexp(1.0, -127) / 2.0f)), "0x1.8p-127"}, + {float(ldexp(1.0, -127) / 2.0 + (ldexp(1.0, -127) / 4.0f)), + "0x1.8p-128"}, + + })), ); + +INSTANTIATE_TEST_CASE_P( + Float32NanTests, HexFloatTest, + ::testing::ValuesIn(std::vector, std::string>>({ + // Various NAN and INF cases + {uint32_t(0xFF800000), "-0x1p+128"}, // -inf + {uint32_t(0x7F800000), "0x1p+128"}, // inf + {uint32_t(0xFFC00000), "-0x1.8p+128"}, // -nan + {uint32_t(0xFF800100), "-0x1.0002p+128"}, // -nan + {uint32_t(0xFF800c00), "-0x1.0018p+128"}, // -nan + {uint32_t(0xFF80F000), "-0x1.01ep+128"}, // -nan + {uint32_t(0xFFFFFFFF), "-0x1.fffffep+128"}, // -nan + {uint32_t(0x7FC00000), "0x1.8p+128"}, // +nan + {uint32_t(0x7F800100), "0x1.0002p+128"}, // +nan + {uint32_t(0x7f800c00), "0x1.0018p+128"}, // +nan + {uint32_t(0x7F80F000), "0x1.01ep+128"}, // +nan + {uint32_t(0x7FFFFFFF), "0x1.fffffep+128"}, // +nan + })), ); + +INSTANTIATE_TEST_CASE_P( + Float64Tests, HexDoubleTest, + ::testing::ValuesIn( + std::vector, std::string>>({ + {0., "0x0p+0"}, + {1., "0x1p+0"}, + {2., "0x1p+1"}, + {3., "0x1.8p+1"}, + {0.5, "0x1p-1"}, + {0.25, "0x1p-2"}, + {0.75, "0x1.8p-1"}, + {-0., "-0x0p+0"}, + {-1., "-0x1p+0"}, + {-0.5, "-0x1p-1"}, + {-0.25, "-0x1p-2"}, + {-0.75, "-0x1.8p-1"}, + + // Larger numbers + {512., "0x1p+9"}, + {-512., "-0x1p+9"}, + {1024., "0x1p+10"}, + {-1024., "-0x1p+10"}, + {1024. + 8., "0x1.02p+10"}, + {-1024. - 8., "-0x1.02p+10"}, + + // Large outside the range of normal floats + {ldexp(1.0, 128), "0x1p+128"}, + {ldexp(1.0, 129), "0x1p+129"}, + {ldexp(-1.0, 128), "-0x1p+128"}, + {ldexp(-1.0, 129), "-0x1p+129"}, + {ldexp(1.0, 128) + ldexp(1.0, 90), "0x1.0000000004p+128"}, + {ldexp(1.0, 129) + ldexp(1.0, 120), "0x1.008p+129"}, + {ldexp(-1.0, 128) + ldexp(1.0, 90), "-0x1.fffffffff8p+127"}, + {ldexp(-1.0, 129) + ldexp(1.0, 120), "-0x1.ffp+128"}, + + // Small numbers + {1.0 / 512., "0x1p-9"}, + {1.0 / -512., "-0x1p-9"}, + {1.0 / 1024., "0x1p-10"}, + {1.0 / -1024., "-0x1p-10"}, + {1.0 / 1024. + 1.0 / 8., "0x1.02p-3"}, + {1.0 / -1024. - 1.0 / 8., "-0x1.02p-3"}, + + // Small outside the range of normal floats + {ldexp(1.0, -128), "0x1p-128"}, + {ldexp(1.0, -129), "0x1p-129"}, + {ldexp(-1.0, -128), "-0x1p-128"}, + {ldexp(-1.0, -129), "-0x1p-129"}, + {ldexp(1.0, -128) + ldexp(1.0, -90), "0x1.0000000004p-90"}, + {ldexp(1.0, -129) + ldexp(1.0, -120), "0x1.008p-120"}, + {ldexp(-1.0, -128) + ldexp(1.0, -90), "0x1.fffffffff8p-91"}, + {ldexp(-1.0, -129) + ldexp(1.0, -120), "0x1.ffp-121"}, + + // lowest non-denorm + {ldexp(1.0, -1022), "0x1p-1022"}, + {ldexp(-1.0, -1022), "-0x1p-1022"}, + + // Denormalized values + {ldexp(1.0, -1023), "0x1p-1023"}, + {ldexp(1.0, -1023) / 2.0, "0x1p-1024"}, + {ldexp(1.0, -1023) / 4.0, "0x1p-1025"}, + {ldexp(1.0, -1023) / 8.0, "0x1p-1026"}, + {ldexp(-1.0, -1024), "-0x1p-1024"}, + {ldexp(-1.0, -1024) / 2.0, "-0x1p-1025"}, + {ldexp(-1.0, -1024) / 4.0, "-0x1p-1026"}, + {ldexp(-1.0, -1024) / 8.0, "-0x1p-1027"}, + + {ldexp(1.0, -1023) + (ldexp(1.0, -1023) / 2.0), "0x1.8p-1023"}, + {ldexp(1.0, -1023) / 2.0 + (ldexp(1.0, -1023) / 4.0), + "0x1.8p-1024"}, + + })), ); + +INSTANTIATE_TEST_CASE_P( + Float64NanTests, HexDoubleTest, + ::testing::ValuesIn(std::vector< + std::pair, std::string>>({ + // Various NAN and INF cases + {uint64_t(0xFFF0000000000000LL), "-0x1p+1024"}, // -inf + {uint64_t(0x7FF0000000000000LL), "0x1p+1024"}, // +inf + {uint64_t(0xFFF8000000000000LL), "-0x1.8p+1024"}, // -nan + {uint64_t(0xFFF0F00000000000LL), "-0x1.0fp+1024"}, // -nan + {uint64_t(0xFFF0000000000001LL), "-0x1.0000000000001p+1024"}, // -nan + {uint64_t(0xFFF0000300000000LL), "-0x1.00003p+1024"}, // -nan + {uint64_t(0xFFFFFFFFFFFFFFFFLL), "-0x1.fffffffffffffp+1024"}, // -nan + {uint64_t(0x7FF8000000000000LL), "0x1.8p+1024"}, // +nan + {uint64_t(0x7FF0F00000000000LL), "0x1.0fp+1024"}, // +nan + {uint64_t(0x7FF0000000000001LL), "0x1.0000000000001p+1024"}, // -nan + {uint64_t(0x7FF0000300000000LL), "0x1.00003p+1024"}, // -nan + {uint64_t(0x7FFFFFFFFFFFFFFFLL), "0x1.fffffffffffffp+1024"}, // -nan + })), ); + +// Tests that encoding a value and decoding it again restores +// the same value. +TEST_P(RoundTripFloatTest, CanStoreAccurately) { + std::stringstream ss; + ss << FloatProxy(GetParam()); + ss.seekg(0); + FloatProxy res; + ss >> res; + EXPECT_THAT(GetParam(), Eq(res.getAsFloat())); +} + +TEST_P(RoundTripDoubleTest, CanStoreAccurately) { + std::stringstream ss; + ss << FloatProxy(GetParam()); + ss.seekg(0); + FloatProxy res; + ss >> res; + EXPECT_THAT(GetParam(), Eq(res.getAsFloat())); +} + +INSTANTIATE_TEST_CASE_P( + Float32StoreTests, RoundTripFloatTest, + ::testing::ValuesIn(std::vector( + {// Value requiring more than 6 digits of precision to be + // represented accurately. + 3.0000002f}))); + +INSTANTIATE_TEST_CASE_P( + Float64StoreTests, RoundTripDoubleTest, + ::testing::ValuesIn(std::vector( + {// Value requiring more than 15 digits of precision to be + // represented accurately. + 1.5000000000000002}))); + +TEST(HexFloatStreamTest, OperatorLeftShiftPreservesFloatAndFill) { + std::stringstream s; + s << std::setw(4) << std::oct << std::setfill('x') << 8 << " " + << FloatProxy(uint32_t(0xFF800100)) << " " << std::setw(4) << 9; + EXPECT_THAT(s.str(), Eq(std::string("xx10 -0x1.0002p+128 xx11"))); +} + +TEST(HexDoubleStreamTest, OperatorLeftShiftPreservesFloatAndFill) { + std::stringstream s; + s << std::setw(4) << std::oct << std::setfill('x') << 8 << " " + << FloatProxy(uint64_t(0x7FF0F00000000000LL)) << " " << std::setw(4) + << 9; + EXPECT_THAT(s.str(), Eq(std::string("xx10 0x1.0fp+1024 xx11"))); +} + +TEST_P(DecodeHexFloatTest, DecodeCorrectly) { + EXPECT_THAT(Decode(GetParam().first), Eq(GetParam().second)); +} + +TEST_P(DecodeHexDoubleTest, DecodeCorrectly) { + EXPECT_THAT(Decode(GetParam().first), Eq(GetParam().second)); +} + +INSTANTIATE_TEST_CASE_P( + Float32DecodeTests, DecodeHexFloatTest, + ::testing::ValuesIn(std::vector>>({ + {"0x0p+000", 0.f}, + {"0x0p0", 0.f}, + {"0x0p-0", 0.f}, + + // flush to zero cases + {"0x1p-500", 0.f}, // Exponent underflows. + {"-0x1p-500", -0.f}, + {"0x0.00000000001p-126", 0.f}, // Fraction causes underflow. + {"-0x0.0000000001p-127", -0.f}, + {"-0x0.01p-142", -0.f}, // Fraction causes additional underflow. + {"0x0.01p-142", 0.f}, + + // Some floats that do not encode the same way as they decode. + {"0x2p+0", 2.f}, + {"0xFFp+0", 255.f}, + {"0x0.8p+0", 0.5f}, + {"0x0.4p+0", 0.25f}, + })), ); + +INSTANTIATE_TEST_CASE_P( + Float32DecodeInfTests, DecodeHexFloatTest, + ::testing::ValuesIn(std::vector>>({ + // inf cases + {"-0x1p+128", uint32_t(0xFF800000)}, // -inf + {"0x32p+127", uint32_t(0x7F800000)}, // inf + {"0x32p+500", uint32_t(0x7F800000)}, // inf + {"-0x32p+127", uint32_t(0xFF800000)}, // -inf + })), ); + +INSTANTIATE_TEST_CASE_P( + Float64DecodeTests, DecodeHexDoubleTest, + ::testing::ValuesIn( + std::vector>>({ + {"0x0p+000", 0.}, + {"0x0p0", 0.}, + {"0x0p-0", 0.}, + + // flush to zero cases + {"0x1p-5000", 0.}, // Exponent underflows. + {"-0x1p-5000", -0.}, + {"0x0.0000000000000001p-1023", 0.}, // Fraction causes underflow. + {"-0x0.000000000000001p-1024", -0.}, + {"-0x0.01p-1090", -0.f}, // Fraction causes additional underflow. + {"0x0.01p-1090", 0.}, + + // Some floats that do not encode the same way as they decode. + {"0x2p+0", 2.}, + {"0xFFp+0", 255.}, + {"0x0.8p+0", 0.5}, + {"0x0.4p+0", 0.25}, + })), ); + +INSTANTIATE_TEST_CASE_P( + Float64DecodeInfTests, DecodeHexDoubleTest, + ::testing::ValuesIn( + std::vector>>({ + // inf cases + {"-0x1p+1024", uint64_t(0xFFF0000000000000)}, // -inf + {"0x32p+1023", uint64_t(0x7FF0000000000000)}, // inf + {"0x32p+5000", uint64_t(0x7FF0000000000000)}, // inf + {"-0x32p+1023", uint64_t(0xFFF0000000000000)}, // -inf + })), ); + +TEST(FloatProxy, ValidConversion) { + EXPECT_THAT(FloatProxy(1.f).getAsFloat(), Eq(1.0f)); + EXPECT_THAT(FloatProxy(32.f).getAsFloat(), Eq(32.0f)); + EXPECT_THAT(FloatProxy(-1.f).getAsFloat(), Eq(-1.0f)); + EXPECT_THAT(FloatProxy(0.f).getAsFloat(), Eq(0.0f)); + EXPECT_THAT(FloatProxy(-0.f).getAsFloat(), Eq(-0.0f)); + EXPECT_THAT(FloatProxy(1.2e32f).getAsFloat(), Eq(1.2e32f)); + + EXPECT_TRUE(std::isinf(FloatProxy(uint32_t(0xFF800000)).getAsFloat())); + EXPECT_TRUE(std::isinf(FloatProxy(uint32_t(0x7F800000)).getAsFloat())); + EXPECT_TRUE(std::isnan(FloatProxy(uint32_t(0xFFC00000)).getAsFloat())); + EXPECT_TRUE(std::isnan(FloatProxy(uint32_t(0xFF800100)).getAsFloat())); + EXPECT_TRUE(std::isnan(FloatProxy(uint32_t(0xFF800c00)).getAsFloat())); + EXPECT_TRUE(std::isnan(FloatProxy(uint32_t(0xFF80F000)).getAsFloat())); + EXPECT_TRUE(std::isnan(FloatProxy(uint32_t(0xFFFFFFFF)).getAsFloat())); + EXPECT_TRUE(std::isnan(FloatProxy(uint32_t(0x7FC00000)).getAsFloat())); + EXPECT_TRUE(std::isnan(FloatProxy(uint32_t(0x7F800100)).getAsFloat())); + EXPECT_TRUE(std::isnan(FloatProxy(uint32_t(0x7f800c00)).getAsFloat())); + EXPECT_TRUE(std::isnan(FloatProxy(uint32_t(0x7F80F000)).getAsFloat())); + EXPECT_TRUE(std::isnan(FloatProxy(uint32_t(0x7FFFFFFF)).getAsFloat())); + + EXPECT_THAT(FloatProxy(uint32_t(0xFF800000)).data(), Eq(0xFF800000u)); + EXPECT_THAT(FloatProxy(uint32_t(0x7F800000)).data(), Eq(0x7F800000u)); + EXPECT_THAT(FloatProxy(uint32_t(0xFFC00000)).data(), Eq(0xFFC00000u)); + EXPECT_THAT(FloatProxy(uint32_t(0xFF800100)).data(), Eq(0xFF800100u)); + EXPECT_THAT(FloatProxy(uint32_t(0xFF800c00)).data(), Eq(0xFF800c00u)); + EXPECT_THAT(FloatProxy(uint32_t(0xFF80F000)).data(), Eq(0xFF80F000u)); + EXPECT_THAT(FloatProxy(uint32_t(0xFFFFFFFF)).data(), Eq(0xFFFFFFFFu)); + EXPECT_THAT(FloatProxy(uint32_t(0x7FC00000)).data(), Eq(0x7FC00000u)); + EXPECT_THAT(FloatProxy(uint32_t(0x7F800100)).data(), Eq(0x7F800100u)); + EXPECT_THAT(FloatProxy(uint32_t(0x7f800c00)).data(), Eq(0x7f800c00u)); + EXPECT_THAT(FloatProxy(uint32_t(0x7F80F000)).data(), Eq(0x7F80F000u)); + EXPECT_THAT(FloatProxy(uint32_t(0x7FFFFFFF)).data(), Eq(0x7FFFFFFFu)); +} + +TEST(FloatProxy, Nan) { + EXPECT_TRUE(FloatProxy(uint32_t(0xFFC00000)).isNan()); + EXPECT_TRUE(FloatProxy(uint32_t(0xFF800100)).isNan()); + EXPECT_TRUE(FloatProxy(uint32_t(0xFF800c00)).isNan()); + EXPECT_TRUE(FloatProxy(uint32_t(0xFF80F000)).isNan()); + EXPECT_TRUE(FloatProxy(uint32_t(0xFFFFFFFF)).isNan()); + EXPECT_TRUE(FloatProxy(uint32_t(0x7FC00000)).isNan()); + EXPECT_TRUE(FloatProxy(uint32_t(0x7F800100)).isNan()); + EXPECT_TRUE(FloatProxy(uint32_t(0x7f800c00)).isNan()); + EXPECT_TRUE(FloatProxy(uint32_t(0x7F80F000)).isNan()); + EXPECT_TRUE(FloatProxy(uint32_t(0x7FFFFFFF)).isNan()); +} + +TEST(FloatProxy, Negation) { + EXPECT_THAT((-FloatProxy(1.f)).getAsFloat(), Eq(-1.0f)); + EXPECT_THAT((-FloatProxy(0.f)).getAsFloat(), Eq(-0.0f)); + + EXPECT_THAT((-FloatProxy(-1.f)).getAsFloat(), Eq(1.0f)); + EXPECT_THAT((-FloatProxy(-0.f)).getAsFloat(), Eq(0.0f)); + + EXPECT_THAT((-FloatProxy(32.f)).getAsFloat(), Eq(-32.0f)); + EXPECT_THAT((-FloatProxy(-32.f)).getAsFloat(), Eq(32.0f)); + + EXPECT_THAT((-FloatProxy(1.2e32f)).getAsFloat(), Eq(-1.2e32f)); + EXPECT_THAT((-FloatProxy(-1.2e32f)).getAsFloat(), Eq(1.2e32f)); + + EXPECT_THAT( + (-FloatProxy(std::numeric_limits::infinity())).getAsFloat(), + Eq(-std::numeric_limits::infinity())); + EXPECT_THAT((-FloatProxy(-std::numeric_limits::infinity())) + .getAsFloat(), + Eq(std::numeric_limits::infinity())); +} + +// Test conversion of FloatProxy values to strings. +// +// In previous cases, we always wrapped the FloatProxy value in a HexFloat +// before conversion to a string. In the following cases, the FloatProxy +// decides for itself whether to print as a regular number or as a hex float. + +using FloatProxyFloatTest = + ::testing::TestWithParam, std::string>>; +using FloatProxyDoubleTest = + ::testing::TestWithParam, std::string>>; + +// Converts a float value to a string via a FloatProxy. +template +std::string EncodeViaFloatProxy(const T& value) { + std::stringstream ss; + ss << value; + return ss.str(); +} + +// Converts a floating point string so that the exponent prefix +// is 'e', and the exponent value does not have leading zeros. +// The Microsoft runtime library likes to write things like "2.5E+010". +// Convert that to "2.5e+10". +// We don't care what happens to strings that are not floating point +// strings. +std::string NormalizeExponentInFloatString(std::string in) { + std::string result; + // Reserve one spot for the terminating null, even when the sscanf fails. + std::vector prefix(in.size() + 1); + char e; + char plus_or_minus; + int exponent; // in base 10 + if ((4 == std::sscanf(in.c_str(), "%[-+.0123456789]%c%c%d", prefix.data(), &e, + &plus_or_minus, &exponent)) && + (e == 'e' || e == 'E') && + (plus_or_minus == '-' || plus_or_minus == '+')) { + // It looks like a floating point value with exponent. + std::stringstream out; + out << prefix.data() << 'e' << plus_or_minus << exponent; + result = out.str(); + } else { + result = in; + } + return result; +} + +TEST(NormalizeFloat, Sample) { + EXPECT_THAT(NormalizeExponentInFloatString(""), Eq("")); + EXPECT_THAT(NormalizeExponentInFloatString("1e-12"), Eq("1e-12")); + EXPECT_THAT(NormalizeExponentInFloatString("1E+14"), Eq("1e+14")); + EXPECT_THAT(NormalizeExponentInFloatString("1e-0012"), Eq("1e-12")); + EXPECT_THAT(NormalizeExponentInFloatString("1.263E+014"), Eq("1.263e+14")); +} + +// The following two tests can't be DRY because they take different parameter +// types. +TEST_P(FloatProxyFloatTest, EncodeCorrectly) { + EXPECT_THAT( + NormalizeExponentInFloatString(EncodeViaFloatProxy(GetParam().first)), + Eq(GetParam().second)); +} + +TEST_P(FloatProxyDoubleTest, EncodeCorrectly) { + EXPECT_THAT( + NormalizeExponentInFloatString(EncodeViaFloatProxy(GetParam().first)), + Eq(GetParam().second)); +} + +INSTANTIATE_TEST_CASE_P( + Float32Tests, FloatProxyFloatTest, + ::testing::ValuesIn(std::vector, std::string>>({ + // Zero + {0.f, "0"}, + // Normal numbers + {1.f, "1"}, + {-0.25f, "-0.25"}, + {1000.0f, "1000"}, + + // Still normal numbers, but with large magnitude exponents. + {float(ldexp(1.f, 126)), "8.50705917e+37"}, + {float(ldexp(-1.f, -126)), "-1.17549435e-38"}, + + // denormalized values are printed as hex floats. + {float(ldexp(1.0f, -127)), "0x1p-127"}, + {float(ldexp(1.5f, -128)), "0x1.8p-128"}, + {float(ldexp(1.25, -129)), "0x1.4p-129"}, + {float(ldexp(1.125, -130)), "0x1.2p-130"}, + {float(ldexp(-1.0f, -127)), "-0x1p-127"}, + {float(ldexp(-1.0f, -128)), "-0x1p-128"}, + {float(ldexp(-1.0f, -129)), "-0x1p-129"}, + {float(ldexp(-1.5f, -130)), "-0x1.8p-130"}, + + // NaNs + {FloatProxy(uint32_t(0xFFC00000)), "-0x1.8p+128"}, + {FloatProxy(uint32_t(0xFF800100)), "-0x1.0002p+128"}, + + {std::numeric_limits::infinity(), "0x1p+128"}, + {-std::numeric_limits::infinity(), "-0x1p+128"}, + })), ); + +INSTANTIATE_TEST_CASE_P( + Float64Tests, FloatProxyDoubleTest, + ::testing::ValuesIn( + std::vector, std::string>>({ + {0., "0"}, + {1., "1"}, + {-0.25, "-0.25"}, + {1000.0, "1000"}, + + // Large outside the range of normal floats + {ldexp(1.0, 128), "3.4028236692093846e+38"}, + {ldexp(1.5, 129), "1.0208471007628154e+39"}, + {ldexp(-1.0, 128), "-3.4028236692093846e+38"}, + {ldexp(-1.5, 129), "-1.0208471007628154e+39"}, + + // Small outside the range of normal floats + {ldexp(1.5, -129), "2.2040519077917891e-39"}, + {ldexp(-1.5, -129), "-2.2040519077917891e-39"}, + + // lowest non-denorm + {ldexp(1.0, -1022), "2.2250738585072014e-308"}, + {ldexp(-1.0, -1022), "-2.2250738585072014e-308"}, + + // Denormalized values + {ldexp(1.125, -1023), "0x1.2p-1023"}, + {ldexp(-1.375, -1024), "-0x1.6p-1024"}, + + // NaNs + {uint64_t(0x7FF8000000000000LL), "0x1.8p+1024"}, + {uint64_t(0xFFF0F00000000000LL), "-0x1.0fp+1024"}, + + // Infinity + {std::numeric_limits::infinity(), "0x1p+1024"}, + {-std::numeric_limits::infinity(), "-0x1p+1024"}, + + })), ); + +// double is used so that unbiased_exponent can be used with the output +// of ldexp directly. +int32_t unbiased_exponent(double f) { + return HexFloat>(static_cast(f)) + .getUnbiasedNormalizedExponent(); +} + +int16_t unbiased_half_exponent(uint16_t f) { + return HexFloat>(f).getUnbiasedNormalizedExponent(); +} + +TEST(HexFloatOperationTest, UnbiasedExponent) { + // Float cases + EXPECT_EQ(0, unbiased_exponent(ldexp(1.0f, 0))); + EXPECT_EQ(-32, unbiased_exponent(ldexp(1.0f, -32))); + EXPECT_EQ(42, unbiased_exponent(ldexp(1.0f, 42))); + EXPECT_EQ(125, unbiased_exponent(ldexp(1.0f, 125))); + + EXPECT_EQ(128, + HexFloat>(std::numeric_limits::infinity()) + .getUnbiasedNormalizedExponent()); + + EXPECT_EQ(-100, unbiased_exponent(ldexp(1.0f, -100))); + EXPECT_EQ(-127, unbiased_exponent(ldexp(1.0f, -127))); // First denorm + EXPECT_EQ(-128, unbiased_exponent(ldexp(1.0f, -128))); + EXPECT_EQ(-129, unbiased_exponent(ldexp(1.0f, -129))); + EXPECT_EQ(-140, unbiased_exponent(ldexp(1.0f, -140))); + // Smallest representable number + EXPECT_EQ(-126 - 23, unbiased_exponent(ldexp(1.0f, -126 - 23))); + // Should get rounded to 0 first. + EXPECT_EQ(0, unbiased_exponent(ldexp(1.0f, -127 - 23))); + + // Float16 cases + // The exponent is represented in the bits 0x7C00 + // The offset is -15 + EXPECT_EQ(0, unbiased_half_exponent(0x3C00)); + EXPECT_EQ(3, unbiased_half_exponent(0x4800)); + EXPECT_EQ(-1, unbiased_half_exponent(0x3800)); + EXPECT_EQ(-14, unbiased_half_exponent(0x0400)); + EXPECT_EQ(16, unbiased_half_exponent(0x7C00)); + EXPECT_EQ(10, unbiased_half_exponent(0x6400)); + + // Smallest representable number + EXPECT_EQ(-24, unbiased_half_exponent(0x0001)); +} + +// Creates a float that is the sum of 1/(2 ^ fractions[i]) for i in factions +float float_fractions(const std::vector& fractions) { + float f = 0; + for (int32_t i : fractions) { + f += std::ldexp(1.0f, -i); + } + return f; +} + +// Returns the normalized significand of a HexFloat> +// that was created by calling float_fractions with the input fractions, +// raised to the power of exp. +uint32_t normalized_significand(const std::vector& fractions, + uint32_t exp) { + return HexFloat>( + static_cast(ldexp(float_fractions(fractions), exp))) + .getNormalizedSignificand(); +} + +// Sets the bits from MSB to LSB of the significand part of a float. +// For example 0 would set the bit 23 (counting from LSB to MSB), +// and 1 would set the 22nd bit. +uint32_t bits_set(const std::vector& bits) { + const uint32_t top_bit = 1u << 22u; + uint32_t val = 0; + for (uint32_t i : bits) { + val |= top_bit >> i; + } + return val; +} + +// The same as bits_set but for a Float16 value instead of 32-bit floating +// point. +uint16_t half_bits_set(const std::vector& bits) { + const uint32_t top_bit = 1u << 9u; + uint32_t val = 0; + for (uint32_t i : bits) { + val |= top_bit >> i; + } + return static_cast(val); +} + +TEST(HexFloatOperationTest, NormalizedSignificand) { + // For normalized numbers (the following) it should be a simple matter + // of getting rid of the top implicit bit + EXPECT_EQ(bits_set({}), normalized_significand({0}, 0)); + EXPECT_EQ(bits_set({0}), normalized_significand({0, 1}, 0)); + EXPECT_EQ(bits_set({0, 1}), normalized_significand({0, 1, 2}, 0)); + EXPECT_EQ(bits_set({1}), normalized_significand({0, 2}, 0)); + EXPECT_EQ(bits_set({1}), normalized_significand({0, 2}, 32)); + EXPECT_EQ(bits_set({1}), normalized_significand({0, 2}, 126)); + + // For denormalized numbers we expect the normalized significand to + // shift as if it were normalized. This means, in practice that the + // top_most set bit will be cut off. Looks very similar to above (on purpose) + EXPECT_EQ(bits_set({}), + normalized_significand({0}, static_cast(-127))); + EXPECT_EQ(bits_set({3}), + normalized_significand({0, 4}, static_cast(-128))); + EXPECT_EQ(bits_set({3}), + normalized_significand({0, 4}, static_cast(-127))); + EXPECT_EQ(bits_set({}), + normalized_significand({22}, static_cast(-127))); + EXPECT_EQ(bits_set({0}), + normalized_significand({21, 22}, static_cast(-127))); +} + +// Returns the 32-bit floating point value created by +// calling setFromSignUnbiasedExponentAndNormalizedSignificand +// on a HexFloat> +float set_from_sign(bool negative, int32_t unbiased_exponent, + uint32_t significand, bool round_denorm_up) { + HexFloat> f(0.f); + f.setFromSignUnbiasedExponentAndNormalizedSignificand( + negative, unbiased_exponent, significand, round_denorm_up); + return f.value().getAsFloat(); +} + +TEST(HexFloatOperationTests, + SetFromSignUnbiasedExponentAndNormalizedSignificand) { + EXPECT_EQ(1.f, set_from_sign(false, 0, 0, false)); + + // Tests insertion of various denormalized numbers with and without round up. + EXPECT_EQ(static_cast(ldexp(1.f, -149)), + set_from_sign(false, -149, 0, false)); + EXPECT_EQ(static_cast(ldexp(1.f, -149)), + set_from_sign(false, -149, 0, true)); + EXPECT_EQ(0.f, set_from_sign(false, -150, 1, false)); + EXPECT_EQ(static_cast(ldexp(1.f, -149)), + set_from_sign(false, -150, 1, true)); + + EXPECT_EQ(ldexp(1.0f, -127), set_from_sign(false, -127, 0, false)); + EXPECT_EQ(ldexp(1.0f, -128), set_from_sign(false, -128, 0, false)); + EXPECT_EQ(float_fractions({0, 1, 2, 5}), + set_from_sign(false, 0, bits_set({0, 1, 4}), false)); + EXPECT_EQ(ldexp(float_fractions({0, 1, 2, 5}), -32), + set_from_sign(false, -32, bits_set({0, 1, 4}), false)); + EXPECT_EQ(ldexp(float_fractions({0, 1, 2, 5}), -128), + set_from_sign(false, -128, bits_set({0, 1, 4}), false)); + + // The negative cases from above. + EXPECT_EQ(-1.f, set_from_sign(true, 0, 0, false)); + EXPECT_EQ(-ldexp(1.0, -127), set_from_sign(true, -127, 0, false)); + EXPECT_EQ(-ldexp(1.0, -128), set_from_sign(true, -128, 0, false)); + EXPECT_EQ(-float_fractions({0, 1, 2, 5}), + set_from_sign(true, 0, bits_set({0, 1, 4}), false)); + EXPECT_EQ(-ldexp(float_fractions({0, 1, 2, 5}), -32), + set_from_sign(true, -32, bits_set({0, 1, 4}), false)); + EXPECT_EQ(-ldexp(float_fractions({0, 1, 2, 5}), -128), + set_from_sign(true, -128, bits_set({0, 1, 4}), false)); +} + +TEST(HexFloatOperationTests, NonRounding) { + // Rounding from 32-bit hex-float to 32-bit hex-float should be trivial, + // except in the denorm case which is a bit more complex. + using HF = HexFloat>; + bool carry_bit = false; + + round_direction rounding[] = {round_direction::kToZero, + round_direction::kToNearestEven, + round_direction::kToPositiveInfinity, + round_direction::kToNegativeInfinity}; + + // Everything fits, so this should be straight-forward + for (round_direction round : rounding) { + EXPECT_EQ(bits_set({}), + HF(0.f).getRoundedNormalizedSignificand(round, &carry_bit)); + EXPECT_FALSE(carry_bit); + + EXPECT_EQ(bits_set({0}), + HF(float_fractions({0, 1})) + .getRoundedNormalizedSignificand(round, &carry_bit)); + EXPECT_FALSE(carry_bit); + + EXPECT_EQ(bits_set({1, 3}), + HF(float_fractions({0, 2, 4})) + .getRoundedNormalizedSignificand(round, &carry_bit)); + EXPECT_FALSE(carry_bit); + + EXPECT_EQ( + bits_set({0, 1, 4}), + HF(static_cast(-ldexp(float_fractions({0, 1, 2, 5}), -128))) + .getRoundedNormalizedSignificand(round, &carry_bit)); + EXPECT_FALSE(carry_bit); + + EXPECT_EQ(bits_set({0, 1, 4, 22}), + HF(static_cast(float_fractions({0, 1, 2, 5, 23}))) + .getRoundedNormalizedSignificand(round, &carry_bit)); + EXPECT_FALSE(carry_bit); + } +} + +using RD = round_direction; +struct RoundSignificandCase { + float source_float; + std::pair expected_results; + round_direction round; +}; + +using HexFloatRoundTest = ::testing::TestWithParam; + +TEST_P(HexFloatRoundTest, RoundDownToFP16) { + using HF = HexFloat>; + using HF16 = HexFloat>; + + HF input_value(GetParam().source_float); + bool carry_bit = false; + EXPECT_EQ(GetParam().expected_results.first, + input_value.getRoundedNormalizedSignificand(GetParam().round, + &carry_bit)); + EXPECT_EQ(carry_bit, GetParam().expected_results.second); +} + +// clang-format off +INSTANTIATE_TEST_CASE_P(F32ToF16, HexFloatRoundTest, + ::testing::ValuesIn(std::vector( + { + {float_fractions({0}), std::make_pair(half_bits_set({}), false), RD::kToZero}, + {float_fractions({0}), std::make_pair(half_bits_set({}), false), RD::kToNearestEven}, + {float_fractions({0}), std::make_pair(half_bits_set({}), false), RD::kToPositiveInfinity}, + {float_fractions({0}), std::make_pair(half_bits_set({}), false), RD::kToNegativeInfinity}, + {float_fractions({0, 1}), std::make_pair(half_bits_set({0}), false), RD::kToZero}, + + {float_fractions({0, 1, 11}), std::make_pair(half_bits_set({0}), false), RD::kToZero}, + {float_fractions({0, 1, 11}), std::make_pair(half_bits_set({0, 9}), false), RD::kToPositiveInfinity}, + {float_fractions({0, 1, 11}), std::make_pair(half_bits_set({0}), false), RD::kToNegativeInfinity}, + {float_fractions({0, 1, 11}), std::make_pair(half_bits_set({0}), false), RD::kToNearestEven}, + + {float_fractions({0, 1, 10, 11}), std::make_pair(half_bits_set({0, 9}), false), RD::kToZero}, + {float_fractions({0, 1, 10, 11}), std::make_pair(half_bits_set({0, 8}), false), RD::kToPositiveInfinity}, + {float_fractions({0, 1, 10, 11}), std::make_pair(half_bits_set({0, 9}), false), RD::kToNegativeInfinity}, + {float_fractions({0, 1, 10, 11}), std::make_pair(half_bits_set({0, 8}), false), RD::kToNearestEven}, + + {float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0}), false), RD::kToZero}, + {float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0, 9}), false), RD::kToPositiveInfinity}, + {float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0}), false), RD::kToNegativeInfinity}, + {float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0, 9}), false), RD::kToNearestEven}, + + {-float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0}), false), RD::kToZero}, + {-float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0}), false), RD::kToPositiveInfinity}, + {-float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0, 9}), false), RD::kToNegativeInfinity}, + {-float_fractions({0, 1, 11, 12}), std::make_pair(half_bits_set({0, 9}), false), RD::kToNearestEven}, + + {float_fractions({0, 1, 11, 22}), std::make_pair(half_bits_set({0}), false), RD::kToZero}, + {float_fractions({0, 1, 11, 22}), std::make_pair(half_bits_set({0, 9}), false), RD::kToPositiveInfinity}, + {float_fractions({0, 1, 11, 22}), std::make_pair(half_bits_set({0}), false), RD::kToNegativeInfinity}, + {float_fractions({0, 1, 11, 22}), std::make_pair(half_bits_set({0, 9}), false), RD::kToNearestEven}, + + // Carries + {float_fractions({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}), std::make_pair(half_bits_set({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), false), RD::kToZero}, + {float_fractions({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}), std::make_pair(half_bits_set({}), true), RD::kToPositiveInfinity}, + {float_fractions({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}), std::make_pair(half_bits_set({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), false), RD::kToNegativeInfinity}, + {float_fractions({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}), std::make_pair(half_bits_set({}), true), RD::kToNearestEven}, + + // Cases where original number was denorm. Note: this should have no effect + // the number is pre-normalized. + {static_cast(ldexp(float_fractions({0, 1, 11, 13}), -128)), std::make_pair(half_bits_set({0}), false), RD::kToZero}, + {static_cast(ldexp(float_fractions({0, 1, 11, 13}), -129)), std::make_pair(half_bits_set({0, 9}), false), RD::kToPositiveInfinity}, + {static_cast(ldexp(float_fractions({0, 1, 11, 13}), -131)), std::make_pair(half_bits_set({0}), false), RD::kToNegativeInfinity}, + {static_cast(ldexp(float_fractions({0, 1, 11, 13}), -130)), std::make_pair(half_bits_set({0, 9}), false), RD::kToNearestEven}, + })),); +// clang-format on + +struct UpCastSignificandCase { + uint16_t source_half; + uint32_t expected_result; +}; + +using HexFloatRoundUpSignificandTest = + ::testing::TestWithParam; +TEST_P(HexFloatRoundUpSignificandTest, Widening) { + using HF = HexFloat>; + using HF16 = HexFloat>; + bool carry_bit = false; + + round_direction rounding[] = {round_direction::kToZero, + round_direction::kToNearestEven, + round_direction::kToPositiveInfinity, + round_direction::kToNegativeInfinity}; + + // Everything fits, so everything should just be bit-shifts. + for (round_direction round : rounding) { + carry_bit = false; + HF16 input_value(GetParam().source_half); + EXPECT_EQ( + GetParam().expected_result, + input_value.getRoundedNormalizedSignificand(round, &carry_bit)) + << std::hex << "0x" + << input_value.getRoundedNormalizedSignificand(round, &carry_bit) + << " 0x" << GetParam().expected_result; + EXPECT_FALSE(carry_bit); + } +} + +INSTANTIATE_TEST_CASE_P( + F16toF32, HexFloatRoundUpSignificandTest, + // 0xFC00 of the source 16-bit hex value cover the sign and the exponent. + // They are ignored for this test. + ::testing::ValuesIn(std::vector({ + {0x3F00, 0x600000}, + {0x0F00, 0x600000}, + {0x0F01, 0x602000}, + {0x0FFF, 0x7FE000}, + })), ); + +struct DownCastTest { + float source_float; + uint16_t expected_half; + std::vector directions; +}; + +std::string get_round_text(round_direction direction) { +#define CASE(round_direction) \ + case round_direction: \ + return #round_direction + + switch (direction) { + CASE(round_direction::kToZero); + CASE(round_direction::kToPositiveInfinity); + CASE(round_direction::kToNegativeInfinity); + CASE(round_direction::kToNearestEven); + } +#undef CASE + return ""; +} + +using HexFloatFP32To16Tests = ::testing::TestWithParam; + +TEST_P(HexFloatFP32To16Tests, NarrowingCasts) { + using HF = HexFloat>; + using HF16 = HexFloat>; + HF f(GetParam().source_float); + for (auto round : GetParam().directions) { + HF16 half(0); + f.castTo(half, round); + EXPECT_EQ(GetParam().expected_half, half.value().getAsFloat().get_value()) + << get_round_text(round) << " " << std::hex + << BitwiseCast(GetParam().source_float) + << " cast to: " << half.value().getAsFloat().get_value(); + } +} + +const uint16_t positive_infinity = 0x7C00; +const uint16_t negative_infinity = 0xFC00; + +INSTANTIATE_TEST_CASE_P( + F32ToF16, HexFloatFP32To16Tests, + ::testing::ValuesIn(std::vector({ + // Exactly representable as half. + {0.f, + 0x0, + {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, + RD::kToNearestEven}}, + {-0.f, + 0x8000, + {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, + RD::kToNearestEven}}, + {1.0f, + 0x3C00, + {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, + RD::kToNearestEven}}, + {-1.0f, + 0xBC00, + {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, + RD::kToNearestEven}}, + + {float_fractions({0, 1, 10}), + 0x3E01, + {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, + RD::kToNearestEven}}, + {-float_fractions({0, 1, 10}), + 0xBE01, + {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, + RD::kToNearestEven}}, + {static_cast(ldexp(float_fractions({0, 1, 10}), 3)), + 0x4A01, + {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, + RD::kToNearestEven}}, + {static_cast(-ldexp(float_fractions({0, 1, 10}), 3)), + 0xCA01, + {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, + RD::kToNearestEven}}, + + // Underflow + {static_cast(ldexp(1.0f, -25)), + 0x0, + {RD::kToZero, RD::kToNegativeInfinity, RD::kToNearestEven}}, + {static_cast(ldexp(1.0f, -25)), 0x1, {RD::kToPositiveInfinity}}, + {static_cast(-ldexp(1.0f, -25)), + 0x8000, + {RD::kToZero, RD::kToPositiveInfinity, RD::kToNearestEven}}, + {static_cast(-ldexp(1.0f, -25)), + 0x8001, + {RD::kToNegativeInfinity}}, + {static_cast(ldexp(1.0f, -24)), + 0x1, + {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, + RD::kToNearestEven}}, + + // Overflow + {static_cast(ldexp(1.0f, 16)), + positive_infinity, + {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, + RD::kToNearestEven}}, + {static_cast(ldexp(1.0f, 18)), + positive_infinity, + {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, + RD::kToNearestEven}}, + {static_cast(ldexp(1.3f, 16)), + positive_infinity, + {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, + RD::kToNearestEven}}, + {static_cast(-ldexp(1.0f, 16)), + negative_infinity, + {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, + RD::kToNearestEven}}, + {static_cast(-ldexp(1.0f, 18)), + negative_infinity, + {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, + RD::kToNearestEven}}, + {static_cast(-ldexp(1.3f, 16)), + negative_infinity, + {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, + RD::kToNearestEven}}, + + // Transfer of Infinities + {std::numeric_limits::infinity(), + positive_infinity, + {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, + RD::kToNearestEven}}, + {-std::numeric_limits::infinity(), + negative_infinity, + {RD::kToZero, RD::kToPositiveInfinity, RD::kToNegativeInfinity, + RD::kToNearestEven}}, + + // Nans are below because we cannot test for equality. + })), ); + +struct UpCastCase { + uint16_t source_half; + float expected_float; +}; + +using HexFloatFP16To32Tests = ::testing::TestWithParam; +TEST_P(HexFloatFP16To32Tests, WideningCasts) { + using HF = HexFloat>; + using HF16 = HexFloat>; + HF16 f(GetParam().source_half); + + round_direction rounding[] = {round_direction::kToZero, + round_direction::kToNearestEven, + round_direction::kToPositiveInfinity, + round_direction::kToNegativeInfinity}; + + // Everything fits, so everything should just be bit-shifts. + for (round_direction round : rounding) { + HF flt(0.f); + f.castTo(flt, round); + EXPECT_EQ(GetParam().expected_float, flt.value().getAsFloat()) + << get_round_text(round) << " " << std::hex + << BitwiseCast(GetParam().source_half) + << " cast to: " << flt.value().getAsFloat(); + } +} + +INSTANTIATE_TEST_CASE_P( + F16ToF32, HexFloatFP16To32Tests, + ::testing::ValuesIn(std::vector({ + {0x0000, 0.f}, + {0x8000, -0.f}, + {0x3C00, 1.0f}, + {0xBC00, -1.0f}, + {0x3F00, float_fractions({0, 1, 2})}, + {0xBF00, -float_fractions({0, 1, 2})}, + {0x3F01, float_fractions({0, 1, 2, 10})}, + {0xBF01, -float_fractions({0, 1, 2, 10})}, + + // denorm + {0x0001, static_cast(ldexp(1.0, -24))}, + {0x0002, static_cast(ldexp(1.0, -23))}, + {0x8001, static_cast(-ldexp(1.0, -24))}, + {0x8011, static_cast(-ldexp(1.0, -20) + -ldexp(1.0, -24))}, + + // inf + {0x7C00, std::numeric_limits::infinity()}, + {0xFC00, -std::numeric_limits::infinity()}, + })), ); + +TEST(HexFloatOperationTests, NanTests) { + using HF = HexFloat>; + using HF16 = HexFloat>; + round_direction rounding[] = {round_direction::kToZero, + round_direction::kToNearestEven, + round_direction::kToPositiveInfinity, + round_direction::kToNegativeInfinity}; + + // Everything fits, so everything should just be bit-shifts. + for (round_direction round : rounding) { + HF16 f16(0); + HF f(0.f); + HF(std::numeric_limits::quiet_NaN()).castTo(f16, round); + EXPECT_TRUE(f16.value().isNan()); + HF(std::numeric_limits::signaling_NaN()).castTo(f16, round); + EXPECT_TRUE(f16.value().isNan()); + + HF16(0x7C01).castTo(f, round); + EXPECT_TRUE(f.value().isNan()); + HF16(0x7C11).castTo(f, round); + EXPECT_TRUE(f.value().isNan()); + HF16(0xFC01).castTo(f, round); + EXPECT_TRUE(f.value().isNan()); + HF16(0x7C10).castTo(f, round); + EXPECT_TRUE(f.value().isNan()); + HF16(0xFF00).castTo(f, round); + EXPECT_TRUE(f.value().isNan()); + } +} + +// A test case for parsing good and bad HexFloat> literals. +template +struct FloatParseCase { + std::string literal; + bool negate_value; + bool expect_success; + HexFloat> expected_value; +}; + +using ParseNormalFloatTest = ::testing::TestWithParam>; + +TEST_P(ParseNormalFloatTest, Samples) { + std::stringstream input(GetParam().literal); + HexFloat> parsed_value(0.0f); + ParseNormalFloat(input, GetParam().negate_value, parsed_value); + EXPECT_NE(GetParam().expect_success, input.fail()) + << " literal: " << GetParam().literal + << " negate: " << GetParam().negate_value; + if (GetParam().expect_success) { + EXPECT_THAT(parsed_value.value(), Eq(GetParam().expected_value.value())) + << " literal: " << GetParam().literal + << " negate: " << GetParam().negate_value; + } +} + +// Returns a FloatParseCase with expected failure. +template +FloatParseCase BadFloatParseCase(std::string literal, bool negate_value, + T expected_value) { + HexFloat> proxy_expected_value(expected_value); + return FloatParseCase{literal, negate_value, false, proxy_expected_value}; +} + +// Returns a FloatParseCase that should successfully parse to a given value. +template +FloatParseCase GoodFloatParseCase(std::string literal, bool negate_value, + T expected_value) { + HexFloat> proxy_expected_value(expected_value); + return FloatParseCase{literal, negate_value, true, proxy_expected_value}; +} + +INSTANTIATE_TEST_CASE_P( + FloatParse, ParseNormalFloatTest, + ::testing::ValuesIn(std::vector>{ + // Failing cases due to trivially incorrect syntax. + BadFloatParseCase("abc", false, 0.0f), + BadFloatParseCase("abc", true, 0.0f), + + // Valid cases. + GoodFloatParseCase("0", false, 0.0f), + GoodFloatParseCase("0.0", false, 0.0f), + GoodFloatParseCase("-0.0", false, -0.0f), + GoodFloatParseCase("2.0", false, 2.0f), + GoodFloatParseCase("-2.0", false, -2.0f), + GoodFloatParseCase("+2.0", false, 2.0f), + // Cases with negate_value being true. + GoodFloatParseCase("0.0", true, -0.0f), + GoodFloatParseCase("2.0", true, -2.0f), + + // When negate_value is true, we should not accept a + // leading minus or plus. + BadFloatParseCase("-0.0", true, 0.0f), + BadFloatParseCase("-2.0", true, 0.0f), + BadFloatParseCase("+0.0", true, 0.0f), + BadFloatParseCase("+2.0", true, 0.0f), + + // Overflow is an error for 32-bit float parsing. + BadFloatParseCase("1e40", false, FLT_MAX), + BadFloatParseCase("1e40", true, -FLT_MAX), + BadFloatParseCase("-1e40", false, -FLT_MAX), + // We can't have -1e40 and negate_value == true since + // that represents an original case of "--1e40" which + // is invalid. + }), ); + +using ParseNormalFloat16Test = + ::testing::TestWithParam>; + +TEST_P(ParseNormalFloat16Test, Samples) { + std::stringstream input(GetParam().literal); + HexFloat> parsed_value(0); + ParseNormalFloat(input, GetParam().negate_value, parsed_value); + EXPECT_NE(GetParam().expect_success, input.fail()) + << " literal: " << GetParam().literal + << " negate: " << GetParam().negate_value; + if (GetParam().expect_success) { + EXPECT_THAT(parsed_value.value(), Eq(GetParam().expected_value.value())) + << " literal: " << GetParam().literal + << " negate: " << GetParam().negate_value; + } +} + +INSTANTIATE_TEST_CASE_P( + Float16Parse, ParseNormalFloat16Test, + ::testing::ValuesIn(std::vector>{ + // Failing cases due to trivially incorrect syntax. + BadFloatParseCase("abc", false, uint16_t{0}), + BadFloatParseCase("abc", true, uint16_t{0}), + + // Valid cases. + GoodFloatParseCase("0", false, uint16_t{0}), + GoodFloatParseCase("0.0", false, uint16_t{0}), + GoodFloatParseCase("-0.0", false, uint16_t{0x8000}), + GoodFloatParseCase("2.0", false, uint16_t{0x4000}), + GoodFloatParseCase("-2.0", false, uint16_t{0xc000}), + GoodFloatParseCase("+2.0", false, uint16_t{0x4000}), + // Cases with negate_value being true. + GoodFloatParseCase("0.0", true, uint16_t{0x8000}), + GoodFloatParseCase("2.0", true, uint16_t{0xc000}), + + // When negate_value is true, we should not accept a leading minus or + // plus. + BadFloatParseCase("-0.0", true, uint16_t{0}), + BadFloatParseCase("-2.0", true, uint16_t{0}), + BadFloatParseCase("+0.0", true, uint16_t{0}), + BadFloatParseCase("+2.0", true, uint16_t{0}), + }), ); + +// A test case for detecting infinities. +template +struct OverflowParseCase { + std::string input; + bool expect_success; + T expected_value; +}; + +using FloatProxyParseOverflowFloatTest = + ::testing::TestWithParam>; + +TEST_P(FloatProxyParseOverflowFloatTest, Sample) { + std::istringstream input(GetParam().input); + HexFloat> value(0.0f); + input >> value; + EXPECT_NE(GetParam().expect_success, input.fail()); + if (GetParam().expect_success) { + EXPECT_THAT(value.value().getAsFloat(), GetParam().expected_value); + } +} + +INSTANTIATE_TEST_CASE_P( + FloatOverflow, FloatProxyParseOverflowFloatTest, + ::testing::ValuesIn(std::vector>({ + {"0", true, 0.0f}, + {"0.0", true, 0.0f}, + {"1.0", true, 1.0f}, + {"1e38", true, 1e38f}, + {"-1e38", true, -1e38f}, + {"1e40", false, FLT_MAX}, + {"-1e40", false, -FLT_MAX}, + {"1e400", false, FLT_MAX}, + {"-1e400", false, -FLT_MAX}, + })), ); + +using FloatProxyParseOverflowDoubleTest = + ::testing::TestWithParam>; + +TEST_P(FloatProxyParseOverflowDoubleTest, Sample) { + std::istringstream input(GetParam().input); + HexFloat> value(0.0); + input >> value; + EXPECT_NE(GetParam().expect_success, input.fail()); + if (GetParam().expect_success) { + EXPECT_THAT(value.value().getAsFloat(), Eq(GetParam().expected_value)); + } +} + +INSTANTIATE_TEST_CASE_P( + DoubleOverflow, FloatProxyParseOverflowDoubleTest, + ::testing::ValuesIn(std::vector>({ + {"0", true, 0.0}, + {"0.0", true, 0.0}, + {"1.0", true, 1.0}, + {"1e38", true, 1e38}, + {"-1e38", true, -1e38}, + {"1e40", true, 1e40}, + {"-1e40", true, -1e40}, + {"1e400", false, DBL_MAX}, + {"-1e400", false, -DBL_MAX}, + })), ); + +using FloatProxyParseOverflowFloat16Test = + ::testing::TestWithParam>; + +TEST_P(FloatProxyParseOverflowFloat16Test, Sample) { + std::istringstream input(GetParam().input); + HexFloat> value(0); + input >> value; + EXPECT_NE(GetParam().expect_success, input.fail()) + << " literal: " << GetParam().input; + if (GetParam().expect_success) { + EXPECT_THAT(value.value().data(), Eq(GetParam().expected_value)) + << " literal: " << GetParam().input; + } +} + +INSTANTIATE_TEST_CASE_P( + Float16Overflow, FloatProxyParseOverflowFloat16Test, + ::testing::ValuesIn(std::vector>({ + {"0", true, uint16_t{0}}, + {"0.0", true, uint16_t{0}}, + {"1.0", true, uint16_t{0x3c00}}, + // Overflow for 16-bit float is an error, and returns max or + // lowest value. + {"1e38", false, uint16_t{0x7bff}}, + {"1e40", false, uint16_t{0x7bff}}, + {"1e400", false, uint16_t{0x7bff}}, + {"-1e38", false, uint16_t{0xfbff}}, + {"-1e40", false, uint16_t{0xfbff}}, + {"-1e400", false, uint16_t{0xfbff}}, + })), ); + +TEST(FloatProxy, Max) { + EXPECT_THAT(FloatProxy::max().getAsFloat().get_value(), + Eq(uint16_t{0x7bff})); + EXPECT_THAT(FloatProxy::max().getAsFloat(), + Eq(std::numeric_limits::max())); + EXPECT_THAT(FloatProxy::max().getAsFloat(), + Eq(std::numeric_limits::max())); +} + +TEST(FloatProxy, Lowest) { + EXPECT_THAT(FloatProxy::lowest().getAsFloat().get_value(), + Eq(uint16_t{0xfbff})); + EXPECT_THAT(FloatProxy::lowest().getAsFloat(), + Eq(std::numeric_limits::lowest())); + EXPECT_THAT(FloatProxy::lowest().getAsFloat(), + Eq(std::numeric_limits::lowest())); +} + +// TODO(awoloszyn): Add fp16 tests and HexFloatTraits. +} // namespace +} // namespace utils +} // namespace spvtools diff --git a/test/huffman_codec.cpp b/test/huffman_codec.cpp new file mode 100644 index 000000000..58a781061 --- /dev/null +++ b/test/huffman_codec.cpp @@ -0,0 +1,317 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/comp/bit_stream.h" +#include "source/comp/huffman_codec.h" + +namespace spvtools { +namespace comp { +namespace { + +const std::map& GetTestSet() { + static const std::map hist = { + {"a", 4}, {"e", 7}, {"f", 3}, {"h", 2}, {"i", 3}, + {"m", 2}, {"n", 2}, {"s", 2}, {"t", 2}, {"l", 1}, + {"o", 2}, {"p", 1}, {"r", 1}, {"u", 1}, {"x", 1}, + }; + + return hist; +} + +class TestBitReader { + public: + TestBitReader(const std::string& bits) : bits_(bits) {} + + bool ReadBit(bool* bit) { + if (pos_ < bits_.length()) { + *bit = bits_[pos_++] == '0' ? false : true; + return true; + } + return false; + } + + private: + std::string bits_; + size_t pos_ = 0; +}; + +TEST(Huffman, PrintTree) { + HuffmanCodec huffman(GetTestSet()); + std::stringstream ss; + huffman.PrintTree(ss); + + // clang-format off + const std::string expected = std::string(R"( +15-----7------e + 8------4------a + 4------2------m + 2------n +19-----8------4------2------o + 2------s + 4------2------t + 2------1------l + 1------p + 11-----5------2------1------r + 1------u + 3------f + 6------3------i + 3------1------x + 2------h +)").substr(1); + // clang-format on + + EXPECT_EQ(expected, ss.str()); +} + +TEST(Huffman, PrintTable) { + HuffmanCodec huffman(GetTestSet()); + std::stringstream ss; + huffman.PrintTable(ss); + + const std::string expected = std::string(R"( +e 7 11 +a 4 101 +i 3 0001 +f 3 0010 +t 2 0101 +s 2 0110 +o 2 0111 +n 2 1000 +m 2 1001 +h 2 00000 +x 1 00001 +u 1 00110 +r 1 00111 +p 1 01000 +l 1 01001 +)") + .substr(1); + + EXPECT_EQ(expected, ss.str()); +} + +TEST(Huffman, TestValidity) { + HuffmanCodec huffman(GetTestSet()); + const auto& encoding_table = huffman.GetEncodingTable(); + std::vector codes; + for (const auto& entry : encoding_table) { + codes.push_back(BitsToStream(entry.second.first, entry.second.second)); + } + + std::sort(codes.begin(), codes.end()); + + ASSERT_LT(codes.size(), 20u) << "Inefficient test ahead"; + + for (size_t i = 0; i < codes.size(); ++i) { + for (size_t j = i + 1; j < codes.size(); ++j) { + ASSERT_FALSE(codes[i] == codes[j].substr(0, codes[i].length())) + << codes[i] << " is prefix of " << codes[j]; + } + } +} + +TEST(Huffman, TestEncode) { + HuffmanCodec huffman(GetTestSet()); + + uint64_t bits = 0; + size_t num_bits = 0; + + EXPECT_TRUE(huffman.Encode("e", &bits, &num_bits)); + EXPECT_EQ(2u, num_bits); + EXPECT_EQ("11", BitsToStream(bits, num_bits)); + + EXPECT_TRUE(huffman.Encode("a", &bits, &num_bits)); + EXPECT_EQ(3u, num_bits); + EXPECT_EQ("101", BitsToStream(bits, num_bits)); + + EXPECT_TRUE(huffman.Encode("x", &bits, &num_bits)); + EXPECT_EQ(5u, num_bits); + EXPECT_EQ("00001", BitsToStream(bits, num_bits)); + + EXPECT_FALSE(huffman.Encode("y", &bits, &num_bits)); +} + +TEST(Huffman, TestDecode) { + HuffmanCodec huffman(GetTestSet()); + TestBitReader bit_reader( + "01001" + "0001" + "1000" + "00110" + "00001" + "00"); + auto read_bit = [&bit_reader](bool* bit) { return bit_reader.ReadBit(bit); }; + + std::string decoded; + + ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded)); + EXPECT_EQ("l", decoded); + + ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded)); + EXPECT_EQ("i", decoded); + + ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded)); + EXPECT_EQ("n", decoded); + + ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded)); + EXPECT_EQ("u", decoded); + + ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded)); + EXPECT_EQ("x", decoded); + + ASSERT_FALSE(huffman.DecodeFromStream(read_bit, &decoded)); +} + +TEST(Huffman, TestDecodeNumbers) { + const std::map hist = {{1, 10}, {2, 5}, {3, 15}}; + HuffmanCodec huffman(hist); + + TestBitReader bit_reader( + "1" + "1" + "01" + "00" + "01" + "1"); + auto read_bit = [&bit_reader](bool* bit) { return bit_reader.ReadBit(bit); }; + + uint32_t decoded; + + ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded)); + EXPECT_EQ(3u, decoded); + + ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded)); + EXPECT_EQ(3u, decoded); + + ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded)); + EXPECT_EQ(2u, decoded); + + ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded)); + EXPECT_EQ(1u, decoded); + + ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded)); + EXPECT_EQ(2u, decoded); + + ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded)); + EXPECT_EQ(3u, decoded); +} + +TEST(Huffman, SerializeToTextU64) { + const std::map hist = {{1001, 10}, {1002, 5}, {1003, 15}}; + HuffmanCodec huffman(hist); + + const std::string code = huffman.SerializeToText(2); + + const std::string expected = R"((5, { + {0, 0, 0}, + {1001, 0, 0}, + {1002, 0, 0}, + {1003, 0, 0}, + {0, 1, 2}, + {0, 4, 3}, + }))"; + + ASSERT_EQ(expected, code); +} + +TEST(Huffman, SerializeToTextString) { + const std::map hist = { + {"aaa", 10}, {"bbb", 20}, {"ccc", 15}}; + HuffmanCodec huffman(hist); + + const std::string code = huffman.SerializeToText(4); + + const std::string expected = R"((5, { + {"", 0, 0}, + {"aaa", 0, 0}, + {"bbb", 0, 0}, + {"ccc", 0, 0}, + {"", 3, 1}, + {"", 4, 2}, + }))"; + + ASSERT_EQ(expected, code); +} + +TEST(Huffman, CreateFromTextString) { + std::vector::Node> nodes = { + {}, + {"root", 2, 3}, + {"left", 0, 0}, + {"right", 0, 0}, + }; + + HuffmanCodec huffman(1, std::move(nodes)); + + std::stringstream ss; + huffman.PrintTree(ss); + + const std::string expected = std::string(R"( +0------right +0------left +)") + .substr(1); + + EXPECT_EQ(expected, ss.str()); +} + +TEST(Huffman, CreateFromTextU64) { + HuffmanCodec huffman(5, { + {0, 0, 0}, + {1001, 0, 0}, + {1002, 0, 0}, + {1003, 0, 0}, + {0, 1, 2}, + {0, 4, 3}, + }); + + std::stringstream ss; + huffman.PrintTree(ss); + + const std::string expected = std::string(R"( +0------1003 +0------0------1002 + 0------1001 +)") + .substr(1); + + EXPECT_EQ(expected, ss.str()); + + TestBitReader bit_reader("01"); + auto read_bit = [&bit_reader](bool* bit) { return bit_reader.ReadBit(bit); }; + + uint64_t decoded = 0; + ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded)); + EXPECT_EQ(1002u, decoded); + + uint64_t bits = 0; + size_t num_bits = 0; + + EXPECT_TRUE(huffman.Encode(1001, &bits, &num_bits)); + EXPECT_EQ(2u, num_bits); + EXPECT_EQ("00", BitsToStream(bits, num_bits)); +} + +} // namespace +} // namespace comp +} // namespace spvtools diff --git a/test/immediate_int_test.cpp b/test/immediate_int_test.cpp new file mode 100644 index 000000000..393075a4e --- /dev/null +++ b/test/immediate_int_test.cpp @@ -0,0 +1,291 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/util/bitutils.h" +#include "test/test_fixture.h" + +namespace spvtools { +namespace utils { +namespace { + +using spvtest::Concatenate; +using spvtest::MakeInstruction; +using spvtest::ScopedContext; +using spvtest::TextToBinaryTest; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::StrEq; + +TEST_F(TextToBinaryTest, ImmediateIntOpCode) { + SetText("!0x00FF00FF"); + ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(ScopedContext().context, text.str, + text.length, &binary, &diagnostic)); + EXPECT_EQ(0x00FF00FFu, binary->code[5]); + if (diagnostic) { + spvDiagnosticPrint(diagnostic); + } +} + +TEST_F(TextToBinaryTest, ImmediateIntOperand) { + SetText("OpCapability !0x00FF00FF"); + EXPECT_EQ(SPV_SUCCESS, spvTextToBinary(ScopedContext().context, text.str, + text.length, &binary, &diagnostic)); + EXPECT_EQ(0x00FF00FFu, binary->code[6]); + if (diagnostic) { + spvDiagnosticPrint(diagnostic); + } +} + +using ImmediateIntTest = TextToBinaryTest; + +TEST_F(ImmediateIntTest, AnyWordInSimpleStatement) { + EXPECT_THAT(CompiledInstructions("!0x00040018 %a %b %123"), + Eq(MakeInstruction(SpvOpTypeMatrix, {1, 2, 3}))); + EXPECT_THAT(CompiledInstructions("!0x00040018 !1 %b %123"), + Eq(MakeInstruction(SpvOpTypeMatrix, {1, 1, 2}))); + EXPECT_THAT(CompiledInstructions("%a = OpTypeMatrix !2 %123"), + Eq(MakeInstruction(SpvOpTypeMatrix, {1, 2, 2}))); + EXPECT_THAT(CompiledInstructions("%a = OpTypeMatrix %b !123"), + Eq(MakeInstruction(SpvOpTypeMatrix, {1, 2, 123}))); + EXPECT_THAT(CompiledInstructions("!0x00040018 %a !2 %123"), + Eq(MakeInstruction(SpvOpTypeMatrix, {1, 2, 2}))); + EXPECT_THAT(CompiledInstructions("!0x00040018 !1 %b !123"), + Eq(MakeInstruction(SpvOpTypeMatrix, {1, 1, 123}))); + EXPECT_THAT(CompiledInstructions("!0x00040018 !1 !2 !123"), + Eq(MakeInstruction(SpvOpTypeMatrix, {1, 2, 123}))); +} + +TEST_F(ImmediateIntTest, AnyWordAfterEqualsAndOpCode) { + EXPECT_THAT(CompiledInstructions("%a = OpArrayLength !2 %c 123"), + Eq(MakeInstruction(SpvOpArrayLength, {2, 1, 2, 123}))); + EXPECT_THAT(CompiledInstructions("%a = OpArrayLength %b !3 123"), + Eq(MakeInstruction(SpvOpArrayLength, {1, 2, 3, 123}))); + EXPECT_THAT(CompiledInstructions("%a = OpArrayLength %b %c !123"), + Eq(MakeInstruction(SpvOpArrayLength, {1, 2, 3, 123}))); + EXPECT_THAT(CompiledInstructions("%a = OpArrayLength %b !3 !123"), + Eq(MakeInstruction(SpvOpArrayLength, {1, 2, 3, 123}))); + EXPECT_THAT(CompiledInstructions("%a = OpArrayLength !2 !3 123"), + Eq(MakeInstruction(SpvOpArrayLength, {2, 1, 3, 123}))); + EXPECT_THAT(CompiledInstructions("%a = OpArrayLength !2 !3 !123"), + Eq(MakeInstruction(SpvOpArrayLength, {2, 1, 3, 123}))); +} + +TEST_F(ImmediateIntTest, ResultIdInAssignment) { + EXPECT_EQ("!2 not allowed before =.", + CompileFailure("!2 = OpArrayLength %12 %1 123")); + EXPECT_EQ("!2 not allowed before =.", + CompileFailure("!2 = !0x00040044 %12 %1 123")); +} + +TEST_F(ImmediateIntTest, OpCodeInAssignment) { + EXPECT_EQ("Invalid Opcode prefix '!0x00040044'.", + CompileFailure("%2 = !0x00040044 %12 %1 123")); +} + +// Literal integers after ! are handled correctly. +TEST_F(ImmediateIntTest, IntegerFollowingImmediate) { + const SpirvVector original = CompiledInstructions("%1 = OpTypeInt 8 1"); + EXPECT_EQ(original, CompiledInstructions("!0x00040015 1 8 1")); + EXPECT_EQ(original, CompiledInstructions("!0x00040015 !1 8 1")); + + // With !, we can (and can only) accept 32-bit number literals, + // even when we declare the return type is 64-bit. + EXPECT_EQ(Concatenate({ + MakeInstruction(SpvOpTypeInt, {1, 64, 0}), + MakeInstruction(SpvOpConstant, {1, 2, 4294967295}), + }), + CompiledInstructions("%i64 = OpTypeInt 64 0\n" + "!0x0004002b %i64 !2 4294967295")); + // 64-bit integer literal. + EXPECT_EQ("Invalid word following !: 5000000000", + CompileFailure("%2 = OpConstant !1 5000000000")); + EXPECT_EQ("Invalid word following !: 5000000000", + CompileFailure("%i64 = OpTypeInt 64 0\n" + "!0x0005002b %i64 !2 5000000000")); + + // Negative integer. + EXPECT_EQ(CompiledInstructions("%i64 = OpTypeInt 32 1\n" + "%2 = OpConstant %i64 -123"), + CompiledInstructions("%i64 = OpTypeInt 32 1\n" + "!0x0004002b %i64 !2 -123")); + + // TODO(deki): uncomment assertions below and make them pass. + // Hex value(s). + // EXPECT_EQ(CompileSuccessfully("%1 = OpConstant %10 0x12345678"), + // CompileSuccessfully("OpConstant %10 !1 0x12345678", kCAF)); + // EXPECT_EQ( + // CompileSuccessfully("%1 = OpConstant %10 0x12345678 0x87654321"), + // CompileSuccessfully("OpConstant %10 !1 0x12345678 0x87654321", kCAF)); +} + +// Literal floats after ! are handled correctly. +TEST_F(ImmediateIntTest, FloatFollowingImmediate) { + EXPECT_EQ( + CompiledInstructions("%1 = OpTypeFloat 32\n%2 = OpConstant %1 0.123"), + CompiledInstructions("%1 = OpTypeFloat 32\n!0x0004002b %1 !2 0.123")); + EXPECT_EQ( + CompiledInstructions("%1 = OpTypeFloat 32\n%2 = OpConstant %1 -0.5"), + CompiledInstructions("%1 = OpTypeFloat 32\n!0x0004002b %1 !2 -0.5")); + EXPECT_EQ( + CompiledInstructions("%1 = OpTypeFloat 32\n%2 = OpConstant %1 0.123"), + CompiledInstructions("%1 = OpTypeFloat 32\n!0x0004002b %1 %2 0.123")); + EXPECT_EQ( + CompiledInstructions("%1 = OpTypeFloat 32\n%2 = OpConstant %1 -0.5"), + CompiledInstructions("%1 = OpTypeFloat 32\n!0x0004002b %1 %2 -0.5")); + + EXPECT_EQ(Concatenate({ + MakeInstruction(SpvOpTypeInt, {1, 64, 0}), + MakeInstruction(SpvOpConstant, {1, 2, 0xb, 0xa}), + MakeInstruction(SpvOpSwitch, + {2, 1234, BitwiseCast(2.5f), 3}), + }), + CompiledInstructions("%i64 = OpTypeInt 64 0\n" + "%big = OpConstant %i64 0xa0000000b\n" + "OpSwitch %big !1234 2.5 %target\n")); +} + +// Literal strings after ! are handled correctly. +TEST_F(ImmediateIntTest, StringFollowingImmediate) { + // Try a variety of strings, including empty and single-character. + for (std::string name : {"", "s", "longish", "really looooooooooooooooong"}) { + const SpirvVector original = + CompiledInstructions("OpMemberName %10 4 \"" + name + "\""); + EXPECT_EQ(original, + CompiledInstructions("OpMemberName %10 !4 \"" + name + "\"")) + << name; + EXPECT_EQ(original, + CompiledInstructions("OpMemberName !1 !4 \"" + name + "\"")) + << name; + const uint16_t wordCount = static_cast(4 + name.size() / 4); + const uint32_t firstWord = spvOpcodeMake(wordCount, SpvOpMemberName); + EXPECT_EQ(original, CompiledInstructions("!" + std::to_string(firstWord) + + " %10 !4 \"" + name + "\"")) + << name; + } +} + +// IDs after ! are handled correctly. +TEST_F(ImmediateIntTest, IdFollowingImmediate) { + EXPECT_EQ(CompileSuccessfully("%123 = OpDecorationGroup"), + CompileSuccessfully("!0x00020049 %123")); + EXPECT_EQ(CompileSuccessfully("%group = OpDecorationGroup"), + CompileSuccessfully("!0x00020049 %group")); +} + +// ! after ! is handled correctly. +TEST_F(ImmediateIntTest, ImmediateFollowingImmediate) { + const SpirvVector original = CompiledInstructions("%a = OpTypeMatrix %b 7"); + EXPECT_EQ(original, CompiledInstructions("%a = OpTypeMatrix !2 !7")); + EXPECT_EQ(original, CompiledInstructions("!0x00040018 %a !2 !7")); +} + +TEST_F(ImmediateIntTest, InvalidStatement) { + EXPECT_THAT(Subvector(CompileSuccessfully("!4 !3 !2 !1"), kFirstInstruction), + ElementsAre(4, 3, 2, 1)); +} + +TEST_F(ImmediateIntTest, InvalidStatementBetweenValidOnes) { + EXPECT_THAT(Subvector(CompileSuccessfully( + "%10 = OpTypeFloat 32 !5 !6 !7 OpEmitVertex"), + kFirstInstruction), + ElementsAre(spvOpcodeMake(3, SpvOpTypeFloat), 1, 32, 5, 6, 7, + spvOpcodeMake(1, SpvOpEmitVertex))); +} + +TEST_F(ImmediateIntTest, NextOpcodeRecognized) { + const SpirvVector original = CompileSuccessfully(R"( +%1 = OpLoad %10 %2 Volatile +%4 = OpCompositeInsert %11 %1 %3 0 1 2 +)"); + const SpirvVector alternate = CompileSuccessfully(R"( +%1 = OpLoad %10 %2 !1 +%4 = OpCompositeInsert %11 %1 %3 0 1 2 +)"); + EXPECT_EQ(original, alternate); +} + +TEST_F(ImmediateIntTest, WrongLengthButNextOpcodeStillRecognized) { + const SpirvVector original = CompileSuccessfully(R"( +%1 = OpLoad %10 %2 Volatile +OpCopyMemorySized %3 %4 %1 +)"); + const SpirvVector alternate = CompileSuccessfully(R"( +!0x0002003D %10 %1 %2 !1 +OpCopyMemorySized %3 %4 %1 +)"); + EXPECT_EQ(0x0002003Du, alternate[kFirstInstruction]); + EXPECT_EQ(Subvector(original, kFirstInstruction + 1), + Subvector(alternate, kFirstInstruction + 1)); +} + +// Like NextOpcodeRecognized, but next statement is in assignment form. +TEST_F(ImmediateIntTest, NextAssignmentRecognized) { + const SpirvVector original = CompileSuccessfully(R"( +%1 = OpLoad %10 %2 None +%4 = OpFunctionCall %10 %3 %123 +)"); + const SpirvVector alternate = CompileSuccessfully(R"( +%1 = OpLoad %10 %2 !0 +%4 = OpFunctionCall %10 %3 %123 +)"); + EXPECT_EQ(original, alternate); +} + +// Two instructions in a row each have ! opcode. +TEST_F(ImmediateIntTest, ConsecutiveImmediateOpcodes) { + const SpirvVector original = CompileSuccessfully(R"( +%1 = OpConstantSampler %10 Clamp 78 Linear +%4 = OpFRem %11 %3 %2 +%5 = OpIsValidEvent %12 %2 +)"); + const SpirvVector alternate = CompileSuccessfully(R"( +!0x0006002D %10 %1 !2 78 !1 +!0x0005008C %11 %4 %3 %2 +%5 = OpIsValidEvent %12 %2 +)"); + EXPECT_EQ(original, alternate); +} + +// ! followed by, eg, an enum or '=' or a random bareword. +TEST_F(ImmediateIntTest, ForbiddenOperands) { + EXPECT_THAT(CompileFailure("OpMemoryModel !0 OpenCL"), HasSubstr("OpenCL")); + EXPECT_THAT(CompileFailure("!1 %0 = !2"), HasSubstr("=")); + EXPECT_THAT(CompileFailure("OpMemoryModel !0 random_bareword"), + HasSubstr("random_bareword")); + // Immediate integers longer than one 32-bit word. + EXPECT_THAT(CompileFailure("!5000000000"), HasSubstr("5000000000")); + EXPECT_THAT(CompileFailure("!999999999999999999"), + HasSubstr("999999999999999999")); + EXPECT_THAT(CompileFailure("!0x00020049 !5000000000"), + HasSubstr("5000000000")); + // Negative numbers. + EXPECT_THAT(CompileFailure("!0x00020049 !-123"), HasSubstr("-123")); +} + +TEST_F(ImmediateIntTest, NotInteger) { + EXPECT_THAT(CompileFailure("!abc"), StrEq("Invalid immediate integer: !abc")); + EXPECT_THAT(CompileFailure("!12.3"), + StrEq("Invalid immediate integer: !12.3")); + EXPECT_THAT(CompileFailure("!12K"), StrEq("Invalid immediate integer: !12K")); +} + +} // namespace +} // namespace utils +} // namespace spvtools diff --git a/test/libspirv_macros_test.cpp b/test/libspirv_macros_test.cpp new file mode 100644 index 000000000..bf5add671 --- /dev/null +++ b/test/libspirv_macros_test.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +TEST(Macros, BitShiftInnerParens) { ASSERT_EQ(65536, SPV_BIT(2 << 3)); } + +TEST(Macros, BitShiftOuterParens) { ASSERT_EQ(15, SPV_BIT(4) - 1); } + +} // namespace +} // namespace spvtools diff --git a/test/link/CMakeLists.txt b/test/link/CMakeLists.txt new file mode 100644 index 000000000..06aeb9164 --- /dev/null +++ b/test/link/CMakeLists.txt @@ -0,0 +1,27 @@ +# Copyright (c) 2017 Pierre Moreau +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +add_spvtools_unittest(TARGET link + SRCS + binary_version_test.cpp + entry_points_test.cpp + global_values_amount_test.cpp + ids_limit_test.cpp + matching_imports_to_exports_test.cpp + memory_model_test.cpp + partial_linkage_test.cpp + unique_ids_test.cpp + LIBS SPIRV-Tools-opt SPIRV-Tools-link +) diff --git a/test/link/binary_version_test.cpp b/test/link/binary_version_test.cpp new file mode 100644 index 000000000..0ceeebae2 --- /dev/null +++ b/test/link/binary_version_test.cpp @@ -0,0 +1,60 @@ +// Copyright (c) 2017 Pierre Moreau +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "test/link/linker_fixture.h" + +namespace spvtools { +namespace { + +using BinaryVersion = spvtest::LinkerTest; + +TEST_F(BinaryVersion, LinkerChoosesMaxSpirvVersion) { + // clang-format off + spvtest::Binaries binaries = { + { + SpvMagicNumber, + 0x00000300u, + SPV_GENERATOR_CODEPLAY, + 1u, // NOTE: Bound + 0u // NOTE: Schema; reserved + }, + { + SpvMagicNumber, + 0x00000600u, + SPV_GENERATOR_CODEPLAY, + 1u, // NOTE: Bound + 0u // NOTE: Schema; reserved + }, + { + SpvMagicNumber, + 0x00000100u, + SPV_GENERATOR_CODEPLAY, + 1u, // NOTE: Bound + 0u // NOTE: Schema; reserved + } + }; + // clang-format on + spvtest::Binary linked_binary; + + ASSERT_EQ(SPV_SUCCESS, Link(binaries, &linked_binary)); + EXPECT_THAT(GetErrorMessage(), std::string()); + + EXPECT_EQ(0x00000600u, linked_binary[1]); +} + +} // namespace +} // namespace spvtools diff --git a/test/link/entry_points_test.cpp b/test/link/entry_points_test.cpp new file mode 100644 index 000000000..bac8e02ef --- /dev/null +++ b/test/link/entry_points_test.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2017 Pierre Moreau +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "test/link/linker_fixture.h" + +namespace spvtools { +namespace { + +using ::testing::HasSubstr; + +class EntryPoints : public spvtest::LinkerTest {}; + +TEST_F(EntryPoints, SameModelDifferentName) { + const std::string body1 = R"( +OpEntryPoint GLCompute %3 "foo" +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpFunction %1 None %2 +OpFunctionEnd +)"; + const std::string body2 = R"( +OpEntryPoint GLCompute %3 "bar" +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpFunction %1 None %2 +OpFunctionEnd +)"; + + spvtest::Binary linked_binary; + EXPECT_EQ(SPV_SUCCESS, AssembleAndLink({body1, body2}, &linked_binary)); + EXPECT_THAT(GetErrorMessage(), std::string()); +} + +TEST_F(EntryPoints, DifferentModelSameName) { + const std::string body1 = R"( +OpEntryPoint GLCompute %3 "foo" +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpFunction %1 None %2 +OpFunctionEnd +)"; + const std::string body2 = R"( +OpEntryPoint Vertex %3 "foo" +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpFunction %1 None %2 +OpFunctionEnd +)"; + + spvtest::Binary linked_binary; + EXPECT_EQ(SPV_SUCCESS, AssembleAndLink({body1, body2}, &linked_binary)); + EXPECT_THAT(GetErrorMessage(), std::string()); +} + +TEST_F(EntryPoints, SameModelAndName) { + const std::string body1 = R"( +OpEntryPoint GLCompute %3 "foo" +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpFunction %1 None %2 +OpFunctionEnd +)"; + const std::string body2 = R"( +OpEntryPoint GLCompute %3 "foo" +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpFunction %1 None %2 +OpFunctionEnd +)"; + + spvtest::Binary linked_binary; + EXPECT_EQ(SPV_ERROR_INTERNAL, + AssembleAndLink({body1, body2}, &linked_binary)); + EXPECT_THAT(GetErrorMessage(), + HasSubstr("The entry point \"foo\", with execution model " + "GLCompute, was already defined.")); +} + +} // namespace +} // namespace spvtools diff --git a/test/link/global_values_amount_test.cpp b/test/link/global_values_amount_test.cpp new file mode 100644 index 000000000..2c4ee1f03 --- /dev/null +++ b/test/link/global_values_amount_test.cpp @@ -0,0 +1,153 @@ +// Copyright (c) 2017 Pierre Moreau +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "test/link/linker_fixture.h" + +namespace spvtools { +namespace { + +using ::testing::HasSubstr; + +class EntryPointsAmountTest : public spvtest::LinkerTest { + public: + EntryPointsAmountTest() { binaries.reserve(0xFFFF); } + + void SetUp() override { + binaries.push_back({SpvMagicNumber, + SpvVersion, + SPV_GENERATOR_CODEPLAY, + 10u, // NOTE: Bound + 0u, // NOTE: Schema; reserved + + 3u << SpvWordCountShift | SpvOpTypeFloat, + 1u, // NOTE: Result ID + 32u, // NOTE: Width + + 4u << SpvWordCountShift | SpvOpTypePointer, + 2u, // NOTE: Result ID + SpvStorageClassInput, + 1u, // NOTE: Type ID + + 2u << SpvWordCountShift | SpvOpTypeVoid, + 3u, // NOTE: Result ID + + 3u << SpvWordCountShift | SpvOpTypeFunction, + 4u, // NOTE: Result ID + 3u, // NOTE: Return type + + 5u << SpvWordCountShift | SpvOpFunction, + 3u, // NOTE: Result type + 5u, // NOTE: Result ID + SpvFunctionControlMaskNone, + 4u, // NOTE: Function type + + 2u << SpvWordCountShift | SpvOpLabel, + 6u, // NOTE: Result ID + + 4u << SpvWordCountShift | SpvOpVariable, + 2u, // NOTE: Type ID + 7u, // NOTE: Result ID + SpvStorageClassFunction, + + 4u << SpvWordCountShift | SpvOpVariable, + 2u, // NOTE: Type ID + 8u, // NOTE: Result ID + SpvStorageClassFunction, + + 4u << SpvWordCountShift | SpvOpVariable, + 2u, // NOTE: Type ID + 9u, // NOTE: Result ID + SpvStorageClassFunction, + + 1u << SpvWordCountShift | SpvOpReturn, + + 1u << SpvWordCountShift | SpvOpFunctionEnd}); + for (size_t i = 0u; i < 2u; ++i) { + spvtest::Binary binary = { + SpvMagicNumber, + SpvVersion, + SPV_GENERATOR_CODEPLAY, + 103u, // NOTE: Bound + 0u, // NOTE: Schema; reserved + + 3u << SpvWordCountShift | SpvOpTypeFloat, + 1u, // NOTE: Result ID + 32u, // NOTE: Width + + 4u << SpvWordCountShift | SpvOpTypePointer, + 2u, // NOTE: Result ID + SpvStorageClassInput, + 1u // NOTE: Type ID + }; + + for (uint32_t j = 0u; j < 0xFFFFu / 2u; ++j) { + binary.push_back(4u << SpvWordCountShift | SpvOpVariable); + binary.push_back(2u); // NOTE: Type ID + binary.push_back(j + 3u); // NOTE: Result ID + binary.push_back(SpvStorageClassInput); + } + binaries.push_back(binary); + } + } + void TearDown() override { binaries.clear(); } + + spvtest::Binaries binaries; +}; + +TEST_F(EntryPointsAmountTest, UnderLimit) { + spvtest::Binary linked_binary; + + EXPECT_EQ(SPV_SUCCESS, Link(binaries, &linked_binary)); + EXPECT_THAT(GetErrorMessage(), std::string()); +} + +TEST_F(EntryPointsAmountTest, OverLimit) { + binaries.push_back({SpvMagicNumber, + SpvVersion, + SPV_GENERATOR_CODEPLAY, + 5u, // NOTE: Bound + 0u, // NOTE: Schema; reserved + + 3u << SpvWordCountShift | SpvOpTypeFloat, + 1u, // NOTE: Result ID + 32u, // NOTE: Width + + 4u << SpvWordCountShift | SpvOpTypePointer, + 2u, // NOTE: Result ID + SpvStorageClassInput, + 1u, // NOTE: Type ID + + 4u << SpvWordCountShift | SpvOpVariable, + 2u, // NOTE: Type ID + 3u, // NOTE: Result ID + SpvStorageClassInput, + + 4u << SpvWordCountShift | SpvOpVariable, + 2u, // NOTE: Type ID + 4u, // NOTE: Result ID + SpvStorageClassInput}); + + spvtest::Binary linked_binary; + + EXPECT_EQ(SPV_ERROR_INTERNAL, Link(binaries, &linked_binary)); + EXPECT_THAT(GetErrorMessage(), + HasSubstr("The limit of global values, 65535, was exceeded; " + "65536 global values were found.")); +} + +} // namespace +} // namespace spvtools diff --git a/test/link/ids_limit_test.cpp b/test/link/ids_limit_test.cpp new file mode 100644 index 000000000..6d7815a24 --- /dev/null +++ b/test/link/ids_limit_test.cpp @@ -0,0 +1,72 @@ +// Copyright (c) 2017 Pierre Moreau +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "test/link/linker_fixture.h" + +namespace spvtools { +namespace { + +using ::testing::HasSubstr; +using IdsLimit = spvtest::LinkerTest; + +TEST_F(IdsLimit, UnderLimit) { + spvtest::Binaries binaries = { + { + SpvMagicNumber, SpvVersion, SPV_GENERATOR_CODEPLAY, + 0x2FFFFFu, // NOTE: Bound + 0u, // NOTE: Schema; reserved + }, + { + SpvMagicNumber, SpvVersion, SPV_GENERATOR_CODEPLAY, + 0x100000u, // NOTE: Bound + 0u, // NOTE: Schema; reserved + }}; + spvtest::Binary linked_binary; + + ASSERT_EQ(SPV_SUCCESS, Link(binaries, &linked_binary)); + EXPECT_THAT(GetErrorMessage(), std::string()); + EXPECT_EQ(0x3FFFFEu, linked_binary[3]); +} + +TEST_F(IdsLimit, OverLimit) { + spvtest::Binaries binaries = { + { + SpvMagicNumber, SpvVersion, SPV_GENERATOR_CODEPLAY, + 0x2FFFFFu, // NOTE: Bound + 0u, // NOTE: Schema; reserved + }, + { + SpvMagicNumber, SpvVersion, SPV_GENERATOR_CODEPLAY, + 0x100000u, // NOTE: Bound + 0u, // NOTE: Schema; reserved + }, + { + SpvMagicNumber, SpvVersion, SPV_GENERATOR_CODEPLAY, + 3u, // NOTE: Bound + 0u, // NOTE: Schema; reserved + }}; + + spvtest::Binary linked_binary; + + EXPECT_EQ(SPV_ERROR_INVALID_ID, Link(binaries, &linked_binary)); + EXPECT_THAT(GetErrorMessage(), + HasSubstr("The limit of IDs, 4194303, was exceeded: 4194304 is " + "the current ID bound.")); +} + +} // namespace +} // namespace spvtools diff --git a/test/link/linker_fixture.h b/test/link/linker_fixture.h new file mode 100644 index 000000000..303f1bfd5 --- /dev/null +++ b/test/link/linker_fixture.h @@ -0,0 +1,125 @@ +// Copyright (c) 2017 Pierre Moreau +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TEST_LINK_LINKER_FIXTURE_H_ +#define TEST_LINK_LINKER_FIXTURE_H_ + +#include +#include +#include + +#include "source/spirv_constant.h" +#include "spirv-tools/linker.hpp" +#include "test/unit_spirv.h" + +namespace spvtest { + +using Binary = std::vector; +using Binaries = std::vector; + +class LinkerTest : public ::testing::Test { + public: + LinkerTest() + : context_(SPV_ENV_UNIVERSAL_1_2), + tools_(SPV_ENV_UNIVERSAL_1_2), + assemble_options_(spvtools::SpirvTools::kDefaultAssembleOption), + disassemble_options_(spvtools::SpirvTools::kDefaultDisassembleOption) { + const auto consumer = [this](spv_message_level_t level, const char*, + const spv_position_t& position, + const char* message) { + if (!error_message_.empty()) error_message_ += "\n"; + switch (level) { + case SPV_MSG_FATAL: + case SPV_MSG_INTERNAL_ERROR: + case SPV_MSG_ERROR: + error_message_ += "ERROR"; + break; + case SPV_MSG_WARNING: + error_message_ += "WARNING"; + break; + case SPV_MSG_INFO: + error_message_ += "INFO"; + break; + case SPV_MSG_DEBUG: + error_message_ += "DEBUG"; + break; + } + error_message_ += ": " + std::to_string(position.index) + ": " + message; + }; + context_.SetMessageConsumer(consumer); + tools_.SetMessageConsumer(consumer); + } + + void TearDown() override { error_message_.clear(); } + + // Assembles each of the given strings into SPIR-V binaries before linking + // them together. SPV_ERROR_INVALID_TEXT is returned if the assembling failed + // for any of the input strings, and SPV_ERROR_INVALID_POINTER if + // |linked_binary| is a null pointer. + spv_result_t AssembleAndLink( + const std::vector& bodies, spvtest::Binary* linked_binary, + spvtools::LinkerOptions options = spvtools::LinkerOptions()) { + if (!linked_binary) return SPV_ERROR_INVALID_POINTER; + + spvtest::Binaries binaries(bodies.size()); + for (size_t i = 0u; i < bodies.size(); ++i) + if (!tools_.Assemble(bodies[i], binaries.data() + i, assemble_options_)) + return SPV_ERROR_INVALID_TEXT; + + return spvtools::Link(context_, binaries, linked_binary, options); + } + + // Links the given SPIR-V binaries together; SPV_ERROR_INVALID_POINTER is + // returned if |linked_binary| is a null pointer. + spv_result_t Link( + const spvtest::Binaries& binaries, spvtest::Binary* linked_binary, + spvtools::LinkerOptions options = spvtools::LinkerOptions()) { + if (!linked_binary) return SPV_ERROR_INVALID_POINTER; + return spvtools::Link(context_, binaries, linked_binary, options); + } + + // Disassembles |binary| and outputs the result in |text|. If |text| is a + // null pointer, SPV_ERROR_INVALID_POINTER is returned. + spv_result_t Disassemble(const spvtest::Binary& binary, std::string* text) { + if (!text) return SPV_ERROR_INVALID_POINTER; + return tools_.Disassemble(binary, text, disassemble_options_) + ? SPV_SUCCESS + : SPV_ERROR_INVALID_BINARY; + } + + // Sets the options for the assembler. + void SetAssembleOptions(uint32_t assemble_options) { + assemble_options_ = assemble_options; + } + + // Sets the options used by the disassembler. + void SetDisassembleOptions(uint32_t disassemble_options) { + disassemble_options_ = disassemble_options; + } + + // Returns the accumulated error messages for the test. + std::string GetErrorMessage() const { return error_message_; } + + private: + spvtools::Context context_; + spvtools::SpirvTools + tools_; // An instance for calling SPIRV-Tools functionalities. + uint32_t assemble_options_; + uint32_t disassemble_options_; + std::string error_message_; +}; + +} // namespace spvtest + +#endif // TEST_LINK_LINKER_FIXTURE_H_ diff --git a/test/link/matching_imports_to_exports_test.cpp b/test/link/matching_imports_to_exports_test.cpp new file mode 100644 index 000000000..59e62d51b --- /dev/null +++ b/test/link/matching_imports_to_exports_test.cpp @@ -0,0 +1,403 @@ +// Copyright (c) 2017 Pierre Moreau +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "test/link/linker_fixture.h" + +namespace spvtools { +namespace { + +using ::testing::HasSubstr; +using MatchingImportsToExports = spvtest::LinkerTest; + +TEST_F(MatchingImportsToExports, Default) { + const std::string body1 = R"( +OpCapability Linkage +OpDecorate %1 LinkageAttributes "foo" Import +%2 = OpTypeFloat 32 +%1 = OpVariable %2 Uniform +%3 = OpVariable %2 Input +)"; + const std::string body2 = R"( +OpCapability Linkage +OpDecorate %1 LinkageAttributes "foo" Export +%2 = OpTypeFloat 32 +%3 = OpConstant %2 42 +%1 = OpVariable %2 Uniform %3 +)"; + + spvtest::Binary linked_binary; + EXPECT_EQ(SPV_SUCCESS, AssembleAndLink({body1, body2}, &linked_binary)) + << GetErrorMessage(); + + const std::string expected_res = + R"(OpModuleProcessed "Linked by SPIR-V Tools Linker" +%1 = OpTypeFloat 32 +%2 = OpVariable %1 Input +%3 = OpConstant %1 42 +%4 = OpVariable %1 Uniform %3 +)"; + std::string res_body; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + EXPECT_EQ(SPV_SUCCESS, Disassemble(linked_binary, &res_body)) + << GetErrorMessage(); + EXPECT_EQ(expected_res, res_body); +} + +TEST_F(MatchingImportsToExports, NotALibraryExtraExports) { + const std::string body = R"( +OpCapability Linkage +OpDecorate %1 LinkageAttributes "foo" Export +%2 = OpTypeFloat 32 +%1 = OpVariable %2 Uniform +)"; + + spvtest::Binary linked_binary; + EXPECT_EQ(SPV_SUCCESS, AssembleAndLink({body}, &linked_binary)) + << GetErrorMessage(); + + const std::string expected_res = + R"(OpModuleProcessed "Linked by SPIR-V Tools Linker" +%1 = OpTypeFloat 32 +%2 = OpVariable %1 Uniform +)"; + std::string res_body; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + EXPECT_EQ(SPV_SUCCESS, Disassemble(linked_binary, &res_body)) + << GetErrorMessage(); + EXPECT_EQ(expected_res, res_body); +} + +TEST_F(MatchingImportsToExports, LibraryExtraExports) { + const std::string body = R"( +OpCapability Linkage +OpDecorate %1 LinkageAttributes "foo" Export +%2 = OpTypeFloat 32 +%1 = OpVariable %2 Uniform +)"; + + spvtest::Binary linked_binary; + LinkerOptions options; + options.SetCreateLibrary(true); + EXPECT_EQ(SPV_SUCCESS, AssembleAndLink({body}, &linked_binary, options)) + << GetErrorMessage(); + + const std::string expected_res = R"(OpCapability Linkage +OpModuleProcessed "Linked by SPIR-V Tools Linker" +OpDecorate %1 LinkageAttributes "foo" Export +%2 = OpTypeFloat 32 +%1 = OpVariable %2 Uniform +)"; + std::string res_body; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + EXPECT_EQ(SPV_SUCCESS, Disassemble(linked_binary, &res_body)) + << GetErrorMessage(); + EXPECT_EQ(expected_res, res_body); +} + +TEST_F(MatchingImportsToExports, UnresolvedImports) { + const std::string body1 = R"( +OpCapability Linkage +OpDecorate %1 LinkageAttributes "foo" Import +%2 = OpTypeFloat 32 +%1 = OpVariable %2 Uniform +)"; + const std::string body2 = R"()"; + + spvtest::Binary linked_binary; + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + AssembleAndLink({body1, body2}, &linked_binary)); + EXPECT_THAT(GetErrorMessage(), + HasSubstr("Unresolved external reference to \"foo\".")); +} + +TEST_F(MatchingImportsToExports, TypeMismatch) { + const std::string body1 = R"( +OpCapability Linkage +OpDecorate %1 LinkageAttributes "foo" Import +%2 = OpTypeFloat 32 +%1 = OpVariable %2 Uniform +%3 = OpVariable %2 Input +)"; + const std::string body2 = R"( +OpCapability Linkage +OpDecorate %1 LinkageAttributes "foo" Export +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 42 +%1 = OpVariable %2 Uniform %3 +)"; + + spvtest::Binary linked_binary; + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + AssembleAndLink({body1, body2}, &linked_binary)) + << GetErrorMessage(); + EXPECT_THAT( + GetErrorMessage(), + HasSubstr("Type mismatch on symbol \"foo\" between imported " + "variable/function %1 and exported variable/function %4")); +} + +TEST_F(MatchingImportsToExports, MultipleDefinitions) { + const std::string body1 = R"( +OpCapability Linkage +OpDecorate %1 LinkageAttributes "foo" Import +%2 = OpTypeFloat 32 +%1 = OpVariable %2 Uniform +%3 = OpVariable %2 Input +)"; + const std::string body2 = R"( +OpCapability Linkage +OpDecorate %1 LinkageAttributes "foo" Export +%2 = OpTypeFloat 32 +%3 = OpConstant %2 42 +%1 = OpVariable %2 Uniform %3 +)"; + const std::string body3 = R"( +OpCapability Linkage +OpDecorate %1 LinkageAttributes "foo" Export +%2 = OpTypeFloat 32 +%3 = OpConstant %2 -1 +%1 = OpVariable %2 Uniform %3 +)"; + + spvtest::Binary linked_binary; + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + AssembleAndLink({body1, body2, body3}, &linked_binary)) + << GetErrorMessage(); + EXPECT_THAT(GetErrorMessage(), + HasSubstr("Too many external references, 2, were found " + "for \"foo\".")); +} + +TEST_F(MatchingImportsToExports, SameNameDifferentTypes) { + const std::string body1 = R"( +OpCapability Linkage +OpDecorate %1 LinkageAttributes "foo" Import +%2 = OpTypeFloat 32 +%1 = OpVariable %2 Uniform +%3 = OpVariable %2 Input +)"; + const std::string body2 = R"( +OpCapability Linkage +OpDecorate %1 LinkageAttributes "foo" Export +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 42 +%1 = OpVariable %2 Uniform %3 +)"; + const std::string body3 = R"( +OpCapability Linkage +OpDecorate %1 LinkageAttributes "foo" Export +%2 = OpTypeFloat 32 +%3 = OpConstant %2 12 +%1 = OpVariable %2 Uniform %3 +)"; + + spvtest::Binary linked_binary; + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + AssembleAndLink({body1, body2, body3}, &linked_binary)) + << GetErrorMessage(); + EXPECT_THAT(GetErrorMessage(), + HasSubstr("Too many external references, 2, were found " + "for \"foo\".")); +} + +TEST_F(MatchingImportsToExports, DecorationMismatch) { + const std::string body1 = R"( +OpCapability Linkage +OpDecorate %1 LinkageAttributes "foo" Import +OpDecorate %2 Constant +%2 = OpTypeFloat 32 +%1 = OpVariable %2 Uniform +%3 = OpVariable %2 Input +)"; + const std::string body2 = R"( +OpCapability Linkage +OpDecorate %1 LinkageAttributes "foo" Export +%2 = OpTypeFloat 32 +%3 = OpConstant %2 42 +%1 = OpVariable %2 Uniform %3 +)"; + + spvtest::Binary linked_binary; + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + AssembleAndLink({body1, body2}, &linked_binary)) + << GetErrorMessage(); + EXPECT_THAT( + GetErrorMessage(), + HasSubstr("Type mismatch on symbol \"foo\" between imported " + "variable/function %1 and exported variable/function %4")); +} + +TEST_F(MatchingImportsToExports, + FuncParamAttrDifferButStillMatchExportToImport) { + const std::string body1 = R"( +OpCapability Kernel +OpCapability Linkage +OpDecorate %1 LinkageAttributes "foo" Import +OpDecorate %2 FuncParamAttr Zext +%3 = OpTypeVoid +%4 = OpTypeInt 32 0 +%5 = OpTypeFunction %3 %4 +%1 = OpFunction %3 None %5 +%2 = OpFunctionParameter %4 +OpFunctionEnd +)"; + const std::string body2 = R"( +OpCapability Kernel +OpCapability Linkage +OpDecorate %1 LinkageAttributes "foo" Export +OpDecorate %2 FuncParamAttr Sext +%3 = OpTypeVoid +%4 = OpTypeInt 32 0 +%5 = OpTypeFunction %3 %4 +%1 = OpFunction %3 None %5 +%2 = OpFunctionParameter %4 +%6 = OpLabel +OpReturn +OpFunctionEnd +)"; + + spvtest::Binary linked_binary; + EXPECT_EQ(SPV_SUCCESS, AssembleAndLink({body1, body2}, &linked_binary)) + << GetErrorMessage(); + + const std::string expected_res = R"(OpCapability Kernel +OpModuleProcessed "Linked by SPIR-V Tools Linker" +OpDecorate %1 FuncParamAttr Sext +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypeFunction %2 %3 +%5 = OpFunction %2 None %4 +%1 = OpFunctionParameter %3 +%6 = OpLabel +OpReturn +OpFunctionEnd +)"; + std::string res_body; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + EXPECT_EQ(SPV_SUCCESS, Disassemble(linked_binary, &res_body)) + << GetErrorMessage(); + EXPECT_EQ(expected_res, res_body); +} + +TEST_F(MatchingImportsToExports, FunctionCtrl) { + const std::string body1 = R"( +OpCapability Linkage +OpDecorate %1 LinkageAttributes "foo" Import +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%4 = OpTypeFloat 32 +%5 = OpVariable %4 Uniform +%1 = OpFunction %2 None %3 +OpFunctionEnd +)"; + const std::string body2 = R"( +OpCapability Linkage +OpDecorate %1 LinkageAttributes "foo" Export +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%1 = OpFunction %2 Inline %3 +%4 = OpLabel +OpReturn +OpFunctionEnd +)"; + + spvtest::Binary linked_binary; + EXPECT_EQ(SPV_SUCCESS, AssembleAndLink({body1, body2}, &linked_binary)) + << GetErrorMessage(); + + const std::string expected_res = + R"(OpModuleProcessed "Linked by SPIR-V Tools Linker" +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpTypeFloat 32 +%4 = OpVariable %3 Uniform +%5 = OpFunction %1 Inline %2 +%6 = OpLabel +OpReturn +OpFunctionEnd +)"; + std::string res_body; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + EXPECT_EQ(SPV_SUCCESS, Disassemble(linked_binary, &res_body)) + << GetErrorMessage(); + EXPECT_EQ(expected_res, res_body); +} + +TEST_F(MatchingImportsToExports, UseExportedFuncParamAttr) { + const std::string body1 = R"( +OpCapability Kernel +OpCapability Linkage +OpDecorate %1 LinkageAttributes "foo" Import +OpDecorate %2 FuncParamAttr Zext +%2 = OpDecorationGroup +OpGroupDecorate %2 %3 %4 +%5 = OpTypeVoid +%6 = OpTypeInt 32 0 +%7 = OpTypeFunction %5 %6 +%1 = OpFunction %5 None %7 +%3 = OpFunctionParameter %6 +OpFunctionEnd +%8 = OpFunction %5 None %7 +%4 = OpFunctionParameter %6 +OpFunctionEnd +)"; + const std::string body2 = R"( +OpCapability Kernel +OpCapability Linkage +OpDecorate %1 LinkageAttributes "foo" Export +OpDecorate %2 FuncParamAttr Sext +%3 = OpTypeVoid +%4 = OpTypeInt 32 0 +%5 = OpTypeFunction %3 %4 +%1 = OpFunction %3 None %5 +%2 = OpFunctionParameter %4 +%6 = OpLabel +OpReturn +OpFunctionEnd +)"; + + spvtest::Binary linked_binary; + EXPECT_EQ(SPV_SUCCESS, AssembleAndLink({body1, body2}, &linked_binary)) + << GetErrorMessage(); + + const std::string expected_res = R"(OpCapability Kernel +OpModuleProcessed "Linked by SPIR-V Tools Linker" +OpDecorate %1 FuncParamAttr Zext +%1 = OpDecorationGroup +OpGroupDecorate %1 %2 +OpDecorate %3 FuncParamAttr Sext +%4 = OpTypeVoid +%5 = OpTypeInt 32 0 +%6 = OpTypeFunction %4 %5 +%7 = OpFunction %4 None %6 +%2 = OpFunctionParameter %5 +OpFunctionEnd +%8 = OpFunction %4 None %6 +%3 = OpFunctionParameter %5 +%9 = OpLabel +OpReturn +OpFunctionEnd +)"; + std::string res_body; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + EXPECT_EQ(SPV_SUCCESS, Disassemble(linked_binary, &res_body)) + << GetErrorMessage(); + EXPECT_EQ(expected_res, res_body); +} + +} // namespace +} // namespace spvtools diff --git a/test/link/memory_model_test.cpp b/test/link/memory_model_test.cpp new file mode 100644 index 000000000..2add5046c --- /dev/null +++ b/test/link/memory_model_test.cpp @@ -0,0 +1,74 @@ +// Copyright (c) 2017 Pierre Moreau +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "test/link/linker_fixture.h" + +namespace spvtools { +namespace { + +using ::testing::HasSubstr; +using MemoryModel = spvtest::LinkerTest; + +TEST_F(MemoryModel, Default) { + const std::string body1 = R"( +OpMemoryModel Logical Simple +)"; + const std::string body2 = R"( +OpMemoryModel Logical Simple +)"; + + spvtest::Binary linked_binary; + ASSERT_EQ(SPV_SUCCESS, AssembleAndLink({body1, body2}, &linked_binary)); + EXPECT_THAT(GetErrorMessage(), std::string()); + + EXPECT_EQ(SpvAddressingModelLogical, linked_binary[6]); + EXPECT_EQ(SpvMemoryModelSimple, linked_binary[7]); +} + +TEST_F(MemoryModel, AddressingMismatch) { + const std::string body1 = R"( +OpMemoryModel Logical Simple +)"; + const std::string body2 = R"( +OpMemoryModel Physical32 Simple +)"; + + spvtest::Binary linked_binary; + EXPECT_EQ(SPV_ERROR_INTERNAL, + AssembleAndLink({body1, body2}, &linked_binary)); + EXPECT_THAT( + GetErrorMessage(), + HasSubstr("Conflicting addressing models: Logical vs Physical32.")); +} + +TEST_F(MemoryModel, MemoryMismatch) { + const std::string body1 = R"( +OpMemoryModel Logical Simple +)"; + const std::string body2 = R"( +OpMemoryModel Logical GLSL450 +)"; + + spvtest::Binary linked_binary; + EXPECT_EQ(SPV_ERROR_INTERNAL, + AssembleAndLink({body1, body2}, &linked_binary)); + EXPECT_THAT(GetErrorMessage(), + HasSubstr("Conflicting memory models: Simple vs GLSL450.")); +} + +} // namespace +} // namespace spvtools diff --git a/test/link/partial_linkage_test.cpp b/test/link/partial_linkage_test.cpp new file mode 100644 index 000000000..c43b06e55 --- /dev/null +++ b/test/link/partial_linkage_test.cpp @@ -0,0 +1,89 @@ +// Copyright (c) 2018 Pierre Moreau +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "test/link/linker_fixture.h" + +namespace spvtools { +namespace { + +using ::testing::HasSubstr; +using PartialLinkage = spvtest::LinkerTest; + +TEST_F(PartialLinkage, Allowed) { + const std::string body1 = R"( +OpCapability Linkage +OpDecorate %1 LinkageAttributes "foo" Import +OpDecorate %2 LinkageAttributes "bar" Import +%3 = OpTypeFloat 32 +%1 = OpVariable %3 Uniform +%2 = OpVariable %3 Uniform +)"; + const std::string body2 = R"( +OpCapability Linkage +OpDecorate %1 LinkageAttributes "bar" Export +%2 = OpTypeFloat 32 +%3 = OpConstant %2 3.1415 +%1 = OpVariable %2 Uniform %3 +)"; + + spvtest::Binary linked_binary; + LinkerOptions linker_options; + linker_options.SetAllowPartialLinkage(true); + ASSERT_EQ(SPV_SUCCESS, + AssembleAndLink({body1, body2}, &linked_binary, linker_options)); + + const std::string expected_res = R"(OpCapability Linkage +OpModuleProcessed "Linked by SPIR-V Tools Linker" +OpDecorate %1 LinkageAttributes "foo" Import +%2 = OpTypeFloat 32 +%1 = OpVariable %2 Uniform +%3 = OpConstant %2 3.1415 +%4 = OpVariable %2 Uniform %3 +)"; + std::string res_body; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + ASSERT_EQ(SPV_SUCCESS, Disassemble(linked_binary, &res_body)) + << GetErrorMessage(); + EXPECT_EQ(expected_res, res_body); +} + +TEST_F(PartialLinkage, Disallowed) { + const std::string body1 = R"( +OpCapability Linkage +OpDecorate %1 LinkageAttributes "foo" Import +OpDecorate %2 LinkageAttributes "bar" Import +%3 = OpTypeFloat 32 +%1 = OpVariable %3 Uniform +%2 = OpVariable %3 Uniform +)"; + const std::string body2 = R"( +OpCapability Linkage +OpDecorate %1 LinkageAttributes "bar" Export +%2 = OpTypeFloat 32 +%3 = OpConstant %2 3.1415 +%1 = OpVariable %2 Uniform %3 +)"; + + spvtest::Binary linked_binary; + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + AssembleAndLink({body1, body2}, &linked_binary)); + EXPECT_THAT(GetErrorMessage(), + HasSubstr("Unresolved external reference to \"foo\".")); +} + +} // namespace +} // namespace spvtools diff --git a/test/link/unique_ids_test.cpp b/test/link/unique_ids_test.cpp new file mode 100644 index 000000000..55c70ea67 --- /dev/null +++ b/test/link/unique_ids_test.cpp @@ -0,0 +1,142 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gmock/gmock.h" +#include "test/link/linker_fixture.h" + +namespace spvtools { +namespace { + +using UniqueIds = spvtest::LinkerTest; + +TEST_F(UniqueIds, UniquelyMerged) { + std::vector bodies(2); + bodies[0] = + // clang-format off + "OpCapability Shader\n" + "%1 = OpExtInstImport \"GLSL.std.450\"\n" + "OpMemoryModel Logical GLSL450\n" + "OpEntryPoint Vertex %main \"main\"\n" + "OpSource ESSL 310\n" + "OpName %main \"main\"\n" + "OpName %f_ \"f(\"\n" + "OpName %gv1 \"gv1\"\n" + "OpName %gv2 \"gv2\"\n" + "OpName %lv1 \"lv1\"\n" + "OpName %lv2 \"lv2\"\n" + "OpName %lv1_0 \"lv1\"\n" + "%void = OpTypeVoid\n" + "%10 = OpTypeFunction %void\n" + "%float = OpTypeFloat 32\n" + "%12 = OpTypeFunction %float\n" + "%_ptr_Private_float = OpTypePointer Private %float\n" + "%gv1 = OpVariable %_ptr_Private_float Private\n" + "%float_10 = OpConstant %float 10\n" + "%gv2 = OpVariable %_ptr_Private_float Private\n" + "%float_100 = OpConstant %float 100\n" + "%_ptr_Function_float = OpTypePointer Function %float\n" + "%main = OpFunction %void None %10\n" + "%17 = OpLabel\n" + "%lv1_0 = OpVariable %_ptr_Function_float Function\n" + "OpStore %gv1 %float_10\n" + "OpStore %gv2 %float_100\n" + "%18 = OpLoad %float %gv1\n" + "%19 = OpLoad %float %gv2\n" + "%20 = OpFSub %float %18 %19\n" + "OpStore %lv1_0 %20\n" + "OpReturn\n" + "OpFunctionEnd\n" + "%f_ = OpFunction %float None %12\n" + "%21 = OpLabel\n" + "%lv1 = OpVariable %_ptr_Function_float Function\n" + "%lv2 = OpVariable %_ptr_Function_float Function\n" + "%22 = OpLoad %float %gv1\n" + "%23 = OpLoad %float %gv2\n" + "%24 = OpFAdd %float %22 %23\n" + "OpStore %lv1 %24\n" + "%25 = OpLoad %float %gv1\n" + "%26 = OpLoad %float %gv2\n" + "%27 = OpFMul %float %25 %26\n" + "OpStore %lv2 %27\n" + "%28 = OpLoad %float %lv1\n" + "%29 = OpLoad %float %lv2\n" + "%30 = OpFDiv %float %28 %29\n" + "OpReturnValue %30\n" + "OpFunctionEnd\n"; + // clang-format on + bodies[1] = + // clang-format off + "OpCapability Shader\n" + "%1 = OpExtInstImport \"GLSL.std.450\"\n" + "OpMemoryModel Logical GLSL450\n" + "OpSource ESSL 310\n" + "OpName %main \"main2\"\n" + "OpName %f_ \"f(\"\n" + "OpName %gv1 \"gv12\"\n" + "OpName %gv2 \"gv22\"\n" + "OpName %lv1 \"lv12\"\n" + "OpName %lv2 \"lv22\"\n" + "OpName %lv1_0 \"lv12\"\n" + "%void = OpTypeVoid\n" + "%10 = OpTypeFunction %void\n" + "%float = OpTypeFloat 32\n" + "%12 = OpTypeFunction %float\n" + "%_ptr_Private_float = OpTypePointer Private %float\n" + "%gv1 = OpVariable %_ptr_Private_float Private\n" + "%float_10 = OpConstant %float 10\n" + "%gv2 = OpVariable %_ptr_Private_float Private\n" + "%float_100 = OpConstant %float 100\n" + "%_ptr_Function_float = OpTypePointer Function %float\n" + "%main = OpFunction %void None %10\n" + "%17 = OpLabel\n" + "%lv1_0 = OpVariable %_ptr_Function_float Function\n" + "OpStore %gv1 %float_10\n" + "OpStore %gv2 %float_100\n" + "%18 = OpLoad %float %gv1\n" + "%19 = OpLoad %float %gv2\n" + "%20 = OpFSub %float %18 %19\n" + "OpStore %lv1_0 %20\n" + "OpReturn\n" + "OpFunctionEnd\n" + "%f_ = OpFunction %float None %12\n" + "%21 = OpLabel\n" + "%lv1 = OpVariable %_ptr_Function_float Function\n" + "%lv2 = OpVariable %_ptr_Function_float Function\n" + "%22 = OpLoad %float %gv1\n" + "%23 = OpLoad %float %gv2\n" + "%24 = OpFAdd %float %22 %23\n" + "OpStore %lv1 %24\n" + "%25 = OpLoad %float %gv1\n" + "%26 = OpLoad %float %gv2\n" + "%27 = OpFMul %float %25 %26\n" + "OpStore %lv2 %27\n" + "%28 = OpLoad %float %lv1\n" + "%29 = OpLoad %float %lv2\n" + "%30 = OpFDiv %float %28 %29\n" + "OpReturnValue %30\n" + "OpFunctionEnd\n"; + // clang-format on + + spvtest::Binary linked_binary; + LinkerOptions options; + options.SetVerifyIds(true); + spv_result_t res = AssembleAndLink(bodies, &linked_binary, options); + EXPECT_EQ(SPV_SUCCESS, res); +} + +} // namespace +} // namespace spvtools diff --git a/test/log_test.cpp b/test/log_test.cpp new file mode 100644 index 000000000..ec66aa1ec --- /dev/null +++ b/test/log_test.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/log.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace spvtools { +namespace { + +using ::testing::MatchesRegex; + +TEST(Log, Unimplemented) { + int invocation = 0; + auto consumer = [&invocation](spv_message_level_t level, const char* source, + const spv_position_t&, const char* message) { + ++invocation; + EXPECT_EQ(SPV_MSG_INTERNAL_ERROR, level); + EXPECT_THAT(source, MatchesRegex(".*log_test.cpp$")); + EXPECT_STREQ("unimplemented: the-ultimite-feature", message); + }; + + SPIRV_UNIMPLEMENTED(consumer, "the-ultimite-feature"); + EXPECT_EQ(1, invocation); +} + +TEST(Log, Unreachable) { + int invocation = 0; + auto consumer = [&invocation](spv_message_level_t level, const char* source, + const spv_position_t&, const char* message) { + ++invocation; + EXPECT_EQ(SPV_MSG_INTERNAL_ERROR, level); + EXPECT_THAT(source, MatchesRegex(".*log_test.cpp$")); + EXPECT_STREQ("unreachable", message); + }; + + SPIRV_UNREACHABLE(consumer); + EXPECT_EQ(1, invocation); +} + +} // namespace +} // namespace spvtools diff --git a/test/move_to_front_test.cpp b/test/move_to_front_test.cpp new file mode 100644 index 000000000..c95d38656 --- /dev/null +++ b/test/move_to_front_test.cpp @@ -0,0 +1,828 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/comp/move_to_front.h" + +namespace spvtools { +namespace comp { +namespace { + +// Class used to test the inner workings of MoveToFront. +class MoveToFrontTester : public MoveToFront { + public: + // Inserts the value in the internal tree data structure. For testing only. + void TestInsert(uint32_t val) { InsertNode(CreateNode(val, val)); } + + // Removes the value from the internal tree data structure. For testing only. + void TestRemove(uint32_t val) { + const auto it = value_to_node_.find(val); + assert(it != value_to_node_.end()); + RemoveNode(it->second); + } + + // Prints the internal tree data structure to |out|. For testing only. + void PrintTree(std::ostream& out, bool print_timestamp = false) const { + if (root_) PrintTreeInternal(out, root_, 1, print_timestamp); + } + + // Returns node handle corresponding to the value. The value may not be in the + // tree. + uint32_t GetNodeHandle(uint32_t value) const { + const auto it = value_to_node_.find(value); + if (it == value_to_node_.end()) return 0; + + return it->second; + } + + // Returns total node count (both those in the tree and removed, + // but not the NIL singleton). + size_t GetTotalNodeCount() const { + assert(nodes_.size()); + return nodes_.size() - 1; + } + + uint32_t GetLastAccessedValue() const { return last_accessed_value_; } + + private: + // Prints the internal tree data structure for debug purposes in the following + // format: + // 10H3S4----5H1S1-----D2 + // 15H2S2----12H1S1----D3 + // Right links are horizontal, left links step down one line. + // 5H1S1 is read as value 5, height 1, size 1. Optionally node label can also + // contain timestamp (5H1S1T15). D3 stands for depth 3. + void PrintTreeInternal(std::ostream& out, uint32_t node, size_t depth, + bool print_timestamp) const; +}; + +void MoveToFrontTester::PrintTreeInternal(std::ostream& out, uint32_t node, + size_t depth, + bool print_timestamp) const { + if (!node) { + out << "D" << depth - 1 << std::endl; + return; + } + + const size_t kTextFieldWvaluethWithoutTimestamp = 10; + const size_t kTextFieldWvaluethWithTimestamp = 14; + const size_t text_field_wvalueth = print_timestamp + ? kTextFieldWvaluethWithTimestamp + : kTextFieldWvaluethWithoutTimestamp; + + std::stringstream label; + label << ValueOf(node) << "H" << HeightOf(node) << "S" << SizeOf(node); + if (print_timestamp) label << "T" << TimestampOf(node); + const size_t label_length = label.str().length(); + if (label_length < text_field_wvalueth) + label << std::string(text_field_wvalueth - label_length, '-'); + + out << label.str(); + + PrintTreeInternal(out, RightOf(node), depth + 1, print_timestamp); + + if (LeftOf(node)) { + out << std::string(depth * text_field_wvalueth, ' '); + PrintTreeInternal(out, LeftOf(node), depth + 1, print_timestamp); + } +} + +void CheckTree(const MoveToFrontTester& mtf, const std::string& expected, + bool print_timestamp = false) { + std::stringstream ss; + mtf.PrintTree(ss, print_timestamp); + EXPECT_EQ(expected, ss.str()); +} + +TEST(MoveToFront, EmptyTree) { + MoveToFrontTester mtf; + CheckTree(mtf, std::string()); +} + +TEST(MoveToFront, InsertLeftRotation) { + MoveToFrontTester mtf; + + mtf.TestInsert(30); + mtf.TestInsert(20); + + CheckTree(mtf, std::string(R"( +30H2S2----20H1S1----D2 +)") + .substr(1)); + + mtf.TestInsert(10); + CheckTree(mtf, std::string(R"( +20H2S3----10H1S1----D2 + 30H1S1----D2 +)") + .substr(1)); +} + +TEST(MoveToFront, InsertRightRotation) { + MoveToFrontTester mtf; + + mtf.TestInsert(10); + mtf.TestInsert(20); + + CheckTree(mtf, std::string(R"( +10H2S2----D1 + 20H1S1----D2 +)") + .substr(1)); + + mtf.TestInsert(30); + CheckTree(mtf, std::string(R"( +20H2S3----10H1S1----D2 + 30H1S1----D2 +)") + .substr(1)); +} + +TEST(MoveToFront, InsertRightLeftRotation) { + MoveToFrontTester mtf; + + mtf.TestInsert(30); + mtf.TestInsert(20); + + CheckTree(mtf, std::string(R"( +30H2S2----20H1S1----D2 +)") + .substr(1)); + + mtf.TestInsert(25); + CheckTree(mtf, std::string(R"( +25H2S3----20H1S1----D2 + 30H1S1----D2 +)") + .substr(1)); +} + +TEST(MoveToFront, InsertLeftRightRotation) { + MoveToFrontTester mtf; + + mtf.TestInsert(10); + mtf.TestInsert(20); + + CheckTree(mtf, std::string(R"( +10H2S2----D1 + 20H1S1----D2 +)") + .substr(1)); + + mtf.TestInsert(15); + CheckTree(mtf, std::string(R"( +15H2S3----10H1S1----D2 + 20H1S1----D2 +)") + .substr(1)); +} + +TEST(MoveToFront, RemoveSingleton) { + MoveToFrontTester mtf; + + mtf.TestInsert(10); + CheckTree(mtf, std::string(R"( +10H1S1----D1 +)") + .substr(1)); + + mtf.TestRemove(10); + CheckTree(mtf, ""); +} + +TEST(MoveToFront, RemoveRootWithScapegoat) { + MoveToFrontTester mtf; + + mtf.TestInsert(10); + mtf.TestInsert(5); + mtf.TestInsert(15); + CheckTree(mtf, std::string(R"( +10H2S3----5H1S1-----D2 + 15H1S1----D2 +)") + .substr(1)); + + mtf.TestRemove(10); + CheckTree(mtf, std::string(R"( +15H2S2----5H1S1-----D2 +)") + .substr(1)); +} + +TEST(MoveToFront, RemoveRightRotation) { + MoveToFrontTester mtf; + + mtf.TestInsert(10); + mtf.TestInsert(5); + mtf.TestInsert(15); + mtf.TestInsert(20); + CheckTree(mtf, std::string(R"( +10H3S4----5H1S1-----D2 + 15H2S2----D2 + 20H1S1----D3 +)") + .substr(1)); + + mtf.TestRemove(5); + + CheckTree(mtf, std::string(R"( +15H2S3----10H1S1----D2 + 20H1S1----D2 +)") + .substr(1)); +} + +TEST(MoveToFront, RemoveLeftRotation) { + MoveToFrontTester mtf; + + mtf.TestInsert(10); + mtf.TestInsert(15); + mtf.TestInsert(5); + mtf.TestInsert(1); + CheckTree(mtf, std::string(R"( +10H3S4----5H2S2-----1H1S1-----D3 + 15H1S1----D2 +)") + .substr(1)); + + mtf.TestRemove(15); + + CheckTree(mtf, std::string(R"( +5H2S3-----1H1S1-----D2 + 10H1S1----D2 +)") + .substr(1)); +} + +TEST(MoveToFront, RemoveLeftRightRotation) { + MoveToFrontTester mtf; + + mtf.TestInsert(10); + mtf.TestInsert(15); + mtf.TestInsert(5); + mtf.TestInsert(12); + CheckTree(mtf, std::string(R"( +10H3S4----5H1S1-----D2 + 15H2S2----12H1S1----D3 +)") + .substr(1)); + + mtf.TestRemove(5); + + CheckTree(mtf, std::string(R"( +12H2S3----10H1S1----D2 + 15H1S1----D2 +)") + .substr(1)); +} + +TEST(MoveToFront, RemoveRightLeftRotation) { + MoveToFrontTester mtf; + + mtf.TestInsert(10); + mtf.TestInsert(15); + mtf.TestInsert(5); + mtf.TestInsert(8); + CheckTree(mtf, std::string(R"( +10H3S4----5H2S2-----D2 + 8H1S1-----D3 + 15H1S1----D2 +)") + .substr(1)); + + mtf.TestRemove(15); + + CheckTree(mtf, std::string(R"( +8H2S3-----5H1S1-----D2 + 10H1S1----D2 +)") + .substr(1)); +} + +TEST(MoveToFront, MultipleOperations) { + MoveToFrontTester mtf; + std::vector vals = {5, 11, 12, 16, 15, 6, 14, 2, + 7, 10, 4, 8, 9, 3, 1, 13}; + + for (uint32_t i : vals) { + mtf.TestInsert(i); + } + + CheckTree(mtf, std::string(R"( +11H5S16---5H4S10----3H3S4-----2H2S2-----1H1S1-----D5 + 4H1S1-----D4 + 7H3S5-----6H1S1-----D4 + 9H2S3-----8H1S1-----D5 + 10H1S1----D5 + 15H3S5----13H2S3----12H1S1----D4 + 14H1S1----D4 + 16H1S1----D3 +)") + .substr(1)); + + mtf.TestRemove(11); + + CheckTree(mtf, std::string(R"( +10H5S15---5H4S9-----3H3S4-----2H2S2-----1H1S1-----D5 + 4H1S1-----D4 + 7H3S4-----6H1S1-----D4 + 9H2S2-----8H1S1-----D5 + 15H3S5----13H2S3----12H1S1----D4 + 14H1S1----D4 + 16H1S1----D3 +)") + .substr(1)); + + mtf.TestInsert(11); + + CheckTree(mtf, std::string(R"( +10H5S16---5H4S9-----3H3S4-----2H2S2-----1H1S1-----D5 + 4H1S1-----D4 + 7H3S4-----6H1S1-----D4 + 9H2S2-----8H1S1-----D5 + 13H3S6----12H2S2----11H1S1----D4 + 15H2S3----14H1S1----D4 + 16H1S1----D4 +)") + .substr(1)); + + mtf.TestRemove(5); + + CheckTree(mtf, std::string(R"( +10H5S15---6H4S8-----3H3S4-----2H2S2-----1H1S1-----D5 + 4H1S1-----D4 + 8H2S3-----7H1S1-----D4 + 9H1S1-----D4 + 13H3S6----12H2S2----11H1S1----D4 + 15H2S3----14H1S1----D4 + 16H1S1----D4 +)") + .substr(1)); + + mtf.TestInsert(5); + + CheckTree(mtf, std::string(R"( +10H5S16---6H4S9-----3H3S5-----2H2S2-----1H1S1-----D5 + 4H2S2-----D4 + 5H1S1-----D5 + 8H2S3-----7H1S1-----D4 + 9H1S1-----D4 + 13H3S6----12H2S2----11H1S1----D4 + 15H2S3----14H1S1----D4 + 16H1S1----D4 +)") + .substr(1)); + + mtf.TestRemove(2); + mtf.TestRemove(1); + mtf.TestRemove(4); + mtf.TestRemove(3); + mtf.TestRemove(6); + mtf.TestRemove(5); + mtf.TestRemove(7); + mtf.TestRemove(9); + + CheckTree(mtf, std::string(R"( +13H4S8----10H3S4----8H1S1-----D3 + 12H2S2----11H1S1----D4 + 15H2S3----14H1S1----D3 + 16H1S1----D3 +)") + .substr(1)); +} + +TEST(MoveToFront, BiggerScaleTreeTest) { + MoveToFrontTester mtf; + std::set all_vals; + + const uint32_t kMagic1 = 2654435761; + const uint32_t kMagic2 = 10000; + + for (uint32_t i = 1; i < 1000; ++i) { + const uint32_t val = (i * kMagic1) % kMagic2; + if (!all_vals.count(val)) { + mtf.TestInsert(val); + all_vals.insert(val); + } + } + + for (uint32_t i = 1; i < 1000; ++i) { + const uint32_t val = (i * kMagic1) % kMagic2; + if (val % 2 == 0) { + mtf.TestRemove(val); + all_vals.erase(val); + } + } + + for (uint32_t i = 1000; i < 2000; ++i) { + const uint32_t val = (i * kMagic1) % kMagic2; + if (!all_vals.count(val)) { + mtf.TestInsert(val); + all_vals.insert(val); + } + } + + for (uint32_t i = 1; i < 2000; ++i) { + const uint32_t val = (i * kMagic1) % kMagic2; + if (val > 50) { + mtf.TestRemove(val); + all_vals.erase(val); + } + } + + EXPECT_EQ(all_vals, std::set({2, 4, 11, 13, 24, 33, 35, 37, 46})); + + CheckTree(mtf, std::string(R"( +33H4S9----11H3S5----2H2S2-----D3 + 4H1S1-----D4 + 13H2S2----D3 + 24H1S1----D4 + 37H2S3----35H1S1----D3 + 46H1S1----D3 +)") + .substr(1)); +} + +TEST(MoveToFront, RankFromValue) { + MoveToFrontTester mtf; + + uint32_t rank = 0; + EXPECT_FALSE(mtf.RankFromValue(1, &rank)); + + EXPECT_TRUE(mtf.Insert(1)); + EXPECT_TRUE(mtf.Insert(2)); + EXPECT_TRUE(mtf.Insert(3)); + EXPECT_FALSE(mtf.Insert(2)); + CheckTree(mtf, + std::string(R"( +2H2S3T2-------1H1S1T1-------D2 + 3H1S1T3-------D2 +)") + .substr(1), + /* print_timestamp = */ true); + + EXPECT_FALSE(mtf.RankFromValue(4, &rank)); + + EXPECT_TRUE(mtf.RankFromValue(1, &rank)); + EXPECT_EQ(3u, rank); + + CheckTree(mtf, + std::string(R"( +3H2S3T3-------2H1S1T2-------D2 + 1H1S1T4-------D2 +)") + .substr(1), + /* print_timestamp = */ true); + + EXPECT_TRUE(mtf.RankFromValue(1, &rank)); + EXPECT_EQ(1u, rank); + + EXPECT_TRUE(mtf.RankFromValue(3, &rank)); + EXPECT_EQ(2u, rank); + + EXPECT_TRUE(mtf.RankFromValue(2, &rank)); + EXPECT_EQ(3u, rank); + + EXPECT_TRUE(mtf.Insert(40)); + + EXPECT_TRUE(mtf.RankFromValue(1, &rank)); + EXPECT_EQ(4u, rank); + + EXPECT_TRUE(mtf.Insert(50)); + + EXPECT_TRUE(mtf.RankFromValue(1, &rank)); + EXPECT_EQ(2u, rank); + + CheckTree(mtf, + std::string(R"( +2H3S5T6-------3H1S1T5-------D2 + 50H2S3T9------40H1S1T7------D3 + 1H1S1T10------D3 +)") + .substr(1), + /* print_timestamp = */ true); + + EXPECT_TRUE(mtf.RankFromValue(50, &rank)); + EXPECT_EQ(2u, rank); + + EXPECT_EQ(5u, mtf.GetSize()); + CheckTree(mtf, + std::string(R"( +2H3S5T6-------3H1S1T5-------D2 + 1H2S3T10------40H1S1T7------D3 + 50H1S1T11-----D3 +)") + .substr(1), + /* print_timestamp = */ true); + + EXPECT_FALSE(mtf.RankFromValue(0, &rank)); + EXPECT_FALSE(mtf.RankFromValue(20, &rank)); +} + +TEST(MoveToFront, ValueFromRank) { + MoveToFrontTester mtf; + + uint32_t value = 0; + EXPECT_FALSE(mtf.ValueFromRank(0, &value)); + EXPECT_FALSE(mtf.ValueFromRank(1, &value)); + + EXPECT_TRUE(mtf.Insert(1)); + EXPECT_EQ(1u, mtf.GetLastAccessedValue()); + EXPECT_TRUE(mtf.Insert(2)); + EXPECT_EQ(2u, mtf.GetLastAccessedValue()); + EXPECT_TRUE(mtf.Insert(3)); + EXPECT_EQ(3u, mtf.GetLastAccessedValue()); + + EXPECT_TRUE(mtf.ValueFromRank(3, &value)); + EXPECT_EQ(1u, value); + EXPECT_EQ(1u, mtf.GetLastAccessedValue()); + + EXPECT_TRUE(mtf.ValueFromRank(1, &value)); + EXPECT_EQ(1u, value); + EXPECT_EQ(1u, mtf.GetLastAccessedValue()); + + CheckTree(mtf, + std::string(R"( +3H2S3T3-------2H1S1T2-------D2 + 1H1S1T4-------D2 +)") + .substr(1), + /* print_timestamp = */ true); + + EXPECT_TRUE(mtf.ValueFromRank(2, &value)); + EXPECT_EQ(3u, value); + + EXPECT_EQ(3u, mtf.GetSize()); + + CheckTree(mtf, + std::string(R"( +1H2S3T4-------2H1S1T2-------D2 + 3H1S1T5-------D2 +)") + .substr(1), + /* print_timestamp = */ true); + + EXPECT_TRUE(mtf.ValueFromRank(3, &value)); + EXPECT_EQ(2u, value); + + CheckTree(mtf, + std::string(R"( +3H2S3T5-------1H1S1T4-------D2 + 2H1S1T6-------D2 +)") + .substr(1), + /* print_timestamp = */ true); + + EXPECT_TRUE(mtf.Insert(10)); + CheckTree(mtf, + std::string(R"( +3H3S4T5-------1H1S1T4-------D2 + 2H2S2T6-------D2 + 10H1S1T7------D3 +)") + .substr(1), + /* print_timestamp = */ true); + + EXPECT_TRUE(mtf.ValueFromRank(1, &value)); + EXPECT_EQ(10u, value); +} + +TEST(MoveToFront, Remove) { + MoveToFrontTester mtf; + + EXPECT_FALSE(mtf.Remove(1)); + EXPECT_EQ(0u, mtf.GetTotalNodeCount()); + + EXPECT_TRUE(mtf.Insert(1)); + EXPECT_TRUE(mtf.Insert(2)); + EXPECT_TRUE(mtf.Insert(3)); + + CheckTree(mtf, + std::string(R"( +2H2S3T2-------1H1S1T1-------D2 + 3H1S1T3-------D2 +)") + .substr(1), + /* print_timestamp = */ true); + + EXPECT_EQ(1u, mtf.GetNodeHandle(1)); + EXPECT_EQ(3u, mtf.GetTotalNodeCount()); + EXPECT_TRUE(mtf.Remove(1)); + EXPECT_EQ(3u, mtf.GetTotalNodeCount()); + + CheckTree(mtf, + std::string(R"( +2H2S2T2-------D1 + 3H1S1T3-------D2 +)") + .substr(1), + /* print_timestamp = */ true); + + uint32_t value = 0; + EXPECT_TRUE(mtf.ValueFromRank(2, &value)); + EXPECT_EQ(2u, value); + + CheckTree(mtf, + std::string(R"( +3H2S2T3-------D1 + 2H1S1T4-------D2 +)") + .substr(1), + /* print_timestamp = */ true); + + EXPECT_TRUE(mtf.Insert(1)); + EXPECT_EQ(1u, mtf.GetNodeHandle(1)); + EXPECT_EQ(3u, mtf.GetTotalNodeCount()); +} + +TEST(MoveToFront, LargerScale) { + MoveToFrontTester mtf; + uint32_t value = 0; + uint32_t rank = 0; + + for (uint32_t i = 1; i < 1000; ++i) { + ASSERT_TRUE(mtf.Insert(i)); + ASSERT_EQ(i, mtf.GetSize()); + + ASSERT_TRUE(mtf.RankFromValue(i, &rank)); + ASSERT_EQ(1u, rank); + + ASSERT_TRUE(mtf.ValueFromRank(1, &value)); + ASSERT_EQ(i, value); + } + + ASSERT_TRUE(mtf.ValueFromRank(999, &value)); + ASSERT_EQ(1u, value); + + ASSERT_TRUE(mtf.ValueFromRank(999, &value)); + ASSERT_EQ(2u, value); + + ASSERT_TRUE(mtf.ValueFromRank(999, &value)); + ASSERT_EQ(3u, value); + + ASSERT_TRUE(mtf.ValueFromRank(999, &value)); + ASSERT_EQ(4u, value); + + ASSERT_TRUE(mtf.ValueFromRank(999, &value)); + ASSERT_EQ(5u, value); + + ASSERT_TRUE(mtf.ValueFromRank(999, &value)); + ASSERT_EQ(6u, value); + + ASSERT_TRUE(mtf.ValueFromRank(101, &value)); + ASSERT_EQ(905u, value); + + ASSERT_TRUE(mtf.ValueFromRank(101, &value)); + ASSERT_EQ(906u, value); + + ASSERT_TRUE(mtf.ValueFromRank(101, &value)); + ASSERT_EQ(907u, value); + + ASSERT_TRUE(mtf.ValueFromRank(201, &value)); + ASSERT_EQ(805u, value); + + ASSERT_TRUE(mtf.ValueFromRank(201, &value)); + ASSERT_EQ(806u, value); + + ASSERT_TRUE(mtf.ValueFromRank(201, &value)); + ASSERT_EQ(807u, value); + + ASSERT_TRUE(mtf.ValueFromRank(301, &value)); + ASSERT_EQ(705u, value); + + ASSERT_TRUE(mtf.ValueFromRank(301, &value)); + ASSERT_EQ(706u, value); + + ASSERT_TRUE(mtf.ValueFromRank(301, &value)); + ASSERT_EQ(707u, value); + + ASSERT_TRUE(mtf.RankFromValue(605, &rank)); + ASSERT_EQ(401u, rank); + + ASSERT_TRUE(mtf.RankFromValue(606, &rank)); + ASSERT_EQ(401u, rank); + + ASSERT_TRUE(mtf.RankFromValue(607, &rank)); + ASSERT_EQ(401u, rank); + + ASSERT_TRUE(mtf.ValueFromRank(1, &value)); + ASSERT_EQ(607u, value); + + ASSERT_TRUE(mtf.ValueFromRank(2, &value)); + ASSERT_EQ(606u, value); + + ASSERT_TRUE(mtf.ValueFromRank(3, &value)); + ASSERT_EQ(605u, value); + + ASSERT_TRUE(mtf.ValueFromRank(4, &value)); + ASSERT_EQ(707u, value); + + ASSERT_TRUE(mtf.ValueFromRank(5, &value)); + ASSERT_EQ(706u, value); + + ASSERT_TRUE(mtf.ValueFromRank(6, &value)); + ASSERT_EQ(705u, value); + + ASSERT_TRUE(mtf.ValueFromRank(7, &value)); + ASSERT_EQ(807u, value); + + ASSERT_TRUE(mtf.ValueFromRank(8, &value)); + ASSERT_EQ(806u, value); + + ASSERT_TRUE(mtf.ValueFromRank(9, &value)); + ASSERT_EQ(805u, value); + + ASSERT_TRUE(mtf.ValueFromRank(10, &value)); + ASSERT_EQ(907u, value); + + ASSERT_TRUE(mtf.ValueFromRank(11, &value)); + ASSERT_EQ(906u, value); + + ASSERT_TRUE(mtf.ValueFromRank(12, &value)); + ASSERT_EQ(905u, value); + + ASSERT_TRUE(mtf.ValueFromRank(13, &value)); + ASSERT_EQ(6u, value); + + ASSERT_TRUE(mtf.ValueFromRank(14, &value)); + ASSERT_EQ(5u, value); + + ASSERT_TRUE(mtf.ValueFromRank(15, &value)); + ASSERT_EQ(4u, value); + + ASSERT_TRUE(mtf.ValueFromRank(16, &value)); + ASSERT_EQ(3u, value); + + ASSERT_TRUE(mtf.ValueFromRank(17, &value)); + ASSERT_EQ(2u, value); + + ASSERT_TRUE(mtf.ValueFromRank(18, &value)); + ASSERT_EQ(1u, value); + + ASSERT_TRUE(mtf.ValueFromRank(19, &value)); + ASSERT_EQ(999u, value); + + ASSERT_TRUE(mtf.ValueFromRank(20, &value)); + ASSERT_EQ(998u, value); + + ASSERT_TRUE(mtf.ValueFromRank(21, &value)); + ASSERT_EQ(997u, value); + + ASSERT_TRUE(mtf.RankFromValue(997, &rank)); + ASSERT_EQ(1u, rank); + + ASSERT_TRUE(mtf.RankFromValue(998, &rank)); + ASSERT_EQ(2u, rank); + + ASSERT_TRUE(mtf.RankFromValue(996, &rank)); + ASSERT_EQ(22u, rank); + + ASSERT_TRUE(mtf.Remove(995)); + + ASSERT_TRUE(mtf.RankFromValue(994, &rank)); + ASSERT_EQ(23u, rank); + + for (uint32_t i = 10; i < 1000; ++i) { + if (i != 995) { + ASSERT_TRUE(mtf.Remove(i)); + } else { + ASSERT_FALSE(mtf.Remove(i)); + } + } + + CheckTree(mtf, + std::string(R"( +6H4S9T1029----8H2S3T8-------7H1S1T7-------D3 + 9H1S1T9-------D3 + 2H3S5T1033----4H2S3T1031----5H1S1T1030----D4 + 3H1S1T1032----D4 + 1H1S1T1034----D3 +)") + .substr(1), + /* print_timestamp = */ true); + + ASSERT_TRUE(mtf.Insert(1000)); + ASSERT_TRUE(mtf.ValueFromRank(1, &value)); + ASSERT_EQ(1000u, value); +} + +} // namespace +} // namespace comp +} // namespace spvtools diff --git a/test/name_mapper_test.cpp b/test/name_mapper_test.cpp new file mode 100644 index 000000000..9a9ee8aa0 --- /dev/null +++ b/test/name_mapper_test.cpp @@ -0,0 +1,347 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gmock/gmock.h" +#include "source/name_mapper.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using spvtest::ScopedContext; +using ::testing::Eq; + +TEST(TrivialNameTest, Samples) { + auto mapper = GetTrivialNameMapper(); + EXPECT_EQ(mapper(1), "1"); + EXPECT_EQ(mapper(1999), "1999"); + EXPECT_EQ(mapper(1024), "1024"); +} + +// A test case for the name mappers that actually look at an assembled module. +struct NameIdCase { + std::string assembly; // Input assembly text + uint32_t id; + std::string expected_name; +}; + +using FriendlyNameTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>; + +TEST_P(FriendlyNameTest, SingleMapping) { + ScopedContext context(SPV_ENV_UNIVERSAL_1_1); + auto words = CompileSuccessfully(GetParam().assembly, SPV_ENV_UNIVERSAL_1_1); + auto friendly_mapper = + FriendlyNameMapper(context.context, words.data(), words.size()); + NameMapper mapper = friendly_mapper.GetNameMapper(); + EXPECT_THAT(mapper(GetParam().id), Eq(GetParam().expected_name)) + << GetParam().assembly << std::endl + << " for id " << GetParam().id; +} + +INSTANTIATE_TEST_CASE_P(ScalarType, FriendlyNameTest, + ::testing::ValuesIn(std::vector{ + {"%1 = OpTypeVoid", 1, "void"}, + {"%1 = OpTypeBool", 1, "bool"}, + {"%1 = OpTypeInt 8 0", 1, "uchar"}, + {"%1 = OpTypeInt 8 1", 1, "char"}, + {"%1 = OpTypeInt 16 0", 1, "ushort"}, + {"%1 = OpTypeInt 16 1", 1, "short"}, + {"%1 = OpTypeInt 32 0", 1, "uint"}, + {"%1 = OpTypeInt 32 1", 1, "int"}, + {"%1 = OpTypeInt 64 0", 1, "ulong"}, + {"%1 = OpTypeInt 64 1", 1, "long"}, + {"%1 = OpTypeInt 1 0", 1, "u1"}, + {"%1 = OpTypeInt 1 1", 1, "i1"}, + {"%1 = OpTypeInt 33 0", 1, "u33"}, + {"%1 = OpTypeInt 33 1", 1, "i33"}, + + {"%1 = OpTypeFloat 16", 1, "half"}, + {"%1 = OpTypeFloat 32", 1, "float"}, + {"%1 = OpTypeFloat 64", 1, "double"}, + {"%1 = OpTypeFloat 10", 1, "fp10"}, + {"%1 = OpTypeFloat 55", 1, "fp55"}, + }), ); + +INSTANTIATE_TEST_CASE_P( + VectorType, FriendlyNameTest, + ::testing::ValuesIn(std::vector{ + {"%1 = OpTypeBool %2 = OpTypeVector %1 1", 2, "v1bool"}, + {"%1 = OpTypeBool %2 = OpTypeVector %1 2", 2, "v2bool"}, + {"%1 = OpTypeBool %2 = OpTypeVector %1 3", 2, "v3bool"}, + {"%1 = OpTypeBool %2 = OpTypeVector %1 4", 2, "v4bool"}, + + {"%1 = OpTypeInt 8 0 %2 = OpTypeVector %1 2", 2, "v2uchar"}, + {"%1 = OpTypeInt 16 1 %2 = OpTypeVector %1 3", 2, "v3short"}, + {"%1 = OpTypeInt 32 0 %2 = OpTypeVector %1 4", 2, "v4uint"}, + {"%1 = OpTypeInt 64 1 %2 = OpTypeVector %1 3", 2, "v3long"}, + {"%1 = OpTypeInt 20 0 %2 = OpTypeVector %1 4", 2, "v4u20"}, + {"%1 = OpTypeInt 21 1 %2 = OpTypeVector %1 3", 2, "v3i21"}, + + {"%1 = OpTypeFloat 32 %2 = OpTypeVector %1 2", 2, "v2float"}, + // OpName overrides the element name. + {"OpName %1 \"time\" %1 = OpTypeFloat 32 %2 = OpTypeVector %1 2", 2, + "v2time"}, + }), ); + +INSTANTIATE_TEST_CASE_P( + MatrixType, FriendlyNameTest, + ::testing::ValuesIn(std::vector{ + {"%1 = OpTypeBool %2 = OpTypeVector %1 2 %3 = OpTypeMatrix %2 2", 3, + "mat2v2bool"}, + {"%1 = OpTypeFloat 32 %2 = OpTypeVector %1 2 %3 = OpTypeMatrix %2 3", 3, + "mat3v2float"}, + {"%1 = OpTypeFloat 32 %2 = OpTypeVector %1 2 %3 = OpTypeMatrix %2 4", 3, + "mat4v2float"}, + {"OpName %1 \"time\" %1 = OpTypeFloat 32 %2 = OpTypeVector %1 2 %3 = " + "OpTypeMatrix %2 4", + 3, "mat4v2time"}, + {"OpName %2 \"lat_long\" %1 = OpTypeFloat 32 %2 = OpTypeVector %1 2 %3 " + "= OpTypeMatrix %2 4", + 3, "mat4lat_long"}, + }), ); + +INSTANTIATE_TEST_CASE_P( + OpName, FriendlyNameTest, + ::testing::ValuesIn(std::vector{ + {"OpName %1 \"abcdefg\"", 1, "abcdefg"}, + {"OpName %1 \"Hello world!\"", 1, "Hello_world_"}, + {"OpName %1 \"0123456789\"", 1, "0123456789"}, + {"OpName %1 \"_\"", 1, "_"}, + // An empty string is not valid for SPIR-V assembly IDs. + {"OpName %1 \"\"", 1, "_"}, + // Test uniqueness when presented with things mapping to "_" + {"OpName %1 \"\" OpName %2 \"\"", 1, "_"}, + {"OpName %1 \"\" OpName %2 \"\"", 2, "__0"}, + {"OpName %1 \"\" OpName %2 \"\" OpName %3 \"_\"", 3, "__1"}, + // Test uniqueness of names that are forced to be + // numbers. + {"OpName %1 \"2\" OpName %2 \"2\"", 1, "2"}, + {"OpName %1 \"2\" OpName %2 \"2\"", 2, "2_0"}, + // Test uniqueness in the face of forward references + // for Ids that don't already have friendly names. + // In particular, the first OpDecorate assigns the name, and + // the second one can't override it. + {"OpDecorate %1 Volatile OpDecorate %1 Restrict", 1, "1"}, + // But a forced name can override the name that + // would have been assigned via the OpDecorate + // forward reference. + {"OpName %1 \"mememe\" OpDecorate %1 Volatile OpDecorate %1 Restrict", + 1, "mememe"}, + // OpName can override other inferences. We assume valid instruction + // ordering, where OpName precedes type definitions. + {"OpName %1 \"myfloat\" %1 = OpTypeFloat 32", 1, "myfloat"}, + }), ); + +INSTANTIATE_TEST_CASE_P( + UniquenessHeuristic, FriendlyNameTest, + ::testing::ValuesIn(std::vector{ + {"%1 = OpTypeVoid %2 = OpTypeVoid %3 = OpTypeVoid", 1, "void"}, + {"%1 = OpTypeVoid %2 = OpTypeVoid %3 = OpTypeVoid", 2, "void_0"}, + {"%1 = OpTypeVoid %2 = OpTypeVoid %3 = OpTypeVoid", 3, "void_1"}, + }), ); + +INSTANTIATE_TEST_CASE_P(Arrays, FriendlyNameTest, + ::testing::ValuesIn(std::vector{ + {"OpName %2 \"FortyTwo\" %1 = OpTypeFloat 32 " + "%2 = OpConstant %1 42 %3 = OpTypeArray %1 %2", + 3, "_arr_float_FortyTwo"}, + {"%1 = OpTypeInt 32 0 " + "%2 = OpTypeRuntimeArray %1", + 2, "_runtimearr_uint"}, + }), ); + +INSTANTIATE_TEST_CASE_P(Structs, FriendlyNameTest, + ::testing::ValuesIn(std::vector{ + {"%1 = OpTypeBool " + "%2 = OpTypeStruct %1 %1 %1", + 2, "_struct_2"}, + {"%1 = OpTypeBool " + "%2 = OpTypeStruct %1 %1 %1 " + "%3 = OpTypeStruct %2 %2", + 3, "_struct_3"}, + }), ); + +INSTANTIATE_TEST_CASE_P( + Pointer, FriendlyNameTest, + ::testing::ValuesIn(std::vector{ + {"%1 = OpTypeFloat 32 %2 = OpTypePointer Workgroup %1", 2, + "_ptr_Workgroup_float"}, + {"%1 = OpTypeBool %2 = OpTypePointer Private %1", 2, + "_ptr_Private_bool"}, + // OpTypeForwardPointer doesn't force generation of the name for its + // target type. + {"%1 = OpTypeBool OpTypeForwardPointer %2 Private %2 = OpTypePointer " + "Private %1", + 2, "_ptr_Private_bool"}, + }), ); + +INSTANTIATE_TEST_CASE_P(ExoticTypes, FriendlyNameTest, + ::testing::ValuesIn(std::vector{ + {"%1 = OpTypeEvent", 1, "Event"}, + {"%1 = OpTypeDeviceEvent", 1, "DeviceEvent"}, + {"%1 = OpTypeReserveId", 1, "ReserveId"}, + {"%1 = OpTypeQueue", 1, "Queue"}, + {"%1 = OpTypeOpaque \"hello world!\"", 1, + "Opaque_hello_world_"}, + {"%1 = OpTypePipe ReadOnly", 1, "PipeReadOnly"}, + {"%1 = OpTypePipe WriteOnly", 1, "PipeWriteOnly"}, + {"%1 = OpTypePipe ReadWrite", 1, "PipeReadWrite"}, + {"%1 = OpTypePipeStorage", 1, "PipeStorage"}, + {"%1 = OpTypeNamedBarrier", 1, "NamedBarrier"}, + }), ); + +// Makes a test case for a BuiltIn variable declaration. +NameIdCase BuiltInCase(std::string assembly_name, std::string expected) { + return NameIdCase{std::string("OpDecorate %1 BuiltIn ") + assembly_name + + " %1 = OpVariable %2 Input", + 1, expected}; +} + +// Makes a test case for a BuiltIn variable declaration. In this overload, +// the expected result is the same as the assembly name. +NameIdCase BuiltInCase(std::string assembly_name) { + return BuiltInCase(assembly_name, assembly_name); +} + +// Makes a test case for a BuiltIn variable declaration. In this overload, +// the expected result is the same as the assembly name, but with a "gl_" +// prefix. +NameIdCase BuiltInGLCase(std::string assembly_name) { + return BuiltInCase(assembly_name, std::string("gl_") + assembly_name); +} + +INSTANTIATE_TEST_CASE_P( + BuiltIns, FriendlyNameTest, + ::testing::ValuesIn(std::vector{ + BuiltInGLCase("Position"), + BuiltInGLCase("PointSize"), + BuiltInGLCase("ClipDistance"), + BuiltInGLCase("CullDistance"), + BuiltInCase("VertexId", "gl_VertexID"), + BuiltInCase("InstanceId", "gl_InstanceID"), + BuiltInCase("PrimitiveId", "gl_PrimitiveID"), + BuiltInCase("InvocationId", "gl_InvocationID"), + BuiltInGLCase("Layer"), + BuiltInGLCase("ViewportIndex"), + BuiltInGLCase("TessLevelOuter"), + BuiltInGLCase("TessLevelInner"), + BuiltInGLCase("TessCoord"), + BuiltInGLCase("PatchVertices"), + BuiltInGLCase("FragCoord"), + BuiltInGLCase("PointCoord"), + BuiltInGLCase("FrontFacing"), + BuiltInCase("SampleId", "gl_SampleID"), + BuiltInGLCase("SamplePosition"), + BuiltInGLCase("SampleMask"), + BuiltInGLCase("FragDepth"), + BuiltInGLCase("HelperInvocation"), + BuiltInCase("NumWorkgroups", "gl_NumWorkGroups"), + BuiltInCase("WorkgroupSize", "gl_WorkGroupSize"), + BuiltInCase("WorkgroupId", "gl_WorkGroupID"), + BuiltInCase("LocalInvocationId", "gl_LocalInvocationID"), + BuiltInCase("GlobalInvocationId", "gl_GlobalInvocationID"), + BuiltInGLCase("LocalInvocationIndex"), + BuiltInCase("WorkDim"), + BuiltInCase("GlobalSize"), + BuiltInCase("EnqueuedWorkgroupSize"), + BuiltInCase("GlobalOffset"), + BuiltInCase("GlobalLinearId"), + BuiltInCase("SubgroupSize"), + BuiltInCase("SubgroupMaxSize"), + BuiltInCase("NumSubgroups"), + BuiltInCase("NumEnqueuedSubgroups"), + BuiltInCase("SubgroupId"), + BuiltInCase("SubgroupLocalInvocationId"), + BuiltInGLCase("VertexIndex"), + BuiltInGLCase("InstanceIndex"), + BuiltInCase("SubgroupEqMaskKHR"), + BuiltInCase("SubgroupGeMaskKHR"), + BuiltInCase("SubgroupGtMaskKHR"), + BuiltInCase("SubgroupLeMaskKHR"), + BuiltInCase("SubgroupLtMaskKHR"), + }), ); + +INSTANTIATE_TEST_CASE_P(DebugNameOverridesBuiltin, FriendlyNameTest, + ::testing::ValuesIn(std::vector{ + {"OpName %1 \"foo\" OpDecorate %1 BuiltIn WorkDim " + "%1 = OpVariable %2 Input", + 1, "foo"}}), ); + +INSTANTIATE_TEST_CASE_P( + SimpleIntegralConstants, FriendlyNameTest, + ::testing::ValuesIn(std::vector{ + {"%1 = OpTypeInt 32 0 %2 = OpConstant %1 0", 2, "uint_0"}, + {"%1 = OpTypeInt 32 0 %2 = OpConstant %1 1", 2, "uint_1"}, + {"%1 = OpTypeInt 32 0 %2 = OpConstant %1 2", 2, "uint_2"}, + {"%1 = OpTypeInt 32 0 %2 = OpConstant %1 9", 2, "uint_9"}, + {"%1 = OpTypeInt 32 0 %2 = OpConstant %1 42", 2, "uint_42"}, + {"%1 = OpTypeInt 32 1 %2 = OpConstant %1 0", 2, "int_0"}, + {"%1 = OpTypeInt 32 1 %2 = OpConstant %1 1", 2, "int_1"}, + {"%1 = OpTypeInt 32 1 %2 = OpConstant %1 2", 2, "int_2"}, + {"%1 = OpTypeInt 32 1 %2 = OpConstant %1 9", 2, "int_9"}, + {"%1 = OpTypeInt 32 1 %2 = OpConstant %1 42", 2, "int_42"}, + {"%1 = OpTypeInt 32 1 %2 = OpConstant %1 -42", 2, "int_n42"}, + // Exotic bit widths + {"%1 = OpTypeInt 33 0 %2 = OpConstant %1 0", 2, "u33_0"}, + {"%1 = OpTypeInt 33 1 %2 = OpConstant %1 10", 2, "i33_10"}, + {"%1 = OpTypeInt 33 1 %2 = OpConstant %1 -19", 2, "i33_n19"}, + }), ); + +INSTANTIATE_TEST_CASE_P( + SimpleFloatConstants, FriendlyNameTest, + ::testing::ValuesIn(std::vector{ + {"%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.ff4p+16", 2, + "half_0x1_ff4p_16"}, + {"%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.d2cp-10", 2, + "half_n0x1_d2cpn10"}, + // 32-bit floats + {"%1 = OpTypeFloat 32\n%2 = OpConstant %1 -3.125", 2, "float_n3_125"}, + {"%1 = OpTypeFloat 32\n%2 = OpConstant %1 0x1.8p+128", 2, + "float_0x1_8p_128"}, // NaN + {"%1 = OpTypeFloat 32\n%2 = OpConstant %1 -0x1.0002p+128", 2, + "float_n0x1_0002p_128"}, // NaN + {"%1 = OpTypeFloat 32\n%2 = OpConstant %1 0x1p+128", 2, + "float_0x1p_128"}, // Inf + {"%1 = OpTypeFloat 32\n%2 = OpConstant %1 -0x1p+128", 2, + "float_n0x1p_128"}, // -Inf + // 64-bit floats + {"%1 = OpTypeFloat 64\n%2 = OpConstant %1 -3.125", 2, "double_n3_125"}, + {"%1 = OpTypeFloat 64\n%2 = OpConstant %1 0x1.ffffffffffffap-1023", 2, + "double_0x1_ffffffffffffapn1023"}, // small normal + {"%1 = OpTypeFloat 64\n%2 = OpConstant %1 -0x1.ffffffffffffap-1023", 2, + "double_n0x1_ffffffffffffapn1023"}, + {"%1 = OpTypeFloat 64\n%2 = OpConstant %1 0x1.8p+1024", 2, + "double_0x1_8p_1024"}, // NaN + {"%1 = OpTypeFloat 64\n%2 = OpConstant %1 -0x1.0002p+1024", 2, + "double_n0x1_0002p_1024"}, // NaN + {"%1 = OpTypeFloat 64\n%2 = OpConstant %1 0x1p+1024", 2, + "double_0x1p_1024"}, // Inf + {"%1 = OpTypeFloat 64\n%2 = OpConstant %1 -0x1p+1024", 2, + "double_n0x1p_1024"}, // -Inf + }), ); + +INSTANTIATE_TEST_CASE_P( + BooleanConstants, FriendlyNameTest, + ::testing::ValuesIn(std::vector{ + {"%1 = OpTypeBool\n%2 = OpConstantTrue %1", 2, "true"}, + {"%1 = OpTypeBool\n%2 = OpConstantFalse %1", 2, "false"}, + }), ); + +} // namespace +} // namespace spvtools diff --git a/test/named_id_test.cpp b/test/named_id_test.cpp new file mode 100644 index 000000000..4ba54adc3 --- /dev/null +++ b/test/named_id_test.cpp @@ -0,0 +1,87 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using NamedIdTest = spvtest::TextToBinaryTest; + +TEST_F(NamedIdTest, Default) { + const std::string input = R"( + OpCapability Shader + OpMemoryModel Logical Simple + OpEntryPoint Vertex %main "foo" + %void = OpTypeVoid +%fnMain = OpTypeFunction %void + %main = OpFunction %void None %fnMain +%lbMain = OpLabel + OpReturn + OpFunctionEnd)"; + const std::string output = + "OpCapability Shader\n" + "OpMemoryModel Logical Simple\n" + "OpEntryPoint Vertex %1 \"foo\"\n" + "%2 = OpTypeVoid\n" + "%3 = OpTypeFunction %2\n" + "%1 = OpFunction %2 None %3\n" + "%4 = OpLabel\n" + "OpReturn\n" + "OpFunctionEnd\n"; + EXPECT_EQ(output, EncodeAndDecodeSuccessfully(input)); +} + +struct IdCheckCase { + std::string id; + bool valid; +}; + +using IdValidityTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>; + +TEST_P(IdValidityTest, IdTypes) { + const std::string input = GetParam().id + " = OpTypeVoid"; + SetText(input); + if (GetParam().valid) { + CompileSuccessfully(input); + } else { + CompileFailure(input); + } +} + +INSTANTIATE_TEST_CASE_P( + ValidAndInvalidIds, IdValidityTest, + ::testing::ValuesIn(std::vector( + {{"%1", true}, {"%2abc", true}, {"%3Def", true}, + {"%4GHI", true}, {"%5_j_k", true}, {"%6J_M", true}, + {"%n", true}, {"%O", true}, {"%p7", true}, + {"%Q8", true}, {"%R_S", true}, {"%T_10_U", true}, + {"%V_11", true}, {"%W_X_13", true}, {"%_A", true}, + {"%_", true}, {"%__", true}, {"%A_", true}, + {"%_A_", true}, + + {"%@", false}, {"%!", false}, {"%ABC!", false}, + {"%__A__@", false}, {"%%", false}, {"%-", false}, + {"%foo_@_bar", false}, {"%", false}, + + {"5", false}, {"32", false}, {"foo", false}, + {"a%bar", false}})), ); + +} // namespace +} // namespace spvtools diff --git a/test/opcode_make_test.cpp b/test/opcode_make_test.cpp new file mode 100644 index 000000000..6481ef326 --- /dev/null +++ b/test/opcode_make_test.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +// A sampling of word counts. Covers extreme points well, and all bit +// positions, and some combinations of bit positions. +const uint16_t kSampleWordCounts[] = { + 0, 1, 2, 3, 4, 8, 16, 32, 64, 127, 128, + 256, 511, 512, 1024, 2048, 4096, 8192, 16384, 32768, 0xfffe, 0xffff}; + +// A sampling of opcode values. Covers the lower values well, a few samples +// around the number of core instructions (as of this writing), and some +// higher values. +const uint16_t kSampleOpcodes[] = {0, 1, 2, 3, 4, 100, + 300, 305, 1023, 0xfffe, 0xffff}; + +TEST(OpcodeMake, Samples) { + for (auto wordCount : kSampleWordCounts) { + for (auto opcode : kSampleOpcodes) { + uint32_t word = 0; + word |= uint32_t(opcode); + word |= uint32_t(wordCount) << 16; + EXPECT_EQ(word, spvOpcodeMake(wordCount, SpvOp(opcode))); + } + } +} + +} // namespace +} // namespace spvtools diff --git a/test/opcode_require_capabilities_test.cpp b/test/opcode_require_capabilities_test.cpp new file mode 100644 index 000000000..32bf1dc08 --- /dev/null +++ b/test/opcode_require_capabilities_test.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/unit_spirv.h" + +#include "source/enum_set.h" + +namespace spvtools { +namespace { + +using spvtest::ElementsIn; + +// Capabilities required by an Opcode. +struct ExpectedOpCodeCapabilities { + SpvOp opcode; + CapabilitySet capabilities; +}; + +using OpcodeTableCapabilitiesTest = + ::testing::TestWithParam; + +TEST_P(OpcodeTableCapabilitiesTest, TableEntryMatchesExpectedCapabilities) { + auto env = SPV_ENV_UNIVERSAL_1_1; + spv_opcode_table opcodeTable; + ASSERT_EQ(SPV_SUCCESS, spvOpcodeTableGet(&opcodeTable, env)); + spv_opcode_desc entry; + ASSERT_EQ(SPV_SUCCESS, spvOpcodeTableValueLookup(env, opcodeTable, + GetParam().opcode, &entry)); + EXPECT_EQ( + ElementsIn(GetParam().capabilities), + ElementsIn(CapabilitySet(entry->numCapabilities, entry->capabilities))); +} + +INSTANTIATE_TEST_CASE_P( + TableRowTest, OpcodeTableCapabilitiesTest, + // Spot-check a few opcodes. + ::testing::Values( + ExpectedOpCodeCapabilities{ + SpvOpImageQuerySize, + CapabilitySet{SpvCapabilityKernel, SpvCapabilityImageQuery}}, + ExpectedOpCodeCapabilities{ + SpvOpImageQuerySizeLod, + CapabilitySet{SpvCapabilityKernel, SpvCapabilityImageQuery}}, + ExpectedOpCodeCapabilities{ + SpvOpImageQueryLevels, + CapabilitySet{SpvCapabilityKernel, SpvCapabilityImageQuery}}, + ExpectedOpCodeCapabilities{ + SpvOpImageQuerySamples, + CapabilitySet{SpvCapabilityKernel, SpvCapabilityImageQuery}}, + ExpectedOpCodeCapabilities{SpvOpImageSparseSampleImplicitLod, + CapabilitySet{SpvCapabilitySparseResidency}}, + ExpectedOpCodeCapabilities{SpvOpCopyMemorySized, + CapabilitySet{SpvCapabilityAddresses}}, + ExpectedOpCodeCapabilities{SpvOpArrayLength, + CapabilitySet{SpvCapabilityShader}}, + ExpectedOpCodeCapabilities{SpvOpFunction, CapabilitySet()}, + ExpectedOpCodeCapabilities{SpvOpConvertFToS, CapabilitySet()}, + ExpectedOpCodeCapabilities{SpvOpEmitStreamVertex, + CapabilitySet{SpvCapabilityGeometryStreams}}, + ExpectedOpCodeCapabilities{SpvOpTypeNamedBarrier, + CapabilitySet{SpvCapabilityNamedBarrier}}, + ExpectedOpCodeCapabilities{ + SpvOpGetKernelMaxNumSubgroups, + CapabilitySet{SpvCapabilitySubgroupDispatch}}), ); + +} // namespace +} // namespace spvtools diff --git a/test/opcode_split_test.cpp b/test/opcode_split_test.cpp new file mode 100644 index 000000000..43fedb385 --- /dev/null +++ b/test/opcode_split_test.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +TEST(OpcodeSplit, Default) { + uint32_t word = spvOpcodeMake(42, (SpvOp)23); + uint16_t wordCount = 0; + uint16_t opcode; + spvOpcodeSplit(word, &wordCount, &opcode); + ASSERT_EQ(42, wordCount); + ASSERT_EQ(23, opcode); +} + +} // namespace +} // namespace spvtools diff --git a/test/opcode_table_get_test.cpp b/test/opcode_table_get_test.cpp new file mode 100644 index 000000000..6f80ad7d8 --- /dev/null +++ b/test/opcode_table_get_test.cpp @@ -0,0 +1,39 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using GetTargetOpcodeTableGetTest = ::testing::TestWithParam; +using ::testing::ValuesIn; + +TEST_P(GetTargetOpcodeTableGetTest, SanityCheck) { + spv_opcode_table table; + ASSERT_EQ(SPV_SUCCESS, spvOpcodeTableGet(&table, GetParam())); + ASSERT_NE(0u, table->count); + ASSERT_NE(nullptr, table->entries); +} + +TEST_P(GetTargetOpcodeTableGetTest, InvalidPointerTable) { + ASSERT_EQ(SPV_ERROR_INVALID_POINTER, spvOpcodeTableGet(nullptr, GetParam())); +} + +INSTANTIATE_TEST_CASE_P(OpcodeTableGet, GetTargetOpcodeTableGetTest, + ValuesIn(spvtest::AllTargetEnvironments())); + +} // namespace +} // namespace spvtools diff --git a/test/operand-class-test-coverage.csv b/test/operand-class-test-coverage.csv new file mode 100644 index 000000000..116fdfeee --- /dev/null +++ b/test/operand-class-test-coverage.csv @@ -0,0 +1,43 @@ +Operand class,Example instruction,Notes,example unit test,negative-enum coverage location +" OperandNone,",UNUSED,not in grammar,,not enum +" OperandId,",many,ID,too many to count,not enum +" OperandOptionalId,","Source, Variable",OPTIONAL_ID,OpSourceAcceptsOptionalFileId,not enum +" OperandOptionalImage,",ImageFetch,,ImageOperandsTest,"TEST_F(ImageOperandsTest, WrongOperand)" +" OperandVariableIds,",ExtInst,,,not enum +" OperandOptionalLiteral,",ExecutionMode,,AnyExecutionMode,not enum +" OperandOptionalLiteralString,",Source,,OpSourceAcceptsOptionalSourceText,not enum +" OperandVariableLiterals,",Decorate,,OpDecorateSimpleTest,not enum +" OperandVariableIdLiteral,",GroupMemberDecorate,,GroupMemberDecorate*,not enum +" OperandVariableLiteralId,",Switch,,Switch*,not enum +" OperandLiteralNumber,","Source, Switch, ...",,Switch*,not enum +" OperandLiteralString,",SourceContinued,,OpSourceContinued,not enum +" OperandSource,",Source,,OpSource,not enum +" OperandExecutionModel,",EntryPoint,,OpEntryPointTest,"TEST_F(OpEntryPointTest, WrongModel)" +" OperandAddressing,",OpMemoryModel,,OpMemoryModelTest,"TEST_F(OpMemoryModelTest, WrongModel)" +" OperandMemory,",OpMemoryModel,,OpMemoryModelTest,"TEST_F(OpMemoryModelTest, WrongModel)" +" OperandExecutionMode,",OpExecutionMode,,OpExecutionModeTest,"TEST_F(OpExecutionModeTest, WrongMode)" +" OperandStorage,","TypePointer, TypeForwardPointer, Variable",,StorageClassTest,"TEST_F(OpTypeForwardPointerTest, WrongClass)" +" OperandDimensionality,",TypeImage,,DimTest/AnyDim,"TEST_F(DimTest, WrongDim)" +" OperandSamplerAddressingMode,",ConstantSampler,,SamplerAddressingModeTest,"TEST_F(SamplerAddressingModeTest, WrongMode)" +" OperandSamplerFilterMode,",ConstantSampler,,AnySamplerFilterMode,"TEST_F(SamplerFilterModeTest, WrongMode)" +" OperandSamplerImageFormat,",TypeImage,SAMPLER_IMAGE_FORMAT,ImageFormatTest,"TEST_F(ImageFormatTest, WrongFormat)" +" OperandImageChannelOrder,",UNUSED,returned as result value only,, +" OperandImageChannelDataType,",UNUSED,returned as result value only,, +" OperandImageOperands,",UNUSED,used to make a spec section,,see OperandOptionalImage +" OperandFPFastMath,",OpDecorate,,CombinedFPFastMathMask,"TEST_F(OpDecorateEnumTest, WrongFPFastMathMode)" +" OperandFPRoundingMode,",OpDecorate,,,"TEST_F(OpDecorateEnumTest, WrongFPRoundingMode)" +" OperandLinkageType,",OpDecorate,,OpDecorateLinkageTest,"TEST_F(OpDecorateLinkageTest, WrongType)" +" OperandAccessQualifier,",OpTypePipe,,AnyAccessQualifier,"TEST_F(OpTypePipeTest, WrongAccessQualifier)" +" OperandFuncParamAttr,",OpDecorate,,TextToBinaryDecorateFuncParamAttr,"TEST_F(OpDecorateEnumTest, WrongFuncParamAttr)" +" OperandDecoration,",OpDecorate,,AnyAccessQualifier,"TEST_F(OpTypePipeTest, WrongAccessQualifier)" +" OperandBuiltIn,",OpDecorate,,TextToBinaryDecorateBultIn,"TEST_F(OpDecorateEnumTest, WrongBuiltIn)" +" OperandSelect,",SelectionMerge,,TextToBinarySelectionMerge,"TEST_F(OpSelectionMergeTest, WrongSelectionControl)" +" OperandLoop,",LoopMerge,,CombinedLoopControlMask,"TEST_F(OpLoopMergeTest, WrongLoopControl)" +" OperandFunction,",Function,,AnySingleFunctionControlMask,"TEST_F(OpFunctionControlTest, WrongFunctionControl)" +" OperandMemorySemantics,",OpMemoryBarrier,"it's an ID, not in grammar",OpMemoryBarrier*,not enum +" OperandMemoryAccess,",UNUSED,"should be on opstore, but hacked in opcode.cpp",,not enum +" OperandScope,",MemoryBarrier,"it's an ID, not in grammar",OpMemoryBarrier*,not enum +" OperandGroupOperation,",GroupIAdd,,GroupOperationTest,"TEST_F(GroupOperationTest, WrongGroupOperation)" +" OperandKernelEnqueueFlags,",OpEnqueueKernel,"it's an ID, not in grammar",should not have one,not enum +" OperandKernelProfilingInfo,",OpCaptureEventProfilingInfo,"it's an ID, not in grammar",should not have one,not enum +" OperandCapability,",Capability,,OpCapabilityTest,"TEST_F(TextToBinaryCapability, BadInvalidCapability)" diff --git a/test/operand_capabilities_test.cpp b/test/operand_capabilities_test.cpp new file mode 100644 index 000000000..e58bc9d2b --- /dev/null +++ b/test/operand_capabilities_test.cpp @@ -0,0 +1,736 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Test capability dependencies for enums. + +#include +#include + +#include "gmock/gmock.h" +#include "source/enum_set.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using spvtest::ElementsIn; +using ::testing::Combine; +using ::testing::Eq; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +// A test case for mapping an enum to a capability mask. +struct EnumCapabilityCase { + spv_operand_type_t type; + uint32_t value; + CapabilitySet expected_capabilities; +}; + +// Test fixture for testing EnumCapabilityCases. +using EnumCapabilityTest = + TestWithParam>; + +TEST_P(EnumCapabilityTest, Sample) { + const auto env = std::get<0>(GetParam()); + const auto context = spvContextCreate(env); + const AssemblyGrammar grammar(context); + spv_operand_desc entry; + + ASSERT_EQ(SPV_SUCCESS, + grammar.lookupOperand(std::get<1>(GetParam()).type, + std::get<1>(GetParam()).value, &entry)); + const auto cap_set = grammar.filterCapsAgainstTargetEnv( + entry->capabilities, entry->numCapabilities); + + EXPECT_THAT(ElementsIn(cap_set), + Eq(ElementsIn(std::get<1>(GetParam()).expected_capabilities))) + << " capability value " << std::get<1>(GetParam()).value; + spvContextDestroy(context); +} + +#define CASE0(TYPE, VALUE) \ + { \ + SPV_OPERAND_TYPE_##TYPE, uint32_t(Spv##VALUE), {} \ + } +#define CASE1(TYPE, VALUE, CAP) \ + { \ + SPV_OPERAND_TYPE_##TYPE, uint32_t(Spv##VALUE), CapabilitySet { \ + SpvCapability##CAP \ + } \ + } +#define CASE2(TYPE, VALUE, CAP1, CAP2) \ + { \ + SPV_OPERAND_TYPE_##TYPE, uint32_t(Spv##VALUE), CapabilitySet { \ + SpvCapability##CAP1, SpvCapability##CAP2 \ + } \ + } +#define CASE3(TYPE, VALUE, CAP1, CAP2, CAP3) \ + { \ + SPV_OPERAND_TYPE_##TYPE, uint32_t(Spv##VALUE), CapabilitySet { \ + SpvCapability##CAP1, SpvCapability##CAP2, SpvCapability##CAP3 \ + } \ + } +#define CASE5(TYPE, VALUE, CAP1, CAP2, CAP3, CAP4, CAP5) \ + { \ + SPV_OPERAND_TYPE_##TYPE, uint32_t(Spv##VALUE), CapabilitySet { \ + SpvCapability##CAP1, SpvCapability##CAP2, SpvCapability##CAP3, \ + SpvCapability##CAP4, SpvCapability##CAP5 \ + } \ + } + +// See SPIR-V Section 3.3 Execution Model +INSTANTIATE_TEST_CASE_P( + ExecutionModel, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + CASE1(EXECUTION_MODEL, ExecutionModelVertex, Shader), + CASE1(EXECUTION_MODEL, ExecutionModelTessellationControl, + Tessellation), + CASE1(EXECUTION_MODEL, ExecutionModelTessellationEvaluation, + Tessellation), + CASE1(EXECUTION_MODEL, ExecutionModelGeometry, Geometry), + CASE1(EXECUTION_MODEL, ExecutionModelFragment, Shader), + CASE1(EXECUTION_MODEL, ExecutionModelGLCompute, Shader), + CASE1(EXECUTION_MODEL, ExecutionModelKernel, Kernel), + })), ); + +// See SPIR-V Section 3.4 Addressing Model +INSTANTIATE_TEST_CASE_P( + AddressingModel, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + CASE0(ADDRESSING_MODEL, AddressingModelLogical), + CASE1(ADDRESSING_MODEL, AddressingModelPhysical32, Addresses), + CASE1(ADDRESSING_MODEL, AddressingModelPhysical64, Addresses), + })), ); + +// See SPIR-V Section 3.5 Memory Model +INSTANTIATE_TEST_CASE_P( + MemoryModel, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + CASE1(MEMORY_MODEL, MemoryModelSimple, Shader), + CASE1(MEMORY_MODEL, MemoryModelGLSL450, Shader), + CASE1(MEMORY_MODEL, MemoryModelOpenCL, Kernel), + })), ); + +// See SPIR-V Section 3.6 Execution Mode +INSTANTIATE_TEST_CASE_P( + ExecutionMode, EnumCapabilityTest, + Combine( + Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + CASE1(EXECUTION_MODE, ExecutionModeInvocations, Geometry), + CASE1(EXECUTION_MODE, ExecutionModeSpacingEqual, Tessellation), + CASE1(EXECUTION_MODE, ExecutionModeSpacingFractionalEven, + Tessellation), + CASE1(EXECUTION_MODE, ExecutionModeSpacingFractionalOdd, + Tessellation), + CASE1(EXECUTION_MODE, ExecutionModeVertexOrderCw, Tessellation), + CASE1(EXECUTION_MODE, ExecutionModeVertexOrderCcw, Tessellation), + CASE1(EXECUTION_MODE, ExecutionModePixelCenterInteger, Shader), + CASE1(EXECUTION_MODE, ExecutionModeOriginUpperLeft, Shader), + CASE1(EXECUTION_MODE, ExecutionModeOriginLowerLeft, Shader), + CASE1(EXECUTION_MODE, ExecutionModeEarlyFragmentTests, Shader), + CASE1(EXECUTION_MODE, ExecutionModePointMode, Tessellation), + CASE1(EXECUTION_MODE, ExecutionModeXfb, TransformFeedback), + CASE1(EXECUTION_MODE, ExecutionModeDepthReplacing, Shader), + CASE1(EXECUTION_MODE, ExecutionModeDepthGreater, Shader), + CASE1(EXECUTION_MODE, ExecutionModeDepthLess, Shader), + CASE1(EXECUTION_MODE, ExecutionModeDepthUnchanged, Shader), + CASE0(EXECUTION_MODE, ExecutionModeLocalSize), + CASE1(EXECUTION_MODE, ExecutionModeLocalSizeHint, Kernel), + CASE1(EXECUTION_MODE, ExecutionModeInputPoints, Geometry), + CASE1(EXECUTION_MODE, ExecutionModeInputLines, Geometry), + CASE1(EXECUTION_MODE, ExecutionModeInputLinesAdjacency, Geometry), + CASE2(EXECUTION_MODE, ExecutionModeTriangles, Geometry, + Tessellation), + CASE1(EXECUTION_MODE, ExecutionModeInputTrianglesAdjacency, + Geometry), + CASE1(EXECUTION_MODE, ExecutionModeQuads, Tessellation), + CASE1(EXECUTION_MODE, ExecutionModeIsolines, Tessellation), + CASE3(EXECUTION_MODE, ExecutionModeOutputVertices, Geometry, + Tessellation, MeshShadingNV), + CASE2(EXECUTION_MODE, ExecutionModeOutputPoints, Geometry, + MeshShadingNV), + CASE1(EXECUTION_MODE, ExecutionModeOutputLineStrip, Geometry), + CASE1(EXECUTION_MODE, ExecutionModeOutputTriangleStrip, Geometry), + CASE1(EXECUTION_MODE, ExecutionModeVecTypeHint, Kernel), + CASE1(EXECUTION_MODE, ExecutionModeContractionOff, Kernel), + })), ); + +INSTANTIATE_TEST_CASE_P( + ExecutionModeV11, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + CASE1(EXECUTION_MODE, ExecutionModeInitializer, Kernel), + CASE1(EXECUTION_MODE, ExecutionModeFinalizer, Kernel), + CASE1(EXECUTION_MODE, ExecutionModeSubgroupSize, + SubgroupDispatch), + CASE1(EXECUTION_MODE, ExecutionModeSubgroupsPerWorkgroup, + SubgroupDispatch)})), ); + +// See SPIR-V Section 3.7 Storage Class +INSTANTIATE_TEST_CASE_P( + StorageClass, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + CASE0(STORAGE_CLASS, StorageClassUniformConstant), + CASE1(STORAGE_CLASS, StorageClassUniform, Shader), + CASE1(STORAGE_CLASS, StorageClassOutput, Shader), + CASE0(STORAGE_CLASS, StorageClassWorkgroup), + CASE0(STORAGE_CLASS, StorageClassCrossWorkgroup), + CASE1(STORAGE_CLASS, StorageClassPrivate, Shader), + CASE0(STORAGE_CLASS, StorageClassFunction), + CASE1(STORAGE_CLASS, StorageClassGeneric, + GenericPointer), // Bug 14287 + CASE1(STORAGE_CLASS, StorageClassPushConstant, Shader), + CASE1(STORAGE_CLASS, StorageClassAtomicCounter, AtomicStorage), + CASE0(STORAGE_CLASS, StorageClassImage), + })), ); + +// See SPIR-V Section 3.8 Dim +INSTANTIATE_TEST_CASE_P( + Dim, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + CASE2(DIMENSIONALITY, Dim1D, Sampled1D, Image1D), + CASE3(DIMENSIONALITY, Dim2D, Kernel, Shader, ImageMSArray), + CASE0(DIMENSIONALITY, Dim3D), + CASE2(DIMENSIONALITY, DimCube, Shader, ImageCubeArray), + CASE2(DIMENSIONALITY, DimRect, SampledRect, ImageRect), + CASE2(DIMENSIONALITY, DimBuffer, SampledBuffer, ImageBuffer), + CASE1(DIMENSIONALITY, DimSubpassData, InputAttachment), + })), ); + +// See SPIR-V Section 3.9 Sampler Addressing Mode +INSTANTIATE_TEST_CASE_P( + SamplerAddressingMode, EnumCapabilityTest, + Combine( + Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + CASE1(SAMPLER_ADDRESSING_MODE, SamplerAddressingModeNone, Kernel), + CASE1(SAMPLER_ADDRESSING_MODE, SamplerAddressingModeClampToEdge, + Kernel), + CASE1(SAMPLER_ADDRESSING_MODE, SamplerAddressingModeClamp, Kernel), + CASE1(SAMPLER_ADDRESSING_MODE, SamplerAddressingModeRepeat, Kernel), + CASE1(SAMPLER_ADDRESSING_MODE, SamplerAddressingModeRepeatMirrored, + Kernel), + })), ); + +// See SPIR-V Section 3.10 Sampler Filter Mode +INSTANTIATE_TEST_CASE_P( + SamplerFilterMode, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + CASE1(SAMPLER_FILTER_MODE, SamplerFilterModeNearest, Kernel), + CASE1(SAMPLER_FILTER_MODE, SamplerFilterModeLinear, Kernel), + })), ); + +// See SPIR-V Section 3.11 Image Format +INSTANTIATE_TEST_CASE_P( + ImageFormat, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + // clang-format off + CASE0(SAMPLER_IMAGE_FORMAT, ImageFormatUnknown), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRgba32f, Shader), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRgba16f, Shader), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatR32f, Shader), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRgba8, Shader), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRgba8Snorm, Shader), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRg32f, StorageImageExtendedFormats), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRg16f, StorageImageExtendedFormats), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatR11fG11fB10f, StorageImageExtendedFormats), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatR16f, StorageImageExtendedFormats), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRgba16, StorageImageExtendedFormats), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRgb10A2, StorageImageExtendedFormats), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRg16, StorageImageExtendedFormats), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRg8, StorageImageExtendedFormats), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatR16, StorageImageExtendedFormats), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatR8, StorageImageExtendedFormats), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRgba16Snorm, StorageImageExtendedFormats), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRg16Snorm, StorageImageExtendedFormats), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRg8Snorm, StorageImageExtendedFormats), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatR16Snorm, StorageImageExtendedFormats), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatR8Snorm, StorageImageExtendedFormats), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRgba32i, Shader), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRgba16i, Shader), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRgba8i, Shader), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatR32i, Shader), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRg32i, StorageImageExtendedFormats), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRg16i, StorageImageExtendedFormats), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRg8i, StorageImageExtendedFormats), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatR16i, StorageImageExtendedFormats), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatR8i, StorageImageExtendedFormats), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRgba32ui, Shader), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRgba16ui, Shader), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRgba8ui, Shader), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRgba8ui, Shader), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRgb10a2ui, StorageImageExtendedFormats), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRg32ui, StorageImageExtendedFormats), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRg16ui, StorageImageExtendedFormats), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatRg8ui, StorageImageExtendedFormats), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatR16ui, StorageImageExtendedFormats), + CASE1(SAMPLER_IMAGE_FORMAT, ImageFormatR8ui, StorageImageExtendedFormats), + // clang-format on + })), ); + +// See SPIR-V Section 3.12 Image Channel Order +INSTANTIATE_TEST_CASE_P( + ImageChannelOrder, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + CASE1(IMAGE_CHANNEL_ORDER, ImageChannelOrderR, Kernel), + CASE1(IMAGE_CHANNEL_ORDER, ImageChannelOrderA, Kernel), + CASE1(IMAGE_CHANNEL_ORDER, ImageChannelOrderRG, Kernel), + CASE1(IMAGE_CHANNEL_ORDER, ImageChannelOrderRA, Kernel), + CASE1(IMAGE_CHANNEL_ORDER, ImageChannelOrderRGB, Kernel), + CASE1(IMAGE_CHANNEL_ORDER, ImageChannelOrderRGBA, Kernel), + CASE1(IMAGE_CHANNEL_ORDER, ImageChannelOrderBGRA, Kernel), + CASE1(IMAGE_CHANNEL_ORDER, ImageChannelOrderARGB, Kernel), + CASE1(IMAGE_CHANNEL_ORDER, ImageChannelOrderIntensity, Kernel), + CASE1(IMAGE_CHANNEL_ORDER, ImageChannelOrderLuminance, Kernel), + CASE1(IMAGE_CHANNEL_ORDER, ImageChannelOrderRx, Kernel), + CASE1(IMAGE_CHANNEL_ORDER, ImageChannelOrderRGx, Kernel), + CASE1(IMAGE_CHANNEL_ORDER, ImageChannelOrderRGBx, Kernel), + CASE1(IMAGE_CHANNEL_ORDER, ImageChannelOrderDepth, Kernel), + CASE1(IMAGE_CHANNEL_ORDER, ImageChannelOrderDepthStencil, + Kernel), + CASE1(IMAGE_CHANNEL_ORDER, ImageChannelOrdersRGB, Kernel), + CASE1(IMAGE_CHANNEL_ORDER, ImageChannelOrdersRGBx, Kernel), + CASE1(IMAGE_CHANNEL_ORDER, ImageChannelOrdersRGBA, Kernel), + CASE1(IMAGE_CHANNEL_ORDER, ImageChannelOrdersBGRA, Kernel), + CASE1(IMAGE_CHANNEL_ORDER, ImageChannelOrderABGR, Kernel), + })), ); + +// See SPIR-V Section 3.13 Image Channel Data Type +INSTANTIATE_TEST_CASE_P( + ImageChannelDataType, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + // clang-format off + CASE1(IMAGE_CHANNEL_DATA_TYPE, ImageChannelDataTypeSnormInt8, Kernel), + CASE1(IMAGE_CHANNEL_DATA_TYPE, ImageChannelDataTypeSnormInt16, Kernel), + CASE1(IMAGE_CHANNEL_DATA_TYPE, ImageChannelDataTypeUnormInt8, Kernel), + CASE1(IMAGE_CHANNEL_DATA_TYPE, ImageChannelDataTypeUnormInt16, Kernel), + CASE1(IMAGE_CHANNEL_DATA_TYPE, ImageChannelDataTypeUnormShort565, Kernel), + CASE1(IMAGE_CHANNEL_DATA_TYPE, ImageChannelDataTypeUnormShort555, Kernel), + CASE1(IMAGE_CHANNEL_DATA_TYPE, ImageChannelDataTypeUnormInt101010, Kernel), + CASE1(IMAGE_CHANNEL_DATA_TYPE, ImageChannelDataTypeSignedInt8, Kernel), + CASE1(IMAGE_CHANNEL_DATA_TYPE, ImageChannelDataTypeSignedInt16, Kernel), + CASE1(IMAGE_CHANNEL_DATA_TYPE, ImageChannelDataTypeSignedInt32, Kernel), + CASE1(IMAGE_CHANNEL_DATA_TYPE, ImageChannelDataTypeUnsignedInt8, Kernel), + CASE1(IMAGE_CHANNEL_DATA_TYPE, ImageChannelDataTypeUnsignedInt16, Kernel), + CASE1(IMAGE_CHANNEL_DATA_TYPE, ImageChannelDataTypeUnsignedInt32, Kernel), + CASE1(IMAGE_CHANNEL_DATA_TYPE, ImageChannelDataTypeHalfFloat, Kernel), + CASE1(IMAGE_CHANNEL_DATA_TYPE, ImageChannelDataTypeFloat, Kernel), + CASE1(IMAGE_CHANNEL_DATA_TYPE, ImageChannelDataTypeUnormInt24, Kernel), + CASE1(IMAGE_CHANNEL_DATA_TYPE, ImageChannelDataTypeUnormInt101010_2, Kernel), + // clang-format on + })), ); + +// See SPIR-V Section 3.14 Image Operands +INSTANTIATE_TEST_CASE_P( + ImageOperands, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + // clang-format off + CASE0(OPTIONAL_IMAGE, ImageOperandsMaskNone), + CASE1(OPTIONAL_IMAGE, ImageOperandsBiasMask, Shader), + CASE0(OPTIONAL_IMAGE, ImageOperandsLodMask), + CASE0(OPTIONAL_IMAGE, ImageOperandsGradMask), + CASE0(OPTIONAL_IMAGE, ImageOperandsConstOffsetMask), + CASE1(OPTIONAL_IMAGE, ImageOperandsOffsetMask, ImageGatherExtended), + CASE1(OPTIONAL_IMAGE, ImageOperandsConstOffsetsMask, ImageGatherExtended), + CASE0(OPTIONAL_IMAGE, ImageOperandsSampleMask), + CASE1(OPTIONAL_IMAGE, ImageOperandsMinLodMask, MinLod), + // clang-format on + })), ); + +// See SPIR-V Section 3.15 FP Fast Math Mode +INSTANTIATE_TEST_CASE_P( + FPFastMathMode, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + CASE0(FP_FAST_MATH_MODE, FPFastMathModeMaskNone), + CASE1(FP_FAST_MATH_MODE, FPFastMathModeNotNaNMask, Kernel), + CASE1(FP_FAST_MATH_MODE, FPFastMathModeNotInfMask, Kernel), + CASE1(FP_FAST_MATH_MODE, FPFastMathModeNSZMask, Kernel), + CASE1(FP_FAST_MATH_MODE, FPFastMathModeAllowRecipMask, Kernel), + CASE1(FP_FAST_MATH_MODE, FPFastMathModeFastMask, Kernel), + })), ); + +// See SPIR-V Section 3.17 Linkage Type +INSTANTIATE_TEST_CASE_P( + LinkageType, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + CASE1(LINKAGE_TYPE, LinkageTypeExport, Linkage), + CASE1(LINKAGE_TYPE, LinkageTypeImport, Linkage), + })), ); + +// See SPIR-V Section 3.18 Access Qualifier +INSTANTIATE_TEST_CASE_P( + AccessQualifier, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + CASE1(ACCESS_QUALIFIER, AccessQualifierReadOnly, Kernel), + CASE1(ACCESS_QUALIFIER, AccessQualifierWriteOnly, Kernel), + CASE1(ACCESS_QUALIFIER, AccessQualifierReadWrite, Kernel), + })), ); + +// See SPIR-V Section 3.19 Function Parameter Attribute +INSTANTIATE_TEST_CASE_P( + FunctionParameterAttribute, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + // clang-format off + CASE1(FUNCTION_PARAMETER_ATTRIBUTE, FunctionParameterAttributeZext, Kernel), + CASE1(FUNCTION_PARAMETER_ATTRIBUTE, FunctionParameterAttributeSext, Kernel), + CASE1(FUNCTION_PARAMETER_ATTRIBUTE, FunctionParameterAttributeByVal, Kernel), + CASE1(FUNCTION_PARAMETER_ATTRIBUTE, FunctionParameterAttributeSret, Kernel), + CASE1(FUNCTION_PARAMETER_ATTRIBUTE, FunctionParameterAttributeNoAlias, Kernel), + CASE1(FUNCTION_PARAMETER_ATTRIBUTE, FunctionParameterAttributeNoCapture, Kernel), + CASE1(FUNCTION_PARAMETER_ATTRIBUTE, FunctionParameterAttributeNoWrite, Kernel), + CASE1(FUNCTION_PARAMETER_ATTRIBUTE, FunctionParameterAttributeNoReadWrite, Kernel), + // clang-format on + })), ); + +// See SPIR-V Section 3.20 Decoration +INSTANTIATE_TEST_CASE_P( + Decoration, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + CASE1(DECORATION, DecorationRelaxedPrecision, Shader), + // DecorationSpecId handled below. + CASE1(DECORATION, DecorationBlock, Shader), + CASE1(DECORATION, DecorationBufferBlock, Shader), + CASE1(DECORATION, DecorationRowMajor, Matrix), + CASE1(DECORATION, DecorationColMajor, Matrix), + CASE1(DECORATION, DecorationArrayStride, Shader), + CASE1(DECORATION, DecorationMatrixStride, Matrix), // Bug 15234 + CASE1(DECORATION, DecorationGLSLShared, Shader), + CASE1(DECORATION, DecorationGLSLPacked, Shader), + CASE1(DECORATION, DecorationCPacked, Kernel), + CASE0(DECORATION, DecorationBuiltIn), // Bug 15248 + // Value 12 placeholder + CASE1(DECORATION, DecorationNoPerspective, Shader), + CASE1(DECORATION, DecorationFlat, Shader), + CASE1(DECORATION, DecorationPatch, Tessellation), + CASE1(DECORATION, DecorationCentroid, Shader), + CASE1(DECORATION, DecorationSample, + SampleRateShading), // Bug 15234 + CASE1(DECORATION, DecorationInvariant, Shader), + CASE0(DECORATION, DecorationRestrict), + CASE0(DECORATION, DecorationAliased), + CASE0(DECORATION, DecorationVolatile), + CASE1(DECORATION, DecorationConstant, Kernel), + CASE0(DECORATION, DecorationCoherent), + CASE0(DECORATION, DecorationNonWritable), + CASE0(DECORATION, DecorationNonReadable), + CASE1(DECORATION, DecorationUniform, Shader), + // Value 27 is an intentional gap in the spec numbering. + CASE1(DECORATION, DecorationSaturatedConversion, Kernel), + CASE1(DECORATION, DecorationStream, GeometryStreams), + CASE1(DECORATION, DecorationLocation, Shader), + CASE1(DECORATION, DecorationComponent, Shader), + CASE1(DECORATION, DecorationIndex, Shader), + CASE1(DECORATION, DecorationBinding, Shader), + CASE1(DECORATION, DecorationDescriptorSet, Shader), + CASE1(DECORATION, DecorationOffset, Shader), // Bug 15268 + CASE1(DECORATION, DecorationXfbBuffer, TransformFeedback), + CASE1(DECORATION, DecorationXfbStride, TransformFeedback), + CASE1(DECORATION, DecorationFuncParamAttr, Kernel), + CASE1(DECORATION, DecorationFPFastMathMode, Kernel), + CASE1(DECORATION, DecorationLinkageAttributes, Linkage), + CASE1(DECORATION, DecorationNoContraction, Shader), + CASE1(DECORATION, DecorationInputAttachmentIndex, + InputAttachment), + CASE1(DECORATION, DecorationAlignment, Kernel), + })), ); + +#if 0 +// SpecId has different requirements in v1.0 and v1.1: +INSTANTIATE_TEST_CASE_P(DecorationSpecIdV10, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0), + ValuesIn(std::vector{CASE1( + DECORATION, DecorationSpecId, Shader)})), ); +#endif + +INSTANTIATE_TEST_CASE_P( + DecorationV11, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + CASE2(DECORATION, DecorationSpecId, Shader, Kernel), + CASE1(DECORATION, DecorationMaxByteOffset, Addresses)})), ); + +// See SPIR-V Section 3.21 BuiltIn +INSTANTIATE_TEST_CASE_P( + BuiltIn, EnumCapabilityTest, + Combine( + Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + // clang-format off + CASE1(BUILT_IN, BuiltInPosition, Shader), + CASE1(BUILT_IN, BuiltInPointSize, Shader), + // 2 is an intentional gap in the spec numbering. + CASE1(BUILT_IN, BuiltInClipDistance, ClipDistance), // Bug 1407, 15234 + CASE1(BUILT_IN, BuiltInCullDistance, CullDistance), // Bug 1407, 15234 + CASE1(BUILT_IN, BuiltInVertexId, Shader), + CASE1(BUILT_IN, BuiltInInstanceId, Shader), + CASE3(BUILT_IN, BuiltInPrimitiveId, Geometry, Tessellation, + RayTracingNV), + CASE2(BUILT_IN, BuiltInInvocationId, Geometry, Tessellation), + CASE1(BUILT_IN, BuiltInLayer, Geometry), + CASE1(BUILT_IN, BuiltInViewportIndex, MultiViewport), // Bug 15234 + CASE1(BUILT_IN, BuiltInTessLevelOuter, Tessellation), + CASE1(BUILT_IN, BuiltInTessLevelInner, Tessellation), + CASE1(BUILT_IN, BuiltInTessCoord, Tessellation), + CASE1(BUILT_IN, BuiltInPatchVertices, Tessellation), + CASE1(BUILT_IN, BuiltInFragCoord, Shader), + CASE1(BUILT_IN, BuiltInPointCoord, Shader), + CASE1(BUILT_IN, BuiltInFrontFacing, Shader), + CASE1(BUILT_IN, BuiltInSampleId, SampleRateShading), // Bug 15234 + CASE1(BUILT_IN, BuiltInSamplePosition, SampleRateShading), // Bug 15234 + CASE1(BUILT_IN, BuiltInSampleMask, Shader), // Bug 15234, Issue 182 + // Value 21 intentionally missing + CASE1(BUILT_IN, BuiltInFragDepth, Shader), + CASE1(BUILT_IN, BuiltInHelperInvocation, Shader), + CASE0(BUILT_IN, BuiltInNumWorkgroups), + CASE0(BUILT_IN, BuiltInWorkgroupSize), + CASE0(BUILT_IN, BuiltInWorkgroupId), + CASE0(BUILT_IN, BuiltInLocalInvocationId), + CASE0(BUILT_IN, BuiltInGlobalInvocationId), + CASE0(BUILT_IN, BuiltInLocalInvocationIndex), + CASE1(BUILT_IN, BuiltInWorkDim, Kernel), + CASE1(BUILT_IN, BuiltInGlobalSize, Kernel), + CASE1(BUILT_IN, BuiltInEnqueuedWorkgroupSize, Kernel), + CASE1(BUILT_IN, BuiltInGlobalOffset, Kernel), + CASE1(BUILT_IN, BuiltInGlobalLinearId, Kernel), + // Value 35 intentionally missing + CASE2(BUILT_IN, BuiltInSubgroupSize, Kernel, SubgroupBallotKHR), + CASE1(BUILT_IN, BuiltInSubgroupMaxSize, Kernel), + CASE1(BUILT_IN, BuiltInNumSubgroups, Kernel), + CASE1(BUILT_IN, BuiltInNumEnqueuedSubgroups, Kernel), + CASE1(BUILT_IN, BuiltInSubgroupId, Kernel), + CASE2(BUILT_IN, BuiltInSubgroupLocalInvocationId, Kernel, SubgroupBallotKHR), + CASE1(BUILT_IN, BuiltInVertexIndex, Shader), + CASE1(BUILT_IN, BuiltInInstanceIndex, Shader), + // clang-format on + })), ); + +// See SPIR-V Section 3.22 Selection Control +INSTANTIATE_TEST_CASE_P( + SelectionControl, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + CASE0(SELECTION_CONTROL, SelectionControlMaskNone), + CASE0(SELECTION_CONTROL, SelectionControlFlattenMask), + CASE0(SELECTION_CONTROL, SelectionControlDontFlattenMask), + })), ); + +// See SPIR-V Section 3.23 Loop Control +INSTANTIATE_TEST_CASE_P( + LoopControl, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + CASE0(LOOP_CONTROL, LoopControlMaskNone), + CASE0(LOOP_CONTROL, LoopControlUnrollMask), + CASE0(LOOP_CONTROL, LoopControlDontUnrollMask), + })), ); + +INSTANTIATE_TEST_CASE_P( + LoopControlV11, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + CASE0(LOOP_CONTROL, LoopControlDependencyInfiniteMask), + CASE0(LOOP_CONTROL, LoopControlDependencyLengthMask), + })), ); + +// See SPIR-V Section 3.24 Function Control +INSTANTIATE_TEST_CASE_P( + FunctionControl, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + CASE0(FUNCTION_CONTROL, FunctionControlMaskNone), + CASE0(FUNCTION_CONTROL, FunctionControlInlineMask), + CASE0(FUNCTION_CONTROL, FunctionControlDontInlineMask), + CASE0(FUNCTION_CONTROL, FunctionControlPureMask), + CASE0(FUNCTION_CONTROL, FunctionControlConstMask), + })), ); + +// See SPIR-V Section 3.25 Memory Semantics +INSTANTIATE_TEST_CASE_P( + MemorySemantics, EnumCapabilityTest, + Combine( + Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + CASE0(MEMORY_SEMANTICS_ID, MemorySemanticsMaskNone), + CASE0(MEMORY_SEMANTICS_ID, MemorySemanticsAcquireMask), + CASE0(MEMORY_SEMANTICS_ID, MemorySemanticsReleaseMask), + CASE0(MEMORY_SEMANTICS_ID, MemorySemanticsAcquireReleaseMask), + CASE0(MEMORY_SEMANTICS_ID, + MemorySemanticsSequentiallyConsistentMask), + CASE1(MEMORY_SEMANTICS_ID, MemorySemanticsUniformMemoryMask, + Shader), + CASE0(MEMORY_SEMANTICS_ID, MemorySemanticsSubgroupMemoryMask), + CASE0(MEMORY_SEMANTICS_ID, MemorySemanticsWorkgroupMemoryMask), + CASE0(MEMORY_SEMANTICS_ID, MemorySemanticsCrossWorkgroupMemoryMask), + CASE1(MEMORY_SEMANTICS_ID, MemorySemanticsAtomicCounterMemoryMask, + AtomicStorage), // Bug 15234 + CASE0(MEMORY_SEMANTICS_ID, MemorySemanticsImageMemoryMask), + })), ); + +// See SPIR-V Section 3.26 Memory Access +INSTANTIATE_TEST_CASE_P( + MemoryAccess, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + CASE0(OPTIONAL_MEMORY_ACCESS, MemoryAccessMaskNone), + CASE0(OPTIONAL_MEMORY_ACCESS, MemoryAccessVolatileMask), + CASE0(OPTIONAL_MEMORY_ACCESS, MemoryAccessAlignedMask), + CASE0(OPTIONAL_MEMORY_ACCESS, MemoryAccessNontemporalMask), + })), ); + +// See SPIR-V Section 3.27 Scope +INSTANTIATE_TEST_CASE_P( + Scope, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, + SPV_ENV_UNIVERSAL_1_2, SPV_ENV_UNIVERSAL_1_3), + ValuesIn(std::vector{ + CASE0(SCOPE_ID, ScopeCrossDevice), + CASE0(SCOPE_ID, ScopeDevice), + CASE0(SCOPE_ID, ScopeWorkgroup), + CASE0(SCOPE_ID, ScopeSubgroup), + CASE0(SCOPE_ID, ScopeInvocation), + CASE1(SCOPE_ID, ScopeQueueFamilyKHR, VulkanMemoryModelKHR), + })), ); + +// See SPIR-V Section 3.28 Group Operation +INSTANTIATE_TEST_CASE_P( + GroupOperation, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + CASE3(GROUP_OPERATION, GroupOperationReduce, Kernel, + GroupNonUniformArithmetic, GroupNonUniformBallot), + CASE3(GROUP_OPERATION, GroupOperationInclusiveScan, Kernel, + GroupNonUniformArithmetic, GroupNonUniformBallot), + CASE3(GROUP_OPERATION, GroupOperationExclusiveScan, Kernel, + GroupNonUniformArithmetic, GroupNonUniformBallot), + })), ); + +// See SPIR-V Section 3.29 Kernel Enqueue Flags +INSTANTIATE_TEST_CASE_P( + KernelEnqueueFlags, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + CASE1(KERNEL_ENQ_FLAGS, KernelEnqueueFlagsNoWait, Kernel), + CASE1(KERNEL_ENQ_FLAGS, KernelEnqueueFlagsWaitKernel, Kernel), + CASE1(KERNEL_ENQ_FLAGS, KernelEnqueueFlagsWaitWorkGroup, + Kernel), + })), ); + +// See SPIR-V Section 3.30 Kernel Profiling Info +INSTANTIATE_TEST_CASE_P( + KernelProfilingInfo, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + CASE0(KERNEL_PROFILING_INFO, KernelProfilingInfoMaskNone), + CASE1(KERNEL_PROFILING_INFO, KernelProfilingInfoCmdExecTimeMask, + Kernel), + })), ); + +// See SPIR-V Section 3.31 Capability +INSTANTIATE_TEST_CASE_P( + CapabilityDependsOn, EnumCapabilityTest, + Combine( + Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + // clang-format off + CASE0(CAPABILITY, CapabilityMatrix), + CASE1(CAPABILITY, CapabilityShader, Matrix), + CASE1(CAPABILITY, CapabilityGeometry, Shader), + CASE1(CAPABILITY, CapabilityTessellation, Shader), + CASE0(CAPABILITY, CapabilityAddresses), + CASE0(CAPABILITY, CapabilityLinkage), + CASE0(CAPABILITY, CapabilityKernel), + CASE1(CAPABILITY, CapabilityVector16, Kernel), + CASE1(CAPABILITY, CapabilityFloat16Buffer, Kernel), + CASE0(CAPABILITY, CapabilityFloat16), // Bug 15234 + CASE0(CAPABILITY, CapabilityFloat64), + CASE0(CAPABILITY, CapabilityInt64), + CASE1(CAPABILITY, CapabilityInt64Atomics, Int64), + CASE1(CAPABILITY, CapabilityImageBasic, Kernel), + CASE1(CAPABILITY, CapabilityImageReadWrite, ImageBasic), + CASE1(CAPABILITY, CapabilityImageMipmap, ImageBasic), + // Value 16 intentionally missing. + CASE1(CAPABILITY, CapabilityPipes, Kernel), + CASE0(CAPABILITY, CapabilityGroups), + CASE1(CAPABILITY, CapabilityDeviceEnqueue, Kernel), + CASE1(CAPABILITY, CapabilityLiteralSampler, Kernel), + CASE1(CAPABILITY, CapabilityAtomicStorage, Shader), + CASE0(CAPABILITY, CapabilityInt16), + CASE1(CAPABILITY, CapabilityTessellationPointSize, Tessellation), + CASE1(CAPABILITY, CapabilityGeometryPointSize, Geometry), + CASE1(CAPABILITY, CapabilityImageGatherExtended, Shader), + // Value 26 intentionally missing. + CASE1(CAPABILITY, CapabilityStorageImageMultisample, Shader), + CASE1(CAPABILITY, CapabilityUniformBufferArrayDynamicIndexing, Shader), + CASE1(CAPABILITY, CapabilitySampledImageArrayDynamicIndexing, Shader), + CASE1(CAPABILITY, CapabilityStorageBufferArrayDynamicIndexing, Shader), + CASE1(CAPABILITY, CapabilityStorageImageArrayDynamicIndexing, Shader), + CASE1(CAPABILITY, CapabilityClipDistance, Shader), + CASE1(CAPABILITY, CapabilityCullDistance, Shader), + CASE1(CAPABILITY, CapabilityImageCubeArray, SampledCubeArray), + CASE1(CAPABILITY, CapabilitySampleRateShading, Shader), + CASE1(CAPABILITY, CapabilityImageRect, SampledRect), + CASE1(CAPABILITY, CapabilitySampledRect, Shader), + CASE1(CAPABILITY, CapabilityGenericPointer, Addresses), + CASE0(CAPABILITY, CapabilityInt8), + CASE1(CAPABILITY, CapabilityInputAttachment, Shader), + CASE1(CAPABILITY, CapabilitySparseResidency, Shader), + CASE1(CAPABILITY, CapabilityMinLod, Shader), + CASE1(CAPABILITY, CapabilityImage1D, Sampled1D), + CASE1(CAPABILITY, CapabilitySampledCubeArray, Shader), + CASE1(CAPABILITY, CapabilityImageBuffer, SampledBuffer), + CASE1(CAPABILITY, CapabilityImageMSArray, Shader), + CASE1(CAPABILITY, CapabilityStorageImageExtendedFormats, Shader), + CASE1(CAPABILITY, CapabilityImageQuery, Shader), + CASE1(CAPABILITY, CapabilityDerivativeControl, Shader), + CASE1(CAPABILITY, CapabilityInterpolationFunction, Shader), + CASE1(CAPABILITY, CapabilityTransformFeedback, Shader), + CASE1(CAPABILITY, CapabilityGeometryStreams, Geometry), + CASE1(CAPABILITY, CapabilityStorageImageReadWithoutFormat, Shader), + CASE1(CAPABILITY, CapabilityStorageImageWriteWithoutFormat, Shader), + CASE1(CAPABILITY, CapabilityMultiViewport, Geometry), + // clang-format on + })), ); + +INSTANTIATE_TEST_CASE_P( + CapabilityDependsOnV11, EnumCapabilityTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector{ + CASE1(CAPABILITY, CapabilitySubgroupDispatch, DeviceEnqueue), + CASE1(CAPABILITY, CapabilityNamedBarrier, Kernel), + CASE1(CAPABILITY, CapabilityPipeStorage, Pipes), + })), ); + +#undef CASE0 +#undef CASE1 +#undef CASE2 + +} // namespace +} // namespace spvtools diff --git a/test/operand_pattern_test.cpp b/test/operand_pattern_test.cpp new file mode 100644 index 000000000..d2d92a01d --- /dev/null +++ b/test/operand_pattern_test.cpp @@ -0,0 +1,270 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "source/operand.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using ::testing::Eq; + +TEST(OperandPattern, InitiallyEmpty) { + spv_operand_pattern_t empty; + EXPECT_THAT(empty, Eq(spv_operand_pattern_t{})); + EXPECT_EQ(0u, empty.size()); + EXPECT_TRUE(empty.empty()); +} + +TEST(OperandPattern, PushBacksAreOnTheRight) { + spv_operand_pattern_t pattern; + + pattern.push_back(SPV_OPERAND_TYPE_ID); + EXPECT_THAT(pattern, Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_ID})); + EXPECT_EQ(1u, pattern.size()); + EXPECT_TRUE(!pattern.empty()); + EXPECT_EQ(SPV_OPERAND_TYPE_ID, pattern.back()); + + pattern.push_back(SPV_OPERAND_TYPE_NONE); + EXPECT_THAT(pattern, Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_ID, + SPV_OPERAND_TYPE_NONE})); + EXPECT_EQ(2u, pattern.size()); + EXPECT_TRUE(!pattern.empty()); + EXPECT_EQ(SPV_OPERAND_TYPE_NONE, pattern.back()); +} + +TEST(OperandPattern, PopBacksAreOnTheRight) { + spv_operand_pattern_t pattern{SPV_OPERAND_TYPE_ID, + SPV_OPERAND_TYPE_LITERAL_INTEGER}; + + pattern.pop_back(); + EXPECT_THAT(pattern, Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_ID})); + + pattern.pop_back(); + EXPECT_THAT(pattern, Eq(spv_operand_pattern_t{})); +} + +// A test case for typed mask expansion +struct MaskExpansionCase { + spv_operand_type_t type; + uint32_t mask; + spv_operand_pattern_t initial; + spv_operand_pattern_t expected; +}; + +using MaskExpansionTest = ::testing::TestWithParam; + +TEST_P(MaskExpansionTest, Sample) { + spv_operand_table operandTable = nullptr; + auto env = SPV_ENV_UNIVERSAL_1_0; + ASSERT_EQ(SPV_SUCCESS, spvOperandTableGet(&operandTable, env)); + + spv_operand_pattern_t pattern(GetParam().initial); + spvPushOperandTypesForMask(env, operandTable, GetParam().type, + GetParam().mask, &pattern); + EXPECT_THAT(pattern, Eq(GetParam().expected)); +} + +// These macros let us write non-trivial examples without too much text. +#define PREFIX0 SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_NONE +#define PREFIX1 \ + SPV_OPERAND_TYPE_STORAGE_CLASS, SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE, \ + SPV_OPERAND_TYPE_ID +INSTANTIATE_TEST_CASE_P( + OperandPattern, MaskExpansionTest, + ::testing::ValuesIn(std::vector{ + // No bits means no change. + {SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS, 0, {PREFIX0}, {PREFIX0}}, + // Unknown bits means no change. Use all bits that aren't in the + // grammar. + // The last mask enum is 0x20 + {SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS, + 0xffffffc0, + {PREFIX1}, + {PREFIX1}}, + // Volatile has no operands. + {SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS, + SpvMemoryAccessVolatileMask, + {PREFIX0}, + {PREFIX0}}, + // Aligned has one literal number operand. + {SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS, + SpvMemoryAccessAlignedMask, + {PREFIX1}, + {PREFIX1, SPV_OPERAND_TYPE_LITERAL_INTEGER}}, + // Volatile with Aligned still has just one literal number operand. + {SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS, + SpvMemoryAccessVolatileMask | SpvMemoryAccessAlignedMask, + {PREFIX1}, + {PREFIX1, SPV_OPERAND_TYPE_LITERAL_INTEGER}}, + }), ); +#undef PREFIX0 +#undef PREFIX1 + +// Returns a vector of all operand types that can be used in a pattern. +std::vector allOperandTypes() { + std::vector result; + for (int i = 0; i < SPV_OPERAND_TYPE_NUM_OPERAND_TYPES; i++) { + result.push_back(spv_operand_type_t(i)); + } + return result; +} + +using MatchableOperandExpansionTest = + ::testing::TestWithParam; + +TEST_P(MatchableOperandExpansionTest, MatchableOperandsDontExpand) { + const spv_operand_type_t type = GetParam(); + if (!spvOperandIsVariable(type)) { + spv_operand_pattern_t pattern; + const bool did_expand = spvExpandOperandSequenceOnce(type, &pattern); + EXPECT_FALSE(did_expand); + EXPECT_THAT(pattern, Eq(spv_operand_pattern_t{})); + } +} + +INSTANTIATE_TEST_CASE_P(MatchableOperandExpansion, + MatchableOperandExpansionTest, + ::testing::ValuesIn(allOperandTypes()), ); + +using VariableOperandExpansionTest = + ::testing::TestWithParam; + +TEST_P(VariableOperandExpansionTest, NonMatchableOperandsExpand) { + const spv_operand_type_t type = GetParam(); + if (spvOperandIsVariable(type)) { + spv_operand_pattern_t pattern; + const bool did_expand = spvExpandOperandSequenceOnce(type, &pattern); + EXPECT_TRUE(did_expand); + EXPECT_FALSE(pattern.empty()); + // For the existing rules, the first expansion of a zero-or-more operand + // type yields a matchable operand type. This isn't strictly necessary. + EXPECT_FALSE(spvOperandIsVariable(pattern.back())); + } +} + +INSTANTIATE_TEST_CASE_P(NonMatchableOperandExpansion, + VariableOperandExpansionTest, + ::testing::ValuesIn(allOperandTypes()), ); + +TEST(AlternatePatternFollowingImmediate, Empty) { + EXPECT_THAT(spvAlternatePatternFollowingImmediate({}), + Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_OPTIONAL_CIV})); +} + +TEST(AlternatePatternFollowingImmediate, SingleElement) { + // Spot-check a random selection of types. + EXPECT_THAT(spvAlternatePatternFollowingImmediate( + {SPV_OPERAND_TYPE_VARIABLE_ID_LITERAL_INTEGER}), + Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_OPTIONAL_CIV})); + EXPECT_THAT( + spvAlternatePatternFollowingImmediate({SPV_OPERAND_TYPE_CAPABILITY}), + Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_OPTIONAL_CIV})); + EXPECT_THAT( + spvAlternatePatternFollowingImmediate({SPV_OPERAND_TYPE_LOOP_CONTROL}), + Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_OPTIONAL_CIV})); + EXPECT_THAT(spvAlternatePatternFollowingImmediate( + {SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER}), + Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_OPTIONAL_CIV})); + EXPECT_THAT(spvAlternatePatternFollowingImmediate({SPV_OPERAND_TYPE_ID}), + Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_OPTIONAL_CIV})); +} + +TEST(AlternatePatternFollowingImmediate, SingleResultId) { + EXPECT_THAT( + spvAlternatePatternFollowingImmediate({SPV_OPERAND_TYPE_RESULT_ID}), + Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_OPTIONAL_CIV, + SPV_OPERAND_TYPE_RESULT_ID})); +} + +TEST(AlternatePatternFollowingImmediate, MultipleNonResultIds) { + EXPECT_THAT( + spvAlternatePatternFollowingImmediate( + {SPV_OPERAND_TYPE_VARIABLE_ID_LITERAL_INTEGER, + SPV_OPERAND_TYPE_CAPABILITY, SPV_OPERAND_TYPE_LOOP_CONTROL, + SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER, SPV_OPERAND_TYPE_ID}), + Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_OPTIONAL_CIV})); +} + +TEST(AlternatePatternFollowingImmediate, ResultIdFront) { + EXPECT_THAT(spvAlternatePatternFollowingImmediate( + {SPV_OPERAND_TYPE_RESULT_ID, SPV_OPERAND_TYPE_ID}), + Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_OPTIONAL_CIV, + SPV_OPERAND_TYPE_RESULT_ID, + SPV_OPERAND_TYPE_OPTIONAL_CIV})); + EXPECT_THAT( + spvAlternatePatternFollowingImmediate({SPV_OPERAND_TYPE_RESULT_ID, + SPV_OPERAND_TYPE_FP_ROUNDING_MODE, + SPV_OPERAND_TYPE_ID}), + Eq(spv_operand_pattern_t{ + SPV_OPERAND_TYPE_OPTIONAL_CIV, SPV_OPERAND_TYPE_RESULT_ID, + SPV_OPERAND_TYPE_OPTIONAL_CIV, SPV_OPERAND_TYPE_OPTIONAL_CIV})); + EXPECT_THAT( + spvAlternatePatternFollowingImmediate( + {SPV_OPERAND_TYPE_RESULT_ID, SPV_OPERAND_TYPE_DIMENSIONALITY, + SPV_OPERAND_TYPE_LINKAGE_TYPE, + SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE, + SPV_OPERAND_TYPE_FP_ROUNDING_MODE, SPV_OPERAND_TYPE_ID, + SPV_OPERAND_TYPE_VARIABLE_ID}), + Eq(spv_operand_pattern_t{ + SPV_OPERAND_TYPE_OPTIONAL_CIV, SPV_OPERAND_TYPE_RESULT_ID, + SPV_OPERAND_TYPE_OPTIONAL_CIV, SPV_OPERAND_TYPE_OPTIONAL_CIV, + SPV_OPERAND_TYPE_OPTIONAL_CIV, SPV_OPERAND_TYPE_OPTIONAL_CIV, + SPV_OPERAND_TYPE_OPTIONAL_CIV, SPV_OPERAND_TYPE_OPTIONAL_CIV})); +} + +TEST(AlternatePatternFollowingImmediate, ResultIdMiddle) { + EXPECT_THAT(spvAlternatePatternFollowingImmediate( + {SPV_OPERAND_TYPE_FP_ROUNDING_MODE, + SPV_OPERAND_TYPE_RESULT_ID, SPV_OPERAND_TYPE_ID}), + Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_OPTIONAL_CIV, + SPV_OPERAND_TYPE_RESULT_ID, + SPV_OPERAND_TYPE_OPTIONAL_CIV})); + EXPECT_THAT( + spvAlternatePatternFollowingImmediate( + {SPV_OPERAND_TYPE_DIMENSIONALITY, SPV_OPERAND_TYPE_LINKAGE_TYPE, + SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE, + SPV_OPERAND_TYPE_RESULT_ID, SPV_OPERAND_TYPE_FP_ROUNDING_MODE, + SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_VARIABLE_ID}), + Eq(spv_operand_pattern_t{ + SPV_OPERAND_TYPE_OPTIONAL_CIV, SPV_OPERAND_TYPE_RESULT_ID, + SPV_OPERAND_TYPE_OPTIONAL_CIV, SPV_OPERAND_TYPE_OPTIONAL_CIV, + SPV_OPERAND_TYPE_OPTIONAL_CIV})); +} + +TEST(AlternatePatternFollowingImmediate, ResultIdBack) { + EXPECT_THAT(spvAlternatePatternFollowingImmediate( + {SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_RESULT_ID}), + Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_OPTIONAL_CIV, + SPV_OPERAND_TYPE_RESULT_ID})); + EXPECT_THAT(spvAlternatePatternFollowingImmediate( + {SPV_OPERAND_TYPE_FP_ROUNDING_MODE, SPV_OPERAND_TYPE_ID, + SPV_OPERAND_TYPE_RESULT_ID}), + Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_OPTIONAL_CIV, + SPV_OPERAND_TYPE_RESULT_ID})); + EXPECT_THAT( + spvAlternatePatternFollowingImmediate( + {SPV_OPERAND_TYPE_DIMENSIONALITY, SPV_OPERAND_TYPE_LINKAGE_TYPE, + SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE, + SPV_OPERAND_TYPE_FP_ROUNDING_MODE, SPV_OPERAND_TYPE_ID, + SPV_OPERAND_TYPE_VARIABLE_ID, SPV_OPERAND_TYPE_RESULT_ID}), + Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_OPTIONAL_CIV, + SPV_OPERAND_TYPE_RESULT_ID})); +} + +} // namespace +} // namespace spvtools diff --git a/test/operand_test.cpp b/test/operand_test.cpp new file mode 100644 index 000000000..08522c323 --- /dev/null +++ b/test/operand_test.cpp @@ -0,0 +1,75 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using GetTargetTest = ::testing::TestWithParam; +using ::testing::ValuesIn; + +TEST_P(GetTargetTest, Default) { + spv_operand_table table; + ASSERT_EQ(SPV_SUCCESS, spvOperandTableGet(&table, GetParam())); + ASSERT_NE(0u, table->count); + ASSERT_NE(nullptr, table->types); +} + +TEST_P(GetTargetTest, InvalidPointerTable) { + ASSERT_EQ(SPV_ERROR_INVALID_POINTER, spvOperandTableGet(nullptr, GetParam())); +} + +INSTANTIATE_TEST_CASE_P(OperandTableGet, GetTargetTest, + ValuesIn(std::vector{ + SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, + SPV_ENV_VULKAN_1_0}), ); + +TEST(OperandString, AllAreDefinedExceptVariable) { + // None has no string, so don't test it. + EXPECT_EQ(0u, SPV_OPERAND_TYPE_NONE); + // Start testing at enum with value 1, skipping None. + for (int i = 1; i < int(SPV_OPERAND_TYPE_FIRST_VARIABLE_TYPE); i++) { + EXPECT_NE(nullptr, spvOperandTypeStr(static_cast(i))) + << " Operand type " << i; + } +} + +TEST(OperandIsConcreteMask, Sample) { + // Check a few operand types preceding the concrete mask types. + EXPECT_FALSE(spvOperandIsConcreteMask(SPV_OPERAND_TYPE_NONE)); + EXPECT_FALSE(spvOperandIsConcreteMask(SPV_OPERAND_TYPE_ID)); + EXPECT_FALSE(spvOperandIsConcreteMask(SPV_OPERAND_TYPE_LITERAL_INTEGER)); + EXPECT_FALSE(spvOperandIsConcreteMask(SPV_OPERAND_TYPE_CAPABILITY)); + + // Check all the concrete mask operand types. + EXPECT_TRUE(spvOperandIsConcreteMask(SPV_OPERAND_TYPE_IMAGE)); + EXPECT_TRUE(spvOperandIsConcreteMask(SPV_OPERAND_TYPE_FP_FAST_MATH_MODE)); + EXPECT_TRUE(spvOperandIsConcreteMask(SPV_OPERAND_TYPE_SELECTION_CONTROL)); + EXPECT_TRUE(spvOperandIsConcreteMask(SPV_OPERAND_TYPE_LOOP_CONTROL)); + EXPECT_TRUE(spvOperandIsConcreteMask(SPV_OPERAND_TYPE_FUNCTION_CONTROL)); + EXPECT_TRUE(spvOperandIsConcreteMask(SPV_OPERAND_TYPE_MEMORY_ACCESS)); + + // Check a few operand types after the concrete mask types, including the + // optional forms for Image and MemoryAccess. + EXPECT_FALSE(spvOperandIsConcreteMask(SPV_OPERAND_TYPE_OPTIONAL_ID)); + EXPECT_FALSE(spvOperandIsConcreteMask(SPV_OPERAND_TYPE_OPTIONAL_IMAGE)); + EXPECT_FALSE( + spvOperandIsConcreteMask(SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)); +} + +} // namespace +} // namespace spvtools diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt new file mode 100644 index 000000000..631538515 --- /dev/null +++ b/test/opt/CMakeLists.txt @@ -0,0 +1,92 @@ +# Copyright (c) 2016 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_subdirectory(dominator_tree) +add_subdirectory(loop_optimizations) + +add_spvtools_unittest(TARGET opt + SRCS aggressive_dead_code_elim_test.cpp + assembly_builder_test.cpp + block_merge_test.cpp + ccp_test.cpp + cfg_cleanup_test.cpp + code_sink_test.cpp + combine_access_chains_test.cpp + common_uniform_elim_test.cpp + compact_ids_test.cpp + constant_manager_test.cpp + copy_prop_array_test.cpp + dead_branch_elim_test.cpp + dead_insert_elim_test.cpp + dead_variable_elim_test.cpp + decoration_manager_test.cpp + def_use_test.cpp + eliminate_dead_const_test.cpp + eliminate_dead_functions_test.cpp + feature_manager_test.cpp + flatten_decoration_test.cpp + fold_spec_const_op_composite_test.cpp + fold_test.cpp + freeze_spec_const_test.cpp + function_test.cpp + if_conversion_test.cpp + inline_opaque_test.cpp + inline_test.cpp + insert_extract_elim_test.cpp + inst_bindless_check_test.cpp + instruction_list_test.cpp + instruction_test.cpp + ir_builder.cpp + ir_context_test.cpp + ir_loader_test.cpp + iterator_test.cpp + line_debug_info_test.cpp + local_access_chain_convert_test.cpp + local_redundancy_elimination_test.cpp + local_single_block_elim.cpp + local_single_store_elim_test.cpp + local_ssa_elim_test.cpp + module_test.cpp + module_utils.h + optimizer_test.cpp + pass_manager_test.cpp + pass_merge_return_test.cpp + pass_remove_duplicates_test.cpp + pass_utils.cpp + private_to_local_test.cpp + process_lines_test.cpp + propagator_test.cpp + reduce_load_size_test.cpp + redundancy_elimination_test.cpp + register_liveness.cpp + replace_invalid_opc_test.cpp + scalar_analysis.cpp + scalar_replacement_test.cpp + set_spec_const_default_value_test.cpp + simplification_test.cpp + strength_reduction_test.cpp + strip_debug_info_test.cpp + strip_reflect_info_test.cpp + struct_cfg_analysis_test.cpp + type_manager_test.cpp + types_test.cpp + unify_const_test.cpp + upgrade_memory_model_test.cpp + utils_test.cpp pass_utils.cpp + value_table_test.cpp + vector_dce_test.cpp + workaround1209_test.cpp + LIBS SPIRV-Tools-opt + PCH_FILE pch_test_opt +) diff --git a/test/opt/aggressive_dead_code_elim_test.cpp b/test/opt/aggressive_dead_code_elim_test.cpp new file mode 100644 index 000000000..e58a76f71 --- /dev/null +++ b/test/opt/aggressive_dead_code_elim_test.cpp @@ -0,0 +1,6225 @@ +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using AggressiveDCETest = PassTest<::testing::Test>; + +TEST_F(AggressiveDCETest, EliminateExtendedInst) { + // #version 140 + // + // in vec4 BaseColor; + // in vec4 Dead; + // + // void main() + // { + // vec4 v = BaseColor; + // vec4 dv = sqrt(Dead); + // gl_FragColor = v; + // } + + const std::string predefs1 = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %Dead %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +)"; + + const std::string names_before = + R"(OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %dv "dv" +OpName %Dead "Dead" +OpName %gl_FragColor "gl_FragColor" +)"; + + const std::string names_after = + R"(OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %Dead "Dead" +OpName %gl_FragColor "gl_FragColor" +)"; + + const std::string predefs2 = + R"(%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%Dead = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string func_before = + R"(%main = OpFunction %void None %9 +%15 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%dv = OpVariable %_ptr_Function_v4float Function +%16 = OpLoad %v4float %BaseColor +OpStore %v %16 +%17 = OpLoad %v4float %Dead +%18 = OpExtInst %v4float %1 Sqrt %17 +OpStore %dv %18 +%19 = OpLoad %v4float %v +OpStore %gl_FragColor %19 +OpReturn +OpFunctionEnd +)"; + + const std::string func_after = + R"(%main = OpFunction %void None %9 +%15 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%16 = OpLoad %v4float %BaseColor +OpStore %v %16 +%19 = OpLoad %v4float %v +OpStore %gl_FragColor %19 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs1 + names_before + predefs2 + func_before, + predefs1 + names_after + predefs2 + func_after, true, true); +} + +TEST_F(AggressiveDCETest, NoEliminateFrexp) { + // Note: SPIR-V hand-edited to utilize Frexp + // + // #version 450 + // + // in vec4 BaseColor; + // in vec4 Dead; + // out vec4 Color; + // out ivec4 iv2; + // + // void main() + // { + // vec4 v = BaseColor; + // vec4 dv = frexp(Dead, iv2); + // Color = v; + // } + + const std::string predefs1 = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %Dead %iv2 %Color +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +)"; + + const std::string names_before = + R"(OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %dv "dv" +OpName %Dead "Dead" +OpName %iv2 "iv2" +OpName %ResType "ResType" +OpName %Color "Color" +)"; + + const std::string names_after = + R"(OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %Dead "Dead" +OpName %iv2 "iv2" +OpName %Color "Color" +)"; + + const std::string predefs2_before = + R"(%void = OpTypeVoid +%11 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%Dead = OpVariable %_ptr_Input_v4float Input +%int = OpTypeInt 32 1 +%v4int = OpTypeVector %int 4 +%_ptr_Output_v4int = OpTypePointer Output %v4int +%iv2 = OpVariable %_ptr_Output_v4int Output +%ResType = OpTypeStruct %v4float %v4int +%_ptr_Output_v4float = OpTypePointer Output %v4float +%Color = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string predefs2_after = + R"(%void = OpTypeVoid +%11 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%Dead = OpVariable %_ptr_Input_v4float Input +%int = OpTypeInt 32 1 +%v4int = OpTypeVector %int 4 +%_ptr_Output_v4int = OpTypePointer Output %v4int +%iv2 = OpVariable %_ptr_Output_v4int Output +%_ptr_Output_v4float = OpTypePointer Output %v4float +%Color = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string func_before = + R"(%main = OpFunction %void None %11 +%20 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%dv = OpVariable %_ptr_Function_v4float Function +%21 = OpLoad %v4float %BaseColor +OpStore %v %21 +%22 = OpLoad %v4float %Dead +%23 = OpExtInst %v4float %1 Frexp %22 %iv2 +OpStore %dv %23 +%24 = OpLoad %v4float %v +OpStore %Color %24 +OpReturn +OpFunctionEnd +)"; + + const std::string func_after = + R"(%main = OpFunction %void None %11 +%20 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%21 = OpLoad %v4float %BaseColor +OpStore %v %21 +%22 = OpLoad %v4float %Dead +%23 = OpExtInst %v4float %1 Frexp %22 %iv2 +%24 = OpLoad %v4float %v +OpStore %Color %24 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs1 + names_before + predefs2_before + func_before, + predefs1 + names_after + predefs2_after + func_after, true, true); +} + +TEST_F(AggressiveDCETest, EliminateDecorate) { + // Note: The SPIR-V was hand-edited to add the OpDecorate + // + // #version 140 + // + // in vec4 BaseColor; + // in vec4 Dead; + // + // void main() + // { + // vec4 v = BaseColor; + // vec4 dv = Dead * 0.5; + // gl_FragColor = v; + // } + + const std::string predefs1 = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %Dead %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +)"; + + const std::string names_before = + R"(OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %dv "dv" +OpName %Dead "Dead" +OpName %gl_FragColor "gl_FragColor" +OpDecorate %8 RelaxedPrecision +)"; + + const std::string names_after = + R"(OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %Dead "Dead" +OpName %gl_FragColor "gl_FragColor" +)"; + + const std::string predefs2_before = + R"(%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%Dead = OpVariable %_ptr_Input_v4float Input +%float_0_5 = OpConstant %float 0.5 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string predefs2_after = + R"(%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%Dead = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string func_before = + R"(%main = OpFunction %void None %10 +%17 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%dv = OpVariable %_ptr_Function_v4float Function +%18 = OpLoad %v4float %BaseColor +OpStore %v %18 +%19 = OpLoad %v4float %Dead +%8 = OpVectorTimesScalar %v4float %19 %float_0_5 +OpStore %dv %8 +%20 = OpLoad %v4float %v +OpStore %gl_FragColor %20 +OpReturn +OpFunctionEnd +)"; + + const std::string func_after = + R"(%main = OpFunction %void None %10 +%17 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%18 = OpLoad %v4float %BaseColor +OpStore %v %18 +%20 = OpLoad %v4float %v +OpStore %gl_FragColor %20 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs1 + names_before + predefs2_before + func_before, + predefs1 + names_after + predefs2_after + func_after, true, true); +} + +TEST_F(AggressiveDCETest, Simple) { + // #version 140 + // + // in vec4 BaseColor; + // in vec4 Dead; + // + // void main() + // { + // vec4 v = BaseColor; + // vec4 dv = Dead; + // gl_FragColor = v; + // } + + const std::string predefs1 = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %Dead %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +)"; + + const std::string names_before = + R"(OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %dv "dv" +OpName %Dead "Dead" +OpName %gl_FragColor "gl_FragColor" +)"; + + const std::string names_after = + R"(OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %Dead "Dead" +OpName %gl_FragColor "gl_FragColor" +)"; + + const std::string predefs2 = + R"(%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%Dead = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string func_before = + R"(%main = OpFunction %void None %9 +%15 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%dv = OpVariable %_ptr_Function_v4float Function +%16 = OpLoad %v4float %BaseColor +OpStore %v %16 +%17 = OpLoad %v4float %Dead +OpStore %dv %17 +%18 = OpLoad %v4float %v +OpStore %gl_FragColor %18 +OpReturn +OpFunctionEnd +)"; + + const std::string func_after = + R"(%main = OpFunction %void None %9 +%15 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%16 = OpLoad %v4float %BaseColor +OpStore %v %16 +%18 = OpLoad %v4float %v +OpStore %gl_FragColor %18 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs1 + names_before + predefs2 + func_before, + predefs1 + names_after + predefs2 + func_after, true, true); +} + +TEST_F(AggressiveDCETest, OptWhitelistExtension) { + // #version 140 + // + // in vec4 BaseColor; + // in vec4 Dead; + // + // void main() + // { + // vec4 v = BaseColor; + // vec4 dv = Dead; + // gl_FragColor = v; + // } + + const std::string predefs1 = + R"(OpCapability Shader +OpExtension "SPV_AMD_gpu_shader_int16" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %Dead %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +)"; + + const std::string names_before = + R"(OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %dv "dv" +OpName %Dead "Dead" +OpName %gl_FragColor "gl_FragColor" +)"; + + const std::string names_after = + R"(OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %Dead "Dead" +OpName %gl_FragColor "gl_FragColor" +)"; + + const std::string predefs2 = + R"(%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%Dead = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string func_before = + R"(%main = OpFunction %void None %9 +%15 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%dv = OpVariable %_ptr_Function_v4float Function +%16 = OpLoad %v4float %BaseColor +OpStore %v %16 +%17 = OpLoad %v4float %Dead +OpStore %dv %17 +%18 = OpLoad %v4float %v +OpStore %gl_FragColor %18 +OpReturn +OpFunctionEnd +)"; + + const std::string func_after = + R"(%main = OpFunction %void None %9 +%15 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%16 = OpLoad %v4float %BaseColor +OpStore %v %16 +%18 = OpLoad %v4float %v +OpStore %gl_FragColor %18 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs1 + names_before + predefs2 + func_before, + predefs1 + names_after + predefs2 + func_after, true, true); +} + +TEST_F(AggressiveDCETest, NoOptBlacklistExtension) { + // #version 140 + // + // in vec4 BaseColor; + // in vec4 Dead; + // + // void main() + // { + // vec4 v = BaseColor; + // vec4 dv = Dead; + // gl_FragColor = v; + // } + + const std::string assembly = + R"(OpCapability Shader +OpExtension "SPV_KHR_variable_pointers" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %Dead %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %dv "dv" +OpName %Dead "Dead" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%Dead = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %9 +%15 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%dv = OpVariable %_ptr_Function_v4float Function +%16 = OpLoad %v4float %BaseColor +OpStore %v %16 +%17 = OpLoad %v4float %Dead +OpStore %dv %17 +%18 = OpLoad %v4float %v +OpStore %gl_FragColor %18 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(assembly, assembly, true, true); +} + +TEST_F(AggressiveDCETest, ElimWithCall) { + // This demonstrates that "dead" function calls are not eliminated. + // Also demonstrates that DCE will happen in presence of function call. + // #version 140 + // in vec4 i1; + // in vec4 i2; + // + // void nothing(vec4 v) + // { + // } + // + // void main() + // { + // vec4 v1 = i1; + // vec4 v2 = i2; + // nothing(v1); + // gl_FragColor = vec4(0.0); + // } + + const std::string defs_before = + R"( OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %i1 %i2 %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %nothing_vf4_ "nothing(vf4;" +OpName %v "v" +OpName %v1 "v1" +OpName %i1 "i1" +OpName %v2 "v2" +OpName %i2 "i2" +OpName %param "param" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%12 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%16 = OpTypeFunction %void %_ptr_Function_v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%i1 = OpVariable %_ptr_Input_v4float Input +%i2 = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%float_0 = OpConstant %float 0 +%20 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +)"; + + const std::string defs_after = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %i1 %i2 %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %nothing_vf4_ "nothing(vf4;" +OpName %v "v" +OpName %v1 "v1" +OpName %i1 "i1" +OpName %i2 "i2" +OpName %param "param" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%12 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%16 = OpTypeFunction %void %_ptr_Function_v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%i1 = OpVariable %_ptr_Input_v4float Input +%i2 = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%float_0 = OpConstant %float 0 +%20 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +)"; + + const std::string func_before = + R"(%main = OpFunction %void None %12 +%21 = OpLabel +%v1 = OpVariable %_ptr_Function_v4float Function +%v2 = OpVariable %_ptr_Function_v4float Function +%param = OpVariable %_ptr_Function_v4float Function +%22 = OpLoad %v4float %i1 +OpStore %v1 %22 +%23 = OpLoad %v4float %i2 +OpStore %v2 %23 +%24 = OpLoad %v4float %v1 +OpStore %param %24 +%25 = OpFunctionCall %void %nothing_vf4_ %param +OpStore %gl_FragColor %20 +OpReturn +OpFunctionEnd +%nothing_vf4_ = OpFunction %void None %16 +%v = OpFunctionParameter %_ptr_Function_v4float +%26 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string func_after = + R"(%main = OpFunction %void None %12 +%21 = OpLabel +%v1 = OpVariable %_ptr_Function_v4float Function +%param = OpVariable %_ptr_Function_v4float Function +%22 = OpLoad %v4float %i1 +OpStore %v1 %22 +%24 = OpLoad %v4float %v1 +OpStore %param %24 +%25 = OpFunctionCall %void %nothing_vf4_ %param +OpStore %gl_FragColor %20 +OpReturn +OpFunctionEnd +%nothing_vf4_ = OpFunction %void None %16 +%v = OpFunctionParameter %_ptr_Function_v4float +%26 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(defs_before + func_before, + defs_after + func_after, true, true); +} + +TEST_F(AggressiveDCETest, NoParamElim) { + // This demonstrates that unused parameters are not eliminated, but + // dead uses of them are. + // #version 140 + // + // in vec4 BaseColor; + // + // vec4 foo(vec4 v1, vec4 v2) + // { + // vec4 t = -v1; + // return v2; + // } + // + // void main() + // { + // vec4 dead; + // gl_FragColor = foo(dead, BaseColor); + // } + + const std::string defs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %gl_FragColor %BaseColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %foo_vf4_vf4_ "foo(vf4;vf4;" +OpName %v1 "v1" +OpName %v2 "v2" +OpName %t "t" +OpName %gl_FragColor "gl_FragColor" +OpName %dead "dead" +OpName %BaseColor "BaseColor" +OpName %param "param" +OpName %param_0 "param" +%void = OpTypeVoid +%13 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%17 = OpTypeFunction %v4float %_ptr_Function_v4float %_ptr_Function_v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%main = OpFunction %void None %13 +%20 = OpLabel +%dead = OpVariable %_ptr_Function_v4float Function +%param = OpVariable %_ptr_Function_v4float Function +%param_0 = OpVariable %_ptr_Function_v4float Function +%21 = OpLoad %v4float %dead +OpStore %param %21 +%22 = OpLoad %v4float %BaseColor +OpStore %param_0 %22 +%23 = OpFunctionCall %v4float %foo_vf4_vf4_ %param %param_0 +OpStore %gl_FragColor %23 +OpReturn +OpFunctionEnd +)"; + + const std::string defs_after = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %gl_FragColor %BaseColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %foo_vf4_vf4_ "foo(vf4;vf4;" +OpName %v1 "v1" +OpName %v2 "v2" +OpName %gl_FragColor "gl_FragColor" +OpName %dead "dead" +OpName %BaseColor "BaseColor" +OpName %param "param" +OpName %param_0 "param" +%void = OpTypeVoid +%13 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%17 = OpTypeFunction %v4float %_ptr_Function_v4float %_ptr_Function_v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%main = OpFunction %void None %13 +%20 = OpLabel +%dead = OpVariable %_ptr_Function_v4float Function +%param = OpVariable %_ptr_Function_v4float Function +%param_0 = OpVariable %_ptr_Function_v4float Function +%21 = OpLoad %v4float %dead +OpStore %param %21 +%22 = OpLoad %v4float %BaseColor +OpStore %param_0 %22 +%23 = OpFunctionCall %v4float %foo_vf4_vf4_ %param %param_0 +OpStore %gl_FragColor %23 +OpReturn +OpFunctionEnd +)"; + + const std::string func_before = + R"(%foo_vf4_vf4_ = OpFunction %v4float None %17 +%v1 = OpFunctionParameter %_ptr_Function_v4float +%v2 = OpFunctionParameter %_ptr_Function_v4float +%24 = OpLabel +%t = OpVariable %_ptr_Function_v4float Function +%25 = OpLoad %v4float %v1 +%26 = OpFNegate %v4float %25 +OpStore %t %26 +%27 = OpLoad %v4float %v2 +OpReturnValue %27 +OpFunctionEnd +)"; + + const std::string func_after = + R"(%foo_vf4_vf4_ = OpFunction %v4float None %17 +%v1 = OpFunctionParameter %_ptr_Function_v4float +%v2 = OpFunctionParameter %_ptr_Function_v4float +%24 = OpLabel +%27 = OpLoad %v4float %v2 +OpReturnValue %27 +OpFunctionEnd +)"; + + SinglePassRunAndCheck(defs_before + func_before, + defs_after + func_after, true, true); +} + +TEST_F(AggressiveDCETest, ElimOpaque) { + // SPIR-V not representable from GLSL; not generatable from HLSL + // for the moment. + + const std::string defs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %outColor %texCoords +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %S_t "S_t" +OpMemberName %S_t 0 "v0" +OpMemberName %S_t 1 "v1" +OpMemberName %S_t 2 "smp" +OpName %outColor "outColor" +OpName %sampler15 "sampler15" +OpName %s0 "s0" +OpName %texCoords "texCoords" +OpDecorate %sampler15 DescriptorSet 0 +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%outColor = OpVariable %_ptr_Output_v4float Output +%14 = OpTypeImage %float 2D 0 0 0 1 Unknown +%15 = OpTypeSampledImage %14 +%S_t = OpTypeStruct %v2float %v2float %15 +%_ptr_Function_S_t = OpTypePointer Function %S_t +%17 = OpTypeFunction %void %_ptr_Function_S_t +%_ptr_UniformConstant_15 = OpTypePointer UniformConstant %15 +%_ptr_Function_15 = OpTypePointer Function %15 +%sampler15 = OpVariable %_ptr_UniformConstant_15 UniformConstant +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%int_2 = OpConstant %int 2 +%_ptr_Function_v2float = OpTypePointer Function %v2float +%_ptr_Input_v2float = OpTypePointer Input %v2float +%texCoords = OpVariable %_ptr_Input_v2float Input +)"; + + const std::string defs_after = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %outColor %texCoords +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %outColor "outColor" +OpName %sampler15 "sampler15" +OpName %texCoords "texCoords" +OpDecorate %sampler15 DescriptorSet 0 +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%outColor = OpVariable %_ptr_Output_v4float Output +%14 = OpTypeImage %float 2D 0 0 0 1 Unknown +%15 = OpTypeSampledImage %14 +%_ptr_UniformConstant_15 = OpTypePointer UniformConstant %15 +%sampler15 = OpVariable %_ptr_UniformConstant_15 UniformConstant +%_ptr_Input_v2float = OpTypePointer Input %v2float +%texCoords = OpVariable %_ptr_Input_v2float Input +)"; + + const std::string func_before = + R"(%main = OpFunction %void None %9 +%25 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%26 = OpLoad %v2float %texCoords +%27 = OpLoad %S_t %s0 +%28 = OpCompositeInsert %S_t %26 %27 0 +%29 = OpLoad %15 %sampler15 +%30 = OpCompositeInsert %S_t %29 %28 2 +OpStore %s0 %30 +%31 = OpImageSampleImplicitLod %v4float %29 %26 +OpStore %outColor %31 +OpReturn +OpFunctionEnd +)"; + + const std::string func_after = + R"(%main = OpFunction %void None %9 +%25 = OpLabel +%26 = OpLoad %v2float %texCoords +%29 = OpLoad %15 %sampler15 +%31 = OpImageSampleImplicitLod %v4float %29 %26 +OpStore %outColor %31 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(defs_before + func_before, + defs_after + func_after, true, true); +} + +TEST_F(AggressiveDCETest, NoParamStoreElim) { + // Should not eliminate stores to params + // + // #version 450 + // + // layout(location = 0) in vec4 BaseColor; + // layout(location = 0) out vec4 OutColor; + // + // void foo(in vec4 v1, out vec4 v2) + // { + // v2 = -v1; + // } + // + // void main() + // { + // foo(BaseColor, OutColor); + // } + + const std::string assembly = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %foo_vf4_vf4_ "foo(vf4;vf4;" +OpName %v1 "v1" +OpName %v2 "v2" +OpName %BaseColor "BaseColor" +OpName %OutColor "OutColor" +OpName %param "param" +OpName %param_0 "param" +OpDecorate %BaseColor Location 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%11 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%15 = OpTypeFunction %void %_ptr_Function_v4float %_ptr_Function_v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %11 +%18 = OpLabel +%param = OpVariable %_ptr_Function_v4float Function +%param_0 = OpVariable %_ptr_Function_v4float Function +%19 = OpLoad %v4float %BaseColor +OpStore %param %19 +%20 = OpFunctionCall %void %foo_vf4_vf4_ %param %param_0 +%21 = OpLoad %v4float %param_0 +OpStore %OutColor %21 +OpReturn +OpFunctionEnd +%foo_vf4_vf4_ = OpFunction %void None %15 +%v1 = OpFunctionParameter %_ptr_Function_v4float +%v2 = OpFunctionParameter %_ptr_Function_v4float +%22 = OpLabel +%23 = OpLoad %v4float %v1 +%24 = OpFNegate %v4float %23 +OpStore %v2 %24 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(assembly, assembly, true, true); +} + +TEST_F(AggressiveDCETest, PrivateStoreElimInEntryNoCalls) { + // Eliminate stores to private in entry point with no calls + // Note: Not legal GLSL + // + // layout(location = 0) in vec4 BaseColor; + // layout(location = 1) in vec4 Dead; + // layout(location = 0) out vec4 OutColor; + // + // private vec4 dv; + // + // void main() + // { + // vec4 v = BaseColor; + // dv = Dead; + // OutColor = v; + // } + + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %Dead %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %dv "dv" +OpName %Dead "Dead" +OpName %OutColor "OutColor" +OpDecorate %BaseColor Location 0 +OpDecorate %Dead Location 1 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Private_v4float = OpTypePointer Private %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%Dead = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%dv = OpVariable %_ptr_Private_v4float Private +%OutColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string predefs_after = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %Dead %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %Dead "Dead" +OpName %OutColor "OutColor" +OpDecorate %BaseColor Location 0 +OpDecorate %Dead Location 1 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%Dead = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string main_before = + R"(%main = OpFunction %void None %9 +%16 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%17 = OpLoad %v4float %BaseColor +OpStore %v %17 +%18 = OpLoad %v4float %Dead +OpStore %dv %18 +%19 = OpLoad %v4float %v +%20 = OpFNegate %v4float %19 +OpStore %OutColor %20 +OpReturn +OpFunctionEnd +)"; + + const std::string main_after = + R"(%main = OpFunction %void None %9 +%16 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%17 = OpLoad %v4float %BaseColor +OpStore %v %17 +%19 = OpLoad %v4float %v +%20 = OpFNegate %v4float %19 +OpStore %OutColor %20 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs_before + main_before, predefs_after + main_after, true, true); +} + +TEST_F(AggressiveDCETest, NoPrivateStoreElimIfLoad) { + // Should not eliminate stores to private when there is a load + // Note: Not legal GLSL + // + // #version 450 + // + // layout(location = 0) in vec4 BaseColor; + // layout(location = 0) out vec4 OutColor; + // + // private vec4 pv; + // + // void main() + // { + // pv = BaseColor; + // OutColor = pv; + // } + + const std::string assembly = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %pv "pv" +OpName %BaseColor "BaseColor" +OpName %OutColor "OutColor" +OpDecorate %BaseColor Location 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Private_v4float = OpTypePointer Private %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%pv = OpVariable %_ptr_Private_v4float Private +%main = OpFunction %void None %7 +%13 = OpLabel +%14 = OpLoad %v4float %BaseColor +OpStore %pv %14 +%15 = OpLoad %v4float %pv +%16 = OpFNegate %v4float %15 +OpStore %OutColor %16 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(assembly, assembly, true, true); +} + +TEST_F(AggressiveDCETest, NoPrivateStoreElimWithCall) { + // Should not eliminate stores to private when function contains call + // Note: Not legal GLSL + // + // #version 450 + // + // layout(location = 0) in vec4 BaseColor; + // layout(location = 0) out vec4 OutColor; + // + // private vec4 v1; + // + // void foo() + // { + // OutColor = -v1; + // } + // + // void main() + // { + // v1 = BaseColor; + // foo(); + // } + + const std::string assembly = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %OutColor %BaseColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %foo_ "foo(" +OpName %OutColor "OutColor" +OpName %v1 "v1" +OpName %BaseColor "BaseColor" +OpDecorate %OutColor Location 0 +OpDecorate %BaseColor Location 0 +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%_ptr_Private_v4float = OpTypePointer Private %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%v1 = OpVariable %_ptr_Private_v4float Private +%BaseColor = OpVariable %_ptr_Input_v4float Input +%main = OpFunction %void None %8 +%14 = OpLabel +%15 = OpLoad %v4float %BaseColor +OpStore %v1 %15 +%16 = OpFunctionCall %void %foo_ +OpReturn +OpFunctionEnd +%foo_ = OpFunction %void None %8 +%17 = OpLabel +%18 = OpLoad %v4float %v1 +%19 = OpFNegate %v4float %18 +OpStore %OutColor %19 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(assembly, assembly, true, true); +} + +TEST_F(AggressiveDCETest, NoPrivateStoreElimInNonEntry) { + // Should not eliminate stores to private when function is not entry point + // Note: Not legal GLSL + // + // #version 450 + // + // layout(location = 0) in vec4 BaseColor; + // layout(location = 0) out vec4 OutColor; + // + // private vec4 v1; + // + // void foo() + // { + // v1 = BaseColor; + // } + // + // void main() + // { + // foo(); + // OutColor = -v1; + // } + + const std::string assembly = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %foo_ "foo(" +OpName %v1 "v1" +OpName %BaseColor "BaseColor" +OpName %OutColor "OutColor" +OpDecorate %BaseColor Location 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Private_v4float = OpTypePointer Private %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%v1 = OpVariable %_ptr_Private_v4float Private +%OutColor = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %8 +%14 = OpLabel +%15 = OpFunctionCall %void %foo_ +%16 = OpLoad %v4float %v1 +%17 = OpFNegate %v4float %16 +OpStore %OutColor %17 +OpReturn +OpFunctionEnd +%foo_ = OpFunction %void None %8 +%18 = OpLabel +%19 = OpLoad %v4float %BaseColor +OpStore %v1 %19 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(assembly, assembly, true, true); +} + +TEST_F(AggressiveDCETest, WorkgroupStoreElimInEntryNoCalls) { + // Eliminate stores to private in entry point with no calls + // Note: Not legal GLSL + // + // layout(location = 0) in vec4 BaseColor; + // layout(location = 1) in vec4 Dead; + // layout(location = 0) out vec4 OutColor; + // + // workgroup vec4 dv; + // + // void main() + // { + // vec4 v = BaseColor; + // dv = Dead; + // OutColor = v; + // } + + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %Dead %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %dv "dv" +OpName %Dead "Dead" +OpName %OutColor "OutColor" +OpDecorate %BaseColor Location 0 +OpDecorate %Dead Location 1 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Workgroup_v4float = OpTypePointer Workgroup %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%Dead = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%dv = OpVariable %_ptr_Workgroup_v4float Workgroup +%OutColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string predefs_after = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %Dead %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %Dead "Dead" +OpName %OutColor "OutColor" +OpDecorate %BaseColor Location 0 +OpDecorate %Dead Location 1 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%Dead = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string main_before = + R"(%main = OpFunction %void None %9 +%16 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%17 = OpLoad %v4float %BaseColor +OpStore %v %17 +%18 = OpLoad %v4float %Dead +OpStore %dv %18 +%19 = OpLoad %v4float %v +%20 = OpFNegate %v4float %19 +OpStore %OutColor %20 +OpReturn +OpFunctionEnd +)"; + + const std::string main_after = + R"(%main = OpFunction %void None %9 +%16 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%17 = OpLoad %v4float %BaseColor +OpStore %v %17 +%19 = OpLoad %v4float %v +%20 = OpFNegate %v4float %19 +OpStore %OutColor %20 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs_before + main_before, predefs_after + main_after, true, true); +} + +TEST_F(AggressiveDCETest, EliminateDeadIfThenElse) { + // #version 450 + // + // layout(location = 0) in vec4 BaseColor; + // layout(location = 0) out vec4 OutColor; + // + // void main() + // { + // float d; + // if (BaseColor.x == 0) + // d = BaseColor.y; + // else + // d = BaseColor.z; + // OutColor = vec4(1.0,1.0,1.0,1.0); + // } + + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %BaseColor "BaseColor" +OpName %d "d" +OpName %OutColor "OutColor" +OpDecorate %BaseColor Location 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%_ptr_Input_float = OpTypePointer Input %float +%float_0 = OpConstant %float 0 +%bool = OpTypeBool +%_ptr_Function_float = OpTypePointer Function %float +%uint_1 = OpConstant %uint 1 +%uint_2 = OpConstant %uint 2 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%float_1 = OpConstant %float 1 +%21 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +)"; + + const std::string predefs_after = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %BaseColor "BaseColor" +OpName %OutColor "OutColor" +OpDecorate %BaseColor Location 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%float_1 = OpConstant %float 1 +%21 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +)"; + + const std::string func_before = + R"(%main = OpFunction %void None %7 +%22 = OpLabel +%d = OpVariable %_ptr_Function_float Function +%23 = OpAccessChain %_ptr_Input_float %BaseColor %uint_0 +%24 = OpLoad %float %23 +%25 = OpFOrdEqual %bool %24 %float_0 +OpSelectionMerge %26 None +OpBranchConditional %25 %27 %28 +%27 = OpLabel +%29 = OpAccessChain %_ptr_Input_float %BaseColor %uint_1 +%30 = OpLoad %float %29 +OpStore %d %30 +OpBranch %26 +%28 = OpLabel +%31 = OpAccessChain %_ptr_Input_float %BaseColor %uint_2 +%32 = OpLoad %float %31 +OpStore %d %32 +OpBranch %26 +%26 = OpLabel +OpStore %OutColor %21 +OpReturn +OpFunctionEnd +)"; + + const std::string func_after = + R"(%main = OpFunction %void None %7 +%22 = OpLabel +OpBranch %26 +%26 = OpLabel +OpStore %OutColor %21 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs_before + func_before, predefs_after + func_after, true, true); +} + +TEST_F(AggressiveDCETest, EliminateDeadIfThen) { + // #version 450 + // + // layout(location = 0) in vec4 BaseColor; + // layout(location = 0) out vec4 OutColor; + // + // void main() + // { + // float d; + // if (BaseColor.x == 0) + // d = BaseColor.y; + // OutColor = vec4(1.0,1.0,1.0,1.0); + // } + + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %BaseColor "BaseColor" +OpName %d "d" +OpName %OutColor "OutColor" +OpDecorate %BaseColor Location 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%_ptr_Input_float = OpTypePointer Input %float +%float_0 = OpConstant %float 0 +%bool = OpTypeBool +%_ptr_Function_float = OpTypePointer Function %float +%uint_1 = OpConstant %uint 1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%float_1 = OpConstant %float 1 +%20 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +)"; + + const std::string predefs_after = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %BaseColor "BaseColor" +OpName %OutColor "OutColor" +OpDecorate %BaseColor Location 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%float_1 = OpConstant %float 1 +%20 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +)"; + + const std::string func_before = + R"(%main = OpFunction %void None %7 +%21 = OpLabel +%d = OpVariable %_ptr_Function_float Function +%22 = OpAccessChain %_ptr_Input_float %BaseColor %uint_0 +%23 = OpLoad %float %22 +%24 = OpFOrdEqual %bool %23 %float_0 +OpSelectionMerge %25 None +OpBranchConditional %24 %26 %25 +%26 = OpLabel +%27 = OpAccessChain %_ptr_Input_float %BaseColor %uint_1 +%28 = OpLoad %float %27 +OpStore %d %28 +OpBranch %25 +%25 = OpLabel +OpStore %OutColor %20 +OpReturn +OpFunctionEnd +)"; + + const std::string func_after = + R"(%main = OpFunction %void None %7 +%21 = OpLabel +OpBranch %25 +%25 = OpLabel +OpStore %OutColor %20 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs_before + func_before, predefs_after + func_after, true, true); +} + +TEST_F(AggressiveDCETest, EliminateDeadSwitch) { + // #version 450 + // + // layout(location = 0) in vec4 BaseColor; + // layout(location = 1) in flat int x; + // layout(location = 0) out vec4 OutColor; + // + // void main() + // { + // float d; + // switch (x) { + // case 0: + // d = BaseColor.y; + // } + // OutColor = vec4(1.0,1.0,1.0,1.0); + // } + const std::string before = + R"(OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %x %BaseColor %OutColor + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpName %main "main" + OpName %x "x" + OpName %d "d" + OpName %BaseColor "BaseColor" + OpName %OutColor "OutColor" + OpDecorate %x Flat + OpDecorate %x Location 1 + OpDecorate %BaseColor Location 0 + OpDecorate %OutColor Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int + %x = OpVariable %_ptr_Input_int Input + %float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float + %v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float + %BaseColor = OpVariable %_ptr_Input_v4float Input + %uint = OpTypeInt 32 0 + %uint_1 = OpConstant %uint 1 +%_ptr_Input_float = OpTypePointer Input %float +%_ptr_Output_v4float = OpTypePointer Output %v4float + %OutColor = OpVariable %_ptr_Output_v4float Output + %float_1 = OpConstant %float 1 + %27 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 + %main = OpFunction %void None %3 + %5 = OpLabel + %d = OpVariable %_ptr_Function_float Function + %9 = OpLoad %int %x + OpSelectionMerge %11 None + OpSwitch %9 %11 0 %10 + %10 = OpLabel + %21 = OpAccessChain %_ptr_Input_float %BaseColor %uint_1 + %22 = OpLoad %float %21 + OpStore %d %22 + OpBranch %11 + %11 = OpLabel + OpStore %OutColor %27 + OpReturn + OpFunctionEnd)"; + + const std::string after = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %x %BaseColor %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %x "x" +OpName %BaseColor "BaseColor" +OpName %OutColor "OutColor" +OpDecorate %x Flat +OpDecorate %x Location 1 +OpDecorate %BaseColor Location 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%3 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%x = OpVariable %_ptr_Input_int Input +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%float_1 = OpConstant %float 1 +%27 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +%main = OpFunction %void None %3 +%5 = OpLabel +OpBranch %11 +%11 = OpLabel +OpStore %OutColor %27 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(before, after, true, true); +} + +TEST_F(AggressiveDCETest, EliminateDeadIfThenElseNested) { + // #version 450 + // + // layout(location = 0) in vec4 BaseColor; + // layout(location = 0) out vec4 OutColor; + // + // void main() + // { + // float d; + // if (BaseColor.x == 0) + // if (BaseColor.y == 0) + // d = 0.0; + // else + // d = 0.25; + // else + // if (BaseColor.y == 0) + // d = 0.5; + // else + // d = 0.75; + // OutColor = vec4(1.0,1.0,1.0,1.0); + // } + + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %BaseColor "BaseColor" +OpName %d "d" +OpName %OutColor "OutColor" +OpDecorate %BaseColor Location 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%_ptr_Input_float = OpTypePointer Input %float +%float_0 = OpConstant %float 0 +%bool = OpTypeBool +%uint_1 = OpConstant %uint 1 +%_ptr_Function_float = OpTypePointer Function %float +%float_0_25 = OpConstant %float 0.25 +%float_0_5 = OpConstant %float 0.5 +%float_0_75 = OpConstant %float 0.75 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%float_1 = OpConstant %float 1 +%23 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +)"; + + const std::string predefs_after = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %BaseColor "BaseColor" +OpName %OutColor "OutColor" +OpDecorate %BaseColor Location 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%float_1 = OpConstant %float 1 +%23 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +)"; + + const std::string func_before = + R"(%main = OpFunction %void None %7 +%24 = OpLabel +%d = OpVariable %_ptr_Function_float Function +%25 = OpAccessChain %_ptr_Input_float %BaseColor %uint_0 +%26 = OpLoad %float %25 +%27 = OpFOrdEqual %bool %26 %float_0 +OpSelectionMerge %28 None +OpBranchConditional %27 %29 %30 +%29 = OpLabel +%31 = OpAccessChain %_ptr_Input_float %BaseColor %uint_1 +%32 = OpLoad %float %31 +%33 = OpFOrdEqual %bool %32 %float_0 +OpSelectionMerge %34 None +OpBranchConditional %33 %35 %36 +%35 = OpLabel +OpStore %d %float_0 +OpBranch %34 +%36 = OpLabel +OpStore %d %float_0_25 +OpBranch %34 +%34 = OpLabel +OpBranch %28 +%30 = OpLabel +%37 = OpAccessChain %_ptr_Input_float %BaseColor %uint_1 +%38 = OpLoad %float %37 +%39 = OpFOrdEqual %bool %38 %float_0 +OpSelectionMerge %40 None +OpBranchConditional %39 %41 %42 +%41 = OpLabel +OpStore %d %float_0_5 +OpBranch %40 +%42 = OpLabel +OpStore %d %float_0_75 +OpBranch %40 +%40 = OpLabel +OpBranch %28 +%28 = OpLabel +OpStore %OutColor %23 +OpReturn +OpFunctionEnd +)"; + + const std::string func_after = + R"(%main = OpFunction %void None %7 +%24 = OpLabel +OpBranch %28 +%28 = OpLabel +OpStore %OutColor %23 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs_before + func_before, predefs_after + func_after, true, true); +} + +TEST_F(AggressiveDCETest, NoEliminateLiveIfThenElse) { + // #version 450 + // + // layout(location = 0) in vec4 BaseColor; + // layout(location = 0) out vec4 OutColor; + // + // void main() + // { + // float t; + // if (BaseColor.x == 0) + // t = BaseColor.y; + // else + // t = BaseColor.z; + // OutColor = vec4(t); + // } + + const std::string assembly = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %BaseColor "BaseColor" +OpName %t "t" +OpName %OutColor "OutColor" +OpDecorate %BaseColor Location 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%_ptr_Input_float = OpTypePointer Input %float +%float_0 = OpConstant %float 0 +%bool = OpTypeBool +%_ptr_Function_float = OpTypePointer Function %float +%uint_1 = OpConstant %uint 1 +%uint_2 = OpConstant %uint 2 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %7 +%20 = OpLabel +%t = OpVariable %_ptr_Function_float Function +%21 = OpAccessChain %_ptr_Input_float %BaseColor %uint_0 +%22 = OpLoad %float %21 +%23 = OpFOrdEqual %bool %22 %float_0 +OpSelectionMerge %24 None +OpBranchConditional %23 %25 %26 +%25 = OpLabel +%27 = OpAccessChain %_ptr_Input_float %BaseColor %uint_1 +%28 = OpLoad %float %27 +OpStore %t %28 +OpBranch %24 +%26 = OpLabel +%29 = OpAccessChain %_ptr_Input_float %BaseColor %uint_2 +%30 = OpLoad %float %29 +OpStore %t %30 +OpBranch %24 +%24 = OpLabel +%31 = OpLoad %float %t +%32 = OpCompositeConstruct %v4float %31 %31 %31 %31 +OpStore %OutColor %32 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(assembly, assembly, true, true); +} + +TEST_F(AggressiveDCETest, NoEliminateLiveIfThenElseNested) { + // #version 450 + // + // layout(location = 0) in vec4 BaseColor; + // layout(location = 0) out vec4 OutColor; + // + // void main() + // { + // float t; + // if (BaseColor.x == 0) + // if (BaseColor.y == 0) + // t = 0.0; + // else + // t = 0.25; + // else + // if (BaseColor.y == 0) + // t = 0.5; + // else + // t = 0.75; + // OutColor = vec4(t); + // } + + const std::string assembly = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %BaseColor "BaseColor" +OpName %t "t" +OpName %OutColor "OutColor" +OpDecorate %BaseColor Location 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%_ptr_Input_float = OpTypePointer Input %float +%float_0 = OpConstant %float 0 +%bool = OpTypeBool +%uint_1 = OpConstant %uint 1 +%_ptr_Function_float = OpTypePointer Function %float +%float_0_25 = OpConstant %float 0.25 +%float_0_5 = OpConstant %float 0.5 +%float_0_75 = OpConstant %float 0.75 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %7 +%22 = OpLabel +%t = OpVariable %_ptr_Function_float Function +%23 = OpAccessChain %_ptr_Input_float %BaseColor %uint_0 +%24 = OpLoad %float %23 +%25 = OpFOrdEqual %bool %24 %float_0 +OpSelectionMerge %26 None +OpBranchConditional %25 %27 %28 +%27 = OpLabel +%29 = OpAccessChain %_ptr_Input_float %BaseColor %uint_1 +%30 = OpLoad %float %29 +%31 = OpFOrdEqual %bool %30 %float_0 +OpSelectionMerge %32 None +OpBranchConditional %31 %33 %34 +%33 = OpLabel +OpStore %t %float_0 +OpBranch %32 +%34 = OpLabel +OpStore %t %float_0_25 +OpBranch %32 +%32 = OpLabel +OpBranch %26 +%28 = OpLabel +%35 = OpAccessChain %_ptr_Input_float %BaseColor %uint_1 +%36 = OpLoad %float %35 +%37 = OpFOrdEqual %bool %36 %float_0 +OpSelectionMerge %38 None +OpBranchConditional %37 %39 %40 +%39 = OpLabel +OpStore %t %float_0_5 +OpBranch %38 +%40 = OpLabel +OpStore %t %float_0_75 +OpBranch %38 +%38 = OpLabel +OpBranch %26 +%26 = OpLabel +%41 = OpLoad %float %t +%42 = OpCompositeConstruct %v4float %41 %41 %41 %41 +OpStore %OutColor %42 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(assembly, assembly, true, true); +} + +TEST_F(AggressiveDCETest, NoEliminateIfWithPhi) { + // Note: Assembly hand-optimized from GLSL + // + // #version 450 + // + // layout(location = 0) in vec4 BaseColor; + // layout(location = 0) out vec4 OutColor; + // + // void main() + // { + // float t; + // if (BaseColor.x == 0) + // t = 0.0; + // else + // t = 1.0; + // OutColor = vec4(t); + // } + + const std::string assembly = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %BaseColor "BaseColor" +OpName %OutColor "OutColor" +OpDecorate %BaseColor Location 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%6 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%_ptr_Input_float = OpTypePointer Input %float +%float_0 = OpConstant %float 0 +%bool = OpTypeBool +%float_1 = OpConstant %float 1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %6 +%17 = OpLabel +%18 = OpAccessChain %_ptr_Input_float %BaseColor %uint_0 +%19 = OpLoad %float %18 +%20 = OpFOrdEqual %bool %19 %float_0 +OpSelectionMerge %21 None +OpBranchConditional %20 %22 %23 +%22 = OpLabel +OpBranch %21 +%23 = OpLabel +OpBranch %21 +%21 = OpLabel +%24 = OpPhi %float %float_0 %22 %float_1 %23 +%25 = OpCompositeConstruct %v4float %24 %24 %24 %24 +OpStore %OutColor %25 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(assembly, assembly, true, true); +} + +TEST_F(AggressiveDCETest, NoEliminateIfBreak) { + // Note: Assembly optimized from GLSL + // + // #version 450 + // + // layout(location=0) in vec4 InColor; + // layout(location=0) out vec4 OutColor; + // + // void main() + // { + // float f = 0.0; + // for (;;) { + // f += 2.0; + // if (f > 20.0) + // break; + // } + // + // OutColor = InColor / f; + // } + + const std::string assembly = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %OutColor %InColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %f "f" +OpName %OutColor "OutColor" +OpName %InColor "InColor" +OpDecorate %OutColor Location 0 +OpDecorate %InColor Location 0 +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%float_2 = OpConstant %float 2 +%float_20 = OpConstant %float 20 +%bool = OpTypeBool +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%_ptr_Input_v4float = OpTypePointer Input %v4float +%InColor = OpVariable %_ptr_Input_v4float Input +%main = OpFunction %void None %7 +%17 = OpLabel +%f = OpVariable %_ptr_Function_float Function +OpStore %f %float_0 +OpBranch %18 +%18 = OpLabel +OpLoopMerge %19 %20 None +OpBranch %21 +%21 = OpLabel +%22 = OpLoad %float %f +%23 = OpFAdd %float %22 %float_2 +OpStore %f %23 +%24 = OpLoad %float %f +%25 = OpFOrdGreaterThan %bool %24 %float_20 +OpSelectionMerge %26 None +OpBranchConditional %25 %27 %26 +%27 = OpLabel +OpBranch %19 +%26 = OpLabel +OpBranch %20 +%20 = OpLabel +OpBranch %18 +%19 = OpLabel +%28 = OpLoad %v4float %InColor +%29 = OpLoad %float %f +%30 = OpCompositeConstruct %v4float %29 %29 %29 %29 +%31 = OpFDiv %v4float %28 %30 +OpStore %OutColor %31 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(assembly, assembly, true, true); +} + +TEST_F(AggressiveDCETest, NoEliminateIfBreak2) { + // Do not eliminate break as conditional branch with merge instruction + // Note: SPIR-V edited to add merge instruction before break. + // + // #version 430 + // + // layout(std430) buffer U_t + // { + // float g_F[10]; + // }; + // + // layout(location = 0)out float o; + // + // void main(void) + // { + // float s = 0.0; + // for (int i=0; i<10; i++) + // s += g_F[i]; + // o = s; + // } + + const std::string assembly = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %o +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 430 +OpName %main "main" +OpName %s "s" +OpName %i "i" +OpName %U_t "U_t" +OpMemberName %U_t 0 "g_F" +OpName %_ "" +OpName %o "o" +OpDecorate %_arr_float_uint_10 ArrayStride 4 +OpMemberDecorate %U_t 0 Offset 0 +OpDecorate %U_t BufferBlock +OpDecorate %_ DescriptorSet 0 +OpDecorate %o Location 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%uint = OpTypeInt 32 0 +%uint_10 = OpConstant %uint 10 +%_arr_float_uint_10 = OpTypeArray %float %uint_10 +%U_t = OpTypeStruct %_arr_float_uint_10 +%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t +%_ = OpVariable %_ptr_Uniform_U_t Uniform +%_ptr_Uniform_float = OpTypePointer Uniform %float +%int_1 = OpConstant %int 1 +%_ptr_Output_float = OpTypePointer Output %float +%o = OpVariable %_ptr_Output_float Output +%main = OpFunction %void None %10 +%25 = OpLabel +%s = OpVariable %_ptr_Function_float Function +%i = OpVariable %_ptr_Function_int Function +OpStore %s %float_0 +OpStore %i %int_0 +OpBranch %26 +%26 = OpLabel +OpLoopMerge %27 %28 None +OpBranch %29 +%29 = OpLabel +%30 = OpLoad %int %i +%31 = OpSLessThan %bool %30 %int_10 +OpSelectionMerge %32 None +OpBranchConditional %31 %32 %27 +%32 = OpLabel +%33 = OpLoad %int %i +%34 = OpAccessChain %_ptr_Uniform_float %_ %int_0 %33 +%35 = OpLoad %float %34 +%36 = OpLoad %float %s +%37 = OpFAdd %float %36 %35 +OpStore %s %37 +OpBranch %28 +%28 = OpLabel +%38 = OpLoad %int %i +%39 = OpIAdd %int %38 %int_1 +OpStore %i %39 +OpBranch %26 +%27 = OpLabel +%40 = OpLoad %float %s +OpStore %o %40 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(assembly, assembly, true, true); +} + +TEST_F(AggressiveDCETest, EliminateEntireUselessLoop) { + // #version 140 + // in vec4 BaseColor; + // + // layout(std140) uniform U_t + // { + // int g_I ; + // } ; + // + // void main() + // { + // vec4 v = BaseColor; + // float df = 0.0; + // int i = 0; + // while (i < g_I) { + // df = df * 0.5; + // i = i + 1; + // } + // gl_FragColor = v; + // } + + const std::string predefs1 = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +)"; + + const std::string names_before = + R"(OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %df "df" +OpName %i "i" +OpName %U_t "U_t" +OpMemberName %U_t 0 "g_I" +OpName %_ "" +OpName %gl_FragColor "gl_FragColor" +)"; + + const std::string names_after = + R"(OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +)"; + + const std::string predefs2_before = + R"(OpMemberDecorate %U_t 0 Offset 0 +OpDecorate %U_t Block +OpDecorate %_ DescriptorSet 0 +%void = OpTypeVoid +%11 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%U_t = OpTypeStruct %int +%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t +%_ = OpVariable %_ptr_Uniform_U_t Uniform +%_ptr_Uniform_int = OpTypePointer Uniform %int +%bool = OpTypeBool +%float_0_5 = OpConstant %float 0.5 +%int_1 = OpConstant %int 1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string predefs2_after = + R"(%void = OpTypeVoid +%11 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string func_before = + R"(%main = OpFunction %void None %11 +%27 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%df = OpVariable %_ptr_Function_float Function +%i = OpVariable %_ptr_Function_int Function +%28 = OpLoad %v4float %BaseColor +OpStore %v %28 +OpStore %df %float_0 +OpStore %i %int_0 +OpBranch %29 +%29 = OpLabel +OpLoopMerge %30 %31 None +OpBranch %32 +%32 = OpLabel +%33 = OpLoad %int %i +%34 = OpAccessChain %_ptr_Uniform_int %_ %int_0 +%35 = OpLoad %int %34 +%36 = OpSLessThan %bool %33 %35 +OpBranchConditional %36 %37 %30 +%37 = OpLabel +%38 = OpLoad %float %df +%39 = OpFMul %float %38 %float_0_5 +OpStore %df %39 +%40 = OpLoad %int %i +%41 = OpIAdd %int %40 %int_1 +OpStore %i %41 +OpBranch %31 +%31 = OpLabel +OpBranch %29 +%30 = OpLabel +%42 = OpLoad %v4float %v +OpStore %gl_FragColor %42 +OpReturn +OpFunctionEnd +)"; + + const std::string func_after = + R"(%main = OpFunction %void None %11 +%27 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%28 = OpLoad %v4float %BaseColor +OpStore %v %28 +OpBranch %29 +%29 = OpLabel +OpBranch %30 +%30 = OpLabel +%42 = OpLoad %v4float %v +OpStore %gl_FragColor %42 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs1 + names_before + predefs2_before + func_before, + predefs1 + names_after + predefs2_after + func_after, true, true); +} + +TEST_F(AggressiveDCETest, NoEliminateBusyLoop) { + // Note: SPIR-V edited to replace AtomicAdd(i,0) with AtomicLoad(i) + // + // #version 450 + // + // layout(std430) buffer I_t + // { + // int g_I; + // int g_I2; + // }; + // + // layout(location = 0) out int o; + // + // void main(void) + // { + // while (atomicAdd(g_I, 0) == 0) {} + // o = g_I2; + // } + + const std::string assembly = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %o +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %I_t "I_t" +OpMemberName %I_t 0 "g_I" +OpMemberName %I_t 1 "g_I2" +OpName %_ "" +OpName %o "o" +OpMemberDecorate %I_t 0 Offset 0 +OpMemberDecorate %I_t 1 Offset 4 +OpDecorate %I_t BufferBlock +OpDecorate %_ DescriptorSet 0 +OpDecorate %o Location 0 +%void = OpTypeVoid +%7 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%I_t = OpTypeStruct %int %int +%_ptr_Uniform_I_t = OpTypePointer Uniform %I_t +%_ = OpVariable %_ptr_Uniform_I_t Uniform +%int_0 = OpConstant %int 0 +%int_1 = OpConstant %int 1 +%_ptr_Uniform_int = OpTypePointer Uniform %int +%uint = OpTypeInt 32 0 +%uint_1 = OpConstant %uint 1 +%uint_0 = OpConstant %uint 0 +%bool = OpTypeBool +%_ptr_Output_int = OpTypePointer Output %int +%o = OpVariable %_ptr_Output_int Output +%main = OpFunction %void None %7 +%18 = OpLabel +OpBranch %19 +%19 = OpLabel +OpLoopMerge %20 %21 None +OpBranch %22 +%22 = OpLabel +%23 = OpAccessChain %_ptr_Uniform_int %_ %int_0 +%24 = OpAtomicLoad %int %23 %uint_1 %uint_0 +%25 = OpIEqual %bool %24 %int_0 +OpBranchConditional %25 %26 %20 +%26 = OpLabel +OpBranch %21 +%21 = OpLabel +OpBranch %19 +%20 = OpLabel +%27 = OpAccessChain %_ptr_Uniform_int %_ %int_1 +%28 = OpLoad %int %27 +OpStore %o %28 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(assembly, assembly, true, true); +} + +TEST_F(AggressiveDCETest, NoEliminateLiveLoop) { + // Note: SPIR-V optimized + // + // #version 430 + // + // layout(std430) buffer U_t + // { + // float g_F[10]; + // }; + // + // layout(location = 0)out float o; + // + // void main(void) + // { + // float s = 0.0; + // for (int i=0; i<10; i++) + // s += g_F[i]; + // o = s; + // } + + const std::string assembly = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %o +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 430 +OpName %main "main" +OpName %U_t "U_t" +OpMemberName %U_t 0 "g_F" +OpName %_ "" +OpName %o "o" +OpDecorate %_arr_float_uint_10 ArrayStride 4 +OpMemberDecorate %U_t 0 Offset 0 +OpDecorate %U_t BufferBlock +OpDecorate %_ DescriptorSet 0 +OpDecorate %o Location 0 +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%float_0 = OpConstant %float 0 +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%uint = OpTypeInt 32 0 +%uint_10 = OpConstant %uint 10 +%_arr_float_uint_10 = OpTypeArray %float %uint_10 +%U_t = OpTypeStruct %_arr_float_uint_10 +%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t +%_ = OpVariable %_ptr_Uniform_U_t Uniform +%_ptr_Uniform_float = OpTypePointer Uniform %float +%int_1 = OpConstant %int 1 +%_ptr_Output_float = OpTypePointer Output %float +%o = OpVariable %_ptr_Output_float Output +%main = OpFunction %void None %8 +%21 = OpLabel +OpBranch %22 +%22 = OpLabel +%23 = OpPhi %float %float_0 %21 %24 %25 +%26 = OpPhi %int %int_0 %21 %27 %25 +OpLoopMerge %28 %25 None +OpBranch %29 +%29 = OpLabel +%30 = OpSLessThan %bool %26 %int_10 +OpBranchConditional %30 %31 %28 +%31 = OpLabel +%32 = OpAccessChain %_ptr_Uniform_float %_ %int_0 %26 +%33 = OpLoad %float %32 +%24 = OpFAdd %float %23 %33 +OpBranch %25 +%25 = OpLabel +%27 = OpIAdd %int %26 %int_1 +OpBranch %22 +%28 = OpLabel +OpStore %o %23 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(assembly, assembly, true, true); +} + +TEST_F(AggressiveDCETest, EliminateEntireFunctionBody) { + // #version 450 + // + // layout(location = 0) in vec4 BaseColor; + // layout(location = 0) out vec4 OutColor; + // + // void main() + // { + // float d; + // if (BaseColor.x == 0) + // d = BaseColor.y; + // else + // d = BaseColor.z; + // } + + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %BaseColor "BaseColor" +OpName %d "d" +OpName %OutColor "OutColor" +OpDecorate %BaseColor Location 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%_ptr_Input_float = OpTypePointer Input %float +%float_0 = OpConstant %float 0 +%bool = OpTypeBool +%_ptr_Function_float = OpTypePointer Function %float +%uint_1 = OpConstant %uint 1 +%uint_2 = OpConstant %uint 2 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string predefs_after = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %BaseColor "BaseColor" +OpName %OutColor "OutColor" +OpDecorate %BaseColor Location 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string func_before = + R"(%main = OpFunction %void None %7 +%20 = OpLabel +%d = OpVariable %_ptr_Function_float Function +%21 = OpAccessChain %_ptr_Input_float %BaseColor %uint_0 +%22 = OpLoad %float %21 +%23 = OpFOrdEqual %bool %22 %float_0 +OpSelectionMerge %24 None +OpBranchConditional %23 %25 %26 +%25 = OpLabel +%27 = OpAccessChain %_ptr_Input_float %BaseColor %uint_1 +%28 = OpLoad %float %27 +OpStore %d %28 +OpBranch %24 +%26 = OpLabel +%29 = OpAccessChain %_ptr_Input_float %BaseColor %uint_2 +%30 = OpLoad %float %29 +OpStore %d %30 +OpBranch %24 +%24 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string func_after = + R"(%main = OpFunction %void None %7 +%20 = OpLabel +OpBranch %24 +%24 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs_before + func_before, predefs_after + func_after, true, true); +} + +TEST_F(AggressiveDCETest, EliminateUselessInnerLoop) { + // #version 430 + // + // layout(std430) buffer U_t + // { + // float g_F[10]; + // }; + // + // layout(location = 0)out float o; + // + // void main(void) + // { + // float s = 0.0; + // for (int i=0; i<10; i++) { + // for (int j=0; j<10; j++) { + // } + // s += g_F[i]; + // } + // o = s; + // } + + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %o +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 430 +OpName %main "main" +OpName %s "s" +OpName %i "i" +OpName %j "j" +OpName %U_t "U_t" +OpMemberName %U_t 0 "g_F" +OpName %_ "" +OpName %o "o" +OpDecorate %_arr_float_uint_10 ArrayStride 4 +OpMemberDecorate %U_t 0 Offset 0 +OpDecorate %U_t BufferBlock +OpDecorate %_ DescriptorSet 0 +OpDecorate %o Location 0 +%void = OpTypeVoid +%11 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%int_1 = OpConstant %int 1 +%uint = OpTypeInt 32 0 +%uint_10 = OpConstant %uint 10 +%_arr_float_uint_10 = OpTypeArray %float %uint_10 +%U_t = OpTypeStruct %_arr_float_uint_10 +%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t +%_ = OpVariable %_ptr_Uniform_U_t Uniform +%_ptr_Uniform_float = OpTypePointer Uniform %float +%_ptr_Output_float = OpTypePointer Output %float +%o = OpVariable %_ptr_Output_float Output +)"; + + const std::string predefs_after = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %o +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 430 +OpName %main "main" +OpName %s "s" +OpName %i "i" +OpName %U_t "U_t" +OpMemberName %U_t 0 "g_F" +OpName %_ "" +OpName %o "o" +OpDecorate %_arr_float_uint_10 ArrayStride 4 +OpMemberDecorate %U_t 0 Offset 0 +OpDecorate %U_t BufferBlock +OpDecorate %_ DescriptorSet 0 +OpDecorate %o Location 0 +%void = OpTypeVoid +%11 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%int_1 = OpConstant %int 1 +%uint = OpTypeInt 32 0 +%uint_10 = OpConstant %uint 10 +%_arr_float_uint_10 = OpTypeArray %float %uint_10 +%U_t = OpTypeStruct %_arr_float_uint_10 +%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t +%_ = OpVariable %_ptr_Uniform_U_t Uniform +%_ptr_Uniform_float = OpTypePointer Uniform %float +%_ptr_Output_float = OpTypePointer Output %float +%o = OpVariable %_ptr_Output_float Output +)"; + + const std::string func_before = + R"(%main = OpFunction %void None %11 +%26 = OpLabel +%s = OpVariable %_ptr_Function_float Function +%i = OpVariable %_ptr_Function_int Function +%j = OpVariable %_ptr_Function_int Function +OpStore %s %float_0 +OpStore %i %int_0 +OpBranch %27 +%27 = OpLabel +OpLoopMerge %28 %29 None +OpBranch %30 +%30 = OpLabel +%31 = OpLoad %int %i +%32 = OpSLessThan %bool %31 %int_10 +OpBranchConditional %32 %33 %28 +%33 = OpLabel +OpStore %j %int_0 +OpBranch %34 +%34 = OpLabel +OpLoopMerge %35 %36 None +OpBranch %37 +%37 = OpLabel +%38 = OpLoad %int %j +%39 = OpSLessThan %bool %38 %int_10 +OpBranchConditional %39 %40 %35 +%40 = OpLabel +OpBranch %36 +%36 = OpLabel +%41 = OpLoad %int %j +%42 = OpIAdd %int %41 %int_1 +OpStore %j %42 +OpBranch %34 +%35 = OpLabel +%43 = OpLoad %int %i +%44 = OpAccessChain %_ptr_Uniform_float %_ %int_0 %43 +%45 = OpLoad %float %44 +%46 = OpLoad %float %s +%47 = OpFAdd %float %46 %45 +OpStore %s %47 +OpBranch %29 +%29 = OpLabel +%48 = OpLoad %int %i +%49 = OpIAdd %int %48 %int_1 +OpStore %i %49 +OpBranch %27 +%28 = OpLabel +%50 = OpLoad %float %s +OpStore %o %50 +OpReturn +OpFunctionEnd +)"; + + const std::string func_after = + R"(%main = OpFunction %void None %11 +%26 = OpLabel +%s = OpVariable %_ptr_Function_float Function +%i = OpVariable %_ptr_Function_int Function +OpStore %s %float_0 +OpStore %i %int_0 +OpBranch %27 +%27 = OpLabel +OpLoopMerge %28 %29 None +OpBranch %30 +%30 = OpLabel +%31 = OpLoad %int %i +%32 = OpSLessThan %bool %31 %int_10 +OpBranchConditional %32 %33 %28 +%33 = OpLabel +OpBranch %34 +%34 = OpLabel +OpBranch %35 +%35 = OpLabel +%43 = OpLoad %int %i +%44 = OpAccessChain %_ptr_Uniform_float %_ %int_0 %43 +%45 = OpLoad %float %44 +%46 = OpLoad %float %s +%47 = OpFAdd %float %46 %45 +OpStore %s %47 +OpBranch %29 +%29 = OpLabel +%48 = OpLoad %int %i +%49 = OpIAdd %int %48 %int_1 +OpStore %i %49 +OpBranch %27 +%28 = OpLabel +%50 = OpLoad %float %s +OpStore %o %50 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs_before + func_before, predefs_after + func_after, true, true); +} + +TEST_F(AggressiveDCETest, EliminateUselessNestedLoopWithIf) { + // #version 430 + // + // layout(std430) buffer U_t + // { + // float g_F[10][10]; + // }; + // + // layout(location = 0)out float o; + // + // void main(void) + // { + // float s = 0.0; + // for (int i=0; i<10; i++) { + // for (int j=0; j<10; j++) { + // float t = g_F[i][j]; + // if (t > 0.0) + // s += t; + // } + // } + // o = 0.0; + // } + + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %o +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 430 +OpName %main "main" +OpName %s "s" +OpName %i "i" +OpName %j "j" +OpName %U_t "U_t" +OpMemberName %U_t 0 "g_F" +OpName %_ "" +OpName %o "o" +OpDecorate %_arr_float_uint_10 ArrayStride 4 +OpDecorate %_arr__arr_float_uint_10_uint_10 ArrayStride 40 +OpMemberDecorate %U_t 0 Offset 0 +OpDecorate %U_t BufferBlock +OpDecorate %_ DescriptorSet 0 +OpDecorate %o Location 0 +%void = OpTypeVoid +%12 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%uint = OpTypeInt 32 0 +%uint_10 = OpConstant %uint 10 +%_arr_float_uint_10 = OpTypeArray %float %uint_10 +%_arr__arr_float_uint_10_uint_10 = OpTypeArray %_arr_float_uint_10 %uint_10 +%U_t = OpTypeStruct %_arr__arr_float_uint_10_uint_10 +%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t +%_ = OpVariable %_ptr_Uniform_U_t Uniform +%_ptr_Uniform_float = OpTypePointer Uniform %float +%int_1 = OpConstant %int 1 +%_ptr_Output_float = OpTypePointer Output %float +%o = OpVariable %_ptr_Output_float Output +)"; + + const std::string predefs_after = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %o +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 430 +OpName %main "main" +OpName %o "o" +OpDecorate %o Location 0 +%void = OpTypeVoid +%12 = OpTypeFunction %void +%float = OpTypeFloat 32 +%float_0 = OpConstant %float 0 +%_ptr_Output_float = OpTypePointer Output %float +%o = OpVariable %_ptr_Output_float Output +)"; + + const std::string func_before = + R"(%main = OpFunction %void None %12 +%27 = OpLabel +%s = OpVariable %_ptr_Function_float Function +%i = OpVariable %_ptr_Function_int Function +%j = OpVariable %_ptr_Function_int Function +OpStore %s %float_0 +OpStore %i %int_0 +OpBranch %28 +%28 = OpLabel +OpLoopMerge %29 %30 None +OpBranch %31 +%31 = OpLabel +%32 = OpLoad %int %i +%33 = OpSLessThan %bool %32 %int_10 +OpBranchConditional %33 %34 %29 +%34 = OpLabel +OpStore %j %int_0 +OpBranch %35 +%35 = OpLabel +OpLoopMerge %36 %37 None +OpBranch %38 +%38 = OpLabel +%39 = OpLoad %int %j +%40 = OpSLessThan %bool %39 %int_10 +OpBranchConditional %40 %41 %36 +%41 = OpLabel +%42 = OpLoad %int %i +%43 = OpLoad %int %j +%44 = OpAccessChain %_ptr_Uniform_float %_ %int_0 %42 %43 +%45 = OpLoad %float %44 +%46 = OpFOrdGreaterThan %bool %45 %float_0 +OpSelectionMerge %47 None +OpBranchConditional %46 %48 %47 +%48 = OpLabel +%49 = OpLoad %float %s +%50 = OpFAdd %float %49 %45 +OpStore %s %50 +OpBranch %47 +%47 = OpLabel +OpBranch %37 +%37 = OpLabel +%51 = OpLoad %int %j +%52 = OpIAdd %int %51 %int_1 +OpStore %j %52 +OpBranch %35 +%36 = OpLabel +OpBranch %30 +%30 = OpLabel +%53 = OpLoad %int %i +%54 = OpIAdd %int %53 %int_1 +OpStore %i %54 +OpBranch %28 +%29 = OpLabel +OpStore %o %float_0 +OpReturn +OpFunctionEnd +)"; + + const std::string func_after = + R"(%main = OpFunction %void None %12 +%27 = OpLabel +OpBranch %28 +%28 = OpLabel +OpBranch %29 +%29 = OpLabel +OpStore %o %float_0 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs_before + func_before, predefs_after + func_after, true, true); +} + +TEST_F(AggressiveDCETest, EliminateEmptyIfBeforeContinue) { + // #version 430 + // + // layout(location = 0)out float o; + // + // void main(void) + // { + // float s = 0.0; + // for (int i=0; i<10; i++) { + // s += 1.0; + // if (i > s) {} + // } + // o = s; + // } + + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %3 +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 430 +OpSourceExtension "GL_GOOGLE_cpp_style_line_directive" +OpSourceExtension "GL_GOOGLE_include_directive" +OpName %main "main" +OpDecorate %3 Location 0 +%void = OpTypeVoid +%5 = OpTypeFunction %void +%float = OpTypeFloat 32 +%float_0 = OpConstant %float 0 +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%float_1 = OpConstant %float 1 +%int_1 = OpConstant %int 1 +%_ptr_Output_float = OpTypePointer Output %float +%3 = OpVariable %_ptr_Output_float Output +)"; + + const std::string predefs_after = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %3 +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 430 +OpSourceExtension "GL_GOOGLE_cpp_style_line_directive" +OpSourceExtension "GL_GOOGLE_include_directive" +OpName %main "main" +OpDecorate %3 Location 0 +%void = OpTypeVoid +%5 = OpTypeFunction %void +%float = OpTypeFloat 32 +%float_0 = OpConstant %float 0 +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%float_1 = OpConstant %float 1 +%int_1 = OpConstant %int 1 +%_ptr_Output_float = OpTypePointer Output %float +%3 = OpVariable %_ptr_Output_float Output +)"; + + const std::string func_before = + R"(%main = OpFunction %void None %5 +%16 = OpLabel +OpBranch %17 +%17 = OpLabel +%18 = OpPhi %float %float_0 %16 %19 %20 +%21 = OpPhi %int %int_0 %16 %22 %20 +OpLoopMerge %23 %20 None +OpBranch %24 +%24 = OpLabel +%25 = OpSLessThan %bool %21 %int_10 +OpBranchConditional %25 %26 %23 +%26 = OpLabel +%19 = OpFAdd %float %18 %float_1 +%27 = OpConvertFToS %int %19 +%28 = OpSGreaterThan %bool %21 %27 +OpSelectionMerge %20 None +OpBranchConditional %28 %29 %20 +%29 = OpLabel +OpBranch %20 +%20 = OpLabel +%22 = OpIAdd %int %21 %int_1 +OpBranch %17 +%23 = OpLabel +OpStore %3 %18 +OpReturn +OpFunctionEnd +)"; + + const std::string func_after = + R"(%main = OpFunction %void None %5 +%16 = OpLabel +OpBranch %17 +%17 = OpLabel +%18 = OpPhi %float %float_0 %16 %19 %20 +%21 = OpPhi %int %int_0 %16 %22 %20 +OpLoopMerge %23 %20 None +OpBranch %24 +%24 = OpLabel +%25 = OpSLessThan %bool %21 %int_10 +OpBranchConditional %25 %26 %23 +%26 = OpLabel +%19 = OpFAdd %float %18 %float_1 +OpBranch %20 +%20 = OpLabel +%22 = OpIAdd %int %21 %int_1 +OpBranch %17 +%23 = OpLabel +OpStore %3 %18 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs_before + func_before, predefs_after + func_after, true, true); +} + +TEST_F(AggressiveDCETest, NoEliminateLiveNestedLoopWithIf) { + // Note: SPIR-V optimized + // + // #version 430 + // + // layout(std430) buffer U_t + // { + // float g_F[10][10]; + // }; + // + // layout(location = 0)out float o; + // + // void main(void) + // { + // float s = 0.0; + // for (int i=0; i<10; i++) { + // for (int j=0; j<10; j++) { + // float t = g_F[i][j]; + // if (t > 0.0) + // s += t; + // } + // } + // o = s; + // } + + const std::string assembly = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %o +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 430 +OpName %main "main" +OpName %s "s" +OpName %i "i" +OpName %j "j" +OpName %U_t "U_t" +OpMemberName %U_t 0 "g_F" +OpName %_ "" +OpName %o "o" +OpDecorate %_arr_float_uint_10 ArrayStride 4 +OpDecorate %_arr__arr_float_uint_10_uint_10 ArrayStride 40 +OpMemberDecorate %U_t 0 Offset 0 +OpDecorate %U_t BufferBlock +OpDecorate %_ DescriptorSet 0 +OpDecorate %o Location 0 +%void = OpTypeVoid +%12 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%uint = OpTypeInt 32 0 +%uint_10 = OpConstant %uint 10 +%_arr_float_uint_10 = OpTypeArray %float %uint_10 +%_arr__arr_float_uint_10_uint_10 = OpTypeArray %_arr_float_uint_10 %uint_10 +%U_t = OpTypeStruct %_arr__arr_float_uint_10_uint_10 +%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t +%_ = OpVariable %_ptr_Uniform_U_t Uniform +%_ptr_Uniform_float = OpTypePointer Uniform %float +%int_1 = OpConstant %int 1 +%_ptr_Output_float = OpTypePointer Output %float +%o = OpVariable %_ptr_Output_float Output +%main = OpFunction %void None %12 +%27 = OpLabel +%s = OpVariable %_ptr_Function_float Function +%i = OpVariable %_ptr_Function_int Function +%j = OpVariable %_ptr_Function_int Function +OpStore %s %float_0 +OpStore %i %int_0 +OpBranch %28 +%28 = OpLabel +OpLoopMerge %29 %30 None +OpBranch %31 +%31 = OpLabel +%32 = OpLoad %int %i +%33 = OpSLessThan %bool %32 %int_10 +OpBranchConditional %33 %34 %29 +%34 = OpLabel +OpStore %j %int_0 +OpBranch %35 +%35 = OpLabel +OpLoopMerge %36 %37 None +OpBranch %38 +%38 = OpLabel +%39 = OpLoad %int %j +%40 = OpSLessThan %bool %39 %int_10 +OpBranchConditional %40 %41 %36 +%41 = OpLabel +%42 = OpLoad %int %i +%43 = OpLoad %int %j +%44 = OpAccessChain %_ptr_Uniform_float %_ %int_0 %42 %43 +%45 = OpLoad %float %44 +%46 = OpFOrdGreaterThan %bool %45 %float_0 +OpSelectionMerge %47 None +OpBranchConditional %46 %48 %47 +%48 = OpLabel +%49 = OpLoad %float %s +%50 = OpFAdd %float %49 %45 +OpStore %s %50 +OpBranch %47 +%47 = OpLabel +OpBranch %37 +%37 = OpLabel +%51 = OpLoad %int %j +%52 = OpIAdd %int %51 %int_1 +OpStore %j %52 +OpBranch %35 +%36 = OpLabel +OpBranch %30 +%30 = OpLabel +%53 = OpLoad %int %i +%54 = OpIAdd %int %53 %int_1 +OpStore %i %54 +OpBranch %28 +%29 = OpLabel +%55 = OpLoad %float %s +OpStore %o %55 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(assembly, assembly, true, true); +} + +TEST_F(AggressiveDCETest, NoEliminateIfContinue) { + // Do not eliminate continue embedded in if construct + // + // #version 430 + // + // layout(std430) buffer U_t + // { + // float g_F[10]; + // }; + // + // layout(location = 0)out float o; + // + // void main(void) + // { + // float s = 0.0; + // for (int i=0; i<10; i++) { + // if (i % 2 == 0) continue; + // s += g_F[i]; + // } + // o = s; + // } + + const std::string assembly = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %o +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 430 +OpName %main "main" +OpName %s "s" +OpName %i "i" +OpName %U_t "U_t" +OpMemberName %U_t 0 "g_F" +OpName %_ "" +OpName %o "o" +OpDecorate %_arr_float_uint_10 ArrayStride 4 +OpMemberDecorate %U_t 0 Offset 0 +OpDecorate %U_t BufferBlock +OpDecorate %_ DescriptorSet 0 +OpDecorate %o Location 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%int_2 = OpConstant %int 2 +%uint = OpTypeInt 32 0 +%uint_10 = OpConstant %uint 10 +%_arr_float_uint_10 = OpTypeArray %float %uint_10 +%U_t = OpTypeStruct %_arr_float_uint_10 +%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t +%_ = OpVariable %_ptr_Uniform_U_t Uniform +%_ptr_Uniform_float = OpTypePointer Uniform %float +%int_1 = OpConstant %int 1 +%_ptr_Output_float = OpTypePointer Output %float +%o = OpVariable %_ptr_Output_float Output +%main = OpFunction %void None %10 +%26 = OpLabel +%s = OpVariable %_ptr_Function_float Function +%i = OpVariable %_ptr_Function_int Function +OpStore %s %float_0 +OpStore %i %int_0 +OpBranch %27 +%27 = OpLabel +OpLoopMerge %28 %29 None +OpBranch %30 +%30 = OpLabel +%31 = OpLoad %int %i +%32 = OpSLessThan %bool %31 %int_10 +OpBranchConditional %32 %33 %28 +%33 = OpLabel +%34 = OpLoad %int %i +%35 = OpSMod %int %34 %int_2 +%36 = OpIEqual %bool %35 %int_0 +OpSelectionMerge %37 None +OpBranchConditional %36 %38 %37 +%38 = OpLabel +OpBranch %29 +%37 = OpLabel +%39 = OpLoad %int %i +%40 = OpAccessChain %_ptr_Uniform_float %_ %int_0 %39 +%41 = OpLoad %float %40 +%42 = OpLoad %float %s +%43 = OpFAdd %float %42 %41 +OpStore %s %43 +OpBranch %29 +%29 = OpLabel +%44 = OpLoad %int %i +%45 = OpIAdd %int %44 %int_1 +OpStore %i %45 +OpBranch %27 +%28 = OpLabel +%46 = OpLoad %float %s +OpStore %o %46 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(assembly, assembly, true, true); +} + +TEST_F(AggressiveDCETest, NoEliminateIfContinue2) { + // Do not eliminate continue not embedded in if construct + // + // #version 430 + // + // layout(std430) buffer U_t + // { + // float g_F[10]; + // }; + // + // layout(location = 0)out float o; + // + // void main(void) + // { + // float s = 0.0; + // for (int i=0; i<10; i++) { + // if (i % 2 == 0) continue; + // s += g_F[i]; + // } + // o = s; + // } + + const std::string assembly = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %o +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 430 +OpName %main "main" +OpName %s "s" +OpName %i "i" +OpName %U_t "U_t" +OpMemberName %U_t 0 "g_F" +OpName %_ "" +OpName %o "o" +OpDecorate %_arr_float_uint_10 ArrayStride 4 +OpMemberDecorate %U_t 0 Offset 0 +OpDecorate %U_t BufferBlock +OpDecorate %_ DescriptorSet 0 +OpDecorate %o Location 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%int_2 = OpConstant %int 2 +%uint = OpTypeInt 32 0 +%uint_10 = OpConstant %uint 10 +%_arr_float_uint_10 = OpTypeArray %float %uint_10 +%U_t = OpTypeStruct %_arr_float_uint_10 +%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t +%_ = OpVariable %_ptr_Uniform_U_t Uniform +%_ptr_Uniform_float = OpTypePointer Uniform %float +%int_1 = OpConstant %int 1 +%_ptr_Output_float = OpTypePointer Output %float +%o = OpVariable %_ptr_Output_float Output +%main = OpFunction %void None %10 +%26 = OpLabel +%s = OpVariable %_ptr_Function_float Function +%i = OpVariable %_ptr_Function_int Function +OpStore %s %float_0 +OpStore %i %int_0 +OpBranch %27 +%27 = OpLabel +OpLoopMerge %28 %29 None +OpBranch %30 +%30 = OpLabel +%31 = OpLoad %int %i +%32 = OpSLessThan %bool %31 %int_10 +OpBranchConditional %32 %33 %28 +%33 = OpLabel +%34 = OpLoad %int %i +%35 = OpSMod %int %34 %int_2 +%36 = OpIEqual %bool %35 %int_0 +OpBranchConditional %36 %29 %37 +%37 = OpLabel +%38 = OpLoad %int %i +%39 = OpAccessChain %_ptr_Uniform_float %_ %int_0 %38 +%40 = OpLoad %float %39 +%41 = OpLoad %float %s +%42 = OpFAdd %float %41 %40 +OpStore %s %42 +OpBranch %29 +%29 = OpLabel +%43 = OpLoad %int %i +%44 = OpIAdd %int %43 %int_1 +OpStore %i %44 +OpBranch %27 +%28 = OpLabel +%45 = OpLoad %float %s +OpStore %o %45 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(assembly, assembly, true, true); +} + +TEST_F(AggressiveDCETest, NoEliminateIfContinue3) { + // Do not eliminate continue as conditional branch with merge instruction + // Note: SPIR-V edited to add merge instruction before continue. + // + // #version 430 + // + // layout(std430) buffer U_t + // { + // float g_F[10]; + // }; + // + // layout(location = 0)out float o; + // + // void main(void) + // { + // float s = 0.0; + // for (int i=0; i<10; i++) { + // if (i % 2 == 0) continue; + // s += g_F[i]; + // } + // o = s; + // } + + const std::string assembly = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %o +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 430 +OpName %main "main" +OpName %s "s" +OpName %i "i" +OpName %U_t "U_t" +OpMemberName %U_t 0 "g_F" +OpName %_ "" +OpName %o "o" +OpDecorate %_arr_float_uint_10 ArrayStride 4 +OpMemberDecorate %U_t 0 Offset 0 +OpDecorate %U_t BufferBlock +OpDecorate %_ DescriptorSet 0 +OpDecorate %o Location 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%int_2 = OpConstant %int 2 +%uint = OpTypeInt 32 0 +%uint_10 = OpConstant %uint 10 +%_arr_float_uint_10 = OpTypeArray %float %uint_10 +%U_t = OpTypeStruct %_arr_float_uint_10 +%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t +%_ = OpVariable %_ptr_Uniform_U_t Uniform +%_ptr_Uniform_float = OpTypePointer Uniform %float +%int_1 = OpConstant %int 1 +%_ptr_Output_float = OpTypePointer Output %float +%o = OpVariable %_ptr_Output_float Output +%main = OpFunction %void None %10 +%26 = OpLabel +%s = OpVariable %_ptr_Function_float Function +%i = OpVariable %_ptr_Function_int Function +OpStore %s %float_0 +OpStore %i %int_0 +OpBranch %27 +%27 = OpLabel +OpLoopMerge %28 %29 None +OpBranch %30 +%30 = OpLabel +%31 = OpLoad %int %i +%32 = OpSLessThan %bool %31 %int_10 +OpBranchConditional %32 %33 %28 +%33 = OpLabel +%34 = OpLoad %int %i +%35 = OpSMod %int %34 %int_2 +%36 = OpIEqual %bool %35 %int_0 +OpSelectionMerge %37 None +OpBranchConditional %36 %29 %37 +%37 = OpLabel +%38 = OpLoad %int %i +%39 = OpAccessChain %_ptr_Uniform_float %_ %int_0 %38 +%40 = OpLoad %float %39 +%41 = OpLoad %float %s +%42 = OpFAdd %float %41 %40 +OpStore %s %42 +OpBranch %29 +%29 = OpLabel +%43 = OpLoad %int %i +%44 = OpIAdd %int %43 %int_1 +OpStore %i %44 +OpBranch %27 +%28 = OpLabel +%45 = OpLoad %float %s +OpStore %o %45 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(assembly, assembly, true, true); +} + +// This is not valid input and ADCE does not support variable pointers and only +// supports shaders. +TEST_F(AggressiveDCETest, PointerVariable) { + // ADCE is able to handle code that contains a load whose base address + // comes from a load and not an OpVariable. I want to see an instruction + // removed to be sure that ADCE is not exiting early. + + const std::string before = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "main" %2 +OpExecutionMode %1 OriginUpperLeft +OpMemberDecorate %_struct_3 0 Offset 0 +OpDecorate %_runtimearr__struct_3 ArrayStride 16 +OpMemberDecorate %_struct_5 0 Offset 0 +OpDecorate %_struct_5 BufferBlock +OpMemberDecorate %_struct_6 0 Offset 0 +OpDecorate %_struct_6 BufferBlock +OpDecorate %2 Location 0 +OpDecorate %7 DescriptorSet 0 +OpDecorate %7 Binding 0 +OpDecorate %8 DescriptorSet 0 +OpDecorate %8 Binding 1 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%uint = OpTypeInt 32 0 +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_ptr_Uniform_v4float = OpTypePointer Uniform %v4float +%_struct_3 = OpTypeStruct %v4float +%_runtimearr__struct_3 = OpTypeRuntimeArray %_struct_3 +%_struct_5 = OpTypeStruct %_runtimearr__struct_3 +%_ptr_Uniform__struct_5 = OpTypePointer Uniform %_struct_5 +%_struct_6 = OpTypeStruct %int +%_ptr_Uniform__struct_6 = OpTypePointer Uniform %_struct_6 +%_ptr_Function__ptr_Uniform__struct_5 = OpTypePointer Function %_ptr_Uniform__struct_5 +%_ptr_Function__ptr_Uniform__struct_6 = OpTypePointer Function %_ptr_Uniform__struct_6 +%int_0 = OpConstant %int 0 +%uint_0 = OpConstant %uint 0 +%2 = OpVariable %_ptr_Output_v4float Output +%7 = OpVariable %_ptr_Uniform__struct_5 Uniform +%8 = OpVariable %_ptr_Uniform__struct_6 Uniform +%1 = OpFunction %void None %10 +%23 = OpLabel +%24 = OpVariable %_ptr_Function__ptr_Uniform__struct_5 Function +OpStore %24 %7 +%26 = OpLoad %_ptr_Uniform__struct_5 %24 +%27 = OpAccessChain %_ptr_Uniform_v4float %26 %int_0 %uint_0 %int_0 +%28 = OpLoad %v4float %27 +%29 = OpCopyObject %v4float %28 +OpStore %2 %28 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "main" %2 +OpExecutionMode %1 OriginUpperLeft +OpMemberDecorate %_struct_3 0 Offset 0 +OpDecorate %_runtimearr__struct_3 ArrayStride 16 +OpMemberDecorate %_struct_5 0 Offset 0 +OpDecorate %_struct_5 BufferBlock +OpDecorate %2 Location 0 +OpDecorate %7 DescriptorSet 0 +OpDecorate %7 Binding 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%uint = OpTypeInt 32 0 +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_ptr_Uniform_v4float = OpTypePointer Uniform %v4float +%_struct_3 = OpTypeStruct %v4float +%_runtimearr__struct_3 = OpTypeRuntimeArray %_struct_3 +%_struct_5 = OpTypeStruct %_runtimearr__struct_3 +%_ptr_Uniform__struct_5 = OpTypePointer Uniform %_struct_5 +%_ptr_Function__ptr_Uniform__struct_5 = OpTypePointer Function %_ptr_Uniform__struct_5 +%int_0 = OpConstant %int 0 +%uint_0 = OpConstant %uint 0 +%2 = OpVariable %_ptr_Output_v4float Output +%7 = OpVariable %_ptr_Uniform__struct_5 Uniform +%1 = OpFunction %void None %10 +%23 = OpLabel +%24 = OpVariable %_ptr_Function__ptr_Uniform__struct_5 Function +OpStore %24 %7 +%25 = OpLoad %_ptr_Uniform__struct_5 %24 +%26 = OpAccessChain %_ptr_Uniform_v4float %25 %int_0 %uint_0 %int_0 +%27 = OpLoad %v4float %26 +OpStore %2 %27 +OpReturn +OpFunctionEnd +)"; + + // The input is not valid and ADCE only supports shaders, but not variable + // pointers. Workaround this by enabling relaxed logical pointers in the + // validator. + ValidatorOptions()->relax_logical_pointer = true; + SinglePassRunAndCheck(before, after, true, true); +} + +// %dead is unused. Make sure we remove it along with its name. +TEST_F(AggressiveDCETest, RemoveUnreferenced) { + const std::string before = + R"(OpCapability Shader +OpCapability Linkage +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 150 +OpName %main "main" +OpName %dead "dead" +%void = OpTypeVoid +%5 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Private_float = OpTypePointer Private %float +%dead = OpVariable %_ptr_Private_float Private +%main = OpFunction %void None %5 +%8 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(OpCapability Shader +OpCapability Linkage +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 150 +OpName %main "main" +%void = OpTypeVoid +%5 = OpTypeFunction %void +%main = OpFunction %void None %5 +%8 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(before, after, true, true); +} + +// Delete %dead because it is unreferenced. Then %initializer becomes +// unreferenced, so remove it as well. +TEST_F(AggressiveDCETest, RemoveUnreferencedWithInit1) { + const std::string before = + R"(OpCapability Shader +OpCapability Linkage +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 150 +OpName %main "main" +OpName %dead "dead" +OpName %initializer "initializer" +%void = OpTypeVoid +%6 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Private_float = OpTypePointer Private %float +%initializer = OpVariable %_ptr_Private_float Private +%dead = OpVariable %_ptr_Private_float Private %initializer +%main = OpFunction %void None %6 +%9 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(OpCapability Shader +OpCapability Linkage +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 150 +OpName %main "main" +%void = OpTypeVoid +%6 = OpTypeFunction %void +%main = OpFunction %void None %6 +%9 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(before, after, true, true); +} + +// Keep %live because it is used, and its initializer. +TEST_F(AggressiveDCETest, KeepReferenced) { + const std::string before = + R"(OpCapability Shader +OpCapability Linkage +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %output +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 150 +OpName %main "main" +OpName %live "live" +OpName %initializer "initializer" +OpName %output "output" +%void = OpTypeVoid +%6 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Private_float = OpTypePointer Private %float +%initializer = OpVariable %_ptr_Private_float Private +%live = OpVariable %_ptr_Private_float Private %initializer +%_ptr_Output_float = OpTypePointer Output %float +%output = OpVariable %_ptr_Output_float Output +%main = OpFunction %void None %6 +%9 = OpLabel +%10 = OpLoad %float %live +OpStore %output %10 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(before, before, true, true); +} + +// This test that the decoration associated with a variable are removed when the +// variable is removed. +TEST_F(AggressiveDCETest, RemoveVariableAndDecorations) { + const std::string before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %main "main" +OpSource GLSL 450 +OpName %main "main" +OpName %B "B" +OpMemberName %B 0 "a" +OpName %Bdat "Bdat" +OpMemberDecorate %B 0 Offset 0 +OpDecorate %B BufferBlock +OpDecorate %Bdat DescriptorSet 0 +OpDecorate %Bdat Binding 0 +%void = OpTypeVoid +%6 = OpTypeFunction %void +%uint = OpTypeInt 32 0 +%B = OpTypeStruct %uint +%_ptr_Uniform_B = OpTypePointer Uniform %B +%Bdat = OpVariable %_ptr_Uniform_B Uniform +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%uint_1 = OpConstant %uint 1 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%main = OpFunction %void None %6 +%13 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %main "main" +OpSource GLSL 450 +OpName %main "main" +%void = OpTypeVoid +%6 = OpTypeFunction %void +%main = OpFunction %void None %6 +%13 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(before, after, true, true); +} + +TEST_F(AggressiveDCETest, DeadNestedSwitch) { + const std::string text = R"( +; CHECK: OpLabel +; CHECK: OpBranch [[block:%\w+]] +; CHECK-NOT: OpSwitch +; CHECK-NEXT: [[block]] = OpLabel +; CHECK: OpBranch [[block:%\w+]] +; CHECK-NOT: OpSwitch +; CHECK-NEXT: [[block]] = OpLabel +; CHECK-NEXT: OpStore +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" %x +OpExecutionMode %func OriginUpperLeft +OpName %func "func" +%void = OpTypeVoid +%1 = OpTypeFunction %void +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_ptr_Output = OpTypePointer Output %uint +%uint_ptr_Input = OpTypePointer Input %uint +%x = OpVariable %uint_ptr_Output Output +%a = OpVariable %uint_ptr_Input Input +%func = OpFunction %void None %1 +%entry = OpLabel +OpBranch %header +%header = OpLabel +%ld = OpLoad %uint %a +OpLoopMerge %merge %continue None +OpBranch %postheader +%postheader = OpLabel +; This switch doesn't require an OpSelectionMerge and is nested in the dead loop. +OpSwitch %ld %merge 0 %extra 1 %continue +%extra = OpLabel +OpBranch %continue +%continue = OpLabel +OpBranch %header +%merge = OpLabel +OpStore %x %uint_0 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(AggressiveDCETest, LiveNestedSwitch) { + const std::string text = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" %3 %10 +OpExecutionMode %func OriginUpperLeft +OpName %func "func" +%void = OpTypeVoid +%1 = OpTypeFunction %void +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%_ptr_Output_uint = OpTypePointer Output %uint +%_ptr_Input_uint = OpTypePointer Input %uint +%3 = OpVariable %_ptr_Output_uint Output +%10 = OpVariable %_ptr_Input_uint Input +%func = OpFunction %void None %1 +%11 = OpLabel +OpBranch %12 +%12 = OpLabel +%13 = OpLoad %uint %10 +OpLoopMerge %14 %15 None +OpBranch %16 +%16 = OpLabel +OpSwitch %13 %14 0 %17 1 %15 +%17 = OpLabel +OpStore %3 %uint_1 +OpBranch %15 +%15 = OpLabel +OpBranch %12 +%14 = OpLabel +OpStore %3 %uint_0 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(text, text, false, true); +} + +TEST_F(AggressiveDCETest, BasicDeleteDeadFunction) { + // The function Dead should be removed because it is never called. + const std::vector common_code = { + // clang-format off + "OpCapability Shader", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Fragment %main \"main\"", + "OpName %main \"main\"", + "OpName %Live \"Live\"", + "%void = OpTypeVoid", + "%7 = OpTypeFunction %void", + "%main = OpFunction %void None %7", + "%15 = OpLabel", + "%16 = OpFunctionCall %void %Live", + "%17 = OpFunctionCall %void %Live", + "OpReturn", + "OpFunctionEnd", + "%Live = OpFunction %void None %7", + "%20 = OpLabel", + "OpReturn", + "OpFunctionEnd" + // clang-format on + }; + + const std::vector dead_function = { + // clang-format off + "%Dead = OpFunction %void None %7", + "%19 = OpLabel", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck( + JoinAllInsts(Concat(common_code, dead_function)), + JoinAllInsts(common_code), /* skip_nop = */ true); +} + +TEST_F(AggressiveDCETest, BasicKeepLiveFunction) { + // Everything is reachable from an entry point, so no functions should be + // deleted. + const std::vector text = { + // clang-format off + "OpCapability Shader", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Fragment %main \"main\"", + "OpName %main \"main\"", + "OpName %Live1 \"Live1\"", + "OpName %Live2 \"Live2\"", + "%void = OpTypeVoid", + "%7 = OpTypeFunction %void", + "%main = OpFunction %void None %7", + "%15 = OpLabel", + "%16 = OpFunctionCall %void %Live2", + "%17 = OpFunctionCall %void %Live1", + "OpReturn", + "OpFunctionEnd", + "%Live1 = OpFunction %void None %7", + "%19 = OpLabel", + "OpReturn", + "OpFunctionEnd", + "%Live2 = OpFunction %void None %7", + "%20 = OpLabel", + "OpReturn", + "OpFunctionEnd" + // clang-format on + }; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + std::string assembly = JoinAllInsts(text); + auto result = SinglePassRunAndDisassemble( + assembly, /* skip_nop = */ true, /* do_validation = */ false); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); + EXPECT_EQ(assembly, std::get<0>(result)); +} + +TEST_F(AggressiveDCETest, BasicRemoveDecorationsAndNames) { + // We want to remove the names and decorations associated with results that + // are removed. This test will check for that. + const std::string text = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpName %main "main" + OpName %Dead "Dead" + OpName %x "x" + OpName %y "y" + OpName %z "z" + OpDecorate %x RelaxedPrecision + OpDecorate %y RelaxedPrecision + OpDecorate %z RelaxedPrecision + OpDecorate %6 RelaxedPrecision + OpDecorate %7 RelaxedPrecision + OpDecorate %8 RelaxedPrecision + %void = OpTypeVoid + %10 = OpTypeFunction %void + %float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float + %float_1 = OpConstant %float 1 + %main = OpFunction %void None %10 + %14 = OpLabel + OpReturn + OpFunctionEnd + %Dead = OpFunction %void None %10 + %15 = OpLabel + %x = OpVariable %_ptr_Function_float Function + %y = OpVariable %_ptr_Function_float Function + %z = OpVariable %_ptr_Function_float Function + OpStore %x %float_1 + OpStore %y %float_1 + %6 = OpLoad %float %x + %7 = OpLoad %float %y + %8 = OpFAdd %float %6 %7 + OpStore %z %8 + OpReturn + OpFunctionEnd)"; + + const std::string expected_output = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %main "main" +OpName %main "main" +%void = OpTypeVoid +%10 = OpTypeFunction %void +%main = OpFunction %void None %10 +%14 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(text, expected_output, + /* skip_nop = */ true); +} + +TEST_F(AggressiveDCETest, BasicAllDeadConstants) { + const std::string text = R"( + ; CHECK-NOT: OpConstant + OpCapability Shader + OpCapability Float64 + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpName %main "main" + %void = OpTypeVoid + %4 = OpTypeFunction %void + %bool = OpTypeBool + %true = OpConstantTrue %bool + %false = OpConstantFalse %bool + %int = OpTypeInt 32 1 + %9 = OpConstant %int 1 + %uint = OpTypeInt 32 0 + %11 = OpConstant %uint 2 + %float = OpTypeFloat 32 + %13 = OpConstant %float 3.1415 + %double = OpTypeFloat 64 + %15 = OpConstant %double 3.14159265358979 + %main = OpFunction %void None %4 + %16 = OpLabel + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(AggressiveDCETest, BasicNoneDeadConstants) { + const std::vector text = { + // clang-format off + "OpCapability Shader", + "OpCapability Float64", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Vertex %main \"main\" %btv %bfv %iv %uv %fv %dv", + "OpName %main \"main\"", + "OpName %btv \"btv\"", + "OpName %bfv \"bfv\"", + "OpName %iv \"iv\"", + "OpName %uv \"uv\"", + "OpName %fv \"fv\"", + "OpName %dv \"dv\"", + "%void = OpTypeVoid", + "%10 = OpTypeFunction %void", + "%bool = OpTypeBool", + "%_ptr_Output_bool = OpTypePointer Output %bool", + "%true = OpConstantTrue %bool", + "%false = OpConstantFalse %bool", + "%int = OpTypeInt 32 1", + "%_ptr_Output_int = OpTypePointer Output %int", + "%int_1 = OpConstant %int 1", + "%uint = OpTypeInt 32 0", + "%_ptr_Output_uint = OpTypePointer Output %uint", + "%uint_2 = OpConstant %uint 2", + "%float = OpTypeFloat 32", + "%_ptr_Output_float = OpTypePointer Output %float", + "%float_3_1415 = OpConstant %float 3.1415", + "%double = OpTypeFloat 64", + "%_ptr_Output_double = OpTypePointer Output %double", + "%double_3_14159265358979 = OpConstant %double 3.14159265358979", + "%btv = OpVariable %_ptr_Output_bool Output", + "%bfv = OpVariable %_ptr_Output_bool Output", + "%iv = OpVariable %_ptr_Output_int Output", + "%uv = OpVariable %_ptr_Output_uint Output", + "%fv = OpVariable %_ptr_Output_float Output", + "%dv = OpVariable %_ptr_Output_double Output", + "%main = OpFunction %void None %10", + "%27 = OpLabel", + "OpStore %btv %true", + "OpStore %bfv %false", + "OpStore %iv %int_1", + "OpStore %uv %uint_2", + "OpStore %fv %float_3_1415", + "OpStore %dv %double_3_14159265358979", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + // All constants are used, so none of them should be eliminated. + SinglePassRunAndCheck( + JoinAllInsts(text), JoinAllInsts(text), /* skip_nop = */ true); +} + +struct AggressiveEliminateDeadConstantTestCase { + // Type declarations and constants that should be kept. + std::vector used_consts; + // Instructions that refer to constants, this is added to create uses for + // some constants so they won't be treated as dead constants. + std::vector main_insts; + // Dead constants that should be removed. + std::vector dead_consts; + // Expectations + std::vector checks; +}; + +// All types that are potentially required in +// AggressiveEliminateDeadConstantTest. +const std::vector CommonTypes = { + // clang-format off + // scalar types + "%bool = OpTypeBool", + "%uint = OpTypeInt 32 0", + "%int = OpTypeInt 32 1", + "%float = OpTypeFloat 32", + "%double = OpTypeFloat 64", + // vector types + "%v2bool = OpTypeVector %bool 2", + "%v2uint = OpTypeVector %uint 2", + "%v2int = OpTypeVector %int 2", + "%v3int = OpTypeVector %int 3", + "%v4int = OpTypeVector %int 4", + "%v2float = OpTypeVector %float 2", + "%v3float = OpTypeVector %float 3", + "%v2double = OpTypeVector %double 2", + // variable pointer types + "%_pf_bool = OpTypePointer Output %bool", + "%_pf_uint = OpTypePointer Output %uint", + "%_pf_int = OpTypePointer Output %int", + "%_pf_float = OpTypePointer Output %float", + "%_pf_double = OpTypePointer Output %double", + "%_pf_v2int = OpTypePointer Output %v2int", + "%_pf_v3int = OpTypePointer Output %v3int", + "%_pf_v2float = OpTypePointer Output %v2float", + "%_pf_v3float = OpTypePointer Output %v3float", + "%_pf_v2double = OpTypePointer Output %v2double", + // struct types + "%inner_struct = OpTypeStruct %bool %int %float %double", + "%outer_struct = OpTypeStruct %inner_struct %int %double", + "%flat_struct = OpTypeStruct %bool %int %float %double", + // clang-format on +}; + +using AggressiveEliminateDeadConstantTest = + PassTest<::testing::TestWithParam>; + +TEST_P(AggressiveEliminateDeadConstantTest, Custom) { + auto& tc = GetParam(); + AssemblyBuilder builder; + builder.AppendTypesConstantsGlobals(CommonTypes) + .AppendTypesConstantsGlobals(tc.used_consts) + .AppendInMain(tc.main_insts); + const std::string expected = builder.GetCode(); + builder.AppendTypesConstantsGlobals(tc.dead_consts); + builder.PrependPreamble(tc.checks); + const std::string assembly_with_dead_const = builder.GetCode(); + + // Do not enable validation. As the input code is invalid from the base + // tests (ported from other passes). + SinglePassRunAndMatch(assembly_with_dead_const, false); +} + +INSTANTIATE_TEST_CASE_P( + ScalarTypeConstants, AggressiveEliminateDeadConstantTest, + ::testing::ValuesIn(std::vector({ + // clang-format off + // Scalar type constants, one dead constant and one used constant. + { + /* .used_consts = */ + { + "%used_const_int = OpConstant %int 1", + }, + /* .main_insts = */ + { + "%int_var = OpVariable %_pf_int Output", + "OpStore %int_var %used_const_int", + }, + /* .dead_consts = */ + { + "%dead_const_int = OpConstant %int 1", + }, + /* .checks = */ + { + "; CHECK: [[const:%\\w+]] = OpConstant %int 1", + "; CHECK-NOT: OpConstant", + "; CHECK: OpStore {{%\\w+}} [[const]]", + }, + }, + { + /* .used_consts = */ + { + "%used_const_uint = OpConstant %uint 1", + }, + /* .main_insts = */ + { + "%uint_var = OpVariable %_pf_uint Output", + "OpStore %uint_var %used_const_uint", + }, + /* .dead_consts = */ + { + "%dead_const_uint = OpConstant %uint 1", + }, + /* .checks = */ + { + "; CHECK: [[const:%\\w+]] = OpConstant %uint 1", + "; CHECK-NOT: OpConstant", + "; CHECK: OpStore {{%\\w+}} [[const]]", + }, + }, + { + /* .used_consts = */ + { + "%used_const_float = OpConstant %float 3.1415", + }, + /* .main_insts = */ + { + "%float_var = OpVariable %_pf_float Output", + "OpStore %float_var %used_const_float", + }, + /* .dead_consts = */ + { + "%dead_const_float = OpConstant %float 3.1415", + }, + /* .checks = */ + { + "; CHECK: [[const:%\\w+]] = OpConstant %float 3.1415", + "; CHECK-NOT: OpConstant", + "; CHECK: OpStore {{%\\w+}} [[const]]", + }, + }, + { + /* .used_consts = */ + { + "%used_const_double = OpConstant %double 3.14", + }, + /* .main_insts = */ + { + "%double_var = OpVariable %_pf_double Output", + "OpStore %double_var %used_const_double", + }, + /* .dead_consts = */ + { + "%dead_const_double = OpConstant %double 3.14", + }, + /* .checks = */ + { + "; CHECK: [[const:%\\w+]] = OpConstant %double 3.14", + "; CHECK-NOT: OpConstant", + "; CHECK: OpStore {{%\\w+}} [[const]]", + }, + }, + // clang-format on + }))); + +INSTANTIATE_TEST_CASE_P( + VectorTypeConstants, AggressiveEliminateDeadConstantTest, + ::testing::ValuesIn(std::vector({ + // clang-format off + // Tests eliminating dead constant type ivec2. One dead constant vector + // and one used constant vector, each built from its own group of + // scalar constants. + { + /* .used_consts = */ + { + "%used_int_x = OpConstant %int 1", + "%used_int_y = OpConstant %int 2", + "%used_v2int = OpConstantComposite %v2int %used_int_x %used_int_y", + }, + /* .main_insts = */ + { + "%v2int_var = OpVariable %_pf_v2int Output", + "OpStore %v2int_var %used_v2int", + }, + /* .dead_consts = */ + { + "%dead_int_x = OpConstant %int 1", + "%dead_int_y = OpConstant %int 2", + "%dead_v2int = OpConstantComposite %v2int %dead_int_x %dead_int_y", + }, + /* .checks = */ + { + "; CHECK: [[constx:%\\w+]] = OpConstant %int 1", + "; CHECK: [[consty:%\\w+]] = OpConstant %int 2", + "; CHECK: [[const:%\\w+]] = OpConstantComposite %v2int [[constx]] [[consty]]", + "; CHECK-NOT: OpConstant", + "; CHECK: OpStore {{%\\w+}} [[const]]", + }, + }, + // Tests eliminating dead constant ivec3. One dead constant vector and + // one used constant vector. But both built from a same group of + // scalar constants. + { + /* .used_consts = */ + { + "%used_int_x = OpConstant %int 1", + "%used_int_y = OpConstant %int 2", + "%used_int_z = OpConstant %int 3", + "%used_v3int = OpConstantComposite %v3int %used_int_x %used_int_y %used_int_z", + }, + /* .main_insts = */ + { + "%v3int_var = OpVariable %_pf_v3int Output", + "OpStore %v3int_var %used_v3int", + }, + /* .dead_consts = */ + { + "%dead_v3int = OpConstantComposite %v3int %used_int_x %used_int_y %used_int_z", + }, + /* .checks = */ + { + "; CHECK: [[constx:%\\w+]] = OpConstant %int 1", + "; CHECK: [[consty:%\\w+]] = OpConstant %int 2", + "; CHECK: [[constz:%\\w+]] = OpConstant %int 3", + "; CHECK: [[const:%\\w+]] = OpConstantComposite %v3int [[constx]] [[consty]] [[constz]]", + "; CHECK-NOT: OpConstant", + "; CHECK: OpStore {{%\\w+}} [[const]]", + }, + }, + // Tests eliminating dead constant vec2. One dead constant vector and + // one used constant vector. Each built from its own group of scalar + // constants. + { + /* .used_consts = */ + { + "%used_float_x = OpConstant %float 3.1415", + "%used_float_y = OpConstant %float 4.13", + "%used_v2float = OpConstantComposite %v2float %used_float_x %used_float_y", + }, + /* .main_insts = */ + { + "%v2float_var = OpVariable %_pf_v2float Output", + "OpStore %v2float_var %used_v2float", + }, + /* .dead_consts = */ + { + "%dead_float_x = OpConstant %float 3.1415", + "%dead_float_y = OpConstant %float 4.13", + "%dead_v2float = OpConstantComposite %v2float %dead_float_x %dead_float_y", + }, + /* .checks = */ + { + "; CHECK: [[constx:%\\w+]] = OpConstant %float 3.1415", + "; CHECK: [[consty:%\\w+]] = OpConstant %float 4.13", + "; CHECK: [[const:%\\w+]] = OpConstantComposite %v2float [[constx]] [[consty]]", + "; CHECK-NOT: OpConstant", + "; CHECK: OpStore {{%\\w+}} [[const]]", + }, + }, + // Tests eliminating dead constant vec3. One dead constant vector and + // one used constant vector. Both built from a same group of scalar + // constants. + { + /* .used_consts = */ + { + "%used_float_x = OpConstant %float 3.1415", + "%used_float_y = OpConstant %float 4.25", + "%used_float_z = OpConstant %float 4.75", + "%used_v3float = OpConstantComposite %v3float %used_float_x %used_float_y %used_float_z", + }, + /* .main_insts = */ + { + "%v3float_var = OpVariable %_pf_v3float Output", + "OpStore %v3float_var %used_v3float", + }, + /* .dead_consts = */ + { + "%dead_v3float = OpConstantComposite %v3float %used_float_x %used_float_y %used_float_z", + }, + /* .checks = */ + { + "; CHECK: [[constx:%\\w+]] = OpConstant %float 3.1415", + "; CHECK: [[consty:%\\w+]] = OpConstant %float 4.25", + "; CHECK: [[constz:%\\w+]] = OpConstant %float 4.75", + "; CHECK: [[const:%\\w+]] = OpConstantComposite %v3float [[constx]] [[consty]]", + "; CHECK-NOT: OpConstant", + "; CHECK: OpStore {{%\\w+}} [[const]]", + }, + }, + // clang-format on + }))); + +INSTANTIATE_TEST_CASE_P( + StructTypeConstants, AggressiveEliminateDeadConstantTest, + ::testing::ValuesIn(std::vector({ + // clang-format off + // A plain struct type dead constants. All of its components are dead + // constants too. + { + /* .used_consts = */ {}, + /* .main_insts = */ {}, + /* .dead_consts = */ + { + "%dead_bool = OpConstantTrue %bool", + "%dead_int = OpConstant %int 1", + "%dead_float = OpConstant %float 2.5", + "%dead_double = OpConstant %double 3.14159265358979", + "%dead_struct = OpConstantComposite %flat_struct %dead_bool %dead_int %dead_float %dead_double", + }, + /* .checks = */ + { + "; CHECK-NOT: OpConstant", + }, + }, + // A plain struct type dead constants. Some of its components are dead + // constants while others are not. + { + /* .used_consts = */ + { + "%used_int = OpConstant %int 1", + "%used_double = OpConstant %double 3.14159265358979", + }, + /* .main_insts = */ + { + "%int_var = OpVariable %_pf_int Output", + "OpStore %int_var %used_int", + "%double_var = OpVariable %_pf_double Output", + "OpStore %double_var %used_double", + }, + /* .dead_consts = */ + { + "%dead_bool = OpConstantTrue %bool", + "%dead_float = OpConstant %float 2.5", + "%dead_struct = OpConstantComposite %flat_struct %dead_bool %used_int %dead_float %used_double", + }, + /* .checks = */ + { + "; CHECK: [[int:%\\w+]] = OpConstant %int 1", + "; CHECK: [[double:%\\w+]] = OpConstant %double 3.14159265358979", + "; CHECK-NOT: OpConstant", + "; CHECK: OpStore {{%\\w+}} [[int]]", + "; CHECK: OpStore {{%\\w+}} [[double]]", + }, + }, + // A nesting struct type dead constants. All components of both outer + // and inner structs are dead and should be removed after dead constant + // elimination. + { + /* .used_consts = */ {}, + /* .main_insts = */ {}, + /* .dead_consts = */ + { + "%dead_bool = OpConstantTrue %bool", + "%dead_int = OpConstant %int 1", + "%dead_float = OpConstant %float 2.5", + "%dead_double = OpConstant %double 3.1415926535", + "%dead_inner_struct = OpConstantComposite %inner_struct %dead_bool %dead_int %dead_float %dead_double", + "%dead_int2 = OpConstant %int 2", + "%dead_double2 = OpConstant %double 1.428571428514", + "%dead_outer_struct = OpConstantComposite %outer_struct %dead_inner_struct %dead_int2 %dead_double2", + }, + /* .checks = */ + { + "; CHECK-NOT: OpConstant", + }, + }, + // A nesting struct type dead constants. Some of its components are + // dead constants while others are not. + { + /* .used_consts = */ + { + "%used_int = OpConstant %int 1", + "%used_double = OpConstant %double 3.14159265358979", + }, + /* .main_insts = */ + { + "%int_var = OpVariable %_pf_int Output", + "OpStore %int_var %used_int", + "%double_var = OpVariable %_pf_double Output", + "OpStore %double_var %used_double", + }, + /* .dead_consts = */ + { + "%dead_bool = OpConstantTrue %bool", + "%dead_float = OpConstant %float 2.5", + "%dead_inner_struct = OpConstantComposite %inner_struct %dead_bool %used_int %dead_float %used_double", + "%dead_int = OpConstant %int 2", + "%dead_outer_struct = OpConstantComposite %outer_struct %dead_inner_struct %dead_int %used_double", + }, + /* .checks = */ + { + "; CHECK: [[int:%\\w+]] = OpConstant %int 1", + "; CHECK: [[double:%\\w+]] = OpConstant %double 3.14159265358979", + "; CHECK-NOT: OpConstant", + "; CHECK: OpStore {{%\\w+}} [[int]]", + "; CHECK: OpStore {{%\\w+}} [[double]]", + }, + }, + // A nesting struct case. The inner struct is used while the outer struct is not + { + /* .used_const = */ + { + "%used_bool = OpConstantTrue %bool", + "%used_int = OpConstant %int 1", + "%used_float = OpConstant %float 1.23", + "%used_double = OpConstant %double 1.2345678901234", + "%used_inner_struct = OpConstantComposite %inner_struct %used_bool %used_int %used_float %used_double", + }, + /* .main_insts = */ + { + "%bool_var = OpVariable %_pf_bool Output", + "%bool_from_inner_struct = OpCompositeExtract %bool %used_inner_struct 0", + "OpStore %bool_var %bool_from_inner_struct", + }, + /* .dead_consts = */ + { + "%dead_int = OpConstant %int 2", + "%dead_outer_struct = OpConstantComposite %outer_struct %used_inner_struct %dead_int %used_double" + }, + /* .checks = */ + { + "; CHECK: [[bool:%\\w+]] = OpConstantTrue", + "; CHECK: [[int:%\\w+]] = OpConstant %int 1", + "; CHECK: [[float:%\\w+]] = OpConstant %float 1.23", + "; CHECK: [[double:%\\w+]] = OpConstant %double 1.2345678901234", + "; CHECK: [[struct:%\\w+]] = OpConstantComposite %inner_struct [[bool]] [[int]] [[float]] [[double]]", + "; CHECK-NOT: OpConstant", + "; CHECK: OpCompositeExtract %bool [[struct]]", + } + }, + // A nesting struct case. The outer struct is used, so the inner struct should not + // be removed even though it is not used anywhere. + { + /* .used_const = */ + { + "%used_bool = OpConstantTrue %bool", + "%used_int = OpConstant %int 1", + "%used_float = OpConstant %float 1.23", + "%used_double = OpConstant %double 1.2345678901234", + "%used_inner_struct = OpConstantComposite %inner_struct %used_bool %used_int %used_float %used_double", + "%used_outer_struct = OpConstantComposite %outer_struct %used_inner_struct %used_int %used_double" + }, + /* .main_insts = */ + { + "%int_var = OpVariable %_pf_int Output", + "%int_from_outer_struct = OpCompositeExtract %int %used_outer_struct 1", + "OpStore %int_var %int_from_outer_struct", + }, + /* .dead_consts = */ {}, + /* .checks = */ + { + "; CHECK: [[bool:%\\w+]] = OpConstantTrue %bool", + "; CHECK: [[int:%\\w+]] = OpConstant %int 1", + "; CHECK: [[float:%\\w+]] = OpConstant %float 1.23", + "; CHECK: [[double:%\\w+]] = OpConstant %double 1.2345678901234", + "; CHECK: [[inner_struct:%\\w+]] = OpConstantComposite %inner_struct %used_bool %used_int %used_float %used_double", + "; CHECK: [[outer_struct:%\\w+]] = OpConstantComposite %outer_struct %used_inner_struct %used_int %used_double", + "; CHECK: OpCompositeExtract %int [[outer_struct]]", + }, + }, + // clang-format on + }))); + +INSTANTIATE_TEST_CASE_P( + ScalarTypeSpecConstants, AggressiveEliminateDeadConstantTest, + ::testing::ValuesIn(std::vector({ + // clang-format off + // All scalar type spec constants. + { + /* .used_consts = */ + { + "%used_bool = OpSpecConstantTrue %bool", + "%used_uint = OpSpecConstant %uint 2", + "%used_int = OpSpecConstant %int 2", + "%used_float = OpSpecConstant %float 2.5", + "%used_double = OpSpecConstant %double 1.428571428514", + }, + /* .main_insts = */ + { + "%bool_var = OpVariable %_pf_bool Output", + "%uint_var = OpVariable %_pf_uint Output", + "%int_var = OpVariable %_pf_int Output", + "%float_var = OpVariable %_pf_float Output", + "%double_var = OpVariable %_pf_double Output", + "OpStore %bool_var %used_bool", + "OpStore %uint_var %used_uint", + "OpStore %int_var %used_int", + "OpStore %float_var %used_float", + "OpStore %double_var %used_double", + }, + /* .dead_consts = */ + { + "%dead_bool = OpSpecConstantTrue %bool", + "%dead_uint = OpSpecConstant %uint 2", + "%dead_int = OpSpecConstant %int 2", + "%dead_float = OpSpecConstant %float 2.5", + "%dead_double = OpSpecConstant %double 1.428571428514", + }, + /* .checks = */ + { + "; CHECK: [[bool:%\\w+]] = OpSpecConstantTrue %bool", + "; CHECK: [[uint:%\\w+]] = OpSpecConstant %uint 2", + "; CHECK: [[int:%\\w+]] = OpSpecConstant %int 2", + "; CHECK: [[float:%\\w+]] = OpSpecConstant %float 2.5", + "; CHECK: [[double:%\\w+]] = OpSpecConstant %double 1.428571428514", + "; CHECK-NOT: OpSpecConstant", + "; CHECK: OpStore {{%\\w+}} [[bool]]", + "; CHECK: OpStore {{%\\w+}} [[uint]]", + "; CHECK: OpStore {{%\\w+}} [[int]]", + "; CHECK: OpStore {{%\\w+}} [[float]]", + "; CHECK: OpStore {{%\\w+}} [[double]]", + }, + }, + // clang-format on + }))); + +INSTANTIATE_TEST_CASE_P( + VectorTypeSpecConstants, AggressiveEliminateDeadConstantTest, + ::testing::ValuesIn(std::vector({ + // clang-format off + // Bool vector type spec constants. One vector has all component dead, + // another vector has one dead boolean and one used boolean. + { + /* .used_consts = */ + { + "%used_bool = OpSpecConstantTrue %bool", + }, + /* .main_insts = */ + { + "%bool_var = OpVariable %_pf_bool Output", + "OpStore %bool_var %used_bool", + }, + /* .dead_consts = */ + { + "%dead_bool = OpSpecConstantFalse %bool", + "%dead_bool_vec1 = OpSpecConstantComposite %v2bool %dead_bool %dead_bool", + "%dead_bool_vec2 = OpSpecConstantComposite %v2bool %dead_bool %used_bool", + }, + /* .checks = */ + { + "; CHECK: [[bool:%\\w+]] = OpSpecConstantTrue %bool", + "; CHECK-NOT: OpSpecConstant", + "; CHECK: OpStore {{%\\w+}} [[bool]]", + }, + }, + + // Uint vector type spec constants. One vector has all component dead, + // another vector has one dead unsigend integer and one used unsigned + // integer. + { + /* .used_consts = */ + { + "%used_uint = OpSpecConstant %uint 3", + }, + /* .main_insts = */ + { + "%uint_var = OpVariable %_pf_uint Output", + "OpStore %uint_var %used_uint", + }, + /* .dead_consts = */ + { + "%dead_uint = OpSpecConstant %uint 1", + "%dead_uint_vec1 = OpSpecConstantComposite %v2uint %dead_uint %dead_uint", + "%dead_uint_vec2 = OpSpecConstantComposite %v2uint %dead_uint %used_uint", + }, + /* .checks = */ + { + "; CHECK: [[uint:%\\w+]] = OpSpecConstant %uint 3", + "; CHECK-NOT: OpSpecConstant", + "; CHECK: OpStore {{%\\w+}} [[uint]]", + }, + }, + + // Int vector type spec constants. One vector has all component dead, + // another vector has one dead integer and one used integer. + { + /* .used_consts = */ + { + "%used_int = OpSpecConstant %int 3", + }, + /* .main_insts = */ + { + "%int_var = OpVariable %_pf_int Output", + "OpStore %int_var %used_int", + }, + /* .dead_consts = */ + { + "%dead_int = OpSpecConstant %int 1", + "%dead_int_vec1 = OpSpecConstantComposite %v2int %dead_int %dead_int", + "%dead_int_vec2 = OpSpecConstantComposite %v2int %dead_int %used_int", + }, + /* .checks = */ + { + "; CHECK: [[int:%\\w+]] = OpSpecConstant %int 3", + "; CHECK-NOT: OpSpecConstant", + "; CHECK: OpStore {{%\\w+}} [[int]]", + }, + }, + + // Int vector type spec constants built with both spec constants and + // front-end constants. + { + /* .used_consts = */ + { + "%used_spec_int = OpSpecConstant %int 3", + "%used_front_end_int = OpConstant %int 3", + }, + /* .main_insts = */ + { + "%int_var1 = OpVariable %_pf_int Output", + "OpStore %int_var1 %used_spec_int", + "%int_var2 = OpVariable %_pf_int Output", + "OpStore %int_var2 %used_front_end_int", + }, + /* .dead_consts = */ + { + "%dead_spec_int = OpSpecConstant %int 1", + "%dead_front_end_int = OpConstant %int 1", + // Dead front-end and dead spec constants + "%dead_int_vec1 = OpSpecConstantComposite %v2int %dead_spec_int %dead_front_end_int", + // Used front-end and dead spec constants + "%dead_int_vec2 = OpSpecConstantComposite %v2int %dead_spec_int %used_front_end_int", + // Dead front-end and used spec constants + "%dead_int_vec3 = OpSpecConstantComposite %v2int %dead_front_end_int %used_spec_int", + }, + /* .checks = */ + { + "; CHECK: [[int1:%\\w+]] = OpSpecConstant %int 3", + "; CHECK: [[int2:%\\w+]] = OpConstant %int 3", + "; CHECK-NOT: OpSpecConstant", + "; CHECK-NOT: OpConstant", + "; CHECK: OpStore {{%\\w+}} [[int1]]", + "; CHECK: OpStore {{%\\w+}} [[int2]]", + }, + }, + // clang-format on + }))); + +INSTANTIATE_TEST_CASE_P( + SpecConstantOp, AggressiveEliminateDeadConstantTest, + ::testing::ValuesIn(std::vector({ + // clang-format off + // Cast operations: uint <-> int <-> bool + { + /* .used_consts = */ {}, + /* .main_insts = */ {}, + /* .dead_consts = */ + { + // Assistant constants, only used in dead spec constant + // operations. + "%signed_zero = OpConstant %int 0", + "%signed_zero_vec = OpConstantComposite %v2int %signed_zero %signed_zero", + "%unsigned_zero = OpConstant %uint 0", + "%unsigned_zero_vec = OpConstantComposite %v2uint %unsigned_zero %unsigned_zero", + "%signed_one = OpConstant %int 1", + "%signed_one_vec = OpConstantComposite %v2int %signed_one %signed_one", + "%unsigned_one = OpConstant %uint 1", + "%unsigned_one_vec = OpConstantComposite %v2uint %unsigned_one %unsigned_one", + + // Spec constants that support casting to each other. + "%dead_bool = OpSpecConstantTrue %bool", + "%dead_uint = OpSpecConstant %uint 1", + "%dead_int = OpSpecConstant %int 2", + "%dead_bool_vec = OpSpecConstantComposite %v2bool %dead_bool %dead_bool", + "%dead_uint_vec = OpSpecConstantComposite %v2uint %dead_uint %dead_uint", + "%dead_int_vec = OpSpecConstantComposite %v2int %dead_int %dead_int", + + // Scalar cast to boolean spec constant. + "%int_to_bool = OpSpecConstantOp %bool INotEqual %dead_int %signed_zero", + "%uint_to_bool = OpSpecConstantOp %bool INotEqual %dead_uint %unsigned_zero", + + // Vector cast to boolean spec constant. + "%int_to_bool_vec = OpSpecConstantOp %v2bool INotEqual %dead_int_vec %signed_zero_vec", + "%uint_to_bool_vec = OpSpecConstantOp %v2bool INotEqual %dead_uint_vec %unsigned_zero_vec", + + // Scalar cast to int spec constant. + "%bool_to_int = OpSpecConstantOp %int Select %dead_bool %signed_one %signed_zero", + "%uint_to_int = OpSpecConstantOp %uint IAdd %dead_uint %unsigned_zero", + + // Vector cast to int spec constant. + "%bool_to_int_vec = OpSpecConstantOp %v2int Select %dead_bool_vec %signed_one_vec %signed_zero_vec", + "%uint_to_int_vec = OpSpecConstantOp %v2uint IAdd %dead_uint_vec %unsigned_zero_vec", + + // Scalar cast to uint spec constant. + "%bool_to_uint = OpSpecConstantOp %uint Select %dead_bool %unsigned_one %unsigned_zero", + "%int_to_uint_vec = OpSpecConstantOp %uint IAdd %dead_int %signed_zero", + + // Vector cast to uint spec constant. + "%bool_to_uint_vec = OpSpecConstantOp %v2uint Select %dead_bool_vec %unsigned_one_vec %unsigned_zero_vec", + "%int_to_uint = OpSpecConstantOp %v2uint IAdd %dead_int_vec %signed_zero_vec", + }, + /* .checks = */ + { + "; CHECK-NOT: OpConstant", + "; CHECK-NOT: OpSpecConstant", + }, + }, + + // Add, sub, mul, div, rem. + { + /* .used_consts = */ {}, + /* .main_insts = */ {}, + /* .dead_consts = */ + { + "%dead_spec_int_a = OpSpecConstant %int 1", + "%dead_spec_int_a_vec = OpSpecConstantComposite %v2int %dead_spec_int_a %dead_spec_int_a", + + "%dead_spec_int_b = OpSpecConstant %int 2", + "%dead_spec_int_b_vec = OpSpecConstantComposite %v2int %dead_spec_int_b %dead_spec_int_b", + + "%dead_const_int_c = OpConstant %int 3", + "%dead_const_int_c_vec = OpConstantComposite %v2int %dead_const_int_c %dead_const_int_c", + + // Add + "%add_a_b = OpSpecConstantOp %int IAdd %dead_spec_int_a %dead_spec_int_b", + "%add_a_b_vec = OpSpecConstantOp %v2int IAdd %dead_spec_int_a_vec %dead_spec_int_b_vec", + + // Sub + "%sub_a_b = OpSpecConstantOp %int ISub %dead_spec_int_a %dead_spec_int_b", + "%sub_a_b_vec = OpSpecConstantOp %v2int ISub %dead_spec_int_a_vec %dead_spec_int_b_vec", + + // Mul + "%mul_a_b = OpSpecConstantOp %int IMul %dead_spec_int_a %dead_spec_int_b", + "%mul_a_b_vec = OpSpecConstantOp %v2int IMul %dead_spec_int_a_vec %dead_spec_int_b_vec", + + // Div + "%div_a_b = OpSpecConstantOp %int SDiv %dead_spec_int_a %dead_spec_int_b", + "%div_a_b_vec = OpSpecConstantOp %v2int SDiv %dead_spec_int_a_vec %dead_spec_int_b_vec", + + // Bitwise Xor + "%xor_a_b = OpSpecConstantOp %int BitwiseXor %dead_spec_int_a %dead_spec_int_b", + "%xor_a_b_vec = OpSpecConstantOp %v2int BitwiseXor %dead_spec_int_a_vec %dead_spec_int_b_vec", + + // Scalar Comparison + "%less_a_b = OpSpecConstantOp %bool SLessThan %dead_spec_int_a %dead_spec_int_b", + }, + /* .checks = */ + { + "; CHECK-NOT: OpConstant", + "; CHECK-NOT: OpSpecConstant", + }, + }, + + // Vectors without used swizzles should be removed. + { + /* .used_consts = */ + { + "%used_int = OpConstant %int 3", + }, + /* .main_insts = */ + { + "%int_var = OpVariable %_pf_int Output", + "OpStore %int_var %used_int", + }, + /* .dead_consts = */ + { + "%dead_int = OpConstant %int 3", + + "%dead_spec_int_a = OpSpecConstant %int 1", + "%vec_a = OpSpecConstantComposite %v4int %dead_spec_int_a %dead_spec_int_a %dead_int %dead_int", + + "%dead_spec_int_b = OpSpecConstant %int 2", + "%vec_b = OpSpecConstantComposite %v4int %dead_spec_int_b %dead_spec_int_b %used_int %used_int", + + // Extract scalar + "%a_x = OpSpecConstantOp %int CompositeExtract %vec_a 0", + "%b_x = OpSpecConstantOp %int CompositeExtract %vec_b 0", + + // Extract vector + "%a_xy = OpSpecConstantOp %v2int VectorShuffle %vec_a %vec_a 0 1", + "%b_xy = OpSpecConstantOp %v2int VectorShuffle %vec_b %vec_b 0 1", + }, + /* .checks = */ + { + "; CHECK: [[int:%\\w+]] = OpConstant %int 3", + "; CHECK-NOT: OpConstant", + "; CHECK-NOT: OpSpecConstant", + "; CHECK: OpStore {{%\\w+}} [[int]]", + }, + }, + // Vectors with used swizzles should not be removed. + { + /* .used_consts = */ + { + "%used_int = OpConstant %int 3", + "%used_spec_int_a = OpSpecConstant %int 1", + "%used_spec_int_b = OpSpecConstant %int 2", + // Create vectors + "%vec_a = OpSpecConstantComposite %v4int %used_spec_int_a %used_spec_int_a %used_int %used_int", + "%vec_b = OpSpecConstantComposite %v4int %used_spec_int_b %used_spec_int_b %used_int %used_int", + // Extract vector + "%a_xy = OpSpecConstantOp %v2int VectorShuffle %vec_a %vec_a 0 1", + "%b_xy = OpSpecConstantOp %v2int VectorShuffle %vec_b %vec_b 0 1", + }, + /* .main_insts = */ + { + "%v2int_var_a = OpVariable %_pf_v2int Output", + "%v2int_var_b = OpVariable %_pf_v2int Output", + "OpStore %v2int_var_a %a_xy", + "OpStore %v2int_var_b %b_xy", + }, + /* .dead_consts = */ {}, + /* .checks = */ + { + "; CHECK: [[int:%\\w+]] = OpConstant %int 3", + "; CHECK: [[a:%\\w+]] = OpSpecConstant %int 1", + "; CHECK: [[b:%\\w+]] = OpSpecConstant %int 2", + "; CHECK: [[veca:%\\w+]] = OpSpecConstantComposite %v4int [[a]] [[a]] [[int]] [[int]]", + "; CHECK: [[vecb:%\\w+]] = OpSpecConstantComposite %v4int [[b]] [[b]] [[int]] [[int]]", + "; CHECK: [[exa:%\\w+]] = OpSpecConstantOp %v2int VectorShuffle [[veca]] [[veca]] 0 1", + "; CHECK: [[exb:%\\w+]] = OpSpecConstantOp %v2int VectorShuffle [[vecb]] [[vecb]] 0 1", + "; CHECK-NOT: OpConstant", + "; CHECK-NOT: OpSpecConstant", + "; CHECK: OpStore {{%\\w+}} [[exa]]", + "; CHECK: OpStore {{%\\w+}} [[exb]]", + }, + }, + // clang-format on + }))); + +INSTANTIATE_TEST_CASE_P( + LongDefUseChain, AggressiveEliminateDeadConstantTest, + ::testing::ValuesIn(std::vector({ + // clang-format off + // Long Def-Use chain with binary operations. + { + /* .used_consts = */ + { + "%array_size = OpConstant %int 4", + "%type_arr_int_4 = OpTypeArray %int %array_size", + "%used_int_0 = OpConstant %int 100", + "%used_int_1 = OpConstant %int 1", + "%used_int_2 = OpSpecConstantOp %int IAdd %used_int_0 %used_int_1", + "%used_int_3 = OpSpecConstantOp %int ISub %used_int_0 %used_int_2", + "%used_int_4 = OpSpecConstantOp %int IAdd %used_int_0 %used_int_3", + "%used_int_5 = OpSpecConstantOp %int ISub %used_int_0 %used_int_4", + "%used_int_6 = OpSpecConstantOp %int IAdd %used_int_0 %used_int_5", + "%used_int_7 = OpSpecConstantOp %int ISub %used_int_0 %used_int_6", + "%used_int_8 = OpSpecConstantOp %int IAdd %used_int_0 %used_int_7", + "%used_int_9 = OpSpecConstantOp %int ISub %used_int_0 %used_int_8", + "%used_int_10 = OpSpecConstantOp %int IAdd %used_int_0 %used_int_9", + "%used_int_11 = OpSpecConstantOp %int ISub %used_int_0 %used_int_10", + "%used_int_12 = OpSpecConstantOp %int IAdd %used_int_0 %used_int_11", + "%used_int_13 = OpSpecConstantOp %int ISub %used_int_0 %used_int_12", + "%used_int_14 = OpSpecConstantOp %int IAdd %used_int_0 %used_int_13", + "%used_int_15 = OpSpecConstantOp %int ISub %used_int_0 %used_int_14", + "%used_int_16 = OpSpecConstantOp %int ISub %used_int_0 %used_int_15", + "%used_int_17 = OpSpecConstantOp %int IAdd %used_int_0 %used_int_16", + "%used_int_18 = OpSpecConstantOp %int ISub %used_int_0 %used_int_17", + "%used_int_19 = OpSpecConstantOp %int IAdd %used_int_0 %used_int_18", + "%used_int_20 = OpSpecConstantOp %int ISub %used_int_0 %used_int_19", + "%used_vec_a = OpSpecConstantComposite %v2int %used_int_18 %used_int_19", + "%used_vec_b = OpSpecConstantOp %v2int IMul %used_vec_a %used_vec_a", + "%used_int_21 = OpSpecConstantOp %int CompositeExtract %used_vec_b 0", + "%used_array = OpConstantComposite %type_arr_int_4 %used_int_20 %used_int_20 %used_int_21 %used_int_21", + }, + /* .main_insts = */ + { + "%int_var = OpVariable %_pf_int Output", + "%used_array_2 = OpCompositeExtract %int %used_array 2", + "OpStore %int_var %used_array_2", + }, + /* .dead_consts = */ + { + "%dead_int_1 = OpConstant %int 2", + "%dead_int_2 = OpSpecConstantOp %int IAdd %used_int_0 %dead_int_1", + "%dead_int_3 = OpSpecConstantOp %int ISub %used_int_0 %dead_int_2", + "%dead_int_4 = OpSpecConstantOp %int IAdd %used_int_0 %dead_int_3", + "%dead_int_5 = OpSpecConstantOp %int ISub %used_int_0 %dead_int_4", + "%dead_int_6 = OpSpecConstantOp %int IAdd %used_int_0 %dead_int_5", + "%dead_int_7 = OpSpecConstantOp %int ISub %used_int_0 %dead_int_6", + "%dead_int_8 = OpSpecConstantOp %int IAdd %used_int_0 %dead_int_7", + "%dead_int_9 = OpSpecConstantOp %int ISub %used_int_0 %dead_int_8", + "%dead_int_10 = OpSpecConstantOp %int IAdd %used_int_0 %dead_int_9", + "%dead_int_11 = OpSpecConstantOp %int ISub %used_int_0 %dead_int_10", + "%dead_int_12 = OpSpecConstantOp %int IAdd %used_int_0 %dead_int_11", + "%dead_int_13 = OpSpecConstantOp %int ISub %used_int_0 %dead_int_12", + "%dead_int_14 = OpSpecConstantOp %int IAdd %used_int_0 %dead_int_13", + "%dead_int_15 = OpSpecConstantOp %int ISub %used_int_0 %dead_int_14", + "%dead_int_16 = OpSpecConstantOp %int ISub %used_int_0 %dead_int_15", + "%dead_int_17 = OpSpecConstantOp %int IAdd %used_int_0 %dead_int_16", + "%dead_int_18 = OpSpecConstantOp %int ISub %used_int_0 %dead_int_17", + "%dead_int_19 = OpSpecConstantOp %int IAdd %used_int_0 %dead_int_18", + "%dead_int_20 = OpSpecConstantOp %int ISub %used_int_0 %dead_int_19", + "%dead_vec_a = OpSpecConstantComposite %v2int %dead_int_18 %dead_int_19", + "%dead_vec_b = OpSpecConstantOp %v2int IMul %dead_vec_a %dead_vec_a", + "%dead_int_21 = OpSpecConstantOp %int CompositeExtract %dead_vec_b 0", + "%dead_array = OpConstantComposite %type_arr_int_4 %dead_int_20 %used_int_20 %dead_int_19 %used_int_19", + }, + /* .checks = */ + { + "; CHECK: OpConstant %int 4", + "; CHECK: [[array:%\\w+]] = OpConstantComposite %type_arr_int_4 %used_int_20 %used_int_20 %used_int_21 %used_int_21", + "; CHECK-NOT: OpConstant", + "; CHECK-NOT: OpSpecConstant", + "; CHECK: OpStore {{%\\w+}} [[array]]", + }, + }, + // Long Def-Use chain with swizzle + // clang-format on + }))); + +TEST_F(AggressiveDCETest, DeadDecorationGroup) { + // The decoration group should be eliminated because the target of group + // decorate is dead. + const std::string text = R"( +; CHECK-NOT: OpDecorat +; CHECK-NOT: OpGroupDecorate +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpDecorate %1 Restrict +OpDecorate %1 Aliased +%1 = OpDecorationGroup +OpGroupDecorate %1 %var +%void = OpTypeVoid +%func = OpTypeFunction %void +%uint = OpTypeInt 32 0 +%uint_ptr = OpTypePointer Function %uint +%main = OpFunction %void None %func +%2 = OpLabel +%var = OpVariable %uint_ptr Function +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(AggressiveDCETest, DeadDecorationGroupAndValidDecorationMgr) { + // The decoration group should be eliminated because the target of group + // decorate is dead. + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpDecorate %1 Restrict +OpDecorate %1 Aliased +%1 = OpDecorationGroup +OpGroupDecorate %1 %var +%void = OpTypeVoid +%func = OpTypeFunction %void +%uint = OpTypeInt 32 0 +%uint_ptr = OpTypePointer Function %uint +%main = OpFunction %void None %func +%2 = OpLabel +%var = OpVariable %uint_ptr Function +OpReturn +OpFunctionEnd + )"; + + auto pass = MakeUnique(); + auto consumer = [](spv_message_level_t, const char*, const spv_position_t&, + const char* message) { + std::cerr << message << std::endl; + }; + auto context = BuildModule(SPV_ENV_UNIVERSAL_1_1, consumer, text); + + // Build the decoration manager before the pass. + context->get_decoration_mgr(); + + const auto status = pass->Run(context.get()); + EXPECT_EQ(status, Pass::Status::SuccessWithChange); +} + +TEST_F(AggressiveDCETest, ParitallyDeadDecorationGroup) { + const std::string text = R"( +; CHECK: OpDecorate [[grp:%\w+]] Restrict +; CHECK: OpDecorate [[grp]] Aliased +; CHECK: [[grp]] = OpDecorationGroup +; CHECK: OpGroupDecorate [[grp]] [[output:%\w+]] +; CHECK: [[output]] = OpVariable {{%\w+}} Output +; CHECK-NOT: OpVariable {{%\w+}} Function +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %output +OpExecutionMode %main OriginUpperLeft +OpDecorate %1 Restrict +OpDecorate %1 Aliased +%1 = OpDecorationGroup +OpGroupDecorate %1 %var %output +%void = OpTypeVoid +%func = OpTypeFunction %void +%uint = OpTypeInt 32 0 +%uint_ptr_Function = OpTypePointer Function %uint +%uint_ptr_Output = OpTypePointer Output %uint +%uint_0 = OpConstant %uint 0 +%output = OpVariable %uint_ptr_Output Output +%main = OpFunction %void None %func +%2 = OpLabel +%var = OpVariable %uint_ptr_Function Function +OpStore %output %uint_0 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(AggressiveDCETest, ParitallyDeadDecorationGroupDifferentGroupDecorate) { + const std::string text = R"( +; CHECK: OpDecorate [[grp:%\w+]] Restrict +; CHECK: OpDecorate [[grp]] Aliased +; CHECK: [[grp]] = OpDecorationGroup +; CHECK: OpGroupDecorate [[grp]] [[output:%\w+]] +; CHECK-NOT: OpGroupDecorate +; CHECK: [[output]] = OpVariable {{%\w+}} Output +; CHECK-NOT: OpVariable {{%\w+}} Function +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %output +OpExecutionMode %main OriginUpperLeft +OpDecorate %1 Restrict +OpDecorate %1 Aliased +%1 = OpDecorationGroup +OpGroupDecorate %1 %output +OpGroupDecorate %1 %var +%void = OpTypeVoid +%func = OpTypeFunction %void +%uint = OpTypeInt 32 0 +%uint_ptr_Function = OpTypePointer Function %uint +%uint_ptr_Output = OpTypePointer Output %uint +%uint_0 = OpConstant %uint 0 +%output = OpVariable %uint_ptr_Output Output +%main = OpFunction %void None %func +%2 = OpLabel +%var = OpVariable %uint_ptr_Function Function +OpStore %output %uint_0 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(AggressiveDCETest, DeadGroupMemberDecorate) { + const std::string text = R"( +; CHECK-NOT: OpDec +; CHECK-NOT: OpGroup +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpDecorate %1 Offset 0 +OpDecorate %1 Uniform +%1 = OpDecorationGroup +OpGroupMemberDecorate %1 %var 0 +%void = OpTypeVoid +%func = OpTypeFunction %void +%uint = OpTypeInt 32 0 +%struct = OpTypeStruct %uint %uint +%struct_ptr = OpTypePointer Function %struct +%main = OpFunction %void None %func +%2 = OpLabel +%var = OpVariable %struct_ptr Function +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(AggressiveDCETest, PartiallyDeadGroupMemberDecorate) { + const std::string text = R"( +; CHECK: OpDecorate [[grp:%\w+]] Offset 0 +; CHECK: OpDecorate [[grp]] RelaxedPrecision +; CHECK: [[grp]] = OpDecorationGroup +; CHECK: OpGroupMemberDecorate [[grp]] [[output:%\w+]] 1 +; CHECK: [[output]] = OpTypeStruct +; CHECK-NOT: OpTypeStruct +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %output +OpExecutionMode %main OriginUpperLeft +OpDecorate %1 Offset 0 +OpDecorate %1 RelaxedPrecision +%1 = OpDecorationGroup +OpGroupMemberDecorate %1 %var_struct 0 %output_struct 1 +%void = OpTypeVoid +%func = OpTypeFunction %void +%uint = OpTypeInt 32 0 +%var_struct = OpTypeStruct %uint %uint +%output_struct = OpTypeStruct %uint %uint +%struct_ptr_Function = OpTypePointer Function %var_struct +%struct_ptr_Output = OpTypePointer Output %output_struct +%uint_ptr_Output = OpTypePointer Output %uint +%output = OpVariable %struct_ptr_Output Output +%uint_0 = OpConstant %uint 0 +%main = OpFunction %void None %func +%2 = OpLabel +%var = OpVariable %struct_ptr_Function Function +%3 = OpAccessChain %uint_ptr_Output %output %uint_0 +OpStore %3 %uint_0 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(AggressiveDCETest, + PartiallyDeadGroupMemberDecorateDifferentGroupDecorate) { + const std::string text = R"( +; CHECK: OpDecorate [[grp:%\w+]] Offset 0 +; CHECK: OpDecorate [[grp]] RelaxedPrecision +; CHECK: [[grp]] = OpDecorationGroup +; CHECK: OpGroupMemberDecorate [[grp]] [[output:%\w+]] 1 +; CHECK-NOT: OpGroupMemberDecorate +; CHECK: [[output]] = OpTypeStruct +; CHECK-NOT: OpTypeStruct +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %output +OpExecutionMode %main OriginUpperLeft +OpDecorate %1 Offset 0 +OpDecorate %1 RelaxedPrecision +%1 = OpDecorationGroup +OpGroupMemberDecorate %1 %var_struct 0 +OpGroupMemberDecorate %1 %output_struct 1 +%void = OpTypeVoid +%func = OpTypeFunction %void +%uint = OpTypeInt 32 0 +%var_struct = OpTypeStruct %uint %uint +%output_struct = OpTypeStruct %uint %uint +%struct_ptr_Function = OpTypePointer Function %var_struct +%struct_ptr_Output = OpTypePointer Output %output_struct +%uint_ptr_Output = OpTypePointer Output %uint +%output = OpVariable %struct_ptr_Output Output +%uint_0 = OpConstant %uint 0 +%main = OpFunction %void None %func +%2 = OpLabel +%var = OpVariable %struct_ptr_Function Function +%3 = OpAccessChain %uint_ptr_Output %output %uint_0 +OpStore %3 %uint_0 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +// Test for #1404 +TEST_F(AggressiveDCETest, DontRemoveWorkgroupSize) { + const std::string text = R"( +; CHECK: OpDecorate [[wgs:%\w+]] BuiltIn WorkgroupSize +; CHECK: [[wgs]] = OpSpecConstantComposite +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %func "func" +OpExecutionMode %func LocalSize 1 1 1 +OpDecorate %1 BuiltIn WorkgroupSize +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%functy = OpTypeFunction %void +%v3int = OpTypeVector %int 3 +%2 = OpSpecConstant %int 1 +%1 = OpSpecConstantComposite %v3int %2 %2 %2 +%func = OpFunction %void None %functy +%3 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +// Test for #1214 +TEST_F(AggressiveDCETest, LoopHeaderIsAlsoAnotherLoopMerge) { + const std::string text = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" %2 +OpExecutionMode %1 OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%uint = OpTypeInt 32 0 +%_ptr_Output_uint = OpTypePointer Output %uint +%2 = OpVariable %_ptr_Output_uint Output +%uint_0 = OpConstant %uint 0 +%9 = OpTypeFunction %void +%1 = OpFunction %void None %9 +%10 = OpLabel +OpBranch %11 +%11 = OpLabel +OpLoopMerge %12 %13 None +OpBranchConditional %true %14 %13 +%14 = OpLabel +OpStore %2 %uint_0 +OpLoopMerge %15 %16 None +OpBranchConditional %true %15 %16 +%16 = OpLabel +OpBranch %14 +%15 = OpLabel +OpBranchConditional %true %12 %13 +%13 = OpLabel +OpBranch %11 +%12 = OpLabel +%17 = OpPhi %uint %uint_0 %15 %uint_0 %18 +OpStore %2 %17 +OpLoopMerge %19 %18 None +OpBranchConditional %true %19 %18 +%18 = OpLabel +OpBranch %12 +%19 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(text, text, true, true); +} + +TEST_F(AggressiveDCETest, BreaksDontVisitPhis) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" %var +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%int = OpTypeInt 32 0 +%int_ptr_Output = OpTypePointer Output %int +%var = OpVariable %int_ptr_Output Output +%int0 = OpConstant %int 0 +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%entry = OpLabel +OpBranch %outer_header +%outer_header = OpLabel +OpLoopMerge %outer_merge %outer_continue None +OpBranchConditional %true %inner_header %outer_continue +%inner_header = OpLabel +%phi = OpPhi %int %int0 %outer_header %int0 %inner_continue +OpStore %var %phi +OpLoopMerge %inner_merge %inner_continue None +OpBranchConditional %true %inner_merge %inner_continue +%inner_continue = OpLabel +OpBranch %inner_header +%inner_merge = OpLabel +OpBranch %outer_continue +%outer_continue = OpLabel +%p = OpPhi %int %int0 %outer_header %int0 %inner_merge +OpStore %var %p +OpBranch %outer_header +%outer_merge = OpLabel +OpReturn +OpFunctionEnd +)"; + + EXPECT_EQ(Pass::Status::SuccessWithoutChange, + std::get<1>(SinglePassRunAndDisassemble( + text, false, true))); +} + +// Test for #1212 +TEST_F(AggressiveDCETest, ConstStoreInnerLoop) { + const std::string text = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %1 "main" %2 +%void = OpTypeVoid +%4 = OpTypeFunction %void +%float = OpTypeFloat 32 +%bool = OpTypeBool +%true = OpConstantTrue %bool +%_ptr_Output_float = OpTypePointer Output %float +%2 = OpVariable %_ptr_Output_float Output +%float_3 = OpConstant %float 3 +%1 = OpFunction %void None %4 +%13 = OpLabel +OpBranch %14 +%14 = OpLabel +OpLoopMerge %15 %16 None +OpBranchConditional %true %17 %15 +%17 = OpLabel +OpStore %2 %float_3 +OpLoopMerge %18 %17 None +OpBranchConditional %true %18 %17 +%18 = OpLabel +OpBranch %15 +%16 = OpLabel +OpBranch %14 +%15 = OpLabel +OpBranch %20 +%20 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(text, text, true, true); +} + +// Test for #1212 +TEST_F(AggressiveDCETest, InnerLoopCopy) { + const std::string text = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %1 "main" %2 %3 +%void = OpTypeVoid +%5 = OpTypeFunction %void +%float = OpTypeFloat 32 +%bool = OpTypeBool +%true = OpConstantTrue %bool +%_ptr_Output_float = OpTypePointer Output %float +%_ptr_Input_float = OpTypePointer Input %float +%2 = OpVariable %_ptr_Output_float Output +%3 = OpVariable %_ptr_Input_float Input +%1 = OpFunction %void None %5 +%14 = OpLabel +OpBranch %15 +%15 = OpLabel +OpLoopMerge %16 %17 None +OpBranchConditional %true %18 %16 +%18 = OpLabel +%19 = OpLoad %float %3 +OpStore %2 %19 +OpLoopMerge %20 %18 None +OpBranchConditional %true %20 %18 +%20 = OpLabel +OpBranch %16 +%17 = OpLabel +OpBranch %15 +%16 = OpLabel +OpBranch %22 +%22 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(text, text, true, true); +} + +TEST_F(AggressiveDCETest, AtomicAdd) { + const std::string text = R"(OpCapability SampledBuffer +OpCapability StorageImageExtendedFormats +OpCapability ImageBuffer +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %2 "min" %gl_GlobalInvocationID +OpExecutionMode %2 LocalSize 64 1 1 +OpSource HLSL 600 +OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId +OpDecorate %4 DescriptorSet 4 +OpDecorate %4 Binding 70 +%uint = OpTypeInt 32 0 +%6 = OpTypeImage %uint Buffer 0 0 0 2 R32ui +%_ptr_UniformConstant_6 = OpTypePointer UniformConstant %6 +%_ptr_Private_6 = OpTypePointer Private %6 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%v3uint = OpTypeVector %uint 3 +%_ptr_Input_v3uint = OpTypePointer Input %v3uint +%_ptr_Image_uint = OpTypePointer Image %uint +%4 = OpVariable %_ptr_UniformConstant_6 UniformConstant +%16 = OpVariable %_ptr_Private_6 Private +%gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input +%2 = OpFunction %void None %10 +%17 = OpLabel +%18 = OpLoad %6 %4 +OpStore %16 %18 +%19 = OpImageTexelPointer %_ptr_Image_uint %16 %uint_0 %uint_0 +%20 = OpAtomicIAdd %uint %19 %uint_1 %uint_0 %uint_1 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(text, text, true, true); +} + +TEST_F(AggressiveDCETest, SafelyRemoveDecorateString) { + const std::string preamble = R"(OpCapability Shader +OpExtension "SPV_GOOGLE_hlsl_functionality1" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "main" +OpExecutionMode %1 OriginUpperLeft +)"; + + const std::string body_before = + R"(OpDecorateStringGOOGLE %2 HlslSemanticGOOGLE "FOOBAR" +%void = OpTypeVoid +%4 = OpTypeFunction %void +%uint = OpTypeInt 32 0 +%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint +%2 = OpVariable %_ptr_StorageBuffer_uint StorageBuffer +%1 = OpFunction %void None %4 +%7 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string body_after = R"(%void = OpTypeVoid +%4 = OpTypeFunction %void +%1 = OpFunction %void None %4 +%7 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(preamble + body_before, + preamble + body_after, true, true); +} + +TEST_F(AggressiveDCETest, CopyMemoryToGlobal) { + // |local| is loaded in an OpCopyMemory instruction. So the store must be + // kept alive. + const std::string test = + R"(OpCapability Geometry +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Geometry %main "main" %global +OpExecutionMode %main Triangles +OpExecutionMode %main Invocations 1 +OpExecutionMode %main OutputTriangleStrip +OpExecutionMode %main OutputVertices 5 +OpSource GLSL 440 +OpName %main "main" +OpName %local "local" +OpName %global "global" +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%12 = OpConstantNull %v4float +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%global = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %7 +%19 = OpLabel +%local = OpVariable %_ptr_Function_v4float Function +OpStore %local %12 +OpCopyMemory %global %local +OpEndPrimitive +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(test, test, true, true); +} + +TEST_F(AggressiveDCETest, CopyMemoryToLocal) { + // Make sure the store to |local2| using OpCopyMemory is kept and keeps + // |local1| alive. + const std::string test = + R"(OpCapability Geometry +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Geometry %main "main" %global +OpExecutionMode %main Triangles +OpExecutionMode %main Invocations 1 +OpExecutionMode %main OutputTriangleStrip +OpExecutionMode %main OutputVertices 5 +OpSource GLSL 440 +OpName %main "main" +OpName %local1 "local1" +OpName %local2 "local2" +OpName %global "global" +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%12 = OpConstantNull %v4float +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%global = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %7 +%19 = OpLabel +%local1 = OpVariable %_ptr_Function_v4float Function +%local2 = OpVariable %_ptr_Function_v4float Function +OpStore %local1 %12 +OpCopyMemory %local2 %local1 +OpCopyMemory %global %local2 +OpEndPrimitive +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(test, test, true, true); +} + +TEST_F(AggressiveDCETest, RemoveCopyMemoryToLocal) { + // Test that we remove function scope variables that are stored to using + // OpCopyMemory, but are never loaded. We can remove both |local1| and + // |local2|. + const std::string test = + R"(OpCapability Geometry +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Geometry %main "main" %global +OpExecutionMode %main Triangles +OpExecutionMode %main Invocations 1 +OpExecutionMode %main OutputTriangleStrip +OpExecutionMode %main OutputVertices 5 +OpSource GLSL 440 +OpName %main "main" +OpName %local1 "local1" +OpName %local2 "local2" +OpName %global "global" +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%12 = OpConstantNull %v4float +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%global = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %7 +%19 = OpLabel +%local1 = OpVariable %_ptr_Function_v4float Function +%local2 = OpVariable %_ptr_Function_v4float Function +OpStore %local1 %12 +OpCopyMemory %local2 %local1 +OpEndPrimitive +OpReturn +OpFunctionEnd +)"; + + const std::string result = + R"(OpCapability Geometry +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Geometry %main "main" %global +OpExecutionMode %main Triangles +OpExecutionMode %main Invocations 1 +OpExecutionMode %main OutputTriangleStrip +OpExecutionMode %main OutputVertices 5 +OpSource GLSL 440 +OpName %main "main" +OpName %global "global" +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%global = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %7 +%19 = OpLabel +OpEndPrimitive +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(test, result, true, true); +} + +TEST_F(AggressiveDCETest, RemoveCopyMemoryToLocal2) { + // We are able to remove "local2" because it is not loaded, but have to keep + // the stores to "local1". + const std::string test = + R"(OpCapability Geometry +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Geometry %main "main" %global +OpExecutionMode %main Triangles +OpExecutionMode %main Invocations 1 +OpExecutionMode %main OutputTriangleStrip +OpExecutionMode %main OutputVertices 5 +OpSource GLSL 440 +OpName %main "main" +OpName %local1 "local1" +OpName %local2 "local2" +OpName %global "global" +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%12 = OpConstantNull %v4float +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%global = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %7 +%19 = OpLabel +%local1 = OpVariable %_ptr_Function_v4float Function +%local2 = OpVariable %_ptr_Function_v4float Function +OpStore %local1 %12 +OpCopyMemory %local2 %local1 +OpCopyMemory %global %local1 +OpEndPrimitive +OpReturn +OpFunctionEnd +)"; + + const std::string result = + R"(OpCapability Geometry +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Geometry %main "main" %global +OpExecutionMode %main Triangles +OpExecutionMode %main Invocations 1 +OpExecutionMode %main OutputTriangleStrip +OpExecutionMode %main OutputVertices 5 +OpSource GLSL 440 +OpName %main "main" +OpName %local1 "local1" +OpName %global "global" +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%12 = OpConstantNull %v4float +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%global = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %7 +%19 = OpLabel +%local1 = OpVariable %_ptr_Function_v4float Function +OpStore %local1 %12 +OpCopyMemory %global %local1 +OpEndPrimitive +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(test, result, true, true); +} + +TEST_F(AggressiveDCETest, StructuredIfWithConditionalExit) { + // We are able to remove "local2" because it is not loaded, but have to keep + // the stores to "local1". + const std::string test = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpSourceExtension "GL_GOOGLE_cpp_style_line_directive" +OpSourceExtension "GL_GOOGLE_include_directive" +OpName %main "main" +OpName %a "a" +%void = OpTypeVoid +%5 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Uniform_int = OpTypePointer Uniform %int +%int_0 = OpConstant %int 0 +%bool = OpTypeBool +%int_100 = OpConstant %int 100 +%int_1 = OpConstant %int 1 +%a = OpVariable %_ptr_Uniform_int Uniform +%main = OpFunction %void None %5 +%12 = OpLabel +%13 = OpLoad %int %a +%14 = OpSGreaterThan %bool %13 %int_0 +OpSelectionMerge %15 None +OpBranchConditional %14 %16 %15 +%16 = OpLabel +%17 = OpLoad %int %a +%18 = OpSLessThan %bool %17 %int_100 +OpBranchConditional %18 %19 %15 +%19 = OpLabel +OpStore %a %int_1 +OpBranch %15 +%15 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(test, test, true, true); +} + +TEST_F(AggressiveDCETest, CountingLoopNotEliminated) { + // #version 310 es + // + // precision highp float; + // precision highp int; + // + // layout(location = 0) out vec4 _GLF_color; + // + // void main() + // { + // float data[1]; + // for (int c = 0; c < 1; c++) { + // if (true) { + // do { + // for (int i = 0; i < 1; i++) { + // data[i] = 1.0; + // } + // } while (false); + // } + // } + // _GLF_color = vec4(data[0], 0.0, 0.0, 1.0); + // } + const std::string test = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %_GLF_color +OpExecutionMode %main OriginUpperLeft +OpSource ESSL 310 +OpName %main "main" +OpName %c "c" +OpName %i "i" +OpName %data "data" +OpName %_GLF_color "_GLF_color" +OpDecorate %_GLF_color Location 0 +%void = OpTypeVoid +%8 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%int_1 = OpConstant %int 1 +%bool = OpTypeBool +%float = OpTypeFloat 32 +%uint = OpTypeInt 32 0 +%uint_1 = OpConstant %uint 1 +%_arr_float_uint_1 = OpTypeArray %float %uint_1 +%_ptr_Function__arr_float_uint_1 = OpTypePointer Function %_arr_float_uint_1 +%float_1 = OpConstant %float 1 +%_ptr_Function_float = OpTypePointer Function %float +%false = OpConstantFalse %bool +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_GLF_color = OpVariable %_ptr_Output_v4float Output +%float_0 = OpConstant %float 0 +%main = OpFunction %void None %8 +%26 = OpLabel +%c = OpVariable %_ptr_Function_int Function +%i = OpVariable %_ptr_Function_int Function +%data = OpVariable %_ptr_Function__arr_float_uint_1 Function +OpStore %c %int_0 +OpBranch %27 +%27 = OpLabel +OpLoopMerge %28 %29 None +OpBranch %30 +%30 = OpLabel +%31 = OpLoad %int %c +%32 = OpSLessThan %bool %31 %int_1 +OpBranchConditional %32 %33 %28 +%33 = OpLabel +OpBranch %34 +%34 = OpLabel +OpBranch %35 +%35 = OpLabel +OpLoopMerge %36 %37 None +OpBranch %38 +%38 = OpLabel +OpStore %i %int_0 +OpBranch %39 +%39 = OpLabel +OpLoopMerge %40 %41 None +OpBranch %42 +%42 = OpLabel +%43 = OpLoad %int %i +%44 = OpSLessThan %bool %43 %int_1 +OpSelectionMerge %45 None +OpBranchConditional %44 %46 %40 +%46 = OpLabel +%47 = OpLoad %int %i +%48 = OpAccessChain %_ptr_Function_float %data %47 +OpStore %48 %float_1 +OpBranch %41 +%41 = OpLabel +%49 = OpLoad %int %i +%50 = OpIAdd %int %49 %int_1 +OpStore %i %50 +OpBranch %39 +%40 = OpLabel +OpBranch %37 +%37 = OpLabel +OpBranchConditional %false %35 %36 +%36 = OpLabel +OpBranch %45 +%45 = OpLabel +OpBranch %29 +%29 = OpLabel +%51 = OpLoad %int %c +%52 = OpIAdd %int %51 %int_1 +OpStore %c %52 +OpBranch %27 +%28 = OpLabel +%53 = OpAccessChain %_ptr_Function_float %data %int_0 +%54 = OpLoad %float %53 +%55 = OpCompositeConstruct %v4float %54 %float_0 %float_0 %float_1 +OpStore %_GLF_color %55 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(test, test, true, true); +} + +TEST_F(AggressiveDCETest, EliminateLoopWithUnreachable) { + // #version 430 + // + // layout(std430) buffer U_t + // { + // float g_F[10]; + // float g_S; + // }; + // + // layout(location = 0)out float o; + // + // void main(void) + // { + // // Useless loop + // for (int i = 0; i<10; i++) { + // if (g_F[i] == 0.0) + // break; + // else + // break; + // // Unreachable merge block created here. + // // Need to edit SPIR-V to change to OpUnreachable + // } + // o = g_S; + // } + + const std::string before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %o +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 430 +OpName %main "main" +OpName %i "i" +OpName %U_t "U_t" +OpMemberName %U_t 0 "g_F" +OpMemberName %U_t 1 "g_S" +OpName %_ "" +OpName %o "o" +OpDecorate %_arr_float_uint_10 ArrayStride 4 +OpMemberDecorate %U_t 0 Offset 0 +OpMemberDecorate %U_t 1 Offset 40 +OpDecorate %U_t BufferBlock +OpDecorate %_ DescriptorSet 0 +OpDecorate %o Location 0 +%void = OpTypeVoid +%9 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%float = OpTypeFloat 32 +%uint = OpTypeInt 32 0 +%uint_10 = OpConstant %uint 10 +%_arr_float_uint_10 = OpTypeArray %float %uint_10 +%U_t = OpTypeStruct %_arr_float_uint_10 %float +%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t +%_ = OpVariable %_ptr_Uniform_U_t Uniform +%_ptr_Uniform_float = OpTypePointer Uniform %float +%float_0 = OpConstant %float 0 +%int_1 = OpConstant %int 1 +%_ptr_Output_float = OpTypePointer Output %float +%o = OpVariable %_ptr_Output_float Output +%main = OpFunction %void None %9 +%23 = OpLabel +%i = OpVariable %_ptr_Function_int Function +OpStore %i %int_0 +OpBranch %24 +%24 = OpLabel +OpLoopMerge %25 %26 None +OpBranch %27 +%27 = OpLabel +%28 = OpLoad %int %i +%29 = OpSLessThan %bool %28 %int_10 +OpBranchConditional %29 %30 %25 +%30 = OpLabel +%31 = OpLoad %int %i +%32 = OpAccessChain %_ptr_Uniform_float %_ %int_0 %31 +%33 = OpLoad %float %32 +%34 = OpFOrdEqual %bool %33 %float_0 +OpSelectionMerge %35 None +OpBranchConditional %34 %36 %37 +%36 = OpLabel +OpBranch %25 +%37 = OpLabel +OpBranch %25 +%35 = OpLabel +OpUnreachable +%26 = OpLabel +%38 = OpLoad %int %i +%39 = OpIAdd %int %38 %int_1 +OpStore %i %39 +OpBranch %24 +%25 = OpLabel +%40 = OpAccessChain %_ptr_Uniform_float %_ %int_1 +%41 = OpLoad %float %40 +OpStore %o %41 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %o +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 430 +OpName %main "main" +OpName %U_t "U_t" +OpMemberName %U_t 0 "g_F" +OpMemberName %U_t 1 "g_S" +OpName %_ "" +OpName %o "o" +OpDecorate %_arr_float_uint_10 ArrayStride 4 +OpMemberDecorate %U_t 0 Offset 0 +OpMemberDecorate %U_t 1 Offset 40 +OpDecorate %U_t BufferBlock +OpDecorate %_ DescriptorSet 0 +OpDecorate %o Location 0 +%void = OpTypeVoid +%9 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%float = OpTypeFloat 32 +%uint = OpTypeInt 32 0 +%uint_10 = OpConstant %uint 10 +%_arr_float_uint_10 = OpTypeArray %float %uint_10 +%U_t = OpTypeStruct %_arr_float_uint_10 %float +%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t +%_ = OpVariable %_ptr_Uniform_U_t Uniform +%_ptr_Uniform_float = OpTypePointer Uniform %float +%int_1 = OpConstant %int 1 +%_ptr_Output_float = OpTypePointer Output %float +%o = OpVariable %_ptr_Output_float Output +%main = OpFunction %void None %9 +%23 = OpLabel +OpBranch %24 +%24 = OpLabel +OpBranch %25 +%25 = OpLabel +%40 = OpAccessChain %_ptr_Uniform_float %_ %int_1 +%41 = OpLoad %float %40 +OpStore %o %41 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(before, after, true, true); +} + +TEST_F(AggressiveDCETest, DeadHlslCounterBufferGOOGLE) { + // We are able to remove "local2" because it is not loaded, but have to keep + // the stores to "local1". + const std::string test = + R"( +; CHECK-NOT: OpDecorateId +; CHECK: [[var:%\w+]] = OpVariable +; CHECK-NOT: OpVariable +; CHECK: [[ac:%\w+]] = OpAccessChain {{%\w+}} [[var]] +; CHECK: OpStore [[ac]] + OpCapability Shader + OpExtension "SPV_GOOGLE_hlsl_functionality1" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + OpExecutionMode %1 LocalSize 32 1 1 + OpSource HLSL 600 + OpDecorate %_runtimearr_v2float ArrayStride 8 + OpMemberDecorate %_struct_3 0 Offset 0 + OpDecorate %_struct_3 BufferBlock + OpMemberDecorate %_struct_4 0 Offset 0 + OpDecorate %_struct_4 BufferBlock + OpDecorateId %5 HlslCounterBufferGOOGLE %6 + OpDecorate %5 DescriptorSet 0 + OpDecorate %5 Binding 0 + OpDecorate %6 DescriptorSet 0 + OpDecorate %6 Binding 1 + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 +%_runtimearr_v2float = OpTypeRuntimeArray %v2float + %_struct_3 = OpTypeStruct %_runtimearr_v2float +%_ptr_Uniform__struct_3 = OpTypePointer Uniform %_struct_3 + %int = OpTypeInt 32 1 + %_struct_4 = OpTypeStruct %int +%_ptr_Uniform__struct_4 = OpTypePointer Uniform %_struct_4 + %void = OpTypeVoid + %13 = OpTypeFunction %void + %19 = OpConstantNull %v2float + %int_0 = OpConstant %int 0 +%_ptr_Uniform_v2float = OpTypePointer Uniform %v2float + %5 = OpVariable %_ptr_Uniform__struct_3 Uniform + %6 = OpVariable %_ptr_Uniform__struct_4 Uniform + %1 = OpFunction %void None %13 + %22 = OpLabel + %23 = OpAccessChain %_ptr_Uniform_v2float %5 %int_0 %int_0 + OpStore %23 %19 + OpReturn + OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndMatch(test, true); +} + +TEST_F(AggressiveDCETest, Dead) { + // We are able to remove "local2" because it is not loaded, but have to keep + // the stores to "local1". + const std::string test = + R"( +; CHECK: OpCapability +; CHECK-NOT: OpMemberDecorateStringGOOGLE +; CHECK: OpFunctionEnd + OpCapability Shader + OpExtension "SPV_GOOGLE_hlsl_functionality1" + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %VSMain "VSMain" + OpSource HLSL 500 + OpName %VSMain "VSMain" + OpName %PSInput "PSInput" + OpMemberName %PSInput 0 "Pos" + OpMemberName %PSInput 1 "uv" + OpMemberDecorateStringGOOGLE %PSInput 0 HlslSemanticGOOGLE "SV_POSITION" + OpMemberDecorateStringGOOGLE %PSInput 1 HlslSemanticGOOGLE "TEX_COORD" + %void = OpTypeVoid + %5 = OpTypeFunction %void + %float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%v4float = OpTypeVector %float 4 +%PSInput = OpTypeStruct %v4float %v2float + %VSMain = OpFunction %void None %5 + %9 = OpLabel + OpReturn + OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndMatch(test, true); +} +// TODO(greg-lunarg): Add tests to verify handling of these cases: +// +// Check that logical addressing required +// Check that function calls inhibit optimization +// Others? + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/assembly_builder.h b/test/opt/assembly_builder.h new file mode 100644 index 000000000..1673c092b --- /dev/null +++ b/test/opt/assembly_builder.h @@ -0,0 +1,266 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TEST_OPT_ASSEMBLY_BUILDER_H_ +#define TEST_OPT_ASSEMBLY_BUILDER_H_ + +#include +#include +#include +#include +#include +#include + +namespace spvtools { +namespace opt { + +// A simple SPIR-V assembly code builder for test uses. It builds an SPIR-V +// assembly module from vectors of assembly strings. It allows users to add +// instructions to the main function and the type-constants-globals section +// directly. It relies on OpName instructions and friendly-name disassembling +// to keep the ID names unchanged after assembling. +// +// An assembly module is divided into several sections, matching with the +// SPIR-V Logical Layout: +// Global Preamble: +// OpCapability instructions; +// OpExtension instructions and OpExtInstImport instructions; +// OpMemoryModel instruction; +// OpEntryPoint and OpExecutionMode instruction; +// OpString, OpSourceExtension, OpSource and OpSourceContinued instructions. +// Names: +// OpName instructions. +// Annotations: +// OpDecorate, OpMemberDecorate, OpGroupDecorate, OpGroupMemberDecorate and +// OpDecorationGroup. +// Types, Constants and Global variables: +// Types, constants and global variables declaration instructions. +// Main Function: +// Main function instructions. +// Main Function Postamble: +// The return and function end instructions. +// +// The assembly code is built by concatenating all the strings in the above +// sections. +// +// Users define the contents in section +// and
. The section is to hold the names for IDs to +// keep them unchanged before and after assembling. All defined IDs to be added +// to this code builder will be assigned with a global name through OpName +// instruction. The name is extracted from the definition instruction. +// E.g. adding instruction: %var_a = OpConstant %int 2, will also add an +// instruction: OpName %var_a, "var_a". +// +// Note that the name must not be used on more than one defined IDs and +// friendly-name disassembling must be enabled so that OpName instructions will +// be respected. +class AssemblyBuilder { + // The base ID value for spec constants. + static const uint32_t SPEC_ID_BASE = 200; + + public: + // Initalize a minimal SPIR-V assembly code as the template. The minimal + // module contains an empty main function and some predefined names for the + // main function. + AssemblyBuilder() + : spec_id_counter_(SPEC_ID_BASE), + global_preamble_({ + // clang-format off + "OpCapability Shader", + "OpCapability Float64", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Vertex %main \"main\"", + // clang-format on + }), + names_(), + annotations_(), + types_consts_globals_(), + main_func_(), + main_func_postamble_({ + "OpReturn", + "OpFunctionEnd", + }) { + AppendTypesConstantsGlobals({ + "%void = OpTypeVoid", + "%main_func_type = OpTypeFunction %void", + }); + AppendInMain({ + "%main = OpFunction %void None %main_func_type", + "%main_func_entry_block = OpLabel", + }); + } + + // Appends OpName instructions to this builder. Instrcution strings that do + // not start with 'OpName ' will be skipped. Returns the references of this + // assembly builder. + AssemblyBuilder& AppendNames(const std::vector& vec_asm_code) { + for (auto& inst_str : vec_asm_code) { + if (inst_str.find("OpName ") == 0) { + names_.push_back(inst_str); + } + } + return *this; + } + + // Appends instructions to the types-constants-globals section and returns + // the reference of this assembly builder. IDs defined in the given code will + // be added to the Names section and then be registered with OpName + // instruction. Corresponding decoration instruction will be added for spec + // constants defined with opcode: 'OpSpecConstant'. + AssemblyBuilder& AppendTypesConstantsGlobals( + const std::vector& vec_asm_code) { + AddNamesForResultIDsIn(vec_asm_code); + // Check spec constants defined with OpSpecConstant. + for (auto& inst_str : vec_asm_code) { + if (inst_str.find("= OpSpecConstant ") != std::string::npos || + inst_str.find("= OpSpecConstantTrue ") != std::string::npos || + inst_str.find("= OpSpecConstantFalse ") != std::string::npos) { + AddSpecIDFor(GetResultIDName(inst_str)); + } + } + types_consts_globals_.insert(types_consts_globals_.end(), + vec_asm_code.begin(), vec_asm_code.end()); + return *this; + } + + // Appends instructions to the main function block, which is already labelled + // with "main_func_entry_block". Returns the reference of this assembly + // builder. IDs defined in the given code will be added to the Names section + // and then be registered with OpName instruction. + AssemblyBuilder& AppendInMain(const std::vector& vec_asm_code) { + AddNamesForResultIDsIn(vec_asm_code); + main_func_.insert(main_func_.end(), vec_asm_code.begin(), + vec_asm_code.end()); + return *this; + } + + // Appends annotation instructions to the annotation section, and returns the + // reference of this assembly builder. + AssemblyBuilder& AppendAnnotations( + const std::vector& vec_annotations) { + annotations_.insert(annotations_.end(), vec_annotations.begin(), + vec_annotations.end()); + return *this; + } + + // Pre-pends string to the preamble of the module. Useful for EFFCEE checks. + AssemblyBuilder& PrependPreamble(const std::vector& preamble) { + preamble_.insert(preamble_.end(), preamble.begin(), preamble.end()); + return *this; + } + + // Get the SPIR-V assembly code as string. + std::string GetCode() const { + std::ostringstream ss; + for (const auto& line : preamble_) { + ss << line << std::endl; + } + for (const auto& line : global_preamble_) { + ss << line << std::endl; + } + for (const auto& line : names_) { + ss << line << std::endl; + } + for (const auto& line : annotations_) { + ss << line << std::endl; + } + for (const auto& line : types_consts_globals_) { + ss << line << std::endl; + } + for (const auto& line : main_func_) { + ss << line << std::endl; + } + for (const auto& line : main_func_postamble_) { + ss << line << std::endl; + } + return ss.str(); + } + + private: + // Adds a given name to the Name section with OpName. If the given name has + // been added before, does nothing. + void AddOpNameIfNotExist(const std::string& id_name) { + if (!used_names_.count(id_name)) { + std::stringstream opname_inst; + opname_inst << "OpName " + << "%" << id_name << " \"" << id_name << "\""; + names_.emplace_back(opname_inst.str()); + used_names_.insert(id_name); + } + } + + // Adds the names in a vector of assembly code strings to the Names section. + // If a '=' sign is found in an instruction, this instruction will be treated + // as an ID defining instruction. The ID name used in the instruction will be + // extracted and added to the Names section. + void AddNamesForResultIDsIn(const std::vector& vec_asm_code) { + for (const auto& line : vec_asm_code) { + std::string name = GetResultIDName(line); + if (!name.empty()) { + AddOpNameIfNotExist(name); + } + } + } + + // Adds an OpDecorate SpecId instruction for the given ID name. + void AddSpecIDFor(const std::string& id_name) { + std::stringstream decorate_inst; + decorate_inst << "OpDecorate " + << "%" << id_name << " SpecId " << spec_id_counter_; + spec_id_counter_ += 1; + annotations_.emplace_back(decorate_inst.str()); + } + + // Extracts the ID name from a SPIR-V assembly instruction string. If the + // instruction is an ID-defining instruction (has result ID), returns the + // name of the result ID in string. If the instruction does not have result + // ID, returns an empty string. + std::string GetResultIDName(const std::string inst_str) { + std::string name; + if (inst_str.find('=') != std::string::npos) { + size_t assign_sign = inst_str.find('='); + name = inst_str.substr(0, assign_sign); + name.erase(remove_if(name.begin(), name.end(), + [](char c) { return c == ' ' || c == '%'; }), + name.end()); + } + return name; + } + + uint32_t spec_id_counter_; + // User-defined preamble. + std::vector preamble_; + // The vector that contains common preambles shared across all test SPIR-V + // code. + std::vector global_preamble_; + // The vector that contains OpName instructions. + std::vector names_; + // The vector that contains annotation instructions. + std::vector annotations_; + // The vector that contains the code to declare types, constants and global + // variables (aka. the Types-Constants-Globals section). + std::vector types_consts_globals_; + // The vector that contains the code in main function's entry block. + std::vector main_func_; + // The vector that contains the postamble of main function body. + std::vector main_func_postamble_; + // All of the defined variable names. + std::unordered_set used_names_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // TEST_OPT_ASSEMBLY_BUILDER_H_ diff --git a/test/opt/assembly_builder_test.cpp b/test/opt/assembly_builder_test.cpp new file mode 100644 index 000000000..55fbbe904 --- /dev/null +++ b/test/opt/assembly_builder_test.cpp @@ -0,0 +1,283 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/opt/assembly_builder.h" + +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using AssemblyBuilderTest = PassTest<::testing::Test>; + +TEST_F(AssemblyBuilderTest, MinimalShader) { + AssemblyBuilder builder; + std::vector expected = { + // clang-format off + "OpCapability Shader", + "OpCapability Float64", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Vertex %main \"main\"", + "OpName %void \"void\"", + "OpName %main_func_type \"main_func_type\"", + "OpName %main \"main\"", + "OpName %main_func_entry_block \"main_func_entry_block\"", + "%void = OpTypeVoid", + "%main_func_type = OpTypeFunction %void", + "%main = OpFunction %void None %main_func_type", +"%main_func_entry_block = OpLabel", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + + SinglePassRunAndCheck(builder.GetCode(), JoinAllInsts(expected), + /* skip_nop = */ false); +} + +TEST_F(AssemblyBuilderTest, ShaderWithConstants) { + AssemblyBuilder builder; + builder + .AppendTypesConstantsGlobals({ + // clang-format off + "%bool = OpTypeBool", + "%_PF_bool = OpTypePointer Function %bool", + "%bt = OpConstantTrue %bool", + "%bf = OpConstantFalse %bool", + "%int = OpTypeInt 32 1", + "%_PF_int = OpTypePointer Function %int", + "%si = OpConstant %int 1", + "%uint = OpTypeInt 32 0", + "%_PF_uint = OpTypePointer Function %uint", + "%ui = OpConstant %uint 2", + "%float = OpTypeFloat 32", + "%_PF_float = OpTypePointer Function %float", + "%f = OpConstant %float 3.1415", + "%double = OpTypeFloat 64", + "%_PF_double = OpTypePointer Function %double", + "%d = OpConstant %double 3.14159265358979", + // clang-format on + }) + .AppendInMain({ + // clang-format off + "%btv = OpVariable %_PF_bool Function", + "%bfv = OpVariable %_PF_bool Function", + "%iv = OpVariable %_PF_int Function", + "%uv = OpVariable %_PF_uint Function", + "%fv = OpVariable %_PF_float Function", + "%dv = OpVariable %_PF_double Function", + "OpStore %btv %bt", + "OpStore %bfv %bf", + "OpStore %iv %si", + "OpStore %uv %ui", + "OpStore %fv %f", + "OpStore %dv %d", + // clang-format on + }); + + std::vector expected = { + // clang-format off + "OpCapability Shader", + "OpCapability Float64", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Vertex %main \"main\"", + "OpName %void \"void\"", + "OpName %main_func_type \"main_func_type\"", + "OpName %main \"main\"", + "OpName %main_func_entry_block \"main_func_entry_block\"", + "OpName %bool \"bool\"", + "OpName %_PF_bool \"_PF_bool\"", + "OpName %bt \"bt\"", + "OpName %bf \"bf\"", + "OpName %int \"int\"", + "OpName %_PF_int \"_PF_int\"", + "OpName %si \"si\"", + "OpName %uint \"uint\"", + "OpName %_PF_uint \"_PF_uint\"", + "OpName %ui \"ui\"", + "OpName %float \"float\"", + "OpName %_PF_float \"_PF_float\"", + "OpName %f \"f\"", + "OpName %double \"double\"", + "OpName %_PF_double \"_PF_double\"", + "OpName %d \"d\"", + "OpName %btv \"btv\"", + "OpName %bfv \"bfv\"", + "OpName %iv \"iv\"", + "OpName %uv \"uv\"", + "OpName %fv \"fv\"", + "OpName %dv \"dv\"", + "%void = OpTypeVoid", +"%main_func_type = OpTypeFunction %void", + "%bool = OpTypeBool", + "%_PF_bool = OpTypePointer Function %bool", + "%bt = OpConstantTrue %bool", + "%bf = OpConstantFalse %bool", + "%int = OpTypeInt 32 1", + "%_PF_int = OpTypePointer Function %int", + "%si = OpConstant %int 1", + "%uint = OpTypeInt 32 0", + "%_PF_uint = OpTypePointer Function %uint", + "%ui = OpConstant %uint 2", + "%float = OpTypeFloat 32", + "%_PF_float = OpTypePointer Function %float", + "%f = OpConstant %float 3.1415", + "%double = OpTypeFloat 64", + "%_PF_double = OpTypePointer Function %double", + "%d = OpConstant %double 3.14159265358979", + "%main = OpFunction %void None %main_func_type", +"%main_func_entry_block = OpLabel", + "%btv = OpVariable %_PF_bool Function", + "%bfv = OpVariable %_PF_bool Function", + "%iv = OpVariable %_PF_int Function", + "%uv = OpVariable %_PF_uint Function", + "%fv = OpVariable %_PF_float Function", + "%dv = OpVariable %_PF_double Function", + "OpStore %btv %bt", + "OpStore %bfv %bf", + "OpStore %iv %si", + "OpStore %uv %ui", + "OpStore %fv %f", + "OpStore %dv %d", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + SinglePassRunAndCheck(builder.GetCode(), JoinAllInsts(expected), + /* skip_nop = */ false); +} + +TEST_F(AssemblyBuilderTest, SpecConstants) { + AssemblyBuilder builder; + builder.AppendTypesConstantsGlobals({ + "%bool = OpTypeBool", + "%uint = OpTypeInt 32 0", + "%int = OpTypeInt 32 1", + "%float = OpTypeFloat 32", + "%double = OpTypeFloat 64", + "%v2int = OpTypeVector %int 2", + + "%spec_true = OpSpecConstantTrue %bool", + "%spec_false = OpSpecConstantFalse %bool", + "%spec_uint = OpSpecConstant %uint 1", + "%spec_int = OpSpecConstant %int 1", + "%spec_float = OpSpecConstant %float 1.25", + "%spec_double = OpSpecConstant %double 1.2345678", + + // Spec constants defined below should not have SpecID. + "%spec_add_op = OpSpecConstantOp %int IAdd %spec_int %spec_int", + "%spec_vec = OpSpecConstantComposite %v2int %spec_int %spec_int", + "%spec_vec_x = OpSpecConstantOp %int CompositeExtract %spec_vec 0", + }); + std::vector expected = { + // clang-format off + "OpCapability Shader", + "OpCapability Float64", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Vertex %main \"main\"", + "OpName %void \"void\"", + "OpName %main_func_type \"main_func_type\"", + "OpName %main \"main\"", + "OpName %main_func_entry_block \"main_func_entry_block\"", + "OpName %bool \"bool\"", + "OpName %uint \"uint\"", + "OpName %int \"int\"", + "OpName %float \"float\"", + "OpName %double \"double\"", + "OpName %v2int \"v2int\"", + "OpName %spec_true \"spec_true\"", + "OpName %spec_false \"spec_false\"", + "OpName %spec_uint \"spec_uint\"", + "OpName %spec_int \"spec_int\"", + "OpName %spec_float \"spec_float\"", + "OpName %spec_double \"spec_double\"", + "OpName %spec_add_op \"spec_add_op\"", + "OpName %spec_vec \"spec_vec\"", + "OpName %spec_vec_x \"spec_vec_x\"", + "OpDecorate %spec_true SpecId 200", + "OpDecorate %spec_false SpecId 201", + "OpDecorate %spec_uint SpecId 202", + "OpDecorate %spec_int SpecId 203", + "OpDecorate %spec_float SpecId 204", + "OpDecorate %spec_double SpecId 205", + "%void = OpTypeVoid", + "%main_func_type = OpTypeFunction %void", + "%bool = OpTypeBool", + "%uint = OpTypeInt 32 0", + "%int = OpTypeInt 32 1", + "%float = OpTypeFloat 32", + "%double = OpTypeFloat 64", + "%v2int = OpTypeVector %int 2", + "%spec_true = OpSpecConstantTrue %bool", + "%spec_false = OpSpecConstantFalse %bool", + "%spec_uint = OpSpecConstant %uint 1", + "%spec_int = OpSpecConstant %int 1", + "%spec_float = OpSpecConstant %float 1.25", + "%spec_double = OpSpecConstant %double 1.2345678", + "%spec_add_op = OpSpecConstantOp %int IAdd %spec_int %spec_int", + "%spec_vec = OpSpecConstantComposite %v2int %spec_int %spec_int", + "%spec_vec_x = OpSpecConstantOp %int CompositeExtract %spec_vec 0", + "%main = OpFunction %void None %main_func_type", +"%main_func_entry_block = OpLabel", + "OpReturn", + "OpFunctionEnd", + + // clang-format on + }; + + SinglePassRunAndCheck(builder.GetCode(), JoinAllInsts(expected), + /* skip_nop = */ false); +} + +TEST_F(AssemblyBuilderTest, AppendNames) { + AssemblyBuilder builder; + builder.AppendNames({ + "OpName %void \"another_name_for_void\"", + "I am an invalid OpName instruction and should not be added", + "OpName %main \"another name for main\"", + }); + std::vector expected = { + // clang-format off + "OpCapability Shader", + "OpCapability Float64", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Vertex %main \"main\"", + "OpName %void \"void\"", + "OpName %main_func_type \"main_func_type\"", + "OpName %main \"main\"", + "OpName %main_func_entry_block \"main_func_entry_block\"", + "OpName %void \"another_name_for_void\"", + "OpName %main \"another name for main\"", + "%void = OpTypeVoid", + "%main_func_type = OpTypeFunction %void", + "%main = OpFunction %void None %main_func_type", +"%main_func_entry_block = OpLabel", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + + SinglePassRunAndCheck(builder.GetCode(), JoinAllInsts(expected), + /* skip_nop = */ false); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/block_merge_test.cpp b/test/opt/block_merge_test.cpp new file mode 100644 index 000000000..654e88019 --- /dev/null +++ b/test/opt/block_merge_test.cpp @@ -0,0 +1,751 @@ +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using BlockMergeTest = PassTest<::testing::Test>; + +TEST_F(BlockMergeTest, Simple) { + // Note: SPIR-V hand edited to insert block boundary + // between two statements in main. + // + // #version 140 + // + // in vec4 BaseColor; + // + // void main() + // { + // vec4 v = BaseColor; + // gl_FragColor = v; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %7 +%13 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%14 = OpLoad %v4float %BaseColor +OpStore %v %14 +OpBranch %15 +%15 = OpLabel +%16 = OpLoad %v4float %v +OpStore %gl_FragColor %16 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %7 +%13 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%14 = OpLoad %v4float %BaseColor +OpStore %v %14 +%16 = OpLoad %v4float %v +OpStore %gl_FragColor %16 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, predefs + after, true, + true); +} + +TEST_F(BlockMergeTest, EmptyBlock) { + // Note: SPIR-V hand edited to insert empty block + // after two statements in main. + // + // #version 140 + // + // in vec4 BaseColor; + // + // void main() + // { + // vec4 v = BaseColor; + // gl_FragColor = v; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %7 +%13 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%14 = OpLoad %v4float %BaseColor +OpStore %v %14 +OpBranch %15 +%15 = OpLabel +%16 = OpLoad %v4float %v +OpStore %gl_FragColor %16 +OpBranch %17 +%17 = OpLabel +OpBranch %18 +%18 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %7 +%13 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%14 = OpLoad %v4float %BaseColor +OpStore %v %14 +%16 = OpLoad %v4float %v +OpStore %gl_FragColor %16 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, predefs + after, true, + true); +} + +TEST_F(BlockMergeTest, NestedInControlFlow) { + // Note: SPIR-V hand edited to insert block boundary + // between OpFMul and OpStore in then-part. + // + // #version 140 + // in vec4 BaseColor; + // + // layout(std140) uniform U_t + // { + // bool g_B ; + // } ; + // + // void main() + // { + // vec4 v = BaseColor; + // if (g_B) + // vec4 v = v * 0.25; + // gl_FragColor = v; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %U_t "U_t" +OpMemberName %U_t 0 "g_B" +OpName %_ "" +OpName %v_0 "v" +OpName %gl_FragColor "gl_FragColor" +OpMemberDecorate %U_t 0 Offset 0 +OpDecorate %U_t Block +OpDecorate %_ DescriptorSet 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%uint = OpTypeInt 32 0 +%U_t = OpTypeStruct %uint +%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t +%_ = OpVariable %_ptr_Uniform_U_t Uniform +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%bool = OpTypeBool +%uint_0 = OpConstant %uint 0 +%float_0_25 = OpConstant %float 0.25 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %10 +%24 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%v_0 = OpVariable %_ptr_Function_v4float Function +%25 = OpLoad %v4float %BaseColor +OpStore %v %25 +%26 = OpAccessChain %_ptr_Uniform_uint %_ %int_0 +%27 = OpLoad %uint %26 +%28 = OpINotEqual %bool %27 %uint_0 +OpSelectionMerge %29 None +OpBranchConditional %28 %30 %29 +%30 = OpLabel +%31 = OpLoad %v4float %v +%32 = OpVectorTimesScalar %v4float %31 %float_0_25 +OpBranch %33 +%33 = OpLabel +OpStore %v_0 %32 +OpBranch %29 +%29 = OpLabel +%34 = OpLoad %v4float %v +OpStore %gl_FragColor %34 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %10 +%24 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%v_0 = OpVariable %_ptr_Function_v4float Function +%25 = OpLoad %v4float %BaseColor +OpStore %v %25 +%26 = OpAccessChain %_ptr_Uniform_uint %_ %int_0 +%27 = OpLoad %uint %26 +%28 = OpINotEqual %bool %27 %uint_0 +OpSelectionMerge %29 None +OpBranchConditional %28 %30 %29 +%30 = OpLabel +%31 = OpLoad %v4float %v +%32 = OpVectorTimesScalar %v4float %31 %float_0_25 +OpStore %v_0 %32 +OpBranch %29 +%29 = OpLabel +%34 = OpLoad %v4float %v +OpStore %gl_FragColor %34 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, predefs + after, true, + true); +} + +TEST_F(BlockMergeTest, PhiInSuccessorOfMergedBlock) { + const std::string text = R"( +; CHECK: OpSelectionMerge [[merge:%\w+]] None +; CHECK-NEXT: OpBranchConditional {{%\w+}} [[then:%\w+]] [[else:%\w+]] +; CHECK: [[then]] = OpLabel +; CHECK-NEXT: OpBranch [[merge]] +; CHECK: [[else]] = OpLabel +; CHECK-NEXT: OpBranch [[merge]] +; CHECK: [[merge]] = OpLabel +; CHECK-NEXT: OpPhi {{%\w+}} %true [[then]] %false [[else]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%false = OpConstantFalse %bool +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%entry = OpLabel +OpSelectionMerge %merge None +OpBranchConditional %true %then %else +%then = OpLabel +OpBranch %then_next +%then_next = OpLabel +OpBranch %merge +%else = OpLabel +OpBranch %merge +%merge = OpLabel +%phi = OpPhi %bool %true %then_next %false %else +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(BlockMergeTest, UpdateMergeInstruction) { + const std::string text = R"( +; CHECK: OpSelectionMerge [[merge:%\w+]] None +; CHECK-NEXT: OpBranchConditional {{%\w+}} [[then:%\w+]] [[else:%\w+]] +; CHECK: [[then]] = OpLabel +; CHECK-NEXT: OpBranch [[merge]] +; CHECK: [[else]] = OpLabel +; CHECK-NEXT: OpBranch [[merge]] +; CHECK: [[merge]] = OpLabel +; CHECK-NEXT: OpReturn +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%false = OpConstantFalse %bool +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%entry = OpLabel +OpSelectionMerge %real_merge None +OpBranchConditional %true %then %else +%then = OpLabel +OpBranch %merge +%else = OpLabel +OpBranch %merge +%merge = OpLabel +OpBranch %real_merge +%real_merge = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(BlockMergeTest, TwoMergeBlocksCannotBeMerged) { + const std::string text = R"( +; CHECK: OpSelectionMerge [[outer_merge:%\w+]] None +; CHECK: OpSelectionMerge [[inner_merge:%\w+]] None +; CHECK: [[inner_merge]] = OpLabel +; CHECK-NEXT: OpBranch [[outer_merge]] +; CHECK: [[outer_merge]] = OpLabel +; CHECK-NEXT: OpReturn +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%false = OpConstantFalse %bool +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%entry = OpLabel +OpSelectionMerge %outer_merge None +OpBranchConditional %true %then %else +%then = OpLabel +OpBranch %inner_header +%else = OpLabel +OpBranch %inner_header +%inner_header = OpLabel +OpSelectionMerge %inner_merge None +OpBranchConditional %true %inner_then %inner_else +%inner_then = OpLabel +OpBranch %inner_merge +%inner_else = OpLabel +OpBranch %inner_merge +%inner_merge = OpLabel +OpBranch %outer_merge +%outer_merge = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(BlockMergeTest, MergeContinue) { + const std::string text = R"( +; CHECK: OpBranch [[header:%\w+]] +; CHECK: [[header]] = OpLabel +; CHECK-NEXT: OpLogicalAnd +; CHECK-NEXT: OpLoopMerge {{%\w+}} [[header]] None +; CHECK-NEXT: OpBranch [[header]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%false = OpConstantFalse %bool +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%entry = OpLabel +OpBranch %header +%header = OpLabel +OpLoopMerge %merge %continue None +OpBranch %continue +%continue = OpLabel +%op = OpLogicalAnd %bool %true %false +OpBranch %header +%merge = OpLabel +OpUnreachable +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(BlockMergeTest, TwoHeadersCannotBeMerged) { + const std::string text = R"( +; CHECK: OpBranch [[loop_header:%\w+]] +; CHECK: [[loop_header]] = OpLabel +; CHECK-NEXT: OpLoopMerge +; CHECK-NEXT: OpBranch [[if_header:%\w+]] +; CHECK: [[if_header]] = OpLabel +; CHECK-NEXT: OpSelectionMerge +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%false = OpConstantFalse %bool +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%entry = OpLabel +OpBranch %header +%header = OpLabel +OpLoopMerge %merge %continue None +OpBranch %inner_header +%inner_header = OpLabel +OpSelectionMerge %continue None +OpBranchConditional %true %then %continue +%then = OpLabel +OpBranch %continue +%continue = OpLabel +OpBranchConditional %false %merge %header +%merge = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(BlockMergeTest, RemoveStructuredDeclaration) { + // Note: SPIR-V hand edited remove dead branch and add block + // before continue block + // + // #version 140 + // in vec4 BaseColor; + // + // void main() + // { + // while (true) { + // break; + // } + // gl_FragColor = BaseColor; + // } + + const std::string assembly = + R"( +; CHECK: OpLabel +; CHECK: [[header:%\w+]] = OpLabel +; CHECK-NOT: OpLoopMerge +; CHECK: OpReturn +; CHECK: [[continue:%\w+]] = OpLabel +; CHECK-NEXT: OpBranch [[header]] +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %gl_FragColor %BaseColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %gl_FragColor "gl_FragColor" +OpName %BaseColor "BaseColor" +%void = OpTypeVoid +%6 = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%main = OpFunction %void None %6 +%13 = OpLabel +OpBranch %14 +%14 = OpLabel +OpLoopMerge %15 %16 None +OpBranch %17 +%17 = OpLabel +OpBranch %15 +%18 = OpLabel +OpBranch %16 +%16 = OpLabel +OpBranch %14 +%15 = OpLabel +%19 = OpLoad %v4float %BaseColor +OpStore %gl_FragColor %19 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(assembly, true); +} + +TEST_F(BlockMergeTest, DontMergeKill) { + const std::string text = R"( +; CHECK: OpLoopMerge [[merge:%\w+]] [[cont:%\w+]] None +; CHECK-NEXT: OpBranch [[ret:%\w+]] +; CHECK: [[ret:%\w+]] = OpLabel +; CHECK-NEXT: OpKill +; CHECK-DAG: [[cont]] = OpLabel +; CHECK-DAG: [[merge]] = OpLabel +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +OpBranch %2 +%2 = OpLabel +OpLoopMerge %3 %4 None +OpBranch %5 +%5 = OpLabel +OpKill +%4 = OpLabel +OpBranch %2 +%3 = OpLabel +OpUnreachable +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(BlockMergeTest, DontMergeUnreachable) { + const std::string text = R"( +; CHECK: OpLoopMerge [[merge:%\w+]] [[cont:%\w+]] None +; CHECK-NEXT: OpBranch [[ret:%\w+]] +; CHECK: [[ret:%\w+]] = OpLabel +; CHECK-NEXT: OpUnreachable +; CHECK-DAG: [[cont]] = OpLabel +; CHECK-DAG: [[merge]] = OpLabel +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +OpBranch %2 +%2 = OpLabel +OpLoopMerge %3 %4 None +OpBranch %5 +%5 = OpLabel +OpUnreachable +%4 = OpLabel +OpBranch %2 +%3 = OpLabel +OpUnreachable +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(BlockMergeTest, DontMergeReturn) { + const std::string text = R"( +; CHECK: OpLoopMerge [[merge:%\w+]] [[cont:%\w+]] None +; CHECK-NEXT: OpBranch [[ret:%\w+]] +; CHECK: [[ret:%\w+]] = OpLabel +; CHECK-NEXT: OpReturn +; CHECK-DAG: [[cont]] = OpLabel +; CHECK-DAG: [[merge]] = OpLabel +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +OpBranch %2 +%2 = OpLabel +OpLoopMerge %3 %4 None +OpBranch %5 +%5 = OpLabel +OpReturn +%4 = OpLabel +OpBranch %2 +%3 = OpLabel +OpUnreachable +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(BlockMergeTest, DontMergeSwitch) { + const std::string text = R"( +; CHECK: OpLoopMerge [[merge:%\w+]] [[cont:%\w+]] None +; CHECK-NEXT: OpBranch [[ret:%\w+]] +; CHECK: [[ret:%\w+]] = OpLabel +; CHECK-NEXT: OpSwitch +; CHECK-DAG: [[cont]] = OpLabel +; CHECK-DAG: [[merge]] = OpLabel +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +OpBranch %2 +%2 = OpLabel +OpLoopMerge %3 %4 None +OpBranch %5 +%5 = OpLabel +OpSwitch %int_0 %6 +%6 = OpLabel +OpReturn +%4 = OpLabel +OpBranch %2 +%3 = OpLabel +OpUnreachable +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(BlockMergeTest, DontMergeReturnValue) { + const std::string text = R"( +; CHECK: OpLoopMerge [[merge:%\w+]] [[cont:%\w+]] None +; CHECK-NEXT: OpBranch [[ret:%\w+]] +; CHECK: [[ret:%\w+]] = OpLabel +; CHECK-NEXT: OpReturn +; CHECK-DAG: [[cont]] = OpLabel +; CHECK-DAG: [[merge]] = OpLabel +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%functy = OpTypeFunction %void +%otherfuncty = OpTypeFunction %bool +%true = OpConstantTrue %bool +%func = OpFunction %void None %functy +%1 = OpLabel +%2 = OpFunctionCall %bool %3 +OpReturn +OpFunctionEnd +%3 = OpFunction %bool None %otherfuncty +%4 = OpLabel +OpBranch %5 +%5 = OpLabel +OpLoopMerge %6 %7 None +OpBranch %8 +%8 = OpLabel +OpReturnValue %true +%7 = OpLabel +OpBranch %5 +%6 = OpLabel +OpUnreachable +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(BlockMergeTest, MergeHeaders) { + // Merge two headers when the second is the merge block of the first. + const std::string text = R"( +; CHECK: OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpBranch [[header:%\w+]] +; CHECK-NEXT: [[header]] = OpLabel +; CHECK-NEXT: OpSelectionMerge [[merge:%\w+]] +; CHECK: [[merge]] = OpLabel +; CHEKC: OpReturn +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%functy = OpTypeFunction %void +%otherfuncty = OpTypeFunction %bool +%true = OpConstantTrue %bool +%func = OpFunction %void None %functy +%1 = OpLabel +OpBranch %5 +%5 = OpLabel +OpLoopMerge %8 %7 None +OpBranch %8 +%7 = OpLabel +OpBranch %5 +%8 = OpLabel +OpSelectionMerge %m None +OpBranchConditional %true %a %m +%a = OpLabel +OpBranch %m +%m = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +// TODO(greg-lunarg): Add tests to verify handling of these cases: +// +// More complex control flow +// Others? + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/ccp_test.cpp b/test/opt/ccp_test.cpp new file mode 100644 index 000000000..20d883b0a --- /dev/null +++ b/test/opt/ccp_test.cpp @@ -0,0 +1,901 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "source/opt/ccp_pass.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using CCPTest = PassTest<::testing::Test>; + +TEST_F(CCPTest, PropagateThroughPhis) { + const std::string spv_asm = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %x %outparm + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpName %main "main" + OpName %x "x" + OpName %outparm "outparm" + OpDecorate %x Flat + OpDecorate %x Location 0 + OpDecorate %outparm Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 + %bool = OpTypeBool +%_ptr_Function_int = OpTypePointer Function %int + %int_4 = OpConstant %int 4 + %int_3 = OpConstant %int 3 + %int_1 = OpConstant %int 1 +%_ptr_Input_int = OpTypePointer Input %int + %x = OpVariable %_ptr_Input_int Input +%_ptr_Output_int = OpTypePointer Output %int + %outparm = OpVariable %_ptr_Output_int Output + %main = OpFunction %void None %3 + %4 = OpLabel + %5 = OpLoad %int %x + %9 = OpIAdd %int %int_1 %int_3 + %6 = OpSGreaterThan %bool %5 %int_3 + OpSelectionMerge %25 None + OpBranchConditional %6 %22 %23 + %22 = OpLabel + +; CHECK: OpCopyObject %int %int_4 + %7 = OpCopyObject %int %9 + + OpBranch %25 + %23 = OpLabel + %8 = OpCopyObject %int %int_4 + OpBranch %25 + %25 = OpLabel + +; %int_4 should have propagated to both OpPhi operands. +; CHECK: OpPhi %int %int_4 {{%\d+}} %int_4 {{%\d+}} + %35 = OpPhi %int %7 %22 %8 %23 + +; This function always returns 4. DCE should get rid of everything else. +; CHECK OpStore %outparm %int_4 + OpStore %outparm %35 + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(spv_asm, true); +} + +TEST_F(CCPTest, SimplifyConditionals) { + const std::string spv_asm = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %outparm + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpName %main "main" + OpName %outparm "outparm" + OpDecorate %outparm Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 + %bool = OpTypeBool +%_ptr_Function_int = OpTypePointer Function %int + %int_4 = OpConstant %int 4 + %int_3 = OpConstant %int 3 + %int_1 = OpConstant %int 1 +%_ptr_Output_int = OpTypePointer Output %int + %outparm = OpVariable %_ptr_Output_int Output + %main = OpFunction %void None %3 + %4 = OpLabel + %9 = OpIAdd %int %int_4 %int_3 + %6 = OpSGreaterThan %bool %9 %int_3 + OpSelectionMerge %25 None +; CHECK: OpBranchConditional %true [[bb_taken:%\d+]] [[bb_not_taken:%\d+]] + OpBranchConditional %6 %22 %23 +; CHECK: [[bb_taken]] = OpLabel + %22 = OpLabel +; CHECK: OpCopyObject %int %int_7 + %7 = OpCopyObject %int %9 + OpBranch %25 +; CHECK: [[bb_not_taken]] = OpLabel + %23 = OpLabel +; CHECK: [[id_not_evaluated:%\d+]] = OpCopyObject %int %int_4 + %8 = OpCopyObject %int %int_4 + OpBranch %25 + %25 = OpLabel + +; %int_7 should have propagated to the first OpPhi operand. But the else branch +; is not executable (conditional is always true), so no values should be +; propagated there and the value of the OpPhi should always be %int_7. +; CHECK: OpPhi %int %int_7 [[bb_taken]] [[id_not_evaluated]] [[bb_not_taken]] + %35 = OpPhi %int %7 %22 %8 %23 + +; Only the true path of the conditional is ever executed. The output of this +; function is always %int_7. +; CHECK: OpStore %outparm %int_7 + OpStore %outparm %35 + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(spv_asm, true); +} + +TEST_F(CCPTest, SimplifySwitches) { + const std::string spv_asm = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %outparm + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpName %main "main" + OpName %outparm "outparm" + OpDecorate %outparm Location 0 + %void = OpTypeVoid + %6 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %int_23 = OpConstant %int 23 + %int_42 = OpConstant %int 42 + %int_14 = OpConstant %int 14 + %int_15 = OpConstant %int 15 + %int_4 = OpConstant %int 4 +%_ptr_Output_int = OpTypePointer Output %int + %outparm = OpVariable %_ptr_Output_int Output + %main = OpFunction %void None %6 + %15 = OpLabel + OpSelectionMerge %17 None + OpSwitch %int_23 %17 10 %18 13 %19 23 %20 + %18 = OpLabel + OpBranch %17 + %19 = OpLabel + OpBranch %17 + %20 = OpLabel + OpBranch %17 + %17 = OpLabel + %24 = OpPhi %int %int_23 %15 %int_42 %18 %int_14 %19 %int_15 %20 + +; The switch will always jump to label %20, which carries the value %int_15. +; CHECK: OpIAdd %int %int_15 %int_4 + %22 = OpIAdd %int %24 %int_4 + +; Consequently, the return value will always be %int_19. +; CHECK: OpStore %outparm %int_19 + OpStore %outparm %22 + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(spv_asm, true); +} + +TEST_F(CCPTest, SimplifySwitchesDefaultBranch) { + const std::string spv_asm = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %outparm + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpName %main "main" + OpName %outparm "outparm" + OpDecorate %outparm Location 0 + %void = OpTypeVoid + %6 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %int_42 = OpConstant %int 42 + %int_4 = OpConstant %int 4 + %int_1 = OpConstant %int 1 +%_ptr_Output_int = OpTypePointer Output %int + %outparm = OpVariable %_ptr_Output_int Output + %main = OpFunction %void None %6 + %13 = OpLabel + %15 = OpIAdd %int %int_42 %int_4 + OpSelectionMerge %16 None + +; CHECK: OpSwitch %int_46 {{%\d+}} 10 {{%\d+}} + OpSwitch %15 %17 10 %18 + %18 = OpLabel + OpBranch %16 + %17 = OpLabel + OpBranch %16 + %16 = OpLabel + %22 = OpPhi %int %int_42 %18 %int_1 %17 + +; The switch will always jump to the default label %17. This carries the value +; %int_1. +; CHECK: OpIAdd %int %int_1 %int_4 + %20 = OpIAdd %int %22 %int_4 + +; Resulting in a return value of %int_5. +; CHECK: OpStore %outparm %int_5 + OpStore %outparm %20 + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(spv_asm, true); +} + +TEST_F(CCPTest, SimplifyIntVector) { + const std::string spv_asm = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %OutColor + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpName %main "main" + OpName %v "v" + OpName %OutColor "OutColor" + OpDecorate %OutColor Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 + %v4int = OpTypeVector %int 4 +%_ptr_Function_v4int = OpTypePointer Function %v4int + %int_1 = OpConstant %int 1 + %int_2 = OpConstant %int 2 + %int_3 = OpConstant %int 3 + %int_4 = OpConstant %int 4 + %14 = OpConstantComposite %v4int %int_1 %int_2 %int_3 %int_4 + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 +%_ptr_Function_int = OpTypePointer Function %int +%_ptr_Output_v4int = OpTypePointer Output %v4int + %OutColor = OpVariable %_ptr_Output_v4int Output + %main = OpFunction %void None %3 + %5 = OpLabel + %v = OpVariable %_ptr_Function_v4int Function + OpStore %v %14 + %18 = OpAccessChain %_ptr_Function_int %v %uint_0 + %19 = OpLoad %int %18 + +; The constant folder does not see through access chains. To get this, the +; vector would have to be scalarized. +; CHECK: [[result_id:%\d+]] = OpIAdd %int {{%\d+}} %int_1 + %20 = OpIAdd %int %19 %int_1 + %21 = OpAccessChain %_ptr_Function_int %v %uint_0 + +; CHECK: OpStore {{%\d+}} [[result_id]] + OpStore %21 %20 + %24 = OpLoad %v4int %v + OpStore %OutColor %24 + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(spv_asm, true); +} + +TEST_F(CCPTest, BadSimplifyFloatVector) { + const std::string spv_asm = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %OutColor + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpName %main "main" + OpName %v "v" + OpName %OutColor "OutColor" + OpDecorate %OutColor Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float + %float_1 = OpConstant %float 1 + %float_2 = OpConstant %float 2 + %float_3 = OpConstant %float 3 + %float_4 = OpConstant %float 4 + %14 = OpConstantComposite %v4float %float_1 %float_2 %float_3 %float_4 + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 +%_ptr_Function_float = OpTypePointer Function %float +%_ptr_Output_v4float = OpTypePointer Output %v4float + %OutColor = OpVariable %_ptr_Output_v4float Output + %main = OpFunction %void None %3 + %5 = OpLabel + %v = OpVariable %_ptr_Function_v4float Function + OpStore %v %14 + %18 = OpAccessChain %_ptr_Function_float %v %uint_0 + %19 = OpLoad %float %18 + +; NOTE: This test should start failing once floating point folding is +; implemented (https://github.com/KhronosGroup/SPIRV-Tools/issues/943). +; This should be checking that we are adding %float_1 + %float_1. +; CHECK: [[result_id:%\d+]] = OpFAdd %float {{%\d+}} %float_1 + %20 = OpFAdd %float %19 %float_1 + %21 = OpAccessChain %_ptr_Function_float %v %uint_0 + +; This should be checkint that we are storing %float_2 instead of result_it. +; CHECK: OpStore {{%\d+}} [[result_id]] + OpStore %21 %20 + %24 = OpLoad %v4float %v + OpStore %OutColor %24 + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(spv_asm, true); +} + +TEST_F(CCPTest, NoLoadStorePropagation) { + const std::string spv_asm = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %outparm + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpName %main "main" + OpName %x "x" + OpName %outparm "outparm" + OpDecorate %outparm Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %int_23 = OpConstant %int 23 +%_ptr_Output_int = OpTypePointer Output %int + %outparm = OpVariable %_ptr_Output_int Output + %main = OpFunction %void None %3 + %5 = OpLabel + %x = OpVariable %_ptr_Function_int Function + OpStore %x %int_23 + +; int_23 should not propagate into this load. +; CHECK: [[load_id:%\d+]] = OpLoad %int %x + %12 = OpLoad %int %x + +; Nor into this copy operation. +; CHECK: [[copy_id:%\d+]] = OpCopyObject %int [[load_id]] + %13 = OpCopyObject %int %12 + +; Likewise here. +; CHECK: OpStore %outparm [[copy_id]] + OpStore %outparm %13 + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(spv_asm, true); +} + +TEST_F(CCPTest, HandleAbortInstructions) { + const std::string spv_asm = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginUpperLeft + OpSource HLSL 500 + OpName %main "main" + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 + %bool = OpTypeBool +; CHECK: %true = OpConstantTrue %bool + %int_3 = OpConstant %int 3 + %int_1 = OpConstant %int 1 + %main = OpFunction %void None %3 + %4 = OpLabel + %9 = OpIAdd %int %int_3 %int_1 + %6 = OpSGreaterThan %bool %9 %int_3 + OpSelectionMerge %23 None +; CHECK: OpBranchConditional %true {{%\d+}} {{%\d+}} + OpBranchConditional %6 %22 %23 + %22 = OpLabel + OpKill + %23 = OpLabel + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(spv_asm, true); +} + +TEST_F(CCPTest, SSAWebCycles) { + // Test reduced from https://github.com/KhronosGroup/SPIRV-Tools/issues/1159 + // When there is a cycle in the SSA def-use web, the propagator was getting + // into an infinite loop. SSA edges for Phi instructions should not be + // added to the edges to simulate. + const std::string spv_asm = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpName %main "main" + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %int_0 = OpConstant %int 0 + %int_4 = OpConstant %int 4 + %bool = OpTypeBool + %int_1 = OpConstant %int 1 +%_ptr_Output_int = OpTypePointer Output %int + %main = OpFunction %void None %3 + %5 = OpLabel + OpBranch %11 + %11 = OpLabel + %29 = OpPhi %int %int_0 %5 %22 %14 + %30 = OpPhi %int %int_0 %5 %25 %14 + OpLoopMerge %13 %14 None + OpBranch %15 + %15 = OpLabel + %19 = OpSLessThan %bool %30 %int_4 +; CHECK: OpBranchConditional %true {{%\d+}} {{%\d+}} + OpBranchConditional %19 %12 %13 + %12 = OpLabel +; CHECK: OpIAdd %int %int_0 %int_0 + %22 = OpIAdd %int %29 %30 + OpBranch %14 + %14 = OpLabel +; CHECK: OpPhi %int %int_0 {{%\d+}} + %25 = OpPhi %int %30 %12 + OpBranch %11 + %13 = OpLabel + OpReturn + OpFunctionEnd + )"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndMatch(spv_asm, true); +} + +TEST_F(CCPTest, LoopInductionVariables) { + // Test reduced from https://github.com/KhronosGroup/SPIRV-Tools/issues/1143 + // We are failing to properly consider the induction variable for this loop + // as Varying. + const std::string spv_asm = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 430 + OpName %main "main" + %void = OpTypeVoid + %5 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %int_0 = OpConstant %int 0 + %int_10 = OpConstant %int 10 + %bool = OpTypeBool + %int_1 = OpConstant %int 1 + %main = OpFunction %void None %5 + %12 = OpLabel + OpBranch %13 + %13 = OpLabel + +; This Phi should not have all constant arguments: +; CHECK: [[phi_id:%\d+]] = OpPhi %int %int_0 {{%\d+}} {{%\d+}} {{%\d+}} + %22 = OpPhi %int %int_0 %12 %21 %15 + OpLoopMerge %14 %15 None + OpBranch %16 + %16 = OpLabel + +; The Phi should never be considered to have the value %int_0. +; CHECK: [[branch_selector:%\d+]] = OpSLessThan %bool [[phi_id]] %int_10 + %18 = OpSLessThan %bool %22 %int_10 + +; This conditional was wrongly converted into an always-true jump due to the +; bad meet evaluation of %22. +; CHECK: OpBranchConditional [[branch_selector]] {{%\d+}} {{%\d+}} + OpBranchConditional %18 %19 %14 + %19 = OpLabel + OpBranch %15 + %15 = OpLabel +; CHECK: OpIAdd %int [[phi_id]] %int_1 + %21 = OpIAdd %int %22 %int_1 + OpBranch %13 + %14 = OpLabel + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(spv_asm, true); +} + +TEST_F(CCPTest, HandleCompositeWithUndef) { + // Check to make sure that CCP does not crash when given a "constant" struct + // with an undef. If at a later time CCP is enhanced to optimize this case, + // it is not wrong. + const std::string spv_asm = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginUpperLeft + OpSource HLSL 500 + OpName %main "main" + %void = OpTypeVoid + %4 = OpTypeFunction %void + %int = OpTypeInt 32 1 + %bool = OpTypeBool + %_struct_7 = OpTypeStruct %int %int + %int_1 = OpConstant %int 1 + %9 = OpUndef %int + %10 = OpConstantComposite %_struct_7 %int_1 %9 + %main = OpFunction %void None %4 + %11 = OpLabel + %12 = OpCompositeExtract %int %10 0 + %13 = OpCopyObject %int %12 + OpReturn + OpFunctionEnd + )"; + + auto res = SinglePassRunToBinary(spv_asm, true); + EXPECT_EQ(std::get<1>(res), Pass::Status::SuccessWithoutChange); +} + +TEST_F(CCPTest, SkipSpecConstantInstrucitons) { + const std::string spv_asm = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginUpperLeft + OpSource HLSL 500 + OpName %main "main" + %void = OpTypeVoid + %4 = OpTypeFunction %void + %bool = OpTypeBool + %10 = OpSpecConstantFalse %bool + %main = OpFunction %void None %4 + %11 = OpLabel + %12 = OpBranchConditional %10 %l1 %l2 + %l1 = OpLabel + OpReturn + %l2 = OpLabel + OpReturn + OpFunctionEnd + )"; + + auto res = SinglePassRunToBinary(spv_asm, true); + EXPECT_EQ(std::get<1>(res), Pass::Status::SuccessWithoutChange); +} + +TEST_F(CCPTest, UpdateSubsequentPhisToVarying) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" %in +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%int = OpTypeInt 32 1 +%false = OpConstantFalse %bool +%int0 = OpConstant %int 0 +%int1 = OpConstant %int 1 +%int6 = OpConstant %int 6 +%int_ptr_Input = OpTypePointer Input %int +%in = OpVariable %int_ptr_Input Input +%undef = OpUndef %int +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +OpBranch %2 +%2 = OpLabel +%outer_phi = OpPhi %int %int0 %1 %outer_add %15 +%cond1 = OpSLessThanEqual %bool %outer_phi %int6 +OpLoopMerge %3 %15 None +OpBranchConditional %cond1 %4 %3 +%4 = OpLabel +%ld = OpLoad %int %in +%cond2 = OpSGreaterThanEqual %bool %int1 %ld +OpSelectionMerge %10 None +OpBranchConditional %cond2 %8 %9 +%8 = OpLabel +OpBranch %10 +%9 = OpLabel +OpBranch %10 +%10 = OpLabel +%extra_phi = OpPhi %int %outer_phi %8 %outer_phi %9 +OpBranch %11 +%11 = OpLabel +%inner_phi = OpPhi %int %int0 %10 %inner_add %13 +%cond3 = OpSLessThanEqual %bool %inner_phi %int6 +OpLoopMerge %14 %13 None +OpBranchConditional %cond3 %12 %14 +%12 = OpLabel +OpBranch %13 +%13 = OpLabel +%inner_add = OpIAdd %int %inner_phi %int1 +OpBranch %11 +%14 = OpLabel +OpBranch %15 +%15 = OpLabel +%outer_add = OpIAdd %int %extra_phi %int1 +OpBranch %2 +%3 = OpLabel +OpReturn +OpFunctionEnd +)"; + + auto res = SinglePassRunToBinary(text, true); + EXPECT_EQ(std::get<1>(res), Pass::Status::SuccessWithoutChange); +} + +TEST_F(CCPTest, UndefInPhi) { + const std::string text = R"( +; CHECK: [[uint1:%\w+]] = OpConstant {{%\w+}} 1 +; CHECK: [[phi:%\w+]] = OpPhi +; CHECK: OpIAdd {{%\w+}} [[phi]] [[uint1]] + OpCapability Kernel + OpCapability Linkage + OpMemoryModel Logical OpenCL + OpDecorate %1 LinkageAttributes "func" Export + %void = OpTypeVoid + %bool = OpTypeBool + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 + %7 = OpUndef %uint + %8 = OpTypeFunction %void %bool + %1 = OpFunction %void None %8 + %9 = OpFunctionParameter %bool + %10 = OpLabel + OpBranchConditional %9 %11 %12 + %11 = OpLabel + OpBranch %13 + %12 = OpLabel + OpBranch %14 + %14 = OpLabel + OpBranchConditional %9 %13 %15 + %15 = OpLabel + OpBranch %13 + %13 = OpLabel + %16 = OpPhi %uint %uint_0 %11 %7 %14 %uint_1 %15 + %17 = OpIAdd %uint %16 %uint_1 + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +// Just test to make sure the constant fold rules are being used. Will rely on +// the folding test for specific testing of specific rules. +TEST_F(CCPTest, UseConstantFoldingRules) { + const std::string text = R"( +; CHECK: [[float1:%\w+]] = OpConstant {{%\w+}} 1 +; CHECK: OpReturnValue [[float1]] + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpDecorate %1 LinkageAttributes "func" Export + %void = OpTypeVoid + %bool = OpTypeBool + %float = OpTypeFloat 32 + %float_0 = OpConstant %float 0 + %float_1 = OpConstant %float 1 + %8 = OpTypeFunction %float + %1 = OpFunction %float None %8 + %10 = OpLabel + %17 = OpFAdd %float %float_0 %float_1 + OpReturnValue %17 + OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +// Test for #1300. Previously value for %5 would not settle during simulation. +TEST_F(CCPTest, SettlePhiLatticeValue) { + const std::string text = R"( +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +OpDecorate %func LinkageAttributes "func" Export +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%false = OpConstantFalse %bool +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +OpBranchConditional %true %2 %3 +%3 = OpLabel +OpBranch %2 +%2 = OpLabel +%5 = OpPhi %bool %true %1 %false %3 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunToBinary(text, true); +} + +TEST_F(CCPTest, NullBranchCondition) { + const std::string text = R"( +; CHECK: [[int1:%\w+]] = OpConstant {{%\w+}} 1 +; CHECK: [[int2:%\w+]] = OpConstant {{%\w+}} 2 +; CHECK: OpIAdd {{%\w+}} [[int1]] [[int2]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%int = OpTypeInt 32 1 +%null = OpConstantNull %bool +%int_1 = OpConstant %int 1 +%int_2 = OpConstant %int 2 +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +OpSelectionMerge %2 None +OpBranchConditional %null %2 %3 +%3 = OpLabel +OpBranch %2 +%2 = OpLabel +%phi = OpPhi %int %int_1 %1 %int_2 %3 +%add = OpIAdd %int %int_1 %phi +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CCPTest, UndefBranchCondition) { + const std::string text = R"( +; CHECK: [[int1:%\w+]] = OpConstant {{%\w+}} 1 +; CHECK: [[phi:%\w+]] = OpPhi +; CHECK: OpIAdd {{%\w+}} [[int1]] [[phi]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%int = OpTypeInt 32 1 +%undef = OpUndef %bool +%int_1 = OpConstant %int 1 +%int_2 = OpConstant %int 2 +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +OpSelectionMerge %2 None +OpBranchConditional %undef %2 %3 +%3 = OpLabel +OpBranch %2 +%2 = OpLabel +%phi = OpPhi %int %int_1 %1 %int_2 %3 +%add = OpIAdd %int %int_1 %phi +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CCPTest, NullSwitchCondition) { + const std::string text = R"( +; CHECK: [[int1:%\w+]] = OpConstant {{%\w+}} 1 +; CHECK: [[int2:%\w+]] = OpConstant {{%\w+}} 2 +; CHECK: OpIAdd {{%\w+}} [[int1]] [[int2]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%int = OpTypeInt 32 1 +%null = OpConstantNull %int +%int_1 = OpConstant %int 1 +%int_2 = OpConstant %int 2 +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +OpSelectionMerge %2 None +OpSwitch %null %2 0 %3 +%3 = OpLabel +OpBranch %2 +%2 = OpLabel +%phi = OpPhi %int %int_1 %1 %int_2 %3 +%add = OpIAdd %int %int_1 %phi +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CCPTest, UndefSwitchCondition) { + const std::string text = R"( +; CHECK: [[int1:%\w+]] = OpConstant {{%\w+}} 1 +; CHECK: [[phi:%\w+]] = OpPhi +; CHECK: OpIAdd {{%\w+}} [[int1]] [[phi]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%int = OpTypeInt 32 1 +%undef = OpUndef %int +%int_1 = OpConstant %int 1 +%int_2 = OpConstant %int 2 +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +OpSelectionMerge %2 None +OpSwitch %undef %2 0 %3 +%3 = OpLabel +OpBranch %2 +%2 = OpLabel +%phi = OpPhi %int %int_1 %1 %int_2 %3 +%add = OpIAdd %int %int_1 %phi +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +// Test for #1361. +TEST_F(CCPTest, CompositeConstructOfGlobalValue) { + const std::string text = R"( +; CHECK: [[phi:%\w+]] = OpPhi +; CHECK-NEXT: OpCompositeExtract {{%\w+}} [[phi]] 0 +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" %in +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%int = OpTypeInt 32 1 +%bool = OpTypeBool +%functy = OpTypeFunction %void +%ptr_int_Input = OpTypePointer Input %int +%in = OpVariable %ptr_int_Input Input +%struct = OpTypeStruct %ptr_int_Input %ptr_int_Input +%struct_null = OpConstantNull %struct +%func = OpFunction %void None %functy +%1 = OpLabel +OpBranch %2 +%2 = OpLabel +%phi = OpPhi %struct %struct_null %1 %5 %4 +%extract = OpCompositeExtract %ptr_int_Input %phi 0 +OpLoopMerge %3 %4 None +OpBranch %4 +%4 = OpLabel +%5 = OpCompositeConstruct %struct %in %in +OpBranch %2 +%3 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/cfg_cleanup_test.cpp b/test/opt/cfg_cleanup_test.cpp new file mode 100644 index 000000000..3498f00bb --- /dev/null +++ b/test/opt/cfg_cleanup_test.cpp @@ -0,0 +1,456 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using CFGCleanupTest = PassTest<::testing::Test>; + +TEST_F(CFGCleanupTest, RemoveUnreachableBlocks) { + const std::string declarations = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %inf %outf4 +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %inf "inf" +OpName %outf4 "outf4" +OpDecorate %inf Location 0 +OpDecorate %outf4 Location 0 +%void = OpTypeVoid +%6 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Input_float = OpTypePointer Input %float +%inf = OpVariable %_ptr_Input_float Input +%float_2 = OpConstant %float 2 +%bool = OpTypeBool +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%outf4 = OpVariable %_ptr_Output_v4float Output +%float_n0_5 = OpConstant %float -0.5 +)"; + + const std::string body_before = R"(%main = OpFunction %void None %6 +%14 = OpLabel +OpBranch %18 +%19 = OpLabel +%20 = OpLoad %float %inf +%21 = OpCompositeConstruct %v4float %20 %20 %20 %20 +OpStore %outf4 %21 +OpBranch %17 +%18 = OpLabel +%22 = OpLoad %float %inf +%23 = OpFAdd %float %22 %float_n0_5 +%24 = OpCompositeConstruct %v4float %23 %23 %23 %23 +OpStore %outf4 %24 +OpBranch %17 +%17 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string body_after = R"(%main = OpFunction %void None %6 +%14 = OpLabel +OpBranch %15 +%15 = OpLabel +%20 = OpLoad %float %inf +%21 = OpFAdd %float %20 %float_n0_5 +%22 = OpCompositeConstruct %v4float %21 %21 %21 %21 +OpStore %outf4 %22 +OpBranch %19 +%19 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(declarations + body_before, + declarations + body_after, true, true); +} + +TEST_F(CFGCleanupTest, RemoveDecorations) { + const std::string before = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginUpperLeft + OpName %main "main" + OpName %x "x" + OpName %dead "dead" + OpDecorate %x RelaxedPrecision + OpDecorate %dead RelaxedPrecision + %void = OpTypeVoid + %6 = OpTypeFunction %void + %float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float + %float_2 = OpConstant %float 2 + %float_4 = OpConstant %float 4 + + %main = OpFunction %void None %6 + %14 = OpLabel + %x = OpVariable %_ptr_Function_float Function + OpBranch %18 + %19 = OpLabel + %dead = OpVariable %_ptr_Function_float Function + OpStore %dead %float_2 + OpBranch %17 + %18 = OpLabel + OpStore %x %float_4 + OpBranch %17 + %17 = OpLabel + OpReturn + OpFunctionEnd +)"; + + const std::string after = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpName %main "main" +OpName %x "x" +OpDecorate %x RelaxedPrecision +%void = OpTypeVoid +%6 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%float_2 = OpConstant %float 2 +%float_4 = OpConstant %float 4 +%main = OpFunction %void None %6 +%11 = OpLabel +%x = OpVariable %_ptr_Function_float Function +OpBranch %12 +%12 = OpLabel +OpStore %x %float_4 +OpBranch %14 +%14 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(before, after, true, true); +} + +TEST_F(CFGCleanupTest, UpdatePhis) { + const std::string before = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %y %outparm + OpExecutionMode %main OriginUpperLeft + OpName %main "main" + OpName %y "y" + OpName %outparm "outparm" + OpDecorate %y Flat + OpDecorate %y Location 0 + OpDecorate %outparm Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%_ptr_Input_int = OpTypePointer Input %int + %y = OpVariable %_ptr_Input_int Input + %int_10 = OpConstant %int 10 + %bool = OpTypeBool + %int_42 = OpConstant %int 42 + %int_23 = OpConstant %int 23 + %int_5 = OpConstant %int 5 +%_ptr_Output_int = OpTypePointer Output %int + %outparm = OpVariable %_ptr_Output_int Output + %main = OpFunction %void None %3 + %5 = OpLabel + %11 = OpLoad %int %y + OpBranch %21 + %16 = OpLabel + %20 = OpIAdd %int %11 %int_42 + OpBranch %17 + %21 = OpLabel + %24 = OpISub %int %11 %int_23 + OpBranch %17 + %17 = OpLabel + %31 = OpPhi %int %20 %16 %24 %21 + %27 = OpIAdd %int %31 %int_5 + OpStore %outparm %27 + OpReturn + OpFunctionEnd +)"; + + const std::string after = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %y %outparm +OpExecutionMode %main OriginUpperLeft +OpName %main "main" +OpName %y "y" +OpName %outparm "outparm" +OpDecorate %y Flat +OpDecorate %y Location 0 +OpDecorate %outparm Location 0 +%void = OpTypeVoid +%6 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%_ptr_Input_int = OpTypePointer Input %int +%y = OpVariable %_ptr_Input_int Input +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%int_42 = OpConstant %int 42 +%int_23 = OpConstant %int 23 +%int_5 = OpConstant %int 5 +%_ptr_Output_int = OpTypePointer Output %int +%outparm = OpVariable %_ptr_Output_int Output +%main = OpFunction %void None %6 +%16 = OpLabel +%17 = OpLoad %int %y +OpBranch %18 +%18 = OpLabel +%22 = OpISub %int %17 %int_23 +OpBranch %21 +%21 = OpLabel +%23 = OpPhi %int %22 %18 +%24 = OpIAdd %int %23 %int_5 +OpStore %outparm %24 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(before, after, true, true); +} + +TEST_F(CFGCleanupTest, RemoveNamedLabels) { + const std::string before = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 430 + OpName %main "main" + OpName %dead "dead" + %void = OpTypeVoid + %5 = OpTypeFunction %void + %main = OpFunction %void None %5 + %6 = OpLabel + OpReturn + %dead = OpLabel + OpReturn + OpFunctionEnd)"; + + const std::string after = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %main "main" +OpSource GLSL 430 +OpName %main "main" +%void = OpTypeVoid +%5 = OpTypeFunction %void +%main = OpFunction %void None %5 +%6 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(before, after, true, true); +} + +TEST_F(CFGCleanupTest, RemovePhiArgsFromFarBlocks) { + const std::string before = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %y %outparm + OpExecutionMode %main OriginUpperLeft + OpName %main "main" + OpName %y "y" + OpName %outparm "outparm" + OpDecorate %y Flat + OpDecorate %y Location 0 + OpDecorate %outparm Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%_ptr_Input_int = OpTypePointer Input %int + %y = OpVariable %_ptr_Input_int Input + %int_42 = OpConstant %int 42 +%_ptr_Output_int = OpTypePointer Output %int + %outparm = OpVariable %_ptr_Output_int Output + %int_14 = OpConstant %int 14 + %int_15 = OpConstant %int 15 + %int_5 = OpConstant %int 5 + %main = OpFunction %void None %3 + %5 = OpLabel + OpBranch %40 + %41 = OpLabel + %11 = OpLoad %int %y + OpBranch %40 + %40 = OpLabel + %12 = OpLoad %int %y + OpSelectionMerge %16 None + OpSwitch %12 %16 10 %13 13 %14 18 %15 + %13 = OpLabel + OpBranch %16 + %14 = OpLabel + OpStore %outparm %int_14 + OpBranch %16 + %15 = OpLabel + OpStore %outparm %int_15 + OpBranch %16 + %16 = OpLabel + %30 = OpPhi %int %11 %40 %int_42 %13 %11 %14 %11 %15 + %28 = OpIAdd %int %30 %int_5 + OpStore %outparm %28 + OpReturn + OpFunctionEnd)"; + + const std::string after = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %y %outparm +OpExecutionMode %main OriginUpperLeft +OpName %main "main" +OpName %y "y" +OpName %outparm "outparm" +OpDecorate %y Flat +OpDecorate %y Location 0 +OpDecorate %outparm Location 0 +%void = OpTypeVoid +%6 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%_ptr_Input_int = OpTypePointer Input %int +%y = OpVariable %_ptr_Input_int Input +%int_42 = OpConstant %int 42 +%_ptr_Output_int = OpTypePointer Output %int +%outparm = OpVariable %_ptr_Output_int Output +%int_14 = OpConstant %int 14 +%int_15 = OpConstant %int 15 +%int_5 = OpConstant %int 5 +%26 = OpUndef %int +%main = OpFunction %void None %6 +%15 = OpLabel +OpBranch %16 +%16 = OpLabel +%19 = OpLoad %int %y +OpSelectionMerge %20 None +OpSwitch %19 %20 10 %21 13 %22 18 %23 +%21 = OpLabel +OpBranch %20 +%22 = OpLabel +OpStore %outparm %int_14 +OpBranch %20 +%23 = OpLabel +OpStore %outparm %int_15 +OpBranch %20 +%20 = OpLabel +%24 = OpPhi %int %26 %16 %int_42 %21 %26 %22 %26 %23 +%25 = OpIAdd %int %24 %int_5 +OpStore %outparm %25 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(before, after, true, true); +} + +TEST_F(CFGCleanupTest, RemovePhiConstantArgs) { + const std::string before = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %y %outparm + OpExecutionMode %main OriginUpperLeft + OpName %main "main" + OpName %y "y" + OpName %outparm "outparm" + OpDecorate %y Flat + OpDecorate %y Location 0 + OpDecorate %outparm Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int + %y = OpVariable %_ptr_Input_int Input + %int_10 = OpConstant %int 10 + %bool = OpTypeBool +%_ptr_Function_int = OpTypePointer Function %int + %int_23 = OpConstant %int 23 + %int_5 = OpConstant %int 5 +%_ptr_Output_int = OpTypePointer Output %int + %outparm = OpVariable %_ptr_Output_int Output + %24 = OpUndef %int + %main = OpFunction %void None %3 + %5 = OpLabel + OpBranch %14 + %40 = OpLabel + %9 = OpLoad %int %y + %12 = OpSGreaterThan %bool %9 %int_10 + OpSelectionMerge %14 None + OpBranchConditional %12 %13 %14 + %13 = OpLabel + OpBranch %14 + %14 = OpLabel + %25 = OpPhi %int %24 %5 %int_23 %13 + %20 = OpIAdd %int %25 %int_5 + OpStore %outparm %20 + OpReturn + OpFunctionEnd)"; + + const std::string after = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %y %outparm +OpExecutionMode %main OriginUpperLeft +OpName %main "main" +OpName %y "y" +OpName %outparm "outparm" +OpDecorate %y Flat +OpDecorate %y Location 0 +OpDecorate %outparm Location 0 +%void = OpTypeVoid +%6 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%y = OpVariable %_ptr_Input_int Input +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%_ptr_Function_int = OpTypePointer Function %int +%int_23 = OpConstant %int 23 +%int_5 = OpConstant %int 5 +%_ptr_Output_int = OpTypePointer Output %int +%outparm = OpVariable %_ptr_Output_int Output +%15 = OpUndef %int +%main = OpFunction %void None %6 +%16 = OpLabel +OpBranch %17 +%17 = OpLabel +%22 = OpPhi %int %15 %16 +%23 = OpIAdd %int %22 %int_5 +OpStore %outparm %23 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(before, after, true, true); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/code_sink_test.cpp b/test/opt/code_sink_test.cpp new file mode 100644 index 000000000..9b86c660a --- /dev/null +++ b/test/opt/code_sink_test.cpp @@ -0,0 +1,533 @@ +// Copyright (c) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using CodeSinkTest = PassTest<::testing::Test>; + +TEST_F(CodeSinkTest, MoveToNextBlock) { + const std::string text = R"( +;CHECK: OpFunction +;CHECK: OpLabel +;CHECK: OpLabel +;CHECK: [[ac:%\w+]] = OpAccessChain +;CHECK: [[ld:%\w+]] = OpLoad %uint [[ac]] +;CHECK: OpCopyObject %uint [[ld]] + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + %void = OpTypeVoid + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_4 = OpConstant %uint 4 +%_arr_uint_uint_4 = OpTypeArray %uint %uint_4 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%_ptr_Uniform__arr_uint_uint_4 = OpTypePointer Uniform %_arr_uint_uint_4 + %9 = OpVariable %_ptr_Uniform__arr_uint_uint_4 Uniform + %10 = OpTypeFunction %void + %1 = OpFunction %void None %10 + %11 = OpLabel + %12 = OpAccessChain %_ptr_Uniform_uint %9 %uint_0 + %13 = OpLoad %uint %12 + OpBranch %14 + %14 = OpLabel + %15 = OpCopyObject %uint %13 + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CodeSinkTest, MovePastSelection) { + const std::string text = R"( +;CHECK: OpFunction +;CHECK: OpLabel +;CHECK: OpSelectionMerge [[merge_bb:%\w+]] +;CHECK: [[merge_bb]] = OpLabel +;CHECK: [[ac:%\w+]] = OpAccessChain +;CHECK: [[ld:%\w+]] = OpLoad %uint [[ac]] +;CHECK: OpCopyObject %uint [[ld]] + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + %void = OpTypeVoid + %bool = OpTypeBool + %true = OpConstantTrue %bool + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_4 = OpConstant %uint 4 +%_arr_uint_uint_4 = OpTypeArray %uint %uint_4 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%_ptr_Uniform__arr_uint_uint_4 = OpTypePointer Uniform %_arr_uint_uint_4 + %11 = OpVariable %_ptr_Uniform__arr_uint_uint_4 Uniform + %12 = OpTypeFunction %void + %1 = OpFunction %void None %12 + %13 = OpLabel + %14 = OpAccessChain %_ptr_Uniform_uint %11 %uint_0 + %15 = OpLoad %uint %14 + OpSelectionMerge %16 None + OpBranchConditional %true %17 %16 + %17 = OpLabel + OpBranch %16 + %16 = OpLabel + %18 = OpCopyObject %uint %15 + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CodeSinkTest, MoveIntoSelection) { + const std::string text = R"( +;CHECK: OpFunction +;CHECK: OpLabel +;CHECK: OpSelectionMerge [[merge_bb:%\w+]] +;CHECK-NEXT: OpBranchConditional %true [[bb:%\w+]] [[merge_bb]] +;CHECK: [[bb]] = OpLabel +;CHECK-NEXT: [[ac:%\w+]] = OpAccessChain +;CHECK-NEXT: [[ld:%\w+]] = OpLoad %uint [[ac]] +;CHECK-NEXT: OpCopyObject %uint [[ld]] + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + %void = OpTypeVoid + %bool = OpTypeBool + %true = OpConstantTrue %bool + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_4 = OpConstant %uint 4 +%_arr_uint_uint_4 = OpTypeArray %uint %uint_4 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%_ptr_Uniform__arr_uint_uint_4 = OpTypePointer Uniform %_arr_uint_uint_4 + %11 = OpVariable %_ptr_Uniform__arr_uint_uint_4 Uniform + %12 = OpTypeFunction %void + %1 = OpFunction %void None %12 + %13 = OpLabel + %14 = OpAccessChain %_ptr_Uniform_uint %11 %uint_0 + %15 = OpLoad %uint %14 + OpSelectionMerge %16 None + OpBranchConditional %true %17 %16 + %17 = OpLabel + %18 = OpCopyObject %uint %15 + OpBranch %16 + %16 = OpLabel + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CodeSinkTest, LeaveBeforeSelection) { + const std::string text = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + %void = OpTypeVoid + %bool = OpTypeBool + %true = OpConstantTrue %bool + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_4 = OpConstant %uint 4 +%_arr_uint_uint_4 = OpTypeArray %uint %uint_4 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%_ptr_Uniform__arr_uint_uint_4 = OpTypePointer Uniform %_arr_uint_uint_4 + %11 = OpVariable %_ptr_Uniform__arr_uint_uint_4 Uniform + %12 = OpTypeFunction %void + %1 = OpFunction %void None %12 + %13 = OpLabel + %14 = OpAccessChain %_ptr_Uniform_uint %11 %uint_0 + %15 = OpLoad %uint %14 + OpSelectionMerge %16 None + OpBranchConditional %true %17 %20 + %20 = OpLabel + OpBranch %16 + %17 = OpLabel + %18 = OpCopyObject %uint %15 + OpBranch %16 + %16 = OpLabel + %19 = OpCopyObject %uint %15 + OpReturn + OpFunctionEnd +)"; + + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ true); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +TEST_F(CodeSinkTest, LeaveAloneUseInSameBlock) { + const std::string text = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + %void = OpTypeVoid + %bool = OpTypeBool + %true = OpConstantTrue %bool + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_4 = OpConstant %uint 4 +%_arr_uint_uint_4 = OpTypeArray %uint %uint_4 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%_ptr_Uniform__arr_uint_uint_4 = OpTypePointer Uniform %_arr_uint_uint_4 + %11 = OpVariable %_ptr_Uniform__arr_uint_uint_4 Uniform + %12 = OpTypeFunction %void + %1 = OpFunction %void None %12 + %13 = OpLabel + %14 = OpAccessChain %_ptr_Uniform_uint %11 %uint_0 + %15 = OpLoad %uint %14 + %cond = OpIEqual %bool %15 %uint_0 + OpSelectionMerge %16 None + OpBranchConditional %cond %17 %16 + %17 = OpLabel + OpBranch %16 + %16 = OpLabel + %19 = OpCopyObject %uint %15 + OpReturn + OpFunctionEnd +)"; + + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ true); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +TEST_F(CodeSinkTest, DontMoveIntoLoop) { + const std::string text = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + %void = OpTypeVoid + %bool = OpTypeBool + %true = OpConstantTrue %bool + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_4 = OpConstant %uint 4 +%_arr_uint_uint_4 = OpTypeArray %uint %uint_4 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%_ptr_Uniform__arr_uint_uint_4 = OpTypePointer Uniform %_arr_uint_uint_4 + %11 = OpVariable %_ptr_Uniform__arr_uint_uint_4 Uniform + %12 = OpTypeFunction %void + %1 = OpFunction %void None %12 + %13 = OpLabel + %14 = OpAccessChain %_ptr_Uniform_uint %11 %uint_0 + %15 = OpLoad %uint %14 + OpBranch %17 + %17 = OpLabel + OpLoopMerge %merge %cont None + OpBranch %cont + %cont = OpLabel + %cond = OpIEqual %bool %15 %uint_0 + OpBranchConditional %cond %merge %17 + %merge = OpLabel + OpReturn + OpFunctionEnd +)"; + + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ true); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +TEST_F(CodeSinkTest, DontMoveIntoLoop2) { + const std::string text = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + %void = OpTypeVoid + %bool = OpTypeBool + %true = OpConstantTrue %bool + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_4 = OpConstant %uint 4 +%_arr_uint_uint_4 = OpTypeArray %uint %uint_4 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%_ptr_Uniform__arr_uint_uint_4 = OpTypePointer Uniform %_arr_uint_uint_4 + %11 = OpVariable %_ptr_Uniform__arr_uint_uint_4 Uniform + %12 = OpTypeFunction %void + %1 = OpFunction %void None %12 + %13 = OpLabel + %14 = OpAccessChain %_ptr_Uniform_uint %11 %uint_0 + %15 = OpLoad %uint %14 + OpSelectionMerge %16 None + OpBranchConditional %true %17 %16 + %17 = OpLabel + OpLoopMerge %merge %cont None + OpBranch %cont + %cont = OpLabel + %cond = OpIEqual %bool %15 %uint_0 + OpBranchConditional %cond %merge %17 + %merge = OpLabel + OpBranch %16 + %16 = OpLabel + OpReturn + OpFunctionEnd +)"; + + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ true); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +TEST_F(CodeSinkTest, DontMoveSelectionUsedInBothSides) { + const std::string text = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + %void = OpTypeVoid + %bool = OpTypeBool + %true = OpConstantTrue %bool + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_4 = OpConstant %uint 4 +%_arr_uint_uint_4 = OpTypeArray %uint %uint_4 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%_ptr_Uniform__arr_uint_uint_4 = OpTypePointer Uniform %_arr_uint_uint_4 + %11 = OpVariable %_ptr_Uniform__arr_uint_uint_4 Uniform + %12 = OpTypeFunction %void + %1 = OpFunction %void None %12 + %13 = OpLabel + %14 = OpAccessChain %_ptr_Uniform_uint %11 %uint_0 + %15 = OpLoad %uint %14 + OpSelectionMerge %16 None + OpBranchConditional %true %17 %20 + %20 = OpLabel + %19 = OpCopyObject %uint %15 + OpBranch %16 + %17 = OpLabel + %18 = OpCopyObject %uint %15 + OpBranch %16 + %16 = OpLabel + OpReturn + OpFunctionEnd +)"; + + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ true); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +TEST_F(CodeSinkTest, DontMoveBecauseOfStore) { + const std::string text = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + %void = OpTypeVoid + %bool = OpTypeBool + %true = OpConstantTrue %bool + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_4 = OpConstant %uint 4 +%_arr_uint_uint_4 = OpTypeArray %uint %uint_4 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%_ptr_Uniform__arr_uint_uint_4 = OpTypePointer Uniform %_arr_uint_uint_4 + %11 = OpVariable %_ptr_Uniform__arr_uint_uint_4 Uniform + %12 = OpTypeFunction %void + %1 = OpFunction %void None %12 + %13 = OpLabel + %14 = OpAccessChain %_ptr_Uniform_uint %11 %uint_0 + %15 = OpLoad %uint %14 + OpStore %14 %15 + OpSelectionMerge %16 None + OpBranchConditional %true %17 %20 + %20 = OpLabel + OpBranch %16 + %17 = OpLabel + %18 = OpCopyObject %uint %15 + OpBranch %16 + %16 = OpLabel + OpReturn + OpFunctionEnd +)"; + + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ true); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +TEST_F(CodeSinkTest, MoveReadOnlyLoadWithSync) { + const std::string text = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + %void = OpTypeVoid + %bool = OpTypeBool + %true = OpConstantTrue %bool + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_4 = OpConstant %uint 4 +%mem_semantics = OpConstant %uint 0x42 ; Uniform memeory arquire +%_arr_uint_uint_4 = OpTypeArray %uint %uint_4 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%_ptr_Uniform__arr_uint_uint_4 = OpTypePointer Uniform %_arr_uint_uint_4 + %11 = OpVariable %_ptr_Uniform__arr_uint_uint_4 Uniform + %12 = OpTypeFunction %void + %1 = OpFunction %void None %12 + %13 = OpLabel + %14 = OpAccessChain %_ptr_Uniform_uint %11 %uint_0 + %15 = OpLoad %uint %14 + OpMemoryBarrier %uint_4 %mem_semantics + OpSelectionMerge %16 None + OpBranchConditional %true %17 %20 + %20 = OpLabel + OpBranch %16 + %17 = OpLabel + %18 = OpCopyObject %uint %15 + OpBranch %16 + %16 = OpLabel + OpReturn + OpFunctionEnd +)"; + + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ true); + EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result)); +} + +TEST_F(CodeSinkTest, DontMoveBecauseOfSync) { + const std::string text = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + OpDecorate %_arr_uint_uint_4 BufferBlock + OpMemberDecorate %_arr_uint_uint_4 0 Offset 0 + %void = OpTypeVoid + %bool = OpTypeBool + %true = OpConstantTrue %bool + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_4 = OpConstant %uint 4 +%mem_semantics = OpConstant %uint 0x42 ; Uniform memeory arquire +%_arr_uint_uint_4 = OpTypeStruct %uint +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%_ptr_Uniform__arr_uint_uint_4 = OpTypePointer Uniform %_arr_uint_uint_4 + %11 = OpVariable %_ptr_Uniform__arr_uint_uint_4 Uniform + %12 = OpTypeFunction %void + %1 = OpFunction %void None %12 + %13 = OpLabel + %14 = OpAccessChain %_ptr_Uniform_uint %11 %uint_0 + %15 = OpLoad %uint %14 + OpMemoryBarrier %uint_4 %mem_semantics + OpSelectionMerge %16 None + OpBranchConditional %true %17 %20 + %20 = OpLabel + OpBranch %16 + %17 = OpLabel + %18 = OpCopyObject %uint %15 + OpBranch %16 + %16 = OpLabel + OpReturn + OpFunctionEnd +)"; + + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ true); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +TEST_F(CodeSinkTest, DontMoveBecauseOfAtomicWithSync) { + const std::string text = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + OpDecorate %_arr_uint_uint_4 BufferBlock + OpMemberDecorate %_arr_uint_uint_4 0 Offset 0 + %void = OpTypeVoid + %bool = OpTypeBool + %true = OpConstantTrue %bool + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_4 = OpConstant %uint 4 +%mem_semantics = OpConstant %uint 0x42 ; Uniform memeory arquire +%_arr_uint_uint_4 = OpTypeStruct %uint +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%_ptr_Uniform__arr_uint_uint_4 = OpTypePointer Uniform %_arr_uint_uint_4 + %11 = OpVariable %_ptr_Uniform__arr_uint_uint_4 Uniform + %12 = OpTypeFunction %void + %1 = OpFunction %void None %12 + %13 = OpLabel + %14 = OpAccessChain %_ptr_Uniform_uint %11 %uint_0 + %15 = OpLoad %uint %14 + %al = OpAtomicLoad %uint %14 %uint_4 %mem_semantics + OpSelectionMerge %16 None + OpBranchConditional %true %17 %20 + %20 = OpLabel + OpBranch %16 + %17 = OpLabel + %18 = OpCopyObject %uint %15 + OpBranch %16 + %16 = OpLabel + OpReturn + OpFunctionEnd +)"; + + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ true); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +TEST_F(CodeSinkTest, MoveWithAtomicWithoutSync) { + const std::string text = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + OpDecorate %_arr_uint_uint_4 BufferBlock + OpMemberDecorate %_arr_uint_uint_4 0 Offset 0 + %void = OpTypeVoid + %bool = OpTypeBool + %true = OpConstantTrue %bool + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_4 = OpConstant %uint 4 +%_arr_uint_uint_4 = OpTypeStruct %uint +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%_ptr_Uniform__arr_uint_uint_4 = OpTypePointer Uniform %_arr_uint_uint_4 + %11 = OpVariable %_ptr_Uniform__arr_uint_uint_4 Uniform + %12 = OpTypeFunction %void + %1 = OpFunction %void None %12 + %13 = OpLabel + %14 = OpAccessChain %_ptr_Uniform_uint %11 %uint_0 + %15 = OpLoad %uint %14 + %al = OpAtomicLoad %uint %14 %uint_4 %uint_0 + OpSelectionMerge %16 None + OpBranchConditional %true %17 %20 + %20 = OpLabel + OpBranch %16 + %17 = OpLabel + %18 = OpCopyObject %uint %15 + OpBranch %16 + %16 = OpLabel + OpReturn + OpFunctionEnd +)"; + + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ true); + EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result)); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/combine_access_chains_test.cpp b/test/opt/combine_access_chains_test.cpp new file mode 100644 index 000000000..aed14c9de --- /dev/null +++ b/test/opt/combine_access_chains_test.cpp @@ -0,0 +1,773 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using CombineAccessChainsTest = PassTest<::testing::Test>; + +TEST_F(CombineAccessChainsTest, PtrAccessChainFromAccessChainConstant) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: OpAccessChain [[ptr_int]] [[var]] [[int3]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpAccessChain %ptr_Workgroup_uint %var %uint_0 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, PtrAccessChainFromInBoundsAccessChainConstant) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: OpAccessChain [[ptr_int]] [[var]] [[int3]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpInBoundsAccessChain %ptr_Workgroup_uint %var %uint_0 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, PtrAccessChainFromAccessChainCombineConstant) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[int2:%\w+]] = OpConstant [[int]] 2 +; CHECK: OpAccessChain [[ptr_int]] [[var]] [[int2]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpAccessChain %ptr_Workgroup_uint %var %uint_1 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %uint_1 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, PtrAccessChainFromAccessChainNonConstant) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[ld1:%\w+]] = OpLoad +; CHECK: [[ld2:%\w+]] = OpLoad +; CHECK: [[add:%\w+]] = OpIAdd [[int]] [[ld1]] [[ld2]] +; CHECK: OpAccessChain [[ptr_int]] [[var]] [[add]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Function_uint = OpTypePointer Function %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%local_var = OpVariable %ptr_Function_uint Function +%ld1 = OpLoad %uint %local_var +%gep = OpAccessChain %ptr_Workgroup_uint %var %ld1 +%ld2 = OpLoad %uint %local_var +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %ld2 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, PtrAccessChainFromAccessChainExtraIndices) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int1:%\w+]] = OpConstant [[int]] 1 +; CHECK: [[int2:%\w+]] = OpConstant [[int]] 2 +; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: OpAccessChain [[ptr_int]] [[var]] [[int1]] [[int2]] [[int3]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%uint_2 = OpConstant %uint 2 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%uint_array_4_array_4 = OpTypeArray %uint_array_4 %uint_4 +%uint_array_4_array_4_array_4 = OpTypeArray %uint_array_4_array_4 %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Function_uint = OpTypePointer Function %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%ptr_Workgroup_uint_array_4_array_4 = OpTypePointer Workgroup %uint_array_4_array_4 +%ptr_Workgroup_uint_array_4_array_4_array_4 = OpTypePointer Workgroup %uint_array_4_array_4_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4_array_4_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpAccessChain %ptr_Workgroup_uint_array_4 %var %uint_1 %uint_0 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %uint_2 %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, + PtrAccessChainFromPtrAccessChainCombineElementOperand) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[int6:%\w+]] = OpConstant [[int]] 6 +; CHECK: OpPtrAccessChain [[ptr_int]] [[var]] [[int6]] [[int3]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %uint_3 %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, + PtrAccessChainFromPtrAccessChainOnlyElementOperand) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int4:%\w+]] = OpConstant [[int]] 4 +; CHECK: [[array:%\w+]] = OpTypeArray [[int]] [[int4]] +; CHECK: [[ptr_array:%\w+]] = OpTypePointer Workgroup [[array]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[int6:%\w+]] = OpConstant [[int]] 6 +; CHECK: OpPtrAccessChain [[ptr_array]] [[var]] [[int6]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %gep %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, + PtrAccessChainFromPtrAccessCombineNonElementIndex) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: OpPtrAccessChain [[ptr_int]] [[var]] [[int3]] [[int3]] [[int3]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%uint_array_4_array_4 = OpTypeArray %uint_array_4 %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Function_uint = OpTypePointer Function %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%ptr_Workgroup_uint_array_4_array_4 = OpTypePointer Workgroup %uint_array_4_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3 %uint_0 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %uint_3 %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, + AccessChainFromPtrAccessChainOnlyElementOperand) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: OpPtrAccessChain [[ptr_int]] [[var]] [[int3]] [[int3]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3 +%gep = OpAccessChain %ptr_Workgroup_uint %ptr_gep %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, AccessChainFromPtrAccessChainAppend) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int1:%\w+]] = OpConstant [[int]] 1 +; CHECK: [[int2:%\w+]] = OpConstant [[int]] 2 +; CHECK: [[int3:%\w+]] = OpConstant [[int]] 3 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: OpPtrAccessChain [[ptr_int]] [[var]] [[int1]] [[int2]] [[int3]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%uint_2 = OpConstant %uint 2 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%uint_array_4_array_4 = OpTypeArray %uint_array_4 %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%ptr_Workgroup_uint_array_4_array_4 = OpTypePointer Workgroup %uint_array_4_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_1 %uint_2 +%gep = OpAccessChain %ptr_Workgroup_uint %ptr_gep %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, AccessChainFromAccessChainAppend) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int1:%\w+]] = OpConstant [[int]] 1 +; CHECK: [[int2:%\w+]] = OpConstant [[int]] 2 +; CHECK: [[ptr_int:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: OpAccessChain [[ptr_int]] [[var]] [[int1]] [[int2]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%uint_2 = OpConstant %uint 2 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%uint_array_4_array_4 = OpTypeArray %uint_array_4 %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%ptr_Workgroup_uint_array_4_array_4 = OpTypePointer Workgroup %uint_array_4_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%ptr_gep = OpAccessChain %ptr_Workgroup_uint_array_4 %var %uint_1 +%gep = OpAccessChain %ptr_Workgroup_uint %ptr_gep %uint_2 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, NonConstantStructSlide) { + const std::string text = R"( +; CHECK: [[int0:%\w+]] = OpConstant {{%\w+}} 0 +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[ld:%\w+]] = OpLoad +; CHECK: OpPtrAccessChain {{%\w+}} [[var]] [[ld]] [[int0]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%struct = OpTypeStruct %uint %uint +%ptr_Workgroup_struct = OpTypePointer Workgroup %struct +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Function_uint = OpTypePointer Function %uint +%wg_var = OpVariable %ptr_Workgroup_struct Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%1 = OpLabel +%func_var = OpVariable %ptr_Function_uint Function +%ld = OpLoad %uint %func_var +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_struct %wg_var %ld +%gep = OpAccessChain %ptr_Workgroup_uint %ptr_gep %uint_0 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, DontCombineNonConstantStructSlide) { + const std::string text = R"( +; CHECK: [[int0:%\w+]] = OpConstant {{%\w+}} 0 +; CHECK: [[ld:%\w+]] = OpLoad +; CHECK: [[gep:%\w+]] = OpAccessChain +; CHECK: OpPtrAccessChain {{%\w+}} [[gep]] [[ld]] [[int0]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_4 = OpConstant %uint 4 +%struct = OpTypeStruct %uint %uint +%struct_array_4 = OpTypeArray %struct %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Function_uint = OpTypePointer Function %uint +%ptr_Workgroup_struct = OpTypePointer Workgroup %struct +%ptr_Workgroup_struct_array_4 = OpTypePointer Workgroup %struct_array_4 +%wg_var = OpVariable %ptr_Workgroup_struct_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%1 = OpLabel +%func_var = OpVariable %ptr_Function_uint Function +%ld = OpLoad %uint %func_var +%gep = OpAccessChain %ptr_Workgroup_struct %wg_var %uint_0 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %ld %uint_0 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, CombineNonConstantStructSlideElement) { + const std::string text = R"( +; CHECK: [[int0:%\w+]] = OpConstant {{%\w+}} 0 +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[ld:%\w+]] = OpLoad +; CHECK: [[add:%\w+]] = OpIAdd {{%\w+}} [[ld]] [[ld]] +; CHECK: OpPtrAccessChain {{%\w+}} [[var]] [[add]] [[int0]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_4 = OpConstant %uint 4 +%struct = OpTypeStruct %uint %uint +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Function_uint = OpTypePointer Function %uint +%ptr_Workgroup_struct = OpTypePointer Workgroup %struct +%wg_var = OpVariable %ptr_Workgroup_struct Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%1 = OpLabel +%func_var = OpVariable %ptr_Function_uint Function +%ld = OpLoad %uint %func_var +%gep = OpPtrAccessChain %ptr_Workgroup_struct %wg_var %ld +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint %gep %ld %uint_0 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, PtrAccessChainFromInBoundsPtrAccessChain) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int4:%\w+]] = OpConstant [[int]] 4 +; CHECK: [[array:%\w+]] = OpTypeArray [[int]] [[int4]] +; CHECK: [[ptr_array:%\w+]] = OpTypePointer Workgroup [[array]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[int6:%\w+]] = OpConstant [[int]] 6 +; CHECK: OpPtrAccessChain [[ptr_array]] [[var]] [[int6]] +OpCapability Shader +OpCapability VariablePointers +OpCapability Addresses +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpInBoundsPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3 +%ptr_gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %gep %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, InBoundsPtrAccessChainFromPtrAccessChain) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int4:%\w+]] = OpConstant [[int]] 4 +; CHECK: [[array:%\w+]] = OpTypeArray [[int]] [[int4]] +; CHECK: [[ptr_array:%\w+]] = OpTypePointer Workgroup [[array]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[int6:%\w+]] = OpConstant [[int]] 6 +; CHECK: OpPtrAccessChain [[ptr_array]] [[var]] [[int6]] +OpCapability Shader +OpCapability VariablePointers +OpCapability Addresses +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3 +%ptr_gep = OpInBoundsPtrAccessChain %ptr_Workgroup_uint_array_4 %gep %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, + InBoundsPtrAccessChainFromInBoundsPtrAccessChain) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[int4:%\w+]] = OpConstant [[int]] 4 +; CHECK: [[array:%\w+]] = OpTypeArray [[int]] [[int4]] +; CHECK: [[ptr_array:%\w+]] = OpTypePointer Workgroup [[array]] +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Workgroup +; CHECK: [[int6:%\w+]] = OpConstant [[int]] 6 +; CHECK: OpInBoundsPtrAccessChain [[ptr_array]] [[var]] [[int6]] +OpCapability Shader +OpCapability VariablePointers +OpCapability Addresses +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_array_4 = OpTypeArray %uint %uint_4 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%ptr_Workgroup_uint_array_4 = OpTypePointer Workgroup %uint_array_4 +%var = OpVariable %ptr_Workgroup_uint_array_4 Workgroup +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +%gep = OpInBoundsPtrAccessChain %ptr_Workgroup_uint_array_4 %var %uint_3 +%ptr_gep = OpInBoundsPtrAccessChain %ptr_Workgroup_uint_array_4 %gep %uint_3 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, NoIndexAccessChains) { + const std::string text = R"( +; CHECK: [[var:%\w+]] = OpVariable +; CHECK-NOT: OpConstant +; CHECK: [[gep:%\w+]] = OpAccessChain {{%\w+}} [[var]] +; CHECK: OpAccessChain {{%\w+}} [[var]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%var = OpVariable %ptr_Workgroup_uint Workgroup +%void_func = OpTypeFunction %void +%func = OpFunction %void None %void_func +%1 = OpLabel +%gep1 = OpAccessChain %ptr_Workgroup_uint %var +%gep2 = OpAccessChain %ptr_Workgroup_uint %gep1 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, NoIndexPtrAccessChains) { + const std::string text = R"( +; CHECK: [[int0:%\w+]] = OpConstant {{%\w+}} 0 +; CHECK: [[var:%\w+]] = OpVariable +; CHECK: [[gep:%\w+]] = OpPtrAccessChain {{%\w+}} [[var]] [[int0]] +; CHECK: OpCopyObject {{%\w+}} [[gep]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%var = OpVariable %ptr_Workgroup_uint Workgroup +%void_func = OpTypeFunction %void +%func = OpFunction %void None %void_func +%1 = OpLabel +%gep1 = OpPtrAccessChain %ptr_Workgroup_uint %var %uint_0 +%gep2 = OpAccessChain %ptr_Workgroup_uint %gep1 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, NoIndexPtrAccessChains2) { + const std::string text = R"( +; CHECK: [[int0:%\w+]] = OpConstant {{%\w+}} 0 +; CHECK: [[var:%\w+]] = OpVariable +; CHECK: OpPtrAccessChain {{%\w+}} [[var]] [[int0]] +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%var = OpVariable %ptr_Workgroup_uint Workgroup +%void_func = OpTypeFunction %void +%func = OpFunction %void None %void_func +%1 = OpLabel +%gep1 = OpAccessChain %ptr_Workgroup_uint %var +%gep2 = OpPtrAccessChain %ptr_Workgroup_uint %gep1 %uint_0 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(CombineAccessChainsTest, CombineMixedSign) { + const std::string text = R"( +; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 +; CHECK: [[var:%\w+]] = OpVariable +; CHECK: [[uint2:%\w+]] = OpConstant [[uint]] 2 +; CHECK: OpInBoundsPtrAccessChain {{%\w+}} [[var]] [[uint2]] +OpCapability Shader +OpCapability VariablePointers +OpCapability Addresses +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%int = OpTypeInt 32 1 +%uint_1 = OpConstant %uint 1 +%int_1 = OpConstant %int 1 +%ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%var = OpVariable %ptr_Workgroup_uint Workgroup +%void_func = OpTypeFunction %void +%func = OpFunction %void None %void_func +%1 = OpLabel +%gep1 = OpInBoundsPtrAccessChain %ptr_Workgroup_uint %var %uint_1 +%gep2 = OpInBoundsPtrAccessChain %ptr_Workgroup_uint %gep1 %int_1 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/common_uniform_elim_test.cpp b/test/opt/common_uniform_elim_test.cpp new file mode 100644 index 000000000..9e3943961 --- /dev/null +++ b/test/opt/common_uniform_elim_test.cpp @@ -0,0 +1,1394 @@ +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "test/opt/pass_fixture.h" + +namespace spvtools { +namespace opt { +namespace { + +using CommonUniformElimTest = PassTest<::testing::Test>; + +TEST_F(CommonUniformElimTest, Basic1) { + // Note: This test exemplifies the following: + // - Common uniform (%_) load floated to nearest non-controlled block + // - Common extract (g_F) floated to non-controlled block + // - Non-common extract (g_F2) not floated, but common uniform load shared + // + // #version 140 + // in vec4 BaseColor; + // in float fi; + // + // layout(std140) uniform U_t + // { + // float g_F; + // float g_F2; + // } ; + // + // void main() + // { + // vec4 v = BaseColor; + // if (fi > 0) { + // v = v * g_F; + // } + // else { + // float f2 = g_F2 - g_F; + // v = v * f2; + // } + // gl_FragColor = v; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %fi %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %fi "fi" +OpName %U_t "U_t" +OpMemberName %U_t 0 "g_F" +OpMemberName %U_t 1 "g_F2" +OpName %_ "" +OpName %f2 "f2" +OpName %gl_FragColor "gl_FragColor" +OpMemberDecorate %U_t 0 Offset 0 +OpMemberDecorate %U_t 1 Offset 4 +OpDecorate %U_t Block +OpDecorate %_ DescriptorSet 0 +%void = OpTypeVoid +%11 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%fi = OpVariable %_ptr_Input_float Input +%float_0 = OpConstant %float 0 +%bool = OpTypeBool +%U_t = OpTypeStruct %float %float +%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t +%_ = OpVariable %_ptr_Uniform_U_t Uniform +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%_ptr_Uniform_float = OpTypePointer Uniform %float +%_ptr_Function_float = OpTypePointer Function %float +%int_1 = OpConstant %int 1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %11 +%26 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%f2 = OpVariable %_ptr_Function_float Function +%27 = OpLoad %v4float %BaseColor +OpStore %v %27 +%28 = OpLoad %float %fi +%29 = OpFOrdGreaterThan %bool %28 %float_0 +OpSelectionMerge %30 None +OpBranchConditional %29 %31 %32 +%31 = OpLabel +%33 = OpLoad %v4float %v +%34 = OpAccessChain %_ptr_Uniform_float %_ %int_0 +%35 = OpLoad %float %34 +%36 = OpVectorTimesScalar %v4float %33 %35 +OpStore %v %36 +OpBranch %30 +%32 = OpLabel +%37 = OpAccessChain %_ptr_Uniform_float %_ %int_1 +%38 = OpLoad %float %37 +%39 = OpAccessChain %_ptr_Uniform_float %_ %int_0 +%40 = OpLoad %float %39 +%41 = OpFSub %float %38 %40 +OpStore %f2 %41 +%42 = OpLoad %v4float %v +%43 = OpLoad %float %f2 +%44 = OpVectorTimesScalar %v4float %42 %43 +OpStore %v %44 +OpBranch %30 +%30 = OpLabel +%45 = OpLoad %v4float %v +OpStore %gl_FragColor %45 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %11 +%26 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%f2 = OpVariable %_ptr_Function_float Function +%52 = OpLoad %U_t %_ +%53 = OpCompositeExtract %float %52 0 +%27 = OpLoad %v4float %BaseColor +OpStore %v %27 +%28 = OpLoad %float %fi +%29 = OpFOrdGreaterThan %bool %28 %float_0 +OpSelectionMerge %30 None +OpBranchConditional %29 %31 %32 +%31 = OpLabel +%33 = OpLoad %v4float %v +%36 = OpVectorTimesScalar %v4float %33 %53 +OpStore %v %36 +OpBranch %30 +%32 = OpLabel +%49 = OpCompositeExtract %float %52 1 +%41 = OpFSub %float %49 %53 +OpStore %f2 %41 +%42 = OpLoad %v4float %v +%43 = OpLoad %float %f2 +%44 = OpVectorTimesScalar %v4float %42 %43 +OpStore %v %44 +OpBranch %30 +%30 = OpLabel +%45 = OpLoad %v4float %v +OpStore %gl_FragColor %45 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); +} + +TEST_F(CommonUniformElimTest, Basic2) { + // Note: This test exemplifies the following: + // - Common uniform (%_) load floated to nearest non-controlled block + // - Common extract (g_F) floated to non-controlled block + // - Non-common extract (g_F2) not floated, but common uniform load shared + // + // #version 140 + // in vec4 BaseColor; + // in float fi; + // in float fi2; + // + // layout(std140) uniform U_t + // { + // float g_F; + // float g_F2; + // } ; + // + // void main() + // { + // float f = fi; + // if (f < 0) + // f = -f; + // if (fi2 > 0) { + // f = f * g_F; + // } + // else { + // f = g_F2 - g_F; + // } + // gl_FragColor = f * BaseColor; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %fi %fi2 %gl_FragColor %BaseColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %f "f" +OpName %fi "fi" +OpName %fi2 "fi2" +OpName %U_t "U_t" +OpMemberName %U_t 0 "g_F" +OpMemberName %U_t 1 "g_F2" +OpName %_ "" +OpName %gl_FragColor "gl_FragColor" +OpName %BaseColor "BaseColor" +OpMemberDecorate %U_t 0 Offset 0 +OpMemberDecorate %U_t 1 Offset 4 +OpDecorate %U_t Block +OpDecorate %_ DescriptorSet 0 +%void = OpTypeVoid +%11 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%_ptr_Input_float = OpTypePointer Input %float +%fi = OpVariable %_ptr_Input_float Input +%float_0 = OpConstant %float 0 +%bool = OpTypeBool +%fi2 = OpVariable %_ptr_Input_float Input +%U_t = OpTypeStruct %float %float +%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t +%_ = OpVariable %_ptr_Uniform_U_t Uniform +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%_ptr_Uniform_float = OpTypePointer Uniform %float +%int_1 = OpConstant %int 1 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +)"; + + const std::string before = + R"(%main = OpFunction %void None %11 +%25 = OpLabel +%f = OpVariable %_ptr_Function_float Function +%26 = OpLoad %float %fi +OpStore %f %26 +%27 = OpLoad %float %f +%28 = OpFOrdLessThan %bool %27 %float_0 +OpSelectionMerge %29 None +OpBranchConditional %28 %30 %29 +%30 = OpLabel +%31 = OpLoad %float %f +%32 = OpFNegate %float %31 +OpStore %f %32 +OpBranch %29 +%29 = OpLabel +%33 = OpLoad %float %fi2 +%34 = OpFOrdGreaterThan %bool %33 %float_0 +OpSelectionMerge %35 None +OpBranchConditional %34 %36 %37 +%36 = OpLabel +%38 = OpLoad %float %f +%39 = OpAccessChain %_ptr_Uniform_float %_ %int_0 +%40 = OpLoad %float %39 +%41 = OpFMul %float %38 %40 +OpStore %f %41 +OpBranch %35 +%37 = OpLabel +%42 = OpAccessChain %_ptr_Uniform_float %_ %int_1 +%43 = OpLoad %float %42 +%44 = OpAccessChain %_ptr_Uniform_float %_ %int_0 +%45 = OpLoad %float %44 +%46 = OpFSub %float %43 %45 +OpStore %f %46 +OpBranch %35 +%35 = OpLabel +%47 = OpLoad %v4float %BaseColor +%48 = OpLoad %float %f +%49 = OpVectorTimesScalar %v4float %47 %48 +OpStore %gl_FragColor %49 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %11 +%25 = OpLabel +%f = OpVariable %_ptr_Function_float Function +%26 = OpLoad %float %fi +OpStore %f %26 +%27 = OpLoad %float %f +%28 = OpFOrdLessThan %bool %27 %float_0 +OpSelectionMerge %29 None +OpBranchConditional %28 %30 %29 +%30 = OpLabel +%31 = OpLoad %float %f +%32 = OpFNegate %float %31 +OpStore %f %32 +OpBranch %29 +%29 = OpLabel +%56 = OpLoad %U_t %_ +%57 = OpCompositeExtract %float %56 0 +%33 = OpLoad %float %fi2 +%34 = OpFOrdGreaterThan %bool %33 %float_0 +OpSelectionMerge %35 None +OpBranchConditional %34 %36 %37 +%36 = OpLabel +%38 = OpLoad %float %f +%41 = OpFMul %float %38 %57 +OpStore %f %41 +OpBranch %35 +%37 = OpLabel +%53 = OpCompositeExtract %float %56 1 +%46 = OpFSub %float %53 %57 +OpStore %f %46 +OpBranch %35 +%35 = OpLabel +%47 = OpLoad %v4float %BaseColor +%48 = OpLoad %float %f +%49 = OpVectorTimesScalar %v4float %47 %48 +OpStore %gl_FragColor %49 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); +} + +TEST_F(CommonUniformElimTest, Basic3) { + // Note: This test exemplifies the following: + // - Existing common uniform (%_) load kept in place and shared + // + // #version 140 + // in vec4 BaseColor; + // in float fi; + // + // layout(std140) uniform U_t + // { + // bool g_B; + // float g_F; + // } ; + // + // void main() + // { + // vec4 v = BaseColor; + // if (g_B) + // v = v * g_F; + // gl_FragColor = v; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor %fi +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %U_t "U_t" +OpMemberName %U_t 0 "g_B" +OpMemberName %U_t 1 "g_F" +OpName %_ "" +OpName %gl_FragColor "gl_FragColor" +OpName %fi "fi" +OpMemberDecorate %U_t 0 Offset 0 +OpMemberDecorate %U_t 1 Offset 4 +OpDecorate %U_t Block +OpDecorate %_ DescriptorSet 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%uint = OpTypeInt 32 0 +%U_t = OpTypeStruct %uint %float +%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t +%_ = OpVariable %_ptr_Uniform_U_t Uniform +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%bool = OpTypeBool +%uint_0 = OpConstant %uint 0 +%int_1 = OpConstant %int 1 +%_ptr_Uniform_float = OpTypePointer Uniform %float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%_ptr_Input_float = OpTypePointer Input %float +%fi = OpVariable %_ptr_Input_float Input +)"; + + const std::string before = + R"(%main = OpFunction %void None %10 +%26 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%27 = OpLoad %v4float %BaseColor +OpStore %v %27 +%28 = OpAccessChain %_ptr_Uniform_uint %_ %int_0 +%29 = OpLoad %uint %28 +%30 = OpINotEqual %bool %29 %uint_0 +OpSelectionMerge %31 None +OpBranchConditional %30 %32 %31 +%32 = OpLabel +%33 = OpLoad %v4float %v +%34 = OpAccessChain %_ptr_Uniform_float %_ %int_1 +%35 = OpLoad %float %34 +%36 = OpVectorTimesScalar %v4float %33 %35 +OpStore %v %36 +OpBranch %31 +%31 = OpLabel +%37 = OpLoad %v4float %v +OpStore %gl_FragColor %37 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %10 +%26 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%27 = OpLoad %v4float %BaseColor +OpStore %v %27 +%38 = OpLoad %U_t %_ +%39 = OpCompositeExtract %uint %38 0 +%30 = OpINotEqual %bool %39 %uint_0 +OpSelectionMerge %31 None +OpBranchConditional %30 %32 %31 +%32 = OpLabel +%33 = OpLoad %v4float %v +%41 = OpCompositeExtract %float %38 1 +%36 = OpVectorTimesScalar %v4float %33 %41 +OpStore %v %36 +OpBranch %31 +%31 = OpLabel +%37 = OpLoad %v4float %v +OpStore %gl_FragColor %37 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); +} + +TEST_F(CommonUniformElimTest, Loop) { + // Note: This test exemplifies the following: + // - Common extract (g_F) shared between two loops + // #version 140 + // in vec4 BC; + // in vec4 BC2; + // + // layout(std140) uniform U_t + // { + // float g_F; + // } ; + // + // void main() + // { + // vec4 v = BC; + // for (int i = 0; i < 4; i++) + // v[i] = v[i] / g_F; + // vec4 v2 = BC2; + // for (int i = 0; i < 4; i++) + // v2[i] = v2[i] * g_F; + // gl_FragColor = v + v2; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BC %BC2 %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BC "BC" +OpName %i "i" +OpName %U_t "U_t" +OpMemberName %U_t 0 "g_F" +OpName %_ "" +OpName %v2 "v2" +OpName %BC2 "BC2" +OpName %i_0 "i" +OpName %gl_FragColor "gl_FragColor" +OpMemberDecorate %U_t 0 Offset 0 +OpDecorate %U_t Block +OpDecorate %_ DescriptorSet 0 +%void = OpTypeVoid +%13 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BC = OpVariable %_ptr_Input_v4float Input +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%int_4 = OpConstant %int 4 +%bool = OpTypeBool +%_ptr_Function_float = OpTypePointer Function %float +%U_t = OpTypeStruct %float +%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t +%_ = OpVariable %_ptr_Uniform_U_t Uniform +%_ptr_Uniform_float = OpTypePointer Uniform %float +%int_1 = OpConstant %int 1 +%BC2 = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %13 +%28 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%i = OpVariable %_ptr_Function_int Function +%v2 = OpVariable %_ptr_Function_v4float Function +%i_0 = OpVariable %_ptr_Function_int Function +%29 = OpLoad %v4float %BC +OpStore %v %29 +OpStore %i %int_0 +OpBranch %30 +%30 = OpLabel +OpLoopMerge %31 %32 None +OpBranch %33 +%33 = OpLabel +%34 = OpLoad %int %i +%35 = OpSLessThan %bool %34 %int_4 +OpBranchConditional %35 %36 %31 +%36 = OpLabel +%37 = OpLoad %int %i +%38 = OpLoad %int %i +%39 = OpAccessChain %_ptr_Function_float %v %38 +%40 = OpLoad %float %39 +%41 = OpAccessChain %_ptr_Uniform_float %_ %int_0 +%42 = OpLoad %float %41 +%43 = OpFDiv %float %40 %42 +%44 = OpAccessChain %_ptr_Function_float %v %37 +OpStore %44 %43 +OpBranch %32 +%32 = OpLabel +%45 = OpLoad %int %i +%46 = OpIAdd %int %45 %int_1 +OpStore %i %46 +OpBranch %30 +%31 = OpLabel +%47 = OpLoad %v4float %BC2 +OpStore %v2 %47 +OpStore %i_0 %int_0 +OpBranch %48 +%48 = OpLabel +OpLoopMerge %49 %50 None +OpBranch %51 +%51 = OpLabel +%52 = OpLoad %int %i_0 +%53 = OpSLessThan %bool %52 %int_4 +OpBranchConditional %53 %54 %49 +%54 = OpLabel +%55 = OpLoad %int %i_0 +%56 = OpLoad %int %i_0 +%57 = OpAccessChain %_ptr_Function_float %v2 %56 +%58 = OpLoad %float %57 +%59 = OpAccessChain %_ptr_Uniform_float %_ %int_0 +%60 = OpLoad %float %59 +%61 = OpFMul %float %58 %60 +%62 = OpAccessChain %_ptr_Function_float %v2 %55 +OpStore %62 %61 +OpBranch %50 +%50 = OpLabel +%63 = OpLoad %int %i_0 +%64 = OpIAdd %int %63 %int_1 +OpStore %i_0 %64 +OpBranch %48 +%49 = OpLabel +%65 = OpLoad %v4float %v +%66 = OpLoad %v4float %v2 +%67 = OpFAdd %v4float %65 %66 +OpStore %gl_FragColor %67 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %13 +%28 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%i = OpVariable %_ptr_Function_int Function +%v2 = OpVariable %_ptr_Function_v4float Function +%i_0 = OpVariable %_ptr_Function_int Function +%72 = OpLoad %U_t %_ +%73 = OpCompositeExtract %float %72 0 +%29 = OpLoad %v4float %BC +OpStore %v %29 +OpStore %i %int_0 +OpBranch %30 +%30 = OpLabel +OpLoopMerge %31 %32 None +OpBranch %33 +%33 = OpLabel +%34 = OpLoad %int %i +%35 = OpSLessThan %bool %34 %int_4 +OpBranchConditional %35 %36 %31 +%36 = OpLabel +%37 = OpLoad %int %i +%38 = OpLoad %int %i +%39 = OpAccessChain %_ptr_Function_float %v %38 +%40 = OpLoad %float %39 +%43 = OpFDiv %float %40 %73 +%44 = OpAccessChain %_ptr_Function_float %v %37 +OpStore %44 %43 +OpBranch %32 +%32 = OpLabel +%45 = OpLoad %int %i +%46 = OpIAdd %int %45 %int_1 +OpStore %i %46 +OpBranch %30 +%31 = OpLabel +%47 = OpLoad %v4float %BC2 +OpStore %v2 %47 +OpStore %i_0 %int_0 +OpBranch %48 +%48 = OpLabel +OpLoopMerge %49 %50 None +OpBranch %51 +%51 = OpLabel +%52 = OpLoad %int %i_0 +%53 = OpSLessThan %bool %52 %int_4 +OpBranchConditional %53 %54 %49 +%54 = OpLabel +%55 = OpLoad %int %i_0 +%56 = OpLoad %int %i_0 +%57 = OpAccessChain %_ptr_Function_float %v2 %56 +%58 = OpLoad %float %57 +%61 = OpFMul %float %58 %73 +%62 = OpAccessChain %_ptr_Function_float %v2 %55 +OpStore %62 %61 +OpBranch %50 +%50 = OpLabel +%63 = OpLoad %int %i_0 +%64 = OpIAdd %int %63 %int_1 +OpStore %i_0 %64 +OpBranch %48 +%49 = OpLabel +%65 = OpLoad %v4float %v +%66 = OpLoad %v4float %v2 +%67 = OpFAdd %v4float %65 %66 +OpStore %gl_FragColor %67 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); +} + +TEST_F(CommonUniformElimTest, Volatile1) { + // Note: This test exemplifies the following: + // - Same test as Basic1 with the exception that + // the Load of g_F in else-branch is volatile + // - Common uniform (%_) load floated to nearest non-controlled block + // + // #version 140 + // in vec4 BaseColor; + // in float fi; + // + // layout(std140) uniform U_t + // { + // float g_F; + // float g_F2; + // } ; + // + // void main() + // { + // vec4 v = BaseColor; + // if (fi > 0) { + // v = v * g_F; + // } + // else { + // float f2 = g_F2 - g_F; + // v = v * f2; + // } + // gl_FragColor = v; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %fi %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %fi "fi" +OpName %U_t "U_t" +OpMemberName %U_t 0 "g_F" +OpMemberName %U_t 1 "g_F2" +OpName %_ "" +OpName %f2 "f2" +OpName %gl_FragColor "gl_FragColor" +OpMemberDecorate %U_t 0 Offset 0 +OpMemberDecorate %U_t 1 Offset 4 +OpDecorate %U_t Block +OpDecorate %_ DescriptorSet 0 +%void = OpTypeVoid +%11 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%fi = OpVariable %_ptr_Input_float Input +%float_0 = OpConstant %float 0 +%bool = OpTypeBool +%U_t = OpTypeStruct %float %float +%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t +%_ = OpVariable %_ptr_Uniform_U_t Uniform +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%_ptr_Uniform_float = OpTypePointer Uniform %float +%_ptr_Function_float = OpTypePointer Function %float +%int_1 = OpConstant %int 1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %11 +%26 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%f2 = OpVariable %_ptr_Function_float Function +%27 = OpLoad %v4float %BaseColor +OpStore %v %27 +%28 = OpLoad %float %fi +%29 = OpFOrdGreaterThan %bool %28 %float_0 +OpSelectionMerge %30 None +OpBranchConditional %29 %31 %32 +%31 = OpLabel +%33 = OpLoad %v4float %v +%34 = OpAccessChain %_ptr_Uniform_float %_ %int_0 +%35 = OpLoad %float %34 +%36 = OpVectorTimesScalar %v4float %33 %35 +OpStore %v %36 +OpBranch %30 +%32 = OpLabel +%37 = OpAccessChain %_ptr_Uniform_float %_ %int_1 +%38 = OpLoad %float %37 +%39 = OpAccessChain %_ptr_Uniform_float %_ %int_0 +%40 = OpLoad %float %39 Volatile +%41 = OpFSub %float %38 %40 +OpStore %f2 %41 +%42 = OpLoad %v4float %v +%43 = OpLoad %float %f2 +%44 = OpVectorTimesScalar %v4float %42 %43 +OpStore %v %44 +OpBranch %30 +%30 = OpLabel +%45 = OpLoad %v4float %v +OpStore %gl_FragColor %45 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %11 +%26 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%f2 = OpVariable %_ptr_Function_float Function +%50 = OpLoad %U_t %_ +%27 = OpLoad %v4float %BaseColor +OpStore %v %27 +%28 = OpLoad %float %fi +%29 = OpFOrdGreaterThan %bool %28 %float_0 +OpSelectionMerge %30 None +OpBranchConditional %29 %31 %32 +%31 = OpLabel +%33 = OpLoad %v4float %v +%47 = OpCompositeExtract %float %50 0 +%36 = OpVectorTimesScalar %v4float %33 %47 +OpStore %v %36 +OpBranch %30 +%32 = OpLabel +%49 = OpCompositeExtract %float %50 1 +%39 = OpAccessChain %_ptr_Uniform_float %_ %int_0 +%40 = OpLoad %float %39 Volatile +%41 = OpFSub %float %49 %40 +OpStore %f2 %41 +%42 = OpLoad %v4float %v +%43 = OpLoad %float %f2 +%44 = OpVectorTimesScalar %v4float %42 %43 +OpStore %v %44 +OpBranch %30 +%30 = OpLabel +%45 = OpLoad %v4float %v +OpStore %gl_FragColor %45 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); +} + +TEST_F(CommonUniformElimTest, Volatile2) { + // Note: This test exemplifies the following: + // - Same test as Basic1 with the exception that + // U_t is Volatile. + // - No optimizations are applied + // + // #version 430 + // in vec4 BaseColor; + // in float fi; + // + // layout(std430) volatile buffer U_t + // { + // float g_F; + // float g_F2; + // }; + // + // + // void main(void) + // { + // vec4 v = BaseColor; + // if (fi > 0) { + // v = v * g_F; + // } else { + // float f2 = g_F2 - g_F; + // v = v * f2; + // } + // } + + const std::string text = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %fi +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 430 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %fi "fi" +OpName %U_t "U_t" +OpMemberName %U_t 0 "g_F" +OpMemberName %U_t 1 "g_F2" +OpName %_ "" +OpName %f2 "f2" +OpDecorate %BaseColor Location 0 +OpDecorate %fi Location 0 +OpMemberDecorate %U_t 0 Volatile +OpMemberDecorate %U_t 0 Offset 0 +OpMemberDecorate %U_t 1 Volatile +OpMemberDecorate %U_t 1 Offset 4 +OpDecorate %U_t BufferBlock +OpDecorate %_ DescriptorSet 0 +%void = OpTypeVoid +%3 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%fi = OpVariable %_ptr_Input_float Input +%float_0 = OpConstant %float 0 +%bool = OpTypeBool +%U_t = OpTypeStruct %float %float +%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t +%_ = OpVariable %_ptr_Uniform_U_t Uniform +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%_ptr_Uniform_float = OpTypePointer Uniform %float +%_ptr_Function_float = OpTypePointer Function %float +%int_1 = OpConstant %int 1 +%main = OpFunction %void None %3 +%5 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%f2 = OpVariable %_ptr_Function_float Function +%12 = OpLoad %v4float %BaseColor +OpStore %v %12 +%15 = OpLoad %float %fi +%18 = OpFOrdGreaterThan %bool %15 %float_0 +OpSelectionMerge %20 None +OpBranchConditional %18 %19 %31 +%19 = OpLabel +%21 = OpLoad %v4float %v +%28 = OpAccessChain %_ptr_Uniform_float %_ %int_0 +%29 = OpLoad %float %28 +%30 = OpVectorTimesScalar %v4float %21 %29 +OpStore %v %30 +OpBranch %20 +%31 = OpLabel +%35 = OpAccessChain %_ptr_Uniform_float %_ %int_1 +%36 = OpLoad %float %35 +%37 = OpAccessChain %_ptr_Uniform_float %_ %int_0 +%38 = OpLoad %float %37 +%39 = OpFSub %float %36 %38 +OpStore %f2 %39 +%40 = OpLoad %v4float %v +%41 = OpLoad %float %f2 +%42 = OpVectorTimesScalar %v4float %40 %41 +OpStore %v %42 +OpBranch %20 +%20 = OpLabel +OpReturn +OpFunctionEnd +)"; + + Pass::Status res = std::get<1>( + SinglePassRunAndDisassemble(text, true, false)); + EXPECT_EQ(res, Pass::Status::SuccessWithoutChange); +} + +TEST_F(CommonUniformElimTest, Volatile3) { + // Note: This test exemplifies the following: + // - Same test as Volatile2 with the exception that + // the nested struct S is volatile + // - No optimizations are applied + // + // #version 430 + // in vec4 BaseColor; + // in float fi; + // + // struct S { + // volatile float a; + // }; + // + // layout(std430) buffer U_t + // { + // S g_F; + // S g_F2; + // }; + // + // + // void main(void) + // { + // vec4 v = BaseColor; + // if (fi > 0) { + // v = v * g_F.a; + // } else { + // float f2 = g_F2.a - g_F.a; + // v = v * f2; + // } + // } + + const std::string text = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %fi +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 430 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %fi "fi" +OpName %S "S" +OpMemberName %S 0 "a" +OpName %U_t "U_t" +OpMemberName %U_t 0 "g_F" +OpMemberName %U_t 1 "g_F2" +OpName %_ "" +OpName %f2 "f2" +OpDecorate %BaseColor Location 0 +OpDecorate %fi Location 0 +OpMemberDecorate %S 0 Offset 0 +OpMemberDecorate %S 0 Volatile +OpMemberDecorate %U_t 0 Offset 0 +OpMemberDecorate %U_t 1 Offset 4 +OpDecorate %U_t BufferBlock +OpDecorate %_ DescriptorSet 0 +%void = OpTypeVoid +%3 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%fi = OpVariable %_ptr_Input_float Input +%float_0 = OpConstant %float 0 +%bool = OpTypeBool +%S = OpTypeStruct %float +%U_t = OpTypeStruct %S %S +%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t +%_ = OpVariable %_ptr_Uniform_U_t Uniform +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%_ptr_Uniform_float = OpTypePointer Uniform %float +%_ptr_Function_float = OpTypePointer Function %float +%int_1 = OpConstant %int 1 +%main = OpFunction %void None %3 +%5 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%f2 = OpVariable %_ptr_Function_float Function +%12 = OpLoad %v4float %BaseColor +OpStore %v %12 +%15 = OpLoad %float %fi +%18 = OpFOrdGreaterThan %bool %15 %float_0 +OpSelectionMerge %20 None +OpBranchConditional %18 %19 %32 +%19 = OpLabel +%21 = OpLoad %v4float %v +%29 = OpAccessChain %_ptr_Uniform_float %_ %int_0 %int_0 +%30 = OpLoad %float %29 +%31 = OpVectorTimesScalar %v4float %21 %30 +OpStore %v %31 +OpBranch %20 +%32 = OpLabel +%36 = OpAccessChain %_ptr_Uniform_float %_ %int_1 %int_0 +%37 = OpLoad %float %36 +%38 = OpAccessChain %_ptr_Uniform_float %_ %int_0 %int_0 +%39 = OpLoad %float %38 +%40 = OpFSub %float %37 %39 +OpStore %f2 %40 +%41 = OpLoad %v4float %v +%42 = OpLoad %float %f2 +%43 = OpVectorTimesScalar %v4float %41 %42 +OpStore %v %43 +OpBranch %20 +%20 = OpLabel +OpReturn +OpFunctionEnd +)"; + + Pass::Status res = std::get<1>( + SinglePassRunAndDisassemble(text, true, false)); + EXPECT_EQ(res, Pass::Status::SuccessWithoutChange); +} + +TEST_F(CommonUniformElimTest, IteratorDanglingPointer) { + // Note: This test exemplifies the following: + // - Existing common uniform (%_) load kept in place and shared + // + // #version 140 + // in vec4 BaseColor; + // in float fi; + // + // layout(std140) uniform U_t + // { + // bool g_B; + // float g_F; + // } ; + // + // uniform float alpha; + // uniform bool alpha_B; + // + // void main() + // { + // vec4 v = BaseColor; + // if (g_B) { + // v = v * g_F; + // if (alpha_B) + // v = v * alpha; + // else + // v = v * fi; + // } + // gl_FragColor = v; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor %fi +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %U_t "U_t" +OpMemberName %U_t 0 "g_B" +OpMemberName %U_t 1 "g_F" +OpName %alpha "alpha" +OpName %alpha_B "alpha_B" +OpName %_ "" +OpName %gl_FragColor "gl_FragColor" +OpName %fi "fi" +OpMemberDecorate %U_t 0 Offset 0 +OpMemberDecorate %U_t 1 Offset 4 +OpDecorate %U_t Block +OpDecorate %_ DescriptorSet 0 +%void = OpTypeVoid +%12 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%uint = OpTypeInt 32 0 +%U_t = OpTypeStruct %uint %float +%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t +%_ = OpVariable %_ptr_Uniform_U_t Uniform +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%bool = OpTypeBool +%uint_0 = OpConstant %uint 0 +%int_1 = OpConstant %int 1 +%_ptr_Uniform_float = OpTypePointer Uniform %float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%_ptr_Input_float = OpTypePointer Input %float +%fi = OpVariable %_ptr_Input_float Input +%alpha = OpVariable %_ptr_Uniform_float Uniform +%alpha_B = OpVariable %_ptr_Uniform_uint Uniform +)"; + + const std::string before = + R"(%main = OpFunction %void None %12 +%26 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%27 = OpLoad %v4float %BaseColor +OpStore %v %27 +%28 = OpAccessChain %_ptr_Uniform_uint %_ %int_0 +%29 = OpLoad %uint %28 +%30 = OpINotEqual %bool %29 %uint_0 +OpSelectionMerge %31 None +OpBranchConditional %30 %31 %32 +%32 = OpLabel +%47 = OpLoad %v4float %v +OpStore %gl_FragColor %47 +OpReturn +%31 = OpLabel +%33 = OpAccessChain %_ptr_Uniform_float %_ %int_1 +%34 = OpLoad %float %33 +%35 = OpLoad %v4float %v +%36 = OpVectorTimesScalar %v4float %35 %34 +OpStore %v %36 +%37 = OpLoad %uint %alpha_B +%38 = OpIEqual %bool %37 %uint_0 +OpSelectionMerge %43 None +OpBranchConditional %38 %43 %39 +%39 = OpLabel +%40 = OpLoad %float %alpha +%41 = OpLoad %v4float %v +%42 = OpVectorTimesScalar %v4float %41 %40 +OpStore %v %42 +OpBranch %50 +%50 = OpLabel +%51 = OpLoad %v4float %v +OpStore %gl_FragColor %51 +OpReturn +%43 = OpLabel +%44 = OpLoad %float %fi +%45 = OpLoad %v4float %v +%46 = OpVectorTimesScalar %v4float %45 %44 +OpStore %v %46 +OpBranch %60 +%60 = OpLabel +%61 = OpLoad %v4float %v +OpStore %gl_FragColor %61 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %12 +%28 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%29 = OpLoad %v4float %BaseColor +OpStore %v %29 +%54 = OpLoad %U_t %_ +%55 = OpCompositeExtract %uint %54 0 +%32 = OpINotEqual %bool %55 %uint_0 +OpSelectionMerge %33 None +OpBranchConditional %32 %33 %34 +%34 = OpLabel +%35 = OpLoad %v4float %v +OpStore %gl_FragColor %35 +OpReturn +%33 = OpLabel +%58 = OpLoad %float %alpha +%57 = OpCompositeExtract %float %54 1 +%38 = OpLoad %v4float %v +%39 = OpVectorTimesScalar %v4float %38 %57 +OpStore %v %39 +%40 = OpLoad %uint %alpha_B +%41 = OpIEqual %bool %40 %uint_0 +OpSelectionMerge %42 None +OpBranchConditional %41 %42 %43 +%43 = OpLabel +%45 = OpLoad %v4float %v +%46 = OpVectorTimesScalar %v4float %45 %58 +OpStore %v %46 +OpBranch %47 +%47 = OpLabel +%48 = OpLoad %v4float %v +OpStore %gl_FragColor %48 +OpReturn +%42 = OpLabel +%49 = OpLoad %float %fi +%50 = OpLoad %v4float %v +%51 = OpVectorTimesScalar %v4float %50 %49 +OpStore %v %51 +OpBranch %52 +%52 = OpLabel +%53 = OpLoad %v4float %v +OpStore %gl_FragColor %53 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); +} + +TEST_F(CommonUniformElimTest, MixedConstantAndNonConstantIndexes) { + const std::string text = R"( +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Uniform +; CHECK: %501 = OpLabel +; CHECK: [[ld:%\w+]] = OpLoad +; CHECK-NOT: OpCompositeExtract {{%\w+}} {{%\w+}} 0 2 484 +; CHECK: OpAccessChain {{%\w+}} [[var]] %int_0 %int_2 [[ld]] + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "ringeffectLayer_px" %gl_FragCoord %178 %182 + OpExecutionMode %4 OriginUpperLeft + OpSource HLSL 500 + OpDecorate %_arr_v4float_uint_10 ArrayStride 16 + OpMemberDecorate %_struct_20 0 Offset 0 + OpMemberDecorate %_struct_20 1 Offset 16 + OpMemberDecorate %_struct_20 2 Offset 32 + OpMemberDecorate %_struct_21 0 Offset 0 + OpDecorate %_struct_21 Block + OpDecorate %23 DescriptorSet 0 + OpDecorate %gl_FragCoord BuiltIn FragCoord + OpDecorate %178 Location 0 + OpDecorate %182 Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %v2float = OpTypeVector %float 2 +%_ptr_Function_v2float = OpTypePointer Function %v2float + %uint = OpTypeInt 32 0 + %uint_10 = OpConstant %uint 10 +%_arr_v4float_uint_10 = OpTypeArray %v4float %uint_10 + %_struct_20 = OpTypeStruct %v4float %v4float %_arr_v4float_uint_10 + %_struct_21 = OpTypeStruct %_struct_20 +%_ptr_Uniform__struct_21 = OpTypePointer Uniform %_struct_21 + %23 = OpVariable %_ptr_Uniform__struct_21 Uniform + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 +%_ptr_Uniform_v4float = OpTypePointer Uniform %v4float +%_ptr_Uniform_float = OpTypePointer Uniform %float + %uint_3 = OpConstant %uint 3 +%_ptr_Function_v4float = OpTypePointer Function %v4float + %float_0 = OpConstant %float 0 + %43 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +%_ptr_Function_int = OpTypePointer Function %int + %int_5 = OpConstant %int 5 + %bool = OpTypeBool + %int_1 = OpConstant %int 1 + %int_2 = OpConstant %int 2 + %uint_5 = OpConstant %uint 5 +%_arr_v2float_uint_5 = OpTypeArray %v2float %uint_5 +%_ptr_Function__arr_v2float_uint_5 = OpTypePointer Function %_arr_v2float_uint_5 + %82 = OpTypeImage %float 2D 0 0 0 1 Unknown +%_ptr_UniformConstant_82 = OpTypePointer UniformConstant %82 + %86 = OpTypeSampler +%_ptr_UniformConstant_86 = OpTypePointer UniformConstant %86 + %90 = OpTypeSampledImage %82 + %v3float = OpTypeVector %float 3 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%gl_FragCoord = OpVariable %_ptr_Input_v4float Input + %178 = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float + %182 = OpVariable %_ptr_Output_v4float Output + %4 = OpFunction %void None %3 + %5 = OpLabel + %483 = OpVariable %_ptr_Function_v4float Function + %484 = OpVariable %_ptr_Function_int Function + %486 = OpVariable %_ptr_Function__arr_v2float_uint_5 Function + %179 = OpLoad %v4float %178 + %493 = OpAccessChain %_ptr_Uniform_float %23 %int_0 %int_0 %uint_3 + %494 = OpLoad %float %493 + OpStore %483 %43 + OpStore %484 %int_0 + OpBranch %495 + %495 = OpLabel + OpLoopMerge %496 %497 None + OpBranch %498 + %498 = OpLabel + %499 = OpLoad %int %484 + %500 = OpSLessThan %bool %499 %int_5 + OpBranchConditional %500 %501 %496 + %501 = OpLabel + %504 = OpVectorShuffle %v2float %179 %179 0 1 + %505 = OpLoad %int %484 + %506 = OpAccessChain %_ptr_Uniform_v4float %23 %int_0 %int_2 %505 + %507 = OpLoad %v4float %506 + %508 = OpVectorShuffle %v2float %507 %507 0 1 + %509 = OpFAdd %v2float %504 %508 + %512 = OpAccessChain %_ptr_Uniform_v4float %23 %int_0 %int_1 + %513 = OpLoad %v4float %512 + %514 = OpVectorShuffle %v2float %513 %513 0 1 + %517 = OpVectorShuffle %v2float %513 %513 2 3 + %518 = OpExtInst %v2float %1 FClamp %509 %514 %517 + %519 = OpAccessChain %_ptr_Function_v2float %486 %505 + OpStore %519 %518 + OpBranch %497 + %497 = OpLabel + %520 = OpLoad %int %484 + %521 = OpIAdd %int %520 %int_1 + OpStore %484 %521 + OpBranch %495 + %496 = OpLabel + OpReturn + OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndMatch(text, true); +} + +TEST_F(CommonUniformElimTest, LoadPlacedAfterPhi) { + const std::string text = R"( +; CHECK: [[var:%\w+]] = OpVariable {{%\w+}} Uniform +; CHECK: OpSelectionMerge [[merge:%\w+]] +; CHECK: [[merge]] = OpLabel +; CHECK-NEXT: OpPhi +; CHECK-NEXT: OpLoad {{%\w+}} [[var]] + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + OpMemberDecorate %_struct_3 0 Offset 0 + OpDecorate %_struct_3 Block + OpDecorate %4 DescriptorSet 0 + OpDecorate %4 Binding 0 + %void = OpTypeVoid + %6 = OpTypeFunction %void + %bool = OpTypeBool + %false = OpConstantFalse %bool + %uint = OpTypeInt 32 0 + %v2uint = OpTypeVector %uint 2 + %_struct_3 = OpTypeStruct %v2uint +%_ptr_Uniform__struct_3 = OpTypePointer Uniform %_struct_3 + %4 = OpVariable %_ptr_Uniform__struct_3 Uniform + %uint_0 = OpConstant %uint 0 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint + %uint_2 = OpConstant %uint 2 + %2 = OpFunction %void None %6 + %15 = OpLabel + OpSelectionMerge %16 None + OpBranchConditional %false %17 %16 + %17 = OpLabel + OpBranch %16 + %16 = OpLabel + %18 = OpPhi %bool %false %15 %false %17 + OpSelectionMerge %19 None + OpBranchConditional %false %20 %21 + %20 = OpLabel + %22 = OpAccessChain %_ptr_Uniform_uint %4 %uint_0 %uint_0 + %23 = OpLoad %uint %22 + OpBranch %19 + %21 = OpLabel + OpBranch %19 + %19 = OpLabel + OpReturn + OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndMatch(text, true); +} + +// TODO(greg-lunarg): Add tests to verify handling of these cases: +// +// Disqualifying cases: extensions, decorations, non-logical addressing, +// non-structured control flow +// Others? + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/compact_ids_test.cpp b/test/opt/compact_ids_test.cpp new file mode 100644 index 000000000..b1e4b2cbb --- /dev/null +++ b/test/opt/compact_ids_test.cpp @@ -0,0 +1,279 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "spirv-tools/libspirv.hpp" +#include "spirv-tools/optimizer.hpp" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using CompactIdsTest = PassTest<::testing::Test>; + +TEST_F(CompactIdsTest, PassOff) { + const std::string before = + R"(OpCapability Addresses +OpCapability Kernel +OpCapability GenericPointer +OpCapability Linkage +OpMemoryModel Physical32 OpenCL +%99 = OpTypeInt 32 0 +%10 = OpTypeVector %99 2 +%20 = OpConstant %99 2 +%30 = OpTypeArray %99 %20 +)"; + + const std::string after = before; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(before, after, false, false); +} + +TEST_F(CompactIdsTest, PassOn) { + const std::string before = + R"(OpCapability Addresses +OpCapability Kernel +OpCapability GenericPointer +OpCapability Linkage +OpMemoryModel Physical32 OpenCL +OpEntryPoint Kernel %3 "simple_kernel" +%99 = OpTypeInt 32 0 +%10 = OpTypeVector %99 2 +%20 = OpConstant %99 2 +%30 = OpTypeArray %99 %20 +%40 = OpTypeVoid +%50 = OpTypeFunction %40 + %3 = OpFunction %40 None %50 +%70 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(OpCapability Addresses +OpCapability Kernel +OpCapability GenericPointer +OpCapability Linkage +OpMemoryModel Physical32 OpenCL +OpEntryPoint Kernel %1 "simple_kernel" +%2 = OpTypeInt 32 0 +%3 = OpTypeVector %2 2 +%4 = OpConstant %2 2 +%5 = OpTypeArray %2 %4 +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%1 = OpFunction %6 None %7 +%8 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(before, after, false, false); +} + +TEST(CompactIds, InstructionResultIsUpdated) { + // For https://github.com/KhronosGroup/SPIRV-Tools/issues/827 + // In that bug, the compact Ids pass was directly updating the result Id + // word for an OpFunction instruction, but not updating the cached + // result_id_ in that Instruction object. + // + // This test is a bit cheesy. We don't expose internal interfaces enough + // to see the inconsistency. So reproduce the original scenario, with + // compact ids followed by a pass that trips up on the inconsistency. + + const std::string input(R"(OpCapability Shader +OpMemoryModel Logical Simple +OpEntryPoint GLCompute %100 "main" +%200 = OpTypeVoid +%300 = OpTypeFunction %200 +%100 = OpFunction %200 None %300 +%400 = OpLabel +OpReturn +OpFunctionEnd +)"); + + std::vector binary; + const spv_target_env env = SPV_ENV_UNIVERSAL_1_0; + spvtools::SpirvTools tools(env); + auto assembled = tools.Assemble( + input, &binary, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_TRUE(assembled); + + spvtools::Optimizer optimizer(env); + optimizer.RegisterPass(CreateCompactIdsPass()); + // The exhaustive inliner will use the result_id + optimizer.RegisterPass(CreateInlineExhaustivePass()); + + // This should not crash! + optimizer.Run(binary.data(), binary.size(), &binary); + + std::string disassembly; + tools.Disassemble(binary, &disassembly, SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + + const std::string expected(R"(OpCapability Shader +OpMemoryModel Logical Simple +OpEntryPoint GLCompute %1 "main" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%1 = OpFunction %2 None %3 +%4 = OpLabel +OpReturn +OpFunctionEnd +)"); + + EXPECT_THAT(disassembly, ::testing::Eq(expected)); +} + +TEST(CompactIds, HeaderIsUpdated) { + const std::string input(R"(OpCapability Shader +OpMemoryModel Logical Simple +OpEntryPoint GLCompute %100 "main" +%200 = OpTypeVoid +%300 = OpTypeFunction %200 +%100 = OpFunction %200 None %300 +%400 = OpLabel +OpReturn +OpFunctionEnd +)"); + + std::vector binary; + const spv_target_env env = SPV_ENV_UNIVERSAL_1_0; + spvtools::SpirvTools tools(env); + auto assembled = tools.Assemble( + input, &binary, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_TRUE(assembled); + + spvtools::Optimizer optimizer(env); + optimizer.RegisterPass(CreateCompactIdsPass()); + // The exhaustive inliner will use the result_id + optimizer.RegisterPass(CreateInlineExhaustivePass()); + + // This should not crash! + optimizer.Run(binary.data(), binary.size(), &binary); + + std::string disassembly; + tools.Disassemble(binary, &disassembly, SPV_BINARY_TO_TEXT_OPTION_NONE); + + const std::string expected(R"(; SPIR-V +; Version: 1.0 +; Generator: Khronos SPIR-V Tools Assembler; 0 +; Bound: 5 +; Schema: 0 +OpCapability Shader +OpMemoryModel Logical Simple +OpEntryPoint GLCompute %1 "main" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%1 = OpFunction %2 None %3 +%4 = OpLabel +OpReturn +OpFunctionEnd +)"); + + EXPECT_THAT(disassembly, ::testing::Eq(expected)); +} + +// Test context consistency check after invalidating +// CFG and others by compact IDs Pass. +// Uses a GLSL shader with named labels for variety +TEST(CompactIds, ConsistentCheck) { + const std::string input(R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %in_var_A %out_var_SV_TARGET +OpExecutionMode %main OriginUpperLeft +OpSource HLSL 600 +OpName %main "main" +OpName %in_var_A "in.var.A" +OpName %out_var_SV_TARGET "out.var.SV_TARGET" +OpDecorate %in_var_A Location 0 +OpDecorate %out_var_SV_TARGET Location 0 +%void = OpTypeVoid +%3 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%in_var_A = OpVariable %_ptr_Input_v4float Input +%out_var_SV_TARGET = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %3 +%5 = OpLabel +%12 = OpLoad %v4float %in_var_A +%23 = OpVectorShuffle %v4float %12 %12 0 0 0 1 +OpStore %out_var_SV_TARGET %23 +OpReturn +OpFunctionEnd +)"); + + spvtools::SpirvTools tools(SPV_ENV_UNIVERSAL_1_1); + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, input, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(context, nullptr); + + CompactIdsPass compact_id_pass; + context->BuildInvalidAnalyses(compact_id_pass.GetPreservedAnalyses()); + const auto status = compact_id_pass.Run(context.get()); + ASSERT_NE(status, Pass::Status::Failure); + EXPECT_TRUE(context->IsConsistent()); + + // Test output just in case + std::vector binary; + context->module()->ToBinary(&binary, false); + std::string disassembly; + tools.Disassemble(binary, &disassembly, + SpirvTools::kDefaultDisassembleOption); + + const std::string expected(R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %in_var_A %out_var_SV_TARGET +OpExecutionMode %main OriginUpperLeft +OpSource HLSL 600 +OpName %main "main" +OpName %in_var_A "in.var.A" +OpName %out_var_SV_TARGET "out.var.SV_TARGET" +OpDecorate %in_var_A Location 0 +OpDecorate %out_var_SV_TARGET Location 0 +%void = OpTypeVoid +%5 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%in_var_A = OpVariable %_ptr_Input_v4float Input +%out_var_SV_TARGET = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %5 +%10 = OpLabel +%11 = OpLoad %v4float %in_var_A +%12 = OpVectorShuffle %v4float %11 %11 0 0 0 1 +OpStore %out_var_SV_TARGET %12 +OpReturn +OpFunctionEnd +)"); + + EXPECT_THAT(disassembly, ::testing::Eq(expected)); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/constant_manager_test.cpp b/test/opt/constant_manager_test.cpp new file mode 100644 index 000000000..57dea6512 --- /dev/null +++ b/test/opt/constant_manager_test.cpp @@ -0,0 +1,88 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "source/opt/build_module.h" +#include "source/opt/constants.h" +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace opt { +namespace analysis { +namespace { + +using ConstantManagerTest = ::testing::Test; + +TEST_F(ConstantManagerTest, GetDefiningInstruction) { + const std::string text = R"( +%int = OpTypeInt 32 0 +%1 = OpTypeStruct %int +%2 = OpTypeStruct %int + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(context, nullptr); + + Type* struct_type_1 = context->get_type_mgr()->GetType(1); + StructConstant struct_const_1(struct_type_1->AsStruct()); + Instruction* const_inst_1 = + context->get_constant_mgr()->GetDefiningInstruction(&struct_const_1, 1); + EXPECT_EQ(const_inst_1->type_id(), 1); + + Type* struct_type_2 = context->get_type_mgr()->GetType(2); + StructConstant struct_const_2(struct_type_2->AsStruct()); + Instruction* const_inst_2 = + context->get_constant_mgr()->GetDefiningInstruction(&struct_const_2, 2); + EXPECT_EQ(const_inst_2->type_id(), 2); +} + +TEST_F(ConstantManagerTest, GetDefiningInstruction2) { + const std::string text = R"( +%int = OpTypeInt 32 0 +%1 = OpTypeStruct %int +%2 = OpTypeStruct %int +%3 = OpConstantNull %1 +%4 = OpConstantNull %2 + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(context, nullptr); + + Type* struct_type_1 = context->get_type_mgr()->GetType(1); + NullConstant struct_const_1(struct_type_1->AsStruct()); + Instruction* const_inst_1 = + context->get_constant_mgr()->GetDefiningInstruction(&struct_const_1, 1); + EXPECT_EQ(const_inst_1->type_id(), 1); + EXPECT_EQ(const_inst_1->result_id(), 3); + + Type* struct_type_2 = context->get_type_mgr()->GetType(2); + NullConstant struct_const_2(struct_type_2->AsStruct()); + Instruction* const_inst_2 = + context->get_constant_mgr()->GetDefiningInstruction(&struct_const_2, 2); + EXPECT_EQ(const_inst_2->type_id(), 2); + EXPECT_EQ(const_inst_2->result_id(), 4); +} + +} // namespace +} // namespace analysis +} // namespace opt +} // namespace spvtools diff --git a/test/opt/copy_prop_array_test.cpp b/test/opt/copy_prop_array_test.cpp new file mode 100644 index 000000000..504ae67e1 --- /dev/null +++ b/test/opt/copy_prop_array_test.cpp @@ -0,0 +1,1576 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gmock/gmock.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" + +namespace spvtools { +namespace opt { +namespace { + +using CopyPropArrayPassTest = PassTest<::testing::Test>; + +TEST_F(CopyPropArrayPassTest, BasicPropagateArray) { + const std::string before = + R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %in_var_INDEX %out_var_SV_Target +OpExecutionMode %main OriginUpperLeft +OpSource HLSL 600 +OpName %type_MyCBuffer "type.MyCBuffer" +OpMemberName %type_MyCBuffer 0 "Data" +OpName %MyCBuffer "MyCBuffer" +OpName %main "main" +OpName %in_var_INDEX "in.var.INDEX" +OpName %out_var_SV_Target "out.var.SV_Target" +OpDecorate %_arr_v4float_uint_8 ArrayStride 16 +OpMemberDecorate %type_MyCBuffer 0 Offset 0 +OpDecorate %type_MyCBuffer Block +OpDecorate %in_var_INDEX Flat +OpDecorate %in_var_INDEX Location 0 +OpDecorate %out_var_SV_Target Location 0 +OpDecorate %MyCBuffer DescriptorSet 0 +OpDecorate %MyCBuffer Binding 0 +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%uint = OpTypeInt 32 0 +%uint_8 = OpConstant %uint 8 +%_arr_v4float_uint_8 = OpTypeArray %v4float %uint_8 +%type_MyCBuffer = OpTypeStruct %_arr_v4float_uint_8 +%_ptr_Uniform_type_MyCBuffer = OpTypePointer Uniform %type_MyCBuffer +%void = OpTypeVoid +%13 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_arr_v4float_uint_8_0 = OpTypeArray %v4float %uint_8 +%_ptr_Function__arr_v4float_uint_8_0 = OpTypePointer Function %_arr_v4float_uint_8_0 +%int_0 = OpConstant %int 0 +%_ptr_Uniform__arr_v4float_uint_8 = OpTypePointer Uniform %_arr_v4float_uint_8 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%MyCBuffer = OpVariable %_ptr_Uniform_type_MyCBuffer Uniform +%in_var_INDEX = OpVariable %_ptr_Input_int Input +%out_var_SV_Target = OpVariable %_ptr_Output_v4float Output +; CHECK: OpFunction +; CHECK: OpLabel +; CHECK: OpVariable +; CHECK: OpAccessChain +; CHECK: [[new_address:%\w+]] = OpAccessChain %_ptr_Uniform__arr_v4float_uint_8 %MyCBuffer %int_0 +; CHECK: [[element_ptr:%\w+]] = OpAccessChain %_ptr_Uniform_v4float [[new_address]] %24 +; CHECK: [[load:%\w+]] = OpLoad %v4float [[element_ptr]] +; CHECK: OpStore %out_var_SV_Target [[load]] +%main = OpFunction %void None %13 +%22 = OpLabel +%23 = OpVariable %_ptr_Function__arr_v4float_uint_8_0 Function +%24 = OpLoad %int %in_var_INDEX +%25 = OpAccessChain %_ptr_Uniform__arr_v4float_uint_8 %MyCBuffer %int_0 +%26 = OpLoad %_arr_v4float_uint_8 %25 +%27 = OpCompositeExtract %v4float %26 0 +%28 = OpCompositeExtract %v4float %26 1 +%29 = OpCompositeExtract %v4float %26 2 +%30 = OpCompositeExtract %v4float %26 3 +%31 = OpCompositeExtract %v4float %26 4 +%32 = OpCompositeExtract %v4float %26 5 +%33 = OpCompositeExtract %v4float %26 6 +%34 = OpCompositeExtract %v4float %26 7 +%35 = OpCompositeConstruct %_arr_v4float_uint_8_0 %27 %28 %29 %30 %31 %32 %33 %34 +OpStore %23 %35 +%36 = OpAccessChain %_ptr_Function_v4float %23 %24 +%37 = OpLoad %v4float %36 +OpStore %out_var_SV_Target %37 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + SinglePassRunAndMatch(before, false); +} + +TEST_F(CopyPropArrayPassTest, BasicPropagateArrayWithName) { + const std::string before = + R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %in_var_INDEX %out_var_SV_Target +OpExecutionMode %main OriginUpperLeft +OpSource HLSL 600 +OpName %type_MyCBuffer "type.MyCBuffer" +OpMemberName %type_MyCBuffer 0 "Data" +OpName %MyCBuffer "MyCBuffer" +OpName %main "main" +OpName %local "local" +OpName %in_var_INDEX "in.var.INDEX" +OpName %out_var_SV_Target "out.var.SV_Target" +OpDecorate %_arr_v4float_uint_8 ArrayStride 16 +OpMemberDecorate %type_MyCBuffer 0 Offset 0 +OpDecorate %type_MyCBuffer Block +OpDecorate %in_var_INDEX Flat +OpDecorate %in_var_INDEX Location 0 +OpDecorate %out_var_SV_Target Location 0 +OpDecorate %MyCBuffer DescriptorSet 0 +OpDecorate %MyCBuffer Binding 0 +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%uint = OpTypeInt 32 0 +%uint_8 = OpConstant %uint 8 +%_arr_v4float_uint_8 = OpTypeArray %v4float %uint_8 +%type_MyCBuffer = OpTypeStruct %_arr_v4float_uint_8 +%_ptr_Uniform_type_MyCBuffer = OpTypePointer Uniform %type_MyCBuffer +%void = OpTypeVoid +%13 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_arr_v4float_uint_8_0 = OpTypeArray %v4float %uint_8 +%_ptr_Function__arr_v4float_uint_8_0 = OpTypePointer Function %_arr_v4float_uint_8_0 +%int_0 = OpConstant %int 0 +%_ptr_Uniform__arr_v4float_uint_8 = OpTypePointer Uniform %_arr_v4float_uint_8 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%MyCBuffer = OpVariable %_ptr_Uniform_type_MyCBuffer Uniform +%in_var_INDEX = OpVariable %_ptr_Input_int Input +%out_var_SV_Target = OpVariable %_ptr_Output_v4float Output +; CHECK: OpFunction +; CHECK: OpLabel +; CHECK: OpVariable +; CHECK: OpAccessChain +; CHECK: [[new_address:%\w+]] = OpAccessChain %_ptr_Uniform__arr_v4float_uint_8 %MyCBuffer %int_0 +; CHECK: [[element_ptr:%\w+]] = OpAccessChain %_ptr_Uniform_v4float [[new_address]] %24 +; CHECK: [[load:%\w+]] = OpLoad %v4float [[element_ptr]] +; CHECK: OpStore %out_var_SV_Target [[load]] +%main = OpFunction %void None %13 +%22 = OpLabel +%local = OpVariable %_ptr_Function__arr_v4float_uint_8_0 Function +%24 = OpLoad %int %in_var_INDEX +%25 = OpAccessChain %_ptr_Uniform__arr_v4float_uint_8 %MyCBuffer %int_0 +%26 = OpLoad %_arr_v4float_uint_8 %25 +%27 = OpCompositeExtract %v4float %26 0 +%28 = OpCompositeExtract %v4float %26 1 +%29 = OpCompositeExtract %v4float %26 2 +%30 = OpCompositeExtract %v4float %26 3 +%31 = OpCompositeExtract %v4float %26 4 +%32 = OpCompositeExtract %v4float %26 5 +%33 = OpCompositeExtract %v4float %26 6 +%34 = OpCompositeExtract %v4float %26 7 +%35 = OpCompositeConstruct %_arr_v4float_uint_8_0 %27 %28 %29 %30 %31 %32 %33 %34 +OpStore %local %35 +%36 = OpAccessChain %_ptr_Function_v4float %local %24 +%37 = OpLoad %v4float %36 +OpStore %out_var_SV_Target %37 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + SinglePassRunAndMatch(before, false); +} + +// Propagate 2d array. This test identifying a copy through multiple levels. +// Also has to traverse multiple OpAccessChains. +TEST_F(CopyPropArrayPassTest, Propagate2DArray) { + const std::string text = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %in_var_INDEX %out_var_SV_Target +OpExecutionMode %main OriginUpperLeft +OpSource HLSL 600 +OpName %type_MyCBuffer "type.MyCBuffer" +OpMemberName %type_MyCBuffer 0 "Data" +OpName %MyCBuffer "MyCBuffer" +OpName %main "main" +OpName %in_var_INDEX "in.var.INDEX" +OpName %out_var_SV_Target "out.var.SV_Target" +OpDecorate %_arr_v4float_uint_2 ArrayStride 16 +OpDecorate %_arr__arr_v4float_uint_2_uint_2 ArrayStride 32 +OpMemberDecorate %type_MyCBuffer 0 Offset 0 +OpDecorate %type_MyCBuffer Block +OpDecorate %in_var_INDEX Flat +OpDecorate %in_var_INDEX Location 0 +OpDecorate %out_var_SV_Target Location 0 +OpDecorate %MyCBuffer DescriptorSet 0 +OpDecorate %MyCBuffer Binding 0 +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%uint = OpTypeInt 32 0 +%uint_2 = OpConstant %uint 2 +%_arr_v4float_uint_2 = OpTypeArray %v4float %uint_2 +%_arr__arr_v4float_uint_2_uint_2 = OpTypeArray %_arr_v4float_uint_2 %uint_2 +%type_MyCBuffer = OpTypeStruct %_arr__arr_v4float_uint_2_uint_2 +%_ptr_Uniform_type_MyCBuffer = OpTypePointer Uniform %type_MyCBuffer +%void = OpTypeVoid +%14 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_arr_v4float_uint_2_0 = OpTypeArray %v4float %uint_2 +%_arr__arr_v4float_uint_2_0_uint_2 = OpTypeArray %_arr_v4float_uint_2_0 %uint_2 +%_ptr_Function__arr__arr_v4float_uint_2_0_uint_2 = OpTypePointer Function %_arr__arr_v4float_uint_2_0_uint_2 +%int_0 = OpConstant %int 0 +%_ptr_Uniform__arr__arr_v4float_uint_2_uint_2 = OpTypePointer Uniform %_arr__arr_v4float_uint_2_uint_2 +%_ptr_Function__arr_v4float_uint_2_0 = OpTypePointer Function %_arr_v4float_uint_2_0 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%MyCBuffer = OpVariable %_ptr_Uniform_type_MyCBuffer Uniform +%in_var_INDEX = OpVariable %_ptr_Input_int Input +%out_var_SV_Target = OpVariable %_ptr_Output_v4float Output +; CHECK: OpFunction +; CHECK: OpLabel +; CHECK: OpVariable +; CHECK: OpVariable +; CHECK: OpAccessChain +; CHECK: [[new_address:%\w+]] = OpAccessChain %_ptr_Uniform__arr__arr_v4float_uint_2_uint_2 %MyCBuffer %int_0 +%main = OpFunction %void None %14 +%25 = OpLabel +%26 = OpVariable %_ptr_Function__arr_v4float_uint_2_0 Function +%27 = OpVariable %_ptr_Function__arr__arr_v4float_uint_2_0_uint_2 Function +%28 = OpLoad %int %in_var_INDEX +%29 = OpAccessChain %_ptr_Uniform__arr__arr_v4float_uint_2_uint_2 %MyCBuffer %int_0 +%30 = OpLoad %_arr__arr_v4float_uint_2_uint_2 %29 +%31 = OpCompositeExtract %_arr_v4float_uint_2 %30 0 +%32 = OpCompositeExtract %v4float %31 0 +%33 = OpCompositeExtract %v4float %31 1 +%34 = OpCompositeConstruct %_arr_v4float_uint_2_0 %32 %33 +%35 = OpCompositeExtract %_arr_v4float_uint_2 %30 1 +%36 = OpCompositeExtract %v4float %35 0 +%37 = OpCompositeExtract %v4float %35 1 +%38 = OpCompositeConstruct %_arr_v4float_uint_2_0 %36 %37 +%39 = OpCompositeConstruct %_arr__arr_v4float_uint_2_0_uint_2 %34 %38 +; CHECK: OpStore +OpStore %27 %39 +%40 = OpAccessChain %_ptr_Function__arr_v4float_uint_2_0 %27 %28 +%42 = OpAccessChain %_ptr_Function_v4float %40 %28 +%43 = OpLoad %v4float %42 +; CHECK: [[ac1:%\w+]] = OpAccessChain %_ptr_Uniform__arr_v4float_uint_2 [[new_address]] %28 +; CHECK: [[ac2:%\w+]] = OpAccessChain %_ptr_Uniform_v4float [[ac1]] %28 +; CHECK: [[load:%\w+]] = OpLoad %v4float [[ac2]] +; CHECK: OpStore %out_var_SV_Target [[load]] +OpStore %out_var_SV_Target %43 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + SinglePassRunAndMatch(text, false); +} + +// Propagate 2d array. This test identifying a copy through multiple levels. +// Also has to traverse multiple OpAccessChains. +TEST_F(CopyPropArrayPassTest, Propagate2DArrayWithMultiLevelExtract) { + const std::string text = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %in_var_INDEX %out_var_SV_Target +OpExecutionMode %main OriginUpperLeft +OpSource HLSL 600 +OpName %type_MyCBuffer "type.MyCBuffer" +OpMemberName %type_MyCBuffer 0 "Data" +OpName %MyCBuffer "MyCBuffer" +OpName %main "main" +OpName %in_var_INDEX "in.var.INDEX" +OpName %out_var_SV_Target "out.var.SV_Target" +OpDecorate %_arr_v4float_uint_2 ArrayStride 16 +OpDecorate %_arr__arr_v4float_uint_2_uint_2 ArrayStride 32 +OpMemberDecorate %type_MyCBuffer 0 Offset 0 +OpDecorate %type_MyCBuffer Block +OpDecorate %in_var_INDEX Flat +OpDecorate %in_var_INDEX Location 0 +OpDecorate %out_var_SV_Target Location 0 +OpDecorate %MyCBuffer DescriptorSet 0 +OpDecorate %MyCBuffer Binding 0 +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%uint = OpTypeInt 32 0 +%uint_2 = OpConstant %uint 2 +%_arr_v4float_uint_2 = OpTypeArray %v4float %uint_2 +%_arr__arr_v4float_uint_2_uint_2 = OpTypeArray %_arr_v4float_uint_2 %uint_2 +%type_MyCBuffer = OpTypeStruct %_arr__arr_v4float_uint_2_uint_2 +%_ptr_Uniform_type_MyCBuffer = OpTypePointer Uniform %type_MyCBuffer +%void = OpTypeVoid +%14 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_arr_v4float_uint_2_0 = OpTypeArray %v4float %uint_2 +%_arr__arr_v4float_uint_2_0_uint_2 = OpTypeArray %_arr_v4float_uint_2_0 %uint_2 +%_ptr_Function__arr__arr_v4float_uint_2_0_uint_2 = OpTypePointer Function %_arr__arr_v4float_uint_2_0_uint_2 +%int_0 = OpConstant %int 0 +%_ptr_Uniform__arr__arr_v4float_uint_2_uint_2 = OpTypePointer Uniform %_arr__arr_v4float_uint_2_uint_2 +%_ptr_Function__arr_v4float_uint_2_0 = OpTypePointer Function %_arr_v4float_uint_2_0 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%MyCBuffer = OpVariable %_ptr_Uniform_type_MyCBuffer Uniform +%in_var_INDEX = OpVariable %_ptr_Input_int Input +%out_var_SV_Target = OpVariable %_ptr_Output_v4float Output +; CHECK: OpFunction +; CHECK: OpLabel +; CHECK: OpVariable +; CHECK: OpVariable +; CHECK: OpAccessChain +; CHECK: [[new_address:%\w+]] = OpAccessChain %_ptr_Uniform__arr__arr_v4float_uint_2_uint_2 %MyCBuffer %int_0 +%main = OpFunction %void None %14 +%25 = OpLabel +%26 = OpVariable %_ptr_Function__arr_v4float_uint_2_0 Function +%27 = OpVariable %_ptr_Function__arr__arr_v4float_uint_2_0_uint_2 Function +%28 = OpLoad %int %in_var_INDEX +%29 = OpAccessChain %_ptr_Uniform__arr__arr_v4float_uint_2_uint_2 %MyCBuffer %int_0 +%30 = OpLoad %_arr__arr_v4float_uint_2_uint_2 %29 +%32 = OpCompositeExtract %v4float %30 0 0 +%33 = OpCompositeExtract %v4float %30 0 1 +%34 = OpCompositeConstruct %_arr_v4float_uint_2_0 %32 %33 +%36 = OpCompositeExtract %v4float %30 1 0 +%37 = OpCompositeExtract %v4float %30 1 1 +%38 = OpCompositeConstruct %_arr_v4float_uint_2_0 %36 %37 +%39 = OpCompositeConstruct %_arr__arr_v4float_uint_2_0_uint_2 %34 %38 +; CHECK: OpStore +OpStore %27 %39 +%40 = OpAccessChain %_ptr_Function__arr_v4float_uint_2_0 %27 %28 +%42 = OpAccessChain %_ptr_Function_v4float %40 %28 +%43 = OpLoad %v4float %42 +; CHECK: [[ac1:%\w+]] = OpAccessChain %_ptr_Uniform__arr_v4float_uint_2 [[new_address]] %28 +; CHECK: [[ac2:%\w+]] = OpAccessChain %_ptr_Uniform_v4float [[ac1]] %28 +; CHECK: [[load:%\w+]] = OpLoad %v4float [[ac2]] +; CHECK: OpStore %out_var_SV_Target [[load]] +OpStore %out_var_SV_Target %43 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + SinglePassRunAndMatch(text, false); +} + +// Test decomposing an object when we need to "rewrite" a store. +TEST_F(CopyPropArrayPassTest, DecomposeObjectForArrayStore) { + const std::string text = + R"( OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %in_var_INDEX %out_var_SV_Target + OpExecutionMode %main OriginUpperLeft + OpSource HLSL 600 + OpName %type_MyCBuffer "type.MyCBuffer" + OpMemberName %type_MyCBuffer 0 "Data" + OpName %MyCBuffer "MyCBuffer" + OpName %main "main" + OpName %in_var_INDEX "in.var.INDEX" + OpName %out_var_SV_Target "out.var.SV_Target" + OpDecorate %_arr_v4float_uint_2 ArrayStride 16 + OpDecorate %_arr__arr_v4float_uint_2_uint_2 ArrayStride 32 + OpMemberDecorate %type_MyCBuffer 0 Offset 0 + OpDecorate %type_MyCBuffer Block + OpDecorate %in_var_INDEX Flat + OpDecorate %in_var_INDEX Location 0 + OpDecorate %out_var_SV_Target Location 0 + OpDecorate %MyCBuffer DescriptorSet 0 + OpDecorate %MyCBuffer Binding 0 + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%_arr_v4float_uint_2 = OpTypeArray %v4float %uint_2 +%_arr__arr_v4float_uint_2_uint_2 = OpTypeArray %_arr_v4float_uint_2 %uint_2 +%type_MyCBuffer = OpTypeStruct %_arr__arr_v4float_uint_2_uint_2 +%_ptr_Uniform_type_MyCBuffer = OpTypePointer Uniform %type_MyCBuffer + %void = OpTypeVoid + %14 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_arr_v4float_uint_2_0 = OpTypeArray %v4float %uint_2 +%_arr__arr_v4float_uint_2_0_uint_2 = OpTypeArray %_arr_v4float_uint_2_0 %uint_2 +%_ptr_Function__arr__arr_v4float_uint_2_0_uint_2 = OpTypePointer Function %_arr__arr_v4float_uint_2_0_uint_2 + %int_0 = OpConstant %int 0 +%_ptr_Uniform__arr__arr_v4float_uint_2_uint_2 = OpTypePointer Uniform %_arr__arr_v4float_uint_2_uint_2 +%_ptr_Function__arr_v4float_uint_2_0 = OpTypePointer Function %_arr_v4float_uint_2_0 +%_ptr_Function_v4float = OpTypePointer Function %v4float + %MyCBuffer = OpVariable %_ptr_Uniform_type_MyCBuffer Uniform +%in_var_INDEX = OpVariable %_ptr_Input_int Input +%out_var_SV_Target = OpVariable %_ptr_Output_v4float Output + %main = OpFunction %void None %14 + %25 = OpLabel + %26 = OpVariable %_ptr_Function__arr_v4float_uint_2_0 Function + %27 = OpVariable %_ptr_Function__arr__arr_v4float_uint_2_0_uint_2 Function + %28 = OpLoad %int %in_var_INDEX + %29 = OpAccessChain %_ptr_Uniform__arr__arr_v4float_uint_2_uint_2 %MyCBuffer %int_0 + %30 = OpLoad %_arr__arr_v4float_uint_2_uint_2 %29 + %31 = OpCompositeExtract %_arr_v4float_uint_2 %30 0 + %32 = OpCompositeExtract %v4float %31 0 + %33 = OpCompositeExtract %v4float %31 1 + %34 = OpCompositeConstruct %_arr_v4float_uint_2_0 %32 %33 + %35 = OpCompositeExtract %_arr_v4float_uint_2 %30 1 + %36 = OpCompositeExtract %v4float %35 0 + %37 = OpCompositeExtract %v4float %35 1 + %38 = OpCompositeConstruct %_arr_v4float_uint_2_0 %36 %37 + %39 = OpCompositeConstruct %_arr__arr_v4float_uint_2_0_uint_2 %34 %38 + OpStore %27 %39 +; CHECK: [[access_chain:%\w+]] = OpAccessChain %_ptr_Uniform__arr_v4float_uint_2 + %40 = OpAccessChain %_ptr_Function__arr_v4float_uint_2_0 %27 %28 +; CHECK: [[load:%\w+]] = OpLoad %_arr_v4float_uint_2 [[access_chain]] + %41 = OpLoad %_arr_v4float_uint_2_0 %40 +; CHECK: [[extract1:%\w+]] = OpCompositeExtract %v4float [[load]] 0 +; CHECK: [[extract2:%\w+]] = OpCompositeExtract %v4float [[load]] 1 +; CHECK: [[construct:%\w+]] = OpCompositeConstruct %_arr_v4float_uint_2_0 [[extract1]] [[extract2]] +; CHECK: OpStore %26 [[construct]] + OpStore %26 %41 + %42 = OpAccessChain %_ptr_Function_v4float %26 %28 + %43 = OpLoad %v4float %42 + OpStore %out_var_SV_Target %43 + OpReturn + OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + SinglePassRunAndMatch(text, false); +} + +// Test decomposing an object when we need to "rewrite" a store. +TEST_F(CopyPropArrayPassTest, DecomposeObjectForStructStore) { + const std::string text = + R"( OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %in_var_INDEX %out_var_SV_Target + OpExecutionMode %main OriginUpperLeft + OpSource HLSL 600 + OpName %type_MyCBuffer "type.MyCBuffer" + OpMemberName %type_MyCBuffer 0 "Data" + OpName %MyCBuffer "MyCBuffer" + OpName %main "main" + OpName %in_var_INDEX "in.var.INDEX" + OpName %out_var_SV_Target "out.var.SV_Target" + OpMemberDecorate %type_MyCBuffer 0 Offset 0 + OpDecorate %type_MyCBuffer Block + OpDecorate %in_var_INDEX Flat + OpDecorate %in_var_INDEX Location 0 + OpDecorate %out_var_SV_Target Location 0 + OpDecorate %MyCBuffer DescriptorSet 0 + OpDecorate %MyCBuffer Binding 0 +; CHECK: OpDecorate [[decorated_type:%\w+]] GLSLPacked + OpDecorate %struct GLSLPacked + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +; CHECK: [[decorated_type]] = OpTypeStruct +%struct = OpTypeStruct %float %uint +%_arr_struct_uint_2 = OpTypeArray %struct %uint_2 +%type_MyCBuffer = OpTypeStruct %_arr_struct_uint_2 +%_ptr_Uniform_type_MyCBuffer = OpTypePointer Uniform %type_MyCBuffer + %void = OpTypeVoid + %14 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%_ptr_Output_v4float = OpTypePointer Output %v4float +; CHECK: [[struct:%\w+]] = OpTypeStruct %float %uint +%struct_0 = OpTypeStruct %float %uint +%_arr_struct_0_uint_2 = OpTypeArray %struct_0 %uint_2 +%_ptr_Function__arr_struct_0_uint_2 = OpTypePointer Function %_arr_struct_0_uint_2 + %int_0 = OpConstant %int 0 +%_ptr_Uniform__arr_struct_uint_2 = OpTypePointer Uniform %_arr_struct_uint_2 +; CHECK: [[decorated_ptr:%\w+]] = OpTypePointer Uniform [[decorated_type]] +%_ptr_Function_struct_0 = OpTypePointer Function %struct_0 +%_ptr_Function_v4float = OpTypePointer Function %v4float + %MyCBuffer = OpVariable %_ptr_Uniform_type_MyCBuffer Uniform +%in_var_INDEX = OpVariable %_ptr_Input_int Input +%out_var_SV_Target = OpVariable %_ptr_Output_v4float Output + %main = OpFunction %void None %14 + %25 = OpLabel + %26 = OpVariable %_ptr_Function_struct_0 Function + %27 = OpVariable %_ptr_Function__arr_struct_0_uint_2 Function + %28 = OpLoad %int %in_var_INDEX + %29 = OpAccessChain %_ptr_Uniform__arr_struct_uint_2 %MyCBuffer %int_0 + %30 = OpLoad %_arr_struct_uint_2 %29 + %31 = OpCompositeExtract %struct %30 0 + %32 = OpCompositeExtract %v4float %31 0 + %33 = OpCompositeExtract %v4float %31 1 + %34 = OpCompositeConstruct %struct_0 %32 %33 + %35 = OpCompositeExtract %struct %30 1 + %36 = OpCompositeExtract %float %35 0 + %37 = OpCompositeExtract %uint %35 1 + %38 = OpCompositeConstruct %struct_0 %36 %37 + %39 = OpCompositeConstruct %_arr_struct_0_uint_2 %34 %38 + OpStore %27 %39 +; CHECK: [[access_chain:%\w+]] = OpAccessChain [[decorated_ptr]] + %40 = OpAccessChain %_ptr_Function_struct_0 %27 %28 +; CHECK: [[load:%\w+]] = OpLoad [[decorated_type]] [[access_chain]] + %41 = OpLoad %struct_0 %40 +; CHECK: [[extract1:%\w+]] = OpCompositeExtract %float [[load]] 0 +; CHECK: [[extract2:%\w+]] = OpCompositeExtract %uint [[load]] 1 +; CHECK: [[construct:%\w+]] = OpCompositeConstruct [[struct]] [[extract1]] [[extract2]] +; CHECK: OpStore %26 [[construct]] + OpStore %26 %41 + %42 = OpAccessChain %_ptr_Function_v4float %26 %28 + %43 = OpLoad %v4float %42 + OpStore %out_var_SV_Target %43 + OpReturn + OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + SinglePassRunAndMatch(text, false); +} + +TEST_F(CopyPropArrayPassTest, CopyViaInserts) { + const std::string before = + R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %in_var_INDEX %out_var_SV_Target +OpExecutionMode %main OriginUpperLeft +OpSource HLSL 600 +OpName %type_MyCBuffer "type.MyCBuffer" +OpMemberName %type_MyCBuffer 0 "Data" +OpName %MyCBuffer "MyCBuffer" +OpName %main "main" +OpName %in_var_INDEX "in.var.INDEX" +OpName %out_var_SV_Target "out.var.SV_Target" +OpDecorate %_arr_v4float_uint_8 ArrayStride 16 +OpMemberDecorate %type_MyCBuffer 0 Offset 0 +OpDecorate %type_MyCBuffer Block +OpDecorate %in_var_INDEX Flat +OpDecorate %in_var_INDEX Location 0 +OpDecorate %out_var_SV_Target Location 0 +OpDecorate %MyCBuffer DescriptorSet 0 +OpDecorate %MyCBuffer Binding 0 +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%uint = OpTypeInt 32 0 +%uint_8 = OpConstant %uint 8 +%_arr_v4float_uint_8 = OpTypeArray %v4float %uint_8 +%type_MyCBuffer = OpTypeStruct %_arr_v4float_uint_8 +%_ptr_Uniform_type_MyCBuffer = OpTypePointer Uniform %type_MyCBuffer +%void = OpTypeVoid +%13 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_arr_v4float_uint_8_0 = OpTypeArray %v4float %uint_8 +%_ptr_Function__arr_v4float_uint_8_0 = OpTypePointer Function %_arr_v4float_uint_8_0 +%int_0 = OpConstant %int 0 +%_ptr_Uniform__arr_v4float_uint_8 = OpTypePointer Uniform %_arr_v4float_uint_8 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%MyCBuffer = OpVariable %_ptr_Uniform_type_MyCBuffer Uniform +%in_var_INDEX = OpVariable %_ptr_Input_int Input +%out_var_SV_Target = OpVariable %_ptr_Output_v4float Output +; CHECK: OpFunction +; CHECK: OpLabel +; CHECK: OpVariable +; CHECK: OpAccessChain +; CHECK: [[new_address:%\w+]] = OpAccessChain %_ptr_Uniform__arr_v4float_uint_8 %MyCBuffer %int_0 +; CHECK: [[element_ptr:%\w+]] = OpAccessChain %_ptr_Uniform_v4float [[new_address]] %24 +; CHECK: [[load:%\w+]] = OpLoad %v4float [[element_ptr]] +; CHECK: OpStore %out_var_SV_Target [[load]] +%main = OpFunction %void None %13 +%22 = OpLabel +%23 = OpVariable %_ptr_Function__arr_v4float_uint_8_0 Function +%undef = OpUndef %_arr_v4float_uint_8_0 +%24 = OpLoad %int %in_var_INDEX +%25 = OpAccessChain %_ptr_Uniform__arr_v4float_uint_8 %MyCBuffer %int_0 +%26 = OpLoad %_arr_v4float_uint_8 %25 +%27 = OpCompositeExtract %v4float %26 0 +%i0 = OpCompositeInsert %_arr_v4float_uint_8_0 %27 %undef 0 +%28 = OpCompositeExtract %v4float %26 1 +%i1 = OpCompositeInsert %_arr_v4float_uint_8_0 %28 %i0 1 +%29 = OpCompositeExtract %v4float %26 2 +%i2 = OpCompositeInsert %_arr_v4float_uint_8_0 %29 %i1 2 +%30 = OpCompositeExtract %v4float %26 3 +%i3 = OpCompositeInsert %_arr_v4float_uint_8_0 %30 %i2 3 +%31 = OpCompositeExtract %v4float %26 4 +%i4 = OpCompositeInsert %_arr_v4float_uint_8_0 %31 %i3 4 +%32 = OpCompositeExtract %v4float %26 5 +%i5 = OpCompositeInsert %_arr_v4float_uint_8_0 %32 %i4 5 +%33 = OpCompositeExtract %v4float %26 6 +%i6 = OpCompositeInsert %_arr_v4float_uint_8_0 %33 %i5 6 +%34 = OpCompositeExtract %v4float %26 7 +%i7 = OpCompositeInsert %_arr_v4float_uint_8_0 %34 %i6 7 +OpStore %23 %i7 +%36 = OpAccessChain %_ptr_Function_v4float %23 %24 +%37 = OpLoad %v4float %36 +OpStore %out_var_SV_Target %37 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + SinglePassRunAndMatch(before, false); +} + +TEST_F(CopyPropArrayPassTest, IsomorphicTypes1) { + const std::string before = + R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[s1:%\w+]] = OpTypeStruct [[int]] +; CHECK: [[s2:%\w+]] = OpTypeStruct [[s1]] +; CHECK: [[a1:%\w+]] = OpTypeArray [[s2]] +; CHECK: [[s3:%\w+]] = OpTypeStruct [[a1]] +; CHECK: [[p_s3:%\w+]] = OpTypePointer Uniform [[s3]] +; CHECK: [[global_var:%\w+]] = OpVariable [[p_s3]] Uniform +; CHECK: [[p_a1:%\w+]] = OpTypePointer Uniform [[a1]] +; CHECK: [[p_s2:%\w+]] = OpTypePointer Uniform [[s2]] +; CHECK: [[ac1:%\w+]] = OpAccessChain [[p_a1]] [[global_var]] %uint_0 +; CHECK: [[ac2:%\w+]] = OpAccessChain [[p_s2]] [[ac1]] %uint_0 +; CHECK: [[ld:%\w+]] = OpLoad [[s2]] [[ac2]] +; CHECK: [[ex:%\w+]] = OpCompositeExtract [[s1]] [[ld]] + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "PS_main" + OpExecutionMode %2 OriginUpperLeft + OpSource HLSL 600 + OpDecorate %3 DescriptorSet 0 + OpDecorate %3 Binding 101 + %uint = OpTypeInt 32 0 + %uint_1 = OpConstant %uint 1 + %s1 = OpTypeStruct %uint + %s2 = OpTypeStruct %s1 +%a1 = OpTypeArray %s2 %uint_1 + %s3 = OpTypeStruct %a1 + %s1_1 = OpTypeStruct %uint +%_ptr_Uniform_uint = OpTypePointer Uniform %uint + %void = OpTypeVoid + %13 = OpTypeFunction %void + %uint_0 = OpConstant %uint 0 + %s1_0 = OpTypeStruct %uint + %s2_0 = OpTypeStruct %s1_0 +%a1_0 = OpTypeArray %s2_0 %uint_1 + %s3_0 = OpTypeStruct %a1_0 +%p_s3 = OpTypePointer Uniform %s3 +%p_s3_0 = OpTypePointer Function %s3_0 + %3 = OpVariable %p_s3 Uniform +%p_a1_0 = OpTypePointer Function %a1_0 +%p_s2_0 = OpTypePointer Function %s2_0 + %2 = OpFunction %void None %13 + %20 = OpLabel + %21 = OpVariable %p_a1_0 Function + %22 = OpLoad %s3 %3 + %23 = OpCompositeExtract %a1 %22 0 + %24 = OpCompositeExtract %s2 %23 0 + %25 = OpCompositeExtract %s1 %24 0 + %26 = OpCompositeExtract %uint %25 0 + %27 = OpCompositeConstruct %s1_0 %26 + %32 = OpCompositeConstruct %s2_0 %27 + %28 = OpCompositeConstruct %a1_0 %32 + OpStore %21 %28 + %29 = OpAccessChain %p_s2_0 %21 %uint_0 + %30 = OpLoad %s2 %29 + %31 = OpCompositeExtract %s1 %30 0 + OpReturn + OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + SinglePassRunAndMatch(before, false); +} + +TEST_F(CopyPropArrayPassTest, IsomorphicTypes2) { + const std::string before = + R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[s1:%\w+]] = OpTypeStruct [[int]] +; CHECK: [[s2:%\w+]] = OpTypeStruct [[s1]] +; CHECK: [[a1:%\w+]] = OpTypeArray [[s2]] +; CHECK: [[s3:%\w+]] = OpTypeStruct [[a1]] +; CHECK: [[p_s3:%\w+]] = OpTypePointer Uniform [[s3]] +; CHECK: [[global_var:%\w+]] = OpVariable [[p_s3]] Uniform +; CHECK: [[p_s2:%\w+]] = OpTypePointer Uniform [[s2]] +; CHECK: [[p_s1:%\w+]] = OpTypePointer Uniform [[s1]] +; CHECK: [[ac1:%\w+]] = OpAccessChain [[p_s2]] [[global_var]] %uint_0 %uint_0 +; CHECK: [[ac2:%\w+]] = OpAccessChain [[p_s1]] [[ac1]] %uint_0 +; CHECK: [[ld:%\w+]] = OpLoad [[s1]] [[ac2]] +; CHECK: [[ex:%\w+]] = OpCompositeExtract [[int]] [[ld]] + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "PS_main" + OpExecutionMode %2 OriginUpperLeft + OpSource HLSL 600 + OpDecorate %3 DescriptorSet 0 + OpDecorate %3 Binding 101 + %uint = OpTypeInt 32 0 + %uint_1 = OpConstant %uint 1 + %_struct_6 = OpTypeStruct %uint + %_struct_7 = OpTypeStruct %_struct_6 +%_arr__struct_7_uint_1 = OpTypeArray %_struct_7 %uint_1 + %_struct_9 = OpTypeStruct %_arr__struct_7_uint_1 + %_struct_10 = OpTypeStruct %uint +%_ptr_Uniform_uint = OpTypePointer Uniform %uint + %void = OpTypeVoid + %13 = OpTypeFunction %void + %uint_0 = OpConstant %uint 0 + %_struct_15 = OpTypeStruct %uint +%_arr__struct_15_uint_1 = OpTypeArray %_struct_15 %uint_1 +%_ptr_Uniform__struct_9 = OpTypePointer Uniform %_struct_9 +%_ptr_Function__struct_15 = OpTypePointer Function %_struct_15 + %3 = OpVariable %_ptr_Uniform__struct_9 Uniform +%_ptr_Function__arr__struct_15_uint_1 = OpTypePointer Function %_arr__struct_15_uint_1 + %2 = OpFunction %void None %13 + %20 = OpLabel + %21 = OpVariable %_ptr_Function__arr__struct_15_uint_1 Function + %22 = OpLoad %_struct_9 %3 + %23 = OpCompositeExtract %_arr__struct_7_uint_1 %22 0 + %24 = OpCompositeExtract %_struct_7 %23 0 + %25 = OpCompositeExtract %_struct_6 %24 0 + %26 = OpCompositeExtract %uint %25 0 + %27 = OpCompositeConstruct %_struct_15 %26 + %28 = OpCompositeConstruct %_arr__struct_15_uint_1 %27 + OpStore %21 %28 + %29 = OpAccessChain %_ptr_Function__struct_15 %21 %uint_0 + %30 = OpLoad %_struct_15 %29 + %31 = OpCompositeExtract %uint %30 0 + OpReturn + OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + SinglePassRunAndMatch(before, false); +} + +TEST_F(CopyPropArrayPassTest, IsomorphicTypes3) { + const std::string before = + R"( +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[s1:%\w+]] = OpTypeStruct [[int]] +; CHECK: [[s2:%\w+]] = OpTypeStruct [[s1]] +; CHECK: [[a1:%\w+]] = OpTypeArray [[s2]] +; CHECK: [[s3:%\w+]] = OpTypeStruct [[a1]] +; CHECK: [[s1_1:%\w+]] = OpTypeStruct [[int]] +; CHECK: [[p_s3:%\w+]] = OpTypePointer Uniform [[s3]] +; CHECK: [[p_s1_1:%\w+]] = OpTypePointer Function [[s1_1]] +; CHECK: [[global_var:%\w+]] = OpVariable [[p_s3]] Uniform +; CHECK: [[p_s2:%\w+]] = OpTypePointer Uniform [[s2]] +; CHECK: [[p_s1:%\w+]] = OpTypePointer Uniform [[s1]] +; CHECK: [[var:%\w+]] = OpVariable [[p_s1_1]] Function +; CHECK: [[ac1:%\w+]] = OpAccessChain [[p_s2]] [[global_var]] %uint_0 %uint_0 +; CHECK: [[ac2:%\w+]] = OpAccessChain [[p_s1]] [[ac1]] %uint_0 +; CHECK: [[ld:%\w+]] = OpLoad [[s1]] [[ac2]] +; CHECK: [[ex:%\w+]] = OpCompositeExtract [[int]] [[ld]] +; CHECK: [[copy:%\w+]] = OpCompositeConstruct [[s1_1]] [[ex]] +; CHECK: OpStore [[var]] [[copy]] + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "PS_main" + OpExecutionMode %2 OriginUpperLeft + OpSource HLSL 600 + OpDecorate %3 DescriptorSet 0 + OpDecorate %3 Binding 101 + %uint = OpTypeInt 32 0 + %uint_1 = OpConstant %uint 1 + %_struct_6 = OpTypeStruct %uint + %_struct_7 = OpTypeStruct %_struct_6 +%_arr__struct_7_uint_1 = OpTypeArray %_struct_7 %uint_1 + %_struct_9 = OpTypeStruct %_arr__struct_7_uint_1 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint + %void = OpTypeVoid + %13 = OpTypeFunction %void + %uint_0 = OpConstant %uint 0 + %_struct_15 = OpTypeStruct %uint + %_struct_10 = OpTypeStruct %uint +%_arr__struct_15_uint_1 = OpTypeArray %_struct_15 %uint_1 +%_ptr_Uniform__struct_9 = OpTypePointer Uniform %_struct_9 +%_ptr_Function__struct_15 = OpTypePointer Function %_struct_15 + %3 = OpVariable %_ptr_Uniform__struct_9 Uniform +%_ptr_Function__arr__struct_15_uint_1 = OpTypePointer Function %_arr__struct_15_uint_1 + %2 = OpFunction %void None %13 + %20 = OpLabel + %21 = OpVariable %_ptr_Function__arr__struct_15_uint_1 Function + %var = OpVariable %_ptr_Function__struct_15 Function + %22 = OpLoad %_struct_9 %3 + %23 = OpCompositeExtract %_arr__struct_7_uint_1 %22 0 + %24 = OpCompositeExtract %_struct_7 %23 0 + %25 = OpCompositeExtract %_struct_6 %24 0 + %26 = OpCompositeExtract %uint %25 0 + %27 = OpCompositeConstruct %_struct_15 %26 + %28 = OpCompositeConstruct %_arr__struct_15_uint_1 %27 + OpStore %21 %28 + %29 = OpAccessChain %_ptr_Function__struct_15 %21 %uint_0 + %30 = OpLoad %_struct_15 %29 + OpStore %var %30 + OpReturn + OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + SinglePassRunAndMatch(before, false); +} + +TEST_F(CopyPropArrayPassTest, BadMergingTwoObjects) { + // The second element in the |OpCompositeConstruct| is from a different + // object. + const std::string text = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpName %type_ConstBuf "type.ConstBuf" +OpMemberName %type_ConstBuf 0 "TexSizeU" +OpMemberName %type_ConstBuf 1 "TexSizeV" +OpName %ConstBuf "ConstBuf" +OpName %main "main" +OpMemberDecorate %type_ConstBuf 0 Offset 0 +OpMemberDecorate %type_ConstBuf 1 Offset 8 +OpDecorate %type_ConstBuf Block +OpDecorate %ConstBuf DescriptorSet 0 +OpDecorate %ConstBuf Binding 2 +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%type_ConstBuf = OpTypeStruct %v2float %v2float +%_ptr_Uniform_type_ConstBuf = OpTypePointer Uniform %type_ConstBuf +%void = OpTypeVoid +%9 = OpTypeFunction %void +%uint = OpTypeInt 32 0 +%int_0 = OpConstant %uint 0 +%uint_2 = OpConstant %uint 2 +%_arr_v2float_uint_2 = OpTypeArray %v2float %uint_2 +%_ptr_Function__arr_v2float_uint_2 = OpTypePointer Function %_arr_v2float_uint_2 +%_ptr_Uniform_v2float = OpTypePointer Uniform %v2float +%ConstBuf = OpVariable %_ptr_Uniform_type_ConstBuf Uniform +%main = OpFunction %void None %9 +%24 = OpLabel +%25 = OpVariable %_ptr_Function__arr_v2float_uint_2 Function +%27 = OpAccessChain %_ptr_Uniform_v2float %ConstBuf %int_0 +%28 = OpLoad %v2float %27 +%29 = OpAccessChain %_ptr_Uniform_v2float %ConstBuf %int_0 +%30 = OpLoad %v2float %29 +%31 = OpFNegate %v2float %30 +%37 = OpCompositeConstruct %_arr_v2float_uint_2 %28 %31 +OpStore %25 %37 +OpReturn +OpFunctionEnd +)"; + + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ false); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +TEST_F(CopyPropArrayPassTest, SecondElementNotContained) { + // The second element in the |OpCompositeConstruct| is not a memory object. + // Make sure no change happends. + const std::string text = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpName %type_ConstBuf "type.ConstBuf" +OpMemberName %type_ConstBuf 0 "TexSizeU" +OpMemberName %type_ConstBuf 1 "TexSizeV" +OpName %ConstBuf "ConstBuf" +OpName %main "main" +OpMemberDecorate %type_ConstBuf 0 Offset 0 +OpMemberDecorate %type_ConstBuf 1 Offset 8 +OpDecorate %type_ConstBuf Block +OpDecorate %ConstBuf DescriptorSet 0 +OpDecorate %ConstBuf Binding 2 +OpDecorate %ConstBuf2 DescriptorSet 1 +OpDecorate %ConstBuf2 Binding 2 +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%type_ConstBuf = OpTypeStruct %v2float %v2float +%_ptr_Uniform_type_ConstBuf = OpTypePointer Uniform %type_ConstBuf +%void = OpTypeVoid +%9 = OpTypeFunction %void +%uint = OpTypeInt 32 0 +%int_0 = OpConstant %uint 0 +%int_1 = OpConstant %uint 1 +%uint_2 = OpConstant %uint 2 +%_arr_v2float_uint_2 = OpTypeArray %v2float %uint_2 +%_ptr_Function__arr_v2float_uint_2 = OpTypePointer Function %_arr_v2float_uint_2 +%_ptr_Uniform_v2float = OpTypePointer Uniform %v2float +%ConstBuf = OpVariable %_ptr_Uniform_type_ConstBuf Uniform +%ConstBuf2 = OpVariable %_ptr_Uniform_type_ConstBuf Uniform +%main = OpFunction %void None %9 +%24 = OpLabel +%25 = OpVariable %_ptr_Function__arr_v2float_uint_2 Function +%27 = OpAccessChain %_ptr_Uniform_v2float %ConstBuf %int_0 +%28 = OpLoad %v2float %27 +%29 = OpAccessChain %_ptr_Uniform_v2float %ConstBuf2 %int_1 +%30 = OpLoad %v2float %29 +%37 = OpCompositeConstruct %_arr_v2float_uint_2 %28 %30 +OpStore %25 %37 +OpReturn +OpFunctionEnd +)"; + + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ false); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} +// This test will place a load before the store. We cannot propagate in this +// case. +TEST_F(CopyPropArrayPassTest, LoadBeforeStore) { + const std::string text = + R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %in_var_INDEX %out_var_SV_Target +OpExecutionMode %main OriginUpperLeft +OpSource HLSL 600 +OpName %type_MyCBuffer "type.MyCBuffer" +OpMemberName %type_MyCBuffer 0 "Data" +OpName %MyCBuffer "MyCBuffer" +OpName %main "main" +OpName %in_var_INDEX "in.var.INDEX" +OpName %out_var_SV_Target "out.var.SV_Target" +OpDecorate %_arr_v4float_uint_8 ArrayStride 16 +OpMemberDecorate %type_MyCBuffer 0 Offset 0 +OpDecorate %type_MyCBuffer Block +OpDecorate %in_var_INDEX Flat +OpDecorate %in_var_INDEX Location 0 +OpDecorate %out_var_SV_Target Location 0 +OpDecorate %MyCBuffer DescriptorSet 0 +OpDecorate %MyCBuffer Binding 0 +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%uint = OpTypeInt 32 0 +%uint_8 = OpConstant %uint 8 +%_arr_v4float_uint_8 = OpTypeArray %v4float %uint_8 +%type_MyCBuffer = OpTypeStruct %_arr_v4float_uint_8 +%_ptr_Uniform_type_MyCBuffer = OpTypePointer Uniform %type_MyCBuffer +%void = OpTypeVoid +%13 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_arr_v4float_uint_8_0 = OpTypeArray %v4float %uint_8 +%_ptr_Function__arr_v4float_uint_8_0 = OpTypePointer Function %_arr_v4float_uint_8_0 +%int_0 = OpConstant %int 0 +%_ptr_Uniform__arr_v4float_uint_8 = OpTypePointer Uniform %_arr_v4float_uint_8 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%MyCBuffer = OpVariable %_ptr_Uniform_type_MyCBuffer Uniform +%in_var_INDEX = OpVariable %_ptr_Input_int Input +%out_var_SV_Target = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %13 +%22 = OpLabel +%23 = OpVariable %_ptr_Function__arr_v4float_uint_8_0 Function +%38 = OpAccessChain %_ptr_Function_v4float %23 %24 +%39 = OpLoad %v4float %36 +%24 = OpLoad %int %in_var_INDEX +%25 = OpAccessChain %_ptr_Uniform__arr_v4float_uint_8 %MyCBuffer %int_0 +%26 = OpLoad %_arr_v4float_uint_8 %25 +%27 = OpCompositeExtract %v4float %26 0 +%28 = OpCompositeExtract %v4float %26 1 +%29 = OpCompositeExtract %v4float %26 2 +%30 = OpCompositeExtract %v4float %26 3 +%31 = OpCompositeExtract %v4float %26 4 +%32 = OpCompositeExtract %v4float %26 5 +%33 = OpCompositeExtract %v4float %26 6 +%34 = OpCompositeExtract %v4float %26 7 +%35 = OpCompositeConstruct %_arr_v4float_uint_8_0 %27 %28 %29 %30 %31 %32 %33 %34 +OpStore %23 %35 +%36 = OpAccessChain %_ptr_Function_v4float %23 %24 +%37 = OpLoad %v4float %36 +OpStore %out_var_SV_Target %37 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ false); + + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +// This test will place a load where it is not dominated by the store. We +// cannot propagate in this case. +TEST_F(CopyPropArrayPassTest, LoadNotDominated) { + const std::string text = + R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %in_var_INDEX %out_var_SV_Target +OpExecutionMode %main OriginUpperLeft +OpSource HLSL 600 +OpName %type_MyCBuffer "type.MyCBuffer" +OpMemberName %type_MyCBuffer 0 "Data" +OpName %MyCBuffer "MyCBuffer" +OpName %main "main" +OpName %in_var_INDEX "in.var.INDEX" +OpName %out_var_SV_Target "out.var.SV_Target" +OpDecorate %_arr_v4float_uint_8 ArrayStride 16 +OpMemberDecorate %type_MyCBuffer 0 Offset 0 +OpDecorate %type_MyCBuffer Block +OpDecorate %in_var_INDEX Flat +OpDecorate %in_var_INDEX Location 0 +OpDecorate %out_var_SV_Target Location 0 +OpDecorate %MyCBuffer DescriptorSet 0 +OpDecorate %MyCBuffer Binding 0 +%bool = OpTypeBool +%true = OpConstantTrue %bool +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%uint = OpTypeInt 32 0 +%uint_8 = OpConstant %uint 8 +%_arr_v4float_uint_8 = OpTypeArray %v4float %uint_8 +%type_MyCBuffer = OpTypeStruct %_arr_v4float_uint_8 +%_ptr_Uniform_type_MyCBuffer = OpTypePointer Uniform %type_MyCBuffer +%void = OpTypeVoid +%13 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_arr_v4float_uint_8_0 = OpTypeArray %v4float %uint_8 +%_ptr_Function__arr_v4float_uint_8_0 = OpTypePointer Function %_arr_v4float_uint_8_0 +%int_0 = OpConstant %int 0 +%_ptr_Uniform__arr_v4float_uint_8 = OpTypePointer Uniform %_arr_v4float_uint_8 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%MyCBuffer = OpVariable %_ptr_Uniform_type_MyCBuffer Uniform +%in_var_INDEX = OpVariable %_ptr_Input_int Input +%out_var_SV_Target = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %13 +%22 = OpLabel +%23 = OpVariable %_ptr_Function__arr_v4float_uint_8_0 Function +OpSelectionMerge %merge None +OpBranchConditional %true %if %else +%if = OpLabel +%24 = OpLoad %int %in_var_INDEX +%25 = OpAccessChain %_ptr_Uniform__arr_v4float_uint_8 %MyCBuffer %int_0 +%26 = OpLoad %_arr_v4float_uint_8 %25 +%27 = OpCompositeExtract %v4float %26 0 +%28 = OpCompositeExtract %v4float %26 1 +%29 = OpCompositeExtract %v4float %26 2 +%30 = OpCompositeExtract %v4float %26 3 +%31 = OpCompositeExtract %v4float %26 4 +%32 = OpCompositeExtract %v4float %26 5 +%33 = OpCompositeExtract %v4float %26 6 +%34 = OpCompositeExtract %v4float %26 7 +%35 = OpCompositeConstruct %_arr_v4float_uint_8_0 %27 %28 %29 %30 %31 %32 %33 %34 +OpStore %23 %35 +%38 = OpAccessChain %_ptr_Function_v4float %23 %24 +%39 = OpLoad %v4float %36 +OpBranch %merge +%else = OpLabel +%36 = OpAccessChain %_ptr_Function_v4float %23 %24 +%37 = OpLoad %v4float %36 +OpBranch %merge +%merge = OpLabel +%phi = OpPhi %out_var_SV_Target %39 %if %37 %else +OpStore %out_var_SV_Target %phi +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ false); + + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +// This test has a partial store to the variable. We cannot propagate in this +// case. +TEST_F(CopyPropArrayPassTest, PartialStore) { + const std::string text = + R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %in_var_INDEX %out_var_SV_Target +OpExecutionMode %main OriginUpperLeft +OpSource HLSL 600 +OpName %type_MyCBuffer "type.MyCBuffer" +OpMemberName %type_MyCBuffer 0 "Data" +OpName %MyCBuffer "MyCBuffer" +OpName %main "main" +OpName %in_var_INDEX "in.var.INDEX" +OpName %out_var_SV_Target "out.var.SV_Target" +OpDecorate %_arr_v4float_uint_8 ArrayStride 16 +OpMemberDecorate %type_MyCBuffer 0 Offset 0 +OpDecorate %type_MyCBuffer Block +OpDecorate %in_var_INDEX Flat +OpDecorate %in_var_INDEX Location 0 +OpDecorate %out_var_SV_Target Location 0 +OpDecorate %MyCBuffer DescriptorSet 0 +OpDecorate %MyCBuffer Binding 0 +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%uint = OpTypeInt 32 0 +%uint_8 = OpConstant %uint 8 +%_arr_v4float_uint_8 = OpTypeArray %v4float %uint_8 +%type_MyCBuffer = OpTypeStruct %_arr_v4float_uint_8 +%_ptr_Uniform_type_MyCBuffer = OpTypePointer Uniform %type_MyCBuffer +%void = OpTypeVoid +%13 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_arr_v4float_uint_8_0 = OpTypeArray %v4float %uint_8 +%_ptr_Function__arr_v4float_uint_8_0 = OpTypePointer Function %_arr_v4float_uint_8_0 +%int_0 = OpConstant %int 0 +%f0 = OpConstant %float 0 +%v4const = OpConstantComposite %v4float %f0 %f0 %f0 %f0 +%_ptr_Uniform__arr_v4float_uint_8 = OpTypePointer Uniform %_arr_v4float_uint_8 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%MyCBuffer = OpVariable %_ptr_Uniform_type_MyCBuffer Uniform +%in_var_INDEX = OpVariable %_ptr_Input_int Input +%out_var_SV_Target = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %13 +%22 = OpLabel +%23 = OpVariable %_ptr_Function__arr_v4float_uint_8_0 Function +%24 = OpLoad %int %in_var_INDEX +%25 = OpAccessChain %_ptr_Uniform__arr_v4float_uint_8 %MyCBuffer %int_0 +%26 = OpLoad %_arr_v4float_uint_8 %25 +%27 = OpCompositeExtract %v4float %26 0 +%28 = OpCompositeExtract %v4float %26 1 +%29 = OpCompositeExtract %v4float %26 2 +%30 = OpCompositeExtract %v4float %26 3 +%31 = OpCompositeExtract %v4float %26 4 +%32 = OpCompositeExtract %v4float %26 5 +%33 = OpCompositeExtract %v4float %26 6 +%34 = OpCompositeExtract %v4float %26 7 +%35 = OpCompositeConstruct %_arr_v4float_uint_8_0 %27 %28 %29 %30 %31 %32 %33 %34 +OpStore %23 %35 +%36 = OpAccessChain %_ptr_Function_v4float %23 %24 +%37 = OpLoad %v4float %36 +%39 = OpStore %36 %v4const +OpStore %out_var_SV_Target %37 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ false); + + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +// This test does not have a proper copy of an object. We cannot propagate in +// this case. +TEST_F(CopyPropArrayPassTest, NotACopy) { + const std::string text = + R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %in_var_INDEX %out_var_SV_Target +OpExecutionMode %main OriginUpperLeft +OpSource HLSL 600 +OpName %type_MyCBuffer "type.MyCBuffer" +OpMemberName %type_MyCBuffer 0 "Data" +OpName %MyCBuffer "MyCBuffer" +OpName %main "main" +OpName %in_var_INDEX "in.var.INDEX" +OpName %out_var_SV_Target "out.var.SV_Target" +OpDecorate %_arr_v4float_uint_8 ArrayStride 16 +OpMemberDecorate %type_MyCBuffer 0 Offset 0 +OpDecorate %type_MyCBuffer Block +OpDecorate %in_var_INDEX Flat +OpDecorate %in_var_INDEX Location 0 +OpDecorate %out_var_SV_Target Location 0 +OpDecorate %MyCBuffer DescriptorSet 0 +OpDecorate %MyCBuffer Binding 0 +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%uint = OpTypeInt 32 0 +%uint_8 = OpConstant %uint 8 +%_arr_v4float_uint_8 = OpTypeArray %v4float %uint_8 +%type_MyCBuffer = OpTypeStruct %_arr_v4float_uint_8 +%_ptr_Uniform_type_MyCBuffer = OpTypePointer Uniform %type_MyCBuffer +%void = OpTypeVoid +%13 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_arr_v4float_uint_8_0 = OpTypeArray %v4float %uint_8 +%_ptr_Function__arr_v4float_uint_8_0 = OpTypePointer Function %_arr_v4float_uint_8_0 +%int_0 = OpConstant %int 0 +%f0 = OpConstant %float 0 +%v4const = OpConstantComposite %v4float %f0 %f0 %f0 %f0 +%_ptr_Uniform__arr_v4float_uint_8 = OpTypePointer Uniform %_arr_v4float_uint_8 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%MyCBuffer = OpVariable %_ptr_Uniform_type_MyCBuffer Uniform +%in_var_INDEX = OpVariable %_ptr_Input_int Input +%out_var_SV_Target = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %13 +%22 = OpLabel +%23 = OpVariable %_ptr_Function__arr_v4float_uint_8_0 Function +%24 = OpLoad %int %in_var_INDEX +%25 = OpAccessChain %_ptr_Uniform__arr_v4float_uint_8 %MyCBuffer %int_0 +%26 = OpLoad %_arr_v4float_uint_8 %25 +%27 = OpCompositeExtract %v4float %26 0 +%28 = OpCompositeExtract %v4float %26 0 +%29 = OpCompositeExtract %v4float %26 2 +%30 = OpCompositeExtract %v4float %26 3 +%31 = OpCompositeExtract %v4float %26 4 +%32 = OpCompositeExtract %v4float %26 5 +%33 = OpCompositeExtract %v4float %26 6 +%34 = OpCompositeExtract %v4float %26 7 +%35 = OpCompositeConstruct %_arr_v4float_uint_8_0 %27 %28 %29 %30 %31 %32 %33 %34 +OpStore %23 %35 +%36 = OpAccessChain %_ptr_Function_v4float %23 %24 +%37 = OpLoad %v4float %36 +OpStore %out_var_SV_Target %37 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ false); + + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +TEST_F(CopyPropArrayPassTest, BadCopyViaInserts1) { + const std::string text = + R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %in_var_INDEX %out_var_SV_Target +OpExecutionMode %main OriginUpperLeft +OpSource HLSL 600 +OpName %type_MyCBuffer "type.MyCBuffer" +OpMemberName %type_MyCBuffer 0 "Data" +OpName %MyCBuffer "MyCBuffer" +OpName %main "main" +OpName %in_var_INDEX "in.var.INDEX" +OpName %out_var_SV_Target "out.var.SV_Target" +OpDecorate %_arr_v4float_uint_8 ArrayStride 16 +OpMemberDecorate %type_MyCBuffer 0 Offset 0 +OpDecorate %type_MyCBuffer Block +OpDecorate %in_var_INDEX Flat +OpDecorate %in_var_INDEX Location 0 +OpDecorate %out_var_SV_Target Location 0 +OpDecorate %MyCBuffer DescriptorSet 0 +OpDecorate %MyCBuffer Binding 0 +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%uint = OpTypeInt 32 0 +%uint_8 = OpConstant %uint 8 +%_arr_v4float_uint_8 = OpTypeArray %v4float %uint_8 +%type_MyCBuffer = OpTypeStruct %_arr_v4float_uint_8 +%_ptr_Uniform_type_MyCBuffer = OpTypePointer Uniform %type_MyCBuffer +%void = OpTypeVoid +%13 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_arr_v4float_uint_8_0 = OpTypeArray %v4float %uint_8 +%_ptr_Function__arr_v4float_uint_8_0 = OpTypePointer Function %_arr_v4float_uint_8_0 +%int_0 = OpConstant %int 0 +%_ptr_Uniform__arr_v4float_uint_8 = OpTypePointer Uniform %_arr_v4float_uint_8 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%MyCBuffer = OpVariable %_ptr_Uniform_type_MyCBuffer Uniform +%in_var_INDEX = OpVariable %_ptr_Input_int Input +%out_var_SV_Target = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %13 +%22 = OpLabel +%23 = OpVariable %_ptr_Function__arr_v4float_uint_8_0 Function +%undef = OpUndef %_arr_v4float_uint_8_0 +%24 = OpLoad %int %in_var_INDEX +%25 = OpAccessChain %_ptr_Uniform__arr_v4float_uint_8 %MyCBuffer %int_0 +%26 = OpLoad %_arr_v4float_uint_8 %25 +%27 = OpCompositeExtract %v4float %26 0 +%i0 = OpCompositeInsert %_arr_v4float_uint_8_0 %27 %undef 0 +%28 = OpCompositeExtract %v4float %26 1 +%i1 = OpCompositeInsert %_arr_v4float_uint_8_0 %28 %i0 1 +%29 = OpCompositeExtract %v4float %26 2 +%i2 = OpCompositeInsert %_arr_v4float_uint_8_0 %29 %i1 3 +%30 = OpCompositeExtract %v4float %26 3 +%i3 = OpCompositeInsert %_arr_v4float_uint_8_0 %30 %i2 3 +%31 = OpCompositeExtract %v4float %26 4 +%i4 = OpCompositeInsert %_arr_v4float_uint_8_0 %31 %i3 4 +%32 = OpCompositeExtract %v4float %26 5 +%i5 = OpCompositeInsert %_arr_v4float_uint_8_0 %32 %i4 5 +%33 = OpCompositeExtract %v4float %26 6 +%i6 = OpCompositeInsert %_arr_v4float_uint_8_0 %33 %i5 6 +%34 = OpCompositeExtract %v4float %26 7 +%i7 = OpCompositeInsert %_arr_v4float_uint_8_0 %34 %i6 7 +OpStore %23 %i7 +%36 = OpAccessChain %_ptr_Function_v4float %23 %24 +%37 = OpLoad %v4float %36 +OpStore %out_var_SV_Target %37 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ false); + + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +TEST_F(CopyPropArrayPassTest, BadCopyViaInserts2) { + const std::string text = + R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %in_var_INDEX %out_var_SV_Target +OpExecutionMode %main OriginUpperLeft +OpSource HLSL 600 +OpName %type_MyCBuffer "type.MyCBuffer" +OpMemberName %type_MyCBuffer 0 "Data" +OpName %MyCBuffer "MyCBuffer" +OpName %main "main" +OpName %in_var_INDEX "in.var.INDEX" +OpName %out_var_SV_Target "out.var.SV_Target" +OpDecorate %_arr_v4float_uint_8 ArrayStride 16 +OpMemberDecorate %type_MyCBuffer 0 Offset 0 +OpDecorate %type_MyCBuffer Block +OpDecorate %in_var_INDEX Flat +OpDecorate %in_var_INDEX Location 0 +OpDecorate %out_var_SV_Target Location 0 +OpDecorate %MyCBuffer DescriptorSet 0 +OpDecorate %MyCBuffer Binding 0 +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%uint = OpTypeInt 32 0 +%uint_8 = OpConstant %uint 8 +%_arr_v4float_uint_8 = OpTypeArray %v4float %uint_8 +%type_MyCBuffer = OpTypeStruct %_arr_v4float_uint_8 +%_ptr_Uniform_type_MyCBuffer = OpTypePointer Uniform %type_MyCBuffer +%void = OpTypeVoid +%13 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_arr_v4float_uint_8_0 = OpTypeArray %v4float %uint_8 +%_ptr_Function__arr_v4float_uint_8_0 = OpTypePointer Function %_arr_v4float_uint_8_0 +%int_0 = OpConstant %int 0 +%_ptr_Uniform__arr_v4float_uint_8 = OpTypePointer Uniform %_arr_v4float_uint_8 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%MyCBuffer = OpVariable %_ptr_Uniform_type_MyCBuffer Uniform +%in_var_INDEX = OpVariable %_ptr_Input_int Input +%out_var_SV_Target = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %13 +%22 = OpLabel +%23 = OpVariable %_ptr_Function__arr_v4float_uint_8_0 Function +%undef = OpUndef %_arr_v4float_uint_8_0 +%24 = OpLoad %int %in_var_INDEX +%25 = OpAccessChain %_ptr_Uniform__arr_v4float_uint_8 %MyCBuffer %int_0 +%26 = OpLoad %_arr_v4float_uint_8 %25 +%27 = OpCompositeExtract %v4float %26 0 +%i0 = OpCompositeInsert %_arr_v4float_uint_8_0 %27 %undef 0 +%28 = OpCompositeExtract %v4float %26 1 +%i1 = OpCompositeInsert %_arr_v4float_uint_8_0 %28 %i0 1 +%29 = OpCompositeExtract %v4float %26 3 +%i2 = OpCompositeInsert %_arr_v4float_uint_8_0 %29 %i1 2 +%30 = OpCompositeExtract %v4float %26 3 +%i3 = OpCompositeInsert %_arr_v4float_uint_8_0 %30 %i2 3 +%31 = OpCompositeExtract %v4float %26 4 +%i4 = OpCompositeInsert %_arr_v4float_uint_8_0 %31 %i3 4 +%32 = OpCompositeExtract %v4float %26 5 +%i5 = OpCompositeInsert %_arr_v4float_uint_8_0 %32 %i4 5 +%33 = OpCompositeExtract %v4float %26 6 +%i6 = OpCompositeInsert %_arr_v4float_uint_8_0 %33 %i5 6 +%34 = OpCompositeExtract %v4float %26 7 +%i7 = OpCompositeInsert %_arr_v4float_uint_8_0 %34 %i6 7 +OpStore %23 %i7 +%36 = OpAccessChain %_ptr_Function_v4float %23 %24 +%37 = OpLoad %v4float %36 +OpStore %out_var_SV_Target %37 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ false); + + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +TEST_F(CopyPropArrayPassTest, BadCopyViaInserts3) { + const std::string text = + R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %in_var_INDEX %out_var_SV_Target +OpExecutionMode %main OriginUpperLeft +OpSource HLSL 600 +OpName %type_MyCBuffer "type.MyCBuffer" +OpMemberName %type_MyCBuffer 0 "Data" +OpName %MyCBuffer "MyCBuffer" +OpName %main "main" +OpName %in_var_INDEX "in.var.INDEX" +OpName %out_var_SV_Target "out.var.SV_Target" +OpDecorate %_arr_v4float_uint_8 ArrayStride 16 +OpMemberDecorate %type_MyCBuffer 0 Offset 0 +OpDecorate %type_MyCBuffer Block +OpDecorate %in_var_INDEX Flat +OpDecorate %in_var_INDEX Location 0 +OpDecorate %out_var_SV_Target Location 0 +OpDecorate %MyCBuffer DescriptorSet 0 +OpDecorate %MyCBuffer Binding 0 +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%uint = OpTypeInt 32 0 +%uint_8 = OpConstant %uint 8 +%_arr_v4float_uint_8 = OpTypeArray %v4float %uint_8 +%type_MyCBuffer = OpTypeStruct %_arr_v4float_uint_8 +%_ptr_Uniform_type_MyCBuffer = OpTypePointer Uniform %type_MyCBuffer +%void = OpTypeVoid +%13 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_arr_v4float_uint_8_0 = OpTypeArray %v4float %uint_8 +%_ptr_Function__arr_v4float_uint_8_0 = OpTypePointer Function %_arr_v4float_uint_8_0 +%int_0 = OpConstant %int 0 +%_ptr_Uniform__arr_v4float_uint_8 = OpTypePointer Uniform %_arr_v4float_uint_8 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%MyCBuffer = OpVariable %_ptr_Uniform_type_MyCBuffer Uniform +%in_var_INDEX = OpVariable %_ptr_Input_int Input +%out_var_SV_Target = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %13 +%22 = OpLabel +%23 = OpVariable %_ptr_Function__arr_v4float_uint_8_0 Function +%undef = OpUndef %_arr_v4float_uint_8_0 +%24 = OpLoad %int %in_var_INDEX +%25 = OpAccessChain %_ptr_Uniform__arr_v4float_uint_8 %MyCBuffer %int_0 +%26 = OpLoad %_arr_v4float_uint_8 %25 +%28 = OpCompositeExtract %v4float %26 1 +%i1 = OpCompositeInsert %_arr_v4float_uint_8_0 %28 %undef 1 +%29 = OpCompositeExtract %v4float %26 2 +%i2 = OpCompositeInsert %_arr_v4float_uint_8_0 %29 %i1 2 +%30 = OpCompositeExtract %v4float %26 3 +%i3 = OpCompositeInsert %_arr_v4float_uint_8_0 %30 %i2 3 +%31 = OpCompositeExtract %v4float %26 4 +%i4 = OpCompositeInsert %_arr_v4float_uint_8_0 %31 %i3 4 +%32 = OpCompositeExtract %v4float %26 5 +%i5 = OpCompositeInsert %_arr_v4float_uint_8_0 %32 %i4 5 +%33 = OpCompositeExtract %v4float %26 6 +%i6 = OpCompositeInsert %_arr_v4float_uint_8_0 %33 %i5 6 +%34 = OpCompositeExtract %v4float %26 7 +%i7 = OpCompositeInsert %_arr_v4float_uint_8_0 %34 %i6 7 +OpStore %23 %i7 +%36 = OpAccessChain %_ptr_Function_v4float %23 %24 +%37 = OpLoad %v4float %36 +OpStore %out_var_SV_Target %37 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ false); + + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +TEST_F(CopyPropArrayPassTest, AtomicAdd) { + const std::string before = R"(OpCapability SampledBuffer +OpCapability StorageImageExtendedFormats +OpCapability ImageBuffer +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %2 "min" %gl_GlobalInvocationID +OpExecutionMode %2 LocalSize 64 1 1 +OpSource HLSL 600 +OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId +OpDecorate %4 DescriptorSet 4 +OpDecorate %4 Binding 70 +%uint = OpTypeInt 32 0 +%6 = OpTypeImage %uint Buffer 0 0 0 2 R32ui +%_ptr_UniformConstant_6 = OpTypePointer UniformConstant %6 +%_ptr_Function_6 = OpTypePointer Function %6 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%v3uint = OpTypeVector %uint 3 +%_ptr_Input_v3uint = OpTypePointer Input %v3uint +%_ptr_Image_uint = OpTypePointer Image %uint +%4 = OpVariable %_ptr_UniformConstant_6 UniformConstant +%gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input +%2 = OpFunction %void None %10 +%17 = OpLabel +%16 = OpVariable %_ptr_Function_6 Function +%18 = OpLoad %6 %4 +OpStore %16 %18 +%19 = OpImageTexelPointer %_ptr_Image_uint %16 %uint_0 %uint_0 +%20 = OpAtomicIAdd %uint %19 %uint_1 %uint_0 %uint_1 +OpReturn +OpFunctionEnd +)"; + + const std::string after = R"(OpCapability SampledBuffer +OpCapability StorageImageExtendedFormats +OpCapability ImageBuffer +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %2 "min" %gl_GlobalInvocationID +OpExecutionMode %2 LocalSize 64 1 1 +OpSource HLSL 600 +OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId +OpDecorate %4 DescriptorSet 4 +OpDecorate %4 Binding 70 +%uint = OpTypeInt 32 0 +%6 = OpTypeImage %uint Buffer 0 0 0 2 R32ui +%_ptr_UniformConstant_6 = OpTypePointer UniformConstant %6 +%_ptr_Function_6 = OpTypePointer Function %6 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%v3uint = OpTypeVector %uint 3 +%_ptr_Input_v3uint = OpTypePointer Input %v3uint +%_ptr_Image_uint = OpTypePointer Image %uint +%4 = OpVariable %_ptr_UniformConstant_6 UniformConstant +%gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input +%2 = OpFunction %void None %10 +%17 = OpLabel +%16 = OpVariable %_ptr_Function_6 Function +%18 = OpLoad %6 %4 +OpStore %16 %18 +%19 = OpImageTexelPointer %_ptr_Image_uint %4 %uint_0 %uint_0 +%20 = OpAtomicIAdd %uint %19 %uint_1 %uint_0 %uint_1 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(before, after, true, true); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/dead_branch_elim_test.cpp b/test/opt/dead_branch_elim_test.cpp new file mode 100644 index 000000000..66e3bddf7 --- /dev/null +++ b/test/opt/dead_branch_elim_test.cpp @@ -0,0 +1,2878 @@ +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using DeadBranchElimTest = PassTest<::testing::Test>; + +TEST_F(DeadBranchElimTest, IfThenElseTrue) { + // #version 140 + // + // in vec4 BaseColor; + // + // void main() + // { + // vec4 v; + // if (true) + // v = vec4(0.0,0.0,0.0,0.0); + // else + // v = vec4(1.0,1.0,1.0,1.0); + // gl_FragColor = v; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %gl_FragColor %BaseColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %gl_FragColor "gl_FragColor" +OpName %BaseColor "BaseColor" +%void = OpTypeVoid +%7 = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%float_0 = OpConstant %float 0 +%14 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +%float_1 = OpConstant %float 1 +%16 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +)"; + + const std::string before = + R"(%main = OpFunction %void None %7 +%19 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +OpSelectionMerge %20 None +OpBranchConditional %true %21 %22 +%21 = OpLabel +OpStore %v %14 +OpBranch %20 +%22 = OpLabel +OpStore %v %16 +OpBranch %20 +%20 = OpLabel +%23 = OpLoad %v4float %v +OpStore %gl_FragColor %23 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %7 +%19 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +OpBranch %21 +%21 = OpLabel +OpStore %v %14 +OpBranch %20 +%20 = OpLabel +%23 = OpLoad %v4float %v +OpStore %gl_FragColor %23 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); +} + +TEST_F(DeadBranchElimTest, IfThenElseFalse) { + // #version 140 + // + // in vec4 BaseColor; + // + // void main() + // { + // vec4 v; + // if (false) + // v = vec4(0.0,0.0,0.0,0.0); + // else + // v = vec4(1.0,1.0,1.0,1.0); + // gl_FragColor = v; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %gl_FragColor %BaseColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %gl_FragColor "gl_FragColor" +OpName %BaseColor "BaseColor" +%void = OpTypeVoid +%7 = OpTypeFunction %void +%bool = OpTypeBool +%false = OpConstantFalse %bool +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%float_0 = OpConstant %float 0 +%14 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +%float_1 = OpConstant %float 1 +%16 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +)"; + + const std::string before = + R"(%main = OpFunction %void None %7 +%19 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +OpSelectionMerge %20 None +OpBranchConditional %false %21 %22 +%21 = OpLabel +OpStore %v %14 +OpBranch %20 +%22 = OpLabel +OpStore %v %16 +OpBranch %20 +%20 = OpLabel +%23 = OpLoad %v4float %v +OpStore %gl_FragColor %23 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %7 +%19 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +OpBranch %22 +%22 = OpLabel +OpStore %v %16 +OpBranch %20 +%20 = OpLabel +%23 = OpLoad %v4float %v +OpStore %gl_FragColor %23 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); +} + +TEST_F(DeadBranchElimTest, IfThenTrue) { + // #version 140 + // + // in vec4 BaseColor; + // + // void main() + // { + // vec4 v = BaseColor; + // if (true) + // v = v * vec4(0.5,0.5,0.5,0.5); + // gl_FragColor = v; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%bool = OpTypeBool +%true = OpConstantTrue %bool +%float_0_5 = OpConstant %float 0.5 +%15 = OpConstantComposite %v4float %float_0_5 %float_0_5 %float_0_5 %float_0_5 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %7 +%17 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%18 = OpLoad %v4float %BaseColor +OpStore %v %18 +OpSelectionMerge %19 None +OpBranchConditional %true %20 %19 +%20 = OpLabel +%21 = OpLoad %v4float %v +%22 = OpFMul %v4float %21 %15 +OpStore %v %22 +OpBranch %19 +%19 = OpLabel +%23 = OpLoad %v4float %v +OpStore %gl_FragColor %23 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %7 +%17 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%18 = OpLoad %v4float %BaseColor +OpStore %v %18 +OpBranch %20 +%20 = OpLabel +%21 = OpLoad %v4float %v +%22 = OpFMul %v4float %21 %15 +OpStore %v %22 +OpBranch %19 +%19 = OpLabel +%23 = OpLoad %v4float %v +OpStore %gl_FragColor %23 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); +} + +TEST_F(DeadBranchElimTest, IfThenFalse) { + // #version 140 + // + // in vec4 BaseColor; + // + // void main() + // { + // vec4 v = BaseColor; + // if (false) + // v = v * vec4(0.5,0.5,0.5,0.5); + // gl_FragColor = v; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%bool = OpTypeBool +%false = OpConstantFalse %bool +%float_0_5 = OpConstant %float 0.5 +%15 = OpConstantComposite %v4float %float_0_5 %float_0_5 %float_0_5 %float_0_5 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %7 +%17 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%18 = OpLoad %v4float %BaseColor +OpStore %v %18 +OpSelectionMerge %19 None +OpBranchConditional %false %20 %19 +%20 = OpLabel +%21 = OpLoad %v4float %v +%22 = OpFMul %v4float %21 %15 +OpStore %v %22 +OpBranch %19 +%19 = OpLabel +%23 = OpLoad %v4float %v +OpStore %gl_FragColor %23 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %7 +%17 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%18 = OpLoad %v4float %BaseColor +OpStore %v %18 +OpBranch %19 +%19 = OpLabel +%23 = OpLoad %v4float %v +OpStore %gl_FragColor %23 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); +} + +TEST_F(DeadBranchElimTest, IfThenElsePhiTrue) { + // Test handling of phi in merge block after dead branch elimination. + // Note: The SPIR-V has had store/load elimination and phi insertion + // + // #version 140 + // + // void main() + // { + // vec4 v; + // if (true) + // v = vec4(0.0,0.0,0.0,0.0); + // else + // v = vec4(1.0,1.0,1.0,1.0); + // gl_FragColor = v; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%5 = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%float_0 = OpConstant %float 0 +%12 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +%float_1 = OpConstant %float 1 +%14 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%_ptr_Input_v4float = OpTypePointer Input %v4float +)"; + + const std::string before = + R"(%main = OpFunction %void None %5 +%17 = OpLabel +OpSelectionMerge %18 None +OpBranchConditional %true %19 %20 +%19 = OpLabel +OpBranch %18 +%20 = OpLabel +OpBranch %18 +%18 = OpLabel +%21 = OpPhi %v4float %12 %19 %14 %20 +OpStore %gl_FragColor %21 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %5 +%17 = OpLabel +OpBranch %19 +%19 = OpLabel +OpBranch %18 +%18 = OpLabel +OpStore %gl_FragColor %12 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); +} + +TEST_F(DeadBranchElimTest, IfThenElsePhiFalse) { + // Test handling of phi in merge block after dead branch elimination. + // Note: The SPIR-V has had store/load elimination and phi insertion + // + // #version 140 + // + // void main() + // { + // vec4 v; + // if (true) + // v = vec4(0.0,0.0,0.0,0.0); + // else + // v = vec4(1.0,1.0,1.0,1.0); + // gl_FragColor = v; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%5 = OpTypeFunction %void +%bool = OpTypeBool +%false = OpConstantFalse %bool +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%float_0 = OpConstant %float 0 +%12 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +%float_1 = OpConstant %float 1 +%14 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%_ptr_Input_v4float = OpTypePointer Input %v4float +)"; + + const std::string before = + R"(%main = OpFunction %void None %5 +%17 = OpLabel +OpSelectionMerge %18 None +OpBranchConditional %false %19 %20 +%19 = OpLabel +OpBranch %18 +%20 = OpLabel +OpBranch %18 +%18 = OpLabel +%21 = OpPhi %v4float %12 %19 %14 %20 +OpStore %gl_FragColor %21 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %5 +%17 = OpLabel +OpBranch %20 +%20 = OpLabel +OpBranch %18 +%18 = OpLabel +OpStore %gl_FragColor %14 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); +} + +TEST_F(DeadBranchElimTest, CompoundIfThenElseFalse) { + // #version 140 + // + // layout(std140) uniform U_t + // { + // bool g_B ; + // } ; + // + // void main() + // { + // vec4 v; + // if (false) { + // if (g_B) + // v = vec4(0.0,0.0,0.0,0.0); + // else + // v = vec4(1.0,1.0,1.0,1.0); + // } else { + // if (g_B) + // v = vec4(1.0,1.0,1.0,1.0); + // else + // v = vec4(0.0,0.0,0.0,0.0); + // } + // gl_FragColor = v; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %U_t "U_t" +OpMemberName %U_t 0 "g_B" +OpName %_ "" +OpName %v "v" +OpName %gl_FragColor "gl_FragColor" +OpMemberDecorate %U_t 0 Offset 0 +OpDecorate %U_t Block +OpDecorate %_ DescriptorSet 0 +%void = OpTypeVoid +%8 = OpTypeFunction %void +%bool = OpTypeBool +%false = OpConstantFalse %bool +%uint = OpTypeInt 32 0 +%U_t = OpTypeStruct %uint +%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t +%_ = OpVariable %_ptr_Uniform_U_t Uniform +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%uint_0 = OpConstant %uint 0 +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%float_0 = OpConstant %float 0 +%21 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +%float_1 = OpConstant %float 1 +%23 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %8 +%25 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +OpSelectionMerge %26 None +OpBranchConditional %false %27 %28 +%27 = OpLabel +%29 = OpAccessChain %_ptr_Uniform_uint %_ %int_0 +%30 = OpLoad %uint %29 +%31 = OpINotEqual %bool %30 %uint_0 +OpSelectionMerge %32 None +OpBranchConditional %31 %33 %34 +%33 = OpLabel +OpStore %v %21 +OpBranch %32 +%34 = OpLabel +OpStore %v %23 +OpBranch %32 +%32 = OpLabel +OpBranch %26 +%28 = OpLabel +%35 = OpAccessChain %_ptr_Uniform_uint %_ %int_0 +%36 = OpLoad %uint %35 +%37 = OpINotEqual %bool %36 %uint_0 +OpSelectionMerge %38 None +OpBranchConditional %37 %39 %40 +%39 = OpLabel +OpStore %v %23 +OpBranch %38 +%40 = OpLabel +OpStore %v %21 +OpBranch %38 +%38 = OpLabel +OpBranch %26 +%26 = OpLabel +%41 = OpLoad %v4float %v +OpStore %gl_FragColor %41 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %8 +%25 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +OpBranch %28 +%28 = OpLabel +%35 = OpAccessChain %_ptr_Uniform_uint %_ %int_0 +%36 = OpLoad %uint %35 +%37 = OpINotEqual %bool %36 %uint_0 +OpSelectionMerge %38 None +OpBranchConditional %37 %39 %40 +%40 = OpLabel +OpStore %v %21 +OpBranch %38 +%39 = OpLabel +OpStore %v %23 +OpBranch %38 +%38 = OpLabel +OpBranch %26 +%26 = OpLabel +%41 = OpLoad %v4float %v +OpStore %gl_FragColor %41 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); +} + +TEST_F(DeadBranchElimTest, PreventOrphanMerge) { + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%bool = OpTypeBool +%true = OpConstantTrue %bool +%float_0_5 = OpConstant %float 0.5 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %7 +%16 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%17 = OpLoad %v4float %BaseColor +OpStore %v %17 +OpSelectionMerge %18 None +OpBranchConditional %true %19 %20 +%19 = OpLabel +OpKill +%20 = OpLabel +%21 = OpLoad %v4float %v +%22 = OpVectorTimesScalar %v4float %21 %float_0_5 +OpStore %v %22 +OpBranch %18 +%18 = OpLabel +%23 = OpLoad %v4float %v +OpStore %gl_FragColor %23 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %7 +%16 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%17 = OpLoad %v4float %BaseColor +OpStore %v %17 +OpBranch %19 +%19 = OpLabel +OpKill +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); +} + +TEST_F(DeadBranchElimTest, HandleOrphanMerge) { + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %foo_ "foo(" +OpName %gl_FragColor "gl_FragColor" +OpDecorate %gl_FragColor Location 0 +%void = OpTypeVoid +%6 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%9 = OpTypeFunction %v4float +%bool = OpTypeBool +%true = OpConstantTrue %bool +%float_0 = OpConstant %float 0 +%13 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +%float_1 = OpConstant %float 1 +%15 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %6 +%17 = OpLabel +%18 = OpFunctionCall %v4float %foo_ +OpStore %gl_FragColor %18 +OpReturn +OpFunctionEnd +)"; + + const std::string before = + R"(%foo_ = OpFunction %v4float None %9 +%19 = OpLabel +OpSelectionMerge %20 None +OpBranchConditional %true %21 %22 +%21 = OpLabel +OpReturnValue %13 +%22 = OpLabel +OpReturnValue %15 +%20 = OpLabel +%23 = OpUndef %v4float +OpReturnValue %23 +OpFunctionEnd +)"; + + const std::string after = + R"(%foo_ = OpFunction %v4float None %9 +%19 = OpLabel +OpBranch %21 +%21 = OpLabel +OpReturnValue %13 +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); +} + +TEST_F(DeadBranchElimTest, KeepContinueTargetWhenKillAfterMerge) { + // #version 450 + // void main() { + // bool c; + // bool d; + // while(c) { + // if(d) { + // continue; + // } + // if(false) { + // continue; + // } + // discard; + // } + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %c "c" +OpName %d "d" +%void = OpTypeVoid +%6 = OpTypeFunction %void +%bool = OpTypeBool +%_ptr_Function_bool = OpTypePointer Function %bool +%false = OpConstantFalse %bool +)"; + + const std::string before = + R"(%main = OpFunction %void None %6 +%10 = OpLabel +%c = OpVariable %_ptr_Function_bool Function +%d = OpVariable %_ptr_Function_bool Function +OpBranch %11 +%11 = OpLabel +OpLoopMerge %12 %13 None +OpBranch %14 +%14 = OpLabel +%15 = OpLoad %bool %c +OpBranchConditional %15 %16 %12 +%16 = OpLabel +%17 = OpLoad %bool %d +OpSelectionMerge %18 None +OpBranchConditional %17 %19 %18 +%19 = OpLabel +OpBranch %13 +%18 = OpLabel +OpSelectionMerge %20 None +OpBranchConditional %false %21 %20 +%21 = OpLabel +OpBranch %13 +%20 = OpLabel +OpKill +%13 = OpLabel +OpBranch %11 +%12 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %6 +%10 = OpLabel +%c = OpVariable %_ptr_Function_bool Function +%d = OpVariable %_ptr_Function_bool Function +OpBranch %11 +%11 = OpLabel +OpLoopMerge %12 %13 None +OpBranch %14 +%14 = OpLabel +%15 = OpLoad %bool %c +OpBranchConditional %15 %16 %12 +%16 = OpLabel +%17 = OpLoad %bool %d +OpSelectionMerge %18 None +OpBranchConditional %17 %19 %18 +%19 = OpLabel +OpBranch %13 +%18 = OpLabel +OpBranch %20 +%20 = OpLabel +OpKill +%13 = OpLabel +OpBranch %11 +%12 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); +} + +TEST_F(DeadBranchElimTest, DecorateDeleted) { + // Note: SPIR-V hand-edited to add decoration + // #version 140 + // + // in vec4 BaseColor; + // + // void main() + // { + // vec4 v = BaseColor; + // if (false) + // v = v * vec4(0.5,0.5,0.5,0.5); + // gl_FragColor = v; + // } + + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +OpDecorate %22 RelaxedPrecision +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%bool = OpTypeBool +%false = OpConstantFalse %bool +%float_0_5 = OpConstant %float 0.5 +%15 = OpConstantComposite %v4float %float_0_5 %float_0_5 %float_0_5 %float_0_5 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string predefs_after = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%bool = OpTypeBool +%false = OpConstantFalse %bool +%float_0_5 = OpConstant %float 0.5 +%16 = OpConstantComposite %v4float %float_0_5 %float_0_5 %float_0_5 %float_0_5 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %7 +%17 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%18 = OpLoad %v4float %BaseColor +OpStore %v %18 +OpSelectionMerge %19 None +OpBranchConditional %false %20 %19 +%20 = OpLabel +%21 = OpLoad %v4float %v +%22 = OpFMul %v4float %21 %15 +OpStore %v %22 +OpBranch %19 +%19 = OpLabel +%23 = OpLoad %v4float %v +OpStore %gl_FragColor %23 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %8 +%18 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%19 = OpLoad %v4float %BaseColor +OpStore %v %19 +OpBranch %20 +%20 = OpLabel +%23 = OpLoad %v4float %v +OpStore %gl_FragColor %23 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs_before + before, + predefs_after + after, true, true); +} + +TEST_F(DeadBranchElimTest, LoopInDeadBranch) { + // #version 450 + // + // layout(location = 0) in vec4 BaseColor; + // layout(location = 0) out vec4 OutColor; + // + // void main() + // { + // vec4 v = BaseColor; + // if (false) + // for (int i=0; i<3; i++) + // v = v * 0.5; + // OutColor = v; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %i "i" +OpName %OutColor "OutColor" +OpDecorate %BaseColor Location 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%bool = OpTypeBool +%false = OpConstantFalse %bool +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%int_3 = OpConstant %int 3 +%float_0_5 = OpConstant %float 0.5 +%int_1 = OpConstant %int 1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %8 +%22 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%i = OpVariable %_ptr_Function_int Function +%23 = OpLoad %v4float %BaseColor +OpStore %v %23 +OpSelectionMerge %24 None +OpBranchConditional %false %25 %24 +%25 = OpLabel +OpStore %i %int_0 +OpBranch %26 +%26 = OpLabel +OpLoopMerge %27 %28 None +OpBranch %29 +%29 = OpLabel +%30 = OpLoad %int %i +%31 = OpSLessThan %bool %30 %int_3 +OpBranchConditional %31 %32 %27 +%32 = OpLabel +%33 = OpLoad %v4float %v +%34 = OpVectorTimesScalar %v4float %33 %float_0_5 +OpStore %v %34 +OpBranch %28 +%28 = OpLabel +%35 = OpLoad %int %i +%36 = OpIAdd %int %35 %int_1 +OpStore %i %36 +OpBranch %26 +%27 = OpLabel +OpBranch %24 +%24 = OpLabel +%37 = OpLoad %v4float %v +OpStore %OutColor %37 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %8 +%22 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%i = OpVariable %_ptr_Function_int Function +%23 = OpLoad %v4float %BaseColor +OpStore %v %23 +OpBranch %24 +%24 = OpLabel +%37 = OpLoad %v4float %v +OpStore %OutColor %37 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); +} + +TEST_F(DeadBranchElimTest, SwitchLiveCase) { + // #version 450 + // + // layout (location=0) in vec4 BaseColor; + // layout (location=0) out vec4 OutColor; + // + // void main() + // { + // switch (1) { + // case 0: + // OutColor = vec4(0.0,0.0,0.0,0.0); + // break; + // case 1: + // OutColor = vec4(0.125,0.125,0.125,0.125); + // break; + // case 2: + // OutColor = vec4(0.25,0.25,0.25,0.25); + // break; + // default: + // OutColor = vec4(1.0,1.0,1.0,1.0); + // } + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %OutColor %BaseColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %OutColor "OutColor" +OpName %BaseColor "BaseColor" +OpDecorate %OutColor Location 0 +OpDecorate %BaseColor Location 0 +%void = OpTypeVoid +%6 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%int_1 = OpConstant %int 1 +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%float_0 = OpConstant %float 0 +%13 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +%float_0_125 = OpConstant %float 0.125 +%15 = OpConstantComposite %v4float %float_0_125 %float_0_125 %float_0_125 %float_0_125 +%float_0_25 = OpConstant %float 0.25 +%17 = OpConstantComposite %v4float %float_0_25 %float_0_25 %float_0_25 %float_0_25 +%float_1 = OpConstant %float 1 +%19 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +)"; + + const std::string before = + R"(%main = OpFunction %void None %6 +%21 = OpLabel +OpSelectionMerge %22 None +OpSwitch %int_1 %23 0 %24 1 %25 2 %26 +%23 = OpLabel +OpStore %OutColor %19 +OpBranch %22 +%24 = OpLabel +OpStore %OutColor %13 +OpBranch %22 +%25 = OpLabel +OpStore %OutColor %15 +OpBranch %22 +%26 = OpLabel +OpStore %OutColor %17 +OpBranch %22 +%22 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %6 +%21 = OpLabel +OpBranch %25 +%25 = OpLabel +OpStore %OutColor %15 +OpBranch %22 +%22 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); +} + +TEST_F(DeadBranchElimTest, SwitchLiveDefault) { + // #version 450 + // + // layout (location=0) in vec4 BaseColor; + // layout (location=0) out vec4 OutColor; + // + // void main() + // { + // switch (7) { + // case 0: + // OutColor = vec4(0.0,0.0,0.0,0.0); + // break; + // case 1: + // OutColor = vec4(0.125,0.125,0.125,0.125); + // break; + // case 2: + // OutColor = vec4(0.25,0.25,0.25,0.25); + // break; + // default: + // OutColor = vec4(1.0,1.0,1.0,1.0); + // } + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %OutColor %BaseColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %OutColor "OutColor" +OpName %BaseColor "BaseColor" +OpDecorate %OutColor Location 0 +OpDecorate %BaseColor Location 0 +%void = OpTypeVoid +%6 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%int_7 = OpConstant %int 7 +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%float_0 = OpConstant %float 0 +%13 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +%float_0_125 = OpConstant %float 0.125 +%15 = OpConstantComposite %v4float %float_0_125 %float_0_125 %float_0_125 %float_0_125 +%float_0_25 = OpConstant %float 0.25 +%17 = OpConstantComposite %v4float %float_0_25 %float_0_25 %float_0_25 %float_0_25 +%float_1 = OpConstant %float 1 +%19 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +)"; + + const std::string before = + R"(%main = OpFunction %void None %6 +%21 = OpLabel +OpSelectionMerge %22 None +OpSwitch %int_7 %23 0 %24 1 %25 2 %26 +%23 = OpLabel +OpStore %OutColor %19 +OpBranch %22 +%24 = OpLabel +OpStore %OutColor %13 +OpBranch %22 +%25 = OpLabel +OpStore %OutColor %15 +OpBranch %22 +%26 = OpLabel +OpStore %OutColor %17 +OpBranch %22 +%22 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %6 +%21 = OpLabel +OpBranch %23 +%23 = OpLabel +OpStore %OutColor %19 +OpBranch %22 +%22 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); +} + +TEST_F(DeadBranchElimTest, SwitchLiveCaseBreakFromLoop) { + // This sample does not directly translate to GLSL/HLSL as + // direct breaks from a loop cannot be made from a switch. + // This construct is currently formed by inlining a function + // containing early returns from the cases of a switch. The + // function is wrapped in a one-trip loop and returns are + // translated to branches to the loop's merge block. + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %OutColor %BaseColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %oc "oc" +OpName %OutColor "OutColor" +OpName %BaseColor "BaseColor" +OpDecorate %OutColor Location 0 +OpDecorate %BaseColor Location 0 +%void = OpTypeVoid +%7 = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%false = OpConstantFalse %bool +%int = OpTypeInt 32 1 +%int_1 = OpConstant %int 1 +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%float_0 = OpConstant %float 0 +%17 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +%float_0_125 = OpConstant %float 0.125 +%19 = OpConstantComposite %v4float %float_0_125 %float_0_125 %float_0_125 %float_0_125 +%float_0_25 = OpConstant %float 0.25 +%21 = OpConstantComposite %v4float %float_0_25 %float_0_25 %float_0_25 %float_0_25 +%float_1 = OpConstant %float 1 +%23 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +)"; + + const std::string before = + R"(%main = OpFunction %void None %7 +%26 = OpLabel +%oc = OpVariable %_ptr_Function_v4float Function +OpBranch %27 +%27 = OpLabel +OpLoopMerge %28 %29 None +OpBranch %30 +%30 = OpLabel +OpSelectionMerge %31 None +OpSwitch %int_1 %31 0 %32 1 %33 2 %34 +%32 = OpLabel +OpStore %oc %17 +OpBranch %28 +%33 = OpLabel +OpStore %oc %19 +OpBranch %28 +%34 = OpLabel +OpStore %oc %21 +OpBranch %28 +%31 = OpLabel +OpStore %oc %23 +OpBranch %28 +%29 = OpLabel +OpBranchConditional %false %27 %28 +%28 = OpLabel +%35 = OpLoad %v4float %oc +OpStore %OutColor %35 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %7 +%26 = OpLabel +%oc = OpVariable %_ptr_Function_v4float Function +OpBranch %27 +%27 = OpLabel +OpLoopMerge %28 %29 None +OpBranch %30 +%30 = OpLabel +OpBranch %33 +%33 = OpLabel +OpStore %oc %19 +OpBranch %28 +%29 = OpLabel +OpBranch %27 +%28 = OpLabel +%35 = OpLoad %v4float %oc +OpStore %OutColor %35 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); +} + +TEST_F(DeadBranchElimTest, LeaveContinueBackedge) { + const std::string text = R"( +; CHECK: OpLoopMerge [[merge:%\w+]] [[continue:%\w+]] None +; CHECK: [[continue]] = OpLabel +; CHECK-NEXT: OpBranchConditional {{%\w+}} {{%\w+}} [[merge]] +; CHECK-NEXT: [[merge]] = OpLabel +; CHECK-NEXT: OpReturn +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +%bool = OpTypeBool +%false = OpConstantFalse %bool +%void = OpTypeVoid +%funcTy = OpTypeFunction %void +%func = OpFunction %void None %funcTy +%1 = OpLabel +OpBranch %2 +%2 = OpLabel +OpLoopMerge %3 %4 None +OpBranch %4 +%4 = OpLabel +; Be careful we don't remove the backedge to %2 despite never taking it. +OpBranchConditional %false %2 %3 +%3 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} +TEST_F(DeadBranchElimTest, LeaveContinueBackedgeExtraBlock) { + const std::string text = R"( +; CHECK: OpBranch [[header:%\w+]] +; CHECK: OpLoopMerge [[merge:%\w+]] [[continue:%\w+]] None +; CHECK-NEXT: OpBranch [[continue]] +; CHECK-NEXT: [[continue]] = OpLabel +; CHECK-NEXT: OpBranchConditional {{%\w+}} [[extra:%\w+]] [[merge]] +; CHECK-NEXT: [[extra]] = OpLabel +; CHECK-NEXT: OpBranch [[header]] +; CHECK-NEXT: [[merge]] = OpLabel +; CHECK-NEXT: OpReturn +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +%bool = OpTypeBool +%false = OpConstantFalse %bool +%void = OpTypeVoid +%funcTy = OpTypeFunction %void +%func = OpFunction %void None %funcTy +%1 = OpLabel +OpBranch %2 +%2 = OpLabel +OpLoopMerge %3 %4 None +OpBranch %4 +%4 = OpLabel +; Be careful we don't remove the backedge to %2 despite never taking it. +OpBranchConditional %false %5 %3 +; This block remains live despite being unreachable. +%5 = OpLabel +OpBranch %2 +%3 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(DeadBranchElimTest, RemovePhiWithUnreachableContinue) { + const std::string text = R"( +; CHECK: [[entry:%\w+]] = OpLabel +; CHECK-NEXT: OpBranch [[header:%\w+]] +; CHECK: OpLoopMerge [[merge:%\w+]] [[continue:%\w+]] None +; CHECK-NEXT: OpBranch [[ret:%\w+]] +; CHECK-NEXT: [[ret]] = OpLabel +; CHECK-NEXT: OpReturn +; CHECK: [[continue]] = OpLabel +; CHECK-NEXT: OpBranch [[header]] +; CHECK: [[merge]] = OpLabel +; CHECK-NEXT: OpUnreachable +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +OpName %func "func" +OpDecorate %func LinkageAttributes "func" Export +%bool = OpTypeBool +%false = OpConstantFalse %bool +%true = OpConstantTrue %bool +%void = OpTypeVoid +%funcTy = OpTypeFunction %void +%func = OpFunction %void None %funcTy +%1 = OpLabel +OpBranch %2 +%2 = OpLabel +%phi = OpPhi %bool %false %1 %true %continue +OpLoopMerge %merge %continue None +OpBranch %3 +%3 = OpLabel +OpReturn +%continue = OpLabel +OpBranch %2 +%merge = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(DeadBranchElimTest, UnreachableLoopMergeAndContinueTargets) { + const std::string text = R"( +; CHECK: [[undef:%\w+]] = OpUndef %bool +; CHECK: OpSelectionMerge [[header:%\w+]] +; CHECK-NEXT: OpBranchConditional {{%\w+}} [[if_lab:%\w+]] [[else_lab:%\w+]] +; CHECK: OpPhi %bool %false [[if_lab]] %false [[else_lab]] [[undef]] [[continue:%\w+]] +; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None +; CHECK-NEXT: OpBranch [[ret:%\w+]] +; CHECK-NEXT: [[ret]] = OpLabel +; CHECK-NEXT: OpReturn +; CHECK: [[continue]] = OpLabel +; CHECK-NEXT: OpBranch [[header]] +; CHECK: [[merge]] = OpLabel +; CHECK-NEXT: OpUnreachable +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +OpName %func "func" +OpDecorate %func LinkageAttributes "func" Export +%bool = OpTypeBool +%false = OpConstantFalse %bool +%true = OpConstantTrue %bool +%void = OpTypeVoid +%funcTy = OpTypeFunction %void +%func = OpFunction %void None %funcTy +%1 = OpLabel +%c = OpUndef %bool +OpSelectionMerge %2 None +OpBranchConditional %c %if %else +%if = OpLabel +OpBranch %2 +%else = OpLabel +OpBranch %2 +%2 = OpLabel +%phi = OpPhi %bool %false %if %false %else %true %continue +OpLoopMerge %merge %continue None +OpBranch %3 +%3 = OpLabel +OpReturn +%continue = OpLabel +OpBranch %2 +%merge = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} +TEST_F(DeadBranchElimTest, EarlyReconvergence) { + const std::string text = R"( +; CHECK-NOT: OpBranchConditional +; CHECK: [[logical:%\w+]] = OpLogicalOr +; CHECK-NOT: OpPhi +; CHECK: OpLogicalAnd {{%\w+}} {{%\w+}} [[logical]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%false = OpConstantFalse %bool +%true = OpConstantTrue %bool +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +OpSelectionMerge %2 None +OpBranchConditional %false %3 %4 +%3 = OpLabel +%12 = OpLogicalNot %bool %true +OpBranch %2 +%4 = OpLabel +OpSelectionMerge %14 None +OpBranchConditional %false %5 %6 +%5 = OpLabel +%10 = OpLogicalAnd %bool %true %false +OpBranch %7 +%6 = OpLabel +%11 = OpLogicalOr %bool %true %false +OpBranch %7 +%7 = OpLabel +; This phi is in a block preceeding the merge %14! +%8 = OpPhi %bool %10 %5 %11 %6 +OpBranch %14 +%14 = OpLabel +OpBranch %2 +%2 = OpLabel +%9 = OpPhi %bool %12 %3 %8 %14 +%13 = OpLogicalAnd %bool %true %9 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(DeadBranchElimTest, RemoveUnreachableBlocksFloating) { + const std::string text = R"( +; CHECK: OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpReturn +; CHECK-NEXT: OpFunctionEnd +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +OpName %func "func" +OpDecorate %func LinkageAttributes "func" Export +%void = OpTypeVoid +%1 = OpTypeFunction %void +%func = OpFunction %void None %1 +%2 = OpLabel +OpReturn +%3 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(DeadBranchElimTest, RemoveUnreachableBlocksFloatingJoin) { + const std::string text = R"( +; CHECK: OpFunction +; CHECK-NEXT: OpFunctionParameter +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpReturn +; CHECK-NEXT: OpFunctionEnd +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +OpName %func "func" +OpDecorate %func LinkageAttributes "func" Export +%void = OpTypeVoid +%bool = OpTypeBool +%false = OpConstantFalse %bool +%true = OpConstantTrue %bool +%1 = OpTypeFunction %void %bool +%func = OpFunction %void None %1 +%bool_param = OpFunctionParameter %bool +%2 = OpLabel +OpReturn +%3 = OpLabel +OpSelectionMerge %6 None +OpBranchConditional %bool_param %4 %5 +%4 = OpLabel +OpBranch %6 +%5 = OpLabel +OpBranch %6 +%6 = OpLabel +%7 = OpPhi %bool %true %4 %false %6 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(DeadBranchElimTest, RemoveUnreachableBlocksDeadPhi) { + const std::string text = R"( +; CHECK: OpFunction +; CHECK-NEXT: OpFunctionParameter +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpBranch [[label:%\w+]] +; CHECK-NEXT: [[label]] = OpLabel +; CHECK-NEXT: OpLogicalNot %bool %true +; CHECK-NEXT: OpReturn +; CHECK-NEXT: OpFunctionEnd +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +OpName %func "func" +OpDecorate %func LinkageAttributes "func" Export +%void = OpTypeVoid +%bool = OpTypeBool +%false = OpConstantFalse %bool +%true = OpConstantTrue %bool +%1 = OpTypeFunction %void %bool +%func = OpFunction %void None %1 +%bool_param = OpFunctionParameter %bool +%2 = OpLabel +OpBranch %3 +%4 = OpLabel +OpBranch %3 +%3 = OpLabel +%5 = OpPhi %bool %true %2 %false %4 +%6 = OpLogicalNot %bool %5 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(DeadBranchElimTest, RemoveUnreachableBlocksPartiallyDeadPhi) { + const std::string text = R"( +; CHECK: OpFunction +; CHECK-NEXT: [[param:%\w+]] = OpFunctionParameter +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpBranchConditional [[param]] [[merge:%\w+]] [[br:%\w+]] +; CHECK-NEXT: [[merge]] = OpLabel +; CHECK-NEXT: [[phi:%\w+]] = OpPhi %bool %true %2 %false [[br]] +; CHECK-NEXT: OpLogicalNot %bool [[phi]] +; CHECK-NEXT: OpReturn +; CHECK-NEXT: [[br]] = OpLabel +; CHECK-NEXT: OpBranch [[merge]] +; CHECK-NEXT: OpFunctionEnd +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +OpName %func "func" +OpDecorate %func LinkageAttributes "func" Export +%void = OpTypeVoid +%bool = OpTypeBool +%false = OpConstantFalse %bool +%true = OpConstantTrue %bool +%1 = OpTypeFunction %void %bool +%func = OpFunction %void None %1 +%bool_param = OpFunctionParameter %bool +%2 = OpLabel +OpBranchConditional %bool_param %3 %7 +%7 = OpLabel +OpBranch %3 +%4 = OpLabel +OpBranch %3 +%3 = OpLabel +%5 = OpPhi %bool %true %2 %false %7 %false %4 +%6 = OpLogicalNot %bool %5 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndMatch(text, true); +} + +TEST_F(DeadBranchElimTest, LiveHeaderDeadPhi) { + const std::string text = R"( +; CHECK: OpLabel +; CHECK-NOT: OpBranchConditional +; CHECK-NOT: OpPhi +; CHECK: OpLogicalNot %bool %false +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +OpName %func "func" +OpDecorate %func LinkageAttributes "func" Export +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%false = OpConstantFalse %bool +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +OpSelectionMerge %3 None +OpBranchConditional %true %2 %3 +%2 = OpLabel +OpBranch %3 +%3 = OpLabel +%5 = OpPhi %bool %true %3 %false %2 +%6 = OpLogicalNot %bool %5 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(DeadBranchElimTest, ExtraBackedgeBlocksLive) { + const std::string text = R"( +; CHECK: [[entry:%\w+]] = OpLabel +; CHECK-NOT: OpSelectionMerge +; CHECK: OpBranch [[header:%\w+]] +; CHECK-NEXT: [[header]] = OpLabel +; CHECK-NEXT: OpPhi %bool %true [[entry]] %false [[backedge:%\w+]] +; CHECK-NEXT: OpLoopMerge +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +OpName %func "func" +OpDecorate %func LinkageAttributes "func" Export +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%false = OpConstantFalse %bool +%func_ty = OpTypeFunction %void %bool +%func = OpFunction %void None %func_ty +%param = OpFunctionParameter %bool +%entry = OpLabel +OpSelectionMerge %if_merge None +; This dead branch is included to ensure the pass does work. +OpBranchConditional %false %if_merge %loop_header +%loop_header = OpLabel +; Both incoming edges are live, so the phi should be untouched. +%phi = OpPhi %bool %true %entry %false %backedge +OpLoopMerge %loop_merge %continue None +OpBranchConditional %param %loop_merge %continue +%continue = OpLabel +OpBranch %backedge +%backedge = OpLabel +OpBranch %loop_header +%loop_merge = OpLabel +OpBranch %if_merge +%if_merge = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(DeadBranchElimTest, ExtraBackedgeBlocksUnreachable) { + const std::string text = R"( +; CHECK: [[entry:%\w+]] = OpLabel +; CHECK-NEXT: OpBranch [[header:%\w+]] +; CHECK-NEXT: [[header]] = OpLabel +; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue:%\w+]] None +; CHECK-NEXT: OpBranch [[merge]] +; CHECK-NEXT: [[merge]] = OpLabel +; CHECK-NEXT: OpReturn +; CHECK-NEXT: [[continue]] = OpLabel +; CHECK-NEXT: OpBranch [[header]] +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +OpName %func "func" +OpDecorate %func LinkageAttributes "func" Export +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%false = OpConstantFalse %bool +%func_ty = OpTypeFunction %void %bool +%func = OpFunction %void None %func_ty +%param = OpFunctionParameter %bool +%entry = OpLabel +OpBranch %loop_header +%loop_header = OpLabel +; Since the continue is unreachable, %backedge will be removed. The phi will +; instead require an edge from %continue. +%phi = OpPhi %bool %true %entry %false %backedge +OpLoopMerge %merge %continue None +OpBranch %merge +%continue = OpLabel +OpBranch %backedge +%backedge = OpLabel +OpBranch %loop_header +%merge = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(DeadBranchElimTest, NoUnnecessaryChanges) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%undef = OpUndef %bool +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +OpBranch %2 +%2 = OpLabel +OpLoopMerge %4 %5 None +OpBranch %6 +%6 = OpLabel +OpReturn +%5 = OpLabel +OpBranch %2 +%4 = OpLabel +OpUnreachable +OpFunctionEnd +)"; + + auto result = SinglePassRunToBinary(text, true); + EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithoutChange); +} + +TEST_F(DeadBranchElimTest, ExtraBackedgePartiallyDead) { + const std::string text = R"( +; CHECK: OpLabel +; CHECK: [[header:%\w+]] = OpLabel +; CHECK: OpLoopMerge [[merge:%\w+]] [[continue:%\w+]] None +; CHECK: [[merge]] = OpLabel +; CHECK: [[continue]] = OpLabel +; CHECK: OpBranch [[extra:%\w+]] +; CHECK: [[extra]] = OpLabel +; CHECK-NOT: OpSelectionMerge +; CHECK-NEXT: OpBranch [[else:%\w+]] +; CHECK-NEXT: [[else]] = OpLabel +; CHECK-NEXT: OpLogicalOr +; CHECK-NEXT: OpBranch [[backedge:%\w+]] +; CHECK-NEXT: [[backedge:%\w+]] = OpLabel +; CHECK-NEXT: OpBranch [[header]] +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +OpName %func "func" +OpDecorate %func LinkageAttributes "func" Export +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%false = OpConstantFalse %bool +%func_ty = OpTypeFunction %void %bool +%func = OpFunction %void None %func_ty +%param = OpFunctionParameter %bool +%entry = OpLabel +OpBranch %loop_header +%loop_header = OpLabel +OpLoopMerge %loop_merge %continue None +OpBranchConditional %param %loop_merge %continue +%continue = OpLabel +OpBranch %extra +%extra = OpLabel +OpSelectionMerge %backedge None +OpBranchConditional %false %then %else +%then = OpLabel +%and = OpLogicalAnd %bool %true %false +OpBranch %backedge +%else = OpLabel +%or = OpLogicalOr %bool %true %false +OpBranch %backedge +%backedge = OpLabel +OpBranch %loop_header +%loop_merge = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(DeadBranchElimTest, UnreachableContinuePhiInMerge) { + const std::string text = R"( +; CHECK: [[entry:%\w+]] = OpLabel +; CHECK-NEXT: OpBranch [[header:%\w+]] +; CHECK-NEXT: [[header]] = OpLabel +; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue:%\w+]] None +; CHECK-NEXT: OpBranch [[label:%\w+]] +; CHECK-NEXT: [[label]] = OpLabel +; CHECK-NEXT: [[fadd:%\w+]] = OpFAdd +; CHECK-NEXT: OpBranch [[label:%\w+]] +; CHECK-NEXT: [[label]] = OpLabel +; CHECK-NEXT: OpBranch [[merge]] +; CHECK-NEXT: [[continue]] = OpLabel +; CHECK-NEXT: OpBranch [[header]] +; CHECK-NEXT: [[merge]] = OpLabel +; CHECK-NEXT: OpStore {{%\w+}} [[fadd]] + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %o + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 430 + OpSourceExtension "GL_GOOGLE_cpp_style_line_directive" + OpSourceExtension "GL_GOOGLE_include_directive" + OpName %main "main" + OpName %o "o" + OpName %S "S" + OpMemberName %S 0 "a" + OpName %U_t "U_t" + OpMemberName %U_t 0 "g_F" + OpMemberName %U_t 1 "g_F2" + OpDecorate %o Location 0 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %U_t 0 Volatile + OpMemberDecorate %U_t 0 Offset 0 + OpMemberDecorate %U_t 1 Offset 4 + OpDecorate %U_t BufferBlock + %void = OpTypeVoid + %7 = OpTypeFunction %void + %float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float + %float_0 = OpConstant %float 0 + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %int_0 = OpConstant %int 0 + %int_10 = OpConstant %int 10 + %bool = OpTypeBool + %true = OpConstantTrue %bool + %float_1 = OpConstant %float 1 + %float_5 = OpConstant %float 5 + %int_1 = OpConstant %int 1 +%_ptr_Output_float = OpTypePointer Output %float + %o = OpVariable %_ptr_Output_float Output + %S = OpTypeStruct %float + %U_t = OpTypeStruct %S %S +%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t + %main = OpFunction %void None %7 + %22 = OpLabel + OpBranch %23 + %23 = OpLabel + %24 = OpPhi %float %float_0 %22 %25 %26 + %27 = OpPhi %int %int_0 %22 %28 %26 + OpLoopMerge %29 %26 None + OpBranch %40 + %40 = OpLabel + %25 = OpFAdd %float %24 %float_1 + OpSelectionMerge %30 None + OpBranchConditional %true %31 %30 + %31 = OpLabel + OpBranch %29 + %30 = OpLabel + OpBranch %26 + %26 = OpLabel + %28 = OpIAdd %int %27 %int_1 + %32 = OpSLessThan %bool %27 %int_10 +; continue block branches to the header or another none dead block. + OpBranchConditional %32 %23 %29 + %29 = OpLabel + %33 = OpPhi %float %24 %26 %25 %31 + OpStore %o %33 + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(DeadBranchElimTest, NonStructuredIf) { + const std::string text = R"( +; CHECK-NOT: OpBranchConditional +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +OpDecorate %func LinkageAttributes "func" Export +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%entry = OpLabel +OpBranchConditional %true %then %else +%then = OpLabel +OpBranch %final +%else = OpLabel +OpBranch %final +%final = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(DeadBranchElimTest, ReorderBlocks) { + const std::string text = R"( +; CHECK: OpLabel +; CHECK: OpBranch [[label:%\w+]] +; CHECK: [[label:%\w+]] = OpLabel +; CHECK-NEXT: OpLogicalNot +; CHECK-NEXT: OpBranch [[label:%\w+]] +; CHECK: [[label]] = OpLabel +; CHECK-NEXT: OpReturn +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +OpSelectionMerge %3 None +OpBranchConditional %true %2 %3 +%3 = OpLabel +OpReturn +%2 = OpLabel +%not = OpLogicalNot %bool %true +OpBranch %3 +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(DeadBranchElimTest, ReorderBlocksMultiple) { + // Checks are not important. The validation post optimization is the + // important part. + const std::string text = R"( +; CHECK: OpLabel +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +OpSelectionMerge %3 None +OpBranchConditional %true %2 %3 +%3 = OpLabel +OpReturn +%2 = OpLabel +OpBranch %4 +%4 = OpLabel +OpBranch %3 +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(DeadBranchElimTest, ReorderBlocksMultiple2) { + // Checks are not important. The validation post optimization is the + // important part. + const std::string text = R"( +; CHECK: OpLabel +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +OpSelectionMerge %3 None +OpBranchConditional %true %2 %3 +%3 = OpLabel +OpBranch %5 +%5 = OpLabel +OpReturn +%2 = OpLabel +OpBranch %4 +%4 = OpLabel +OpBranch %3 +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(DeadBranchElimTest, SelectionMergeWithEarlyExit1) { + // Checks that if a selection merge construct contains a conditional branch + // to the merge node, then the OpSelectionMerge instruction is positioned + // correctly. + const std::string predefs = R"( +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +%void = OpTypeVoid +%func_type = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%undef_bool = OpUndef %bool +)"; + + const std::string body = + R"( +; CHECK: OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpBranch [[taken_branch:%\w+]] +; CHECK-NEXT: [[taken_branch]] = OpLabel +; CHECK-NEXT: OpSelectionMerge [[merge:%\w+]] +; CHECK-NEXT: OpBranchConditional {{%\w+}} [[merge]] {{%\w+}} +%main = OpFunction %void None %func_type +%entry_bb = OpLabel +OpSelectionMerge %outer_merge None +OpBranchConditional %true %bb1 %bb3 +%bb1 = OpLabel +OpBranchConditional %undef_bool %outer_merge %bb2 +%bb2 = OpLabel +OpBranch %outer_merge +%bb3 = OpLabel +OpBranch %outer_merge +%outer_merge = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(predefs + body, true); +} + +TEST_F(DeadBranchElimTest, SelectionMergeWithEarlyExit2) { + // Checks that if a selection merge construct contains a conditional branch + // to the merge node, then the OpSelectionMerge instruction is positioned + // correctly. + const std::string predefs = R"( +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +%void = OpTypeVoid +%func_type = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%undef_bool = OpUndef %bool +)"; + + const std::string body = + R"( +; CHECK: OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpBranch [[bb1:%\w+]] +; CHECK-NEXT: [[bb1]] = OpLabel +; CHECK-NEXT: OpSelectionMerge [[inner_merge:%\w+]] +; CHECK: [[inner_merge]] = OpLabel +; CHECK-NEXT: OpSelectionMerge [[outer_merge:%\w+]] +; CHECK-NEXT: OpBranchConditional {{%\w+}} [[outer_merge]:%\w+]] {{%\w+}} +; CHECK: [[outer_merge]] = OpLabel +; CHECK-NEXT: OpReturn +%main = OpFunction %void None %func_type +%entry_bb = OpLabel +OpSelectionMerge %outer_merge None +OpBranchConditional %true %bb1 %bb5 +%bb1 = OpLabel +OpSelectionMerge %inner_merge None +OpBranchConditional %undef_bool %bb2 %bb3 +%bb2 = OpLabel +OpBranch %inner_merge +%bb3 = OpLabel +OpBranch %inner_merge +%inner_merge = OpLabel +OpBranchConditional %undef_bool %outer_merge %bb4 +%bb4 = OpLabel +OpBranch %outer_merge +%bb5 = OpLabel +OpBranch %outer_merge +%outer_merge = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(predefs + body, true); +} + +TEST_F(DeadBranchElimTest, SelectionMergeWithConditionalExit) { + // Checks that if a selection merge construct contains a conditional branch + // to the merge node, then we keep the OpSelectionMerge on that branch. + const std::string predefs = R"( +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +%void = OpTypeVoid +%func_type = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%uint = OpTypeInt 32 0 +%undef_int = OpUndef %uint +)"; + + const std::string body = + R"( +; CHECK: OpLoopMerge [[loop_merge:%\w+]] +; CHECK-NEXT: OpBranch [[bb1:%\w+]] +; CHECK: [[bb1]] = OpLabel +; CHECK-NEXT: OpBranch [[bb2:%\w+]] +; CHECK: [[bb2]] = OpLabel +; CHECK-NEXT: OpSelectionMerge [[sel_merge:%\w+]] None +; CHECK-NEXT: OpSwitch {{%\w+}} [[sel_merge]] 1 [[bb3:%\w+]] +; CHECK: [[bb3]] = OpLabel +; CHECK-NEXT: OpBranch [[sel_merge]] +; CHECK: [[sel_merge]] = OpLabel +; CHECK-NEXT: OpBranch [[loop_merge]] +; CHECK: [[loop_merge]] = OpLabel +; CHECK-NEXT: OpReturn +%main = OpFunction %void None %func_type +%entry_bb = OpLabel +OpBranch %loop_header +%loop_header = OpLabel +OpLoopMerge %loop_merge %cont None +OpBranch %bb1 +%bb1 = OpLabel +OpSelectionMerge %sel_merge None +OpBranchConditional %true %bb2 %bb4 +%bb2 = OpLabel +OpSwitch %undef_int %sel_merge 1 %bb3 +%bb3 = OpLabel +OpBranch %sel_merge +%bb4 = OpLabel +OpBranch %sel_merge +%sel_merge = OpLabel +OpBranch %loop_merge +%cont = OpLabel +OpBranch %loop_header +%loop_merge = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(predefs + body, true); +} + +TEST_F(DeadBranchElimTest, SelectionMergeWithExitToLoop) { + // Checks that if a selection merge construct contains a conditional branch + // to a loop surrounding the selection merge, then we do not keep the + // OpSelectionMerge instruction. + const std::string predefs = R"( +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +%void = OpTypeVoid +%func_type = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%undef_bool = OpUndef %bool +)"; + + const std::string body = + R"( +; CHECK: OpLoopMerge [[loop_merge:%\w+]] +; CHECK-NEXT: OpBranch [[bb1:%\w+]] +; CHECK: [[bb1]] = OpLabel +; CHECK-NEXT: OpBranch [[bb2:%\w+]] +; CHECK: [[bb2]] = OpLabel +; CHECK-NEXT: OpBranchConditional {{%\w+}} [[bb3:%\w+]] [[loop_merge]] +; CHECK: [[bb3]] = OpLabel +; CHECK-NEXT: OpBranch [[sel_merge:%\w+]] +; CHECK: [[sel_merge]] = OpLabel +; CHECK-NEXT: OpBranch [[loop_merge]] +; CHECK: [[loop_merge]] = OpLabel +; CHECK-NEXT: OpReturn +%main = OpFunction %void None %func_type +%entry_bb = OpLabel +OpBranch %loop_header +%loop_header = OpLabel +OpLoopMerge %loop_merge %cont None +OpBranch %bb1 +%bb1 = OpLabel +OpSelectionMerge %sel_merge None +OpBranchConditional %true %bb2 %bb4 +%bb2 = OpLabel +OpBranchConditional %undef_bool %bb3 %loop_merge +%bb3 = OpLabel +OpBranch %sel_merge +%bb4 = OpLabel +OpBranch %sel_merge +%sel_merge = OpLabel +OpBranch %loop_merge +%cont = OpLabel +OpBranch %loop_header +%loop_merge = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(predefs + body, true); +} + +TEST_F(DeadBranchElimTest, SelectionMergeWithExitToLoopContinue) { + // Checks that if a selection merge construct contains a conditional branch + // to continue of a loop surrounding the selection merge, then we do not keep + // the OpSelectionMerge instruction. + const std::string predefs = R"( +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +%void = OpTypeVoid +%func_type = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%undef_bool = OpUndef %bool +)"; + + const std::string body = + R"(; +; CHECK: OpLabel +; CHECK: [[loop_header:%\w+]] = OpLabel +; CHECK: OpLoopMerge [[loop_merge:%\w+]] [[loop_cont:%\w+]] +; CHECK-NEXT: OpBranch [[bb1:%\w+]] +; CHECK: [[bb1]] = OpLabel +; CHECK-NEXT: OpBranch [[bb2:%\w+]] +; CHECK: [[bb2]] = OpLabel +; CHECK-NEXT: OpBranchConditional {{%\w+}} [[bb3:%\w+]] [[loop_cont]] +; CHECK: [[bb3]] = OpLabel +; CHECK-NEXT: OpBranch [[sel_merge:%\w+]] +; CHECK: [[sel_merge]] = OpLabel +; CHECK-NEXT: OpBranch [[loop_merge]] +; CHECK: [[loop_cont]] = OpLabel +; CHECK-NEXT: OpBranch [[loop_header]] +; CHECK: [[loop_merge]] = OpLabel +; CHECK-NEXT: OpReturn +%main = OpFunction %void None %func_type +%entry_bb = OpLabel +OpBranch %loop_header +%loop_header = OpLabel +OpLoopMerge %loop_merge %cont None +OpBranch %bb1 +%bb1 = OpLabel +OpSelectionMerge %sel_merge None +OpBranchConditional %true %bb2 %bb4 +%bb2 = OpLabel +OpBranchConditional %undef_bool %bb3 %cont +%bb3 = OpLabel +OpBranch %sel_merge +%bb4 = OpLabel +OpBranch %sel_merge +%sel_merge = OpLabel +OpBranch %loop_merge +%cont = OpLabel +OpBranch %loop_header +%loop_merge = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(predefs + body, true); +} + +TEST_F(DeadBranchElimTest, SelectionMergeWithExitToLoop2) { + // Same as |SelectionMergeWithExitToLoop|, except the switch goes to the loop + // merge or the selection merge. In this case, we do not need an + // OpSelectionMerge either. + const std::string predefs = R"( +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +%void = OpTypeVoid +%func_type = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%undef_bool = OpUndef %bool +)"; + + const std::string body = + R"( +; CHECK: OpLoopMerge [[loop_merge:%\w+]] +; CHECK-NEXT: OpBranch [[bb1:%\w+]] +; CHECK: [[bb1]] = OpLabel +; CHECK-NEXT: OpBranch [[bb2:%\w+]] +; CHECK: [[bb2]] = OpLabel +; CHECK-NEXT: OpBranchConditional {{%\w+}} [[sel_merge:%\w+]] [[loop_merge]] +; CHECK: [[sel_merge]] = OpLabel +; CHECK-NEXT: OpBranch [[loop_merge]] +; CHECK: [[loop_merge]] = OpLabel +; CHECK-NEXT: OpReturn +%main = OpFunction %void None %func_type +%entry_bb = OpLabel +OpBranch %loop_header +%loop_header = OpLabel +OpLoopMerge %loop_merge %cont None +OpBranch %bb1 +%bb1 = OpLabel +OpSelectionMerge %sel_merge None +OpBranchConditional %true %bb2 %bb4 +%bb2 = OpLabel +OpBranchConditional %undef_bool %sel_merge %loop_merge +%bb4 = OpLabel +OpBranch %sel_merge +%sel_merge = OpLabel +OpBranch %loop_merge +%cont = OpLabel +OpBranch %loop_header +%loop_merge = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(predefs + body, true); +} + +TEST_F(DeadBranchElimTest, SelectionMergeWithExitToLoopContinue2) { + // Same as |SelectionMergeWithExitToLoopContinue|, except the branch goes to + // the loop continue or the selection merge. In this case, we do not need an + // OpSelectionMerge either. + const std::string predefs = R"( +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +%void = OpTypeVoid +%func_type = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%undef_bool = OpUndef %bool +)"; + + const std::string body = + R"( +; CHECK: OpLabel +; CHECK: [[loop_header:%\w+]] = OpLabel +; CHECK: OpLoopMerge [[loop_merge:%\w+]] [[loop_cont:%\w+]] +; CHECK-NEXT: OpBranch [[bb1:%\w+]] +; CHECK: [[bb1]] = OpLabel +; CHECK-NEXT: OpBranch [[bb2:%\w+]] +; CHECK: [[bb2]] = OpLabel +; CHECK-NEXT: OpBranchConditional {{%\w+}} [[sel_merge:%\w+]] [[loop_cont]] +; CHECK: [[sel_merge]] = OpLabel +; CHECK-NEXT: OpBranch [[loop_merge]] +; CHECK: [[loop_cont]] = OpLabel +; CHECK: OpBranch [[loop_header]] +; CHECK: [[loop_merge]] = OpLabel +; CHECK-NEXT: OpReturn +%main = OpFunction %void None %func_type +%entry_bb = OpLabel +OpBranch %loop_header +%loop_header = OpLabel +OpLoopMerge %loop_merge %cont None +OpBranch %bb1 +%bb1 = OpLabel +OpSelectionMerge %sel_merge None +OpBranchConditional %true %bb2 %bb4 +%bb2 = OpLabel +OpBranchConditional %undef_bool %sel_merge %cont +%bb4 = OpLabel +OpBranch %sel_merge +%sel_merge = OpLabel +OpBranch %loop_merge +%cont = OpLabel +OpBranch %loop_header +%loop_merge = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(predefs + body, true); +} + +TEST_F(DeadBranchElimTest, SelectionMergeWithExitToLoop3) { + // Checks that if a selection merge construct contains a conditional branch + // to the merge of a surrounding loop, the selection merge, and another block + // inside the selection merge, then we must keep the OpSelectionMerge + // instruction on that branch. + const std::string predefs = R"( +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +%void = OpTypeVoid +%func_type = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%uint = OpTypeInt 32 0 +%undef_int = OpUndef %uint +)"; + + const std::string body = + R"( +; CHECK: OpLoopMerge [[loop_merge:%\w+]] +; CHECK-NEXT: OpBranch [[bb1:%\w+]] +; CHECK: [[bb1]] = OpLabel +; CHECK-NEXT: OpBranch [[bb2:%\w+]] +; CHECK: [[bb2]] = OpLabel +; CHECK-NEXT: OpSelectionMerge [[sel_merge:%\w+]] None +; CHECK-NEXT: OpSwitch {{%\w+}} [[sel_merge]] 0 [[loop_merge]] 1 [[bb3:%\w+]] +; CHECK: [[bb3]] = OpLabel +; CHECK-NEXT: OpBranch [[sel_merge]] +; CHECK: [[sel_merge]] = OpLabel +; CHECK-NEXT: OpBranch [[loop_merge]] +; CHECK: [[loop_merge]] = OpLabel +; CHECK-NEXT: OpReturn +%main = OpFunction %void None %func_type +%entry_bb = OpLabel +OpBranch %loop_header +%loop_header = OpLabel +OpLoopMerge %loop_merge %cont None +OpBranch %bb1 +%bb1 = OpLabel +OpSelectionMerge %sel_merge None +OpBranchConditional %true %bb2 %bb4 +%bb2 = OpLabel +OpSwitch %undef_int %sel_merge 0 %loop_merge 1 %bb3 +%bb3 = OpLabel +OpBranch %sel_merge +%bb4 = OpLabel +OpBranch %sel_merge +%sel_merge = OpLabel +OpBranch %loop_merge +%cont = OpLabel +OpBranch %loop_header +%loop_merge = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(predefs + body, true); +} + +TEST_F(DeadBranchElimTest, SelectionMergeWithExitToLoopContinue3) { + // Checks that if a selection merge construct contains a conditional branch + // to the merge of a surrounding loop, the selection merge, and another block + // inside the selection merge, then we must keep the OpSelectionMerge + // instruction on that branch. + const std::string predefs = R"( +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +%void = OpTypeVoid +%func_type = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%uint = OpTypeInt 32 0 +%undef_int = OpUndef %uint +)"; + + const std::string body = + R"( +; CHECK: OpLabel +; CHECK: [[loop_header:%\w+]] = OpLabel +; CHECK: OpLoopMerge [[loop_merge:%\w+]] [[loop_continue:%\w+]] +; CHECK-NEXT: OpBranch [[bb1:%\w+]] +; CHECK: [[bb1]] = OpLabel +; CHECK-NEXT: OpBranch [[bb2:%\w+]] +; CHECK: [[bb2]] = OpLabel +; CHECK-NEXT: OpSelectionMerge [[sel_merge:%\w+]] None +; CHECK-NEXT: OpSwitch {{%\w+}} [[sel_merge]] 0 [[loop_continue]] 1 [[bb3:%\w+]] +; CHECK: [[bb3]] = OpLabel +; CHECK-NEXT: OpBranch [[sel_merge]] +; CHECK: [[sel_merge]] = OpLabel +; CHECK-NEXT: OpBranch [[loop_merge]] +; CHECK: [[loop_continue]] = OpLabel +; CHECK-NEXT: OpBranch [[loop_header]] +; CHECK: [[loop_merge]] = OpLabel +; CHECK-NEXT: OpReturn +%main = OpFunction %void None %func_type +%entry_bb = OpLabel +OpBranch %loop_header +%loop_header = OpLabel +OpLoopMerge %loop_merge %cont None +OpBranch %bb1 +%bb1 = OpLabel +OpSelectionMerge %sel_merge None +OpBranchConditional %true %bb2 %bb4 +%bb2 = OpLabel +OpSwitch %undef_int %sel_merge 0 %cont 1 %bb3 +%bb3 = OpLabel +OpBranch %sel_merge +%bb4 = OpLabel +OpBranch %sel_merge +%sel_merge = OpLabel +OpBranch %loop_merge +%cont = OpLabel +OpBranch %loop_header +%loop_merge = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(predefs + body, true); +} + +TEST_F(DeadBranchElimTest, SelectionMergeWithExitToLoop4) { + // Same as |SelectionMergeWithExitToLoop|, except the branch in the selection + // construct is an |OpSwitch| instead of an |OpConditionalBranch|. The + // OpSelectionMerge instruction is not needed in this case either. + const std::string predefs = R"( +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +%void = OpTypeVoid +%func_type = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%uint = OpTypeInt 32 0 +%undef_int = OpUndef %uint +)"; + + const std::string body = + R"( +; CHECK: OpLoopMerge [[loop_merge:%\w+]] +; CHECK-NEXT: OpBranch [[bb1:%\w+]] +; CHECK: [[bb1]] = OpLabel +; CHECK-NEXT: OpBranch [[bb2:%\w+]] +; CHECK: [[bb2]] = OpLabel +; CHECK-NEXT: OpSwitch {{%\w+}} [[bb3:%\w+]] 0 [[loop_merge]] 1 [[bb3:%\w+]] +; CHECK: [[bb3]] = OpLabel +; CHECK-NEXT: OpBranch [[sel_merge:%\w+]] +; CHECK: [[sel_merge]] = OpLabel +; CHECK-NEXT: OpBranch [[loop_merge]] +; CHECK: [[loop_merge]] = OpLabel +; CHECK-NEXT: OpReturn +%main = OpFunction %void None %func_type +%entry_bb = OpLabel +OpBranch %loop_header +%loop_header = OpLabel +OpLoopMerge %loop_merge %cont None +OpBranch %bb1 +%bb1 = OpLabel +OpSelectionMerge %sel_merge None +OpBranchConditional %true %bb2 %bb4 +%bb2 = OpLabel +OpSwitch %undef_int %bb3 0 %loop_merge 1 %bb3 +%bb3 = OpLabel +OpBranch %sel_merge +%bb4 = OpLabel +OpBranch %sel_merge +%sel_merge = OpLabel +OpBranch %loop_merge +%cont = OpLabel +OpBranch %loop_header +%loop_merge = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(predefs + body, true); +} + +TEST_F(DeadBranchElimTest, SelectionMergeWithExitToLoopContinue4) { + // Same as |SelectionMergeWithExitToLoopContinue|, except the branch in the + // selection construct is an |OpSwitch| instead of an |OpConditionalBranch|. + // The OpSelectionMerge instruction is not needed in this case either. + const std::string predefs = R"( +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +%void = OpTypeVoid +%func_type = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%uint = OpTypeInt 32 0 +%undef_int = OpUndef %uint +)"; + + const std::string body = + R"( +; CHECK: OpLoopMerge [[loop_merge:%\w+]] [[loop_cont:%\w+]] +; CHECK-NEXT: OpBranch [[bb1:%\w+]] +; CHECK: [[bb1]] = OpLabel +; CHECK-NEXT: OpBranch [[bb2:%\w+]] +; CHECK: [[bb2]] = OpLabel +; CHECK-NEXT: OpSwitch {{%\w+}} [[bb3:%\w+]] 0 [[loop_cont]] 1 [[bb3:%\w+]] +; CHECK: [[bb3]] = OpLabel +; CHECK-NEXT: OpBranch [[sel_merge:%\w+]] +; CHECK: [[sel_merge]] = OpLabel +; CHECK-NEXT: OpBranch [[loop_merge]] +; CHECK: [[loop_merge]] = OpLabel +; CHECK-NEXT: OpReturn +%main = OpFunction %void None %func_type +%entry_bb = OpLabel +OpBranch %loop_header +%loop_header = OpLabel +OpLoopMerge %loop_merge %cont None +OpBranch %bb1 +%bb1 = OpLabel +OpSelectionMerge %sel_merge None +OpBranchConditional %true %bb2 %bb4 +%bb2 = OpLabel +OpSwitch %undef_int %bb3 0 %cont 1 %bb3 +%bb3 = OpLabel +OpBranch %sel_merge +%bb4 = OpLabel +OpBranch %sel_merge +%sel_merge = OpLabel +OpBranch %loop_merge +%cont = OpLabel +OpBranch %loop_header +%loop_merge = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(predefs + body, true); +} + +TEST_F(DeadBranchElimTest, SelectionMergeSameAsLoopContinue) { + // Same as |SelectionMergeWithExitToLoopContinue|, except the branch in the + // selection construct is an |OpSwitch| instead of an |OpConditionalBranch|. + // The OpSelectionMerge instruction is not needed in this case either. + const std::string predefs = R"( +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +%void = OpTypeVoid +%func_type = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%uint = OpTypeInt 32 0 +%undef_bool = OpUndef %bool +)"; + + const std::string body = + R"( +; CHECK: OpLabel +; CHECK: [[loop_header:%\w+]] = OpLabel +; CHECK: OpLoopMerge [[loop_merge:%\w+]] [[loop_cont:%\w+]] +; CHECK-NEXT: OpBranch [[bb1:%\w+]] +; CHECK: [[bb1]] = OpLabel +; CHECK-NEXT: OpBranch [[bb2:%\w+]] +; CHECK: [[bb2]] = OpLabel +; CHECK-NEXT: OpSelectionMerge [[loop_cont]] +; CHECK-NEXT: OpBranchConditional {{%\w+}} [[bb3:%\w+]] [[loop_cont]] +; CHECK: [[bb3]] = OpLabel +; CHECK-NEXT: OpBranch [[loop_cont]] +; CHECK: [[loop_cont]] = OpLabel +; CHECK-NEXT: OpBranchConditional {{%\w+}} [[loop_header]] [[loop_merge]] +; CHECK: [[loop_merge]] = OpLabel +; CHECK-NEXT: OpReturn +%main = OpFunction %void None %func_type +%entry_bb = OpLabel +OpBranch %loop_header +%loop_header = OpLabel +OpLoopMerge %loop_merge %cont None +OpBranch %bb1 +%bb1 = OpLabel +OpSelectionMerge %cont None +OpBranchConditional %true %bb2 %bb4 +%bb2 = OpLabel +OpBranchConditional %undef_bool %bb3 %cont +%bb3 = OpLabel +OpBranch %cont +%bb4 = OpLabel +OpBranch %cont +%cont = OpLabel +OpBranchConditional %undef_bool %loop_header %loop_merge +%loop_merge = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(predefs + body, true); +} + +TEST_F(DeadBranchElimTest, SelectionMergeWithNestedLoop) { + const std::string body = + R"( +; CHECK: OpSelectionMerge [[merge1:%\w+]] +; CHECK: [[merge1]] = OpLabel +; CHECK-NEXT: OpBranch [[preheader:%\w+]] +; CHECK: [[preheader]] = OpLabel +; CHECK-NOT: OpLabel +; CHECK: OpBranch [[header:%\w+]] +; CHECK: [[header]] = OpLabel +; CHECK-NOT: OpLabel +; CHECK: OpLoopMerge [[merge2:%\w+]] +; CHECK: [[merge2]] = OpLabel +; CHECK-NEXT: OpUnreachable + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginUpperLeft + OpSource ESSL 310 + OpName %main "main" + OpName %h "h" + OpName %i "i" + %void = OpTypeVoid + %3 = OpTypeFunction %void + %bool = OpTypeBool + %_ptr_Function_bool = OpTypePointer Function %bool + %true = OpConstantTrue %bool + %int = OpTypeInt 32 1 + %_ptr_Function_int = OpTypePointer Function %int + %int_1 = OpConstant %int 1 + %int_0 = OpConstant %int 0 + %27 = OpUndef %bool + %main = OpFunction %void None %3 + %5 = OpLabel + %h = OpVariable %_ptr_Function_bool Function + %i = OpVariable %_ptr_Function_int Function + OpSelectionMerge %11 None + OpBranchConditional %27 %10 %11 + %10 = OpLabel + OpBranch %11 + %11 = OpLabel + OpSelectionMerge %14 None + OpBranchConditional %true %13 %14 + %13 = OpLabel + OpStore %i %int_1 + OpBranch %19 + %19 = OpLabel + OpLoopMerge %21 %22 None + OpBranch %23 + %23 = OpLabel + %26 = OpSGreaterThan %bool %int_1 %int_0 + OpBranchConditional %true %20 %21 + %20 = OpLabel + OpBranch %22 + %22 = OpLabel + OpBranch %19 + %21 = OpLabel + OpBranch %14 + %14 = OpLabel + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(body, true); +} + +// TODO(greg-lunarg): Add tests to verify handling of these cases: +// +// More complex control flow +// Others? + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/dead_insert_elim_test.cpp b/test/opt/dead_insert_elim_test.cpp new file mode 100644 index 000000000..8ae6894d8 --- /dev/null +++ b/test/opt/dead_insert_elim_test.cpp @@ -0,0 +1,571 @@ +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using DeadInsertElimTest = PassTest<::testing::Test>; + +TEST_F(DeadInsertElimTest, InsertAfterInsertElim) { + // With two insertions to the same offset, the first is dead. + // + // Note: The SPIR-V assembly has had store/load elimination + // performed to allow the inserts and extracts to directly + // reference each other. + // + // #version 450 + // + // layout (location=0) in float In0; + // layout (location=1) in float In1; + // layout (location=2) in vec2 In2; + // layout (location=0) out vec4 OutColor; + // + // void main() + // { + // vec2 v = In2; + // v.x = In0 + In1; // dead + // v.x = 0.0; + // OutColor = v.xyxy; + // } + + const std::string before_predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %In2 %In0 %In1 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %In2 "In2" +OpName %In0 "In0" +OpName %In1 "In1" +OpName %OutColor "OutColor" +OpName %_Globals_ "_Globals_" +OpMemberName %_Globals_ 0 "g_b" +OpMemberName %_Globals_ 1 "g_n" +OpName %_ "" +OpDecorate %In2 Location 2 +OpDecorate %In0 Location 0 +OpDecorate %In1 Location 1 +OpDecorate %OutColor Location 0 +OpMemberDecorate %_Globals_ 0 Offset 0 +OpMemberDecorate %_Globals_ 1 Offset 4 +OpDecorate %_Globals_ Block +OpDecorate %_ DescriptorSet 0 +OpDecorate %_ Binding 0 +%void = OpTypeVoid +%11 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%_ptr_Function_v2float = OpTypePointer Function %v2float +%_ptr_Input_v2float = OpTypePointer Input %v2float +%In2 = OpVariable %_ptr_Input_v2float Input +%_ptr_Input_float = OpTypePointer Input %float +%In0 = OpVariable %_ptr_Input_float Input +%In1 = OpVariable %_ptr_Input_float Input +%uint = OpTypeInt 32 0 +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%int = OpTypeInt 32 1 +%_Globals_ = OpTypeStruct %uint %int +%_ptr_Uniform__Globals_ = OpTypePointer Uniform %_Globals_ +%_ = OpVariable %_ptr_Uniform__Globals_ Uniform +)"; + + const std::string after_predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %In2 %In0 %In1 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %In2 "In2" +OpName %In0 "In0" +OpName %In1 "In1" +OpName %OutColor "OutColor" +OpName %_Globals_ "_Globals_" +OpMemberName %_Globals_ 0 "g_b" +OpMemberName %_Globals_ 1 "g_n" +OpName %_ "" +OpDecorate %In2 Location 2 +OpDecorate %In0 Location 0 +OpDecorate %In1 Location 1 +OpDecorate %OutColor Location 0 +OpMemberDecorate %_Globals_ 0 Offset 0 +OpMemberDecorate %_Globals_ 1 Offset 4 +OpDecorate %_Globals_ Block +OpDecorate %_ DescriptorSet 0 +OpDecorate %_ Binding 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%_ptr_Function_v2float = OpTypePointer Function %v2float +%_ptr_Input_v2float = OpTypePointer Input %v2float +%In2 = OpVariable %_ptr_Input_v2float Input +%_ptr_Input_float = OpTypePointer Input %float +%In0 = OpVariable %_ptr_Input_float Input +%In1 = OpVariable %_ptr_Input_float Input +%uint = OpTypeInt 32 0 +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%int = OpTypeInt 32 1 +%_Globals_ = OpTypeStruct %uint %int +%_ptr_Uniform__Globals_ = OpTypePointer Uniform %_Globals_ +%_ = OpVariable %_ptr_Uniform__Globals_ Uniform +)"; + + const std::string before = + R"(%main = OpFunction %void None %11 +%25 = OpLabel +%26 = OpLoad %v2float %In2 +%27 = OpLoad %float %In0 +%28 = OpLoad %float %In1 +%29 = OpFAdd %float %27 %28 +%35 = OpCompositeInsert %v2float %29 %26 0 +%37 = OpCompositeInsert %v2float %float_0 %35 0 +%33 = OpVectorShuffle %v4float %37 %37 0 1 0 1 +OpStore %OutColor %33 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %10 +%23 = OpLabel +%24 = OpLoad %v2float %In2 +%29 = OpCompositeInsert %v2float %float_0 %24 0 +%30 = OpVectorShuffle %v4float %29 %29 0 1 0 1 +OpStore %OutColor %30 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(before_predefs + before, + after_predefs + after, true, true); +} + +TEST_F(DeadInsertElimTest, DeadInsertInChainWithPhi) { + // Dead insert eliminated with phi in insertion chain. + // + // Note: The SPIR-V assembly has had store/load elimination + // performed to allow the inserts and extracts to directly + // reference each other. + // + // #version 450 + // + // layout (location=0) in vec4 In0; + // layout (location=1) in float In1; + // layout (location=2) in float In2; + // layout (location=0) out vec4 OutColor; + // + // layout(std140, binding = 0 ) uniform _Globals_ + // { + // bool g_b; + // }; + // + // void main() + // { + // vec4 v = In0; + // v.z = In1 + In2; + // if (g_b) v.w = 1.0; + // OutColor = vec4(v.x,v.y,0.0,v.w); + // } + + const std::string before_predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %In0 %In1 %In2 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %In0 "In0" +OpName %In1 "In1" +OpName %In2 "In2" +OpName %_Globals_ "_Globals_" +OpMemberName %_Globals_ 0 "g_b" +OpName %_ "" +OpName %OutColor "OutColor" +OpDecorate %In0 Location 0 +OpDecorate %In1 Location 1 +OpDecorate %In2 Location 2 +OpMemberDecorate %_Globals_ 0 Offset 0 +OpDecorate %_Globals_ Block +OpDecorate %_ DescriptorSet 0 +OpDecorate %_ Binding 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%11 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%In0 = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%In1 = OpVariable %_ptr_Input_float Input +%In2 = OpVariable %_ptr_Input_float Input +%uint = OpTypeInt 32 0 +%_ptr_Function_float = OpTypePointer Function %float +%_Globals_ = OpTypeStruct %uint +%_ptr_Uniform__Globals_ = OpTypePointer Uniform %_Globals_ +%_ = OpVariable %_ptr_Uniform__Globals_ Uniform +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%bool = OpTypeBool +%uint_0 = OpConstant %uint 0 +%float_1 = OpConstant %float 1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%float_0 = OpConstant %float 0 +)"; + + const std::string after_predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %In0 %In1 %In2 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %In0 "In0" +OpName %In1 "In1" +OpName %In2 "In2" +OpName %_Globals_ "_Globals_" +OpMemberName %_Globals_ 0 "g_b" +OpName %_ "" +OpName %OutColor "OutColor" +OpDecorate %In0 Location 0 +OpDecorate %In1 Location 1 +OpDecorate %In2 Location 2 +OpMemberDecorate %_Globals_ 0 Offset 0 +OpDecorate %_Globals_ Block +OpDecorate %_ DescriptorSet 0 +OpDecorate %_ Binding 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%In0 = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%In1 = OpVariable %_ptr_Input_float Input +%In2 = OpVariable %_ptr_Input_float Input +%uint = OpTypeInt 32 0 +%_ptr_Function_float = OpTypePointer Function %float +%_Globals_ = OpTypeStruct %uint +%_ptr_Uniform__Globals_ = OpTypePointer Uniform %_Globals_ +%_ = OpVariable %_ptr_Uniform__Globals_ Uniform +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%bool = OpTypeBool +%uint_0 = OpConstant %uint 0 +%float_1 = OpConstant %float 1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%float_0 = OpConstant %float 0 +)"; + + const std::string before = + R"(%main = OpFunction %void None %11 +%31 = OpLabel +%32 = OpLoad %v4float %In0 +%33 = OpLoad %float %In1 +%34 = OpLoad %float %In2 +%35 = OpFAdd %float %33 %34 +%51 = OpCompositeInsert %v4float %35 %32 2 +%37 = OpAccessChain %_ptr_Uniform_uint %_ %int_0 +%38 = OpLoad %uint %37 +%39 = OpINotEqual %bool %38 %uint_0 +OpSelectionMerge %40 None +OpBranchConditional %39 %41 %40 +%41 = OpLabel +%53 = OpCompositeInsert %v4float %float_1 %51 3 +OpBranch %40 +%40 = OpLabel +%60 = OpPhi %v4float %51 %31 %53 %41 +%55 = OpCompositeExtract %float %60 0 +%57 = OpCompositeExtract %float %60 1 +%59 = OpCompositeExtract %float %60 3 +%49 = OpCompositeConstruct %v4float %55 %57 %float_0 %59 +OpStore %OutColor %49 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %10 +%27 = OpLabel +%28 = OpLoad %v4float %In0 +%33 = OpAccessChain %_ptr_Uniform_uint %_ %int_0 +%34 = OpLoad %uint %33 +%35 = OpINotEqual %bool %34 %uint_0 +OpSelectionMerge %36 None +OpBranchConditional %35 %37 %36 +%37 = OpLabel +%38 = OpCompositeInsert %v4float %float_1 %28 3 +OpBranch %36 +%36 = OpLabel +%39 = OpPhi %v4float %28 %27 %38 %37 +%40 = OpCompositeExtract %float %39 0 +%41 = OpCompositeExtract %float %39 1 +%42 = OpCompositeExtract %float %39 3 +%43 = OpCompositeConstruct %v4float %40 %41 %float_0 %42 +OpStore %OutColor %43 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(before_predefs + before, + after_predefs + after, true, true); +} + +TEST_F(DeadInsertElimTest, DeadInsertTwoPasses) { + // Dead insert which requires two passes to eliminate + // + // Note: The SPIR-V assembly has had store/load elimination + // performed to allow the inserts and extracts to directly + // reference each other. + // + // #version 450 + // + // layout (location=0) in vec4 In0; + // layout (location=1) in float In1; + // layout (location=2) in float In2; + // layout (location=0) out vec4 OutColor; + // + // layout(std140, binding = 0 ) uniform _Globals_ + // { + // bool g_b; + // bool g_b2; + // }; + // + // void main() + // { + // vec4 v1, v2; + // v1 = In0; + // v1.y = In1 + In2; // dead, second pass + // if (g_b) v1.x = 1.0; + // v2.x = v1.x; + // v2.y = v1.y; // dead, first pass + // if (g_b2) v2.x = 0.0; + // OutColor = vec4(v2.x,v2.x,0.0,1.0); + // } + + const std::string before_predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %In0 %In1 %In2 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %In0 "In0" +OpName %In1 "In1" +OpName %In2 "In2" +OpName %_Globals_ "_Globals_" +OpMemberName %_Globals_ 0 "g_b" +OpMemberName %_Globals_ 1 "g_b2" +OpName %_ "" +OpName %OutColor "OutColor" +OpDecorate %In0 Location 0 +OpDecorate %In1 Location 1 +OpDecorate %In2 Location 2 +OpMemberDecorate %_Globals_ 0 Offset 0 +OpMemberDecorate %_Globals_ 1 Offset 4 +OpDecorate %_Globals_ Block +OpDecorate %_ DescriptorSet 0 +OpDecorate %_ Binding 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%In0 = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%In1 = OpVariable %_ptr_Input_float Input +%In2 = OpVariable %_ptr_Input_float Input +%uint = OpTypeInt 32 0 +%_Globals_ = OpTypeStruct %uint %uint +%_ptr_Uniform__Globals_ = OpTypePointer Uniform %_Globals_ +%_ = OpVariable %_ptr_Uniform__Globals_ Uniform +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%bool = OpTypeBool +%uint_0 = OpConstant %uint 0 +%float_1 = OpConstant %float 1 +%int_1 = OpConstant %int 1 +%float_0 = OpConstant %float 0 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%27 = OpUndef %v4float +)"; + + const std::string after_predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %In0 %In1 %In2 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %In0 "In0" +OpName %In1 "In1" +OpName %In2 "In2" +OpName %_Globals_ "_Globals_" +OpMemberName %_Globals_ 0 "g_b" +OpMemberName %_Globals_ 1 "g_b2" +OpName %_ "" +OpName %OutColor "OutColor" +OpDecorate %In0 Location 0 +OpDecorate %In1 Location 1 +OpDecorate %In2 Location 2 +OpMemberDecorate %_Globals_ 0 Offset 0 +OpMemberDecorate %_Globals_ 1 Offset 4 +OpDecorate %_Globals_ Block +OpDecorate %_ DescriptorSet 0 +OpDecorate %_ Binding 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%In0 = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%In1 = OpVariable %_ptr_Input_float Input +%In2 = OpVariable %_ptr_Input_float Input +%uint = OpTypeInt 32 0 +%_Globals_ = OpTypeStruct %uint %uint +%_ptr_Uniform__Globals_ = OpTypePointer Uniform %_Globals_ +%_ = OpVariable %_ptr_Uniform__Globals_ Uniform +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%bool = OpTypeBool +%uint_0 = OpConstant %uint 0 +%float_1 = OpConstant %float 1 +%int_1 = OpConstant %int 1 +%float_0 = OpConstant %float 0 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%27 = OpUndef %v4float +)"; + + const std::string before = + R"(%main = OpFunction %void None %10 +%28 = OpLabel +%29 = OpLoad %v4float %In0 +%30 = OpLoad %float %In1 +%31 = OpLoad %float %In2 +%32 = OpFAdd %float %30 %31 +%33 = OpCompositeInsert %v4float %32 %29 1 +%34 = OpAccessChain %_ptr_Uniform_uint %_ %int_0 +%35 = OpLoad %uint %34 +%36 = OpINotEqual %bool %35 %uint_0 +OpSelectionMerge %37 None +OpBranchConditional %36 %38 %37 +%38 = OpLabel +%39 = OpCompositeInsert %v4float %float_1 %33 0 +OpBranch %37 +%37 = OpLabel +%40 = OpPhi %v4float %33 %28 %39 %38 +%41 = OpCompositeExtract %float %40 0 +%42 = OpCompositeInsert %v4float %41 %27 0 +%43 = OpCompositeExtract %float %40 1 +%44 = OpCompositeInsert %v4float %43 %42 1 +%45 = OpAccessChain %_ptr_Uniform_uint %_ %int_1 +%46 = OpLoad %uint %45 +%47 = OpINotEqual %bool %46 %uint_0 +OpSelectionMerge %48 None +OpBranchConditional %47 %49 %48 +%49 = OpLabel +%50 = OpCompositeInsert %v4float %float_0 %44 0 +OpBranch %48 +%48 = OpLabel +%51 = OpPhi %v4float %44 %37 %50 %49 +%52 = OpCompositeExtract %float %51 0 +%53 = OpCompositeExtract %float %51 0 +%54 = OpCompositeConstruct %v4float %52 %53 %float_0 %float_1 +OpStore %OutColor %54 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %10 +%28 = OpLabel +%29 = OpLoad %v4float %In0 +%34 = OpAccessChain %_ptr_Uniform_uint %_ %int_0 +%35 = OpLoad %uint %34 +%36 = OpINotEqual %bool %35 %uint_0 +OpSelectionMerge %37 None +OpBranchConditional %36 %38 %37 +%38 = OpLabel +%39 = OpCompositeInsert %v4float %float_1 %29 0 +OpBranch %37 +%37 = OpLabel +%40 = OpPhi %v4float %29 %28 %39 %38 +%41 = OpCompositeExtract %float %40 0 +%42 = OpCompositeInsert %v4float %41 %27 0 +%45 = OpAccessChain %_ptr_Uniform_uint %_ %int_1 +%46 = OpLoad %uint %45 +%47 = OpINotEqual %bool %46 %uint_0 +OpSelectionMerge %48 None +OpBranchConditional %47 %49 %48 +%49 = OpLabel +%50 = OpCompositeInsert %v4float %float_0 %42 0 +OpBranch %48 +%48 = OpLabel +%51 = OpPhi %v4float %42 %37 %50 %49 +%52 = OpCompositeExtract %float %51 0 +%53 = OpCompositeExtract %float %51 0 +%54 = OpCompositeConstruct %v4float %52 %53 %float_0 %float_1 +OpStore %OutColor %54 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(before_predefs + before, + after_predefs + after, true, true); +} + +// TODO(greg-lunarg): Add tests to verify handling of these cases: +// + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/dead_variable_elim_test.cpp b/test/opt/dead_variable_elim_test.cpp new file mode 100644 index 000000000..fca13a8e2 --- /dev/null +++ b/test/opt/dead_variable_elim_test.cpp @@ -0,0 +1,298 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using DeadVariableElimTest = PassTest<::testing::Test>; + +// %dead is unused. Make sure we remove it along with its name. +TEST_F(DeadVariableElimTest, RemoveUnreferenced) { + const std::string before = + R"(OpCapability Shader +OpCapability Linkage +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 150 +OpName %main "main" +OpName %dead "dead" +%void = OpTypeVoid +%5 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Private_float = OpTypePointer Private %float +%dead = OpVariable %_ptr_Private_float Private +%main = OpFunction %void None %5 +%8 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(OpCapability Shader +OpCapability Linkage +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 150 +OpName %main "main" +%void = OpTypeVoid +%5 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Private_float = OpTypePointer Private %float +%main = OpFunction %void None %5 +%8 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(before, after, true, true); +} + +// Since %dead is exported, make sure we keep it. It could be referenced +// somewhere else. +TEST_F(DeadVariableElimTest, KeepExported) { + const std::string before = + R"(OpCapability Shader +OpCapability Linkage +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 150 +OpName %main "main" +OpName %dead "dead" +OpDecorate %dead LinkageAttributes "dead" Export +%void = OpTypeVoid +%5 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Private_float = OpTypePointer Private %float +%dead = OpVariable %_ptr_Private_float Private +%main = OpFunction %void None %5 +%8 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(before, before, true, true); +} + +// Delete %dead because it is unreferenced. Then %initializer becomes +// unreferenced, so remove it as well. +TEST_F(DeadVariableElimTest, RemoveUnreferencedWithInit1) { + const std::string before = + R"(OpCapability Shader +OpCapability Linkage +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 150 +OpName %main "main" +OpName %dead "dead" +OpName %initializer "initializer" +%void = OpTypeVoid +%6 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Private_float = OpTypePointer Private %float +%initializer = OpVariable %_ptr_Private_float Private +%dead = OpVariable %_ptr_Private_float Private %initializer +%main = OpFunction %void None %6 +%9 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(OpCapability Shader +OpCapability Linkage +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 150 +OpName %main "main" +%void = OpTypeVoid +%6 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Private_float = OpTypePointer Private %float +%main = OpFunction %void None %6 +%9 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(before, after, true, true); +} + +// Delete %dead because it is unreferenced. In this case, the initialized has +// another reference, and should not be removed. +TEST_F(DeadVariableElimTest, RemoveUnreferencedWithInit2) { + const std::string before = + R"(OpCapability Shader +OpCapability Linkage +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 150 +OpName %main "main" +OpName %dead "dead" +OpName %initializer "initializer" +%void = OpTypeVoid +%6 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Private_float = OpTypePointer Private %float +%initializer = OpVariable %_ptr_Private_float Private +%dead = OpVariable %_ptr_Private_float Private %initializer +%main = OpFunction %void None %6 +%9 = OpLabel +%10 = OpLoad %float %initializer +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(OpCapability Shader +OpCapability Linkage +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 150 +OpName %main "main" +OpName %initializer "initializer" +%void = OpTypeVoid +%6 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Private_float = OpTypePointer Private %float +%initializer = OpVariable %_ptr_Private_float Private +%main = OpFunction %void None %6 +%9 = OpLabel +%10 = OpLoad %float %initializer +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(before, after, true, true); +} + +// Keep %live because it is used, and its initializer. +TEST_F(DeadVariableElimTest, KeepReferenced) { + const std::string before = + R"(OpCapability Shader +OpCapability Linkage +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 150 +OpName %main "main" +OpName %live "live" +OpName %initializer "initializer" +%void = OpTypeVoid +%6 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Private_float = OpTypePointer Private %float +%initializer = OpVariable %_ptr_Private_float Private +%live = OpVariable %_ptr_Private_float Private %initializer +%main = OpFunction %void None %6 +%9 = OpLabel +%10 = OpLoad %float %live +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(before, before, true, true); +} + +// This test that the decoration associated with a variable are removed when the +// variable is removed. +TEST_F(DeadVariableElimTest, RemoveVariableAndDecorations) { + const std::string before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %main "main" +OpSource GLSL 450 +OpName %main "main" +OpName %B "B" +OpMemberName %B 0 "a" +OpName %Bdat "Bdat" +OpMemberDecorate %B 0 Offset 0 +OpDecorate %B BufferBlock +OpDecorate %Bdat DescriptorSet 0 +OpDecorate %Bdat Binding 0 +%void = OpTypeVoid +%6 = OpTypeFunction %void +%uint = OpTypeInt 32 0 +%B = OpTypeStruct %uint +%_ptr_Uniform_B = OpTypePointer Uniform %B +%Bdat = OpVariable %_ptr_Uniform_B Uniform +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%uint_1 = OpConstant %uint 1 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%main = OpFunction %void None %6 +%13 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %main "main" +OpSource GLSL 450 +OpName %main "main" +OpName %B "B" +OpMemberName %B 0 "a" +OpMemberDecorate %B 0 Offset 0 +OpDecorate %B BufferBlock +%void = OpTypeVoid +%6 = OpTypeFunction %void +%uint = OpTypeInt 32 0 +%B = OpTypeStruct %uint +%_ptr_Uniform_B = OpTypePointer Uniform %B +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%uint_1 = OpConstant %uint 1 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%main = OpFunction %void None %6 +%13 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(before, after, true, true); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/decoration_manager_test.cpp b/test/opt/decoration_manager_test.cpp new file mode 100644 index 000000000..f85ff6aa6 --- /dev/null +++ b/test/opt/decoration_manager_test.cpp @@ -0,0 +1,1283 @@ +// Copyright (c) 2017 Pierre Moreau +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/build_module.h" +#include "source/opt/decoration_manager.h" +#include "source/opt/ir_context.h" +#include "source/spirv_constant.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace opt { +namespace analysis { +namespace { + +using spvtest::MakeVector; + +class DecorationManagerTest : public ::testing::Test { + public: + DecorationManagerTest() + : tools_(SPV_ENV_UNIVERSAL_1_2), + context_(), + consumer_([this](spv_message_level_t level, const char*, + const spv_position_t& position, const char* message) { + if (!error_message_.empty()) error_message_ += "\n"; + switch (level) { + case SPV_MSG_FATAL: + case SPV_MSG_INTERNAL_ERROR: + case SPV_MSG_ERROR: + error_message_ += "ERROR"; + break; + case SPV_MSG_WARNING: + error_message_ += "WARNING"; + break; + case SPV_MSG_INFO: + error_message_ += "INFO"; + break; + case SPV_MSG_DEBUG: + error_message_ += "DEBUG"; + break; + } + error_message_ += + ": " + std::to_string(position.index) + ": " + message; + }), + disassemble_options_(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER), + error_message_() { + tools_.SetMessageConsumer(consumer_); + } + + void TearDown() override { error_message_.clear(); } + + DecorationManager* GetDecorationManager(const std::string& text) { + context_ = BuildModule(SPV_ENV_UNIVERSAL_1_2, consumer_, text); + if (context_.get()) + return context_->get_decoration_mgr(); + else + return nullptr; + } + + // Disassembles |binary| and outputs the result in |text|. If |text| is a + // null pointer, SPV_ERROR_INVALID_POINTER is returned. + spv_result_t Disassemble(const std::vector& binary, + std::string* text) { + if (!text) return SPV_ERROR_INVALID_POINTER; + return tools_.Disassemble(binary, text, disassemble_options_) + ? SPV_SUCCESS + : SPV_ERROR_INVALID_BINARY; + } + + // Returns the accumulated error messages for the test. + std::string GetErrorMessage() const { return error_message_; } + + std::string ToText(const std::vector& inst) { + std::vector binary = {SpvMagicNumber, 0x10200, 0u, 2u, 0u}; + for (const Instruction* i : inst) + i->ToBinaryWithoutAttachedDebugInsts(&binary); + std::string text; + Disassemble(binary, &text); + return text; + } + + std::string ModuleToText() { + std::vector binary; + context_->module()->ToBinary(&binary, false); + std::string text; + Disassemble(binary, &text); + return text; + } + + spvtools::MessageConsumer GetConsumer() { return consumer_; } + + private: + // An instance for calling SPIRV-Tools functionalities. + spvtools::SpirvTools tools_; + std::unique_ptr context_; + spvtools::MessageConsumer consumer_; + uint32_t disassemble_options_; + std::string error_message_; +}; + +TEST_F(DecorationManagerTest, + ComparingDecorationsWithDiffOpcodesDecorateDecorateId) { + IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); + // This parameter can be interprated both as { SpvDecorationConstant } + // and also as a list of IDs: { 22 } + const std::vector param{SpvDecorationConstant}; + // OpDecorate %1 Constant + Instruction inst1( + &ir_context, SpvOpDecorate, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {1u}}, {SPV_OPERAND_TYPE_DECORATION, param}}); + // OpDecorateId %1 %22 ; 'Constant' is decoration number 22 + Instruction inst2( + &ir_context, SpvOpDecorateId, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {1u}}, {SPV_OPERAND_TYPE_ID, param}}); + DecorationManager* decoManager = ir_context.get_decoration_mgr(); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_FALSE(decoManager->AreDecorationsTheSame(&inst1, &inst2, true)); +} + +TEST_F(DecorationManagerTest, + ComparingDecorationsWithDiffOpcodesDecorateDecorateString) { + IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); + // This parameter can be interprated both as { SpvDecorationConstant } + // and also as a null-terminated string with a single character with value 22. + const std::vector param{SpvDecorationConstant}; + // OpDecorate %1 Constant + Instruction inst1( + &ir_context, SpvOpDecorate, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {1u}}, {SPV_OPERAND_TYPE_DECORATION, param}}); + // OpDecorateStringGOOGLE %1 !22 + Instruction inst2( + &ir_context, SpvOpDecorateStringGOOGLE, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {1u}}, {SPV_OPERAND_TYPE_LITERAL_STRING, param}}); + DecorationManager* decoManager = ir_context.get_decoration_mgr(); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_FALSE(decoManager->AreDecorationsTheSame(&inst1, &inst2, true)); +} + +TEST_F(DecorationManagerTest, ComparingDecorationsWithDiffDecorateParam) { + IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); + // OpDecorate %1 Constant + Instruction inst1(&ir_context, SpvOpDecorate, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {1u}}, + {SPV_OPERAND_TYPE_DECORATION, {SpvDecorationConstant}}}); + // OpDecorate %1 Restrict + Instruction inst2(&ir_context, SpvOpDecorate, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {1u}}, + {SPV_OPERAND_TYPE_DECORATION, {SpvDecorationRestrict}}}); + DecorationManager* decoManager = ir_context.get_decoration_mgr(); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_FALSE(decoManager->AreDecorationsTheSame(&inst1, &inst2, true)); +} + +TEST_F(DecorationManagerTest, ComparingDecorationsWithDiffDecorateIdParam) { + IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); + // OpDecorate %1 Constant + Instruction inst1( + &ir_context, SpvOpDecorateId, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {1u}}, {SPV_OPERAND_TYPE_ID, {555}}}); + // OpDecorate %1 Restrict + Instruction inst2( + &ir_context, SpvOpDecorateId, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {1u}}, {SPV_OPERAND_TYPE_ID, {666}}}); + DecorationManager* decoManager = ir_context.get_decoration_mgr(); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_FALSE(decoManager->AreDecorationsTheSame(&inst1, &inst2, true)); +} + +TEST_F(DecorationManagerTest, ComparingDecorationsWithDiffDecorateStringParam) { + IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); + // OpDecorate %1 Constant + Instruction inst1(&ir_context, SpvOpDecorateStringGOOGLE, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {1u}}, + {SPV_OPERAND_TYPE_LITERAL_STRING, MakeVector("Hello!")}}); + // OpDecorate %1 Restrict + Instruction inst2(&ir_context, SpvOpDecorateStringGOOGLE, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {1u}}, + {SPV_OPERAND_TYPE_LITERAL_STRING, MakeVector("Hellx")}}); + DecorationManager* decoManager = ir_context.get_decoration_mgr(); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_FALSE(decoManager->AreDecorationsTheSame(&inst1, &inst2, true)); +} + +TEST_F(DecorationManagerTest, ComparingSameDecorationsOnDiffTargetAllowed) { + IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); + // OpDecorate %1 Constant + Instruction inst1(&ir_context, SpvOpDecorate, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {1u}}, + {SPV_OPERAND_TYPE_DECORATION, {SpvDecorationConstant}}}); + // OpDecorate %2 Constant + Instruction inst2(&ir_context, SpvOpDecorate, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {2u}}, + {SPV_OPERAND_TYPE_DECORATION, {SpvDecorationConstant}}}); + DecorationManager* decoManager = ir_context.get_decoration_mgr(); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_TRUE(decoManager->AreDecorationsTheSame(&inst1, &inst2, true)); +} + +TEST_F(DecorationManagerTest, ComparingSameDecorationIdsOnDiffTargetAllowed) { + IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); + Instruction inst1( + &ir_context, SpvOpDecorateId, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {1u}}, {SPV_OPERAND_TYPE_DECORATION, {44}}}); + Instruction inst2( + &ir_context, SpvOpDecorateId, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {2u}}, {SPV_OPERAND_TYPE_DECORATION, {44}}}); + DecorationManager* decoManager = ir_context.get_decoration_mgr(); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_TRUE(decoManager->AreDecorationsTheSame(&inst1, &inst2, true)); +} + +TEST_F(DecorationManagerTest, + ComparingSameDecorationStringsOnDiffTargetAllowed) { + IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); + Instruction inst1(&ir_context, SpvOpDecorateStringGOOGLE, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {1u}}, + {SPV_OPERAND_TYPE_LITERAL_STRING, MakeVector("hello")}}); + Instruction inst2(&ir_context, SpvOpDecorateStringGOOGLE, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {2u}}, + {SPV_OPERAND_TYPE_LITERAL_STRING, MakeVector("hello")}}); + DecorationManager* decoManager = ir_context.get_decoration_mgr(); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_TRUE(decoManager->AreDecorationsTheSame(&inst1, &inst2, true)); +} + +TEST_F(DecorationManagerTest, ComparingSameDecorationsOnDiffTargetDisallowed) { + IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); + // OpDecorate %1 Constant + Instruction inst1(&ir_context, SpvOpDecorate, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {1u}}, + {SPV_OPERAND_TYPE_DECORATION, {SpvDecorationConstant}}}); + // OpDecorate %2 Constant + Instruction inst2(&ir_context, SpvOpDecorate, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {2u}}, + {SPV_OPERAND_TYPE_DECORATION, {SpvDecorationConstant}}}); + DecorationManager* decoManager = ir_context.get_decoration_mgr(); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_FALSE(decoManager->AreDecorationsTheSame(&inst1, &inst2, false)); +} + +TEST_F(DecorationManagerTest, ComparingMemberDecorationsOnSameTypeDiffMember) { + IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); + // OpMemberDecorate %1 0 Constant + Instruction inst1(&ir_context, SpvOpMemberDecorate, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {1u}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {0u}}, + {SPV_OPERAND_TYPE_DECORATION, {SpvDecorationConstant}}}); + // OpMemberDecorate %1 1 Constant + Instruction inst2(&ir_context, SpvOpMemberDecorate, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {1u}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {1u}}, + {SPV_OPERAND_TYPE_DECORATION, {SpvDecorationConstant}}}); + DecorationManager* decoManager = ir_context.get_decoration_mgr(); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_FALSE(decoManager->AreDecorationsTheSame(&inst1, &inst2, true)); +} + +TEST_F(DecorationManagerTest, + ComparingSameMemberDecorationsOnDiffTargetAllowed) { + IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); + // OpMemberDecorate %1 0 Constant + Instruction inst1(&ir_context, SpvOpMemberDecorate, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {1u}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {0u}}, + {SPV_OPERAND_TYPE_DECORATION, {SpvDecorationConstant}}}); + // OpMemberDecorate %2 0 Constant + Instruction inst2(&ir_context, SpvOpMemberDecorate, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {2u}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {0u}}, + {SPV_OPERAND_TYPE_DECORATION, {SpvDecorationConstant}}}); + DecorationManager* decoManager = ir_context.get_decoration_mgr(); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_TRUE(decoManager->AreDecorationsTheSame(&inst1, &inst2, true)); +} + +TEST_F(DecorationManagerTest, + ComparingSameMemberDecorationsOnDiffTargetDisallowed) { + IRContext ir_context(SPV_ENV_UNIVERSAL_1_2, GetConsumer()); + // OpMemberDecorate %1 0 Constant + Instruction inst1(&ir_context, SpvOpMemberDecorate, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {1u}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {0u}}, + {SPV_OPERAND_TYPE_DECORATION, {SpvDecorationConstant}}}); + // OpMemberDecorate %2 0 Constant + Instruction inst2(&ir_context, SpvOpMemberDecorate, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {2u}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {0u}}, + {SPV_OPERAND_TYPE_DECORATION, {SpvDecorationConstant}}}); + DecorationManager* decoManager = ir_context.get_decoration_mgr(); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_FALSE(decoManager->AreDecorationsTheSame(&inst1, &inst2, false)); +} + +TEST_F(DecorationManagerTest, RemoveDecorationFromVariable) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 Constant +OpDecorate %2 Restrict +%2 = OpDecorationGroup +OpGroupDecorate %2 %1 %3 +%4 = OpTypeInt 32 0 +%1 = OpVariable %4 Uniform +%3 = OpVariable %4 Uniform +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + decoManager->RemoveDecorationsFrom(1u); + auto decorations = decoManager->GetDecorationsFor(1u, false); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_TRUE(decorations.empty()); + decorations = decoManager->GetDecorationsFor(3u, false); + EXPECT_THAT(GetErrorMessage(), ""); + + const std::string expected_decorations = R"(OpDecorate %2 Restrict +)"; + EXPECT_THAT(ToText(decorations), expected_decorations); + + const std::string expected_binary = R"(OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %2 Restrict +%2 = OpDecorationGroup +OpGroupDecorate %2 %3 +%4 = OpTypeInt 32 0 +%1 = OpVariable %4 Uniform +%3 = OpVariable %4 Uniform +)"; + EXPECT_THAT(ModuleToText(), expected_binary); +} + +TEST_F(DecorationManagerTest, RemoveDecorationStringFromVariable) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpExtension "SPV_GOOGLE_hlsl_functionality1" +OpExtension "SPV_GOOGLE_decorate_string" +OpMemoryModel Logical GLSL450 +OpDecorateStringGOOGLE %1 HlslSemanticGOOGLE "hello world" +OpDecorate %2 Restrict +%2 = OpDecorationGroup +OpGroupDecorate %2 %1 %3 +%4 = OpTypeInt 32 0 +%1 = OpVariable %4 Uniform +%3 = OpVariable %4 Uniform +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + decoManager->RemoveDecorationsFrom(1u); + auto decorations = decoManager->GetDecorationsFor(1u, false); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_TRUE(decorations.empty()); + decorations = decoManager->GetDecorationsFor(3u, false); + EXPECT_THAT(GetErrorMessage(), ""); + + const std::string expected_decorations = R"(OpDecorate %2 Restrict +)"; + EXPECT_THAT(ToText(decorations), expected_decorations); + + const std::string expected_binary = R"(OpCapability Shader +OpCapability Linkage +OpExtension "SPV_GOOGLE_hlsl_functionality1" +OpExtension "SPV_GOOGLE_decorate_string" +OpMemoryModel Logical GLSL450 +OpDecorate %2 Restrict +%2 = OpDecorationGroup +OpGroupDecorate %2 %3 +%4 = OpTypeInt 32 0 +%1 = OpVariable %4 Uniform +%3 = OpVariable %4 Uniform +)"; + EXPECT_THAT(ModuleToText(), expected_binary); +} + +TEST_F(DecorationManagerTest, RemoveDecorationFromDecorationGroup) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 Constant +OpDecorate %2 Restrict +%2 = OpDecorationGroup +OpGroupDecorate %2 %1 %3 +%4 = OpTypeInt 32 0 +%1 = OpVariable %4 Uniform +%3 = OpVariable %4 Uniform +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + decoManager->RemoveDecorationsFrom(2u); + auto decorations = decoManager->GetDecorationsFor(2u, false); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_TRUE(decorations.empty()); + decorations = decoManager->GetDecorationsFor(1u, false); + EXPECT_THAT(GetErrorMessage(), ""); + + const std::string expected_decorations = R"(OpDecorate %1 Constant +)"; + EXPECT_THAT(ToText(decorations), expected_decorations); + decorations = decoManager->GetDecorationsFor(3u, false); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_THAT(ToText(decorations), ""); + + const std::string expected_binary = R"(OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 Constant +%2 = OpDecorationGroup +%4 = OpTypeInt 32 0 +%1 = OpVariable %4 Uniform +%3 = OpVariable %4 Uniform +)"; + EXPECT_THAT(ModuleToText(), expected_binary); +} + +TEST_F(DecorationManagerTest, + RemoveDecorationFromDecorationGroupKeepDeadDecorations) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 Constant +OpDecorate %2 Restrict +%2 = OpDecorationGroup +OpGroupDecorate %2 %1 +%3 = OpTypeInt 32 0 +%1 = OpVariable %3 Uniform +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + decoManager->RemoveDecorationsFrom(1u); + auto decorations = decoManager->GetDecorationsFor(1u, false); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_TRUE(decorations.empty()); + decorations = decoManager->GetDecorationsFor(2u, false); + EXPECT_THAT(GetErrorMessage(), ""); + + const std::string expected_decorations = R"(OpDecorate %2 Restrict +)"; + EXPECT_THAT(ToText(decorations), expected_decorations); + + const std::string expected_binary = R"(OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %2 Restrict +%2 = OpDecorationGroup +%3 = OpTypeInt 32 0 +%1 = OpVariable %3 Uniform +)"; + EXPECT_THAT(ModuleToText(), expected_binary); +} + +TEST_F(DecorationManagerTest, RemoveAllDecorationsAppliedByGroup) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 Constant +OpDecorate %2 Restrict +%2 = OpDecorationGroup +OpGroupDecorate %2 %1 +OpDecorate %3 BuiltIn VertexId +%3 = OpDecorationGroup +OpGroupDecorate %3 %1 +%4 = OpTypeInt 32 0 +%1 = OpVariable %4 Input +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + decoManager->RemoveDecorationsFrom(1u, [](const Instruction& inst) { + return inst.opcode() == SpvOpDecorate && + inst.GetSingleWordInOperand(0u) == 3u; + }); + auto decorations = decoManager->GetDecorationsFor(1u, false); + EXPECT_THAT(GetErrorMessage(), ""); + + std::string expected_decorations = R"(OpDecorate %1 Constant +OpDecorate %2 Restrict +)"; + EXPECT_THAT(ToText(decorations), expected_decorations); + decorations = decoManager->GetDecorationsFor(2u, false); + EXPECT_THAT(GetErrorMessage(), ""); + + expected_decorations = R"(OpDecorate %2 Restrict +)"; + EXPECT_THAT(ToText(decorations), expected_decorations); + + const std::string expected_binary = R"(OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 Constant +OpDecorate %2 Restrict +%2 = OpDecorationGroup +OpGroupDecorate %2 %1 +OpDecorate %3 BuiltIn VertexId +%3 = OpDecorationGroup +%4 = OpTypeInt 32 0 +%1 = OpVariable %4 Input +)"; + EXPECT_THAT(ModuleToText(), expected_binary); +} + +TEST_F(DecorationManagerTest, RemoveSomeDecorationsAppliedByGroup) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 Constant +OpDecorate %2 Restrict +%2 = OpDecorationGroup +OpGroupDecorate %2 %1 +OpDecorate %3 BuiltIn VertexId +OpDecorate %3 Invariant +%3 = OpDecorationGroup +OpGroupDecorate %3 %1 +%uint = OpTypeInt 32 0 +%1 = OpVariable %uint Input +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + decoManager->RemoveDecorationsFrom(1u, [](const Instruction& inst) { + return inst.opcode() == SpvOpDecorate && + inst.GetSingleWordInOperand(0u) == 3u && + inst.GetSingleWordInOperand(1u) == SpvDecorationBuiltIn; + }); + auto decorations = decoManager->GetDecorationsFor(1u, false); + EXPECT_THAT(GetErrorMessage(), ""); + + std::string expected_decorations = R"(OpDecorate %1 Constant +OpDecorate %1 Invariant +OpDecorate %2 Restrict +)"; + EXPECT_THAT(ToText(decorations), expected_decorations); + decorations = decoManager->GetDecorationsFor(2u, false); + EXPECT_THAT(GetErrorMessage(), ""); + + expected_decorations = R"(OpDecorate %2 Restrict +)"; + EXPECT_THAT(ToText(decorations), expected_decorations); + + const std::string expected_binary = R"(OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 Constant +OpDecorate %2 Restrict +%2 = OpDecorationGroup +OpGroupDecorate %2 %1 +OpDecorate %3 BuiltIn VertexId +OpDecorate %3 Invariant +%3 = OpDecorationGroup +OpDecorate %1 Invariant +%4 = OpTypeInt 32 0 +%1 = OpVariable %4 Input +)"; + EXPECT_THAT(ModuleToText(), expected_binary); +} + +TEST_F(DecorationManagerTest, RemoveDecorationDecorate) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 Constant +OpDecorate %1 Restrict +%2 = OpTypeInt 32 0 +%1 = OpVariable %2 Uniform +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + auto decorations = decoManager->GetDecorationsFor(1u, false); + decoManager->RemoveDecoration(decorations.front()); + decorations = decoManager->GetDecorationsFor(1u, false); + EXPECT_THAT(GetErrorMessage(), ""); + + const std::string expected_decorations = R"(OpDecorate %1 Restrict +)"; + EXPECT_THAT(ToText(decorations), expected_decorations); +} + +TEST_F(DecorationManagerTest, RemoveDecorationStringDecorate) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpExtension "SPV_GOOGLE_hlsl_functionality1" +OpExtension "SPV_GOOGLE_decorate_string" +OpMemoryModel Logical GLSL450 +OpDecorateStringGOOGLE %1 HlslSemanticGOOGLE "foobar" +OpDecorate %1 Restrict +%2 = OpTypeInt 32 0 +%1 = OpVariable %2 Uniform +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + auto decorations = decoManager->GetDecorationsFor(1u, false); + decoManager->RemoveDecoration(decorations.front()); + decorations = decoManager->GetDecorationsFor(1u, false); + EXPECT_THAT(GetErrorMessage(), ""); + + const std::string expected_decorations = R"(OpDecorate %1 Restrict +)"; + EXPECT_THAT(ToText(decorations), expected_decorations); +} + +TEST_F(DecorationManagerTest, CloneDecorations) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 Constant +OpDecorate %2 Restrict +%2 = OpDecorationGroup +OpGroupDecorate %2 %1 +OpDecorate %3 BuiltIn VertexId +OpDecorate %3 Invariant +%3 = OpDecorationGroup +OpGroupDecorate %3 %1 +%4 = OpTypeInt 32 0 +%1 = OpVariable %4 Input +%5 = OpVariable %4 Input +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + + // Check cloning OpDecorate including group decorations. + auto decorations = decoManager->GetDecorationsFor(5u, false); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_TRUE(decorations.empty()); + + decoManager->CloneDecorations(1u, 5u); + decorations = decoManager->GetDecorationsFor(5u, false); + EXPECT_THAT(GetErrorMessage(), ""); + + std::string expected_decorations = R"(OpDecorate %5 Constant +OpDecorate %2 Restrict +OpDecorate %3 BuiltIn VertexId +OpDecorate %3 Invariant +)"; + EXPECT_THAT(ToText(decorations), expected_decorations); + + // Check that bookkeeping for ID 2 remains the same. + decorations = decoManager->GetDecorationsFor(2u, false); + EXPECT_THAT(GetErrorMessage(), ""); + + expected_decorations = R"(OpDecorate %2 Restrict +)"; + EXPECT_THAT(ToText(decorations), expected_decorations); + + const std::string expected_binary = R"(OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 Constant +OpDecorate %2 Restrict +%2 = OpDecorationGroup +OpGroupDecorate %2 %1 %5 +OpDecorate %3 BuiltIn VertexId +OpDecorate %3 Invariant +%3 = OpDecorationGroup +OpGroupDecorate %3 %1 %5 +OpDecorate %5 Constant +%4 = OpTypeInt 32 0 +%1 = OpVariable %4 Input +%5 = OpVariable %4 Input +)"; + EXPECT_THAT(ModuleToText(), expected_binary); +} + +TEST_F(DecorationManagerTest, CloneDecorationsStringAndId) { + const std::string spirv = R"(OpCapability Shader +OpCapability Linkage +OpExtension "SPV_GOOGLE_hlsl_functionality1" +OpExtension "SPV_GOOGLE_decorate_string" +OpMemoryModel Logical GLSL450 +OpDecorateStringGOOGLE %1 HlslSemanticGOOGLE "blah" +OpDecorateId %1 HlslCounterBufferGOOGLE %2 +OpDecorate %1 Aliased +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Uniform %3 +%1 = OpVariable %4 Uniform +%2 = OpVariable %4 Uniform +%5 = OpVariable %4 Uniform +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + + // Check cloning OpDecorate including group decorations. + auto decorations = decoManager->GetDecorationsFor(5u, false); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_TRUE(decorations.empty()); + + decoManager->CloneDecorations(1u, 5u); + decorations = decoManager->GetDecorationsFor(5u, false); + EXPECT_THAT(GetErrorMessage(), ""); + + std::string expected_decorations = + R"(OpDecorateStringGOOGLE %5 HlslSemanticGOOGLE "blah" +OpDecorateId %5 HlslCounterBufferGOOGLE %2 +OpDecorate %5 Aliased +)"; + EXPECT_THAT(ToText(decorations), expected_decorations); + + const std::string expected_binary = R"(OpCapability Shader +OpCapability Linkage +OpExtension "SPV_GOOGLE_hlsl_functionality1" +OpExtension "SPV_GOOGLE_decorate_string" +OpMemoryModel Logical GLSL450 +OpDecorateStringGOOGLE %1 HlslSemanticGOOGLE "blah" +OpDecorateId %1 HlslCounterBufferGOOGLE %2 +OpDecorate %1 Aliased +OpDecorateStringGOOGLE %5 HlslSemanticGOOGLE "blah" +OpDecorateId %5 HlslCounterBufferGOOGLE %2 +OpDecorate %5 Aliased +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Uniform %3 +%1 = OpVariable %4 Uniform +%2 = OpVariable %4 Uniform +%5 = OpVariable %4 Uniform +)"; + EXPECT_THAT(ModuleToText(), expected_binary); +} + +TEST_F(DecorationManagerTest, CloneSomeDecorations) { + const std::string spirv = R"(OpCapability Shader +OpCapability Linkage +OpExtension "SPV_GOOGLE_hlsl_functionality1" +OpExtension "SPV_GOOGLE_decorate_string" +OpMemoryModel Logical GLSL450 +OpDecorate %1 RelaxedPrecision +OpDecorate %1 Restrict +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Function %2 +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpFunction %4 None %5 +%7 = OpLabel +%1 = OpVariable %3 Function +%8 = OpUndef %2 +OpReturn +OpFunctionEnd +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_EQ(GetErrorMessage(), ""); + + // Check cloning OpDecorate including group decorations. + auto decorations = decoManager->GetDecorationsFor(8u, false); + EXPECT_EQ(GetErrorMessage(), ""); + EXPECT_TRUE(decorations.empty()); + + decoManager->CloneDecorations(1u, 8u, {SpvDecorationRelaxedPrecision}); + decorations = decoManager->GetDecorationsFor(8u, false); + EXPECT_THAT(GetErrorMessage(), ""); + + std::string expected_decorations = + R"(OpDecorate %8 RelaxedPrecision +)"; + EXPECT_EQ(ToText(decorations), expected_decorations); + + const std::string expected_binary = R"(OpCapability Shader +OpCapability Linkage +OpExtension "SPV_GOOGLE_hlsl_functionality1" +OpExtension "SPV_GOOGLE_decorate_string" +OpMemoryModel Logical GLSL450 +OpDecorate %1 RelaxedPrecision +OpDecorate %1 Restrict +OpDecorate %8 RelaxedPrecision +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Function %2 +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpFunction %4 None %5 +%7 = OpLabel +%1 = OpVariable %3 Function +%8 = OpUndef %2 +OpReturn +OpFunctionEnd +)"; + EXPECT_EQ(ModuleToText(), expected_binary); +} + +// Test cloning decoration for an id that is decorated via a group decoration. +TEST_F(DecorationManagerTest, CloneSomeGroupDecorations) { + const std::string spirv = R"(OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 RelaxedPrecision +OpDecorate %1 Restrict +%1 = OpDecorationGroup +OpGroupDecorate %1 %2 +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Function %3 +%5 = OpTypeVoid +%6 = OpTypeFunction %5 +%7 = OpFunction %5 None %6 +%8 = OpLabel +%2 = OpVariable %4 Function +%9 = OpUndef %3 +OpReturn +OpFunctionEnd +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_EQ(GetErrorMessage(), ""); + + // Check cloning OpDecorate including group decorations. + auto decorations = decoManager->GetDecorationsFor(9u, false); + EXPECT_EQ(GetErrorMessage(), ""); + EXPECT_TRUE(decorations.empty()); + + decoManager->CloneDecorations(2u, 9u, {SpvDecorationRelaxedPrecision}); + decorations = decoManager->GetDecorationsFor(9u, false); + EXPECT_THAT(GetErrorMessage(), ""); + + std::string expected_decorations = + R"(OpDecorate %9 RelaxedPrecision +)"; + EXPECT_EQ(ToText(decorations), expected_decorations); + + const std::string expected_binary = R"(OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 RelaxedPrecision +OpDecorate %1 Restrict +%1 = OpDecorationGroup +OpGroupDecorate %1 %2 +OpDecorate %9 RelaxedPrecision +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Function %3 +%5 = OpTypeVoid +%6 = OpTypeFunction %5 +%7 = OpFunction %5 None %6 +%8 = OpLabel +%2 = OpVariable %4 Function +%9 = OpUndef %3 +OpReturn +OpFunctionEnd +)"; + EXPECT_EQ(ModuleToText(), expected_binary); +} + +TEST_F(DecorationManagerTest, HaveTheSameDecorationsWithoutGroupsTrue) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 Restrict +OpDecorate %2 Constant +OpDecorate %2 Restrict +OpDecorate %1 Constant +%u32 = OpTypeInt 32 0 +%1 = OpVariable %u32 Uniform +%2 = OpVariable %u32 Uniform +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_TRUE(decoManager->HaveTheSameDecorations(1u, 2u)); +} + +TEST_F(DecorationManagerTest, HaveTheSameDecorationsWithoutGroupsFalse) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 Restrict +OpDecorate %2 Constant +OpDecorate %2 Restrict +%u32 = OpTypeInt 32 0 +%1 = OpVariable %u32 Uniform +%2 = OpVariable %u32 Uniform +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_FALSE(decoManager->HaveTheSameDecorations(1u, 2u)); +} + +TEST_F(DecorationManagerTest, HaveTheSameDecorationsIdWithoutGroupsTrue) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorateId %1 AlignmentId %nine +OpDecorateId %3 MaxByteOffsetId %nine +OpDecorateId %3 AlignmentId %nine +OpDecorateId %1 MaxByteOffsetId %nine +%u32 = OpTypeInt 32 0 +%nine = OpConstant %u32 9 +%1 = OpVariable %u32 Uniform +%3 = OpVariable %u32 Uniform +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_TRUE(decoManager->HaveTheSameDecorations(1u, 3u)); +} + +TEST_F(DecorationManagerTest, HaveTheSameDecorationsIdWithoutGroupsFalse) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorateId %1 AlignmentId %nine +OpDecorateId %2 MaxByteOffsetId %nine +OpDecorateId %2 AlignmentId %nine +%u32 = OpTypeInt 32 0 +%nine = OpConstant %u32 9 +%1 = OpVariable %u32 Uniform +%2 = OpVariable %u32 Uniform +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_FALSE(decoManager->HaveTheSameDecorations(1u, 2u)); +} + +TEST_F(DecorationManagerTest, HaveTheSameDecorationsStringWithoutGroupsTrue) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Linkage +OpExtension "SPV_GOOGLE_hlsl_functionality1" +OpExtension "SPV_GOOGLE_decorate_string" +OpMemoryModel Logical GLSL450 +OpDecorateStringGOOGLE %1 HlslSemanticGOOGLE "hello" +OpDecorateStringGOOGLE %2 HlslSemanticGOOGLE "world" +OpDecorateStringGOOGLE %2 HlslSemanticGOOGLE "hello" +OpDecorateStringGOOGLE %1 HlslSemanticGOOGLE "world" +%u32 = OpTypeInt 32 0 +%1 = OpVariable %u32 Uniform +%2 = OpVariable %u32 Uniform +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_TRUE(decoManager->HaveTheSameDecorations(1u, 2u)); +} + +TEST_F(DecorationManagerTest, HaveTheSameDecorationsStringWithoutGroupsFalse) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Linkage +OpExtension "SPV_GOOGLE_hlsl_functionality1" +OpExtension "SPV_GOOGLE_decorate_string" +OpMemoryModel Logical GLSL450 +OpDecorateStringGOOGLE %1 HlslSemanticGOOGLE "hello" +OpDecorateStringGOOGLE %2 HlslSemanticGOOGLE "world" +OpDecorateStringGOOGLE %2 HlslSemanticGOOGLE "hello" +%u32 = OpTypeInt 32 0 +%1 = OpVariable %u32 Uniform +%2 = OpVariable %u32 Uniform +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_FALSE(decoManager->HaveTheSameDecorations(1u, 2u)); +} + +TEST_F(DecorationManagerTest, HaveTheSameDecorationsWithGroupsTrue) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 Restrict +OpDecorate %2 Constant +OpDecorate %1 Constant +OpDecorate %3 Restrict +%3 = OpDecorationGroup +OpGroupDecorate %3 %2 +OpDecorate %4 Invariant +%4 = OpDecorationGroup +OpGroupDecorate %4 %1 %2 +%u32 = OpTypeInt 32 0 +%1 = OpVariable %u32 Uniform +%2 = OpVariable %u32 Uniform +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_TRUE(decoManager->HaveTheSameDecorations(1u, 2u)); +} + +TEST_F(DecorationManagerTest, HaveTheSameDecorationsWithGroupsFalse) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 Restrict +OpDecorate %2 Constant +OpDecorate %1 Constant +OpDecorate %4 Invariant +%4 = OpDecorationGroup +OpGroupDecorate %4 %1 %2 +%u32 = OpTypeInt 32 0 +%1 = OpVariable %u32 Uniform +%2 = OpVariable %u32 Uniform +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_FALSE(decoManager->HaveTheSameDecorations(1u, 2u)); +} + +TEST_F(DecorationManagerTest, HaveTheSameDecorationsDuplicateDecorations) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 Constant +OpDecorate %2 Constant +OpDecorate %2 Constant +%u32 = OpTypeInt 32 0 +%1 = OpVariable %u32 Uniform +%2 = OpVariable %u32 Uniform +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_TRUE(decoManager->HaveTheSameDecorations(1u, 2u)); +} + +TEST_F(DecorationManagerTest, HaveTheSameDecorationsDifferentVariations) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 Location 0 +OpDecorate %2 Location 1 +%u32 = OpTypeInt 32 0 +%1 = OpVariable %u32 Uniform +%2 = OpVariable %u32 Uniform +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_FALSE(decoManager->HaveTheSameDecorations(1u, 2u)); +} + +TEST_F(DecorationManagerTest, + HaveTheSameDecorationsDuplicateMemberDecorations) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpMemberDecorate %1 0 Location 0 +OpMemberDecorate %2 0 Location 0 +OpMemberDecorate %2 0 Location 0 +%u32 = OpTypeInt 32 0 +%1 = OpTypeStruct %u32 %u32 +%2 = OpTypeStruct %u32 %u32 +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_TRUE(decoManager->HaveTheSameDecorations(1u, 2u)); +} + +TEST_F(DecorationManagerTest, + HaveTheSameDecorationsDifferentMemberSameDecoration) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpMemberDecorate %1 0 Location 0 +OpMemberDecorate %2 1 Location 0 +%u32 = OpTypeInt 32 0 +%1 = OpTypeStruct %u32 %u32 +%2 = OpTypeStruct %u32 %u32 +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_FALSE(decoManager->HaveTheSameDecorations(1u, 2u)); +} + +TEST_F(DecorationManagerTest, HaveTheSameDecorationsDifferentMemberVariations) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpMemberDecorate %1 0 Location 0 +OpMemberDecorate %2 0 Location 1 +%u32 = OpTypeInt 32 0 +%1 = OpTypeStruct %u32 %u32 +%2 = OpTypeStruct %u32 %u32 +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_FALSE(decoManager->HaveTheSameDecorations(1u, 2u)); +} + +TEST_F(DecorationManagerTest, HaveTheSameDecorationsDuplicateIdDecorations) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorateId %1 AlignmentId %2 +OpDecorateId %3 AlignmentId %2 +OpDecorateId %3 AlignmentId %2 +%u32 = OpTypeInt 32 0 +%1 = OpVariable %u32 Uniform +%3 = OpVariable %u32 Uniform +%2 = OpSpecConstant %u32 0 +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_TRUE(decoManager->HaveTheSameDecorations(1u, 3u)); +} + +TEST_F(DecorationManagerTest, + HaveTheSameDecorationsDuplicateStringDecorations) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpExtension "SPV_GOOGLE_hlsl_functionality1" +OpExtension "SPV_GOOGLE_decorate_string" +OpMemoryModel Logical GLSL450 +OpDecorateStringGOOGLE %1 HlslSemanticGOOGLE "hello" +OpDecorateStringGOOGLE %2 HlslSemanticGOOGLE "hello" +OpDecorateStringGOOGLE %2 HlslSemanticGOOGLE "hello" +%u32 = OpTypeInt 32 0 +%1 = OpVariable %u32 Uniform +%2 = OpVariable %u32 Uniform +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_TRUE(decoManager->HaveTheSameDecorations(1u, 2u)); +} + +TEST_F(DecorationManagerTest, HaveTheSameDecorationsDifferentIdVariations) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorateId %1 AlignmentId %2 +OpDecorateId %3 AlignmentId %4 +%u32 = OpTypeInt 32 0 +%1 = OpVariable %u32 Uniform +%3 = OpVariable %u32 Uniform +%2 = OpSpecConstant %u32 0 +%4 = OpSpecConstant %u32 0 +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_FALSE(decoManager->HaveTheSameDecorations(1u, 2u)); +} + +TEST_F(DecorationManagerTest, HaveTheSameDecorationsDifferentStringVariations) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpExtension "SPV_GOOGLE_hlsl_functionality1" +OpExtension "SPV_GOOGLE_decorate_string" +OpMemoryModel Logical GLSL450 +OpDecorateStringGOOGLE %1 HlslSemanticGOOGLE "hello" +OpDecorateStringGOOGLE %2 HlslSemanticGOOGLE "world" +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_FALSE(decoManager->HaveTheSameDecorations(1u, 2u)); +} + +TEST_F(DecorationManagerTest, HaveTheSameDecorationsLeftSymmetry) { + // Left being a subset of right is not enough. + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 Constant +OpDecorate %1 Constant +OpDecorate %2 Constant +OpDecorate %2 Restrict +%u32 = OpTypeInt 32 0 +%1 = OpVariable %u32 Uniform +%2 = OpVariable %u32 Uniform +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_FALSE(decoManager->HaveTheSameDecorations(1u, 2u)); +} + +TEST_F(DecorationManagerTest, HaveTheSameDecorationsRightSymmetry) { + // Right being a subset of left is not enough. + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 Constant +OpDecorate %1 Restrict +OpDecorate %2 Constant +OpDecorate %2 Constant +%u32 = OpTypeInt 32 0 +%1 = OpVariable %u32 Uniform +%2 = OpVariable %u32 Uniform +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_FALSE(decoManager->HaveTheSameDecorations(1u, 2u)); +} + +TEST_F(DecorationManagerTest, HaveTheSameDecorationIdsLeftSymmetry) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorateId %1 AlignmentId %nine +OpDecorateId %1 AlignmentId %nine +OpDecorateId %2 AlignmentId %nine +OpDecorateId %2 MaxByteOffsetId %nine +%u32 = OpTypeInt 32 0 +%nine = OpConstant %u32 9 +%1 = OpVariable %u32 Uniform +%2 = OpVariable %u32 Uniform +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_FALSE(decoManager->HaveTheSameDecorations(1u, 2u)); +} + +TEST_F(DecorationManagerTest, HaveTheSameDecorationIdsRightSymmetry) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorateId %1 AlignmentId %nine +OpDecorateId %1 MaxByteOffsetId %nine +OpDecorateId %2 AlignmentId %nine +OpDecorateId %2 AlignmentId %nine +%u32 = OpTypeInt 32 0 +%nine = OpConstant %u32 9 +%1 = OpVariable %u32 Uniform +%2 = OpVariable %u32 Uniform +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_FALSE(decoManager->HaveTheSameDecorations(1u, 2u)); +} + +TEST_F(DecorationManagerTest, HaveTheSameDecorationStringsLeftSymmetry) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Linkage +OpExtension "SPV_GOOGLE_hlsl_functionality1" +OpExtension "SPV_GOOGLE_decorate_string" +OpMemoryModel Logical GLSL450 +OpDecorateStringGOOGLE %1 HlslSemanticGOOGLE "hello" +OpDecorateStringGOOGLE %1 HlslSemanticGOOGLE "hello" +OpDecorateStringGOOGLE %2 HlslSemanticGOOGLE "hello" +OpDecorateStringGOOGLE %2 HlslSemanticGOOGLE "world" +%u32 = OpTypeInt 32 0 +%1 = OpVariable %u32 Uniform +%2 = OpVariable %u32 Uniform +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_FALSE(decoManager->HaveTheSameDecorations(1u, 2u)); +} + +TEST_F(DecorationManagerTest, HaveTheSameDecorationStringsRightSymmetry) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Linkage +OpExtension "SPV_GOOGLE_hlsl_functionality1" +OpExtension "SPV_GOOGLE_decorate_string" +OpMemoryModel Logical GLSL450 +OpDecorateStringGOOGLE %1 HlslSemanticGOOGLE "hello" +OpDecorateStringGOOGLE %1 HlslSemanticGOOGLE "world" +OpDecorateStringGOOGLE %2 HlslSemanticGOOGLE "hello" +OpDecorateStringGOOGLE %2 HlslSemanticGOOGLE "hello" +%u32 = OpTypeInt 32 0 +%1 = OpVariable %u32 Uniform +%2 = OpVariable %u32 Uniform +)"; + DecorationManager* decoManager = GetDecorationManager(spirv); + EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_FALSE(decoManager->HaveTheSameDecorations(1u, 2u)); +} + +} // namespace +} // namespace analysis +} // namespace opt +} // namespace spvtools diff --git a/test/opt/def_use_test.cpp b/test/opt/def_use_test.cpp new file mode 100644 index 000000000..3b856ce7f --- /dev/null +++ b/test/opt/def_use_test.cpp @@ -0,0 +1,1719 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "source/opt/build_module.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "spirv-tools/libspirv.hpp" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace analysis { +namespace { + +using ::testing::Contains; +using ::testing::UnorderedElementsAre; +using ::testing::UnorderedElementsAreArray; + +// Returns the number of uses of |id|. +uint32_t NumUses(const std::unique_ptr& context, uint32_t id) { + uint32_t count = 0; + context->get_def_use_mgr()->ForEachUse( + id, [&count](Instruction*, uint32_t) { ++count; }); + return count; +} + +// Returns the opcode of each use of |id|. +// +// If |id| is used multiple times in a single instruction, that instruction's +// opcode will appear a corresponding number of times. +std::vector GetUseOpcodes(const std::unique_ptr& context, + uint32_t id) { + std::vector opcodes; + context->get_def_use_mgr()->ForEachUse( + id, [&opcodes](Instruction* user, uint32_t) { + opcodes.push_back(user->opcode()); + }); + return opcodes; +} + +// Disassembles the given |inst| and returns the disassembly. +std::string DisassembleInst(Instruction* inst) { + SpirvTools tools(SPV_ENV_UNIVERSAL_1_1); + + std::vector binary; + // We need this to generate the necessary header in the binary. + tools.Assemble("", &binary); + inst->ToBinaryWithoutAttachedDebugInsts(&binary); + + std::string text; + // We'll need to check the underlying id numbers. + // So turn off friendly names for ids. + tools.Disassemble(binary, &text, SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + while (!text.empty() && text.back() == '\n') text.pop_back(); + return text; +} + +// A struct for holding expected id defs and uses. +struct InstDefUse { + using IdInstPair = std::pair; + using IdInstsPair = std::pair>; + + // Ids and their corresponding def instructions. + std::vector defs; + // Ids and their corresponding use instructions. + std::vector uses; +}; + +// Checks that the |actual_defs| and |actual_uses| are in accord with +// |expected_defs_uses|. +void CheckDef(const InstDefUse& expected_defs_uses, + const DefUseManager::IdToDefMap& actual_defs) { + // Check defs. + ASSERT_EQ(expected_defs_uses.defs.size(), actual_defs.size()); + for (uint32_t i = 0; i < expected_defs_uses.defs.size(); ++i) { + const auto id = expected_defs_uses.defs[i].first; + const auto expected_def = expected_defs_uses.defs[i].second; + ASSERT_EQ(1u, actual_defs.count(id)) << "expected to def id [" << id << "]"; + auto def = actual_defs.at(id); + if (def->opcode() != SpvOpConstant) { + // Constants don't disassemble properly without a full context. + EXPECT_EQ(expected_def, DisassembleInst(actual_defs.at(id))); + } + } +} + +using UserMap = std::unordered_map>; + +// Creates a mapping of all definitions to their users (except OpConstant). +// +// OpConstants are skipped because they cannot be disassembled in isolation. +UserMap BuildAllUsers(const DefUseManager* mgr, uint32_t idBound) { + UserMap userMap; + for (uint32_t id = 0; id != idBound; ++id) { + if (mgr->GetDef(id)) { + mgr->ForEachUser(id, [id, &userMap](Instruction* user) { + if (user->opcode() != SpvOpConstant) { + userMap[id].push_back(user); + } + }); + } + } + return userMap; +} + +// Constants don't disassemble properly without a full context, so skip them as +// checks. +void CheckUse(const InstDefUse& expected_defs_uses, const DefUseManager* mgr, + uint32_t idBound) { + UserMap actual_uses = BuildAllUsers(mgr, idBound); + // Check uses. + ASSERT_EQ(expected_defs_uses.uses.size(), actual_uses.size()); + for (uint32_t i = 0; i < expected_defs_uses.uses.size(); ++i) { + const auto id = expected_defs_uses.uses[i].first; + const auto& expected_uses = expected_defs_uses.uses[i].second; + + ASSERT_EQ(1u, actual_uses.count(id)) << "expected to use id [" << id << "]"; + const auto& uses = actual_uses.at(id); + + ASSERT_EQ(expected_uses.size(), uses.size()) + << "id [" << id << "] # uses: expected: " << expected_uses.size() + << " actual: " << uses.size(); + + std::vector actual_uses_disassembled; + for (const auto actual_use : uses) { + actual_uses_disassembled.emplace_back(DisassembleInst(actual_use)); + } + EXPECT_THAT(actual_uses_disassembled, + UnorderedElementsAreArray(expected_uses)); + } +} + +// The following test case mimics how LLVM handles induction variables. +// But, yeah, it's not very readable. However, we only care about the id +// defs and uses. So, no need to make sure this is valid OpPhi construct. +const char kOpPhiTestFunction[] = + " %1 = OpTypeVoid " + " %6 = OpTypeInt 32 0 " + "%10 = OpTypeFloat 32 " + "%16 = OpTypeBool " + " %3 = OpTypeFunction %1 " + " %8 = OpConstant %6 0 " + "%18 = OpConstant %6 1 " + "%12 = OpConstant %10 1.0 " + " %2 = OpFunction %1 None %3 " + " %4 = OpLabel " + " OpBranch %5 " + + " %5 = OpLabel " + " %7 = OpPhi %6 %8 %4 %9 %5 " + "%11 = OpPhi %10 %12 %4 %13 %5 " + " %9 = OpIAdd %6 %7 %8 " + "%13 = OpFAdd %10 %11 %12 " + "%17 = OpSLessThan %16 %7 %18 " + " OpLoopMerge %19 %5 None " + " OpBranchConditional %17 %5 %19 " + + "%19 = OpLabel " + " OpReturn " + " OpFunctionEnd"; + +struct ParseDefUseCase { + const char* text; + InstDefUse du; +}; + +using ParseDefUseTest = ::testing::TestWithParam; + +TEST_P(ParseDefUseTest, Case) { + const auto& tc = GetParam(); + + // Build module. + const std::vector text = {tc.text}; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text), + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(nullptr, context); + + // Analyze def and use. + DefUseManager manager(context->module()); + + CheckDef(tc.du, manager.id_to_defs()); + CheckUse(tc.du, &manager, context->module()->IdBound()); +} + +// clang-format off +INSTANTIATE_TEST_CASE_P( + TestCase, ParseDefUseTest, + ::testing::ValuesIn(std::vector{ + {"", {{}, {}}}, // no instruction + {"OpMemoryModel Logical GLSL450", {{}, {}}}, // no def and use + { // single def, no use + "%1 = OpString \"wow\"", + { + {{1, "%1 = OpString \"wow\""}}, // defs + {} // uses + } + }, + { // multiple def, no use + "%1 = OpString \"hello\" " + "%2 = OpString \"world\" " + "%3 = OpTypeVoid", + { + { // defs + {1, "%1 = OpString \"hello\""}, + {2, "%2 = OpString \"world\""}, + {3, "%3 = OpTypeVoid"}, + }, + {} // uses + } + }, + { // multiple def, multiple use + "%1 = OpTypeBool " + "%2 = OpTypeVector %1 3 " + "%3 = OpTypeMatrix %2 3", + { + { // defs + {1, "%1 = OpTypeBool"}, + {2, "%2 = OpTypeVector %1 3"}, + {3, "%3 = OpTypeMatrix %2 3"}, + }, + { // uses + {1, {"%2 = OpTypeVector %1 3"}}, + {2, {"%3 = OpTypeMatrix %2 3"}}, + } + } + }, + { // multiple use of the same id + "%1 = OpTypeBool " + "%2 = OpTypeVector %1 2 " + "%3 = OpTypeVector %1 3 " + "%4 = OpTypeVector %1 4", + { + { // defs + {1, "%1 = OpTypeBool"}, + {2, "%2 = OpTypeVector %1 2"}, + {3, "%3 = OpTypeVector %1 3"}, + {4, "%4 = OpTypeVector %1 4"}, + }, + { // uses + {1, + { + "%2 = OpTypeVector %1 2", + "%3 = OpTypeVector %1 3", + "%4 = OpTypeVector %1 4", + } + }, + } + } + }, + { // labels + "%1 = OpTypeVoid " + "%2 = OpTypeBool " + "%3 = OpTypeFunction %1 " + "%4 = OpConstantTrue %2 " + "%5 = OpFunction %1 None %3 " + + "%6 = OpLabel " + "OpBranchConditional %4 %7 %8 " + + "%7 = OpLabel " + "OpBranch %7 " + + "%8 = OpLabel " + "OpReturn " + + "OpFunctionEnd", + { + { // defs + {1, "%1 = OpTypeVoid"}, + {2, "%2 = OpTypeBool"}, + {3, "%3 = OpTypeFunction %1"}, + {4, "%4 = OpConstantTrue %2"}, + {5, "%5 = OpFunction %1 None %3"}, + {6, "%6 = OpLabel"}, + {7, "%7 = OpLabel"}, + {8, "%8 = OpLabel"}, + }, + { // uses + {1, { + "%3 = OpTypeFunction %1", + "%5 = OpFunction %1 None %3", + } + }, + {2, {"%4 = OpConstantTrue %2"}}, + {3, {"%5 = OpFunction %1 None %3"}}, + {4, {"OpBranchConditional %4 %7 %8"}}, + {7, + { + "OpBranchConditional %4 %7 %8", + "OpBranch %7", + } + }, + {8, {"OpBranchConditional %4 %7 %8"}}, + } + } + }, + { // cross function + "%1 = OpTypeBool " + "%3 = OpTypeFunction %1 " + "%2 = OpFunction %1 None %3 " + + "%4 = OpLabel " + "%5 = OpVariable %1 Function " + "%6 = OpFunctionCall %1 %2 %5 " + "OpReturnValue %6 " + + "OpFunctionEnd", + { + { // defs + {1, "%1 = OpTypeBool"}, + {2, "%2 = OpFunction %1 None %3"}, + {3, "%3 = OpTypeFunction %1"}, + {4, "%4 = OpLabel"}, + {5, "%5 = OpVariable %1 Function"}, + {6, "%6 = OpFunctionCall %1 %2 %5"}, + }, + { // uses + {1, + { + "%2 = OpFunction %1 None %3", + "%3 = OpTypeFunction %1", + "%5 = OpVariable %1 Function", + "%6 = OpFunctionCall %1 %2 %5", + } + }, + {2, {"%6 = OpFunctionCall %1 %2 %5"}}, + {3, {"%2 = OpFunction %1 None %3"}}, + {5, {"%6 = OpFunctionCall %1 %2 %5"}}, + {6, {"OpReturnValue %6"}}, + } + } + }, + { // selection merge and loop merge + "%1 = OpTypeVoid " + "%3 = OpTypeFunction %1 " + "%10 = OpTypeBool " + "%8 = OpConstantTrue %10 " + "%2 = OpFunction %1 None %3 " + + "%4 = OpLabel " + "OpLoopMerge %5 %4 None " + "OpBranch %6 " + + "%5 = OpLabel " + "OpReturn " + + "%6 = OpLabel " + "OpSelectionMerge %7 None " + "OpBranchConditional %8 %9 %7 " + + "%7 = OpLabel " + "OpReturn " + + "%9 = OpLabel " + "OpReturn " + + "OpFunctionEnd", + { + { // defs + {1, "%1 = OpTypeVoid"}, + {2, "%2 = OpFunction %1 None %3"}, + {3, "%3 = OpTypeFunction %1"}, + {4, "%4 = OpLabel"}, + {5, "%5 = OpLabel"}, + {6, "%6 = OpLabel"}, + {7, "%7 = OpLabel"}, + {8, "%8 = OpConstantTrue %10"}, + {9, "%9 = OpLabel"}, + {10, "%10 = OpTypeBool"}, + }, + { // uses + {1, + { + "%2 = OpFunction %1 None %3", + "%3 = OpTypeFunction %1", + } + }, + {3, {"%2 = OpFunction %1 None %3"}}, + {4, {"OpLoopMerge %5 %4 None"}}, + {5, {"OpLoopMerge %5 %4 None"}}, + {6, {"OpBranch %6"}}, + {7, + { + "OpSelectionMerge %7 None", + "OpBranchConditional %8 %9 %7", + } + }, + {8, {"OpBranchConditional %8 %9 %7"}}, + {9, {"OpBranchConditional %8 %9 %7"}}, + {10, {"%8 = OpConstantTrue %10"}}, + } + } + }, + { // Forward reference + "OpDecorate %1 Block " + "OpTypeForwardPointer %2 Input " + "%3 = OpTypeInt 32 0 " + "%1 = OpTypeStruct %3 " + "%2 = OpTypePointer Input %3", + { + { // defs + {1, "%1 = OpTypeStruct %3"}, + {2, "%2 = OpTypePointer Input %3"}, + {3, "%3 = OpTypeInt 32 0"}, + }, + { // uses + {1, {"OpDecorate %1 Block"}}, + {2, {"OpTypeForwardPointer %2 Input"}}, + {3, + { + "%1 = OpTypeStruct %3", + "%2 = OpTypePointer Input %3", + } + } + }, + }, + }, + { // OpPhi + kOpPhiTestFunction, + { + { // defs + {1, "%1 = OpTypeVoid"}, + {2, "%2 = OpFunction %1 None %3"}, + {3, "%3 = OpTypeFunction %1"}, + {4, "%4 = OpLabel"}, + {5, "%5 = OpLabel"}, + {6, "%6 = OpTypeInt 32 0"}, + {7, "%7 = OpPhi %6 %8 %4 %9 %5"}, + {8, "%8 = OpConstant %6 0"}, + {9, "%9 = OpIAdd %6 %7 %8"}, + {10, "%10 = OpTypeFloat 32"}, + {11, "%11 = OpPhi %10 %12 %4 %13 %5"}, + {12, "%12 = OpConstant %10 1.0"}, + {13, "%13 = OpFAdd %10 %11 %12"}, + {16, "%16 = OpTypeBool"}, + {17, "%17 = OpSLessThan %16 %7 %18"}, + {18, "%18 = OpConstant %6 1"}, + {19, "%19 = OpLabel"}, + }, + { // uses + {1, + { + "%2 = OpFunction %1 None %3", + "%3 = OpTypeFunction %1", + } + }, + {3, {"%2 = OpFunction %1 None %3"}}, + {4, + { + "%7 = OpPhi %6 %8 %4 %9 %5", + "%11 = OpPhi %10 %12 %4 %13 %5", + } + }, + {5, + { + "OpBranch %5", + "%7 = OpPhi %6 %8 %4 %9 %5", + "%11 = OpPhi %10 %12 %4 %13 %5", + "OpLoopMerge %19 %5 None", + "OpBranchConditional %17 %5 %19", + } + }, + {6, + { + // Can't check constants properly + // "%8 = OpConstant %6 0", + // "%18 = OpConstant %6 1", + "%7 = OpPhi %6 %8 %4 %9 %5", + "%9 = OpIAdd %6 %7 %8", + } + }, + {7, + { + "%9 = OpIAdd %6 %7 %8", + "%17 = OpSLessThan %16 %7 %18", + } + }, + {8, + { + "%7 = OpPhi %6 %8 %4 %9 %5", + "%9 = OpIAdd %6 %7 %8", + } + }, + {9, {"%7 = OpPhi %6 %8 %4 %9 %5"}}, + {10, + { + // "%12 = OpConstant %10 1.0", + "%11 = OpPhi %10 %12 %4 %13 %5", + "%13 = OpFAdd %10 %11 %12", + } + }, + {11, {"%13 = OpFAdd %10 %11 %12"}}, + {12, + { + "%11 = OpPhi %10 %12 %4 %13 %5", + "%13 = OpFAdd %10 %11 %12", + } + }, + {13, {"%11 = OpPhi %10 %12 %4 %13 %5"}}, + {16, {"%17 = OpSLessThan %16 %7 %18"}}, + {17, {"OpBranchConditional %17 %5 %19"}}, + {18, {"%17 = OpSLessThan %16 %7 %18"}}, + {19, + { + "OpLoopMerge %19 %5 None", + "OpBranchConditional %17 %5 %19", + } + }, + }, + }, + }, + { // OpPhi defining and referencing the same id. + "%1 = OpTypeBool " + "%3 = OpTypeFunction %1 " + "%2 = OpConstantTrue %1 " + "%4 = OpFunction %1 None %3 " + "%6 = OpLabel " + " OpBranch %7 " + "%7 = OpLabel " + "%8 = OpPhi %1 %8 %7 %2 %6 " // both defines and uses %8 + " OpBranch %7 " + " OpFunctionEnd", + { + { // defs + {1, "%1 = OpTypeBool"}, + {2, "%2 = OpConstantTrue %1"}, + {3, "%3 = OpTypeFunction %1"}, + {4, "%4 = OpFunction %1 None %3"}, + {6, "%6 = OpLabel"}, + {7, "%7 = OpLabel"}, + {8, "%8 = OpPhi %1 %8 %7 %2 %6"}, + }, + { // uses + {1, + { + "%2 = OpConstantTrue %1", + "%3 = OpTypeFunction %1", + "%4 = OpFunction %1 None %3", + "%8 = OpPhi %1 %8 %7 %2 %6", + } + }, + {2, {"%8 = OpPhi %1 %8 %7 %2 %6"}}, + {3, {"%4 = OpFunction %1 None %3"}}, + {6, {"%8 = OpPhi %1 %8 %7 %2 %6"}}, + {7, + { + "OpBranch %7", + "%8 = OpPhi %1 %8 %7 %2 %6", + "OpBranch %7", + } + }, + {8, {"%8 = OpPhi %1 %8 %7 %2 %6"}}, + }, + }, + }, + }) +); +// clang-format on + +struct ReplaceUseCase { + const char* before; + std::vector> candidates; + const char* after; + InstDefUse du; +}; + +using ReplaceUseTest = ::testing::TestWithParam; + +// Disassembles the given |module| and returns the disassembly. +std::string DisassembleModule(Module* module) { + SpirvTools tools(SPV_ENV_UNIVERSAL_1_1); + + std::vector binary; + module->ToBinary(&binary, /* skip_nop = */ false); + + std::string text; + // We'll need to check the underlying id numbers. + // So turn off friendly names for ids. + tools.Disassemble(binary, &text, SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + while (!text.empty() && text.back() == '\n') text.pop_back(); + return text; +} + +TEST_P(ReplaceUseTest, Case) { + const auto& tc = GetParam(); + + // Build module. + const std::vector text = {tc.before}; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text), + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(nullptr, context); + + // Force a re-build of def-use manager. + context->InvalidateAnalyses(IRContext::Analysis::kAnalysisDefUse); + (void)context->get_def_use_mgr(); + + // Do the substitution. + for (const auto& candidate : tc.candidates) { + context->ReplaceAllUsesWith(candidate.first, candidate.second); + } + + EXPECT_EQ(tc.after, DisassembleModule(context->module())); + CheckDef(tc.du, context->get_def_use_mgr()->id_to_defs()); + CheckUse(tc.du, context->get_def_use_mgr(), context->module()->IdBound()); +} + +// clang-format off +INSTANTIATE_TEST_CASE_P( + TestCase, ReplaceUseTest, + ::testing::ValuesIn(std::vector{ + { // no use, no replace request + "", {}, "", {}, + }, + { // replace one use + "%1 = OpTypeBool " + "%2 = OpTypeVector %1 3 " + "%3 = OpTypeInt 32 0 ", + {{1, 3}}, + "%1 = OpTypeBool\n" + "%2 = OpTypeVector %3 3\n" + "%3 = OpTypeInt 32 0", + { + { // defs + {1, "%1 = OpTypeBool"}, + {2, "%2 = OpTypeVector %3 3"}, + {3, "%3 = OpTypeInt 32 0"}, + }, + { // uses + {3, {"%2 = OpTypeVector %3 3"}}, + }, + }, + }, + { // replace and then replace back + "%1 = OpTypeBool " + "%2 = OpTypeVector %1 3 " + "%3 = OpTypeInt 32 0", + {{1, 3}, {3, 1}}, + "%1 = OpTypeBool\n" + "%2 = OpTypeVector %1 3\n" + "%3 = OpTypeInt 32 0", + { + { // defs + {1, "%1 = OpTypeBool"}, + {2, "%2 = OpTypeVector %1 3"}, + {3, "%3 = OpTypeInt 32 0"}, + }, + { // uses + {1, {"%2 = OpTypeVector %1 3"}}, + }, + }, + }, + { // replace with the same id + "%1 = OpTypeBool " + "%2 = OpTypeVector %1 3", + {{1, 1}, {2, 2}, {3, 3}}, + "%1 = OpTypeBool\n" + "%2 = OpTypeVector %1 3", + { + { // defs + {1, "%1 = OpTypeBool"}, + {2, "%2 = OpTypeVector %1 3"}, + }, + { // uses + {1, {"%2 = OpTypeVector %1 3"}}, + }, + }, + }, + { // replace in sequence + "%1 = OpTypeBool " + "%2 = OpTypeVector %1 3 " + "%3 = OpTypeInt 32 0 " + "%4 = OpTypeInt 32 1 ", + {{1, 3}, {3, 4}}, + "%1 = OpTypeBool\n" + "%2 = OpTypeVector %4 3\n" + "%3 = OpTypeInt 32 0\n" + "%4 = OpTypeInt 32 1", + { + { // defs + {1, "%1 = OpTypeBool"}, + {2, "%2 = OpTypeVector %4 3"}, + {3, "%3 = OpTypeInt 32 0"}, + {4, "%4 = OpTypeInt 32 1"}, + }, + { // uses + {4, {"%2 = OpTypeVector %4 3"}}, + }, + }, + }, + { // replace multiple uses + "%1 = OpTypeBool " + "%2 = OpTypeVector %1 2 " + "%3 = OpTypeVector %1 3 " + "%4 = OpTypeVector %1 4 " + "%5 = OpTypeMatrix %2 2 " + "%6 = OpTypeMatrix %3 3 " + "%7 = OpTypeMatrix %4 4 " + "%8 = OpTypeInt 32 0 " + "%9 = OpTypeInt 32 1 " + "%10 = OpTypeInt 64 0", + {{1, 8}, {2, 9}, {4, 10}}, + "%1 = OpTypeBool\n" + "%2 = OpTypeVector %8 2\n" + "%3 = OpTypeVector %8 3\n" + "%4 = OpTypeVector %8 4\n" + "%5 = OpTypeMatrix %9 2\n" + "%6 = OpTypeMatrix %3 3\n" + "%7 = OpTypeMatrix %10 4\n" + "%8 = OpTypeInt 32 0\n" + "%9 = OpTypeInt 32 1\n" + "%10 = OpTypeInt 64 0", + { + { // defs + {1, "%1 = OpTypeBool"}, + {2, "%2 = OpTypeVector %8 2"}, + {3, "%3 = OpTypeVector %8 3"}, + {4, "%4 = OpTypeVector %8 4"}, + {5, "%5 = OpTypeMatrix %9 2"}, + {6, "%6 = OpTypeMatrix %3 3"}, + {7, "%7 = OpTypeMatrix %10 4"}, + {8, "%8 = OpTypeInt 32 0"}, + {9, "%9 = OpTypeInt 32 1"}, + {10, "%10 = OpTypeInt 64 0"}, + }, + { // uses + {8, + { + "%2 = OpTypeVector %8 2", + "%3 = OpTypeVector %8 3", + "%4 = OpTypeVector %8 4", + } + }, + {9, {"%5 = OpTypeMatrix %9 2"}}, + {3, {"%6 = OpTypeMatrix %3 3"}}, + {10, {"%7 = OpTypeMatrix %10 4"}}, + }, + }, + }, + { // OpPhi. + kOpPhiTestFunction, + // replace one id used by OpPhi, replace one id generated by OpPhi + {{9, 13}, {11, 9}}, + "%1 = OpTypeVoid\n" + "%6 = OpTypeInt 32 0\n" + "%10 = OpTypeFloat 32\n" + "%16 = OpTypeBool\n" + "%3 = OpTypeFunction %1\n" + "%8 = OpConstant %6 0\n" + "%18 = OpConstant %6 1\n" + "%12 = OpConstant %10 1\n" + "%2 = OpFunction %1 None %3\n" + "%4 = OpLabel\n" + "OpBranch %5\n" + + "%5 = OpLabel\n" + "%7 = OpPhi %6 %8 %4 %13 %5\n" // %9 -> %13 + "%11 = OpPhi %10 %12 %4 %13 %5\n" + "%9 = OpIAdd %6 %7 %8\n" + "%13 = OpFAdd %10 %9 %12\n" // %11 -> %9 + "%17 = OpSLessThan %16 %7 %18\n" + "OpLoopMerge %19 %5 None\n" + "OpBranchConditional %17 %5 %19\n" + + "%19 = OpLabel\n" + "OpReturn\n" + "OpFunctionEnd", + { + { // defs. + {1, "%1 = OpTypeVoid"}, + {2, "%2 = OpFunction %1 None %3"}, + {3, "%3 = OpTypeFunction %1"}, + {4, "%4 = OpLabel"}, + {5, "%5 = OpLabel"}, + {6, "%6 = OpTypeInt 32 0"}, + {7, "%7 = OpPhi %6 %8 %4 %13 %5"}, + {8, "%8 = OpConstant %6 0"}, + {9, "%9 = OpIAdd %6 %7 %8"}, + {10, "%10 = OpTypeFloat 32"}, + {11, "%11 = OpPhi %10 %12 %4 %13 %5"}, + {12, "%12 = OpConstant %10 1.0"}, + {13, "%13 = OpFAdd %10 %9 %12"}, + {16, "%16 = OpTypeBool"}, + {17, "%17 = OpSLessThan %16 %7 %18"}, + {18, "%18 = OpConstant %6 1"}, + {19, "%19 = OpLabel"}, + }, + { // uses + {1, + { + "%2 = OpFunction %1 None %3", + "%3 = OpTypeFunction %1", + } + }, + {3, {"%2 = OpFunction %1 None %3"}}, + {4, + { + "%7 = OpPhi %6 %8 %4 %13 %5", + "%11 = OpPhi %10 %12 %4 %13 %5", + } + }, + {5, + { + "OpBranch %5", + "%7 = OpPhi %6 %8 %4 %13 %5", + "%11 = OpPhi %10 %12 %4 %13 %5", + "OpLoopMerge %19 %5 None", + "OpBranchConditional %17 %5 %19", + } + }, + {6, + { + // Can't properly check constants + // "%8 = OpConstant %6 0", + // "%18 = OpConstant %6 1", + "%7 = OpPhi %6 %8 %4 %13 %5", + "%9 = OpIAdd %6 %7 %8" + } + }, + {7, + { + "%9 = OpIAdd %6 %7 %8", + "%17 = OpSLessThan %16 %7 %18", + } + }, + {8, + { + "%7 = OpPhi %6 %8 %4 %13 %5", + "%9 = OpIAdd %6 %7 %8", + } + }, + {9, {"%13 = OpFAdd %10 %9 %12"}}, // uses of %9 changed from %7 to %13 + {10, + { + "%11 = OpPhi %10 %12 %4 %13 %5", + // "%12 = OpConstant %10 1", + "%13 = OpFAdd %10 %9 %12" + } + }, + // no more uses of %11 + {12, + { + "%11 = OpPhi %10 %12 %4 %13 %5", + "%13 = OpFAdd %10 %9 %12" + } + }, + {13, { + "%7 = OpPhi %6 %8 %4 %13 %5", + "%11 = OpPhi %10 %12 %4 %13 %5", + } + }, + {16, {"%17 = OpSLessThan %16 %7 %18"}}, + {17, {"OpBranchConditional %17 %5 %19"}}, + {18, {"%17 = OpSLessThan %16 %7 %18"}}, + {19, + { + "OpLoopMerge %19 %5 None", + "OpBranchConditional %17 %5 %19", + } + }, + }, + }, + }, + { // OpPhi defining and referencing the same id. + "%1 = OpTypeBool " + "%3 = OpTypeFunction %1 " + "%2 = OpConstantTrue %1 " + + "%4 = OpFunction %3 None %1 " + "%6 = OpLabel " + " OpBranch %7 " + "%7 = OpLabel " + "%8 = OpPhi %1 %8 %7 %2 %6 " // both defines and uses %8 + " OpBranch %7 " + " OpFunctionEnd", + {{8, 2}}, + "%1 = OpTypeBool\n" + "%3 = OpTypeFunction %1\n" + "%2 = OpConstantTrue %1\n" + + "%4 = OpFunction %3 None %1\n" + "%6 = OpLabel\n" + "OpBranch %7\n" + "%7 = OpLabel\n" + "%8 = OpPhi %1 %2 %7 %2 %6\n" // use of %8 changed to %2 + "OpBranch %7\n" + "OpFunctionEnd", + { + { // defs + {1, "%1 = OpTypeBool"}, + {2, "%2 = OpConstantTrue %1"}, + {3, "%3 = OpTypeFunction %1"}, + {4, "%4 = OpFunction %3 None %1"}, + {6, "%6 = OpLabel"}, + {7, "%7 = OpLabel"}, + {8, "%8 = OpPhi %1 %2 %7 %2 %6"}, + }, + { // uses + {1, + { + "%2 = OpConstantTrue %1", + "%3 = OpTypeFunction %1", + "%4 = OpFunction %3 None %1", + "%8 = OpPhi %1 %2 %7 %2 %6", + } + }, + {2, + { + // Only checking users + "%8 = OpPhi %1 %2 %7 %2 %6", + } + }, + {3, {"%4 = OpFunction %3 None %1"}}, + {6, {"%8 = OpPhi %1 %2 %7 %2 %6"}}, + {7, + { + "OpBranch %7", + "%8 = OpPhi %1 %2 %7 %2 %6", + "OpBranch %7", + } + }, + // {8, {"%8 = OpPhi %1 %8 %7 %2 %6"}}, + }, + }, + }, + }) +); +// clang-format on + +struct KillDefCase { + const char* before; + std::vector ids_to_kill; + const char* after; + InstDefUse du; +}; + +using KillDefTest = ::testing::TestWithParam; + +TEST_P(KillDefTest, Case) { + const auto& tc = GetParam(); + + // Build module. + const std::vector text = {tc.before}; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text), + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(nullptr, context); + + // Analyze def and use. + DefUseManager manager(context->module()); + + // Do the substitution. + for (const auto id : tc.ids_to_kill) context->KillDef(id); + + EXPECT_EQ(tc.after, DisassembleModule(context->module())); + CheckDef(tc.du, context->get_def_use_mgr()->id_to_defs()); + CheckUse(tc.du, context->get_def_use_mgr(), context->module()->IdBound()); +} + +// clang-format off +INSTANTIATE_TEST_CASE_P( + TestCase, KillDefTest, + ::testing::ValuesIn(std::vector{ + { // no def, no use, no kill + "", {}, "", {} + }, + { // kill nothing + "%1 = OpTypeBool " + "%2 = OpTypeVector %1 2 " + "%3 = OpTypeVector %1 3 ", + {}, + "%1 = OpTypeBool\n" + "%2 = OpTypeVector %1 2\n" + "%3 = OpTypeVector %1 3", + { + { // defs + {1, "%1 = OpTypeBool"}, + {2, "%2 = OpTypeVector %1 2"}, + {3, "%3 = OpTypeVector %1 3"}, + }, + { // uses + {1, + { + "%2 = OpTypeVector %1 2", + "%3 = OpTypeVector %1 3", + } + }, + }, + }, + }, + { // kill id used, kill id not used, kill id not defined + "%1 = OpTypeBool " + "%2 = OpTypeVector %1 2 " + "%3 = OpTypeVector %1 3 " + "%4 = OpTypeVector %1 4 " + "%5 = OpTypeMatrix %3 3 " + "%6 = OpTypeMatrix %2 3", + {1, 3, 5, 10}, // ids to kill + "%2 = OpTypeVector %1 2\n" + "%4 = OpTypeVector %1 4\n" + "%6 = OpTypeMatrix %2 3", + { + { // defs + {2, "%2 = OpTypeVector %1 2"}, + {4, "%4 = OpTypeVector %1 4"}, + {6, "%6 = OpTypeMatrix %2 3"}, + }, + { // uses. %1 and %3 are both killed, so no uses + // recorded for them anymore. + {2, {"%6 = OpTypeMatrix %2 3"}}, + } + }, + }, + { // OpPhi. + kOpPhiTestFunction, + {9, 11}, // kill one id used by OpPhi, kill one id generated by OpPhi + "%1 = OpTypeVoid\n" + "%6 = OpTypeInt 32 0\n" + "%10 = OpTypeFloat 32\n" + "%16 = OpTypeBool\n" + "%3 = OpTypeFunction %1\n" + "%8 = OpConstant %6 0\n" + "%18 = OpConstant %6 1\n" + "%12 = OpConstant %10 1\n" + "%2 = OpFunction %1 None %3\n" + "%4 = OpLabel\n" + "OpBranch %5\n" + + "%5 = OpLabel\n" + "%7 = OpPhi %6 %8 %4 %9 %5\n" + "%13 = OpFAdd %10 %11 %12\n" + "%17 = OpSLessThan %16 %7 %18\n" + "OpLoopMerge %19 %5 None\n" + "OpBranchConditional %17 %5 %19\n" + + "%19 = OpLabel\n" + "OpReturn\n" + "OpFunctionEnd", + { + { // defs. %9 & %11 are killed. + {1, "%1 = OpTypeVoid"}, + {2, "%2 = OpFunction %1 None %3"}, + {3, "%3 = OpTypeFunction %1"}, + {4, "%4 = OpLabel"}, + {5, "%5 = OpLabel"}, + {6, "%6 = OpTypeInt 32 0"}, + {7, "%7 = OpPhi %6 %8 %4 %9 %5"}, + {8, "%8 = OpConstant %6 0"}, + {10, "%10 = OpTypeFloat 32"}, + {12, "%12 = OpConstant %10 1.0"}, + {13, "%13 = OpFAdd %10 %11 %12"}, + {16, "%16 = OpTypeBool"}, + {17, "%17 = OpSLessThan %16 %7 %18"}, + {18, "%18 = OpConstant %6 1"}, + {19, "%19 = OpLabel"}, + }, + { // uses + {1, + { + "%2 = OpFunction %1 None %3", + "%3 = OpTypeFunction %1", + } + }, + {3, {"%2 = OpFunction %1 None %3"}}, + {4, + { + "%7 = OpPhi %6 %8 %4 %9 %5", + // "%11 = OpPhi %10 %12 %4 %13 %5", + } + }, + {5, + { + "OpBranch %5", + "%7 = OpPhi %6 %8 %4 %9 %5", + // "%11 = OpPhi %10 %12 %4 %13 %5", + "OpLoopMerge %19 %5 None", + "OpBranchConditional %17 %5 %19", + } + }, + {6, + { + // Can't properly check constants + // "%8 = OpConstant %6 0", + // "%18 = OpConstant %6 1", + "%7 = OpPhi %6 %8 %4 %9 %5", + // "%9 = OpIAdd %6 %7 %8" + } + }, + {7, {"%17 = OpSLessThan %16 %7 %18"}}, + {8, + { + "%7 = OpPhi %6 %8 %4 %9 %5", + // "%9 = OpIAdd %6 %7 %8", + } + }, + // {9, {"%7 = OpPhi %6 %8 %4 %13 %5"}}, + {10, + { + // "%11 = OpPhi %10 %12 %4 %13 %5", + // "%12 = OpConstant %10 1", + "%13 = OpFAdd %10 %11 %12" + } + }, + // {11, {"%13 = OpFAdd %10 %11 %12"}}, + {12, + { + // "%11 = OpPhi %10 %12 %4 %13 %5", + "%13 = OpFAdd %10 %11 %12" + } + }, + // {13, {"%11 = OpPhi %10 %12 %4 %13 %5"}}, + {16, {"%17 = OpSLessThan %16 %7 %18"}}, + {17, {"OpBranchConditional %17 %5 %19"}}, + {18, {"%17 = OpSLessThan %16 %7 %18"}}, + {19, + { + "OpLoopMerge %19 %5 None", + "OpBranchConditional %17 %5 %19", + } + }, + }, + }, + }, + { // OpPhi defining and referencing the same id. + "%1 = OpTypeBool " + "%3 = OpTypeFunction %1 " + "%2 = OpConstantTrue %1 " + "%4 = OpFunction %3 None %1 " + "%6 = OpLabel " + " OpBranch %7 " + "%7 = OpLabel " + "%8 = OpPhi %1 %8 %7 %2 %6 " // both defines and uses %8 + " OpBranch %7 " + " OpFunctionEnd", + {8}, + "%1 = OpTypeBool\n" + "%3 = OpTypeFunction %1\n" + "%2 = OpConstantTrue %1\n" + + "%4 = OpFunction %3 None %1\n" + "%6 = OpLabel\n" + "OpBranch %7\n" + "%7 = OpLabel\n" + "OpBranch %7\n" + "OpFunctionEnd", + { + { // defs + {1, "%1 = OpTypeBool"}, + {2, "%2 = OpConstantTrue %1"}, + {3, "%3 = OpTypeFunction %1"}, + {4, "%4 = OpFunction %3 None %1"}, + {6, "%6 = OpLabel"}, + {7, "%7 = OpLabel"}, + // {8, "%8 = OpPhi %1 %8 %7 %2 %6"}, + }, + { // uses + {1, + { + "%2 = OpConstantTrue %1", + "%3 = OpTypeFunction %1", + "%4 = OpFunction %3 None %1", + // "%8 = OpPhi %1 %8 %7 %2 %6", + } + }, + // {2, {"%8 = OpPhi %1 %8 %7 %2 %6"}}, + {3, {"%4 = OpFunction %3 None %1"}}, + // {6, {"%8 = OpPhi %1 %8 %7 %2 %6"}}, + {7, + { + "OpBranch %7", + // "%8 = OpPhi %1 %8 %7 %2 %6", + "OpBranch %7", + } + }, + // {8, {"%8 = OpPhi %1 %8 %7 %2 %6"}}, + }, + }, + }, + }) +); +// clang-format on + +TEST(DefUseTest, OpSwitch) { + // Because disassembler has basic type check for OpSwitch's selector, we + // cannot use the DisassembleInst() in the above. Thus, this special spotcheck + // test case. + + const char original_text[] = + // int64 f(int64 v) { + // switch (v) { + // case 1: break; + // case -4294967296: break; + // case 9223372036854775807: break; + // default: break; + // } + // return v; + // } + " %1 = OpTypeInt 64 1 " + " %3 = OpTypePointer Input %1 " + " %2 = OpFunction %1 None %3 " // %3 is int64(int64)* + " %4 = OpFunctionParameter %1 " + " %5 = OpLabel " + " %6 = OpLoad %1 %4 " // selector value + " OpSelectionMerge %7 None " + " OpSwitch %6 %8 " + " 1 %9 " // 1 + " -4294967296 %10 " // -2^32 + " 9223372036854775807 %11 " // 2^63-1 + " %8 = OpLabel " // default + " OpBranch %7 " + " %9 = OpLabel " + " OpBranch %7 " + "%10 = OpLabel " + " OpBranch %7 " + "%11 = OpLabel " + " OpBranch %7 " + " %7 = OpLabel " + " OpReturnValue %6 " + " OpFunctionEnd"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, original_text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(nullptr, context); + + // Force a re-build of def-use manager. + context->InvalidateAnalyses(IRContext::Analysis::kAnalysisDefUse); + (void)context->get_def_use_mgr(); + + // Do a bunch replacements. + context->ReplaceAllUsesWith(11, 7); // to existing id + context->ReplaceAllUsesWith(10, 11); // to existing id + context->ReplaceAllUsesWith(9, 10); // to existing id + + // clang-format off + const char modified_text[] = + "%1 = OpTypeInt 64 1\n" + "%3 = OpTypePointer Input %1\n" + "%2 = OpFunction %1 None %3\n" // %3 is int64(int64)* + "%4 = OpFunctionParameter %1\n" + "%5 = OpLabel\n" + "%6 = OpLoad %1 %4\n" // selector value + "OpSelectionMerge %7 None\n" + "OpSwitch %6 %8 1 %10 -4294967296 %11 9223372036854775807 %7\n" // changed! + "%8 = OpLabel\n" // default + "OpBranch %7\n" + "%9 = OpLabel\n" + "OpBranch %7\n" + "%10 = OpLabel\n" + "OpBranch %7\n" + "%11 = OpLabel\n" + "OpBranch %7\n" + "%7 = OpLabel\n" + "OpReturnValue %6\n" + "OpFunctionEnd"; + // clang-format on + + EXPECT_EQ(modified_text, DisassembleModule(context->module())); + + InstDefUse def_uses = {}; + def_uses.defs = { + {1, "%1 = OpTypeInt 64 1"}, + {2, "%2 = OpFunction %1 None %3"}, + {3, "%3 = OpTypePointer Input %1"}, + {4, "%4 = OpFunctionParameter %1"}, + {5, "%5 = OpLabel"}, + {6, "%6 = OpLoad %1 %4"}, + {7, "%7 = OpLabel"}, + {8, "%8 = OpLabel"}, + {9, "%9 = OpLabel"}, + {10, "%10 = OpLabel"}, + {11, "%11 = OpLabel"}, + }; + CheckDef(def_uses, context->get_def_use_mgr()->id_to_defs()); + + { + EXPECT_EQ(2u, NumUses(context, 6)); + std::vector opcodes = GetUseOpcodes(context, 6u); + EXPECT_THAT(opcodes, UnorderedElementsAre(SpvOpSwitch, SpvOpReturnValue)); + } + { + EXPECT_EQ(6u, NumUses(context, 7)); + std::vector opcodes = GetUseOpcodes(context, 7u); + // OpSwitch is now a user of %7. + EXPECT_THAT(opcodes, UnorderedElementsAre(SpvOpSelectionMerge, SpvOpBranch, + SpvOpBranch, SpvOpBranch, + SpvOpBranch, SpvOpSwitch)); + } + // Check all ids only used by OpSwitch after replacement. + for (const auto id : {8u, 10u, 11u}) { + EXPECT_EQ(1u, NumUses(context, id)); + EXPECT_EQ(SpvOpSwitch, GetUseOpcodes(context, id).back()); + } +} + +// Test case for analyzing individual instructions. +struct AnalyzeInstDefUseTestCase { + const char* module_text; + InstDefUse expected_define_use; +}; + +using AnalyzeInstDefUseTest = + ::testing::TestWithParam; + +// Test the analyzing result for individual instructions. +TEST_P(AnalyzeInstDefUseTest, Case) { + auto tc = GetParam(); + + // Build module. + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.module_text); + ASSERT_NE(nullptr, context); + + // Analyze the instructions. + DefUseManager manager(context->module()); + + CheckDef(tc.expected_define_use, manager.id_to_defs()); + CheckUse(tc.expected_define_use, &manager, context->module()->IdBound()); + // CheckUse(tc.expected_define_use, manager.id_to_uses()); +} + +// clang-format off +INSTANTIATE_TEST_CASE_P( + TestCase, AnalyzeInstDefUseTest, + ::testing::ValuesIn(std::vector{ + { // A type declaring instruction. + "%1 = OpTypeInt 32 1", + { + // defs + {{1, "%1 = OpTypeInt 32 1"}}, + {}, // no uses + }, + }, + { // A type declaring instruction and a constant value. + "%1 = OpTypeBool " + "%2 = OpConstantTrue %1", + { + { // defs + {1, "%1 = OpTypeBool"}, + {2, "%2 = OpConstantTrue %1"}, + }, + { // uses + {1, {"%2 = OpConstantTrue %1"}}, + }, + }, + }, + })); +// clang-format on + +using AnalyzeInstDefUse = ::testing::Test; + +TEST(AnalyzeInstDefUse, UseWithNoResultId) { + IRContext context(SPV_ENV_UNIVERSAL_1_2, nullptr); + + // Analyze the instructions. + DefUseManager manager(context.module()); + + Instruction label(&context, SpvOpLabel, 0, 2, {}); + manager.AnalyzeInstDefUse(&label); + + Instruction branch(&context, SpvOpBranch, 0, 0, {{SPV_OPERAND_TYPE_ID, {2}}}); + manager.AnalyzeInstDefUse(&branch); + context.module()->SetIdBound(3); + + InstDefUse expected = { + // defs + { + {2, "%2 = OpLabel"}, + }, + // uses + {{2, {"OpBranch %2"}}}, + }; + + CheckDef(expected, manager.id_to_defs()); + CheckUse(expected, &manager, context.module()->IdBound()); +} + +TEST(AnalyzeInstDefUse, AddNewInstruction) { + const std::string input = "%1 = OpTypeBool"; + + // Build module. + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, input); + ASSERT_NE(nullptr, context); + + // Analyze the instructions. + DefUseManager manager(context->module()); + + Instruction newInst(context.get(), SpvOpConstantTrue, 1, 2, {}); + manager.AnalyzeInstDefUse(&newInst); + + InstDefUse expected = { + { + // defs + {1, "%1 = OpTypeBool"}, + {2, "%2 = OpConstantTrue %1"}, + }, + { + // uses + {1, {"%2 = OpConstantTrue %1"}}, + }, + }; + + CheckDef(expected, manager.id_to_defs()); + CheckUse(expected, &manager, context->module()->IdBound()); +} + +struct KillInstTestCase { + const char* before; + std::unordered_set indices_for_inst_to_kill; + const char* after; + InstDefUse expected_define_use; +}; + +using KillInstTest = ::testing::TestWithParam; + +TEST_P(KillInstTest, Case) { + auto tc = GetParam(); + + // Build module. + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.before, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(nullptr, context); + + // Force a re-build of the def-use manager. + context->InvalidateAnalyses(IRContext::Analysis::kAnalysisDefUse); + (void)context->get_def_use_mgr(); + + // KillInst + context->module()->ForEachInst([&tc, &context](Instruction* inst) { + if (tc.indices_for_inst_to_kill.count(inst->result_id())) { + context->KillInst(inst); + } + }); + + EXPECT_EQ(tc.after, DisassembleModule(context->module())); + CheckDef(tc.expected_define_use, context->get_def_use_mgr()->id_to_defs()); + CheckUse(tc.expected_define_use, context->get_def_use_mgr(), + context->module()->IdBound()); +} + +// clang-format off +INSTANTIATE_TEST_CASE_P( + TestCase, KillInstTest, + ::testing::ValuesIn(std::vector{ + // Kill id defining instructions. + { + "%3 = OpTypeVoid " + "%1 = OpTypeFunction %3 " + "%2 = OpFunction %1 None %3 " + "%4 = OpLabel " + " OpBranch %5 " + "%5 = OpLabel " + " OpBranch %6 " + "%6 = OpLabel " + " OpBranch %4 " + "%7 = OpLabel " + " OpReturn " + " OpFunctionEnd", + {3, 5, 7}, + "%1 = OpTypeFunction %3\n" + "%2 = OpFunction %1 None %3\n" + "%4 = OpLabel\n" + "OpBranch %5\n" + "OpNop\n" + "OpBranch %6\n" + "%6 = OpLabel\n" + "OpBranch %4\n" + "OpNop\n" + "OpReturn\n" + "OpFunctionEnd", + { + // defs + { + {1, "%1 = OpTypeFunction %3"}, + {2, "%2 = OpFunction %1 None %3"}, + {4, "%4 = OpLabel"}, + {6, "%6 = OpLabel"}, + }, + // uses + { + {1, {"%2 = OpFunction %1 None %3"}}, + {4, {"OpBranch %4"}}, + {6, {"OpBranch %6"}}, + } + } + }, + // Kill instructions that do not have result ids. + { + "%3 = OpTypeVoid " + "%1 = OpTypeFunction %3 " + "%2 = OpFunction %1 None %3 " + "%4 = OpLabel " + " OpBranch %5 " + "%5 = OpLabel " + " OpBranch %6 " + "%6 = OpLabel " + " OpBranch %4 " + "%7 = OpLabel " + " OpReturn " + " OpFunctionEnd", + {2, 4}, + "%3 = OpTypeVoid\n" + "%1 = OpTypeFunction %3\n" + "OpNop\n" + "OpNop\n" + "OpBranch %5\n" + "%5 = OpLabel\n" + "OpBranch %6\n" + "%6 = OpLabel\n" + "OpBranch %4\n" + "%7 = OpLabel\n" + "OpReturn\n" + "OpFunctionEnd", + { + // defs + { + {1, "%1 = OpTypeFunction %3"}, + {3, "%3 = OpTypeVoid"}, + {5, "%5 = OpLabel"}, + {6, "%6 = OpLabel"}, + {7, "%7 = OpLabel"}, + }, + // uses + { + {3, {"%1 = OpTypeFunction %3"}}, + {5, {"OpBranch %5"}}, + {6, {"OpBranch %6"}}, + } + } + }, + })); +// clang-format on + +struct GetAnnotationsTestCase { + const char* code; + uint32_t id; + std::vector annotations; +}; + +using GetAnnotationsTest = ::testing::TestWithParam; + +TEST_P(GetAnnotationsTest, Case) { + const GetAnnotationsTestCase& tc = GetParam(); + + // Build module. + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.code); + ASSERT_NE(nullptr, context); + + // Get annotations + DefUseManager manager(context->module()); + auto insts = manager.GetAnnotations(tc.id); + + // Check + ASSERT_EQ(tc.annotations.size(), insts.size()) + << "wrong number of annotation instructions"; + auto inst_iter = insts.begin(); + for (const std::string& expected_anno_inst : tc.annotations) { + EXPECT_EQ(expected_anno_inst, DisassembleInst(*inst_iter)) + << "annotation instruction mismatch"; + inst_iter++; + } +} + +// clang-format off +INSTANTIATE_TEST_CASE_P( + TestCase, GetAnnotationsTest, + ::testing::ValuesIn(std::vector{ + // empty + {"", 0, {}}, + // basic + { + // code + "OpDecorate %1 Block " + "OpDecorate %1 RelaxedPrecision " + "%3 = OpTypeInt 32 0 " + "%1 = OpTypeStruct %3", + // id + 1, + // annotations + { + "OpDecorate %1 Block", + "OpDecorate %1 RelaxedPrecision", + }, + }, + // with debug instructions + { + // code + "OpName %1 \"struct_type\" " + "OpName %3 \"int_type\" " + "OpDecorate %1 Block " + "OpDecorate %1 RelaxedPrecision " + "%3 = OpTypeInt 32 0 " + "%1 = OpTypeStruct %3", + // id + 1, + // annotations + { + "OpDecorate %1 Block", + "OpDecorate %1 RelaxedPrecision", + }, + }, + // no annotations + { + // code + "OpName %1 \"struct_type\" " + "OpName %3 \"int_type\" " + "OpDecorate %1 Block " + "OpDecorate %1 RelaxedPrecision " + "%3 = OpTypeInt 32 0 " + "%1 = OpTypeStruct %3", + // id + 3, + // annotations + {}, + }, + // decoration group + { + // code + "OpDecorate %1 Block " + "OpDecorate %1 RelaxedPrecision " + "%1 = OpDecorationGroup " + "OpGroupDecorate %1 %2 %3 " + "%4 = OpTypeInt 32 0 " + "%2 = OpTypeStruct %4 " + "%3 = OpTypeStruct %4 %4", + // id + 3, + // annotations + { + "OpGroupDecorate %1 %2 %3", + }, + }, + // memeber decorate + { + // code + "OpMemberDecorate %1 0 RelaxedPrecision " + "%2 = OpTypeInt 32 0 " + "%1 = OpTypeStruct %2 %2", + // id + 1, + // annotations + { + "OpMemberDecorate %1 0 RelaxedPrecision", + }, + }, + })); + +using UpdateUsesTest = PassTest<::testing::Test>; + +TEST_F(UpdateUsesTest, KeepOldUses) { + const std::vector text = { + // clang-format off + "OpCapability Shader", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Vertex %main \"main\"", + "OpName %main \"main\"", + "%void = OpTypeVoid", + "%4 = OpTypeFunction %void", + "%uint = OpTypeInt 32 0", + "%uint_5 = OpConstant %uint 5", + "%25 = OpConstant %uint 25", + "%main = OpFunction %void None %4", + "%8 = OpLabel", + "%9 = OpIMul %uint %uint_5 %uint_5", + "%10 = OpIMul %uint %9 %uint_5", + "OpReturn", + "OpFunctionEnd" + // clang-format on + }; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text), + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(nullptr, context); + + DefUseManager* def_use_mgr = context->get_def_use_mgr(); + Instruction* def = def_use_mgr->GetDef(9); + Instruction* use = def_use_mgr->GetDef(10); + def->SetOpcode(SpvOpCopyObject); + def->SetInOperands({{SPV_OPERAND_TYPE_ID, {25}}}); + context->UpdateDefUse(def); + + auto users = def_use_mgr->id_to_users(); + UserEntry entry = {def, use}; + EXPECT_THAT(users, Contains(entry)); +} +// clang-format on + +} // namespace +} // namespace analysis +} // namespace opt +} // namespace spvtools diff --git a/test/opt/dominator_tree/CMakeLists.txt b/test/opt/dominator_tree/CMakeLists.txt new file mode 100644 index 000000000..813d628a0 --- /dev/null +++ b/test/opt/dominator_tree/CMakeLists.txt @@ -0,0 +1,31 @@ +# Copyright (c) 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +add_spvtools_unittest(TARGET dominator_analysis + SRCS ../function_utils.h + common_dominators.cpp + generated.cpp + nested_ifs.cpp + nested_ifs_post.cpp + nested_loops.cpp + nested_loops_with_unreachables.cpp + post.cpp + simple.cpp + switch_case_fallthrough.cpp + unreachable_for.cpp + unreachable_for_post.cpp + LIBS SPIRV-Tools-opt + PCH_FILE pch_test_opt_dom +) diff --git a/test/opt/dominator_tree/common_dominators.cpp b/test/opt/dominator_tree/common_dominators.cpp new file mode 100644 index 000000000..dfa03e986 --- /dev/null +++ b/test/opt/dominator_tree/common_dominators.cpp @@ -0,0 +1,151 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "source/opt/build_module.h" +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace opt { +namespace { + +using CommonDominatorsTest = ::testing::Test; + +const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +OpBranch %2 +%2 = OpLabel +OpLoopMerge %3 %4 None +OpBranch %5 +%5 = OpLabel +OpBranchConditional %true %3 %4 +%4 = OpLabel +OpBranch %2 +%3 = OpLabel +OpSelectionMerge %6 None +OpBranchConditional %true %7 %8 +%7 = OpLabel +OpBranch %6 +%8 = OpLabel +OpBranch %9 +%9 = OpLabel +OpBranch %6 +%6 = OpLabel +OpBranch %10 +%11 = OpLabel +OpBranch %10 +%10 = OpLabel +OpReturn +OpFunctionEnd +)"; + +BasicBlock* GetBlock(uint32_t id, std::unique_ptr& context) { + return context->get_instr_block(context->get_def_use_mgr()->GetDef(id)); +} + +TEST(CommonDominatorsTest, SameBlock) { + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(nullptr, context); + + DominatorAnalysis* analysis = + context->GetDominatorAnalysis(&*context->module()->begin()); + + for (auto& block : *context->module()->begin()) { + EXPECT_EQ(&block, analysis->CommonDominator(&block, &block)); + } +} + +TEST(CommonDominatorsTest, ParentAndChild) { + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(nullptr, context); + + DominatorAnalysis* analysis = + context->GetDominatorAnalysis(&*context->module()->begin()); + + EXPECT_EQ( + GetBlock(1u, context), + analysis->CommonDominator(GetBlock(1u, context), GetBlock(2u, context))); + EXPECT_EQ( + GetBlock(2u, context), + analysis->CommonDominator(GetBlock(2u, context), GetBlock(5u, context))); + EXPECT_EQ( + GetBlock(1u, context), + analysis->CommonDominator(GetBlock(1u, context), GetBlock(5u, context))); +} + +TEST(CommonDominatorsTest, BranchSplit) { + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(nullptr, context); + + DominatorAnalysis* analysis = + context->GetDominatorAnalysis(&*context->module()->begin()); + + EXPECT_EQ( + GetBlock(3u, context), + analysis->CommonDominator(GetBlock(7u, context), GetBlock(8u, context))); + EXPECT_EQ( + GetBlock(3u, context), + analysis->CommonDominator(GetBlock(7u, context), GetBlock(9u, context))); +} + +TEST(CommonDominatorsTest, LoopContinueAndMerge) { + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(nullptr, context); + + DominatorAnalysis* analysis = + context->GetDominatorAnalysis(&*context->module()->begin()); + + EXPECT_EQ( + GetBlock(5u, context), + analysis->CommonDominator(GetBlock(3u, context), GetBlock(4u, context))); +} + +TEST(CommonDominatorsTest, NoCommonDominator) { + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(nullptr, context); + + DominatorAnalysis* analysis = + context->GetDominatorAnalysis(&*context->module()->begin()); + + EXPECT_EQ(nullptr, analysis->CommonDominator(GetBlock(10u, context), + GetBlock(11u, context))); + EXPECT_EQ(nullptr, analysis->CommonDominator(GetBlock(11u, context), + GetBlock(6u, context))); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/dominator_tree/generated.cpp b/test/opt/dominator_tree/generated.cpp new file mode 100644 index 000000000..43b723e93 --- /dev/null +++ b/test/opt/dominator_tree/generated.cpp @@ -0,0 +1,900 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/iterator.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::UnorderedElementsAre; +using PassClassTest = PassTest<::testing::Test>; + +// Check that x dominates y, and +// if x != y then +// x strictly dominates y and +// y does not dominate x and +// y does not strictly dominate x +// if x == x then +// x does not strictly dominate itself +void check_dominance(const DominatorAnalysisBase& dom_tree, const Function* fn, + uint32_t x, uint32_t y) { + SCOPED_TRACE("Check dominance properties for Basic Block " + + std::to_string(x) + " and " + std::to_string(y)); + EXPECT_TRUE(dom_tree.Dominates(spvtest::GetBasicBlock(fn, x), + spvtest::GetBasicBlock(fn, y))); + EXPECT_TRUE(dom_tree.Dominates(x, y)); + if (x == y) { + EXPECT_FALSE(dom_tree.StrictlyDominates(x, x)); + } else { + EXPECT_TRUE(dom_tree.StrictlyDominates(x, y)); + EXPECT_FALSE(dom_tree.Dominates(y, x)); + EXPECT_FALSE(dom_tree.StrictlyDominates(y, x)); + } +} + +// Check that x does not dominates y and vise versa +void check_no_dominance(const DominatorAnalysisBase& dom_tree, + const Function* fn, uint32_t x, uint32_t y) { + SCOPED_TRACE("Check no domination for Basic Block " + std::to_string(x) + + " and " + std::to_string(y)); + EXPECT_FALSE(dom_tree.Dominates(spvtest::GetBasicBlock(fn, x), + spvtest::GetBasicBlock(fn, y))); + EXPECT_FALSE(dom_tree.Dominates(x, y)); + EXPECT_FALSE(dom_tree.StrictlyDominates(spvtest::GetBasicBlock(fn, x), + spvtest::GetBasicBlock(fn, y))); + EXPECT_FALSE(dom_tree.StrictlyDominates(x, y)); + + EXPECT_FALSE(dom_tree.Dominates(spvtest::GetBasicBlock(fn, y), + spvtest::GetBasicBlock(fn, x))); + EXPECT_FALSE(dom_tree.Dominates(y, x)); + EXPECT_FALSE(dom_tree.StrictlyDominates(spvtest::GetBasicBlock(fn, y), + spvtest::GetBasicBlock(fn, x))); + EXPECT_FALSE(dom_tree.StrictlyDominates(y, x)); +} + +TEST_F(PassClassTest, DominatorSimpleCFG) { + const std::string text = R"( + OpCapability Addresses + OpCapability Kernel + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpTypeBool + %5 = OpTypeInt 32 0 + %6 = OpConstant %5 0 + %7 = OpConstantFalse %4 + %8 = OpConstantTrue %4 + %9 = OpConstant %5 1 + %1 = OpFunction %2 None %3 + %10 = OpLabel + OpBranch %11 + %11 = OpLabel + OpSwitch %6 %12 1 %13 + %12 = OpLabel + OpBranch %14 + %13 = OpLabel + OpBranch %14 + %14 = OpLabel + OpBranchConditional %8 %11 %15 + %15 = OpLabel + OpReturn + OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_0, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* fn = spvtest::GetFunction(module, 1); + const BasicBlock* entry = spvtest::GetBasicBlock(fn, 10); + EXPECT_EQ(entry, fn->entry().get()) + << "The entry node is not the expected one"; + + // Test normal dominator tree + { + DominatorAnalysis dom_tree; + const CFG& cfg = *context->cfg(); + dom_tree.InitializeTree(cfg, fn); + + // Inspect the actual tree + DominatorTree& tree = dom_tree.GetDomTree(); + EXPECT_EQ(tree.GetRoot()->bb_, cfg.pseudo_entry_block()); + EXPECT_TRUE( + dom_tree.Dominates(cfg.pseudo_entry_block()->id(), entry->id())); + + // (strict) dominance checks + for (uint32_t id : {10, 11, 12, 13, 14, 15}) + check_dominance(dom_tree, fn, id, id); + + check_dominance(dom_tree, fn, 10, 11); + check_dominance(dom_tree, fn, 10, 12); + check_dominance(dom_tree, fn, 10, 13); + check_dominance(dom_tree, fn, 10, 14); + check_dominance(dom_tree, fn, 10, 15); + + check_dominance(dom_tree, fn, 11, 12); + check_dominance(dom_tree, fn, 11, 13); + check_dominance(dom_tree, fn, 11, 14); + check_dominance(dom_tree, fn, 11, 15); + + check_dominance(dom_tree, fn, 14, 15); + + check_no_dominance(dom_tree, fn, 12, 13); + check_no_dominance(dom_tree, fn, 12, 14); + check_no_dominance(dom_tree, fn, 13, 14); + + // check with some invalid inputs + EXPECT_FALSE(dom_tree.Dominates(nullptr, entry)); + EXPECT_FALSE(dom_tree.Dominates(entry, nullptr)); + EXPECT_FALSE(dom_tree.Dominates(static_cast(nullptr), + static_cast(nullptr))); + EXPECT_FALSE(dom_tree.Dominates(10, 1)); + EXPECT_FALSE(dom_tree.Dominates(1, 10)); + EXPECT_FALSE(dom_tree.Dominates(1, 1)); + + EXPECT_FALSE(dom_tree.StrictlyDominates(nullptr, entry)); + EXPECT_FALSE(dom_tree.StrictlyDominates(entry, nullptr)); + EXPECT_FALSE(dom_tree.StrictlyDominates(nullptr, nullptr)); + EXPECT_FALSE(dom_tree.StrictlyDominates(10, 1)); + EXPECT_FALSE(dom_tree.StrictlyDominates(1, 10)); + EXPECT_FALSE(dom_tree.StrictlyDominates(1, 1)); + + EXPECT_EQ(dom_tree.ImmediateDominator(cfg.pseudo_entry_block()), nullptr); + EXPECT_EQ(dom_tree.ImmediateDominator(entry), cfg.pseudo_entry_block()); + EXPECT_EQ(dom_tree.ImmediateDominator(nullptr), nullptr); + + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 11)), + spvtest::GetBasicBlock(fn, 10)); + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 12)), + spvtest::GetBasicBlock(fn, 11)); + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 13)), + spvtest::GetBasicBlock(fn, 11)); + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 14)), + spvtest::GetBasicBlock(fn, 11)); + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 15)), + spvtest::GetBasicBlock(fn, 14)); + } + + // Test post dominator tree + { + PostDominatorAnalysis dom_tree; + const CFG& cfg = *context->cfg(); + dom_tree.InitializeTree(cfg, fn); + + // Inspect the actual tree + DominatorTree& tree = dom_tree.GetDomTree(); + EXPECT_EQ(tree.GetRoot()->bb_, cfg.pseudo_exit_block()); + EXPECT_TRUE(dom_tree.Dominates(cfg.pseudo_exit_block()->id(), 15)); + + // (strict) dominance checks + for (uint32_t id : {10, 11, 12, 13, 14, 15}) + check_dominance(dom_tree, fn, id, id); + + check_dominance(dom_tree, fn, 14, 10); + check_dominance(dom_tree, fn, 14, 11); + check_dominance(dom_tree, fn, 14, 12); + check_dominance(dom_tree, fn, 14, 13); + + check_dominance(dom_tree, fn, 15, 10); + check_dominance(dom_tree, fn, 15, 11); + check_dominance(dom_tree, fn, 15, 12); + check_dominance(dom_tree, fn, 15, 13); + check_dominance(dom_tree, fn, 15, 14); + + check_no_dominance(dom_tree, fn, 13, 12); + check_no_dominance(dom_tree, fn, 12, 11); + check_no_dominance(dom_tree, fn, 13, 11); + + // check with some invalid inputs + EXPECT_FALSE(dom_tree.Dominates(nullptr, entry)); + EXPECT_FALSE(dom_tree.Dominates(entry, nullptr)); + EXPECT_FALSE(dom_tree.Dominates(static_cast(nullptr), + static_cast(nullptr))); + EXPECT_FALSE(dom_tree.Dominates(10, 1)); + EXPECT_FALSE(dom_tree.Dominates(1, 10)); + EXPECT_FALSE(dom_tree.Dominates(1, 1)); + + EXPECT_FALSE(dom_tree.StrictlyDominates(nullptr, entry)); + EXPECT_FALSE(dom_tree.StrictlyDominates(entry, nullptr)); + EXPECT_FALSE(dom_tree.StrictlyDominates(nullptr, nullptr)); + EXPECT_FALSE(dom_tree.StrictlyDominates(10, 1)); + EXPECT_FALSE(dom_tree.StrictlyDominates(1, 10)); + EXPECT_FALSE(dom_tree.StrictlyDominates(1, 1)); + + EXPECT_EQ(dom_tree.ImmediateDominator(nullptr), nullptr); + + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 11)), + spvtest::GetBasicBlock(fn, 14)); + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 12)), + spvtest::GetBasicBlock(fn, 14)); + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 13)), + spvtest::GetBasicBlock(fn, 14)); + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 14)), + spvtest::GetBasicBlock(fn, 15)); + + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 15)), + cfg.pseudo_exit_block()); + + EXPECT_EQ(dom_tree.ImmediateDominator(cfg.pseudo_exit_block()), nullptr); + } +} + +TEST_F(PassClassTest, DominatorIrreducibleCFG) { + const std::string text = R"( + OpCapability Addresses + OpCapability Kernel + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpTypeBool + %5 = OpTypeInt 32 0 + %6 = OpConstantFalse %4 + %7 = OpConstantTrue %4 + %1 = OpFunction %2 None %3 + %8 = OpLabel + OpBranch %9 + %9 = OpLabel + OpBranchConditional %7 %10 %11 + %10 = OpLabel + OpBranch %11 + %11 = OpLabel + OpBranchConditional %7 %10 %12 + %12 = OpLabel + OpReturn + OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_0, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* fn = spvtest::GetFunction(module, 1); + + const BasicBlock* entry = spvtest::GetBasicBlock(fn, 8); + EXPECT_EQ(entry, fn->entry().get()) + << "The entry node is not the expected one"; + + // Check normal dominator tree + { + DominatorAnalysis dom_tree; + const CFG& cfg = *context->cfg(); + dom_tree.InitializeTree(cfg, fn); + + // Inspect the actual tree + DominatorTree& tree = dom_tree.GetDomTree(); + EXPECT_EQ(tree.GetRoot()->bb_, cfg.pseudo_entry_block()); + EXPECT_TRUE( + dom_tree.Dominates(cfg.pseudo_entry_block()->id(), entry->id())); + + // (strict) dominance checks + for (uint32_t id : {8, 9, 10, 11, 12}) + check_dominance(dom_tree, fn, id, id); + + check_dominance(dom_tree, fn, 8, 9); + check_dominance(dom_tree, fn, 8, 10); + check_dominance(dom_tree, fn, 8, 11); + check_dominance(dom_tree, fn, 8, 12); + + check_dominance(dom_tree, fn, 9, 10); + check_dominance(dom_tree, fn, 9, 11); + check_dominance(dom_tree, fn, 9, 12); + + check_dominance(dom_tree, fn, 11, 12); + + check_no_dominance(dom_tree, fn, 10, 11); + + EXPECT_EQ(dom_tree.ImmediateDominator(cfg.pseudo_entry_block()), nullptr); + EXPECT_EQ(dom_tree.ImmediateDominator(entry), cfg.pseudo_entry_block()); + + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 9)), + spvtest::GetBasicBlock(fn, 8)); + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 10)), + spvtest::GetBasicBlock(fn, 9)); + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 11)), + spvtest::GetBasicBlock(fn, 9)); + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 12)), + spvtest::GetBasicBlock(fn, 11)); + } + + // Check post dominator tree + { + PostDominatorAnalysis dom_tree; + const CFG& cfg = *context->cfg(); + dom_tree.InitializeTree(cfg, fn); + + // Inspect the actual tree + DominatorTree& tree = dom_tree.GetDomTree(); + EXPECT_EQ(tree.GetRoot()->bb_, cfg.pseudo_exit_block()); + EXPECT_TRUE(dom_tree.Dominates(cfg.pseudo_exit_block()->id(), 12)); + + // (strict) dominance checks + for (uint32_t id : {8, 9, 10, 11, 12}) + check_dominance(dom_tree, fn, id, id); + + check_dominance(dom_tree, fn, 12, 8); + check_dominance(dom_tree, fn, 12, 10); + check_dominance(dom_tree, fn, 12, 11); + check_dominance(dom_tree, fn, 12, 12); + + check_dominance(dom_tree, fn, 11, 8); + check_dominance(dom_tree, fn, 11, 9); + check_dominance(dom_tree, fn, 11, 10); + + check_dominance(dom_tree, fn, 9, 8); + + EXPECT_EQ(dom_tree.ImmediateDominator(entry), + spvtest::GetBasicBlock(fn, 9)); + + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 9)), + spvtest::GetBasicBlock(fn, 11)); + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 10)), + spvtest::GetBasicBlock(fn, 11)); + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 11)), + spvtest::GetBasicBlock(fn, 12)); + + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 12)), + cfg.pseudo_exit_block()); + + EXPECT_EQ(dom_tree.ImmediateDominator(cfg.pseudo_exit_block()), nullptr); + } +} + +TEST_F(PassClassTest, DominatorLoopToSelf) { + const std::string text = R"( + OpCapability Addresses + OpCapability Kernel + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpTypeBool + %5 = OpTypeInt 32 0 + %6 = OpConstant %5 0 + %7 = OpConstantFalse %4 + %8 = OpConstantTrue %4 + %9 = OpConstant %5 1 + %1 = OpFunction %2 None %3 + %10 = OpLabel + OpBranch %11 + %11 = OpLabel + OpSwitch %6 %12 1 %11 + %12 = OpLabel + OpReturn + OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_0, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* fn = spvtest::GetFunction(module, 1); + + const BasicBlock* entry = spvtest::GetBasicBlock(fn, 10); + EXPECT_EQ(entry, fn->entry().get()) + << "The entry node is not the expected one"; + + // Check normal dominator tree + { + DominatorAnalysis dom_tree; + const CFG& cfg = *context->cfg(); + dom_tree.InitializeTree(cfg, fn); + + // Inspect the actual tree + DominatorTree& tree = dom_tree.GetDomTree(); + EXPECT_EQ(tree.GetRoot()->bb_, cfg.pseudo_entry_block()); + EXPECT_TRUE( + dom_tree.Dominates(cfg.pseudo_entry_block()->id(), entry->id())); + + // (strict) dominance checks + for (uint32_t id : {10, 11, 12}) check_dominance(dom_tree, fn, id, id); + + check_dominance(dom_tree, fn, 10, 11); + check_dominance(dom_tree, fn, 10, 12); + check_dominance(dom_tree, fn, 11, 12); + + EXPECT_EQ(dom_tree.ImmediateDominator(cfg.pseudo_entry_block()), nullptr); + EXPECT_EQ(dom_tree.ImmediateDominator(entry), cfg.pseudo_entry_block()); + + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 11)), + spvtest::GetBasicBlock(fn, 10)); + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 12)), + spvtest::GetBasicBlock(fn, 11)); + + std::array node_order = {{10, 11, 12}}; + { + // Test dominator tree iteration order. + DominatorTree::iterator node_it = dom_tree.GetDomTree().begin(); + DominatorTree::iterator node_end = dom_tree.GetDomTree().end(); + for (uint32_t id : node_order) { + EXPECT_NE(node_it, node_end); + EXPECT_EQ(node_it->id(), id); + node_it++; + } + EXPECT_EQ(node_it, node_end); + } + { + // Same as above, but with const iterators. + DominatorTree::const_iterator node_it = dom_tree.GetDomTree().cbegin(); + DominatorTree::const_iterator node_end = dom_tree.GetDomTree().cend(); + for (uint32_t id : node_order) { + EXPECT_NE(node_it, node_end); + EXPECT_EQ(node_it->id(), id); + node_it++; + } + EXPECT_EQ(node_it, node_end); + } + { + // Test dominator tree iteration order. + DominatorTree::post_iterator node_it = dom_tree.GetDomTree().post_begin(); + DominatorTree::post_iterator node_end = dom_tree.GetDomTree().post_end(); + for (uint32_t id : make_range(node_order.rbegin(), node_order.rend())) { + EXPECT_NE(node_it, node_end); + EXPECT_EQ(node_it->id(), id); + node_it++; + } + EXPECT_EQ(node_it, node_end); + } + { + // Same as above, but with const iterators. + DominatorTree::const_post_iterator node_it = + dom_tree.GetDomTree().post_cbegin(); + DominatorTree::const_post_iterator node_end = + dom_tree.GetDomTree().post_cend(); + for (uint32_t id : make_range(node_order.rbegin(), node_order.rend())) { + EXPECT_NE(node_it, node_end); + EXPECT_EQ(node_it->id(), id); + node_it++; + } + EXPECT_EQ(node_it, node_end); + } + } + + // Check post dominator tree + { + PostDominatorAnalysis dom_tree; + const CFG& cfg = *context->cfg(); + dom_tree.InitializeTree(cfg, fn); + + // Inspect the actual tree + DominatorTree& tree = dom_tree.GetDomTree(); + EXPECT_EQ(tree.GetRoot()->bb_, cfg.pseudo_exit_block()); + EXPECT_TRUE(dom_tree.Dominates(cfg.pseudo_exit_block()->id(), 12)); + + // (strict) dominance checks + for (uint32_t id : {10, 11, 12}) check_dominance(dom_tree, fn, id, id); + + check_dominance(dom_tree, fn, 12, 10); + check_dominance(dom_tree, fn, 12, 11); + check_dominance(dom_tree, fn, 12, 12); + + EXPECT_EQ(dom_tree.ImmediateDominator(entry), + spvtest::GetBasicBlock(fn, 11)); + + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 11)), + spvtest::GetBasicBlock(fn, 12)); + + EXPECT_EQ(dom_tree.ImmediateDominator(cfg.pseudo_exit_block()), nullptr); + + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 12)), + cfg.pseudo_exit_block()); + + std::array node_order = {{12, 11, 10}}; + { + // Test dominator tree iteration order. + DominatorTree::iterator node_it = tree.begin(); + DominatorTree::iterator node_end = tree.end(); + for (uint32_t id : node_order) { + EXPECT_NE(node_it, node_end); + EXPECT_EQ(node_it->id(), id); + node_it++; + } + EXPECT_EQ(node_it, node_end); + } + { + // Same as above, but with const iterators. + DominatorTree::const_iterator node_it = tree.cbegin(); + DominatorTree::const_iterator node_end = tree.cend(); + for (uint32_t id : node_order) { + EXPECT_NE(node_it, node_end); + EXPECT_EQ(node_it->id(), id); + node_it++; + } + EXPECT_EQ(node_it, node_end); + } + { + // Test dominator tree iteration order. + DominatorTree::post_iterator node_it = dom_tree.GetDomTree().post_begin(); + DominatorTree::post_iterator node_end = dom_tree.GetDomTree().post_end(); + for (uint32_t id : make_range(node_order.rbegin(), node_order.rend())) { + EXPECT_NE(node_it, node_end); + EXPECT_EQ(node_it->id(), id); + node_it++; + } + EXPECT_EQ(node_it, node_end); + } + { + // Same as above, but with const iterators. + DominatorTree::const_post_iterator node_it = + dom_tree.GetDomTree().post_cbegin(); + DominatorTree::const_post_iterator node_end = + dom_tree.GetDomTree().post_cend(); + for (uint32_t id : make_range(node_order.rbegin(), node_order.rend())) { + EXPECT_NE(node_it, node_end); + EXPECT_EQ(node_it->id(), id); + node_it++; + } + EXPECT_EQ(node_it, node_end); + } + } +} + +TEST_F(PassClassTest, DominatorUnreachableInLoop) { + const std::string text = R"( + OpCapability Addresses + OpCapability Kernel + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpTypeBool + %5 = OpTypeInt 32 0 + %6 = OpConstant %5 0 + %7 = OpConstantFalse %4 + %8 = OpConstantTrue %4 + %9 = OpConstant %5 1 + %1 = OpFunction %2 None %3 + %10 = OpLabel + OpBranch %11 + %11 = OpLabel + OpSwitch %6 %12 1 %13 + %12 = OpLabel + OpBranch %14 + %13 = OpLabel + OpUnreachable + %14 = OpLabel + OpBranchConditional %8 %11 %15 + %15 = OpLabel + OpReturn + OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_0, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* fn = spvtest::GetFunction(module, 1); + + const BasicBlock* entry = spvtest::GetBasicBlock(fn, 10); + EXPECT_EQ(entry, fn->entry().get()) + << "The entry node is not the expected one"; + + // Check normal dominator tree + { + DominatorAnalysis dom_tree; + const CFG& cfg = *context->cfg(); + dom_tree.InitializeTree(cfg, fn); + + // Inspect the actual tree + DominatorTree& tree = dom_tree.GetDomTree(); + EXPECT_EQ(tree.GetRoot()->bb_, cfg.pseudo_entry_block()); + EXPECT_TRUE( + dom_tree.Dominates(cfg.pseudo_entry_block()->id(), entry->id())); + + // (strict) dominance checks + for (uint32_t id : {10, 11, 12, 13, 14, 15}) + check_dominance(dom_tree, fn, id, id); + + check_dominance(dom_tree, fn, 10, 11); + check_dominance(dom_tree, fn, 10, 13); + check_dominance(dom_tree, fn, 10, 12); + check_dominance(dom_tree, fn, 10, 14); + check_dominance(dom_tree, fn, 10, 15); + + check_dominance(dom_tree, fn, 11, 12); + check_dominance(dom_tree, fn, 11, 13); + check_dominance(dom_tree, fn, 11, 14); + check_dominance(dom_tree, fn, 11, 15); + + check_dominance(dom_tree, fn, 12, 14); + check_dominance(dom_tree, fn, 12, 15); + + check_dominance(dom_tree, fn, 14, 15); + + check_no_dominance(dom_tree, fn, 13, 12); + check_no_dominance(dom_tree, fn, 13, 14); + check_no_dominance(dom_tree, fn, 13, 15); + + EXPECT_EQ(dom_tree.ImmediateDominator(cfg.pseudo_entry_block()), nullptr); + EXPECT_EQ(dom_tree.ImmediateDominator(entry), cfg.pseudo_entry_block()); + + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 11)), + spvtest::GetBasicBlock(fn, 10)); + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 12)), + spvtest::GetBasicBlock(fn, 11)); + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 13)), + spvtest::GetBasicBlock(fn, 11)); + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 14)), + spvtest::GetBasicBlock(fn, 12)); + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 15)), + spvtest::GetBasicBlock(fn, 14)); + } + + // Check post dominator tree. + { + PostDominatorAnalysis dom_tree; + const CFG& cfg = *context->cfg(); + dom_tree.InitializeTree(cfg, fn); + + // (strict) dominance checks. + for (uint32_t id : {10, 11, 12, 13, 14, 15}) + check_dominance(dom_tree, fn, id, id); + + check_no_dominance(dom_tree, fn, 15, 10); + check_no_dominance(dom_tree, fn, 15, 11); + check_no_dominance(dom_tree, fn, 15, 12); + check_no_dominance(dom_tree, fn, 15, 13); + check_no_dominance(dom_tree, fn, 15, 14); + + check_dominance(dom_tree, fn, 14, 12); + + check_no_dominance(dom_tree, fn, 13, 10); + check_no_dominance(dom_tree, fn, 13, 11); + check_no_dominance(dom_tree, fn, 13, 12); + check_no_dominance(dom_tree, fn, 13, 14); + check_no_dominance(dom_tree, fn, 13, 15); + + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 10)), + spvtest::GetBasicBlock(fn, 11)); + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 12)), + spvtest::GetBasicBlock(fn, 14)); + + EXPECT_EQ(dom_tree.ImmediateDominator(cfg.pseudo_exit_block()), nullptr); + + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 15)), + cfg.pseudo_exit_block()); + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 13)), + cfg.pseudo_exit_block()); + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 14)), + cfg.pseudo_exit_block()); + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 11)), + cfg.pseudo_exit_block()); + } +} + +TEST_F(PassClassTest, DominatorInfinitLoop) { + const std::string text = R"( + OpCapability Addresses + OpCapability Kernel + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpTypeBool + %5 = OpTypeInt 32 0 + %6 = OpConstant %5 0 + %7 = OpConstantFalse %4 + %8 = OpConstantTrue %4 + %9 = OpConstant %5 1 + %1 = OpFunction %2 None %3 + %10 = OpLabel + OpBranch %11 + %11 = OpLabel + OpSwitch %6 %12 1 %13 + %12 = OpLabel + OpReturn + %13 = OpLabel + OpBranch %13 + OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_0, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* fn = spvtest::GetFunction(module, 1); + + const BasicBlock* entry = spvtest::GetBasicBlock(fn, 10); + EXPECT_EQ(entry, fn->entry().get()) + << "The entry node is not the expected one"; + // Check normal dominator tree + { + DominatorAnalysis dom_tree; + const CFG& cfg = *context->cfg(); + dom_tree.InitializeTree(cfg, fn); + + // Inspect the actual tree + DominatorTree& tree = dom_tree.GetDomTree(); + EXPECT_EQ(tree.GetRoot()->bb_, cfg.pseudo_entry_block()); + EXPECT_TRUE( + dom_tree.Dominates(cfg.pseudo_entry_block()->id(), entry->id())); + + // (strict) dominance checks + for (uint32_t id : {10, 11, 12, 13}) check_dominance(dom_tree, fn, id, id); + + check_dominance(dom_tree, fn, 10, 11); + check_dominance(dom_tree, fn, 10, 12); + check_dominance(dom_tree, fn, 10, 13); + + check_dominance(dom_tree, fn, 11, 12); + check_dominance(dom_tree, fn, 11, 13); + + check_no_dominance(dom_tree, fn, 13, 12); + + EXPECT_EQ(dom_tree.ImmediateDominator(cfg.pseudo_entry_block()), nullptr); + EXPECT_EQ(dom_tree.ImmediateDominator(entry), cfg.pseudo_entry_block()); + + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 11)), + spvtest::GetBasicBlock(fn, 10)); + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 12)), + spvtest::GetBasicBlock(fn, 11)); + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 13)), + spvtest::GetBasicBlock(fn, 11)); + } + + // Check post dominator tree + { + PostDominatorAnalysis dom_tree; + const CFG& cfg = *context->cfg(); + dom_tree.InitializeTree(cfg, fn); + + // Inspect the actual tree + DominatorTree& tree = dom_tree.GetDomTree(); + EXPECT_EQ(tree.GetRoot()->bb_, cfg.pseudo_exit_block()); + EXPECT_TRUE(dom_tree.Dominates(cfg.pseudo_exit_block()->id(), 12)); + + // (strict) dominance checks + for (uint32_t id : {10, 11, 12}) check_dominance(dom_tree, fn, id, id); + + check_dominance(dom_tree, fn, 12, 11); + check_dominance(dom_tree, fn, 12, 10); + + // 13 should be completely out of tree as it's unreachable from exit nodes + check_no_dominance(dom_tree, fn, 12, 13); + check_no_dominance(dom_tree, fn, 11, 13); + check_no_dominance(dom_tree, fn, 10, 13); + + EXPECT_EQ(dom_tree.ImmediateDominator(cfg.pseudo_exit_block()), nullptr); + + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 12)), + cfg.pseudo_exit_block()); + + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 10)), + spvtest::GetBasicBlock(fn, 11)); + + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 11)), + spvtest::GetBasicBlock(fn, 12)); + } +} + +TEST_F(PassClassTest, DominatorUnreachableFromEntry) { + const std::string text = R"( + OpCapability Addresses + OpCapability Addresses + OpCapability Kernel + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpTypeBool + %5 = OpTypeInt 32 0 + %6 = OpConstantFalse %4 + %7 = OpConstantTrue %4 + %1 = OpFunction %2 None %3 + %8 = OpLabel + OpBranch %9 + %9 = OpLabel + OpReturn + %10 = OpLabel + OpBranch %9 + OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_0, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* fn = spvtest::GetFunction(module, 1); + + const BasicBlock* entry = spvtest::GetBasicBlock(fn, 8); + EXPECT_EQ(entry, fn->entry().get()) + << "The entry node is not the expected one"; + + // Check dominator tree + { + DominatorAnalysis dom_tree; + const CFG& cfg = *context->cfg(); + dom_tree.InitializeTree(cfg, fn); + + // Inspect the actual tree + DominatorTree& tree = dom_tree.GetDomTree(); + EXPECT_EQ(tree.GetRoot()->bb_, cfg.pseudo_entry_block()); + EXPECT_TRUE( + dom_tree.Dominates(cfg.pseudo_entry_block()->id(), entry->id())); + + // (strict) dominance checks + for (uint32_t id : {8, 9}) check_dominance(dom_tree, fn, id, id); + + check_dominance(dom_tree, fn, 8, 9); + + check_no_dominance(dom_tree, fn, 10, 8); + check_no_dominance(dom_tree, fn, 10, 9); + + EXPECT_EQ(dom_tree.ImmediateDominator(cfg.pseudo_entry_block()), nullptr); + EXPECT_EQ(dom_tree.ImmediateDominator(entry), cfg.pseudo_entry_block()); + + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 9)), + spvtest::GetBasicBlock(fn, 8)); + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 10)), + nullptr); + } + + // Check post dominator tree + { + PostDominatorAnalysis dom_tree; + const CFG& cfg = *context->cfg(); + dom_tree.InitializeTree(cfg, fn); + + // Inspect the actual tree + DominatorTree& tree = dom_tree.GetDomTree(); + EXPECT_EQ(tree.GetRoot()->bb_, cfg.pseudo_exit_block()); + EXPECT_TRUE(dom_tree.Dominates(cfg.pseudo_exit_block()->id(), 9)); + + // (strict) dominance checks + for (uint32_t id : {8, 9, 10}) check_dominance(dom_tree, fn, id, id); + + check_dominance(dom_tree, fn, 9, 8); + check_dominance(dom_tree, fn, 9, 10); + + EXPECT_EQ(dom_tree.ImmediateDominator(entry), + spvtest::GetBasicBlock(fn, 9)); + + EXPECT_EQ(dom_tree.ImmediateDominator(cfg.pseudo_exit_block()), nullptr); + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 9)), + cfg.pseudo_exit_block()); + EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 10)), + spvtest::GetBasicBlock(fn, 9)); + } +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/dominator_tree/nested_ifs.cpp b/test/opt/dominator_tree/nested_ifs.cpp new file mode 100644 index 000000000..0552b7580 --- /dev/null +++ b/test/opt/dominator_tree/nested_ifs.cpp @@ -0,0 +1,153 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::UnorderedElementsAre; +using PassClassTest = PassTest<::testing::Test>; + +/* + Generated from the following GLSL +#version 330 core +layout(location = 0) out vec4 v; +void main(){ + if (true) { + if (true) { + v = vec4(1,1,1,1); + } else { + v = vec4(2,2,2,2); + } + } else { + if (true) { + v = vec4(3,3,3,3); + } else { + v = vec4(4,4,4,4); + } + } +} +*/ +TEST_F(PassClassTest, UnreachableNestedIfs) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %15 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 330 + OpName %4 "main" + OpName %15 "v" + OpDecorate %15 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeBool + %7 = OpConstantTrue %6 + %12 = OpTypeFloat 32 + %13 = OpTypeVector %12 4 + %14 = OpTypePointer Output %13 + %15 = OpVariable %14 Output + %16 = OpConstant %12 1 + %17 = OpConstantComposite %13 %16 %16 %16 %16 + %19 = OpConstant %12 2 + %20 = OpConstantComposite %13 %19 %19 %19 %19 + %24 = OpConstant %12 3 + %25 = OpConstantComposite %13 %24 %24 %24 %24 + %27 = OpConstant %12 4 + %28 = OpConstantComposite %13 %27 %27 %27 %27 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpSelectionMerge %9 None + OpBranchConditional %7 %8 %21 + %8 = OpLabel + OpSelectionMerge %11 None + OpBranchConditional %7 %10 %18 + %10 = OpLabel + OpStore %15 %17 + OpBranch %11 + %18 = OpLabel + OpStore %15 %20 + OpBranch %11 + %11 = OpLabel + OpBranch %9 + %21 = OpLabel + OpSelectionMerge %23 None + OpBranchConditional %7 %22 %26 + %22 = OpLabel + OpStore %15 %25 + OpBranch %23 + %26 = OpLabel + OpStore %15 %28 + OpBranch %23 + %23 = OpLabel + OpBranch %9 + %9 = OpLabel + OpReturn + OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + const Function* f = spvtest::GetFunction(module, 4); + + DominatorAnalysis* analysis = context->GetDominatorAnalysis(f); + + EXPECT_TRUE(analysis->Dominates(5, 8)); + EXPECT_TRUE(analysis->Dominates(5, 9)); + EXPECT_TRUE(analysis->Dominates(5, 21)); + EXPECT_TRUE(analysis->Dominates(5, 18)); + EXPECT_TRUE(analysis->Dominates(5, 10)); + EXPECT_TRUE(analysis->Dominates(5, 11)); + EXPECT_TRUE(analysis->Dominates(5, 23)); + EXPECT_TRUE(analysis->Dominates(5, 22)); + EXPECT_TRUE(analysis->Dominates(5, 26)); + EXPECT_TRUE(analysis->Dominates(8, 18)); + EXPECT_TRUE(analysis->Dominates(8, 10)); + EXPECT_TRUE(analysis->Dominates(8, 11)); + EXPECT_TRUE(analysis->Dominates(21, 23)); + EXPECT_TRUE(analysis->Dominates(21, 22)); + EXPECT_TRUE(analysis->Dominates(21, 26)); + + EXPECT_TRUE(analysis->StrictlyDominates(5, 8)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 9)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 21)); + EXPECT_TRUE(analysis->StrictlyDominates(8, 18)); + EXPECT_TRUE(analysis->StrictlyDominates(8, 10)); + EXPECT_TRUE(analysis->StrictlyDominates(8, 11)); + EXPECT_TRUE(analysis->StrictlyDominates(21, 23)); + EXPECT_TRUE(analysis->StrictlyDominates(21, 22)); + EXPECT_TRUE(analysis->StrictlyDominates(21, 26)); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/dominator_tree/nested_ifs_post.cpp b/test/opt/dominator_tree/nested_ifs_post.cpp new file mode 100644 index 000000000..ad759df86 --- /dev/null +++ b/test/opt/dominator_tree/nested_ifs_post.cpp @@ -0,0 +1,156 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::UnorderedElementsAre; +using PassClassTest = PassTest<::testing::Test>; + +/* + Generated from the following GLSL +#version 330 core +layout(location = 0) out vec4 v; +void main(){ + if (true) { + if (true) { + v = vec4(1,1,1,1); + } else { + v = vec4(2,2,2,2); + } + } else { + if (true) { + v = vec4(3,3,3,3); + } else { + v = vec4(4,4,4,4); + } + } +} +*/ +TEST_F(PassClassTest, UnreachableNestedIfs) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %15 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 330 + OpName %4 "main" + OpName %15 "v" + OpDecorate %15 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeBool + %7 = OpConstantTrue %6 + %12 = OpTypeFloat 32 + %13 = OpTypeVector %12 4 + %14 = OpTypePointer Output %13 + %15 = OpVariable %14 Output + %16 = OpConstant %12 1 + %17 = OpConstantComposite %13 %16 %16 %16 %16 + %19 = OpConstant %12 2 + %20 = OpConstantComposite %13 %19 %19 %19 %19 + %24 = OpConstant %12 3 + %25 = OpConstantComposite %13 %24 %24 %24 %24 + %27 = OpConstant %12 4 + %28 = OpConstantComposite %13 %27 %27 %27 %27 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpSelectionMerge %9 None + OpBranchConditional %7 %8 %21 + %8 = OpLabel + OpSelectionMerge %11 None + OpBranchConditional %7 %10 %18 + %10 = OpLabel + OpStore %15 %17 + OpBranch %11 + %18 = OpLabel + OpStore %15 %20 + OpBranch %11 + %11 = OpLabel + OpBranch %9 + %21 = OpLabel + OpSelectionMerge %23 None + OpBranchConditional %7 %22 %26 + %22 = OpLabel + OpStore %15 %25 + OpBranch %23 + %26 = OpLabel + OpStore %15 %28 + OpBranch %23 + %23 = OpLabel + OpBranch %9 + %9 = OpLabel + OpReturn + OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + const Function* f = spvtest::GetFunction(module, 4); + + PostDominatorAnalysis* analysis = context->GetPostDominatorAnalysis(f); + + EXPECT_TRUE(analysis->Dominates(5, 5)); + EXPECT_TRUE(analysis->Dominates(8, 8)); + EXPECT_TRUE(analysis->Dominates(9, 9)); + EXPECT_TRUE(analysis->Dominates(10, 10)); + EXPECT_TRUE(analysis->Dominates(11, 11)); + EXPECT_TRUE(analysis->Dominates(18, 18)); + EXPECT_TRUE(analysis->Dominates(21, 21)); + EXPECT_TRUE(analysis->Dominates(22, 22)); + EXPECT_TRUE(analysis->Dominates(23, 23)); + EXPECT_TRUE(analysis->Dominates(26, 26)); + EXPECT_TRUE(analysis->Dominates(9, 5)); + EXPECT_TRUE(analysis->Dominates(9, 11)); + EXPECT_TRUE(analysis->Dominates(9, 23)); + EXPECT_TRUE(analysis->Dominates(11, 10)); + EXPECT_TRUE(analysis->Dominates(11, 18)); + EXPECT_TRUE(analysis->Dominates(11, 8)); + EXPECT_TRUE(analysis->Dominates(23, 22)); + EXPECT_TRUE(analysis->Dominates(23, 26)); + EXPECT_TRUE(analysis->Dominates(23, 21)); + + EXPECT_TRUE(analysis->StrictlyDominates(9, 5)); + EXPECT_TRUE(analysis->StrictlyDominates(9, 11)); + EXPECT_TRUE(analysis->StrictlyDominates(9, 23)); + EXPECT_TRUE(analysis->StrictlyDominates(11, 10)); + EXPECT_TRUE(analysis->StrictlyDominates(11, 18)); + EXPECT_TRUE(analysis->StrictlyDominates(11, 8)); + EXPECT_TRUE(analysis->StrictlyDominates(23, 22)); + EXPECT_TRUE(analysis->StrictlyDominates(23, 26)); + EXPECT_TRUE(analysis->StrictlyDominates(23, 21)); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/dominator_tree/nested_loops.cpp b/test/opt/dominator_tree/nested_loops.cpp new file mode 100644 index 000000000..7d03937b1 --- /dev/null +++ b/test/opt/dominator_tree/nested_loops.cpp @@ -0,0 +1,433 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::UnorderedElementsAre; +using PassClassTest = PassTest<::testing::Test>; + +/* + Generated from the following GLSL +#version 440 core +layout(location = 0) out vec4 v; +layout(location = 1) in vec4 in_val; +void main() { + for (int i = 0; i < in_val.x; ++i) { + for (int j = 0; j < in_val.y; j++) { + } + } + for (int i = 0; i < in_val.x; ++i) { + for (int j = 0; j < in_val.y; j++) { + } + if (in_val.z == in_val.w) { + break; + } + } + int i = 0; + while (i < in_val.x) { + ++i; + for (int j = 0; j < 1; j++) { + for (int k = 0; k < 1; k++) { + } + } + } + i = 0; + while (i < in_val.x) { + ++i; + if (in_val.z == in_val.w) { + continue; + } + for (int j = 0; j < 1; j++) { + for (int k = 0; k < 1; k++) { + } + if (in_val.z == in_val.w) { + break; + } + } + } + v = vec4(1,1,1,1); +} +*/ +TEST_F(PassClassTest, BasicVisitFromEntryPoint) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %20 %163 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %20 "in_val" + OpName %28 "j" + OpName %45 "i" + OpName %56 "j" + OpName %81 "i" + OpName %94 "j" + OpName %102 "k" + OpName %134 "j" + OpName %142 "k" + OpName %163 "v" + OpDecorate %20 Location 1 + OpDecorate %163 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpTypeFloat 32 + %18 = OpTypeVector %16 4 + %19 = OpTypePointer Input %18 + %20 = OpVariable %19 Input + %21 = OpTypeInt 32 0 + %22 = OpConstant %21 0 + %23 = OpTypePointer Input %16 + %26 = OpTypeBool + %36 = OpConstant %21 1 + %41 = OpConstant %6 1 + %69 = OpConstant %21 2 + %72 = OpConstant %21 3 + %162 = OpTypePointer Output %18 + %163 = OpVariable %162 Output + %164 = OpConstant %16 1 + %165 = OpConstantComposite %18 %164 %164 %164 %164 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %28 = OpVariable %7 Function + %45 = OpVariable %7 Function + %56 = OpVariable %7 Function + %81 = OpVariable %7 Function + %94 = OpVariable %7 Function + %102 = OpVariable %7 Function + %134 = OpVariable %7 Function + %142 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %15 = OpLoad %6 %8 + %17 = OpConvertSToF %16 %15 + %24 = OpAccessChain %23 %20 %22 + %25 = OpLoad %16 %24 + %27 = OpFOrdLessThan %26 %17 %25 + OpBranchConditional %27 %11 %12 + %11 = OpLabel + OpStore %28 %9 + OpBranch %29 + %29 = OpLabel + OpLoopMerge %31 %32 None + OpBranch %33 + %33 = OpLabel + %34 = OpLoad %6 %28 + %35 = OpConvertSToF %16 %34 + %37 = OpAccessChain %23 %20 %36 + %38 = OpLoad %16 %37 + %39 = OpFOrdLessThan %26 %35 %38 + OpBranchConditional %39 %30 %31 + %30 = OpLabel + OpBranch %32 + %32 = OpLabel + %40 = OpLoad %6 %28 + %42 = OpIAdd %6 %40 %41 + OpStore %28 %42 + OpBranch %29 + %31 = OpLabel + OpBranch %13 + %13 = OpLabel + %43 = OpLoad %6 %8 + %44 = OpIAdd %6 %43 %41 + OpStore %8 %44 + OpBranch %10 + %12 = OpLabel + OpStore %45 %9 + OpBranch %46 + %46 = OpLabel + OpLoopMerge %48 %49 None + OpBranch %50 + %50 = OpLabel + %51 = OpLoad %6 %45 + %52 = OpConvertSToF %16 %51 + %53 = OpAccessChain %23 %20 %22 + %54 = OpLoad %16 %53 + %55 = OpFOrdLessThan %26 %52 %54 + OpBranchConditional %55 %47 %48 + %47 = OpLabel + OpStore %56 %9 + OpBranch %57 + %57 = OpLabel + OpLoopMerge %59 %60 None + OpBranch %61 + %61 = OpLabel + %62 = OpLoad %6 %56 + %63 = OpConvertSToF %16 %62 + %64 = OpAccessChain %23 %20 %36 + %65 = OpLoad %16 %64 + %66 = OpFOrdLessThan %26 %63 %65 + OpBranchConditional %66 %58 %59 + %58 = OpLabel + OpBranch %60 + %60 = OpLabel + %67 = OpLoad %6 %56 + %68 = OpIAdd %6 %67 %41 + OpStore %56 %68 + OpBranch %57 + %59 = OpLabel + %70 = OpAccessChain %23 %20 %69 + %71 = OpLoad %16 %70 + %73 = OpAccessChain %23 %20 %72 + %74 = OpLoad %16 %73 + %75 = OpFOrdEqual %26 %71 %74 + OpSelectionMerge %77 None + OpBranchConditional %75 %76 %77 + %76 = OpLabel + OpBranch %48 + %77 = OpLabel + OpBranch %49 + %49 = OpLabel + %79 = OpLoad %6 %45 + %80 = OpIAdd %6 %79 %41 + OpStore %45 %80 + OpBranch %46 + %48 = OpLabel + OpStore %81 %9 + OpBranch %82 + %82 = OpLabel + OpLoopMerge %84 %85 None + OpBranch %86 + %86 = OpLabel + %87 = OpLoad %6 %81 + %88 = OpConvertSToF %16 %87 + %89 = OpAccessChain %23 %20 %22 + %90 = OpLoad %16 %89 + %91 = OpFOrdLessThan %26 %88 %90 + OpBranchConditional %91 %83 %84 + %83 = OpLabel + %92 = OpLoad %6 %81 + %93 = OpIAdd %6 %92 %41 + OpStore %81 %93 + OpStore %94 %9 + OpBranch %95 + %95 = OpLabel + OpLoopMerge %97 %98 None + OpBranch %99 + %99 = OpLabel + %100 = OpLoad %6 %94 + %101 = OpSLessThan %26 %100 %41 + OpBranchConditional %101 %96 %97 + %96 = OpLabel + OpStore %102 %9 + OpBranch %103 + %103 = OpLabel + OpLoopMerge %105 %106 None + OpBranch %107 + %107 = OpLabel + %108 = OpLoad %6 %102 + %109 = OpSLessThan %26 %108 %41 + OpBranchConditional %109 %104 %105 + %104 = OpLabel + OpBranch %106 + %106 = OpLabel + %110 = OpLoad %6 %102 + %111 = OpIAdd %6 %110 %41 + OpStore %102 %111 + OpBranch %103 + %105 = OpLabel + OpBranch %98 + %98 = OpLabel + %112 = OpLoad %6 %94 + %113 = OpIAdd %6 %112 %41 + OpStore %94 %113 + OpBranch %95 + %97 = OpLabel + OpBranch %85 + %85 = OpLabel + OpBranch %82 + %84 = OpLabel + OpStore %81 %9 + OpBranch %114 + %114 = OpLabel + OpLoopMerge %116 %117 None + OpBranch %118 + %118 = OpLabel + %119 = OpLoad %6 %81 + %120 = OpConvertSToF %16 %119 + %121 = OpAccessChain %23 %20 %22 + %122 = OpLoad %16 %121 + %123 = OpFOrdLessThan %26 %120 %122 + OpBranchConditional %123 %115 %116 + %115 = OpLabel + %124 = OpLoad %6 %81 + %125 = OpIAdd %6 %124 %41 + OpStore %81 %125 + %126 = OpAccessChain %23 %20 %69 + %127 = OpLoad %16 %126 + %128 = OpAccessChain %23 %20 %72 + %129 = OpLoad %16 %128 + %130 = OpFOrdEqual %26 %127 %129 + OpSelectionMerge %132 None + OpBranchConditional %130 %131 %132 + %131 = OpLabel + OpBranch %117 + %132 = OpLabel + OpStore %134 %9 + OpBranch %135 + %135 = OpLabel + OpLoopMerge %137 %138 None + OpBranch %139 + %139 = OpLabel + %140 = OpLoad %6 %134 + %141 = OpSLessThan %26 %140 %41 + OpBranchConditional %141 %136 %137 + %136 = OpLabel + OpStore %142 %9 + OpBranch %143 + %143 = OpLabel + OpLoopMerge %145 %146 None + OpBranch %147 + %147 = OpLabel + %148 = OpLoad %6 %142 + %149 = OpSLessThan %26 %148 %41 + OpBranchConditional %149 %144 %145 + %144 = OpLabel + OpBranch %146 + %146 = OpLabel + %150 = OpLoad %6 %142 + %151 = OpIAdd %6 %150 %41 + OpStore %142 %151 + OpBranch %143 + %145 = OpLabel + %152 = OpAccessChain %23 %20 %69 + %153 = OpLoad %16 %152 + %154 = OpAccessChain %23 %20 %72 + %155 = OpLoad %16 %154 + %156 = OpFOrdEqual %26 %153 %155 + OpSelectionMerge %158 None + OpBranchConditional %156 %157 %158 + %157 = OpLabel + OpBranch %137 + %158 = OpLabel + OpBranch %138 + %138 = OpLabel + %160 = OpLoad %6 %134 + %161 = OpIAdd %6 %160 %41 + OpStore %134 %161 + OpBranch %135 + %137 = OpLabel + OpBranch %117 + %117 = OpLabel + OpBranch %114 + %116 = OpLabel + OpStore %163 %165 + OpReturn + OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + const Function* f = spvtest::GetFunction(module, 4); + DominatorAnalysis* analysis = context->GetDominatorAnalysis(f); + + EXPECT_TRUE(analysis->Dominates(5, 10)); + EXPECT_TRUE(analysis->Dominates(5, 46)); + EXPECT_TRUE(analysis->Dominates(5, 82)); + EXPECT_TRUE(analysis->Dominates(5, 114)); + EXPECT_TRUE(analysis->Dominates(5, 116)); + + EXPECT_TRUE(analysis->Dominates(10, 14)); + EXPECT_TRUE(analysis->Dominates(10, 11)); + EXPECT_TRUE(analysis->Dominates(10, 29)); + EXPECT_TRUE(analysis->Dominates(10, 33)); + EXPECT_TRUE(analysis->Dominates(10, 30)); + EXPECT_TRUE(analysis->Dominates(10, 32)); + EXPECT_TRUE(analysis->Dominates(10, 31)); + EXPECT_TRUE(analysis->Dominates(10, 13)); + EXPECT_TRUE(analysis->Dominates(10, 12)); + + EXPECT_TRUE(analysis->Dominates(12, 46)); + + EXPECT_TRUE(analysis->Dominates(46, 50)); + EXPECT_TRUE(analysis->Dominates(46, 47)); + EXPECT_TRUE(analysis->Dominates(46, 57)); + EXPECT_TRUE(analysis->Dominates(46, 61)); + EXPECT_TRUE(analysis->Dominates(46, 58)); + EXPECT_TRUE(analysis->Dominates(46, 60)); + EXPECT_TRUE(analysis->Dominates(46, 59)); + EXPECT_TRUE(analysis->Dominates(46, 77)); + EXPECT_TRUE(analysis->Dominates(46, 49)); + EXPECT_TRUE(analysis->Dominates(46, 76)); + EXPECT_TRUE(analysis->Dominates(46, 48)); + + EXPECT_TRUE(analysis->Dominates(48, 82)); + + EXPECT_TRUE(analysis->Dominates(82, 86)); + EXPECT_TRUE(analysis->Dominates(82, 83)); + EXPECT_TRUE(analysis->Dominates(82, 95)); + EXPECT_TRUE(analysis->Dominates(82, 99)); + EXPECT_TRUE(analysis->Dominates(82, 96)); + EXPECT_TRUE(analysis->Dominates(82, 103)); + EXPECT_TRUE(analysis->Dominates(82, 107)); + EXPECT_TRUE(analysis->Dominates(82, 104)); + EXPECT_TRUE(analysis->Dominates(82, 106)); + EXPECT_TRUE(analysis->Dominates(82, 105)); + EXPECT_TRUE(analysis->Dominates(82, 98)); + EXPECT_TRUE(analysis->Dominates(82, 97)); + EXPECT_TRUE(analysis->Dominates(82, 85)); + EXPECT_TRUE(analysis->Dominates(82, 84)); + + EXPECT_TRUE(analysis->Dominates(84, 114)); + + EXPECT_TRUE(analysis->Dominates(114, 118)); + EXPECT_TRUE(analysis->Dominates(114, 116)); + EXPECT_TRUE(analysis->Dominates(114, 115)); + EXPECT_TRUE(analysis->Dominates(114, 132)); + EXPECT_TRUE(analysis->Dominates(114, 135)); + EXPECT_TRUE(analysis->Dominates(114, 139)); + EXPECT_TRUE(analysis->Dominates(114, 136)); + EXPECT_TRUE(analysis->Dominates(114, 143)); + EXPECT_TRUE(analysis->Dominates(114, 147)); + EXPECT_TRUE(analysis->Dominates(114, 144)); + EXPECT_TRUE(analysis->Dominates(114, 146)); + EXPECT_TRUE(analysis->Dominates(114, 145)); + EXPECT_TRUE(analysis->Dominates(114, 158)); + EXPECT_TRUE(analysis->Dominates(114, 138)); + EXPECT_TRUE(analysis->Dominates(114, 137)); + EXPECT_TRUE(analysis->Dominates(114, 131)); + EXPECT_TRUE(analysis->Dominates(114, 117)); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/dominator_tree/nested_loops_with_unreachables.cpp b/test/opt/dominator_tree/nested_loops_with_unreachables.cpp new file mode 100644 index 000000000..e87e8ddab --- /dev/null +++ b/test/opt/dominator_tree/nested_loops_with_unreachables.cpp @@ -0,0 +1,848 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::UnorderedElementsAre; + +using PassClassTest = PassTest<::testing::Test>; + +/* +Generated from the following GLSL +#version 440 core +layout(location = 0) out vec4 v; +layout(location = 1) in vec4 in_val; +void main() { + for (int i = 0; i < in_val.x; ++i) { + for (int j = 0; j < in_val.y; j++) { + } + } + for (int i = 0; i < in_val.x; ++i) { + for (int j = 0; j < in_val.y; j++) { + } + break; + } + int i = 0; + while (i < in_val.x) { + ++i; + for (int j = 0; j < 1; j++) { + for (int k = 0; k < 1; k++) { + } + break; + } + } + i = 0; + while (i < in_val.x) { + ++i; + continue; + for (int j = 0; j < 1; j++) { + for (int k = 0; k < 1; k++) { + } + break; + } + } + v = vec4(1,1,1,1); +} +*/ +TEST_F(PassClassTest, BasicVisitFromEntryPoint) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %20 %141 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %20 "in_val" + OpName %28 "j" + OpName %45 "i" + OpName %56 "j" + OpName %72 "i" + OpName %85 "j" + OpName %93 "k" + OpName %119 "j" + OpName %127 "k" + OpName %141 "v" + OpDecorate %20 Location 1 + OpDecorate %141 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpTypeFloat 32 + %18 = OpTypeVector %16 4 + %19 = OpTypePointer Input %18 + %20 = OpVariable %19 Input + %21 = OpTypeInt 32 0 + %22 = OpConstant %21 0 + %23 = OpTypePointer Input %16 + %26 = OpTypeBool + %36 = OpConstant %21 1 + %41 = OpConstant %6 1 + %140 = OpTypePointer Output %18 + %141 = OpVariable %140 Output + %142 = OpConstant %16 1 + %143 = OpConstantComposite %18 %142 %142 %142 %142 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %28 = OpVariable %7 Function + %45 = OpVariable %7 Function + %56 = OpVariable %7 Function + %72 = OpVariable %7 Function + %85 = OpVariable %7 Function + %93 = OpVariable %7 Function + %119 = OpVariable %7 Function + %127 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %15 = OpLoad %6 %8 + %17 = OpConvertSToF %16 %15 + %24 = OpAccessChain %23 %20 %22 + %25 = OpLoad %16 %24 + %27 = OpFOrdLessThan %26 %17 %25 + OpBranchConditional %27 %11 %12 + %11 = OpLabel + OpStore %28 %9 + OpBranch %29 + %29 = OpLabel + OpLoopMerge %31 %32 None + OpBranch %33 + %33 = OpLabel + %34 = OpLoad %6 %28 + %35 = OpConvertSToF %16 %34 + %37 = OpAccessChain %23 %20 %36 + %38 = OpLoad %16 %37 + %39 = OpFOrdLessThan %26 %35 %38 + OpBranchConditional %39 %30 %31 + %30 = OpLabel + OpBranch %32 + %32 = OpLabel + %40 = OpLoad %6 %28 + %42 = OpIAdd %6 %40 %41 + OpStore %28 %42 + OpBranch %29 + %31 = OpLabel + OpBranch %13 + %13 = OpLabel + %43 = OpLoad %6 %8 + %44 = OpIAdd %6 %43 %41 + OpStore %8 %44 + OpBranch %10 + %12 = OpLabel + OpStore %45 %9 + OpBranch %46 + %46 = OpLabel + OpLoopMerge %48 %49 None + OpBranch %50 + %50 = OpLabel + %51 = OpLoad %6 %45 + %52 = OpConvertSToF %16 %51 + %53 = OpAccessChain %23 %20 %22 + %54 = OpLoad %16 %53 + %55 = OpFOrdLessThan %26 %52 %54 + OpBranchConditional %55 %47 %48 + %47 = OpLabel + OpStore %56 %9 + OpBranch %57 + %57 = OpLabel + OpLoopMerge %59 %60 None + OpBranch %61 + %61 = OpLabel + %62 = OpLoad %6 %56 + %63 = OpConvertSToF %16 %62 + %64 = OpAccessChain %23 %20 %36 + %65 = OpLoad %16 %64 + %66 = OpFOrdLessThan %26 %63 %65 + OpBranchConditional %66 %58 %59 + %58 = OpLabel + OpBranch %60 + %60 = OpLabel + %67 = OpLoad %6 %56 + %68 = OpIAdd %6 %67 %41 + OpStore %56 %68 + OpBranch %57 + %59 = OpLabel + OpBranch %48 + %49 = OpLabel + %70 = OpLoad %6 %45 + %71 = OpIAdd %6 %70 %41 + OpStore %45 %71 + OpBranch %46 + %48 = OpLabel + OpStore %72 %9 + OpBranch %73 + %73 = OpLabel + OpLoopMerge %75 %76 None + OpBranch %77 + %77 = OpLabel + %78 = OpLoad %6 %72 + %79 = OpConvertSToF %16 %78 + %80 = OpAccessChain %23 %20 %22 + %81 = OpLoad %16 %80 + %82 = OpFOrdLessThan %26 %79 %81 + OpBranchConditional %82 %74 %75 + %74 = OpLabel + %83 = OpLoad %6 %72 + %84 = OpIAdd %6 %83 %41 + OpStore %72 %84 + OpStore %85 %9 + OpBranch %86 + %86 = OpLabel + OpLoopMerge %88 %89 None + OpBranch %90 + %90 = OpLabel + %91 = OpLoad %6 %85 + %92 = OpSLessThan %26 %91 %41 + OpBranchConditional %92 %87 %88 + %87 = OpLabel + OpStore %93 %9 + OpBranch %94 + %94 = OpLabel + OpLoopMerge %96 %97 None + OpBranch %98 + %98 = OpLabel + %99 = OpLoad %6 %93 + %100 = OpSLessThan %26 %99 %41 + OpBranchConditional %100 %95 %96 + %95 = OpLabel + OpBranch %97 + %97 = OpLabel + %101 = OpLoad %6 %93 + %102 = OpIAdd %6 %101 %41 + OpStore %93 %102 + OpBranch %94 + %96 = OpLabel + OpBranch %88 + %89 = OpLabel + %104 = OpLoad %6 %85 + %105 = OpIAdd %6 %104 %41 + OpStore %85 %105 + OpBranch %86 + %88 = OpLabel + OpBranch %76 + %76 = OpLabel + OpBranch %73 + %75 = OpLabel + OpStore %72 %9 + OpBranch %106 + %106 = OpLabel + OpLoopMerge %108 %109 None + OpBranch %110 + %110 = OpLabel + %111 = OpLoad %6 %72 + %112 = OpConvertSToF %16 %111 + %113 = OpAccessChain %23 %20 %22 + %114 = OpLoad %16 %113 + %115 = OpFOrdLessThan %26 %112 %114 + OpBranchConditional %115 %107 %108 + %107 = OpLabel + %116 = OpLoad %6 %72 + %117 = OpIAdd %6 %116 %41 + OpStore %72 %117 + OpBranch %109 + %109 = OpLabel + OpBranch %106 + %108 = OpLabel + OpStore %141 %143 + OpReturn + OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + const Function* f = spvtest::GetFunction(module, 4); + DominatorAnalysis* analysis = context->GetDominatorAnalysis(f); + + EXPECT_TRUE(analysis->Dominates(5, 10)); + EXPECT_TRUE(analysis->Dominates(5, 14)); + EXPECT_TRUE(analysis->Dominates(5, 11)); + EXPECT_TRUE(analysis->Dominates(5, 29)); + EXPECT_TRUE(analysis->Dominates(5, 33)); + EXPECT_TRUE(analysis->Dominates(5, 30)); + EXPECT_TRUE(analysis->Dominates(5, 32)); + EXPECT_TRUE(analysis->Dominates(5, 31)); + EXPECT_TRUE(analysis->Dominates(5, 13)); + EXPECT_TRUE(analysis->Dominates(5, 12)); + EXPECT_TRUE(analysis->Dominates(5, 46)); + EXPECT_TRUE(analysis->Dominates(5, 50)); + EXPECT_TRUE(analysis->Dominates(5, 47)); + EXPECT_TRUE(analysis->Dominates(5, 57)); + EXPECT_TRUE(analysis->Dominates(5, 61)); + EXPECT_TRUE(analysis->Dominates(5, 59)); + EXPECT_TRUE(analysis->Dominates(5, 58)); + EXPECT_TRUE(analysis->Dominates(5, 60)); + EXPECT_TRUE(analysis->Dominates(5, 48)); + EXPECT_TRUE(analysis->Dominates(5, 73)); + EXPECT_TRUE(analysis->Dominates(5, 77)); + EXPECT_TRUE(analysis->Dominates(5, 75)); + EXPECT_TRUE(analysis->Dominates(5, 106)); + EXPECT_TRUE(analysis->Dominates(5, 110)); + EXPECT_TRUE(analysis->Dominates(5, 107)); + EXPECT_TRUE(analysis->Dominates(5, 108)); + EXPECT_TRUE(analysis->Dominates(5, 109)); + EXPECT_TRUE(analysis->Dominates(5, 74)); + EXPECT_TRUE(analysis->Dominates(5, 86)); + EXPECT_TRUE(analysis->Dominates(5, 90)); + EXPECT_TRUE(analysis->Dominates(5, 87)); + EXPECT_TRUE(analysis->Dominates(5, 94)); + EXPECT_TRUE(analysis->Dominates(5, 98)); + EXPECT_TRUE(analysis->Dominates(5, 95)); + EXPECT_TRUE(analysis->Dominates(5, 97)); + EXPECT_TRUE(analysis->Dominates(5, 96)); + EXPECT_TRUE(analysis->Dominates(5, 88)); + EXPECT_TRUE(analysis->Dominates(5, 76)); + + EXPECT_TRUE(analysis->Dominates(10, 14)); + EXPECT_TRUE(analysis->Dominates(10, 11)); + EXPECT_TRUE(analysis->Dominates(10, 29)); + EXPECT_TRUE(analysis->Dominates(10, 33)); + EXPECT_TRUE(analysis->Dominates(10, 30)); + EXPECT_TRUE(analysis->Dominates(10, 32)); + EXPECT_TRUE(analysis->Dominates(10, 31)); + EXPECT_TRUE(analysis->Dominates(10, 13)); + EXPECT_TRUE(analysis->Dominates(10, 12)); + EXPECT_TRUE(analysis->Dominates(10, 46)); + EXPECT_TRUE(analysis->Dominates(10, 50)); + EXPECT_TRUE(analysis->Dominates(10, 47)); + EXPECT_TRUE(analysis->Dominates(10, 57)); + EXPECT_TRUE(analysis->Dominates(10, 61)); + EXPECT_TRUE(analysis->Dominates(10, 59)); + EXPECT_TRUE(analysis->Dominates(10, 58)); + EXPECT_TRUE(analysis->Dominates(10, 60)); + EXPECT_TRUE(analysis->Dominates(10, 48)); + EXPECT_TRUE(analysis->Dominates(10, 73)); + EXPECT_TRUE(analysis->Dominates(10, 77)); + EXPECT_TRUE(analysis->Dominates(10, 75)); + EXPECT_TRUE(analysis->Dominates(10, 106)); + EXPECT_TRUE(analysis->Dominates(10, 110)); + EXPECT_TRUE(analysis->Dominates(10, 107)); + EXPECT_TRUE(analysis->Dominates(10, 108)); + EXPECT_TRUE(analysis->Dominates(10, 109)); + EXPECT_TRUE(analysis->Dominates(10, 74)); + EXPECT_TRUE(analysis->Dominates(10, 86)); + EXPECT_TRUE(analysis->Dominates(10, 90)); + EXPECT_TRUE(analysis->Dominates(10, 87)); + EXPECT_TRUE(analysis->Dominates(10, 94)); + EXPECT_TRUE(analysis->Dominates(10, 98)); + EXPECT_TRUE(analysis->Dominates(10, 95)); + EXPECT_TRUE(analysis->Dominates(10, 97)); + EXPECT_TRUE(analysis->Dominates(10, 96)); + EXPECT_TRUE(analysis->Dominates(10, 88)); + EXPECT_TRUE(analysis->Dominates(10, 76)); + + EXPECT_TRUE(analysis->Dominates(14, 11)); + EXPECT_TRUE(analysis->Dominates(14, 29)); + EXPECT_TRUE(analysis->Dominates(14, 33)); + EXPECT_TRUE(analysis->Dominates(14, 30)); + EXPECT_TRUE(analysis->Dominates(14, 32)); + EXPECT_TRUE(analysis->Dominates(14, 31)); + + EXPECT_TRUE(analysis->Dominates(11, 29)); + EXPECT_TRUE(analysis->Dominates(11, 33)); + EXPECT_TRUE(analysis->Dominates(11, 30)); + EXPECT_TRUE(analysis->Dominates(11, 32)); + EXPECT_TRUE(analysis->Dominates(11, 31)); + + EXPECT_TRUE(analysis->Dominates(29, 33)); + EXPECT_TRUE(analysis->Dominates(29, 30)); + EXPECT_TRUE(analysis->Dominates(29, 32)); + EXPECT_TRUE(analysis->Dominates(29, 31)); + + EXPECT_TRUE(analysis->Dominates(33, 30)); + + EXPECT_TRUE(analysis->Dominates(12, 46)); + EXPECT_TRUE(analysis->Dominates(12, 50)); + EXPECT_TRUE(analysis->Dominates(12, 47)); + EXPECT_TRUE(analysis->Dominates(12, 57)); + EXPECT_TRUE(analysis->Dominates(12, 61)); + EXPECT_TRUE(analysis->Dominates(12, 59)); + EXPECT_TRUE(analysis->Dominates(12, 58)); + EXPECT_TRUE(analysis->Dominates(12, 60)); + EXPECT_TRUE(analysis->Dominates(12, 48)); + EXPECT_TRUE(analysis->Dominates(12, 73)); + EXPECT_TRUE(analysis->Dominates(12, 77)); + EXPECT_TRUE(analysis->Dominates(12, 75)); + EXPECT_TRUE(analysis->Dominates(12, 106)); + EXPECT_TRUE(analysis->Dominates(12, 110)); + EXPECT_TRUE(analysis->Dominates(12, 107)); + EXPECT_TRUE(analysis->Dominates(12, 108)); + EXPECT_TRUE(analysis->Dominates(12, 109)); + EXPECT_TRUE(analysis->Dominates(12, 74)); + EXPECT_TRUE(analysis->Dominates(12, 86)); + EXPECT_TRUE(analysis->Dominates(12, 90)); + EXPECT_TRUE(analysis->Dominates(12, 87)); + EXPECT_TRUE(analysis->Dominates(12, 94)); + EXPECT_TRUE(analysis->Dominates(12, 98)); + EXPECT_TRUE(analysis->Dominates(12, 95)); + EXPECT_TRUE(analysis->Dominates(12, 97)); + EXPECT_TRUE(analysis->Dominates(12, 96)); + EXPECT_TRUE(analysis->Dominates(12, 88)); + EXPECT_TRUE(analysis->Dominates(12, 76)); + + EXPECT_TRUE(analysis->Dominates(46, 50)); + EXPECT_TRUE(analysis->Dominates(46, 47)); + EXPECT_TRUE(analysis->Dominates(46, 57)); + EXPECT_TRUE(analysis->Dominates(46, 61)); + EXPECT_TRUE(analysis->Dominates(46, 59)); + EXPECT_TRUE(analysis->Dominates(46, 58)); + EXPECT_TRUE(analysis->Dominates(46, 60)); + EXPECT_TRUE(analysis->Dominates(46, 48)); + EXPECT_TRUE(analysis->Dominates(46, 73)); + EXPECT_TRUE(analysis->Dominates(46, 77)); + EXPECT_TRUE(analysis->Dominates(46, 75)); + EXPECT_TRUE(analysis->Dominates(46, 106)); + EXPECT_TRUE(analysis->Dominates(46, 110)); + EXPECT_TRUE(analysis->Dominates(46, 107)); + EXPECT_TRUE(analysis->Dominates(46, 108)); + EXPECT_TRUE(analysis->Dominates(46, 109)); + EXPECT_TRUE(analysis->Dominates(46, 74)); + EXPECT_TRUE(analysis->Dominates(46, 86)); + EXPECT_TRUE(analysis->Dominates(46, 90)); + EXPECT_TRUE(analysis->Dominates(46, 87)); + EXPECT_TRUE(analysis->Dominates(46, 94)); + EXPECT_TRUE(analysis->Dominates(46, 98)); + EXPECT_TRUE(analysis->Dominates(46, 95)); + EXPECT_TRUE(analysis->Dominates(46, 97)); + EXPECT_TRUE(analysis->Dominates(46, 96)); + EXPECT_TRUE(analysis->Dominates(46, 88)); + EXPECT_TRUE(analysis->Dominates(46, 76)); + + EXPECT_TRUE(analysis->Dominates(50, 47)); + EXPECT_TRUE(analysis->Dominates(50, 57)); + EXPECT_TRUE(analysis->Dominates(50, 61)); + EXPECT_TRUE(analysis->Dominates(50, 59)); + EXPECT_TRUE(analysis->Dominates(50, 58)); + EXPECT_TRUE(analysis->Dominates(50, 60)); + + EXPECT_TRUE(analysis->Dominates(47, 57)); + EXPECT_TRUE(analysis->Dominates(47, 61)); + EXPECT_TRUE(analysis->Dominates(47, 59)); + EXPECT_TRUE(analysis->Dominates(47, 58)); + EXPECT_TRUE(analysis->Dominates(47, 60)); + + EXPECT_TRUE(analysis->Dominates(57, 61)); + EXPECT_TRUE(analysis->Dominates(57, 59)); + EXPECT_TRUE(analysis->Dominates(57, 58)); + EXPECT_TRUE(analysis->Dominates(57, 60)); + + EXPECT_TRUE(analysis->Dominates(61, 59)); + + EXPECT_TRUE(analysis->Dominates(48, 73)); + EXPECT_TRUE(analysis->Dominates(48, 77)); + EXPECT_TRUE(analysis->Dominates(48, 75)); + EXPECT_TRUE(analysis->Dominates(48, 106)); + EXPECT_TRUE(analysis->Dominates(48, 110)); + EXPECT_TRUE(analysis->Dominates(48, 107)); + EXPECT_TRUE(analysis->Dominates(48, 108)); + EXPECT_TRUE(analysis->Dominates(48, 109)); + EXPECT_TRUE(analysis->Dominates(48, 74)); + EXPECT_TRUE(analysis->Dominates(48, 86)); + EXPECT_TRUE(analysis->Dominates(48, 90)); + EXPECT_TRUE(analysis->Dominates(48, 87)); + EXPECT_TRUE(analysis->Dominates(48, 94)); + EXPECT_TRUE(analysis->Dominates(48, 98)); + EXPECT_TRUE(analysis->Dominates(48, 95)); + EXPECT_TRUE(analysis->Dominates(48, 97)); + EXPECT_TRUE(analysis->Dominates(48, 96)); + EXPECT_TRUE(analysis->Dominates(48, 88)); + EXPECT_TRUE(analysis->Dominates(48, 76)); + + EXPECT_TRUE(analysis->Dominates(73, 77)); + EXPECT_TRUE(analysis->Dominates(73, 75)); + EXPECT_TRUE(analysis->Dominates(73, 106)); + EXPECT_TRUE(analysis->Dominates(73, 110)); + EXPECT_TRUE(analysis->Dominates(73, 107)); + EXPECT_TRUE(analysis->Dominates(73, 108)); + EXPECT_TRUE(analysis->Dominates(73, 109)); + EXPECT_TRUE(analysis->Dominates(73, 74)); + EXPECT_TRUE(analysis->Dominates(73, 86)); + EXPECT_TRUE(analysis->Dominates(73, 90)); + EXPECT_TRUE(analysis->Dominates(73, 87)); + EXPECT_TRUE(analysis->Dominates(73, 94)); + EXPECT_TRUE(analysis->Dominates(73, 98)); + EXPECT_TRUE(analysis->Dominates(73, 95)); + EXPECT_TRUE(analysis->Dominates(73, 97)); + EXPECT_TRUE(analysis->Dominates(73, 96)); + EXPECT_TRUE(analysis->Dominates(73, 88)); + EXPECT_TRUE(analysis->Dominates(73, 76)); + + EXPECT_TRUE(analysis->Dominates(75, 106)); + EXPECT_TRUE(analysis->Dominates(75, 110)); + EXPECT_TRUE(analysis->Dominates(75, 107)); + EXPECT_TRUE(analysis->Dominates(75, 108)); + EXPECT_TRUE(analysis->Dominates(75, 109)); + + EXPECT_TRUE(analysis->Dominates(106, 110)); + EXPECT_TRUE(analysis->Dominates(106, 107)); + EXPECT_TRUE(analysis->Dominates(106, 108)); + EXPECT_TRUE(analysis->Dominates(106, 109)); + + EXPECT_TRUE(analysis->Dominates(110, 107)); + + EXPECT_TRUE(analysis->Dominates(77, 74)); + EXPECT_TRUE(analysis->Dominates(77, 86)); + EXPECT_TRUE(analysis->Dominates(77, 90)); + EXPECT_TRUE(analysis->Dominates(77, 87)); + EXPECT_TRUE(analysis->Dominates(77, 94)); + EXPECT_TRUE(analysis->Dominates(77, 98)); + EXPECT_TRUE(analysis->Dominates(77, 95)); + EXPECT_TRUE(analysis->Dominates(77, 97)); + EXPECT_TRUE(analysis->Dominates(77, 96)); + EXPECT_TRUE(analysis->Dominates(77, 88)); + + EXPECT_TRUE(analysis->Dominates(74, 86)); + EXPECT_TRUE(analysis->Dominates(74, 90)); + EXPECT_TRUE(analysis->Dominates(74, 87)); + EXPECT_TRUE(analysis->Dominates(74, 94)); + EXPECT_TRUE(analysis->Dominates(74, 98)); + EXPECT_TRUE(analysis->Dominates(74, 95)); + EXPECT_TRUE(analysis->Dominates(74, 97)); + EXPECT_TRUE(analysis->Dominates(74, 96)); + EXPECT_TRUE(analysis->Dominates(74, 88)); + + EXPECT_TRUE(analysis->Dominates(86, 90)); + EXPECT_TRUE(analysis->Dominates(86, 87)); + EXPECT_TRUE(analysis->Dominates(86, 94)); + EXPECT_TRUE(analysis->Dominates(86, 98)); + EXPECT_TRUE(analysis->Dominates(86, 95)); + EXPECT_TRUE(analysis->Dominates(86, 97)); + EXPECT_TRUE(analysis->Dominates(86, 96)); + EXPECT_TRUE(analysis->Dominates(86, 88)); + + EXPECT_TRUE(analysis->Dominates(90, 87)); + EXPECT_TRUE(analysis->Dominates(90, 94)); + EXPECT_TRUE(analysis->Dominates(90, 98)); + EXPECT_TRUE(analysis->Dominates(90, 95)); + EXPECT_TRUE(analysis->Dominates(90, 97)); + EXPECT_TRUE(analysis->Dominates(90, 96)); + + EXPECT_TRUE(analysis->Dominates(87, 94)); + EXPECT_TRUE(analysis->Dominates(87, 98)); + EXPECT_TRUE(analysis->Dominates(87, 95)); + EXPECT_TRUE(analysis->Dominates(87, 97)); + EXPECT_TRUE(analysis->Dominates(87, 96)); + + EXPECT_TRUE(analysis->Dominates(94, 98)); + EXPECT_TRUE(analysis->Dominates(94, 95)); + EXPECT_TRUE(analysis->Dominates(94, 97)); + EXPECT_TRUE(analysis->Dominates(94, 96)); + + EXPECT_TRUE(analysis->Dominates(98, 95)); + + EXPECT_TRUE(analysis->StrictlyDominates(5, 10)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 14)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 11)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 29)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 33)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 30)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 32)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 31)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 13)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 12)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 46)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 50)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 47)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 57)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 61)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 59)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 58)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 60)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 48)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 73)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 77)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 75)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 106)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 110)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 107)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 108)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 109)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 74)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 86)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 90)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 87)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 94)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 98)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 95)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 97)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 96)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 88)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 76)); + + EXPECT_TRUE(analysis->StrictlyDominates(10, 14)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 11)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 29)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 33)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 30)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 32)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 31)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 13)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 12)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 46)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 50)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 47)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 57)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 61)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 59)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 58)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 60)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 48)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 73)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 77)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 75)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 106)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 110)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 107)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 108)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 109)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 74)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 86)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 90)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 87)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 94)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 98)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 95)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 97)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 96)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 88)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 76)); + + EXPECT_TRUE(analysis->StrictlyDominates(14, 11)); + EXPECT_TRUE(analysis->StrictlyDominates(14, 29)); + EXPECT_TRUE(analysis->StrictlyDominates(14, 33)); + EXPECT_TRUE(analysis->StrictlyDominates(14, 30)); + EXPECT_TRUE(analysis->StrictlyDominates(14, 32)); + EXPECT_TRUE(analysis->StrictlyDominates(14, 31)); + + EXPECT_TRUE(analysis->StrictlyDominates(11, 29)); + EXPECT_TRUE(analysis->StrictlyDominates(11, 33)); + EXPECT_TRUE(analysis->StrictlyDominates(11, 30)); + EXPECT_TRUE(analysis->StrictlyDominates(11, 32)); + EXPECT_TRUE(analysis->StrictlyDominates(11, 31)); + + EXPECT_TRUE(analysis->StrictlyDominates(29, 33)); + EXPECT_TRUE(analysis->StrictlyDominates(29, 30)); + EXPECT_TRUE(analysis->StrictlyDominates(29, 32)); + EXPECT_TRUE(analysis->StrictlyDominates(29, 31)); + + EXPECT_TRUE(analysis->StrictlyDominates(33, 30)); + + EXPECT_TRUE(analysis->StrictlyDominates(12, 46)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 50)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 47)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 57)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 61)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 59)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 58)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 60)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 48)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 73)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 77)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 75)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 106)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 110)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 107)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 108)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 109)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 74)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 86)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 90)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 87)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 94)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 98)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 95)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 97)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 96)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 88)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 76)); + + EXPECT_TRUE(analysis->StrictlyDominates(46, 50)); + EXPECT_TRUE(analysis->StrictlyDominates(46, 47)); + EXPECT_TRUE(analysis->StrictlyDominates(46, 57)); + EXPECT_TRUE(analysis->StrictlyDominates(46, 61)); + EXPECT_TRUE(analysis->StrictlyDominates(46, 59)); + EXPECT_TRUE(analysis->StrictlyDominates(46, 58)); + EXPECT_TRUE(analysis->StrictlyDominates(46, 60)); + EXPECT_TRUE(analysis->StrictlyDominates(46, 48)); + EXPECT_TRUE(analysis->StrictlyDominates(46, 73)); + EXPECT_TRUE(analysis->StrictlyDominates(46, 77)); + EXPECT_TRUE(analysis->StrictlyDominates(46, 75)); + EXPECT_TRUE(analysis->StrictlyDominates(46, 106)); + EXPECT_TRUE(analysis->StrictlyDominates(46, 110)); + EXPECT_TRUE(analysis->StrictlyDominates(46, 107)); + EXPECT_TRUE(analysis->StrictlyDominates(46, 108)); + EXPECT_TRUE(analysis->StrictlyDominates(46, 109)); + EXPECT_TRUE(analysis->StrictlyDominates(46, 74)); + EXPECT_TRUE(analysis->StrictlyDominates(46, 86)); + EXPECT_TRUE(analysis->StrictlyDominates(46, 90)); + EXPECT_TRUE(analysis->StrictlyDominates(46, 87)); + EXPECT_TRUE(analysis->StrictlyDominates(46, 94)); + EXPECT_TRUE(analysis->StrictlyDominates(46, 98)); + EXPECT_TRUE(analysis->StrictlyDominates(46, 95)); + EXPECT_TRUE(analysis->StrictlyDominates(46, 97)); + EXPECT_TRUE(analysis->StrictlyDominates(46, 96)); + EXPECT_TRUE(analysis->StrictlyDominates(46, 88)); + EXPECT_TRUE(analysis->StrictlyDominates(46, 76)); + + EXPECT_TRUE(analysis->StrictlyDominates(50, 47)); + EXPECT_TRUE(analysis->StrictlyDominates(50, 57)); + EXPECT_TRUE(analysis->StrictlyDominates(50, 61)); + EXPECT_TRUE(analysis->StrictlyDominates(50, 59)); + EXPECT_TRUE(analysis->StrictlyDominates(50, 58)); + EXPECT_TRUE(analysis->StrictlyDominates(50, 60)); + + EXPECT_TRUE(analysis->StrictlyDominates(47, 57)); + EXPECT_TRUE(analysis->StrictlyDominates(47, 61)); + EXPECT_TRUE(analysis->StrictlyDominates(47, 59)); + EXPECT_TRUE(analysis->StrictlyDominates(47, 58)); + EXPECT_TRUE(analysis->StrictlyDominates(47, 60)); + + EXPECT_TRUE(analysis->StrictlyDominates(57, 61)); + EXPECT_TRUE(analysis->StrictlyDominates(57, 59)); + EXPECT_TRUE(analysis->StrictlyDominates(57, 58)); + EXPECT_TRUE(analysis->StrictlyDominates(57, 60)); + + EXPECT_TRUE(analysis->StrictlyDominates(61, 59)); + + EXPECT_TRUE(analysis->StrictlyDominates(48, 73)); + EXPECT_TRUE(analysis->StrictlyDominates(48, 77)); + EXPECT_TRUE(analysis->StrictlyDominates(48, 75)); + EXPECT_TRUE(analysis->StrictlyDominates(48, 106)); + EXPECT_TRUE(analysis->StrictlyDominates(48, 110)); + EXPECT_TRUE(analysis->StrictlyDominates(48, 107)); + EXPECT_TRUE(analysis->StrictlyDominates(48, 108)); + EXPECT_TRUE(analysis->StrictlyDominates(48, 109)); + EXPECT_TRUE(analysis->StrictlyDominates(48, 74)); + EXPECT_TRUE(analysis->StrictlyDominates(48, 86)); + EXPECT_TRUE(analysis->StrictlyDominates(48, 90)); + EXPECT_TRUE(analysis->StrictlyDominates(48, 87)); + EXPECT_TRUE(analysis->StrictlyDominates(48, 94)); + EXPECT_TRUE(analysis->StrictlyDominates(48, 98)); + EXPECT_TRUE(analysis->StrictlyDominates(48, 95)); + EXPECT_TRUE(analysis->StrictlyDominates(48, 97)); + EXPECT_TRUE(analysis->StrictlyDominates(48, 96)); + EXPECT_TRUE(analysis->StrictlyDominates(48, 88)); + EXPECT_TRUE(analysis->StrictlyDominates(48, 76)); + + EXPECT_TRUE(analysis->StrictlyDominates(73, 77)); + EXPECT_TRUE(analysis->StrictlyDominates(73, 75)); + EXPECT_TRUE(analysis->StrictlyDominates(73, 106)); + EXPECT_TRUE(analysis->StrictlyDominates(73, 110)); + EXPECT_TRUE(analysis->StrictlyDominates(73, 107)); + EXPECT_TRUE(analysis->StrictlyDominates(73, 108)); + EXPECT_TRUE(analysis->StrictlyDominates(73, 109)); + EXPECT_TRUE(analysis->StrictlyDominates(73, 74)); + EXPECT_TRUE(analysis->StrictlyDominates(73, 86)); + EXPECT_TRUE(analysis->StrictlyDominates(73, 90)); + EXPECT_TRUE(analysis->StrictlyDominates(73, 87)); + EXPECT_TRUE(analysis->StrictlyDominates(73, 94)); + EXPECT_TRUE(analysis->StrictlyDominates(73, 98)); + EXPECT_TRUE(analysis->StrictlyDominates(73, 95)); + EXPECT_TRUE(analysis->StrictlyDominates(73, 97)); + EXPECT_TRUE(analysis->StrictlyDominates(73, 96)); + EXPECT_TRUE(analysis->StrictlyDominates(73, 88)); + EXPECT_TRUE(analysis->StrictlyDominates(73, 76)); + + EXPECT_TRUE(analysis->StrictlyDominates(75, 106)); + EXPECT_TRUE(analysis->StrictlyDominates(75, 110)); + EXPECT_TRUE(analysis->StrictlyDominates(75, 107)); + EXPECT_TRUE(analysis->StrictlyDominates(75, 108)); + EXPECT_TRUE(analysis->StrictlyDominates(75, 109)); + + EXPECT_TRUE(analysis->StrictlyDominates(106, 110)); + EXPECT_TRUE(analysis->StrictlyDominates(106, 107)); + EXPECT_TRUE(analysis->StrictlyDominates(106, 108)); + EXPECT_TRUE(analysis->StrictlyDominates(106, 109)); + + EXPECT_TRUE(analysis->StrictlyDominates(110, 107)); + + EXPECT_TRUE(analysis->StrictlyDominates(77, 74)); + EXPECT_TRUE(analysis->StrictlyDominates(77, 86)); + EXPECT_TRUE(analysis->StrictlyDominates(77, 90)); + EXPECT_TRUE(analysis->StrictlyDominates(77, 87)); + EXPECT_TRUE(analysis->StrictlyDominates(77, 94)); + EXPECT_TRUE(analysis->StrictlyDominates(77, 98)); + EXPECT_TRUE(analysis->StrictlyDominates(77, 95)); + EXPECT_TRUE(analysis->StrictlyDominates(77, 97)); + EXPECT_TRUE(analysis->StrictlyDominates(77, 96)); + EXPECT_TRUE(analysis->StrictlyDominates(77, 88)); + + EXPECT_TRUE(analysis->StrictlyDominates(74, 86)); + EXPECT_TRUE(analysis->StrictlyDominates(74, 90)); + EXPECT_TRUE(analysis->StrictlyDominates(74, 87)); + EXPECT_TRUE(analysis->StrictlyDominates(74, 94)); + EXPECT_TRUE(analysis->StrictlyDominates(74, 98)); + EXPECT_TRUE(analysis->StrictlyDominates(74, 95)); + EXPECT_TRUE(analysis->StrictlyDominates(74, 97)); + EXPECT_TRUE(analysis->StrictlyDominates(74, 96)); + EXPECT_TRUE(analysis->StrictlyDominates(74, 88)); + + EXPECT_TRUE(analysis->StrictlyDominates(86, 90)); + EXPECT_TRUE(analysis->StrictlyDominates(86, 87)); + EXPECT_TRUE(analysis->StrictlyDominates(86, 94)); + EXPECT_TRUE(analysis->StrictlyDominates(86, 98)); + EXPECT_TRUE(analysis->StrictlyDominates(86, 95)); + EXPECT_TRUE(analysis->StrictlyDominates(86, 97)); + EXPECT_TRUE(analysis->StrictlyDominates(86, 96)); + EXPECT_TRUE(analysis->StrictlyDominates(86, 88)); + + EXPECT_TRUE(analysis->StrictlyDominates(90, 87)); + EXPECT_TRUE(analysis->StrictlyDominates(90, 94)); + EXPECT_TRUE(analysis->StrictlyDominates(90, 98)); + EXPECT_TRUE(analysis->StrictlyDominates(90, 95)); + EXPECT_TRUE(analysis->StrictlyDominates(90, 97)); + EXPECT_TRUE(analysis->StrictlyDominates(90, 96)); + + EXPECT_TRUE(analysis->StrictlyDominates(87, 94)); + EXPECT_TRUE(analysis->StrictlyDominates(87, 98)); + EXPECT_TRUE(analysis->StrictlyDominates(87, 95)); + EXPECT_TRUE(analysis->StrictlyDominates(87, 97)); + EXPECT_TRUE(analysis->StrictlyDominates(87, 96)); + + EXPECT_TRUE(analysis->StrictlyDominates(94, 98)); + EXPECT_TRUE(analysis->StrictlyDominates(94, 95)); + EXPECT_TRUE(analysis->StrictlyDominates(94, 97)); + EXPECT_TRUE(analysis->StrictlyDominates(94, 96)); + + EXPECT_TRUE(analysis->StrictlyDominates(98, 95)); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/dominator_tree/pch_test_opt_dom.cpp b/test/opt/dominator_tree/pch_test_opt_dom.cpp new file mode 100644 index 000000000..a28310e57 --- /dev/null +++ b/test/opt/dominator_tree/pch_test_opt_dom.cpp @@ -0,0 +1,15 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "pch_test_opt_dom.h" diff --git a/test/opt/dominator_tree/pch_test_opt_dom.h b/test/opt/dominator_tree/pch_test_opt_dom.h new file mode 100644 index 000000000..4e8106fbf --- /dev/null +++ b/test/opt/dominator_tree/pch_test_opt_dom.h @@ -0,0 +1,25 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gmock/gmock.h" +#include "source/opt/iterator.h" +#include "source/opt/loop_dependence.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/pass.h" +#include "source/opt/scalar_analysis.h" +#include "source/opt/tree_iterator.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" diff --git a/test/opt/dominator_tree/post.cpp b/test/opt/dominator_tree/post.cpp new file mode 100644 index 000000000..bb10fdef1 --- /dev/null +++ b/test/opt/dominator_tree/post.cpp @@ -0,0 +1,207 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::UnorderedElementsAre; +using PassClassTest = PassTest<::testing::Test>; + +/* + Generated from the following GLSL +#version 440 core +layout(location = 0) out vec4 c; +layout(location = 1)in vec4 in_val; +void main(){ + if ( in_val.x < 10) { + int z = 0; + int i = 0; + for (i = 0; i < in_val.y; ++i) { + z += i; + } + c = vec4(i,i,i,i); + } else { + c = vec4(1,1,1,1); + } +} +*/ +TEST_F(PassClassTest, BasicVisitFromEntryPoint) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %9 %43 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %9 "in_val" + OpName %22 "z" + OpName %24 "i" + OpName %43 "c" + OpDecorate %9 Location 1 + OpDecorate %43 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypeVector %6 4 + %8 = OpTypePointer Input %7 + %9 = OpVariable %8 Input + %10 = OpTypeInt 32 0 + %11 = OpConstant %10 0 + %12 = OpTypePointer Input %6 + %15 = OpConstant %6 10 + %16 = OpTypeBool + %20 = OpTypeInt 32 1 + %21 = OpTypePointer Function %20 + %23 = OpConstant %20 0 + %32 = OpConstant %10 1 + %40 = OpConstant %20 1 + %42 = OpTypePointer Output %7 + %43 = OpVariable %42 Output + %54 = OpConstant %6 1 + %55 = OpConstantComposite %7 %54 %54 %54 %54 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %22 = OpVariable %21 Function + %24 = OpVariable %21 Function + %13 = OpAccessChain %12 %9 %11 + %14 = OpLoad %6 %13 + %17 = OpFOrdLessThan %16 %14 %15 + OpSelectionMerge %19 None + OpBranchConditional %17 %18 %53 + %18 = OpLabel + OpStore %22 %23 + OpStore %24 %23 + OpStore %24 %23 + OpBranch %25 + %25 = OpLabel + OpLoopMerge %27 %28 None + OpBranch %29 + %29 = OpLabel + %30 = OpLoad %20 %24 + %31 = OpConvertSToF %6 %30 + %33 = OpAccessChain %12 %9 %32 + %34 = OpLoad %6 %33 + %35 = OpFOrdLessThan %16 %31 %34 + OpBranchConditional %35 %26 %27 + %26 = OpLabel + %36 = OpLoad %20 %24 + %37 = OpLoad %20 %22 + %38 = OpIAdd %20 %37 %36 + OpStore %22 %38 + OpBranch %28 + %28 = OpLabel + %39 = OpLoad %20 %24 + %41 = OpIAdd %20 %39 %40 + OpStore %24 %41 + OpBranch %25 + %27 = OpLabel + %44 = OpLoad %20 %24 + %45 = OpConvertSToF %6 %44 + %46 = OpLoad %20 %24 + %47 = OpConvertSToF %6 %46 + %48 = OpLoad %20 %24 + %49 = OpConvertSToF %6 %48 + %50 = OpLoad %20 %24 + %51 = OpConvertSToF %6 %50 + %52 = OpCompositeConstruct %7 %45 %47 %49 %51 + OpStore %43 %52 + OpBranch %19 + %53 = OpLabel + OpStore %43 %55 + OpBranch %19 + %19 = OpLabel + OpReturn + OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + const Function* f = spvtest::GetFunction(module, 4); + CFG cfg(module); + PostDominatorAnalysis* analysis = context->GetPostDominatorAnalysis(f); + + EXPECT_TRUE(analysis->Dominates(19, 18)); + EXPECT_TRUE(analysis->Dominates(19, 5)); + EXPECT_TRUE(analysis->Dominates(19, 53)); + EXPECT_TRUE(analysis->Dominates(19, 19)); + EXPECT_TRUE(analysis->Dominates(19, 25)); + EXPECT_TRUE(analysis->Dominates(19, 29)); + EXPECT_TRUE(analysis->Dominates(19, 27)); + EXPECT_TRUE(analysis->Dominates(19, 26)); + EXPECT_TRUE(analysis->Dominates(19, 28)); + + EXPECT_TRUE(analysis->Dominates(27, 18)); + EXPECT_TRUE(analysis->Dominates(27, 25)); + EXPECT_TRUE(analysis->Dominates(27, 29)); + EXPECT_TRUE(analysis->Dominates(27, 27)); + EXPECT_TRUE(analysis->Dominates(27, 26)); + EXPECT_TRUE(analysis->Dominates(27, 28)); + + EXPECT_FALSE(analysis->Dominates(27, 19)); + EXPECT_FALSE(analysis->Dominates(27, 5)); + EXPECT_FALSE(analysis->Dominates(27, 53)); + + EXPECT_FALSE(analysis->StrictlyDominates(19, 19)); + + EXPECT_TRUE(analysis->StrictlyDominates(19, 18)); + EXPECT_TRUE(analysis->StrictlyDominates(19, 5)); + EXPECT_TRUE(analysis->StrictlyDominates(19, 53)); + EXPECT_TRUE(analysis->StrictlyDominates(19, 25)); + EXPECT_TRUE(analysis->StrictlyDominates(19, 29)); + EXPECT_TRUE(analysis->StrictlyDominates(19, 27)); + EXPECT_TRUE(analysis->StrictlyDominates(19, 26)); + EXPECT_TRUE(analysis->StrictlyDominates(19, 28)); + + // These would be expected true for a normal, non post, dominator tree + EXPECT_FALSE(analysis->Dominates(5, 18)); + EXPECT_FALSE(analysis->Dominates(5, 53)); + EXPECT_FALSE(analysis->Dominates(5, 19)); + EXPECT_FALSE(analysis->Dominates(5, 25)); + EXPECT_FALSE(analysis->Dominates(5, 29)); + EXPECT_FALSE(analysis->Dominates(5, 27)); + EXPECT_FALSE(analysis->Dominates(5, 26)); + EXPECT_FALSE(analysis->Dominates(5, 28)); + + EXPECT_FALSE(analysis->StrictlyDominates(5, 18)); + EXPECT_FALSE(analysis->StrictlyDominates(5, 53)); + EXPECT_FALSE(analysis->StrictlyDominates(5, 19)); + EXPECT_FALSE(analysis->StrictlyDominates(5, 25)); + EXPECT_FALSE(analysis->StrictlyDominates(5, 29)); + EXPECT_FALSE(analysis->StrictlyDominates(5, 27)); + EXPECT_FALSE(analysis->StrictlyDominates(5, 26)); + EXPECT_FALSE(analysis->StrictlyDominates(5, 28)); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/dominator_tree/simple.cpp b/test/opt/dominator_tree/simple.cpp new file mode 100644 index 000000000..d11854d55 --- /dev/null +++ b/test/opt/dominator_tree/simple.cpp @@ -0,0 +1,177 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::UnorderedElementsAre; +using PassClassTest = PassTest<::testing::Test>; + +/* +Generated from the following GLSL +#version 440 core +layout(location = 0) out vec4 c; +layout(location = 1)in vec4 in_val; +void main(){ + if ( in_val.x < 10) { + int z = 0; + int i = 0; + for (i = 0; i < in_val.y; ++i) { + z += i; + } + c = vec4(i,i,i,i); + } else { + c = vec4(1,1,1,1); + } +} +*/ +TEST_F(PassClassTest, BasicVisitFromEntryPoint) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %9 %43 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %9 "in_val" + OpName %22 "z" + OpName %24 "i" + OpName %43 "c" + OpDecorate %9 Location 1 + OpDecorate %43 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypeVector %6 4 + %8 = OpTypePointer Input %7 + %9 = OpVariable %8 Input + %10 = OpTypeInt 32 0 + %11 = OpConstant %10 0 + %12 = OpTypePointer Input %6 + %15 = OpConstant %6 10 + %16 = OpTypeBool + %20 = OpTypeInt 32 1 + %21 = OpTypePointer Function %20 + %23 = OpConstant %20 0 + %32 = OpConstant %10 1 + %40 = OpConstant %20 1 + %42 = OpTypePointer Output %7 + %43 = OpVariable %42 Output + %54 = OpConstant %6 1 + %55 = OpConstantComposite %7 %54 %54 %54 %54 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %22 = OpVariable %21 Function + %24 = OpVariable %21 Function + %13 = OpAccessChain %12 %9 %11 + %14 = OpLoad %6 %13 + %17 = OpFOrdLessThan %16 %14 %15 + OpSelectionMerge %19 None + OpBranchConditional %17 %18 %53 + %18 = OpLabel + OpStore %22 %23 + OpStore %24 %23 + OpStore %24 %23 + OpBranch %25 + %25 = OpLabel + OpLoopMerge %27 %28 None + OpBranch %29 + %29 = OpLabel + %30 = OpLoad %20 %24 + %31 = OpConvertSToF %6 %30 + %33 = OpAccessChain %12 %9 %32 + %34 = OpLoad %6 %33 + %35 = OpFOrdLessThan %16 %31 %34 + OpBranchConditional %35 %26 %27 + %26 = OpLabel + %36 = OpLoad %20 %24 + %37 = OpLoad %20 %22 + %38 = OpIAdd %20 %37 %36 + OpStore %22 %38 + OpBranch %28 + %28 = OpLabel + %39 = OpLoad %20 %24 + %41 = OpIAdd %20 %39 %40 + OpStore %24 %41 + OpBranch %25 + %27 = OpLabel + %44 = OpLoad %20 %24 + %45 = OpConvertSToF %6 %44 + %46 = OpLoad %20 %24 + %47 = OpConvertSToF %6 %46 + %48 = OpLoad %20 %24 + %49 = OpConvertSToF %6 %48 + %50 = OpLoad %20 %24 + %51 = OpConvertSToF %6 %50 + %52 = OpCompositeConstruct %7 %45 %47 %49 %51 + OpStore %43 %52 + OpBranch %19 + %53 = OpLabel + OpStore %43 %55 + OpBranch %19 + %19 = OpLabel + OpReturn + OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 4); + + DominatorAnalysis* analysis = context->GetDominatorAnalysis(f); + const CFG& cfg = *context->cfg(); + + DominatorTree& tree = analysis->GetDomTree(); + + EXPECT_EQ(tree.GetRoot()->bb_, cfg.pseudo_entry_block()); + EXPECT_TRUE(analysis->Dominates(5, 18)); + EXPECT_TRUE(analysis->Dominates(5, 53)); + EXPECT_TRUE(analysis->Dominates(5, 19)); + EXPECT_TRUE(analysis->Dominates(5, 25)); + EXPECT_TRUE(analysis->Dominates(5, 29)); + EXPECT_TRUE(analysis->Dominates(5, 27)); + EXPECT_TRUE(analysis->Dominates(5, 26)); + EXPECT_TRUE(analysis->Dominates(5, 28)); + + EXPECT_TRUE(analysis->StrictlyDominates(5, 18)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 53)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 19)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 25)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 29)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 27)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 26)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 28)); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/dominator_tree/switch_case_fallthrough.cpp b/test/opt/dominator_tree/switch_case_fallthrough.cpp new file mode 100644 index 000000000..d9dd7d161 --- /dev/null +++ b/test/opt/dominator_tree/switch_case_fallthrough.cpp @@ -0,0 +1,163 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::UnorderedElementsAre; +using PassClassTest = PassTest<::testing::Test>; + +/* + Generated from the following GLSL +#version 440 core +layout(location = 0) out vec4 v; +layout(location = 1) in vec4 in_val; +void main() { + int i; + switch (int(in_val.x)) { + case 0: + i = 0; + case 1: + i = 1; + break; + case 2: + i = 2; + case 3: + i = 3; + case 4: + i = 4; + break; + default: + i = 0; + } + v = vec4(i, i, i, i); +} +*/ +TEST_F(PassClassTest, UnreachableNestedIfs) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %9 %35 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %9 "in_val" + OpName %25 "i" + OpName %35 "v" + OpDecorate %9 Location 1 + OpDecorate %35 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypeVector %6 4 + %8 = OpTypePointer Input %7 + %9 = OpVariable %8 Input + %10 = OpTypeInt 32 0 + %11 = OpConstant %10 0 + %12 = OpTypePointer Input %6 + %15 = OpTypeInt 32 1 + %24 = OpTypePointer Function %15 + %26 = OpConstant %15 0 + %27 = OpConstant %15 1 + %29 = OpConstant %15 2 + %30 = OpConstant %15 3 + %31 = OpConstant %15 4 + %34 = OpTypePointer Output %7 + %35 = OpVariable %34 Output + %4 = OpFunction %2 None %3 + %5 = OpLabel + %25 = OpVariable %24 Function + %13 = OpAccessChain %12 %9 %11 + %14 = OpLoad %6 %13 + %16 = OpConvertFToS %15 %14 + OpSelectionMerge %23 None + OpSwitch %16 %22 0 %17 1 %18 2 %19 3 %20 4 %21 + %22 = OpLabel + OpStore %25 %26 + OpBranch %23 + %17 = OpLabel + OpStore %25 %26 + OpBranch %18 + %18 = OpLabel + OpStore %25 %27 + OpBranch %23 + %19 = OpLabel + OpStore %25 %29 + OpBranch %20 + %20 = OpLabel + OpStore %25 %30 + OpBranch %21 + %21 = OpLabel + OpStore %25 %31 + OpBranch %23 + %23 = OpLabel + %36 = OpLoad %15 %25 + %37 = OpConvertSToF %6 %36 + %38 = OpLoad %15 %25 + %39 = OpConvertSToF %6 %38 + %40 = OpLoad %15 %25 + %41 = OpConvertSToF %6 %40 + %42 = OpLoad %15 %25 + %43 = OpConvertSToF %6 %42 + %44 = OpCompositeConstruct %7 %37 %39 %41 %43 + OpStore %35 %44 + OpReturn + OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + const Function* f = spvtest::GetFunction(module, 4); + DominatorAnalysis* analysis = context->GetDominatorAnalysis(f); + + EXPECT_TRUE(analysis->Dominates(5, 5)); + EXPECT_TRUE(analysis->Dominates(5, 17)); + EXPECT_TRUE(analysis->Dominates(5, 18)); + EXPECT_TRUE(analysis->Dominates(5, 19)); + EXPECT_TRUE(analysis->Dominates(5, 20)); + EXPECT_TRUE(analysis->Dominates(5, 21)); + EXPECT_TRUE(analysis->Dominates(5, 22)); + EXPECT_TRUE(analysis->Dominates(5, 23)); + + EXPECT_TRUE(analysis->StrictlyDominates(5, 17)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 18)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 19)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 20)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 21)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 22)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 23)); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/dominator_tree/unreachable_for.cpp b/test/opt/dominator_tree/unreachable_for.cpp new file mode 100644 index 000000000..469e5c142 --- /dev/null +++ b/test/opt/dominator_tree/unreachable_for.cpp @@ -0,0 +1,121 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::UnorderedElementsAre; +using PassClassTest = PassTest<::testing::Test>; + +/* + Generated from the following GLSL +#version 440 core +void main() { + for (int i = 0; i < 1; i++) { + break; + } +} +*/ +TEST_F(PassClassTest, UnreachableNestedIfs) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 1 + %17 = OpTypeBool + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %15 = OpLoad %6 %8 + %18 = OpSLessThan %17 %15 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %12 + %13 = OpLabel + %20 = OpLoad %6 %8 + %21 = OpIAdd %6 %20 %16 + OpStore %8 %21 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + const Function* f = spvtest::GetFunction(module, 4); + DominatorAnalysis* analysis = context->GetDominatorAnalysis(f); + EXPECT_TRUE(analysis->Dominates(5, 5)); + EXPECT_TRUE(analysis->Dominates(5, 10)); + EXPECT_TRUE(analysis->Dominates(5, 14)); + EXPECT_TRUE(analysis->Dominates(5, 11)); + EXPECT_TRUE(analysis->Dominates(5, 12)); + EXPECT_TRUE(analysis->Dominates(10, 10)); + EXPECT_TRUE(analysis->Dominates(10, 14)); + EXPECT_TRUE(analysis->Dominates(10, 11)); + EXPECT_TRUE(analysis->Dominates(10, 12)); + EXPECT_TRUE(analysis->Dominates(14, 14)); + EXPECT_TRUE(analysis->Dominates(14, 11)); + EXPECT_TRUE(analysis->Dominates(14, 12)); + EXPECT_TRUE(analysis->Dominates(11, 11)); + EXPECT_TRUE(analysis->Dominates(12, 12)); + + EXPECT_TRUE(analysis->StrictlyDominates(5, 10)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 14)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 11)); + EXPECT_TRUE(analysis->StrictlyDominates(5, 12)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 14)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 11)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 12)); + EXPECT_TRUE(analysis->StrictlyDominates(14, 11)); + EXPECT_TRUE(analysis->StrictlyDominates(14, 12)); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/dominator_tree/unreachable_for_post.cpp b/test/opt/dominator_tree/unreachable_for_post.cpp new file mode 100644 index 000000000..8d3e37b4a --- /dev/null +++ b/test/opt/dominator_tree/unreachable_for_post.cpp @@ -0,0 +1,118 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/dominator_analysis.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::UnorderedElementsAre; +using PassClassTest = PassTest<::testing::Test>; + +/* + Generated from the following GLSL +#version 440 core +void main() { + for (int i = 0; i < 1; i++) { + break; + } +} +*/ +TEST_F(PassClassTest, UnreachableNestedIfs) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 1 + %17 = OpTypeBool + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %15 = OpLoad %6 %8 + %18 = OpSLessThan %17 %15 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %12 + %13 = OpLabel + %20 = OpLoad %6 %8 + %21 = OpIAdd %6 %20 %16 + OpStore %8 %21 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + const Function* f = spvtest::GetFunction(module, 4); + + PostDominatorAnalysis* analysis = context->GetPostDominatorAnalysis(f); + + EXPECT_TRUE(analysis->Dominates(12, 12)); + EXPECT_TRUE(analysis->Dominates(12, 14)); + EXPECT_TRUE(analysis->Dominates(12, 11)); + EXPECT_TRUE(analysis->Dominates(12, 10)); + EXPECT_TRUE(analysis->Dominates(12, 5)); + EXPECT_TRUE(analysis->Dominates(14, 14)); + EXPECT_TRUE(analysis->Dominates(14, 10)); + EXPECT_TRUE(analysis->Dominates(14, 5)); + EXPECT_TRUE(analysis->Dominates(10, 10)); + EXPECT_TRUE(analysis->Dominates(10, 5)); + EXPECT_TRUE(analysis->Dominates(5, 5)); + + EXPECT_TRUE(analysis->StrictlyDominates(12, 14)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 11)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 10)); + EXPECT_TRUE(analysis->StrictlyDominates(12, 5)); + EXPECT_TRUE(analysis->StrictlyDominates(14, 10)); + EXPECT_TRUE(analysis->StrictlyDominates(14, 5)); + EXPECT_TRUE(analysis->StrictlyDominates(10, 5)); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/eliminate_dead_const_test.cpp b/test/opt/eliminate_dead_const_test.cpp new file mode 100644 index 000000000..7fac866ce --- /dev/null +++ b/test/opt/eliminate_dead_const_test.cpp @@ -0,0 +1,847 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include + +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using EliminateDeadConstantBasicTest = PassTest<::testing::Test>; + +TEST_F(EliminateDeadConstantBasicTest, BasicAllDeadConstants) { + const std::vector text = { + // clang-format off + "OpCapability Shader", + "OpCapability Float64", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Vertex %main \"main\"", + "OpName %main \"main\"", + "%void = OpTypeVoid", + "%4 = OpTypeFunction %void", + "%bool = OpTypeBool", + "%true = OpConstantTrue %bool", + "%false = OpConstantFalse %bool", + "%int = OpTypeInt 32 1", + "%9 = OpConstant %int 1", + "%uint = OpTypeInt 32 0", + "%11 = OpConstant %uint 2", + "%float = OpTypeFloat 32", + "%13 = OpConstant %float 3.1415", + "%double = OpTypeFloat 64", + "%15 = OpConstant %double 3.14159265358979", + "%main = OpFunction %void None %4", + "%16 = OpLabel", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + // None of the above constants is ever used, so all of them should be + // eliminated. + const char* const_decl_opcodes[] = { + " OpConstantTrue ", + " OpConstantFalse ", + " OpConstant ", + }; + // Skip lines that have any one of const_decl_opcodes. + const std::string expected_disassembly = + SelectiveJoin(text, [&const_decl_opcodes](const char* line) { + return std::any_of( + std::begin(const_decl_opcodes), std::end(const_decl_opcodes), + [&line](const char* const_decl_op) { + return std::string(line).find(const_decl_op) != std::string::npos; + }); + }); + + SinglePassRunAndCheck( + JoinAllInsts(text), expected_disassembly, /* skip_nop = */ true); +} + +TEST_F(EliminateDeadConstantBasicTest, BasicNoneDeadConstants) { + const std::vector text = { + // clang-format off + "OpCapability Shader", + "OpCapability Float64", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Vertex %main \"main\"", + "OpName %main \"main\"", + "OpName %btv \"btv\"", + "OpName %bfv \"bfv\"", + "OpName %iv \"iv\"", + "OpName %uv \"uv\"", + "OpName %fv \"fv\"", + "OpName %dv \"dv\"", + "%void = OpTypeVoid", + "%10 = OpTypeFunction %void", + "%bool = OpTypeBool", + "%_ptr_Function_bool = OpTypePointer Function %bool", + "%true = OpConstantTrue %bool", + "%false = OpConstantFalse %bool", + "%int = OpTypeInt 32 1", + "%_ptr_Function_int = OpTypePointer Function %int", + "%int_1 = OpConstant %int 1", + "%uint = OpTypeInt 32 0", + "%_ptr_Function_uint = OpTypePointer Function %uint", + "%uint_2 = OpConstant %uint 2", + "%float = OpTypeFloat 32", + "%_ptr_Function_float = OpTypePointer Function %float", + "%float_3_1415 = OpConstant %float 3.1415", + "%double = OpTypeFloat 64", + "%_ptr_Function_double = OpTypePointer Function %double", + "%double_3_14159265358979 = OpConstant %double 3.14159265358979", + "%main = OpFunction %void None %10", + "%27 = OpLabel", + "%btv = OpVariable %_ptr_Function_bool Function", + "%bfv = OpVariable %_ptr_Function_bool Function", + "%iv = OpVariable %_ptr_Function_int Function", + "%uv = OpVariable %_ptr_Function_uint Function", + "%fv = OpVariable %_ptr_Function_float Function", + "%dv = OpVariable %_ptr_Function_double Function", + "OpStore %btv %true", + "OpStore %bfv %false", + "OpStore %iv %int_1", + "OpStore %uv %uint_2", + "OpStore %fv %float_3_1415", + "OpStore %dv %double_3_14159265358979", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + // All constants are used, so none of them should be eliminated. + SinglePassRunAndCheck( + JoinAllInsts(text), JoinAllInsts(text), /* skip_nop = */ true); +} + +struct EliminateDeadConstantTestCase { + // Type declarations and constants that should be kept. + std::vector used_consts; + // Instructions that refer to constants, this is added to create uses for + // some constants so they won't be treated as dead constants. + std::vector main_insts; + // Dead constants that should be removed. + std::vector dead_consts; +}; + +// All types that are potentially required in EliminateDeadConstantTest. +const std::vector CommonTypes = { + // clang-format off + // scalar types + "%bool = OpTypeBool", + "%uint = OpTypeInt 32 0", + "%int = OpTypeInt 32 1", + "%float = OpTypeFloat 32", + "%double = OpTypeFloat 64", + // vector types + "%v2bool = OpTypeVector %bool 2", + "%v2uint = OpTypeVector %uint 2", + "%v2int = OpTypeVector %int 2", + "%v3int = OpTypeVector %int 3", + "%v4int = OpTypeVector %int 4", + "%v2float = OpTypeVector %float 2", + "%v3float = OpTypeVector %float 3", + "%v2double = OpTypeVector %double 2", + // variable pointer types + "%_pf_bool = OpTypePointer Function %bool", + "%_pf_uint = OpTypePointer Function %uint", + "%_pf_int = OpTypePointer Function %int", + "%_pf_float = OpTypePointer Function %float", + "%_pf_double = OpTypePointer Function %double", + "%_pf_v2int = OpTypePointer Function %v2int", + "%_pf_v3int = OpTypePointer Function %v3int", + "%_pf_v2float = OpTypePointer Function %v2float", + "%_pf_v3float = OpTypePointer Function %v3float", + "%_pf_v2double = OpTypePointer Function %v2double", + // struct types + "%inner_struct = OpTypeStruct %bool %int %float %double", + "%outer_struct = OpTypeStruct %inner_struct %int %double", + "%flat_struct = OpTypeStruct %bool %int %float %double", + // clang-format on +}; + +using EliminateDeadConstantTest = + PassTest<::testing::TestWithParam>; + +TEST_P(EliminateDeadConstantTest, Custom) { + auto& tc = GetParam(); + AssemblyBuilder builder; + builder.AppendTypesConstantsGlobals(CommonTypes) + .AppendTypesConstantsGlobals(tc.used_consts) + .AppendInMain(tc.main_insts); + const std::string expected = builder.GetCode(); + builder.AppendTypesConstantsGlobals(tc.dead_consts); + const std::string assembly_with_dead_const = builder.GetCode(); + SinglePassRunAndCheck( + assembly_with_dead_const, expected, /* skip_nop = */ true); +} + +INSTANTIATE_TEST_CASE_P( + ScalarTypeConstants, EliminateDeadConstantTest, + ::testing::ValuesIn(std::vector({ + // clang-format off + // Scalar type constants, one dead constant and one used constant. + { + /* .used_consts = */ + { + "%used_const_int = OpConstant %int 1", + }, + /* .main_insts = */ + { + "%int_var = OpVariable %_pf_int Function", + "OpStore %int_var %used_const_int", + }, + /* .dead_consts = */ + { + "%dead_const_int = OpConstant %int 1", + }, + }, + { + /* .used_consts = */ + { + "%used_const_uint = OpConstant %uint 1", + }, + /* .main_insts = */ + { + "%uint_var = OpVariable %_pf_uint Function", + "OpStore %uint_var %used_const_uint", + }, + /* .dead_consts = */ + { + "%dead_const_uint = OpConstant %uint 1", + }, + }, + { + /* .used_consts = */ + { + "%used_const_float = OpConstant %float 3.1415", + }, + /* .main_insts = */ + { + "%float_var = OpVariable %_pf_float Function", + "OpStore %float_var %used_const_float", + }, + /* .dead_consts = */ + { + "%dead_const_float = OpConstant %float 3.1415", + }, + }, + { + /* .used_consts = */ + { + "%used_const_double = OpConstant %double 3.141592653", + }, + /* .main_insts = */ + { + "%double_var = OpVariable %_pf_double Function", + "OpStore %double_var %used_const_double", + }, + /* .dead_consts = */ + { + "%dead_const_double = OpConstant %double 3.141592653", + }, + }, + // clang-format on + }))); + +INSTANTIATE_TEST_CASE_P( + VectorTypeConstants, EliminateDeadConstantTest, + ::testing::ValuesIn(std::vector({ + // clang-format off + // Tests eliminating dead constant type ivec2. One dead constant vector + // and one used constant vector, each built from its own group of + // scalar constants. + { + /* .used_consts = */ + { + "%used_int_x = OpConstant %int 1", + "%used_int_y = OpConstant %int 2", + "%used_v2int = OpConstantComposite %v2int %used_int_x %used_int_y", + }, + /* .main_insts = */ + { + "%v2int_var = OpVariable %_pf_v2int Function", + "OpStore %v2int_var %used_v2int", + }, + /* .dead_consts = */ + { + "%dead_int_x = OpConstant %int 1", + "%dead_int_y = OpConstant %int 2", + "%dead_v2int = OpConstantComposite %v2int %dead_int_x %dead_int_y", + }, + }, + // Tests eliminating dead constant ivec2. One dead constant vector and + // one used constant vector. But both built from a same group of + // scalar constants. + { + /* .used_consts = */ + { + "%used_int_x = OpConstant %int 1", + "%used_int_y = OpConstant %int 2", + "%used_int_z = OpConstant %int 3", + "%used_v3int = OpConstantComposite %v3int %used_int_x %used_int_y %used_int_z", + }, + /* .main_insts = */ + { + "%v3int_var = OpVariable %_pf_v3int Function", + "OpStore %v3int_var %used_v3int", + }, + /* .dead_consts = */ + { + "%dead_v3int = OpConstantComposite %v3int %used_int_x %used_int_y %used_int_z", + }, + }, + // Tests eliminating dead cosntant vec2. One dead constant vector and + // one used constant vector. Each built from its own group of scalar + // constants. + { + /* .used_consts = */ + { + "%used_float_x = OpConstant %float 3.1415", + "%used_float_y = OpConstant %float 4.25", + "%used_v2float = OpConstantComposite %v2float %used_float_x %used_float_y", + }, + /* .main_insts = */ + { + "%v2float_var = OpVariable %_pf_v2float Function", + "OpStore %v2float_var %used_v2float", + }, + /* .dead_consts = */ + { + "%dead_float_x = OpConstant %float 3.1415", + "%dead_float_y = OpConstant %float 4.25", + "%dead_v2float = OpConstantComposite %v2float %dead_float_x %dead_float_y", + }, + }, + // Tests eliminating dead cosntant vec2. One dead constant vector and + // one used constant vector. Both built from a same group of scalar + // constants. + { + /* .used_consts = */ + { + "%used_float_x = OpConstant %float 3.1415", + "%used_float_y = OpConstant %float 4.25", + "%used_float_z = OpConstant %float 4.75", + "%used_v3float = OpConstantComposite %v3float %used_float_x %used_float_y %used_float_z", + }, + /* .main_insts = */ + { + "%v3float_var = OpVariable %_pf_v3float Function", + "OpStore %v3float_var %used_v3float", + }, + /* .dead_consts = */ + { + "%dead_v3float = OpConstantComposite %v3float %used_float_x %used_float_y %used_float_z", + }, + }, + // clang-format on + }))); + +INSTANTIATE_TEST_CASE_P( + StructTypeConstants, EliminateDeadConstantTest, + ::testing::ValuesIn(std::vector({ + // clang-format off + // A plain struct type dead constants. All of its components are dead + // constants too. + { + /* .used_consts = */ {}, + /* .main_insts = */ {}, + /* .dead_consts = */ + { + "%dead_bool = OpConstantTrue %bool", + "%dead_int = OpConstant %int 1", + "%dead_float = OpConstant %float 2.5", + "%dead_double = OpConstant %double 3.14159265358979", + "%dead_struct = OpConstantComposite %flat_struct %dead_bool %dead_int %dead_float %dead_double", + }, + }, + // A plain struct type dead constants. Some of its components are dead + // constants while others are not. + { + /* .used_consts = */ + { + "%used_int = OpConstant %int 1", + "%used_double = OpConstant %double 3.14159265358979", + }, + /* .main_insts = */ + { + "%int_var = OpVariable %_pf_int Function", + "OpStore %int_var %used_int", + "%double_var = OpVariable %_pf_double Function", + "OpStore %double_var %used_double", + }, + /* .dead_consts = */ + { + "%dead_bool = OpConstantTrue %bool", + "%dead_float = OpConstant %float 2.5", + "%dead_struct = OpConstantComposite %flat_struct %dead_bool %used_int %dead_float %used_double", + }, + }, + // A nesting struct type dead constants. All components of both outer + // and inner structs are dead and should be removed after dead constant + // elimination. + { + /* .used_consts = */ {}, + /* .main_insts = */ {}, + /* .dead_consts = */ + { + "%dead_bool = OpConstantTrue %bool", + "%dead_int = OpConstant %int 1", + "%dead_float = OpConstant %float 2.5", + "%dead_double = OpConstant %double 3.1415926535", + "%dead_inner_struct = OpConstantComposite %inner_struct %dead_bool %dead_int %dead_float %dead_double", + "%dead_int2 = OpConstant %int 2", + "%dead_double2 = OpConstant %double 1.428571428514", + "%dead_outer_struct = OpConstantComposite %outer_struct %dead_inner_struct %dead_int2 %dead_double2", + }, + }, + // A nesting struct type dead constants. Some of its components are + // dead constants while others are not. + { + /* .used_consts = */ + { + "%used_int = OpConstant %int 1", + "%used_double = OpConstant %double 3.14159265358979", + }, + /* .main_insts = */ + { + "%int_var = OpVariable %_pf_int Function", + "OpStore %int_var %used_int", + "%double_var = OpVariable %_pf_double Function", + "OpStore %double_var %used_double", + }, + /* .dead_consts = */ + { + "%dead_bool = OpConstantTrue %bool", + "%dead_float = OpConstant %float 2.5", + "%dead_inner_struct = OpConstantComposite %inner_struct %dead_bool %used_int %dead_float %used_double", + "%dead_int = OpConstant %int 2", + "%dead_outer_struct = OpConstantComposite %outer_struct %dead_inner_struct %dead_int %used_double", + }, + }, + // A nesting struct case. The inner struct is used while the outer struct is not + { + /* .used_const = */ + { + "%used_bool = OpConstantTrue %bool", + "%used_int = OpConstant %int 1", + "%used_float = OpConstant %float 1.25", + "%used_double = OpConstant %double 1.23456789012345", + "%used_inner_struct = OpConstantComposite %inner_struct %used_bool %used_int %used_float %used_double", + }, + /* .main_insts = */ + { + "%bool_var = OpVariable %_pf_bool Function", + "%bool_from_inner_struct = OpCompositeExtract %bool %used_inner_struct 0", + "OpStore %bool_var %bool_from_inner_struct", + }, + /* .dead_consts = */ + { + "%dead_int = OpConstant %int 2", + "%dead_outer_struct = OpConstantComposite %outer_struct %used_inner_struct %dead_int %used_double" + }, + }, + // A nesting struct case. The outer struct is used, so the inner struct should not + // be removed even though it is not used anywhere. + { + /* .used_const = */ + { + "%used_bool = OpConstantTrue %bool", + "%used_int = OpConstant %int 1", + "%used_float = OpConstant %float 1.25", + "%used_double = OpConstant %double 1.23456789012345", + "%used_inner_struct = OpConstantComposite %inner_struct %used_bool %used_int %used_float %used_double", + "%used_outer_struct = OpConstantComposite %outer_struct %used_inner_struct %used_int %used_double" + }, + /* .main_insts = */ + { + "%int_var = OpVariable %_pf_int Function", + "%int_from_outer_struct = OpCompositeExtract %int %used_outer_struct 1", + "OpStore %int_var %int_from_outer_struct", + }, + /* .dead_consts = */ {}, + }, + // clang-format on + }))); + +INSTANTIATE_TEST_CASE_P( + ScalarTypeSpecConstants, EliminateDeadConstantTest, + ::testing::ValuesIn(std::vector({ + // clang-format off + // All scalar type spec constants. + { + /* .used_consts = */ + { + "%used_bool = OpSpecConstantTrue %bool", + "%used_uint = OpSpecConstant %uint 2", + "%used_int = OpSpecConstant %int 2", + "%used_float = OpSpecConstant %float 2.5", + "%used_double = OpSpecConstant %double 1.42857142851", + }, + /* .main_insts = */ + { + "%bool_var = OpVariable %_pf_bool Function", + "%uint_var = OpVariable %_pf_uint Function", + "%int_var = OpVariable %_pf_int Function", + "%float_var = OpVariable %_pf_float Function", + "%double_var = OpVariable %_pf_double Function", + "OpStore %bool_var %used_bool", "OpStore %uint_var %used_uint", + "OpStore %int_var %used_int", "OpStore %float_var %used_float", + "OpStore %double_var %used_double", + }, + /* .dead_consts = */ + { + "%dead_bool = OpSpecConstantTrue %bool", + "%dead_uint = OpSpecConstant %uint 2", + "%dead_int = OpSpecConstant %int 2", + "%dead_float = OpSpecConstant %float 2.5", + "%dead_double = OpSpecConstant %double 1.42857142851", + }, + }, + // clang-format on + }))); + +INSTANTIATE_TEST_CASE_P( + VectorTypeSpecConstants, EliminateDeadConstantTest, + ::testing::ValuesIn(std::vector({ + // clang-format off + // Bool vector type spec constants. One vector has all component dead, + // another vector has one dead boolean and one used boolean. + { + /* .used_consts = */ + { + "%used_bool = OpSpecConstantTrue %bool", + }, + /* .main_insts = */ + { + "%bool_var = OpVariable %_pf_bool Function", + "OpStore %bool_var %used_bool", + }, + /* .dead_consts = */ + { + "%dead_bool = OpSpecConstantFalse %bool", + "%dead_bool_vec1 = OpSpecConstantComposite %v2bool %dead_bool %dead_bool", + "%dead_bool_vec2 = OpSpecConstantComposite %v2bool %dead_bool %used_bool", + }, + }, + + // Uint vector type spec constants. One vector has all component dead, + // another vector has one dead unsigend integer and one used unsigned + // integer. + { + /* .used_consts = */ + { + "%used_uint = OpSpecConstant %uint 3", + }, + /* .main_insts = */ + { + "%uint_var = OpVariable %_pf_uint Function", + "OpStore %uint_var %used_uint", + }, + /* .dead_consts = */ + { + "%dead_uint = OpSpecConstant %uint 1", + "%dead_uint_vec1 = OpSpecConstantComposite %v2uint %dead_uint %dead_uint", + "%dead_uint_vec2 = OpSpecConstantComposite %v2uint %dead_uint %used_uint", + }, + }, + + // Int vector type spec constants. One vector has all component dead, + // another vector has one dead integer and one used integer. + { + /* .used_consts = */ + { + "%used_int = OpSpecConstant %int 3", + }, + /* .main_insts = */ + { + "%int_var = OpVariable %_pf_int Function", + "OpStore %int_var %used_int", + }, + /* .dead_consts = */ + { + "%dead_int = OpSpecConstant %int 1", + "%dead_int_vec1 = OpSpecConstantComposite %v2int %dead_int %dead_int", + "%dead_int_vec2 = OpSpecConstantComposite %v2int %dead_int %used_int", + }, + }, + + // Int vector type spec constants built with both spec constants and + // front-end constants. + { + /* .used_consts = */ + { + "%used_spec_int = OpSpecConstant %int 3", + "%used_front_end_int = OpConstant %int 3", + }, + /* .main_insts = */ + { + "%int_var1 = OpVariable %_pf_int Function", + "OpStore %int_var1 %used_spec_int", + "%int_var2 = OpVariable %_pf_int Function", + "OpStore %int_var2 %used_front_end_int", + }, + /* .dead_consts = */ + { + "%dead_spec_int = OpSpecConstant %int 1", + "%dead_front_end_int = OpConstant %int 1", + // Dead front-end and dead spec constants + "%dead_int_vec1 = OpSpecConstantComposite %v2int %dead_spec_int %dead_front_end_int", + // Used front-end and dead spec constants + "%dead_int_vec2 = OpSpecConstantComposite %v2int %dead_spec_int %used_front_end_int", + // Dead front-end and used spec constants + "%dead_int_vec3 = OpSpecConstantComposite %v2int %dead_front_end_int %used_spec_int", + }, + }, + // clang-format on + }))); + +INSTANTIATE_TEST_CASE_P( + SpecConstantOp, EliminateDeadConstantTest, + ::testing::ValuesIn(std::vector({ + // clang-format off + // Cast operations: uint <-> int <-> bool + { + /* .used_consts = */ {}, + /* .main_insts = */ {}, + /* .dead_consts = */ + { + // Assistant constants, only used in dead spec constant + // operations. + "%signed_zero = OpConstant %int 0", + "%signed_zero_vec = OpConstantComposite %v2int %signed_zero %signed_zero", + "%unsigned_zero = OpConstant %uint 0", + "%unsigned_zero_vec = OpConstantComposite %v2uint %unsigned_zero %unsigned_zero", + "%signed_one = OpConstant %int 1", + "%signed_one_vec = OpConstantComposite %v2int %signed_one %signed_one", + "%unsigned_one = OpConstant %uint 1", + "%unsigned_one_vec = OpConstantComposite %v2uint %unsigned_one %unsigned_one", + + // Spec constants that support casting to each other. + "%dead_bool = OpSpecConstantTrue %bool", + "%dead_uint = OpSpecConstant %uint 1", + "%dead_int = OpSpecConstant %int 2", + "%dead_bool_vec = OpSpecConstantComposite %v2bool %dead_bool %dead_bool", + "%dead_uint_vec = OpSpecConstantComposite %v2uint %dead_uint %dead_uint", + "%dead_int_vec = OpSpecConstantComposite %v2int %dead_int %dead_int", + + // Scalar cast to boolean spec constant. + "%int_to_bool = OpSpecConstantOp %bool INotEqual %dead_int %signed_zero", + "%uint_to_bool = OpSpecConstantOp %bool INotEqual %dead_uint %unsigned_zero", + + // Vector cast to boolean spec constant. + "%int_to_bool_vec = OpSpecConstantOp %v2bool INotEqual %dead_int_vec %signed_zero_vec", + "%uint_to_bool_vec = OpSpecConstantOp %v2bool INotEqual %dead_uint_vec %unsigned_zero_vec", + + // Scalar cast to int spec constant. + "%bool_to_int = OpSpecConstantOp %int Select %dead_bool %signed_one %signed_zero", + "%uint_to_int = OpSpecConstantOp %uint IAdd %dead_uint %unsigned_zero", + + // Vector cast to int spec constant. + "%bool_to_int_vec = OpSpecConstantOp %v2int Select %dead_bool_vec %signed_one_vec %signed_zero_vec", + "%uint_to_int_vec = OpSpecConstantOp %v2uint IAdd %dead_uint_vec %unsigned_zero_vec", + + // Scalar cast to uint spec constant. + "%bool_to_uint = OpSpecConstantOp %uint Select %dead_bool %unsigned_one %unsigned_zero", + "%int_to_uint_vec = OpSpecConstantOp %uint IAdd %dead_int %signed_zero", + + // Vector cast to uint spec constant. + "%bool_to_uint_vec = OpSpecConstantOp %v2uint Select %dead_bool_vec %unsigned_one_vec %unsigned_zero_vec", + "%int_to_uint = OpSpecConstantOp %v2uint IAdd %dead_int_vec %signed_zero_vec", + }, + }, + + // Add, sub, mul, div, rem. + { + /* .used_consts = */ {}, + /* .main_insts = */ {}, + /* .dead_consts = */ + { + "%dead_spec_int_a = OpSpecConstant %int 1", + "%dead_spec_int_a_vec = OpSpecConstantComposite %v2int %dead_spec_int_a %dead_spec_int_a", + + "%dead_spec_int_b = OpSpecConstant %int 2", + "%dead_spec_int_b_vec = OpSpecConstantComposite %v2int %dead_spec_int_b %dead_spec_int_b", + + "%dead_const_int_c = OpConstant %int 3", + "%dead_const_int_c_vec = OpConstantComposite %v2int %dead_const_int_c %dead_const_int_c", + + // Add + "%add_a_b = OpSpecConstantOp %int IAdd %dead_spec_int_a %dead_spec_int_b", + "%add_a_b_vec = OpSpecConstantOp %v2int IAdd %dead_spec_int_a_vec %dead_spec_int_b_vec", + + // Sub + "%sub_a_b = OpSpecConstantOp %int ISub %dead_spec_int_a %dead_spec_int_b", + "%sub_a_b_vec = OpSpecConstantOp %v2int ISub %dead_spec_int_a_vec %dead_spec_int_b_vec", + + // Mul + "%mul_a_b = OpSpecConstantOp %int IMul %dead_spec_int_a %dead_spec_int_b", + "%mul_a_b_vec = OpSpecConstantOp %v2int IMul %dead_spec_int_a_vec %dead_spec_int_b_vec", + + // Div + "%div_a_b = OpSpecConstantOp %int SDiv %dead_spec_int_a %dead_spec_int_b", + "%div_a_b_vec = OpSpecConstantOp %v2int SDiv %dead_spec_int_a_vec %dead_spec_int_b_vec", + + // Bitwise Xor + "%xor_a_b = OpSpecConstantOp %int BitwiseXor %dead_spec_int_a %dead_spec_int_b", + "%xor_a_b_vec = OpSpecConstantOp %v2int BitwiseXor %dead_spec_int_a_vec %dead_spec_int_b_vec", + + // Scalar Comparison + "%less_a_b = OpSpecConstantOp %bool SLessThan %dead_spec_int_a %dead_spec_int_b", + }, + }, + + // Vectors without used swizzles should be removed. + { + /* .used_consts = */ + { + "%used_int = OpConstant %int 3", + }, + /* .main_insts = */ + { + "%int_var = OpVariable %_pf_int Function", + "OpStore %int_var %used_int", + }, + /* .dead_consts = */ + { + "%dead_int = OpConstant %int 3", + + "%dead_spec_int_a = OpSpecConstant %int 1", + "%vec_a = OpSpecConstantComposite %v4int %dead_spec_int_a %dead_spec_int_a %dead_int %dead_int", + + "%dead_spec_int_b = OpSpecConstant %int 2", + "%vec_b = OpSpecConstantComposite %v4int %dead_spec_int_b %dead_spec_int_b %used_int %used_int", + + // Extract scalar + "%a_x = OpSpecConstantOp %int CompositeExtract %vec_a 0", + "%b_x = OpSpecConstantOp %int CompositeExtract %vec_b 0", + + // Extract vector + "%a_xy = OpSpecConstantOp %v2int VectorShuffle %vec_a %vec_a 0 1", + "%b_xy = OpSpecConstantOp %v2int VectorShuffle %vec_b %vec_b 0 1", + }, + }, + // Vectors with used swizzles should not be removed. + { + /* .used_consts = */ + { + "%used_int = OpConstant %int 3", + "%used_spec_int_a = OpSpecConstant %int 1", + "%used_spec_int_b = OpSpecConstant %int 2", + // Create vectors + "%vec_a = OpSpecConstantComposite %v4int %used_spec_int_a %used_spec_int_a %used_int %used_int", + "%vec_b = OpSpecConstantComposite %v4int %used_spec_int_b %used_spec_int_b %used_int %used_int", + // Extract vector + "%a_xy = OpSpecConstantOp %v2int VectorShuffle %vec_a %vec_a 0 1", + "%b_xy = OpSpecConstantOp %v2int VectorShuffle %vec_b %vec_b 0 1", + }, + /* .main_insts = */ + { + "%v2int_var_a = OpVariable %_pf_v2int Function", + "%v2int_var_b = OpVariable %_pf_v2int Function", + "OpStore %v2int_var_a %a_xy", + "OpStore %v2int_var_b %b_xy", + }, + /* .dead_consts = */ {}, + }, + // clang-format on + }))); + +INSTANTIATE_TEST_CASE_P( + LongDefUseChain, EliminateDeadConstantTest, + ::testing::ValuesIn(std::vector({ + // clang-format off + // Long Def-Use chain with binary operations. + { + /* .used_consts = */ + { + "%array_size = OpConstant %int 4", + "%type_arr_int_4 = OpTypeArray %int %array_size", + "%used_int_0 = OpConstant %int 100", + "%used_int_1 = OpConstant %int 1", + "%used_int_2 = OpSpecConstantOp %int IAdd %used_int_0 %used_int_1", + "%used_int_3 = OpSpecConstantOp %int ISub %used_int_0 %used_int_2", + "%used_int_4 = OpSpecConstantOp %int IAdd %used_int_0 %used_int_3", + "%used_int_5 = OpSpecConstantOp %int ISub %used_int_0 %used_int_4", + "%used_int_6 = OpSpecConstantOp %int IAdd %used_int_0 %used_int_5", + "%used_int_7 = OpSpecConstantOp %int ISub %used_int_0 %used_int_6", + "%used_int_8 = OpSpecConstantOp %int IAdd %used_int_0 %used_int_7", + "%used_int_9 = OpSpecConstantOp %int ISub %used_int_0 %used_int_8", + "%used_int_10 = OpSpecConstantOp %int IAdd %used_int_0 %used_int_9", + "%used_int_11 = OpSpecConstantOp %int ISub %used_int_0 %used_int_10", + "%used_int_12 = OpSpecConstantOp %int IAdd %used_int_0 %used_int_11", + "%used_int_13 = OpSpecConstantOp %int ISub %used_int_0 %used_int_12", + "%used_int_14 = OpSpecConstantOp %int IAdd %used_int_0 %used_int_13", + "%used_int_15 = OpSpecConstantOp %int ISub %used_int_0 %used_int_14", + "%used_int_16 = OpSpecConstantOp %int ISub %used_int_0 %used_int_15", + "%used_int_17 = OpSpecConstantOp %int IAdd %used_int_0 %used_int_16", + "%used_int_18 = OpSpecConstantOp %int ISub %used_int_0 %used_int_17", + "%used_int_19 = OpSpecConstantOp %int IAdd %used_int_0 %used_int_18", + "%used_int_20 = OpSpecConstantOp %int ISub %used_int_0 %used_int_19", + "%used_vec_a = OpSpecConstantComposite %v2int %used_int_18 %used_int_19", + "%used_vec_b = OpSpecConstantOp %v2int IMul %used_vec_a %used_vec_a", + "%used_int_21 = OpSpecConstantOp %int CompositeExtract %used_vec_b 0", + "%used_array = OpConstantComposite %type_arr_int_4 %used_int_20 %used_int_20 %used_int_21 %used_int_21", + }, + /* .main_insts = */ + { + "%int_var = OpVariable %_pf_int Function", + "%used_array_2 = OpCompositeExtract %int %used_array 2", + "OpStore %int_var %used_array_2", + }, + /* .dead_consts = */ + { + "%dead_int_1 = OpConstant %int 2", + "%dead_int_2 = OpSpecConstantOp %int IAdd %used_int_0 %dead_int_1", + "%dead_int_3 = OpSpecConstantOp %int ISub %used_int_0 %dead_int_2", + "%dead_int_4 = OpSpecConstantOp %int IAdd %used_int_0 %dead_int_3", + "%dead_int_5 = OpSpecConstantOp %int ISub %used_int_0 %dead_int_4", + "%dead_int_6 = OpSpecConstantOp %int IAdd %used_int_0 %dead_int_5", + "%dead_int_7 = OpSpecConstantOp %int ISub %used_int_0 %dead_int_6", + "%dead_int_8 = OpSpecConstantOp %int IAdd %used_int_0 %dead_int_7", + "%dead_int_9 = OpSpecConstantOp %int ISub %used_int_0 %dead_int_8", + "%dead_int_10 = OpSpecConstantOp %int IAdd %used_int_0 %dead_int_9", + "%dead_int_11 = OpSpecConstantOp %int ISub %used_int_0 %dead_int_10", + "%dead_int_12 = OpSpecConstantOp %int IAdd %used_int_0 %dead_int_11", + "%dead_int_13 = OpSpecConstantOp %int ISub %used_int_0 %dead_int_12", + "%dead_int_14 = OpSpecConstantOp %int IAdd %used_int_0 %dead_int_13", + "%dead_int_15 = OpSpecConstantOp %int ISub %used_int_0 %dead_int_14", + "%dead_int_16 = OpSpecConstantOp %int ISub %used_int_0 %dead_int_15", + "%dead_int_17 = OpSpecConstantOp %int IAdd %used_int_0 %dead_int_16", + "%dead_int_18 = OpSpecConstantOp %int ISub %used_int_0 %dead_int_17", + "%dead_int_19 = OpSpecConstantOp %int IAdd %used_int_0 %dead_int_18", + "%dead_int_20 = OpSpecConstantOp %int ISub %used_int_0 %dead_int_19", + "%dead_vec_a = OpSpecConstantComposite %v2int %dead_int_18 %dead_int_19", + "%dead_vec_b = OpSpecConstantOp %v2int IMul %dead_vec_a %dead_vec_a", + "%dead_int_21 = OpSpecConstantOp %int CompositeExtract %dead_vec_b 0", + "%dead_array = OpConstantComposite %type_arr_int_4 %dead_int_20 %used_int_20 %dead_int_19 %used_int_19", + }, + }, + // Long Def-Use chain with swizzle + // clang-format on + }))); + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/eliminate_dead_functions_test.cpp b/test/opt/eliminate_dead_functions_test.cpp new file mode 100644 index 000000000..0a3d490a8 --- /dev/null +++ b/test/opt/eliminate_dead_functions_test.cpp @@ -0,0 +1,209 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gmock/gmock.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::HasSubstr; +using EliminateDeadFunctionsBasicTest = PassTest<::testing::Test>; + +TEST_F(EliminateDeadFunctionsBasicTest, BasicDeleteDeadFunction) { + // The function Dead should be removed because it is never called. + const std::vector common_code = { + // clang-format off + "OpCapability Shader", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Fragment %main \"main\"", + "OpName %main \"main\"", + "OpName %Live \"Live\"", + "%void = OpTypeVoid", + "%7 = OpTypeFunction %void", + "%main = OpFunction %void None %7", + "%15 = OpLabel", + "%16 = OpFunctionCall %void %Live", + "%17 = OpFunctionCall %void %Live", + "OpReturn", + "OpFunctionEnd", + "%Live = OpFunction %void None %7", + "%20 = OpLabel", + "OpReturn", + "OpFunctionEnd" + // clang-format on + }; + + const std::vector dead_function = { + // clang-format off + "%Dead = OpFunction %void None %7", + "%19 = OpLabel", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck( + JoinAllInsts(Concat(common_code, dead_function)), + JoinAllInsts(common_code), /* skip_nop = */ true); +} + +TEST_F(EliminateDeadFunctionsBasicTest, BasicKeepLiveFunction) { + // Everything is reachable from an entry point, so no functions should be + // deleted. + const std::vector text = { + // clang-format off + "OpCapability Shader", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Fragment %main \"main\"", + "OpName %main \"main\"", + "OpName %Live1 \"Live1\"", + "OpName %Live2 \"Live2\"", + "%void = OpTypeVoid", + "%7 = OpTypeFunction %void", + "%main = OpFunction %void None %7", + "%15 = OpLabel", + "%16 = OpFunctionCall %void %Live2", + "%17 = OpFunctionCall %void %Live1", + "OpReturn", + "OpFunctionEnd", + "%Live1 = OpFunction %void None %7", + "%19 = OpLabel", + "OpReturn", + "OpFunctionEnd", + "%Live2 = OpFunction %void None %7", + "%20 = OpLabel", + "OpReturn", + "OpFunctionEnd" + // clang-format on + }; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + std::string assembly = JoinAllInsts(text); + auto result = SinglePassRunAndDisassemble( + assembly, /* skip_nop = */ true, /* do_validation = */ false); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); + EXPECT_EQ(assembly, std::get<0>(result)); +} + +TEST_F(EliminateDeadFunctionsBasicTest, BasicKeepExportFunctions) { + // All functions are reachable. In particular, ExportedFunc and Constant are + // reachable because ExportedFunc is exported. Nothing should be removed. + const std::vector text = { + // clang-format off + "OpCapability Shader", + "OpCapability Linkage", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Fragment %main \"main\"", + "OpName %main \"main\"", + "OpName %ExportedFunc \"ExportedFunc\"", + "OpName %Live \"Live\"", + "OpDecorate %ExportedFunc LinkageAttributes \"ExportedFunc\" Export", + "%void = OpTypeVoid", + "%7 = OpTypeFunction %void", + "%main = OpFunction %void None %7", + "%15 = OpLabel", + "OpReturn", + "OpFunctionEnd", +"%ExportedFunc = OpFunction %void None %7", + "%19 = OpLabel", + "%16 = OpFunctionCall %void %Live", + "OpReturn", + "OpFunctionEnd", + "%Live = OpFunction %void None %7", + "%20 = OpLabel", + "OpReturn", + "OpFunctionEnd" + // clang-format on + }; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + std::string assembly = JoinAllInsts(text); + auto result = SinglePassRunAndDisassemble( + assembly, /* skip_nop = */ true, /* do_validation = */ false); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); + EXPECT_EQ(assembly, std::get<0>(result)); +} + +TEST_F(EliminateDeadFunctionsBasicTest, BasicRemoveDecorationsAndNames) { + // We want to remove the names and decorations associated with results that + // are removed. This test will check for that. + const std::string text = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpName %main "main" + OpName %Dead "Dead" + OpName %x "x" + OpName %y "y" + OpName %z "z" + OpDecorate %x RelaxedPrecision + OpDecorate %y RelaxedPrecision + OpDecorate %z RelaxedPrecision + OpDecorate %6 RelaxedPrecision + OpDecorate %7 RelaxedPrecision + OpDecorate %8 RelaxedPrecision + %void = OpTypeVoid + %10 = OpTypeFunction %void + %float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float + %float_1 = OpConstant %float 1 + %main = OpFunction %void None %10 + %14 = OpLabel + OpReturn + OpFunctionEnd + %Dead = OpFunction %void None %10 + %15 = OpLabel + %x = OpVariable %_ptr_Function_float Function + %y = OpVariable %_ptr_Function_float Function + %z = OpVariable %_ptr_Function_float Function + OpStore %x %float_1 + OpStore %y %float_1 + %6 = OpLoad %float %x + %7 = OpLoad %float %y + %8 = OpFAdd %float %6 %7 + OpStore %z %8 + OpReturn + OpFunctionEnd)"; + + const std::string expected_output = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %main "main" +OpName %main "main" +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%float_1 = OpConstant %float 1 +%main = OpFunction %void None %10 +%14 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(text, expected_output, + /* skip_nop = */ true); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/feature_manager_test.cpp b/test/opt/feature_manager_test.cpp new file mode 100644 index 000000000..767376cf5 --- /dev/null +++ b/test/opt/feature_manager_test.cpp @@ -0,0 +1,142 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "source/opt/build_module.h" +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace opt { +namespace { + +using FeatureManagerTest = ::testing::Test; + +TEST_F(FeatureManagerTest, MissingExtension) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ASSERT_NE(context, nullptr); + + EXPECT_FALSE(context->get_feature_mgr()->HasExtension( + Extension::kSPV_KHR_variable_pointers)); +} + +TEST_F(FeatureManagerTest, OneExtension) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpExtension "SPV_KHR_variable_pointers" + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ASSERT_NE(context, nullptr); + + EXPECT_TRUE(context->get_feature_mgr()->HasExtension( + Extension::kSPV_KHR_variable_pointers)); +} + +TEST_F(FeatureManagerTest, NotADifferentExtension) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpExtension "SPV_KHR_variable_pointers" + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ASSERT_NE(context, nullptr); + + EXPECT_FALSE(context->get_feature_mgr()->HasExtension( + Extension::kSPV_KHR_storage_buffer_storage_class)); +} + +TEST_F(FeatureManagerTest, TwoExtensions) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpExtension "SPV_KHR_variable_pointers" +OpExtension "SPV_KHR_storage_buffer_storage_class" + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ASSERT_NE(context, nullptr); + + EXPECT_TRUE(context->get_feature_mgr()->HasExtension( + Extension::kSPV_KHR_variable_pointers)); + EXPECT_TRUE(context->get_feature_mgr()->HasExtension( + Extension::kSPV_KHR_storage_buffer_storage_class)); +} + +// Test capability checks. +TEST_F(FeatureManagerTest, ExplicitlyPresent1) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ASSERT_NE(context, nullptr); + + EXPECT_TRUE(context->get_feature_mgr()->HasCapability(SpvCapabilityShader)); + EXPECT_FALSE(context->get_feature_mgr()->HasCapability(SpvCapabilityKernel)); +} + +TEST_F(FeatureManagerTest, ExplicitlyPresent2) { + const std::string text = R"( +OpCapability Kernel +OpMemoryModel Logical GLSL450 + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ASSERT_NE(context, nullptr); + + EXPECT_FALSE(context->get_feature_mgr()->HasCapability(SpvCapabilityShader)); + EXPECT_TRUE(context->get_feature_mgr()->HasCapability(SpvCapabilityKernel)); +} + +TEST_F(FeatureManagerTest, ImplicitlyPresent) { + const std::string text = R"( +OpCapability Tessellation +OpMemoryModel Logical GLSL450 + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ASSERT_NE(context, nullptr); + + // Check multiple levels of indirection. Tessellation implies Shader, which + // implies Matrix. + EXPECT_TRUE( + context->get_feature_mgr()->HasCapability(SpvCapabilityTessellation)); + EXPECT_TRUE(context->get_feature_mgr()->HasCapability(SpvCapabilityShader)); + EXPECT_TRUE(context->get_feature_mgr()->HasCapability(SpvCapabilityMatrix)); + EXPECT_FALSE(context->get_feature_mgr()->HasCapability(SpvCapabilityKernel)); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/flatten_decoration_test.cpp b/test/opt/flatten_decoration_test.cpp new file mode 100644 index 000000000..fcf2341e7 --- /dev/null +++ b/test/opt/flatten_decoration_test.cpp @@ -0,0 +1,239 @@ +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gmock/gmock.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +// Returns the initial part of the assembly text for a valid +// SPIR-V module, including instructions prior to decorations. +std::string PreambleAssembly() { + return + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %hue %saturation %value +OpExecutionMode %main OriginUpperLeft +OpName %main "main" +OpName %void_fn "void_fn" +OpName %hue "hue" +OpName %saturation "saturation" +OpName %value "value" +OpName %entry "entry" +OpName %Point "Point" +OpName %Camera "Camera" +)"; +} + +// Retuns types +std::string TypesAndFunctionsAssembly() { + return + R"(%void = OpTypeVoid +%void_fn = OpTypeFunction %void +%float = OpTypeFloat 32 +%Point = OpTypeStruct %float %float %float +%Camera = OpTypeStruct %float %float +%_ptr_Input_float = OpTypePointer Input %float +%hue = OpVariable %_ptr_Input_float Input +%saturation = OpVariable %_ptr_Input_float Input +%value = OpVariable %_ptr_Input_float Input +%main = OpFunction %void None %void_fn +%entry = OpLabel +OpReturn +OpFunctionEnd +)"; +} + +struct FlattenDecorationCase { + // Names and decorations before the pass. + std::string input; + // Names and decorations after the pass. + std::string expected; +}; + +using FlattenDecorationTest = + PassTest<::testing::TestWithParam>; + +TEST_P(FlattenDecorationTest, TransformsDecorations) { + const auto before = + PreambleAssembly() + GetParam().input + TypesAndFunctionsAssembly(); + const auto after = + PreambleAssembly() + GetParam().expected + TypesAndFunctionsAssembly(); + + SinglePassRunAndCheck(before, after, false, true); +} + +INSTANTIATE_TEST_CASE_P(NoUses, FlattenDecorationTest, + ::testing::ValuesIn(std::vector{ + // No OpDecorationGroup + {"", ""}, + + // OpDecorationGroup without any uses, and + // no OpName. + {"%group = OpDecorationGroup\n", ""}, + + // OpDecorationGroup without any uses, and + // with OpName targeting it. Proves you must + // remove the names as well. + {"OpName %group \"group\"\n" + "%group = OpDecorationGroup\n", + ""}, + + // OpDecorationGroup with decorations that + // target it, but no uses in OpGroupDecorate + // or OpGroupMemberDecorate instructions. + {"OpDecorate %group Flat\n" + "OpDecorate %group NoPerspective\n" + "%group = OpDecorationGroup\n", + ""}, + }), ); + +INSTANTIATE_TEST_CASE_P(OpGroupDecorate, FlattenDecorationTest, + ::testing::ValuesIn(std::vector{ + // One OpGroupDecorate + {"OpName %group \"group\"\n" + "OpDecorate %group Flat\n" + "OpDecorate %group NoPerspective\n" + "%group = OpDecorationGroup\n" + "OpGroupDecorate %group %hue %saturation\n", + "OpDecorate %hue Flat\n" + "OpDecorate %saturation Flat\n" + "OpDecorate %hue NoPerspective\n" + "OpDecorate %saturation NoPerspective\n"}, + // Multiple OpGroupDecorate + {"OpName %group \"group\"\n" + "OpDecorate %group Flat\n" + "OpDecorate %group NoPerspective\n" + "%group = OpDecorationGroup\n" + "OpGroupDecorate %group %hue %value\n" + "OpGroupDecorate %group %saturation\n", + "OpDecorate %hue Flat\n" + "OpDecorate %value Flat\n" + "OpDecorate %saturation Flat\n" + "OpDecorate %hue NoPerspective\n" + "OpDecorate %value NoPerspective\n" + "OpDecorate %saturation NoPerspective\n"}, + // Two group decorations, interleaved + {"OpName %group0 \"group0\"\n" + "OpName %group1 \"group1\"\n" + "OpDecorate %group0 Flat\n" + "OpDecorate %group1 NoPerspective\n" + "%group0 = OpDecorationGroup\n" + "%group1 = OpDecorationGroup\n" + "OpGroupDecorate %group0 %hue %value\n" + "OpGroupDecorate %group1 %saturation\n", + "OpDecorate %hue Flat\n" + "OpDecorate %value Flat\n" + "OpDecorate %saturation NoPerspective\n"}, + // Decoration with operands + {"OpName %group \"group\"\n" + "OpDecorate %group Location 42\n" + "%group = OpDecorationGroup\n" + "OpGroupDecorate %group %hue %saturation\n", + "OpDecorate %hue Location 42\n" + "OpDecorate %saturation Location 42\n"}, + }), ); + +INSTANTIATE_TEST_CASE_P(OpGroupMemberDecorate, FlattenDecorationTest, + ::testing::ValuesIn(std::vector{ + // One OpGroupMemberDecorate + {"OpName %group \"group\"\n" + "OpDecorate %group Flat\n" + "OpDecorate %group Offset 16\n" + "%group = OpDecorationGroup\n" + "OpGroupMemberDecorate %group %Point 1\n", + "OpMemberDecorate %Point 1 Flat\n" + "OpMemberDecorate %Point 1 Offset 16\n"}, + // Multiple OpGroupMemberDecorate using the same + // decoration group. + {"OpName %group \"group\"\n" + "OpDecorate %group Flat\n" + "OpDecorate %group NoPerspective\n" + "OpDecorate %group Offset 8\n" + "%group = OpDecorationGroup\n" + "OpGroupMemberDecorate %group %Point 2\n" + "OpGroupMemberDecorate %group %Camera 1\n", + "OpMemberDecorate %Point 2 Flat\n" + "OpMemberDecorate %Camera 1 Flat\n" + "OpMemberDecorate %Point 2 NoPerspective\n" + "OpMemberDecorate %Camera 1 NoPerspective\n" + "OpMemberDecorate %Point 2 Offset 8\n" + "OpMemberDecorate %Camera 1 Offset 8\n"}, + // Two groups of member decorations, interleaved. + // Decoration is with and without operands. + {"OpName %group0 \"group0\"\n" + "OpName %group1 \"group1\"\n" + "OpDecorate %group0 Flat\n" + "OpDecorate %group0 Offset 8\n" + "OpDecorate %group1 NoPerspective\n" + "OpDecorate %group1 Offset 16\n" + "%group0 = OpDecorationGroup\n" + "%group1 = OpDecorationGroup\n" + "OpGroupMemberDecorate %group0 %Point 0\n" + "OpGroupMemberDecorate %group1 %Point 2\n", + "OpMemberDecorate %Point 0 Flat\n" + "OpMemberDecorate %Point 0 Offset 8\n" + "OpMemberDecorate %Point 2 NoPerspective\n" + "OpMemberDecorate %Point 2 Offset 16\n"}, + }), ); + +INSTANTIATE_TEST_CASE_P(UnrelatedDecorations, FlattenDecorationTest, + ::testing::ValuesIn(std::vector{ + // A non-group non-member decoration is untouched. + {"OpDecorate %hue Centroid\n" + "OpDecorate %saturation Flat\n", + "OpDecorate %hue Centroid\n" + "OpDecorate %saturation Flat\n"}, + // A non-group member decoration is untouched. + {"OpMemberDecorate %Point 0 Offset 0\n" + "OpMemberDecorate %Point 1 Offset 4\n" + "OpMemberDecorate %Point 1 Flat\n", + "OpMemberDecorate %Point 0 Offset 0\n" + "OpMemberDecorate %Point 1 Offset 4\n" + "OpMemberDecorate %Point 1 Flat\n"}, + // A non-group non-member decoration survives any + // replacement of group decorations. + {"OpName %group \"group\"\n" + "OpDecorate %group Flat\n" + "OpDecorate %hue Centroid\n" + "OpDecorate %group NoPerspective\n" + "%group = OpDecorationGroup\n" + "OpGroupDecorate %group %hue %saturation\n", + "OpDecorate %hue Flat\n" + "OpDecorate %saturation Flat\n" + "OpDecorate %hue Centroid\n" + "OpDecorate %hue NoPerspective\n" + "OpDecorate %saturation NoPerspective\n"}, + // A non-group member decoration survives any + // replacement of group decorations. + {"OpDecorate %group Offset 0\n" + "OpDecorate %group Flat\n" + "OpMemberDecorate %Point 1 Offset 4\n" + "%group = OpDecorationGroup\n" + "OpGroupMemberDecorate %group %Point 0\n", + "OpMemberDecorate %Point 0 Offset 0\n" + "OpMemberDecorate %Point 0 Flat\n" + "OpMemberDecorate %Point 1 Offset 4\n"}, + }), ); + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/fold_spec_const_op_composite_test.cpp b/test/opt/fold_spec_const_op_composite_test.cpp new file mode 100644 index 000000000..8ecfd5c78 --- /dev/null +++ b/test/opt/fold_spec_const_op_composite_test.cpp @@ -0,0 +1,1394 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using FoldSpecConstantOpAndCompositePassBasicTest = PassTest<::testing::Test>; + +TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, Empty) { + SinglePassRunAndCheck( + "", "", /* skip_nop = */ true); +} + +// A test of the basic functionality of FoldSpecConstantOpAndCompositePass. +// A spec constant defined with an integer addition operation should be folded +// to a normal constant with fixed value. +TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, Basic) { + AssemblyBuilder builder; + builder.AppendTypesConstantsGlobals({ + // clang-format off + "%int = OpTypeInt 32 1", + "%frozen_spec_const_int = OpConstant %int 1", + "%const_int = OpConstant %int 2", + // Folding target: + "%spec_add = OpSpecConstantOp %int IAdd %frozen_spec_const_int %const_int", + // clang-format on + }); + + std::vector expected = { + // clang-format off + "OpCapability Shader", + "OpCapability Float64", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Vertex %main \"main\"", + "OpName %void \"void\"", + "OpName %main_func_type \"main_func_type\"", + "OpName %main \"main\"", + "OpName %main_func_entry_block \"main_func_entry_block\"", + "OpName %int \"int\"", + "OpName %frozen_spec_const_int \"frozen_spec_const_int\"", + "OpName %const_int \"const_int\"", + "OpName %spec_add \"spec_add\"", + "%void = OpTypeVoid", + "%main_func_type = OpTypeFunction %void", + "%int = OpTypeInt 32 1", +"%frozen_spec_const_int = OpConstant %int 1", + "%const_int = OpConstant %int 2", + // The SpecConstantOp IAdd instruction should be replace by OpConstant + // instruction: + "%spec_add = OpConstant %int 3", + "%main = OpFunction %void None %main_func_type", +"%main_func_entry_block = OpLabel", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + SinglePassRunAndCheck( + builder.GetCode(), JoinAllInsts(expected), /* skip_nop = */ true); +} + +// A test of skipping folding an instruction when the instruction result type +// has decorations. +TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, + SkipWhenTypeHasDecorations) { + AssemblyBuilder builder; + builder + .AppendAnnotations({ + // clang-format off + "OpDecorate %int RelaxedPrecision", + // clang-format on + }) + .AppendTypesConstantsGlobals({ + // clang-format off + "%int = OpTypeInt 32 1", + "%frozen_spec_const_int = OpConstant %int 1", + "%const_int = OpConstant %int 2", + // The following spec constant should not be folded as the result type + // has relaxed precision decoration. + "%spec_add = OpSpecConstantOp %int IAdd %frozen_spec_const_int %const_int", + // clang-format on + }); + + SinglePassRunAndCheck( + builder.GetCode(), builder.GetCode(), /* skip_nop = */ true); +} + +// All types and some common constants that are potentially required in +// FoldSpecConstantOpAndCompositeTest. +std::vector CommonTypesAndConstants() { + return std::vector{ + // clang-format off + // scalar types + "%bool = OpTypeBool", + "%uint = OpTypeInt 32 0", + "%int = OpTypeInt 32 1", + "%float = OpTypeFloat 32", + "%double = OpTypeFloat 64", + // vector types + "%v2bool = OpTypeVector %bool 2", + "%v2uint = OpTypeVector %uint 2", + "%v2int = OpTypeVector %int 2", + "%v3int = OpTypeVector %int 3", + "%v4int = OpTypeVector %int 4", + "%v2float = OpTypeVector %float 2", + "%v2double = OpTypeVector %double 2", + // variable pointer types + "%_pf_bool = OpTypePointer Function %bool", + "%_pf_uint = OpTypePointer Function %uint", + "%_pf_int = OpTypePointer Function %int", + "%_pf_float = OpTypePointer Function %float", + "%_pf_double = OpTypePointer Function %double", + "%_pf_v2int = OpTypePointer Function %v2int", + "%_pf_v2float = OpTypePointer Function %v2float", + "%_pf_v2double = OpTypePointer Function %v2double", + // struct types + "%inner_struct = OpTypeStruct %bool %int %float", + "%outer_struct = OpTypeStruct %inner_struct %int", + "%flat_struct = OpTypeStruct %bool %int %float", + + // common constants + // scalar constants: + "%bool_true = OpConstantTrue %bool", + "%bool_false = OpConstantFalse %bool", + "%bool_null = OpConstantNull %bool", + "%signed_zero = OpConstant %int 0", + "%unsigned_zero = OpConstant %uint 0", + "%signed_one = OpConstant %int 1", + "%unsigned_one = OpConstant %uint 1", + "%signed_two = OpConstant %int 2", + "%unsigned_two = OpConstant %uint 2", + "%signed_three = OpConstant %int 3", + "%unsigned_three = OpConstant %uint 3", + "%signed_null = OpConstantNull %int", + "%unsigned_null = OpConstantNull %uint", + // vector constants: + "%bool_true_vec = OpConstantComposite %v2bool %bool_true %bool_true", + "%bool_false_vec = OpConstantComposite %v2bool %bool_false %bool_false", + "%bool_null_vec = OpConstantNull %v2bool", + "%signed_zero_vec = OpConstantComposite %v2int %signed_zero %signed_zero", + "%unsigned_zero_vec = OpConstantComposite %v2uint %unsigned_zero %unsigned_zero", + "%signed_one_vec = OpConstantComposite %v2int %signed_one %signed_one", + "%unsigned_one_vec = OpConstantComposite %v2uint %unsigned_one %unsigned_one", + "%signed_two_vec = OpConstantComposite %v2int %signed_two %signed_two", + "%unsigned_two_vec = OpConstantComposite %v2uint %unsigned_two %unsigned_two", + "%signed_three_vec = OpConstantComposite %v2int %signed_three %signed_three", + "%unsigned_three_vec = OpConstantComposite %v2uint %unsigned_three %unsigned_three", + "%signed_null_vec = OpConstantNull %v2int", + "%unsigned_null_vec = OpConstantNull %v2uint", + "%v4int_0_1_2_3 = OpConstantComposite %v4int %signed_zero %signed_one %signed_two %signed_three", + // clang-format on + }; +} + +// A helper function to strip OpName instructions from the given string of +// disassembly code. Returns the string with all OpName instruction stripped. +std::string StripOpNameInstructions(const std::string& str) { + std::stringstream ss(str); + std::ostringstream oss; + std::string inst_str; + while (std::getline(ss, inst_str, '\n')) { + if (inst_str.find("OpName %") == std::string::npos) { + oss << inst_str << '\n'; + } + } + return oss.str(); +} + +struct FoldSpecConstantOpAndCompositePassTestCase { + // Original constants with unfolded spec constants. + std::vector original; + // Expected cosntants after folding. + std::vector expected; +}; + +using FoldSpecConstantOpAndCompositePassTest = PassTest< + ::testing::TestWithParam>; + +TEST_P(FoldSpecConstantOpAndCompositePassTest, ParamTestCase) { + AssemblyBuilder test_code_builder, expected_code_builder; + const auto& tc = GetParam(); + test_code_builder.AppendTypesConstantsGlobals(CommonTypesAndConstants()); + test_code_builder.AppendTypesConstantsGlobals(tc.original); + expected_code_builder.AppendTypesConstantsGlobals(CommonTypesAndConstants()); + expected_code_builder.AppendTypesConstantsGlobals(tc.expected); + const std::string original = test_code_builder.GetCode(); + const std::string expected = expected_code_builder.GetCode(); + + // Run the optimization and get the result code in disassembly. + std::string optimized; + auto status = Pass::Status::SuccessWithoutChange; + std::tie(optimized, status) = + SinglePassRunAndDisassemble( + original, /* skip_nop = */ true, /* do_validation = */ false); + + // Check the optimized code, but ignore the OpName instructions. + EXPECT_NE(Pass::Status::Failure, status); + EXPECT_EQ( + StripOpNameInstructions(expected) == StripOpNameInstructions(original), + status == Pass::Status::SuccessWithoutChange); + EXPECT_EQ(StripOpNameInstructions(expected), + StripOpNameInstructions(optimized)); +} + +// Tests that OpSpecConstantComposite opcodes are replace with +// OpConstantComposite correctly. +INSTANTIATE_TEST_CASE_P( + Composite, FoldSpecConstantOpAndCompositePassTest, + ::testing::ValuesIn(std::vector< + FoldSpecConstantOpAndCompositePassTestCase>({ + // clang-format off + // normal vector + { + // original + { + "%spec_v2bool = OpSpecConstantComposite %v2bool %bool_true %bool_false", + "%spec_v2uint = OpSpecConstantComposite %v2uint %unsigned_one %unsigned_one", + "%spec_v2int_a = OpSpecConstantComposite %v2int %signed_one %signed_two", + // Spec constants whose value can not be fully resolved should + // not be processed. + "%spec_int = OpSpecConstant %int 99", + "%spec_v2int_b = OpSpecConstantComposite %v2int %signed_one %spec_int", + }, + // expected + { + "%spec_v2bool = OpConstantComposite %v2bool %bool_true %bool_false", + "%spec_v2uint = OpConstantComposite %v2uint %unsigned_one %unsigned_one", + "%spec_v2int_a = OpConstantComposite %v2int %signed_one %signed_two", + "%spec_int = OpSpecConstant %int 99", + "%spec_v2int_b = OpSpecConstantComposite %v2int %signed_one %spec_int", + }, + }, + // vector with null constants + { + // original + { + "%null_bool = OpConstantNull %bool", + "%null_int = OpConstantNull %int", + "%spec_v2bool = OpSpecConstantComposite %v2bool %null_bool %null_bool", + "%spec_v3int = OpSpecConstantComposite %v3int %null_int %null_int %null_int", + "%spec_v4int = OpSpecConstantComposite %v4int %null_int %null_int %null_int %null_int", + }, + // expected + { + "%null_bool = OpConstantNull %bool", + "%null_int = OpConstantNull %int", + "%spec_v2bool = OpConstantComposite %v2bool %null_bool %null_bool", + "%spec_v3int = OpConstantComposite %v3int %null_int %null_int %null_int", + "%spec_v4int = OpConstantComposite %v4int %null_int %null_int %null_int %null_int", + }, + }, + // flat struct + { + // original + { + "%float_1 = OpConstant %float 1", + "%flat_1 = OpSpecConstantComposite %flat_struct %bool_true %signed_null %float_1", + // following struct should not be folded as the value of + // %spec_float is not determined. + "%spec_float = OpSpecConstant %float 1", + "%flat_2 = OpSpecConstantComposite %flat_struct %bool_true %signed_one %spec_float", + }, + // expected + { + "%float_1 = OpConstant %float 1", + "%flat_1 = OpConstantComposite %flat_struct %bool_true %signed_null %float_1", + "%spec_float = OpSpecConstant %float 1", + "%flat_2 = OpSpecConstantComposite %flat_struct %bool_true %signed_one %spec_float", + } + }, + // nested struct + { + // original + { + "%float_1 = OpConstant %float 1", + "%inner_1 = OpSpecConstantComposite %inner_struct %bool_true %signed_null %float_1", + "%outer_1 = OpSpecConstantComposite %outer_struct %inner_1 %signed_one", + // following structs should not be folded as the value of + // %spec_float is not determined. + "%spec_float = OpSpecConstant %float 1", + "%inner_2 = OpSpecConstantComposite %inner_struct %bool_true %signed_null %spec_float", + "%outer_2 = OpSpecConstantComposite %outer_struct %inner_2 %signed_one", + }, + // expected + { + "%float_1 = OpConstant %float 1", + "%inner_1 = OpConstantComposite %inner_struct %bool_true %signed_null %float_1", + "%outer_1 = OpConstantComposite %outer_struct %inner_1 %signed_one", + "%spec_float = OpSpecConstant %float 1", + "%inner_2 = OpSpecConstantComposite %inner_struct %bool_true %signed_null %spec_float", + "%outer_2 = OpSpecConstantComposite %outer_struct %inner_2 %signed_one", + } + }, + // composite constants touched by OpUndef should be skipped + { + // original + { + "%undef = OpUndef %float", + "%inner = OpConstantComposite %inner_struct %bool_true %signed_one %undef", + "%outer = OpSpecConstantComposite %outer_struct %inner %signed_one", + }, + // expected + { + "%undef = OpUndef %float", + "%inner = OpConstantComposite %inner_struct %bool_true %signed_one %undef", + "%outer = OpSpecConstantComposite %outer_struct %inner %signed_one", + }, + } + // clang-format on + }))); + +// Tests for operations that resulting in different types. +INSTANTIATE_TEST_CASE_P( + Cast, FoldSpecConstantOpAndCompositePassTest, + ::testing::ValuesIn( + std::vector({ + // clang-format off + // int -> bool scalar + { + // original + { + "%spec_bool_t = OpSpecConstantOp %bool INotEqual %signed_three %signed_zero", + "%spec_bool_f = OpSpecConstantOp %bool INotEqual %signed_zero %signed_zero", + "%spec_bool_from_null = OpSpecConstantOp %bool INotEqual %signed_null %signed_zero", + }, + // expected + { + "%spec_bool_t = OpConstantTrue %bool", + "%spec_bool_f = OpConstantFalse %bool", + "%spec_bool_from_null = OpConstantFalse %bool", + }, + }, + + // uint -> bool scalar + { + // original + { + "%spec_bool_t = OpSpecConstantOp %bool INotEqual %unsigned_three %unsigned_zero", + "%spec_bool_f = OpSpecConstantOp %bool INotEqual %unsigned_zero %unsigned_zero", + "%spec_bool_from_null = OpSpecConstantOp %bool INotEqual %unsigned_null %unsigned_zero", + }, + // expected + { + "%spec_bool_t = OpConstantTrue %bool", + "%spec_bool_f = OpConstantFalse %bool", + "%spec_bool_from_null = OpConstantFalse %bool", + }, + }, + + // bool -> int scalar + { + // original + { + "%spec_int_one = OpSpecConstantOp %int Select %bool_true %signed_one %signed_zero", + "%spec_int_zero = OpSpecConstantOp %int Select %bool_false %signed_one %signed_zero", + "%spec_int_from_null = OpSpecConstantOp %int Select %bool_null %signed_one %signed_zero", + }, + // expected + { + "%spec_int_one = OpConstant %int 1", + "%spec_int_zero = OpConstant %int 0", + "%spec_int_from_null = OpConstant %int 0", + }, + }, + + // uint -> int scalar + { + // original + { + "%spec_int_one = OpSpecConstantOp %int IAdd %unsigned_one %signed_zero", + "%spec_int_zero = OpSpecConstantOp %int IAdd %unsigned_zero %signed_zero", + "%spec_int_from_null = OpSpecConstantOp %int IAdd %unsigned_null %unsigned_zero", + }, + // expected + { + "%spec_int_one = OpConstant %int 1", + "%spec_int_zero = OpConstant %int 0", + "%spec_int_from_null = OpConstant %int 0", + }, + }, + + // bool -> uint scalar + { + // original + { + "%spec_uint_one = OpSpecConstantOp %uint Select %bool_true %unsigned_one %unsigned_zero", + "%spec_uint_zero = OpSpecConstantOp %uint Select %bool_false %unsigned_one %unsigned_zero", + "%spec_uint_from_null = OpSpecConstantOp %uint Select %bool_null %unsigned_one %unsigned_zero", + }, + // expected + { + "%spec_uint_one = OpConstant %uint 1", + "%spec_uint_zero = OpConstant %uint 0", + "%spec_uint_from_null = OpConstant %uint 0", + }, + }, + + // int -> uint scalar + { + // original + { + "%spec_uint_one = OpSpecConstantOp %uint IAdd %signed_one %unsigned_zero", + "%spec_uint_zero = OpSpecConstantOp %uint IAdd %signed_zero %unsigned_zero", + "%spec_uint_from_null = OpSpecConstantOp %uint IAdd %signed_null %unsigned_zero", + }, + // expected + { + "%spec_uint_one = OpConstant %uint 1", + "%spec_uint_zero = OpConstant %uint 0", + "%spec_uint_from_null = OpConstant %uint 0", + }, + }, + + // int -> bool vector + { + // original + { + "%spec_bool_t_vec = OpSpecConstantOp %v2bool INotEqual %signed_three_vec %signed_zero_vec", + "%spec_bool_f_vec = OpSpecConstantOp %v2bool INotEqual %signed_zero_vec %signed_zero_vec", + "%spec_bool_from_null = OpSpecConstantOp %v2bool INotEqual %signed_null_vec %signed_zero_vec", + }, + // expected + { + "%true = OpConstantTrue %bool", + "%true_0 = OpConstantTrue %bool", + "%spec_bool_t_vec = OpConstantComposite %v2bool %bool_true %bool_true", + "%false = OpConstantFalse %bool", + "%false_0 = OpConstantFalse %bool", + "%spec_bool_f_vec = OpConstantComposite %v2bool %bool_false %bool_false", + "%false_1 = OpConstantFalse %bool", + "%false_2 = OpConstantFalse %bool", + "%spec_bool_from_null = OpConstantComposite %v2bool %bool_false %bool_false", + }, + }, + + // uint -> bool vector + { + // original + { + "%spec_bool_t_vec = OpSpecConstantOp %v2bool INotEqual %unsigned_three_vec %unsigned_zero_vec", + "%spec_bool_f_vec = OpSpecConstantOp %v2bool INotEqual %unsigned_zero_vec %unsigned_zero_vec", + "%spec_bool_from_null = OpSpecConstantOp %v2bool INotEqual %unsigned_null_vec %unsigned_zero_vec", + }, + // expected + { + "%true = OpConstantTrue %bool", + "%true_0 = OpConstantTrue %bool", + "%spec_bool_t_vec = OpConstantComposite %v2bool %bool_true %bool_true", + "%false = OpConstantFalse %bool", + "%false_0 = OpConstantFalse %bool", + "%spec_bool_f_vec = OpConstantComposite %v2bool %bool_false %bool_false", + "%false_1 = OpConstantFalse %bool", + "%false_2 = OpConstantFalse %bool", + "%spec_bool_from_null = OpConstantComposite %v2bool %bool_false %bool_false", + }, + }, + + // bool -> int vector + { + // original + { + "%spec_int_one_vec = OpSpecConstantOp %v2int Select %bool_true_vec %signed_one_vec %signed_zero_vec", + "%spec_int_zero_vec = OpSpecConstantOp %v2int Select %bool_false_vec %signed_one_vec %signed_zero_vec", + "%spec_int_from_null = OpSpecConstantOp %v2int Select %bool_null_vec %signed_one_vec %signed_zero_vec", + }, + // expected + { + "%int_1 = OpConstant %int 1", + "%int_1_0 = OpConstant %int 1", + "%spec_int_one_vec = OpConstantComposite %v2int %signed_one %signed_one", + "%int_0 = OpConstant %int 0", + "%int_0_0 = OpConstant %int 0", + "%spec_int_zero_vec = OpConstantComposite %v2int %signed_zero %signed_zero", + "%int_0_1 = OpConstant %int 0", + "%int_0_2 = OpConstant %int 0", + "%spec_int_from_null = OpConstantComposite %v2int %signed_zero %signed_zero", + }, + }, + + // uint -> int vector + { + // original + { + "%spec_int_one_vec = OpSpecConstantOp %v2int IAdd %unsigned_one_vec %signed_zero_vec", + "%spec_int_zero_vec = OpSpecConstantOp %v2int IAdd %unsigned_zero_vec %signed_zero_vec", + "%spec_int_from_null = OpSpecConstantOp %v2int IAdd %unsigned_null_vec %signed_zero_vec", + }, + // expected + { + "%int_1 = OpConstant %int 1", + "%int_1_0 = OpConstant %int 1", + "%spec_int_one_vec = OpConstantComposite %v2int %signed_one %signed_one", + "%int_0 = OpConstant %int 0", + "%int_0_0 = OpConstant %int 0", + "%spec_int_zero_vec = OpConstantComposite %v2int %signed_zero %signed_zero", + "%int_0_1 = OpConstant %int 0", + "%int_0_2 = OpConstant %int 0", + "%spec_int_from_null = OpConstantComposite %v2int %signed_zero %signed_zero", + }, + }, + + // bool -> uint vector + { + // original + { + "%spec_uint_one_vec = OpSpecConstantOp %v2uint Select %bool_true_vec %unsigned_one_vec %unsigned_zero_vec", + "%spec_uint_zero_vec = OpSpecConstantOp %v2uint Select %bool_false_vec %unsigned_one_vec %unsigned_zero_vec", + "%spec_uint_from_null = OpSpecConstantOp %v2uint Select %bool_null_vec %unsigned_one_vec %unsigned_zero_vec", + }, + // expected + { + "%uint_1 = OpConstant %uint 1", + "%uint_1_0 = OpConstant %uint 1", + "%spec_uint_one_vec = OpConstantComposite %v2uint %unsigned_one %unsigned_one", + "%uint_0 = OpConstant %uint 0", + "%uint_0_0 = OpConstant %uint 0", + "%spec_uint_zero_vec = OpConstantComposite %v2uint %unsigned_zero %unsigned_zero", + "%uint_0_1 = OpConstant %uint 0", + "%uint_0_2 = OpConstant %uint 0", + "%spec_uint_from_null = OpConstantComposite %v2uint %unsigned_zero %unsigned_zero", + }, + }, + + // int -> uint vector + { + // original + { + "%spec_uint_one_vec = OpSpecConstantOp %v2uint IAdd %signed_one_vec %unsigned_zero_vec", + "%spec_uint_zero_vec = OpSpecConstantOp %v2uint IAdd %signed_zero_vec %unsigned_zero_vec", + "%spec_uint_from_null = OpSpecConstantOp %v2uint IAdd %signed_null_vec %unsigned_zero_vec", + }, + // expected + { + "%uint_1 = OpConstant %uint 1", + "%uint_1_0 = OpConstant %uint 1", + "%spec_uint_one_vec = OpConstantComposite %v2uint %unsigned_one %unsigned_one", + "%uint_0 = OpConstant %uint 0", + "%uint_0_0 = OpConstant %uint 0", + "%spec_uint_zero_vec = OpConstantComposite %v2uint %unsigned_zero %unsigned_zero", + "%uint_0_1 = OpConstant %uint 0", + "%uint_0_2 = OpConstant %uint 0", + "%spec_uint_from_null = OpConstantComposite %v2uint %unsigned_zero %unsigned_zero", + }, + }, + // clang-format on + }))); + +// Tests about boolean scalar logical operations and comparison operations with +// scalar int/uint type. +INSTANTIATE_TEST_CASE_P( + Logical, FoldSpecConstantOpAndCompositePassTest, + ::testing::ValuesIn(std::vector< + FoldSpecConstantOpAndCompositePassTestCase>({ + // clang-format off + // scalar integer comparison + { + // original + { + "%int_minus_1 = OpConstant %int -1", + + "%slt_0_1 = OpSpecConstantOp %bool SLessThan %signed_zero %signed_one", + "%sgt_0_1 = OpSpecConstantOp %bool SGreaterThan %signed_zero %signed_one", + "%sle_2_2 = OpSpecConstantOp %bool SLessThanEqual %signed_two %signed_two", + "%sge_2_1 = OpSpecConstantOp %bool SGreaterThanEqual %signed_two %signed_one", + "%sge_2_null = OpSpecConstantOp %bool SGreaterThanEqual %signed_two %signed_null", + "%sge_minus_1_null = OpSpecConstantOp %bool SGreaterThanEqual %int_minus_1 %signed_null", + + "%ult_0_1 = OpSpecConstantOp %bool ULessThan %unsigned_zero %unsigned_one", + "%ugt_0_1 = OpSpecConstantOp %bool UGreaterThan %unsigned_zero %unsigned_one", + "%ule_2_3 = OpSpecConstantOp %bool ULessThanEqual %unsigned_two %unsigned_three", + "%uge_1_1 = OpSpecConstantOp %bool UGreaterThanEqual %unsigned_one %unsigned_one", + "%uge_2_null = OpSpecConstantOp %bool UGreaterThanEqual %unsigned_two %unsigned_null", + "%uge_minus_1_null = OpSpecConstantOp %bool UGreaterThanEqual %int_minus_1 %unsigned_null", + }, + // expected + { + "%int_minus_1 = OpConstant %int -1", + + "%slt_0_1 = OpConstantTrue %bool", + "%sgt_0_1 = OpConstantFalse %bool", + "%sle_2_2 = OpConstantTrue %bool", + "%sge_2_1 = OpConstantTrue %bool", + "%sge_2_null = OpConstantTrue %bool", + "%sge_minus_1_null = OpConstantFalse %bool", + + "%ult_0_1 = OpConstantTrue %bool", + "%ugt_0_1 = OpConstantFalse %bool", + "%ule_2_3 = OpConstantTrue %bool", + "%uge_1_1 = OpConstantTrue %bool", + "%uge_2_null = OpConstantTrue %bool", + "%uge_minus_1_null = OpConstantTrue %bool", + }, + }, + // Logical and, or, xor. + { + // original + { + "%logical_or = OpSpecConstantOp %bool LogicalOr %bool_true %bool_false", + "%logical_and = OpSpecConstantOp %bool LogicalAnd %bool_true %bool_false", + "%logical_not = OpSpecConstantOp %bool LogicalNot %bool_true", + "%logical_eq = OpSpecConstantOp %bool LogicalEqual %bool_true %bool_true", + "%logical_neq = OpSpecConstantOp %bool LogicalNotEqual %bool_true %bool_true", + "%logical_and_null = OpSpecConstantOp %bool LogicalAnd %bool_true %bool_null", + }, + // expected + { + "%logical_or = OpConstantTrue %bool", + "%logical_and = OpConstantFalse %bool", + "%logical_not = OpConstantFalse %bool", + "%logical_eq = OpConstantTrue %bool", + "%logical_neq = OpConstantFalse %bool", + "%logical_and_null = OpConstantFalse %bool", + }, + }, + // clang-format on + }))); + +// Tests about arithmetic operations for scalar int and uint types. +INSTANTIATE_TEST_CASE_P( + ScalarArithmetic, FoldSpecConstantOpAndCompositePassTest, + ::testing::ValuesIn(std::vector< + FoldSpecConstantOpAndCompositePassTestCase>({ + // clang-format off + // scalar integer negate + { + // original + { + "%int_minus_1 = OpSpecConstantOp %int SNegate %signed_one", + "%int_minus_2 = OpSpecConstantOp %int SNegate %signed_two", + "%int_neg_null = OpSpecConstantOp %int SNegate %signed_null", + "%int_max = OpConstant %int 2147483647", + "%int_neg_max = OpSpecConstantOp %int SNegate %int_max", + }, + // expected + { + "%int_minus_1 = OpConstant %int -1", + "%int_minus_2 = OpConstant %int -2", + "%int_neg_null = OpConstant %int 0", + "%int_max = OpConstant %int 2147483647", + "%int_neg_max = OpConstant %int -2147483647", + }, + }, + // scalar integer not + { + // original + { + "%uint_4294967294 = OpSpecConstantOp %uint Not %unsigned_one", + "%uint_4294967293 = OpSpecConstantOp %uint Not %unsigned_two", + "%uint_neg_null = OpSpecConstantOp %uint Not %unsigned_null", + }, + // expected + { + "%uint_4294967294 = OpConstant %uint 4294967294", + "%uint_4294967293 = OpConstant %uint 4294967293", + "%uint_neg_null = OpConstant %uint 4294967295", + }, + }, + // scalar integer add, sub, mul, div + { + // original + { + "%signed_max = OpConstant %int 2147483647", + "%signed_min = OpConstant %int -2147483648", + + "%spec_int_iadd = OpSpecConstantOp %int IAdd %signed_three %signed_two", + "%spec_int_isub = OpSpecConstantOp %int ISub %signed_one %spec_int_iadd", + "%spec_int_sdiv = OpSpecConstantOp %int SDiv %spec_int_isub %signed_two", + "%spec_int_imul = OpSpecConstantOp %int IMul %spec_int_sdiv %signed_three", + "%spec_int_iadd_null = OpSpecConstantOp %int IAdd %spec_int_imul %signed_null", + "%spec_int_imul_null = OpSpecConstantOp %int IMul %spec_int_iadd_null %signed_null", + "%spec_int_iadd_overflow = OpSpecConstantOp %int IAdd %signed_max %signed_three", + "%spec_int_isub_overflow = OpSpecConstantOp %int ISub %signed_min %signed_three", + + "%spec_uint_iadd = OpSpecConstantOp %uint IAdd %unsigned_three %unsigned_two", + "%spec_uint_isub = OpSpecConstantOp %uint ISub %unsigned_one %spec_uint_iadd", + "%spec_uint_udiv = OpSpecConstantOp %uint UDiv %spec_uint_isub %unsigned_three", + "%spec_uint_imul = OpSpecConstantOp %uint IMul %spec_uint_udiv %unsigned_two", + "%spec_uint_isub_null = OpSpecConstantOp %uint ISub %spec_uint_imul %signed_null", + }, + // expected + { + "%signed_max = OpConstant %int 2147483647", + "%signed_min = OpConstant %int -2147483648", + + "%spec_int_iadd = OpConstant %int 5", + "%spec_int_isub = OpConstant %int -4", + "%spec_int_sdiv = OpConstant %int -2", + "%spec_int_imul = OpConstant %int -6", + "%spec_int_iadd_null = OpConstant %int -6", + "%spec_int_imul_null = OpConstant %int 0", + "%spec_int_iadd_overflow = OpConstant %int -2147483646", + "%spec_int_isub_overflow = OpConstant %int 2147483645", + + "%spec_uint_iadd = OpConstant %uint 5", + "%spec_uint_isub = OpConstant %uint 4294967292", + "%spec_uint_udiv = OpConstant %uint 1431655764", + "%spec_uint_imul = OpConstant %uint 2863311528", + "%spec_uint_isub_null = OpConstant %uint 2863311528", + }, + }, + // scalar integer rem, mod + { + // original + { + // common constants + "%int_7 = OpConstant %int 7", + "%uint_7 = OpConstant %uint 7", + "%int_minus_7 = OpConstant %int -7", + "%int_minus_3 = OpConstant %int -3", + + // srem + "%7_srem_3 = OpSpecConstantOp %int SRem %int_7 %signed_three", + "%minus_7_srem_3 = OpSpecConstantOp %int SRem %int_minus_7 %signed_three", + "%7_srem_minus_3 = OpSpecConstantOp %int SRem %int_7 %int_minus_3", + "%minus_7_srem_minus_3 = OpSpecConstantOp %int SRem %int_minus_7 %int_minus_3", + // smod + "%7_smod_3 = OpSpecConstantOp %int SMod %int_7 %signed_three", + "%minus_7_smod_3 = OpSpecConstantOp %int SMod %int_minus_7 %signed_three", + "%7_smod_minus_3 = OpSpecConstantOp %int SMod %int_7 %int_minus_3", + "%minus_7_smod_minus_3 = OpSpecConstantOp %int SMod %int_minus_7 %int_minus_3", + // umod + "%7_umod_3 = OpSpecConstantOp %uint UMod %uint_7 %unsigned_three", + // null constant + "%null_srem_3 = OpSpecConstantOp %int SRem %signed_null %signed_three", + "%null_smod_3 = OpSpecConstantOp %int SMod %signed_null %signed_three", + "%null_umod_3 = OpSpecConstantOp %uint UMod %unsigned_null %unsigned_three", + }, + // expected + { + // common constants + "%int_7 = OpConstant %int 7", + "%uint_7 = OpConstant %uint 7", + "%int_minus_7 = OpConstant %int -7", + "%int_minus_3 = OpConstant %int -3", + + // srem + "%7_srem_3 = OpConstant %int 1", + "%minus_7_srem_3 = OpConstant %int -1", + "%7_srem_minus_3 = OpConstant %int 1", + "%minus_7_srem_minus_3 = OpConstant %int -1", + // smod + "%7_smod_3 = OpConstant %int 1", + "%minus_7_smod_3 = OpConstant %int 2", + "%7_smod_minus_3 = OpConstant %int -2", + "%minus_7_smod_minus_3 = OpConstant %int -1", + // umod + "%7_umod_3 = OpConstant %uint 1", + // null constant + "%null_srem_3 = OpConstant %int 0", + "%null_smod_3 = OpConstant %int 0", + "%null_umod_3 = OpConstant %uint 0", + }, + }, + // scalar integer bitwise and shift + { + // original + { + // bitwise + "%xor_1_3 = OpSpecConstantOp %int BitwiseXor %signed_one %signed_three", + "%and_1_2 = OpSpecConstantOp %int BitwiseAnd %signed_one %xor_1_3", + "%or_1_2 = OpSpecConstantOp %int BitwiseOr %signed_one %xor_1_3", + "%xor_3_null = OpSpecConstantOp %int BitwiseXor %or_1_2 %signed_null", + + // shift + "%unsigned_31 = OpConstant %uint 31", + "%unsigned_left_shift_max = OpSpecConstantOp %uint ShiftLeftLogical %unsigned_one %unsigned_31", + "%unsigned_right_shift_logical = OpSpecConstantOp %uint ShiftRightLogical %unsigned_left_shift_max %unsigned_31", + "%signed_right_shift_arithmetic = OpSpecConstantOp %int ShiftRightArithmetic %unsigned_left_shift_max %unsigned_31", + "%left_shift_null_31 = OpSpecConstantOp %uint ShiftLeftLogical %unsigned_null %unsigned_31", + "%right_shift_31_null = OpSpecConstantOp %uint ShiftRightLogical %unsigned_31 %unsigned_null", + }, + // expected + { + "%xor_1_3 = OpConstant %int 2", + "%and_1_2 = OpConstant %int 0", + "%or_1_2 = OpConstant %int 3", + "%xor_3_null = OpConstant %int 3", + + "%unsigned_31 = OpConstant %uint 31", + "%unsigned_left_shift_max = OpConstant %uint 2147483648", + "%unsigned_right_shift_logical = OpConstant %uint 1", + "%signed_right_shift_arithmetic = OpConstant %int -1", + "%left_shift_null_31 = OpConstant %uint 0", + "%right_shift_31_null = OpConstant %uint 31", + }, + }, + // Skip folding if any operands have undetermined value. + { + // original + { + "%spec_int = OpSpecConstant %int 1", + "%spec_iadd = OpSpecConstantOp %int IAdd %signed_three %spec_int", + }, + // expected + { + "%spec_int = OpSpecConstant %int 1", + "%spec_iadd = OpSpecConstantOp %int IAdd %signed_three %spec_int", + }, + }, + // clang-format on + }))); + +// Tests about arithmetic operations for vector int and uint types. +INSTANTIATE_TEST_CASE_P( + VectorArithmetic, FoldSpecConstantOpAndCompositePassTest, + ::testing::ValuesIn(std::vector< + FoldSpecConstantOpAndCompositePassTestCase>({ + // clang-format off + // vector integer negate + { + // original + { + "%v2int_minus_1 = OpSpecConstantOp %v2int SNegate %signed_one_vec", + "%v2int_minus_2 = OpSpecConstantOp %v2int SNegate %signed_two_vec", + "%v2int_neg_null = OpSpecConstantOp %v2int SNegate %signed_null_vec", + }, + // expected + { + "%int_n1 = OpConstant %int -1", + "%int_n1_0 = OpConstant %int -1", + "%v2int_minus_1 = OpConstantComposite %v2int %int_n1 %int_n1", + "%int_n2 = OpConstant %int -2", + "%int_n2_0 = OpConstant %int -2", + "%v2int_minus_2 = OpConstantComposite %v2int %int_n2 %int_n2", + "%int_0 = OpConstant %int 0", + "%int_0_0 = OpConstant %int 0", + "%v2int_neg_null = OpConstantComposite %v2int %signed_zero %signed_zero", + }, + }, + // vector integer (including null vetors) add, sub, div, mul + { + // original + { + "%spec_v2int_iadd = OpSpecConstantOp %v2int IAdd %signed_three_vec %signed_two_vec", + "%spec_v2int_isub = OpSpecConstantOp %v2int ISub %signed_one_vec %spec_v2int_iadd", + "%spec_v2int_sdiv = OpSpecConstantOp %v2int SDiv %spec_v2int_isub %signed_two_vec", + "%spec_v2int_imul = OpSpecConstantOp %v2int IMul %spec_v2int_sdiv %signed_three_vec", + "%spec_v2int_iadd_null = OpSpecConstantOp %v2int IAdd %spec_v2int_imul %signed_null_vec", + + "%spec_v2uint_iadd = OpSpecConstantOp %v2uint IAdd %unsigned_three_vec %unsigned_two_vec", + "%spec_v2uint_isub = OpSpecConstantOp %v2uint ISub %unsigned_one_vec %spec_v2uint_iadd", + "%spec_v2uint_udiv = OpSpecConstantOp %v2uint UDiv %spec_v2uint_isub %unsigned_three_vec", + "%spec_v2uint_imul = OpSpecConstantOp %v2uint IMul %spec_v2uint_udiv %unsigned_two_vec", + "%spec_v2uint_isub_null = OpSpecConstantOp %v2uint ISub %spec_v2uint_imul %signed_null_vec", + }, + // expected + { + "%int_5 = OpConstant %int 5", + "%int_5_0 = OpConstant %int 5", + "%spec_v2int_iadd = OpConstantComposite %v2int %int_5 %int_5", + "%int_n4 = OpConstant %int -4", + "%int_n4_0 = OpConstant %int -4", + "%spec_v2int_isub = OpConstantComposite %v2int %int_n4 %int_n4", + "%int_n2 = OpConstant %int -2", + "%int_n2_0 = OpConstant %int -2", + "%spec_v2int_sdiv = OpConstantComposite %v2int %int_n2 %int_n2", + "%int_n6 = OpConstant %int -6", + "%int_n6_0 = OpConstant %int -6", + "%spec_v2int_imul = OpConstantComposite %v2int %int_n6 %int_n6", + "%int_n6_1 = OpConstant %int -6", + "%int_n6_2 = OpConstant %int -6", + "%spec_v2int_iadd_null = OpConstantComposite %v2int %int_n6 %int_n6", + + "%uint_5 = OpConstant %uint 5", + "%uint_5_0 = OpConstant %uint 5", + "%spec_v2uint_iadd = OpConstantComposite %v2uint %uint_5 %uint_5", + "%uint_4294967292 = OpConstant %uint 4294967292", + "%uint_4294967292_0 = OpConstant %uint 4294967292", + "%spec_v2uint_isub = OpConstantComposite %v2uint %uint_4294967292 %uint_4294967292", + "%uint_1431655764 = OpConstant %uint 1431655764", + "%uint_1431655764_0 = OpConstant %uint 1431655764", + "%spec_v2uint_udiv = OpConstantComposite %v2uint %uint_1431655764 %uint_1431655764", + "%uint_2863311528 = OpConstant %uint 2863311528", + "%uint_2863311528_0 = OpConstant %uint 2863311528", + "%spec_v2uint_imul = OpConstantComposite %v2uint %uint_2863311528 %uint_2863311528", + "%uint_2863311528_1 = OpConstant %uint 2863311528", + "%uint_2863311528_2 = OpConstant %uint 2863311528", + "%spec_v2uint_isub_null = OpConstantComposite %v2uint %uint_2863311528 %uint_2863311528", + }, + }, + // vector integer rem, mod + { + // original + { + // common constants + "%int_7 = OpConstant %int 7", + "%v2int_7 = OpConstantComposite %v2int %int_7 %int_7", + "%uint_7 = OpConstant %uint 7", + "%v2uint_7 = OpConstantComposite %v2uint %uint_7 %uint_7", + "%int_minus_7 = OpConstant %int -7", + "%v2int_minus_7 = OpConstantComposite %v2int %int_minus_7 %int_minus_7", + "%int_minus_3 = OpConstant %int -3", + "%v2int_minus_3 = OpConstantComposite %v2int %int_minus_3 %int_minus_3", + + // srem + "%7_srem_3 = OpSpecConstantOp %v2int SRem %v2int_7 %signed_three_vec", + "%minus_7_srem_3 = OpSpecConstantOp %v2int SRem %v2int_minus_7 %signed_three_vec", + "%7_srem_minus_3 = OpSpecConstantOp %v2int SRem %v2int_7 %v2int_minus_3", + "%minus_7_srem_minus_3 = OpSpecConstantOp %v2int SRem %v2int_minus_7 %v2int_minus_3", + // smod + "%7_smod_3 = OpSpecConstantOp %v2int SMod %v2int_7 %signed_three_vec", + "%minus_7_smod_3 = OpSpecConstantOp %v2int SMod %v2int_minus_7 %signed_three_vec", + "%7_smod_minus_3 = OpSpecConstantOp %v2int SMod %v2int_7 %v2int_minus_3", + "%minus_7_smod_minus_3 = OpSpecConstantOp %v2int SMod %v2int_minus_7 %v2int_minus_3", + // umod + "%7_umod_3 = OpSpecConstantOp %v2uint UMod %v2uint_7 %unsigned_three_vec", + }, + // expected + { + // common constants + "%int_7 = OpConstant %int 7", + "%v2int_7 = OpConstantComposite %v2int %int_7 %int_7", + "%uint_7 = OpConstant %uint 7", + "%v2uint_7 = OpConstantComposite %v2uint %uint_7 %uint_7", + "%int_minus_7 = OpConstant %int -7", + "%v2int_minus_7 = OpConstantComposite %v2int %int_minus_7 %int_minus_7", + "%int_minus_3 = OpConstant %int -3", + "%v2int_minus_3 = OpConstantComposite %v2int %int_minus_3 %int_minus_3", + + // srem + "%int_1 = OpConstant %int 1", + "%int_1_0 = OpConstant %int 1", + "%7_srem_3 = OpConstantComposite %v2int %signed_one %signed_one", + "%int_n1 = OpConstant %int -1", + "%int_n1_0 = OpConstant %int -1", + "%minus_7_srem_3 = OpConstantComposite %v2int %int_n1 %int_n1", + "%int_1_1 = OpConstant %int 1", + "%int_1_2 = OpConstant %int 1", + "%7_srem_minus_3 = OpConstantComposite %v2int %signed_one %signed_one", + "%int_n1_1 = OpConstant %int -1", + "%int_n1_2 = OpConstant %int -1", + "%minus_7_srem_minus_3 = OpConstantComposite %v2int %int_n1 %int_n1", + // smod + "%int_1_3 = OpConstant %int 1", + "%int_1_4 = OpConstant %int 1", + "%7_smod_3 = OpConstantComposite %v2int %signed_one %signed_one", + "%int_2 = OpConstant %int 2", + "%int_2_0 = OpConstant %int 2", + "%minus_7_smod_3 = OpConstantComposite %v2int %signed_two %signed_two", + "%int_n2 = OpConstant %int -2", + "%int_n2_0 = OpConstant %int -2", + "%7_smod_minus_3 = OpConstantComposite %v2int %int_n2 %int_n2", + "%int_n1_3 = OpConstant %int -1", + "%int_n1_4 = OpConstant %int -1", + "%minus_7_smod_minus_3 = OpConstantComposite %v2int %int_n1 %int_n1", + // umod + "%uint_1 = OpConstant %uint 1", + "%uint_1_0 = OpConstant %uint 1", + "%7_umod_3 = OpConstantComposite %v2uint %unsigned_one %unsigned_one", + }, + }, + // vector integer bitwise, shift + { + // original + { + "%xor_1_3 = OpSpecConstantOp %v2int BitwiseXor %signed_one_vec %signed_three_vec", + "%and_1_2 = OpSpecConstantOp %v2int BitwiseAnd %signed_one_vec %xor_1_3", + "%or_1_2 = OpSpecConstantOp %v2int BitwiseOr %signed_one_vec %xor_1_3", + + "%unsigned_31 = OpConstant %uint 31", + "%v2unsigned_31 = OpConstantComposite %v2uint %unsigned_31 %unsigned_31", + "%unsigned_left_shift_max = OpSpecConstantOp %v2uint ShiftLeftLogical %unsigned_one_vec %v2unsigned_31", + "%unsigned_right_shift_logical = OpSpecConstantOp %v2uint ShiftRightLogical %unsigned_left_shift_max %v2unsigned_31", + "%signed_right_shift_arithmetic = OpSpecConstantOp %v2int ShiftRightArithmetic %unsigned_left_shift_max %v2unsigned_31", + }, + // expected + { + "%int_2 = OpConstant %int 2", + "%int_2_0 = OpConstant %int 2", + "%xor_1_3 = OpConstantComposite %v2int %signed_two %signed_two", + "%int_0 = OpConstant %int 0", + "%int_0_0 = OpConstant %int 0", + "%and_1_2 = OpConstantComposite %v2int %signed_zero %signed_zero", + "%int_3 = OpConstant %int 3", + "%int_3_0 = OpConstant %int 3", + "%or_1_2 = OpConstantComposite %v2int %signed_three %signed_three", + + "%unsigned_31 = OpConstant %uint 31", + "%v2unsigned_31 = OpConstantComposite %v2uint %unsigned_31 %unsigned_31", + "%uint_2147483648 = OpConstant %uint 2147483648", + "%uint_2147483648_0 = OpConstant %uint 2147483648", + "%unsigned_left_shift_max = OpConstantComposite %v2uint %uint_2147483648 %uint_2147483648", + "%uint_1 = OpConstant %uint 1", + "%uint_1_0 = OpConstant %uint 1", + "%unsigned_right_shift_logical = OpConstantComposite %v2uint %unsigned_one %unsigned_one", + "%int_n1 = OpConstant %int -1", + "%int_n1_0 = OpConstant %int -1", + "%signed_right_shift_arithmetic = OpConstantComposite %v2int %int_n1 %int_n1", + }, + }, + // Skip folding if any vector operands or components of the operands + // have undetermined value. + { + // original + { + "%spec_int = OpSpecConstant %int 1", + "%spec_vec = OpSpecConstantComposite %v2int %signed_zero %spec_int", + "%spec_iadd = OpSpecConstantOp %v2int IAdd %signed_three_vec %spec_vec", + }, + // expected + { + "%spec_int = OpSpecConstant %int 1", + "%spec_vec = OpSpecConstantComposite %v2int %signed_zero %spec_int", + "%spec_iadd = OpSpecConstantOp %v2int IAdd %signed_three_vec %spec_vec", + }, + }, + // Skip folding if any vector operands are defined by OpUndef + { + // original + { + "%undef = OpUndef %int", + "%vec = OpConstantComposite %v2int %undef %signed_one", + "%spec_iadd = OpSpecConstantOp %v2int IAdd %signed_three_vec %vec", + }, + // expected + { + "%undef = OpUndef %int", + "%vec = OpConstantComposite %v2int %undef %signed_one", + "%spec_iadd = OpSpecConstantOp %v2int IAdd %signed_three_vec %vec", + }, + }, + // clang-format on + }))); + +// Tests for SpecConstantOp CompositeExtract instruction +INSTANTIATE_TEST_CASE_P( + CompositeExtract, FoldSpecConstantOpAndCompositePassTest, + ::testing::ValuesIn(std::vector< + FoldSpecConstantOpAndCompositePassTestCase>({ + // clang-format off + // normal vector + { + // original + { + "%r = OpSpecConstantOp %int CompositeExtract %signed_three_vec 0", + "%x = OpSpecConstantOp %int CompositeExtract %v4int_0_1_2_3 0", + "%y = OpSpecConstantOp %int CompositeExtract %v4int_0_1_2_3 1", + "%z = OpSpecConstantOp %int CompositeExtract %v4int_0_1_2_3 2", + "%w = OpSpecConstantOp %int CompositeExtract %v4int_0_1_2_3 3", + }, + // expected + { + "%r = OpConstant %int 3", + "%x = OpConstant %int 0", + "%y = OpConstant %int 1", + "%z = OpConstant %int 2", + "%w = OpConstant %int 3", + }, + }, + // null vector + { + // original + { + "%x = OpSpecConstantOp %int CompositeExtract %signed_null_vec 0", + "%y = OpSpecConstantOp %int CompositeExtract %signed_null_vec 1", + "%null_v4int = OpConstantNull %v4int", + "%z = OpSpecConstantOp %int CompositeExtract %signed_null_vec 2", + }, + // expected + { + "%x = OpConstantNull %int", + "%y = OpConstantNull %int", + "%null_v4int = OpConstantNull %v4int", + "%z = OpConstantNull %int", + } + }, + // normal flat struct + { + // original + { + "%float_1 = OpConstant %float 1", + "%flat_1 = OpConstantComposite %flat_struct %bool_true %signed_null %float_1", + "%extract_bool = OpSpecConstantOp %bool CompositeExtract %flat_1 0", + "%extract_int = OpSpecConstantOp %int CompositeExtract %flat_1 1", + "%extract_float_1 = OpSpecConstantOp %float CompositeExtract %flat_1 2", + // foldable composite constants built with OpSpecConstantComposite + // should also be processed. + "%flat_2 = OpSpecConstantComposite %flat_struct %bool_true %signed_null %float_1", + "%extract_float_2 = OpSpecConstantOp %float CompositeExtract %flat_2 2", + }, + // expected + { + "%float_1 = OpConstant %float 1", + "%flat_1 = OpConstantComposite %flat_struct %bool_true %signed_null %float_1", + "%extract_bool = OpConstantTrue %bool", + "%extract_int = OpConstantNull %int", + "%extract_float_1 = OpConstant %float 1", + "%flat_2 = OpConstantComposite %flat_struct %bool_true %signed_null %float_1", + "%extract_float_2 = OpConstant %float 1", + }, + }, + // null flat struct + { + // original + { + "%flat = OpConstantNull %flat_struct", + "%extract_bool = OpSpecConstantOp %bool CompositeExtract %flat 0", + "%extract_int = OpSpecConstantOp %int CompositeExtract %flat 1", + "%extract_float = OpSpecConstantOp %float CompositeExtract %flat 2", + }, + // expected + { + "%flat = OpConstantNull %flat_struct", + "%extract_bool = OpConstantNull %bool", + "%extract_int = OpConstantNull %int", + "%extract_float = OpConstantNull %float", + }, + }, + // normal nested struct + { + // original + { + "%float_1 = OpConstant %float 1", + "%inner = OpConstantComposite %inner_struct %bool_true %signed_null %float_1", + "%outer = OpConstantComposite %outer_struct %inner %signed_one", + "%extract_inner = OpSpecConstantOp %inner_struct CompositeExtract %outer 0", + "%extract_int = OpSpecConstantOp %int CompositeExtract %outer 1", + "%extract_inner_float = OpSpecConstantOp %int CompositeExtract %outer 0 2", + }, + // expected + { + "%float_1 = OpConstant %float 1", + "%inner = OpConstantComposite %inner_struct %bool_true %signed_null %float_1", + "%outer = OpConstantComposite %outer_struct %inner %signed_one", + "%extract_inner = OpConstantComposite %flat_struct %bool_true %signed_null %float_1", + "%extract_int = OpConstant %int 1", + "%extract_inner_float = OpConstant %float 1", + }, + }, + // null nested struct + { + // original + { + "%outer = OpConstantNull %outer_struct", + "%extract_inner = OpSpecConstantOp %inner_struct CompositeExtract %outer 0", + "%extract_int = OpSpecConstantOp %int CompositeExtract %outer 1", + "%extract_inner_float = OpSpecConstantOp %float CompositeExtract %outer 0 2", + }, + // expected + { + "%outer = OpConstantNull %outer_struct", + "%extract_inner = OpConstantNull %inner_struct", + "%extract_int = OpConstantNull %int", + "%extract_inner_float = OpConstantNull %float", + }, + }, + // skip folding if the any composite constant's value are not fully + // determined, even though the extracting target might have + // determined value. + { + // original + { + "%float_1 = OpConstant %float 1", + "%spec_float = OpSpecConstant %float 1", + "%spec_inner = OpSpecConstantComposite %inner_struct %bool_true %signed_null %spec_float", + "%spec_outer = OpSpecConstantComposite %outer_struct %spec_inner %signed_one", + "%spec_vec = OpSpecConstantComposite %v2float %spec_float %float_1", + "%extract_inner = OpSpecConstantOp %int CompositeExtract %spec_inner 1", + "%extract_outer = OpSpecConstantOp %int CompositeExtract %spec_outer 1", + "%extract_vec = OpSpecConstantOp %float CompositeExtract %spec_vec 1", + }, + // expected + { + "%float_1 = OpConstant %float 1", + "%spec_float = OpSpecConstant %float 1", + "%spec_inner = OpSpecConstantComposite %inner_struct %bool_true %signed_null %spec_float", + "%spec_outer = OpSpecConstantComposite %outer_struct %spec_inner %signed_one", + "%spec_vec = OpSpecConstantComposite %v2float %spec_float %float_1", + "%extract_inner = OpSpecConstantOp %int CompositeExtract %spec_inner 1", + "%extract_outer = OpSpecConstantOp %int CompositeExtract %spec_outer 1", + "%extract_vec = OpSpecConstantOp %float CompositeExtract %spec_vec 1", + }, + }, + // skip if the composite constant depends on the result of OpUndef, + // even though the composite extract target element does not depends + // on the OpUndef. + { + // original + { + "%undef = OpUndef %float", + "%inner = OpConstantComposite %inner_struct %bool_true %signed_one %undef", + "%outer = OpConstantComposite %outer_struct %inner %signed_one", + "%extract_inner = OpSpecConstantOp %int CompositeExtract %inner 1", + "%extract_outer = OpSpecConstantOp %int CompositeExtract %outer 1", + }, + // expected + { + "%undef = OpUndef %float", + "%inner = OpConstantComposite %inner_struct %bool_true %signed_one %undef", + "%outer = OpConstantComposite %outer_struct %inner %signed_one", + "%extract_inner = OpSpecConstantOp %int CompositeExtract %inner 1", + "%extract_outer = OpSpecConstantOp %int CompositeExtract %outer 1", + }, + }, + // TODO(qining): Add tests for Array and other composite type constants. + // clang-format on + }))); + +// Tests the swizzle operations for spec const vectors. +INSTANTIATE_TEST_CASE_P( + VectorShuffle, FoldSpecConstantOpAndCompositePassTest, + ::testing::ValuesIn(std::vector< + FoldSpecConstantOpAndCompositePassTestCase>({ + // clang-format off + // normal vector + { + // original + { + "%xy = OpSpecConstantOp %v2int VectorShuffle %v4int_0_1_2_3 %v4int_0_1_2_3 0 1", + "%yz = OpSpecConstantOp %v2int VectorShuffle %v4int_0_1_2_3 %v4int_0_1_2_3 1 2", + "%zw = OpSpecConstantOp %v2int VectorShuffle %v4int_0_1_2_3 %v4int_0_1_2_3 2 3", + "%wx = OpSpecConstantOp %v2int VectorShuffle %v4int_0_1_2_3 %v4int_0_1_2_3 3 0", + "%xx = OpSpecConstantOp %v2int VectorShuffle %v4int_0_1_2_3 %v4int_0_1_2_3 0 0", + "%yyy = OpSpecConstantOp %v3int VectorShuffle %v4int_0_1_2_3 %v4int_0_1_2_3 1 1 1", + "%wwww = OpSpecConstantOp %v4int VectorShuffle %v4int_0_1_2_3 %v4int_0_1_2_3 2 2 2 2", + }, + // expected + { + "%xy = OpConstantComposite %v2int %signed_zero %signed_one", + "%yz = OpConstantComposite %v2int %signed_one %signed_two", + "%zw = OpConstantComposite %v2int %signed_two %signed_three", + "%wx = OpConstantComposite %v2int %signed_three %signed_zero", + "%xx = OpConstantComposite %v2int %signed_zero %signed_zero", + "%yyy = OpConstantComposite %v3int %signed_one %signed_one %signed_one", + "%wwww = OpConstantComposite %v4int %signed_two %signed_two %signed_two %signed_two", + }, + }, + // null vector + { + // original + { + "%a = OpSpecConstantOp %v2int VectorShuffle %signed_null_vec %v4int_0_1_2_3 0 1", + "%b = OpSpecConstantOp %v2int VectorShuffle %signed_null_vec %v4int_0_1_2_3 2 3", + "%c = OpSpecConstantOp %v2int VectorShuffle %v4int_0_1_2_3 %signed_null_vec 3 4", + "%d = OpSpecConstantOp %v2int VectorShuffle %signed_null_vec %signed_null_vec 1 2", + }, + // expected + { + "%60 = OpConstantNull %int", + "%a = OpConstantComposite %v2int %signed_null %signed_null", + "%62 = OpConstantNull %int", + "%b = OpConstantComposite %v2int %signed_zero %signed_one", + "%64 = OpConstantNull %int", + "%c = OpConstantComposite %v2int %signed_three %signed_null", + "%66 = OpConstantNull %int", + "%d = OpConstantComposite %v2int %signed_null %signed_null", + } + }, + // skip if any of the components of the vector operands do not have + // determined value, even though the result vector might not be + // built with those undermined values. + { + // original + { + "%spec_int = OpSpecConstant %int 1", + "%spec_ivec = OpSpecConstantComposite %v2int %signed_null %spec_int", + "%a = OpSpecConstantOp %v2int VectorShuffle %v4int_0_1_2_3 %spec_ivec 0 1", + "%b = OpSpecConstantOp %v2int VectorShuffle %v4int_0_1_2_3 %spec_ivec 3 4", + }, + // expected + { + "%spec_int = OpSpecConstant %int 1", + "%spec_ivec = OpSpecConstantComposite %v2int %signed_null %spec_int", + "%a = OpSpecConstantOp %v2int VectorShuffle %v4int_0_1_2_3 %spec_ivec 0 1", + "%b = OpSpecConstantOp %v2int VectorShuffle %v4int_0_1_2_3 %spec_ivec 3 4", + }, + }, + // Skip if any components of the two vector operands depend on + // the result of OpUndef. Even though the selected components do + // not depend on the OpUndef result. + { + // original + { + "%undef = OpUndef %int", + "%vec_1 = OpConstantComposite %v2int %undef %signed_one", + "%dep = OpSpecConstantOp %v2int VectorShuffle %vec_1 %signed_three_vec 0 3", + "%not_dep_element = OpSpecConstantOp %v2int VectorShuffle %vec_1 %signed_three_vec 1 3", + "%no_dep_vector = OpSpecConstantOp %v2int VectorShuffle %vec_1 %signed_three_vec 2 3", + }, + // expected + { + "%undef = OpUndef %int", + "%vec_1 = OpConstantComposite %v2int %undef %signed_one", + "%dep = OpSpecConstantOp %v2int VectorShuffle %vec_1 %signed_three_vec 0 3", + "%not_dep_element = OpSpecConstantOp %v2int VectorShuffle %vec_1 %signed_three_vec 1 3", + "%no_dep_vector = OpSpecConstantOp %v2int VectorShuffle %vec_1 %signed_three_vec 2 3", + }, + }, + // clang-format on + }))); + +// Test with long use-def chain. +INSTANTIATE_TEST_CASE_P( + LongDefUseChain, FoldSpecConstantOpAndCompositePassTest, + ::testing::ValuesIn(std::vector< + FoldSpecConstantOpAndCompositePassTestCase>({ + // clang-format off + // Long Def-Use chain with binary operations. + { + // original + { + "%array_size = OpConstant %int 4", + "%type_arr_int_4 = OpTypeArray %int %array_size", + "%spec_int_0 = OpConstant %int 100", + "%spec_int_1 = OpConstant %int 1", + "%spec_int_2 = OpSpecConstantOp %int IAdd %spec_int_0 %spec_int_1", + "%spec_int_3 = OpSpecConstantOp %int ISub %spec_int_0 %spec_int_2", + "%spec_int_4 = OpSpecConstantOp %int IAdd %spec_int_0 %spec_int_3", + "%spec_int_5 = OpSpecConstantOp %int ISub %spec_int_0 %spec_int_4", + "%spec_int_6 = OpSpecConstantOp %int IAdd %spec_int_0 %spec_int_5", + "%spec_int_7 = OpSpecConstantOp %int ISub %spec_int_0 %spec_int_6", + "%spec_int_8 = OpSpecConstantOp %int IAdd %spec_int_0 %spec_int_7", + "%spec_int_9 = OpSpecConstantOp %int ISub %spec_int_0 %spec_int_8", + "%spec_int_10 = OpSpecConstantOp %int IAdd %spec_int_0 %spec_int_9", + "%spec_int_11 = OpSpecConstantOp %int ISub %spec_int_0 %spec_int_10", + "%spec_int_12 = OpSpecConstantOp %int IAdd %spec_int_0 %spec_int_11", + "%spec_int_13 = OpSpecConstantOp %int ISub %spec_int_0 %spec_int_12", + "%spec_int_14 = OpSpecConstantOp %int IAdd %spec_int_0 %spec_int_13", + "%spec_int_15 = OpSpecConstantOp %int ISub %spec_int_0 %spec_int_14", + "%spec_int_16 = OpSpecConstantOp %int ISub %spec_int_0 %spec_int_15", + "%spec_int_17 = OpSpecConstantOp %int IAdd %spec_int_0 %spec_int_16", + "%spec_int_18 = OpSpecConstantOp %int ISub %spec_int_0 %spec_int_17", + "%spec_int_19 = OpSpecConstantOp %int IAdd %spec_int_0 %spec_int_18", + "%spec_int_20 = OpSpecConstantOp %int ISub %spec_int_0 %spec_int_19", + "%used_vec_a = OpSpecConstantComposite %v2int %spec_int_18 %spec_int_19", + "%used_vec_b = OpSpecConstantOp %v2int IMul %used_vec_a %used_vec_a", + "%spec_int_21 = OpSpecConstantOp %int CompositeExtract %used_vec_b 0", + "%array = OpConstantComposite %type_arr_int_4 %spec_int_20 %spec_int_20 %spec_int_21 %spec_int_21", + // Spec constants whose values can not be fully resolved should + // not be processed. + "%spec_int_22 = OpSpecConstant %int 123", + "%spec_int_23 = OpSpecConstantOp %int IAdd %spec_int_22 %signed_one", + }, + // expected + { + "%array_size = OpConstant %int 4", + "%type_arr_int_4 = OpTypeArray %int %array_size", + "%spec_int_0 = OpConstant %int 100", + "%spec_int_1 = OpConstant %int 1", + "%spec_int_2 = OpConstant %int 101", + "%spec_int_3 = OpConstant %int -1", + "%spec_int_4 = OpConstant %int 99", + "%spec_int_5 = OpConstant %int 1", + "%spec_int_6 = OpConstant %int 101", + "%spec_int_7 = OpConstant %int -1", + "%spec_int_8 = OpConstant %int 99", + "%spec_int_9 = OpConstant %int 1", + "%spec_int_10 = OpConstant %int 101", + "%spec_int_11 = OpConstant %int -1", + "%spec_int_12 = OpConstant %int 99", + "%spec_int_13 = OpConstant %int 1", + "%spec_int_14 = OpConstant %int 101", + "%spec_int_15 = OpConstant %int -1", + "%spec_int_16 = OpConstant %int 101", + "%spec_int_17 = OpConstant %int 201", + "%spec_int_18 = OpConstant %int -101", + "%spec_int_19 = OpConstant %int -1", + "%spec_int_20 = OpConstant %int 101", + "%used_vec_a = OpConstantComposite %v2int %spec_int_18 %spec_int_19", + "%int_10201 = OpConstant %int 10201", + "%int_1 = OpConstant %int 1", + "%used_vec_b = OpConstantComposite %v2int %int_10201 %signed_one", + "%spec_int_21 = OpConstant %int 10201", + "%array = OpConstantComposite %type_arr_int_4 %spec_int_20 %spec_int_20 %spec_int_21 %spec_int_21", + "%spec_int_22 = OpSpecConstant %int 123", + "%spec_int_23 = OpSpecConstantOp %int IAdd %spec_int_22 %signed_one", + }, + }, + // Long Def-Use chain with swizzle + }))); + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp new file mode 100644 index 000000000..487c18a83 --- /dev/null +++ b/test/opt/fold_test.cpp @@ -0,0 +1,6198 @@ +// Copyright (c) 2016 Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "effcee/effcee.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "source/opt/build_module.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/fold.h" +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "spirv-tools/libspirv.hpp" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::Contains; + +std::string Disassemble(const std::string& original, IRContext* context, + uint32_t disassemble_options = 0) { + std::vector optimized_bin; + context->module()->ToBinary(&optimized_bin, true); + spv_target_env target_env = SPV_ENV_UNIVERSAL_1_2; + SpirvTools tools(target_env); + std::string optimized_asm; + EXPECT_TRUE( + tools.Disassemble(optimized_bin, &optimized_asm, disassemble_options)) + << "Disassembling failed for shader:\n" + << original << std::endl; + return optimized_asm; +} + +void Match(const std::string& original, IRContext* context, + uint32_t disassemble_options = 0) { + std::string disassembly = Disassemble(original, context, disassemble_options); + auto match_result = effcee::Match(disassembly, original); + EXPECT_EQ(effcee::Result::Status::Ok, match_result.status()) + << match_result.message() << "\nChecking result:\n" + << disassembly; +} + +template +struct InstructionFoldingCase { + InstructionFoldingCase(const std::string& tb, uint32_t id, ResultType result) + : test_body(tb), id_to_fold(id), expected_result(result) {} + + std::string test_body; + uint32_t id_to_fold; + ResultType expected_result; +}; + +using IntegerInstructionFoldingTest = + ::testing::TestWithParam>; + +TEST_P(IntegerInstructionFoldingTest, Case) { + const auto& tc = GetParam(); + + // Build module. + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(nullptr, context); + + // Fold the instruction to test. + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); + bool succeeded = context->get_instruction_folder().FoldInstruction(inst); + + // Make sure the instruction folded as expected. + EXPECT_TRUE(succeeded); + if (inst != nullptr) { + EXPECT_EQ(inst->opcode(), SpvOpCopyObject); + inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); + EXPECT_EQ(inst->opcode(), SpvOpConstant); + analysis::ConstantManager* const_mrg = context->get_constant_mgr(); + const analysis::IntConstant* result = + const_mrg->GetConstantFromInst(inst)->AsIntConstant(); + EXPECT_NE(result, nullptr); + if (result != nullptr) { + EXPECT_EQ(result->GetU32BitValue(), tc.expected_result); + } + } +} + +// Returns a common SPIR-V header for all of the test that follow. +#define INT_0_ID 100 +#define TRUE_ID 101 +#define VEC2_0_ID 102 +#define INT_7_ID 103 +#define FLOAT_0_ID 104 +#define DOUBLE_0_ID 105 +#define VEC4_0_ID 106 +#define DVEC4_0_ID 106 +#define HALF_0_ID 108 +const std::string& Header() { + static const std::string header = R"(OpCapability Shader +OpCapability Float16 +OpCapability Float64 +OpCapability Int16 +OpCapability Int64 +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +%void = OpTypeVoid +%void_func = OpTypeFunction %void +%bool = OpTypeBool +%float = OpTypeFloat 32 +%double = OpTypeFloat 64 +%half = OpTypeFloat 16 +%101 = OpConstantTrue %bool ; Need a def with an numerical id to define id maps. +%true = OpConstantTrue %bool +%false = OpConstantFalse %bool +%bool_null = OpConstantNull %bool +%short = OpTypeInt 16 1 +%int = OpTypeInt 32 1 +%long = OpTypeInt 64 1 +%uint = OpTypeInt 32 0 +%v2int = OpTypeVector %int 2 +%v4int = OpTypeVector %int 4 +%v4float = OpTypeVector %float 4 +%v4double = OpTypeVector %double 4 +%v2float = OpTypeVector %float 2 +%v2double = OpTypeVector %double 2 +%v2bool = OpTypeVector %bool 2 +%struct_v2int_int_int = OpTypeStruct %v2int %int %int +%_ptr_int = OpTypePointer Function %int +%_ptr_uint = OpTypePointer Function %uint +%_ptr_bool = OpTypePointer Function %bool +%_ptr_float = OpTypePointer Function %float +%_ptr_double = OpTypePointer Function %double +%_ptr_half = OpTypePointer Function %half +%_ptr_long = OpTypePointer Function %long +%_ptr_v2int = OpTypePointer Function %v2int +%_ptr_v4int = OpTypePointer Function %v4int +%_ptr_v4float = OpTypePointer Function %v4float +%_ptr_v4double = OpTypePointer Function %v4double +%_ptr_struct_v2int_int_int = OpTypePointer Function %struct_v2int_int_int +%_ptr_v2float = OpTypePointer Function %v2float +%_ptr_v2double = OpTypePointer Function %v2double +%short_0 = OpConstant %short 0 +%short_2 = OpConstant %short 2 +%short_3 = OpConstant %short 3 +%100 = OpConstant %int 0 ; Need a def with an numerical id to define id maps. +%103 = OpConstant %int 7 ; Need a def with an numerical id to define id maps. +%int_0 = OpConstant %int 0 +%int_1 = OpConstant %int 1 +%int_2 = OpConstant %int 2 +%int_3 = OpConstant %int 3 +%int_4 = OpConstant %int 4 +%int_n24 = OpConstant %int -24 +%int_min = OpConstant %int -2147483648 +%int_max = OpConstant %int 2147483647 +%long_0 = OpConstant %long 0 +%long_2 = OpConstant %long 2 +%long_3 = OpConstant %long 3 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%uint_2 = OpConstant %uint 2 +%uint_3 = OpConstant %uint 3 +%uint_4 = OpConstant %uint 4 +%uint_32 = OpConstant %uint 32 +%uint_42 = OpConstant %uint 42 +%uint_max = OpConstant %uint 4294967295 +%v2int_undef = OpUndef %v2int +%v2int_0_0 = OpConstantComposite %v2int %int_0 %int_0 +%v2int_1_0 = OpConstantComposite %v2int %int_1 %int_0 +%v2int_2_2 = OpConstantComposite %v2int %int_2 %int_2 +%v2int_2_3 = OpConstantComposite %v2int %int_2 %int_3 +%v2int_3_2 = OpConstantComposite %v2int %int_3 %int_2 +%v2int_4_4 = OpConstantComposite %v2int %int_4 %int_4 +%v2bool_null = OpConstantNull %v2bool +%v2bool_true_false = OpConstantComposite %v2bool %true %false +%v2bool_false_true = OpConstantComposite %v2bool %false %true +%struct_v2int_int_int_null = OpConstantNull %struct_v2int_int_int +%v2int_null = OpConstantNull %v2int +%102 = OpConstantComposite %v2int %103 %103 +%v4int_0_0_0_0 = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0 +%struct_undef_0_0 = OpConstantComposite %struct_v2int_int_int %v2int_undef %int_0 %int_0 +%float_n1 = OpConstant %float -1 +%104 = OpConstant %float 0 ; Need a def with an numerical id to define id maps. +%float_null = OpConstantNull %float +%float_0 = OpConstant %float 0 +%float_1 = OpConstant %float 1 +%float_2 = OpConstant %float 2 +%float_3 = OpConstant %float 3 +%float_4 = OpConstant %float 4 +%float_0p5 = OpConstant %float 0.5 +%v2float_0_0 = OpConstantComposite %v2float %float_0 %float_0 +%v2float_2_2 = OpConstantComposite %v2float %float_2 %float_2 +%v2float_2_3 = OpConstantComposite %v2float %float_2 %float_3 +%v2float_3_2 = OpConstantComposite %v2float %float_3 %float_2 +%v2float_4_4 = OpConstantComposite %v2float %float_4 %float_4 +%v2float_2_0p5 = OpConstantComposite %v2float %float_2 %float_0p5 +%v2float_null = OpConstantNull %v2float +%double_n1 = OpConstant %double -1 +%105 = OpConstant %double 0 ; Need a def with an numerical id to define id maps. +%double_null = OpConstantNull %double +%double_0 = OpConstant %double 0 +%double_1 = OpConstant %double 1 +%double_2 = OpConstant %double 2 +%double_3 = OpConstant %double 3 +%double_4 = OpConstant %double 4 +%double_0p5 = OpConstant %double 0.5 +%v2double_0_0 = OpConstantComposite %v2double %double_0 %double_0 +%v2double_2_2 = OpConstantComposite %v2double %double_2 %double_2 +%v2double_2_3 = OpConstantComposite %v2double %double_2 %double_3 +%v2double_3_2 = OpConstantComposite %v2double %double_3 %double_2 +%v2double_4_4 = OpConstantComposite %v2double %double_4 %double_4 +%v2double_2_0p5 = OpConstantComposite %v2double %double_2 %double_0p5 +%v2double_null = OpConstantNull %v2double +%108 = OpConstant %half 0 +%half_1 = OpConstant %half 1 +%106 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +%v4float_0_0_0_0 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +%v4float_0_0_0_1 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_1 +%v4float_0_1_0_0 = OpConstantComposite %v4float %float_0 %float_1 %float_null %float_0 +%v4float_1_1_1_1 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +%107 = OpConstantComposite %v4double %double_0 %double_0 %double_0 %double_0 +%v4double_0_0_0_0 = OpConstantComposite %v4double %double_0 %double_0 %double_0 %double_0 +%v4double_0_0_0_1 = OpConstantComposite %v4double %double_0 %double_0 %double_0 %double_1 +%v4double_0_1_0_0 = OpConstantComposite %v4double %double_0 %double_1 %double_null %double_0 +%v4double_1_1_1_1 = OpConstantComposite %v4double %double_1 %double_1 %double_1 %double_1 +%v4double_1_1_1_0p5 = OpConstantComposite %v4double %double_1 %double_1 %double_1 %double_0p5 +%v4double_null = OpConstantNull %v4double +%v4float_n1_2_1_3 = OpConstantComposite %v4float %float_n1 %float_2 %float_1 %float_3 +)"; + + return header; +} + +// Returns the header with definitions of float NaN and double NaN. Since FC +// "; CHECK: [[double_n0:%\\w+]] = OpConstant [[double]] -0\n" finds +// %double_nan = OpConstant %double -0x1.8p+1024 instead of +// %double_n0 = OpConstant %double -0, +// we separates those definitions from Header(). +const std::string& HeaderWithNaN() { + static const std::string headerWithNaN = + Header() + + R"(%float_nan = OpConstant %float -0x1.8p+128 +%double_nan = OpConstant %double -0x1.8p+1024 +)"; + + return headerWithNaN; +} + +// clang-format off +INSTANTIATE_TEST_CASE_P(TestCase, IntegerInstructionFoldingTest, + ::testing::Values( + // Test case 0: fold 0*n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpIMul %int %int_0 %load\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 1: fold n*0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpIMul %int %load %int_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 2: fold 0/n (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSDiv %int %int_0 %load\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 3: fold n/0 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSDiv %int %load %int_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 4: fold 0/n (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load = OpLoad %uint %n\n" + + "%2 = OpUDiv %uint %uint_0 %load\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 5: fold n/0 (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSDiv %int %load %int_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 6: fold 0 remainder n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSRem %int %int_0 %load\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 7: fold n remainder 0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSRem %int %load %int_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 8: fold 0%n (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSMod %int %int_0 %load\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 9: fold n%0 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSMod %int %load %int_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 10: fold 0%n (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load = OpLoad %uint %n\n" + + "%2 = OpUMod %uint %uint_0 %load\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 11: fold n%0 (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load = OpLoad %uint %n\n" + + "%2 = OpUMod %uint %load %uint_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 12: fold n << 32 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load = OpLoad %uint %n\n" + + "%2 = OpShiftLeftLogical %uint %load %uint_32\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 13: fold n >> 32 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load = OpLoad %uint %n\n" + + "%2 = OpShiftRightLogical %uint %load %uint_32\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 14: fold n | 0xFFFFFFFF + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load = OpLoad %uint %n\n" + + "%2 = OpBitwiseOr %uint %load %uint_max\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0xFFFFFFFF), + // Test case 15: fold 0xFFFFFFFF | n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load = OpLoad %uint %n\n" + + "%2 = OpBitwiseOr %uint %uint_max %load\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0xFFFFFFFF), + // Test case 16: fold n & 0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load = OpLoad %uint %n\n" + + "%2 = OpBitwiseAnd %uint %load %uint_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 17: fold 1/0 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpSDiv %int %int_1 %int_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 18: fold 1/0 (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpUDiv %uint %uint_1 %uint_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 19: fold OpSRem 1 0 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpSRem %int %int_1 %int_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 20: fold 1%0 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpSMod %int %int_1 %int_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 21: fold 1%0 (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpUMod %uint %uint_1 %uint_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 22: fold unsigned n >> 42 (undefined, so set to zero). + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load = OpLoad %uint %n\n" + + "%2 = OpShiftRightLogical %uint %load %uint_42\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 23: fold signed n >> 42 (undefined, so set to zero). + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpShiftRightLogical %int %load %uint_42\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 24: fold n << 42 (undefined, so set to zero). + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpShiftLeftLogical %int %load %uint_42\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 25: fold -24 >> 32 (defined as -1) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpShiftRightArithmetic %int %int_n24 %uint_32\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, -1), + // Test case 26: fold 2 >> 32 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpShiftRightArithmetic %int %int_2 %uint_32\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 27: fold 2 >> 32 (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpShiftRightLogical %int %int_2 %uint_32\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 28: fold 2 << 32 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpShiftLeftLogical %int %int_2 %uint_32\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 29: fold -INT_MIN + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpSNegate %int %int_min\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, std::numeric_limits::min()) +)); +// clang-format on + +using IntVectorInstructionFoldingTest = + ::testing::TestWithParam>>; + +TEST_P(IntVectorInstructionFoldingTest, Case) { + const auto& tc = GetParam(); + + // Build module. + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(nullptr, context); + + // Fold the instruction to test. + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); + SpvOp original_opcode = inst->opcode(); + bool succeeded = context->get_instruction_folder().FoldInstruction(inst); + + // Make sure the instruction folded as expected. + EXPECT_EQ(succeeded, inst == nullptr || inst->opcode() != original_opcode); + if (succeeded && inst != nullptr) { + EXPECT_EQ(inst->opcode(), SpvOpCopyObject); + inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); + std::vector opcodes = {SpvOpConstantComposite}; + EXPECT_THAT(opcodes, Contains(inst->opcode())); + analysis::ConstantManager* const_mrg = context->get_constant_mgr(); + const analysis::Constant* result = const_mrg->GetConstantFromInst(inst); + EXPECT_NE(result, nullptr); + if (result != nullptr) { + const std::vector& componenets = + result->AsVectorConstant()->GetComponents(); + EXPECT_EQ(componenets.size(), tc.expected_result.size()); + for (size_t i = 0; i < componenets.size(); i++) { + EXPECT_EQ(tc.expected_result[i], componenets[i]->GetU32()); + } + } + } +} + +// clang-format off +INSTANTIATE_TEST_CASE_P(TestCase, IntVectorInstructionFoldingTest, +::testing::Values( + // Test case 0: fold 0*n + InstructionFoldingCase>( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpVectorShuffle %v2int %v2int_2_2 %v2int_2_3 0 3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, {2,3}), + InstructionFoldingCase>( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpVectorShuffle %v2int %v2int_null %v2int_2_3 0 3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, {0,3}), + InstructionFoldingCase>( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpVectorShuffle %v2int %v2int_null %v2int_2_3 4294967295 3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, {0,0}), + InstructionFoldingCase>( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpVectorShuffle %v2int %v2int_null %v2int_2_3 0 4294967295 \n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, {0,0}) +)); +// clang-format on + +using BooleanInstructionFoldingTest = + ::testing::TestWithParam>; + +TEST_P(BooleanInstructionFoldingTest, Case) { + const auto& tc = GetParam(); + + // Build module. + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(nullptr, context); + + // Fold the instruction to test. + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); + bool succeeded = context->get_instruction_folder().FoldInstruction(inst); + + // Make sure the instruction folded as expected. + EXPECT_TRUE(succeeded); + if (inst != nullptr) { + EXPECT_EQ(inst->opcode(), SpvOpCopyObject); + inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); + std::vector bool_opcodes = {SpvOpConstantTrue, SpvOpConstantFalse}; + EXPECT_THAT(bool_opcodes, Contains(inst->opcode())); + analysis::ConstantManager* const_mrg = context->get_constant_mgr(); + const analysis::BoolConstant* result = + const_mrg->GetConstantFromInst(inst)->AsBoolConstant(); + EXPECT_NE(result, nullptr); + if (result != nullptr) { + EXPECT_EQ(result->value(), tc.expected_result); + } + } +} + +// clang-format off +INSTANTIATE_TEST_CASE_P(TestCase, BooleanInstructionFoldingTest, + ::testing::Values( + // Test case 0: fold true || n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_bool Function\n" + + "%load = OpLoad %bool %n\n" + + "%2 = OpLogicalOr %bool %true %load\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 1: fold n || true + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_bool Function\n" + + "%load = OpLoad %bool %n\n" + + "%2 = OpLogicalOr %bool %load %true\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 2: fold false && n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_bool Function\n" + + "%load = OpLoad %bool %n\n" + + "%2 = OpLogicalAnd %bool %false %load\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 3: fold n && false + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_bool Function\n" + + "%load = OpLoad %bool %n\n" + + "%2 = OpLogicalAnd %bool %load %false\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 4: fold n < 0 (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load = OpLoad %uint %n\n" + + "%2 = OpULessThan %bool %load %uint_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 5: fold UINT_MAX < n (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load = OpLoad %uint %n\n" + + "%2 = OpULessThan %bool %uint_max %load\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 6: fold INT_MAX < n (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSLessThan %bool %int_max %load\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 7: fold n < INT_MIN (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSLessThan %bool %load %int_min\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 8: fold 0 > n (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load = OpLoad %uint %n\n" + + "%2 = OpUGreaterThan %bool %uint_0 %load\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 9: fold n > UINT_MAX (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load = OpLoad %uint %n\n" + + "%2 = OpUGreaterThan %bool %load %uint_max\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 10: fold n > INT_MAX (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSGreaterThan %bool %load %int_max\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 11: fold INT_MIN > n (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load = OpLoad %uint %n\n" + + "%2 = OpSGreaterThan %bool %int_min %load\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 12: fold 0 <= n (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load = OpLoad %uint %n\n" + + "%2 = OpULessThanEqual %bool %uint_0 %load\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 13: fold n <= UINT_MAX (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load = OpLoad %uint %n\n" + + "%2 = OpULessThanEqual %bool %load %uint_max\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 14: fold INT_MIN <= n (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSLessThanEqual %bool %int_min %load\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 15: fold n <= INT_MAX (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSLessThanEqual %bool %load %int_max\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 16: fold n >= 0 (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load = OpLoad %uint %n\n" + + "%2 = OpUGreaterThanEqual %bool %load %uint_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 17: fold UINT_MAX >= n (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load = OpLoad %uint %n\n" + + "%2 = OpUGreaterThanEqual %bool %uint_max %load\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 18: fold n >= INT_MIN (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSGreaterThanEqual %bool %load %int_min\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 19: fold INT_MAX >= n (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSGreaterThanEqual %bool %int_max %load\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true) +)); + +INSTANTIATE_TEST_CASE_P(FClampAndCmpLHS, BooleanInstructionFoldingTest, +::testing::Values( + // Test case 0: fold 0.0 > clamp(n, 0.0, 1.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_0 %float_1\n" + + "%2 = OpFOrdGreaterThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 1: fold 0.0 > clamp(n, -1.0, -1.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_n1\n" + + "%2 = OpFOrdGreaterThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 2: fold 0.0 >= clamp(n, 1, 2) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_1 %float_2\n" + + "%2 = OpFOrdGreaterThanEqual %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 3: fold 0.0 >= clamp(n, -1.0, 0.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_0\n" + + "%2 = OpFOrdGreaterThanEqual %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 4: fold 0.0 <= clamp(n, 0.0, 1.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_0 %float_1\n" + + "%2 = OpFOrdLessThanEqual %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 5: fold 0.0 <= clamp(n, -1.0, -1.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_n1\n" + + "%2 = OpFOrdLessThanEqual %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 6: fold 0.0 < clamp(n, 1, 2) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_1 %float_2\n" + + "%2 = OpFOrdLessThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 7: fold 0.0 < clamp(n, -1.0, 0.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_0\n" + + "%2 = OpFOrdLessThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 8: fold 0.0 > clamp(n, 0.0, 1.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_0 %float_1\n" + + "%2 = OpFUnordGreaterThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 9: fold 0.0 > clamp(n, -1.0, -1.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_n1\n" + + "%2 = OpFUnordGreaterThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 10: fold 0.0 >= clamp(n, 1, 2) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_1 %float_2\n" + + "%2 = OpFUnordGreaterThanEqual %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 11: fold 0.0 >= clamp(n, -1.0, 0.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_0\n" + + "%2 = OpFUnordGreaterThanEqual %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 12: fold 0.0 <= clamp(n, 0.0, 1.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_0 %float_1\n" + + "%2 = OpFUnordLessThanEqual %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 13: fold 0.0 <= clamp(n, -1.0, -1.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_n1\n" + + "%2 = OpFUnordLessThanEqual %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 14: fold 0.0 < clamp(n, 1, 2) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_1 %float_2\n" + + "%2 = OpFUnordLessThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 15: fold 0.0 < clamp(n, -1.0, 0.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_0\n" + + "%2 = OpFUnordLessThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false) +)); + +INSTANTIATE_TEST_CASE_P(FClampAndCmpRHS, BooleanInstructionFoldingTest, +::testing::Values( + // Test case 0: fold clamp(n, 0.0, 1.0) > 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_0 %float_1\n" + + "%2 = OpFOrdGreaterThan %bool %clamp %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 1: fold clamp(n, 1.0, 1.0) > 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_1 %float_1\n" + + "%2 = OpFOrdGreaterThan %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 2: fold clamp(n, 1, 2) >= 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_1 %float_2\n" + + "%2 = OpFOrdGreaterThanEqual %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 3: fold clamp(n, 1.0, 2.0) >= 3.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_1 %float_2\n" + + "%2 = OpFOrdGreaterThanEqual %bool %clamp %float_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 4: fold clamp(n, 0.0, 1.0) <= 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_0 %float_1\n" + + "%2 = OpFOrdLessThanEqual %bool %clamp %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 5: fold clamp(n, 1.0, 2.0) <= 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_1 %float_2\n" + + "%2 = OpFOrdLessThanEqual %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 6: fold clamp(n, 1, 2) < 3 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_1 %float_2\n" + + "%2 = OpFOrdLessThan %bool %clamp %float_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 7: fold clamp(n, -1.0, 0.0) < -1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_0\n" + + "%2 = OpFOrdLessThan %bool %clamp %float_n1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 8: fold clamp(n, 0.0, 1.0) > 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_0 %float_1\n" + + "%2 = OpFUnordGreaterThan %bool %clamp %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 9: fold clamp(n, 1.0, 2.0) > 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_1 %float_2\n" + + "%2 = OpFUnordGreaterThan %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 10: fold clamp(n, 1, 2) >= 3.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_1 %float_2\n" + + "%2 = OpFUnordGreaterThanEqual %bool %clamp %float_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 11: fold clamp(n, -1.0, 0.0) >= -1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_0\n" + + "%2 = OpFUnordGreaterThanEqual %bool %clamp %float_n1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 12: fold clamp(n, 0.0, 1.0) <= 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_0 %float_1\n" + + "%2 = OpFUnordLessThanEqual %bool %clamp %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 13: fold clamp(n, 1.0, 1.0) <= 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_1 %float_1\n" + + "%2 = OpFUnordLessThanEqual %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 14: fold clamp(n, 1, 2) < 3 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_1 %float_2\n" + + "%2 = OpFUnordLessThan %bool %clamp %float_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 15: fold clamp(n, -1.0, 0.0) < -1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_0\n" + + "%2 = OpFUnordLessThan %bool %clamp %float_n1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 16: fold clamp(n, -1.0, 0.0) < -1.0 (one test for double) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_double Function\n" + + "%ld = OpLoad %double %n\n" + + "%clamp = OpExtInst %double %1 FClamp %ld %double_n1 %double_0\n" + + "%2 = OpFUnordLessThan %bool %clamp %double_n1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false) +)); +// clang-format on + +using FloatInstructionFoldingTest = + ::testing::TestWithParam>; + +TEST_P(FloatInstructionFoldingTest, Case) { + const auto& tc = GetParam(); + + // Build module. + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(nullptr, context); + + // Fold the instruction to test. + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); + bool succeeded = context->get_instruction_folder().FoldInstruction(inst); + + // Make sure the instruction folded as expected. + EXPECT_TRUE(succeeded); + if (inst != nullptr) { + EXPECT_EQ(inst->opcode(), SpvOpCopyObject); + inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); + EXPECT_EQ(inst->opcode(), SpvOpConstant); + analysis::ConstantManager* const_mrg = context->get_constant_mgr(); + const analysis::FloatConstant* result = + const_mrg->GetConstantFromInst(inst)->AsFloatConstant(); + EXPECT_NE(result, nullptr); + if (result != nullptr) { + EXPECT_EQ(result->GetFloatValue(), tc.expected_result); + } + } +} + +// Not testing NaNs because there are no expectations concerning NaNs according +// to the "Precision and Operation of SPIR-V Instructions" section of the Vulkan +// specification. + +// clang-format off +INSTANTIATE_TEST_CASE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest, +::testing::Values( + // Test case 0: Fold 2.0 - 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFSub %float %float_2 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 1.0), + // Test case 1: Fold 2.0 + 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFAdd %float %float_2 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3.0), + // Test case 2: Fold 3.0 * 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFMul %float %float_3 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 6.0), + // Test case 3: Fold 1.0 / 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFDiv %float %float_1 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0.5), + // Test case 4: Fold 1.0 / 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFDiv %float %float_1 %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, std::numeric_limits::infinity()), + // Test case 5: Fold -1.0 / 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFDiv %float %float_n1 %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, -std::numeric_limits::infinity()), + // Test case 6: Fold (2.0, 3.0) dot (2.0, 0.5) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpDot %float %v2float_2_3 %v2float_2_0p5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 5.5f), + // Test case 7: Fold (0.0, 0.0) dot v + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%v = OpVariable %_ptr_v2float Function\n" + + "%2 = OpLoad %v2float %v\n" + + "%3 = OpDot %float %v2float_0_0 %2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, 0.0f), + // Test case 8: Fold v dot (0.0, 0.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%v = OpVariable %_ptr_v2float Function\n" + + "%2 = OpLoad %v2float %v\n" + + "%3 = OpDot %float %2 %v2float_0_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, 0.0f), + // Test case 9: Fold Null dot v + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%v = OpVariable %_ptr_v2float Function\n" + + "%2 = OpLoad %v2float %v\n" + + "%3 = OpDot %float %v2float_null %2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, 0.0f), + // Test case 10: Fold v dot Null + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%v = OpVariable %_ptr_v2float Function\n" + + "%2 = OpLoad %v2float %v\n" + + "%3 = OpDot %float %2 %v2float_null\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, 0.0f), + // Test case 11: Fold -2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFNegate %float %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, -2) +)); +// clang-format on + +using DoubleInstructionFoldingTest = + ::testing::TestWithParam>; + +TEST_P(DoubleInstructionFoldingTest, Case) { + const auto& tc = GetParam(); + + // Build module. + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(nullptr, context); + + // Fold the instruction to test. + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); + bool succeeded = context->get_instruction_folder().FoldInstruction(inst); + + // Make sure the instruction folded as expected. + EXPECT_TRUE(succeeded); + if (inst != nullptr) { + EXPECT_EQ(inst->opcode(), SpvOpCopyObject); + inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); + EXPECT_EQ(inst->opcode(), SpvOpConstant); + analysis::ConstantManager* const_mrg = context->get_constant_mgr(); + const analysis::FloatConstant* result = + const_mrg->GetConstantFromInst(inst)->AsFloatConstant(); + EXPECT_NE(result, nullptr); + if (result != nullptr) { + EXPECT_EQ(result->GetDoubleValue(), tc.expected_result); + } + } +} + +// clang-format off +INSTANTIATE_TEST_CASE_P(DoubleConstantFoldingTest, DoubleInstructionFoldingTest, +::testing::Values( + // Test case 0: Fold 2.0 - 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFSub %double %double_2 %double_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 1.0), + // Test case 1: Fold 2.0 + 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFAdd %double %double_2 %double_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3.0), + // Test case 2: Fold 3.0 * 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFMul %double %double_3 %double_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 6.0), + // Test case 3: Fold 1.0 / 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFDiv %double %double_1 %double_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0.5), + // Test case 4: Fold 1.0 / 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFDiv %double %double_1 %double_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, std::numeric_limits::infinity()), + // Test case 5: Fold -1.0 / 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFDiv %double %double_n1 %double_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, -std::numeric_limits::infinity()), + // Test case 6: Fold (2.0, 3.0) dot (2.0, 0.5) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpDot %double %v2double_2_3 %v2double_2_0p5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 5.5f), + // Test case 7: Fold (0.0, 0.0) dot v + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%v = OpVariable %_ptr_v2double Function\n" + + "%2 = OpLoad %v2double %v\n" + + "%3 = OpDot %double %v2double_0_0 %2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, 0.0f), + // Test case 8: Fold v dot (0.0, 0.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%v = OpVariable %_ptr_v2double Function\n" + + "%2 = OpLoad %v2double %v\n" + + "%3 = OpDot %double %2 %v2double_0_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, 0.0f), + // Test case 9: Fold Null dot v + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%v = OpVariable %_ptr_v2double Function\n" + + "%2 = OpLoad %v2double %v\n" + + "%3 = OpDot %double %v2double_null %2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, 0.0f), + // Test case 10: Fold v dot Null + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%v = OpVariable %_ptr_v2double Function\n" + + "%2 = OpLoad %v2double %v\n" + + "%3 = OpDot %double %2 %v2double_null\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, 0.0f), + // Test case 11: Fold -2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFNegate %double %double_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, -2) +)); +// clang-format on + +// clang-format off +INSTANTIATE_TEST_CASE_P(DoubleOrderedCompareConstantFoldingTest, BooleanInstructionFoldingTest, + ::testing::Values( + // Test case 0: fold 1.0 == 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdEqual %bool %double_1 %double_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 1: fold 1.0 != 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdNotEqual %bool %double_1 %double_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 2: fold 1.0 < 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdLessThan %bool %double_1 %double_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 3: fold 1.0 > 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdGreaterThan %bool %double_1 %double_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 4: fold 1.0 <= 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdLessThanEqual %bool %double_1 %double_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 5: fold 1.0 >= 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdGreaterThanEqual %bool %double_1 %double_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 6: fold 1.0 == 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdEqual %bool %double_1 %double_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 7: fold 1.0 != 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdNotEqual %bool %double_1 %double_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 8: fold 1.0 < 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdLessThan %bool %double_1 %double_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 9: fold 1.0 > 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdGreaterThan %bool %double_1 %double_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 10: fold 1.0 <= 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdLessThanEqual %bool %double_1 %double_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 11: fold 1.0 >= 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdGreaterThanEqual %bool %double_1 %double_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 12: fold 2.0 < 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdLessThan %bool %double_2 %double_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 13: fold 2.0 > 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdGreaterThan %bool %double_2 %double_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 14: fold 2.0 <= 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdLessThanEqual %bool %double_2 %double_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 15: fold 2.0 >= 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdGreaterThanEqual %bool %double_2 %double_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true) +)); + +INSTANTIATE_TEST_CASE_P(DoubleUnorderedCompareConstantFoldingTest, BooleanInstructionFoldingTest, + ::testing::Values( + // Test case 0: fold 1.0 == 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordEqual %bool %double_1 %double_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 1: fold 1.0 != 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordNotEqual %bool %double_1 %double_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 2: fold 1.0 < 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordLessThan %bool %double_1 %double_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 3: fold 1.0 > 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordGreaterThan %bool %double_1 %double_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 4: fold 1.0 <= 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordLessThanEqual %bool %double_1 %double_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 5: fold 1.0 >= 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordGreaterThanEqual %bool %double_1 %double_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 6: fold 1.0 == 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordEqual %bool %double_1 %double_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 7: fold 1.0 != 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordNotEqual %bool %double_1 %double_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 8: fold 1.0 < 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordLessThan %bool %double_1 %double_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 9: fold 1.0 > 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordGreaterThan %bool %double_1 %double_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 10: fold 1.0 <= 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordLessThanEqual %bool %double_1 %double_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 11: fold 1.0 >= 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordGreaterThanEqual %bool %double_1 %double_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 12: fold 2.0 < 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordLessThan %bool %double_2 %double_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 13: fold 2.0 > 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordGreaterThan %bool %double_2 %double_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 14: fold 2.0 <= 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordLessThanEqual %bool %double_2 %double_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 15: fold 2.0 >= 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordGreaterThanEqual %bool %double_2 %double_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true) +)); + +INSTANTIATE_TEST_CASE_P(FloatOrderedCompareConstantFoldingTest, BooleanInstructionFoldingTest, + ::testing::Values( + // Test case 0: fold 1.0 == 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdEqual %bool %float_1 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 1: fold 1.0 != 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdNotEqual %bool %float_1 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 2: fold 1.0 < 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdLessThan %bool %float_1 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 3: fold 1.0 > 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdGreaterThan %bool %float_1 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 4: fold 1.0 <= 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdLessThanEqual %bool %float_1 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 5: fold 1.0 >= 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdGreaterThanEqual %bool %float_1 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 6: fold 1.0 == 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdEqual %bool %float_1 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 7: fold 1.0 != 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdNotEqual %bool %float_1 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 8: fold 1.0 < 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdLessThan %bool %float_1 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 9: fold 1.0 > 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdGreaterThan %bool %float_1 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 10: fold 1.0 <= 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdLessThanEqual %bool %float_1 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 11: fold 1.0 >= 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdGreaterThanEqual %bool %float_1 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 12: fold 2.0 < 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdLessThan %bool %float_2 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 13: fold 2.0 > 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdGreaterThan %bool %float_2 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 14: fold 2.0 <= 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdLessThanEqual %bool %float_2 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 15: fold 2.0 >= 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdGreaterThanEqual %bool %float_2 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true) +)); + +INSTANTIATE_TEST_CASE_P(FloatUnorderedCompareConstantFoldingTest, BooleanInstructionFoldingTest, + ::testing::Values( + // Test case 0: fold 1.0 == 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordEqual %bool %float_1 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 1: fold 1.0 != 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordNotEqual %bool %float_1 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 2: fold 1.0 < 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordLessThan %bool %float_1 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 3: fold 1.0 > 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordGreaterThan %bool %float_1 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 4: fold 1.0 <= 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordLessThanEqual %bool %float_1 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 5: fold 1.0 >= 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordGreaterThanEqual %bool %float_1 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 6: fold 1.0 == 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordEqual %bool %float_1 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 7: fold 1.0 != 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordNotEqual %bool %float_1 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 8: fold 1.0 < 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordLessThan %bool %float_1 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 9: fold 1.0 > 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordGreaterThan %bool %float_1 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 10: fold 1.0 <= 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordLessThanEqual %bool %float_1 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 11: fold 1.0 >= 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordGreaterThanEqual %bool %float_1 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 12: fold 2.0 < 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordLessThan %bool %float_2 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 13: fold 2.0 > 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordGreaterThan %bool %float_2 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 14: fold 2.0 <= 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordLessThanEqual %bool %float_2 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 15: fold 2.0 >= 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordGreaterThanEqual %bool %float_2 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true) +)); + +INSTANTIATE_TEST_CASE_P(DoubleNaNCompareConstantFoldingTest, BooleanInstructionFoldingTest, + ::testing::Values( + // Test case 0: fold NaN == 0 (ord) + InstructionFoldingCase( + HeaderWithNaN() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdEqual %bool %double_nan %double_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 1: fold NaN == NaN (unord) + InstructionFoldingCase( + HeaderWithNaN() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordEqual %bool %double_nan %double_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 2: fold NaN != NaN (ord) + InstructionFoldingCase( + HeaderWithNaN() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdNotEqual %bool %double_nan %double_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 3: fold NaN != NaN (unord) + InstructionFoldingCase( + HeaderWithNaN() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordNotEqual %bool %double_nan %double_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true) +)); + +INSTANTIATE_TEST_CASE_P(FloatNaNCompareConstantFoldingTest, BooleanInstructionFoldingTest, + ::testing::Values( + // Test case 0: fold NaN == 0 (ord) + InstructionFoldingCase( + HeaderWithNaN() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdEqual %bool %float_nan %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 1: fold NaN == NaN (unord) + InstructionFoldingCase( + HeaderWithNaN() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordEqual %bool %float_nan %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 2: fold NaN != NaN (ord) + InstructionFoldingCase( + HeaderWithNaN() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFOrdNotEqual %bool %float_nan %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, false), + // Test case 3: fold NaN != NaN (unord) + InstructionFoldingCase( + HeaderWithNaN() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFUnordNotEqual %bool %float_nan %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true) +)); +// clang-format on + +template +struct InstructionFoldingCaseWithMap { + InstructionFoldingCaseWithMap(const std::string& tb, uint32_t id, + ResultType result, + std::function map) + : test_body(tb), id_to_fold(id), expected_result(result), id_map(map) {} + + std::string test_body; + uint32_t id_to_fold; + ResultType expected_result; + std::function id_map; +}; + +using IntegerInstructionFoldingTestWithMap = + ::testing::TestWithParam>; + +TEST_P(IntegerInstructionFoldingTestWithMap, Case) { + const auto& tc = GetParam(); + + // Build module. + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(nullptr, context); + + // Fold the instruction to test. + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); + inst = context->get_instruction_folder().FoldInstructionToConstant(inst, + tc.id_map); + + // Make sure the instruction folded as expected. + EXPECT_NE(inst, nullptr); + if (inst != nullptr) { + EXPECT_EQ(inst->opcode(), SpvOpConstant); + analysis::ConstantManager* const_mrg = context->get_constant_mgr(); + const analysis::IntConstant* result = + const_mrg->GetConstantFromInst(inst)->AsIntConstant(); + EXPECT_NE(result, nullptr); + if (result != nullptr) { + EXPECT_EQ(result->GetU32BitValue(), tc.expected_result); + } + } +} +// clang-format off +INSTANTIATE_TEST_CASE_P(TestCase, IntegerInstructionFoldingTestWithMap, + ::testing::Values( + // Test case 0: fold %3 = 0; %3 * n + InstructionFoldingCaseWithMap( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%3 = OpCopyObject %int %int_0\n" + "%2 = OpIMul %int %3 %load\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0, [](uint32_t id) {return (id == 3 ? INT_0_ID : id);}) + )); +// clang-format on + +using BooleanInstructionFoldingTestWithMap = + ::testing::TestWithParam>; + +TEST_P(BooleanInstructionFoldingTestWithMap, Case) { + const auto& tc = GetParam(); + + // Build module. + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(nullptr, context); + + // Fold the instruction to test. + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); + inst = context->get_instruction_folder().FoldInstructionToConstant(inst, + tc.id_map); + + // Make sure the instruction folded as expected. + EXPECT_NE(inst, nullptr); + if (inst != nullptr) { + std::vector bool_opcodes = {SpvOpConstantTrue, SpvOpConstantFalse}; + EXPECT_THAT(bool_opcodes, Contains(inst->opcode())); + analysis::ConstantManager* const_mrg = context->get_constant_mgr(); + const analysis::BoolConstant* result = + const_mrg->GetConstantFromInst(inst)->AsBoolConstant(); + EXPECT_NE(result, nullptr); + if (result != nullptr) { + EXPECT_EQ(result->value(), tc.expected_result); + } + } +} + +// clang-format off +INSTANTIATE_TEST_CASE_P(TestCase, BooleanInstructionFoldingTestWithMap, + ::testing::Values( + // Test case 0: fold %3 = true; %3 || n + InstructionFoldingCaseWithMap( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_bool Function\n" + + "%load = OpLoad %bool %n\n" + + "%3 = OpCopyObject %bool %true\n" + + "%2 = OpLogicalOr %bool %3 %load\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true, [](uint32_t id) {return (id == 3 ? TRUE_ID : id);}) + )); +// clang-format on + +using GeneralInstructionFoldingTest = + ::testing::TestWithParam>; + +TEST_P(GeneralInstructionFoldingTest, Case) { + const auto& tc = GetParam(); + + // Build module. + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(nullptr, context); + + // Fold the instruction to test. + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); + std::unique_ptr original_inst(inst->Clone(context.get())); + bool succeeded = context->get_instruction_folder().FoldInstruction(inst); + + // Make sure the instruction folded as expected. + EXPECT_EQ(inst->result_id(), original_inst->result_id()); + EXPECT_EQ(inst->type_id(), original_inst->type_id()); + EXPECT_TRUE((!succeeded) == (tc.expected_result == 0)); + if (succeeded) { + EXPECT_EQ(inst->opcode(), SpvOpCopyObject); + EXPECT_EQ(inst->GetSingleWordInOperand(0), tc.expected_result); + } else { + EXPECT_EQ(inst->NumInOperands(), original_inst->NumInOperands()); + for (uint32_t i = 0; i < inst->NumInOperands(); ++i) { + EXPECT_EQ(inst->GetOperand(i), original_inst->GetOperand(i)); + } + } +} + +// clang-format off +INSTANTIATE_TEST_CASE_P(IntegerArithmeticTestCases, GeneralInstructionFoldingTest, + ::testing::Values( + // Test case 0: Don't fold n * m + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%m = OpVariable %_ptr_int Function\n" + + "%load_n = OpLoad %int %n\n" + + "%load_m = OpLoad %int %m\n" + + "%2 = OpIMul %int %load_n %load_m\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 1: Don't fold n / m (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%m = OpVariable %_ptr_uint Function\n" + + "%load_n = OpLoad %uint %n\n" + + "%load_m = OpLoad %uint %m\n" + + "%2 = OpUDiv %uint %load_n %load_m\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 2: Don't fold n / m (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%m = OpVariable %_ptr_int Function\n" + + "%load_n = OpLoad %int %n\n" + + "%load_m = OpLoad %int %m\n" + + "%2 = OpSDiv %int %load_n %load_m\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 3: Don't fold n remainder m + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%m = OpVariable %_ptr_int Function\n" + + "%load_n = OpLoad %int %n\n" + + "%load_m = OpLoad %int %m\n" + + "%2 = OpSRem %int %load_n %load_m\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 4: Don't fold n % m (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%m = OpVariable %_ptr_int Function\n" + + "%load_n = OpLoad %int %n\n" + + "%load_m = OpLoad %int %m\n" + + "%2 = OpSMod %int %load_n %load_m\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 5: Don't fold n % m (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%m = OpVariable %_ptr_uint Function\n" + + "%load_n = OpLoad %uint %n\n" + + "%load_m = OpLoad %uint %m\n" + + "%2 = OpUMod %int %load_n %load_m\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 6: Don't fold n << m + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%m = OpVariable %_ptr_uint Function\n" + + "%load_n = OpLoad %uint %n\n" + + "%load_m = OpLoad %uint %m\n" + + "%2 = OpShiftRightLogical %int %load_n %load_m\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 7: Don't fold n >> m + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%m = OpVariable %_ptr_uint Function\n" + + "%load_n = OpLoad %uint %n\n" + + "%load_m = OpLoad %uint %m\n" + + "%2 = OpShiftLeftLogical %int %load_n %load_m\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 8: Don't fold n | m + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%m = OpVariable %_ptr_uint Function\n" + + "%load_n = OpLoad %uint %n\n" + + "%load_m = OpLoad %uint %m\n" + + "%2 = OpBitwiseOr %int %load_n %load_m\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 9: Don't fold n & m + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%m = OpVariable %_ptr_uint Function\n" + + "%load_n = OpLoad %uint %n\n" + + "%load_m = OpLoad %uint %m\n" + + "%2 = OpBitwiseAnd %int %load_n %load_m\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 10: Don't fold n < m (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%m = OpVariable %_ptr_uint Function\n" + + "%load_n = OpLoad %uint %n\n" + + "%load_m = OpLoad %uint %m\n" + + "%2 = OpULessThan %bool %load_n %load_m\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 11: Don't fold n > m (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%m = OpVariable %_ptr_uint Function\n" + + "%load_n = OpLoad %uint %n\n" + + "%load_m = OpLoad %uint %m\n" + + "%2 = OpUGreaterThan %bool %load_n %load_m\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 12: Don't fold n <= m (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%m = OpVariable %_ptr_uint Function\n" + + "%load_n = OpLoad %uint %n\n" + + "%load_m = OpLoad %uint %m\n" + + "%2 = OpULessThanEqual %bool %load_n %load_m\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 13: Don't fold n >= m (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%m = OpVariable %_ptr_uint Function\n" + + "%load_n = OpLoad %uint %n\n" + + "%load_m = OpLoad %uint %m\n" + + "%2 = OpUGreaterThanEqual %bool %load_n %load_m\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 14: Don't fold n < m (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%m = OpVariable %_ptr_int Function\n" + + "%load_n = OpLoad %int %n\n" + + "%load_m = OpLoad %int %m\n" + + "%2 = OpULessThan %bool %load_n %load_m\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 15: Don't fold n > m (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%m = OpVariable %_ptr_int Function\n" + + "%load_n = OpLoad %int %n\n" + + "%load_m = OpLoad %int %m\n" + + "%2 = OpUGreaterThan %bool %load_n %load_m\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 16: Don't fold n <= m (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%m = OpVariable %_ptr_int Function\n" + + "%load_n = OpLoad %int %n\n" + + "%load_m = OpLoad %int %m\n" + + "%2 = OpULessThanEqual %bool %load_n %load_m\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 17: Don't fold n >= m (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%m = OpVariable %_ptr_int Function\n" + + "%load_n = OpLoad %int %n\n" + + "%load_m = OpLoad %int %m\n" + + "%2 = OpUGreaterThanEqual %bool %load_n %load_m\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 18: Don't fold n || m + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_bool Function\n" + + "%m = OpVariable %_ptr_bool Function\n" + + "%load_n = OpLoad %bool %n\n" + + "%load_m = OpLoad %bool %m\n" + + "%2 = OpLogicalOr %bool %load_n %load_m\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 19: Don't fold n && m + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_bool Function\n" + + "%m = OpVariable %_ptr_bool Function\n" + + "%load_n = OpLoad %bool %n\n" + + "%load_m = OpLoad %bool %m\n" + + "%2 = OpLogicalAnd %bool %load_n %load_m\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 20: Don't fold n * 3 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load_n = OpLoad %int %n\n" + + "%2 = OpIMul %int %load_n %int_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 21: Don't fold n / 3 (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load_n = OpLoad %uint %n\n" + + "%2 = OpUDiv %uint %load_n %uint_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 22: Don't fold n / 3 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load_n = OpLoad %int %n\n" + + "%2 = OpSDiv %int %load_n %int_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 23: Don't fold n remainder 3 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load_n = OpLoad %int %n\n" + + "%2 = OpSRem %int %load_n %int_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 24: Don't fold n % 3 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load_n = OpLoad %int %n\n" + + "%2 = OpSMod %int %load_n %int_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 25: Don't fold n % 3 (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load_n = OpLoad %uint %n\n" + + "%2 = OpUMod %int %load_n %int_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 26: Don't fold n << 3 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load_n = OpLoad %uint %n\n" + + "%2 = OpShiftRightLogical %int %load_n %int_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 27: Don't fold n >> 3 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load_n = OpLoad %uint %n\n" + + "%2 = OpShiftLeftLogical %int %load_n %int_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 28: Don't fold n | 3 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load_n = OpLoad %uint %n\n" + + "%2 = OpBitwiseOr %int %load_n %int_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 29: Don't fold n & 3 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load_n = OpLoad %uint %n\n" + + "%2 = OpBitwiseAnd %uint %load_n %uint_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 30: Don't fold n < 3 (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load_n = OpLoad %uint %n\n" + + "%2 = OpULessThan %bool %load_n %uint_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 31: Don't fold n > 3 (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load_n = OpLoad %uint %n\n" + + "%2 = OpUGreaterThan %bool %load_n %uint_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 32: Don't fold n <= 3 (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load_n = OpLoad %uint %n\n" + + "%2 = OpULessThanEqual %bool %load_n %uint_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 33: Don't fold n >= 3 (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%load_n = OpLoad %uint %n\n" + + "%2 = OpUGreaterThanEqual %bool %load_n %uint_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 34: Don't fold n < 3 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load_n = OpLoad %int %n\n" + + "%2 = OpULessThan %bool %load_n %int_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 35: Don't fold n > 3 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load_n = OpLoad %int %n\n" + + "%2 = OpUGreaterThan %bool %load_n %int_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 36: Don't fold n <= 3 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load_n = OpLoad %int %n\n" + + "%2 = OpULessThanEqual %bool %load_n %int_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 37: Don't fold n >= 3 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load_n = OpLoad %int %n\n" + + "%2 = OpUGreaterThanEqual %bool %load_n %int_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 38: Don't fold 2 + 3 (long), bad length + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpIAdd %long %long_2 %long_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 39: Don't fold 2 + 3 (short), bad length + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpIAdd %short %short_2 %short_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 40: fold 1*n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%3 = OpLoad %int %n\n" + + "%2 = OpIMul %int %int_1 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 41: fold n*1 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%3 = OpLoad %int %n\n" + + "%2 = OpIMul %int %3 %int_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3) +)); + +INSTANTIATE_TEST_CASE_P(CompositeExtractFoldingTest, GeneralInstructionFoldingTest, +::testing::Values( + // Test case 0: fold Insert feeding extract + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %n\n" + + "%3 = OpCompositeInsert %v4int %2 %v4int_0_0_0_0 0\n" + + "%4 = OpCompositeInsert %v4int %int_1 %3 1\n" + + "%5 = OpCompositeInsert %v4int %int_1 %4 2\n" + + "%6 = OpCompositeInsert %v4int %int_1 %5 3\n" + + "%7 = OpCompositeExtract %int %6 0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 7, 2), + // Test case 1: fold Composite construct feeding extract (position 0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %n\n" + + "%3 = OpCompositeConstruct %v4int %2 %int_0 %int_0 %int_0\n" + + "%4 = OpCompositeExtract %int %3 0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, 2), + // Test case 2: fold Composite construct feeding extract (position 3) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %n\n" + + "%3 = OpCompositeConstruct %v4int %2 %int_0 %int_0 %100\n" + + "%4 = OpCompositeExtract %int %3 3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, INT_0_ID), + // Test case 3: fold Composite construct with vectors feeding extract (scalar element) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %n\n" + + "%3 = OpCompositeConstruct %v2int %2 %int_0\n" + + "%4 = OpCompositeConstruct %v4int %3 %int_0 %100\n" + + "%5 = OpCompositeExtract %int %4 3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 5, INT_0_ID), + // Test case 4: fold Composite construct with vectors feeding extract (start of vector element) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %n\n" + + "%3 = OpCompositeConstruct %v2int %2 %int_0\n" + + "%4 = OpCompositeConstruct %v4int %3 %int_0 %100\n" + + "%5 = OpCompositeExtract %int %4 0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 5, 2), + // Test case 5: fold Composite construct with vectors feeding extract (middle of vector element) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %n\n" + + "%3 = OpCompositeConstruct %v2int %int_0 %2\n" + + "%4 = OpCompositeConstruct %v4int %3 %int_0 %100\n" + + "%5 = OpCompositeExtract %int %4 1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 5, 2), + // Test case 6: fold Composite construct with multiple indices. + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %n\n" + + "%3 = OpCompositeConstruct %v2int %int_0 %2\n" + + "%4 = OpCompositeConstruct %struct_v2int_int_int %3 %int_0 %100\n" + + "%5 = OpCompositeExtract %int %4 0 1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 5, 2), + // Test case 7: fold constant extract. + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpCompositeExtract %int %102 1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, INT_7_ID), + // Test case 8: constant struct has OpUndef + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpCompositeExtract %int %struct_undef_0_0 0 1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 9: Extracting a member of element inserted via Insert + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_struct_v2int_int_int Function\n" + + "%2 = OpLoad %struct_v2int_int_int %n\n" + + "%3 = OpCompositeInsert %struct_v2int_int_int %102 %2 0\n" + + "%4 = OpCompositeExtract %int %3 0 1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, 103), + // Test case 10: Extracting a element that is partially changed by Insert. (Don't fold) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_struct_v2int_int_int Function\n" + + "%2 = OpLoad %struct_v2int_int_int %n\n" + + "%3 = OpCompositeInsert %struct_v2int_int_int %int_0 %2 0 1\n" + + "%4 = OpCompositeExtract %v2int %3 0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, 0), + // Test case 11: Extracting from result of vector shuffle (first input) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v2int Function\n" + + "%2 = OpLoad %v2int %n\n" + + "%3 = OpVectorShuffle %v2int %102 %2 3 0\n" + + "%4 = OpCompositeExtract %int %3 1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, INT_7_ID), + // Test case 12: Extracting from result of vector shuffle (second input) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v2int Function\n" + + "%2 = OpLoad %v2int %n\n" + + "%3 = OpVectorShuffle %v2int %2 %102 2 0\n" + + "%4 = OpCompositeExtract %int %3 0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, INT_7_ID) +)); + +INSTANTIATE_TEST_CASE_P(CompositeConstructFoldingTest, GeneralInstructionFoldingTest, +::testing::Values( + // Test case 0: fold Extracts feeding construct + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpCopyObject %v4int %v4int_0_0_0_0\n" + + "%3 = OpCompositeExtract %int %2 0\n" + + "%4 = OpCompositeExtract %int %2 1\n" + + "%5 = OpCompositeExtract %int %2 2\n" + + "%6 = OpCompositeExtract %int %2 3\n" + + "%7 = OpCompositeConstruct %v4int %3 %4 %5 %6\n" + + "OpReturn\n" + + "OpFunctionEnd", + 7, 2), + // Test case 1: Don't fold Extracts feeding construct (Different source) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpCopyObject %v4int %v4int_0_0_0_0\n" + + "%3 = OpCompositeExtract %int %2 0\n" + + "%4 = OpCompositeExtract %int %2 1\n" + + "%5 = OpCompositeExtract %int %2 2\n" + + "%6 = OpCompositeExtract %int %v4int_0_0_0_0 3\n" + + "%7 = OpCompositeConstruct %v4int %3 %4 %5 %6\n" + + "OpReturn\n" + + "OpFunctionEnd", + 7, 0), + // Test case 2: Don't fold Extracts feeding construct (bad indices) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpCopyObject %v4int %v4int_0_0_0_0\n" + + "%3 = OpCompositeExtract %int %2 0\n" + + "%4 = OpCompositeExtract %int %2 0\n" + + "%5 = OpCompositeExtract %int %2 2\n" + + "%6 = OpCompositeExtract %int %2 3\n" + + "%7 = OpCompositeConstruct %v4int %3 %4 %5 %6\n" + + "OpReturn\n" + + "OpFunctionEnd", + 7, 0), + // Test case 3: Don't fold Extracts feeding construct (different type) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpCopyObject %struct_v2int_int_int %struct_v2int_int_int_null\n" + + "%3 = OpCompositeExtract %v2int %2 0\n" + + "%4 = OpCompositeExtract %int %2 1\n" + + "%5 = OpCompositeExtract %int %2 2\n" + + "%7 = OpCompositeConstruct %v4int %3 %4 %5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 7, 0), + // Test case 4: Fold construct with constants to constant. + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpCompositeConstruct %v2int %103 %103\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, VEC2_0_ID) +)); + +INSTANTIATE_TEST_CASE_P(PhiFoldingTest, GeneralInstructionFoldingTest, +::testing::Values( + // Test case 0: Fold phi with the same values for all edges. + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + " OpBranchConditional %true %l1 %l2\n" + + "%l1 = OpLabel\n" + + " OpBranch %merge_lab\n" + + "%l2 = OpLabel\n" + + " OpBranch %merge_lab\n" + + "%merge_lab = OpLabel\n" + + "%2 = OpPhi %int %100 %l1 %100 %l2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, INT_0_ID), + // Test case 1: Fold phi in pass through loop. + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + " OpBranch %l1\n" + + "%l1 = OpLabel\n" + + "%2 = OpPhi %int %100 %main_lab %2 %l1\n" + + " OpBranchConditional %true %l1 %merge_lab\n" + + "%merge_lab = OpLabel\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, INT_0_ID), + // Test case 2: Don't Fold phi because of different values. + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + " OpBranch %l1\n" + + "%l1 = OpLabel\n" + + "%2 = OpPhi %int %int_0 %main_lab %int_3 %l1\n" + + " OpBranchConditional %true %l1 %merge_lab\n" + + "%merge_lab = OpLabel\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0) +)); + +INSTANTIATE_TEST_CASE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTest, + ::testing::Values( + // Test case 0: Don't fold n + 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%3 = OpLoad %float %n\n" + + "%2 = OpFAdd %float %3 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 1: Don't fold n - 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%3 = OpLoad %float %n\n" + + "%2 = OpFSub %float %3 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 2: Don't fold n * 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%3 = OpLoad %float %n\n" + + "%2 = OpFMul %float %3 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 3: Fold n + 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%3 = OpLoad %float %n\n" + + "%2 = OpFAdd %float %3 %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 4: Fold 0.0 + n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%3 = OpLoad %float %n\n" + + "%2 = OpFAdd %float %float_0 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 5: Fold n - 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%3 = OpLoad %float %n\n" + + "%2 = OpFSub %float %3 %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 6: Fold n * 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%3 = OpLoad %float %n\n" + + "%2 = OpFMul %float %3 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 7: Fold 1.0 * n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%3 = OpLoad %float %n\n" + + "%2 = OpFMul %float %float_1 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 8: Fold n / 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%3 = OpLoad %float %n\n" + + "%2 = OpFDiv %float %3 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 9: Fold n * 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%3 = OpLoad %float %n\n" + + "%2 = OpFMul %float %3 %104\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, FLOAT_0_ID), + // Test case 10: Fold 0.0 * n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%3 = OpLoad %float %n\n" + + "%2 = OpFMul %float %104 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, FLOAT_0_ID), + // Test case 11: Fold 0.0 / n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%3 = OpLoad %float %n\n" + + "%2 = OpFDiv %float %104 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, FLOAT_0_ID), + // Test case 12: Don't fold mix(a, b, 2.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%a = OpVariable %_ptr_float Function\n" + + "%b = OpVariable %_ptr_float Function\n" + + "%3 = OpLoad %float %a\n" + + "%4 = OpLoad %float %b\n" + + "%2 = OpExtInst %float %1 FMix %3 %4 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 13: Fold mix(a, b, 0.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%a = OpVariable %_ptr_float Function\n" + + "%b = OpVariable %_ptr_float Function\n" + + "%3 = OpLoad %float %a\n" + + "%4 = OpLoad %float %b\n" + + "%2 = OpExtInst %float %1 FMix %3 %4 %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 14: Fold mix(a, b, 1.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%a = OpVariable %_ptr_float Function\n" + + "%b = OpVariable %_ptr_float Function\n" + + "%3 = OpLoad %float %a\n" + + "%4 = OpLoad %float %b\n" + + "%2 = OpExtInst %float %1 FMix %3 %4 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 4), + // Test case 15: Fold vector fadd with null + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%a = OpVariable %_ptr_v2float Function\n" + + "%2 = OpLoad %v2float %a\n" + + "%3 = OpFAdd %v2float %2 %v2float_null\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, 2), + // Test case 16: Fold vector fadd with null + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%a = OpVariable %_ptr_v2float Function\n" + + "%2 = OpLoad %v2float %a\n" + + "%3 = OpFAdd %v2float %v2float_null %2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, 2), + // Test case 15: Fold vector fsub with null + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%a = OpVariable %_ptr_v2float Function\n" + + "%2 = OpLoad %v2float %a\n" + + "%3 = OpFSub %v2float %2 %v2float_null\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, 2), + // Test case 16: Fold 0.0(half) * n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_half Function\n" + + "%3 = OpLoad %half %n\n" + + "%2 = OpFMul %half %108 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, HALF_0_ID), + // Test case 17: Don't fold 1.0(half) * n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_half Function\n" + + "%3 = OpLoad %half %n\n" + + "%2 = OpFMul %half %half_1 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 18: Don't fold 1.0 * 1.0 (half) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFMul %half %half_1 %half_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0) +)); + +INSTANTIATE_TEST_CASE_P(DoubleRedundantFoldingTest, GeneralInstructionFoldingTest, + ::testing::Values( + // Test case 0: Don't fold n + 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_double Function\n" + + "%3 = OpLoad %double %n\n" + + "%2 = OpFAdd %double %3 %double_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 1: Don't fold n - 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_double Function\n" + + "%3 = OpLoad %double %n\n" + + "%2 = OpFSub %double %3 %double_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 2: Don't fold n * 2.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_double Function\n" + + "%3 = OpLoad %double %n\n" + + "%2 = OpFMul %double %3 %double_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 3: Fold n + 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_double Function\n" + + "%3 = OpLoad %double %n\n" + + "%2 = OpFAdd %double %3 %double_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 4: Fold 0.0 + n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_double Function\n" + + "%3 = OpLoad %double %n\n" + + "%2 = OpFAdd %double %double_0 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 5: Fold n - 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_double Function\n" + + "%3 = OpLoad %double %n\n" + + "%2 = OpFSub %double %3 %double_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 6: Fold n * 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_double Function\n" + + "%3 = OpLoad %double %n\n" + + "%2 = OpFMul %double %3 %double_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 7: Fold 1.0 * n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_double Function\n" + + "%3 = OpLoad %double %n\n" + + "%2 = OpFMul %double %double_1 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 8: Fold n / 1.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_double Function\n" + + "%3 = OpLoad %double %n\n" + + "%2 = OpFDiv %double %3 %double_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 9: Fold n * 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_double Function\n" + + "%3 = OpLoad %double %n\n" + + "%2 = OpFMul %double %3 %105\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, DOUBLE_0_ID), + // Test case 10: Fold 0.0 * n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_double Function\n" + + "%3 = OpLoad %double %n\n" + + "%2 = OpFMul %double %105 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, DOUBLE_0_ID), + // Test case 11: Fold 0.0 / n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_double Function\n" + + "%3 = OpLoad %double %n\n" + + "%2 = OpFDiv %double %105 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, DOUBLE_0_ID), + // Test case 12: Don't fold mix(a, b, 2.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%a = OpVariable %_ptr_double Function\n" + + "%b = OpVariable %_ptr_double Function\n" + + "%3 = OpLoad %double %a\n" + + "%4 = OpLoad %double %b\n" + + "%2 = OpExtInst %double %1 FMix %3 %4 %double_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 13: Fold mix(a, b, 0.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%a = OpVariable %_ptr_double Function\n" + + "%b = OpVariable %_ptr_double Function\n" + + "%3 = OpLoad %double %a\n" + + "%4 = OpLoad %double %b\n" + + "%2 = OpExtInst %double %1 FMix %3 %4 %double_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 14: Fold mix(a, b, 1.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%a = OpVariable %_ptr_double Function\n" + + "%b = OpVariable %_ptr_double Function\n" + + "%3 = OpLoad %double %a\n" + + "%4 = OpLoad %double %b\n" + + "%2 = OpExtInst %double %1 FMix %3 %4 %double_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 4) +)); + +INSTANTIATE_TEST_CASE_P(FloatVectorRedundantFoldingTest, GeneralInstructionFoldingTest, + ::testing::Values( + // Test case 0: Don't fold a * vec4(0.0, 0.0, 0.0, 1.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4float Function\n" + + "%3 = OpLoad %v4float %n\n" + + "%2 = OpFMul %v4float %3 %v4float_0_0_0_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 1: Fold a * vec4(0.0, 0.0, 0.0, 0.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4float Function\n" + + "%3 = OpLoad %v4float %n\n" + + "%2 = OpFMul %v4float %3 %106\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, VEC4_0_ID), + // Test case 2: Fold a * vec4(1.0, 1.0, 1.0, 1.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4float Function\n" + + "%3 = OpLoad %v4float %n\n" + + "%2 = OpFMul %v4float %3 %v4float_1_1_1_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3) +)); + +INSTANTIATE_TEST_CASE_P(DoubleVectorRedundantFoldingTest, GeneralInstructionFoldingTest, + ::testing::Values( + // Test case 0: Don't fold a * vec4(0.0, 0.0, 0.0, 1.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4double Function\n" + + "%3 = OpLoad %v4double %n\n" + + "%2 = OpFMul %v4double %3 %v4double_0_0_0_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 1: Fold a * vec4(0.0, 0.0, 0.0, 0.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4double Function\n" + + "%3 = OpLoad %v4double %n\n" + + "%2 = OpFMul %v4double %3 %106\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, DVEC4_0_ID), + // Test case 2: Fold a * vec4(1.0, 1.0, 1.0, 1.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4double Function\n" + + "%3 = OpLoad %v4double %n\n" + + "%2 = OpFMul %v4double %3 %v4double_1_1_1_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3) +)); + +INSTANTIATE_TEST_CASE_P(IntegerRedundantFoldingTest, GeneralInstructionFoldingTest, + ::testing::Values( + // Test case 0: Don't fold n + 1 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%3 = OpLoad %uint %n\n" + + "%2 = OpIAdd %uint %3 %uint_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 1: Don't fold 1 + n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%3 = OpLoad %uint %n\n" + + "%2 = OpIAdd %uint %uint_1 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 2: Fold n + 0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%3 = OpLoad %uint %n\n" + + "%2 = OpIAdd %uint %3 %uint_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 3: Fold 0 + n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_uint Function\n" + + "%3 = OpLoad %uint %n\n" + + "%2 = OpIAdd %uint %uint_0 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 4: Don't fold n + (1,0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v2int Function\n" + + "%3 = OpLoad %v2int %n\n" + + "%2 = OpIAdd %v2int %3 %v2int_1_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 5: Don't fold (1,0) + n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v2int Function\n" + + "%3 = OpLoad %v2int %n\n" + + "%2 = OpIAdd %v2int %v2int_1_0 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 6: Fold n + (0,0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v2int Function\n" + + "%3 = OpLoad %v2int %n\n" + + "%2 = OpIAdd %v2int %3 %v2int_0_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 7: Fold (0,0) + n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v2int Function\n" + + "%3 = OpLoad %v2int %n\n" + + "%2 = OpIAdd %v2int %v2int_0_0 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3) +)); + +INSTANTIATE_TEST_CASE_P(ClampAndCmpLHS, GeneralInstructionFoldingTest, +::testing::Values( + // Test case 0: Don't Fold 0.0 < clamp(-1, 1) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFUnordLessThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 1: Don't Fold 0.0 < clamp(-1, 1) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFOrdLessThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 2: Don't Fold 0.0 <= clamp(-1, 1) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFUnordLessThanEqual %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 3: Don't Fold 0.0 <= clamp(-1, 1) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFOrdLessThanEqual %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 4: Don't Fold 0.0 > clamp(-1, 1) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFUnordGreaterThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 5: Don't Fold 0.0 > clamp(-1, 1) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFOrdGreaterThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 6: Don't Fold 0.0 >= clamp(-1, 1) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFUnordGreaterThanEqual %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 7: Don't Fold 0.0 >= clamp(-1, 1) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFOrdGreaterThanEqual %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 8: Don't Fold 0.0 < clamp(0, 1) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_0 %float_1\n" + + "%2 = OpFUnordLessThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 9: Don't Fold 0.0 < clamp(0, 1) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_0 %float_1\n" + + "%2 = OpFOrdLessThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 10: Don't Fold 0.0 > clamp(-1, 0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_0\n" + + "%2 = OpFUnordGreaterThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 11: Don't Fold 0.0 > clamp(-1, 0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_0\n" + + "%2 = OpFOrdGreaterThan %bool %float_0 %clamp\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0) +)); + +INSTANTIATE_TEST_CASE_P(ClampAndCmpRHS, GeneralInstructionFoldingTest, +::testing::Values( + // Test case 0: Don't Fold clamp(-1, 1) < 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFUnordLessThan %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 1: Don't Fold clamp(-1, 1) < 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFOrdLessThan %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 2: Don't Fold clamp(-1, 1) <= 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFUnordLessThanEqual %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 3: Don't Fold clamp(-1, 1) <= 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFOrdLessThanEqual %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 4: Don't Fold clamp(-1, 1) > 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFUnordGreaterThan %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 5: Don't Fold clamp(-1, 1) > 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFOrdGreaterThan %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 6: Don't Fold clamp(-1, 1) >= 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFUnordGreaterThanEqual %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 7: Don't Fold clamp(-1, 1) >= 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_1\n" + + "%2 = OpFOrdGreaterThanEqual %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 8: Don't Fold clamp(-1, 0) < 0.0 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_0\n" + + "%2 = OpFUnordLessThan %bool %clamp %float_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 9: Don't Fold clamp(0, 1) < 1 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_0 %float_1\n" + + "%2 = OpFOrdLessThan %bool %clamp %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 10: Don't Fold clamp(-1, 0) > -1 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_0\n" + + "%2 = OpFUnordGreaterThan %bool %clamp %float_n1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 11: Don't Fold clamp(-1, 0) > -1 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%ld = OpLoad %float %n\n" + + "%clamp = OpExtInst %float %1 FClamp %ld %float_n1 %float_0\n" + + "%2 = OpFOrdGreaterThan %bool %clamp %float_n1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0) +)); + +INSTANTIATE_TEST_CASE_P(FToIConstantFoldingTest, IntegerInstructionFoldingTest, + ::testing::Values( + // Test case 0: Fold int(3.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpConvertFToS %int %float_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 1: Fold uint(3.0) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpConvertFToU %int %float_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3) +)); + +INSTANTIATE_TEST_CASE_P(IToFConstantFoldingTest, FloatInstructionFoldingTest, + ::testing::Values( + // Test case 0: Fold float(3) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpConvertSToF %float %int_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3.0), + // Test case 1: Fold float(3u) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpConvertUToF %float %uint_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3.0) +)); +// clang-format on + +using ToNegateFoldingTest = + ::testing::TestWithParam>; + +TEST_P(ToNegateFoldingTest, Case) { + const auto& tc = GetParam(); + + // Build module. + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(nullptr, context); + + // Fold the instruction to test. + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); + std::unique_ptr original_inst(inst->Clone(context.get())); + bool succeeded = context->get_instruction_folder().FoldInstruction(inst); + + // Make sure the instruction folded as expected. + EXPECT_EQ(inst->result_id(), original_inst->result_id()); + EXPECT_EQ(inst->type_id(), original_inst->type_id()); + EXPECT_TRUE((!succeeded) == (tc.expected_result == 0)); + if (succeeded) { + EXPECT_EQ(inst->opcode(), SpvOpFNegate); + EXPECT_EQ(inst->GetSingleWordInOperand(0), tc.expected_result); + } else { + EXPECT_EQ(inst->NumInOperands(), original_inst->NumInOperands()); + for (uint32_t i = 0; i < inst->NumInOperands(); ++i) { + EXPECT_EQ(inst->GetOperand(i), original_inst->GetOperand(i)); + } + } +} + +// clang-format off +INSTANTIATE_TEST_CASE_P(FloatRedundantSubFoldingTest, ToNegateFoldingTest, + ::testing::Values( + // Test case 0: Don't fold 1.0 - n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%3 = OpLoad %float %n\n" + + "%2 = OpFSub %float %float_1 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 1: Fold 0.0 - n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_float Function\n" + + "%3 = OpLoad %float %n\n" + + "%2 = OpFSub %float %float_0 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 2: Don't fold (0,0,0,1) - n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4float Function\n" + + "%3 = OpLoad %v4float %n\n" + + "%2 = OpFSub %v4float %v4float_0_0_0_1 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 3: Fold (0,0,0,0) - n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4float Function\n" + + "%3 = OpLoad %v4float %n\n" + + "%2 = OpFSub %v4float %v4float_0_0_0_0 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3) +)); + +INSTANTIATE_TEST_CASE_P(DoubleRedundantSubFoldingTest, ToNegateFoldingTest, + ::testing::Values( + // Test case 0: Don't fold 1.0 - n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_double Function\n" + + "%3 = OpLoad %double %n\n" + + "%2 = OpFSub %double %double_1 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 1: Fold 0.0 - n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_double Function\n" + + "%3 = OpLoad %double %n\n" + + "%2 = OpFSub %double %double_0 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 2: Don't fold (0,0,0,1) - n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4double Function\n" + + "%3 = OpLoad %v4double %n\n" + + "%2 = OpFSub %v4double %v4double_0_0_0_1 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0), + // Test case 3: Fold (0,0,0,0) - n + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4double Function\n" + + "%3 = OpLoad %v4double %n\n" + + "%2 = OpFSub %v4double %v4double_0_0_0_0 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3) +)); + +using MatchingInstructionFoldingTest = + ::testing::TestWithParam>; + +TEST_P(MatchingInstructionFoldingTest, Case) { + const auto& tc = GetParam(); + + // Build module. + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(nullptr, context); + + // Fold the instruction to test. + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + Instruction* inst = def_use_mgr->GetDef(tc.id_to_fold); + std::unique_ptr original_inst(inst->Clone(context.get())); + bool succeeded = context->get_instruction_folder().FoldInstruction(inst); + EXPECT_EQ(succeeded, tc.expected_result); + if (succeeded) { + Match(tc.test_body, context.get()); + } +} + +INSTANTIATE_TEST_CASE_P(RedundantIntegerMatching, MatchingInstructionFoldingTest, +::testing::Values( + // Test case 0: Fold 0 + n (change sign) + InstructionFoldingCase( + Header() + + "; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" + + "; CHECK: %2 = OpBitcast [[uint]] %3\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%3 = OpLoad %uint %n\n" + + "%2 = OpIAdd %uint %int_0 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 2, true), + // Test case 0: Fold 0 + n (change sign) + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: %2 = OpBitcast [[int]] %3\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%3 = OpLoad %int %n\n" + + "%2 = OpIAdd %int %uint_0 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 2, true) +)); + +INSTANTIATE_TEST_CASE_P(MergeNegateTest, MatchingInstructionFoldingTest, +::testing::Values( + // Test case 0: fold consecutive fnegate + // -(-x) = x + InstructionFoldingCase( + Header() + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float:%\\w+]]\n" + + "; CHECK: %4 = OpCopyObject [[float]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFNegate %float %2\n" + + "%4 = OpFNegate %float %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 1: fold fnegate(fmul with const). + // -(x * 2.0) = x * -2.0 + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFMul [[float]] [[ld]] [[float_n2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFMul %float %2 %float_2\n" + + "%4 = OpFNegate %float %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 2: fold fnegate(fmul with const). + // -(2.0 * x) = x * 2.0 + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFMul [[float]] [[ld]] [[float_n2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFMul %float %float_2 %2\n" + + "%4 = OpFNegate %float %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 3: fold fnegate(fdiv with const). + // -(x / 2.0) = x * -0.5 + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_n0p5:%\\w+]] = OpConstant [[float]] -0.5\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFMul [[float]] [[ld]] [[float_n0p5]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFDiv %float %2 %float_2\n" + + "%4 = OpFNegate %float %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 4: fold fnegate(fdiv with const). + // -(2.0 / x) = -2.0 / x + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFDiv [[float]] [[float_n2]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFDiv %float %float_2 %2\n" + + "%4 = OpFNegate %float %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 5: fold fnegate(fadd with const). + // -(2.0 + x) = -2.0 - x + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFSub [[float]] [[float_n2]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFAdd %float %float_2 %2\n" + + "%4 = OpFNegate %float %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 6: fold fnegate(fadd with const). + // -(x + 2.0) = -2.0 - x + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFSub [[float]] [[float_n2]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFAdd %float %2 %float_2\n" + + "%4 = OpFNegate %float %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 7: fold fnegate(fsub with const). + // -(2.0 - x) = x - 2.0 + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFSub [[float]] [[ld]] [[float_2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFSub %float %float_2 %2\n" + + "%4 = OpFNegate %float %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 8: fold fnegate(fsub with const). + // -(x - 2.0) = 2.0 - x + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFSub [[float]] [[float_2]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFSub %float %2 %float_2\n" + + "%4 = OpFNegate %float %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 9: fold consecutive snegate + // -(-x) = x + InstructionFoldingCase( + Header() + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int:%\\w+]]\n" + + "; CHECK: %4 = OpCopyObject [[int]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpSNegate %int %2\n" + + "%4 = OpSNegate %int %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 10: fold consecutive vector negate + // -(-x) = x + InstructionFoldingCase( + Header() + + "; CHECK: [[ld:%\\w+]] = OpLoad [[v2float:%\\w+]]\n" + + "; CHECK: %4 = OpCopyObject [[v2float]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_v2float Function\n" + + "%2 = OpLoad %v2float %var\n" + + "%3 = OpFNegate %v2float %2\n" + + "%4 = OpFNegate %v2float %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 11: fold snegate(iadd with const). + // -(2 + x) = -2 - x + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: OpConstant [[int]] -2147483648\n" + + "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpISub [[int]] [[int_n2]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpIAdd %int %int_2 %2\n" + + "%4 = OpSNegate %int %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 12: fold snegate(iadd with const). + // -(x + 2) = -2 - x + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: OpConstant [[int]] -2147483648\n" + + "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpISub [[int]] [[int_n2]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpIAdd %int %2 %int_2\n" + + "%4 = OpSNegate %int %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 13: fold snegate(isub with const). + // -(2 - x) = x - 2 + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[int_2:%\\w+]] = OpConstant [[int]] 2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpISub [[int]] [[ld]] [[int_2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpISub %int %int_2 %2\n" + + "%4 = OpSNegate %int %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 14: fold snegate(isub with const). + // -(x - 2) = 2 - x + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[int_2:%\\w+]] = OpConstant [[int]] 2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpISub [[int]] [[int_2]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpISub %int %2 %int_2\n" + + "%4 = OpSNegate %int %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 15: fold snegate(iadd with const). + // -(x + 2) = -2 - x + InstructionFoldingCase( + Header() + + "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" + + "; CHECK: [[long_n2:%\\w+]] = OpConstant [[long]] -2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" + + "; CHECK: %4 = OpISub [[long]] [[long_n2]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_long Function\n" + + "%2 = OpLoad %long %var\n" + + "%3 = OpIAdd %long %2 %long_2\n" + + "%4 = OpSNegate %long %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 16: fold snegate(isub with const). + // -(2 - x) = x - 2 + InstructionFoldingCase( + Header() + + "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" + + "; CHECK: [[long_2:%\\w+]] = OpConstant [[long]] 2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" + + "; CHECK: %4 = OpISub [[long]] [[ld]] [[long_2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_long Function\n" + + "%2 = OpLoad %long %var\n" + + "%3 = OpISub %long %long_2 %2\n" + + "%4 = OpSNegate %long %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 17: fold snegate(isub with const). + // -(x - 2) = 2 - x + InstructionFoldingCase( + Header() + + "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" + + "; CHECK: [[long_2:%\\w+]] = OpConstant [[long]] 2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" + + "; CHECK: %4 = OpISub [[long]] [[long_2]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_long Function\n" + + "%2 = OpLoad %long %var\n" + + "%3 = OpISub %long %2 %long_2\n" + + "%4 = OpSNegate %long %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 18: fold -vec4(-1.0, 2.0, 1.0, 3.0) + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[v4float:%\\w+]] = OpTypeVector [[float]] 4\n" + + "; CHECK: [[float_n1:%\\w+]] = OpConstant [[float]] -1\n" + + "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" + + "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" + + "; CHECK: [[float_n3:%\\w+]] = OpConstant [[float]] -3\n" + + "; CHECK: [[v4float_1_n2_n1_n3:%\\w+]] = OpConstantComposite [[v4float]] [[float_1]] [[float_n2]] [[float_n1]] [[float_n3]]\n" + + "; CHECK: %2 = OpCopyObject [[v4float]] [[v4float_1_n2_n1_n3]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFNegate %v4float %v4float_n1_2_1_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 19: fold vector fnegate with null + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: [[v2double:%\\w+]] = OpTypeVector [[double]] 2\n" + + "; CHECK: [[double_n0:%\\w+]] = OpConstant [[double]] -0\n" + + "; CHECK: [[v2double_0_0:%\\w+]] = OpConstantComposite [[v2double]] [[double_n0]] [[double_n0]]\n" + + "; CHECK: %2 = OpCopyObject [[v2double]] [[v2double_0_0]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpFNegate %v2double %v2double_null\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true) +)); + +INSTANTIATE_TEST_CASE_P(ReciprocalFDivTest, MatchingInstructionFoldingTest, +::testing::Values( + // Test case 0: scalar reicprocal + // x / 0.5 = x * 2.0 + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %3 = OpFMul [[float]] [[ld]] [[float_2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFDiv %float %2 %float_0p5\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 3, true), + // Test case 1: Unfoldable + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_0:%\\w+]] = OpConstant [[float]] 0\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %3 = OpFDiv [[float]] [[ld]] [[float_0]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFDiv %float %2 %104\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 3, false), + // Test case 2: Vector reciprocal + // x / {2.0, 0.5} = x * {0.5, 2.0} + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[v2float:%\\w+]] = OpTypeVector [[float]] 2\n" + + "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" + + "; CHECK: [[float_0p5:%\\w+]] = OpConstant [[float]] 0.5\n" + + "; CHECK: [[v2float_0p5_2:%\\w+]] = OpConstantComposite [[v2float]] [[float_0p5]] [[float_2]]\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[v2float]]\n" + + "; CHECK: %3 = OpFMul [[v2float]] [[ld]] [[v2float_0p5_2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_v2float Function\n" + + "%2 = OpLoad %v2float %var\n" + + "%3 = OpFDiv %v2float %2 %v2float_2_0p5\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 3, true), + // Test case 3: double reciprocal + // x / 2.0 = x * 0.5 + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: [[double_0p5:%\\w+]] = OpConstant [[double]] 0.5\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[double]]\n" + + "; CHECK: %3 = OpFMul [[double]] [[ld]] [[double_0p5]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_double Function\n" + + "%2 = OpLoad %double %var\n" + + "%3 = OpFDiv %double %2 %double_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 3, true), + // Test case 4: don't fold x / 0. + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_v2float Function\n" + + "%2 = OpLoad %v2float %var\n" + + "%3 = OpFDiv %v2float %2 %v2float_null\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 3, false) +)); + +INSTANTIATE_TEST_CASE_P(MergeMulTest, MatchingInstructionFoldingTest, +::testing::Values( + // Test case 0: fold consecutive fmuls + // (x * 3.0) * 2.0 = x * 6.0 + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_6:%\\w+]] = OpConstant [[float]] 6\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFMul [[float]] [[ld]] [[float_6]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFMul %float %2 %float_3\n" + + "%4 = OpFMul %float %3 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 1: fold consecutive fmuls + // 2.0 * (x * 3.0) = x * 6.0 + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_6:%\\w+]] = OpConstant [[float]] 6\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFMul [[float]] [[ld]] [[float_6]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFMul %float %2 %float_3\n" + + "%4 = OpFMul %float %float_2 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 2: fold consecutive fmuls + // (3.0 * x) * 2.0 = x * 6.0 + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_6:%\\w+]] = OpConstant [[float]] 6\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFMul [[float]] [[ld]] [[float_6]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFMul %float %float_3 %2\n" + + "%4 = OpFMul %float %float_2 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 3: fold vector fmul + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[v2float:%\\w+]] = OpTypeVector [[float]] 2\n" + + "; CHECK: [[float_6:%\\w+]] = OpConstant [[float]] 6\n" + + "; CHECK: [[v2float_6_6:%\\w+]] = OpConstantComposite [[v2float]] [[float_6]] [[float_6]]\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[v2float]]\n" + + "; CHECK: %4 = OpFMul [[v2float]] [[ld]] [[v2float_6_6]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_v2float Function\n" + + "%2 = OpLoad %v2float %var\n" + + "%3 = OpFMul %v2float %2 %v2float_2_3\n" + + "%4 = OpFMul %v2float %3 %v2float_3_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 4: fold double fmuls + // (x * 3.0) * 2.0 = x * 6.0 + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: [[double_6:%\\w+]] = OpConstant [[double]] 6\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[double]]\n" + + "; CHECK: %4 = OpFMul [[double]] [[ld]] [[double_6]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_double Function\n" + + "%2 = OpLoad %double %var\n" + + "%3 = OpFMul %double %2 %double_3\n" + + "%4 = OpFMul %double %3 %double_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 5: fold 32 bit imuls + // (x * 3) * 2 = x * 6 + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[int_6:%\\w+]] = OpConstant [[int]] 6\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpIMul [[int]] [[ld]] [[int_6]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpIMul %int %2 %int_3\n" + + "%4 = OpIMul %int %3 %int_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 6: fold 64 bit imuls + // (x * 3) * 2 = x * 6 + InstructionFoldingCase( + Header() + + "; CHECK: [[long:%\\w+]] = OpTypeInt 64\n" + + "; CHECK: [[long_6:%\\w+]] = OpConstant [[long]] 6\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" + + "; CHECK: %4 = OpIMul [[long]] [[ld]] [[long_6]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_long Function\n" + + "%2 = OpLoad %long %var\n" + + "%3 = OpIMul %long %2 %long_3\n" + + "%4 = OpIMul %long %3 %long_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 7: merge vector integer mults + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" + + "; CHECK: [[int_6:%\\w+]] = OpConstant [[int]] 6\n" + + "; CHECK: [[v2int_6_6:%\\w+]] = OpConstantComposite [[v2int]] [[int_6]] [[int_6]]\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[v2int]]\n" + + "; CHECK: %4 = OpIMul [[v2int]] [[ld]] [[v2int_6_6]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_v2int Function\n" + + "%2 = OpLoad %v2int %var\n" + + "%3 = OpIMul %v2int %2 %v2int_2_3\n" + + "%4 = OpIMul %v2int %3 %v2int_3_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 8: merge fmul of fdiv + // 2.0 * (2.0 / x) = 4.0 / x + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_4:%\\w+]] = OpConstant [[float]] 4\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFDiv [[float]] [[float_4]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFDiv %float %float_2 %2\n" + + "%4 = OpFMul %float %float_2 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 9: merge fmul of fdiv + // (2.0 / x) * 2.0 = 4.0 / x + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_4:%\\w+]] = OpConstant [[float]] 4\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFDiv [[float]] [[float_4]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFDiv %float %float_2 %2\n" + + "%4 = OpFMul %float %3 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 10: Do not merge imul of sdiv + // 4 * (x / 2) + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpSDiv %int %2 %int_2\n" + + "%4 = OpIMul %int %int_4 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, false), + // Test case 11: Do not merge imul of sdiv + // (x / 2) * 4 + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpSDiv %int %2 %int_2\n" + + "%4 = OpIMul %int %3 %int_4\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, false), + // Test case 12: Do not merge imul of udiv + // 4 * (x / 2) + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_uint Function\n" + + "%2 = OpLoad %uint %var\n" + + "%3 = OpUDiv %uint %2 %uint_2\n" + + "%4 = OpIMul %uint %uint_4 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, false), + // Test case 13: Do not merge imul of udiv + // (x / 2) * 4 + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_uint Function\n" + + "%2 = OpLoad %uint %var\n" + + "%3 = OpUDiv %uint %2 %uint_2\n" + + "%4 = OpIMul %uint %3 %uint_4\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, false), + // Test case 14: Don't fold + // (x / 3) * 4 + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_uint Function\n" + + "%2 = OpLoad %uint %var\n" + + "%3 = OpUDiv %uint %2 %uint_3\n" + + "%4 = OpIMul %uint %3 %uint_4\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, false), + // Test case 15: merge vector fmul of fdiv + // (x / {2,2}) * {4,4} = x * {2,2} + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[v2float:%\\w+]] = OpTypeVector [[float]] 2\n" + + "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" + + "; CHECK: [[v2float_2_2:%\\w+]] = OpConstantComposite [[v2float]] [[float_2]] [[float_2]]\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[v2float]]\n" + + "; CHECK: %4 = OpFMul [[v2float]] [[ld]] [[v2float_2_2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_v2float Function\n" + + "%2 = OpLoad %v2float %var\n" + + "%3 = OpFDiv %v2float %2 %v2float_2_2\n" + + "%4 = OpFMul %v2float %3 %v2float_4_4\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 16: merge vector imul of snegate + // (-x) * {2,2} = x * {-2,-2} + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" + + "; CHECK: OpConstant [[int]] -2147483648\n" + + "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" + + "; CHECK: [[v2int_n2_n2:%\\w+]] = OpConstantComposite [[v2int]] [[int_n2]] [[int_n2]]\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[v2int]]\n" + + "; CHECK: %4 = OpIMul [[v2int]] [[ld]] [[v2int_n2_n2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_v2int Function\n" + + "%2 = OpLoad %v2int %var\n" + + "%3 = OpSNegate %v2int %2\n" + + "%4 = OpIMul %v2int %3 %v2int_2_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 17: merge vector imul of snegate + // {2,2} * (-x) = x * {-2,-2} + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" + + "; CHECK: OpConstant [[int]] -2147483648\n" + + "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" + + "; CHECK: [[v2int_n2_n2:%\\w+]] = OpConstantComposite [[v2int]] [[int_n2]] [[int_n2]]\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[v2int]]\n" + + "; CHECK: %4 = OpIMul [[v2int]] [[ld]] [[v2int_n2_n2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_v2int Function\n" + + "%2 = OpLoad %v2int %var\n" + + "%3 = OpSNegate %v2int %2\n" + + "%4 = OpIMul %v2int %v2int_2_2 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 18: Fold OpVectorTimesScalar + // {4,4} = OpVectorTimesScalar v2float {2,2} 2 + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[v2float:%\\w+]] = OpTypeVector [[float]] 2\n" + + "; CHECK: [[float_4:%\\w+]] = OpConstant [[float]] 4\n" + + "; CHECK: [[v2float_4_4:%\\w+]] = OpConstantComposite [[v2float]] [[float_4]] [[float_4]]\n" + + "; CHECK: %2 = OpCopyObject [[v2float]] [[v2float_4_4]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVectorTimesScalar %v2float %v2float_2_2 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 19: Fold OpVectorTimesScalar + // {0,0} = OpVectorTimesScalar v2float v2float_null -1 + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[v2float:%\\w+]] = OpTypeVector [[float]] 2\n" + + "; CHECK: [[v2float_null:%\\w+]] = OpConstantNull [[v2float]]\n" + + "; CHECK: %2 = OpCopyObject [[v2float]] [[v2float_null]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVectorTimesScalar %v2float %v2float_null %float_n1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 20: Fold OpVectorTimesScalar + // {4,4} = OpVectorTimesScalar v2double {2,2} 2 + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: [[v2double:%\\w+]] = OpTypeVector [[double]] 2\n" + + "; CHECK: [[double_4:%\\w+]] = OpConstant [[double]] 4\n" + + "; CHECK: [[v2double_4_4:%\\w+]] = OpConstantComposite [[v2double]] [[double_4]] [[double_4]]\n" + + "; CHECK: %2 = OpCopyObject [[v2double]] [[v2double_4_4]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVectorTimesScalar %v2double %v2double_2_2 %double_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 21: Fold OpVectorTimesScalar + // {0,0} = OpVectorTimesScalar v2double {0,0} n + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: [[v2double:%\\w+]] = OpTypeVector [[double]] 2\n" + + "; CHECK: {{%\\w+}} = OpConstant [[double]] 0\n" + + "; CHECK: [[double_0:%\\w+]] = OpConstant [[double]] 0\n" + + "; CHECK: [[v2double_0_0:%\\w+]] = OpConstantComposite [[v2double]] [[double_0]] [[double_0]]\n" + + "; CHECK: %2 = OpCopyObject [[v2double]] [[v2double_0_0]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_double Function\n" + + "%load = OpLoad %double %n\n" + + "%2 = OpVectorTimesScalar %v2double %v2double_0_0 %load\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 22: Fold OpVectorTimesScalar + // {0,0} = OpVectorTimesScalar v2double n 0 + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: [[v2double:%\\w+]] = OpTypeVector [[double]] 2\n" + + "; CHECK: [[v2double_null:%\\w+]] = OpConstantNull [[v2double]]\n" + + "; CHECK: %2 = OpCopyObject [[v2double]] [[v2double_null]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v2double Function\n" + + "%load = OpLoad %v2double %n\n" + + "%2 = OpVectorTimesScalar %v2double %load %double_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 23: merge fmul of fdiv + // x * (y / x) = y + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[ldx:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: [[ldy:%\\w+]] = OpLoad [[float]] [[y:%\\w+]]\n" + + "; CHECK: %5 = OpCopyObject [[float]] [[ldy]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%x = OpVariable %_ptr_float Function\n" + + "%y = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %x\n" + + "%3 = OpLoad %float %y\n" + + "%4 = OpFDiv %float %3 %2\n" + + "%5 = OpFMul %float %2 %4\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 5, true), + // Test case 24: merge fmul of fdiv + // (y / x) * x = y + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[ldx:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: [[ldy:%\\w+]] = OpLoad [[float]] [[y:%\\w+]]\n" + + "; CHECK: %5 = OpCopyObject [[float]] [[ldy]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%x = OpVariable %_ptr_float Function\n" + + "%y = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %x\n" + + "%3 = OpLoad %float %y\n" + + "%4 = OpFDiv %float %3 %2\n" + + "%5 = OpFMul %float %4 %2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 5, true) +)); + +INSTANTIATE_TEST_CASE_P(MergeDivTest, MatchingInstructionFoldingTest, +::testing::Values( + // Test case 0: merge consecutive fdiv + // 4.0 / (2.0 / x) = 2.0 * x + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFMul [[float]] [[float_2]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFDiv %float %float_2 %2\n" + + "%4 = OpFDiv %float %float_4 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 1: merge consecutive fdiv + // 4.0 / (x / 2.0) = 8.0 / x + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_8:%\\w+]] = OpConstant [[float]] 8\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFDiv [[float]] [[float_8]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFDiv %float %2 %float_2\n" + + "%4 = OpFDiv %float %float_4 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 2: merge consecutive fdiv + // (4.0 / x) / 2.0 = 2.0 / x + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFDiv [[float]] [[float_2]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFDiv %float %float_4 %2\n" + + "%4 = OpFDiv %float %3 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 3: Do not merge consecutive sdiv + // 4 / (2 / x) + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpSDiv %int %int_2 %2\n" + + "%4 = OpSDiv %int %int_4 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, false), + // Test case 4: Do not merge consecutive sdiv + // 4 / (x / 2) + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpSDiv %int %2 %int_2\n" + + "%4 = OpSDiv %int %int_4 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, false), + // Test case 5: Do not merge consecutive sdiv + // (4 / x) / 2 + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpSDiv %int %int_4 %2\n" + + "%4 = OpSDiv %int %3 %int_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, false), + // Test case 6: Do not merge consecutive sdiv + // (x / 4) / 2 + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpSDiv %int %2 %int_4\n" + + "%4 = OpSDiv %int %3 %int_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, false), + // Test case 7: Do not merge sdiv of imul + // 4 / (2 * x) + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpIMul %int %int_2 %2\n" + + "%4 = OpSDiv %int %int_4 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, false), + // Test case 8: Do not merge sdiv of imul + // 4 / (x * 2) + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpIMul %int %2 %int_2\n" + + "%4 = OpSDiv %int %int_4 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, false), + // Test case 9: Do not merge sdiv of imul + // (4 * x) / 2 + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpIMul %int %int_4 %2\n" + + "%4 = OpSDiv %int %3 %int_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, false), + // Test case 10: Do not merge sdiv of imul + // (x * 4) / 2 + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpIMul %int %2 %int_4\n" + + "%4 = OpSDiv %int %3 %int_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, false), + // Test case 11: merge sdiv of snegate + // (-x) / 2 = x / -2 + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: OpConstant [[int]] -2147483648\n" + + "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpSDiv [[int]] [[ld]] [[int_n2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpSNegate %int %2\n" + + "%4 = OpSDiv %int %3 %int_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 12: merge sdiv of snegate + // 2 / (-x) = -2 / x + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: OpConstant [[int]] -2147483648\n" + + "; CHECK: [[int_n2:%\\w+]] = OpConstant [[int]] -2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpSDiv [[int]] [[int_n2]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpSNegate %int %2\n" + + "%4 = OpSDiv %int %int_2 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 13: Don't merge + // (x / {null}) / {null} + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_v2float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFDiv %float %2 %v2float_null\n" + + "%4 = OpFDiv %float %3 %v2float_null\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, false), + // Test case 14: merge fmul of fdiv + // (y * x) / x = y + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[ldx:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: [[ldy:%\\w+]] = OpLoad [[float]] [[y:%\\w+]]\n" + + "; CHECK: %5 = OpCopyObject [[float]] [[ldy]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%x = OpVariable %_ptr_float Function\n" + + "%y = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %x\n" + + "%3 = OpLoad %float %y\n" + + "%4 = OpFMul %float %3 %2\n" + + "%5 = OpFDiv %float %4 %2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 5, true), + // Test case 15: merge fmul of fdiv + // (x * y) / x = y + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[ldx:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: [[ldy:%\\w+]] = OpLoad [[float]] [[y:%\\w+]]\n" + + "; CHECK: %5 = OpCopyObject [[float]] [[ldy]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%x = OpVariable %_ptr_float Function\n" + + "%y = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %x\n" + + "%3 = OpLoad %float %y\n" + + "%4 = OpFMul %float %2 %3\n" + + "%5 = OpFDiv %float %4 %2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 5, true) +)); + +INSTANTIATE_TEST_CASE_P(MergeAddTest, MatchingInstructionFoldingTest, +::testing::Values( + // Test case 0: merge add of negate + // (-x) + 2 = 2 - x + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFSub [[float]] [[float_2]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFNegate %float %2\n" + + "%4 = OpFAdd %float %3 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 1: merge add of negate + // 2 + (-x) = 2 - x + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFSub [[float]] [[float_2]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpSNegate %float %2\n" + + "%4 = OpIAdd %float %float_2 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 2: merge add of negate + // (-x) + 2 = 2 - x + InstructionFoldingCase( + Header() + + "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" + + "; CHECK: [[long_2:%\\w+]] = OpConstant [[long]] 2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" + + "; CHECK: %4 = OpISub [[long]] [[long_2]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_long Function\n" + + "%2 = OpLoad %long %var\n" + + "%3 = OpSNegate %long %2\n" + + "%4 = OpIAdd %long %3 %long_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 3: merge add of negate + // 2 + (-x) = 2 - x + InstructionFoldingCase( + Header() + + "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" + + "; CHECK: [[long_2:%\\w+]] = OpConstant [[long]] 2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" + + "; CHECK: %4 = OpISub [[long]] [[long_2]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_long Function\n" + + "%2 = OpLoad %long %var\n" + + "%3 = OpSNegate %long %2\n" + + "%4 = OpIAdd %long %long_2 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 4: merge add of subtract + // (x - 1) + 2 = x + 1 + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_1]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFSub %float %2 %float_1\n" + + "%4 = OpFAdd %float %3 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 5: merge add of subtract + // (1 - x) + 2 = 3 - x + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_3:%\\w+]] = OpConstant [[float]] 3\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFSub [[float]] [[float_3]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFSub %float %float_1 %2\n" + + "%4 = OpFAdd %float %3 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 6: merge add of subtract + // 2 + (x - 1) = x + 1 + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_1]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFSub %float %2 %float_1\n" + + "%4 = OpFAdd %float %float_2 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 7: merge add of subtract + // 2 + (1 - x) = 3 - x + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_3:%\\w+]] = OpConstant [[float]] 3\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFSub [[float]] [[float_3]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFSub %float %float_1 %2\n" + + "%4 = OpFAdd %float %float_2 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 8: merge add of add + // (x + 1) + 2 = x + 3 + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_3:%\\w+]] = OpConstant [[float]] 3\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_3]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFAdd %float %2 %float_1\n" + + "%4 = OpFAdd %float %3 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 9: merge add of add + // (1 + x) + 2 = 3 + x + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_3:%\\w+]] = OpConstant [[float]] 3\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_3]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFAdd %float %float_1 %2\n" + + "%4 = OpFAdd %float %3 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 10: merge add of add + // 2 + (x + 1) = x + 1 + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_3:%\\w+]] = OpConstant [[float]] 3\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_3]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFAdd %float %2 %float_1\n" + + "%4 = OpFAdd %float %float_2 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 11: merge add of add + // 2 + (1 + x) = 3 - x + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_3:%\\w+]] = OpConstant [[float]] 3\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_3]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFAdd %float %float_1 %2\n" + + "%4 = OpFAdd %float %float_2 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true) +)); + +INSTANTIATE_TEST_CASE_P(MergeSubTest, MatchingInstructionFoldingTest, +::testing::Values( + // Test case 0: merge sub of negate + // (-x) - 2 = -2 - x + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_n2:%\\w+]] = OpConstant [[float]] -2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFSub [[float]] [[float_n2]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFNegate %float %2\n" + + "%4 = OpFSub %float %3 %float_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 1: merge sub of negate + // 2 - (-x) = x + 2 + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_2:%\\w+]] = OpConstant [[float]] 2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFNegate %float %2\n" + + "%4 = OpFSub %float %float_2 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 2: merge sub of negate + // (-x) - 2 = -2 - x + InstructionFoldingCase( + Header() + + "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" + + "; CHECK: [[long_n2:%\\w+]] = OpConstant [[long]] -2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" + + "; CHECK: %4 = OpISub [[long]] [[long_n2]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_long Function\n" + + "%2 = OpLoad %long %var\n" + + "%3 = OpSNegate %long %2\n" + + "%4 = OpISub %long %3 %long_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 3: merge sub of negate + // 2 - (-x) = x + 2 + InstructionFoldingCase( + Header() + + "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" + + "; CHECK: [[long_2:%\\w+]] = OpConstant [[long]] 2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" + + "; CHECK: %4 = OpIAdd [[long]] [[ld]] [[long_2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_long Function\n" + + "%2 = OpLoad %long %var\n" + + "%3 = OpSNegate %long %2\n" + + "%4 = OpISub %long %long_2 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 4: merge add of subtract + // (x + 2) - 1 = x + 1 + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_1]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFAdd %float %2 %float_2\n" + + "%4 = OpFSub %float %3 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 5: merge add of subtract + // (2 + x) - 1 = x + 1 + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_1]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFAdd %float %float_2 %2\n" + + "%4 = OpFSub %float %3 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 6: merge add of subtract + // 2 - (x + 1) = 1 - x + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFSub [[float]] [[float_1]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFAdd %float %2 %float_1\n" + + "%4 = OpFSub %float %float_2 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 7: merge add of subtract + // 2 - (1 + x) = 1 - x + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFSub [[float]] [[float_1]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFAdd %float %float_1 %2\n" + + "%4 = OpFSub %float %float_2 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 8: merge subtract of subtract + // (x - 2) - 1 = x - 3 + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_3:%\\w+]] = OpConstant [[float]] 3\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFSub [[float]] [[ld]] [[float_3]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFSub %float %2 %float_2\n" + + "%4 = OpFSub %float %3 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 9: merge subtract of subtract + // (2 - x) - 1 = 1 - x + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFSub [[float]] [[float_1]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFSub %float %float_2 %2\n" + + "%4 = OpFSub %float %3 %float_1\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 10: merge subtract of subtract + // 2 - (x - 1) = 3 - x + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_3:%\\w+]] = OpConstant [[float]] 3\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFSub [[float]] [[float_3]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFSub %float %2 %float_1\n" + + "%4 = OpFSub %float %float_2 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 11: merge subtract of subtract + // 1 - (2 - x) = x + (-1) + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_n1:%\\w+]] = OpConstant [[float]] -1\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_n1]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFSub %float %float_2 %2\n" + + "%4 = OpFSub %float %float_1 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 12: merge subtract of subtract + // 2 - (1 - x) = x + 1 + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: [[float_1:%\\w+]] = OpConstant [[float]] 1\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[float]]\n" + + "; CHECK: %4 = OpFAdd [[float]] [[ld]] [[float_1]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_float Function\n" + + "%2 = OpLoad %float %var\n" + + "%3 = OpFSub %float %float_1 %2\n" + + "%4 = OpFSub %float %float_2 %3\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true) +)); + +INSTANTIATE_TEST_CASE_P(SelectFoldingTest, MatchingInstructionFoldingTest, +::testing::Values( + // Test case 0: Fold select with the same values for both sides + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[int0:%\\w+]] = OpConstant [[int]] 0\n" + + "; CHECK: %2 = OpCopyObject [[int]] [[int0]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_bool Function\n" + + "%load = OpLoad %bool %n\n" + + "%2 = OpSelect %int %load %100 %100\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 1: Fold select true to left side + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[int0:%\\w+]] = OpConstant [[int]] 0\n" + + "; CHECK: %2 = OpCopyObject [[int]] [[int0]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %bool %n\n" + + "%2 = OpSelect %int %true %100 %n\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 2: Fold select false to right side + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[int0:%\\w+]] = OpConstant [[int]] 0\n" + + "; CHECK: %2 = OpCopyObject [[int]] [[int0]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %bool %n\n" + + "%2 = OpSelect %int %false %n %100\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 3: Fold select null to right side + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[int0:%\\w+]] = OpConstant [[int]] 0\n" + + "; CHECK: %2 = OpCopyObject [[int]] [[int0]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSelect %int %bool_null %load %100\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 4: vector null + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" + + "; CHECK: [[int2:%\\w+]] = OpConstant [[int]] 2\n" + + "; CHECK: [[v2int2_2:%\\w+]] = OpConstantComposite [[v2int]] [[int2]] [[int2]]\n" + + "; CHECK: %2 = OpCopyObject [[v2int]] [[v2int2_2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v2int Function\n" + + "%load = OpLoad %v2int %n\n" + + "%2 = OpSelect %v2int %v2bool_null %load %v2int_2_2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, true), + // Test case 5: vector select + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" + + "; CHECK: %4 = OpVectorShuffle [[v2int]] %2 %3 0 3\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%m = OpVariable %_ptr_v2int Function\n" + + "%n = OpVariable %_ptr_v2int Function\n" + + "%2 = OpLoad %v2int %n\n" + + "%3 = OpLoad %v2int %n\n" + + "%4 = OpSelect %v2int %v2bool_true_false %2 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 6: vector select + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" + + "; CHECK: %4 = OpVectorShuffle [[v2int]] %2 %3 2 1\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%m = OpVariable %_ptr_v2int Function\n" + + "%n = OpVariable %_ptr_v2int Function\n" + + "%2 = OpLoad %v2int %n\n" + + "%3 = OpLoad %v2int %n\n" + + "%4 = OpSelect %v2int %v2bool_false_true %2 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true) +)); + +INSTANTIATE_TEST_CASE_P(CompositeExtractMatchingTest, MatchingInstructionFoldingTest, +::testing::Values( + // Test case 0: Extracting from result of consecutive shuffles of differing + // size. + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: %5 = OpCompositeExtract [[int]] %2 2\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4int Function\n" + + "%2 = OpLoad %v4int %n\n" + + "%3 = OpVectorShuffle %v2int %2 %2 2 3\n" + + "%4 = OpVectorShuffle %v4int %2 %3 0 4 2 5\n" + + "%5 = OpCompositeExtract %int %4 1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 5, true), + // Test case 1: Extracting from result of vector shuffle of differing + // input and result sizes. + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: %4 = OpCompositeExtract [[int]] %2 2\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4int Function\n" + + "%2 = OpLoad %v4int %n\n" + + "%3 = OpVectorShuffle %v2int %2 %2 2 3\n" + + "%4 = OpCompositeExtract %int %3 0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 2: Extracting from result of vector shuffle of differing + // input and result sizes. + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: %4 = OpCompositeExtract [[int]] %2 3\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4int Function\n" + + "%2 = OpLoad %v4int %n\n" + + "%3 = OpVectorShuffle %v2int %2 %2 2 3\n" + + "%4 = OpCompositeExtract %int %3 1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 3: Using fmix feeding extract with a 1 in the a position. + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: [[v4double:%\\w+]] = OpTypeVector [[double]] 4\n" + + "; CHECK: [[ptr_v4double:%\\w+]] = OpTypePointer Function [[v4double]]\n" + + "; CHECK: [[m:%\\w+]] = OpVariable [[ptr_v4double]] Function\n" + + "; CHECK: [[n:%\\w+]] = OpVariable [[ptr_v4double]] Function\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[v4double]] [[n]]\n" + + "; CHECK: %5 = OpCompositeExtract [[double]] [[ld]] 1\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%m = OpVariable %_ptr_v4double Function\n" + + "%n = OpVariable %_ptr_v4double Function\n" + + "%2 = OpLoad %v4double %m\n" + + "%3 = OpLoad %v4double %n\n" + + "%4 = OpExtInst %v4double %1 FMix %2 %3 %v4double_0_1_0_0\n" + + "%5 = OpCompositeExtract %double %4 1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 5, true), + // Test case 4: Using fmix feeding extract with a 0 in the a position. + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: [[v4double:%\\w+]] = OpTypeVector [[double]] 4\n" + + "; CHECK: [[ptr_v4double:%\\w+]] = OpTypePointer Function [[v4double]]\n" + + "; CHECK: [[m:%\\w+]] = OpVariable [[ptr_v4double]] Function\n" + + "; CHECK: [[n:%\\w+]] = OpVariable [[ptr_v4double]] Function\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[v4double]] [[m]]\n" + + "; CHECK: %5 = OpCompositeExtract [[double]] [[ld]] 2\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%m = OpVariable %_ptr_v4double Function\n" + + "%n = OpVariable %_ptr_v4double Function\n" + + "%2 = OpLoad %v4double %m\n" + + "%3 = OpLoad %v4double %n\n" + + "%4 = OpExtInst %v4double %1 FMix %2 %3 %v4double_0_1_0_0\n" + + "%5 = OpCompositeExtract %double %4 2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 5, true), + // Test case 5: Using fmix feeding extract with a null for the alpha + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: [[v4double:%\\w+]] = OpTypeVector [[double]] 4\n" + + "; CHECK: [[ptr_v4double:%\\w+]] = OpTypePointer Function [[v4double]]\n" + + "; CHECK: [[m:%\\w+]] = OpVariable [[ptr_v4double]] Function\n" + + "; CHECK: [[n:%\\w+]] = OpVariable [[ptr_v4double]] Function\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[v4double]] [[m]]\n" + + "; CHECK: %5 = OpCompositeExtract [[double]] [[ld]] 0\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%m = OpVariable %_ptr_v4double Function\n" + + "%n = OpVariable %_ptr_v4double Function\n" + + "%2 = OpLoad %v4double %m\n" + + "%3 = OpLoad %v4double %n\n" + + "%4 = OpExtInst %v4double %1 FMix %2 %3 %v4double_null\n" + + "%5 = OpCompositeExtract %double %4 0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 5, true), + // Test case 6: Don't fold: Using fmix feeding extract with 0.5 in the a + // position. + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%m = OpVariable %_ptr_v4double Function\n" + + "%n = OpVariable %_ptr_v4double Function\n" + + "%2 = OpLoad %v4double %m\n" + + "%3 = OpLoad %v4double %n\n" + + "%4 = OpExtInst %v4double %1 FMix %2 %3 %v4double_1_1_1_0p5\n" + + "%5 = OpCompositeExtract %double %4 3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 5, false), + // Test case 7: Extracting the undefined literal value from a vector + // shuffle. + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: %4 = OpUndef [[int]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4int Function\n" + + "%2 = OpLoad %v4int %n\n" + + "%3 = OpVectorShuffle %v2int %2 %2 2 4294967295\n" + + "%4 = OpCompositeExtract %int %3 1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true) +)); + +INSTANTIATE_TEST_CASE_P(DotProductMatchingTest, MatchingInstructionFoldingTest, +::testing::Values( + // Test case 0: Using OpDot to extract last element. + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: %3 = OpCompositeExtract [[float]] %2 3\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4float Function\n" + + "%2 = OpLoad %v4float %n\n" + + "%3 = OpDot %float %2 %v4float_0_0_0_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, true), + // Test case 1: Using OpDot to extract last element. + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: %3 = OpCompositeExtract [[float]] %2 3\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4float Function\n" + + "%2 = OpLoad %v4float %n\n" + + "%3 = OpDot %float %v4float_0_0_0_1 %2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, true), + // Test case 2: Using OpDot to extract second element. + InstructionFoldingCase( + Header() + + "; CHECK: [[float:%\\w+]] = OpTypeFloat 32\n" + + "; CHECK: %3 = OpCompositeExtract [[float]] %2 1\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4float Function\n" + + "%2 = OpLoad %v4float %n\n" + + "%3 = OpDot %float %v4float_0_1_0_0 %2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, true), + // Test case 3: Using OpDot to extract last element. + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: %3 = OpCompositeExtract [[double]] %2 3\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4double Function\n" + + "%2 = OpLoad %v4double %n\n" + + "%3 = OpDot %double %2 %v4double_0_0_0_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, true), + // Test case 4: Using OpDot to extract last element. + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: %3 = OpCompositeExtract [[double]] %2 3\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4double Function\n" + + "%2 = OpLoad %v4double %n\n" + + "%3 = OpDot %double %v4double_0_0_0_1 %2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, true), + // Test case 5: Using OpDot to extract second element. + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: %3 = OpCompositeExtract [[double]] %2 1\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4double Function\n" + + "%2 = OpLoad %v4double %n\n" + + "%3 = OpDot %double %v4double_0_1_0_0 %2\n" + + "OpReturn\n" + + "OpFunctionEnd", + 3, true) +)); + +using MatchingInstructionWithNoResultFoldingTest = +::testing::TestWithParam>; + +// Test folding instructions that do not have a result. The instruction +// that will be folded is the last instruction before the return. If there +// are multiple returns, there is not guarentee which one is used. +TEST_P(MatchingInstructionWithNoResultFoldingTest, Case) { + const auto& tc = GetParam(); + + // Build module. + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(nullptr, context); + + // Fold the instruction to test. + Instruction* inst = nullptr; + Function* func = &*context->module()->begin(); + for (auto& bb : *func) { + Instruction* terminator = bb.terminator(); + if (terminator->IsReturnOrAbort()) { + inst = terminator->PreviousNode(); + break; + } + } + assert(inst && "Invalid test. Could not find instruction to fold."); + std::unique_ptr original_inst(inst->Clone(context.get())); + bool succeeded = context->get_instruction_folder().FoldInstruction(inst); + EXPECT_EQ(succeeded, tc.expected_result); + if (succeeded) { + Match(tc.test_body, context.get()); + } +} + +INSTANTIATE_TEST_CASE_P(StoreMatchingTest, MatchingInstructionWithNoResultFoldingTest, +::testing::Values( + // Test case 0: Remove store of undef. + InstructionFoldingCase( + Header() + + "; CHECK: OpLabel\n" + + "; CHECK-NOT: OpStore\n" + + "; CHECK: OpReturn\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4double Function\n" + + "%undef = OpUndef %v4double\n" + + "OpStore %n %undef\n" + + "OpReturn\n" + + "OpFunctionEnd", + 0 /* OpStore */, true), + // Test case 1: Keep volatile store. + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v4double Function\n" + + "%undef = OpUndef %v4double\n" + + "OpStore %n %undef Volatile\n" + + "OpReturn\n" + + "OpFunctionEnd", + 0 /* OpStore */, false) +)); + +INSTANTIATE_TEST_CASE_P(VectorShuffleMatchingTest, MatchingInstructionWithNoResultFoldingTest, +::testing::Values( + // Test case 0: Basic test 1 + InstructionFoldingCase( + Header() + + "; CHECK: OpVectorShuffle\n" + + "; CHECK: OpVectorShuffle {{%\\w+}} %7 %5 2 3 6 7\n" + + "; CHECK: OpReturn\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v4double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%4 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v4double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%7 = OpLoad %v4double %4\n" + + "%8 = OpVectorShuffle %v4double %5 %6 2 3 4 5\n" + + "%9 = OpVectorShuffle %v4double %7 %8 2 3 4 5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, true), + // Test case 1: Basic test 2 + InstructionFoldingCase( + Header() + + "; CHECK: OpVectorShuffle\n" + + "; CHECK: OpVectorShuffle {{%\\w+}} %6 %7 0 1 4 5\n" + + "; CHECK: OpReturn\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v4double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%4 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v4double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%7 = OpLoad %v4double %4\n" + + "%8 = OpVectorShuffle %v4double %5 %6 2 3 4 5\n" + + "%9 = OpVectorShuffle %v4double %8 %7 2 3 4 5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, true), + // Test case 2: Basic test 3 + InstructionFoldingCase( + Header() + + "; CHECK: OpVectorShuffle\n" + + "; CHECK: OpVectorShuffle {{%\\w+}} %5 %7 3 2 4 5\n" + + "; CHECK: OpReturn\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v4double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%4 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v4double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%7 = OpLoad %v4double %4\n" + + "%8 = OpVectorShuffle %v4double %5 %6 2 3 4 5\n" + + "%9 = OpVectorShuffle %v4double %8 %7 1 0 4 5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, true), + // Test case 3: Basic test 4 + InstructionFoldingCase( + Header() + + "; CHECK: OpVectorShuffle\n" + + "; CHECK: OpVectorShuffle {{%\\w+}} %7 %6 2 3 5 4\n" + + "; CHECK: OpReturn\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v4double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%4 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v4double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%7 = OpLoad %v4double %4\n" + + "%8 = OpVectorShuffle %v4double %5 %6 2 3 4 5\n" + + "%9 = OpVectorShuffle %v4double %7 %8 2 3 7 6\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, true), + // Test case 4: Don't fold, need both operands of the feeder. + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v4double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%4 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v4double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%7 = OpLoad %v4double %4\n" + + "%8 = OpVectorShuffle %v4double %5 %6 2 3 4 5\n" + + "%9 = OpVectorShuffle %v4double %7 %8 2 3 7 5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, false), + // Test case 5: Don't fold, need both operands of the feeder. + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v4double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%4 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v4double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%7 = OpLoad %v4double %4\n" + + "%8 = OpVectorShuffle %v4double %5 %6 2 3 4 5\n" + + "%9 = OpVectorShuffle %v4double %8 %7 2 0 7 5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, false), + // Test case 6: Fold, need both operands of the feeder, but they are the same. + InstructionFoldingCase( + Header() + + "; CHECK: OpVectorShuffle\n" + + "; CHECK: OpVectorShuffle {{%\\w+}} %5 %7 0 2 7 5\n" + + "; CHECK: OpReturn\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v4double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%4 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v4double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%7 = OpLoad %v4double %4\n" + + "%8 = OpVectorShuffle %v4double %5 %5 2 3 4 5\n" + + "%9 = OpVectorShuffle %v4double %8 %7 2 0 7 5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, true), + // Test case 7: Fold, need both operands of the feeder, but they are the same. + InstructionFoldingCase( + Header() + + "; CHECK: OpVectorShuffle\n" + + "; CHECK: OpVectorShuffle {{%\\w+}} %7 %5 2 0 5 7\n" + + "; CHECK: OpReturn\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v4double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%4 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v4double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%7 = OpLoad %v4double %4\n" + + "%8 = OpVectorShuffle %v4double %5 %5 2 3 4 5\n" + + "%9 = OpVectorShuffle %v4double %7 %8 2 0 7 5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, true), + // Test case 8: Replace first operand with a smaller vector. + InstructionFoldingCase( + Header() + + "; CHECK: OpVectorShuffle\n" + + "; CHECK: OpVectorShuffle {{%\\w+}} %5 %7 0 0 5 3\n" + + "; CHECK: OpReturn\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v2double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%4 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v2double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%7 = OpLoad %v4double %4\n" + + "%8 = OpVectorShuffle %v4double %5 %5 0 1 2 3\n" + + "%9 = OpVectorShuffle %v4double %8 %7 2 0 7 5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, true), + // Test case 9: Replace first operand with a larger vector. + InstructionFoldingCase( + Header() + + "; CHECK: OpVectorShuffle\n" + + "; CHECK: OpVectorShuffle {{%\\w+}} %5 %7 3 0 7 5\n" + + "; CHECK: OpReturn\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v4double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%4 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v4double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%7 = OpLoad %v4double %4\n" + + "%8 = OpVectorShuffle %v2double %5 %5 0 3\n" + + "%9 = OpVectorShuffle %v4double %8 %7 1 0 5 3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, true), + // Test case 10: Replace unused operand with null. + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: [[v4double:%\\w+]] = OpTypeVector [[double]] 2\n" + + "; CHECK: [[null:%\\w+]] = OpConstantNull [[v4double]]\n" + + "; CHECK: OpVectorShuffle\n" + + "; CHECK: OpVectorShuffle {{%\\w+}} [[null]] %7 4 2 5 3\n" + + "; CHECK: OpReturn\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v4double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%4 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v4double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%7 = OpLoad %v4double %4\n" + + "%8 = OpVectorShuffle %v2double %5 %5 0 3\n" + + "%9 = OpVectorShuffle %v4double %8 %7 4 2 5 3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, true), + // Test case 11: Replace unused operand with null. + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: [[v4double:%\\w+]] = OpTypeVector [[double]] 2\n" + + "; CHECK: [[null:%\\w+]] = OpConstantNull [[v4double]]\n" + + "; CHECK: OpVectorShuffle\n" + + "; CHECK: OpVectorShuffle {{%\\w+}} [[null]] %5 2 2 5 5\n" + + "; CHECK: OpReturn\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v4double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v4double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%8 = OpVectorShuffle %v2double %5 %5 0 3\n" + + "%9 = OpVectorShuffle %v4double %8 %8 2 2 3 3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, true), + // Test case 12: Replace unused operand with null. + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: [[v4double:%\\w+]] = OpTypeVector [[double]] 2\n" + + "; CHECK: [[null:%\\w+]] = OpConstantNull [[v4double]]\n" + + "; CHECK: OpVectorShuffle\n" + + "; CHECK: OpVectorShuffle {{%\\w+}} %7 [[null]] 2 0 1 3\n" + + "; CHECK: OpReturn\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v4double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%4 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v4double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%7 = OpLoad %v4double %4\n" + + "%8 = OpVectorShuffle %v2double %5 %5 0 3\n" + + "%9 = OpVectorShuffle %v4double %7 %8 2 0 1 3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, true), + // Test case 13: Shuffle with undef literal. + InstructionFoldingCase( + Header() + + "; CHECK: [[double:%\\w+]] = OpTypeFloat 64\n" + + "; CHECK: [[v4double:%\\w+]] = OpTypeVector [[double]] 2\n" + + "; CHECK: OpVectorShuffle\n" + + "; CHECK: OpVectorShuffle {{%\\w+}} %7 {{%\\w+}} 2 0 1 4294967295\n" + + "; CHECK: OpReturn\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpVariable %_ptr_v4double Function\n" + + "%3 = OpVariable %_ptr_v4double Function\n" + + "%4 = OpVariable %_ptr_v4double Function\n" + + "%5 = OpLoad %v4double %2\n" + + "%6 = OpLoad %v4double %3\n" + + "%7 = OpLoad %v4double %4\n" + + "%8 = OpVectorShuffle %v2double %5 %5 0 1\n" + + "%9 = OpVectorShuffle %v4double %7 %8 2 0 1 4294967295\n" + + "OpReturn\n" + + "OpFunctionEnd", + 9, true) +)); + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/freeze_spec_const_test.cpp b/test/opt/freeze_spec_const_test.cpp new file mode 100644 index 000000000..5cc7843b1 --- /dev/null +++ b/test/opt/freeze_spec_const_test.cpp @@ -0,0 +1,133 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +struct FreezeSpecConstantValueTypeTestCase { + const char* type_decl; + const char* spec_const; + const char* expected_frozen_const; +}; + +using FreezeSpecConstantValueTypeTest = + PassTest<::testing::TestWithParam>; + +TEST_P(FreezeSpecConstantValueTypeTest, PrimaryType) { + auto& test_case = GetParam(); + std::vector text = {"OpCapability Shader", + "OpMemoryModel Logical GLSL450", + test_case.type_decl, test_case.spec_const}; + std::vector expected = { + "OpCapability Shader", "OpMemoryModel Logical GLSL450", + test_case.type_decl, test_case.expected_frozen_const}; + SinglePassRunAndCheck( + JoinAllInsts(text), JoinAllInsts(expected), /* skip_nop = */ false); +} + +// Test each primary type. +INSTANTIATE_TEST_CASE_P( + PrimaryTypeSpecConst, FreezeSpecConstantValueTypeTest, + ::testing::ValuesIn(std::vector({ + // Type declaration, original spec constant definition, expected frozen + // spec constants. + {"%int = OpTypeInt 32 1", "%2 = OpSpecConstant %int 1", + "%int_1 = OpConstant %int 1"}, + {"%uint = OpTypeInt 32 0", "%2 = OpSpecConstant %uint 1", + "%uint_1 = OpConstant %uint 1"}, + {"%float = OpTypeFloat 32", "%2 = OpSpecConstant %float 3.1415", + "%float_3_1415 = OpConstant %float 3.1415"}, + {"%double = OpTypeFloat 64", "%2 = OpSpecConstant %double 3.141592653", + "%double_3_141592653 = OpConstant %double 3.141592653"}, + {"%bool = OpTypeBool", "%2 = OpSpecConstantTrue %bool", + "%true = OpConstantTrue %bool"}, + {"%bool = OpTypeBool", "%2 = OpSpecConstantFalse %bool", + "%false = OpConstantFalse %bool"}, + }))); + +using FreezeSpecConstantValueRemoveDecorationTest = PassTest<::testing::Test>; + +TEST_F(FreezeSpecConstantValueRemoveDecorationTest, + RemoveDecorationInstWithSpecId) { + std::vector text = { + // clang-format off + "OpCapability Shader", + "OpCapability Float64", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Vertex %main \"main\"", + "OpSource GLSL 450", + "OpSourceExtension \"GL_GOOGLE_cpp_style_line_directive\"", + "OpSourceExtension \"GL_GOOGLE_include_directive\"", + "OpName %main \"main\"", + "OpDecorate %3 SpecId 200", + "OpDecorate %4 SpecId 201", + "OpDecorate %5 SpecId 202", + "OpDecorate %6 SpecId 203", + "%void = OpTypeVoid", + "%8 = OpTypeFunction %void", + "%int = OpTypeInt 32 1", + "%3 = OpSpecConstant %int 3", + "%float = OpTypeFloat 32", + "%4 = OpSpecConstant %float 3.1415", + "%double = OpTypeFloat 64", + "%5 = OpSpecConstant %double 3.14159265358979", + "%bool = OpTypeBool", + "%6 = OpSpecConstantTrue %bool", + "%13 = OpSpecConstantFalse %bool", + "%main = OpFunction %void None %8", + "%14 = OpLabel", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + std::string expected_disassembly = SelectiveJoin(text, [](const char* line) { + return std::string(line).find("SpecId") != std::string::npos; + }); + std::vector> replacement_pairs = { + {"%3 = OpSpecConstant %int 3", "%int_3 = OpConstant %int 3"}, + {"%4 = OpSpecConstant %float 3.1415", + "%float_3_1415 = OpConstant %float 3.1415"}, + {"%5 = OpSpecConstant %double 3.14159265358979", + "%double_3_14159265358979 = OpConstant %double 3.14159265358979"}, + {"%6 = OpSpecConstantTrue ", "%true = OpConstantTrue "}, + {"%13 = OpSpecConstantFalse ", "%false = OpConstantFalse "}, + }; + for (auto& p : replacement_pairs) { + EXPECT_TRUE(FindAndReplace(&expected_disassembly, p.first, p.second)) + << "text:\n" + << expected_disassembly << "\n" + << "find_str:\n" + << p.first << "\n" + << "replace_str:\n" + << p.second << "\n"; + } + SinglePassRunAndCheck(JoinAllInsts(text), + expected_disassembly, + /* skip_nop = */ true); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/function_test.cpp b/test/opt/function_test.cpp new file mode 100644 index 000000000..38ab29876 --- /dev/null +++ b/test/opt/function_test.cpp @@ -0,0 +1,173 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "function_utils.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "source/opt/build_module.h" +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::Eq; + +TEST(FunctionTest, IsNotRecursive) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "main" +OpExecutionMode %1 OriginUpperLeft +OpDecorate %2 DescriptorSet 439418829 +%void = OpTypeVoid +%4 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_struct_6 = OpTypeStruct %float %float +%7 = OpTypeFunction %_struct_6 +%1 = OpFunction %void Pure|Const %4 +%8 = OpLabel +%2 = OpFunctionCall %_struct_6 %9 +OpKill +OpFunctionEnd +%9 = OpFunction %_struct_6 None %7 +%10 = OpLabel +%11 = OpFunctionCall %_struct_6 %12 +OpUnreachable +OpFunctionEnd +%12 = OpFunction %_struct_6 None %7 +%13 = OpLabel +OpUnreachable +OpFunctionEnd +)"; + + std::unique_ptr ctx = + spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + auto* func = spvtest::GetFunction(ctx->module(), 9); + EXPECT_FALSE(func->IsRecursive()); + + func = spvtest::GetFunction(ctx->module(), 12); + EXPECT_FALSE(func->IsRecursive()); +} + +TEST(FunctionTest, IsDirectlyRecursive) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "main" +OpExecutionMode %1 OriginUpperLeft +OpDecorate %2 DescriptorSet 439418829 +%void = OpTypeVoid +%4 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_struct_6 = OpTypeStruct %float %float +%7 = OpTypeFunction %_struct_6 +%1 = OpFunction %void Pure|Const %4 +%8 = OpLabel +%2 = OpFunctionCall %_struct_6 %9 +OpKill +OpFunctionEnd +%9 = OpFunction %_struct_6 None %7 +%10 = OpLabel +%11 = OpFunctionCall %_struct_6 %9 +OpUnreachable +OpFunctionEnd +)"; + + std::unique_ptr ctx = + spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + auto* func = spvtest::GetFunction(ctx->module(), 9); + EXPECT_TRUE(func->IsRecursive()); +} + +TEST(FunctionTest, IsIndirectlyRecursive) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "main" +OpExecutionMode %1 OriginUpperLeft +OpDecorate %2 DescriptorSet 439418829 +%void = OpTypeVoid +%4 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_struct_6 = OpTypeStruct %float %float +%7 = OpTypeFunction %_struct_6 +%1 = OpFunction %void Pure|Const %4 +%8 = OpLabel +%2 = OpFunctionCall %_struct_6 %9 +OpKill +OpFunctionEnd +%9 = OpFunction %_struct_6 None %7 +%10 = OpLabel +%11 = OpFunctionCall %_struct_6 %12 +OpUnreachable +OpFunctionEnd +%12 = OpFunction %_struct_6 None %7 +%13 = OpLabel +%14 = OpFunctionCall %_struct_6 %9 +OpUnreachable +OpFunctionEnd +)"; + + std::unique_ptr ctx = + spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + auto* func = spvtest::GetFunction(ctx->module(), 9); + EXPECT_TRUE(func->IsRecursive()); + + func = spvtest::GetFunction(ctx->module(), 12); + EXPECT_TRUE(func->IsRecursive()); +} + +TEST(FunctionTest, IsNotRecuriseCallingRecursive) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "main" +OpExecutionMode %1 OriginUpperLeft +OpDecorate %2 DescriptorSet 439418829 +%void = OpTypeVoid +%4 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_struct_6 = OpTypeStruct %float %float +%7 = OpTypeFunction %_struct_6 +%1 = OpFunction %void Pure|Const %4 +%8 = OpLabel +%2 = OpFunctionCall %_struct_6 %9 +OpKill +OpFunctionEnd +%9 = OpFunction %_struct_6 None %7 +%10 = OpLabel +%11 = OpFunctionCall %_struct_6 %9 +OpUnreachable +OpFunctionEnd +)"; + + std::unique_ptr ctx = + spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + auto* func = spvtest::GetFunction(ctx->module(), 1); + EXPECT_FALSE(func->IsRecursive()); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/function_utils.h b/test/opt/function_utils.h new file mode 100644 index 000000000..803cacdd5 --- /dev/null +++ b/test/opt/function_utils.h @@ -0,0 +1,55 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TEST_OPT_FUNCTION_UTILS_H_ +#define TEST_OPT_FUNCTION_UTILS_H_ + +#include "source/opt/function.h" +#include "source/opt/module.h" + +namespace spvtest { + +inline spvtools::opt::Function* GetFunction(spvtools::opt::Module* module, + uint32_t id) { + for (spvtools::opt::Function& f : *module) { + if (f.result_id() == id) { + return &f; + } + } + return nullptr; +} + +inline const spvtools::opt::Function* GetFunction( + const spvtools::opt::Module* module, uint32_t id) { + for (const spvtools::opt::Function& f : *module) { + if (f.result_id() == id) { + return &f; + } + } + return nullptr; +} + +inline const spvtools::opt::BasicBlock* GetBasicBlock( + const spvtools::opt::Function* fn, uint32_t id) { + for (const spvtools::opt::BasicBlock& bb : *fn) { + if (bb.id() == id) { + return &bb; + } + } + return nullptr; +} + +} // namespace spvtest + +#endif // TEST_OPT_FUNCTION_UTILS_H_ diff --git a/test/opt/if_conversion_test.cpp b/test/opt/if_conversion_test.cpp new file mode 100644 index 000000000..03932a95a --- /dev/null +++ b/test/opt/if_conversion_test.cpp @@ -0,0 +1,511 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using IfConversionTest = PassTest<::testing::Test>; + +TEST_F(IfConversionTest, TestSimpleIfThenElse) { + const std::string text = R"( +; CHECK: OpSelectionMerge [[merge:%\w+]] +; CHECK: [[merge]] = OpLabel +; CHECK-NOT: OpPhi +; CHECK: [[sel:%\w+]] = OpSelect %uint %true %uint_0 %uint_1 +; CHECK OpStore {{%\w+}} [[sel]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %1 "func" %2 +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%_ptr_Output_uint = OpTypePointer Output %uint +%2 = OpVariable %_ptr_Output_uint Output +%11 = OpTypeFunction %void +%1 = OpFunction %void None %11 +%12 = OpLabel +OpSelectionMerge %14 None +OpBranchConditional %true %15 %16 +%15 = OpLabel +OpBranch %14 +%16 = OpLabel +OpBranch %14 +%14 = OpLabel +%18 = OpPhi %uint %uint_0 %15 %uint_1 %16 +OpStore %2 %18 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(IfConversionTest, TestSimpleHalfIfTrue) { + const std::string text = R"( +; CHECK: OpSelectionMerge [[merge:%\w+]] +; CHECK: [[merge]] = OpLabel +; CHECK-NOT: OpPhi +; CHECK: [[sel:%\w+]] = OpSelect %uint %true %uint_0 %uint_1 +; CHECK OpStore {{%\w+}} [[sel]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %1 "func" %2 +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%_ptr_Output_uint = OpTypePointer Output %uint +%2 = OpVariable %_ptr_Output_uint Output +%11 = OpTypeFunction %void +%1 = OpFunction %void None %11 +%12 = OpLabel +OpSelectionMerge %14 None +OpBranchConditional %true %15 %14 +%15 = OpLabel +OpBranch %14 +%14 = OpLabel +%18 = OpPhi %uint %uint_0 %15 %uint_1 %12 +OpStore %2 %18 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(IfConversionTest, TestSimpleHalfIfExtraBlock) { + const std::string text = R"( +; CHECK: OpSelectionMerge [[merge:%\w+]] +; CHECK: [[merge]] = OpLabel +; CHECK-NOT: OpPhi +; CHECK: [[sel:%\w+]] = OpSelect %uint %true %uint_0 %uint_1 +; CHECK OpStore {{%\w+}} [[sel]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %1 "func" %2 +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%_ptr_Output_uint = OpTypePointer Output %uint +%2 = OpVariable %_ptr_Output_uint Output +%11 = OpTypeFunction %void +%1 = OpFunction %void None %11 +%12 = OpLabel +OpSelectionMerge %14 None +OpBranchConditional %true %15 %14 +%15 = OpLabel +OpBranch %16 +%16 = OpLabel +OpBranch %14 +%14 = OpLabel +%18 = OpPhi %uint %uint_0 %15 %uint_1 %12 +OpStore %2 %18 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(IfConversionTest, TestSimpleHalfIfFalse) { + const std::string text = R"( +; CHECK: OpSelectionMerge [[merge:%\w+]] +; CHECK: [[merge]] = OpLabel +; CHECK-NOT: OpPhi +; CHECK: [[sel:%\w+]] = OpSelect %uint %true %uint_0 %uint_1 +; CHECK OpStore {{%\w+}} [[sel]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %1 "func" %2 +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%_ptr_Output_uint = OpTypePointer Output %uint +%2 = OpVariable %_ptr_Output_uint Output +%11 = OpTypeFunction %void +%1 = OpFunction %void None %11 +%12 = OpLabel +OpSelectionMerge %14 None +OpBranchConditional %true %14 %15 +%15 = OpLabel +OpBranch %14 +%14 = OpLabel +%18 = OpPhi %uint %uint_0 %12 %uint_1 %15 +OpStore %2 %18 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(IfConversionTest, TestVectorSplat) { + const std::string text = R"( +; CHECK: [[bool_vec:%\w+]] = OpTypeVector %bool 2 +; CHECK: OpSelectionMerge [[merge:%\w+]] +; CHECK: [[merge]] = OpLabel +; CHECK-NOT: OpPhi +; CHECK: [[comp:%\w+]] = OpCompositeConstruct [[bool_vec]] %true %true +; CHECK: [[sel:%\w+]] = OpSelect {{%\w+}} [[comp]] +; CHECK OpStore {{%\w+}} [[sel]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %1 "func" %2 +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%uint_vec2 = OpTypeVector %uint 2 +%vec2_01 = OpConstantComposite %uint_vec2 %uint_0 %uint_1 +%vec2_10 = OpConstantComposite %uint_vec2 %uint_1 %uint_0 +%_ptr_Output_uint = OpTypePointer Output %uint_vec2 +%2 = OpVariable %_ptr_Output_uint Output +%11 = OpTypeFunction %void +%1 = OpFunction %void None %11 +%12 = OpLabel +OpSelectionMerge %14 None +OpBranchConditional %true %15 %16 +%15 = OpLabel +OpBranch %14 +%16 = OpLabel +OpBranch %14 +%14 = OpLabel +%18 = OpPhi %uint_vec2 %vec2_01 %15 %vec2_10 %16 +OpStore %2 %18 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(IfConversionTest, CodeMotionSameValue) { + const std::string text = R"( +; CHECK: [[var:%\w+]] = OpVariable +; CHECK: OpFunction +; CHECK: OpLabel +; CHECK-NOT: OpLabel +; CHECK: [[add:%\w+]] = OpIAdd %uint %uint_0 %uint_1 +; CHECK: OpSelectionMerge [[merge_lab:%\w+]] None +; CHECK-NEXT: OpBranchConditional +; CHECK: [[merge_lab]] = OpLabel +; CHECK-NOT: OpLabel +; CHECK: OpStore [[var]] [[add]] + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "func" %2 + %void = OpTypeVoid + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 +%_ptr_Output_uint = OpTypePointer Output %uint + %2 = OpVariable %_ptr_Output_uint Output + %8 = OpTypeFunction %void + %bool = OpTypeBool + %true = OpConstantTrue %bool + %1 = OpFunction %void None %8 + %11 = OpLabel + OpSelectionMerge %12 None + OpBranchConditional %true %13 %15 + %13 = OpLabel + %14 = OpIAdd %uint %uint_0 %uint_1 + OpBranch %12 + %15 = OpLabel + %16 = OpIAdd %uint %uint_0 %uint_1 + OpBranch %12 + %12 = OpLabel + %17 = OpPhi %uint %16 %15 %14 %13 + OpStore %2 %17 + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(IfConversionTest, CodeMotionMultipleInstructions) { + const std::string text = R"( +; CHECK: [[var:%\w+]] = OpVariable +; CHECK: OpFunction +; CHECK: OpLabel +; CHECK-NOT: OpLabel +; CHECK: [[a1:%\w+]] = OpIAdd %uint %uint_0 %uint_1 +; CHECK: [[a2:%\w+]] = OpIAdd %uint [[a1]] %uint_1 +; CHECK: OpSelectionMerge [[merge_lab:%\w+]] None +; CHECK-NEXT: OpBranchConditional +; CHECK: [[merge_lab]] = OpLabel +; CHECK-NOT: OpLabel +; CHECK: OpStore [[var]] [[a2]] + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "func" %2 + %void = OpTypeVoid + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 +%_ptr_Output_uint = OpTypePointer Output %uint + %2 = OpVariable %_ptr_Output_uint Output + %8 = OpTypeFunction %void + %bool = OpTypeBool + %true = OpConstantTrue %bool + %1 = OpFunction %void None %8 + %11 = OpLabel + OpSelectionMerge %12 None + OpBranchConditional %true %13 %15 + %13 = OpLabel + %a1 = OpIAdd %uint %uint_0 %uint_1 + %a2 = OpIAdd %uint %a1 %uint_1 + OpBranch %12 + %15 = OpLabel + %b1 = OpIAdd %uint %uint_0 %uint_1 + %b2 = OpIAdd %uint %b1 %uint_1 + OpBranch %12 + %12 = OpLabel + %17 = OpPhi %uint %b2 %15 %a2 %13 + OpStore %2 %17 + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(IfConversionTest, NoCommonDominator) { + const std::string text = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %1 "func" %2 +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%_ptr_Output_uint = OpTypePointer Output %uint +%2 = OpVariable %_ptr_Output_uint Output +%8 = OpTypeFunction %void +%1 = OpFunction %void None %8 +%9 = OpLabel +OpBranch %10 +%11 = OpLabel +OpBranch %10 +%10 = OpLabel +%12 = OpPhi %uint %uint_0 %9 %uint_1 %11 +OpStore %2 %12 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(text, text, true, true); +} + +TEST_F(IfConversionTest, LoopUntouched) { + const std::string text = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %1 "func" %2 +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%_ptr_Output_uint = OpTypePointer Output %uint +%2 = OpVariable %_ptr_Output_uint Output +%8 = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%1 = OpFunction %void None %8 +%11 = OpLabel +OpBranch %12 +%12 = OpLabel +%13 = OpPhi %uint %uint_0 %11 %uint_1 %12 +OpLoopMerge %14 %12 None +OpBranchConditional %true %14 %12 +%14 = OpLabel +OpStore %2 %13 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(text, text, true, true); +} + +TEST_F(IfConversionTest, TooManyPredecessors) { + const std::string text = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %1 "func" %2 +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%_ptr_Output_uint = OpTypePointer Output %uint +%2 = OpVariable %_ptr_Output_uint Output +%8 = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%1 = OpFunction %void None %8 +%11 = OpLabel +OpSelectionMerge %12 None +OpBranchConditional %true %13 %12 +%13 = OpLabel +OpBranchConditional %true %14 %15 +%14 = OpLabel +OpBranch %12 +%15 = OpLabel +OpBranch %12 +%12 = OpLabel +%16 = OpPhi %uint %uint_0 %11 %uint_0 %14 %uint_1 %15 +OpStore %2 %16 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(text, text, true, true); +} + +TEST_F(IfConversionTest, NoCodeMotion) { + const std::string text = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %1 "func" %2 +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%_ptr_Output_uint = OpTypePointer Output %uint +%2 = OpVariable %_ptr_Output_uint Output +%8 = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%1 = OpFunction %void None %8 +%11 = OpLabel +OpSelectionMerge %12 None +OpBranchConditional %true %13 %12 +%13 = OpLabel +%14 = OpIAdd %uint %uint_0 %uint_1 +OpBranch %12 +%12 = OpLabel +%15 = OpPhi %uint %uint_0 %11 %14 %13 +OpStore %2 %15 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(text, text, true, true); +} + +TEST_F(IfConversionTest, NoCodeMotionImmovableInst) { + const std::string text = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %1 "func" %2 +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%_ptr_Output_uint = OpTypePointer Output %uint +%2 = OpVariable %_ptr_Output_uint Output +%8 = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%1 = OpFunction %void None %8 +%11 = OpLabel +OpSelectionMerge %12 None +OpBranchConditional %true %13 %14 +%13 = OpLabel +OpSelectionMerge %15 None +OpBranchConditional %true %16 %15 +%16 = OpLabel +%17 = OpIAdd %uint %uint_0 %uint_1 +OpBranch %15 +%15 = OpLabel +%18 = OpPhi %uint %uint_0 %13 %17 %16 +%19 = OpIAdd %uint %18 %uint_1 +OpBranch %12 +%14 = OpLabel +OpSelectionMerge %20 None +OpBranchConditional %true %21 %20 +%21 = OpLabel +%22 = OpIAdd %uint %uint_0 %uint_1 +OpBranch %20 +%20 = OpLabel +%23 = OpPhi %uint %uint_0 %14 %22 %21 +%24 = OpIAdd %uint %23 %uint_1 +OpBranch %12 +%12 = OpLabel +%25 = OpPhi %uint %24 %20 %19 %15 +OpStore %2 %25 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(text, text, true, true); +} + +TEST_F(IfConversionTest, InvalidCommonDominator) { + const std::string text = R"(OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%void = OpTypeVoid +%float = OpTypeFloat 32 +%float_0 = OpConstant %float 0 +%float_1 = OpConstant %float 1 +%bool = OpTypeBool +%true = OpConstantTrue %bool +%1 = OpTypeFunction %void +%2 = OpFunction %void None %1 +%3 = OpLabel +OpBranch %4 +%4 = OpLabel +OpLoopMerge %5 %6 None +OpBranch %7 +%7 = OpLabel +OpSelectionMerge %8 None +OpBranchConditional %true %8 %9 +%9 = OpLabel +OpSelectionMerge %10 None +OpBranchConditional %true %10 %5 +%10 = OpLabel +OpBranch %8 +%8 = OpLabel +OpBranch %6 +%6 = OpLabel +OpBranchConditional %true %4 %5 +%5 = OpLabel +%11 = OpPhi %float %float_0 %6 %float_1 %9 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(text, text, true, true); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/inline_opaque_test.cpp b/test/opt/inline_opaque_test.cpp new file mode 100644 index 000000000..d10913aec --- /dev/null +++ b/test/opt/inline_opaque_test.cpp @@ -0,0 +1,412 @@ +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using InlineOpaqueTest = PassTest<::testing::Test>; + +TEST_F(InlineOpaqueTest, InlineCallWithStructArgContainingSampledImage) { + // Function with opaque argument is inlined. + // TODO(greg-lunarg): Add HLSL code + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %outColor %texCoords +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %S_t "S_t" +OpMemberName %S_t 0 "v0" +OpMemberName %S_t 1 "v1" +OpMemberName %S_t 2 "smp" +OpName %foo_struct_S_t_vf2_vf21_ "foo(struct-S_t-vf2-vf21;" +OpName %s "s" +OpName %outColor "outColor" +OpName %sampler15 "sampler15" +OpName %s0 "s0" +OpName %texCoords "texCoords" +OpName %param "param" +OpDecorate %sampler15 DescriptorSet 0 +%void = OpTypeVoid +%12 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%outColor = OpVariable %_ptr_Output_v4float Output +%17 = OpTypeImage %float 2D 0 0 0 1 Unknown +%18 = OpTypeSampledImage %17 +%S_t = OpTypeStruct %v2float %v2float %18 +%_ptr_Function_S_t = OpTypePointer Function %S_t +%20 = OpTypeFunction %void %_ptr_Function_S_t +%_ptr_UniformConstant_18 = OpTypePointer UniformConstant %18 +%_ptr_Function_18 = OpTypePointer Function %18 +%sampler15 = OpVariable %_ptr_UniformConstant_18 UniformConstant +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%int_2 = OpConstant %int 2 +%_ptr_Function_v2float = OpTypePointer Function %v2float +%_ptr_Input_v2float = OpTypePointer Input %v2float +%texCoords = OpVariable %_ptr_Input_v2float Input +)"; + + const std::string before = + R"(%main = OpFunction %void None %12 +%28 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%param = OpVariable %_ptr_Function_S_t Function +%29 = OpLoad %v2float %texCoords +%30 = OpAccessChain %_ptr_Function_v2float %s0 %int_0 +OpStore %30 %29 +%31 = OpLoad %18 %sampler15 +%32 = OpAccessChain %_ptr_Function_18 %s0 %int_2 +OpStore %32 %31 +%33 = OpLoad %S_t %s0 +OpStore %param %33 +%34 = OpFunctionCall %void %foo_struct_S_t_vf2_vf21_ %param +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %12 +%28 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%param = OpVariable %_ptr_Function_S_t Function +%29 = OpLoad %v2float %texCoords +%30 = OpAccessChain %_ptr_Function_v2float %s0 %int_0 +OpStore %30 %29 +%31 = OpLoad %18 %sampler15 +%32 = OpAccessChain %_ptr_Function_18 %s0 %int_2 +OpStore %32 %31 +%33 = OpLoad %S_t %s0 +OpStore %param %33 +%41 = OpAccessChain %_ptr_Function_18 %param %int_2 +%42 = OpLoad %18 %41 +%43 = OpAccessChain %_ptr_Function_v2float %param %int_0 +%44 = OpLoad %v2float %43 +%45 = OpImageSampleImplicitLod %v4float %42 %44 +OpStore %outColor %45 +OpReturn +OpFunctionEnd +)"; + + const std::string post_defs = + R"(%foo_struct_S_t_vf2_vf21_ = OpFunction %void None %20 +%s = OpFunctionParameter %_ptr_Function_S_t +%35 = OpLabel +%36 = OpAccessChain %_ptr_Function_18 %s %int_2 +%37 = OpLoad %18 %36 +%38 = OpAccessChain %_ptr_Function_v2float %s %int_0 +%39 = OpLoad %v2float %38 +%40 = OpImageSampleImplicitLod %v4float %37 %39 +OpStore %outColor %40 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs + before + post_defs, predefs + after + post_defs, true, true); +} + +TEST_F(InlineOpaqueTest, InlineOpaqueReturn) { + // Function with opaque return value is inlined. + // TODO(greg-lunarg): Add HLSL code + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %texCoords %outColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %foo_ "foo(" +OpName %texCoords "texCoords" +OpName %outColor "outColor" +OpName %sampler15 "sampler15" +OpName %sampler16 "sampler16" +OpDecorate %sampler15 DescriptorSet 0 +OpDecorate %sampler16 DescriptorSet 0 +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%bool = OpTypeBool +%false = OpConstantFalse %bool +%_ptr_Input_v2float = OpTypePointer Input %v2float +%texCoords = OpVariable %_ptr_Input_v2float Input +%float_0 = OpConstant %float 0 +%16 = OpConstantComposite %v2float %float_0 %float_0 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%outColor = OpVariable %_ptr_Output_v4float Output +%19 = OpTypeImage %float 2D 0 0 0 1 Unknown +%20 = OpTypeSampledImage %19 +%21 = OpTypeFunction %20 +%_ptr_UniformConstant_20 = OpTypePointer UniformConstant %20 +%_ptr_Function_20 = OpTypePointer Function %20 +%sampler15 = OpVariable %_ptr_UniformConstant_20 UniformConstant +%sampler16 = OpVariable %_ptr_UniformConstant_20 UniformConstant +)"; + + const std::string before = + R"(%main = OpFunction %void None %9 +%24 = OpLabel +%25 = OpVariable %_ptr_Function_20 Function +%26 = OpFunctionCall %20 %foo_ +OpStore %25 %26 +%27 = OpLoad %20 %25 +%28 = OpLoad %v2float %texCoords +%29 = OpImageSampleImplicitLod %v4float %27 %28 +OpStore %outColor %29 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %9 +%24 = OpLabel +%34 = OpVariable %_ptr_Function_20 Function +%35 = OpVariable %_ptr_Function_20 Function +%25 = OpVariable %_ptr_Function_20 Function +%36 = OpLoad %20 %sampler16 +OpStore %34 %36 +%37 = OpLoad %20 %34 +OpStore %35 %37 +%26 = OpLoad %20 %35 +OpStore %25 %26 +%27 = OpLoad %20 %25 +%28 = OpLoad %v2float %texCoords +%29 = OpImageSampleImplicitLod %v4float %27 %28 +OpStore %outColor %29 +OpReturn +OpFunctionEnd +)"; + + const std::string post_defs = + R"(%foo_ = OpFunction %20 None %21 +%30 = OpLabel +%31 = OpVariable %_ptr_Function_20 Function +%32 = OpLoad %20 %sampler16 +OpStore %31 %32 +%33 = OpLoad %20 %31 +OpReturnValue %33 +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs + before + post_defs, predefs + after + post_defs, true, true); +} + +TEST_F(InlineOpaqueTest, InlineInNonEntryPointFunction) { + // This demonstrates opaque inlining in a function that is not + // an entry point function (main2) but is in the call tree of an + // entry point function (main). + // TODO(greg-lunarg): Add HLSL code + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %outColor %texCoords +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %main2 "main2" +OpName %S_t "S_t" +OpMemberName %S_t 0 "v0" +OpMemberName %S_t 1 "v1" +OpMemberName %S_t 2 "smp" +OpName %foo_struct_S_t_vf2_vf21_ "foo(struct-S_t-vf2-vf21;" +OpName %s "s" +OpName %outColor "outColor" +OpName %sampler15 "sampler15" +OpName %s0 "s0" +OpName %texCoords "texCoords" +OpName %param "param" +OpDecorate %sampler15 DescriptorSet 0 +%void = OpTypeVoid +%13 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%outColor = OpVariable %_ptr_Output_v4float Output +%18 = OpTypeImage %float 2D 0 0 0 1 Unknown +%19 = OpTypeSampledImage %18 +%S_t = OpTypeStruct %v2float %v2float %19 +%_ptr_Function_S_t = OpTypePointer Function %S_t +%21 = OpTypeFunction %void %_ptr_Function_S_t +%_ptr_UniformConstant_19 = OpTypePointer UniformConstant %19 +%_ptr_Function_19 = OpTypePointer Function %19 +%sampler15 = OpVariable %_ptr_UniformConstant_19 UniformConstant +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%int_2 = OpConstant %int 2 +%_ptr_Function_v2float = OpTypePointer Function %v2float +%_ptr_Input_v2float = OpTypePointer Input %v2float +%texCoords = OpVariable %_ptr_Input_v2float Input +)"; + + const std::string before = + R"(%main2 = OpFunction %void None %13 +%29 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%param = OpVariable %_ptr_Function_S_t Function +%30 = OpLoad %v2float %texCoords +%31 = OpAccessChain %_ptr_Function_v2float %s0 %int_0 +OpStore %31 %30 +%32 = OpLoad %19 %sampler15 +%33 = OpAccessChain %_ptr_Function_19 %s0 %int_2 +OpStore %33 %32 +%34 = OpLoad %S_t %s0 +OpStore %param %34 +%35 = OpFunctionCall %void %foo_struct_S_t_vf2_vf21_ %param +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main2 = OpFunction %void None %13 +%29 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%param = OpVariable %_ptr_Function_S_t Function +%30 = OpLoad %v2float %texCoords +%31 = OpAccessChain %_ptr_Function_v2float %s0 %int_0 +OpStore %31 %30 +%32 = OpLoad %19 %sampler15 +%33 = OpAccessChain %_ptr_Function_19 %s0 %int_2 +OpStore %33 %32 +%34 = OpLoad %S_t %s0 +OpStore %param %34 +%44 = OpAccessChain %_ptr_Function_19 %param %int_2 +%45 = OpLoad %19 %44 +%46 = OpAccessChain %_ptr_Function_v2float %param %int_0 +%47 = OpLoad %v2float %46 +%48 = OpImageSampleImplicitLod %v4float %45 %47 +OpStore %outColor %48 +OpReturn +OpFunctionEnd +)"; + + const std::string post_defs = + R"(%main = OpFunction %void None %13 +%36 = OpLabel +%37 = OpFunctionCall %void %main2 +OpReturn +OpFunctionEnd +%foo_struct_S_t_vf2_vf21_ = OpFunction %void None %21 +%s = OpFunctionParameter %_ptr_Function_S_t +%38 = OpLabel +%39 = OpAccessChain %_ptr_Function_19 %s %int_2 +%40 = OpLoad %19 %39 +%41 = OpAccessChain %_ptr_Function_v2float %s %int_0 +%42 = OpLoad %v2float %41 +%43 = OpImageSampleImplicitLod %v4float %40 %42 +OpStore %outColor %43 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs + before + post_defs, predefs + after + post_defs, true, true); +} + +TEST_F(InlineOpaqueTest, NoInlineNoOpaque) { + // Function without opaque interface is not inlined. + // #version 140 + // + // in vec4 BaseColor; + // + // float foo(vec4 bar) + // { + // return bar.x + bar.y; + // } + // + // void main() + // { + // vec4 color = vec4(foo(BaseColor)); + // gl_FragColor = color; + // } + + const std::string assembly = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %foo_vf4_ "foo(vf4;" +OpName %bar "bar" +OpName %color "color" +OpName %BaseColor "BaseColor" +OpName %param "param" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%14 = OpTypeFunction %float %_ptr_Function_v4float +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%_ptr_Function_float = OpTypePointer Function %float +%uint_1 = OpConstant %uint 1 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %10 +%21 = OpLabel +%color = OpVariable %_ptr_Function_v4float Function +%param = OpVariable %_ptr_Function_v4float Function +%22 = OpLoad %v4float %BaseColor +OpStore %param %22 +%23 = OpFunctionCall %float %foo_vf4_ %param +%24 = OpCompositeConstruct %v4float %23 %23 %23 %23 +OpStore %color %24 +%25 = OpLoad %v4float %color +OpStore %gl_FragColor %25 +OpReturn +OpFunctionEnd +%foo_vf4_ = OpFunction %float None %14 +%bar = OpFunctionParameter %_ptr_Function_v4float +%26 = OpLabel +%27 = OpAccessChain %_ptr_Function_float %bar %uint_0 +%28 = OpLoad %float %27 +%29 = OpAccessChain %_ptr_Function_float %bar %uint_1 +%30 = OpLoad %float %29 +%31 = OpFAdd %float %28 %30 +OpReturnValue %31 +OpFunctionEnd +)"; + + SinglePassRunAndCheck(assembly, assembly, true, true); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/inline_test.cpp b/test/opt/inline_test.cpp new file mode 100644 index 000000000..44a969858 --- /dev/null +++ b/test/opt/inline_test.cpp @@ -0,0 +1,3133 @@ +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using InlineTest = PassTest<::testing::Test>; + +TEST_F(InlineTest, Simple) { + // #version 140 + // + // in vec4 BaseColor; + // + // float foo(vec4 bar) + // { + // return bar.x + bar.y; + // } + // + // void main() + // { + // vec4 color = vec4(foo(BaseColor)); + // gl_FragColor = color; + // } + const std::vector predefs = { + // clang-format off + "OpCapability Shader", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Fragment %main \"main\" %BaseColor %gl_FragColor", + "OpExecutionMode %main OriginUpperLeft", + "OpSource GLSL 140", + "OpName %main \"main\"", + "OpName %foo_vf4_ \"foo(vf4;\"", + "OpName %bar \"bar\"", + "OpName %color \"color\"", + "OpName %BaseColor \"BaseColor\"", + "OpName %param \"param\"", + "OpName %gl_FragColor \"gl_FragColor\"", + "%void = OpTypeVoid", + "%10 = OpTypeFunction %void", + "%float = OpTypeFloat 32", + "%v4float = OpTypeVector %float 4", +"%_ptr_Function_v4float = OpTypePointer Function %v4float", + "%14 = OpTypeFunction %float %_ptr_Function_v4float", + "%uint = OpTypeInt 32 0", + "%uint_0 = OpConstant %uint 0", +"%_ptr_Function_float = OpTypePointer Function %float", + "%uint_1 = OpConstant %uint 1", +"%_ptr_Input_v4float = OpTypePointer Input %v4float", + "%BaseColor = OpVariable %_ptr_Input_v4float Input", +"%_ptr_Output_v4float = OpTypePointer Output %v4float", +"%gl_FragColor = OpVariable %_ptr_Output_v4float Output", + // clang-format on + }; + + const std::vector nonEntryFuncs = { + // clang-format off + "%foo_vf4_ = OpFunction %float None %14", + "%bar = OpFunctionParameter %_ptr_Function_v4float", + "%26 = OpLabel", + "%27 = OpAccessChain %_ptr_Function_float %bar %uint_0", + "%28 = OpLoad %float %27", + "%29 = OpAccessChain %_ptr_Function_float %bar %uint_1", + "%30 = OpLoad %float %29", + "%31 = OpFAdd %float %28 %30", + "OpReturnValue %31", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector before = { + // clang-format off + "%main = OpFunction %void None %10", + "%21 = OpLabel", + "%color = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_v4float Function", + "%22 = OpLoad %v4float %BaseColor", + "OpStore %param %22", + "%23 = OpFunctionCall %float %foo_vf4_ %param", + "%24 = OpCompositeConstruct %v4float %23 %23 %23 %23", + "OpStore %color %24", + "%25 = OpLoad %v4float %color", + "OpStore %gl_FragColor %25", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector after = { + // clang-format off + "%main = OpFunction %void None %10", + "%21 = OpLabel", + "%32 = OpVariable %_ptr_Function_float Function", + "%color = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_v4float Function", + "%22 = OpLoad %v4float %BaseColor", + "OpStore %param %22", + "%33 = OpAccessChain %_ptr_Function_float %param %uint_0", + "%34 = OpLoad %float %33", + "%35 = OpAccessChain %_ptr_Function_float %param %uint_1", + "%36 = OpLoad %float %35", + "%37 = OpFAdd %float %34 %36", + "OpStore %32 %37", + "%23 = OpLoad %float %32", + "%24 = OpCompositeConstruct %v4float %23 %23 %23 %23", + "OpStore %color %24", + "%25 = OpLoad %v4float %color", + "OpStore %gl_FragColor %25", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + SinglePassRunAndCheck( + JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)), + JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)), + /* skip_nop = */ false, /* do_validate = */ true); +} + +TEST_F(InlineTest, Nested) { + // #version 140 + // + // in vec4 BaseColor; + // + // float foo2(float f, float f2) + // { + // return f * f2; + // } + // + // float foo(vec4 bar) + // { + // return foo2(bar.x + bar.y, bar.z); + // } + // + // void main() + // { + // vec4 color = vec4(foo(BaseColor)); + // gl_FragColor = color; + // } + const std::vector predefs = { + // clang-format off + "OpCapability Shader", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Fragment %main \"main\" %BaseColor %gl_FragColor", + "OpExecutionMode %main OriginUpperLeft", + "OpSource GLSL 140", + "OpName %main \"main\"", + "OpName %foo2_f1_f1_ \"foo2(f1;f1;\"", + "OpName %f \"f\"", + "OpName %f2 \"f2\"", + "OpName %foo_vf4_ \"foo(vf4;\"", + "OpName %bar \"bar\"", + "OpName %param \"param\"", + "OpName %param_0 \"param\"", + "OpName %color \"color\"", + "OpName %BaseColor \"BaseColor\"", + "OpName %param_1 \"param\"", + "OpName %gl_FragColor \"gl_FragColor\"", + "%void = OpTypeVoid", + "%15 = OpTypeFunction %void", + "%float = OpTypeFloat 32", +"%_ptr_Function_float = OpTypePointer Function %float", + "%18 = OpTypeFunction %float %_ptr_Function_float %_ptr_Function_float", + "%v4float = OpTypeVector %float 4", +"%_ptr_Function_v4float = OpTypePointer Function %v4float", + "%21 = OpTypeFunction %float %_ptr_Function_v4float", + "%uint = OpTypeInt 32 0", + "%uint_0 = OpConstant %uint 0", + "%uint_1 = OpConstant %uint 1", + "%uint_2 = OpConstant %uint 2", +"%_ptr_Input_v4float = OpTypePointer Input %v4float", + "%BaseColor = OpVariable %_ptr_Input_v4float Input", +"%_ptr_Output_v4float = OpTypePointer Output %v4float", +"%gl_FragColor = OpVariable %_ptr_Output_v4float Output", + // clang-format on + }; + + const std::vector nonEntryFuncs = { + // clang-format off +"%foo2_f1_f1_ = OpFunction %float None %18", + "%f = OpFunctionParameter %_ptr_Function_float", + "%f2 = OpFunctionParameter %_ptr_Function_float", + "%33 = OpLabel", + "%34 = OpLoad %float %f", + "%35 = OpLoad %float %f2", + "%36 = OpFMul %float %34 %35", + "OpReturnValue %36", + "OpFunctionEnd", + "%foo_vf4_ = OpFunction %float None %21", + "%bar = OpFunctionParameter %_ptr_Function_v4float", + "%37 = OpLabel", + "%param = OpVariable %_ptr_Function_float Function", + "%param_0 = OpVariable %_ptr_Function_float Function", + "%38 = OpAccessChain %_ptr_Function_float %bar %uint_0", + "%39 = OpLoad %float %38", + "%40 = OpAccessChain %_ptr_Function_float %bar %uint_1", + "%41 = OpLoad %float %40", + "%42 = OpFAdd %float %39 %41", + "OpStore %param %42", + "%43 = OpAccessChain %_ptr_Function_float %bar %uint_2", + "%44 = OpLoad %float %43", + "OpStore %param_0 %44", + "%45 = OpFunctionCall %float %foo2_f1_f1_ %param %param_0", + "OpReturnValue %45", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector before = { + // clang-format off + "%main = OpFunction %void None %15", + "%28 = OpLabel", + "%color = OpVariable %_ptr_Function_v4float Function", + "%param_1 = OpVariable %_ptr_Function_v4float Function", + "%29 = OpLoad %v4float %BaseColor", + "OpStore %param_1 %29", + "%30 = OpFunctionCall %float %foo_vf4_ %param_1", + "%31 = OpCompositeConstruct %v4float %30 %30 %30 %30", + "OpStore %color %31", + "%32 = OpLoad %v4float %color", + "OpStore %gl_FragColor %32", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector after = { + // clang-format off + "%main = OpFunction %void None %15", + "%28 = OpLabel", + "%57 = OpVariable %_ptr_Function_float Function", + "%46 = OpVariable %_ptr_Function_float Function", + "%47 = OpVariable %_ptr_Function_float Function", + "%48 = OpVariable %_ptr_Function_float Function", + "%color = OpVariable %_ptr_Function_v4float Function", + "%param_1 = OpVariable %_ptr_Function_v4float Function", + "%29 = OpLoad %v4float %BaseColor", + "OpStore %param_1 %29", + "%49 = OpAccessChain %_ptr_Function_float %param_1 %uint_0", + "%50 = OpLoad %float %49", + "%51 = OpAccessChain %_ptr_Function_float %param_1 %uint_1", + "%52 = OpLoad %float %51", + "%53 = OpFAdd %float %50 %52", + "OpStore %46 %53", + "%54 = OpAccessChain %_ptr_Function_float %param_1 %uint_2", + "%55 = OpLoad %float %54", + "OpStore %47 %55", + "%58 = OpLoad %float %46", + "%59 = OpLoad %float %47", + "%60 = OpFMul %float %58 %59", + "OpStore %57 %60", + "%56 = OpLoad %float %57", + "OpStore %48 %56", + "%30 = OpLoad %float %48", + "%31 = OpCompositeConstruct %v4float %30 %30 %30 %30", + "OpStore %color %31", + "%32 = OpLoad %v4float %color", + "OpStore %gl_FragColor %32", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + SinglePassRunAndCheck( + JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)), + JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)), + /* skip_nop = */ false, /* do_validate = */ true); +} + +TEST_F(InlineTest, InOutParameter) { + // #version 400 + // + // in vec4 Basecolor; + // + // void foo(inout vec4 bar) + // { + // bar.z = bar.x + bar.y; + // } + // + // void main() + // { + // vec4 b = Basecolor; + // foo(b); + // vec4 color = vec4(b.z); + // gl_FragColor = color; + // } + const std::vector predefs = { + // clang-format off + "OpCapability Shader", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Fragment %main \"main\" %Basecolor %gl_FragColor", + "OpExecutionMode %main OriginUpperLeft", + "OpSource GLSL 400", + "OpName %main \"main\"", + "OpName %foo_vf4_ \"foo(vf4;\"", + "OpName %bar \"bar\"", + "OpName %b \"b\"", + "OpName %Basecolor \"Basecolor\"", + "OpName %param \"param\"", + "OpName %color \"color\"", + "OpName %gl_FragColor \"gl_FragColor\"", + "%void = OpTypeVoid", + "%11 = OpTypeFunction %void", + "%float = OpTypeFloat 32", + "%v4float = OpTypeVector %float 4", +"%_ptr_Function_v4float = OpTypePointer Function %v4float", + "%15 = OpTypeFunction %void %_ptr_Function_v4float", + "%uint = OpTypeInt 32 0", + "%uint_0 = OpConstant %uint 0", +"%_ptr_Function_float = OpTypePointer Function %float", + "%uint_1 = OpConstant %uint 1", + "%uint_2 = OpConstant %uint 2", +"%_ptr_Input_v4float = OpTypePointer Input %v4float", + "%Basecolor = OpVariable %_ptr_Input_v4float Input", +"%_ptr_Output_v4float = OpTypePointer Output %v4float", +"%gl_FragColor = OpVariable %_ptr_Output_v4float Output", + // clang-format on + }; + + const std::vector nonEntryFuncs = { + // clang-format off + "%foo_vf4_ = OpFunction %void None %15", + "%bar = OpFunctionParameter %_ptr_Function_v4float", + "%32 = OpLabel", + "%33 = OpAccessChain %_ptr_Function_float %bar %uint_0", + "%34 = OpLoad %float %33", + "%35 = OpAccessChain %_ptr_Function_float %bar %uint_1", + "%36 = OpLoad %float %35", + "%37 = OpFAdd %float %34 %36", + "%38 = OpAccessChain %_ptr_Function_float %bar %uint_2", + "OpStore %38 %37", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector before = { + // clang-format off + "%main = OpFunction %void None %11", + "%23 = OpLabel", + "%b = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_v4float Function", + "%color = OpVariable %_ptr_Function_v4float Function", + "%24 = OpLoad %v4float %Basecolor", + "OpStore %b %24", + "%25 = OpLoad %v4float %b", + "OpStore %param %25", + "%26 = OpFunctionCall %void %foo_vf4_ %param", + "%27 = OpLoad %v4float %param", + "OpStore %b %27", + "%28 = OpAccessChain %_ptr_Function_float %b %uint_2", + "%29 = OpLoad %float %28", + "%30 = OpCompositeConstruct %v4float %29 %29 %29 %29", + "OpStore %color %30", + "%31 = OpLoad %v4float %color", + "OpStore %gl_FragColor %31", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector after = { + // clang-format off + "%main = OpFunction %void None %11", + "%23 = OpLabel", + "%b = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_v4float Function", + "%color = OpVariable %_ptr_Function_v4float Function", + "%24 = OpLoad %v4float %Basecolor", + "OpStore %b %24", + "%25 = OpLoad %v4float %b", + "OpStore %param %25", + "%39 = OpAccessChain %_ptr_Function_float %param %uint_0", + "%40 = OpLoad %float %39", + "%41 = OpAccessChain %_ptr_Function_float %param %uint_1", + "%42 = OpLoad %float %41", + "%43 = OpFAdd %float %40 %42", + "%44 = OpAccessChain %_ptr_Function_float %param %uint_2", + "OpStore %44 %43", + "%27 = OpLoad %v4float %param", + "OpStore %b %27", + "%28 = OpAccessChain %_ptr_Function_float %b %uint_2", + "%29 = OpLoad %float %28", + "%30 = OpCompositeConstruct %v4float %29 %29 %29 %29", + "OpStore %color %30", + "%31 = OpLoad %v4float %color", + "OpStore %gl_FragColor %31", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + SinglePassRunAndCheck( + JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)), + JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)), + /* skip_nop = */ false, /* do_validate = */ true); +} + +TEST_F(InlineTest, BranchInCallee) { + // #version 140 + // + // in vec4 BaseColor; + // + // float foo(vec4 bar) + // { + // float r = bar.x; + // if (r < 0.0) + // r = -r; + // return r; + // } + // + // void main() + // { + // vec4 color = vec4(foo(BaseColor)); + // + // gl_FragColor = color; + // } + const std::vector predefs = { + // clang-format off + "OpCapability Shader", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Fragment %main \"main\" %BaseColor %gl_FragColor", + "OpExecutionMode %main OriginUpperLeft", + "OpSource GLSL 140", + "OpName %main \"main\"", + "OpName %foo_vf4_ \"foo(vf4;\"", + "OpName %bar \"bar\"", + "OpName %r \"r\"", + "OpName %color \"color\"", + "OpName %BaseColor \"BaseColor\"", + "OpName %param \"param\"", + "OpName %gl_FragColor \"gl_FragColor\"", + "%void = OpTypeVoid", + "%11 = OpTypeFunction %void", + "%float = OpTypeFloat 32", + "%v4float = OpTypeVector %float 4", +"%_ptr_Function_v4float = OpTypePointer Function %v4float", + "%15 = OpTypeFunction %float %_ptr_Function_v4float", +"%_ptr_Function_float = OpTypePointer Function %float", + "%uint = OpTypeInt 32 0", + "%uint_0 = OpConstant %uint 0", + "%float_0 = OpConstant %float 0", + "%bool = OpTypeBool", +"%_ptr_Input_v4float = OpTypePointer Input %v4float", + "%BaseColor = OpVariable %_ptr_Input_v4float Input", +"%_ptr_Output_v4float = OpTypePointer Output %v4float", +"%gl_FragColor = OpVariable %_ptr_Output_v4float Output", + // clang-format on + }; + + const std::vector nonEntryFuncs = { + // clang-format off + "%foo_vf4_ = OpFunction %float None %15", + "%bar = OpFunctionParameter %_ptr_Function_v4float", + "%28 = OpLabel", + "%r = OpVariable %_ptr_Function_float Function", + "%29 = OpAccessChain %_ptr_Function_float %bar %uint_0", + "%30 = OpLoad %float %29", + "OpStore %r %30", + "%31 = OpLoad %float %r", + "%32 = OpFOrdLessThan %bool %31 %float_0", + "OpSelectionMerge %33 None", + "OpBranchConditional %32 %34 %33", + "%34 = OpLabel", + "%35 = OpLoad %float %r", + "%36 = OpFNegate %float %35", + "OpStore %r %36", + "OpBranch %33", + "%33 = OpLabel", + "%37 = OpLoad %float %r", + "OpReturnValue %37", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector before = { + // clang-format off + "%main = OpFunction %void None %11", + "%23 = OpLabel", + "%color = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_v4float Function", + "%24 = OpLoad %v4float %BaseColor", + "OpStore %param %24", + "%25 = OpFunctionCall %float %foo_vf4_ %param", + "%26 = OpCompositeConstruct %v4float %25 %25 %25 %25", + "OpStore %color %26", + "%27 = OpLoad %v4float %color", + "OpStore %gl_FragColor %27", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector after = { + // clang-format off + "%main = OpFunction %void None %11", + "%23 = OpLabel", + "%38 = OpVariable %_ptr_Function_float Function", + "%39 = OpVariable %_ptr_Function_float Function", + "%color = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_v4float Function", + "%24 = OpLoad %v4float %BaseColor", + "OpStore %param %24", + "%40 = OpAccessChain %_ptr_Function_float %param %uint_0", + "%41 = OpLoad %float %40", + "OpStore %38 %41", + "%42 = OpLoad %float %38", + "%43 = OpFOrdLessThan %bool %42 %float_0", + "OpSelectionMerge %44 None", + "OpBranchConditional %43 %45 %44", + "%45 = OpLabel", + "%46 = OpLoad %float %38", + "%47 = OpFNegate %float %46", + "OpStore %38 %47", + "OpBranch %44", + "%44 = OpLabel", + "%48 = OpLoad %float %38", + "OpStore %39 %48", + "%25 = OpLoad %float %39", + "%26 = OpCompositeConstruct %v4float %25 %25 %25 %25", + "OpStore %color %26", + "%27 = OpLoad %v4float %color", + "OpStore %gl_FragColor %27", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + SinglePassRunAndCheck( + JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)), + JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)), + /* skip_nop = */ false, /* do_validate = */ true); +} + +TEST_F(InlineTest, PhiAfterCall) { + // #version 140 + // + // in vec4 BaseColor; + // + // float foo(float bar) + // { + // float r = bar; + // if (r < 0.0) + // r = -r; + // return r; + // } + // + // void main() + // { + // vec4 color = BaseColor; + // if (foo(color.x) > 2.0 && foo(color.y) > 2.0) + // color = vec4(0.0); + // gl_FragColor = color; + // } + const std::vector predefs = { + // clang-format off + "OpCapability Shader", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Fragment %main \"main\" %BaseColor %gl_FragColor", + "OpExecutionMode %main OriginUpperLeft", + "OpSource GLSL 140", + "OpName %main \"main\"", + "OpName %foo_f1_ \"foo(f1;\"", + "OpName %bar \"bar\"", + "OpName %r \"r\"", + "OpName %color \"color\"", + "OpName %BaseColor \"BaseColor\"", + "OpName %param \"param\"", + "OpName %param_0 \"param\"", + "OpName %gl_FragColor \"gl_FragColor\"", + "%void = OpTypeVoid", + "%12 = OpTypeFunction %void", + "%float = OpTypeFloat 32", +"%_ptr_Function_float = OpTypePointer Function %float", + "%15 = OpTypeFunction %float %_ptr_Function_float", + "%float_0 = OpConstant %float 0", + "%bool = OpTypeBool", + "%v4float = OpTypeVector %float 4", +"%_ptr_Function_v4float = OpTypePointer Function %v4float", +"%_ptr_Input_v4float = OpTypePointer Input %v4float", + "%BaseColor = OpVariable %_ptr_Input_v4float Input", + "%uint = OpTypeInt 32 0", + "%uint_0 = OpConstant %uint 0", + "%float_2 = OpConstant %float 2", + "%uint_1 = OpConstant %uint 1", + "%25 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0", +"%_ptr_Output_v4float = OpTypePointer Output %v4float", +"%gl_FragColor = OpVariable %_ptr_Output_v4float Output", + // clang-format on + }; + + const std::vector nonEntryFuncs = { + // clang-format off + "%foo_f1_ = OpFunction %float None %15", + "%bar = OpFunctionParameter %_ptr_Function_float", + "%43 = OpLabel", + "%r = OpVariable %_ptr_Function_float Function", + "%44 = OpLoad %float %bar", + "OpStore %r %44", + "%45 = OpLoad %float %r", + "%46 = OpFOrdLessThan %bool %45 %float_0", + "OpSelectionMerge %47 None", + "OpBranchConditional %46 %48 %47", + "%48 = OpLabel", + "%49 = OpLoad %float %r", + "%50 = OpFNegate %float %49", + "OpStore %r %50", + "OpBranch %47", + "%47 = OpLabel", + "%51 = OpLoad %float %r", + "OpReturnValue %51", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector before = { + // clang-format off + "%main = OpFunction %void None %12", + "%27 = OpLabel", + "%color = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_float Function", + "%param_0 = OpVariable %_ptr_Function_float Function", + "%28 = OpLoad %v4float %BaseColor", + "OpStore %color %28", + "%29 = OpAccessChain %_ptr_Function_float %color %uint_0", + "%30 = OpLoad %float %29", + "OpStore %param %30", + "%31 = OpFunctionCall %float %foo_f1_ %param", + "%32 = OpFOrdGreaterThan %bool %31 %float_2", + "OpSelectionMerge %33 None", + "OpBranchConditional %32 %34 %33", + "%34 = OpLabel", + "%35 = OpAccessChain %_ptr_Function_float %color %uint_1", + "%36 = OpLoad %float %35", + "OpStore %param_0 %36", + "%37 = OpFunctionCall %float %foo_f1_ %param_0", + "%38 = OpFOrdGreaterThan %bool %37 %float_2", + "OpBranch %33", + "%33 = OpLabel", + "%39 = OpPhi %bool %32 %27 %38 %34", + "OpSelectionMerge %40 None", + "OpBranchConditional %39 %41 %40", + "%41 = OpLabel", + "OpStore %color %25", + "OpBranch %40", + "%40 = OpLabel", + "%42 = OpLoad %v4float %color", + "OpStore %gl_FragColor %42", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector after = { + // clang-format off + "%main = OpFunction %void None %12", + "%27 = OpLabel", + "%62 = OpVariable %_ptr_Function_float Function", + "%63 = OpVariable %_ptr_Function_float Function", + "%52 = OpVariable %_ptr_Function_float Function", + "%53 = OpVariable %_ptr_Function_float Function", + "%color = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_float Function", + "%param_0 = OpVariable %_ptr_Function_float Function", + "%28 = OpLoad %v4float %BaseColor", + "OpStore %color %28", + "%29 = OpAccessChain %_ptr_Function_float %color %uint_0", + "%30 = OpLoad %float %29", + "OpStore %param %30", + "%54 = OpLoad %float %param", + "OpStore %52 %54", + "%55 = OpLoad %float %52", + "%56 = OpFOrdLessThan %bool %55 %float_0", + "OpSelectionMerge %57 None", + "OpBranchConditional %56 %58 %57", + "%58 = OpLabel", + "%59 = OpLoad %float %52", + "%60 = OpFNegate %float %59", + "OpStore %52 %60", + "OpBranch %57", + "%57 = OpLabel", + "%61 = OpLoad %float %52", + "OpStore %53 %61", + "%31 = OpLoad %float %53", + "%32 = OpFOrdGreaterThan %bool %31 %float_2", + "OpSelectionMerge %33 None", + "OpBranchConditional %32 %34 %33", + "%34 = OpLabel", + "%35 = OpAccessChain %_ptr_Function_float %color %uint_1", + "%36 = OpLoad %float %35", + "OpStore %param_0 %36", + "%64 = OpLoad %float %param_0", + "OpStore %62 %64", + "%65 = OpLoad %float %62", + "%66 = OpFOrdLessThan %bool %65 %float_0", + "OpSelectionMerge %67 None", + "OpBranchConditional %66 %68 %67", + "%68 = OpLabel", + "%69 = OpLoad %float %62", + "%70 = OpFNegate %float %69", + "OpStore %62 %70", + "OpBranch %67", + "%67 = OpLabel", + "%71 = OpLoad %float %62", + "OpStore %63 %71", + "%37 = OpLoad %float %63", + "%38 = OpFOrdGreaterThan %bool %37 %float_2", + "OpBranch %33", + "%33 = OpLabel", + "%39 = OpPhi %bool %32 %57 %38 %67", + "OpSelectionMerge %40 None", + "OpBranchConditional %39 %41 %40", + "%41 = OpLabel", + "OpStore %color %25", + "OpBranch %40", + "%40 = OpLabel", + "%42 = OpLoad %v4float %color", + "OpStore %gl_FragColor %42", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + SinglePassRunAndCheck( + JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)), + JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)), + /* skip_nop = */ false, /* do_validate = */ true); +} + +TEST_F(InlineTest, OpSampledImageOutOfBlock) { + // #version 450 + // + // uniform texture2D t2D; + // uniform sampler samp; + // out vec4 FragColor; + // in vec4 BaseColor; + // + // float foo(vec4 bar) + // { + // float r = bar.x; + // if (r < 0.0) + // r = -r; + // return r; + // } + // + // void main() + // { + // vec4 color1 = texture(sampler2D(t2D, samp), vec2(1.0)); + // vec4 color2 = vec4(foo(BaseColor)); + // vec4 color3 = texture(sampler2D(t2D, samp), vec2(0.5)); + // FragColor = (color1 + color2 + color3)/3; + // } + // + // Note: the before SPIR-V will need to be edited to create a use of + // the OpSampledImage across the function call. + const std::vector predefs = { + // clang-format off + "OpCapability Shader", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Fragment %main \"main\" %BaseColor %FragColor", + "OpExecutionMode %main OriginUpperLeft", + "OpSource GLSL 450", + "OpName %main \"main\"", + "OpName %foo_vf4_ \"foo(vf4;\"", + "OpName %bar \"bar\"", + "OpName %r \"r\"", + "OpName %color1 \"color1\"", + "OpName %t2D \"t2D\"", + "OpName %samp \"samp\"", + "OpName %color2 \"color2\"", + "OpName %BaseColor \"BaseColor\"", + "OpName %param \"param\"", + "OpName %color3 \"color3\"", + "OpName %FragColor \"FragColor\"", + "OpDecorate %t2D DescriptorSet 0", + "OpDecorate %samp DescriptorSet 0", + "%void = OpTypeVoid", + "%15 = OpTypeFunction %void", + "%float = OpTypeFloat 32", + "%v4float = OpTypeVector %float 4", +"%_ptr_Function_v4float = OpTypePointer Function %v4float", + "%19 = OpTypeFunction %float %_ptr_Function_v4float", +"%_ptr_Function_float = OpTypePointer Function %float", + "%uint = OpTypeInt 32 0", + "%uint_0 = OpConstant %uint 0", + "%float_0 = OpConstant %float 0", + "%bool = OpTypeBool", + "%25 = OpTypeImage %float 2D 0 0 0 1 Unknown", +"%_ptr_UniformConstant_25 = OpTypePointer UniformConstant %25", + "%t2D = OpVariable %_ptr_UniformConstant_25 UniformConstant", + "%27 = OpTypeSampler", +"%_ptr_UniformConstant_27 = OpTypePointer UniformConstant %27", + "%samp = OpVariable %_ptr_UniformConstant_27 UniformConstant", + "%29 = OpTypeSampledImage %25", + "%v2float = OpTypeVector %float 2", + "%float_1 = OpConstant %float 1", + "%32 = OpConstantComposite %v2float %float_1 %float_1", +"%_ptr_Input_v4float = OpTypePointer Input %v4float", + "%BaseColor = OpVariable %_ptr_Input_v4float Input", + "%float_0_5 = OpConstant %float 0.5", + "%35 = OpConstantComposite %v2float %float_0_5 %float_0_5", +"%_ptr_Output_v4float = OpTypePointer Output %v4float", + "%FragColor = OpVariable %_ptr_Output_v4float Output", + "%float_3 = OpConstant %float 3", + // clang-format on + }; + + const std::vector nonEntryFuncs = { + // clang-format off + "%foo_vf4_ = OpFunction %float None %19", + "%bar = OpFunctionParameter %_ptr_Function_v4float", + "%56 = OpLabel", + "%r = OpVariable %_ptr_Function_float Function", + "%57 = OpAccessChain %_ptr_Function_float %bar %uint_0", + "%58 = OpLoad %float %57", + "OpStore %r %58", + "%59 = OpLoad %float %r", + "%60 = OpFOrdLessThan %bool %59 %float_0", + "OpSelectionMerge %61 None", + "OpBranchConditional %60 %62 %61", + "%62 = OpLabel", + "%63 = OpLoad %float %r", + "%64 = OpFNegate %float %63", + "OpStore %r %64", + "OpBranch %61", + "%61 = OpLabel", + "%65 = OpLoad %float %r", + "OpReturnValue %65", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector before = { + // clang-format off + "%main = OpFunction %void None %15", + "%38 = OpLabel", + "%color1 = OpVariable %_ptr_Function_v4float Function", + "%color2 = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_v4float Function", + "%color3 = OpVariable %_ptr_Function_v4float Function", + "%39 = OpLoad %25 %t2D", + "%40 = OpLoad %27 %samp", + "%41 = OpSampledImage %29 %39 %40", + "%42 = OpImageSampleImplicitLod %v4float %41 %32", + "OpStore %color1 %42", + "%43 = OpLoad %v4float %BaseColor", + "OpStore %param %43", + "%44 = OpFunctionCall %float %foo_vf4_ %param", + "%45 = OpCompositeConstruct %v4float %44 %44 %44 %44", + "OpStore %color2 %45", + "%46 = OpLoad %25 %t2D", + "%47 = OpLoad %27 %samp", + "%48 = OpImageSampleImplicitLod %v4float %41 %35", + "OpStore %color3 %48", + "%49 = OpLoad %v4float %color1", + "%50 = OpLoad %v4float %color2", + "%51 = OpFAdd %v4float %49 %50", + "%52 = OpLoad %v4float %color3", + "%53 = OpFAdd %v4float %51 %52", + "%54 = OpCompositeConstruct %v4float %float_3 %float_3 %float_3 %float_3", + "%55 = OpFDiv %v4float %53 %54", + "OpStore %FragColor %55", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector after = { + // clang-format off + "%main = OpFunction %void None %15", + "%38 = OpLabel", + "%66 = OpVariable %_ptr_Function_float Function", + "%67 = OpVariable %_ptr_Function_float Function", + "%color1 = OpVariable %_ptr_Function_v4float Function", + "%color2 = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_v4float Function", + "%color3 = OpVariable %_ptr_Function_v4float Function", + "%39 = OpLoad %25 %t2D", + "%40 = OpLoad %27 %samp", + "%41 = OpSampledImage %29 %39 %40", + "%42 = OpImageSampleImplicitLod %v4float %41 %32", + "OpStore %color1 %42", + "%43 = OpLoad %v4float %BaseColor", + "OpStore %param %43", + "%68 = OpAccessChain %_ptr_Function_float %param %uint_0", + "%69 = OpLoad %float %68", + "OpStore %66 %69", + "%70 = OpLoad %float %66", + "%71 = OpFOrdLessThan %bool %70 %float_0", + "OpSelectionMerge %72 None", + "OpBranchConditional %71 %73 %72", + "%73 = OpLabel", + "%74 = OpLoad %float %66", + "%75 = OpFNegate %float %74", + "OpStore %66 %75", + "OpBranch %72", + "%72 = OpLabel", + "%76 = OpLoad %float %66", + "OpStore %67 %76", + "%44 = OpLoad %float %67", + "%45 = OpCompositeConstruct %v4float %44 %44 %44 %44", + "OpStore %color2 %45", + "%46 = OpLoad %25 %t2D", + "%47 = OpLoad %27 %samp", + "%77 = OpSampledImage %29 %39 %40", + "%48 = OpImageSampleImplicitLod %v4float %77 %35", + "OpStore %color3 %48", + "%49 = OpLoad %v4float %color1", + "%50 = OpLoad %v4float %color2", + "%51 = OpFAdd %v4float %49 %50", + "%52 = OpLoad %v4float %color3", + "%53 = OpFAdd %v4float %51 %52", + "%54 = OpCompositeConstruct %v4float %float_3 %float_3 %float_3 %float_3", + "%55 = OpFDiv %v4float %53 %54", + "OpStore %FragColor %55", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + SinglePassRunAndCheck( + JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)), + JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)), + /* skip_nop = */ false, /* do_validate = */ true); +} + +TEST_F(InlineTest, OpImageOutOfBlock) { + // #version 450 + // + // uniform texture2D t2D; + // uniform sampler samp; + // uniform sampler samp2; + // + // out vec4 FragColor; + // + // in vec4 BaseColor; + // + // float foo(vec4 bar) + // { + // float r = bar.x; + // if (r < 0.0) + // r = -r; + // return r; + // } + // + // void main() + // { + // vec4 color1 = texture(sampler2D(t2D, samp), vec2(1.0)); + // vec4 color2 = vec4(foo(BaseColor)); + // vec4 color3 = texture(sampler2D(t2D, samp2), vec2(0.5)); + // FragColor = (color1 + color2 + color3)/3; + // } + // Note: the before SPIR-V will need to be edited to create an OpImage + // from the first OpSampledImage, place it before the call and use it + // in the second OpSampledImage following the call. + const std::vector predefs = { + // clang-format off + "OpCapability Shader", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Fragment %main \"main\" %BaseColor %FragColor", + "OpExecutionMode %main OriginUpperLeft", + "OpSource GLSL 450", + "OpName %main \"main\"", + "OpName %foo_vf4_ \"foo(vf4;\"", + "OpName %bar \"bar\"", + "OpName %r \"r\"", + "OpName %color1 \"color1\"", + "OpName %t2D \"t2D\"", + "OpName %samp \"samp\"", + "OpName %color2 \"color2\"", + "OpName %BaseColor \"BaseColor\"", + "OpName %param \"param\"", + "OpName %color3 \"color3\"", + "OpName %samp2 \"samp2\"", + "OpName %FragColor \"FragColor\"", + "OpDecorate %t2D DescriptorSet 0", + "OpDecorate %samp DescriptorSet 0", + "OpDecorate %samp2 DescriptorSet 0", + "%void = OpTypeVoid", + "%16 = OpTypeFunction %void", + "%float = OpTypeFloat 32", + "%v4float = OpTypeVector %float 4", +"%_ptr_Function_v4float = OpTypePointer Function %v4float", + "%20 = OpTypeFunction %float %_ptr_Function_v4float", +"%_ptr_Function_float = OpTypePointer Function %float", + "%uint = OpTypeInt 32 0", + "%uint_0 = OpConstant %uint 0", + "%float_0 = OpConstant %float 0", + "%bool = OpTypeBool", + "%26 = OpTypeImage %float 2D 0 0 0 1 Unknown", +"%_ptr_UniformConstant_26 = OpTypePointer UniformConstant %26", + "%t2D = OpVariable %_ptr_UniformConstant_26 UniformConstant", + "%28 = OpTypeSampler", +"%_ptr_UniformConstant_28 = OpTypePointer UniformConstant %28", + "%samp = OpVariable %_ptr_UniformConstant_28 UniformConstant", + "%30 = OpTypeSampledImage %26", + "%v2float = OpTypeVector %float 2", + "%float_1 = OpConstant %float 1", + "%33 = OpConstantComposite %v2float %float_1 %float_1", +"%_ptr_Input_v4float = OpTypePointer Input %v4float", + "%BaseColor = OpVariable %_ptr_Input_v4float Input", + "%samp2 = OpVariable %_ptr_UniformConstant_28 UniformConstant", + "%float_0_5 = OpConstant %float 0.5", + "%36 = OpConstantComposite %v2float %float_0_5 %float_0_5", +"%_ptr_Output_v4float = OpTypePointer Output %v4float", + "%FragColor = OpVariable %_ptr_Output_v4float Output", + "%float_3 = OpConstant %float 3", + // clang-format on + }; + + const std::vector nonEntryFuncs = { + // clang-format off + "%foo_vf4_ = OpFunction %float None %20", + "%bar = OpFunctionParameter %_ptr_Function_v4float", + "%58 = OpLabel", + "%r = OpVariable %_ptr_Function_float Function", + "%59 = OpAccessChain %_ptr_Function_float %bar %uint_0", + "%60 = OpLoad %float %59", + "OpStore %r %60", + "%61 = OpLoad %float %r", + "%62 = OpFOrdLessThan %bool %61 %float_0", + "OpSelectionMerge %63 None", + "OpBranchConditional %62 %64 %63", + "%64 = OpLabel", + "%65 = OpLoad %float %r", + "%66 = OpFNegate %float %65", + "OpStore %r %66", + "OpBranch %63", + "%63 = OpLabel", + "%67 = OpLoad %float %r", + "OpReturnValue %67", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector before = { + // clang-format off + "%main = OpFunction %void None %16", + "%39 = OpLabel", + "%color1 = OpVariable %_ptr_Function_v4float Function", + "%color2 = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_v4float Function", + "%color3 = OpVariable %_ptr_Function_v4float Function", + "%40 = OpLoad %26 %t2D", + "%41 = OpLoad %28 %samp", + "%42 = OpSampledImage %30 %40 %41", + "%43 = OpImageSampleImplicitLod %v4float %42 %33", + "%44 = OpImage %26 %42", + "%45 = OpLoad %28 %samp2", + "OpStore %color1 %43", + "%46 = OpLoad %v4float %BaseColor", + "OpStore %param %46", + "%47 = OpFunctionCall %float %foo_vf4_ %param", + "%48 = OpCompositeConstruct %v4float %47 %47 %47 %47", + "OpStore %color2 %48", + "%49 = OpSampledImage %30 %44 %45", + "%50 = OpImageSampleImplicitLod %v4float %49 %36", + "OpStore %color3 %50", + "%51 = OpLoad %v4float %color1", + "%52 = OpLoad %v4float %color2", + "%53 = OpFAdd %v4float %51 %52", + "%54 = OpLoad %v4float %color3", + "%55 = OpFAdd %v4float %53 %54", + "%56 = OpCompositeConstruct %v4float %float_3 %float_3 %float_3 %float_3", + "%57 = OpFDiv %v4float %55 %56", + "OpStore %FragColor %57", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector after = { + // clang-format off + "%main = OpFunction %void None %16", + "%39 = OpLabel", + "%68 = OpVariable %_ptr_Function_float Function", + "%69 = OpVariable %_ptr_Function_float Function", + "%color1 = OpVariable %_ptr_Function_v4float Function", + "%color2 = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_v4float Function", + "%color3 = OpVariable %_ptr_Function_v4float Function", + "%40 = OpLoad %26 %t2D", + "%41 = OpLoad %28 %samp", + "%42 = OpSampledImage %30 %40 %41", + "%43 = OpImageSampleImplicitLod %v4float %42 %33", + "%44 = OpImage %26 %42", + "%45 = OpLoad %28 %samp2", + "OpStore %color1 %43", + "%46 = OpLoad %v4float %BaseColor", + "OpStore %param %46", + "%70 = OpAccessChain %_ptr_Function_float %param %uint_0", + "%71 = OpLoad %float %70", + "OpStore %68 %71", + "%72 = OpLoad %float %68", + "%73 = OpFOrdLessThan %bool %72 %float_0", + "OpSelectionMerge %74 None", + "OpBranchConditional %73 %75 %74", + "%75 = OpLabel", + "%76 = OpLoad %float %68", + "%77 = OpFNegate %float %76", + "OpStore %68 %77", + "OpBranch %74", + "%74 = OpLabel", + "%78 = OpLoad %float %68", + "OpStore %69 %78", + "%47 = OpLoad %float %69", + "%48 = OpCompositeConstruct %v4float %47 %47 %47 %47", + "OpStore %color2 %48", + "%79 = OpSampledImage %30 %40 %41", + "%80 = OpImage %26 %79", + "%49 = OpSampledImage %30 %80 %45", + "%50 = OpImageSampleImplicitLod %v4float %49 %36", + "OpStore %color3 %50", + "%51 = OpLoad %v4float %color1", + "%52 = OpLoad %v4float %color2", + "%53 = OpFAdd %v4float %51 %52", + "%54 = OpLoad %v4float %color3", + "%55 = OpFAdd %v4float %53 %54", + "%56 = OpCompositeConstruct %v4float %float_3 %float_3 %float_3 %float_3", + "%57 = OpFDiv %v4float %55 %56", + "OpStore %FragColor %57", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + SinglePassRunAndCheck( + JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)), + JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)), + /* skip_nop = */ false, /* do_validate = */ true); +} + +TEST_F(InlineTest, OpImageAndOpSampledImageOutOfBlock) { + // #version 450 + // + // uniform texture2D t2D; + // uniform sampler samp; + // uniform sampler samp2; + // + // out vec4 FragColor; + // + // in vec4 BaseColor; + // + // float foo(vec4 bar) + // { + // float r = bar.x; + // if (r < 0.0) + // r = -r; + // return r; + // } + // + // void main() + // { + // vec4 color1 = texture(sampler2D(t2D, samp), vec2(1.0)); + // vec4 color2 = vec4(foo(BaseColor)); + // vec4 color3 = texture(sampler2D(t2D, samp2), vec2(0.5)); + // FragColor = (color1 + color2 + color3)/3; + // } + // Note: the before SPIR-V will need to be edited to create an OpImage + // and subsequent OpSampledImage that is used across the function call. + const std::vector predefs = { + // clang-format off + "OpCapability Shader", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Fragment %main \"main\" %BaseColor %FragColor", + "OpExecutionMode %main OriginUpperLeft", + "OpSource GLSL 450", + "OpName %main \"main\"", + "OpName %foo_vf4_ \"foo(vf4;\"", + "OpName %bar \"bar\"", + "OpName %r \"r\"", + "OpName %color1 \"color1\"", + "OpName %t2D \"t2D\"", + "OpName %samp \"samp\"", + "OpName %color2 \"color2\"", + "OpName %BaseColor \"BaseColor\"", + "OpName %param \"param\"", + "OpName %color3 \"color3\"", + "OpName %samp2 \"samp2\"", + "OpName %FragColor \"FragColor\"", + "OpDecorate %t2D DescriptorSet 0", + "OpDecorate %samp DescriptorSet 0", + "OpDecorate %samp2 DescriptorSet 0", + "%void = OpTypeVoid", + "%16 = OpTypeFunction %void", + "%float = OpTypeFloat 32", + "%v4float = OpTypeVector %float 4", +"%_ptr_Function_v4float = OpTypePointer Function %v4float", + "%20 = OpTypeFunction %float %_ptr_Function_v4float", +"%_ptr_Function_float = OpTypePointer Function %float", + "%uint = OpTypeInt 32 0", + "%uint_0 = OpConstant %uint 0", + "%float_0 = OpConstant %float 0", + "%bool = OpTypeBool", + "%26 = OpTypeImage %float 2D 0 0 0 1 Unknown", +"%_ptr_UniformConstant_26 = OpTypePointer UniformConstant %26", + "%t2D = OpVariable %_ptr_UniformConstant_26 UniformConstant", + "%28 = OpTypeSampler", +"%_ptr_UniformConstant_28 = OpTypePointer UniformConstant %28", + "%samp = OpVariable %_ptr_UniformConstant_28 UniformConstant", + "%30 = OpTypeSampledImage %26", + "%v2float = OpTypeVector %float 2", + "%float_1 = OpConstant %float 1", + "%33 = OpConstantComposite %v2float %float_1 %float_1", +"%_ptr_Input_v4float = OpTypePointer Input %v4float", + "%BaseColor = OpVariable %_ptr_Input_v4float Input", + "%samp2 = OpVariable %_ptr_UniformConstant_28 UniformConstant", + "%float_0_5 = OpConstant %float 0.5", + "%36 = OpConstantComposite %v2float %float_0_5 %float_0_5", +"%_ptr_Output_v4float = OpTypePointer Output %v4float", + "%FragColor = OpVariable %_ptr_Output_v4float Output", + "%float_3 = OpConstant %float 3", + // clang-format on + }; + + const std::vector nonEntryFuncs = { + // clang-format off + "%foo_vf4_ = OpFunction %float None %20", + "%bar = OpFunctionParameter %_ptr_Function_v4float", + "%58 = OpLabel", + "%r = OpVariable %_ptr_Function_float Function", + "%59 = OpAccessChain %_ptr_Function_float %bar %uint_0", + "%60 = OpLoad %float %59", + "OpStore %r %60", + "%61 = OpLoad %float %r", + "%62 = OpFOrdLessThan %bool %61 %float_0", + "OpSelectionMerge %63 None", + "OpBranchConditional %62 %64 %63", + "%64 = OpLabel", + "%65 = OpLoad %float %r", + "%66 = OpFNegate %float %65", + "OpStore %r %66", + "OpBranch %63", + "%63 = OpLabel", + "%67 = OpLoad %float %r", + "OpReturnValue %67", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector before = { + // clang-format off + "%main = OpFunction %void None %16", + "%39 = OpLabel", + "%color1 = OpVariable %_ptr_Function_v4float Function", + "%color2 = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_v4float Function", + "%color3 = OpVariable %_ptr_Function_v4float Function", + "%40 = OpLoad %26 %t2D", + "%41 = OpLoad %28 %samp", + "%42 = OpSampledImage %30 %40 %41", + "%43 = OpImageSampleImplicitLod %v4float %42 %33", + "%44 = OpImage %26 %42", + "%45 = OpLoad %28 %samp2", + "%46 = OpSampledImage %30 %44 %45", + "OpStore %color1 %43", + "%47 = OpLoad %v4float %BaseColor", + "OpStore %param %47", + "%48 = OpFunctionCall %float %foo_vf4_ %param", + "%49 = OpCompositeConstruct %v4float %48 %48 %48 %48", + "OpStore %color2 %49", + "%50 = OpImageSampleImplicitLod %v4float %46 %36", + "OpStore %color3 %50", + "%51 = OpLoad %v4float %color1", + "%52 = OpLoad %v4float %color2", + "%53 = OpFAdd %v4float %51 %52", + "%54 = OpLoad %v4float %color3", + "%55 = OpFAdd %v4float %53 %54", + "%56 = OpCompositeConstruct %v4float %float_3 %float_3 %float_3 %float_3", + "%57 = OpFDiv %v4float %55 %56", + "OpStore %FragColor %57", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector after = { + // clang-format off + "%main = OpFunction %void None %16", + "%39 = OpLabel", + "%68 = OpVariable %_ptr_Function_float Function", + "%69 = OpVariable %_ptr_Function_float Function", + "%color1 = OpVariable %_ptr_Function_v4float Function", + "%color2 = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_v4float Function", + "%color3 = OpVariable %_ptr_Function_v4float Function", + "%40 = OpLoad %26 %t2D", + "%41 = OpLoad %28 %samp", + "%42 = OpSampledImage %30 %40 %41", + "%43 = OpImageSampleImplicitLod %v4float %42 %33", + "%44 = OpImage %26 %42", + "%45 = OpLoad %28 %samp2", + "%46 = OpSampledImage %30 %44 %45", + "OpStore %color1 %43", + "%47 = OpLoad %v4float %BaseColor", + "OpStore %param %47", + "%70 = OpAccessChain %_ptr_Function_float %param %uint_0", + "%71 = OpLoad %float %70", + "OpStore %68 %71", + "%72 = OpLoad %float %68", + "%73 = OpFOrdLessThan %bool %72 %float_0", + "OpSelectionMerge %74 None", + "OpBranchConditional %73 %75 %74", + "%75 = OpLabel", + "%76 = OpLoad %float %68", + "%77 = OpFNegate %float %76", + "OpStore %68 %77", + "OpBranch %74", + "%74 = OpLabel", + "%78 = OpLoad %float %68", + "OpStore %69 %78", + "%48 = OpLoad %float %69", + "%49 = OpCompositeConstruct %v4float %48 %48 %48 %48", + "OpStore %color2 %49", + "%79 = OpSampledImage %30 %40 %41", + "%80 = OpImage %26 %79", + "%81 = OpSampledImage %30 %80 %45", + "%50 = OpImageSampleImplicitLod %v4float %81 %36", + "OpStore %color3 %50", + "%51 = OpLoad %v4float %color1", + "%52 = OpLoad %v4float %color2", + "%53 = OpFAdd %v4float %51 %52", + "%54 = OpLoad %v4float %color3", + "%55 = OpFAdd %v4float %53 %54", + "%56 = OpCompositeConstruct %v4float %float_3 %float_3 %float_3 %float_3", + "%57 = OpFDiv %v4float %55 %56", + "OpStore %FragColor %57", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + SinglePassRunAndCheck( + JoinAllInsts(Concat(Concat(predefs, before), nonEntryFuncs)), + JoinAllInsts(Concat(Concat(predefs, after), nonEntryFuncs)), + /* skip_nop = */ false, /* do_validate = */ true); +} + +TEST_F(InlineTest, EarlyReturnFunctionInlined) { + // #version 140 + // + // in vec4 BaseColor; + // + // float foo(vec4 bar) + // { + // if (bar.x < 0.0) + // return 0.0; + // return bar.x; + // } + // + // void main() + // { + // vec4 color = vec4(foo(BaseColor)); + // gl_FragColor = color; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %foo_vf4_ "foo(vf4;" +OpName %bar "bar" +OpName %color "color" +OpName %BaseColor "BaseColor" +OpName %param "param" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%14 = OpTypeFunction %float %_ptr_Function_v4float +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%bool = OpTypeBool +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string nonEntryFuncs = + R"(%foo_vf4_ = OpFunction %float None %14 +%bar = OpFunctionParameter %_ptr_Function_v4float +%27 = OpLabel +%28 = OpAccessChain %_ptr_Function_float %bar %uint_0 +%29 = OpLoad %float %28 +%30 = OpFOrdLessThan %bool %29 %float_0 +OpSelectionMerge %31 None +OpBranchConditional %30 %32 %31 +%32 = OpLabel +OpReturnValue %float_0 +%31 = OpLabel +%33 = OpAccessChain %_ptr_Function_float %bar %uint_0 +%34 = OpLoad %float %33 +OpReturnValue %34 +OpFunctionEnd +)"; + + const std::string before = + R"(%main = OpFunction %void None %10 +%22 = OpLabel +%color = OpVariable %_ptr_Function_v4float Function +%param = OpVariable %_ptr_Function_v4float Function +%23 = OpLoad %v4float %BaseColor +OpStore %param %23 +%24 = OpFunctionCall %float %foo_vf4_ %param +%25 = OpCompositeConstruct %v4float %24 %24 %24 %24 +OpStore %color %25 +%26 = OpLoad %v4float %color +OpStore %gl_FragColor %26 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%false = OpConstantFalse %bool +%main = OpFunction %void None %10 +%22 = OpLabel +%35 = OpVariable %_ptr_Function_float Function +%color = OpVariable %_ptr_Function_v4float Function +%param = OpVariable %_ptr_Function_v4float Function +%23 = OpLoad %v4float %BaseColor +OpStore %param %23 +OpBranch %36 +%36 = OpLabel +OpLoopMerge %37 %38 None +OpBranch %39 +%39 = OpLabel +%40 = OpAccessChain %_ptr_Function_float %param %uint_0 +%41 = OpLoad %float %40 +%42 = OpFOrdLessThan %bool %41 %float_0 +OpSelectionMerge %43 None +OpBranchConditional %42 %44 %43 +%44 = OpLabel +OpStore %35 %float_0 +OpBranch %37 +%43 = OpLabel +%45 = OpAccessChain %_ptr_Function_float %param %uint_0 +%46 = OpLoad %float %45 +OpStore %35 %46 +OpBranch %37 +%38 = OpLabel +OpBranchConditional %false %36 %37 +%37 = OpLabel +%24 = OpLoad %float %35 +%25 = OpCompositeConstruct %v4float %24 %24 %24 %24 +OpStore %color %25 +%26 = OpLoad %v4float %color +OpStore %gl_FragColor %26 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before + nonEntryFuncs, + predefs + after + nonEntryFuncs, + false, true); +} + +TEST_F(InlineTest, EarlyReturnNotAppearingLastInFunctionInlined) { + // Example from https://github.com/KhronosGroup/SPIRV-Tools/issues/755 + // + // Original example is derived from: + // + // #version 450 + // + // float foo() { + // if (true) { + // } + // } + // + // void main() { foo(); } + // + // But the order of basic blocks in foo is changed so that the return + // block is listed second-last. There is only one return in the callee + // but it does not appear last. + + const std::string predefs = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %main "main" +OpSource GLSL 450 +OpName %main "main" +OpName %foo_ "foo(" +%void = OpTypeVoid +%4 = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +)"; + + const std::string nonEntryFuncs = + R"(%foo_ = OpFunction %void None %4 +%7 = OpLabel +OpSelectionMerge %8 None +OpBranchConditional %true %9 %8 +%8 = OpLabel +OpReturn +%9 = OpLabel +OpBranch %8 +OpFunctionEnd +)"; + + const std::string before = + R"(%main = OpFunction %void None %4 +%10 = OpLabel +%11 = OpFunctionCall %void %foo_ +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %4 +%10 = OpLabel +OpSelectionMerge %12 None +OpBranchConditional %true %13 %12 +%12 = OpLabel +OpBranch %14 +%13 = OpLabel +OpBranch %12 +%14 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + nonEntryFuncs + before, + predefs + nonEntryFuncs + after, + false, true); +} + +TEST_F(InlineTest, ForwardReferencesInPhiInlined) { + // The basic structure of the test case is like this: + // + // int foo() { + // int result = 1; + // if (true) { + // result = 1; + // } + // return result; + // } + // + // void main() { + // int x = foo(); + // } + // + // but with modifications: Using Phi instead of load/store, and the + // return block in foo appears before the "then" block. + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %main "main" +OpSource GLSL 450 +OpName %main "main" +OpName %foo_ "foo(" +OpName %x "x" +%void = OpTypeVoid +%6 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%8 = OpTypeFunction %int +%bool = OpTypeBool +%true = OpConstantTrue %bool +%int_0 = OpConstant %int 0 +%_ptr_Function_int = OpTypePointer Function %int +)"; + + const std::string nonEntryFuncs = + R"(%foo_ = OpFunction %int None %8 +%13 = OpLabel +%14 = OpCopyObject %int %int_0 +OpSelectionMerge %15 None +OpBranchConditional %true %16 %15 +%15 = OpLabel +%17 = OpPhi %int %14 %13 %18 %16 +OpReturnValue %17 +%16 = OpLabel +%18 = OpCopyObject %int %int_0 +OpBranch %15 +OpFunctionEnd +)"; + + const std::string before = + R"(%main = OpFunction %void None %6 +%19 = OpLabel +%x = OpVariable %_ptr_Function_int Function +%20 = OpFunctionCall %int %foo_ +OpStore %x %20 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %6 +%19 = OpLabel +%21 = OpVariable %_ptr_Function_int Function +%x = OpVariable %_ptr_Function_int Function +%22 = OpCopyObject %int %int_0 +OpSelectionMerge %23 None +OpBranchConditional %true %24 %23 +%23 = OpLabel +%26 = OpPhi %int %22 %19 %25 %24 +OpStore %21 %26 +OpBranch %27 +%24 = OpLabel +%25 = OpCopyObject %int %int_0 +OpBranch %23 +%27 = OpLabel +%20 = OpLoad %int %21 +OpStore %x %20 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + nonEntryFuncs + before, + predefs + nonEntryFuncs + after, + false, true); +} + +TEST_F(InlineTest, EarlyReturnInLoopIsNotInlined) { + // #version 140 + // + // in vec4 BaseColor; + // + // float foo(vec4 bar) + // { + // while (true) { + // if (bar.x < 0.0) + // return 0.0; + // return bar.x; + // } + // } + // + // void main() + // { + // vec4 color = vec4(foo(BaseColor)); + // gl_FragColor = color; + // } + + const std::string assembly = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %foo_vf4_ "foo(vf4;" +OpName %bar "bar" +OpName %color "color" +OpName %BaseColor "BaseColor" +OpName %param "param" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%14 = OpTypeFunction %float %_ptr_Function_v4float +%bool = OpTypeBool +%true = OpConstantTrue %bool +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %10 +%23 = OpLabel +%color = OpVariable %_ptr_Function_v4float Function +%param = OpVariable %_ptr_Function_v4float Function +%24 = OpLoad %v4float %BaseColor +OpStore %param %24 +%25 = OpFunctionCall %float %foo_vf4_ %param +%26 = OpCompositeConstruct %v4float %25 %25 %25 %25 +OpStore %color %26 +%27 = OpLoad %v4float %color +OpStore %gl_FragColor %27 +OpReturn +OpFunctionEnd +%foo_vf4_ = OpFunction %float None %14 +%bar = OpFunctionParameter %_ptr_Function_v4float +%28 = OpLabel +OpBranch %29 +%29 = OpLabel +OpLoopMerge %30 %31 None +OpBranch %32 +%32 = OpLabel +OpBranchConditional %true %33 %30 +%33 = OpLabel +%34 = OpAccessChain %_ptr_Function_float %bar %uint_0 +%35 = OpLoad %float %34 +%36 = OpFOrdLessThan %bool %35 %float_0 +OpSelectionMerge %37 None +OpBranchConditional %36 %38 %37 +%38 = OpLabel +OpReturnValue %float_0 +%37 = OpLabel +%39 = OpAccessChain %_ptr_Function_float %bar %uint_0 +%40 = OpLoad %float %39 +OpReturnValue %40 +%31 = OpLabel +OpBranch %29 +%30 = OpLabel +%41 = OpUndef %float +OpReturnValue %41 +OpFunctionEnd +)"; + + SinglePassRunAndCheck(assembly, assembly, false, true); +} + +TEST_F(InlineTest, ExternalFunctionIsNotInlined) { + // In particular, don't crash. + // See report https://github.com/KhronosGroup/SPIRV-Tools/issues/605 + const std::string assembly = + R"(OpCapability Addresses +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Physical32 OpenCL +OpEntryPoint Kernel %1 "entry_pt" +OpDecorate %2 LinkageAttributes "external" Import +%void = OpTypeVoid +%4 = OpTypeFunction %void +%2 = OpFunction %void None %4 +OpFunctionEnd +%1 = OpFunction %void None %4 +%5 = OpLabel +%6 = OpFunctionCall %void %2 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(assembly, assembly, false, true); +} + +TEST_F(InlineTest, SingleBlockLoopCallsMultiBlockCallee) { + // Example from https://github.com/KhronosGroup/SPIRV-Tools/issues/787 + // + // CFG structure is: + // foo: + // fooentry -> fooexit + // + // main: + // entry -> loop + // loop -> loop, merge + // loop calls foo() + // merge + // + // Since the callee has multiple blocks, it will split the calling block + // into at least two, resulting in a new "back-half" block that contains + // the instructions after the inlined function call. If the calling block + // has an OpLoopMerge that points back to the calling block itself, then + // the OpLoopMerge can't remain in the back-half block, but must be + // moved to the end of the original calling block, and it continue target + // operand updated to point to the back-half block. + + const std::string predefs = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %1 "main" +OpSource OpenCL_C 120 +%bool = OpTypeBool +%true = OpConstantTrue %bool +%void = OpTypeVoid +)"; + + const std::string nonEntryFuncs = + R"(%5 = OpTypeFunction %void +%6 = OpFunction %void None %5 +%7 = OpLabel +OpBranch %8 +%8 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string before = + R"(%1 = OpFunction %void None %5 +%9 = OpLabel +OpBranch %10 +%10 = OpLabel +%11 = OpFunctionCall %void %6 +OpLoopMerge %12 %10 None +OpBranchConditional %true %10 %12 +%12 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%1 = OpFunction %void None %5 +%9 = OpLabel +OpBranch %10 +%10 = OpLabel +OpLoopMerge %12 %13 None +OpBranch %13 +%13 = OpLabel +OpBranchConditional %true %10 %12 +%12 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + nonEntryFuncs + before, + predefs + nonEntryFuncs + after, + false, true); +} + +TEST_F(InlineTest, MultiBlockLoopHeaderCallsMultiBlockCallee) { + // Like SingleBlockLoopCallsMultiBlockCallee but the loop has several + // blocks, but the function call still occurs in the loop header. + // Example from https://github.com/KhronosGroup/SPIRV-Tools/issues/800 + + const std::string predefs = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %1 "main" +OpSource OpenCL_C 120 +%bool = OpTypeBool +%true = OpConstantTrue %bool +%int = OpTypeInt 32 1 +%int_1 = OpConstant %int 1 +%int_2 = OpConstant %int 2 +%int_3 = OpConstant %int 3 +%int_4 = OpConstant %int 4 +%int_5 = OpConstant %int 5 +%void = OpTypeVoid +%11 = OpTypeFunction %void +)"; + + const std::string nonEntryFuncs = + R"(%12 = OpFunction %void None %11 +%13 = OpLabel +%14 = OpCopyObject %int %int_1 +OpBranch %15 +%15 = OpLabel +%16 = OpCopyObject %int %int_2 +OpReturn +OpFunctionEnd +)"; + + const std::string before = + R"(%1 = OpFunction %void None %11 +%17 = OpLabel +OpBranch %18 +%18 = OpLabel +%19 = OpCopyObject %int %int_3 +%20 = OpFunctionCall %void %12 +%21 = OpCopyObject %int %int_4 +OpLoopMerge %22 %23 None +OpBranchConditional %true %23 %22 +%23 = OpLabel +%24 = OpCopyObject %int %int_5 +OpBranchConditional %true %18 %22 +%22 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%1 = OpFunction %void None %11 +%17 = OpLabel +OpBranch %18 +%18 = OpLabel +%19 = OpCopyObject %int %int_3 +%25 = OpCopyObject %int %int_1 +OpLoopMerge %22 %23 None +OpBranch %26 +%26 = OpLabel +%27 = OpCopyObject %int %int_2 +%21 = OpCopyObject %int %int_4 +OpBranchConditional %true %23 %22 +%23 = OpLabel +%24 = OpCopyObject %int %int_5 +OpBranchConditional %true %18 %22 +%22 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + nonEntryFuncs + before, + predefs + nonEntryFuncs + after, + false, true); +} + +TEST_F(InlineTest, SingleBlockLoopCallsMultiBlockCalleeHavingSelectionMerge) { + // This is similar to SingleBlockLoopCallsMultiBlockCallee except + // that calleee block also has a merge instruction in its first block. + // That merge instruction must be an OpSelectionMerge (because the entry + // block of a function can't be the header of a loop since the entry + // block can't be the target of a branch). + // + // In this case the OpLoopMerge can't be placed in the same block as + // the OpSelectionMerge, so inlining must create a new block to contain + // the callee contents. + // + // Additionally, we have two dummy OpCopyObject instructions to prove that + // the OpLoopMerge is moved to the right location. + // + // Also ensure that OpPhis within the cloned callee code are valid. + // We need to test that the predecessor blocks are remapped correctly so that + // dominance rules are satisfied + + const std::string predefs = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %1 "main" +OpSource OpenCL_C 120 +%bool = OpTypeBool +%true = OpConstantTrue %bool +%false = OpConstantFalse %bool +%void = OpTypeVoid +%6 = OpTypeFunction %void +)"; + + // This callee has multiple blocks, and an OpPhi in the last block + // that references a value from the first block. This tests that + // cloned block IDs are remapped appropriately. The OpPhi dominance + // requires that the remapped %9 must be in a block that dominates + // the remapped %8. + const std::string nonEntryFuncs = + R"(%7 = OpFunction %void None %6 +%8 = OpLabel +%9 = OpCopyObject %bool %true +OpSelectionMerge %10 None +OpBranchConditional %true %10 %10 +%10 = OpLabel +%11 = OpPhi %bool %9 %8 +OpReturn +OpFunctionEnd +)"; + + const std::string before = + R"(%1 = OpFunction %void None %6 +%12 = OpLabel +OpBranch %13 +%13 = OpLabel +%14 = OpCopyObject %bool %false +%15 = OpFunctionCall %void %7 +OpLoopMerge %16 %13 None +OpBranchConditional %true %13 %16 +%16 = OpLabel +OpReturn +OpFunctionEnd +)"; + + // Note the remapped Phi uses %17 as the parent instead + // of %13, demonstrating that the parent block has been remapped + // correctly. + const std::string after = + R"(%1 = OpFunction %void None %6 +%12 = OpLabel +OpBranch %13 +%13 = OpLabel +%14 = OpCopyObject %bool %false +OpLoopMerge %16 %19 None +OpBranch %17 +%17 = OpLabel +%18 = OpCopyObject %bool %true +OpSelectionMerge %19 None +OpBranchConditional %true %19 %19 +%19 = OpLabel +%20 = OpPhi %bool %18 %17 +OpBranchConditional %true %13 %16 +%16 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + nonEntryFuncs + before, + predefs + nonEntryFuncs + after, + false, true); +} + +TEST_F(InlineTest, + MultiBlockLoopHeaderCallsFromToMultiBlockCalleeHavingSelectionMerge) { + // This is similar to SingleBlockLoopCallsMultiBlockCalleeHavingSelectionMerge + // but the call is in the header block of a multi block loop. + + const std::string predefs = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %1 "main" +OpSource OpenCL_C 120 +%bool = OpTypeBool +%true = OpConstantTrue %bool +%int = OpTypeInt 32 1 +%int_1 = OpConstant %int 1 +%int_2 = OpConstant %int 2 +%int_3 = OpConstant %int 3 +%int_4 = OpConstant %int 4 +%int_5 = OpConstant %int 5 +%void = OpTypeVoid +%11 = OpTypeFunction %void +)"; + + const std::string nonEntryFuncs = + R"(%12 = OpFunction %void None %11 +%13 = OpLabel +%14 = OpCopyObject %int %int_1 +OpSelectionMerge %15 None +OpBranchConditional %true %15 %15 +%15 = OpLabel +%16 = OpCopyObject %int %int_2 +OpReturn +OpFunctionEnd +)"; + + const std::string before = + R"(%1 = OpFunction %void None %11 +%17 = OpLabel +OpBranch %18 +%18 = OpLabel +%19 = OpCopyObject %int %int_3 +%20 = OpFunctionCall %void %12 +%21 = OpCopyObject %int %int_4 +OpLoopMerge %22 %23 None +OpBranchConditional %true %23 %22 +%23 = OpLabel +%24 = OpCopyObject %int %int_5 +OpBranchConditional %true %18 %22 +%22 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%1 = OpFunction %void None %11 +%17 = OpLabel +OpBranch %18 +%18 = OpLabel +%19 = OpCopyObject %int %int_3 +OpLoopMerge %22 %23 None +OpBranch %25 +%25 = OpLabel +%26 = OpCopyObject %int %int_1 +OpSelectionMerge %27 None +OpBranchConditional %true %27 %27 +%27 = OpLabel +%28 = OpCopyObject %int %int_2 +%21 = OpCopyObject %int %int_4 +OpBranchConditional %true %23 %22 +%23 = OpLabel +%24 = OpCopyObject %int %int_5 +OpBranchConditional %true %18 %22 +%22 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + nonEntryFuncs + before, + predefs + nonEntryFuncs + after, + false, true); +} + +TEST_F( + InlineTest, + SingleBlockLoopCallsMultiBlockCalleeHavingSelectionMergeAndMultiReturns) { + // This is similar to SingleBlockLoopCallsMultiBlockCalleeHavingSelectionMerge + // except that in addition to starting with a selection header, the + // callee also has multi returns. + // + // So now we have to accommodate: + // - The caller's OpLoopMerge (which must move to the first block) + // - The single-trip loop to wrap the multi returns, and + // - The callee's selection merge in its first block. + // Each of these must go into their own blocks. + + const std::string predefs = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %1 "main" +OpSource OpenCL_C 120 +%bool = OpTypeBool +%int = OpTypeInt 32 1 +%true = OpConstantTrue %bool +%false = OpConstantFalse %bool +%int_0 = OpConstant %int 0 +%int_1 = OpConstant %int 1 +%int_2 = OpConstant %int 2 +%int_3 = OpConstant %int 3 +%int_4 = OpConstant %int 4 +%void = OpTypeVoid +%12 = OpTypeFunction %void +)"; + + const std::string nonEntryFuncs = + R"(%13 = OpFunction %void None %12 +%14 = OpLabel +%15 = OpCopyObject %int %int_0 +OpReturn +%16 = OpLabel +%17 = OpCopyObject %int %int_1 +OpReturn +OpFunctionEnd +)"; + + const std::string before = + R"(%1 = OpFunction %void None %12 +%18 = OpLabel +OpBranch %19 +%19 = OpLabel +%20 = OpCopyObject %int %int_2 +%21 = OpFunctionCall %void %13 +%22 = OpCopyObject %int %int_3 +OpLoopMerge %23 %19 None +OpBranchConditional %true %19 %23 +%23 = OpLabel +%24 = OpCopyObject %int %int_4 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%1 = OpFunction %void None %12 +%18 = OpLabel +OpBranch %19 +%19 = OpLabel +%20 = OpCopyObject %int %int_2 +%25 = OpCopyObject %int %int_0 +OpLoopMerge %23 %26 None +OpBranch %26 +%27 = OpLabel +%28 = OpCopyObject %int %int_1 +OpBranch %26 +%26 = OpLabel +%22 = OpCopyObject %int %int_3 +OpBranchConditional %true %19 %23 +%23 = OpLabel +%24 = OpCopyObject %int %int_4 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + nonEntryFuncs + before, + predefs + nonEntryFuncs + after, + false, true); +} + +TEST_F(InlineTest, CalleeWithMultiReturnAndPhiRequiresEntryBlockRemapping) { + // The case from https://github.com/KhronosGroup/SPIRV-Tools/issues/790 + // + // The callee has multiple returns, and so must be wrapped with a single-trip + // loop. That code must remap the callee entry block ID to the introduced + // loop body's ID. Otherwise you can get a dominance error in a cloned OpPhi. + + const std::string predefs = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %1 "main" +OpSource OpenCL_C 120 +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%int_1 = OpConstant %int 1 +%int_2 = OpConstant %int 2 +%int_3 = OpConstant %int 3 +%int_4 = OpConstant %int 4 +%void = OpTypeVoid +%9 = OpTypeFunction %void +%bool = OpTypeBool +%false = OpConstantFalse %bool +)"; + + // This callee has multiple returns, and a Phi in the second block referencing + // a value generated in the entry block. + const std::string nonEntryFuncs = + R"(%12 = OpFunction %void None %9 +%13 = OpLabel +%14 = OpCopyObject %int %int_0 +OpBranch %15 +%15 = OpLabel +%16 = OpPhi %int %14 %13 +%17 = OpCopyObject %int %int_1 +OpReturn +%18 = OpLabel +%19 = OpCopyObject %int %int_2 +OpReturn +OpFunctionEnd +)"; + + const std::string before = + R"(%1 = OpFunction %void None %9 +%20 = OpLabel +%21 = OpCopyObject %int %int_3 +%22 = OpFunctionCall %void %12 +%23 = OpCopyObject %int %int_4 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%1 = OpFunction %void None %9 +%20 = OpLabel +%21 = OpCopyObject %int %int_3 +%24 = OpCopyObject %int %int_0 +OpBranch %25 +%25 = OpLabel +%26 = OpPhi %int %24 %20 +%27 = OpCopyObject %int %int_1 +OpBranch %28 +%29 = OpLabel +%30 = OpCopyObject %int %int_2 +OpBranch %28 +%28 = OpLabel +%23 = OpCopyObject %int %int_4 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + nonEntryFuncs + before, + predefs + nonEntryFuncs + after, + false, true); +} + +TEST_F(InlineTest, NonInlinableCalleeWithSingleReturn) { + // The case from https://github.com/KhronosGroup/SPIRV-Tools/issues/2018 + // + // The callee has a single return, but cannot be inlined because the + // return is inside a loop. + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %_GLF_color +OpExecutionMode %main OriginUpperLeft +OpSource ESSL 310 +OpName %main "main" +OpName %f_ "f(" +OpName %i "i" +OpName %_GLF_color "_GLF_color" +OpDecorate %_GLF_color Location 0 +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%9 = OpTypeFunction %float +%float_1 = OpConstant %float 1 +%bool = OpTypeBool +%false = OpConstantFalse %bool +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%int_1 = OpConstant %int 1 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_GLF_color = OpVariable %_ptr_Output_v4float Output +%float_0 = OpConstant %float 0 +%20 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +%21 = OpConstantComposite %v4float %float_0 %float_1 %float_0 %float_1 +)"; + + const std::string caller = + R"(%main = OpFunction %void None %7 +%22 = OpLabel +%i = OpVariable %_ptr_Function_int Function +OpStore %i %int_0 +OpBranch %23 +%23 = OpLabel +OpLoopMerge %24 %25 None +OpBranch %26 +%26 = OpLabel +%27 = OpLoad %int %i +%28 = OpSLessThan %bool %27 %int_1 +OpBranchConditional %28 %29 %24 +%29 = OpLabel +OpStore %_GLF_color %20 +%30 = OpFunctionCall %float %f_ +OpBranch %25 +%25 = OpLabel +%31 = OpLoad %int %i +%32 = OpIAdd %int %31 %int_1 +OpStore %i %32 +OpBranch %23 +%24 = OpLabel +OpStore %_GLF_color %21 +OpReturn +OpFunctionEnd +)"; + + const std::string callee = + R"(%f_ = OpFunction %float None %9 +%33 = OpLabel +OpBranch %34 +%34 = OpLabel +OpLoopMerge %35 %36 None +OpBranch %37 +%37 = OpLabel +OpReturnValue %float_1 +%36 = OpLabel +OpBranch %34 +%35 = OpLabel +OpUnreachable +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs + caller + callee, predefs + caller + callee, false, true); +} + +TEST_F(InlineTest, CalleeWithSingleReturnNeedsSingleTripLoopWrapper) { + // The case from https://github.com/KhronosGroup/SPIRV-Tools/issues/2018 + // + // The callee has a single return, but needs single-trip loop wrapper + // to be inlined because the return is in a selection structure. + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %_GLF_color +OpExecutionMode %main OriginUpperLeft +OpSource ESSL 310 +OpName %main "main" +OpName %f_ "f(" +OpName %i "i" +OpName %_GLF_color "_GLF_color" +OpDecorate %_GLF_color Location 0 +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%9 = OpTypeFunction %float +%float_1 = OpConstant %float 1 +%bool = OpTypeBool +%false = OpConstantFalse %bool +%true = OpConstantTrue %bool +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%int_1 = OpConstant %int 1 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_GLF_color = OpVariable %_ptr_Output_v4float Output +%float_0 = OpConstant %float 0 +%21 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +%22 = OpConstantComposite %v4float %float_0 %float_1 %float_0 %float_1 +)"; + + const std::string new_predefs = + R"(%_ptr_Function_float = OpTypePointer Function %float +)"; + + const std::string main_before = + R"(%main = OpFunction %void None %7 +%23 = OpLabel +%i = OpVariable %_ptr_Function_int Function +OpStore %i %int_0 +OpBranch %24 +%24 = OpLabel +OpLoopMerge %25 %26 None +OpBranch %27 +%27 = OpLabel +%28 = OpLoad %int %i +%29 = OpSLessThan %bool %28 %int_1 +OpBranchConditional %29 %30 %25 +%30 = OpLabel +OpStore %_GLF_color %21 +%31 = OpFunctionCall %float %f_ +OpBranch %26 +%26 = OpLabel +%32 = OpLoad %int %i +%33 = OpIAdd %int %32 %int_1 +OpStore %i %33 +OpBranch %24 +%25 = OpLabel +OpStore %_GLF_color %22 +OpReturn +OpFunctionEnd +)"; + + const std::string main_after = + R"(%main = OpFunction %void None %7 +%23 = OpLabel +%38 = OpVariable %_ptr_Function_float Function +%i = OpVariable %_ptr_Function_int Function +OpStore %i %int_0 +OpBranch %24 +%24 = OpLabel +OpLoopMerge %25 %26 None +OpBranch %27 +%27 = OpLabel +%28 = OpLoad %int %i +%29 = OpSLessThan %bool %28 %int_1 +OpBranchConditional %29 %30 %25 +%30 = OpLabel +OpStore %_GLF_color %21 +OpBranch %39 +%39 = OpLabel +OpLoopMerge %40 %41 None +OpBranch %42 +%42 = OpLabel +OpSelectionMerge %43 None +OpBranchConditional %true %44 %43 +%44 = OpLabel +OpStore %38 %float_1 +OpBranch %40 +%43 = OpLabel +OpUnreachable +%41 = OpLabel +OpBranchConditional %false %39 %40 +%40 = OpLabel +%31 = OpLoad %float %38 +OpBranch %26 +%26 = OpLabel +%32 = OpLoad %int %i +%33 = OpIAdd %int %32 %int_1 +OpStore %i %33 +OpBranch %24 +%25 = OpLabel +OpStore %_GLF_color %22 +OpReturn +OpFunctionEnd +)"; + + const std::string callee = + R"(%f_ = OpFunction %float None %9 +%34 = OpLabel +OpSelectionMerge %35 None +OpBranchConditional %true %36 %35 +%36 = OpLabel +OpReturnValue %float_1 +%35 = OpLabel +OpUnreachable +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs + main_before + callee, + predefs + new_predefs + main_after + callee, false, true); +} + +TEST_F(InlineTest, Decorated1) { + // Same test as Simple with the difference + // that OpFAdd in the outlined function is + // decorated with RelaxedPrecision + // Expected result is an equal decoration + // of the corresponding inlined instruction + // + // #version 140 + // + // in vec4 BaseColor; + // + // float foo(vec4 bar) + // { + // return bar.x + bar.y; + // } + // + // void main() + // { + // vec4 color = vec4(foo(BaseColor)); + // gl_FragColor = color; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %foo_vf4_ "foo(vf4;" +OpName %bar "bar" +OpName %color "color" +OpName %BaseColor "BaseColor" +OpName %param "param" +OpName %gl_FragColor "gl_FragColor" +OpDecorate %9 RelaxedPrecision +)"; + + const std::string before = + R"(%void = OpTypeVoid +%11 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%15 = OpTypeFunction %float %_ptr_Function_v4float +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%_ptr_Function_float = OpTypePointer Function %float +%uint_1 = OpConstant %uint 1 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %11 +%22 = OpLabel +%color = OpVariable %_ptr_Function_v4float Function +%param = OpVariable %_ptr_Function_v4float Function +%23 = OpLoad %v4float %BaseColor +OpStore %param %23 +%24 = OpFunctionCall %float %foo_vf4_ %param +%25 = OpCompositeConstruct %v4float %24 %24 %24 %24 +OpStore %color %25 +%26 = OpLoad %v4float %color +OpStore %gl_FragColor %26 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(OpDecorate %37 RelaxedPrecision +%void = OpTypeVoid +%11 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%15 = OpTypeFunction %float %_ptr_Function_v4float +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%_ptr_Function_float = OpTypePointer Function %float +%uint_1 = OpConstant %uint 1 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %11 +%22 = OpLabel +%32 = OpVariable %_ptr_Function_float Function +%color = OpVariable %_ptr_Function_v4float Function +%param = OpVariable %_ptr_Function_v4float Function +%23 = OpLoad %v4float %BaseColor +OpStore %param %23 +%33 = OpAccessChain %_ptr_Function_float %param %uint_0 +%34 = OpLoad %float %33 +%35 = OpAccessChain %_ptr_Function_float %param %uint_1 +%36 = OpLoad %float %35 +%37 = OpFAdd %float %34 %36 +OpStore %32 %37 +%24 = OpLoad %float %32 +%25 = OpCompositeConstruct %v4float %24 %24 %24 %24 +OpStore %color %25 +%26 = OpLoad %v4float %color +OpStore %gl_FragColor %26 +OpReturn +OpFunctionEnd +)"; + + const std::string nonEntryFuncs = + R"(%foo_vf4_ = OpFunction %float None %15 +%bar = OpFunctionParameter %_ptr_Function_v4float +%27 = OpLabel +%28 = OpAccessChain %_ptr_Function_float %bar %uint_0 +%29 = OpLoad %float %28 +%30 = OpAccessChain %_ptr_Function_float %bar %uint_1 +%31 = OpLoad %float %30 +%9 = OpFAdd %float %29 %31 +OpReturnValue %9 +OpFunctionEnd +)"; + SinglePassRunAndCheck(predefs + before + nonEntryFuncs, + predefs + after + nonEntryFuncs, + false, true); +} + +TEST_F(InlineTest, Decorated2) { + // Same test as Simple with the difference + // that the Result of the outlined OpFunction + // is decorated with RelaxedPrecision + // Expected result is an equal decoration + // of the created return variable + // + // #version 140 + // + // in vec4 BaseColor; + // + // float foo(vec4 bar) + // { + // return bar.x + bar.y; + // } + // + // void main() + // { + // vec4 color = vec4(foo(BaseColor)); + // gl_FragColor = color; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %foo_vf4_ "foo(vf4;" +OpName %bar "bar" +OpName %color "color" +OpName %BaseColor "BaseColor" +OpName %param "param" +OpName %gl_FragColor "gl_FragColor" +OpDecorate %foo_vf4_ RelaxedPrecision +)"; + + const std::string before = + R"(%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%14 = OpTypeFunction %float %_ptr_Function_v4float +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%_ptr_Function_float = OpTypePointer Function %float +%uint_1 = OpConstant %uint 1 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %10 +%21 = OpLabel +%color = OpVariable %_ptr_Function_v4float Function +%param = OpVariable %_ptr_Function_v4float Function +%22 = OpLoad %v4float %BaseColor +OpStore %param %22 +%23 = OpFunctionCall %float %foo_vf4_ %param +%24 = OpCompositeConstruct %v4float %23 %23 %23 %23 +OpStore %color %24 +%25 = OpLoad %v4float %color +OpStore %gl_FragColor %25 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(OpDecorate %32 RelaxedPrecision +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%14 = OpTypeFunction %float %_ptr_Function_v4float +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%_ptr_Function_float = OpTypePointer Function %float +%uint_1 = OpConstant %uint 1 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %10 +%21 = OpLabel +%32 = OpVariable %_ptr_Function_float Function +%color = OpVariable %_ptr_Function_v4float Function +%param = OpVariable %_ptr_Function_v4float Function +%22 = OpLoad %v4float %BaseColor +OpStore %param %22 +%33 = OpAccessChain %_ptr_Function_float %param %uint_0 +%34 = OpLoad %float %33 +%35 = OpAccessChain %_ptr_Function_float %param %uint_1 +%36 = OpLoad %float %35 +%37 = OpFAdd %float %34 %36 +OpStore %32 %37 +%23 = OpLoad %float %32 +%24 = OpCompositeConstruct %v4float %23 %23 %23 %23 +OpStore %color %24 +%25 = OpLoad %v4float %color +OpStore %gl_FragColor %25 +OpReturn +OpFunctionEnd +)"; + + const std::string nonEntryFuncs = + R"(%foo_vf4_ = OpFunction %float None %14 +%bar = OpFunctionParameter %_ptr_Function_v4float +%26 = OpLabel +%27 = OpAccessChain %_ptr_Function_float %bar %uint_0 +%28 = OpLoad %float %27 +%29 = OpAccessChain %_ptr_Function_float %bar %uint_1 +%30 = OpLoad %float %29 +%31 = OpFAdd %float %28 %30 +OpReturnValue %31 +OpFunctionEnd +)"; + SinglePassRunAndCheck(predefs + before + nonEntryFuncs, + predefs + after + nonEntryFuncs, + false, true); +} + +TEST_F(InlineTest, DeleteName) { + // Test that the name of the result id of the call is deleted. + const std::string before = + R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpName %main "main" + OpName %main_entry "main_entry" + OpName %foo_result "foo_result" + OpName %void_fn "void_fn" + OpName %foo "foo" + OpName %foo_entry "foo_entry" + %void = OpTypeVoid + %void_fn = OpTypeFunction %void + %foo = OpFunction %void None %void_fn + %foo_entry = OpLabel + OpReturn + OpFunctionEnd + %main = OpFunction %void None %void_fn + %main_entry = OpLabel + %foo_result = OpFunctionCall %void %foo + OpReturn + OpFunctionEnd +)"; + + const std::string after = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %main "main" +OpName %main "main" +OpName %main_entry "main_entry" +OpName %void_fn "void_fn" +OpName %foo "foo" +OpName %foo_entry "foo_entry" +%void = OpTypeVoid +%void_fn = OpTypeFunction %void +%foo = OpFunction %void None %void_fn +%foo_entry = OpLabel +OpReturn +OpFunctionEnd +%main = OpFunction %void None %void_fn +%main_entry = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(before, after, false, true); +} + +TEST_F(InlineTest, SetParent) { + // Test that after inlining all basic blocks have the correct parent. + const std::string text = + R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpName %main "main" + OpName %main_entry "main_entry" + OpName %foo_result "foo_result" + OpName %void_fn "void_fn" + OpName %foo "foo" + OpName %foo_entry "foo_entry" + %void = OpTypeVoid + %void_fn = OpTypeFunction %void + %foo = OpFunction %void None %void_fn + %foo_entry = OpLabel + OpReturn + OpFunctionEnd + %main = OpFunction %void None %void_fn + %main_entry = OpLabel + %foo_result = OpFunctionCall %void %foo + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + InlineExhaustivePass pass; + pass.Run(context.get()); + + for (Function& func : *context->module()) { + for (BasicBlock& bb : func) { + EXPECT_TRUE(bb.GetParent() == &func); + } + } +} + +TEST_F(InlineTest, OpKill) { + const std::string text = R"( +; CHECK: OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpKill +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpReturn +; CHECK-NEXT: OpFunctionEnd +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%voidfuncty = OpTypeFunction %void +%main = OpFunction %void None %voidfuncty +%1 = OpLabel +%2 = OpFunctionCall %void %func +OpReturn +OpFunctionEnd +%func = OpFunction %void None %voidfuncty +%3 = OpLabel +OpKill +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(InlineTest, OpKillWithTrailingInstructions) { + const std::string text = R"( +; CHECK: OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: [[var:%\w+]] = OpVariable +; CHECK-NEXT: OpKill +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpStore [[var]] +; CHECK-NEXT: OpReturn +; CHECK-NEXT: OpFunctionEnd +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%bool_func_ptr = OpTypePointer Function %bool +%voidfuncty = OpTypeFunction %void +%main = OpFunction %void None %voidfuncty +%1 = OpLabel +%2 = OpVariable %bool_func_ptr Function +%3 = OpFunctionCall %void %func +OpStore %2 %true +OpReturn +OpFunctionEnd +%func = OpFunction %void None %voidfuncty +%4 = OpLabel +OpKill +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(InlineTest, OpKillInIf) { + const std::string text = R"( +; CHECK: OpFunction +; CHECK: OpLabel +; CHECK: [[var:%\w+]] = OpVariable +; CHECK-NEXT: [[ld:%\w+]] = OpLoad {{%\w+}} [[var]] +; CHECK-NEXT: OpBranch [[label:%\w+]] +; CHECK-NEXT: [[label]] = OpLabel +; CHECK-NEXT: OpLoopMerge [[loop_merge:%\w+]] [[continue:%\w+]] None +; CHECK-NEXT: OpBranch [[label:%\w+]] +; CHECK-NEXT: [[label]] = OpLabel +; CHECK-NEXT: OpSelectionMerge [[sel_merge:%\w+]] None +; CHECK-NEXT: OpBranchConditional {{%\w+}} [[kill_label:%\w+]] [[label:%\w+]] +; CHECK-NEXT: [[kill_label]] = OpLabel +; CHECK-NEXT: OpKill +; CHECK-NEXT: [[label]] = OpLabel +; CHECK-NEXT: OpBranch [[loop_merge]] +; CHECK-NEXT: [[sel_merge]] = OpLabel +; CHECK-NEXT: OpBranch [[loop_merge]] +; CHECK-NEXT: [[continue]] = OpLabel +; CHECK-NEXT: OpBranchConditional +; CHECK-NEXT: [[loop_merge]] = OpLabel +; CHECK-NEXT: OpStore [[var]] [[ld]] +; CHECK-NEXT: OpReturn +; CHECK-NEXT: OpFunctionEnd +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%bool_func_ptr = OpTypePointer Function %bool +%voidfuncty = OpTypeFunction %void +%main = OpFunction %void None %voidfuncty +%1 = OpLabel +%2 = OpVariable %bool_func_ptr Function +%3 = OpLoad %bool %2 +%4 = OpFunctionCall %void %func +OpStore %2 %3 +OpReturn +OpFunctionEnd +%func = OpFunction %void None %voidfuncty +%5 = OpLabel +OpSelectionMerge %6 None +OpBranchConditional %true %7 %8 +%7 = OpLabel +OpKill +%8 = OpLabel +OpReturn +%6 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(InlineTest, OpKillInLoop) { + const std::string text = R"( +; CHECK: OpFunction +; CHECK: OpLabel +; CHECK: [[var:%\w+]] = OpVariable +; CHECK-NEXT: [[ld:%\w+]] = OpLoad {{%\w+}} [[var]] +; CHECK-NEXT: OpBranch [[loop:%\w+]] +; CHECK-NEXT: [[loop]] = OpLabel +; CHECK-NEXT: OpLoopMerge [[loop_merge:%\w+]] [[continue:%\w+]] None +; CHECK-NEXT: OpBranch [[label:%\w+]] +; CHECK-NEXT: [[label]] = OpLabel +; CHECK-NEXT: OpKill +; CHECK-NEXT: [[loop_merge]] = OpLabel +; CHECK-NEXT: OpBranch [[label:%\w+]] +; CHECK-NEXT: [[continue]] = OpLabel +; CHECK-NEXT: OpBranch [[loop]] +; CHECK-NEXT: [[label]] = OpLabel +; CHECK-NEXT: OpStore [[var]] [[ld]] +; CHECK-NEXT: OpReturn +; CHECK-NEXT: OpFunctionEnd +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%voidfuncty = OpTypeFunction %void +%bool_func_ptr = OpTypePointer Function %bool +%main = OpFunction %void None %voidfuncty +%1 = OpLabel +%2 = OpVariable %bool_func_ptr Function +%3 = OpLoad %bool %2 +%4 = OpFunctionCall %void %func +OpStore %2 %3 +OpReturn +OpFunctionEnd +%func = OpFunction %void None %voidfuncty +%5 = OpLabel +OpBranch %10 +%10 = OpLabel +OpLoopMerge %6 %7 None +OpBranch %8 +%8 = OpLabel +OpKill +%6 = OpLabel +OpReturn +%7 = OpLabel +OpBranch %10 +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(InlineTest, OpVariableWithInit) { + // Check that there is a store that corresponds to the initializer. This + // test makes sure that is a store to the variable in the loop and before any + // load. + const std::string text = R"( +; CHECK: OpFunction +; CHECK-NOT: OpFunctionEnd +; CHECK: [[var:%\w+]] = OpVariable %_ptr_Function_float Function %float_0 +; CHECK: OpLoopMerge [[outer_merge:%\w+]] +; CHECK-NOT: OpLoad %float [[var]] +; CHECK: OpStore [[var]] %float_0 +; CHECK: OpFunctionEnd + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %o + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpDecorate %o Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %7 = OpTypeFunction %float +%_ptr_Function_float = OpTypePointer Function %float + %float_0 = OpConstant %float 0 + %bool = OpTypeBool + %float_1 = OpConstant %float 1 +%_ptr_Output_float = OpTypePointer Output %float + %o = OpVariable %_ptr_Output_float Output + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%_ptr_Input_int = OpTypePointer Input %int + %int_0 = OpConstant %int 0 + %int_1 = OpConstant %int 1 + %int_2 = OpConstant %int 2 + %main = OpFunction %void None %3 + %5 = OpLabel + OpStore %o %float_0 + OpBranch %34 + %34 = OpLabel + %39 = OpPhi %int %int_0 %5 %47 %37 + OpLoopMerge %36 %37 None + OpBranch %38 + %38 = OpLabel + %41 = OpSLessThan %bool %39 %int_2 + OpBranchConditional %41 %35 %36 + %35 = OpLabel + %42 = OpFunctionCall %float %foo_ + %43 = OpLoad %float %o + %44 = OpFAdd %float %43 %42 + OpStore %o %44 + OpBranch %37 + %37 = OpLabel + %47 = OpIAdd %int %39 %int_1 + OpBranch %34 + %36 = OpLabel + OpReturn + OpFunctionEnd + %foo_ = OpFunction %float None %7 + %9 = OpLabel + %n = OpVariable %_ptr_Function_float Function %float_0 + %13 = OpLoad %float %n + %15 = OpFOrdEqual %bool %13 %float_0 + OpSelectionMerge %17 None + OpBranchConditional %15 %16 %17 + %16 = OpLabel + %19 = OpLoad %float %n + %20 = OpFAdd %float %19 %float_1 + OpStore %n %20 + OpBranch %17 + %17 = OpLabel + %21 = OpLoad %float %n + OpReturnValue %21 + OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(InlineTest, DontInlineDirectlyRecursiveFunc) { + // Test that the name of the result id of the call is deleted. + const std::string test = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "main" +OpExecutionMode %1 OriginUpperLeft +OpDecorate %2 DescriptorSet 439418829 +%void = OpTypeVoid +%4 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_struct_6 = OpTypeStruct %float %float +%7 = OpTypeFunction %_struct_6 +%1 = OpFunction %void Pure|Const %4 +%8 = OpLabel +%2 = OpFunctionCall %_struct_6 %9 +OpKill +OpFunctionEnd +%9 = OpFunction %_struct_6 None %7 +%10 = OpLabel +%11 = OpFunctionCall %_struct_6 %9 +OpUnreachable +OpFunctionEnd +)"; + + SinglePassRunAndCheck(test, test, false, true); +} + +TEST_F(InlineTest, DontInlineInDirectlyRecursiveFunc) { + // Test that the name of the result id of the call is deleted. + const std::string test = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "main" +OpExecutionMode %1 OriginUpperLeft +OpDecorate %2 DescriptorSet 439418829 +%void = OpTypeVoid +%4 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_struct_6 = OpTypeStruct %float %float +%7 = OpTypeFunction %_struct_6 +%1 = OpFunction %void Pure|Const %4 +%8 = OpLabel +%2 = OpFunctionCall %_struct_6 %9 +OpKill +OpFunctionEnd +%9 = OpFunction %_struct_6 None %7 +%10 = OpLabel +%11 = OpFunctionCall %_struct_6 %12 +OpUnreachable +OpFunctionEnd +%12 = OpFunction %_struct_6 None %7 +%13 = OpLabel +%14 = OpFunctionCall %_struct_6 %9 +OpUnreachable +OpFunctionEnd +)"; + + SinglePassRunAndCheck(test, test, false, true); +} + +// TODO(greg-lunarg): Add tests to verify handling of these cases: +// +// Empty modules +// Modules without function definitions +// Modules in which all functions do not call other functions +// Caller and callee both accessing the same global variable +// Functions with OpLine & OpNoLine +// Others? + +// TODO(dneto): Test suggestions from code review +// https://github.com/KhronosGroup/SPIRV-Tools/pull/534 +// +// Callee function returns a value generated outside the callee, +// e.g. a constant value. This might exercise some logic not yet +// exercised by the current tests: the false branch in the "if" +// inside the SpvOpReturnValue case in InlinePass::GenInlineCode? +// SampledImage before function call, but callee is only single block. +// Then the SampledImage instruction is not cloned. Documents existing +// behaviour. +// SampledImage after function call. It is not cloned or changed. + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/insert_extract_elim_test.cpp b/test/opt/insert_extract_elim_test.cpp new file mode 100644 index 000000000..c5169750b --- /dev/null +++ b/test/opt/insert_extract_elim_test.cpp @@ -0,0 +1,900 @@ +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "source/opt/simplification_pass.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using InsertExtractElimTest = PassTest<::testing::Test>; + +TEST_F(InsertExtractElimTest, Simple) { + // Note: The SPIR-V assembly has had store/load elimination + // performed to allow the inserts and extracts to directly + // reference each other. + // + // #version 140 + // + // in vec4 BaseColor; + // + // struct S_t { + // vec4 v0; + // vec4 v1; + // }; + // + // void main() + // { + // S_t s0; + // s0.v1 = BaseColor; + // gl_FragColor = s0.v1; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %S_t "S_t" +OpMemberName %S_t 0 "v0" +OpMemberName %S_t 1 "v1" +OpName %s0 "s0" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%S_t = OpTypeStruct %v4float %v4float +%_ptr_Function_S_t = OpTypePointer Function %S_t +%int = OpTypeInt 32 1 +%int_1 = OpConstant %int 1 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %8 +%17 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%18 = OpLoad %v4float %BaseColor +%19 = OpLoad %S_t %s0 +%20 = OpCompositeInsert %S_t %18 %19 1 +OpStore %s0 %20 +%21 = OpCompositeExtract %v4float %20 1 +OpStore %gl_FragColor %21 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %8 +%17 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%18 = OpLoad %v4float %BaseColor +%19 = OpLoad %S_t %s0 +%20 = OpCompositeInsert %S_t %18 %19 1 +OpStore %s0 %20 +OpStore %gl_FragColor %18 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); +} + +TEST_F(InsertExtractElimTest, OptimizeAcrossNonConflictingInsert) { + // Note: The SPIR-V assembly has had store/load elimination + // performed to allow the inserts and extracts to directly + // reference each other. + // + // #version 140 + // + // in vec4 BaseColor; + // + // struct S_t { + // vec4 v0; + // vec4 v1; + // }; + // + // void main() + // { + // S_t s0; + // s0.v1 = BaseColor; + // s0.v0[2] = 0.0; + // gl_FragColor = s0.v1; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %S_t "S_t" +OpMemberName %S_t 0 "v0" +OpMemberName %S_t 1 "v1" +OpName %s0 "s0" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%S_t = OpTypeStruct %v4float %v4float +%_ptr_Function_S_t = OpTypePointer Function %S_t +%int = OpTypeInt 32 1 +%int_1 = OpConstant %int 1 +%float_0 = OpConstant %float 0 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %8 +%18 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%19 = OpLoad %v4float %BaseColor +%20 = OpLoad %S_t %s0 +%21 = OpCompositeInsert %S_t %19 %20 1 +%22 = OpCompositeInsert %S_t %float_0 %21 0 2 +OpStore %s0 %22 +%23 = OpCompositeExtract %v4float %22 1 +OpStore %gl_FragColor %23 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %8 +%18 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%19 = OpLoad %v4float %BaseColor +%20 = OpLoad %S_t %s0 +%21 = OpCompositeInsert %S_t %19 %20 1 +%22 = OpCompositeInsert %S_t %float_0 %21 0 2 +OpStore %s0 %22 +OpStore %gl_FragColor %19 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); +} + +TEST_F(InsertExtractElimTest, OptimizeOpaque) { + // SPIR-V not representable in GLSL; not generatable from HLSL + // for the moment. + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %outColor %texCoords +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %S_t "S_t" +OpMemberName %S_t 0 "v0" +OpMemberName %S_t 1 "v1" +OpMemberName %S_t 2 "smp" +OpName %outColor "outColor" +OpName %sampler15 "sampler15" +OpName %s0 "s0" +OpName %texCoords "texCoords" +OpDecorate %sampler15 DescriptorSet 0 +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%outColor = OpVariable %_ptr_Output_v4float Output +%14 = OpTypeImage %float 2D 0 0 0 1 Unknown +%15 = OpTypeSampledImage %14 +%S_t = OpTypeStruct %v2float %v2float %15 +%_ptr_Function_S_t = OpTypePointer Function %S_t +%17 = OpTypeFunction %void %_ptr_Function_S_t +%_ptr_UniformConstant_15 = OpTypePointer UniformConstant %15 +%_ptr_Function_15 = OpTypePointer Function %15 +%sampler15 = OpVariable %_ptr_UniformConstant_15 UniformConstant +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%int_2 = OpConstant %int 2 +%_ptr_Function_v2float = OpTypePointer Function %v2float +%_ptr_Input_v2float = OpTypePointer Input %v2float +%texCoords = OpVariable %_ptr_Input_v2float Input +)"; + + const std::string before = + R"(%main = OpFunction %void None %9 +%25 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%26 = OpLoad %v2float %texCoords +%27 = OpLoad %S_t %s0 +%28 = OpCompositeInsert %S_t %26 %27 0 +%29 = OpLoad %15 %sampler15 +%30 = OpCompositeInsert %S_t %29 %28 2 +OpStore %s0 %30 +%31 = OpCompositeExtract %15 %30 2 +%32 = OpCompositeExtract %v2float %30 0 +%33 = OpImageSampleImplicitLod %v4float %31 %32 +OpStore %outColor %33 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %9 +%25 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%26 = OpLoad %v2float %texCoords +%27 = OpLoad %S_t %s0 +%28 = OpCompositeInsert %S_t %26 %27 0 +%29 = OpLoad %15 %sampler15 +%30 = OpCompositeInsert %S_t %29 %28 2 +OpStore %s0 %30 +%33 = OpImageSampleImplicitLod %v4float %29 %26 +OpStore %outColor %33 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); +} + +TEST_F(InsertExtractElimTest, OptimizeNestedStruct) { + // The following HLSL has been pre-optimized to get the SPIR-V: + // struct S0 + // { + // int x; + // SamplerState ss; + // }; + // + // struct S1 + // { + // float b; + // S0 s0; + // }; + // + // struct S2 + // { + // int a1; + // S1 resources; + // }; + // + // SamplerState samp; + // Texture2D tex; + // + // float4 main(float4 vpos : VPOS) : COLOR0 + // { + // S1 s1; + // S2 s2; + // s1.s0.ss = samp; + // s2.resources = s1; + // return tex.Sample(s2.resources.s0.ss, float2(0.5)); + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %_entryPointOutput +OpExecutionMode %main OriginUpperLeft +OpSource HLSL 500 +OpName %main "main" +OpName %S0 "S0" +OpMemberName %S0 0 "x" +OpMemberName %S0 1 "ss" +OpName %S1 "S1" +OpMemberName %S1 0 "b" +OpMemberName %S1 1 "s0" +OpName %samp "samp" +OpName %S2 "S2" +OpMemberName %S2 0 "a1" +OpMemberName %S2 1 "resources" +OpName %tex "tex" +OpName %_entryPointOutput "@entryPointOutput" +OpDecorate %samp DescriptorSet 0 +OpDecorate %tex DescriptorSet 0 +OpDecorate %_entryPointOutput Location 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%14 = OpTypeFunction %v4float %_ptr_Function_v4float +%int = OpTypeInt 32 1 +%16 = OpTypeSampler +%S0 = OpTypeStruct %int %16 +%S1 = OpTypeStruct %float %S0 +%_ptr_Function_S1 = OpTypePointer Function %S1 +%int_1 = OpConstant %int 1 +%_ptr_UniformConstant_16 = OpTypePointer UniformConstant %16 +%samp = OpVariable %_ptr_UniformConstant_16 UniformConstant +%_ptr_Function_16 = OpTypePointer Function %16 +%S2 = OpTypeStruct %int %S1 +%_ptr_Function_S2 = OpTypePointer Function %S2 +%22 = OpTypeImage %float 2D 0 0 0 1 Unknown +%_ptr_UniformConstant_22 = OpTypePointer UniformConstant %22 +%tex = OpVariable %_ptr_UniformConstant_22 UniformConstant +%24 = OpTypeSampledImage %22 +%v2float = OpTypeVector %float 2 +%float_0_5 = OpConstant %float 0.5 +%27 = OpConstantComposite %v2float %float_0_5 %float_0_5 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_entryPointOutput = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %10 +%30 = OpLabel +%31 = OpVariable %_ptr_Function_S1 Function +%32 = OpVariable %_ptr_Function_S2 Function +%33 = OpLoad %16 %samp +%34 = OpLoad %S1 %31 +%35 = OpCompositeInsert %S1 %33 %34 1 1 +OpStore %31 %35 +%36 = OpLoad %S2 %32 +%37 = OpCompositeInsert %S2 %35 %36 1 +OpStore %32 %37 +%38 = OpLoad %22 %tex +%39 = OpCompositeExtract %16 %37 1 1 1 +%40 = OpSampledImage %24 %38 %39 +%41 = OpImageSampleImplicitLod %v4float %40 %27 +OpStore %_entryPointOutput %41 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %10 +%30 = OpLabel +%31 = OpVariable %_ptr_Function_S1 Function +%32 = OpVariable %_ptr_Function_S2 Function +%33 = OpLoad %16 %samp +%34 = OpLoad %S1 %31 +%35 = OpCompositeInsert %S1 %33 %34 1 1 +OpStore %31 %35 +%36 = OpLoad %S2 %32 +%37 = OpCompositeInsert %S2 %35 %36 1 +OpStore %32 %37 +%38 = OpLoad %22 %tex +%40 = OpSampledImage %24 %38 %33 +%41 = OpImageSampleImplicitLod %v4float %40 %27 +OpStore %_entryPointOutput %41 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); +} + +TEST_F(InsertExtractElimTest, ConflictingInsertPreventsOptimization) { + // Note: The SPIR-V assembly has had store/load elimination + // performed to allow the inserts and extracts to directly + // reference each other. + // + // #version 140 + // + // in vec4 BaseColor; + // + // struct S_t { + // vec4 v0; + // vec4 v1; + // }; + // + // void main() + // { + // S_t s0; + // s0.v1 = BaseColor; + // s0.v1[2] = 0.0; + // gl_FragColor = s0.v1; + // } + + const std::string assembly = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %S_t "S_t" +OpMemberName %S_t 0 "v0" +OpMemberName %S_t 1 "v1" +OpName %s0 "s0" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%S_t = OpTypeStruct %v4float %v4float +%_ptr_Function_S_t = OpTypePointer Function %S_t +%int = OpTypeInt 32 1 +%int_1 = OpConstant %int 1 +%float_0 = OpConstant %float 0 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %8 +%18 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%19 = OpLoad %v4float %BaseColor +%20 = OpLoad %S_t %s0 +%21 = OpCompositeInsert %S_t %19 %20 1 +%22 = OpCompositeInsert %S_t %float_0 %21 1 2 +OpStore %s0 %22 +%23 = OpCompositeExtract %v4float %22 1 +OpStore %gl_FragColor %23 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(assembly, assembly, true, true); +} + +TEST_F(InsertExtractElimTest, ConflictingInsertPreventsOptimization2) { + // Note: The SPIR-V assembly has had store/load elimination + // performed to allow the inserts and extracts to directly + // reference each other. + // + // #version 140 + // + // in vec4 BaseColor; + // + // struct S_t { + // vec4 v0; + // vec4 v1; + // }; + // + // void main() + // { + // S_t s0; + // s0.v1[1] = 1.0; // dead + // s0.v1 = Baseline; + // gl_FragColor = vec4(s0.v1[1], 0.0, 0.0, 0.0); + // } + + const std::string before_predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %S_t "S_t" +OpMemberName %S_t 0 "v0" +OpMemberName %S_t 1 "v1" +OpName %s0 "s0" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%S_t = OpTypeStruct %v4float %v4float +%_ptr_Function_S_t = OpTypePointer Function %S_t +%int = OpTypeInt 32 1 +%int_1 = OpConstant %int 1 +%float_1 = OpConstant %float 1 +%uint = OpTypeInt 32 0 +%uint_1 = OpConstant %uint 1 +%_ptr_Function_float = OpTypePointer Function %float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%float_0 = OpConstant %float 0 +)"; + + const std::string after_predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %S_t "S_t" +OpMemberName %S_t 0 "v0" +OpMemberName %S_t 1 "v1" +OpName %s0 "s0" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%S_t = OpTypeStruct %v4float %v4float +%_ptr_Function_S_t = OpTypePointer Function %S_t +%int = OpTypeInt 32 1 +%int_1 = OpConstant %int 1 +%float_1 = OpConstant %float 1 +%uint = OpTypeInt 32 0 +%uint_1 = OpConstant %uint 1 +%_ptr_Function_float = OpTypePointer Function %float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%float_0 = OpConstant %float 0 +)"; + + const std::string before = + R"(%main = OpFunction %void None %8 +%22 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%23 = OpLoad %S_t %s0 +%24 = OpCompositeInsert %S_t %float_1 %23 1 1 +%25 = OpLoad %v4float %BaseColor +%26 = OpCompositeInsert %S_t %25 %24 1 +%27 = OpCompositeExtract %float %26 1 1 +%28 = OpCompositeConstruct %v4float %27 %float_0 %float_0 %float_0 +OpStore %gl_FragColor %28 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %8 +%22 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%23 = OpLoad %S_t %s0 +%24 = OpCompositeInsert %S_t %float_1 %23 1 1 +%25 = OpLoad %v4float %BaseColor +%26 = OpCompositeInsert %S_t %25 %24 1 +%27 = OpCompositeExtract %float %25 1 +%28 = OpCompositeConstruct %v4float %27 %float_0 %float_0 %float_0 +OpStore %gl_FragColor %28 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(before_predefs + before, + after_predefs + after, true, true); +} + +TEST_F(InsertExtractElimTest, MixWithConstants) { + // Extract component of FMix with 0.0 or 1.0 as the a-value. + // + // Note: The SPIR-V assembly has had store/load elimination + // performed to allow the inserts and extracts to directly + // reference each other. + // + // #version 450 + // + // layout (location=0) in float bc; + // layout (location=1) in float bc2; + // layout (location=2) in float m; + // layout (location=3) in float m2; + // layout (location=0) out vec4 OutColor; + // + // void main() + // { + // vec4 bcv = vec4(bc, bc2, 0.0, 1.0); + // vec4 bcv2 = vec4(bc2, bc, 1.0, 0.0); + // vec4 v = mix(bcv, bcv2, vec4(0.0,1.0,m,m2)); + // OutColor = vec4(v.y); + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %bc %bc2 %m %m2 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %bc "bc" +OpName %bc2 "bc2" +OpName %m "m" +OpName %m2 "m2" +OpName %OutColor "OutColor" +OpDecorate %bc Location 0 +OpDecorate %bc2 Location 1 +OpDecorate %m Location 2 +OpDecorate %m2 Location 3 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_float = OpTypePointer Input %float +%bc = OpVariable %_ptr_Input_float Input +%bc2 = OpVariable %_ptr_Input_float Input +%float_0 = OpConstant %float 0 +%float_1 = OpConstant %float 1 +%m = OpVariable %_ptr_Input_float Input +%m2 = OpVariable %_ptr_Input_float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%uint = OpTypeInt 32 0 +%_ptr_Function_float = OpTypePointer Function %float +)"; + + const std::string before = + R"(%main = OpFunction %void None %9 +%19 = OpLabel +%20 = OpLoad %float %bc +%21 = OpLoad %float %bc2 +%22 = OpCompositeConstruct %v4float %20 %21 %float_0 %float_1 +%23 = OpLoad %float %bc2 +%24 = OpLoad %float %bc +%25 = OpCompositeConstruct %v4float %23 %24 %float_1 %float_0 +%26 = OpLoad %float %m +%27 = OpLoad %float %m2 +%28 = OpCompositeConstruct %v4float %float_0 %float_1 %26 %27 +%29 = OpExtInst %v4float %1 FMix %22 %25 %28 +%30 = OpCompositeExtract %float %29 1 +%31 = OpCompositeConstruct %v4float %30 %30 %30 %30 +OpStore %OutColor %31 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %9 +%19 = OpLabel +%20 = OpLoad %float %bc +%21 = OpLoad %float %bc2 +%22 = OpCompositeConstruct %v4float %20 %21 %float_0 %float_1 +%23 = OpLoad %float %bc2 +%24 = OpLoad %float %bc +%25 = OpCompositeConstruct %v4float %23 %24 %float_1 %float_0 +%26 = OpLoad %float %m +%27 = OpLoad %float %m2 +%28 = OpCompositeConstruct %v4float %float_0 %float_1 %26 %27 +%29 = OpExtInst %v4float %1 FMix %22 %25 %28 +%31 = OpCompositeConstruct %v4float %24 %24 %24 %24 +OpStore %OutColor %31 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, predefs + after, + true, true); +} + +TEST_F(InsertExtractElimTest, VectorShuffle1) { + // Extract component from first vector in VectorShuffle + // + // Note: The SPIR-V assembly has had store/load elimination + // performed to allow the inserts and extracts to directly + // reference each other. + // + // #version 450 + // + // layout (location=0) in float bc; + // layout (location=1) in float bc2; + // layout (location=0) out vec4 OutColor; + // + // void main() + // { + // vec4 bcv = vec4(bc, bc2, 0.0, 1.0); + // vec4 v = bcv.zwxy; + // OutColor = vec4(v.y); + // } + + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %bc %bc2 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %bc "bc" +OpName %bc2 "bc2" +OpName %OutColor "OutColor" +OpDecorate %bc Location 0 +OpDecorate %bc2 Location 1 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_float = OpTypePointer Input %float +%bc = OpVariable %_ptr_Input_float Input +%bc2 = OpVariable %_ptr_Input_float Input +%float_0 = OpConstant %float 0 +%float_1 = OpConstant %float 1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%uint = OpTypeInt 32 0 +%_ptr_Function_float = OpTypePointer Function %float +)"; + + const std::string predefs_after = predefs_before + + "%24 = OpConstantComposite %v4float " + "%float_1 %float_1 %float_1 %float_1\n"; + + const std::string before = + R"(%main = OpFunction %void None %7 +%17 = OpLabel +%18 = OpLoad %float %bc +%19 = OpLoad %float %bc2 +%20 = OpCompositeConstruct %v4float %18 %19 %float_0 %float_1 +%21 = OpVectorShuffle %v4float %20 %20 2 3 0 1 +%22 = OpCompositeExtract %float %21 1 +%23 = OpCompositeConstruct %v4float %22 %22 %22 %22 +OpStore %OutColor %23 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %7 +%17 = OpLabel +%18 = OpLoad %float %bc +%19 = OpLoad %float %bc2 +%20 = OpCompositeConstruct %v4float %18 %19 %float_0 %float_1 +%21 = OpVectorShuffle %v4float %20 %20 2 3 0 1 +OpStore %OutColor %24 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs_before + before, + predefs_after + after, true, true); +} + +TEST_F(InsertExtractElimTest, VectorShuffle2) { + // Extract component from second vector in VectorShuffle + // Identical to test VectorShuffle1 except for the vector + // shuffle index of 7. + // + // Note: The SPIR-V assembly has had store/load elimination + // performed to allow the inserts and extracts to directly + // reference each other. + // + // #version 450 + // + // layout (location=0) in float bc; + // layout (location=1) in float bc2; + // layout (location=0) out vec4 OutColor; + // + // void main() + // { + // vec4 bcv = vec4(bc, bc2, 0.0, 1.0); + // vec4 v = bcv.zwxy; + // OutColor = vec4(v.y); + // } + + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %bc %bc2 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %bc "bc" +OpName %bc2 "bc2" +OpName %OutColor "OutColor" +OpDecorate %bc Location 0 +OpDecorate %bc2 Location 1 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_float = OpTypePointer Input %float +%bc = OpVariable %_ptr_Input_float Input +%bc2 = OpVariable %_ptr_Input_float Input +%float_0 = OpConstant %float 0 +%float_1 = OpConstant %float 1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%uint = OpTypeInt 32 0 +%_ptr_Function_float = OpTypePointer Function %float +)"; + + const std::string predefs_after = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %bc %bc2 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %bc "bc" +OpName %bc2 "bc2" +OpName %OutColor "OutColor" +OpDecorate %bc Location 0 +OpDecorate %bc2 Location 1 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_float = OpTypePointer Input %float +%bc = OpVariable %_ptr_Input_float Input +%bc2 = OpVariable %_ptr_Input_float Input +%float_0 = OpConstant %float 0 +%float_1 = OpConstant %float 1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%uint = OpTypeInt 32 0 +%_ptr_Function_float = OpTypePointer Function %float +%24 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +)"; + + const std::string before = + R"(%main = OpFunction %void None %7 +%17 = OpLabel +%18 = OpLoad %float %bc +%19 = OpLoad %float %bc2 +%20 = OpCompositeConstruct %v4float %18 %19 %float_0 %float_1 +%21 = OpVectorShuffle %v4float %20 %20 2 7 0 1 +%22 = OpCompositeExtract %float %21 1 +%23 = OpCompositeConstruct %v4float %22 %22 %22 %22 +OpStore %OutColor %23 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %7 +%17 = OpLabel +%18 = OpLoad %float %bc +%19 = OpLoad %float %bc2 +%20 = OpCompositeConstruct %v4float %18 %19 %float_0 %float_1 +%21 = OpVectorShuffle %v4float %20 %20 2 7 0 1 +OpStore %OutColor %24 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs_before + before, + predefs_after + after, true, true); +} + +// TODO(greg-lunarg): Add tests to verify handling of these cases: +// + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/inst_bindless_check_test.cpp b/test/opt/inst_bindless_check_test.cpp new file mode 100644 index 000000000..a426ce04c --- /dev/null +++ b/test/opt/inst_bindless_check_test.cpp @@ -0,0 +1,1857 @@ +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using InstBindlessTest = PassTest<::testing::Test>; + +TEST_F(InstBindlessTest, Simple) { + // Texture2D g_tColor[128]; + // + // layout(push_constant) cbuffer PerViewConstantBuffer_t + // { + // uint g_nDataIdx; + // }; + // + // SamplerState g_sAniso; + // + // struct PS_INPUT + // { + // float2 vTextureCoords : TEXCOORD2; + // }; + // + // struct PS_OUTPUT + // { + // float4 vColor : SV_Target0; + // }; + // + // PS_OUTPUT MainPs(PS_INPUT i) + // { + // PS_OUTPUT ps_output; + // ps_output.vColor = + // g_tColor[ g_nDataIdx ].Sample(g_sAniso, i.vTextureCoords.xy); + // return ps_output; + // } + + const std::string entry_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %MainPs "MainPs" %i_vTextureCoords %_entryPointOutput_vColor +OpExecutionMode %MainPs OriginUpperLeft +OpSource HLSL 500 +)"; + + const std::string entry_after = + R"(OpCapability Shader +OpExtension "SPV_KHR_storage_buffer_storage_class" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %MainPs "MainPs" %i_vTextureCoords %_entryPointOutput_vColor %gl_FragCoord +OpExecutionMode %MainPs OriginUpperLeft +OpSource HLSL 500 +)"; + + const std::string names_annots = + R"(OpName %MainPs "MainPs" +OpName %g_tColor "g_tColor" +OpName %PerViewConstantBuffer_t "PerViewConstantBuffer_t" +OpMemberName %PerViewConstantBuffer_t 0 "g_nDataIdx" +OpName %_ "" +OpName %g_sAniso "g_sAniso" +OpName %i_vTextureCoords "i.vTextureCoords" +OpName %_entryPointOutput_vColor "@entryPointOutput.vColor" +OpDecorate %g_tColor DescriptorSet 3 +OpDecorate %g_tColor Binding 0 +OpMemberDecorate %PerViewConstantBuffer_t 0 Offset 0 +OpDecorate %PerViewConstantBuffer_t Block +OpDecorate %g_sAniso DescriptorSet 0 +OpDecorate %i_vTextureCoords Location 0 +OpDecorate %_entryPointOutput_vColor Location 0 +)"; + + const std::string new_annots = + R"(OpDecorate %_runtimearr_uint ArrayStride 4 +OpDecorate %_struct_55 Block +OpMemberDecorate %_struct_55 0 Offset 0 +OpMemberDecorate %_struct_55 1 Offset 4 +OpDecorate %57 DescriptorSet 7 +OpDecorate %57 Binding 0 +OpDecorate %gl_FragCoord BuiltIn FragCoord +)"; + + const std::string consts_types_vars = + R"(%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%v4float = OpTypeVector %float 4 +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%16 = OpTypeImage %float 2D 0 0 0 1 Unknown +%uint = OpTypeInt 32 0 +%uint_128 = OpConstant %uint 128 +%_arr_16_uint_128 = OpTypeArray %16 %uint_128 +%_ptr_UniformConstant__arr_16_uint_128 = OpTypePointer UniformConstant %_arr_16_uint_128 +%g_tColor = OpVariable %_ptr_UniformConstant__arr_16_uint_128 UniformConstant +%PerViewConstantBuffer_t = OpTypeStruct %uint +%_ptr_PushConstant_PerViewConstantBuffer_t = OpTypePointer PushConstant %PerViewConstantBuffer_t +%_ = OpVariable %_ptr_PushConstant_PerViewConstantBuffer_t PushConstant +%_ptr_PushConstant_uint = OpTypePointer PushConstant %uint +%_ptr_UniformConstant_16 = OpTypePointer UniformConstant %16 +%24 = OpTypeSampler +%_ptr_UniformConstant_24 = OpTypePointer UniformConstant %24 +%g_sAniso = OpVariable %_ptr_UniformConstant_24 UniformConstant +%26 = OpTypeSampledImage %16 +%_ptr_Input_v2float = OpTypePointer Input %v2float +%i_vTextureCoords = OpVariable %_ptr_Input_v2float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_entryPointOutput_vColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string new_consts_types_vars = + R"(%uint_0 = OpConstant %uint 0 +%bool = OpTypeBool +%48 = OpTypeFunction %void %uint %uint %uint %uint +%_runtimearr_uint = OpTypeRuntimeArray %uint +%_struct_55 = OpTypeStruct %uint %_runtimearr_uint +%_ptr_StorageBuffer__struct_55 = OpTypePointer StorageBuffer %_struct_55 +%57 = OpVariable %_ptr_StorageBuffer__struct_55 StorageBuffer +%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint +%uint_9 = OpConstant %uint 9 +%uint_4 = OpConstant %uint 4 +%uint_1 = OpConstant %uint 1 +%uint_23 = OpConstant %uint 23 +%uint_2 = OpConstant %uint 2 +%uint_3 = OpConstant %uint 3 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%gl_FragCoord = OpVariable %_ptr_Input_v4float Input +%v4uint = OpTypeVector %uint 4 +%uint_5 = OpConstant %uint 5 +%uint_6 = OpConstant %uint 6 +%uint_7 = OpConstant %uint 7 +%uint_8 = OpConstant %uint 8 +%uint_56 = OpConstant %uint 56 +%103 = OpConstantNull %v4float +)"; + + const std::string func_pt1 = + R"(%MainPs = OpFunction %void None %10 +%29 = OpLabel +%30 = OpLoad %v2float %i_vTextureCoords +%31 = OpAccessChain %_ptr_PushConstant_uint %_ %int_0 +%32 = OpLoad %uint %31 +%33 = OpAccessChain %_ptr_UniformConstant_16 %g_tColor %32 +%34 = OpLoad %16 %33 +%35 = OpLoad %24 %g_sAniso +%36 = OpSampledImage %26 %34 %35 +)"; + + const std::string func_pt2_before = + R"(%37 = OpImageSampleImplicitLod %v4float %36 %30 +OpStore %_entryPointOutput_vColor %37 +OpReturn +OpFunctionEnd +)"; + + const std::string func_pt2_after = + R"(%40 = OpULessThan %bool %32 %uint_128 +OpSelectionMerge %41 None +OpBranchConditional %40 %42 %43 +%42 = OpLabel +%44 = OpLoad %16 %33 +%45 = OpSampledImage %26 %44 %35 +%46 = OpImageSampleImplicitLod %v4float %45 %30 +OpBranch %41 +%43 = OpLabel +%102 = OpFunctionCall %void %47 %uint_56 %uint_0 %32 %uint_128 +OpBranch %41 +%41 = OpLabel +%104 = OpPhi %v4float %46 %42 %103 %43 +OpStore %_entryPointOutput_vColor %104 +OpReturn +OpFunctionEnd +)"; + + const std::string output_func = + R"(%47 = OpFunction %void None %48 +%49 = OpFunctionParameter %uint +%50 = OpFunctionParameter %uint +%51 = OpFunctionParameter %uint +%52 = OpFunctionParameter %uint +%53 = OpLabel +%59 = OpAccessChain %_ptr_StorageBuffer_uint %57 %uint_0 +%62 = OpAtomicIAdd %uint %59 %uint_4 %uint_0 %uint_9 +%63 = OpIAdd %uint %62 %uint_9 +%64 = OpArrayLength %uint %57 1 +%65 = OpULessThanEqual %bool %63 %64 +OpSelectionMerge %66 None +OpBranchConditional %65 %67 %66 +%67 = OpLabel +%68 = OpIAdd %uint %62 %uint_0 +%70 = OpAccessChain %_ptr_StorageBuffer_uint %57 %uint_1 %68 +OpStore %70 %uint_9 +%72 = OpIAdd %uint %62 %uint_1 +%73 = OpAccessChain %_ptr_StorageBuffer_uint %57 %uint_1 %72 +OpStore %73 %uint_23 +%75 = OpIAdd %uint %62 %uint_2 +%76 = OpAccessChain %_ptr_StorageBuffer_uint %57 %uint_1 %75 +OpStore %76 %49 +%78 = OpIAdd %uint %62 %uint_3 +%79 = OpAccessChain %_ptr_StorageBuffer_uint %57 %uint_1 %78 +OpStore %79 %uint_4 +%82 = OpLoad %v4float %gl_FragCoord +%84 = OpBitcast %v4uint %82 +%85 = OpCompositeExtract %uint %84 0 +%86 = OpIAdd %uint %62 %uint_4 +%87 = OpAccessChain %_ptr_StorageBuffer_uint %57 %uint_1 %86 +OpStore %87 %85 +%88 = OpCompositeExtract %uint %84 1 +%90 = OpIAdd %uint %62 %uint_5 +%91 = OpAccessChain %_ptr_StorageBuffer_uint %57 %uint_1 %90 +OpStore %91 %88 +%93 = OpIAdd %uint %62 %uint_6 +%94 = OpAccessChain %_ptr_StorageBuffer_uint %57 %uint_1 %93 +OpStore %94 %50 +%96 = OpIAdd %uint %62 %uint_7 +%97 = OpAccessChain %_ptr_StorageBuffer_uint %57 %uint_1 %96 +OpStore %97 %51 +%99 = OpIAdd %uint %62 %uint_8 +%100 = OpAccessChain %_ptr_StorageBuffer_uint %57 %uint_1 %99 +OpStore %100 %52 +OpBranch %66 +%66 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck( + entry_before + names_annots + consts_types_vars + func_pt1 + + func_pt2_before, + entry_after + names_annots + new_annots + consts_types_vars + + new_consts_types_vars + func_pt1 + func_pt2_after + output_func, + true, true); +} + +TEST_F(InstBindlessTest, NoInstrumentConstIndexInbounds) { + // Texture2D g_tColor[128]; + // + // SamplerState g_sAniso; + // + // struct PS_INPUT + // { + // float2 vTextureCoords : TEXCOORD2; + // }; + // + // struct PS_OUTPUT + // { + // float4 vColor : SV_Target0; + // }; + // + // PS_OUTPUT MainPs(PS_INPUT i) + // { + // PS_OUTPUT ps_output; + // + // ps_output.vColor = g_tColor[ 37 ].Sample(g_sAniso, i.vTextureCoords.xy); + // return ps_output; + // } + + const std::string before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %MainPs "MainPs" %i_vTextureCoords %_entryPointOutput_vColor +OpExecutionMode %MainPs OriginUpperLeft +OpSource HLSL 500 +OpName %MainPs "MainPs" +OpName %g_tColor "g_tColor" +OpName %g_sAniso "g_sAniso" +OpName %i_vTextureCoords "i.vTextureCoords" +OpName %_entryPointOutput_vColor "@entryPointOutput.vColor" +OpDecorate %g_tColor DescriptorSet 3 +OpDecorate %g_tColor Binding 0 +OpDecorate %g_sAniso DescriptorSet 0 +OpDecorate %i_vTextureCoords Location 0 +OpDecorate %_entryPointOutput_vColor Location 0 +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%v4float = OpTypeVector %float 4 +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%int_37 = OpConstant %int 37 +%15 = OpTypeImage %float 2D 0 0 0 1 Unknown +%uint = OpTypeInt 32 0 +%uint_128 = OpConstant %uint 128 +%_arr_15_uint_128 = OpTypeArray %15 %uint_128 +%_ptr_UniformConstant__arr_15_uint_128 = OpTypePointer UniformConstant %_arr_15_uint_128 +%g_tColor = OpVariable %_ptr_UniformConstant__arr_15_uint_128 UniformConstant +%_ptr_UniformConstant_15 = OpTypePointer UniformConstant %15 +%21 = OpTypeSampler +%_ptr_UniformConstant_21 = OpTypePointer UniformConstant %21 +%g_sAniso = OpVariable %_ptr_UniformConstant_21 UniformConstant +%23 = OpTypeSampledImage %15 +%_ptr_Input_v2float = OpTypePointer Input %v2float +%i_vTextureCoords = OpVariable %_ptr_Input_v2float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_entryPointOutput_vColor = OpVariable %_ptr_Output_v4float Output +%MainPs = OpFunction %void None %8 +%26 = OpLabel +%27 = OpLoad %v2float %i_vTextureCoords +%28 = OpAccessChain %_ptr_UniformConstant_15 %g_tColor %int_37 +%29 = OpLoad %15 %28 +%30 = OpLoad %21 %g_sAniso +%31 = OpSampledImage %23 %29 %30 +%32 = OpImageSampleImplicitLod %v4float %31 %27 +OpStore %_entryPointOutput_vColor %32 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(before, before, true, true); +} + +TEST_F(InstBindlessTest, InstrumentMultipleInstructions) { + // Texture2D g_tColor[128]; + // + // layout(push_constant) cbuffer PerViewConstantBuffer_t + // { + // uint g_nDataIdx; + // uint g_nDataIdx2; + // }; + // + // SamplerState g_sAniso; + // + // struct PS_INPUT + // { + // float2 vTextureCoords : TEXCOORD2; + // }; + // + // struct PS_OUTPUT + // { + // float4 vColor : SV_Target0; + // }; + // + // PS_OUTPUT MainPs(PS_INPUT i) + // { + // PS_OUTPUT ps_output; + // + // float t = g_tColor[g_nDataIdx ].Sample(g_sAniso, i.vTextureCoords.xy); + // float t2 = g_tColor[g_nDataIdx2].Sample(g_sAniso, i.vTextureCoords.xy); + // ps_output.vColor = t + t2; + // return ps_output; + // } + + const std::string defs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %MainPs "MainPs" %i_vTextureCoords %_entryPointOutput_vColor +OpExecutionMode %MainPs OriginUpperLeft +OpSource HLSL 500 +OpName %MainPs "MainPs" +OpName %g_tColor "g_tColor" +OpName %PerViewConstantBuffer_t "PerViewConstantBuffer_t" +OpMemberName %PerViewConstantBuffer_t 0 "g_nDataIdx" +OpName %_ "" +OpName %g_sAniso "g_sAniso" +OpName %i_vTextureCoords "i.vTextureCoords" +OpName %_entryPointOutput_vColor "@entryPointOutput.vColor" +OpDecorate %g_tColor DescriptorSet 3 +OpDecorate %g_tColor Binding 0 +OpMemberDecorate %PerViewConstantBuffer_t 0 Offset 0 +OpMemberDecorate %PerViewConstantBuffer_t 1 Offset 4 +OpDecorate %PerViewConstantBuffer_t Block +OpDecorate %g_sAniso DescriptorSet 0 +OpDecorate %i_vTextureCoords Location 0 +OpDecorate %_entryPointOutput_vColor Location 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%v4float = OpTypeVector %float 4 +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%int_1 = OpConstant %int 1 +%17 = OpTypeImage %float 2D 0 0 0 1 Unknown +%uint = OpTypeInt 32 0 +%uint_128 = OpConstant %uint 128 +%_arr_17_uint_128 = OpTypeArray %17 %uint_128 +%_ptr_UniformConstant__arr_17_uint_128 = OpTypePointer UniformConstant %_arr_17_uint_128 +%g_tColor = OpVariable %_ptr_UniformConstant__arr_17_uint_128 UniformConstant +%PerViewConstantBuffer_t = OpTypeStruct %uint %uint +%_ptr_PushConstant_PerViewConstantBuffer_t = OpTypePointer PushConstant %PerViewConstantBuffer_t +%_ = OpVariable %_ptr_PushConstant_PerViewConstantBuffer_t PushConstant +%_ptr_PushConstant_uint = OpTypePointer PushConstant %uint +%_ptr_UniformConstant_17 = OpTypePointer UniformConstant %17 +%25 = OpTypeSampler +%_ptr_UniformConstant_25 = OpTypePointer UniformConstant %25 +%g_sAniso = OpVariable %_ptr_UniformConstant_25 UniformConstant +%27 = OpTypeSampledImage %17 +%_ptr_Input_v2float = OpTypePointer Input %v2float +%i_vTextureCoords = OpVariable %_ptr_Input_v2float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_entryPointOutput_vColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string defs_after = + R"(OpCapability Shader +OpExtension "SPV_KHR_storage_buffer_storage_class" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %MainPs "MainPs" %i_vTextureCoords %_entryPointOutput_vColor %gl_FragCoord +OpExecutionMode %MainPs OriginUpperLeft +OpSource HLSL 500 +OpName %MainPs "MainPs" +OpName %g_tColor "g_tColor" +OpName %PerViewConstantBuffer_t "PerViewConstantBuffer_t" +OpMemberName %PerViewConstantBuffer_t 0 "g_nDataIdx" +OpName %_ "" +OpName %g_sAniso "g_sAniso" +OpName %i_vTextureCoords "i.vTextureCoords" +OpName %_entryPointOutput_vColor "@entryPointOutput.vColor" +OpDecorate %g_tColor DescriptorSet 3 +OpDecorate %g_tColor Binding 0 +OpMemberDecorate %PerViewConstantBuffer_t 0 Offset 0 +OpMemberDecorate %PerViewConstantBuffer_t 1 Offset 4 +OpDecorate %PerViewConstantBuffer_t Block +OpDecorate %g_sAniso DescriptorSet 0 +OpDecorate %i_vTextureCoords Location 0 +OpDecorate %_entryPointOutput_vColor Location 0 +OpDecorate %_runtimearr_uint ArrayStride 4 +OpDecorate %_struct_63 Block +OpMemberDecorate %_struct_63 0 Offset 0 +OpMemberDecorate %_struct_63 1 Offset 4 +OpDecorate %65 DescriptorSet 7 +OpDecorate %65 Binding 0 +OpDecorate %gl_FragCoord BuiltIn FragCoord +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%v4float = OpTypeVector %float 4 +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%int_1 = OpConstant %int 1 +%17 = OpTypeImage %float 2D 0 0 0 1 Unknown +%uint = OpTypeInt 32 0 +%uint_128 = OpConstant %uint 128 +%_arr_17_uint_128 = OpTypeArray %17 %uint_128 +%_ptr_UniformConstant__arr_17_uint_128 = OpTypePointer UniformConstant %_arr_17_uint_128 +%g_tColor = OpVariable %_ptr_UniformConstant__arr_17_uint_128 UniformConstant +%PerViewConstantBuffer_t = OpTypeStruct %uint %uint +%_ptr_PushConstant_PerViewConstantBuffer_t = OpTypePointer PushConstant %PerViewConstantBuffer_t +%_ = OpVariable %_ptr_PushConstant_PerViewConstantBuffer_t PushConstant +%_ptr_PushConstant_uint = OpTypePointer PushConstant %uint +%_ptr_UniformConstant_17 = OpTypePointer UniformConstant %17 +%25 = OpTypeSampler +%_ptr_UniformConstant_25 = OpTypePointer UniformConstant %25 +%g_sAniso = OpVariable %_ptr_UniformConstant_25 UniformConstant +%27 = OpTypeSampledImage %17 +%_ptr_Input_v2float = OpTypePointer Input %v2float +%i_vTextureCoords = OpVariable %_ptr_Input_v2float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_entryPointOutput_vColor = OpVariable %_ptr_Output_v4float Output +%uint_0 = OpConstant %uint 0 +%bool = OpTypeBool +%56 = OpTypeFunction %void %uint %uint %uint %uint +%_runtimearr_uint = OpTypeRuntimeArray %uint +%_struct_63 = OpTypeStruct %uint %_runtimearr_uint +%_ptr_StorageBuffer__struct_63 = OpTypePointer StorageBuffer %_struct_63 +%65 = OpVariable %_ptr_StorageBuffer__struct_63 StorageBuffer +%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint +%uint_9 = OpConstant %uint 9 +%uint_4 = OpConstant %uint 4 +%uint_1 = OpConstant %uint 1 +%uint_23 = OpConstant %uint 23 +%uint_2 = OpConstant %uint 2 +%uint_3 = OpConstant %uint 3 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%gl_FragCoord = OpVariable %_ptr_Input_v4float Input +%v4uint = OpTypeVector %uint 4 +%uint_5 = OpConstant %uint 5 +%uint_6 = OpConstant %uint 6 +%uint_7 = OpConstant %uint 7 +%uint_8 = OpConstant %uint 8 +%uint_58 = OpConstant %uint 58 +%111 = OpConstantNull %v4float +%uint_64 = OpConstant %uint 64 +)"; + + const std::string func_before = + R"(%MainPs = OpFunction %void None %10 +%30 = OpLabel +%31 = OpLoad %v2float %i_vTextureCoords +%32 = OpAccessChain %_ptr_PushConstant_uint %_ %int_0 +%33 = OpLoad %uint %32 +%34 = OpAccessChain %_ptr_UniformConstant_17 %g_tColor %33 +%35 = OpLoad %17 %34 +%36 = OpLoad %25 %g_sAniso +%37 = OpSampledImage %27 %35 %36 +%38 = OpImageSampleImplicitLod %v4float %37 %31 +%39 = OpAccessChain %_ptr_PushConstant_uint %_ %int_1 +%40 = OpLoad %uint %39 +%41 = OpAccessChain %_ptr_UniformConstant_17 %g_tColor %40 +%42 = OpLoad %17 %41 +%43 = OpSampledImage %27 %42 %36 +%44 = OpImageSampleImplicitLod %v4float %43 %31 +%45 = OpFAdd %v4float %38 %44 +OpStore %_entryPointOutput_vColor %45 +OpReturn +OpFunctionEnd +)"; + + const std::string func_after = + R"(%MainPs = OpFunction %void None %10 +%30 = OpLabel +%31 = OpLoad %v2float %i_vTextureCoords +%32 = OpAccessChain %_ptr_PushConstant_uint %_ %int_0 +%33 = OpLoad %uint %32 +%34 = OpAccessChain %_ptr_UniformConstant_17 %g_tColor %33 +%35 = OpLoad %17 %34 +%36 = OpLoad %25 %g_sAniso +%37 = OpSampledImage %27 %35 %36 +%48 = OpULessThan %bool %33 %uint_128 +OpSelectionMerge %49 None +OpBranchConditional %48 %50 %51 +%50 = OpLabel +%52 = OpLoad %17 %34 +%53 = OpSampledImage %27 %52 %36 +%54 = OpImageSampleImplicitLod %v4float %53 %31 +OpBranch %49 +%51 = OpLabel +%110 = OpFunctionCall %void %55 %uint_58 %uint_0 %33 %uint_128 +OpBranch %49 +%49 = OpLabel +%112 = OpPhi %v4float %54 %50 %111 %51 +%39 = OpAccessChain %_ptr_PushConstant_uint %_ %int_1 +%40 = OpLoad %uint %39 +%41 = OpAccessChain %_ptr_UniformConstant_17 %g_tColor %40 +%42 = OpLoad %17 %41 +%43 = OpSampledImage %27 %42 %36 +%113 = OpULessThan %bool %40 %uint_128 +OpSelectionMerge %114 None +OpBranchConditional %113 %115 %116 +%115 = OpLabel +%117 = OpLoad %17 %41 +%118 = OpSampledImage %27 %117 %36 +%119 = OpImageSampleImplicitLod %v4float %118 %31 +OpBranch %114 +%116 = OpLabel +%121 = OpFunctionCall %void %55 %uint_64 %uint_0 %40 %uint_128 +OpBranch %114 +%114 = OpLabel +%122 = OpPhi %v4float %119 %115 %111 %116 +%45 = OpFAdd %v4float %112 %122 +OpStore %_entryPointOutput_vColor %45 +OpReturn +OpFunctionEnd +)"; + + const std::string output_func = + R"(%55 = OpFunction %void None %56 +%57 = OpFunctionParameter %uint +%58 = OpFunctionParameter %uint +%59 = OpFunctionParameter %uint +%60 = OpFunctionParameter %uint +%61 = OpLabel +%67 = OpAccessChain %_ptr_StorageBuffer_uint %65 %uint_0 +%70 = OpAtomicIAdd %uint %67 %uint_4 %uint_0 %uint_9 +%71 = OpIAdd %uint %70 %uint_9 +%72 = OpArrayLength %uint %65 1 +%73 = OpULessThanEqual %bool %71 %72 +OpSelectionMerge %74 None +OpBranchConditional %73 %75 %74 +%75 = OpLabel +%76 = OpIAdd %uint %70 %uint_0 +%78 = OpAccessChain %_ptr_StorageBuffer_uint %65 %uint_1 %76 +OpStore %78 %uint_9 +%80 = OpIAdd %uint %70 %uint_1 +%81 = OpAccessChain %_ptr_StorageBuffer_uint %65 %uint_1 %80 +OpStore %81 %uint_23 +%83 = OpIAdd %uint %70 %uint_2 +%84 = OpAccessChain %_ptr_StorageBuffer_uint %65 %uint_1 %83 +OpStore %84 %57 +%86 = OpIAdd %uint %70 %uint_3 +%87 = OpAccessChain %_ptr_StorageBuffer_uint %65 %uint_1 %86 +OpStore %87 %uint_4 +%90 = OpLoad %v4float %gl_FragCoord +%92 = OpBitcast %v4uint %90 +%93 = OpCompositeExtract %uint %92 0 +%94 = OpIAdd %uint %70 %uint_4 +%95 = OpAccessChain %_ptr_StorageBuffer_uint %65 %uint_1 %94 +OpStore %95 %93 +%96 = OpCompositeExtract %uint %92 1 +%98 = OpIAdd %uint %70 %uint_5 +%99 = OpAccessChain %_ptr_StorageBuffer_uint %65 %uint_1 %98 +OpStore %99 %96 +%101 = OpIAdd %uint %70 %uint_6 +%102 = OpAccessChain %_ptr_StorageBuffer_uint %65 %uint_1 %101 +OpStore %102 %58 +%104 = OpIAdd %uint %70 %uint_7 +%105 = OpAccessChain %_ptr_StorageBuffer_uint %65 %uint_1 %104 +OpStore %105 %59 +%107 = OpIAdd %uint %70 %uint_8 +%108 = OpAccessChain %_ptr_StorageBuffer_uint %65 %uint_1 %107 +OpStore %108 %60 +OpBranch %74 +%74 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck( + defs_before + func_before, defs_after + func_after + output_func, true, + true); +} + +TEST_F(InstBindlessTest, ReuseConstsTypesBuiltins) { + // This test verifies that the pass resuses existing constants, types + // and builtin variables. This test was created by editing the SPIR-V + // from the Simple test. + + const std::string defs_before = + R"(OpCapability Shader +OpExtension "SPV_KHR_storage_buffer_storage_class" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %MainPs "MainPs" %i_vTextureCoords %_entryPointOutput_vColor %gl_FragCoord +OpExecutionMode %MainPs OriginUpperLeft +OpSource HLSL 500 +OpName %MainPs "MainPs" +OpName %g_tColor "g_tColor" +OpName %PerViewConstantBuffer_t "PerViewConstantBuffer_t" +OpMemberName %PerViewConstantBuffer_t 0 "g_nDataIdx" +OpName %_ "" +OpName %g_sAniso "g_sAniso" +OpName %i_vTextureCoords "i.vTextureCoords" +OpName %_entryPointOutput_vColor "@entryPointOutput.vColor" +OpDecorate %g_tColor DescriptorSet 3 +OpDecorate %g_tColor Binding 0 +OpMemberDecorate %PerViewConstantBuffer_t 0 Offset 0 +OpDecorate %PerViewConstantBuffer_t Block +OpDecorate %g_sAniso DescriptorSet 0 +OpDecorate %i_vTextureCoords Location 0 +OpDecorate %_entryPointOutput_vColor Location 0 +OpDecorate %85 DescriptorSet 7 +OpDecorate %85 Binding 0 +OpDecorate %gl_FragCoord BuiltIn FragCoord +%void = OpTypeVoid +%3 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%v4float = OpTypeVector %float 4 +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%20 = OpTypeImage %float 2D 0 0 0 1 Unknown +%uint = OpTypeInt 32 0 +%uint_128 = OpConstant %uint 128 +%_arr_20_uint_128 = OpTypeArray %20 %uint_128 +%_ptr_UniformConstant__arr_20_uint_128 = OpTypePointer UniformConstant %_arr_20_uint_128 +%g_tColor = OpVariable %_ptr_UniformConstant__arr_20_uint_128 UniformConstant +%PerViewConstantBuffer_t = OpTypeStruct %uint +%_ptr_PushConstant_PerViewConstantBuffer_t = OpTypePointer PushConstant %PerViewConstantBuffer_t +%_ = OpVariable %_ptr_PushConstant_PerViewConstantBuffer_t PushConstant +%_ptr_PushConstant_uint = OpTypePointer PushConstant %uint +%_ptr_UniformConstant_20 = OpTypePointer UniformConstant %20 +%35 = OpTypeSampler +%_ptr_UniformConstant_35 = OpTypePointer UniformConstant %35 +%g_sAniso = OpVariable %_ptr_UniformConstant_35 UniformConstant +%39 = OpTypeSampledImage %20 +%_ptr_Input_v2float = OpTypePointer Input %v2float +%i_vTextureCoords = OpVariable %_ptr_Input_v2float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_entryPointOutput_vColor = OpVariable %_ptr_Output_v4float Output +%uint_0 = OpConstant %uint 0 +%bool = OpTypeBool +%_runtimearr_uint = OpTypeRuntimeArray %uint +%_struct_83 = OpTypeStruct %uint %_runtimearr_uint +%_ptr_StorageBuffer__struct_83 = OpTypePointer StorageBuffer %_struct_83 +%85 = OpVariable %_ptr_StorageBuffer__struct_83 StorageBuffer +%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint +%uint_10 = OpConstant %uint 10 +%uint_4 = OpConstant %uint 4 +%uint_1 = OpConstant %uint 1 +%uint_23 = OpConstant %uint 23 +%uint_2 = OpConstant %uint 2 +%uint_9 = OpConstant %uint 9 +%uint_3 = OpConstant %uint 3 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%gl_FragCoord = OpVariable %_ptr_Input_v4float Input +%v4uint = OpTypeVector %uint 4 +%uint_5 = OpConstant %uint 5 +%uint_6 = OpConstant %uint 6 +%uint_7 = OpConstant %uint 7 +%uint_8 = OpConstant %uint 8 +%131 = OpConstantNull %v4float +)"; + + const std::string defs_after = + R"(OpCapability Shader +OpExtension "SPV_KHR_storage_buffer_storage_class" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %MainPs "MainPs" %i_vTextureCoords %_entryPointOutput_vColor %gl_FragCoord +OpExecutionMode %MainPs OriginUpperLeft +OpSource HLSL 500 +OpName %MainPs "MainPs" +OpName %g_tColor "g_tColor" +OpName %PerViewConstantBuffer_t "PerViewConstantBuffer_t" +OpMemberName %PerViewConstantBuffer_t 0 "g_nDataIdx" +OpName %_ "" +OpName %g_sAniso "g_sAniso" +OpName %i_vTextureCoords "i.vTextureCoords" +OpName %_entryPointOutput_vColor "@entryPointOutput.vColor" +OpDecorate %g_tColor DescriptorSet 3 +OpDecorate %g_tColor Binding 0 +OpMemberDecorate %PerViewConstantBuffer_t 0 Offset 0 +OpDecorate %PerViewConstantBuffer_t Block +OpDecorate %g_sAniso DescriptorSet 0 +OpDecorate %i_vTextureCoords Location 0 +OpDecorate %_entryPointOutput_vColor Location 0 +OpDecorate %10 DescriptorSet 7 +OpDecorate %10 Binding 0 +OpDecorate %gl_FragCoord BuiltIn FragCoord +OpDecorate %_runtimearr_uint ArrayStride 4 +OpDecorate %_struct_34 Block +OpMemberDecorate %_struct_34 0 Offset 0 +OpMemberDecorate %_struct_34 1 Offset 4 +OpDecorate %74 DescriptorSet 7 +OpDecorate %74 Binding 0 +%void = OpTypeVoid +%12 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%v4float = OpTypeVector %float 4 +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%18 = OpTypeImage %float 2D 0 0 0 1 Unknown +%uint = OpTypeInt 32 0 +%uint_128 = OpConstant %uint 128 +%_arr_18_uint_128 = OpTypeArray %18 %uint_128 +%_ptr_UniformConstant__arr_18_uint_128 = OpTypePointer UniformConstant %_arr_18_uint_128 +%g_tColor = OpVariable %_ptr_UniformConstant__arr_18_uint_128 UniformConstant +%PerViewConstantBuffer_t = OpTypeStruct %uint +%_ptr_PushConstant_PerViewConstantBuffer_t = OpTypePointer PushConstant %PerViewConstantBuffer_t +%_ = OpVariable %_ptr_PushConstant_PerViewConstantBuffer_t PushConstant +%_ptr_PushConstant_uint = OpTypePointer PushConstant %uint +%_ptr_UniformConstant_18 = OpTypePointer UniformConstant %18 +%26 = OpTypeSampler +%_ptr_UniformConstant_26 = OpTypePointer UniformConstant %26 +%g_sAniso = OpVariable %_ptr_UniformConstant_26 UniformConstant +%28 = OpTypeSampledImage %18 +%_ptr_Input_v2float = OpTypePointer Input %v2float +%i_vTextureCoords = OpVariable %_ptr_Input_v2float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_entryPointOutput_vColor = OpVariable %_ptr_Output_v4float Output +%uint_0 = OpConstant %uint 0 +%bool = OpTypeBool +%_runtimearr_uint = OpTypeRuntimeArray %uint +%_struct_34 = OpTypeStruct %uint %_runtimearr_uint +%_ptr_StorageBuffer__struct_34 = OpTypePointer StorageBuffer %_struct_34 +%10 = OpVariable %_ptr_StorageBuffer__struct_34 StorageBuffer +%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint +%uint_10 = OpConstant %uint 10 +%uint_4 = OpConstant %uint 4 +%uint_1 = OpConstant %uint 1 +%uint_23 = OpConstant %uint 23 +%uint_2 = OpConstant %uint 2 +%uint_9 = OpConstant %uint 9 +%uint_3 = OpConstant %uint 3 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%gl_FragCoord = OpVariable %_ptr_Input_v4float Input +%v4uint = OpTypeVector %uint 4 +%uint_5 = OpConstant %uint 5 +%uint_6 = OpConstant %uint 6 +%uint_7 = OpConstant %uint 7 +%uint_8 = OpConstant %uint 8 +%50 = OpConstantNull %v4float +%68 = OpTypeFunction %void %uint %uint %uint %uint +%74 = OpVariable %_ptr_StorageBuffer__struct_34 StorageBuffer +%uint_82 = OpConstant %uint 82 +)"; + + const std::string func_before = + R"(%MainPs = OpFunction %void None %3 +%5 = OpLabel +%53 = OpLoad %v2float %i_vTextureCoords +%63 = OpAccessChain %_ptr_PushConstant_uint %_ %int_0 +%64 = OpLoad %uint %63 +%65 = OpAccessChain %_ptr_UniformConstant_20 %g_tColor %64 +%67 = OpLoad %35 %g_sAniso +%78 = OpLoad %20 %65 +%79 = OpSampledImage %39 %78 %67 +%71 = OpImageSampleImplicitLod %v4float %79 %53 +OpStore %_entryPointOutput_vColor %71 +OpReturn +OpFunctionEnd +)"; + + const std::string func_after = + R"(%MainPs = OpFunction %void None %12 +%51 = OpLabel +%52 = OpLoad %v2float %i_vTextureCoords +%53 = OpAccessChain %_ptr_PushConstant_uint %_ %int_0 +%54 = OpLoad %uint %53 +%55 = OpAccessChain %_ptr_UniformConstant_18 %g_tColor %54 +%56 = OpLoad %26 %g_sAniso +%57 = OpLoad %18 %55 +%58 = OpSampledImage %28 %57 %56 +%60 = OpULessThan %bool %54 %uint_128 +OpSelectionMerge %61 None +OpBranchConditional %60 %62 %63 +%62 = OpLabel +%64 = OpLoad %18 %55 +%65 = OpSampledImage %28 %64 %56 +%66 = OpImageSampleImplicitLod %v4float %65 %52 +OpBranch %61 +%63 = OpLabel +%105 = OpFunctionCall %void %67 %uint_82 %uint_0 %54 %uint_128 +OpBranch %61 +%61 = OpLabel +%106 = OpPhi %v4float %66 %62 %50 %63 +OpStore %_entryPointOutput_vColor %106 +OpReturn +OpFunctionEnd +)"; + + const std::string output_func = + R"(%67 = OpFunction %void None %68 +%69 = OpFunctionParameter %uint +%70 = OpFunctionParameter %uint +%71 = OpFunctionParameter %uint +%72 = OpFunctionParameter %uint +%73 = OpLabel +%75 = OpAccessChain %_ptr_StorageBuffer_uint %74 %uint_0 +%76 = OpAtomicIAdd %uint %75 %uint_4 %uint_0 %uint_9 +%77 = OpIAdd %uint %76 %uint_9 +%78 = OpArrayLength %uint %74 1 +%79 = OpULessThanEqual %bool %77 %78 +OpSelectionMerge %80 None +OpBranchConditional %79 %81 %80 +%81 = OpLabel +%82 = OpIAdd %uint %76 %uint_0 +%83 = OpAccessChain %_ptr_StorageBuffer_uint %74 %uint_1 %82 +OpStore %83 %uint_9 +%84 = OpIAdd %uint %76 %uint_1 +%85 = OpAccessChain %_ptr_StorageBuffer_uint %74 %uint_1 %84 +OpStore %85 %uint_23 +%86 = OpIAdd %uint %76 %uint_2 +%87 = OpAccessChain %_ptr_StorageBuffer_uint %74 %uint_1 %86 +OpStore %87 %69 +%88 = OpIAdd %uint %76 %uint_3 +%89 = OpAccessChain %_ptr_StorageBuffer_uint %74 %uint_1 %88 +OpStore %89 %uint_4 +%90 = OpLoad %v4float %gl_FragCoord +%91 = OpBitcast %v4uint %90 +%92 = OpCompositeExtract %uint %91 0 +%93 = OpIAdd %uint %76 %uint_4 +%94 = OpAccessChain %_ptr_StorageBuffer_uint %74 %uint_1 %93 +OpStore %94 %92 +%95 = OpCompositeExtract %uint %91 1 +%96 = OpIAdd %uint %76 %uint_5 +%97 = OpAccessChain %_ptr_StorageBuffer_uint %74 %uint_1 %96 +OpStore %97 %95 +%98 = OpIAdd %uint %76 %uint_6 +%99 = OpAccessChain %_ptr_StorageBuffer_uint %74 %uint_1 %98 +OpStore %99 %70 +%100 = OpIAdd %uint %76 %uint_7 +%101 = OpAccessChain %_ptr_StorageBuffer_uint %74 %uint_1 %100 +OpStore %101 %71 +%102 = OpIAdd %uint %76 %uint_8 +%103 = OpAccessChain %_ptr_StorageBuffer_uint %74 %uint_1 %102 +OpStore %103 %72 +OpBranch %80 +%80 = OpLabel +OpReturn +OpFunctionEnd +)"; + + // SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck( + defs_before + func_before, defs_after + func_after + output_func, true, + true); +} + +TEST_F(InstBindlessTest, InstrumentOpImage) { + // This test verifies that the pass will correctly instrument shader + // using OpImage. This test was created by editing the SPIR-V + // from the Simple test. + + const std::string defs_before = + R"(OpCapability Shader +OpCapability StorageImageReadWithoutFormat +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %MainPs "MainPs" %i_vTextureCoords %_entryPointOutput_vColor +OpExecutionMode %MainPs OriginUpperLeft +OpSource HLSL 500 +OpName %MainPs "MainPs" +OpName %g_tColor "g_tColor" +OpName %PerViewConstantBuffer_t "PerViewConstantBuffer_t" +OpMemberName %PerViewConstantBuffer_t 0 "g_nDataIdx" +OpName %_ "" +OpName %i_vTextureCoords "i.vTextureCoords" +OpName %_entryPointOutput_vColor "@entryPointOutput.vColor" +OpDecorate %g_tColor DescriptorSet 3 +OpDecorate %g_tColor Binding 0 +OpMemberDecorate %PerViewConstantBuffer_t 0 Offset 0 +OpDecorate %PerViewConstantBuffer_t Block +OpDecorate %i_vTextureCoords Location 0 +OpDecorate %_entryPointOutput_vColor Location 0 +%void = OpTypeVoid +%3 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%int = OpTypeInt 32 1 +%v2int = OpTypeVector %int 2 +%int_0 = OpConstant %int 0 +%20 = OpTypeImage %float 2D 0 0 0 0 Unknown +%uint = OpTypeInt 32 0 +%uint_128 = OpConstant %uint 128 +%39 = OpTypeSampledImage %20 +%_arr_39_uint_128 = OpTypeArray %39 %uint_128 +%_ptr_UniformConstant__arr_39_uint_128 = OpTypePointer UniformConstant %_arr_39_uint_128 +%g_tColor = OpVariable %_ptr_UniformConstant__arr_39_uint_128 UniformConstant +%PerViewConstantBuffer_t = OpTypeStruct %uint +%_ptr_PushConstant_PerViewConstantBuffer_t = OpTypePointer PushConstant %PerViewConstantBuffer_t +%_ = OpVariable %_ptr_PushConstant_PerViewConstantBuffer_t PushConstant +%_ptr_PushConstant_uint = OpTypePointer PushConstant %uint +%_ptr_UniformConstant_39 = OpTypePointer UniformConstant %39 +%_ptr_Input_v2int = OpTypePointer Input %v2int +%i_vTextureCoords = OpVariable %_ptr_Input_v2int Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_entryPointOutput_vColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string defs_after = + R"(OpCapability Shader +OpCapability StorageImageReadWithoutFormat +OpExtension "SPV_KHR_storage_buffer_storage_class" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %MainPs "MainPs" %i_vTextureCoords %_entryPointOutput_vColor %gl_FragCoord +OpExecutionMode %MainPs OriginUpperLeft +OpSource HLSL 500 +OpName %MainPs "MainPs" +OpName %g_tColor "g_tColor" +OpName %PerViewConstantBuffer_t "PerViewConstantBuffer_t" +OpMemberName %PerViewConstantBuffer_t 0 "g_nDataIdx" +OpName %_ "" +OpName %i_vTextureCoords "i.vTextureCoords" +OpName %_entryPointOutput_vColor "@entryPointOutput.vColor" +OpDecorate %g_tColor DescriptorSet 3 +OpDecorate %g_tColor Binding 0 +OpMemberDecorate %PerViewConstantBuffer_t 0 Offset 0 +OpDecorate %PerViewConstantBuffer_t Block +OpDecorate %i_vTextureCoords Location 0 +OpDecorate %_entryPointOutput_vColor Location 0 +OpDecorate %_runtimearr_uint ArrayStride 4 +OpDecorate %_struct_51 Block +OpMemberDecorate %_struct_51 0 Offset 0 +OpMemberDecorate %_struct_51 1 Offset 4 +OpDecorate %53 DescriptorSet 7 +OpDecorate %53 Binding 0 +OpDecorate %gl_FragCoord BuiltIn FragCoord +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%int = OpTypeInt 32 1 +%v2int = OpTypeVector %int 2 +%int_0 = OpConstant %int 0 +%15 = OpTypeImage %float 2D 0 0 0 0 Unknown +%uint = OpTypeInt 32 0 +%uint_128 = OpConstant %uint 128 +%18 = OpTypeSampledImage %15 +%_arr_18_uint_128 = OpTypeArray %18 %uint_128 +%_ptr_UniformConstant__arr_18_uint_128 = OpTypePointer UniformConstant %_arr_18_uint_128 +%g_tColor = OpVariable %_ptr_UniformConstant__arr_18_uint_128 UniformConstant +%PerViewConstantBuffer_t = OpTypeStruct %uint +%_ptr_PushConstant_PerViewConstantBuffer_t = OpTypePointer PushConstant %PerViewConstantBuffer_t +%_ = OpVariable %_ptr_PushConstant_PerViewConstantBuffer_t PushConstant +%_ptr_PushConstant_uint = OpTypePointer PushConstant %uint +%_ptr_UniformConstant_18 = OpTypePointer UniformConstant %18 +%_ptr_Input_v2int = OpTypePointer Input %v2int +%i_vTextureCoords = OpVariable %_ptr_Input_v2int Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_entryPointOutput_vColor = OpVariable %_ptr_Output_v4float Output +%uint_0 = OpConstant %uint 0 +%bool = OpTypeBool +%44 = OpTypeFunction %void %uint %uint %uint %uint +%_runtimearr_uint = OpTypeRuntimeArray %uint +%_struct_51 = OpTypeStruct %uint %_runtimearr_uint +%_ptr_StorageBuffer__struct_51 = OpTypePointer StorageBuffer %_struct_51 +%53 = OpVariable %_ptr_StorageBuffer__struct_51 StorageBuffer +%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint +%uint_9 = OpConstant %uint 9 +%uint_4 = OpConstant %uint 4 +%uint_1 = OpConstant %uint 1 +%uint_23 = OpConstant %uint 23 +%uint_2 = OpConstant %uint 2 +%uint_3 = OpConstant %uint 3 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%gl_FragCoord = OpVariable %_ptr_Input_v4float Input +%v4uint = OpTypeVector %uint 4 +%uint_5 = OpConstant %uint 5 +%uint_6 = OpConstant %uint 6 +%uint_7 = OpConstant %uint 7 +%uint_8 = OpConstant %uint 8 +%uint_51 = OpConstant %uint 51 +%99 = OpConstantNull %v4float +)"; + + const std::string func_before = + R"(%MainPs = OpFunction %void None %3 +%5 = OpLabel +%53 = OpLoad %v2int %i_vTextureCoords +%63 = OpAccessChain %_ptr_PushConstant_uint %_ %int_0 +%64 = OpLoad %uint %63 +%65 = OpAccessChain %_ptr_UniformConstant_39 %g_tColor %64 +%66 = OpLoad %39 %65 +%75 = OpImage %20 %66 +%71 = OpImageRead %v4float %75 %53 +OpStore %_entryPointOutput_vColor %71 +OpReturn +OpFunctionEnd +)"; + + const std::string func_after = + R"(%MainPs = OpFunction %void None %9 +%26 = OpLabel +%27 = OpLoad %v2int %i_vTextureCoords +%28 = OpAccessChain %_ptr_PushConstant_uint %_ %int_0 +%29 = OpLoad %uint %28 +%30 = OpAccessChain %_ptr_UniformConstant_18 %g_tColor %29 +%31 = OpLoad %18 %30 +%32 = OpImage %15 %31 +%36 = OpULessThan %bool %29 %uint_128 +OpSelectionMerge %37 None +OpBranchConditional %36 %38 %39 +%38 = OpLabel +%40 = OpLoad %18 %30 +%41 = OpImage %15 %40 +%42 = OpImageRead %v4float %41 %27 +OpBranch %37 +%39 = OpLabel +%98 = OpFunctionCall %void %43 %uint_51 %uint_0 %29 %uint_128 +OpBranch %37 +%37 = OpLabel +%100 = OpPhi %v4float %42 %38 %99 %39 +OpStore %_entryPointOutput_vColor %100 +OpReturn +OpFunctionEnd +)"; + + const std::string output_func = + R"(%43 = OpFunction %void None %44 +%45 = OpFunctionParameter %uint +%46 = OpFunctionParameter %uint +%47 = OpFunctionParameter %uint +%48 = OpFunctionParameter %uint +%49 = OpLabel +%55 = OpAccessChain %_ptr_StorageBuffer_uint %53 %uint_0 +%58 = OpAtomicIAdd %uint %55 %uint_4 %uint_0 %uint_9 +%59 = OpIAdd %uint %58 %uint_9 +%60 = OpArrayLength %uint %53 1 +%61 = OpULessThanEqual %bool %59 %60 +OpSelectionMerge %62 None +OpBranchConditional %61 %63 %62 +%63 = OpLabel +%64 = OpIAdd %uint %58 %uint_0 +%66 = OpAccessChain %_ptr_StorageBuffer_uint %53 %uint_1 %64 +OpStore %66 %uint_9 +%68 = OpIAdd %uint %58 %uint_1 +%69 = OpAccessChain %_ptr_StorageBuffer_uint %53 %uint_1 %68 +OpStore %69 %uint_23 +%71 = OpIAdd %uint %58 %uint_2 +%72 = OpAccessChain %_ptr_StorageBuffer_uint %53 %uint_1 %71 +OpStore %72 %45 +%74 = OpIAdd %uint %58 %uint_3 +%75 = OpAccessChain %_ptr_StorageBuffer_uint %53 %uint_1 %74 +OpStore %75 %uint_4 +%78 = OpLoad %v4float %gl_FragCoord +%80 = OpBitcast %v4uint %78 +%81 = OpCompositeExtract %uint %80 0 +%82 = OpIAdd %uint %58 %uint_4 +%83 = OpAccessChain %_ptr_StorageBuffer_uint %53 %uint_1 %82 +OpStore %83 %81 +%84 = OpCompositeExtract %uint %80 1 +%86 = OpIAdd %uint %58 %uint_5 +%87 = OpAccessChain %_ptr_StorageBuffer_uint %53 %uint_1 %86 +OpStore %87 %84 +%89 = OpIAdd %uint %58 %uint_6 +%90 = OpAccessChain %_ptr_StorageBuffer_uint %53 %uint_1 %89 +OpStore %90 %46 +%92 = OpIAdd %uint %58 %uint_7 +%93 = OpAccessChain %_ptr_StorageBuffer_uint %53 %uint_1 %92 +OpStore %93 %47 +%95 = OpIAdd %uint %58 %uint_8 +%96 = OpAccessChain %_ptr_StorageBuffer_uint %53 %uint_1 %95 +OpStore %96 %48 +OpBranch %62 +%62 = OpLabel +OpReturn +OpFunctionEnd +)"; + + // SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck( + defs_before + func_before, defs_after + func_after + output_func, true, + true); +} + +TEST_F(InstBindlessTest, InstrumentSampledImage) { + // This test verifies that the pass will correctly instrument shader + // using sampled image. This test was created by editing the SPIR-V + // from the Simple test. + + const std::string defs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %MainPs "MainPs" %i_vTextureCoords %_entryPointOutput_vColor +OpExecutionMode %MainPs OriginUpperLeft +OpSource HLSL 500 +OpName %MainPs "MainPs" +OpName %g_tColor "g_tColor" +OpName %PerViewConstantBuffer_t "PerViewConstantBuffer_t" +OpMemberName %PerViewConstantBuffer_t 0 "g_nDataIdx" +OpName %_ "" +OpName %i_vTextureCoords "i.vTextureCoords" +OpName %_entryPointOutput_vColor "@entryPointOutput.vColor" +OpDecorate %g_tColor DescriptorSet 3 +OpDecorate %g_tColor Binding 0 +OpMemberDecorate %PerViewConstantBuffer_t 0 Offset 0 +OpDecorate %PerViewConstantBuffer_t Block +OpDecorate %i_vTextureCoords Location 0 +OpDecorate %_entryPointOutput_vColor Location 0 +%void = OpTypeVoid +%3 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%v4float = OpTypeVector %float 4 +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%20 = OpTypeImage %float 2D 0 0 0 1 Unknown +%uint = OpTypeInt 32 0 +%uint_128 = OpConstant %uint 128 +%39 = OpTypeSampledImage %20 +%_arr_39_uint_128 = OpTypeArray %39 %uint_128 +%_ptr_UniformConstant__arr_39_uint_128 = OpTypePointer UniformConstant %_arr_39_uint_128 +%g_tColor = OpVariable %_ptr_UniformConstant__arr_39_uint_128 UniformConstant +%PerViewConstantBuffer_t = OpTypeStruct %uint +%_ptr_PushConstant_PerViewConstantBuffer_t = OpTypePointer PushConstant %PerViewConstantBuffer_t +%_ = OpVariable %_ptr_PushConstant_PerViewConstantBuffer_t PushConstant +%_ptr_PushConstant_uint = OpTypePointer PushConstant %uint +%_ptr_UniformConstant_39 = OpTypePointer UniformConstant %39 +%_ptr_Input_v2float = OpTypePointer Input %v2float +%i_vTextureCoords = OpVariable %_ptr_Input_v2float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_entryPointOutput_vColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string defs_after = + R"(OpCapability Shader +OpExtension "SPV_KHR_storage_buffer_storage_class" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %MainPs "MainPs" %i_vTextureCoords %_entryPointOutput_vColor %gl_FragCoord +OpExecutionMode %MainPs OriginUpperLeft +OpSource HLSL 500 +OpName %MainPs "MainPs" +OpName %g_tColor "g_tColor" +OpName %PerViewConstantBuffer_t "PerViewConstantBuffer_t" +OpMemberName %PerViewConstantBuffer_t 0 "g_nDataIdx" +OpName %_ "" +OpName %i_vTextureCoords "i.vTextureCoords" +OpName %_entryPointOutput_vColor "@entryPointOutput.vColor" +OpDecorate %g_tColor DescriptorSet 3 +OpDecorate %g_tColor Binding 0 +OpMemberDecorate %PerViewConstantBuffer_t 0 Offset 0 +OpDecorate %PerViewConstantBuffer_t Block +OpDecorate %i_vTextureCoords Location 0 +OpDecorate %_entryPointOutput_vColor Location 0 +OpDecorate %_runtimearr_uint ArrayStride 4 +OpDecorate %_struct_49 Block +OpMemberDecorate %_struct_49 0 Offset 0 +OpMemberDecorate %_struct_49 1 Offset 4 +OpDecorate %51 DescriptorSet 7 +OpDecorate %51 Binding 0 +OpDecorate %gl_FragCoord BuiltIn FragCoord +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%v4float = OpTypeVector %float 4 +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%15 = OpTypeImage %float 2D 0 0 0 1 Unknown +%uint = OpTypeInt 32 0 +%uint_128 = OpConstant %uint 128 +%18 = OpTypeSampledImage %15 +%_arr_18_uint_128 = OpTypeArray %18 %uint_128 +%_ptr_UniformConstant__arr_18_uint_128 = OpTypePointer UniformConstant %_arr_18_uint_128 +%g_tColor = OpVariable %_ptr_UniformConstant__arr_18_uint_128 UniformConstant +%PerViewConstantBuffer_t = OpTypeStruct %uint +%_ptr_PushConstant_PerViewConstantBuffer_t = OpTypePointer PushConstant %PerViewConstantBuffer_t +%_ = OpVariable %_ptr_PushConstant_PerViewConstantBuffer_t PushConstant +%_ptr_PushConstant_uint = OpTypePointer PushConstant %uint +%_ptr_UniformConstant_18 = OpTypePointer UniformConstant %18 +%_ptr_Input_v2float = OpTypePointer Input %v2float +%i_vTextureCoords = OpVariable %_ptr_Input_v2float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_entryPointOutput_vColor = OpVariable %_ptr_Output_v4float Output +%uint_0 = OpConstant %uint 0 +%bool = OpTypeBool +%42 = OpTypeFunction %void %uint %uint %uint %uint +%_runtimearr_uint = OpTypeRuntimeArray %uint +%_struct_49 = OpTypeStruct %uint %_runtimearr_uint +%_ptr_StorageBuffer__struct_49 = OpTypePointer StorageBuffer %_struct_49 +%51 = OpVariable %_ptr_StorageBuffer__struct_49 StorageBuffer +%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint +%uint_9 = OpConstant %uint 9 +%uint_4 = OpConstant %uint 4 +%uint_1 = OpConstant %uint 1 +%uint_23 = OpConstant %uint 23 +%uint_2 = OpConstant %uint 2 +%uint_3 = OpConstant %uint 3 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%gl_FragCoord = OpVariable %_ptr_Input_v4float Input +%v4uint = OpTypeVector %uint 4 +%uint_5 = OpConstant %uint 5 +%uint_6 = OpConstant %uint 6 +%uint_7 = OpConstant %uint 7 +%uint_8 = OpConstant %uint 8 +%uint_49 = OpConstant %uint 49 +%97 = OpConstantNull %v4float +)"; + + const std::string func_before = + R"(%MainPs = OpFunction %void None %3 +%5 = OpLabel +%53 = OpLoad %v2float %i_vTextureCoords +%63 = OpAccessChain %_ptr_PushConstant_uint %_ %int_0 +%64 = OpLoad %uint %63 +%65 = OpAccessChain %_ptr_UniformConstant_39 %g_tColor %64 +%66 = OpLoad %39 %65 +%71 = OpImageSampleImplicitLod %v4float %66 %53 +OpStore %_entryPointOutput_vColor %71 +OpReturn +OpFunctionEnd +)"; + + const std::string func_after = + R"(%MainPs = OpFunction %void None %9 +%26 = OpLabel +%27 = OpLoad %v2float %i_vTextureCoords +%28 = OpAccessChain %_ptr_PushConstant_uint %_ %int_0 +%29 = OpLoad %uint %28 +%30 = OpAccessChain %_ptr_UniformConstant_18 %g_tColor %29 +%31 = OpLoad %18 %30 +%35 = OpULessThan %bool %29 %uint_128 +OpSelectionMerge %36 None +OpBranchConditional %35 %37 %38 +%37 = OpLabel +%39 = OpLoad %18 %30 +%40 = OpImageSampleImplicitLod %v4float %39 %27 +OpBranch %36 +%38 = OpLabel +%96 = OpFunctionCall %void %41 %uint_49 %uint_0 %29 %uint_128 +OpBranch %36 +%36 = OpLabel +%98 = OpPhi %v4float %40 %37 %97 %38 +OpStore %_entryPointOutput_vColor %98 +OpReturn +OpFunctionEnd +)"; + + const std::string output_func = + R"(%41 = OpFunction %void None %42 +%43 = OpFunctionParameter %uint +%44 = OpFunctionParameter %uint +%45 = OpFunctionParameter %uint +%46 = OpFunctionParameter %uint +%47 = OpLabel +%53 = OpAccessChain %_ptr_StorageBuffer_uint %51 %uint_0 +%56 = OpAtomicIAdd %uint %53 %uint_4 %uint_0 %uint_9 +%57 = OpIAdd %uint %56 %uint_9 +%58 = OpArrayLength %uint %51 1 +%59 = OpULessThanEqual %bool %57 %58 +OpSelectionMerge %60 None +OpBranchConditional %59 %61 %60 +%61 = OpLabel +%62 = OpIAdd %uint %56 %uint_0 +%64 = OpAccessChain %_ptr_StorageBuffer_uint %51 %uint_1 %62 +OpStore %64 %uint_9 +%66 = OpIAdd %uint %56 %uint_1 +%67 = OpAccessChain %_ptr_StorageBuffer_uint %51 %uint_1 %66 +OpStore %67 %uint_23 +%69 = OpIAdd %uint %56 %uint_2 +%70 = OpAccessChain %_ptr_StorageBuffer_uint %51 %uint_1 %69 +OpStore %70 %43 +%72 = OpIAdd %uint %56 %uint_3 +%73 = OpAccessChain %_ptr_StorageBuffer_uint %51 %uint_1 %72 +OpStore %73 %uint_4 +%76 = OpLoad %v4float %gl_FragCoord +%78 = OpBitcast %v4uint %76 +%79 = OpCompositeExtract %uint %78 0 +%80 = OpIAdd %uint %56 %uint_4 +%81 = OpAccessChain %_ptr_StorageBuffer_uint %51 %uint_1 %80 +OpStore %81 %79 +%82 = OpCompositeExtract %uint %78 1 +%84 = OpIAdd %uint %56 %uint_5 +%85 = OpAccessChain %_ptr_StorageBuffer_uint %51 %uint_1 %84 +OpStore %85 %82 +%87 = OpIAdd %uint %56 %uint_6 +%88 = OpAccessChain %_ptr_StorageBuffer_uint %51 %uint_1 %87 +OpStore %88 %44 +%90 = OpIAdd %uint %56 %uint_7 +%91 = OpAccessChain %_ptr_StorageBuffer_uint %51 %uint_1 %90 +OpStore %91 %45 +%93 = OpIAdd %uint %56 %uint_8 +%94 = OpAccessChain %_ptr_StorageBuffer_uint %51 %uint_1 %93 +OpStore %94 %46 +OpBranch %60 +%60 = OpLabel +OpReturn +OpFunctionEnd +)"; + + // SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck( + defs_before + func_before, defs_after + func_after + output_func, true, + true); +} + +TEST_F(InstBindlessTest, InstrumentImageWrite) { + // This test verifies that the pass will correctly instrument shader + // doing bindless image write. This test was created by editing the SPIR-V + // from the Simple test. + + const std::string defs_before = + R"(OpCapability Shader +OpCapability StorageImageWriteWithoutFormat +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %MainPs "MainPs" %i_vTextureCoords %_entryPointOutput_vColor +OpExecutionMode %MainPs OriginUpperLeft +OpSource HLSL 500 +OpName %MainPs "MainPs" +OpName %g_tColor "g_tColor" +OpName %PerViewConstantBuffer_t "PerViewConstantBuffer_t" +OpMemberName %PerViewConstantBuffer_t 0 "g_nDataIdx" +OpName %_ "" +OpName %i_vTextureCoords "i.vTextureCoords" +OpName %_entryPointOutput_vColor "@entryPointOutput.vColor" +OpDecorate %g_tColor DescriptorSet 3 +OpDecorate %g_tColor Binding 0 +OpMemberDecorate %PerViewConstantBuffer_t 0 Offset 0 +OpDecorate %PerViewConstantBuffer_t Block +OpDecorate %i_vTextureCoords Location 0 +OpDecorate %_entryPointOutput_vColor Location 0 +%void = OpTypeVoid +%3 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%v4float = OpTypeVector %float 4 +%int = OpTypeInt 32 1 +%v2int = OpTypeVector %int 2 +%int_0 = OpConstant %int 0 +%20 = OpTypeImage %float 2D 0 0 0 0 Unknown +%uint = OpTypeInt 32 0 +%uint_128 = OpConstant %uint 128 +%80 = OpConstantNull %v4float +%_arr_20_uint_128 = OpTypeArray %20 %uint_128 +%_ptr_UniformConstant__arr_20_uint_128 = OpTypePointer UniformConstant %_arr_20_uint_128 +%g_tColor = OpVariable %_ptr_UniformConstant__arr_20_uint_128 UniformConstant +%PerViewConstantBuffer_t = OpTypeStruct %uint +%_ptr_PushConstant_PerViewConstantBuffer_t = OpTypePointer PushConstant %PerViewConstantBuffer_t +%_ = OpVariable %_ptr_PushConstant_PerViewConstantBuffer_t PushConstant +%_ptr_PushConstant_uint = OpTypePointer PushConstant %uint +%_ptr_UniformConstant_20 = OpTypePointer UniformConstant %20 +%_ptr_Input_v2int = OpTypePointer Input %v2int +%i_vTextureCoords = OpVariable %_ptr_Input_v2int Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_entryPointOutput_vColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string defs_after = + R"(OpCapability Shader +OpCapability StorageImageWriteWithoutFormat +OpExtension "SPV_KHR_storage_buffer_storage_class" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %MainPs "MainPs" %i_vTextureCoords %_entryPointOutput_vColor %gl_FragCoord +OpExecutionMode %MainPs OriginUpperLeft +OpSource HLSL 500 +OpName %MainPs "MainPs" +OpName %g_tColor "g_tColor" +OpName %PerViewConstantBuffer_t "PerViewConstantBuffer_t" +OpMemberName %PerViewConstantBuffer_t 0 "g_nDataIdx" +OpName %_ "" +OpName %i_vTextureCoords "i.vTextureCoords" +OpName %_entryPointOutput_vColor "@entryPointOutput.vColor" +OpDecorate %g_tColor DescriptorSet 3 +OpDecorate %g_tColor Binding 0 +OpMemberDecorate %PerViewConstantBuffer_t 0 Offset 0 +OpDecorate %PerViewConstantBuffer_t Block +OpDecorate %i_vTextureCoords Location 0 +OpDecorate %_entryPointOutput_vColor Location 0 +OpDecorate %_runtimearr_uint ArrayStride 4 +OpDecorate %_struct_48 Block +OpMemberDecorate %_struct_48 0 Offset 0 +OpMemberDecorate %_struct_48 1 Offset 4 +OpDecorate %50 DescriptorSet 7 +OpDecorate %50 Binding 0 +OpDecorate %gl_FragCoord BuiltIn FragCoord +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%v4float = OpTypeVector %float 4 +%int = OpTypeInt 32 1 +%v2int = OpTypeVector %int 2 +%int_0 = OpConstant %int 0 +%16 = OpTypeImage %float 2D 0 0 0 0 Unknown +%uint = OpTypeInt 32 0 +%uint_128 = OpConstant %uint 128 +%19 = OpConstantNull %v4float +%_arr_16_uint_128 = OpTypeArray %16 %uint_128 +%_ptr_UniformConstant__arr_16_uint_128 = OpTypePointer UniformConstant %_arr_16_uint_128 +%g_tColor = OpVariable %_ptr_UniformConstant__arr_16_uint_128 UniformConstant +%PerViewConstantBuffer_t = OpTypeStruct %uint +%_ptr_PushConstant_PerViewConstantBuffer_t = OpTypePointer PushConstant %PerViewConstantBuffer_t +%_ = OpVariable %_ptr_PushConstant_PerViewConstantBuffer_t PushConstant +%_ptr_PushConstant_uint = OpTypePointer PushConstant %uint +%_ptr_UniformConstant_16 = OpTypePointer UniformConstant %16 +%_ptr_Input_v2int = OpTypePointer Input %v2int +%i_vTextureCoords = OpVariable %_ptr_Input_v2int Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_entryPointOutput_vColor = OpVariable %_ptr_Output_v4float Output +%uint_0 = OpConstant %uint 0 +%bool = OpTypeBool +%41 = OpTypeFunction %void %uint %uint %uint %uint +%_runtimearr_uint = OpTypeRuntimeArray %uint +%_struct_48 = OpTypeStruct %uint %_runtimearr_uint +%_ptr_StorageBuffer__struct_48 = OpTypePointer StorageBuffer %_struct_48 +%50 = OpVariable %_ptr_StorageBuffer__struct_48 StorageBuffer +%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint +%uint_9 = OpConstant %uint 9 +%uint_4 = OpConstant %uint 4 +%uint_1 = OpConstant %uint 1 +%uint_23 = OpConstant %uint 23 +%uint_2 = OpConstant %uint 2 +%uint_3 = OpConstant %uint 3 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%gl_FragCoord = OpVariable %_ptr_Input_v4float Input +%v4uint = OpTypeVector %uint 4 +%uint_5 = OpConstant %uint 5 +%uint_6 = OpConstant %uint 6 +%uint_7 = OpConstant %uint 7 +%uint_8 = OpConstant %uint 8 +%uint_51 = OpConstant %uint 51 +)"; + + const std::string func_before = + R"(%MainPs = OpFunction %void None %3 +%5 = OpLabel +%53 = OpLoad %v2int %i_vTextureCoords +%63 = OpAccessChain %_ptr_PushConstant_uint %_ %int_0 +%64 = OpLoad %uint %63 +%65 = OpAccessChain %_ptr_UniformConstant_20 %g_tColor %64 +%66 = OpLoad %20 %65 +OpImageWrite %66 %53 %80 +OpStore %_entryPointOutput_vColor %80 +OpReturn +OpFunctionEnd +)"; + + const std::string func_after = + R"(%MainPs = OpFunction %void None %9 +%27 = OpLabel +%28 = OpLoad %v2int %i_vTextureCoords +%29 = OpAccessChain %_ptr_PushConstant_uint %_ %int_0 +%30 = OpLoad %uint %29 +%31 = OpAccessChain %_ptr_UniformConstant_16 %g_tColor %30 +%32 = OpLoad %16 %31 +%35 = OpULessThan %bool %30 %uint_128 +OpSelectionMerge %36 None +OpBranchConditional %35 %37 %38 +%37 = OpLabel +%39 = OpLoad %16 %31 +OpImageWrite %39 %28 %19 +OpBranch %36 +%38 = OpLabel +%95 = OpFunctionCall %void %40 %uint_51 %uint_0 %30 %uint_128 +OpBranch %36 +%36 = OpLabel +OpStore %_entryPointOutput_vColor %19 +OpReturn +OpFunctionEnd +)"; + + const std::string output_func = + R"(%40 = OpFunction %void None %41 +%42 = OpFunctionParameter %uint +%43 = OpFunctionParameter %uint +%44 = OpFunctionParameter %uint +%45 = OpFunctionParameter %uint +%46 = OpLabel +%52 = OpAccessChain %_ptr_StorageBuffer_uint %50 %uint_0 +%55 = OpAtomicIAdd %uint %52 %uint_4 %uint_0 %uint_9 +%56 = OpIAdd %uint %55 %uint_9 +%57 = OpArrayLength %uint %50 1 +%58 = OpULessThanEqual %bool %56 %57 +OpSelectionMerge %59 None +OpBranchConditional %58 %60 %59 +%60 = OpLabel +%61 = OpIAdd %uint %55 %uint_0 +%63 = OpAccessChain %_ptr_StorageBuffer_uint %50 %uint_1 %61 +OpStore %63 %uint_9 +%65 = OpIAdd %uint %55 %uint_1 +%66 = OpAccessChain %_ptr_StorageBuffer_uint %50 %uint_1 %65 +OpStore %66 %uint_23 +%68 = OpIAdd %uint %55 %uint_2 +%69 = OpAccessChain %_ptr_StorageBuffer_uint %50 %uint_1 %68 +OpStore %69 %42 +%71 = OpIAdd %uint %55 %uint_3 +%72 = OpAccessChain %_ptr_StorageBuffer_uint %50 %uint_1 %71 +OpStore %72 %uint_4 +%75 = OpLoad %v4float %gl_FragCoord +%77 = OpBitcast %v4uint %75 +%78 = OpCompositeExtract %uint %77 0 +%79 = OpIAdd %uint %55 %uint_4 +%80 = OpAccessChain %_ptr_StorageBuffer_uint %50 %uint_1 %79 +OpStore %80 %78 +%81 = OpCompositeExtract %uint %77 1 +%83 = OpIAdd %uint %55 %uint_5 +%84 = OpAccessChain %_ptr_StorageBuffer_uint %50 %uint_1 %83 +OpStore %84 %81 +%86 = OpIAdd %uint %55 %uint_6 +%87 = OpAccessChain %_ptr_StorageBuffer_uint %50 %uint_1 %86 +OpStore %87 %43 +%89 = OpIAdd %uint %55 %uint_7 +%90 = OpAccessChain %_ptr_StorageBuffer_uint %50 %uint_1 %89 +OpStore %90 %44 +%92 = OpIAdd %uint %55 %uint_8 +%93 = OpAccessChain %_ptr_StorageBuffer_uint %50 %uint_1 %92 +OpStore %93 %45 +OpBranch %59 +%59 = OpLabel +OpReturn +OpFunctionEnd +)"; + + // SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck( + defs_before + func_before, defs_after + func_after + output_func, true, + true); +} + +TEST_F(InstBindlessTest, InstrumentVertexSimple) { + // This test verifies that the pass will correctly instrument shader + // doing bindless image write. This test was created by editing the SPIR-V + // from the Simple test. + + const std::string defs_before = + R"(OpCapability Shader +OpCapability Sampled1D +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %main "main" %_ %coords2D +OpSource GLSL 450 +OpName %main "main" +OpName %lod "lod" +OpName %coords1D "coords1D" +OpName %gl_PerVertex "gl_PerVertex" +OpMemberName %gl_PerVertex 0 "gl_Position" +OpMemberName %gl_PerVertex 1 "gl_PointSize" +OpMemberName %gl_PerVertex 2 "gl_ClipDistance" +OpMemberName %gl_PerVertex 3 "gl_CullDistance" +OpName %_ "" +OpName %texSampler1D "texSampler1D" +OpName %foo "foo" +OpMemberName %foo 0 "g_idx" +OpName %__0 "" +OpName %coords2D "coords2D" +OpMemberDecorate %gl_PerVertex 0 BuiltIn Position +OpMemberDecorate %gl_PerVertex 1 BuiltIn PointSize +OpMemberDecorate %gl_PerVertex 2 BuiltIn ClipDistance +OpMemberDecorate %gl_PerVertex 3 BuiltIn CullDistance +OpDecorate %gl_PerVertex Block +OpDecorate %texSampler1D DescriptorSet 0 +OpDecorate %texSampler1D Binding 3 +OpMemberDecorate %foo 0 Offset 0 +OpDecorate %foo Block +OpDecorate %__0 DescriptorSet 0 +OpDecorate %__0 Binding 5 +OpDecorate %coords2D Location 0 +%void = OpTypeVoid +%3 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%float_3 = OpConstant %float 3 +%float_1_78900003 = OpConstant %float 1.78900003 +%v4float = OpTypeVector %float 4 +%uint = OpTypeInt 32 0 +%uint_1 = OpConstant %uint 1 +%_arr_float_uint_1 = OpTypeArray %float %uint_1 +%gl_PerVertex = OpTypeStruct %v4float %float %_arr_float_uint_1 %_arr_float_uint_1 +%_ptr_Output_gl_PerVertex = OpTypePointer Output %gl_PerVertex +%_ = OpVariable %_ptr_Output_gl_PerVertex Output +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%21 = OpTypeImage %float 1D 0 0 0 1 Unknown +%22 = OpTypeSampledImage %21 +%uint_128 = OpConstant %uint 128 +%_arr_22_uint_128 = OpTypeArray %22 %uint_128 +%_ptr_UniformConstant__arr_22_uint_128 = OpTypePointer UniformConstant %_arr_22_uint_128 +%texSampler1D = OpVariable %_ptr_UniformConstant__arr_22_uint_128 UniformConstant +%foo = OpTypeStruct %int +%_ptr_Uniform_foo = OpTypePointer Uniform %foo +%__0 = OpVariable %_ptr_Uniform_foo Uniform +%_ptr_Uniform_int = OpTypePointer Uniform %int +%_ptr_UniformConstant_22 = OpTypePointer UniformConstant %22 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%v2float = OpTypeVector %float 2 +%_ptr_Input_v2float = OpTypePointer Input %v2float +%coords2D = OpVariable %_ptr_Input_v2float Input +)"; + + const std::string defs_after = + R"(OpCapability Shader +OpCapability Sampled1D +OpExtension "SPV_KHR_storage_buffer_storage_class" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %main "main" %_ %coords2D %gl_VertexIndex %gl_InstanceIndex +OpSource GLSL 450 +OpName %main "main" +OpName %lod "lod" +OpName %coords1D "coords1D" +OpName %gl_PerVertex "gl_PerVertex" +OpMemberName %gl_PerVertex 0 "gl_Position" +OpMemberName %gl_PerVertex 1 "gl_PointSize" +OpMemberName %gl_PerVertex 2 "gl_ClipDistance" +OpMemberName %gl_PerVertex 3 "gl_CullDistance" +OpName %_ "" +OpName %texSampler1D "texSampler1D" +OpName %foo "foo" +OpMemberName %foo 0 "g_idx" +OpName %__0 "" +OpName %coords2D "coords2D" +OpMemberDecorate %gl_PerVertex 0 BuiltIn Position +OpMemberDecorate %gl_PerVertex 1 BuiltIn PointSize +OpMemberDecorate %gl_PerVertex 2 BuiltIn ClipDistance +OpMemberDecorate %gl_PerVertex 3 BuiltIn CullDistance +OpDecorate %gl_PerVertex Block +OpDecorate %texSampler1D DescriptorSet 0 +OpDecorate %texSampler1D Binding 3 +OpMemberDecorate %foo 0 Offset 0 +OpDecorate %foo Block +OpDecorate %__0 DescriptorSet 0 +OpDecorate %__0 Binding 5 +OpDecorate %coords2D Location 0 +OpDecorate %_runtimearr_uint ArrayStride 4 +OpDecorate %_struct_61 Block +OpMemberDecorate %_struct_61 0 Offset 0 +OpMemberDecorate %_struct_61 1 Offset 4 +OpDecorate %63 DescriptorSet 7 +OpDecorate %63 Binding 0 +OpDecorate %gl_VertexIndex BuiltIn VertexIndex +OpDecorate %gl_InstanceIndex BuiltIn InstanceIndex +%void = OpTypeVoid +%12 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%float_3 = OpConstant %float 3 +%float_1_78900003 = OpConstant %float 1.78900003 +%v4float = OpTypeVector %float 4 +%uint = OpTypeInt 32 0 +%uint_1 = OpConstant %uint 1 +%_arr_float_uint_1 = OpTypeArray %float %uint_1 +%gl_PerVertex = OpTypeStruct %v4float %float %_arr_float_uint_1 %_arr_float_uint_1 +%_ptr_Output_gl_PerVertex = OpTypePointer Output %gl_PerVertex +%_ = OpVariable %_ptr_Output_gl_PerVertex Output +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%24 = OpTypeImage %float 1D 0 0 0 1 Unknown +%25 = OpTypeSampledImage %24 +%uint_128 = OpConstant %uint 128 +%_arr_25_uint_128 = OpTypeArray %25 %uint_128 +%_ptr_UniformConstant__arr_25_uint_128 = OpTypePointer UniformConstant %_arr_25_uint_128 +%texSampler1D = OpVariable %_ptr_UniformConstant__arr_25_uint_128 UniformConstant +%foo = OpTypeStruct %int +%_ptr_Uniform_foo = OpTypePointer Uniform %foo +%__0 = OpVariable %_ptr_Uniform_foo Uniform +%_ptr_Uniform_int = OpTypePointer Uniform %int +%_ptr_UniformConstant_25 = OpTypePointer UniformConstant %25 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%v2float = OpTypeVector %float 2 +%_ptr_Input_v2float = OpTypePointer Input %v2float +%coords2D = OpVariable %_ptr_Input_v2float Input +%uint_0 = OpConstant %uint 0 +%bool = OpTypeBool +%54 = OpTypeFunction %void %uint %uint %uint %uint +%_runtimearr_uint = OpTypeRuntimeArray %uint +%_struct_61 = OpTypeStruct %uint %_runtimearr_uint +%_ptr_StorageBuffer__struct_61 = OpTypePointer StorageBuffer %_struct_61 +%63 = OpVariable %_ptr_StorageBuffer__struct_61 StorageBuffer +%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint +%uint_9 = OpConstant %uint 9 +%uint_4 = OpConstant %uint 4 +%uint_23 = OpConstant %uint 23 +%uint_2 = OpConstant %uint 2 +%uint_3 = OpConstant %uint 3 +%_ptr_Input_uint = OpTypePointer Input %uint +%gl_VertexIndex = OpVariable %_ptr_Input_uint Input +%gl_InstanceIndex = OpVariable %_ptr_Input_uint Input +%uint_5 = OpConstant %uint 5 +%uint_6 = OpConstant %uint 6 +%uint_7 = OpConstant %uint 7 +%uint_8 = OpConstant %uint 8 +%uint_74 = OpConstant %uint 74 +%106 = OpConstantNull %v4float +)"; + + const std::string func_before = + R"(%main = OpFunction %void None %3 +%5 = OpLabel +%lod = OpVariable %_ptr_Function_float Function +%coords1D = OpVariable %_ptr_Function_float Function +OpStore %lod %float_3 +OpStore %coords1D %float_1_78900003 +%31 = OpAccessChain %_ptr_Uniform_int %__0 %int_0 +%32 = OpLoad %int %31 +%34 = OpAccessChain %_ptr_UniformConstant_22 %texSampler1D %32 +%35 = OpLoad %22 %34 +%36 = OpLoad %float %coords1D +%37 = OpLoad %float %lod +%38 = OpImageSampleExplicitLod %v4float %35 %36 Lod %37 +%40 = OpAccessChain %_ptr_Output_v4float %_ %int_0 +OpStore %40 %38 +OpReturn +OpFunctionEnd +)"; + + const std::string func_after = + R"(%main = OpFunction %void None %12 +%35 = OpLabel +%lod = OpVariable %_ptr_Function_float Function +%coords1D = OpVariable %_ptr_Function_float Function +OpStore %lod %float_3 +OpStore %coords1D %float_1_78900003 +%36 = OpAccessChain %_ptr_Uniform_int %__0 %int_0 +%37 = OpLoad %int %36 +%38 = OpAccessChain %_ptr_UniformConstant_25 %texSampler1D %37 +%39 = OpLoad %25 %38 +%40 = OpLoad %float %coords1D +%41 = OpLoad %float %lod +%46 = OpULessThan %bool %37 %uint_128 +OpSelectionMerge %47 None +OpBranchConditional %46 %48 %49 +%48 = OpLabel +%50 = OpLoad %25 %38 +%51 = OpImageSampleExplicitLod %v4float %50 %40 Lod %41 +OpBranch %47 +%49 = OpLabel +%52 = OpBitcast %uint %37 +%105 = OpFunctionCall %void %53 %uint_74 %uint_0 %52 %uint_128 +OpBranch %47 +%47 = OpLabel +%107 = OpPhi %v4float %51 %48 %106 %49 +%43 = OpAccessChain %_ptr_Output_v4float %_ %int_0 +OpStore %43 %107 +OpReturn +OpFunctionEnd +)"; + + const std::string output_func = + R"(%53 = OpFunction %void None %54 +%55 = OpFunctionParameter %uint +%56 = OpFunctionParameter %uint +%57 = OpFunctionParameter %uint +%58 = OpFunctionParameter %uint +%59 = OpLabel +%65 = OpAccessChain %_ptr_StorageBuffer_uint %63 %uint_0 +%68 = OpAtomicIAdd %uint %65 %uint_4 %uint_0 %uint_9 +%69 = OpIAdd %uint %68 %uint_9 +%70 = OpArrayLength %uint %63 1 +%71 = OpULessThanEqual %bool %69 %70 +OpSelectionMerge %72 None +OpBranchConditional %71 %73 %72 +%73 = OpLabel +%74 = OpIAdd %uint %68 %uint_0 +%75 = OpAccessChain %_ptr_StorageBuffer_uint %63 %uint_1 %74 +OpStore %75 %uint_9 +%77 = OpIAdd %uint %68 %uint_1 +%78 = OpAccessChain %_ptr_StorageBuffer_uint %63 %uint_1 %77 +OpStore %78 %uint_23 +%80 = OpIAdd %uint %68 %uint_2 +%81 = OpAccessChain %_ptr_StorageBuffer_uint %63 %uint_1 %80 +OpStore %81 %55 +%83 = OpIAdd %uint %68 %uint_3 +%84 = OpAccessChain %_ptr_StorageBuffer_uint %63 %uint_1 %83 +OpStore %84 %uint_0 +%87 = OpLoad %uint %gl_VertexIndex +%88 = OpIAdd %uint %68 %uint_4 +%89 = OpAccessChain %_ptr_StorageBuffer_uint %63 %uint_1 %88 +OpStore %89 %87 +%91 = OpLoad %uint %gl_InstanceIndex +%93 = OpIAdd %uint %68 %uint_5 +%94 = OpAccessChain %_ptr_StorageBuffer_uint %63 %uint_1 %93 +OpStore %94 %91 +%96 = OpIAdd %uint %68 %uint_6 +%97 = OpAccessChain %_ptr_StorageBuffer_uint %63 %uint_1 %96 +OpStore %97 %56 +%99 = OpIAdd %uint %68 %uint_7 +%100 = OpAccessChain %_ptr_StorageBuffer_uint %63 %uint_1 %99 +OpStore %100 %57 +%102 = OpIAdd %uint %68 %uint_8 +%103 = OpAccessChain %_ptr_StorageBuffer_uint %63 %uint_1 %102 +OpStore %103 %58 +OpBranch %72 +%72 = OpLabel +OpReturn +OpFunctionEnd +)"; + + // SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck( + defs_before + func_before, defs_after + func_after + output_func, true, + true); +} + +// TODO(greg-lunarg): Add tests to verify handling of these cases: +// +// TODO(greg-lunarg): Come up with cases to put here :) + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/instruction_list_test.cpp b/test/opt/instruction_list_test.cpp new file mode 100644 index 000000000..e745790a3 --- /dev/null +++ b/test/opt/instruction_list_test.cpp @@ -0,0 +1,115 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "source/opt/instruction.h" +#include "source/opt/instruction_list.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::ContainerEq; +using ::testing::ElementsAre; +using InstructionListTest = ::testing::Test; + +// A class that overrides the destructor, so we can trace it. +class TestInstruction : public Instruction { + public: + TestInstruction() : Instruction() { created_instructions_.push_back(this); } + + ~TestInstruction() { deleted_instructions_.push_back(this); } + + static std::vector created_instructions_; + static std::vector deleted_instructions_; +}; + +std::vector TestInstruction::created_instructions_; +std::vector TestInstruction::deleted_instructions_; + +// Test that the destructor for InstructionList is calling the destructor +// for every element that is in the list. +TEST(InstructionListTest, Destructor) { + InstructionList* list = new InstructionList(); + list->push_back(std::unique_ptr(new Instruction())); + list->push_back(std::unique_ptr(new Instruction())); + delete list; + + // Sorting because we do not care if the order of create and destruction is + // the same. Using generic sort just incase things are changed above. + std::sort(TestInstruction::created_instructions_.begin(), + TestInstruction::created_instructions_.end()); + std::sort(TestInstruction::deleted_instructions_.begin(), + TestInstruction::deleted_instructions_.end()); + EXPECT_THAT(TestInstruction::created_instructions_, + ContainerEq(TestInstruction::deleted_instructions_)); +} + +// Test the |InsertBefore| with a single instruction in the iterator class. +// Need to make sure the elements are inserted in the correct order, and the +// return value points to the correct location. +// +// Comparing addresses to make sure they remain stable, so other data structures +// can have pointers to instructions in InstructionList. +TEST(InstructionListTest, InsertBefore1) { + InstructionList list; + std::vector inserted_instructions; + for (int i = 0; i < 4; i++) { + std::unique_ptr inst(new Instruction()); + inserted_instructions.push_back(inst.get()); + auto new_element = list.end().InsertBefore(std::move(inst)); + EXPECT_EQ(&*new_element, inserted_instructions.back()); + } + + std::vector output; + for (auto& i : list) { + output.push_back(&i); + } + EXPECT_THAT(output, ContainerEq(inserted_instructions)); +} + +// Test inserting an entire vector of instructions using InsertBefore. Checking +// the order of insertion and the return value. +// +// Comparing addresses to make sure they remain stable, so other data structures +// can have pointers to instructions in InstructionList. +TEST(InstructionListTest, InsertBefore2) { + InstructionList list; + std::vector> new_instructions; + std::vector created_instructions; + for (int i = 0; i < 4; i++) { + std::unique_ptr inst(new Instruction()); + created_instructions.push_back(inst.get()); + new_instructions.push_back(std::move(inst)); + } + auto new_element = list.begin().InsertBefore(std::move(new_instructions)); + EXPECT_TRUE(new_instructions.empty()); + EXPECT_EQ(&*new_element, created_instructions.front()); + + std::vector output; + for (auto& i : list) { + output.push_back(&i); + } + EXPECT_THAT(output, ContainerEq(created_instructions)); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/instruction_test.cpp b/test/opt/instruction_test.cpp new file mode 100644 index 000000000..a6972011f --- /dev/null +++ b/test/opt/instruction_test.cpp @@ -0,0 +1,1135 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/instruction.h" +#include "source/opt/ir_context.h" +#include "spirv-tools/libspirv.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace opt { +namespace { + +using spvtest::MakeInstruction; +using ::testing::Eq; +using DescriptorTypeTest = PassTest<::testing::Test>; +using OpaqueTypeTest = PassTest<::testing::Test>; +using GetBaseTest = PassTest<::testing::Test>; +using ValidBasePointerTest = PassTest<::testing::Test>; + +TEST(InstructionTest, CreateTrivial) { + Instruction empty; + EXPECT_EQ(SpvOpNop, empty.opcode()); + EXPECT_EQ(0u, empty.type_id()); + EXPECT_EQ(0u, empty.result_id()); + EXPECT_EQ(0u, empty.NumOperands()); + EXPECT_EQ(0u, empty.NumOperandWords()); + EXPECT_EQ(0u, empty.NumInOperandWords()); + EXPECT_EQ(empty.cend(), empty.cbegin()); + EXPECT_EQ(empty.end(), empty.begin()); +} + +TEST(InstructionTest, CreateWithOpcodeAndNoOperands) { + IRContext context(SPV_ENV_UNIVERSAL_1_2, nullptr); + Instruction inst(&context, SpvOpReturn); + EXPECT_EQ(SpvOpReturn, inst.opcode()); + EXPECT_EQ(0u, inst.type_id()); + EXPECT_EQ(0u, inst.result_id()); + EXPECT_EQ(0u, inst.NumOperands()); + EXPECT_EQ(0u, inst.NumOperandWords()); + EXPECT_EQ(0u, inst.NumInOperandWords()); + EXPECT_EQ(inst.cend(), inst.cbegin()); + EXPECT_EQ(inst.end(), inst.begin()); +} + +// The words for an OpTypeInt for 32-bit signed integer resulting in Id 44. +uint32_t kSampleInstructionWords[] = {(4 << 16) | uint32_t(SpvOpTypeInt), 44, + 32, 1}; +// The operands that would be parsed from kSampleInstructionWords +spv_parsed_operand_t kSampleParsedOperands[] = { + {1, 1, SPV_OPERAND_TYPE_RESULT_ID, SPV_NUMBER_NONE, 0}, + {2, 1, SPV_OPERAND_TYPE_LITERAL_INTEGER, SPV_NUMBER_UNSIGNED_INT, 32}, + {3, 1, SPV_OPERAND_TYPE_LITERAL_INTEGER, SPV_NUMBER_UNSIGNED_INT, 1}, +}; + +// A valid parse of kSampleParsedOperands. +spv_parsed_instruction_t kSampleParsedInstruction = {kSampleInstructionWords, + uint16_t(4), + uint16_t(SpvOpTypeInt), + SPV_EXT_INST_TYPE_NONE, + 0, // type id + 44, // result id + kSampleParsedOperands, + 3}; + +// The words for an OpAccessChain instruction. +uint32_t kSampleAccessChainInstructionWords[] = { + (7 << 16) | uint32_t(SpvOpAccessChain), 100, 101, 102, 103, 104, 105}; + +// The operands that would be parsed from kSampleAccessChainInstructionWords. +spv_parsed_operand_t kSampleAccessChainOperands[] = { + {1, 1, SPV_OPERAND_TYPE_RESULT_ID, SPV_NUMBER_NONE, 0}, + {2, 1, SPV_OPERAND_TYPE_TYPE_ID, SPV_NUMBER_NONE, 0}, + {3, 1, SPV_OPERAND_TYPE_ID, SPV_NUMBER_NONE, 0}, + {4, 1, SPV_OPERAND_TYPE_ID, SPV_NUMBER_NONE, 0}, + {5, 1, SPV_OPERAND_TYPE_ID, SPV_NUMBER_NONE, 0}, + {6, 1, SPV_OPERAND_TYPE_ID, SPV_NUMBER_NONE, 0}, +}; + +// A valid parse of kSampleAccessChainInstructionWords +spv_parsed_instruction_t kSampleAccessChainInstruction = { + kSampleAccessChainInstructionWords, + uint16_t(7), + uint16_t(SpvOpAccessChain), + SPV_EXT_INST_TYPE_NONE, + 100, // type id + 101, // result id + kSampleAccessChainOperands, + 6}; + +// The words for an OpControlBarrier instruction. +uint32_t kSampleControlBarrierInstructionWords[] = { + (4 << 16) | uint32_t(SpvOpControlBarrier), 100, 101, 102}; + +// The operands that would be parsed from kSampleControlBarrierInstructionWords. +spv_parsed_operand_t kSampleControlBarrierOperands[] = { + {1, 1, SPV_OPERAND_TYPE_SCOPE_ID, SPV_NUMBER_NONE, 0}, // Execution + {2, 1, SPV_OPERAND_TYPE_SCOPE_ID, SPV_NUMBER_NONE, 0}, // Memory + {3, 1, SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID, SPV_NUMBER_NONE, + 0}, // Semantics +}; + +// A valid parse of kSampleControlBarrierInstructionWords +spv_parsed_instruction_t kSampleControlBarrierInstruction = { + kSampleControlBarrierInstructionWords, + uint16_t(4), + uint16_t(SpvOpControlBarrier), + SPV_EXT_INST_TYPE_NONE, + 0, // type id + 0, // result id + kSampleControlBarrierOperands, + 3}; + +TEST(InstructionTest, CreateWithOpcodeAndOperands) { + IRContext context(SPV_ENV_UNIVERSAL_1_2, nullptr); + Instruction inst(&context, kSampleParsedInstruction); + EXPECT_EQ(SpvOpTypeInt, inst.opcode()); + EXPECT_EQ(0u, inst.type_id()); + EXPECT_EQ(44u, inst.result_id()); + EXPECT_EQ(3u, inst.NumOperands()); + EXPECT_EQ(3u, inst.NumOperandWords()); + EXPECT_EQ(2u, inst.NumInOperandWords()); +} + +TEST(InstructionTest, GetOperand) { + IRContext context(SPV_ENV_UNIVERSAL_1_2, nullptr); + Instruction inst(&context, kSampleParsedInstruction); + EXPECT_THAT(inst.GetOperand(0).words, Eq(std::vector{44})); + EXPECT_THAT(inst.GetOperand(1).words, Eq(std::vector{32})); + EXPECT_THAT(inst.GetOperand(2).words, Eq(std::vector{1})); +} + +TEST(InstructionTest, GetInOperand) { + IRContext context(SPV_ENV_UNIVERSAL_1_2, nullptr); + Instruction inst(&context, kSampleParsedInstruction); + EXPECT_THAT(inst.GetInOperand(0).words, Eq(std::vector{32})); + EXPECT_THAT(inst.GetInOperand(1).words, Eq(std::vector{1})); +} + +TEST(InstructionTest, OperandConstIterators) { + IRContext context(SPV_ENV_UNIVERSAL_1_2, nullptr); + Instruction inst(&context, kSampleParsedInstruction); + // Spot check iteration across operands. + auto cbegin = inst.cbegin(); + auto cend = inst.cend(); + EXPECT_NE(cend, inst.cbegin()); + + auto citer = inst.cbegin(); + for (int i = 0; i < 3; ++i, ++citer) { + const auto& operand = *citer; + EXPECT_THAT(operand.type, Eq(kSampleParsedOperands[i].type)); + EXPECT_THAT(operand.words, + Eq(std::vector{kSampleInstructionWords[i + 1]})); + EXPECT_NE(cend, citer); + } + EXPECT_EQ(cend, citer); + + // Check that cbegin and cend have not changed. + EXPECT_EQ(cbegin, inst.cbegin()); + EXPECT_EQ(cend, inst.cend()); + + // Check arithmetic. + const Operand& operand2 = *(inst.cbegin() + 2); + EXPECT_EQ(SPV_OPERAND_TYPE_LITERAL_INTEGER, operand2.type); +} + +TEST(InstructionTest, OperandIterators) { + IRContext context(SPV_ENV_UNIVERSAL_1_2, nullptr); + Instruction inst(&context, kSampleParsedInstruction); + // Spot check iteration across operands, with mutable iterators. + auto begin = inst.begin(); + auto end = inst.end(); + EXPECT_NE(end, inst.begin()); + + auto iter = inst.begin(); + for (int i = 0; i < 3; ++i, ++iter) { + const auto& operand = *iter; + EXPECT_THAT(operand.type, Eq(kSampleParsedOperands[i].type)); + EXPECT_THAT(operand.words, + Eq(std::vector{kSampleInstructionWords[i + 1]})); + EXPECT_NE(end, iter); + } + EXPECT_EQ(end, iter); + + // Check that begin and end have not changed. + EXPECT_EQ(begin, inst.begin()); + EXPECT_EQ(end, inst.end()); + + // Check arithmetic. + Operand& operand2 = *(inst.begin() + 2); + EXPECT_EQ(SPV_OPERAND_TYPE_LITERAL_INTEGER, operand2.type); + + // Check mutation through an iterator. + operand2.type = SPV_OPERAND_TYPE_TYPE_ID; + EXPECT_EQ(SPV_OPERAND_TYPE_TYPE_ID, (*(inst.cbegin() + 2)).type); +} + +TEST(InstructionTest, ForInIdStandardIdTypes) { + IRContext context(SPV_ENV_UNIVERSAL_1_2, nullptr); + Instruction inst(&context, kSampleAccessChainInstruction); + + std::vector ids; + inst.ForEachInId([&ids](const uint32_t* idptr) { ids.push_back(*idptr); }); + EXPECT_THAT(ids, Eq(std::vector{102, 103, 104, 105})); + + ids.clear(); + inst.ForEachInId([&ids](uint32_t* idptr) { ids.push_back(*idptr); }); + EXPECT_THAT(ids, Eq(std::vector{102, 103, 104, 105})); +} + +TEST(InstructionTest, ForInIdNonstandardIdTypes) { + IRContext context(SPV_ENV_UNIVERSAL_1_2, nullptr); + Instruction inst(&context, kSampleControlBarrierInstruction); + + std::vector ids; + inst.ForEachInId([&ids](const uint32_t* idptr) { ids.push_back(*idptr); }); + EXPECT_THAT(ids, Eq(std::vector{100, 101, 102})); + + ids.clear(); + inst.ForEachInId([&ids](uint32_t* idptr) { ids.push_back(*idptr); }); + EXPECT_THAT(ids, Eq(std::vector{100, 101, 102})); +} + +TEST(InstructionTest, UniqueIds) { + IRContext context(SPV_ENV_UNIVERSAL_1_2, nullptr); + Instruction inst1(&context); + Instruction inst2(&context); + EXPECT_NE(inst1.unique_id(), inst2.unique_id()); +} + +TEST(InstructionTest, CloneUniqueIdDifferent) { + IRContext context(SPV_ENV_UNIVERSAL_1_2, nullptr); + Instruction inst(&context); + std::unique_ptr clone(inst.Clone(&context)); + EXPECT_EQ(inst.context(), clone->context()); + EXPECT_NE(inst.unique_id(), clone->unique_id()); +} + +TEST(InstructionTest, CloneDifferentContext) { + IRContext c1(SPV_ENV_UNIVERSAL_1_2, nullptr); + IRContext c2(SPV_ENV_UNIVERSAL_1_2, nullptr); + Instruction inst(&c1); + std::unique_ptr clone(inst.Clone(&c2)); + EXPECT_EQ(&c1, inst.context()); + EXPECT_EQ(&c2, clone->context()); + EXPECT_NE(&c1, &c2); +} + +TEST(InstructionTest, CloneDifferentContextDifferentUniqueId) { + IRContext c1(SPV_ENV_UNIVERSAL_1_2, nullptr); + IRContext c2(SPV_ENV_UNIVERSAL_1_2, nullptr); + Instruction inst(&c1); + Instruction other(&c2); + std::unique_ptr clone(inst.Clone(&c2)); + EXPECT_EQ(&c2, clone->context()); + EXPECT_NE(other.unique_id(), clone->unique_id()); +} + +TEST(InstructionTest, EqualsEqualsOperator) { + IRContext context(SPV_ENV_UNIVERSAL_1_2, nullptr); + Instruction i1(&context); + Instruction i2(&context); + std::unique_ptr clone(i1.Clone(&context)); + EXPECT_TRUE(i1 == i1); + EXPECT_FALSE(i1 == i2); + EXPECT_FALSE(i1 == *clone); + EXPECT_FALSE(i2 == *clone); +} + +TEST(InstructionTest, LessThanOperator) { + IRContext context(SPV_ENV_UNIVERSAL_1_2, nullptr); + Instruction i1(&context); + Instruction i2(&context); + std::unique_ptr clone(i1.Clone(&context)); + EXPECT_TRUE(i1 < i2); + EXPECT_TRUE(i1 < *clone); + EXPECT_TRUE(i2 < *clone); +} + +TEST_F(DescriptorTypeTest, StorageImage) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %3 "myStorageImage" + OpDecorate %3 DescriptorSet 0 + OpDecorate %3 Binding 0 + %4 = OpTypeVoid + %5 = OpTypeFunction %4 + %6 = OpTypeFloat 32 + %7 = OpTypeImage %6 2D 0 0 0 2 R32f + %8 = OpTypePointer UniformConstant %7 + %3 = OpVariable %8 UniformConstant + %2 = OpFunction %4 None %5 + %9 = OpLabel + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + Instruction* type = context->get_def_use_mgr()->GetDef(8); + EXPECT_TRUE(type->IsVulkanStorageImage()); + EXPECT_FALSE(type->IsVulkanSampledImage()); + EXPECT_FALSE(type->IsVulkanStorageTexelBuffer()); + EXPECT_FALSE(type->IsVulkanStorageBuffer()); + EXPECT_FALSE(type->IsVulkanUniformBuffer()); + + Instruction* variable = context->get_def_use_mgr()->GetDef(3); + EXPECT_FALSE(variable->IsReadOnlyVariable()); +} + +TEST_F(DescriptorTypeTest, SampledImage) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %3 "myStorageImage" + OpDecorate %3 DescriptorSet 0 + OpDecorate %3 Binding 0 + %4 = OpTypeVoid + %5 = OpTypeFunction %4 + %6 = OpTypeFloat 32 + %7 = OpTypeImage %6 2D 0 0 0 1 Unknown + %8 = OpTypePointer UniformConstant %7 + %3 = OpVariable %8 UniformConstant + %2 = OpFunction %4 None %5 + %9 = OpLabel + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + Instruction* type = context->get_def_use_mgr()->GetDef(8); + EXPECT_FALSE(type->IsVulkanStorageImage()); + EXPECT_TRUE(type->IsVulkanSampledImage()); + EXPECT_FALSE(type->IsVulkanStorageTexelBuffer()); + EXPECT_FALSE(type->IsVulkanStorageBuffer()); + EXPECT_FALSE(type->IsVulkanUniformBuffer()); + + Instruction* variable = context->get_def_use_mgr()->GetDef(3); + EXPECT_TRUE(variable->IsReadOnlyVariable()); +} + +TEST_F(DescriptorTypeTest, StorageTexelBuffer) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %3 "myStorageImage" + OpDecorate %3 DescriptorSet 0 + OpDecorate %3 Binding 0 + %4 = OpTypeVoid + %5 = OpTypeFunction %4 + %6 = OpTypeFloat 32 + %7 = OpTypeImage %6 Buffer 0 0 0 2 R32f + %8 = OpTypePointer UniformConstant %7 + %3 = OpVariable %8 UniformConstant + %2 = OpFunction %4 None %5 + %9 = OpLabel + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + Instruction* type = context->get_def_use_mgr()->GetDef(8); + EXPECT_FALSE(type->IsVulkanStorageImage()); + EXPECT_FALSE(type->IsVulkanSampledImage()); + EXPECT_TRUE(type->IsVulkanStorageTexelBuffer()); + EXPECT_FALSE(type->IsVulkanStorageBuffer()); + EXPECT_FALSE(type->IsVulkanUniformBuffer()); + + Instruction* variable = context->get_def_use_mgr()->GetDef(3); + EXPECT_FALSE(variable->IsReadOnlyVariable()); +} + +TEST_F(DescriptorTypeTest, StorageBuffer) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %3 "myStorageImage" + OpDecorate %3 DescriptorSet 0 + OpDecorate %3 Binding 0 + OpDecorate %9 BufferBlock + %4 = OpTypeVoid + %5 = OpTypeFunction %4 + %6 = OpTypeFloat 32 + %7 = OpTypeVector %6 4 + %8 = OpTypeRuntimeArray %7 + %9 = OpTypeStruct %8 + %10 = OpTypePointer Uniform %9 + %3 = OpVariable %10 Uniform + %2 = OpFunction %4 None %5 + %11 = OpLabel + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + Instruction* type = context->get_def_use_mgr()->GetDef(10); + EXPECT_FALSE(type->IsVulkanStorageImage()); + EXPECT_FALSE(type->IsVulkanSampledImage()); + EXPECT_FALSE(type->IsVulkanStorageTexelBuffer()); + EXPECT_TRUE(type->IsVulkanStorageBuffer()); + EXPECT_FALSE(type->IsVulkanUniformBuffer()); + + Instruction* variable = context->get_def_use_mgr()->GetDef(3); + EXPECT_FALSE(variable->IsReadOnlyVariable()); +} + +TEST_F(DescriptorTypeTest, UniformBuffer) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %3 "myStorageImage" + OpDecorate %3 DescriptorSet 0 + OpDecorate %3 Binding 0 + OpDecorate %9 Block + %4 = OpTypeVoid + %5 = OpTypeFunction %4 + %6 = OpTypeFloat 32 + %7 = OpTypeVector %6 4 + %8 = OpTypeRuntimeArray %7 + %9 = OpTypeStruct %8 + %10 = OpTypePointer Uniform %9 + %3 = OpVariable %10 Uniform + %2 = OpFunction %4 None %5 + %11 = OpLabel + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + Instruction* type = context->get_def_use_mgr()->GetDef(10); + EXPECT_FALSE(type->IsVulkanStorageImage()); + EXPECT_FALSE(type->IsVulkanSampledImage()); + EXPECT_FALSE(type->IsVulkanStorageTexelBuffer()); + EXPECT_FALSE(type->IsVulkanStorageBuffer()); + EXPECT_TRUE(type->IsVulkanUniformBuffer()); + + Instruction* variable = context->get_def_use_mgr()->GetDef(3); + EXPECT_TRUE(variable->IsReadOnlyVariable()); +} + +TEST_F(DescriptorTypeTest, NonWritableIsReadOnly) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %3 "myStorageImage" + OpDecorate %3 DescriptorSet 0 + OpDecorate %3 Binding 0 + OpDecorate %9 BufferBlock + OpDecorate %3 NonWritable + %4 = OpTypeVoid + %5 = OpTypeFunction %4 + %6 = OpTypeFloat 32 + %7 = OpTypeVector %6 4 + %8 = OpTypeRuntimeArray %7 + %9 = OpTypeStruct %8 + %10 = OpTypePointer Uniform %9 + %3 = OpVariable %10 Uniform + %2 = OpFunction %4 None %5 + %11 = OpLabel + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + Instruction* variable = context->get_def_use_mgr()->GetDef(3); + EXPECT_TRUE(variable->IsReadOnlyVariable()); +} + +TEST_F(OpaqueTypeTest, BaseOpaqueTypesShader) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypeImage %5 2D 1 0 0 1 Unknown + %7 = OpTypeSampler + %8 = OpTypeSampledImage %6 + %9 = OpTypeRuntimeArray %5 + %2 = OpFunction %3 None %4 + %10 = OpLabel + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + Instruction* image_type = context->get_def_use_mgr()->GetDef(6); + EXPECT_TRUE(image_type->IsOpaqueType()); + Instruction* sampler_type = context->get_def_use_mgr()->GetDef(7); + EXPECT_TRUE(sampler_type->IsOpaqueType()); + Instruction* sampled_image_type = context->get_def_use_mgr()->GetDef(8); + EXPECT_TRUE(sampled_image_type->IsOpaqueType()); + Instruction* runtime_array_type = context->get_def_use_mgr()->GetDef(9); + EXPECT_TRUE(runtime_array_type->IsOpaqueType()); + Instruction* float_type = context->get_def_use_mgr()->GetDef(5); + EXPECT_FALSE(float_type->IsOpaqueType()); + Instruction* void_type = context->get_def_use_mgr()->GetDef(3); + EXPECT_FALSE(void_type->IsOpaqueType()); +} + +TEST_F(OpaqueTypeTest, OpaqueStructTypes) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypeRuntimeArray %5 + %7 = OpTypeStruct %6 %6 + %8 = OpTypeStruct %5 %6 + %9 = OpTypeStruct %6 %5 + %10 = OpTypeStruct %7 + %2 = OpFunction %3 None %4 + %11 = OpLabel + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + for (int i = 7; i <= 10; i++) { + Instruction* type = context->get_def_use_mgr()->GetDef(i); + EXPECT_TRUE(type->IsOpaqueType()); + } +} + +TEST_F(GetBaseTest, SampleImage) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %3 "myStorageImage" + OpDecorate %3 DescriptorSet 0 + OpDecorate %3 Binding 0 + %4 = OpTypeVoid + %5 = OpTypeFunction %4 + %6 = OpTypeFloat 32 + %7 = OpTypeVector %6 2 + %8 = OpTypeVector %6 4 + %9 = OpConstant %6 0 + %10 = OpConstantComposite %7 %9 %9 + %11 = OpTypeImage %6 2D 0 0 0 1 R32f + %12 = OpTypePointer UniformConstant %11 + %3 = OpVariable %12 UniformConstant + %13 = OpTypeSampledImage %11 + %14 = OpTypeSampler + %15 = OpTypePointer UniformConstant %14 + %16 = OpVariable %15 UniformConstant + %2 = OpFunction %4 None %5 + %17 = OpLabel + %18 = OpLoad %11 %3 + %19 = OpLoad %14 %16 + %20 = OpSampledImage %13 %18 %19 + %21 = OpImageSampleImplicitLod %8 %20 %10 + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + Instruction* load = context->get_def_use_mgr()->GetDef(21); + Instruction* base = context->get_def_use_mgr()->GetDef(20); + EXPECT_TRUE(load->GetBaseAddress() == base); +} + +TEST_F(GetBaseTest, PtrAccessChain) { + const std::string text = R"( + OpCapability VariablePointers + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "PSMain" %2 + OpExecutionMode %1 OriginUpperLeft + %void = OpTypeVoid + %4 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %int = OpTypeInt 32 8388353 + %int_0 = OpConstant %int 0 +%_ptr_Function_v4float = OpTypePointer Function %v4float + %2 = OpVariable %_ptr_Function_v4float Input + %1 = OpFunction %void None %4 + %10 = OpLabel + %11 = OpPtrAccessChain %_ptr_Function_v4float %2 %int_0 + %12 = OpLoad %v4float %11 + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + Instruction* load = context->get_def_use_mgr()->GetDef(12); + Instruction* base = context->get_def_use_mgr()->GetDef(2); + EXPECT_TRUE(load->GetBaseAddress() == base); +} + +TEST_F(GetBaseTest, ImageRead) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %3 "myStorageImage" + OpDecorate %3 DescriptorSet 0 + OpDecorate %3 Binding 0 + %4 = OpTypeVoid + %5 = OpTypeFunction %4 + %6 = OpTypeInt 32 0 + %7 = OpTypeVector %6 2 + %8 = OpConstant %6 0 + %9 = OpConstantComposite %7 %8 %8 + %10 = OpTypeImage %6 2D 0 0 0 2 R32f + %11 = OpTypePointer UniformConstant %10 + %3 = OpVariable %11 UniformConstant + %2 = OpFunction %4 None %5 + %12 = OpLabel + %13 = OpLoad %10 %3 + %14 = OpImageRead %6 %13 %9 + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + Instruction* load = context->get_def_use_mgr()->GetDef(14); + Instruction* base = context->get_def_use_mgr()->GetDef(13); + EXPECT_TRUE(load->GetBaseAddress() == base); +} + +TEST_F(ValidBasePointerTest, OpSelectBadNoVariablePointersStorageBuffer) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer StorageBuffer %3 +%5 = OpVariable %4 StorageBuffer +%6 = OpTypeFunction %2 +%7 = OpTypeBool +%8 = OpConstantTrue %7 +%1 = OpFunction %2 None %6 +%9 = OpLabel +%10 = OpSelect %4 %8 %5 %5 +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* select = context->get_def_use_mgr()->GetDef(10); + EXPECT_NE(select, nullptr); + EXPECT_FALSE(select->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpSelectBadNoVariablePointers) { + const std::string text = R"( +OpCapability Shader +OpCapability VariablePointersStorageBuffer +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Workgroup %3 +%5 = OpVariable %4 Workgroup +%6 = OpTypeFunction %2 +%7 = OpTypeBool +%8 = OpConstantTrue %7 +%1 = OpFunction %2 None %6 +%9 = OpLabel +%10 = OpSelect %4 %8 %5 %5 +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* select = context->get_def_use_mgr()->GetDef(10); + EXPECT_NE(select, nullptr); + EXPECT_FALSE(select->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpSelectGoodVariablePointersStorageBuffer) { + const std::string text = R"( +OpCapability Shader +OpCapability VariablePointersStorageBuffer +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer StorageBuffer %3 +%5 = OpVariable %4 StorageBuffer +%6 = OpTypeFunction %2 +%7 = OpTypeBool +%8 = OpConstantTrue %7 +%1 = OpFunction %2 None %6 +%9 = OpLabel +%10 = OpSelect %4 %8 %5 %5 +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* select = context->get_def_use_mgr()->GetDef(10); + EXPECT_NE(select, nullptr); + EXPECT_TRUE(select->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpSelectGoodVariablePointers) { + const std::string text = R"( +OpCapability Shader +OpCapability VariablePointers +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Workgroup %3 +%5 = OpVariable %4 Workgroup +%6 = OpTypeFunction %2 +%7 = OpTypeBool +%8 = OpConstantTrue %7 +%1 = OpFunction %2 None %6 +%9 = OpLabel +%10 = OpSelect %4 %8 %5 %5 +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* select = context->get_def_use_mgr()->GetDef(10); + EXPECT_NE(select, nullptr); + EXPECT_TRUE(select->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpConstantNullBadNoVariablePointersStorageBuffer) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer StorageBuffer %3 +%5 = OpConstantNull %4 +%6 = OpTypeFunction %2 +%1 = OpFunction %2 None %6 +%7 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* null_inst = context->get_def_use_mgr()->GetDef(5); + EXPECT_NE(null_inst, nullptr); + EXPECT_FALSE(null_inst->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpConstantNullBadNoVariablePointers) { + const std::string text = R"( +OpCapability Shader +OpCapability VariablePointersStorageBuffer +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Workgroup %3 +%5 = OpConstantNull %4 +%6 = OpTypeFunction %2 +%1 = OpFunction %2 None %6 +%7 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* null_inst = context->get_def_use_mgr()->GetDef(5); + EXPECT_NE(null_inst, nullptr); + EXPECT_FALSE(null_inst->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpConstantNullGoodVariablePointersStorageBuffer) { + const std::string text = R"( +OpCapability Shader +OpCapability VariablePointersStorageBuffer +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer StorageBuffer %3 +%5 = OpConstantNull %4 +%6 = OpTypeFunction %2 +%1 = OpFunction %2 None %6 +%9 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* null_inst = context->get_def_use_mgr()->GetDef(5); + EXPECT_NE(null_inst, nullptr); + EXPECT_TRUE(null_inst->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpConstantNullGoodVariablePointers) { + const std::string text = R"( +OpCapability Shader +OpCapability VariablePointers +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Workgroup %3 +%5 = OpConstantNull %4 +%6 = OpTypeFunction %2 +%1 = OpFunction %2 None %6 +%7 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* null_inst = context->get_def_use_mgr()->GetDef(5); + EXPECT_NE(null_inst, nullptr); + EXPECT_TRUE(null_inst->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpPhiBadNoVariablePointersStorageBuffer) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer StorageBuffer %3 +%5 = OpVariable %4 StorageBuffer +%6 = OpTypeFunction %2 +%1 = OpFunction %2 None %6 +%7 = OpLabel +OpBranch %8 +%8 = OpLabel +%9 = OpPhi %4 %5 %7 +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* phi = context->get_def_use_mgr()->GetDef(9); + EXPECT_NE(phi, nullptr); + EXPECT_FALSE(phi->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpPhiBadNoVariablePointers) { + const std::string text = R"( +OpCapability Shader +OpCapability VariablePointersStorageBuffer +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Workgroup %3 +%5 = OpVariable %4 Workgroup +%6 = OpTypeFunction %2 +%1 = OpFunction %2 None %6 +%7 = OpLabel +OpBranch %8 +%8 = OpLabel +%9 = OpPhi %4 %5 %7 +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* phi = context->get_def_use_mgr()->GetDef(9); + EXPECT_NE(phi, nullptr); + EXPECT_FALSE(phi->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpPhiGoodVariablePointersStorageBuffer) { + const std::string text = R"( +OpCapability Shader +OpCapability VariablePointersStorageBuffer +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer StorageBuffer %3 +%5 = OpVariable %4 StorageBuffer +%6 = OpTypeFunction %2 +%1 = OpFunction %2 None %6 +%7 = OpLabel +OpBranch %8 +%8 = OpLabel +%9 = OpPhi %4 %5 %7 +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* phi = context->get_def_use_mgr()->GetDef(9); + EXPECT_NE(phi, nullptr); + EXPECT_TRUE(phi->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpPhiGoodVariablePointers) { + const std::string text = R"( +OpCapability Shader +OpCapability VariablePointers +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Workgroup %3 +%5 = OpVariable %4 Workgroup +%6 = OpTypeFunction %2 +%1 = OpFunction %2 None %6 +%7 = OpLabel +OpBranch %8 +%8 = OpLabel +%9 = OpPhi %4 %5 %7 +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* phi = context->get_def_use_mgr()->GetDef(9); + EXPECT_NE(phi, nullptr); + EXPECT_TRUE(phi->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpFunctionCallBadNoVariablePointersStorageBuffer) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer StorageBuffer %3 +%5 = OpConstantNull %4 +%6 = OpTypeFunction %2 +%7 = OpTypeFunction %4 +%1 = OpFunction %2 None %6 +%8 = OpLabel +%9 = OpFunctionCall %4 %10 +OpReturn +OpFunctionEnd +%10 = OpFunction %4 None %7 +%11 = OpLabel +OpReturnValue %5 +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* null_inst = context->get_def_use_mgr()->GetDef(9); + EXPECT_NE(null_inst, nullptr); + EXPECT_FALSE(null_inst->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpFunctionCallBadNoVariablePointers) { + const std::string text = R"( +OpCapability Shader +OpCapability VariablePointersStorageBuffer +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Workgroup %3 +%5 = OpConstantNull %4 +%6 = OpTypeFunction %2 +%7 = OpTypeFunction %4 +%1 = OpFunction %2 None %6 +%8 = OpLabel +%9 = OpFunctionCall %4 %10 +OpReturn +OpFunctionEnd +%10 = OpFunction %4 None %7 +%11 = OpLabel +OpReturnValue %5 +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* null_inst = context->get_def_use_mgr()->GetDef(9); + EXPECT_NE(null_inst, nullptr); + EXPECT_FALSE(null_inst->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpFunctionCallGoodVariablePointersStorageBuffer) { + const std::string text = R"( +OpCapability Shader +OpCapability VariablePointersStorageBuffer +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer StorageBuffer %3 +%5 = OpConstantNull %4 +%6 = OpTypeFunction %2 +%7 = OpTypeFunction %4 +%1 = OpFunction %2 None %6 +%8 = OpLabel +%9 = OpFunctionCall %4 %10 +OpReturn +OpFunctionEnd +%10 = OpFunction %4 None %7 +%11 = OpLabel +OpReturnValue %5 +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* null_inst = context->get_def_use_mgr()->GetDef(9); + EXPECT_NE(null_inst, nullptr); + EXPECT_TRUE(null_inst->IsValidBasePointer()); +} + +TEST_F(ValidBasePointerTest, OpFunctionCallGoodVariablePointers) { + const std::string text = R"( +OpCapability Shader +OpCapability VariablePointers +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Workgroup %3 +%5 = OpConstantNull %4 +%6 = OpTypeFunction %2 +%7 = OpTypeFunction %4 +%1 = OpFunction %2 None %6 +%8 = OpLabel +%9 = OpFunctionCall %4 %10 +OpReturn +OpFunctionEnd +%10 = OpFunction %4 None %7 +%11 = OpLabel +OpReturnValue %5 +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text); + EXPECT_NE(context, nullptr); + Instruction* null_inst = context->get_def_use_mgr()->GetDef(9); + EXPECT_NE(null_inst, nullptr); + EXPECT_TRUE(null_inst->IsValidBasePointer()); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/ir_builder.cpp b/test/opt/ir_builder.cpp new file mode 100644 index 000000000..f800ca437 --- /dev/null +++ b/test/opt/ir_builder.cpp @@ -0,0 +1,439 @@ +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "effcee/effcee.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "source/opt/basic_block.h" +#include "source/opt/build_module.h" +#include "source/opt/instruction.h" +#include "source/opt/ir_builder.h" +#include "source/opt/type_manager.h" +#include "spirv-tools/libspirv.hpp" + +namespace spvtools { +namespace opt { +namespace { + +using Analysis = IRContext::Analysis; +using IRBuilderTest = ::testing::Test; + +bool Validate(const std::vector& bin) { + spv_target_env target_env = SPV_ENV_UNIVERSAL_1_2; + spv_context spvContext = spvContextCreate(target_env); + spv_diagnostic diagnostic = nullptr; + spv_const_binary_t binary = {bin.data(), bin.size()}; + spv_result_t error = spvValidate(spvContext, &binary, &diagnostic); + if (error != 0) spvDiagnosticPrint(diagnostic); + spvDiagnosticDestroy(diagnostic); + spvContextDestroy(spvContext); + return error == 0; +} + +void Match(const std::string& original, IRContext* context, + bool do_validation = true) { + std::vector bin; + context->module()->ToBinary(&bin, true); + if (do_validation) { + EXPECT_TRUE(Validate(bin)); + } + std::string assembly; + SpirvTools tools(SPV_ENV_UNIVERSAL_1_2); + EXPECT_TRUE( + tools.Disassemble(bin, &assembly, SpirvTools::kDefaultDisassembleOption)) + << "Disassembling failed for shader:\n" + << assembly << std::endl; + auto match_result = effcee::Match(assembly, original); + EXPECT_EQ(effcee::Result::Status::Ok, match_result.status()) + << match_result.message() << "\nChecking result:\n" + << assembly; +} + +TEST_F(IRBuilderTest, TestInsnAddition) { + const std::string text = R"( +; CHECK: %18 = OpLabel +; CHECK: OpPhi %int %int_0 %14 +; CHECK: OpPhi %bool %16 %14 +; CHECK: OpBranch %17 + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 330 + OpName %2 "main" + OpName %4 "i" + OpName %3 "c" + OpDecorate %3 Location 0 + %5 = OpTypeVoid + %6 = OpTypeFunction %5 + %7 = OpTypeInt 32 1 + %8 = OpTypePointer Function %7 + %9 = OpConstant %7 0 + %10 = OpTypeBool + %11 = OpTypeFloat 32 + %12 = OpTypeVector %11 4 + %13 = OpTypePointer Output %12 + %3 = OpVariable %13 Output + %2 = OpFunction %5 None %6 + %14 = OpLabel + %4 = OpVariable %8 Function + OpStore %4 %9 + %15 = OpLoad %7 %4 + %16 = OpINotEqual %10 %15 %9 + OpSelectionMerge %17 None + OpBranchConditional %16 %18 %17 + %18 = OpLabel + OpBranch %17 + %17 = OpLabel + OpReturn + OpFunctionEnd +)"; + + { + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + + BasicBlock* bb = context->cfg()->block(18); + + // Build managers. + context->get_def_use_mgr(); + context->get_instr_block(nullptr); + + InstructionBuilder builder(context.get(), &*bb->begin()); + Instruction* phi1 = builder.AddPhi(7, {9, 14}); + Instruction* phi2 = builder.AddPhi(10, {16, 14}); + + // Make sure the InstructionBuilder did not update the def/use manager. + EXPECT_EQ(context->get_def_use_mgr()->GetDef(phi1->result_id()), nullptr); + EXPECT_EQ(context->get_def_use_mgr()->GetDef(phi2->result_id()), nullptr); + EXPECT_EQ(context->get_instr_block(phi1), nullptr); + EXPECT_EQ(context->get_instr_block(phi2), nullptr); + + Match(text, context.get()); + } + + { + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + + // Build managers. + context->get_def_use_mgr(); + context->get_instr_block(nullptr); + + BasicBlock* bb = context->cfg()->block(18); + InstructionBuilder builder( + context.get(), &*bb->begin(), + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + Instruction* phi1 = builder.AddPhi(7, {9, 14}); + Instruction* phi2 = builder.AddPhi(10, {16, 14}); + + // Make sure InstructionBuilder updated the def/use manager + EXPECT_NE(context->get_def_use_mgr()->GetDef(phi1->result_id()), nullptr); + EXPECT_NE(context->get_def_use_mgr()->GetDef(phi2->result_id()), nullptr); + EXPECT_NE(context->get_instr_block(phi1), nullptr); + EXPECT_NE(context->get_instr_block(phi2), nullptr); + + Match(text, context.get()); + } +} + +TEST_F(IRBuilderTest, TestCondBranchAddition) { + const std::string text = R"( +; CHECK: %main = OpFunction %void None %6 +; CHECK-NEXT: %15 = OpLabel +; CHECK-NEXT: OpSelectionMerge %13 None +; CHECK-NEXT: OpBranchConditional %true %14 %13 +; CHECK-NEXT: %14 = OpLabel +; CHECK-NEXT: OpBranch %13 +; CHECK-NEXT: %13 = OpLabel +; CHECK-NEXT: OpReturn + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 330 + OpName %2 "main" + OpName %4 "i" + OpName %3 "c" + OpDecorate %3 Location 0 + %5 = OpTypeVoid + %6 = OpTypeFunction %5 + %7 = OpTypeBool + %8 = OpTypePointer Private %7 + %9 = OpConstantTrue %7 + %10 = OpTypeFloat 32 + %11 = OpTypeVector %10 4 + %12 = OpTypePointer Output %11 + %3 = OpVariable %12 Output + %4 = OpVariable %8 Private + %2 = OpFunction %5 None %6 + %13 = OpLabel + OpReturn + OpFunctionEnd +)"; + + { + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + + Function& fn = *context->module()->begin(); + + BasicBlock& bb_merge = *fn.begin(); + + // TODO(1841): Handle id overflow. + fn.begin().InsertBefore(std::unique_ptr( + new BasicBlock(std::unique_ptr(new Instruction( + context.get(), SpvOpLabel, 0, context->TakeNextId(), {}))))); + BasicBlock& bb_true = *fn.begin(); + { + InstructionBuilder builder(context.get(), &*bb_true.begin()); + builder.AddBranch(bb_merge.id()); + } + + // TODO(1841): Handle id overflow. + fn.begin().InsertBefore(std::unique_ptr( + new BasicBlock(std::unique_ptr(new Instruction( + context.get(), SpvOpLabel, 0, context->TakeNextId(), {}))))); + BasicBlock& bb_cond = *fn.begin(); + + InstructionBuilder builder(context.get(), &bb_cond); + // This also test consecutive instruction insertion: merge selection + + // branch. + builder.AddConditionalBranch(9, bb_true.id(), bb_merge.id(), bb_merge.id()); + + Match(text, context.get()); + } +} + +TEST_F(IRBuilderTest, AddSelect) { + const std::string text = R"( +; CHECK: [[bool:%\w+]] = OpTypeBool +; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 +; CHECK: [[true:%\w+]] = OpConstantTrue [[bool]] +; CHECK: [[u0:%\w+]] = OpConstant [[uint]] 0 +; CHECK: [[u1:%\w+]] = OpConstant [[uint]] 1 +; CHECK: OpSelect [[uint]] [[true]] [[u0]] [[u1]] +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +%1 = OpTypeVoid +%2 = OpTypeBool +%3 = OpTypeInt 32 0 +%4 = OpConstantTrue %2 +%5 = OpConstant %3 0 +%6 = OpConstant %3 1 +%7 = OpTypeFunction %1 +%8 = OpFunction %1 None %7 +%9 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + EXPECT_NE(nullptr, context); + + InstructionBuilder builder(context.get(), + &*context->module()->begin()->begin()->begin()); + EXPECT_NE(nullptr, builder.AddSelect(3u, 4u, 5u, 6u)); + + Match(text, context.get()); +} + +TEST_F(IRBuilderTest, AddCompositeConstruct) { + const std::string text = R"( +; CHECK: [[uint:%\w+]] = OpTypeInt +; CHECK: [[u0:%\w+]] = OpConstant [[uint]] 0 +; CHECK: [[u1:%\w+]] = OpConstant [[uint]] 1 +; CHECK: [[struct:%\w+]] = OpTypeStruct [[uint]] [[uint]] [[uint]] [[uint]] +; CHECK: OpCompositeConstruct [[struct]] [[u0]] [[u1]] [[u1]] [[u0]] +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 0 +%4 = OpConstant %2 1 +%5 = OpTypeStruct %2 %2 %2 %2 +%6 = OpTypeFunction %1 +%7 = OpFunction %1 None %6 +%8 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + EXPECT_NE(nullptr, context); + + InstructionBuilder builder(context.get(), + &*context->module()->begin()->begin()->begin()); + std::vector ids = {3u, 4u, 4u, 3u}; + EXPECT_NE(nullptr, builder.AddCompositeConstruct(5u, ids)); + + Match(text, context.get()); +} + +TEST_F(IRBuilderTest, ConstantAdder) { + const std::string text = R"( +; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 +; CHECK: OpConstant [[uint]] 13 +; CHECK: [[sint:%\w+]] = OpTypeInt 32 1 +; CHECK: OpConstant [[sint]] -1 +; CHECK: OpConstant [[uint]] 1 +; CHECK: OpConstant [[sint]] 34 +; CHECK: OpConstant [[uint]] 0 +; CHECK: OpConstant [[sint]] 0 +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpFunction %1 None %2 +%4 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + EXPECT_NE(nullptr, context); + + InstructionBuilder builder(context.get(), + &*context->module()->begin()->begin()->begin()); + EXPECT_NE(nullptr, builder.GetUintConstant(13)); + EXPECT_NE(nullptr, builder.GetSintConstant(-1)); + + // Try adding the same constants again to make sure they aren't added. + EXPECT_NE(nullptr, builder.GetUintConstant(13)); + EXPECT_NE(nullptr, builder.GetSintConstant(-1)); + + // Try adding different constants to make sure the type is reused. + EXPECT_NE(nullptr, builder.GetUintConstant(1)); + EXPECT_NE(nullptr, builder.GetSintConstant(34)); + + // Try adding 0 as both signed and unsigned. + EXPECT_NE(nullptr, builder.GetUintConstant(0)); + EXPECT_NE(nullptr, builder.GetSintConstant(0)); + + Match(text, context.get()); +} + +TEST_F(IRBuilderTest, ConstantAdderTypeAlreadyExists) { + const std::string text = R"( +; CHECK: OpConstant %uint 13 +; CHECK: OpConstant %int -1 +; CHECK: OpConstant %uint 1 +; CHECK: OpConstant %int 34 +; CHECK: OpConstant %uint 0 +; CHECK: OpConstant %int 0 +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%uint = OpTypeInt 32 0 +%int = OpTypeInt 32 1 +%4 = OpTypeFunction %1 +%5 = OpFunction %1 None %4 +%6 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + EXPECT_NE(nullptr, context); + + InstructionBuilder builder(context.get(), + &*context->module()->begin()->begin()->begin()); + Instruction* const_1 = builder.GetUintConstant(13); + Instruction* const_2 = builder.GetSintConstant(-1); + + EXPECT_NE(nullptr, const_1); + EXPECT_NE(nullptr, const_2); + + // Try adding the same constants again to make sure they aren't added. + EXPECT_EQ(const_1, builder.GetUintConstant(13)); + EXPECT_EQ(const_2, builder.GetSintConstant(-1)); + + Instruction* const_3 = builder.GetUintConstant(1); + Instruction* const_4 = builder.GetSintConstant(34); + + // Try adding different constants to make sure the type is reused. + EXPECT_NE(nullptr, const_3); + EXPECT_NE(nullptr, const_4); + + Instruction* const_5 = builder.GetUintConstant(0); + Instruction* const_6 = builder.GetSintConstant(0); + + // Try adding 0 as both signed and unsigned. + EXPECT_NE(nullptr, const_5); + EXPECT_NE(nullptr, const_6); + + // They have the same value but different types so should be unique. + EXPECT_NE(const_5, const_6); + + // Check the types are correct. + uint32_t type_id_unsigned = const_1->GetSingleWordOperand(0); + uint32_t type_id_signed = const_2->GetSingleWordOperand(0); + + EXPECT_NE(type_id_unsigned, type_id_signed); + + EXPECT_EQ(const_3->GetSingleWordOperand(0), type_id_unsigned); + EXPECT_EQ(const_5->GetSingleWordOperand(0), type_id_unsigned); + + EXPECT_EQ(const_4->GetSingleWordOperand(0), type_id_signed); + EXPECT_EQ(const_6->GetSingleWordOperand(0), type_id_signed); + + Match(text, context.get()); +} + +TEST_F(IRBuilderTest, AccelerationStructureNV) { + const std::string text = R"( +; CHECK: OpTypeAccelerationStructureNV +OpCapability Shader +OpCapability RayTracingNV +OpExtension "SPV_NV_ray_tracing" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %8 "main" +OpExecutionMode %8 OriginUpperLeft +%1 = OpTypeVoid +%2 = OpTypeBool +%3 = OpTypeAccelerationStructureNV +%7 = OpTypeFunction %1 +%8 = OpFunction %1 None %7 +%9 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + EXPECT_NE(nullptr, context); + + InstructionBuilder builder(context.get(), + &*context->module()->begin()->begin()->begin()); + Match(text, context.get()); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/ir_context_test.cpp b/test/opt/ir_context_test.cpp new file mode 100644 index 000000000..4e2f5b2c4 --- /dev/null +++ b/test/opt/ir_context_test.cpp @@ -0,0 +1,669 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "source/opt/ir_context.h" +#include "source/opt/pass.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using Analysis = IRContext::Analysis; +using ::testing::Each; +using ::testing::UnorderedElementsAre; + +class DummyPassPreservesNothing : public Pass { + public: + DummyPassPreservesNothing(Status s) : Pass(), status_to_return_(s) {} + + const char* name() const override { return "dummy-pass"; } + Status Process() override { return status_to_return_; } + + private: + Status status_to_return_; +}; + +class DummyPassPreservesAll : public Pass { + public: + DummyPassPreservesAll(Status s) : Pass(), status_to_return_(s) {} + + const char* name() const override { return "dummy-pass"; } + Status Process() override { return status_to_return_; } + + Analysis GetPreservedAnalyses() override { + return Analysis(IRContext::kAnalysisEnd - 1); + } + + private: + Status status_to_return_; +}; + +class DummyPassPreservesFirst : public Pass { + public: + DummyPassPreservesFirst(Status s) : Pass(), status_to_return_(s) {} + + const char* name() const override { return "dummy-pass"; } + Status Process() override { return status_to_return_; } + + Analysis GetPreservedAnalyses() override { return IRContext::kAnalysisBegin; } + + private: + Status status_to_return_; +}; + +using IRContextTest = PassTest<::testing::Test>; + +TEST_F(IRContextTest, IndividualValidAfterBuild) { + std::unique_ptr module(new Module()); + IRContext localContext(SPV_ENV_UNIVERSAL_1_2, std::move(module), + spvtools::MessageConsumer()); + + for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; + i <<= 1) { + localContext.BuildInvalidAnalyses(i); + EXPECT_TRUE(localContext.AreAnalysesValid(i)); + } +} + +TEST_F(IRContextTest, AllValidAfterBuild) { + std::unique_ptr module = MakeUnique(); + IRContext localContext(SPV_ENV_UNIVERSAL_1_2, std::move(module), + spvtools::MessageConsumer()); + + Analysis built_analyses = IRContext::kAnalysisNone; + for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; + i <<= 1) { + localContext.BuildInvalidAnalyses(i); + built_analyses |= i; + } + EXPECT_TRUE(localContext.AreAnalysesValid(built_analyses)); +} + +TEST_F(IRContextTest, AllValidAfterPassNoChange) { + std::unique_ptr module = MakeUnique(); + IRContext localContext(SPV_ENV_UNIVERSAL_1_2, std::move(module), + spvtools::MessageConsumer()); + + Analysis built_analyses = IRContext::kAnalysisNone; + for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; + i <<= 1) { + localContext.BuildInvalidAnalyses(i); + built_analyses |= i; + } + + DummyPassPreservesNothing pass(Pass::Status::SuccessWithoutChange); + Pass::Status s = pass.Run(&localContext); + EXPECT_EQ(s, Pass::Status::SuccessWithoutChange); + EXPECT_TRUE(localContext.AreAnalysesValid(built_analyses)); +} + +TEST_F(IRContextTest, NoneValidAfterPassWithChange) { + std::unique_ptr module = MakeUnique(); + IRContext localContext(SPV_ENV_UNIVERSAL_1_2, std::move(module), + spvtools::MessageConsumer()); + + for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; + i <<= 1) { + localContext.BuildInvalidAnalyses(i); + } + + DummyPassPreservesNothing pass(Pass::Status::SuccessWithChange); + Pass::Status s = pass.Run(&localContext); + EXPECT_EQ(s, Pass::Status::SuccessWithChange); + for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; + i <<= 1) { + EXPECT_FALSE(localContext.AreAnalysesValid(i)); + } +} + +TEST_F(IRContextTest, AllPreservedAfterPassWithChange) { + std::unique_ptr module = MakeUnique(); + IRContext localContext(SPV_ENV_UNIVERSAL_1_2, std::move(module), + spvtools::MessageConsumer()); + + for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; + i <<= 1) { + localContext.BuildInvalidAnalyses(i); + } + + DummyPassPreservesAll pass(Pass::Status::SuccessWithChange); + Pass::Status s = pass.Run(&localContext); + EXPECT_EQ(s, Pass::Status::SuccessWithChange); + for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; + i <<= 1) { + EXPECT_TRUE(localContext.AreAnalysesValid(i)); + } +} + +TEST_F(IRContextTest, PreserveFirstOnlyAfterPassWithChange) { + std::unique_ptr module = MakeUnique(); + IRContext localContext(SPV_ENV_UNIVERSAL_1_2, std::move(module), + spvtools::MessageConsumer()); + + for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; + i <<= 1) { + localContext.BuildInvalidAnalyses(i); + } + + DummyPassPreservesFirst pass(Pass::Status::SuccessWithChange); + Pass::Status s = pass.Run(&localContext); + EXPECT_EQ(s, Pass::Status::SuccessWithChange); + EXPECT_TRUE(localContext.AreAnalysesValid(IRContext::kAnalysisBegin)); + for (Analysis i = IRContext::kAnalysisBegin << 1; i < IRContext::kAnalysisEnd; + i <<= 1) { + EXPECT_FALSE(localContext.AreAnalysesValid(i)); + } +} + +TEST_F(IRContextTest, KillMemberName) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %3 "stuff" + OpMemberName %3 0 "refZ" + OpMemberDecorate %3 0 Offset 0 + OpDecorate %3 Block + %4 = OpTypeFloat 32 + %3 = OpTypeStruct %4 + %5 = OpTypeVoid + %6 = OpTypeFunction %5 + %2 = OpFunction %5 None %6 + %7 = OpLabel + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + + // Build the decoration manager. + context->get_decoration_mgr(); + + // Delete the OpTypeStruct. Should delete the OpName, OpMemberName, and + // OpMemberDecorate associated with it. + context->KillDef(3); + + // Make sure all of the name are removed. + for (auto& inst : context->debugs2()) { + EXPECT_EQ(inst.opcode(), SpvOpNop); + } + + // Make sure all of the decorations are removed. + for (auto& inst : context->annotations()) { + EXPECT_EQ(inst.opcode(), SpvOpNop); + } +} + +TEST_F(IRContextTest, KillGroupDecoration) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpDecorate %3 Restrict + %3 = OpDecorationGroup + OpGroupDecorate %3 %4 %5 + %6 = OpTypeFloat 32 + %7 = OpTypePointer Function %6 + %8 = OpTypeStruct %6 + %9 = OpTypeVoid + %10 = OpTypeFunction %9 + %2 = OpFunction %9 None %10 + %11 = OpLabel + %4 = OpVariable %7 Function + %5 = OpVariable %7 Function + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + + // Build the decoration manager. + context->get_decoration_mgr(); + + // Delete the second variable. + context->KillDef(5); + + // The three decorations instructions should still be there. The first two + // should be the same, but the third should have %5 removed. + + // Check the OpDecorate instruction + auto inst = context->annotation_begin(); + EXPECT_EQ(inst->opcode(), SpvOpDecorate); + EXPECT_EQ(inst->GetSingleWordInOperand(0), 3); + + // Check the OpDecorationGroup Instruction + ++inst; + EXPECT_EQ(inst->opcode(), SpvOpDecorationGroup); + EXPECT_EQ(inst->result_id(), 3); + + // Check that %5 is no longer part of the group. + ++inst; + EXPECT_EQ(inst->opcode(), SpvOpGroupDecorate); + EXPECT_EQ(inst->NumInOperands(), 2); + EXPECT_EQ(inst->GetSingleWordInOperand(0), 3); + EXPECT_EQ(inst->GetSingleWordInOperand(1), 4); + + // Check that we are at the end. + ++inst; + EXPECT_EQ(inst, context->annotation_end()); +} + +TEST_F(IRContextTest, TakeNextUniqueIdIncrementing) { + const uint32_t NUM_TESTS = 1000; + IRContext localContext(SPV_ENV_UNIVERSAL_1_2, nullptr); + for (uint32_t i = 1; i < NUM_TESTS; ++i) + EXPECT_EQ(i, localContext.TakeNextUniqueId()); +} + +TEST_F(IRContextTest, KillGroupDecorationWitNoDecorations) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpDecorationGroup + OpGroupDecorate %3 %4 %5 + %6 = OpTypeFloat 32 + %7 = OpTypePointer Function %6 + %8 = OpTypeStruct %6 + %9 = OpTypeVoid + %10 = OpTypeFunction %9 + %2 = OpFunction %9 None %10 + %11 = OpLabel + %4 = OpVariable %7 Function + %5 = OpVariable %7 Function + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + + // Build the decoration manager. + context->get_decoration_mgr(); + + // Delete the second variable. + context->KillDef(5); + + // The two decoration instructions should still be there. The first one + // should be the same, but the second should have %5 removed. + + // Check the OpDecorationGroup Instruction + auto inst = context->annotation_begin(); + EXPECT_EQ(inst->opcode(), SpvOpDecorationGroup); + EXPECT_EQ(inst->result_id(), 3); + + // Check that %5 is no longer part of the group. + ++inst; + EXPECT_EQ(inst->opcode(), SpvOpGroupDecorate); + EXPECT_EQ(inst->NumInOperands(), 2); + EXPECT_EQ(inst->GetSingleWordInOperand(0), 3); + EXPECT_EQ(inst->GetSingleWordInOperand(1), 4); + + // Check that we are at the end. + ++inst; + EXPECT_EQ(inst, context->annotation_end()); +} + +TEST_F(IRContextTest, KillDecorationGroup) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpDecorationGroup + OpGroupDecorate %3 %4 %5 + %6 = OpTypeFloat 32 + %7 = OpTypePointer Function %6 + %8 = OpTypeStruct %6 + %9 = OpTypeVoid + %10 = OpTypeFunction %9 + %2 = OpFunction %9 None %10 + %11 = OpLabel + %4 = OpVariable %7 Function + %5 = OpVariable %7 Function + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + + // Build the decoration manager. + context->get_decoration_mgr(); + + // Delete the second variable. + context->KillDef(3); + + // Check the OpDecorationGroup Instruction is still there. + EXPECT_TRUE(context->annotations().empty()); +} + +TEST_F(IRContextTest, BasicVisitFromEntryPoint) { + // Make sure we visit the entry point, and the function it calls. + // Do not visit Dead or Exported. + const std::string text = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %10 "main" + OpName %10 "main" + OpName %Dead "Dead" + OpName %11 "Constant" + OpName %ExportedFunc "ExportedFunc" + OpDecorate %ExportedFunc LinkageAttributes "ExportedFunc" Export + %void = OpTypeVoid + %6 = OpTypeFunction %void + %10 = OpFunction %void None %6 + %14 = OpLabel + %15 = OpFunctionCall %void %11 + %16 = OpFunctionCall %void %11 + OpReturn + OpFunctionEnd + %11 = OpFunction %void None %6 + %18 = OpLabel + OpReturn + OpFunctionEnd + %Dead = OpFunction %void None %6 + %19 = OpLabel + OpReturn + OpFunctionEnd +%ExportedFunc = OpFunction %void None %7 + %20 = OpLabel + %21 = OpFunctionCall %void %11 + OpReturn + OpFunctionEnd +)"; + // clang-format on + + std::unique_ptr localContext = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n" + << text << std::endl; + std::vector processed; + Pass::ProcessFunction mark_visited = [&processed](Function* fp) { + processed.push_back(fp->result_id()); + return false; + }; + localContext->ProcessEntryPointCallTree(mark_visited); + EXPECT_THAT(processed, UnorderedElementsAre(10, 11)); +} + +TEST_F(IRContextTest, BasicVisitReachable) { + // Make sure we visit the entry point, exported function, and the function + // they call. Do not visit Dead. + const std::string text = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %10 "main" + OpName %10 "main" + OpName %Dead "Dead" + OpName %11 "Constant" + OpName %12 "ExportedFunc" + OpName %13 "Constant2" + OpDecorate %12 LinkageAttributes "ExportedFunc" Export + %void = OpTypeVoid + %6 = OpTypeFunction %void + %10 = OpFunction %void None %6 + %14 = OpLabel + %15 = OpFunctionCall %void %11 + %16 = OpFunctionCall %void %11 + OpReturn + OpFunctionEnd + %11 = OpFunction %void None %6 + %18 = OpLabel + OpReturn + OpFunctionEnd + %Dead = OpFunction %void None %6 + %19 = OpLabel + OpReturn + OpFunctionEnd + %12 = OpFunction %void None %6 + %20 = OpLabel + %21 = OpFunctionCall %void %13 + OpReturn + OpFunctionEnd + %13 = OpFunction %void None %6 + %22 = OpLabel + OpReturn + OpFunctionEnd +)"; + // clang-format on + + std::unique_ptr localContext = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n" + << text << std::endl; + + std::vector processed; + Pass::ProcessFunction mark_visited = [&processed](Function* fp) { + processed.push_back(fp->result_id()); + return false; + }; + localContext->ProcessReachableCallTree(mark_visited); + EXPECT_THAT(processed, UnorderedElementsAre(10, 11, 12, 13)); +} + +TEST_F(IRContextTest, BasicVisitOnlyOnce) { + // Make sure we visit %12 only once, even if it is called from two different + // functions. + const std::string text = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %10 "main" + OpName %10 "main" + OpName %Dead "Dead" + OpName %11 "Constant" + OpName %12 "ExportedFunc" + OpDecorate %12 LinkageAttributes "ExportedFunc" Export + %void = OpTypeVoid + %6 = OpTypeFunction %void + %10 = OpFunction %void None %6 + %14 = OpLabel + %15 = OpFunctionCall %void %11 + %16 = OpFunctionCall %void %12 + OpReturn + OpFunctionEnd + %11 = OpFunction %void None %6 + %18 = OpLabel + %19 = OpFunctionCall %void %12 + OpReturn + OpFunctionEnd + %Dead = OpFunction %void None %6 + %20 = OpLabel + OpReturn + OpFunctionEnd + %12 = OpFunction %void None %6 + %21 = OpLabel + OpReturn + OpFunctionEnd +)"; + // clang-format on + + std::unique_ptr localContext = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n" + << text << std::endl; + + std::vector processed; + Pass::ProcessFunction mark_visited = [&processed](Function* fp) { + processed.push_back(fp->result_id()); + return false; + }; + localContext->ProcessReachableCallTree(mark_visited); + EXPECT_THAT(processed, UnorderedElementsAre(10, 11, 12)); +} + +TEST_F(IRContextTest, BasicDontVisitExportedVariable) { + // Make sure we only visit functions and not exported variables. + const std::string text = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %10 "main" + OpExecutionMode %10 OriginUpperLeft + OpSource GLSL 150 + OpName %10 "main" + OpName %12 "export_var" + OpDecorate %12 LinkageAttributes "export_var" Export + %void = OpTypeVoid + %6 = OpTypeFunction %void + %float = OpTypeFloat 32 + %float_1 = OpConstant %float 1 + %12 = OpVariable %float Output + %10 = OpFunction %void None %6 + %14 = OpLabel + OpStore %12 %float_1 + OpReturn + OpFunctionEnd +)"; + // clang-format on + + std::unique_ptr localContext = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n" + << text << std::endl; + + std::vector processed; + Pass::ProcessFunction mark_visited = [&processed](Function* fp) { + processed.push_back(fp->result_id()); + return false; + }; + localContext->ProcessReachableCallTree(mark_visited); + EXPECT_THAT(processed, UnorderedElementsAre(10)); +} + +TEST_F(IRContextTest, IdBoundTestAtLimit) { + const std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpFunction %1 None %2 +%4 = OpLabel +OpReturn +OpFunctionEnd)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + uint32_t current_bound = context->module()->id_bound(); + context->set_max_id_bound(current_bound); + uint32_t next_id_bound = context->TakeNextId(); + EXPECT_EQ(next_id_bound, 0); + EXPECT_EQ(current_bound, context->module()->id_bound()); + next_id_bound = context->TakeNextId(); + EXPECT_EQ(next_id_bound, 0); +} + +TEST_F(IRContextTest, IdBoundTestBelowLimit) { + const std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpFunction %1 None %2 +%4 = OpLabel +OpReturn +OpFunctionEnd)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + uint32_t current_bound = context->module()->id_bound(); + context->set_max_id_bound(current_bound + 100); + uint32_t next_id_bound = context->TakeNextId(); + EXPECT_EQ(next_id_bound, current_bound); + EXPECT_EQ(current_bound + 1, context->module()->id_bound()); + next_id_bound = context->TakeNextId(); + EXPECT_EQ(next_id_bound, current_bound + 1); +} + +TEST_F(IRContextTest, IdBoundTestNearLimit) { + const std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpFunction %1 None %2 +%4 = OpLabel +OpReturn +OpFunctionEnd)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + uint32_t current_bound = context->module()->id_bound(); + context->set_max_id_bound(current_bound + 1); + uint32_t next_id_bound = context->TakeNextId(); + EXPECT_EQ(next_id_bound, current_bound); + EXPECT_EQ(current_bound + 1, context->module()->id_bound()); + next_id_bound = context->TakeNextId(); + EXPECT_EQ(next_id_bound, 0); +} + +TEST_F(IRContextTest, IdBoundTestUIntMax) { + const std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpFunction %1 None %2 +%4294967294 = OpLabel ; ID is UINT_MAX-1 +OpReturn +OpFunctionEnd)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + uint32_t current_bound = context->module()->id_bound(); + + // Expecting |BuildModule| to preserve the numeric ids. + EXPECT_EQ(current_bound, std::numeric_limits::max()); + + context->set_max_id_bound(current_bound); + uint32_t next_id_bound = context->TakeNextId(); + EXPECT_EQ(next_id_bound, 0); + EXPECT_EQ(current_bound, context->module()->id_bound()); +} +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/ir_loader_test.cpp b/test/opt/ir_loader_test.cpp new file mode 100644 index 000000000..ac5c52075 --- /dev/null +++ b/test/opt/ir_loader_test.cpp @@ -0,0 +1,451 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "source/opt/build_module.h" +#include "source/opt/ir_context.h" +#include "spirv-tools/libspirv.hpp" + +namespace spvtools { +namespace opt { +namespace { + +void DoRoundTripCheck(const std::string& text) { + SpirvTools t(SPV_ENV_UNIVERSAL_1_1); + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); + ASSERT_NE(nullptr, context) << "Failed to assemble\n" << text; + + std::vector binary; + context->module()->ToBinary(&binary, /* skip_nop = */ false); + + std::string disassembled_text; + EXPECT_TRUE(t.Disassemble(binary, &disassembled_text)); + EXPECT_EQ(text, disassembled_text); +} + +TEST(IrBuilder, RoundTrip) { + // #version 310 es + // int add(int a, int b) { return a + b; } + // void main() { add(1, 2); } + DoRoundTripCheck( + // clang-format off + "OpCapability Shader\n" + "%1 = OpExtInstImport \"GLSL.std.450\"\n" + "OpMemoryModel Logical GLSL450\n" + "OpEntryPoint Vertex %main \"main\"\n" + "OpSource ESSL 310\n" + "OpSourceExtension \"GL_GOOGLE_cpp_style_line_directive\"\n" + "OpSourceExtension \"GL_GOOGLE_include_directive\"\n" + "OpName %main \"main\"\n" + "OpName %add_i1_i1_ \"add(i1;i1;\"\n" + "OpName %a \"a\"\n" + "OpName %b \"b\"\n" + "OpName %param \"param\"\n" + "OpName %param_0 \"param\"\n" + "%void = OpTypeVoid\n" + "%9 = OpTypeFunction %void\n" + "%int = OpTypeInt 32 1\n" + "%_ptr_Function_int = OpTypePointer Function %int\n" + "%12 = OpTypeFunction %int %_ptr_Function_int %_ptr_Function_int\n" + "%int_1 = OpConstant %int 1\n" + "%int_2 = OpConstant %int 2\n" + "%main = OpFunction %void None %9\n" + "%15 = OpLabel\n" + "%param = OpVariable %_ptr_Function_int Function\n" + "%param_0 = OpVariable %_ptr_Function_int Function\n" + "OpStore %param %int_1\n" + "OpStore %param_0 %int_2\n" + "%16 = OpFunctionCall %int %add_i1_i1_ %param %param_0\n" + "OpReturn\n" + "OpFunctionEnd\n" + "%add_i1_i1_ = OpFunction %int None %12\n" + "%a = OpFunctionParameter %_ptr_Function_int\n" + "%b = OpFunctionParameter %_ptr_Function_int\n" + "%17 = OpLabel\n" + "%18 = OpLoad %int %a\n" + "%19 = OpLoad %int %b\n" + "%20 = OpIAdd %int %18 %19\n" + "OpReturnValue %20\n" + "OpFunctionEnd\n"); + // clang-format on +} + +TEST(IrBuilder, RoundTripIncompleteBasicBlock) { + DoRoundTripCheck( + "%2 = OpFunction %1 None %3\n" + "%4 = OpLabel\n" + "OpNop\n"); +} + +TEST(IrBuilder, RoundTripIncompleteFunction) { + DoRoundTripCheck("%2 = OpFunction %1 None %3\n"); +} + +TEST(IrBuilder, KeepLineDebugInfo) { + // #version 310 es + // void main() {} + DoRoundTripCheck( + // clang-format off + "OpCapability Shader\n" + "%1 = OpExtInstImport \"GLSL.std.450\"\n" + "OpMemoryModel Logical GLSL450\n" + "OpEntryPoint Vertex %main \"main\"\n" + "%3 = OpString \"minimal.vert\"\n" + "OpSource ESSL 310\n" + "OpName %main \"main\"\n" + "OpLine %3 10 10\n" + "%void = OpTypeVoid\n" + "OpLine %3 100 100\n" + "%5 = OpTypeFunction %void\n" + "%main = OpFunction %void None %5\n" + "OpLine %3 1 1\n" + "OpNoLine\n" + "OpLine %3 2 2\n" + "OpLine %3 3 3\n" + "%6 = OpLabel\n" + "OpLine %3 4 4\n" + "OpNoLine\n" + "OpReturn\n" + "OpFunctionEnd\n"); + // clang-format on +} + +TEST(IrBuilder, LocalGlobalVariables) { + // #version 310 es + // + // float gv1 = 10.; + // float gv2 = 100.; + // + // float f() { + // float lv1 = gv1 + gv2; + // float lv2 = gv1 * gv2; + // return lv1 / lv2; + // } + // + // void main() { + // float lv1 = gv1 - gv2; + // } + DoRoundTripCheck( + // clang-format off + "OpCapability Shader\n" + "%1 = OpExtInstImport \"GLSL.std.450\"\n" + "OpMemoryModel Logical GLSL450\n" + "OpEntryPoint Vertex %main \"main\"\n" + "OpSource ESSL 310\n" + "OpName %main \"main\"\n" + "OpName %f_ \"f(\"\n" + "OpName %gv1 \"gv1\"\n" + "OpName %gv2 \"gv2\"\n" + "OpName %lv1 \"lv1\"\n" + "OpName %lv2 \"lv2\"\n" + "OpName %lv1_0 \"lv1\"\n" + "%void = OpTypeVoid\n" + "%10 = OpTypeFunction %void\n" + "%float = OpTypeFloat 32\n" + "%12 = OpTypeFunction %float\n" + "%_ptr_Private_float = OpTypePointer Private %float\n" + "%gv1 = OpVariable %_ptr_Private_float Private\n" + "%float_10 = OpConstant %float 10\n" + "%gv2 = OpVariable %_ptr_Private_float Private\n" + "%float_100 = OpConstant %float 100\n" + "%_ptr_Function_float = OpTypePointer Function %float\n" + "%main = OpFunction %void None %10\n" + "%17 = OpLabel\n" + "%lv1_0 = OpVariable %_ptr_Function_float Function\n" + "OpStore %gv1 %float_10\n" + "OpStore %gv2 %float_100\n" + "%18 = OpLoad %float %gv1\n" + "%19 = OpLoad %float %gv2\n" + "%20 = OpFSub %float %18 %19\n" + "OpStore %lv1_0 %20\n" + "OpReturn\n" + "OpFunctionEnd\n" + "%f_ = OpFunction %float None %12\n" + "%21 = OpLabel\n" + "%lv1 = OpVariable %_ptr_Function_float Function\n" + "%lv2 = OpVariable %_ptr_Function_float Function\n" + "%22 = OpLoad %float %gv1\n" + "%23 = OpLoad %float %gv2\n" + "%24 = OpFAdd %float %22 %23\n" + "OpStore %lv1 %24\n" + "%25 = OpLoad %float %gv1\n" + "%26 = OpLoad %float %gv2\n" + "%27 = OpFMul %float %25 %26\n" + "OpStore %lv2 %27\n" + "%28 = OpLoad %float %lv1\n" + "%29 = OpLoad %float %lv2\n" + "%30 = OpFDiv %float %28 %29\n" + "OpReturnValue %30\n" + "OpFunctionEnd\n"); + // clang-format on +} + +TEST(IrBuilder, OpUndefOutsideFunction) { + // #version 310 es + // void main() {} + const std::string text = + // clang-format off + "OpMemoryModel Logical GLSL450\n" + "%int = OpTypeInt 32 1\n" + "%uint = OpTypeInt 32 0\n" + "%float = OpTypeFloat 32\n" + "%4 = OpUndef %int\n" + "%int_10 = OpConstant %int 10\n" + "%6 = OpUndef %uint\n" + "%bool = OpTypeBool\n" + "%8 = OpUndef %float\n" + "%double = OpTypeFloat 64\n"; + // clang-format on + + SpirvTools t(SPV_ENV_UNIVERSAL_1_1); + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); + ASSERT_NE(nullptr, context); + + const auto opundef_count = std::count_if( + context->module()->types_values_begin(), + context->module()->types_values_end(), + [](const Instruction& inst) { return inst.opcode() == SpvOpUndef; }); + EXPECT_EQ(3, opundef_count); + + std::vector binary; + context->module()->ToBinary(&binary, /* skip_nop = */ false); + + std::string disassembled_text; + EXPECT_TRUE(t.Disassemble(binary, &disassembled_text)); + EXPECT_EQ(text, disassembled_text); +} + +TEST(IrBuilder, OpUndefInBasicBlock) { + DoRoundTripCheck( + // clang-format off + "OpMemoryModel Logical GLSL450\n" + "OpName %main \"main\"\n" + "%void = OpTypeVoid\n" + "%uint = OpTypeInt 32 0\n" + "%double = OpTypeFloat 64\n" + "%5 = OpTypeFunction %void\n" + "%main = OpFunction %void None %5\n" + "%6 = OpLabel\n" + "%7 = OpUndef %uint\n" + "%8 = OpUndef %double\n" + "OpReturn\n" + "OpFunctionEnd\n"); + // clang-format on +} + +TEST(IrBuilder, KeepLineDebugInfoBeforeType) { + DoRoundTripCheck( + // clang-format off + "OpCapability Shader\n" + "OpMemoryModel Logical GLSL450\n" + "%1 = OpString \"minimal.vert\"\n" + "OpLine %1 1 1\n" + "OpNoLine\n" + "%void = OpTypeVoid\n" + "OpLine %1 2 2\n" + "%3 = OpTypeFunction %void\n"); + // clang-format on +} + +TEST(IrBuilder, KeepLineDebugInfoBeforeLabel) { + DoRoundTripCheck( + // clang-format off + "OpCapability Shader\n" + "OpMemoryModel Logical GLSL450\n" + "%1 = OpString \"minimal.vert\"\n" + "%void = OpTypeVoid\n" + "%3 = OpTypeFunction %void\n" + "%4 = OpFunction %void None %3\n" + "%5 = OpLabel\n" + "OpBranch %6\n" + "OpLine %1 1 1\n" + "OpLine %1 2 2\n" + "%6 = OpLabel\n" + "OpBranch %7\n" + "OpLine %1 100 100\n" + "%7 = OpLabel\n" + "OpReturn\n" + "OpFunctionEnd\n"); + // clang-format on +} + +TEST(IrBuilder, KeepLineDebugInfoBeforeFunctionEnd) { + DoRoundTripCheck( + // clang-format off + "OpCapability Shader\n" + "OpMemoryModel Logical GLSL450\n" + "%1 = OpString \"minimal.vert\"\n" + "%void = OpTypeVoid\n" + "%3 = OpTypeFunction %void\n" + "%4 = OpFunction %void None %3\n" + "OpLine %1 1 1\n" + "OpLine %1 2 2\n" + "OpFunctionEnd\n"); + // clang-format on +} + +TEST(IrBuilder, KeepModuleProcessedInRightPlace) { + DoRoundTripCheck( + // clang-format off + "OpCapability Shader\n" + "OpMemoryModel Logical GLSL450\n" + "%1 = OpString \"minimal.vert\"\n" + "OpName %void \"void\"\n" + "OpModuleProcessed \"Made it faster\"\n" + "OpModuleProcessed \".. and smaller\"\n" + "%void = OpTypeVoid\n"); + // clang-format on +} + +// Checks the given |error_message| is reported when trying to build a module +// from the given |assembly|. +void DoErrorMessageCheck(const std::string& assembly, + const std::string& error_message, uint32_t line_num) { + auto consumer = [error_message, line_num](spv_message_level_t, const char*, + const spv_position_t& position, + const char* m) { + EXPECT_EQ(error_message, m); + EXPECT_EQ(line_num, position.line); + }; + + SpirvTools t(SPV_ENV_UNIVERSAL_1_1); + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, std::move(consumer), assembly); + EXPECT_EQ(nullptr, context); +} + +TEST(IrBuilder, FunctionInsideFunction) { + DoErrorMessageCheck("%2 = OpFunction %1 None %3\n%5 = OpFunction %4 None %6", + "function inside function", 2); +} + +TEST(IrBuilder, MismatchOpFunctionEnd) { + DoErrorMessageCheck("OpFunctionEnd", + "OpFunctionEnd without corresponding OpFunction", 1); +} + +TEST(IrBuilder, OpFunctionEndInsideBasicBlock) { + DoErrorMessageCheck( + "%2 = OpFunction %1 None %3\n" + "%4 = OpLabel\n" + "OpFunctionEnd", + "OpFunctionEnd inside basic block", 3); +} + +TEST(IrBuilder, BasicBlockOutsideFunction) { + DoErrorMessageCheck("OpCapability Shader\n%1 = OpLabel", + "OpLabel outside function", 2); +} + +TEST(IrBuilder, OpLabelInsideBasicBlock) { + DoErrorMessageCheck( + "%2 = OpFunction %1 None %3\n" + "%4 = OpLabel\n" + "%5 = OpLabel", + "OpLabel inside basic block", 3); +} + +TEST(IrBuilder, TerminatorOutsideFunction) { + DoErrorMessageCheck("OpReturn", "terminator instruction outside function", 1); +} + +TEST(IrBuilder, TerminatorOutsideBasicBlock) { + DoErrorMessageCheck("%2 = OpFunction %1 None %3\nOpReturn", + "terminator instruction outside basic block", 2); +} + +TEST(IrBuilder, NotAllowedInstAppearingInFunction) { + DoErrorMessageCheck("%2 = OpFunction %1 None %3\n%5 = OpVariable %4 Function", + "Non-OpFunctionParameter (opcode: 59) found inside " + "function but outside basic block", + 2); +} + +TEST(IrBuilder, UniqueIds) { + const std::string text = + // clang-format off + "OpCapability Shader\n" + "%1 = OpExtInstImport \"GLSL.std.450\"\n" + "OpMemoryModel Logical GLSL450\n" + "OpEntryPoint Vertex %main \"main\"\n" + "OpSource ESSL 310\n" + "OpName %main \"main\"\n" + "OpName %f_ \"f(\"\n" + "OpName %gv1 \"gv1\"\n" + "OpName %gv2 \"gv2\"\n" + "OpName %lv1 \"lv1\"\n" + "OpName %lv2 \"lv2\"\n" + "OpName %lv1_0 \"lv1\"\n" + "%void = OpTypeVoid\n" + "%10 = OpTypeFunction %void\n" + "%float = OpTypeFloat 32\n" + "%12 = OpTypeFunction %float\n" + "%_ptr_Private_float = OpTypePointer Private %float\n" + "%gv1 = OpVariable %_ptr_Private_float Private\n" + "%float_10 = OpConstant %float 10\n" + "%gv2 = OpVariable %_ptr_Private_float Private\n" + "%float_100 = OpConstant %float 100\n" + "%_ptr_Function_float = OpTypePointer Function %float\n" + "%main = OpFunction %void None %10\n" + "%17 = OpLabel\n" + "%lv1_0 = OpVariable %_ptr_Function_float Function\n" + "OpStore %gv1 %float_10\n" + "OpStore %gv2 %float_100\n" + "%18 = OpLoad %float %gv1\n" + "%19 = OpLoad %float %gv2\n" + "%20 = OpFSub %float %18 %19\n" + "OpStore %lv1_0 %20\n" + "OpReturn\n" + "OpFunctionEnd\n" + "%f_ = OpFunction %float None %12\n" + "%21 = OpLabel\n" + "%lv1 = OpVariable %_ptr_Function_float Function\n" + "%lv2 = OpVariable %_ptr_Function_float Function\n" + "%22 = OpLoad %float %gv1\n" + "%23 = OpLoad %float %gv2\n" + "%24 = OpFAdd %float %22 %23\n" + "OpStore %lv1 %24\n" + "%25 = OpLoad %float %gv1\n" + "%26 = OpLoad %float %gv2\n" + "%27 = OpFMul %float %25 %26\n" + "OpStore %lv2 %27\n" + "%28 = OpLoad %float %lv1\n" + "%29 = OpLoad %float %lv2\n" + "%30 = OpFDiv %float %28 %29\n" + "OpReturnValue %30\n" + "OpFunctionEnd\n"; + // clang-format on + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); + ASSERT_NE(nullptr, context); + + std::unordered_set ids; + context->module()->ForEachInst([&ids](const Instruction* inst) { + EXPECT_TRUE(ids.insert(inst->unique_id()).second); + }); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/iterator_test.cpp b/test/opt/iterator_test.cpp new file mode 100644 index 000000000..d61bc1ab8 --- /dev/null +++ b/test/opt/iterator_test.cpp @@ -0,0 +1,267 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gmock/gmock.h" + +#include "source/opt/iterator.h" +#include "source/util/make_unique.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::ContainerEq; + +TEST(Iterator, IncrementDeref) { + const int count = 100; + std::vector> data; + for (int i = 0; i < count; ++i) { + data.emplace_back(new int(i)); + } + + UptrVectorIterator it(&data, data.begin()); + UptrVectorIterator end(&data, data.end()); + + EXPECT_EQ(*data[0], *it); + for (int i = 1; i < count; ++i) { + EXPECT_NE(end, it); + EXPECT_EQ(*data[i], *(++it)); + } + EXPECT_EQ(end, ++it); +} + +TEST(Iterator, DecrementDeref) { + const int count = 100; + std::vector> data; + for (int i = 0; i < count; ++i) { + data.emplace_back(new int(i)); + } + + UptrVectorIterator begin(&data, data.begin()); + UptrVectorIterator it(&data, data.end()); + + for (int i = count - 1; i >= 0; --i) { + EXPECT_NE(begin, it); + EXPECT_EQ(*data[i], *(--it)); + } + EXPECT_EQ(begin, it); +} + +TEST(Iterator, PostIncrementDeref) { + const int count = 100; + std::vector> data; + for (int i = 0; i < count; ++i) { + data.emplace_back(new int(i)); + } + + UptrVectorIterator it(&data, data.begin()); + UptrVectorIterator end(&data, data.end()); + + for (int i = 0; i < count; ++i) { + EXPECT_NE(end, it); + EXPECT_EQ(*data[i], *(it++)); + } + EXPECT_EQ(end, it); +} + +TEST(Iterator, PostDecrementDeref) { + const int count = 100; + std::vector> data; + for (int i = 0; i < count; ++i) { + data.emplace_back(new int(i)); + } + + UptrVectorIterator begin(&data, data.begin()); + UptrVectorIterator end(&data, data.end()); + UptrVectorIterator it(&data, data.end()); + + EXPECT_EQ(end, it--); + for (int i = count - 1; i >= 1; --i) { + EXPECT_EQ(*data[i], *(it--)); + } + // Decrementing .begin() is undefined behavior. + EXPECT_EQ(*data[0], *it); +} + +TEST(Iterator, Access) { + const int count = 100; + std::vector> data; + for (int i = 0; i < count; ++i) { + data.emplace_back(new int(i)); + } + + UptrVectorIterator it(&data, data.begin()); + + for (int i = 0; i < count; ++i) EXPECT_EQ(*data[i], it[i]); +} + +TEST(Iterator, Comparison) { + const int count = 100; + std::vector> data; + for (int i = 0; i < count; ++i) { + data.emplace_back(new int(i)); + } + + UptrVectorIterator it(&data, data.begin()); + UptrVectorIterator end(&data, data.end()); + + for (int i = 0; i < count; ++i, ++it) EXPECT_TRUE(it < end); + EXPECT_EQ(end, it); +} + +TEST(Iterator, InsertBeginEnd) { + const int count = 100; + + std::vector> data; + std::vector expected; + std::vector actual; + + for (int i = 0; i < count; ++i) { + data.emplace_back(new int(i)); + expected.push_back(i); + } + + // Insert at the beginning + expected.insert(expected.begin(), -100); + UptrVectorIterator begin(&data, data.begin()); + auto insert_point = begin.InsertBefore(MakeUnique(-100)); + for (int i = 0; i < count + 1; ++i) { + actual.push_back(*(insert_point++)); + } + EXPECT_THAT(actual, ContainerEq(expected)); + + // Insert at the end + expected.push_back(-42); + expected.push_back(-36); + expected.push_back(-77); + UptrVectorIterator end(&data, data.end()); + end = end.InsertBefore(MakeUnique(-77)); + end = end.InsertBefore(MakeUnique(-36)); + end = end.InsertBefore(MakeUnique(-42)); + + actual.clear(); + begin = UptrVectorIterator(&data, data.begin()); + for (int i = 0; i < count + 4; ++i) { + actual.push_back(*(begin++)); + } + EXPECT_THAT(actual, ContainerEq(expected)); +} + +TEST(Iterator, InsertMiddle) { + const int count = 100; + + std::vector> data; + std::vector expected; + std::vector actual; + + for (int i = 0; i < count; ++i) { + data.emplace_back(new int(i)); + expected.push_back(i); + } + + const int insert_pos = 42; + expected.insert(expected.begin() + insert_pos, -100); + expected.insert(expected.begin() + insert_pos, -42); + + UptrVectorIterator it(&data, data.begin()); + for (int i = 0; i < insert_pos; ++i) ++it; + it = it.InsertBefore(MakeUnique(-100)); + it = it.InsertBefore(MakeUnique(-42)); + auto begin = UptrVectorIterator(&data, data.begin()); + for (int i = 0; i < count + 2; ++i) { + actual.push_back(*(begin++)); + } + EXPECT_THAT(actual, ContainerEq(expected)); +} + +TEST(IteratorRange, Interface) { + const uint32_t count = 100; + + std::vector> data; + + for (uint32_t i = 0; i < count; ++i) { + data.emplace_back(new uint32_t(i)); + } + + auto b = UptrVectorIterator(&data, data.begin()); + auto e = UptrVectorIterator(&data, data.end()); + auto range = IteratorRange(b, e); + + EXPECT_EQ(b, range.begin()); + EXPECT_EQ(e, range.end()); + EXPECT_FALSE(range.empty()); + EXPECT_EQ(count, range.size()); + EXPECT_EQ(0u, *range.begin()); + EXPECT_EQ(99u, *(--range.end())); + + // IteratorRange itself is immutable. + ++b, --e; + EXPECT_EQ(count, range.size()); + ++range.begin(), --range.end(); + EXPECT_EQ(count, range.size()); +} + +TEST(Iterator, FilterIterator) { + struct Placeholder { + int val; + }; + std::vector data = {{1}, {2}, {3}, {4}, {5}, + {6}, {7}, {8}, {9}, {10}}; + + // Predicate to only consider odd values. + struct Predicate { + bool operator()(const Placeholder& data) { return data.val % 2; } + }; + Predicate pred; + + auto filter_range = MakeFilterIteratorRange(data.begin(), data.end(), pred); + + EXPECT_EQ(filter_range.begin().Get(), data.begin()); + EXPECT_EQ(filter_range.end(), filter_range.begin().GetEnd()); + + for (Placeholder& data : filter_range) { + EXPECT_EQ(data.val % 2, 1); + } + + for (auto it = filter_range.begin(); it != filter_range.end(); it++) { + EXPECT_EQ(it->val % 2, 1); + EXPECT_EQ((*it).val % 2, 1); + } + + for (auto it = filter_range.begin(); it != filter_range.end(); ++it) { + EXPECT_EQ(it->val % 2, 1); + EXPECT_EQ((*it).val % 2, 1); + } + + EXPECT_EQ(MakeFilterIterator(data.begin(), data.end(), pred).Get(), + data.begin()); + EXPECT_EQ(MakeFilterIterator(data.end(), data.end(), pred).Get(), data.end()); + EXPECT_EQ(MakeFilterIterator(data.begin(), data.end(), pred).GetEnd(), + MakeFilterIterator(data.end(), data.end(), pred)); + EXPECT_NE(MakeFilterIterator(data.begin(), data.end(), pred), + MakeFilterIterator(data.end(), data.end(), pred)); + + // Empty range: no values satisfies the predicate. + auto empty_range = MakeFilterIteratorRange( + data.begin(), data.end(), + [](const Placeholder& data) { return data.val > 10; }); + EXPECT_EQ(empty_range.begin(), empty_range.end()); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/line_debug_info_test.cpp b/test/opt/line_debug_info_test.cpp new file mode 100644 index 000000000..6a20a0136 --- /dev/null +++ b/test/opt/line_debug_info_test.cpp @@ -0,0 +1,113 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +// A pass turning all none debug line instructions into Nop. +class NopifyPass : public Pass { + public: + const char* name() const override { return "NopifyPass"; } + Status Process() override { + bool modified = false; + context()->module()->ForEachInst( + [&modified](Instruction* inst) { + inst->ToNop(); + modified = true; + }, + /* run_on_debug_line_insts = */ false); + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; + } +}; + +using PassTestForLineDebugInfo = PassTest<::testing::Test>; + +// This test's purpose to show our implementation choice: line debug info is +// preserved even if the following instruction is killed. It serves as a guard +// of potential behavior changes. +TEST_F(PassTestForLineDebugInfo, KeepLineDebugInfo) { + // clang-format off + const char* text = + "OpCapability Shader " + "%1 = OpExtInstImport \"GLSL.std.450\" " + "OpMemoryModel Logical GLSL450 " + "OpEntryPoint Vertex %2 \"main\" " + "%3 = OpString \"minimal.vert\" " + "OpNoLine " + "OpLine %3 10 10 " + "%void = OpTypeVoid " + "OpLine %3 100 100 " + "%5 = OpTypeFunction %void " + "%2 = OpFunction %void None %5 " + "OpLine %3 1 1 " + "OpNoLine " + "OpLine %3 2 2 " + "OpLine %3 3 3 " + "%6 = OpLabel " + "OpLine %3 4 4 " + "OpNoLine " + "OpReturn " + "OpLine %3 4 4 " + "OpNoLine " + "OpFunctionEnd "; + // clang-format on + + const char* result_keep_nop = + "OpNop\n" + "OpNop\n" + "OpNop\n" + "OpNop\n" + "OpNop\n" + "OpNoLine\n" + "OpLine %3 10 10\n" + "OpNop\n" + "OpLine %3 100 100\n" + "OpNop\n" + "OpNop\n" + "OpLine %3 1 1\n" + "OpNoLine\n" + "OpLine %3 2 2\n" + "OpLine %3 3 3\n" + "OpNop\n" + "OpLine %3 4 4\n" + "OpNoLine\n" + "OpNop\n" + "OpLine %3 4 4\n" + "OpNoLine\n" + "OpNop\n"; + SinglePassRunAndCheck(text, result_keep_nop, + /* skip_nop = */ false); + const char* result_skip_nop = + "OpNoLine\n" + "OpLine %3 10 10\n" + "OpLine %3 100 100\n" + "OpLine %3 1 1\n" + "OpNoLine\n" + "OpLine %3 2 2\n" + "OpLine %3 3 3\n" + "OpLine %3 4 4\n" + "OpNoLine\n" + "OpLine %3 4 4\n" + "OpNoLine\n"; + SinglePassRunAndCheck(text, result_skip_nop, + /* skip_nop = */ true); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/local_access_chain_convert_test.cpp b/test/opt/local_access_chain_convert_test.cpp new file mode 100644 index 000000000..154824b1e --- /dev/null +++ b/test/opt/local_access_chain_convert_test.cpp @@ -0,0 +1,714 @@ +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using LocalAccessChainConvertTest = PassTest<::testing::Test>; + +TEST_F(LocalAccessChainConvertTest, StructOfVecsOfFloatConverted) { + // #version 140 + // + // in vec4 BaseColor; + // + // struct S_t { + // vec4 v0; + // vec4 v1; + // }; + // + // void main() + // { + // S_t s0; + // s0.v1 = BaseColor; + // gl_FragColor = s0.v1; + // } + + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %S_t "S_t" +OpMemberName %S_t 0 "v0" +OpMemberName %S_t 1 "v1" +OpName %s0 "s0" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%S_t = OpTypeStruct %v4float %v4float +%_ptr_Function_S_t = OpTypePointer Function %S_t +%int = OpTypeInt 32 1 +%int_1 = OpConstant %int 1 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"( +; CHECK: [[st_id:%\w+]] = OpLoad %v4float %BaseColor +; CHECK: [[ld1:%\w+]] = OpLoad %S_t %s0 +; CHECK: [[ex1:%\w+]] = OpCompositeInsert %S_t [[st_id]] [[ld1]] 1 +; CHECK: OpStore %s0 [[ex1]] +; CHECK: [[ld2:%\w+]] = OpLoad %S_t %s0 +; CHECK: [[ex2:%\w+]] = OpCompositeExtract %v4float [[ld2]] 1 +; CHECK: OpStore %gl_FragColor [[ex2]] +%main = OpFunction %void None %8 +%17 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%18 = OpLoad %v4float %BaseColor +%19 = OpAccessChain %_ptr_Function_v4float %s0 %int_1 +OpStore %19 %18 +%20 = OpAccessChain %_ptr_Function_v4float %s0 %int_1 +%21 = OpLoad %v4float %20 +OpStore %gl_FragColor %21 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(predefs_before + before, + true); +} + +TEST_F(LocalAccessChainConvertTest, InBoundsAccessChainsConverted) { + // #version 140 + // + // in vec4 BaseColor; + // + // struct S_t { + // vec4 v0; + // vec4 v1; + // }; + // + // void main() + // { + // S_t s0; + // s0.v1 = BaseColor; + // gl_FragColor = s0.v1; + // } + + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %S_t "S_t" +OpMemberName %S_t 0 "v0" +OpMemberName %S_t 1 "v1" +OpName %s0 "s0" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%S_t = OpTypeStruct %v4float %v4float +%_ptr_Function_S_t = OpTypePointer Function %S_t +%int = OpTypeInt 32 1 +%int_1 = OpConstant %int 1 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"( +; CHECK: [[st_id:%\w+]] = OpLoad %v4float %BaseColor +; CHECK: [[ld1:%\w+]] = OpLoad %S_t %s0 +; CHECK: [[ex1:%\w+]] = OpCompositeInsert %S_t [[st_id]] [[ld1]] 1 +; CHECK: OpStore %s0 [[ex1]] +; CHECK: [[ld2:%\w+]] = OpLoad %S_t %s0 +; CHECK: [[ex2:%\w+]] = OpCompositeExtract %v4float [[ld2]] 1 +; CHECK: OpStore %gl_FragColor [[ex2]] +%main = OpFunction %void None %8 +%17 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%18 = OpLoad %v4float %BaseColor +%19 = OpInBoundsAccessChain %_ptr_Function_v4float %s0 %int_1 +OpStore %19 %18 +%20 = OpInBoundsAccessChain %_ptr_Function_v4float %s0 %int_1 +%21 = OpLoad %v4float %20 +OpStore %gl_FragColor %21 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(predefs_before + before, + true); +} + +TEST_F(LocalAccessChainConvertTest, TwoUsesofSingleChainConverted) { + // #version 140 + // + // in vec4 BaseColor; + // + // struct S_t { + // vec4 v0; + // vec4 v1; + // }; + // + // void main() + // { + // S_t s0; + // s0.v1 = BaseColor; + // gl_FragColor = s0.v1; + // } + + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %S_t "S_t" +OpMemberName %S_t 0 "v0" +OpMemberName %S_t 1 "v1" +OpName %s0 "s0" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%S_t = OpTypeStruct %v4float %v4float +%_ptr_Function_S_t = OpTypePointer Function %S_t +%int = OpTypeInt 32 1 +%int_1 = OpConstant %int 1 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"( +; CHECK: [[st_id:%\w+]] = OpLoad %v4float %BaseColor +; CHECK: [[ld1:%\w+]] = OpLoad %S_t %s0 +; CHECK: [[ex1:%\w+]] = OpCompositeInsert %S_t [[st_id]] [[ld1]] 1 +; CHECK: OpStore %s0 [[ex1]] +; CHECK: [[ld2:%\w+]] = OpLoad %S_t %s0 +; CHECK: [[ex2:%\w+]] = OpCompositeExtract %v4float [[ld2]] 1 +; CHECK: OpStore %gl_FragColor [[ex2]] +%main = OpFunction %void None %8 +%17 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%18 = OpLoad %v4float %BaseColor +%19 = OpAccessChain %_ptr_Function_v4float %s0 %int_1 +OpStore %19 %18 +%20 = OpLoad %v4float %19 +OpStore %gl_FragColor %20 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(predefs_before + before, + true); +} + +TEST_F(LocalAccessChainConvertTest, OpaqueConverted) { + // SPIR-V not representable in GLSL; not generatable from HLSL + // at the moment + + const std::string predefs = + R"( +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %outColor %texCoords +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %S_t "S_t" +OpMemberName %S_t 0 "v0" +OpMemberName %S_t 1 "v1" +OpMemberName %S_t 2 "smp" +OpName %foo_struct_S_t_vf2_vf21_ "foo(struct-S_t-vf2-vf21;" +OpName %s "s" +OpName %outColor "outColor" +OpName %sampler15 "sampler15" +OpName %s0 "s0" +OpName %texCoords "texCoords" +OpName %param "param" +OpDecorate %sampler15 DescriptorSet 0 +%void = OpTypeVoid +%12 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%outColor = OpVariable %_ptr_Output_v4float Output +%17 = OpTypeImage %float 2D 0 0 0 1 Unknown +%18 = OpTypeSampledImage %17 +%S_t = OpTypeStruct %v2float %v2float %18 +%_ptr_Function_S_t = OpTypePointer Function %S_t +%20 = OpTypeFunction %void %_ptr_Function_S_t +%_ptr_UniformConstant_18 = OpTypePointer UniformConstant %18 +%_ptr_Function_18 = OpTypePointer Function %18 +%sampler15 = OpVariable %_ptr_UniformConstant_18 UniformConstant +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%int_2 = OpConstant %int 2 +%_ptr_Function_v2float = OpTypePointer Function %v2float +%_ptr_Input_v2float = OpTypePointer Input %v2float +%texCoords = OpVariable %_ptr_Input_v2float Input +)"; + + const std::string before = + R"( +; CHECK: [[l1:%\w+]] = OpLoad %S_t %param +; CHECK: [[e1:%\w+]] = OpCompositeExtract {{%\w+}} [[l1]] 2 +; CHECK: [[l2:%\w+]] = OpLoad %S_t %param +; CHECK: [[e2:%\w+]] = OpCompositeExtract {{%\w+}} [[l2]] 0 +; CHECK: OpImageSampleImplicitLod {{%\w+}} [[e1]] [[e2]] +%main = OpFunction %void None %12 +%28 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%param = OpVariable %_ptr_Function_S_t Function +%29 = OpLoad %v2float %texCoords +%30 = OpAccessChain %_ptr_Function_v2float %s0 %int_0 +OpStore %30 %29 +%31 = OpLoad %18 %sampler15 +%32 = OpAccessChain %_ptr_Function_18 %s0 %int_2 +OpStore %32 %31 +%33 = OpLoad %S_t %s0 +OpStore %param %33 +%34 = OpAccessChain %_ptr_Function_18 %param %int_2 +%35 = OpLoad %18 %34 +%36 = OpAccessChain %_ptr_Function_v2float %param %int_0 +%37 = OpLoad %v2float %36 +%38 = OpImageSampleImplicitLod %v4float %35 %37 +OpStore %outColor %38 +OpReturn +OpFunctionEnd +)"; + + const std::string remain = + R"(%foo_struct_S_t_vf2_vf21_ = OpFunction %void None %20 +%s = OpFunctionParameter %_ptr_Function_S_t +%39 = OpLabel +%40 = OpAccessChain %_ptr_Function_18 %s %int_2 +%41 = OpLoad %18 %40 +%42 = OpAccessChain %_ptr_Function_v2float %s %int_0 +%43 = OpLoad %v2float %42 +%44 = OpImageSampleImplicitLod %v4float %41 %43 +OpStore %outColor %44 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(predefs + before + remain, + true); +} + +TEST_F(LocalAccessChainConvertTest, NestedStructsConverted) { + // #version 140 + // + // in vec4 BaseColor; + // + // struct S1_t { + // vec4 v1; + // }; + // + // struct S2_t { + // vec4 v2; + // S1_t s1; + // }; + // + // void main() + // { + // S2_t s2; + // s2.s1.v1 = BaseColor; + // gl_FragColor = s2.s1.v1; + // } + + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %S1_t "S1_t" +OpMemberName %S1_t 0 "v1" +OpName %S2_t "S2_t" +OpMemberName %S2_t 0 "v2" +OpMemberName %S2_t 1 "s1" +OpName %s2 "s2" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%S1_t = OpTypeStruct %v4float +%S2_t = OpTypeStruct %v4float %S1_t +%_ptr_Function_S2_t = OpTypePointer Function %S2_t +%int = OpTypeInt 32 1 +%int_1 = OpConstant %int 1 +%int_0 = OpConstant %int 0 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"( +; CHECK: [[st_id:%\w+]] = OpLoad %v4float %BaseColor +; CHECK: [[ld1:%\w+]] = OpLoad %S2_t %s2 +; CHECK: [[ex1:%\w+]] = OpCompositeInsert %S2_t [[st_id]] [[ld1]] 1 0 +; CHECK: OpStore %s2 [[ex1]] +; CHECK: [[ld2:%\w+]] = OpLoad %S2_t %s2 +; CHECK: [[ex2:%\w+]] = OpCompositeExtract %v4float [[ld2]] 1 0 +; CHECK: OpStore %gl_FragColor [[ex2]] +%main = OpFunction %void None %9 +%19 = OpLabel +%s2 = OpVariable %_ptr_Function_S2_t Function +%20 = OpLoad %v4float %BaseColor +%21 = OpAccessChain %_ptr_Function_v4float %s2 %int_1 %int_0 +OpStore %21 %20 +%22 = OpAccessChain %_ptr_Function_v4float %s2 %int_1 %int_0 +%23 = OpLoad %v4float %22 +OpStore %gl_FragColor %23 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(predefs_before + before, + true); +} + +TEST_F(LocalAccessChainConvertTest, SomeAccessChainsHaveNoUse) { + // Based on HLSL source code: + // struct S { + // float f; + // }; + + // float main(float input : A) : B { + // S local = { input }; + // return local.f; + // } + + const std::string predefs = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %main "main" %in_var_A %out_var_B +OpName %main "main" +OpName %in_var_A "in.var.A" +OpName %out_var_B "out.var.B" +OpName %S "S" +OpName %local "local" +%int = OpTypeInt 32 1 +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%_ptr_Input_float = OpTypePointer Input %float +%_ptr_Output_float = OpTypePointer Output %float +%S = OpTypeStruct %float +%_ptr_Function_S = OpTypePointer Function %S +%int_0 = OpConstant %int 0 +%in_var_A = OpVariable %_ptr_Input_float Input +%out_var_B = OpVariable %_ptr_Output_float Output +%main = OpFunction %void None %8 +%15 = OpLabel +%local = OpVariable %_ptr_Function_S Function +%16 = OpLoad %float %in_var_A +%17 = OpCompositeConstruct %S %16 +OpStore %local %17 +)"; + + const std::string before = + R"( +; CHECK: [[ld:%\w+]] = OpLoad %S %local +; CHECK: [[ex:%\w+]] = OpCompositeExtract %float [[ld]] 0 +; CHECK: OpStore %out_var_B [[ex]] +%18 = OpAccessChain %_ptr_Function_float %local %int_0 +%19 = OpAccessChain %_ptr_Function_float %local %int_0 +%20 = OpLoad %float %18 +OpStore %out_var_B %20 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(predefs + before, true); +} + +TEST_F(LocalAccessChainConvertTest, + StructOfVecsOfFloatConvertedWithDecorationOnLoad) { + // #version 140 + // + // in vec4 BaseColor; + // + // struct S_t { + // vec4 v0; + // vec4 v1; + // }; + // + // void main() + // { + // S_t s0; + // s0.v1 = BaseColor; + // gl_FragColor = s0.v1; + // } + + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %S_t "S_t" +OpMemberName %S_t 0 "v0" +OpMemberName %S_t 1 "v1" +OpName %s0 "s0" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +OpDecorate %21 RelaxedPrecision +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%S_t = OpTypeStruct %v4float %v4float +%_ptr_Function_S_t = OpTypePointer Function %S_t +%int = OpTypeInt 32 1 +%int_1 = OpConstant %int 1 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"( +; CHECK: OpDecorate +; CHECK: OpDecorate [[ld2:%\w+]] RelaxedPrecision +; CHECK-NOT: OpDecorate +; CHECK: [[st_id:%\w+]] = OpLoad %v4float %BaseColor +; CHECK: [[ld1:%\w+]] = OpLoad %S_t %s0 +; CHECK: [[ins:%\w+]] = OpCompositeInsert %S_t [[st_id]] [[ld1]] 1 +; CHECK: OpStore %s0 [[ins]] +; CHECK: [[ld2]] = OpLoad %S_t %s0 +; CHECK: [[ex2:%\w+]] = OpCompositeExtract %v4float [[ld2]] 1 +; CHECK: OpStore %gl_FragColor [[ex2]] +%main = OpFunction %void None %8 +%17 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%18 = OpLoad %v4float %BaseColor +%19 = OpAccessChain %_ptr_Function_v4float %s0 %int_1 +OpStore %19 %18 +%20 = OpAccessChain %_ptr_Function_v4float %s0 %int_1 +%21 = OpLoad %v4float %20 +OpStore %gl_FragColor %21 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(predefs_before + before, + true); +} + +TEST_F(LocalAccessChainConvertTest, + StructOfVecsOfFloatConvertedWithDecorationOnStore) { + // #version 140 + // + // in vec4 BaseColor; + // + // struct S_t { + // vec4 v0; + // vec4 v1; + // }; + // + // void main() + // { + // S_t s0; + // s0.v1 = BaseColor; + // gl_FragColor = s0.v1; + // } + + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %S_t "S_t" +OpMemberName %S_t 0 "v0" +OpMemberName %S_t 1 "v1" +OpName %s0 "s0" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +OpDecorate %s0 RelaxedPrecision +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%S_t = OpTypeStruct %v4float %v4float +%_ptr_Function_S_t = OpTypePointer Function %S_t +%int = OpTypeInt 32 1 +%int_1 = OpConstant %int 1 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"( +; CHECK: OpDecorate +; CHECK: OpDecorate [[ld1:%\w+]] RelaxedPrecision +; CHECK: OpDecorate [[ins:%\w+]] RelaxedPrecision +; CHECK-NOT: OpDecorate +; CHECK: [[st_id:%\w+]] = OpLoad %v4float %BaseColor +; CHECK: [[ld1]] = OpLoad %S_t %s0 +; CHECK: [[ins]] = OpCompositeInsert %S_t [[st_id]] [[ld1]] 1 +; CHECK: OpStore %s0 [[ins]] +; CHECK: [[ld2:%\w+]] = OpLoad %S_t %s0 +; CHECK: [[ex2:%\w+]] = OpCompositeExtract %v4float [[ld2]] 1 +; CHECK: OpStore %gl_FragColor [[ex2]] +%main = OpFunction %void None %8 +%17 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%18 = OpLoad %v4float %BaseColor +%19 = OpAccessChain %_ptr_Function_v4float %s0 %int_1 +OpStore %19 %18 +%20 = OpAccessChain %_ptr_Function_v4float %s0 %int_1 +%21 = OpLoad %v4float %20 +OpStore %gl_FragColor %21 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(predefs_before + before, + true); +} + +TEST_F(LocalAccessChainConvertTest, DynamicallyIndexedVarNotConverted) { + // #version 140 + // + // in vec4 BaseColor; + // flat in int Idx; + // in float Bi; + // + // struct S_t { + // vec4 v0; + // vec4 v1; + // }; + // + // void main() + // { + // S_t s0; + // s0.v1 = BaseColor; + // s0.v1[Idx] = Bi; + // gl_FragColor = s0.v1; + // } + + const std::string assembly = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %Idx %Bi %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %S_t "S_t" +OpMemberName %S_t 0 "v0" +OpMemberName %S_t 1 "v1" +OpName %s0 "s0" +OpName %BaseColor "BaseColor" +OpName %Idx "Idx" +OpName %Bi "Bi" +OpName %gl_FragColor "gl_FragColor" +OpDecorate %Idx Flat +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%S_t = OpTypeStruct %v4float %v4float +%_ptr_Function_S_t = OpTypePointer Function %S_t +%int = OpTypeInt 32 1 +%int_1 = OpConstant %int 1 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_int = OpTypePointer Input %int +%Idx = OpVariable %_ptr_Input_int Input +%_ptr_Input_float = OpTypePointer Input %float +%Bi = OpVariable %_ptr_Input_float Input +%_ptr_Function_float = OpTypePointer Function %float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %10 +%22 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%23 = OpLoad %v4float %BaseColor +%24 = OpAccessChain %_ptr_Function_v4float %s0 %int_1 +OpStore %24 %23 +%25 = OpLoad %int %Idx +%26 = OpLoad %float %Bi +%27 = OpAccessChain %_ptr_Function_float %s0 %int_1 %25 +OpStore %27 %26 +%28 = OpAccessChain %_ptr_Function_v4float %s0 %int_1 +%29 = OpLoad %v4float %28 +OpStore %gl_FragColor %29 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(assembly, assembly, false, + true); +} + +// TODO(greg-lunarg): Add tests to verify handling of these cases: +// +// Assorted vector and matrix types +// Assorted struct array types +// Assorted scalar types +// Assorted non-target types +// OpInBoundsAccessChain +// Others? + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/local_redundancy_elimination_test.cpp b/test/opt/local_redundancy_elimination_test.cpp new file mode 100644 index 000000000..bc4635e29 --- /dev/null +++ b/test/opt/local_redundancy_elimination_test.cpp @@ -0,0 +1,159 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "source/opt/build_module.h" +#include "source/opt/value_number_table.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::HasSubstr; +using ::testing::MatchesRegex; +using LocalRedundancyEliminationTest = PassTest<::testing::Test>; + +// Remove an instruction when it was already computed. +TEST_F(LocalRedundancyEliminationTest, RemoveRedundantAdd) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer Function %5 + %2 = OpFunction %3 None %4 + %7 = OpLabel + %8 = OpVariable %6 Function + %9 = OpLoad %5 %8 + %10 = OpFAdd %5 %9 %9 +; CHECK: OpFAdd +; CHECK-NOT: OpFAdd + %11 = OpFAdd %5 %9 %9 + OpReturn + OpFunctionEnd + )"; + SinglePassRunAndMatch(text, false); +} + +// Make sure we keep instruction that are different, but look similar. +TEST_F(LocalRedundancyEliminationTest, KeepDifferentAdd) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer Function %5 + %2 = OpFunction %3 None %4 + %7 = OpLabel + %8 = OpVariable %6 Function + %9 = OpLoad %5 %8 + %10 = OpFAdd %5 %9 %9 +; CHECK: OpFAdd + OpStore %8 %10 + %11 = OpLoad %5 %8 +; CHECK: %11 = OpLoad + %12 = OpFAdd %5 %11 %11 +; CHECK: OpFAdd [[:%\w+]] %11 %11 + OpReturn + OpFunctionEnd + )"; + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndMatch(text, false); +} + +// This test is check that the values are being propagated properly, and that +// we are able to identify sequences of instruction that are not needed. +TEST_F(LocalRedundancyEliminationTest, RemoveMultipleInstructions) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer Uniform %5 + %8 = OpVariable %6 Uniform + %2 = OpFunction %3 None %4 + %7 = OpLabel +; CHECK: [[r1:%\w+]] = OpLoad + %9 = OpLoad %5 %8 +; CHECK-NEXT: [[r2:%\w+]] = OpFAdd [[:%\w+]] [[r1]] [[r1]] + %10 = OpFAdd %5 %9 %9 +; CHECK-NEXT: [[r3:%\w+]] = OpFMul [[:%\w+]] [[r2]] [[r1]] + %11 = OpFMul %5 %10 %9 +; CHECK-NOT: OpLoad + %12 = OpLoad %5 %8 +; CHECK-NOT: OpFAdd [[:\w+]] %12 %12 + %13 = OpFAdd %5 %12 %12 +; CHECK-NOT: OpFMul + %14 = OpFMul %5 %13 %12 +; CHECK-NEXT: [[:%\w+]] = OpFAdd [[:%\w+]] [[r3]] [[r3]] + %15 = OpFAdd %5 %14 %11 + OpReturn + OpFunctionEnd + )"; + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndMatch(text, false); +} + +// Redundant instructions in different blocks should be kept. +TEST_F(LocalRedundancyEliminationTest, KeepInstructionsInDifferentBlocks) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer Function %5 + %2 = OpFunction %3 None %4 + %bb1 = OpLabel + %8 = OpVariable %6 Function + %9 = OpLoad %5 %8 + %10 = OpFAdd %5 %9 %9 +; CHECK: OpFAdd + OpBranch %bb2 + %bb2 = OpLabel +; CHECK: OpFAdd + %11 = OpFAdd %5 %9 %9 + OpReturn + OpFunctionEnd + )"; + SinglePassRunAndMatch(text, false); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/local_single_block_elim.cpp b/test/opt/local_single_block_elim.cpp new file mode 100644 index 000000000..e2efc5e3c --- /dev/null +++ b/test/opt/local_single_block_elim.cpp @@ -0,0 +1,1074 @@ +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using LocalSingleBlockLoadStoreElimTest = PassTest<::testing::Test>; + +TEST_F(LocalSingleBlockLoadStoreElimTest, SimpleStoreLoadElim) { + // #version 140 + // + // in vec4 BaseColor; + // + // void main() + // { + // vec4 v = BaseColor; + // gl_FragColor = v; + // } + + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %7 +%13 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%14 = OpLoad %v4float %BaseColor +OpStore %v %14 +%15 = OpLoad %v4float %v +OpStore %gl_FragColor %15 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %7 +%13 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%14 = OpLoad %v4float %BaseColor +OpStore %v %14 +OpStore %gl_FragColor %14 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs_before + before, predefs_before + after, true, true); +} + +TEST_F(LocalSingleBlockLoadStoreElimTest, SimpleLoadLoadElim) { + // #version 140 + // + // in vec4 BaseColor; + // in float fi; + // + // void main() + // { + // vec4 v = BaseColor; + // if (fi < 0) + // v = vec4(0.0); + // gl_FragData[0] = v; + // gl_FragData[1] = v; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %fi %gl_FragData +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %fi "fi" +OpName %gl_FragData "gl_FragData" +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%fi = OpVariable %_ptr_Input_float Input +%float_0 = OpConstant %float 0 +%bool = OpTypeBool +%16 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +%uint = OpTypeInt 32 0 +%uint_32 = OpConstant %uint 32 +%_arr_v4float_uint_32 = OpTypeArray %v4float %uint_32 +%_ptr_Output__arr_v4float_uint_32 = OpTypePointer Output %_arr_v4float_uint_32 +%gl_FragData = OpVariable %_ptr_Output__arr_v4float_uint_32 Output +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%int_1 = OpConstant %int 1 +)"; + + const std::string before = + R"(%main = OpFunction %void None %8 +%25 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%26 = OpLoad %v4float %BaseColor +OpStore %v %26 +%27 = OpLoad %float %fi +%28 = OpFOrdLessThan %bool %27 %float_0 +OpSelectionMerge %29 None +OpBranchConditional %28 %30 %29 +%30 = OpLabel +OpStore %v %16 +OpBranch %29 +%29 = OpLabel +%31 = OpLoad %v4float %v +%32 = OpAccessChain %_ptr_Output_v4float %gl_FragData %int_0 +OpStore %32 %31 +%33 = OpLoad %v4float %v +%34 = OpAccessChain %_ptr_Output_v4float %gl_FragData %int_1 +OpStore %34 %33 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %8 +%25 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%26 = OpLoad %v4float %BaseColor +OpStore %v %26 +%27 = OpLoad %float %fi +%28 = OpFOrdLessThan %bool %27 %float_0 +OpSelectionMerge %29 None +OpBranchConditional %28 %30 %29 +%30 = OpLabel +OpStore %v %16 +OpBranch %29 +%29 = OpLabel +%31 = OpLoad %v4float %v +%32 = OpAccessChain %_ptr_Output_v4float %gl_FragData %int_0 +OpStore %32 %31 +%34 = OpAccessChain %_ptr_Output_v4float %gl_FragData %int_1 +OpStore %34 %31 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs + before, predefs + after, true, true); +} + +TEST_F(LocalSingleBlockLoadStoreElimTest, StoreStoreElim) { + // + // Note first store to v is eliminated + // + // #version 450 + // + // layout(location = 0) in vec4 BaseColor; + // layout(location = 0) out vec4 OutColor; + // + // void main() + // { + // vec4 v = BaseColor; + // v = v * 0.5; + // OutColor = v; + // } + + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %OutColor "OutColor" +OpDecorate %BaseColor Location 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%float_0_5 = OpConstant %float 0.5 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %7 +%14 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%15 = OpLoad %v4float %BaseColor +OpStore %v %15 +%16 = OpLoad %v4float %v +%17 = OpVectorTimesScalar %v4float %16 %float_0_5 +OpStore %v %17 +%18 = OpLoad %v4float %v +OpStore %OutColor %18 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %7 +%14 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%15 = OpLoad %v4float %BaseColor +%17 = OpVectorTimesScalar %v4float %15 %float_0_5 +OpStore %v %17 +OpStore %OutColor %17 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs_before + before, predefs_before + after, true, true); +} + +TEST_F(LocalSingleBlockLoadStoreElimTest, + NoStoreElimIfInterveningAccessChainLoad) { + // + // Note the first Store to %v is not eliminated due to the following access + // chain reference. + // + // #version 450 + // + // layout(location = 0) in vec4 BaseColor0; + // layout(location = 1) in vec4 BaseColor1; + // layout(location = 2) flat in int Idx; + // layout(location = 0) out vec4 OutColor; + // + // void main() + // { + // vec4 v = BaseColor0; + // float f = v[Idx]; + // v = BaseColor1 + vec4(0.1); + // OutColor = v/f; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor0 %Idx %BaseColor1 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %v "v" +OpName %BaseColor0 "BaseColor0" +OpName %f "f" +OpName %Idx "Idx" +OpName %BaseColor1 "BaseColor1" +OpName %OutColor "OutColor" +OpDecorate %BaseColor0 Location 0 +OpDecorate %Idx Flat +OpDecorate %Idx Location 2 +OpDecorate %BaseColor1 Location 1 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor0 = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_float = OpTypePointer Function %float +%int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%Idx = OpVariable %_ptr_Input_int Input +%BaseColor1 = OpVariable %_ptr_Input_v4float Input +%float_0_100000001 = OpConstant %float 0.100000001 +%19 = OpConstantComposite %v4float %float_0_100000001 %float_0_100000001 %float_0_100000001 %float_0_100000001 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %10 +%21 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%f = OpVariable %_ptr_Function_float Function +%22 = OpLoad %v4float %BaseColor0 +OpStore %v %22 +%23 = OpLoad %int %Idx +%24 = OpAccessChain %_ptr_Function_float %v %23 +%25 = OpLoad %float %24 +OpStore %f %25 +%26 = OpLoad %v4float %BaseColor1 +%27 = OpFAdd %v4float %26 %19 +OpStore %v %27 +%28 = OpLoad %v4float %v +%29 = OpLoad %float %f +%30 = OpCompositeConstruct %v4float %29 %29 %29 %29 +%31 = OpFDiv %v4float %28 %30 +OpStore %OutColor %31 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %10 +%21 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%f = OpVariable %_ptr_Function_float Function +%22 = OpLoad %v4float %BaseColor0 +OpStore %v %22 +%23 = OpLoad %int %Idx +%24 = OpAccessChain %_ptr_Function_float %v %23 +%25 = OpLoad %float %24 +OpStore %f %25 +%26 = OpLoad %v4float %BaseColor1 +%27 = OpFAdd %v4float %26 %19 +OpStore %v %27 +%30 = OpCompositeConstruct %v4float %25 %25 %25 %25 +%31 = OpFDiv %v4float %27 %30 +OpStore %OutColor %31 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs + before, predefs + after, true, true); +} + +TEST_F(LocalSingleBlockLoadStoreElimTest, NoElimIfInterveningAccessChainStore) { + // #version 140 + // + // in vec4 BaseColor; + // flat in int Idx; + // + // void main() + // { + // vec4 v = BaseColor; + // v[Idx] = 0; + // gl_FragColor = v; + // } + + const std::string assembly = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %Idx %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %Idx "Idx" +OpName %gl_FragColor "gl_FragColor" +OpDecorate %Idx Flat +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%Idx = OpVariable %_ptr_Input_int Input +%float_0 = OpConstant %float 0 +%_ptr_Function_float = OpTypePointer Function %float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %8 +%18 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%19 = OpLoad %v4float %BaseColor +OpStore %v %19 +%20 = OpLoad %int %Idx +%21 = OpAccessChain %_ptr_Function_float %v %20 +OpStore %21 %float_0 +%22 = OpLoad %v4float %v +OpStore %gl_FragColor %22 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(assembly, assembly, + false, true); +} + +TEST_F(LocalSingleBlockLoadStoreElimTest, NoElimIfInterveningFunctionCall) { + // #version 140 + // + // in vec4 BaseColor; + // + // void foo() { + // } + // + // void main() + // { + // vec4 v = BaseColor; + // foo(); + // gl_FragColor = v; + // } + + const std::string assembly = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %foo_ "foo(" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %8 +%14 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%15 = OpLoad %v4float %BaseColor +OpStore %v %15 +%16 = OpFunctionCall %void %foo_ +%17 = OpLoad %v4float %v +OpStore %gl_FragColor %17 +OpReturn +OpFunctionEnd +%foo_ = OpFunction %void None %8 +%18 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(assembly, assembly, + false, true); +} + +TEST_F(LocalSingleBlockLoadStoreElimTest, ElimIfCopyObjectInFunction) { + // Note: SPIR-V hand edited to insert CopyObject + // + // #version 140 + // + // in vec4 BaseColor; + // + // void main() + // { + // vec4 v1 = BaseColor; + // gl_FragData[0] = v1; + // vec4 v2 = BaseColor * 0.5; + // gl_FragData[1] = v2; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragData +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v1 "v1" +OpName %BaseColor "BaseColor" +OpName %gl_FragData "gl_FragData" +OpName %v2 "v2" +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%uint = OpTypeInt 32 0 +%uint_32 = OpConstant %uint 32 +%_arr_v4float_uint_32 = OpTypeArray %v4float %uint_32 +%_ptr_Output__arr_v4float_uint_32 = OpTypePointer Output %_arr_v4float_uint_32 +%gl_FragData = OpVariable %_ptr_Output__arr_v4float_uint_32 Output +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%float_0_5 = OpConstant %float 0.5 +%int_1 = OpConstant %int 1 +)"; + + const std::string before = + R"(%main = OpFunction %void None %8 +%22 = OpLabel +%v1 = OpVariable %_ptr_Function_v4float Function +%v2 = OpVariable %_ptr_Function_v4float Function +%23 = OpLoad %v4float %BaseColor +OpStore %v1 %23 +%24 = OpLoad %v4float %v1 +%25 = OpAccessChain %_ptr_Output_v4float %gl_FragData %int_0 +OpStore %25 %24 +%26 = OpLoad %v4float %BaseColor +%27 = OpVectorTimesScalar %v4float %26 %float_0_5 +%28 = OpCopyObject %_ptr_Function_v4float %v2 +OpStore %28 %27 +%29 = OpLoad %v4float %28 +%30 = OpAccessChain %_ptr_Output_v4float %gl_FragData %int_1 +OpStore %30 %29 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %8 +%22 = OpLabel +%v1 = OpVariable %_ptr_Function_v4float Function +%v2 = OpVariable %_ptr_Function_v4float Function +%23 = OpLoad %v4float %BaseColor +OpStore %v1 %23 +%25 = OpAccessChain %_ptr_Output_v4float %gl_FragData %int_0 +OpStore %25 %23 +%26 = OpLoad %v4float %BaseColor +%27 = OpVectorTimesScalar %v4float %26 %float_0_5 +%28 = OpCopyObject %_ptr_Function_v4float %v2 +OpStore %28 %27 +%30 = OpAccessChain %_ptr_Output_v4float %gl_FragData %int_1 +OpStore %30 %27 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs + before, predefs + after, true, true); +} + +TEST_F(LocalSingleBlockLoadStoreElimTest, ElimOpaque) { + // SPIR-V not representable in GLSL; not generatable from HLSL + // at the moment + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %outColor %texCoords +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %S_t "S_t" +OpMemberName %S_t 0 "v0" +OpMemberName %S_t 1 "v1" +OpMemberName %S_t 2 "smp" +OpName %outColor "outColor" +OpName %sampler15 "sampler15" +OpName %s0 "s0" +OpName %texCoords "texCoords" +OpName %param "param" +OpDecorate %sampler15 DescriptorSet 0 +%void = OpTypeVoid +%12 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%outColor = OpVariable %_ptr_Output_v4float Output +%17 = OpTypeImage %float 2D 0 0 0 1 Unknown +%18 = OpTypeSampledImage %17 +%S_t = OpTypeStruct %v2float %v2float %18 +%_ptr_Function_S_t = OpTypePointer Function %S_t +%20 = OpTypeFunction %void %_ptr_Function_S_t +%_ptr_UniformConstant_18 = OpTypePointer UniformConstant %18 +%_ptr_Function_18 = OpTypePointer Function %18 +%sampler15 = OpVariable %_ptr_UniformConstant_18 UniformConstant +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%int_2 = OpConstant %int 2 +%_ptr_Function_v2float = OpTypePointer Function %v2float +%_ptr_Input_v2float = OpTypePointer Input %v2float +%texCoords = OpVariable %_ptr_Input_v2float Input +)"; + + const std::string before = + R"(%main = OpFunction %void None %12 +%28 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%param = OpVariable %_ptr_Function_S_t Function +%29 = OpLoad %v2float %texCoords +%30 = OpLoad %S_t %s0 +%31 = OpCompositeInsert %S_t %29 %30 0 +OpStore %s0 %31 +%32 = OpLoad %18 %sampler15 +%33 = OpLoad %S_t %s0 +%34 = OpCompositeInsert %S_t %32 %33 2 +OpStore %s0 %34 +%35 = OpLoad %S_t %s0 +OpStore %param %35 +%36 = OpLoad %S_t %param +%37 = OpCompositeExtract %18 %36 2 +%38 = OpLoad %S_t %param +%39 = OpCompositeExtract %v2float %38 0 +%40 = OpImageSampleImplicitLod %v4float %37 %39 +OpStore %outColor %40 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %12 +%28 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%param = OpVariable %_ptr_Function_S_t Function +%29 = OpLoad %v2float %texCoords +%30 = OpLoad %S_t %s0 +%31 = OpCompositeInsert %S_t %29 %30 0 +%32 = OpLoad %18 %sampler15 +%34 = OpCompositeInsert %S_t %32 %31 2 +OpStore %s0 %34 +OpStore %param %34 +%37 = OpCompositeExtract %18 %34 2 +%39 = OpCompositeExtract %v2float %34 0 +%40 = OpImageSampleImplicitLod %v4float %37 %39 +OpStore %outColor %40 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck( + predefs + before, predefs + after, true, true); +} + +TEST_F(LocalSingleBlockLoadStoreElimTest, PositiveAndNegativeCallTree) { + // Note that the call tree function bar is optimized, but foo is not + // + // #version 140 + // + // in vec4 BaseColor; + // + // vec4 foo(vec4 v1) + // { + // vec4 t = v1; + // return t; + // } + // + // vec4 bar(vec4 v1) + // { + // vec4 t = v1; + // return t; + // } + // + // void main() + // { + // gl_FragColor = bar(BaseColor); + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %gl_FragColor %BaseColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %foo_vf4_ "foo(vf4;" +OpName %v1 "v1" +OpName %bar_vf4_ "bar(vf4;" +OpName %v1_0 "v1" +OpName %t "t" +OpName %t_0 "t" +OpName %gl_FragColor "gl_FragColor" +OpName %BaseColor "BaseColor" +OpName %param "param" +%void = OpTypeVoid +%13 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%17 = OpTypeFunction %v4float %_ptr_Function_v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%main = OpFunction %void None %13 +%20 = OpLabel +%param = OpVariable %_ptr_Function_v4float Function +%21 = OpLoad %v4float %BaseColor +OpStore %param %21 +%22 = OpFunctionCall %v4float %bar_vf4_ %param +OpStore %gl_FragColor %22 +OpReturn +OpFunctionEnd +)"; + + const std::string before = + R"(%foo_vf4_ = OpFunction %v4float None %17 +%v1 = OpFunctionParameter %_ptr_Function_v4float +%23 = OpLabel +%t = OpVariable %_ptr_Function_v4float Function +%24 = OpLoad %v4float %v1 +OpStore %t %24 +%25 = OpLoad %v4float %t +OpReturnValue %25 +OpFunctionEnd +%bar_vf4_ = OpFunction %v4float None %17 +%v1_0 = OpFunctionParameter %_ptr_Function_v4float +%26 = OpLabel +%t_0 = OpVariable %_ptr_Function_v4float Function +%27 = OpLoad %v4float %v1_0 +OpStore %t_0 %27 +%28 = OpLoad %v4float %t_0 +OpReturnValue %28 +OpFunctionEnd +)"; + + const std::string after = + R"(%foo_vf4_ = OpFunction %v4float None %17 +%v1 = OpFunctionParameter %_ptr_Function_v4float +%23 = OpLabel +%t = OpVariable %_ptr_Function_v4float Function +%24 = OpLoad %v4float %v1 +OpStore %t %24 +%25 = OpLoad %v4float %t +OpReturnValue %25 +OpFunctionEnd +%bar_vf4_ = OpFunction %v4float None %17 +%v1_0 = OpFunctionParameter %_ptr_Function_v4float +%26 = OpLabel +%t_0 = OpVariable %_ptr_Function_v4float Function +%27 = OpLoad %v4float %v1_0 +OpStore %t_0 %27 +OpReturnValue %27 +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs + before, predefs + after, true, true); +} + +TEST_F(LocalSingleBlockLoadStoreElimTest, PointerVariable) { + // Test that checks if a pointer variable is removed. + + const std::string before = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "main" %2 +OpExecutionMode %1 OriginUpperLeft +OpMemberDecorate %_struct_3 0 Offset 0 +OpDecorate %_runtimearr__struct_3 ArrayStride 16 +OpMemberDecorate %_struct_5 0 Offset 0 +OpDecorate %_struct_5 BufferBlock +OpMemberDecorate %_struct_6 0 Offset 0 +OpDecorate %_struct_6 BufferBlock +OpDecorate %2 Location 0 +OpDecorate %7 DescriptorSet 0 +OpDecorate %7 Binding 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%uint = OpTypeInt 32 0 +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_ptr_Uniform_v4float = OpTypePointer Uniform %v4float +%_struct_3 = OpTypeStruct %v4float +%_runtimearr__struct_3 = OpTypeRuntimeArray %_struct_3 +%_struct_5 = OpTypeStruct %_runtimearr__struct_3 +%_ptr_Uniform__struct_5 = OpTypePointer Uniform %_struct_5 +%_struct_6 = OpTypeStruct %int +%_ptr_Uniform__struct_6 = OpTypePointer Uniform %_struct_6 +%_ptr_Function__ptr_Uniform__struct_5 = OpTypePointer Function %_ptr_Uniform__struct_5 +%_ptr_Function__ptr_Uniform__struct_6 = OpTypePointer Function %_ptr_Uniform__struct_6 +%int_0 = OpConstant %int 0 +%uint_0 = OpConstant %uint 0 +%2 = OpVariable %_ptr_Output_v4float Output +%7 = OpVariable %_ptr_Uniform__struct_5 Uniform +%1 = OpFunction %void None %10 +%23 = OpLabel +%24 = OpVariable %_ptr_Function__ptr_Uniform__struct_5 Function +OpStore %24 %7 +%26 = OpLoad %_ptr_Uniform__struct_5 %24 +%27 = OpAccessChain %_ptr_Uniform_v4float %26 %int_0 %uint_0 %int_0 +%28 = OpLoad %v4float %27 +%29 = OpCopyObject %v4float %28 +OpStore %2 %28 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "main" %2 +OpExecutionMode %1 OriginUpperLeft +OpMemberDecorate %_struct_3 0 Offset 0 +OpDecorate %_runtimearr__struct_3 ArrayStride 16 +OpMemberDecorate %_struct_5 0 Offset 0 +OpDecorate %_struct_5 BufferBlock +OpMemberDecorate %_struct_6 0 Offset 0 +OpDecorate %_struct_6 BufferBlock +OpDecorate %2 Location 0 +OpDecorate %7 DescriptorSet 0 +OpDecorate %7 Binding 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%uint = OpTypeInt 32 0 +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_ptr_Uniform_v4float = OpTypePointer Uniform %v4float +%_struct_3 = OpTypeStruct %v4float +%_runtimearr__struct_3 = OpTypeRuntimeArray %_struct_3 +%_struct_5 = OpTypeStruct %_runtimearr__struct_3 +%_ptr_Uniform__struct_5 = OpTypePointer Uniform %_struct_5 +%_struct_6 = OpTypeStruct %int +%_ptr_Uniform__struct_6 = OpTypePointer Uniform %_struct_6 +%_ptr_Function__ptr_Uniform__struct_5 = OpTypePointer Function %_ptr_Uniform__struct_5 +%_ptr_Function__ptr_Uniform__struct_6 = OpTypePointer Function %_ptr_Uniform__struct_6 +%int_0 = OpConstant %int 0 +%uint_0 = OpConstant %uint 0 +%2 = OpVariable %_ptr_Output_v4float Output +%7 = OpVariable %_ptr_Uniform__struct_5 Uniform +%1 = OpFunction %void None %10 +%23 = OpLabel +%24 = OpVariable %_ptr_Function__ptr_Uniform__struct_5 Function +OpStore %24 %7 +%27 = OpAccessChain %_ptr_Uniform_v4float %7 %int_0 %uint_0 %int_0 +%28 = OpLoad %v4float %27 +%29 = OpCopyObject %v4float %28 +OpStore %2 %28 +OpReturn +OpFunctionEnd +)"; + + // Relax logical pointers to allow pointer allocations. + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ValidatorOptions()->relax_logical_pointer = true; + SinglePassRunAndCheck(before, after, true, + true); +} + +TEST_F(LocalSingleBlockLoadStoreElimTest, RedundantStore) { + // Test that checks if a pointer variable is removed. + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %7 +%13 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%14 = OpLoad %v4float %BaseColor +OpStore %v %14 +OpBranch %16 +%16 = OpLabel +%15 = OpLoad %v4float %v +OpStore %v %15 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %7 +%13 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%14 = OpLoad %v4float %BaseColor +OpStore %v %14 +OpBranch %16 +%16 = OpLabel +%15 = OpLoad %v4float %v +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck( + predefs_before + before, predefs_before + after, true, true); +} + +TEST_F(LocalSingleBlockLoadStoreElimTest, RedundantStore2) { + // Test that checks if a pointer variable is removed. + const std::string predefs_before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %7 +%13 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%14 = OpLoad %v4float %BaseColor +OpStore %v %14 +OpBranch %16 +%16 = OpLabel +%15 = OpLoad %v4float %v +OpStore %v %15 +%17 = OpLoad %v4float %v +OpStore %v %17 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %7 +%13 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%14 = OpLoad %v4float %BaseColor +OpStore %v %14 +OpBranch %16 +%16 = OpLabel +%15 = OpLoad %v4float %v +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck( + predefs_before + before, predefs_before + after, true, true); +} + +// Test that that an unused OpAccessChain between two store does does not +// hinders the removal of the first store. We need to check this because +// local-access-chain-convert does always remove the OpAccessChain instructions +// that become dead. + +TEST_F(LocalSingleBlockLoadStoreElimTest, + StoreElimIfInterveningUnusedAccessChain) { + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor0 %Idx %BaseColor1 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %v "v" +OpName %BaseColor0 "BaseColor0" +OpName %Idx "Idx" +OpName %BaseColor1 "BaseColor1" +OpName %OutColor "OutColor" +OpDecorate %BaseColor0 Location 0 +OpDecorate %Idx Flat +OpDecorate %Idx Location 2 +OpDecorate %BaseColor1 Location 1 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor0 = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_float = OpTypePointer Function %float +%int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%Idx = OpVariable %_ptr_Input_int Input +%BaseColor1 = OpVariable %_ptr_Input_v4float Input +%float_0_100000001 = OpConstant %float 0.100000001 +%19 = OpConstantComposite %v4float %float_0_100000001 %float_0_100000001 %float_0_100000001 %float_0_100000001 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %10 +%21 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%22 = OpLoad %v4float %BaseColor0 +OpStore %v %22 +%23 = OpLoad %int %Idx +%24 = OpAccessChain %_ptr_Function_float %v %23 +%26 = OpLoad %v4float %BaseColor1 +%27 = OpFAdd %v4float %26 %19 +OpStore %v %27 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %10 +%21 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%22 = OpLoad %v4float %BaseColor0 +%23 = OpLoad %int %Idx +%24 = OpAccessChain %_ptr_Function_float %v %23 +%26 = OpLoad %v4float %BaseColor1 +%27 = OpFAdd %v4float %26 %19 +OpStore %v %27 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck( + predefs + before, predefs + after, true, true); +} +// TODO(greg-lunarg): Add tests to verify handling of these cases: +// +// Other target variable types +// InBounds Access Chains +// Check for correctness in the presence of function calls +// Others? + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/local_single_store_elim_test.cpp b/test/opt/local_single_store_elim_test.cpp new file mode 100644 index 000000000..5fd0217e3 --- /dev/null +++ b/test/opt/local_single_store_elim_test.cpp @@ -0,0 +1,857 @@ +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using LocalSingleStoreElimTest = PassTest<::testing::Test>; + +TEST_F(LocalSingleStoreElimTest, PositiveAndNegative) { + // Single store to v is optimized. Multiple store to + // f is not optimized. + // + // #version 140 + // + // in vec4 BaseColor; + // in float fi; + // + // void main() + // { + // vec4 v = BaseColor; + // float f = fi; + // if (f < 0) + // f = 0.0; + // gl_FragColor = v + f; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %fi %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %f "f" +OpName %fi "fi" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_float = OpTypePointer Function %float +%_ptr_Input_float = OpTypePointer Input %float +%fi = OpVariable %_ptr_Input_float Input +%float_0 = OpConstant %float 0 +%bool = OpTypeBool +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %9 +%19 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%f = OpVariable %_ptr_Function_float Function +%20 = OpLoad %v4float %BaseColor +OpStore %v %20 +%21 = OpLoad %float %fi +OpStore %f %21 +%22 = OpLoad %float %f +%23 = OpFOrdLessThan %bool %22 %float_0 +OpSelectionMerge %24 None +OpBranchConditional %23 %25 %24 +%25 = OpLabel +OpStore %f %float_0 +OpBranch %24 +%24 = OpLabel +%26 = OpLoad %v4float %v +%27 = OpLoad %float %f +%28 = OpCompositeConstruct %v4float %27 %27 %27 %27 +%29 = OpFAdd %v4float %26 %28 +OpStore %gl_FragColor %29 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %9 +%19 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%f = OpVariable %_ptr_Function_float Function +%20 = OpLoad %v4float %BaseColor +OpStore %v %20 +%21 = OpLoad %float %fi +OpStore %f %21 +%22 = OpLoad %float %f +%23 = OpFOrdLessThan %bool %22 %float_0 +OpSelectionMerge %24 None +OpBranchConditional %23 %25 %24 +%25 = OpLabel +OpStore %f %float_0 +OpBranch %24 +%24 = OpLabel +%27 = OpLoad %float %f +%28 = OpCompositeConstruct %v4float %27 %27 %27 %27 +%29 = OpFAdd %v4float %20 %28 +OpStore %gl_FragColor %29 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); +} + +TEST_F(LocalSingleStoreElimTest, ThreeStores) { + // Three stores to multiple loads of v is not optimized. + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %fi %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %fi "fi" +OpName %r "r" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%fi = OpVariable %_ptr_Input_float Input +%float_0 = OpConstant %float 0 +%bool = OpTypeBool +%float_1 = OpConstant %float 1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %9 +%19 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%r = OpVariable %_ptr_Function_v4float Function +%20 = OpLoad %v4float %BaseColor +OpStore %v %20 +%21 = OpLoad %float %fi +%22 = OpFOrdLessThan %bool %21 %float_0 +OpSelectionMerge %23 None +OpBranchConditional %22 %24 %25 +%24 = OpLabel +%26 = OpLoad %v4float %v +OpStore %v %26 +OpStore %r %26 +OpBranch %23 +%25 = OpLabel +%27 = OpLoad %v4float %v +%28 = OpCompositeConstruct %v4float %float_1 %float_1 %float_1 %float_1 +OpStore %v %28 +%29 = OpFSub %v4float %28 %27 +OpStore %r %29 +OpBranch %23 +%23 = OpLabel +%30 = OpLoad %v4float %r +OpStore %gl_FragColor %30 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, + predefs + before, true, true); +} + +TEST_F(LocalSingleStoreElimTest, MultipleLoads) { + // Single store to multiple loads of v is optimized. + // + // #version 140 + // + // in vec4 BaseColor; + // in float fi; + // + // void main() + // { + // vec4 v = BaseColor; + // float f = fi; + // if (f < 0) + // f = 0.0; + // gl_FragColor = v + f; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %fi %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %fi "fi" +OpName %r "r" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%fi = OpVariable %_ptr_Input_float Input +%float_0 = OpConstant %float 0 +%bool = OpTypeBool +%float_1 = OpConstant %float 1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %9 +%19 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%r = OpVariable %_ptr_Function_v4float Function +%20 = OpLoad %v4float %BaseColor +OpStore %v %20 +%21 = OpLoad %float %fi +%22 = OpFOrdLessThan %bool %21 %float_0 +OpSelectionMerge %23 None +OpBranchConditional %22 %24 %25 +%24 = OpLabel +%26 = OpLoad %v4float %v +OpStore %r %26 +OpBranch %23 +%25 = OpLabel +%27 = OpLoad %v4float %v +%28 = OpCompositeConstruct %v4float %float_1 %float_1 %float_1 %float_1 +%29 = OpFSub %v4float %28 %27 +OpStore %r %29 +OpBranch %23 +%23 = OpLabel +%30 = OpLoad %v4float %r +OpStore %gl_FragColor %30 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %9 +%19 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%r = OpVariable %_ptr_Function_v4float Function +%20 = OpLoad %v4float %BaseColor +OpStore %v %20 +%21 = OpLoad %float %fi +%22 = OpFOrdLessThan %bool %21 %float_0 +OpSelectionMerge %23 None +OpBranchConditional %22 %24 %25 +%24 = OpLabel +OpStore %r %20 +OpBranch %23 +%25 = OpLabel +%28 = OpCompositeConstruct %v4float %float_1 %float_1 %float_1 %float_1 +%29 = OpFSub %v4float %28 %20 +OpStore %r %29 +OpBranch %23 +%23 = OpLabel +%30 = OpLoad %v4float %r +OpStore %gl_FragColor %30 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); +} + +TEST_F(LocalSingleStoreElimTest, NoStoreElimWithInterveningAccessChainLoad) { + // Last load of v is eliminated, but access chain load and store of v isn't + // + // #version 140 + // + // in vec4 BaseColor; + // + // void main() + // { + // vec4 v = BaseColor; + // float f = v[3]; + // gl_FragColor = v * f; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %f "f" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_float = OpTypePointer Function %float +%uint = OpTypeInt 32 0 +%uint_3 = OpConstant %uint 3 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %8 +%17 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%f = OpVariable %_ptr_Function_float Function +%18 = OpLoad %v4float %BaseColor +OpStore %v %18 +%19 = OpAccessChain %_ptr_Function_float %v %uint_3 +%20 = OpLoad %float %19 +OpStore %f %20 +%21 = OpLoad %v4float %v +%22 = OpLoad %float %f +%23 = OpVectorTimesScalar %v4float %21 %22 +OpStore %gl_FragColor %23 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %8 +%17 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%f = OpVariable %_ptr_Function_float Function +%18 = OpLoad %v4float %BaseColor +OpStore %v %18 +%19 = OpAccessChain %_ptr_Function_float %v %uint_3 +%20 = OpLoad %float %19 +OpStore %f %20 +%23 = OpVectorTimesScalar %v4float %18 %20 +OpStore %gl_FragColor %23 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); +} + +TEST_F(LocalSingleStoreElimTest, NoReplaceOfDominatingPartialStore) { + // Note: SPIR-V hand edited to initialize v to vec4(0.0) + // + // #version 140 + // + // in vec4 BaseColor; + // + // void main() + // { + // vec4 v; + // float v[1] = 1.0; + // gl_FragColor = v; + // } + + const std::string assembly = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %gl_FragColor %BaseColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %gl_FragColor "gl_FragColor" +OpName %BaseColor "BaseColor" +%void = OpTypeVoid +%7 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%float_0 = OpConstant %float 0 +%12 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +%float_1 = OpConstant %float 1 +%uint = OpTypeInt 32 0 +%uint_1 = OpConstant %uint 1 +%_ptr_Function_float = OpTypePointer Function %float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%main = OpFunction %void None %7 +%19 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function %12 +%20 = OpAccessChain %_ptr_Function_float %v %uint_1 +OpStore %20 %float_1 +%21 = OpLoad %v4float %v +OpStore %gl_FragColor %21 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(assembly, assembly, true, + true); +} + +TEST_F(LocalSingleStoreElimTest, ElimIfCopyObjectInFunction) { + // Note: hand edited to insert OpCopyObject + // + // #version 140 + // + // in vec4 BaseColor; + // in float fi; + // + // void main() + // { + // vec4 v = BaseColor; + // float f = fi; + // if (f < 0) + // f = 0.0; + // gl_FragColor = v + f; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %fi %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %f "f" +OpName %fi "fi" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_float = OpTypePointer Function %float +%_ptr_Input_float = OpTypePointer Input %float +%fi = OpVariable %_ptr_Input_float Input +%float_0 = OpConstant %float 0 +%bool = OpTypeBool +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %9 +%19 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%f = OpVariable %_ptr_Function_float Function +%20 = OpLoad %v4float %BaseColor +OpStore %v %20 +%21 = OpLoad %float %fi +OpStore %f %21 +%22 = OpLoad %float %f +%23 = OpFOrdLessThan %bool %22 %float_0 +OpSelectionMerge %24 None +OpBranchConditional %23 %25 %24 +%25 = OpLabel +OpStore %f %float_0 +OpBranch %24 +%24 = OpLabel +%26 = OpCopyObject %_ptr_Function_v4float %v +%27 = OpLoad %v4float %26 +%28 = OpLoad %float %f +%29 = OpCompositeConstruct %v4float %28 %28 %28 %28 +%30 = OpFAdd %v4float %27 %29 +OpStore %gl_FragColor %30 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %9 +%19 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%f = OpVariable %_ptr_Function_float Function +%20 = OpLoad %v4float %BaseColor +OpStore %v %20 +%21 = OpLoad %float %fi +OpStore %f %21 +%22 = OpLoad %float %f +%23 = OpFOrdLessThan %bool %22 %float_0 +OpSelectionMerge %24 None +OpBranchConditional %23 %25 %24 +%25 = OpLabel +OpStore %f %float_0 +OpBranch %24 +%24 = OpLabel +%26 = OpCopyObject %_ptr_Function_v4float %v +%28 = OpLoad %float %f +%29 = OpCompositeConstruct %v4float %28 %28 %28 %28 +%30 = OpFAdd %v4float %20 %29 +OpStore %gl_FragColor %30 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); +} + +TEST_F(LocalSingleStoreElimTest, NoOptIfStoreNotDominating) { + // Single store to f not optimized because it does not dominate + // the load. + // + // #version 140 + // + // in vec4 BaseColor; + // in float fi; + // + // void main() + // { + // float f; + // if (fi < 0) + // f = 0.5; + // if (fi < 0) + // gl_FragColor = BaseColor * f; + // else + // gl_FragColor = BaseColor; + // } + + const std::string assembly = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %fi %gl_FragColor %BaseColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %fi "fi" +OpName %f "f" +OpName %gl_FragColor "gl_FragColor" +OpName %BaseColor "BaseColor" +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Input_float = OpTypePointer Input %float +%fi = OpVariable %_ptr_Input_float Input +%float_0 = OpConstant %float 0 +%bool = OpTypeBool +%_ptr_Function_float = OpTypePointer Function %float +%float_0_5 = OpConstant %float 0.5 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%main = OpFunction %void None %8 +%18 = OpLabel +%f = OpVariable %_ptr_Function_float Function +%19 = OpLoad %float %fi +%20 = OpFOrdLessThan %bool %19 %float_0 +OpSelectionMerge %21 None +OpBranchConditional %20 %22 %21 +%22 = OpLabel +OpStore %f %float_0_5 +OpBranch %21 +%21 = OpLabel +%23 = OpLoad %float %fi +%24 = OpFOrdLessThan %bool %23 %float_0 +OpSelectionMerge %25 None +OpBranchConditional %24 %26 %27 +%26 = OpLabel +%28 = OpLoad %v4float %BaseColor +%29 = OpLoad %float %f +%30 = OpVectorTimesScalar %v4float %28 %29 +OpStore %gl_FragColor %30 +OpBranch %25 +%27 = OpLabel +%31 = OpLoad %v4float %BaseColor +OpStore %gl_FragColor %31 +OpBranch %25 +%25 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(assembly, assembly, true, + true); +} + +TEST_F(LocalSingleStoreElimTest, OptInitializedVariableLikeStore) { + // Initialized variable f is optimized like it was a store. + // Note: The SPIR-V was edited to turn the store to f to an + // an initialization. + // + // #version 140 + // + // void main() + // { + // float f = 0.0; + // gl_FragColor = vec4(f); + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %f "f" +OpName %gl_FragColor "gl_FragColor" +OpDecorate %gl_FragColor Location 0 +%void = OpTypeVoid +%6 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %6 +%12 = OpLabel +%f = OpVariable %_ptr_Function_float Function %float_0 +%13 = OpLoad %float %f +%14 = OpCompositeConstruct %v4float %13 %13 %13 %13 +OpStore %gl_FragColor %14 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %6 +%12 = OpLabel +%f = OpVariable %_ptr_Function_float Function %float_0 +%14 = OpCompositeConstruct %v4float %float_0 %float_0 %float_0 %float_0 +OpStore %gl_FragColor %14 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); +} + +TEST_F(LocalSingleStoreElimTest, PointerVariable) { + // Test that checks if a pointer variable is removed. + + const std::string before = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "main" %2 +OpExecutionMode %1 OriginUpperLeft +OpMemberDecorate %_struct_3 0 Offset 0 +OpDecorate %_runtimearr__struct_3 ArrayStride 16 +OpMemberDecorate %_struct_5 0 Offset 0 +OpDecorate %_struct_5 BufferBlock +OpMemberDecorate %_struct_6 0 Offset 0 +OpDecorate %_struct_6 BufferBlock +OpDecorate %2 Location 0 +OpDecorate %7 DescriptorSet 0 +OpDecorate %7 Binding 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%uint = OpTypeInt 32 0 +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_ptr_Uniform_v4float = OpTypePointer Uniform %v4float +%_struct_3 = OpTypeStruct %v4float +%_runtimearr__struct_3 = OpTypeRuntimeArray %_struct_3 +%_struct_5 = OpTypeStruct %_runtimearr__struct_3 +%_ptr_Uniform__struct_5 = OpTypePointer Uniform %_struct_5 +%_struct_6 = OpTypeStruct %int +%_ptr_Uniform__struct_6 = OpTypePointer Uniform %_struct_6 +%_ptr_Function__ptr_Uniform__struct_5 = OpTypePointer Function %_ptr_Uniform__struct_5 +%_ptr_Function__ptr_Uniform__struct_6 = OpTypePointer Function %_ptr_Uniform__struct_6 +%int_0 = OpConstant %int 0 +%uint_0 = OpConstant %uint 0 +%2 = OpVariable %_ptr_Output_v4float Output +%7 = OpVariable %_ptr_Uniform__struct_5 Uniform +%1 = OpFunction %void None %10 +%23 = OpLabel +%24 = OpVariable %_ptr_Function__ptr_Uniform__struct_5 Function +OpStore %24 %7 +%26 = OpLoad %_ptr_Uniform__struct_5 %24 +%27 = OpAccessChain %_ptr_Uniform_v4float %26 %int_0 %uint_0 %int_0 +%28 = OpLoad %v4float %27 +%29 = OpCopyObject %v4float %28 +OpStore %2 %28 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "main" %2 +OpExecutionMode %1 OriginUpperLeft +OpMemberDecorate %_struct_3 0 Offset 0 +OpDecorate %_runtimearr__struct_3 ArrayStride 16 +OpMemberDecorate %_struct_5 0 Offset 0 +OpDecorate %_struct_5 BufferBlock +OpMemberDecorate %_struct_6 0 Offset 0 +OpDecorate %_struct_6 BufferBlock +OpDecorate %2 Location 0 +OpDecorate %7 DescriptorSet 0 +OpDecorate %7 Binding 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%uint = OpTypeInt 32 0 +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_ptr_Uniform_v4float = OpTypePointer Uniform %v4float +%_struct_3 = OpTypeStruct %v4float +%_runtimearr__struct_3 = OpTypeRuntimeArray %_struct_3 +%_struct_5 = OpTypeStruct %_runtimearr__struct_3 +%_ptr_Uniform__struct_5 = OpTypePointer Uniform %_struct_5 +%_struct_6 = OpTypeStruct %int +%_ptr_Uniform__struct_6 = OpTypePointer Uniform %_struct_6 +%_ptr_Function__ptr_Uniform__struct_5 = OpTypePointer Function %_ptr_Uniform__struct_5 +%_ptr_Function__ptr_Uniform__struct_6 = OpTypePointer Function %_ptr_Uniform__struct_6 +%int_0 = OpConstant %int 0 +%uint_0 = OpConstant %uint 0 +%2 = OpVariable %_ptr_Output_v4float Output +%7 = OpVariable %_ptr_Uniform__struct_5 Uniform +%1 = OpFunction %void None %10 +%23 = OpLabel +%24 = OpVariable %_ptr_Function__ptr_Uniform__struct_5 Function +OpStore %24 %7 +%27 = OpAccessChain %_ptr_Uniform_v4float %7 %int_0 %uint_0 %int_0 +%28 = OpLoad %v4float %27 +%29 = OpCopyObject %v4float %28 +OpStore %2 %28 +OpReturn +OpFunctionEnd +)"; + + // Relax logical pointers to allow pointer allocations. + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ValidatorOptions()->relax_logical_pointer = true; + SinglePassRunAndCheck(before, after, true, true); +} + +// Test that that an unused OpAccessChain between a store and a use does does +// not hinders the replacement of the use. We need to check this because +// local-access-chain-convert does always remove the OpAccessChain instructions +// that become dead. + +TEST_F(LocalSingleStoreElimTest, + StoreElimWithUnusedInterveningAccessChainLoad) { + // Last load of v is eliminated, but access chain load and store of v isn't + // + // #version 140 + // + // in vec4 BaseColor; + // + // void main() + // { + // vec4 v = BaseColor; + // float f = v[3]; + // gl_FragColor = v * f; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_float = OpTypePointer Function %float +%uint = OpTypeInt 32 0 +%uint_3 = OpConstant %uint 3 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %8 +%17 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%18 = OpLoad %v4float %BaseColor +OpStore %v %18 +%19 = OpAccessChain %_ptr_Function_float %v %uint_3 +%21 = OpLoad %v4float %v +OpStore %gl_FragColor %21 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %8 +%17 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%18 = OpLoad %v4float %BaseColor +OpStore %v %18 +%19 = OpAccessChain %_ptr_Function_float %v %uint_3 +OpStore %gl_FragColor %18 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); +} +// TODO(greg-lunarg): Add tests to verify handling of these cases: +// +// Other types +// Others? + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/local_ssa_elim_test.cpp b/test/opt/local_ssa_elim_test.cpp new file mode 100644 index 000000000..b2961a23c --- /dev/null +++ b/test/opt/local_ssa_elim_test.cpp @@ -0,0 +1,1834 @@ +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using LocalSSAElimTest = PassTest<::testing::Test>; + +TEST_F(LocalSSAElimTest, ForLoop) { + // #version 140 + // + // in vec4 BC; + // out float fo; + // + // void main() + // { + // float f = 0.0; + // for (int i=0; i<4; i++) { + // f = f + BC[i]; + // } + // fo = f; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BC %fo +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %f "f" +OpName %i "i" +OpName %BC "BC" +OpName %fo "fo" +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%int_4 = OpConstant %int 4 +%bool = OpTypeBool +%v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BC = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%int_1 = OpConstant %int 1 +%_ptr_Output_float = OpTypePointer Output %float +%fo = OpVariable %_ptr_Output_float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %8 +%22 = OpLabel +%f = OpVariable %_ptr_Function_float Function +%i = OpVariable %_ptr_Function_int Function +OpStore %f %float_0 +OpStore %i %int_0 +OpBranch %23 +%23 = OpLabel +OpLoopMerge %24 %25 None +OpBranch %26 +%26 = OpLabel +%27 = OpLoad %int %i +%28 = OpSLessThan %bool %27 %int_4 +OpBranchConditional %28 %29 %24 +%29 = OpLabel +%30 = OpLoad %float %f +%31 = OpLoad %int %i +%32 = OpAccessChain %_ptr_Input_float %BC %31 +%33 = OpLoad %float %32 +%34 = OpFAdd %float %30 %33 +OpStore %f %34 +OpBranch %25 +%25 = OpLabel +%35 = OpLoad %int %i +%36 = OpIAdd %int %35 %int_1 +OpStore %i %36 +OpBranch %23 +%24 = OpLabel +%37 = OpLoad %float %f +OpStore %fo %37 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %8 +%22 = OpLabel +%f = OpVariable %_ptr_Function_float Function +%i = OpVariable %_ptr_Function_int Function +OpStore %f %float_0 +OpStore %i %int_0 +OpBranch %23 +%23 = OpLabel +%39 = OpPhi %float %float_0 %22 %34 %25 +%38 = OpPhi %int %int_0 %22 %36 %25 +OpLoopMerge %24 %25 None +OpBranch %26 +%26 = OpLabel +%28 = OpSLessThan %bool %38 %int_4 +OpBranchConditional %28 %29 %24 +%29 = OpLabel +%32 = OpAccessChain %_ptr_Input_float %BC %38 +%33 = OpLoad %float %32 +%34 = OpFAdd %float %39 %33 +OpStore %f %34 +OpBranch %25 +%25 = OpLabel +%36 = OpIAdd %int %38 %int_1 +OpStore %i %36 +OpBranch %23 +%24 = OpLabel +OpStore %fo %39 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); +} + +TEST_F(LocalSSAElimTest, NestedForLoop) { + // #version 450 + // + // layout (location=0) in mat4 BC; + // layout (location=0) out float fo; + // + // void main() + // { + // float f = 0.0; + // for (int i=0; i<4; i++) + // for (int j=0; j<4; j++) + // f = f + BC[i][j]; + // fo = f; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BC %fo +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %f "f" +OpName %i "i" +OpName %j "j" +OpName %BC "BC" +OpName %fo "fo" +OpDecorate %BC Location 0 +OpDecorate %fo Location 0 +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%int_4 = OpConstant %int 4 +%bool = OpTypeBool +%v4float = OpTypeVector %float 4 +%mat4v4float = OpTypeMatrix %v4float 4 +%_ptr_Input_mat4v4float = OpTypePointer Input %mat4v4float +%BC = OpVariable %_ptr_Input_mat4v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%int_1 = OpConstant %int 1 +%_ptr_Output_float = OpTypePointer Output %float +%fo = OpVariable %_ptr_Output_float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %9 +%24 = OpLabel +%f = OpVariable %_ptr_Function_float Function +%i = OpVariable %_ptr_Function_int Function +%j = OpVariable %_ptr_Function_int Function +OpStore %f %float_0 +OpStore %i %int_0 +OpBranch %25 +%25 = OpLabel +%26 = OpLoad %int %i +%27 = OpSLessThan %bool %26 %int_4 +OpLoopMerge %28 %29 None +OpBranchConditional %27 %30 %28 +%30 = OpLabel +OpStore %j %int_0 +OpBranch %31 +%31 = OpLabel +%32 = OpLoad %int %j +%33 = OpSLessThan %bool %32 %int_4 +OpLoopMerge %29 %34 None +OpBranchConditional %33 %34 %29 +%34 = OpLabel +%35 = OpLoad %float %f +%36 = OpLoad %int %i +%37 = OpLoad %int %j +%38 = OpAccessChain %_ptr_Input_float %BC %36 %37 +%39 = OpLoad %float %38 +%40 = OpFAdd %float %35 %39 +OpStore %f %40 +%41 = OpLoad %int %j +%42 = OpIAdd %int %41 %int_1 +OpStore %j %42 +OpBranch %31 +%29 = OpLabel +%43 = OpLoad %int %i +%44 = OpIAdd %int %43 %int_1 +OpStore %i %44 +OpBranch %25 +%28 = OpLabel +%45 = OpLoad %float %f +OpStore %fo %45 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %9 +%24 = OpLabel +%f = OpVariable %_ptr_Function_float Function +%i = OpVariable %_ptr_Function_int Function +%j = OpVariable %_ptr_Function_int Function +OpStore %f %float_0 +OpStore %i %int_0 +OpBranch %25 +%25 = OpLabel +%47 = OpPhi %float %float_0 %24 %50 %29 +%46 = OpPhi %int %int_0 %24 %44 %29 +%27 = OpSLessThan %bool %46 %int_4 +OpLoopMerge %28 %29 None +OpBranchConditional %27 %30 %28 +%30 = OpLabel +OpStore %j %int_0 +OpBranch %31 +%31 = OpLabel +%50 = OpPhi %float %47 %30 %40 %34 +%48 = OpPhi %int %int_0 %30 %42 %34 +%33 = OpSLessThan %bool %48 %int_4 +OpLoopMerge %29 %34 None +OpBranchConditional %33 %34 %29 +%34 = OpLabel +%38 = OpAccessChain %_ptr_Input_float %BC %46 %48 +%39 = OpLoad %float %38 +%40 = OpFAdd %float %50 %39 +OpStore %f %40 +%42 = OpIAdd %int %48 %int_1 +OpStore %j %42 +OpBranch %31 +%29 = OpLabel +%44 = OpIAdd %int %46 %int_1 +OpStore %i %44 +OpBranch %25 +%28 = OpLabel +OpStore %fo %47 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); +} + +TEST_F(LocalSSAElimTest, ForLoopWithContinue) { + // #version 140 + // + // in vec4 BC; + // out float fo; + // + // void main() + // { + // float f = 0.0; + // for (int i=0; i<4; i++) { + // float t = BC[i]; + // if (t < 0.0) + // continue; + // f = f + t; + // } + // fo = f; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BC %fo +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +)"; + + const std::string names = + R"(OpName %main "main" +OpName %f "f" +OpName %i "i" +OpName %t "t" +OpName %BC "BC" +OpName %fo "fo" +)"; + + const std::string predefs2 = + R"(%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%int_4 = OpConstant %int 4 +%bool = OpTypeBool +%v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BC = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%int_1 = OpConstant %int 1 +%_ptr_Output_float = OpTypePointer Output %float +%fo = OpVariable %_ptr_Output_float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %9 +%23 = OpLabel +%f = OpVariable %_ptr_Function_float Function +%i = OpVariable %_ptr_Function_int Function +%t = OpVariable %_ptr_Function_float Function +OpStore %f %float_0 +OpStore %i %int_0 +OpBranch %24 +%24 = OpLabel +OpLoopMerge %25 %26 None +OpBranch %27 +%27 = OpLabel +%28 = OpLoad %int %i +%29 = OpSLessThan %bool %28 %int_4 +OpBranchConditional %29 %30 %25 +%30 = OpLabel +%31 = OpLoad %int %i +%32 = OpAccessChain %_ptr_Input_float %BC %31 +%33 = OpLoad %float %32 +OpStore %t %33 +%34 = OpLoad %float %t +%35 = OpFOrdLessThan %bool %34 %float_0 +OpSelectionMerge %36 None +OpBranchConditional %35 %37 %36 +%37 = OpLabel +OpBranch %26 +%36 = OpLabel +%38 = OpLoad %float %f +%39 = OpLoad %float %t +%40 = OpFAdd %float %38 %39 +OpStore %f %40 +OpBranch %26 +%26 = OpLabel +%41 = OpLoad %int %i +%42 = OpIAdd %int %41 %int_1 +OpStore %i %42 +OpBranch %24 +%25 = OpLabel +%43 = OpLoad %float %f +OpStore %fo %43 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %9 +%23 = OpLabel +%f = OpVariable %_ptr_Function_float Function +%i = OpVariable %_ptr_Function_int Function +%t = OpVariable %_ptr_Function_float Function +OpStore %f %float_0 +OpStore %i %int_0 +OpBranch %24 +%24 = OpLabel +%45 = OpPhi %float %float_0 %23 %47 %26 +%44 = OpPhi %int %int_0 %23 %42 %26 +OpLoopMerge %25 %26 None +OpBranch %27 +%27 = OpLabel +%29 = OpSLessThan %bool %44 %int_4 +OpBranchConditional %29 %30 %25 +%30 = OpLabel +%32 = OpAccessChain %_ptr_Input_float %BC %44 +%33 = OpLoad %float %32 +OpStore %t %33 +%35 = OpFOrdLessThan %bool %33 %float_0 +OpSelectionMerge %36 None +OpBranchConditional %35 %37 %36 +%37 = OpLabel +OpBranch %26 +%36 = OpLabel +%40 = OpFAdd %float %45 %33 +OpStore %f %40 +OpBranch %26 +%26 = OpLabel +%47 = OpPhi %float %45 %37 %40 %36 +%42 = OpIAdd %int %44 %int_1 +OpStore %i %42 +OpBranch %24 +%25 = OpLabel +OpStore %fo %45 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs + names + predefs2 + before, predefs + names + predefs2 + after, + true, true); +} + +TEST_F(LocalSSAElimTest, ForLoopWithBreak) { + // #version 140 + // + // in vec4 BC; + // out float fo; + // + // void main() + // { + // float f = 0.0; + // for (int i=0; i<4; i++) { + // float t = f + BC[i]; + // if (t > 1.0) + // break; + // f = t; + // } + // fo = f; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BC %fo +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %f "f" +OpName %i "i" +OpName %t "t" +OpName %BC "BC" +OpName %fo "fo" +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%int_4 = OpConstant %int 4 +%bool = OpTypeBool +%v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BC = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%float_1 = OpConstant %float 1 +%int_1 = OpConstant %int 1 +%_ptr_Output_float = OpTypePointer Output %float +%fo = OpVariable %_ptr_Output_float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %9 +%24 = OpLabel +%f = OpVariable %_ptr_Function_float Function +%i = OpVariable %_ptr_Function_int Function +%t = OpVariable %_ptr_Function_float Function +OpStore %f %float_0 +OpStore %i %int_0 +OpBranch %25 +%25 = OpLabel +OpLoopMerge %26 %27 None +OpBranch %28 +%28 = OpLabel +%29 = OpLoad %int %i +%30 = OpSLessThan %bool %29 %int_4 +OpBranchConditional %30 %31 %26 +%31 = OpLabel +%32 = OpLoad %float %f +%33 = OpLoad %int %i +%34 = OpAccessChain %_ptr_Input_float %BC %33 +%35 = OpLoad %float %34 +%36 = OpFAdd %float %32 %35 +OpStore %t %36 +%37 = OpLoad %float %t +%38 = OpFOrdGreaterThan %bool %37 %float_1 +OpSelectionMerge %39 None +OpBranchConditional %38 %40 %39 +%40 = OpLabel +OpBranch %26 +%39 = OpLabel +%41 = OpLoad %float %t +OpStore %f %41 +OpBranch %27 +%27 = OpLabel +%42 = OpLoad %int %i +%43 = OpIAdd %int %42 %int_1 +OpStore %i %43 +OpBranch %25 +%26 = OpLabel +%44 = OpLoad %float %f +OpStore %fo %44 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %9 +%24 = OpLabel +%f = OpVariable %_ptr_Function_float Function +%i = OpVariable %_ptr_Function_int Function +%t = OpVariable %_ptr_Function_float Function +OpStore %f %float_0 +OpStore %i %int_0 +OpBranch %25 +%25 = OpLabel +%46 = OpPhi %float %float_0 %24 %36 %27 +%45 = OpPhi %int %int_0 %24 %43 %27 +OpLoopMerge %26 %27 None +OpBranch %28 +%28 = OpLabel +%30 = OpSLessThan %bool %45 %int_4 +OpBranchConditional %30 %31 %26 +%31 = OpLabel +%34 = OpAccessChain %_ptr_Input_float %BC %45 +%35 = OpLoad %float %34 +%36 = OpFAdd %float %46 %35 +OpStore %t %36 +%38 = OpFOrdGreaterThan %bool %36 %float_1 +OpSelectionMerge %39 None +OpBranchConditional %38 %40 %39 +%40 = OpLabel +OpBranch %26 +%39 = OpLabel +OpStore %f %36 +OpBranch %27 +%27 = OpLabel +%43 = OpIAdd %int %45 %int_1 +OpStore %i %43 +OpBranch %25 +%26 = OpLabel +OpStore %fo %46 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); +} + +TEST_F(LocalSSAElimTest, SwapProblem) { + // #version 140 + // + // in float fe; + // out float fo; + // + // void main() + // { + // float f1 = 0.0; + // float f2 = 1.0; + // int ie = int(fe); + // for (int i=0; i(predefs + before, + predefs + after, true, true); +} + +TEST_F(LocalSSAElimTest, LostCopyProblem) { + // #version 140 + // + // in vec4 BC; + // out float fo; + // + // void main() + // { + // float f = 0.0; + // float t; + // for (int i=0; i<4; i++) { + // t = f; + // f = f + BC[i]; + // if (f > 1.0) + // break; + // } + // fo = t; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BC %fo +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %f "f" +OpName %i "i" +OpName %t "t" +OpName %BC "BC" +OpName %fo "fo" +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%int_4 = OpConstant %int 4 +%bool = OpTypeBool +%v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BC = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%float_1 = OpConstant %float 1 +%int_1 = OpConstant %int 1 +%_ptr_Output_float = OpTypePointer Output %float +%fo = OpVariable %_ptr_Output_float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %9 +%24 = OpLabel +%f = OpVariable %_ptr_Function_float Function +%i = OpVariable %_ptr_Function_int Function +%t = OpVariable %_ptr_Function_float Function +OpStore %f %float_0 +OpStore %i %int_0 +OpBranch %25 +%25 = OpLabel +OpLoopMerge %26 %27 None +OpBranch %28 +%28 = OpLabel +%29 = OpLoad %int %i +%30 = OpSLessThan %bool %29 %int_4 +OpBranchConditional %30 %31 %26 +%31 = OpLabel +%32 = OpLoad %float %f +OpStore %t %32 +%33 = OpLoad %float %f +%34 = OpLoad %int %i +%35 = OpAccessChain %_ptr_Input_float %BC %34 +%36 = OpLoad %float %35 +%37 = OpFAdd %float %33 %36 +OpStore %f %37 +%38 = OpLoad %float %f +%39 = OpFOrdGreaterThan %bool %38 %float_1 +OpSelectionMerge %40 None +OpBranchConditional %39 %41 %40 +%41 = OpLabel +OpBranch %26 +%40 = OpLabel +OpBranch %27 +%27 = OpLabel +%42 = OpLoad %int %i +%43 = OpIAdd %int %42 %int_1 +OpStore %i %43 +OpBranch %25 +%26 = OpLabel +%44 = OpLoad %float %t +OpStore %fo %44 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%49 = OpUndef %float +%main = OpFunction %void None %9 +%24 = OpLabel +%f = OpVariable %_ptr_Function_float Function +%i = OpVariable %_ptr_Function_int Function +%t = OpVariable %_ptr_Function_float Function +OpStore %f %float_0 +OpStore %i %int_0 +OpBranch %25 +%25 = OpLabel +%46 = OpPhi %float %float_0 %24 %37 %27 +%45 = OpPhi %int %int_0 %24 %43 %27 +%48 = OpPhi %float %49 %24 %46 %27 +OpLoopMerge %26 %27 None +OpBranch %28 +%28 = OpLabel +%30 = OpSLessThan %bool %45 %int_4 +OpBranchConditional %30 %31 %26 +%31 = OpLabel +OpStore %t %46 +%35 = OpAccessChain %_ptr_Input_float %BC %45 +%36 = OpLoad %float %35 +%37 = OpFAdd %float %46 %36 +OpStore %f %37 +%39 = OpFOrdGreaterThan %bool %37 %float_1 +OpSelectionMerge %40 None +OpBranchConditional %39 %41 %40 +%41 = OpLabel +OpBranch %26 +%40 = OpLabel +OpBranch %27 +%27 = OpLabel +%43 = OpIAdd %int %45 %int_1 +OpStore %i %43 +OpBranch %25 +%26 = OpLabel +%47 = OpPhi %float %48 %28 %46 %41 +OpStore %fo %47 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); +} + +TEST_F(LocalSSAElimTest, IfThenElse) { + // #version 140 + // + // in vec4 BaseColor; + // in float f; + // + // void main() + // { + // vec4 v; + // if (f >= 0) + // v = BaseColor * 0.5; + // else + // v = BaseColor + vec4(1.0,1.0,1.0,1.0); + // gl_FragColor = v; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %f %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %f "f" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Input_float = OpTypePointer Input %float +%f = OpVariable %_ptr_Input_float Input +%float_0 = OpConstant %float 0 +%bool = OpTypeBool +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%float_0_5 = OpConstant %float 0.5 +%float_1 = OpConstant %float 1 +%18 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %8 +%20 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%21 = OpLoad %float %f +%22 = OpFOrdGreaterThanEqual %bool %21 %float_0 +OpSelectionMerge %23 None +OpBranchConditional %22 %24 %25 +%24 = OpLabel +%26 = OpLoad %v4float %BaseColor +%27 = OpVectorTimesScalar %v4float %26 %float_0_5 +OpStore %v %27 +OpBranch %23 +%25 = OpLabel +%28 = OpLoad %v4float %BaseColor +%29 = OpFAdd %v4float %28 %18 +OpStore %v %29 +OpBranch %23 +%23 = OpLabel +%30 = OpLoad %v4float %v +OpStore %gl_FragColor %30 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %8 +%20 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%21 = OpLoad %float %f +%22 = OpFOrdGreaterThanEqual %bool %21 %float_0 +OpSelectionMerge %23 None +OpBranchConditional %22 %24 %25 +%24 = OpLabel +%26 = OpLoad %v4float %BaseColor +%27 = OpVectorTimesScalar %v4float %26 %float_0_5 +OpStore %v %27 +OpBranch %23 +%25 = OpLabel +%28 = OpLoad %v4float %BaseColor +%29 = OpFAdd %v4float %28 %18 +OpStore %v %29 +OpBranch %23 +%23 = OpLabel +%31 = OpPhi %v4float %27 %24 %29 %25 +OpStore %gl_FragColor %31 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); +} + +TEST_F(LocalSSAElimTest, IfThen) { + // #version 140 + // + // in vec4 BaseColor; + // in float f; + // + // void main() + // { + // vec4 v = BaseColor; + // if (f <= 0) + // v = v * 0.5; + // gl_FragColor = v; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %f %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %f "f" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%f = OpVariable %_ptr_Input_float Input +%float_0 = OpConstant %float 0 +%bool = OpTypeBool +%float_0_5 = OpConstant %float 0.5 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %8 +%18 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%19 = OpLoad %v4float %BaseColor +OpStore %v %19 +%20 = OpLoad %float %f +%21 = OpFOrdLessThanEqual %bool %20 %float_0 +OpSelectionMerge %22 None +OpBranchConditional %21 %23 %22 +%23 = OpLabel +%24 = OpLoad %v4float %v +%25 = OpVectorTimesScalar %v4float %24 %float_0_5 +OpStore %v %25 +OpBranch %22 +%22 = OpLabel +%26 = OpLoad %v4float %v +OpStore %gl_FragColor %26 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %8 +%18 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%19 = OpLoad %v4float %BaseColor +OpStore %v %19 +%20 = OpLoad %float %f +%21 = OpFOrdLessThanEqual %bool %20 %float_0 +OpSelectionMerge %22 None +OpBranchConditional %21 %23 %22 +%23 = OpLabel +%25 = OpVectorTimesScalar %v4float %19 %float_0_5 +OpStore %v %25 +OpBranch %22 +%22 = OpLabel +%27 = OpPhi %v4float %19 %18 %25 %23 +OpStore %gl_FragColor %27 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); +} + +TEST_F(LocalSSAElimTest, Switch) { + // #version 140 + // + // in vec4 BaseColor; + // in float f; + // + // void main() + // { + // vec4 v = BaseColor; + // int i = int(f); + // switch (i) { + // case 0: + // v = v * 0.25; + // break; + // case 1: + // v = v * 0.625; + // break; + // case 2: + // v = v * 0.75; + // break; + // default: + // break; + // } + // gl_FragColor = v; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %f %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %i "i" +OpName %f "f" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%_ptr_Input_float = OpTypePointer Input %float +%f = OpVariable %_ptr_Input_float Input +%float_0_25 = OpConstant %float 0.25 +%float_0_625 = OpConstant %float 0.625 +%float_0_75 = OpConstant %float 0.75 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %9 +%21 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%i = OpVariable %_ptr_Function_int Function +%22 = OpLoad %v4float %BaseColor +OpStore %v %22 +%23 = OpLoad %float %f +%24 = OpConvertFToS %int %23 +OpStore %i %24 +%25 = OpLoad %int %i +OpSelectionMerge %26 None +OpSwitch %25 %27 0 %28 1 %29 2 %30 +%27 = OpLabel +OpBranch %26 +%28 = OpLabel +%31 = OpLoad %v4float %v +%32 = OpVectorTimesScalar %v4float %31 %float_0_25 +OpStore %v %32 +OpBranch %26 +%29 = OpLabel +%33 = OpLoad %v4float %v +%34 = OpVectorTimesScalar %v4float %33 %float_0_625 +OpStore %v %34 +OpBranch %26 +%30 = OpLabel +%35 = OpLoad %v4float %v +%36 = OpVectorTimesScalar %v4float %35 %float_0_75 +OpStore %v %36 +OpBranch %26 +%26 = OpLabel +%37 = OpLoad %v4float %v +OpStore %gl_FragColor %37 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %9 +%21 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%i = OpVariable %_ptr_Function_int Function +%22 = OpLoad %v4float %BaseColor +OpStore %v %22 +%23 = OpLoad %float %f +%24 = OpConvertFToS %int %23 +OpStore %i %24 +OpSelectionMerge %26 None +OpSwitch %24 %27 0 %28 1 %29 2 %30 +%27 = OpLabel +OpBranch %26 +%28 = OpLabel +%32 = OpVectorTimesScalar %v4float %22 %float_0_25 +OpStore %v %32 +OpBranch %26 +%29 = OpLabel +%34 = OpVectorTimesScalar %v4float %22 %float_0_625 +OpStore %v %34 +OpBranch %26 +%30 = OpLabel +%36 = OpVectorTimesScalar %v4float %22 %float_0_75 +OpStore %v %36 +OpBranch %26 +%26 = OpLabel +%38 = OpPhi %v4float %22 %27 %32 %28 %34 %29 %36 %30 +OpStore %gl_FragColor %38 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); +} + +TEST_F(LocalSSAElimTest, SwitchWithFallThrough) { + // #version 140 + // + // in vec4 BaseColor; + // in float f; + // + // void main() + // { + // vec4 v = BaseColor; + // int i = int(f); + // switch (i) { + // case 0: + // v = v * 0.25; + // break; + // case 1: + // v = v + 0.25; + // case 2: + // v = v * 0.75; + // break; + // default: + // break; + // } + // gl_FragColor = v; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %f %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %v "v" +OpName %BaseColor "BaseColor" +OpName %i "i" +OpName %f "f" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%_ptr_Input_float = OpTypePointer Input %float +%f = OpVariable %_ptr_Input_float Input +%float_0_25 = OpConstant %float 0.25 +%float_0_75 = OpConstant %float 0.75 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %9 +%20 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%i = OpVariable %_ptr_Function_int Function +%21 = OpLoad %v4float %BaseColor +OpStore %v %21 +%22 = OpLoad %float %f +%23 = OpConvertFToS %int %22 +OpStore %i %23 +%24 = OpLoad %int %i +OpSelectionMerge %25 None +OpSwitch %24 %26 0 %27 1 %28 2 %29 +%26 = OpLabel +OpBranch %25 +%27 = OpLabel +%30 = OpLoad %v4float %v +%31 = OpVectorTimesScalar %v4float %30 %float_0_25 +OpStore %v %31 +OpBranch %25 +%28 = OpLabel +%32 = OpLoad %v4float %v +%33 = OpCompositeConstruct %v4float %float_0_25 %float_0_25 %float_0_25 %float_0_25 +%34 = OpFAdd %v4float %32 %33 +OpStore %v %34 +OpBranch %29 +%29 = OpLabel +%35 = OpLoad %v4float %v +%36 = OpVectorTimesScalar %v4float %35 %float_0_75 +OpStore %v %36 +OpBranch %25 +%25 = OpLabel +%37 = OpLoad %v4float %v +OpStore %gl_FragColor %37 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %9 +%20 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function +%i = OpVariable %_ptr_Function_int Function +%21 = OpLoad %v4float %BaseColor +OpStore %v %21 +%22 = OpLoad %float %f +%23 = OpConvertFToS %int %22 +OpStore %i %23 +OpSelectionMerge %25 None +OpSwitch %23 %26 0 %27 1 %28 2 %29 +%26 = OpLabel +OpBranch %25 +%27 = OpLabel +%31 = OpVectorTimesScalar %v4float %21 %float_0_25 +OpStore %v %31 +OpBranch %25 +%28 = OpLabel +%33 = OpCompositeConstruct %v4float %float_0_25 %float_0_25 %float_0_25 %float_0_25 +%34 = OpFAdd %v4float %21 %33 +OpStore %v %34 +OpBranch %29 +%29 = OpLabel +%38 = OpPhi %v4float %21 %20 %34 %28 +%36 = OpVectorTimesScalar %v4float %38 %float_0_75 +OpStore %v %36 +OpBranch %25 +%25 = OpLabel +%39 = OpPhi %v4float %21 %26 %31 %27 %36 %29 +OpStore %gl_FragColor %39 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, + predefs + after, true, true); +} + +TEST_F(LocalSSAElimTest, DontPatchPhiInLoopHeaderThatIsNotAVar) { + // From https://github.com/KhronosGroup/SPIRV-Tools/issues/826 + // Don't try patching the (%16 %7) value/predecessor pair in the OpPhi. + // That OpPhi is unrelated to this optimization: we did not set that up + // in the SSA initialization for the loop header block. + // The pass should be a no-op on this module. + + const std::string before = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %1 "main" +%void = OpTypeVoid +%3 = OpTypeFunction %void +%float = OpTypeFloat 32 +%float_1 = OpConstant %float 1 +%1 = OpFunction %void None %3 +%6 = OpLabel +OpBranch %7 +%7 = OpLabel +%8 = OpPhi %float %float_1 %6 %9 %7 +%9 = OpFAdd %float %8 %float_1 +OpLoopMerge %10 %7 None +OpBranch %7 +%10 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(before, before, true, true); +} + +TEST_F(LocalSSAElimTest, OptInitializedVariableLikeStore) { + // Note: SPIR-V edited to change store to v into variable initialization + // + // #version 450 + // + // layout (location=0) in vec4 iColor; + // layout (location=1) in float fi; + // layout (location=0) out vec4 oColor; + // + // void main() + // { + // vec4 v = vec4(0.0); + // if (fi < 0.0) + // v.x = iColor.x; + // oColor = v; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %fi %iColor %oColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %v "v" +OpName %fi "fi" +OpName %iColor "iColor" +OpName %oColor "oColor" +OpDecorate %fi Location 1 +OpDecorate %iColor Location 0 +OpDecorate %oColor Location 0 +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%float_0 = OpConstant %float 0 +%13 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +%_ptr_Input_float = OpTypePointer Input %float +%fi = OpVariable %_ptr_Input_float Input +%bool = OpTypeBool +%_ptr_Input_v4float = OpTypePointer Input %v4float +%iColor = OpVariable %_ptr_Input_v4float Input +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%_ptr_Function_float = OpTypePointer Function %float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%oColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string func_before = + R"(%main = OpFunction %void None %8 +%21 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function %13 +%22 = OpLoad %float %fi +%23 = OpFOrdLessThan %bool %22 %float_0 +OpSelectionMerge %24 None +OpBranchConditional %23 %25 %24 +%25 = OpLabel +%26 = OpAccessChain %_ptr_Input_float %iColor %uint_0 +%27 = OpLoad %float %26 +%28 = OpLoad %v4float %v +%29 = OpCompositeInsert %v4float %27 %28 0 +OpStore %v %29 +OpBranch %24 +%24 = OpLabel +%30 = OpLoad %v4float %v +OpStore %oColor %30 +OpReturn +OpFunctionEnd +)"; + + const std::string func_after = + R"(%main = OpFunction %void None %8 +%21 = OpLabel +%v = OpVariable %_ptr_Function_v4float Function %13 +%22 = OpLoad %float %fi +%23 = OpFOrdLessThan %bool %22 %float_0 +OpSelectionMerge %24 None +OpBranchConditional %23 %25 %24 +%25 = OpLabel +%26 = OpAccessChain %_ptr_Input_float %iColor %uint_0 +%27 = OpLoad %float %26 +%29 = OpCompositeInsert %v4float %27 %13 0 +OpStore %v %29 +OpBranch %24 +%24 = OpLabel +%31 = OpPhi %v4float %13 %21 %29 %25 +OpStore %oColor %31 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs + func_before, predefs + func_after, true, true); +} + +TEST_F(LocalSSAElimTest, PointerVariable) { + // Test that checks if a pointer variable is removed. + + const std::string before = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "main" %2 +OpExecutionMode %1 OriginUpperLeft +OpMemberDecorate %_struct_3 0 Offset 0 +OpDecorate %_runtimearr__struct_3 ArrayStride 16 +OpMemberDecorate %_struct_5 0 Offset 0 +OpDecorate %_struct_5 BufferBlock +OpMemberDecorate %_struct_6 0 Offset 0 +OpDecorate %_struct_6 BufferBlock +OpDecorate %2 Location 0 +OpDecorate %7 DescriptorSet 0 +OpDecorate %7 Binding 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%uint = OpTypeInt 32 0 +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_ptr_Uniform_v4float = OpTypePointer Uniform %v4float +%_struct_3 = OpTypeStruct %v4float +%_runtimearr__struct_3 = OpTypeRuntimeArray %_struct_3 +%_struct_5 = OpTypeStruct %_runtimearr__struct_3 +%_ptr_Uniform__struct_5 = OpTypePointer Uniform %_struct_5 +%_struct_6 = OpTypeStruct %int +%_ptr_Uniform__struct_6 = OpTypePointer Uniform %_struct_6 +%_ptr_Function__ptr_Uniform__struct_5 = OpTypePointer Function %_ptr_Uniform__struct_5 +%_ptr_Function__ptr_Uniform__struct_6 = OpTypePointer Function %_ptr_Uniform__struct_6 +%int_0 = OpConstant %int 0 +%uint_0 = OpConstant %uint 0 +%2 = OpVariable %_ptr_Output_v4float Output +%7 = OpVariable %_ptr_Uniform__struct_5 Uniform +%1 = OpFunction %void None %10 +%23 = OpLabel +%24 = OpVariable %_ptr_Function__ptr_Uniform__struct_5 Function +OpStore %24 %7 +%26 = OpLoad %_ptr_Uniform__struct_5 %24 +%27 = OpAccessChain %_ptr_Uniform_v4float %26 %int_0 %uint_0 %int_0 +%28 = OpLoad %v4float %27 +%29 = OpCopyObject %v4float %28 +OpStore %2 %28 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "main" %2 +OpExecutionMode %1 OriginUpperLeft +OpMemberDecorate %_struct_3 0 Offset 0 +OpDecorate %_runtimearr__struct_3 ArrayStride 16 +OpMemberDecorate %_struct_5 0 Offset 0 +OpDecorate %_struct_5 BufferBlock +OpMemberDecorate %_struct_6 0 Offset 0 +OpDecorate %_struct_6 BufferBlock +OpDecorate %2 Location 0 +OpDecorate %7 DescriptorSet 0 +OpDecorate %7 Binding 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%uint = OpTypeInt 32 0 +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_ptr_Uniform_v4float = OpTypePointer Uniform %v4float +%_struct_3 = OpTypeStruct %v4float +%_runtimearr__struct_3 = OpTypeRuntimeArray %_struct_3 +%_struct_5 = OpTypeStruct %_runtimearr__struct_3 +%_ptr_Uniform__struct_5 = OpTypePointer Uniform %_struct_5 +%_struct_6 = OpTypeStruct %int +%_ptr_Uniform__struct_6 = OpTypePointer Uniform %_struct_6 +%_ptr_Function__ptr_Uniform__struct_5 = OpTypePointer Function %_ptr_Uniform__struct_5 +%_ptr_Function__ptr_Uniform__struct_6 = OpTypePointer Function %_ptr_Uniform__struct_6 +%int_0 = OpConstant %int 0 +%uint_0 = OpConstant %uint 0 +%2 = OpVariable %_ptr_Output_v4float Output +%7 = OpVariable %_ptr_Uniform__struct_5 Uniform +%1 = OpFunction %void None %10 +%23 = OpLabel +%24 = OpVariable %_ptr_Function__ptr_Uniform__struct_5 Function +OpStore %24 %7 +%27 = OpAccessChain %_ptr_Uniform_v4float %7 %int_0 %uint_0 %int_0 +%28 = OpLoad %v4float %27 +%29 = OpCopyObject %v4float %28 +OpStore %2 %28 +OpReturn +OpFunctionEnd +)"; + + // Relax logical pointers to allow pointer allocations. + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ValidatorOptions()->relax_logical_pointer = true; + SinglePassRunAndCheck(before, after, true, true); +} + +TEST_F(LocalSSAElimTest, VerifyInstToBlockMap) { + // #version 140 + // + // in vec4 BC; + // out float fo; + // + // void main() + // { + // float f = 0.0; + // for (int i=0; i<4; i++) { + // f = f + BC[i]; + // } + // fo = f; + // } + + const std::string text = R"( +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BC %fo +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %f "f" +OpName %i "i" +OpName %BC "BC" +OpName %fo "fo" +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%int_4 = OpConstant %int 4 +%bool = OpTypeBool +%v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BC = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%int_1 = OpConstant %int 1 +%_ptr_Output_float = OpTypePointer Output %float +%fo = OpVariable %_ptr_Output_float Output +%main = OpFunction %void None %8 +%22 = OpLabel +%f = OpVariable %_ptr_Function_float Function +%i = OpVariable %_ptr_Function_int Function +OpStore %f %float_0 +OpStore %i %int_0 +OpBranch %23 +%23 = OpLabel +OpLoopMerge %24 %25 None +OpBranch %26 +%26 = OpLabel +%27 = OpLoad %int %i +%28 = OpSLessThan %bool %27 %int_4 +OpBranchConditional %28 %29 %24 +%29 = OpLabel +%30 = OpLoad %float %f +%31 = OpLoad %int %i +%32 = OpAccessChain %_ptr_Input_float %BC %31 +%33 = OpLoad %float %32 +%34 = OpFAdd %float %30 %33 +OpStore %f %34 +OpBranch %25 +%25 = OpLabel +%35 = OpLoad %int %i +%36 = OpIAdd %int %35 %int_1 +OpStore %i %36 +OpBranch %23 +%24 = OpLabel +%37 = OpLoad %float %f +OpStore %fo %37 +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(nullptr, context); + + // Force the instruction to block mapping to get built. + context->get_instr_block(27u); + + auto pass = MakeUnique(); + pass->SetMessageConsumer(nullptr); + const auto status = pass->Run(context.get()); + EXPECT_TRUE(status == Pass::Status::SuccessWithChange); +} + +TEST_F(LocalSSAElimTest, CompositeExtractProblem) { + const std::string spv_asm = R"( + OpCapability Tessellation + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint TessellationControl %2 "main" %16 %17 %18 %20 %22 %26 %27 %30 %31 + %void = OpTypeVoid + %4 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %uint = OpTypeInt 32 0 + %uint_3 = OpConstant %uint 3 + %v3float = OpTypeVector %float 3 + %v2float = OpTypeVector %float 2 + %_struct_11 = OpTypeStruct %v4float %v4float %v4float %v3float %v3float %v2float %v2float +%_arr__struct_11_uint_3 = OpTypeArray %_struct_11 %uint_3 +%_ptr_Function__arr__struct_11_uint_3 = OpTypePointer Function %_arr__struct_11_uint_3 +%_arr_v4float_uint_3 = OpTypeArray %v4float %uint_3 +%_ptr_Input__arr_v4float_uint_3 = OpTypePointer Input %_arr_v4float_uint_3 + %16 = OpVariable %_ptr_Input__arr_v4float_uint_3 Input + %17 = OpVariable %_ptr_Input__arr_v4float_uint_3 Input + %18 = OpVariable %_ptr_Input__arr_v4float_uint_3 Input +%_ptr_Input_uint = OpTypePointer Input %uint + %20 = OpVariable %_ptr_Input_uint Input +%_ptr_Output__arr_v4float_uint_3 = OpTypePointer Output %_arr_v4float_uint_3 + %22 = OpVariable %_ptr_Output__arr_v4float_uint_3 Output +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_arr_v3float_uint_3 = OpTypeArray %v3float %uint_3 +%_ptr_Input__arr_v3float_uint_3 = OpTypePointer Input %_arr_v3float_uint_3 + %26 = OpVariable %_ptr_Input__arr_v3float_uint_3 Input + %27 = OpVariable %_ptr_Input__arr_v3float_uint_3 Input +%_arr_v2float_uint_3 = OpTypeArray %v2float %uint_3 +%_ptr_Input__arr_v2float_uint_3 = OpTypePointer Input %_arr_v2float_uint_3 + %30 = OpVariable %_ptr_Input__arr_v2float_uint_3 Input + %31 = OpVariable %_ptr_Input__arr_v2float_uint_3 Input +%_ptr_Function__struct_11 = OpTypePointer Function %_struct_11 + %2 = OpFunction %void None %4 + %33 = OpLabel + %66 = OpVariable %_ptr_Function__arr__struct_11_uint_3 Function + %34 = OpLoad %_arr_v4float_uint_3 %16 + %35 = OpLoad %_arr_v4float_uint_3 %17 + %36 = OpLoad %_arr_v4float_uint_3 %18 + %37 = OpLoad %_arr_v3float_uint_3 %26 + %38 = OpLoad %_arr_v3float_uint_3 %27 + %39 = OpLoad %_arr_v2float_uint_3 %30 + %40 = OpLoad %_arr_v2float_uint_3 %31 + %41 = OpCompositeExtract %v4float %34 0 + %42 = OpCompositeExtract %v4float %35 0 + %43 = OpCompositeExtract %v4float %36 0 + %44 = OpCompositeExtract %v3float %37 0 + %45 = OpCompositeExtract %v3float %38 0 + %46 = OpCompositeExtract %v2float %39 0 + %47 = OpCompositeExtract %v2float %40 0 + %48 = OpCompositeConstruct %_struct_11 %41 %42 %43 %44 %45 %46 %47 + %49 = OpCompositeExtract %v4float %34 1 + %50 = OpCompositeExtract %v4float %35 1 + %51 = OpCompositeExtract %v4float %36 1 + %52 = OpCompositeExtract %v3float %37 1 + %53 = OpCompositeExtract %v3float %38 1 + %54 = OpCompositeExtract %v2float %39 1 + %55 = OpCompositeExtract %v2float %40 1 + %56 = OpCompositeConstruct %_struct_11 %49 %50 %51 %52 %53 %54 %55 + %57 = OpCompositeExtract %v4float %34 2 + %58 = OpCompositeExtract %v4float %35 2 + %59 = OpCompositeExtract %v4float %36 2 + %60 = OpCompositeExtract %v3float %37 2 + %61 = OpCompositeExtract %v3float %38 2 + %62 = OpCompositeExtract %v2float %39 2 + %63 = OpCompositeExtract %v2float %40 2 + %64 = OpCompositeConstruct %_struct_11 %57 %58 %59 %60 %61 %62 %63 + %65 = OpCompositeConstruct %_arr__struct_11_uint_3 %48 %56 %64 + %67 = OpLoad %uint %20 + +; CHECK OpStore {{%\d+}} [[store_source:%\d+]] + OpStore %66 %65 + %68 = OpAccessChain %_ptr_Function__struct_11 %66 %67 + +; This load was being removed, because %_ptr_Function__struct_11 was being +; wrongfully considered an SSA target. +; CHECK OpLoad %_struct_11 %68 + %69 = OpLoad %_struct_11 %68 + +; Similarly, %69 cannot be replaced with %65. +; CHECK-NOT: OpCompositeExtract %v4float [[store_source]] 0 + %70 = OpCompositeExtract %v4float %69 0 + + %71 = OpAccessChain %_ptr_Output_v4float %22 %67 + OpStore %71 %70 + OpReturn + OpFunctionEnd)"; + + SinglePassRunAndMatch(spv_asm, true); +} + +// Test that the RelaxedPrecision decoration on the variable to added to the +// result of the OpPhi instruction. +TEST_F(LocalSSAElimTest, DecoratedVariable) { + const std::string spv_asm = R"( +; CHECK: OpDecorate [[var:%\w+]] RelaxedPrecision +; CHECK: OpDecorate [[phi_id:%\w+]] RelaxedPrecision +; CHECK: [[phi_id]] = OpPhi + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %2 "main" + OpDecorate %v RelaxedPrecision + %void = OpTypeVoid + %func_t = OpTypeFunction %void + %bool = OpTypeBool + %true = OpConstantTrue %bool + %int = OpTypeInt 32 0 + %int_p = OpTypePointer Function %int + %int_1 = OpConstant %int 1 + %int_0 = OpConstant %int 0 + %2 = OpFunction %void None %func_t + %33 = OpLabel + %v = OpVariable %int_p Function + OpSelectionMerge %merge None + OpBranchConditional %true %l1 %l2 + %l1 = OpLabel + OpStore %v %int_1 + OpBranch %merge + %l2 = OpLabel + OpStore %v %int_0 + OpBranch %merge + %merge = OpLabel + %ld = OpLoad %int %v + OpReturn + OpFunctionEnd)"; + + SinglePassRunAndMatch(spv_asm, true); +} + +// Test that the RelaxedPrecision decoration on the variable to added to the +// result of the OpPhi instruction. +TEST_F(LocalSSAElimTest, MultipleEdges) { + const std::string spv_asm = R"( + ; CHECK: OpSelectionMerge + ; CHECK: [[header_bb:%\w+]] = OpLabel + ; CHECK-NOT: OpLabel + ; CHECK: OpSwitch {{%\w+}} {{%\w+}} 76 [[bb1:%\w+]] 17 [[bb2:%\w+]] + ; CHECK-SAME: 4 [[bb2]] + ; CHECK: [[bb2]] = OpLabel + ; CHECK-NEXT: OpPhi [[type:%\w+]] [[val:%\w+]] [[header_bb]] %int_0 [[bb1]] + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 + %_ptr_Function_int = OpTypePointer Function %int + %int_0 = OpConstant %int 0 + %bool = OpTypeBool + %true = OpConstantTrue %bool + %false = OpConstantFalse %bool + %int_1 = OpConstant %int 1 + %4 = OpFunction %void None %3 + %5 = OpLabel + %8 = OpVariable %_ptr_Function_int Function + OpBranch %10 + %10 = OpLabel + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + OpBranchConditional %true %11 %12 + %11 = OpLabel + OpSelectionMerge %19 None + OpBranchConditional %false %18 %19 + %18 = OpLabel + OpSelectionMerge %22 None + OpSwitch %int_0 %22 76 %20 17 %21 4 %21 + %20 = OpLabel + %23 = OpLoad %int %8 + OpStore %8 %int_0 + OpBranch %21 + %21 = OpLabel + OpBranch %22 + %22 = OpLabel + OpBranch %19 + %19 = OpLabel + OpBranch %13 + %13 = OpLabel + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(spv_asm, true); +} + +// TODO(greg-lunarg): Add tests to verify handling of these cases: +// +// No optimization in the presence of +// access chains +// function calls +// OpCopyMemory? +// unsupported extensions +// Others? + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/loop_optimizations/CMakeLists.txt b/test/opt/loop_optimizations/CMakeLists.txt new file mode 100644 index 000000000..e3620787d --- /dev/null +++ b/test/opt/loop_optimizations/CMakeLists.txt @@ -0,0 +1,41 @@ +# Copyright (c) 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +add_spvtools_unittest(TARGET opt_loops + SRCS ../function_utils.h + dependence_analysis.cpp + dependence_analysis_helpers.cpp + fusion_compatibility.cpp + fusion_illegal.cpp + fusion_legal.cpp + fusion_pass.cpp + hoist_all_loop_types.cpp + hoist_double_nested_loops.cpp + hoist_from_independent_loops.cpp + hoist_simple_case.cpp + hoist_single_nested_loops.cpp + hoist_without_preheader.cpp + lcssa.cpp + loop_descriptions.cpp + loop_fission.cpp + nested_loops.cpp + peeling.cpp + peeling_pass.cpp + unroll_assumptions.cpp + unroll_simple.cpp + unswitch.cpp + LIBS SPIRV-Tools-opt + PCH_FILE pch_test_opt_loop +) diff --git a/test/opt/loop_optimizations/dependence_analysis.cpp b/test/opt/loop_optimizations/dependence_analysis.cpp new file mode 100644 index 000000000..8aeb20afc --- /dev/null +++ b/test/opt/loop_optimizations/dependence_analysis.cpp @@ -0,0 +1,4205 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/iterator.h" +#include "source/opt/loop_dependence.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/pass.h" +#include "source/opt/tree_iterator.h" +#include "test/opt//assembly_builder.h" +#include "test/opt//function_utils.h" +#include "test/opt//pass_fixture.h" +#include "test/opt//pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using DependencyAnalysis = ::testing::Test; + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +void main(){ + int[10] arr; + int[10] arr2; + int a = 2; + for (int i = 0; i < 10; i++) { + arr[a] = arr[3]; + arr[a*2] = arr[a+3]; + arr[6] = arr2[6]; + arr[a+5] = arr2[7]; + } +} +*/ +TEST(DependencyAnalysis, ZIV) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %25 "arr" + OpName %39 "arr2" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 2 + %11 = OpConstant %6 0 + %18 = OpConstant %6 10 + %19 = OpTypeBool + %21 = OpTypeInt 32 0 + %22 = OpConstant %21 10 + %23 = OpTypeArray %6 %22 + %24 = OpTypePointer Function %23 + %27 = OpConstant %6 3 + %38 = OpConstant %6 6 + %44 = OpConstant %6 5 + %46 = OpConstant %6 7 + %51 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %25 = OpVariable %24 Function + %39 = OpVariable %24 Function + OpBranch %12 + %12 = OpLabel + %53 = OpPhi %6 %11 %5 %52 %15 + OpLoopMerge %14 %15 None + OpBranch %16 + %16 = OpLabel + %20 = OpSLessThan %19 %53 %18 + OpBranchConditional %20 %13 %14 + %13 = OpLabel + %28 = OpAccessChain %7 %25 %27 + %29 = OpLoad %6 %28 + %30 = OpAccessChain %7 %25 %9 + OpStore %30 %29 + %32 = OpIMul %6 %9 %9 + %34 = OpIAdd %6 %9 %27 + %35 = OpAccessChain %7 %25 %34 + %36 = OpLoad %6 %35 + %37 = OpAccessChain %7 %25 %32 + OpStore %37 %36 + %40 = OpAccessChain %7 %39 %38 + %41 = OpLoad %6 %40 + %42 = OpAccessChain %7 %25 %38 + OpStore %42 %41 + %45 = OpIAdd %6 %9 %44 + %47 = OpAccessChain %7 %39 %46 + %48 = OpLoad %6 %47 + %49 = OpAccessChain %7 %25 %45 + OpStore %49 %48 + OpBranch %15 + %15 = OpLabel + %52 = OpIAdd %6 %53 %51 + OpBranch %12 + %14 = OpLabel + OpReturn + OpFunctionEnd +)"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 4); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[4]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 13)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(store[i]); + } + + // 29 -> 30 tests looking through constants. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(29), + store[0], &distance_vector)); + } + + // 36 -> 37 tests looking through additions. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(36), + store[1], &distance_vector)); + } + + // 41 -> 42 tests looking at same index across two different arrays. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(41), + store[2], &distance_vector)); + } + + // 48 -> 49 tests looking through additions for same index in two different + // arrays. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(48), + store[3], &distance_vector)); + } +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +layout(location = 0) in vec4 c; +void main(){ + int[10] arr; + int[10] arr2; + int[10] arr3; + int[10] arr4; + int[10] arr5; + int N = int(c.x); + for (int i = 0; i < N; i++) { + arr[2*N] = arr[N]; + arr2[2*N+1] = arr2[N]; + arr3[2*N] = arr3[N-1]; + arr4[N] = arr5[N]; + } +} +*/ +TEST(DependencyAnalysis, SymbolicZIV) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %12 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %12 "c" + OpName %33 "arr" + OpName %41 "arr2" + OpName %50 "arr3" + OpName %58 "arr4" + OpName %60 "arr5" + OpDecorate %12 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpTypeFloat 32 + %10 = OpTypeVector %9 4 + %11 = OpTypePointer Input %10 + %12 = OpVariable %11 Input + %13 = OpTypeInt 32 0 + %14 = OpConstant %13 0 + %15 = OpTypePointer Input %9 + %20 = OpConstant %6 0 + %28 = OpTypeBool + %30 = OpConstant %13 10 + %31 = OpTypeArray %6 %30 + %32 = OpTypePointer Function %31 + %34 = OpConstant %6 2 + %44 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %33 = OpVariable %32 Function + %41 = OpVariable %32 Function + %50 = OpVariable %32 Function + %58 = OpVariable %32 Function + %60 = OpVariable %32 Function + %16 = OpAccessChain %15 %12 %14 + %17 = OpLoad %9 %16 + %18 = OpConvertFToS %6 %17 + OpBranch %21 + %21 = OpLabel + %67 = OpPhi %6 %20 %5 %66 %24 + OpLoopMerge %23 %24 None + OpBranch %25 + %25 = OpLabel + %29 = OpSLessThan %28 %67 %18 + OpBranchConditional %29 %22 %23 + %22 = OpLabel + %36 = OpIMul %6 %34 %18 + %38 = OpAccessChain %7 %33 %18 + %39 = OpLoad %6 %38 + %40 = OpAccessChain %7 %33 %36 + OpStore %40 %39 + %43 = OpIMul %6 %34 %18 + %45 = OpIAdd %6 %43 %44 + %47 = OpAccessChain %7 %41 %18 + %48 = OpLoad %6 %47 + %49 = OpAccessChain %7 %41 %45 + OpStore %49 %48 + %52 = OpIMul %6 %34 %18 + %54 = OpISub %6 %18 %44 + %55 = OpAccessChain %7 %50 %54 + %56 = OpLoad %6 %55 + %57 = OpAccessChain %7 %50 %52 + OpStore %57 %56 + %62 = OpAccessChain %7 %60 %18 + %63 = OpLoad %6 %62 + %64 = OpAccessChain %7 %58 %18 + OpStore %64 %63 + OpBranch %24 + %24 = OpLabel + %66 = OpIAdd %6 %67 %44 + OpBranch %21 + %23 = OpLabel + OpReturn + OpFunctionEnd +)"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 4); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[4]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 22)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(store[i]); + } + + // independent due to loop bounds (won't enter if N <= 0). + // 39 -> 40 tests looking through symbols and multiplicaiton. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(39), + store[0], &distance_vector)); + } + + // 48 -> 49 tests looking through symbols and multiplication + addition. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(48), + store[1], &distance_vector)); + } + + // 56 -> 57 tests looking through symbols and arithmetic on load and store. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(56), + store[2], &distance_vector)); + } + + // independent as different arrays + // 63 -> 64 tests looking through symbols and load/store from/to different + // arrays. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(63), + store[3], &distance_vector)); + } +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +void a(){ + int[10] arr; + int[11] arr2; + int[20] arr3; + int[20] arr4; + int a = 2; + for (int i = 0; i < 10; i++) { + arr[i] = arr[i]; + arr2[i] = arr2[i+1]; + arr3[i] = arr3[i-1]; + arr4[2*i] = arr4[i]; + } +} +void b(){ + int[10] arr; + int[11] arr2; + int[20] arr3; + int[20] arr4; + int a = 2; + for (int i = 10; i > 0; i--) { + arr[i] = arr[i]; + arr2[i] = arr2[i+1]; + arr3[i] = arr3[i-1]; + arr4[2*i] = arr4[i]; + } +} + +void main() { + a(); + b(); +} +*/ +TEST(DependencyAnalysis, SIV) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %6 "a(" + OpName %8 "b(" + OpName %12 "a" + OpName %14 "i" + OpName %29 "arr" + OpName %38 "arr2" + OpName %49 "arr3" + OpName %56 "arr4" + OpName %65 "a" + OpName %66 "i" + OpName %74 "arr" + OpName %80 "arr2" + OpName %87 "arr3" + OpName %94 "arr4" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %10 = OpTypeInt 32 1 + %11 = OpTypePointer Function %10 + %13 = OpConstant %10 2 + %15 = OpConstant %10 0 + %22 = OpConstant %10 10 + %23 = OpTypeBool + %25 = OpTypeInt 32 0 + %26 = OpConstant %25 10 + %27 = OpTypeArray %10 %26 + %28 = OpTypePointer Function %27 + %35 = OpConstant %25 11 + %36 = OpTypeArray %10 %35 + %37 = OpTypePointer Function %36 + %41 = OpConstant %10 1 + %46 = OpConstant %25 20 + %47 = OpTypeArray %10 %46 + %48 = OpTypePointer Function %47 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %103 = OpFunctionCall %2 %6 + %104 = OpFunctionCall %2 %8 + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %12 = OpVariable %11 Function + %14 = OpVariable %11 Function + %29 = OpVariable %28 Function + %38 = OpVariable %37 Function + %49 = OpVariable %48 Function + %56 = OpVariable %48 Function + OpStore %12 %13 + OpStore %14 %15 + OpBranch %16 + %16 = OpLabel + %105 = OpPhi %10 %15 %7 %64 %19 + OpLoopMerge %18 %19 None + OpBranch %20 + %20 = OpLabel + %24 = OpSLessThan %23 %105 %22 + OpBranchConditional %24 %17 %18 + %17 = OpLabel + %32 = OpAccessChain %11 %29 %105 + %33 = OpLoad %10 %32 + %34 = OpAccessChain %11 %29 %105 + OpStore %34 %33 + %42 = OpIAdd %10 %105 %41 + %43 = OpAccessChain %11 %38 %42 + %44 = OpLoad %10 %43 + %45 = OpAccessChain %11 %38 %105 + OpStore %45 %44 + %52 = OpISub %10 %105 %41 + %53 = OpAccessChain %11 %49 %52 + %54 = OpLoad %10 %53 + %55 = OpAccessChain %11 %49 %105 + OpStore %55 %54 + %58 = OpIMul %10 %13 %105 + %60 = OpAccessChain %11 %56 %105 + %61 = OpLoad %10 %60 + %62 = OpAccessChain %11 %56 %58 + OpStore %62 %61 + OpBranch %19 + %19 = OpLabel + %64 = OpIAdd %10 %105 %41 + OpStore %14 %64 + OpBranch %16 + %18 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %2 None %3 + %9 = OpLabel + %65 = OpVariable %11 Function + %66 = OpVariable %11 Function + %74 = OpVariable %28 Function + %80 = OpVariable %37 Function + %87 = OpVariable %48 Function + %94 = OpVariable %48 Function + OpStore %65 %13 + OpStore %66 %22 + OpBranch %67 + %67 = OpLabel + %106 = OpPhi %10 %22 %9 %102 %70 + OpLoopMerge %69 %70 None + OpBranch %71 + %71 = OpLabel + %73 = OpSGreaterThan %23 %106 %15 + OpBranchConditional %73 %68 %69 + %68 = OpLabel + %77 = OpAccessChain %11 %74 %106 + %78 = OpLoad %10 %77 + %79 = OpAccessChain %11 %74 %106 + OpStore %79 %78 + %83 = OpIAdd %10 %106 %41 + %84 = OpAccessChain %11 %80 %83 + %85 = OpLoad %10 %84 + %86 = OpAccessChain %11 %80 %106 + OpStore %86 %85 + %90 = OpISub %10 %106 %41 + %91 = OpAccessChain %11 %87 %90 + %92 = OpLoad %10 %91 + %93 = OpAccessChain %11 %87 %106 + OpStore %93 %92 + %96 = OpIMul %10 %13 %106 + %98 = OpAccessChain %11 %94 %106 + %99 = OpLoad %10 %98 + %100 = OpAccessChain %11 %94 %96 + OpStore %100 %99 + OpBranch %70 + %70 = OpLabel + %102 = OpISub %10 %106 %41 + OpStore %66 %102 + OpBranch %67 + %69 = OpLabel + OpReturn + OpFunctionEnd +)"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + // For the loop in function a. + { + const Function* f = spvtest::GetFunction(module, 6); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[4]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 17)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(store[i]); + } + + // = dependence + // 33 -> 34 tests looking at SIV in same array. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(33), store[0], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::DISTANCE); + EXPECT_EQ(distance_vector.GetEntries()[0].direction, + DistanceEntry::Directions::EQ); + EXPECT_EQ(distance_vector.GetEntries()[0].distance, 0); + } + + // > -1 dependence + // 44 -> 45 tests looking at SIV in same array with addition. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(44), store[1], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::DISTANCE); + EXPECT_EQ(distance_vector.GetEntries()[0].direction, + DistanceEntry::Directions::GT); + EXPECT_EQ(distance_vector.GetEntries()[0].distance, -1); + } + + // < 1 dependence + // 54 -> 55 tests looking at SIV in same array with subtraction. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(54), store[2], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::DISTANCE); + EXPECT_EQ(distance_vector.GetEntries()[0].direction, + DistanceEntry::Directions::LT); + EXPECT_EQ(distance_vector.GetEntries()[0].distance, 1); + } + + // <=> dependence + // 61 -> 62 tests looking at SIV in same array with multiplication. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(61), store[3], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::UNKNOWN); + EXPECT_EQ(distance_vector.GetEntries()[0].direction, + DistanceEntry::Directions::ALL); + } + } + // For the loop in function b. + { + const Function* f = spvtest::GetFunction(module, 8); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[4]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 68)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(store[i]); + } + + // = dependence + // 78 -> 79 tests looking at SIV in same array. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(78), store[0], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::DISTANCE); + EXPECT_EQ(distance_vector.GetEntries()[0].direction, + DistanceEntry::Directions::EQ); + EXPECT_EQ(distance_vector.GetEntries()[0].distance, 0); + } + + // < 1 dependence + // 85 -> 86 tests looking at SIV in same array with addition. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(85), store[1], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::DISTANCE); + EXPECT_EQ(distance_vector.GetEntries()[0].direction, + DistanceEntry::Directions::LT); + EXPECT_EQ(distance_vector.GetEntries()[0].distance, 1); + } + + // > -1 dependence + // 92 -> 93 tests looking at SIV in same array with subtraction. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(92), store[2], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::DISTANCE); + EXPECT_EQ(distance_vector.GetEntries()[0].direction, + DistanceEntry::Directions::GT); + EXPECT_EQ(distance_vector.GetEntries()[0].distance, -1); + } + + // <=> dependence + // 99 -> 100 tests looking at SIV in same array with multiplication. + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(99), store[3], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::UNKNOWN); + EXPECT_EQ(distance_vector.GetEntries()[0].direction, + DistanceEntry::Directions::ALL); + } + } +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +layout(location = 0) in vec4 c; +void a() { + int[13] arr; + int[15] arr2; + int[18] arr3; + int[18] arr4; + int N = int(c.x); + int C = 2; + int a = 2; + for (int i = 0; i < N; i++) { // Bounds are N - 1 + arr[i+2*N] = arr[i+N]; // |distance| = N + arr2[i+N] = arr2[i+2*N] + C; // |distance| = N + arr3[2*i+2*N+1] = arr3[2*i+N+1]; // |distance| = N + arr4[a*i+N+1] = arr4[a*i+2*N+1]; // |distance| = N + } +} +void b() { + int[13] arr; + int[15] arr2; + int[18] arr3; + int[18] arr4; + int N = int(c.x); + int C = 2; + int a = 2; + for (int i = N; i > 0; i--) { // Bounds are N - 1 + arr[i+2*N] = arr[i+N]; // |distance| = N + arr2[i+N] = arr2[i+2*N] + C; // |distance| = N + arr3[2*i+2*N+1] = arr3[2*i+N+1]; // |distance| = N + arr4[a*i+N+1] = arr4[a*i+2*N+1]; // |distance| = N + } +} +void main(){ + a(); + b(); +}*/ +TEST(DependencyAnalysis, SymbolicSIV) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %16 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %6 "a(" + OpName %8 "b(" + OpName %12 "N" + OpName %16 "c" + OpName %23 "C" + OpName %25 "a" + OpName %26 "i" + OpName %40 "arr" + OpName %54 "arr2" + OpName %70 "arr3" + OpName %86 "arr4" + OpName %105 "N" + OpName %109 "C" + OpName %110 "a" + OpName %111 "i" + OpName %120 "arr" + OpName %131 "arr2" + OpName %144 "arr3" + OpName %159 "arr4" + OpDecorate %16 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %10 = OpTypeInt 32 1 + %11 = OpTypePointer Function %10 + %13 = OpTypeFloat 32 + %14 = OpTypeVector %13 4 + %15 = OpTypePointer Input %14 + %16 = OpVariable %15 Input + %17 = OpTypeInt 32 0 + %18 = OpConstant %17 0 + %19 = OpTypePointer Input %13 + %24 = OpConstant %10 2 + %27 = OpConstant %10 0 + %35 = OpTypeBool + %37 = OpConstant %17 13 + %38 = OpTypeArray %10 %37 + %39 = OpTypePointer Function %38 + %51 = OpConstant %17 15 + %52 = OpTypeArray %10 %51 + %53 = OpTypePointer Function %52 + %67 = OpConstant %17 18 + %68 = OpTypeArray %10 %67 + %69 = OpTypePointer Function %68 + %76 = OpConstant %10 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %178 = OpFunctionCall %2 %6 + %179 = OpFunctionCall %2 %8 + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %12 = OpVariable %11 Function + %23 = OpVariable %11 Function + %25 = OpVariable %11 Function + %26 = OpVariable %11 Function + %40 = OpVariable %39 Function + %54 = OpVariable %53 Function + %70 = OpVariable %69 Function + %86 = OpVariable %69 Function + %20 = OpAccessChain %19 %16 %18 + %21 = OpLoad %13 %20 + %22 = OpConvertFToS %10 %21 + OpStore %12 %22 + OpStore %23 %24 + OpStore %25 %24 + OpStore %26 %27 + OpBranch %28 + %28 = OpLabel + %180 = OpPhi %10 %27 %7 %104 %31 + OpLoopMerge %30 %31 None + OpBranch %32 + %32 = OpLabel + %36 = OpSLessThan %35 %180 %22 + OpBranchConditional %36 %29 %30 + %29 = OpLabel + %43 = OpIMul %10 %24 %22 + %44 = OpIAdd %10 %180 %43 + %47 = OpIAdd %10 %180 %22 + %48 = OpAccessChain %11 %40 %47 + %49 = OpLoad %10 %48 + %50 = OpAccessChain %11 %40 %44 + OpStore %50 %49 + %57 = OpIAdd %10 %180 %22 + %60 = OpIMul %10 %24 %22 + %61 = OpIAdd %10 %180 %60 + %62 = OpAccessChain %11 %54 %61 + %63 = OpLoad %10 %62 + %65 = OpIAdd %10 %63 %24 + %66 = OpAccessChain %11 %54 %57 + OpStore %66 %65 + %72 = OpIMul %10 %24 %180 + %74 = OpIMul %10 %24 %22 + %75 = OpIAdd %10 %72 %74 + %77 = OpIAdd %10 %75 %76 + %79 = OpIMul %10 %24 %180 + %81 = OpIAdd %10 %79 %22 + %82 = OpIAdd %10 %81 %76 + %83 = OpAccessChain %11 %70 %82 + %84 = OpLoad %10 %83 + %85 = OpAccessChain %11 %70 %77 + OpStore %85 %84 + %89 = OpIMul %10 %24 %180 + %91 = OpIAdd %10 %89 %22 + %92 = OpIAdd %10 %91 %76 + %95 = OpIMul %10 %24 %180 + %97 = OpIMul %10 %24 %22 + %98 = OpIAdd %10 %95 %97 + %99 = OpIAdd %10 %98 %76 + %100 = OpAccessChain %11 %86 %99 + %101 = OpLoad %10 %100 + %102 = OpAccessChain %11 %86 %92 + OpStore %102 %101 + OpBranch %31 + %31 = OpLabel + %104 = OpIAdd %10 %180 %76 + OpStore %26 %104 + OpBranch %28 + %30 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %2 None %3 + %9 = OpLabel + %105 = OpVariable %11 Function + %109 = OpVariable %11 Function + %110 = OpVariable %11 Function + %111 = OpVariable %11 Function + %120 = OpVariable %39 Function + %131 = OpVariable %53 Function + %144 = OpVariable %69 Function + %159 = OpVariable %69 Function + %106 = OpAccessChain %19 %16 %18 + %107 = OpLoad %13 %106 + %108 = OpConvertFToS %10 %107 + OpStore %105 %108 + OpStore %109 %24 + OpStore %110 %24 + OpStore %111 %108 + OpBranch %113 + %113 = OpLabel + %181 = OpPhi %10 %108 %9 %177 %116 + OpLoopMerge %115 %116 None + OpBranch %117 + %117 = OpLabel + %119 = OpSGreaterThan %35 %181 %27 + OpBranchConditional %119 %114 %115 + %114 = OpLabel + %123 = OpIMul %10 %24 %108 + %124 = OpIAdd %10 %181 %123 + %127 = OpIAdd %10 %181 %108 + %128 = OpAccessChain %11 %120 %127 + %129 = OpLoad %10 %128 + %130 = OpAccessChain %11 %120 %124 + OpStore %130 %129 + %134 = OpIAdd %10 %181 %108 + %137 = OpIMul %10 %24 %108 + %138 = OpIAdd %10 %181 %137 + %139 = OpAccessChain %11 %131 %138 + %140 = OpLoad %10 %139 + %142 = OpIAdd %10 %140 %24 + %143 = OpAccessChain %11 %131 %134 + OpStore %143 %142 + %146 = OpIMul %10 %24 %181 + %148 = OpIMul %10 %24 %108 + %149 = OpIAdd %10 %146 %148 + %150 = OpIAdd %10 %149 %76 + %152 = OpIMul %10 %24 %181 + %154 = OpIAdd %10 %152 %108 + %155 = OpIAdd %10 %154 %76 + %156 = OpAccessChain %11 %144 %155 + %157 = OpLoad %10 %156 + %158 = OpAccessChain %11 %144 %150 + OpStore %158 %157 + %162 = OpIMul %10 %24 %181 + %164 = OpIAdd %10 %162 %108 + %165 = OpIAdd %10 %164 %76 + %168 = OpIMul %10 %24 %181 + %170 = OpIMul %10 %24 %108 + %171 = OpIAdd %10 %168 %170 + %172 = OpIAdd %10 %171 %76 + %173 = OpAccessChain %11 %159 %172 + %174 = OpLoad %10 %173 + %175 = OpAccessChain %11 %159 %165 + OpStore %175 %174 + OpBranch %116 + %116 = OpLabel + %177 = OpISub %10 %181 %76 + OpStore %111 %177 + OpBranch %113 + %115 = OpLabel + OpReturn + OpFunctionEnd +)"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + // For the loop in function a. + { + const Function* f = spvtest::GetFunction(module, 6); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[4]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 29)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(store[i]); + } + + // independent due to loop bounds (won't enter when N <= 0) + // 49 -> 50 tests looking through SIV and symbols with multiplication + { + DistanceVector distance_vector{loops.size()}; + // Independent but not yet supported. + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(49), store[0], &distance_vector)); + } + + // 63 -> 66 tests looking through SIV and symbols with multiplication and + + // C + { + DistanceVector distance_vector{loops.size()}; + // Independent. + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(63), + store[1], &distance_vector)); + } + + // 84 -> 85 tests looking through arithmetic on SIV and symbols + { + DistanceVector distance_vector{loops.size()}; + // Independent but not yet supported. + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(84), store[2], &distance_vector)); + } + + // 101 -> 102 tests looking through symbol arithmetic on SIV and symbols + { + DistanceVector distance_vector{loops.size()}; + // Independent. + EXPECT_TRUE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(101), store[3], &distance_vector)); + } + } + // For the loop in function b. + { + const Function* f = spvtest::GetFunction(module, 8); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[4]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 114)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(store[i]); + } + + // independent due to loop bounds (won't enter when N <= 0). + // 129 -> 130 tests looking through SIV and symbols with multiplication. + { + DistanceVector distance_vector{loops.size()}; + // Independent but not yet supported. + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(129), store[0], &distance_vector)); + } + + // 140 -> 143 tests looking through SIV and symbols with multiplication and + // + C. + { + DistanceVector distance_vector{loops.size()}; + // Independent. + EXPECT_TRUE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(140), store[1], &distance_vector)); + } + + // 157 -> 158 tests looking through arithmetic on SIV and symbols. + { + DistanceVector distance_vector{loops.size()}; + // Independent but not yet supported. + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(157), store[2], &distance_vector)); + } + + // 174 -> 175 tests looking through symbol arithmetic on SIV and symbols. + { + DistanceVector distance_vector{loops.size()}; + // Independent. + EXPECT_TRUE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(174), store[3], &distance_vector)); + } + } +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +void a() { + int[6] arr; + int N = 5; + for (int i = 1; i < N; i++) { + arr[i] = arr[N-i]; + } +} +void b() { + int[6] arr; + int N = 5; + for (int i = 1; i < N; i++) { + arr[N-i] = arr[i]; + } +} +void c() { + int[11] arr; + int N = 10; + for (int i = 1; i < N; i++) { + arr[i] = arr[N-i+1]; + } +} +void d() { + int[11] arr; + int N = 10; + for (int i = 1; i < N; i++) { + arr[N-i+1] = arr[i]; + } +} +void e() { + int[6] arr; + int N = 5; + for (int i = N; i > 0; i--) { + arr[i] = arr[N-i]; + } +} +void f() { + int[6] arr; + int N = 5; + for (int i = N; i > 0; i--) { + arr[N-i] = arr[i]; + } +} +void g() { + int[11] arr; + int N = 10; + for (int i = N; i > 0; i--) { + arr[i] = arr[N-i+1]; + } +} +void h() { + int[11] arr; + int N = 10; + for (int i = N; i > 0; i--) { + arr[N-i+1] = arr[i]; + } +} +void main(){ + a(); + b(); + c(); + d(); + e(); + f(); + g(); + h(); +} +*/ +TEST(DependencyAnalysis, Crossing) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %6 "a(" + OpName %8 "b(" + OpName %10 "c(" + OpName %12 "d(" + OpName %14 "e(" + OpName %16 "f(" + OpName %18 "g(" + OpName %20 "h(" + OpName %24 "N" + OpName %26 "i" + OpName %41 "arr" + OpName %51 "N" + OpName %52 "i" + OpName %61 "arr" + OpName %71 "N" + OpName %73 "i" + OpName %85 "arr" + OpName %96 "N" + OpName %97 "i" + OpName %106 "arr" + OpName %117 "N" + OpName %118 "i" + OpName %128 "arr" + OpName %138 "N" + OpName %139 "i" + OpName %148 "arr" + OpName %158 "N" + OpName %159 "i" + OpName %168 "arr" + OpName %179 "N" + OpName %180 "i" + OpName %189 "arr" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %22 = OpTypeInt 32 1 + %23 = OpTypePointer Function %22 + %25 = OpConstant %22 5 + %27 = OpConstant %22 1 + %35 = OpTypeBool + %37 = OpTypeInt 32 0 + %38 = OpConstant %37 6 + %39 = OpTypeArray %22 %38 + %40 = OpTypePointer Function %39 + %72 = OpConstant %22 10 + %82 = OpConstant %37 11 + %83 = OpTypeArray %22 %82 + %84 = OpTypePointer Function %83 + %126 = OpConstant %22 0 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %200 = OpFunctionCall %2 %6 + %201 = OpFunctionCall %2 %8 + %202 = OpFunctionCall %2 %10 + %203 = OpFunctionCall %2 %12 + %204 = OpFunctionCall %2 %14 + %205 = OpFunctionCall %2 %16 + %206 = OpFunctionCall %2 %18 + %207 = OpFunctionCall %2 %20 + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %24 = OpVariable %23 Function + %26 = OpVariable %23 Function + %41 = OpVariable %40 Function + OpStore %24 %25 + OpStore %26 %27 + OpBranch %28 + %28 = OpLabel + %208 = OpPhi %22 %27 %7 %50 %31 + OpLoopMerge %30 %31 None + OpBranch %32 + %32 = OpLabel + %36 = OpSLessThan %35 %208 %25 + OpBranchConditional %36 %29 %30 + %29 = OpLabel + %45 = OpISub %22 %25 %208 + %46 = OpAccessChain %23 %41 %45 + %47 = OpLoad %22 %46 + %48 = OpAccessChain %23 %41 %208 + OpStore %48 %47 + OpBranch %31 + %31 = OpLabel + %50 = OpIAdd %22 %208 %27 + OpStore %26 %50 + OpBranch %28 + %30 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %2 None %3 + %9 = OpLabel + %51 = OpVariable %23 Function + %52 = OpVariable %23 Function + %61 = OpVariable %40 Function + OpStore %51 %25 + OpStore %52 %27 + OpBranch %53 + %53 = OpLabel + %209 = OpPhi %22 %27 %9 %70 %56 + OpLoopMerge %55 %56 None + OpBranch %57 + %57 = OpLabel + %60 = OpSLessThan %35 %209 %25 + OpBranchConditional %60 %54 %55 + %54 = OpLabel + %64 = OpISub %22 %25 %209 + %66 = OpAccessChain %23 %61 %209 + %67 = OpLoad %22 %66 + %68 = OpAccessChain %23 %61 %64 + OpStore %68 %67 + OpBranch %56 + %56 = OpLabel + %70 = OpIAdd %22 %209 %27 + OpStore %52 %70 + OpBranch %53 + %55 = OpLabel + OpReturn + OpFunctionEnd + %10 = OpFunction %2 None %3 + %11 = OpLabel + %71 = OpVariable %23 Function + %73 = OpVariable %23 Function + %85 = OpVariable %84 Function + OpStore %71 %72 + OpStore %73 %27 + OpBranch %74 + %74 = OpLabel + %210 = OpPhi %22 %27 %11 %95 %77 + OpLoopMerge %76 %77 None + OpBranch %78 + %78 = OpLabel + %81 = OpSLessThan %35 %210 %72 + OpBranchConditional %81 %75 %76 + %75 = OpLabel + %89 = OpISub %22 %72 %210 + %90 = OpIAdd %22 %89 %27 + %91 = OpAccessChain %23 %85 %90 + %92 = OpLoad %22 %91 + %93 = OpAccessChain %23 %85 %210 + OpStore %93 %92 + OpBranch %77 + %77 = OpLabel + %95 = OpIAdd %22 %210 %27 + OpStore %73 %95 + OpBranch %74 + %76 = OpLabel + OpReturn + OpFunctionEnd + %12 = OpFunction %2 None %3 + %13 = OpLabel + %96 = OpVariable %23 Function + %97 = OpVariable %23 Function + %106 = OpVariable %84 Function + OpStore %96 %72 + OpStore %97 %27 + OpBranch %98 + %98 = OpLabel + %211 = OpPhi %22 %27 %13 %116 %101 + OpLoopMerge %100 %101 None + OpBranch %102 + %102 = OpLabel + %105 = OpSLessThan %35 %211 %72 + OpBranchConditional %105 %99 %100 + %99 = OpLabel + %109 = OpISub %22 %72 %211 + %110 = OpIAdd %22 %109 %27 + %112 = OpAccessChain %23 %106 %211 + %113 = OpLoad %22 %112 + %114 = OpAccessChain %23 %106 %110 + OpStore %114 %113 + OpBranch %101 + %101 = OpLabel + %116 = OpIAdd %22 %211 %27 + OpStore %97 %116 + OpBranch %98 + %100 = OpLabel + OpReturn + OpFunctionEnd + %14 = OpFunction %2 None %3 + %15 = OpLabel + %117 = OpVariable %23 Function + %118 = OpVariable %23 Function + %128 = OpVariable %40 Function + OpStore %117 %25 + OpStore %118 %25 + OpBranch %120 + %120 = OpLabel + %212 = OpPhi %22 %25 %15 %137 %123 + OpLoopMerge %122 %123 None + OpBranch %124 + %124 = OpLabel + %127 = OpSGreaterThan %35 %212 %126 + OpBranchConditional %127 %121 %122 + %121 = OpLabel + %132 = OpISub %22 %25 %212 + %133 = OpAccessChain %23 %128 %132 + %134 = OpLoad %22 %133 + %135 = OpAccessChain %23 %128 %212 + OpStore %135 %134 + OpBranch %123 + %123 = OpLabel + %137 = OpISub %22 %212 %27 + OpStore %118 %137 + OpBranch %120 + %122 = OpLabel + OpReturn + OpFunctionEnd + %16 = OpFunction %2 None %3 + %17 = OpLabel + %138 = OpVariable %23 Function + %139 = OpVariable %23 Function + %148 = OpVariable %40 Function + OpStore %138 %25 + OpStore %139 %25 + OpBranch %141 + %141 = OpLabel + %213 = OpPhi %22 %25 %17 %157 %144 + OpLoopMerge %143 %144 None + OpBranch %145 + %145 = OpLabel + %147 = OpSGreaterThan %35 %213 %126 + OpBranchConditional %147 %142 %143 + %142 = OpLabel + %151 = OpISub %22 %25 %213 + %153 = OpAccessChain %23 %148 %213 + %154 = OpLoad %22 %153 + %155 = OpAccessChain %23 %148 %151 + OpStore %155 %154 + OpBranch %144 + %144 = OpLabel + %157 = OpISub %22 %213 %27 + OpStore %139 %157 + OpBranch %141 + %143 = OpLabel + OpReturn + OpFunctionEnd + %18 = OpFunction %2 None %3 + %19 = OpLabel + %158 = OpVariable %23 Function + %159 = OpVariable %23 Function + %168 = OpVariable %84 Function + OpStore %158 %72 + OpStore %159 %72 + OpBranch %161 + %161 = OpLabel + %214 = OpPhi %22 %72 %19 %178 %164 + OpLoopMerge %163 %164 None + OpBranch %165 + %165 = OpLabel + %167 = OpSGreaterThan %35 %214 %126 + OpBranchConditional %167 %162 %163 + %162 = OpLabel + %172 = OpISub %22 %72 %214 + %173 = OpIAdd %22 %172 %27 + %174 = OpAccessChain %23 %168 %173 + %175 = OpLoad %22 %174 + %176 = OpAccessChain %23 %168 %214 + OpStore %176 %175 + OpBranch %164 + %164 = OpLabel + %178 = OpISub %22 %214 %27 + OpStore %159 %178 + OpBranch %161 + %163 = OpLabel + OpReturn + OpFunctionEnd + %20 = OpFunction %2 None %3 + %21 = OpLabel + %179 = OpVariable %23 Function + %180 = OpVariable %23 Function + %189 = OpVariable %84 Function + OpStore %179 %72 + OpStore %180 %72 + OpBranch %182 + %182 = OpLabel + %215 = OpPhi %22 %72 %21 %199 %185 + OpLoopMerge %184 %185 None + OpBranch %186 + %186 = OpLabel + %188 = OpSGreaterThan %35 %215 %126 + OpBranchConditional %188 %183 %184 + %183 = OpLabel + %192 = OpISub %22 %72 %215 + %193 = OpIAdd %22 %192 %27 + %195 = OpAccessChain %23 %189 %215 + %196 = OpLoad %22 %195 + %197 = OpAccessChain %23 %189 %193 + OpStore %197 %196 + OpBranch %185 + %185 = OpLabel + %199 = OpISub %22 %215 %27 + OpStore %180 %199 + OpBranch %182 + %184 = OpLabel + OpReturn + OpFunctionEnd +)"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + // First two tests can be split into two loops. + // Tests even crossing subscripts from low to high indexes. + // 47 -> 48 + { + const Function* f = spvtest::GetFunction(module, 6); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store = nullptr; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 29)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + } + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(47), + store, &distance_vector)); + } + + // Tests even crossing subscripts from high to low indexes. + // 67 -> 68 + { + const Function* f = spvtest::GetFunction(module, 8); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store = nullptr; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 54)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + } + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(67), + store, &distance_vector)); + } + + // Next two tests can have an end peeled, then be split. + // Tests uneven crossing subscripts from low to high indexes. + // 92 -> 93 + { + const Function* f = spvtest::GetFunction(module, 10); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store = nullptr; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 75)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + } + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(92), + store, &distance_vector)); + } + + // Tests uneven crossing subscripts from high to low indexes. + // 113 -> 114 + { + const Function* f = spvtest::GetFunction(module, 12); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store = nullptr; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 99)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + } + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(113), + store, &distance_vector)); + } + + // First two tests can be split into two loops. + // Tests even crossing subscripts from low to high indexes. + // 134 -> 135 + { + const Function* f = spvtest::GetFunction(module, 14); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store = nullptr; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 121)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + } + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(134), + store, &distance_vector)); + } + + // Tests even crossing subscripts from high to low indexes. + // 154 -> 155 + { + const Function* f = spvtest::GetFunction(module, 16); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store = nullptr; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 142)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + } + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(154), + store, &distance_vector)); + } + + // Next two tests can have an end peeled, then be split. + // Tests uneven crossing subscripts from low to high indexes. + // 175 -> 176 + { + const Function* f = spvtest::GetFunction(module, 18); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store = nullptr; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 162)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + } + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(175), + store, &distance_vector)); + } + + // Tests uneven crossing subscripts from high to low indexes. + // 196 -> 197 + { + const Function* f = spvtest::GetFunction(module, 20); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store = nullptr; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 183)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + } + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(196), + store, &distance_vector)); + } +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +void a() { + int[10] arr; + for (int i = 0; i < 10; i++) { + arr[0] = arr[i]; // peel first + arr[i] = arr[0]; // peel first + arr[9] = arr[i]; // peel last + arr[i] = arr[9]; // peel last + } +} +void b() { + int[11] arr; + for (int i = 0; i <= 10; i++) { + arr[0] = arr[i]; // peel first + arr[i] = arr[0]; // peel first + arr[10] = arr[i]; // peel last + arr[i] = arr[10]; // peel last + + } +} +void c() { + int[11] arr; + for (int i = 10; i > 0; i--) { + arr[10] = arr[i]; // peel first + arr[i] = arr[10]; // peel first + arr[1] = arr[i]; // peel last + arr[i] = arr[1]; // peel last + + } +} +void d() { + int[11] arr; + for (int i = 10; i >= 0; i--) { + arr[10] = arr[i]; // peel first + arr[i] = arr[10]; // peel first + arr[0] = arr[i]; // peel last + arr[i] = arr[0]; // peel last + + } +} +void main(){ + a(); + b(); + c(); + d(); +} +*/ +TEST(DependencyAnalysis, WeakZeroSIV) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %6 "a(" + OpName %8 "b(" + OpName %10 "c(" + OpName %12 "d(" + OpName %16 "i" + OpName %31 "arr" + OpName %52 "i" + OpName %63 "arr" + OpName %82 "i" + OpName %90 "arr" + OpName %109 "i" + OpName %117 "arr" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %14 = OpTypeInt 32 1 + %15 = OpTypePointer Function %14 + %17 = OpConstant %14 0 + %24 = OpConstant %14 10 + %25 = OpTypeBool + %27 = OpTypeInt 32 0 + %28 = OpConstant %27 10 + %29 = OpTypeArray %14 %28 + %30 = OpTypePointer Function %29 + %40 = OpConstant %14 9 + %50 = OpConstant %14 1 + %60 = OpConstant %27 11 + %61 = OpTypeArray %14 %60 + %62 = OpTypePointer Function %61 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %136 = OpFunctionCall %2 %6 + %137 = OpFunctionCall %2 %8 + %138 = OpFunctionCall %2 %10 + %139 = OpFunctionCall %2 %12 + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %16 = OpVariable %15 Function + %31 = OpVariable %30 Function + OpStore %16 %17 + OpBranch %18 + %18 = OpLabel + %140 = OpPhi %14 %17 %7 %51 %21 + OpLoopMerge %20 %21 None + OpBranch %22 + %22 = OpLabel + %26 = OpSLessThan %25 %140 %24 + OpBranchConditional %26 %19 %20 + %19 = OpLabel + %33 = OpAccessChain %15 %31 %140 + %34 = OpLoad %14 %33 + %35 = OpAccessChain %15 %31 %17 + OpStore %35 %34 + %37 = OpAccessChain %15 %31 %17 + %38 = OpLoad %14 %37 + %39 = OpAccessChain %15 %31 %140 + OpStore %39 %38 + %42 = OpAccessChain %15 %31 %140 + %43 = OpLoad %14 %42 + %44 = OpAccessChain %15 %31 %40 + OpStore %44 %43 + %46 = OpAccessChain %15 %31 %40 + %47 = OpLoad %14 %46 + %48 = OpAccessChain %15 %31 %140 + OpStore %48 %47 + OpBranch %21 + %21 = OpLabel + %51 = OpIAdd %14 %140 %50 + OpStore %16 %51 + OpBranch %18 + %20 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %2 None %3 + %9 = OpLabel + %52 = OpVariable %15 Function + %63 = OpVariable %62 Function + OpStore %52 %17 + OpBranch %53 + %53 = OpLabel + %141 = OpPhi %14 %17 %9 %81 %56 + OpLoopMerge %55 %56 None + OpBranch %57 + %57 = OpLabel + %59 = OpSLessThanEqual %25 %141 %24 + OpBranchConditional %59 %54 %55 + %54 = OpLabel + %65 = OpAccessChain %15 %63 %141 + %66 = OpLoad %14 %65 + %67 = OpAccessChain %15 %63 %17 + OpStore %67 %66 + %69 = OpAccessChain %15 %63 %17 + %70 = OpLoad %14 %69 + %71 = OpAccessChain %15 %63 %141 + OpStore %71 %70 + %73 = OpAccessChain %15 %63 %141 + %74 = OpLoad %14 %73 + %75 = OpAccessChain %15 %63 %24 + OpStore %75 %74 + %77 = OpAccessChain %15 %63 %24 + %78 = OpLoad %14 %77 + %79 = OpAccessChain %15 %63 %141 + OpStore %79 %78 + OpBranch %56 + %56 = OpLabel + %81 = OpIAdd %14 %141 %50 + OpStore %52 %81 + OpBranch %53 + %55 = OpLabel + OpReturn + OpFunctionEnd + %10 = OpFunction %2 None %3 + %11 = OpLabel + %82 = OpVariable %15 Function + %90 = OpVariable %62 Function + OpStore %82 %24 + OpBranch %83 + %83 = OpLabel + %142 = OpPhi %14 %24 %11 %108 %86 + OpLoopMerge %85 %86 None + OpBranch %87 + %87 = OpLabel + %89 = OpSGreaterThan %25 %142 %17 + OpBranchConditional %89 %84 %85 + %84 = OpLabel + %92 = OpAccessChain %15 %90 %142 + %93 = OpLoad %14 %92 + %94 = OpAccessChain %15 %90 %24 + OpStore %94 %93 + %96 = OpAccessChain %15 %90 %24 + %97 = OpLoad %14 %96 + %98 = OpAccessChain %15 %90 %142 + OpStore %98 %97 + %100 = OpAccessChain %15 %90 %142 + %101 = OpLoad %14 %100 + %102 = OpAccessChain %15 %90 %50 + OpStore %102 %101 + %104 = OpAccessChain %15 %90 %50 + %105 = OpLoad %14 %104 + %106 = OpAccessChain %15 %90 %142 + OpStore %106 %105 + OpBranch %86 + %86 = OpLabel + %108 = OpISub %14 %142 %50 + OpStore %82 %108 + OpBranch %83 + %85 = OpLabel + OpReturn + OpFunctionEnd + %12 = OpFunction %2 None %3 + %13 = OpLabel + %109 = OpVariable %15 Function + %117 = OpVariable %62 Function + OpStore %109 %24 + OpBranch %110 + %110 = OpLabel + %143 = OpPhi %14 %24 %13 %135 %113 + OpLoopMerge %112 %113 None + OpBranch %114 + %114 = OpLabel + %116 = OpSGreaterThanEqual %25 %143 %17 + OpBranchConditional %116 %111 %112 + %111 = OpLabel + %119 = OpAccessChain %15 %117 %143 + %120 = OpLoad %14 %119 + %121 = OpAccessChain %15 %117 %24 + OpStore %121 %120 + %123 = OpAccessChain %15 %117 %24 + %124 = OpLoad %14 %123 + %125 = OpAccessChain %15 %117 %143 + OpStore %125 %124 + %127 = OpAccessChain %15 %117 %143 + %128 = OpLoad %14 %127 + %129 = OpAccessChain %15 %117 %17 + OpStore %129 %128 + %131 = OpAccessChain %15 %117 %17 + %132 = OpLoad %14 %131 + %133 = OpAccessChain %15 %117 %143 + OpStore %133 %132 + OpBranch %113 + %113 = OpLabel + %135 = OpISub %14 %143 %50 + OpStore %109 %135 + OpBranch %110 + %112 = OpLabel + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + // For the loop in function a + { + const Function* f = spvtest::GetFunction(module, 6); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[4]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 19)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(store[i]); + } + + // Tests identifying peel first with weak zero with destination as zero + // index. + // 34 -> 35 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(34), store[0], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_first); + } + + // Tests identifying peel first with weak zero with source as zero index. + // 38 -> 39 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(38), store[1], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_first); + } + + // Tests identifying peel first with weak zero with destination as zero + // index. + // 43 -> 44 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(43), store[2], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_last); + } + + // Tests identifying peel first with weak zero with source as zero index. + // 47 -> 48 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(47), store[3], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_last); + } + } + // For the loop in function b + { + const Function* f = spvtest::GetFunction(module, 8); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[4]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 54)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(store[i]); + } + + // Tests identifying peel first with weak zero with destination as zero + // index. + // 66 -> 67 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(66), store[0], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_first); + } + + // Tests identifying peel first with weak zero with source as zero index. + // 70 -> 71 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(70), store[1], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_first); + } + + // Tests identifying peel first with weak zero with destination as zero + // index. + // 74 -> 75 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(74), store[2], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_last); + } + + // Tests identifying peel first with weak zero with source as zero index. + // 78 -> 79 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(78), store[3], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_last); + } + } + // For the loop in function c + { + const Function* f = spvtest::GetFunction(module, 10); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + const Instruction* store[4]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 84)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(store[i]); + } + + // Tests identifying peel first with weak zero with destination as zero + // index. + // 93 -> 94 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(93), store[0], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_first); + } + + // Tests identifying peel first with weak zero with source as zero index. + // 97 -> 98 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(97), store[1], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_first); + } + + // Tests identifying peel first with weak zero with destination as zero + // index. + // 101 -> 102 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(101), store[2], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_last); + } + + // Tests identifying peel first with weak zero with source as zero index. + // 105 -> 106 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(105), store[3], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_last); + } + } + // For the loop in function d + { + const Function* f = spvtest::GetFunction(module, 12); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[4]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 111)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(store[i]); + } + + // Tests identifying peel first with weak zero with destination as zero + // index. + // 120 -> 121 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(120), store[0], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_first); + } + + // Tests identifying peel first with weak zero with source as zero index. + // 124 -> 125 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(124), store[1], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_first); + } + + // Tests identifying peel first with weak zero with destination as zero + // index. + // 128 -> 129 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(128), store[2], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_last); + } + + // Tests identifying peel first with weak zero with source as zero index. + // 132 -> 133 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(132), store[3], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::PEEL); + EXPECT_TRUE(distance_vector.GetEntries()[0].peel_last); + } + } +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +void main(){ + int[10][10] arr; + for (int i = 0; i < 10; i++) { + arr[i][i] = arr[i][i]; + arr[0][i] = arr[1][i]; + arr[1][i] = arr[0][i]; + arr[i][0] = arr[i][1]; + arr[i][1] = arr[i][0]; + arr[0][1] = arr[1][0]; + } +} +*/ +TEST(DependencyAnalysis, MultipleSubscriptZIVSIV) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %24 "arr" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypeArray %21 %20 + %23 = OpTypePointer Function %22 + %33 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %24 = OpVariable %23 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %58 = OpPhi %6 %9 %5 %57 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %58 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %29 = OpAccessChain %7 %24 %58 %58 + %30 = OpLoad %6 %29 + %31 = OpAccessChain %7 %24 %58 %58 + OpStore %31 %30 + %35 = OpAccessChain %7 %24 %33 %58 + %36 = OpLoad %6 %35 + %37 = OpAccessChain %7 %24 %9 %58 + OpStore %37 %36 + %40 = OpAccessChain %7 %24 %9 %58 + %41 = OpLoad %6 %40 + %42 = OpAccessChain %7 %24 %33 %58 + OpStore %42 %41 + %45 = OpAccessChain %7 %24 %58 %33 + %46 = OpLoad %6 %45 + %47 = OpAccessChain %7 %24 %58 %9 + OpStore %47 %46 + %50 = OpAccessChain %7 %24 %58 %9 + %51 = OpLoad %6 %50 + %52 = OpAccessChain %7 %24 %58 %33 + OpStore %52 %51 + %53 = OpAccessChain %7 %24 %33 %9 + %54 = OpLoad %6 %53 + %55 = OpAccessChain %7 %24 %9 %33 + OpStore %55 %54 + OpBranch %13 + %13 = OpLabel + %57 = OpIAdd %6 %58 %33 + OpStore %8 %57 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 4); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[6]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 11)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 6; ++i) { + EXPECT_TRUE(store[i]); + } + + // 30 -> 31 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(30), + store[0], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::DISTANCE); + EXPECT_EQ(distance_vector.GetEntries()[0].direction, + DistanceEntry::Directions::EQ); + EXPECT_EQ(distance_vector.GetEntries()[0].distance, 0); + } + + // 36 -> 37 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(36), + store[1], &distance_vector)); + } + + // 41 -> 42 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(41), + store[2], &distance_vector)); + } + + // 46 -> 47 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(46), + store[3], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::DISTANCE); + EXPECT_EQ(distance_vector.GetEntries()[0].direction, + DistanceEntry::Directions::EQ); + EXPECT_EQ(distance_vector.GetEntries()[0].distance, 0); + } + + // 51 -> 52 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(51), + store[4], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::DISTANCE); + EXPECT_EQ(distance_vector.GetEntries()[0].direction, + DistanceEntry::Directions::EQ); + EXPECT_EQ(distance_vector.GetEntries()[0].distance, 0); + } + + // 54 -> 55 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_TRUE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(54), + store[5], &distance_vector)); + } +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +void a(){ + int[10] arr; + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + arr[j] = arr[j]; + } + } +} +void b(){ + int[10] arr; + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + arr[i] = arr[i]; + } + } +} +void main() { + a(); + b(); +} +*/ +TEST(DependencyAnalysis, IrrelevantSubscripts) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %6 "a(" + OpName %8 "b(" + OpName %12 "i" + OpName %23 "j" + OpName %35 "arr" + OpName %46 "i" + OpName %54 "j" + OpName %62 "arr" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %10 = OpTypeInt 32 1 + %11 = OpTypePointer Function %10 + %13 = OpConstant %10 0 + %20 = OpConstant %10 10 + %21 = OpTypeBool + %31 = OpTypeInt 32 0 + %32 = OpConstant %31 10 + %33 = OpTypeArray %10 %32 + %34 = OpTypePointer Function %33 + %42 = OpConstant %10 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %72 = OpFunctionCall %2 %6 + %73 = OpFunctionCall %2 %8 + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %12 = OpVariable %11 Function + %23 = OpVariable %11 Function + %35 = OpVariable %34 Function + OpStore %12 %13 + OpBranch %14 + %14 = OpLabel + %74 = OpPhi %10 %13 %7 %45 %17 + OpLoopMerge %16 %17 None + OpBranch %18 + %18 = OpLabel + %22 = OpSLessThan %21 %74 %20 + OpBranchConditional %22 %15 %16 + %15 = OpLabel + OpStore %23 %13 + OpBranch %24 + %24 = OpLabel + %75 = OpPhi %10 %13 %15 %43 %27 + OpLoopMerge %26 %27 None + OpBranch %28 + %28 = OpLabel + %30 = OpSLessThan %21 %75 %20 + OpBranchConditional %30 %25 %26 + %25 = OpLabel + %38 = OpAccessChain %11 %35 %75 + %39 = OpLoad %10 %38 + %40 = OpAccessChain %11 %35 %75 + OpStore %40 %39 + OpBranch %27 + %27 = OpLabel + %43 = OpIAdd %10 %75 %42 + OpStore %23 %43 + OpBranch %24 + %26 = OpLabel + OpBranch %17 + %17 = OpLabel + %45 = OpIAdd %10 %74 %42 + OpStore %12 %45 + OpBranch %14 + %16 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %2 None %3 + %9 = OpLabel + %46 = OpVariable %11 Function + %54 = OpVariable %11 Function + %62 = OpVariable %34 Function + OpStore %46 %13 + OpBranch %47 + %47 = OpLabel + %77 = OpPhi %10 %13 %9 %71 %50 + OpLoopMerge %49 %50 None + OpBranch %51 + %51 = OpLabel + %53 = OpSLessThan %21 %77 %20 + OpBranchConditional %53 %48 %49 + %48 = OpLabel + OpStore %54 %13 + OpBranch %55 + %55 = OpLabel + %78 = OpPhi %10 %13 %48 %69 %58 + OpLoopMerge %57 %58 None + OpBranch %59 + %59 = OpLabel + %61 = OpSLessThan %21 %78 %20 + OpBranchConditional %61 %56 %57 + %56 = OpLabel + %65 = OpAccessChain %11 %62 %77 + %66 = OpLoad %10 %65 + %67 = OpAccessChain %11 %62 %77 + OpStore %67 %66 + OpBranch %58 + %58 = OpLabel + %69 = OpIAdd %10 %78 %42 + OpStore %54 %69 + OpBranch %55 + %57 = OpLabel + OpBranch %50 + %50 = OpLabel + %71 = OpIAdd %10 %77 %42 + OpStore %46 %71 + OpBranch %47 + %49 = OpLabel + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + // For the loop in function a + { + const Function* f = spvtest::GetFunction(module, 6); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + std::vector loops{&ld.GetLoopByIndex(1), + &ld.GetLoopByIndex(0)}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[1]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 25)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 1; ++i) { + EXPECT_TRUE(store[i]); + } + + // 39 -> 40 + { + DistanceVector distance_vector{loops.size()}; + analysis.SetDebugStream(std::cout); + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(39), store[0], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::IRRELEVANT); + EXPECT_EQ(distance_vector.GetEntries()[1].dependence_information, + DistanceEntry::DependenceInformation::DISTANCE); + EXPECT_EQ(distance_vector.GetEntries()[1].distance, 0); + } + } + + // For the loop in function b + { + const Function* f = spvtest::GetFunction(module, 8); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + std::vector loops{&ld.GetLoopByIndex(1), + &ld.GetLoopByIndex(0)}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[1]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 56)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 1; ++i) { + EXPECT_TRUE(store[i]); + } + + // 66 -> 67 + { + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.GetDependence( + context->get_def_use_mgr()->GetDef(66), store[0], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::DISTANCE); + EXPECT_EQ(distance_vector.GetEntries()[0].distance, 0); + EXPECT_EQ(distance_vector.GetEntries()[1].dependence_information, + DistanceEntry::DependenceInformation::IRRELEVANT); + } + } +} + +void CheckDependenceAndDirection(const Instruction* source, + const Instruction* destination, + bool expected_dependence, + DistanceVector expected_distance, + LoopDependenceAnalysis* analysis) { + DistanceVector dv_entry(2); + EXPECT_EQ(expected_dependence, + analysis->GetDependence(source, destination, &dv_entry)); + EXPECT_EQ(expected_distance, dv_entry); +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +layout(location = 0) in vec4 c; +void main(){ + int[10] arr; + int a = 2; + int b = 3; + int N = int(c.x); + for (int i = 0; i < 10; i++) { + for (int j = 2; j < 10; j++) { + arr[i] = arr[j]; // 0 + arr[j] = arr[i]; // 1 + arr[j-2] = arr[i+3]; // 2 + arr[j-a] = arr[i+b]; // 3 + arr[2*i] = arr[4*j+3]; // 4, independent + arr[2*i] = arr[4*j]; // 5 + arr[i+j] = arr[i+j]; // 6 + arr[10*i+j] = arr[10*i+j]; // 7 + arr[10*i+10*j] = arr[10*i+10*j+3]; // 8, independent + arr[10*i+10*j] = arr[10*i+N*j+3]; // 9, bail out because of N coefficient + arr[10*i+10*j] = arr[10*i+10*j+N]; // 10, bail out because of N constant + // term + arr[10*i+N*j] = arr[10*i+10*j+3]; // 11, bail out because of N coefficient + arr[10*i+10*j+N] = arr[10*i+10*j]; // 12, bail out because of N constant + // term + arr[10*i] = arr[5*j]; // 13, independent + arr[5*i] = arr[10*j]; // 14, independent + arr[9*i] = arr[3*j]; // 15, independent + arr[3*i] = arr[9*j]; // 16, independent + } + } +} +*/ +TEST(DependencyAnalysis, MIV) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %16 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "a" + OpName %10 "b" + OpName %12 "N" + OpName %16 "c" + OpName %23 "i" + OpName %34 "j" + OpName %45 "arr" + OpDecorate %16 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 2 + %11 = OpConstant %6 3 + %13 = OpTypeFloat 32 + %14 = OpTypeVector %13 4 + %15 = OpTypePointer Input %14 + %16 = OpVariable %15 Input + %17 = OpTypeInt 32 0 + %18 = OpConstant %17 0 + %19 = OpTypePointer Input %13 + %24 = OpConstant %6 0 + %31 = OpConstant %6 10 + %32 = OpTypeBool + %42 = OpConstant %17 10 + %43 = OpTypeArray %6 %42 + %44 = OpTypePointer Function %43 + %74 = OpConstant %6 4 + %184 = OpConstant %6 5 + %197 = OpConstant %6 9 + %213 = OpConstant %6 1 + %218 = OpUndef %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %12 = OpVariable %7 Function + %23 = OpVariable %7 Function + %34 = OpVariable %7 Function + %45 = OpVariable %44 Function + OpStore %8 %9 + OpStore %10 %11 + %20 = OpAccessChain %19 %16 %18 + %21 = OpLoad %13 %20 + %22 = OpConvertFToS %6 %21 + OpStore %12 %22 + OpStore %23 %24 + OpBranch %25 + %25 = OpLabel + %217 = OpPhi %6 %24 %5 %216 %28 + %219 = OpPhi %6 %218 %5 %220 %28 + OpLoopMerge %27 %28 None + OpBranch %29 + %29 = OpLabel + %33 = OpSLessThan %32 %217 %31 + OpBranchConditional %33 %26 %27 + %26 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %220 = OpPhi %6 %9 %26 %214 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %32 %220 %31 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %48 = OpAccessChain %7 %45 %220 + %49 = OpLoad %6 %48 + %50 = OpAccessChain %7 %45 %217 + OpStore %50 %49 + %53 = OpAccessChain %7 %45 %217 + %54 = OpLoad %6 %53 + %55 = OpAccessChain %7 %45 %220 + OpStore %55 %54 + %57 = OpISub %6 %220 %9 + %59 = OpIAdd %6 %217 %11 + %60 = OpAccessChain %7 %45 %59 + %61 = OpLoad %6 %60 + %62 = OpAccessChain %7 %45 %57 + OpStore %62 %61 + %65 = OpISub %6 %220 %9 + %68 = OpIAdd %6 %217 %11 + %69 = OpAccessChain %7 %45 %68 + %70 = OpLoad %6 %69 + %71 = OpAccessChain %7 %45 %65 + OpStore %71 %70 + %73 = OpIMul %6 %9 %217 + %76 = OpIMul %6 %74 %220 + %77 = OpIAdd %6 %76 %11 + %78 = OpAccessChain %7 %45 %77 + %79 = OpLoad %6 %78 + %80 = OpAccessChain %7 %45 %73 + OpStore %80 %79 + %82 = OpIMul %6 %9 %217 + %84 = OpIMul %6 %74 %220 + %85 = OpAccessChain %7 %45 %84 + %86 = OpLoad %6 %85 + %87 = OpAccessChain %7 %45 %82 + OpStore %87 %86 + %90 = OpIAdd %6 %217 %220 + %93 = OpIAdd %6 %217 %220 + %94 = OpAccessChain %7 %45 %93 + %95 = OpLoad %6 %94 + %96 = OpAccessChain %7 %45 %90 + OpStore %96 %95 + %98 = OpIMul %6 %31 %217 + %100 = OpIAdd %6 %98 %220 + %102 = OpIMul %6 %31 %217 + %104 = OpIAdd %6 %102 %220 + %105 = OpAccessChain %7 %45 %104 + %106 = OpLoad %6 %105 + %107 = OpAccessChain %7 %45 %100 + OpStore %107 %106 + %109 = OpIMul %6 %31 %217 + %111 = OpIMul %6 %31 %220 + %112 = OpIAdd %6 %109 %111 + %114 = OpIMul %6 %31 %217 + %116 = OpIMul %6 %31 %220 + %117 = OpIAdd %6 %114 %116 + %118 = OpIAdd %6 %117 %11 + %119 = OpAccessChain %7 %45 %118 + %120 = OpLoad %6 %119 + %121 = OpAccessChain %7 %45 %112 + OpStore %121 %120 + %123 = OpIMul %6 %31 %217 + %125 = OpIMul %6 %31 %220 + %126 = OpIAdd %6 %123 %125 + %128 = OpIMul %6 %31 %217 + %131 = OpIMul %6 %22 %220 + %132 = OpIAdd %6 %128 %131 + %133 = OpIAdd %6 %132 %11 + %134 = OpAccessChain %7 %45 %133 + %135 = OpLoad %6 %134 + %136 = OpAccessChain %7 %45 %126 + OpStore %136 %135 + %138 = OpIMul %6 %31 %217 + %140 = OpIMul %6 %31 %220 + %141 = OpIAdd %6 %138 %140 + %143 = OpIMul %6 %31 %217 + %145 = OpIMul %6 %31 %220 + %146 = OpIAdd %6 %143 %145 + %148 = OpIAdd %6 %146 %22 + %149 = OpAccessChain %7 %45 %148 + %150 = OpLoad %6 %149 + %151 = OpAccessChain %7 %45 %141 + OpStore %151 %150 + %153 = OpIMul %6 %31 %217 + %156 = OpIMul %6 %22 %220 + %157 = OpIAdd %6 %153 %156 + %159 = OpIMul %6 %31 %217 + %161 = OpIMul %6 %31 %220 + %162 = OpIAdd %6 %159 %161 + %163 = OpIAdd %6 %162 %11 + %164 = OpAccessChain %7 %45 %163 + %165 = OpLoad %6 %164 + %166 = OpAccessChain %7 %45 %157 + OpStore %166 %165 + %168 = OpIMul %6 %31 %217 + %170 = OpIMul %6 %31 %220 + %171 = OpIAdd %6 %168 %170 + %173 = OpIAdd %6 %171 %22 + %175 = OpIMul %6 %31 %217 + %177 = OpIMul %6 %31 %220 + %178 = OpIAdd %6 %175 %177 + %179 = OpAccessChain %7 %45 %178 + %180 = OpLoad %6 %179 + %181 = OpAccessChain %7 %45 %173 + OpStore %181 %180 + %183 = OpIMul %6 %31 %217 + %186 = OpIMul %6 %184 %220 + %187 = OpAccessChain %7 %45 %186 + %188 = OpLoad %6 %187 + %189 = OpAccessChain %7 %45 %183 + OpStore %189 %188 + %191 = OpIMul %6 %184 %217 + %193 = OpIMul %6 %31 %220 + %194 = OpAccessChain %7 %45 %193 + %195 = OpLoad %6 %194 + %196 = OpAccessChain %7 %45 %191 + OpStore %196 %195 + %199 = OpIMul %6 %197 %217 + %201 = OpIMul %6 %11 %220 + %202 = OpAccessChain %7 %45 %201 + %203 = OpLoad %6 %202 + %204 = OpAccessChain %7 %45 %199 + OpStore %204 %203 + %206 = OpIMul %6 %11 %217 + %208 = OpIMul %6 %197 %220 + %209 = OpAccessChain %7 %45 %208 + %210 = OpLoad %6 %209 + %211 = OpAccessChain %7 %45 %206 + OpStore %211 %210 + OpBranch %38 + %38 = OpLabel + %214 = OpIAdd %6 %220 %213 + OpStore %34 %214 + OpBranch %35 + %37 = OpLabel + OpBranch %28 + %28 = OpLabel + %216 = OpIAdd %6 %217 %213 + OpStore %23 %216 + OpBranch %25 + %27 = OpLabel + OpReturn + OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 4); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + std::vector loops{&ld.GetLoopByIndex(0), &ld.GetLoopByIndex(1)}; + + LoopDependenceAnalysis analysis{context.get(), loops}; + + const int instructions_expected = 17; + const Instruction* store[instructions_expected]; + const Instruction* load[instructions_expected]; + int stores_found = 0; + int loads_found = 0; + + int block_id = 36; + ASSERT_TRUE(spvtest::GetBasicBlock(f, block_id)); + + for (const Instruction& inst : *spvtest::GetBasicBlock(f, block_id)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + + if (inst.opcode() == SpvOp::SpvOpLoad) { + load[loads_found] = &inst; + ++loads_found; + } + } + + EXPECT_EQ(instructions_expected, stores_found); + EXPECT_EQ(instructions_expected, loads_found); + + auto directions_all = DistanceEntry(DistanceEntry::Directions::ALL); + auto directions_none = DistanceEntry(DistanceEntry::Directions::NONE); + + auto dependent = DistanceVector({directions_all, directions_all}); + auto independent = DistanceVector({directions_none, directions_none}); + + CheckDependenceAndDirection(load[0], store[0], false, dependent, &analysis); + CheckDependenceAndDirection(load[1], store[1], false, dependent, &analysis); + CheckDependenceAndDirection(load[2], store[2], false, dependent, &analysis); + CheckDependenceAndDirection(load[3], store[3], false, dependent, &analysis); + CheckDependenceAndDirection(load[4], store[4], true, independent, &analysis); + CheckDependenceAndDirection(load[5], store[5], false, dependent, &analysis); + CheckDependenceAndDirection(load[6], store[6], false, dependent, &analysis); + CheckDependenceAndDirection(load[7], store[7], false, dependent, &analysis); + CheckDependenceAndDirection(load[8], store[8], true, independent, &analysis); + CheckDependenceAndDirection(load[9], store[9], false, dependent, &analysis); + CheckDependenceAndDirection(load[10], store[10], false, dependent, &analysis); + CheckDependenceAndDirection(load[11], store[11], false, dependent, &analysis); + CheckDependenceAndDirection(load[12], store[12], false, dependent, &analysis); + CheckDependenceAndDirection(load[13], store[13], true, independent, + &analysis); + CheckDependenceAndDirection(load[14], store[14], true, independent, + &analysis); + CheckDependenceAndDirection(load[15], store[15], true, independent, + &analysis); + CheckDependenceAndDirection(load[16], store[16], true, independent, + &analysis); +} + +void PartitionSubscripts(const Instruction* instruction_0, + const Instruction* instruction_1, + LoopDependenceAnalysis* analysis, + std::vector> expected_ids) { + auto subscripts_0 = analysis->GetSubscripts(instruction_0); + auto subscripts_1 = analysis->GetSubscripts(instruction_1); + + std::vector>> + expected_partition{}; + + for (const auto& partition : expected_ids) { + expected_partition.push_back( + std::set>{}); + for (auto id : partition) { + expected_partition.back().insert({subscripts_0[id], subscripts_1[id]}); + } + } + + EXPECT_EQ(expected_partition, + analysis->PartitionSubscripts(subscripts_0, subscripts_1)); +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +void main(){ + int[10][10][10][10] arr; + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + for (int k = 0; k < 10; k++) { + for (int l = 0; l < 10; l++) { + arr[i][j][k][l] = arr[i][j][k][l]; // 0, all independent + arr[i][j][k][l] = arr[i][j][l][0]; // 1, last 2 coupled + arr[i][j][k][l] = arr[j][i][k][l]; // 2, first 2 coupled + arr[i][j][k][l] = arr[l][j][k][i]; // 3, first & last coupled + arr[i][j][k][l] = arr[i][k][j][l]; // 4, middle 2 coupled + arr[i+j][j][k][l] = arr[i][j][k][l]; // 5, first 2 coupled + arr[i+j+k][j][k][l] = arr[i][j][k][l]; // 6, first 3 coupled + arr[i+j+k+l][j][k][l] = arr[i][j][k][l]; // 7, all 4 coupled + arr[i][j][k][l] = arr[i][l][j][k]; // 8, last 3 coupled + arr[i][j-k][k][l] = arr[i][j][l][k]; // 9, last 3 coupled + arr[i][j][k][l] = arr[l][i][j][k]; // 10, all 4 coupled + arr[i][j][k][l] = arr[j][i][l][k]; // 11, 2 coupled partitions (i,j) & +(l&k) + arr[i][j][k][l] = arr[k][l][i][j]; // 12, 2 coupled partitions (i,k) & +(j&l) + } + } + } + } +} +*/ +TEST(DependencyAnalysis, SubscriptPartitioning) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %19 "j" + OpName %27 "k" + OpName %35 "l" + OpName %50 "arr" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %43 = OpTypeInt 32 0 + %44 = OpConstant %43 10 + %45 = OpTypeArray %6 %44 + %46 = OpTypeArray %45 %44 + %47 = OpTypeArray %46 %44 + %48 = OpTypeArray %47 %44 + %49 = OpTypePointer Function %48 + %208 = OpConstant %6 1 + %217 = OpUndef %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %27 = OpVariable %7 Function + %35 = OpVariable %7 Function + %50 = OpVariable %49 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %216 = OpPhi %6 %9 %5 %215 %13 + %218 = OpPhi %6 %217 %5 %221 %13 + %219 = OpPhi %6 %217 %5 %222 %13 + %220 = OpPhi %6 %217 %5 %223 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %216 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpStore %19 %9 + OpBranch %20 + %20 = OpLabel + %221 = OpPhi %6 %9 %11 %213 %23 + %222 = OpPhi %6 %219 %11 %224 %23 + %223 = OpPhi %6 %220 %11 %225 %23 + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + %26 = OpSLessThan %17 %221 %16 + OpBranchConditional %26 %21 %22 + %21 = OpLabel + OpStore %27 %9 + OpBranch %28 + %28 = OpLabel + %224 = OpPhi %6 %9 %21 %211 %31 + %225 = OpPhi %6 %223 %21 %226 %31 + OpLoopMerge %30 %31 None + OpBranch %32 + %32 = OpLabel + %34 = OpSLessThan %17 %224 %16 + OpBranchConditional %34 %29 %30 + %29 = OpLabel + OpStore %35 %9 + OpBranch %36 + %36 = OpLabel + %226 = OpPhi %6 %9 %29 %209 %39 + OpLoopMerge %38 %39 None + OpBranch %40 + %40 = OpLabel + %42 = OpSLessThan %17 %226 %16 + OpBranchConditional %42 %37 %38 + %37 = OpLabel + %59 = OpAccessChain %7 %50 %216 %221 %224 %226 + %60 = OpLoad %6 %59 + %61 = OpAccessChain %7 %50 %216 %221 %224 %226 + OpStore %61 %60 + %69 = OpAccessChain %7 %50 %216 %221 %226 %9 + %70 = OpLoad %6 %69 + %71 = OpAccessChain %7 %50 %216 %221 %224 %226 + OpStore %71 %70 + %80 = OpAccessChain %7 %50 %221 %216 %224 %226 + %81 = OpLoad %6 %80 + %82 = OpAccessChain %7 %50 %216 %221 %224 %226 + OpStore %82 %81 + %91 = OpAccessChain %7 %50 %226 %221 %224 %216 + %92 = OpLoad %6 %91 + %93 = OpAccessChain %7 %50 %216 %221 %224 %226 + OpStore %93 %92 + %102 = OpAccessChain %7 %50 %216 %224 %221 %226 + %103 = OpLoad %6 %102 + %104 = OpAccessChain %7 %50 %216 %221 %224 %226 + OpStore %104 %103 + %107 = OpIAdd %6 %216 %221 + %115 = OpAccessChain %7 %50 %216 %221 %224 %226 + %116 = OpLoad %6 %115 + %117 = OpAccessChain %7 %50 %107 %221 %224 %226 + OpStore %117 %116 + %120 = OpIAdd %6 %216 %221 + %122 = OpIAdd %6 %120 %224 + %130 = OpAccessChain %7 %50 %216 %221 %224 %226 + %131 = OpLoad %6 %130 + %132 = OpAccessChain %7 %50 %122 %221 %224 %226 + OpStore %132 %131 + %135 = OpIAdd %6 %216 %221 + %137 = OpIAdd %6 %135 %224 + %139 = OpIAdd %6 %137 %226 + %147 = OpAccessChain %7 %50 %216 %221 %224 %226 + %148 = OpLoad %6 %147 + %149 = OpAccessChain %7 %50 %139 %221 %224 %226 + OpStore %149 %148 + %158 = OpAccessChain %7 %50 %216 %226 %221 %224 + %159 = OpLoad %6 %158 + %160 = OpAccessChain %7 %50 %216 %221 %224 %226 + OpStore %160 %159 + %164 = OpISub %6 %221 %224 + %171 = OpAccessChain %7 %50 %216 %221 %226 %224 + %172 = OpLoad %6 %171 + %173 = OpAccessChain %7 %50 %216 %164 %224 %226 + OpStore %173 %172 + %182 = OpAccessChain %7 %50 %226 %216 %221 %224 + %183 = OpLoad %6 %182 + %184 = OpAccessChain %7 %50 %216 %221 %224 %226 + OpStore %184 %183 + %193 = OpAccessChain %7 %50 %221 %216 %226 %224 + %194 = OpLoad %6 %193 + %195 = OpAccessChain %7 %50 %216 %221 %224 %226 + OpStore %195 %194 + %204 = OpAccessChain %7 %50 %224 %226 %216 %221 + %205 = OpLoad %6 %204 + %206 = OpAccessChain %7 %50 %216 %221 %224 %226 + OpStore %206 %205 + OpBranch %39 + %39 = OpLabel + %209 = OpIAdd %6 %226 %208 + OpStore %35 %209 + OpBranch %36 + %38 = OpLabel + OpBranch %31 + %31 = OpLabel + %211 = OpIAdd %6 %224 %208 + OpStore %27 %211 + OpBranch %28 + %30 = OpLabel + OpBranch %23 + %23 = OpLabel + %213 = OpIAdd %6 %221 %208 + OpStore %19 %213 + OpBranch %20 + %22 = OpLabel + OpBranch %13 + %13 = OpLabel + %215 = OpIAdd %6 %216 %208 + OpStore %8 %215 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 4); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + std::vector loop_nest{ + &ld.GetLoopByIndex(0), &ld.GetLoopByIndex(1), &ld.GetLoopByIndex(2), + &ld.GetLoopByIndex(3)}; + LoopDependenceAnalysis analysis{context.get(), loop_nest}; + + const int instructions_expected = 13; + const Instruction* store[instructions_expected]; + const Instruction* load[instructions_expected]; + int stores_found = 0; + int loads_found = 0; + + int block_id = 37; + ASSERT_TRUE(spvtest::GetBasicBlock(f, block_id)); + + for (const Instruction& inst : *spvtest::GetBasicBlock(f, block_id)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + + if (inst.opcode() == SpvOp::SpvOpLoad) { + load[loads_found] = &inst; + ++loads_found; + } + } + + EXPECT_EQ(instructions_expected, stores_found); + EXPECT_EQ(instructions_expected, loads_found); + + PartitionSubscripts(load[0], store[0], &analysis, {{0}, {1}, {2}, {3}}); + PartitionSubscripts(load[1], store[1], &analysis, {{0}, {1}, {2, 3}}); + PartitionSubscripts(load[2], store[2], &analysis, {{0, 1}, {2}, {3}}); + PartitionSubscripts(load[3], store[3], &analysis, {{0, 3}, {1}, {2}}); + PartitionSubscripts(load[4], store[4], &analysis, {{0}, {1, 2}, {3}}); + PartitionSubscripts(load[5], store[5], &analysis, {{0, 1}, {2}, {3}}); + PartitionSubscripts(load[6], store[6], &analysis, {{0, 1, 2}, {3}}); + PartitionSubscripts(load[7], store[7], &analysis, {{0, 1, 2, 3}}); + PartitionSubscripts(load[8], store[8], &analysis, {{0}, {1, 2, 3}}); + PartitionSubscripts(load[9], store[9], &analysis, {{0}, {1, 2, 3}}); + PartitionSubscripts(load[10], store[10], &analysis, {{0, 1, 2, 3}}); + PartitionSubscripts(load[11], store[11], &analysis, {{0, 1}, {2, 3}}); + PartitionSubscripts(load[12], store[12], &analysis, {{0, 2}, {1, 3}}); +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store + +#version 440 core +void a() { + int[10][10] arr; + for (int i = 0; i < 10; ++i) { + for (int j = 0; j < 10; ++j) { + // Dependent, distance vector (1, -1) + arr[i+1][i+j] = arr[i][i+j]; + } + } +} + +void b() { + int[10][10] arr; + for (int i = 0; i < 10; ++i) { + // Independent + arr[i+1][i+2] = arr[i][i] + 2; + } +} + +void c() { + int[10][10] arr; + for (int i = 0; i < 10; ++i) { + // Dependence point (1,2) + arr[i][i] = arr[1][i-1] + 2; + } +} + +void d() { + int[10][10][10] arr; + for (int i = 0; i < 10; ++i) { + for (int j = 0; j < 10; ++j) { + for (int k = 0; k < 10; ++k) { + // Dependent, distance vector (1,1,-1) + arr[j-i][i+1][j+k] = arr[j-i][i][j+k]; + } + } + } +} + +void e() { + int[10][10] arr; + for (int i = 0; i < 10; ++i) { + for (int j = 0; j < 10; ++j) { + // Independent with GCD after propagation + arr[i][2*j+i] = arr[i][2*j-i+5]; + } + } +} + +void main(){ + a(); + b(); + c(); + d(); + e(); +} +*/ +TEST(DependencyAnalysis, Delta) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %6 "a(" + OpName %8 "b(" + OpName %10 "c(" + OpName %12 "d(" + OpName %14 "e(" + OpName %18 "i" + OpName %29 "j" + OpName %42 "arr" + OpName %60 "i" + OpName %68 "arr" + OpName %82 "i" + OpName %90 "arr" + OpName %101 "i" + OpName %109 "j" + OpName %117 "k" + OpName %127 "arr" + OpName %152 "i" + OpName %160 "j" + OpName %168 "arr" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %16 = OpTypeInt 32 1 + %17 = OpTypePointer Function %16 + %19 = OpConstant %16 0 + %26 = OpConstant %16 10 + %27 = OpTypeBool + %37 = OpTypeInt 32 0 + %38 = OpConstant %37 10 + %39 = OpTypeArray %16 %38 + %40 = OpTypeArray %39 %38 + %41 = OpTypePointer Function %40 + %44 = OpConstant %16 1 + %72 = OpConstant %16 2 + %125 = OpTypeArray %40 %38 + %126 = OpTypePointer Function %125 + %179 = OpConstant %16 5 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %188 = OpFunctionCall %2 %6 + %189 = OpFunctionCall %2 %8 + %190 = OpFunctionCall %2 %10 + %191 = OpFunctionCall %2 %12 + %192 = OpFunctionCall %2 %14 + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %18 = OpVariable %17 Function + %29 = OpVariable %17 Function + %42 = OpVariable %41 Function + OpStore %18 %19 + OpBranch %20 + %20 = OpLabel + %193 = OpPhi %16 %19 %7 %59 %23 + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + %28 = OpSLessThan %27 %193 %26 + OpBranchConditional %28 %21 %22 + %21 = OpLabel + OpStore %29 %19 + OpBranch %30 + %30 = OpLabel + %194 = OpPhi %16 %19 %21 %57 %33 + OpLoopMerge %32 %33 None + OpBranch %34 + %34 = OpLabel + %36 = OpSLessThan %27 %194 %26 + OpBranchConditional %36 %31 %32 + %31 = OpLabel + %45 = OpIAdd %16 %193 %44 + %48 = OpIAdd %16 %193 %194 + %52 = OpIAdd %16 %193 %194 + %53 = OpAccessChain %17 %42 %193 %52 + %54 = OpLoad %16 %53 + %55 = OpAccessChain %17 %42 %45 %48 + OpStore %55 %54 + OpBranch %33 + %33 = OpLabel + %57 = OpIAdd %16 %194 %44 + OpStore %29 %57 + OpBranch %30 + %32 = OpLabel + OpBranch %23 + %23 = OpLabel + %59 = OpIAdd %16 %193 %44 + OpStore %18 %59 + OpBranch %20 + %22 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %2 None %3 + %9 = OpLabel + %60 = OpVariable %17 Function + %68 = OpVariable %41 Function + OpStore %60 %19 + OpBranch %61 + %61 = OpLabel + %196 = OpPhi %16 %19 %9 %81 %64 + OpLoopMerge %63 %64 None + OpBranch %65 + %65 = OpLabel + %67 = OpSLessThan %27 %196 %26 + OpBranchConditional %67 %62 %63 + %62 = OpLabel + %70 = OpIAdd %16 %196 %44 + %73 = OpIAdd %16 %196 %72 + %76 = OpAccessChain %17 %68 %196 %196 + %77 = OpLoad %16 %76 + %78 = OpIAdd %16 %77 %72 + %79 = OpAccessChain %17 %68 %70 %73 + OpStore %79 %78 + OpBranch %64 + %64 = OpLabel + %81 = OpIAdd %16 %196 %44 + OpStore %60 %81 + OpBranch %61 + %63 = OpLabel + OpReturn + OpFunctionEnd + %10 = OpFunction %2 None %3 + %11 = OpLabel + %82 = OpVariable %17 Function + %90 = OpVariable %41 Function + OpStore %82 %19 + OpBranch %83 + %83 = OpLabel + %197 = OpPhi %16 %19 %11 %100 %86 + OpLoopMerge %85 %86 None + OpBranch %87 + %87 = OpLabel + %89 = OpSLessThan %27 %197 %26 + OpBranchConditional %89 %84 %85 + %84 = OpLabel + %94 = OpISub %16 %197 %44 + %95 = OpAccessChain %17 %90 %44 %94 + %96 = OpLoad %16 %95 + %97 = OpIAdd %16 %96 %72 + %98 = OpAccessChain %17 %90 %197 %197 + OpStore %98 %97 + OpBranch %86 + %86 = OpLabel + %100 = OpIAdd %16 %197 %44 + OpStore %82 %100 + OpBranch %83 + %85 = OpLabel + OpReturn + OpFunctionEnd + %12 = OpFunction %2 None %3 + %13 = OpLabel + %101 = OpVariable %17 Function + %109 = OpVariable %17 Function + %117 = OpVariable %17 Function + %127 = OpVariable %126 Function + OpStore %101 %19 + OpBranch %102 + %102 = OpLabel + %198 = OpPhi %16 %19 %13 %151 %105 + OpLoopMerge %104 %105 None + OpBranch %106 + %106 = OpLabel + %108 = OpSLessThan %27 %198 %26 + OpBranchConditional %108 %103 %104 + %103 = OpLabel + OpStore %109 %19 + OpBranch %110 + %110 = OpLabel + %199 = OpPhi %16 %19 %103 %149 %113 + OpLoopMerge %112 %113 None + OpBranch %114 + %114 = OpLabel + %116 = OpSLessThan %27 %199 %26 + OpBranchConditional %116 %111 %112 + %111 = OpLabel + OpStore %117 %19 + OpBranch %118 + %118 = OpLabel + %201 = OpPhi %16 %19 %111 %147 %121 + OpLoopMerge %120 %121 None + OpBranch %122 + %122 = OpLabel + %124 = OpSLessThan %27 %201 %26 + OpBranchConditional %124 %119 %120 + %119 = OpLabel + %130 = OpISub %16 %199 %198 + %132 = OpIAdd %16 %198 %44 + %135 = OpIAdd %16 %199 %201 + %138 = OpISub %16 %199 %198 + %142 = OpIAdd %16 %199 %201 + %143 = OpAccessChain %17 %127 %138 %198 %142 + %144 = OpLoad %16 %143 + %145 = OpAccessChain %17 %127 %130 %132 %135 + OpStore %145 %144 + OpBranch %121 + %121 = OpLabel + %147 = OpIAdd %16 %201 %44 + OpStore %117 %147 + OpBranch %118 + %120 = OpLabel + OpBranch %113 + %113 = OpLabel + %149 = OpIAdd %16 %199 %44 + OpStore %109 %149 + OpBranch %110 + %112 = OpLabel + OpBranch %105 + %105 = OpLabel + %151 = OpIAdd %16 %198 %44 + OpStore %101 %151 + OpBranch %102 + %104 = OpLabel + OpReturn + OpFunctionEnd + %14 = OpFunction %2 None %3 + %15 = OpLabel + %152 = OpVariable %17 Function + %160 = OpVariable %17 Function + %168 = OpVariable %41 Function + OpStore %152 %19 + OpBranch %153 + %153 = OpLabel + %204 = OpPhi %16 %19 %15 %187 %156 + OpLoopMerge %155 %156 None + OpBranch %157 + %157 = OpLabel + %159 = OpSLessThan %27 %204 %26 + OpBranchConditional %159 %154 %155 + %154 = OpLabel + OpStore %160 %19 + OpBranch %161 + %161 = OpLabel + %205 = OpPhi %16 %19 %154 %185 %164 + OpLoopMerge %163 %164 None + OpBranch %165 + %165 = OpLabel + %167 = OpSLessThan %27 %205 %26 + OpBranchConditional %167 %162 %163 + %162 = OpLabel + %171 = OpIMul %16 %72 %205 + %173 = OpIAdd %16 %171 %204 + %176 = OpIMul %16 %72 %205 + %178 = OpISub %16 %176 %204 + %180 = OpIAdd %16 %178 %179 + %181 = OpAccessChain %17 %168 %204 %180 + %182 = OpLoad %16 %181 + %183 = OpAccessChain %17 %168 %204 %173 + OpStore %183 %182 + OpBranch %164 + %164 = OpLabel + %185 = OpIAdd %16 %205 %44 + OpStore %160 %185 + OpBranch %161 + %163 = OpLabel + OpBranch %156 + %156 = OpLabel + %187 = OpIAdd %16 %204 %44 + OpStore %152 %187 + OpBranch %153 + %155 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(nullptr, context); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + { + const Function* f = spvtest::GetFunction(module, 6); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + const Instruction* store = nullptr; + const Instruction* load = nullptr; + + int block_id = 31; + ASSERT_TRUE(spvtest::GetBasicBlock(f, block_id)); + + for (const Instruction& inst : *spvtest::GetBasicBlock(f, block_id)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + + if (inst.opcode() == SpvOp::SpvOpLoad) { + load = &inst; + } + } + + EXPECT_NE(nullptr, store); + EXPECT_NE(nullptr, load); + + std::vector loop_nest{&ld.GetLoopByIndex(0), + &ld.GetLoopByIndex(1)}; + LoopDependenceAnalysis analysis{context.get(), loop_nest}; + + DistanceVector dv_entry(loop_nest.size()); + + std::vector expected_entries{ + DistanceEntry(DistanceEntry::Directions::LT, 1), + DistanceEntry(DistanceEntry::Directions::LT, 1)}; + + DistanceVector expected_distance_vector(expected_entries); + + auto is_independent = analysis.GetDependence(load, store, &dv_entry); + + EXPECT_FALSE(is_independent); + EXPECT_EQ(expected_distance_vector, dv_entry); + } + + { + const Function* f = spvtest::GetFunction(module, 8); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + const Instruction* store = nullptr; + const Instruction* load = nullptr; + + int block_id = 62; + ASSERT_TRUE(spvtest::GetBasicBlock(f, block_id)); + + for (const Instruction& inst : *spvtest::GetBasicBlock(f, block_id)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + + if (inst.opcode() == SpvOp::SpvOpLoad) { + load = &inst; + } + } + + EXPECT_NE(nullptr, store); + EXPECT_NE(nullptr, load); + + std::vector loop_nest{&ld.GetLoopByIndex(0)}; + LoopDependenceAnalysis analysis{context.get(), loop_nest}; + + DistanceVector dv_entry(loop_nest.size()); + auto is_independent = analysis.GetDependence(load, store, &dv_entry); + + EXPECT_TRUE(is_independent); + } + + { + const Function* f = spvtest::GetFunction(module, 10); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + const Instruction* store = nullptr; + const Instruction* load = nullptr; + + int block_id = 84; + ASSERT_TRUE(spvtest::GetBasicBlock(f, block_id)); + + for (const Instruction& inst : *spvtest::GetBasicBlock(f, block_id)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + + if (inst.opcode() == SpvOp::SpvOpLoad) { + load = &inst; + } + } + + EXPECT_NE(nullptr, store); + EXPECT_NE(nullptr, load); + + std::vector loop_nest{&ld.GetLoopByIndex(0)}; + LoopDependenceAnalysis analysis{context.get(), loop_nest}; + + DistanceVector dv_entry(loop_nest.size()); + auto is_independent = analysis.GetDependence(load, store, &dv_entry); + + DistanceVector expected_distance_vector({DistanceEntry(1, 2)}); + + EXPECT_FALSE(is_independent); + EXPECT_EQ(expected_distance_vector, dv_entry); + } + + { + const Function* f = spvtest::GetFunction(module, 12); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + const Instruction* store = nullptr; + const Instruction* load = nullptr; + + int block_id = 119; + ASSERT_TRUE(spvtest::GetBasicBlock(f, block_id)); + + for (const Instruction& inst : *spvtest::GetBasicBlock(f, block_id)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + + if (inst.opcode() == SpvOp::SpvOpLoad) { + load = &inst; + } + } + + EXPECT_NE(nullptr, store); + EXPECT_NE(nullptr, load); + + std::vector loop_nest{ + &ld.GetLoopByIndex(0), &ld.GetLoopByIndex(1), &ld.GetLoopByIndex(2)}; + LoopDependenceAnalysis analysis{context.get(), loop_nest}; + + DistanceVector dv_entry(loop_nest.size()); + + std::vector expected_entries{ + DistanceEntry(DistanceEntry::Directions::LT, 1), + DistanceEntry(DistanceEntry::Directions::LT, 1), + DistanceEntry(DistanceEntry::Directions::GT, -1)}; + + DistanceVector expected_distance_vector(expected_entries); + + auto is_independent = analysis.GetDependence(store, load, &dv_entry); + + EXPECT_FALSE(is_independent); + EXPECT_EQ(expected_distance_vector, dv_entry); + } + + { + const Function* f = spvtest::GetFunction(module, 14); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + const Instruction* store = nullptr; + const Instruction* load = nullptr; + + int block_id = 162; + ASSERT_TRUE(spvtest::GetBasicBlock(f, block_id)); + + for (const Instruction& inst : *spvtest::GetBasicBlock(f, block_id)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + + if (inst.opcode() == SpvOp::SpvOpLoad) { + load = &inst; + } + } + + EXPECT_NE(nullptr, store); + EXPECT_NE(nullptr, load); + + std::vector loop_nest{&ld.GetLoopByIndex(0), + &ld.GetLoopByIndex(1)}; + LoopDependenceAnalysis analysis{context.get(), loop_nest}; + + DistanceVector dv_entry(loop_nest.size()); + auto is_independent = analysis.GetDependence(load, store, &dv_entry); + + EXPECT_TRUE(is_independent); + } +} + +TEST(DependencyAnalysis, ConstraintIntersection) { + LoopDependenceAnalysis analysis{nullptr, std::vector{}}; + auto scalar_evolution = analysis.GetScalarEvolution(); + { + // One is none. Other should be returned + auto none = analysis.make_constraint(); + auto x = scalar_evolution->CreateConstant(1); + auto y = scalar_evolution->CreateConstant(10); + auto point = analysis.make_constraint(x, y, nullptr); + + auto ret_0 = analysis.IntersectConstraints(none, point, nullptr, nullptr); + + auto ret_point_0 = ret_0->AsDependencePoint(); + ASSERT_NE(nullptr, ret_point_0); + EXPECT_EQ(*x, *ret_point_0->GetSource()); + EXPECT_EQ(*y, *ret_point_0->GetDestination()); + + auto ret_1 = analysis.IntersectConstraints(point, none, nullptr, nullptr); + + auto ret_point_1 = ret_1->AsDependencePoint(); + ASSERT_NE(nullptr, ret_point_1); + EXPECT_EQ(*x, *ret_point_1->GetSource()); + EXPECT_EQ(*y, *ret_point_1->GetDestination()); + } + + { + // Both distances + auto x = scalar_evolution->CreateConstant(1); + auto y = scalar_evolution->CreateConstant(10); + + auto distance_0 = analysis.make_constraint(x, nullptr); + auto distance_1 = analysis.make_constraint(y, nullptr); + + // Equal distances + auto ret_0 = + analysis.IntersectConstraints(distance_1, distance_1, nullptr, nullptr); + + auto ret_distance = ret_0->AsDependenceDistance(); + ASSERT_NE(nullptr, ret_distance); + EXPECT_EQ(*y, *ret_distance->GetDistance()); + + // Non-equal distances + auto ret_1 = + analysis.IntersectConstraints(distance_0, distance_1, nullptr, nullptr); + EXPECT_NE(nullptr, ret_1->AsDependenceEmpty()); + } + + { + // Both points + auto x = scalar_evolution->CreateConstant(1); + auto y = scalar_evolution->CreateConstant(10); + + auto point_0 = analysis.make_constraint(x, y, nullptr); + auto point_1 = analysis.make_constraint(x, y, nullptr); + auto point_2 = analysis.make_constraint(y, y, nullptr); + + // Equal points + auto ret_0 = + analysis.IntersectConstraints(point_0, point_1, nullptr, nullptr); + auto ret_point_0 = ret_0->AsDependencePoint(); + ASSERT_NE(nullptr, ret_point_0); + EXPECT_EQ(*x, *ret_point_0->GetSource()); + EXPECT_EQ(*y, *ret_point_0->GetDestination()); + + // Non-equal points + auto ret_1 = + analysis.IntersectConstraints(point_0, point_2, nullptr, nullptr); + EXPECT_NE(nullptr, ret_1->AsDependenceEmpty()); + } + + { + // Both lines, parallel + auto a0 = scalar_evolution->CreateConstant(3); + auto b0 = scalar_evolution->CreateConstant(6); + auto c0 = scalar_evolution->CreateConstant(9); + + auto a1 = scalar_evolution->CreateConstant(6); + auto b1 = scalar_evolution->CreateConstant(12); + auto c1 = scalar_evolution->CreateConstant(18); + + auto line_0 = analysis.make_constraint(a0, b0, c0, nullptr); + auto line_1 = analysis.make_constraint(a1, b1, c1, nullptr); + + // Same line, both ways + auto ret_0 = + analysis.IntersectConstraints(line_0, line_1, nullptr, nullptr); + auto ret_1 = + analysis.IntersectConstraints(line_1, line_0, nullptr, nullptr); + + auto ret_line_0 = ret_0->AsDependenceLine(); + auto ret_line_1 = ret_1->AsDependenceLine(); + + EXPECT_NE(nullptr, ret_line_0); + EXPECT_NE(nullptr, ret_line_1); + + // Non-intersecting parallel lines + auto c2 = scalar_evolution->CreateConstant(12); + auto line_2 = analysis.make_constraint(a1, b1, c2, nullptr); + + auto ret_2 = + analysis.IntersectConstraints(line_0, line_2, nullptr, nullptr); + auto ret_3 = + analysis.IntersectConstraints(line_2, line_0, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_2->AsDependenceEmpty()); + EXPECT_NE(nullptr, ret_3->AsDependenceEmpty()); + + auto c3 = scalar_evolution->CreateConstant(20); + auto line_3 = analysis.make_constraint(a1, b1, c3, nullptr); + + auto ret_4 = + analysis.IntersectConstraints(line_0, line_3, nullptr, nullptr); + auto ret_5 = + analysis.IntersectConstraints(line_3, line_0, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_4->AsDependenceEmpty()); + EXPECT_NE(nullptr, ret_5->AsDependenceEmpty()); + } + + { + // Non-constant line + auto unknown = scalar_evolution->CreateCantComputeNode(); + auto constant = scalar_evolution->CreateConstant(10); + + auto line_0 = analysis.make_constraint(constant, constant, + constant, nullptr); + auto line_1 = analysis.make_constraint(unknown, unknown, + unknown, nullptr); + + auto ret_0 = + analysis.IntersectConstraints(line_0, line_1, nullptr, nullptr); + auto ret_1 = + analysis.IntersectConstraints(line_1, line_0, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_0->AsDependenceNone()); + EXPECT_NE(nullptr, ret_1->AsDependenceNone()); + } + + { + auto bound_0 = scalar_evolution->CreateConstant(0); + auto bound_1 = scalar_evolution->CreateConstant(20); + + auto a0 = scalar_evolution->CreateConstant(1); + auto b0 = scalar_evolution->CreateConstant(2); + auto c0 = scalar_evolution->CreateConstant(6); + + auto a1 = scalar_evolution->CreateConstant(-1); + auto b1 = scalar_evolution->CreateConstant(2); + auto c1 = scalar_evolution->CreateConstant(2); + + auto line_0 = analysis.make_constraint(a0, b0, c0, nullptr); + auto line_1 = analysis.make_constraint(a1, b1, c1, nullptr); + + // Intersecting lines, has integer solution, in bounds + auto ret_0 = + analysis.IntersectConstraints(line_0, line_1, bound_0, bound_1); + auto ret_1 = + analysis.IntersectConstraints(line_1, line_0, bound_0, bound_1); + + auto ret_point_0 = ret_0->AsDependencePoint(); + auto ret_point_1 = ret_1->AsDependencePoint(); + + EXPECT_NE(nullptr, ret_point_0); + EXPECT_NE(nullptr, ret_point_1); + + auto const_2 = scalar_evolution->CreateConstant(2); + + EXPECT_EQ(*const_2, *ret_point_0->GetSource()); + EXPECT_EQ(*const_2, *ret_point_0->GetDestination()); + + EXPECT_EQ(*const_2, *ret_point_1->GetSource()); + EXPECT_EQ(*const_2, *ret_point_1->GetDestination()); + + // Intersecting lines, has integer solution, out of bounds + auto ret_2 = + analysis.IntersectConstraints(line_0, line_1, bound_0, bound_0); + auto ret_3 = + analysis.IntersectConstraints(line_1, line_0, bound_0, bound_0); + + EXPECT_NE(nullptr, ret_2->AsDependenceEmpty()); + EXPECT_NE(nullptr, ret_3->AsDependenceEmpty()); + + auto a2 = scalar_evolution->CreateConstant(-4); + auto b2 = scalar_evolution->CreateConstant(1); + auto c2 = scalar_evolution->CreateConstant(0); + + auto a3 = scalar_evolution->CreateConstant(4); + auto b3 = scalar_evolution->CreateConstant(1); + auto c3 = scalar_evolution->CreateConstant(4); + + auto line_2 = analysis.make_constraint(a2, b2, c2, nullptr); + auto line_3 = analysis.make_constraint(a3, b3, c3, nullptr); + + // Intersecting, no integer solution + auto ret_4 = + analysis.IntersectConstraints(line_2, line_3, bound_0, bound_1); + auto ret_5 = + analysis.IntersectConstraints(line_3, line_2, bound_0, bound_1); + + EXPECT_NE(nullptr, ret_4->AsDependenceEmpty()); + EXPECT_NE(nullptr, ret_5->AsDependenceEmpty()); + + auto unknown = scalar_evolution->CreateCantComputeNode(); + + // Non-constant bound + auto ret_6 = + analysis.IntersectConstraints(line_0, line_1, unknown, bound_1); + auto ret_7 = + analysis.IntersectConstraints(line_1, line_0, bound_0, unknown); + + EXPECT_NE(nullptr, ret_6->AsDependenceNone()); + EXPECT_NE(nullptr, ret_7->AsDependenceNone()); + } + + { + auto constant_0 = scalar_evolution->CreateConstant(0); + auto constant_1 = scalar_evolution->CreateConstant(1); + auto constant_neg_1 = scalar_evolution->CreateConstant(-1); + auto constant_2 = scalar_evolution->CreateConstant(2); + auto constant_neg_2 = scalar_evolution->CreateConstant(-2); + + auto point_0_0 = analysis.make_constraint( + constant_0, constant_0, nullptr); + auto point_0_1 = analysis.make_constraint( + constant_0, constant_1, nullptr); + auto point_1_0 = analysis.make_constraint( + constant_1, constant_0, nullptr); + auto point_1_1 = analysis.make_constraint( + constant_1, constant_1, nullptr); + auto point_1_2 = analysis.make_constraint( + constant_1, constant_2, nullptr); + auto point_1_neg_1 = analysis.make_constraint( + constant_1, constant_neg_1, nullptr); + auto point_neg_1_1 = analysis.make_constraint( + constant_neg_1, constant_1, nullptr); + + auto line_y_0 = analysis.make_constraint( + constant_0, constant_1, constant_0, nullptr); + auto line_y_1 = analysis.make_constraint( + constant_0, constant_1, constant_1, nullptr); + auto line_y_2 = analysis.make_constraint( + constant_0, constant_1, constant_2, nullptr); + + // Parallel horizontal lines, y = 0 & y = 1, should return no intersection + auto ret = + analysis.IntersectConstraints(line_y_0, line_y_1, nullptr, nullptr); + + EXPECT_NE(nullptr, ret->AsDependenceEmpty()); + + // Parallel horizontal lines, y = 1 & y = 2, should return no intersection + auto ret_y_12 = + analysis.IntersectConstraints(line_y_1, line_y_2, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_y_12->AsDependenceEmpty()); + + // Same horizontal lines, y = 0 & y = 0, should return the line + auto ret_y_same_0 = + analysis.IntersectConstraints(line_y_0, line_y_0, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_y_same_0->AsDependenceLine()); + + // Same horizontal lines, y = 1 & y = 1, should return the line + auto ret_y_same_1 = + analysis.IntersectConstraints(line_y_1, line_y_1, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_y_same_1->AsDependenceLine()); + + auto line_x_0 = analysis.make_constraint( + constant_1, constant_0, constant_0, nullptr); + auto line_x_1 = analysis.make_constraint( + constant_1, constant_0, constant_1, nullptr); + auto line_x_2 = analysis.make_constraint( + constant_1, constant_0, constant_2, nullptr); + auto line_2x_1 = analysis.make_constraint( + constant_2, constant_0, constant_1, nullptr); + auto line_2x_2 = analysis.make_constraint( + constant_2, constant_0, constant_2, nullptr); + + // Parallel vertical lines, x = 0 & x = 1, should return no intersection + auto ret_x = + analysis.IntersectConstraints(line_x_0, line_x_1, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_x->AsDependenceEmpty()); + + // Parallel vertical lines, x = 1 & x = 2, should return no intersection + auto ret_x_12 = + analysis.IntersectConstraints(line_x_1, line_x_2, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_x_12->AsDependenceEmpty()); + + // Parallel vertical lines, 2x = 1 & 2x = 2, should return no intersection + auto ret_2x_2_2x_1 = + analysis.IntersectConstraints(line_2x_2, line_2x_1, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_2x_2_2x_1->AsDependenceEmpty()); + + // same line, 2x=2 & x = 1 + auto ret_2x_2_x_1 = + analysis.IntersectConstraints(line_2x_2, line_x_1, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_2x_2_x_1->AsDependenceLine()); + + // Same vertical lines, x = 0 & x = 0, should return the line + auto ret_x_same_0 = + analysis.IntersectConstraints(line_x_0, line_x_0, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_x_same_0->AsDependenceLine()); + // EXPECT_EQ(*line_x_0, *ret_x_same_0->AsDependenceLine()); + + // Same vertical lines, x = 1 & x = 1, should return the line + auto ret_x_same_1 = + analysis.IntersectConstraints(line_x_1, line_x_1, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_x_same_1->AsDependenceLine()); + EXPECT_EQ(*line_x_1, *ret_x_same_1->AsDependenceLine()); + + // x=1 & y = 0, intersect at (1, 0) + auto ret_1_0 = analysis.IntersectConstraints(line_x_1, line_y_0, + constant_neg_1, constant_2); + + auto ret_point_1_0 = ret_1_0->AsDependencePoint(); + EXPECT_NE(nullptr, ret_point_1_0); + EXPECT_EQ(*point_1_0, *ret_point_1_0); + + // x=1 & y = 1, intersect at (1, 1) + auto ret_1_1 = analysis.IntersectConstraints(line_x_1, line_y_1, + constant_neg_1, constant_2); + + auto ret_point_1_1 = ret_1_1->AsDependencePoint(); + EXPECT_NE(nullptr, ret_point_1_1); + EXPECT_EQ(*point_1_1, *ret_point_1_1); + + // x=0 & y = 0, intersect at (0, 0) + auto ret_0_0 = analysis.IntersectConstraints(line_x_0, line_y_0, + constant_neg_1, constant_2); + + auto ret_point_0_0 = ret_0_0->AsDependencePoint(); + EXPECT_NE(nullptr, ret_point_0_0); + EXPECT_EQ(*point_0_0, *ret_point_0_0); + + // x=0 & y = 1, intersect at (0, 1) + auto ret_0_1 = analysis.IntersectConstraints(line_x_0, line_y_1, + constant_neg_1, constant_2); + auto ret_point_0_1 = ret_0_1->AsDependencePoint(); + EXPECT_NE(nullptr, ret_point_0_1); + EXPECT_EQ(*point_0_1, *ret_point_0_1); + + // x = 1 & y = 2 + auto ret_1_2 = analysis.IntersectConstraints(line_x_1, line_y_2, + constant_neg_1, constant_2); + auto ret_point_1_2 = ret_1_2->AsDependencePoint(); + EXPECT_NE(nullptr, ret_point_1_2); + EXPECT_EQ(*point_1_2, *ret_point_1_2); + + auto line_x_y_0 = analysis.make_constraint( + constant_1, constant_1, constant_0, nullptr); + auto line_x_y_1 = analysis.make_constraint( + constant_1, constant_1, constant_1, nullptr); + + // x+y=0 & x=0, intersect (0, 0) + auto ret_xy_0_x_0 = analysis.IntersectConstraints( + line_x_y_0, line_x_0, constant_neg_1, constant_2); + + EXPECT_NE(nullptr, ret_xy_0_x_0->AsDependencePoint()); + EXPECT_EQ(*point_0_0, *ret_xy_0_x_0); + + // x+y=0 & y=0, intersect (0, 0) + auto ret_xy_0_y_0 = analysis.IntersectConstraints( + line_x_y_0, line_y_0, constant_neg_1, constant_2); + + EXPECT_NE(nullptr, ret_xy_0_y_0->AsDependencePoint()); + EXPECT_EQ(*point_0_0, *ret_xy_0_y_0); + + // x+y=0 & x=1, intersect (1, -1) + auto ret_xy_0_x_1 = analysis.IntersectConstraints( + line_x_y_0, line_x_1, constant_neg_2, constant_2); + + EXPECT_NE(nullptr, ret_xy_0_x_1->AsDependencePoint()); + EXPECT_EQ(*point_1_neg_1, *ret_xy_0_x_1); + + // x+y=0 & y=1, intersect (-1, 1) + auto ret_xy_0_y_1 = analysis.IntersectConstraints( + line_x_y_0, line_y_1, constant_neg_2, constant_2); + + EXPECT_NE(nullptr, ret_xy_0_y_1->AsDependencePoint()); + EXPECT_EQ(*point_neg_1_1, *ret_xy_0_y_1); + + // x=0 & x+y=0, intersect (0, 0) + auto ret_x_0_xy_0 = analysis.IntersectConstraints( + line_x_0, line_x_y_0, constant_neg_1, constant_2); + + EXPECT_NE(nullptr, ret_x_0_xy_0->AsDependencePoint()); + EXPECT_EQ(*point_0_0, *ret_x_0_xy_0); + + // y=0 & x+y=0, intersect (0, 0) + auto ret_y_0_xy_0 = analysis.IntersectConstraints( + line_y_0, line_x_y_0, constant_neg_1, constant_2); + + EXPECT_NE(nullptr, ret_y_0_xy_0->AsDependencePoint()); + EXPECT_EQ(*point_0_0, *ret_y_0_xy_0); + + // x=1 & x+y=0, intersect (1, -1) + auto ret_x_1_xy_0 = analysis.IntersectConstraints( + line_x_1, line_x_y_0, constant_neg_2, constant_2); + + EXPECT_NE(nullptr, ret_x_1_xy_0->AsDependencePoint()); + EXPECT_EQ(*point_1_neg_1, *ret_x_1_xy_0); + + // y=1 & x+y=0, intersect (-1, 1) + auto ret_y_1_xy_0 = analysis.IntersectConstraints( + line_y_1, line_x_y_0, constant_neg_2, constant_2); + + EXPECT_NE(nullptr, ret_y_1_xy_0->AsDependencePoint()); + EXPECT_EQ(*point_neg_1_1, *ret_y_1_xy_0); + + // x+y=1 & x=0, intersect (0, 1) + auto ret_xy_1_x_0 = analysis.IntersectConstraints( + line_x_y_1, line_x_0, constant_neg_1, constant_2); + + EXPECT_NE(nullptr, ret_xy_1_x_0->AsDependencePoint()); + EXPECT_EQ(*point_0_1, *ret_xy_1_x_0); + + // x+y=1 & y=0, intersect (1, 0) + auto ret_xy_1_y_0 = analysis.IntersectConstraints( + line_x_y_1, line_y_0, constant_neg_1, constant_2); + + EXPECT_NE(nullptr, ret_xy_1_y_0->AsDependencePoint()); + EXPECT_EQ(*point_1_0, *ret_xy_1_y_0); + + // x+y=1 & x=1, intersect (1, 0) + auto ret_xy_1_x_1 = analysis.IntersectConstraints( + line_x_y_1, line_x_1, constant_neg_1, constant_2); + + EXPECT_NE(nullptr, ret_xy_1_x_1->AsDependencePoint()); + EXPECT_EQ(*point_1_0, *ret_xy_1_x_1); + + // x+y=1 & y=1, intersect (0, 1) + auto ret_xy_1_y_1 = analysis.IntersectConstraints( + line_x_y_1, line_y_1, constant_neg_1, constant_2); + + EXPECT_NE(nullptr, ret_xy_1_y_1->AsDependencePoint()); + EXPECT_EQ(*point_0_1, *ret_xy_1_y_1); + + // x=0 & x+y=1, intersect (0, 1) + auto ret_x_0_xy_1 = analysis.IntersectConstraints( + line_x_0, line_x_y_1, constant_neg_1, constant_2); + + EXPECT_NE(nullptr, ret_x_0_xy_1->AsDependencePoint()); + EXPECT_EQ(*point_0_1, *ret_x_0_xy_1); + + // y=0 & x+y=1, intersect (1, 0) + auto ret_y_0_xy_1 = analysis.IntersectConstraints( + line_y_0, line_x_y_1, constant_neg_1, constant_2); + + EXPECT_NE(nullptr, ret_y_0_xy_1->AsDependencePoint()); + EXPECT_EQ(*point_1_0, *ret_y_0_xy_1); + + // x=1 & x+y=1, intersect (1, 0) + auto ret_x_1_xy_1 = analysis.IntersectConstraints( + line_x_1, line_x_y_1, constant_neg_2, constant_2); + + EXPECT_NE(nullptr, ret_x_1_xy_1->AsDependencePoint()); + EXPECT_EQ(*point_1_0, *ret_x_1_xy_1); + + // y=1 & x+y=1, intersect (0, 1) + auto ret_y_1_xy_1 = analysis.IntersectConstraints( + line_y_1, line_x_y_1, constant_neg_2, constant_2); + + EXPECT_NE(nullptr, ret_y_1_xy_1->AsDependencePoint()); + EXPECT_EQ(*point_0_1, *ret_y_1_xy_1); + } + + { + // Line and point + auto a = scalar_evolution->CreateConstant(3); + auto b = scalar_evolution->CreateConstant(10); + auto c = scalar_evolution->CreateConstant(16); + + auto line = analysis.make_constraint(a, b, c, nullptr); + + // Point on line + auto x = scalar_evolution->CreateConstant(2); + auto y = scalar_evolution->CreateConstant(1); + auto point_0 = analysis.make_constraint(x, y, nullptr); + + auto ret_0 = analysis.IntersectConstraints(line, point_0, nullptr, nullptr); + auto ret_1 = analysis.IntersectConstraints(point_0, line, nullptr, nullptr); + + auto ret_point_0 = ret_0->AsDependencePoint(); + auto ret_point_1 = ret_1->AsDependencePoint(); + ASSERT_NE(nullptr, ret_point_0); + ASSERT_NE(nullptr, ret_point_1); + + EXPECT_EQ(*x, *ret_point_0->GetSource()); + EXPECT_EQ(*y, *ret_point_0->GetDestination()); + + EXPECT_EQ(*x, *ret_point_1->GetSource()); + EXPECT_EQ(*y, *ret_point_1->GetDestination()); + + // Point not on line + auto point_1 = analysis.make_constraint(a, a, nullptr); + + auto ret_2 = analysis.IntersectConstraints(line, point_1, nullptr, nullptr); + auto ret_3 = analysis.IntersectConstraints(point_1, line, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_2->AsDependenceEmpty()); + EXPECT_NE(nullptr, ret_3->AsDependenceEmpty()); + + // Non-constant + auto unknown = scalar_evolution->CreateCantComputeNode(); + + auto point_2 = + analysis.make_constraint(unknown, x, nullptr); + + auto ret_4 = analysis.IntersectConstraints(line, point_2, nullptr, nullptr); + auto ret_5 = analysis.IntersectConstraints(point_2, line, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_4->AsDependenceNone()); + EXPECT_NE(nullptr, ret_5->AsDependenceNone()); + } + + { + // Distance and point + auto d = scalar_evolution->CreateConstant(5); + auto distance = analysis.make_constraint(d, nullptr); + + // Point on line + auto x = scalar_evolution->CreateConstant(10); + auto point_0 = analysis.make_constraint(d, x, nullptr); + + auto ret_0 = + analysis.IntersectConstraints(distance, point_0, nullptr, nullptr); + auto ret_1 = + analysis.IntersectConstraints(point_0, distance, nullptr, nullptr); + + auto ret_point_0 = ret_0->AsDependencePoint(); + auto ret_point_1 = ret_1->AsDependencePoint(); + ASSERT_NE(nullptr, ret_point_0); + ASSERT_NE(nullptr, ret_point_1); + + // Point not on line + auto point_1 = analysis.make_constraint(x, x, nullptr); + + auto ret_2 = + analysis.IntersectConstraints(distance, point_1, nullptr, nullptr); + auto ret_3 = + analysis.IntersectConstraints(point_1, distance, nullptr, nullptr); + + EXPECT_NE(nullptr, ret_2->AsDependenceEmpty()); + EXPECT_NE(nullptr, ret_3->AsDependenceEmpty()); + + // Non-constant + auto unknown = scalar_evolution->CreateCantComputeNode(); + auto unknown_distance = + analysis.make_constraint(unknown, nullptr); + + auto ret_4 = analysis.IntersectConstraints(unknown_distance, point_1, + nullptr, nullptr); + auto ret_5 = analysis.IntersectConstraints(point_1, unknown_distance, + nullptr, nullptr); + + EXPECT_NE(nullptr, ret_4->AsDependenceNone()); + EXPECT_NE(nullptr, ret_5->AsDependenceNone()); + } +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/loop_optimizations/dependence_analysis_helpers.cpp b/test/opt/loop_optimizations/dependence_analysis_helpers.cpp new file mode 100644 index 000000000..715cf541d --- /dev/null +++ b/test/opt/loop_optimizations/dependence_analysis_helpers.cpp @@ -0,0 +1,3017 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/iterator.h" +#include "source/opt/loop_dependence.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/pass.h" +#include "source/opt/scalar_analysis.h" +#include "source/opt/tree_iterator.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using DependencyAnalysisHelpers = ::testing::Test; + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +void a() { + int[10][10] arr; + int i = 0; + int j = 0; + for (; i < 10 && j < 10; i++, j++) { + arr[i][j] = arr[i][j]; + } +} +void b() { + int[10] arr; + for (int i = 0; i < 10; i+=2) { + arr[i] = arr[i]; + } +} +void main(){ + a(); + b(); +} +*/ +TEST(DependencyAnalysisHelpers, UnsupportedLoops) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %6 "a(" + OpName %8 "b(" + OpName %12 "i" + OpName %14 "j" + OpName %32 "arr" + OpName %45 "i" + OpName %54 "arr" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %10 = OpTypeInt 32 1 + %11 = OpTypePointer Function %10 + %13 = OpConstant %10 0 + %21 = OpConstant %10 10 + %22 = OpTypeBool + %27 = OpTypeInt 32 0 + %28 = OpConstant %27 10 + %29 = OpTypeArray %10 %28 + %30 = OpTypeArray %29 %28 + %31 = OpTypePointer Function %30 + %41 = OpConstant %10 1 + %53 = OpTypePointer Function %29 + %60 = OpConstant %10 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %63 = OpFunctionCall %2 %6 + %64 = OpFunctionCall %2 %8 + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %12 = OpVariable %11 Function + %14 = OpVariable %11 Function + %32 = OpVariable %31 Function + OpStore %12 %13 + OpStore %14 %13 + OpBranch %15 + %15 = OpLabel + %65 = OpPhi %10 %13 %7 %42 %18 + %66 = OpPhi %10 %13 %7 %44 %18 + OpLoopMerge %17 %18 None + OpBranch %19 + %19 = OpLabel + %23 = OpSLessThan %22 %65 %21 + %25 = OpSLessThan %22 %66 %21 + %26 = OpLogicalAnd %22 %23 %25 + OpBranchConditional %26 %16 %17 + %16 = OpLabel + %37 = OpAccessChain %11 %32 %65 %66 + %38 = OpLoad %10 %37 + %39 = OpAccessChain %11 %32 %65 %66 + OpStore %39 %38 + OpBranch %18 + %18 = OpLabel + %42 = OpIAdd %10 %65 %41 + OpStore %12 %42 + %44 = OpIAdd %10 %66 %41 + OpStore %14 %44 + OpBranch %15 + %17 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %2 None %3 + %9 = OpLabel + %45 = OpVariable %11 Function + %54 = OpVariable %53 Function + OpStore %45 %13 + OpBranch %46 + %46 = OpLabel + %67 = OpPhi %10 %13 %9 %62 %49 + OpLoopMerge %48 %49 None + OpBranch %50 + %50 = OpLabel + %52 = OpSLessThan %22 %67 %21 + OpBranchConditional %52 %47 %48 + %47 = OpLabel + %57 = OpAccessChain %11 %54 %67 + %58 = OpLoad %10 %57 + %59 = OpAccessChain %11 %54 %67 + OpStore %59 %58 + OpBranch %49 + %49 = OpLabel + %62 = OpIAdd %10 %67 %60 + OpStore %45 %62 + OpBranch %46 + %48 = OpLabel + OpReturn + OpFunctionEnd +)"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + { + // Function a + const Function* f = spvtest::GetFunction(module, 6); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[1] = {nullptr}; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 16)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + // 38 -> 39 + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.IsSupportedLoop(loops[0])); + EXPECT_FALSE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(38), + store[0], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::UNKNOWN); + EXPECT_EQ(distance_vector.GetEntries()[0].direction, + DistanceEntry::Directions::ALL); + } + { + // Function b + const Function* f = spvtest::GetFunction(module, 8); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* store[1] = {nullptr}; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 47)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store[stores_found] = &inst; + ++stores_found; + } + } + // 58 -> 59 + DistanceVector distance_vector{loops.size()}; + EXPECT_FALSE(analysis.IsSupportedLoop(loops[0])); + EXPECT_FALSE(analysis.GetDependence(context->get_def_use_mgr()->GetDef(58), + store[0], &distance_vector)); + EXPECT_EQ(distance_vector.GetEntries()[0].dependence_information, + DistanceEntry::DependenceInformation::UNKNOWN); + EXPECT_EQ(distance_vector.GetEntries()[0].direction, + DistanceEntry::Directions::ALL); + } +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +void a() { + for (int i = -10; i < 0; i++) { + + } +} +void b() { + for (int i = -5; i < 5; i++) { + + } +} +void c() { + for (int i = 0; i < 10; i++) { + + } +} +void d() { + for (int i = 5; i < 15; i++) { + + } +} +void e() { + for (int i = -10; i <= 0; i++) { + + } +} +void f() { + for (int i = -5; i <= 5; i++) { + + } +} +void g() { + for (int i = 0; i <= 10; i++) { + + } +} +void h() { + for (int i = 5; i <= 15; i++) { + + } +} +void i() { + for (int i = 0; i > -10; i--) { + + } +} +void j() { + for (int i = 5; i > -5; i--) { + + } +} +void k() { + for (int i = 10; i > 0; i--) { + + } +} +void l() { + for (int i = 15; i > 5; i--) { + + } +} +void m() { + for (int i = 0; i >= -10; i--) { + + } +} +void n() { + for (int i = 5; i >= -5; i--) { + + } +} +void o() { + for (int i = 10; i >= 0; i--) { + + } +} +void p() { + for (int i = 15; i >= 5; i--) { + + } +} +void main(){ + a(); + b(); + c(); + d(); + e(); + f(); + g(); + h(); + i(); + j(); + k(); + l(); + m(); + n(); + o(); + p(); +} +*/ +TEST(DependencyAnalysisHelpers, loop_information) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %6 "a(" + OpName %8 "b(" + OpName %10 "c(" + OpName %12 "d(" + OpName %14 "e(" + OpName %16 "f(" + OpName %18 "g(" + OpName %20 "h(" + OpName %22 "i(" + OpName %24 "j(" + OpName %26 "k(" + OpName %28 "l(" + OpName %30 "m(" + OpName %32 "n(" + OpName %34 "o(" + OpName %36 "p(" + OpName %40 "i" + OpName %54 "i" + OpName %66 "i" + OpName %77 "i" + OpName %88 "i" + OpName %98 "i" + OpName %108 "i" + OpName %118 "i" + OpName %128 "i" + OpName %138 "i" + OpName %148 "i" + OpName %158 "i" + OpName %168 "i" + OpName %178 "i" + OpName %188 "i" + OpName %198 "i" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %38 = OpTypeInt 32 1 + %39 = OpTypePointer Function %38 + %41 = OpConstant %38 -10 + %48 = OpConstant %38 0 + %49 = OpTypeBool + %52 = OpConstant %38 1 + %55 = OpConstant %38 -5 + %62 = OpConstant %38 5 + %73 = OpConstant %38 10 + %84 = OpConstant %38 15 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %208 = OpFunctionCall %2 %6 + %209 = OpFunctionCall %2 %8 + %210 = OpFunctionCall %2 %10 + %211 = OpFunctionCall %2 %12 + %212 = OpFunctionCall %2 %14 + %213 = OpFunctionCall %2 %16 + %214 = OpFunctionCall %2 %18 + %215 = OpFunctionCall %2 %20 + %216 = OpFunctionCall %2 %22 + %217 = OpFunctionCall %2 %24 + %218 = OpFunctionCall %2 %26 + %219 = OpFunctionCall %2 %28 + %220 = OpFunctionCall %2 %30 + %221 = OpFunctionCall %2 %32 + %222 = OpFunctionCall %2 %34 + %223 = OpFunctionCall %2 %36 + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %40 = OpVariable %39 Function + OpStore %40 %41 + OpBranch %42 + %42 = OpLabel + %224 = OpPhi %38 %41 %7 %53 %45 + OpLoopMerge %44 %45 None + OpBranch %46 + %46 = OpLabel + %50 = OpSLessThan %49 %224 %48 + OpBranchConditional %50 %43 %44 + %43 = OpLabel + OpBranch %45 + %45 = OpLabel + %53 = OpIAdd %38 %224 %52 + OpStore %40 %53 + OpBranch %42 + %44 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %2 None %3 + %9 = OpLabel + %54 = OpVariable %39 Function + OpStore %54 %55 + OpBranch %56 + %56 = OpLabel + %225 = OpPhi %38 %55 %9 %65 %59 + OpLoopMerge %58 %59 None + OpBranch %60 + %60 = OpLabel + %63 = OpSLessThan %49 %225 %62 + OpBranchConditional %63 %57 %58 + %57 = OpLabel + OpBranch %59 + %59 = OpLabel + %65 = OpIAdd %38 %225 %52 + OpStore %54 %65 + OpBranch %56 + %58 = OpLabel + OpReturn + OpFunctionEnd + %10 = OpFunction %2 None %3 + %11 = OpLabel + %66 = OpVariable %39 Function + OpStore %66 %48 + OpBranch %67 + %67 = OpLabel + %226 = OpPhi %38 %48 %11 %76 %70 + OpLoopMerge %69 %70 None + OpBranch %71 + %71 = OpLabel + %74 = OpSLessThan %49 %226 %73 + OpBranchConditional %74 %68 %69 + %68 = OpLabel + OpBranch %70 + %70 = OpLabel + %76 = OpIAdd %38 %226 %52 + OpStore %66 %76 + OpBranch %67 + %69 = OpLabel + OpReturn + OpFunctionEnd + %12 = OpFunction %2 None %3 + %13 = OpLabel + %77 = OpVariable %39 Function + OpStore %77 %62 + OpBranch %78 + %78 = OpLabel + %227 = OpPhi %38 %62 %13 %87 %81 + OpLoopMerge %80 %81 None + OpBranch %82 + %82 = OpLabel + %85 = OpSLessThan %49 %227 %84 + OpBranchConditional %85 %79 %80 + %79 = OpLabel + OpBranch %81 + %81 = OpLabel + %87 = OpIAdd %38 %227 %52 + OpStore %77 %87 + OpBranch %78 + %80 = OpLabel + OpReturn + OpFunctionEnd + %14 = OpFunction %2 None %3 + %15 = OpLabel + %88 = OpVariable %39 Function + OpStore %88 %41 + OpBranch %89 + %89 = OpLabel + %228 = OpPhi %38 %41 %15 %97 %92 + OpLoopMerge %91 %92 None + OpBranch %93 + %93 = OpLabel + %95 = OpSLessThanEqual %49 %228 %48 + OpBranchConditional %95 %90 %91 + %90 = OpLabel + OpBranch %92 + %92 = OpLabel + %97 = OpIAdd %38 %228 %52 + OpStore %88 %97 + OpBranch %89 + %91 = OpLabel + OpReturn + OpFunctionEnd + %16 = OpFunction %2 None %3 + %17 = OpLabel + %98 = OpVariable %39 Function + OpStore %98 %55 + OpBranch %99 + %99 = OpLabel + %229 = OpPhi %38 %55 %17 %107 %102 + OpLoopMerge %101 %102 None + OpBranch %103 + %103 = OpLabel + %105 = OpSLessThanEqual %49 %229 %62 + OpBranchConditional %105 %100 %101 + %100 = OpLabel + OpBranch %102 + %102 = OpLabel + %107 = OpIAdd %38 %229 %52 + OpStore %98 %107 + OpBranch %99 + %101 = OpLabel + OpReturn + OpFunctionEnd + %18 = OpFunction %2 None %3 + %19 = OpLabel + %108 = OpVariable %39 Function + OpStore %108 %48 + OpBranch %109 + %109 = OpLabel + %230 = OpPhi %38 %48 %19 %117 %112 + OpLoopMerge %111 %112 None + OpBranch %113 + %113 = OpLabel + %115 = OpSLessThanEqual %49 %230 %73 + OpBranchConditional %115 %110 %111 + %110 = OpLabel + OpBranch %112 + %112 = OpLabel + %117 = OpIAdd %38 %230 %52 + OpStore %108 %117 + OpBranch %109 + %111 = OpLabel + OpReturn + OpFunctionEnd + %20 = OpFunction %2 None %3 + %21 = OpLabel + %118 = OpVariable %39 Function + OpStore %118 %62 + OpBranch %119 + %119 = OpLabel + %231 = OpPhi %38 %62 %21 %127 %122 + OpLoopMerge %121 %122 None + OpBranch %123 + %123 = OpLabel + %125 = OpSLessThanEqual %49 %231 %84 + OpBranchConditional %125 %120 %121 + %120 = OpLabel + OpBranch %122 + %122 = OpLabel + %127 = OpIAdd %38 %231 %52 + OpStore %118 %127 + OpBranch %119 + %121 = OpLabel + OpReturn + OpFunctionEnd + %22 = OpFunction %2 None %3 + %23 = OpLabel + %128 = OpVariable %39 Function + OpStore %128 %48 + OpBranch %129 + %129 = OpLabel + %232 = OpPhi %38 %48 %23 %137 %132 + OpLoopMerge %131 %132 None + OpBranch %133 + %133 = OpLabel + %135 = OpSGreaterThan %49 %232 %41 + OpBranchConditional %135 %130 %131 + %130 = OpLabel + OpBranch %132 + %132 = OpLabel + %137 = OpISub %38 %232 %52 + OpStore %128 %137 + OpBranch %129 + %131 = OpLabel + OpReturn + OpFunctionEnd + %24 = OpFunction %2 None %3 + %25 = OpLabel + %138 = OpVariable %39 Function + OpStore %138 %62 + OpBranch %139 + %139 = OpLabel + %233 = OpPhi %38 %62 %25 %147 %142 + OpLoopMerge %141 %142 None + OpBranch %143 + %143 = OpLabel + %145 = OpSGreaterThan %49 %233 %55 + OpBranchConditional %145 %140 %141 + %140 = OpLabel + OpBranch %142 + %142 = OpLabel + %147 = OpISub %38 %233 %52 + OpStore %138 %147 + OpBranch %139 + %141 = OpLabel + OpReturn + OpFunctionEnd + %26 = OpFunction %2 None %3 + %27 = OpLabel + %148 = OpVariable %39 Function + OpStore %148 %73 + OpBranch %149 + %149 = OpLabel + %234 = OpPhi %38 %73 %27 %157 %152 + OpLoopMerge %151 %152 None + OpBranch %153 + %153 = OpLabel + %155 = OpSGreaterThan %49 %234 %48 + OpBranchConditional %155 %150 %151 + %150 = OpLabel + OpBranch %152 + %152 = OpLabel + %157 = OpISub %38 %234 %52 + OpStore %148 %157 + OpBranch %149 + %151 = OpLabel + OpReturn + OpFunctionEnd + %28 = OpFunction %2 None %3 + %29 = OpLabel + %158 = OpVariable %39 Function + OpStore %158 %84 + OpBranch %159 + %159 = OpLabel + %235 = OpPhi %38 %84 %29 %167 %162 + OpLoopMerge %161 %162 None + OpBranch %163 + %163 = OpLabel + %165 = OpSGreaterThan %49 %235 %62 + OpBranchConditional %165 %160 %161 + %160 = OpLabel + OpBranch %162 + %162 = OpLabel + %167 = OpISub %38 %235 %52 + OpStore %158 %167 + OpBranch %159 + %161 = OpLabel + OpReturn + OpFunctionEnd + %30 = OpFunction %2 None %3 + %31 = OpLabel + %168 = OpVariable %39 Function + OpStore %168 %48 + OpBranch %169 + %169 = OpLabel + %236 = OpPhi %38 %48 %31 %177 %172 + OpLoopMerge %171 %172 None + OpBranch %173 + %173 = OpLabel + %175 = OpSGreaterThanEqual %49 %236 %41 + OpBranchConditional %175 %170 %171 + %170 = OpLabel + OpBranch %172 + %172 = OpLabel + %177 = OpISub %38 %236 %52 + OpStore %168 %177 + OpBranch %169 + %171 = OpLabel + OpReturn + OpFunctionEnd + %32 = OpFunction %2 None %3 + %33 = OpLabel + %178 = OpVariable %39 Function + OpStore %178 %62 + OpBranch %179 + %179 = OpLabel + %237 = OpPhi %38 %62 %33 %187 %182 + OpLoopMerge %181 %182 None + OpBranch %183 + %183 = OpLabel + %185 = OpSGreaterThanEqual %49 %237 %55 + OpBranchConditional %185 %180 %181 + %180 = OpLabel + OpBranch %182 + %182 = OpLabel + %187 = OpISub %38 %237 %52 + OpStore %178 %187 + OpBranch %179 + %181 = OpLabel + OpReturn + OpFunctionEnd + %34 = OpFunction %2 None %3 + %35 = OpLabel + %188 = OpVariable %39 Function + OpStore %188 %73 + OpBranch %189 + %189 = OpLabel + %238 = OpPhi %38 %73 %35 %197 %192 + OpLoopMerge %191 %192 None + OpBranch %193 + %193 = OpLabel + %195 = OpSGreaterThanEqual %49 %238 %48 + OpBranchConditional %195 %190 %191 + %190 = OpLabel + OpBranch %192 + %192 = OpLabel + %197 = OpISub %38 %238 %52 + OpStore %188 %197 + OpBranch %189 + %191 = OpLabel + OpReturn + OpFunctionEnd + %36 = OpFunction %2 None %3 + %37 = OpLabel + %198 = OpVariable %39 Function + OpStore %198 %84 + OpBranch %199 + %199 = OpLabel + %239 = OpPhi %38 %84 %37 %207 %202 + OpLoopMerge %201 %202 None + OpBranch %203 + %203 = OpLabel + %205 = OpSGreaterThanEqual %49 %239 %62 + OpBranchConditional %205 %200 %201 + %200 = OpLabel + OpBranch %202 + %202 = OpLabel + %207 = OpISub %38 %239 %52 + OpStore %198 %207 + OpBranch %199 + %201 = OpLabel + OpReturn + OpFunctionEnd +)"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + { + // Function a + const Function* f = spvtest::GetFunction(module, 6); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + -10); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + -1); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 10); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(-10)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(1)), + analysis.GetScalarEvolution()->CreateConstant(-1)); + } + { + // Function b + const Function* f = spvtest::GetFunction(module, 8); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + -5); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 4); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 10); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(-5)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(1)), + analysis.GetScalarEvolution()->CreateConstant(4)); + } + { + // Function c + const Function* f = spvtest::GetFunction(module, 10); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 0); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 9); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 10); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(0)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(1)), + analysis.GetScalarEvolution()->CreateConstant(9)); + } + { + // Function d + const Function* f = spvtest::GetFunction(module, 12); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 5); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 14); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 10); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(5)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(1)), + analysis.GetScalarEvolution()->CreateConstant(14)); + } + { + // Function e + const Function* f = spvtest::GetFunction(module, 14); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + -10); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 0); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 11); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(-10)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(1)), + analysis.GetScalarEvolution()->CreateConstant(0)); + } + { + // Function f + const Function* f = spvtest::GetFunction(module, 16); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + -5); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 5); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 11); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(-5)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(1)), + analysis.GetScalarEvolution()->CreateConstant(5)); + } + { + // Function g + const Function* f = spvtest::GetFunction(module, 18); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 0); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 10); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 11); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(0)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(1)), + analysis.GetScalarEvolution()->CreateConstant(10)); + } + { + // Function h + const Function* f = spvtest::GetFunction(module, 20); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 5); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 15); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 11); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(5)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(1)), + analysis.GetScalarEvolution()->CreateConstant(15)); + } + { + // Function i + const Function* f = spvtest::GetFunction(module, 22); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 0); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + -9); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 10); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(0)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(-1)), + analysis.GetScalarEvolution()->CreateConstant(-9)); + } + { + // Function j + const Function* f = spvtest::GetFunction(module, 24); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 5); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + -4); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 10); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(5)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(-1)), + analysis.GetScalarEvolution()->CreateConstant(-4)); + } + { + // Function k + const Function* f = spvtest::GetFunction(module, 26); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 10); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 1); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 10); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(10)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(-1)), + analysis.GetScalarEvolution()->CreateConstant(1)); + } + { + // Function l + const Function* f = spvtest::GetFunction(module, 28); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 15); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 6); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 10); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(15)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(-1)), + analysis.GetScalarEvolution()->CreateConstant(6)); + } + { + // Function m + const Function* f = spvtest::GetFunction(module, 30); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 0); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + -10); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 11); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(0)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(-1)), + analysis.GetScalarEvolution()->CreateConstant(-10)); + } + { + // Function n + const Function* f = spvtest::GetFunction(module, 32); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 5); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + -5); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 11); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(5)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(-1)), + analysis.GetScalarEvolution()->CreateConstant(-5)); + } + { + // Function o + const Function* f = spvtest::GetFunction(module, 34); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 10); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 0); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 11); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(10)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(-1)), + analysis.GetScalarEvolution()->CreateConstant(0)); + } + { + // Function p + const Function* f = spvtest::GetFunction(module, 36); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_EQ( + analysis.GetLowerBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 15); + EXPECT_EQ( + analysis.GetUpperBound(loop)->AsSEConstantNode()->FoldToSingleValue(), + 5); + + EXPECT_EQ( + analysis.GetTripCount(loop)->AsSEConstantNode()->FoldToSingleValue(), + 11); + + EXPECT_EQ(analysis.GetFirstTripInductionNode(loop), + analysis.GetScalarEvolution()->CreateConstant(15)); + + EXPECT_EQ(analysis.GetFinalTripInductionNode( + loop, analysis.GetScalarEvolution()->CreateConstant(-1)), + analysis.GetScalarEvolution()->CreateConstant(5)); + } +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +void main(){ + for (int i = 0; i < 10; i++) { + + } +} +*/ +TEST(DependencyAnalysisHelpers, bounds_checks) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %20 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %22 = OpPhi %6 %9 %5 %21 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %22 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %13 + %13 = OpLabel + %21 = OpIAdd %6 %22 %20 + OpStore %8 %21 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd +)"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + // We need a shader that includes a loop for this test so we can build a + // LoopDependenceAnalaysis + const Function* f = spvtest::GetFunction(module, 4); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + EXPECT_TRUE(analysis.IsWithinBounds(0, 0, 0)); + EXPECT_TRUE(analysis.IsWithinBounds(0, -1, 0)); + EXPECT_TRUE(analysis.IsWithinBounds(0, 0, 1)); + EXPECT_TRUE(analysis.IsWithinBounds(0, -1, 1)); + EXPECT_TRUE(analysis.IsWithinBounds(-2, -2, -2)); + EXPECT_TRUE(analysis.IsWithinBounds(-2, -3, 0)); + EXPECT_TRUE(analysis.IsWithinBounds(-2, 0, -3)); + EXPECT_TRUE(analysis.IsWithinBounds(2, 2, 2)); + EXPECT_TRUE(analysis.IsWithinBounds(2, 3, 0)); + + EXPECT_FALSE(analysis.IsWithinBounds(2, 3, 3)); + EXPECT_FALSE(analysis.IsWithinBounds(0, 1, 5)); + EXPECT_FALSE(analysis.IsWithinBounds(0, -1, -4)); + EXPECT_FALSE(analysis.IsWithinBounds(-2, -4, -3)); +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +layout(location = 0) in vec4 in_vec; +// Loop iterates from constant to symbolic +void a() { + int N = int(in_vec.x); + int arr[10]; + for (int i = 0; i < N; i++) { // Bounds are N - 0 - 1 + arr[i] = arr[i+N]; // |distance| = N + arr[i+N] = arr[i]; // |distance| = N + } +} +void b() { + int N = int(in_vec.x); + int arr[10]; + for (int i = 0; i <= N; i++) { // Bounds are N - 0 + arr[i] = arr[i+N]; // |distance| = N + arr[i+N] = arr[i]; // |distance| = N + } +} +void c() { + int N = int(in_vec.x); + int arr[10]; + for (int i = 9; i > N; i--) { // Bounds are 9 - N - 1 + arr[i] = arr[i+N]; // |distance| = N + arr[i+N] = arr[i]; // |distance| = N + } +} +void d() { + int N = int(in_vec.x); + int arr[10]; + for (int i = 9; i >= N; i--) { // Bounds are 9 - N + arr[i] = arr[i+N]; // |distance| = N + arr[i+N] = arr[i]; // |distance| = N + } +} +void main(){ + a(); + b(); + c(); + d(); +} +*/ +TEST(DependencyAnalysisHelpers, const_to_symbolic) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %20 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %6 "a(" + OpName %8 "b(" + OpName %10 "c(" + OpName %12 "d(" + OpName %16 "N" + OpName %20 "in_vec" + OpName %27 "i" + OpName %41 "arr" + OpName %59 "N" + OpName %63 "i" + OpName %72 "arr" + OpName %89 "N" + OpName %93 "i" + OpName %103 "arr" + OpName %120 "N" + OpName %124 "i" + OpName %133 "arr" + OpDecorate %20 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %14 = OpTypeInt 32 1 + %15 = OpTypePointer Function %14 + %17 = OpTypeFloat 32 + %18 = OpTypeVector %17 4 + %19 = OpTypePointer Input %18 + %20 = OpVariable %19 Input + %21 = OpTypeInt 32 0 + %22 = OpConstant %21 0 + %23 = OpTypePointer Input %17 + %28 = OpConstant %14 0 + %36 = OpTypeBool + %38 = OpConstant %21 10 + %39 = OpTypeArray %14 %38 + %40 = OpTypePointer Function %39 + %57 = OpConstant %14 1 + %94 = OpConstant %14 9 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %150 = OpFunctionCall %2 %6 + %151 = OpFunctionCall %2 %8 + %152 = OpFunctionCall %2 %10 + %153 = OpFunctionCall %2 %12 + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %16 = OpVariable %15 Function + %27 = OpVariable %15 Function + %41 = OpVariable %40 Function + %24 = OpAccessChain %23 %20 %22 + %25 = OpLoad %17 %24 + %26 = OpConvertFToS %14 %25 + OpStore %16 %26 + OpStore %27 %28 + OpBranch %29 + %29 = OpLabel + %154 = OpPhi %14 %28 %7 %58 %32 + OpLoopMerge %31 %32 None + OpBranch %33 + %33 = OpLabel + %37 = OpSLessThan %36 %154 %26 + OpBranchConditional %37 %30 %31 + %30 = OpLabel + %45 = OpIAdd %14 %154 %26 + %46 = OpAccessChain %15 %41 %45 + %47 = OpLoad %14 %46 + %48 = OpAccessChain %15 %41 %154 + OpStore %48 %47 + %51 = OpIAdd %14 %154 %26 + %53 = OpAccessChain %15 %41 %154 + %54 = OpLoad %14 %53 + %55 = OpAccessChain %15 %41 %51 + OpStore %55 %54 + OpBranch %32 + %32 = OpLabel + %58 = OpIAdd %14 %154 %57 + OpStore %27 %58 + OpBranch %29 + %31 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %2 None %3 + %9 = OpLabel + %59 = OpVariable %15 Function + %63 = OpVariable %15 Function + %72 = OpVariable %40 Function + %60 = OpAccessChain %23 %20 %22 + %61 = OpLoad %17 %60 + %62 = OpConvertFToS %14 %61 + OpStore %59 %62 + OpStore %63 %28 + OpBranch %64 + %64 = OpLabel + %155 = OpPhi %14 %28 %9 %88 %67 + OpLoopMerge %66 %67 None + OpBranch %68 + %68 = OpLabel + %71 = OpSLessThanEqual %36 %155 %62 + OpBranchConditional %71 %65 %66 + %65 = OpLabel + %76 = OpIAdd %14 %155 %62 + %77 = OpAccessChain %15 %72 %76 + %78 = OpLoad %14 %77 + %79 = OpAccessChain %15 %72 %155 + OpStore %79 %78 + %82 = OpIAdd %14 %155 %62 + %84 = OpAccessChain %15 %72 %155 + %85 = OpLoad %14 %84 + %86 = OpAccessChain %15 %72 %82 + OpStore %86 %85 + OpBranch %67 + %67 = OpLabel + %88 = OpIAdd %14 %155 %57 + OpStore %63 %88 + OpBranch %64 + %66 = OpLabel + OpReturn + OpFunctionEnd + %10 = OpFunction %2 None %3 + %11 = OpLabel + %89 = OpVariable %15 Function + %93 = OpVariable %15 Function + %103 = OpVariable %40 Function + %90 = OpAccessChain %23 %20 %22 + %91 = OpLoad %17 %90 + %92 = OpConvertFToS %14 %91 + OpStore %89 %92 + OpStore %93 %94 + OpBranch %95 + %95 = OpLabel + %156 = OpPhi %14 %94 %11 %119 %98 + OpLoopMerge %97 %98 None + OpBranch %99 + %99 = OpLabel + %102 = OpSGreaterThan %36 %156 %92 + OpBranchConditional %102 %96 %97 + %96 = OpLabel + %107 = OpIAdd %14 %156 %92 + %108 = OpAccessChain %15 %103 %107 + %109 = OpLoad %14 %108 + %110 = OpAccessChain %15 %103 %156 + OpStore %110 %109 + %113 = OpIAdd %14 %156 %92 + %115 = OpAccessChain %15 %103 %156 + %116 = OpLoad %14 %115 + %117 = OpAccessChain %15 %103 %113 + OpStore %117 %116 + OpBranch %98 + %98 = OpLabel + %119 = OpISub %14 %156 %57 + OpStore %93 %119 + OpBranch %95 + %97 = OpLabel + OpReturn + OpFunctionEnd + %12 = OpFunction %2 None %3 + %13 = OpLabel + %120 = OpVariable %15 Function + %124 = OpVariable %15 Function + %133 = OpVariable %40 Function + %121 = OpAccessChain %23 %20 %22 + %122 = OpLoad %17 %121 + %123 = OpConvertFToS %14 %122 + OpStore %120 %123 + OpStore %124 %94 + OpBranch %125 + %125 = OpLabel + %157 = OpPhi %14 %94 %13 %149 %128 + OpLoopMerge %127 %128 None + OpBranch %129 + %129 = OpLabel + %132 = OpSGreaterThanEqual %36 %157 %123 + OpBranchConditional %132 %126 %127 + %126 = OpLabel + %137 = OpIAdd %14 %157 %123 + %138 = OpAccessChain %15 %133 %137 + %139 = OpLoad %14 %138 + %140 = OpAccessChain %15 %133 %157 + OpStore %140 %139 + %143 = OpIAdd %14 %157 %123 + %145 = OpAccessChain %15 %133 %157 + %146 = OpLoad %14 %145 + %147 = OpAccessChain %15 %133 %143 + OpStore %147 %146 + OpBranch %128 + %128 = OpLabel + %149 = OpISub %14 %157 %57 + OpStore %124 %149 + OpBranch %125 + %127 = OpLabel + OpReturn + OpFunctionEnd +)"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + { + // Function a + const Function* f = spvtest::GetFunction(module, 6); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* stores[2]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 30)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + stores[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(stores[i]); + } + + // 47 -> 48 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(47) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[0]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Independent and supported. + EXPECT_TRUE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + + // 54 -> 55 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(54) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[1]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Independent but not supported. + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + } + { + // Function b + const Function* f = spvtest::GetFunction(module, 8); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* stores[2]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 65)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + stores[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(stores[i]); + } + + // 78 -> 79 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(78) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[0]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Dependent. + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + + // 85 -> 86 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(85) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[1]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Dependent. + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + } + { + // Function c + const Function* f = spvtest::GetFunction(module, 10); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* stores[2]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 96)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + stores[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(stores[i]); + } + + // 109 -> 110 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(109) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[0]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Independent but not supported. + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + + // 116 -> 117 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(116) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[1]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Independent but not supported. + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + } + { + // Function d + const Function* f = spvtest::GetFunction(module, 12); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* stores[2]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 126)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + stores[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(stores[i]); + } + + // 139 -> 140 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(139) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[0]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Dependent. + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + + // 146 -> 147 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(146) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[1]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Dependent. + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + } +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +layout(location = 0) in vec4 in_vec; +// Loop iterates from symbolic to constant +void a() { + int N = int(in_vec.x); + int arr[10]; + for (int i = N; i < 9; i++) { // Bounds are 9 - N - 1 + arr[i] = arr[i+N]; // |distance| = N + arr[i+N] = arr[i]; // |distance| = N + } +} +void b() { + int N = int(in_vec.x); + int arr[10]; + for (int i = N; i <= 9; i++) { // Bounds are 9 - N + arr[i] = arr[i+N]; // |distance| = N + arr[i+N] = arr[i]; // |distance| = N + } +} +void c() { + int N = int(in_vec.x); + int arr[10]; + for (int i = N; i > 0; i--) { // Bounds are N - 0 - 1 + arr[i] = arr[i+N]; // |distance| = N + arr[i+N] = arr[i]; // |distance| = N + } +} +void d() { + int N = int(in_vec.x); + int arr[10]; + for (int i = N; i >= 0; i--) { // Bounds are N - 0 + arr[i] = arr[i+N]; // |distance| = N + arr[i+N] = arr[i]; // |distance| = N + } +} +void main(){ + a(); + b(); + c(); + d(); +} +*/ +TEST(DependencyAnalysisHelpers, symbolic_to_const) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %20 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %6 "a(" + OpName %8 "b(" + OpName %10 "c(" + OpName %12 "d(" + OpName %16 "N" + OpName %20 "in_vec" + OpName %27 "i" + OpName %41 "arr" + OpName %59 "N" + OpName %63 "i" + OpName %72 "arr" + OpName %89 "N" + OpName %93 "i" + OpName %103 "arr" + OpName %120 "N" + OpName %124 "i" + OpName %133 "arr" + OpDecorate %20 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %14 = OpTypeInt 32 1 + %15 = OpTypePointer Function %14 + %17 = OpTypeFloat 32 + %18 = OpTypeVector %17 4 + %19 = OpTypePointer Input %18 + %20 = OpVariable %19 Input + %21 = OpTypeInt 32 0 + %22 = OpConstant %21 0 + %23 = OpTypePointer Input %17 + %35 = OpConstant %14 9 + %36 = OpTypeBool + %38 = OpConstant %21 10 + %39 = OpTypeArray %14 %38 + %40 = OpTypePointer Function %39 + %57 = OpConstant %14 1 + %101 = OpConstant %14 0 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %150 = OpFunctionCall %2 %6 + %151 = OpFunctionCall %2 %8 + %152 = OpFunctionCall %2 %10 + %153 = OpFunctionCall %2 %12 + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %16 = OpVariable %15 Function + %27 = OpVariable %15 Function + %41 = OpVariable %40 Function + %24 = OpAccessChain %23 %20 %22 + %25 = OpLoad %17 %24 + %26 = OpConvertFToS %14 %25 + OpStore %16 %26 + OpStore %27 %26 + OpBranch %29 + %29 = OpLabel + %154 = OpPhi %14 %26 %7 %58 %32 + OpLoopMerge %31 %32 None + OpBranch %33 + %33 = OpLabel + %37 = OpSLessThan %36 %154 %35 + OpBranchConditional %37 %30 %31 + %30 = OpLabel + %45 = OpIAdd %14 %154 %26 + %46 = OpAccessChain %15 %41 %45 + %47 = OpLoad %14 %46 + %48 = OpAccessChain %15 %41 %154 + OpStore %48 %47 + %51 = OpIAdd %14 %154 %26 + %53 = OpAccessChain %15 %41 %154 + %54 = OpLoad %14 %53 + %55 = OpAccessChain %15 %41 %51 + OpStore %55 %54 + OpBranch %32 + %32 = OpLabel + %58 = OpIAdd %14 %154 %57 + OpStore %27 %58 + OpBranch %29 + %31 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %2 None %3 + %9 = OpLabel + %59 = OpVariable %15 Function + %63 = OpVariable %15 Function + %72 = OpVariable %40 Function + %60 = OpAccessChain %23 %20 %22 + %61 = OpLoad %17 %60 + %62 = OpConvertFToS %14 %61 + OpStore %59 %62 + OpStore %63 %62 + OpBranch %65 + %65 = OpLabel + %155 = OpPhi %14 %62 %9 %88 %68 + OpLoopMerge %67 %68 None + OpBranch %69 + %69 = OpLabel + %71 = OpSLessThanEqual %36 %155 %35 + OpBranchConditional %71 %66 %67 + %66 = OpLabel + %76 = OpIAdd %14 %155 %62 + %77 = OpAccessChain %15 %72 %76 + %78 = OpLoad %14 %77 + %79 = OpAccessChain %15 %72 %155 + OpStore %79 %78 + %82 = OpIAdd %14 %155 %62 + %84 = OpAccessChain %15 %72 %155 + %85 = OpLoad %14 %84 + %86 = OpAccessChain %15 %72 %82 + OpStore %86 %85 + OpBranch %68 + %68 = OpLabel + %88 = OpIAdd %14 %155 %57 + OpStore %63 %88 + OpBranch %65 + %67 = OpLabel + OpReturn + OpFunctionEnd + %10 = OpFunction %2 None %3 + %11 = OpLabel + %89 = OpVariable %15 Function + %93 = OpVariable %15 Function + %103 = OpVariable %40 Function + %90 = OpAccessChain %23 %20 %22 + %91 = OpLoad %17 %90 + %92 = OpConvertFToS %14 %91 + OpStore %89 %92 + OpStore %93 %92 + OpBranch %95 + %95 = OpLabel + %156 = OpPhi %14 %92 %11 %119 %98 + OpLoopMerge %97 %98 None + OpBranch %99 + %99 = OpLabel + %102 = OpSGreaterThan %36 %156 %101 + OpBranchConditional %102 %96 %97 + %96 = OpLabel + %107 = OpIAdd %14 %156 %92 + %108 = OpAccessChain %15 %103 %107 + %109 = OpLoad %14 %108 + %110 = OpAccessChain %15 %103 %156 + OpStore %110 %109 + %113 = OpIAdd %14 %156 %92 + %115 = OpAccessChain %15 %103 %156 + %116 = OpLoad %14 %115 + %117 = OpAccessChain %15 %103 %113 + OpStore %117 %116 + OpBranch %98 + %98 = OpLabel + %119 = OpISub %14 %156 %57 + OpStore %93 %119 + OpBranch %95 + %97 = OpLabel + OpReturn + OpFunctionEnd + %12 = OpFunction %2 None %3 + %13 = OpLabel + %120 = OpVariable %15 Function + %124 = OpVariable %15 Function + %133 = OpVariable %40 Function + %121 = OpAccessChain %23 %20 %22 + %122 = OpLoad %17 %121 + %123 = OpConvertFToS %14 %122 + OpStore %120 %123 + OpStore %124 %123 + OpBranch %126 + %126 = OpLabel + %157 = OpPhi %14 %123 %13 %149 %129 + OpLoopMerge %128 %129 None + OpBranch %130 + %130 = OpLabel + %132 = OpSGreaterThanEqual %36 %157 %101 + OpBranchConditional %132 %127 %128 + %127 = OpLabel + %137 = OpIAdd %14 %157 %123 + %138 = OpAccessChain %15 %133 %137 + %139 = OpLoad %14 %138 + %140 = OpAccessChain %15 %133 %157 + OpStore %140 %139 + %143 = OpIAdd %14 %157 %123 + %145 = OpAccessChain %15 %133 %157 + %146 = OpLoad %14 %145 + %147 = OpAccessChain %15 %133 %143 + OpStore %147 %146 + OpBranch %129 + %129 = OpLabel + %149 = OpISub %14 %157 %57 + OpStore %124 %149 + OpBranch %126 + %128 = OpLabel + OpReturn + OpFunctionEnd +)"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + { + // Function a + const Function* f = spvtest::GetFunction(module, 6); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* stores[2]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 30)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + stores[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(stores[i]); + } + + // 47 -> 48 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(47) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[0]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Independent but not supported. + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + + // 54 -> 55 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(54) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[1]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Independent but not supported. + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + } + { + // Function b + const Function* f = spvtest::GetFunction(module, 8); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* stores[2]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 66)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + stores[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(stores[i]); + } + + // 78 -> 79 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(78) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[0]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Dependent. + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + + // 85 -> 86 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(85) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[1]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Dependent. + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + } + { + // Function c + const Function* f = spvtest::GetFunction(module, 10); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* stores[2]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 96)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + stores[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(stores[i]); + } + + // 109 -> 110 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(109) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[0]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Independent and supported. + EXPECT_TRUE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + + // 116 -> 117 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(116) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[1]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Independent but not supported. + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + } + { + // Function d + const Function* f = spvtest::GetFunction(module, 12); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* stores[2]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 127)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + stores[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(stores[i]); + } + + // 139 -> 140 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(139) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[0]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Dependent + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + + // 146 -> 147 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(146) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[1]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + // Dependent + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + } +} + +/* + Generated from the following GLSL fragment shader + with --eliminate-local-multi-store +#version 440 core +layout(location = 0) in vec4 in_vec; +// Loop iterates from symbolic to symbolic +void a() { + int M = int(in_vec.x); + int N = int(in_vec.y); + int arr[10]; + for (int i = M; i < N; i++) { // Bounds are N - M - 1 + arr[i+M+N] = arr[i+M+2*N]; // |distance| = N + arr[i+M+2*N] = arr[i+M+N]; // |distance| = N + } +} +void b() { + int M = int(in_vec.x); + int N = int(in_vec.y); + int arr[10]; + for (int i = M; i <= N; i++) { // Bounds are N - M + arr[i+M+N] = arr[i+M+2*N]; // |distance| = N + arr[i+M+2*N] = arr[i+M+N]; // |distance| = N + } +} +void c() { + int M = int(in_vec.x); + int N = int(in_vec.y); + int arr[10]; + for (int i = M; i > N; i--) { // Bounds are M - N - 1 + arr[i+M+N] = arr[i+M+2*N]; // |distance| = N + arr[i+M+2*N] = arr[i+M+N]; // |distance| = N + } +} +void d() { + int M = int(in_vec.x); + int N = int(in_vec.y); + int arr[10]; + for (int i = M; i >= N; i--) { // Bounds are M - N + arr[i+M+N] = arr[i+M+2*N]; // |distance| = N + arr[i+M+2*N] = arr[i+M+N]; // |distance| = N + } +} +void main(){ + a(); + b(); + c(); + d(); +} +*/ +TEST(DependencyAnalysisHelpers, symbolic_to_symbolic) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %20 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %6 "a(" + OpName %8 "b(" + OpName %10 "c(" + OpName %12 "d(" + OpName %16 "M" + OpName %20 "in_vec" + OpName %27 "N" + OpName %32 "i" + OpName %46 "arr" + OpName %79 "M" + OpName %83 "N" + OpName %87 "i" + OpName %97 "arr" + OpName %128 "M" + OpName %132 "N" + OpName %136 "i" + OpName %146 "arr" + OpName %177 "M" + OpName %181 "N" + OpName %185 "i" + OpName %195 "arr" + OpDecorate %20 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %14 = OpTypeInt 32 1 + %15 = OpTypePointer Function %14 + %17 = OpTypeFloat 32 + %18 = OpTypeVector %17 4 + %19 = OpTypePointer Input %18 + %20 = OpVariable %19 Input + %21 = OpTypeInt 32 0 + %22 = OpConstant %21 0 + %23 = OpTypePointer Input %17 + %28 = OpConstant %21 1 + %41 = OpTypeBool + %43 = OpConstant %21 10 + %44 = OpTypeArray %14 %43 + %45 = OpTypePointer Function %44 + %55 = OpConstant %14 2 + %77 = OpConstant %14 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %226 = OpFunctionCall %2 %6 + %227 = OpFunctionCall %2 %8 + %228 = OpFunctionCall %2 %10 + %229 = OpFunctionCall %2 %12 + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %16 = OpVariable %15 Function + %27 = OpVariable %15 Function + %32 = OpVariable %15 Function + %46 = OpVariable %45 Function + %24 = OpAccessChain %23 %20 %22 + %25 = OpLoad %17 %24 + %26 = OpConvertFToS %14 %25 + OpStore %16 %26 + %29 = OpAccessChain %23 %20 %28 + %30 = OpLoad %17 %29 + %31 = OpConvertFToS %14 %30 + OpStore %27 %31 + OpStore %32 %26 + OpBranch %34 + %34 = OpLabel + %230 = OpPhi %14 %26 %7 %78 %37 + OpLoopMerge %36 %37 None + OpBranch %38 + %38 = OpLabel + %42 = OpSLessThan %41 %230 %31 + OpBranchConditional %42 %35 %36 + %35 = OpLabel + %49 = OpIAdd %14 %230 %26 + %51 = OpIAdd %14 %49 %31 + %54 = OpIAdd %14 %230 %26 + %57 = OpIMul %14 %55 %31 + %58 = OpIAdd %14 %54 %57 + %59 = OpAccessChain %15 %46 %58 + %60 = OpLoad %14 %59 + %61 = OpAccessChain %15 %46 %51 + OpStore %61 %60 + %64 = OpIAdd %14 %230 %26 + %66 = OpIMul %14 %55 %31 + %67 = OpIAdd %14 %64 %66 + %70 = OpIAdd %14 %230 %26 + %72 = OpIAdd %14 %70 %31 + %73 = OpAccessChain %15 %46 %72 + %74 = OpLoad %14 %73 + %75 = OpAccessChain %15 %46 %67 + OpStore %75 %74 + OpBranch %37 + %37 = OpLabel + %78 = OpIAdd %14 %230 %77 + OpStore %32 %78 + OpBranch %34 + %36 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %2 None %3 + %9 = OpLabel + %79 = OpVariable %15 Function + %83 = OpVariable %15 Function + %87 = OpVariable %15 Function + %97 = OpVariable %45 Function + %80 = OpAccessChain %23 %20 %22 + %81 = OpLoad %17 %80 + %82 = OpConvertFToS %14 %81 + OpStore %79 %82 + %84 = OpAccessChain %23 %20 %28 + %85 = OpLoad %17 %84 + %86 = OpConvertFToS %14 %85 + OpStore %83 %86 + OpStore %87 %82 + OpBranch %89 + %89 = OpLabel + %231 = OpPhi %14 %82 %9 %127 %92 + OpLoopMerge %91 %92 None + OpBranch %93 + %93 = OpLabel + %96 = OpSLessThanEqual %41 %231 %86 + OpBranchConditional %96 %90 %91 + %90 = OpLabel + %100 = OpIAdd %14 %231 %82 + %102 = OpIAdd %14 %100 %86 + %105 = OpIAdd %14 %231 %82 + %107 = OpIMul %14 %55 %86 + %108 = OpIAdd %14 %105 %107 + %109 = OpAccessChain %15 %97 %108 + %110 = OpLoad %14 %109 + %111 = OpAccessChain %15 %97 %102 + OpStore %111 %110 + %114 = OpIAdd %14 %231 %82 + %116 = OpIMul %14 %55 %86 + %117 = OpIAdd %14 %114 %116 + %120 = OpIAdd %14 %231 %82 + %122 = OpIAdd %14 %120 %86 + %123 = OpAccessChain %15 %97 %122 + %124 = OpLoad %14 %123 + %125 = OpAccessChain %15 %97 %117 + OpStore %125 %124 + OpBranch %92 + %92 = OpLabel + %127 = OpIAdd %14 %231 %77 + OpStore %87 %127 + OpBranch %89 + %91 = OpLabel + OpReturn + OpFunctionEnd + %10 = OpFunction %2 None %3 + %11 = OpLabel + %128 = OpVariable %15 Function + %132 = OpVariable %15 Function + %136 = OpVariable %15 Function + %146 = OpVariable %45 Function + %129 = OpAccessChain %23 %20 %22 + %130 = OpLoad %17 %129 + %131 = OpConvertFToS %14 %130 + OpStore %128 %131 + %133 = OpAccessChain %23 %20 %28 + %134 = OpLoad %17 %133 + %135 = OpConvertFToS %14 %134 + OpStore %132 %135 + OpStore %136 %131 + OpBranch %138 + %138 = OpLabel + %232 = OpPhi %14 %131 %11 %176 %141 + OpLoopMerge %140 %141 None + OpBranch %142 + %142 = OpLabel + %145 = OpSGreaterThan %41 %232 %135 + OpBranchConditional %145 %139 %140 + %139 = OpLabel + %149 = OpIAdd %14 %232 %131 + %151 = OpIAdd %14 %149 %135 + %154 = OpIAdd %14 %232 %131 + %156 = OpIMul %14 %55 %135 + %157 = OpIAdd %14 %154 %156 + %158 = OpAccessChain %15 %146 %157 + %159 = OpLoad %14 %158 + %160 = OpAccessChain %15 %146 %151 + OpStore %160 %159 + %163 = OpIAdd %14 %232 %131 + %165 = OpIMul %14 %55 %135 + %166 = OpIAdd %14 %163 %165 + %169 = OpIAdd %14 %232 %131 + %171 = OpIAdd %14 %169 %135 + %172 = OpAccessChain %15 %146 %171 + %173 = OpLoad %14 %172 + %174 = OpAccessChain %15 %146 %166 + OpStore %174 %173 + OpBranch %141 + %141 = OpLabel + %176 = OpISub %14 %232 %77 + OpStore %136 %176 + OpBranch %138 + %140 = OpLabel + OpReturn + OpFunctionEnd + %12 = OpFunction %2 None %3 + %13 = OpLabel + %177 = OpVariable %15 Function + %181 = OpVariable %15 Function + %185 = OpVariable %15 Function + %195 = OpVariable %45 Function + %178 = OpAccessChain %23 %20 %22 + %179 = OpLoad %17 %178 + %180 = OpConvertFToS %14 %179 + OpStore %177 %180 + %182 = OpAccessChain %23 %20 %28 + %183 = OpLoad %17 %182 + %184 = OpConvertFToS %14 %183 + OpStore %181 %184 + OpStore %185 %180 + OpBranch %187 + %187 = OpLabel + %233 = OpPhi %14 %180 %13 %225 %190 + OpLoopMerge %189 %190 None + OpBranch %191 + %191 = OpLabel + %194 = OpSGreaterThanEqual %41 %233 %184 + OpBranchConditional %194 %188 %189 + %188 = OpLabel + %198 = OpIAdd %14 %233 %180 + %200 = OpIAdd %14 %198 %184 + %203 = OpIAdd %14 %233 %180 + %205 = OpIMul %14 %55 %184 + %206 = OpIAdd %14 %203 %205 + %207 = OpAccessChain %15 %195 %206 + %208 = OpLoad %14 %207 + %209 = OpAccessChain %15 %195 %200 + OpStore %209 %208 + %212 = OpIAdd %14 %233 %180 + %214 = OpIMul %14 %55 %184 + %215 = OpIAdd %14 %212 %214 + %218 = OpIAdd %14 %233 %180 + %220 = OpIAdd %14 %218 %184 + %221 = OpAccessChain %15 %195 %220 + %222 = OpLoad %14 %221 + %223 = OpAccessChain %15 %195 %215 + OpStore %223 %222 + OpBranch %190 + %190 = OpLabel + %225 = OpISub %14 %233 %77 + OpStore %185 %225 + OpBranch %187 + %189 = OpLabel + OpReturn + OpFunctionEnd +)"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + { + // Function a + const Function* f = spvtest::GetFunction(module, 6); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* stores[2]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 35)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + stores[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(stores[i]); + } + + // 60 -> 61 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(60) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[0]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + + // 74 -> 75 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(74) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[1]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + } + { + // Function b + const Function* f = spvtest::GetFunction(module, 8); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* stores[2]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 90)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + stores[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(stores[i]); + } + + // 110 -> 111 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(110) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[0]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + + // 124 -> 125 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(124) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[1]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + } + { + // Function c + const Function* f = spvtest::GetFunction(module, 10); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* stores[2]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 139)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + stores[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(stores[i]); + } + + // 159 -> 160 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(159) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[0]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + + // 173 -> 174 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(173) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[1]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + } + { + // Function d + const Function* f = spvtest::GetFunction(module, 12); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + Loop* loop = &ld.GetLoopByIndex(0); + std::vector loops{loop}; + LoopDependenceAnalysis analysis{context.get(), loops}; + + const Instruction* stores[2]; + int stores_found = 0; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 188)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + stores[stores_found] = &inst; + ++stores_found; + } + } + + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(stores[i]); + } + + // 208 -> 209 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(208) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[0]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + + // 222 -> 223 + { + // Analyse and simplify the instruction behind the access chain of this + // load. + Instruction* load_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(context->get_def_use_mgr() + ->GetDef(222) + ->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* load = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(load_var)); + + // Analyse and simplify the instruction behind the access chain of this + // store. + Instruction* store_var = context->get_def_use_mgr()->GetDef( + context->get_def_use_mgr() + ->GetDef(stores[1]->GetSingleWordInOperand(0)) + ->GetSingleWordInOperand(1)); + SENode* store = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->AnalyzeInstruction(store_var)); + + SENode* delta = analysis.GetScalarEvolution()->SimplifyExpression( + analysis.GetScalarEvolution()->CreateSubtraction(load, store)); + + EXPECT_FALSE(analysis.IsProvablyOutsideOfLoopBounds( + loop, delta, store->AsSERecurrentNode()->GetCoefficient())); + } + } +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/loop_optimizations/fusion_compatibility.cpp b/test/opt/loop_optimizations/fusion_compatibility.cpp new file mode 100644 index 000000000..cda8576c5 --- /dev/null +++ b/test/opt/loop_optimizations/fusion_compatibility.cpp @@ -0,0 +1,1785 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/loop_fusion.h" +#include "test/opt/pass_fixture.h" + +namespace spvtools { +namespace opt { +namespace { + +using FusionCompatibilityTest = PassTest<::testing::Test>; + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int i = 0; // Can't fuse, i=0 in first & i=10 in second + for (; i < 10; i++) {} + for (; i < 10; i++) {} +} +*/ +TEST_F(FusionCompatibilityTest, SameInductionVariableDifferentBounds) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %20 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %31 = OpPhi %6 %9 %5 %21 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %31 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %13 + %13 = OpLabel + %21 = OpIAdd %6 %31 %20 + OpStore %8 %21 + OpBranch %10 + %12 = OpLabel + OpBranch %22 + %22 = OpLabel + %32 = OpPhi %6 %31 %12 %30 %25 + OpLoopMerge %24 %25 None + OpBranch %26 + %26 = OpLabel + %28 = OpSLessThan %17 %32 %16 + OpBranchConditional %28 %23 %24 + %23 = OpLabel + OpBranch %25 + %25 = OpLabel + %30 = OpIAdd %6 %32 %20 + OpStore %8 %30 + OpBranch %22 + %24 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_FALSE(fusion.AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 1 +#version 440 core +void main() { + for (int i = 0; i < 10; i++) {} + for (int i = 0; i < 10; i++) {} +} +*/ +TEST_F(FusionCompatibilityTest, Compatible) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %22 "i" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %20 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %22 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %32 = OpPhi %6 %9 %5 %21 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %32 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %13 + %13 = OpLabel + %21 = OpIAdd %6 %32 %20 + OpStore %8 %21 + OpBranch %10 + %12 = OpLabel + OpStore %22 %9 + OpBranch %23 + %23 = OpLabel + %33 = OpPhi %6 %9 %12 %31 %26 + OpLoopMerge %25 %26 None + OpBranch %27 + %27 = OpLabel + %29 = OpSLessThan %17 %33 %16 + OpBranchConditional %29 %24 %25 + %24 = OpLabel + OpBranch %26 + %26 = OpLabel + %31 = OpIAdd %6 %33 %20 + OpStore %22 %31 + OpBranch %23 + %25 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 2 +#version 440 core +void main() { + for (int i = 0; i < 10; i++) {} + for (int j = 0; j < 10; j++) {} +} + +*/ +TEST_F(FusionCompatibilityTest, DifferentName) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %22 "j" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %20 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %22 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %32 = OpPhi %6 %9 %5 %21 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %32 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %13 + %13 = OpLabel + %21 = OpIAdd %6 %32 %20 + OpStore %8 %21 + OpBranch %10 + %12 = OpLabel + OpStore %22 %9 + OpBranch %23 + %23 = OpLabel + %33 = OpPhi %6 %9 %12 %31 %26 + OpLoopMerge %25 %26 None + OpBranch %27 + %27 = OpLabel + %29 = OpSLessThan %17 %33 %16 + OpBranchConditional %29 %24 %25 + %24 = OpLabel + OpBranch %26 + %26 = OpLabel + %31 = OpIAdd %6 %33 %20 + OpStore %22 %31 + OpBranch %23 + %25 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + // Can't fuse, different step + for (int i = 0; i < 10; i++) {} + for (int j = 0; j < 10; j=j+2) {} +} + +*/ +TEST_F(FusionCompatibilityTest, SameBoundsDifferentStep) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %22 "j" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %20 = OpConstant %6 1 + %31 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %22 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %33 = OpPhi %6 %9 %5 %21 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %33 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %13 + %13 = OpLabel + %21 = OpIAdd %6 %33 %20 + OpStore %8 %21 + OpBranch %10 + %12 = OpLabel + OpStore %22 %9 + OpBranch %23 + %23 = OpLabel + %34 = OpPhi %6 %9 %12 %32 %26 + OpLoopMerge %25 %26 None + OpBranch %27 + %27 = OpLabel + %29 = OpSLessThan %17 %34 %16 + OpBranchConditional %29 %24 %25 + %24 = OpLabel + OpBranch %26 + %26 = OpLabel + %32 = OpIAdd %6 %34 %31 + OpStore %22 %32 + OpBranch %23 + %25 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_FALSE(fusion.AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 4 +#version 440 core +void main() { + // Can't fuse, different upper bound + for (int i = 0; i < 10; i++) {} + for (int j = 0; j < 20; j++) {} +} + +*/ +TEST_F(FusionCompatibilityTest, DifferentUpperBound) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %22 "j" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %20 = OpConstant %6 1 + %29 = OpConstant %6 20 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %22 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %33 = OpPhi %6 %9 %5 %21 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %33 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %13 + %13 = OpLabel + %21 = OpIAdd %6 %33 %20 + OpStore %8 %21 + OpBranch %10 + %12 = OpLabel + OpStore %22 %9 + OpBranch %23 + %23 = OpLabel + %34 = OpPhi %6 %9 %12 %32 %26 + OpLoopMerge %25 %26 None + OpBranch %27 + %27 = OpLabel + %30 = OpSLessThan %17 %34 %29 + OpBranchConditional %30 %24 %25 + %24 = OpLabel + OpBranch %26 + %26 = OpLabel + %32 = OpIAdd %6 %34 %20 + OpStore %22 %32 + OpBranch %23 + %25 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_FALSE(fusion.AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 5 +#version 440 core +void main() { + // Can't fuse, different lower bound + for (int i = 5; i < 10; i++) {} + for (int j = 0; j < 10; j++) {} +} + +*/ +TEST_F(FusionCompatibilityTest, DifferentLowerBound) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %22 "j" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 5 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %20 = OpConstant %6 1 + %23 = OpConstant %6 0 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %22 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %33 = OpPhi %6 %9 %5 %21 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %33 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %13 + %13 = OpLabel + %21 = OpIAdd %6 %33 %20 + OpStore %8 %21 + OpBranch %10 + %12 = OpLabel + OpStore %22 %23 + OpBranch %24 + %24 = OpLabel + %34 = OpPhi %6 %23 %12 %32 %27 + OpLoopMerge %26 %27 None + OpBranch %28 + %28 = OpLabel + %30 = OpSLessThan %17 %34 %16 + OpBranchConditional %30 %25 %26 + %25 = OpLabel + OpBranch %27 + %27 = OpLabel + %32 = OpIAdd %6 %34 %20 + OpStore %22 %32 + OpBranch %24 + %26 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_FALSE(fusion.AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 6 +#version 440 core +void main() { + // Can't fuse, break in first loop + for (int i = 0; i < 10; i++) { + if (i == 5) { + break; + } + } + for (int j = 0; j < 10; j++) {} +} + +*/ +TEST_F(FusionCompatibilityTest, Break) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %28 "j" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %20 = OpConstant %6 5 + %26 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %28 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %38 = OpPhi %6 %9 %5 %27 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %38 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %21 = OpIEqual %17 %38 %20 + OpSelectionMerge %23 None + OpBranchConditional %21 %22 %23 + %22 = OpLabel + OpBranch %12 + %23 = OpLabel + OpBranch %13 + %13 = OpLabel + %27 = OpIAdd %6 %38 %26 + OpStore %8 %27 + OpBranch %10 + %12 = OpLabel + OpStore %28 %9 + OpBranch %29 + %29 = OpLabel + %39 = OpPhi %6 %9 %12 %37 %32 + OpLoopMerge %31 %32 None + OpBranch %33 + %33 = OpLabel + %35 = OpSLessThan %17 %39 %16 + OpBranchConditional %35 %30 %31 + %30 = OpLabel + OpBranch %32 + %32 = OpLabel + %37 = OpIAdd %6 %39 %26 + OpStore %28 %37 + OpBranch %29 + %31 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_FALSE(fusion.AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +layout(location = 0) in vec4 c; +void main() { + int N = int(c.x); + for (int i = 0; i < N; i++) {} + for (int j = 0; j < N; j++) {} +} + +*/ +TEST_F(FusionCompatibilityTest, UnknownButSameUpperBound) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %12 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "N" + OpName %12 "c" + OpName %19 "i" + OpName %33 "j" + OpDecorate %12 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpTypeFloat 32 + %10 = OpTypeVector %9 4 + %11 = OpTypePointer Input %10 + %12 = OpVariable %11 Input + %13 = OpTypeInt 32 0 + %14 = OpConstant %13 0 + %15 = OpTypePointer Input %9 + %20 = OpConstant %6 0 + %28 = OpTypeBool + %31 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %33 = OpVariable %7 Function + %16 = OpAccessChain %15 %12 %14 + %17 = OpLoad %9 %16 + %18 = OpConvertFToS %6 %17 + OpStore %8 %18 + OpStore %19 %20 + OpBranch %21 + %21 = OpLabel + %44 = OpPhi %6 %20 %5 %32 %24 + OpLoopMerge %23 %24 None + OpBranch %25 + %25 = OpLabel + %29 = OpSLessThan %28 %44 %18 + OpBranchConditional %29 %22 %23 + %22 = OpLabel + OpBranch %24 + %24 = OpLabel + %32 = OpIAdd %6 %44 %31 + OpStore %19 %32 + OpBranch %21 + %23 = OpLabel + OpStore %33 %20 + OpBranch %34 + %34 = OpLabel + %46 = OpPhi %6 %20 %23 %43 %37 + OpLoopMerge %36 %37 None + OpBranch %38 + %38 = OpLabel + %41 = OpSLessThan %28 %46 %18 + OpBranchConditional %41 %35 %36 + %35 = OpLabel + OpBranch %37 + %37 = OpLabel + %43 = OpIAdd %6 %46 %31 + OpStore %33 %43 + OpBranch %34 + %36 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +layout(location = 0) in vec4 c; +void main() { + int N = int(c.x); + for (int i = 0; N > j; i++) {} + for (int j = 0; N > j; j++) {} +} +*/ +TEST_F(FusionCompatibilityTest, UnknownButSameUpperBoundReverseCondition) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %12 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "N" + OpName %12 "c" + OpName %19 "i" + OpName %33 "j" + OpDecorate %12 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpTypeFloat 32 + %10 = OpTypeVector %9 4 + %11 = OpTypePointer Input %10 + %12 = OpVariable %11 Input + %13 = OpTypeInt 32 0 + %14 = OpConstant %13 0 + %15 = OpTypePointer Input %9 + %20 = OpConstant %6 0 + %28 = OpTypeBool + %31 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %33 = OpVariable %7 Function + %16 = OpAccessChain %15 %12 %14 + %17 = OpLoad %9 %16 + %18 = OpConvertFToS %6 %17 + OpStore %8 %18 + OpStore %19 %20 + OpBranch %21 + %21 = OpLabel + %45 = OpPhi %6 %20 %5 %32 %24 + OpLoopMerge %23 %24 None + OpBranch %25 + %25 = OpLabel + %29 = OpSGreaterThan %28 %18 %45 + OpBranchConditional %29 %22 %23 + %22 = OpLabel + OpBranch %24 + %24 = OpLabel + %32 = OpIAdd %6 %45 %31 + OpStore %19 %32 + OpBranch %21 + %23 = OpLabel + OpStore %33 %20 + OpBranch %34 + %34 = OpLabel + %47 = OpPhi %6 %20 %23 %43 %37 + OpLoopMerge %36 %37 None + OpBranch %38 + %38 = OpLabel + %41 = OpSGreaterThan %28 %18 %47 + OpBranchConditional %41 %35 %36 + %35 = OpLabel + OpBranch %37 + %37 = OpLabel + %43 = OpIAdd %6 %47 %31 + OpStore %33 %43 + OpBranch %34 + %36 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +layout(location = 0) in vec4 c; +void main() { + // Can't fuse different bound + int N = int(c.x); + for (int i = 0; i < N; i++) {} + for (int j = 0; j < N+1; j++) {} +} + +*/ +TEST_F(FusionCompatibilityTest, UnknownUpperBoundAddition) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %12 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "N" + OpName %12 "c" + OpName %19 "i" + OpName %33 "j" + OpDecorate %12 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpTypeFloat 32 + %10 = OpTypeVector %9 4 + %11 = OpTypePointer Input %10 + %12 = OpVariable %11 Input + %13 = OpTypeInt 32 0 + %14 = OpConstant %13 0 + %15 = OpTypePointer Input %9 + %20 = OpConstant %6 0 + %28 = OpTypeBool + %31 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %33 = OpVariable %7 Function + %16 = OpAccessChain %15 %12 %14 + %17 = OpLoad %9 %16 + %18 = OpConvertFToS %6 %17 + OpStore %8 %18 + OpStore %19 %20 + OpBranch %21 + %21 = OpLabel + %45 = OpPhi %6 %20 %5 %32 %24 + OpLoopMerge %23 %24 None + OpBranch %25 + %25 = OpLabel + %29 = OpSLessThan %28 %45 %18 + OpBranchConditional %29 %22 %23 + %22 = OpLabel + OpBranch %24 + %24 = OpLabel + %32 = OpIAdd %6 %45 %31 + OpStore %19 %32 + OpBranch %21 + %23 = OpLabel + OpStore %33 %20 + OpBranch %34 + %34 = OpLabel + %47 = OpPhi %6 %20 %23 %44 %37 + OpLoopMerge %36 %37 None + OpBranch %38 + %38 = OpLabel + %41 = OpIAdd %6 %18 %31 + %42 = OpSLessThan %28 %47 %41 + OpBranchConditional %42 %35 %36 + %35 = OpLabel + OpBranch %37 + %37 = OpLabel + %44 = OpIAdd %6 %47 %31 + OpStore %33 %44 + OpBranch %34 + %36 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_FALSE(fusion.AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 10 +#version 440 core +void main() { + for (int i = 0; i < 10; i++) {} + for (int j = 0; j < 10; j++) {} + for (int k = 0; k < 10; k++) {} +} + +*/ +TEST_F(FusionCompatibilityTest, SeveralAdjacentLoops) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %22 "j" + OpName %32 "k" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %20 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %22 = OpVariable %7 Function + %32 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %42 = OpPhi %6 %9 %5 %21 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %42 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %13 + %13 = OpLabel + %21 = OpIAdd %6 %42 %20 + OpStore %8 %21 + OpBranch %10 + %12 = OpLabel + OpStore %22 %9 + OpBranch %23 + %23 = OpLabel + %43 = OpPhi %6 %9 %12 %31 %26 + OpLoopMerge %25 %26 None + OpBranch %27 + %27 = OpLabel + %29 = OpSLessThan %17 %43 %16 + OpBranchConditional %29 %24 %25 + %24 = OpLabel + OpBranch %26 + %26 = OpLabel + %31 = OpIAdd %6 %43 %20 + OpStore %22 %31 + OpBranch %23 + %25 = OpLabel + OpStore %32 %9 + OpBranch %33 + %33 = OpLabel + %44 = OpPhi %6 %9 %25 %41 %36 + OpLoopMerge %35 %36 None + OpBranch %37 + %37 = OpLabel + %39 = OpSLessThan %17 %44 %16 + OpBranchConditional %39 %34 %35 + %34 = OpLabel + OpBranch %36 + %36 = OpLabel + %41 = OpIAdd %6 %44 %20 + OpStore %32 %41 + OpBranch %33 + %35 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 3u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + auto loop_0 = loops[0]; + auto loop_1 = loops[1]; + auto loop_2 = loops[2]; + + EXPECT_FALSE(LoopFusion(context.get(), loop_0, loop_0).AreCompatible()); + EXPECT_FALSE(LoopFusion(context.get(), loop_0, loop_2).AreCompatible()); + EXPECT_FALSE(LoopFusion(context.get(), loop_1, loop_0).AreCompatible()); + EXPECT_TRUE(LoopFusion(context.get(), loop_0, loop_1).AreCompatible()); + EXPECT_TRUE(LoopFusion(context.get(), loop_1, loop_2).AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + // Can't fuse, not adjacent + int x = 0; + for (int i = 0; i < 10; i++) { + if (i > 10) { + x++; + } + } + x++; + for (int j = 0; j < 10; j++) {} + for (int k = 0; k < 10; k++) {} +} + +*/ +TEST_F(FusionCompatibilityTest, NonAdjacentLoops) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "x" + OpName %10 "i" + OpName %31 "j" + OpName %41 "k" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %17 = OpConstant %6 10 + %18 = OpTypeBool + %25 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %31 = OpVariable %7 Function + %41 = OpVariable %7 Function + OpStore %8 %9 + OpStore %10 %9 + OpBranch %11 + %11 = OpLabel + %52 = OpPhi %6 %9 %5 %56 %14 + %51 = OpPhi %6 %9 %5 %28 %14 + OpLoopMerge %13 %14 None + OpBranch %15 + %15 = OpLabel + %19 = OpSLessThan %18 %51 %17 + OpBranchConditional %19 %12 %13 + %12 = OpLabel + %21 = OpSGreaterThan %18 %52 %17 + OpSelectionMerge %23 None + OpBranchConditional %21 %22 %23 + %22 = OpLabel + %26 = OpIAdd %6 %52 %25 + OpStore %8 %26 + OpBranch %23 + %23 = OpLabel + %56 = OpPhi %6 %52 %12 %26 %22 + OpBranch %14 + %14 = OpLabel + %28 = OpIAdd %6 %51 %25 + OpStore %10 %28 + OpBranch %11 + %13 = OpLabel + %30 = OpIAdd %6 %52 %25 + OpStore %8 %30 + OpStore %31 %9 + OpBranch %32 + %32 = OpLabel + %53 = OpPhi %6 %9 %13 %40 %35 + OpLoopMerge %34 %35 None + OpBranch %36 + %36 = OpLabel + %38 = OpSLessThan %18 %53 %17 + OpBranchConditional %38 %33 %34 + %33 = OpLabel + OpBranch %35 + %35 = OpLabel + %40 = OpIAdd %6 %53 %25 + OpStore %31 %40 + OpBranch %32 + %34 = OpLabel + OpStore %41 %9 + OpBranch %42 + %42 = OpLabel + %54 = OpPhi %6 %9 %34 %50 %45 + OpLoopMerge %44 %45 None + OpBranch %46 + %46 = OpLabel + %48 = OpSLessThan %18 %54 %17 + OpBranchConditional %48 %43 %44 + %43 = OpLabel + OpBranch %45 + %45 = OpLabel + %50 = OpIAdd %6 %54 %25 + OpStore %41 %50 + OpBranch %42 + %44 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 3u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + auto loop_0 = loops[0]; + auto loop_1 = loops[1]; + auto loop_2 = loops[2]; + + EXPECT_FALSE(LoopFusion(context.get(), loop_0, loop_0).AreCompatible()); + EXPECT_FALSE(LoopFusion(context.get(), loop_0, loop_2).AreCompatible()); + EXPECT_FALSE(LoopFusion(context.get(), loop_0, loop_1).AreCompatible()); + EXPECT_TRUE(LoopFusion(context.get(), loop_1, loop_2).AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 12 +#version 440 core +void main() { + int j = 0; + int i = 0; + for (; i < 10; i++) {} + for (; j < 10; j++) {} +} + +*/ +TEST_F(FusionCompatibilityTest, CompatibleInitDeclaredBeforeLoops) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "j" + OpName %10 "i" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %17 = OpConstant %6 10 + %18 = OpTypeBool + %21 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + OpStore %8 %9 + OpStore %10 %9 + OpBranch %11 + %11 = OpLabel + %32 = OpPhi %6 %9 %5 %22 %14 + OpLoopMerge %13 %14 None + OpBranch %15 + %15 = OpLabel + %19 = OpSLessThan %18 %32 %17 + OpBranchConditional %19 %12 %13 + %12 = OpLabel + OpBranch %14 + %14 = OpLabel + %22 = OpIAdd %6 %32 %21 + OpStore %10 %22 + OpBranch %11 + %13 = OpLabel + OpBranch %23 + %23 = OpLabel + %33 = OpPhi %6 %9 %13 %31 %26 + OpLoopMerge %25 %26 None + OpBranch %27 + %27 = OpLabel + %29 = OpSLessThan %18 %33 %17 + OpBranchConditional %29 %24 %25 + %24 = OpLabel + OpBranch %26 + %26 = OpLabel + %31 = OpIAdd %6 %33 %21 + OpStore %8 %31 + OpBranch %23 + %25 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + EXPECT_TRUE(LoopFusion(context.get(), loops[0], loops[1]).AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 13 regenerate! +#version 440 core +void main() { + int[10] a; + int[10] b; + // Can't fuse, several induction variables + for (int j = 0; j < 10; j++) { + b[i] = a[i]; + } + for (int i = 0, j = 0; i < 10; i++, j = j+2) { + } +} + +*/ +TEST_F(FusionCompatibilityTest, SeveralInductionVariables) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "j" + OpName %23 "b" + OpName %25 "a" + OpName %33 "i" + OpName %34 "j" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %31 = OpConstant %6 1 + %48 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %25 = OpVariable %22 Function + %33 = OpVariable %7 Function + %34 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %50 = OpPhi %6 %9 %5 %32 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %50 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpAccessChain %7 %25 %50 + %28 = OpLoad %6 %27 + %29 = OpAccessChain %7 %23 %50 + OpStore %29 %28 + OpBranch %13 + %13 = OpLabel + %32 = OpIAdd %6 %50 %31 + OpStore %8 %32 + OpBranch %10 + %12 = OpLabel + OpStore %33 %9 + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %52 = OpPhi %6 %9 %12 %49 %38 + %51 = OpPhi %6 %9 %12 %46 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %51 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %44 = OpAccessChain %7 %25 %52 + OpStore %44 %51 + OpBranch %38 + %38 = OpLabel + %46 = OpIAdd %6 %51 %31 + OpStore %33 %46 + %49 = OpIAdd %6 %52 %48 + OpStore %34 %49 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + EXPECT_FALSE(LoopFusion(context.get(), loops[0], loops[1]).AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 14 +#version 440 core +void main() { + // Fine + for (int i = 0; i < 10; i = i + 2) {} + for (int j = 0; j < 10; j = j + 2) {} +} + +*/ +TEST_F(FusionCompatibilityTest, CompatibleNonIncrementStep) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "j" + OpName %10 "i" + OpName %11 "i" + OpName %24 "j" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %18 = OpConstant %6 10 + %19 = OpTypeBool + %22 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %11 = OpVariable %7 Function + %24 = OpVariable %7 Function + OpStore %8 %9 + OpStore %10 %9 + OpStore %11 %9 + OpBranch %12 + %12 = OpLabel + %34 = OpPhi %6 %9 %5 %23 %15 + OpLoopMerge %14 %15 None + OpBranch %16 + %16 = OpLabel + %20 = OpSLessThan %19 %34 %18 + OpBranchConditional %20 %13 %14 + %13 = OpLabel + OpBranch %15 + %15 = OpLabel + %23 = OpIAdd %6 %34 %22 + OpStore %11 %23 + OpBranch %12 + %14 = OpLabel + OpStore %24 %9 + OpBranch %25 + %25 = OpLabel + %35 = OpPhi %6 %9 %14 %33 %28 + OpLoopMerge %27 %28 None + OpBranch %29 + %29 = OpLabel + %31 = OpSLessThan %19 %35 %18 + OpBranchConditional %31 %26 %27 + %26 = OpLabel + OpBranch %28 + %28 = OpLabel + %33 = OpIAdd %6 %35 %22 + OpStore %24 %33 + OpBranch %25 + %27 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + EXPECT_TRUE(LoopFusion(context.get(), loops[0], loops[1]).AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 15 +#version 440 core + +int j = 0; + +void main() { + // Not compatible, unknown init for second. + for (int i = 0; i < 10; i = i + 2) {} + for (; j < 10; j = j + 2) {} +} + +*/ +TEST_F(FusionCompatibilityTest, UnknonInitForSecondLoop) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "j" + OpName %11 "i" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Private %6 + %8 = OpVariable %7 Private + %9 = OpConstant %6 0 + %10 = OpTypePointer Function %6 + %18 = OpConstant %6 10 + %19 = OpTypeBool + %22 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %11 = OpVariable %10 Function + OpStore %8 %9 + OpStore %11 %9 + OpBranch %12 + %12 = OpLabel + %33 = OpPhi %6 %9 %5 %23 %15 + OpLoopMerge %14 %15 None + OpBranch %16 + %16 = OpLabel + %20 = OpSLessThan %19 %33 %18 + OpBranchConditional %20 %13 %14 + %13 = OpLabel + OpBranch %15 + %15 = OpLabel + %23 = OpIAdd %6 %33 %22 + OpStore %11 %23 + OpBranch %12 + %14 = OpLabel + OpBranch %24 + %24 = OpLabel + OpLoopMerge %26 %27 None + OpBranch %28 + %28 = OpLabel + %29 = OpLoad %6 %8 + %30 = OpSLessThan %19 %29 %18 + OpBranchConditional %30 %25 %26 + %25 = OpLabel + OpBranch %27 + %27 = OpLabel + %31 = OpLoad %6 %8 + %32 = OpIAdd %6 %31 %22 + OpStore %8 %32 + OpBranch %24 + %26 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + EXPECT_FALSE(LoopFusion(context.get(), loops[0], loops[1]).AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 16 +#version 440 core +void main() { + // Not compatible, continue in loop 0 + for (int i = 0; i < 10; ++i) { + if (i % 2 == 1) { + continue; + } + } + for (int j = 0; j < 10; ++j) {} +} + +*/ +TEST_F(FusionCompatibilityTest, Continue) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %29 "j" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %20 = OpConstant %6 2 + %22 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %29 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %39 = OpPhi %6 %9 %5 %28 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %39 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %21 = OpSMod %6 %39 %20 + %23 = OpIEqual %17 %21 %22 + OpSelectionMerge %25 None + OpBranchConditional %23 %24 %25 + %24 = OpLabel + OpBranch %13 + %25 = OpLabel + OpBranch %13 + %13 = OpLabel + %28 = OpIAdd %6 %39 %22 + OpStore %8 %28 + OpBranch %10 + %12 = OpLabel + OpStore %29 %9 + OpBranch %30 + %30 = OpLabel + %40 = OpPhi %6 %9 %12 %38 %33 + OpLoopMerge %32 %33 None + OpBranch %34 + %34 = OpLabel + %36 = OpSLessThan %17 %40 %16 + OpBranchConditional %36 %31 %32 + %31 = OpLabel + OpBranch %33 + %33 = OpLabel + %38 = OpIAdd %6 %40 %22 + OpStore %29 %38 + OpBranch %30 + %32 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + EXPECT_FALSE(LoopFusion(context.get(), loops[0], loops[1]).AreCompatible()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + // Compatible + for (int i = 0; i < 10; ++i) { + if (i % 2 == 1) { + } else { + a[i] = i; + } + } + for (int j = 0; j < 10; ++j) {} +} + +*/ +TEST_F(FusionCompatibilityTest, IfElseInLoop) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %31 "a" + OpName %37 "j" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %20 = OpConstant %6 2 + %22 = OpConstant %6 1 + %27 = OpTypeInt 32 0 + %28 = OpConstant %27 10 + %29 = OpTypeArray %6 %28 + %30 = OpTypePointer Function %29 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %31 = OpVariable %30 Function + %37 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %47 = OpPhi %6 %9 %5 %36 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %47 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %21 = OpSMod %6 %47 %20 + %23 = OpIEqual %17 %21 %22 + OpSelectionMerge %25 None + OpBranchConditional %23 %24 %26 + %24 = OpLabel + OpBranch %25 + %26 = OpLabel + %34 = OpAccessChain %7 %31 %47 + OpStore %34 %47 + OpBranch %25 + %25 = OpLabel + OpBranch %13 + %13 = OpLabel + %36 = OpIAdd %6 %47 %22 + OpStore %8 %36 + OpBranch %10 + %12 = OpLabel + OpStore %37 %9 + OpBranch %38 + %38 = OpLabel + %48 = OpPhi %6 %9 %12 %46 %41 + OpLoopMerge %40 %41 None + OpBranch %42 + %42 = OpLabel + %44 = OpSLessThan %17 %48 %16 + OpBranchConditional %44 %39 %40 + %39 = OpLabel + OpBranch %41 + %41 = OpLabel + %46 = OpIAdd %6 %48 %22 + OpStore %37 %46 + OpBranch %38 + %40 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + EXPECT_TRUE(LoopFusion(context.get(), loops[0], loops[1]).AreCompatible()); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/loop_optimizations/fusion_illegal.cpp b/test/opt/loop_optimizations/fusion_illegal.cpp new file mode 100644 index 000000000..26d54457d --- /dev/null +++ b/test/opt/loop_optimizations/fusion_illegal.cpp @@ -0,0 +1,1592 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/loop_fusion.h" +#include "test/opt/pass_fixture.h" + +namespace spvtools { +namespace opt { +namespace { + +using FusionIllegalTest = PassTest<::testing::Test>; + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Illegal, loop-independent dependence will become a + // backward loop-carried antidependence + for (int i = 0; i < 10; i++) { + a[i] = b[i] + 1; + } + for (int i = 0; i < 10; i++) { + c[i] = a[i+1] + 2; + } +} + +*/ +TEST_F(FusionIllegalTest, PositiveDistanceCreatedRAW) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %25 "b" + OpName %34 "i" + OpName %42 "c" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %29 = OpConstant %6 1 + %48 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %25 = OpVariable %22 Function + %34 = OpVariable %7 Function + %42 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %53 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %53 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpAccessChain %7 %25 %53 + %28 = OpLoad %6 %27 + %30 = OpIAdd %6 %28 %29 + %31 = OpAccessChain %7 %23 %53 + OpStore %31 %30 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %53 %29 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %54 = OpPhi %6 %9 %12 %52 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %54 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %45 = OpIAdd %6 %54 %29 + %46 = OpAccessChain %7 %23 %45 + %47 = OpLoad %6 %46 + %49 = OpIAdd %6 %47 %48 + %50 = OpAccessChain %7 %42 %54 + OpStore %50 %49 + OpBranch %38 + %38 = OpLabel + %52 = OpIAdd %6 %54 %29 + OpStore %34 %52 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_FALSE(fusion.IsLegal()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core + +int func() { + return 10; +} + +void main() { + int[10] a; + int[10] b; + // Illegal, function call + for (int i = 0; i < 10; i++) { + a[i] = func(); + } + for (int i = 0; i < 10; i++) { + b[i] = a[i]; + } +} +*/ +TEST_F(FusionIllegalTest, FunctionCall) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "func(" + OpName %14 "i" + OpName %28 "a" + OpName %35 "i" + OpName %43 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypeFunction %6 + %10 = OpConstant %6 10 + %13 = OpTypePointer Function %6 + %15 = OpConstant %6 0 + %22 = OpTypeBool + %24 = OpTypeInt 32 0 + %25 = OpConstant %24 10 + %26 = OpTypeArray %6 %25 + %27 = OpTypePointer Function %26 + %33 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %14 = OpVariable %13 Function + %28 = OpVariable %27 Function + %35 = OpVariable %13 Function + %43 = OpVariable %27 Function + OpStore %14 %15 + OpBranch %16 + %16 = OpLabel + %51 = OpPhi %6 %15 %5 %34 %19 + OpLoopMerge %18 %19 None + OpBranch %20 + %20 = OpLabel + %23 = OpSLessThan %22 %51 %10 + OpBranchConditional %23 %17 %18 + %17 = OpLabel + %30 = OpFunctionCall %6 %8 + %31 = OpAccessChain %13 %28 %51 + OpStore %31 %30 + OpBranch %19 + %19 = OpLabel + %34 = OpIAdd %6 %51 %33 + OpStore %14 %34 + OpBranch %16 + %18 = OpLabel + OpStore %35 %15 + OpBranch %36 + %36 = OpLabel + %52 = OpPhi %6 %15 %18 %50 %39 + OpLoopMerge %38 %39 None + OpBranch %40 + %40 = OpLabel + %42 = OpSLessThan %22 %52 %10 + OpBranchConditional %42 %37 %38 + %37 = OpLabel + %46 = OpAccessChain %13 %28 %52 + %47 = OpLoad %6 %46 + %48 = OpAccessChain %13 %43 %52 + OpStore %48 %47 + OpBranch %39 + %39 = OpLabel + %50 = OpIAdd %6 %52 %33 + OpStore %35 %50 + OpBranch %36 + %38 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %6 None %7 + %9 = OpLabel + OpReturnValue %10 + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_FALSE(fusion.IsLegal()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 16 +#version 440 core +void main() { + int[10][10] a; + int[10][10] b; + int[10][10] c; + // Illegal outer. + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + c[i][j] = a[i][j] + 2; + } + } + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + b[i][j] = c[i+1][j] + 10; + } + } +} + +*/ +TEST_F(FusionIllegalTest, PositiveDistanceCreatedRAWOuterLoop) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %19 "j" + OpName %32 "c" + OpName %35 "a" + OpName %48 "i" + OpName %56 "j" + OpName %64 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %27 = OpTypeInt 32 0 + %28 = OpConstant %27 10 + %29 = OpTypeArray %6 %28 + %30 = OpTypeArray %29 %28 + %31 = OpTypePointer Function %30 + %40 = OpConstant %6 2 + %44 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %32 = OpVariable %31 Function + %35 = OpVariable %31 Function + %48 = OpVariable %7 Function + %56 = OpVariable %7 Function + %64 = OpVariable %31 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %78 = OpPhi %6 %9 %5 %47 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %78 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpStore %19 %9 + OpBranch %20 + %20 = OpLabel + %82 = OpPhi %6 %9 %11 %45 %23 + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + %26 = OpSLessThan %17 %82 %16 + OpBranchConditional %26 %21 %22 + %21 = OpLabel + %38 = OpAccessChain %7 %35 %78 %82 + %39 = OpLoad %6 %38 + %41 = OpIAdd %6 %39 %40 + %42 = OpAccessChain %7 %32 %78 %82 + OpStore %42 %41 + OpBranch %23 + %23 = OpLabel + %45 = OpIAdd %6 %82 %44 + OpStore %19 %45 + OpBranch %20 + %22 = OpLabel + OpBranch %13 + %13 = OpLabel + %47 = OpIAdd %6 %78 %44 + OpStore %8 %47 + OpBranch %10 + %12 = OpLabel + OpStore %48 %9 + OpBranch %49 + %49 = OpLabel + %79 = OpPhi %6 %9 %12 %77 %52 + OpLoopMerge %51 %52 None + OpBranch %53 + %53 = OpLabel + %55 = OpSLessThan %17 %79 %16 + OpBranchConditional %55 %50 %51 + %50 = OpLabel + OpStore %56 %9 + OpBranch %57 + %57 = OpLabel + %80 = OpPhi %6 %9 %50 %75 %60 + OpLoopMerge %59 %60 None + OpBranch %61 + %61 = OpLabel + %63 = OpSLessThan %17 %80 %16 + OpBranchConditional %63 %58 %59 + %58 = OpLabel + %68 = OpIAdd %6 %79 %44 + %70 = OpAccessChain %7 %32 %68 %80 + %71 = OpLoad %6 %70 + %72 = OpIAdd %6 %71 %16 + %73 = OpAccessChain %7 %64 %79 %80 + OpStore %73 %72 + OpBranch %60 + %60 = OpLabel + %75 = OpIAdd %6 %80 %44 + OpStore %56 %75 + OpBranch %57 + %59 = OpLabel + OpBranch %52 + %52 = OpLabel + %77 = OpIAdd %6 %79 %44 + OpStore %48 %77 + OpBranch %49 + %51 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 4u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + auto loop_0 = loops[0]; + auto loop_1 = loops[1]; + auto loop_2 = loops[2]; + auto loop_3 = loops[3]; + + { + LoopFusion fusion(context.get(), loop_0, loop_1); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_0, loop_2); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_FALSE(fusion.IsLegal()); + } + + { + LoopFusion fusion(context.get(), loop_1, loop_2); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_2, loop_3); + EXPECT_FALSE(fusion.AreCompatible()); + } + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 19 +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Illegal, would create a backward loop-carried anti-dependence. + for (int i = 0; i < 10; i++) { + c[i] = a[i] + 1; + } + for (int i = 0; i < 10; i++) { + a[i+1] = c[i] + 2; + } +} + +*/ +TEST_F(FusionIllegalTest, PositiveDistanceCreatedWAR) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "c" + OpName %25 "a" + OpName %34 "i" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %29 = OpConstant %6 1 + %47 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %25 = OpVariable %22 Function + %34 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %52 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %52 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpAccessChain %7 %25 %52 + %28 = OpLoad %6 %27 + %30 = OpIAdd %6 %28 %29 + %31 = OpAccessChain %7 %23 %52 + OpStore %31 %30 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %52 %29 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %53 = OpPhi %6 %9 %12 %51 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %53 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %43 = OpIAdd %6 %53 %29 + %45 = OpAccessChain %7 %23 %53 + %46 = OpLoad %6 %45 + %48 = OpIAdd %6 %46 %47 + %49 = OpAccessChain %7 %25 %43 + OpStore %49 %48 + OpBranch %38 + %38 = OpLabel + %51 = OpIAdd %6 %53 %29 + OpStore %34 %51 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_FALSE(fusion.IsLegal()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 21 +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Illegal, would create a backward loop-carried anti-dependence. + for (int i = 0; i < 10; i++) { + a[i] = b[i] + 1; + } + for (int i = 0; i < 10; i++) { + a[i+1] = c[i+1] + 2; + } +} + +*/ +TEST_F(FusionIllegalTest, PositiveDistanceCreatedWAW) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %25 "b" + OpName %34 "i" + OpName %44 "c" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %29 = OpConstant %6 1 + %49 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %25 = OpVariable %22 Function + %34 = OpVariable %7 Function + %44 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %54 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %54 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpAccessChain %7 %25 %54 + %28 = OpLoad %6 %27 + %30 = OpIAdd %6 %28 %29 + %31 = OpAccessChain %7 %23 %54 + OpStore %31 %30 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %54 %29 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %55 = OpPhi %6 %9 %12 %53 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %55 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %43 = OpIAdd %6 %55 %29 + %46 = OpIAdd %6 %55 %29 + %47 = OpAccessChain %7 %44 %46 + %48 = OpLoad %6 %47 + %50 = OpIAdd %6 %48 %49 + %51 = OpAccessChain %7 %23 %43 + OpStore %51 %50 + OpBranch %38 + %38 = OpLabel + %53 = OpIAdd %6 %55 %29 + OpStore %34 %53 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_FALSE(fusion.IsLegal()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 28 +#version 440 core +void main() { + int[10] a; + int[10] b; + + int sum_0 = 0; + + // Illegal + for (int i = 0; i < 10; i++) { + sum_0 += a[i]; + } + for (int j = 0; j < 10; j++) { + sum_0 += b[j]; + } +} + +*/ +TEST_F(FusionIllegalTest, SameReductionVariable) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "sum_0" + OpName %10 "i" + OpName %24 "a" + OpName %33 "j" + OpName %41 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %17 = OpConstant %6 10 + %18 = OpTypeBool + %20 = OpTypeInt 32 0 + %21 = OpConstant %20 10 + %22 = OpTypeArray %6 %21 + %23 = OpTypePointer Function %22 + %31 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %24 = OpVariable %23 Function + %33 = OpVariable %7 Function + %41 = OpVariable %23 Function + OpStore %8 %9 + OpStore %10 %9 + OpBranch %11 + %11 = OpLabel + %52 = OpPhi %6 %9 %5 %29 %14 + %49 = OpPhi %6 %9 %5 %32 %14 + OpLoopMerge %13 %14 None + OpBranch %15 + %15 = OpLabel + %19 = OpSLessThan %18 %49 %17 + OpBranchConditional %19 %12 %13 + %12 = OpLabel + %26 = OpAccessChain %7 %24 %49 + %27 = OpLoad %6 %26 + %29 = OpIAdd %6 %52 %27 + OpStore %8 %29 + OpBranch %14 + %14 = OpLabel + %32 = OpIAdd %6 %49 %31 + OpStore %10 %32 + OpBranch %11 + %13 = OpLabel + OpStore %33 %9 + OpBranch %34 + %34 = OpLabel + %51 = OpPhi %6 %52 %13 %46 %37 + %50 = OpPhi %6 %9 %13 %48 %37 + OpLoopMerge %36 %37 None + OpBranch %38 + %38 = OpLabel + %40 = OpSLessThan %18 %50 %17 + OpBranchConditional %40 %35 %36 + %35 = OpLabel + %43 = OpAccessChain %7 %41 %50 + %44 = OpLoad %6 %43 + %46 = OpIAdd %6 %51 %44 + OpStore %8 %46 + OpBranch %37 + %37 = OpLabel + %48 = OpIAdd %6 %50 %31 + OpStore %33 %48 + OpBranch %34 + %36 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_FALSE(fusion.IsLegal()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 28 +#version 440 core +void main() { + int[10] a; + int[10] b; + + int sum_0 = 0; + + // Illegal + for (int i = 0; i < 10; i++) { + sum_0 += a[i]; + } + for (int j = 0; j < 10; j++) { + sum_0 += b[j]; + } +} + +*/ +TEST_F(FusionIllegalTest, SameReductionVariableLCSSA) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "sum_0" + OpName %10 "i" + OpName %24 "a" + OpName %33 "j" + OpName %41 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %17 = OpConstant %6 10 + %18 = OpTypeBool + %20 = OpTypeInt 32 0 + %21 = OpConstant %20 10 + %22 = OpTypeArray %6 %21 + %23 = OpTypePointer Function %22 + %31 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %24 = OpVariable %23 Function + %33 = OpVariable %7 Function + %41 = OpVariable %23 Function + OpStore %8 %9 + OpStore %10 %9 + OpBranch %11 + %11 = OpLabel + %52 = OpPhi %6 %9 %5 %29 %14 + %49 = OpPhi %6 %9 %5 %32 %14 + OpLoopMerge %13 %14 None + OpBranch %15 + %15 = OpLabel + %19 = OpSLessThan %18 %49 %17 + OpBranchConditional %19 %12 %13 + %12 = OpLabel + %26 = OpAccessChain %7 %24 %49 + %27 = OpLoad %6 %26 + %29 = OpIAdd %6 %52 %27 + OpStore %8 %29 + OpBranch %14 + %14 = OpLabel + %32 = OpIAdd %6 %49 %31 + OpStore %10 %32 + OpBranch %11 + %13 = OpLabel + OpStore %33 %9 + OpBranch %34 + %34 = OpLabel + %51 = OpPhi %6 %52 %13 %46 %37 + %50 = OpPhi %6 %9 %13 %48 %37 + OpLoopMerge %36 %37 None + OpBranch %38 + %38 = OpLabel + %40 = OpSLessThan %18 %50 %17 + OpBranchConditional %40 %35 %36 + %35 = OpLabel + %43 = OpAccessChain %7 %41 %50 + %44 = OpLoad %6 %43 + %46 = OpIAdd %6 %51 %44 + OpStore %8 %46 + OpBranch %37 + %37 = OpLabel + %48 = OpIAdd %6 %50 %31 + OpStore %33 %48 + OpBranch %34 + %36 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopUtils utils_0(context.get(), loops[0]); + utils_0.MakeLoopClosedSSA(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_FALSE(fusion.IsLegal()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 30 +#version 440 core +int x; +void main() { + int[10] a; + int[10] b; + + // Illegal, x is unknown. + for (int i = 0; i < 10; i++) { + a[x] = a[i]; + } + for (int j = 0; j < 10; j++) { + a[j] = b[j]; + } +} + +*/ +TEST_F(FusionIllegalTest, UnknownIndexVariable) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %25 "x" + OpName %34 "j" + OpName %43 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %24 = OpTypePointer Private %6 + %25 = OpVariable %24 Private + %32 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %34 = OpVariable %7 Function + %43 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %50 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %50 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %26 = OpLoad %6 %25 + %28 = OpAccessChain %7 %23 %50 + %29 = OpLoad %6 %28 + %30 = OpAccessChain %7 %23 %26 + OpStore %30 %29 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %50 %32 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %51 = OpPhi %6 %9 %12 %49 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %51 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %45 = OpAccessChain %7 %43 %51 + %46 = OpLoad %6 %45 + %47 = OpAccessChain %7 %23 %51 + OpStore %47 %46 + OpBranch %38 + %38 = OpLabel + %49 = OpIAdd %6 %51 %32 + OpStore %34 %49 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_FALSE(fusion.IsLegal()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + + int sum = 0; + + // Illegal, accumulator used for indexing. + for (int i = 0; i < 10; i++) { + sum += a[i]; + b[sum] = a[i]; + } + for (int j = 0; j < 10; j++) { + b[j] = b[j]+1; + } +} + +*/ +TEST_F(FusionIllegalTest, AccumulatorIndexing) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "sum" + OpName %10 "i" + OpName %24 "a" + OpName %30 "b" + OpName %39 "j" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %17 = OpConstant %6 10 + %18 = OpTypeBool + %20 = OpTypeInt 32 0 + %21 = OpConstant %20 10 + %22 = OpTypeArray %6 %21 + %23 = OpTypePointer Function %22 + %37 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %24 = OpVariable %23 Function + %30 = OpVariable %23 Function + %39 = OpVariable %7 Function + OpStore %8 %9 + OpStore %10 %9 + OpBranch %11 + %11 = OpLabel + %57 = OpPhi %6 %9 %5 %29 %14 + %55 = OpPhi %6 %9 %5 %38 %14 + OpLoopMerge %13 %14 None + OpBranch %15 + %15 = OpLabel + %19 = OpSLessThan %18 %55 %17 + OpBranchConditional %19 %12 %13 + %12 = OpLabel + %26 = OpAccessChain %7 %24 %55 + %27 = OpLoad %6 %26 + %29 = OpIAdd %6 %57 %27 + OpStore %8 %29 + %33 = OpAccessChain %7 %24 %55 + %34 = OpLoad %6 %33 + %35 = OpAccessChain %7 %30 %29 + OpStore %35 %34 + OpBranch %14 + %14 = OpLabel + %38 = OpIAdd %6 %55 %37 + OpStore %10 %38 + OpBranch %11 + %13 = OpLabel + OpStore %39 %9 + OpBranch %40 + %40 = OpLabel + %56 = OpPhi %6 %9 %13 %54 %43 + OpLoopMerge %42 %43 None + OpBranch %44 + %44 = OpLabel + %46 = OpSLessThan %18 %56 %17 + OpBranchConditional %46 %41 %42 + %41 = OpLabel + %49 = OpAccessChain %7 %30 %56 + %50 = OpLoad %6 %49 + %51 = OpIAdd %6 %50 %37 + %52 = OpAccessChain %7 %30 %56 + OpStore %52 %51 + OpBranch %43 + %43 = OpLabel + %54 = OpIAdd %6 %56 %37 + OpStore %39 %54 + OpBranch %40 + %42 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_FALSE(fusion.IsLegal()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 33 +#version 440 core +void main() { + int[10] a; + int[10] b; + + // Illegal, barrier. + for (int i = 0; i < 10; i++) { + a[i] = a[i] * 2; + memoryBarrier(); + } + for (int j = 0; j < 10; j++) { + b[j] = b[j] + 1; + } +} + +*/ +TEST_F(FusionIllegalTest, Barrier) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %36 "j" + OpName %44 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %28 = OpConstant %6 2 + %31 = OpConstant %19 1 + %32 = OpConstant %19 3400 + %34 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %36 = OpVariable %7 Function + %44 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %53 = OpPhi %6 %9 %5 %35 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %53 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %26 = OpAccessChain %7 %23 %53 + %27 = OpLoad %6 %26 + %29 = OpIMul %6 %27 %28 + %30 = OpAccessChain %7 %23 %53 + OpStore %30 %29 + OpMemoryBarrier %31 %32 + OpBranch %13 + %13 = OpLabel + %35 = OpIAdd %6 %53 %34 + OpStore %8 %35 + OpBranch %10 + %12 = OpLabel + OpStore %36 %9 + OpBranch %37 + %37 = OpLabel + %54 = OpPhi %6 %9 %12 %52 %40 + OpLoopMerge %39 %40 None + OpBranch %41 + %41 = OpLabel + %43 = OpSLessThan %17 %54 %16 + OpBranchConditional %43 %38 %39 + %38 = OpLabel + %47 = OpAccessChain %7 %44 %54 + %48 = OpLoad %6 %47 + %49 = OpIAdd %6 %48 %34 + %50 = OpAccessChain %7 %44 %54 + OpStore %50 %49 + OpBranch %40 + %40 = OpLabel + %52 = OpIAdd %6 %54 %34 + OpStore %36 %52 + OpBranch %37 + %39 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_FALSE(fusion.IsLegal()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +struct TestStruct { + int[10] a; + int b; +}; + +void main() { + TestStruct test_0; + TestStruct test_1; + + for (int i = 0; i < 10; i++) { + test_0.a[i] = i; + } + for (int j = 0; j < 10; j++) { + test_0 = test_1; + } +} + +*/ +TEST_F(FusionIllegalTest, ArrayInStruct) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %22 "TestStruct" + OpMemberName %22 0 "a" + OpMemberName %22 1 "b" + OpName %24 "test_0" + OpName %31 "j" + OpName %39 "test_1" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypeStruct %21 %6 + %23 = OpTypePointer Function %22 + %29 = OpConstant %6 1 + %47 = OpUndef %22 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %24 = OpVariable %23 Function + %31 = OpVariable %7 Function + %39 = OpVariable %23 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %43 = OpPhi %6 %9 %5 %30 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %43 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpAccessChain %7 %24 %9 %43 + OpStore %27 %43 + OpBranch %13 + %13 = OpLabel + %30 = OpIAdd %6 %43 %29 + OpStore %8 %30 + OpBranch %10 + %12 = OpLabel + OpStore %31 %9 + OpBranch %32 + %32 = OpLabel + %44 = OpPhi %6 %9 %12 %42 %35 + OpLoopMerge %34 %35 None + OpBranch %36 + %36 = OpLabel + %38 = OpSLessThan %17 %44 %16 + OpBranchConditional %38 %33 %34 + %33 = OpLabel + OpStore %24 %47 + OpBranch %35 + %35 = OpLabel + %42 = OpIAdd %6 %44 %29 + OpStore %31 %42 + OpBranch %32 + %34 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_FALSE(fusion.IsLegal()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 450 + +struct P {float x,y,z;}; +uniform G { int a; P b[2]; int c; } g; +layout(location = 0) out float o; + +void main() +{ + P p[2]; + for (int i = 0; i < 2; ++i) { + p = g.b; + } + for (int j = 0; j < 2; ++j) { + o = p[g.a].x; + } +} + +*/ +TEST_F(FusionIllegalTest, NestedAccessChain) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %64 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 450 + OpName %4 "main" + OpName %8 "i" + OpName %20 "P" + OpMemberName %20 0 "x" + OpMemberName %20 1 "y" + OpMemberName %20 2 "z" + OpName %25 "p" + OpName %26 "P" + OpMemberName %26 0 "x" + OpMemberName %26 1 "y" + OpMemberName %26 2 "z" + OpName %28 "G" + OpMemberName %28 0 "a" + OpMemberName %28 1 "b" + OpMemberName %28 2 "c" + OpName %30 "g" + OpName %55 "j" + OpName %64 "o" + OpMemberDecorate %26 0 Offset 0 + OpMemberDecorate %26 1 Offset 4 + OpMemberDecorate %26 2 Offset 8 + OpDecorate %27 ArrayStride 16 + OpMemberDecorate %28 0 Offset 0 + OpMemberDecorate %28 1 Offset 16 + OpMemberDecorate %28 2 Offset 48 + OpDecorate %28 Block + OpDecorate %30 DescriptorSet 0 + OpDecorate %64 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 2 + %17 = OpTypeBool + %19 = OpTypeFloat 32 + %20 = OpTypeStruct %19 %19 %19 + %21 = OpTypeInt 32 0 + %22 = OpConstant %21 2 + %23 = OpTypeArray %20 %22 + %24 = OpTypePointer Function %23 + %26 = OpTypeStruct %19 %19 %19 + %27 = OpTypeArray %26 %22 + %28 = OpTypeStruct %6 %27 %6 + %29 = OpTypePointer Uniform %28 + %30 = OpVariable %29 Uniform + %31 = OpConstant %6 1 + %32 = OpTypePointer Uniform %27 + %36 = OpTypePointer Function %20 + %39 = OpTypePointer Function %19 + %63 = OpTypePointer Output %19 + %64 = OpVariable %63 Output + %65 = OpTypePointer Uniform %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %25 = OpVariable %24 Function + %55 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %72 = OpPhi %6 %9 %5 %54 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %72 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %33 = OpAccessChain %32 %30 %31 + %34 = OpLoad %27 %33 + %35 = OpCompositeExtract %26 %34 0 + %37 = OpAccessChain %36 %25 %9 + %38 = OpCompositeExtract %19 %35 0 + %40 = OpAccessChain %39 %37 %9 + OpStore %40 %38 + %41 = OpCompositeExtract %19 %35 1 + %42 = OpAccessChain %39 %37 %31 + OpStore %42 %41 + %43 = OpCompositeExtract %19 %35 2 + %44 = OpAccessChain %39 %37 %16 + OpStore %44 %43 + %45 = OpCompositeExtract %26 %34 1 + %46 = OpAccessChain %36 %25 %31 + %47 = OpCompositeExtract %19 %45 0 + %48 = OpAccessChain %39 %46 %9 + OpStore %48 %47 + %49 = OpCompositeExtract %19 %45 1 + %50 = OpAccessChain %39 %46 %31 + OpStore %50 %49 + %51 = OpCompositeExtract %19 %45 2 + %52 = OpAccessChain %39 %46 %16 + OpStore %52 %51 + OpBranch %13 + %13 = OpLabel + %54 = OpIAdd %6 %72 %31 + OpStore %8 %54 + OpBranch %10 + %12 = OpLabel + OpStore %55 %9 + OpBranch %56 + %56 = OpLabel + %73 = OpPhi %6 %9 %12 %71 %59 + OpLoopMerge %58 %59 None + OpBranch %60 + %60 = OpLabel + %62 = OpSLessThan %17 %73 %16 + OpBranchConditional %62 %57 %58 + %57 = OpLabel + %66 = OpAccessChain %65 %30 %9 + %67 = OpLoad %6 %66 + %68 = OpAccessChain %39 %25 %67 %9 + %69 = OpLoad %19 %68 + OpStore %64 %69 + OpBranch %59 + %59 = OpLabel + %71 = OpIAdd %6 %73 %31 + OpStore %55 %71 + OpBranch %56 + %58 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_FALSE(fusion.IsLegal()); + } +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/loop_optimizations/fusion_legal.cpp b/test/opt/loop_optimizations/fusion_legal.cpp new file mode 100644 index 000000000..41d796fd9 --- /dev/null +++ b/test/opt/loop_optimizations/fusion_legal.cpp @@ -0,0 +1,4578 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "effcee/effcee.h" +#include "gmock/gmock.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/loop_fusion.h" +#include "test/opt/pass_fixture.h" + +namespace spvtools { +namespace opt { +namespace { + +using FusionLegalTest = PassTest<::testing::Test>; + +bool Validate(const std::vector& bin) { + spv_target_env target_env = SPV_ENV_UNIVERSAL_1_2; + spv_context spvContext = spvContextCreate(target_env); + spv_diagnostic diagnostic = nullptr; + spv_const_binary_t binary = {bin.data(), bin.size()}; + spv_result_t error = spvValidate(spvContext, &binary, &diagnostic); + if (error != 0) spvDiagnosticPrint(diagnostic); + spvDiagnosticDestroy(diagnostic); + spvContextDestroy(spvContext); + return error == 0; +} + +void Match(const std::string& checks, IRContext* context) { + // Silence unused warnings with !defined(SPIRV_EFFCE) + (void)checks; + + std::vector bin; + context->module()->ToBinary(&bin, true); + EXPECT_TRUE(Validate(bin)); + std::string assembly; + SpirvTools tools(SPV_ENV_UNIVERSAL_1_2); + EXPECT_TRUE( + tools.Disassemble(bin, &assembly, SPV_BINARY_TO_TEXT_OPTION_NO_HEADER)) + << "Disassembling failed for shader:\n" + << assembly << std::endl; + auto match_result = effcee::Match(assembly, checks); + EXPECT_EQ(effcee::Result::Status::Ok, match_result.status()) + << match_result.message() << "\nChecking result:\n" + << assembly; +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + // No dependence, legal + for (int i = 0; i < 10; i++) { + a[i] = a[i]*2; + } + for (int i = 0; i < 10; i++) { + b[i] = b[i]+2; + } +} + +*/ +TEST_F(FusionLegalTest, DifferentArraysInLoops) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %34 "i" + OpName %42 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %28 = OpConstant %6 2 + %32 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %34 = OpVariable %7 Function + %42 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %51 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %51 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %26 = OpAccessChain %7 %23 %51 + %27 = OpLoad %6 %26 + %29 = OpIMul %6 %27 %28 + %30 = OpAccessChain %7 %23 %51 + OpStore %30 %29 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %51 %32 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %52 = OpPhi %6 %9 %12 %50 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %52 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %45 = OpAccessChain %7 %42 %52 + %46 = OpLoad %6 %45 + %47 = OpIAdd %6 %46 %28 + %48 = OpAccessChain %7 %42 %52 + OpStore %48 %47 + OpBranch %38 + %38 = OpLabel + %50 = OpIAdd %6 %52 %32 + OpStore %34 %50 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] +)"; + + Match(checks, context.get()); + auto& ld_final = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld_final.NumLoops(), 1u); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Only loads to the same array, legal + for (int i = 0; i < 10; i++) { + b[i] = a[i]*2; + } + for (int i = 0; i < 10; i++) { + c[i] = a[i]+2; + } +} + +*/ +TEST_F(FusionLegalTest, OnlyLoadsToSameArray) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "b" + OpName %25 "a" + OpName %35 "i" + OpName %43 "c" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %29 = OpConstant %6 2 + %33 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %25 = OpVariable %22 Function + %35 = OpVariable %7 Function + %43 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %52 = OpPhi %6 %9 %5 %34 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %52 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpAccessChain %7 %25 %52 + %28 = OpLoad %6 %27 + %30 = OpIMul %6 %28 %29 + %31 = OpAccessChain %7 %23 %52 + OpStore %31 %30 + OpBranch %13 + %13 = OpLabel + %34 = OpIAdd %6 %52 %33 + OpStore %8 %34 + OpBranch %10 + %12 = OpLabel + OpStore %35 %9 + OpBranch %36 + %36 = OpLabel + %53 = OpPhi %6 %9 %12 %51 %39 + OpLoopMerge %38 %39 None + OpBranch %40 + %40 = OpLabel + %42 = OpSLessThan %17 %53 %16 + OpBranchConditional %42 %37 %38 + %37 = OpLabel + %46 = OpAccessChain %7 %25 %53 + %47 = OpLoad %6 %46 + %48 = OpIAdd %6 %47 %29 + %49 = OpAccessChain %7 %43 %53 + OpStore %49 %48 + OpBranch %39 + %39 = OpLabel + %51 = OpIAdd %6 %53 %33 + OpStore %35 %51 + OpBranch %36 + %38 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] +)"; + + Match(checks, context.get()); + auto& ld_final = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld_final.NumLoops(), 1u); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + // No loop-carried dependences, legal + for (int i = 0; i < 10; i++) { + a[i] = a[i]*2; + } + for (int i = 0; i < 10; i++) { + b[i] = a[i]+2; + } +} + +*/ +TEST_F(FusionLegalTest, NoLoopCarriedDependences) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %34 "i" + OpName %42 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %28 = OpConstant %6 2 + %32 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %34 = OpVariable %7 Function + %42 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %51 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %51 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %26 = OpAccessChain %7 %23 %51 + %27 = OpLoad %6 %26 + %29 = OpIMul %6 %27 %28 + %30 = OpAccessChain %7 %23 %51 + OpStore %30 %29 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %51 %32 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %52 = OpPhi %6 %9 %12 %50 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %52 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %45 = OpAccessChain %7 %23 %52 + %46 = OpLoad %6 %45 + %47 = OpIAdd %6 %46 %28 + %48 = OpAccessChain %7 %42 %52 + OpStore %48 %47 + OpBranch %38 + %38 = OpLabel + %50 = OpIAdd %6 %52 %32 + OpStore %34 %50 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] +)"; + + Match(checks, context.get()); + auto& ld_final = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld_final.NumLoops(), 1u); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Parallelism inhibiting, but legal. + for (int i = 0; i < 10; i++) { + a[i] = b[i] + 1; + } + for (int i = 0; i < 10; i++) { + c[i] = a[i] + c[i-1]; + } +} + +*/ +TEST_F(FusionLegalTest, ExistingLoopCarriedDependence) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %25 "b" + OpName %34 "i" + OpName %42 "c" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %29 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %25 = OpVariable %22 Function + %34 = OpVariable %7 Function + %42 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %55 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %55 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpAccessChain %7 %25 %55 + %28 = OpLoad %6 %27 + %30 = OpIAdd %6 %28 %29 + %31 = OpAccessChain %7 %23 %55 + OpStore %31 %30 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %55 %29 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %56 = OpPhi %6 %9 %12 %54 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %56 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %45 = OpAccessChain %7 %23 %56 + %46 = OpLoad %6 %45 + %48 = OpISub %6 %56 %29 + %49 = OpAccessChain %7 %42 %48 + %50 = OpLoad %6 %49 + %51 = OpIAdd %6 %46 %50 + %52 = OpAccessChain %7 %42 %56 + OpStore %52 %51 + OpBranch %38 + %38 = OpLabel + %54 = OpIAdd %6 %56 %29 + OpStore %34 %54 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[I_1:%\w+]] = OpISub {{%\w+}} [[PHI]] {{%\w+}} +CHECK-NEXT: [[LOAD_2:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[I_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_2]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] +)"; + + Match(checks, context.get()); + auto& ld_final = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld_final.NumLoops(), 1u); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Creates a loop-carried dependence, but negative, so legal + for (int i = 0; i < 10; i++) { + a[i+1] = b[i] + 1; + } + for (int i = 0; i < 10; i++) { + c[i] = a[i] + 2; + } +} + +*/ +TEST_F(FusionLegalTest, NegativeDistanceCreatedRAW) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %27 "b" + OpName %35 "i" + OpName %43 "c" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %25 = OpConstant %6 1 + %48 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %27 = OpVariable %22 Function + %35 = OpVariable %7 Function + %43 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %53 = OpPhi %6 %9 %5 %34 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %53 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %26 = OpIAdd %6 %53 %25 + %29 = OpAccessChain %7 %27 %53 + %30 = OpLoad %6 %29 + %31 = OpIAdd %6 %30 %25 + %32 = OpAccessChain %7 %23 %26 + OpStore %32 %31 + OpBranch %13 + %13 = OpLabel + %34 = OpIAdd %6 %53 %25 + OpStore %8 %34 + OpBranch %10 + %12 = OpLabel + OpStore %35 %9 + OpBranch %36 + %36 = OpLabel + %54 = OpPhi %6 %9 %12 %52 %39 + OpLoopMerge %38 %39 None + OpBranch %40 + %40 = OpLabel + %42 = OpSLessThan %17 %54 %16 + OpBranchConditional %42 %37 %38 + %37 = OpLabel + %46 = OpAccessChain %7 %23 %54 + %47 = OpLoad %6 %46 + %49 = OpIAdd %6 %47 %48 + %50 = OpAccessChain %7 %43 %54 + OpStore %50 %49 + OpBranch %39 + %39 = OpLabel + %52 = OpIAdd %6 %54 %25 + OpStore %35 %52 + OpBranch %36 + %38 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[I_1:%\w+]] = OpIAdd {{%\w+}} [[PHI]] {{%\w+}} +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[I_1]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } + + { + auto& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 1u); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Legal + for (int i = 0; i < 10; i++) { + a[i+1] = b[i] + 1; + } + for (int i = 0; i < 10; i++) { + c[i] = a[i+1] + 2; + } +} + +*/ +TEST_F(FusionLegalTest, NoLoopCarriedDependencesAdjustedIndex) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %27 "b" + OpName %35 "i" + OpName %43 "c" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %25 = OpConstant %6 1 + %49 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %27 = OpVariable %22 Function + %35 = OpVariable %7 Function + %43 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %54 = OpPhi %6 %9 %5 %34 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %54 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %26 = OpIAdd %6 %54 %25 + %29 = OpAccessChain %7 %27 %54 + %30 = OpLoad %6 %29 + %31 = OpIAdd %6 %30 %25 + %32 = OpAccessChain %7 %23 %26 + OpStore %32 %31 + OpBranch %13 + %13 = OpLabel + %34 = OpIAdd %6 %54 %25 + OpStore %8 %34 + OpBranch %10 + %12 = OpLabel + OpStore %35 %9 + OpBranch %36 + %36 = OpLabel + %55 = OpPhi %6 %9 %12 %53 %39 + OpLoopMerge %38 %39 None + OpBranch %40 + %40 = OpLabel + %42 = OpSLessThan %17 %55 %16 + OpBranchConditional %42 %37 %38 + %37 = OpLabel + %46 = OpIAdd %6 %55 %25 + %47 = OpAccessChain %7 %23 %46 + %48 = OpLoad %6 %47 + %50 = OpIAdd %6 %48 %49 + %51 = OpAccessChain %7 %43 %55 + OpStore %51 %50 + OpBranch %39 + %39 = OpLabel + %53 = OpIAdd %6 %55 %25 + OpStore %35 %53 + OpBranch %36 + %38 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[I_1:%\w+]] = OpIAdd {{%\w+}} [[PHI]] {{%\w+}} +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[I_1]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[I_1:%\w+]] = OpIAdd {{%\w+}} [[PHI]] {{%\w+}} +CHECK-NEXT: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[I_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] +)"; + + Match(checks, context.get()); + auto& ld_final = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld_final.NumLoops(), 1u); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Legal, independent locations in |a|, SIV + for (int i = 0; i < 10; i++) { + a[2*i+1] = b[i] + 1; + } + for (int i = 0; i < 10; i++) { + c[i] = a[2*i] + 2; + } +} + +*/ +TEST_F(FusionLegalTest, IndependentSIV) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %29 "b" + OpName %37 "i" + OpName %45 "c" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %24 = OpConstant %6 2 + %27 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %29 = OpVariable %22 Function + %37 = OpVariable %7 Function + %45 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %55 = OpPhi %6 %9 %5 %36 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %55 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %26 = OpIMul %6 %24 %55 + %28 = OpIAdd %6 %26 %27 + %31 = OpAccessChain %7 %29 %55 + %32 = OpLoad %6 %31 + %33 = OpIAdd %6 %32 %27 + %34 = OpAccessChain %7 %23 %28 + OpStore %34 %33 + OpBranch %13 + %13 = OpLabel + %36 = OpIAdd %6 %55 %27 + OpStore %8 %36 + OpBranch %10 + %12 = OpLabel + OpStore %37 %9 + OpBranch %38 + %38 = OpLabel + %56 = OpPhi %6 %9 %12 %54 %41 + OpLoopMerge %40 %41 None + OpBranch %42 + %42 = OpLabel + %44 = OpSLessThan %17 %56 %16 + OpBranchConditional %44 %39 %40 + %39 = OpLabel + %48 = OpIMul %6 %24 %56 + %49 = OpAccessChain %7 %23 %48 + %50 = OpLoad %6 %49 + %51 = OpIAdd %6 %50 %24 + %52 = OpAccessChain %7 %45 %56 + OpStore %52 %51 + OpBranch %41 + %41 = OpLabel + %54 = OpIAdd %6 %56 %27 + OpStore %37 %54 + OpBranch %38 + %40 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[I_2:%\w+]] = OpIMul {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: [[I_2_1:%\w+]] = OpIAdd {{%\w+}} [[I_2]] {{%\w+}} +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[I_2_1]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[I_2:%\w+]] = OpIMul {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[I_2]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] +)"; + + Match(checks, context.get()); + auto& ld_final = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld_final.NumLoops(), 1u); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Legal, independent locations in |a|, ZIV + for (int i = 0; i < 10; i++) { + a[1] = b[i] + 1; + } + for (int i = 0; i < 10; i++) { + c[i] = a[9] + 2; + } +} + +*/ +TEST_F(FusionLegalTest, IndependentZIV) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %25 "b" + OpName %33 "i" + OpName %41 "c" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %24 = OpConstant %6 1 + %43 = OpConstant %6 9 + %46 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %25 = OpVariable %22 Function + %33 = OpVariable %7 Function + %41 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %51 = OpPhi %6 %9 %5 %32 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %51 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpAccessChain %7 %25 %51 + %28 = OpLoad %6 %27 + %29 = OpIAdd %6 %28 %24 + %30 = OpAccessChain %7 %23 %24 + OpStore %30 %29 + OpBranch %13 + %13 = OpLabel + %32 = OpIAdd %6 %51 %24 + OpStore %8 %32 + OpBranch %10 + %12 = OpLabel + OpStore %33 %9 + OpBranch %34 + %34 = OpLabel + %52 = OpPhi %6 %9 %12 %50 %37 + OpLoopMerge %36 %37 None + OpBranch %38 + %38 = OpLabel + %40 = OpSLessThan %17 %52 %16 + OpBranchConditional %40 %35 %36 + %35 = OpLabel + %44 = OpAccessChain %7 %23 %43 + %45 = OpLoad %6 %44 + %47 = OpIAdd %6 %45 %46 + %48 = OpAccessChain %7 %41 %52 + OpStore %48 %47 + OpBranch %37 + %37 = OpLabel + %50 = OpIAdd %6 %52 %24 + OpStore %33 %50 + OpBranch %34 + %36 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK-NOT: OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK: OpStore +CHECK-NOT: OpPhi +CHECK-NOT: OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK: OpLoad +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] +)"; + + Match(checks, context.get()); + auto& ld_final = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld_final.NumLoops(), 1u); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[20] a; + int[10] b; + int[10] c; + // Legal, non-overlapping sections in |a| + for (int i = 0; i < 10; i++) { + a[i] = b[i] + 1; + } + for (int i = 0; i < 10; i++) { + c[i] = a[i+10] + 2; + } +} + +*/ +TEST_F(FusionLegalTest, NonOverlappingAccesses) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %28 "b" + OpName %37 "i" + OpName %45 "c" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 20 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %25 = OpConstant %19 10 + %26 = OpTypeArray %6 %25 + %27 = OpTypePointer Function %26 + %32 = OpConstant %6 1 + %51 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %28 = OpVariable %27 Function + %37 = OpVariable %7 Function + %45 = OpVariable %27 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %56 = OpPhi %6 %9 %5 %36 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %56 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %30 = OpAccessChain %7 %28 %56 + %31 = OpLoad %6 %30 + %33 = OpIAdd %6 %31 %32 + %34 = OpAccessChain %7 %23 %56 + OpStore %34 %33 + OpBranch %13 + %13 = OpLabel + %36 = OpIAdd %6 %56 %32 + OpStore %8 %36 + OpBranch %10 + %12 = OpLabel + OpStore %37 %9 + OpBranch %38 + %38 = OpLabel + %57 = OpPhi %6 %9 %12 %55 %41 + OpLoopMerge %40 %41 None + OpBranch %42 + %42 = OpLabel + %44 = OpSLessThan %17 %57 %16 + OpBranchConditional %44 %39 %40 + %39 = OpLabel + %48 = OpIAdd %6 %57 %16 + %49 = OpAccessChain %7 %23 %48 + %50 = OpLoad %6 %49 + %52 = OpIAdd %6 %50 %51 + %53 = OpAccessChain %7 %45 %57 + OpStore %53 %52 + OpBranch %41 + %41 = OpLabel + %55 = OpIAdd %6 %57 %32 + OpStore %37 %55 + OpBranch %38 + %40 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NOT: OpPhi +CHECK: [[I_10:%\w+]] = OpIAdd {{%\w+}} [[PHI]] {{%\w+}} +CHECK-NEXT: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[I_10]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] +)"; + + Match(checks, context.get()); + + auto& ld_final = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld_final.NumLoops(), 1u); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Legal, 3 adjacent loops + for (int i = 0; i < 10; i++) { + a[i] = b[i] + 1; + } + for (int i = 0; i < 10; i++) { + c[i] = a[i] + 2; + } + for (int i = 0; i < 10; i++) { + b[i] = c[i] + 10; + } +} + +*/ +TEST_F(FusionLegalTest, AdjacentLoops) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %25 "b" + OpName %34 "i" + OpName %42 "c" + OpName %52 "i" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %29 = OpConstant %6 1 + %47 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %25 = OpVariable %22 Function + %34 = OpVariable %7 Function + %42 = OpVariable %22 Function + %52 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %68 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %68 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpAccessChain %7 %25 %68 + %28 = OpLoad %6 %27 + %30 = OpIAdd %6 %28 %29 + %31 = OpAccessChain %7 %23 %68 + OpStore %31 %30 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %68 %29 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %69 = OpPhi %6 %9 %12 %51 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %69 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %45 = OpAccessChain %7 %23 %69 + %46 = OpLoad %6 %45 + %48 = OpIAdd %6 %46 %47 + %49 = OpAccessChain %7 %42 %69 + OpStore %49 %48 + OpBranch %38 + %38 = OpLabel + %51 = OpIAdd %6 %69 %29 + OpStore %34 %51 + OpBranch %35 + %37 = OpLabel + OpStore %52 %9 + OpBranch %53 + %53 = OpLabel + %70 = OpPhi %6 %9 %37 %67 %56 + OpLoopMerge %55 %56 None + OpBranch %57 + %57 = OpLabel + %59 = OpSLessThan %17 %70 %16 + OpBranchConditional %59 %54 %55 + %54 = OpLabel + %62 = OpAccessChain %7 %42 %70 + %63 = OpLoad %6 %62 + %64 = OpIAdd %6 %63 %16 + %65 = OpAccessChain %7 %25 %70 + OpStore %65 %64 + OpBranch %56 + %56 = OpLabel + %67 = OpIAdd %6 %70 %29 + OpStore %52 %67 + OpBranch %53 + %55 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 3u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[1], loops[2]); + + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + + std::string checks = R"( +CHECK: [[PHI_0:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK: [[PHI_1:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_1]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_2:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_2]] +CHECK: [[STORE_2:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_2]] + )"; + + Match(checks, context.get()); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + + std::string checks_ = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_2:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_2]] +CHECK: [[STORE_2:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_2]] + )"; + + Match(checks_, context.get()); + + auto& ld_final = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld_final.NumLoops(), 1u); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10][10] a; + int[10][10] b; + int[10][10] c; + // Legal inner loop fusion + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + c[i][j] = a[i][j] + 2; + } + for (int j = 0; j < 10; j++) { + b[i][j] = c[i][j] + 10; + } + } +} + +*/ +TEST_F(FusionLegalTest, InnerLoopFusion) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %19 "j" + OpName %32 "c" + OpName %35 "a" + OpName %46 "j" + OpName %54 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %27 = OpTypeInt 32 0 + %28 = OpConstant %27 10 + %29 = OpTypeArray %6 %28 + %30 = OpTypeArray %29 %28 + %31 = OpTypePointer Function %30 + %40 = OpConstant %6 2 + %44 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %32 = OpVariable %31 Function + %35 = OpVariable %31 Function + %46 = OpVariable %7 Function + %54 = OpVariable %31 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %67 = OpPhi %6 %9 %5 %66 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %67 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpStore %19 %9 + OpBranch %20 + %20 = OpLabel + %68 = OpPhi %6 %9 %11 %45 %23 + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + %26 = OpSLessThan %17 %68 %16 + OpBranchConditional %26 %21 %22 + %21 = OpLabel + %38 = OpAccessChain %7 %35 %67 %68 + %39 = OpLoad %6 %38 + %41 = OpIAdd %6 %39 %40 + %42 = OpAccessChain %7 %32 %67 %68 + OpStore %42 %41 + OpBranch %23 + %23 = OpLabel + %45 = OpIAdd %6 %68 %44 + OpStore %19 %45 + OpBranch %20 + %22 = OpLabel + OpStore %46 %9 + OpBranch %47 + %47 = OpLabel + %69 = OpPhi %6 %9 %22 %64 %50 + OpLoopMerge %49 %50 None + OpBranch %51 + %51 = OpLabel + %53 = OpSLessThan %17 %69 %16 + OpBranchConditional %53 %48 %49 + %48 = OpLabel + %59 = OpAccessChain %7 %32 %67 %69 + %60 = OpLoad %6 %59 + %61 = OpIAdd %6 %60 %16 + %62 = OpAccessChain %7 %54 %67 %69 + OpStore %62 %61 + OpBranch %50 + %50 = OpLabel + %64 = OpIAdd %6 %69 %44 + OpStore %46 %64 + OpBranch %47 + %49 = OpLabel + OpBranch %13 + %13 = OpLabel + %66 = OpIAdd %6 %67 %44 + OpStore %8 %66 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 3u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + auto loop_0 = loops[0]; + auto loop_1 = loops[1]; + auto loop_2 = loops[2]; + + { + LoopFusion fusion(context.get(), loop_0, loop_1); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_0, loop_2); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_1, loop_2); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + + std::string checks = R"( +CHECK: [[PHI_0:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[PHI_1:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + + auto& ld_final = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld_final.NumLoops(), 2u); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// 12 +#version 440 core +void main() { + int[10][10] a; + int[10][10] b; + int[10][10] c; + // Legal both + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + c[i][j] = a[i][j] + 2; + } + } + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + b[i][j] = c[i][j] + 10; + } + } +} + +*/ +TEST_F(FusionLegalTest, OuterAndInnerLoop) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %19 "j" + OpName %32 "c" + OpName %35 "a" + OpName %48 "i" + OpName %56 "j" + OpName %64 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %27 = OpTypeInt 32 0 + %28 = OpConstant %27 10 + %29 = OpTypeArray %6 %28 + %30 = OpTypeArray %29 %28 + %31 = OpTypePointer Function %30 + %40 = OpConstant %6 2 + %44 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %32 = OpVariable %31 Function + %35 = OpVariable %31 Function + %48 = OpVariable %7 Function + %56 = OpVariable %7 Function + %64 = OpVariable %31 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %77 = OpPhi %6 %9 %5 %47 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %77 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpStore %19 %9 + OpBranch %20 + %20 = OpLabel + %81 = OpPhi %6 %9 %11 %45 %23 + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + %26 = OpSLessThan %17 %81 %16 + OpBranchConditional %26 %21 %22 + %21 = OpLabel + %38 = OpAccessChain %7 %35 %77 %81 + %39 = OpLoad %6 %38 + %41 = OpIAdd %6 %39 %40 + %42 = OpAccessChain %7 %32 %77 %81 + OpStore %42 %41 + OpBranch %23 + %23 = OpLabel + %45 = OpIAdd %6 %81 %44 + OpStore %19 %45 + OpBranch %20 + %22 = OpLabel + OpBranch %13 + %13 = OpLabel + %47 = OpIAdd %6 %77 %44 + OpStore %8 %47 + OpBranch %10 + %12 = OpLabel + OpStore %48 %9 + OpBranch %49 + %49 = OpLabel + %78 = OpPhi %6 %9 %12 %76 %52 + OpLoopMerge %51 %52 None + OpBranch %53 + %53 = OpLabel + %55 = OpSLessThan %17 %78 %16 + OpBranchConditional %55 %50 %51 + %50 = OpLabel + OpStore %56 %9 + OpBranch %57 + %57 = OpLabel + %79 = OpPhi %6 %9 %50 %74 %60 + OpLoopMerge %59 %60 None + OpBranch %61 + %61 = OpLabel + %63 = OpSLessThan %17 %79 %16 + OpBranchConditional %63 %58 %59 + %58 = OpLabel + %69 = OpAccessChain %7 %32 %78 %79 + %70 = OpLoad %6 %69 + %71 = OpIAdd %6 %70 %16 + %72 = OpAccessChain %7 %64 %78 %79 + OpStore %72 %71 + OpBranch %60 + %60 = OpLabel + %74 = OpIAdd %6 %79 %44 + OpStore %56 %74 + OpBranch %57 + %59 = OpLabel + OpBranch %52 + %52 = OpLabel + %76 = OpIAdd %6 %78 %44 + OpStore %48 %76 + OpBranch %49 + %51 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 4u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + auto loop_0 = loops[0]; + auto loop_1 = loops[1]; + auto loop_2 = loops[2]; + auto loop_3 = loops[3]; + + { + LoopFusion fusion(context.get(), loop_0, loop_1); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_1, loop_2); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_2, loop_3); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_1, loop_3); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_0, loop_2); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + fusion.Fuse(); + } + + std::string checks = R"( +CHECK: [[PHI_0:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[PHI_1:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK: [[PHI_2:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_2]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_2]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } + + { + auto& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 3u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + auto loop_0 = loops[0]; + auto loop_1 = loops[1]; + auto loop_2 = loops[2]; + + { + LoopFusion fusion(context.get(), loop_0, loop_1); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_0, loop_2); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_1, loop_2); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + fusion.Fuse(); + } + + std::string checks = R"( +CHECK: [[PHI_0:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[PHI_1:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } + + { + auto& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10][10] a; + int[10][10] b; + int[10][10] c; + // Legal both, more complex + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + if (i % 2 == 0 && j % 2 == 0) { + c[i][j] = a[i][j] + 2; + } + } + } + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + b[i][j] = c[i][j] + 10; + } + } +} + +*/ +TEST_F(FusionLegalTest, OuterAndInnerLoopMoreComplex) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %19 "j" + OpName %44 "c" + OpName %47 "a" + OpName %59 "i" + OpName %67 "j" + OpName %75 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %28 = OpConstant %6 2 + %39 = OpTypeInt 32 0 + %40 = OpConstant %39 10 + %41 = OpTypeArray %6 %40 + %42 = OpTypeArray %41 %40 + %43 = OpTypePointer Function %42 + %55 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %44 = OpVariable %43 Function + %47 = OpVariable %43 Function + %59 = OpVariable %7 Function + %67 = OpVariable %7 Function + %75 = OpVariable %43 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %88 = OpPhi %6 %9 %5 %58 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %88 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpStore %19 %9 + OpBranch %20 + %20 = OpLabel + %92 = OpPhi %6 %9 %11 %56 %23 + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + %26 = OpSLessThan %17 %92 %16 + OpBranchConditional %26 %21 %22 + %21 = OpLabel + %29 = OpSMod %6 %88 %28 + %30 = OpIEqual %17 %29 %9 + OpSelectionMerge %32 None + OpBranchConditional %30 %31 %32 + %31 = OpLabel + %34 = OpSMod %6 %92 %28 + %35 = OpIEqual %17 %34 %9 + OpBranch %32 + %32 = OpLabel + %36 = OpPhi %17 %30 %21 %35 %31 + OpSelectionMerge %38 None + OpBranchConditional %36 %37 %38 + %37 = OpLabel + %50 = OpAccessChain %7 %47 %88 %92 + %51 = OpLoad %6 %50 + %52 = OpIAdd %6 %51 %28 + %53 = OpAccessChain %7 %44 %88 %92 + OpStore %53 %52 + OpBranch %38 + %38 = OpLabel + OpBranch %23 + %23 = OpLabel + %56 = OpIAdd %6 %92 %55 + OpStore %19 %56 + OpBranch %20 + %22 = OpLabel + OpBranch %13 + %13 = OpLabel + %58 = OpIAdd %6 %88 %55 + OpStore %8 %58 + OpBranch %10 + %12 = OpLabel + OpStore %59 %9 + OpBranch %60 + %60 = OpLabel + %89 = OpPhi %6 %9 %12 %87 %63 + OpLoopMerge %62 %63 None + OpBranch %64 + %64 = OpLabel + %66 = OpSLessThan %17 %89 %16 + OpBranchConditional %66 %61 %62 + %61 = OpLabel + OpStore %67 %9 + OpBranch %68 + %68 = OpLabel + %90 = OpPhi %6 %9 %61 %85 %71 + OpLoopMerge %70 %71 None + OpBranch %72 + %72 = OpLabel + %74 = OpSLessThan %17 %90 %16 + OpBranchConditional %74 %69 %70 + %69 = OpLabel + %80 = OpAccessChain %7 %44 %89 %90 + %81 = OpLoad %6 %80 + %82 = OpIAdd %6 %81 %16 + %83 = OpAccessChain %7 %75 %89 %90 + OpStore %83 %82 + OpBranch %71 + %71 = OpLabel + %85 = OpIAdd %6 %90 %55 + OpStore %67 %85 + OpBranch %68 + %70 = OpLabel + OpBranch %63 + %63 = OpLabel + %87 = OpIAdd %6 %89 %55 + OpStore %59 %87 + OpBranch %60 + %62 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 4u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + auto loop_0 = loops[0]; + auto loop_1 = loops[1]; + auto loop_2 = loops[2]; + auto loop_3 = loops[3]; + + { + LoopFusion fusion(context.get(), loop_0, loop_1); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_1, loop_2); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_2, loop_3); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_1, loop_3); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_0, loop_2); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + fusion.Fuse(); + } + + std::string checks = R"( +CHECK: [[PHI_0:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[PHI_1:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: OpPhi +CHECK-NEXT: OpSelectionMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK: [[PHI_2:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_2]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_2]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 3u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + auto loop_0 = loops[0]; + auto loop_1 = loops[1]; + auto loop_2 = loops[2]; + + { + LoopFusion fusion(context.get(), loop_0, loop_1); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_0, loop_2); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_1, loop_2); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + fusion.Fuse(); + } + + std::string checks = R"( +CHECK: [[PHI_0:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[PHI_1:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: OpPhi +CHECK-NEXT: OpSelectionMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10][10] a; + int[10][10] b; + int[10][10] c; + // Outer would have been illegal to fuse, but since written + // like this, inner loop fusion is legal. + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + c[i][j] = a[i][j] + 2; + } + for (int j = 0; j < 10; j++) { + b[i][j] = c[i+1][j] + 10; + } + } +} + +*/ +TEST_F(FusionLegalTest, InnerWithExistingDependenceOnOuter) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %19 "j" + OpName %32 "c" + OpName %35 "a" + OpName %46 "j" + OpName %54 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %27 = OpTypeInt 32 0 + %28 = OpConstant %27 10 + %29 = OpTypeArray %6 %28 + %30 = OpTypeArray %29 %28 + %31 = OpTypePointer Function %30 + %40 = OpConstant %6 2 + %44 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %32 = OpVariable %31 Function + %35 = OpVariable %31 Function + %46 = OpVariable %7 Function + %54 = OpVariable %31 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %68 = OpPhi %6 %9 %5 %67 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %68 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpStore %19 %9 + OpBranch %20 + %20 = OpLabel + %69 = OpPhi %6 %9 %11 %45 %23 + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + %26 = OpSLessThan %17 %69 %16 + OpBranchConditional %26 %21 %22 + %21 = OpLabel + %38 = OpAccessChain %7 %35 %68 %69 + %39 = OpLoad %6 %38 + %41 = OpIAdd %6 %39 %40 + %42 = OpAccessChain %7 %32 %68 %69 + OpStore %42 %41 + OpBranch %23 + %23 = OpLabel + %45 = OpIAdd %6 %69 %44 + OpStore %19 %45 + OpBranch %20 + %22 = OpLabel + OpStore %46 %9 + OpBranch %47 + %47 = OpLabel + %70 = OpPhi %6 %9 %22 %65 %50 + OpLoopMerge %49 %50 None + OpBranch %51 + %51 = OpLabel + %53 = OpSLessThan %17 %70 %16 + OpBranchConditional %53 %48 %49 + %48 = OpLabel + %58 = OpIAdd %6 %68 %44 + %60 = OpAccessChain %7 %32 %58 %70 + %61 = OpLoad %6 %60 + %62 = OpIAdd %6 %61 %16 + %63 = OpAccessChain %7 %54 %68 %70 + OpStore %63 %62 + OpBranch %50 + %50 = OpLabel + %65 = OpIAdd %6 %70 %44 + OpStore %46 %65 + OpBranch %47 + %49 = OpLabel + OpBranch %13 + %13 = OpLabel + %67 = OpIAdd %6 %68 %44 + OpStore %8 %67 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 3u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + auto loop_0 = loops[0]; + auto loop_1 = loops[1]; + auto loop_2 = loops[2]; + + { + LoopFusion fusion(context.get(), loop_0, loop_1); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_0, loop_2); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_1, loop_2); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + std::string checks = R"( +CHECK: [[PHI_0:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[PHI_1:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[I_1:%\w+]] = OpIAdd {{%\w+}} [[PHI_0]] {{%\w+}} +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[I_1]] [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // One dimensional arrays. Legal, outer dist 0, inner independent. + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + c[i] = a[j] + 2; + } + } + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + b[j] = c[i] + 10; + } + } +} + +*/ +TEST_F(FusionLegalTest, OuterAndInnerLoopOneDimArrays) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %19 "j" + OpName %31 "c" + OpName %33 "a" + OpName %45 "i" + OpName %53 "j" + OpName %61 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %27 = OpTypeInt 32 0 + %28 = OpConstant %27 10 + %29 = OpTypeArray %6 %28 + %30 = OpTypePointer Function %29 + %37 = OpConstant %6 2 + %41 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %31 = OpVariable %30 Function + %33 = OpVariable %30 Function + %45 = OpVariable %7 Function + %53 = OpVariable %7 Function + %61 = OpVariable %30 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %72 = OpPhi %6 %9 %5 %44 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %72 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpStore %19 %9 + OpBranch %20 + %20 = OpLabel + %76 = OpPhi %6 %9 %11 %42 %23 + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + %26 = OpSLessThan %17 %76 %16 + OpBranchConditional %26 %21 %22 + %21 = OpLabel + %35 = OpAccessChain %7 %33 %76 + %36 = OpLoad %6 %35 + %38 = OpIAdd %6 %36 %37 + %39 = OpAccessChain %7 %31 %72 + OpStore %39 %38 + OpBranch %23 + %23 = OpLabel + %42 = OpIAdd %6 %76 %41 + OpStore %19 %42 + OpBranch %20 + %22 = OpLabel + OpBranch %13 + %13 = OpLabel + %44 = OpIAdd %6 %72 %41 + OpStore %8 %44 + OpBranch %10 + %12 = OpLabel + OpStore %45 %9 + OpBranch %46 + %46 = OpLabel + %73 = OpPhi %6 %9 %12 %71 %49 + OpLoopMerge %48 %49 None + OpBranch %50 + %50 = OpLabel + %52 = OpSLessThan %17 %73 %16 + OpBranchConditional %52 %47 %48 + %47 = OpLabel + OpStore %53 %9 + OpBranch %54 + %54 = OpLabel + %74 = OpPhi %6 %9 %47 %69 %57 + OpLoopMerge %56 %57 None + OpBranch %58 + %58 = OpLabel + %60 = OpSLessThan %17 %74 %16 + OpBranchConditional %60 %55 %56 + %55 = OpLabel + %64 = OpAccessChain %7 %31 %73 + %65 = OpLoad %6 %64 + %66 = OpIAdd %6 %65 %16 + %67 = OpAccessChain %7 %61 %74 + OpStore %67 %66 + OpBranch %57 + %57 = OpLabel + %69 = OpIAdd %6 %74 %41 + OpStore %53 %69 + OpBranch %54 + %56 = OpLabel + OpBranch %49 + %49 = OpLabel + %71 = OpIAdd %6 %73 %41 + OpStore %45 %71 + OpBranch %46 + %48 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 4u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + auto loop_0 = loops[0]; + auto loop_1 = loops[1]; + auto loop_2 = loops[2]; + auto loop_3 = loops[3]; + + { + LoopFusion fusion(context.get(), loop_0, loop_1); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_1, loop_2); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_2, loop_3); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_0, loop_2); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + fusion.Fuse(); + } + + std::string checks = R"( +CHECK: [[PHI_0:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[PHI_1:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK: [[PHI_2:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_2]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 3u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + auto loop_0 = loops[0]; + auto loop_1 = loops[1]; + auto loop_2 = loops[2]; + + { + LoopFusion fusion(context.get(), loop_0, loop_1); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_0, loop_2); + EXPECT_FALSE(fusion.AreCompatible()); + } + + { + LoopFusion fusion(context.get(), loop_1, loop_2); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + + std::string checks = R"( +CHECK: [[PHI_0:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[PHI_1:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Legal, creates a loop-carried dependence, but has negative distance + for (int i = 0; i < 10; i++) { + c[i] = a[i+1] + 1; + } + for (int i = 0; i < 10; i++) { + a[i] = c[i] + 2; + } +} + +*/ +TEST_F(FusionLegalTest, NegativeDistanceCreatedWAR) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "c" + OpName %25 "a" + OpName %35 "i" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %27 = OpConstant %6 1 + %47 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %25 = OpVariable %22 Function + %35 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %52 = OpPhi %6 %9 %5 %34 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %52 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %28 = OpIAdd %6 %52 %27 + %29 = OpAccessChain %7 %25 %28 + %30 = OpLoad %6 %29 + %31 = OpIAdd %6 %30 %27 + %32 = OpAccessChain %7 %23 %52 + OpStore %32 %31 + OpBranch %13 + %13 = OpLabel + %34 = OpIAdd %6 %52 %27 + OpStore %8 %34 + OpBranch %10 + %12 = OpLabel + OpStore %35 %9 + OpBranch %36 + %36 = OpLabel + %53 = OpPhi %6 %9 %12 %51 %39 + OpLoopMerge %38 %39 None + OpBranch %40 + %40 = OpLabel + %42 = OpSLessThan %17 %53 %16 + OpBranchConditional %42 %37 %38 + %37 = OpLabel + %45 = OpAccessChain %7 %23 %53 + %46 = OpLoad %6 %45 + %48 = OpIAdd %6 %46 %47 + %49 = OpAccessChain %7 %25 %53 + OpStore %49 %48 + OpBranch %39 + %39 = OpLabel + %51 = OpIAdd %6 %53 %27 + OpStore %35 %51 + OpBranch %36 + %38 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK: [[I_1:%\w+]] = OpIAdd {{%\w+}} [[PHI]] {{%\w+}} +CHECK-NEXT: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[I_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } + + { + auto& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 1u); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Legal, creates a loop-carried dependence, but has negative distance + for (int i = 0; i < 10; i++) { + a[i+1] = b[i] + 1; + } + for (int i = 0; i < 10; i++) { + a[i] = c[i+1] + 2; + } +} + +*/ +TEST_F(FusionLegalTest, NegativeDistanceCreatedWAW) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %27 "b" + OpName %35 "i" + OpName %44 "c" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %25 = OpConstant %6 1 + %49 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %27 = OpVariable %22 Function + %35 = OpVariable %7 Function + %44 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %54 = OpPhi %6 %9 %5 %34 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %54 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %26 = OpIAdd %6 %54 %25 + %29 = OpAccessChain %7 %27 %54 + %30 = OpLoad %6 %29 + %31 = OpIAdd %6 %30 %25 + %32 = OpAccessChain %7 %23 %26 + OpStore %32 %31 + OpBranch %13 + %13 = OpLabel + %34 = OpIAdd %6 %54 %25 + OpStore %8 %34 + OpBranch %10 + %12 = OpLabel + OpStore %35 %9 + OpBranch %36 + %36 = OpLabel + %55 = OpPhi %6 %9 %12 %53 %39 + OpLoopMerge %38 %39 None + OpBranch %40 + %40 = OpLabel + %42 = OpSLessThan %17 %55 %16 + OpBranchConditional %42 %37 %38 + %37 = OpLabel + %46 = OpIAdd %6 %55 %25 + %47 = OpAccessChain %7 %44 %46 + %48 = OpLoad %6 %47 + %50 = OpIAdd %6 %48 %49 + %51 = OpAccessChain %7 %23 %55 + OpStore %51 %50 + OpBranch %39 + %39 = OpLabel + %53 = OpIAdd %6 %55 %25 + OpStore %35 %53 + OpBranch %36 + %38 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 1u); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[I_1:%\w+]] = OpIAdd {{%\w+}} [[PHI]] {{%\w+}} +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[I_1]] +CHECK-NEXT: OpStore +CHECK-NOT: OpPhi +CHECK: [[I_1:%\w+]] = OpIAdd {{%\w+}} [[PHI]] {{%\w+}} +CHECK-NEXT: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[I_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Legal, no loop-carried dependence + for (int i = 0; i < 10; i++) { + a[i] = b[i] + 1; + } + for (int i = 0; i < 10; i++) { + a[i] = c[i+1] + 2; + } +} + +*/ +TEST_F(FusionLegalTest, NoLoopCarriedDependencesWAW) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %25 "b" + OpName %34 "i" + OpName %43 "c" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %29 = OpConstant %6 1 + %48 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %25 = OpVariable %22 Function + %34 = OpVariable %7 Function + %43 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %53 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %53 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpAccessChain %7 %25 %53 + %28 = OpLoad %6 %27 + %30 = OpIAdd %6 %28 %29 + %31 = OpAccessChain %7 %23 %53 + OpStore %31 %30 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %53 %29 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %54 = OpPhi %6 %9 %12 %52 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %54 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %45 = OpIAdd %6 %54 %29 + %46 = OpAccessChain %7 %43 %45 + %47 = OpLoad %6 %46 + %49 = OpIAdd %6 %47 %48 + %50 = OpAccessChain %7 %23 %54 + OpStore %50 %49 + OpBranch %38 + %38 = OpLabel + %52 = OpIAdd %6 %54 %29 + OpStore %34 %52 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 1u); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[I_1:%\w+]] = OpIAdd {{%\w+}} [[PHI]] {{%\w+}} +CHECK-NEXT: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[I_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10][10] a; + int[10][10] b; + int[10][10] c; + // Legal outer. Continue and break are fine if nested in inner loops + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + if (j % 2 == 0) { + c[i][j] = a[i][j] + 2; + } else { + continue; + } + } + } + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + if (j % 2 == 0) { + b[i][j] = c[i][j] + 10; + } else { + break; + } + } + } +} + +*/ +TEST_F(FusionLegalTest, OuterloopWithBreakContinueInInner) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %19 "j" + OpName %38 "c" + OpName %41 "a" + OpName %55 "i" + OpName %63 "j" + OpName %76 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %28 = OpConstant %6 2 + %33 = OpTypeInt 32 0 + %34 = OpConstant %33 10 + %35 = OpTypeArray %6 %34 + %36 = OpTypeArray %35 %34 + %37 = OpTypePointer Function %36 + %51 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %38 = OpVariable %37 Function + %41 = OpVariable %37 Function + %55 = OpVariable %7 Function + %63 = OpVariable %7 Function + %76 = OpVariable %37 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %91 = OpPhi %6 %9 %5 %54 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %91 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpStore %19 %9 + OpBranch %20 + %20 = OpLabel + %96 = OpPhi %6 %9 %11 %52 %23 + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + %26 = OpSLessThan %17 %96 %16 + OpBranchConditional %26 %21 %22 + %21 = OpLabel + %29 = OpSMod %6 %96 %28 + %30 = OpIEqual %17 %29 %9 + OpSelectionMerge %23 None + OpBranchConditional %30 %31 %48 + %31 = OpLabel + %44 = OpAccessChain %7 %41 %91 %96 + %45 = OpLoad %6 %44 + %46 = OpIAdd %6 %45 %28 + %47 = OpAccessChain %7 %38 %91 %96 + OpStore %47 %46 + OpBranch %32 + %48 = OpLabel + OpBranch %23 + %32 = OpLabel + OpBranch %23 + %23 = OpLabel + %52 = OpIAdd %6 %96 %51 + OpStore %19 %52 + OpBranch %20 + %22 = OpLabel + OpBranch %13 + %13 = OpLabel + %54 = OpIAdd %6 %91 %51 + OpStore %8 %54 + OpBranch %10 + %12 = OpLabel + OpStore %55 %9 + OpBranch %56 + %56 = OpLabel + %92 = OpPhi %6 %9 %12 %90 %59 + OpLoopMerge %58 %59 None + OpBranch %60 + %60 = OpLabel + %62 = OpSLessThan %17 %92 %16 + OpBranchConditional %62 %57 %58 + %57 = OpLabel + OpStore %63 %9 + OpBranch %64 + %64 = OpLabel + %93 = OpPhi %6 %9 %57 %88 %67 + OpLoopMerge %66 %67 None + OpBranch %68 + %68 = OpLabel + %70 = OpSLessThan %17 %93 %16 + OpBranchConditional %70 %65 %66 + %65 = OpLabel + %72 = OpSMod %6 %93 %28 + %73 = OpIEqual %17 %72 %9 + OpSelectionMerge %75 None + OpBranchConditional %73 %74 %66 + %74 = OpLabel + %81 = OpAccessChain %7 %38 %92 %93 + %82 = OpLoad %6 %81 + %83 = OpIAdd %6 %82 %16 + %84 = OpAccessChain %7 %76 %92 %93 + OpStore %84 %83 + OpBranch %75 + %75 = OpLabel + OpBranch %67 + %67 = OpLabel + %88 = OpIAdd %6 %93 %51 + OpStore %63 %88 + OpBranch %64 + %66 = OpLabel + OpBranch %59 + %59 = OpLabel + %90 = OpIAdd %6 %92 %51 + OpStore %55 %90 + OpBranch %56 + %58 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 4u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[2]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 3u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[1], loops[2]); + EXPECT_FALSE(fusion.AreCompatible()); + + std::string checks = R"( +CHECK: [[PHI_0:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[PHI_1:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK: [[PHI_2:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_2]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] [[PHI_2]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// j loop preheader removed manually +#version 440 core +void main() { + int[10] a; + int[10] b; + int i = 0; + int j = 0; + // No loop-carried dependences, legal + for (; i < 10; i++) { + a[i] = a[i]*2; + } + for (; j < 10; j++) { + b[j] = a[j]+2; + } +} + +*/ +TEST_F(FusionLegalTest, DifferentArraysInLoopsNoPreheader) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %10 "j" + OpName %24 "a" + OpName %42 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %17 = OpConstant %6 10 + %18 = OpTypeBool + %20 = OpTypeInt 32 0 + %21 = OpConstant %20 10 + %22 = OpTypeArray %6 %21 + %23 = OpTypePointer Function %22 + %29 = OpConstant %6 2 + %33 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %24 = OpVariable %23 Function + %42 = OpVariable %23 Function + OpStore %8 %9 + OpStore %10 %9 + OpBranch %11 + %11 = OpLabel + %51 = OpPhi %6 %9 %5 %34 %14 + OpLoopMerge %35 %14 None + OpBranch %15 + %15 = OpLabel + %19 = OpSLessThan %18 %51 %17 + OpBranchConditional %19 %12 %35 + %12 = OpLabel + %27 = OpAccessChain %7 %24 %51 + %28 = OpLoad %6 %27 + %30 = OpIMul %6 %28 %29 + %31 = OpAccessChain %7 %24 %51 + OpStore %31 %30 + OpBranch %14 + %14 = OpLabel + %34 = OpIAdd %6 %51 %33 + OpStore %8 %34 + OpBranch %11 + %35 = OpLabel + %52 = OpPhi %6 %9 %15 %50 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %18 %52 %17 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %45 = OpAccessChain %7 %24 %52 + %46 = OpLoad %6 %45 + %47 = OpIAdd %6 %46 %29 + %48 = OpAccessChain %7 %42 %52 + OpStore %48 %47 + OpBranch %38 + %38 = OpLabel + %50 = OpIAdd %6 %52 %33 + OpStore %10 %50 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + { + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_FALSE(fusion.AreCompatible()); + } + + ld.CreatePreHeaderBlocksIfMissing(); + + { + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 1u); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +// j & k loop preheaders removed manually +#version 440 core +void main() { + int[10] a; + int[10] b; + int i = 0; + int j = 0; + int k = 0; + // No loop-carried dependences, legal + for (; i < 10; i++) { + a[i] = a[i]*2; + } + for (; j < 10; j++) { + b[j] = a[j]+2; + } + for (; k < 10; k++) { + a[k] = a[k]*2; + } +} + +*/ +TEST_F(FusionLegalTest, AdjacentLoopsNoPreheaders) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %10 "j" + OpName %11 "k" + OpName %25 "a" + OpName %43 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %18 = OpConstant %6 10 + %19 = OpTypeBool + %21 = OpTypeInt 32 0 + %22 = OpConstant %21 10 + %23 = OpTypeArray %6 %22 + %24 = OpTypePointer Function %23 + %30 = OpConstant %6 2 + %34 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %11 = OpVariable %7 Function + %25 = OpVariable %24 Function + %43 = OpVariable %24 Function + OpStore %8 %9 + OpStore %10 %9 + OpStore %11 %9 + OpBranch %12 + %12 = OpLabel + %67 = OpPhi %6 %9 %5 %35 %15 + OpLoopMerge %36 %15 None + OpBranch %16 + %16 = OpLabel + %20 = OpSLessThan %19 %67 %18 + OpBranchConditional %20 %13 %36 + %13 = OpLabel + %28 = OpAccessChain %7 %25 %67 + %29 = OpLoad %6 %28 + %31 = OpIMul %6 %29 %30 + %32 = OpAccessChain %7 %25 %67 + OpStore %32 %31 + OpBranch %15 + %15 = OpLabel + %35 = OpIAdd %6 %67 %34 + OpStore %8 %35 + OpBranch %12 + %36 = OpLabel + %68 = OpPhi %6 %9 %16 %51 %39 + OpLoopMerge %52 %39 None + OpBranch %40 + %40 = OpLabel + %42 = OpSLessThan %19 %68 %18 + OpBranchConditional %42 %37 %52 + %37 = OpLabel + %46 = OpAccessChain %7 %25 %68 + %47 = OpLoad %6 %46 + %48 = OpIAdd %6 %47 %30 + %49 = OpAccessChain %7 %43 %68 + OpStore %49 %48 + OpBranch %39 + %39 = OpLabel + %51 = OpIAdd %6 %68 %34 + OpStore %10 %51 + OpBranch %36 + %52 = OpLabel + %70 = OpPhi %6 %9 %40 %66 %55 + OpLoopMerge %54 %55 None + OpBranch %56 + %56 = OpLabel + %58 = OpSLessThan %19 %70 %18 + OpBranchConditional %58 %53 %54 + %53 = OpLabel + %61 = OpAccessChain %7 %25 %70 + %62 = OpLoad %6 %61 + %63 = OpIMul %6 %62 %30 + %64 = OpAccessChain %7 %25 %70 + OpStore %64 %63 + OpBranch %55 + %55 = OpLabel + %66 = OpIAdd %6 %70 %34 + OpStore %11 %66 + OpBranch %52 + %54 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 3u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + { + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_FALSE(fusion.AreCompatible()); + } + + ld.CreatePreHeaderBlocksIfMissing(); + + { + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + std::string checks = R"( +CHECK: [[PHI_0:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_0]] +CHECK-NEXT: OpStore [[STORE_1]] +CHECK: [[PHI_1:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_2:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_1]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_2]] +CHECK: [[STORE_2:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI_1]] +CHECK-NEXT: OpStore [[STORE_2]] + )"; + + Match(checks, context.get()); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 1u); + + std::string checks = R"( +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_0]] +CHECK: [[STORE_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_2:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpLoad {{%\w+}} [[LOAD_2]] +CHECK: [[STORE_2:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_2]] + )"; + + Match(checks, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + + int sum_0 = 0; + int sum_1 = 0; + + // No loop-carried dependences, legal + for (int i = 0; i < 10; i++) { + sum_0 += a[i]; + } + for (int j = 0; j < 10; j++) { + sum_1 += b[j]; + } + + int total = sum_0 + sum_1; +} + +*/ +TEST_F(FusionLegalTest, IndependentReductions) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "sum_0" + OpName %10 "sum_1" + OpName %11 "i" + OpName %25 "a" + OpName %34 "j" + OpName %42 "b" + OpName %50 "total" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %18 = OpConstant %6 10 + %19 = OpTypeBool + %21 = OpTypeInt 32 0 + %22 = OpConstant %21 10 + %23 = OpTypeArray %6 %22 + %24 = OpTypePointer Function %23 + %32 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %11 = OpVariable %7 Function + %25 = OpVariable %24 Function + %34 = OpVariable %7 Function + %42 = OpVariable %24 Function + %50 = OpVariable %7 Function + OpStore %8 %9 + OpStore %10 %9 + OpStore %11 %9 + OpBranch %12 + %12 = OpLabel + %57 = OpPhi %6 %9 %5 %30 %15 + %54 = OpPhi %6 %9 %5 %33 %15 + OpLoopMerge %14 %15 None + OpBranch %16 + %16 = OpLabel + %20 = OpSLessThan %19 %54 %18 + OpBranchConditional %20 %13 %14 + %13 = OpLabel + %27 = OpAccessChain %7 %25 %54 + %28 = OpLoad %6 %27 + %30 = OpIAdd %6 %57 %28 + OpStore %8 %30 + OpBranch %15 + %15 = OpLabel + %33 = OpIAdd %6 %54 %32 + OpStore %11 %33 + OpBranch %12 + %14 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %58 = OpPhi %6 %9 %14 %47 %38 + %55 = OpPhi %6 %9 %14 %49 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %19 %55 %18 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %44 = OpAccessChain %7 %42 %55 + %45 = OpLoad %6 %44 + %47 = OpIAdd %6 %58 %45 + OpStore %10 %47 + OpBranch %38 + %38 = OpLabel + %49 = OpIAdd %6 %55 %32 + OpStore %34 %49 + OpBranch %35 + %37 = OpLabel + %53 = OpIAdd %6 %57 %58 + OpStore %50 %53 + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 1u); + + std::string checks = R"( +CHECK: [[SUM_0:%\w+]] = OpPhi +CHECK-NEXT: [[SUM_1:%\w+]] = OpPhi +CHECK-NEXT: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: [[LOAD_RES_0:%\w+]] = OpLoad {{%\w+}} [[LOAD_0]] +CHECK-NEXT: [[ADD_RES_0:%\w+]] = OpIAdd {{%\w+}} [[SUM_0]] [[LOAD_RES_0]] +CHECK-NEXT: OpStore {{%\w+}} [[ADD_RES_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: [[LOAD_RES_1:%\w+]] = OpLoad {{%\w+}} [[LOAD_1]] +CHECK-NEXT: [[ADD_RES_1:%\w+]] = OpIAdd {{%\w+}} [[SUM_1]] [[LOAD_RES_1]] +CHECK-NEXT: OpStore {{%\w+}} [[ADD_RES_1]] + )"; + + Match(checks, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + + int sum_0 = 0; + int sum_1 = 0; + + // No loop-carried dependences, legal + for (int i = 0; i < 10; i++) { + sum_0 += a[i]; + } + for (int j = 0; j < 10; j++) { + sum_1 += b[j]; + } + + int total = sum_0 + sum_1; +} + +*/ +TEST_F(FusionLegalTest, IndependentReductionsOneLCSSA) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "sum_0" + OpName %10 "sum_1" + OpName %11 "i" + OpName %25 "a" + OpName %34 "j" + OpName %42 "b" + OpName %50 "total" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %18 = OpConstant %6 10 + %19 = OpTypeBool + %21 = OpTypeInt 32 0 + %22 = OpConstant %21 10 + %23 = OpTypeArray %6 %22 + %24 = OpTypePointer Function %23 + %32 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %11 = OpVariable %7 Function + %25 = OpVariable %24 Function + %34 = OpVariable %7 Function + %42 = OpVariable %24 Function + %50 = OpVariable %7 Function + OpStore %8 %9 + OpStore %10 %9 + OpStore %11 %9 + OpBranch %12 + %12 = OpLabel + %57 = OpPhi %6 %9 %5 %30 %15 + %54 = OpPhi %6 %9 %5 %33 %15 + OpLoopMerge %14 %15 None + OpBranch %16 + %16 = OpLabel + %20 = OpSLessThan %19 %54 %18 + OpBranchConditional %20 %13 %14 + %13 = OpLabel + %27 = OpAccessChain %7 %25 %54 + %28 = OpLoad %6 %27 + %30 = OpIAdd %6 %57 %28 + OpStore %8 %30 + OpBranch %15 + %15 = OpLabel + %33 = OpIAdd %6 %54 %32 + OpStore %11 %33 + OpBranch %12 + %14 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %58 = OpPhi %6 %9 %14 %47 %38 + %55 = OpPhi %6 %9 %14 %49 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %19 %55 %18 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %44 = OpAccessChain %7 %42 %55 + %45 = OpLoad %6 %44 + %47 = OpIAdd %6 %58 %45 + OpStore %10 %47 + OpBranch %38 + %38 = OpLabel + %49 = OpIAdd %6 %55 %32 + OpStore %34 %49 + OpBranch %35 + %37 = OpLabel + %53 = OpIAdd %6 %57 %58 + OpStore %50 %53 + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopUtils utils_0(context.get(), loops[0]); + utils_0.MakeLoopClosedSSA(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 1u); + + std::string checks = R"( +CHECK: [[SUM_0:%\w+]] = OpPhi +CHECK-NEXT: [[SUM_1:%\w+]] = OpPhi +CHECK-NEXT: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: [[LOAD_RES_0:%\w+]] = OpLoad {{%\w+}} [[LOAD_0]] +CHECK-NEXT: [[ADD_RES_0:%\w+]] = OpIAdd {{%\w+}} [[SUM_0]] [[LOAD_RES_0]] +CHECK-NEXT: OpStore {{%\w+}} [[ADD_RES_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: [[LOAD_RES_1:%\w+]] = OpLoad {{%\w+}} [[LOAD_1]] +CHECK-NEXT: [[ADD_RES_1:%\w+]] = OpIAdd {{%\w+}} [[SUM_1]] [[LOAD_RES_1]] +CHECK-NEXT: OpStore {{%\w+}} [[ADD_RES_1]] + )"; + + Match(checks, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + + int sum_0 = 0; + int sum_1 = 0; + + // No loop-carried dependences, legal + for (int i = 0; i < 10; i++) { + sum_0 += a[i]; + } + for (int j = 0; j < 10; j++) { + sum_1 += b[j]; + } + + int total = sum_0 + sum_1; +} + +*/ +TEST_F(FusionLegalTest, IndependentReductionsBothLCSSA) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "sum_0" + OpName %10 "sum_1" + OpName %11 "i" + OpName %25 "a" + OpName %34 "j" + OpName %42 "b" + OpName %50 "total" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %18 = OpConstant %6 10 + %19 = OpTypeBool + %21 = OpTypeInt 32 0 + %22 = OpConstant %21 10 + %23 = OpTypeArray %6 %22 + %24 = OpTypePointer Function %23 + %32 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %11 = OpVariable %7 Function + %25 = OpVariable %24 Function + %34 = OpVariable %7 Function + %42 = OpVariable %24 Function + %50 = OpVariable %7 Function + OpStore %8 %9 + OpStore %10 %9 + OpStore %11 %9 + OpBranch %12 + %12 = OpLabel + %57 = OpPhi %6 %9 %5 %30 %15 + %54 = OpPhi %6 %9 %5 %33 %15 + OpLoopMerge %14 %15 None + OpBranch %16 + %16 = OpLabel + %20 = OpSLessThan %19 %54 %18 + OpBranchConditional %20 %13 %14 + %13 = OpLabel + %27 = OpAccessChain %7 %25 %54 + %28 = OpLoad %6 %27 + %30 = OpIAdd %6 %57 %28 + OpStore %8 %30 + OpBranch %15 + %15 = OpLabel + %33 = OpIAdd %6 %54 %32 + OpStore %11 %33 + OpBranch %12 + %14 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %58 = OpPhi %6 %9 %14 %47 %38 + %55 = OpPhi %6 %9 %14 %49 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %19 %55 %18 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %44 = OpAccessChain %7 %42 %55 + %45 = OpLoad %6 %44 + %47 = OpIAdd %6 %58 %45 + OpStore %10 %47 + OpBranch %38 + %38 = OpLabel + %49 = OpIAdd %6 %55 %32 + OpStore %34 %49 + OpBranch %35 + %37 = OpLabel + %53 = OpIAdd %6 %57 %58 + OpStore %50 %53 + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopUtils utils_0(context.get(), loops[0]); + utils_0.MakeLoopClosedSSA(); + LoopUtils utils_1(context.get(), loops[1]); + utils_1.MakeLoopClosedSSA(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 1u); + + std::string checks = R"( +CHECK: [[SUM_0:%\w+]] = OpPhi +CHECK-NEXT: [[SUM_1:%\w+]] = OpPhi +CHECK-NEXT: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: [[LOAD_RES_0:%\w+]] = OpLoad {{%\w+}} [[LOAD_0]] +CHECK-NEXT: [[ADD_RES_0:%\w+]] = OpIAdd {{%\w+}} [[SUM_0]] [[LOAD_RES_0]] +CHECK-NEXT: OpStore {{%\w+}} [[ADD_RES_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: [[LOAD_RES_1:%\w+]] = OpLoad {{%\w+}} [[LOAD_1]] +CHECK-NEXT: [[ADD_RES_1:%\w+]] = OpIAdd {{%\w+}} [[SUM_1]] [[LOAD_RES_1]] +CHECK-NEXT: OpStore {{%\w+}} [[ADD_RES_1]] + )"; + + Match(checks, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + + int sum_0 = 0; + + // No loop-carried dependences, legal + for (int i = 0; i < 10; i++) { + sum_0 += a[i]; + } + for (int j = 0; j < 10; j++) { + a[j] = b[j]; + } +} + +*/ +TEST_F(FusionLegalTest, LoadStoreReductionAndNonLoopCarriedDependence) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "sum_0" + OpName %10 "i" + OpName %24 "a" + OpName %33 "j" + OpName %42 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %17 = OpConstant %6 10 + %18 = OpTypeBool + %20 = OpTypeInt 32 0 + %21 = OpConstant %20 10 + %22 = OpTypeArray %6 %21 + %23 = OpTypePointer Function %22 + %31 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %24 = OpVariable %23 Function + %33 = OpVariable %7 Function + %42 = OpVariable %23 Function + OpStore %8 %9 + OpStore %10 %9 + OpBranch %11 + %11 = OpLabel + %51 = OpPhi %6 %9 %5 %29 %14 + %49 = OpPhi %6 %9 %5 %32 %14 + OpLoopMerge %13 %14 None + OpBranch %15 + %15 = OpLabel + %19 = OpSLessThan %18 %49 %17 + OpBranchConditional %19 %12 %13 + %12 = OpLabel + %26 = OpAccessChain %7 %24 %49 + %27 = OpLoad %6 %26 + %29 = OpIAdd %6 %51 %27 + OpStore %8 %29 + OpBranch %14 + %14 = OpLabel + %32 = OpIAdd %6 %49 %31 + OpStore %10 %32 + OpBranch %11 + %13 = OpLabel + OpStore %33 %9 + OpBranch %34 + %34 = OpLabel + %50 = OpPhi %6 %9 %13 %48 %37 + OpLoopMerge %36 %37 None + OpBranch %38 + %38 = OpLabel + %40 = OpSLessThan %18 %50 %17 + OpBranchConditional %40 %35 %36 + %35 = OpLabel + %44 = OpAccessChain %7 %42 %50 + %45 = OpLoad %6 %44 + %46 = OpAccessChain %7 %24 %50 + OpStore %46 %45 + OpBranch %37 + %37 = OpLabel + %48 = OpIAdd %6 %50 %31 + OpStore %33 %48 + OpBranch %34 + %36 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + // TODO: Loop descriptor doesn't return induction variables but all OpPhi + // in the header and LoopDependenceAnalysis falls over. + // EXPECT_TRUE(fusion.IsLegal()); + + // fusion.Fuse(); + } + + { + // LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + // EXPECT_EQ(ld.NumLoops(), 1u); + + // std::string checks = R"( + // CHECK: [[SUM_0:%\w+]] = OpPhi + // CHECK-NEXT: [[PHI:%\w+]] = OpPhi + // CHECK-NEXT: OpLoopMerge + // CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] + // CHECK-NEXT: [[LOAD_RES_0:%\w+]] = OpLoad {{%\w+}} [[LOAD_0]] + // CHECK-NEXT: [[ADD_RES_0:%\w+]] = OpIAdd {{%\w+}} [[SUM_0]] [[LOAD_RES_0]] + // CHECK-NEXT: OpStore {{%\w+}} [[ADD_RES_0]] + // CHECK-NOT: OpPhi + // CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] + // CHECK-NEXT: [[LOAD_RES_1:%\w+]] = OpLoad {{%\w+}} [[LOAD_1]] + // CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] + // CHECK-NEXT: OpStore [[STORE_1]] [[LOAD_RES_1]] + // )"; + + // Match(checks, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +int x; +void main() { + int[10] a; + int[10] b; + + // Legal. + for (int i = 0; i < 10; i++) { + x += a[i]; + } + for (int j = 0; j < 10; j++) { + b[j] = b[j]+1; + } +} + +*/ +TEST_F(FusionLegalTest, ReductionAndNonLoopCarriedDependence) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %20 "x" + OpName %25 "a" + OpName %34 "j" + OpName %42 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypePointer Private %6 + %20 = OpVariable %19 Private + %21 = OpTypeInt 32 0 + %22 = OpConstant %21 10 + %23 = OpTypeArray %6 %22 + %24 = OpTypePointer Function %23 + %32 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %25 = OpVariable %24 Function + %34 = OpVariable %7 Function + %42 = OpVariable %24 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %51 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %51 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpAccessChain %7 %25 %51 + %28 = OpLoad %6 %27 + %29 = OpLoad %6 %20 + %30 = OpIAdd %6 %29 %28 + OpStore %20 %30 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %51 %32 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %52 = OpPhi %6 %9 %12 %50 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %52 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %45 = OpAccessChain %7 %42 %52 + %46 = OpLoad %6 %45 + %47 = OpIAdd %6 %46 %32 + %48 = OpAccessChain %7 %42 %52 + OpStore %48 %47 + OpBranch %38 + %38 = OpLabel + %50 = OpIAdd %6 %52 %32 + OpStore %34 %50 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 1u); + + std::string checks = R"( +CHECK: OpName [[X:%\w+]] "x" +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[LOAD_0:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: [[LOAD_RES_0:%\w+]] = OpLoad {{%\w+}} [[LOAD_0]] +CHECK-NEXT: [[X_LOAD:%\w+]] = OpLoad {{%\w+}} [[X]] +CHECK-NEXT: [[ADD_RES_0:%\w+]] = OpIAdd {{%\w+}} [[X_LOAD]] [[LOAD_RES_0]] +CHECK-NEXT: OpStore [[X]] [[ADD_RES_0]] +CHECK-NOT: OpPhi +CHECK: [[LOAD_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: {{%\w+}} = OpLoad {{%\w+}} [[LOAD_1]] +CHECK: [[STORE_1:%\w+]] = OpAccessChain {{%\w+}} {{%\w+}} [[PHI]] +CHECK-NEXT: OpStore [[STORE_1]] + )"; + + Match(checks, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +struct TestStruct { + int[10] a; + int b; +}; + +void main() { + TestStruct test_0; + TestStruct test_1; + TestStruct test_2; + + test_1.b = 2; + + for (int i = 0; i < 10; i++) { + test_0.a[i] = i; + } + for (int j = 0; j < 10; j++) { + test_2 = test_1; + } +} + +*/ +TEST_F(FusionLegalTest, ArrayInStruct) { + std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %10 "TestStruct" + OpMemberName %10 0 "a" + OpMemberName %10 1 "b" + OpName %12 "test_1" + OpName %17 "i" + OpName %28 "test_0" + OpName %34 "j" + OpName %42 "test_2" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypeInt 32 0 + %8 = OpConstant %7 10 + %9 = OpTypeArray %6 %8 + %10 = OpTypeStruct %9 %6 + %11 = OpTypePointer Function %10 + %13 = OpConstant %6 1 + %14 = OpConstant %6 2 + %15 = OpTypePointer Function %6 + %18 = OpConstant %6 0 + %25 = OpConstant %6 10 + %26 = OpTypeBool + %4 = OpFunction %2 None %3 + %5 = OpLabel + %12 = OpVariable %11 Function + %17 = OpVariable %15 Function + %28 = OpVariable %11 Function + %34 = OpVariable %15 Function + %42 = OpVariable %11 Function + %16 = OpAccessChain %15 %12 %13 + OpStore %16 %14 + OpStore %17 %18 + OpBranch %19 + %19 = OpLabel + %46 = OpPhi %6 %18 %5 %33 %22 + OpLoopMerge %21 %22 None + OpBranch %23 + %23 = OpLabel + %27 = OpSLessThan %26 %46 %25 + OpBranchConditional %27 %20 %21 + %20 = OpLabel + %31 = OpAccessChain %15 %28 %18 %46 + OpStore %31 %46 + OpBranch %22 + %22 = OpLabel + %33 = OpIAdd %6 %46 %13 + OpStore %17 %33 + OpBranch %19 + %21 = OpLabel + OpStore %34 %18 + OpBranch %35 + %35 = OpLabel + %47 = OpPhi %6 %18 %21 %45 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %26 %47 %25 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %43 = OpLoad %10 %12 + OpStore %42 %43 + OpBranch %38 + %38 = OpLabel + %45 = OpIAdd %6 %47 %13 + OpStore %34 %45 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 2u); + + auto loops = ld.GetLoopsInBinaryLayoutOrder(); + + LoopFusion fusion(context.get(), loops[0], loops[1]); + EXPECT_TRUE(fusion.AreCompatible()); + EXPECT_TRUE(fusion.IsLegal()); + + fusion.Fuse(); + } + + { + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), 1u); + + // clang-format off + std::string checks = R"( +CHECK: OpName [[TEST_1:%\w+]] "test_1" +CHECK: OpName [[TEST_0:%\w+]] "test_0" +CHECK: OpName [[TEST_2:%\w+]] "test_2" +CHECK: [[PHI:%\w+]] = OpPhi +CHECK-NEXT: OpLoopMerge +CHECK: [[TEST_0_STORE:%\w+]] = OpAccessChain {{%\w+}} [[TEST_0]] {{%\w+}} {{%\w+}} +CHECK-NEXT: OpStore [[TEST_0_STORE]] [[PHI]] +CHECK-NOT: OpPhi +CHECK: [[TEST_1_LOAD:%\w+]] = OpLoad {{%\w+}} [[TEST_1]] +CHECK: OpStore [[TEST_2]] [[TEST_1_LOAD]] + )"; + // clang-format on + + Match(checks, context.get()); + } +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/loop_optimizations/fusion_pass.cpp b/test/opt/loop_optimizations/fusion_pass.cpp new file mode 100644 index 000000000..949392375 --- /dev/null +++ b/test/opt/loop_optimizations/fusion_pass.cpp @@ -0,0 +1,717 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "effcee/effcee.h" +#include "gmock/gmock.h" +#include "test/opt/pass_fixture.h" + +namespace spvtools { +namespace opt { +namespace { + +using FusionPassTest = PassTest<::testing::Test>; + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + for (int i = 0; i < 10; i++) { + a[i] = a[i]*2; + } + for (int i = 0; i < 10; i++) { + b[i] = a[i]+2; + } +} + +*/ +TEST_F(FusionPassTest, SimpleFusion) { + const std::string text = R"( +; CHECK: OpPhi +; CHECK: OpLoad +; CHECK: OpStore +; CHECK-NOT: OpPhi +; CHECK: OpLoad +; CHECK: OpStore + + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %34 "i" + OpName %42 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %28 = OpConstant %6 2 + %32 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %34 = OpVariable %7 Function + %42 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %51 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %51 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %26 = OpAccessChain %7 %23 %51 + %27 = OpLoad %6 %26 + %29 = OpIMul %6 %27 %28 + %30 = OpAccessChain %7 %23 %51 + OpStore %30 %29 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %51 %32 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %52 = OpPhi %6 %9 %12 %50 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %52 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %45 = OpAccessChain %7 %23 %52 + %46 = OpLoad %6 %45 + %47 = OpIAdd %6 %46 %28 + %48 = OpAccessChain %7 %42 %52 + OpStore %48 %47 + OpBranch %38 + %38 = OpLabel + %50 = OpIAdd %6 %52 %32 + OpStore %34 %50 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true, 20); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + for (int i = 0; i < 10; i++) { + a[i] = b[i] + 1; + } + for (int i = 0; i < 10; i++) { + c[i] = a[i] + 2; + } + for (int i = 0; i < 10; i++) { + b[i] = c[i] + 10; + } +} + +*/ +TEST_F(FusionPassTest, ThreeLoopsFused) { + const std::string text = R"( +; CHECK: OpPhi +; CHECK: OpLoad +; CHECK: OpStore +; CHECK-NOT: OpPhi +; CHECK: OpLoad +; CHECK: OpStore +; CHECK-NOT: OpPhi +; CHECK: OpLoad +; CHECK: OpStore + + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %25 "b" + OpName %34 "i" + OpName %42 "c" + OpName %52 "i" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %29 = OpConstant %6 1 + %47 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %25 = OpVariable %22 Function + %34 = OpVariable %7 Function + %42 = OpVariable %22 Function + %52 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %68 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %68 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpAccessChain %7 %25 %68 + %28 = OpLoad %6 %27 + %30 = OpIAdd %6 %28 %29 + %31 = OpAccessChain %7 %23 %68 + OpStore %31 %30 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %68 %29 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %69 = OpPhi %6 %9 %12 %51 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %69 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %45 = OpAccessChain %7 %23 %69 + %46 = OpLoad %6 %45 + %48 = OpIAdd %6 %46 %47 + %49 = OpAccessChain %7 %42 %69 + OpStore %49 %48 + OpBranch %38 + %38 = OpLabel + %51 = OpIAdd %6 %69 %29 + OpStore %34 %51 + OpBranch %35 + %37 = OpLabel + OpStore %52 %9 + OpBranch %53 + %53 = OpLabel + %70 = OpPhi %6 %9 %37 %67 %56 + OpLoopMerge %55 %56 None + OpBranch %57 + %57 = OpLabel + %59 = OpSLessThan %17 %70 %16 + OpBranchConditional %59 %54 %55 + %54 = OpLabel + %62 = OpAccessChain %7 %42 %70 + %63 = OpLoad %6 %62 + %64 = OpIAdd %6 %63 %16 + %65 = OpAccessChain %7 %25 %70 + OpStore %65 %64 + OpBranch %56 + %56 = OpLabel + %67 = OpIAdd %6 %70 %29 + OpStore %52 %67 + OpBranch %53 + %55 = OpLabel + OpReturn + OpFunctionEnd + + )"; + + SinglePassRunAndMatch(text, true, 20); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10][10] a; + int[10][10] b; + int[10][10] c; + // Legal both + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + c[i][j] = a[i][j] + 2; + } + } + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + b[i][j] = c[i][j] + 10; + } + } +} + +*/ +TEST_F(FusionPassTest, NestedLoopsFused) { + const std::string text = R"( +; CHECK: OpPhi +; CHECK: OpPhi +; CHECK: OpLoad +; CHECK: OpStore +; CHECK-NOT: OpPhi +; CHECK: OpLoad +; CHECK: OpStore + + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %19 "j" + OpName %32 "c" + OpName %35 "a" + OpName %48 "i" + OpName %56 "j" + OpName %64 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %27 = OpTypeInt 32 0 + %28 = OpConstant %27 10 + %29 = OpTypeArray %6 %28 + %30 = OpTypeArray %29 %28 + %31 = OpTypePointer Function %30 + %40 = OpConstant %6 2 + %44 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %32 = OpVariable %31 Function + %35 = OpVariable %31 Function + %48 = OpVariable %7 Function + %56 = OpVariable %7 Function + %64 = OpVariable %31 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %77 = OpPhi %6 %9 %5 %47 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %77 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpStore %19 %9 + OpBranch %20 + %20 = OpLabel + %81 = OpPhi %6 %9 %11 %45 %23 + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + %26 = OpSLessThan %17 %81 %16 + OpBranchConditional %26 %21 %22 + %21 = OpLabel + %38 = OpAccessChain %7 %35 %77 %81 + %39 = OpLoad %6 %38 + %41 = OpIAdd %6 %39 %40 + %42 = OpAccessChain %7 %32 %77 %81 + OpStore %42 %41 + OpBranch %23 + %23 = OpLabel + %45 = OpIAdd %6 %81 %44 + OpStore %19 %45 + OpBranch %20 + %22 = OpLabel + OpBranch %13 + %13 = OpLabel + %47 = OpIAdd %6 %77 %44 + OpStore %8 %47 + OpBranch %10 + %12 = OpLabel + OpStore %48 %9 + OpBranch %49 + %49 = OpLabel + %78 = OpPhi %6 %9 %12 %76 %52 + OpLoopMerge %51 %52 None + OpBranch %53 + %53 = OpLabel + %55 = OpSLessThan %17 %78 %16 + OpBranchConditional %55 %50 %51 + %50 = OpLabel + OpStore %56 %9 + OpBranch %57 + %57 = OpLabel + %79 = OpPhi %6 %9 %50 %74 %60 + OpLoopMerge %59 %60 None + OpBranch %61 + %61 = OpLabel + %63 = OpSLessThan %17 %79 %16 + OpBranchConditional %63 %58 %59 + %58 = OpLabel + %69 = OpAccessChain %7 %32 %78 %79 + %70 = OpLoad %6 %69 + %71 = OpIAdd %6 %70 %16 + %72 = OpAccessChain %7 %64 %78 %79 + OpStore %72 %71 + OpBranch %60 + %60 = OpLabel + %74 = OpIAdd %6 %79 %44 + OpStore %56 %74 + OpBranch %57 + %59 = OpLabel + OpBranch %52 + %52 = OpLabel + %76 = OpIAdd %6 %78 %44 + OpStore %48 %76 + OpBranch %49 + %51 = OpLabel + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true, 20); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + // Can't fuse, different step + for (int i = 0; i < 10; i++) {} + for (int j = 0; j < 10; j=j+2) {} +} + +*/ +TEST_F(FusionPassTest, Incompatible) { + const std::string text = R"( +; CHECK: OpPhi +; CHECK-NEXT: OpLoopMerge +; CHECK: OpPhi +; CHECK-NEXT: OpLoopMerge + + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %22 "j" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %20 = OpConstant %6 1 + %31 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %22 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %33 = OpPhi %6 %9 %5 %21 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %33 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %13 + %13 = OpLabel + %21 = OpIAdd %6 %33 %20 + OpStore %8 %21 + OpBranch %10 + %12 = OpLabel + OpStore %22 %9 + OpBranch %23 + %23 = OpLabel + %34 = OpPhi %6 %9 %12 %32 %26 + OpLoopMerge %25 %26 None + OpBranch %27 + %27 = OpLabel + %29 = OpSLessThan %17 %34 %16 + OpBranchConditional %29 %24 %25 + %24 = OpLabel + OpBranch %26 + %26 = OpLabel + %32 = OpIAdd %6 %34 %31 + OpStore %22 %32 + OpBranch %23 + %25 = OpLabel + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true, 20); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + int[10] c; + // Illegal, loop-independent dependence will become a + // backward loop-carried antidependence + for (int i = 0; i < 10; i++) { + a[i] = b[i] + 1; + } + for (int i = 0; i < 10; i++) { + c[i] = a[i+1] + 2; + } +} + +*/ +TEST_F(FusionPassTest, Illegal) { + std::string text = R"( +; CHECK: OpPhi +; CHECK-NEXT: OpLoopMerge +; CHECK: OpLoad +; CHECK: OpStore +; CHECK: OpPhi +; CHECK-NEXT: OpLoopMerge +; CHECK: OpLoad +; CHECK: OpStore + + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %25 "b" + OpName %34 "i" + OpName %42 "c" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %29 = OpConstant %6 1 + %48 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %25 = OpVariable %22 Function + %34 = OpVariable %7 Function + %42 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %53 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %53 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpAccessChain %7 %25 %53 + %28 = OpLoad %6 %27 + %30 = OpIAdd %6 %28 %29 + %31 = OpAccessChain %7 %23 %53 + OpStore %31 %30 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %53 %29 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %54 = OpPhi %6 %9 %12 %52 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %54 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %45 = OpIAdd %6 %54 %29 + %46 = OpAccessChain %7 %23 %45 + %47 = OpLoad %6 %46 + %49 = OpIAdd %6 %47 %48 + %50 = OpAccessChain %7 %42 %54 + OpStore %50 %49 + OpBranch %38 + %38 = OpLabel + %52 = OpIAdd %6 %54 %29 + OpStore %34 %52 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true, 20); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +void main() { + int[10] a; + int[10] b; + for (int i = 0; i < 10; i++) { + a[i] = a[i]*2; + } + for (int i = 0; i < 10; i++) { + b[i] = a[i]+2; + } +} + +*/ +TEST_F(FusionPassTest, TooManyRegisters) { + const std::string text = R"( +; CHECK: OpPhi +; CHECK-NEXT: OpLoopMerge +; CHECK: OpLoad +; CHECK: OpStore +; CHECK: OpPhi +; CHECK-NEXT: OpLoopMerge +; CHECK: OpLoad +; CHECK: OpStore + + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + OpName %8 "i" + OpName %23 "a" + OpName %34 "i" + OpName %42 "b" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 10 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %28 = OpConstant %6 2 + %32 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %23 = OpVariable %22 Function + %34 = OpVariable %7 Function + %42 = OpVariable %22 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + %51 = OpPhi %6 %9 %5 %33 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %51 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %26 = OpAccessChain %7 %23 %51 + %27 = OpLoad %6 %26 + %29 = OpIMul %6 %27 %28 + %30 = OpAccessChain %7 %23 %51 + OpStore %30 %29 + OpBranch %13 + %13 = OpLabel + %33 = OpIAdd %6 %51 %32 + OpStore %8 %33 + OpBranch %10 + %12 = OpLabel + OpStore %34 %9 + OpBranch %35 + %35 = OpLabel + %52 = OpPhi %6 %9 %12 %50 %38 + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %41 = OpSLessThan %17 %52 %16 + OpBranchConditional %41 %36 %37 + %36 = OpLabel + %45 = OpAccessChain %7 %23 %52 + %46 = OpLoad %6 %45 + %47 = OpIAdd %6 %46 %28 + %48 = OpAccessChain %7 %42 %52 + OpStore %48 %47 + OpBranch %38 + %38 = OpLabel + %50 = OpIAdd %6 %52 %32 + OpStore %34 %50 + OpBranch %35 + %37 = OpLabel + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true, 5); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/loop_optimizations/hoist_all_loop_types.cpp b/test/opt/loop_optimizations/hoist_all_loop_types.cpp new file mode 100644 index 000000000..27e0a0d91 --- /dev/null +++ b/test/opt/loop_optimizations/hoist_all_loop_types.cpp @@ -0,0 +1,285 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "source/opt/licm_pass.h" +#include "test/opt/pass_fixture.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::UnorderedElementsAre; +using PassClassTest = PassTest<::testing::Test>; + +/* + Tests that all loop types are handled appropriately by the LICM pass. + + Generated from the following GLSL fragment shader +--eliminate-local-multi-store has also been run on the spv binary +#version 440 core +void main(){ + int i_1 = 0; + for (i_1 = 0; i_1 < 10; i_1++) { + } + int i_2 = 0; + while (i_2 < 10) { + i_2++; + } + int i_3 = 0; + do { + i_3++; + } while (i_3 < 10); + int hoist = 0; + int i_4 = 0; + int i_5 = 0; + int i_6 = 0; + for (i_4 = 0; i_4 < 10; i_4++) { + while (i_5 < 10) { + do { + hoist = i_1 + i_2 + i_3; + i_6++; + } while (i_6 < 10); + i_5++; + } + } +} +*/ +TEST_F(PassClassTest, AllLoopTypes) { + const std::string before_hoist = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 440 +OpName %main "main" +%void = OpTypeVoid +%4 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%int_1 = OpConstant %int 1 +%main = OpFunction %void None %4 +%11 = OpLabel +OpBranch %12 +%12 = OpLabel +%13 = OpPhi %int %int_0 %11 %14 %15 +OpLoopMerge %16 %15 None +OpBranch %17 +%17 = OpLabel +%18 = OpSLessThan %bool %13 %int_10 +OpBranchConditional %18 %19 %16 +%19 = OpLabel +OpBranch %15 +%15 = OpLabel +%14 = OpIAdd %int %13 %int_1 +OpBranch %12 +%16 = OpLabel +OpBranch %20 +%20 = OpLabel +%21 = OpPhi %int %int_0 %16 %22 %23 +OpLoopMerge %24 %23 None +OpBranch %25 +%25 = OpLabel +%26 = OpSLessThan %bool %21 %int_10 +OpBranchConditional %26 %27 %24 +%27 = OpLabel +%22 = OpIAdd %int %21 %int_1 +OpBranch %23 +%23 = OpLabel +OpBranch %20 +%24 = OpLabel +OpBranch %28 +%28 = OpLabel +%29 = OpPhi %int %int_0 %24 %30 %31 +OpLoopMerge %32 %31 None +OpBranch %33 +%33 = OpLabel +%30 = OpIAdd %int %29 %int_1 +OpBranch %31 +%31 = OpLabel +%34 = OpSLessThan %bool %30 %int_10 +OpBranchConditional %34 %28 %32 +%32 = OpLabel +OpBranch %35 +%35 = OpLabel +%36 = OpPhi %int %int_0 %32 %37 %38 +%39 = OpPhi %int %int_0 %32 %40 %38 +%41 = OpPhi %int %int_0 %32 %42 %38 +%43 = OpPhi %int %int_0 %32 %44 %38 +OpLoopMerge %45 %38 None +OpBranch %46 +%46 = OpLabel +%47 = OpSLessThan %bool %39 %int_10 +OpBranchConditional %47 %48 %45 +%48 = OpLabel +OpBranch %49 +%49 = OpLabel +%37 = OpPhi %int %36 %48 %50 %51 +%42 = OpPhi %int %41 %48 %52 %51 +%44 = OpPhi %int %43 %48 %53 %51 +OpLoopMerge %54 %51 None +OpBranch %55 +%55 = OpLabel +%56 = OpSLessThan %bool %42 %int_10 +OpBranchConditional %56 %57 %54 +%57 = OpLabel +OpBranch %58 +%58 = OpLabel +%59 = OpPhi %int %37 %57 %50 %60 +%61 = OpPhi %int %44 %57 %53 %60 +OpLoopMerge %62 %60 None +OpBranch %63 +%63 = OpLabel +%64 = OpIAdd %int %13 %21 +%50 = OpIAdd %int %64 %30 +%53 = OpIAdd %int %61 %int_1 +OpBranch %60 +%60 = OpLabel +%65 = OpSLessThan %bool %53 %int_10 +OpBranchConditional %65 %58 %62 +%62 = OpLabel +%52 = OpIAdd %int %42 %int_1 +OpBranch %51 +%51 = OpLabel +OpBranch %49 +%54 = OpLabel +OpBranch %38 +%38 = OpLabel +%40 = OpIAdd %int %39 %int_1 +OpBranch %35 +%45 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string after_hoist = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 440 +OpName %main "main" +%void = OpTypeVoid +%4 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%int_1 = OpConstant %int 1 +%main = OpFunction %void None %4 +%11 = OpLabel +OpBranch %12 +%12 = OpLabel +%13 = OpPhi %int %int_0 %11 %14 %15 +OpLoopMerge %16 %15 None +OpBranch %17 +%17 = OpLabel +%18 = OpSLessThan %bool %13 %int_10 +OpBranchConditional %18 %19 %16 +%19 = OpLabel +OpBranch %15 +%15 = OpLabel +%14 = OpIAdd %int %13 %int_1 +OpBranch %12 +%16 = OpLabel +OpBranch %20 +%20 = OpLabel +%21 = OpPhi %int %int_0 %16 %22 %23 +OpLoopMerge %24 %23 None +OpBranch %25 +%25 = OpLabel +%26 = OpSLessThan %bool %21 %int_10 +OpBranchConditional %26 %27 %24 +%27 = OpLabel +%22 = OpIAdd %int %21 %int_1 +OpBranch %23 +%23 = OpLabel +OpBranch %20 +%24 = OpLabel +OpBranch %28 +%28 = OpLabel +%29 = OpPhi %int %int_0 %24 %30 %31 +OpLoopMerge %32 %31 None +OpBranch %33 +%33 = OpLabel +%30 = OpIAdd %int %29 %int_1 +OpBranch %31 +%31 = OpLabel +%34 = OpSLessThan %bool %30 %int_10 +OpBranchConditional %34 %28 %32 +%32 = OpLabel +%64 = OpIAdd %int %13 %21 +%50 = OpIAdd %int %64 %30 +OpBranch %35 +%35 = OpLabel +%36 = OpPhi %int %int_0 %32 %37 %38 +%39 = OpPhi %int %int_0 %32 %40 %38 +%41 = OpPhi %int %int_0 %32 %42 %38 +%43 = OpPhi %int %int_0 %32 %44 %38 +OpLoopMerge %45 %38 None +OpBranch %46 +%46 = OpLabel +%47 = OpSLessThan %bool %39 %int_10 +OpBranchConditional %47 %48 %45 +%48 = OpLabel +OpBranch %49 +%49 = OpLabel +%37 = OpPhi %int %36 %48 %50 %51 +%42 = OpPhi %int %41 %48 %52 %51 +%44 = OpPhi %int %43 %48 %53 %51 +OpLoopMerge %54 %51 None +OpBranch %55 +%55 = OpLabel +%56 = OpSLessThan %bool %42 %int_10 +OpBranchConditional %56 %57 %54 +%57 = OpLabel +OpBranch %58 +%58 = OpLabel +%59 = OpPhi %int %37 %57 %50 %60 +%61 = OpPhi %int %44 %57 %53 %60 +OpLoopMerge %62 %60 None +OpBranch %63 +%63 = OpLabel +%53 = OpIAdd %int %61 %int_1 +OpBranch %60 +%60 = OpLabel +%65 = OpSLessThan %bool %53 %int_10 +OpBranchConditional %65 %58 %62 +%62 = OpLabel +%52 = OpIAdd %int %42 %int_1 +OpBranch %51 +%51 = OpLabel +OpBranch %49 +%54 = OpLabel +OpBranch %38 +%38 = OpLabel +%40 = OpIAdd %int %39 %int_1 +OpBranch %35 +%45 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(before_hoist, after_hoist, true); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/loop_optimizations/hoist_double_nested_loops.cpp b/test/opt/loop_optimizations/hoist_double_nested_loops.cpp new file mode 100644 index 000000000..ea1949658 --- /dev/null +++ b/test/opt/loop_optimizations/hoist_double_nested_loops.cpp @@ -0,0 +1,162 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "source/opt/licm_pass.h" +#include "test/opt/pass_fixture.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::UnorderedElementsAre; +using PassClassTest = PassTest<::testing::Test>; + +/* + Tests that the LICM pass will move invariants through multiple loops + + Generated from the following GLSL fragment shader +--eliminate-local-multi-store has also been run on the spv binary +#version 440 core +void main(){ + int a = 2; + int b = 1; + int hoist = 0; + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + // hoist 'hoist = a - b' out of both loops + hoist = a - b; + } + } +} +*/ +TEST_F(PassClassTest, NestedDoubleHoist) { + const std::string before_hoist = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 440 +OpName %main "main" +%void = OpTypeVoid +%4 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_2 = OpConstant %int 2 +%int_1 = OpConstant %int 1 +%int_0 = OpConstant %int 0 +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%12 = OpUndef %int +%main = OpFunction %void None %4 +%13 = OpLabel +OpBranch %14 +%14 = OpLabel +%15 = OpPhi %int %int_0 %13 %16 %17 +%18 = OpPhi %int %int_0 %13 %19 %17 +%20 = OpPhi %int %12 %13 %21 %17 +OpLoopMerge %22 %17 None +OpBranch %23 +%23 = OpLabel +%24 = OpSLessThan %bool %18 %int_10 +OpBranchConditional %24 %25 %22 +%25 = OpLabel +OpBranch %26 +%26 = OpLabel +%16 = OpPhi %int %15 %25 %27 %28 +%21 = OpPhi %int %int_0 %25 %29 %28 +OpLoopMerge %30 %28 None +OpBranch %31 +%31 = OpLabel +%32 = OpSLessThan %bool %21 %int_10 +OpBranchConditional %32 %33 %30 +%33 = OpLabel +%27 = OpISub %int %int_2 %int_1 +OpBranch %28 +%28 = OpLabel +%29 = OpIAdd %int %21 %int_1 +OpBranch %26 +%30 = OpLabel +OpBranch %17 +%17 = OpLabel +%19 = OpIAdd %int %18 %int_1 +OpBranch %14 +%22 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string after_hoist = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 440 +OpName %main "main" +%void = OpTypeVoid +%4 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_2 = OpConstant %int 2 +%int_1 = OpConstant %int 1 +%int_0 = OpConstant %int 0 +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%12 = OpUndef %int +%main = OpFunction %void None %4 +%13 = OpLabel +%27 = OpISub %int %int_2 %int_1 +OpBranch %14 +%14 = OpLabel +%15 = OpPhi %int %int_0 %13 %16 %17 +%18 = OpPhi %int %int_0 %13 %19 %17 +%20 = OpPhi %int %12 %13 %21 %17 +OpLoopMerge %22 %17 None +OpBranch %23 +%23 = OpLabel +%24 = OpSLessThan %bool %18 %int_10 +OpBranchConditional %24 %25 %22 +%25 = OpLabel +OpBranch %26 +%26 = OpLabel +%16 = OpPhi %int %15 %25 %27 %28 +%21 = OpPhi %int %int_0 %25 %29 %28 +OpLoopMerge %30 %28 None +OpBranch %31 +%31 = OpLabel +%32 = OpSLessThan %bool %21 %int_10 +OpBranchConditional %32 %33 %30 +%33 = OpLabel +OpBranch %28 +%28 = OpLabel +%29 = OpIAdd %int %21 %int_1 +OpBranch %26 +%30 = OpLabel +OpBranch %17 +%17 = OpLabel +%19 = OpIAdd %int %18 %int_1 +OpBranch %14 +%22 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(before_hoist, after_hoist, true); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/loop_optimizations/hoist_from_independent_loops.cpp b/test/opt/loop_optimizations/hoist_from_independent_loops.cpp new file mode 100644 index 000000000..abc79e37c --- /dev/null +++ b/test/opt/loop_optimizations/hoist_from_independent_loops.cpp @@ -0,0 +1,201 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "source/opt/licm_pass.h" +#include "test/opt/pass_fixture.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::UnorderedElementsAre; +using PassClassTest = PassTest<::testing::Test>; + +/* + Tests that the LICM pass will analyse multiple independent loops in a function + + Generated from the following GLSL fragment shader +--eliminate-local-multi-store has also been run on the spv binary +#version 440 core +void main(){ + int a = 1; + int b = 2; + int hoist = 0; + for (int i = 0; i < 10; i++) { + // invariant + hoist = a + b; + } + for (int i = 0; i < 10; i++) { + // invariant + hoist = a + b; + } + int c = 1; + int d = 2; + int hoist2 = 0; + for (int i = 0; i < 10; i++) { + // invariant + hoist2 = c + d; + } +} +*/ +TEST_F(PassClassTest, HoistFromIndependentLoops) { + const std::string before_hoist = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 440 +OpName %main "main" +%void = OpTypeVoid +%4 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_1 = OpConstant %int 1 +%int_2 = OpConstant %int 2 +%int_0 = OpConstant %int 0 +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%main = OpFunction %void None %4 +%12 = OpLabel +OpBranch %13 +%13 = OpLabel +%14 = OpPhi %int %int_0 %12 %15 %16 +%17 = OpPhi %int %int_0 %12 %18 %16 +OpLoopMerge %19 %16 None +OpBranch %20 +%20 = OpLabel +%21 = OpSLessThan %bool %17 %int_10 +OpBranchConditional %21 %22 %19 +%22 = OpLabel +%15 = OpIAdd %int %int_1 %int_2 +OpBranch %16 +%16 = OpLabel +%18 = OpIAdd %int %17 %int_1 +OpBranch %13 +%19 = OpLabel +OpBranch %23 +%23 = OpLabel +%24 = OpPhi %int %14 %19 %25 %26 +%27 = OpPhi %int %int_0 %19 %28 %26 +OpLoopMerge %29 %26 None +OpBranch %30 +%30 = OpLabel +%31 = OpSLessThan %bool %27 %int_10 +OpBranchConditional %31 %32 %29 +%32 = OpLabel +%25 = OpIAdd %int %int_1 %int_2 +OpBranch %26 +%26 = OpLabel +%28 = OpIAdd %int %27 %int_1 +OpBranch %23 +%29 = OpLabel +OpBranch %33 +%33 = OpLabel +%34 = OpPhi %int %int_0 %29 %35 %36 +%37 = OpPhi %int %int_0 %29 %38 %36 +OpLoopMerge %39 %36 None +OpBranch %40 +%40 = OpLabel +%41 = OpSLessThan %bool %37 %int_10 +OpBranchConditional %41 %42 %39 +%42 = OpLabel +%35 = OpIAdd %int %int_1 %int_2 +OpBranch %36 +%36 = OpLabel +%38 = OpIAdd %int %37 %int_1 +OpBranch %33 +%39 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string after_hoist = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 440 +OpName %main "main" +%void = OpTypeVoid +%4 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_1 = OpConstant %int 1 +%int_2 = OpConstant %int 2 +%int_0 = OpConstant %int 0 +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%main = OpFunction %void None %4 +%12 = OpLabel +%15 = OpIAdd %int %int_1 %int_2 +OpBranch %13 +%13 = OpLabel +%14 = OpPhi %int %int_0 %12 %15 %16 +%17 = OpPhi %int %int_0 %12 %18 %16 +OpLoopMerge %19 %16 None +OpBranch %20 +%20 = OpLabel +%21 = OpSLessThan %bool %17 %int_10 +OpBranchConditional %21 %22 %19 +%22 = OpLabel +OpBranch %16 +%16 = OpLabel +%18 = OpIAdd %int %17 %int_1 +OpBranch %13 +%19 = OpLabel +%25 = OpIAdd %int %int_1 %int_2 +OpBranch %23 +%23 = OpLabel +%24 = OpPhi %int %14 %19 %25 %26 +%27 = OpPhi %int %int_0 %19 %28 %26 +OpLoopMerge %29 %26 None +OpBranch %30 +%30 = OpLabel +%31 = OpSLessThan %bool %27 %int_10 +OpBranchConditional %31 %32 %29 +%32 = OpLabel +OpBranch %26 +%26 = OpLabel +%28 = OpIAdd %int %27 %int_1 +OpBranch %23 +%29 = OpLabel +%35 = OpIAdd %int %int_1 %int_2 +OpBranch %33 +%33 = OpLabel +%34 = OpPhi %int %int_0 %29 %35 %36 +%37 = OpPhi %int %int_0 %29 %38 %36 +OpLoopMerge %39 %36 None +OpBranch %40 +%40 = OpLabel +%41 = OpSLessThan %bool %37 %int_10 +OpBranchConditional %41 %42 %39 +%42 = OpLabel +OpBranch %36 +%36 = OpLabel +%38 = OpIAdd %int %37 %int_1 +OpBranch %33 +%39 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(before_hoist, after_hoist, true); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/loop_optimizations/hoist_simple_case.cpp b/test/opt/loop_optimizations/hoist_simple_case.cpp new file mode 100644 index 000000000..e973d9d29 --- /dev/null +++ b/test/opt/loop_optimizations/hoist_simple_case.cpp @@ -0,0 +1,126 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "source/opt/licm_pass.h" +#include "test/opt/pass_fixture.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::UnorderedElementsAre; +using PassClassTest = PassTest<::testing::Test>; + +/* + A simple test for the LICM pass + + Generated from the following GLSL fragment shader +--eliminate-local-multi-store has also been run on the spv binary +#version 440 core +void main(){ + int a = 1; + int b = 2; + int hoist = 0; + for (int i = 0; i < 10; i++) { + // invariant + hoist = a + b; + } +} +*/ +TEST_F(PassClassTest, SimpleHoist) { + const std::string before_hoist = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 440 +OpName %main "main" +%void = OpTypeVoid +%4 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_1 = OpConstant %int 1 +%int_2 = OpConstant %int 2 +%int_0 = OpConstant %int 0 +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%main = OpFunction %void None %4 +%12 = OpLabel +OpBranch %13 +%13 = OpLabel +%14 = OpPhi %int %int_0 %12 %15 %16 +%17 = OpPhi %int %int_0 %12 %18 %16 +OpLoopMerge %19 %16 None +OpBranch %20 +%20 = OpLabel +%21 = OpSLessThan %bool %17 %int_10 +OpBranchConditional %21 %22 %19 +%22 = OpLabel +%15 = OpIAdd %int %int_1 %int_2 +OpBranch %16 +%16 = OpLabel +%18 = OpIAdd %int %17 %int_1 +OpBranch %13 +%19 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string after_hoist = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 440 +OpName %main "main" +%void = OpTypeVoid +%4 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_1 = OpConstant %int 1 +%int_2 = OpConstant %int 2 +%int_0 = OpConstant %int 0 +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%main = OpFunction %void None %4 +%12 = OpLabel +%15 = OpIAdd %int %int_1 %int_2 +OpBranch %13 +%13 = OpLabel +%14 = OpPhi %int %int_0 %12 %15 %16 +%17 = OpPhi %int %int_0 %12 %18 %16 +OpLoopMerge %19 %16 None +OpBranch %20 +%20 = OpLabel +%21 = OpSLessThan %bool %17 %int_10 +OpBranchConditional %21 %22 %19 +%22 = OpLabel +OpBranch %16 +%16 = OpLabel +%18 = OpIAdd %int %17 %int_1 +OpBranch %13 +%19 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(before_hoist, after_hoist, true); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/loop_optimizations/hoist_single_nested_loops.cpp b/test/opt/loop_optimizations/hoist_single_nested_loops.cpp new file mode 100644 index 000000000..056f3f025 --- /dev/null +++ b/test/opt/loop_optimizations/hoist_single_nested_loops.cpp @@ -0,0 +1,209 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "source/opt/licm_pass.h" +#include "test/opt/pass_fixture.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::UnorderedElementsAre; + +using PassClassTest = PassTest<::testing::Test>; + +/* + Tests that the LICM pass will detect an move an invariant from a nested loop, + but not it's parent loop + + Generated from the following GLSL fragment shader +--eliminate-local-multi-store has also been run on the spv binary +#version 440 core +void main(){ + int a = 2; + int hoist = 0; + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + // hoist 'hoist = a - i' out of j loop, but not i loop + hoist = a - i; + } + } +} +*/ +TEST_F(PassClassTest, NestedSingleHoist) { + const std::string before_hoist = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 440 +OpName %main "main" +%void = OpTypeVoid +%4 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_2 = OpConstant %int 2 +%int_0 = OpConstant %int 0 +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%int_1 = OpConstant %int 1 +%12 = OpUndef %int +%main = OpFunction %void None %4 +%13 = OpLabel +OpBranch %14 +%14 = OpLabel +%15 = OpPhi %int %int_0 %13 %16 %17 +%18 = OpPhi %int %int_0 %13 %19 %17 +%20 = OpPhi %int %12 %13 %21 %17 +OpLoopMerge %22 %17 None +OpBranch %23 +%23 = OpLabel +%24 = OpSLessThan %bool %18 %int_10 +OpBranchConditional %24 %25 %22 +%25 = OpLabel +OpBranch %26 +%26 = OpLabel +%16 = OpPhi %int %15 %25 %27 %28 +%21 = OpPhi %int %int_0 %25 %29 %28 +OpLoopMerge %30 %28 None +OpBranch %31 +%31 = OpLabel +%32 = OpSLessThan %bool %21 %int_10 +OpBranchConditional %32 %33 %30 +%33 = OpLabel +%27 = OpISub %int %int_2 %18 +OpBranch %28 +%28 = OpLabel +%29 = OpIAdd %int %21 %int_1 +OpBranch %26 +%30 = OpLabel +OpBranch %17 +%17 = OpLabel +%19 = OpIAdd %int %18 %int_1 +OpBranch %14 +%22 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string after_hoist = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 440 +OpName %main "main" +%void = OpTypeVoid +%4 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_2 = OpConstant %int 2 +%int_0 = OpConstant %int 0 +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%int_1 = OpConstant %int 1 +%12 = OpUndef %int +%main = OpFunction %void None %4 +%13 = OpLabel +OpBranch %14 +%14 = OpLabel +%15 = OpPhi %int %int_0 %13 %16 %17 +%18 = OpPhi %int %int_0 %13 %19 %17 +%20 = OpPhi %int %12 %13 %21 %17 +OpLoopMerge %22 %17 None +OpBranch %23 +%23 = OpLabel +%24 = OpSLessThan %bool %18 %int_10 +OpBranchConditional %24 %25 %22 +%25 = OpLabel +%27 = OpISub %int %int_2 %18 +OpBranch %26 +%26 = OpLabel +%16 = OpPhi %int %15 %25 %27 %28 +%21 = OpPhi %int %int_0 %25 %29 %28 +OpLoopMerge %30 %28 None +OpBranch %31 +%31 = OpLabel +%32 = OpSLessThan %bool %21 %int_10 +OpBranchConditional %32 %33 %30 +%33 = OpLabel +OpBranch %28 +%28 = OpLabel +%29 = OpIAdd %int %21 %int_1 +OpBranch %26 +%30 = OpLabel +OpBranch %17 +%17 = OpLabel +%19 = OpIAdd %int %18 %int_1 +OpBranch %14 +%22 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(before_hoist, after_hoist, true); +} + +TEST_F(PassClassTest, PreHeaderIsAlsoHeader) { + // Move OpSLessThan out of the inner loop. The preheader for the inner loop + // is the header of the outer loop. The loop merge should not be separated + // from the branch in that block. + const std::string text = R"( + ; CHECK: OpFunction + ; CHECK-NEXT: OpLabel + ; CHECK-NEXT: OpBranch [[header:%\w+]] + ; CHECK: [[header]] = OpLabel + ; CHECK-NEXT: OpSLessThan %bool %int_1 %int_1 + ; CHECK-NEXT: OpLoopMerge + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + %void = OpTypeVoid + %4 = OpTypeFunction %void + %int = OpTypeInt 32 1 + %int_1 = OpConstant %int 1 + %bool = OpTypeBool + %2 = OpFunction %void None %4 + %18 = OpLabel + OpBranch %21 + %21 = OpLabel + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + %25 = OpSLessThan %bool %int_1 %int_1 + OpLoopMerge %26 %27 None + OpBranchConditional %25 %27 %26 + %27 = OpLabel + OpBranch %24 + %26 = OpLabel + OpBranch %22 + %23 = OpLabel + OpBranch %21 + %22 = OpLabel + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/loop_optimizations/hoist_without_preheader.cpp b/test/opt/loop_optimizations/hoist_without_preheader.cpp new file mode 100644 index 000000000..2e34b0142 --- /dev/null +++ b/test/opt/loop_optimizations/hoist_without_preheader.cpp @@ -0,0 +1,197 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "source/opt/licm_pass.h" +#include "test/opt/pass_fixture.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::UnorderedElementsAre; +using PassClassTest = PassTest<::testing::Test>; + +/* + Tests that the LICM pass will generate a preheader when one is not present + + Generated from the following GLSL fragment shader +--eliminate-local-multi-store has also been run on the spv binary +#version 440 core +void main(){ + int a = 1; + int b = 2; + int hoist = 0; + for (int i = 0; i < 10; i++) { + if (i == 5) { + break; + } + } + for (int i = 0; i < 10; i++) { + hoist = a + b; + } +} +*/ +TEST_F(PassClassTest, HoistWithoutPreheader) { + const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 440 +OpName %main "main" +%void = OpTypeVoid +%4 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_1 = OpConstant %int 1 +%int_2 = OpConstant %int 2 +%int_0 = OpConstant %int 0 +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%int_5 = OpConstant %int 5 +%main = OpFunction %void None %4 +%13 = OpLabel +OpBranch %14 +%14 = OpLabel +%15 = OpPhi %int %int_0 %13 %16 %17 +; CHECK: OpLoopMerge [[preheader:%\w+]] +OpLoopMerge %25 %17 None +OpBranch %19 +%19 = OpLabel +%20 = OpSLessThan %bool %15 %int_10 +OpBranchConditional %20 %21 %25 +%21 = OpLabel +%22 = OpIEqual %bool %15 %int_5 +OpSelectionMerge %23 None +OpBranchConditional %22 %24 %23 +%24 = OpLabel +OpBranch %25 +%23 = OpLabel +OpBranch %17 +%17 = OpLabel +%16 = OpIAdd %int %15 %int_1 +OpBranch %14 +; Check that we hoisted the code to the preheader +; CHECK: [[preheader]] = OpLabel +; CHECK-NEXT: OpPhi +; CHECK-NEXT: OpPhi +; CHECK-NEXT: OpIAdd +; CHECK-NEXT: OpBranch [[header:%\w+]] +; CHECK: [[header]] = OpLabel +; CHECK-NEXT: OpPhi +; CHECK-NEXT: OpPhi +; CHECK: OpLoopMerge +%25 = OpLabel +%26 = OpPhi %int %int_0 %24 %int_0 %19 %27 %28 +%29 = OpPhi %int %int_0 %24 %int_0 %19 %30 %28 +OpLoopMerge %31 %28 None +OpBranch %32 +%32 = OpLabel +%33 = OpSLessThan %bool %29 %int_10 +OpBranchConditional %33 %34 %31 +%34 = OpLabel +%27 = OpIAdd %int %int_1 %int_2 +OpBranch %28 +%28 = OpLabel +%30 = OpIAdd %int %29 %int_1 +OpBranch %25 +%31 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, false); +} + +TEST_F(PassClassTest, HoistWithoutPreheaderAtIdBound) { + const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 440 +OpName %main "main" +%void = OpTypeVoid +%4 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_1 = OpConstant %int 1 +%int_2 = OpConstant %int 2 +%int_0 = OpConstant %int 0 +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%int_5 = OpConstant %int 5 +%main = OpFunction %void None %4 +%13 = OpLabel +OpBranch %14 +%14 = OpLabel +%15 = OpPhi %int %int_0 %13 %16 %17 +OpLoopMerge %25 %17 None +OpBranch %19 +%19 = OpLabel +%20 = OpSLessThan %bool %15 %int_10 +OpBranchConditional %20 %21 %25 +%21 = OpLabel +%22 = OpIEqual %bool %15 %int_5 +OpSelectionMerge %23 None +OpBranchConditional %22 %24 %23 +%24 = OpLabel +OpBranch %25 +%23 = OpLabel +OpBranch %17 +%17 = OpLabel +%16 = OpIAdd %int %15 %int_1 +OpBranch %14 +%25 = OpLabel +%26 = OpPhi %int %int_0 %24 %int_0 %19 %27 %28 +%29 = OpPhi %int %int_0 %24 %int_0 %19 %30 %28 +OpLoopMerge %31 %28 None +OpBranch %32 +%32 = OpLabel +%33 = OpSLessThan %bool %29 %int_10 +OpBranchConditional %33 %34 %31 +%34 = OpLabel +%27 = OpIAdd %int %int_1 %int_2 +OpBranch %28 +%28 = OpLabel +%30 = OpIAdd %int %29 %int_1 +OpBranch %25 +%31 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + uint32_t current_bound = context->module()->id_bound(); + context->set_max_id_bound(current_bound); + + auto pass = MakeUnique(); + auto result = pass->Run(context.get()); + EXPECT_EQ(result, Pass::Status::Failure); + + std::vector binary; + context->module()->ToBinary(&binary, false); + std::string optimized_asm; + SpirvTools tools_(SPV_ENV_UNIVERSAL_1_1); + tools_.Disassemble(binary, &optimized_asm); + std::cout << optimized_asm << std::endl; +} +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/loop_optimizations/lcssa.cpp b/test/opt/loop_optimizations/lcssa.cpp new file mode 100644 index 000000000..ace6ce196 --- /dev/null +++ b/test/opt/loop_optimizations/lcssa.cpp @@ -0,0 +1,607 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "effcee/effcee.h" +#include "gmock/gmock.h" +#include "source/opt/build_module.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/loop_utils.h" +#include "source/opt/pass.h" +#include "test/opt//assembly_builder.h" +#include "test/opt/function_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +bool Validate(const std::vector& bin) { + spv_target_env target_env = SPV_ENV_UNIVERSAL_1_2; + spv_context spvContext = spvContextCreate(target_env); + spv_diagnostic diagnostic = nullptr; + spv_const_binary_t binary = {bin.data(), bin.size()}; + spv_result_t error = spvValidate(spvContext, &binary, &diagnostic); + if (error != 0) spvDiagnosticPrint(diagnostic); + spvDiagnosticDestroy(diagnostic); + spvContextDestroy(spvContext); + return error == 0; +} + +void Match(const std::string& original, IRContext* context, + bool do_validation = true) { + std::vector bin; + context->module()->ToBinary(&bin, true); + if (do_validation) { + EXPECT_TRUE(Validate(bin)); + } + std::string assembly; + SpirvTools tools(SPV_ENV_UNIVERSAL_1_2); + EXPECT_TRUE( + tools.Disassemble(bin, &assembly, SPV_BINARY_TO_TEXT_OPTION_NO_HEADER)) + << "Disassembling failed for shader:\n" + << assembly << std::endl; + auto match_result = effcee::Match(assembly, original); + EXPECT_EQ(effcee::Result::Status::Ok, match_result.status()) + << match_result.message() << "\nChecking result:\n" + << assembly; +} + +using LCSSATest = ::testing::Test; + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 330 core +layout(location = 0) out vec4 c; +void main() { + int i = 0; + for (; i < 10; i++) { + } + if (i != 0) { + i = 1; + } +} +*/ +TEST_F(LCSSATest, SimpleLCSSA) { + const std::string text = R"( +; CHECK: OpLoopMerge [[merge:%\w+]] %19 None +; CHECK: [[merge]] = OpLabel +; CHECK-NEXT: [[phi:%\w+]] = OpPhi {{%\w+}} %30 %20 +; CHECK-NEXT: %27 = OpINotEqual {{%\w+}} [[phi]] %9 + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 330 + OpName %2 "main" + OpName %3 "c" + OpDecorate %3 Location 0 + %5 = OpTypeVoid + %6 = OpTypeFunction %5 + %7 = OpTypeInt 32 1 + %8 = OpTypePointer Function %7 + %9 = OpConstant %7 0 + %10 = OpConstant %7 10 + %11 = OpTypeBool + %12 = OpConstant %7 1 + %13 = OpTypeFloat 32 + %14 = OpTypeVector %13 4 + %15 = OpTypePointer Output %14 + %3 = OpVariable %15 Output + %2 = OpFunction %5 None %6 + %16 = OpLabel + OpBranch %17 + %17 = OpLabel + %30 = OpPhi %7 %9 %16 %25 %19 + OpLoopMerge %18 %19 None + OpBranch %20 + %20 = OpLabel + %22 = OpSLessThan %11 %30 %10 + OpBranchConditional %22 %23 %18 + %23 = OpLabel + OpBranch %19 + %19 = OpLabel + %25 = OpIAdd %7 %30 %12 + OpBranch %17 + %18 = OpLabel + %27 = OpINotEqual %11 %30 %9 + OpSelectionMerge %28 None + OpBranchConditional %27 %29 %28 + %29 = OpLabel + OpBranch %28 + %28 = OpLabel + %31 = OpPhi %7 %30 %18 %12 %29 + OpReturn + OpFunctionEnd + )"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor ld{context.get(), f}; + + Loop* loop = ld[17]; + EXPECT_FALSE(loop->IsLCSSA()); + LoopUtils Util(context.get(), loop); + Util.MakeLoopClosedSSA(); + EXPECT_TRUE(loop->IsLCSSA()); + Match(text, context.get()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 330 core +layout(location = 0) out vec4 c; +void main() { + int i = 0; + for (; i < 10; i++) { + } + if (i != 0) { + i = 1; + } +} +*/ +// Same test as above, but should reuse an existing phi. +TEST_F(LCSSATest, PhiReuseLCSSA) { + const std::string text = R"( +; CHECK: OpLoopMerge [[merge:%\w+]] %19 None +; CHECK: [[merge]] = OpLabel +; CHECK-NEXT: [[phi:%\w+]] = OpPhi {{%\w+}} %30 %20 +; CHECK-NEXT: %27 = OpINotEqual {{%\w+}} [[phi]] %9 + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 330 + OpName %2 "main" + OpName %3 "c" + OpDecorate %3 Location 0 + %5 = OpTypeVoid + %6 = OpTypeFunction %5 + %7 = OpTypeInt 32 1 + %8 = OpTypePointer Function %7 + %9 = OpConstant %7 0 + %10 = OpConstant %7 10 + %11 = OpTypeBool + %12 = OpConstant %7 1 + %13 = OpTypeFloat 32 + %14 = OpTypeVector %13 4 + %15 = OpTypePointer Output %14 + %3 = OpVariable %15 Output + %2 = OpFunction %5 None %6 + %16 = OpLabel + OpBranch %17 + %17 = OpLabel + %30 = OpPhi %7 %9 %16 %25 %19 + OpLoopMerge %18 %19 None + OpBranch %20 + %20 = OpLabel + %22 = OpSLessThan %11 %30 %10 + OpBranchConditional %22 %23 %18 + %23 = OpLabel + OpBranch %19 + %19 = OpLabel + %25 = OpIAdd %7 %30 %12 + OpBranch %17 + %18 = OpLabel + %32 = OpPhi %7 %30 %20 + %27 = OpINotEqual %11 %30 %9 + OpSelectionMerge %28 None + OpBranchConditional %27 %29 %28 + %29 = OpLabel + OpBranch %28 + %28 = OpLabel + %31 = OpPhi %7 %30 %18 %12 %29 + OpReturn + OpFunctionEnd + )"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor ld{context.get(), f}; + + Loop* loop = ld[17]; + EXPECT_FALSE(loop->IsLCSSA()); + LoopUtils Util(context.get(), loop); + Util.MakeLoopClosedSSA(); + EXPECT_TRUE(loop->IsLCSSA()); + Match(text, context.get()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 330 core +layout(location = 0) out vec4 c; +void main() { + int i = 0; + int j = 0; + for (; i < 10; i++) {} + for (; j < 10; j++) {} + if (j != 0) { + i = 1; + } +} +*/ +TEST_F(LCSSATest, DualLoopLCSSA) { + const std::string text = R"( +; CHECK: %20 = OpLabel +; CHECK-NEXT: [[phi:%\w+]] = OpPhi %6 %17 %21 +; CHECK: %33 = OpLabel +; CHECK-NEXT: {{%\w+}} = OpPhi {{%\w+}} [[phi]] %28 %11 %34 + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 330 + OpName %2 "main" + OpName %3 "c" + OpDecorate %3 Location 0 + %4 = OpTypeVoid + %5 = OpTypeFunction %4 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %8 = OpConstant %6 0 + %9 = OpConstant %6 10 + %10 = OpTypeBool + %11 = OpConstant %6 1 + %12 = OpTypeFloat 32 + %13 = OpTypeVector %12 4 + %14 = OpTypePointer Output %13 + %3 = OpVariable %14 Output + %2 = OpFunction %4 None %5 + %15 = OpLabel + OpBranch %16 + %16 = OpLabel + %17 = OpPhi %6 %8 %15 %18 %19 + OpLoopMerge %20 %19 None + OpBranch %21 + %21 = OpLabel + %22 = OpSLessThan %10 %17 %9 + OpBranchConditional %22 %23 %20 + %23 = OpLabel + OpBranch %19 + %19 = OpLabel + %18 = OpIAdd %6 %17 %11 + OpBranch %16 + %20 = OpLabel + OpBranch %24 + %24 = OpLabel + %25 = OpPhi %6 %8 %20 %26 %27 + OpLoopMerge %28 %27 None + OpBranch %29 + %29 = OpLabel + %30 = OpSLessThan %10 %25 %9 + OpBranchConditional %30 %31 %28 + %31 = OpLabel + OpBranch %27 + %27 = OpLabel + %26 = OpIAdd %6 %25 %11 + OpBranch %24 + %28 = OpLabel + %32 = OpINotEqual %10 %25 %8 + OpSelectionMerge %33 None + OpBranchConditional %32 %34 %33 + %34 = OpLabel + OpBranch %33 + %33 = OpLabel + %35 = OpPhi %6 %17 %28 %11 %34 + OpReturn + OpFunctionEnd + )"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor ld{context.get(), f}; + + Loop* loop = ld[16]; + EXPECT_FALSE(loop->IsLCSSA()); + LoopUtils Util(context.get(), loop); + Util.MakeLoopClosedSSA(); + EXPECT_TRUE(loop->IsLCSSA()); + Match(text, context.get()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 330 core +layout(location = 0) out vec4 c; +void main() { + int i = 0; + if (i != 0) { + for (; i < 10; i++) {} + } + if (i != 0) { + i = 1; + } +} +*/ +TEST_F(LCSSATest, PhiUserLCSSA) { + const std::string text = R"( +; CHECK: OpLoopMerge [[merge:%\w+]] %22 None +; CHECK: [[merge]] = OpLabel +; CHECK-NEXT: [[phi:%\w+]] = OpPhi {{%\w+}} %20 %24 +; CHECK: %17 = OpLabel +; CHECK-NEXT: {{%\w+}} = OpPhi {{%\w+}} %8 %15 [[phi]] %23 + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 330 + OpName %2 "main" + OpName %3 "c" + OpDecorate %3 Location 0 + %4 = OpTypeVoid + %5 = OpTypeFunction %4 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %8 = OpConstant %6 0 + %9 = OpTypeBool + %10 = OpConstant %6 10 + %11 = OpConstant %6 1 + %12 = OpTypeFloat 32 + %13 = OpTypeVector %12 4 + %14 = OpTypePointer Output %13 + %3 = OpVariable %14 Output + %2 = OpFunction %4 None %5 + %15 = OpLabel + %16 = OpINotEqual %9 %8 %8 + OpSelectionMerge %17 None + OpBranchConditional %16 %18 %17 + %18 = OpLabel + OpBranch %19 + %19 = OpLabel + %20 = OpPhi %6 %8 %18 %21 %22 + OpLoopMerge %23 %22 None + OpBranch %24 + %24 = OpLabel + %25 = OpSLessThan %9 %20 %10 + OpBranchConditional %25 %26 %23 + %26 = OpLabel + OpBranch %22 + %22 = OpLabel + %21 = OpIAdd %6 %20 %11 + OpBranch %19 + %23 = OpLabel + OpBranch %17 + %17 = OpLabel + %27 = OpPhi %6 %8 %15 %20 %23 + %28 = OpINotEqual %9 %27 %8 + OpSelectionMerge %29 None + OpBranchConditional %28 %30 %29 + %30 = OpLabel + OpBranch %29 + %29 = OpLabel + %31 = OpPhi %6 %27 %17 %11 %30 + OpReturn + OpFunctionEnd + )"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor ld{context.get(), f}; + + Loop* loop = ld[19]; + EXPECT_FALSE(loop->IsLCSSA()); + LoopUtils Util(context.get(), loop); + Util.MakeLoopClosedSSA(); + EXPECT_TRUE(loop->IsLCSSA()); + Match(text, context.get()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 330 core +void main() { + int i = 0; + if (i != 0) { + for (; i < 10; i++) { + if (i > 5) break; + } + } + if (i != 0) { + i = 1; + } +} +*/ +TEST_F(LCSSATest, LCSSAWithBreak) { + const std::string text = R"( +; CHECK: OpLoopMerge [[merge:%\w+]] %19 None +; CHECK: [[merge]] = OpLabel +; CHECK-NEXT: [[phi:%\w+]] = OpPhi {{%\w+}} %17 %21 %17 %26 +; CHECK: %14 = OpLabel +; CHECK-NEXT: {{%\w+}} = OpPhi {{%\w+}} %7 %12 [[phi]] [[merge]] + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 330 + OpName %2 "main" + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeInt 32 1 + %6 = OpTypePointer Function %5 + %7 = OpConstant %5 0 + %8 = OpTypeBool + %9 = OpConstant %5 10 + %10 = OpConstant %5 5 + %11 = OpConstant %5 1 + %2 = OpFunction %3 None %4 + %12 = OpLabel + %13 = OpINotEqual %8 %7 %7 + OpSelectionMerge %14 None + OpBranchConditional %13 %15 %14 + %15 = OpLabel + OpBranch %16 + %16 = OpLabel + %17 = OpPhi %5 %7 %15 %18 %19 + OpLoopMerge %20 %19 None + OpBranch %21 + %21 = OpLabel + %22 = OpSLessThan %8 %17 %9 + OpBranchConditional %22 %23 %20 + %23 = OpLabel + %24 = OpSGreaterThan %8 %17 %10 + OpSelectionMerge %25 None + OpBranchConditional %24 %26 %25 + %26 = OpLabel + OpBranch %20 + %25 = OpLabel + OpBranch %19 + %19 = OpLabel + %18 = OpIAdd %5 %17 %11 + OpBranch %16 + %20 = OpLabel + OpBranch %14 + %14 = OpLabel + %27 = OpPhi %5 %7 %12 %17 %20 + %28 = OpINotEqual %8 %27 %7 + OpSelectionMerge %29 None + OpBranchConditional %28 %30 %29 + %30 = OpLabel + OpBranch %29 + %29 = OpLabel + %31 = OpPhi %5 %27 %14 %11 %30 + OpReturn + OpFunctionEnd + )"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor ld{context.get(), f}; + + Loop* loop = ld[19]; + EXPECT_FALSE(loop->IsLCSSA()); + LoopUtils Util(context.get(), loop); + Util.MakeLoopClosedSSA(); + EXPECT_TRUE(loop->IsLCSSA()); + Match(text, context.get()); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 330 core +void main() { + int i = 0; + for (; i < 10; i++) {} + for (int j = i; j < 10;) { j = i + j; } +} +*/ +TEST_F(LCSSATest, LCSSAUseInNonEligiblePhi) { + const std::string text = R"( +; CHECK: %12 = OpLabel +; CHECK-NEXT: [[def_to_close:%\w+]] = OpPhi {{%\w+}} {{%\w+}} {{%\w+}} {{%\w+}} [[continue:%\w+]] +; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None +; CHECK: [[merge]] = OpLabel +; CHECK-NEXT: [[closing_phi:%\w+]] = OpPhi {{%\w+}} [[def_to_close]] %17 +; CHECK: %16 = OpLabel +; CHECK-NEXT: [[use_in_phi:%\w+]] = OpPhi {{%\w+}} %21 %22 [[closing_phi]] [[merge]] +; CHECK: OpIAdd {{%\w+}} [[closing_phi]] [[use_in_phi]] + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 330 + OpName %2 "main" + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeInt 32 1 + %6 = OpTypePointer Function %5 + %7 = OpConstant %5 0 + %8 = OpConstant %5 10 + %9 = OpTypeBool + %10 = OpConstant %5 1 + %2 = OpFunction %3 None %4 + %11 = OpLabel + OpBranch %12 + %12 = OpLabel + %13 = OpPhi %5 %7 %11 %14 %15 + OpLoopMerge %16 %15 None + OpBranch %17 + %17 = OpLabel + %18 = OpSLessThan %9 %13 %8 + OpBranchConditional %18 %19 %16 + %19 = OpLabel + OpBranch %15 + %15 = OpLabel + %14 = OpIAdd %5 %13 %10 + OpBranch %12 + %16 = OpLabel + %20 = OpPhi %5 %13 %17 %21 %22 + OpLoopMerge %23 %22 None + OpBranch %24 + %24 = OpLabel + %25 = OpSLessThan %9 %20 %8 + OpBranchConditional %25 %26 %23 + %26 = OpLabel + %21 = OpIAdd %5 %13 %20 + OpBranch %22 + %22 = OpLabel + OpBranch %16 + %23 = OpLabel + OpReturn + OpFunctionEnd + )"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor ld{context.get(), f}; + + Loop* loop = ld[12]; + EXPECT_FALSE(loop->IsLCSSA()); + LoopUtils Util(context.get(), loop); + Util.MakeLoopClosedSSA(); + EXPECT_TRUE(loop->IsLCSSA()); + Match(text, context.get()); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/loop_optimizations/loop_descriptions.cpp b/test/opt/loop_optimizations/loop_descriptions.cpp new file mode 100644 index 000000000..91dbdc6b5 --- /dev/null +++ b/test/opt/loop_optimizations/loop_descriptions.cpp @@ -0,0 +1,384 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::UnorderedElementsAre; +using PassClassTest = PassTest<::testing::Test>; + +/* +Generated from the following GLSL +#version 330 core +layout(location = 0) out vec4 c; +void main() { + int i = 0; + for(; i < 10; ++i) { + } +} +*/ +TEST_F(PassClassTest, BasicVisitFromEntryPoint) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 330 + OpName %2 "main" + OpName %5 "i" + OpName %3 "c" + OpDecorate %3 Location 0 + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %10 = OpConstant %8 0 + %11 = OpConstant %8 10 + %12 = OpTypeBool + %13 = OpConstant %8 1 + %14 = OpTypeFloat 32 + %15 = OpTypeVector %14 4 + %16 = OpTypePointer Output %15 + %3 = OpVariable %16 Output + %2 = OpFunction %6 None %7 + %17 = OpLabel + %5 = OpVariable %9 Function + OpStore %5 %10 + OpBranch %18 + %18 = OpLabel + OpLoopMerge %19 %20 None + OpBranch %21 + %21 = OpLabel + %22 = OpLoad %8 %5 + %23 = OpSLessThan %12 %22 %11 + OpBranchConditional %23 %24 %19 + %24 = OpLabel + OpBranch %20 + %20 = OpLabel + %25 = OpLoad %8 %5 + %26 = OpIAdd %8 %25 %13 + OpStore %5 %26 + OpBranch %18 + %19 = OpLabel + OpReturn + OpFunctionEnd + )"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + EXPECT_EQ(ld.NumLoops(), 1u); + + Loop& loop = ld.GetLoopByIndex(0); + EXPECT_EQ(loop.GetHeaderBlock(), spvtest::GetBasicBlock(f, 18)); + EXPECT_EQ(loop.GetLatchBlock(), spvtest::GetBasicBlock(f, 20)); + EXPECT_EQ(loop.GetMergeBlock(), spvtest::GetBasicBlock(f, 19)); + + EXPECT_FALSE(loop.HasNestedLoops()); + EXPECT_FALSE(loop.IsNested()); + EXPECT_EQ(loop.GetDepth(), 1u); +} + +/* +Generated from the following GLSL: +#version 330 core +layout(location = 0) out vec4 c; +void main() { + for(int i = 0; i < 10; ++i) {} + for(int i = 0; i < 10; ++i) {} +} + +But it was "hacked" to make the first loop merge block the second loop header. +*/ +TEST_F(PassClassTest, LoopWithNoPreHeader) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 330 + OpName %2 "main" + OpName %4 "i" + OpName %5 "i" + OpName %3 "c" + OpDecorate %3 Location 0 + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %10 = OpConstant %8 0 + %11 = OpConstant %8 10 + %12 = OpTypeBool + %13 = OpConstant %8 1 + %14 = OpTypeFloat 32 + %15 = OpTypeVector %14 4 + %16 = OpTypePointer Output %15 + %3 = OpVariable %16 Output + %2 = OpFunction %6 None %7 + %17 = OpLabel + %4 = OpVariable %9 Function + %5 = OpVariable %9 Function + OpStore %4 %10 + OpStore %5 %10 + OpBranch %18 + %18 = OpLabel + OpLoopMerge %27 %20 None + OpBranch %21 + %21 = OpLabel + %22 = OpLoad %8 %4 + %23 = OpSLessThan %12 %22 %11 + OpBranchConditional %23 %24 %27 + %24 = OpLabel + OpBranch %20 + %20 = OpLabel + %25 = OpLoad %8 %4 + %26 = OpIAdd %8 %25 %13 + OpStore %4 %26 + OpBranch %18 + %27 = OpLabel + OpLoopMerge %28 %29 None + OpBranch %30 + %30 = OpLabel + %31 = OpLoad %8 %5 + %32 = OpSLessThan %12 %31 %11 + OpBranchConditional %32 %33 %28 + %33 = OpLabel + OpBranch %29 + %29 = OpLabel + %34 = OpLoad %8 %5 + %35 = OpIAdd %8 %34 %13 + OpStore %5 %35 + OpBranch %27 + %28 = OpLabel + OpReturn + OpFunctionEnd + )"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + EXPECT_EQ(ld.NumLoops(), 2u); + + Loop* loop = ld[27]; + EXPECT_EQ(loop->GetPreHeaderBlock(), nullptr); + EXPECT_NE(loop->GetOrCreatePreHeaderBlock(), nullptr); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 330 core +in vec4 c; +void main() { + int i = 0; + bool cond = c[0] == 0; + for (; i < 10; i++) { + if (cond) { + return; + } + else { + return; + } + } + bool cond2 = i == 9; +} +*/ +TEST_F(PassClassTest, NoLoop) { + const std::string text = R"(; SPIR-V +; Version: 1.0 +; Generator: Khronos Glslang Reference Front End; 3 +; Bound: 47 +; Schema: 0 + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %16 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 330 + OpName %4 "main" + OpName %16 "c" + OpDecorate %16 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %10 = OpTypeBool + %11 = OpTypePointer Function %10 + %13 = OpTypeFloat 32 + %14 = OpTypeVector %13 4 + %15 = OpTypePointer Input %14 + %16 = OpVariable %15 Input + %17 = OpTypeInt 32 0 + %18 = OpConstant %17 0 + %19 = OpTypePointer Input %13 + %22 = OpConstant %13 0 + %30 = OpConstant %6 10 + %39 = OpConstant %6 1 + %46 = OpUndef %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %20 = OpAccessChain %19 %16 %18 + %21 = OpLoad %13 %20 + %23 = OpFOrdEqual %10 %21 %22 + OpBranch %24 + %24 = OpLabel + %45 = OpPhi %6 %9 %5 %40 %27 + OpLoopMerge %26 %27 None + OpBranch %28 + %28 = OpLabel + %31 = OpSLessThan %10 %45 %30 + OpBranchConditional %31 %25 %26 + %25 = OpLabel + OpSelectionMerge %34 None + OpBranchConditional %23 %33 %36 + %33 = OpLabel + OpReturn + %36 = OpLabel + OpReturn + %34 = OpLabel + OpBranch %27 + %27 = OpLabel + %40 = OpIAdd %6 %46 %39 + OpBranch %24 + %26 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 4); + LoopDescriptor ld{context.get(), f}; + + EXPECT_EQ(ld.NumLoops(), 0u); +} + +/* +Generated from following GLSL with latch block artificially inserted to be +seperate from continue. +#version 430 +void main(void) { + float x[10]; + for (int i = 0; i < 10; ++i) { + x[i] = i; + } +} +*/ +TEST_F(PassClassTest, LoopLatchNotContinue) { + const std::string text = R"(OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %3 "i" + OpName %4 "x" + %5 = OpTypeVoid + %6 = OpTypeFunction %5 + %7 = OpTypeInt 32 1 + %8 = OpTypePointer Function %7 + %9 = OpConstant %7 0 + %10 = OpConstant %7 10 + %11 = OpTypeBool + %12 = OpTypeFloat 32 + %13 = OpTypeInt 32 0 + %14 = OpConstant %13 10 + %15 = OpTypeArray %12 %14 + %16 = OpTypePointer Function %15 + %17 = OpTypePointer Function %12 + %18 = OpConstant %7 1 + %2 = OpFunction %5 None %6 + %19 = OpLabel + %3 = OpVariable %8 Function + %4 = OpVariable %16 Function + OpStore %3 %9 + OpBranch %20 + %20 = OpLabel + %21 = OpPhi %7 %9 %19 %22 %30 + OpLoopMerge %24 %23 None + OpBranch %25 + %25 = OpLabel + %26 = OpSLessThan %11 %21 %10 + OpBranchConditional %26 %27 %24 + %27 = OpLabel + %28 = OpConvertSToF %12 %21 + %29 = OpAccessChain %17 %4 %21 + OpStore %29 %28 + OpBranch %23 + %23 = OpLabel + %22 = OpIAdd %7 %21 %18 + OpStore %3 %22 + OpBranch %30 + %30 = OpLabel + OpBranch %20 + %24 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor ld{context.get(), f}; + + EXPECT_EQ(ld.NumLoops(), 1u); + + Loop& loop = ld.GetLoopByIndex(0u); + + EXPECT_NE(loop.GetLatchBlock(), loop.GetContinueBlock()); + + EXPECT_EQ(loop.GetContinueBlock()->id(), 23u); + EXPECT_EQ(loop.GetLatchBlock()->id(), 30u); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/loop_optimizations/loop_fission.cpp b/test/opt/loop_optimizations/loop_fission.cpp new file mode 100644 index 000000000..e513f4253 --- /dev/null +++ b/test/opt/loop_optimizations/loop_fission.cpp @@ -0,0 +1,3491 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/loop_fission.h" +#include "source/opt/loop_unroller.h" +#include "source/opt/loop_utils.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::UnorderedElementsAre; +using FissionClassTest = PassTest<::testing::Test>; + +/* +Generated from the following GLSL + +#version 430 + +void main(void) { + float A[10]; + float B[10]; + for (int i = 0; i < 10; i++) { + A[i] = B[i]; + B[i] = A[i]; + } +} + +Result should be equivalent to: + +void main(void) { + float A[10]; + float B[10]; + for (int i = 0; i < 10; i++) { + A[i] = B[i]; + } + + for (int i = 0; i < 10; i++) { + B[i] = A[i]; + } +} +*/ +TEST_F(FissionClassTest, SimpleFission) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string source = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "A" +OpName %5 "B" +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 0 +%11 = OpConstant %8 10 +%12 = OpTypeBool +%13 = OpTypeFloat 32 +%14 = OpTypeInt 32 0 +%15 = OpConstant %14 10 +%16 = OpTypeArray %13 %15 +%17 = OpTypePointer Function %16 +%18 = OpTypePointer Function %13 +%19 = OpConstant %8 1 +%2 = OpFunction %6 None %7 +%20 = OpLabel +%3 = OpVariable %9 Function +%4 = OpVariable %17 Function +%5 = OpVariable %17 Function +OpBranch %21 +%21 = OpLabel +%22 = OpPhi %8 %10 %20 %23 %24 +OpLoopMerge %25 %24 None +OpBranch %26 +%26 = OpLabel +%27 = OpSLessThan %12 %22 %11 +OpBranchConditional %27 %28 %25 +%28 = OpLabel +%29 = OpAccessChain %18 %5 %22 +%30 = OpLoad %13 %29 +%31 = OpAccessChain %18 %4 %22 +OpStore %31 %30 +%32 = OpAccessChain %18 %4 %22 +%33 = OpLoad %13 %32 +%34 = OpAccessChain %18 %5 %22 +OpStore %34 %33 +OpBranch %24 +%24 = OpLabel +%23 = OpIAdd %8 %22 %19 +OpBranch %21 +%25 = OpLabel +OpReturn +OpFunctionEnd +)"; + +const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "A" +OpName %5 "B" +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 0 +%11 = OpConstant %8 10 +%12 = OpTypeBool +%13 = OpTypeFloat 32 +%14 = OpTypeInt 32 0 +%15 = OpConstant %14 10 +%16 = OpTypeArray %13 %15 +%17 = OpTypePointer Function %16 +%18 = OpTypePointer Function %13 +%19 = OpConstant %8 1 +%2 = OpFunction %6 None %7 +%20 = OpLabel +%3 = OpVariable %9 Function +%4 = OpVariable %17 Function +%5 = OpVariable %17 Function +OpBranch %35 +%35 = OpLabel +%36 = OpPhi %8 %10 %20 %47 %46 +OpLoopMerge %48 %46 None +OpBranch %37 +%37 = OpLabel +%38 = OpSLessThan %12 %36 %11 +OpBranchConditional %38 %39 %48 +%39 = OpLabel +%40 = OpAccessChain %18 %5 %36 +%41 = OpLoad %13 %40 +%42 = OpAccessChain %18 %4 %36 +OpStore %42 %41 +OpBranch %46 +%46 = OpLabel +%47 = OpIAdd %8 %36 %19 +OpBranch %35 +%48 = OpLabel +OpBranch %21 +%21 = OpLabel +%22 = OpPhi %8 %10 %48 %23 %24 +OpLoopMerge %25 %24 None +OpBranch %26 +%26 = OpLabel +%27 = OpSLessThan %12 %22 %11 +OpBranchConditional %27 %28 %25 +%28 = OpLabel +%32 = OpAccessChain %18 %4 %22 +%33 = OpLoad %13 %32 +%34 = OpAccessChain %18 %5 %22 +OpStore %34 %33 +OpBranch %24 +%24 = OpLabel +%23 = OpIAdd %8 %22 %19 +OpBranch %21 +%25 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, expected, true); + + // Check that the loop will NOT be split when provided with a pass-through + // register pressure functor which just returns false. + SinglePassRunAndCheck( + source, source, true, + [](const RegisterLiveness::RegionRegisterLiveness&) { return false; }); +} + +/* +Generated from the following GLSL + +#version 430 + +void main(void) { + float A[10]; + float B[10]; + for (int i = 0; i < 10; i++) { + A[i] = B[i]; + B[i] = A[i+1]; + } +} + +This loop should not be split, as the i+1 dependence would be broken by +splitting the loop. +*/ + +TEST_F(FissionClassTest, FissionInterdependency) { + // clang-format off + // With LocalMultiStoreElimPass + const std::string source = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "A" +OpName %5 "B" +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 0 +%11 = OpConstant %8 10 +%12 = OpTypeBool +%13 = OpTypeFloat 32 +%14 = OpTypeInt 32 0 +%15 = OpConstant %14 10 +%16 = OpTypeArray %13 %15 +%17 = OpTypePointer Function %16 +%18 = OpTypePointer Function %13 +%19 = OpConstant %8 1 +%2 = OpFunction %6 None %7 +%20 = OpLabel +%3 = OpVariable %9 Function +%4 = OpVariable %17 Function +%5 = OpVariable %17 Function +OpBranch %21 +%21 = OpLabel +%22 = OpPhi %8 %10 %20 %23 %24 +OpLoopMerge %25 %24 None +OpBranch %26 +%26 = OpLabel +%27 = OpSLessThan %12 %22 %11 +OpBranchConditional %27 %28 %25 +%28 = OpLabel +%29 = OpAccessChain %18 %5 %22 +%30 = OpLoad %13 %29 +%31 = OpAccessChain %18 %4 %22 +OpStore %31 %30 +%32 = OpIAdd %8 %22 %19 +%33 = OpAccessChain %18 %4 %32 +%34 = OpLoad %13 %33 +%35 = OpAccessChain %18 %5 %22 +OpStore %35 %34 +OpBranch %24 +%24 = OpLabel +%23 = OpIAdd %8 %22 %19 +OpBranch %21 +%25 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, source, true); +} + +/* +Generated from the following GLSL + +#version 430 + +void main(void) { + float A[10]; + float B[10]; + for (int i = 0; i < 10; i++) { + A[i] = B[i]; + B[i+1] = A[i]; + } +} + + +This should not be split as the load B[i] is dependent on the store B[i+1] +*/ +TEST_F(FissionClassTest, FissionInterdependency2) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string source = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "A" +OpName %5 "B" +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 0 +%11 = OpConstant %8 10 +%12 = OpTypeBool +%13 = OpTypeFloat 32 +%14 = OpTypeInt 32 0 +%15 = OpConstant %14 10 +%16 = OpTypeArray %13 %15 +%17 = OpTypePointer Function %16 +%18 = OpTypePointer Function %13 +%19 = OpConstant %8 1 +%2 = OpFunction %6 None %7 +%20 = OpLabel +%3 = OpVariable %9 Function +%4 = OpVariable %17 Function +%5 = OpVariable %17 Function +OpBranch %21 +%21 = OpLabel +%22 = OpPhi %8 %10 %20 %23 %24 +OpLoopMerge %25 %24 None +OpBranch %26 +%26 = OpLabel +%27 = OpSLessThan %12 %22 %11 +OpBranchConditional %27 %28 %25 +%28 = OpLabel +%29 = OpAccessChain %18 %5 %22 +%30 = OpLoad %13 %29 +%31 = OpAccessChain %18 %4 %22 +OpStore %31 %30 +%32 = OpIAdd %8 %22 %19 +%33 = OpAccessChain %18 %4 %22 +%34 = OpLoad %13 %33 +%35 = OpAccessChain %18 %5 %32 +OpStore %35 %34 +OpBranch %24 +%24 = OpLabel +%23 = OpIAdd %8 %22 %19 +OpBranch %21 +%25 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, source, true); +} + +/* +#version 430 +void main(void) { + float A[10]; + float B[10]; + float C[10] + float D[10] + for (int i = 0; i < 10; i++) { + A[i] = B[i]; + B[i] = A[i]; + C[i] = D[i]; + D[i] = C[i]; + } +} + +This should be split into the equivalent of: + + for (int i = 0; i < 10; i++) { + A[i] = B[i]; + B[i] = A[i]; + } + for (int i = 0; i < 10; i++) { + C[i] = D[i]; + D[i] = C[i]; + } + +We then check that the loop is broken into four for loops like so, if the pass +is run twice: + for (int i = 0; i < 10; i++) + A[i] = B[i]; + for (int i = 0; i < 10; i++) + B[i] = A[i]; + for (int i = 0; i < 10; i++) + C[i] = D[i]; + for (int i = 0; i < 10; i++) + D[i] = C[i]; + +*/ + +TEST_F(FissionClassTest, FissionMultipleLoadStores) { + // clang-format off + // With LocalMultiStoreElimPass + const std::string source = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %3 "i" + OpName %4 "A" + OpName %5 "B" + OpName %6 "C" + OpName %7 "D" + %8 = OpTypeVoid + %9 = OpTypeFunction %8 + %10 = OpTypeInt 32 1 + %11 = OpTypePointer Function %10 + %12 = OpConstant %10 0 + %13 = OpConstant %10 10 + %14 = OpTypeBool + %15 = OpTypeFloat 32 + %16 = OpTypeInt 32 0 + %17 = OpConstant %16 10 + %18 = OpTypeArray %15 %17 + %19 = OpTypePointer Function %18 + %20 = OpTypePointer Function %15 + %21 = OpConstant %10 1 + %2 = OpFunction %8 None %9 + %22 = OpLabel + %3 = OpVariable %11 Function + %4 = OpVariable %19 Function + %5 = OpVariable %19 Function + %6 = OpVariable %19 Function + %7 = OpVariable %19 Function + OpBranch %23 + %23 = OpLabel + %24 = OpPhi %10 %12 %22 %25 %26 + OpLoopMerge %27 %26 None + OpBranch %28 + %28 = OpLabel + %29 = OpSLessThan %14 %24 %13 + OpBranchConditional %29 %30 %27 + %30 = OpLabel + %31 = OpAccessChain %20 %5 %24 + %32 = OpLoad %15 %31 + %33 = OpAccessChain %20 %4 %24 + OpStore %33 %32 + %34 = OpAccessChain %20 %4 %24 + %35 = OpLoad %15 %34 + %36 = OpAccessChain %20 %5 %24 + OpStore %36 %35 + %37 = OpAccessChain %20 %7 %24 + %38 = OpLoad %15 %37 + %39 = OpAccessChain %20 %6 %24 + OpStore %39 %38 + %40 = OpAccessChain %20 %6 %24 + %41 = OpLoad %15 %40 + %42 = OpAccessChain %20 %7 %24 + OpStore %42 %41 + OpBranch %26 + %26 = OpLabel + %25 = OpIAdd %10 %24 %21 + OpBranch %23 + %27 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "A" +OpName %5 "B" +OpName %6 "C" +OpName %7 "D" +%8 = OpTypeVoid +%9 = OpTypeFunction %8 +%10 = OpTypeInt 32 1 +%11 = OpTypePointer Function %10 +%12 = OpConstant %10 0 +%13 = OpConstant %10 10 +%14 = OpTypeBool +%15 = OpTypeFloat 32 +%16 = OpTypeInt 32 0 +%17 = OpConstant %16 10 +%18 = OpTypeArray %15 %17 +%19 = OpTypePointer Function %18 +%20 = OpTypePointer Function %15 +%21 = OpConstant %10 1 +%2 = OpFunction %8 None %9 +%22 = OpLabel +%3 = OpVariable %11 Function +%4 = OpVariable %19 Function +%5 = OpVariable %19 Function +%6 = OpVariable %19 Function +%7 = OpVariable %19 Function +OpBranch %43 +%43 = OpLabel +%44 = OpPhi %10 %12 %22 %61 %60 +OpLoopMerge %62 %60 None +OpBranch %45 +%45 = OpLabel +%46 = OpSLessThan %14 %44 %13 +OpBranchConditional %46 %47 %62 +%47 = OpLabel +%48 = OpAccessChain %20 %5 %44 +%49 = OpLoad %15 %48 +%50 = OpAccessChain %20 %4 %44 +OpStore %50 %49 +%51 = OpAccessChain %20 %4 %44 +%52 = OpLoad %15 %51 +%53 = OpAccessChain %20 %5 %44 +OpStore %53 %52 +OpBranch %60 +%60 = OpLabel +%61 = OpIAdd %10 %44 %21 +OpBranch %43 +%62 = OpLabel +OpBranch %23 +%23 = OpLabel +%24 = OpPhi %10 %12 %62 %25 %26 +OpLoopMerge %27 %26 None +OpBranch %28 +%28 = OpLabel +%29 = OpSLessThan %14 %24 %13 +OpBranchConditional %29 %30 %27 +%30 = OpLabel +%37 = OpAccessChain %20 %7 %24 +%38 = OpLoad %15 %37 +%39 = OpAccessChain %20 %6 %24 +OpStore %39 %38 +%40 = OpAccessChain %20 %6 %24 +%41 = OpLoad %15 %40 +%42 = OpAccessChain %20 %7 %24 +OpStore %42 %41 +OpBranch %26 +%26 = OpLabel +%25 = OpIAdd %10 %24 %21 +OpBranch %23 +%27 = OpLabel +OpReturn +OpFunctionEnd +)"; + + +const std::string expected_multiple_passes = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "A" +OpName %5 "B" +OpName %6 "C" +OpName %7 "D" +%8 = OpTypeVoid +%9 = OpTypeFunction %8 +%10 = OpTypeInt 32 1 +%11 = OpTypePointer Function %10 +%12 = OpConstant %10 0 +%13 = OpConstant %10 10 +%14 = OpTypeBool +%15 = OpTypeFloat 32 +%16 = OpTypeInt 32 0 +%17 = OpConstant %16 10 +%18 = OpTypeArray %15 %17 +%19 = OpTypePointer Function %18 +%20 = OpTypePointer Function %15 +%21 = OpConstant %10 1 +%2 = OpFunction %8 None %9 +%22 = OpLabel +%3 = OpVariable %11 Function +%4 = OpVariable %19 Function +%5 = OpVariable %19 Function +%6 = OpVariable %19 Function +%7 = OpVariable %19 Function +OpBranch %63 +%63 = OpLabel +%64 = OpPhi %10 %12 %22 %75 %74 +OpLoopMerge %76 %74 None +OpBranch %65 +%65 = OpLabel +%66 = OpSLessThan %14 %64 %13 +OpBranchConditional %66 %67 %76 +%67 = OpLabel +%68 = OpAccessChain %20 %5 %64 +%69 = OpLoad %15 %68 +%70 = OpAccessChain %20 %4 %64 +OpStore %70 %69 +OpBranch %74 +%74 = OpLabel +%75 = OpIAdd %10 %64 %21 +OpBranch %63 +%76 = OpLabel +OpBranch %43 +%43 = OpLabel +%44 = OpPhi %10 %12 %76 %61 %60 +OpLoopMerge %62 %60 None +OpBranch %45 +%45 = OpLabel +%46 = OpSLessThan %14 %44 %13 +OpBranchConditional %46 %47 %62 +%47 = OpLabel +%51 = OpAccessChain %20 %4 %44 +%52 = OpLoad %15 %51 +%53 = OpAccessChain %20 %5 %44 +OpStore %53 %52 +OpBranch %60 +%60 = OpLabel +%61 = OpIAdd %10 %44 %21 +OpBranch %43 +%62 = OpLabel +OpBranch %77 +%77 = OpLabel +%78 = OpPhi %10 %12 %62 %89 %88 +OpLoopMerge %90 %88 None +OpBranch %79 +%79 = OpLabel +%80 = OpSLessThan %14 %78 %13 +OpBranchConditional %80 %81 %90 +%81 = OpLabel +%82 = OpAccessChain %20 %7 %78 +%83 = OpLoad %15 %82 +%84 = OpAccessChain %20 %6 %78 +OpStore %84 %83 +OpBranch %88 +%88 = OpLabel +%89 = OpIAdd %10 %78 %21 +OpBranch %77 +%90 = OpLabel +OpBranch %23 +%23 = OpLabel +%24 = OpPhi %10 %12 %90 %25 %26 +OpLoopMerge %27 %26 None +OpBranch %28 +%28 = OpLabel +%29 = OpSLessThan %14 %24 %13 +OpBranchConditional %29 %30 %27 +%30 = OpLabel +%40 = OpAccessChain %20 %6 %24 +%41 = OpLoad %15 %40 +%42 = OpAccessChain %20 %7 %24 +OpStore %42 %41 +OpBranch %26 +%26 = OpLabel +%25 = OpIAdd %10 %24 %21 +OpBranch %23 +%27 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on +std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); +Module* module = context->module(); +EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + +SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); +SinglePassRunAndCheck(source, expected, true); + +// By passing 1 as argument we are using the constructor which makes the +// critera to split the loop be if the registers in the loop exceede 1. By +// using this constructor we are also enabling multiple passes (disabled by +// default). +SinglePassRunAndCheck(source, expected_multiple_passes, true, + 1); +} + +/* +#version 430 +void main(void) { + int accumulator = 0; + float X[10]; + float Y[10]; + + for (int i = 0; i < 10; i++) { + X[i] = Y[i]; + Y[i] = X[i]; + accumulator += i; + } +} + +This should be split into the equivalent of: + +#version 430 +void main(void) { + int accumulator = 0; + float X[10]; + float Y[10]; + + for (int i = 0; i < 10; i++) { + X[i] = Y[i]; + } + for (int i = 0; i < 10; i++) { + Y[i] = X[i]; + accumulator += i; + } +} +*/ +TEST_F(FissionClassTest, FissionWithAccumulator) { + // clang-format off + // With LocalMultiStoreElimPass + const std::string source = R"(OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %3 "accumulator" + OpName %4 "i" + OpName %5 "X" + OpName %6 "Y" + %7 = OpTypeVoid + %8 = OpTypeFunction %7 + %9 = OpTypeInt 32 1 + %10 = OpTypePointer Function %9 + %11 = OpConstant %9 0 + %12 = OpConstant %9 10 + %13 = OpTypeBool + %14 = OpTypeFloat 32 + %15 = OpTypeInt 32 0 + %16 = OpConstant %15 10 + %17 = OpTypeArray %14 %16 + %18 = OpTypePointer Function %17 + %19 = OpTypePointer Function %14 + %20 = OpConstant %9 1 + %2 = OpFunction %7 None %8 + %21 = OpLabel + %3 = OpVariable %10 Function + %4 = OpVariable %10 Function + %5 = OpVariable %18 Function + %6 = OpVariable %18 Function + OpBranch %22 + %22 = OpLabel + %23 = OpPhi %9 %11 %21 %24 %25 + %26 = OpPhi %9 %11 %21 %27 %25 + OpLoopMerge %28 %25 None + OpBranch %29 + %29 = OpLabel + %30 = OpSLessThan %13 %26 %12 + OpBranchConditional %30 %31 %28 + %31 = OpLabel + %32 = OpAccessChain %19 %6 %26 + %33 = OpLoad %14 %32 + %34 = OpAccessChain %19 %5 %26 + OpStore %34 %33 + %35 = OpAccessChain %19 %5 %26 + %36 = OpLoad %14 %35 + %37 = OpAccessChain %19 %6 %26 + OpStore %37 %36 + %24 = OpIAdd %9 %23 %26 + OpBranch %25 + %25 = OpLabel + %27 = OpIAdd %9 %26 %20 + OpBranch %22 + %28 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "accumulator" +OpName %4 "i" +OpName %5 "X" +OpName %6 "Y" +%7 = OpTypeVoid +%8 = OpTypeFunction %7 +%9 = OpTypeInt 32 1 +%10 = OpTypePointer Function %9 +%11 = OpConstant %9 0 +%12 = OpConstant %9 10 +%13 = OpTypeBool +%14 = OpTypeFloat 32 +%15 = OpTypeInt 32 0 +%16 = OpConstant %15 10 +%17 = OpTypeArray %14 %16 +%18 = OpTypePointer Function %17 +%19 = OpTypePointer Function %14 +%20 = OpConstant %9 1 +%2 = OpFunction %7 None %8 +%21 = OpLabel +%3 = OpVariable %10 Function +%4 = OpVariable %10 Function +%5 = OpVariable %18 Function +%6 = OpVariable %18 Function +OpBranch %38 +%38 = OpLabel +%40 = OpPhi %9 %11 %21 %52 %51 +OpLoopMerge %53 %51 None +OpBranch %41 +%41 = OpLabel +%42 = OpSLessThan %13 %40 %12 +OpBranchConditional %42 %43 %53 +%43 = OpLabel +%44 = OpAccessChain %19 %6 %40 +%45 = OpLoad %14 %44 +%46 = OpAccessChain %19 %5 %40 +OpStore %46 %45 +OpBranch %51 +%51 = OpLabel +%52 = OpIAdd %9 %40 %20 +OpBranch %38 +%53 = OpLabel +OpBranch %22 +%22 = OpLabel +%23 = OpPhi %9 %11 %53 %24 %25 +%26 = OpPhi %9 %11 %53 %27 %25 +OpLoopMerge %28 %25 None +OpBranch %29 +%29 = OpLabel +%30 = OpSLessThan %13 %26 %12 +OpBranchConditional %30 %31 %28 +%31 = OpLabel +%35 = OpAccessChain %19 %5 %26 +%36 = OpLoad %14 %35 +%37 = OpAccessChain %19 %6 %26 +OpStore %37 %36 +%24 = OpIAdd %9 %23 %26 +OpBranch %25 +%25 = OpLabel +%27 = OpIAdd %9 %26 %20 +OpBranch %22 +%28 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, expected, true); +} + +/* +Generated from the following glsl: + +#version 430 +layout(location=0) out float x; +layout(location=1) out float y; + +void main(void) { + float accumulator_1 = 0; + float accumulator_2 = 0; + for (int i = 0; i < 10; i++) { + accumulator_1 += i; + accumulator_2 += i; + } + + x = accumulator_1; + y = accumulator_2; +} + +Should be split into equivalent of: + +void main(void) { + float accumulator_1 = 0; + float accumulator_2 = 0; + for (int i = 0; i < 10; i++) { + accumulator_1 += i; + } + + for (int i = 0; i < 10; i++) { + accumulator_2 += i; + } + x = accumulator_1; + y = accumulator_2; +} + +*/ +TEST_F(FissionClassTest, FissionWithPhisUsedOutwithLoop) { + // clang-format off + // With LocalMultiStoreElimPass + const std::string source = R"(OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 %4 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %5 "accumulator_1" + OpName %6 "accumulator_2" + OpName %7 "i" + OpName %3 "x" + OpName %4 "y" + OpDecorate %3 Location 0 + OpDecorate %4 Location 1 + %8 = OpTypeVoid + %9 = OpTypeFunction %8 + %10 = OpTypeFloat 32 + %11 = OpTypePointer Function %10 + %12 = OpConstant %10 0 + %13 = OpTypeInt 32 1 + %14 = OpTypePointer Function %13 + %15 = OpConstant %13 0 + %16 = OpConstant %13 10 + %17 = OpTypeBool + %18 = OpConstant %13 1 + %19 = OpTypePointer Output %10 + %3 = OpVariable %19 Output + %4 = OpVariable %19 Output + %2 = OpFunction %8 None %9 + %20 = OpLabel + %5 = OpVariable %11 Function + %6 = OpVariable %11 Function + %7 = OpVariable %14 Function + OpBranch %21 + %21 = OpLabel + %22 = OpPhi %10 %12 %20 %23 %24 + %25 = OpPhi %10 %12 %20 %26 %24 + %27 = OpPhi %13 %15 %20 %28 %24 + OpLoopMerge %29 %24 None + OpBranch %30 + %30 = OpLabel + %31 = OpSLessThan %17 %27 %16 + OpBranchConditional %31 %32 %29 + %32 = OpLabel + %33 = OpConvertSToF %10 %27 + %26 = OpFAdd %10 %25 %33 + %34 = OpConvertSToF %10 %27 + %23 = OpFAdd %10 %22 %34 + OpBranch %24 + %24 = OpLabel + %28 = OpIAdd %13 %27 %18 + OpStore %7 %28 + OpBranch %21 + %29 = OpLabel + OpStore %3 %25 + OpStore %4 %22 + OpReturn + OpFunctionEnd + )"; + + const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" %3 %4 +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %5 "accumulator_1" +OpName %6 "accumulator_2" +OpName %7 "i" +OpName %3 "x" +OpName %4 "y" +OpDecorate %3 Location 0 +OpDecorate %4 Location 1 +%8 = OpTypeVoid +%9 = OpTypeFunction %8 +%10 = OpTypeFloat 32 +%11 = OpTypePointer Function %10 +%12 = OpConstant %10 0 +%13 = OpTypeInt 32 1 +%14 = OpTypePointer Function %13 +%15 = OpConstant %13 0 +%16 = OpConstant %13 10 +%17 = OpTypeBool +%18 = OpConstant %13 1 +%19 = OpTypePointer Output %10 +%3 = OpVariable %19 Output +%4 = OpVariable %19 Output +%2 = OpFunction %8 None %9 +%20 = OpLabel +%5 = OpVariable %11 Function +%6 = OpVariable %11 Function +%7 = OpVariable %14 Function +OpBranch %35 +%35 = OpLabel +%37 = OpPhi %10 %12 %20 %43 %46 +%38 = OpPhi %13 %15 %20 %47 %46 +OpLoopMerge %48 %46 None +OpBranch %39 +%39 = OpLabel +%40 = OpSLessThan %17 %38 %16 +OpBranchConditional %40 %41 %48 +%41 = OpLabel +%42 = OpConvertSToF %10 %38 +%43 = OpFAdd %10 %37 %42 +OpBranch %46 +%46 = OpLabel +%47 = OpIAdd %13 %38 %18 +OpStore %7 %47 +OpBranch %35 +%48 = OpLabel +OpBranch %21 +%21 = OpLabel +%22 = OpPhi %10 %12 %48 %23 %24 +%27 = OpPhi %13 %15 %48 %28 %24 +OpLoopMerge %29 %24 None +OpBranch %30 +%30 = OpLabel +%31 = OpSLessThan %17 %27 %16 +OpBranchConditional %31 %32 %29 +%32 = OpLabel +%34 = OpConvertSToF %10 %27 +%23 = OpFAdd %10 %22 %34 +OpBranch %24 +%24 = OpLabel +%28 = OpIAdd %13 %27 %18 +OpStore %7 %28 +OpBranch %21 +%29 = OpLabel +OpStore %3 %37 +OpStore %4 %22 +OpReturn +OpFunctionEnd +)"; + + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, expected, true); +} + +/* +#version 430 +void main(void) { + float A[10][10]; + float B[10][10]; + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + A[i][j] = B[i][j]; + B[i][j] = A[i][j]; + } + } +} + +Should be split into equivalent of: + +#version 430 +void main(void) { + float A[10][10]; + float B[10][10]; + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + A[i][j] = B[i][j]; + } + for (int j = 0; j < 10; j++) { + B[i][j] = A[i][j]; + } + } +} + + +*/ +TEST_F(FissionClassTest, FissionNested) { + // clang-format off + // With LocalMultiStoreElimPass + const std::string source = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %3 "i" + OpName %4 "j" + OpName %5 "A" + OpName %6 "B" + %7 = OpTypeVoid + %8 = OpTypeFunction %7 + %9 = OpTypeInt 32 1 + %10 = OpTypePointer Function %9 + %11 = OpConstant %9 0 + %12 = OpConstant %9 10 + %13 = OpTypeBool + %14 = OpTypeFloat 32 + %15 = OpTypeInt 32 0 + %16 = OpConstant %15 10 + %17 = OpTypeArray %14 %16 + %18 = OpTypeArray %17 %16 + %19 = OpTypePointer Function %18 + %20 = OpTypePointer Function %14 + %21 = OpConstant %9 1 + %2 = OpFunction %7 None %8 + %22 = OpLabel + %3 = OpVariable %10 Function + %4 = OpVariable %10 Function + %5 = OpVariable %19 Function + %6 = OpVariable %19 Function + OpStore %3 %11 + OpBranch %23 + %23 = OpLabel + %24 = OpPhi %9 %11 %22 %25 %26 + OpLoopMerge %27 %26 None + OpBranch %28 + %28 = OpLabel + %29 = OpSLessThan %13 %24 %12 + OpBranchConditional %29 %30 %27 + %30 = OpLabel + OpStore %4 %11 + OpBranch %31 + %31 = OpLabel + %32 = OpPhi %9 %11 %30 %33 %34 + OpLoopMerge %35 %34 None + OpBranch %36 + %36 = OpLabel + %37 = OpSLessThan %13 %32 %12 + OpBranchConditional %37 %38 %35 + %38 = OpLabel + %39 = OpAccessChain %20 %6 %24 %32 + %40 = OpLoad %14 %39 + %41 = OpAccessChain %20 %5 %24 %32 + OpStore %41 %40 + %42 = OpAccessChain %20 %5 %24 %32 + %43 = OpLoad %14 %42 + %44 = OpAccessChain %20 %6 %24 %32 + OpStore %44 %43 + OpBranch %34 + %34 = OpLabel + %33 = OpIAdd %9 %32 %21 + OpStore %4 %33 + OpBranch %31 + %35 = OpLabel + OpBranch %26 + %26 = OpLabel + %25 = OpIAdd %9 %24 %21 + OpStore %3 %25 + OpBranch %23 + %27 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "j" +OpName %5 "A" +OpName %6 "B" +%7 = OpTypeVoid +%8 = OpTypeFunction %7 +%9 = OpTypeInt 32 1 +%10 = OpTypePointer Function %9 +%11 = OpConstant %9 0 +%12 = OpConstant %9 10 +%13 = OpTypeBool +%14 = OpTypeFloat 32 +%15 = OpTypeInt 32 0 +%16 = OpConstant %15 10 +%17 = OpTypeArray %14 %16 +%18 = OpTypeArray %17 %16 +%19 = OpTypePointer Function %18 +%20 = OpTypePointer Function %14 +%21 = OpConstant %9 1 +%2 = OpFunction %7 None %8 +%22 = OpLabel +%3 = OpVariable %10 Function +%4 = OpVariable %10 Function +%5 = OpVariable %19 Function +%6 = OpVariable %19 Function +OpStore %3 %11 +OpBranch %23 +%23 = OpLabel +%24 = OpPhi %9 %11 %22 %25 %26 +OpLoopMerge %27 %26 None +OpBranch %28 +%28 = OpLabel +%29 = OpSLessThan %13 %24 %12 +OpBranchConditional %29 %30 %27 +%30 = OpLabel +OpStore %4 %11 +OpBranch %45 +%45 = OpLabel +%46 = OpPhi %9 %11 %30 %57 %56 +OpLoopMerge %58 %56 None +OpBranch %47 +%47 = OpLabel +%48 = OpSLessThan %13 %46 %12 +OpBranchConditional %48 %49 %58 +%49 = OpLabel +%50 = OpAccessChain %20 %6 %24 %46 +%51 = OpLoad %14 %50 +%52 = OpAccessChain %20 %5 %24 %46 +OpStore %52 %51 +OpBranch %56 +%56 = OpLabel +%57 = OpIAdd %9 %46 %21 +OpStore %4 %57 +OpBranch %45 +%58 = OpLabel +OpBranch %31 +%31 = OpLabel +%32 = OpPhi %9 %11 %58 %33 %34 +OpLoopMerge %35 %34 None +OpBranch %36 +%36 = OpLabel +%37 = OpSLessThan %13 %32 %12 +OpBranchConditional %37 %38 %35 +%38 = OpLabel +%42 = OpAccessChain %20 %5 %24 %32 +%43 = OpLoad %14 %42 +%44 = OpAccessChain %20 %6 %24 %32 +OpStore %44 %43 +OpBranch %34 +%34 = OpLabel +%33 = OpIAdd %9 %32 %21 +OpStore %4 %33 +OpBranch %31 +%35 = OpLabel +OpBranch %26 +%26 = OpLabel +%25 = OpIAdd %9 %24 %21 +OpStore %3 %25 +OpBranch %23 +%27 = OpLabel +OpReturn +OpFunctionEnd +)"; + + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, expected, true); +} + +/* +#version 430 +void main(void) { + int accumulator = 0; + float A[10]; + float B[10]; + float C[10]; + + for (int i = 0; i < 10; i++) { + int c = C[i]; + A[i] = B[i]; + B[i] = A[i] + c; + } +} + +This loop should not be split as we would have to break the order of the loads +to do so. It would be grouped into two sets: + +1 + int c = C[i]; + B[i] = A[i] + c; + +2 + A[i] = B[i]; + +To keep the load C[i] in the same order we would need to put B[i] ahead of that +*/ +TEST_F(FissionClassTest, FissionLoad) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string source = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "c" +OpName %5 "C" +OpName %6 "A" +OpName %7 "B" +%8 = OpTypeVoid +%9 = OpTypeFunction %8 +%10 = OpTypeInt 32 1 +%11 = OpTypePointer Function %10 +%12 = OpConstant %10 0 +%13 = OpConstant %10 10 +%14 = OpTypeBool +%15 = OpTypeFloat 32 +%16 = OpTypePointer Function %15 +%17 = OpTypeInt 32 0 +%18 = OpConstant %17 10 +%19 = OpTypeArray %15 %18 +%20 = OpTypePointer Function %19 +%21 = OpConstant %10 1 +%2 = OpFunction %8 None %9 +%22 = OpLabel +%3 = OpVariable %11 Function +%4 = OpVariable %16 Function +%5 = OpVariable %20 Function +%6 = OpVariable %20 Function +%7 = OpVariable %20 Function +OpBranch %23 +%23 = OpLabel +%24 = OpPhi %10 %12 %22 %25 %26 +OpLoopMerge %27 %26 None +OpBranch %28 +%28 = OpLabel +%29 = OpSLessThan %14 %24 %13 +OpBranchConditional %29 %30 %27 +%30 = OpLabel +%31 = OpAccessChain %16 %5 %24 +%32 = OpLoad %15 %31 +OpStore %4 %32 +%33 = OpAccessChain %16 %7 %24 +%34 = OpLoad %15 %33 +%35 = OpAccessChain %16 %6 %24 +OpStore %35 %34 +%36 = OpAccessChain %16 %6 %24 +%37 = OpLoad %15 %36 +%38 = OpFAdd %15 %37 %32 +%39 = OpAccessChain %16 %7 %24 +OpStore %39 %38 +OpBranch %26 +%26 = OpLabel +%25 = OpIAdd %10 %24 %21 +OpBranch %23 +%27 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, source, true); +} + +/* +#version 430 +layout(location=0) flat in int condition; +void main(void) { + float A[10]; + float B[10]; + + for (int i = 0; i < 10; i++) { + if (condition == 1) + A[i] = B[i]; + else + B[i] = A[i]; + } +} + + +When this is split we leave the condition check and control flow inplace and +leave its removal for dead code elimination. + +#version 430 +layout(location=0) flat in int condition; +void main(void) { + float A[10]; + float B[10]; + + for (int i = 0; i < 10; i++) { + if (condition == 1) + A[i] = B[i]; + else + ; + } + for (int i = 0; i < 10; i++) { + if (condition == 1) + ; + else + B[i] = A[i]; + } +} + + +*/ +TEST_F(FissionClassTest, FissionControlFlow) { + // clang-format off + // With LocalMultiStoreElimPass + const std::string source = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %4 "i" + OpName %3 "condition" + OpName %5 "A" + OpName %6 "B" + OpDecorate %3 Flat + OpDecorate %3 Location 0 + %7 = OpTypeVoid + %8 = OpTypeFunction %7 + %9 = OpTypeInt 32 1 + %10 = OpTypePointer Function %9 + %11 = OpConstant %9 0 + %12 = OpConstant %9 10 + %13 = OpTypeBool + %14 = OpTypePointer Input %9 + %3 = OpVariable %14 Input + %15 = OpConstant %9 1 + %16 = OpTypeFloat 32 + %17 = OpTypeInt 32 0 + %18 = OpConstant %17 10 + %19 = OpTypeArray %16 %18 + %20 = OpTypePointer Function %19 + %21 = OpTypePointer Function %16 + %2 = OpFunction %7 None %8 + %22 = OpLabel + %4 = OpVariable %10 Function + %5 = OpVariable %20 Function + %6 = OpVariable %20 Function + %31 = OpLoad %9 %3 + OpStore %4 %11 + OpBranch %23 + %23 = OpLabel + %24 = OpPhi %9 %11 %22 %25 %26 + OpLoopMerge %27 %26 None + OpBranch %28 + %28 = OpLabel + %29 = OpSLessThan %13 %24 %12 + OpBranchConditional %29 %30 %27 + %30 = OpLabel + %32 = OpIEqual %13 %31 %15 + OpSelectionMerge %33 None + OpBranchConditional %32 %34 %35 + %34 = OpLabel + %36 = OpAccessChain %21 %6 %24 + %37 = OpLoad %16 %36 + %38 = OpAccessChain %21 %5 %24 + OpStore %38 %37 + OpBranch %33 + %35 = OpLabel + %39 = OpAccessChain %21 %5 %24 + %40 = OpLoad %16 %39 + %41 = OpAccessChain %21 %6 %24 + OpStore %41 %40 + OpBranch %33 + %33 = OpLabel + OpBranch %26 + %26 = OpLabel + %25 = OpIAdd %9 %24 %15 + OpStore %4 %25 + OpBranch %23 + %27 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" %3 +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %4 "i" +OpName %3 "condition" +OpName %5 "A" +OpName %6 "B" +OpDecorate %3 Flat +OpDecorate %3 Location 0 +%7 = OpTypeVoid +%8 = OpTypeFunction %7 +%9 = OpTypeInt 32 1 +%10 = OpTypePointer Function %9 +%11 = OpConstant %9 0 +%12 = OpConstant %9 10 +%13 = OpTypeBool +%14 = OpTypePointer Input %9 +%3 = OpVariable %14 Input +%15 = OpConstant %9 1 +%16 = OpTypeFloat 32 +%17 = OpTypeInt 32 0 +%18 = OpConstant %17 10 +%19 = OpTypeArray %16 %18 +%20 = OpTypePointer Function %19 +%21 = OpTypePointer Function %16 +%2 = OpFunction %7 None %8 +%22 = OpLabel +%4 = OpVariable %10 Function +%5 = OpVariable %20 Function +%6 = OpVariable %20 Function +%23 = OpLoad %9 %3 +OpStore %4 %11 +OpBranch %42 +%42 = OpLabel +%43 = OpPhi %9 %11 %22 %58 %57 +OpLoopMerge %59 %57 None +OpBranch %44 +%44 = OpLabel +%45 = OpSLessThan %13 %43 %12 +OpBranchConditional %45 %46 %59 +%46 = OpLabel +%47 = OpIEqual %13 %23 %15 +OpSelectionMerge %56 None +OpBranchConditional %47 %52 %48 +%48 = OpLabel +OpBranch %56 +%52 = OpLabel +%53 = OpAccessChain %21 %6 %43 +%54 = OpLoad %16 %53 +%55 = OpAccessChain %21 %5 %43 +OpStore %55 %54 +OpBranch %56 +%56 = OpLabel +OpBranch %57 +%57 = OpLabel +%58 = OpIAdd %9 %43 %15 +OpStore %4 %58 +OpBranch %42 +%59 = OpLabel +OpBranch %24 +%24 = OpLabel +%25 = OpPhi %9 %11 %59 %26 %27 +OpLoopMerge %28 %27 None +OpBranch %29 +%29 = OpLabel +%30 = OpSLessThan %13 %25 %12 +OpBranchConditional %30 %31 %28 +%31 = OpLabel +%32 = OpIEqual %13 %23 %15 +OpSelectionMerge %33 None +OpBranchConditional %32 %34 %35 +%34 = OpLabel +OpBranch %33 +%35 = OpLabel +%39 = OpAccessChain %21 %5 %25 +%40 = OpLoad %16 %39 +%41 = OpAccessChain %21 %6 %25 +OpStore %41 %40 +OpBranch %33 +%33 = OpLabel +OpBranch %27 +%27 = OpLabel +%26 = OpIAdd %9 %25 %15 +OpStore %4 %26 +OpBranch %24 +%28 = OpLabel +OpReturn +OpFunctionEnd +)"; + + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, expected, true); +} + +/* +#version 430 +void main(void) { + float A[10]; + float B[10]; + for (int i = 0; i < 10; i++) { + if (i == 1) + B[i] = A[i]; + else if (i == 2) + A[i] = B[i]; + else + A[i] = 0; + } +} + +After running the pass with multiple splits enabled (via register threshold of +1) we expect the equivalent of: + +#version 430 +void main(void) { + float A[10]; + float B[10]; + for (int i = 0; i < 10; i++) { + if (i == 1) + B[i] = A[i]; + else if (i == 2) + else + } + for (int i = 0; i < 10; i++) { + if (i == 1) + else if (i == 2) + A[i] = B[i]; + else + } + for (int i = 0; i < 10; i++) { + if (i == 1) + else if (i == 2) + else + A[i] = 0; + } + +} + +*/ +TEST_F(FissionClassTest, FissionControlFlow2) { + // clang-format off + // With LocalMultiStoreElimPass + const std::string source = R"(OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %3 "i" + OpName %4 "B" + OpName %5 "A" + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %10 = OpConstant %8 0 + %11 = OpConstant %8 10 + %12 = OpTypeBool + %13 = OpConstant %8 1 + %14 = OpTypeFloat 32 + %15 = OpTypeInt 32 0 + %16 = OpConstant %15 10 + %17 = OpTypeArray %14 %16 + %18 = OpTypePointer Function %17 + %19 = OpTypePointer Function %14 + %20 = OpConstant %8 2 + %21 = OpConstant %14 0 + %2 = OpFunction %6 None %7 + %22 = OpLabel + %3 = OpVariable %9 Function + %4 = OpVariable %18 Function + %5 = OpVariable %18 Function + OpStore %3 %10 + OpBranch %23 + %23 = OpLabel + %24 = OpPhi %8 %10 %22 %25 %26 + OpLoopMerge %27 %26 None + OpBranch %28 + %28 = OpLabel + %29 = OpSLessThan %12 %24 %11 + OpBranchConditional %29 %30 %27 + %30 = OpLabel + %31 = OpIEqual %12 %24 %13 + OpSelectionMerge %32 None + OpBranchConditional %31 %33 %34 + %33 = OpLabel + %35 = OpAccessChain %19 %5 %24 + %36 = OpLoad %14 %35 + %37 = OpAccessChain %19 %4 %24 + OpStore %37 %36 + OpBranch %32 + %34 = OpLabel + %38 = OpIEqual %12 %24 %20 + OpSelectionMerge %39 None + OpBranchConditional %38 %40 %41 + %40 = OpLabel + %42 = OpAccessChain %19 %4 %24 + %43 = OpLoad %14 %42 + %44 = OpAccessChain %19 %5 %24 + OpStore %44 %43 + OpBranch %39 + %41 = OpLabel + %45 = OpAccessChain %19 %5 %24 + OpStore %45 %21 + OpBranch %39 + %39 = OpLabel + OpBranch %32 + %32 = OpLabel + OpBranch %26 + %26 = OpLabel + %25 = OpIAdd %8 %24 %13 + OpStore %3 %25 + OpBranch %23 + %27 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "B" +OpName %5 "A" +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 0 +%11 = OpConstant %8 10 +%12 = OpTypeBool +%13 = OpConstant %8 1 +%14 = OpTypeFloat 32 +%15 = OpTypeInt 32 0 +%16 = OpConstant %15 10 +%17 = OpTypeArray %14 %16 +%18 = OpTypePointer Function %17 +%19 = OpTypePointer Function %14 +%20 = OpConstant %8 2 +%21 = OpConstant %14 0 +%2 = OpFunction %6 None %7 +%22 = OpLabel +%3 = OpVariable %9 Function +%4 = OpVariable %18 Function +%5 = OpVariable %18 Function +OpStore %3 %10 +OpBranch %46 +%46 = OpLabel +%47 = OpPhi %8 %10 %22 %67 %66 +OpLoopMerge %68 %66 None +OpBranch %48 +%48 = OpLabel +%49 = OpSLessThan %12 %47 %11 +OpBranchConditional %49 %50 %68 +%50 = OpLabel +%51 = OpIEqual %12 %47 %13 +OpSelectionMerge %65 None +OpBranchConditional %51 %61 %52 +%52 = OpLabel +%53 = OpIEqual %12 %47 %20 +OpSelectionMerge %60 None +OpBranchConditional %53 %56 %54 +%54 = OpLabel +OpBranch %60 +%56 = OpLabel +OpBranch %60 +%60 = OpLabel +OpBranch %65 +%61 = OpLabel +%62 = OpAccessChain %19 %5 %47 +%63 = OpLoad %14 %62 +%64 = OpAccessChain %19 %4 %47 +OpStore %64 %63 +OpBranch %65 +%65 = OpLabel +OpBranch %66 +%66 = OpLabel +%67 = OpIAdd %8 %47 %13 +OpStore %3 %67 +OpBranch %46 +%68 = OpLabel +OpBranch %69 +%69 = OpLabel +%70 = OpPhi %8 %10 %68 %87 %86 +OpLoopMerge %88 %86 None +OpBranch %71 +%71 = OpLabel +%72 = OpSLessThan %12 %70 %11 +OpBranchConditional %72 %73 %88 +%73 = OpLabel +%74 = OpIEqual %12 %70 %13 +OpSelectionMerge %85 None +OpBranchConditional %74 %84 %75 +%75 = OpLabel +%76 = OpIEqual %12 %70 %20 +OpSelectionMerge %83 None +OpBranchConditional %76 %79 %77 +%77 = OpLabel +OpBranch %83 +%79 = OpLabel +%80 = OpAccessChain %19 %4 %70 +%81 = OpLoad %14 %80 +%82 = OpAccessChain %19 %5 %70 +OpStore %82 %81 +OpBranch %83 +%83 = OpLabel +OpBranch %85 +%84 = OpLabel +OpBranch %85 +%85 = OpLabel +OpBranch %86 +%86 = OpLabel +%87 = OpIAdd %8 %70 %13 +OpStore %3 %87 +OpBranch %69 +%88 = OpLabel +OpBranch %23 +%23 = OpLabel +%24 = OpPhi %8 %10 %88 %25 %26 +OpLoopMerge %27 %26 None +OpBranch %28 +%28 = OpLabel +%29 = OpSLessThan %12 %24 %11 +OpBranchConditional %29 %30 %27 +%30 = OpLabel +%31 = OpIEqual %12 %24 %13 +OpSelectionMerge %32 None +OpBranchConditional %31 %33 %34 +%33 = OpLabel +OpBranch %32 +%34 = OpLabel +%38 = OpIEqual %12 %24 %20 +OpSelectionMerge %39 None +OpBranchConditional %38 %40 %41 +%40 = OpLabel +OpBranch %39 +%41 = OpLabel +%45 = OpAccessChain %19 %5 %24 +OpStore %45 %21 +OpBranch %39 +%39 = OpLabel +OpBranch %32 +%32 = OpLabel +OpBranch %26 +%26 = OpLabel +%25 = OpIAdd %8 %24 %13 +OpStore %3 %25 +OpBranch %23 +%27 = OpLabel +OpReturn +OpFunctionEnd +)"; + + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, expected, true, 1); +} + +/* +#version 430 +layout(location=0) flat in int condition; +void main(void) { + float A[10]; + float B[10]; + for (int i = 0; i < 10; i++) { + B[i] = A[i]; + memoryBarrier(); + A[i] = B[i]; + } +} + +This should not be split due to the memory barrier. +*/ +TEST_F(FissionClassTest, FissionBarrier) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string source = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" %3 +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %4 "i" +OpName %5 "B" +OpName %6 "A" +OpName %3 "condition" +OpDecorate %3 Flat +OpDecorate %3 Location 0 +%7 = OpTypeVoid +%8 = OpTypeFunction %7 +%9 = OpTypeInt 32 1 +%10 = OpTypePointer Function %9 +%11 = OpConstant %9 0 +%12 = OpConstant %9 10 +%13 = OpTypeBool +%14 = OpTypeFloat 32 +%15 = OpTypeInt 32 0 +%16 = OpConstant %15 10 +%17 = OpTypeArray %14 %16 +%18 = OpTypePointer Function %17 +%19 = OpTypePointer Function %14 +%20 = OpConstant %15 1 +%21 = OpConstant %15 4048 +%22 = OpConstant %9 1 +%23 = OpTypePointer Input %9 +%3 = OpVariable %23 Input +%2 = OpFunction %7 None %8 +%24 = OpLabel +%4 = OpVariable %10 Function +%5 = OpVariable %18 Function +%6 = OpVariable %18 Function +OpStore %4 %11 +OpBranch %25 +%25 = OpLabel +%26 = OpPhi %9 %11 %24 %27 %28 +OpLoopMerge %29 %28 None +OpBranch %30 +%30 = OpLabel +%31 = OpSLessThan %13 %26 %12 +OpBranchConditional %31 %32 %29 +%32 = OpLabel +%33 = OpAccessChain %19 %6 %26 +%34 = OpLoad %14 %33 +%35 = OpAccessChain %19 %5 %26 +OpStore %35 %34 +OpMemoryBarrier %20 %21 +%36 = OpAccessChain %19 %5 %26 +%37 = OpLoad %14 %36 +%38 = OpAccessChain %19 %6 %26 +OpStore %38 %37 +OpBranch %28 +%28 = OpLabel +%27 = OpIAdd %9 %26 %22 +OpStore %4 %27 +OpBranch %25 +%29 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, source, true); +} + +/* +#version 430 +void main(void) { + float A[10]; + float B[10]; + for (int i = 0; i < 10; i++) { + B[i] = A[i]; + if ( i== 1) + break; + A[i] = B[i]; + } +} + +This should not be split due to the break. +*/ +TEST_F(FissionClassTest, FissionBreak) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string source = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "B" +OpName %5 "A" +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 0 +%11 = OpConstant %8 10 +%12 = OpTypeBool +%13 = OpTypeFloat 32 +%14 = OpTypeInt 32 0 +%15 = OpConstant %14 10 +%16 = OpTypeArray %13 %15 +%17 = OpTypePointer Function %16 +%18 = OpTypePointer Function %13 +%19 = OpConstant %8 1 +%2 = OpFunction %6 None %7 +%20 = OpLabel +%3 = OpVariable %9 Function +%4 = OpVariable %17 Function +%5 = OpVariable %17 Function +OpStore %3 %10 +OpBranch %21 +%21 = OpLabel +%22 = OpPhi %8 %10 %20 %23 %24 +OpLoopMerge %25 %24 None +OpBranch %26 +%26 = OpLabel +%27 = OpSLessThan %12 %22 %11 +OpBranchConditional %27 %28 %25 +%28 = OpLabel +%29 = OpAccessChain %18 %5 %22 +%30 = OpLoad %13 %29 +%31 = OpAccessChain %18 %4 %22 +OpStore %31 %30 +%32 = OpIEqual %12 %22 %19 +OpSelectionMerge %33 None +OpBranchConditional %32 %34 %33 +%34 = OpLabel +OpBranch %25 +%33 = OpLabel +%35 = OpAccessChain %18 %4 %22 +%36 = OpLoad %13 %35 +%37 = OpAccessChain %18 %5 %22 +OpStore %37 %36 +OpBranch %24 +%24 = OpLabel +%23 = OpIAdd %8 %22 %19 +OpStore %3 %23 +OpBranch %21 +%25 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, source, true); +} + +/* +#version 430 +void main(void) { + float A[10]; + float B[10]; + for (int i = 0; i < 10; i++) { + B[i] = A[i]; + if ( i== 1) + continue; + A[i] = B[i]; + } +} + +This loop should be split into: + + for (int i = 0; i < 10; i++) { + B[i] = A[i]; + if ( i== 1) + continue; + } + for (int i = 0; i < 10; i++) { + if ( i== 1) + continue; + A[i] = B[i]; + } +The continue block in the first loop is left to DCE. +} + + +*/ +TEST_F(FissionClassTest, FissionContinue) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string source = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "B" +OpName %5 "A" +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 0 +%11 = OpConstant %8 10 +%12 = OpTypeBool +%13 = OpTypeFloat 32 +%14 = OpTypeInt 32 0 +%15 = OpConstant %14 10 +%16 = OpTypeArray %13 %15 +%17 = OpTypePointer Function %16 +%18 = OpTypePointer Function %13 +%19 = OpConstant %8 1 +%2 = OpFunction %6 None %7 +%20 = OpLabel +%3 = OpVariable %9 Function +%4 = OpVariable %17 Function +%5 = OpVariable %17 Function +OpStore %3 %10 +OpBranch %21 +%21 = OpLabel +%22 = OpPhi %8 %10 %20 %23 %24 +OpLoopMerge %25 %24 None +OpBranch %26 +%26 = OpLabel +%27 = OpSLessThan %12 %22 %11 +OpBranchConditional %27 %28 %25 +%28 = OpLabel +%29 = OpAccessChain %18 %5 %22 +%30 = OpLoad %13 %29 +%31 = OpAccessChain %18 %4 %22 +OpStore %31 %30 +%32 = OpIEqual %12 %22 %19 +OpSelectionMerge %33 None +OpBranchConditional %32 %34 %33 +%34 = OpLabel +OpBranch %24 +%33 = OpLabel +%35 = OpAccessChain %18 %4 %22 +%36 = OpLoad %13 %35 +%37 = OpAccessChain %18 %5 %22 +OpStore %37 %36 +OpBranch %24 +%24 = OpLabel +%23 = OpIAdd %8 %22 %19 +OpStore %3 %23 +OpBranch %21 +%25 = OpLabel +OpReturn +OpFunctionEnd +)"; + +const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "B" +OpName %5 "A" +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 0 +%11 = OpConstant %8 10 +%12 = OpTypeBool +%13 = OpTypeFloat 32 +%14 = OpTypeInt 32 0 +%15 = OpConstant %14 10 +%16 = OpTypeArray %13 %15 +%17 = OpTypePointer Function %16 +%18 = OpTypePointer Function %13 +%19 = OpConstant %8 1 +%2 = OpFunction %6 None %7 +%20 = OpLabel +%3 = OpVariable %9 Function +%4 = OpVariable %17 Function +%5 = OpVariable %17 Function +OpStore %3 %10 +OpBranch %38 +%38 = OpLabel +%39 = OpPhi %8 %10 %20 %53 %52 +OpLoopMerge %54 %52 None +OpBranch %40 +%40 = OpLabel +%41 = OpSLessThan %12 %39 %11 +OpBranchConditional %41 %42 %54 +%42 = OpLabel +%43 = OpAccessChain %18 %5 %39 +%44 = OpLoad %13 %43 +%45 = OpAccessChain %18 %4 %39 +OpStore %45 %44 +%46 = OpIEqual %12 %39 %19 +OpSelectionMerge %47 None +OpBranchConditional %46 %51 %47 +%47 = OpLabel +OpBranch %52 +%51 = OpLabel +OpBranch %52 +%52 = OpLabel +%53 = OpIAdd %8 %39 %19 +OpStore %3 %53 +OpBranch %38 +%54 = OpLabel +OpBranch %21 +%21 = OpLabel +%22 = OpPhi %8 %10 %54 %23 %24 +OpLoopMerge %25 %24 None +OpBranch %26 +%26 = OpLabel +%27 = OpSLessThan %12 %22 %11 +OpBranchConditional %27 %28 %25 +%28 = OpLabel +%32 = OpIEqual %12 %22 %19 +OpSelectionMerge %33 None +OpBranchConditional %32 %34 %33 +%34 = OpLabel +OpBranch %24 +%33 = OpLabel +%35 = OpAccessChain %18 %4 %22 +%36 = OpLoad %13 %35 +%37 = OpAccessChain %18 %5 %22 +OpStore %37 %36 +OpBranch %24 +%24 = OpLabel +%23 = OpIAdd %8 %22 %19 +OpStore %3 %23 +OpBranch %21 +%25 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, expected, true); +} + +/* +#version 430 +void main(void) { + float A[10]; + float B[10]; + int i = 0; + do { + B[i] = A[i]; + A[i] = B[i]; + ++i; + } while (i < 10); +} + + +Check that this is split into: + int i = 0; + do { + B[i] = A[i]; + ++i; + } while (i < 10); + + i = 0; + do { + A[i] = B[i]; + ++i; + } while (i < 10); + + +*/ +TEST_F(FissionClassTest, FissionDoWhile) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string source = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "B" +OpName %5 "A" +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 0 +%11 = OpTypeFloat 32 +%12 = OpTypeInt 32 0 +%13 = OpConstant %12 10 +%14 = OpTypeArray %11 %13 +%15 = OpTypePointer Function %14 +%16 = OpTypePointer Function %11 +%17 = OpConstant %8 1 +%18 = OpConstant %8 10 +%19 = OpTypeBool +%2 = OpFunction %6 None %7 +%20 = OpLabel +%3 = OpVariable %9 Function +%4 = OpVariable %15 Function +%5 = OpVariable %15 Function +OpStore %3 %10 +OpBranch %21 +%21 = OpLabel +%22 = OpPhi %8 %10 %20 %23 %24 +OpLoopMerge %25 %24 None +OpBranch %26 +%26 = OpLabel +%27 = OpAccessChain %16 %5 %22 +%28 = OpLoad %11 %27 +%29 = OpAccessChain %16 %4 %22 +OpStore %29 %28 +%30 = OpAccessChain %16 %4 %22 +%31 = OpLoad %11 %30 +%32 = OpAccessChain %16 %5 %22 +OpStore %32 %31 +%23 = OpIAdd %8 %22 %17 +OpStore %3 %23 +OpBranch %24 +%24 = OpLabel +%33 = OpSLessThan %19 %23 %18 +OpBranchConditional %33 %21 %25 +%25 = OpLabel +OpReturn +OpFunctionEnd +)"; + +const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "B" +OpName %5 "A" +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 0 +%11 = OpTypeFloat 32 +%12 = OpTypeInt 32 0 +%13 = OpConstant %12 10 +%14 = OpTypeArray %11 %13 +%15 = OpTypePointer Function %14 +%16 = OpTypePointer Function %11 +%17 = OpConstant %8 1 +%18 = OpConstant %8 10 +%19 = OpTypeBool +%2 = OpFunction %6 None %7 +%20 = OpLabel +%3 = OpVariable %9 Function +%4 = OpVariable %15 Function +%5 = OpVariable %15 Function +OpStore %3 %10 +OpBranch %34 +%34 = OpLabel +%35 = OpPhi %8 %10 %20 %43 %44 +OpLoopMerge %46 %44 None +OpBranch %36 +%36 = OpLabel +%37 = OpAccessChain %16 %5 %35 +%38 = OpLoad %11 %37 +%39 = OpAccessChain %16 %4 %35 +OpStore %39 %38 +%43 = OpIAdd %8 %35 %17 +OpStore %3 %43 +OpBranch %44 +%44 = OpLabel +%45 = OpSLessThan %19 %43 %18 +OpBranchConditional %45 %34 %46 +%46 = OpLabel +OpBranch %21 +%21 = OpLabel +%22 = OpPhi %8 %10 %46 %23 %24 +OpLoopMerge %25 %24 None +OpBranch %26 +%26 = OpLabel +%30 = OpAccessChain %16 %4 %22 +%31 = OpLoad %11 %30 +%32 = OpAccessChain %16 %5 %22 +OpStore %32 %31 +%23 = OpIAdd %8 %22 %17 +OpStore %3 %23 +OpBranch %24 +%24 = OpLabel +%33 = OpSLessThan %19 %23 %18 +OpBranchConditional %33 %21 %25 +%25 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, expected, true); +} + +/* + +#version 430 +void main(void) { + float A[10][10]; + float B[10][10]; + for (int j = 0; j < 10; ++j) { + for (int i = 0; i < 10; ++i) { + B[i][j] = A[i][i]; + A[i][i] = B[i][j + 1]; + } + } +} + + +This loop can't be split because the load B[i][j + 1] is dependent on the store +B[i][j]. + +*/ +TEST_F(FissionClassTest, FissionNestedDependency) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string source = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "j" +OpName %4 "i" +OpName %5 "B" +OpName %6 "A" +%7 = OpTypeVoid +%8 = OpTypeFunction %7 +%9 = OpTypeInt 32 1 +%10 = OpTypePointer Function %9 +%11 = OpConstant %9 0 +%12 = OpConstant %9 10 +%13 = OpTypeBool +%14 = OpTypeFloat 32 +%15 = OpTypeInt 32 0 +%16 = OpConstant %15 10 +%17 = OpTypeArray %14 %16 +%18 = OpTypeArray %17 %16 +%19 = OpTypePointer Function %18 +%20 = OpTypePointer Function %14 +%21 = OpConstant %9 1 +%2 = OpFunction %7 None %8 +%22 = OpLabel +%3 = OpVariable %10 Function +%4 = OpVariable %10 Function +%5 = OpVariable %19 Function +%6 = OpVariable %19 Function +OpBranch %23 +%23 = OpLabel +%24 = OpPhi %9 %11 %22 %25 %26 +OpLoopMerge %27 %26 None +OpBranch %28 +%28 = OpLabel +%29 = OpSLessThan %13 %24 %12 +OpBranchConditional %29 %30 %27 +%30 = OpLabel +OpBranch %31 +%31 = OpLabel +%32 = OpPhi %9 %11 %30 %33 %34 +OpLoopMerge %35 %34 None +OpBranch %36 +%36 = OpLabel +%37 = OpSLessThan %13 %32 %12 +OpBranchConditional %37 %38 %35 +%38 = OpLabel +%39 = OpAccessChain %20 %6 %32 %32 +%40 = OpLoad %14 %39 +%41 = OpAccessChain %20 %5 %32 %24 +OpStore %41 %40 +%42 = OpIAdd %9 %24 %21 +%43 = OpAccessChain %20 %5 %32 %42 +%44 = OpLoad %14 %43 +%45 = OpAccessChain %20 %6 %32 %32 +OpStore %45 %44 +OpBranch %34 +%34 = OpLabel +%33 = OpIAdd %9 %32 %21 +OpBranch %31 +%35 = OpLabel +OpBranch %26 +%26 = OpLabel +%25 = OpIAdd %9 %24 %21 +OpBranch %23 +%27 = OpLabel +OpReturn +OpFunctionEnd +)"; + + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, source, true); +} + +/* +#version 430 +void main(void) { + float A[10][10]; + float B[10][10]; + for (int j = 0; j < 10; ++j) { + for (int i = 0; i < 10; ++i) { + B[i][i] = A[i][j]; + A[i][j+1] = B[i][i]; + } + } +} + +This loop should not be split as the load A[i][j+1] would be reading a value +written in the store A[i][j] which would be hit before A[i][j+1] if the loops +where split but would not get hit before the read currently. + +*/ +TEST_F(FissionClassTest, FissionNestedDependency2) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string source = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "j" +OpName %4 "i" +OpName %5 "B" +OpName %6 "A" +%7 = OpTypeVoid +%8 = OpTypeFunction %7 +%9 = OpTypeInt 32 1 +%10 = OpTypePointer Function %9 +%11 = OpConstant %9 0 +%12 = OpConstant %9 10 +%13 = OpTypeBool +%14 = OpTypeFloat 32 +%15 = OpTypeInt 32 0 +%16 = OpConstant %15 10 +%17 = OpTypeArray %14 %16 +%18 = OpTypeArray %17 %16 +%19 = OpTypePointer Function %18 +%20 = OpTypePointer Function %14 +%21 = OpConstant %9 1 +%2 = OpFunction %7 None %8 +%22 = OpLabel +%3 = OpVariable %10 Function +%4 = OpVariable %10 Function +%5 = OpVariable %19 Function +%6 = OpVariable %19 Function +OpStore %3 %11 +OpBranch %23 +%23 = OpLabel +%24 = OpPhi %9 %11 %22 %25 %26 +OpLoopMerge %27 %26 None +OpBranch %28 +%28 = OpLabel +%29 = OpSLessThan %13 %24 %12 +OpBranchConditional %29 %30 %27 +%30 = OpLabel +OpStore %4 %11 +OpBranch %31 +%31 = OpLabel +%32 = OpPhi %9 %11 %30 %33 %34 +OpLoopMerge %35 %34 None +OpBranch %36 +%36 = OpLabel +%37 = OpSLessThan %13 %32 %12 +OpBranchConditional %37 %38 %35 +%38 = OpLabel +%39 = OpAccessChain %20 %6 %32 %24 +%40 = OpLoad %14 %39 +%41 = OpAccessChain %20 %5 %32 %32 +OpStore %41 %40 +%42 = OpIAdd %9 %24 %21 +%43 = OpAccessChain %20 %5 %32 %32 +%44 = OpLoad %14 %43 +%45 = OpAccessChain %20 %6 %32 %42 +OpStore %45 %44 +OpBranch %34 +%34 = OpLabel +%33 = OpIAdd %9 %32 %21 +OpStore %4 %33 +OpBranch %31 +%35 = OpLabel +OpBranch %26 +%26 = OpLabel +%25 = OpIAdd %9 %24 %21 +OpStore %3 %25 +OpBranch %23 +%27 = OpLabel +OpReturn +OpFunctionEnd +)"; + + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, source, true); +} + +/* +#version 430 +void main(void) { + float A[10][10]; + float B[10][10]; + for (int j = 0; j < 10; ++j) { + for (int i = 0; i < 10; ++i) { + B[i][j] = A[i][j]; + A[i][j] = B[i][j]; + } + for (int i = 0; i < 10; ++i) { + B[i][j] = A[i][j]; + A[i][j] = B[i][j]; + } + } +} + + + +Should be split into: + +for (int j = 0; j < 10; ++j) { + for (int i = 0; i < 10; ++i) + B[i][j] = A[i][j]; + for (int i = 0; i < 10; ++i) + A[i][j] = B[i][j]; + for (int i = 0; i < 10; ++i) + B[i][j] = A[i][j]; + for (int i = 0; i < 10; ++i) + A[i][j] = B[i][j]; +*/ +TEST_F(FissionClassTest, FissionMultipleLoopsNested) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string source = R"(OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %3 "j" + OpName %4 "i" + OpName %5 "B" + OpName %6 "A" + OpName %7 "i" + %8 = OpTypeVoid + %9 = OpTypeFunction %8 + %10 = OpTypeInt 32 1 + %11 = OpTypePointer Function %10 + %12 = OpConstant %10 0 + %13 = OpConstant %10 10 + %14 = OpTypeBool + %15 = OpTypeFloat 32 + %16 = OpTypeInt 32 0 + %17 = OpConstant %16 10 + %18 = OpTypeArray %15 %17 + %19 = OpTypeArray %18 %17 + %20 = OpTypePointer Function %19 + %21 = OpTypePointer Function %15 + %22 = OpConstant %10 1 + %2 = OpFunction %8 None %9 + %23 = OpLabel + %3 = OpVariable %11 Function + %4 = OpVariable %11 Function + %5 = OpVariable %20 Function + %6 = OpVariable %20 Function + %7 = OpVariable %11 Function + OpStore %3 %12 + OpBranch %24 + %24 = OpLabel + %25 = OpPhi %10 %12 %23 %26 %27 + OpLoopMerge %28 %27 None + OpBranch %29 + %29 = OpLabel + %30 = OpSLessThan %14 %25 %13 + OpBranchConditional %30 %31 %28 + %31 = OpLabel + OpStore %4 %12 + OpBranch %32 + %32 = OpLabel + %33 = OpPhi %10 %12 %31 %34 %35 + OpLoopMerge %36 %35 None + OpBranch %37 + %37 = OpLabel + %38 = OpSLessThan %14 %33 %13 + OpBranchConditional %38 %39 %36 + %39 = OpLabel + %40 = OpAccessChain %21 %6 %33 %25 + %41 = OpLoad %15 %40 + %42 = OpAccessChain %21 %5 %33 %25 + OpStore %42 %41 + %43 = OpAccessChain %21 %5 %33 %25 + %44 = OpLoad %15 %43 + %45 = OpAccessChain %21 %6 %33 %25 + OpStore %45 %44 + OpBranch %35 + %35 = OpLabel + %34 = OpIAdd %10 %33 %22 + OpStore %4 %34 + OpBranch %32 + %36 = OpLabel + OpStore %7 %12 + OpBranch %46 + %46 = OpLabel + %47 = OpPhi %10 %12 %36 %48 %49 + OpLoopMerge %50 %49 None + OpBranch %51 + %51 = OpLabel + %52 = OpSLessThan %14 %47 %13 + OpBranchConditional %52 %53 %50 + %53 = OpLabel + %54 = OpAccessChain %21 %6 %47 %25 + %55 = OpLoad %15 %54 + %56 = OpAccessChain %21 %5 %47 %25 + OpStore %56 %55 + %57 = OpAccessChain %21 %5 %47 %25 + %58 = OpLoad %15 %57 + %59 = OpAccessChain %21 %6 %47 %25 + OpStore %59 %58 + OpBranch %49 + %49 = OpLabel + %48 = OpIAdd %10 %47 %22 + OpStore %7 %48 + OpBranch %46 + %50 = OpLabel + OpBranch %27 + %27 = OpLabel + %26 = OpIAdd %10 %25 %22 + OpStore %3 %26 + OpBranch %24 + %28 = OpLabel + OpReturn + OpFunctionEnd +)"; + +const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "j" +OpName %4 "i" +OpName %5 "B" +OpName %6 "A" +OpName %7 "i" +%8 = OpTypeVoid +%9 = OpTypeFunction %8 +%10 = OpTypeInt 32 1 +%11 = OpTypePointer Function %10 +%12 = OpConstant %10 0 +%13 = OpConstant %10 10 +%14 = OpTypeBool +%15 = OpTypeFloat 32 +%16 = OpTypeInt 32 0 +%17 = OpConstant %16 10 +%18 = OpTypeArray %15 %17 +%19 = OpTypeArray %18 %17 +%20 = OpTypePointer Function %19 +%21 = OpTypePointer Function %15 +%22 = OpConstant %10 1 +%2 = OpFunction %8 None %9 +%23 = OpLabel +%3 = OpVariable %11 Function +%4 = OpVariable %11 Function +%5 = OpVariable %20 Function +%6 = OpVariable %20 Function +%7 = OpVariable %11 Function +OpStore %3 %12 +OpBranch %24 +%24 = OpLabel +%25 = OpPhi %10 %12 %23 %26 %27 +OpLoopMerge %28 %27 None +OpBranch %29 +%29 = OpLabel +%30 = OpSLessThan %14 %25 %13 +OpBranchConditional %30 %31 %28 +%31 = OpLabel +OpStore %4 %12 +OpBranch %60 +%60 = OpLabel +%61 = OpPhi %10 %12 %31 %72 %71 +OpLoopMerge %73 %71 None +OpBranch %62 +%62 = OpLabel +%63 = OpSLessThan %14 %61 %13 +OpBranchConditional %63 %64 %73 +%64 = OpLabel +%65 = OpAccessChain %21 %6 %61 %25 +%66 = OpLoad %15 %65 +%67 = OpAccessChain %21 %5 %61 %25 +OpStore %67 %66 +OpBranch %71 +%71 = OpLabel +%72 = OpIAdd %10 %61 %22 +OpStore %4 %72 +OpBranch %60 +%73 = OpLabel +OpBranch %32 +%32 = OpLabel +%33 = OpPhi %10 %12 %73 %34 %35 +OpLoopMerge %36 %35 None +OpBranch %37 +%37 = OpLabel +%38 = OpSLessThan %14 %33 %13 +OpBranchConditional %38 %39 %36 +%39 = OpLabel +%43 = OpAccessChain %21 %5 %33 %25 +%44 = OpLoad %15 %43 +%45 = OpAccessChain %21 %6 %33 %25 +OpStore %45 %44 +OpBranch %35 +%35 = OpLabel +%34 = OpIAdd %10 %33 %22 +OpStore %4 %34 +OpBranch %32 +%36 = OpLabel +OpStore %7 %12 +OpBranch %74 +%74 = OpLabel +%75 = OpPhi %10 %12 %36 %86 %85 +OpLoopMerge %87 %85 None +OpBranch %76 +%76 = OpLabel +%77 = OpSLessThan %14 %75 %13 +OpBranchConditional %77 %78 %87 +%78 = OpLabel +%79 = OpAccessChain %21 %6 %75 %25 +%80 = OpLoad %15 %79 +%81 = OpAccessChain %21 %5 %75 %25 +OpStore %81 %80 +OpBranch %85 +%85 = OpLabel +%86 = OpIAdd %10 %75 %22 +OpStore %7 %86 +OpBranch %74 +%87 = OpLabel +OpBranch %46 +%46 = OpLabel +%47 = OpPhi %10 %12 %87 %48 %49 +OpLoopMerge %50 %49 None +OpBranch %51 +%51 = OpLabel +%52 = OpSLessThan %14 %47 %13 +OpBranchConditional %52 %53 %50 +%53 = OpLabel +%57 = OpAccessChain %21 %5 %47 %25 +%58 = OpLoad %15 %57 +%59 = OpAccessChain %21 %6 %47 %25 +OpStore %59 %58 +OpBranch %49 +%49 = OpLabel +%48 = OpIAdd %10 %47 %22 +OpStore %7 %48 +OpBranch %46 +%50 = OpLabel +OpBranch %27 +%27 = OpLabel +%26 = OpIAdd %10 %25 %22 +OpStore %3 %26 +OpBranch %24 +%28 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + const Function* function = spvtest::GetFunction(module, 2); + LoopDescriptor& pre_pass_descriptor = *context->GetLoopDescriptor(function); + EXPECT_EQ(pre_pass_descriptor.NumLoops(), 3u); + EXPECT_EQ(pre_pass_descriptor.pre_begin()->NumImmediateChildren(), 2u); + + // Test that the pass transforms the ir into the expected output. + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, expected, true); + + // Test that the loop descriptor is correctly maintained and updated by the + // pass. + LoopFissionPass loop_fission; + loop_fission.SetContextForTesting(context.get()); + loop_fission.Process(); + + function = spvtest::GetFunction(module, 2); + LoopDescriptor& post_pass_descriptor = *context->GetLoopDescriptor(function); + EXPECT_EQ(post_pass_descriptor.NumLoops(), 5u); + EXPECT_EQ(post_pass_descriptor.pre_begin()->NumImmediateChildren(), 4u); +} + +/* +#version 430 +void main(void) { + float A[10][10]; + float B[10][10]; + for (int i = 0; i < 10; ++i) { + B[i][i] = A[i][i]; + A[i][i] = B[i][i]; + } + for (int i = 0; i < 10; ++i) { + B[i][i] = A[i][i]; + A[i][i] = B[i][i] + } +} + + + +Should be split into: + + for (int i = 0; i < 10; ++i) + B[i][i] = A[i][i]; + for (int i = 0; i < 10; ++i) + A[i][i] = B[i][i]; + for (int i = 0; i < 10; ++i) + B[i][i] = A[i][i]; + for (int i = 0; i < 10; ++i) + A[i][i] = B[i][i]; +*/ +TEST_F(FissionClassTest, FissionMultipleLoops) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string source = R"(OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %3 "i" + OpName %4 "B" + OpName %5 "A" + OpName %6 "i" + %7 = OpTypeVoid + %8 = OpTypeFunction %7 + %9 = OpTypeInt 32 1 + %10 = OpTypePointer Function %9 + %11 = OpConstant %9 0 + %12 = OpConstant %9 10 + %13 = OpTypeBool + %14 = OpTypeFloat 32 + %15 = OpTypeInt 32 0 + %16 = OpConstant %15 10 + %17 = OpTypeArray %14 %16 + %18 = OpTypePointer Function %17 + %19 = OpTypePointer Function %14 + %20 = OpConstant %9 1 + %2 = OpFunction %7 None %8 + %21 = OpLabel + %3 = OpVariable %10 Function + %4 = OpVariable %18 Function + %5 = OpVariable %18 Function + %6 = OpVariable %10 Function + OpStore %3 %11 + OpBranch %22 + %22 = OpLabel + %23 = OpPhi %9 %11 %21 %24 %25 + OpLoopMerge %26 %25 None + OpBranch %27 + %27 = OpLabel + %28 = OpSLessThan %13 %23 %12 + OpBranchConditional %28 %29 %26 + %29 = OpLabel + %30 = OpAccessChain %19 %5 %23 + %31 = OpLoad %14 %30 + %32 = OpAccessChain %19 %4 %23 + OpStore %32 %31 + %33 = OpAccessChain %19 %4 %23 + %34 = OpLoad %14 %33 + %35 = OpAccessChain %19 %5 %23 + OpStore %35 %34 + OpBranch %25 + %25 = OpLabel + %24 = OpIAdd %9 %23 %20 + OpStore %3 %24 + OpBranch %22 + %26 = OpLabel + OpStore %6 %11 + OpBranch %36 + %36 = OpLabel + %37 = OpPhi %9 %11 %26 %38 %39 + OpLoopMerge %40 %39 None + OpBranch %41 + %41 = OpLabel + %42 = OpSLessThan %13 %37 %12 + OpBranchConditional %42 %43 %40 + %43 = OpLabel + %44 = OpAccessChain %19 %5 %37 + %45 = OpLoad %14 %44 + %46 = OpAccessChain %19 %4 %37 + OpStore %46 %45 + %47 = OpAccessChain %19 %4 %37 + %48 = OpLoad %14 %47 + %49 = OpAccessChain %19 %5 %37 + OpStore %49 %48 + OpBranch %39 + %39 = OpLabel + %38 = OpIAdd %9 %37 %20 + OpStore %6 %38 + OpBranch %36 + %40 = OpLabel + OpReturn + OpFunctionEnd +)"; + +const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "B" +OpName %5 "A" +OpName %6 "i" +%7 = OpTypeVoid +%8 = OpTypeFunction %7 +%9 = OpTypeInt 32 1 +%10 = OpTypePointer Function %9 +%11 = OpConstant %9 0 +%12 = OpConstant %9 10 +%13 = OpTypeBool +%14 = OpTypeFloat 32 +%15 = OpTypeInt 32 0 +%16 = OpConstant %15 10 +%17 = OpTypeArray %14 %16 +%18 = OpTypePointer Function %17 +%19 = OpTypePointer Function %14 +%20 = OpConstant %9 1 +%2 = OpFunction %7 None %8 +%21 = OpLabel +%3 = OpVariable %10 Function +%4 = OpVariable %18 Function +%5 = OpVariable %18 Function +%6 = OpVariable %10 Function +OpStore %3 %11 +OpBranch %64 +%64 = OpLabel +%65 = OpPhi %9 %11 %21 %76 %75 +OpLoopMerge %77 %75 None +OpBranch %66 +%66 = OpLabel +%67 = OpSLessThan %13 %65 %12 +OpBranchConditional %67 %68 %77 +%68 = OpLabel +%69 = OpAccessChain %19 %5 %65 +%70 = OpLoad %14 %69 +%71 = OpAccessChain %19 %4 %65 +OpStore %71 %70 +OpBranch %75 +%75 = OpLabel +%76 = OpIAdd %9 %65 %20 +OpStore %3 %76 +OpBranch %64 +%77 = OpLabel +OpBranch %22 +%22 = OpLabel +%23 = OpPhi %9 %11 %77 %24 %25 +OpLoopMerge %26 %25 None +OpBranch %27 +%27 = OpLabel +%28 = OpSLessThan %13 %23 %12 +OpBranchConditional %28 %29 %26 +%29 = OpLabel +%33 = OpAccessChain %19 %4 %23 +%34 = OpLoad %14 %33 +%35 = OpAccessChain %19 %5 %23 +OpStore %35 %34 +OpBranch %25 +%25 = OpLabel +%24 = OpIAdd %9 %23 %20 +OpStore %3 %24 +OpBranch %22 +%26 = OpLabel +OpStore %6 %11 +OpBranch %50 +%50 = OpLabel +%51 = OpPhi %9 %11 %26 %62 %61 +OpLoopMerge %63 %61 None +OpBranch %52 +%52 = OpLabel +%53 = OpSLessThan %13 %51 %12 +OpBranchConditional %53 %54 %63 +%54 = OpLabel +%55 = OpAccessChain %19 %5 %51 +%56 = OpLoad %14 %55 +%57 = OpAccessChain %19 %4 %51 +OpStore %57 %56 +OpBranch %61 +%61 = OpLabel +%62 = OpIAdd %9 %51 %20 +OpStore %6 %62 +OpBranch %50 +%63 = OpLabel +OpBranch %36 +%36 = OpLabel +%37 = OpPhi %9 %11 %63 %38 %39 +OpLoopMerge %40 %39 None +OpBranch %41 +%41 = OpLabel +%42 = OpSLessThan %13 %37 %12 +OpBranchConditional %42 %43 %40 +%43 = OpLabel +%47 = OpAccessChain %19 %4 %37 +%48 = OpLoad %14 %47 +%49 = OpAccessChain %19 %5 %37 +OpStore %49 %48 +OpBranch %39 +%39 = OpLabel +%38 = OpIAdd %9 %37 %20 +OpStore %6 %38 +OpBranch %36 +%40 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, expected, true); + + const Function* function = spvtest::GetFunction(module, 2); + LoopDescriptor& pre_pass_descriptor = *context->GetLoopDescriptor(function); + EXPECT_EQ(pre_pass_descriptor.NumLoops(), 2u); + EXPECT_EQ(pre_pass_descriptor.pre_begin()->NumImmediateChildren(), 0u); + + // Test that the pass transforms the ir into the expected output. + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, expected, true); + + // Test that the loop descriptor is correctly maintained and updated by the + // pass. + LoopFissionPass loop_fission; + loop_fission.SetContextForTesting(context.get()); + loop_fission.Process(); + + function = spvtest::GetFunction(module, 2); + LoopDescriptor& post_pass_descriptor = *context->GetLoopDescriptor(function); + EXPECT_EQ(post_pass_descriptor.NumLoops(), 4u); + EXPECT_EQ(post_pass_descriptor.pre_begin()->NumImmediateChildren(), 0u); +} + +/* +#version 430 +int foo() { return 1; } +void main(void) { + float A[10]; + float B[10]; + for (int i = 0; i < 10; ++i) { + B[i] = A[i]; + foo(); + A[i] = B[i]; + } +} + +This should not be split as it has a function call in it so we can't determine +if it has side effects. +*/ +TEST_F(FissionClassTest, FissionFunctionCall) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string source = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "foo(" +OpName %4 "i" +OpName %5 "B" +OpName %6 "A" +%7 = OpTypeVoid +%8 = OpTypeFunction %7 +%9 = OpTypeInt 32 1 +%10 = OpTypeFunction %9 +%11 = OpConstant %9 1 +%12 = OpTypePointer Function %9 +%13 = OpConstant %9 0 +%14 = OpConstant %9 10 +%15 = OpTypeBool +%16 = OpTypeFloat 32 +%17 = OpTypeInt 32 0 +%18 = OpConstant %17 10 +%19 = OpTypeArray %16 %18 +%20 = OpTypePointer Function %19 +%21 = OpTypePointer Function %16 +%2 = OpFunction %7 None %8 +%22 = OpLabel +%4 = OpVariable %12 Function +%5 = OpVariable %20 Function +%6 = OpVariable %20 Function +OpStore %4 %13 +OpBranch %23 +%23 = OpLabel +%24 = OpPhi %9 %13 %22 %25 %26 +OpLoopMerge %27 %26 None +OpBranch %28 +%28 = OpLabel +%29 = OpSLessThan %15 %24 %14 +OpBranchConditional %29 %30 %27 +%30 = OpLabel +%31 = OpAccessChain %21 %6 %24 +%32 = OpLoad %16 %31 +%33 = OpAccessChain %21 %5 %24 +OpStore %33 %32 +%34 = OpFunctionCall %9 %3 +%35 = OpAccessChain %21 %5 %24 +%36 = OpLoad %16 %35 +%37 = OpAccessChain %21 %6 %24 +OpStore %37 %36 +OpBranch %26 +%26 = OpLabel +%25 = OpIAdd %9 %24 %11 +OpStore %4 %25 +OpBranch %23 +%27 = OpLabel +OpReturn +OpFunctionEnd +%3 = OpFunction %9 None %10 +%38 = OpLabel +OpReturnValue %11 +OpFunctionEnd +)"; + + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, source, true); +} + +/* +#version 430 +void main(void) { + float A[10]; + float B[10]; + for (int i = 0; i < 10; ++i) { + switch (i) { + case 1: + B[i] = A[i]; + break; + default: + A[i] = B[i]; + } + } +} + +This should be split into: + for (int i = 0; i < 10; ++i) { + switch (i) { + case 1: + break; + default: + A[i] = B[i]; + } + } + + for (int i = 0; i < 10; ++i) { + switch (i) { + case 1: + B[i] = A[i]; + break; + default: + break; + } + } + +*/ +TEST_F(FissionClassTest, FissionSwitchStatement) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string source = R"(OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %3 "i" + OpName %4 "B" + OpName %5 "A" + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %10 = OpConstant %8 0 + %11 = OpConstant %8 10 + %12 = OpTypeBool + %13 = OpTypeFloat 32 + %14 = OpTypeInt 32 0 + %15 = OpConstant %14 10 + %16 = OpTypeArray %13 %15 + %17 = OpTypePointer Function %16 + %18 = OpTypePointer Function %13 + %19 = OpConstant %8 1 + %2 = OpFunction %6 None %7 + %20 = OpLabel + %3 = OpVariable %9 Function + %4 = OpVariable %17 Function + %5 = OpVariable %17 Function + OpStore %3 %10 + OpBranch %21 + %21 = OpLabel + %22 = OpPhi %8 %10 %20 %23 %24 + OpLoopMerge %25 %24 None + OpBranch %26 + %26 = OpLabel + %27 = OpSLessThan %12 %22 %11 + OpBranchConditional %27 %28 %25 + %28 = OpLabel + OpSelectionMerge %29 None + OpSwitch %22 %30 1 %31 + %30 = OpLabel + %32 = OpAccessChain %18 %4 %22 + %33 = OpLoad %13 %32 + %34 = OpAccessChain %18 %5 %22 + OpStore %34 %33 + OpBranch %29 + %31 = OpLabel + %35 = OpAccessChain %18 %5 %22 + %36 = OpLoad %13 %35 + %37 = OpAccessChain %18 %4 %22 + OpStore %37 %36 + OpBranch %29 + %29 = OpLabel + OpBranch %24 + %24 = OpLabel + %23 = OpIAdd %8 %22 %19 + OpStore %3 %23 + OpBranch %21 + %25 = OpLabel + OpReturn + OpFunctionEnd +)"; + +const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "B" +OpName %5 "A" +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 0 +%11 = OpConstant %8 10 +%12 = OpTypeBool +%13 = OpTypeFloat 32 +%14 = OpTypeInt 32 0 +%15 = OpConstant %14 10 +%16 = OpTypeArray %13 %15 +%17 = OpTypePointer Function %16 +%18 = OpTypePointer Function %13 +%19 = OpConstant %8 1 +%2 = OpFunction %6 None %7 +%20 = OpLabel +%3 = OpVariable %9 Function +%4 = OpVariable %17 Function +%5 = OpVariable %17 Function +OpStore %3 %10 +OpBranch %38 +%38 = OpLabel +%39 = OpPhi %8 %10 %20 %53 %52 +OpLoopMerge %54 %52 None +OpBranch %40 +%40 = OpLabel +%41 = OpSLessThan %12 %39 %11 +OpBranchConditional %41 %42 %54 +%42 = OpLabel +OpSelectionMerge %51 None +OpSwitch %39 %47 1 %43 +%43 = OpLabel +OpBranch %51 +%47 = OpLabel +%48 = OpAccessChain %18 %4 %39 +%49 = OpLoad %13 %48 +%50 = OpAccessChain %18 %5 %39 +OpStore %50 %49 +OpBranch %51 +%51 = OpLabel +OpBranch %52 +%52 = OpLabel +%53 = OpIAdd %8 %39 %19 +OpStore %3 %53 +OpBranch %38 +%54 = OpLabel +OpBranch %21 +%21 = OpLabel +%22 = OpPhi %8 %10 %54 %23 %24 +OpLoopMerge %25 %24 None +OpBranch %26 +%26 = OpLabel +%27 = OpSLessThan %12 %22 %11 +OpBranchConditional %27 %28 %25 +%28 = OpLabel +OpSelectionMerge %29 None +OpSwitch %22 %30 1 %31 +%30 = OpLabel +OpBranch %29 +%31 = OpLabel +%35 = OpAccessChain %18 %5 %22 +%36 = OpLoad %13 %35 +%37 = OpAccessChain %18 %4 %22 +OpStore %37 %36 +OpBranch %29 +%29 = OpLabel +OpBranch %24 +%24 = OpLabel +%23 = OpIAdd %8 %22 %19 +OpStore %3 %23 +OpBranch %21 +%25 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(source, expected, true); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/loop_optimizations/nested_loops.cpp b/test/opt/loop_optimizations/nested_loops.cpp new file mode 100644 index 000000000..651cdef44 --- /dev/null +++ b/test/opt/loop_optimizations/nested_loops.cpp @@ -0,0 +1,795 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/iterator.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/pass.h" +#include "source/opt/tree_iterator.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::UnorderedElementsAre; + +bool Validate(const std::vector& bin) { + spv_target_env target_env = SPV_ENV_UNIVERSAL_1_2; + spv_context spvContext = spvContextCreate(target_env); + spv_diagnostic diagnostic = nullptr; + spv_const_binary_t binary = {bin.data(), bin.size()}; + spv_result_t error = spvValidate(spvContext, &binary, &diagnostic); + if (error != 0) spvDiagnosticPrint(diagnostic); + spvDiagnosticDestroy(diagnostic); + spvContextDestroy(spvContext); + return error == 0; +} + +using PassClassTest = PassTest<::testing::Test>; + +/* +Generated from the following GLSL +#version 330 core +layout(location = 0) out vec4 c; +void main() { + int i = 0; + for (; i < 10; ++i) { + int j = 0; + int k = 0; + for (; j < 11; ++j) {} + for (; k < 12; ++k) {} + } +} +*/ +TEST_F(PassClassTest, BasicVisitFromEntryPoint) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 330 + OpName %2 "main" + OpName %4 "i" + OpName %5 "j" + OpName %6 "k" + OpName %3 "c" + OpDecorate %3 Location 0 + %7 = OpTypeVoid + %8 = OpTypeFunction %7 + %9 = OpTypeInt 32 1 + %10 = OpTypePointer Function %9 + %11 = OpConstant %9 0 + %12 = OpConstant %9 10 + %13 = OpTypeBool + %14 = OpConstant %9 11 + %15 = OpConstant %9 1 + %16 = OpConstant %9 12 + %17 = OpTypeFloat 32 + %18 = OpTypeVector %17 4 + %19 = OpTypePointer Output %18 + %3 = OpVariable %19 Output + %2 = OpFunction %7 None %8 + %20 = OpLabel + %4 = OpVariable %10 Function + %5 = OpVariable %10 Function + %6 = OpVariable %10 Function + OpStore %4 %11 + OpBranch %21 + %21 = OpLabel + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + %25 = OpLoad %9 %4 + %26 = OpSLessThan %13 %25 %12 + OpBranchConditional %26 %27 %22 + %27 = OpLabel + OpStore %5 %11 + OpStore %6 %11 + OpBranch %28 + %28 = OpLabel + OpLoopMerge %29 %30 None + OpBranch %31 + %31 = OpLabel + %32 = OpLoad %9 %5 + %33 = OpSLessThan %13 %32 %14 + OpBranchConditional %33 %34 %29 + %34 = OpLabel + OpBranch %30 + %30 = OpLabel + %35 = OpLoad %9 %5 + %36 = OpIAdd %9 %35 %15 + OpStore %5 %36 + OpBranch %28 + %29 = OpLabel + OpBranch %37 + %37 = OpLabel + OpLoopMerge %38 %39 None + OpBranch %40 + %40 = OpLabel + %41 = OpLoad %9 %6 + %42 = OpSLessThan %13 %41 %16 + OpBranchConditional %42 %43 %38 + %43 = OpLabel + OpBranch %39 + %39 = OpLabel + %44 = OpLoad %9 %6 + %45 = OpIAdd %9 %44 %15 + OpStore %6 %45 + OpBranch %37 + %38 = OpLabel + OpBranch %23 + %23 = OpLabel + %46 = OpLoad %9 %4 + %47 = OpIAdd %9 %46 %15 + OpStore %4 %47 + OpBranch %21 + %22 = OpLabel + OpReturn + OpFunctionEnd + )"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + EXPECT_EQ(ld.NumLoops(), 3u); + + // Invalid basic block id. + EXPECT_EQ(ld[0u], nullptr); + // Not a loop header. + EXPECT_EQ(ld[20], nullptr); + + Loop& parent_loop = *ld[21]; + EXPECT_TRUE(parent_loop.HasNestedLoops()); + EXPECT_FALSE(parent_loop.IsNested()); + EXPECT_EQ(parent_loop.GetDepth(), 1u); + EXPECT_EQ(std::distance(parent_loop.begin(), parent_loop.end()), 2u); + EXPECT_EQ(parent_loop.GetHeaderBlock(), spvtest::GetBasicBlock(f, 21)); + EXPECT_EQ(parent_loop.GetLatchBlock(), spvtest::GetBasicBlock(f, 23)); + EXPECT_EQ(parent_loop.GetMergeBlock(), spvtest::GetBasicBlock(f, 22)); + + Loop& child_loop_1 = *ld[28]; + EXPECT_FALSE(child_loop_1.HasNestedLoops()); + EXPECT_TRUE(child_loop_1.IsNested()); + EXPECT_EQ(child_loop_1.GetDepth(), 2u); + EXPECT_EQ(std::distance(child_loop_1.begin(), child_loop_1.end()), 0u); + EXPECT_EQ(child_loop_1.GetHeaderBlock(), spvtest::GetBasicBlock(f, 28)); + EXPECT_EQ(child_loop_1.GetLatchBlock(), spvtest::GetBasicBlock(f, 30)); + EXPECT_EQ(child_loop_1.GetMergeBlock(), spvtest::GetBasicBlock(f, 29)); + + Loop& child_loop_2 = *ld[37]; + EXPECT_FALSE(child_loop_2.HasNestedLoops()); + EXPECT_TRUE(child_loop_2.IsNested()); + EXPECT_EQ(child_loop_2.GetDepth(), 2u); + EXPECT_EQ(std::distance(child_loop_2.begin(), child_loop_2.end()), 0u); + EXPECT_EQ(child_loop_2.GetHeaderBlock(), spvtest::GetBasicBlock(f, 37)); + EXPECT_EQ(child_loop_2.GetLatchBlock(), spvtest::GetBasicBlock(f, 39)); + EXPECT_EQ(child_loop_2.GetMergeBlock(), spvtest::GetBasicBlock(f, 38)); +} + +static void CheckLoopBlocks(Loop* loop, + std::unordered_set* expected_ids) { + SCOPED_TRACE("Check loop " + std::to_string(loop->GetHeaderBlock()->id())); + for (uint32_t bb_id : loop->GetBlocks()) { + EXPECT_EQ(expected_ids->count(bb_id), 1u); + expected_ids->erase(bb_id); + } + EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock())); + EXPECT_EQ(expected_ids->size(), 0u); +} + +/* +Generated from the following GLSL +#version 330 core +layout(location = 0) out vec4 c; +void main() { + int i = 0; + for (; i < 10; ++i) { + for (int j = 0; j < 11; ++j) { + if (j < 5) { + for (int k = 0; k < 12; ++k) {} + } + else {} + for (int k = 0; k < 12; ++k) {} + } + } +}*/ +TEST_F(PassClassTest, TripleNestedLoop) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 330 + OpName %2 "main" + OpName %4 "i" + OpName %5 "j" + OpName %6 "k" + OpName %7 "k" + OpName %3 "c" + OpDecorate %3 Location 0 + %8 = OpTypeVoid + %9 = OpTypeFunction %8 + %10 = OpTypeInt 32 1 + %11 = OpTypePointer Function %10 + %12 = OpConstant %10 0 + %13 = OpConstant %10 10 + %14 = OpTypeBool + %15 = OpConstant %10 11 + %16 = OpConstant %10 5 + %17 = OpConstant %10 12 + %18 = OpConstant %10 1 + %19 = OpTypeFloat 32 + %20 = OpTypeVector %19 4 + %21 = OpTypePointer Output %20 + %3 = OpVariable %21 Output + %2 = OpFunction %8 None %9 + %22 = OpLabel + %4 = OpVariable %11 Function + %5 = OpVariable %11 Function + %6 = OpVariable %11 Function + %7 = OpVariable %11 Function + OpStore %4 %12 + OpBranch %23 + %23 = OpLabel + OpLoopMerge %24 %25 None + OpBranch %26 + %26 = OpLabel + %27 = OpLoad %10 %4 + %28 = OpSLessThan %14 %27 %13 + OpBranchConditional %28 %29 %24 + %29 = OpLabel + OpStore %5 %12 + OpBranch %30 + %30 = OpLabel + OpLoopMerge %31 %32 None + OpBranch %33 + %33 = OpLabel + %34 = OpLoad %10 %5 + %35 = OpSLessThan %14 %34 %15 + OpBranchConditional %35 %36 %31 + %36 = OpLabel + %37 = OpLoad %10 %5 + %38 = OpSLessThan %14 %37 %16 + OpSelectionMerge %39 None + OpBranchConditional %38 %40 %39 + %40 = OpLabel + OpStore %6 %12 + OpBranch %41 + %41 = OpLabel + OpLoopMerge %42 %43 None + OpBranch %44 + %44 = OpLabel + %45 = OpLoad %10 %6 + %46 = OpSLessThan %14 %45 %17 + OpBranchConditional %46 %47 %42 + %47 = OpLabel + OpBranch %43 + %43 = OpLabel + %48 = OpLoad %10 %6 + %49 = OpIAdd %10 %48 %18 + OpStore %6 %49 + OpBranch %41 + %42 = OpLabel + OpBranch %39 + %39 = OpLabel + OpStore %7 %12 + OpBranch %50 + %50 = OpLabel + OpLoopMerge %51 %52 None + OpBranch %53 + %53 = OpLabel + %54 = OpLoad %10 %7 + %55 = OpSLessThan %14 %54 %17 + OpBranchConditional %55 %56 %51 + %56 = OpLabel + OpBranch %52 + %52 = OpLabel + %57 = OpLoad %10 %7 + %58 = OpIAdd %10 %57 %18 + OpStore %7 %58 + OpBranch %50 + %51 = OpLabel + OpBranch %32 + %32 = OpLabel + %59 = OpLoad %10 %5 + %60 = OpIAdd %10 %59 %18 + OpStore %5 %60 + OpBranch %30 + %31 = OpLabel + OpBranch %25 + %25 = OpLabel + %61 = OpLoad %10 %4 + %62 = OpIAdd %10 %61 %18 + OpStore %4 %62 + OpBranch %23 + %24 = OpLabel + OpReturn + OpFunctionEnd + )"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + EXPECT_EQ(ld.NumLoops(), 4u); + + // Invalid basic block id. + EXPECT_EQ(ld[0u], nullptr); + // Not in a loop. + EXPECT_EQ(ld[22], nullptr); + + // Check that we can map basic block to the correct loop. + // The following block ids do not belong to a loop. + for (uint32_t bb_id : {22, 24}) EXPECT_EQ(ld[bb_id], nullptr); + + { + std::unordered_set basic_block_in_loop = { + {23, 26, 29, 30, 33, 36, 40, 41, 44, 47, 43, + 42, 39, 50, 53, 56, 52, 51, 32, 31, 25}}; + Loop* loop = ld[23]; + CheckLoopBlocks(loop, &basic_block_in_loop); + + EXPECT_TRUE(loop->HasNestedLoops()); + EXPECT_FALSE(loop->IsNested()); + EXPECT_EQ(loop->GetDepth(), 1u); + EXPECT_EQ(std::distance(loop->begin(), loop->end()), 1u); + EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 22)); + EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 23)); + EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 25)); + EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 24)); + EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock())); + EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock())); + } + + { + std::unordered_set basic_block_in_loop = { + {30, 33, 36, 40, 41, 44, 47, 43, 42, 39, 50, 53, 56, 52, 51, 32}}; + Loop* loop = ld[30]; + CheckLoopBlocks(loop, &basic_block_in_loop); + + EXPECT_TRUE(loop->HasNestedLoops()); + EXPECT_TRUE(loop->IsNested()); + EXPECT_EQ(loop->GetDepth(), 2u); + EXPECT_EQ(std::distance(loop->begin(), loop->end()), 2u); + EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 29)); + EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 30)); + EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 32)); + EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 31)); + EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock())); + EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock())); + } + + { + std::unordered_set basic_block_in_loop = {{41, 44, 47, 43}}; + Loop* loop = ld[41]; + CheckLoopBlocks(loop, &basic_block_in_loop); + + EXPECT_FALSE(loop->HasNestedLoops()); + EXPECT_TRUE(loop->IsNested()); + EXPECT_EQ(loop->GetDepth(), 3u); + EXPECT_EQ(std::distance(loop->begin(), loop->end()), 0u); + EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 40)); + EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 41)); + EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 43)); + EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 42)); + EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock())); + EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock())); + } + + { + std::unordered_set basic_block_in_loop = {{50, 53, 56, 52}}; + Loop* loop = ld[50]; + CheckLoopBlocks(loop, &basic_block_in_loop); + + EXPECT_FALSE(loop->HasNestedLoops()); + EXPECT_TRUE(loop->IsNested()); + EXPECT_EQ(loop->GetDepth(), 3u); + EXPECT_EQ(std::distance(loop->begin(), loop->end()), 0u); + EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 39)); + EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 50)); + EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 52)); + EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 51)); + EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock())); + EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock())); + } + + // Make sure LoopDescriptor gives us the inner most loop when we query for + // loops. + for (const BasicBlock& bb : *f) { + if (Loop* loop = ld[&bb]) { + for (Loop& sub_loop : + make_range(++TreeDFIterator(loop), TreeDFIterator())) { + EXPECT_FALSE(sub_loop.IsInsideLoop(bb.id())); + } + } + } +} + +/* +Generated from the following GLSL +#version 330 core +layout(location = 0) out vec4 c; +void main() { + for (int i = 0; i < 10; ++i) { + for (int j = 0; j < 11; ++j) { + for (int k = 0; k < 11; ++k) {} + } + for (int k = 0; k < 12; ++k) {} + } +} +*/ +TEST_F(PassClassTest, LoopParentTest) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 330 + OpName %2 "main" + OpName %4 "i" + OpName %5 "j" + OpName %6 "k" + OpName %7 "k" + OpName %3 "c" + OpDecorate %3 Location 0 + %8 = OpTypeVoid + %9 = OpTypeFunction %8 + %10 = OpTypeInt 32 1 + %11 = OpTypePointer Function %10 + %12 = OpConstant %10 0 + %13 = OpConstant %10 10 + %14 = OpTypeBool + %15 = OpConstant %10 11 + %16 = OpConstant %10 1 + %17 = OpConstant %10 12 + %18 = OpTypeFloat 32 + %19 = OpTypeVector %18 4 + %20 = OpTypePointer Output %19 + %3 = OpVariable %20 Output + %2 = OpFunction %8 None %9 + %21 = OpLabel + %4 = OpVariable %11 Function + %5 = OpVariable %11 Function + %6 = OpVariable %11 Function + %7 = OpVariable %11 Function + OpStore %4 %12 + OpBranch %22 + %22 = OpLabel + OpLoopMerge %23 %24 None + OpBranch %25 + %25 = OpLabel + %26 = OpLoad %10 %4 + %27 = OpSLessThan %14 %26 %13 + OpBranchConditional %27 %28 %23 + %28 = OpLabel + OpStore %5 %12 + OpBranch %29 + %29 = OpLabel + OpLoopMerge %30 %31 None + OpBranch %32 + %32 = OpLabel + %33 = OpLoad %10 %5 + %34 = OpSLessThan %14 %33 %15 + OpBranchConditional %34 %35 %30 + %35 = OpLabel + OpStore %6 %12 + OpBranch %36 + %36 = OpLabel + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %40 = OpLoad %10 %6 + %41 = OpSLessThan %14 %40 %15 + OpBranchConditional %41 %42 %37 + %42 = OpLabel + OpBranch %38 + %38 = OpLabel + %43 = OpLoad %10 %6 + %44 = OpIAdd %10 %43 %16 + OpStore %6 %44 + OpBranch %36 + %37 = OpLabel + OpBranch %31 + %31 = OpLabel + %45 = OpLoad %10 %5 + %46 = OpIAdd %10 %45 %16 + OpStore %5 %46 + OpBranch %29 + %30 = OpLabel + OpStore %7 %12 + OpBranch %47 + %47 = OpLabel + OpLoopMerge %48 %49 None + OpBranch %50 + %50 = OpLabel + %51 = OpLoad %10 %7 + %52 = OpSLessThan %14 %51 %17 + OpBranchConditional %52 %53 %48 + %53 = OpLabel + OpBranch %49 + %49 = OpLabel + %54 = OpLoad %10 %7 + %55 = OpIAdd %10 %54 %16 + OpStore %7 %55 + OpBranch %47 + %48 = OpLabel + OpBranch %24 + %24 = OpLabel + %56 = OpLoad %10 %4 + %57 = OpIAdd %10 %56 %16 + OpStore %4 %57 + OpBranch %22 + %23 = OpLabel + OpReturn + OpFunctionEnd + )"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + EXPECT_EQ(ld.NumLoops(), 4u); + + { + Loop& loop = *ld[22]; + EXPECT_TRUE(loop.HasNestedLoops()); + EXPECT_FALSE(loop.IsNested()); + EXPECT_EQ(loop.GetDepth(), 1u); + EXPECT_EQ(loop.GetParent(), nullptr); + } + + { + Loop& loop = *ld[29]; + EXPECT_TRUE(loop.HasNestedLoops()); + EXPECT_TRUE(loop.IsNested()); + EXPECT_EQ(loop.GetDepth(), 2u); + EXPECT_EQ(loop.GetParent(), ld[22]); + } + + { + Loop& loop = *ld[36]; + EXPECT_FALSE(loop.HasNestedLoops()); + EXPECT_TRUE(loop.IsNested()); + EXPECT_EQ(loop.GetDepth(), 3u); + EXPECT_EQ(loop.GetParent(), ld[29]); + } + + { + Loop& loop = *ld[47]; + EXPECT_FALSE(loop.HasNestedLoops()); + EXPECT_TRUE(loop.IsNested()); + EXPECT_EQ(loop.GetDepth(), 2u); + EXPECT_EQ(loop.GetParent(), ld[22]); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store +The preheader of loop %33 and %41 were removed as well. + +#version 330 core +void main() { + int a = 0; + for (int i = 0; i < 10; ++i) { + if (i == 0) { + a = 1; + } else { + a = 2; + } + for (int j = 0; j < 11; ++j) { + a++; + } + } + for (int k = 0; k < 12; ++k) {} +} +*/ +TEST_F(PassClassTest, CreatePreheaderTest) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 330 + OpName %2 "main" + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeInt 32 1 + %6 = OpTypePointer Function %5 + %7 = OpConstant %5 0 + %8 = OpConstant %5 10 + %9 = OpTypeBool + %10 = OpConstant %5 1 + %11 = OpConstant %5 2 + %12 = OpConstant %5 11 + %13 = OpConstant %5 12 + %14 = OpUndef %5 + %2 = OpFunction %3 None %4 + %15 = OpLabel + OpBranch %16 + %16 = OpLabel + %17 = OpPhi %5 %7 %15 %18 %19 + %20 = OpPhi %5 %7 %15 %21 %19 + %22 = OpPhi %5 %14 %15 %23 %19 + OpLoopMerge %41 %19 None + OpBranch %25 + %25 = OpLabel + %26 = OpSLessThan %9 %20 %8 + OpBranchConditional %26 %27 %41 + %27 = OpLabel + %28 = OpIEqual %9 %20 %7 + OpSelectionMerge %33 None + OpBranchConditional %28 %30 %31 + %30 = OpLabel + OpBranch %33 + %31 = OpLabel + OpBranch %33 + %33 = OpLabel + %18 = OpPhi %5 %10 %30 %11 %31 %34 %35 + %23 = OpPhi %5 %7 %30 %7 %31 %36 %35 + OpLoopMerge %37 %35 None + OpBranch %38 + %38 = OpLabel + %39 = OpSLessThan %9 %23 %12 + OpBranchConditional %39 %40 %37 + %40 = OpLabel + %34 = OpIAdd %5 %18 %10 + OpBranch %35 + %35 = OpLabel + %36 = OpIAdd %5 %23 %10 + OpBranch %33 + %37 = OpLabel + OpBranch %19 + %19 = OpLabel + %21 = OpIAdd %5 %20 %10 + OpBranch %16 + %41 = OpLabel + %42 = OpPhi %5 %7 %25 %43 %44 + OpLoopMerge %45 %44 None + OpBranch %46 + %46 = OpLabel + %47 = OpSLessThan %9 %42 %13 + OpBranchConditional %47 %48 %45 + %48 = OpLabel + OpBranch %44 + %44 = OpLabel + %43 = OpIAdd %5 %42 %10 + OpBranch %41 + %45 = OpLabel + OpReturn + OpFunctionEnd + )"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + // No invalidation of the cfg should occur during this test. + CFG* cfg = context->cfg(); + + EXPECT_EQ(ld.NumLoops(), 3u); + + { + Loop& loop = *ld[16]; + EXPECT_TRUE(loop.HasNestedLoops()); + EXPECT_FALSE(loop.IsNested()); + EXPECT_EQ(loop.GetDepth(), 1u); + EXPECT_EQ(loop.GetParent(), nullptr); + } + + { + Loop& loop = *ld[33]; + EXPECT_EQ(loop.GetPreHeaderBlock(), nullptr); + EXPECT_NE(loop.GetOrCreatePreHeaderBlock(), nullptr); + // Make sure the loop descriptor was properly updated. + EXPECT_EQ(ld[loop.GetPreHeaderBlock()], ld[16]); + { + const std::vector& preds = + cfg->preds(loop.GetPreHeaderBlock()->id()); + std::unordered_set pred_set(preds.begin(), preds.end()); + EXPECT_EQ(pred_set.size(), 2u); + EXPECT_TRUE(pred_set.count(30)); + EXPECT_TRUE(pred_set.count(31)); + // Check the phi instructions. + loop.GetPreHeaderBlock()->ForEachPhiInst([&pred_set](Instruction* phi) { + for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) { + EXPECT_TRUE(pred_set.count(phi->GetSingleWordInOperand(i))); + } + }); + } + { + const std::vector& preds = + cfg->preds(loop.GetHeaderBlock()->id()); + std::unordered_set pred_set(preds.begin(), preds.end()); + EXPECT_EQ(pred_set.size(), 2u); + EXPECT_TRUE(pred_set.count(loop.GetPreHeaderBlock()->id())); + EXPECT_TRUE(pred_set.count(35)); + // Check the phi instructions. + loop.GetHeaderBlock()->ForEachPhiInst([&pred_set](Instruction* phi) { + for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) { + EXPECT_TRUE(pred_set.count(phi->GetSingleWordInOperand(i))); + } + }); + } + } + + { + Loop& loop = *ld[41]; + EXPECT_EQ(loop.GetPreHeaderBlock(), nullptr); + EXPECT_NE(loop.GetOrCreatePreHeaderBlock(), nullptr); + EXPECT_EQ(ld[loop.GetPreHeaderBlock()], nullptr); + EXPECT_EQ(cfg->preds(loop.GetPreHeaderBlock()->id()).size(), 1u); + EXPECT_EQ(cfg->preds(loop.GetPreHeaderBlock()->id())[0], 25u); + // Check the phi instructions. + loop.GetPreHeaderBlock()->ForEachPhiInst([](Instruction* phi) { + EXPECT_EQ(phi->NumInOperands(), 2u); + EXPECT_EQ(phi->GetSingleWordInOperand(1), 25u); + }); + { + const std::vector& preds = + cfg->preds(loop.GetHeaderBlock()->id()); + std::unordered_set pred_set(preds.begin(), preds.end()); + EXPECT_EQ(pred_set.size(), 2u); + EXPECT_TRUE(pred_set.count(loop.GetPreHeaderBlock()->id())); + EXPECT_TRUE(pred_set.count(44)); + // Check the phi instructions. + loop.GetHeaderBlock()->ForEachPhiInst([&pred_set](Instruction* phi) { + for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) { + EXPECT_TRUE(pred_set.count(phi->GetSingleWordInOperand(i))); + } + }); + } + } + + // Make sure pre-header insertion leaves the module valid. + std::vector bin; + context->module()->ToBinary(&bin, true); + EXPECT_TRUE(Validate(bin)); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/loop_optimizations/pch_test_opt_loop.cpp b/test/opt/loop_optimizations/pch_test_opt_loop.cpp new file mode 100644 index 000000000..f4ac7b298 --- /dev/null +++ b/test/opt/loop_optimizations/pch_test_opt_loop.cpp @@ -0,0 +1,15 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "pch_test_opt_loop.h" diff --git a/test/opt/loop_optimizations/pch_test_opt_loop.h b/test/opt/loop_optimizations/pch_test_opt_loop.h new file mode 100644 index 000000000..4e8106fbf --- /dev/null +++ b/test/opt/loop_optimizations/pch_test_opt_loop.h @@ -0,0 +1,25 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gmock/gmock.h" +#include "source/opt/iterator.h" +#include "source/opt/loop_dependence.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/pass.h" +#include "source/opt/scalar_analysis.h" +#include "source/opt/tree_iterator.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" diff --git a/test/opt/loop_optimizations/peeling.cpp b/test/opt/loop_optimizations/peeling.cpp new file mode 100644 index 000000000..10d8add38 --- /dev/null +++ b/test/opt/loop_optimizations/peeling.cpp @@ -0,0 +1,1186 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "effcee/effcee.h" +#include "gmock/gmock.h" +#include "source/opt/ir_builder.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/loop_peeling.h" +#include "test/opt/pass_fixture.h" + +namespace spvtools { +namespace opt { +namespace { + +using PeelingTest = PassTest<::testing::Test>; + +bool Validate(const std::vector& bin) { + spv_target_env target_env = SPV_ENV_UNIVERSAL_1_2; + spv_context spvContext = spvContextCreate(target_env); + spv_diagnostic diagnostic = nullptr; + spv_const_binary_t binary = {bin.data(), bin.size()}; + spv_result_t error = spvValidate(spvContext, &binary, &diagnostic); + if (error != 0) spvDiagnosticPrint(diagnostic); + spvDiagnosticDestroy(diagnostic); + spvContextDestroy(spvContext); + return error == 0; +} + +void Match(const std::string& checks, IRContext* context) { + // Silence unused warnings with !defined(SPIRV_EFFCE) + (void)checks; + + std::vector bin; + context->module()->ToBinary(&bin, true); + EXPECT_TRUE(Validate(bin)); + std::string assembly; + SpirvTools tools(SPV_ENV_UNIVERSAL_1_2); + EXPECT_TRUE( + tools.Disassemble(bin, &assembly, SPV_BINARY_TO_TEXT_OPTION_NO_HEADER)) + << "Disassembling failed for shader:\n" + << assembly << std::endl; + auto match_result = effcee::Match(assembly, checks); + EXPECT_EQ(effcee::Result::Status::Ok, match_result.status()) + << match_result.message() << "\nChecking result:\n" + << assembly; +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +First test: +#version 330 core +void main() { + for(int i = 0; i < 10; ++i) { + if (i < 4) + break; + } +} + +Second test (with a common sub-expression elimination): +#version 330 core +void main() { + for(int i = 0; i + 1 < 10; ++i) { + } +} + +Third test: +#version 330 core +void main() { + int a[10]; + for (int i = 0; a[i] != 0; i++) {} +} + +Forth test: +#version 330 core +void main() { + for (long i = 0; i < 10; i++) {} +} + +Fifth test: +#version 330 core +void main() { + for (float i = 0; i < 10; i++) {} +} + +Sixth test: +#version 450 +layout(location = 0)out float o; +void main() { + o = 0.0; + for( int i = 0; true; i++ ) { + o += 1.0; + if (i > 10) break; + } +} +*/ +TEST_F(PeelingTest, CannotPeel) { + // Build the given SPIR-V program in |text|, take the first loop in the first + // function and test that it is not peelable. |loop_count_id| is the id + // representing the loop count, if equals to 0, then the function build a 10 + // constant as loop count. + auto test_cannot_peel = [](const std::string& text, uint32_t loop_count_id) { + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + + EXPECT_EQ(ld.NumLoops(), 1u); + + Instruction* loop_count = nullptr; + if (loop_count_id) { + loop_count = context->get_def_use_mgr()->GetDef(loop_count_id); + } else { + InstructionBuilder builder(context.get(), &*f.begin()); + // Exit condition. + loop_count = builder.GetSintConstant(10); + } + + LoopPeeling peel(&*ld.begin(), loop_count); + EXPECT_FALSE(peel.CanPeelLoop()); + }; + { + SCOPED_TRACE("loop with break"); + + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginLowerLeft + OpSource GLSL 330 + OpName %main "main" + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %int_0 = OpConstant %int 0 + %int_10 = OpConstant %int 10 + %bool = OpTypeBool + %int_4 = OpConstant %int 4 + %int_1 = OpConstant %int 1 + %main = OpFunction %void None %3 + %5 = OpLabel + OpBranch %10 + %10 = OpLabel + %28 = OpPhi %int %int_0 %5 %27 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %bool %28 %int_10 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %21 = OpSLessThan %bool %28 %int_4 + OpSelectionMerge %23 None + OpBranchConditional %21 %22 %23 + %22 = OpLabel + OpBranch %12 + %23 = OpLabel + OpBranch %13 + %13 = OpLabel + %27 = OpIAdd %int %28 %int_1 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + test_cannot_peel(text, 0); + } + + { + SCOPED_TRACE("Ambiguous iterator update"); + + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginLowerLeft + OpSource GLSL 330 + OpName %main "main" + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %int_0 = OpConstant %int 0 + %int_1 = OpConstant %int 1 + %int_10 = OpConstant %int 10 + %bool = OpTypeBool + %main = OpFunction %void None %3 + %5 = OpLabel + OpBranch %10 + %10 = OpLabel + %23 = OpPhi %int %int_0 %5 %17 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %17 = OpIAdd %int %23 %int_1 + %20 = OpSLessThan %bool %17 %int_10 + OpBranchConditional %20 %11 %12 + %11 = OpLabel + OpBranch %13 + %13 = OpLabel + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + test_cannot_peel(text, 0); + } + + { + SCOPED_TRACE("No loop static bounds"); + + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginLowerLeft + OpSource GLSL 330 + OpName %main "main" + OpName %i "i" + OpName %a "a" + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %int_0 = OpConstant %int 0 + %uint = OpTypeInt 32 0 + %uint_10 = OpConstant %uint 10 +%_arr_int_uint_10 = OpTypeArray %int %uint_10 +%_ptr_Function__arr_int_uint_10 = OpTypePointer Function %_arr_int_uint_10 + %bool = OpTypeBool + %int_1 = OpConstant %int 1 + %main = OpFunction %void None %3 + %5 = OpLabel + %i = OpVariable %_ptr_Function_int Function + %a = OpVariable %_ptr_Function__arr_int_uint_10 Function + OpStore %i %int_0 + OpBranch %10 + %10 = OpLabel + %28 = OpPhi %int %int_0 %5 %27 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %21 = OpAccessChain %_ptr_Function_int %a %28 + %22 = OpLoad %int %21 + %24 = OpINotEqual %bool %22 %int_0 + OpBranchConditional %24 %11 %12 + %11 = OpLabel + OpBranch %13 + %13 = OpLabel + %27 = OpIAdd %int %28 %int_1 + OpStore %i %27 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + test_cannot_peel(text, 22); + } + { + SCOPED_TRACE("Int 64 type for conditions"); + + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginLowerLeft + OpSource GLSL 330 + OpName %2 "main" + OpName %4 "i" + %6 = OpTypeVoid + %3 = OpTypeFunction %6 + %7 = OpTypeInt 64 1 + %8 = OpTypePointer Function %7 + %9 = OpConstant %7 0 + %15 = OpConstant %7 10 + %16 = OpTypeBool + %17 = OpConstant %7 1 + %2 = OpFunction %6 None %3 + %5 = OpLabel + %4 = OpVariable %8 Function + OpStore %4 %9 + OpBranch %10 + %10 = OpLabel + %22 = OpPhi %7 %9 %5 %21 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %16 %22 %15 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %13 + %13 = OpLabel + %21 = OpIAdd %7 %22 %17 + OpStore %4 %21 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + // %15 is a constant for a 64 int. Currently rejected. + test_cannot_peel(text, 15); + } + { + SCOPED_TRACE("Float type for conditions"); + + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginLowerLeft + OpSource GLSL 330 + OpName %2 "main" + OpName %4 "i" + %6 = OpTypeVoid + %3 = OpTypeFunction %6 + %7 = OpTypeFloat 32 + %8 = OpTypePointer Function %7 + %9 = OpConstant %7 0 + %15 = OpConstant %7 10 + %16 = OpTypeBool + %17 = OpConstant %7 1 + %2 = OpFunction %6 None %3 + %5 = OpLabel + %4 = OpVariable %8 Function + OpStore %4 %9 + OpBranch %10 + %10 = OpLabel + %22 = OpPhi %7 %9 %5 %21 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpFOrdLessThan %16 %22 %15 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %13 + %13 = OpLabel + %21 = OpFAdd %7 %22 %17 + OpStore %4 %21 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + // %15 is a constant for a float. Currently rejected. + test_cannot_peel(text, 15); + } + { + SCOPED_TRACE("Side effect before exit"); + + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %o + OpExecutionMode %main OriginLowerLeft + OpSource GLSL 450 + OpName %main "main" + OpName %o "o" + OpName %i "i" + OpDecorate %o Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 +%_ptr_Output_float = OpTypePointer Output %float + %o = OpVariable %_ptr_Output_float Output + %float_0 = OpConstant %float 0 + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %int_0 = OpConstant %int 0 + %bool = OpTypeBool + %true = OpConstantTrue %bool + %float_1 = OpConstant %float 1 + %int_10 = OpConstant %int 10 + %int_1 = OpConstant %int 1 + %main = OpFunction %void None %3 + %5 = OpLabel + %i = OpVariable %_ptr_Function_int Function + OpStore %o %float_0 + OpStore %i %int_0 + OpBranch %14 + %14 = OpLabel + %33 = OpPhi %int %int_0 %5 %32 %17 + OpLoopMerge %16 %17 None + OpBranch %15 + %15 = OpLabel + %22 = OpLoad %float %o + %23 = OpFAdd %float %22 %float_1 + OpStore %o %23 + %26 = OpSGreaterThan %bool %33 %int_10 + OpSelectionMerge %28 None + OpBranchConditional %26 %27 %28 + %27 = OpLabel + OpBranch %16 + %28 = OpLabel + OpBranch %17 + %17 = OpLabel + %32 = OpIAdd %int %33 %int_1 + OpStore %i %32 + OpBranch %14 + %16 = OpLabel + OpReturn + OpFunctionEnd + )"; + test_cannot_peel(text, 0); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 330 core +void main() { + int i = 0; + for (; i < 10; i++) {} +} +*/ +TEST_F(PeelingTest, SimplePeeling) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginLowerLeft + OpSource GLSL 330 + OpName %main "main" + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %int_0 = OpConstant %int 0 + %int_10 = OpConstant %int 10 + %bool = OpTypeBool + %int_1 = OpConstant %int 1 + %main = OpFunction %void None %3 + %5 = OpLabel + OpBranch %10 + %10 = OpLabel + %22 = OpPhi %int %int_0 %5 %21 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %bool %22 %int_10 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %13 + %13 = OpLabel + %21 = OpIAdd %int %22 %int_1 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + // Peel before. + { + SCOPED_TRACE("Peel before"); + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + + EXPECT_EQ(ld.NumLoops(), 1u); + + InstructionBuilder builder(context.get(), &*f.begin()); + // Exit condition. + Instruction* ten_cst = builder.GetSintConstant(10); + + LoopPeeling peel(&*ld.begin(), ten_cst); + EXPECT_TRUE(peel.CanPeelLoop()); + peel.PeelBefore(2); + + const std::string check = R"( +CHECK: [[CST_TEN:%\w+]] = OpConstant {{%\w+}} 10 +CHECK: [[CST_TWO:%\w+]] = OpConstant {{%\w+}} 2 +CHECK: OpFunction +CHECK-NEXT: [[ENTRY:%\w+]] = OpLabel +CHECK: [[MIN_LOOP_COUNT:%\w+]] = OpSLessThan {{%\w+}} [[CST_TWO]] [[CST_TEN]] +CHECK-NEXT: [[LOOP_COUNT:%\w+]] = OpSelect {{%\w+}} [[MIN_LOOP_COUNT]] [[CST_TWO]] [[CST_TEN]] +CHECK: [[BEFORE_LOOP:%\w+]] = OpLabel +CHECK-NEXT: [[DUMMY_IT:%\w+]] = OpPhi {{%\w+}} {{%\w+}} [[ENTRY]] [[DUMMY_IT_1:%\w+]] [[BE:%\w+]] +CHECK-NEXT: [[i:%\w+]] = OpPhi {{%\w+}} {{%\w+}} [[ENTRY]] [[I_1:%\w+]] [[BE]] +CHECK-NEXT: OpLoopMerge [[AFTER_LOOP_PREHEADER:%\w+]] [[BE]] None +CHECK: [[COND_BLOCK:%\w+]] = OpLabel +CHECK-NEXT: OpSLessThan +CHECK-NEXT: [[EXIT_COND:%\w+]] = OpSLessThan {{%\w+}} [[DUMMY_IT]] +CHECK-NEXT: OpBranchConditional [[EXIT_COND]] {{%\w+}} [[AFTER_LOOP_PREHEADER]] +CHECK: [[I_1]] = OpIAdd {{%\w+}} [[i]] +CHECK-NEXT: [[DUMMY_IT_1]] = OpIAdd {{%\w+}} [[DUMMY_IT]] +CHECK-NEXT: OpBranch [[BEFORE_LOOP]] + +CHECK: [[AFTER_LOOP_PREHEADER]] = OpLabel +CHECK-NEXT: OpSelectionMerge [[IF_MERGE:%\w+]] +CHECK-NEXT: OpBranchConditional [[MIN_LOOP_COUNT]] [[AFTER_LOOP:%\w+]] [[IF_MERGE]] + +CHECK: [[AFTER_LOOP]] = OpLabel +CHECK-NEXT: OpPhi {{%\w+}} {{%\w+}} {{%\w+}} [[i]] [[AFTER_LOOP_PREHEADER]] +CHECK-NEXT: OpLoopMerge +)"; + + Match(check, context.get()); + } + + // Peel after. + { + SCOPED_TRACE("Peel after"); + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + + EXPECT_EQ(ld.NumLoops(), 1u); + + InstructionBuilder builder(context.get(), &*f.begin()); + // Exit condition. + Instruction* ten_cst = builder.GetSintConstant(10); + + LoopPeeling peel(&*ld.begin(), ten_cst); + EXPECT_TRUE(peel.CanPeelLoop()); + peel.PeelAfter(2); + + const std::string check = R"( +CHECK: OpFunction +CHECK-NEXT: [[ENTRY:%\w+]] = OpLabel +CHECK: [[MIN_LOOP_COUNT:%\w+]] = OpSLessThan {{%\w+}} +CHECK-NEXT: OpSelectionMerge [[IF_MERGE:%\w+]] +CHECK-NEXT: OpBranchConditional [[MIN_LOOP_COUNT]] [[BEFORE_LOOP:%\w+]] [[IF_MERGE]] +CHECK: [[BEFORE_LOOP]] = OpLabel +CHECK-NEXT: [[DUMMY_IT:%\w+]] = OpPhi {{%\w+}} {{%\w+}} [[ENTRY]] [[DUMMY_IT_1:%\w+]] [[BE:%\w+]] +CHECK-NEXT: [[I:%\w+]] = OpPhi {{%\w+}} {{%\w+}} [[ENTRY]] [[I_1:%\w+]] [[BE]] +CHECK-NEXT: OpLoopMerge [[BEFORE_LOOP_MERGE:%\w+]] [[BE]] None +CHECK: [[COND_BLOCK:%\w+]] = OpLabel +CHECK-NEXT: OpSLessThan +CHECK-NEXT: [[TMP:%\w+]] = OpIAdd {{%\w+}} [[DUMMY_IT]] {{%\w+}} +CHECK-NEXT: [[EXIT_COND:%\w+]] = OpSLessThan {{%\w+}} [[TMP]] +CHECK-NEXT: OpBranchConditional [[EXIT_COND]] {{%\w+}} [[BEFORE_LOOP_MERGE]] +CHECK: [[I_1]] = OpIAdd {{%\w+}} [[I]] +CHECK-NEXT: [[DUMMY_IT_1]] = OpIAdd {{%\w+}} [[DUMMY_IT]] +CHECK-NEXT: OpBranch [[BEFORE_LOOP]] + +CHECK: [[IF_MERGE]] = OpLabel +CHECK-NEXT: [[TMP:%\w+]] = OpPhi {{%\w+}} [[I]] [[BEFORE_LOOP_MERGE]] +CHECK-NEXT: OpBranch [[AFTER_LOOP:%\w+]] + +CHECK: [[AFTER_LOOP]] = OpLabel +CHECK-NEXT: OpPhi {{%\w+}} {{%\w+}} {{%\w+}} [[TMP]] [[IF_MERGE]] +CHECK-NEXT: OpLoopMerge + +)"; + + Match(check, context.get()); + } + + // Same as above, but reuse the induction variable. + // Peel before. + { + SCOPED_TRACE("Peel before with IV reuse"); + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + + EXPECT_EQ(ld.NumLoops(), 1u); + + InstructionBuilder builder(context.get(), &*f.begin()); + // Exit condition. + Instruction* ten_cst = builder.GetSintConstant(10); + + LoopPeeling peel(&*ld.begin(), ten_cst, + context->get_def_use_mgr()->GetDef(22)); + EXPECT_TRUE(peel.CanPeelLoop()); + peel.PeelBefore(2); + + const std::string check = R"( +CHECK: [[CST_TEN:%\w+]] = OpConstant {{%\w+}} 10 +CHECK: [[CST_TWO:%\w+]] = OpConstant {{%\w+}} 2 +CHECK: OpFunction +CHECK-NEXT: [[ENTRY:%\w+]] = OpLabel +CHECK: [[MIN_LOOP_COUNT:%\w+]] = OpSLessThan {{%\w+}} [[CST_TWO]] [[CST_TEN]] +CHECK-NEXT: [[LOOP_COUNT:%\w+]] = OpSelect {{%\w+}} [[MIN_LOOP_COUNT]] [[CST_TWO]] [[CST_TEN]] +CHECK: [[BEFORE_LOOP:%\w+]] = OpLabel +CHECK-NEXT: [[i:%\w+]] = OpPhi {{%\w+}} {{%\w+}} [[ENTRY]] [[I_1:%\w+]] [[BE:%\w+]] +CHECK-NEXT: OpLoopMerge [[AFTER_LOOP_PREHEADER:%\w+]] [[BE]] None +CHECK: [[COND_BLOCK:%\w+]] = OpLabel +CHECK-NEXT: OpSLessThan +CHECK-NEXT: [[EXIT_COND:%\w+]] = OpSLessThan {{%\w+}} [[i]] +CHECK-NEXT: OpBranchConditional [[EXIT_COND]] {{%\w+}} [[AFTER_LOOP_PREHEADER]] +CHECK: [[I_1]] = OpIAdd {{%\w+}} [[i]] +CHECK-NEXT: OpBranch [[BEFORE_LOOP]] + +CHECK: [[AFTER_LOOP_PREHEADER]] = OpLabel +CHECK-NEXT: OpSelectionMerge [[IF_MERGE:%\w+]] +CHECK-NEXT: OpBranchConditional [[MIN_LOOP_COUNT]] [[AFTER_LOOP:%\w+]] [[IF_MERGE]] + +CHECK: [[AFTER_LOOP]] = OpLabel +CHECK-NEXT: OpPhi {{%\w+}} {{%\w+}} {{%\w+}} [[i]] [[AFTER_LOOP_PREHEADER]] +CHECK-NEXT: OpLoopMerge +)"; + + Match(check, context.get()); + } + + // Peel after. + { + SCOPED_TRACE("Peel after IV reuse"); + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + + EXPECT_EQ(ld.NumLoops(), 1u); + + InstructionBuilder builder(context.get(), &*f.begin()); + // Exit condition. + Instruction* ten_cst = builder.GetSintConstant(10); + + LoopPeeling peel(&*ld.begin(), ten_cst, + context->get_def_use_mgr()->GetDef(22)); + EXPECT_TRUE(peel.CanPeelLoop()); + peel.PeelAfter(2); + + const std::string check = R"( +CHECK: OpFunction +CHECK-NEXT: [[ENTRY:%\w+]] = OpLabel +CHECK: [[MIN_LOOP_COUNT:%\w+]] = OpSLessThan {{%\w+}} +CHECK-NEXT: OpSelectionMerge [[IF_MERGE:%\w+]] +CHECK-NEXT: OpBranchConditional [[MIN_LOOP_COUNT]] [[BEFORE_LOOP:%\w+]] [[IF_MERGE]] +CHECK: [[BEFORE_LOOP]] = OpLabel +CHECK-NEXT: [[I:%\w+]] = OpPhi {{%\w+}} {{%\w+}} [[ENTRY]] [[I_1:%\w+]] [[BE:%\w+]] +CHECK-NEXT: OpLoopMerge [[BEFORE_LOOP_MERGE:%\w+]] [[BE]] None +CHECK: [[COND_BLOCK:%\w+]] = OpLabel +CHECK-NEXT: OpSLessThan +CHECK-NEXT: [[TMP:%\w+]] = OpIAdd {{%\w+}} [[I]] {{%\w+}} +CHECK-NEXT: [[EXIT_COND:%\w+]] = OpSLessThan {{%\w+}} [[TMP]] +CHECK-NEXT: OpBranchConditional [[EXIT_COND]] {{%\w+}} [[BEFORE_LOOP_MERGE]] +CHECK: [[I_1]] = OpIAdd {{%\w+}} [[I]] +CHECK-NEXT: OpBranch [[BEFORE_LOOP]] + +CHECK: [[IF_MERGE]] = OpLabel +CHECK-NEXT: [[TMP:%\w+]] = OpPhi {{%\w+}} [[I]] [[BEFORE_LOOP_MERGE]] +CHECK-NEXT: OpBranch [[AFTER_LOOP:%\w+]] + +CHECK: [[AFTER_LOOP]] = OpLabel +CHECK-NEXT: OpPhi {{%\w+}} {{%\w+}} {{%\w+}} [[TMP]] [[IF_MERGE]] +CHECK-NEXT: OpLoopMerge + +)"; + + Match(check, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 330 core +void main() { + int a[10]; + int n = a[0]; + for(int i = 0; i < n; ++i) {} +} +*/ +TEST_F(PeelingTest, PeelingUncountable) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginLowerLeft + OpSource GLSL 330 + OpName %main "main" + OpName %a "a" + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %uint = OpTypeInt 32 0 + %uint_10 = OpConstant %uint 10 +%_arr_int_uint_10 = OpTypeArray %int %uint_10 +%_ptr_Function__arr_int_uint_10 = OpTypePointer Function %_arr_int_uint_10 + %int_0 = OpConstant %int 0 + %bool = OpTypeBool + %int_1 = OpConstant %int 1 + %main = OpFunction %void None %3 + %5 = OpLabel + %a = OpVariable %_ptr_Function__arr_int_uint_10 Function + %15 = OpAccessChain %_ptr_Function_int %a %int_0 + %16 = OpLoad %int %15 + OpBranch %18 + %18 = OpLabel + %30 = OpPhi %int %int_0 %5 %29 %21 + OpLoopMerge %20 %21 None + OpBranch %22 + %22 = OpLabel + %26 = OpSLessThan %bool %30 %16 + OpBranchConditional %26 %19 %20 + %19 = OpLabel + OpBranch %21 + %21 = OpLabel + %29 = OpIAdd %int %30 %int_1 + OpBranch %18 + %20 = OpLabel + OpReturn + OpFunctionEnd + )"; + + // Peel before. + { + SCOPED_TRACE("Peel before"); + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + + EXPECT_EQ(ld.NumLoops(), 1u); + + Instruction* loop_count = context->get_def_use_mgr()->GetDef(16); + EXPECT_EQ(loop_count->opcode(), SpvOpLoad); + + LoopPeeling peel(&*ld.begin(), loop_count); + EXPECT_TRUE(peel.CanPeelLoop()); + peel.PeelBefore(1); + + const std::string check = R"( +CHECK: OpFunction +CHECK-NEXT: [[ENTRY:%\w+]] = OpLabel +CHECK: [[LOOP_COUNT:%\w+]] = OpLoad +CHECK: [[MIN_LOOP_COUNT:%\w+]] = OpSLessThan {{%\w+}} {{%\w+}} [[LOOP_COUNT]] +CHECK-NEXT: [[LOOP_COUNT:%\w+]] = OpSelect {{%\w+}} [[MIN_LOOP_COUNT]] {{%\w+}} [[LOOP_COUNT]] +CHECK: [[BEFORE_LOOP:%\w+]] = OpLabel +CHECK-NEXT: [[DUMMY_IT:%\w+]] = OpPhi {{%\w+}} {{%\w+}} [[ENTRY]] [[DUMMY_IT_1:%\w+]] [[BE:%\w+]] +CHECK-NEXT: [[i:%\w+]] = OpPhi {{%\w+}} {{%\w+}} [[ENTRY]] [[I_1:%\w+]] [[BE]] +CHECK-NEXT: OpLoopMerge [[AFTER_LOOP_PREHEADER:%\w+]] [[BE]] None +CHECK: [[COND_BLOCK:%\w+]] = OpLabel +CHECK-NEXT: OpSLessThan +CHECK-NEXT: [[EXIT_COND:%\w+]] = OpSLessThan {{%\w+}} [[DUMMY_IT]] +CHECK-NEXT: OpBranchConditional [[EXIT_COND]] {{%\w+}} [[AFTER_LOOP_PREHEADER]] +CHECK: [[I_1]] = OpIAdd {{%\w+}} [[i]] +CHECK-NEXT: [[DUMMY_IT_1]] = OpIAdd {{%\w+}} [[DUMMY_IT]] +CHECK-NEXT: OpBranch [[BEFORE_LOOP]] + +CHECK: [[AFTER_LOOP_PREHEADER]] = OpLabel +CHECK-NEXT: OpSelectionMerge [[IF_MERGE:%\w+]] +CHECK-NEXT: OpBranchConditional [[MIN_LOOP_COUNT]] [[AFTER_LOOP:%\w+]] [[IF_MERGE]] + +CHECK: [[AFTER_LOOP]] = OpLabel +CHECK-NEXT: OpPhi {{%\w+}} {{%\w+}} {{%\w+}} [[i]] [[AFTER_LOOP_PREHEADER]] +CHECK-NEXT: OpLoopMerge +)"; + + Match(check, context.get()); + } + + // Peel after. + { + SCOPED_TRACE("Peel after"); + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + + EXPECT_EQ(ld.NumLoops(), 1u); + + Instruction* loop_count = context->get_def_use_mgr()->GetDef(16); + EXPECT_EQ(loop_count->opcode(), SpvOpLoad); + + LoopPeeling peel(&*ld.begin(), loop_count); + EXPECT_TRUE(peel.CanPeelLoop()); + peel.PeelAfter(1); + + const std::string check = R"( +CHECK: OpFunction +CHECK-NEXT: [[ENTRY:%\w+]] = OpLabel +CHECK: [[MIN_LOOP_COUNT:%\w+]] = OpSLessThan {{%\w+}} +CHECK-NEXT: OpSelectionMerge [[IF_MERGE:%\w+]] +CHECK-NEXT: OpBranchConditional [[MIN_LOOP_COUNT]] [[BEFORE_LOOP:%\w+]] [[IF_MERGE]] +CHECK: [[BEFORE_LOOP]] = OpLabel +CHECK-NEXT: [[DUMMY_IT:%\w+]] = OpPhi {{%\w+}} {{%\w+}} [[ENTRY]] [[DUMMY_IT_1:%\w+]] [[BE:%\w+]] +CHECK-NEXT: [[I:%\w+]] = OpPhi {{%\w+}} {{%\w+}} [[ENTRY]] [[I_1:%\w+]] [[BE]] +CHECK-NEXT: OpLoopMerge [[BEFORE_LOOP_MERGE:%\w+]] [[BE]] None +CHECK: [[COND_BLOCK:%\w+]] = OpLabel +CHECK-NEXT: OpSLessThan +CHECK-NEXT: [[TMP:%\w+]] = OpIAdd {{%\w+}} [[DUMMY_IT]] {{%\w+}} +CHECK-NEXT: [[EXIT_COND:%\w+]] = OpSLessThan {{%\w+}} [[TMP]] +CHECK-NEXT: OpBranchConditional [[EXIT_COND]] {{%\w+}} [[BEFORE_LOOP_MERGE]] +CHECK: [[I_1]] = OpIAdd {{%\w+}} [[I]] +CHECK-NEXT: [[DUMMY_IT_1]] = OpIAdd {{%\w+}} [[DUMMY_IT]] +CHECK-NEXT: OpBranch [[BEFORE_LOOP]] + +CHECK: [[IF_MERGE]] = OpLabel +CHECK-NEXT: [[TMP:%\w+]] = OpPhi {{%\w+}} [[I]] [[BEFORE_LOOP_MERGE]] +CHECK-NEXT: OpBranch [[AFTER_LOOP:%\w+]] + +CHECK: [[AFTER_LOOP]] = OpLabel +CHECK-NEXT: OpPhi {{%\w+}} {{%\w+}} {{%\w+}} [[TMP]] [[IF_MERGE]] +CHECK-NEXT: OpLoopMerge + +)"; + + Match(check, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 330 core +void main() { + int i = 0; + do { + i++; + } while (i < 10); +} +*/ +TEST_F(PeelingTest, DoWhilePeeling) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginLowerLeft + OpSource GLSL 330 + OpName %main "main" + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %int_0 = OpConstant %int 0 + %int_1 = OpConstant %int 1 + %int_10 = OpConstant %int 10 + %bool = OpTypeBool + %main = OpFunction %void None %3 + %5 = OpLabel + OpBranch %10 + %10 = OpLabel + %21 = OpPhi %int %int_0 %5 %16 %13 + OpLoopMerge %12 %13 None + OpBranch %11 + %11 = OpLabel + %16 = OpIAdd %int %21 %int_1 + OpBranch %13 + %13 = OpLabel + %20 = OpSLessThan %bool %16 %int_10 + OpBranchConditional %20 %10 %12 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + // Peel before. + { + SCOPED_TRACE("Peel before"); + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + + EXPECT_EQ(ld.NumLoops(), 1u); + InstructionBuilder builder(context.get(), &*f.begin()); + // Exit condition. + Instruction* ten_cst = builder.GetUintConstant(10); + + LoopPeeling peel(&*ld.begin(), ten_cst); + EXPECT_TRUE(peel.CanPeelLoop()); + peel.PeelBefore(2); + + const std::string check = R"( +CHECK: OpFunction +CHECK-NEXT: [[ENTRY:%\w+]] = OpLabel +CHECK: [[MIN_LOOP_COUNT:%\w+]] = OpULessThan {{%\w+}} +CHECK-NEXT: [[LOOP_COUNT:%\w+]] = OpSelect {{%\w+}} [[MIN_LOOP_COUNT]] +CHECK: [[BEFORE_LOOP:%\w+]] = OpLabel +CHECK-NEXT: [[DUMMY_IT:%\w+]] = OpPhi {{%\w+}} {{%\w+}} [[ENTRY]] [[DUMMY_IT_1:%\w+]] [[BE:%\w+]] +CHECK-NEXT: [[i:%\w+]] = OpPhi {{%\w+}} {{%\w+}} [[ENTRY]] [[I_1:%\w+]] [[BE]] +CHECK-NEXT: OpLoopMerge [[AFTER_LOOP_PREHEADER:%\w+]] [[BE]] None +CHECK: [[I_1]] = OpIAdd {{%\w+}} [[i]] +CHECK: [[BE]] = OpLabel +CHECK: [[DUMMY_IT_1]] = OpIAdd {{%\w+}} [[DUMMY_IT]] +CHECK-NEXT: [[EXIT_COND:%\w+]] = OpULessThan {{%\w+}} [[DUMMY_IT_1]] +CHECK-NEXT: OpBranchConditional [[EXIT_COND]] [[BEFORE_LOOP]] [[AFTER_LOOP_PREHEADER]] + +CHECK: [[AFTER_LOOP_PREHEADER]] = OpLabel +CHECK-NEXT: OpSelectionMerge [[IF_MERGE:%\w+]] +CHECK-NEXT: OpBranchConditional [[MIN_LOOP_COUNT]] [[AFTER_LOOP:%\w+]] [[IF_MERGE]] + +CHECK: [[AFTER_LOOP]] = OpLabel +CHECK-NEXT: OpPhi {{%\w+}} {{%\w+}} {{%\w+}} [[I_1]] [[AFTER_LOOP_PREHEADER]] +CHECK-NEXT: OpLoopMerge +)"; + + Match(check, context.get()); + } + + // Peel after. + { + SCOPED_TRACE("Peel after"); + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + + EXPECT_EQ(ld.NumLoops(), 1u); + + InstructionBuilder builder(context.get(), &*f.begin()); + // Exit condition. + Instruction* ten_cst = builder.GetUintConstant(10); + + LoopPeeling peel(&*ld.begin(), ten_cst); + EXPECT_TRUE(peel.CanPeelLoop()); + peel.PeelAfter(2); + + const std::string check = R"( +CHECK: OpFunction +CHECK-NEXT: [[ENTRY:%\w+]] = OpLabel +CHECK: [[MIN_LOOP_COUNT:%\w+]] = OpULessThan {{%\w+}} +CHECK-NEXT: OpSelectionMerge [[IF_MERGE:%\w+]] +CHECK-NEXT: OpBranchConditional [[MIN_LOOP_COUNT]] [[BEFORE_LOOP:%\w+]] [[IF_MERGE]] +CHECK: [[BEFORE_LOOP]] = OpLabel +CHECK-NEXT: [[DUMMY_IT:%\w+]] = OpPhi {{%\w+}} {{%\w+}} [[ENTRY]] [[DUMMY_IT_1:%\w+]] [[BE:%\w+]] +CHECK-NEXT: [[I:%\w+]] = OpPhi {{%\w+}} {{%\w+}} [[ENTRY]] [[I_1:%\w+]] [[BE]] +CHECK-NEXT: OpLoopMerge [[BEFORE_LOOP_MERGE:%\w+]] [[BE]] None +CHECK: [[I_1]] = OpIAdd {{%\w+}} [[I]] +CHECK: [[BE]] = OpLabel +CHECK: [[DUMMY_IT_1]] = OpIAdd {{%\w+}} [[DUMMY_IT]] +CHECK-NEXT: [[EXIT_VAL:%\w+]] = OpIAdd {{%\w+}} [[DUMMY_IT_1]] +CHECK-NEXT: [[EXIT_COND:%\w+]] = OpULessThan {{%\w+}} [[EXIT_VAL]] +CHECK-NEXT: OpBranchConditional [[EXIT_COND]] [[BEFORE_LOOP]] [[BEFORE_LOOP_MERGE]] + +CHECK: [[IF_MERGE]] = OpLabel +CHECK-NEXT: [[TMP:%\w+]] = OpPhi {{%\w+}} [[I_1]] [[BEFORE_LOOP_MERGE]] +CHECK-NEXT: OpBranch [[AFTER_LOOP:%\w+]] + +CHECK: [[AFTER_LOOP]] = OpLabel +CHECK-NEXT: OpPhi {{%\w+}} {{%\w+}} {{%\w+}} [[TMP]] [[IF_MERGE]] +CHECK-NEXT: OpLoopMerge +)"; + + Match(check, context.get()); + } +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 330 core +void main() { + int a[10]; + int n = a[0]; + for(int i = 0; i < n; ++i) {} +} +*/ +TEST_F(PeelingTest, PeelingLoopWithStore) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %o %n + OpExecutionMode %main OriginLowerLeft + OpSource GLSL 450 + OpName %main "main" + OpName %o "o" + OpName %end "end" + OpName %n "n" + OpName %i "i" + OpDecorate %o Location 0 + OpDecorate %n Flat + OpDecorate %n Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 +%_ptr_Output_float = OpTypePointer Output %float + %o = OpVariable %_ptr_Output_float Output + %float_0 = OpConstant %float 0 + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%_ptr_Input_int = OpTypePointer Input %int + %n = OpVariable %_ptr_Input_int Input + %int_0 = OpConstant %int 0 + %bool = OpTypeBool + %float_1 = OpConstant %float 1 + %int_1 = OpConstant %int 1 + %main = OpFunction %void None %3 + %5 = OpLabel + %end = OpVariable %_ptr_Function_int Function + %i = OpVariable %_ptr_Function_int Function + OpStore %o %float_0 + %15 = OpLoad %int %n + OpStore %end %15 + OpStore %i %int_0 + OpBranch %18 + %18 = OpLabel + %33 = OpPhi %int %int_0 %5 %32 %21 + OpLoopMerge %20 %21 None + OpBranch %22 + %22 = OpLabel + %26 = OpSLessThan %bool %33 %15 + OpBranchConditional %26 %19 %20 + %19 = OpLabel + %28 = OpLoad %float %o + %29 = OpFAdd %float %28 %float_1 + OpStore %o %29 + OpBranch %21 + %21 = OpLabel + %32 = OpIAdd %int %33 %int_1 + OpStore %i %32 + OpBranch %18 + %20 = OpLabel + OpReturn + OpFunctionEnd + )"; + + // Peel before. + { + SCOPED_TRACE("Peel before"); + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + + EXPECT_EQ(ld.NumLoops(), 1u); + + Instruction* loop_count = context->get_def_use_mgr()->GetDef(15); + EXPECT_EQ(loop_count->opcode(), SpvOpLoad); + + LoopPeeling peel(&*ld.begin(), loop_count); + EXPECT_TRUE(peel.CanPeelLoop()); + peel.PeelBefore(1); + + const std::string check = R"( +CHECK: OpFunction +CHECK-NEXT: [[ENTRY:%\w+]] = OpLabel +CHECK: [[LOOP_COUNT:%\w+]] = OpLoad +CHECK: [[MIN_LOOP_COUNT:%\w+]] = OpSLessThan {{%\w+}} {{%\w+}} [[LOOP_COUNT]] +CHECK-NEXT: [[LOOP_COUNT:%\w+]] = OpSelect {{%\w+}} [[MIN_LOOP_COUNT]] {{%\w+}} [[LOOP_COUNT]] +CHECK: [[BEFORE_LOOP:%\w+]] = OpLabel +CHECK-NEXT: [[DUMMY_IT:%\w+]] = OpPhi {{%\w+}} {{%\w+}} [[ENTRY]] [[DUMMY_IT_1:%\w+]] [[BE:%\w+]] +CHECK-NEXT: [[i:%\w+]] = OpPhi {{%\w+}} {{%\w+}} [[ENTRY]] [[I_1:%\w+]] [[BE]] +CHECK-NEXT: OpLoopMerge [[AFTER_LOOP_PREHEADER:%\w+]] [[BE]] None +CHECK: [[COND_BLOCK:%\w+]] = OpLabel +CHECK-NEXT: OpSLessThan +CHECK-NEXT: [[EXIT_COND:%\w+]] = OpSLessThan {{%\w+}} [[DUMMY_IT]] +CHECK-NEXT: OpBranchConditional [[EXIT_COND]] {{%\w+}} [[AFTER_LOOP_PREHEADER]] +CHECK: [[I_1]] = OpIAdd {{%\w+}} [[i]] +CHECK: [[DUMMY_IT_1]] = OpIAdd {{%\w+}} [[DUMMY_IT]] +CHECK-NEXT: OpBranch [[BEFORE_LOOP]] + +CHECK: [[AFTER_LOOP_PREHEADER]] = OpLabel +CHECK-NEXT: OpSelectionMerge [[IF_MERGE:%\w+]] +CHECK-NEXT: OpBranchConditional [[MIN_LOOP_COUNT]] [[AFTER_LOOP:%\w+]] [[IF_MERGE]] + +CHECK: [[AFTER_LOOP]] = OpLabel +CHECK-NEXT: OpPhi {{%\w+}} {{%\w+}} {{%\w+}} [[i]] [[AFTER_LOOP_PREHEADER]] +CHECK-NEXT: OpLoopMerge +)"; + + Match(check, context.get()); + } + + // Peel after. + { + SCOPED_TRACE("Peel after"); + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function& f = *module->begin(); + LoopDescriptor& ld = *context->GetLoopDescriptor(&f); + + EXPECT_EQ(ld.NumLoops(), 1u); + + Instruction* loop_count = context->get_def_use_mgr()->GetDef(15); + EXPECT_EQ(loop_count->opcode(), SpvOpLoad); + + LoopPeeling peel(&*ld.begin(), loop_count); + EXPECT_TRUE(peel.CanPeelLoop()); + peel.PeelAfter(1); + + const std::string check = R"( +CHECK: OpFunction +CHECK-NEXT: [[ENTRY:%\w+]] = OpLabel +CHECK: [[MIN_LOOP_COUNT:%\w+]] = OpSLessThan {{%\w+}} +CHECK-NEXT: OpSelectionMerge [[IF_MERGE:%\w+]] +CHECK-NEXT: OpBranchConditional [[MIN_LOOP_COUNT]] [[BEFORE_LOOP:%\w+]] [[IF_MERGE]] +CHECK: [[BEFORE_LOOP]] = OpLabel +CHECK-NEXT: [[DUMMY_IT:%\w+]] = OpPhi {{%\w+}} {{%\w+}} [[ENTRY]] [[DUMMY_IT_1:%\w+]] [[BE:%\w+]] +CHECK-NEXT: [[I:%\w+]] = OpPhi {{%\w+}} {{%\w+}} [[ENTRY]] [[I_1:%\w+]] [[BE]] +CHECK-NEXT: OpLoopMerge [[BEFORE_LOOP_MERGE:%\w+]] [[BE]] None +CHECK: [[COND_BLOCK:%\w+]] = OpLabel +CHECK-NEXT: OpSLessThan +CHECK-NEXT: [[TMP:%\w+]] = OpIAdd {{%\w+}} [[DUMMY_IT]] {{%\w+}} +CHECK-NEXT: [[EXIT_COND:%\w+]] = OpSLessThan {{%\w+}} [[TMP]] +CHECK-NEXT: OpBranchConditional [[EXIT_COND]] {{%\w+}} [[BEFORE_LOOP_MERGE]] +CHECK: [[I_1]] = OpIAdd {{%\w+}} [[I]] +CHECK: [[DUMMY_IT_1]] = OpIAdd {{%\w+}} [[DUMMY_IT]] +CHECK-NEXT: OpBranch [[BEFORE_LOOP]] + +CHECK: [[IF_MERGE]] = OpLabel +CHECK-NEXT: [[TMP:%\w+]] = OpPhi {{%\w+}} [[I]] [[BEFORE_LOOP_MERGE]] +CHECK-NEXT: OpBranch [[AFTER_LOOP:%\w+]] + +CHECK: [[AFTER_LOOP]] = OpLabel +CHECK-NEXT: OpPhi {{%\w+}} {{%\w+}} {{%\w+}} [[TMP]] [[IF_MERGE]] +CHECK-NEXT: OpLoopMerge + +)"; + + Match(check, context.get()); + } +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/loop_optimizations/peeling_pass.cpp b/test/opt/loop_optimizations/peeling_pass.cpp new file mode 100644 index 000000000..284ad838d --- /dev/null +++ b/test/opt/loop_optimizations/peeling_pass.cpp @@ -0,0 +1,1099 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/ir_builder.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/loop_peeling.h" +#include "test/opt/pass_fixture.h" + +namespace spvtools { +namespace opt { +namespace { + +class PeelingPassTest : public PassTest<::testing::Test> { + public: + // Generic routine to run the loop peeling pass and check + LoopPeelingPass::LoopPeelingStats AssembleAndRunPeelingTest( + const std::string& text_head, const std::string& text_tail, SpvOp opcode, + const std::string& res_id, const std::string& op1, + const std::string& op2) { + std::string opcode_str; + switch (opcode) { + case SpvOpSLessThan: + opcode_str = "OpSLessThan"; + break; + case SpvOpSGreaterThan: + opcode_str = "OpSGreaterThan"; + break; + case SpvOpSLessThanEqual: + opcode_str = "OpSLessThanEqual"; + break; + case SpvOpSGreaterThanEqual: + opcode_str = "OpSGreaterThanEqual"; + break; + case SpvOpIEqual: + opcode_str = "OpIEqual"; + break; + case SpvOpINotEqual: + opcode_str = "OpINotEqual"; + break; + default: + assert(false && "Unhandled"); + break; + } + std::string test_cond = + res_id + " = " + opcode_str + " %bool " + op1 + " " + op2 + "\n"; + + LoopPeelingPass::LoopPeelingStats stats; + SinglePassRunAndDisassemble( + text_head + test_cond + text_tail, true, true, &stats); + + return stats; + } + + // Generic routine to run the loop peeling pass and check + LoopPeelingPass::LoopPeelingStats RunPeelingTest( + const std::string& text_head, const std::string& text_tail, SpvOp opcode, + const std::string& res_id, const std::string& op1, const std::string& op2, + size_t nb_of_loops) { + LoopPeelingPass::LoopPeelingStats stats = AssembleAndRunPeelingTest( + text_head, text_tail, opcode, res_id, op1, op2); + + Function& f = *context()->module()->begin(); + LoopDescriptor& ld = *context()->GetLoopDescriptor(&f); + EXPECT_EQ(ld.NumLoops(), nb_of_loops); + + return stats; + } + + using PeelTraceType = + std::vector>; + + void BuildAndCheckTrace(const std::string& text_head, + const std::string& text_tail, SpvOp opcode, + const std::string& res_id, const std::string& op1, + const std::string& op2, + const PeelTraceType& expected_peel_trace, + size_t expected_nb_of_loops) { + auto stats = RunPeelingTest(text_head, text_tail, opcode, res_id, op1, op2, + expected_nb_of_loops); + + EXPECT_EQ(stats.peeled_loops_.size(), expected_peel_trace.size()); + if (stats.peeled_loops_.size() != expected_peel_trace.size()) { + return; + } + + PeelTraceType::const_iterator expected_trace_it = + expected_peel_trace.begin(); + decltype(stats.peeled_loops_)::const_iterator stats_it = + stats.peeled_loops_.begin(); + + while (expected_trace_it != expected_peel_trace.end()) { + EXPECT_EQ(expected_trace_it->first, std::get<1>(*stats_it)); + EXPECT_EQ(expected_trace_it->second, std::get<2>(*stats_it)); + ++expected_trace_it; + ++stats_it; + } + } +}; + +/* +Test are derivation of the following generated test from the following GLSL + +--eliminate-local-multi-store + +#version 330 core +void main() { + int a = 0; + for(int i = 1; i < 10; i += 2) { + if (i < 3) { + a += 2; + } + } +} + +The condition is interchanged to test < > <= >= == and peel before/after +opportunities. +*/ +TEST_F(PeelingPassTest, PeelingPassBasic) { + const std::string text_head = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginLowerLeft + OpSource GLSL 330 + OpName %main "main" + OpName %a "a" + OpName %i "i" + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %bool = OpTypeBool + %int_20 = OpConstant %int 20 + %int_19 = OpConstant %int 19 + %int_18 = OpConstant %int 18 + %int_17 = OpConstant %int 17 + %int_16 = OpConstant %int 16 + %int_15 = OpConstant %int 15 + %int_14 = OpConstant %int 14 + %int_13 = OpConstant %int 13 + %int_12 = OpConstant %int 12 + %int_11 = OpConstant %int 11 + %int_10 = OpConstant %int 10 + %int_9 = OpConstant %int 9 + %int_8 = OpConstant %int 8 + %int_7 = OpConstant %int 7 + %int_6 = OpConstant %int 6 + %int_5 = OpConstant %int 5 + %int_4 = OpConstant %int 4 + %int_3 = OpConstant %int 3 + %int_2 = OpConstant %int 2 + %int_1 = OpConstant %int 1 + %int_0 = OpConstant %int 0 + %main = OpFunction %void None %3 + %5 = OpLabel + %a = OpVariable %_ptr_Function_int Function + %i = OpVariable %_ptr_Function_int Function + OpStore %a %int_0 + OpStore %i %int_0 + OpBranch %11 + %11 = OpLabel + %31 = OpPhi %int %int_0 %5 %33 %14 + %32 = OpPhi %int %int_1 %5 %30 %14 + OpLoopMerge %13 %14 None + OpBranch %15 + %15 = OpLabel + %19 = OpSLessThan %bool %32 %int_20 + OpBranchConditional %19 %12 %13 + %12 = OpLabel + )"; + const std::string text_tail = R"( + OpSelectionMerge %24 None + OpBranchConditional %22 %23 %24 + %23 = OpLabel + %27 = OpIAdd %int %31 %int_2 + OpStore %a %27 + OpBranch %24 + %24 = OpLabel + %33 = OpPhi %int %31 %12 %27 %23 + OpBranch %14 + %14 = OpLabel + %30 = OpIAdd %int %32 %int_2 + OpStore %i %30 + OpBranch %11 + %13 = OpLabel + OpReturn + OpFunctionEnd + )"; + + auto run_test = [&text_head, &text_tail, this](SpvOp opcode, + const std::string& op1, + const std::string& op2) { + auto stats = + RunPeelingTest(text_head, text_tail, opcode, "%22", op1, op2, 2); + + EXPECT_EQ(stats.peeled_loops_.size(), 1u); + if (stats.peeled_loops_.size() != 1u) + return std::pair{ + LoopPeelingPass::PeelDirection::kNone, 0}; + + return std::pair{ + std::get<1>(*stats.peeled_loops_.begin()), + std::get<2>(*stats.peeled_loops_.begin())}; + }; + + // Test LT + // Peel before by a factor of 2. + { + SCOPED_TRACE("Peel before iv < 4"); + + std::pair peel_info = + run_test(SpvOpSLessThan, "%32", "%int_4"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel before 4 > iv"); + + std::pair peel_info = + run_test(SpvOpSGreaterThan, "%int_4", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel before iv < 5"); + + std::pair peel_info = + run_test(SpvOpSLessThan, "%32", "%int_5"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel before 5 > iv"); + + std::pair peel_info = + run_test(SpvOpSGreaterThan, "%int_5", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + + // Peel after by a factor of 2. + { + SCOPED_TRACE("Peel after iv < 16"); + + std::pair peel_info = + run_test(SpvOpSLessThan, "%32", "%int_16"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel after 16 > iv"); + + std::pair peel_info = + run_test(SpvOpSGreaterThan, "%int_16", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel after iv < 17"); + + std::pair peel_info = + run_test(SpvOpSLessThan, "%32", "%int_17"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel after 17 > iv"); + + std::pair peel_info = + run_test(SpvOpSGreaterThan, "%int_17", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + + // Test GT + // Peel before by a factor of 2. + { + SCOPED_TRACE("Peel before iv > 5"); + + std::pair peel_info = + run_test(SpvOpSGreaterThan, "%32", "%int_5"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel before 5 < iv"); + + std::pair peel_info = + run_test(SpvOpSLessThan, "%int_5", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel before iv > 4"); + + std::pair peel_info = + run_test(SpvOpSGreaterThan, "%32", "%int_4"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel before 4 < iv"); + + std::pair peel_info = + run_test(SpvOpSLessThan, "%int_4", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + + // Peel after by a factor of 2. + { + SCOPED_TRACE("Peel after iv > 16"); + + std::pair peel_info = + run_test(SpvOpSGreaterThan, "%32", "%int_16"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel after 16 < iv"); + + std::pair peel_info = + run_test(SpvOpSLessThan, "%int_16", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel after iv > 17"); + + std::pair peel_info = + run_test(SpvOpSGreaterThan, "%32", "%int_17"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel after 17 < iv"); + + std::pair peel_info = + run_test(SpvOpSLessThan, "%int_17", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + + // Test LE + // Peel before by a factor of 2. + { + SCOPED_TRACE("Peel before iv <= 4"); + + std::pair peel_info = + run_test(SpvOpSLessThanEqual, "%32", "%int_4"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel before 4 => iv"); + + std::pair peel_info = + run_test(SpvOpSGreaterThanEqual, "%int_4", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel before iv <= 3"); + + std::pair peel_info = + run_test(SpvOpSLessThanEqual, "%32", "%int_3"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel before 3 => iv"); + + std::pair peel_info = + run_test(SpvOpSGreaterThanEqual, "%int_3", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + + // Peel after by a factor of 2. + { + SCOPED_TRACE("Peel after iv <= 16"); + + std::pair peel_info = + run_test(SpvOpSLessThanEqual, "%32", "%int_16"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel after 16 => iv"); + + std::pair peel_info = + run_test(SpvOpSGreaterThanEqual, "%int_16", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel after iv <= 15"); + + std::pair peel_info = + run_test(SpvOpSLessThanEqual, "%32", "%int_15"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel after 15 => iv"); + + std::pair peel_info = + run_test(SpvOpSGreaterThanEqual, "%int_15", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + + // Test GE + // Peel before by a factor of 2. + { + SCOPED_TRACE("Peel before iv >= 5"); + + std::pair peel_info = + run_test(SpvOpSGreaterThanEqual, "%32", "%int_5"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel before 35 >= iv"); + + std::pair peel_info = + run_test(SpvOpSLessThanEqual, "%int_5", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel before iv >= 4"); + + std::pair peel_info = + run_test(SpvOpSGreaterThanEqual, "%32", "%int_4"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel before 4 <= iv"); + + std::pair peel_info = + run_test(SpvOpSLessThanEqual, "%int_4", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 2u); + } + + // Peel after by a factor of 2. + { + SCOPED_TRACE("Peel after iv >= 17"); + + std::pair peel_info = + run_test(SpvOpSGreaterThanEqual, "%32", "%int_17"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel after 17 <= iv"); + + std::pair peel_info = + run_test(SpvOpSLessThanEqual, "%int_17", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel after iv >= 16"); + + std::pair peel_info = + run_test(SpvOpSGreaterThanEqual, "%32", "%int_16"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + { + SCOPED_TRACE("Peel after 16 <= iv"); + + std::pair peel_info = + run_test(SpvOpSLessThanEqual, "%int_16", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 2u); + } + + // Test EQ + // Peel before by a factor of 1. + { + SCOPED_TRACE("Peel before iv == 1"); + + std::pair peel_info = + run_test(SpvOpIEqual, "%32", "%int_1"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 1u); + } + { + SCOPED_TRACE("Peel before 1 == iv"); + + std::pair peel_info = + run_test(SpvOpIEqual, "%int_1", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 1u); + } + + // Peel after by a factor of 1. + { + SCOPED_TRACE("Peel after iv == 19"); + + std::pair peel_info = + run_test(SpvOpIEqual, "%32", "%int_19"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 1u); + } + { + SCOPED_TRACE("Peel after 19 == iv"); + + std::pair peel_info = + run_test(SpvOpIEqual, "%int_19", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 1u); + } + + // Test NE + // Peel before by a factor of 1. + { + SCOPED_TRACE("Peel before iv != 1"); + + std::pair peel_info = + run_test(SpvOpINotEqual, "%32", "%int_1"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 1u); + } + { + SCOPED_TRACE("Peel before 1 != iv"); + + std::pair peel_info = + run_test(SpvOpINotEqual, "%int_1", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kBefore); + EXPECT_EQ(peel_info.second, 1u); + } + + // Peel after by a factor of 1. + { + SCOPED_TRACE("Peel after iv != 19"); + + std::pair peel_info = + run_test(SpvOpINotEqual, "%32", "%int_19"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 1u); + } + { + SCOPED_TRACE("Peel after 19 != iv"); + + std::pair peel_info = + run_test(SpvOpINotEqual, "%int_19", "%32"); + EXPECT_EQ(peel_info.first, LoopPeelingPass::PeelDirection::kAfter); + EXPECT_EQ(peel_info.second, 1u); + } + + // No peel. + { + SCOPED_TRACE("No Peel: 20 => iv"); + + auto stats = RunPeelingTest(text_head, text_tail, SpvOpSLessThanEqual, + "%22", "%int_20", "%32", 1); + + EXPECT_EQ(stats.peeled_loops_.size(), 0u); + } +} + +/* +Test are derivation of the following generated test from the following GLSL + +--eliminate-local-multi-store + +#version 330 core +void main() { + int a = 0; + for(int i = 0; i < 10; ++i) { + if (i < 3) { + a += 2; + } + if (i < 1) { + a += 2; + } + } +} + +The condition is interchanged to test < > <= >= == and peel before/after +opportunities. +*/ +TEST_F(PeelingPassTest, MultiplePeelingPass) { + const std::string text_head = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginLowerLeft + OpSource GLSL 330 + OpName %main "main" + OpName %a "a" + OpName %i "i" + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %bool = OpTypeBool + %int_10 = OpConstant %int 10 + %int_9 = OpConstant %int 9 + %int_8 = OpConstant %int 8 + %int_7 = OpConstant %int 7 + %int_6 = OpConstant %int 6 + %int_5 = OpConstant %int 5 + %int_4 = OpConstant %int 4 + %int_3 = OpConstant %int 3 + %int_2 = OpConstant %int 2 + %int_1 = OpConstant %int 1 + %int_0 = OpConstant %int 0 + %main = OpFunction %void None %3 + %5 = OpLabel + %a = OpVariable %_ptr_Function_int Function + %i = OpVariable %_ptr_Function_int Function + OpStore %a %int_0 + OpStore %i %int_0 + OpBranch %11 + %11 = OpLabel + %37 = OpPhi %int %int_0 %5 %40 %14 + %38 = OpPhi %int %int_0 %5 %36 %14 + OpLoopMerge %13 %14 None + OpBranch %15 + %15 = OpLabel + %19 = OpSLessThan %bool %38 %int_10 + OpBranchConditional %19 %12 %13 + %12 = OpLabel + )"; + const std::string text_tail = R"( + OpSelectionMerge %24 None + OpBranchConditional %22 %23 %24 + %23 = OpLabel + %27 = OpIAdd %int %37 %int_2 + OpStore %a %27 + OpBranch %24 + %24 = OpLabel + %39 = OpPhi %int %37 %12 %27 %23 + %30 = OpSLessThan %bool %38 %int_1 + OpSelectionMerge %32 None + OpBranchConditional %30 %31 %32 + %31 = OpLabel + %34 = OpIAdd %int %39 %int_2 + OpStore %a %34 + OpBranch %32 + %32 = OpLabel + %40 = OpPhi %int %39 %24 %34 %31 + OpBranch %14 + %14 = OpLabel + %36 = OpIAdd %int %38 %int_1 + OpStore %i %36 + OpBranch %11 + %13 = OpLabel + OpReturn + OpFunctionEnd + )"; + + auto run_test = [&text_head, &text_tail, this]( + SpvOp opcode, const std::string& op1, + const std::string& op2, + const PeelTraceType& expected_peel_trace) { + BuildAndCheckTrace(text_head, text_tail, opcode, "%22", op1, op2, + expected_peel_trace, expected_peel_trace.size() + 1); + }; + + // Test LT + // Peel before by a factor of 3. + { + SCOPED_TRACE("Peel before iv < 3"); + + run_test(SpvOpSLessThan, "%38", "%int_3", + {{LoopPeelingPass::PeelDirection::kBefore, 3u}}); + } + { + SCOPED_TRACE("Peel before 3 > iv"); + + run_test(SpvOpSGreaterThan, "%int_3", "%38", + {{LoopPeelingPass::PeelDirection::kBefore, 3u}}); + } + + // Peel after by a factor of 2. + { + SCOPED_TRACE("Peel after iv < 8"); + + run_test(SpvOpSLessThan, "%38", "%int_8", + {{LoopPeelingPass::PeelDirection::kAfter, 2u}}); + } + { + SCOPED_TRACE("Peel after 8 > iv"); + + run_test(SpvOpSGreaterThan, "%int_8", "%38", + {{LoopPeelingPass::PeelDirection::kAfter, 2u}}); + } + + // Test GT + // Peel before by a factor of 2. + { + SCOPED_TRACE("Peel before iv > 2"); + + run_test(SpvOpSGreaterThan, "%38", "%int_2", + {{LoopPeelingPass::PeelDirection::kBefore, 2u}}); + } + { + SCOPED_TRACE("Peel before 2 < iv"); + + run_test(SpvOpSLessThan, "%int_2", "%38", + {{LoopPeelingPass::PeelDirection::kBefore, 2u}}); + } + + // Peel after by a factor of 3. + { + SCOPED_TRACE("Peel after iv > 7"); + + run_test(SpvOpSGreaterThan, "%38", "%int_7", + {{LoopPeelingPass::PeelDirection::kAfter, 3u}}); + } + { + SCOPED_TRACE("Peel after 7 < iv"); + + run_test(SpvOpSLessThan, "%int_7", "%38", + {{LoopPeelingPass::PeelDirection::kAfter, 3u}}); + } + + // Test LE + // Peel before by a factor of 2. + { + SCOPED_TRACE("Peel before iv <= 1"); + + run_test(SpvOpSLessThanEqual, "%38", "%int_1", + {{LoopPeelingPass::PeelDirection::kBefore, 2u}}); + } + { + SCOPED_TRACE("Peel before 1 => iv"); + + run_test(SpvOpSGreaterThanEqual, "%int_1", "%38", + {{LoopPeelingPass::PeelDirection::kBefore, 2u}}); + } + + // Peel after by a factor of 2. + { + SCOPED_TRACE("Peel after iv <= 7"); + + run_test(SpvOpSLessThanEqual, "%38", "%int_7", + {{LoopPeelingPass::PeelDirection::kAfter, 2u}}); + } + { + SCOPED_TRACE("Peel after 7 => iv"); + + run_test(SpvOpSGreaterThanEqual, "%int_7", "%38", + {{LoopPeelingPass::PeelDirection::kAfter, 2u}}); + } + + // Test GE + // Peel before by a factor of 2. + { + SCOPED_TRACE("Peel before iv >= 2"); + + run_test(SpvOpSGreaterThanEqual, "%38", "%int_2", + {{LoopPeelingPass::PeelDirection::kBefore, 2u}}); + } + { + SCOPED_TRACE("Peel before 2 <= iv"); + + run_test(SpvOpSLessThanEqual, "%int_2", "%38", + {{LoopPeelingPass::PeelDirection::kBefore, 2u}}); + } + + // Peel after by a factor of 2. + { + SCOPED_TRACE("Peel after iv >= 8"); + + run_test(SpvOpSGreaterThanEqual, "%38", "%int_8", + {{LoopPeelingPass::PeelDirection::kAfter, 2u}}); + } + { + SCOPED_TRACE("Peel after 8 <= iv"); + + run_test(SpvOpSLessThanEqual, "%int_8", "%38", + {{LoopPeelingPass::PeelDirection::kAfter, 2u}}); + } + // Test EQ + // Peel before by a factor of 1. + { + SCOPED_TRACE("Peel before iv == 0"); + + run_test(SpvOpIEqual, "%38", "%int_0", + {{LoopPeelingPass::PeelDirection::kBefore, 1u}}); + } + { + SCOPED_TRACE("Peel before 0 == iv"); + + run_test(SpvOpIEqual, "%int_0", "%38", + {{LoopPeelingPass::PeelDirection::kBefore, 1u}}); + } + + // Peel after by a factor of 1. + { + SCOPED_TRACE("Peel after iv == 9"); + + run_test(SpvOpIEqual, "%38", "%int_9", + {{LoopPeelingPass::PeelDirection::kBefore, 1u}}); + } + { + SCOPED_TRACE("Peel after 9 == iv"); + + run_test(SpvOpIEqual, "%int_9", "%38", + {{LoopPeelingPass::PeelDirection::kBefore, 1u}}); + } + + // Test NE + // Peel before by a factor of 1. + { + SCOPED_TRACE("Peel before iv != 0"); + + run_test(SpvOpINotEqual, "%38", "%int_0", + {{LoopPeelingPass::PeelDirection::kBefore, 1u}}); + } + { + SCOPED_TRACE("Peel before 0 != iv"); + + run_test(SpvOpINotEqual, "%int_0", "%38", + {{LoopPeelingPass::PeelDirection::kBefore, 1u}}); + } + + // Peel after by a factor of 1. + { + SCOPED_TRACE("Peel after iv != 9"); + + run_test(SpvOpINotEqual, "%38", "%int_9", + {{LoopPeelingPass::PeelDirection::kBefore, 1u}}); + } + { + SCOPED_TRACE("Peel after 9 != iv"); + + run_test(SpvOpINotEqual, "%int_9", "%38", + {{LoopPeelingPass::PeelDirection::kBefore, 1u}}); + } +} + +/* +Test are derivation of the following generated test from the following GLSL + +--eliminate-local-multi-store + +#version 330 core +void main() { + int a = 0; + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + if (i < 3) { + a += 2; + } + } + } +} +*/ +TEST_F(PeelingPassTest, PeelingNestedPass) { + const std::string text_head = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginLowerLeft + OpSource GLSL 330 + OpName %main "main" + OpName %a "a" + OpName %i "i" + OpName %j "j" + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %int_0 = OpConstant %int 0 + %int_10 = OpConstant %int 10 + %bool = OpTypeBool + %int_7 = OpConstant %int 7 + %int_3 = OpConstant %int 3 + %int_2 = OpConstant %int 2 + %int_1 = OpConstant %int 1 + %43 = OpUndef %int + %main = OpFunction %void None %3 + %5 = OpLabel + %a = OpVariable %_ptr_Function_int Function + %i = OpVariable %_ptr_Function_int Function + %j = OpVariable %_ptr_Function_int Function + OpStore %a %int_0 + OpStore %i %int_0 + OpBranch %11 + %11 = OpLabel + %41 = OpPhi %int %int_0 %5 %45 %14 + %42 = OpPhi %int %int_0 %5 %40 %14 + %44 = OpPhi %int %43 %5 %46 %14 + OpLoopMerge %13 %14 None + OpBranch %15 + %15 = OpLabel + %19 = OpSLessThan %bool %42 %int_10 + OpBranchConditional %19 %12 %13 + %12 = OpLabel + OpStore %j %int_0 + OpBranch %21 + %21 = OpLabel + %45 = OpPhi %int %41 %12 %47 %24 + %46 = OpPhi %int %int_0 %12 %38 %24 + OpLoopMerge %23 %24 None + OpBranch %25 + %25 = OpLabel + %27 = OpSLessThan %bool %46 %int_10 + OpBranchConditional %27 %22 %23 + %22 = OpLabel + )"; + + const std::string text_tail = R"( + OpSelectionMerge %32 None + OpBranchConditional %30 %31 %32 + %31 = OpLabel + %35 = OpIAdd %int %45 %int_2 + OpStore %a %35 + OpBranch %32 + %32 = OpLabel + %47 = OpPhi %int %45 %22 %35 %31 + OpBranch %24 + %24 = OpLabel + %38 = OpIAdd %int %46 %int_1 + OpStore %j %38 + OpBranch %21 + %23 = OpLabel + OpBranch %14 + %14 = OpLabel + %40 = OpIAdd %int %42 %int_1 + OpStore %i %40 + OpBranch %11 + %13 = OpLabel + OpReturn + OpFunctionEnd + )"; + + auto run_test = + [&text_head, &text_tail, this]( + SpvOp opcode, const std::string& op1, const std::string& op2, + const PeelTraceType& expected_peel_trace, size_t nb_of_loops) { + BuildAndCheckTrace(text_head, text_tail, opcode, "%30", op1, op2, + expected_peel_trace, nb_of_loops); + }; + + // Peeling outer before by a factor of 3. + { + SCOPED_TRACE("Peel before iv_i < 3"); + + // Expect peel before by a factor of 3 and 4 loops at the end. + run_test(SpvOpSLessThan, "%42", "%int_3", + {{LoopPeelingPass::PeelDirection::kBefore, 3u}}, 4); + } + // Peeling outer loop after by a factor of 3. + { + SCOPED_TRACE("Peel after iv_i < 7"); + + // Expect peel after by a factor of 3 and 4 loops at the end. + run_test(SpvOpSLessThan, "%42", "%int_7", + {{LoopPeelingPass::PeelDirection::kAfter, 3u}}, 4); + } + + // Peeling inner loop before by a factor of 3. + { + SCOPED_TRACE("Peel before iv_j < 3"); + + // Expect peel before by a factor of 3 and 3 loops at the end. + run_test(SpvOpSLessThan, "%46", "%int_3", + {{LoopPeelingPass::PeelDirection::kBefore, 3u}}, 3); + } + // Peeling inner loop after by a factor of 3. + { + SCOPED_TRACE("Peel after iv_j < 7"); + + // Expect peel after by a factor of 3 and 3 loops at the end. + run_test(SpvOpSLessThan, "%46", "%int_7", + {{LoopPeelingPass::PeelDirection::kAfter, 3u}}, 3); + } + + // Not unworkable condition. + { + SCOPED_TRACE("No peel"); + + // Expect no peeling and 2 loops at the end. + run_test(SpvOpSLessThan, "%46", "%42", {}, 2); + } + + // Could do a peeling of 3, but the goes over the threshold. + { + SCOPED_TRACE("Over threshold"); + + size_t current_threshold = LoopPeelingPass::GetLoopPeelingThreshold(); + LoopPeelingPass::SetLoopPeelingThreshold(1u); + // Expect no peeling and 2 loops at the end. + run_test(SpvOpSLessThan, "%46", "%int_7", {}, 2); + LoopPeelingPass::SetLoopPeelingThreshold(current_threshold); + } +} +/* +Test are derivation of the following generated test from the following GLSL + +--eliminate-local-multi-store + +#version 330 core +void main() { + int a = 0; + for (int i = 0, j = 0; i < 10; j++, i++) { + if (i < j) { + a += 2; + } + } +} +*/ +TEST_F(PeelingPassTest, PeelingNoChanges) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginLowerLeft + OpSource GLSL 330 + OpName %main "main" + OpName %a "a" + OpName %i "i" + OpName %j "j" + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %int_0 = OpConstant %int 0 + %int_10 = OpConstant %int 10 + %bool = OpTypeBool + %int_2 = OpConstant %int 2 + %int_1 = OpConstant %int 1 + %main = OpFunction %void None %3 + %5 = OpLabel + %a = OpVariable %_ptr_Function_int Function + %i = OpVariable %_ptr_Function_int Function + %j = OpVariable %_ptr_Function_int Function + OpStore %a %int_0 + OpStore %i %int_0 + OpStore %j %int_0 + OpBranch %12 + %12 = OpLabel + %34 = OpPhi %int %int_0 %5 %37 %15 + %35 = OpPhi %int %int_0 %5 %33 %15 + %36 = OpPhi %int %int_0 %5 %31 %15 + OpLoopMerge %14 %15 None + OpBranch %16 + %16 = OpLabel + %20 = OpSLessThan %bool %35 %int_10 + OpBranchConditional %20 %13 %14 + %13 = OpLabel + %23 = OpSLessThan %bool %35 %36 + OpSelectionMerge %25 None + OpBranchConditional %23 %24 %25 + %24 = OpLabel + %28 = OpIAdd %int %34 %int_2 + OpStore %a %28 + OpBranch %25 + %25 = OpLabel + %37 = OpPhi %int %34 %13 %28 %24 + OpBranch %15 + %15 = OpLabel + %31 = OpIAdd %int %36 %int_1 + OpStore %j %31 + %33 = OpIAdd %int %35 %int_1 + OpStore %i %33 + OpBranch %12 + %14 = OpLabel + OpReturn + OpFunctionEnd + )"; + + { + auto result = + SinglePassRunAndDisassemble(text, true, false); + + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); + } +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/loop_optimizations/unroll_assumptions.cpp b/test/opt/loop_optimizations/unroll_assumptions.cpp new file mode 100644 index 000000000..62f77d782 --- /dev/null +++ b/test/opt/loop_optimizations/unroll_assumptions.cpp @@ -0,0 +1,1448 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/loop_unroller.h" +#include "source/opt/loop_utils.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::UnorderedElementsAre; +using PassClassTest = PassTest<::testing::Test>; + +template +class PartialUnrollerTestPass : public Pass { + public: + PartialUnrollerTestPass() : Pass() {} + + const char* name() const override { return "Loop unroller"; } + + Status Process() override { + bool changed = false; + for (Function& f : *context()->module()) { + LoopDescriptor& loop_descriptor = *context()->GetLoopDescriptor(&f); + for (auto& loop : loop_descriptor) { + LoopUtils loop_utils{context(), &loop}; + if (loop_utils.PartiallyUnroll(factor)) { + changed = true; + } + } + } + + if (changed) return Pass::Status::SuccessWithChange; + return Pass::Status::SuccessWithoutChange; + } +}; + +/* +Generated from the following GLSL +#version 410 core +layout(location = 0) flat in int in_upper_bound; +void main() { + for (int i = 0; i < in_upper_bound; ++i) { + x[i] = 1.0f; + } +} +*/ +TEST_F(PassClassTest, CheckUpperBound) { + // clang-format off + // With LocalMultiStoreElimPass + const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" %3 +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 410 +OpName %2 "main" +OpName %3 "in_upper_bound" +OpName %4 "x" +OpDecorate %3 Flat +OpDecorate %3 Location 0 +%5 = OpTypeVoid +%6 = OpTypeFunction %5 +%7 = OpTypeInt 32 1 +%8 = OpTypePointer Function %7 +%9 = OpConstant %7 0 +%10 = OpTypePointer Input %7 +%3 = OpVariable %10 Input +%11 = OpTypeBool +%12 = OpTypeFloat 32 +%13 = OpTypeInt 32 0 +%14 = OpConstant %13 10 +%15 = OpTypeArray %12 %14 +%16 = OpTypePointer Function %15 +%17 = OpConstant %12 1 +%18 = OpTypePointer Function %12 +%19 = OpConstant %7 1 +%2 = OpFunction %5 None %6 +%20 = OpLabel +%4 = OpVariable %16 Function +OpBranch %21 +%21 = OpLabel +%22 = OpPhi %7 %9 %20 %23 %24 +OpLoopMerge %25 %24 Unroll +OpBranch %26 +%26 = OpLabel +%27 = OpLoad %7 %3 +%28 = OpSLessThan %11 %22 %27 +OpBranchConditional %28 %29 %25 +%29 = OpLabel +%30 = OpAccessChain %18 %4 %22 +OpStore %30 %17 +OpBranch %24 +%24 = OpLabel +%23 = OpIAdd %7 %22 %19 +OpBranch %21 +%25 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + + // Make sure the pass doesn't run + SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck>(text, text, false); + SinglePassRunAndCheck>(text, text, false); +} + +/* +Generated from the following GLSL +#version 410 core +void main() { + float out_array[10]; + for (uint i = 0; i < 2; i++) { + for (float x = 0; x < 5; ++x) { + out_array[x + i*5] = i; + } + } +} +*/ +TEST_F(PassClassTest, UnrollNestedLoopsInvalid) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 410 +OpName %2 "main" +OpName %3 "out_array" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 0 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 0 +%9 = OpConstant %6 2 +%10 = OpTypeBool +%11 = OpTypeInt 32 1 +%12 = OpTypePointer Function %11 +%13 = OpConstant %11 0 +%14 = OpConstant %11 5 +%15 = OpTypeFloat 32 +%16 = OpConstant %6 10 +%17 = OpTypeArray %15 %16 +%18 = OpTypePointer Function %17 +%19 = OpConstant %6 5 +%20 = OpTypePointer Function %15 +%21 = OpConstant %11 1 +%22 = OpUndef %11 +%2 = OpFunction %4 None %5 +%23 = OpLabel +%3 = OpVariable %18 Function +OpBranch %24 +%24 = OpLabel +%25 = OpPhi %6 %8 %23 %26 %27 +%28 = OpPhi %11 %22 %23 %29 %27 +OpLoopMerge %30 %27 Unroll +OpBranch %31 +%31 = OpLabel +%32 = OpULessThan %10 %25 %9 +OpBranchConditional %32 %33 %30 +%33 = OpLabel +OpBranch %34 +%34 = OpLabel +%29 = OpPhi %11 %13 %33 %35 %36 +OpLoopMerge %37 %36 None +OpBranch %38 +%38 = OpLabel +%39 = OpSLessThan %10 %29 %14 +OpBranchConditional %39 %40 %37 +%40 = OpLabel +%41 = OpBitcast %6 %29 +%42 = OpIMul %6 %25 %19 +%43 = OpIAdd %6 %41 %42 +%44 = OpConvertUToF %15 %25 +%45 = OpAccessChain %20 %3 %43 +OpStore %45 %44 +OpBranch %36 +%36 = OpLabel +%35 = OpIAdd %11 %29 %21 +OpBranch %34 +%37 = OpLabel +OpBranch %27 +%27 = OpLabel +%26 = OpIAdd %6 %25 %21 +OpBranch %24 +%30 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(text, text, false); +} + +/* +Generated from the following GLSL +#version 440 core +void main(){ + float x[10]; + for (int i = 0; i < 10; i++) { + if (i == 5) { + break; + } + x[i] = i; + } +} +*/ +TEST_F(PassClassTest, BreakInBody) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 440 +OpName %2 "main" +OpName %3 "x" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 1 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 0 +%9 = OpConstant %6 10 +%10 = OpTypeBool +%11 = OpConstant %6 5 +%12 = OpTypeFloat 32 +%13 = OpTypeInt 32 0 +%14 = OpConstant %13 10 +%15 = OpTypeArray %12 %14 +%16 = OpTypePointer Function %15 +%17 = OpTypePointer Function %12 +%18 = OpConstant %6 1 +%2 = OpFunction %4 None %5 +%19 = OpLabel +%3 = OpVariable %16 Function +OpBranch %20 +%20 = OpLabel +%21 = OpPhi %6 %8 %19 %22 %23 +OpLoopMerge %24 %23 Unroll +OpBranch %25 +%25 = OpLabel +%26 = OpSLessThan %10 %21 %9 +OpBranchConditional %26 %27 %24 +%27 = OpLabel +%28 = OpIEqual %10 %21 %11 +OpSelectionMerge %29 None +OpBranchConditional %28 %30 %29 +%30 = OpLabel +OpBranch %24 +%29 = OpLabel +%31 = OpConvertSToF %12 %21 +%32 = OpAccessChain %17 %3 %21 +OpStore %32 %31 +OpBranch %23 +%23 = OpLabel +%22 = OpIAdd %6 %21 %18 +OpBranch %20 +%24 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(text, text, false); +} + +/* +Generated from the following GLSL +#version 440 core +void main(){ + float x[10]; + for (int i = 0; i < 10; i++) { + if (i == 5) { + continue; + } + x[i] = i; + } +} +*/ +TEST_F(PassClassTest, ContinueInBody) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 440 +OpName %2 "main" +OpName %3 "x" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 1 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 0 +%9 = OpConstant %6 10 +%10 = OpTypeBool +%11 = OpConstant %6 5 +%12 = OpTypeFloat 32 +%13 = OpTypeInt 32 0 +%14 = OpConstant %13 10 +%15 = OpTypeArray %12 %14 +%16 = OpTypePointer Function %15 +%17 = OpTypePointer Function %12 +%18 = OpConstant %6 1 +%2 = OpFunction %4 None %5 +%19 = OpLabel +%3 = OpVariable %16 Function +OpBranch %20 +%20 = OpLabel +%21 = OpPhi %6 %8 %19 %22 %23 +OpLoopMerge %24 %23 Unroll +OpBranch %25 +%25 = OpLabel +%26 = OpSLessThan %10 %21 %9 +OpBranchConditional %26 %27 %24 +%27 = OpLabel +%28 = OpIEqual %10 %21 %11 +OpSelectionMerge %29 None +OpBranchConditional %28 %30 %29 +%30 = OpLabel +OpBranch %23 +%29 = OpLabel +%31 = OpConvertSToF %12 %21 +%32 = OpAccessChain %17 %3 %21 +OpStore %32 %31 +OpBranch %23 +%23 = OpLabel +%22 = OpIAdd %6 %21 %18 +OpBranch %20 +%24 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(text, text, false); +} + +/* +Generated from the following GLSL +#version 440 core +void main(){ + float x[10]; + for (int i = 0; i < 10; i++) { + if (i == 5) { + return; + } + x[i] = i; + } +} +*/ +TEST_F(PassClassTest, ReturnInBody) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 440 +OpName %2 "main" +OpName %3 "x" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 1 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 0 +%9 = OpConstant %6 10 +%10 = OpTypeBool +%11 = OpConstant %6 5 +%12 = OpTypeFloat 32 +%13 = OpTypeInt 32 0 +%14 = OpConstant %13 10 +%15 = OpTypeArray %12 %14 +%16 = OpTypePointer Function %15 +%17 = OpTypePointer Function %12 +%18 = OpConstant %6 1 +%2 = OpFunction %4 None %5 +%19 = OpLabel +%3 = OpVariable %16 Function +OpBranch %20 +%20 = OpLabel +%21 = OpPhi %6 %8 %19 %22 %23 +OpLoopMerge %24 %23 Unroll +OpBranch %25 +%25 = OpLabel +%26 = OpSLessThan %10 %21 %9 +OpBranchConditional %26 %27 %24 +%27 = OpLabel +%28 = OpIEqual %10 %21 %11 +OpSelectionMerge %29 None +OpBranchConditional %28 %30 %29 +%30 = OpLabel +OpReturn +%29 = OpLabel +%31 = OpConvertSToF %12 %21 +%32 = OpAccessChain %17 %3 %21 +OpStore %32 %31 +OpBranch %23 +%23 = OpLabel +%22 = OpIAdd %6 %21 %18 +OpBranch %20 +%24 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(text, text, false); +} + +/* +Generated from the following GLSL +#version 440 core +void main() { + int j = 0; + for (int i = 0; i < 10 && i > 0; i++) { + j++; + } +} +*/ +TEST_F(PassClassTest, MultipleConditionsSingleVariable) { + // clang-format off + // With LocalMultiStoreElimPass + const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 440 +OpName %2 "main" +%3 = OpTypeVoid +%4 = OpTypeFunction %3 +%5 = OpTypeInt 32 1 +%6 = OpTypePointer Function %5 +%7 = OpConstant %5 0 +%8 = OpConstant %5 10 +%9 = OpTypeBool +%10 = OpConstant %5 1 +%2 = OpFunction %3 None %4 +%11 = OpLabel +OpBranch %12 +%12 = OpLabel +%13 = OpPhi %5 %7 %11 %14 %15 +%16 = OpPhi %5 %7 %11 %17 %15 +OpLoopMerge %18 %15 Unroll +OpBranch %19 +%19 = OpLabel +%20 = OpSLessThan %9 %16 %8 +%21 = OpSGreaterThan %9 %16 %7 +%22 = OpLogicalAnd %9 %20 %21 +OpBranchConditional %22 %23 %18 +%23 = OpLabel +%14 = OpIAdd %5 %13 %10 +OpBranch %15 +%15 = OpLabel +%17 = OpIAdd %5 %16 %10 +OpBranch %12 +%18 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + + // Make sure the pass doesn't run + SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck>(text, text, false); + SinglePassRunAndCheck>(text, text, false); +} + +/* +Generated from the following GLSL +#version 440 core +void main() { + int i = 0; + int j = 0; + int k = 0; + for (; i < 10 && j > 0; i++, j++) { + k++; + } +} +*/ +TEST_F(PassClassTest, MultipleConditionsMultipleVariables) { + // clang-format off + // With LocalMultiStoreElimPass + const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 440 +OpName %2 "main" +%3 = OpTypeVoid +%4 = OpTypeFunction %3 +%5 = OpTypeInt 32 1 +%6 = OpTypePointer Function %5 +%7 = OpConstant %5 0 +%8 = OpConstant %5 10 +%9 = OpTypeBool +%10 = OpConstant %5 1 +%2 = OpFunction %3 None %4 +%11 = OpLabel +OpBranch %12 +%12 = OpLabel +%13 = OpPhi %5 %7 %11 %14 %15 +%16 = OpPhi %5 %7 %11 %17 %15 +%18 = OpPhi %5 %7 %11 %19 %15 +OpLoopMerge %20 %15 Unroll +OpBranch %21 +%21 = OpLabel +%22 = OpSLessThan %9 %13 %8 +%23 = OpSGreaterThan %9 %16 %7 +%24 = OpLogicalAnd %9 %22 %23 +OpBranchConditional %24 %25 %20 +%25 = OpLabel +%19 = OpIAdd %5 %18 %10 +OpBranch %15 +%15 = OpLabel +%14 = OpIAdd %5 %13 %10 +%17 = OpIAdd %5 %16 %10 +OpBranch %12 +%20 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + + // Make sure the pass doesn't run + SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck>(text, text, false); + SinglePassRunAndCheck>(text, text, false); +} + +/* +Generated from the following GLSL +#version 440 core +void main() { + float i = 0.0; + int j = 0; + for (; i < 10; i++) { + j++; + } +} +*/ +TEST_F(PassClassTest, FloatingPointLoop) { + // clang-format off + // With LocalMultiStoreElimPass + const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 440 +OpName %2 "main" +%3 = OpTypeVoid +%4 = OpTypeFunction %3 +%5 = OpTypeFloat 32 +%6 = OpTypePointer Function %5 +%7 = OpConstant %5 0 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 0 +%11 = OpConstant %5 10 +%12 = OpTypeBool +%13 = OpConstant %8 1 +%14 = OpConstant %5 1 +%2 = OpFunction %3 None %4 +%15 = OpLabel +OpBranch %16 +%16 = OpLabel +%17 = OpPhi %5 %7 %15 %18 %19 +%20 = OpPhi %8 %10 %15 %21 %19 +OpLoopMerge %22 %19 Unroll +OpBranch %23 +%23 = OpLabel +%24 = OpFOrdLessThan %12 %17 %11 +OpBranchConditional %24 %25 %22 +%25 = OpLabel +%21 = OpIAdd %8 %20 %13 +OpBranch %19 +%19 = OpLabel +%18 = OpFAdd %5 %17 %14 +OpBranch %16 +%22 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + + // Make sure the pass doesn't run + SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck>(text, text, false); + SinglePassRunAndCheck>(text, text, false); +} + +/* +Generated from the following GLSL +#version 440 core +void main() { + int i = 2; + int j = 0; + if (j == 0) { i = 5; } + for (; i < 3; ++i) { + j++; + } +} +*/ +TEST_F(PassClassTest, InductionPhiOutsideLoop) { + // clang-format off + // With LocalMultiStoreElimPass + const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 440 +OpName %2 "main" +%3 = OpTypeVoid +%4 = OpTypeFunction %3 +%5 = OpTypeInt 32 1 +%6 = OpTypePointer Function %5 +%7 = OpConstant %5 2 +%8 = OpConstant %5 0 +%9 = OpTypeBool +%10 = OpConstant %5 5 +%11 = OpConstant %5 3 +%12 = OpConstant %5 1 +%2 = OpFunction %3 None %4 +%13 = OpLabel +%14 = OpIEqual %9 %8 %8 +OpSelectionMerge %15 None +OpBranchConditional %14 %16 %15 +%16 = OpLabel +OpBranch %15 +%15 = OpLabel +%17 = OpPhi %5 %7 %13 %10 %16 +OpBranch %18 +%18 = OpLabel +%19 = OpPhi %5 %17 %15 %20 %21 +%22 = OpPhi %5 %8 %15 %23 %21 +OpLoopMerge %24 %21 Unroll +OpBranch %25 +%25 = OpLabel +%26 = OpSLessThan %9 %19 %11 +OpBranchConditional %26 %27 %24 +%27 = OpLabel +%23 = OpIAdd %5 %22 %12 +OpBranch %21 +%21 = OpLabel +%20 = OpIAdd %5 %19 %12 +OpBranch %18 +%24 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + + // Make sure the pass doesn't run + SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck>(text, text, false); + SinglePassRunAndCheck>(text, text, false); +} + +/* +Generated from the following GLSL +#version 440 core +void main() { + int j = 0; + for (int i = 0; i == 0; ++i) { + ++j; + } + for (int i = 0; i != 3; ++i) { + ++j; + } + for (int i = 0; i < 3; i *= 2) { + ++j; + } + for (int i = 10; i > 3; i /= 2) { + ++j; + } + for (int i = 10; i > 3; i |= 2) { + ++j; + } + for (int i = 10; i > 3; i &= 2) { + ++j; + } + for (int i = 10; i > 3; i ^= 2) { + ++j; + } + for (int i = 0; i < 3; i << 2) { + ++j; + } + for (int i = 10; i > 3; i >> 2) { + ++j; + } +} +*/ +TEST_F(PassClassTest, UnsupportedLoopTypes) { + // clang-format off + // With LocalMultiStoreElimPass + const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 440 +OpName %2 "main" +%3 = OpTypeVoid +%4 = OpTypeFunction %3 +%5 = OpTypeInt 32 1 +%6 = OpTypePointer Function %5 +%7 = OpConstant %5 0 +%8 = OpTypeBool +%9 = OpConstant %5 1 +%10 = OpConstant %5 3 +%11 = OpConstant %5 2 +%12 = OpConstant %5 10 +%2 = OpFunction %3 None %4 +%13 = OpLabel +OpBranch %14 +%14 = OpLabel +%15 = OpPhi %5 %7 %13 %16 %17 +%18 = OpPhi %5 %7 %13 %19 %17 +OpLoopMerge %20 %17 Unroll +OpBranch %21 +%21 = OpLabel +%22 = OpIEqual %8 %18 %7 +OpBranchConditional %22 %23 %20 +%23 = OpLabel +%16 = OpIAdd %5 %15 %9 +OpBranch %17 +%17 = OpLabel +%19 = OpIAdd %5 %18 %9 +OpBranch %14 +%20 = OpLabel +OpBranch %24 +%24 = OpLabel +%25 = OpPhi %5 %15 %20 %26 %27 +%28 = OpPhi %5 %7 %20 %29 %27 +OpLoopMerge %30 %27 Unroll +OpBranch %31 +%31 = OpLabel +%32 = OpINotEqual %8 %28 %10 +OpBranchConditional %32 %33 %30 +%33 = OpLabel +%26 = OpIAdd %5 %25 %9 +OpBranch %27 +%27 = OpLabel +%29 = OpIAdd %5 %28 %9 +OpBranch %24 +%30 = OpLabel +OpBranch %34 +%34 = OpLabel +%35 = OpPhi %5 %25 %30 %36 %37 +%38 = OpPhi %5 %7 %30 %39 %37 +OpLoopMerge %40 %37 Unroll +OpBranch %41 +%41 = OpLabel +%42 = OpSLessThan %8 %38 %10 +OpBranchConditional %42 %43 %40 +%43 = OpLabel +%36 = OpIAdd %5 %35 %9 +OpBranch %37 +%37 = OpLabel +%39 = OpIMul %5 %38 %11 +OpBranch %34 +%40 = OpLabel +OpBranch %44 +%44 = OpLabel +%45 = OpPhi %5 %35 %40 %46 %47 +%48 = OpPhi %5 %12 %40 %49 %47 +OpLoopMerge %50 %47 Unroll +OpBranch %51 +%51 = OpLabel +%52 = OpSGreaterThan %8 %48 %10 +OpBranchConditional %52 %53 %50 +%53 = OpLabel +%46 = OpIAdd %5 %45 %9 +OpBranch %47 +%47 = OpLabel +%49 = OpSDiv %5 %48 %11 +OpBranch %44 +%50 = OpLabel +OpBranch %54 +%54 = OpLabel +%55 = OpPhi %5 %45 %50 %56 %57 +%58 = OpPhi %5 %12 %50 %59 %57 +OpLoopMerge %60 %57 Unroll +OpBranch %61 +%61 = OpLabel +%62 = OpSGreaterThan %8 %58 %10 +OpBranchConditional %62 %63 %60 +%63 = OpLabel +%56 = OpIAdd %5 %55 %9 +OpBranch %57 +%57 = OpLabel +%59 = OpBitwiseOr %5 %58 %11 +OpBranch %54 +%60 = OpLabel +OpBranch %64 +%64 = OpLabel +%65 = OpPhi %5 %55 %60 %66 %67 +%68 = OpPhi %5 %12 %60 %69 %67 +OpLoopMerge %70 %67 Unroll +OpBranch %71 +%71 = OpLabel +%72 = OpSGreaterThan %8 %68 %10 +OpBranchConditional %72 %73 %70 +%73 = OpLabel +%66 = OpIAdd %5 %65 %9 +OpBranch %67 +%67 = OpLabel +%69 = OpBitwiseAnd %5 %68 %11 +OpBranch %64 +%70 = OpLabel +OpBranch %74 +%74 = OpLabel +%75 = OpPhi %5 %65 %70 %76 %77 +%78 = OpPhi %5 %12 %70 %79 %77 +OpLoopMerge %80 %77 Unroll +OpBranch %81 +%81 = OpLabel +%82 = OpSGreaterThan %8 %78 %10 +OpBranchConditional %82 %83 %80 +%83 = OpLabel +%76 = OpIAdd %5 %75 %9 +OpBranch %77 +%77 = OpLabel +%79 = OpBitwiseXor %5 %78 %11 +OpBranch %74 +%80 = OpLabel +OpBranch %84 +%84 = OpLabel +%85 = OpPhi %5 %75 %80 %86 %87 +OpLoopMerge %88 %87 Unroll +OpBranch %89 +%89 = OpLabel +%90 = OpSLessThan %8 %7 %10 +OpBranchConditional %90 %91 %88 +%91 = OpLabel +%86 = OpIAdd %5 %85 %9 +OpBranch %87 +%87 = OpLabel +%92 = OpShiftLeftLogical %5 %7 %11 +OpBranch %84 +%88 = OpLabel +OpBranch %93 +%93 = OpLabel +%94 = OpPhi %5 %85 %88 %95 %96 +OpLoopMerge %97 %96 Unroll +OpBranch %98 +%98 = OpLabel +%99 = OpSGreaterThan %8 %12 %10 +OpBranchConditional %99 %100 %97 +%100 = OpLabel +%95 = OpIAdd %5 %94 %9 +OpBranch %96 +%96 = OpLabel +%101 = OpShiftRightArithmetic %5 %12 %11 +OpBranch %93 +%97 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + + // Make sure the pass doesn't run + SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck>(text, text, false); + SinglePassRunAndCheck>(text, text, false); +} + +/* +#version 430 + +layout(location = 0) out float o; + +void main(void) { + for (int j = 2; j < 0; j += 1) { + o += 1.0; + } +} +*/ +TEST_F(PassClassTest, NegativeNumberOfIterations) { + // clang-format off + // With LocalMultiStoreElimPass + const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" %3 +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "o" +OpDecorate %3 Location 0 +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 1 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 2 +%9 = OpConstant %6 0 +%10 = OpTypeBool +%11 = OpTypeFloat 32 +%12 = OpTypePointer Output %11 +%3 = OpVariable %12 Output +%13 = OpConstant %11 1 +%14 = OpConstant %6 1 +%2 = OpFunction %4 None %5 +%15 = OpLabel +OpBranch %16 +%16 = OpLabel +%17 = OpPhi %6 %8 %15 %18 %19 +OpLoopMerge %20 %19 None +OpBranch %21 +%21 = OpLabel +%22 = OpSLessThan %10 %17 %9 +OpBranchConditional %22 %23 %20 +%23 = OpLabel +%24 = OpLoad %11 %3 +%25 = OpFAdd %11 %24 %13 +OpStore %3 %25 +OpBranch %19 +%19 = OpLabel +%18 = OpIAdd %6 %17 %14 +OpBranch %16 +%20 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + + // Make sure the pass doesn't run + SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck>(text, text, false); + SinglePassRunAndCheck>(text, text, false); +} + +/* +#version 430 + +layout(location = 0) out float o; + +void main(void) { + float s = 0.0; + for (int j = 0; j < 3; j += 1) { + s += 1.0; + j += 1; + } + o = s; +} +*/ +TEST_F(PassClassTest, MultipleStepOperations) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" %3 +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "o" +OpDecorate %3 Location 0 +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeFloat 32 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 0 +%9 = OpTypeInt 32 1 +%10 = OpTypePointer Function %9 +%11 = OpConstant %9 0 +%12 = OpConstant %9 3 +%13 = OpTypeBool +%14 = OpConstant %6 1 +%15 = OpConstant %9 1 +%16 = OpTypePointer Output %6 +%3 = OpVariable %16 Output +%2 = OpFunction %4 None %5 +%17 = OpLabel +OpBranch %18 +%18 = OpLabel +%19 = OpPhi %6 %8 %17 %20 %21 +%22 = OpPhi %9 %11 %17 %23 %21 +OpLoopMerge %24 %21 Unroll +OpBranch %25 +%25 = OpLabel +%26 = OpSLessThan %13 %22 %12 +OpBranchConditional %26 %27 %24 +%27 = OpLabel +%20 = OpFAdd %6 %19 %14 +%28 = OpIAdd %9 %22 %15 +OpBranch %21 +%21 = OpLabel +%23 = OpIAdd %9 %28 %15 +OpBranch %18 +%24 = OpLabel +OpStore %3 %19 +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + + // Make sure the pass doesn't run + SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck>(text, text, false); + SinglePassRunAndCheck>(text, text, false); +} + +/* +#version 430 + +layout(location = 0) out float o; + +void main(void) { + float s = 0.0; + for (int j = 10; j > 20; j -= 1) { + s += 1.0; + } + o = s; +} +*/ + +TEST_F(PassClassTest, ConditionFalseFromStartGreaterThan) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" %3 +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "o" +OpDecorate %3 Location 0 +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeFloat 32 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 0 +%9 = OpTypeInt 32 1 +%10 = OpTypePointer Function %9 +%11 = OpConstant %9 10 +%12 = OpConstant %9 20 +%13 = OpTypeBool +%14 = OpConstant %6 1 +%15 = OpConstant %9 1 +%16 = OpTypePointer Output %6 +%3 = OpVariable %16 Output +%2 = OpFunction %4 None %5 +%17 = OpLabel +OpBranch %18 +%18 = OpLabel +%19 = OpPhi %6 %8 %17 %20 %21 +%22 = OpPhi %9 %11 %17 %23 %21 +OpLoopMerge %24 %21 Unroll +OpBranch %25 +%25 = OpLabel +%26 = OpSGreaterThan %13 %22 %12 +OpBranchConditional %26 %27 %24 +%27 = OpLabel +%20 = OpFAdd %6 %19 %14 +OpBranch %21 +%21 = OpLabel +%23 = OpISub %9 %22 %15 +OpBranch %18 +%24 = OpLabel +OpStore %3 %19 +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + + // Make sure the pass doesn't run + SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck>(text, text, false); + SinglePassRunAndCheck>(text, text, false); +} + +/* +#version 430 + +layout(location = 0) out float o; + +void main(void) { + float s = 0.0; + for (int j = 10; j >= 20; j -= 1) { + s += 1.0; + } + o = s; +} +*/ +TEST_F(PassClassTest, ConditionFalseFromStartGreaterThanOrEqual) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" %3 +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "o" +OpDecorate %3 Location 0 +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeFloat 32 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 0 +%9 = OpTypeInt 32 1 +%10 = OpTypePointer Function %9 +%11 = OpConstant %9 10 +%12 = OpConstant %9 20 +%13 = OpTypeBool +%14 = OpConstant %6 1 +%15 = OpConstant %9 1 +%16 = OpTypePointer Output %6 +%3 = OpVariable %16 Output +%2 = OpFunction %4 None %5 +%17 = OpLabel +OpBranch %18 +%18 = OpLabel +%19 = OpPhi %6 %8 %17 %20 %21 +%22 = OpPhi %9 %11 %17 %23 %21 +OpLoopMerge %24 %21 Unroll +OpBranch %25 +%25 = OpLabel +%26 = OpSGreaterThanEqual %13 %22 %12 +OpBranchConditional %26 %27 %24 +%27 = OpLabel +%20 = OpFAdd %6 %19 %14 +OpBranch %21 +%21 = OpLabel +%23 = OpISub %9 %22 %15 +OpBranch %18 +%24 = OpLabel +OpStore %3 %19 +OpReturn +OpFunctionEnd +)"; + + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + + // Make sure the pass doesn't run + SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck>(text, text, false); + SinglePassRunAndCheck>(text, text, false); +} + +/* +#version 430 + +layout(location = 0) out float o; + +void main(void) { + float s = 0.0; + for (int j = 20; j < 10; j -= 1) { + s += 1.0; + } + o = s; +} +*/ +TEST_F(PassClassTest, ConditionFalseFromStartLessThan) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" %3 +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "o" +OpDecorate %3 Location 0 +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeFloat 32 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 0 +%9 = OpTypeInt 32 1 +%10 = OpTypePointer Function %9 +%11 = OpConstant %9 20 +%12 = OpConstant %9 10 +%13 = OpTypeBool +%14 = OpConstant %6 1 +%15 = OpConstant %9 1 +%16 = OpTypePointer Output %6 +%3 = OpVariable %16 Output +%2 = OpFunction %4 None %5 +%17 = OpLabel +OpBranch %18 +%18 = OpLabel +%19 = OpPhi %6 %8 %17 %20 %21 +%22 = OpPhi %9 %11 %17 %23 %21 +OpLoopMerge %24 %21 Unroll +OpBranch %25 +%25 = OpLabel +%26 = OpSLessThan %13 %22 %12 +OpBranchConditional %26 %27 %24 +%27 = OpLabel +%20 = OpFAdd %6 %19 %14 +OpBranch %21 +%21 = OpLabel +%23 = OpISub %9 %22 %15 +OpBranch %18 +%24 = OpLabel +OpStore %3 %19 +OpReturn +OpFunctionEnd +)"; + + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + + // Make sure the pass doesn't run + SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck>(text, text, false); + SinglePassRunAndCheck>(text, text, false); +} + +/* +#version 430 + +layout(location = 0) out float o; + +void main(void) { + float s = 0.0; + for (int j = 20; j <= 10; j -= 1) { + s += 1.0; + } + o = s; +} +*/ +TEST_F(PassClassTest, ConditionFalseFromStartLessThanEqual) { + // clang-format off + // With LocalMultiStoreElimPass +const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" %3 +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "o" +OpDecorate %3 Location 0 +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeFloat 32 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 0 +%9 = OpTypeInt 32 1 +%10 = OpTypePointer Function %9 +%11 = OpConstant %9 20 +%12 = OpConstant %9 10 +%13 = OpTypeBool +%14 = OpConstant %6 1 +%15 = OpConstant %9 1 +%16 = OpTypePointer Output %6 +%3 = OpVariable %16 Output +%2 = OpFunction %4 None %5 +%17 = OpLabel +OpBranch %18 +%18 = OpLabel +%19 = OpPhi %6 %8 %17 %20 %21 +%22 = OpPhi %9 %11 %17 %23 %21 +OpLoopMerge %24 %21 Unroll +OpBranch %25 +%25 = OpLabel +%26 = OpSLessThanEqual %13 %22 %12 +OpBranchConditional %26 %27 %24 +%27 = OpLabel +%20 = OpFAdd %6 %19 %14 +OpBranch %21 +%21 = OpLabel +%23 = OpISub %9 %22 %15 +OpBranch %18 +%24 = OpLabel +OpStore %3 %19 +OpReturn +OpFunctionEnd +)"; + + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + + // Make sure the pass doesn't run + SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck>(text, text, false); + SinglePassRunAndCheck>(text, text, false); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/loop_optimizations/unroll_simple.cpp b/test/opt/loop_optimizations/unroll_simple.cpp new file mode 100644 index 000000000..f551e7ca9 --- /dev/null +++ b/test/opt/loop_optimizations/unroll_simple.cpp @@ -0,0 +1,3003 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/loop_unroller.h" +#include "source/opt/loop_utils.h" +#include "source/opt/pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::UnorderedElementsAre; +using PassClassTest = PassTest<::testing::Test>; + +/* +Generated from the following GLSL +#version 330 core +layout(location = 0) out vec4 c; +void main() { + float x[4]; + for (int i = 0; i < 4; ++i) { + x[i] = 1.0f; + } +} +*/ +TEST_F(PassClassTest, SimpleFullyUnrollTest) { + // With LocalMultiStoreElimPass + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 330 + OpName %2 "main" + OpName %5 "x" + OpName %3 "c" + OpDecorate %3 Location 0 + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %10 = OpConstant %8 0 + %11 = OpConstant %8 4 + %12 = OpTypeBool + %13 = OpTypeFloat 32 + %14 = OpTypeInt 32 0 + %15 = OpConstant %14 4 + %16 = OpTypeArray %13 %15 + %17 = OpTypePointer Function %16 + %18 = OpConstant %13 1 + %19 = OpTypePointer Function %13 + %20 = OpConstant %8 1 + %21 = OpTypeVector %13 4 + %22 = OpTypePointer Output %21 + %3 = OpVariable %22 Output + %2 = OpFunction %6 None %7 + %23 = OpLabel + %5 = OpVariable %17 Function + OpBranch %24 + %24 = OpLabel + %35 = OpPhi %8 %10 %23 %34 %26 + OpLoopMerge %25 %26 Unroll + OpBranch %27 + %27 = OpLabel + %29 = OpSLessThan %12 %35 %11 + OpBranchConditional %29 %30 %25 + %30 = OpLabel + %32 = OpAccessChain %19 %5 %35 + OpStore %32 %18 + OpBranch %26 + %26 = OpLabel + %34 = OpIAdd %8 %35 %20 + OpBranch %24 + %25 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const std::string output = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" %3 +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 330 +OpName %2 "main" +OpName %4 "x" +OpName %3 "c" +OpDecorate %3 Location 0 +%5 = OpTypeVoid +%6 = OpTypeFunction %5 +%7 = OpTypeInt 32 1 +%8 = OpTypePointer Function %7 +%9 = OpConstant %7 0 +%10 = OpConstant %7 4 +%11 = OpTypeBool +%12 = OpTypeFloat 32 +%13 = OpTypeInt 32 0 +%14 = OpConstant %13 4 +%15 = OpTypeArray %12 %14 +%16 = OpTypePointer Function %15 +%17 = OpConstant %12 1 +%18 = OpTypePointer Function %12 +%19 = OpConstant %7 1 +%20 = OpTypeVector %12 4 +%21 = OpTypePointer Output %20 +%3 = OpVariable %21 Output +%2 = OpFunction %5 None %6 +%22 = OpLabel +%4 = OpVariable %16 Function +OpBranch %23 +%23 = OpLabel +OpBranch %28 +%28 = OpLabel +%29 = OpSLessThan %11 %9 %10 +OpBranch %30 +%30 = OpLabel +%31 = OpAccessChain %18 %4 %9 +OpStore %31 %17 +OpBranch %26 +%26 = OpLabel +%25 = OpIAdd %7 %9 %19 +OpBranch %32 +%32 = OpLabel +OpBranch %34 +%34 = OpLabel +%35 = OpSLessThan %11 %25 %10 +OpBranch %36 +%36 = OpLabel +%37 = OpAccessChain %18 %4 %25 +OpStore %37 %17 +OpBranch %38 +%38 = OpLabel +%39 = OpIAdd %7 %25 %19 +OpBranch %40 +%40 = OpLabel +OpBranch %42 +%42 = OpLabel +%43 = OpSLessThan %11 %39 %10 +OpBranch %44 +%44 = OpLabel +%45 = OpAccessChain %18 %4 %39 +OpStore %45 %17 +OpBranch %46 +%46 = OpLabel +%47 = OpIAdd %7 %39 %19 +OpBranch %48 +%48 = OpLabel +OpBranch %50 +%50 = OpLabel +%51 = OpSLessThan %11 %47 %10 +OpBranch %52 +%52 = OpLabel +%53 = OpAccessChain %18 %4 %47 +OpStore %53 %17 +OpBranch %54 +%54 = OpLabel +%55 = OpIAdd %7 %47 %19 +OpBranch %27 +%27 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(text, output, false); +} + +template +class PartialUnrollerTestPass : public Pass { + public: + PartialUnrollerTestPass() : Pass() {} + + const char* name() const override { return "Loop unroller"; } + + Status Process() override { + for (Function& f : *context()->module()) { + LoopDescriptor& loop_descriptor = *context()->GetLoopDescriptor(&f); + for (auto& loop : loop_descriptor) { + LoopUtils loop_utils{context(), &loop}; + loop_utils.PartiallyUnroll(factor); + } + } + + return Pass::Status::SuccessWithChange; + } +}; + +/* +Generated from the following GLSL +#version 330 core +layout(location = 0) out vec4 c; +void main() { + float x[10]; + for (int i = 0; i < 10; ++i) { + x[i] = 1.0f; + } +} +*/ +TEST_F(PassClassTest, SimplePartialUnroll) { + // With LocalMultiStoreElimPass + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 330 + OpName %2 "main" + OpName %5 "x" + OpName %3 "c" + OpDecorate %3 Location 0 + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %10 = OpConstant %8 0 + %11 = OpConstant %8 10 + %12 = OpTypeBool + %13 = OpTypeFloat 32 + %14 = OpTypeInt 32 0 + %15 = OpConstant %14 10 + %16 = OpTypeArray %13 %15 + %17 = OpTypePointer Function %16 + %18 = OpConstant %13 1 + %19 = OpTypePointer Function %13 + %20 = OpConstant %8 1 + %21 = OpTypeVector %13 4 + %22 = OpTypePointer Output %21 + %3 = OpVariable %22 Output + %2 = OpFunction %6 None %7 + %23 = OpLabel + %5 = OpVariable %17 Function + OpBranch %24 + %24 = OpLabel + %35 = OpPhi %8 %10 %23 %34 %26 + OpLoopMerge %25 %26 Unroll + OpBranch %27 + %27 = OpLabel + %29 = OpSLessThan %12 %35 %11 + OpBranchConditional %29 %30 %25 + %30 = OpLabel + %32 = OpAccessChain %19 %5 %35 + OpStore %32 %18 + OpBranch %26 + %26 = OpLabel + %34 = OpIAdd %8 %35 %20 + OpBranch %24 + %25 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const std::string output = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" %3 +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 330 +OpName %2 "main" +OpName %4 "x" +OpName %3 "c" +OpDecorate %3 Location 0 +%5 = OpTypeVoid +%6 = OpTypeFunction %5 +%7 = OpTypeInt 32 1 +%8 = OpTypePointer Function %7 +%9 = OpConstant %7 0 +%10 = OpConstant %7 10 +%11 = OpTypeBool +%12 = OpTypeFloat 32 +%13 = OpTypeInt 32 0 +%14 = OpConstant %13 10 +%15 = OpTypeArray %12 %14 +%16 = OpTypePointer Function %15 +%17 = OpConstant %12 1 +%18 = OpTypePointer Function %12 +%19 = OpConstant %7 1 +%20 = OpTypeVector %12 4 +%21 = OpTypePointer Output %20 +%3 = OpVariable %21 Output +%2 = OpFunction %5 None %6 +%22 = OpLabel +%4 = OpVariable %16 Function +OpBranch %23 +%23 = OpLabel +%24 = OpPhi %7 %9 %22 %39 %38 +OpLoopMerge %27 %38 DontUnroll +OpBranch %28 +%28 = OpLabel +%29 = OpSLessThan %11 %24 %10 +OpBranchConditional %29 %30 %27 +%30 = OpLabel +%31 = OpAccessChain %18 %4 %24 +OpStore %31 %17 +OpBranch %26 +%26 = OpLabel +%25 = OpIAdd %7 %24 %19 +OpBranch %32 +%32 = OpLabel +OpBranch %34 +%34 = OpLabel +%35 = OpSLessThan %11 %25 %10 +OpBranch %36 +%36 = OpLabel +%37 = OpAccessChain %18 %4 %25 +OpStore %37 %17 +OpBranch %38 +%38 = OpLabel +%39 = OpIAdd %7 %25 %19 +OpBranch %23 +%27 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck>(text, output, false); +} + +/* +Generated from the following GLSL +#version 330 core +layout(location = 0) out vec4 c; +void main() { + float x[10]; + for (int i = 0; i < 10; ++i) { + x[i] = 1.0f; + } +} +*/ +TEST_F(PassClassTest, SimpleUnevenPartialUnroll) { + // With LocalMultiStoreElimPass + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 330 + OpName %2 "main" + OpName %5 "x" + OpName %3 "c" + OpDecorate %3 Location 0 + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %10 = OpConstant %8 0 + %11 = OpConstant %8 10 + %12 = OpTypeBool + %13 = OpTypeFloat 32 + %14 = OpTypeInt 32 0 + %15 = OpConstant %14 10 + %16 = OpTypeArray %13 %15 + %17 = OpTypePointer Function %16 + %18 = OpConstant %13 1 + %19 = OpTypePointer Function %13 + %20 = OpConstant %8 1 + %21 = OpTypeVector %13 4 + %22 = OpTypePointer Output %21 + %3 = OpVariable %22 Output + %2 = OpFunction %6 None %7 + %23 = OpLabel + %5 = OpVariable %17 Function + OpBranch %24 + %24 = OpLabel + %35 = OpPhi %8 %10 %23 %34 %26 + OpLoopMerge %25 %26 Unroll + OpBranch %27 + %27 = OpLabel + %29 = OpSLessThan %12 %35 %11 + OpBranchConditional %29 %30 %25 + %30 = OpLabel + %32 = OpAccessChain %19 %5 %35 + OpStore %32 %18 + OpBranch %26 + %26 = OpLabel + %34 = OpIAdd %8 %35 %20 + OpBranch %24 + %25 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const std::string output = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" %3 +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 330 +OpName %2 "main" +OpName %4 "x" +OpName %3 "c" +OpDecorate %3 Location 0 +%5 = OpTypeVoid +%6 = OpTypeFunction %5 +%7 = OpTypeInt 32 1 +%8 = OpTypePointer Function %7 +%9 = OpConstant %7 0 +%10 = OpConstant %7 10 +%11 = OpTypeBool +%12 = OpTypeFloat 32 +%13 = OpTypeInt 32 0 +%14 = OpConstant %13 10 +%15 = OpTypeArray %12 %14 +%16 = OpTypePointer Function %15 +%17 = OpConstant %12 1 +%18 = OpTypePointer Function %12 +%19 = OpConstant %7 1 +%20 = OpTypeVector %12 4 +%21 = OpTypePointer Output %20 +%3 = OpVariable %21 Output +%58 = OpConstant %13 1 +%2 = OpFunction %5 None %6 +%22 = OpLabel +%4 = OpVariable %16 Function +OpBranch %23 +%23 = OpLabel +%24 = OpPhi %7 %9 %22 %25 %26 +OpLoopMerge %32 %26 Unroll +OpBranch %28 +%28 = OpLabel +%29 = OpSLessThan %11 %24 %58 +OpBranchConditional %29 %30 %32 +%30 = OpLabel +%31 = OpAccessChain %18 %4 %24 +OpStore %31 %17 +OpBranch %26 +%26 = OpLabel +%25 = OpIAdd %7 %24 %19 +OpBranch %23 +%32 = OpLabel +OpBranch %33 +%33 = OpLabel +%34 = OpPhi %7 %24 %32 %57 %56 +OpLoopMerge %41 %56 DontUnroll +OpBranch %35 +%35 = OpLabel +%36 = OpSLessThan %11 %34 %10 +OpBranchConditional %36 %37 %41 +%37 = OpLabel +%38 = OpAccessChain %18 %4 %34 +OpStore %38 %17 +OpBranch %39 +%39 = OpLabel +%40 = OpIAdd %7 %34 %19 +OpBranch %42 +%42 = OpLabel +OpBranch %44 +%44 = OpLabel +%45 = OpSLessThan %11 %40 %10 +OpBranch %46 +%46 = OpLabel +%47 = OpAccessChain %18 %4 %40 +OpStore %47 %17 +OpBranch %48 +%48 = OpLabel +%49 = OpIAdd %7 %40 %19 +OpBranch %50 +%50 = OpLabel +OpBranch %52 +%52 = OpLabel +%53 = OpSLessThan %11 %49 %10 +OpBranch %54 +%54 = OpLabel +%55 = OpAccessChain %18 %4 %49 +OpStore %55 %17 +OpBranch %56 +%56 = OpLabel +%57 = OpIAdd %7 %49 %19 +OpBranch %33 +%41 = OpLabel +OpReturn +%27 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + // By unrolling by a factor that doesn't divide evenly into the number of loop + // iterations we perfom an additional transform when partially unrolling to + // account for the remainder. + SinglePassRunAndCheck>(text, output, false); +} + +/* Generated from +#version 410 core +layout(location=0) flat in int upper_bound; +void main() { + float x[10]; + for (int i = 2; i < 8; i+=2) { + x[i] = i; + } +} +*/ +TEST_F(PassClassTest, SimpleLoopIterationsCheck) { + // With LocalMultiStoreElimPass + const std::string text = R"( +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" %3 +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 410 +OpName %2 "main" +OpName %5 "x" +OpName %3 "upper_bound" +OpDecorate %3 Flat +OpDecorate %3 Location 0 +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 2 +%11 = OpConstant %8 8 +%12 = OpTypeBool +%13 = OpTypeFloat 32 +%14 = OpTypeInt 32 0 +%15 = OpConstant %14 10 +%16 = OpTypeArray %13 %15 +%17 = OpTypePointer Function %16 +%18 = OpTypePointer Function %13 +%19 = OpTypePointer Input %8 +%3 = OpVariable %19 Input +%2 = OpFunction %6 None %7 +%20 = OpLabel +%5 = OpVariable %17 Function +OpBranch %21 +%21 = OpLabel +%34 = OpPhi %8 %10 %20 %33 %23 +OpLoopMerge %22 %23 Unroll +OpBranch %24 +%24 = OpLabel +%26 = OpSLessThan %12 %34 %11 +OpBranchConditional %26 %27 %22 +%27 = OpLabel +%30 = OpConvertSToF %13 %34 +%31 = OpAccessChain %18 %5 %34 +OpStore %31 %30 +OpBranch %23 +%23 = OpLabel +%33 = OpIAdd %8 %34 %10 +OpBranch %21 +%22 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + + Function* f = spvtest::GetFunction(module, 2); + + LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f); + EXPECT_EQ(loop_descriptor.NumLoops(), 1u); + + Loop& loop = loop_descriptor.GetLoopByIndex(0); + + EXPECT_TRUE(loop.HasUnrollLoopControl()); + + BasicBlock* condition = loop.FindConditionBlock(); + EXPECT_EQ(condition->id(), 24u); + + Instruction* induction = loop.FindConditionVariable(condition); + EXPECT_EQ(induction->result_id(), 34u); + + LoopUtils loop_utils{context.get(), &loop}; + EXPECT_TRUE(loop_utils.CanPerformUnroll()); + + size_t iterations = 0; + EXPECT_TRUE(loop.FindNumberOfIterations(induction, &*condition->ctail(), + &iterations)); + EXPECT_EQ(iterations, 3u); +} + +/* Generated from +#version 410 core +void main() { + float x[10]; + for (int i = -1; i < 6; i+=3) { + x[i] = i; + } +} +*/ +TEST_F(PassClassTest, SimpleLoopIterationsCheckSignedInit) { + // With LocalMultiStoreElimPass + const std::string text = R"( +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" %3 +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 410 +OpName %2 "main" +OpName %5 "x" +OpName %3 "upper_bound" +OpDecorate %3 Flat +OpDecorate %3 Location 0 +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 -1 +%11 = OpConstant %8 6 +%12 = OpTypeBool +%13 = OpTypeFloat 32 +%14 = OpTypeInt 32 0 +%15 = OpConstant %14 10 +%16 = OpTypeArray %13 %15 +%17 = OpTypePointer Function %16 +%18 = OpTypePointer Function %13 +%19 = OpConstant %8 3 +%20 = OpTypePointer Input %8 +%3 = OpVariable %20 Input +%2 = OpFunction %6 None %7 +%21 = OpLabel +%5 = OpVariable %17 Function +OpBranch %22 +%22 = OpLabel +%35 = OpPhi %8 %10 %21 %34 %24 +OpLoopMerge %23 %24 None +OpBranch %25 +%25 = OpLabel +%27 = OpSLessThan %12 %35 %11 +OpBranchConditional %27 %28 %23 +%28 = OpLabel +%31 = OpConvertSToF %13 %35 +%32 = OpAccessChain %18 %5 %35 +OpStore %32 %31 +OpBranch %24 +%24 = OpLabel +%34 = OpIAdd %8 %35 %19 +OpBranch %22 +%23 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + + Function* f = spvtest::GetFunction(module, 2); + + LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f); + + EXPECT_EQ(loop_descriptor.NumLoops(), 1u); + + Loop& loop = loop_descriptor.GetLoopByIndex(0); + + EXPECT_FALSE(loop.HasUnrollLoopControl()); + + BasicBlock* condition = loop.FindConditionBlock(); + EXPECT_EQ(condition->id(), 25u); + + Instruction* induction = loop.FindConditionVariable(condition); + EXPECT_EQ(induction->result_id(), 35u); + + LoopUtils loop_utils{context.get(), &loop}; + EXPECT_TRUE(loop_utils.CanPerformUnroll()); + + size_t iterations = 0; + EXPECT_TRUE(loop.FindNumberOfIterations(induction, &*condition->ctail(), + &iterations)); + EXPECT_EQ(iterations, 3u); +} + +/* +Generated from the following GLSL +#version 410 core +void main() { + float out_array[6]; + for (uint i = 0; i < 2; i++) { + for (int x = 0; x < 3; ++x) { + out_array[x + i*3] = i; + } + } +} +*/ +TEST_F(PassClassTest, UnrollNestedLoops) { + // With LocalMultiStoreElimPass + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 410 + OpName %4 "main" + OpName %35 "out_array" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 0 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 2 + %17 = OpTypeBool + %19 = OpTypeInt 32 1 + %20 = OpTypePointer Function %19 + %22 = OpConstant %19 0 + %29 = OpConstant %19 3 + %31 = OpTypeFloat 32 + %32 = OpConstant %6 6 + %33 = OpTypeArray %31 %32 + %34 = OpTypePointer Function %33 + %39 = OpConstant %6 3 + %44 = OpTypePointer Function %31 + %47 = OpConstant %19 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %35 = OpVariable %34 Function + OpBranch %10 + %10 = OpLabel + %51 = OpPhi %6 %9 %5 %50 %13 + OpLoopMerge %12 %13 Unroll + OpBranch %14 + %14 = OpLabel + %18 = OpULessThan %17 %51 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %23 + %23 = OpLabel + %54 = OpPhi %19 %22 %11 %48 %26 + OpLoopMerge %25 %26 Unroll + OpBranch %27 + %27 = OpLabel + %30 = OpSLessThan %17 %54 %29 + OpBranchConditional %30 %24 %25 + %24 = OpLabel + %37 = OpBitcast %6 %54 + %40 = OpIMul %6 %51 %39 + %41 = OpIAdd %6 %37 %40 + %43 = OpConvertUToF %31 %51 + %45 = OpAccessChain %44 %35 %41 + OpStore %45 %43 + OpBranch %26 + %26 = OpLabel + %48 = OpIAdd %19 %54 %47 + OpBranch %23 + %25 = OpLabel + OpBranch %13 + %13 = OpLabel + %50 = OpIAdd %6 %51 %47 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const std::string output = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 410 +OpName %2 "main" +OpName %3 "out_array" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 0 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 0 +%9 = OpConstant %6 2 +%10 = OpTypeBool +%11 = OpTypeInt 32 1 +%12 = OpTypePointer Function %11 +%13 = OpConstant %11 0 +%14 = OpConstant %11 3 +%15 = OpTypeFloat 32 +%16 = OpConstant %6 6 +%17 = OpTypeArray %15 %16 +%18 = OpTypePointer Function %17 +%19 = OpConstant %6 3 +%20 = OpTypePointer Function %15 +%21 = OpConstant %11 1 +%2 = OpFunction %4 None %5 +%22 = OpLabel +%3 = OpVariable %18 Function +OpBranch %23 +%23 = OpLabel +OpBranch %28 +%28 = OpLabel +%29 = OpULessThan %10 %8 %9 +OpBranch %30 +%30 = OpLabel +OpBranch %31 +%31 = OpLabel +OpBranch %36 +%36 = OpLabel +%37 = OpSLessThan %10 %13 %14 +OpBranch %38 +%38 = OpLabel +%39 = OpBitcast %6 %13 +%40 = OpIMul %6 %8 %19 +%41 = OpIAdd %6 %39 %40 +%42 = OpConvertUToF %15 %8 +%43 = OpAccessChain %20 %3 %41 +OpStore %43 %42 +OpBranch %34 +%34 = OpLabel +%33 = OpIAdd %11 %13 %21 +OpBranch %44 +%44 = OpLabel +OpBranch %46 +%46 = OpLabel +%47 = OpSLessThan %10 %33 %14 +OpBranch %48 +%48 = OpLabel +%49 = OpBitcast %6 %33 +%50 = OpIMul %6 %8 %19 +%51 = OpIAdd %6 %49 %50 +%52 = OpConvertUToF %15 %8 +%53 = OpAccessChain %20 %3 %51 +OpStore %53 %52 +OpBranch %54 +%54 = OpLabel +%55 = OpIAdd %11 %33 %21 +OpBranch %56 +%56 = OpLabel +OpBranch %58 +%58 = OpLabel +%59 = OpSLessThan %10 %55 %14 +OpBranch %60 +%60 = OpLabel +%61 = OpBitcast %6 %55 +%62 = OpIMul %6 %8 %19 +%63 = OpIAdd %6 %61 %62 +%64 = OpConvertUToF %15 %8 +%65 = OpAccessChain %20 %3 %63 +OpStore %65 %64 +OpBranch %66 +%66 = OpLabel +%67 = OpIAdd %11 %55 %21 +OpBranch %35 +%35 = OpLabel +OpBranch %26 +%26 = OpLabel +%25 = OpIAdd %6 %8 %21 +OpBranch %68 +%68 = OpLabel +OpBranch %70 +%70 = OpLabel +%71 = OpULessThan %10 %25 %9 +OpBranch %72 +%72 = OpLabel +OpBranch %73 +%73 = OpLabel +OpBranch %74 +%74 = OpLabel +%75 = OpSLessThan %10 %13 %14 +OpBranch %76 +%76 = OpLabel +%77 = OpBitcast %6 %13 +%78 = OpIMul %6 %25 %19 +%79 = OpIAdd %6 %77 %78 +%80 = OpConvertUToF %15 %25 +%81 = OpAccessChain %20 %3 %79 +OpStore %81 %80 +OpBranch %82 +%82 = OpLabel +%83 = OpIAdd %11 %13 %21 +OpBranch %84 +%84 = OpLabel +OpBranch %85 +%85 = OpLabel +%86 = OpSLessThan %10 %83 %14 +OpBranch %87 +%87 = OpLabel +%88 = OpBitcast %6 %83 +%89 = OpIMul %6 %25 %19 +%90 = OpIAdd %6 %88 %89 +%91 = OpConvertUToF %15 %25 +%92 = OpAccessChain %20 %3 %90 +OpStore %92 %91 +OpBranch %93 +%93 = OpLabel +%94 = OpIAdd %11 %83 %21 +OpBranch %95 +%95 = OpLabel +OpBranch %96 +%96 = OpLabel +%97 = OpSLessThan %10 %94 %14 +OpBranch %98 +%98 = OpLabel +%99 = OpBitcast %6 %94 +%100 = OpIMul %6 %25 %19 +%101 = OpIAdd %6 %99 %100 +%102 = OpConvertUToF %15 %25 +%103 = OpAccessChain %20 %3 %101 +OpStore %103 %102 +OpBranch %104 +%104 = OpLabel +%105 = OpIAdd %11 %94 %21 +OpBranch %106 +%106 = OpLabel +OpBranch %107 +%107 = OpLabel +%108 = OpIAdd %6 %25 %21 +OpBranch %27 +%27 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(text, output, false); +} + +/* +Generated from the following GLSL +#version 410 core +void main() { + float out_array[2]; + for (int i = -3; i < -1; i++) { + out_array[3 + i] = i; + } +} +*/ +TEST_F(PassClassTest, NegativeConditionAndInit) { + // With LocalMultiStoreElimPass + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 410 + OpName %4 "main" + OpName %23 "out_array" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 -3 + %16 = OpConstant %6 -1 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 2 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %25 = OpConstant %6 3 + %30 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %23 = OpVariable %22 Function + OpBranch %10 + %10 = OpLabel + %32 = OpPhi %6 %9 %5 %31 %13 + OpLoopMerge %12 %13 Unroll + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %32 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %26 = OpIAdd %6 %32 %25 + %28 = OpAccessChain %7 %23 %26 + OpStore %28 %32 + OpBranch %13 + %13 = OpLabel + %31 = OpIAdd %6 %32 %30 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd +)"; + + const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 410 +OpName %2 "main" +OpName %3 "out_array" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 1 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 -3 +%9 = OpConstant %6 -1 +%10 = OpTypeBool +%11 = OpTypeInt 32 0 +%12 = OpConstant %11 2 +%13 = OpTypeArray %6 %12 +%14 = OpTypePointer Function %13 +%15 = OpConstant %6 3 +%16 = OpConstant %6 1 +%2 = OpFunction %4 None %5 +%17 = OpLabel +%3 = OpVariable %14 Function +OpBranch %18 +%18 = OpLabel +OpBranch %23 +%23 = OpLabel +%24 = OpSLessThan %10 %8 %9 +OpBranch %25 +%25 = OpLabel +%26 = OpIAdd %6 %8 %15 +%27 = OpAccessChain %7 %3 %26 +OpStore %27 %8 +OpBranch %21 +%21 = OpLabel +%20 = OpIAdd %6 %8 %16 +OpBranch %28 +%28 = OpLabel +OpBranch %30 +%30 = OpLabel +%31 = OpSLessThan %10 %20 %9 +OpBranch %32 +%32 = OpLabel +%33 = OpIAdd %6 %20 %15 +%34 = OpAccessChain %7 %3 %33 +OpStore %34 %20 +OpBranch %35 +%35 = OpLabel +%36 = OpIAdd %6 %20 %16 +OpBranch %22 +%22 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + // SinglePassRunAndCheck(text, expected, false); + + Function* f = spvtest::GetFunction(module, 4); + + LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f); + EXPECT_EQ(loop_descriptor.NumLoops(), 1u); + + Loop& loop = loop_descriptor.GetLoopByIndex(0); + + EXPECT_TRUE(loop.HasUnrollLoopControl()); + + BasicBlock* condition = loop.FindConditionBlock(); + EXPECT_EQ(condition->id(), 14u); + + Instruction* induction = loop.FindConditionVariable(condition); + EXPECT_EQ(induction->result_id(), 32u); + + LoopUtils loop_utils{context.get(), &loop}; + EXPECT_TRUE(loop_utils.CanPerformUnroll()); + + size_t iterations = 0; + EXPECT_TRUE(loop.FindNumberOfIterations(induction, &*condition->ctail(), + &iterations)); + EXPECT_EQ(iterations, 2u); + SinglePassRunAndCheck(text, expected, false); +} + +/* +Generated from the following GLSL +#version 410 core +void main() { + float out_array[9]; + for (int i = -10; i < -1; i++) { + out_array[i] = i; + } +} +*/ +TEST_F(PassClassTest, NegativeConditionAndInitResidualUnroll) { + // With LocalMultiStoreElimPass + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 410 + OpName %4 "main" + OpName %23 "out_array" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 -10 + %16 = OpConstant %6 -1 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 9 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %25 = OpConstant %6 10 + %30 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %23 = OpVariable %22 Function + OpBranch %10 + %10 = OpLabel + %32 = OpPhi %6 %9 %5 %31 %13 + OpLoopMerge %12 %13 Unroll + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %32 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %26 = OpIAdd %6 %32 %25 + %28 = OpAccessChain %7 %23 %26 + OpStore %28 %32 + OpBranch %13 + %13 = OpLabel + %31 = OpIAdd %6 %32 %30 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd +)"; + + const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 410 +OpName %2 "main" +OpName %3 "out_array" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 1 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 -10 +%9 = OpConstant %6 -1 +%10 = OpTypeBool +%11 = OpTypeInt 32 0 +%12 = OpConstant %11 9 +%13 = OpTypeArray %6 %12 +%14 = OpTypePointer Function %13 +%15 = OpConstant %6 10 +%16 = OpConstant %6 1 +%48 = OpConstant %6 -9 +%2 = OpFunction %4 None %5 +%17 = OpLabel +%3 = OpVariable %14 Function +OpBranch %18 +%18 = OpLabel +%19 = OpPhi %6 %8 %17 %20 %21 +OpLoopMerge %28 %21 Unroll +OpBranch %23 +%23 = OpLabel +%24 = OpSLessThan %10 %19 %48 +OpBranchConditional %24 %25 %28 +%25 = OpLabel +%26 = OpIAdd %6 %19 %15 +%27 = OpAccessChain %7 %3 %26 +OpStore %27 %19 +OpBranch %21 +%21 = OpLabel +%20 = OpIAdd %6 %19 %16 +OpBranch %18 +%28 = OpLabel +OpBranch %29 +%29 = OpLabel +%30 = OpPhi %6 %19 %28 %47 %46 +OpLoopMerge %38 %46 DontUnroll +OpBranch %31 +%31 = OpLabel +%32 = OpSLessThan %10 %30 %9 +OpBranchConditional %32 %33 %38 +%33 = OpLabel +%34 = OpIAdd %6 %30 %15 +%35 = OpAccessChain %7 %3 %34 +OpStore %35 %30 +OpBranch %36 +%36 = OpLabel +%37 = OpIAdd %6 %30 %16 +OpBranch %39 +%39 = OpLabel +OpBranch %41 +%41 = OpLabel +%42 = OpSLessThan %10 %37 %9 +OpBranch %43 +%43 = OpLabel +%44 = OpIAdd %6 %37 %15 +%45 = OpAccessChain %7 %3 %44 +OpStore %45 %37 +OpBranch %46 +%46 = OpLabel +%47 = OpIAdd %6 %37 %16 +OpBranch %29 +%38 = OpLabel +OpReturn +%22 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + + Function* f = spvtest::GetFunction(module, 4); + + LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f); + EXPECT_EQ(loop_descriptor.NumLoops(), 1u); + + Loop& loop = loop_descriptor.GetLoopByIndex(0); + + EXPECT_TRUE(loop.HasUnrollLoopControl()); + + BasicBlock* condition = loop.FindConditionBlock(); + EXPECT_EQ(condition->id(), 14u); + + Instruction* induction = loop.FindConditionVariable(condition); + EXPECT_EQ(induction->result_id(), 32u); + + LoopUtils loop_utils{context.get(), &loop}; + EXPECT_TRUE(loop_utils.CanPerformUnroll()); + + size_t iterations = 0; + EXPECT_TRUE(loop.FindNumberOfIterations(induction, &*condition->ctail(), + &iterations)); + EXPECT_EQ(iterations, 9u); + SinglePassRunAndCheck>(text, expected, false); +} + +/* +Generated from the following GLSL +#version 410 core +void main() { + float out_array[10]; + for (uint i = 0; i < 2; i++) { + for (int x = 0; x < 5; ++x) { + out_array[x + i*5] = i; + } + } +} +*/ +TEST_F(PassClassTest, UnrollNestedLoopsValidateDescriptor) { + // With LocalMultiStoreElimPass + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 410 + OpName %4 "main" + OpName %35 "out_array" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 0 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 2 + %17 = OpTypeBool + %19 = OpTypeInt 32 1 + %20 = OpTypePointer Function %19 + %22 = OpConstant %19 0 + %29 = OpConstant %19 5 + %31 = OpTypeFloat 32 + %32 = OpConstant %6 10 + %33 = OpTypeArray %31 %32 + %34 = OpTypePointer Function %33 + %39 = OpConstant %6 5 + %44 = OpTypePointer Function %31 + %47 = OpConstant %19 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %35 = OpVariable %34 Function + OpBranch %10 + %10 = OpLabel + %51 = OpPhi %6 %9 %5 %50 %13 + OpLoopMerge %12 %13 Unroll + OpBranch %14 + %14 = OpLabel + %18 = OpULessThan %17 %51 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %23 + %23 = OpLabel + %54 = OpPhi %19 %22 %11 %48 %26 + OpLoopMerge %25 %26 Unroll + OpBranch %27 + %27 = OpLabel + %30 = OpSLessThan %17 %54 %29 + OpBranchConditional %30 %24 %25 + %24 = OpLabel + %37 = OpBitcast %6 %54 + %40 = OpIMul %6 %51 %39 + %41 = OpIAdd %6 %37 %40 + %43 = OpConvertUToF %31 %51 + %45 = OpAccessChain %44 %35 %41 + OpStore %45 %43 + OpBranch %26 + %26 = OpLabel + %48 = OpIAdd %19 %54 %47 + OpBranch %23 + %25 = OpLabel + OpBranch %13 + %13 = OpLabel + %50 = OpIAdd %6 %51 %47 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + { // Test fully unroll + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + + Function* f = spvtest::GetFunction(module, 4); + LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f); + EXPECT_EQ(loop_descriptor.NumLoops(), 2u); + + Loop& outer_loop = loop_descriptor.GetLoopByIndex(1); + + EXPECT_TRUE(outer_loop.HasUnrollLoopControl()); + + Loop& inner_loop = loop_descriptor.GetLoopByIndex(0); + + EXPECT_TRUE(inner_loop.HasUnrollLoopControl()); + + EXPECT_EQ(outer_loop.GetBlocks().size(), 9u); + + EXPECT_EQ(inner_loop.GetBlocks().size(), 4u); + EXPECT_EQ(outer_loop.NumImmediateChildren(), 1u); + EXPECT_EQ(inner_loop.NumImmediateChildren(), 0u); + + { + LoopUtils loop_utils{context.get(), &inner_loop}; + loop_utils.FullyUnroll(); + loop_utils.Finalize(); + } + + EXPECT_EQ(loop_descriptor.NumLoops(), 1u); + EXPECT_EQ(outer_loop.GetBlocks().size(), 25u); + EXPECT_EQ(outer_loop.NumImmediateChildren(), 0u); + { + LoopUtils loop_utils{context.get(), &outer_loop}; + loop_utils.FullyUnroll(); + loop_utils.Finalize(); + } + EXPECT_EQ(loop_descriptor.NumLoops(), 0u); + } + + { // Test partially unroll + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + + Function* f = spvtest::GetFunction(module, 4); + LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f); + EXPECT_EQ(loop_descriptor.NumLoops(), 2u); + + Loop& outer_loop = loop_descriptor.GetLoopByIndex(1); + + EXPECT_TRUE(outer_loop.HasUnrollLoopControl()); + + Loop& inner_loop = loop_descriptor.GetLoopByIndex(0); + + EXPECT_TRUE(inner_loop.HasUnrollLoopControl()); + + EXPECT_EQ(outer_loop.GetBlocks().size(), 9u); + + EXPECT_EQ(inner_loop.GetBlocks().size(), 4u); + + EXPECT_EQ(outer_loop.NumImmediateChildren(), 1u); + EXPECT_EQ(inner_loop.NumImmediateChildren(), 0u); + + LoopUtils loop_utils{context.get(), &inner_loop}; + loop_utils.PartiallyUnroll(2); + loop_utils.Finalize(); + + // The number of loops should actually grow. + EXPECT_EQ(loop_descriptor.NumLoops(), 3u); + EXPECT_EQ(outer_loop.GetBlocks().size(), 18u); + EXPECT_EQ(outer_loop.NumImmediateChildren(), 2u); + } +} + +/* +Generated from the following GLSL +#version 410 core +void main() { + float out_array[3]; + for (int i = 3; i > 0; --i) { + out_array[i] = i; + } +} +*/ +TEST_F(PassClassTest, FullyUnrollNegativeStepLoopTest) { + // With LocalMultiStoreElimPass + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 410 + OpName %4 "main" + OpName %24 "out_array" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 3 + %16 = OpConstant %6 0 + %17 = OpTypeBool + %19 = OpTypeFloat 32 + %20 = OpTypeInt 32 0 + %21 = OpConstant %20 3 + %22 = OpTypeArray %19 %21 + %23 = OpTypePointer Function %22 + %28 = OpTypePointer Function %19 + %31 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %24 = OpVariable %23 Function + OpBranch %10 + %10 = OpLabel + %33 = OpPhi %6 %9 %5 %32 %13 + OpLoopMerge %12 %13 Unroll + OpBranch %14 + %14 = OpLabel + %18 = OpSGreaterThan %17 %33 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpConvertSToF %19 %33 + %29 = OpAccessChain %28 %24 %33 + OpStore %29 %27 + OpBranch %13 + %13 = OpLabel + %32 = OpISub %6 %33 %31 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const std::string output = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 410 +OpName %2 "main" +OpName %3 "out_array" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 1 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 3 +%9 = OpConstant %6 0 +%10 = OpTypeBool +%11 = OpTypeFloat 32 +%12 = OpTypeInt 32 0 +%13 = OpConstant %12 3 +%14 = OpTypeArray %11 %13 +%15 = OpTypePointer Function %14 +%16 = OpTypePointer Function %11 +%17 = OpConstant %6 1 +%2 = OpFunction %4 None %5 +%18 = OpLabel +%3 = OpVariable %15 Function +OpBranch %19 +%19 = OpLabel +OpBranch %24 +%24 = OpLabel +%25 = OpSGreaterThan %10 %8 %9 +OpBranch %26 +%26 = OpLabel +%27 = OpConvertSToF %11 %8 +%28 = OpAccessChain %16 %3 %8 +OpStore %28 %27 +OpBranch %22 +%22 = OpLabel +%21 = OpISub %6 %8 %17 +OpBranch %29 +%29 = OpLabel +OpBranch %31 +%31 = OpLabel +%32 = OpSGreaterThan %10 %21 %9 +OpBranch %33 +%33 = OpLabel +%34 = OpConvertSToF %11 %21 +%35 = OpAccessChain %16 %3 %21 +OpStore %35 %34 +OpBranch %36 +%36 = OpLabel +%37 = OpISub %6 %21 %17 +OpBranch %38 +%38 = OpLabel +OpBranch %40 +%40 = OpLabel +%41 = OpSGreaterThan %10 %37 %9 +OpBranch %42 +%42 = OpLabel +%43 = OpConvertSToF %11 %37 +%44 = OpAccessChain %16 %3 %37 +OpStore %44 %43 +OpBranch %45 +%45 = OpLabel +%46 = OpISub %6 %37 %17 +OpBranch %23 +%23 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(text, output, false); +} + +/* +Generated from the following GLSL +#version 410 core +void main() { + float out_array[3]; + for (int i = 9; i > 0; i-=3) { + out_array[i] = i; + } +} +*/ +TEST_F(PassClassTest, FullyUnrollNegativeNonOneStepLoop) { + // With LocalMultiStoreElimPass + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 410 + OpName %4 "main" + OpName %24 "out_array" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 9 + %16 = OpConstant %6 0 + %17 = OpTypeBool + %19 = OpTypeFloat 32 + %20 = OpTypeInt 32 0 + %21 = OpConstant %20 3 + %22 = OpTypeArray %19 %21 + %23 = OpTypePointer Function %22 + %28 = OpTypePointer Function %19 + %30 = OpConstant %6 3 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %24 = OpVariable %23 Function + OpBranch %10 + %10 = OpLabel + %33 = OpPhi %6 %9 %5 %32 %13 + OpLoopMerge %12 %13 Unroll + OpBranch %14 + %14 = OpLabel + %18 = OpSGreaterThan %17 %33 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpConvertSToF %19 %33 + %29 = OpAccessChain %28 %24 %33 + OpStore %29 %27 + OpBranch %13 + %13 = OpLabel + %32 = OpISub %6 %33 %30 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const std::string output = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 410 +OpName %2 "main" +OpName %3 "out_array" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 1 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 9 +%9 = OpConstant %6 0 +%10 = OpTypeBool +%11 = OpTypeFloat 32 +%12 = OpTypeInt 32 0 +%13 = OpConstant %12 3 +%14 = OpTypeArray %11 %13 +%15 = OpTypePointer Function %14 +%16 = OpTypePointer Function %11 +%17 = OpConstant %6 3 +%2 = OpFunction %4 None %5 +%18 = OpLabel +%3 = OpVariable %15 Function +OpBranch %19 +%19 = OpLabel +OpBranch %24 +%24 = OpLabel +%25 = OpSGreaterThan %10 %8 %9 +OpBranch %26 +%26 = OpLabel +%27 = OpConvertSToF %11 %8 +%28 = OpAccessChain %16 %3 %8 +OpStore %28 %27 +OpBranch %22 +%22 = OpLabel +%21 = OpISub %6 %8 %17 +OpBranch %29 +%29 = OpLabel +OpBranch %31 +%31 = OpLabel +%32 = OpSGreaterThan %10 %21 %9 +OpBranch %33 +%33 = OpLabel +%34 = OpConvertSToF %11 %21 +%35 = OpAccessChain %16 %3 %21 +OpStore %35 %34 +OpBranch %36 +%36 = OpLabel +%37 = OpISub %6 %21 %17 +OpBranch %38 +%38 = OpLabel +OpBranch %40 +%40 = OpLabel +%41 = OpSGreaterThan %10 %37 %9 +OpBranch %42 +%42 = OpLabel +%43 = OpConvertSToF %11 %37 +%44 = OpAccessChain %16 %3 %37 +OpStore %44 %43 +OpBranch %45 +%45 = OpLabel +%46 = OpISub %6 %37 %17 +OpBranch %23 +%23 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(text, output, false); +} + +/* +Generated from the following GLSL +#version 410 core +void main() { + float out_array[3]; + for (int i = 0; i < 7; i+=3) { + out_array[i] = i; + } +} +*/ +TEST_F(PassClassTest, FullyUnrollNonDivisibleStepLoop) { + // With LocalMultiStoreElimPass + const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %4 "main" +OpExecutionMode %4 OriginUpperLeft +OpSource GLSL 410 +OpName %4 "main" +OpName %24 "out_array" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%6 = OpTypeInt 32 1 +%7 = OpTypePointer Function %6 +%9 = OpConstant %6 0 +%16 = OpConstant %6 7 +%17 = OpTypeBool +%19 = OpTypeFloat 32 +%20 = OpTypeInt 32 0 +%21 = OpConstant %20 3 +%22 = OpTypeArray %19 %21 +%23 = OpTypePointer Function %22 +%28 = OpTypePointer Function %19 +%30 = OpConstant %6 3 +%4 = OpFunction %2 None %3 +%5 = OpLabel +%24 = OpVariable %23 Function +OpBranch %10 +%10 = OpLabel +%33 = OpPhi %6 %9 %5 %32 %13 +OpLoopMerge %12 %13 Unroll +OpBranch %14 +%14 = OpLabel +%18 = OpSLessThan %17 %33 %16 +OpBranchConditional %18 %11 %12 +%11 = OpLabel +%27 = OpConvertSToF %19 %33 +%29 = OpAccessChain %28 %24 %33 +OpStore %29 %27 +OpBranch %13 +%13 = OpLabel +%32 = OpIAdd %6 %33 %30 +OpBranch %10 +%12 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string output = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 410 +OpName %2 "main" +OpName %3 "out_array" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 1 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 0 +%9 = OpConstant %6 7 +%10 = OpTypeBool +%11 = OpTypeFloat 32 +%12 = OpTypeInt 32 0 +%13 = OpConstant %12 3 +%14 = OpTypeArray %11 %13 +%15 = OpTypePointer Function %14 +%16 = OpTypePointer Function %11 +%17 = OpConstant %6 3 +%2 = OpFunction %4 None %5 +%18 = OpLabel +%3 = OpVariable %15 Function +OpBranch %19 +%19 = OpLabel +OpBranch %24 +%24 = OpLabel +%25 = OpSLessThan %10 %8 %9 +OpBranch %26 +%26 = OpLabel +%27 = OpConvertSToF %11 %8 +%28 = OpAccessChain %16 %3 %8 +OpStore %28 %27 +OpBranch %22 +%22 = OpLabel +%21 = OpIAdd %6 %8 %17 +OpBranch %29 +%29 = OpLabel +OpBranch %31 +%31 = OpLabel +%32 = OpSLessThan %10 %21 %9 +OpBranch %33 +%33 = OpLabel +%34 = OpConvertSToF %11 %21 +%35 = OpAccessChain %16 %3 %21 +OpStore %35 %34 +OpBranch %36 +%36 = OpLabel +%37 = OpIAdd %6 %21 %17 +OpBranch %38 +%38 = OpLabel +OpBranch %40 +%40 = OpLabel +%41 = OpSLessThan %10 %37 %9 +OpBranch %42 +%42 = OpLabel +%43 = OpConvertSToF %11 %37 +%44 = OpAccessChain %16 %3 %37 +OpStore %44 %43 +OpBranch %45 +%45 = OpLabel +%46 = OpIAdd %6 %37 %17 +OpBranch %23 +%23 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(text, output, false); +} + +/* +Generated from the following GLSL +#version 410 core +void main() { + float out_array[4]; + for (int i = 11; i > 0; i-=3) { + out_array[i] = i; + } +} +*/ +TEST_F(PassClassTest, FullyUnrollNegativeNonDivisibleStepLoop) { + // With LocalMultiStoreElimPass + const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %4 "main" +OpExecutionMode %4 OriginUpperLeft +OpSource GLSL 410 +OpName %4 "main" +OpName %24 "out_array" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%6 = OpTypeInt 32 1 +%7 = OpTypePointer Function %6 +%9 = OpConstant %6 11 +%16 = OpConstant %6 0 +%17 = OpTypeBool +%19 = OpTypeFloat 32 +%20 = OpTypeInt 32 0 +%21 = OpConstant %20 4 +%22 = OpTypeArray %19 %21 +%23 = OpTypePointer Function %22 +%28 = OpTypePointer Function %19 +%30 = OpConstant %6 3 +%4 = OpFunction %2 None %3 +%5 = OpLabel +%24 = OpVariable %23 Function +OpBranch %10 +%10 = OpLabel +%33 = OpPhi %6 %9 %5 %32 %13 +OpLoopMerge %12 %13 Unroll +OpBranch %14 +%14 = OpLabel +%18 = OpSGreaterThan %17 %33 %16 +OpBranchConditional %18 %11 %12 +%11 = OpLabel +%27 = OpConvertSToF %19 %33 +%29 = OpAccessChain %28 %24 %33 +OpStore %29 %27 +OpBranch %13 +%13 = OpLabel +%32 = OpISub %6 %33 %30 +OpBranch %10 +%12 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string output = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 410 +OpName %2 "main" +OpName %3 "out_array" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 1 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 11 +%9 = OpConstant %6 0 +%10 = OpTypeBool +%11 = OpTypeFloat 32 +%12 = OpTypeInt 32 0 +%13 = OpConstant %12 4 +%14 = OpTypeArray %11 %13 +%15 = OpTypePointer Function %14 +%16 = OpTypePointer Function %11 +%17 = OpConstant %6 3 +%2 = OpFunction %4 None %5 +%18 = OpLabel +%3 = OpVariable %15 Function +OpBranch %19 +%19 = OpLabel +OpBranch %24 +%24 = OpLabel +%25 = OpSGreaterThan %10 %8 %9 +OpBranch %26 +%26 = OpLabel +%27 = OpConvertSToF %11 %8 +%28 = OpAccessChain %16 %3 %8 +OpStore %28 %27 +OpBranch %22 +%22 = OpLabel +%21 = OpISub %6 %8 %17 +OpBranch %29 +%29 = OpLabel +OpBranch %31 +%31 = OpLabel +%32 = OpSGreaterThan %10 %21 %9 +OpBranch %33 +%33 = OpLabel +%34 = OpConvertSToF %11 %21 +%35 = OpAccessChain %16 %3 %21 +OpStore %35 %34 +OpBranch %36 +%36 = OpLabel +%37 = OpISub %6 %21 %17 +OpBranch %38 +%38 = OpLabel +OpBranch %40 +%40 = OpLabel +%41 = OpSGreaterThan %10 %37 %9 +OpBranch %42 +%42 = OpLabel +%43 = OpConvertSToF %11 %37 +%44 = OpAccessChain %16 %3 %37 +OpStore %44 %43 +OpBranch %45 +%45 = OpLabel +%46 = OpISub %6 %37 %17 +OpBranch %47 +%47 = OpLabel +OpBranch %49 +%49 = OpLabel +%50 = OpSGreaterThan %10 %46 %9 +OpBranch %51 +%51 = OpLabel +%52 = OpConvertSToF %11 %46 +%53 = OpAccessChain %16 %3 %46 +OpStore %53 %52 +OpBranch %54 +%54 = OpLabel +%55 = OpISub %6 %46 %17 +OpBranch %23 +%23 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(text, output, false); +} + +// With LocalMultiStoreElimPass +static const std::string multiple_phi_shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 410 + OpName %4 "main" + OpName %8 "foo(" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypeFunction %6 + %10 = OpTypePointer Function %6 + %12 = OpConstant %6 0 + %14 = OpConstant %6 3 + %22 = OpConstant %6 6 + %23 = OpTypeBool + %31 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %40 = OpFunctionCall %6 %8 + OpReturn + OpFunctionEnd + %8 = OpFunction %6 None %7 + %9 = OpLabel + OpBranch %16 + %16 = OpLabel + %41 = OpPhi %6 %12 %9 %34 %19 + %42 = OpPhi %6 %14 %9 %29 %19 + %43 = OpPhi %6 %12 %9 %32 %19 + OpLoopMerge %18 %19 Unroll + OpBranch %20 + %20 = OpLabel + %24 = OpSLessThan %23 %43 %22 + OpBranchConditional %24 %17 %18 + %17 = OpLabel + %27 = OpIMul %6 %43 %41 + %29 = OpIAdd %6 %42 %27 + OpBranch %19 + %19 = OpLabel + %32 = OpIAdd %6 %43 %31 + %34 = OpISub %6 %41 %31 + OpBranch %16 + %18 = OpLabel + %37 = OpIAdd %6 %42 %41 + OpReturnValue %37 + OpFunctionEnd + )"; + +TEST_F(PassClassTest, PartiallyUnrollResidualMultipleInductionVariables) { + const std::string output = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 410 +OpName %2 "main" +OpName %3 "foo(" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 1 +%7 = OpTypeFunction %6 +%8 = OpTypePointer Function %6 +%9 = OpConstant %6 0 +%10 = OpConstant %6 3 +%11 = OpConstant %6 6 +%12 = OpTypeBool +%13 = OpConstant %6 1 +%82 = OpTypeInt 32 0 +%83 = OpConstant %82 2 +%2 = OpFunction %4 None %5 +%14 = OpLabel +%15 = OpFunctionCall %6 %3 +OpReturn +OpFunctionEnd +%3 = OpFunction %6 None %7 +%16 = OpLabel +OpBranch %17 +%17 = OpLabel +%18 = OpPhi %6 %9 %16 %19 %20 +%21 = OpPhi %6 %10 %16 %22 %20 +%23 = OpPhi %6 %9 %16 %24 %20 +OpLoopMerge %31 %20 Unroll +OpBranch %26 +%26 = OpLabel +%27 = OpSLessThan %12 %23 %83 +OpBranchConditional %27 %28 %31 +%28 = OpLabel +%29 = OpIMul %6 %23 %18 +%22 = OpIAdd %6 %21 %29 +OpBranch %20 +%20 = OpLabel +%24 = OpIAdd %6 %23 %13 +%19 = OpISub %6 %18 %13 +OpBranch %17 +%31 = OpLabel +OpBranch %32 +%32 = OpLabel +%33 = OpPhi %6 %18 %31 %81 %79 +%34 = OpPhi %6 %21 %31 %78 %79 +%35 = OpPhi %6 %23 %31 %80 %79 +OpLoopMerge %44 %79 DontUnroll +OpBranch %36 +%36 = OpLabel +%37 = OpSLessThan %12 %35 %11 +OpBranchConditional %37 %38 %44 +%38 = OpLabel +%39 = OpIMul %6 %35 %33 +%40 = OpIAdd %6 %34 %39 +OpBranch %41 +%41 = OpLabel +%42 = OpIAdd %6 %35 %13 +%43 = OpISub %6 %33 %13 +OpBranch %46 +%46 = OpLabel +OpBranch %50 +%50 = OpLabel +%51 = OpSLessThan %12 %42 %11 +OpBranch %52 +%52 = OpLabel +%53 = OpIMul %6 %42 %43 +%54 = OpIAdd %6 %40 %53 +OpBranch %55 +%55 = OpLabel +%56 = OpIAdd %6 %42 %13 +%57 = OpISub %6 %43 %13 +OpBranch %58 +%58 = OpLabel +OpBranch %62 +%62 = OpLabel +%63 = OpSLessThan %12 %56 %11 +OpBranch %64 +%64 = OpLabel +%65 = OpIMul %6 %56 %57 +%66 = OpIAdd %6 %54 %65 +OpBranch %67 +%67 = OpLabel +%68 = OpIAdd %6 %56 %13 +%69 = OpISub %6 %57 %13 +OpBranch %70 +%70 = OpLabel +OpBranch %74 +%74 = OpLabel +%75 = OpSLessThan %12 %68 %11 +OpBranch %76 +%76 = OpLabel +%77 = OpIMul %6 %68 %69 +%78 = OpIAdd %6 %66 %77 +OpBranch %79 +%79 = OpLabel +%80 = OpIAdd %6 %68 %13 +%81 = OpISub %6 %69 %13 +OpBranch %32 +%44 = OpLabel +%45 = OpIAdd %6 %34 %33 +OpReturnValue %45 +%25 = OpLabel +%30 = OpIAdd %6 %34 %33 +OpReturnValue %30 +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, multiple_phi_shader, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << multiple_phi_shader << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck>(multiple_phi_shader, output, + false); +} + +TEST_F(PassClassTest, PartiallyUnrollMultipleInductionVariables) { + const std::string output = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 410 +OpName %2 "main" +OpName %3 "foo(" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 1 +%7 = OpTypeFunction %6 +%8 = OpTypePointer Function %6 +%9 = OpConstant %6 0 +%10 = OpConstant %6 3 +%11 = OpConstant %6 6 +%12 = OpTypeBool +%13 = OpConstant %6 1 +%2 = OpFunction %4 None %5 +%14 = OpLabel +%15 = OpFunctionCall %6 %3 +OpReturn +OpFunctionEnd +%3 = OpFunction %6 None %7 +%16 = OpLabel +OpBranch %17 +%17 = OpLabel +%18 = OpPhi %6 %9 %16 %42 %40 +%21 = OpPhi %6 %10 %16 %39 %40 +%23 = OpPhi %6 %9 %16 %41 %40 +OpLoopMerge %25 %40 DontUnroll +OpBranch %26 +%26 = OpLabel +%27 = OpSLessThan %12 %23 %11 +OpBranchConditional %27 %28 %25 +%28 = OpLabel +%29 = OpIMul %6 %23 %18 +%22 = OpIAdd %6 %21 %29 +OpBranch %20 +%20 = OpLabel +%24 = OpIAdd %6 %23 %13 +%19 = OpISub %6 %18 %13 +OpBranch %31 +%31 = OpLabel +OpBranch %35 +%35 = OpLabel +%36 = OpSLessThan %12 %24 %11 +OpBranch %37 +%37 = OpLabel +%38 = OpIMul %6 %24 %19 +%39 = OpIAdd %6 %22 %38 +OpBranch %40 +%40 = OpLabel +%41 = OpIAdd %6 %24 %13 +%42 = OpISub %6 %19 %13 +OpBranch %17 +%25 = OpLabel +%30 = OpIAdd %6 %21 %18 +OpReturnValue %30 +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, multiple_phi_shader, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << multiple_phi_shader << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck>(multiple_phi_shader, output, + false); +} + +TEST_F(PassClassTest, FullyUnrollMultipleInductionVariables) { + const std::string output = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 410 +OpName %2 "main" +OpName %3 "foo(" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 1 +%7 = OpTypeFunction %6 +%8 = OpTypePointer Function %6 +%9 = OpConstant %6 0 +%10 = OpConstant %6 3 +%11 = OpConstant %6 6 +%12 = OpTypeBool +%13 = OpConstant %6 1 +%2 = OpFunction %4 None %5 +%14 = OpLabel +%15 = OpFunctionCall %6 %3 +OpReturn +OpFunctionEnd +%3 = OpFunction %6 None %7 +%16 = OpLabel +OpBranch %17 +%17 = OpLabel +OpBranch %26 +%26 = OpLabel +%27 = OpSLessThan %12 %9 %11 +OpBranch %28 +%28 = OpLabel +%29 = OpIMul %6 %9 %9 +%22 = OpIAdd %6 %10 %29 +OpBranch %20 +%20 = OpLabel +%24 = OpIAdd %6 %9 %13 +%19 = OpISub %6 %9 %13 +OpBranch %31 +%31 = OpLabel +OpBranch %35 +%35 = OpLabel +%36 = OpSLessThan %12 %24 %11 +OpBranch %37 +%37 = OpLabel +%38 = OpIMul %6 %24 %19 +%39 = OpIAdd %6 %22 %38 +OpBranch %40 +%40 = OpLabel +%41 = OpIAdd %6 %24 %13 +%42 = OpISub %6 %19 %13 +OpBranch %43 +%43 = OpLabel +OpBranch %47 +%47 = OpLabel +%48 = OpSLessThan %12 %41 %11 +OpBranch %49 +%49 = OpLabel +%50 = OpIMul %6 %41 %42 +%51 = OpIAdd %6 %39 %50 +OpBranch %52 +%52 = OpLabel +%53 = OpIAdd %6 %41 %13 +%54 = OpISub %6 %42 %13 +OpBranch %55 +%55 = OpLabel +OpBranch %59 +%59 = OpLabel +%60 = OpSLessThan %12 %53 %11 +OpBranch %61 +%61 = OpLabel +%62 = OpIMul %6 %53 %54 +%63 = OpIAdd %6 %51 %62 +OpBranch %64 +%64 = OpLabel +%65 = OpIAdd %6 %53 %13 +%66 = OpISub %6 %54 %13 +OpBranch %67 +%67 = OpLabel +OpBranch %71 +%71 = OpLabel +%72 = OpSLessThan %12 %65 %11 +OpBranch %73 +%73 = OpLabel +%74 = OpIMul %6 %65 %66 +%75 = OpIAdd %6 %63 %74 +OpBranch %76 +%76 = OpLabel +%77 = OpIAdd %6 %65 %13 +%78 = OpISub %6 %66 %13 +OpBranch %79 +%79 = OpLabel +OpBranch %83 +%83 = OpLabel +%84 = OpSLessThan %12 %77 %11 +OpBranch %85 +%85 = OpLabel +%86 = OpIMul %6 %77 %78 +%87 = OpIAdd %6 %75 %86 +OpBranch %88 +%88 = OpLabel +%89 = OpIAdd %6 %77 %13 +%90 = OpISub %6 %78 %13 +OpBranch %25 +%25 = OpLabel +%30 = OpIAdd %6 %87 %90 +OpReturnValue %30 +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, multiple_phi_shader, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << multiple_phi_shader << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(multiple_phi_shader, output, false); +} + +/* +Generated from the following GLSL +#version 440 core +void main() +{ + int j = 0; + for (int i = 0; i <= 2; ++i) + ++j; + + for (int i = 1; i >= 0; --i) + ++j; +} +*/ +TEST_F(PassClassTest, FullyUnrollEqualToOperations) { + // With LocalMultiStoreElimPass + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 440 + OpName %4 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %17 = OpConstant %6 2 + %18 = OpTypeBool + %21 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %11 + %11 = OpLabel + %37 = OpPhi %6 %9 %5 %22 %14 + %38 = OpPhi %6 %9 %5 %24 %14 + OpLoopMerge %13 %14 Unroll + OpBranch %15 + %15 = OpLabel + %19 = OpSLessThanEqual %18 %38 %17 + OpBranchConditional %19 %12 %13 + %12 = OpLabel + %22 = OpIAdd %6 %37 %21 + OpBranch %14 + %14 = OpLabel + %24 = OpIAdd %6 %38 %21 + OpBranch %11 + %13 = OpLabel + OpBranch %26 + %26 = OpLabel + %39 = OpPhi %6 %37 %13 %34 %29 + %40 = OpPhi %6 %21 %13 %36 %29 + OpLoopMerge %28 %29 Unroll + OpBranch %30 + %30 = OpLabel + %32 = OpSGreaterThanEqual %18 %40 %9 + OpBranchConditional %32 %27 %28 + %27 = OpLabel + %34 = OpIAdd %6 %39 %21 + OpBranch %29 + %29 = OpLabel + %36 = OpISub %6 %40 %21 + OpBranch %26 + %28 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const std::string output = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 440 +OpName %2 "main" +%3 = OpTypeVoid +%4 = OpTypeFunction %3 +%5 = OpTypeInt 32 1 +%6 = OpTypePointer Function %5 +%7 = OpConstant %5 0 +%8 = OpConstant %5 2 +%9 = OpTypeBool +%10 = OpConstant %5 1 +%2 = OpFunction %3 None %4 +%11 = OpLabel +OpBranch %12 +%12 = OpLabel +OpBranch %19 +%19 = OpLabel +%20 = OpSLessThanEqual %9 %7 %8 +OpBranch %21 +%21 = OpLabel +%14 = OpIAdd %5 %7 %10 +OpBranch %15 +%15 = OpLabel +%17 = OpIAdd %5 %7 %10 +OpBranch %41 +%41 = OpLabel +OpBranch %44 +%44 = OpLabel +%45 = OpSLessThanEqual %9 %17 %8 +OpBranch %46 +%46 = OpLabel +%47 = OpIAdd %5 %14 %10 +OpBranch %48 +%48 = OpLabel +%49 = OpIAdd %5 %17 %10 +OpBranch %50 +%50 = OpLabel +OpBranch %53 +%53 = OpLabel +%54 = OpSLessThanEqual %9 %49 %8 +OpBranch %55 +%55 = OpLabel +%56 = OpIAdd %5 %47 %10 +OpBranch %57 +%57 = OpLabel +%58 = OpIAdd %5 %49 %10 +OpBranch %18 +%18 = OpLabel +OpBranch %22 +%22 = OpLabel +OpBranch %29 +%29 = OpLabel +%30 = OpSGreaterThanEqual %9 %10 %7 +OpBranch %31 +%31 = OpLabel +%24 = OpIAdd %5 %56 %10 +OpBranch %25 +%25 = OpLabel +%27 = OpISub %5 %10 %10 +OpBranch %32 +%32 = OpLabel +OpBranch %35 +%35 = OpLabel +%36 = OpSGreaterThanEqual %9 %27 %7 +OpBranch %37 +%37 = OpLabel +%38 = OpIAdd %5 %24 %10 +OpBranch %39 +%39 = OpLabel +%40 = OpISub %5 %27 %10 +OpBranch %28 +%28 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(text, output, false); +} + +// With LocalMultiStoreElimPass +const std::string condition_in_header = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %o + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 430 + OpDecorate %o Location 0 + %void = OpTypeVoid + %6 = OpTypeFunction %void + %int = OpTypeInt 32 1 + %int_n2 = OpConstant %int -2 + %int_2 = OpConstant %int 2 + %bool = OpTypeBool + %float = OpTypeFloat 32 +%_ptr_Output_float = OpTypePointer Output %float + %o = OpVariable %_ptr_Output_float Output + %float_1 = OpConstant %float 1 + %main = OpFunction %void None %6 + %15 = OpLabel + OpBranch %16 + %16 = OpLabel + %27 = OpPhi %int %int_n2 %15 %26 %18 + %21 = OpSLessThanEqual %bool %27 %int_2 + OpLoopMerge %17 %18 Unroll + OpBranchConditional %21 %22 %17 + %22 = OpLabel + %23 = OpLoad %float %o + %24 = OpFAdd %float %23 %float_1 + OpStore %o %24 + OpBranch %18 + %18 = OpLabel + %26 = OpIAdd %int %27 %int_2 + OpBranch %16 + %17 = OpLabel + OpReturn + OpFunctionEnd + )"; + +TEST_F(PassClassTest, FullyUnrollConditionIsInHeaderBlock) { + const std::string output = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "main" %2 +OpExecutionMode %1 OriginUpperLeft +OpSource GLSL 430 +OpDecorate %2 Location 0 +%3 = OpTypeVoid +%4 = OpTypeFunction %3 +%5 = OpTypeInt 32 1 +%6 = OpConstant %5 -2 +%7 = OpConstant %5 2 +%8 = OpTypeBool +%9 = OpTypeFloat 32 +%10 = OpTypePointer Output %9 +%2 = OpVariable %10 Output +%11 = OpConstant %9 1 +%1 = OpFunction %3 None %4 +%12 = OpLabel +OpBranch %13 +%13 = OpLabel +%17 = OpSLessThanEqual %8 %6 %7 +OpBranch %19 +%19 = OpLabel +%20 = OpLoad %9 %2 +%21 = OpFAdd %9 %20 %11 +OpStore %2 %21 +OpBranch %16 +%16 = OpLabel +%15 = OpIAdd %5 %6 %7 +OpBranch %22 +%22 = OpLabel +%24 = OpSLessThanEqual %8 %15 %7 +OpBranch %25 +%25 = OpLabel +%26 = OpLoad %9 %2 +%27 = OpFAdd %9 %26 %11 +OpStore %2 %27 +OpBranch %28 +%28 = OpLabel +%29 = OpIAdd %5 %15 %7 +OpBranch %30 +%30 = OpLabel +%32 = OpSLessThanEqual %8 %29 %7 +OpBranch %33 +%33 = OpLabel +%34 = OpLoad %9 %2 +%35 = OpFAdd %9 %34 %11 +OpStore %2 %35 +OpBranch %36 +%36 = OpLabel +%37 = OpIAdd %5 %29 %7 +OpBranch %18 +%18 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, condition_in_header, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << condition_in_header << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(condition_in_header, output, false); +} + +TEST_F(PassClassTest, PartiallyUnrollResidualConditionIsInHeaderBlock) { + const std::string output = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "main" %2 +OpExecutionMode %1 OriginUpperLeft +OpSource GLSL 430 +OpDecorate %2 Location 0 +%3 = OpTypeVoid +%4 = OpTypeFunction %3 +%5 = OpTypeInt 32 1 +%6 = OpConstant %5 -2 +%7 = OpConstant %5 2 +%8 = OpTypeBool +%9 = OpTypeFloat 32 +%10 = OpTypePointer Output %9 +%2 = OpVariable %10 Output +%11 = OpConstant %9 1 +%40 = OpTypeInt 32 0 +%41 = OpConstant %40 1 +%1 = OpFunction %3 None %4 +%12 = OpLabel +OpBranch %13 +%13 = OpLabel +%14 = OpPhi %5 %6 %12 %15 %16 +%17 = OpSLessThanEqual %8 %14 %41 +OpLoopMerge %22 %16 Unroll +OpBranchConditional %17 %19 %22 +%19 = OpLabel +%20 = OpLoad %9 %2 +%21 = OpFAdd %9 %20 %11 +OpStore %2 %21 +OpBranch %16 +%16 = OpLabel +%15 = OpIAdd %5 %14 %7 +OpBranch %13 +%22 = OpLabel +OpBranch %23 +%23 = OpLabel +%24 = OpPhi %5 %14 %22 %39 %38 +%25 = OpSLessThanEqual %8 %24 %7 +OpLoopMerge %31 %38 DontUnroll +OpBranchConditional %25 %26 %31 +%26 = OpLabel +%27 = OpLoad %9 %2 +%28 = OpFAdd %9 %27 %11 +OpStore %2 %28 +OpBranch %29 +%29 = OpLabel +%30 = OpIAdd %5 %24 %7 +OpBranch %32 +%32 = OpLabel +%34 = OpSLessThanEqual %8 %30 %7 +OpBranch %35 +%35 = OpLabel +%36 = OpLoad %9 %2 +%37 = OpFAdd %9 %36 %11 +OpStore %2 %37 +OpBranch %38 +%38 = OpLabel +%39 = OpIAdd %5 %30 %7 +OpBranch %23 +%31 = OpLabel +OpReturn +%18 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, condition_in_header, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << condition_in_header << std::endl; + + LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck>(condition_in_header, output, + false); +} + +/* +Generated from following GLSL with latch block artificially inserted to be +seperate from continue. +#version 430 +void main(void) { + float x[10]; + for (int i = 0; i < 10; ++i) { + x[i] = i; + } +} +*/ +TEST_F(PassClassTest, PartiallyUnrollLatchNotContinue) { + const std::string text = R"(OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %3 "i" + OpName %4 "x" + %5 = OpTypeVoid + %6 = OpTypeFunction %5 + %7 = OpTypeInt 32 1 + %8 = OpTypePointer Function %7 + %9 = OpConstant %7 0 + %10 = OpConstant %7 10 + %11 = OpTypeBool + %12 = OpTypeFloat 32 + %13 = OpTypeInt 32 0 + %14 = OpConstant %13 10 + %15 = OpTypeArray %12 %14 + %16 = OpTypePointer Function %15 + %17 = OpTypePointer Function %12 + %18 = OpConstant %7 1 + %2 = OpFunction %5 None %6 + %19 = OpLabel + %3 = OpVariable %8 Function + %4 = OpVariable %16 Function + OpStore %3 %9 + OpBranch %20 + %20 = OpLabel + %21 = OpPhi %7 %9 %19 %22 %30 + OpLoopMerge %24 %23 Unroll + OpBranch %25 + %25 = OpLabel + %26 = OpSLessThan %11 %21 %10 + OpBranchConditional %26 %27 %24 + %27 = OpLabel + %28 = OpConvertSToF %12 %21 + %29 = OpAccessChain %17 %4 %21 + OpStore %29 %28 + OpBranch %23 + %23 = OpLabel + %22 = OpIAdd %7 %21 %18 + OpStore %3 %22 + OpBranch %30 + %30 = OpLabel + OpBranch %20 + %24 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 430 +OpName %2 "main" +OpName %3 "i" +OpName %4 "x" +%5 = OpTypeVoid +%6 = OpTypeFunction %5 +%7 = OpTypeInt 32 1 +%8 = OpTypePointer Function %7 +%9 = OpConstant %7 0 +%10 = OpConstant %7 10 +%11 = OpTypeBool +%12 = OpTypeFloat 32 +%13 = OpTypeInt 32 0 +%14 = OpConstant %13 10 +%15 = OpTypeArray %12 %14 +%16 = OpTypePointer Function %15 +%17 = OpTypePointer Function %12 +%18 = OpConstant %7 1 +%63 = OpConstant %13 1 +%2 = OpFunction %5 None %6 +%19 = OpLabel +%3 = OpVariable %8 Function +%4 = OpVariable %16 Function +OpStore %3 %9 +OpBranch %20 +%20 = OpLabel +%21 = OpPhi %7 %9 %19 %22 %23 +OpLoopMerge %31 %25 Unroll +OpBranch %26 +%26 = OpLabel +%27 = OpSLessThan %11 %21 %63 +OpBranchConditional %27 %28 %31 +%28 = OpLabel +%29 = OpConvertSToF %12 %21 +%30 = OpAccessChain %17 %4 %21 +OpStore %30 %29 +OpBranch %25 +%25 = OpLabel +%22 = OpIAdd %7 %21 %18 +OpStore %3 %22 +OpBranch %23 +%23 = OpLabel +OpBranch %20 +%31 = OpLabel +OpBranch %32 +%32 = OpLabel +%33 = OpPhi %7 %21 %31 %61 %62 +OpLoopMerge %42 %60 DontUnroll +OpBranch %34 +%34 = OpLabel +%35 = OpSLessThan %11 %33 %10 +OpBranchConditional %35 %36 %42 +%36 = OpLabel +%37 = OpConvertSToF %12 %33 +%38 = OpAccessChain %17 %4 %33 +OpStore %38 %37 +OpBranch %39 +%39 = OpLabel +%40 = OpIAdd %7 %33 %18 +OpStore %3 %40 +OpBranch %41 +%41 = OpLabel +OpBranch %43 +%43 = OpLabel +OpBranch %45 +%45 = OpLabel +%46 = OpSLessThan %11 %40 %10 +OpBranch %47 +%47 = OpLabel +%48 = OpConvertSToF %12 %40 +%49 = OpAccessChain %17 %4 %40 +OpStore %49 %48 +OpBranch %50 +%50 = OpLabel +%51 = OpIAdd %7 %40 %18 +OpStore %3 %51 +OpBranch %52 +%52 = OpLabel +OpBranch %53 +%53 = OpLabel +OpBranch %55 +%55 = OpLabel +%56 = OpSLessThan %11 %51 %10 +OpBranch %57 +%57 = OpLabel +%58 = OpConvertSToF %12 %51 +%59 = OpAccessChain %17 %4 %51 +OpStore %59 %58 +OpBranch %60 +%60 = OpLabel +%61 = OpIAdd %7 %51 %18 +OpStore %3 %61 +OpBranch %62 +%62 = OpLabel +OpBranch %32 +%42 = OpLabel +OpReturn +%24 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck>(text, expected, true); + + // Make sure the latch block information is preserved and propagated correctly + // by the pass. + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + + PartialUnrollerTestPass<3> unroller; + unroller.SetContextForTesting(context.get()); + unroller.Process(); + + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 2); + LoopDescriptor ld{context.get(), f}; + + EXPECT_EQ(ld.NumLoops(), 2u); + + Loop& loop_1 = ld.GetLoopByIndex(0u); + EXPECT_NE(loop_1.GetLatchBlock(), loop_1.GetContinueBlock()); + + Loop& loop_2 = ld.GetLoopByIndex(1u); + EXPECT_NE(loop_2.GetLatchBlock(), loop_2.GetContinueBlock()); +} + +// Test that a loop with a self-referencing OpPhi instruction is handled +// correctly. +TEST_F(PassClassTest, OpPhiSelfReference) { + const std::string text = R"( + ; Find the two adds from the unrolled loop + ; CHECK: OpIAdd + ; CHECK: OpIAdd + ; CHECK: OpIAdd %uint %uint_0 %uint_1 + ; CHECK-NEXT: OpReturn + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %2 "main" + OpExecutionMode %2 LocalSize 8 8 1 + OpSource HLSL 600 + %uint = OpTypeInt 32 0 + %void = OpTypeVoid + %5 = OpTypeFunction %void +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 + %bool = OpTypeBool + %true = OpConstantTrue %bool + %2 = OpFunction %void None %5 + %10 = OpLabel + OpBranch %19 + %19 = OpLabel + %20 = OpPhi %uint %uint_0 %10 %20 %21 + %22 = OpPhi %uint %uint_0 %10 %23 %21 + %24 = OpULessThanEqual %bool %22 %uint_1 + OpLoopMerge %25 %21 Unroll + OpBranchConditional %24 %21 %25 + %21 = OpLabel + %23 = OpIAdd %uint %22 %uint_1 + OpBranch %19 + %25 = OpLabel + %14 = OpIAdd %uint %20 %uint_1 + OpReturn + OpFunctionEnd + )"; + + const bool kFullyUnroll = true; + const uint32_t kUnrollFactor = 0; + SinglePassRunAndMatch(text, true, kFullyUnroll, + kUnrollFactor); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/loop_optimizations/unswitch.cpp b/test/opt/loop_optimizations/unswitch.cpp new file mode 100644 index 000000000..dc7073fd5 --- /dev/null +++ b/test/opt/loop_optimizations/unswitch.cpp @@ -0,0 +1,967 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "effcee/effcee.h" +#include "gmock/gmock.h" +#include "test/opt/pass_fixture.h" + +namespace spvtools { +namespace opt { +namespace { + +using UnswitchTest = PassTest<::testing::Test>; + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 450 core +uniform vec4 c; +void main() { + int i = 0; + int j = 0; + bool cond = c[0] == 0; + for (; i < 10; i++, j++) { + if (cond) { + i++; + } + else { + j++; + } + } +} +*/ +TEST_F(UnswitchTest, SimpleUnswitch) { + const std::string text = R"( +; CHECK: [[cst_cond:%\w+]] = OpFOrdEqual +; CHECK-NEXT: OpSelectionMerge [[if_merge:%\w+]] None +; CHECK-NEXT: OpBranchConditional [[cst_cond]] [[loop_t:%\w+]] [[loop_f:%\w+]] + +; Loop specialized for false. +; CHECK: [[loop_f]] = OpLabel +; CHECK-NEXT: OpBranch [[loop:%\w+]] +; CHECK: [[loop]] = OpLabel +; CHECK-NEXT: [[phi_i:%\w+]] = OpPhi %int %int_0 [[loop_f]] [[iv_i:%\w+]] [[continue:%\w+]] +; CHECK-NEXT: [[phi_j:%\w+]] = OpPhi %int %int_0 [[loop_f]] [[iv_j:%\w+]] [[continue]] +; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None +; CHECK: [[loop_exit:%\w+]] = OpSLessThan {{%\w+}} [[phi_i]] {{%\w+}} +; CHECK-NEXT: OpBranchConditional [[loop_exit]] [[loop_body:%\w+]] [[merge]] +; [[loop_body]] = OpLabel +; CHECK: OpSelectionMerge [[sel_merge:%\w+]] None +; CHECK: OpBranchConditional %false [[bb1:%\w+]] [[bb2:%\w+]] +; CHECK: [[bb2]] = OpLabel +; CHECK-NEXT: [[inc_j:%\w+]] = OpIAdd %int [[phi_j]] %int_1 +; CHECK-NEXT: OpBranch [[sel_merge]] +; CHECK: [[bb1]] = OpLabel +; CHECK-NEXT: [[inc_i:%\w+]] = OpIAdd %int [[phi_i]] %int_1 +; CHECK-NEXT: OpBranch [[sel_merge]] +; CHECK: [[sel_merge]] = OpLabel +; CHECK: OpBranch [[if_merge]] + +; Loop specialized for true. +; CHECK: [[loop_t]] = OpLabel +; CHECK-NEXT: OpBranch [[loop:%\w+]] +; CHECK: [[loop]] = OpLabel +; CHECK-NEXT: [[phi_i:%\w+]] = OpPhi %int %int_0 [[loop_t]] [[iv_i:%\w+]] [[continue:%\w+]] +; CHECK-NEXT: [[phi_j:%\w+]] = OpPhi %int %int_0 [[loop_t]] [[iv_j:%\w+]] [[continue]] +; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None +; CHECK: [[loop_exit:%\w+]] = OpSLessThan {{%\w+}} [[phi_i]] {{%\w+}} +; CHECK-NEXT: OpBranchConditional [[loop_exit]] [[loop_body:%\w+]] [[merge]] +; [[loop_body]] = OpLabel +; CHECK: OpSelectionMerge [[sel_merge:%\w+]] None +; CHECK: OpBranchConditional %true [[bb1:%\w+]] [[bb2:%\w+]] +; CHECK: [[bb1]] = OpLabel +; CHECK-NEXT: [[inc_i:%\w+]] = OpIAdd %int [[phi_i]] %int_1 +; CHECK-NEXT: OpBranch [[sel_merge]] +; CHECK: [[bb2]] = OpLabel +; CHECK-NEXT: [[inc_j:%\w+]] = OpIAdd %int [[phi_j]] %int_1 +; CHECK-NEXT: OpBranch [[sel_merge]] +; CHECK: [[sel_merge]] = OpLabel +; CHECK: OpBranch [[if_merge]] + +; CHECK: [[if_merge]] = OpLabel +; CHECK-NEXT: OpReturn + + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginLowerLeft + OpSource GLSL 450 + OpName %main "main" + OpName %c "c" + OpDecorate %c Location 0 + OpDecorate %c DescriptorSet 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %int_0 = OpConstant %int 0 + %bool = OpTypeBool +%_ptr_Function_bool = OpTypePointer Function %bool + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%_ptr_UniformConstant_v4float = OpTypePointer UniformConstant %v4float + %c = OpVariable %_ptr_UniformConstant_v4float UniformConstant + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 +%_ptr_UniformConstant_float = OpTypePointer UniformConstant %float + %float_0 = OpConstant %float 0 + %int_10 = OpConstant %int 10 + %int_1 = OpConstant %int 1 + %main = OpFunction %void None %3 + %5 = OpLabel + %21 = OpAccessChain %_ptr_UniformConstant_float %c %uint_0 + %22 = OpLoad %float %21 + %24 = OpFOrdEqual %bool %22 %float_0 + OpBranch %25 + %25 = OpLabel + %46 = OpPhi %int %int_0 %5 %43 %28 + %47 = OpPhi %int %int_0 %5 %45 %28 + OpLoopMerge %27 %28 None + OpBranch %29 + %29 = OpLabel + %32 = OpSLessThan %bool %46 %int_10 + OpBranchConditional %32 %26 %27 + %26 = OpLabel + OpSelectionMerge %35 None + OpBranchConditional %24 %34 %39 + %34 = OpLabel + %38 = OpIAdd %int %46 %int_1 + OpBranch %35 + %39 = OpLabel + %41 = OpIAdd %int %47 %int_1 + OpBranch %35 + %35 = OpLabel + %48 = OpPhi %int %38 %34 %46 %39 + %49 = OpPhi %int %47 %34 %41 %39 + OpBranch %28 + %28 = OpLabel + %43 = OpIAdd %int %48 %int_1 + %45 = OpIAdd %int %49 %int_1 + OpBranch %25 + %27 = OpLabel + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 330 core +in vec4 c; +void main() { + int i = 0; + bool cond = c[0] == 0; + for (; i < 10; i++) { + if (cond) { + i++; + } + else { + return; + } + } +} +*/ +TEST_F(UnswitchTest, UnswitchExit) { + const std::string text = R"( +; CHECK: [[cst_cond:%\w+]] = OpFOrdEqual +; CHECK-NEXT: OpSelectionMerge [[if_merge:%\w+]] None +; CHECK-NEXT: OpBranchConditional [[cst_cond]] [[loop_t:%\w+]] [[loop_f:%\w+]] + +; Loop specialized for false. +; CHECK: [[loop_f]] = OpLabel +; CHECK: OpReturn + +; Loop specialized for true. +; CHECK: [[loop_t]] = OpLabel +; CHECK-NEXT: OpBranch [[loop:%\w+]] +; CHECK: [[loop]] = OpLabel +; CHECK-NEXT: [[phi_i:%\w+]] = OpPhi %int %int_0 [[loop_t]] [[iv_i:%\w+]] [[continue:%\w+]] +; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None +; CHECK: [[loop_exit:%\w+]] = OpSLessThan {{%\w+}} [[phi_i]] {{%\w+}} +; CHECK-NEXT: OpBranchConditional [[loop_exit]] {{%\w+}} [[merge]] +; Check that we have i+=2. +; CHECK: [[phi_i:%\w+]] = OpIAdd %int [[phi_i]] %int_1 +; CHECK: [[iv_i]] = OpIAdd %int [[phi_i]] %int_1 +; CHECK: [[merge]] = OpLabel +; CHECK-NEXT: OpBranch [[if_merge]] + +; CHECK: [[if_merge]] = OpLabel +; CHECK-NEXT: OpReturn + + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %c + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 330 + OpName %main "main" + OpName %c "c" + OpDecorate %c Location 0 + OpDecorate %23 Uniform + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %int_0 = OpConstant %int 0 + %bool = OpTypeBool +%_ptr_Function_bool = OpTypePointer Function %bool + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float + %c = OpVariable %_ptr_Input_v4float Input + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 +%_ptr_Input_float = OpTypePointer Input %float + %float_0 = OpConstant %float 0 + %int_10 = OpConstant %int 10 + %int_1 = OpConstant %int 1 + %main = OpFunction %void None %3 + %5 = OpLabel + %20 = OpAccessChain %_ptr_Input_float %c %uint_0 + %21 = OpLoad %float %20 + %23 = OpFOrdEqual %bool %21 %float_0 + OpBranch %24 + %24 = OpLabel + %42 = OpPhi %int %int_0 %5 %41 %27 + OpLoopMerge %26 %27 None + OpBranch %28 + %28 = OpLabel + %31 = OpSLessThan %bool %42 %int_10 + OpBranchConditional %31 %25 %26 + %25 = OpLabel + OpSelectionMerge %34 None + OpBranchConditional %23 %33 %38 + %33 = OpLabel + %37 = OpIAdd %int %42 %int_1 + OpBranch %34 + %38 = OpLabel + OpReturn + %34 = OpLabel + OpBranch %27 + %27 = OpLabel + %41 = OpIAdd %int %37 %int_1 + OpBranch %24 + %26 = OpLabel + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 330 core +in vec4 c; +void main() { + int i = 0; + bool cond = c[0] == 0; + for (; i < 10; i++) { + if (cond) { + continue; + } + else { + i++; + } + } +} +*/ +TEST_F(UnswitchTest, UnswitchContinue) { + const std::string text = R"( +; CHECK: [[cst_cond:%\w+]] = OpFOrdEqual +; CHECK-NEXT: OpSelectionMerge [[if_merge:%\w+]] None +; CHECK-NEXT: OpBranchConditional [[cst_cond]] [[loop_t:%\w+]] [[loop_f:%\w+]] + +; Loop specialized for false. +; CHECK: [[loop_f]] = OpLabel +; CHECK-NEXT: OpBranch [[loop:%\w+]] +; CHECK: [[loop]] = OpLabel +; CHECK-NEXT: [[phi_i:%\w+]] = OpPhi %int %int_0 [[loop_f]] [[iv_i:%\w+]] [[continue:%\w+]] +; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None +; CHECK: [[loop_exit:%\w+]] = OpSLessThan {{%\w+}} [[phi_i]] {{%\w+}} +; CHECK-NEXT: OpBranchConditional [[loop_exit]] [[loop_body:%\w+]] [[merge]] +; CHECK: [[loop_body:%\w+]] = OpLabel +; CHECK-NEXT: OpSelectionMerge +; CHECK-NEXT: OpBranchConditional %false +; CHECK: [[merge]] = OpLabel +; CHECK-NEXT: OpBranch [[if_merge]] + +; Loop specialized for true. +; CHECK: [[loop_t]] = OpLabel +; CHECK-NEXT: OpBranch [[loop:%\w+]] +; CHECK: [[loop]] = OpLabel +; CHECK-NEXT: [[phi_i:%\w+]] = OpPhi %int %int_0 [[loop_t]] [[iv_i:%\w+]] [[continue:%\w+]] +; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None +; CHECK: [[loop_exit:%\w+]] = OpSLessThan {{%\w+}} [[phi_i]] {{%\w+}} +; CHECK-NEXT: OpBranchConditional [[loop_exit]] [[loop_body:%\w+]] [[merge]] +; CHECK: [[loop_body:%\w+]] = OpLabel +; CHECK-NEXT: OpSelectionMerge +; CHECK-NEXT: OpBranchConditional %true +; CHECK: [[merge]] = OpLabel +; CHECK-NEXT: OpBranch [[if_merge]] + +; CHECK: [[if_merge]] = OpLabel +; CHECK-NEXT: OpReturn + + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %c + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 330 + OpName %main "main" + OpName %c "c" + OpDecorate %c Location 0 + OpDecorate %23 Uniform + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %int_0 = OpConstant %int 0 + %bool = OpTypeBool +%_ptr_Function_bool = OpTypePointer Function %bool + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float + %c = OpVariable %_ptr_Input_v4float Input + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 +%_ptr_Input_float = OpTypePointer Input %float + %float_0 = OpConstant %float 0 + %int_10 = OpConstant %int 10 + %int_1 = OpConstant %int 1 + %main = OpFunction %void None %3 + %5 = OpLabel + %20 = OpAccessChain %_ptr_Input_float %c %uint_0 + %21 = OpLoad %float %20 + %23 = OpFOrdEqual %bool %21 %float_0 + OpBranch %24 + %24 = OpLabel + %42 = OpPhi %int %int_0 %5 %41 %27 + OpLoopMerge %26 %27 None + OpBranch %28 + %28 = OpLabel + %31 = OpSLessThan %bool %42 %int_10 + OpBranchConditional %31 %25 %26 + %25 = OpLabel + OpSelectionMerge %34 None + OpBranchConditional %23 %33 %36 + %33 = OpLabel + OpBranch %27 + %36 = OpLabel + %39 = OpIAdd %int %42 %int_1 + OpBranch %34 + %34 = OpLabel + OpBranch %27 + %27 = OpLabel + %43 = OpPhi %int %42 %33 %39 %34 + %41 = OpIAdd %int %43 %int_1 + OpBranch %24 + %26 = OpLabel + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 330 core +in vec4 c; +void main() { + int i = 0; + bool cond = c[0] == 0; + for (; i < 10; i++) { + if (cond) { + i++; + } + else { + break; + } + } +} +*/ +TEST_F(UnswitchTest, UnswitchKillLoop) { + const std::string text = R"( +; CHECK: [[cst_cond:%\w+]] = OpFOrdEqual +; CHECK-NEXT: OpSelectionMerge [[if_merge:%\w+]] None +; CHECK-NEXT: OpBranchConditional [[cst_cond]] [[loop_t:%\w+]] [[loop_f:%\w+]] + +; Loop specialized for false. +; CHECK: [[loop_f]] = OpLabel +; CHECK: OpBranch [[if_merge]] + +; Loop specialized for true. +; CHECK: [[loop_t]] = OpLabel +; CHECK-NEXT: OpBranch [[loop:%\w+]] +; CHECK: [[loop]] = OpLabel +; CHECK-NEXT: [[phi_i:%\w+]] = OpPhi %int %int_0 [[loop_t]] [[iv_i:%\w+]] [[continue:%\w+]] +; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None +; CHECK: [[loop_exit:%\w+]] = OpSLessThan {{%\w+}} [[phi_i]] {{%\w+}} +; CHECK-NEXT: OpBranchConditional [[loop_exit]] {{%\w+}} [[merge]] +; Check that we have i+=2. +; CHECK: [[phi_i:%\w+]] = OpIAdd %int [[phi_i]] %int_1 +; CHECK: [[iv_i]] = OpIAdd %int [[phi_i]] %int_1 +; CHECK: [[merge]] = OpLabel +; CHECK-NEXT: OpBranch [[if_merge]] + +; CHECK: [[if_merge]] = OpLabel +; CHECK-NEXT: OpReturn + + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %c + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 330 + OpName %main "main" + OpName %c "c" + OpDecorate %c Location 0 + OpDecorate %23 Uniform + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %int_0 = OpConstant %int 0 + %bool = OpTypeBool +%_ptr_Function_bool = OpTypePointer Function %bool + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float + %c = OpVariable %_ptr_Input_v4float Input + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 +%_ptr_Input_float = OpTypePointer Input %float + %float_0 = OpConstant %float 0 + %int_10 = OpConstant %int 10 + %int_1 = OpConstant %int 1 + %main = OpFunction %void None %3 + %5 = OpLabel + %20 = OpAccessChain %_ptr_Input_float %c %uint_0 + %21 = OpLoad %float %20 + %23 = OpFOrdEqual %bool %21 %float_0 + OpBranch %24 + %24 = OpLabel + %42 = OpPhi %int %int_0 %5 %41 %27 + OpLoopMerge %26 %27 None + OpBranch %28 + %28 = OpLabel + %31 = OpSLessThan %bool %42 %int_10 + OpBranchConditional %31 %25 %26 + %25 = OpLabel + OpSelectionMerge %34 None + OpBranchConditional %23 %33 %38 + %33 = OpLabel + %37 = OpIAdd %int %42 %int_1 + OpBranch %34 + %38 = OpLabel + OpBranch %26 + %34 = OpLabel + OpBranch %27 + %27 = OpLabel + %41 = OpIAdd %int %37 %int_1 + OpBranch %24 + %26 = OpLabel + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 330 core +in vec4 c; +void main() { + int i = 0; + int cond = int(c[0]); + for (; i < 10; i++) { + switch (cond) { + case 0: + return; + case 1: + discard; + case 2: + break; + default: + break; + } + } + bool cond2 = i == 9; +} +*/ +TEST_F(UnswitchTest, UnswitchSwitch) { + const std::string text = R"( +; CHECK: [[cst_cond:%\w+]] = OpConvertFToS +; CHECK-NEXT: OpSelectionMerge [[if_merge:%\w+]] None +; CHECK-NEXT: OpSwitch [[cst_cond]] [[default:%\w+]] 0 [[loop_0:%\w+]] 1 [[loop_1:%\w+]] 2 [[loop_2:%\w+]] + +; Loop specialized for 2. +; CHECK: [[loop_2]] = OpLabel +; CHECK-NEXT: OpBranch [[loop:%\w+]] +; CHECK: [[loop]] = OpLabel +; CHECK-NEXT: [[phi_i:%\w+]] = OpPhi %int %int_0 [[loop_2]] [[iv_i:%\w+]] [[continue:%\w+]] +; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None +; CHECK: [[loop_exit:%\w+]] = OpSLessThan {{%\w+}} [[phi_i]] {{%\w+}} +; CHECK-NEXT: OpBranchConditional [[loop_exit]] [[loop_body:%\w+]] [[merge]] +; CHECK: [[loop_body]] = OpLabel +; CHECK-NEXT: OpSelectionMerge +; CHECK-NEXT: OpSwitch %int_2 +; CHECK: [[merge]] = OpLabel +; CHECK-NEXT: OpBranch [[if_merge]] + +; Loop specialized for 1. +; CHECK: [[loop_1]] = OpLabel +; CHECK-NEXT: OpBranch [[loop:%\w+]] +; CHECK: [[loop]] = OpLabel +; CHECK-NEXT: [[phi_i:%\w+]] = OpPhi %int %int_0 [[loop_1]] [[iv_i:%\w+]] [[continue:%\w+]] +; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None +; CHECK: [[loop_exit:%\w+]] = OpSLessThan {{%\w+}} [[phi_i]] {{%\w+}} +; CHECK-NEXT: OpBranchConditional [[loop_exit]] [[loop_body:%\w+]] [[merge]] +; CHECK: [[loop_body]] = OpLabel +; CHECK-NEXT: OpSelectionMerge +; CHECK-NEXT: OpSwitch %int_1 +; CHECK: [[merge]] = OpLabel +; CHECK-NEXT: OpBranch [[if_merge]] + +; Loop specialized for 0. +; CHECK: [[loop_0]] = OpLabel +; CHECK-NEXT: OpBranch [[loop:%\w+]] +; CHECK: [[loop]] = OpLabel +; CHECK-NEXT: [[phi_i:%\w+]] = OpPhi %int %int_0 [[loop_0]] [[iv_i:%\w+]] [[continue:%\w+]] +; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None +; CHECK: [[loop_exit:%\w+]] = OpSLessThan {{%\w+}} [[phi_i]] {{%\w+}} +; CHECK-NEXT: OpBranchConditional [[loop_exit]] [[loop_body:%\w+]] [[merge]] +; CHECK: [[loop_body]] = OpLabel +; CHECK-NEXT: OpSelectionMerge +; CHECK-NEXT: OpSwitch %int_0 +; CHECK: [[merge]] = OpLabel +; CHECK-NEXT: OpBranch [[if_merge]] + +; Loop specialized for the default case. +; CHECK: [[default]] = OpLabel +; CHECK-NEXT: OpBranch [[loop:%\w+]] +; CHECK: [[loop]] = OpLabel +; CHECK-NEXT: [[phi_i:%\w+]] = OpPhi %int %int_0 [[default]] [[iv_i:%\w+]] [[continue:%\w+]] +; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None +; CHECK: [[loop_exit:%\w+]] = OpSLessThan {{%\w+}} [[phi_i]] {{%\w+}} +; CHECK-NEXT: OpBranchConditional [[loop_exit]] [[loop_body:%\w+]] [[merge]] +; CHECK: [[loop_body]] = OpLabel +; CHECK-NEXT: OpSelectionMerge +; CHECK-NEXT: OpSwitch %uint_3 +; CHECK: [[merge]] = OpLabel +; CHECK-NEXT: OpBranch [[if_merge]] + +; CHECK: [[if_merge]] = OpLabel +; CHECK-NEXT: OpReturn + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %c + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 330 + OpName %main "main" + OpName %c "c" + OpDecorate %c Location 0 + OpDecorate %20 Uniform + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %int_0 = OpConstant %int 0 + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float + %c = OpVariable %_ptr_Input_v4float Input + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 +%_ptr_Input_float = OpTypePointer Input %float + %int_10 = OpConstant %int 10 + %bool = OpTypeBool + %int_1 = OpConstant %int 1 +%_ptr_Function_bool = OpTypePointer Function %bool + %main = OpFunction %void None %3 + %5 = OpLabel + %18 = OpAccessChain %_ptr_Input_float %c %uint_0 + %19 = OpLoad %float %18 + %20 = OpConvertFToS %int %19 + OpBranch %21 + %21 = OpLabel + %49 = OpPhi %int %int_0 %5 %43 %24 + OpLoopMerge %23 %24 None + OpBranch %25 + %25 = OpLabel + %29 = OpSLessThan %bool %49 %int_10 + OpBranchConditional %29 %22 %23 + %22 = OpLabel + OpSelectionMerge %35 None + OpSwitch %20 %34 0 %31 1 %32 2 %33 + %34 = OpLabel + OpBranch %35 + %31 = OpLabel + OpReturn + %32 = OpLabel + OpKill + %33 = OpLabel + OpBranch %35 + %35 = OpLabel + OpBranch %24 + %24 = OpLabel + %43 = OpIAdd %int %49 %int_1 + OpBranch %21 + %23 = OpLabel + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 440 core +layout(location = 0)in vec4 c; +void main() { + int i = 0; + int j = 0; + int k = 0; + bool cond = c[0] == 0; + for (; i < 10; i++) { + for (; j < 10; j++) { + if (cond) { + i++; + } else { + j++; + } + } + } +} +*/ +TEST_F(UnswitchTest, UnSwitchNested) { + // Test that an branch can be unswitched out of two nested loops. + const std::string text = R"( +; CHECK: [[cst_cond:%\w+]] = OpFOrdEqual +; CHECK-NEXT: OpSelectionMerge [[if_merge:%\w+]] None +; CHECK-NEXT: OpBranchConditional [[cst_cond]] [[loop_t:%\w+]] [[loop_f:%\w+]] + +; Loop specialized for false +; CHECK: [[loop_f]] = OpLabel +; CHECK-NEXT: OpBranch [[loop:%\w+]] +; CHECK: [[loop]] = OpLabel +; CHECK-NEXT: {{%\w+}} = OpPhi %int %int_0 [[loop_f]] {{%\w+}} [[continue:%\w+]] +; CHECK-NEXT: {{%\w+}} = OpPhi %int %int_0 [[loop_f]] {{%\w+}} [[continue]] +; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None +; CHECK-NOT: [[merge]] = OpLabel +; CHECK: OpLoopMerge +; CHECK-NEXT: OpBranch [[bb1:%\w+]] +; CHECK: [[bb1]] = OpLabel +; CHECK-NEXT: OpSLessThan +; CHECK-NEXT: OpBranchConditional {{%\w+}} [[bb2:%\w+]] +; CHECK: [[bb2]] = OpLabel +; CHECK-NEXT: OpSelectionMerge +; CHECK-NEXT: OpBranchConditional %false +; CHECK: [[merge]] = OpLabel + +; Loop specialized for true. Same as first loop except the branch condition is true. +; CHECK: [[loop_t]] = OpLabel +; CHECK-NEXT: OpBranch [[loop:%\w+]] +; CHECK: [[loop]] = OpLabel +; CHECK-NEXT: {{%\w+}} = OpPhi %int %int_0 [[loop_t]] {{%\w+}} [[continue:%\w+]] +; CHECK-NEXT: {{%\w+}} = OpPhi %int %int_0 [[loop_t]] {{%\w+}} [[continue]] +; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None +; CHECK-NOT: [[merge]] = OpLabel +; CHECK: OpLoopMerge +; CHECK-NEXT: OpBranch [[bb1:%\w+]] +; CHECK: [[bb1]] = OpLabel +; CHECK-NEXT: OpSLessThan +; CHECK-NEXT: OpBranchConditional {{%\w+}} [[bb2:%\w+]] +; CHECK: [[bb2]] = OpLabel +; CHECK-NEXT: OpSelectionMerge +; CHECK-NEXT: OpBranchConditional %true +; CHECK: [[merge]] = OpLabel + + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %c + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 440 + OpName %main "main" + OpName %c "c" + OpDecorate %c Location 0 + OpDecorate %25 Uniform + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %int_0 = OpConstant %int 0 + %bool = OpTypeBool +%_ptr_Function_bool = OpTypePointer Function %bool + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float + %c = OpVariable %_ptr_Input_v4float Input + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 +%_ptr_Input_float = OpTypePointer Input %float + %float_0 = OpConstant %float 0 + %int_10 = OpConstant %int 10 + %int_1 = OpConstant %int 1 + %main = OpFunction %void None %3 + %5 = OpLabel + %22 = OpAccessChain %_ptr_Input_float %c %uint_0 + %23 = OpLoad %float %22 + %25 = OpFOrdEqual %bool %23 %float_0 + OpBranch %26 + %26 = OpLabel + %67 = OpPhi %int %int_0 %5 %52 %29 + %68 = OpPhi %int %int_0 %5 %70 %29 + OpLoopMerge %28 %29 None + OpBranch %30 + %30 = OpLabel + %33 = OpSLessThan %bool %67 %int_10 + OpBranchConditional %33 %27 %28 + %27 = OpLabel + OpBranch %34 + %34 = OpLabel + %69 = OpPhi %int %67 %27 %46 %37 + %70 = OpPhi %int %68 %27 %50 %37 + OpLoopMerge %36 %37 None + OpBranch %38 + %38 = OpLabel + %40 = OpSLessThan %bool %70 %int_10 + OpBranchConditional %40 %35 %36 + %35 = OpLabel + OpSelectionMerge %43 None + OpBranchConditional %25 %42 %47 + %42 = OpLabel + %46 = OpIAdd %int %69 %int_1 + OpBranch %43 + %47 = OpLabel + OpReturn + %43 = OpLabel + OpBranch %37 + %37 = OpLabel + %50 = OpIAdd %int %70 %int_1 + OpBranch %34 + %36 = OpLabel + OpBranch %29 + %29 = OpLabel + %52 = OpIAdd %int %69 %int_1 + OpBranch %26 + %28 = OpLabel + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 330 core +in vec4 c; +void main() { + bool cond = false; + if (c[0] == 0) { + cond = c[1] == 0; + } else { + cond = c[2] == 0; + } + for (int i = 0; i < 10; i++) { + if (cond) { + i++; + } + } +} +*/ +TEST_F(UnswitchTest, UnswitchNotUniform) { + // Check that the unswitch is not triggered (condition loop invariant but not + // uniform) + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %c + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 330 + OpName %main "main" + OpName %c "c" + OpDecorate %c Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %bool = OpTypeBool +%_ptr_Function_bool = OpTypePointer Function %bool + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float + %c = OpVariable %_ptr_Input_v4float Input + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 +%_ptr_Input_float = OpTypePointer Input %float + %float_0 = OpConstant %float 0 + %uint_1 = OpConstant %uint 1 + %uint_2 = OpConstant %uint 2 + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %int_0 = OpConstant %int 0 + %int_10 = OpConstant %int 10 + %int_1 = OpConstant %int 1 + %main = OpFunction %void None %3 + %5 = OpLabel + %17 = OpAccessChain %_ptr_Input_float %c %uint_0 + %18 = OpLoad %float %17 + %20 = OpFOrdEqual %bool %18 %float_0 + OpSelectionMerge %22 None + OpBranchConditional %20 %21 %27 + %21 = OpLabel + %24 = OpAccessChain %_ptr_Input_float %c %uint_1 + %25 = OpLoad %float %24 + %26 = OpFOrdEqual %bool %25 %float_0 + OpBranch %22 + %27 = OpLabel + %29 = OpAccessChain %_ptr_Input_float %c %uint_2 + %30 = OpLoad %float %29 + %31 = OpFOrdEqual %bool %30 %float_0 + OpBranch %22 + %22 = OpLabel + %52 = OpPhi %bool %26 %21 %31 %27 + OpBranch %36 + %36 = OpLabel + %53 = OpPhi %int %int_0 %22 %51 %39 + OpLoopMerge %38 %39 None + OpBranch %40 + %40 = OpLabel + %43 = OpSLessThan %bool %53 %int_10 + OpBranchConditional %43 %37 %38 + %37 = OpLabel + OpSelectionMerge %46 None + OpBranchConditional %52 %45 %46 + %45 = OpLabel + %49 = OpIAdd %int %53 %int_1 + OpBranch %46 + %46 = OpLabel + %54 = OpPhi %int %53 %37 %49 %45 + OpBranch %39 + %39 = OpLabel + %51 = OpIAdd %int %54 %int_1 + OpBranch %36 + %38 = OpLabel + OpReturn + OpFunctionEnd + )"; + + auto result = + SinglePassRunAndDisassemble(text, true, false); + + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +TEST_F(UnswitchTest, DontUnswitchLatch) { + // Check that the unswitch is not triggered for the latch branch. + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %bool = OpTypeBool +%false = OpConstantFalse %bool + %4 = OpFunction %void None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + OpLoopMerge %8 %9 None + OpBranch %7 + %7 = OpLabel + OpBranch %9 + %9 = OpLabel + OpBranchConditional %false %6 %8 + %8 = OpLabel + OpReturn + OpFunctionEnd + )"; + + auto result = + SinglePassRunAndDisassemble(text, true, false); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +TEST_F(UnswitchTest, DontUnswitchConstantCondition) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginLowerLeft + OpSource GLSL 450 + OpName %main "main" + %void = OpTypeVoid + %4 = OpTypeFunction %void + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %bool = OpTypeBool + %true = OpConstantTrue %bool + %int_1 = OpConstant %int 1 + %main = OpFunction %void None %4 + %10 = OpLabel + OpBranch %11 + %11 = OpLabel + %12 = OpPhi %int %int_0 %10 %13 %14 + OpLoopMerge %15 %14 None + OpBranch %16 + %16 = OpLabel + %17 = OpSLessThan %bool %12 %int_1 + OpBranchConditional %17 %18 %15 + %18 = OpLabel + OpSelectionMerge %19 None + OpBranchConditional %true %20 %19 + %20 = OpLabel + %21 = OpIAdd %int %12 %int_1 + OpBranch %19 + %19 = OpLabel + %22 = OpPhi %int %21 %20 %12 %18 + OpBranch %14 + %14 = OpLabel + %13 = OpIAdd %int %22 %int_1 + OpBranch %11 + %15 = OpLabel + OpReturn + OpFunctionEnd + )"; + + auto result = + SinglePassRunAndDisassemble(text, true, false); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/module_test.cpp b/test/opt/module_test.cpp new file mode 100644 index 000000000..569cf9bcd --- /dev/null +++ b/test/opt/module_test.cpp @@ -0,0 +1,233 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "source/opt/build_module.h" +#include "source/opt/module.h" +#include "spirv-tools/libspirv.hpp" +#include "test/opt/module_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::Eq; +using spvtest::GetIdBound; + +TEST(ModuleTest, SetIdBound) { + Module m; + // It's initialized to 0. + EXPECT_EQ(0u, GetIdBound(m)); + + m.SetIdBound(19); + EXPECT_EQ(19u, GetIdBound(m)); + + m.SetIdBound(102); + EXPECT_EQ(102u, GetIdBound(m)); +} + +// Returns an IRContext owning the module formed by assembling the given text, +// then loading the result. +inline std::unique_ptr BuildModule(std::string text) { + return spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); +} + +TEST(ModuleTest, ComputeIdBound) { + // Emtpy module case. + EXPECT_EQ(1u, BuildModule("")->module()->ComputeIdBound()); + // Sensitive to result id + EXPECT_EQ(2u, BuildModule("%void = OpTypeVoid")->module()->ComputeIdBound()); + // Sensitive to type id + EXPECT_EQ(1000u, + BuildModule("%a = OpTypeArray !999 3")->module()->ComputeIdBound()); + // Sensitive to a regular Id parameter + EXPECT_EQ(2000u, + BuildModule("OpDecorate !1999 0")->module()->ComputeIdBound()); + // Sensitive to a scope Id parameter. + EXPECT_EQ(3000u, + BuildModule("%f = OpFunction %void None %fntype %a = OpLabel " + "OpMemoryBarrier !2999 %b\n") + ->module() + ->ComputeIdBound()); + // Sensitive to a semantics Id parameter + EXPECT_EQ(4000u, + BuildModule("%f = OpFunction %void None %fntype %a = OpLabel " + "OpMemoryBarrier %b !3999\n") + ->module() + ->ComputeIdBound()); +} + +TEST(ModuleTest, OstreamOperator) { + const std::string text = R"(OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %7 "restrict" +OpDecorate %8 Restrict +%9 = OpTypeVoid +%10 = OpTypeInt 32 0 +%11 = OpTypeStruct %10 %10 +%12 = OpTypePointer Function %10 +%13 = OpTypePointer Function %11 +%14 = OpConstant %10 0 +%15 = OpConstant %10 1 +%7 = OpTypeFunction %9 +%1 = OpFunction %9 None %7 +%2 = OpLabel +%8 = OpVariable %13 Function +%3 = OpAccessChain %12 %8 %14 +%4 = OpLoad %10 %3 +%5 = OpAccessChain %12 %8 %15 +%6 = OpLoad %10 %5 +OpReturn +OpFunctionEnd)"; + + std::string s; + std::ostringstream str(s); + str << *BuildModule(text)->module(); + EXPECT_EQ(text, str.str()); +} + +TEST(ModuleTest, OstreamOperatorInt64) { + const std::string text = R"(OpCapability Shader +OpCapability Linkage +OpCapability Int64 +OpMemoryModel Logical GLSL450 +OpName %7 "restrict" +OpDecorate %5 Restrict +%9 = OpTypeVoid +%10 = OpTypeInt 64 0 +%11 = OpTypeStruct %10 %10 +%12 = OpTypePointer Function %10 +%13 = OpTypePointer Function %11 +%14 = OpConstant %10 0 +%15 = OpConstant %10 1 +%16 = OpConstant %10 4294967297 +%7 = OpTypeFunction %9 +%1 = OpFunction %9 None %7 +%2 = OpLabel +%5 = OpVariable %12 Function +%6 = OpLoad %10 %5 +OpSelectionMerge %3 None +OpSwitch %6 %3 4294967297 %4 +%4 = OpLabel +OpBranch %3 +%3 = OpLabel +OpReturn +OpFunctionEnd)"; + + std::string s; + std::ostringstream str(s); + str << *BuildModule(text)->module(); + EXPECT_EQ(text, str.str()); +} + +TEST(ModuleTest, IdBoundTestAtLimit) { + const std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpFunction %1 None %2 +%4 = OpLabel +OpReturn +OpFunctionEnd)"; + + std::unique_ptr context = BuildModule(text); + uint32_t current_bound = context->module()->id_bound(); + context->set_max_id_bound(current_bound); + uint32_t next_id_bound = context->module()->TakeNextIdBound(); + EXPECT_EQ(next_id_bound, 0); + EXPECT_EQ(current_bound, context->module()->id_bound()); + next_id_bound = context->module()->TakeNextIdBound(); + EXPECT_EQ(next_id_bound, 0); +} + +TEST(ModuleTest, IdBoundTestBelowLimit) { + const std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpFunction %1 None %2 +%4 = OpLabel +OpReturn +OpFunctionEnd)"; + + std::unique_ptr context = BuildModule(text); + uint32_t current_bound = context->module()->id_bound(); + context->set_max_id_bound(current_bound + 100); + uint32_t next_id_bound = context->module()->TakeNextIdBound(); + EXPECT_EQ(next_id_bound, current_bound); + EXPECT_EQ(current_bound + 1, context->module()->id_bound()); + next_id_bound = context->module()->TakeNextIdBound(); + EXPECT_EQ(next_id_bound, current_bound + 1); +} + +TEST(ModuleTest, IdBoundTestNearLimit) { + const std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpFunction %1 None %2 +%4 = OpLabel +OpReturn +OpFunctionEnd)"; + + std::unique_ptr context = BuildModule(text); + uint32_t current_bound = context->module()->id_bound(); + context->set_max_id_bound(current_bound + 1); + uint32_t next_id_bound = context->module()->TakeNextIdBound(); + EXPECT_EQ(next_id_bound, current_bound); + EXPECT_EQ(current_bound + 1, context->module()->id_bound()); + next_id_bound = context->module()->TakeNextIdBound(); + EXPECT_EQ(next_id_bound, 0); +} + +TEST(ModuleTest, IdBoundTestUIntMax) { + const std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpFunction %1 None %2 +%4294967294 = OpLabel ; ID is UINT_MAX-1 +OpReturn +OpFunctionEnd)"; + + std::unique_ptr context = BuildModule(text); + uint32_t current_bound = context->module()->id_bound(); + + // Expecting |BuildModule| to preserve the numeric ids. + EXPECT_EQ(current_bound, std::numeric_limits::max()); + + context->set_max_id_bound(current_bound); + uint32_t next_id_bound = context->module()->TakeNextIdBound(); + EXPECT_EQ(next_id_bound, 0); + EXPECT_EQ(current_bound, context->module()->id_bound()); +} +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/module_utils.h b/test/opt/module_utils.h new file mode 100644 index 000000000..007f132c2 --- /dev/null +++ b/test/opt/module_utils.h @@ -0,0 +1,34 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TEST_OPT_MODULE_UTILS_H_ +#define TEST_OPT_MODULE_UTILS_H_ + +#include +#include "source/opt/module.h" + +namespace spvtest { + +inline uint32_t GetIdBound(const spvtools::opt::Module& m) { + std::vector binary; + m.ToBinary(&binary, false); + // The 5-word header must always exist. + EXPECT_LE(5u, binary.size()); + // The bound is the fourth word. + return binary[3]; +} + +} // namespace spvtest + +#endif // TEST_OPT_MODULE_UTILS_H_ diff --git a/test/opt/optimizer_test.cpp b/test/opt/optimizer_test.cpp new file mode 100644 index 000000000..90abc00d0 --- /dev/null +++ b/test/opt/optimizer_test.cpp @@ -0,0 +1,227 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gmock/gmock.h" +#include "spirv-tools/libspirv.hpp" +#include "spirv-tools/optimizer.hpp" +#include "test/opt/pass_fixture.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::Eq; + +// Return a string that contains the minimum instructions needed to form +// a valid module. Other instructions can be appended to this string. +std::string Header() { + return R"(OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +)"; +} + +TEST(Optimizer, CanRunNullPassWithDistinctInputOutputVectors) { + SpirvTools tools(SPV_ENV_UNIVERSAL_1_0); + std::vector binary_in; + tools.Assemble(Header() + "OpName %foo \"foo\"\n%foo = OpTypeVoid", + &binary_in); + + Optimizer opt(SPV_ENV_UNIVERSAL_1_0); + opt.RegisterPass(CreateNullPass()); + std::vector binary_out; + opt.Run(binary_in.data(), binary_in.size(), &binary_out); + + std::string disassembly; + tools.Disassemble(binary_out.data(), binary_out.size(), &disassembly); + EXPECT_THAT(disassembly, + Eq(Header() + "OpName %foo \"foo\"\n%foo = OpTypeVoid\n")); +} + +TEST(Optimizer, CanRunTransformingPassWithDistinctInputOutputVectors) { + SpirvTools tools(SPV_ENV_UNIVERSAL_1_0); + std::vector binary_in; + tools.Assemble(Header() + "OpName %foo \"foo\"\n%foo = OpTypeVoid", + &binary_in); + + Optimizer opt(SPV_ENV_UNIVERSAL_1_0); + opt.RegisterPass(CreateStripDebugInfoPass()); + std::vector binary_out; + opt.Run(binary_in.data(), binary_in.size(), &binary_out); + + std::string disassembly; + tools.Disassemble(binary_out.data(), binary_out.size(), &disassembly); + EXPECT_THAT(disassembly, Eq(Header() + "%void = OpTypeVoid\n")); +} + +TEST(Optimizer, CanRunNullPassWithAliasedVectors) { + SpirvTools tools(SPV_ENV_UNIVERSAL_1_0); + std::vector binary; + tools.Assemble("OpName %foo \"foo\"\n%foo = OpTypeVoid", &binary); + + Optimizer opt(SPV_ENV_UNIVERSAL_1_0); + opt.RegisterPass(CreateNullPass()); + opt.Run(binary.data(), binary.size(), &binary); // This is the key. + + std::string disassembly; + tools.Disassemble(binary.data(), binary.size(), &disassembly); + EXPECT_THAT(disassembly, Eq("OpName %foo \"foo\"\n%foo = OpTypeVoid\n")); +} + +TEST(Optimizer, CanRunNullPassWithAliasedVectorDataButDifferentSize) { + SpirvTools tools(SPV_ENV_UNIVERSAL_1_0); + std::vector binary; + tools.Assemble(Header() + "OpName %foo \"foo\"\n%foo = OpTypeVoid", &binary); + + Optimizer opt(SPV_ENV_UNIVERSAL_1_0); + opt.RegisterPass(CreateNullPass()); + auto orig_size = binary.size(); + // Now change the size. Add a word that will be ignored + // by the optimizer. + binary.push_back(42); + EXPECT_THAT(orig_size + 1, Eq(binary.size())); + opt.Run(binary.data(), orig_size, &binary); // This is the key. + // The binary vector should have been rewritten. + EXPECT_THAT(binary.size(), Eq(orig_size)); + + std::string disassembly; + tools.Disassemble(binary.data(), binary.size(), &disassembly); + EXPECT_THAT(disassembly, + Eq(Header() + "OpName %foo \"foo\"\n%foo = OpTypeVoid\n")); +} + +TEST(Optimizer, CanRunTransformingPassWithAliasedVectors) { + SpirvTools tools(SPV_ENV_UNIVERSAL_1_0); + std::vector binary; + tools.Assemble(Header() + "OpName %foo \"foo\"\n%foo = OpTypeVoid", &binary); + + Optimizer opt(SPV_ENV_UNIVERSAL_1_0); + opt.RegisterPass(CreateStripDebugInfoPass()); + opt.Run(binary.data(), binary.size(), &binary); // This is the key + + std::string disassembly; + tools.Disassemble(binary.data(), binary.size(), &disassembly); + EXPECT_THAT(disassembly, Eq(Header() + "%void = OpTypeVoid\n")); +} + +TEST(Optimizer, CanValidateFlags) { + Optimizer opt(SPV_ENV_UNIVERSAL_1_0); + EXPECT_FALSE(opt.FlagHasValidForm("bad-flag")); + EXPECT_TRUE(opt.FlagHasValidForm("-O")); + EXPECT_TRUE(opt.FlagHasValidForm("-Os")); + EXPECT_FALSE(opt.FlagHasValidForm("-O2")); + EXPECT_TRUE(opt.FlagHasValidForm("--this_flag")); +} + +TEST(Optimizer, CanRegisterPassesFromFlags) { + SpirvTools tools(SPV_ENV_UNIVERSAL_1_0); + Optimizer opt(SPV_ENV_UNIVERSAL_1_0); + + spv_message_level_t msg_level; + const char* msg_fname; + spv_position_t msg_position; + const char* msg; + auto examine_message = [&msg_level, &msg_fname, &msg_position, &msg]( + spv_message_level_t ml, const char* f, + const spv_position_t& p, const char* m) { + msg_level = ml; + msg_fname = f; + msg_position = p; + msg = m; + }; + opt.SetMessageConsumer(examine_message); + + std::vector pass_flags = { + "--strip-debug", + "--strip-reflect", + "--set-spec-const-default-value=23:42 21:12", + "--if-conversion", + "--freeze-spec-const", + "--inline-entry-points-exhaustive", + "--inline-entry-points-opaque", + "--convert-local-access-chains", + "--eliminate-dead-code-aggressive", + "--eliminate-insert-extract", + "--eliminate-local-single-block", + "--eliminate-local-single-store", + "--merge-blocks", + "--merge-return", + "--eliminate-dead-branches", + "--eliminate-dead-functions", + "--eliminate-local-multi-store", + "--eliminate-common-uniform", + "--eliminate-dead-const", + "--eliminate-dead-inserts", + "--eliminate-dead-variables", + "--fold-spec-const-op-composite", + "--loop-unswitch", + "--scalar-replacement=300", + "--scalar-replacement", + "--strength-reduction", + "--unify-const", + "--flatten-decorations", + "--compact-ids", + "--cfg-cleanup", + "--local-redundancy-elimination", + "--loop-invariant-code-motion", + "--reduce-load-size", + "--redundancy-elimination", + "--private-to-local", + "--remove-duplicates", + "--workaround-1209", + "--replace-invalid-opcode", + "--simplify-instructions", + "--ssa-rewrite", + "--copy-propagate-arrays", + "--loop-fission=20", + "--loop-fusion=2", + "--loop-unroll", + "--vector-dce", + "--loop-unroll-partial=3", + "--loop-peeling", + "--ccp", + "-O", + "-Os", + "--legalize-hlsl"}; + EXPECT_TRUE(opt.RegisterPassesFromFlags(pass_flags)); + + // Test some invalid flags. + EXPECT_FALSE(opt.RegisterPassFromFlag("-O2")); + EXPECT_EQ(msg_level, SPV_MSG_ERROR); + + EXPECT_FALSE(opt.RegisterPassFromFlag("-loop-unroll")); + EXPECT_EQ(msg_level, SPV_MSG_ERROR); + + EXPECT_FALSE(opt.RegisterPassFromFlag("--set-spec-const-default-value")); + EXPECT_EQ(msg_level, SPV_MSG_ERROR); + + EXPECT_FALSE(opt.RegisterPassFromFlag("--scalar-replacement=s")); + EXPECT_EQ(msg_level, SPV_MSG_ERROR); + + EXPECT_FALSE(opt.RegisterPassFromFlag("--loop-fission=-4")); + EXPECT_EQ(msg_level, SPV_MSG_ERROR); + + EXPECT_FALSE(opt.RegisterPassFromFlag("--loop-fusion=xx")); + EXPECT_EQ(msg_level, SPV_MSG_ERROR); + + EXPECT_FALSE(opt.RegisterPassFromFlag("--loop-unroll-partial")); + EXPECT_EQ(msg_level, SPV_MSG_ERROR); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/pass_fixture.h b/test/opt/pass_fixture.h new file mode 100644 index 000000000..117bc3126 --- /dev/null +++ b/test/opt/pass_fixture.h @@ -0,0 +1,249 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TEST_OPT_PASS_FIXTURE_H_ +#define TEST_OPT_PASS_FIXTURE_H_ + +#include +#include +#include +#include +#include +#include + +#include "effcee/effcee.h" +#include "gtest/gtest.h" +#include "source/opt/build_module.h" +#include "source/opt/pass_manager.h" +#include "source/opt/passes.h" +#include "source/spirv_validator_options.h" +#include "source/util/make_unique.h" +#include "spirv-tools/libspirv.hpp" + +namespace spvtools { +namespace opt { + +// Template class for testing passes. It contains some handy utility methods for +// running passes and checking results. +// +// To write value-Parameterized tests: +// using ValueParamTest = PassTest<::testing::TestWithParam>; +// To use as normal fixture: +// using FixtureTest = PassTest<::testing::Test>; +template +class PassTest : public TestT { + public: + PassTest() + : consumer_( + [](spv_message_level_t, const char*, const spv_position_t&, + const char* message) { std::cerr << message << std::endl; }), + context_(nullptr), + tools_(SPV_ENV_UNIVERSAL_1_3), + manager_(new PassManager()), + assemble_options_(SpirvTools::kDefaultAssembleOption), + disassemble_options_(SpirvTools::kDefaultDisassembleOption) {} + + // Runs the given |pass| on the binary assembled from the |original|. + // Returns a tuple of the optimized binary and the boolean value returned + // from pass Process() function. + std::tuple, Pass::Status> OptimizeToBinary( + Pass* pass, const std::string& original, bool skip_nop) { + context_ = std::move(BuildModule(SPV_ENV_UNIVERSAL_1_3, consumer_, original, + assemble_options_)); + EXPECT_NE(nullptr, context()) << "Assembling failed for shader:\n" + << original << std::endl; + if (!context()) { + return std::make_tuple(std::vector(), Pass::Status::Failure); + } + + const auto status = pass->Run(context()); + + std::vector binary; + context()->module()->ToBinary(&binary, skip_nop); + return std::make_tuple(binary, status); + } + + // Runs a single pass of class |PassT| on the binary assembled from the + // |assembly|. Returns a tuple of the optimized binary and the boolean value + // from the pass Process() function. + template + std::tuple, Pass::Status> SinglePassRunToBinary( + const std::string& assembly, bool skip_nop, Args&&... args) { + auto pass = MakeUnique(std::forward(args)...); + pass->SetMessageConsumer(consumer_); + return OptimizeToBinary(pass.get(), assembly, skip_nop); + } + + // Runs a single pass of class |PassT| on the binary assembled from the + // |assembly|, disassembles the optimized binary. Returns a tuple of + // disassembly string and the boolean value from the pass Process() function. + template + std::tuple SinglePassRunAndDisassemble( + const std::string& assembly, bool skip_nop, bool do_validation, + Args&&... args) { + std::vector optimized_bin; + auto status = Pass::Status::SuccessWithoutChange; + std::tie(optimized_bin, status) = SinglePassRunToBinary( + assembly, skip_nop, std::forward(args)...); + if (do_validation) { + spv_target_env target_env = SPV_ENV_UNIVERSAL_1_3; + spv_context spvContext = spvContextCreate(target_env); + spv_diagnostic diagnostic = nullptr; + spv_const_binary_t binary = {optimized_bin.data(), optimized_bin.size()}; + spv_result_t error = spvValidateWithOptions( + spvContext, ValidatorOptions(), &binary, &diagnostic); + EXPECT_EQ(error, 0); + if (error != 0) spvDiagnosticPrint(diagnostic); + spvDiagnosticDestroy(diagnostic); + spvContextDestroy(spvContext); + } + std::string optimized_asm; + EXPECT_TRUE( + tools_.Disassemble(optimized_bin, &optimized_asm, disassemble_options_)) + << "Disassembling failed for shader:\n" + << assembly << std::endl; + return std::make_tuple(optimized_asm, status); + } + + // Runs a single pass of class |PassT| on the binary assembled from the + // |original| assembly, and checks whether the optimized binary can be + // disassembled to the |expected| assembly. Optionally will also validate + // the optimized binary. This does *not* involve pass manager. Callers + // are suggested to use SCOPED_TRACE() for better messages. + template + void SinglePassRunAndCheck(const std::string& original, + const std::string& expected, bool skip_nop, + bool do_validation, Args&&... args) { + std::vector optimized_bin; + auto status = Pass::Status::SuccessWithoutChange; + std::tie(optimized_bin, status) = SinglePassRunToBinary( + original, skip_nop, std::forward(args)...); + // Check whether the pass returns the correct modification indication. + EXPECT_NE(Pass::Status::Failure, status); + EXPECT_EQ(original == expected, + status == Pass::Status::SuccessWithoutChange); + if (do_validation) { + spv_target_env target_env = SPV_ENV_UNIVERSAL_1_3; + spv_context spvContext = spvContextCreate(target_env); + spv_diagnostic diagnostic = nullptr; + spv_const_binary_t binary = {optimized_bin.data(), optimized_bin.size()}; + spv_result_t error = spvValidateWithOptions( + spvContext, ValidatorOptions(), &binary, &diagnostic); + EXPECT_EQ(error, 0); + if (error != 0) spvDiagnosticPrint(diagnostic); + spvDiagnosticDestroy(diagnostic); + spvContextDestroy(spvContext); + } + std::string optimized_asm; + EXPECT_TRUE( + tools_.Disassemble(optimized_bin, &optimized_asm, disassemble_options_)) + << "Disassembling failed for shader:\n" + << original << std::endl; + EXPECT_EQ(expected, optimized_asm); + } + + // Runs a single pass of class |PassT| on the binary assembled from the + // |original| assembly, and checks whether the optimized binary can be + // disassembled to the |expected| assembly. This does *not* involve pass + // manager. Callers are suggested to use SCOPED_TRACE() for better messages. + template + void SinglePassRunAndCheck(const std::string& original, + const std::string& expected, bool skip_nop, + Args&&... args) { + SinglePassRunAndCheck(original, expected, skip_nop, false, + std::forward(args)...); + } + + // Runs a single pass of class |PassT| on the binary assembled from the + // |original| assembly, then runs an Effcee matcher over the disassembled + // result, using checks parsed from |original|. Always skips OpNop. + // This does *not* involve pass manager. Callers are suggested to use + // SCOPED_TRACE() for better messages. + template + void SinglePassRunAndMatch(const std::string& original, bool do_validation, + Args&&... args) { + const bool skip_nop = true; + auto pass_result = SinglePassRunAndDisassemble( + original, skip_nop, do_validation, std::forward(args)...); + auto disassembly = std::get<0>(pass_result); + auto match_result = effcee::Match(disassembly, original); + EXPECT_EQ(effcee::Result::Status::Ok, match_result.status()) + << match_result.message() << "\nChecking result:\n" + << disassembly; + } + + // Adds a pass to be run. + template + void AddPass(Args&&... args) { + manager_->AddPass(std::forward(args)...); + } + + // Renews the pass manager, including clearing all previously added passes. + void RenewPassManger() { + manager_ = MakeUnique(); + manager_->SetMessageConsumer(consumer_); + } + + // Runs the passes added thus far using a pass manager on the binary assembled + // from the |original| assembly, and checks whether the optimized binary can + // be disassembled to the |expected| assembly. Callers are suggested to use + // SCOPED_TRACE() for better messages. + void RunAndCheck(const std::string& original, const std::string& expected) { + assert(manager_->NumPasses()); + + context_ = std::move(BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, original, + assemble_options_)); + ASSERT_NE(nullptr, context()); + + manager_->Run(context()); + + std::vector binary; + context()->module()->ToBinary(&binary, /* skip_nop = */ false); + + std::string optimized; + EXPECT_TRUE(tools_.Disassemble(binary, &optimized, disassemble_options_)); + EXPECT_EQ(expected, optimized); + } + + void SetAssembleOptions(uint32_t assemble_options) { + assemble_options_ = assemble_options; + } + + void SetDisassembleOptions(uint32_t disassemble_options) { + disassemble_options_ = disassemble_options; + } + + MessageConsumer consumer() { return consumer_; } + IRContext* context() { return context_.get(); } + + void SetMessageConsumer(MessageConsumer msg_consumer) { + consumer_ = msg_consumer; + } + + spv_validator_options ValidatorOptions() { return &validator_options_; } + + private: + MessageConsumer consumer_; // Message consumer. + std::unique_ptr context_; // IR context + SpirvTools tools_; // An instance for calling SPIRV-Tools functionalities. + std::unique_ptr manager_; // The pass manager. + uint32_t assemble_options_; + uint32_t disassemble_options_; + spv_validator_options_t validator_options_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // TEST_OPT_PASS_FIXTURE_H_ diff --git a/test/opt/pass_manager_test.cpp b/test/opt/pass_manager_test.cpp new file mode 100644 index 000000000..22d5e22ec --- /dev/null +++ b/test/opt/pass_manager_test.cpp @@ -0,0 +1,191 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/util/make_unique.h" +#include "test/opt/module_utils.h" +#include "test/opt/pass_fixture.h" + +namespace spvtools { +namespace opt { +namespace { + +using spvtest::GetIdBound; +using ::testing::Eq; + +// A null pass whose construtors accept arguments +class NullPassWithArgs : public NullPass { + public: + NullPassWithArgs(uint32_t) {} + NullPassWithArgs(std::string) {} + NullPassWithArgs(const std::vector&) {} + NullPassWithArgs(const std::vector&, uint32_t) {} + + const char* name() const override { return "null-with-args"; } +}; + +TEST(PassManager, Interface) { + PassManager manager; + EXPECT_EQ(0u, manager.NumPasses()); + + manager.AddPass(); + EXPECT_EQ(1u, manager.NumPasses()); + EXPECT_STREQ("strip-debug", manager.GetPass(0)->name()); + + manager.AddPass(MakeUnique()); + EXPECT_EQ(2u, manager.NumPasses()); + EXPECT_STREQ("strip-debug", manager.GetPass(0)->name()); + EXPECT_STREQ("null", manager.GetPass(1)->name()); + + manager.AddPass(); + EXPECT_EQ(3u, manager.NumPasses()); + EXPECT_STREQ("strip-debug", manager.GetPass(0)->name()); + EXPECT_STREQ("null", manager.GetPass(1)->name()); + EXPECT_STREQ("strip-debug", manager.GetPass(2)->name()); + + manager.AddPass(1u); + manager.AddPass("null pass args"); + manager.AddPass(std::initializer_list{1, 2}); + manager.AddPass(std::initializer_list{1, 2}, 3); + EXPECT_EQ(7u, manager.NumPasses()); + EXPECT_STREQ("strip-debug", manager.GetPass(0)->name()); + EXPECT_STREQ("null", manager.GetPass(1)->name()); + EXPECT_STREQ("strip-debug", manager.GetPass(2)->name()); + EXPECT_STREQ("null-with-args", manager.GetPass(3)->name()); + EXPECT_STREQ("null-with-args", manager.GetPass(4)->name()); + EXPECT_STREQ("null-with-args", manager.GetPass(5)->name()); + EXPECT_STREQ("null-with-args", manager.GetPass(6)->name()); +} + +// A pass that appends an OpNop instruction to the debug1 section. +class AppendOpNopPass : public Pass { + public: + const char* name() const override { return "AppendOpNop"; } + Status Process() override { + context()->AddDebug1Inst(MakeUnique(context())); + return Status::SuccessWithChange; + } +}; + +// A pass that appends specified number of OpNop instructions to the debug1 +// section. +class AppendMultipleOpNopPass : public Pass { + public: + explicit AppendMultipleOpNopPass(uint32_t num_nop) : num_nop_(num_nop) {} + + const char* name() const override { return "AppendOpNop"; } + Status Process() override { + for (uint32_t i = 0; i < num_nop_; i++) { + context()->AddDebug1Inst(MakeUnique(context())); + } + return Status::SuccessWithChange; + } + + private: + uint32_t num_nop_; +}; + +// A pass that duplicates the last instruction in the debug1 section. +class DuplicateInstPass : public Pass { + public: + const char* name() const override { return "DuplicateInst"; } + Status Process() override { + auto inst = MakeUnique(*(--context()->debug1_end())); + context()->AddDebug1Inst(std::move(inst)); + return Status::SuccessWithChange; + } +}; + +using PassManagerTest = PassTest<::testing::Test>; + +TEST_F(PassManagerTest, Run) { + const std::string text = "OpMemoryModel Logical GLSL450\nOpSource ESSL 310\n"; + + AddPass(); + AddPass(); + RunAndCheck(text, text + "OpNop\nOpNop\n"); + + RenewPassManger(); + AddPass(); + AddPass(); + RunAndCheck(text, text + "OpNop\nOpNop\n"); + + RenewPassManger(); + AddPass(); + AddPass(); + RunAndCheck(text, text + "OpSource ESSL 310\nOpNop\n"); + + RenewPassManger(); + AddPass(3); + RunAndCheck(text, text + "OpNop\nOpNop\nOpNop\n"); +} + +// A pass that appends an OpTypeVoid instruction that uses a given id. +class AppendTypeVoidInstPass : public Pass { + public: + explicit AppendTypeVoidInstPass(uint32_t result_id) : result_id_(result_id) {} + + const char* name() const override { return "AppendTypeVoidInstPass"; } + Status Process() override { + auto inst = MakeUnique(context(), SpvOpTypeVoid, 0, result_id_, + std::vector{}); + context()->AddType(std::move(inst)); + return Status::SuccessWithChange; + } + + private: + uint32_t result_id_; +}; + +TEST(PassManager, RecomputeIdBoundAutomatically) { + PassManager manager; + std::unique_ptr module(new Module()); + IRContext context(SPV_ENV_UNIVERSAL_1_2, std::move(module), + manager.consumer()); + EXPECT_THAT(GetIdBound(*context.module()), Eq(0u)); + + manager.Run(&context); + manager.AddPass(); + // With no ID changes, the ID bound does not change. + EXPECT_THAT(GetIdBound(*context.module()), Eq(0u)); + + // Now we force an Id of 100 to be used. + manager.AddPass(MakeUnique(100)); + EXPECT_THAT(GetIdBound(*context.module()), Eq(0u)); + manager.Run(&context); + // The Id has been updated automatically, even though the pass + // did not update it. + EXPECT_THAT(GetIdBound(*context.module()), Eq(101u)); + + // Try one more time! + manager.AddPass(MakeUnique(200)); + manager.Run(&context); + EXPECT_THAT(GetIdBound(*context.module()), Eq(201u)); + + // Add another pass, but which uses a lower Id. + manager.AddPass(MakeUnique(10)); + manager.Run(&context); + // The Id stays high. + EXPECT_THAT(GetIdBound(*context.module()), Eq(201u)); +} + +} // anonymous namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/pass_merge_return_test.cpp b/test/opt/pass_merge_return_test.cpp new file mode 100644 index 000000000..9c8d80d5c --- /dev/null +++ b/test/opt/pass_merge_return_test.cpp @@ -0,0 +1,1373 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "spirv-tools/libspirv.hpp" +#include "spirv-tools/optimizer.hpp" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using MergeReturnPassTest = PassTest<::testing::Test>; + +TEST_F(MergeReturnPassTest, OneReturn) { + const std::string before = + R"(OpCapability Addresses +OpCapability Kernel +OpCapability GenericPointer +OpCapability Linkage +OpMemoryModel Physical32 OpenCL +OpEntryPoint Kernel %1 "simple_kernel" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%1 = OpFunction %2 None %3 +%4 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string after = before; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(before, after, false, true); +} + +TEST_F(MergeReturnPassTest, TwoReturnsNoValue) { + const std::string before = + R"(OpCapability Addresses +OpCapability Kernel +OpCapability GenericPointer +OpCapability Linkage +OpMemoryModel Physical32 OpenCL +OpEntryPoint Kernel %6 "simple_kernel" +%2 = OpTypeVoid +%3 = OpTypeBool +%4 = OpConstantFalse %3 +%1 = OpTypeFunction %2 +%6 = OpFunction %2 None %1 +%7 = OpLabel +OpBranchConditional %4 %8 %9 +%8 = OpLabel +OpReturn +%9 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(OpCapability Addresses +OpCapability Kernel +OpCapability GenericPointer +OpCapability Linkage +OpMemoryModel Physical32 OpenCL +OpEntryPoint Kernel %6 "simple_kernel" +%2 = OpTypeVoid +%3 = OpTypeBool +%4 = OpConstantFalse %3 +%1 = OpTypeFunction %2 +%6 = OpFunction %2 None %1 +%7 = OpLabel +OpBranchConditional %4 %8 %9 +%8 = OpLabel +OpBranch %10 +%9 = OpLabel +OpBranch %10 +%10 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(before, after, false, true); +} + +TEST_F(MergeReturnPassTest, TwoReturnsWithValues) { + const std::string before = + R"(OpCapability Linkage +OpCapability Kernel +OpMemoryModel Logical OpenCL +OpDecorate %7 LinkageAttributes "simple_kernel" Export +%1 = OpTypeInt 32 0 +%2 = OpTypeBool +%3 = OpConstantFalse %2 +%4 = OpConstant %1 0 +%5 = OpConstant %1 1 +%6 = OpTypeFunction %1 +%7 = OpFunction %1 None %6 +%8 = OpLabel +OpBranchConditional %3 %9 %10 +%9 = OpLabel +OpReturnValue %4 +%10 = OpLabel +OpReturnValue %5 +OpFunctionEnd +)"; + + const std::string after = + R"(OpCapability Linkage +OpCapability Kernel +OpMemoryModel Logical OpenCL +OpDecorate %7 LinkageAttributes "simple_kernel" Export +%1 = OpTypeInt 32 0 +%2 = OpTypeBool +%3 = OpConstantFalse %2 +%4 = OpConstant %1 0 +%5 = OpConstant %1 1 +%6 = OpTypeFunction %1 +%7 = OpFunction %1 None %6 +%8 = OpLabel +OpBranchConditional %3 %9 %10 +%9 = OpLabel +OpBranch %11 +%10 = OpLabel +OpBranch %11 +%11 = OpLabel +%12 = OpPhi %1 %4 %9 %5 %10 +OpReturnValue %12 +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(before, after, false, true); +} + +TEST_F(MergeReturnPassTest, UnreachableReturnsNoValue) { + const std::string before = + R"(OpCapability Addresses +OpCapability Kernel +OpCapability GenericPointer +OpCapability Linkage +OpMemoryModel Physical32 OpenCL +OpEntryPoint Kernel %6 "simple_kernel" +%2 = OpTypeVoid +%3 = OpTypeBool +%4 = OpConstantFalse %3 +%1 = OpTypeFunction %2 +%6 = OpFunction %2 None %1 +%7 = OpLabel +OpReturn +%8 = OpLabel +OpBranchConditional %4 %9 %10 +%9 = OpLabel +OpReturn +%10 = OpLabel +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(OpCapability Addresses +OpCapability Kernel +OpCapability GenericPointer +OpCapability Linkage +OpMemoryModel Physical32 OpenCL +OpEntryPoint Kernel %6 "simple_kernel" +%2 = OpTypeVoid +%3 = OpTypeBool +%4 = OpConstantFalse %3 +%1 = OpTypeFunction %2 +%6 = OpFunction %2 None %1 +%7 = OpLabel +OpBranch %11 +%8 = OpLabel +OpBranchConditional %4 %9 %10 +%9 = OpLabel +OpBranch %11 +%10 = OpLabel +OpBranch %11 +%11 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(before, after, false, true); +} + +TEST_F(MergeReturnPassTest, UnreachableReturnsWithValues) { + const std::string before = + R"(OpCapability Linkage +OpCapability Kernel +OpMemoryModel Logical OpenCL +OpDecorate %7 LinkageAttributes "simple_kernel" Export +%1 = OpTypeInt 32 0 +%2 = OpTypeBool +%3 = OpConstantFalse %2 +%4 = OpConstant %1 0 +%5 = OpConstant %1 1 +%6 = OpTypeFunction %1 +%7 = OpFunction %1 None %6 +%8 = OpLabel +%9 = OpIAdd %1 %4 %5 +OpReturnValue %9 +%10 = OpLabel +OpBranchConditional %3 %11 %12 +%11 = OpLabel +OpReturnValue %4 +%12 = OpLabel +OpReturnValue %5 +OpFunctionEnd +)"; + + const std::string after = + R"(OpCapability Linkage +OpCapability Kernel +OpMemoryModel Logical OpenCL +OpDecorate %7 LinkageAttributes "simple_kernel" Export +%1 = OpTypeInt 32 0 +%2 = OpTypeBool +%3 = OpConstantFalse %2 +%4 = OpConstant %1 0 +%5 = OpConstant %1 1 +%6 = OpTypeFunction %1 +%7 = OpFunction %1 None %6 +%8 = OpLabel +%9 = OpIAdd %1 %4 %5 +OpBranch %13 +%10 = OpLabel +OpBranchConditional %3 %11 %12 +%11 = OpLabel +OpBranch %13 +%12 = OpLabel +OpBranch %13 +%13 = OpLabel +%14 = OpPhi %1 %9 %8 %4 %11 %5 %12 +OpReturnValue %14 +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(before, after, false, true); +} + +TEST_F(MergeReturnPassTest, StructuredControlFlowWithUnreachableMerge) { + const std::string before = + R"( +; CHECK: [[false:%\w+]] = OpConstantFalse +; CHECK: [[true:%\w+]] = OpConstantTrue +; CHECK: OpFunction +; CHECK: [[var:%\w+]] = OpVariable [[:%\w+]] Function [[false]] +; CHECK: OpLoopMerge [[return_block:%\w+]] +; CHECK: OpSelectionMerge [[merge_lab:%\w+]] +; CHECK: OpBranchConditional [[cond:%\w+]] [[if_lab:%\w+]] [[then_lab:%\w+]] +; CHECK: [[if_lab]] = OpLabel +; CHECK-NEXT: OpStore [[var]] [[true]] +; CHECK-NEXT: OpBranch [[return_block]] +; CHECK: [[then_lab]] = OpLabel +; CHECK-NEXT: OpStore [[var]] [[true]] +; CHECK-NEXT: OpBranch [[return_block]] +; CHECK: [[merge_lab]] = OpLabel +; CHECK-NEXT: OpBranch [[return_block]] +; CHECK: [[return_block]] = OpLabel +; CHECK-NEXT: OpReturn +OpCapability Addresses +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %6 "simple_shader" +%2 = OpTypeVoid +%3 = OpTypeBool +%4 = OpConstantFalse %3 +%1 = OpTypeFunction %2 +%6 = OpFunction %2 None %1 +%7 = OpLabel +OpSelectionMerge %10 None +OpBranchConditional %4 %8 %9 +%8 = OpLabel +OpReturn +%9 = OpLabel +OpReturn +%10 = OpLabel +OpUnreachable +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndMatch(before, false); +} + +TEST_F(MergeReturnPassTest, StructuredControlFlowAddPhi) { + const std::string before = + R"( +; CHECK: [[false:%\w+]] = OpConstantFalse +; CHECK: [[true:%\w+]] = OpConstantTrue +; CHECK: OpFunction +; CHECK: [[var:%\w+]] = OpVariable [[:%\w+]] Function [[false]] +; CHECK: OpLoopMerge [[dummy_loop_merge:%\w+]] +; CHECK: OpSelectionMerge [[merge_lab:%\w+]] +; CHECK: OpBranchConditional [[cond:%\w+]] [[if_lab:%\w+]] [[then_lab:%\w+]] +; CHECK: [[if_lab]] = OpLabel +; CHECK-NEXT: [[add:%\w+]] = OpIAdd [[type:%\w+]] +; CHECK-NEXT: OpBranch +; CHECK: [[then_lab]] = OpLabel +; CHECK-NEXT: OpStore [[var]] [[true]] +; CHECK-NEXT: OpBranch [[dummy_loop_merge]] +; CHECK: [[merge_lab]] = OpLabel +; CHECK: [[dummy_loop_merge]] = OpLabel +; CHECK-NEXT: OpReturn +OpCapability Addresses +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %6 "simple_shader" +%2 = OpTypeVoid +%3 = OpTypeBool +%int = OpTypeInt 32 0 +%int_0 = OpConstant %int 0 +%4 = OpConstantFalse %3 +%1 = OpTypeFunction %2 +%6 = OpFunction %2 None %1 +%7 = OpLabel +OpSelectionMerge %10 None +OpBranchConditional %4 %8 %9 +%8 = OpLabel +%11 = OpIAdd %int %int_0 %int_0 +OpBranch %10 +%9 = OpLabel +OpReturn +%10 = OpLabel +%12 = OpIAdd %int %11 %11 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndMatch(before, false); +} + +TEST_F(MergeReturnPassTest, StructuredControlDecoration) { + const std::string before = + R"( +; CHECK: OpDecorate [[dec_id:%\w+]] RelaxedPrecision +; CHECK: [[false:%\w+]] = OpConstantFalse +; CHECK: [[true:%\w+]] = OpConstantTrue +; CHECK: OpFunction +; CHECK: [[var:%\w+]] = OpVariable [[:%\w+]] Function [[false]] +; CHECK: OpLoopMerge [[return_block:%\w+]] +; CHECK: OpSelectionMerge [[merge_lab:%\w+]] +; CHECK: OpBranchConditional [[cond:%\w+]] [[if_lab:%\w+]] [[then_lab:%\w+]] +; CHECK: [[if_lab]] = OpLabel +; CHECK-NEXT: [[dec_id]] = OpIAdd [[type:%\w+]] +; CHECK-NEXT: OpBranch +; CHECK: [[then_lab]] = OpLabel +; CHECK-NEXT: OpStore [[var]] [[true]] +; CHECK-NEXT: OpBranch [[return_block]] +; CHECK: [[merge_lab]] = OpLabel +; CHECK-NEXT: OpStore [[var]] [[true]] +; CHECK-NEXT: OpBranch [[return_block]] +; CHECK: [[return_block]] = OpLabel +; CHECK-NEXT: OpReturn +OpCapability Addresses +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %6 "simple_shader" +OpDecorate %11 RelaxedPrecision +%2 = OpTypeVoid +%3 = OpTypeBool +%int = OpTypeInt 32 0 +%int_0 = OpConstant %int 0 +%4 = OpConstantFalse %3 +%1 = OpTypeFunction %2 +%6 = OpFunction %2 None %1 +%7 = OpLabel +OpSelectionMerge %10 None +OpBranchConditional %4 %8 %9 +%8 = OpLabel +%11 = OpIAdd %int %int_0 %int_0 +OpBranch %10 +%9 = OpLabel +OpReturn +%10 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(before, false); +} + +TEST_F(MergeReturnPassTest, SplitBlockUsedInPhi) { + const std::string before = + R"( +; CHECK: OpFunction +; CHECK: OpLoopMerge [[dummy_loop_merge:%\w+]] +; CHECK: OpLoopMerge [[loop_merge:%\w+]] +; CHECK: [[loop_merge]] = OpLabel +; CHECK: OpBranchConditional {{%\w+}} [[dummy_loop_merge]] [[old_code_path:%\w+]] +; CHECK: [[old_code_path:%\w+]] = OpLabel +; CHECK: OpBranchConditional {{%\w+}} [[side_node:%\w+]] [[phi_block:%\w+]] +; CHECK: [[phi_block]] = OpLabel +; CHECK-NEXT: OpPhi %bool %false [[side_node]] %true [[old_code_path]] + OpCapability Addresses + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "simple_shader" + %void = OpTypeVoid + %bool = OpTypeBool + %false = OpConstantFalse %bool + %true = OpConstantTrue %bool + %6 = OpTypeFunction %void + %1 = OpFunction %void None %6 + %7 = OpLabel + OpLoopMerge %merge %cont None + OpBranchConditional %false %9 %merge + %9 = OpLabel + OpReturn + %cont = OpLabel + OpBranch %7 + %merge = OpLabel + OpSelectionMerge %merge2 None + OpBranchConditional %false %if %merge2 + %if = OpLabel + OpBranch %merge2 + %merge2 = OpLabel + %12 = OpPhi %bool %false %if %true %merge + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(before, false); +} + +// TODO(#1861): Reenable these test when the breaks from selection constructs +// are reenabled. +/* +TEST_F(MergeReturnPassTest, UpdateOrderWhenPredicating) { + const std::string before = + R"( +; CHECK: OpFunction +; CHECK: OpFunction +; CHECK: OpSelectionMerge [[m1:%\w+]] None +; CHECK-NOT: OpReturn +; CHECK: [[m1]] = OpLabel +; CHECK: OpSelectionMerge [[m2:%\w+]] None +; CHECK: OpSelectionMerge [[m3:%\w+]] None +; CHECK: OpSelectionMerge [[m4:%\w+]] None +; CHECK: OpLabel +; CHECK-NEXT: OpStore +; CHECK-NEXT: OpBranch [[m4]] +; CHECK: [[m4]] = OpLabel +; CHECK-NEXT: [[ld4:%\w+]] = OpLoad %bool +; CHECK-NEXT: OpBranchConditional [[ld4]] [[m3]] +; CHECK: [[m3]] = OpLabel +; CHECK-NEXT: [[ld3:%\w+]] = OpLoad %bool +; CHECK-NEXT: OpBranchConditional [[ld3]] [[m2]] +; CHECK: [[m2]] = OpLabel + OpCapability SampledBuffer + OpCapability StorageImageExtendedFormats + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "PS_DebugTiles" + OpExecutionMode %1 OriginUpperLeft + OpSource HLSL 600 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %bool = OpTypeBool + %1 = OpFunction %void None %3 + %5 = OpLabel + %6 = OpFunctionCall %void %7 + OpReturn + OpFunctionEnd + %7 = OpFunction %void None %3 + %8 = OpLabel + %9 = OpUndef %bool + OpSelectionMerge %10 None + OpBranchConditional %9 %11 %10 + %11 = OpLabel + OpReturn + %10 = OpLabel + %12 = OpUndef %bool + OpSelectionMerge %13 None + OpBranchConditional %12 %14 %15 + %15 = OpLabel + %16 = OpUndef %bool + OpSelectionMerge %17 None + OpBranchConditional %16 %18 %17 + %18 = OpLabel + OpReturn + %17 = OpLabel + OpBranch %13 + %14 = OpLabel + OpReturn + %13 = OpLabel + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(before, false); +} +*/ + +TEST_F(MergeReturnPassTest, StructuredControlFlowBothMergeAndHeader) { + const std::string test = + R"( +; CHECK: OpFunction +; CHECK: [[ret_flag:%\w+]] = OpVariable %_ptr_Function_bool Function %false +; CHECK: OpLoopMerge [[dummy_loop_merge:%\w+]] +; CHECK: OpLoopMerge [[loop1_merge:%\w+]] {{%\w+}} +; CHECK-NEXT: OpBranchConditional {{%\w+}} [[if_lab:%\w+]] {{%\w+}} +; CHECK: [[if_lab]] = OpLabel +; CHECK: OpStore [[ret_flag]] %true +; CHECK-NEXT: OpBranch [[loop1_merge]] +; CHECK: [[loop1_merge]] = OpLabel +; CHECK-NEXT: [[ld:%\w+]] = OpLoad %bool [[ret_flag]] +; CHECK-NOT: OpLabel +; CHECK: OpBranchConditional [[ld]] [[dummy_loop_merge]] [[empty_block:%\w+]] +; CHECK: [[empty_block]] = OpLabel +; CHECK-NEXT: OpBranch [[loop2:%\w+]] +; CHECK: [[loop2]] = OpLabel +; CHECK-NOT: OpLabel +; CHECK: OpLoopMerge + OpCapability Addresses + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "simple_shader" + %void = OpTypeVoid + %bool = OpTypeBool + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %false = OpConstantFalse %bool + %7 = OpTypeFunction %void + %1 = OpFunction %void None %7 + %8 = OpLabel + OpBranch %9 + %9 = OpLabel + OpLoopMerge %10 %11 None + OpBranchConditional %false %12 %13 + %12 = OpLabel + OpReturn + %13 = OpLabel + OpBranch %10 + %11 = OpLabel + OpBranch %9 + %10 = OpLabel + OpLoopMerge %14 %15 None + OpBranch %15 + %15 = OpLabel + %16 = OpIAdd %uint %uint_0 %uint_0 + OpBranchConditional %false %10 %14 + %14 = OpLabel + %17 = OpIAdd %uint %16 %16 + OpReturn + OpFunctionEnd + +)"; + + const std::string after = + R"(OpCapability Addresses +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %1 "simple_shader" +%void = OpTypeVoid +%bool = OpTypeBool +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%false = OpConstantFalse %bool +%7 = OpTypeFunction %void +%_ptr_Function_bool = OpTypePointer Function %bool +%true = OpConstantTrue %bool +%1 = OpFunction %void None %7 +%8 = OpLabel +%18 = OpVariable %_ptr_Function_bool Function %false +OpSelectionMerge %9 None +OpBranchConditional %false %10 %11 +%10 = OpLabel +OpStore %18 %true +OpBranch %9 +%11 = OpLabel +OpBranch %9 +%9 = OpLabel +%23 = OpLoad %bool %18 +OpSelectionMerge %22 None +OpBranchConditional %23 %22 %21 +%21 = OpLabel +OpBranch %20 +%20 = OpLabel +OpLoopMerge %12 %13 None +OpBranch %13 +%13 = OpLabel +%14 = OpIAdd %uint %uint_0 %uint_0 +OpBranchConditional %false %20 %12 +%12 = OpLabel +%15 = OpIAdd %uint %14 %14 +OpStore %18 %true +OpBranch %22 +%22 = OpLabel +OpBranch %16 +%16 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(test, false); +} + +// TODO(#1861): Reenable these test when the breaks from selection constructs +// are reenabled. +/* +TEST_F(MergeReturnPassTest, NestedSelectionMerge) { + const std::string before = + R"( + OpCapability Addresses + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "simple_shader" + %void = OpTypeVoid + %bool = OpTypeBool + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %false = OpConstantFalse %bool + %7 = OpTypeFunction %void + %1 = OpFunction %void None %7 + %8 = OpLabel + OpSelectionMerge %9 None + OpBranchConditional %false %10 %11 + %10 = OpLabel + OpReturn + %11 = OpLabel + OpSelectionMerge %12 None + OpBranchConditional %false %13 %14 + %13 = OpLabel + %15 = OpIAdd %uint %uint_0 %uint_0 + OpBranch %12 + %14 = OpLabel + OpReturn + %12 = OpLabel + OpBranch %9 + %9 = OpLabel + %16 = OpIAdd %uint %15 %15 + OpReturn + OpFunctionEnd +)"; + + const std::string after = + R"(OpCapability Addresses +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %1 "simple_shader" +%void = OpTypeVoid +%bool = OpTypeBool +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%false = OpConstantFalse %bool +%7 = OpTypeFunction %void +%_ptr_Function_bool = OpTypePointer Function %bool +%true = OpConstantTrue %bool +%26 = OpUndef %uint +%1 = OpFunction %void None %7 +%8 = OpLabel +%19 = OpVariable %_ptr_Function_bool Function %false +OpSelectionMerge %9 None +OpBranchConditional %false %10 %11 +%10 = OpLabel +OpStore %19 %true +OpBranch %9 +%11 = OpLabel +OpSelectionMerge %12 None +OpBranchConditional %false %13 %14 +%13 = OpLabel +%15 = OpIAdd %uint %uint_0 %uint_0 +OpBranch %12 +%14 = OpLabel +OpStore %19 %true +OpBranch %12 +%12 = OpLabel +%27 = OpPhi %uint %15 %13 %26 %14 +%22 = OpLoad %bool %19 +OpBranchConditional %22 %9 %21 +%21 = OpLabel +OpBranch %9 +%9 = OpLabel +%28 = OpPhi %uint %27 %21 %26 %10 %26 %12 +%25 = OpLoad %bool %19 +OpSelectionMerge %24 None +OpBranchConditional %25 %24 %23 +%23 = OpLabel +%16 = OpIAdd %uint %28 %28 +OpStore %19 %true +OpBranch %24 +%24 = OpLabel +OpBranch %17 +%17 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(before, after, false, true); +} + +// This is essentially the same as NestedSelectionMerge, except +// the order of the first branch is changed. This is to make sure things +// work even if the order of the traversals change. +TEST_F(MergeReturnPassTest, NestedSelectionMerge2) { + const std::string before = + R"( OpCapability Addresses + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "simple_shader" + %void = OpTypeVoid + %bool = OpTypeBool + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %false = OpConstantFalse %bool + %7 = OpTypeFunction %void + %1 = OpFunction %void None %7 + %8 = OpLabel + OpSelectionMerge %9 None + OpBranchConditional %false %10 %11 + %11 = OpLabel + OpReturn + %10 = OpLabel + OpSelectionMerge %12 None + OpBranchConditional %false %13 %14 + %13 = OpLabel + %15 = OpIAdd %uint %uint_0 %uint_0 + OpBranch %12 + %14 = OpLabel + OpReturn + %12 = OpLabel + OpBranch %9 + %9 = OpLabel + %16 = OpIAdd %uint %15 %15 + OpReturn + OpFunctionEnd +)"; + + const std::string after = + R"(OpCapability Addresses +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %1 "simple_shader" +%void = OpTypeVoid +%bool = OpTypeBool +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%false = OpConstantFalse %bool +%7 = OpTypeFunction %void +%_ptr_Function_bool = OpTypePointer Function %bool +%true = OpConstantTrue %bool +%26 = OpUndef %uint +%1 = OpFunction %void None %7 +%8 = OpLabel +%19 = OpVariable %_ptr_Function_bool Function %false +OpSelectionMerge %9 None +OpBranchConditional %false %10 %11 +%11 = OpLabel +OpStore %19 %true +OpBranch %9 +%10 = OpLabel +OpSelectionMerge %12 None +OpBranchConditional %false %13 %14 +%13 = OpLabel +%15 = OpIAdd %uint %uint_0 %uint_0 +OpBranch %12 +%14 = OpLabel +OpStore %19 %true +OpBranch %12 +%12 = OpLabel +%27 = OpPhi %uint %15 %13 %26 %14 +%25 = OpLoad %bool %19 +OpBranchConditional %25 %9 %24 +%24 = OpLabel +OpBranch %9 +%9 = OpLabel +%28 = OpPhi %uint %27 %24 %26 %11 %26 %12 +%23 = OpLoad %bool %19 +OpSelectionMerge %22 None +OpBranchConditional %23 %22 %21 +%21 = OpLabel +%16 = OpIAdd %uint %28 %28 +OpStore %19 %true +OpBranch %22 +%22 = OpLabel +OpBranch %17 +%17 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(before, after, false, true); +} + +TEST_F(MergeReturnPassTest, NestedSelectionMerge3) { + const std::string before = + R"( OpCapability Addresses + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "simple_shader" + %void = OpTypeVoid + %bool = OpTypeBool + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %false = OpConstantFalse %bool + %7 = OpTypeFunction %void + %1 = OpFunction %void None %7 + %8 = OpLabel + OpSelectionMerge %9 None + OpBranchConditional %false %10 %11 + %11 = OpLabel + OpReturn + %10 = OpLabel + %12 = OpIAdd %uint %uint_0 %uint_0 + OpSelectionMerge %13 None + OpBranchConditional %false %14 %15 + %14 = OpLabel + OpBranch %13 + %15 = OpLabel + OpReturn + %13 = OpLabel + OpBranch %9 + %9 = OpLabel + %16 = OpIAdd %uint %12 %12 + OpReturn + OpFunctionEnd +)"; + + const std::string after = + R"(OpCapability Addresses +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %1 "simple_shader" +%void = OpTypeVoid +%bool = OpTypeBool +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%false = OpConstantFalse %bool +%7 = OpTypeFunction %void +%_ptr_Function_bool = OpTypePointer Function %bool +%true = OpConstantTrue %bool +%26 = OpUndef %uint +%1 = OpFunction %void None %7 +%8 = OpLabel +%19 = OpVariable %_ptr_Function_bool Function %false +OpSelectionMerge %9 None +OpBranchConditional %false %10 %11 +%11 = OpLabel +OpStore %19 %true +OpBranch %9 +%10 = OpLabel +%12 = OpIAdd %uint %uint_0 %uint_0 +OpSelectionMerge %13 None +OpBranchConditional %false %14 %15 +%14 = OpLabel +OpBranch %13 +%15 = OpLabel +OpStore %19 %true +OpBranch %13 +%13 = OpLabel +%25 = OpLoad %bool %19 +OpBranchConditional %25 %9 %24 +%24 = OpLabel +OpBranch %9 +%9 = OpLabel +%27 = OpPhi %uint %12 %24 %26 %11 %26 %13 +%23 = OpLoad %bool %19 +OpSelectionMerge %22 None +OpBranchConditional %23 %22 %21 +%21 = OpLabel +%16 = OpIAdd %uint %27 %27 +OpStore %19 %true +OpBranch %22 +%22 = OpLabel +OpBranch %17 +%17 = OpLabel +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(before, after, false, true); +} +*/ + +TEST_F(MergeReturnPassTest, NestedLoopMerge) { + const std::string test = + R"( +; CHECK: OpFunction +; CHECK: OpLoopMerge [[dummy_loop_merge:%\w+]] +; CHECK: OpLoopMerge [[outer_loop_merge:%\w+]] +; CHECK: OpLoopMerge [[inner_loop_merge:%\w+]] +; CHECK: OpSelectionMerge +; CHECK-NEXT: OpBranchConditional %true [[early_exit_block:%\w+]] +; CHECK: [[early_exit_block]] = OpLabel +; CHECK-NOT: OpLabel +; CHECK: OpBranch [[inner_loop_merge]] +; CHECK: [[inner_loop_merge]] = OpLabel +; CHECK-NOT: OpLabel +; CHECK: OpBranchConditional {{%\w+}} [[outer_loop_merge]] +; CHECK: [[outer_loop_merge]] = OpLabel +; CHECK-NOT: OpLabel +; CHECK: OpBranchConditional {{%\w+}} [[dummy_loop_merge]] +; CHECK: [[dummy_loop_merge]] = OpLabel +; CHECK-NOT: OpLabel +; CHECK: OpReturn + OpCapability SampledBuffer + OpCapability StorageImageExtendedFormats + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %2 "CS" + OpExecutionMode %2 LocalSize 8 8 1 + OpSource HLSL 600 + %uint = OpTypeInt 32 0 + %void = OpTypeVoid + %6 = OpTypeFunction %void + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 + %v3uint = OpTypeVector %uint 3 + %bool = OpTypeBool + %true = OpConstantTrue %bool +%_ptr_Function_uint = OpTypePointer Function %uint + %2 = OpFunction %void None %6 + %14 = OpLabel + OpBranch %19 + %19 = OpLabel + %20 = OpPhi %uint %uint_0 %2 %34 %23 + %21 = OpULessThan %bool %20 %uint_1 + OpLoopMerge %22 %23 DontUnroll + OpBranchConditional %21 %24 %22 + %24 = OpLabel + OpBranch %25 + %25 = OpLabel + %27 = OpINotEqual %bool %uint_1 %uint_0 + OpLoopMerge %28 %29 DontUnroll + OpBranchConditional %27 %30 %28 + %30 = OpLabel + OpSelectionMerge %31 None + OpBranchConditional %true %32 %31 + %32 = OpLabel + OpReturn + %31 = OpLabel + OpBranch %29 + %29 = OpLabel + OpBranch %25 + %28 = OpLabel + OpBranch %23 + %23 = OpLabel + %34 = OpIAdd %uint %20 %uint_1 + OpBranch %19 + %22 = OpLabel + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(test, false); +} + +TEST_F(MergeReturnPassTest, ReturnValueDecoration) { + const std::string test = + R"( +; CHECK: OpDecorate [[func:%\w+]] RelaxedPrecision +; CHECK: OpDecorate [[ret_val:%\w+]] RelaxedPrecision +; CHECK: [[func]] = OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NOT: OpLabel +; CHECK: [[ret_val]] = OpVariable +OpCapability Linkage +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %11 "simple_shader" +OpDecorate %7 RelaxedPrecision +%12 = OpTypeVoid +%1 = OpTypeInt 32 0 +%2 = OpTypeBool +%3 = OpConstantFalse %2 +%4 = OpConstant %1 0 +%5 = OpConstant %1 1 +%6 = OpTypeFunction %1 +%13 = OpTypeFunction %12 +%11 = OpFunction %12 None %13 +%l1 = OpLabel +%fc = OpFunctionCall %1 %7 +OpReturn +OpFunctionEnd +%7 = OpFunction %1 None %6 +%8 = OpLabel +OpBranchConditional %3 %9 %10 +%9 = OpLabel +OpReturnValue %4 +%10 = OpLabel +OpReturnValue %5 +OpFunctionEnd +)"; + + SinglePassRunAndMatch(test, false); +} + +TEST_F(MergeReturnPassTest, + StructuredControlFlowWithNonTrivialUnreachableMerge) { + const std::string before = + R"( +OpCapability Addresses +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %6 "simple_shader" +%2 = OpTypeVoid +%3 = OpTypeBool +%4 = OpConstantFalse %3 +%1 = OpTypeFunction %2 +%6 = OpFunction %2 None %1 +%7 = OpLabel +OpSelectionMerge %10 None +OpBranchConditional %4 %8 %9 +%8 = OpLabel +OpReturn +%9 = OpLabel +OpReturn +%10 = OpLabel +%11 = OpUndef %3 +OpUnreachable +OpFunctionEnd +)"; + + std::vector messages = { + {SPV_MSG_ERROR, nullptr, 0, 0, + "Module contains unreachable blocks during merge return. Run dead " + "branch elimination before merge return."}}; + SetMessageConsumer(GetTestMessageConsumer(messages)); + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + auto result = SinglePassRunToBinary(before, false); + EXPECT_EQ(Pass::Status::Failure, std::get<1>(result)); + EXPECT_TRUE(messages.empty()); +} + +TEST_F(MergeReturnPassTest, + StructuredControlFlowWithNonTrivialUnreachableContinue) { + const std::string before = + R"( +OpCapability Addresses +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %6 "simple_shader" +%2 = OpTypeVoid +%3 = OpTypeBool +%4 = OpConstantFalse %3 +%1 = OpTypeFunction %2 +%6 = OpFunction %2 None %1 +%7 = OpLabel +OpBranch %header +%header = OpLabel +OpLoopMerge %merge %continue None +OpBranchConditional %4 %8 %merge +%8 = OpLabel +OpReturn +%continue = OpLabel +%11 = OpUndef %3 +OpBranch %header +%merge = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::vector messages = { + {SPV_MSG_ERROR, nullptr, 0, 0, + "Module contains unreachable blocks during merge return. Run dead " + "branch elimination before merge return."}}; + SetMessageConsumer(GetTestMessageConsumer(messages)); + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + auto result = SinglePassRunToBinary(before, false); + EXPECT_EQ(Pass::Status::Failure, std::get<1>(result)); + EXPECT_TRUE(messages.empty()); +} + +TEST_F(MergeReturnPassTest, StructuredControlFlowWithUnreachableBlock) { + const std::string before = + R"( +OpCapability Addresses +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %6 "simple_shader" +%2 = OpTypeVoid +%3 = OpTypeBool +%4 = OpConstantFalse %3 +%1 = OpTypeFunction %2 +%6 = OpFunction %2 None %1 +%7 = OpLabel +OpBranch %header +%header = OpLabel +OpLoopMerge %merge %continue None +OpBranchConditional %4 %8 %merge +%8 = OpLabel +OpReturn +%continue = OpLabel +OpBranch %header +%merge = OpLabel +OpReturn +%unreachable = OpLabel +OpUnreachable +OpFunctionEnd +)"; + + std::vector messages = { + {SPV_MSG_ERROR, nullptr, 0, 0, + "Module contains unreachable blocks during merge return. Run dead " + "branch elimination before merge return."}}; + SetMessageConsumer(GetTestMessageConsumer(messages)); + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + auto result = SinglePassRunToBinary(before, false); + EXPECT_EQ(Pass::Status::Failure, std::get<1>(result)); + EXPECT_TRUE(messages.empty()); +} + +TEST_F(MergeReturnPassTest, StructuredControlFlowDontChangeEntryPhi) { + const std::string before = + R"( +; CHECK: OpFunction %void +; CHECK: OpLabel +; CHECK: OpLabel +; CHECK: [[pre_header:%\w+]] = OpLabel +; CHECK: [[header:%\w+]] = OpLabel +; CHECK-NEXT: OpPhi %bool {{%\w+}} [[pre_header]] [[iv:%\w+]] [[continue:%\w+]] +; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] +; CHECK: [[continue]] = OpLabel +; CHECK-NEXT: [[iv]] = Op +; CHECK: [[merge]] = OpLabel + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "main" + %void = OpTypeVoid + %bool = OpTypeBool + %4 = OpTypeFunction %void + %1 = OpFunction %void None %4 + %5 = OpLabel + %6 = OpUndef %bool + OpBranch %7 + %7 = OpLabel + %8 = OpPhi %bool %6 %5 %9 %10 + OpLoopMerge %11 %10 None + OpBranch %12 + %12 = OpLabel + %13 = OpUndef %bool + OpSelectionMerge %10 DontFlatten + OpBranchConditional %13 %10 %14 + %14 = OpLabel + OpReturn + %10 = OpLabel + %9 = OpUndef %bool + OpBranchConditional %13 %7 %11 + %11 = OpLabel + OpReturn + OpFunctionEnd + +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndMatch(before, false); +} + +TEST_F(MergeReturnPassTest, StructuredControlFlowPartialReplacePhi) { + const std::string before = + R"( +; CHECK: OpFunction %void +; CHECK: OpLabel +; CHECK: OpLabel +; CHECK: [[pre_header:%\w+]] = OpLabel +; CHECK: [[header:%\w+]] = OpLabel +; CHECK-NEXT: OpPhi +; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] +; CHECK: OpLabel +; CHECK: [[old_ret_block:%\w+]] = OpLabel +; CHECK: [[bb:%\w+]] = OpLabel +; CHECK-NEXT: [[val:%\w+]] = OpUndef %bool +; CHECK: [[merge]] = OpLabel +; CHECK-NEXT: [[phi1:%\w+]] = OpPhi %bool [[val]] [[bb]] {{%\w+}} [[old_ret_block]] +; CHECK: OpBranchConditional {{%\w+}} {{%\w+}} [[bb2:%\w+]] +; CHECK: [[bb2]] = OpLabel +; CHECK: OpBranch [[header2:%\w+]] +; CHECK: [[header2]] = OpLabel +; CHECK-NEXT: [[phi2:%\w+]] = OpPhi %bool [[phi1]] [[continue2:%\w+]] [[phi1]] [[bb2]] +; CHECK-NEXT: OpLoopMerge {{%\w+}} [[continue2]] +; CHECK: [[continue2]] = OpLabel +; CHECK-NEXT: OpBranch [[header2]] + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "main" + %void = OpTypeVoid + %bool = OpTypeBool + %4 = OpTypeFunction %void + %1 = OpFunction %void None %4 + %5 = OpLabel + %6 = OpUndef %bool + OpBranch %7 + %7 = OpLabel + %8 = OpPhi %bool %6 %5 %9 %10 + OpLoopMerge %11 %10 None + OpBranch %12 + %12 = OpLabel + %13 = OpUndef %bool + OpSelectionMerge %10 DontFlatten + OpBranchConditional %13 %10 %14 + %14 = OpLabel + OpReturn + %10 = OpLabel + %9 = OpUndef %bool + OpBranchConditional %13 %7 %11 + %11 = OpLabel + %phi = OpPhi %bool %9 %10 %9 %cont + OpLoopMerge %ret %cont None + OpBranch %bb + %bb = OpLabel + OpBranchConditional %13 %ret %cont + %cont = OpLabel + OpBranch %11 + %ret = OpLabel + OpReturn + OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndMatch(before, false); +} + +TEST_F(MergeReturnPassTest, GeneratePhiInOuterLoop) { + const std::string before = + R"( + ; CHECK: OpLoopMerge + ; CHECK: OpLoopMerge [[merge:%\w+]] [[continue:%\w+]] + ; CHECK: [[continue]] = OpLabel + ; CHECK-NEXT: [[undef:%\w+]] = OpUndef + ; CHECK: [[merge]] = OpLabel + ; CHECK-NEXT: [[phi:%\w+]] = OpPhi %bool [[undef]] [[continue]] {{%\w+}} {{%\w+}} + ; CHECK: OpCopyObject %bool [[phi]] + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %bool = OpTypeBool + %8 = OpTypeFunction %bool + %false = OpConstantFalse %bool + %4 = OpFunction %void None %3 + %5 = OpLabel + %63 = OpFunctionCall %bool %9 + OpReturn + OpFunctionEnd + %9 = OpFunction %bool None %8 + %10 = OpLabel + OpBranch %31 + %31 = OpLabel + OpLoopMerge %33 %34 None + OpBranch %32 + %32 = OpLabel + OpSelectionMerge %34 None + OpBranchConditional %false %46 %34 + %46 = OpLabel + OpLoopMerge %51 %52 None + OpBranch %53 + %53 = OpLabel + OpBranchConditional %false %50 %51 + %50 = OpLabel + OpReturnValue %false + %52 = OpLabel + OpBranch %46 + %51 = OpLabel + OpBranch %34 + %34 = OpLabel + %64 = OpUndef %bool + OpBranchConditional %false %31 %33 + %33 = OpLabel + OpBranch %28 + %28 = OpLabel + %60 = OpCopyObject %bool %64 + OpBranch %17 + %17 = OpLabel + OpReturnValue %false + OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndMatch(before, false); +} + +TEST_F(MergeReturnPassTest, InnerLoopMergeIsOuterLoopContinue) { + const std::string before = + R"( + ; CHECK: OpLoopMerge + ; CHECK-NEXT: OpBranch [[bb1:%\w+]] + ; CHECK: [[bb1]] = OpLabel + ; CHECK-NEXT: OpBranch [[outer_loop_header:%\w+]] + ; CHECK: [[outer_loop_header]] = OpLabel + ; CHECK-NEXT: OpLoopMerge [[outer_loop_merge:%\w+]] [[outer_loop_continue:%\w+]] None + ; CHECK: [[outer_loop_continue]] = OpLabel + ; CHECK-NEXT: OpBranch [[outer_loop_header]] + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + %void = OpTypeVoid + %4 = OpTypeFunction %void + %bool = OpTypeBool + %6 = OpTypeFunction %bool + %true = OpConstantTrue %bool + %2 = OpFunction %void None %4 + %8 = OpLabel + %9 = OpFunctionCall %bool %10 + OpReturn + OpFunctionEnd + %10 = OpFunction %bool None %6 + %11 = OpLabel + OpBranch %12 + %12 = OpLabel + OpLoopMerge %13 %14 None + OpBranchConditional %true %15 %13 + %15 = OpLabel + OpLoopMerge %14 %16 None + OpBranchConditional %true %17 %14 + %17 = OpLabel + OpReturnValue %true + %16 = OpLabel + OpBranch %15 + %14 = OpLabel + OpBranch %12 + %13 = OpLabel + OpReturnValue %true + OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndMatch(before, false); +} +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/pass_remove_duplicates_test.cpp b/test/opt/pass_remove_duplicates_test.cpp new file mode 100644 index 000000000..887fdfdb4 --- /dev/null +++ b/test/opt/pass_remove_duplicates_test.cpp @@ -0,0 +1,646 @@ +// Copyright (c) 2017 Pierre Moreau +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/build_module.h" +#include "source/opt/ir_context.h" +#include "source/opt/pass_manager.h" +#include "source/opt/remove_duplicates_pass.h" +#include "source/spirv_constant.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace opt { +namespace { + +class RemoveDuplicatesTest : public ::testing::Test { + public: + RemoveDuplicatesTest() + : tools_(SPV_ENV_UNIVERSAL_1_2), + context_(), + consumer_([this](spv_message_level_t level, const char*, + const spv_position_t& position, const char* message) { + if (!error_message_.empty()) error_message_ += "\n"; + switch (level) { + case SPV_MSG_FATAL: + case SPV_MSG_INTERNAL_ERROR: + case SPV_MSG_ERROR: + error_message_ += "ERROR"; + break; + case SPV_MSG_WARNING: + error_message_ += "WARNING"; + break; + case SPV_MSG_INFO: + error_message_ += "INFO"; + break; + case SPV_MSG_DEBUG: + error_message_ += "DEBUG"; + break; + } + error_message_ += + ": " + std::to_string(position.index) + ": " + message; + }), + disassemble_options_(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER), + error_message_() { + tools_.SetMessageConsumer(consumer_); + } + + void TearDown() override { error_message_.clear(); } + + std::string RunPass(const std::string& text) { + context_ = spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_2, consumer_, text); + if (!context_.get()) return std::string(); + + PassManager manager; + manager.SetMessageConsumer(consumer_); + manager.AddPass(); + + Pass::Status pass_res = manager.Run(context_.get()); + if (pass_res == Pass::Status::Failure) return std::string(); + + return ModuleToText(); + } + + // Disassembles |binary| and outputs the result in |text|. If |text| is a + // null pointer, SPV_ERROR_INVALID_POINTER is returned. + spv_result_t Disassemble(const std::vector& binary, + std::string* text) { + if (!text) return SPV_ERROR_INVALID_POINTER; + return tools_.Disassemble(binary, text, disassemble_options_) + ? SPV_SUCCESS + : SPV_ERROR_INVALID_BINARY; + } + + // Returns the accumulated error messages for the test. + std::string GetErrorMessage() const { return error_message_; } + + std::string ToText(const std::vector& inst) { + std::vector binary = {SpvMagicNumber, 0x10200, 0u, 2u, 0u}; + for (const Instruction* i : inst) + i->ToBinaryWithoutAttachedDebugInsts(&binary); + std::string text; + Disassemble(binary, &text); + return text; + } + + std::string ModuleToText() { + std::vector binary; + context_->module()->ToBinary(&binary, false); + std::string text; + Disassemble(binary, &text); + return text; + } + + private: + spvtools::SpirvTools + tools_; // An instance for calling SPIRV-Tools functionalities. + std::unique_ptr context_; + spvtools::MessageConsumer consumer_; + uint32_t disassemble_options_; + std::string error_message_; +}; + +TEST_F(RemoveDuplicatesTest, DuplicateCapabilities) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability Shader +OpMemoryModel Logical GLSL450 +)"; + const std::string after = R"(OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +)"; + + EXPECT_EQ(RunPass(spirv), after); + EXPECT_EQ(GetErrorMessage(), ""); +} + +TEST_F(RemoveDuplicatesTest, DuplicateExtInstImports) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +%1 = OpExtInstImport "OpenCL.std" +%2 = OpExtInstImport "OpenCL.std" +%3 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +)"; + const std::string after = R"(OpCapability Shader +OpCapability Linkage +%1 = OpExtInstImport "OpenCL.std" +%3 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +)"; + + EXPECT_EQ(RunPass(spirv), after); + EXPECT_EQ(GetErrorMessage(), ""); +} + +TEST_F(RemoveDuplicatesTest, DuplicateTypes) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeInt 32 0 +%2 = OpTypeInt 32 0 +%3 = OpTypeStruct %1 %2 +)"; + const std::string after = R"(OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeInt 32 0 +%3 = OpTypeStruct %1 %1 +)"; + + EXPECT_EQ(RunPass(spirv), after); + EXPECT_EQ(GetErrorMessage(), ""); +} + +TEST_F(RemoveDuplicatesTest, SameTypeDifferentMemberDecoration) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 GLSLPacked +%2 = OpTypeInt 32 0 +%1 = OpTypeStruct %2 %2 +%3 = OpTypeStruct %2 %2 +)"; + const std::string after = R"(OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 GLSLPacked +%2 = OpTypeInt 32 0 +%1 = OpTypeStruct %2 %2 +%3 = OpTypeStruct %2 %2 +)"; + + EXPECT_EQ(RunPass(spirv), after); + EXPECT_EQ(GetErrorMessage(), ""); +} + +TEST_F(RemoveDuplicatesTest, SameTypeAndMemberDecoration) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 GLSLPacked +OpDecorate %2 GLSLPacked +%3 = OpTypeInt 32 0 +%1 = OpTypeStruct %3 %3 +%2 = OpTypeStruct %3 %3 +)"; + const std::string after = R"(OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 GLSLPacked +%3 = OpTypeInt 32 0 +%1 = OpTypeStruct %3 %3 +)"; + + EXPECT_EQ(RunPass(spirv), after); + EXPECT_EQ(GetErrorMessage(), ""); +} + +TEST_F(RemoveDuplicatesTest, SameTypeAndDifferentName) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %1 "Type1" +OpName %2 "Type2" +%3 = OpTypeInt 32 0 +%1 = OpTypeStruct %3 %3 +%2 = OpTypeStruct %3 %3 +)"; + const std::string after = R"(OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %1 "Type1" +%3 = OpTypeInt 32 0 +%1 = OpTypeStruct %3 %3 +)"; + + EXPECT_EQ(RunPass(spirv), after); + EXPECT_EQ(GetErrorMessage(), ""); +} + +// Check that #1033 has been fixed. +TEST_F(RemoveDuplicatesTest, DoNotRemoveDifferentOpDecorationGroup) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 Constant +%1 = OpDecorationGroup +OpDecorate %2 Restrict +%2 = OpDecorationGroup +OpGroupDecorate %3 %1 %2 +%4 = OpTypeInt 32 0 +%3 = OpVariable %4 Uniform +)"; + const std::string after = R"(OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 Constant +%1 = OpDecorationGroup +OpDecorate %2 Restrict +%2 = OpDecorationGroup +OpGroupDecorate %3 %1 %2 +%4 = OpTypeInt 32 0 +%3 = OpVariable %4 Uniform +)"; + + EXPECT_EQ(RunPass(spirv), after); + EXPECT_EQ(GetErrorMessage(), ""); +} + +TEST_F(RemoveDuplicatesTest, DifferentDecorationGroup) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 Constant +OpDecorate %1 Restrict +%1 = OpDecorationGroup +OpDecorate %2 Constant +%2 = OpDecorationGroup +OpGroupDecorate %1 %3 +OpGroupDecorate %2 %4 +%5 = OpTypeInt 32 0 +%3 = OpVariable %5 Uniform +%4 = OpVariable %5 Uniform +)"; + const std::string after = R"(OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %1 Constant +OpDecorate %1 Restrict +%1 = OpDecorationGroup +OpDecorate %2 Constant +%2 = OpDecorationGroup +OpGroupDecorate %1 %3 +OpGroupDecorate %2 %4 +%5 = OpTypeInt 32 0 +%3 = OpVariable %5 Uniform +%4 = OpVariable %5 Uniform +)"; + + EXPECT_EQ(RunPass(spirv), after); + EXPECT_EQ(GetErrorMessage(), ""); +} + +// Test what happens when a type is a resource type. For now we are merging +// them, but, if we want to merge types and make reflection work (issue #1372), +// we will not be able to merge %2 and %3 below. +TEST_F(RemoveDuplicatesTest, DontMergeNestedResourceTypes) { + const std::string spirv = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %1 "PositionAdjust" +OpMemberName %1 0 "XAdjust" +OpName %2 "NormalAdjust" +OpMemberName %2 0 "XDir" +OpMemberName %3 0 "AdjustXYZ" +OpMemberName %3 1 "AdjustDir" +OpName %4 "Constants" +OpMemberDecorate %1 0 Offset 0 +OpMemberDecorate %2 0 Offset 0 +OpMemberDecorate %3 0 Offset 0 +OpMemberDecorate %3 1 Offset 16 +OpDecorate %3 Block +OpDecorate %4 DescriptorSet 0 +OpDecorate %4 Binding 0 +%5 = OpTypeFloat 32 +%6 = OpTypeVector %5 3 +%1 = OpTypeStruct %6 +%2 = OpTypeStruct %6 +%3 = OpTypeStruct %1 %2 +%7 = OpTypePointer Uniform %3 +%4 = OpVariable %7 Uniform +)"; + + const std::string result = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %1 "PositionAdjust" +OpMemberName %1 0 "XAdjust" +OpMemberName %3 0 "AdjustXYZ" +OpMemberName %3 1 "AdjustDir" +OpName %4 "Constants" +OpMemberDecorate %1 0 Offset 0 +OpMemberDecorate %3 0 Offset 0 +OpMemberDecorate %3 1 Offset 16 +OpDecorate %3 Block +OpDecorate %4 DescriptorSet 0 +OpDecorate %4 Binding 0 +%5 = OpTypeFloat 32 +%6 = OpTypeVector %5 3 +%1 = OpTypeStruct %6 +%3 = OpTypeStruct %1 %1 +%7 = OpTypePointer Uniform %3 +%4 = OpVariable %7 Uniform +)"; + + EXPECT_EQ(RunPass(spirv), result); + EXPECT_EQ(GetErrorMessage(), ""); +} + +// See comment for DontMergeNestedResourceTypes. +TEST_F(RemoveDuplicatesTest, DontMergeResourceTypes) { + const std::string spirv = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %1 "PositionAdjust" +OpMemberName %1 0 "XAdjust" +OpName %2 "NormalAdjust" +OpMemberName %2 0 "XDir" +OpName %3 "Constants" +OpMemberDecorate %1 0 Offset 0 +OpMemberDecorate %2 0 Offset 0 +OpDecorate %3 DescriptorSet 0 +OpDecorate %3 Binding 0 +OpDecorate %4 DescriptorSet 1 +OpDecorate %4 Binding 0 +%5 = OpTypeFloat 32 +%6 = OpTypeVector %5 3 +%1 = OpTypeStruct %6 +%2 = OpTypeStruct %6 +%7 = OpTypePointer Uniform %1 +%8 = OpTypePointer Uniform %2 +%3 = OpVariable %7 Uniform +%4 = OpVariable %8 Uniform +)"; + + const std::string result = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %1 "PositionAdjust" +OpMemberName %1 0 "XAdjust" +OpName %3 "Constants" +OpMemberDecorate %1 0 Offset 0 +OpDecorate %3 DescriptorSet 0 +OpDecorate %3 Binding 0 +OpDecorate %4 DescriptorSet 1 +OpDecorate %4 Binding 0 +%5 = OpTypeFloat 32 +%6 = OpTypeVector %5 3 +%1 = OpTypeStruct %6 +%7 = OpTypePointer Uniform %1 +%3 = OpVariable %7 Uniform +%4 = OpVariable %7 Uniform +)"; + + EXPECT_EQ(RunPass(spirv), result); + EXPECT_EQ(GetErrorMessage(), ""); +} + +// See comment for DontMergeNestedResourceTypes. +TEST_F(RemoveDuplicatesTest, DontMergeResourceTypesContainingArray) { + const std::string spirv = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %1 "PositionAdjust" +OpMemberName %1 0 "XAdjust" +OpName %2 "NormalAdjust" +OpMemberName %2 0 "XDir" +OpName %3 "Constants" +OpMemberDecorate %1 0 Offset 0 +OpMemberDecorate %2 0 Offset 0 +OpDecorate %3 DescriptorSet 0 +OpDecorate %3 Binding 0 +OpDecorate %4 DescriptorSet 1 +OpDecorate %4 Binding 0 +%5 = OpTypeFloat 32 +%6 = OpTypeVector %5 3 +%1 = OpTypeStruct %6 +%2 = OpTypeStruct %6 +%7 = OpTypeInt 32 0 +%8 = OpConstant %7 4 +%9 = OpTypeArray %1 %8 +%10 = OpTypeArray %2 %8 +%11 = OpTypePointer Uniform %9 +%12 = OpTypePointer Uniform %10 +%3 = OpVariable %11 Uniform +%4 = OpVariable %12 Uniform +)"; + + const std::string result = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %1 "PositionAdjust" +OpMemberName %1 0 "XAdjust" +OpName %3 "Constants" +OpMemberDecorate %1 0 Offset 0 +OpDecorate %3 DescriptorSet 0 +OpDecorate %3 Binding 0 +OpDecorate %4 DescriptorSet 1 +OpDecorate %4 Binding 0 +%5 = OpTypeFloat 32 +%6 = OpTypeVector %5 3 +%1 = OpTypeStruct %6 +%7 = OpTypeInt 32 0 +%8 = OpConstant %7 4 +%9 = OpTypeArray %1 %8 +%11 = OpTypePointer Uniform %9 +%3 = OpVariable %11 Uniform +%4 = OpVariable %11 Uniform +)"; + + EXPECT_EQ(RunPass(spirv), result); + EXPECT_EQ(GetErrorMessage(), ""); +} + +// Test that we merge the type of a resource with a type that is not the type +// a resource. The resource type appears first in this case. We must keep +// the resource type. +TEST_F(RemoveDuplicatesTest, MergeResourceTypeWithNonresourceType1) { + const std::string spirv = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %1 "PositionAdjust" +OpMemberName %1 0 "XAdjust" +OpName %2 "NormalAdjust" +OpMemberName %2 0 "XDir" +OpName %3 "Constants" +OpMemberDecorate %1 0 Offset 0 +OpMemberDecorate %2 0 Offset 0 +OpDecorate %3 DescriptorSet 0 +OpDecorate %3 Binding 0 +%4 = OpTypeFloat 32 +%5 = OpTypeVector %4 3 +%1 = OpTypeStruct %5 +%2 = OpTypeStruct %5 +%6 = OpTypePointer Uniform %1 +%7 = OpTypePointer Uniform %2 +%3 = OpVariable %6 Uniform +%8 = OpVariable %7 Uniform +)"; + + const std::string result = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %1 "PositionAdjust" +OpMemberName %1 0 "XAdjust" +OpName %3 "Constants" +OpMemberDecorate %1 0 Offset 0 +OpDecorate %3 DescriptorSet 0 +OpDecorate %3 Binding 0 +%4 = OpTypeFloat 32 +%5 = OpTypeVector %4 3 +%1 = OpTypeStruct %5 +%6 = OpTypePointer Uniform %1 +%3 = OpVariable %6 Uniform +%8 = OpVariable %6 Uniform +)"; + + EXPECT_EQ(RunPass(spirv), result); + EXPECT_EQ(GetErrorMessage(), ""); +} + +// Test that we merge the type of a resource with a type that is not the type +// a resource. The resource type appears second in this case. We must keep +// the resource type. +// +// See comment for DontMergeNestedResourceTypes. +TEST_F(RemoveDuplicatesTest, MergeResourceTypeWithNonresourceType2) { + const std::string spirv = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %1 "PositionAdjust" +OpMemberName %1 0 "XAdjust" +OpName %2 "NormalAdjust" +OpMemberName %2 0 "XDir" +OpName %3 "Constants" +OpMemberDecorate %1 0 Offset 0 +OpMemberDecorate %2 0 Offset 0 +OpDecorate %3 DescriptorSet 0 +OpDecorate %3 Binding 0 +%4 = OpTypeFloat 32 +%5 = OpTypeVector %4 3 +%1 = OpTypeStruct %5 +%2 = OpTypeStruct %5 +%6 = OpTypePointer Uniform %1 +%7 = OpTypePointer Uniform %2 +%8 = OpVariable %6 Uniform +%3 = OpVariable %7 Uniform +)"; + + const std::string result = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %1 "PositionAdjust" +OpMemberName %1 0 "XAdjust" +OpName %3 "Constants" +OpMemberDecorate %1 0 Offset 0 +OpDecorate %3 DescriptorSet 0 +OpDecorate %3 Binding 0 +%4 = OpTypeFloat 32 +%5 = OpTypeVector %4 3 +%1 = OpTypeStruct %5 +%6 = OpTypePointer Uniform %1 +%8 = OpVariable %6 Uniform +%3 = OpVariable %6 Uniform +)"; + + EXPECT_EQ(RunPass(spirv), result); + EXPECT_EQ(GetErrorMessage(), ""); +} + +// In this test, %8 and %9 are the same and only %9 is used in a resource. +// However, we cannot merge them unless we also merge %2 and %3, which cannot +// happen because both are used in resources. +// +// If we try to avoid replaces resource types, then remove duplicates should +// have not change in this case. That is not currently implemented. +TEST_F(RemoveDuplicatesTest, MergeResourceTypeWithNonresourceType3) { + const std::string spirv = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %1 "main" +OpSource HLSL 600 +OpName %2 "PositionAdjust" +OpMemberName %2 0 "XAdjust" +OpName %3 "NormalAdjust" +OpMemberName %3 0 "XDir" +OpName %4 "Constants" +OpMemberDecorate %2 0 Offset 0 +OpMemberDecorate %3 0 Offset 0 +OpDecorate %4 DescriptorSet 0 +OpDecorate %4 Binding 0 +OpDecorate %5 DescriptorSet 1 +OpDecorate %5 Binding 0 +%6 = OpTypeFloat 32 +%7 = OpTypeVector %6 3 +%2 = OpTypeStruct %7 +%3 = OpTypeStruct %7 +%8 = OpTypePointer Uniform %3 +%9 = OpTypePointer Uniform %2 +%10 = OpTypeStruct %3 +%11 = OpTypePointer Uniform %10 +%5 = OpVariable %9 Uniform +%4 = OpVariable %11 Uniform +%12 = OpTypeVoid +%13 = OpTypeFunction %12 +%14 = OpTypeInt 32 0 +%15 = OpConstant %14 0 +%1 = OpFunction %12 None %13 +%16 = OpLabel +%17 = OpAccessChain %8 %4 %15 +OpReturn +OpFunctionEnd +)"; + + const std::string result = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %1 "main" +OpSource HLSL 600 +OpName %2 "PositionAdjust" +OpMemberName %2 0 "XAdjust" +OpName %4 "Constants" +OpMemberDecorate %2 0 Offset 0 +OpDecorate %4 DescriptorSet 0 +OpDecorate %4 Binding 0 +OpDecorate %5 DescriptorSet 1 +OpDecorate %5 Binding 0 +%6 = OpTypeFloat 32 +%7 = OpTypeVector %6 3 +%2 = OpTypeStruct %7 +%8 = OpTypePointer Uniform %2 +%10 = OpTypeStruct %2 +%11 = OpTypePointer Uniform %10 +%5 = OpVariable %8 Uniform +%4 = OpVariable %11 Uniform +%12 = OpTypeVoid +%13 = OpTypeFunction %12 +%14 = OpTypeInt 32 0 +%15 = OpConstant %14 0 +%1 = OpFunction %12 None %13 +%16 = OpLabel +%17 = OpAccessChain %8 %4 %15 +OpReturn +OpFunctionEnd +)"; + + EXPECT_EQ(RunPass(spirv), result); + EXPECT_EQ(GetErrorMessage(), ""); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/pass_utils.cpp b/test/opt/pass_utils.cpp new file mode 100644 index 000000000..4709d0fd1 --- /dev/null +++ b/test/opt/pass_utils.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/opt/pass_utils.h" + +#include +#include + +namespace spvtools { +namespace opt { +namespace { + +// Well, this is another place requiring the knowledge of the grammar and can be +// stale when SPIR-V is updated. It would be nice to automatically generate +// this, but the cost is just too high. + +const char* kDebugOpcodes[] = { + // clang-format off + "OpSourceContinued", "OpSource", "OpSourceExtension", + "OpName", "OpMemberName", "OpString", + "OpLine", "OpNoLine", "OpModuleProcessed" + // clang-format on +}; + +} // anonymous namespace + +MessageConsumer GetTestMessageConsumer( + std::vector& expected_messages) { + return [&expected_messages](spv_message_level_t level, const char* source, + const spv_position_t& position, + const char* message) { + EXPECT_TRUE(!expected_messages.empty()); + if (expected_messages.empty()) { + return; + } + + EXPECT_EQ(expected_messages[0].level, level); + EXPECT_EQ(expected_messages[0].line_number, position.line); + EXPECT_EQ(expected_messages[0].column_number, position.column); + EXPECT_STREQ(expected_messages[0].source_file, source); + EXPECT_STREQ(expected_messages[0].message, message); + + expected_messages.erase(expected_messages.begin()); + }; +} + +bool FindAndReplace(std::string* process_str, const std::string find_str, + const std::string replace_str) { + if (process_str->empty() || find_str.empty()) { + return false; + } + bool replaced = false; + // Note this algorithm has quadratic time complexity. It is OK for test cases + // with short strings, but might not fit in other contexts. + for (size_t pos = process_str->find(find_str, 0); pos != std::string::npos; + pos = process_str->find(find_str, pos)) { + process_str->replace(pos, find_str.length(), replace_str); + pos += replace_str.length(); + replaced = true; + } + return replaced; +} + +bool ContainsDebugOpcode(const char* inst) { + return std::any_of(std::begin(kDebugOpcodes), std::end(kDebugOpcodes), + [inst](const char* op) { + return std::string(inst).find(op) != std::string::npos; + }); +} + +std::string SelectiveJoin(const std::vector& strings, + const std::function& skip_dictator, + char delimiter) { + std::ostringstream oss; + for (const auto* str : strings) { + if (!skip_dictator(str)) oss << str << delimiter; + } + return oss.str(); +} + +std::string JoinAllInsts(const std::vector& insts) { + return SelectiveJoin(insts, [](const char*) { return false; }); +} + +std::string JoinNonDebugInsts(const std::vector& insts) { + return SelectiveJoin( + insts, [](const char* inst) { return ContainsDebugOpcode(inst); }); +} + +} // namespace opt +} // namespace spvtools diff --git a/test/opt/pass_utils.h b/test/opt/pass_utils.h new file mode 100644 index 000000000..8968f8a64 --- /dev/null +++ b/test/opt/pass_utils.h @@ -0,0 +1,84 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TEST_OPT_PASS_UTILS_H_ +#define TEST_OPT_PASS_UTILS_H_ + +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "include/spirv-tools/libspirv.h" +#include "include/spirv-tools/libspirv.hpp" + +namespace spvtools { +namespace opt { + +struct Message { + spv_message_level_t level; + const char* source_file; + uint32_t line_number; + uint32_t column_number; + const char* message; +}; + +// Return a message consumer that can be used to check that the message produced +// are the messages in |expexted_messages|, and in the same order. +MessageConsumer GetTestMessageConsumer(std::vector& expected_messages); + +// In-place substring replacement. Finds the |find_str| in the |process_str| +// and replaces the found substring with |replace_str|. Returns true if at +// least one replacement is done successfully, returns false otherwise. The +// replaced substring won't be processed again, which means: If the +// |replace_str| has |find_str| as its substring, that newly replaced part of +// |process_str| won't be processed again. +bool FindAndReplace(std::string* process_str, const std::string find_str, + const std::string replace_str); + +// Returns true if the given string contains any debug opcode substring. +bool ContainsDebugOpcode(const char* inst); + +// Returns the concatenated string from a vector of |strings|, with postfixing +// each string with the given |delimiter|. if the |skip_dictator| returns true +// for an original string, that string will be omitted. +std::string SelectiveJoin(const std::vector& strings, + const std::function& skip_dictator, + char delimiter = '\n'); + +// Concatenates a vector of strings into one string. Each string is postfixed +// with '\n'. +std::string JoinAllInsts(const std::vector& insts); + +// Concatenates a vector of strings into one string. Each string is postfixed +// with '\n'. If a string contains opcode for debug instruction, that string +// will be ignored. +std::string JoinNonDebugInsts(const std::vector& insts); + +// Returns a vector that contains the contents of |a| followed by the contents +// of |b|. +template +std::vector Concat(const std::vector& a, const std::vector& b) { + std::vector ret; + std::copy(a.begin(), a.end(), back_inserter(ret)); + std::copy(b.begin(), b.end(), back_inserter(ret)); + return ret; +} + +} // namespace opt +} // namespace spvtools + +#endif // TEST_OPT_PASS_UTILS_H_ diff --git a/test/opt/pch_test_opt.cpp b/test/opt/pch_test_opt.cpp new file mode 100644 index 000000000..f15812913 --- /dev/null +++ b/test/opt/pch_test_opt.cpp @@ -0,0 +1,15 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "pch_test_opt.h" diff --git a/test/opt/pch_test_opt.h b/test/opt/pch_test_opt.h new file mode 100644 index 000000000..4e8106fbf --- /dev/null +++ b/test/opt/pch_test_opt.h @@ -0,0 +1,25 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gmock/gmock.h" +#include "source/opt/iterator.h" +#include "source/opt/loop_dependence.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/pass.h" +#include "source/opt/scalar_analysis.h" +#include "source/opt/tree_iterator.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" diff --git a/test/opt/private_to_local_test.cpp b/test/opt/private_to_local_test.cpp new file mode 100644 index 000000000..3ec74fae6 --- /dev/null +++ b/test/opt/private_to_local_test.cpp @@ -0,0 +1,313 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "source/opt/build_module.h" +#include "source/opt/value_number_table.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::HasSubstr; +using ::testing::MatchesRegex; +using PrivateToLocalTest = PassTest<::testing::Test>; + +TEST_F(PrivateToLocalTest, ChangeToLocal) { + // Change the private variable to a local, and change the types accordingly. + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 +; CHECK: [[float:%[a-zA-Z_\d]+]] = OpTypeFloat 32 + %5 = OpTypeFloat 32 +; CHECK: [[newtype:%[a-zA-Z_\d]+]] = OpTypePointer Function [[float]] + %6 = OpTypePointer Private %5 +; CHECK-NOT: OpVariable [[.+]] Private + %8 = OpVariable %6 Private +; CHECK: OpFunction + %2 = OpFunction %3 None %4 +; CHECK: OpLabel + %7 = OpLabel +; CHECK-NEXT: [[newvar:%[a-zA-Z_\d]+]] = OpVariable [[newtype]] Function +; CHECK: OpLoad [[float]] [[newvar]] + %9 = OpLoad %5 %8 + OpReturn + OpFunctionEnd + )"; + SinglePassRunAndMatch(text, false); +} + +TEST_F(PrivateToLocalTest, ReuseExistingType) { + // Change the private variable to a local, and change the types accordingly. + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 +; CHECK: [[float:%[a-zA-Z_\d]+]] = OpTypeFloat 32 + %5 = OpTypeFloat 32 + %func_ptr = OpTypePointer Function %5 +; CHECK: [[newtype:%[a-zA-Z_\d]+]] = OpTypePointer Function [[float]] +; CHECK-NOT: [[%[a-zA-Z_\d]+]] = OpTypePointer Function [[float]] + %6 = OpTypePointer Private %5 +; CHECK-NOT: OpVariable [[.+]] Private + %8 = OpVariable %6 Private +; CHECK: OpFunction + %2 = OpFunction %3 None %4 +; CHECK: OpLabel + %7 = OpLabel +; CHECK-NEXT: [[newvar:%[a-zA-Z_\d]+]] = OpVariable [[newtype]] Function +; CHECK: OpLoad [[float]] [[newvar]] + %9 = OpLoad %5 %8 + OpReturn + OpFunctionEnd + )"; + SinglePassRunAndMatch(text, false); +} + +TEST_F(PrivateToLocalTest, UpdateAccessChain) { + // Change the private variable to a local, and change the AccessChain. + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %void = OpTypeVoid + %6 = OpTypeFunction %void +; CHECK: [[float:%[a-zA-Z_\d]+]] = OpTypeFloat + %float = OpTypeFloat 32 +; CHECK: [[struct:%[a-zA-Z_\d]+]] = OpTypeStruct + %_struct_8 = OpTypeStruct %float +%_ptr_Private_float = OpTypePointer Private %float +; CHECK: [[new_struct_type:%[a-zA-Z_\d]+]] = OpTypePointer Function [[struct]] +; CHECK: [[new_float_type:%[a-zA-Z_\d]+]] = OpTypePointer Function [[float]] +%_ptr_Private__struct_8 = OpTypePointer Private %_struct_8 +; CHECK-NOT: OpVariable [[.+]] Private + %11 = OpVariable %_ptr_Private__struct_8 Private +; CHECK: OpFunction + %2 = OpFunction %void None %6 +; CHECK: OpLabel + %12 = OpLabel +; CHECK-NEXT: [[newvar:%[a-zA-Z_\d]+]] = OpVariable [[new_struct_type]] Function +; CHECK: [[member:%[a-zA-Z_\d]+]] = OpAccessChain [[new_float_type]] [[newvar]] + %13 = OpAccessChain %_ptr_Private_float %11 %uint_0 +; CHECK: OpLoad [[float]] [[member]] + %14 = OpLoad %float %13 + OpReturn + OpFunctionEnd + )"; + SinglePassRunAndMatch(text, false); +} + +TEST_F(PrivateToLocalTest, UseTexelPointer) { + // Change the private variable to a local, and change the OpImageTexelPointer. + const std::string text = R"( +OpCapability SampledBuffer + OpCapability StorageImageExtendedFormats + OpCapability ImageBuffer + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %2 "min" %gl_GlobalInvocationID + OpExecutionMode %2 LocalSize 64 1 1 + OpSource HLSL 600 + OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId + OpDecorate %4 DescriptorSet 4 + OpDecorate %4 Binding 70 + %uint = OpTypeInt 32 0 + %6 = OpTypeImage %uint Buffer 0 0 0 2 R32ui +%_ptr_UniformConstant_6 = OpTypePointer UniformConstant %6 +%_ptr_Private_6 = OpTypePointer Private %6 + %void = OpTypeVoid + %10 = OpTypeFunction %void + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 + %v3uint = OpTypeVector %uint 3 +%_ptr_Input_v3uint = OpTypePointer Input %v3uint +%_ptr_Image_uint = OpTypePointer Image %uint + %4 = OpVariable %_ptr_UniformConstant_6 UniformConstant + %16 = OpVariable %_ptr_Private_6 Private +%gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input + %2 = OpFunction %void None %10 + %17 = OpLabel +; Make sure the variable was moved. +; CHECK: OpFunction +; CHECK-NEXT: OpLabel +; CHECK-NEXT: OpVariable %_ptr_Function_6 Function + %18 = OpLoad %6 %4 + OpStore %16 %18 + %19 = OpImageTexelPointer %_ptr_Image_uint %16 %uint_0 %uint_0 + %20 = OpAtomicIAdd %uint %19 %uint_1 %uint_0 %uint_1 + OpReturn + OpFunctionEnd + )"; + SinglePassRunAndMatch(text, false); +} + +TEST_F(PrivateToLocalTest, UsedInTwoFunctions) { + // Should not change because it is used in multiple functions. + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer Private %5 + %8 = OpVariable %6 Private + %2 = OpFunction %3 None %4 + %7 = OpLabel + %9 = OpLoad %5 %8 + OpReturn + OpFunctionEnd + %10 = OpFunction %3 None %4 + %11 = OpLabel + %12 = OpLoad %5 %8 + OpReturn + OpFunctionEnd + )"; + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ false); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +TEST_F(PrivateToLocalTest, UsedInFunctionCall) { + // Should not change because it is used in a function call. Changing the + // signature of the function would require cloning the function, which is not + // worth it. + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %void = OpTypeVoid + %4 = OpTypeFunction %void + %float = OpTypeFloat 32 +%_ptr_Private_float = OpTypePointer Private %float + %7 = OpTypeFunction %void %_ptr_Private_float + %8 = OpVariable %_ptr_Private_float Private + %2 = OpFunction %void None %4 + %9 = OpLabel + %10 = OpFunctionCall %void %11 %8 + OpReturn + OpFunctionEnd + %11 = OpFunction %void None %7 + %12 = OpFunctionParameter %_ptr_Private_float + %13 = OpLabel + %14 = OpLoad %float %12 + OpReturn + OpFunctionEnd + )"; + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ false); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +TEST_F(PrivateToLocalTest, CreatePointerToAmbiguousStruct1) { + // Test that the correct pointer type is picked up. + const std::string text = R"( +; CHECK: [[struct1:%[a-zA-Z_\d]+]] = OpTypeStruct +; CHECK: [[struct2:%[a-zA-Z_\d]+]] = OpTypeStruct +; CHECK: [[priv_ptr:%[\w]+]] = OpTypePointer Private [[struct1]] +; CHECK: [[fuct_ptr2:%[\w]+]] = OpTypePointer Function [[struct2]] +; CHECK: [[fuct_ptr1:%[\w]+]] = OpTypePointer Function [[struct1]] +; CHECK: OpFunction +; CHECK: OpLabel +; CHECK-NEXT: [[newvar:%[a-zA-Z_\d]+]] = OpVariable [[fuct_ptr1]] Function +; CHECK: OpLoad [[struct1]] [[newvar]] + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %struct1 = OpTypeStruct %5 + %struct2 = OpTypeStruct %5 + %6 = OpTypePointer Private %struct1 + %func_ptr2 = OpTypePointer Function %struct2 + %8 = OpVariable %6 Private + %2 = OpFunction %3 None %4 + %7 = OpLabel + %9 = OpLoad %struct1 %8 + OpReturn + OpFunctionEnd + )"; + SinglePassRunAndMatch(text, false); +} + +TEST_F(PrivateToLocalTest, CreatePointerToAmbiguousStruct2) { + // Test that the correct pointer type is picked up. + const std::string text = R"( +; CHECK: [[struct1:%[a-zA-Z_\d]+]] = OpTypeStruct +; CHECK: [[struct2:%[a-zA-Z_\d]+]] = OpTypeStruct +; CHECK: [[priv_ptr:%[\w]+]] = OpTypePointer Private [[struct2]] +; CHECK: [[fuct_ptr1:%[\w]+]] = OpTypePointer Function [[struct1]] +; CHECK: [[fuct_ptr2:%[\w]+]] = OpTypePointer Function [[struct2]] +; CHECK: OpFunction +; CHECK: OpLabel +; CHECK-NEXT: [[newvar:%[a-zA-Z_\d]+]] = OpVariable [[fuct_ptr2]] Function +; CHECK: OpLoad [[struct2]] [[newvar]] + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %struct1 = OpTypeStruct %5 + %struct2 = OpTypeStruct %5 + %6 = OpTypePointer Private %struct2 + %func_ptr2 = OpTypePointer Function %struct1 + %8 = OpVariable %6 Private + %2 = OpFunction %3 None %4 + %7 = OpLabel + %9 = OpLoad %struct2 %8 + OpReturn + OpFunctionEnd + )"; + SinglePassRunAndMatch(text, false); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/process_lines_test.cpp b/test/opt/process_lines_test.cpp new file mode 100644 index 000000000..33ad4be89 --- /dev/null +++ b/test/opt/process_lines_test.cpp @@ -0,0 +1,695 @@ +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ProcessLinesTest = PassTest<::testing::Test>; + +TEST_F(ProcessLinesTest, SimplePropagation) { + // Texture2D g_tColor[128]; + // + // layout(push_constant) cbuffer PerViewConstantBuffer_t + // { + // uint g_nDataIdx; + // uint g_nDataIdx2; + // bool g_B; + // }; + // + // SamplerState g_sAniso; + // + // struct PS_INPUT + // { + // float2 vTextureCoords : TEXCOORD2; + // }; + // + // struct PS_OUTPUT + // { + // float4 vColor : SV_Target0; + // }; + // + // PS_OUTPUT MainPs(PS_INPUT i) + // { + // PS_OUTPUT ps_output; + // + // uint u; + // if (g_B) + // u = g_nDataIdx; + // else + // u = g_nDataIdx2; + // ps_output.vColor = g_tColor[u].Sample(g_sAniso, i.vTextureCoords.xy); + // return ps_output; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %MainPs "MainPs" %i_vTextureCoords %_entryPointOutput_vColor +OpExecutionMode %MainPs OriginUpperLeft +%5 = OpString "foo.frag" +OpSource HLSL 500 +OpName %MainPs "MainPs" +OpName %PS_INPUT "PS_INPUT" +OpMemberName %PS_INPUT 0 "vTextureCoords" +OpName %PS_OUTPUT "PS_OUTPUT" +OpMemberName %PS_OUTPUT 0 "vColor" +OpName %_MainPs_struct_PS_INPUT_vf21_ "@MainPs(struct-PS_INPUT-vf21;" +OpName %i "i" +OpName %PerViewConstantBuffer_t "PerViewConstantBuffer_t" +OpMemberName %PerViewConstantBuffer_t 0 "g_nDataIdx" +OpMemberName %PerViewConstantBuffer_t 1 "g_nDataIdx2" +OpMemberName %PerViewConstantBuffer_t 2 "g_B" +OpName %_ "" +OpName %u "u" +OpName %ps_output "ps_output" +OpName %g_tColor "g_tColor" +OpName %g_sAniso "g_sAniso" +OpName %i_0 "i" +OpName %i_vTextureCoords "i.vTextureCoords" +OpName %_entryPointOutput_vColor "@entryPointOutput.vColor" +OpName %param "param" +OpMemberDecorate %PerViewConstantBuffer_t 0 Offset 0 +OpMemberDecorate %PerViewConstantBuffer_t 1 Offset 4 +OpMemberDecorate %PerViewConstantBuffer_t 2 Offset 8 +OpDecorate %PerViewConstantBuffer_t Block +OpDecorate %g_tColor DescriptorSet 0 +OpDecorate %g_sAniso DescriptorSet 0 +OpDecorate %i_vTextureCoords Location 0 +OpDecorate %_entryPointOutput_vColor Location 0 +)"; + + const std::string before = + R"(%void = OpTypeVoid +%19 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%PS_INPUT = OpTypeStruct %v2float +%_ptr_Function_PS_INPUT = OpTypePointer Function %PS_INPUT +%v4float = OpTypeVector %float 4 +%PS_OUTPUT = OpTypeStruct %v4float +%24 = OpTypeFunction %PS_OUTPUT %_ptr_Function_PS_INPUT +%uint = OpTypeInt 32 0 +%PerViewConstantBuffer_t = OpTypeStruct %uint %uint %uint +%_ptr_PushConstant_PerViewConstantBuffer_t = OpTypePointer PushConstant %PerViewConstantBuffer_t +%_ = OpVariable %_ptr_PushConstant_PerViewConstantBuffer_t PushConstant +%int = OpTypeInt 32 1 +%int_2 = OpConstant %int 2 +%_ptr_PushConstant_uint = OpTypePointer PushConstant %uint +%bool = OpTypeBool +%uint_0 = OpConstant %uint 0 +%_ptr_Function_uint = OpTypePointer Function %uint +%int_0 = OpConstant %int 0 +%int_1 = OpConstant %int 1 +%_ptr_Function_PS_OUTPUT = OpTypePointer Function %PS_OUTPUT +%36 = OpTypeImage %float 2D 0 0 0 1 Unknown +%uint_128 = OpConstant %uint 128 +%_arr_36_uint_128 = OpTypeArray %36 %uint_128 +%_ptr_UniformConstant__arr_36_uint_128 = OpTypePointer UniformConstant %_arr_36_uint_128 +%g_tColor = OpVariable %_ptr_UniformConstant__arr_36_uint_128 UniformConstant +%_ptr_UniformConstant_36 = OpTypePointer UniformConstant %36 +%41 = OpTypeSampler +%_ptr_UniformConstant_41 = OpTypePointer UniformConstant %41 +%g_sAniso = OpVariable %_ptr_UniformConstant_41 UniformConstant +%43 = OpTypeSampledImage %36 +%_ptr_Function_v2float = OpTypePointer Function %v2float +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v2float = OpTypePointer Input %v2float +%i_vTextureCoords = OpVariable %_ptr_Input_v2float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_entryPointOutput_vColor = OpVariable %_ptr_Output_v4float Output +%MainPs = OpFunction %void None %19 +%48 = OpLabel +%i_0 = OpVariable %_ptr_Function_PS_INPUT Function +%param = OpVariable %_ptr_Function_PS_INPUT Function +OpLine %5 23 0 +%49 = OpLoad %v2float %i_vTextureCoords +%50 = OpAccessChain %_ptr_Function_v2float %i_0 %int_0 +OpStore %50 %49 +%51 = OpLoad %PS_INPUT %i_0 +OpStore %param %51 +%52 = OpFunctionCall %PS_OUTPUT %_MainPs_struct_PS_INPUT_vf21_ %param +%53 = OpCompositeExtract %v4float %52 0 +OpStore %_entryPointOutput_vColor %53 +OpReturn +OpFunctionEnd +%_MainPs_struct_PS_INPUT_vf21_ = OpFunction %PS_OUTPUT None %24 +%i = OpFunctionParameter %_ptr_Function_PS_INPUT +%54 = OpLabel +%u = OpVariable %_ptr_Function_uint Function +%ps_output = OpVariable %_ptr_Function_PS_OUTPUT Function +OpLine %5 27 0 +%55 = OpAccessChain %_ptr_PushConstant_uint %_ %int_2 +%56 = OpLoad %uint %55 +%57 = OpINotEqual %bool %56 %uint_0 +OpSelectionMerge %58 None +OpBranchConditional %57 %59 %60 +%59 = OpLabel +OpLine %5 28 0 +%61 = OpAccessChain %_ptr_PushConstant_uint %_ %int_0 +%62 = OpLoad %uint %61 +OpStore %u %62 +OpBranch %58 +%60 = OpLabel +OpLine %5 30 0 +%63 = OpAccessChain %_ptr_PushConstant_uint %_ %int_1 +%64 = OpLoad %uint %63 +OpStore %u %64 +OpBranch %58 +%58 = OpLabel +OpLine %5 31 0 +%65 = OpLoad %uint %u +%66 = OpAccessChain %_ptr_UniformConstant_36 %g_tColor %65 +%67 = OpLoad %36 %66 +%68 = OpLoad %41 %g_sAniso +%69 = OpSampledImage %43 %67 %68 +%70 = OpAccessChain %_ptr_Function_v2float %i %int_0 +%71 = OpLoad %v2float %70 +%72 = OpImageSampleImplicitLod %v4float %69 %71 +%73 = OpAccessChain %_ptr_Function_v4float %ps_output %int_0 +OpStore %73 %72 +OpLine %5 32 0 +%74 = OpLoad %PS_OUTPUT %ps_output +OpReturnValue %74 +OpFunctionEnd +)"; + + const std::string after = + R"(OpNoLine +%void = OpTypeVoid +OpNoLine +%19 = OpTypeFunction %void +OpNoLine +%float = OpTypeFloat 32 +OpNoLine +%v2float = OpTypeVector %float 2 +OpNoLine +%PS_INPUT = OpTypeStruct %v2float +OpNoLine +%_ptr_Function_PS_INPUT = OpTypePointer Function %PS_INPUT +OpNoLine +%v4float = OpTypeVector %float 4 +OpNoLine +%PS_OUTPUT = OpTypeStruct %v4float +OpNoLine +%24 = OpTypeFunction %PS_OUTPUT %_ptr_Function_PS_INPUT +OpNoLine +%uint = OpTypeInt 32 0 +OpNoLine +%PerViewConstantBuffer_t = OpTypeStruct %uint %uint %uint +OpNoLine +%_ptr_PushConstant_PerViewConstantBuffer_t = OpTypePointer PushConstant %PerViewConstantBuffer_t +OpNoLine +%_ = OpVariable %_ptr_PushConstant_PerViewConstantBuffer_t PushConstant +OpNoLine +%int = OpTypeInt 32 1 +OpNoLine +%int_2 = OpConstant %int 2 +OpNoLine +%_ptr_PushConstant_uint = OpTypePointer PushConstant %uint +OpNoLine +%bool = OpTypeBool +OpNoLine +%uint_0 = OpConstant %uint 0 +OpNoLine +%_ptr_Function_uint = OpTypePointer Function %uint +OpNoLine +%int_0 = OpConstant %int 0 +OpNoLine +%int_1 = OpConstant %int 1 +OpNoLine +%_ptr_Function_PS_OUTPUT = OpTypePointer Function %PS_OUTPUT +OpNoLine +%36 = OpTypeImage %float 2D 0 0 0 1 Unknown +OpNoLine +%uint_128 = OpConstant %uint 128 +OpNoLine +%_arr_36_uint_128 = OpTypeArray %36 %uint_128 +OpNoLine +%_ptr_UniformConstant__arr_36_uint_128 = OpTypePointer UniformConstant %_arr_36_uint_128 +OpNoLine +%g_tColor = OpVariable %_ptr_UniformConstant__arr_36_uint_128 UniformConstant +OpNoLine +%_ptr_UniformConstant_36 = OpTypePointer UniformConstant %36 +OpNoLine +%41 = OpTypeSampler +OpNoLine +%_ptr_UniformConstant_41 = OpTypePointer UniformConstant %41 +OpNoLine +%g_sAniso = OpVariable %_ptr_UniformConstant_41 UniformConstant +OpNoLine +%43 = OpTypeSampledImage %36 +OpNoLine +%_ptr_Function_v2float = OpTypePointer Function %v2float +OpNoLine +%_ptr_Function_v4float = OpTypePointer Function %v4float +OpNoLine +%_ptr_Input_v2float = OpTypePointer Input %v2float +OpNoLine +%i_vTextureCoords = OpVariable %_ptr_Input_v2float Input +OpNoLine +%_ptr_Output_v4float = OpTypePointer Output %v4float +OpNoLine +%_entryPointOutput_vColor = OpVariable %_ptr_Output_v4float Output +OpNoLine +%MainPs = OpFunction %void None %19 +OpNoLine +%48 = OpLabel +OpNoLine +%i_0 = OpVariable %_ptr_Function_PS_INPUT Function +OpNoLine +%param = OpVariable %_ptr_Function_PS_INPUT Function +OpLine %5 23 0 +%49 = OpLoad %v2float %i_vTextureCoords +OpLine %5 23 0 +%50 = OpAccessChain %_ptr_Function_v2float %i_0 %int_0 +OpLine %5 23 0 +OpStore %50 %49 +OpLine %5 23 0 +%51 = OpLoad %PS_INPUT %i_0 +OpLine %5 23 0 +OpStore %param %51 +OpLine %5 23 0 +%52 = OpFunctionCall %PS_OUTPUT %_MainPs_struct_PS_INPUT_vf21_ %param +OpLine %5 23 0 +%53 = OpCompositeExtract %v4float %52 0 +OpLine %5 23 0 +OpStore %_entryPointOutput_vColor %53 +OpLine %5 23 0 +OpReturn +OpNoLine +OpFunctionEnd +OpNoLine +%_MainPs_struct_PS_INPUT_vf21_ = OpFunction %PS_OUTPUT None %24 +OpNoLine +%i = OpFunctionParameter %_ptr_Function_PS_INPUT +OpNoLine +%54 = OpLabel +OpNoLine +%u = OpVariable %_ptr_Function_uint Function +OpNoLine +%ps_output = OpVariable %_ptr_Function_PS_OUTPUT Function +OpLine %5 27 0 +%55 = OpAccessChain %_ptr_PushConstant_uint %_ %int_2 +OpLine %5 27 0 +%56 = OpLoad %uint %55 +OpLine %5 27 0 +%57 = OpINotEqual %bool %56 %uint_0 +OpLine %5 27 0 +OpSelectionMerge %58 None +OpBranchConditional %57 %59 %60 +OpNoLine +%59 = OpLabel +OpLine %5 28 0 +%61 = OpAccessChain %_ptr_PushConstant_uint %_ %int_0 +OpLine %5 28 0 +%62 = OpLoad %uint %61 +OpLine %5 28 0 +OpStore %u %62 +OpLine %5 28 0 +OpBranch %58 +OpNoLine +%60 = OpLabel +OpLine %5 30 0 +%63 = OpAccessChain %_ptr_PushConstant_uint %_ %int_1 +OpLine %5 30 0 +%64 = OpLoad %uint %63 +OpLine %5 30 0 +OpStore %u %64 +OpLine %5 30 0 +OpBranch %58 +OpNoLine +%58 = OpLabel +OpLine %5 31 0 +%65 = OpLoad %uint %u +OpLine %5 31 0 +%66 = OpAccessChain %_ptr_UniformConstant_36 %g_tColor %65 +OpLine %5 31 0 +%67 = OpLoad %36 %66 +OpLine %5 31 0 +%68 = OpLoad %41 %g_sAniso +OpLine %5 31 0 +%69 = OpSampledImage %43 %67 %68 +OpLine %5 31 0 +%70 = OpAccessChain %_ptr_Function_v2float %i %int_0 +OpLine %5 31 0 +%71 = OpLoad %v2float %70 +OpLine %5 31 0 +%72 = OpImageSampleImplicitLod %v4float %69 %71 +OpLine %5 31 0 +%73 = OpAccessChain %_ptr_Function_v4float %ps_output %int_0 +OpLine %5 31 0 +OpStore %73 %72 +OpLine %5 32 0 +%74 = OpLoad %PS_OUTPUT %ps_output +OpLine %5 32 0 +OpReturnValue %74 +OpNoLine +OpFunctionEnd +)"; + + SinglePassRunAndCheck(predefs + before, predefs + after, + false, true, kLinesPropagateLines); +} + +TEST_F(ProcessLinesTest, SimpleElimination) { + // Previous test with before and after reversed + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %MainPs "MainPs" %i_vTextureCoords %_entryPointOutput_vColor +OpExecutionMode %MainPs OriginUpperLeft +%5 = OpString "foo.frag" +OpSource HLSL 500 +OpName %MainPs "MainPs" +OpName %PS_INPUT "PS_INPUT" +OpMemberName %PS_INPUT 0 "vTextureCoords" +OpName %PS_OUTPUT "PS_OUTPUT" +OpMemberName %PS_OUTPUT 0 "vColor" +OpName %_MainPs_struct_PS_INPUT_vf21_ "@MainPs(struct-PS_INPUT-vf21;" +OpName %i "i" +OpName %PerViewConstantBuffer_t "PerViewConstantBuffer_t" +OpMemberName %PerViewConstantBuffer_t 0 "g_nDataIdx" +OpMemberName %PerViewConstantBuffer_t 1 "g_nDataIdx2" +OpMemberName %PerViewConstantBuffer_t 2 "g_B" +OpName %_ "" +OpName %u "u" +OpName %ps_output "ps_output" +OpName %g_tColor "g_tColor" +OpName %g_sAniso "g_sAniso" +OpName %i_0 "i" +OpName %i_vTextureCoords "i.vTextureCoords" +OpName %_entryPointOutput_vColor "@entryPointOutput.vColor" +OpName %param "param" +OpMemberDecorate %PerViewConstantBuffer_t 0 Offset 0 +OpMemberDecorate %PerViewConstantBuffer_t 1 Offset 4 +OpMemberDecorate %PerViewConstantBuffer_t 2 Offset 8 +OpDecorate %PerViewConstantBuffer_t Block +OpDecorate %g_tColor DescriptorSet 0 +OpDecorate %g_sAniso DescriptorSet 0 +OpDecorate %i_vTextureCoords Location 0 +OpDecorate %_entryPointOutput_vColor Location 0 +)"; + + const std::string before = + R"(OpNoLine +%void = OpTypeVoid +OpNoLine +%19 = OpTypeFunction %void +OpNoLine +%float = OpTypeFloat 32 +OpNoLine +%v2float = OpTypeVector %float 2 +OpNoLine +%PS_INPUT = OpTypeStruct %v2float +OpNoLine +%_ptr_Function_PS_INPUT = OpTypePointer Function %PS_INPUT +OpNoLine +%v4float = OpTypeVector %float 4 +OpNoLine +%PS_OUTPUT = OpTypeStruct %v4float +OpNoLine +%24 = OpTypeFunction %PS_OUTPUT %_ptr_Function_PS_INPUT +OpNoLine +%uint = OpTypeInt 32 0 +OpNoLine +%PerViewConstantBuffer_t = OpTypeStruct %uint %uint %uint +OpNoLine +%_ptr_PushConstant_PerViewConstantBuffer_t = OpTypePointer PushConstant %PerViewConstantBuffer_t +OpNoLine +%_ = OpVariable %_ptr_PushConstant_PerViewConstantBuffer_t PushConstant +OpNoLine +%int = OpTypeInt 32 1 +OpNoLine +%int_2 = OpConstant %int 2 +OpNoLine +%_ptr_PushConstant_uint = OpTypePointer PushConstant %uint +OpNoLine +%bool = OpTypeBool +OpNoLine +%uint_0 = OpConstant %uint 0 +OpNoLine +%_ptr_Function_uint = OpTypePointer Function %uint +OpNoLine +%int_0 = OpConstant %int 0 +OpNoLine +%int_1 = OpConstant %int 1 +OpNoLine +%_ptr_Function_PS_OUTPUT = OpTypePointer Function %PS_OUTPUT +OpNoLine +%36 = OpTypeImage %float 2D 0 0 0 1 Unknown +OpNoLine +%uint_128 = OpConstant %uint 128 +OpNoLine +%_arr_36_uint_128 = OpTypeArray %36 %uint_128 +OpNoLine +%_ptr_UniformConstant__arr_36_uint_128 = OpTypePointer UniformConstant %_arr_36_uint_128 +OpNoLine +%g_tColor = OpVariable %_ptr_UniformConstant__arr_36_uint_128 UniformConstant +OpNoLine +%_ptr_UniformConstant_36 = OpTypePointer UniformConstant %36 +OpNoLine +%41 = OpTypeSampler +OpNoLine +%_ptr_UniformConstant_41 = OpTypePointer UniformConstant %41 +OpNoLine +%g_sAniso = OpVariable %_ptr_UniformConstant_41 UniformConstant +OpNoLine +%43 = OpTypeSampledImage %36 +OpNoLine +%_ptr_Function_v2float = OpTypePointer Function %v2float +OpNoLine +%_ptr_Function_v4float = OpTypePointer Function %v4float +OpNoLine +%_ptr_Input_v2float = OpTypePointer Input %v2float +OpNoLine +%i_vTextureCoords = OpVariable %_ptr_Input_v2float Input +OpNoLine +%_ptr_Output_v4float = OpTypePointer Output %v4float +OpNoLine +%_entryPointOutput_vColor = OpVariable %_ptr_Output_v4float Output +OpNoLine +%MainPs = OpFunction %void None %19 +OpNoLine +%48 = OpLabel +OpNoLine +%i_0 = OpVariable %_ptr_Function_PS_INPUT Function +OpNoLine +%param = OpVariable %_ptr_Function_PS_INPUT Function +OpLine %5 23 0 +%49 = OpLoad %v2float %i_vTextureCoords +OpLine %5 23 0 +%50 = OpAccessChain %_ptr_Function_v2float %i_0 %int_0 +OpLine %5 23 0 +OpStore %50 %49 +OpLine %5 23 0 +%51 = OpLoad %PS_INPUT %i_0 +OpLine %5 23 0 +OpStore %param %51 +OpLine %5 23 0 +%52 = OpFunctionCall %PS_OUTPUT %_MainPs_struct_PS_INPUT_vf21_ %param +OpLine %5 23 0 +%53 = OpCompositeExtract %v4float %52 0 +OpLine %5 23 0 +OpStore %_entryPointOutput_vColor %53 +OpLine %5 23 0 +OpReturn +OpNoLine +OpFunctionEnd +OpNoLine +%_MainPs_struct_PS_INPUT_vf21_ = OpFunction %PS_OUTPUT None %24 +OpNoLine +%i = OpFunctionParameter %_ptr_Function_PS_INPUT +OpNoLine +%54 = OpLabel +OpNoLine +%u = OpVariable %_ptr_Function_uint Function +OpNoLine +%ps_output = OpVariable %_ptr_Function_PS_OUTPUT Function +OpLine %5 27 0 +%55 = OpAccessChain %_ptr_PushConstant_uint %_ %int_2 +OpLine %5 27 0 +%56 = OpLoad %uint %55 +OpLine %5 27 0 +%57 = OpINotEqual %bool %56 %uint_0 +OpLine %5 27 0 +OpSelectionMerge %58 None +OpBranchConditional %57 %59 %60 +OpNoLine +%59 = OpLabel +OpLine %5 28 0 +%61 = OpAccessChain %_ptr_PushConstant_uint %_ %int_0 +OpLine %5 28 0 +%62 = OpLoad %uint %61 +OpLine %5 28 0 +OpStore %u %62 +OpLine %5 28 0 +OpBranch %58 +OpNoLine +%60 = OpLabel +OpLine %5 30 0 +%63 = OpAccessChain %_ptr_PushConstant_uint %_ %int_1 +OpLine %5 30 0 +%64 = OpLoad %uint %63 +OpLine %5 30 0 +OpStore %u %64 +OpLine %5 30 0 +OpBranch %58 +OpNoLine +%58 = OpLabel +OpLine %5 31 0 +%65 = OpLoad %uint %u +OpLine %5 31 0 +%66 = OpAccessChain %_ptr_UniformConstant_36 %g_tColor %65 +OpLine %5 31 0 +%67 = OpLoad %36 %66 +OpLine %5 31 0 +%68 = OpLoad %41 %g_sAniso +OpLine %5 31 0 +%69 = OpSampledImage %43 %67 %68 +OpLine %5 31 0 +%70 = OpAccessChain %_ptr_Function_v2float %i %int_0 +OpLine %5 31 0 +%71 = OpLoad %v2float %70 +OpLine %5 31 0 +%72 = OpImageSampleImplicitLod %v4float %69 %71 +OpLine %5 31 0 +%73 = OpAccessChain %_ptr_Function_v4float %ps_output %int_0 +OpLine %5 31 0 +OpStore %73 %72 +OpLine %5 32 0 +%74 = OpLoad %PS_OUTPUT %ps_output +OpLine %5 32 0 +OpReturnValue %74 +OpNoLine +OpFunctionEnd +)"; + + const std::string after = + R"(%void = OpTypeVoid +%19 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%PS_INPUT = OpTypeStruct %v2float +%_ptr_Function_PS_INPUT = OpTypePointer Function %PS_INPUT +%v4float = OpTypeVector %float 4 +%PS_OUTPUT = OpTypeStruct %v4float +%24 = OpTypeFunction %PS_OUTPUT %_ptr_Function_PS_INPUT +%uint = OpTypeInt 32 0 +%PerViewConstantBuffer_t = OpTypeStruct %uint %uint %uint +%_ptr_PushConstant_PerViewConstantBuffer_t = OpTypePointer PushConstant %PerViewConstantBuffer_t +%_ = OpVariable %_ptr_PushConstant_PerViewConstantBuffer_t PushConstant +%int = OpTypeInt 32 1 +%int_2 = OpConstant %int 2 +%_ptr_PushConstant_uint = OpTypePointer PushConstant %uint +%bool = OpTypeBool +%uint_0 = OpConstant %uint 0 +%_ptr_Function_uint = OpTypePointer Function %uint +%int_0 = OpConstant %int 0 +%int_1 = OpConstant %int 1 +%_ptr_Function_PS_OUTPUT = OpTypePointer Function %PS_OUTPUT +%36 = OpTypeImage %float 2D 0 0 0 1 Unknown +%uint_128 = OpConstant %uint 128 +%_arr_36_uint_128 = OpTypeArray %36 %uint_128 +%_ptr_UniformConstant__arr_36_uint_128 = OpTypePointer UniformConstant %_arr_36_uint_128 +%g_tColor = OpVariable %_ptr_UniformConstant__arr_36_uint_128 UniformConstant +%_ptr_UniformConstant_36 = OpTypePointer UniformConstant %36 +%41 = OpTypeSampler +%_ptr_UniformConstant_41 = OpTypePointer UniformConstant %41 +%g_sAniso = OpVariable %_ptr_UniformConstant_41 UniformConstant +%43 = OpTypeSampledImage %36 +%_ptr_Function_v2float = OpTypePointer Function %v2float +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v2float = OpTypePointer Input %v2float +%i_vTextureCoords = OpVariable %_ptr_Input_v2float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_entryPointOutput_vColor = OpVariable %_ptr_Output_v4float Output +%MainPs = OpFunction %void None %19 +%48 = OpLabel +%i_0 = OpVariable %_ptr_Function_PS_INPUT Function +%param = OpVariable %_ptr_Function_PS_INPUT Function +OpLine %5 23 0 +%49 = OpLoad %v2float %i_vTextureCoords +%50 = OpAccessChain %_ptr_Function_v2float %i_0 %int_0 +OpStore %50 %49 +%51 = OpLoad %PS_INPUT %i_0 +OpStore %param %51 +%52 = OpFunctionCall %PS_OUTPUT %_MainPs_struct_PS_INPUT_vf21_ %param +%53 = OpCompositeExtract %v4float %52 0 +OpStore %_entryPointOutput_vColor %53 +OpReturn +OpFunctionEnd +%_MainPs_struct_PS_INPUT_vf21_ = OpFunction %PS_OUTPUT None %24 +%i = OpFunctionParameter %_ptr_Function_PS_INPUT +%54 = OpLabel +%u = OpVariable %_ptr_Function_uint Function +%ps_output = OpVariable %_ptr_Function_PS_OUTPUT Function +OpLine %5 27 0 +%55 = OpAccessChain %_ptr_PushConstant_uint %_ %int_2 +%56 = OpLoad %uint %55 +%57 = OpINotEqual %bool %56 %uint_0 +OpSelectionMerge %58 None +OpBranchConditional %57 %59 %60 +%59 = OpLabel +OpLine %5 28 0 +%61 = OpAccessChain %_ptr_PushConstant_uint %_ %int_0 +%62 = OpLoad %uint %61 +OpStore %u %62 +OpBranch %58 +%60 = OpLabel +OpLine %5 30 0 +%63 = OpAccessChain %_ptr_PushConstant_uint %_ %int_1 +%64 = OpLoad %uint %63 +OpStore %u %64 +OpBranch %58 +%58 = OpLabel +OpLine %5 31 0 +%65 = OpLoad %uint %u +%66 = OpAccessChain %_ptr_UniformConstant_36 %g_tColor %65 +%67 = OpLoad %36 %66 +%68 = OpLoad %41 %g_sAniso +%69 = OpSampledImage %43 %67 %68 +%70 = OpAccessChain %_ptr_Function_v2float %i %int_0 +%71 = OpLoad %v2float %70 +%72 = OpImageSampleImplicitLod %v4float %69 %71 +%73 = OpAccessChain %_ptr_Function_v4float %ps_output %int_0 +OpStore %73 %72 +OpLine %5 32 0 +%74 = OpLoad %PS_OUTPUT %ps_output +OpReturnValue %74 +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs + before, predefs + after, false, true, kLinesEliminateDeadLines); +} + +// TODO(greg-lunarg): Add tests to verify handling of these cases: +// +// TODO(greg-lunarg): Think about other tests :) + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/propagator_test.cpp b/test/opt/propagator_test.cpp new file mode 100644 index 000000000..fb8e487cc --- /dev/null +++ b/test/opt/propagator_test.cpp @@ -0,0 +1,219 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "source/opt/build_module.h" +#include "source/opt/cfg.h" +#include "source/opt/ir_context.h" +#include "source/opt/pass.h" +#include "source/opt/propagator.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::UnorderedElementsAre; + +class PropagatorTest : public testing::Test { + protected: + virtual void TearDown() { + ctx_.reset(nullptr); + values_.clear(); + values_vec_.clear(); + } + + void Assemble(const std::string& input) { + ctx_ = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, input); + ASSERT_NE(nullptr, ctx_) << "Assembling failed for shader:\n" + << input << "\n"; + } + + bool Propagate(const SSAPropagator::VisitFunction& visit_fn) { + SSAPropagator propagator(ctx_.get(), visit_fn); + bool retval = false; + for (auto& fn : *ctx_->module()) { + retval |= propagator.Run(&fn); + } + return retval; + } + + const std::vector& GetValues() { + values_vec_.clear(); + for (const auto& it : values_) { + values_vec_.push_back(it.second); + } + return values_vec_; + } + + std::unique_ptr ctx_; + std::map values_; + std::vector values_vec_; +}; + +TEST_F(PropagatorTest, LocalPropagate) { + const std::string spv_asm = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %outparm + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpName %main "main" + OpName %x "x" + OpName %y "y" + OpName %z "z" + OpName %outparm "outparm" + OpDecorate %outparm Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int + %int_4 = OpConstant %int 4 + %int_3 = OpConstant %int 3 + %int_1 = OpConstant %int 1 +%_ptr_Output_int = OpTypePointer Output %int + %outparm = OpVariable %_ptr_Output_int Output + %main = OpFunction %void None %3 + %5 = OpLabel + %x = OpVariable %_ptr_Function_int Function + %y = OpVariable %_ptr_Function_int Function + %z = OpVariable %_ptr_Function_int Function + OpStore %x %int_4 + OpStore %y %int_3 + OpStore %z %int_1 + %20 = OpLoad %int %z + OpStore %outparm %20 + OpReturn + OpFunctionEnd + )"; + Assemble(spv_asm); + + const auto visit_fn = [this](Instruction* instr, BasicBlock** dest_bb) { + *dest_bb = nullptr; + if (instr->opcode() == SpvOpStore) { + uint32_t lhs_id = instr->GetSingleWordOperand(0); + uint32_t rhs_id = instr->GetSingleWordOperand(1); + Instruction* rhs_def = ctx_->get_def_use_mgr()->GetDef(rhs_id); + if (rhs_def->opcode() == SpvOpConstant) { + uint32_t val = rhs_def->GetSingleWordOperand(2); + values_[lhs_id] = val; + return SSAPropagator::kInteresting; + } + } + return SSAPropagator::kVarying; + }; + + EXPECT_TRUE(Propagate(visit_fn)); + EXPECT_THAT(GetValues(), UnorderedElementsAre(4, 3, 1)); +} + +TEST_F(PropagatorTest, PropagateThroughPhis) { + const std::string spv_asm = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %x %outparm + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpName %main "main" + OpName %x "x" + OpName %outparm "outparm" + OpDecorate %x Flat + OpDecorate %x Location 0 + OpDecorate %outparm Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 + %bool = OpTypeBool +%_ptr_Function_int = OpTypePointer Function %int + %int_4 = OpConstant %int 4 + %int_3 = OpConstant %int 3 + %int_1 = OpConstant %int 1 +%_ptr_Input_int = OpTypePointer Input %int + %x = OpVariable %_ptr_Input_int Input +%_ptr_Output_int = OpTypePointer Output %int + %outparm = OpVariable %_ptr_Output_int Output + %main = OpFunction %void None %3 + %4 = OpLabel + %5 = OpLoad %int %x + %6 = OpSGreaterThan %bool %5 %int_3 + OpSelectionMerge %25 None + OpBranchConditional %6 %22 %23 + %22 = OpLabel + %7 = OpLoad %int %int_4 + OpBranch %25 + %23 = OpLabel + %8 = OpLoad %int %int_4 + OpBranch %25 + %25 = OpLabel + %35 = OpPhi %int %7 %22 %8 %23 + OpStore %outparm %35 + OpReturn + OpFunctionEnd + )"; + + Assemble(spv_asm); + + Instruction* phi_instr = nullptr; + const auto visit_fn = [this, &phi_instr](Instruction* instr, + BasicBlock** dest_bb) { + *dest_bb = nullptr; + if (instr->opcode() == SpvOpLoad) { + uint32_t rhs_id = instr->GetSingleWordOperand(2); + Instruction* rhs_def = ctx_->get_def_use_mgr()->GetDef(rhs_id); + if (rhs_def->opcode() == SpvOpConstant) { + uint32_t val = rhs_def->GetSingleWordOperand(2); + values_[instr->result_id()] = val; + return SSAPropagator::kInteresting; + } + } else if (instr->opcode() == SpvOpPhi) { + phi_instr = instr; + SSAPropagator::PropStatus retval; + for (uint32_t i = 2; i < instr->NumOperands(); i += 2) { + uint32_t phi_arg_id = instr->GetSingleWordOperand(i); + auto it = values_.find(phi_arg_id); + if (it != values_.end()) { + EXPECT_EQ(it->second, 4u); + retval = SSAPropagator::kInteresting; + values_[instr->result_id()] = it->second; + } else { + retval = SSAPropagator::kNotInteresting; + break; + } + } + return retval; + } + + return SSAPropagator::kVarying; + }; + + EXPECT_TRUE(Propagate(visit_fn)); + + // The propagator should've concluded that the Phi instruction has a constant + // value of 4. + EXPECT_NE(phi_instr, nullptr); + EXPECT_EQ(values_[phi_instr->result_id()], 4u); + + EXPECT_THAT(GetValues(), UnorderedElementsAre(4u, 4u, 4u)); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/reduce_load_size_test.cpp b/test/opt/reduce_load_size_test.cpp new file mode 100644 index 000000000..7b850e317 --- /dev/null +++ b/test/opt/reduce_load_size_test.cpp @@ -0,0 +1,354 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ReduceLoadSizeTest = PassTest<::testing::Test>; + +TEST_F(ReduceLoadSizeTest, cbuffer_load_extract) { + // Originally from the following HLSL: + // struct S { + // uint f; + // }; + // + // + // cbuffer gBuffer { uint a[32]; }; + // + // RWStructuredBuffer gRWSBuffer; + // + // uint foo(uint p[32]) { + // return p[1]; + // } + // + // [numthreads(1,1,1)] + // void main() { + // gRWSBuffer[0].f = foo(a); + // } + const std::string test = + R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource HLSL 600 + OpName %type_gBuffer "type.gBuffer" + OpMemberName %type_gBuffer 0 "a" + OpName %gBuffer "gBuffer" + OpName %S "S" + OpMemberName %S 0 "f" + OpName %type_RWStructuredBuffer_S "type.RWStructuredBuffer.S" + OpName %gRWSBuffer "gRWSBuffer" + OpName %main "main" + OpDecorate %_arr_uint_uint_32 ArrayStride 16 + OpMemberDecorate %type_gBuffer 0 Offset 0 + OpDecorate %type_gBuffer Block + OpMemberDecorate %S 0 Offset 0 + OpDecorate %_runtimearr_S ArrayStride 4 + OpMemberDecorate %type_RWStructuredBuffer_S 0 Offset 0 + OpDecorate %type_RWStructuredBuffer_S BufferBlock + OpDecorate %gBuffer DescriptorSet 0 + OpDecorate %gBuffer Binding 0 + OpDecorate %gRWSBuffer DescriptorSet 0 + OpDecorate %gRWSBuffer Binding 1 + %uint = OpTypeInt 32 0 + %uint_32 = OpConstant %uint 32 +%_arr_uint_uint_32 = OpTypeArray %uint %uint_32 +%type_gBuffer = OpTypeStruct %_arr_uint_uint_32 +%_ptr_Uniform_type_gBuffer = OpTypePointer Uniform %type_gBuffer + %S = OpTypeStruct %uint +%_runtimearr_S = OpTypeRuntimeArray %S +%type_RWStructuredBuffer_S = OpTypeStruct %_runtimearr_S +%_ptr_Uniform_type_RWStructuredBuffer_S = OpTypePointer Uniform %type_RWStructuredBuffer_S + %int = OpTypeInt 32 1 + %void = OpTypeVoid + %15 = OpTypeFunction %void + %int_0 = OpConstant %int 0 +%_ptr_Uniform__arr_uint_uint_32 = OpTypePointer Uniform %_arr_uint_uint_32 + %uint_0 = OpConstant %uint 0 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint + %gBuffer = OpVariable %_ptr_Uniform_type_gBuffer Uniform + %gRWSBuffer = OpVariable %_ptr_Uniform_type_RWStructuredBuffer_S Uniform + %main = OpFunction %void None %15 + %20 = OpLabel +; CHECK: [[ac1:%\w+]] = OpAccessChain {{%\w+}} %gBuffer %int_0 +; CHECK: [[ac2:%\w+]] = OpAccessChain {{%\w+}} [[ac1]] %uint_1 +; CHECK: [[ld:%\w+]] = OpLoad {{%\w+}} [[ac2]] +; CHECK: OpStore {{%\w+}} [[ld]] + %21 = OpAccessChain %_ptr_Uniform__arr_uint_uint_32 %gBuffer %int_0 + %22 = OpLoad %_arr_uint_uint_32 %21 ; Load of 32-element array. + %23 = OpCompositeExtract %uint %22 1 + %24 = OpAccessChain %_ptr_Uniform_uint %gRWSBuffer %int_0 %uint_0 %int_0 + OpStore %24 %23 + OpReturn + OpFunctionEnd + )"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + SinglePassRunAndMatch(test, false); +} + +TEST_F(ReduceLoadSizeTest, cbuffer_load_extract_vector) { + // Originally from the following HLSL: + // struct S { + // uint f; + // }; + // + // + // cbuffer gBuffer { uint4 a; }; + // + // RWStructuredBuffer gRWSBuffer; + // + // uint foo(uint p[32]) { + // return p[1]; + // } + // + // [numthreads(1,1,1)] + // void main() { + // gRWSBuffer[0].f = foo(a); + // } + const std::string test = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpExecutionMode %main LocalSize 1 1 1 +OpSource HLSL 600 +OpName %type_gBuffer "type.gBuffer" +OpMemberName %type_gBuffer 0 "a" +OpName %gBuffer "gBuffer" +OpName %S "S" +OpMemberName %S 0 "f" +OpName %type_RWStructuredBuffer_S "type.RWStructuredBuffer.S" +OpName %gRWSBuffer "gRWSBuffer" +OpName %main "main" +OpMemberDecorate %type_gBuffer 0 Offset 0 +OpDecorate %type_gBuffer Block +OpMemberDecorate %S 0 Offset 0 +OpDecorate %_runtimearr_S ArrayStride 4 +OpMemberDecorate %type_RWStructuredBuffer_S 0 Offset 0 +OpDecorate %type_RWStructuredBuffer_S BufferBlock +OpDecorate %gBuffer DescriptorSet 0 +OpDecorate %gBuffer Binding 0 +OpDecorate %gRWSBuffer DescriptorSet 0 +OpDecorate %gRWSBuffer Binding 1 +%uint = OpTypeInt 32 0 +%uint_32 = OpConstant %uint 32 +%v4uint = OpTypeVector %uint 4 +%type_gBuffer = OpTypeStruct %v4uint +%_ptr_Uniform_type_gBuffer = OpTypePointer Uniform %type_gBuffer +%S = OpTypeStruct %uint +%_runtimearr_S = OpTypeRuntimeArray %S +%type_RWStructuredBuffer_S = OpTypeStruct %_runtimearr_S +%_ptr_Uniform_type_RWStructuredBuffer_S = OpTypePointer Uniform %type_RWStructuredBuffer_S +%int = OpTypeInt 32 1 +%void = OpTypeVoid +%15 = OpTypeFunction %void +%int_0 = OpConstant %int 0 +%_ptr_Uniform_v4uint = OpTypePointer Uniform %v4uint +%uint_0 = OpConstant %uint 0 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%gBuffer = OpVariable %_ptr_Uniform_type_gBuffer Uniform +%gRWSBuffer = OpVariable %_ptr_Uniform_type_RWStructuredBuffer_S Uniform +%main = OpFunction %void None %15 +%20 = OpLabel +%21 = OpAccessChain %_ptr_Uniform_v4uint %gBuffer %int_0 +%22 = OpLoad %v4uint %21 +%23 = OpCompositeExtract %uint %22 1 +%24 = OpAccessChain %_ptr_Uniform_uint %gRWSBuffer %int_0 %uint_0 %int_0 +OpStore %24 %23 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + SinglePassRunAndCheck(test, test, true, false); +} + +TEST_F(ReduceLoadSizeTest, cbuffer_load_5_extract) { + // All of the elements of the value loaded are used, so we should not + // change the load. + const std::string test = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpExecutionMode %main LocalSize 1 1 1 +OpSource HLSL 600 +OpName %type_gBuffer "type.gBuffer" +OpMemberName %type_gBuffer 0 "a" +OpName %gBuffer "gBuffer" +OpName %S "S" +OpMemberName %S 0 "f" +OpName %type_RWStructuredBuffer_S "type.RWStructuredBuffer.S" +OpName %gRWSBuffer "gRWSBuffer" +OpName %main "main" +OpDecorate %_arr_uint_uint_5 ArrayStride 16 +OpMemberDecorate %type_gBuffer 0 Offset 0 +OpDecorate %type_gBuffer Block +OpMemberDecorate %S 0 Offset 0 +OpDecorate %_runtimearr_S ArrayStride 4 +OpMemberDecorate %type_RWStructuredBuffer_S 0 Offset 0 +OpDecorate %type_RWStructuredBuffer_S BufferBlock +OpDecorate %gBuffer DescriptorSet 0 +OpDecorate %gBuffer Binding 0 +OpDecorate %gRWSBuffer DescriptorSet 0 +OpDecorate %gRWSBuffer Binding 1 +%uint = OpTypeInt 32 0 +%uint_5 = OpConstant %uint 5 +%_arr_uint_uint_5 = OpTypeArray %uint %uint_5 +%type_gBuffer = OpTypeStruct %_arr_uint_uint_5 +%_ptr_Uniform_type_gBuffer = OpTypePointer Uniform %type_gBuffer +%S = OpTypeStruct %uint +%_runtimearr_S = OpTypeRuntimeArray %S +%type_RWStructuredBuffer_S = OpTypeStruct %_runtimearr_S +%_ptr_Uniform_type_RWStructuredBuffer_S = OpTypePointer Uniform %type_RWStructuredBuffer_S +%int = OpTypeInt 32 1 +%void = OpTypeVoid +%15 = OpTypeFunction %void +%int_0 = OpConstant %int 0 +%_ptr_Uniform__arr_uint_uint_5 = OpTypePointer Uniform %_arr_uint_uint_5 +%uint_0 = OpConstant %uint 0 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%gBuffer = OpVariable %_ptr_Uniform_type_gBuffer Uniform +%gRWSBuffer = OpVariable %_ptr_Uniform_type_RWStructuredBuffer_S Uniform +%main = OpFunction %void None %15 +%20 = OpLabel +%21 = OpAccessChain %_ptr_Uniform__arr_uint_uint_5 %gBuffer %int_0 +%22 = OpLoad %_arr_uint_uint_5 %21 +%23 = OpCompositeExtract %uint %22 0 +%24 = OpCompositeExtract %uint %22 1 +%25 = OpCompositeExtract %uint %22 2 +%26 = OpCompositeExtract %uint %22 3 +%27 = OpCompositeExtract %uint %22 4 +%28 = OpIAdd %uint %23 %24 +%29 = OpIAdd %uint %28 %25 +%30 = OpIAdd %uint %29 %26 +%31 = OpIAdd %uint %20 %27 +%32 = OpAccessChain %_ptr_Uniform_uint %gRWSBuffer %int_0 %uint_0 %int_0 +OpStore %32 %31 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + SinglePassRunAndCheck(test, test, true, false); +} + +TEST_F(ReduceLoadSizeTest, cbuffer_load_fully_used) { + // The result of the load (%22) is used in an instruction that uses the whole + // load and has only 1 in operand. This trigger issue #1559. + const std::string test = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpExecutionMode %main LocalSize 1 1 1 +OpSource HLSL 600 +OpName %type_gBuffer "type.gBuffer" +OpMemberName %type_gBuffer 0 "a" +OpName %gBuffer "gBuffer" +OpName %S "S" +OpMemberName %S 0 "f" +OpName %type_RWStructuredBuffer_S "type.RWStructuredBuffer.S" +OpName %gRWSBuffer "gRWSBuffer" +OpName %main "main" +OpMemberDecorate %type_gBuffer 0 Offset 0 +OpDecorate %type_gBuffer Block +OpMemberDecorate %S 0 Offset 0 +OpDecorate %_runtimearr_S ArrayStride 4 +OpMemberDecorate %type_RWStructuredBuffer_S 0 Offset 0 +OpDecorate %type_RWStructuredBuffer_S BufferBlock +OpDecorate %gBuffer DescriptorSet 0 +OpDecorate %gBuffer Binding 0 +OpDecorate %gRWSBuffer DescriptorSet 0 +OpDecorate %gRWSBuffer Binding 1 +%uint = OpTypeInt 32 0 +%uint_32 = OpConstant %uint 32 +%v4uint = OpTypeVector %uint 4 +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%type_gBuffer = OpTypeStruct %v4uint +%_ptr_Uniform_type_gBuffer = OpTypePointer Uniform %type_gBuffer +%S = OpTypeStruct %uint +%_runtimearr_S = OpTypeRuntimeArray %S +%type_RWStructuredBuffer_S = OpTypeStruct %_runtimearr_S +%_ptr_Uniform_type_RWStructuredBuffer_S = OpTypePointer Uniform %type_RWStructuredBuffer_S +%int = OpTypeInt 32 1 +%void = OpTypeVoid +%15 = OpTypeFunction %void +%int_0 = OpConstant %int 0 +%_ptr_Uniform_v4uint = OpTypePointer Uniform %v4uint +%uint_0 = OpConstant %uint 0 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%gBuffer = OpVariable %_ptr_Uniform_type_gBuffer Uniform +%gRWSBuffer = OpVariable %_ptr_Uniform_type_RWStructuredBuffer_S Uniform +%main = OpFunction %void None %15 +%20 = OpLabel +%21 = OpAccessChain %_ptr_Uniform_v4uint %gBuffer %int_0 +%22 = OpLoad %v4uint %21 +%23 = OpCompositeExtract %uint %22 1 +%24 = OpConvertUToF %v4float %22 +%25 = OpAccessChain %_ptr_Uniform_uint %gRWSBuffer %int_0 %uint_0 %int_0 +OpStore %25 %23 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + SinglePassRunAndCheck(test, test, true, false); +} + +TEST_F(ReduceLoadSizeTest, extract_with_no_index) { + const std::string test = + R"( + OpCapability ImageGatherExtended + OpExtension "" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "P�Ma'" %12 %17 + OpExecutionMode %4 OriginUpperLeft + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %_struct_7 = OpTypeStruct %float %float +%_ptr_Input__struct_7 = OpTypePointer Input %_struct_7 +%_ptr_Output__struct_7 = OpTypePointer Output %_struct_7 + %12 = OpVariable %_ptr_Input__struct_7 Input + %17 = OpVariable %_ptr_Output__struct_7 Output + %4 = OpFunction %void DontInline|Pure|Const %3 + %245 = OpLabel + %13 = OpLoad %_struct_7 %12 + %33 = OpCompositeExtract %_struct_7 %13 + OpReturn + OpFunctionEnd + )"; + + auto result = SinglePassRunAndDisassemble(test, true, true); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/redundancy_elimination_test.cpp b/test/opt/redundancy_elimination_test.cpp new file mode 100644 index 000000000..7d2abe846 --- /dev/null +++ b/test/opt/redundancy_elimination_test.cpp @@ -0,0 +1,277 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "source/opt/build_module.h" +#include "source/opt/value_number_table.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::HasSubstr; +using ::testing::MatchesRegex; +using RedundancyEliminationTest = PassTest<::testing::Test>; + +// Test that it can get a simple case of local redundancy elimination. +// The rest of the test check for extra functionality. +TEST_F(RedundancyEliminationTest, RemoveRedundantLocalAdd) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer Function %5 + %2 = OpFunction %3 None %4 + %7 = OpLabel + %8 = OpVariable %6 Function + %9 = OpLoad %5 %8 + %10 = OpFAdd %5 %9 %9 +; CHECK: OpFAdd +; CHECK-NOT: OpFAdd + %11 = OpFAdd %5 %9 %9 + OpReturn + OpFunctionEnd + )"; + SinglePassRunAndMatch(text, false); +} + +// Remove a redundant add across basic blocks. +TEST_F(RedundancyEliminationTest, RemoveRedundantAdd) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer Function %5 + %2 = OpFunction %3 None %4 + %7 = OpLabel + %8 = OpVariable %6 Function + %9 = OpLoad %5 %8 + %10 = OpFAdd %5 %9 %9 + OpBranch %11 + %11 = OpLabel +; CHECK: OpFAdd +; CHECK-NOT: OpFAdd + %12 = OpFAdd %5 %9 %9 + OpReturn + OpFunctionEnd + )"; + SinglePassRunAndMatch(text, false); +} + +// Remove a redundant add going through a multiple basic blocks. +TEST_F(RedundancyEliminationTest, RemoveRedundantAddDiamond) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer Function %5 + %7 = OpTypeBool + %8 = OpConstantTrue %7 + %2 = OpFunction %3 None %4 + %9 = OpLabel + %10 = OpVariable %6 Function + %11 = OpLoad %5 %10 + %12 = OpFAdd %5 %11 %11 +; CHECK: OpFAdd +; CHECK-NOT: OpFAdd + OpBranchConditional %8 %13 %14 + %13 = OpLabel + OpBranch %15 + %14 = OpLabel + OpBranch %15 + %15 = OpLabel + %16 = OpFAdd %5 %11 %11 + OpReturn + OpFunctionEnd + + )"; + SinglePassRunAndMatch(text, false); +} + +// Remove a redundant add in a side node. +TEST_F(RedundancyEliminationTest, RemoveRedundantAddInSideNode) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer Function %5 + %7 = OpTypeBool + %8 = OpConstantTrue %7 + %2 = OpFunction %3 None %4 + %9 = OpLabel + %10 = OpVariable %6 Function + %11 = OpLoad %5 %10 + %12 = OpFAdd %5 %11 %11 +; CHECK: OpFAdd +; CHECK-NOT: OpFAdd + OpBranchConditional %8 %13 %14 + %13 = OpLabel + OpBranch %15 + %14 = OpLabel + %16 = OpFAdd %5 %11 %11 + OpBranch %15 + %15 = OpLabel + OpReturn + OpFunctionEnd + + )"; + SinglePassRunAndMatch(text, false); +} + +// Remove a redundant add whose value is in the result of a phi node. +TEST_F(RedundancyEliminationTest, RemoveRedundantAddWithPhi) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer Function %5 + %7 = OpTypeBool + %8 = OpConstantTrue %7 + %2 = OpFunction %3 None %4 + %9 = OpLabel + %10 = OpVariable %6 Function + %11 = OpLoad %5 %10 + OpBranchConditional %8 %13 %14 + %13 = OpLabel + %add1 = OpFAdd %5 %11 %11 +; CHECK: OpFAdd + OpBranch %15 + %14 = OpLabel + %add2 = OpFAdd %5 %11 %11 +; CHECK: OpFAdd + OpBranch %15 + %15 = OpLabel +; CHECK: OpPhi + %phi = OpPhi %5 %add1 %13 %add2 %14 +; CHECK-NOT: OpFAdd + %16 = OpFAdd %5 %11 %11 + OpReturn + OpFunctionEnd + + )"; + SinglePassRunAndMatch(text, false); +} + +// Keep the add because it is redundant on some paths, but not all paths. +TEST_F(RedundancyEliminationTest, KeepPartiallyRedundantAdd) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer Function %5 + %7 = OpTypeBool + %8 = OpConstantTrue %7 + %2 = OpFunction %3 None %4 + %9 = OpLabel + %10 = OpVariable %6 Function + %11 = OpLoad %5 %10 + OpBranchConditional %8 %13 %14 + %13 = OpLabel + %add = OpFAdd %5 %11 %11 + OpBranch %15 + %14 = OpLabel + OpBranch %15 + %15 = OpLabel + %16 = OpFAdd %5 %11 %11 + OpReturn + OpFunctionEnd + + )"; + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ false); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +// Keep the add. Even if it is redundant on all paths, there is no single id +// whose definition dominates the add and contains the same value. +TEST_F(RedundancyEliminationTest, KeepRedundantAddWithoutPhi) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer Function %5 + %7 = OpTypeBool + %8 = OpConstantTrue %7 + %2 = OpFunction %3 None %4 + %9 = OpLabel + %10 = OpVariable %6 Function + %11 = OpLoad %5 %10 + OpBranchConditional %8 %13 %14 + %13 = OpLabel + %add1 = OpFAdd %5 %11 %11 + OpBranch %15 + %14 = OpLabel + %add2 = OpFAdd %5 %11 %11 + OpBranch %15 + %15 = OpLabel + %16 = OpFAdd %5 %11 %11 + OpReturn + OpFunctionEnd + + )"; + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ false); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/register_liveness.cpp b/test/opt/register_liveness.cpp new file mode 100644 index 000000000..cb973d2e6 --- /dev/null +++ b/test/opt/register_liveness.cpp @@ -0,0 +1,1282 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/register_pressure.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::UnorderedElementsAre; +using PassClassTest = PassTest<::testing::Test>; + +void CompareSets(const std::unordered_set& computed, + const std::unordered_set& expected) { + for (Instruction* insn : computed) { + EXPECT_TRUE(expected.count(insn->result_id())) + << "Unexpected instruction in live set: " << *insn; + } + EXPECT_EQ(computed.size(), expected.size()); +} + +/* +Generated from the following GLSL + +#version 330 +in vec4 BaseColor; +flat in int Count; +void main() +{ + vec4 color = BaseColor; + vec4 acc; + if (Count == 0) { + acc = color; + } + else { + acc = color + vec4(0,1,2,0); + } + gl_FragColor = acc + color; +} +*/ +TEST_F(PassClassTest, LivenessWithIf) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %11 %15 %32 + OpExecutionMode %4 OriginLowerLeft + OpSource GLSL 330 + OpName %4 "main" + OpName %11 "BaseColor" + OpName %15 "Count" + OpName %32 "gl_FragColor" + OpDecorate %11 Location 0 + OpDecorate %15 Flat + OpDecorate %15 Location 0 + OpDecorate %32 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypeVector %6 4 + %10 = OpTypePointer Input %7 + %11 = OpVariable %10 Input + %13 = OpTypeInt 32 1 + %14 = OpTypePointer Input %13 + %15 = OpVariable %14 Input + %17 = OpConstant %13 0 + %18 = OpTypeBool + %26 = OpConstant %6 0 + %27 = OpConstant %6 1 + %28 = OpConstant %6 2 + %29 = OpConstantComposite %7 %26 %27 %28 %26 + %31 = OpTypePointer Output %7 + %32 = OpVariable %31 Output + %4 = OpFunction %2 None %3 + %5 = OpLabel + %12 = OpLoad %7 %11 + %16 = OpLoad %13 %15 + %19 = OpIEqual %18 %16 %17 + OpSelectionMerge %21 None + OpBranchConditional %19 %20 %24 + %20 = OpLabel + OpBranch %21 + %24 = OpLabel + %30 = OpFAdd %7 %12 %29 + OpBranch %21 + %21 = OpLabel + %36 = OpPhi %7 %12 %20 %30 %24 + %35 = OpFAdd %7 %36 %12 + OpStore %32 %35 + OpReturn + OpFunctionEnd + )"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function* f = &*module->begin(); + LivenessAnalysis* liveness_analysis = context->GetLivenessAnalysis(); + const RegisterLiveness* register_liveness = liveness_analysis->Get(f); + { + SCOPED_TRACE("Block 5"); + auto live_sets = register_liveness->Get(5); + std::unordered_set live_in{ + 11, // %11 = OpVariable %10 Input + 15, // %15 = OpVariable %14 Input + 32, // %32 = OpVariable %31 Output + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 12, // %12 = OpLoad %7 %11 + 32, // %32 = OpVariable %31 Output + }; + CompareSets(live_sets->live_out_, live_out); + } + { + SCOPED_TRACE("Block 20"); + auto live_sets = register_liveness->Get(20); + std::unordered_set live_inout{ + 12, // %12 = OpLoad %7 %11 + 32, // %32 = OpVariable %31 Output + }; + CompareSets(live_sets->live_in_, live_inout); + CompareSets(live_sets->live_out_, live_inout); + } + { + SCOPED_TRACE("Block 24"); + auto live_sets = register_liveness->Get(24); + std::unordered_set live_in{ + 12, // %12 = OpLoad %7 %11 + 32, // %32 = OpVariable %31 Output + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 12, // %12 = OpLoad %7 %11 + 30, // %30 = OpFAdd %7 %12 %29 + 32, // %32 = OpVariable %31 Output + }; + CompareSets(live_sets->live_out_, live_out); + } + { + SCOPED_TRACE("Block 21"); + auto live_sets = register_liveness->Get(21); + std::unordered_set live_in{ + 12, // %12 = OpLoad %7 %11 + 32, // %32 = OpVariable %31 Output + 36, // %36 = OpPhi %7 %12 %20 %30 %24 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{}; + CompareSets(live_sets->live_out_, live_out); + } +} + +/* +Generated from the following GLSL +#version 330 +in vec4 bigColor; +in vec4 BaseColor; +in float f; +flat in int Count; +flat in uvec4 v4; +void main() +{ + vec4 color = BaseColor; + for (int i = 0; i < Count; ++i) + color += bigColor; + float sum = 0.0; + for (int i = 0; i < 4; ++i) { + float acc = 0.0; + if (sum == 0.0) { + acc = v4[i]; + } + else { + acc = BaseColor[i]; + } + sum += acc + v4[i]; + } + vec4 tv4; + for (int i = 0; i < 4; ++i) + tv4[i] = v4[i] * 4u; + color += vec4(sum) + tv4; + vec4 r; + r.xyz = BaseColor.xyz; + for (int i = 0; i < Count; ++i) + r.w = f; + color.xyz += r.xyz; + for (int i = 0; i < 16; i += 4) + for (int j = 0; j < 4; j++) + color *= f; + gl_FragColor = color + tv4; +} +*/ +TEST_F(PassClassTest, RegisterLiveness) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %11 %24 %28 %55 %124 %176 + OpExecutionMode %4 OriginLowerLeft + OpSource GLSL 330 + OpName %4 "main" + OpName %11 "BaseColor" + OpName %24 "Count" + OpName %28 "bigColor" + OpName %55 "v4" + OpName %84 "tv4" + OpName %124 "f" + OpName %176 "gl_FragColor" + OpDecorate %11 Location 0 + OpDecorate %24 Flat + OpDecorate %24 Location 0 + OpDecorate %28 Location 0 + OpDecorate %55 Flat + OpDecorate %55 Location 0 + OpDecorate %124 Location 0 + OpDecorate %176 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypeVector %6 4 + %8 = OpTypePointer Function %7 + %10 = OpTypePointer Input %7 + %11 = OpVariable %10 Input + %13 = OpTypeInt 32 1 + %16 = OpConstant %13 0 + %23 = OpTypePointer Input %13 + %24 = OpVariable %23 Input + %26 = OpTypeBool + %28 = OpVariable %10 Input + %33 = OpConstant %13 1 + %35 = OpTypePointer Function %6 + %37 = OpConstant %6 0 + %45 = OpConstant %13 4 + %52 = OpTypeInt 32 0 + %53 = OpTypeVector %52 4 + %54 = OpTypePointer Input %53 + %55 = OpVariable %54 Input + %57 = OpTypePointer Input %52 + %63 = OpTypePointer Input %6 + %89 = OpConstant %52 4 + %102 = OpTypeVector %6 3 + %124 = OpVariable %63 Input + %158 = OpConstant %13 16 + %175 = OpTypePointer Output %7 + %176 = OpVariable %175 Output + %195 = OpUndef %7 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %84 = OpVariable %8 Function + %12 = OpLoad %7 %11 + OpBranch %17 + %17 = OpLabel + %191 = OpPhi %7 %12 %5 %31 %18 + %184 = OpPhi %13 %16 %5 %34 %18 + %25 = OpLoad %13 %24 + %27 = OpSLessThan %26 %184 %25 + OpLoopMerge %19 %18 None + OpBranchConditional %27 %18 %19 + %18 = OpLabel + %29 = OpLoad %7 %28 + %31 = OpFAdd %7 %191 %29 + %34 = OpIAdd %13 %184 %33 + OpBranch %17 + %19 = OpLabel + OpBranch %39 + %39 = OpLabel + %188 = OpPhi %6 %37 %19 %73 %51 + %185 = OpPhi %13 %16 %19 %75 %51 + %46 = OpSLessThan %26 %185 %45 + OpLoopMerge %41 %51 None + OpBranchConditional %46 %40 %41 + %40 = OpLabel + %49 = OpFOrdEqual %26 %188 %37 + OpSelectionMerge %51 None + OpBranchConditional %49 %50 %61 + %50 = OpLabel + %58 = OpAccessChain %57 %55 %185 + %59 = OpLoad %52 %58 + %60 = OpConvertUToF %6 %59 + OpBranch %51 + %61 = OpLabel + %64 = OpAccessChain %63 %11 %185 + %65 = OpLoad %6 %64 + OpBranch %51 + %51 = OpLabel + %210 = OpPhi %6 %60 %50 %65 %61 + %68 = OpAccessChain %57 %55 %185 + %69 = OpLoad %52 %68 + %70 = OpConvertUToF %6 %69 + %71 = OpFAdd %6 %210 %70 + %73 = OpFAdd %6 %188 %71 + %75 = OpIAdd %13 %185 %33 + OpBranch %39 + %41 = OpLabel + OpBranch %77 + %77 = OpLabel + %186 = OpPhi %13 %16 %41 %94 %78 + %83 = OpSLessThan %26 %186 %45 + OpLoopMerge %79 %78 None + OpBranchConditional %83 %78 %79 + %78 = OpLabel + %87 = OpAccessChain %57 %55 %186 + %88 = OpLoad %52 %87 + %90 = OpIMul %52 %88 %89 + %91 = OpConvertUToF %6 %90 + %92 = OpAccessChain %35 %84 %186 + OpStore %92 %91 + %94 = OpIAdd %13 %186 %33 + OpBranch %77 + %79 = OpLabel + %96 = OpCompositeConstruct %7 %188 %188 %188 %188 + %97 = OpLoad %7 %84 + %98 = OpFAdd %7 %96 %97 + %100 = OpFAdd %7 %191 %98 + %104 = OpVectorShuffle %102 %12 %12 0 1 2 + %106 = OpVectorShuffle %7 %195 %104 4 5 6 3 + OpBranch %108 + %108 = OpLabel + %197 = OpPhi %7 %106 %79 %208 %133 + %196 = OpPhi %13 %16 %79 %143 %133 + %115 = OpSLessThan %26 %196 %25 + OpLoopMerge %110 %133 None + OpBranchConditional %115 %109 %110 + %109 = OpLabel + OpBranch %117 + %117 = OpLabel + %209 = OpPhi %7 %197 %109 %181 %118 + %204 = OpPhi %13 %16 %109 %129 %118 + %123 = OpSLessThan %26 %204 %45 + OpLoopMerge %119 %118 None + OpBranchConditional %123 %118 %119 + %118 = OpLabel + %125 = OpLoad %6 %124 + %181 = OpCompositeInsert %7 %125 %209 3 + %129 = OpIAdd %13 %204 %33 + OpBranch %117 + %119 = OpLabel + OpBranch %131 + %131 = OpLabel + %208 = OpPhi %7 %209 %119 %183 %132 + %205 = OpPhi %13 %16 %119 %141 %132 + %137 = OpSLessThan %26 %205 %45 + OpLoopMerge %133 %132 None + OpBranchConditional %137 %132 %133 + %132 = OpLabel + %138 = OpLoad %6 %124 + %183 = OpCompositeInsert %7 %138 %208 3 + %141 = OpIAdd %13 %205 %33 + OpBranch %131 + %133 = OpLabel + %143 = OpIAdd %13 %196 %33 + OpBranch %108 + %110 = OpLabel + %145 = OpVectorShuffle %102 %197 %197 0 1 2 + %147 = OpVectorShuffle %102 %100 %100 0 1 2 + %148 = OpFAdd %102 %147 %145 + %150 = OpVectorShuffle %7 %100 %148 4 5 6 3 + OpBranch %152 + %152 = OpLabel + %200 = OpPhi %7 %150 %110 %203 %163 + %199 = OpPhi %13 %16 %110 %174 %163 + %159 = OpSLessThan %26 %199 %158 + OpLoopMerge %154 %163 None + OpBranchConditional %159 %153 %154 + %153 = OpLabel + OpBranch %161 + %161 = OpLabel + %203 = OpPhi %7 %200 %153 %170 %162 + %201 = OpPhi %13 %16 %153 %172 %162 + %167 = OpSLessThan %26 %201 %45 + OpLoopMerge %163 %162 None + OpBranchConditional %167 %162 %163 + %162 = OpLabel + %168 = OpLoad %6 %124 + %170 = OpVectorTimesScalar %7 %203 %168 + %172 = OpIAdd %13 %201 %33 + OpBranch %161 + %163 = OpLabel + %174 = OpIAdd %13 %199 %45 + OpBranch %152 + %154 = OpLabel + %178 = OpLoad %7 %84 + %179 = OpFAdd %7 %200 %178 + OpStore %176 %179 + OpReturn + OpFunctionEnd + )"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + Function* f = &*module->begin(); + LivenessAnalysis* liveness_analysis = context->GetLivenessAnalysis(); + const RegisterLiveness* register_liveness = liveness_analysis->Get(f); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + + { + SCOPED_TRACE("Block 5"); + auto live_sets = register_liveness->Get(5); + std::unordered_set live_in{ + 11, // %11 = OpVariable %10 Input + 24, // %24 = OpVariable %23 Input + 28, // %28 = OpVariable %10 Input + 55, // %55 = OpVariable %54 Input + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 24, // %24 = OpVariable %23 Input + 28, // %28 = OpVariable %10 Input + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 8u); + } + { + SCOPED_TRACE("Block 17"); + auto live_sets = register_liveness->Get(17); + std::unordered_set live_in{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 24, // %24 = OpVariable %23 Input + 28, // %28 = OpVariable %10 Input + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 184, // %184 = OpPhi %13 %16 %5 %34 %18 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 28, // %28 = OpVariable %10 Input + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 184, // %184 = OpPhi %13 %16 %5 %34 %18 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 11u); + } + { + SCOPED_TRACE("Block 18"); + auto live_sets = register_liveness->Get(18); + std::unordered_set live_in{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 24, // %24 = OpVariable %23 Input + 28, // %28 = OpVariable %10 Input + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 184, // %184 = OpPhi %13 %16 %5 %34 %18 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 24, // %24 = OpVariable %23 Input + 28, // %28 = OpVariable %10 Input + 31, // %31 = OpFAdd %7 %191 %29 + 34, // %34 = OpIAdd %13 %184 %33 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 12u); + } + { + SCOPED_TRACE("Block 19"); + auto live_sets = register_liveness->Get(19); + std::unordered_set live_inout{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_in_, live_inout); + CompareSets(live_sets->live_out_, live_inout); + + EXPECT_EQ(live_sets->used_registers_, 8u); + } + { + SCOPED_TRACE("Block 39"); + auto live_sets = register_liveness->Get(39); + std::unordered_set live_inout{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 185, // %185 = OpPhi %13 %16 %19 %75 %51 + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_in_, live_inout); + CompareSets(live_sets->live_out_, live_inout); + + EXPECT_EQ(live_sets->used_registers_, 11u); + } + { + SCOPED_TRACE("Block 40"); + auto live_sets = register_liveness->Get(40); + std::unordered_set live_inout{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 185, // %185 = OpPhi %13 %16 %19 %75 %51 + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_in_, live_inout); + CompareSets(live_sets->live_out_, live_inout); + + EXPECT_EQ(live_sets->used_registers_, 11u); + } + { + SCOPED_TRACE("Block 50"); + auto live_sets = register_liveness->Get(50); + std::unordered_set live_in{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 185, // %185 = OpPhi %13 %16 %19 %75 %51 + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 60, // %60 = OpConvertUToF %6 %59 + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 185, // %185 = OpPhi %13 %16 %19 %75 %51 + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 12u); + } + { + SCOPED_TRACE("Block 61"); + auto live_sets = register_liveness->Get(61); + std::unordered_set live_in{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 185, // %185 = OpPhi %13 %16 %19 %75 %51 + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 65, // %65 = OpLoad %6 %64 + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 185, // %185 = OpPhi %13 %16 %19 %75 %51 + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 12u); + } + { + SCOPED_TRACE("Block 51"); + auto live_sets = register_liveness->Get(51); + std::unordered_set live_in{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 185, // %185 = OpPhi %13 %16 %19 %75 %51 + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + 210, // %210 = OpPhi %6 %60 %50 %65 %61 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 73, // %73 = OpFAdd %6 %188 %71 + 75, // %75 = OpIAdd %13 %185 %33 + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 13u); + } + { + SCOPED_TRACE("Block 41"); + auto live_sets = register_liveness->Get(41); + std::unordered_set live_inout{ + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_in_, live_inout); + CompareSets(live_sets->live_out_, live_inout); + + EXPECT_EQ(live_sets->used_registers_, 8u); + } + { + SCOPED_TRACE("Block 77"); + auto live_sets = register_liveness->Get(77); + std::unordered_set live_inout{ + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 186, // %186 = OpPhi %13 %16 %41 %94 %78 + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_in_, live_inout); + CompareSets(live_sets->live_out_, live_inout); + + EXPECT_EQ(live_sets->used_registers_, 10u); + } + { + SCOPED_TRACE("Block 78"); + auto live_sets = register_liveness->Get(78); + std::unordered_set live_in{ + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 186, // %186 = OpPhi %13 %16 %41 %94 %78 + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 94, // %94 = OpIAdd %13 %186 %33 + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 11u); + } + { + SCOPED_TRACE("Block 79"); + auto live_sets = register_liveness->Get(79); + std::unordered_set live_in{ + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 25, // %25 = OpLoad %13 %24 + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 106, // %106 = OpVectorShuffle %7 %195 %104 4 5 6 3 + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 9u); + } + { + SCOPED_TRACE("Block 108"); + auto live_sets = register_liveness->Get(108); + std::unordered_set live_in{ + 25, // %25 = OpLoad %13 %24 + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 196, // %196 = OpPhi %13 %16 %79 %143 %133 + 197, // %197 = OpPhi %7 %106 %79 %208 %133 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 196, // %196 = OpPhi %13 %16 %79 %143 %133 + 197, // %197 = OpPhi %7 %106 %79 %208 %133 + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 8u); + } + { + SCOPED_TRACE("Block 109"); + auto live_sets = register_liveness->Get(109); + std::unordered_set live_inout{ + 25, // %25 = OpLoad %13 %24 + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 196, // %196 = OpPhi %13 %16 %79 %143 %133 + 197, // %197 = OpPhi %7 %106 %79 %208 %133 + }; + CompareSets(live_sets->live_in_, live_inout); + CompareSets(live_sets->live_out_, live_inout); + + EXPECT_EQ(live_sets->used_registers_, 7u); + } + { + SCOPED_TRACE("Block 117"); + auto live_sets = register_liveness->Get(117); + std::unordered_set live_inout{ + 25, // %25 = OpLoad %13 %24 + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 196, // %196 = OpPhi %13 %16 %79 %143 %133 + 204, // %204 = OpPhi %13 %16 %109 %129 %118 + 209, // %209 = OpPhi %7 %197 %109 %181 %118 + }; + CompareSets(live_sets->live_in_, live_inout); + CompareSets(live_sets->live_out_, live_inout); + + EXPECT_EQ(live_sets->used_registers_, 9u); + } + { + SCOPED_TRACE("Block 118"); + auto live_sets = register_liveness->Get(118); + std::unordered_set live_in{ + 25, // %25 = OpLoad %13 %24 + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 196, // %196 = OpPhi %13 %16 %79 %143 %133 + 204, // %204 = OpPhi %13 %16 %109 %129 %118 + 209, // %209 = OpPhi %7 %197 %109 %181 %118 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 25, // %25 = OpLoad %13 %24 + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 124, // %124 = OpVariable %63 Input + 129, // %129 = OpIAdd %13 %204 %33 + 176, // %176 = OpVariable %175 Output + 181, // %181 = OpCompositeInsert %7 %125 %209 3 + 196, // %196 = OpPhi %13 %16 %79 %143 %133 + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 10u); + } + { + SCOPED_TRACE("Block 119"); + auto live_sets = register_liveness->Get(119); + std::unordered_set live_inout{ + 25, // %25 = OpLoad %13 %24 + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 196, // %196 = OpPhi %13 %16 %79 %143 %133 + 209, // %209 = OpPhi %7 %197 %109 %181 %118 + }; + CompareSets(live_sets->live_in_, live_inout); + CompareSets(live_sets->live_out_, live_inout); + + EXPECT_EQ(live_sets->used_registers_, 7u); + } + { + SCOPED_TRACE("Block 131"); + auto live_sets = register_liveness->Get(131); + std::unordered_set live_inout{ + 25, // %25 = OpLoad %13 %24 + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 196, // %196 = OpPhi %13 %16 %79 %143 %133 + 205, // %205 = OpPhi %13 %16 %119 %141 %132 + 208, // %208 = OpPhi %7 %209 %119 %183 %132 + }; + CompareSets(live_sets->live_in_, live_inout); + CompareSets(live_sets->live_out_, live_inout); + + EXPECT_EQ(live_sets->used_registers_, 9u); + } + { + SCOPED_TRACE("Block 132"); + auto live_sets = register_liveness->Get(132); + std::unordered_set live_in{ + 25, // %25 = OpLoad %13 %24 + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 196, // %196 = OpPhi %13 %16 %79 %143 %133 + 205, // %205 = OpPhi %13 %16 %119 %141 %132 + 208, // %208 = OpPhi %7 %209 %119 %183 %132 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 25, // %25 = OpLoad %13 %24 + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 124, // %124 = OpVariable %63 Input + 141, // %141 = OpIAdd %13 %205 %33 + 176, // %176 = OpVariable %175 Output + 183, // %183 = OpCompositeInsert %7 %138 %208 3 + 196, // %196 = OpPhi %13 %16 %79 %143 %133 + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 10u); + } + { + SCOPED_TRACE("Block 133"); + auto live_sets = register_liveness->Get(133); + std::unordered_set live_in{ + 25, // %25 = OpLoad %13 %24 + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 196, // %196 = OpPhi %13 %16 %79 %143 %133 + 208, // %208 = OpPhi %7 %209 %119 %183 %132 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 25, // %25 = OpLoad %13 %24 + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 124, // %124 = OpVariable %63 Input + 143, // %143 = OpIAdd %13 %196 %33 + 176, // %176 = OpVariable %175 Output + 208, // %208 = OpPhi %7 %209 %119 %183 %132 + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 8u); + } + { + SCOPED_TRACE("Block 110"); + auto live_sets = register_liveness->Get(110); + std::unordered_set live_in{ + 84, // %84 = OpVariable %8 Function + 100, // %100 = OpFAdd %7 %191 %98 + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 197, // %197 = OpPhi %7 %106 %79 %208 %133 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 150, // %150 = OpVectorShuffle %7 %100 %148 4 5 6 3 + 176, // %176 = OpVariable %175 Output + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 7u); + } + { + SCOPED_TRACE("Block 152"); + auto live_sets = register_liveness->Get(152); + std::unordered_set live_inout{ + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 199, // %199 = OpPhi %13 %16 %110 %174 %163 + 200, // %200 = OpPhi %7 %150 %110 %203 %163 + }; + CompareSets(live_sets->live_in_, live_inout); + CompareSets(live_sets->live_out_, live_inout); + + EXPECT_EQ(live_sets->used_registers_, 6u); + } + { + SCOPED_TRACE("Block 153"); + auto live_sets = register_liveness->Get(153); + std::unordered_set live_inout{ + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 199, // %199 = OpPhi %13 %16 %110 %174 %163 + 200, // %200 = OpPhi %7 %150 %110 %203 %163 + }; + CompareSets(live_sets->live_in_, live_inout); + CompareSets(live_sets->live_out_, live_inout); + + EXPECT_EQ(live_sets->used_registers_, 5u); + } + { + SCOPED_TRACE("Block 161"); + auto live_sets = register_liveness->Get(161); + std::unordered_set live_inout{ + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 199, // %199 = OpPhi %13 %16 %110 %174 %163 + 201, // %201 = OpPhi %13 %16 %153 %172 %162 + 203, // %203 = OpPhi %7 %200 %153 %170 %162 + }; + CompareSets(live_sets->live_in_, live_inout); + CompareSets(live_sets->live_out_, live_inout); + + EXPECT_EQ(live_sets->used_registers_, 7u); + } + { + SCOPED_TRACE("Block 162"); + auto live_sets = register_liveness->Get(162); + std::unordered_set live_in{ + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 199, // %199 = OpPhi %13 %16 %110 %174 %163 + 201, // %201 = OpPhi %13 %16 %153 %172 %162 + 203, // %203 = OpPhi %7 %200 %153 %170 %162 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 170, // %170 = OpVectorTimesScalar %7 %203 %168 + 172, // %172 = OpIAdd %13 %201 %33 + 176, // %176 = OpVariable %175 Output + 199, // %199 = OpPhi %13 %16 %110 %174 %163 + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 8u); + } + { + SCOPED_TRACE("Block 163"); + auto live_sets = register_liveness->Get(163); + std::unordered_set live_in{ + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 199, // %199 = OpPhi %13 %16 %110 %174 %163 + 203, // %203 = OpPhi %7 %200 %153 %170 %162 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{ + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 174, // %174 = OpIAdd %13 %199 %45 + 176, // %176 = OpVariable %175 Output + 203, // %203 = OpPhi %7 %200 %153 %170 %162 + }; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 6u); + } + { + SCOPED_TRACE("Block 154"); + auto live_sets = register_liveness->Get(154); + std::unordered_set live_in{ + 84, // %84 = OpVariable %8 Function + 176, // %176 = OpVariable %175 Output + 200, // %200 = OpPhi %7 %150 %110 %203 %163 + }; + CompareSets(live_sets->live_in_, live_in); + + std::unordered_set live_out{}; + CompareSets(live_sets->live_out_, live_out); + + EXPECT_EQ(live_sets->used_registers_, 4u); + } + + { + SCOPED_TRACE("Compute loop pressure"); + RegisterLiveness::RegionRegisterLiveness loop_reg_pressure; + register_liveness->ComputeLoopRegisterPressure(*ld[39], &loop_reg_pressure); + // Generate(*context->cfg()->block(39), &loop_reg_pressure); + std::unordered_set live_in{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 185, // %185 = OpPhi %13 %16 %19 %75 %51 + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(loop_reg_pressure.live_in_, live_in); + + std::unordered_set live_out{ + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(loop_reg_pressure.live_out_, live_out); + + EXPECT_EQ(loop_reg_pressure.used_registers_, 13u); + } + + { + SCOPED_TRACE("Loop Fusion simulation"); + RegisterLiveness::RegionRegisterLiveness simulation_resut; + register_liveness->SimulateFusion(*ld[17], *ld[39], &simulation_resut); + + std::unordered_set live_in{ + 11, // %11 = OpVariable %10 Input + 12, // %12 = OpLoad %7 %11 + 24, // %24 = OpVariable %23 Input + 25, // %25 = OpLoad %13 %24 + 28, // %28 = OpVariable %10 Input + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 184, // %184 = OpPhi %13 %16 %5 %34 %18 + 185, // %185 = OpPhi %13 %16 %19 %75 %51 + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(simulation_resut.live_in_, live_in); + + std::unordered_set live_out{ + 12, // %12 = OpLoad %7 %11 + 25, // %25 = OpLoad %13 %24 + 55, // %55 = OpVariable %54 Input + 84, // %84 = OpVariable %8 Function + 124, // %124 = OpVariable %63 Input + 176, // %176 = OpVariable %175 Output + 188, // %188 = OpPhi %6 %37 %19 %73 %51 + 191, // %191 = OpPhi %7 %12 %5 %31 %18 + }; + CompareSets(simulation_resut.live_out_, live_out); + + EXPECT_EQ(simulation_resut.used_registers_, 17u); + } +} + +TEST_F(PassClassTest, FissionSimulation) { + const std::string source = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %3 "i" + OpName %4 "A" + OpName %5 "B" + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %10 = OpConstant %8 0 + %11 = OpConstant %8 10 + %12 = OpTypeBool + %13 = OpTypeFloat 32 + %14 = OpTypeInt 32 0 + %15 = OpConstant %14 10 + %16 = OpTypeArray %13 %15 + %17 = OpTypePointer Function %16 + %18 = OpTypePointer Function %13 + %19 = OpConstant %8 1 + %2 = OpFunction %6 None %7 + %20 = OpLabel + %3 = OpVariable %9 Function + %4 = OpVariable %17 Function + %5 = OpVariable %17 Function + OpBranch %21 + %21 = OpLabel + %22 = OpPhi %8 %10 %20 %23 %24 + OpLoopMerge %25 %24 None + OpBranch %26 + %26 = OpLabel + %27 = OpSLessThan %12 %22 %11 + OpBranchConditional %27 %28 %25 + %28 = OpLabel + %29 = OpAccessChain %18 %5 %22 + %30 = OpLoad %13 %29 + %31 = OpAccessChain %18 %4 %22 + OpStore %31 %30 + %32 = OpAccessChain %18 %4 %22 + %33 = OpLoad %13 %32 + %34 = OpAccessChain %18 %5 %22 + OpStore %34 %33 + OpBranch %24 + %24 = OpLabel + %23 = OpIAdd %8 %22 %19 + OpBranch %21 + %25 = OpLabel + OpStore %3 %22 + OpReturn + OpFunctionEnd + )"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, source, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << source << std::endl; + Function* f = &*module->begin(); + LivenessAnalysis* liveness_analysis = context->GetLivenessAnalysis(); + const RegisterLiveness* register_liveness = liveness_analysis->Get(f); + LoopDescriptor& ld = *context->GetLoopDescriptor(f); + analysis::DefUseManager& def_use_mgr = *context->get_def_use_mgr(); + + { + RegisterLiveness::RegionRegisterLiveness l1_sim_resut; + RegisterLiveness::RegionRegisterLiveness l2_sim_resut; + std::unordered_set moved_instructions{ + def_use_mgr.GetDef(29), def_use_mgr.GetDef(30), def_use_mgr.GetDef(31), + def_use_mgr.GetDef(31)->NextNode()}; + std::unordered_set copied_instructions{ + def_use_mgr.GetDef(22), def_use_mgr.GetDef(27), + def_use_mgr.GetDef(27)->NextNode(), def_use_mgr.GetDef(23)}; + + register_liveness->SimulateFission(*ld[21], moved_instructions, + copied_instructions, &l1_sim_resut, + &l2_sim_resut); + { + SCOPED_TRACE("L1 simulation"); + std::unordered_set live_in{ + 3, // %3 = OpVariable %9 Function + 4, // %4 = OpVariable %17 Function + 5, // %5 = OpVariable %17 Function + 22, // %22 = OpPhi %8 %10 %20 %23 %24 + }; + CompareSets(l1_sim_resut.live_in_, live_in); + + std::unordered_set live_out{ + 3, // %3 = OpVariable %9 Function + 4, // %4 = OpVariable %17 Function + 5, // %5 = OpVariable %17 Function + 22, // %22 = OpPhi %8 %10 %20 %23 %24 + }; + CompareSets(l1_sim_resut.live_out_, live_out); + + EXPECT_EQ(l1_sim_resut.used_registers_, 6u); + } + { + SCOPED_TRACE("L2 simulation"); + std::unordered_set live_in{ + 3, // %3 = OpVariable %9 Function + 4, // %4 = OpVariable %17 Function + 5, // %5 = OpVariable %17 Function + 22, // %22 = OpPhi %8 %10 %20 %23 %24 + }; + CompareSets(l2_sim_resut.live_in_, live_in); + + std::unordered_set live_out{ + 3, // %3 = OpVariable %9 Function + 22, // %22 = OpPhi %8 %10 %20 %23 %24 + }; + CompareSets(l2_sim_resut.live_out_, live_out); + + EXPECT_EQ(l2_sim_resut.used_registers_, 6u); + } + } +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/replace_invalid_opc_test.cpp b/test/opt/replace_invalid_opc_test.cpp new file mode 100644 index 000000000..1be904b4e --- /dev/null +++ b/test/opt/replace_invalid_opc_test.cpp @@ -0,0 +1,566 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "pass_utils.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" + +namespace spvtools { +namespace opt { +namespace { + +using ReplaceInvalidOpcodeTest = PassTest<::testing::Test>; + +TEST_F(ReplaceInvalidOpcodeTest, ReplaceInstruction) { + const std::string text = R"( +; CHECK: [[special_const:%\w+]] = OpConstant %float -6.2598534e+18 +; CHECK: [[constant:%\w+]] = OpConstantComposite %v4float [[special_const]] [[special_const]] [[special_const]] [[special_const]] +; CHECK-NOT: OpImageSampleImplicitLod +; CHECK: OpStore [[:%\w+]] [[constant]] + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" %3 %gl_VertexIndex %5 + OpSource GLSL 400 + OpSourceExtension "GL_ARB_separate_shader_objects" + OpSourceExtension "GL_ARB_shading_language_420pack" + OpName %main "main" + OpDecorate %3 Location 0 + OpDecorate %gl_VertexIndex BuiltIn VertexIndex + OpMemberDecorate %_struct_6 0 BuiltIn Position + OpDecorate %_struct_6 Block + %void = OpTypeVoid + %8 = OpTypeFunction %void + %float = OpTypeFloat 32 + %10 = OpTypeImage %float 2D 0 0 0 1 Unknown +%_ptr_UniformConstant_10 = OpTypePointer UniformConstant %10 + %12 = OpTypeSampler +%_ptr_UniformConstant_12 = OpTypePointer UniformConstant %12 + %14 = OpTypeSampledImage %10 + %v4float = OpTypeVector %float 4 + %v2float = OpTypeVector %float 2 +%_ptr_Output_v4float = OpTypePointer Output %v4float + %3 = OpVariable %_ptr_Output_v4float Output + %int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%gl_VertexIndex = OpVariable %_ptr_Input_int Input + %_struct_6 = OpTypeStruct %v4float +%_ptr_Output__struct_6 = OpTypePointer Output %_struct_6 + %5 = OpVariable %_ptr_Output__struct_6 Output + %int_0 = OpConstant %int 0 + %float_0 = OpConstant %float 0 + %23 = OpConstantComposite %v2float %float_0 %float_0 + %24 = OpVariable %_ptr_UniformConstant_10 UniformConstant + %25 = OpVariable %_ptr_UniformConstant_12 UniformConstant + %main = OpFunction %void None %8 + %26 = OpLabel + %27 = OpLoad %12 %25 + %28 = OpLoad %10 %24 + %29 = OpSampledImage %14 %28 %27 + %30 = OpImageSampleImplicitLod %v4float %29 %23 + %31 = OpAccessChain %_ptr_Output_v4float %5 %int_0 + OpStore %31 %30 + OpReturn + OpFunctionEnd)"; + + SinglePassRunAndMatch(text, false); +} + +TEST_F(ReplaceInvalidOpcodeTest, ReplaceInstructionInNonEntryPoint) { + const std::string text = R"( +; CHECK: [[special_const:%\w+]] = OpConstant %float -6.2598534e+18 +; CHECK: [[constant:%\w+]] = OpConstantComposite %v4float [[special_const]] [[special_const]] [[special_const]] [[special_const]] +; CHECK-NOT: OpImageSampleImplicitLod +; CHECK: OpStore [[:%\w+]] [[constant]] + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" %3 %gl_VertexIndex %5 + OpSource GLSL 400 + OpSourceExtension "GL_ARB_separate_shader_objects" + OpSourceExtension "GL_ARB_shading_language_420pack" + OpName %main "main" + OpDecorate %3 Location 0 + OpDecorate %gl_VertexIndex BuiltIn VertexIndex + OpMemberDecorate %_struct_6 0 BuiltIn Position + OpDecorate %_struct_6 Block + %void = OpTypeVoid + %8 = OpTypeFunction %void + %float = OpTypeFloat 32 + %10 = OpTypeImage %float 2D 0 0 0 1 Unknown +%_ptr_UniformConstant_10 = OpTypePointer UniformConstant %10 + %12 = OpTypeSampler +%_ptr_UniformConstant_12 = OpTypePointer UniformConstant %12 + %14 = OpTypeSampledImage %10 + %v4float = OpTypeVector %float 4 + %v2float = OpTypeVector %float 2 +%_ptr_Output_v4float = OpTypePointer Output %v4float + %3 = OpVariable %_ptr_Output_v4float Output + %int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%gl_VertexIndex = OpVariable %_ptr_Input_int Input + %_struct_6 = OpTypeStruct %v4float +%_ptr_Output__struct_6 = OpTypePointer Output %_struct_6 + %5 = OpVariable %_ptr_Output__struct_6 Output + %int_0 = OpConstant %int 0 + %float_0 = OpConstant %float 0 + %23 = OpConstantComposite %v2float %float_0 %float_0 + %24 = OpVariable %_ptr_UniformConstant_10 UniformConstant + %25 = OpVariable %_ptr_UniformConstant_12 UniformConstant + %main = OpFunction %void None %8 + %26 = OpLabel + %27 = OpFunctionCall %void %28 + OpReturn + OpFunctionEnd + %28 = OpFunction %void None %8 + %29 = OpLabel + %30 = OpLoad %12 %25 + %31 = OpLoad %10 %24 + %32 = OpSampledImage %14 %31 %30 + %33 = OpImageSampleImplicitLod %v4float %32 %23 + %34 = OpAccessChain %_ptr_Output_v4float %5 %int_0 + OpStore %34 %33 + OpReturn + OpFunctionEnd)"; + + SinglePassRunAndMatch(text, false); +} + +TEST_F(ReplaceInvalidOpcodeTest, ReplaceInstructionMultipleEntryPoints) { + const std::string text = R"( +; CHECK: [[special_const:%\w+]] = OpConstant %float -6.2598534e+18 +; CHECK: [[constant:%\w+]] = OpConstantComposite %v4float [[special_const]] [[special_const]] [[special_const]] [[special_const]] +; CHECK-NOT: OpImageSampleImplicitLod +; CHECK: OpStore [[:%\w+]] [[constant]] +; CHECK-NOT: OpImageSampleImplicitLod +; CHECK: OpStore [[:%\w+]] [[constant]] + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" %3 %gl_VertexIndex %5 + OpEntryPoint Vertex %main2 "main2" %3 %gl_VertexIndex %5 + OpSource GLSL 400 + OpSourceExtension "GL_ARB_separate_shader_objects" + OpSourceExtension "GL_ARB_shading_language_420pack" + OpName %main "main" + OpName %main2 "main2" + OpDecorate %3 Location 0 + OpDecorate %gl_VertexIndex BuiltIn VertexIndex + OpMemberDecorate %_struct_6 0 BuiltIn Position + OpDecorate %_struct_6 Block + %void = OpTypeVoid + %8 = OpTypeFunction %void + %float = OpTypeFloat 32 + %10 = OpTypeImage %float 2D 0 0 0 1 Unknown +%_ptr_UniformConstant_10 = OpTypePointer UniformConstant %10 + %12 = OpTypeSampler +%_ptr_UniformConstant_12 = OpTypePointer UniformConstant %12 + %14 = OpTypeSampledImage %10 + %v4float = OpTypeVector %float 4 + %v2float = OpTypeVector %float 2 +%_ptr_Output_v4float = OpTypePointer Output %v4float + %3 = OpVariable %_ptr_Output_v4float Output + %int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%gl_VertexIndex = OpVariable %_ptr_Input_int Input + %_struct_6 = OpTypeStruct %v4float +%_ptr_Output__struct_6 = OpTypePointer Output %_struct_6 + %5 = OpVariable %_ptr_Output__struct_6 Output + %int_0 = OpConstant %int 0 + %float_0 = OpConstant %float 0 + %23 = OpConstantComposite %v2float %float_0 %float_0 + %24 = OpVariable %_ptr_UniformConstant_10 UniformConstant + %25 = OpVariable %_ptr_UniformConstant_12 UniformConstant + %main = OpFunction %void None %8 + %26 = OpLabel + %27 = OpLoad %12 %25 + %28 = OpLoad %10 %24 + %29 = OpSampledImage %14 %28 %27 + %30 = OpImageSampleImplicitLod %v4float %29 %23 + %31 = OpAccessChain %_ptr_Output_v4float %5 %int_0 + OpStore %31 %30 + OpReturn + OpFunctionEnd + %main2 = OpFunction %void None %8 + %46 = OpLabel + %47 = OpLoad %12 %25 + %48 = OpLoad %10 %24 + %49 = OpSampledImage %14 %48 %47 + %50 = OpImageSampleImplicitLod %v4float %49 %23 + %51 = OpAccessChain %_ptr_Output_v4float %5 %int_0 + OpStore %51 %50 + OpReturn + OpFunctionEnd)"; + + SinglePassRunAndMatch(text, false); +} +TEST_F(ReplaceInvalidOpcodeTest, DontReplaceInstruction) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %3 %gl_VertexIndex %5 + OpSource GLSL 400 + OpSourceExtension "GL_ARB_separate_shader_objects" + OpSourceExtension "GL_ARB_shading_language_420pack" + OpName %main "main" + OpDecorate %3 Location 0 + OpDecorate %gl_VertexIndex BuiltIn VertexIndex + OpMemberDecorate %_struct_6 0 BuiltIn Position + OpDecorate %_struct_6 Block + %void = OpTypeVoid + %8 = OpTypeFunction %void + %float = OpTypeFloat 32 + %10 = OpTypeImage %float 2D 0 0 0 1 Unknown +%_ptr_UniformConstant_10 = OpTypePointer UniformConstant %10 + %12 = OpTypeSampler +%_ptr_UniformConstant_12 = OpTypePointer UniformConstant %12 + %14 = OpTypeSampledImage %10 + %v4float = OpTypeVector %float 4 + %v2float = OpTypeVector %float 2 +%_ptr_Output_v4float = OpTypePointer Output %v4float + %3 = OpVariable %_ptr_Output_v4float Output + %int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%gl_VertexIndex = OpVariable %_ptr_Input_int Input + %_struct_6 = OpTypeStruct %v4float +%_ptr_Output__struct_6 = OpTypePointer Output %_struct_6 + %5 = OpVariable %_ptr_Output__struct_6 Output + %int_0 = OpConstant %int 0 + %float_0 = OpConstant %float 0 + %23 = OpConstantComposite %v2float %float_0 %float_0 + %24 = OpVariable %_ptr_UniformConstant_10 UniformConstant + %25 = OpVariable %_ptr_UniformConstant_12 UniformConstant + %main = OpFunction %void None %8 + %26 = OpLabel + %27 = OpLoad %12 %25 + %28 = OpLoad %10 %24 + %29 = OpSampledImage %14 %28 %27 + %30 = OpImageSampleImplicitLod %v4float %29 %23 + %31 = OpAccessChain %_ptr_Output_v4float %5 %int_0 + OpStore %31 %30 + OpReturn + OpFunctionEnd)"; + + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ false); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +TEST_F(ReplaceInvalidOpcodeTest, MultipleEntryPointsDifferentStage) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" %3 %gl_VertexIndex %5 + OpEntryPoint Fragment %main2 "main2" %3 %gl_VertexIndex %5 + OpSource GLSL 400 + OpSourceExtension "GL_ARB_separate_shader_objects" + OpSourceExtension "GL_ARB_shading_language_420pack" + OpName %main "main" + OpName %main2 "main2" + OpDecorate %3 Location 0 + OpDecorate %gl_VertexIndex BuiltIn VertexIndex + OpMemberDecorate %_struct_6 0 BuiltIn Position + OpDecorate %_struct_6 Block + %void = OpTypeVoid + %8 = OpTypeFunction %void + %float = OpTypeFloat 32 + %10 = OpTypeImage %float 2D 0 0 0 1 Unknown +%_ptr_UniformConstant_10 = OpTypePointer UniformConstant %10 + %12 = OpTypeSampler +%_ptr_UniformConstant_12 = OpTypePointer UniformConstant %12 + %14 = OpTypeSampledImage %10 + %v4float = OpTypeVector %float 4 + %v2float = OpTypeVector %float 2 +%_ptr_Output_v4float = OpTypePointer Output %v4float + %3 = OpVariable %_ptr_Output_v4float Output + %int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%gl_VertexIndex = OpVariable %_ptr_Input_int Input + %_struct_6 = OpTypeStruct %v4float +%_ptr_Output__struct_6 = OpTypePointer Output %_struct_6 + %5 = OpVariable %_ptr_Output__struct_6 Output + %int_0 = OpConstant %int 0 + %float_0 = OpConstant %float 0 + %23 = OpConstantComposite %v2float %float_0 %float_0 + %24 = OpVariable %_ptr_UniformConstant_10 UniformConstant + %25 = OpVariable %_ptr_UniformConstant_12 UniformConstant + %main = OpFunction %void None %8 + %26 = OpLabel + %27 = OpLoad %12 %25 + %28 = OpLoad %10 %24 + %29 = OpSampledImage %14 %28 %27 + %30 = OpImageSampleImplicitLod %v4float %29 %23 + %31 = OpAccessChain %_ptr_Output_v4float %5 %int_0 + OpStore %31 %30 + OpReturn + OpFunctionEnd + %main2 = OpFunction %void None %8 + %46 = OpLabel + %47 = OpLoad %12 %25 + %48 = OpLoad %10 %24 + %49 = OpSampledImage %14 %48 %47 + %50 = OpImageSampleImplicitLod %v4float %49 %23 + %51 = OpAccessChain %_ptr_Output_v4float %5 %int_0 + OpStore %51 %50 + OpReturn + OpFunctionEnd)"; + + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ false); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +TEST_F(ReplaceInvalidOpcodeTest, DontReplaceLinkage) { + const std::string text = R"( + OpCapability Shader + OpCapability Linkage + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" %3 %gl_VertexIndex %5 + OpSource GLSL 400 + OpSourceExtension "GL_ARB_separate_shader_objects" + OpSourceExtension "GL_ARB_shading_language_420pack" + OpName %main "main" + OpDecorate %3 Location 0 + OpDecorate %gl_VertexIndex BuiltIn VertexIndex + OpMemberDecorate %_struct_6 0 BuiltIn Position + OpDecorate %_struct_6 Block + %void = OpTypeVoid + %8 = OpTypeFunction %void + %float = OpTypeFloat 32 + %10 = OpTypeImage %float 2D 0 0 0 1 Unknown +%_ptr_UniformConstant_10 = OpTypePointer UniformConstant %10 + %12 = OpTypeSampler +%_ptr_UniformConstant_12 = OpTypePointer UniformConstant %12 + %14 = OpTypeSampledImage %10 + %v4float = OpTypeVector %float 4 + %v2float = OpTypeVector %float 2 +%_ptr_Output_v4float = OpTypePointer Output %v4float + %3 = OpVariable %_ptr_Output_v4float Output + %int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%gl_VertexIndex = OpVariable %_ptr_Input_int Input + %_struct_6 = OpTypeStruct %v4float +%_ptr_Output__struct_6 = OpTypePointer Output %_struct_6 + %5 = OpVariable %_ptr_Output__struct_6 Output + %int_0 = OpConstant %int 0 + %float_0 = OpConstant %float 0 + %23 = OpConstantComposite %v2float %float_0 %float_0 + %24 = OpVariable %_ptr_UniformConstant_10 UniformConstant + %25 = OpVariable %_ptr_UniformConstant_12 UniformConstant + %main = OpFunction %void None %8 + %26 = OpLabel + %27 = OpLoad %12 %25 + %28 = OpLoad %10 %24 + %29 = OpSampledImage %14 %28 %27 + %30 = OpImageSampleImplicitLod %v4float %29 %23 + %31 = OpAccessChain %_ptr_Output_v4float %5 %int_0 + OpStore %31 %30 + OpReturn + OpFunctionEnd)"; + + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ false); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +TEST_F(ReplaceInvalidOpcodeTest, BarrierDontReplace) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 450 + OpSourceExtension "GL_GOOGLE_cpp_style_line_directive" + OpSourceExtension "GL_GOOGLE_include_directive" + OpName %main "main" + %void = OpTypeVoid + %3 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%uint_264 = OpConstant %uint 264 + %main = OpFunction %void None %3 + %5 = OpLabel + OpControlBarrier %uint_2 %uint_2 %uint_264 + OpReturn + OpFunctionEnd)"; + + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ false); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +TEST_F(ReplaceInvalidOpcodeTest, BarrierReplace) { + const std::string text = R"( +; CHECK-NOT: OpControlBarrier + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 450 + OpSourceExtension "GL_GOOGLE_cpp_style_line_directive" + OpSourceExtension "GL_GOOGLE_include_directive" + OpName %main "main" + %void = OpTypeVoid + %3 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%uint_264 = OpConstant %uint 264 + %main = OpFunction %void None %3 + %5 = OpLabel + OpControlBarrier %uint_2 %uint_2 %uint_264 + OpReturn + OpFunctionEnd)"; + + SinglePassRunAndMatch(text, false); +} + +TEST_F(ReplaceInvalidOpcodeTest, MessageTest) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" %3 %gl_VertexIndex %5 + OpSource GLSL 400 + %6 = OpString "test.hlsl" + OpSourceExtension "GL_ARB_separate_shader_objects" + OpSourceExtension "GL_ARB_shading_language_420pack" + OpName %main "main" + OpDecorate %3 Location 0 + OpDecorate %gl_VertexIndex BuiltIn VertexIndex + OpMemberDecorate %_struct_7 0 BuiltIn Position + OpDecorate %_struct_7 Block + %void = OpTypeVoid + %9 = OpTypeFunction %void + %float = OpTypeFloat 32 + %11 = OpTypeImage %float 2D 0 0 0 1 Unknown +%_ptr_UniformConstant_11 = OpTypePointer UniformConstant %11 + %13 = OpTypeSampler +%_ptr_UniformConstant_13 = OpTypePointer UniformConstant %13 + %15 = OpTypeSampledImage %11 + %v4float = OpTypeVector %float 4 + %v2float = OpTypeVector %float 2 +%_ptr_Output_v4float = OpTypePointer Output %v4float + %3 = OpVariable %_ptr_Output_v4float Output + %int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%gl_VertexIndex = OpVariable %_ptr_Input_int Input + %_struct_7 = OpTypeStruct %v4float +%_ptr_Output__struct_7 = OpTypePointer Output %_struct_7 + %5 = OpVariable %_ptr_Output__struct_7 Output + %int_0 = OpConstant %int 0 + %float_0 = OpConstant %float 0 + %24 = OpConstantComposite %v2float %float_0 %float_0 + %25 = OpVariable %_ptr_UniformConstant_11 UniformConstant + %26 = OpVariable %_ptr_UniformConstant_13 UniformConstant + %main = OpFunction %void None %9 + %27 = OpLabel + OpLine %6 2 4 + %28 = OpLoad %13 %26 + %29 = OpLoad %11 %25 + %30 = OpSampledImage %15 %29 %28 + %31 = OpImageSampleImplicitLod %v4float %30 %24 + %32 = OpAccessChain %_ptr_Output_v4float %5 %int_0 + OpStore %32 %31 + OpReturn + OpFunctionEnd)"; + + std::vector messages = { + {SPV_MSG_WARNING, "test.hlsl", 2, 4, + "Removing ImageSampleImplicitLod instruction because of incompatible " + "execution model."}}; + SetMessageConsumer(GetTestMessageConsumer(messages)); + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ false); + EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result)); +} + +TEST_F(ReplaceInvalidOpcodeTest, MultipleMessageTest) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" %3 %gl_VertexIndex %5 + OpSource GLSL 400 + %6 = OpString "test.hlsl" + OpSourceExtension "GL_ARB_separate_shader_objects" + OpSourceExtension "GL_ARB_shading_language_420pack" + OpName %main "main" + OpDecorate %3 Location 0 + OpDecorate %gl_VertexIndex BuiltIn VertexIndex + OpMemberDecorate %_struct_7 0 BuiltIn Position + OpDecorate %_struct_7 Block + %void = OpTypeVoid + %9 = OpTypeFunction %void + %float = OpTypeFloat 32 + %11 = OpTypeImage %float 2D 0 0 0 1 Unknown +%_ptr_UniformConstant_11 = OpTypePointer UniformConstant %11 + %13 = OpTypeSampler +%_ptr_UniformConstant_13 = OpTypePointer UniformConstant %13 + %15 = OpTypeSampledImage %11 + %v4float = OpTypeVector %float 4 + %v2float = OpTypeVector %float 2 +%_ptr_Output_v4float = OpTypePointer Output %v4float + %3 = OpVariable %_ptr_Output_v4float Output + %int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int +%gl_VertexIndex = OpVariable %_ptr_Input_int Input + %_struct_7 = OpTypeStruct %v4float +%_ptr_Output__struct_7 = OpTypePointer Output %_struct_7 + %5 = OpVariable %_ptr_Output__struct_7 Output + %int_0 = OpConstant %int 0 + %float_0 = OpConstant %float 0 + %24 = OpConstantComposite %v2float %float_0 %float_0 + %25 = OpVariable %_ptr_UniformConstant_11 UniformConstant + %26 = OpVariable %_ptr_UniformConstant_13 UniformConstant + %main = OpFunction %void None %9 + %27 = OpLabel + OpLine %6 2 4 + %28 = OpLoad %13 %26 + %29 = OpLoad %11 %25 + %30 = OpSampledImage %15 %29 %28 + %31 = OpImageSampleImplicitLod %v4float %30 %24 + OpLine %6 12 4 + %41 = OpImageSampleProjImplicitLod %v4float %30 %24 + %32 = OpAccessChain %_ptr_Output_v4float %5 %int_0 + OpStore %32 %31 + OpReturn + OpFunctionEnd)"; + + std::vector messages = { + {SPV_MSG_WARNING, "test.hlsl", 2, 4, + "Removing ImageSampleImplicitLod instruction because of incompatible " + "execution model."}, + {SPV_MSG_WARNING, "test.hlsl", 12, 4, + "Removing ImageSampleProjImplicitLod instruction because of " + "incompatible " + "execution model."}}; + SetMessageConsumer(GetTestMessageConsumer(messages)); + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ false); + EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result)); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/scalar_analysis.cpp b/test/opt/scalar_analysis.cpp new file mode 100644 index 000000000..598d8c7b7 --- /dev/null +++ b/test/opt/scalar_analysis.cpp @@ -0,0 +1,1221 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/opt/iterator.h" +#include "source/opt/loop_descriptor.h" +#include "source/opt/pass.h" +#include "source/opt/scalar_analysis.h" +#include "source/opt/tree_iterator.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/function_utils.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::UnorderedElementsAre; +using ScalarAnalysisTest = PassTest<::testing::Test>; + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 410 core +layout (location = 1) out float array[10]; +void main() { + for (int i = 0; i < 10; ++i) { + array[i] = array[i+1]; + } +} +*/ +TEST_F(ScalarAnalysisTest, BasicEvolutionTest) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %24 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 410 + OpName %4 "main" + OpName %24 "array" + OpDecorate %24 Location 1 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 10 + %17 = OpTypeBool + %19 = OpTypeFloat 32 + %20 = OpTypeInt 32 0 + %21 = OpConstant %20 10 + %22 = OpTypeArray %19 %21 + %23 = OpTypePointer Output %22 + %24 = OpVariable %23 Output + %27 = OpConstant %6 1 + %29 = OpTypePointer Output %19 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %10 + %10 = OpLabel + %35 = OpPhi %6 %9 %5 %34 %13 + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %35 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %28 = OpIAdd %6 %35 %27 + %30 = OpAccessChain %29 %24 %28 + %31 = OpLoad %19 %30 + %32 = OpAccessChain %29 %24 %35 + OpStore %32 %31 + OpBranch %13 + %13 = OpLabel + %34 = OpIAdd %6 %35 %27 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 4); + ScalarEvolutionAnalysis analysis{context.get()}; + + const Instruction* store = nullptr; + const Instruction* load = nullptr; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 11)) { + if (inst.opcode() == SpvOp::SpvOpStore) { + store = &inst; + } + if (inst.opcode() == SpvOp::SpvOpLoad) { + load = &inst; + } + } + + EXPECT_NE(load, nullptr); + EXPECT_NE(store, nullptr); + + Instruction* access_chain = + context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0)); + + Instruction* child = context->get_def_use_mgr()->GetDef( + access_chain->GetSingleWordInOperand(1)); + const SENode* node = analysis.AnalyzeInstruction(child); + + EXPECT_NE(node, nullptr); + + // Unsimplified node should have the form of ADD(REC(0,1), 1) + EXPECT_EQ(node->GetType(), SENode::Add); + + const SENode* child_1 = node->GetChild(0); + EXPECT_TRUE(child_1->GetType() == SENode::Constant || + child_1->GetType() == SENode::RecurrentAddExpr); + + const SENode* child_2 = node->GetChild(1); + EXPECT_TRUE(child_2->GetType() == SENode::Constant || + child_2->GetType() == SENode::RecurrentAddExpr); + + SENode* simplified = analysis.SimplifyExpression(const_cast(node)); + // Simplified should be in the form of REC(1,1) + EXPECT_EQ(simplified->GetType(), SENode::RecurrentAddExpr); + + EXPECT_EQ(simplified->GetChild(0)->GetType(), SENode::Constant); + EXPECT_EQ(simplified->GetChild(0)->AsSEConstantNode()->FoldToSingleValue(), + 1); + + EXPECT_EQ(simplified->GetChild(1)->GetType(), SENode::Constant); + EXPECT_EQ(simplified->GetChild(1)->AsSEConstantNode()->FoldToSingleValue(), + 1); + + EXPECT_EQ(simplified->GetChild(0), simplified->GetChild(1)); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 410 core +layout (location = 1) out float array[10]; +layout (location = 2) flat in int loop_invariant; +void main() { + for (int i = 0; i < 10; ++i) { + array[i] = array[i+loop_invariant]; + } +} + +*/ +TEST_F(ScalarAnalysisTest, LoadTest) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 %4 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %3 "array" + OpName %4 "loop_invariant" + OpDecorate %3 Location 1 + OpDecorate %4 Flat + OpDecorate %4 Location 2 + %5 = OpTypeVoid + %6 = OpTypeFunction %5 + %7 = OpTypeInt 32 1 + %8 = OpTypePointer Function %7 + %9 = OpConstant %7 0 + %10 = OpConstant %7 10 + %11 = OpTypeBool + %12 = OpTypeFloat 32 + %13 = OpTypeInt 32 0 + %14 = OpConstant %13 10 + %15 = OpTypeArray %12 %14 + %16 = OpTypePointer Output %15 + %3 = OpVariable %16 Output + %17 = OpTypePointer Input %7 + %4 = OpVariable %17 Input + %18 = OpTypePointer Output %12 + %19 = OpConstant %7 1 + %2 = OpFunction %5 None %6 + %20 = OpLabel + OpBranch %21 + %21 = OpLabel + %22 = OpPhi %7 %9 %20 %23 %24 + OpLoopMerge %25 %24 None + OpBranch %26 + %26 = OpLabel + %27 = OpSLessThan %11 %22 %10 + OpBranchConditional %27 %28 %25 + %28 = OpLabel + %29 = OpLoad %7 %4 + %30 = OpIAdd %7 %22 %29 + %31 = OpAccessChain %18 %3 %30 + %32 = OpLoad %12 %31 + %33 = OpAccessChain %18 %3 %22 + OpStore %33 %32 + OpBranch %24 + %24 = OpLabel + %23 = OpIAdd %7 %22 %19 + OpBranch %21 + %25 = OpLabel + OpReturn + OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 2); + ScalarEvolutionAnalysis analysis{context.get()}; + + const Instruction* load = nullptr; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 28)) { + if (inst.opcode() == SpvOp::SpvOpLoad) { + load = &inst; + } + } + + EXPECT_NE(load, nullptr); + + Instruction* access_chain = + context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0)); + + Instruction* child = context->get_def_use_mgr()->GetDef( + access_chain->GetSingleWordInOperand(1)); + // const SENode* node = + // analysis.GetNodeFromInstruction(child->unique_id()); + + const SENode* node = analysis.AnalyzeInstruction(child); + + EXPECT_NE(node, nullptr); + + // Unsimplified node should have the form of ADD(REC(0,1), X) + EXPECT_EQ(node->GetType(), SENode::Add); + + const SENode* child_1 = node->GetChild(0); + EXPECT_TRUE(child_1->GetType() == SENode::ValueUnknown || + child_1->GetType() == SENode::RecurrentAddExpr); + + const SENode* child_2 = node->GetChild(1); + EXPECT_TRUE(child_2->GetType() == SENode::ValueUnknown || + child_2->GetType() == SENode::RecurrentAddExpr); + + SENode* simplified = analysis.SimplifyExpression(const_cast(node)); + EXPECT_EQ(simplified->GetType(), SENode::RecurrentAddExpr); + + const SERecurrentNode* rec = simplified->AsSERecurrentNode(); + + EXPECT_NE(rec->GetChild(0), rec->GetChild(1)); + + EXPECT_EQ(rec->GetOffset()->GetType(), SENode::ValueUnknown); + + EXPECT_EQ(rec->GetCoefficient()->GetType(), SENode::Constant); + EXPECT_EQ(rec->GetCoefficient()->AsSEConstantNode()->FoldToSingleValue(), 1u); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 410 core +layout (location = 1) out float array[10]; +layout (location = 2) flat in int loop_invariant; +void main() { + array[0] = array[loop_invariant * 2 + 4 + 5 - 24 - loop_invariant - +loop_invariant+ 16 * 3]; +} + +*/ +TEST_F(ScalarAnalysisTest, SimplifySimple) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 %4 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %3 "array" + OpName %4 "loop_invariant" + OpDecorate %3 Location 1 + OpDecorate %4 Flat + OpDecorate %4 Location 2 + %5 = OpTypeVoid + %6 = OpTypeFunction %5 + %7 = OpTypeFloat 32 + %8 = OpTypeInt 32 0 + %9 = OpConstant %8 10 + %10 = OpTypeArray %7 %9 + %11 = OpTypePointer Output %10 + %3 = OpVariable %11 Output + %12 = OpTypeInt 32 1 + %13 = OpConstant %12 0 + %14 = OpTypePointer Input %12 + %4 = OpVariable %14 Input + %15 = OpConstant %12 2 + %16 = OpConstant %12 4 + %17 = OpConstant %12 5 + %18 = OpConstant %12 24 + %19 = OpConstant %12 48 + %20 = OpTypePointer Output %7 + %2 = OpFunction %5 None %6 + %21 = OpLabel + %22 = OpLoad %12 %4 + %23 = OpIMul %12 %22 %15 + %24 = OpIAdd %12 %23 %16 + %25 = OpIAdd %12 %24 %17 + %26 = OpISub %12 %25 %18 + %28 = OpISub %12 %26 %22 + %30 = OpISub %12 %28 %22 + %31 = OpIAdd %12 %30 %19 + %32 = OpAccessChain %20 %3 %31 + %33 = OpLoad %7 %32 + %34 = OpAccessChain %20 %3 %13 + OpStore %34 %33 + OpReturn + OpFunctionEnd + )"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 2); + ScalarEvolutionAnalysis analysis{context.get()}; + + const Instruction* load = nullptr; + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 21)) { + if (inst.opcode() == SpvOp::SpvOpLoad && inst.result_id() == 33) { + load = &inst; + } + } + + EXPECT_NE(load, nullptr); + + Instruction* access_chain = + context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0)); + + Instruction* child = context->get_def_use_mgr()->GetDef( + access_chain->GetSingleWordInOperand(1)); + + const SENode* node = analysis.AnalyzeInstruction(child); + + // Unsimplified is a very large graph with an add at the top. + EXPECT_NE(node, nullptr); + EXPECT_EQ(node->GetType(), SENode::Add); + + // Simplified node should resolve down to a constant expression as the loads + // will eliminate themselves. + SENode* simplified = analysis.SimplifyExpression(const_cast(node)); + + EXPECT_EQ(simplified->GetType(), SENode::Constant); + EXPECT_EQ(simplified->AsSEConstantNode()->FoldToSingleValue(), 33u); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 410 core +layout(location = 0) in vec4 c; +layout (location = 1) out float array[10]; +void main() { + int N = int(c.x); + for (int i = 0; i < 10; ++i) { + array[i] = array[i]; + array[i] = array[i-1]; + array[i] = array[i+1]; + array[i+1] = array[i+1]; + array[i+N] = array[i+N]; + array[i] = array[i+N]; + } +} + +*/ +TEST_F(ScalarAnalysisTest, Simplify) { + const std::string text = R"( OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %12 %33 + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 410 + OpName %4 "main" + OpName %8 "N" + OpName %12 "c" + OpName %19 "i" + OpName %33 "array" + OpDecorate %12 Location 0 + OpDecorate %33 Location 1 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpTypeFloat 32 + %10 = OpTypeVector %9 4 + %11 = OpTypePointer Input %10 + %12 = OpVariable %11 Input + %13 = OpTypeInt 32 0 + %14 = OpConstant %13 0 + %15 = OpTypePointer Input %9 + %20 = OpConstant %6 0 + %27 = OpConstant %6 10 + %28 = OpTypeBool + %30 = OpConstant %13 10 + %31 = OpTypeArray %9 %30 + %32 = OpTypePointer Output %31 + %33 = OpVariable %32 Output + %36 = OpTypePointer Output %9 + %42 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %16 = OpAccessChain %15 %12 %14 + %17 = OpLoad %9 %16 + %18 = OpConvertFToS %6 %17 + OpStore %8 %18 + OpStore %19 %20 + OpBranch %21 + %21 = OpLabel + %78 = OpPhi %6 %20 %5 %77 %24 + OpLoopMerge %23 %24 None + OpBranch %25 + %25 = OpLabel + %29 = OpSLessThan %28 %78 %27 + OpBranchConditional %29 %22 %23 + %22 = OpLabel + %37 = OpAccessChain %36 %33 %78 + %38 = OpLoad %9 %37 + %39 = OpAccessChain %36 %33 %78 + OpStore %39 %38 + %43 = OpISub %6 %78 %42 + %44 = OpAccessChain %36 %33 %43 + %45 = OpLoad %9 %44 + %46 = OpAccessChain %36 %33 %78 + OpStore %46 %45 + %49 = OpIAdd %6 %78 %42 + %50 = OpAccessChain %36 %33 %49 + %51 = OpLoad %9 %50 + %52 = OpAccessChain %36 %33 %78 + OpStore %52 %51 + %54 = OpIAdd %6 %78 %42 + %56 = OpIAdd %6 %78 %42 + %57 = OpAccessChain %36 %33 %56 + %58 = OpLoad %9 %57 + %59 = OpAccessChain %36 %33 %54 + OpStore %59 %58 + %62 = OpIAdd %6 %78 %18 + %65 = OpIAdd %6 %78 %18 + %66 = OpAccessChain %36 %33 %65 + %67 = OpLoad %9 %66 + %68 = OpAccessChain %36 %33 %62 + OpStore %68 %67 + %72 = OpIAdd %6 %78 %18 + %73 = OpAccessChain %36 %33 %72 + %74 = OpLoad %9 %73 + %75 = OpAccessChain %36 %33 %78 + OpStore %75 %74 + OpBranch %24 + %24 = OpLabel + %77 = OpIAdd %6 %78 %42 + OpStore %19 %77 + OpBranch %21 + %23 = OpLabel + OpReturn + OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 4); + ScalarEvolutionAnalysis analysis{context.get()}; + + const Instruction* loads[6]; + const Instruction* stores[6]; + int load_count = 0; + int store_count = 0; + + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 22)) { + if (inst.opcode() == SpvOp::SpvOpLoad) { + loads[load_count] = &inst; + ++load_count; + } + if (inst.opcode() == SpvOp::SpvOpStore) { + stores[store_count] = &inst; + ++store_count; + } + } + + EXPECT_EQ(load_count, 6); + EXPECT_EQ(store_count, 6); + + Instruction* load_access_chain; + Instruction* store_access_chain; + Instruction* load_child; + Instruction* store_child; + SENode* load_node; + SENode* store_node; + SENode* subtract_node; + SENode* simplified_node; + + // Testing [i] - [i] == 0 + load_access_chain = + context->get_def_use_mgr()->GetDef(loads[0]->GetSingleWordInOperand(0)); + store_access_chain = + context->get_def_use_mgr()->GetDef(stores[0]->GetSingleWordInOperand(0)); + + load_child = context->get_def_use_mgr()->GetDef( + load_access_chain->GetSingleWordInOperand(1)); + store_child = context->get_def_use_mgr()->GetDef( + store_access_chain->GetSingleWordInOperand(1)); + + load_node = analysis.AnalyzeInstruction(load_child); + store_node = analysis.AnalyzeInstruction(store_child); + + subtract_node = analysis.CreateSubtraction(store_node, load_node); + simplified_node = analysis.SimplifyExpression(subtract_node); + EXPECT_EQ(simplified_node->GetType(), SENode::Constant); + EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u); + + // Testing [i] - [i-1] == 1 + load_access_chain = + context->get_def_use_mgr()->GetDef(loads[1]->GetSingleWordInOperand(0)); + store_access_chain = + context->get_def_use_mgr()->GetDef(stores[1]->GetSingleWordInOperand(0)); + + load_child = context->get_def_use_mgr()->GetDef( + load_access_chain->GetSingleWordInOperand(1)); + store_child = context->get_def_use_mgr()->GetDef( + store_access_chain->GetSingleWordInOperand(1)); + + load_node = analysis.AnalyzeInstruction(load_child); + store_node = analysis.AnalyzeInstruction(store_child); + + subtract_node = analysis.CreateSubtraction(store_node, load_node); + simplified_node = analysis.SimplifyExpression(subtract_node); + + EXPECT_EQ(simplified_node->GetType(), SENode::Constant); + EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 1u); + + // Testing [i] - [i+1] == -1 + load_access_chain = + context->get_def_use_mgr()->GetDef(loads[2]->GetSingleWordInOperand(0)); + store_access_chain = + context->get_def_use_mgr()->GetDef(stores[2]->GetSingleWordInOperand(0)); + + load_child = context->get_def_use_mgr()->GetDef( + load_access_chain->GetSingleWordInOperand(1)); + store_child = context->get_def_use_mgr()->GetDef( + store_access_chain->GetSingleWordInOperand(1)); + + load_node = analysis.AnalyzeInstruction(load_child); + store_node = analysis.AnalyzeInstruction(store_child); + + subtract_node = analysis.CreateSubtraction(store_node, load_node); + simplified_node = analysis.SimplifyExpression(subtract_node); + EXPECT_EQ(simplified_node->GetType(), SENode::Constant); + EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), -1); + + // Testing [i+1] - [i+1] == 0 + load_access_chain = + context->get_def_use_mgr()->GetDef(loads[3]->GetSingleWordInOperand(0)); + store_access_chain = + context->get_def_use_mgr()->GetDef(stores[3]->GetSingleWordInOperand(0)); + + load_child = context->get_def_use_mgr()->GetDef( + load_access_chain->GetSingleWordInOperand(1)); + store_child = context->get_def_use_mgr()->GetDef( + store_access_chain->GetSingleWordInOperand(1)); + + load_node = analysis.AnalyzeInstruction(load_child); + store_node = analysis.AnalyzeInstruction(store_child); + + subtract_node = analysis.CreateSubtraction(store_node, load_node); + simplified_node = analysis.SimplifyExpression(subtract_node); + EXPECT_EQ(simplified_node->GetType(), SENode::Constant); + EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u); + + // Testing [i+N] - [i+N] == 0 + load_access_chain = + context->get_def_use_mgr()->GetDef(loads[4]->GetSingleWordInOperand(0)); + store_access_chain = + context->get_def_use_mgr()->GetDef(stores[4]->GetSingleWordInOperand(0)); + + load_child = context->get_def_use_mgr()->GetDef( + load_access_chain->GetSingleWordInOperand(1)); + store_child = context->get_def_use_mgr()->GetDef( + store_access_chain->GetSingleWordInOperand(1)); + + load_node = analysis.AnalyzeInstruction(load_child); + store_node = analysis.AnalyzeInstruction(store_child); + + subtract_node = analysis.CreateSubtraction(store_node, load_node); + + simplified_node = analysis.SimplifyExpression(subtract_node); + EXPECT_EQ(simplified_node->GetType(), SENode::Constant); + EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u); + + // Testing [i] - [i+N] == -N + load_access_chain = + context->get_def_use_mgr()->GetDef(loads[5]->GetSingleWordInOperand(0)); + store_access_chain = + context->get_def_use_mgr()->GetDef(stores[5]->GetSingleWordInOperand(0)); + + load_child = context->get_def_use_mgr()->GetDef( + load_access_chain->GetSingleWordInOperand(1)); + store_child = context->get_def_use_mgr()->GetDef( + store_access_chain->GetSingleWordInOperand(1)); + + load_node = analysis.AnalyzeInstruction(load_child); + store_node = analysis.AnalyzeInstruction(store_child); + + subtract_node = analysis.CreateSubtraction(store_node, load_node); + simplified_node = analysis.SimplifyExpression(subtract_node); + EXPECT_EQ(simplified_node->GetType(), SENode::Negative); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 430 +layout(location = 1) out float array[10]; +layout(location = 2) flat in int loop_invariant; +void main(void) { + for (int i = 0; i < 10; ++i) { + array[i * 2 + i * 5] = array[i * i * 2]; + array[i * 2] = array[i * 5]; + } +} + +*/ + +TEST_F(ScalarAnalysisTest, SimplifyMultiplyInductions) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 %4 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %5 "i" + OpName %3 "array" + OpName %4 "loop_invariant" + OpDecorate %3 Location 1 + OpDecorate %4 Flat + OpDecorate %4 Location 2 + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %10 = OpConstant %8 0 + %11 = OpConstant %8 10 + %12 = OpTypeBool + %13 = OpTypeFloat 32 + %14 = OpTypeInt 32 0 + %15 = OpConstant %14 10 + %16 = OpTypeArray %13 %15 + %17 = OpTypePointer Output %16 + %3 = OpVariable %17 Output + %18 = OpConstant %8 2 + %19 = OpConstant %8 5 + %20 = OpTypePointer Output %13 + %21 = OpConstant %8 1 + %22 = OpTypePointer Input %8 + %4 = OpVariable %22 Input + %2 = OpFunction %6 None %7 + %23 = OpLabel + %5 = OpVariable %9 Function + OpStore %5 %10 + OpBranch %24 + %24 = OpLabel + %25 = OpPhi %8 %10 %23 %26 %27 + OpLoopMerge %28 %27 None + OpBranch %29 + %29 = OpLabel + %30 = OpSLessThan %12 %25 %11 + OpBranchConditional %30 %31 %28 + %31 = OpLabel + %32 = OpIMul %8 %25 %18 + %33 = OpIMul %8 %25 %19 + %34 = OpIAdd %8 %32 %33 + %35 = OpIMul %8 %25 %25 + %36 = OpIMul %8 %35 %18 + %37 = OpAccessChain %20 %3 %36 + %38 = OpLoad %13 %37 + %39 = OpAccessChain %20 %3 %34 + OpStore %39 %38 + %40 = OpIMul %8 %25 %18 + %41 = OpIMul %8 %25 %19 + %42 = OpAccessChain %20 %3 %41 + %43 = OpLoad %13 %42 + %44 = OpAccessChain %20 %3 %40 + OpStore %44 %43 + OpBranch %27 + %27 = OpLabel + %26 = OpIAdd %8 %25 %21 + OpStore %5 %26 + OpBranch %24 + %28 = OpLabel + OpReturn + OpFunctionEnd + )"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 2); + ScalarEvolutionAnalysis analysis{context.get()}; + + const Instruction* loads[2] = {nullptr, nullptr}; + const Instruction* stores[2] = {nullptr, nullptr}; + int load_count = 0; + int store_count = 0; + + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 31)) { + if (inst.opcode() == SpvOp::SpvOpLoad) { + loads[load_count] = &inst; + ++load_count; + } + if (inst.opcode() == SpvOp::SpvOpStore) { + stores[store_count] = &inst; + ++store_count; + } + } + + EXPECT_EQ(load_count, 2); + EXPECT_EQ(store_count, 2); + + Instruction* load_access_chain = + context->get_def_use_mgr()->GetDef(loads[0]->GetSingleWordInOperand(0)); + Instruction* store_access_chain = + context->get_def_use_mgr()->GetDef(stores[0]->GetSingleWordInOperand(0)); + + Instruction* load_child = context->get_def_use_mgr()->GetDef( + load_access_chain->GetSingleWordInOperand(1)); + Instruction* store_child = context->get_def_use_mgr()->GetDef( + store_access_chain->GetSingleWordInOperand(1)); + + SENode* store_node = analysis.AnalyzeInstruction(store_child); + + SENode* store_simplified = analysis.SimplifyExpression(store_node); + + load_access_chain = + context->get_def_use_mgr()->GetDef(loads[1]->GetSingleWordInOperand(0)); + store_access_chain = + context->get_def_use_mgr()->GetDef(stores[1]->GetSingleWordInOperand(0)); + load_child = context->get_def_use_mgr()->GetDef( + load_access_chain->GetSingleWordInOperand(1)); + store_child = context->get_def_use_mgr()->GetDef( + store_access_chain->GetSingleWordInOperand(1)); + + SENode* second_store = + analysis.SimplifyExpression(analysis.AnalyzeInstruction(store_child)); + SENode* second_load = + analysis.SimplifyExpression(analysis.AnalyzeInstruction(load_child)); + SENode* combined_add = analysis.SimplifyExpression( + analysis.CreateAddNode(second_load, second_store)); + + // We're checking that the two recurrent expression have been correctly + // folded. In store_simplified they will have been folded as the entire + // expression was simplified as one. In combined_add the two expressions have + // been simplified one after the other which means the recurrent expressions + // aren't exactly the same but should still be folded as they are with respect + // to the same loop. + EXPECT_EQ(combined_add, store_simplified); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 430 +void main(void) { + for (int i = 0; i < 10; --i) { + array[i] = array[i]; + } +} + +*/ + +TEST_F(ScalarAnalysisTest, SimplifyNegativeSteps) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 %4 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %5 "i" + OpName %3 "array" + OpName %4 "loop_invariant" + OpDecorate %3 Location 1 + OpDecorate %4 Flat + OpDecorate %4 Location 2 + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %10 = OpConstant %8 0 + %11 = OpConstant %8 10 + %12 = OpTypeBool + %13 = OpTypeFloat 32 + %14 = OpTypeInt 32 0 + %15 = OpConstant %14 10 + %16 = OpTypeArray %13 %15 + %17 = OpTypePointer Output %16 + %3 = OpVariable %17 Output + %18 = OpTypePointer Output %13 + %19 = OpConstant %8 1 + %20 = OpTypePointer Input %8 + %4 = OpVariable %20 Input + %2 = OpFunction %6 None %7 + %21 = OpLabel + %5 = OpVariable %9 Function + OpStore %5 %10 + OpBranch %22 + %22 = OpLabel + %23 = OpPhi %8 %10 %21 %24 %25 + OpLoopMerge %26 %25 None + OpBranch %27 + %27 = OpLabel + %28 = OpSLessThan %12 %23 %11 + OpBranchConditional %28 %29 %26 + %29 = OpLabel + %30 = OpAccessChain %18 %3 %23 + %31 = OpLoad %13 %30 + %32 = OpAccessChain %18 %3 %23 + OpStore %32 %31 + OpBranch %25 + %25 = OpLabel + %24 = OpISub %8 %23 %19 + OpStore %5 %24 + OpBranch %22 + %26 = OpLabel + OpReturn + OpFunctionEnd + )"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 2); + ScalarEvolutionAnalysis analysis{context.get()}; + + const Instruction* loads[1] = {nullptr}; + int load_count = 0; + + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 29)) { + if (inst.opcode() == SpvOp::SpvOpLoad) { + loads[load_count] = &inst; + ++load_count; + } + } + + EXPECT_EQ(load_count, 1); + + Instruction* load_access_chain = + context->get_def_use_mgr()->GetDef(loads[0]->GetSingleWordInOperand(0)); + Instruction* load_child = context->get_def_use_mgr()->GetDef( + load_access_chain->GetSingleWordInOperand(1)); + + SENode* load_node = analysis.AnalyzeInstruction(load_child); + + EXPECT_TRUE(load_node); + EXPECT_EQ(load_node->GetType(), SENode::RecurrentAddExpr); + EXPECT_TRUE(load_node->AsSERecurrentNode()); + + SENode* child_1 = load_node->AsSERecurrentNode()->GetCoefficient(); + SENode* child_2 = load_node->AsSERecurrentNode()->GetOffset(); + + EXPECT_EQ(child_1->GetType(), SENode::Constant); + EXPECT_EQ(child_2->GetType(), SENode::Constant); + + EXPECT_EQ(child_1->AsSEConstantNode()->FoldToSingleValue(), -1); + EXPECT_EQ(child_2->AsSEConstantNode()->FoldToSingleValue(), 0u); + + SERecurrentNode* load_simplified = + analysis.SimplifyExpression(load_node)->AsSERecurrentNode(); + + EXPECT_TRUE(load_simplified); + EXPECT_EQ(load_node, load_simplified); + + EXPECT_EQ(load_simplified->GetType(), SENode::RecurrentAddExpr); + EXPECT_TRUE(load_simplified->AsSERecurrentNode()); + + SENode* simplified_child_1 = + load_simplified->AsSERecurrentNode()->GetCoefficient(); + SENode* simplified_child_2 = + load_simplified->AsSERecurrentNode()->GetOffset(); + + EXPECT_EQ(child_1, simplified_child_1); + EXPECT_EQ(child_2, simplified_child_2); +} + +/* +Generated from the following GLSL + --eliminate-local-multi-store + +#version 430 +void main(void) { + for (int i = 0; i < 10; --i) { + array[i] = array[i]; + } +} + +*/ + +TEST_F(ScalarAnalysisTest, SimplifyInductionsAndLoads) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 %4 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %5 "i" + OpName %3 "array" + OpName %4 "N" + OpDecorate %3 Location 1 + OpDecorate %4 Flat + OpDecorate %4 Location 2 + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %10 = OpConstant %8 0 + %11 = OpConstant %8 10 + %12 = OpTypeBool + %13 = OpTypeFloat 32 + %14 = OpTypeInt 32 0 + %15 = OpConstant %14 10 + %16 = OpTypeArray %13 %15 + %17 = OpTypePointer Output %16 + %3 = OpVariable %17 Output + %18 = OpConstant %8 2 + %19 = OpTypePointer Input %8 + %4 = OpVariable %19 Input + %20 = OpTypePointer Output %13 + %21 = OpConstant %8 1 + %2 = OpFunction %6 None %7 + %22 = OpLabel + %5 = OpVariable %9 Function + OpStore %5 %10 + OpBranch %23 + %23 = OpLabel + %24 = OpPhi %8 %10 %22 %25 %26 + OpLoopMerge %27 %26 None + OpBranch %28 + %28 = OpLabel + %29 = OpSLessThan %12 %24 %11 + OpBranchConditional %29 %30 %27 + %30 = OpLabel + %31 = OpLoad %8 %4 + %32 = OpIMul %8 %18 %31 + %33 = OpIAdd %8 %24 %32 + %35 = OpIAdd %8 %24 %31 + %36 = OpAccessChain %20 %3 %35 + %37 = OpLoad %13 %36 + %38 = OpAccessChain %20 %3 %33 + OpStore %38 %37 + %39 = OpIMul %8 %18 %24 + %41 = OpIMul %8 %18 %31 + %42 = OpIAdd %8 %39 %41 + %43 = OpIAdd %8 %42 %21 + %44 = OpIMul %8 %18 %24 + %46 = OpIAdd %8 %44 %31 + %47 = OpIAdd %8 %46 %21 + %48 = OpAccessChain %20 %3 %47 + %49 = OpLoad %13 %48 + %50 = OpAccessChain %20 %3 %43 + OpStore %50 %49 + OpBranch %26 + %26 = OpLabel + %25 = OpISub %8 %24 %21 + OpStore %5 %25 + OpBranch %23 + %27 = OpLabel + OpReturn + OpFunctionEnd + )"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 2); + ScalarEvolutionAnalysis analysis{context.get()}; + + std::vector loads{}; + std::vector stores{}; + + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 30)) { + if (inst.opcode() == SpvOp::SpvOpLoad) { + loads.push_back(&inst); + } + if (inst.opcode() == SpvOp::SpvOpStore) { + stores.push_back(&inst); + } + } + + EXPECT_EQ(loads.size(), 3u); + EXPECT_EQ(stores.size(), 2u); + { + Instruction* store_access_chain = context->get_def_use_mgr()->GetDef( + stores[0]->GetSingleWordInOperand(0)); + + Instruction* store_child = context->get_def_use_mgr()->GetDef( + store_access_chain->GetSingleWordInOperand(1)); + + SENode* store_node = analysis.AnalyzeInstruction(store_child); + + SENode* store_simplified = analysis.SimplifyExpression(store_node); + + Instruction* load_access_chain = + context->get_def_use_mgr()->GetDef(loads[1]->GetSingleWordInOperand(0)); + + Instruction* load_child = context->get_def_use_mgr()->GetDef( + load_access_chain->GetSingleWordInOperand(1)); + + SENode* load_node = analysis.AnalyzeInstruction(load_child); + + SENode* load_simplified = analysis.SimplifyExpression(load_node); + + SENode* difference = + analysis.CreateSubtraction(store_simplified, load_simplified); + + SENode* difference_simplified = analysis.SimplifyExpression(difference); + + // Check that i+2*N - i*N, turns into just N when both sides have already + // been simplified into a single recurrent expression. + EXPECT_EQ(difference_simplified->GetType(), SENode::ValueUnknown); + + // Check that the inverse, i*N - i+2*N turns into -N. + SENode* difference_inverse = analysis.SimplifyExpression( + analysis.CreateSubtraction(load_simplified, store_simplified)); + + EXPECT_EQ(difference_inverse->GetType(), SENode::Negative); + EXPECT_EQ(difference_inverse->GetChild(0)->GetType(), SENode::ValueUnknown); + EXPECT_EQ(difference_inverse->GetChild(0), difference_simplified); + } + + { + Instruction* store_access_chain = context->get_def_use_mgr()->GetDef( + stores[1]->GetSingleWordInOperand(0)); + + Instruction* store_child = context->get_def_use_mgr()->GetDef( + store_access_chain->GetSingleWordInOperand(1)); + SENode* store_node = analysis.AnalyzeInstruction(store_child); + SENode* store_simplified = analysis.SimplifyExpression(store_node); + + Instruction* load_access_chain = + context->get_def_use_mgr()->GetDef(loads[2]->GetSingleWordInOperand(0)); + + Instruction* load_child = context->get_def_use_mgr()->GetDef( + load_access_chain->GetSingleWordInOperand(1)); + + SENode* load_node = analysis.AnalyzeInstruction(load_child); + + SENode* load_simplified = analysis.SimplifyExpression(load_node); + + SENode* difference = + analysis.CreateSubtraction(store_simplified, load_simplified); + SENode* difference_simplified = analysis.SimplifyExpression(difference); + + // Check that 2*i + 2*N + 1 - 2*i + N + 1, turns into just N when both + // sides have already been simplified into a single recurrent expression. + EXPECT_EQ(difference_simplified->GetType(), SENode::ValueUnknown); + + // Check that the inverse, (2*i + N + 1) - (2*i + 2*N + 1) turns into -N. + SENode* difference_inverse = analysis.SimplifyExpression( + analysis.CreateSubtraction(load_simplified, store_simplified)); + + EXPECT_EQ(difference_inverse->GetType(), SENode::Negative); + EXPECT_EQ(difference_inverse->GetChild(0)->GetType(), SENode::ValueUnknown); + EXPECT_EQ(difference_inverse->GetChild(0), difference_simplified); + } +} + +/* Generated from the following GLSL + --eliminate-local-multi-store + + #version 430 + layout(location = 1) out float array[10]; + layout(location = 2) flat in int N; + void main(void) { + int step = 0; + for (int i = 0; i < N; i += step) { + step++; + } + } +*/ +TEST_F(ScalarAnalysisTest, InductionWithVariantStep) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 %4 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + OpName %2 "main" + OpName %5 "step" + OpName %6 "i" + OpName %3 "N" + OpName %4 "array" + OpDecorate %3 Flat + OpDecorate %3 Location 2 + OpDecorate %4 Location 1 + %7 = OpTypeVoid + %8 = OpTypeFunction %7 + %9 = OpTypeInt 32 1 + %10 = OpTypePointer Function %9 + %11 = OpConstant %9 0 + %12 = OpTypePointer Input %9 + %3 = OpVariable %12 Input + %13 = OpTypeBool + %14 = OpConstant %9 1 + %15 = OpTypeFloat 32 + %16 = OpTypeInt 32 0 + %17 = OpConstant %16 10 + %18 = OpTypeArray %15 %17 + %19 = OpTypePointer Output %18 + %4 = OpVariable %19 Output + %2 = OpFunction %7 None %8 + %20 = OpLabel + %5 = OpVariable %10 Function + %6 = OpVariable %10 Function + OpStore %5 %11 + OpStore %6 %11 + OpBranch %21 + %21 = OpLabel + %22 = OpPhi %9 %11 %20 %23 %24 + %25 = OpPhi %9 %11 %20 %26 %24 + OpLoopMerge %27 %24 None + OpBranch %28 + %28 = OpLabel + %29 = OpLoad %9 %3 + %30 = OpSLessThan %13 %25 %29 + OpBranchConditional %30 %31 %27 + %31 = OpLabel + %23 = OpIAdd %9 %22 %14 + OpStore %5 %23 + OpBranch %24 + %24 = OpLabel + %26 = OpIAdd %9 %25 %23 + OpStore %6 %26 + OpBranch %21 + %27 = OpLabel + OpReturn + OpFunctionEnd + )"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const Function* f = spvtest::GetFunction(module, 2); + ScalarEvolutionAnalysis analysis{context.get()}; + + std::vector phis{}; + + for (const Instruction& inst : *spvtest::GetBasicBlock(f, 21)) { + if (inst.opcode() == SpvOp::SpvOpPhi) { + phis.push_back(&inst); + } + } + + EXPECT_EQ(phis.size(), 2u); + SENode* phi_node_1 = analysis.AnalyzeInstruction(phis[0]); + SENode* phi_node_2 = analysis.AnalyzeInstruction(phis[1]); + phi_node_1->DumpDot(std::cout, true); + EXPECT_NE(phi_node_1, nullptr); + EXPECT_NE(phi_node_2, nullptr); + + EXPECT_EQ(phi_node_1->GetType(), SENode::RecurrentAddExpr); + EXPECT_EQ(phi_node_2->GetType(), SENode::CanNotCompute); + + SENode* simplified_1 = analysis.SimplifyExpression(phi_node_1); + SENode* simplified_2 = analysis.SimplifyExpression(phi_node_2); + + EXPECT_EQ(simplified_1->GetType(), SENode::RecurrentAddExpr); + EXPECT_EQ(simplified_2->GetType(), SENode::CanNotCompute); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/scalar_replacement_test.cpp b/test/opt/scalar_replacement_test.cpp new file mode 100644 index 000000000..a53f09dfd --- /dev/null +++ b/test/opt/scalar_replacement_test.cpp @@ -0,0 +1,1625 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ScalarReplacementTest = PassTest<::testing::Test>; + +TEST_F(ScalarReplacementTest, SimpleStruct) { + const std::string text = R"( +; +; CHECK: [[struct:%\w+]] = OpTypeStruct [[elem:%\w+]] +; CHECK: [[struct_ptr:%\w+]] = OpTypePointer Function [[struct]] +; CHECK: [[elem_ptr:%\w+]] = OpTypePointer Function [[elem]] +; CHECK: OpConstantNull [[struct]] +; CHECK: [[null:%\w+]] = OpConstantNull [[elem]] +; CHECK-NOT: OpVariable [[struct_ptr]] +; CHECK: [[one:%\w+]] = OpVariable [[elem_ptr]] Function [[null]] +; CHECK-NEXT: [[two:%\w+]] = OpVariable [[elem_ptr]] Function [[null]] +; CHECK-NOT: OpVariable [[elem_ptr]] Function [[null]] +; CHECK-NOT: OpVariable [[struct_ptr]] +; CHECK-NOT: OpInBoundsAccessChain +; CHECK: [[l1:%\w+]] = OpLoad [[elem]] [[two]] +; CHECK-NOT: OpAccessChain +; CHECK: [[l2:%\w+]] = OpLoad [[elem]] [[one]] +; CHECK: OpIAdd [[elem]] [[l1]] [[l2]] +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %6 "simple_struct" +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypeStruct %2 %2 %2 %2 +%4 = OpTypePointer Function %3 +%5 = OpTypePointer Function %2 +%6 = OpTypeFunction %2 +%7 = OpConstantNull %3 +%8 = OpConstant %2 0 +%9 = OpConstant %2 1 +%10 = OpConstant %2 2 +%11 = OpConstant %2 3 +%12 = OpFunction %2 None %6 +%13 = OpLabel +%14 = OpVariable %4 Function %7 +%15 = OpInBoundsAccessChain %5 %14 %8 +%16 = OpLoad %2 %15 +%17 = OpAccessChain %5 %14 %10 +%18 = OpLoad %2 %17 +%19 = OpIAdd %2 %16 %18 +OpReturnValue %19 +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, StructInitialization) { + const std::string text = R"( +; +; CHECK: [[elem:%\w+]] = OpTypeInt 32 0 +; CHECK: [[struct:%\w+]] = OpTypeStruct [[elem]] [[elem]] [[elem]] [[elem]] +; CHECK: [[struct_ptr:%\w+]] = OpTypePointer Function [[struct]] +; CHECK: [[elem_ptr:%\w+]] = OpTypePointer Function [[elem]] +; CHECK: [[zero:%\w+]] = OpConstant [[elem]] 0 +; CHECK: [[undef:%\w+]] = OpUndef [[elem]] +; CHECK: [[two:%\w+]] = OpConstant [[elem]] 2 +; CHECK: [[null:%\w+]] = OpConstantNull [[elem]] +; CHECK-NOT: OpVariable [[struct_ptr]] +; CHECK: OpVariable [[elem_ptr]] Function [[null]] +; CHECK-NEXT: OpVariable [[elem_ptr]] Function [[two]] +; CHECK-NOT: OpVariable [[elem_ptr]] Function [[undef]] +; CHECK-NEXT: OpVariable [[elem_ptr]] Function +; CHECK-NEXT: OpVariable [[elem_ptr]] Function [[zero]] +; CHECK-NOT: OpVariable [[elem_ptr]] Function [[undef]] +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %6 "struct_init" +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypeStruct %2 %2 %2 %2 +%4 = OpTypePointer Function %3 +%20 = OpTypePointer Function %2 +%6 = OpTypeFunction %1 +%7 = OpConstant %2 0 +%8 = OpUndef %2 +%9 = OpConstant %2 2 +%30 = OpConstant %2 1 +%31 = OpConstant %2 3 +%10 = OpConstantNull %2 +%11 = OpConstantComposite %3 %7 %8 %9 %10 +%12 = OpFunction %1 None %6 +%13 = OpLabel +%14 = OpVariable %4 Function %11 +%15 = OpAccessChain %20 %14 %7 +OpStore %15 %10 +%16 = OpAccessChain %20 %14 %9 +OpStore %16 %10 +%17 = OpAccessChain %20 %14 %30 +OpStore %17 %10 +%18 = OpAccessChain %20 %14 %31 +OpStore %18 %10 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, SpecConstantInitialization) { + const std::string text = R"( +; +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[struct:%\w+]] = OpTypeStruct [[int]] [[int]] +; CHECK: [[struct_ptr:%\w+]] = OpTypePointer Function [[struct]] +; CHECK: [[int_ptr:%\w+]] = OpTypePointer Function [[int]] +; CHECK: [[spec_comp:%\w+]] = OpSpecConstantComposite [[struct]] +; CHECK: [[ex0:%\w+]] = OpSpecConstantOp [[int]] CompositeExtract [[spec_comp]] 0 +; CHECK: [[ex1:%\w+]] = OpSpecConstantOp [[int]] CompositeExtract [[spec_comp]] 1 +; CHECK-NOT: OpVariable [[struct]] +; CHECK: OpVariable [[int_ptr]] Function [[ex1]] +; CHECK-NEXT: OpVariable [[int_ptr]] Function [[ex0]] +; CHECK-NOT: OpVariable [[struct]] +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %6 "spec_const" +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypeStruct %2 %2 +%4 = OpTypePointer Function %3 +%20 = OpTypePointer Function %2 +%5 = OpTypeFunction %1 +%6 = OpConstant %2 0 +%30 = OpConstant %2 1 +%7 = OpSpecConstant %2 0 +%8 = OpSpecConstantOp %2 IAdd %7 %7 +%9 = OpSpecConstantComposite %3 %7 %8 +%10 = OpFunction %1 None %5 +%11 = OpLabel +%12 = OpVariable %4 Function %9 +%13 = OpAccessChain %20 %12 %6 +%14 = OpLoad %2 %13 +%15 = OpAccessChain %20 %12 %30 +%16 = OpLoad %2 %15 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +// TODO(alanbaker): Re-enable when vector and matrix scalarization is supported. +// TEST_F(ScalarReplacementTest, VectorInitialization) { +// const std::string text = R"( +// ; +// ; CHECK: [[elem:%\w+]] = OpTypeInt 32 0 +// ; CHECK: [[vector:%\w+]] = OpTypeVector [[elem]] 4 +// ; CHECK: [[vector_ptr:%\w+]] = OpTypePointer Function [[vector]] +// ; CHECK: [[elem_ptr:%\w+]] = OpTypePointer Function [[elem]] +// ; CHECK: [[zero:%\w+]] = OpConstant [[elem]] 0 +// ; CHECK: [[undef:%\w+]] = OpUndef [[elem]] +// ; CHECK: [[two:%\w+]] = OpConstant [[elem]] 2 +// ; CHECK: [[null:%\w+]] = OpConstantNull [[elem]] +// ; CHECK-NOT: OpVariable [[vector_ptr]] +// ; CHECK: OpVariable [[elem_ptr]] Function [[zero]] +// ; CHECK-NOT: OpVariable [[elem_ptr]] Function [[undef]] +// ; CHECK-NEXT: OpVariable [[elem_ptr]] Function +// ; CHECK-NEXT: OpVariable [[elem_ptr]] Function [[two]] +// ; CHECK-NEXT: OpVariable [[elem_ptr]] Function [[null]] +// ; CHECK-NOT: OpVariable [[elem_ptr]] Function [[undef]] +// ; +// OpCapability Shader +// OpCapability Linkage +// OpMemoryModel Logical GLSL450 +// OpName %6 "vector_init" +// %1 = OpTypeVoid +// %2 = OpTypeInt 32 0 +// %3 = OpTypeVector %2 4 +// %4 = OpTypePointer Function %3 +// %20 = OpTypePointer Function %2 +// %6 = OpTypeFunction %1 +// %7 = OpConstant %2 0 +// %8 = OpUndef %2 +// %9 = OpConstant %2 2 +// %30 = OpConstant %2 1 +// %31 = OpConstant %2 3 +// %10 = OpConstantNull %2 +// %11 = OpConstantComposite %3 %10 %9 %8 %7 +// %12 = OpFunction %1 None %6 +// %13 = OpLabel +// %14 = OpVariable %4 Function %11 +// %15 = OpAccessChain %20 %14 %7 +// OpStore %15 %10 +// %16 = OpAccessChain %20 %14 %9 +// OpStore %16 %10 +// %17 = OpAccessChain %20 %14 %30 +// OpStore %17 %10 +// %18 = OpAccessChain %20 %14 %31 +// OpStore %18 %10 +// OpReturn +// OpFunctionEnd +// )"; +// +// SinglePassRunAndMatch(text, true); +// } +// +// TEST_F(ScalarReplacementTest, MatrixInitialization) { +// const std::string text = R"( +// ; +// ; CHECK: [[float:%\w+]] = OpTypeFloat 32 +// ; CHECK: [[vector:%\w+]] = OpTypeVector [[float]] 2 +// ; CHECK: [[matrix:%\w+]] = OpTypeMatrix [[vector]] 2 +// ; CHECK: [[matrix_ptr:%\w+]] = OpTypePointer Function [[matrix]] +// ; CHECK: [[float_ptr:%\w+]] = OpTypePointer Function [[float]] +// ; CHECK: [[vec_ptr:%\w+]] = OpTypePointer Function [[vector]] +// ; CHECK: [[zerof:%\w+]] = OpConstant [[float]] 0 +// ; CHECK: [[onef:%\w+]] = OpConstant [[float]] 1 +// ; CHECK: [[one_zero:%\w+]] = OpConstantComposite [[vector]] [[onef]] +// [[zerof]] ; CHECK: [[zero_one:%\w+]] = OpConstantComposite [[vector]] +// [[zerof]] [[onef]] ; CHECK: [[const_mat:%\w+]] = OpConstantComposite +// [[matrix]] [[one_zero]] +// [[zero_one]] ; CHECK-NOT: OpVariable [[matrix]] ; CHECK-NOT: OpVariable +// [[vector]] Function [[one_zero]] ; CHECK: [[f1:%\w+]] = OpVariable +// [[float_ptr]] Function [[zerof]] ; CHECK-NEXT: [[f2:%\w+]] = OpVariable +// [[float_ptr]] Function [[onef]] ; CHECK-NEXT: [[vec_var:%\w+]] = OpVariable +// [[vec_ptr]] Function [[zero_one]] ; CHECK-NOT: OpVariable [[matrix]] ; +// CHECK-NOT: OpVariable [[vector]] Function [[one_zero]] +// ; +// OpCapability Shader +// OpCapability Linkage +// OpMemoryModel Logical GLSL450 +// OpName %7 "matrix_init" +// %1 = OpTypeVoid +// %2 = OpTypeFloat 32 +// %3 = OpTypeVector %2 2 +// %4 = OpTypeMatrix %3 2 +// %5 = OpTypePointer Function %4 +// %6 = OpTypePointer Function %2 +// %30 = OpTypePointer Function %3 +// %10 = OpTypeInt 32 0 +// %7 = OpTypeFunction %1 %10 +// %8 = OpConstant %2 0.0 +// %9 = OpConstant %2 1.0 +// %11 = OpConstant %10 0 +// %12 = OpConstant %10 1 +// %13 = OpConstantComposite %3 %9 %8 +// %14 = OpConstantComposite %3 %8 %9 +// %15 = OpConstantComposite %4 %13 %14 +// %16 = OpFunction %1 None %7 +// %31 = OpFunctionParameter %10 +// %17 = OpLabel +// %18 = OpVariable %5 Function %15 +// %19 = OpAccessChain %6 %18 %11 %12 +// OpStore %19 %8 +// %20 = OpAccessChain %6 %18 %11 %11 +// OpStore %20 %8 +// %21 = OpAccessChain %30 %18 %12 +// OpStore %21 %14 +// OpReturn +// OpFunctionEnd +// )"; +// +// SinglePassRunAndMatch(text, true); +// } + +TEST_F(ScalarReplacementTest, ElideAccessChain) { + const std::string text = R"( +; +; CHECK: [[var:%\w+]] = OpVariable +; CHECK-NOT: OpAccessChain +; CHECK: OpStore [[var]] +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %6 "elide_access_chain" +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypeStruct %2 %2 %2 %2 +%4 = OpTypePointer Function %3 +%20 = OpTypePointer Function %2 +%6 = OpTypeFunction %1 +%7 = OpConstant %2 0 +%8 = OpUndef %2 +%9 = OpConstant %2 2 +%10 = OpConstantNull %2 +%11 = OpConstantComposite %3 %7 %8 %9 %10 +%12 = OpFunction %1 None %6 +%13 = OpLabel +%14 = OpVariable %4 Function %11 +%15 = OpAccessChain %20 %14 %7 +OpStore %15 %10 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, ElideMultipleAccessChains) { + const std::string text = R"( +; +; CHECK: [[var:%\w+]] = OpVariable +; CHECK-NOT: OpInBoundsAccessChain +; CHECK OpStore [[var]] +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %6 "elide_two_access_chains" +%1 = OpTypeVoid +%2 = OpTypeFloat 32 +%3 = OpTypeStruct %2 %2 +%4 = OpTypeStruct %3 %3 +%5 = OpTypePointer Function %4 +%6 = OpTypePointer Function %2 +%7 = OpTypeFunction %1 +%8 = OpConstant %2 0.0 +%9 = OpConstant %2 1.0 +%10 = OpTypeInt 32 0 +%11 = OpConstant %10 0 +%12 = OpConstant %10 1 +%13 = OpConstantComposite %3 %9 %8 +%14 = OpConstantComposite %3 %8 %9 +%15 = OpConstantComposite %4 %13 %14 +%16 = OpFunction %1 None %7 +%17 = OpLabel +%18 = OpVariable %5 Function %15 +%19 = OpInBoundsAccessChain %6 %18 %11 %12 +OpStore %19 %8 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, ReplaceAccessChain) { + const std::string text = R"( +; +; CHECK: [[param:%\w+]] = OpFunctionParameter +; CHECK: [[var:%\w+]] = OpVariable +; CHECK: [[access:%\w+]] = OpAccessChain {{%\w+}} [[var]] [[param]] +; CHECK: OpStore [[access]] +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %7 "replace_access_chain" +%1 = OpTypeVoid +%2 = OpTypeFloat 32 +%10 = OpTypeInt 32 0 +%uint_2 = OpConstant %10 2 +%3 = OpTypeArray %2 %uint_2 +%4 = OpTypeStruct %3 %3 +%5 = OpTypePointer Function %4 +%20 = OpTypePointer Function %3 +%6 = OpTypePointer Function %2 +%7 = OpTypeFunction %1 %10 +%8 = OpConstant %2 0.0 +%9 = OpConstant %2 1.0 +%11 = OpConstant %10 0 +%12 = OpConstant %10 1 +%13 = OpConstantComposite %3 %9 %8 +%14 = OpConstantComposite %3 %8 %9 +%15 = OpConstantComposite %4 %13 %14 +%16 = OpFunction %1 None %7 +%32 = OpFunctionParameter %10 +%17 = OpLabel +%18 = OpVariable %5 Function %15 +%19 = OpAccessChain %6 %18 %11 %32 +OpStore %19 %8 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, ArrayInitialization) { + const std::string text = R"( +; +; CHECK: [[float:%\w+]] = OpTypeFloat 32 +; CHECK: [[array:%\w+]] = OpTypeArray +; CHECK: [[array_ptr:%\w+]] = OpTypePointer Function [[array]] +; CHECK: [[float_ptr:%\w+]] = OpTypePointer Function [[float]] +; CHECK: [[float0:%\w+]] = OpConstant [[float]] 0 +; CHECK: [[float1:%\w+]] = OpConstant [[float]] 1 +; CHECK: [[float2:%\w+]] = OpConstant [[float]] 2 +; CHECK-NOT: OpVariable [[array_ptr]] +; CHECK: [[var0:%\w+]] = OpVariable [[float_ptr]] Function [[float0]] +; CHECK-NEXT: [[var1:%\w+]] = OpVariable [[float_ptr]] Function [[float1]] +; CHECK-NEXT: [[var2:%\w+]] = OpVariable [[float_ptr]] Function [[float2]] +; CHECK-NOT: OpVariable [[array_ptr]] +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %func "array_init" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%float = OpTypeFloat 32 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%uint_2 = OpConstant %uint 2 +%uint_3 = OpConstant %uint 3 +%float_array = OpTypeArray %float %uint_3 +%array_ptr = OpTypePointer Function %float_array +%float_ptr = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%float_1 = OpConstant %float 1 +%float_2 = OpConstant %float 2 +%const_array = OpConstantComposite %float_array %float_2 %float_1 %float_0 +%func = OpTypeFunction %void +%1 = OpFunction %void None %func +%2 = OpLabel +%3 = OpVariable %array_ptr Function %const_array +%4 = OpInBoundsAccessChain %float_ptr %3 %uint_0 +OpStore %4 %float_0 +%5 = OpInBoundsAccessChain %float_ptr %3 %uint_1 +OpStore %5 %float_0 +%6 = OpInBoundsAccessChain %float_ptr %3 %uint_2 +OpStore %6 %float_0 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, NonUniformCompositeInitialization) { + const std::string text = R"( +; +; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 +; CHECK: [[long:%\w+]] = OpTypeInt 64 1 +; CHECK: [[dvector:%\w+]] = OpTypeVector +; CHECK: [[vector:%\w+]] = OpTypeVector +; CHECK: [[array:%\w+]] = OpTypeArray +; CHECK: [[matrix:%\w+]] = OpTypeMatrix +; CHECK: [[struct1:%\w+]] = OpTypeStruct [[uint]] [[vector]] +; CHECK: [[struct2:%\w+]] = OpTypeStruct [[struct1]] [[matrix]] [[array]] [[uint]] +; CHECK: [[struct1_ptr:%\w+]] = OpTypePointer Function [[struct1]] +; CHECK: [[matrix_ptr:%\w+]] = OpTypePointer Function [[matrix]] +; CHECK: [[array_ptr:%\w+]] = OpTypePointer Function [[array]] +; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]] +; CHECK: [[struct2_ptr:%\w+]] = OpTypePointer Function [[struct2]] +; CHECK: [[const_array:%\w+]] = OpConstantComposite [[array]] +; CHECK: [[const_matrix:%\w+]] = OpConstantNull [[matrix]] +; CHECK: [[const_struct1:%\w+]] = OpConstantComposite [[struct1]] +; CHECK: OpConstantNull [[uint]] +; CHECK: OpConstantNull [[vector]] +; CHECK: OpConstantNull [[long]] +; CHECK: OpFunction +; CHECK-NOT: OpVariable [[struct2_ptr]] Function +; CHECK: OpVariable [[uint_ptr]] Function +; CHECK-NEXT: OpVariable [[matrix_ptr]] Function [[const_matrix]] +; CHECK-NOT: OpVariable [[struct1_ptr]] Function [[const_struct1]] +; CHECK-NOT: OpVariable [[struct2_ptr]] Function +; +OpCapability Shader +OpCapability Linkage +OpCapability Int64 +OpCapability Float64 +OpMemoryModel Logical GLSL450 +OpName %func "non_uniform_composite_init" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%int64 = OpTypeInt 64 1 +%float = OpTypeFloat 32 +%double = OpTypeFloat 64 +%double2 = OpTypeVector %double 2 +%float4 = OpTypeVector %float 4 +%int64_0 = OpConstant %int64 0 +%int64_1 = OpConstant %int64 1 +%int64_2 = OpConstant %int64 2 +%int64_3 = OpConstant %int64 3 +%int64_array3 = OpTypeArray %int64 %int64_3 +%matrix_double2 = OpTypeMatrix %double2 2 +%struct1 = OpTypeStruct %uint %float4 +%struct2 = OpTypeStruct %struct1 %matrix_double2 %int64_array3 %uint +%struct1_ptr = OpTypePointer Function %struct1 +%matrix_double2_ptr = OpTypePointer Function %matrix_double2 +%int64_array_ptr = OpTypePointer Function %int64_array3 +%uint_ptr = OpTypePointer Function %uint +%struct2_ptr = OpTypePointer Function %struct2 +%const_uint = OpConstant %uint 0 +%const_int64_array = OpConstantComposite %int64_array3 %int64_0 %int64_1 %int64_2 +%const_double2 = OpConstantNull %double2 +%const_matrix_double2 = OpConstantNull %matrix_double2 +%undef_float4 = OpUndef %float4 +%const_struct1 = OpConstantComposite %struct1 %const_uint %undef_float4 +%const_struct2 = OpConstantComposite %struct2 %const_struct1 %const_matrix_double2 %const_int64_array %const_uint +%func = OpTypeFunction %void +%1 = OpFunction %void None %func +%2 = OpLabel +%var = OpVariable %struct2_ptr Function %const_struct2 +%3 = OpAccessChain %struct1_ptr %var %int64_0 +OpStore %3 %const_struct1 +%4 = OpAccessChain %matrix_double2_ptr %var %int64_1 +OpStore %4 %const_matrix_double2 +%5 = OpAccessChain %int64_array_ptr %var %int64_2 +OpStore %5 %const_int64_array +%6 = OpAccessChain %uint_ptr %var %int64_3 +OpStore %6 %const_uint +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, ElideUncombinedAccessChains) { + const std::string text = R"( +; +; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 +; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]] +; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0 +; CHECK: [[var:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK-NOT: OpAccessChain +; CHECK: OpStore [[var]] [[const]] +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %func "elide_uncombined_access_chains" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%struct1 = OpTypeStruct %uint +%struct2 = OpTypeStruct %struct1 +%uint_ptr = OpTypePointer Function %uint +%struct1_ptr = OpTypePointer Function %struct1 +%struct2_ptr = OpTypePointer Function %struct2 +%uint_0 = OpConstant %uint 0 +%func = OpTypeFunction %void +%1 = OpFunction %void None %func +%2 = OpLabel +%var = OpVariable %struct2_ptr Function +%3 = OpAccessChain %struct1_ptr %var %uint_0 +%4 = OpAccessChain %uint_ptr %3 %uint_0 +OpStore %4 %uint_0 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, ElideSingleUncombinedAccessChains) { + const std::string text = R"( +; +; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 +; CHECK: [[array:%\w+]] = OpTypeArray [[uint]] +; CHECK: [[array_ptr:%\w+]] = OpTypePointer Function [[array]] +; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0 +; CHECK: [[param:%\w+]] = OpFunctionParameter [[uint]] +; CHECK: [[var:%\w+]] = OpVariable [[array_ptr]] Function +; CHECK: [[access:%\w+]] = OpAccessChain {{.*}} [[var]] [[param]] +; CHECK: OpStore [[access]] [[const]] +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %func "elide_single_uncombined_access_chains" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_1 = OpConstant %uint 1 +%array = OpTypeArray %uint %uint_1 +%struct2 = OpTypeStruct %array +%uint_ptr = OpTypePointer Function %uint +%array_ptr = OpTypePointer Function %array +%struct2_ptr = OpTypePointer Function %struct2 +%uint_0 = OpConstant %uint 0 +%func = OpTypeFunction %void %uint +%1 = OpFunction %void None %func +%param = OpFunctionParameter %uint +%2 = OpLabel +%var = OpVariable %struct2_ptr Function +%3 = OpAccessChain %array_ptr %var %uint_0 +%4 = OpAccessChain %uint_ptr %3 %param +OpStore %4 %uint_0 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, ReplaceWholeLoad) { + const std::string text = R"( +; +; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 +; CHECK: [[struct1:%\w+]] = OpTypeStruct [[uint]] [[uint]] +; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]] +; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0 +; CHECK: [[var1:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK: [[var0:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK: [[l1:%\w+]] = OpLoad [[uint]] [[var1]] +; CHECK: [[l0:%\w+]] = OpLoad [[uint]] [[var0]] +; CHECK: OpCompositeConstruct [[struct1]] [[l0]] [[l1]] +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %func "replace_whole_load" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%struct1 = OpTypeStruct %uint %uint +%uint_ptr = OpTypePointer Function %uint +%struct1_ptr = OpTypePointer Function %struct1 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%func = OpTypeFunction %void +%1 = OpFunction %void None %func +%2 = OpLabel +%var = OpVariable %struct1_ptr Function +%load = OpLoad %struct1 %var +%3 = OpAccessChain %uint_ptr %var %uint_0 +OpStore %3 %uint_0 +%4 = OpAccessChain %uint_ptr %var %uint_1 +OpStore %4 %uint_0 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, ReplaceWholeLoadCopyMemoryAccess) { + const std::string text = R"( +; +; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 +; CHECK: [[struct1:%\w+]] = OpTypeStruct [[uint]] [[uint]] +; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]] +; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0 +; CHECK: [[null:%\w+]] = OpConstantNull [[uint]] +; CHECK: [[var0:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK: [[l0:%\w+]] = OpLoad [[uint]] [[var0]] Nontemporal +; CHECK: OpCompositeConstruct [[struct1]] [[l0]] [[null]] +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %func "replace_whole_load_copy_memory_access" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%struct1 = OpTypeStruct %uint %uint +%uint_ptr = OpTypePointer Function %uint +%struct1_ptr = OpTypePointer Function %struct1 +%uint_0 = OpConstant %uint 0 +%func = OpTypeFunction %void +%1 = OpFunction %void None %func +%2 = OpLabel +%var = OpVariable %struct1_ptr Function +%load = OpLoad %struct1 %var Nontemporal +%3 = OpAccessChain %uint_ptr %var %uint_0 +OpStore %3 %uint_0 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, ReplaceWholeStore) { + const std::string text = R"( +; +; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 +; CHECK: [[struct1:%\w+]] = OpTypeStruct [[uint]] [[uint]] +; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]] +; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0 +; CHECK: [[const_struct:%\w+]] = OpConstantComposite [[struct1]] [[const]] [[const]] +; CHECK: [[var0:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK: [[ex0:%\w+]] = OpCompositeExtract [[uint]] [[const_struct]] 0 +; CHECK: OpStore [[var0]] [[ex0]] +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %func "replace_whole_store" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%struct1 = OpTypeStruct %uint %uint +%uint_ptr = OpTypePointer Function %uint +%struct1_ptr = OpTypePointer Function %struct1 +%uint_0 = OpConstant %uint 0 +%const_struct = OpConstantComposite %struct1 %uint_0 %uint_0 +%func = OpTypeFunction %void +%1 = OpFunction %void None %func +%2 = OpLabel +%var = OpVariable %struct1_ptr Function +OpStore %var %const_struct +%3 = OpAccessChain %uint_ptr %var %uint_0 +%4 = OpLoad %uint %3 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, ReplaceWholeStoreCopyMemoryAccess) { + const std::string text = R"( +; +; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 +; CHECK: [[struct1:%\w+]] = OpTypeStruct [[uint]] [[uint]] +; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]] +; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0 +; CHECK: [[const_struct:%\w+]] = OpConstantComposite [[struct1]] [[const]] [[const]] +; CHECK: [[var0:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK-NOT: OpVariable +; CHECK: [[ex0:%\w+]] = OpCompositeExtract [[uint]] [[const_struct]] 0 +; CHECK: OpStore [[var0]] [[ex0]] Aligned 4 +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %func "replace_whole_store_copy_memory_access" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%struct1 = OpTypeStruct %uint %uint +%uint_ptr = OpTypePointer Function %uint +%struct1_ptr = OpTypePointer Function %struct1 +%uint_0 = OpConstant %uint 0 +%const_struct = OpConstantComposite %struct1 %uint_0 %uint_0 +%func = OpTypeFunction %void +%1 = OpFunction %void None %func +%2 = OpLabel +%var = OpVariable %struct1_ptr Function +OpStore %var %const_struct Aligned 4 +%3 = OpAccessChain %uint_ptr %var %uint_0 +%4 = OpLoad %uint %3 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, DontTouchVolatileLoad) { + const std::string text = R"( +; +; CHECK: [[struct:%\w+]] = OpTypeStruct +; CHECK: [[struct_ptr:%\w+]] = OpTypePointer Function [[struct]] +; CHECK: OpLabel +; CHECK-NEXT: OpVariable [[struct_ptr]] +; CHECK-NOT: OpVariable +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %func "dont_touch_volatile_load" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%struct1 = OpTypeStruct %uint +%uint_ptr = OpTypePointer Function %uint +%struct1_ptr = OpTypePointer Function %struct1 +%uint_0 = OpConstant %uint 0 +%func = OpTypeFunction %void +%1 = OpFunction %void None %func +%2 = OpLabel +%var = OpVariable %struct1_ptr Function +%3 = OpAccessChain %uint_ptr %var %uint_0 +%4 = OpLoad %uint %3 Volatile +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, DontTouchVolatileStore) { + const std::string text = R"( +; +; CHECK: [[struct:%\w+]] = OpTypeStruct +; CHECK: [[struct_ptr:%\w+]] = OpTypePointer Function [[struct]] +; CHECK: OpLabel +; CHECK-NEXT: OpVariable [[struct_ptr]] +; CHECK-NOT: OpVariable +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %func "dont_touch_volatile_store" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%struct1 = OpTypeStruct %uint +%uint_ptr = OpTypePointer Function %uint +%struct1_ptr = OpTypePointer Function %struct1 +%uint_0 = OpConstant %uint 0 +%func = OpTypeFunction %void +%1 = OpFunction %void None %func +%2 = OpLabel +%var = OpVariable %struct1_ptr Function +%3 = OpAccessChain %uint_ptr %var %uint_0 +OpStore %3 %uint_0 Volatile +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, DontTouchSpecNonFunctionVariable) { + const std::string text = R"( +; +; CHECK: [[struct:%\w+]] = OpTypeStruct +; CHECK: [[struct_ptr:%\w+]] = OpTypePointer Uniform [[struct]] +; CHECK: OpConstant +; CHECK-NEXT: OpVariable [[struct_ptr]] +; CHECK-NOT: OpVariable +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %func "dont_touch_spec_constant_access_chain" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%struct1 = OpTypeStruct %uint +%uint_ptr = OpTypePointer Uniform %uint +%struct1_ptr = OpTypePointer Uniform %struct1 +%uint_0 = OpConstant %uint 0 +%var = OpVariable %struct1_ptr Uniform +%func = OpTypeFunction %void +%1 = OpFunction %void None %func +%2 = OpLabel +%3 = OpAccessChain %uint_ptr %var %uint_0 +OpStore %3 %uint_0 Volatile +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, DontTouchSpecConstantAccessChain) { + const std::string text = R"( +; +; CHECK: [[array:%\w+]] = OpTypeArray +; CHECK: [[array_ptr:%\w+]] = OpTypePointer Function [[array]] +; CHECK: OpLabel +; CHECK-NEXT: OpVariable [[array_ptr]] +; CHECK-NOT: OpVariable +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %func "dont_touch_spec_constant_access_chain" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_1 = OpConstant %uint 1 +%array = OpTypeArray %uint %uint_1 +%uint_ptr = OpTypePointer Function %uint +%array_ptr = OpTypePointer Function %array +%uint_0 = OpConstant %uint 0 +%spec_const = OpSpecConstant %uint 0 +%func = OpTypeFunction %void +%1 = OpFunction %void None %func +%2 = OpLabel +%var = OpVariable %array_ptr Function +%3 = OpAccessChain %uint_ptr %var %spec_const +OpStore %3 %uint_0 Volatile +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, NoPartialAccesses) { + const std::string text = R"( +; +; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 +; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]] +; CHECK: OpLabel +; CHECK-NOT: OpVariable +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %func "no_partial_accesses" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%struct1 = OpTypeStruct %uint +%uint_ptr = OpTypePointer Function %uint +%struct1_ptr = OpTypePointer Function %struct1 +%const = OpConstantNull %struct1 +%func = OpTypeFunction %void +%1 = OpFunction %void None %func +%2 = OpLabel +%var = OpVariable %struct1_ptr Function +OpStore %var %const +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, DontTouchPtrAccessChain) { + const std::string text = R"( +; +; CHECK: [[struct:%\w+]] = OpTypeStruct +; CHECK: [[struct_ptr:%\w+]] = OpTypePointer Function [[struct]] +; CHECK: OpLabel +; CHECK-NEXT: OpVariable [[struct_ptr]] +; CHECK-NOT: OpVariable +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %func "dont_touch_ptr_access_chain" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%struct1 = OpTypeStruct %uint +%uint_ptr = OpTypePointer Function %uint +%struct1_ptr = OpTypePointer Function %struct1 +%uint_0 = OpConstant %uint 0 +%func = OpTypeFunction %void +%1 = OpFunction %void None %func +%2 = OpLabel +%var = OpVariable %struct1_ptr Function +%3 = OpPtrAccessChain %uint_ptr %var %uint_0 %uint_0 +OpStore %3 %uint_0 +%4 = OpAccessChain %uint_ptr %var %uint_0 +OpStore %4 %uint_0 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, false); +} + +TEST_F(ScalarReplacementTest, DontTouchInBoundsPtrAccessChain) { + const std::string text = R"( +; +; CHECK: [[struct:%\w+]] = OpTypeStruct +; CHECK: [[struct_ptr:%\w+]] = OpTypePointer Function [[struct]] +; CHECK: OpLabel +; CHECK-NEXT: OpVariable [[struct_ptr]] +; CHECK-NOT: OpVariable +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %func "dont_touch_in_bounds_ptr_access_chain" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%struct1 = OpTypeStruct %uint +%uint_ptr = OpTypePointer Function %uint +%struct1_ptr = OpTypePointer Function %struct1 +%uint_0 = OpConstant %uint 0 +%func = OpTypeFunction %void +%1 = OpFunction %void None %func +%2 = OpLabel +%var = OpVariable %struct1_ptr Function +%3 = OpInBoundsPtrAccessChain %uint_ptr %var %uint_0 %uint_0 +OpStore %3 %uint_0 +%4 = OpInBoundsAccessChain %uint_ptr %var %uint_0 +OpStore %4 %uint_0 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, false); +} + +TEST_F(ScalarReplacementTest, DonTouchAliasedDecoration) { + const std::string text = R"( +; +; CHECK: [[struct:%\w+]] = OpTypeStruct +; CHECK: [[struct_ptr:%\w+]] = OpTypePointer Function [[struct]] +; CHECK: OpLabel +; CHECK-NEXT: OpVariable [[struct_ptr]] +; CHECK-NOT: OpVariable +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %func "aliased" +OpDecorate %var Aliased +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%struct1 = OpTypeStruct %uint +%uint_ptr = OpTypePointer Function %uint +%struct1_ptr = OpTypePointer Function %struct1 +%uint_0 = OpConstant %uint 0 +%func = OpTypeFunction %void +%1 = OpFunction %void None %func +%2 = OpLabel +%var = OpVariable %struct1_ptr Function +%3 = OpAccessChain %uint_ptr %var %uint_0 +%4 = OpLoad %uint %3 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, CopyRestrictDecoration) { + const std::string text = R"( +; +; CHECK: OpName +; CHECK-NEXT: OpDecorate [[var0:%\w+]] Restrict +; CHECK-NEXT: OpDecorate [[var1:%\w+]] Restrict +; CHECK: [[int:%\w+]] = OpTypeInt +; CHECK: [[struct:%\w+]] = OpTypeStruct +; CHECK: [[int_ptr:%\w+]] = OpTypePointer Function [[int]] +; CHECK: [[struct_ptr:%\w+]] = OpTypePointer Function [[struct]] +; CHECK: OpLabel +; CHECK-NEXT: [[var1]] = OpVariable [[int_ptr]] +; CHECK-NEXT: [[var0]] = OpVariable [[int_ptr]] +; CHECK-NOT: OpVariable [[struct_ptr]] +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %func "restrict" +OpDecorate %var Restrict +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%struct1 = OpTypeStruct %uint %uint +%uint_ptr = OpTypePointer Function %uint +%struct1_ptr = OpTypePointer Function %struct1 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%func = OpTypeFunction %void +%1 = OpFunction %void None %func +%2 = OpLabel +%var = OpVariable %struct1_ptr Function +%3 = OpAccessChain %uint_ptr %var %uint_0 +%4 = OpLoad %uint %3 +%5 = OpAccessChain %uint_ptr %var %uint_1 +%6 = OpLoad %uint %5 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, DontClobberDecoratesOnSubtypes) { + const std::string text = R"( +; +; CHECK: OpDecorate [[array:%\w+]] ArrayStride 1 +; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 +; CHECK: [[array]] = OpTypeArray [[uint]] +; CHECK: [[array_ptr:%\w+]] = OpTypePointer Function [[array]] +; CHECK: OpLabel +; CHECK-NEXT: OpVariable [[array_ptr]] Function +; CHECK-NOT: OpVariable +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %func "array_stride" +OpDecorate %array ArrayStride 1 +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_1 = OpConstant %uint 1 +%array = OpTypeArray %uint %uint_1 +%struct1 = OpTypeStruct %array +%uint_ptr = OpTypePointer Function %uint +%struct1_ptr = OpTypePointer Function %struct1 +%uint_0 = OpConstant %uint 0 +%func = OpTypeFunction %void %uint +%1 = OpFunction %void None %func +%param = OpFunctionParameter %uint +%2 = OpLabel +%var = OpVariable %struct1_ptr Function +%3 = OpAccessChain %uint_ptr %var %uint_0 %param +%4 = OpLoad %uint %3 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, DontCopyMemberDecorate) { + const std::string text = R"( +; +; CHECK-NOT: OpDecorate +; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 +; CHECK: [[struct:%\w+]] = OpTypeStruct [[uint]] +; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]] +; CHECK: [[struct_ptr:%\w+]] = OpTypePointer Function [[struct]] +; CHECK: OpLabel +; CHECK-NEXT: OpVariable [[uint_ptr]] Function +; CHECK-NOT: OpVariable +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %func "member_decorate" +OpMemberDecorate %struct1 0 Offset 1 +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_1 = OpConstant %uint 1 +%struct1 = OpTypeStruct %uint +%uint_ptr = OpTypePointer Function %uint +%struct1_ptr = OpTypePointer Function %struct1 +%uint_0 = OpConstant %uint 0 +%func = OpTypeFunction %void %uint +%1 = OpFunction %void None %func +%2 = OpLabel +%var = OpVariable %struct1_ptr Function +%3 = OpAccessChain %uint_ptr %var %uint_0 +%4 = OpLoad %uint %3 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, NoPartialAccesses2) { + const std::string text = R"( +; +; CHECK: [[float:%\w+]] = OpTypeFloat 32 +; CHECK: [[float_ptr:%\w+]] = OpTypePointer Function [[float]] +; CHECK: OpVariable [[float_ptr]] Function +; CHECK: OpVariable [[float_ptr]] Function +; CHECK: OpVariable [[float_ptr]] Function +; CHECK: OpVariable [[float_ptr]] Function +; CHECK: OpVariable [[float_ptr]] Function +; CHECK: OpVariable [[float_ptr]] Function +; CHECK: OpVariable [[float_ptr]] Function +; CHECK-NOT: OpVariable +; +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %fo +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 430 +OpName %main "main" +OpName %S "S" +OpMemberName %S 0 "x" +OpMemberName %S 1 "y" +OpName %ts1 "ts1" +OpName %S_0 "S" +OpMemberName %S_0 0 "x" +OpMemberName %S_0 1 "y" +OpName %U_t "U_t" +OpMemberName %U_t 0 "g_s1" +OpMemberName %U_t 1 "g_s2" +OpMemberName %U_t 2 "g_s3" +OpName %_ "" +OpName %ts2 "ts2" +OpName %_Globals_ "_Globals_" +OpMemberName %_Globals_ 0 "g_b" +OpName %__0 "" +OpName %ts3 "ts3" +OpName %ts4 "ts4" +OpName %fo "fo" +OpMemberDecorate %S_0 0 Offset 0 +OpMemberDecorate %S_0 1 Offset 4 +OpMemberDecorate %U_t 0 Offset 0 +OpMemberDecorate %U_t 1 Offset 8 +OpMemberDecorate %U_t 2 Offset 16 +OpDecorate %U_t BufferBlock +OpDecorate %_ DescriptorSet 0 +OpMemberDecorate %_Globals_ 0 Offset 0 +OpDecorate %_Globals_ Block +OpDecorate %__0 DescriptorSet 0 +OpDecorate %__0 Binding 0 +OpDecorate %fo Location 0 +%void = OpTypeVoid +%15 = OpTypeFunction %void +%float = OpTypeFloat 32 +%S = OpTypeStruct %float %float +%_ptr_Function_S = OpTypePointer Function %S +%S_0 = OpTypeStruct %float %float +%U_t = OpTypeStruct %S_0 %S_0 %S_0 +%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t +%_ = OpVariable %_ptr_Uniform_U_t Uniform +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%_ptr_Uniform_S_0 = OpTypePointer Uniform %S_0 +%_ptr_Function_float = OpTypePointer Function %float +%int_1 = OpConstant %int 1 +%uint = OpTypeInt 32 0 +%_Globals_ = OpTypeStruct %uint +%_ptr_Uniform__Globals_ = OpTypePointer Uniform %_Globals_ +%__0 = OpVariable %_ptr_Uniform__Globals_ Uniform +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%bool = OpTypeBool +%uint_0 = OpConstant %uint 0 +%_ptr_Output_float = OpTypePointer Output %float +%fo = OpVariable %_ptr_Output_float Output +%main = OpFunction %void None %15 +%30 = OpLabel +%ts1 = OpVariable %_ptr_Function_S Function +%ts2 = OpVariable %_ptr_Function_S Function +%ts3 = OpVariable %_ptr_Function_S Function +%ts4 = OpVariable %_ptr_Function_S Function +%31 = OpAccessChain %_ptr_Uniform_S_0 %_ %int_0 +%32 = OpLoad %S_0 %31 +%33 = OpCompositeExtract %float %32 0 +%34 = OpAccessChain %_ptr_Function_float %ts1 %int_0 +OpStore %34 %33 +%35 = OpCompositeExtract %float %32 1 +%36 = OpAccessChain %_ptr_Function_float %ts1 %int_1 +OpStore %36 %35 +%37 = OpAccessChain %_ptr_Uniform_S_0 %_ %int_1 +%38 = OpLoad %S_0 %37 +%39 = OpCompositeExtract %float %38 0 +%40 = OpAccessChain %_ptr_Function_float %ts2 %int_0 +OpStore %40 %39 +%41 = OpCompositeExtract %float %38 1 +%42 = OpAccessChain %_ptr_Function_float %ts2 %int_1 +OpStore %42 %41 +%43 = OpAccessChain %_ptr_Uniform_uint %__0 %int_0 +%44 = OpLoad %uint %43 +%45 = OpINotEqual %bool %44 %uint_0 +OpSelectionMerge %46 None +OpBranchConditional %45 %47 %48 +%47 = OpLabel +%49 = OpLoad %S %ts1 +OpStore %ts3 %49 +OpBranch %46 +%48 = OpLabel +%50 = OpLoad %S %ts2 +OpStore %ts3 %50 +OpBranch %46 +%46 = OpLabel +%51 = OpLoad %S %ts3 +OpStore %ts4 %51 +%52 = OpAccessChain %_ptr_Function_float %ts4 %int_1 +%53 = OpLoad %float %52 +OpStore %fo %53 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, ReplaceWholeLoadAndStore) { + const std::string text = R"( +; +; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 +; CHECK: [[struct1:%\w+]] = OpTypeStruct [[uint]] [[uint]] +; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]] +; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0 +; CHECK: [[null:%\w+]] = OpConstantNull [[uint]] +; CHECK: [[var0:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK: [[var1:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK-NOT: OpVariable +; CHECK: [[l0:%\w+]] = OpLoad [[uint]] [[var0]] +; CHECK: [[c0:%\w+]] = OpCompositeConstruct [[struct1]] [[l0]] [[null]] +; CHECK: [[e0:%\w+]] = OpCompositeExtract [[uint]] [[c0]] 0 +; CHECK: OpStore [[var1]] [[e0]] +; CHECK: [[l1:%\w+]] = OpLoad [[uint]] [[var1]] +; CHECK: [[c1:%\w+]] = OpCompositeConstruct [[struct1]] [[l1]] [[null]] +; CHECK: [[e1:%\w+]] = OpCompositeExtract [[uint]] [[c1]] 0 +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %func "replace_whole_load" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%struct1 = OpTypeStruct %uint %uint +%uint_ptr = OpTypePointer Function %uint +%struct1_ptr = OpTypePointer Function %struct1 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%func = OpTypeFunction %void +%1 = OpFunction %void None %func +%2 = OpLabel +%var2 = OpVariable %struct1_ptr Function +%var1 = OpVariable %struct1_ptr Function +%load1 = OpLoad %struct1 %var1 +OpStore %var2 %load1 +%load2 = OpLoad %struct1 %var2 +%3 = OpCompositeExtract %uint %load2 0 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, ReplaceWholeLoadAndStore2) { + // TODO: We can improve this case by ensuring that |var2| is processed first. + const std::string text = R"( +; +; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 +; CHECK: [[struct1:%\w+]] = OpTypeStruct [[uint]] [[uint]] +; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]] +; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0 +; CHECK: [[null:%\w+]] = OpConstantNull [[uint]] +; CHECK: [[var1:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK: [[var0a:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK: [[var0b:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK-NOT: OpVariable +; CHECK: [[l0a:%\w+]] = OpLoad [[uint]] [[var0a]] +; CHECK: [[l0b:%\w+]] = OpLoad [[uint]] [[var0b]] +; CHECK: [[c0:%\w+]] = OpCompositeConstruct [[struct1]] [[l0b]] [[l0a]] +; CHECK: [[e0:%\w+]] = OpCompositeExtract [[uint]] [[c0]] 0 +; CHECK: OpStore [[var1]] [[e0]] +; CHECK: [[l1:%\w+]] = OpLoad [[uint]] [[var1]] +; CHECK: [[c1:%\w+]] = OpCompositeConstruct [[struct1]] [[l1]] [[null]] +; CHECK: [[e1:%\w+]] = OpCompositeExtract [[uint]] [[c1]] 0 +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %func "replace_whole_load" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%struct1 = OpTypeStruct %uint %uint +%uint_ptr = OpTypePointer Function %uint +%struct1_ptr = OpTypePointer Function %struct1 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%func = OpTypeFunction %void +%1 = OpFunction %void None %func +%2 = OpLabel +%var1 = OpVariable %struct1_ptr Function +%var2 = OpVariable %struct1_ptr Function +%load1 = OpLoad %struct1 %var1 +OpStore %var2 %load1 +%load2 = OpLoad %struct1 %var2 +%3 = OpCompositeExtract %uint %load2 0 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, CreateAmbiguousNullConstant1) { + const std::string text = R"( +; +; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 +; CHECK: [[struct1:%\w+]] = OpTypeStruct [[uint]] [[struct_member:%\w+]] +; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]] +; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0 +; CHECK: [[null:%\w+]] = OpConstantNull [[struct_member]] +; CHECK: [[var0a:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK: [[var1:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK: [[var0b:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK-NOT: OpVariable +; CHECK: OpStore [[var1]] +; CHECK: [[l1:%\w+]] = OpLoad [[uint]] [[var1]] +; CHECK: [[c1:%\w+]] = OpCompositeConstruct [[struct1]] [[l1]] [[null]] +; CHECK: [[e1:%\w+]] = OpCompositeExtract [[uint]] [[c1]] 0 +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %func "replace_whole_load" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%struct2 = OpTypeStruct %uint +%struct3 = OpTypeStruct %uint +%struct1 = OpTypeStruct %uint %struct2 +%uint_ptr = OpTypePointer Function %uint +%struct1_ptr = OpTypePointer Function %struct1 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%func = OpTypeFunction %void +%1 = OpFunction %void None %func +%2 = OpLabel +%var1 = OpVariable %struct1_ptr Function +%var2 = OpVariable %struct1_ptr Function +%load1 = OpLoad %struct1 %var1 +OpStore %var2 %load1 +%load2 = OpLoad %struct1 %var2 +%3 = OpCompositeExtract %uint %load2 0 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, SpecConstantArray) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt +; CHECK: [[spec_const:%\w+]] = OpSpecConstant [[int]] 4 +; CHECK: [[spec_op:%\w+]] = OpSpecConstantOp [[int]] IAdd [[spec_const]] [[spec_const]] +; CHECK: [[array1:%\w+]] = OpTypeArray [[int]] [[spec_const]] +; CHECK: [[array2:%\w+]] = OpTypeArray [[int]] [[spec_op]] +; CHECK: [[ptr_array1:%\w+]] = OpTypePointer Function [[array1]] +; CHECK: [[ptr_array2:%\w+]] = OpTypePointer Function [[array2]] +; CHECK: OpLabel +; CHECK-NEXT: OpVariable [[ptr_array1]] Function +; CHECK-NEXT: OpVariable [[ptr_array2]] Function +; CHECK-NOT: OpVariable +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%void = OpTypeVoid +%void_fn = OpTypeFunction %void +%int = OpTypeInt 32 0 +%spec_const = OpSpecConstant %int 4 +%spec_op = OpSpecConstantOp %int IAdd %spec_const %spec_const +%array_1 = OpTypeArray %int %spec_const +%array_2 = OpTypeArray %int %spec_op +%ptr_array_1_Function = OpTypePointer Function %array_1 +%ptr_array_2_Function = OpTypePointer Function %array_2 +%func = OpFunction %void None %void_fn +%1 = OpLabel +%var_1 = OpVariable %ptr_array_1_Function Function +%var_2 = OpVariable %ptr_array_2_Function Function +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(ScalarReplacementTest, CreateAmbiguousNullConstant2) { + const std::string text = R"( +; +; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 +; CHECK: [[struct1:%\w+]] = OpTypeStruct [[uint]] [[struct_member:%\w+]] +; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]] +; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0 +; CHECK: [[null:%\w+]] = OpConstantNull [[struct_member]] +; CHECK: [[var0a:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK: [[var1:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK: [[var0b:%\w+]] = OpVariable [[uint_ptr]] Function +; CHECK: OpStore [[var1]] +; CHECK: [[l1:%\w+]] = OpLoad [[uint]] [[var1]] +; CHECK: [[c1:%\w+]] = OpCompositeConstruct [[struct1]] [[l1]] [[null]] +; CHECK: [[e1:%\w+]] = OpCompositeExtract [[uint]] [[c1]] 0 +; +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %func "replace_whole_load" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%struct3 = OpTypeStruct %uint +%struct2 = OpTypeStruct %uint +%struct1 = OpTypeStruct %uint %struct2 +%uint_ptr = OpTypePointer Function %uint +%struct1_ptr = OpTypePointer Function %struct1 +%uint_0 = OpConstant %uint 0 +%uint_1 = OpConstant %uint 1 +%func = OpTypeFunction %void +%1 = OpFunction %void None %func +%2 = OpLabel +%var1 = OpVariable %struct1_ptr Function +%var2 = OpVariable %struct1_ptr Function +%load1 = OpLoad %struct1 %var1 +OpStore %var2 %load1 +%load2 = OpLoad %struct1 %var2 +%3 = OpCompositeExtract %uint %load2 0 +OpReturn +OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +// Test that a struct of size 4 is not replaced when there is a limit of 2. +TEST_F(ScalarReplacementTest, TestLimit) { + const std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %6 "simple_struct" +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypeStruct %2 %2 %2 %2 +%4 = OpTypePointer Function %3 +%5 = OpTypePointer Function %2 +%6 = OpTypeFunction %2 +%7 = OpConstantNull %3 +%8 = OpConstant %2 0 +%9 = OpConstant %2 1 +%10 = OpConstant %2 2 +%11 = OpConstant %2 3 +%12 = OpFunction %2 None %6 +%13 = OpLabel +%14 = OpVariable %4 Function %7 +%15 = OpInBoundsAccessChain %5 %14 %8 +%16 = OpLoad %2 %15 +%17 = OpAccessChain %5 %14 %10 +%18 = OpLoad %2 %17 +%19 = OpIAdd %2 %16 %18 +OpReturnValue %19 +OpFunctionEnd + )"; + + auto result = + SinglePassRunAndDisassemble(text, true, false, 2); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +// Test that a struct of size 4 is replaced when there is a limit of 0 (no +// limit). This is the same spir-v as a test above, so we do not check that it +// is correctly transformed. We leave that to the test above. +TEST_F(ScalarReplacementTest, TestUnimited) { + const std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpName %6 "simple_struct" +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypeStruct %2 %2 %2 %2 +%4 = OpTypePointer Function %3 +%5 = OpTypePointer Function %2 +%6 = OpTypeFunction %2 +%7 = OpConstantNull %3 +%8 = OpConstant %2 0 +%9 = OpConstant %2 1 +%10 = OpConstant %2 2 +%11 = OpConstant %2 3 +%12 = OpFunction %2 None %6 +%13 = OpLabel +%14 = OpVariable %4 Function %7 +%15 = OpInBoundsAccessChain %5 %14 %8 +%16 = OpLoad %2 %15 +%17 = OpAccessChain %5 %14 %10 +%18 = OpLoad %2 %17 +%19 = OpIAdd %2 %16 %18 +OpReturnValue %19 +OpFunctionEnd + )"; + + auto result = + SinglePassRunAndDisassemble(text, true, false, 0); + EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result)); +} + +TEST_F(ScalarReplacementTest, AmbigousPointer) { + const std::string text = R"( +; CHECK: [[s1:%\w+]] = OpTypeStruct %uint +; CHECK: [[s2:%\w+]] = OpTypeStruct %uint +; CHECK: [[s3:%\w+]] = OpTypeStruct [[s2]] +; CHECK: [[s3_const:%\w+]] = OpConstantComposite [[s3]] +; CHECK: [[s2_ptr:%\w+]] = OpTypePointer Function [[s2]] +; CHECK: OpCompositeExtract [[s2]] [[s3_const]] + + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + %void = OpTypeVoid + %5 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %_struct_7 = OpTypeStruct %uint + %_struct_8 = OpTypeStruct %uint + %_struct_9 = OpTypeStruct %_struct_8 + %uint_1 = OpConstant %uint 1 + %11 = OpConstantComposite %_struct_8 %uint_1 + %12 = OpConstantComposite %_struct_9 %11 +%_ptr_Function__struct_9 = OpTypePointer Function %_struct_9 +%_ptr_Function__struct_7 = OpTypePointer Function %_struct_7 + %2 = OpFunction %void None %5 + %15 = OpLabel + %var = OpVariable %_ptr_Function__struct_9 Function + OpStore %var %12 + %ld = OpLoad %_struct_9 %var + %ex = OpCompositeExtract %_struct_8 %ld 0 + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, true); +} + +// Test that scalar replacement does not crash when there is an OpAccessChain +// with no index. If we choose to handle this case in the future, then the +// result can change. +TEST_F(ScalarReplacementTest, TestAccessChainWithNoIndexes) { + const std::string text = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginLowerLeft + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %_struct_5 = OpTypeStruct %float +%_ptr_Function__struct_5 = OpTypePointer Function %_struct_5 + %1 = OpFunction %void None %3 + %7 = OpLabel + %8 = OpVariable %_ptr_Function__struct_5 Function + %9 = OpAccessChain %_ptr_Function__struct_5 %8 + OpReturn + OpFunctionEnd + )"; + + auto result = + SinglePassRunAndDisassemble(text, true, false); + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/set_spec_const_default_value_test.cpp b/test/opt/set_spec_const_default_value_test.cpp new file mode 100644 index 000000000..161674fe0 --- /dev/null +++ b/test/opt/set_spec_const_default_value_test.cpp @@ -0,0 +1,1077 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "test/opt/pass_fixture.h" + +namespace spvtools { +namespace opt { +namespace { + +using testing::Eq; +using SpecIdToValueStrMap = + SetSpecConstantDefaultValuePass::SpecIdToValueStrMap; +using SpecIdToValueBitPatternMap = + SetSpecConstantDefaultValuePass::SpecIdToValueBitPatternMap; + +struct DefaultValuesStringParsingTestCase { + const char* default_values_str; + bool expect_success; + SpecIdToValueStrMap expected_map; +}; + +using DefaultValuesStringParsingTest = + ::testing::TestWithParam; + +TEST_P(DefaultValuesStringParsingTest, TestCase) { + const auto& tc = GetParam(); + auto actual_map = SetSpecConstantDefaultValuePass::ParseDefaultValuesString( + tc.default_values_str); + if (tc.expect_success) { + EXPECT_NE(nullptr, actual_map); + if (actual_map) { + EXPECT_THAT(*actual_map, Eq(tc.expected_map)); + } + } else { + EXPECT_EQ(nullptr, actual_map); + } +} + +INSTANTIATE_TEST_CASE_P( + ValidString, DefaultValuesStringParsingTest, + ::testing::ValuesIn(std::vector{ + // 0. empty map + {"", true, SpecIdToValueStrMap{}}, + // 1. one pair + {"100:1024", true, SpecIdToValueStrMap{{100, "1024"}}}, + // 2. two pairs + {"100:1024 200:2048", true, + SpecIdToValueStrMap{{100, "1024"}, {200, "2048"}}}, + // 3. spaces between entries + {"100:1024 \n \r \t \v \f 200:2048", true, + SpecIdToValueStrMap{{100, "1024"}, {200, "2048"}}}, + // 4. \t, \n, \r and spaces before spec id + {" \n \r\t \t \v \f 100:1024", true, + SpecIdToValueStrMap{{100, "1024"}}}, + // 5. \t, \n, \r and spaces after value string + {"100:1024 \n \r\t \t \v \f ", true, + SpecIdToValueStrMap{{100, "1024"}}}, + // 6. maximum spec id + {"4294967295:0", true, SpecIdToValueStrMap{{4294967295, "0"}}}, + // 7. minimum spec id + {"0:100", true, SpecIdToValueStrMap{{0, "100"}}}, + // 8. random content without spaces are allowed + {"200:random_stuff", true, SpecIdToValueStrMap{{200, "random_stuff"}}}, + // 9. support hex format spec id (just because we use the + // ParseNumber() utility) + {"0x100:1024", true, SpecIdToValueStrMap{{256, "1024"}}}, + // 10. multiple entries + {"101:1 102:2 103:3 104:4 200:201 9999:1000 0x100:333", true, + SpecIdToValueStrMap{{101, "1"}, + {102, "2"}, + {103, "3"}, + {104, "4"}, + {200, "201"}, + {9999, "1000"}, + {256, "333"}}}, + // 11. default value in hex float format + {"100:0x0.3p10", true, SpecIdToValueStrMap{{100, "0x0.3p10"}}}, + // 12. default value in decimal float format + {"100:1.5e-13", true, SpecIdToValueStrMap{{100, "1.5e-13"}}}, + })); + +INSTANTIATE_TEST_CASE_P( + InvalidString, DefaultValuesStringParsingTest, + ::testing::ValuesIn(std::vector{ + // 0. missing default value + {"100:", false, SpecIdToValueStrMap{}}, + // 1. spec id is not an integer + {"100.0:200", false, SpecIdToValueStrMap{}}, + // 2. spec id is not a number + {"something_not_a_number:1", false, SpecIdToValueStrMap{}}, + // 3. only spec id number + {"100", false, SpecIdToValueStrMap{}}, + // 4. same spec id defined multiple times + {"100:20 100:21", false, SpecIdToValueStrMap{}}, + // 5. Multiple definition of an identical spec id in different forms + // is not allowed + {"0x100:100 256:200", false, SpecIdToValueStrMap{}}, + // 6. empty spec id + {":3", false, SpecIdToValueStrMap{}}, + // 7. only colon + {":", false, SpecIdToValueStrMap{}}, + // 8. spec id overflow + {"4294967296:200", false, SpecIdToValueStrMap{}}, + // 9. spec id less than 0 + {"-1:200", false, SpecIdToValueStrMap{}}, + // 10. nullptr + {nullptr, false, SpecIdToValueStrMap{}}, + // 11. only a number is invalid + {"1234", false, SpecIdToValueStrMap{}}, + // 12. invalid entry separator + {"12:34;23:14", false, SpecIdToValueStrMap{}}, + // 13. invalid spec id and default value separator + {"12@34", false, SpecIdToValueStrMap{}}, + // 14. spaces before colon + {"100 :1024", false, SpecIdToValueStrMap{}}, + // 15. spaces after colon + {"100: 1024", false, SpecIdToValueStrMap{}}, + // 16. spec id represented in hex float format is invalid + {"0x3p10:200", false, SpecIdToValueStrMap{}}, + })); + +struct SetSpecConstantDefaultValueInStringFormTestCase { + const char* code; + SpecIdToValueStrMap default_values; + const char* expected; +}; + +using SetSpecConstantDefaultValueInStringFormParamTest = PassTest< + ::testing::TestWithParam>; + +TEST_P(SetSpecConstantDefaultValueInStringFormParamTest, TestCase) { + const auto& tc = GetParam(); + SinglePassRunAndCheck( + tc.code, tc.expected, /* skip_nop = */ false, tc.default_values); +} + +INSTANTIATE_TEST_CASE_P( + ValidCases, SetSpecConstantDefaultValueInStringFormParamTest, + ::testing::ValuesIn(std::vector< + SetSpecConstantDefaultValueInStringFormTestCase>{ + // 0. Empty. + {"", SpecIdToValueStrMap{}, ""}, + // 1. Empty with non-empty values to set. + {"", SpecIdToValueStrMap{{1, "100"}, {2, "200"}}, ""}, + // 2. Bool type. + { + // code + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "%bool = OpTypeBool\n" + "%1 = OpSpecConstantTrue %bool\n" + "%2 = OpSpecConstantFalse %bool\n", + // default values + SpecIdToValueStrMap{{100, "false"}, {101, "true"}}, + // expected + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "%bool = OpTypeBool\n" + "%1 = OpSpecConstantFalse %bool\n" + "%2 = OpSpecConstantTrue %bool\n", + }, + // 3. 32-bit int type. + { + // code + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "OpDecorate %3 SpecId 102\n" + "%int = OpTypeInt 32 1\n" + "%1 = OpSpecConstant %int 10\n" + "%2 = OpSpecConstant %int 11\n" + "%3 = OpSpecConstant %int 11\n", + // default values + SpecIdToValueStrMap{ + {100, "2147483647"}, {101, "0xffffffff"}, {102, "-42"}}, + // expected + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "OpDecorate %3 SpecId 102\n" + "%int = OpTypeInt 32 1\n" + "%1 = OpSpecConstant %int 2147483647\n" + "%2 = OpSpecConstant %int -1\n" + "%3 = OpSpecConstant %int -42\n", + }, + // 4. 64-bit uint type. + { + // code + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "%ulong = OpTypeInt 64 0\n" + "%1 = OpSpecConstant %ulong 10\n" + "%2 = OpSpecConstant %ulong 11\n", + // default values + SpecIdToValueStrMap{{100, "18446744073709551614"}, {101, "0x100"}}, + // expected + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "%ulong = OpTypeInt 64 0\n" + "%1 = OpSpecConstant %ulong 18446744073709551614\n" + "%2 = OpSpecConstant %ulong 256\n", + }, + // 5. 32-bit float type. + { + // code + "OpDecorate %1 SpecId 101\n" + "OpDecorate %2 SpecId 102\n" + "%float = OpTypeFloat 32\n" + "%1 = OpSpecConstant %float 200\n" + "%2 = OpSpecConstant %float 201\n", + // default values + SpecIdToValueStrMap{{101, "-0x1.fffffep+128"}, {102, "2.5"}}, + // expected + "OpDecorate %1 SpecId 101\n" + "OpDecorate %2 SpecId 102\n" + "%float = OpTypeFloat 32\n" + "%1 = OpSpecConstant %float -0x1.fffffep+128\n" + "%2 = OpSpecConstant %float 2.5\n", + }, + // 6. 64-bit float type. + { + // code + "OpDecorate %1 SpecId 201\n" + "OpDecorate %2 SpecId 202\n" + "%double = OpTypeFloat 64\n" + "%1 = OpSpecConstant %double 3.14159265358979\n" + "%2 = OpSpecConstant %double 0.14285\n", + // default values + SpecIdToValueStrMap{{201, "0x1.fffffffffffffp+1024"}, + {202, "-32.5"}}, + // expected + "OpDecorate %1 SpecId 201\n" + "OpDecorate %2 SpecId 202\n" + "%double = OpTypeFloat 64\n" + "%1 = OpSpecConstant %double 0x1.fffffffffffffp+1024\n" + "%2 = OpSpecConstant %double -32.5\n", + }, + // 7. SpecId not found, expect no modification. + { + // code + "OpDecorate %1 SpecId 201\n" + "%double = OpTypeFloat 64\n" + "%1 = OpSpecConstant %double 3.14159265358979\n", + // default values + SpecIdToValueStrMap{{8888, "0.0"}}, + // expected + "OpDecorate %1 SpecId 201\n" + "%double = OpTypeFloat 64\n" + "%1 = OpSpecConstant %double 3.14159265358979\n", + }, + // 8. Multiple types of spec constants. + { + // code + "OpDecorate %1 SpecId 201\n" + "OpDecorate %2 SpecId 202\n" + "OpDecorate %3 SpecId 203\n" + "%bool = OpTypeBool\n" + "%int = OpTypeInt 32 1\n" + "%double = OpTypeFloat 64\n" + "%1 = OpSpecConstant %double 3.14159265358979\n" + "%2 = OpSpecConstant %int 1024\n" + "%3 = OpSpecConstantTrue %bool\n", + // default values + SpecIdToValueStrMap{ + {201, "0x1.fffffffffffffp+1024"}, + {202, "2048"}, + {203, "false"}, + }, + // expected + "OpDecorate %1 SpecId 201\n" + "OpDecorate %2 SpecId 202\n" + "OpDecorate %3 SpecId 203\n" + "%bool = OpTypeBool\n" + "%int = OpTypeInt 32 1\n" + "%double = OpTypeFloat 64\n" + "%1 = OpSpecConstant %double 0x1.fffffffffffffp+1024\n" + "%2 = OpSpecConstant %int 2048\n" + "%3 = OpSpecConstantFalse %bool\n", + }, + // 9. Ignore other decorations. + { + // code + "OpDecorate %1 ArrayStride 4\n" + "%int = OpTypeInt 32 1\n" + "%1 = OpSpecConstant %int 100\n", + // default values + SpecIdToValueStrMap{{4, "0x7fffffff"}}, + // expected + "OpDecorate %1 ArrayStride 4\n" + "%int = OpTypeInt 32 1\n" + "%1 = OpSpecConstant %int 100\n", + }, + // 10. Distinguish from other decorations. + { + // code + "OpDecorate %1 SpecId 100\n" + "OpDecorate %1 ArrayStride 4\n" + "%int = OpTypeInt 32 1\n" + "%1 = OpSpecConstant %int 100\n", + // default values + SpecIdToValueStrMap{{4, "0x7fffffff"}, {100, "0xffffffff"}}, + // expected + "OpDecorate %1 SpecId 100\n" + "OpDecorate %1 ArrayStride 4\n" + "%int = OpTypeInt 32 1\n" + "%1 = OpSpecConstant %int -1\n", + }, + // 11. Decorate through decoration group. + { + // code + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "OpGroupDecorate %1 %2\n" + "%int = OpTypeInt 32 1\n" + "%2 = OpSpecConstant %int 100\n", + // default values + SpecIdToValueStrMap{{100, "0x7fffffff"}}, + // expected + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "OpGroupDecorate %1 %2\n" + "%int = OpTypeInt 32 1\n" + "%2 = OpSpecConstant %int 2147483647\n", + }, + // 12. Ignore other decorations in decoration group. + { + // code + "OpDecorate %1 ArrayStride 4\n" + "%1 = OpDecorationGroup\n" + "OpGroupDecorate %1 %2\n" + "%int = OpTypeInt 32 1\n" + "%2 = OpSpecConstant %int 100\n", + // default values + SpecIdToValueStrMap{{4, "0x7fffffff"}}, + // expected + "OpDecorate %1 ArrayStride 4\n" + "%1 = OpDecorationGroup\n" + "OpGroupDecorate %1 %2\n" + "%int = OpTypeInt 32 1\n" + "%2 = OpSpecConstant %int 100\n", + }, + // 13. Distinguish from other decorations in decoration group. + { + // code + "OpDecorate %1 SpecId 100\n" + "OpDecorate %1 ArrayStride 4\n" + "%1 = OpDecorationGroup\n" + "OpGroupDecorate %1 %2\n" + "%int = OpTypeInt 32 1\n" + "%2 = OpSpecConstant %int 100\n", + // default values + SpecIdToValueStrMap{{100, "0x7fffffff"}, {4, "0x00000001"}}, + // expected + "OpDecorate %1 SpecId 100\n" + "OpDecorate %1 ArrayStride 4\n" + "%1 = OpDecorationGroup\n" + "OpGroupDecorate %1 %2\n" + "%int = OpTypeInt 32 1\n" + "%2 = OpSpecConstant %int 2147483647\n", + }, + // 14. Unchanged bool default value + { + // code + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "%bool = OpTypeBool\n" + "%1 = OpSpecConstantTrue %bool\n" + "%2 = OpSpecConstantFalse %bool\n", + // default values + SpecIdToValueStrMap{{100, "true"}, {101, "false"}}, + // expected + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "%bool = OpTypeBool\n" + "%1 = OpSpecConstantTrue %bool\n" + "%2 = OpSpecConstantFalse %bool\n", + }, + // 15. Unchanged int default values + { + // code + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "%int = OpTypeInt 32 1\n" + "%ulong = OpTypeInt 64 0\n" + "%1 = OpSpecConstant %int 10\n" + "%2 = OpSpecConstant %ulong 11\n", + // default values + SpecIdToValueStrMap{{100, "10"}, {101, "11"}}, + // expected + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "%int = OpTypeInt 32 1\n" + "%ulong = OpTypeInt 64 0\n" + "%1 = OpSpecConstant %int 10\n" + "%2 = OpSpecConstant %ulong 11\n", + }, + // 16. Unchanged float default values + { + // code + "OpDecorate %1 SpecId 201\n" + "OpDecorate %2 SpecId 202\n" + "%float = OpTypeFloat 32\n" + "%double = OpTypeFloat 64\n" + "%1 = OpSpecConstant %float 3.1415\n" + "%2 = OpSpecConstant %double 0.14285\n", + // default values + SpecIdToValueStrMap{{201, "3.1415"}, {202, "0.14285"}}, + // expected + "OpDecorate %1 SpecId 201\n" + "OpDecorate %2 SpecId 202\n" + "%float = OpTypeFloat 32\n" + "%double = OpTypeFloat 64\n" + "%1 = OpSpecConstant %float 3.1415\n" + "%2 = OpSpecConstant %double 0.14285\n", + }, + // 17. OpGroupDecorate may have multiple target ids defined by the same + // eligible spec constant + { + // code + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "OpGroupDecorate %1 %2 %2 %2\n" + "%int = OpTypeInt 32 1\n" + "%2 = OpSpecConstant %int 100\n", + // default values + SpecIdToValueStrMap{{100, "0xffffffff"}}, + // expected + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "OpGroupDecorate %1 %2 %2 %2\n" + "%int = OpTypeInt 32 1\n" + "%2 = OpSpecConstant %int -1\n", + }, + })); + +INSTANTIATE_TEST_CASE_P( + InvalidCases, SetSpecConstantDefaultValueInStringFormParamTest, + ::testing::ValuesIn(std::vector< + SetSpecConstantDefaultValueInStringFormTestCase>{ + // 0. Do not crash when decoration group is not used. + { + // code + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "%int = OpTypeInt 32 1\n" + "%3 = OpSpecConstant %int 100\n", + // default values + SpecIdToValueStrMap{{100, "0x7fffffff"}}, + // expected + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "%int = OpTypeInt 32 1\n" + "%3 = OpSpecConstant %int 100\n", + }, + // 1. Do not crash when target does not exist. + { + // code + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "%int = OpTypeInt 32 1\n", + // default values + SpecIdToValueStrMap{{100, "0x7fffffff"}}, + // expected + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "%int = OpTypeInt 32 1\n", + }, + // 2. Do nothing when SpecId decoration is not attached to a + // non-spec-constant instruction. + { + // code + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "%int = OpTypeInt 32 1\n" + "%int_101 = OpConstant %int 101\n", + // default values + SpecIdToValueStrMap{{100, "0x7fffffff"}}, + // expected + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "%int = OpTypeInt 32 1\n" + "%int_101 = OpConstant %int 101\n", + }, + // 3. Do nothing when SpecId decoration is not attached to a + // OpSpecConstant{|True|False} instruction. + { + // code + "OpDecorate %1 SpecId 100\n" + "%int = OpTypeInt 32 1\n" + "%3 = OpSpecConstant %int 101\n" + "%1 = OpSpecConstantOp %int IAdd %3 %3\n", + // default values + SpecIdToValueStrMap{{100, "0x7fffffff"}}, + // expected + "OpDecorate %1 SpecId 100\n" + "%int = OpTypeInt 32 1\n" + "%3 = OpSpecConstant %int 101\n" + "%1 = OpSpecConstantOp %int IAdd %3 %3\n", + }, + // 4. Do not crash and do nothing when SpecId decoration is applied to + // multiple spec constants. + { + // code + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "OpGroupDecorate %1 %2 %3 %4\n" + "%int = OpTypeInt 32 1\n" + "%2 = OpSpecConstant %int 100\n" + "%3 = OpSpecConstant %int 200\n" + "%4 = OpSpecConstant %int 300\n", + // default values + SpecIdToValueStrMap{{100, "0xffffffff"}}, + // expected + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "OpGroupDecorate %1 %2 %3 %4\n" + "%int = OpTypeInt 32 1\n" + "%2 = OpSpecConstant %int 100\n" + "%3 = OpSpecConstant %int 200\n" + "%4 = OpSpecConstant %int 300\n", + }, + // 5. Do not crash and do nothing when SpecId decoration is attached to + // non-spec-constants (invalid case). + { + // code + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "%2 = OpDecorationGroup\n" + "OpGroupDecorate %1 %2\n" + "%int = OpTypeInt 32 1\n" + "%int_100 = OpConstant %int 100\n", + // default values + SpecIdToValueStrMap{{100, "0xffffffff"}}, + // expected + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "%2 = OpDecorationGroup\n" + "OpGroupDecorate %1 %2\n" + "%int = OpTypeInt 32 1\n" + "%int_100 = OpConstant %int 100\n", + }, + // 6. Boolean type spec constant cannot be set with numeric values in + // string form. i.e. only 'true' and 'false' are acceptable for setting + // boolean type spec constants. Nothing should be done if numeric values + // in string form are provided. + { + // code + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "OpDecorate %3 SpecId 102\n" + "OpDecorate %4 SpecId 103\n" + "OpDecorate %5 SpecId 104\n" + "OpDecorate %6 SpecId 105\n" + "%bool = OpTypeBool\n" + "%1 = OpSpecConstantTrue %bool\n" + "%2 = OpSpecConstantFalse %bool\n" + "%3 = OpSpecConstantTrue %bool\n" + "%4 = OpSpecConstantTrue %bool\n" + "%5 = OpSpecConstantTrue %bool\n" + "%6 = OpSpecConstantFalse %bool\n", + // default values + SpecIdToValueStrMap{{100, "0"}, + {101, "1"}, + {102, "0x0"}, + {103, "0.0"}, + {104, "-0.0"}, + {105, "0x12345678"}}, + // expected + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "OpDecorate %3 SpecId 102\n" + "OpDecorate %4 SpecId 103\n" + "OpDecorate %5 SpecId 104\n" + "OpDecorate %6 SpecId 105\n" + "%bool = OpTypeBool\n" + "%1 = OpSpecConstantTrue %bool\n" + "%2 = OpSpecConstantFalse %bool\n" + "%3 = OpSpecConstantTrue %bool\n" + "%4 = OpSpecConstantTrue %bool\n" + "%5 = OpSpecConstantTrue %bool\n" + "%6 = OpSpecConstantFalse %bool\n", + }, + })); + +struct SetSpecConstantDefaultValueInBitPatternFormTestCase { + const char* code; + SpecIdToValueBitPatternMap default_values; + const char* expected; +}; + +using SetSpecConstantDefaultValueInBitPatternFormParamTest = + PassTest<::testing::TestWithParam< + SetSpecConstantDefaultValueInBitPatternFormTestCase>>; + +TEST_P(SetSpecConstantDefaultValueInBitPatternFormParamTest, TestCase) { + const auto& tc = GetParam(); + SinglePassRunAndCheck( + tc.code, tc.expected, /* skip_nop = */ false, tc.default_values); +} + +INSTANTIATE_TEST_CASE_P( + ValidCases, SetSpecConstantDefaultValueInBitPatternFormParamTest, + ::testing::ValuesIn(std::vector< + SetSpecConstantDefaultValueInBitPatternFormTestCase>{ + // 0. Empty. + {"", SpecIdToValueBitPatternMap{}, ""}, + // 1. Empty with non-empty values to set. + {"", SpecIdToValueBitPatternMap{{1, {100}}, {2, {200}}}, ""}, + // 2. Baisc bool type. + { + // code + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "%bool = OpTypeBool\n" + "%1 = OpSpecConstantTrue %bool\n" + "%2 = OpSpecConstantFalse %bool\n", + // default values + SpecIdToValueBitPatternMap{{100, {0x0}}, {101, {0x1}}}, + // expected + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "%bool = OpTypeBool\n" + "%1 = OpSpecConstantFalse %bool\n" + "%2 = OpSpecConstantTrue %bool\n", + }, + // 3. 32-bit int type. + { + // code + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "OpDecorate %3 SpecId 102\n" + "%int = OpTypeInt 32 1\n" + "%1 = OpSpecConstant %int 10\n" + "%2 = OpSpecConstant %int 11\n" + "%3 = OpSpecConstant %int 11\n", + // default values + SpecIdToValueBitPatternMap{ + {100, {2147483647}}, {101, {0xffffffff}}, {102, {0xffffffd6}}}, + // expected + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "OpDecorate %3 SpecId 102\n" + "%int = OpTypeInt 32 1\n" + "%1 = OpSpecConstant %int 2147483647\n" + "%2 = OpSpecConstant %int -1\n" + "%3 = OpSpecConstant %int -42\n", + }, + // 4. 64-bit uint type. + { + // code + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "%ulong = OpTypeInt 64 0\n" + "%1 = OpSpecConstant %ulong 10\n" + "%2 = OpSpecConstant %ulong 11\n", + // default values + SpecIdToValueBitPatternMap{{100, {0xFFFFFFFE, 0xFFFFFFFF}}, + {101, {0x100, 0x0}}}, + // expected + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "%ulong = OpTypeInt 64 0\n" + "%1 = OpSpecConstant %ulong 18446744073709551614\n" + "%2 = OpSpecConstant %ulong 256\n", + }, + // 5. 32-bit float type. + { + // code + "OpDecorate %1 SpecId 101\n" + "OpDecorate %2 SpecId 102\n" + "%float = OpTypeFloat 32\n" + "%1 = OpSpecConstant %float 200\n" + "%2 = OpSpecConstant %float 201\n", + // default values + SpecIdToValueBitPatternMap{{101, {0xffffffff}}, + {102, {0x40200000}}}, + // expected + "OpDecorate %1 SpecId 101\n" + "OpDecorate %2 SpecId 102\n" + "%float = OpTypeFloat 32\n" + "%1 = OpSpecConstant %float -0x1.fffffep+128\n" + "%2 = OpSpecConstant %float 2.5\n", + }, + // 6. 64-bit float type. + { + // code + "OpDecorate %1 SpecId 201\n" + "OpDecorate %2 SpecId 202\n" + "%double = OpTypeFloat 64\n" + "%1 = OpSpecConstant %double 3.14159265358979\n" + "%2 = OpSpecConstant %double 0.14285\n", + // default values + SpecIdToValueBitPatternMap{{201, {0xffffffff, 0x7fffffff}}, + {202, {0x00000000, 0xc0404000}}}, + // expected + "OpDecorate %1 SpecId 201\n" + "OpDecorate %2 SpecId 202\n" + "%double = OpTypeFloat 64\n" + "%1 = OpSpecConstant %double 0x1.fffffffffffffp+1024\n" + "%2 = OpSpecConstant %double -32.5\n", + }, + // 7. SpecId not found, expect no modification. + { + // code + "OpDecorate %1 SpecId 201\n" + "%double = OpTypeFloat 64\n" + "%1 = OpSpecConstant %double 3.14159265358979\n", + // default values + SpecIdToValueBitPatternMap{{8888, {0x0}}}, + // expected + "OpDecorate %1 SpecId 201\n" + "%double = OpTypeFloat 64\n" + "%1 = OpSpecConstant %double 3.14159265358979\n", + }, + // 8. Multiple types of spec constants. + { + // code + "OpDecorate %1 SpecId 201\n" + "OpDecorate %2 SpecId 202\n" + "OpDecorate %3 SpecId 203\n" + "%bool = OpTypeBool\n" + "%int = OpTypeInt 32 1\n" + "%double = OpTypeFloat 64\n" + "%1 = OpSpecConstant %double 3.14159265358979\n" + "%2 = OpSpecConstant %int 1024\n" + "%3 = OpSpecConstantTrue %bool\n", + // default values + SpecIdToValueBitPatternMap{ + {201, {0xffffffff, 0x7fffffff}}, + {202, {0x00000800}}, + {203, {0x0}}, + }, + // expected + "OpDecorate %1 SpecId 201\n" + "OpDecorate %2 SpecId 202\n" + "OpDecorate %3 SpecId 203\n" + "%bool = OpTypeBool\n" + "%int = OpTypeInt 32 1\n" + "%double = OpTypeFloat 64\n" + "%1 = OpSpecConstant %double 0x1.fffffffffffffp+1024\n" + "%2 = OpSpecConstant %int 2048\n" + "%3 = OpSpecConstantFalse %bool\n", + }, + // 9. Ignore other decorations. + { + // code + "OpDecorate %1 ArrayStride 4\n" + "%int = OpTypeInt 32 1\n" + "%1 = OpSpecConstant %int 100\n", + // default values + SpecIdToValueBitPatternMap{{4, {0x7fffffff}}}, + // expected + "OpDecorate %1 ArrayStride 4\n" + "%int = OpTypeInt 32 1\n" + "%1 = OpSpecConstant %int 100\n", + }, + // 10. Distinguish from other decorations. + { + // code + "OpDecorate %1 SpecId 100\n" + "OpDecorate %1 ArrayStride 4\n" + "%int = OpTypeInt 32 1\n" + "%1 = OpSpecConstant %int 100\n", + // default values + SpecIdToValueBitPatternMap{{4, {0x7fffffff}}, {100, {0xffffffff}}}, + // expected + "OpDecorate %1 SpecId 100\n" + "OpDecorate %1 ArrayStride 4\n" + "%int = OpTypeInt 32 1\n" + "%1 = OpSpecConstant %int -1\n", + }, + // 11. Decorate through decoration group. + { + // code + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "OpGroupDecorate %1 %2\n" + "%int = OpTypeInt 32 1\n" + "%2 = OpSpecConstant %int 100\n", + // default values + SpecIdToValueBitPatternMap{{100, {0x7fffffff}}}, + // expected + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "OpGroupDecorate %1 %2\n" + "%int = OpTypeInt 32 1\n" + "%2 = OpSpecConstant %int 2147483647\n", + }, + // 12. Ignore other decorations in decoration group. + { + // code + "OpDecorate %1 ArrayStride 4\n" + "%1 = OpDecorationGroup\n" + "OpGroupDecorate %1 %2\n" + "%int = OpTypeInt 32 1\n" + "%2 = OpSpecConstant %int 100\n", + // default values + SpecIdToValueBitPatternMap{{4, {0x7fffffff}}}, + // expected + "OpDecorate %1 ArrayStride 4\n" + "%1 = OpDecorationGroup\n" + "OpGroupDecorate %1 %2\n" + "%int = OpTypeInt 32 1\n" + "%2 = OpSpecConstant %int 100\n", + }, + // 13. Distinguish from other decorations in decoration group. + { + // code + "OpDecorate %1 SpecId 100\n" + "OpDecorate %1 ArrayStride 4\n" + "%1 = OpDecorationGroup\n" + "OpGroupDecorate %1 %2\n" + "%int = OpTypeInt 32 1\n" + "%2 = OpSpecConstant %int 100\n", + // default values + SpecIdToValueBitPatternMap{{100, {0x7fffffff}}, {4, {0x00000001}}}, + // expected + "OpDecorate %1 SpecId 100\n" + "OpDecorate %1 ArrayStride 4\n" + "%1 = OpDecorationGroup\n" + "OpGroupDecorate %1 %2\n" + "%int = OpTypeInt 32 1\n" + "%2 = OpSpecConstant %int 2147483647\n", + }, + // 14. Unchanged bool default value + { + // code + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "%bool = OpTypeBool\n" + "%1 = OpSpecConstantTrue %bool\n" + "%2 = OpSpecConstantFalse %bool\n", + // default values + SpecIdToValueBitPatternMap{{100, {0x1}}, {101, {0x0}}}, + // expected + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "%bool = OpTypeBool\n" + "%1 = OpSpecConstantTrue %bool\n" + "%2 = OpSpecConstantFalse %bool\n", + }, + // 15. Unchanged int default values + { + // code + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "%int = OpTypeInt 32 1\n" + "%ulong = OpTypeInt 64 0\n" + "%1 = OpSpecConstant %int 10\n" + "%2 = OpSpecConstant %ulong 11\n", + // default values + SpecIdToValueBitPatternMap{{100, {10}}, {101, {11, 0}}}, + // expected + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "%int = OpTypeInt 32 1\n" + "%ulong = OpTypeInt 64 0\n" + "%1 = OpSpecConstant %int 10\n" + "%2 = OpSpecConstant %ulong 11\n", + }, + // 16. Unchanged float default values + { + // code + "OpDecorate %1 SpecId 201\n" + "OpDecorate %2 SpecId 202\n" + "%float = OpTypeFloat 32\n" + "%double = OpTypeFloat 64\n" + "%1 = OpSpecConstant %float 3.25\n" + "%2 = OpSpecConstant %double 1.25\n", + // default values + SpecIdToValueBitPatternMap{{201, {0x40500000}}, + {202, {0x00000000, 0x3ff40000}}}, + // expected + "OpDecorate %1 SpecId 201\n" + "OpDecorate %2 SpecId 202\n" + "%float = OpTypeFloat 32\n" + "%double = OpTypeFloat 64\n" + "%1 = OpSpecConstant %float 3.25\n" + "%2 = OpSpecConstant %double 1.25\n", + }, + // 17. OpGroupDecorate may have multiple target ids defined by the same + // eligible spec constant + { + // code + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "OpGroupDecorate %1 %2 %2 %2\n" + "%int = OpTypeInt 32 1\n" + "%2 = OpSpecConstant %int 100\n", + // default values + SpecIdToValueBitPatternMap{{100, {0xffffffff}}}, + // expected + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "OpGroupDecorate %1 %2 %2 %2\n" + "%int = OpTypeInt 32 1\n" + "%2 = OpSpecConstant %int -1\n", + }, + // 18. For Boolean type spec constants,if any word in the bit pattern + // is not zero, it can be considered as a 'true', otherwise, it can be + // considered as a 'false'. + { + // code + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "OpDecorate %3 SpecId 102\n" + "%bool = OpTypeBool\n" + "%1 = OpSpecConstantTrue %bool\n" + "%2 = OpSpecConstantFalse %bool\n" + "%3 = OpSpecConstantFalse %bool\n", + // default values + SpecIdToValueBitPatternMap{ + {100, {0x0, 0x0, 0x0, 0x0}}, + {101, {0x10101010}}, + {102, {0x0, 0x0, 0x0, 0x2}}, + }, + // expected + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "OpDecorate %3 SpecId 102\n" + "%bool = OpTypeBool\n" + "%1 = OpSpecConstantFalse %bool\n" + "%2 = OpSpecConstantTrue %bool\n" + "%3 = OpSpecConstantTrue %bool\n", + }, + })); + +INSTANTIATE_TEST_CASE_P( + InvalidCases, SetSpecConstantDefaultValueInBitPatternFormParamTest, + ::testing::ValuesIn(std::vector< + SetSpecConstantDefaultValueInBitPatternFormTestCase>{ + // 0. Do not crash when decoration group is not used. + { + // code + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "%int = OpTypeInt 32 1\n" + "%3 = OpSpecConstant %int 100\n", + // default values + SpecIdToValueBitPatternMap{{100, {0x7fffffff}}}, + // expected + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "%int = OpTypeInt 32 1\n" + "%3 = OpSpecConstant %int 100\n", + }, + // 1. Do not crash when target does not exist. + { + // code + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "%int = OpTypeInt 32 1\n", + // default values + SpecIdToValueBitPatternMap{{100, {0x7fffffff}}}, + // expected + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "%int = OpTypeInt 32 1\n", + }, + // 2. Do nothing when SpecId decoration is not attached to a + // non-spec-constant instruction. + { + // code + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "%int = OpTypeInt 32 1\n" + "%int_101 = OpConstant %int 101\n", + // default values + SpecIdToValueBitPatternMap{{100, {0x7fffffff}}}, + // expected + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "%int = OpTypeInt 32 1\n" + "%int_101 = OpConstant %int 101\n", + }, + // 3. Do nothing when SpecId decoration is not attached to a + // OpSpecConstant{|True|False} instruction. + { + // code + "OpDecorate %1 SpecId 100\n" + "%int = OpTypeInt 32 1\n" + "%3 = OpSpecConstant %int 101\n" + "%1 = OpSpecConstantOp %int IAdd %3 %3\n", + // default values + SpecIdToValueBitPatternMap{{100, {0x7fffffff}}}, + // expected + "OpDecorate %1 SpecId 100\n" + "%int = OpTypeInt 32 1\n" + "%3 = OpSpecConstant %int 101\n" + "%1 = OpSpecConstantOp %int IAdd %3 %3\n", + }, + // 4. Do not crash and do nothing when SpecId decoration is applied to + // multiple spec constants. + { + // code + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "OpGroupDecorate %1 %2 %3 %4\n" + "%int = OpTypeInt 32 1\n" + "%2 = OpSpecConstant %int 100\n" + "%3 = OpSpecConstant %int 200\n" + "%4 = OpSpecConstant %int 300\n", + // default values + SpecIdToValueBitPatternMap{{100, {0xffffffff}}}, + // expected + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "OpGroupDecorate %1 %2 %3 %4\n" + "%int = OpTypeInt 32 1\n" + "%2 = OpSpecConstant %int 100\n" + "%3 = OpSpecConstant %int 200\n" + "%4 = OpSpecConstant %int 300\n", + }, + // 5. Do not crash and do nothing when SpecId decoration is attached to + // non-spec-constants (invalid case). + { + // code + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "%2 = OpDecorationGroup\n" + "OpGroupDecorate %1 %2\n" + "%int = OpTypeInt 32 1\n" + "%int_100 = OpConstant %int 100\n", + // default values + SpecIdToValueBitPatternMap{{100, {0xffffffff}}}, + // expected + "OpDecorate %1 SpecId 100\n" + "%1 = OpDecorationGroup\n" + "%2 = OpDecorationGroup\n" + "OpGroupDecorate %1 %2\n" + "%int = OpTypeInt 32 1\n" + "%int_100 = OpConstant %int 100\n", + }, + // 6. Incompatible input bit pattern with the type. Nothing should be + // done in such a case. + { + // code + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "OpDecorate %3 SpecId 102\n" + "%int = OpTypeInt 32 1\n" + "%ulong = OpTypeInt 64 0\n" + "%double = OpTypeFloat 64\n" + "%1 = OpSpecConstant %int 100\n" + "%2 = OpSpecConstant %ulong 200\n" + "%3 = OpSpecConstant %double 3.141592653\n", + // default values + SpecIdToValueBitPatternMap{ + {100, {10, 0}}, {101, {11}}, {102, {0xffffffff}}}, + // expected + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "OpDecorate %3 SpecId 102\n" + "%int = OpTypeInt 32 1\n" + "%ulong = OpTypeInt 64 0\n" + "%double = OpTypeFloat 64\n" + "%1 = OpSpecConstant %int 100\n" + "%2 = OpSpecConstant %ulong 200\n" + "%3 = OpSpecConstant %double 3.141592653\n", + }, + })); + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/simplification_test.cpp b/test/opt/simplification_test.cpp new file mode 100644 index 000000000..b7d6f18c6 --- /dev/null +++ b/test/opt/simplification_test.cpp @@ -0,0 +1,207 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "source/opt/simplification_pass.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" + +namespace spvtools { +namespace opt { +namespace { + +using SimplificationTest = PassTest<::testing::Test>; + +TEST_F(SimplificationTest, StraightLineTest) { + // Testing that folding rules are combined in simple straight line code. + const std::string text = R"(OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %i %o + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 430 + OpSourceExtension "GL_GOOGLE_cpp_style_line_directive" + OpSourceExtension "GL_GOOGLE_include_directive" + OpName %main "main" + OpName %i "i" + OpName %o "o" + OpDecorate %i Flat + OpDecorate %i Location 0 + OpDecorate %o Location 0 + %void = OpTypeVoid + %8 = OpTypeFunction %void + %int = OpTypeInt 32 1 + %v4int = OpTypeVector %int 4 + %int_0 = OpConstant %int 0 + %13 = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0 + %int_1 = OpConstant %int 1 +%_ptr_Input_v4int = OpTypePointer Input %v4int + %i = OpVariable %_ptr_Input_v4int Input +%_ptr_Output_int = OpTypePointer Output %int + %o = OpVariable %_ptr_Output_int Output + %main = OpFunction %void None %8 + %21 = OpLabel + %31 = OpCompositeInsert %v4int %int_1 %13 0 +; CHECK: [[load:%[a-zA-Z_\d]+]] = OpLoad + %23 = OpLoad %v4int %i + %33 = OpCompositeInsert %v4int %int_0 %23 0 + %35 = OpCompositeExtract %int %31 0 +; CHECK: [[extract:%[a-zA-Z_\d]+]] = OpCompositeExtract %int [[load]] 1 + %37 = OpCompositeExtract %int %33 1 +; CHECK: [[add:%[a-zA-Z_\d]+]] = OpIAdd %int %int_1 [[extract]] + %29 = OpIAdd %int %35 %37 + OpStore %o %29 + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, false); +} + +TEST_F(SimplificationTest, AcrossBasicBlocks) { + // Testing that folding rules are combined across basic blocks. + const std::string text = R"(OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %i %o + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 430 + OpSourceExtension "GL_GOOGLE_cpp_style_line_directive" + OpSourceExtension "GL_GOOGLE_include_directive" + OpName %main "main" + OpName %i "i" + OpName %o "o" + OpDecorate %i Flat + OpDecorate %i Location 0 + OpDecorate %o Location 0 + %void = OpTypeVoid + %8 = OpTypeFunction %void + %int = OpTypeInt 32 1 + %v4int = OpTypeVector %int 4 + %int_0 = OpConstant %int 0 +%_ptr_Input_v4int = OpTypePointer Input %v4int + %i = OpVariable %_ptr_Input_v4int Input + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 +%_ptr_Input_int = OpTypePointer Input %int + %int_10 = OpConstant %int 10 + %bool = OpTypeBool + %int_1 = OpConstant %int 1 +%_ptr_Output_int = OpTypePointer Output %int + %o = OpVariable %_ptr_Output_int Output + %main = OpFunction %void None %8 + %24 = OpLabel +; CHECK: [[load:%[a-zA-Z_\d]+]] = OpLoad %v4int %i + %25 = OpLoad %v4int %i + %41 = OpCompositeInsert %v4int %int_0 %25 0 + %27 = OpAccessChain %_ptr_Input_int %i %uint_0 + %28 = OpLoad %int %27 + %29 = OpSGreaterThan %bool %28 %int_10 + OpSelectionMerge %30 None + OpBranchConditional %29 %31 %32 + %31 = OpLabel + %43 = OpCopyObject %v4int %25 + OpBranch %30 + %32 = OpLabel + %45 = OpCopyObject %v4int %25 + OpBranch %30 + %30 = OpLabel + %50 = OpPhi %v4int %43 %31 %45 %32 +; CHECK: [[extract1:%[a-zA-Z_\d]+]] = OpCompositeExtract %int [[load]] 0 + %47 = OpCompositeExtract %int %50 0 +; CHECK: [[extract2:%[a-zA-Z_\d]+]] = OpCompositeExtract %int [[load]] 1 + %49 = OpCompositeExtract %int %41 1 +; CHECK: [[add:%[a-zA-Z_\d]+]] = OpIAdd %int [[extract1]] [[extract2]] + %39 = OpIAdd %int %47 %49 + OpStore %o %39 + OpReturn + OpFunctionEnd + +)"; + + SinglePassRunAndMatch(text, false); +} + +TEST_F(SimplificationTest, ThroughLoops) { + // Testing that folding rules are applied multiple times to instructions + // to be able to propagate across loop iterations. + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %o %i + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 430 + OpSourceExtension "GL_GOOGLE_cpp_style_line_directive" + OpSourceExtension "GL_GOOGLE_include_directive" + OpName %main "main" + OpName %o "o" + OpName %i "i" + OpDecorate %o Location 0 + OpDecorate %i Flat + OpDecorate %i Location 0 + %void = OpTypeVoid + %8 = OpTypeFunction %void + %int = OpTypeInt 32 1 + %v4int = OpTypeVector %int 4 + %int_0 = OpConstant %int 0 +; CHECK: [[constant:%[a-zA-Z_\d]+]] = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0 + %13 = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0 + %bool = OpTypeBool +%_ptr_Output_int = OpTypePointer Output %int + %o = OpVariable %_ptr_Output_int Output +%_ptr_Input_v4int = OpTypePointer Input %v4int + %i = OpVariable %_ptr_Input_v4int Input + %68 = OpUndef %v4int + %main = OpFunction %void None %8 + %23 = OpLabel +; CHECK: [[load:%[a-zA-Z_\d]+]] = OpLoad %v4int %i + %load = OpLoad %v4int %i + OpBranch %24 + %24 = OpLabel + %67 = OpPhi %v4int %load %23 %64 %26 +; CHECK: OpLoopMerge [[merge_lab:%[a-zA-Z_\d]+]] + OpLoopMerge %25 %26 None + OpBranch %27 + %27 = OpLabel + %48 = OpCompositeExtract %int %67 0 + %30 = OpIEqual %bool %48 %int_0 + OpBranchConditional %30 %31 %25 + %31 = OpLabel + %50 = OpCompositeExtract %int %67 0 + %54 = OpCompositeExtract %int %67 1 + %58 = OpCompositeExtract %int %67 2 + %62 = OpCompositeExtract %int %67 3 + %64 = OpCompositeConstruct %v4int %50 %54 %58 %62 + OpBranch %26 + %26 = OpLabel + OpBranch %24 + %25 = OpLabel +; CHECK: [[merge_lab]] = OpLabel +; CHECK: [[extract:%[a-zA-Z_\d]+]] = OpCompositeExtract %int [[load]] 0 + %66 = OpCompositeExtract %int %67 0 +; CHECK-NEXT: OpStore %o [[extract]] + OpStore %o %66 + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, false); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/strength_reduction_test.cpp b/test/opt/strength_reduction_test.cpp new file mode 100644 index 000000000..31d050360 --- /dev/null +++ b/test/opt/strength_reduction_test.cpp @@ -0,0 +1,438 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::HasSubstr; +using ::testing::MatchesRegex; +using StrengthReductionBasicTest = PassTest<::testing::Test>; + +// Test to make sure we replace 5*8. +TEST_F(StrengthReductionBasicTest, BasicReplaceMulBy8) { + const std::vector text = { + // clang-format off + "OpCapability Shader", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Vertex %main \"main\"", + "OpName %main \"main\"", + "%void = OpTypeVoid", + "%4 = OpTypeFunction %void", + "%uint = OpTypeInt 32 0", + "%uint_5 = OpConstant %uint 5", + "%uint_8 = OpConstant %uint 8", + "%main = OpFunction %void None %4", + "%8 = OpLabel", + "%9 = OpIMul %uint %uint_5 %uint_8", + "OpReturn", + "OpFunctionEnd" + // clang-format on + }; + + auto result = SinglePassRunAndDisassemble( + JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false); + + EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result)); + const std::string& output = std::get<0>(result); + EXPECT_THAT(output, Not(HasSubstr("OpIMul"))); + EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %uint %uint_5 %uint_3")); +} + +// TODO(dneto): Add Effcee as required dependency, and make this unconditional. +// Test to make sure we replace 16*5 +// Also demonstrate use of Effcee matching. +TEST_F(StrengthReductionBasicTest, BasicReplaceMulBy16) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpName %main "main" + %void = OpTypeVoid + %4 = OpTypeFunction %void +; We know disassembly will produce %uint here, but +; CHECK: %uint = OpTypeInt 32 0 +; CHECK-DAG: [[five:%[a-zA-Z_\d]+]] = OpConstant %uint 5 + +; We have RE2 regular expressions, so \w matches [_a-zA-Z0-9]. +; This shows the preferred pattern for matching SPIR-V identifiers. +; (We could have cheated in this case since we know the disassembler will +; generate the 'nice' name of "%uint_4". +; CHECK-DAG: [[four:%\w+]] = OpConstant %uint 4 + %uint = OpTypeInt 32 0 + %uint_5 = OpConstant %uint 5 + %uint_16 = OpConstant %uint 16 + %main = OpFunction %void None %4 +; CHECK: OpLabel + %8 = OpLabel +; CHECK-NEXT: OpShiftLeftLogical %uint [[five]] [[four]] +; The multiplication disappears. +; CHECK-NOT: OpIMul + %9 = OpIMul %uint %uint_16 %uint_5 + OpReturn +; CHECK: OpFunctionEnd + OpFunctionEnd)"; + + SinglePassRunAndMatch(text, false); +} + +// Test to make sure we replace a multiple of 32 and 4. +TEST_F(StrengthReductionBasicTest, BasicTwoPowersOf2) { + // In this case, we have two powers of 2. Need to make sure we replace only + // one of them for the bit shift. + // clang-format off + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpName %main "main" + %void = OpTypeVoid + %4 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%int_32 = OpConstant %int 32 + %int_4 = OpConstant %int 4 + %main = OpFunction %void None %4 + %8 = OpLabel + %9 = OpIMul %int %int_32 %int_4 + OpReturn + OpFunctionEnd +)"; + // clang-format on + auto result = SinglePassRunAndDisassemble( + text, /* skip_nop = */ true, /* do_validation = */ false); + + EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result)); + const std::string& output = std::get<0>(result); + EXPECT_THAT(output, Not(HasSubstr("OpIMul"))); + EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %int %int_4 %uint_5")); +} + +// Test to make sure we don't replace 0*5. +TEST_F(StrengthReductionBasicTest, BasicDontReplace0) { + const std::vector text = { + // clang-format off + "OpCapability Shader", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Vertex %main \"main\"", + "OpName %main \"main\"", + "%void = OpTypeVoid", + "%4 = OpTypeFunction %void", + "%int = OpTypeInt 32 1", + "%int_0 = OpConstant %int 0", + "%int_5 = OpConstant %int 5", + "%main = OpFunction %void None %4", + "%8 = OpLabel", + "%9 = OpIMul %int %int_0 %int_5", + "OpReturn", + "OpFunctionEnd" + // clang-format on + }; + + auto result = SinglePassRunAndDisassemble( + JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false); + + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +// Test to make sure we do not replace a multiple of 5 and 7. +TEST_F(StrengthReductionBasicTest, BasicNoChange) { + const std::vector text = { + // clang-format off + "OpCapability Shader", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Vertex %2 \"main\"", + "OpName %2 \"main\"", + "%3 = OpTypeVoid", + "%4 = OpTypeFunction %3", + "%5 = OpTypeInt 32 1", + "%6 = OpTypeInt 32 0", + "%7 = OpConstant %5 5", + "%8 = OpConstant %5 7", + "%2 = OpFunction %3 None %4", + "%9 = OpLabel", + "%10 = OpIMul %5 %7 %8", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + + auto result = SinglePassRunAndDisassemble( + JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false); + + EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result)); +} + +// Test to make sure constants and types are reused and not duplicated. +TEST_F(StrengthReductionBasicTest, NoDuplicateConstantsAndTypes) { + const std::vector text = { + // clang-format off + "OpCapability Shader", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Vertex %main \"main\"", + "OpName %main \"main\"", + "%void = OpTypeVoid", + "%4 = OpTypeFunction %void", + "%uint = OpTypeInt 32 0", + "%uint_8 = OpConstant %uint 8", + "%uint_3 = OpConstant %uint 3", + "%main = OpFunction %void None %4", + "%8 = OpLabel", + "%9 = OpIMul %uint %uint_8 %uint_3", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + + auto result = SinglePassRunAndDisassemble( + JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false); + + EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result)); + const std::string& output = std::get<0>(result); + EXPECT_THAT(output, + Not(MatchesRegex(".*OpConstant %uint 3.*OpConstant %uint 3.*"))); + EXPECT_THAT(output, Not(MatchesRegex(".*OpTypeInt 32 0.*OpTypeInt 32 0.*"))); +} + +// Test to make sure we generate the constants only once +TEST_F(StrengthReductionBasicTest, BasicCreateOneConst) { + const std::vector text = { + // clang-format off + "OpCapability Shader", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Vertex %main \"main\"", + "OpName %main \"main\"", + "%void = OpTypeVoid", + "%4 = OpTypeFunction %void", + "%uint = OpTypeInt 32 0", + "%uint_5 = OpConstant %uint 5", + "%uint_9 = OpConstant %uint 9", + "%uint_128 = OpConstant %uint 128", + "%main = OpFunction %void None %4", + "%8 = OpLabel", + "%9 = OpIMul %uint %uint_5 %uint_128", + "%10 = OpIMul %uint %uint_9 %uint_128", + "OpReturn", + "OpFunctionEnd" + // clang-format on + }; + + auto result = SinglePassRunAndDisassemble( + JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false); + + EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result)); + const std::string& output = std::get<0>(result); + EXPECT_THAT(output, Not(HasSubstr("OpIMul"))); + EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %uint %uint_5 %uint_7")); + EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %uint %uint_9 %uint_7")); +} + +// Test to make sure we generate the instructions in the correct position and +// that the uses get replaced as well. Here we check that the use in the return +// is replaced, we also check that we can replace two OpIMuls when one feeds the +// other. +TEST_F(StrengthReductionBasicTest, BasicCheckPositionAndReplacement) { + // This is just the preamble to set up the test. + const std::vector common_text = { + // clang-format off + "OpCapability Shader", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Fragment %main \"main\" %gl_FragColor", + "OpExecutionMode %main OriginUpperLeft", + "OpName %main \"main\"", + "OpName %foo_i1_ \"foo(i1;\"", + "OpName %n \"n\"", + "OpName %gl_FragColor \"gl_FragColor\"", + "OpName %param \"param\"", + "OpDecorate %gl_FragColor Location 0", + "%void = OpTypeVoid", + "%3 = OpTypeFunction %void", + "%int = OpTypeInt 32 1", +"%_ptr_Function_int = OpTypePointer Function %int", + "%8 = OpTypeFunction %int %_ptr_Function_int", + "%int_256 = OpConstant %int 256", + "%int_2 = OpConstant %int 2", + "%float = OpTypeFloat 32", + "%v4float = OpTypeVector %float 4", +"%_ptr_Output_v4float = OpTypePointer Output %v4float", +"%gl_FragColor = OpVariable %_ptr_Output_v4float Output", + "%float_1 = OpConstant %float 1", + "%int_10 = OpConstant %int 10", + "%float_0_375 = OpConstant %float 0.375", + "%float_0_75 = OpConstant %float 0.75", + "%uint = OpTypeInt 32 0", + "%uint_8 = OpConstant %uint 8", + "%uint_1 = OpConstant %uint 1", + "%main = OpFunction %void None %3", + "%5 = OpLabel", + "%param = OpVariable %_ptr_Function_int Function", + "OpStore %param %int_10", + "%26 = OpFunctionCall %int %foo_i1_ %param", + "%27 = OpConvertSToF %float %26", + "%28 = OpFDiv %float %float_1 %27", + "%31 = OpCompositeConstruct %v4float %28 %float_0_375 %float_0_75 %float_1", + "OpStore %gl_FragColor %31", + "OpReturn", + "OpFunctionEnd" + // clang-format on + }; + + // This is the real test. The two OpIMul should be replaced. The expected + // output is in |foo_after|. + const std::vector foo_before = { + // clang-format off + "%foo_i1_ = OpFunction %int None %8", + "%n = OpFunctionParameter %_ptr_Function_int", + "%11 = OpLabel", + "%12 = OpLoad %int %n", + "%14 = OpIMul %int %12 %int_256", + "%16 = OpIMul %int %14 %int_2", + "OpReturnValue %16", + "OpFunctionEnd", + + // clang-format on + }; + + const std::vector foo_after = { + // clang-format off + "%foo_i1_ = OpFunction %int None %8", + "%n = OpFunctionParameter %_ptr_Function_int", + "%11 = OpLabel", + "%12 = OpLoad %int %n", + "%33 = OpShiftLeftLogical %int %12 %uint_8", + "%34 = OpShiftLeftLogical %int %33 %uint_1", + "OpReturnValue %34", + "OpFunctionEnd", + // clang-format on + }; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck( + JoinAllInsts(Concat(common_text, foo_before)), + JoinAllInsts(Concat(common_text, foo_after)), + /* skip_nop = */ true, /* do_validate = */ true); +} + +// Test that, when the result of an OpIMul instruction has more than 1 use, and +// the instruction is replaced, all of the uses of the results are replace with +// the new result. +TEST_F(StrengthReductionBasicTest, BasicTestMultipleReplacements) { + // This is just the preamble to set up the test. + const std::vector common_text = { + // clang-format off + "OpCapability Shader", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Fragment %main \"main\" %gl_FragColor", + "OpExecutionMode %main OriginUpperLeft", + "OpName %main \"main\"", + "OpName %foo_i1_ \"foo(i1;\"", + "OpName %n \"n\"", + "OpName %gl_FragColor \"gl_FragColor\"", + "OpName %param \"param\"", + "OpDecorate %gl_FragColor Location 0", + "%void = OpTypeVoid", + "%3 = OpTypeFunction %void", + "%int = OpTypeInt 32 1", +"%_ptr_Function_int = OpTypePointer Function %int", + "%8 = OpTypeFunction %int %_ptr_Function_int", + "%int_256 = OpConstant %int 256", + "%int_2 = OpConstant %int 2", + "%float = OpTypeFloat 32", + "%v4float = OpTypeVector %float 4", +"%_ptr_Output_v4float = OpTypePointer Output %v4float", +"%gl_FragColor = OpVariable %_ptr_Output_v4float Output", + "%float_1 = OpConstant %float 1", + "%int_10 = OpConstant %int 10", + "%float_0_375 = OpConstant %float 0.375", + "%float_0_75 = OpConstant %float 0.75", + "%uint = OpTypeInt 32 0", + "%uint_8 = OpConstant %uint 8", + "%uint_1 = OpConstant %uint 1", + "%main = OpFunction %void None %3", + "%5 = OpLabel", + "%param = OpVariable %_ptr_Function_int Function", + "OpStore %param %int_10", + "%26 = OpFunctionCall %int %foo_i1_ %param", + "%27 = OpConvertSToF %float %26", + "%28 = OpFDiv %float %float_1 %27", + "%31 = OpCompositeConstruct %v4float %28 %float_0_375 %float_0_75 %float_1", + "OpStore %gl_FragColor %31", + "OpReturn", + "OpFunctionEnd" + // clang-format on + }; + + // This is the real test. The two OpIMul instructions should be replaced. In + // particular, we want to be sure that both uses of %16 are changed to use the + // new result. + const std::vector foo_before = { + // clang-format off + "%foo_i1_ = OpFunction %int None %8", + "%n = OpFunctionParameter %_ptr_Function_int", + "%11 = OpLabel", + "%12 = OpLoad %int %n", + "%14 = OpIMul %int %12 %int_256", + "%16 = OpIMul %int %14 %int_2", + "%17 = OpIAdd %int %14 %16", + "OpReturnValue %17", + "OpFunctionEnd", + + // clang-format on + }; + + const std::vector foo_after = { + // clang-format off + "%foo_i1_ = OpFunction %int None %8", + "%n = OpFunctionParameter %_ptr_Function_int", + "%11 = OpLabel", + "%12 = OpLoad %int %n", + "%34 = OpShiftLeftLogical %int %12 %uint_8", + "%35 = OpShiftLeftLogical %int %34 %uint_1", + "%17 = OpIAdd %int %34 %35", + "OpReturnValue %17", + "OpFunctionEnd", + // clang-format on + }; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck( + JoinAllInsts(Concat(common_text, foo_before)), + JoinAllInsts(Concat(common_text, foo_after)), + /* skip_nop = */ true, /* do_validate = */ true); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/strip_debug_info_test.cpp b/test/opt/strip_debug_info_test.cpp new file mode 100644 index 000000000..f40ed382a --- /dev/null +++ b/test/opt/strip_debug_info_test.cpp @@ -0,0 +1,107 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using StripLineDebugInfoTest = PassTest<::testing::Test>; + +TEST_F(StripLineDebugInfoTest, LineNoLine) { + std::vector text = { + // clang-format off + "OpCapability Shader", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Vertex %2 \"main\"", + "%3 = OpString \"minimal.vert\"", + "OpModuleProcessed \"42\"", + "OpModuleProcessed \"43\"", + "OpModuleProcessed \"44\"", + "OpNoLine", + "OpLine %3 10 10", + "%void = OpTypeVoid", + "OpLine %3 100 100", + "%5 = OpTypeFunction %void", + "%2 = OpFunction %void None %5", + "OpLine %3 1 1", + "OpNoLine", + "OpLine %3 2 2", + "OpLine %3 3 3", + "%6 = OpLabel", + "OpLine %3 4 4", + "OpNoLine", + "OpReturn", + "OpLine %3 4 4", + "OpNoLine", + "OpFunctionEnd", + // clang-format on + }; + SinglePassRunAndCheck(JoinAllInsts(text), + JoinNonDebugInsts(text), + /* skip_nop = */ false); + + // Let's add more debug instruction before the "OpString" instruction. + const std::vector more_text = { + "OpSourceContinued \"I'm a happy shader! Yay! ;)\"", + "OpSourceContinued \"wahahaha\"", + "OpSource ESSL 310", + "OpSource ESSL 310", + "OpSourceContinued \"wahahaha\"", + "OpSourceContinued \"wahahaha\"", + "OpSourceExtension \"save-the-world-extension\"", + "OpName %2 \"main\"", + }; + text.insert(text.begin() + 4, more_text.cbegin(), more_text.cend()); + SinglePassRunAndCheck(JoinAllInsts(text), + JoinNonDebugInsts(text), + /* skip_nop = */ false); +} + +using StripDebugInfoTest = PassTest<::testing::TestWithParam>; + +TEST_P(StripDebugInfoTest, Kind) { + std::vector text = { + "OpCapability Shader", + "OpMemoryModel Logical GLSL450", + GetParam(), + }; + SinglePassRunAndCheck(JoinAllInsts(text), + JoinNonDebugInsts(text), + /* skip_nop = */ false); +} + +// Test each possible non-line debug instruction. +// clang-format off +INSTANTIATE_TEST_CASE_P( + SingleKindDebugInst, StripDebugInfoTest, + ::testing::ValuesIn(std::vector({ + "OpSourceContinued \"I'm a happy shader! Yay! ;)\"", + "OpSource ESSL 310", + "OpSourceExtension \"save-the-world-extension\"", + "OpName %main \"main\"", + "OpMemberName %struct 0 \"field\"", + "%1 = OpString \"name.vert\"", + "OpModuleProcessed \"42\"", + }))); +// clang-format on + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/strip_reflect_info_test.cpp b/test/opt/strip_reflect_info_test.cpp new file mode 100644 index 000000000..a9cfd4763 --- /dev/null +++ b/test/opt/strip_reflect_info_test.cpp @@ -0,0 +1,90 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using StripLineReflectInfoTest = PassTest<::testing::Test>; + +TEST_F(StripLineReflectInfoTest, StripHlslSemantic) { + // This is a non-sensical example, but exercises the instructions. + std::string before = R"(OpCapability Shader +OpCapability Linkage +OpExtension "SPV_GOOGLE_decorate_string" +OpExtension "SPV_GOOGLE_hlsl_functionality1" +OpMemoryModel Logical Simple +OpDecorateStringGOOGLE %float HlslSemanticGOOGLE "foobar" +OpDecorateStringGOOGLE %void HlslSemanticGOOGLE "my goodness" +%void = OpTypeVoid +%float = OpTypeFloat 32 +)"; + std::string after = R"(OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical Simple +%void = OpTypeVoid +%float = OpTypeFloat 32 +)"; + + SinglePassRunAndCheck(before, after, false); +} + +TEST_F(StripLineReflectInfoTest, StripHlslCounterBuffer) { + std::string before = R"(OpCapability Shader +OpCapability Linkage +OpExtension "SPV_GOOGLE_hlsl_functionality1" +OpMemoryModel Logical Simple +OpDecorateId %void HlslCounterBufferGOOGLE %float +%void = OpTypeVoid +%float = OpTypeFloat 32 +)"; + std::string after = R"(OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical Simple +%void = OpTypeVoid +%float = OpTypeFloat 32 +)"; + + SinglePassRunAndCheck(before, after, false); +} + +TEST_F(StripLineReflectInfoTest, StripHlslSemanticOnMember) { + // This is a non-sensical example, but exercises the instructions. + std::string before = R"(OpCapability Shader +OpCapability Linkage +OpExtension "SPV_GOOGLE_decorate_string" +OpExtension "SPV_GOOGLE_hlsl_functionality1" +OpMemoryModel Logical Simple +OpMemberDecorateStringGOOGLE %struct 0 HlslSemanticGOOGLE "foobar" +%float = OpTypeFloat 32 +%_struct_3 = OpTypeStruct %float +)"; + std::string after = R"(OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical Simple +%float = OpTypeFloat 32 +%_struct_3 = OpTypeStruct %float +)"; + + SinglePassRunAndCheck(before, after, false); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/struct_cfg_analysis_test.cpp b/test/opt/struct_cfg_analysis_test.cpp new file mode 100644 index 000000000..13f9022d0 --- /dev/null +++ b/test/opt/struct_cfg_analysis_test.cpp @@ -0,0 +1,466 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "source/opt/struct_cfg_analysis.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using StructCFGAnalysisTest = PassTest<::testing::Test>; + +TEST_F(StructCFGAnalysisTest, BBInSelection) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%bool = OpTypeBool +%bool_undef = OpUndef %bool +%uint = OpTypeInt 32 0 +%uint_undef = OpUndef %uint +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%1 = OpLabel +OpSelectionMerge %3 None +OpBranchConditional %undef_bool %2 %3 +%2 = OpLabel +OpBranch %3 +%3 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + + StructuredCFGAnalysis analysis(context.get()); + + // The header is not in the construct. + EXPECT_EQ(analysis.ContainingConstruct(1), 0); + EXPECT_EQ(analysis.ContainingLoop(1), 0); + EXPECT_EQ(analysis.MergeBlock(1), 0); + EXPECT_EQ(analysis.LoopMergeBlock(1), 0); + + // BB2 is in the construct. + EXPECT_EQ(analysis.ContainingConstruct(2), 1); + EXPECT_EQ(analysis.ContainingLoop(2), 0); + EXPECT_EQ(analysis.MergeBlock(2), 3); + EXPECT_EQ(analysis.LoopMergeBlock(2), 0); + + // The merge node is not in the construct. + EXPECT_EQ(analysis.ContainingConstruct(3), 0); + EXPECT_EQ(analysis.ContainingLoop(3), 0); + EXPECT_EQ(analysis.MergeBlock(3), 0); + EXPECT_EQ(analysis.LoopMergeBlock(3), 0); +} + +TEST_F(StructCFGAnalysisTest, BBInLoop) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%bool = OpTypeBool +%bool_undef = OpUndef %bool +%uint = OpTypeInt 32 0 +%uint_undef = OpUndef %uint +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%entry_lab = OpLabel +OpBranch %1 +%1 = OpLabel +OpLoopMerge %3 %4 None +OpBranchConditional %undef_bool %2 %3 +%2 = OpLabel +OpBranch %3 +%4 = OpLabel +OpBranch %1 +%3 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + + StructuredCFGAnalysis analysis(context.get()); + + // The header is not in the construct. + EXPECT_EQ(analysis.ContainingConstruct(1), 0); + EXPECT_EQ(analysis.ContainingLoop(1), 0); + EXPECT_EQ(analysis.MergeBlock(1), 0); + EXPECT_EQ(analysis.LoopMergeBlock(1), 0); + + // BB2 is in the construct. + EXPECT_EQ(analysis.ContainingConstruct(2), 1); + EXPECT_EQ(analysis.ContainingLoop(2), 1); + EXPECT_EQ(analysis.MergeBlock(2), 3); + EXPECT_EQ(analysis.LoopMergeBlock(2), 3); + + // The merge node is not in the construct. + EXPECT_EQ(analysis.ContainingConstruct(3), 0); + EXPECT_EQ(analysis.ContainingLoop(3), 0); + EXPECT_EQ(analysis.MergeBlock(3), 0); + EXPECT_EQ(analysis.LoopMergeBlock(3), 0); + + // The continue block is in the construct. + EXPECT_EQ(analysis.ContainingConstruct(4), 1); + EXPECT_EQ(analysis.ContainingLoop(4), 1); + EXPECT_EQ(analysis.MergeBlock(4), 3); + EXPECT_EQ(analysis.LoopMergeBlock(4), 3); +} + +TEST_F(StructCFGAnalysisTest, SelectionInLoop) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%bool = OpTypeBool +%bool_undef = OpUndef %bool +%uint = OpTypeInt 32 0 +%uint_undef = OpUndef %uint +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%entry_lab = OpLabel +OpBranch %1 +%1 = OpLabel +OpLoopMerge %3 %4 None +OpBranchConditional %undef_bool %2 %3 +%2 = OpLabel +OpSelectionMerge %6 None +OpBranchConditional %undef_bool %5 %6 +%5 = OpLabel +OpBranch %6 +%6 = OpLabel +OpBranch %3 +%4 = OpLabel +OpBranch %1 +%3 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + + StructuredCFGAnalysis analysis(context.get()); + + // The loop header is not in either construct. + EXPECT_EQ(analysis.ContainingConstruct(1), 0); + EXPECT_EQ(analysis.ContainingLoop(1), 0); + EXPECT_EQ(analysis.MergeBlock(1), 0); + EXPECT_EQ(analysis.LoopMergeBlock(1), 0); + + // Selection header is in the loop only. + EXPECT_EQ(analysis.ContainingConstruct(2), 1); + EXPECT_EQ(analysis.ContainingLoop(2), 1); + EXPECT_EQ(analysis.MergeBlock(2), 3); + EXPECT_EQ(analysis.LoopMergeBlock(2), 3); + + // The loop merge node is not in either construct. + EXPECT_EQ(analysis.ContainingConstruct(3), 0); + EXPECT_EQ(analysis.ContainingLoop(3), 0); + EXPECT_EQ(analysis.MergeBlock(3), 0); + EXPECT_EQ(analysis.LoopMergeBlock(3), 0); + + // The continue block is in the loop only. + EXPECT_EQ(analysis.ContainingConstruct(4), 1); + EXPECT_EQ(analysis.ContainingLoop(4), 1); + EXPECT_EQ(analysis.MergeBlock(4), 3); + EXPECT_EQ(analysis.LoopMergeBlock(4), 3); + + // BB5 is in the selection fist and the loop. + EXPECT_EQ(analysis.ContainingConstruct(5), 2); + EXPECT_EQ(analysis.ContainingLoop(5), 1); + EXPECT_EQ(analysis.MergeBlock(5), 6); + EXPECT_EQ(analysis.LoopMergeBlock(5), 3); + + // The selection merge is in the loop only. + EXPECT_EQ(analysis.ContainingConstruct(6), 1); + EXPECT_EQ(analysis.ContainingLoop(6), 1); + EXPECT_EQ(analysis.MergeBlock(6), 3); + EXPECT_EQ(analysis.LoopMergeBlock(6), 3); +} + +TEST_F(StructCFGAnalysisTest, LoopInSelection) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%bool = OpTypeBool +%bool_undef = OpUndef %bool +%uint = OpTypeInt 32 0 +%uint_undef = OpUndef %uint +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%entry_lab = OpLabel +OpBranch %1 +%1 = OpLabel +OpSelectionMerge %3 None +OpBranchConditional %undef_bool %2 %3 +%2 = OpLabel +OpLoopMerge %4 %5 None +OpBranchConditional %undef_bool %4 %6 +%5 = OpLabel +OpBranch %2 +%6 = OpLabel +OpBranch %4 +%4 = OpLabel +OpBranch %3 +%3 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + + StructuredCFGAnalysis analysis(context.get()); + + // The selection header is not in either construct. + EXPECT_EQ(analysis.ContainingConstruct(1), 0); + EXPECT_EQ(analysis.ContainingLoop(1), 0); + EXPECT_EQ(analysis.MergeBlock(1), 0); + EXPECT_EQ(analysis.LoopMergeBlock(1), 0); + + // Loop header is in the selection only. + EXPECT_EQ(analysis.ContainingConstruct(2), 1); + EXPECT_EQ(analysis.ContainingLoop(2), 0); + EXPECT_EQ(analysis.MergeBlock(2), 3); + EXPECT_EQ(analysis.LoopMergeBlock(2), 0); + + // The selection merge node is not in either construct. + EXPECT_EQ(analysis.ContainingConstruct(3), 0); + EXPECT_EQ(analysis.ContainingLoop(3), 0); + EXPECT_EQ(analysis.MergeBlock(3), 0); + EXPECT_EQ(analysis.LoopMergeBlock(3), 0); + + // The loop merge is in the selection only. + EXPECT_EQ(analysis.ContainingConstruct(4), 1); + EXPECT_EQ(analysis.ContainingLoop(4), 0); + EXPECT_EQ(analysis.MergeBlock(4), 3); + EXPECT_EQ(analysis.LoopMergeBlock(4), 0); + + // The loop continue target is in the loop. + EXPECT_EQ(analysis.ContainingConstruct(5), 2); + EXPECT_EQ(analysis.ContainingLoop(5), 2); + EXPECT_EQ(analysis.MergeBlock(5), 4); + EXPECT_EQ(analysis.LoopMergeBlock(5), 4); + + // BB6 is in the loop. + EXPECT_EQ(analysis.ContainingConstruct(6), 2); + EXPECT_EQ(analysis.ContainingLoop(6), 2); + EXPECT_EQ(analysis.MergeBlock(6), 4); + EXPECT_EQ(analysis.LoopMergeBlock(6), 4); +} + +TEST_F(StructCFGAnalysisTest, SelectionInSelection) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%bool = OpTypeBool +%bool_undef = OpUndef %bool +%uint = OpTypeInt 32 0 +%uint_undef = OpUndef %uint +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%entry_lab = OpLabel +OpBranch %1 +%1 = OpLabel +OpSelectionMerge %3 None +OpBranchConditional %undef_bool %2 %3 +%2 = OpLabel +OpSelectionMerge %4 None +OpBranchConditional %undef_bool %4 %5 +%5 = OpLabel +OpBranch %4 +%4 = OpLabel +OpBranch %3 +%3 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + + StructuredCFGAnalysis analysis(context.get()); + + // The outer selection header is not in either construct. + EXPECT_EQ(analysis.ContainingConstruct(1), 0); + EXPECT_EQ(analysis.ContainingLoop(1), 0); + EXPECT_EQ(analysis.MergeBlock(1), 0); + EXPECT_EQ(analysis.LoopMergeBlock(1), 0); + + // The inner header is in the outer selection. + EXPECT_EQ(analysis.ContainingConstruct(2), 1); + EXPECT_EQ(analysis.ContainingLoop(2), 0); + EXPECT_EQ(analysis.MergeBlock(2), 3); + EXPECT_EQ(analysis.LoopMergeBlock(2), 0); + + // The outer merge node is not in either construct. + EXPECT_EQ(analysis.ContainingConstruct(3), 0); + EXPECT_EQ(analysis.ContainingLoop(3), 0); + EXPECT_EQ(analysis.MergeBlock(3), 0); + EXPECT_EQ(analysis.LoopMergeBlock(3), 0); + + // The inner merge is in the outer selection. + EXPECT_EQ(analysis.ContainingConstruct(4), 1); + EXPECT_EQ(analysis.ContainingLoop(4), 0); + EXPECT_EQ(analysis.MergeBlock(4), 3); + EXPECT_EQ(analysis.LoopMergeBlock(4), 0); + + // BB5 is in the inner selection. + EXPECT_EQ(analysis.ContainingConstruct(5), 2); + EXPECT_EQ(analysis.ContainingLoop(5), 0); + EXPECT_EQ(analysis.MergeBlock(5), 4); + EXPECT_EQ(analysis.LoopMergeBlock(5), 0); +} + +TEST_F(StructCFGAnalysisTest, LoopInLoop) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%bool = OpTypeBool +%bool_undef = OpUndef %bool +%uint = OpTypeInt 32 0 +%uint_undef = OpUndef %uint +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%entry_lab = OpLabel +OpBranch %1 +%1 = OpLabel +OpLoopMerge %3 %7 None +OpBranchConditional %undef_bool %2 %3 +%2 = OpLabel +OpLoopMerge %4 %5 None +OpBranchConditional %undef_bool %4 %6 +%5 = OpLabel +OpBranch %2 +%6 = OpLabel +OpBranch %4 +%4 = OpLabel +OpBranch %3 +%7 = OpLabel +OpBranch %1 +%3 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + + StructuredCFGAnalysis analysis(context.get()); + + // The outer loop header is not in either construct. + EXPECT_EQ(analysis.ContainingConstruct(1), 0); + EXPECT_EQ(analysis.ContainingLoop(1), 0); + EXPECT_EQ(analysis.MergeBlock(1), 0); + EXPECT_EQ(analysis.LoopMergeBlock(1), 0); + + // The inner loop header is in the outer loop. + EXPECT_EQ(analysis.ContainingConstruct(2), 1); + EXPECT_EQ(analysis.ContainingLoop(2), 1); + EXPECT_EQ(analysis.MergeBlock(2), 3); + EXPECT_EQ(analysis.LoopMergeBlock(2), 3); + + // The outer merge node is not in either construct. + EXPECT_EQ(analysis.ContainingConstruct(3), 0); + EXPECT_EQ(analysis.ContainingLoop(3), 0); + EXPECT_EQ(analysis.MergeBlock(3), 0); + EXPECT_EQ(analysis.LoopMergeBlock(3), 0); + + // The inner merge is in the outer loop. + EXPECT_EQ(analysis.ContainingConstruct(4), 1); + EXPECT_EQ(analysis.ContainingLoop(4), 1); + EXPECT_EQ(analysis.MergeBlock(4), 3); + EXPECT_EQ(analysis.LoopMergeBlock(4), 3); + + // The inner continue target is in the inner loop. + EXPECT_EQ(analysis.ContainingConstruct(5), 2); + EXPECT_EQ(analysis.ContainingLoop(5), 2); + EXPECT_EQ(analysis.MergeBlock(5), 4); + EXPECT_EQ(analysis.LoopMergeBlock(5), 4); + + // BB6 is in the loop. + EXPECT_EQ(analysis.ContainingConstruct(6), 2); + EXPECT_EQ(analysis.ContainingLoop(6), 2); + EXPECT_EQ(analysis.MergeBlock(6), 4); + EXPECT_EQ(analysis.LoopMergeBlock(6), 4); + + // The outer continue target is in the outer loop. + EXPECT_EQ(analysis.ContainingConstruct(7), 1); + EXPECT_EQ(analysis.ContainingLoop(7), 1); + EXPECT_EQ(analysis.MergeBlock(7), 3); + EXPECT_EQ(analysis.LoopMergeBlock(7), 3); +} + +TEST_F(StructCFGAnalysisTest, KernelTest) { + const std::string text = R"( +OpCapability Kernel +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +%void = OpTypeVoid +%bool = OpTypeBool +%bool_undef = OpUndef %bool +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%1 = OpLabel +OpBranchConditional %undef_bool %2 %3 +%2 = OpLabel +OpBranch %3 +%3 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + + StructuredCFGAnalysis analysis(context.get()); + + // No structured control flow, so none of the basic block are in any + // construct. + for (uint32_t i = 1; i <= 3; i++) { + EXPECT_EQ(analysis.ContainingConstruct(i), 0); + EXPECT_EQ(analysis.ContainingLoop(i), 0); + EXPECT_EQ(analysis.MergeBlock(i), 0); + EXPECT_EQ(analysis.LoopMergeBlock(i), 0); + } +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/type_manager_test.cpp b/test/opt/type_manager_test.cpp new file mode 100644 index 000000000..1072c365c --- /dev/null +++ b/test/opt/type_manager_test.cpp @@ -0,0 +1,1148 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "effcee/effcee.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "source/opt/build_module.h" +#include "source/opt/instruction.h" +#include "source/opt/type_manager.h" +#include "spirv-tools/libspirv.hpp" + +namespace spvtools { +namespace opt { +namespace analysis { +namespace { + +bool Validate(const std::vector& bin) { + spv_target_env target_env = SPV_ENV_UNIVERSAL_1_2; + spv_context spvContext = spvContextCreate(target_env); + spv_diagnostic diagnostic = nullptr; + spv_const_binary_t binary = {bin.data(), bin.size()}; + spv_result_t error = spvValidate(spvContext, &binary, &diagnostic); + if (error != 0) spvDiagnosticPrint(diagnostic); + spvDiagnosticDestroy(diagnostic); + spvContextDestroy(spvContext); + return error == 0; +} + +void Match(const std::string& original, IRContext* context, + bool do_validation = true) { + std::vector bin; + context->module()->ToBinary(&bin, true); + if (do_validation) { + EXPECT_TRUE(Validate(bin)); + } + std::string assembly; + SpirvTools tools(SPV_ENV_UNIVERSAL_1_2); + EXPECT_TRUE( + tools.Disassemble(bin, &assembly, SpirvTools::kDefaultDisassembleOption)) + << "Disassembling failed for shader:\n" + << assembly << std::endl; + auto match_result = effcee::Match(assembly, original); + EXPECT_EQ(effcee::Result::Status::Ok, match_result.status()) + << match_result.message() << "\nChecking result:\n" + << assembly; +} + +std::vector> GenerateAllTypes() { + // Types in this test case are only equal to themselves, nothing else. + std::vector> types; + + // Void, Bool + types.emplace_back(new Void()); + auto* voidt = types.back().get(); + types.emplace_back(new Bool()); + auto* boolt = types.back().get(); + + // Integer + types.emplace_back(new Integer(32, true)); + auto* s32 = types.back().get(); + types.emplace_back(new Integer(32, false)); + types.emplace_back(new Integer(64, true)); + types.emplace_back(new Integer(64, false)); + auto* u64 = types.back().get(); + + // Float + types.emplace_back(new Float(32)); + auto* f32 = types.back().get(); + types.emplace_back(new Float(64)); + + // Vector + types.emplace_back(new Vector(s32, 2)); + types.emplace_back(new Vector(s32, 3)); + auto* v3s32 = types.back().get(); + types.emplace_back(new Vector(u64, 4)); + types.emplace_back(new Vector(f32, 3)); + auto* v3f32 = types.back().get(); + + // Matrix + types.emplace_back(new Matrix(v3s32, 3)); + types.emplace_back(new Matrix(v3s32, 4)); + types.emplace_back(new Matrix(v3f32, 4)); + + // Images + types.emplace_back(new Image(s32, SpvDim2D, 0, 0, 0, 0, SpvImageFormatRg8, + SpvAccessQualifierReadOnly)); + auto* image1 = types.back().get(); + types.emplace_back(new Image(s32, SpvDim2D, 0, 1, 0, 0, SpvImageFormatRg8, + SpvAccessQualifierReadOnly)); + types.emplace_back(new Image(s32, SpvDim3D, 0, 1, 0, 0, SpvImageFormatRg8, + SpvAccessQualifierReadOnly)); + types.emplace_back(new Image(voidt, SpvDim3D, 0, 1, 0, 1, SpvImageFormatRg8, + SpvAccessQualifierReadWrite)); + auto* image2 = types.back().get(); + + // Sampler + types.emplace_back(new Sampler()); + + // Sampled Image + types.emplace_back(new SampledImage(image1)); + types.emplace_back(new SampledImage(image2)); + + // Array + types.emplace_back(new Array(f32, 100)); + types.emplace_back(new Array(f32, 42)); + auto* a42f32 = types.back().get(); + types.emplace_back(new Array(u64, 24)); + + // RuntimeArray + types.emplace_back(new RuntimeArray(v3f32)); + types.emplace_back(new RuntimeArray(v3s32)); + auto* rav3s32 = types.back().get(); + + // Struct + types.emplace_back(new Struct(std::vector{s32})); + types.emplace_back(new Struct(std::vector{s32, f32})); + auto* sts32f32 = types.back().get(); + types.emplace_back( + new Struct(std::vector{u64, a42f32, rav3s32})); + + // Opaque + types.emplace_back(new Opaque("")); + types.emplace_back(new Opaque("hello")); + types.emplace_back(new Opaque("world")); + + // Pointer + types.emplace_back(new Pointer(f32, SpvStorageClassInput)); + types.emplace_back(new Pointer(sts32f32, SpvStorageClassFunction)); + types.emplace_back(new Pointer(a42f32, SpvStorageClassFunction)); + + // Function + types.emplace_back(new Function(voidt, {})); + types.emplace_back(new Function(voidt, {boolt})); + types.emplace_back(new Function(voidt, {boolt, s32})); + types.emplace_back(new Function(s32, {boolt, s32})); + + // Event, Device Event, Reserve Id, Queue, + types.emplace_back(new Event()); + types.emplace_back(new DeviceEvent()); + types.emplace_back(new ReserveId()); + types.emplace_back(new Queue()); + + // Pipe, Forward Pointer, PipeStorage, NamedBarrier, AccelerationStructureNV + types.emplace_back(new Pipe(SpvAccessQualifierReadWrite)); + types.emplace_back(new Pipe(SpvAccessQualifierReadOnly)); + types.emplace_back(new ForwardPointer(1, SpvStorageClassInput)); + types.emplace_back(new ForwardPointer(2, SpvStorageClassInput)); + types.emplace_back(new ForwardPointer(2, SpvStorageClassUniform)); + types.emplace_back(new PipeStorage()); + types.emplace_back(new NamedBarrier()); + types.emplace_back(new AccelerationStructureNV()); + + return types; +} + +TEST(TypeManager, TypeStrings) { + const std::string text = R"( + OpTypeForwardPointer !20 !2 ; id for %p is 20, Uniform is 2 + %void = OpTypeVoid + %bool = OpTypeBool + %u32 = OpTypeInt 32 0 + %id4 = OpConstant %u32 4 + %s32 = OpTypeInt 32 1 + %f64 = OpTypeFloat 64 + %v3u32 = OpTypeVector %u32 3 + %m3x3 = OpTypeMatrix %v3u32 3 + %img1 = OpTypeImage %s32 Cube 0 1 1 0 R32f ReadWrite + %img2 = OpTypeImage %s32 Cube 0 1 1 0 R32f + %sampler = OpTypeSampler + %si1 = OpTypeSampledImage %img1 + %si2 = OpTypeSampledImage %img2 + %a5u32 = OpTypeArray %u32 %id4 + %af64 = OpTypeRuntimeArray %f64 + %st1 = OpTypeStruct %u32 + %st2 = OpTypeStruct %f64 %s32 %v3u32 + %opaque1 = OpTypeOpaque "" + %opaque2 = OpTypeOpaque "opaque" + %p = OpTypePointer Uniform %st1 + %f = OpTypeFunction %void %u32 %u32 + %event = OpTypeEvent + %de = OpTypeDeviceEvent + %ri = OpTypeReserveId + %queue = OpTypeQueue + %pipe = OpTypePipe ReadOnly + %ps = OpTypePipeStorage + %nb = OpTypeNamedBarrier + %rtacc = OpTypeAccelerationStructureNV + )"; + + std::vector> type_id_strs = { + {1, "void"}, + {2, "bool"}, + {3, "uint32"}, + // Id 4 is used by the constant. + {5, "sint32"}, + {6, "float64"}, + {7, ""}, + {8, "<, 3>"}, + {9, "image(sint32, 3, 0, 1, 1, 0, 3, 2)"}, + {10, "image(sint32, 3, 0, 1, 1, 0, 3, 0)"}, + {11, "sampler"}, + {12, "sampled_image(image(sint32, 3, 0, 1, 1, 0, 3, 2))"}, + {13, "sampled_image(image(sint32, 3, 0, 1, 1, 0, 3, 0))"}, + {14, "[uint32, id(4)]"}, + {15, "[float64]"}, + {16, "{uint32}"}, + {17, "{float64, sint32, }"}, + {18, "opaque('')"}, + {19, "opaque('opaque')"}, + {20, "{uint32}*"}, + {21, "(uint32, uint32) -> void"}, + {22, "event"}, + {23, "device_event"}, + {24, "reserve_id"}, + {25, "queue"}, + {26, "pipe(0)"}, + {27, "pipe_storage"}, + {28, "named_barrier"}, + {29, "accelerationStructureNV"}, + }; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); + TypeManager manager(nullptr, context.get()); + + EXPECT_EQ(type_id_strs.size(), manager.NumTypes()); + + for (const auto& p : type_id_strs) { + EXPECT_EQ(p.second, manager.GetType(p.first)->str()); + EXPECT_EQ(p.first, manager.GetId(manager.GetType(p.first))); + } +} + +TEST(TypeManager, StructWithFwdPtr) { + const std::string text = R"( + OpCapability Addresses + OpCapability Kernel + %1 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %7 "test" + OpSource OpenCL_C 102000 + OpDecorate %11 FuncParamAttr NoCapture + %11 = OpDecorationGroup + OpGroupDecorate %11 %8 %9 + OpTypeForwardPointer %100 CrossWorkgroup + %void = OpTypeVoid + %150 = OpTypeStruct %100 +%100 = OpTypePointer CrossWorkgroup %150 + %6 = OpTypeFunction %void %100 %100 + %7 = OpFunction %void Pure %6 + %8 = OpFunctionParameter %100 + %9 = OpFunctionParameter %100 + %10 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + TypeManager manager(nullptr, context.get()); + + Type* p100 = manager.GetType(100); + Type* s150 = manager.GetType(150); + + EXPECT_TRUE(p100->AsPointer()); + EXPECT_EQ(p100->AsPointer()->pointee_type(), s150); + + EXPECT_TRUE(s150->AsStruct()); + EXPECT_EQ(s150->AsStruct()->element_types()[0], p100); +} + +TEST(TypeManager, CircularFwdPtr) { + const std::string text = R"( + OpCapability Addresses + OpCapability Kernel + %1 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %7 "test" + OpSource OpenCL_C 102000 + OpDecorate %11 FuncParamAttr NoCapture + %11 = OpDecorationGroup + OpGroupDecorate %11 %8 %9 + OpTypeForwardPointer %100 CrossWorkgroup + OpTypeForwardPointer %200 CrossWorkgroup + %void = OpTypeVoid + %int = OpTypeInt 32 0 + %float = OpTypeFloat 32 + %150 = OpTypeStruct %200 %int + %250 = OpTypeStruct %100 %float +%100 = OpTypePointer CrossWorkgroup %150 +%200 = OpTypePointer CrossWorkgroup %250 + %6 = OpTypeFunction %void %100 %200 + %7 = OpFunction %void Pure %6 + %8 = OpFunctionParameter %100 + %9 = OpFunctionParameter %200 + %10 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + TypeManager manager(nullptr, context.get()); + + Type* p100 = manager.GetType(100); + Type* s150 = manager.GetType(150); + Type* p200 = manager.GetType(200); + Type* s250 = manager.GetType(250); + + EXPECT_TRUE(p100->AsPointer()); + EXPECT_EQ(p100->AsPointer()->pointee_type(), s150); + + EXPECT_TRUE(p200->AsPointer()); + EXPECT_EQ(p200->AsPointer()->pointee_type(), s250); + + EXPECT_TRUE(s150->AsStruct()); + EXPECT_EQ(s150->AsStruct()->element_types()[0], p200); + + EXPECT_TRUE(s250->AsStruct()); + EXPECT_EQ(s250->AsStruct()->element_types()[0], p100); +} + +TEST(TypeManager, IsomorphicStructWithFwdPtr) { + const std::string text = R"( + OpCapability Addresses + OpCapability Kernel + %1 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %7 "test" + OpSource OpenCL_C 102000 + OpDecorate %11 FuncParamAttr NoCapture + %11 = OpDecorationGroup + OpGroupDecorate %11 %8 %9 + OpTypeForwardPointer %100 CrossWorkgroup + OpTypeForwardPointer %200 CrossWorkgroup + %void = OpTypeVoid + %_struct_1 = OpTypeStruct %100 + %_struct_2 = OpTypeStruct %200 +%100 = OpTypePointer CrossWorkgroup %_struct_1 +%200 = OpTypePointer CrossWorkgroup %_struct_2 + %6 = OpTypeFunction %void %100 %200 + %7 = OpFunction %void Pure %6 + %8 = OpFunctionParameter %100 + %9 = OpFunctionParameter %200 + %10 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + TypeManager manager(nullptr, context.get()); + + EXPECT_EQ(manager.GetType(100), manager.GetType(200)); +} + +TEST(TypeManager, IsomorphicCircularFwdPtr) { + const std::string text = R"( + OpCapability Addresses + OpCapability Kernel + %1 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %7 "test" + OpSource OpenCL_C 102000 + OpDecorate %11 FuncParamAttr NoCapture + %11 = OpDecorationGroup + OpGroupDecorate %11 %8 %9 + OpTypeForwardPointer %100 CrossWorkgroup + OpTypeForwardPointer %200 CrossWorkgroup + OpTypeForwardPointer %300 CrossWorkgroup + OpTypeForwardPointer %400 CrossWorkgroup + %void = OpTypeVoid + %int = OpTypeInt 32 0 + %float = OpTypeFloat 32 + %150 = OpTypeStruct %200 %int + %250 = OpTypeStruct %100 %float + %350 = OpTypeStruct %400 %int + %450 = OpTypeStruct %300 %float +%100 = OpTypePointer CrossWorkgroup %150 +%200 = OpTypePointer CrossWorkgroup %250 +%300 = OpTypePointer CrossWorkgroup %350 +%400 = OpTypePointer CrossWorkgroup %450 + %6 = OpTypeFunction %void %100 %200 + %7 = OpFunction %void Pure %6 + %8 = OpFunctionParameter %100 + %9 = OpFunctionParameter %200 + %10 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + TypeManager manager(nullptr, context.get()); + + Type* p100 = manager.GetType(100); + Type* p300 = manager.GetType(300); + EXPECT_EQ(p100, p300); + Type* p200 = manager.GetType(200); + Type* p400 = manager.GetType(400); + EXPECT_EQ(p200, p400); + + Type* p150 = manager.GetType(150); + Type* p350 = manager.GetType(350); + EXPECT_EQ(p150, p350); + Type* p250 = manager.GetType(250); + Type* p450 = manager.GetType(450); + EXPECT_EQ(p250, p450); +} + +TEST(TypeManager, PartialIsomorphicFwdPtr) { + const std::string text = R"( + OpCapability Addresses + OpCapability Kernel + %1 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %7 "test" + OpSource OpenCL_C 102000 + OpDecorate %11 FuncParamAttr NoCapture + %11 = OpDecorationGroup + OpGroupDecorate %11 %8 %9 + OpTypeForwardPointer %100 CrossWorkgroup + OpTypeForwardPointer %200 CrossWorkgroup + %void = OpTypeVoid + %int = OpTypeInt 32 0 + %float = OpTypeFloat 32 + %150 = OpTypeStruct %200 %int + %250 = OpTypeStruct %200 %int +%100 = OpTypePointer CrossWorkgroup %150 +%200 = OpTypePointer CrossWorkgroup %250 + %6 = OpTypeFunction %void %100 %200 + %7 = OpFunction %void Pure %6 + %8 = OpFunctionParameter %100 + %9 = OpFunctionParameter %200 + %10 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + TypeManager manager(nullptr, context.get()); + + Type* p100 = manager.GetType(100); + Type* p200 = manager.GetType(200); + EXPECT_EQ(p100->AsPointer()->pointee_type(), + p200->AsPointer()->pointee_type()); +} + +TEST(TypeManager, DecorationOnStruct) { + const std::string text = R"( + OpDecorate %struct1 Block + OpDecorate %struct2 Block + OpDecorate %struct3 Block + OpDecorate %struct4 Block + + %u32 = OpTypeInt 32 0 ; id: 5 + %f32 = OpTypeFloat 32 ; id: 6 + %struct1 = OpTypeStruct %u32 %f32 ; base + %struct2 = OpTypeStruct %f32 %u32 ; different member order + %struct3 = OpTypeStruct %f32 ; different member list + %struct4 = OpTypeStruct %u32 %f32 ; the same + %struct7 = OpTypeStruct %f32 ; no decoration + )"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); + TypeManager manager(nullptr, context.get()); + + ASSERT_EQ(7u, manager.NumTypes()); + // Make sure we get ids correct. + ASSERT_EQ("uint32", manager.GetType(5)->str()); + ASSERT_EQ("float32", manager.GetType(6)->str()); + + // Try all combinations of pairs. Expect to be the same type only when the + // same id or (1, 4). + for (const auto id1 : {1, 2, 3, 4, 7}) { + for (const auto id2 : {1, 2, 3, 4, 7}) { + if (id1 == id2 || (id1 == 1 && id2 == 4) || (id1 == 4 && id2 == 1)) { + EXPECT_TRUE(manager.GetType(id1)->IsSame(manager.GetType(id2))) + << "%struct" << id1 << " is expected to be the same as %struct" + << id2; + } else { + EXPECT_FALSE(manager.GetType(id1)->IsSame(manager.GetType(id2))) + << "%struct" << id1 << " is expected to be different with %struct" + << id2; + } + } + } +} + +TEST(TypeManager, DecorationOnMember) { + const std::string text = R"( + OpMemberDecorate %struct1 0 Offset 0 + OpMemberDecorate %struct2 0 Offset 0 + OpMemberDecorate %struct3 0 Offset 0 + OpMemberDecorate %struct4 0 Offset 0 + OpMemberDecorate %struct5 1 Offset 0 + OpMemberDecorate %struct6 0 Offset 4 + + OpDecorate %struct7 Block + OpMemberDecorate %struct7 0 Offset 0 + + %u32 = OpTypeInt 32 0 ; id: 8 + %f32 = OpTypeFloat 32 ; id: 9 + %struct1 = OpTypeStruct %u32 %f32 ; base + %struct2 = OpTypeStruct %f32 %u32 ; different member order + %struct3 = OpTypeStruct %f32 ; different member list + %struct4 = OpTypeStruct %u32 %f32 ; the same + %struct5 = OpTypeStruct %u32 %f32 ; member decorate different field + %struct6 = OpTypeStruct %u32 %f32 ; different member decoration parameter + %struct7 = OpTypeStruct %u32 %f32 ; extra decoration on the struct + %struct10 = OpTypeStruct %u32 %f32 ; no member decoration + )"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); + TypeManager manager(nullptr, context.get()); + + ASSERT_EQ(10u, manager.NumTypes()); + // Make sure we get ids correct. + ASSERT_EQ("uint32", manager.GetType(8)->str()); + ASSERT_EQ("float32", manager.GetType(9)->str()); + + // Try all combinations of pairs. Expect to be the same type only when the + // same id or (1, 4). + for (const auto id1 : {1, 2, 3, 4, 5, 6, 7, 10}) { + for (const auto id2 : {1, 2, 3, 4, 5, 6, 7, 10}) { + if (id1 == id2 || (id1 == 1 && id2 == 4) || (id1 == 4 && id2 == 1)) { + EXPECT_TRUE(manager.GetType(id1)->IsSame(manager.GetType(id2))) + << "%struct" << id1 << " is expected to be the same as %struct" + << id2; + } else { + EXPECT_FALSE(manager.GetType(id1)->IsSame(manager.GetType(id2))) + << "%struct" << id1 << " is expected to be different with %struct" + << id2; + } + } + } +} + +TEST(TypeManager, DecorationEmpty) { + const std::string text = R"( + OpDecorate %struct1 Block + OpMemberDecorate %struct2 0 Offset 0 + + %u32 = OpTypeInt 32 0 ; id: 3 + %f32 = OpTypeFloat 32 ; id: 4 + %struct1 = OpTypeStruct %u32 %f32 + %struct2 = OpTypeStruct %f32 %u32 + %struct5 = OpTypeStruct %f32 + )"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); + TypeManager manager(nullptr, context.get()); + + ASSERT_EQ(5u, manager.NumTypes()); + // Make sure we get ids correct. + ASSERT_EQ("uint32", manager.GetType(3)->str()); + ASSERT_EQ("float32", manager.GetType(4)->str()); + + // %struct1 with decoration on itself + EXPECT_FALSE(manager.GetType(1)->decoration_empty()); + // %struct2 with decoration on its member + EXPECT_FALSE(manager.GetType(2)->decoration_empty()); + EXPECT_TRUE(manager.GetType(3)->decoration_empty()); + EXPECT_TRUE(manager.GetType(4)->decoration_empty()); + // %struct5 has no decorations + EXPECT_TRUE(manager.GetType(5)->decoration_empty()); +} + +TEST(TypeManager, BeginEndForEmptyModule) { + const std::string text = ""; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); + TypeManager manager(nullptr, context.get()); + ASSERT_EQ(0u, manager.NumTypes()); + + EXPECT_EQ(manager.begin(), manager.end()); +} + +TEST(TypeManager, BeginEnd) { + const std::string text = R"( + %void1 = OpTypeVoid + %void2 = OpTypeVoid + %bool = OpTypeBool + %u32 = OpTypeInt 32 0 + %f64 = OpTypeFloat 64 + )"; + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); + TypeManager manager(nullptr, context.get()); + ASSERT_EQ(5u, manager.NumTypes()); + + EXPECT_NE(manager.begin(), manager.end()); + for (const auto& t : manager) { + switch (t.first) { + case 1: + case 2: + EXPECT_EQ("void", t.second->str()); + break; + case 3: + EXPECT_EQ("bool", t.second->str()); + break; + case 4: + EXPECT_EQ("uint32", t.second->str()); + break; + case 5: + EXPECT_EQ("float64", t.second->str()); + break; + default: + EXPECT_TRUE(false && "unreachable"); + break; + } + } +} + +TEST(TypeManager, LookupType) { + const std::string text = R"( +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%int = OpTypeInt 32 1 +%vec2 = OpTypeVector %int 2 +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + EXPECT_NE(context, nullptr); + TypeManager manager(nullptr, context.get()); + + Void voidTy; + EXPECT_EQ(manager.GetId(&voidTy), 1u); + + Integer uintTy(32, false); + EXPECT_EQ(manager.GetId(&uintTy), 2u); + + Integer intTy(32, true); + EXPECT_EQ(manager.GetId(&intTy), 3u); + + Integer intTy2(32, true); + Vector vecTy(&intTy2, 2u); + EXPECT_EQ(manager.GetId(&vecTy), 4u); +} + +TEST(TypeManager, RemoveId) { + const std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeInt 32 0 +%2 = OpTypeInt 32 1 + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(context, nullptr); + + context->get_type_mgr()->RemoveId(1u); + ASSERT_EQ(context->get_type_mgr()->GetType(1u), nullptr); + ASSERT_NE(context->get_type_mgr()->GetType(2u), nullptr); + + context->get_type_mgr()->RemoveId(2u); + ASSERT_EQ(context->get_type_mgr()->GetType(1u), nullptr); + ASSERT_EQ(context->get_type_mgr()->GetType(2u), nullptr); +} + +TEST(TypeManager, RemoveIdNonDuplicateAmbiguousType) { + const std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeInt 32 0 +%2 = OpTypeStruct %1 + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(context, nullptr); + + Integer u32(32, false); + Struct st({&u32}); + ASSERT_EQ(context->get_type_mgr()->GetId(&st), 2u); + context->get_type_mgr()->RemoveId(2u); + ASSERT_EQ(context->get_type_mgr()->GetType(2u), nullptr); + ASSERT_EQ(context->get_type_mgr()->GetId(&st), 0u); +} + +TEST(TypeManager, RemoveIdDuplicateAmbiguousType) { + const std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeInt 32 0 +%2 = OpTypeStruct %1 +%3 = OpTypeStruct %1 + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(context, nullptr); + + Integer u32(32, false); + Struct st({&u32}); + uint32_t id = context->get_type_mgr()->GetId(&st); + ASSERT_NE(id, 0u); + uint32_t toRemove = id == 2u ? 2u : 3u; + uint32_t toStay = id == 2u ? 3u : 2u; + context->get_type_mgr()->RemoveId(toRemove); + ASSERT_EQ(context->get_type_mgr()->GetType(toRemove), nullptr); + ASSERT_EQ(context->get_type_mgr()->GetId(&st), toStay); +} + +TEST(TypeManager, RemoveIdDoesntUnmapOtherTypes) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +%1 = OpTypeInt 32 0 +%2 = OpTypeStruct %1 +%3 = OpTypeStruct %1 + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(context, nullptr); + + Integer u32(32, false); + Struct st({&u32}); + + EXPECT_EQ(1u, context->get_type_mgr()->GetId(&u32)); + uint32_t id = context->get_type_mgr()->GetId(&st); + ASSERT_NE(id, 0u); + uint32_t toRemove = id == 2u ? 3u : 2u; + uint32_t toStay = id == 2u ? 2u : 3u; + context->get_type_mgr()->RemoveId(toRemove); + ASSERT_EQ(context->get_type_mgr()->GetType(toRemove), nullptr); + ASSERT_EQ(context->get_type_mgr()->GetId(&st), toStay); +} + +TEST(TypeManager, GetTypeAndPointerType) { + const std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeInt 32 0 +%2 = OpTypeStruct %1 + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(context, nullptr); + + Integer u32(32, false); + Pointer u32Ptr(&u32, SpvStorageClassFunction); + Struct st({&u32}); + Pointer stPtr(&st, SpvStorageClassInput); + + auto pair = context->get_type_mgr()->GetTypeAndPointerType( + 3u, SpvStorageClassFunction); + ASSERT_EQ(nullptr, pair.first); + ASSERT_EQ(nullptr, pair.second); + + pair = context->get_type_mgr()->GetTypeAndPointerType( + 1u, SpvStorageClassFunction); + ASSERT_TRUE(pair.first->IsSame(&u32)); + ASSERT_TRUE(pair.second->IsSame(&u32Ptr)); + + pair = + context->get_type_mgr()->GetTypeAndPointerType(2u, SpvStorageClassInput); + ASSERT_TRUE(pair.first->IsSame(&st)); + ASSERT_TRUE(pair.second->IsSame(&stPtr)); +} + +TEST(TypeManager, DuplicateType) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +%1 = OpTypeInt 32 0 +%2 = OpTypeInt 32 0 + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(context, nullptr); + + const Type* type1 = context->get_type_mgr()->GetType(1u); + const Type* type2 = context->get_type_mgr()->GetType(2u); + EXPECT_NE(type1, nullptr); + EXPECT_NE(type2, nullptr); + EXPECT_EQ(*type1, *type2); +} + +TEST(TypeManager, MultipleStructs) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpDecorate %3 Constant +%1 = OpTypeInt 32 0 +%2 = OpTypeStruct %1 +%3 = OpTypeStruct %1 + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(context, nullptr); + + const Type* type1 = context->get_type_mgr()->GetType(2u); + const Type* type2 = context->get_type_mgr()->GetType(3u); + EXPECT_NE(type1, nullptr); + EXPECT_NE(type2, nullptr); + EXPECT_FALSE(type1->IsSame(type2)); +} + +TEST(TypeManager, RemovingIdAvoidsUseAfterFree) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +%1 = OpTypeInt 32 0 +%2 = OpTypeStruct %1 + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(context, nullptr); + + Integer u32(32, false); + Struct st({&u32}); + const Type* type = context->get_type_mgr()->GetType(2u); + EXPECT_NE(type, nullptr); + context->get_type_mgr()->RemoveId(1u); + EXPECT_TRUE(type->IsSame(&st)); +} + +TEST(TypeManager, RegisterAndRemoveId) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +%1 = OpTypeInt 32 0 +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(context, nullptr); + + uint32_t id = 2u; + { + // Ensure that u32 goes out of scope. + Integer u32(32, false); + Struct st({&u32}); + context->get_type_mgr()->RegisterType(id, st); + } + + context->get_type_mgr()->RemoveId(id); + EXPECT_EQ(nullptr, context->get_type_mgr()->GetType(id)); +} + +TEST(TypeManager, RegisterAndRemoveIdAllTypes) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(context, nullptr); + + std::vector> types = GenerateAllTypes(); + uint32_t id = 1u; + for (auto& t : types) { + context->get_type_mgr()->RegisterType(id, *t); + EXPECT_EQ(*t, *context->get_type_mgr()->GetType(id)); + } + types.clear(); + + for (; id > 0; --id) { + context->get_type_mgr()->RemoveId(id); + EXPECT_EQ(nullptr, context->get_type_mgr()->GetType(id)); + } +} + +TEST(TypeManager, RegisterAndRemoveIdWithDecorations) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +%1 = OpTypeInt 32 0 +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(context, nullptr); + + uint32_t id = 2u; + { + Integer u32(32, false); + Struct st({&u32, &u32}); + st.AddDecoration({10}); + st.AddDecoration({11}); + st.AddMemberDecoration(0, {{35, 4}}); + st.AddMemberDecoration(1, {{35, 4}}); + st.AddMemberDecoration(1, {{36, 5}}); + context->get_type_mgr()->RegisterType(id, st); + EXPECT_EQ(st, *context->get_type_mgr()->GetType(id)); + } + + context->get_type_mgr()->RemoveId(id); + EXPECT_EQ(nullptr, context->get_type_mgr()->GetType(id)); +} + +TEST(TypeManager, GetTypeInstructionInt) { + const std::string text = R"( +; CHECK: OpTypeInt 32 0 +; CHECK: OpTypeInt 16 1 +OpCapability Shader +OpCapability Int16 +OpCapability Linkage +OpMemoryModel Logical GLSL450 + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + EXPECT_NE(context, nullptr); + + Integer uint_32(32, false); + context->get_type_mgr()->GetTypeInstruction(&uint_32); + + Integer int_16(16, true); + context->get_type_mgr()->GetTypeInstruction(&int_16); + + Match(text, context.get()); +} + +TEST(TypeManager, GetTypeInstructionDuplicateInts) { + const std::string text = R"( +; CHECK: OpTypeInt 32 0 +; CHECK-NOT: OpType +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + EXPECT_NE(context, nullptr); + + Integer uint_32(32, false); + uint32_t id = context->get_type_mgr()->GetTypeInstruction(&uint_32); + + Integer other(32, false); + EXPECT_EQ(context->get_type_mgr()->GetTypeInstruction(&other), id); + + Match(text, context.get()); +} + +TEST(TypeManager, GetTypeInstructionAllTypes) { + const std::string text = R"( +; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 +; CHECK: [[input_ptr:%\w+]] = OpTypePointer Input [[uint]] +; CHECK: [[uniform_ptr:%\w+]] = OpTypePointer Uniform [[uint]] +; CHECK: [[uint24:%\w+]] = OpConstant [[uint]] 24 +; CHECK: [[uint42:%\w+]] = OpConstant [[uint]] 42 +; CHECK: [[uint100:%\w+]] = OpConstant [[uint]] 100 +; CHECK: [[void:%\w+]] = OpTypeVoid +; CHECK: [[bool:%\w+]] = OpTypeBool +; CHECK: [[s32:%\w+]] = OpTypeInt 32 1 +; CHECK: OpTypeInt 64 1 +; CHECK: [[u64:%\w+]] = OpTypeInt 64 0 +; CHECK: [[f32:%\w+]] = OpTypeFloat 32 +; CHECK: OpTypeFloat 64 +; CHECK: OpTypeVector [[s32]] 2 +; CHECK: [[v3s32:%\w+]] = OpTypeVector [[s32]] 3 +; CHECK: OpTypeVector [[u64]] 4 +; CHECK: [[v3f32:%\w+]] = OpTypeVector [[f32]] 3 +; CHECK: OpTypeMatrix [[v3s32]] 3 +; CHECK: OpTypeMatrix [[v3s32]] 4 +; CHECK: OpTypeMatrix [[v3f32]] 4 +; CHECK: [[image1:%\w+]] = OpTypeImage [[s32]] 2D 0 0 0 0 Rg8 ReadOnly +; CHECK: OpTypeImage [[s32]] 2D 0 1 0 0 Rg8 ReadOnly +; CHECK: OpTypeImage [[s32]] 3D 0 1 0 0 Rg8 ReadOnly +; CHECK: [[image2:%\w+]] = OpTypeImage [[void]] 3D 0 1 0 1 Rg8 ReadWrite +; CHECK: OpTypeSampler +; CHECK: OpTypeSampledImage [[image1]] +; CHECK: OpTypeSampledImage [[image2]] +; CHECK: OpTypeArray [[f32]] [[uint100]] +; CHECK: [[a42f32:%\w+]] = OpTypeArray [[f32]] [[uint42]] +; CHECK: OpTypeArray [[u64]] [[uint24]] +; CHECK: OpTypeRuntimeArray [[v3f32]] +; CHECK: [[rav3s32:%\w+]] = OpTypeRuntimeArray [[v3s32]] +; CHECK: OpTypeStruct [[s32]] +; CHECK: [[sts32f32:%\w+]] = OpTypeStruct [[s32]] [[f32]] +; CHECK: OpTypeStruct [[u64]] [[a42f32]] [[rav3s32]] +; CHECK: OpTypeOpaque "" +; CHECK: OpTypeOpaque "hello" +; CHECK: OpTypeOpaque "world" +; CHECK: OpTypePointer Input [[f32]] +; CHECK: OpTypePointer Function [[sts32f32]] +; CHECK: OpTypePointer Function [[a42f32]] +; CHECK: OpTypeFunction [[void]] +; CHECK: OpTypeFunction [[void]] [[bool]] +; CHECK: OpTypeFunction [[void]] [[bool]] [[s32]] +; CHECK: OpTypeFunction [[s32]] [[bool]] [[s32]] +; CHECK: OpTypeEvent +; CHECK: OpTypeDeviceEvent +; CHECK: OpTypeReserveId +; CHECK: OpTypeQueue +; CHECK: OpTypePipe ReadWrite +; CHECK: OpTypePipe ReadOnly +; CHECK: OpTypeForwardPointer [[input_ptr]] Input +; CHECK: OpTypeForwardPointer [[uniform_ptr]] Input +; CHECK: OpTypeForwardPointer [[uniform_ptr]] Uniform +; CHECK: OpTypePipeStorage +; CHECK: OpTypeNamedBarrier +; CHECK: OpTypeAccelerationStructureNV +OpCapability Shader +OpCapability Int64 +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%uint = OpTypeInt 32 0 +%1 = OpTypePointer Input %uint +%2 = OpTypePointer Uniform %uint +%24 = OpConstant %uint 24 +%42 = OpConstant %uint 42 +%100 = OpConstant %uint 100 + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(context, nullptr); + + std::vector> types = GenerateAllTypes(); + for (auto& t : types) { + context->get_type_mgr()->GetTypeInstruction(t.get()); + } + + Match(text, context.get(), false); +} + +TEST(TypeManager, GetTypeInstructionWithDecorations) { + const std::string text = R"( +; CHECK: OpDecorate [[struct:%\w+]] CPacked +; CHECK: OpMemberDecorate [[struct]] 1 Offset 4 +; CHECK: [[uint:%\w+]] = OpTypeInt 32 0 +; CHECK: [[struct]] = OpTypeStruct [[uint]] [[uint]] +OpCapability Shader +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%uint = OpTypeInt 32 0 + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(context, nullptr); + + Integer u32(32, false); + Struct st({&u32, &u32}); + st.AddDecoration({10}); + st.AddMemberDecoration(1, {{35, 4}}); + (void)context->get_def_use_mgr(); + context->get_type_mgr()->GetTypeInstruction(&st); + + Match(text, context.get()); +} + +TEST(TypeManager, GetPointerToAmbiguousType1) { + const std::string text = R"( +; CHECK: [[struct1:%\w+]] = OpTypeStruct +; CHECK: [[struct2:%\w+]] = OpTypeStruct +; CHECK: OpTypePointer Function [[struct2]] +; CHECK: OpTypePointer Function [[struct1]] +OpCapability Shader +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%uint = OpTypeInt 32 0 +%1 = OpTypeStruct %uint +%2 = OpTypeStruct %uint +%3 = OpTypePointer Function %2 + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(context, nullptr); + + context->get_type_mgr()->FindPointerToType(1, SpvStorageClassFunction); + Match(text, context.get()); +} + +TEST(TypeManager, GetPointerToAmbiguousType2) { + const std::string text = R"( +; CHECK: [[struct1:%\w+]] = OpTypeStruct +; CHECK: [[struct2:%\w+]] = OpTypeStruct +; CHECK: OpTypePointer Function [[struct1]] +; CHECK: OpTypePointer Function [[struct2]] +OpCapability Shader +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%uint = OpTypeInt 32 0 +%1 = OpTypeStruct %uint +%2 = OpTypeStruct %uint +%3 = OpTypePointer Function %1 + )"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + EXPECT_NE(context, nullptr); + + context->get_type_mgr()->FindPointerToType(2, SpvStorageClassFunction); + Match(text, context.get()); +} + +} // namespace +} // namespace analysis +} // namespace opt +} // namespace spvtools diff --git a/test/opt/types_test.cpp b/test/opt/types_test.cpp new file mode 100644 index 000000000..7426ed799 --- /dev/null +++ b/test/opt/types_test.cpp @@ -0,0 +1,345 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gtest/gtest.h" +#include "source/opt/types.h" +#include "source/util/make_unique.h" + +namespace spvtools { +namespace opt { +namespace analysis { +namespace { + +// Fixture class providing some element types. +class SameTypeTest : public ::testing::Test { + protected: + void SetUp() override { + void_t_ = MakeUnique(); + u32_t_ = MakeUnique(32, false); + f64_t_ = MakeUnique(64); + v3u32_t_ = MakeUnique(u32_t_.get(), 3); + image_t_ = + MakeUnique(f64_t_.get(), SpvDim2D, 1, 1, 0, 0, SpvImageFormatR16, + SpvAccessQualifierReadWrite); + } + + // Element types to be used for constructing other types for testing. + std::unique_ptr void_t_; + std::unique_ptr u32_t_; + std::unique_ptr f64_t_; + std::unique_ptr v3u32_t_; + std::unique_ptr image_t_; +}; + +#define TestMultipleInstancesOfTheSameType(ty, ...) \ + TEST_F(SameTypeTest, MultiSame##ty) { \ + std::vector> types; \ + for (int i = 0; i < 10; ++i) types.emplace_back(new ty(__VA_ARGS__)); \ + for (size_t i = 0; i < types.size(); ++i) { \ + for (size_t j = 0; j < types.size(); ++j) { \ + EXPECT_TRUE(types[i]->IsSame(types[j].get())) \ + << "expected '" << types[i]->str() << "' is the same as '" \ + << types[j]->str() << "'"; \ + EXPECT_TRUE(*types[i] == *types[j]) \ + << "expected '" << types[i]->str() << "' is the same as '" \ + << types[j]->str() << "'"; \ + } \ + } \ + } +TestMultipleInstancesOfTheSameType(Void); +TestMultipleInstancesOfTheSameType(Bool); +TestMultipleInstancesOfTheSameType(Integer, 32, true); +TestMultipleInstancesOfTheSameType(Float, 64); +TestMultipleInstancesOfTheSameType(Vector, u32_t_.get(), 3); +TestMultipleInstancesOfTheSameType(Matrix, v3u32_t_.get(), 4); +TestMultipleInstancesOfTheSameType(Image, f64_t_.get(), SpvDimCube, 0, 0, 1, 1, + SpvImageFormatRgb10A2, + SpvAccessQualifierWriteOnly); +TestMultipleInstancesOfTheSameType(Sampler); +TestMultipleInstancesOfTheSameType(SampledImage, image_t_.get()); +TestMultipleInstancesOfTheSameType(Array, u32_t_.get(), 10); +TestMultipleInstancesOfTheSameType(RuntimeArray, u32_t_.get()); +TestMultipleInstancesOfTheSameType(Struct, std::vector{ + u32_t_.get(), f64_t_.get()}); +TestMultipleInstancesOfTheSameType(Opaque, "testing rocks"); +TestMultipleInstancesOfTheSameType(Pointer, u32_t_.get(), SpvStorageClassInput); +TestMultipleInstancesOfTheSameType(Function, u32_t_.get(), + {f64_t_.get(), f64_t_.get()}); +TestMultipleInstancesOfTheSameType(Event); +TestMultipleInstancesOfTheSameType(DeviceEvent); +TestMultipleInstancesOfTheSameType(ReserveId); +TestMultipleInstancesOfTheSameType(Queue); +TestMultipleInstancesOfTheSameType(Pipe, SpvAccessQualifierReadWrite); +TestMultipleInstancesOfTheSameType(ForwardPointer, 10, SpvStorageClassUniform); +TestMultipleInstancesOfTheSameType(PipeStorage); +TestMultipleInstancesOfTheSameType(NamedBarrier); +TestMultipleInstancesOfTheSameType(AccelerationStructureNV); +#undef TestMultipleInstanceOfTheSameType + +std::vector> GenerateAllTypes() { + // Types in this test case are only equal to themselves, nothing else. + std::vector> types; + + // Forward Pointer + types.emplace_back(new ForwardPointer(10000, SpvStorageClassInput)); + types.emplace_back(new ForwardPointer(20000, SpvStorageClassInput)); + + // Void, Bool + types.emplace_back(new Void()); + auto* voidt = types.back().get(); + types.emplace_back(new Bool()); + auto* boolt = types.back().get(); + + // Integer + types.emplace_back(new Integer(32, true)); + auto* s32 = types.back().get(); + types.emplace_back(new Integer(32, false)); + types.emplace_back(new Integer(64, true)); + types.emplace_back(new Integer(64, false)); + auto* u64 = types.back().get(); + + // Float + types.emplace_back(new Float(32)); + auto* f32 = types.back().get(); + types.emplace_back(new Float(64)); + + // Vector + types.emplace_back(new Vector(s32, 2)); + types.emplace_back(new Vector(s32, 3)); + auto* v3s32 = types.back().get(); + types.emplace_back(new Vector(u64, 4)); + types.emplace_back(new Vector(f32, 3)); + auto* v3f32 = types.back().get(); + + // Matrix + types.emplace_back(new Matrix(v3s32, 3)); + types.emplace_back(new Matrix(v3s32, 4)); + types.emplace_back(new Matrix(v3f32, 4)); + + // Images + types.emplace_back(new Image(s32, SpvDim2D, 0, 0, 0, 0, SpvImageFormatRg8, + SpvAccessQualifierReadOnly)); + auto* image1 = types.back().get(); + types.emplace_back(new Image(s32, SpvDim2D, 0, 1, 0, 0, SpvImageFormatRg8, + SpvAccessQualifierReadOnly)); + types.emplace_back(new Image(s32, SpvDim3D, 0, 1, 0, 0, SpvImageFormatRg8, + SpvAccessQualifierReadOnly)); + types.emplace_back(new Image(voidt, SpvDim3D, 0, 1, 0, 1, SpvImageFormatRg8, + SpvAccessQualifierReadWrite)); + auto* image2 = types.back().get(); + + // Sampler + types.emplace_back(new Sampler()); + + // Sampled Image + types.emplace_back(new SampledImage(image1)); + types.emplace_back(new SampledImage(image2)); + + // Array + types.emplace_back(new Array(f32, 100)); + types.emplace_back(new Array(f32, 42)); + auto* a42f32 = types.back().get(); + types.emplace_back(new Array(u64, 24)); + + // RuntimeArray + types.emplace_back(new RuntimeArray(v3f32)); + types.emplace_back(new RuntimeArray(v3s32)); + auto* rav3s32 = types.back().get(); + + // Struct + types.emplace_back(new Struct(std::vector{s32})); + types.emplace_back(new Struct(std::vector{s32, f32})); + auto* sts32f32 = types.back().get(); + types.emplace_back( + new Struct(std::vector{u64, a42f32, rav3s32})); + + // Opaque + types.emplace_back(new Opaque("")); + types.emplace_back(new Opaque("hello")); + types.emplace_back(new Opaque("world")); + + // Pointer + types.emplace_back(new Pointer(f32, SpvStorageClassInput)); + types.emplace_back(new Pointer(sts32f32, SpvStorageClassFunction)); + types.emplace_back(new Pointer(a42f32, SpvStorageClassFunction)); + types.emplace_back(new Pointer(voidt, SpvStorageClassFunction)); + + // Function + types.emplace_back(new Function(voidt, {})); + types.emplace_back(new Function(voidt, {boolt})); + types.emplace_back(new Function(voidt, {boolt, s32})); + types.emplace_back(new Function(s32, {boolt, s32})); + + // Event, Device Event, Reserve Id, Queue, + types.emplace_back(new Event()); + types.emplace_back(new DeviceEvent()); + types.emplace_back(new ReserveId()); + types.emplace_back(new Queue()); + + // Pipe, Forward Pointer, PipeStorage, NamedBarrier + types.emplace_back(new Pipe(SpvAccessQualifierReadWrite)); + types.emplace_back(new Pipe(SpvAccessQualifierReadOnly)); + types.emplace_back(new ForwardPointer(1, SpvStorageClassInput)); + types.emplace_back(new ForwardPointer(2, SpvStorageClassInput)); + types.emplace_back(new ForwardPointer(2, SpvStorageClassUniform)); + types.emplace_back(new PipeStorage()); + types.emplace_back(new NamedBarrier()); + + return types; +} + +TEST(Types, AllTypes) { + // Types in this test case are only equal to themselves, nothing else. + std::vector> types = GenerateAllTypes(); + + for (size_t i = 0; i < types.size(); ++i) { + for (size_t j = 0; j < types.size(); ++j) { + if (i == j) { + EXPECT_TRUE(types[i]->IsSame(types[j].get())) + << "expected '" << types[i]->str() << "' is the same as '" + << types[j]->str() << "'"; + } else { + EXPECT_FALSE(types[i]->IsSame(types[j].get())) + << "expected '" << types[i]->str() << "' is different to '" + << types[j]->str() << "'"; + } + } + } +} + +TEST(Types, IntSignedness) { + std::vector signednesses = {true, false, false, true}; + std::vector> types; + for (bool s : signednesses) { + types.emplace_back(new Integer(32, s)); + } + for (size_t i = 0; i < signednesses.size(); i++) { + EXPECT_EQ(signednesses[i], types[i]->IsSigned()); + } +} + +TEST(Types, IntWidth) { + std::vector widths = {1, 2, 4, 8, 16, 32, 48, 64, 128}; + std::vector> types; + for (uint32_t w : widths) { + types.emplace_back(new Integer(w, true)); + } + for (size_t i = 0; i < widths.size(); i++) { + EXPECT_EQ(widths[i], types[i]->width()); + } +} + +TEST(Types, FloatWidth) { + std::vector widths = {1, 2, 4, 8, 16, 32, 48, 64, 128}; + std::vector> types; + for (uint32_t w : widths) { + types.emplace_back(new Float(w)); + } + for (size_t i = 0; i < widths.size(); i++) { + EXPECT_EQ(widths[i], types[i]->width()); + } +} + +TEST(Types, VectorElementCount) { + auto s32 = MakeUnique(32, true); + for (uint32_t c : {2, 3, 4}) { + auto s32v = MakeUnique(s32.get(), c); + EXPECT_EQ(c, s32v->element_count()); + } +} + +TEST(Types, MatrixElementCount) { + auto s32 = MakeUnique(32, true); + auto s32v4 = MakeUnique(s32.get(), 4); + for (uint32_t c : {1, 2, 3, 4, 10, 100}) { + auto s32m = MakeUnique(s32v4.get(), c); + EXPECT_EQ(c, s32m->element_count()); + } +} + +TEST(Types, IsUniqueType) { + std::vector> types = GenerateAllTypes(); + + for (auto& t : types) { + bool expectation = true; + // Disallowing variable pointers. + switch (t->kind()) { + case Type::kArray: + case Type::kRuntimeArray: + case Type::kStruct: + expectation = false; + break; + default: + break; + } + EXPECT_EQ(t->IsUniqueType(false), expectation) + << "expected '" << t->str() << "' to be a " + << (expectation ? "" : "non-") << "unique type"; + + // Allowing variables pointers. + if (t->AsPointer()) expectation = false; + EXPECT_EQ(t->IsUniqueType(true), expectation) + << "expected '" << t->str() << "' to be a " + << (expectation ? "" : "non-") << "unique type"; + } +} + +std::vector> GenerateAllTypesWithDecorations() { + std::vector> types = GenerateAllTypes(); + uint32_t elems = 1; + uint32_t decs = 1; + for (auto& t : types) { + for (uint32_t i = 0; i < (decs % 10); ++i) { + std::vector decoration; + for (uint32_t j = 0; j < (elems % 4) + 1; ++j) { + decoration.push_back(j); + } + t->AddDecoration(std::move(decoration)); + ++elems; + ++decs; + } + } + + return types; +} + +TEST(Types, Clone) { + std::vector> types = GenerateAllTypesWithDecorations(); + for (auto& t : types) { + auto clone = t->Clone(); + EXPECT_TRUE(*t == *clone); + EXPECT_TRUE(t->HasSameDecorations(clone.get())); + EXPECT_NE(clone.get(), t.get()); + } +} + +TEST(Types, RemoveDecorations) { + std::vector> types = GenerateAllTypesWithDecorations(); + for (auto& t : types) { + auto decorationless = t->RemoveDecorations(); + EXPECT_EQ(*t == *decorationless, t->decoration_empty()); + EXPECT_EQ(t->HasSameDecorations(decorationless.get()), + t->decoration_empty()); + EXPECT_NE(t.get(), decorationless.get()); + } +} + +} // namespace +} // namespace analysis +} // namespace opt +} // namespace spvtools diff --git a/test/opt/unify_const_test.cpp b/test/opt/unify_const_test.cpp new file mode 100644 index 000000000..37728cc23 --- /dev/null +++ b/test/opt/unify_const_test.cpp @@ -0,0 +1,990 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +// Returns the types defining instructions commonly used in many tests. +std::vector CommonTypes() { + return std::vector{ + // clang-format off + // scalar types + "%bool = OpTypeBool", + "%uint = OpTypeInt 32 0", + "%int = OpTypeInt 32 1", + "%uint64 = OpTypeInt 64 0", + "%int64 = OpTypeInt 64 1", + "%float = OpTypeFloat 32", + "%double = OpTypeFloat 64", + // vector types + "%v2bool = OpTypeVector %bool 2", + "%v2uint = OpTypeVector %uint 2", + "%v2int = OpTypeVector %int 2", + "%v3int = OpTypeVector %int 3", + "%v4int = OpTypeVector %int 4", + "%v2float = OpTypeVector %float 2", + "%v3float = OpTypeVector %float 3", + "%v2double = OpTypeVector %double 2", + // struct types + "%inner_struct = OpTypeStruct %bool %float", + "%outer_struct = OpTypeStruct %inner_struct %int %double", + "%flat_struct = OpTypeStruct %bool %int %float %double", + // variable pointer types + "%_pf_bool = OpTypePointer Function %bool", + "%_pf_uint = OpTypePointer Function %uint", + "%_pf_int = OpTypePointer Function %int", + "%_pf_uint64 = OpTypePointer Function %uint64", + "%_pf_int64 = OpTypePointer Function %int64", + "%_pf_float = OpTypePointer Function %float", + "%_pf_double = OpTypePointer Function %double", + "%_pf_v2int = OpTypePointer Function %v2int", + "%_pf_v3int = OpTypePointer Function %v3int", + "%_pf_v4int = OpTypePointer Function %v4int", + "%_pf_v2float = OpTypePointer Function %v2float", + "%_pf_v3float = OpTypePointer Function %v3float", + "%_pf_v2double = OpTypePointer Function %v2double", + "%_pf_inner_struct = OpTypePointer Function %inner_struct", + "%_pf_outer_struct = OpTypePointer Function %outer_struct", + "%_pf_flat_struct = OpTypePointer Function %flat_struct", + // clang-format on + }; +} + +// A helper function to strip OpName instructions from the given string of +// disassembly code and put those debug instructions to a set. Returns the +// string with all OpName instruction stripped and a set of OpName +// instructions. +std::tuple> +StripOpNameInstructionsToSet(const std::string& str) { + std::stringstream ss(str); + std::ostringstream oss; + std::string inst_str; + std::unordered_set opname_instructions; + while (std::getline(ss, inst_str, '\n')) { + if (inst_str.find("OpName %") == std::string::npos) { + oss << inst_str << '\n'; + } else { + opname_instructions.insert(inst_str); + } + } + return std::make_tuple(oss.str(), std::move(opname_instructions)); +} + +// The test fixture for all tests of UnifyConstantPass. This fixture defines +// the rule of checking: all the optimized code should be exactly the same as +// the expected code, except the OpName instructions, which can be different in +// order. +template +class UnifyConstantTest : public PassTest { + protected: + // Runs UnifyConstantPass on the code built from the given |test_builder|, + // and checks whether the optimization result matches with the code built + // from |expected_builder|. + void Check(const AssemblyBuilder& expected_builder, + const AssemblyBuilder& test_builder) { + // unoptimized code + const std::string original_before_strip = test_builder.GetCode(); + std::string original_without_opnames; + std::unordered_set original_opnames; + std::tie(original_without_opnames, original_opnames) = + StripOpNameInstructionsToSet(original_before_strip); + + // expected code + std::string expected_without_opnames; + std::unordered_set expected_opnames; + std::tie(expected_without_opnames, expected_opnames) = + StripOpNameInstructionsToSet(expected_builder.GetCode()); + + // optimized code + std::string optimized_before_strip; + auto status = Pass::Status::SuccessWithoutChange; + std::tie(optimized_before_strip, status) = + this->template SinglePassRunAndDisassemble( + test_builder.GetCode(), + /* skip_nop = */ true, /* do_validation = */ false); + std::string optimized_without_opnames; + std::unordered_set optimized_opnames; + std::tie(optimized_without_opnames, optimized_opnames) = + StripOpNameInstructionsToSet(optimized_before_strip); + + // Flag "status" should be returned correctly. + EXPECT_NE(Pass::Status::Failure, status); + EXPECT_EQ(expected_without_opnames == original_without_opnames, + status == Pass::Status::SuccessWithoutChange); + // Code except OpName instructions should be exactly the same. + EXPECT_EQ(expected_without_opnames, optimized_without_opnames); + // OpName instructions can be in different order, but the content must be + // the same. + EXPECT_EQ(expected_opnames, optimized_opnames); + } +}; + +using UnifyFrontEndConstantSingleTest = + UnifyConstantTest>; + +TEST_F(UnifyFrontEndConstantSingleTest, Basic) { + AssemblyBuilder test_builder; + AssemblyBuilder expected_builder; + + test_builder + .AppendTypesConstantsGlobals({ + "%uint = OpTypeInt 32 0", "%_pf_uint = OpTypePointer Function %uint", + "%unsigned_1 = OpConstant %uint 1", + "%unsigned_1_duplicate = OpConstant %uint 1", // duplicated constant + }) + .AppendInMain({ + "%uint_var = OpVariable %_pf_uint Function", + "OpStore %uint_var %unsigned_1_duplicate", + }); + + expected_builder + .AppendTypesConstantsGlobals({ + "%uint = OpTypeInt 32 0", + "%_pf_uint = OpTypePointer Function %uint", + "%unsigned_1 = OpConstant %uint 1", + }) + .AppendInMain({ + "%uint_var = OpVariable %_pf_uint Function", + "OpStore %uint_var %unsigned_1", + }) + .AppendNames({ + "OpName %unsigned_1 \"unsigned_1_duplicate\"", // the OpName + // instruction of the + // removed duplicated + // constant won't be + // erased. + }); + Check(expected_builder, test_builder); +} + +TEST_F(UnifyFrontEndConstantSingleTest, SkipWhenResultIdHasDecorations) { + AssemblyBuilder test_builder; + AssemblyBuilder expected_builder; + + test_builder + .AppendAnnotations({ + // So far we don't have valid decorations for constants. This is + // preparing for the future updates of SPIR-V. + // TODO(qining): change to a valid decoration once they are available. + "OpDecorate %f_1 RelaxedPrecision", + "OpDecorate %f_2_dup RelaxedPrecision", + }) + .AppendTypesConstantsGlobals({ + // clang-format off + "%float = OpTypeFloat 32", + "%_pf_float = OpTypePointer Function %float", + "%f_1 = OpConstant %float 1", + // %f_1 has decoration, so %f_1 will not be used to replace %f_1_dup. + "%f_1_dup = OpConstant %float 1", + "%f_2 = OpConstant %float 2", + // %_2_dup has decoration, so %f_2 will not replace %f_2_dup. + "%f_2_dup = OpConstant %float 2", + // no decoration for %f_3 or %f_3_dup, %f_3_dup should be replaced. + "%f_3 = OpConstant %float 3", + "%f_3_dup = OpConstant %float 3", + // clang-format on + }) + .AppendInMain({ + // clang-format off + "%f_var = OpVariable %_pf_float Function", + "OpStore %f_var %f_1_dup", + "OpStore %f_var %f_2_dup", + "OpStore %f_var %f_3_dup", + // clang-format on + }); + + expected_builder + .AppendAnnotations({ + "OpDecorate %f_1 RelaxedPrecision", + "OpDecorate %f_2_dup RelaxedPrecision", + }) + .AppendTypesConstantsGlobals({ + // clang-format off + "%float = OpTypeFloat 32", + "%_pf_float = OpTypePointer Function %float", + "%f_1 = OpConstant %float 1", + "%f_1_dup = OpConstant %float 1", + "%f_2 = OpConstant %float 2", + "%f_2_dup = OpConstant %float 2", + "%f_3 = OpConstant %float 3", + // clang-format on + }) + .AppendInMain({ + // clang-format off + "%f_var = OpVariable %_pf_float Function", + "OpStore %f_var %f_1_dup", + "OpStore %f_var %f_2_dup", + "OpStore %f_var %f_3", + // clang-format on + }) + .AppendNames({ + "OpName %f_3 \"f_3_dup\"", + }); + + Check(expected_builder, test_builder); +} + +TEST_F(UnifyFrontEndConstantSingleTest, UnifyWithDecorationOnTypes) { + AssemblyBuilder test_builder; + AssemblyBuilder expected_builder; + + test_builder + .AppendAnnotations({ + "OpMemberDecorate %flat_d 1 RelaxedPrecision", + }) + .AppendTypesConstantsGlobals({ + // clang-format off + "%int = OpTypeInt 32 1", + "%float = OpTypeFloat 32", + "%flat = OpTypeStruct %int %float", + "%_pf_flat = OpTypePointer Function %flat", + // decorated flat struct + "%flat_d = OpTypeStruct %int %float", + "%_pf_flat_d = OpTypePointer Function %flat_d", + // perserved contants. %flat_1 and %flat_d has same members, but + // their type are different in decorations, so they should not be + // used to replace each other. + "%int_1 = OpConstant %int 1", + "%float_1 = OpConstant %float 1", + "%flat_1 = OpConstantComposite %flat %int_1 %float_1", + "%flat_d_1 = OpConstantComposite %flat_d %int_1 %float_1", + // duplicated constants. + "%flat_1_dup = OpConstantComposite %flat %int_1 %float_1", + "%flat_d_1_dup = OpConstantComposite %flat_d %int_1 %float_1", + // clang-format on + }) + .AppendInMain({ + "%flat_var = OpVariable %_pf_flat Function", + "OpStore %flat_var %flat_1_dup", + "%flat_d_var = OpVariable %_pf_flat_d Function", + "OpStore %flat_d_var %flat_d_1_dup", + }); + + expected_builder + .AppendAnnotations({ + "OpMemberDecorate %flat_d 1 RelaxedPrecision", + }) + .AppendTypesConstantsGlobals({ + // clang-format off + "%int = OpTypeInt 32 1", + "%float = OpTypeFloat 32", + "%flat = OpTypeStruct %int %float", + "%_pf_flat = OpTypePointer Function %flat", + // decorated flat struct + "%flat_d = OpTypeStruct %int %float", + "%_pf_flat_d = OpTypePointer Function %flat_d", + "%int_1 = OpConstant %int 1", + "%float_1 = OpConstant %float 1", + "%flat_1 = OpConstantComposite %flat %int_1 %float_1", + "%flat_d_1 = OpConstantComposite %flat_d %int_1 %float_1", + // clang-format on + }) + .AppendInMain({ + "%flat_var = OpVariable %_pf_flat Function", + "OpStore %flat_var %flat_1", + "%flat_d_var = OpVariable %_pf_flat_d Function", + "OpStore %flat_d_var %flat_d_1", + }) + .AppendNames({ + "OpName %flat_1 \"flat_1_dup\"", + "OpName %flat_d_1 \"flat_d_1_dup\"", + }); + + Check(expected_builder, test_builder); +} + +struct UnifyConstantTestCase { + // preserved constants. + std::vector preserved_consts; + // expected uses of the preserved constants. + std::vector use_preserved_consts; + // duplicated constants of the preserved constants. + std::vector duplicate_consts; + // uses of the duplicated constants, expected to be updated to use the + // preserved constants. + std::vector use_duplicate_consts; + // The updated OpName instructions that originally refer to duplicated + // constants. + std::vector remapped_names; +}; + +using UnifyFrontEndConstantParamTest = UnifyConstantTest< + PassTest<::testing::TestWithParam>>; + +TEST_P(UnifyFrontEndConstantParamTest, TestCase) { + auto& tc = GetParam(); + AssemblyBuilder test_builder; + AssemblyBuilder expected_builder; + test_builder.AppendTypesConstantsGlobals(CommonTypes()); + expected_builder.AppendTypesConstantsGlobals(CommonTypes()); + + test_builder.AppendTypesConstantsGlobals(tc.preserved_consts) + .AppendTypesConstantsGlobals(tc.duplicate_consts) + .AppendInMain(tc.use_duplicate_consts); + + // Duplicated constants are killed in the expected output, and the debug + // instructions attached to those duplicated instructions will be migrated to + // the corresponding preserved constants. + expected_builder.AppendTypesConstantsGlobals(tc.preserved_consts) + .AppendInMain(tc.use_preserved_consts) + .AppendNames(tc.remapped_names); + + Check(expected_builder, test_builder); +} + +INSTANTIATE_TEST_CASE_P(Case, UnifyFrontEndConstantParamTest, + ::testing::ValuesIn(std::vector({ + // clang-format off + // basic tests for scalar constants + { + // preserved constants + { + "%bool_true = OpConstantTrue %bool", + "%signed_1 = OpConstant %int 1", + "%signed_minus_1 = OpConstant %int64 -1", + "%unsigned_max = OpConstant %uint64 18446744073709551615", + "%float_1 = OpConstant %float 1", + "%double_1 = OpConstant %double 1", + }, + // use preserved constants in main + { + "%bool_var = OpVariable %_pf_bool Function", + "OpStore %bool_var %bool_true", + "%int_var = OpVariable %_pf_int Function", + "OpStore %int_var %signed_1", + "%int64_var = OpVariable %_pf_int64 Function", + "OpStore %int64_var %signed_minus_1", + "%uint64_var = OpVariable %_pf_uint64 Function", + "OpStore %uint64_var %unsigned_max", + "%float_var = OpVariable %_pf_float Function", + "OpStore %float_var %float_1", + "%double_var = OpVariable %_pf_double Function", + "OpStore %double_var %double_1", + }, + // duplicated constants + { + "%bool_true_duplicate = OpConstantTrue %bool", + "%signed_1_duplicate = OpConstant %int 1", + "%signed_minus_1_duplicate = OpConstant %int64 -1", + "%unsigned_max_duplicate = OpConstant %uint64 18446744073709551615", + "%float_1_duplicate = OpConstant %float 1", + "%double_1_duplicate = OpConstant %double 1", + }, + // use duplicated constants in main + { + "%bool_var = OpVariable %_pf_bool Function", + "OpStore %bool_var %bool_true_duplicate", + "%int_var = OpVariable %_pf_int Function", + "OpStore %int_var %signed_1_duplicate", + "%int64_var = OpVariable %_pf_int64 Function", + "OpStore %int64_var %signed_minus_1_duplicate", + "%uint64_var = OpVariable %_pf_uint64 Function", + "OpStore %uint64_var %unsigned_max_duplicate", + "%float_var = OpVariable %_pf_float Function", + "OpStore %float_var %float_1_duplicate", + "%double_var = OpVariable %_pf_double Function", + "OpStore %double_var %double_1_duplicate", + }, + // remapped names + { + "OpName %bool_true \"bool_true_duplicate\"", + "OpName %signed_1 \"signed_1_duplicate\"", + "OpName %signed_minus_1 \"signed_minus_1_duplicate\"", + "OpName %unsigned_max \"unsigned_max_duplicate\"", + "OpName %float_1 \"float_1_duplicate\"", + "OpName %double_1 \"double_1_duplicate\"", + }, + }, + // NaN in different bit patterns should not be unified, but the ones + // using same bit pattern should be unified. + { + // preserved constants + { + "%float_nan_1 = OpConstant %float 0x1.8p+128", // !2143289344, 7FC00000 + "%float_nan_2 = OpConstant %float 0x1.800002p+128",// !2143289345 7FC00001 + }, + // use preserved constants in main + { + "%float_var = OpVariable %_pf_float Function", + "OpStore %float_var %float_nan_1", + "OpStore %float_var %float_nan_2", + }, + // duplicated constants + { + "%float_nan_1_duplicate = OpConstant %float 0x1.8p+128", // !2143289344, 7FC00000 + "%float_nan_2_duplicate = OpConstant %float 0x1.800002p+128",// !2143289345, 7FC00001 + }, + // use duplicated constants in main + { + "%float_var = OpVariable %_pf_float Function", + "OpStore %float_var %float_nan_1_duplicate", + "OpStore %float_var %float_nan_2_duplicate", + }, + // remapped names + { + "OpName %float_nan_1 \"float_nan_1_duplicate\"", + "OpName %float_nan_2 \"float_nan_2_duplicate\"", + }, + }, + // null values + { + // preserved constants + { + "%bool_null = OpConstantNull %bool", + "%signed_null = OpConstantNull %int", + "%signed_64_null = OpConstantNull %int64", + "%float_null = OpConstantNull %float", + "%double_null = OpConstantNull %double", + // zero-valued constants will not be unified with the equivalent + // null constants. + "%signed_zero = OpConstant %int 0", + }, + // use preserved constants in main + { + "%bool_var = OpVariable %_pf_bool Function", + "OpStore %bool_var %bool_null", + "%int_var = OpVariable %_pf_int Function", + "OpStore %int_var %signed_null", + "%int64_var = OpVariable %_pf_int64 Function", + "OpStore %int64_var %signed_64_null", + "%float_var = OpVariable %_pf_float Function", + "OpStore %float_var %float_null", + "%double_var = OpVariable %_pf_double Function", + "OpStore %double_var %double_null", + }, + // duplicated constants + { + "%bool_null_duplicate = OpConstantNull %bool", + "%signed_null_duplicate = OpConstantNull %int", + "%signed_64_null_duplicate = OpConstantNull %int64", + "%float_null_duplicate = OpConstantNull %float", + "%double_null_duplicate = OpConstantNull %double", + }, + // use duplicated constants in main + { + "%bool_var = OpVariable %_pf_bool Function", + "OpStore %bool_var %bool_null_duplicate", + "%int_var = OpVariable %_pf_int Function", + "OpStore %int_var %signed_null_duplicate", + "%int64_var = OpVariable %_pf_int64 Function", + "OpStore %int64_var %signed_64_null_duplicate", + "%float_var = OpVariable %_pf_float Function", + "OpStore %float_var %float_null_duplicate", + "%double_var = OpVariable %_pf_double Function", + "OpStore %double_var %double_null_duplicate", + }, + // remapped names + { + "OpName %bool_null \"bool_null_duplicate\"", + "OpName %signed_null \"signed_null_duplicate\"", + "OpName %signed_64_null \"signed_64_null_duplicate\"", + "OpName %float_null \"float_null_duplicate\"", + "OpName %double_null \"double_null_duplicate\"", + }, + }, + // constant sampler + { + // preserved constants + { + "%sampler = OpTypeSampler", + "%_pf_sampler = OpTypePointer Function %sampler", + "%sampler_1 = OpConstantSampler %sampler Repeat 0 Linear", + }, + // use preserved constants in main + { + "%sampler_var = OpVariable %_pf_sampler Function", + "OpStore %sampler_var %sampler_1", + }, + // duplicated constants + { + "%sampler_1_duplicate = OpConstantSampler %sampler Repeat 0 Linear", + }, + // use duplicated constants in main + { + "%sampler_var = OpVariable %_pf_sampler Function", + "OpStore %sampler_var %sampler_1_duplicate", + }, + // remapped names + { + "OpName %sampler_1 \"sampler_1_duplicate\"", + }, + }, + // duplicate vector built from same ids. + { + // preserved constants + { + "%signed_1 = OpConstant %int 1", + "%signed_2 = OpConstant %int 2", + "%signed_3 = OpConstant %int 3", + "%signed_4 = OpConstant %int 4", + "%vec = OpConstantComposite %v4int %signed_1 %signed_2 %signed_3 %signed_4", + }, + // use preserved constants in main + { + "%vec_var = OpVariable %_pf_v4int Function", + "OpStore %vec_var %vec", + }, + // duplicated constants + { + "%vec_duplicate = OpConstantComposite %v4int %signed_1 %signed_2 %signed_3 %signed_4", + }, + // use duplicated constants in main + { + "%vec_var = OpVariable %_pf_v4int Function", + "OpStore %vec_var %vec_duplicate", + }, + // remapped names + { + "OpName %vec \"vec_duplicate\"", + } + }, + // duplicate vector built from duplicated ids. + { + // preserved constants + { + "%signed_1 = OpConstant %int 1", + "%signed_2 = OpConstant %int 2", + "%signed_3 = OpConstant %int 3", + "%signed_4 = OpConstant %int 4", + "%vec = OpConstantComposite %v4int %signed_1 %signed_2 %signed_3 %signed_4", + }, + // use preserved constants in main + { + "%vec_var = OpVariable %_pf_v4int Function", + "OpStore %vec_var %vec", + }, + // duplicated constants + { + "%signed_3_duplicate = OpConstant %int 3", + "%signed_4_duplicate = OpConstant %int 4", + "%vec_duplicate = OpConstantComposite %v4int %signed_1 %signed_2 %signed_3_duplicate %signed_4_duplicate", + }, + // use duplicated constants in main + { + "%vec_var = OpVariable %_pf_v4int Function", + "OpStore %vec_var %vec_duplicate", + }, + // remapped names + { + "OpName %signed_3 \"signed_3_duplicate\"", + "OpName %signed_4 \"signed_4_duplicate\"", + "OpName %vec \"vec_duplicate\"", + }, + }, + // flat struct + { + // preserved constants + { + "%bool_true = OpConstantTrue %bool", + "%signed_1 = OpConstant %int 1", + "%float_1 = OpConstant %float 1", + "%double_1 = OpConstant %double 1", + "%s = OpConstantComposite %flat_struct %bool_true %signed_1 %float_1 %double_1", + }, + // use preserved constants in main + { + "%s_var = OpVariable %_pf_flat_struct Function", + "OpStore %s_var %s", + }, + // duplicated constants + { + "%float_1_duplicate = OpConstant %float 1", + "%double_1_duplicate = OpConstant %double 1", + "%s_duplicate = OpConstantComposite %flat_struct %bool_true %signed_1 %float_1_duplicate %double_1_duplicate", + }, + // use duplicated constants in main + { + "%s_var = OpVariable %_pf_flat_struct Function", + "OpStore %s_var %s_duplicate", + }, + // remapped names + { + "OpName %float_1 \"float_1_duplicate\"", + "OpName %double_1 \"double_1_duplicate\"", + "OpName %s \"s_duplicate\"", + }, + }, + // nested struct + { + // preserved constants + { + "%bool_true = OpConstantTrue %bool", + "%signed_1 = OpConstant %int 1", + "%float_1 = OpConstant %float 1", + "%double_1 = OpConstant %double 1", + "%inner = OpConstantComposite %inner_struct %bool_true %float_1", + "%outer = OpConstantComposite %outer_struct %inner %signed_1 %double_1", + }, + // use preserved constants in main + { + "%outer_var = OpVariable %_pf_outer_struct Function", + "OpStore %outer_var %outer", + }, + // duplicated constants + { + "%float_1_duplicate = OpConstant %float 1", + "%double_1_duplicate = OpConstant %double 1", + "%inner_duplicate = OpConstantComposite %inner_struct %bool_true %float_1_duplicate", + "%outer_duplicate = OpConstantComposite %outer_struct %inner_duplicate %signed_1 %double_1_duplicate", + }, + // use duplicated constants in main + { + "%outer_var = OpVariable %_pf_outer_struct Function", + "OpStore %outer_var %outer_duplicate", + }, + // remapped names + { + "OpName %float_1 \"float_1_duplicate\"", + "OpName %double_1 \"double_1_duplicate\"", + "OpName %inner \"inner_duplicate\"", + "OpName %outer \"outer_duplicate\"", + }, + }, + // composite type null constants. Null constants and zero-valued + // constants should not be used to replace each other. + { + // preserved constants + { + "%bool_zero = OpConstantFalse %bool", + "%float_zero = OpConstant %float 0", + "%int_null = OpConstantNull %int", + "%double_null = OpConstantNull %double", + // inner_struct type null constant. + "%null_inner = OpConstantNull %inner_struct", + // zero-valued composite constant built from zero-valued constant + // component. inner_zero should not be replace by null_inner. + "%inner_zero = OpConstantComposite %inner_struct %bool_zero %float_zero", + // zero-valued composite contant built from zero-valued constants + // and null constants. + "%outer_zero = OpConstantComposite %outer_struct %inner_zero %int_null %double_null", + // outer_struct type null constant, it should not be replaced by + // outer_zero. + "%null_outer = OpConstantNull %outer_struct", + }, + // use preserved constants in main + { + "%inner_var = OpVariable %_pf_inner_struct Function", + "OpStore %inner_var %inner_zero", + "OpStore %inner_var %null_inner", + "%outer_var = OpVariable %_pf_outer_struct Function", + "OpStore %outer_var %outer_zero", + "OpStore %outer_var %null_outer", + }, + // duplicated constants + { + "%null_inner_dup = OpConstantNull %inner_struct", + "%null_outer_dup = OpConstantNull %outer_struct", + "%inner_zero_dup = OpConstantComposite %inner_struct %bool_zero %float_zero", + "%outer_zero_dup = OpConstantComposite %outer_struct %inner_zero_dup %int_null %double_null", + }, + // use duplicated constants in main + { + "%inner_var = OpVariable %_pf_inner_struct Function", + "OpStore %inner_var %inner_zero_dup", + "OpStore %inner_var %null_inner_dup", + "%outer_var = OpVariable %_pf_outer_struct Function", + "OpStore %outer_var %outer_zero_dup", + "OpStore %outer_var %null_outer_dup", + }, + // remapped names + { + "OpName %null_inner \"null_inner_dup\"", + "OpName %null_outer \"null_outer_dup\"", + "OpName %inner_zero \"inner_zero_dup\"", + "OpName %outer_zero \"outer_zero_dup\"", + }, + }, + // Spec Constants with SpecId decoration should be skipped. + { + // preserved constants + { + // Assembly builder will add OpDecorate SpecId instruction for the + // following spec constant instructions automatically. + "%spec_bool_1 = OpSpecConstantTrue %bool", + "%spec_bool_2 = OpSpecConstantTrue %bool", + "%spec_int_1 = OpSpecConstant %int 1", + "%spec_int_2 = OpSpecConstant %int 1", + }, + // use preserved constants in main + { + "%bool_var = OpVariable %_pf_bool Function", + "OpStore %bool_var %spec_bool_1", + "OpStore %bool_var %spec_bool_2", + "%int_var = OpVariable %_pf_int Function", + "OpStore %int_var %spec_int_1", + "OpStore %int_var %spec_int_2", + }, + // duplicated constants. No duplicated instruction to remove in this + // case. + {}, + // use duplicated constants in main. Same as the above 'use preserved + // constants in main' defined above, as no instruction should be + // removed in this case. + { + "%bool_var = OpVariable %_pf_bool Function", + "OpStore %bool_var %spec_bool_1", + "OpStore %bool_var %spec_bool_2", + "%int_var = OpVariable %_pf_int Function", + "OpStore %int_var %spec_int_1", + "OpStore %int_var %spec_int_2", + }, + // remapped names. No duplicated instruction removed, so this is + // empty. + {} + }, + // spec constant composite + { + // preserved constants + { + "%spec_bool_true = OpSpecConstantTrue %bool", + "%spec_signed_1 = OpSpecConstant %int 1", + "%float_1 = OpConstant %float 1", + "%double_1 = OpConstant %double 1", + "%spec_inner = OpSpecConstantComposite %inner_struct %spec_bool_true %float_1", + "%spec_outer = OpSpecConstantComposite %outer_struct %spec_inner %spec_signed_1 %double_1", + "%spec_vec2 = OpSpecConstantComposite %v2float %float_1 %float_1", + }, + // use preserved constants in main + { + "%outer_var = OpVariable %_pf_outer_struct Function", + "OpStore %outer_var %spec_outer", + "%v2float_var = OpVariable %_pf_v2float Function", + "OpStore %v2float_var %spec_vec2", + }, + // duplicated constants + { + "%float_1_duplicate = OpConstant %float 1", + "%double_1_duplicate = OpConstant %double 1", + "%spec_inner_duplicate = OpSpecConstantComposite %inner_struct %spec_bool_true %float_1_duplicate", + "%spec_outer_duplicate = OpSpecConstantComposite %outer_struct %spec_inner_duplicate %spec_signed_1 %double_1_duplicate", + "%spec_vec2_duplicate = OpSpecConstantComposite %v2float %float_1 %float_1_duplicate", + }, + // use duplicated constants in main + { + "%outer_var = OpVariable %_pf_outer_struct Function", + "OpStore %outer_var %spec_outer_duplicate", + "%v2float_var = OpVariable %_pf_v2float Function", + "OpStore %v2float_var %spec_vec2_duplicate", + }, + // remapped names + { + "OpName %float_1 \"float_1_duplicate\"", + "OpName %double_1 \"double_1_duplicate\"", + "OpName %spec_inner \"spec_inner_duplicate\"", + "OpName %spec_outer \"spec_outer_duplicate\"", + "OpName %spec_vec2 \"spec_vec2_duplicate\"", + }, + }, + // spec constant op with int scalar + { + // preserved constants + { + "%spec_signed_1 = OpSpecConstant %int 1", + "%spec_signed_2 = OpSpecConstant %int 2", + "%spec_signed_add = OpSpecConstantOp %int IAdd %spec_signed_1 %spec_signed_2", + }, + // use preserved constants in main + { + "%int_var = OpVariable %_pf_int Function", + "OpStore %int_var %spec_signed_add", + }, + // duplicated constants + { + "%spec_signed_add_duplicate = OpSpecConstantOp %int IAdd %spec_signed_1 %spec_signed_2", + }, + // use duplicated contants in main + { + "%int_var = OpVariable %_pf_int Function", + "OpStore %int_var %spec_signed_add_duplicate", + }, + // remapped names + { + "OpName %spec_signed_add \"spec_signed_add_duplicate\"", + }, + }, + // spec constant op composite extract + { + // preserved constants + { + "%float_1 = OpConstant %float 1", + "%spec_vec2 = OpSpecConstantComposite %v2float %float_1 %float_1", + "%spec_extract = OpSpecConstantOp %float CompositeExtract %spec_vec2 1", + }, + // use preserved constants in main + { + "%float_var = OpVariable %_pf_float Function", + "OpStore %float_var %spec_extract", + }, + // duplicated constants + { + "%spec_extract_duplicate = OpSpecConstantOp %float CompositeExtract %spec_vec2 1", + }, + // use duplicated constants in main + { + "%float_var = OpVariable %_pf_float Function", + "OpStore %float_var %spec_extract_duplicate", + }, + // remapped names + { + "OpName %spec_extract \"spec_extract_duplicate\"", + }, + }, + // spec constant op vector shuffle + { + // preserved constants + { + "%float_1 = OpConstant %float 1", + "%float_2 = OpConstant %float 2", + "%spec_vec2_1 = OpSpecConstantComposite %v2float %float_1 %float_1", + "%spec_vec2_2 = OpSpecConstantComposite %v2float %float_2 %float_2", + "%spec_vector_shuffle = OpSpecConstantOp %v2float VectorShuffle %spec_vec2_1 %spec_vec2_2 1 2", + }, + // use preserved constants in main + { + "%v2float_var = OpVariable %_pf_v2float Function", + "OpStore %v2float_var %spec_vector_shuffle", + }, + // duplicated constants + { + "%spec_vector_shuffle_duplicate = OpSpecConstantOp %v2float VectorShuffle %spec_vec2_1 %spec_vec2_2 1 2", + }, + // use duplicated constants in main + { + "%v2float_var = OpVariable %_pf_v2float Function", + "OpStore %v2float_var %spec_vector_shuffle_duplicate", + }, + // remapped names + { + "OpName %spec_vector_shuffle \"spec_vector_shuffle_duplicate\"", + }, + }, + // long dependency chain + { + // preserved constants + { + "%array_size = OpConstant %int 4", + "%type_arr_int_4 = OpTypeArray %int %array_size", + "%signed_0 = OpConstant %int 100", + "%signed_1 = OpConstant %int 1", + "%signed_2 = OpSpecConstantOp %int IAdd %signed_0 %signed_1", + "%signed_3 = OpSpecConstantOp %int ISub %signed_0 %signed_2", + "%signed_4 = OpSpecConstantOp %int IAdd %signed_0 %signed_3", + "%signed_5 = OpSpecConstantOp %int ISub %signed_0 %signed_4", + "%signed_6 = OpSpecConstantOp %int IAdd %signed_0 %signed_5", + "%signed_7 = OpSpecConstantOp %int ISub %signed_0 %signed_6", + "%signed_8 = OpSpecConstantOp %int IAdd %signed_0 %signed_7", + "%signed_9 = OpSpecConstantOp %int ISub %signed_0 %signed_8", + "%signed_10 = OpSpecConstantOp %int IAdd %signed_0 %signed_9", + "%signed_11 = OpSpecConstantOp %int ISub %signed_0 %signed_10", + "%signed_12 = OpSpecConstantOp %int IAdd %signed_0 %signed_11", + "%signed_13 = OpSpecConstantOp %int ISub %signed_0 %signed_12", + "%signed_14 = OpSpecConstantOp %int IAdd %signed_0 %signed_13", + "%signed_15 = OpSpecConstantOp %int ISub %signed_0 %signed_14", + "%signed_16 = OpSpecConstantOp %int ISub %signed_0 %signed_15", + "%signed_17 = OpSpecConstantOp %int IAdd %signed_0 %signed_16", + "%signed_18 = OpSpecConstantOp %int ISub %signed_0 %signed_17", + "%signed_19 = OpSpecConstantOp %int IAdd %signed_0 %signed_18", + "%signed_20 = OpSpecConstantOp %int ISub %signed_0 %signed_19", + "%signed_vec_a = OpSpecConstantComposite %v2int %signed_18 %signed_19", + "%signed_vec_b = OpSpecConstantOp %v2int IMul %signed_vec_a %signed_vec_a", + "%signed_21 = OpSpecConstantOp %int CompositeExtract %signed_vec_b 0", + "%signed_array = OpConstantComposite %type_arr_int_4 %signed_20 %signed_20 %signed_21 %signed_21", + "%signed_22 = OpSpecConstantOp %int CompositeExtract %signed_array 0", + }, + // use preserved constants in main + { + "%int_var = OpVariable %_pf_int Function", + "OpStore %int_var %signed_22", + }, + // duplicated constants + { + "%signed_0_dup = OpConstant %int 100", + "%signed_1_dup = OpConstant %int 1", + "%signed_2_dup = OpSpecConstantOp %int IAdd %signed_0_dup %signed_1_dup", + "%signed_3_dup = OpSpecConstantOp %int ISub %signed_0_dup %signed_2_dup", + "%signed_4_dup = OpSpecConstantOp %int IAdd %signed_0_dup %signed_3_dup", + "%signed_5_dup = OpSpecConstantOp %int ISub %signed_0_dup %signed_4_dup", + "%signed_6_dup = OpSpecConstantOp %int IAdd %signed_0_dup %signed_5_dup", + "%signed_7_dup = OpSpecConstantOp %int ISub %signed_0_dup %signed_6_dup", + "%signed_8_dup = OpSpecConstantOp %int IAdd %signed_0_dup %signed_7_dup", + "%signed_9_dup = OpSpecConstantOp %int ISub %signed_0_dup %signed_8_dup", + "%signed_10_dup = OpSpecConstantOp %int IAdd %signed_0_dup %signed_9_dup", + "%signed_11_dup = OpSpecConstantOp %int ISub %signed_0_dup %signed_10_dup", + "%signed_12_dup = OpSpecConstantOp %int IAdd %signed_0_dup %signed_11_dup", + "%signed_13_dup = OpSpecConstantOp %int ISub %signed_0_dup %signed_12_dup", + "%signed_14_dup = OpSpecConstantOp %int IAdd %signed_0_dup %signed_13_dup", + "%signed_15_dup = OpSpecConstantOp %int ISub %signed_0_dup %signed_14_dup", + "%signed_16_dup = OpSpecConstantOp %int ISub %signed_0_dup %signed_15_dup", + "%signed_17_dup = OpSpecConstantOp %int IAdd %signed_0_dup %signed_16_dup", + "%signed_18_dup = OpSpecConstantOp %int ISub %signed_0_dup %signed_17_dup", + "%signed_19_dup = OpSpecConstantOp %int IAdd %signed_0_dup %signed_18_dup", + "%signed_20_dup = OpSpecConstantOp %int ISub %signed_0_dup %signed_19_dup", + "%signed_vec_a_dup = OpSpecConstantComposite %v2int %signed_18_dup %signed_19_dup", + "%signed_vec_b_dup = OpSpecConstantOp %v2int IMul %signed_vec_a_dup %signed_vec_a_dup", + "%signed_21_dup = OpSpecConstantOp %int CompositeExtract %signed_vec_b_dup 0", + "%signed_array_dup = OpConstantComposite %type_arr_int_4 %signed_20_dup %signed_20_dup %signed_21_dup %signed_21_dup", + "%signed_22_dup = OpSpecConstantOp %int CompositeExtract %signed_array_dup 0", + }, + // use duplicated constants in main + { + "%int_var = OpVariable %_pf_int Function", + "OpStore %int_var %signed_22_dup", + }, + // remapped names + { + "OpName %signed_0 \"signed_0_dup\"", + "OpName %signed_1 \"signed_1_dup\"", + "OpName %signed_2 \"signed_2_dup\"", + "OpName %signed_3 \"signed_3_dup\"", + "OpName %signed_4 \"signed_4_dup\"", + "OpName %signed_5 \"signed_5_dup\"", + "OpName %signed_6 \"signed_6_dup\"", + "OpName %signed_7 \"signed_7_dup\"", + "OpName %signed_8 \"signed_8_dup\"", + "OpName %signed_9 \"signed_9_dup\"", + "OpName %signed_10 \"signed_10_dup\"", + "OpName %signed_11 \"signed_11_dup\"", + "OpName %signed_12 \"signed_12_dup\"", + "OpName %signed_13 \"signed_13_dup\"", + "OpName %signed_14 \"signed_14_dup\"", + "OpName %signed_15 \"signed_15_dup\"", + "OpName %signed_16 \"signed_16_dup\"", + "OpName %signed_17 \"signed_17_dup\"", + "OpName %signed_18 \"signed_18_dup\"", + "OpName %signed_19 \"signed_19_dup\"", + "OpName %signed_20 \"signed_20_dup\"", + "OpName %signed_vec_a \"signed_vec_a_dup\"", + "OpName %signed_vec_b \"signed_vec_b_dup\"", + "OpName %signed_21 \"signed_21_dup\"", + "OpName %signed_array \"signed_array_dup\"", + "OpName %signed_22 \"signed_22_dup\"", + }, + }, + // clang-format on + }))); + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/upgrade_memory_model_test.cpp b/test/opt/upgrade_memory_model_test.cpp new file mode 100644 index 000000000..9d2d7627d --- /dev/null +++ b/test/opt/upgrade_memory_model_test.cpp @@ -0,0 +1,1716 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "assembly_builder.h" +#include "gmock/gmock.h" +#include "pass_fixture.h" +#include "pass_utils.h" + +namespace { + +using namespace spvtools; + +using UpgradeMemoryModelTest = opt::PassTest<::testing::Test>; + +TEST_F(UpgradeMemoryModelTest, InvalidMemoryModelOpenCL) { + const std::string text = R"( +; CHECK: OpMemoryModel Logical OpenCL +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, InvalidMemoryModelVulkanKHR) { + const std::string text = R"( +; CHECK: OpMemoryModel Logical VulkanKHR +OpCapability Shader +OpCapability Linkage +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, JustMemoryModel) { + const std::string text = R"( +; CHECK: OpCapability VulkanMemoryModelKHR +; CHECK: OpExtension "SPV_KHR_vulkan_memory_model" +; CHECK: OpMemoryModel Logical VulkanKHR +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, RemoveDecorations) { + const std::string text = R"( +; CHECK-NOT: OpDecorate +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %var Volatile +OpDecorate %var Coherent +%int = OpTypeInt 32 0 +%ptr_int_Uniform = OpTypePointer Uniform %int +%var = OpVariable %ptr_int_Uniform Uniform +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, WorkgroupVariable) { + const std::string text = R"( +; CHECK: [[scope:%\w+]] = OpConstant {{%\w+}} 2 +; CHECK: OpLoad {{%\w+}} {{%\w+}} MakePointerVisibleKHR|NonPrivatePointerKHR [[scope]] +; CHECK: OpStore {{%\w+}} {{%\w+}} MakePointerAvailableKHR|NonPrivatePointerKHR [[scope]] +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%ptr_int_Workgroup = OpTypePointer Workgroup %int +%var = OpVariable %ptr_int_Workgroup Workgroup +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +%ld = OpLoad %int %var +%st = OpStore %var %ld +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, WorkgroupFunctionParameter) { + const std::string text = R"( +; CHECK: [[scope:%\w+]] = OpConstant {{%\w+}} 2 +; CHECK: OpLoad {{%\w+}} {{%\w+}} MakePointerVisibleKHR|NonPrivatePointerKHR [[scope]] +; CHECK: OpStore {{%\w+}} {{%\w+}} MakePointerAvailableKHR|NonPrivatePointerKHR [[scope]] +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%ptr_int_Workgroup = OpTypePointer Workgroup %int +%func_ty = OpTypeFunction %void %ptr_int_Workgroup +%func = OpFunction %void None %func_ty +%param = OpFunctionParameter %ptr_int_Workgroup +%1 = OpLabel +%ld = OpLoad %int %param +%st = OpStore %param %ld +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, SimpleUniformVariable) { + const std::string text = R"( +; CHECK-NOT: OpDecorate +; CHECK: [[scope:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK: OpLoad {{%\w+}} {{%\w+}} Volatile|MakePointerVisibleKHR|NonPrivatePointerKHR [[scope]] +; CHECK: OpStore {{%\w+}} {{%\w+}} Volatile|MakePointerAvailableKHR|NonPrivatePointerKHR [[scope]] +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %var Coherent +OpDecorate %var Volatile +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%ptr_int_Uniform = OpTypePointer Uniform %int +%var = OpVariable %ptr_int_Uniform Uniform +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +%ld = OpLoad %int %var +OpStore %var %ld +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, SimpleUniformFunctionParameter) { + const std::string text = R"( +; CHECK-NOT: OpDecorate +; CHECK: [[scope:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK: OpLoad {{%\w+}} {{%\w+}} Volatile|MakePointerVisibleKHR|NonPrivatePointerKHR [[scope]] +; CHECK: OpStore {{%\w+}} {{%\w+}} Volatile|MakePointerAvailableKHR|NonPrivatePointerKHR [[scope]] +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %param Coherent +OpDecorate %param Volatile +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%ptr_int_Uniform = OpTypePointer Uniform %int +%func_ty = OpTypeFunction %void %ptr_int_Uniform +%func = OpFunction %void None %func_ty +%param = OpFunctionParameter %ptr_int_Uniform +%1 = OpLabel +%ld = OpLoad %int %param +OpStore %param %ld +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, SimpleUniformVariableOnlyVolatile) { + const std::string text = R"( +; CHECK-NOT: OpDecorate +; CHECK-NOT: OpConstant +; CHECK: OpLoad {{%\w+}} {{%\w+}} Volatile +; CHECK: OpStore {{%\w+}} {{%\w+}} Volatile +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %var Volatile +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%ptr_int_Uniform = OpTypePointer Uniform %int +%var = OpVariable %ptr_int_Uniform Uniform +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +%ld = OpLoad %int %var +OpStore %var %ld +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, SimpleUniformVariableCopied) { + const std::string text = R"( +; CHECK-NOT: OpDecorate +; CHECK: [[scope:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK: OpLoad {{%\w+}} {{%\w+}} Volatile|MakePointerVisibleKHR|NonPrivatePointerKHR [[scope]] +; CHECK: OpStore {{%\w+}} {{%\w+}} Volatile|MakePointerAvailableKHR|NonPrivatePointerKHR [[scope]] +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %var Coherent +OpDecorate %var Volatile +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%ptr_int_Uniform = OpTypePointer Uniform %int +%var = OpVariable %ptr_int_Uniform Uniform +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +%copy = OpCopyObject %ptr_int_Uniform %var +%ld = OpLoad %int %copy +OpStore %copy %ld +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, SimpleUniformFunctionParameterCopied) { + const std::string text = R"( +; CHECK-NOT: OpDecorate +; CHECK: [[scope:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK: OpLoad {{%\w+}} {{%\w+}} Volatile|MakePointerVisibleKHR|NonPrivatePointerKHR [[scope]] +; CHECK: OpStore {{%\w+}} {{%\w+}} Volatile|MakePointerAvailableKHR|NonPrivatePointerKHR [[scope]] +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %param Coherent +OpDecorate %param Volatile +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%ptr_int_Uniform = OpTypePointer Uniform %int +%func_ty = OpTypeFunction %void %ptr_int_Uniform +%func = OpFunction %void None %func_ty +%param = OpFunctionParameter %ptr_int_Uniform +%1 = OpLabel +%copy = OpCopyObject %ptr_int_Uniform %param +%ld = OpLoad %int %copy +%copy2 = OpCopyObject %ptr_int_Uniform %param +OpStore %copy2 %ld +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, SimpleUniformVariableAccessChain) { + const std::string text = R"( +; CHECK-NOT: OpDecorate +; CHECK: [[scope:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK: OpLoad {{%\w+}} {{%\w+}} Volatile|MakePointerVisibleKHR|NonPrivatePointerKHR [[scope]] +; CHECK: OpStore {{%\w+}} {{%\w+}} Volatile|MakePointerAvailableKHR|NonPrivatePointerKHR [[scope]] +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %var Coherent +OpDecorate %var Volatile +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%int0 = OpConstant %int 0 +%int3 = OpConstant %int 3 +%int_array_3 = OpTypeArray %int %int3 +%ptr_intarray_Uniform = OpTypePointer Uniform %int_array_3 +%ptr_int_Uniform = OpTypePointer Uniform %int +%var = OpVariable %ptr_intarray_Uniform Uniform +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +%gep = OpAccessChain %ptr_int_Uniform %var %int0 +%ld = OpLoad %int %gep +OpStore %gep %ld +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, SimpleUniformFunctionParameterAccessChain) { + const std::string text = R"( +; CHECK-NOT: OpDecorate +; CHECK: [[scope:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK: OpLoad {{%\w+}} {{%\w+}} Volatile|MakePointerVisibleKHR|NonPrivatePointerKHR [[scope]] +; CHECK: OpStore {{%\w+}} {{%\w+}} Volatile|MakePointerAvailableKHR|NonPrivatePointerKHR [[scope]] +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %param Coherent +OpDecorate %param Volatile +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%int0 = OpConstant %int 0 +%int3 = OpConstant %int 3 +%int_array_3 = OpTypeArray %int %int3 +%ptr_intarray_Uniform = OpTypePointer Uniform %int_array_3 +%ptr_int_Uniform = OpTypePointer Uniform %int +%func_ty = OpTypeFunction %void %ptr_intarray_Uniform +%func = OpFunction %void None %func_ty +%param = OpFunctionParameter %ptr_intarray_Uniform +%1 = OpLabel +%ld_gep = OpAccessChain %ptr_int_Uniform %param %int0 +%ld = OpLoad %int %ld_gep +%st_gep = OpAccessChain %ptr_int_Uniform %param %int0 +OpStore %st_gep %ld +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, VariablePointerSelect) { + const std::string text = R"( +; CHECK-NOT: OpDecorate +; CHECK: [[scope:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK: OpLoad {{%\w+}} {{%\w+}} Volatile|MakePointerVisibleKHR|NonPrivatePointerKHR [[scope]] +; CHECK: OpStore {{%\w+}} {{%\w+}} Volatile|MakePointerAvailableKHR|NonPrivatePointerKHR [[scope]] +OpCapability Shader +OpCapability Linkage +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpDecorate %var Coherent +OpDecorate %var Volatile +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%bool = OpTypeBool +%true = OpConstantTrue %bool +%ptr_int_StorageBuffer = OpTypePointer StorageBuffer %int +%null = OpConstantNull %ptr_int_StorageBuffer +%var = OpVariable %ptr_int_StorageBuffer StorageBuffer +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +%select = OpSelect %ptr_int_StorageBuffer %true %var %null +%ld = OpLoad %int %select +OpStore %var %ld +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, VariablePointerSelectConservative) { + const std::string text = R"( +; CHECK-NOT: OpDecorate +; CHECK: [[scope:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK: OpLoad {{%\w+}} {{%\w+}} Volatile|MakePointerVisibleKHR|NonPrivatePointerKHR [[scope]] +; CHECK: OpStore {{%\w+}} {{%\w+}} Volatile|MakePointerAvailableKHR|NonPrivatePointerKHR [[scope]] +OpCapability Shader +OpCapability Linkage +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpDecorate %var1 Coherent +OpDecorate %var2 Volatile +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%bool = OpTypeBool +%true = OpConstantTrue %bool +%ptr_int_StorageBuffer = OpTypePointer StorageBuffer %int +%var1 = OpVariable %ptr_int_StorageBuffer StorageBuffer +%var2 = OpVariable %ptr_int_StorageBuffer StorageBuffer +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +%select = OpSelect %ptr_int_StorageBuffer %true %var1 %var2 +%ld = OpLoad %int %select +OpStore %select %ld +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, VariablePointerIncrement) { + const std::string text = R"( +; CHECK-NOT: OpDecorate {{%\w+}} Coherent +; CHECK: [[scope:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK: OpLoad {{%\w+}} {{%\w+}} MakePointerVisibleKHR|NonPrivatePointerKHR [[scope]] +; CHECK: OpStore {{%\w+}} {{%\w+}} MakePointerAvailableKHR|NonPrivatePointerKHR [[scope]] +OpCapability Shader +OpCapability Linkage +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +OpDecorate %param Coherent +OpDecorate %param ArrayStride 4 +%void = OpTypeVoid +%bool = OpTypeBool +%int = OpTypeInt 32 0 +%int0 = OpConstant %int 0 +%int1 = OpConstant %int 1 +%int10 = OpConstant %int 10 +%ptr_int_StorageBuffer = OpTypePointer StorageBuffer %int +%func_ty = OpTypeFunction %void %ptr_int_StorageBuffer +%func = OpFunction %void None %func_ty +%param = OpFunctionParameter %ptr_int_StorageBuffer +%1 = OpLabel +OpBranch %2 +%2 = OpLabel +%phi = OpPhi %ptr_int_StorageBuffer %param %1 %ptr_next %2 +%iv = OpPhi %int %int0 %1 %inc %2 +%inc = OpIAdd %int %iv %int1 +%ptr_next = OpPtrAccessChain %ptr_int_StorageBuffer %phi %int1 +%cmp = OpIEqual %bool %iv %int10 +OpLoopMerge %3 %2 None +OpBranchConditional %cmp %3 %2 +%3 = OpLabel +%ld = OpLoad %int %phi +OpStore %phi %ld +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, CoherentStructElement) { + const std::string text = R"( +; CHECK-NOT: OpMemberDecorate +; CHECK: [[scope:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK: OpLoad {{%\w+}} {{%\w+}} MakePointerVisibleKHR|NonPrivatePointerKHR [[scope]] +; CHECK: OpStore {{%\w+}} {{%\w+}} MakePointerAvailableKHR|NonPrivatePointerKHR [[scope]] +OpCapability Shader +OpCapability Linkage +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpMemoryModel Logical GLSL450 +OpMemberDecorate %struct 0 Coherent +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%int0 = OpConstant %int 0 +%struct = OpTypeStruct %int +%ptr_struct_StorageBuffer = OpTypePointer StorageBuffer %struct +%ptr_int_StorageBuffer = OpTypePointer StorageBuffer %int +%func_ty = OpTypeFunction %void %ptr_struct_StorageBuffer +%func = OpFunction %void None %func_ty +%param = OpFunctionParameter %ptr_struct_StorageBuffer +%1 = OpLabel +%gep = OpAccessChain %ptr_int_StorageBuffer %param %int0 +%ld = OpLoad %int %gep +OpStore %gep %ld +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, CoherentElementFullStructAccess) { + const std::string text = R"( +; CHECK-NOT: OpMemberDecorate +; CHECK: [[scope:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK: OpLoad {{%\w+}} {{%\w+}} MakePointerVisibleKHR|NonPrivatePointerKHR [[scope]] +; CHECK: OpStore {{%\w+}} {{%\w+}} MakePointerAvailableKHR|NonPrivatePointerKHR [[scope]] +OpCapability Shader +OpCapability Linkage +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpMemoryModel Logical GLSL450 +OpMemberDecorate %struct 0 Coherent +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%struct = OpTypeStruct %int +%ptr_struct_StorageBuffer = OpTypePointer StorageBuffer %struct +%func_ty = OpTypeFunction %void %ptr_struct_StorageBuffer +%func = OpFunction %void None %func_ty +%param = OpFunctionParameter %ptr_struct_StorageBuffer +%1 = OpLabel +%ld = OpLoad %struct %param +OpStore %param %ld +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, CoherentElementNotAccessed) { + const std::string text = R"( +; CHECK-NOT: OpMemberDecorate +; CHECK-NOT: MakePointerAvailableKHR +; CHECK-NOT: NonPrivatePointerKHR +; CHECK-NOT: MakePointerVisibleKHR +OpCapability Shader +OpCapability Linkage +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpMemoryModel Logical GLSL450 +OpMemberDecorate %struct 1 Coherent +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%int0 = OpConstant %int 0 +%struct = OpTypeStruct %int %int +%ptr_struct_StorageBuffer = OpTypePointer StorageBuffer %struct +%ptr_int_StorageBuffer = OpTypePointer StorageBuffer %int +%func_ty = OpTypeFunction %void %ptr_struct_StorageBuffer +%func = OpFunction %void None %func_ty +%param = OpFunctionParameter %ptr_struct_StorageBuffer +%1 = OpLabel +%gep = OpAccessChain %ptr_int_StorageBuffer %param %int0 +%ld = OpLoad %int %gep +OpStore %gep %ld +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, MultiIndexAccessCoherent) { + const std::string text = R"( +; CHECK-NOT: OpMemberDecorate +; CHECK: [[scope:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK: OpLoad {{%\w+}} {{%\w+}} MakePointerVisibleKHR|NonPrivatePointerKHR [[scope]] +; CHECK: OpStore {{%\w+}} {{%\w+}} MakePointerAvailableKHR|NonPrivatePointerKHR [[scope]] +OpCapability Shader +OpCapability Linkage +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpMemoryModel Logical GLSL450 +OpMemberDecorate %inner 1 Coherent +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%int0 = OpConstant %int 0 +%int1 = OpConstant %int 1 +%inner = OpTypeStruct %int %int +%middle = OpTypeStruct %inner +%outer = OpTypeStruct %middle %middle +%ptr_outer_StorageBuffer = OpTypePointer StorageBuffer %outer +%ptr_int_StorageBuffer = OpTypePointer StorageBuffer %int +%func_ty = OpTypeFunction %void %ptr_outer_StorageBuffer +%func = OpFunction %void None %func_ty +%param = OpFunctionParameter %ptr_outer_StorageBuffer +%1 = OpLabel +%ld_gep = OpInBoundsAccessChain %ptr_int_StorageBuffer %param %int0 %int0 %int1 +%ld = OpLoad %int %ld_gep +%st_gep = OpInBoundsAccessChain %ptr_int_StorageBuffer %param %int1 %int0 %int1 +OpStore %st_gep %ld +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, MultiIndexAccessNonCoherent) { + const std::string text = R"( +; CHECK-NOT: OpMemberDecorate +; CHECK-NOT: MakePointerAvailableKHR +; CHECK-NOT: NonPrivatePointerKHR +; CHECK-NOT: MakePointerVisibleKHR +OpCapability Shader +OpCapability Linkage +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpMemoryModel Logical GLSL450 +OpMemberDecorate %inner 1 Coherent +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%int0 = OpConstant %int 0 +%int1 = OpConstant %int 1 +%inner = OpTypeStruct %int %int +%middle = OpTypeStruct %inner +%outer = OpTypeStruct %middle %middle +%ptr_outer_StorageBuffer = OpTypePointer StorageBuffer %outer +%ptr_int_StorageBuffer = OpTypePointer StorageBuffer %int +%func_ty = OpTypeFunction %void %ptr_outer_StorageBuffer +%func = OpFunction %void None %func_ty +%param = OpFunctionParameter %ptr_outer_StorageBuffer +%1 = OpLabel +%ld_gep = OpInBoundsAccessChain %ptr_int_StorageBuffer %param %int0 %int0 %int0 +%ld = OpLoad %int %ld_gep +%st_gep = OpInBoundsAccessChain %ptr_int_StorageBuffer %param %int1 %int0 %int0 +OpStore %st_gep %ld +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, ConsecutiveAccessChainCoherent) { + const std::string text = R"( +; CHECK-NOT: OpMemberDecorate +; CHECK: [[scope:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK: OpLoad {{%\w+}} {{%\w+}} MakePointerVisibleKHR|NonPrivatePointerKHR [[scope]] +; CHECK: OpStore {{%\w+}} {{%\w+}} MakePointerAvailableKHR|NonPrivatePointerKHR [[scope]] +OpCapability Shader +OpCapability Linkage +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpMemoryModel Logical GLSL450 +OpMemberDecorate %inner 1 Coherent +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%int0 = OpConstant %int 0 +%int1 = OpConstant %int 1 +%inner = OpTypeStruct %int %int +%middle = OpTypeStruct %inner +%outer = OpTypeStruct %middle %middle +%ptr_outer_StorageBuffer = OpTypePointer StorageBuffer %outer +%ptr_middle_StorageBuffer = OpTypePointer StorageBuffer %middle +%ptr_inner_StorageBuffer = OpTypePointer StorageBuffer %inner +%ptr_int_StorageBuffer = OpTypePointer StorageBuffer %int +%func_ty = OpTypeFunction %void %ptr_outer_StorageBuffer +%func = OpFunction %void None %func_ty +%param = OpFunctionParameter %ptr_outer_StorageBuffer +%1 = OpLabel +%ld_gep1 = OpInBoundsAccessChain %ptr_middle_StorageBuffer %param %int0 +%ld_gep2 = OpInBoundsAccessChain %ptr_inner_StorageBuffer %ld_gep1 %int0 +%ld_gep3 = OpInBoundsAccessChain %ptr_int_StorageBuffer %ld_gep2 %int1 +%ld = OpLoad %int %ld_gep3 +%st_gep1 = OpInBoundsAccessChain %ptr_middle_StorageBuffer %param %int1 +%st_gep2 = OpInBoundsAccessChain %ptr_inner_StorageBuffer %st_gep1 %int0 +%st_gep3 = OpInBoundsAccessChain %ptr_int_StorageBuffer %st_gep2 %int1 +OpStore %st_gep3 %ld +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, ConsecutiveAccessChainNonCoherent) { + const std::string text = R"( +; CHECK-NOT: OpMemberDecorate +; CHECK-NOT: MakePointerAvailableKHR +; CHECK-NOT: NonPrivatePointerKHR +; CHECK-NOT: MakePointerVisibleKHR +OpCapability Shader +OpCapability Linkage +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpMemoryModel Logical GLSL450 +OpMemberDecorate %inner 1 Coherent +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%int0 = OpConstant %int 0 +%int1 = OpConstant %int 1 +%inner = OpTypeStruct %int %int +%middle = OpTypeStruct %inner +%outer = OpTypeStruct %middle %middle +%ptr_outer_StorageBuffer = OpTypePointer StorageBuffer %outer +%ptr_middle_StorageBuffer = OpTypePointer StorageBuffer %middle +%ptr_inner_StorageBuffer = OpTypePointer StorageBuffer %inner +%ptr_int_StorageBuffer = OpTypePointer StorageBuffer %int +%func_ty = OpTypeFunction %void %ptr_outer_StorageBuffer +%func = OpFunction %void None %func_ty +%param = OpFunctionParameter %ptr_outer_StorageBuffer +%1 = OpLabel +%ld_gep1 = OpInBoundsAccessChain %ptr_middle_StorageBuffer %param %int0 +%ld_gep2 = OpInBoundsAccessChain %ptr_inner_StorageBuffer %ld_gep1 %int0 +%ld_gep3 = OpInBoundsAccessChain %ptr_int_StorageBuffer %ld_gep2 %int0 +%ld = OpLoad %int %ld_gep3 +%st_gep1 = OpInBoundsAccessChain %ptr_middle_StorageBuffer %param %int1 +%st_gep2 = OpInBoundsAccessChain %ptr_inner_StorageBuffer %st_gep1 %int0 +%st_gep3 = OpInBoundsAccessChain %ptr_int_StorageBuffer %st_gep2 %int0 +OpStore %st_gep3 %ld +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, CoherentStructElementAccess) { + const std::string text = R"( +; CHECK-NOT: OpMemberDecorate +; CHECK: [[scope:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK: OpLoad {{%\w+}} {{%\w+}} MakePointerVisibleKHR|NonPrivatePointerKHR [[scope]] +; CHECK: OpStore {{%\w+}} {{%\w+}} MakePointerAvailableKHR|NonPrivatePointerKHR [[scope]] +OpCapability Shader +OpCapability Linkage +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpMemoryModel Logical GLSL450 +OpMemberDecorate %middle 0 Coherent +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%int0 = OpConstant %int 0 +%int1 = OpConstant %int 1 +%inner = OpTypeStruct %int %int +%middle = OpTypeStruct %inner +%outer = OpTypeStruct %middle %middle +%ptr_outer_StorageBuffer = OpTypePointer StorageBuffer %outer +%ptr_middle_StorageBuffer = OpTypePointer StorageBuffer %middle +%ptr_inner_StorageBuffer = OpTypePointer StorageBuffer %inner +%ptr_int_StorageBuffer = OpTypePointer StorageBuffer %int +%func_ty = OpTypeFunction %void %ptr_outer_StorageBuffer +%func = OpFunction %void None %func_ty +%param = OpFunctionParameter %ptr_outer_StorageBuffer +%1 = OpLabel +%ld_gep1 = OpInBoundsAccessChain %ptr_middle_StorageBuffer %param %int0 +%ld_gep2 = OpInBoundsAccessChain %ptr_inner_StorageBuffer %ld_gep1 %int0 +%ld_gep3 = OpInBoundsAccessChain %ptr_int_StorageBuffer %ld_gep2 %int1 +%ld = OpLoad %int %ld_gep3 +%st_gep1 = OpInBoundsAccessChain %ptr_middle_StorageBuffer %param %int1 +%st_gep2 = OpInBoundsAccessChain %ptr_inner_StorageBuffer %st_gep1 %int0 +%st_gep3 = OpInBoundsAccessChain %ptr_int_StorageBuffer %st_gep2 %int1 +OpStore %st_gep3 %ld +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, NonCoherentLoadCoherentStore) { + const std::string text = R"( +; CHECK-NOT: OpMemberDecorate +; CHECK: [[scope:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK-NOT: MakePointerVisibleKHR +; CHECK: OpStore {{%\w+}} {{%\w+}} MakePointerAvailableKHR|NonPrivatePointerKHR [[scope]] +OpCapability Shader +OpCapability Linkage +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpMemoryModel Logical GLSL450 +OpMemberDecorate %outer 1 Coherent +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%int0 = OpConstant %int 0 +%int1 = OpConstant %int 1 +%inner = OpTypeStruct %int %int +%middle = OpTypeStruct %inner +%outer = OpTypeStruct %middle %middle +%ptr_outer_StorageBuffer = OpTypePointer StorageBuffer %outer +%ptr_middle_StorageBuffer = OpTypePointer StorageBuffer %middle +%ptr_inner_StorageBuffer = OpTypePointer StorageBuffer %inner +%ptr_int_StorageBuffer = OpTypePointer StorageBuffer %int +%func_ty = OpTypeFunction %void %ptr_outer_StorageBuffer +%func = OpFunction %void None %func_ty +%param = OpFunctionParameter %ptr_outer_StorageBuffer +%1 = OpLabel +%ld_gep1 = OpInBoundsAccessChain %ptr_middle_StorageBuffer %param %int0 +%ld_gep2 = OpInBoundsAccessChain %ptr_inner_StorageBuffer %ld_gep1 %int0 +%ld_gep3 = OpInBoundsAccessChain %ptr_int_StorageBuffer %ld_gep2 %int1 +%ld = OpLoad %int %ld_gep3 +%st_gep1 = OpInBoundsAccessChain %ptr_middle_StorageBuffer %param %int1 +%st_gep2 = OpInBoundsAccessChain %ptr_inner_StorageBuffer %st_gep1 %int0 +%st_gep3 = OpInBoundsAccessChain %ptr_int_StorageBuffer %st_gep2 %int1 +OpStore %st_gep3 %ld +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, CopyMemory) { + const std::string text = R"( +; CHECK-NOT: OpDecorate +; CHECK: [[queuefamily:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK: OpCopyMemory {{%\w+}} {{%\w+}} Volatile|MakePointerVisibleKHR|NonPrivatePointerKHR [[queuefamily]] +; CHECK-NOT: [[queuefamily]] +OpCapability Shader +OpCapability Linkage +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpMemoryModel Logical GLSL450 +OpDecorate %in_var Coherent +OpDecorate %out_var Volatile +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%ptr_int_StorageBuffer = OpTypePointer StorageBuffer %int +%in_var = OpVariable %ptr_int_StorageBuffer StorageBuffer +%out_var = OpVariable %ptr_int_StorageBuffer StorageBuffer +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +OpCopyMemory %out_var %in_var +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, CopyMemorySized) { + const std::string text = R"( +; CHECK-NOT: OpDecorate +; CHECK: [[queuefamily:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK: OpCopyMemorySized {{%\w+}} {{%\w+}} {{%\w+}} Volatile|MakePointerAvailableKHR|NonPrivatePointerKHR [[queuefamily]] +; CHECK-NOT: [[queuefamily]] +OpCapability Shader +OpCapability Linkage +OpCapability Addresses +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpMemoryModel Logical GLSL450 +OpDecorate %out_param Coherent +OpDecorate %in_param Volatile +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%int4 = OpConstant %int 4 +%ptr_int_StorageBuffer = OpTypePointer StorageBuffer %int +%func_ty = OpTypeFunction %void %ptr_int_StorageBuffer %ptr_int_StorageBuffer +%func = OpFunction %void None %func_ty +%in_param = OpFunctionParameter %ptr_int_StorageBuffer +%out_param = OpFunctionParameter %ptr_int_StorageBuffer +%1 = OpLabel +OpCopyMemorySized %out_param %in_param %int4 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, CopyMemoryTwoScopes) { + const std::string text = R"( +; CHECK-NOT: OpDecorate +; CHECK-DAG: [[queuefamily:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK-DAG: [[workgroup:%\w+]] = OpConstant {{%\w+}} 2 +; CHECK: OpCopyMemory {{%\w+}} {{%\w+}} MakePointerAvailableKHR|MakePointerVisibleKHR|NonPrivatePointerKHR [[workgroup]] [[queuefamily]] +OpCapability Shader +OpCapability Linkage +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpMemoryModel Logical GLSL450 +OpDecorate %in_var Coherent +OpDecorate %out_var Coherent +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%ptr_int_Workgroup = OpTypePointer Workgroup %int +%ptr_int_StorageBuffer = OpTypePointer StorageBuffer %int +%in_var = OpVariable %ptr_int_StorageBuffer StorageBuffer +%out_var = OpVariable %ptr_int_Workgroup Workgroup +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +OpCopyMemory %out_var %in_var +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, VolatileImageRead) { + const std::string text = R"( +; CHECK-NOT: OpDecorate +; CHECK: OpLoad {{%\w+}} {{%\w+}} Volatile +; CHECK: OpImageRead {{%\w+}} {{%\w+}} {{%\w+}} VolatileTexelKHR +OpCapability Shader +OpCapability Linkage +OpCapability StorageImageReadWithoutFormat +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpMemoryModel Logical GLSL450 +OpDecorate %var Volatile +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%v2int = OpTypeVector %int 2 +%float = OpTypeFloat 32 +%int0 = OpConstant %int 0 +%v2int_0 = OpConstantComposite %v2int %int0 %int0 +%image = OpTypeImage %float 2D 0 0 0 2 Unknown +%ptr_image_StorageBuffer = OpTypePointer StorageBuffer %image +%var = OpVariable %ptr_image_StorageBuffer StorageBuffer +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +%ld = OpLoad %image %var +%rd = OpImageRead %float %ld %v2int_0 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, CoherentImageRead) { + const std::string text = R"( +; CHECK-NOT: OpDecorate +; CHECK: [[scope:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK: OpLoad {{%\w+}} {{%\w+}} MakePointerVisibleKHR|NonPrivatePointerKHR [[scope]] +; CHECK: OpImageRead {{%\w+}} {{%\w+}} {{%\w+}} MakeTexelVisibleKHR|NonPrivateTexelKHR [[scope]] +OpCapability Shader +OpCapability Linkage +OpCapability StorageImageReadWithoutFormat +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpMemoryModel Logical GLSL450 +OpDecorate %var Coherent +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%v2int = OpTypeVector %int 2 +%float = OpTypeFloat 32 +%int0 = OpConstant %int 0 +%v2int_0 = OpConstantComposite %v2int %int0 %int0 +%image = OpTypeImage %float 2D 0 0 0 2 Unknown +%ptr_image_StorageBuffer = OpTypePointer StorageBuffer %image +%var = OpVariable %ptr_image_StorageBuffer StorageBuffer +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +%ld = OpLoad %image %var +%rd = OpImageRead %float %ld %v2int_0 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, CoherentImageReadExtractedFromSampledImage) { + const std::string text = R"( +; CHECK-NOT: OpDecorate +; CHECK: [[image:%\w+]] = OpTypeImage +; CHECK: [[scope:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK: OpLoad [[image]] {{%\w+}} MakePointerVisibleKHR|NonPrivatePointerKHR [[scope]] +; CHECK-NOT: NonPrivatePointerKHR +; CHECK: OpImageRead {{%\w+}} {{%\w+}} {{%\w+}} MakeTexelVisibleKHR|NonPrivateTexelKHR [[scope]] +OpCapability Shader +OpCapability Linkage +OpCapability StorageImageReadWithoutFormat +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpMemoryModel Logical GLSL450 +OpDecorate %var Coherent +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%v2int = OpTypeVector %int 2 +%float = OpTypeFloat 32 +%int0 = OpConstant %int 0 +%v2int_0 = OpConstantComposite %v2int %int0 %int0 +%image = OpTypeImage %float 2D 0 0 0 0 Unknown +%sampled_image = OpTypeSampledImage %image +%sampler = OpTypeSampler +%ptr_image_StorageBuffer = OpTypePointer StorageBuffer %image +%ptr_sampler_StorageBuffer = OpTypePointer StorageBuffer %sampler +%var = OpVariable %ptr_image_StorageBuffer StorageBuffer +%sampler_var = OpVariable %ptr_sampler_StorageBuffer StorageBuffer +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +%ld = OpLoad %image %var +%ld_sampler = OpLoad %sampler %sampler_var +%sample = OpSampledImage %sampled_image %ld %ld_sampler +%extract = OpImage %image %sample +%rd = OpImageRead %float %extract %v2int_0 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, VolatileImageWrite) { + const std::string text = R"( +; CHECK-NOT: OpDecorate +; CHECK: OpLoad {{%\w+}} {{%\w+}} Volatile +; CHECK: OpImageWrite {{%\w+}} {{%\w+}} {{%\w+}} VolatileTexelKHR +OpCapability Shader +OpCapability Linkage +OpCapability StorageImageWriteWithoutFormat +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpMemoryModel Logical GLSL450 +OpDecorate %param Volatile +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%v2int = OpTypeVector %int 2 +%float = OpTypeFloat 32 +%float0 = OpConstant %float 0 +%v2int_null = OpConstantNull %v2int +%image = OpTypeImage %float 2D 0 0 0 0 Unknown +%ptr_image_StorageBuffer = OpTypePointer StorageBuffer %image +%func_ty = OpTypeFunction %void %ptr_image_StorageBuffer +%func = OpFunction %void None %func_ty +%param = OpFunctionParameter %ptr_image_StorageBuffer +%1 = OpLabel +%ld = OpLoad %image %param +OpImageWrite %ld %v2int_null %float0 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, CoherentImageWrite) { + const std::string text = R"( +; CHECK-NOT: OpDecorate +; CHECK: [[scope:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK: OpLoad {{%\w+}} {{%\w+}} MakePointerVisibleKHR|NonPrivatePointerKHR +; CHECK: OpImageWrite {{%\w+}} {{%\w+}} {{%\w+}} MakeTexelAvailableKHR|NonPrivateTexelKHR [[scope]] +OpCapability Shader +OpCapability Linkage +OpCapability StorageImageWriteWithoutFormat +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpMemoryModel Logical GLSL450 +OpDecorate %param Coherent +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%v2int = OpTypeVector %int 2 +%float = OpTypeFloat 32 +%float0 = OpConstant %float 0 +%v2int_null = OpConstantNull %v2int +%image = OpTypeImage %float 2D 0 0 0 0 Unknown +%ptr_image_StorageBuffer = OpTypePointer StorageBuffer %image +%func_ty = OpTypeFunction %void %ptr_image_StorageBuffer +%func = OpFunction %void None %func_ty +%param = OpFunctionParameter %ptr_image_StorageBuffer +%1 = OpLabel +%ld = OpLoad %image %param +OpImageWrite %ld %v2int_null %float0 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, CoherentImageWriteExtractFromSampledImage) { + const std::string text = R"( +; CHECK-NOT: OpDecorate +; CHECK: [[scope:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK: OpLoad {{%\w+}} {{%\w+}} MakePointerVisibleKHR|NonPrivatePointerKHR +; CHECK-NOT: NonPrivatePointerKHR +; CHECK: OpImageWrite {{%\w+}} {{%\w+}} {{%\w+}} MakeTexelAvailableKHR|NonPrivateTexelKHR [[scope]] +OpCapability Shader +OpCapability Linkage +OpCapability StorageImageWriteWithoutFormat +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpMemoryModel Logical GLSL450 +OpDecorate %param Coherent +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%v2int = OpTypeVector %int 2 +%float = OpTypeFloat 32 +%float0 = OpConstant %float 0 +%v2int_null = OpConstantNull %v2int +%image = OpTypeImage %float 2D 0 0 0 0 Unknown +%sampled_image = OpTypeSampledImage %image +%sampler = OpTypeSampler +%ptr_image_StorageBuffer = OpTypePointer StorageBuffer %image +%ptr_sampler_StorageBuffer = OpTypePointer StorageBuffer %sampler +%func_ty = OpTypeFunction %void %ptr_image_StorageBuffer %ptr_sampler_StorageBuffer +%func = OpFunction %void None %func_ty +%param = OpFunctionParameter %ptr_image_StorageBuffer +%sampler_param = OpFunctionParameter %ptr_sampler_StorageBuffer +%1 = OpLabel +%ld = OpLoad %image %param +%ld_sampler = OpLoad %sampler %sampler_param +%sample = OpSampledImage %sampled_image %ld %ld_sampler +%extract = OpImage %image %sample +OpImageWrite %extract %v2int_null %float0 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, VolatileImageSparseRead) { + const std::string text = R"( +; CHECK-NOT: OpDecorate +; CHECK: OpLoad {{%\w+}} {{%\w+}} Volatile +; CHECK: OpImageSparseRead {{%\w+}} {{%\w+}} {{%\w+}} VolatileTexelKHR +OpCapability Shader +OpCapability Linkage +OpCapability StorageImageReadWithoutFormat +OpCapability SparseResidency +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpMemoryModel Logical GLSL450 +OpDecorate %var Volatile +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%v2int = OpTypeVector %int 2 +%float = OpTypeFloat 32 +%int0 = OpConstant %int 0 +%v2int_0 = OpConstantComposite %v2int %int0 %int0 +%image = OpTypeImage %float 2D 0 0 0 2 Unknown +%struct = OpTypeStruct %int %float +%ptr_image_StorageBuffer = OpTypePointer StorageBuffer %image +%var = OpVariable %ptr_image_StorageBuffer StorageBuffer +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +%ld = OpLoad %image %var +%rd = OpImageSparseRead %struct %ld %v2int_0 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, CoherentImageSparseRead) { + const std::string text = R"( +; CHECK-NOT: OpDecorate +; CHECK: [[scope:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK: OpLoad {{%\w+}} {{%\w+}} MakePointerVisibleKHR|NonPrivatePointerKHR [[scope]] +; CHECK: OpImageSparseRead {{%\w+}} {{%\w+}} {{%\w+}} MakeTexelVisibleKHR|NonPrivateTexelKHR [[scope]] +OpCapability Shader +OpCapability Linkage +OpCapability StorageImageReadWithoutFormat +OpCapability SparseResidency +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpMemoryModel Logical GLSL450 +OpDecorate %var Coherent +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%v2int = OpTypeVector %int 2 +%float = OpTypeFloat 32 +%int0 = OpConstant %int 0 +%v2int_0 = OpConstantComposite %v2int %int0 %int0 +%image = OpTypeImage %float 2D 0 0 0 2 Unknown +%struct = OpTypeStruct %int %float +%ptr_image_StorageBuffer = OpTypePointer StorageBuffer %image +%var = OpVariable %ptr_image_StorageBuffer StorageBuffer +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +%ld = OpLoad %image %var +%rd = OpImageSparseRead %struct %ld %v2int_0 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, + CoherentImageSparseReadExtractedFromSampledImage) { + const std::string text = R"( +; CHECK-NOT: OpDecorate +; CHECK: [[image:%\w+]] = OpTypeImage +; CHECK: [[scope:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK: OpLoad [[image]] {{%\w+}} MakePointerVisibleKHR|NonPrivatePointerKHR [[scope]] +; CHECK-NOT: NonPrivatePointerKHR +; CHECK: OpImageSparseRead {{%\w+}} {{%\w+}} {{%\w+}} MakeTexelVisibleKHR|NonPrivateTexelKHR [[scope]] +OpCapability Shader +OpCapability Linkage +OpCapability StorageImageReadWithoutFormat +OpCapability SparseResidency +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpMemoryModel Logical GLSL450 +OpDecorate %var Coherent +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%v2int = OpTypeVector %int 2 +%float = OpTypeFloat 32 +%int0 = OpConstant %int 0 +%v2int_0 = OpConstantComposite %v2int %int0 %int0 +%image = OpTypeImage %float 2D 0 0 0 0 Unknown +%struct = OpTypeStruct %int %float +%sampled_image = OpTypeSampledImage %image +%sampler = OpTypeSampler +%ptr_image_StorageBuffer = OpTypePointer StorageBuffer %image +%ptr_sampler_StorageBuffer = OpTypePointer StorageBuffer %sampler +%var = OpVariable %ptr_image_StorageBuffer StorageBuffer +%sampler_var = OpVariable %ptr_sampler_StorageBuffer StorageBuffer +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +%ld = OpLoad %image %var +%ld_sampler = OpLoad %sampler %sampler_var +%sample = OpSampledImage %sampled_image %ld %ld_sampler +%extract = OpImage %image %sample +%rd = OpImageSparseRead %struct %extract %v2int_0 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, TessellationControlBarrierNoChange) { + const std::string text = R"( +; CHECK: [[none:%\w+]] = OpConstant {{%\w+}} 0 +; CHECK: [[workgroup:%\w+]] = OpConstant {{%\w+}} 2 +; CHECK: OpControlBarrier [[workgroup]] [[workgroup]] [[none]] +OpCapability Tessellation +OpMemoryModel Logical GLSL450 +OpEntryPoint TessellationControl %func "func" +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%none = OpConstant %int 0 +%workgroup = OpConstant %int 2 +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +OpControlBarrier %workgroup %workgroup %none +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, TessellationControlBarrierAddOutput) { + const std::string text = R"( +; CHECK: [[workgroup:%\w+]] = OpConstant {{%\w+}} 2 +; CHECK: [[output:%\w+]] = OpConstant {{%\w+}} 4096 +; CHECK: OpControlBarrier [[workgroup]] [[workgroup]] [[output]] +OpCapability Tessellation +OpMemoryModel Logical GLSL450 +OpEntryPoint TessellationControl %func "func" %var +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%none = OpConstant %int 0 +%workgroup = OpConstant %int 2 +%ptr_int_Output = OpTypePointer Output %int +%var = OpVariable %ptr_int_Output Output +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +%ld = OpLoad %int %var +OpControlBarrier %workgroup %workgroup %none +OpStore %var %ld +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, TessellationMemoryBarrierNoChange) { + const std::string text = R"( +; CHECK: [[none:%\w+]] = OpConstant {{%\w+}} 0 +; CHECK: [[workgroup:%\w+]] = OpConstant {{%\w+}} 2 +; CHECK: OpMemoryBarrier [[workgroup]] [[none]] +OpCapability Tessellation +OpMemoryModel Logical GLSL450 +OpEntryPoint TessellationControl %func "func" %var +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%none = OpConstant %int 0 +%workgroup = OpConstant %int 2 +%ptr_int_Output = OpTypePointer Output %int +%var = OpVariable %ptr_int_Output Output +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +%ld = OpLoad %int %var +OpMemoryBarrier %workgroup %none +OpStore %var %ld +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, TessellationControlBarrierAddOutputSubFunction) { + const std::string text = R"( +; CHECK: [[workgroup:%\w+]] = OpConstant {{%\w+}} 2 +; CHECK: [[output:%\w+]] = OpConstant {{%\w+}} 4096 +; CHECK: OpControlBarrier [[workgroup]] [[workgroup]] [[output]] +OpCapability Tessellation +OpMemoryModel Logical GLSL450 +OpEntryPoint TessellationControl %func "func" %var +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%none = OpConstant %int 0 +%workgroup = OpConstant %int 2 +%ptr_int_Output = OpTypePointer Output %int +%var = OpVariable %ptr_int_Output Output +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +%call = OpFunctionCall %void %sub_func +OpReturn +OpFunctionEnd +%sub_func = OpFunction %void None %func_ty +%2 = OpLabel +%ld = OpLoad %int %var +OpControlBarrier %workgroup %workgroup %none +OpStore %var %ld +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, + TessellationControlBarrierAddOutputDifferentFunctions) { + const std::string text = R"( +; CHECK: [[workgroup:%\w+]] = OpConstant {{%\w+}} 2 +; CHECK: [[output:%\w+]] = OpConstant {{%\w+}} 4096 +; CHECK: OpControlBarrier [[workgroup]] [[workgroup]] [[output]] +OpCapability Tessellation +OpMemoryModel Logical GLSL450 +OpEntryPoint TessellationControl %func "func" %var +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%none = OpConstant %int 0 +%workgroup = OpConstant %int 2 +%ptr_int_Output = OpTypePointer Output %int +%var = OpVariable %ptr_int_Output Output +%func_ty = OpTypeFunction %void +%ld_func_ty = OpTypeFunction %int +%st_func_ty = OpTypeFunction %void %int +%func = OpFunction %void None %func_ty +%1 = OpLabel +%call_ld = OpFunctionCall %int %ld_func +%call_barrier = OpFunctionCall %void %barrier_func +%call_st = OpFunctionCall %void %st_func %call_ld +OpReturn +OpFunctionEnd +%ld_func = OpFunction %int None %ld_func_ty +%2 = OpLabel +%ld = OpLoad %int %var +OpReturnValue %ld +OpFunctionEnd +%barrier_func = OpFunction %void None %func_ty +%3 = OpLabel +OpControlBarrier %workgroup %workgroup %none +OpReturn +OpFunctionEnd +%st_func = OpFunction %void None %st_func_ty +%param = OpFunctionParameter %int +%4 = OpLabel +OpStore %var %param +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, ChangeControlBarrierMemoryScope) { + std::string text = R"( +; CHECK: [[workgroup:%\w+]] = OpConstant {{%\w+}} 2 +; CHECK: [[queuefamily:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK: OpControlBarrier [[workgroup]] [[queuefamily]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %func "func" +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%none = OpConstant %int 0 +%device = OpConstant %int 1 +%workgroup = OpConstant %int 2 +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +OpControlBarrier %workgroup %device %none +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, ChangeMemoryBarrierMemoryScope) { + std::string text = R"( +; CHECK: [[queuefamily:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK: OpMemoryBarrier [[queuefamily]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %func "func" +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%none = OpConstant %int 0 +%device = OpConstant %int 1 +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +OpMemoryBarrier %device %none +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, ChangeAtomicMemoryScope) { + std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt +; CHECK: [[var:%\w+]] = OpVariable +; CHECK: [[qf:%\w+]] = OpConstant [[int]] 5 +; CHECK: OpAtomicLoad [[int]] [[var]] [[qf]] +; CHECK: OpAtomicStore [[var]] [[qf]] +; CHECK: OpAtomicExchange [[int]] [[var]] [[qf]] +; CHECK: OpAtomicCompareExchange [[int]] [[var]] [[qf]] +; CHECK: OpAtomicIIncrement [[int]] [[var]] [[qf]] +; CHECK: OpAtomicIDecrement [[int]] [[var]] [[qf]] +; CHECK: OpAtomicIAdd [[int]] [[var]] [[qf]] +; CHECK: OpAtomicISub [[int]] [[var]] [[qf]] +; CHECK: OpAtomicSMin [[int]] [[var]] [[qf]] +; CHECK: OpAtomicSMax [[int]] [[var]] [[qf]] +; CHECK: OpAtomicUMin [[int]] [[var]] [[qf]] +; CHECK: OpAtomicUMax [[int]] [[var]] [[qf]] +; CHECK: OpAtomicAnd [[int]] [[var]] [[qf]] +; CHECK: OpAtomicOr [[int]] [[var]] [[qf]] +; CHECK: OpAtomicXor [[int]] [[var]] [[qf]] +OpCapability Shader +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %func "func" +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%none = OpConstant %int 0 +%device = OpConstant %int 1 +%func_ty = OpTypeFunction %void +%ptr_int_StorageBuffer = OpTypePointer StorageBuffer %int +%var = OpVariable %ptr_int_StorageBuffer StorageBuffer +%func = OpFunction %void None %func_ty +%1 = OpLabel +%ld = OpAtomicLoad %int %var %device %none +OpAtomicStore %var %device %none %ld +%ex = OpAtomicExchange %int %var %device %none %ld +%cmp_ex = OpAtomicCompareExchange %int %var %device %none %none %ld %ld +%inc = OpAtomicIIncrement %int %var %device %none +%dec = OpAtomicIDecrement %int %var %device %none +%add = OpAtomicIAdd %int %var %device %none %ld +%sub = OpAtomicISub %int %var %device %none %ld +%smin = OpAtomicSMin %int %var %device %none %ld +%smax = OpAtomicSMax %int %var %device %none %ld +%umin = OpAtomicUMin %int %var %device %none %ld +%umax = OpAtomicUMax %int %var %device %none %ld +%and = OpAtomicAnd %int %var %device %none %ld +%or = OpAtomicOr %int %var %device %none %ld +%xor = OpAtomicXor %int %var %device %none %ld +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, UpgradeModfNoFlags) { + const std::string text = R"( +; CHECK: [[float:%\w+]] = OpTypeFloat 32 +; CHECK: [[float_0:%\w+]] = OpConstant [[float]] 0 +; CHECK: [[ptr:%\w+]] = OpTypePointer StorageBuffer [[float]] +; CHECK: [[var:%\w+]] = OpVariable [[ptr]] StorageBuffer +; CHECK: [[struct:%\w+]] = OpTypeStruct [[float]] [[float]] +; CHECK: [[modfstruct:%\w+]] = OpExtInst [[struct]] {{%\w+}} ModfStruct [[float_0]] +; CHECK: [[ex0:%\w+]] = OpCompositeExtract [[float]] [[modfstruct]] 0 +; CHECK: [[ex1:%\w+]] = OpCompositeExtract [[float]] [[modfstruct]] 1 +; CHECK: OpStore [[var]] [[ex1]] +; CHECK-NOT: NonPrivatePointerKHR +; CHECK: OpFAdd [[float]] [[float_0]] [[ex0]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +%import = OpExtInstImport "GLSL.std.450" +OpEntryPoint GLCompute %func "func" +%void = OpTypeVoid +%float = OpTypeFloat 32 +%float_0 = OpConstant %float 0 +%ptr_ssbo_float = OpTypePointer StorageBuffer %float +%ssbo_var = OpVariable %ptr_ssbo_float StorageBuffer +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +%2 = OpExtInst %float %import Modf %float_0 %ssbo_var +%3 = OpFAdd %float %float_0 %2 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, UpgradeModfWorkgroupCoherent) { + const std::string text = R"( +; CHECK: [[float:%\w+]] = OpTypeFloat 32 +; CHECK: [[float_0:%\w+]] = OpConstant [[float]] 0 +; CHECK: [[ptr:%\w+]] = OpTypePointer Workgroup [[float]] +; CHECK: [[var:%\w+]] = OpVariable [[ptr]] Workgroup +; CHECK: [[struct:%\w+]] = OpTypeStruct [[float]] [[float]] +; CHECK: [[wg_scope:%\w+]] = OpConstant {{%\w+}} 2 +; CHECK: [[modfstruct:%\w+]] = OpExtInst [[struct]] {{%\w+}} ModfStruct [[float_0]] +; CHECK: [[ex0:%\w+]] = OpCompositeExtract [[float]] [[modfstruct]] 0 +; CHECK: [[ex1:%\w+]] = OpCompositeExtract [[float]] [[modfstruct]] 1 +; CHECK: OpStore [[var]] [[ex1]] MakePointerAvailableKHR|NonPrivatePointerKHR [[wg_scope]] +; CHECK: OpFAdd [[float]] [[float_0]] [[ex0]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +%import = OpExtInstImport "GLSL.std.450" +OpEntryPoint GLCompute %func "func" +OpDecorate %wg_var Coherent +%void = OpTypeVoid +%float = OpTypeFloat 32 +%float_0 = OpConstant %float 0 +%ptr_wg_float = OpTypePointer Workgroup %float +%wg_var = OpVariable %ptr_wg_float Workgroup +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +%2 = OpExtInst %float %import Modf %float_0 %wg_var +%3 = OpFAdd %float %float_0 %2 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, UpgradeModfSSBOCoherent) { + const std::string text = R"( +; CHECK: [[float:%\w+]] = OpTypeFloat 32 +; CHECK: [[float_0:%\w+]] = OpConstant [[float]] 0 +; CHECK: [[ptr:%\w+]] = OpTypePointer StorageBuffer [[float]] +; CHECK: [[var:%\w+]] = OpVariable [[ptr]] StorageBuffer +; CHECK: [[struct:%\w+]] = OpTypeStruct [[float]] [[float]] +; CHECK: [[qf_scope:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK: [[modfstruct:%\w+]] = OpExtInst [[struct]] {{%\w+}} ModfStruct [[float_0]] +; CHECK: [[ex0:%\w+]] = OpCompositeExtract [[float]] [[modfstruct]] 0 +; CHECK: [[ex1:%\w+]] = OpCompositeExtract [[float]] [[modfstruct]] 1 +; CHECK: OpStore [[var]] [[ex1]] MakePointerAvailableKHR|NonPrivatePointerKHR [[qf_scope]] +; CHECK: OpFAdd [[float]] [[float_0]] [[ex0]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +%import = OpExtInstImport "GLSL.std.450" +OpEntryPoint GLCompute %func "func" +OpDecorate %ssbo_var Coherent +%void = OpTypeVoid +%float = OpTypeFloat 32 +%float_0 = OpConstant %float 0 +%ptr_ssbo_float = OpTypePointer StorageBuffer %float +%ssbo_var = OpVariable %ptr_ssbo_float StorageBuffer +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +%2 = OpExtInst %float %import Modf %float_0 %ssbo_var +%3 = OpFAdd %float %float_0 %2 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, UpgradeModfSSBOVolatile) { + const std::string text = R"( +; CHECK: [[float:%\w+]] = OpTypeFloat 32 +; CHECK: [[float_0:%\w+]] = OpConstant [[float]] 0 +; CHECK: [[ptr:%\w+]] = OpTypePointer StorageBuffer [[float]] +; CHECK: [[var:%\w+]] = OpVariable [[ptr]] StorageBuffer +; CHECK: [[struct:%\w+]] = OpTypeStruct [[float]] [[float]] +; CHECK: [[modfstruct:%\w+]] = OpExtInst [[struct]] {{%\w+}} ModfStruct [[float_0]] +; CHECK: [[ex0:%\w+]] = OpCompositeExtract [[float]] [[modfstruct]] 0 +; CHECK: [[ex1:%\w+]] = OpCompositeExtract [[float]] [[modfstruct]] 1 +; CHECK: OpStore [[var]] [[ex1]] Volatile +; CHECK: OpFAdd [[float]] [[float_0]] [[ex0]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +%import = OpExtInstImport "GLSL.std.450" +OpEntryPoint GLCompute %func "func" +OpDecorate %wg_var Volatile +%void = OpTypeVoid +%float = OpTypeFloat 32 +%float_0 = OpConstant %float 0 +%ptr_ssbo_float = OpTypePointer StorageBuffer %float +%wg_var = OpVariable %ptr_ssbo_float StorageBuffer +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +%2 = OpExtInst %float %import Modf %float_0 %wg_var +%3 = OpFAdd %float %float_0 %2 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, UpgradeFrexpNoFlags) { + const std::string text = R"( +; CHECK: [[float:%\w+]] = OpTypeFloat 32 +; CHECK: [[float_0:%\w+]] = OpConstant [[float]] 0 +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[ptr:%\w+]] = OpTypePointer StorageBuffer [[int]] +; CHECK: [[var:%\w+]] = OpVariable [[ptr]] StorageBuffer +; CHECK: [[struct:%\w+]] = OpTypeStruct [[float]] [[int]] +; CHECK: [[modfstruct:%\w+]] = OpExtInst [[struct]] {{%\w+}} FrexpStruct [[float_0]] +; CHECK: [[ex0:%\w+]] = OpCompositeExtract [[float]] [[modfstruct]] 0 +; CHECK: [[ex1:%\w+]] = OpCompositeExtract [[int]] [[modfstruct]] 1 +; CHECK: OpStore [[var]] [[ex1]] +; CHECK-NOT: NonPrivatePointerKHR +; CHECK: OpFAdd [[float]] [[float_0]] [[ex0]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +%import = OpExtInstImport "GLSL.std.450" +OpEntryPoint GLCompute %func "func" +%void = OpTypeVoid +%float = OpTypeFloat 32 +%float_0 = OpConstant %float 0 +%int = OpTypeInt 32 0 +%ptr_ssbo_int = OpTypePointer StorageBuffer %int +%ssbo_var = OpVariable %ptr_ssbo_int StorageBuffer +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +%2 = OpExtInst %float %import Frexp %float_0 %ssbo_var +%3 = OpFAdd %float %float_0 %2 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, UpgradeFrexpWorkgroupCoherent) { + const std::string text = R"( +; CHECK: [[float:%\w+]] = OpTypeFloat 32 +; CHECK: [[float_0:%\w+]] = OpConstant [[float]] 0 +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[ptr:%\w+]] = OpTypePointer Workgroup [[int]] +; CHECK: [[var:%\w+]] = OpVariable [[ptr]] Workgroup +; CHECK: [[struct:%\w+]] = OpTypeStruct [[float]] [[int]] +; CHECK: [[wg_scope:%\w+]] = OpConstant {{%\w+}} 2 +; CHECK: [[modfstruct:%\w+]] = OpExtInst [[struct]] {{%\w+}} FrexpStruct [[float_0]] +; CHECK: [[ex0:%\w+]] = OpCompositeExtract [[float]] [[modfstruct]] 0 +; CHECK: [[ex1:%\w+]] = OpCompositeExtract [[int]] [[modfstruct]] 1 +; CHECK: OpStore [[var]] [[ex1]] MakePointerAvailableKHR|NonPrivatePointerKHR [[wg_scope]] +; CHECK: OpFAdd [[float]] [[float_0]] [[ex0]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +%import = OpExtInstImport "GLSL.std.450" +OpEntryPoint GLCompute %func "func" +OpDecorate %wg_var Coherent +%void = OpTypeVoid +%float = OpTypeFloat 32 +%float_0 = OpConstant %float 0 +%int = OpTypeInt 32 0 +%ptr_wg_int = OpTypePointer Workgroup %int +%wg_var = OpVariable %ptr_wg_int Workgroup +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +%2 = OpExtInst %float %import Frexp %float_0 %wg_var +%3 = OpFAdd %float %float_0 %2 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, UpgradeFrexpSSBOCoherent) { + const std::string text = R"( +; CHECK: [[float:%\w+]] = OpTypeFloat 32 +; CHECK: [[float_0:%\w+]] = OpConstant [[float]] 0 +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[ptr:%\w+]] = OpTypePointer StorageBuffer [[int]] +; CHECK: [[var:%\w+]] = OpVariable [[ptr]] StorageBuffer +; CHECK: [[struct:%\w+]] = OpTypeStruct [[float]] [[int]] +; CHECK: [[qf_scope:%\w+]] = OpConstant {{%\w+}} 5 +; CHECK: [[modfstruct:%\w+]] = OpExtInst [[struct]] {{%\w+}} FrexpStruct [[float_0]] +; CHECK: [[ex0:%\w+]] = OpCompositeExtract [[float]] [[modfstruct]] 0 +; CHECK: [[ex1:%\w+]] = OpCompositeExtract [[int]] [[modfstruct]] 1 +; CHECK: OpStore [[var]] [[ex1]] MakePointerAvailableKHR|NonPrivatePointerKHR [[qf_scope]] +; CHECK: OpFAdd [[float]] [[float_0]] [[ex0]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +%import = OpExtInstImport "GLSL.std.450" +OpEntryPoint GLCompute %func "func" +OpDecorate %ssbo_var Coherent +%void = OpTypeVoid +%float = OpTypeFloat 32 +%float_0 = OpConstant %float 0 +%int = OpTypeInt 32 0 +%ptr_ssbo_int = OpTypePointer StorageBuffer %int +%ssbo_var = OpVariable %ptr_ssbo_int StorageBuffer +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +%2 = OpExtInst %float %import Frexp %float_0 %ssbo_var +%3 = OpFAdd %float %float_0 %2 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(UpgradeMemoryModelTest, UpgradeFrexpSSBOVolatile) { + const std::string text = R"( +; CHECK: [[float:%\w+]] = OpTypeFloat 32 +; CHECK: [[float_0:%\w+]] = OpConstant [[float]] 0 +; CHECK: [[int:%\w+]] = OpTypeInt 32 0 +; CHECK: [[ptr:%\w+]] = OpTypePointer StorageBuffer [[int]] +; CHECK: [[var:%\w+]] = OpVariable [[ptr]] StorageBuffer +; CHECK: [[struct:%\w+]] = OpTypeStruct [[float]] [[int]] +; CHECK: [[modfstruct:%\w+]] = OpExtInst [[struct]] {{%\w+}} FrexpStruct [[float_0]] +; CHECK: [[ex0:%\w+]] = OpCompositeExtract [[float]] [[modfstruct]] 0 +; CHECK: [[ex1:%\w+]] = OpCompositeExtract [[int]] [[modfstruct]] 1 +; CHECK: OpStore [[var]] [[ex1]] Volatile +; CHECK: OpFAdd [[float]] [[float_0]] [[ex0]] +OpCapability Shader +OpMemoryModel Logical GLSL450 +%import = OpExtInstImport "GLSL.std.450" +OpEntryPoint GLCompute %func "func" +OpDecorate %wg_var Volatile +%void = OpTypeVoid +%float = OpTypeFloat 32 +%float_0 = OpConstant %float 0 +%int = OpTypeInt 32 0 +%ptr_ssbo_int = OpTypePointer StorageBuffer %int +%wg_var = OpVariable %ptr_ssbo_int StorageBuffer +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +%2 = OpExtInst %float %import Frexp %float_0 %wg_var +%3 = OpFAdd %float %float_0 %2 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +} // namespace diff --git a/test/opt/utils_test.cpp b/test/opt/utils_test.cpp new file mode 100644 index 000000000..9bb82a367 --- /dev/null +++ b/test/opt/utils_test.cpp @@ -0,0 +1,110 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gtest/gtest.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +TEST(JoinAllInsts, Cases) { + EXPECT_EQ("", JoinAllInsts({})); + EXPECT_EQ("a\n", JoinAllInsts({"a"})); + EXPECT_EQ("a\nb\n", JoinAllInsts({"a", "b"})); + EXPECT_EQ("a\nb\nc\n", JoinAllInsts({"a", "b", "c"})); + EXPECT_EQ("hello,\nworld!\n\n\n", JoinAllInsts({"hello,", "world!", "\n"})); +} + +TEST(JoinNonDebugInsts, Cases) { + EXPECT_EQ("", JoinNonDebugInsts({})); + EXPECT_EQ("a\n", JoinNonDebugInsts({"a"})); + EXPECT_EQ("", JoinNonDebugInsts({"OpName"})); + EXPECT_EQ("a\nb\n", JoinNonDebugInsts({"a", "b"})); + EXPECT_EQ("", JoinNonDebugInsts({"OpName", "%1 = OpString \"42\""})); + EXPECT_EQ("Opstring\n", JoinNonDebugInsts({"OpName", "Opstring"})); + EXPECT_EQ("the only remaining string\n", + JoinNonDebugInsts( + {"OpSourceContinued", "OpSource", "OpSourceExtension", + "lgtm OpName", "hello OpMemberName", "this is a OpString", + "lonely OpLine", "happy OpNoLine", "OpModuleProcessed", + "the only remaining string"})); +} + +struct SubstringReplacementTestCase { + const char* orig_str; + const char* find_substr; + const char* replace_substr; + const char* expected_str; + bool replace_should_succeed; +}; + +using FindAndReplaceTest = + ::testing::TestWithParam; + +TEST_P(FindAndReplaceTest, SubstringReplacement) { + auto process = std::string(GetParam().orig_str); + EXPECT_EQ(GetParam().replace_should_succeed, + FindAndReplace(&process, GetParam().find_substr, + GetParam().replace_substr)) + << "Original string: " << GetParam().orig_str + << " replace: " << GetParam().find_substr + << " to: " << GetParam().replace_substr + << " should returns: " << GetParam().replace_should_succeed; + EXPECT_STREQ(GetParam().expected_str, process.c_str()) + << "Original string: " << GetParam().orig_str + << " replace: " << GetParam().find_substr + << " to: " << GetParam().replace_substr + << " expected string: " << GetParam().expected_str; +} + +INSTANTIATE_TEST_CASE_P( + SubstringReplacement, FindAndReplaceTest, + ::testing::ValuesIn(std::vector({ + // orig string, find substring, replace substring, expected string, + // replacement happened + {"", "", "", "", false}, + {"", "b", "", "", false}, + {"", "", "c", "", false}, + {"", "a", "b", "", false}, + + {"a", "", "c", "a", false}, + {"a", "b", "c", "a", false}, + {"a", "b", "", "a", false}, + {"a", "a", "", "", true}, + {"a", "a", "b", "b", true}, + + {"ab", "a", "b", "bb", true}, + {"ab", "a", "", "b", true}, + {"ab", "b", "", "a", true}, + {"ab", "ab", "", "", true}, + {"ab", "ab", "cd", "cd", true}, + {"bc", "abc", "efg", "bc", false}, + + {"abc", "ab", "bc", "bcc", true}, + {"abc", "ab", "", "c", true}, + {"abc", "bc", "", "a", true}, + {"abc", "bc", "d", "ad", true}, + {"abc", "a", "123", "123bc", true}, + {"abc", "ab", "a", "ac", true}, + {"abc", "a", "aab", "aabbc", true}, + {"abc", "abcd", "efg", "abc", false}, + }))); + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/value_table_test.cpp b/test/opt/value_table_test.cpp new file mode 100644 index 000000000..ef338ae7e --- /dev/null +++ b/test/opt/value_table_test.cpp @@ -0,0 +1,591 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "source/opt/build_module.h" +#include "source/opt/value_number_table.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" + +namespace spvtools { +namespace opt { +namespace { + +using ::testing::HasSubstr; +using ::testing::MatchesRegex; +using ValueTableTest = PassTest<::testing::Test>; + +TEST_F(ValueTableTest, SameInstructionSameValue) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer Function %5 + %2 = OpFunction %3 None %4 + %7 = OpLabel + %8 = OpVariable %6 Function + %9 = OpLoad %5 %8 + %10 = OpFAdd %5 %9 %9 + OpReturn + OpFunctionEnd + )"; + auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ValueNumberTable vtable(context.get()); + Instruction* inst = context->get_def_use_mgr()->GetDef(10); + EXPECT_EQ(vtable.GetValueNumber(inst), vtable.GetValueNumber(inst)); +} + +TEST_F(ValueTableTest, DifferentInstructionSameValue) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer Function %5 + %2 = OpFunction %3 None %4 + %7 = OpLabel + %8 = OpVariable %6 Function + %9 = OpLoad %5 %8 + %10 = OpFAdd %5 %9 %9 + %11 = OpFAdd %5 %9 %9 + OpReturn + OpFunctionEnd + )"; + auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(10); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(11); + EXPECT_EQ(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); +} + +TEST_F(ValueTableTest, SameValueDifferentBlock) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer Function %5 + %2 = OpFunction %3 None %4 + %7 = OpLabel + %8 = OpVariable %6 Function + %9 = OpLoad %5 %8 + %10 = OpFAdd %5 %9 %9 + OpBranch %11 + %11 = OpLabel + %12 = OpFAdd %5 %9 %9 + OpReturn + OpFunctionEnd + )"; + auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(10); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(12); + EXPECT_EQ(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); +} + +TEST_F(ValueTableTest, DifferentValue) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer Function %5 + %2 = OpFunction %3 None %4 + %7 = OpLabel + %8 = OpVariable %6 Function + %9 = OpLoad %5 %8 + %10 = OpFAdd %5 %9 %9 + %11 = OpFAdd %5 %9 %10 + OpReturn + OpFunctionEnd + )"; + auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(10); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(11); + EXPECT_NE(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); +} + +TEST_F(ValueTableTest, DifferentValueDifferentBlock) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer Function %5 + %2 = OpFunction %3 None %4 + %7 = OpLabel + %8 = OpVariable %6 Function + %9 = OpLoad %5 %8 + %10 = OpFAdd %5 %9 %9 + OpBranch %11 + %11 = OpLabel + %12 = OpFAdd %5 %9 %10 + OpReturn + OpFunctionEnd + )"; + auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(10); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(12); + EXPECT_NE(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); +} + +TEST_F(ValueTableTest, SameLoad) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer Function %5 + %2 = OpFunction %3 None %4 + %7 = OpLabel + %8 = OpVariable %6 Function + %9 = OpLoad %5 %8 + OpReturn + OpFunctionEnd + )"; + auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ValueNumberTable vtable(context.get()); + Instruction* inst = context->get_def_use_mgr()->GetDef(9); + EXPECT_EQ(vtable.GetValueNumber(inst), vtable.GetValueNumber(inst)); +} + +// Two different loads, even from the same memory, must given different value +// numbers if the memory is not read-only. +TEST_F(ValueTableTest, DifferentFunctionLoad) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer Function %5 + %2 = OpFunction %3 None %4 + %7 = OpLabel + %8 = OpVariable %6 Function + %9 = OpLoad %5 %8 + %10 = OpLoad %5 %8 + OpReturn + OpFunctionEnd + )"; + auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(9); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(10); + EXPECT_NE(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); +} + +TEST_F(ValueTableTest, DifferentUniformLoad) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer Uniform %5 + %8 = OpVariable %6 Uniform + %2 = OpFunction %3 None %4 + %7 = OpLabel + %9 = OpLoad %5 %8 + %10 = OpLoad %5 %8 + OpReturn + OpFunctionEnd + )"; + auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(9); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(10); + EXPECT_EQ(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); +} + +TEST_F(ValueTableTest, DifferentInputLoad) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer Input %5 + %8 = OpVariable %6 Input + %2 = OpFunction %3 None %4 + %7 = OpLabel + %9 = OpLoad %5 %8 + %10 = OpLoad %5 %8 + OpReturn + OpFunctionEnd + )"; + auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(9); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(10); + EXPECT_EQ(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); +} + +TEST_F(ValueTableTest, DifferentUniformConstantLoad) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer UniformConstant %5 + %8 = OpVariable %6 UniformConstant + %2 = OpFunction %3 None %4 + %7 = OpLabel + %9 = OpLoad %5 %8 + %10 = OpLoad %5 %8 + OpReturn + OpFunctionEnd + )"; + auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(9); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(10); + EXPECT_EQ(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); +} + +TEST_F(ValueTableTest, DifferentPushConstantLoad) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer PushConstant %5 + %8 = OpVariable %6 PushConstant + %2 = OpFunction %3 None %4 + %7 = OpLabel + %9 = OpLoad %5 %8 + %10 = OpLoad %5 %8 + OpReturn + OpFunctionEnd + )"; + auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(9); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(10); + EXPECT_EQ(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); +} + +TEST_F(ValueTableTest, SameCall) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypeFunction %5 + %7 = OpTypePointer Function %5 + %8 = OpVariable %7 Private + %2 = OpFunction %3 None %4 + %9 = OpLabel + %10 = OpFunctionCall %5 %11 + OpReturn + OpFunctionEnd + %11 = OpFunction %5 None %6 + %12 = OpLabel + %13 = OpLoad %5 %8 + OpReturnValue %13 + OpFunctionEnd + )"; + auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ValueNumberTable vtable(context.get()); + Instruction* inst = context->get_def_use_mgr()->GetDef(10); + EXPECT_EQ(vtable.GetValueNumber(inst), vtable.GetValueNumber(inst)); +} + +// Function calls should be given a new value number, even if they are the same. +TEST_F(ValueTableTest, DifferentCall) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypeFunction %5 + %7 = OpTypePointer Function %5 + %8 = OpVariable %7 Private + %2 = OpFunction %3 None %4 + %9 = OpLabel + %10 = OpFunctionCall %5 %11 + %12 = OpFunctionCall %5 %11 + OpReturn + OpFunctionEnd + %11 = OpFunction %5 None %6 + %13 = OpLabel + %14 = OpLoad %5 %8 + OpReturnValue %14 + OpFunctionEnd + )"; + auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(10); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(12); + EXPECT_NE(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); +} + +// It is possible to have two instruction that compute the same numerical value, +// but with different types. They should have different value numbers. +TEST_F(ValueTableTest, DifferentTypes) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeInt 32 0 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %5 + %2 = OpFunction %3 None %4 + %8 = OpLabel + %9 = OpVariable %7 Function + %10 = OpLoad %5 %9 + %11 = OpIAdd %5 %10 %10 + %12 = OpIAdd %6 %10 %10 + OpReturn + OpFunctionEnd + )"; + auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(11); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(12); + EXPECT_NE(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); +} + +TEST_F(ValueTableTest, CopyObject) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer Function %5 + %2 = OpFunction %3 None %4 + %7 = OpLabel + %8 = OpVariable %6 Function + %9 = OpLoad %5 %8 + %10 = OpCopyObject %5 %9 + OpReturn + OpFunctionEnd + )"; + auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(9); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(10); + EXPECT_EQ(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); +} + +// Test that a phi where the operands have the same value assigned that value +// to the result of the phi. +TEST_F(ValueTableTest, PhiTest1) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer Uniform %5 + %7 = OpTypeBool + %8 = OpConstantTrue %7 + %9 = OpVariable %6 Uniform + %2 = OpFunction %3 None %4 + %10 = OpLabel + OpBranchConditional %8 %11 %12 + %11 = OpLabel + %13 = OpLoad %5 %9 + OpBranch %14 + %12 = OpLabel + %15 = OpLoad %5 %9 + OpBranch %14 + %14 = OpLabel + %16 = OpPhi %5 %13 %11 %15 %12 + OpReturn + OpFunctionEnd + )"; + auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(13); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(15); + Instruction* phi = context->get_def_use_mgr()->GetDef(16); + EXPECT_EQ(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); + EXPECT_EQ(vtable.GetValueNumber(inst1), vtable.GetValueNumber(phi)); +} + +// When the values for the inputs to a phi do not match, then the phi should +// have its own value number. +TEST_F(ValueTableTest, PhiTest2) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer Uniform %5 + %7 = OpTypeBool + %8 = OpConstantTrue %7 + %9 = OpVariable %6 Uniform + %10 = OpVariable %6 Uniform + %2 = OpFunction %3 None %4 + %11 = OpLabel + OpBranchConditional %8 %12 %13 + %12 = OpLabel + %14 = OpLoad %5 %9 + OpBranch %15 + %13 = OpLabel + %16 = OpLoad %5 %10 + OpBranch %15 + %15 = OpLabel + %17 = OpPhi %14 %12 %16 %13 + OpReturn + OpFunctionEnd + )"; + auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(14); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(16); + Instruction* phi = context->get_def_use_mgr()->GetDef(17); + EXPECT_NE(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); + EXPECT_NE(vtable.GetValueNumber(inst1), vtable.GetValueNumber(phi)); + EXPECT_NE(vtable.GetValueNumber(inst2), vtable.GetValueNumber(phi)); +} + +// Test that a phi node in a loop header gets a new value because one of its +// inputs comes from later in the loop. +TEST_F(ValueTableTest, PhiLoopTest) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeFloat 32 + %6 = OpTypePointer Uniform %5 + %7 = OpTypeBool + %8 = OpConstantTrue %7 + %9 = OpVariable %6 Uniform + %10 = OpVariable %6 Uniform + %2 = OpFunction %3 None %4 + %11 = OpLabel + %12 = OpLoad %5 %9 + OpSelectionMerge %13 None + OpBranchConditional %8 %14 %13 + %14 = OpLabel + %15 = OpPhi %5 %12 %11 %16 %14 + %16 = OpLoad %5 %9 + OpLoopMerge %17 %14 None + OpBranchConditional %8 %14 %17 + %17 = OpLabel + OpBranch %13 + %13 = OpLabel + %18 = OpPhi %5 %12 %11 %16 %17 + OpReturn + OpFunctionEnd + )"; + auto context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); + ValueNumberTable vtable(context.get()); + Instruction* inst1 = context->get_def_use_mgr()->GetDef(12); + Instruction* inst2 = context->get_def_use_mgr()->GetDef(16); + EXPECT_EQ(vtable.GetValueNumber(inst1), vtable.GetValueNumber(inst2)); + + Instruction* phi1 = context->get_def_use_mgr()->GetDef(15); + EXPECT_NE(vtable.GetValueNumber(inst1), vtable.GetValueNumber(phi1)); + + Instruction* phi2 = context->get_def_use_mgr()->GetDef(18); + EXPECT_EQ(vtable.GetValueNumber(inst1), vtable.GetValueNumber(phi2)); + EXPECT_NE(vtable.GetValueNumber(phi1), vtable.GetValueNumber(phi2)); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/vector_dce_test.cpp b/test/opt/vector_dce_test.cpp new file mode 100644 index 000000000..a978a07f9 --- /dev/null +++ b/test/opt/vector_dce_test.cpp @@ -0,0 +1,1231 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using VectorDCETest = PassTest<::testing::Test>; + +TEST_F(VectorDCETest, InsertAfterInsertElim) { + // With two insertions to the same offset, the first is dead. + // + // Note: The SPIR-V assembly has had store/load elimination + // performed to allow the inserts and extracts to directly + // reference each other. + // + // #version 450 + // + // layout (location=0) in float In0; + // layout (location=1) in float In1; + // layout (location=2) in vec2 In2; + // layout (location=0) out vec4 OutColor; + // + // void main() + // { + // vec2 v = In2; + // v.x = In0 + In1; // dead + // v.x = 0.0; + // OutColor = v.xyxy; + // } + + const std::string before_predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %In2 %In0 %In1 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %In2 "In2" +OpName %In0 "In0" +OpName %In1 "In1" +OpName %OutColor "OutColor" +OpName %_Globals_ "_Globals_" +OpMemberName %_Globals_ 0 "g_b" +OpMemberName %_Globals_ 1 "g_n" +OpName %_ "" +OpDecorate %In2 Location 2 +OpDecorate %In0 Location 0 +OpDecorate %In1 Location 1 +OpDecorate %OutColor Location 0 +OpMemberDecorate %_Globals_ 0 Offset 0 +OpMemberDecorate %_Globals_ 1 Offset 4 +OpDecorate %_Globals_ Block +OpDecorate %_ DescriptorSet 0 +OpDecorate %_ Binding 0 +%void = OpTypeVoid +%11 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%_ptr_Function_v2float = OpTypePointer Function %v2float +%_ptr_Input_v2float = OpTypePointer Input %v2float +%In2 = OpVariable %_ptr_Input_v2float Input +%_ptr_Input_float = OpTypePointer Input %float +%In0 = OpVariable %_ptr_Input_float Input +%In1 = OpVariable %_ptr_Input_float Input +%uint = OpTypeInt 32 0 +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%int = OpTypeInt 32 1 +%_Globals_ = OpTypeStruct %uint %int +%_ptr_Uniform__Globals_ = OpTypePointer Uniform %_Globals_ +%_ = OpVariable %_ptr_Uniform__Globals_ Uniform +)"; + + const std::string after_predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %In2 %In0 %In1 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %In2 "In2" +OpName %In0 "In0" +OpName %In1 "In1" +OpName %OutColor "OutColor" +OpName %_Globals_ "_Globals_" +OpMemberName %_Globals_ 0 "g_b" +OpMemberName %_Globals_ 1 "g_n" +OpName %_ "" +OpDecorate %In2 Location 2 +OpDecorate %In0 Location 0 +OpDecorate %In1 Location 1 +OpDecorate %OutColor Location 0 +OpMemberDecorate %_Globals_ 0 Offset 0 +OpMemberDecorate %_Globals_ 1 Offset 4 +OpDecorate %_Globals_ Block +OpDecorate %_ DescriptorSet 0 +OpDecorate %_ Binding 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%_ptr_Function_v2float = OpTypePointer Function %v2float +%_ptr_Input_v2float = OpTypePointer Input %v2float +%In2 = OpVariable %_ptr_Input_v2float Input +%_ptr_Input_float = OpTypePointer Input %float +%In0 = OpVariable %_ptr_Input_float Input +%In1 = OpVariable %_ptr_Input_float Input +%uint = OpTypeInt 32 0 +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%int = OpTypeInt 32 1 +%_Globals_ = OpTypeStruct %uint %int +%_ptr_Uniform__Globals_ = OpTypePointer Uniform %_Globals_ +%_ = OpVariable %_ptr_Uniform__Globals_ Uniform +)"; + + const std::string before = + R"(%main = OpFunction %void None %11 +%25 = OpLabel +%26 = OpLoad %v2float %In2 +%27 = OpLoad %float %In0 +%28 = OpLoad %float %In1 +%29 = OpFAdd %float %27 %28 +%35 = OpCompositeInsert %v2float %29 %26 0 +%37 = OpCompositeInsert %v2float %float_0 %35 0 +%33 = OpVectorShuffle %v4float %37 %37 0 1 0 1 +OpStore %OutColor %33 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %10 +%23 = OpLabel +%24 = OpLoad %v2float %In2 +%25 = OpLoad %float %In0 +%26 = OpLoad %float %In1 +%27 = OpFAdd %float %25 %26 +%28 = OpCompositeInsert %v2float %27 %24 0 +%29 = OpCompositeInsert %v2float %float_0 %24 0 +%30 = OpVectorShuffle %v4float %29 %29 0 1 0 1 +OpStore %OutColor %30 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(before_predefs + before, + after_predefs + after, true, true); +} + +TEST_F(VectorDCETest, DeadInsertInChainWithPhi) { + // Dead insert eliminated with phi in insertion chain. + // + // Note: The SPIR-V assembly has had store/load elimination + // performed to allow the inserts and extracts to directly + // reference each other. + // + // #version 450 + // + // layout (location=0) in vec4 In0; + // layout (location=1) in float In1; + // layout (location=2) in float In2; + // layout (location=0) out vec4 OutColor; + // + // layout(std140, binding = 0 ) uniform _Globals_ + // { + // bool g_b; + // }; + // + // void main() + // { + // vec4 v = In0; + // v.z = In1 + In2; + // if (g_b) v.w = 1.0; + // OutColor = vec4(v.x,v.y,0.0,v.w); + // } + + const std::string before_predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %In0 %In1 %In2 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %In0 "In0" +OpName %In1 "In1" +OpName %In2 "In2" +OpName %_Globals_ "_Globals_" +OpMemberName %_Globals_ 0 "g_b" +OpName %_ "" +OpName %OutColor "OutColor" +OpDecorate %In0 Location 0 +OpDecorate %In1 Location 1 +OpDecorate %In2 Location 2 +OpMemberDecorate %_Globals_ 0 Offset 0 +OpDecorate %_Globals_ Block +OpDecorate %_ DescriptorSet 0 +OpDecorate %_ Binding 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%11 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%In0 = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%In1 = OpVariable %_ptr_Input_float Input +%In2 = OpVariable %_ptr_Input_float Input +%uint = OpTypeInt 32 0 +%_ptr_Function_float = OpTypePointer Function %float +%_Globals_ = OpTypeStruct %uint +%_ptr_Uniform__Globals_ = OpTypePointer Uniform %_Globals_ +%_ = OpVariable %_ptr_Uniform__Globals_ Uniform +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%bool = OpTypeBool +%uint_0 = OpConstant %uint 0 +%float_1 = OpConstant %float 1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%float_0 = OpConstant %float 0 +)"; + + const std::string after_predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %In0 %In1 %In2 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %In0 "In0" +OpName %In1 "In1" +OpName %In2 "In2" +OpName %_Globals_ "_Globals_" +OpMemberName %_Globals_ 0 "g_b" +OpName %_ "" +OpName %OutColor "OutColor" +OpDecorate %In0 Location 0 +OpDecorate %In1 Location 1 +OpDecorate %In2 Location 2 +OpMemberDecorate %_Globals_ 0 Offset 0 +OpDecorate %_Globals_ Block +OpDecorate %_ DescriptorSet 0 +OpDecorate %_ Binding 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%In0 = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%In1 = OpVariable %_ptr_Input_float Input +%In2 = OpVariable %_ptr_Input_float Input +%uint = OpTypeInt 32 0 +%_ptr_Function_float = OpTypePointer Function %float +%_Globals_ = OpTypeStruct %uint +%_ptr_Uniform__Globals_ = OpTypePointer Uniform %_Globals_ +%_ = OpVariable %_ptr_Uniform__Globals_ Uniform +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%bool = OpTypeBool +%uint_0 = OpConstant %uint 0 +%float_1 = OpConstant %float 1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%float_0 = OpConstant %float 0 +)"; + + const std::string before = + R"(%main = OpFunction %void None %11 +%31 = OpLabel +%32 = OpLoad %v4float %In0 +%33 = OpLoad %float %In1 +%34 = OpLoad %float %In2 +%35 = OpFAdd %float %33 %34 +%51 = OpCompositeInsert %v4float %35 %32 2 +%37 = OpAccessChain %_ptr_Uniform_uint %_ %int_0 +%38 = OpLoad %uint %37 +%39 = OpINotEqual %bool %38 %uint_0 +OpSelectionMerge %40 None +OpBranchConditional %39 %41 %40 +%41 = OpLabel +%53 = OpCompositeInsert %v4float %float_1 %51 3 +OpBranch %40 +%40 = OpLabel +%60 = OpPhi %v4float %51 %31 %53 %41 +%55 = OpCompositeExtract %float %60 0 +%57 = OpCompositeExtract %float %60 1 +%59 = OpCompositeExtract %float %60 3 +%49 = OpCompositeConstruct %v4float %55 %57 %float_0 %59 +OpStore %OutColor %49 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %10 +%27 = OpLabel +%28 = OpLoad %v4float %In0 +%29 = OpLoad %float %In1 +%30 = OpLoad %float %In2 +%31 = OpFAdd %float %29 %30 +%32 = OpCompositeInsert %v4float %31 %28 2 +%33 = OpAccessChain %_ptr_Uniform_uint %_ %int_0 +%34 = OpLoad %uint %33 +%35 = OpINotEqual %bool %34 %uint_0 +OpSelectionMerge %36 None +OpBranchConditional %35 %37 %36 +%37 = OpLabel +%38 = OpCompositeInsert %v4float %float_1 %28 3 +OpBranch %36 +%36 = OpLabel +%39 = OpPhi %v4float %28 %27 %38 %37 +%40 = OpCompositeExtract %float %39 0 +%41 = OpCompositeExtract %float %39 1 +%42 = OpCompositeExtract %float %39 3 +%43 = OpCompositeConstruct %v4float %40 %41 %float_0 %42 +OpStore %OutColor %43 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(before_predefs + before, + after_predefs + after, true, true); +} + +TEST_F(VectorDCETest, DeadInsertWithScalars) { + // Dead insert which requires two passes to eliminate + // + // Note: The SPIR-V assembly has had store/load elimination + // performed to allow the inserts and extracts to directly + // reference each other. + // + // #version 450 + // + // layout (location=0) in vec4 In0; + // layout (location=1) in float In1; + // layout (location=2) in float In2; + // layout (location=0) out vec4 OutColor; + // + // layout(std140, binding = 0 ) uniform _Globals_ + // { + // bool g_b; + // bool g_b2; + // }; + // + // void main() + // { + // vec4 v1, v2; + // v1 = In0; + // v1.y = In1 + In2; // dead, second pass + // if (g_b) v1.x = 1.0; + // v2.x = v1.x; + // v2.y = v1.y; // dead, first pass + // if (g_b2) v2.x = 0.0; + // OutColor = vec4(v2.x,v2.x,0.0,1.0); + // } + + const std::string before_predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %In0 %In1 %In2 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %In0 "In0" +OpName %In1 "In1" +OpName %In2 "In2" +OpName %_Globals_ "_Globals_" +OpMemberName %_Globals_ 0 "g_b" +OpMemberName %_Globals_ 1 "g_b2" +OpName %_ "" +OpName %OutColor "OutColor" +OpDecorate %In0 Location 0 +OpDecorate %In1 Location 1 +OpDecorate %In2 Location 2 +OpMemberDecorate %_Globals_ 0 Offset 0 +OpMemberDecorate %_Globals_ 1 Offset 4 +OpDecorate %_Globals_ Block +OpDecorate %_ DescriptorSet 0 +OpDecorate %_ Binding 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%In0 = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%In1 = OpVariable %_ptr_Input_float Input +%In2 = OpVariable %_ptr_Input_float Input +%uint = OpTypeInt 32 0 +%_Globals_ = OpTypeStruct %uint %uint +%_ptr_Uniform__Globals_ = OpTypePointer Uniform %_Globals_ +%_ = OpVariable %_ptr_Uniform__Globals_ Uniform +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%bool = OpTypeBool +%uint_0 = OpConstant %uint 0 +%float_1 = OpConstant %float 1 +%int_1 = OpConstant %int 1 +%float_0 = OpConstant %float 0 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%27 = OpUndef %v4float +)"; + + const std::string after_predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %In0 %In1 %In2 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %In0 "In0" +OpName %In1 "In1" +OpName %In2 "In2" +OpName %_Globals_ "_Globals_" +OpMemberName %_Globals_ 0 "g_b" +OpMemberName %_Globals_ 1 "g_b2" +OpName %_ "" +OpName %OutColor "OutColor" +OpDecorate %In0 Location 0 +OpDecorate %In1 Location 1 +OpDecorate %In2 Location 2 +OpMemberDecorate %_Globals_ 0 Offset 0 +OpMemberDecorate %_Globals_ 1 Offset 4 +OpDecorate %_Globals_ Block +OpDecorate %_ DescriptorSet 0 +OpDecorate %_ Binding 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%In0 = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%In1 = OpVariable %_ptr_Input_float Input +%In2 = OpVariable %_ptr_Input_float Input +%uint = OpTypeInt 32 0 +%_Globals_ = OpTypeStruct %uint %uint +%_ptr_Uniform__Globals_ = OpTypePointer Uniform %_Globals_ +%_ = OpVariable %_ptr_Uniform__Globals_ Uniform +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%_ptr_Uniform_uint = OpTypePointer Uniform %uint +%bool = OpTypeBool +%uint_0 = OpConstant %uint 0 +%float_1 = OpConstant %float 1 +%int_1 = OpConstant %int 1 +%float_0 = OpConstant %float 0 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%27 = OpUndef %v4float +%55 = OpUndef %v4float +)"; + + const std::string before = + R"(%main = OpFunction %void None %10 +%28 = OpLabel +%29 = OpLoad %v4float %In0 +%30 = OpLoad %float %In1 +%31 = OpLoad %float %In2 +%32 = OpFAdd %float %30 %31 +%33 = OpCompositeInsert %v4float %32 %29 1 +%34 = OpAccessChain %_ptr_Uniform_uint %_ %int_0 +%35 = OpLoad %uint %34 +%36 = OpINotEqual %bool %35 %uint_0 +OpSelectionMerge %37 None +OpBranchConditional %36 %38 %37 +%38 = OpLabel +%39 = OpCompositeInsert %v4float %float_1 %33 0 +OpBranch %37 +%37 = OpLabel +%40 = OpPhi %v4float %33 %28 %39 %38 +%41 = OpCompositeExtract %float %40 0 +%42 = OpCompositeInsert %v4float %41 %27 0 +%43 = OpCompositeExtract %float %40 1 +%44 = OpCompositeInsert %v4float %43 %42 1 +%45 = OpAccessChain %_ptr_Uniform_uint %_ %int_1 +%46 = OpLoad %uint %45 +%47 = OpINotEqual %bool %46 %uint_0 +OpSelectionMerge %48 None +OpBranchConditional %47 %49 %48 +%49 = OpLabel +%50 = OpCompositeInsert %v4float %float_0 %44 0 +OpBranch %48 +%48 = OpLabel +%51 = OpPhi %v4float %44 %37 %50 %49 +%52 = OpCompositeExtract %float %51 0 +%53 = OpCompositeExtract %float %51 0 +%54 = OpCompositeConstruct %v4float %52 %53 %float_0 %float_1 +OpStore %OutColor %54 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %10 +%28 = OpLabel +%29 = OpLoad %v4float %In0 +%30 = OpLoad %float %In1 +%31 = OpLoad %float %In2 +%32 = OpFAdd %float %30 %31 +%33 = OpCompositeInsert %v4float %32 %29 1 +%34 = OpAccessChain %_ptr_Uniform_uint %_ %int_0 +%35 = OpLoad %uint %34 +%36 = OpINotEqual %bool %35 %uint_0 +OpSelectionMerge %37 None +OpBranchConditional %36 %38 %37 +%38 = OpLabel +%39 = OpCompositeInsert %v4float %float_1 %55 0 +OpBranch %37 +%37 = OpLabel +%40 = OpPhi %v4float %29 %28 %39 %38 +%41 = OpCompositeExtract %float %40 0 +%42 = OpCompositeInsert %v4float %41 %55 0 +%43 = OpCompositeExtract %float %40 1 +%44 = OpCompositeInsert %v4float %43 %42 1 +%45 = OpAccessChain %_ptr_Uniform_uint %_ %int_1 +%46 = OpLoad %uint %45 +%47 = OpINotEqual %bool %46 %uint_0 +OpSelectionMerge %48 None +OpBranchConditional %47 %49 %48 +%49 = OpLabel +%50 = OpCompositeInsert %v4float %float_0 %55 0 +OpBranch %48 +%48 = OpLabel +%51 = OpPhi %v4float %42 %37 %50 %49 +%52 = OpCompositeExtract %float %51 0 +%53 = OpCompositeExtract %float %51 0 +%54 = OpCompositeConstruct %v4float %52 %53 %float_0 %float_1 +OpStore %OutColor %54 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(before_predefs + before, + after_predefs + after, true, true); +} + +TEST_F(VectorDCETest, InsertObjectLive) { + // Make sure that the object being inserted in an OpCompositeInsert + // is not removed when it is uses later on. + const std::string before = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %In0 %In1 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %In0 "In0" +OpName %In1 "In1" +OpName %OutColor "OutColor" +OpDecorate %In0 Location 0 +OpDecorate %In1 Location 1 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%In0 = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%In1 = OpVariable %_ptr_Input_float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %10 +%28 = OpLabel +%29 = OpLoad %v4float %In0 +%30 = OpLoad %float %In1 +%33 = OpCompositeInsert %v4float %30 %29 1 +OpStore %OutColor %33 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(before, before, true, true); +} + +TEST_F(VectorDCETest, DeadInsertInCycle) { + // Dead insert in chain with cycle. Demonstrates analysis can handle + // cycles in chains going through scalars intermediate values. + // + // Note: The SPIR-V assembly has had store/load elimination + // performed to allow the inserts and extracts to directly + // reference each other. + // + // #version 450 + // + // layout (location=0) in vec4 In0; + // layout (location=1) in float In1; + // layout (location=2) in float In2; + // layout (location=0) out vec4 OutColor; + // + // layout(std140, binding = 0 ) uniform _Globals_ + // { + // int g_n ; + // }; + // + // void main() + // { + // vec2 v = vec2(0.0, 1.0); + // for (int i = 0; i < g_n; i++) { + // v.x = v.x + 1; + // v.y = v.y * 0.9; // dead + // } + // OutColor = vec4(v.x); + // } + + const std::string assembly = + R"( +; CHECK: [[init_val:%\w+]] = OpConstantComposite %v2float %float_0 %float_1 +; CHECK: [[undef:%\w+]] = OpUndef %v2float +; CHECK: OpFunction +; CHECK: [[entry_lab:%\w+]] = OpLabel +; CHECK: [[loop_header:%\w+]] = OpLabel +; CHECK: OpPhi %v2float [[init_val]] [[entry_lab]] [[x_insert:%\w+]] {{%\w+}} +; CHECK: [[x_insert:%\w+]] = OpCompositeInsert %v2float %43 [[undef]] 0 +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %OutColor %In0 %In1 %In2 +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpName %main "main" +OpName %_Globals_ "_Globals_" +OpMemberName %_Globals_ 0 "g_n" +OpName %_ "" +OpName %OutColor "OutColor" +OpName %In0 "In0" +OpName %In1 "In1" +OpName %In2 "In2" +OpMemberDecorate %_Globals_ 0 Offset 0 +OpDecorate %_Globals_ Block +OpDecorate %_ DescriptorSet 0 +OpDecorate %_ Binding 0 +OpDecorate %OutColor Location 0 +OpDecorate %In0 Location 0 +OpDecorate %In1 Location 1 +OpDecorate %In2 Location 2 +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%_ptr_Function_v2float = OpTypePointer Function %v2float +%float_0 = OpConstant %float 0 +%float_1 = OpConstant %float 1 +%16 = OpConstantComposite %v2float %float_0 %float_1 +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_0 = OpConstant %int 0 +%_Globals_ = OpTypeStruct %int +%_ptr_Uniform__Globals_ = OpTypePointer Uniform %_Globals_ +%_ = OpVariable %_ptr_Uniform__Globals_ Uniform +%_ptr_Uniform_int = OpTypePointer Uniform %int +%bool = OpTypeBool +%float_0_75 = OpConstant %float 0.75 +%int_1 = OpConstant %int 1 +%v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%_ptr_Input_v4float = OpTypePointer Input %v4float +%In0 = OpVariable %_ptr_Input_v4float Input +%_ptr_Input_float = OpTypePointer Input %float +%In1 = OpVariable %_ptr_Input_float Input +%In2 = OpVariable %_ptr_Input_float Input +%main = OpFunction %void None %10 +%29 = OpLabel +OpBranch %30 +%30 = OpLabel +%31 = OpPhi %v2float %16 %29 %32 %33 +%34 = OpPhi %int %int_0 %29 %35 %33 +OpLoopMerge %36 %33 None +OpBranch %37 +%37 = OpLabel +%38 = OpAccessChain %_ptr_Uniform_int %_ %int_0 +%39 = OpLoad %int %38 +%40 = OpSLessThan %bool %34 %39 +OpBranchConditional %40 %41 %36 +%41 = OpLabel +%42 = OpCompositeExtract %float %31 0 +%43 = OpFAdd %float %42 %float_1 +%44 = OpCompositeInsert %v2float %43 %31 0 +%45 = OpCompositeExtract %float %44 1 +%46 = OpFMul %float %45 %float_0_75 +%32 = OpCompositeInsert %v2float %46 %44 1 +OpBranch %33 +%33 = OpLabel +%35 = OpIAdd %int %34 %int_1 +OpBranch %30 +%36 = OpLabel +%47 = OpCompositeExtract %float %31 0 +%48 = OpCompositeConstruct %v4float %47 %47 %47 %47 +OpStore %OutColor %48 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(assembly, true); +} + +TEST_F(VectorDCETest, DeadLoadFeedingCompositeConstruct) { + // Detach the loads feeding the CompositeConstruct for the unused elements. + // TODO: Implement the rewrite for CompositeConstruct. + + const std::string assembly = + R"( +; CHECK: [[undef:%\w+]] = OpUndef %float +; CHECK: [[ac:%\w+]] = OpAccessChain %_ptr_Input_float %In0 %uint_2 +; CHECK: [[load:%\w+]] = OpLoad %float [[ac]] +; CHECK: OpCompositeConstruct %v3float [[load]] [[undef]] [[undef]] +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %In0 %OutColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 450 +OpSourceExtension "GL_GOOGLE_cpp_style_line_directive" +OpSourceExtension "GL_GOOGLE_include_directive" +OpName %main "main" +OpName %In0 "In0" +OpName %OutColor "OutColor" +OpDecorate %In0 Location 0 +OpDecorate %OutColor Location 0 +%void = OpTypeVoid +%6 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%In0 = OpVariable %_ptr_Input_v4float Input +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%_ptr_Input_float = OpTypePointer Input %float +%uint_1 = OpConstant %uint 1 +%uint_2 = OpConstant %uint 2 +%v3float = OpTypeVector %float 3 +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%int_20 = OpConstant %int 20 +%bool = OpTypeBool +%float_1 = OpConstant %float 1 +%int_1 = OpConstant %int 1 +%_ptr_Output_v4float = OpTypePointer Output %v4float +%OutColor = OpVariable %_ptr_Output_v4float Output +%23 = OpUndef %v3float +%main = OpFunction %void None %6 +%24 = OpLabel +%25 = OpAccessChain %_ptr_Input_float %In0 %uint_0 +%26 = OpLoad %float %25 +%27 = OpAccessChain %_ptr_Input_float %In0 %uint_1 +%28 = OpLoad %float %27 +%29 = OpAccessChain %_ptr_Input_float %In0 %uint_2 +%30 = OpLoad %float %29 +%31 = OpCompositeConstruct %v3float %30 %28 %26 +OpBranch %32 +%32 = OpLabel +%33 = OpPhi %v3float %31 %24 %34 %35 +%36 = OpPhi %int %int_0 %24 %37 %35 +OpLoopMerge %38 %35 None +OpBranch %39 +%39 = OpLabel +%40 = OpSLessThan %bool %36 %int_20 +OpBranchConditional %40 %41 %38 +%41 = OpLabel +%42 = OpCompositeExtract %float %33 0 +%43 = OpFAdd %float %42 %float_1 +%34 = OpCompositeInsert %v3float %43 %33 0 +OpBranch %35 +%35 = OpLabel +%37 = OpIAdd %int %36 %int_1 +OpBranch %32 +%38 = OpLabel +%44 = OpCompositeExtract %float %33 0 +%45 = OpCompositeConstruct %v4float %44 %44 %44 %44 +OpStore %OutColor %45 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch(assembly, true); +} + +TEST_F(VectorDCETest, DeadLoadFeedingVectorShuffle) { + // Detach the loads feeding the CompositeConstruct for the unused elements. + // TODO: Implement the rewrite for CompositeConstruct. + + const std::string assembly = + R"( +; MemPass Type2Undef does not reuse and already existing undef. +; CHECK: {{%\w+}} = OpUndef %v3float +; CHECK: [[undef:%\w+]] = OpUndef %v3float +; CHECK: OpFunction +; CHECK: OpVectorShuffle %v3float {{%\w+}} [[undef]] 0 4 5 + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %In0 %OutColor + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpSourceExtension "GL_GOOGLE_cpp_style_line_directive" + OpSourceExtension "GL_GOOGLE_include_directive" + OpName %main "main" + OpName %In0 "In0" + OpName %OutColor "OutColor" + OpDecorate %In0 Location 0 + OpDecorate %OutColor Location 0 + %void = OpTypeVoid + %6 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float + %In0 = OpVariable %_ptr_Input_v4float Input + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 +%_ptr_Input_float = OpTypePointer Input %float + %uint_1 = OpConstant %uint 1 + %uint_2 = OpConstant %uint 2 + %v3float = OpTypeVector %float 3 + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %int_20 = OpConstant %int 20 + %bool = OpTypeBool + %float_1 = OpConstant %float 1 + %vec_const = OpConstantComposite %v3float %float_1 %float_1 %float_1 + %int_1 = OpConstant %int 1 +%_ptr_Output_v4float = OpTypePointer Output %v4float + %OutColor = OpVariable %_ptr_Output_v4float Output + %23 = OpUndef %v3float + %main = OpFunction %void None %6 + %24 = OpLabel + %25 = OpAccessChain %_ptr_Input_float %In0 %uint_0 + %26 = OpLoad %float %25 + %27 = OpAccessChain %_ptr_Input_float %In0 %uint_1 + %28 = OpLoad %float %27 + %29 = OpAccessChain %_ptr_Input_float %In0 %uint_2 + %30 = OpLoad %float %29 + %31 = OpCompositeConstruct %v3float %30 %28 %26 + %sh = OpVectorShuffle %v3float %vec_const %31 0 4 5 + OpBranch %32 + %32 = OpLabel + %33 = OpPhi %v3float %sh %24 %34 %35 + %36 = OpPhi %int %int_0 %24 %37 %35 + OpLoopMerge %38 %35 None + OpBranch %39 + %39 = OpLabel + %40 = OpSLessThan %bool %36 %int_20 + OpBranchConditional %40 %41 %38 + %41 = OpLabel + %42 = OpCompositeExtract %float %33 0 + %43 = OpFAdd %float %42 %float_1 + %34 = OpCompositeInsert %v3float %43 %33 0 + OpBranch %35 + %35 = OpLabel + %37 = OpIAdd %int %36 %int_1 + OpBranch %32 + %38 = OpLabel + %44 = OpCompositeExtract %float %33 0 + %45 = OpCompositeConstruct %v4float %44 %44 %44 %44 + OpStore %OutColor %45 + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(assembly, true); +} + +TEST_F(VectorDCETest, DeadInstThroughShuffle) { + // Dead insert in chain with cycle. Demonstrates analysis can handle + // cycles in chains. + // + // Note: The SPIR-V assembly has had store/load elimination + // performed to allow the inserts and extracts to directly + // reference each other. + // + // #version 450 + // + // layout (location=0) out vec4 OutColor; + // + // void main() + // { + // vec2 v; + // v.x = 0.0; + // v.y = 0.1; // dead + // for (int i = 0; i < 20; i++) { + // v.x = v.x + 1; + // v = v * 0.9; + // } + // OutColor = vec4(v.x); + // } + + const std::string assembly = + R"( +; CHECK: OpFunction +; CHECK-NOT: OpCompositeInsert %v2float {{%\w+}} 1 +; CHECK: OpFunctionEnd + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %OutColor + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpSourceExtension "GL_GOOGLE_cpp_style_line_directive" + OpSourceExtension "GL_GOOGLE_include_directive" + OpName %main "main" + OpName %OutColor "OutColor" + OpDecorate %OutColor Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %float_0 = OpConstant %float 0 +%float_0_100000001 = OpConstant %float 0.100000001 + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %int_20 = OpConstant %int 20 + %bool = OpTypeBool + %float_1 = OpConstant %float 1 +%float_0_899999976 = OpConstant %float 0.899999976 + %int_1 = OpConstant %int 1 + %v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float + %OutColor = OpVariable %_ptr_Output_v4float Output + %58 = OpUndef %v2float + %main = OpFunction %void None %3 + %5 = OpLabel + %49 = OpCompositeInsert %v2float %float_0 %58 0 + %51 = OpCompositeInsert %v2float %float_0_100000001 %49 1 + OpBranch %22 + %22 = OpLabel + %60 = OpPhi %v2float %51 %5 %38 %25 + %59 = OpPhi %int %int_0 %5 %41 %25 + OpLoopMerge %24 %25 None + OpBranch %26 + %26 = OpLabel + %30 = OpSLessThan %bool %59 %int_20 + OpBranchConditional %30 %23 %24 + %23 = OpLabel + %53 = OpCompositeExtract %float %60 0 + %34 = OpFAdd %float %53 %float_1 + %55 = OpCompositeInsert %v2float %34 %60 0 + %38 = OpVectorTimesScalar %v2float %55 %float_0_899999976 + OpBranch %25 + %25 = OpLabel + %41 = OpIAdd %int %59 %int_1 + OpBranch %22 + %24 = OpLabel + %57 = OpCompositeExtract %float %60 0 + %47 = OpCompositeConstruct %v4float %57 %57 %57 %57 + OpStore %OutColor %47 + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(assembly, true); +} + +TEST_F(VectorDCETest, DeadInsertThroughOtherInst) { + // Dead insert in chain with cycle. Demonstrates analysis can handle + // cycles in chains. + // + // Note: The SPIR-V assembly has had store/load elimination + // performed to allow the inserts and extracts to directly + // reference each other. + // + // #version 450 + // + // layout (location=0) out vec4 OutColor; + // + // void main() + // { + // vec2 v; + // v.x = 0.0; + // v.y = 0.1; // dead + // for (int i = 0; i < 20; i++) { + // v.x = v.x + 1; + // v = v * 0.9; + // } + // OutColor = vec4(v.x); + // } + + const std::string assembly = + R"( +; CHECK: OpFunction +; CHECK-NOT: OpCompositeInsert %v2float {{%\w+}} 1 +; CHECK: OpFunctionEnd + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %OutColor + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpSourceExtension "GL_GOOGLE_cpp_style_line_directive" + OpSourceExtension "GL_GOOGLE_include_directive" + OpName %main "main" + OpName %OutColor "OutColor" + OpDecorate %OutColor Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %float_0 = OpConstant %float 0 +%float_0_100000001 = OpConstant %float 0.100000001 + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %int_20 = OpConstant %int 20 + %bool = OpTypeBool + %float_1 = OpConstant %float 1 +%float_0_899999976 = OpConstant %float 0.899999976 + %int_1 = OpConstant %int 1 + %v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float + %OutColor = OpVariable %_ptr_Output_v4float Output + %58 = OpUndef %v2float + %main = OpFunction %void None %3 + %5 = OpLabel + %49 = OpCompositeInsert %v2float %float_0 %58 0 + %51 = OpCompositeInsert %v2float %float_0_100000001 %49 1 + OpBranch %22 + %22 = OpLabel + %60 = OpPhi %v2float %51 %5 %38 %25 + %59 = OpPhi %int %int_0 %5 %41 %25 + OpLoopMerge %24 %25 None + OpBranch %26 + %26 = OpLabel + %30 = OpSLessThan %bool %59 %int_20 + OpBranchConditional %30 %23 %24 + %23 = OpLabel + %53 = OpCompositeExtract %float %60 0 + %34 = OpFAdd %float %53 %float_1 + %55 = OpCompositeInsert %v2float %34 %60 0 + %38 = OpVectorTimesScalar %v2float %55 %float_0_899999976 + OpBranch %25 + %25 = OpLabel + %41 = OpIAdd %int %59 %int_1 + OpBranch %22 + %24 = OpLabel + %57 = OpCompositeExtract %float %60 0 + %47 = OpCompositeConstruct %v4float %57 %57 %57 %57 + OpStore %OutColor %47 + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(assembly, true); +} + +TEST_F(VectorDCETest, VectorIntoCompositeConstruct) { + const std::string text = R"(OpCapability Linkage +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "EntryPoint_Main" %2 %3 +OpExecutionMode %1 OriginUpperLeft +OpDecorate %2 Location 0 +OpDecorate %_struct_4 Block +OpDecorate %3 Location 0 +%float = OpTypeFloat 32 +%v2float = OpTypeVector %float 2 +%_ptr_Function_v2float = OpTypePointer Function %v2float +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%mat4v4float = OpTypeMatrix %v4float 4 +%_ptr_Function_mat4v4float = OpTypePointer Function %mat4v4float +%v3float = OpTypeVector %float 3 +%_ptr_Function_v3float = OpTypePointer Function %v3float +%_struct_14 = OpTypeStruct %v2float %mat4v4float %v3float %v2float %v4float +%_ptr_Function__struct_14 = OpTypePointer Function %_struct_14 +%void = OpTypeVoid +%int = OpTypeInt 32 1 +%int_2 = OpConstant %int 2 +%int_1 = OpConstant %int 1 +%int_4 = OpConstant %int 4 +%int_0 = OpConstant %int 0 +%int_3 = OpConstant %int 3 +%float_0 = OpConstant %float 0 +%float_1 = OpConstant %float 1 +%_ptr_Input_v2float = OpTypePointer Input %v2float +%2 = OpVariable %_ptr_Input_v2float Input +%_ptr_Output_v2float = OpTypePointer Output %v2float +%_struct_4 = OpTypeStruct %v2float +%_ptr_Output__struct_4 = OpTypePointer Output %_struct_4 +%3 = OpVariable %_ptr_Output__struct_4 Output +%28 = OpTypeFunction %void +%29 = OpConstantComposite %v2float %float_0 %float_0 +%30 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +%31 = OpConstantComposite %mat4v4float %30 %30 %30 %30 +%32 = OpConstantComposite %v3float %float_0 %float_0 %float_0 +%1 = OpFunction %void None %28 +%33 = OpLabel +%34 = OpVariable %_ptr_Function_v4float Function +%35 = OpVariable %_ptr_Function__struct_14 Function +%36 = OpAccessChain %_ptr_Function_v2float %35 %int_0 +OpStore %36 %29 +%37 = OpAccessChain %_ptr_Function_mat4v4float %35 %int_1 +OpStore %37 %31 +%38 = OpAccessChain %_ptr_Function_v3float %35 %int_2 +OpStore %38 %32 +%39 = OpAccessChain %_ptr_Function_v2float %35 %int_3 +OpStore %39 %29 +%40 = OpAccessChain %_ptr_Function_v4float %35 %int_4 +OpStore %40 %30 +%41 = OpLoad %v2float %2 +OpStore %36 %41 +%42 = OpLoad %v3float %38 +%43 = OpCompositeConstruct %v4float %42 %float_1 +%44 = OpLoad %mat4v4float %37 +%45 = OpVectorTimesMatrix %v4float %43 %44 +OpStore %34 %45 +OpCopyMemory %40 %34 +OpCopyMemory %36 %39 +%46 = OpAccessChain %_ptr_Output_v2float %3 %int_0 +%47 = OpLoad %v2float %36 +OpStore %46 %47 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(text, text, true, true); +} + +TEST_F(VectorDCETest, InsertWithNoIndices) { + const std::string text = R"( +; CHECK: OpEntryPoint Fragment {{%\w+}} "PSMain" [[in1:%\w+]] [[in2:%\w+]] [[out:%\w+]] +; CHECK: OpFunction +; CHECK: [[ld:%\w+]] = OpLoad %v4float [[in2]] +; CHECK: OpStore [[out]] [[ld]] + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "PSMain" %2 %14 %3 + OpExecutionMode %1 OriginUpperLeft + %void = OpTypeVoid + %5 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float + %2 = OpVariable %_ptr_Input_v4float Input + %14 = OpVariable %_ptr_Input_v4float Input + %3 = OpVariable %_ptr_Output_v4float Output + %1 = OpFunction %void None %5 + %10 = OpLabel + %13 = OpLoad %v4float %14 + %11 = OpLoad %v4float %2 + %12 = OpCompositeInsert %v4float %13 %11 + OpStore %3 %12 + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +TEST_F(VectorDCETest, ExtractWithNoIndices) { + const std::string text = R"( +; CHECK: OpLoad %float +; CHECK: [[ld:%\w+]] = OpLoad %v4float +; CHECK: [[ex1:%\w+]] = OpCompositeExtract %v4float [[ld]] +; CHECK: [[ex2:%\w+]] = OpCompositeExtract %float [[ex1]] 1 +; CHECK: OpStore {{%\w+}} [[ex2]] + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "PSMain" %2 %14 %3 + OpExecutionMode %1 OriginUpperLeft + %void = OpTypeVoid + %5 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%_ptr_Input_float = OpTypePointer Input %float +%_ptr_Input_v4float = OpTypePointer Input %v4float +%_ptr_Output_float = OpTypePointer Output %float + %2 = OpVariable %_ptr_Input_v4float Input + %14 = OpVariable %_ptr_Input_float Input + %3 = OpVariable %_ptr_Output_float Output + %1 = OpFunction %void None %5 + %10 = OpLabel + %13 = OpLoad %float %14 + %11 = OpLoad %v4float %2 + %12 = OpCompositeInsert %v4float %13 %11 0 + %20 = OpCompositeExtract %v4float %12 + %21 = OpCompositeExtract %float %20 1 + OpStore %3 %21 + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(text, true); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/opt/workaround1209_test.cpp b/test/opt/workaround1209_test.cpp new file mode 100644 index 000000000..50d3c0915 --- /dev/null +++ b/test/opt/workaround1209_test.cpp @@ -0,0 +1,423 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "test/opt/assembly_builder.h" +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using Workaround1209Test = PassTest<::testing::Test>; + +TEST_F(Workaround1209Test, RemoveOpUnreachableInLoop) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" %texcoord %gl_VertexIndex %_ + OpSource GLSL 400 + OpSourceExtension "GL_ARB_separate_shader_objects" + OpSourceExtension "GL_ARB_shading_language_420pack" + OpName %main "main" + OpName %texcoord "texcoord" + OpName %buf "buf" + OpMemberName %buf 0 "MVP" + OpMemberName %buf 1 "position" + OpMemberName %buf 2 "attr" + OpName %ubuf "ubuf" + OpName %gl_VertexIndex "gl_VertexIndex" + OpName %gl_PerVertex "gl_PerVertex" + OpMemberName %gl_PerVertex 0 "gl_Position" + OpName %_ "" + OpDecorate %texcoord Location 0 + OpDecorate %_arr_v4float_uint_72 ArrayStride 16 + OpDecorate %_arr_v4float_uint_72_0 ArrayStride 16 + OpMemberDecorate %buf 0 ColMajor + OpMemberDecorate %buf 0 Offset 0 + OpMemberDecorate %buf 0 MatrixStride 16 + OpMemberDecorate %buf 1 Offset 64 + OpMemberDecorate %buf 2 Offset 1216 + OpDecorate %buf Block + OpDecorate %ubuf DescriptorSet 0 + OpDecorate %ubuf Binding 0 + OpDecorate %gl_VertexIndex BuiltIn VertexIndex + OpMemberDecorate %gl_PerVertex 0 BuiltIn Position + OpDecorate %gl_PerVertex Block + %void = OpTypeVoid + %12 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float + %texcoord = OpVariable %_ptr_Output_v4float Output +%mat4v4float = OpTypeMatrix %v4float 4 + %uint = OpTypeInt 32 0 + %uint_72 = OpConstant %uint 72 +%_arr_v4float_uint_72 = OpTypeArray %v4float %uint_72 +%_arr_v4float_uint_72_0 = OpTypeArray %v4float %uint_72 + %buf = OpTypeStruct %mat4v4float %_arr_v4float_uint_72 %_arr_v4float_uint_72_0 +%_ptr_Uniform_buf = OpTypePointer Uniform %buf + %ubuf = OpVariable %_ptr_Uniform_buf Uniform + %int = OpTypeInt 32 1 + %int_2 = OpConstant %int 2 +%_ptr_Input_int = OpTypePointer Input %int +%gl_VertexIndex = OpVariable %_ptr_Input_int Input +%_ptr_Uniform_v4float = OpTypePointer Uniform %v4float +%gl_PerVertex = OpTypeStruct %v4float +%_ptr_Output_gl_PerVertex = OpTypePointer Output %gl_PerVertex + %_ = OpVariable %_ptr_Output_gl_PerVertex Output + %int_0 = OpConstant %int 0 + %int_1 = OpConstant %int 1 + %float_1 = OpConstant %float 1 + %28 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 + %main = OpFunction %void None %12 + %29 = OpLabel + OpBranch %30 + %30 = OpLabel +; CHECK: OpLoopMerge [[merge:%[a-zA-Z_\d]+]] + OpLoopMerge %31 %32 None + OpBranch %33 + %33 = OpLabel +; CHECK: OpSelectionMerge [[sel_merge:%[a-zA-Z_\d]+]] + OpSelectionMerge %34 None + OpSwitch %int_1 %35 + %35 = OpLabel + %36 = OpLoad %int %gl_VertexIndex + %37 = OpAccessChain %_ptr_Uniform_v4float %ubuf %int_2 %36 + %38 = OpLoad %v4float %37 + OpStore %texcoord %38 + %39 = OpAccessChain %_ptr_Output_v4float %_ %int_0 + OpStore %39 %28 + OpBranch %31 +; CHECK: [[sel_merge]] = OpLabel + %34 = OpLabel +; CHECK-NEXT: OpBranch [[merge]] + OpUnreachable + %32 = OpLabel + OpBranch %30 + %31 = OpLabel + OpReturn + OpFunctionEnd)"; + + SinglePassRunAndMatch(text, false); +} + +TEST_F(Workaround1209Test, RemoveOpUnreachableInNestedLoop) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %2 "main" %3 %4 %5 + OpSource GLSL 400 + OpSourceExtension "GL_ARB_separate_shader_objects" + OpSourceExtension "GL_ARB_shading_language_420pack" + OpName %2 "main" + OpName %3 "texcoord" + OpName %6 "buf" + OpMemberName %6 0 "MVP" + OpMemberName %6 1 "position" + OpMemberName %6 2 "attr" + OpName %7 "ubuf" + OpName %4 "gl_VertexIndex" + OpName %8 "gl_PerVertex" + OpMemberName %8 0 "gl_Position" + OpName %5 "" + OpDecorate %3 Location 0 + OpDecorate %9 ArrayStride 16 + OpDecorate %10 ArrayStride 16 + OpMemberDecorate %6 0 ColMajor + OpMemberDecorate %6 0 Offset 0 + OpMemberDecorate %6 0 MatrixStride 16 + OpMemberDecorate %6 1 Offset 64 + OpMemberDecorate %6 2 Offset 1216 + OpDecorate %6 Block + OpDecorate %7 DescriptorSet 0 + OpDecorate %7 Binding 0 + OpDecorate %4 BuiltIn VertexIndex + OpMemberDecorate %8 0 BuiltIn Position + OpDecorate %8 Block + %11 = OpTypeVoid + %12 = OpTypeFunction %11 + %13 = OpTypeFloat 32 + %14 = OpTypeVector %13 4 + %15 = OpTypePointer Output %14 + %3 = OpVariable %15 Output + %16 = OpTypeMatrix %14 4 + %17 = OpTypeInt 32 0 + %18 = OpConstant %17 72 + %9 = OpTypeArray %14 %18 + %10 = OpTypeArray %14 %18 + %6 = OpTypeStruct %16 %9 %10 + %19 = OpTypePointer Uniform %6 + %7 = OpVariable %19 Uniform + %20 = OpTypeInt 32 1 + %21 = OpConstant %20 2 + %22 = OpTypePointer Input %20 + %4 = OpVariable %22 Input + %23 = OpTypePointer Uniform %14 + %8 = OpTypeStruct %14 + %24 = OpTypePointer Output %8 + %5 = OpVariable %24 Output + %25 = OpConstant %20 0 + %26 = OpConstant %20 1 + %27 = OpConstant %13 1 + %28 = OpConstantComposite %14 %27 %27 %27 %27 + %2 = OpFunction %11 None %12 + %29 = OpLabel + OpBranch %31 + %31 = OpLabel +; CHECK: OpLoopMerge + OpLoopMerge %32 %33 None + OpBranch %30 + %30 = OpLabel +; CHECK: OpLoopMerge [[merge:%[a-zA-Z_\d]+]] + OpLoopMerge %34 %35 None + OpBranch %36 + %36 = OpLabel +; CHECK: OpSelectionMerge [[sel_merge:%[a-zA-Z_\d]+]] + OpSelectionMerge %37 None + OpSwitch %26 %38 + %38 = OpLabel + %39 = OpLoad %20 %4 + %40 = OpAccessChain %23 %7 %21 %39 + %41 = OpLoad %14 %40 + OpStore %3 %41 + %42 = OpAccessChain %15 %5 %25 + OpStore %42 %28 + OpBranch %34 +; CHECK: [[sel_merge]] = OpLabel + %37 = OpLabel +; CHECK-NEXT: OpBranch [[merge]] + OpUnreachable + %35 = OpLabel + OpBranch %30 + %34 = OpLabel + OpBranch %32 + %33 = OpLabel + OpBranch %31 + %32 = OpLabel + OpReturn + OpFunctionEnd)"; + + SinglePassRunAndMatch(text, false); +} + +TEST_F(Workaround1209Test, RemoveOpUnreachableInAdjacentLoops) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %2 "main" %3 %4 %5 + OpSource GLSL 400 + OpSourceExtension "GL_ARB_separate_shader_objects" + OpSourceExtension "GL_ARB_shading_language_420pack" + OpName %2 "main" + OpName %3 "texcoord" + OpName %6 "buf" + OpMemberName %6 0 "MVP" + OpMemberName %6 1 "position" + OpMemberName %6 2 "attr" + OpName %7 "ubuf" + OpName %4 "gl_VertexIndex" + OpName %8 "gl_PerVertex" + OpMemberName %8 0 "gl_Position" + OpName %5 "" + OpDecorate %3 Location 0 + OpDecorate %9 ArrayStride 16 + OpDecorate %10 ArrayStride 16 + OpMemberDecorate %6 0 ColMajor + OpMemberDecorate %6 0 Offset 0 + OpMemberDecorate %6 0 MatrixStride 16 + OpMemberDecorate %6 1 Offset 64 + OpMemberDecorate %6 2 Offset 1216 + OpDecorate %6 Block + OpDecorate %7 DescriptorSet 0 + OpDecorate %7 Binding 0 + OpDecorate %4 BuiltIn VertexIndex + OpMemberDecorate %8 0 BuiltIn Position + OpDecorate %8 Block + %11 = OpTypeVoid + %12 = OpTypeFunction %11 + %13 = OpTypeFloat 32 + %14 = OpTypeVector %13 4 + %15 = OpTypePointer Output %14 + %3 = OpVariable %15 Output + %16 = OpTypeMatrix %14 4 + %17 = OpTypeInt 32 0 + %18 = OpConstant %17 72 + %9 = OpTypeArray %14 %18 + %10 = OpTypeArray %14 %18 + %6 = OpTypeStruct %16 %9 %10 + %19 = OpTypePointer Uniform %6 + %7 = OpVariable %19 Uniform + %20 = OpTypeInt 32 1 + %21 = OpConstant %20 2 + %22 = OpTypePointer Input %20 + %4 = OpVariable %22 Input + %23 = OpTypePointer Uniform %14 + %8 = OpTypeStruct %14 + %24 = OpTypePointer Output %8 + %5 = OpVariable %24 Output + %25 = OpConstant %20 0 + %26 = OpConstant %20 1 + %27 = OpConstant %13 1 + %28 = OpConstantComposite %14 %27 %27 %27 %27 + %2 = OpFunction %11 None %12 + %29 = OpLabel + OpBranch %30 + %30 = OpLabel +; CHECK: OpLoopMerge [[merge1:%[a-zA-Z_\d]+]] + OpLoopMerge %31 %32 None + OpBranch %33 + %33 = OpLabel +; CHECK: OpSelectionMerge [[sel_merge1:%[a-zA-Z_\d]+]] + OpSelectionMerge %34 None + OpSwitch %26 %35 + %35 = OpLabel + %36 = OpLoad %20 %4 + %37 = OpAccessChain %23 %7 %21 %36 + %38 = OpLoad %14 %37 + OpStore %3 %38 + %39 = OpAccessChain %15 %5 %25 + OpStore %39 %28 + OpBranch %31 +; CHECK: [[sel_merge1]] = OpLabel + %34 = OpLabel +; CHECK-NEXT: OpBranch [[merge1]] + OpUnreachable + %32 = OpLabel + OpBranch %30 + %31 = OpLabel +; CHECK: OpLoopMerge [[merge2:%[a-zA-Z_\d]+]] + OpLoopMerge %40 %41 None + OpBranch %42 + %42 = OpLabel +; CHECK: OpSelectionMerge [[sel_merge2:%[a-zA-Z_\d]+]] + OpSelectionMerge %43 None + OpSwitch %26 %44 + %44 = OpLabel + %45 = OpLoad %20 %4 + %46 = OpAccessChain %23 %7 %21 %45 + %47 = OpLoad %14 %46 + OpStore %3 %47 + %48 = OpAccessChain %15 %5 %25 + OpStore %48 %28 + OpBranch %40 +; CHECK: [[sel_merge2]] = OpLabel + %43 = OpLabel +; CHECK-NEXT: OpBranch [[merge2]] + OpUnreachable + %41 = OpLabel + OpBranch %31 + %40 = OpLabel + OpReturn + OpFunctionEnd)"; + + SinglePassRunAndMatch(text, false); +} + +TEST_F(Workaround1209Test, LeaveUnreachableNotInLoop) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" %texcoord %gl_VertexIndex %_ + OpSource GLSL 400 + OpSourceExtension "GL_ARB_separate_shader_objects" + OpSourceExtension "GL_ARB_shading_language_420pack" + OpName %main "main" + OpName %texcoord "texcoord" + OpName %buf "buf" + OpMemberName %buf 0 "MVP" + OpMemberName %buf 1 "position" + OpMemberName %buf 2 "attr" + OpName %ubuf "ubuf" + OpName %gl_VertexIndex "gl_VertexIndex" + OpName %gl_PerVertex "gl_PerVertex" + OpMemberName %gl_PerVertex 0 "gl_Position" + OpName %_ "" + OpDecorate %texcoord Location 0 + OpDecorate %_arr_v4float_uint_72 ArrayStride 16 + OpDecorate %_arr_v4float_uint_72_0 ArrayStride 16 + OpMemberDecorate %buf 0 ColMajor + OpMemberDecorate %buf 0 Offset 0 + OpMemberDecorate %buf 0 MatrixStride 16 + OpMemberDecorate %buf 1 Offset 64 + OpMemberDecorate %buf 2 Offset 1216 + OpDecorate %buf Block + OpDecorate %ubuf DescriptorSet 0 + OpDecorate %ubuf Binding 0 + OpDecorate %gl_VertexIndex BuiltIn VertexIndex + OpMemberDecorate %gl_PerVertex 0 BuiltIn Position + OpDecorate %gl_PerVertex Block + %void = OpTypeVoid + %12 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float + %texcoord = OpVariable %_ptr_Output_v4float Output +%mat4v4float = OpTypeMatrix %v4float 4 + %uint = OpTypeInt 32 0 + %uint_72 = OpConstant %uint 72 +%_arr_v4float_uint_72 = OpTypeArray %v4float %uint_72 +%_arr_v4float_uint_72_0 = OpTypeArray %v4float %uint_72 + %buf = OpTypeStruct %mat4v4float %_arr_v4float_uint_72 %_arr_v4float_uint_72_0 +%_ptr_Uniform_buf = OpTypePointer Uniform %buf + %ubuf = OpVariable %_ptr_Uniform_buf Uniform + %int = OpTypeInt 32 1 + %int_2 = OpConstant %int 2 +%_ptr_Input_int = OpTypePointer Input %int +%gl_VertexIndex = OpVariable %_ptr_Input_int Input +%_ptr_Uniform_v4float = OpTypePointer Uniform %v4float +%gl_PerVertex = OpTypeStruct %v4float +%_ptr_Output_gl_PerVertex = OpTypePointer Output %gl_PerVertex + %_ = OpVariable %_ptr_Output_gl_PerVertex Output + %int_0 = OpConstant %int 0 + %int_1 = OpConstant %int 1 + %float_1 = OpConstant %float 1 + %28 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1 + %main = OpFunction %void None %12 + %29 = OpLabel + OpBranch %30 + %30 = OpLabel + OpSelectionMerge %34 None + OpSwitch %int_1 %35 + %35 = OpLabel + %36 = OpLoad %int %gl_VertexIndex + %37 = OpAccessChain %_ptr_Uniform_v4float %ubuf %int_2 %36 + %38 = OpLoad %v4float %37 + OpStore %texcoord %38 + %39 = OpAccessChain %_ptr_Output_v4float %_ %int_0 + OpStore %39 %28 + OpReturn + %34 = OpLabel +; CHECK: OpUnreachable + OpUnreachable + OpFunctionEnd)"; + + SinglePassRunAndMatch(text, false); +} + +} // namespace +} // namespace opt +} // namespace spvtools diff --git a/test/parse_number_test.cpp b/test/parse_number_test.cpp new file mode 100644 index 000000000..c99205cf5 --- /dev/null +++ b/test/parse_number_test.cpp @@ -0,0 +1,970 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/util/parse_number.h" +#include "spirv-tools/libspirv.h" + +namespace spvtools { +namespace utils { +namespace { + +using testing::Eq; +using testing::IsNull; +using testing::NotNull; + +TEST(ParseNarrowSignedIntegers, Sample) { + int16_t i16; + + EXPECT_FALSE(ParseNumber(nullptr, &i16)); + EXPECT_FALSE(ParseNumber("", &i16)); + EXPECT_FALSE(ParseNumber("0=", &i16)); + + EXPECT_TRUE(ParseNumber("0", &i16)); + EXPECT_EQ(0, i16); + EXPECT_TRUE(ParseNumber("32767", &i16)); + EXPECT_EQ(32767, i16); + EXPECT_TRUE(ParseNumber("-32768", &i16)); + EXPECT_EQ(-32768, i16); + EXPECT_TRUE(ParseNumber("-0", &i16)); + EXPECT_EQ(0, i16); + + // These are out of range, so they should return an error. + // The error code depends on whether this is an optional value. + EXPECT_FALSE(ParseNumber("32768", &i16)); + EXPECT_FALSE(ParseNumber("65535", &i16)); + + // Check hex parsing. + EXPECT_TRUE(ParseNumber("0x7fff", &i16)); + EXPECT_EQ(32767, i16); + // This is out of range. + EXPECT_FALSE(ParseNumber("0xffff", &i16)); +} + +TEST(ParseNarrowUnsignedIntegers, Sample) { + uint16_t u16; + + EXPECT_FALSE(ParseNumber(nullptr, &u16)); + EXPECT_FALSE(ParseNumber("", &u16)); + EXPECT_FALSE(ParseNumber("0=", &u16)); + + EXPECT_TRUE(ParseNumber("0", &u16)); + EXPECT_EQ(0, u16); + EXPECT_TRUE(ParseNumber("65535", &u16)); + EXPECT_EQ(65535, u16); + EXPECT_FALSE(ParseNumber("65536", &u16)); + + // We don't care about -0 since it's rejected at a higher level. + EXPECT_FALSE(ParseNumber("-1", &u16)); + EXPECT_TRUE(ParseNumber("0xffff", &u16)); + EXPECT_EQ(0xffff, u16); + EXPECT_FALSE(ParseNumber("0x10000", &u16)); +} + +TEST(ParseSignedIntegers, Sample) { + int32_t i32; + + // Invalid parse. + EXPECT_FALSE(ParseNumber(nullptr, &i32)); + EXPECT_FALSE(ParseNumber("", &i32)); + EXPECT_FALSE(ParseNumber("0=", &i32)); + + // Decimal values. + EXPECT_TRUE(ParseNumber("0", &i32)); + EXPECT_EQ(0, i32); + EXPECT_TRUE(ParseNumber("2147483647", &i32)); + EXPECT_EQ(std::numeric_limits::max(), i32); + EXPECT_FALSE(ParseNumber("2147483648", &i32)); + EXPECT_TRUE(ParseNumber("-0", &i32)); + EXPECT_EQ(0, i32); + EXPECT_TRUE(ParseNumber("-1", &i32)); + EXPECT_EQ(-1, i32); + EXPECT_TRUE(ParseNumber("-2147483648", &i32)); + EXPECT_EQ(std::numeric_limits::min(), i32); + + // Hex values. + EXPECT_TRUE(ParseNumber("0x7fffffff", &i32)); + EXPECT_EQ(std::numeric_limits::max(), i32); + EXPECT_FALSE(ParseNumber("0x80000000", &i32)); + EXPECT_TRUE(ParseNumber("-0x000", &i32)); + EXPECT_EQ(0, i32); + EXPECT_TRUE(ParseNumber("-0x001", &i32)); + EXPECT_EQ(-1, i32); + EXPECT_TRUE(ParseNumber("-0x80000000", &i32)); + EXPECT_EQ(std::numeric_limits::min(), i32); +} + +TEST(ParseUnsignedIntegers, Sample) { + uint32_t u32; + + // Invalid parse. + EXPECT_FALSE(ParseNumber(nullptr, &u32)); + EXPECT_FALSE(ParseNumber("", &u32)); + EXPECT_FALSE(ParseNumber("0=", &u32)); + + // Valid values. + EXPECT_TRUE(ParseNumber("0", &u32)); + EXPECT_EQ(0u, u32); + EXPECT_TRUE(ParseNumber("4294967295", &u32)); + EXPECT_EQ(std::numeric_limits::max(), u32); + EXPECT_FALSE(ParseNumber("4294967296", &u32)); + + // Hex values. + EXPECT_TRUE(ParseNumber("0xffffffff", &u32)); + EXPECT_EQ(std::numeric_limits::max(), u32); + + // We don't care about -0 since it's rejected at a higher level. + EXPECT_FALSE(ParseNumber("-1", &u32)); +} + +TEST(ParseWideSignedIntegers, Sample) { + int64_t i64; + EXPECT_FALSE(ParseNumber(nullptr, &i64)); + EXPECT_FALSE(ParseNumber("", &i64)); + EXPECT_FALSE(ParseNumber("0=", &i64)); + EXPECT_TRUE(ParseNumber("0", &i64)); + EXPECT_EQ(0, i64); + EXPECT_TRUE(ParseNumber("0x7fffffffffffffff", &i64)); + EXPECT_EQ(0x7fffffffffffffff, i64); + EXPECT_TRUE(ParseNumber("-0", &i64)); + EXPECT_EQ(0, i64); + EXPECT_TRUE(ParseNumber("-1", &i64)); + EXPECT_EQ(-1, i64); +} + +TEST(ParseWideUnsignedIntegers, Sample) { + uint64_t u64; + EXPECT_FALSE(ParseNumber(nullptr, &u64)); + EXPECT_FALSE(ParseNumber("", &u64)); + EXPECT_FALSE(ParseNumber("0=", &u64)); + EXPECT_TRUE(ParseNumber("0", &u64)); + EXPECT_EQ(0u, u64); + EXPECT_TRUE(ParseNumber("0xffffffffffffffff", &u64)); + EXPECT_EQ(0xffffffffffffffffULL, u64); + + // We don't care about -0 since it's rejected at a higher level. + EXPECT_FALSE(ParseNumber("-1", &u64)); +} + +TEST(ParseFloat, Sample) { + float f; + + EXPECT_FALSE(ParseNumber(nullptr, &f)); + EXPECT_FALSE(ParseNumber("", &f)); + EXPECT_FALSE(ParseNumber("0=", &f)); + + // These values are exactly representatble. + EXPECT_TRUE(ParseNumber("0", &f)); + EXPECT_EQ(0.0f, f); + EXPECT_TRUE(ParseNumber("42", &f)); + EXPECT_EQ(42.0f, f); + EXPECT_TRUE(ParseNumber("2.5", &f)); + EXPECT_EQ(2.5f, f); + EXPECT_TRUE(ParseNumber("-32.5", &f)); + EXPECT_EQ(-32.5f, f); + EXPECT_TRUE(ParseNumber("1e38", &f)); + EXPECT_EQ(1e38f, f); + EXPECT_TRUE(ParseNumber("-1e38", &f)); + EXPECT_EQ(-1e38f, f); +} + +TEST(ParseFloat, Overflow) { + // The assembler parses using HexFloat>. Make + // sure that succeeds for in-range values, and fails for out of + // range values. When it does overflow, the value is set to the + // nearest finite value, matching C++11 behavior for operator>> + // on floating point. + HexFloat> f(0.0f); + + EXPECT_TRUE(ParseNumber("1e38", &f)); + EXPECT_EQ(1e38f, f.value().getAsFloat()); + EXPECT_TRUE(ParseNumber("-1e38", &f)); + EXPECT_EQ(-1e38f, f.value().getAsFloat()); + EXPECT_FALSE(ParseNumber("1e40", &f)); + EXPECT_FALSE(ParseNumber("-1e40", &f)); + EXPECT_FALSE(ParseNumber("1e400", &f)); + EXPECT_FALSE(ParseNumber("-1e400", &f)); +} + +TEST(ParseDouble, Sample) { + double f; + + EXPECT_FALSE(ParseNumber(nullptr, &f)); + EXPECT_FALSE(ParseNumber("", &f)); + EXPECT_FALSE(ParseNumber("0=", &f)); + + // These values are exactly representatble. + EXPECT_TRUE(ParseNumber("0", &f)); + EXPECT_EQ(0.0, f); + EXPECT_TRUE(ParseNumber("42", &f)); + EXPECT_EQ(42.0, f); + EXPECT_TRUE(ParseNumber("2.5", &f)); + EXPECT_EQ(2.5, f); + EXPECT_TRUE(ParseNumber("-32.5", &f)); + EXPECT_EQ(-32.5, f); + EXPECT_TRUE(ParseNumber("1e38", &f)); + EXPECT_EQ(1e38, f); + EXPECT_TRUE(ParseNumber("-1e38", &f)); + EXPECT_EQ(-1e38, f); + // These are out of range for 32-bit float, but in range for 64-bit float. + EXPECT_TRUE(ParseNumber("1e40", &f)); + EXPECT_EQ(1e40, f); + EXPECT_TRUE(ParseNumber("-1e40", &f)); + EXPECT_EQ(-1e40, f); +} + +TEST(ParseDouble, Overflow) { + // The assembler parses using HexFloat>. Make + // sure that succeeds for in-range values, and fails for out of + // range values. When it does overflow, the value is set to the + // nearest finite value, matching C++11 behavior for operator>> + // on floating point. + HexFloat> f(0.0); + + EXPECT_TRUE(ParseNumber("1e38", &f)); + EXPECT_EQ(1e38, f.value().getAsFloat()); + EXPECT_TRUE(ParseNumber("-1e38", &f)); + EXPECT_EQ(-1e38, f.value().getAsFloat()); + EXPECT_TRUE(ParseNumber("1e40", &f)); + EXPECT_EQ(1e40, f.value().getAsFloat()); + EXPECT_TRUE(ParseNumber("-1e40", &f)); + EXPECT_EQ(-1e40, f.value().getAsFloat()); + EXPECT_FALSE(ParseNumber("1e400", &f)); + EXPECT_FALSE(ParseNumber("-1e400", &f)); +} + +TEST(ParseFloat16, Overflow) { + // The assembler parses using HexFloat>. Make + // sure that succeeds for in-range values, and fails for out of + // range values. When it does overflow, the value is set to the + // nearest finite value, matching C++11 behavior for operator>> + // on floating point. + HexFloat> f(0); + + EXPECT_FALSE(ParseNumber(nullptr, &f)); + EXPECT_TRUE(ParseNumber("-0.0", &f)); + EXPECT_EQ(uint16_t{0x8000}, f.value().getAsFloat().get_value()); + EXPECT_TRUE(ParseNumber("1.0", &f)); + EXPECT_EQ(uint16_t{0x3c00}, f.value().getAsFloat().get_value()); + + // Overflows 16-bit but not 32-bit + EXPECT_FALSE(ParseNumber("1e38", &f)); + EXPECT_FALSE(ParseNumber("-1e38", &f)); + + // Overflows 32-bit but not 64-bit + EXPECT_FALSE(ParseNumber("1e40", &f)); + EXPECT_FALSE(ParseNumber("-1e40", &f)); + + // Overflows 64-bit + EXPECT_FALSE(ParseNumber("1e400", &f)); + EXPECT_FALSE(ParseNumber("-1e400", &f)); +} + +void AssertEmitFunc(uint32_t) { + ASSERT_FALSE(true) + << "Should not call emit() function when the number can not be parsed."; + return; +} + +TEST(ParseAndEncodeNarrowSignedIntegers, Invalid) { + // The error message should be overwritten after each parsing call. + EncodeNumberStatus rc = EncodeNumberStatus::kSuccess; + std::string err_msg; + NumberType type = {16, SPV_NUMBER_SIGNED_INT}; + + rc = ParseAndEncodeIntegerNumber(nullptr, type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("The given text is a nullptr", err_msg); + rc = ParseAndEncodeIntegerNumber("", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid unsigned integer literal: ", err_msg); + rc = ParseAndEncodeIntegerNumber("=", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid unsigned integer literal: =", err_msg); + rc = ParseAndEncodeIntegerNumber("-", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid signed integer literal: -", err_msg); + rc = ParseAndEncodeIntegerNumber("0=", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid unsigned integer literal: 0=", err_msg); +} + +TEST(ParseAndEncodeNarrowSignedIntegers, Overflow) { + // The error message should be overwritten after each parsing call. + EncodeNumberStatus rc = EncodeNumberStatus::kSuccess; + std::string err_msg; + NumberType type = {16, SPV_NUMBER_SIGNED_INT}; + + rc = ParseAndEncodeIntegerNumber("32768", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Integer 32768 does not fit in a 16-bit signed integer", err_msg); + rc = ParseAndEncodeIntegerNumber("-32769", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Integer -32769 does not fit in a 16-bit signed integer", err_msg); +} + +TEST(ParseAndEncodeNarrowSignedIntegers, Success) { + // Don't care the error message in this case. + EncodeNumberStatus rc = EncodeNumberStatus::kInvalidText; + NumberType type = {16, SPV_NUMBER_SIGNED_INT}; + + // Zero, maximum, and minimum value + rc = ParseAndEncodeIntegerNumber( + "0", type, [](uint32_t word) { EXPECT_EQ(0u, word); }, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + rc = ParseAndEncodeIntegerNumber( + "-0", type, [](uint32_t word) { EXPECT_EQ(0u, word); }, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + rc = ParseAndEncodeIntegerNumber( + "32767", type, [](uint32_t word) { EXPECT_EQ(0x00007fffu, word); }, + nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + rc = ParseAndEncodeIntegerNumber( + "-32768", type, [](uint32_t word) { EXPECT_EQ(0xffff8000u, word); }, + nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + + // Hex parsing + rc = ParseAndEncodeIntegerNumber( + "0x7fff", type, [](uint32_t word) { EXPECT_EQ(0x00007fffu, word); }, + nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + rc = ParseAndEncodeIntegerNumber( + "0xffff", type, [](uint32_t word) { EXPECT_EQ(0xffffffffu, word); }, + nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); +} + +TEST(ParseAndEncodeNarrowUnsignedIntegers, Invalid) { + // The error message should be overwritten after each parsing call. + EncodeNumberStatus rc = EncodeNumberStatus::kSuccess; + std::string err_msg; + NumberType type = {16, SPV_NUMBER_UNSIGNED_INT}; + + rc = ParseAndEncodeIntegerNumber(nullptr, type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("The given text is a nullptr", err_msg); + rc = ParseAndEncodeIntegerNumber("", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid unsigned integer literal: ", err_msg); + rc = ParseAndEncodeIntegerNumber("=", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid unsigned integer literal: =", err_msg); + rc = ParseAndEncodeIntegerNumber("-", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidUsage, rc); + EXPECT_EQ("Cannot put a negative number in an unsigned literal", err_msg); + rc = ParseAndEncodeIntegerNumber("0=", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid unsigned integer literal: 0=", err_msg); + rc = ParseAndEncodeIntegerNumber("-0", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidUsage, rc); + EXPECT_EQ("Cannot put a negative number in an unsigned literal", err_msg); + rc = ParseAndEncodeIntegerNumber("-1", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidUsage, rc); + EXPECT_EQ("Cannot put a negative number in an unsigned literal", err_msg); +} + +TEST(ParseAndEncodeNarrowUnsignedIntegers, Overflow) { + // The error message should be overwritten after each parsing call. + EncodeNumberStatus rc = EncodeNumberStatus::kSuccess; + std::string err_msg("random content"); + NumberType type = {16, SPV_NUMBER_UNSIGNED_INT}; + + // Overflow + rc = ParseAndEncodeIntegerNumber("65536", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Integer 65536 does not fit in a 16-bit unsigned integer", err_msg); +} + +TEST(ParseAndEncodeNarrowUnsignedIntegers, Success) { + // Don't care the error message in this case. + EncodeNumberStatus rc = EncodeNumberStatus::kInvalidText; + NumberType type = {16, SPV_NUMBER_UNSIGNED_INT}; + + // Zero, maximum, and minimum value + rc = ParseAndEncodeIntegerNumber( + "0", type, [](uint32_t word) { EXPECT_EQ(0u, word); }, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + rc = ParseAndEncodeIntegerNumber( + "65535", type, [](uint32_t word) { EXPECT_EQ(0x0000ffffu, word); }, + nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + + // Hex parsing + rc = ParseAndEncodeIntegerNumber( + "0xffff", type, [](uint32_t word) { EXPECT_EQ(0x0000ffffu, word); }, + nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); +} + +TEST(ParseAndEncodeSignedIntegers, Invalid) { + // The error message should be overwritten after each parsing call. + EncodeNumberStatus rc = EncodeNumberStatus::kSuccess; + std::string err_msg; + NumberType type = {32, SPV_NUMBER_SIGNED_INT}; + + rc = ParseAndEncodeIntegerNumber(nullptr, type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("The given text is a nullptr", err_msg); + rc = ParseAndEncodeIntegerNumber("", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid unsigned integer literal: ", err_msg); + rc = ParseAndEncodeIntegerNumber("=", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid unsigned integer literal: =", err_msg); + rc = ParseAndEncodeIntegerNumber("-", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid signed integer literal: -", err_msg); + rc = ParseAndEncodeIntegerNumber("0=", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid unsigned integer literal: 0=", err_msg); +} + +TEST(ParseAndEncodeSignedIntegers, Overflow) { + // The error message should be overwritten after each parsing call. + EncodeNumberStatus rc = EncodeNumberStatus::kSuccess; + std::string err_msg; + NumberType type = {32, SPV_NUMBER_SIGNED_INT}; + + rc = + ParseAndEncodeIntegerNumber("2147483648", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Integer 2147483648 does not fit in a 32-bit signed integer", + err_msg); + rc = ParseAndEncodeIntegerNumber("-2147483649", type, AssertEmitFunc, + &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Integer -2147483649 does not fit in a 32-bit signed integer", + err_msg); +} + +TEST(ParseAndEncodeSignedIntegers, Success) { + // Don't care the error message in this case. + EncodeNumberStatus rc = EncodeNumberStatus::kInvalidText; + NumberType type = {32, SPV_NUMBER_SIGNED_INT}; + + // Zero, maximum, and minimum value + rc = ParseAndEncodeIntegerNumber( + "0", type, [](uint32_t word) { EXPECT_EQ(0u, word); }, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + rc = ParseAndEncodeIntegerNumber( + "-0", type, [](uint32_t word) { EXPECT_EQ(0u, word); }, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + rc = ParseAndEncodeIntegerNumber( + "2147483647", type, [](uint32_t word) { EXPECT_EQ(0x7fffffffu, word); }, + nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + rc = ParseAndEncodeIntegerNumber( + "-2147483648", type, [](uint32_t word) { EXPECT_EQ(0x80000000u, word); }, + nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + + // Hex parsing + rc = ParseAndEncodeIntegerNumber( + "0x7fffffff", type, [](uint32_t word) { EXPECT_EQ(0x7fffffffu, word); }, + nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + rc = ParseAndEncodeIntegerNumber( + "0xffffffff", type, [](uint32_t word) { EXPECT_EQ(0xffffffffu, word); }, + nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); +} + +TEST(ParseAndEncodeUnsignedIntegers, Invalid) { + // The error message should be overwritten after each parsing call. + EncodeNumberStatus rc = EncodeNumberStatus::kSuccess; + std::string err_msg; + NumberType type = {32, SPV_NUMBER_UNSIGNED_INT}; + + rc = ParseAndEncodeIntegerNumber(nullptr, type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("The given text is a nullptr", err_msg); + rc = ParseAndEncodeIntegerNumber("", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid unsigned integer literal: ", err_msg); + rc = ParseAndEncodeIntegerNumber("=", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid unsigned integer literal: =", err_msg); + rc = ParseAndEncodeIntegerNumber("-", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidUsage, rc); + EXPECT_EQ("Cannot put a negative number in an unsigned literal", err_msg); + rc = ParseAndEncodeIntegerNumber("0=", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid unsigned integer literal: 0=", err_msg); + rc = ParseAndEncodeIntegerNumber("-0", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidUsage, rc); + EXPECT_EQ("Cannot put a negative number in an unsigned literal", err_msg); + rc = ParseAndEncodeIntegerNumber("-1", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidUsage, rc); + EXPECT_EQ("Cannot put a negative number in an unsigned literal", err_msg); +} + +TEST(ParseAndEncodeUnsignedIntegers, Overflow) { + // The error message should be overwritten after each parsing call. + EncodeNumberStatus rc = EncodeNumberStatus::kSuccess; + std::string err_msg("random content"); + NumberType type = {32, SPV_NUMBER_UNSIGNED_INT}; + + // Overflow + rc = + ParseAndEncodeIntegerNumber("4294967296", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Integer 4294967296 does not fit in a 32-bit unsigned integer", + err_msg); +} + +TEST(ParseAndEncodeUnsignedIntegers, Success) { + // Don't care the error message in this case. + EncodeNumberStatus rc = EncodeNumberStatus::kInvalidText; + NumberType type = {32, SPV_NUMBER_UNSIGNED_INT}; + + // Zero, maximum, and minimum value + rc = ParseAndEncodeIntegerNumber( + "0", type, [](uint32_t word) { EXPECT_EQ(0u, word); }, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + rc = ParseAndEncodeIntegerNumber( + "4294967295", type, [](uint32_t word) { EXPECT_EQ(0xffffffffu, word); }, + nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + + // Hex parsing + rc = ParseAndEncodeIntegerNumber( + "0xffffffff", type, [](uint32_t word) { EXPECT_EQ(0xffffffffu, word); }, + nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); +} + +TEST(ParseAndEncodeWideSignedIntegers, Invalid) { + // The error message should be overwritten after each parsing call. + EncodeNumberStatus rc = EncodeNumberStatus::kSuccess; + std::string err_msg; + NumberType type = {64, SPV_NUMBER_SIGNED_INT}; + + rc = ParseAndEncodeIntegerNumber(nullptr, type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("The given text is a nullptr", err_msg); + rc = ParseAndEncodeIntegerNumber("", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid unsigned integer literal: ", err_msg); + rc = ParseAndEncodeIntegerNumber("=", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid unsigned integer literal: =", err_msg); + rc = ParseAndEncodeIntegerNumber("-", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid signed integer literal: -", err_msg); + rc = ParseAndEncodeIntegerNumber("0=", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid unsigned integer literal: 0=", err_msg); +} + +TEST(ParseAndEncodeWideSignedIntegers, Overflow) { + // The error message should be overwritten after each parsing call. + EncodeNumberStatus rc = EncodeNumberStatus::kSuccess; + std::string err_msg; + NumberType type = {64, SPV_NUMBER_SIGNED_INT}; + + rc = ParseAndEncodeIntegerNumber("9223372036854775808", type, AssertEmitFunc, + &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ( + "Integer 9223372036854775808 does not fit in a 64-bit signed integer", + err_msg); + rc = ParseAndEncodeIntegerNumber("-9223372036854775809", type, AssertEmitFunc, + &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid signed integer literal: -9223372036854775809", err_msg); +} + +TEST(ParseAndEncodeWideSignedIntegers, Success) { + // Don't care the error message in this case. + EncodeNumberStatus rc = EncodeNumberStatus::kInvalidText; + NumberType type = {64, SPV_NUMBER_SIGNED_INT}; + std::vector word_buffer; + auto emit = [&word_buffer](uint32_t word) { + if (word_buffer.size() == 2) word_buffer.clear(); + word_buffer.push_back(word); + }; + + // Zero, maximum, and minimum value + rc = ParseAndEncodeIntegerNumber("0", type, emit, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + EXPECT_THAT(word_buffer, Eq(std::vector{0u, 0u})); + rc = ParseAndEncodeIntegerNumber("-0", type, emit, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + EXPECT_THAT(word_buffer, Eq(std::vector{0u, 0u})); + rc = ParseAndEncodeIntegerNumber("9223372036854775807", type, emit, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + EXPECT_THAT(word_buffer, Eq(std::vector{0xffffffffu, 0x7fffffffu})); + rc = ParseAndEncodeIntegerNumber("-9223372036854775808", type, emit, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + EXPECT_THAT(word_buffer, Eq(std::vector{0u, 0x80000000u})); + rc = ParseAndEncodeIntegerNumber("-1", type, emit, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + EXPECT_THAT(word_buffer, Eq(std::vector{0xffffffffu, 0xffffffffu})); + + // Hex parsing + rc = ParseAndEncodeIntegerNumber("0x7fffffffffffffff", type, emit, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + EXPECT_THAT(word_buffer, Eq(std::vector{0xffffffffu, 0x7fffffffu})); + rc = ParseAndEncodeIntegerNumber("0xffffffffffffffff", type, emit, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + EXPECT_THAT(word_buffer, Eq(std::vector{0xffffffffu, 0xffffffffu})); +} + +TEST(ParseAndEncodeWideUnsignedIntegers, Invalid) { + // The error message should be overwritten after each parsing call. + EncodeNumberStatus rc = EncodeNumberStatus::kSuccess; + std::string err_msg; + NumberType type = {64, SPV_NUMBER_UNSIGNED_INT}; + + // Invalid + rc = ParseAndEncodeIntegerNumber(nullptr, type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("The given text is a nullptr", err_msg); + rc = ParseAndEncodeIntegerNumber("", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid unsigned integer literal: ", err_msg); + rc = ParseAndEncodeIntegerNumber("=", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid unsigned integer literal: =", err_msg); + rc = ParseAndEncodeIntegerNumber("-", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidUsage, rc); + EXPECT_EQ("Cannot put a negative number in an unsigned literal", err_msg); + rc = ParseAndEncodeIntegerNumber("0=", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid unsigned integer literal: 0=", err_msg); + rc = ParseAndEncodeIntegerNumber("-0", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidUsage, rc); + EXPECT_EQ("Cannot put a negative number in an unsigned literal", err_msg); + rc = ParseAndEncodeIntegerNumber("-1", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidUsage, rc); + EXPECT_EQ("Cannot put a negative number in an unsigned literal", err_msg); +} + +TEST(ParseAndEncodeWideUnsignedIntegers, Overflow) { + // The error message should be overwritten after each parsing call. + EncodeNumberStatus rc = EncodeNumberStatus::kSuccess; + std::string err_msg; + NumberType type = {64, SPV_NUMBER_UNSIGNED_INT}; + + // Overflow + rc = ParseAndEncodeIntegerNumber("18446744073709551616", type, AssertEmitFunc, + &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid unsigned integer literal: 18446744073709551616", err_msg); +} + +TEST(ParseAndEncodeWideUnsignedIntegers, Success) { + // Don't care the error message in this case. + EncodeNumberStatus rc = EncodeNumberStatus::kInvalidText; + NumberType type = {64, SPV_NUMBER_UNSIGNED_INT}; + std::vector word_buffer; + auto emit = [&word_buffer](uint32_t word) { + if (word_buffer.size() == 2) word_buffer.clear(); + word_buffer.push_back(word); + }; + + // Zero, maximum, and minimum value + rc = ParseAndEncodeIntegerNumber("0", type, emit, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + EXPECT_THAT(word_buffer, Eq(std::vector{0u, 0u})); + rc = ParseAndEncodeIntegerNumber("18446744073709551615", type, emit, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + EXPECT_THAT(word_buffer, Eq(std::vector{0xffffffffu, 0xffffffffu})); + + // Hex parsing + rc = ParseAndEncodeIntegerNumber("0xffffffffffffffff", type, emit, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + EXPECT_THAT(word_buffer, Eq(std::vector{0xffffffffu, 0xffffffffu})); +} + +TEST(ParseAndEncodeIntegerNumber, TypeNone) { + EncodeNumberStatus rc = EncodeNumberStatus::kSuccess; + std::string err_msg; + NumberType type = {32, SPV_NUMBER_NONE}; + + rc = ParseAndEncodeIntegerNumber( + "0.0", type, [](uint32_t word) { EXPECT_EQ(0x0u, word); }, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidUsage, rc); + EXPECT_EQ("The expected type is not a integer type", err_msg); +} + +TEST(ParseAndEncodeIntegerNumber, InvalidCaseWithoutErrorMessageString) { + EncodeNumberStatus rc = EncodeNumberStatus::kSuccess; + NumberType type = {32, SPV_NUMBER_SIGNED_INT}; + + rc = ParseAndEncodeIntegerNumber("invalid", type, AssertEmitFunc, nullptr); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); +} + +TEST(ParseAndEncodeIntegerNumber, DoNotTouchErrorMessageStringOnSuccess) { + EncodeNumberStatus rc = EncodeNumberStatus::kInvalidText; + std::string err_msg("random content"); + NumberType type = {32, SPV_NUMBER_SIGNED_INT}; + + rc = ParseAndEncodeIntegerNumber( + "100", type, [](uint32_t word) { EXPECT_EQ(100u, word); }, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + EXPECT_EQ("random content", err_msg); +} + +TEST(ParseAndEncodeFloat, Sample) { + EncodeNumberStatus rc = EncodeNumberStatus::kSuccess; + std::string err_msg; + NumberType type = {32, SPV_NUMBER_FLOATING}; + + // Invalid + rc = ParseAndEncodeFloatingPointNumber("", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid 32-bit float literal: ", err_msg); + rc = ParseAndEncodeFloatingPointNumber("0=", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid 32-bit float literal: 0=", err_msg); + + // Representative samples + rc = ParseAndEncodeFloatingPointNumber( + "0.0", type, [](uint32_t word) { EXPECT_EQ(0x0u, word); }, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + rc = ParseAndEncodeFloatingPointNumber( + "-0.0", type, [](uint32_t word) { EXPECT_EQ(0x80000000u, word); }, + nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + rc = ParseAndEncodeFloatingPointNumber( + "42", type, [](uint32_t word) { EXPECT_EQ(0x42280000u, word); }, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + rc = ParseAndEncodeFloatingPointNumber( + "2.5", type, [](uint32_t word) { EXPECT_EQ(0x40200000u, word); }, + nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + rc = ParseAndEncodeFloatingPointNumber( + "-32.5", type, [](uint32_t word) { EXPECT_EQ(0xc2020000u, word); }, + nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + rc = ParseAndEncodeFloatingPointNumber( + "1e38", type, [](uint32_t word) { EXPECT_EQ(0x7e967699u, word); }, + nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + rc = ParseAndEncodeFloatingPointNumber( + "-1e38", type, [](uint32_t word) { EXPECT_EQ(0xfe967699u, word); }, + nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + + // Overflow + rc = + ParseAndEncodeFloatingPointNumber("1e40", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid 32-bit float literal: 1e40", err_msg); + rc = ParseAndEncodeFloatingPointNumber("-1e40", type, AssertEmitFunc, + &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid 32-bit float literal: -1e40", err_msg); + rc = ParseAndEncodeFloatingPointNumber("1e400", type, AssertEmitFunc, + &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid 32-bit float literal: 1e400", err_msg); + rc = ParseAndEncodeFloatingPointNumber("-1e400", type, AssertEmitFunc, + &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid 32-bit float literal: -1e400", err_msg); +} + +TEST(ParseAndEncodeDouble, Sample) { + EncodeNumberStatus rc = EncodeNumberStatus::kSuccess; + std::string err_msg; + NumberType type = {64, SPV_NUMBER_FLOATING}; + std::vector word_buffer; + auto emit = [&word_buffer](uint32_t word) { + if (word_buffer.size() == 2) word_buffer.clear(); + word_buffer.push_back(word); + }; + + // Invalid + rc = ParseAndEncodeFloatingPointNumber("", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid 64-bit float literal: ", err_msg); + rc = ParseAndEncodeFloatingPointNumber("0=", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid 64-bit float literal: 0=", err_msg); + + // Representative samples + rc = ParseAndEncodeFloatingPointNumber("0.0", type, emit, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + EXPECT_THAT(word_buffer, Eq(std::vector{0u, 0u})); + rc = ParseAndEncodeFloatingPointNumber("-0.0", type, emit, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + EXPECT_THAT(word_buffer, Eq(std::vector{0u, 0x80000000u})); + rc = ParseAndEncodeFloatingPointNumber("42", type, emit, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + EXPECT_THAT(word_buffer, Eq(std::vector{0u, 0x40450000u})); + rc = ParseAndEncodeFloatingPointNumber("2.5", type, emit, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + EXPECT_THAT(word_buffer, Eq(std::vector{0u, 0x40040000u})); + rc = ParseAndEncodeFloatingPointNumber("32.5", type, emit, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + EXPECT_THAT(word_buffer, Eq(std::vector{0u, 0x40404000u})); + rc = ParseAndEncodeFloatingPointNumber("1e38", type, emit, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + EXPECT_THAT(word_buffer, Eq(std::vector{0x2a16a1b1u, 0x47d2ced3u})); + rc = ParseAndEncodeFloatingPointNumber("-1e38", type, emit, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + EXPECT_THAT(word_buffer, Eq(std::vector{0x2a16a1b1u, 0xc7d2ced3u})); + rc = ParseAndEncodeFloatingPointNumber("1e40", type, emit, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + EXPECT_THAT(word_buffer, Eq(std::vector{0xf1c35ca5u, 0x483d6329u})); + rc = ParseAndEncodeFloatingPointNumber("-1e40", type, emit, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + EXPECT_THAT(word_buffer, Eq(std::vector{0xf1c35ca5u, 0xc83d6329u})); + + // Overflow + rc = ParseAndEncodeFloatingPointNumber("1e400", type, AssertEmitFunc, + &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid 64-bit float literal: 1e400", err_msg); + rc = ParseAndEncodeFloatingPointNumber("-1e400", type, AssertEmitFunc, + &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid 64-bit float literal: -1e400", err_msg); +} + +TEST(ParseAndEncodeFloat16, Sample) { + EncodeNumberStatus rc = EncodeNumberStatus::kSuccess; + std::string err_msg; + NumberType type = {16, SPV_NUMBER_FLOATING}; + + // Invalid + rc = ParseAndEncodeFloatingPointNumber("", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid 16-bit float literal: ", err_msg); + rc = ParseAndEncodeFloatingPointNumber("0=", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid 16-bit float literal: 0=", err_msg); + + // Representative samples + rc = ParseAndEncodeFloatingPointNumber( + "0.0", type, [](uint32_t word) { EXPECT_EQ(0x0u, word); }, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + rc = ParseAndEncodeFloatingPointNumber( + "-0.0", type, [](uint32_t word) { EXPECT_EQ(0x8000u, word); }, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + rc = ParseAndEncodeFloatingPointNumber( + "1.0", type, [](uint32_t word) { EXPECT_EQ(0x3c00u, word); }, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + rc = ParseAndEncodeFloatingPointNumber( + "2.5", type, [](uint32_t word) { EXPECT_EQ(0x4100u, word); }, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + rc = ParseAndEncodeFloatingPointNumber( + "32.5", type, [](uint32_t word) { EXPECT_EQ(0x5010u, word); }, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + + // Overflow + rc = + ParseAndEncodeFloatingPointNumber("1e38", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid 16-bit float literal: 1e38", err_msg); + rc = ParseAndEncodeFloatingPointNumber("-1e38", type, AssertEmitFunc, + &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid 16-bit float literal: -1e38", err_msg); + rc = + ParseAndEncodeFloatingPointNumber("1e40", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid 16-bit float literal: 1e40", err_msg); + rc = ParseAndEncodeFloatingPointNumber("-1e40", type, AssertEmitFunc, + &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid 16-bit float literal: -1e40", err_msg); + rc = ParseAndEncodeFloatingPointNumber("1e400", type, AssertEmitFunc, + &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid 16-bit float literal: 1e400", err_msg); + rc = ParseAndEncodeFloatingPointNumber("-1e400", type, AssertEmitFunc, + &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid 16-bit float literal: -1e400", err_msg); +} + +TEST(ParseAndEncodeFloatingPointNumber, TypeNone) { + EncodeNumberStatus rc = EncodeNumberStatus::kSuccess; + std::string err_msg; + NumberType type = {32, SPV_NUMBER_NONE}; + + rc = ParseAndEncodeFloatingPointNumber( + "0.0", type, [](uint32_t word) { EXPECT_EQ(0x0u, word); }, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidUsage, rc); + EXPECT_EQ("The expected type is not a float type", err_msg); +} + +TEST(ParseAndEncodeFloatingPointNumber, InvalidCaseWithoutErrorMessageString) { + EncodeNumberStatus rc = EncodeNumberStatus::kSuccess; + NumberType type = {32, SPV_NUMBER_FLOATING}; + + rc = ParseAndEncodeFloatingPointNumber("invalid", type, AssertEmitFunc, + nullptr); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); +} + +TEST(ParseAndEncodeFloatingPointNumber, DoNotTouchErrorMessageStringOnSuccess) { + EncodeNumberStatus rc = EncodeNumberStatus::kInvalidText; + std::string err_msg("random content"); + NumberType type = {32, SPV_NUMBER_FLOATING}; + + rc = ParseAndEncodeFloatingPointNumber( + "0.0", type, [](uint32_t word) { EXPECT_EQ(0x0u, word); }, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + EXPECT_EQ("random content", err_msg); +} + +TEST(ParseAndEncodeNumber, Sample) { + EncodeNumberStatus rc = EncodeNumberStatus::kSuccess; + std::string err_msg; + NumberType type = {32, SPV_NUMBER_SIGNED_INT}; + + // Invalid with error message string + rc = ParseAndEncodeNumber("something wrong", type, AssertEmitFunc, &err_msg); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + EXPECT_EQ("Invalid unsigned integer literal: something wrong", err_msg); + + // Invalid without error message string + rc = ParseAndEncodeNumber("something wrong", type, AssertEmitFunc, nullptr); + EXPECT_EQ(EncodeNumberStatus::kInvalidText, rc); + + // Signed integer, should not touch the error message string. + err_msg = "random content"; + rc = ParseAndEncodeNumber("-1", type, + [](uint32_t word) { EXPECT_EQ(0xffffffffu, word); }, + &err_msg); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + EXPECT_EQ("random content", err_msg); + + // Unsigned integer + type = {32, SPV_NUMBER_UNSIGNED_INT}; + rc = ParseAndEncodeNumber( + "1", type, [](uint32_t word) { EXPECT_EQ(1u, word); }, nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); + + // Float + type = {32, SPV_NUMBER_FLOATING}; + rc = ParseAndEncodeNumber("-1.0", type, + [](uint32_t word) { EXPECT_EQ(0xbf800000, word); }, + nullptr); + EXPECT_EQ(EncodeNumberStatus::kSuccess, rc); +} + +} // namespace +} // namespace utils +} // namespace spvtools diff --git a/test/pch_test.cpp b/test/pch_test.cpp new file mode 100644 index 000000000..3b06a0aa2 --- /dev/null +++ b/test/pch_test.cpp @@ -0,0 +1,15 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "pch_test.h" diff --git a/test/pch_test.h b/test/pch_test.h new file mode 100644 index 000000000..7dac06acf --- /dev/null +++ b/test/pch_test.h @@ -0,0 +1,18 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gmock/gmock.h" +#include "source/spirv_constant.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" diff --git a/test/preserve_numeric_ids_test.cpp b/test/preserve_numeric_ids_test.cpp new file mode 100644 index 000000000..1c3354d55 --- /dev/null +++ b/test/preserve_numeric_ids_test.cpp @@ -0,0 +1,159 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests for unique type declaration rules validator. + +#include + +#include "source/text.h" +#include "source/text_handler.h" +#include "test/test_fixture.h" + +namespace spvtools { +namespace { + +using spvtest::ScopedContext; + +// Converts code to binary and then back to text. +spv_result_t ToBinaryAndBack( + const std::string& before, std::string* after, + uint32_t text_to_binary_options = SPV_TEXT_TO_BINARY_OPTION_NONE, + uint32_t binary_to_text_options = SPV_BINARY_TO_TEXT_OPTION_NONE, + spv_target_env env = SPV_ENV_UNIVERSAL_1_0) { + ScopedContext ctx(env); + spv_binary binary; + spv_text text; + + spv_result_t result = + spvTextToBinaryWithOptions(ctx.context, before.c_str(), before.size(), + text_to_binary_options, &binary, nullptr); + if (result != SPV_SUCCESS) { + return result; + } + + result = spvBinaryToText(ctx.context, binary->code, binary->wordCount, + binary_to_text_options, &text, nullptr); + if (result != SPV_SUCCESS) { + return result; + } + + *after = std::string(text->str, text->length); + + spvBinaryDestroy(binary); + spvTextDestroy(text); + + return SPV_SUCCESS; +} + +TEST(ToBinaryAndBack, DontPreserveNumericIds) { + const std::string before = + R"(OpCapability Addresses +OpCapability Kernel +OpCapability GenericPointer +OpCapability Linkage +OpMemoryModel Physical32 OpenCL +%i32 = OpTypeInt 32 1 +%u32 = OpTypeInt 32 0 +%f32 = OpTypeFloat 32 +%200 = OpTypeVoid +%300 = OpTypeFunction %200 +%main = OpFunction %200 None %300 +%entry = OpLabel +%100 = OpConstant %u32 100 +%1 = OpConstant %u32 200 +%2 = OpConstant %u32 300 +OpReturn +OpFunctionEnd +)"; + + const std::string expected = + R"(OpCapability Addresses +OpCapability Kernel +OpCapability GenericPointer +OpCapability Linkage +OpMemoryModel Physical32 OpenCL +%1 = OpTypeInt 32 1 +%2 = OpTypeInt 32 0 +%3 = OpTypeFloat 32 +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpFunction %4 None %5 +%7 = OpLabel +%8 = OpConstant %2 100 +%9 = OpConstant %2 200 +%10 = OpConstant %2 300 +OpReturn +OpFunctionEnd +)"; + + std::string after; + EXPECT_EQ(SPV_SUCCESS, + ToBinaryAndBack(before, &after, SPV_TEXT_TO_BINARY_OPTION_NONE, + SPV_BINARY_TO_TEXT_OPTION_NO_HEADER)); + + EXPECT_EQ(expected, after); +} + +TEST(TextHandler, PreserveNumericIds) { + const std::string before = + R"(OpCapability Addresses +OpCapability Kernel +OpCapability GenericPointer +OpCapability Linkage +OpMemoryModel Physical32 OpenCL +%i32 = OpTypeInt 32 1 +%u32 = OpTypeInt 32 0 +%f32 = OpTypeFloat 32 +%200 = OpTypeVoid +%300 = OpTypeFunction %200 +%main = OpFunction %200 None %300 +%entry = OpLabel +%100 = OpConstant %u32 100 +%1 = OpConstant %u32 200 +%2 = OpConstant %u32 300 +OpReturn +OpFunctionEnd +)"; + + const std::string expected = + R"(OpCapability Addresses +OpCapability Kernel +OpCapability GenericPointer +OpCapability Linkage +OpMemoryModel Physical32 OpenCL +%3 = OpTypeInt 32 1 +%4 = OpTypeInt 32 0 +%5 = OpTypeFloat 32 +%200 = OpTypeVoid +%300 = OpTypeFunction %200 +%6 = OpFunction %200 None %300 +%7 = OpLabel +%100 = OpConstant %4 100 +%1 = OpConstant %4 200 +%2 = OpConstant %4 300 +OpReturn +OpFunctionEnd +)"; + + std::string after; + EXPECT_EQ(SPV_SUCCESS, + ToBinaryAndBack(before, &after, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS, + SPV_BINARY_TO_TEXT_OPTION_NO_HEADER)); + + EXPECT_EQ(expected, after); +} + +} // namespace +} // namespace spvtools diff --git a/test/reduce/CMakeLists.txt b/test/reduce/CMakeLists.txt new file mode 100644 index 000000000..b35cdb260 --- /dev/null +++ b/test/reduce/CMakeLists.txt @@ -0,0 +1,28 @@ +# Copyright (c) 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_spvtools_unittest(TARGET reduce + SRCS operand_to_constant_reduction_pass_test.cpp + operand_to_undef_reduction_pass_test.cpp + operand_to_dominating_id_reduction_pass_test.cpp + reduce_test_util.cpp + reduce_test_util.h + reducer_test.cpp + remove_opname_instruction_reduction_pass_test.cpp + remove_unreferenced_instruction_reduction_pass_test.cpp + structured_loop_to_selection_reduction_pass_test.cpp + validation_during_reduction_test.cpp + LIBS SPIRV-Tools-reduce + ) + diff --git a/test/reduce/operand_to_constant_reduction_pass_test.cpp b/test/reduce/operand_to_constant_reduction_pass_test.cpp new file mode 100644 index 000000000..34cc4a117 --- /dev/null +++ b/test/reduce/operand_to_constant_reduction_pass_test.cpp @@ -0,0 +1,156 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reduce_test_util.h" +#include "source/opt/build_module.h" +#include "source/reduce/operand_to_const_reduction_pass.h" + +namespace spvtools { +namespace reduce { +namespace { + +TEST(OperandToConstantReductionPassTest, BasicCheck) { + std::string prologue = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %37 + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + OpName %9 "buf1" + OpMemberName %9 0 "f" + OpName %11 "" + OpName %24 "buf2" + OpMemberName %24 0 "i" + OpName %26 "" + OpName %37 "_GLF_color" + OpMemberDecorate %9 0 Offset 0 + OpDecorate %9 Block + OpDecorate %11 DescriptorSet 0 + OpDecorate %11 Binding 1 + OpMemberDecorate %24 0 Offset 0 + OpDecorate %24 Block + OpDecorate %26 DescriptorSet 0 + OpDecorate %26 Binding 2 + OpDecorate %37 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %9 = OpTypeStruct %6 + %10 = OpTypePointer Uniform %9 + %11 = OpVariable %10 Uniform + %12 = OpTypeInt 32 1 + %13 = OpConstant %12 0 + %14 = OpTypePointer Uniform %6 + %20 = OpConstant %6 2 + %24 = OpTypeStruct %12 + %25 = OpTypePointer Uniform %24 + %26 = OpVariable %25 Uniform + %27 = OpTypePointer Uniform %12 + %33 = OpConstant %12 3 + %35 = OpTypeVector %6 4 + %36 = OpTypePointer Output %35 + %37 = OpVariable %36 Output + %4 = OpFunction %2 None %3 + %5 = OpLabel + %15 = OpAccessChain %14 %11 %13 + %16 = OpLoad %6 %15 + %19 = OpFAdd %6 %16 %16 + %21 = OpFAdd %6 %19 %20 + %28 = OpAccessChain %27 %26 %13 + %29 = OpLoad %12 %28 + )"; + + std::string epilogue = R"( + %45 = OpConvertSToF %6 %34 + %46 = OpCompositeConstruct %35 %16 %21 %43 %45 + OpStore %37 %46 + OpReturn + OpFunctionEnd + )"; + + std::string original = prologue + R"( + %32 = OpIAdd %12 %29 %29 + %34 = OpIAdd %12 %32 %33 + %43 = OpConvertSToF %6 %29 + )" + epilogue; + + std::string expected = prologue + R"( + %32 = OpIAdd %12 %13 %13 ; %29 -> %13 x 2 + %34 = OpIAdd %12 %13 %33 ; %32 -> %13 + %43 = OpConvertSToF %6 %13 ; %29 -> %13 + )" + epilogue; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = + BuildModule(env, consumer, original, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + ASSERT_EQ(17, ops.size()); + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + ASSERT_TRUE(ops[1]->PreconditionHolds()); + ops[1]->TryToApply(); + ASSERT_TRUE(ops[2]->PreconditionHolds()); + ops[2]->TryToApply(); + ASSERT_TRUE(ops[3]->PreconditionHolds()); + ops[3]->TryToApply(); + + CheckEqual(env, expected, context.get()); +} + +TEST(OperandToConstantReductionPassTest, WithCalledFunction) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %10 %12 + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypeVector %6 4 + %8 = OpTypeFunction %7 + %9 = OpTypePointer Output %7 + %10 = OpVariable %9 Output + %11 = OpTypePointer Input %7 + %12 = OpVariable %11 Input + %13 = OpConstant %6 0 + %14 = OpConstantComposite %7 %13 %13 %13 %13 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %15 = OpFunctionCall %7 %16 + OpReturn + OpFunctionEnd + %16 = OpFunction %7 None %8 + %17 = OpLabel + OpReturnValue %14 + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = + BuildModule(env, consumer, shader, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + ASSERT_EQ(0, ops.size()); +} + +} // namespace +} // namespace reduce +} // namespace spvtools diff --git a/test/reduce/operand_to_dominating_id_reduction_pass_test.cpp b/test/reduce/operand_to_dominating_id_reduction_pass_test.cpp new file mode 100644 index 000000000..cc0de65cd --- /dev/null +++ b/test/reduce/operand_to_dominating_id_reduction_pass_test.cpp @@ -0,0 +1,196 @@ +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/reduce/operand_to_dominating_id_reduction_pass.h" +#include "reduce_test_util.h" +#include "source/opt/build_module.h" + +namespace spvtools { +namespace reduce { +namespace { + +TEST(OperandToDominatingIdReductionPassTest, BasicCheck) { + std::string original = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %14 = OpVariable %7 Function + OpStore %8 %9 + %11 = OpLoad %6 %8 + %12 = OpLoad %6 %8 + %13 = OpIAdd %6 %11 %12 + OpStore %10 %13 + %15 = OpLoad %6 %10 + OpStore %14 %15 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = + BuildModule(env, consumer, original, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + ASSERT_EQ(10, ops.size()); + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + + std::string after_op_0 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %14 = OpVariable %7 Function + OpStore %8 %9 + %11 = OpLoad %6 %8 + %12 = OpLoad %6 %8 + %13 = OpIAdd %6 %11 %12 + OpStore %8 %13 ; %10 -> %8 + %15 = OpLoad %6 %10 + OpStore %14 %15 + OpReturn + OpFunctionEnd + )"; + + CheckEqual(env, after_op_0, context.get()); + + ASSERT_TRUE(ops[1]->PreconditionHolds()); + ops[1]->TryToApply(); + + std::string after_op_1 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %14 = OpVariable %7 Function + OpStore %8 %9 + %11 = OpLoad %6 %8 + %12 = OpLoad %6 %8 + %13 = OpIAdd %6 %11 %12 + OpStore %8 %13 ; %10 -> %8 + %15 = OpLoad %6 %8 ; %10 -> %8 + OpStore %14 %15 + OpReturn + OpFunctionEnd + )"; + + CheckEqual(env, after_op_1, context.get()); + + ASSERT_TRUE(ops[2]->PreconditionHolds()); + ops[2]->TryToApply(); + + std::string after_op_2 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %14 = OpVariable %7 Function + OpStore %8 %9 + %11 = OpLoad %6 %8 + %12 = OpLoad %6 %8 + %13 = OpIAdd %6 %11 %12 + OpStore %8 %13 ; %10 -> %8 + %15 = OpLoad %6 %8 ; %10 -> %8 + OpStore %8 %15 ; %14 -> %8 + OpReturn + OpFunctionEnd + )"; + + CheckEqual(env, after_op_2, context.get()); + + // The precondition has been disabled by an earlier opportunity's application. + ASSERT_FALSE(ops[3]->PreconditionHolds()); + + ASSERT_TRUE(ops[4]->PreconditionHolds()); + ops[4]->TryToApply(); + + std::string after_op_4 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %14 = OpVariable %7 Function + OpStore %8 %9 + %11 = OpLoad %6 %8 + %12 = OpLoad %6 %8 + %13 = OpIAdd %6 %11 %11 ; %12 -> %11 + OpStore %8 %13 ; %10 -> %8 + %15 = OpLoad %6 %8 ; %10 -> %8 + OpStore %8 %15 ; %14 -> %8 + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, after_op_4, context.get()); +} + +} // namespace +} // namespace reduce +} // namespace spvtools diff --git a/test/reduce/operand_to_undef_reduction_pass_test.cpp b/test/reduce/operand_to_undef_reduction_pass_test.cpp new file mode 100644 index 000000000..71bf96cf4 --- /dev/null +++ b/test/reduce/operand_to_undef_reduction_pass_test.cpp @@ -0,0 +1,226 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/reduce/operand_to_undef_reduction_pass.h" +#include "source/opt/build_module.h" +#include "test/reduce/reduce_test_util.h" + +namespace spvtools { +namespace reduce { +namespace { + +TEST(OperandToUndefReductionPassTest, BasicCheck) { + // The following shader has 10 opportunities for replacing with undef. + + // #version 310 es + // + // precision highp float; + // + // layout(location=0) out vec4 _GLF_color; + // + // layout(set = 0, binding = 0) uniform buf0 { + // vec2 uniform1; + // }; + // + // void main() + // { + // _GLF_color = + // vec4( // opportunity + // uniform1.x / 2.0, // opportunity x2 (2.0 is const) + // uniform1.y / uniform1.x, // opportunity x3 + // uniform1.x + uniform1.x, // opportunity x3 + // uniform1.y); // opportunity + // } + + std::string original = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %9 + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + OpName %9 "_GLF_color" + OpName %11 "buf0" + OpMemberName %11 0 "uniform1" + OpName %13 "" + OpDecorate %9 Location 0 + OpMemberDecorate %11 0 Offset 0 + OpDecorate %11 Block + OpDecorate %13 DescriptorSet 0 + OpDecorate %13 Binding 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypeVector %6 4 + %8 = OpTypePointer Output %7 + %9 = OpVariable %8 Output + %10 = OpTypeVector %6 2 + %11 = OpTypeStruct %10 + %12 = OpTypePointer Uniform %11 + %13 = OpVariable %12 Uniform + %14 = OpTypeInt 32 1 + %15 = OpConstant %14 0 + %16 = OpTypeInt 32 0 + %17 = OpConstant %16 0 + %18 = OpTypePointer Uniform %6 + %21 = OpConstant %6 2 + %23 = OpConstant %16 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %19 = OpAccessChain %18 %13 %15 %17 + %20 = OpLoad %6 %19 + %22 = OpFDiv %6 %20 %21 ; opportunity %20 (%21 is const) + %24 = OpAccessChain %18 %13 %15 %23 + %25 = OpLoad %6 %24 + %26 = OpAccessChain %18 %13 %15 %17 + %27 = OpLoad %6 %26 + %28 = OpFDiv %6 %25 %27 ; opportunity %25 %27 + %29 = OpAccessChain %18 %13 %15 %17 + %30 = OpLoad %6 %29 + %31 = OpAccessChain %18 %13 %15 %17 + %32 = OpLoad %6 %31 + %33 = OpFAdd %6 %30 %32 ; opportunity %30 %32 + %34 = OpAccessChain %18 %13 %15 %23 + %35 = OpLoad %6 %34 + %36 = OpCompositeConstruct %7 %22 %28 %33 %35 ; opportunity %22 %28 %33 %35 + OpStore %9 %36 ; opportunity %36 + OpReturn + OpFunctionEnd + )"; + + // This is the same as original, except where noted. + std::string expected = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %9 + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + OpName %9 "_GLF_color" + OpName %11 "buf0" + OpMemberName %11 0 "uniform1" + OpName %13 "" + OpDecorate %9 Location 0 + OpMemberDecorate %11 0 Offset 0 + OpDecorate %11 Block + OpDecorate %13 DescriptorSet 0 + OpDecorate %13 Binding 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypeVector %6 4 + %8 = OpTypePointer Output %7 + %9 = OpVariable %8 Output + %10 = OpTypeVector %6 2 + %11 = OpTypeStruct %10 + %12 = OpTypePointer Uniform %11 + %13 = OpVariable %12 Uniform + %14 = OpTypeInt 32 1 + %15 = OpConstant %14 0 + %16 = OpTypeInt 32 0 + %17 = OpConstant %16 0 + %18 = OpTypePointer Uniform %6 + %21 = OpConstant %6 2 + %23 = OpConstant %16 1 + %37 = OpUndef %6 ; Added undef float as %37 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %19 = OpAccessChain %18 %13 %15 %17 + %20 = OpLoad %6 %19 + %22 = OpFDiv %6 %37 %21 ; Replaced with %37 + %24 = OpAccessChain %18 %13 %15 %23 + %25 = OpLoad %6 %24 + %26 = OpAccessChain %18 %13 %15 %17 + %27 = OpLoad %6 %26 + %28 = OpFDiv %6 %37 %37 ; Replaced with %37 twice + %29 = OpAccessChain %18 %13 %15 %17 + %30 = OpLoad %6 %29 + %31 = OpAccessChain %18 %13 %15 %17 + %32 = OpLoad %6 %31 + %33 = OpFAdd %6 %30 %32 + %34 = OpAccessChain %18 %13 %15 %23 + %35 = OpLoad %6 %34 + %36 = OpCompositeConstruct %7 %22 %28 %33 %35 + OpStore %9 %36 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = + BuildModule(env, consumer, original, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + + ASSERT_EQ(10, ops.size()); + + // Apply first three opportunities. + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + ASSERT_TRUE(ops[1]->PreconditionHolds()); + ops[1]->TryToApply(); + ASSERT_TRUE(ops[2]->PreconditionHolds()); + ops[2]->TryToApply(); + + CheckEqual(env, expected, context.get()); +} + +TEST(OperandToUndefReductionPassTest, WithCalledFunction) { + // The following shader has no opportunities. + // Most importantly, the noted function operand is not changed. + + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %10 %12 + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypeVector %6 4 + %8 = OpTypeFunction %7 + %9 = OpTypePointer Output %7 + %10 = OpVariable %9 Output + %11 = OpTypePointer Input %7 + %12 = OpVariable %11 Input + %13 = OpConstant %6 0 + %14 = OpConstantComposite %7 %13 %13 %13 %13 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %15 = OpFunctionCall %7 %16 ; do not replace %16 with undef + OpReturn + OpFunctionEnd + %16 = OpFunction %7 None %8 + %17 = OpLabel + OpReturnValue %14 + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = + BuildModule(env, consumer, shader, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + ASSERT_EQ(0, ops.size()); +} + +} // namespace +} // namespace reduce +} // namespace spvtools diff --git a/test/reduce/reduce_test_util.cpp b/test/reduce/reduce_test_util.cpp new file mode 100644 index 000000000..19ef74989 --- /dev/null +++ b/test/reduce/reduce_test_util.cpp @@ -0,0 +1,72 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reduce_test_util.h" + +namespace spvtools { +namespace reduce { + +void CheckEqual(const spv_target_env env, + const std::vector& expected_binary, + const std::vector& actual_binary) { + if (expected_binary != actual_binary) { + SpirvTools t(env); + std::string expected_disassembled; + std::string actual_disassembled; + ASSERT_TRUE(t.Disassemble(expected_binary, &expected_disassembled, + kReduceDisassembleOption)); + ASSERT_TRUE(t.Disassemble(actual_binary, &actual_disassembled, + kReduceDisassembleOption)); + ASSERT_EQ(expected_disassembled, actual_disassembled); + } +} + +void CheckEqual(const spv_target_env env, const std::string& expected_text, + const std::vector& actual_binary) { + std::vector expected_binary; + SpirvTools t(env); + ASSERT_TRUE( + t.Assemble(expected_text, &expected_binary, kReduceAssembleOption)); + CheckEqual(env, expected_binary, actual_binary); +} + +void CheckEqual(const spv_target_env env, const std::string& expected_text, + const opt::IRContext* actual_ir) { + std::vector actual_binary; + actual_ir->module()->ToBinary(&actual_binary, false); + CheckEqual(env, expected_text, actual_binary); +} + +void CheckValid(spv_target_env env, const opt::IRContext* ir) { + std::vector binary; + ir->module()->ToBinary(&binary, false); + SpirvTools t(env); + ASSERT_TRUE(t.Validate(binary)); +} + +std::string ToString(spv_target_env env, const opt::IRContext* ir) { + std::vector binary; + ir->module()->ToBinary(&binary, false); + SpirvTools t(env); + std::string result; + t.Disassemble(binary, &result, kReduceDisassembleOption); + return result; +} + +void NopDiagnostic(spv_message_level_t /*level*/, const char* /*source*/, + const spv_position_t& /*position*/, + const char* /*message*/) {} + +} // namespace reduce +} // namespace spvtools diff --git a/test/reduce/reduce_test_util.h b/test/reduce/reduce_test_util.h new file mode 100644 index 000000000..499c77475 --- /dev/null +++ b/test/reduce/reduce_test_util.h @@ -0,0 +1,82 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TEST_REDUCE_REDUCE_TEST_UTIL_H_ +#define TEST_REDUCE_REDUCE_TEST_UTIL_H_ + +#include "gtest/gtest.h" + +#include "source/opt/ir_context.h" +#include "source/reduce/reduction_opportunity.h" +#include "spirv-tools/libspirv.h" + +namespace spvtools { +namespace reduce { + +// A helper class that subclasses a given reduction pass class in order to +// provide a wrapper for its protected methods. +template +class TestSubclass : public ReductionPassT { + public: + // Creates an instance of the reduction pass subclass with respect to target + // environment |env|. + explicit TestSubclass(const spv_target_env env) : ReductionPassT(env) {} + ~TestSubclass() = default; + + // A wrapper for GetAvailableOpportunities(...) + std::vector> + WrapGetAvailableOpportunities(opt::IRContext* context) const { + return ReductionPassT::GetAvailableOpportunities(context); + } +}; + +// Checks whether the given binaries are bit-wise equal. +void CheckEqual(spv_target_env env, + const std::vector& expected_binary, + const std::vector& actual_binary); + +// Assembles the given text and check whether the resulting binary is bit-wise +// equal to the given binary. +void CheckEqual(spv_target_env env, const std::string& expected_text, + const std::vector& actual_binary); + +// Assembles the given text and turns the given IR into binary, then checks +// whether the resulting binaries are bit-wise equal. +void CheckEqual(spv_target_env env, const std::string& expected_text, + const opt::IRContext* actual_ir); + +// Assembles the given IR context and checks whether the resulting binary is +// valid. +void CheckValid(spv_target_env env, const opt::IRContext* ir); + +// Assembles the given IR context, then returns its disassembly as a string. +// Useful for debugging. +std::string ToString(spv_target_env env, const opt::IRContext* ir); + +// Assembly options for writing reduction tests. It simplifies matters if +// numeric ids do not change. +const uint32_t kReduceAssembleOption = + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS; +// Disassembly options for writing reduction tests. +const uint32_t kReduceDisassembleOption = + SPV_BINARY_TO_TEXT_OPTION_NO_HEADER | SPV_BINARY_TO_TEXT_OPTION_INDENT; + +// Don't print reducer info during testing. +void NopDiagnostic(spv_message_level_t /*level*/, const char* /*source*/, + const spv_position_t& /*position*/, const char* /*message*/); + +} // namespace reduce +} // namespace spvtools + +#endif // TEST_REDUCE_REDUCE_TEST_UTIL_H_ diff --git a/test/reduce/reducer_test.cpp b/test/reduce/reducer_test.cpp new file mode 100644 index 000000000..88fc5e44e --- /dev/null +++ b/test/reduce/reducer_test.cpp @@ -0,0 +1,310 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reduce_test_util.h" + +#include "source/reduce/operand_to_const_reduction_pass.h" +#include "source/reduce/reducer.h" +#include "source/reduce/remove_opname_instruction_reduction_pass.h" +#include "source/reduce/remove_unreferenced_instruction_reduction_pass.h" + +namespace spvtools { +namespace reduce { +namespace { + +// This changes its mind each time IsInteresting is invoked as to whether the +// binary is interesting, until some limit is reached after which the binary is +// always deemed interesting. This is useful to test that reduction passes +// interleave in interesting ways for a while, and then always succeed after +// some point; the latter is important to end up with a predictable final +// reduced binary for tests. +class PingPongInteresting { + public: + explicit PingPongInteresting(uint32_t always_interesting_after) + : is_interesting_(true), + always_interesting_after_(always_interesting_after), + count_(0) {} + + bool IsInteresting(const std::vector&) { + bool result; + if (count_ > always_interesting_after_) { + result = true; + } else { + result = is_interesting_; + is_interesting_ = !is_interesting_; + } + count_++; + return result; + } + + private: + bool is_interesting_; + const uint32_t always_interesting_after_; + uint32_t count_; +}; + +TEST(ReducerTest, ExprToConstantAndRemoveUnreferenced) { + std::string original = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %60 + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + OpName %16 "buf2" + OpMemberName %16 0 "i" + OpName %18 "" + OpName %25 "buf1" + OpMemberName %25 0 "f" + OpName %27 "" + OpName %60 "_GLF_color" + OpMemberDecorate %16 0 Offset 0 + OpDecorate %16 Block + OpDecorate %18 DescriptorSet 0 + OpDecorate %18 Binding 2 + OpMemberDecorate %25 0 Offset 0 + OpDecorate %25 Block + OpDecorate %27 DescriptorSet 0 + OpDecorate %27 Binding 1 + OpDecorate %60 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %9 = OpConstant %6 0 + %16 = OpTypeStruct %6 + %17 = OpTypePointer Uniform %16 + %18 = OpVariable %17 Uniform + %19 = OpTypePointer Uniform %6 + %22 = OpTypeBool + %100 = OpConstantTrue %22 + %24 = OpTypeFloat 32 + %25 = OpTypeStruct %24 + %26 = OpTypePointer Uniform %25 + %27 = OpVariable %26 Uniform + %28 = OpTypePointer Uniform %24 + %31 = OpConstant %24 2 + %56 = OpConstant %6 1 + %58 = OpTypeVector %24 4 + %59 = OpTypePointer Output %58 + %60 = OpVariable %59 Output + %72 = OpUndef %24 + %74 = OpUndef %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %10 + %10 = OpLabel + %73 = OpPhi %6 %74 %5 %77 %34 + %71 = OpPhi %24 %72 %5 %76 %34 + %70 = OpPhi %6 %9 %5 %57 %34 + %20 = OpAccessChain %19 %18 %9 + %21 = OpLoad %6 %20 + %23 = OpSLessThan %22 %70 %21 + OpLoopMerge %12 %34 None + OpBranchConditional %23 %11 %12 + %11 = OpLabel + %29 = OpAccessChain %28 %27 %9 + %30 = OpLoad %24 %29 + %32 = OpFOrdGreaterThan %22 %30 %31 + OpSelectionMerge %34 None + OpBranchConditional %32 %33 %46 + %33 = OpLabel + %40 = OpFAdd %24 %71 %30 + %45 = OpISub %6 %73 %21 + OpBranch %34 + %46 = OpLabel + %50 = OpFMul %24 %71 %30 + %54 = OpSDiv %6 %73 %21 + OpBranch %34 + %34 = OpLabel + %77 = OpPhi %6 %45 %33 %54 %46 + %76 = OpPhi %24 %40 %33 %50 %46 + %57 = OpIAdd %6 %70 %56 + OpBranch %10 + %12 = OpLabel + %61 = OpAccessChain %28 %27 %9 + %62 = OpLoad %24 %61 + %66 = OpConvertSToF %24 %21 + %68 = OpConvertSToF %24 %73 + %69 = OpCompositeConstruct %58 %62 %71 %66 %68 + OpStore %60 %69 + OpReturn + OpFunctionEnd + )"; + + std::string expected = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %60 + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + OpName %16 "buf2" + OpMemberName %16 0 "i" + OpName %18 "" + OpName %25 "buf1" + OpMemberName %25 0 "f" + OpName %27 "" + OpName %60 "_GLF_color" + OpMemberDecorate %16 0 Offset 0 + OpDecorate %16 Block + OpDecorate %18 DescriptorSet 0 + OpDecorate %18 Binding 2 + OpMemberDecorate %25 0 Offset 0 + OpDecorate %25 Block + OpDecorate %27 DescriptorSet 0 + OpDecorate %27 Binding 1 + OpDecorate %60 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %9 = OpConstant %6 0 + %16 = OpTypeStruct %6 + %17 = OpTypePointer Uniform %16 + %18 = OpVariable %17 Uniform + %19 = OpTypePointer Uniform %6 + %22 = OpTypeBool + %100 = OpConstantTrue %22 + %24 = OpTypeFloat 32 + %25 = OpTypeStruct %24 + %26 = OpTypePointer Uniform %25 + %27 = OpVariable %26 Uniform + %28 = OpTypePointer Uniform %24 + %31 = OpConstant %24 2 + %56 = OpConstant %6 1 + %58 = OpTypeVector %24 4 + %59 = OpTypePointer Output %58 + %60 = OpVariable %59 Output + %72 = OpUndef %24 + %74 = OpUndef %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %10 + %10 = OpLabel + OpLoopMerge %12 %34 None + OpBranchConditional %100 %11 %12 + %11 = OpLabel + OpSelectionMerge %34 None + OpBranchConditional %100 %33 %46 + %33 = OpLabel + OpBranch %34 + %46 = OpLabel + OpBranch %34 + %34 = OpLabel + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + spv_target_env env = SPV_ENV_UNIVERSAL_1_3; + Reducer reducer(env); + PingPongInteresting ping_pong_interesting(10); + reducer.SetMessageConsumer(NopDiagnostic); + reducer.SetInterestingnessFunction( + [&](const std::vector& binary, uint32_t) -> bool { + return ping_pong_interesting.IsInteresting(binary); + }); + reducer.AddReductionPass(MakeUnique(env)); + reducer.AddReductionPass( + MakeUnique(env)); + + std::vector binary_in; + SpirvTools t(env); + + ASSERT_TRUE(t.Assemble(original, &binary_in, kReduceAssembleOption)); + std::vector binary_out; + spvtools::ReducerOptions reducer_options; + reducer_options.set_step_limit(500); + + reducer.Run(std::move(binary_in), &binary_out, reducer_options); + + CheckEqual(env, expected, binary_out); +} + +TEST(ReducerTest, RemoveOpnameAndRemoveUnreferenced) { + const std::string original = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + OpName %2 "main" + OpName %3 "a" + OpName %4 "this-name-counts-as-usage-for-load-instruction" + %5 = OpTypeVoid + %6 = OpTypeFunction %5 + %7 = OpTypeFloat 32 + %8 = OpTypePointer Function %7 + %9 = OpConstant %7 1 + %2 = OpFunction %5 None %6 + %10 = OpLabel + %3 = OpVariable %8 Function + %4 = OpLoad %7 %3 + OpStore %3 %7 + OpReturn + OpFunctionEnd + )"; + + const std::string expected = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + %5 = OpTypeVoid + %6 = OpTypeFunction %5 + %7 = OpTypeFloat 32 + %8 = OpTypePointer Function %7 + %9 = OpConstant %7 1 + %2 = OpFunction %5 None %6 + %10 = OpLabel + OpReturn + OpFunctionEnd + )"; + + spv_target_env env = SPV_ENV_UNIVERSAL_1_3; + Reducer reducer(env); + // Make ping-pong interesting very quickly, as there are not much + // opportunities. + PingPongInteresting ping_pong_interesting(1); + reducer.SetMessageConsumer(NopDiagnostic); + reducer.SetInterestingnessFunction( + [&](const std::vector& binary, uint32_t) -> bool { + return ping_pong_interesting.IsInteresting(binary); + }); + reducer.AddReductionPass( + MakeUnique(env)); + reducer.AddReductionPass( + MakeUnique(env)); + + std::vector binary_in; + SpirvTools t(env); + + ASSERT_TRUE(t.Assemble(original, &binary_in, kReduceAssembleOption)); + std::vector binary_out; + spvtools::ReducerOptions reducer_options; + reducer_options.set_step_limit(500); + + reducer.Run(std::move(binary_in), &binary_out, reducer_options); + + CheckEqual(env, expected, binary_out); +} + +} // namespace +} // namespace reduce +} // namespace spvtools \ No newline at end of file diff --git a/test/reduce/remove_opname_instruction_reduction_pass_test.cpp b/test/reduce/remove_opname_instruction_reduction_pass_test.cpp new file mode 100644 index 000000000..38a2d7f77 --- /dev/null +++ b/test/reduce/remove_opname_instruction_reduction_pass_test.cpp @@ -0,0 +1,216 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reduce_test_util.h" + +#include "source/opt/build_module.h" +#include "source/reduce/reduction_opportunity.h" +#include "source/reduce/remove_opname_instruction_reduction_pass.h" + +namespace spvtools { +namespace reduce { +namespace { + +TEST(RemoveOpnameInstructionReductionPassTest, NothingToRemove) { + const std::string source = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = + BuildModule(env, consumer, source, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + ASSERT_EQ(0, ops.size()); +} + +TEST(RemoveOpnameInstructionReductionPassTest, RemoveSingleOpName) { + const std::string prologue = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + )"; + + const std::string epilogue = R"( + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const std::string original = prologue + R"( + OpName %4 "main" + )" + epilogue; + + const std::string expected = prologue + epilogue; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = + BuildModule(env, consumer, original, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + ASSERT_EQ(1, ops.size()); + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + + CheckEqual(env, expected, context.get()); +} + +TEST(RemoveOpnameInstructionReductionPassTest, TryApplyRemovesAllOpName) { + const std::string prologue = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + )"; + + const std::string epilogue = R"( + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %11 = OpVariable %7 Function + %12 = OpVariable %7 Function + OpStore %8 %9 + OpStore %10 %9 + OpStore %11 %9 + OpStore %12 %9 + OpReturn + OpFunctionEnd + )"; + + const std::string original = prologue + R"( + OpName %4 "main" + OpName %8 "a" + OpName %10 "b" + OpName %11 "c" + OpName %12 "d" + )" + epilogue; + + const std::string expected = prologue + epilogue; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + auto pass = TestSubclass(env); + + { + // Check the right number of opportunities is detected + const auto consumer = nullptr; + const auto context = + BuildModule(env, consumer, original, kReduceAssembleOption); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + ASSERT_EQ(5, ops.size()); + } + + { + // The reduction should remove all OpName + std::vector binary; + SpirvTools t(env); + ASSERT_TRUE(t.Assemble(original, &binary, kReduceAssembleOption)); + auto reduced_binary = pass.TryApplyReduction(binary); + CheckEqual(env, expected, reduced_binary); + } +} + +TEST(RemoveOpnameInstructionReductionPassTest, + TryApplyRemovesAllOpNameAndOpMemberName) { + const std::string prologue = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + )"; + + const std::string epilogue = R"( + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypeInt 32 1 + %8 = OpTypeVector %6 3 + %9 = OpTypeStruct %6 %7 %8 + %10 = OpTypePointer Function %9 + %12 = OpConstant %7 0 + %13 = OpConstant %6 1 + %14 = OpTypePointer Function %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %11 = OpVariable %10 Function + %15 = OpAccessChain %14 %11 %12 + OpStore %15 %13 + OpReturn + OpFunctionEnd + )"; + + const std::string original = prologue + R"( + OpName %4 "main" + OpName %9 "S" + OpMemberName %9 0 "f" + OpMemberName %9 1 "i" + OpMemberName %9 2 "v" + OpName %11 "s" + )" + epilogue; + + const std::string expected = prologue + epilogue; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + auto pass = TestSubclass(env); + + { + // Check the right number of opportunities is detected + const auto consumer = nullptr; + const auto context = + BuildModule(env, consumer, original, kReduceAssembleOption); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + ASSERT_EQ(6, ops.size()); + } + + { + // The reduction should remove all OpName + std::vector binary; + SpirvTools t(env); + ASSERT_TRUE(t.Assemble(original, &binary, kReduceAssembleOption)); + auto reduced_binary = pass.TryApplyReduction(binary); + CheckEqual(env, expected, reduced_binary); + } +} + +} // namespace +} // namespace reduce +} // namespace spvtools diff --git a/test/reduce/remove_unreferenced_instruction_reduction_pass_test.cpp b/test/reduce/remove_unreferenced_instruction_reduction_pass_test.cpp new file mode 100644 index 000000000..a002fa36a --- /dev/null +++ b/test/reduce/remove_unreferenced_instruction_reduction_pass_test.cpp @@ -0,0 +1,230 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reduce_test_util.h" + +#include "source/opt/build_module.h" +#include "source/reduce/reduction_opportunity.h" +#include "source/reduce/remove_unreferenced_instruction_reduction_pass.h" + +namespace spvtools { +namespace reduce { +namespace { + +TEST(RemoveUnreferencedInstructionReductionPassTest, RemoveStores) { + const std::string prologue = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + OpName %8 "a" + OpName %10 "b" + OpName %12 "c" + OpName %14 "d" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 10 + %11 = OpConstant %6 20 + %13 = OpConstant %6 30 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %12 = OpVariable %7 Function + %14 = OpVariable %7 Function + )"; + + const std::string epilogue = R"( + OpReturn + OpFunctionEnd + )"; + + const std::string original = prologue + R"( + OpStore %8 %9 + OpStore %10 %11 + OpStore %12 %13 + %15 = OpLoad %6 %8 + OpStore %14 %15 + )" + epilogue; + + const std::string expected = prologue + R"( + OpStore %12 %13 + %15 = OpLoad %6 %8 + OpStore %14 %15 + )" + epilogue; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = + BuildModule(env, consumer, original, kReduceAssembleOption); + const auto pass = + TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + ASSERT_EQ(4, ops.size()); + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + ASSERT_TRUE(ops[1]->PreconditionHolds()); + ops[1]->TryToApply(); + + CheckEqual(env, expected, context.get()); +} + +TEST(RemoveUnreferencedInstructionReductionPassTest, ApplyReduction) { + const std::string prologue = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + OpName %8 "a" + OpName %10 "b" + OpName %12 "c" + OpName %14 "d" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 10 + %11 = OpConstant %6 20 + %13 = OpConstant %6 30 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %10 = OpVariable %7 Function + %12 = OpVariable %7 Function + %14 = OpVariable %7 Function + )"; + + const std::string epilogue = R"( + OpReturn + OpFunctionEnd + )"; + + const std::string original = prologue + R"( + OpStore %8 %9 + OpStore %10 %11 + OpStore %12 %13 + %15 = OpLoad %6 %8 + OpStore %14 %15 + )" + epilogue; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + + std::vector binary; + SpirvTools t(env); + ASSERT_TRUE(t.Assemble(original, &binary, kReduceAssembleOption)); + + auto pass = TestSubclass(env); + + { + // Attempt 1 should remove everything removable. + const std::string expected_reduced = prologue + R"( + %15 = OpLoad %6 %8 + )" + epilogue; + auto reduced_binary = pass.TryApplyReduction(binary); + CheckEqual(env, expected_reduced, reduced_binary); + } + + // Attempt 2 should fail as pass with granularity 4 got to end. + ASSERT_EQ(0, pass.TryApplyReduction(binary).size()); + + { + // Attempt 3 should remove first two removable statements. + const std::string expected_reduced = prologue + R"( + OpStore %12 %13 + %15 = OpLoad %6 %8 + OpStore %14 %15 + )" + epilogue; + auto reduced_binary = pass.TryApplyReduction(binary); + CheckEqual(env, expected_reduced, reduced_binary); + } + + { + // Attempt 4 should remove last two removable statements. + const std::string expected_reduced = prologue + R"( + OpStore %8 %9 + OpStore %10 %11 + %15 = OpLoad %6 %8 + )" + epilogue; + auto reduced_binary = pass.TryApplyReduction(binary); + CheckEqual(env, expected_reduced, reduced_binary); + } + + // Attempt 5 should fail as pass with granularity 2 got to end. + ASSERT_EQ(0, pass.TryApplyReduction(binary).size()); + + { + // Attempt 6 should remove first removable statement. + const std::string expected_reduced = prologue + R"( + OpStore %10 %11 + OpStore %12 %13 + %15 = OpLoad %6 %8 + OpStore %14 %15 + )" + epilogue; + auto reduced_binary = pass.TryApplyReduction(binary); + CheckEqual(env, expected_reduced, reduced_binary); + } + + { + // Attempt 7 should remove second removable statement. + const std::string expected_reduced = prologue + R"( + OpStore %8 %9 + OpStore %12 %13 + %15 = OpLoad %6 %8 + OpStore %14 %15 + )" + epilogue; + auto reduced_binary = pass.TryApplyReduction(binary); + CheckEqual(env, expected_reduced, reduced_binary); + } + + { + // Attempt 8 should remove third removable statement. + const std::string expected_reduced = prologue + R"( + OpStore %8 %9 + OpStore %10 %11 + %15 = OpLoad %6 %8 + OpStore %14 %15 + )" + epilogue; + auto reduced_binary = pass.TryApplyReduction(binary); + CheckEqual(env, expected_reduced, reduced_binary); + } + + { + // Attempt 9 should remove fourth removable statement. + const std::string expected_reduced = prologue + R"( + OpStore %8 %9 + OpStore %10 %11 + OpStore %12 %13 + %15 = OpLoad %6 %8 + )" + epilogue; + auto reduced_binary = pass.TryApplyReduction(binary); + CheckEqual(env, expected_reduced, reduced_binary); + } + + // Attempt 10 should fail as pass with granularity 1 got to end. + ASSERT_EQ(0, pass.TryApplyReduction(binary).size()); + + ASSERT_TRUE(pass.ReachedMinimumGranularity()); +} + +} // namespace +} // namespace reduce +} // namespace spvtools diff --git a/test/reduce/structured_loop_to_selection_reduction_pass_test.cpp b/test/reduce/structured_loop_to_selection_reduction_pass_test.cpp new file mode 100644 index 000000000..8388cb2e2 --- /dev/null +++ b/test/reduce/structured_loop_to_selection_reduction_pass_test.cpp @@ -0,0 +1,3440 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/reduce/structured_loop_to_selection_reduction_pass.h" +#include "reduce_test_util.h" +#include "source/opt/build_module.h" + +namespace spvtools { +namespace reduce { +namespace { + +TEST(StructuredLoopToSelectionReductionPassTest, LoopyShader1) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 100 + %17 = OpTypeBool + %20 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %15 = OpLoad %6 %8 + %18 = OpSLessThan %17 %15 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %13 + %13 = OpLabel + %19 = OpLoad %6 %8 + %21 = OpIAdd %6 %19 %20 + OpStore %8 %21 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto context = BuildModule(env, nullptr, shader, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + ASSERT_EQ(1, ops.size()); + + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + CheckValid(env, context.get()); + + std::string after_op_0 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 100 + %17 = OpTypeBool + %20 = OpConstant %6 1 + %22 = OpConstantTrue %17 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + OpSelectionMerge %12 None + OpBranchConditional %22 %14 %12 + %14 = OpLabel + %15 = OpLoad %6 %8 + %18 = OpSLessThan %17 %15 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %12 + %13 = OpLabel + %19 = OpLoad %6 %8 + %21 = OpIAdd %6 %19 %20 + OpStore %8 %21 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, after_op_0, context.get()); +} + +TEST(StructuredLoopToSelectionReductionPassTest, LoopyShader2) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 100 + %17 = OpTypeBool + %28 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %32 = OpVariable %7 Function + %40 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %15 = OpLoad %6 %8 + %18 = OpSLessThan %17 %15 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpStore %19 %9 + OpBranch %20 + %20 = OpLabel + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + %25 = OpLoad %6 %19 + %26 = OpSLessThan %17 %25 %16 + OpBranchConditional %26 %21 %22 + %21 = OpLabel + OpBranch %23 + %23 = OpLabel + %27 = OpLoad %6 %19 + %29 = OpIAdd %6 %27 %28 + OpStore %19 %29 + OpBranch %20 + %22 = OpLabel + OpBranch %13 + %13 = OpLabel + %30 = OpLoad %6 %8 + %31 = OpIAdd %6 %30 %28 + OpStore %8 %31 + OpBranch %10 + %12 = OpLabel + OpStore %32 %9 + OpBranch %33 + %33 = OpLabel + OpLoopMerge %35 %36 None + OpBranch %37 + %37 = OpLabel + %38 = OpLoad %6 %32 + %39 = OpSLessThan %17 %38 %16 + OpBranchConditional %39 %34 %35 + %34 = OpLabel + OpStore %40 %9 + OpBranch %41 + %41 = OpLabel + OpLoopMerge %43 %44 None + OpBranch %45 + %45 = OpLabel + %46 = OpLoad %6 %40 + %47 = OpSLessThan %17 %46 %16 + OpBranchConditional %47 %42 %43 + %42 = OpLabel + OpBranch %44 + %44 = OpLabel + %48 = OpLoad %6 %40 + %49 = OpIAdd %6 %48 %28 + OpStore %40 %49 + OpBranch %41 + %43 = OpLabel + OpBranch %36 + %36 = OpLabel + %50 = OpLoad %6 %32 + %51 = OpIAdd %6 %50 %28 + OpStore %32 %51 + OpBranch %33 + %35 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto context = BuildModule(env, nullptr, shader, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + ASSERT_EQ(4, ops.size()); + + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + CheckValid(env, context.get()); + std::string after_op_0 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 100 + %17 = OpTypeBool + %28 = OpConstant %6 1 + %52 = OpConstantTrue %17 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %32 = OpVariable %7 Function + %40 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + OpSelectionMerge %12 None + OpBranchConditional %52 %14 %12 + %14 = OpLabel + %15 = OpLoad %6 %8 + %18 = OpSLessThan %17 %15 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpStore %19 %9 + OpBranch %20 + %20 = OpLabel + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + %25 = OpLoad %6 %19 + %26 = OpSLessThan %17 %25 %16 + OpBranchConditional %26 %21 %22 + %21 = OpLabel + OpBranch %23 + %23 = OpLabel + %27 = OpLoad %6 %19 + %29 = OpIAdd %6 %27 %28 + OpStore %19 %29 + OpBranch %20 + %22 = OpLabel + OpBranch %12 + %13 = OpLabel + %30 = OpLoad %6 %8 + %31 = OpIAdd %6 %30 %28 + OpStore %8 %31 + OpBranch %10 + %12 = OpLabel + OpStore %32 %9 + OpBranch %33 + %33 = OpLabel + OpLoopMerge %35 %36 None + OpBranch %37 + %37 = OpLabel + %38 = OpLoad %6 %32 + %39 = OpSLessThan %17 %38 %16 + OpBranchConditional %39 %34 %35 + %34 = OpLabel + OpStore %40 %9 + OpBranch %41 + %41 = OpLabel + OpLoopMerge %43 %44 None + OpBranch %45 + %45 = OpLabel + %46 = OpLoad %6 %40 + %47 = OpSLessThan %17 %46 %16 + OpBranchConditional %47 %42 %43 + %42 = OpLabel + OpBranch %44 + %44 = OpLabel + %48 = OpLoad %6 %40 + %49 = OpIAdd %6 %48 %28 + OpStore %40 %49 + OpBranch %41 + %43 = OpLabel + OpBranch %36 + %36 = OpLabel + %50 = OpLoad %6 %32 + %51 = OpIAdd %6 %50 %28 + OpStore %32 %51 + OpBranch %33 + %35 = OpLabel + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, after_op_0, context.get()); + + ASSERT_TRUE(ops[1]->PreconditionHolds()); + ops[1]->TryToApply(); + CheckValid(env, context.get()); + std::string after_op_1 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 100 + %17 = OpTypeBool + %28 = OpConstant %6 1 + %52 = OpConstantTrue %17 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %32 = OpVariable %7 Function + %40 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + OpSelectionMerge %12 None + OpBranchConditional %52 %14 %12 + %14 = OpLabel + %15 = OpLoad %6 %8 + %18 = OpSLessThan %17 %15 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpStore %19 %9 + OpBranch %20 + %20 = OpLabel + OpSelectionMerge %22 None + OpBranchConditional %52 %24 %22 + %24 = OpLabel + %25 = OpLoad %6 %19 + %26 = OpSLessThan %17 %25 %16 + OpBranchConditional %26 %21 %22 + %21 = OpLabel + OpBranch %22 + %23 = OpLabel + %27 = OpLoad %6 %19 + %29 = OpIAdd %6 %27 %28 + OpStore %19 %29 + OpBranch %20 + %22 = OpLabel + OpBranch %12 + %13 = OpLabel + %30 = OpLoad %6 %8 + %31 = OpIAdd %6 %30 %28 + OpStore %8 %31 + OpBranch %10 + %12 = OpLabel + OpStore %32 %9 + OpBranch %33 + %33 = OpLabel + OpLoopMerge %35 %36 None + OpBranch %37 + %37 = OpLabel + %38 = OpLoad %6 %32 + %39 = OpSLessThan %17 %38 %16 + OpBranchConditional %39 %34 %35 + %34 = OpLabel + OpStore %40 %9 + OpBranch %41 + %41 = OpLabel + OpLoopMerge %43 %44 None + OpBranch %45 + %45 = OpLabel + %46 = OpLoad %6 %40 + %47 = OpSLessThan %17 %46 %16 + OpBranchConditional %47 %42 %43 + %42 = OpLabel + OpBranch %44 + %44 = OpLabel + %48 = OpLoad %6 %40 + %49 = OpIAdd %6 %48 %28 + OpStore %40 %49 + OpBranch %41 + %43 = OpLabel + OpBranch %36 + %36 = OpLabel + %50 = OpLoad %6 %32 + %51 = OpIAdd %6 %50 %28 + OpStore %32 %51 + OpBranch %33 + %35 = OpLabel + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, after_op_1, context.get()); + + ASSERT_TRUE(ops[2]->PreconditionHolds()); + ops[2]->TryToApply(); + CheckValid(env, context.get()); + std::string after_op_2 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 100 + %17 = OpTypeBool + %28 = OpConstant %6 1 + %52 = OpConstantTrue %17 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %32 = OpVariable %7 Function + %40 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + OpSelectionMerge %12 None + OpBranchConditional %52 %14 %12 + %14 = OpLabel + %15 = OpLoad %6 %8 + %18 = OpSLessThan %17 %15 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpStore %19 %9 + OpBranch %20 + %20 = OpLabel + OpSelectionMerge %22 None + OpBranchConditional %52 %24 %22 + %24 = OpLabel + %25 = OpLoad %6 %19 + %26 = OpSLessThan %17 %25 %16 + OpBranchConditional %26 %21 %22 + %21 = OpLabel + OpBranch %22 + %23 = OpLabel + %27 = OpLoad %6 %19 + %29 = OpIAdd %6 %27 %28 + OpStore %19 %29 + OpBranch %20 + %22 = OpLabel + OpBranch %12 + %13 = OpLabel + %30 = OpLoad %6 %8 + %31 = OpIAdd %6 %30 %28 + OpStore %8 %31 + OpBranch %10 + %12 = OpLabel + OpStore %32 %9 + OpBranch %33 + %33 = OpLabel + OpSelectionMerge %35 None + OpBranchConditional %52 %37 %35 + %37 = OpLabel + %38 = OpLoad %6 %32 + %39 = OpSLessThan %17 %38 %16 + OpBranchConditional %39 %34 %35 + %34 = OpLabel + OpStore %40 %9 + OpBranch %41 + %41 = OpLabel + OpLoopMerge %43 %44 None + OpBranch %45 + %45 = OpLabel + %46 = OpLoad %6 %40 + %47 = OpSLessThan %17 %46 %16 + OpBranchConditional %47 %42 %43 + %42 = OpLabel + OpBranch %44 + %44 = OpLabel + %48 = OpLoad %6 %40 + %49 = OpIAdd %6 %48 %28 + OpStore %40 %49 + OpBranch %41 + %43 = OpLabel + OpBranch %35 + %36 = OpLabel + %50 = OpLoad %6 %32 + %51 = OpIAdd %6 %50 %28 + OpStore %32 %51 + OpBranch %33 + %35 = OpLabel + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, after_op_2, context.get()); + + ASSERT_TRUE(ops[3]->PreconditionHolds()); + ops[3]->TryToApply(); + CheckValid(env, context.get()); + std::string after_op_3 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 100 + %17 = OpTypeBool + %28 = OpConstant %6 1 + %52 = OpConstantTrue %17 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + %19 = OpVariable %7 Function + %32 = OpVariable %7 Function + %40 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + OpSelectionMerge %12 None + OpBranchConditional %52 %14 %12 + %14 = OpLabel + %15 = OpLoad %6 %8 + %18 = OpSLessThan %17 %15 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpStore %19 %9 + OpBranch %20 + %20 = OpLabel + OpSelectionMerge %22 None + OpBranchConditional %52 %24 %22 + %24 = OpLabel + %25 = OpLoad %6 %19 + %26 = OpSLessThan %17 %25 %16 + OpBranchConditional %26 %21 %22 + %21 = OpLabel + OpBranch %22 + %23 = OpLabel + %27 = OpLoad %6 %19 + %29 = OpIAdd %6 %27 %28 + OpStore %19 %29 + OpBranch %20 + %22 = OpLabel + OpBranch %12 + %13 = OpLabel + %30 = OpLoad %6 %8 + %31 = OpIAdd %6 %30 %28 + OpStore %8 %31 + OpBranch %10 + %12 = OpLabel + OpStore %32 %9 + OpBranch %33 + %33 = OpLabel + OpSelectionMerge %35 None + OpBranchConditional %52 %37 %35 + %37 = OpLabel + %38 = OpLoad %6 %32 + %39 = OpSLessThan %17 %38 %16 + OpBranchConditional %39 %34 %35 + %34 = OpLabel + OpStore %40 %9 + OpBranch %41 + %41 = OpLabel + OpSelectionMerge %43 None + OpBranchConditional %52 %45 %43 + %45 = OpLabel + %46 = OpLoad %6 %40 + %47 = OpSLessThan %17 %46 %16 + OpBranchConditional %47 %42 %43 + %42 = OpLabel + OpBranch %43 + %44 = OpLabel + %48 = OpLoad %6 %40 + %49 = OpIAdd %6 %48 %28 + OpStore %40 %49 + OpBranch %41 + %43 = OpLabel + OpBranch %35 + %36 = OpLabel + %50 = OpLoad %6 %32 + %51 = OpIAdd %6 %50 %28 + OpStore %32 %51 + OpBranch %33 + %35 = OpLabel + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, after_op_3, context.get()); +} + +TEST(StructuredLoopToSelectionReductionPassTest, LoopyShader3) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 10 + %16 = OpConstant %6 0 + %17 = OpTypeBool + %20 = OpConstant %6 1 + %23 = OpConstant %6 3 + %40 = OpConstant %6 5 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %8 = OpVariable %7 Function + OpStore %8 %9 + OpBranch %10 + %10 = OpLabel + OpLoopMerge %12 %13 None + OpBranch %14 + %14 = OpLabel + %15 = OpLoad %6 %8 + %18 = OpSGreaterThan %17 %15 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %19 = OpLoad %6 %8 + %21 = OpISub %6 %19 %20 + OpStore %8 %21 + %22 = OpLoad %6 %8 + %24 = OpSLessThan %17 %22 %23 + OpSelectionMerge %26 None + OpBranchConditional %24 %25 %26 + %25 = OpLabel + OpBranch %13 + %26 = OpLabel + OpBranch %28 + %28 = OpLabel + OpLoopMerge %30 %31 None + OpBranch %29 + %29 = OpLabel + %32 = OpLoad %6 %8 + %33 = OpISub %6 %32 %20 + OpStore %8 %33 + %34 = OpLoad %6 %8 + %35 = OpIEqual %17 %34 %20 + OpSelectionMerge %37 None + OpBranchConditional %35 %36 %37 + %36 = OpLabel + OpReturn ; This return spoils everything: it means the merge does not post-dominate the header. + %37 = OpLabel + OpBranch %31 + %31 = OpLabel + %39 = OpLoad %6 %8 + %41 = OpSGreaterThan %17 %39 %40 + OpBranchConditional %41 %28 %30 + %30 = OpLabel + OpBranch %13 + %13 = OpLabel + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto context = BuildModule(env, nullptr, shader, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + ASSERT_EQ(0, ops.size()); +} + +TEST(StructuredLoopToSelectionReductionPassTest, LoopyShader4) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %8 = OpTypeFunction %6 %7 + %13 = OpConstant %6 0 + %22 = OpTypeBool + %25 = OpConstant %6 1 + %39 = OpConstant %6 100 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %45 = OpVariable %7 Function + %46 = OpVariable %7 Function + %47 = OpVariable %7 Function + %32 = OpVariable %7 Function + %42 = OpVariable %7 Function + OpStore %32 %13 + OpBranch %33 + %33 = OpLabel + OpLoopMerge %35 %36 None + OpBranch %37 + %37 = OpLabel + %38 = OpLoad %6 %32 + %40 = OpSLessThan %22 %38 %39 + OpBranchConditional %40 %34 %35 + %34 = OpLabel + OpBranch %36 + %36 = OpLabel + %41 = OpLoad %6 %32 + OpStore %42 %25 + OpStore %45 %13 + OpStore %46 %13 + OpBranch %48 + %48 = OpLabel + OpLoopMerge %49 %50 None + OpBranch %51 + %51 = OpLabel + %52 = OpLoad %6 %46 + %53 = OpLoad %6 %42 + %54 = OpSLessThan %22 %52 %53 + OpBranchConditional %54 %55 %49 + %55 = OpLabel + %56 = OpLoad %6 %45 + %57 = OpIAdd %6 %56 %25 + OpStore %45 %57 + OpBranch %50 + %50 = OpLabel + %58 = OpLoad %6 %46 + %59 = OpIAdd %6 %58 %25 + OpStore %46 %59 + OpBranch %48 + %49 = OpLabel + %60 = OpLoad %6 %45 + OpStore %47 %60 + %43 = OpLoad %6 %47 + %44 = OpIAdd %6 %41 %43 + OpStore %32 %44 + OpBranch %33 + %35 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto context = BuildModule(env, nullptr, shader, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + + // Initially there are two opportunities. + ASSERT_EQ(2, ops.size()); + + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + CheckValid(env, context.get()); + std::string after_op_0 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %8 = OpTypeFunction %6 %7 + %13 = OpConstant %6 0 + %22 = OpTypeBool + %25 = OpConstant %6 1 + %39 = OpConstant %6 100 + %61 = OpConstantTrue %22 + %62 = OpUndef %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %45 = OpVariable %7 Function + %46 = OpVariable %7 Function + %47 = OpVariable %7 Function + %32 = OpVariable %7 Function + %42 = OpVariable %7 Function + OpStore %32 %13 + OpBranch %33 + %33 = OpLabel + OpSelectionMerge %35 None + OpBranchConditional %61 %37 %35 + %37 = OpLabel + %38 = OpLoad %6 %32 + %40 = OpSLessThan %22 %38 %39 + OpBranchConditional %40 %34 %35 + %34 = OpLabel + OpBranch %35 + %36 = OpLabel + %41 = OpLoad %6 %32 + OpStore %42 %25 + OpStore %45 %13 + OpStore %46 %13 + OpBranch %48 + %48 = OpLabel + OpLoopMerge %49 %50 None + OpBranch %51 + %51 = OpLabel + %52 = OpLoad %6 %46 + %53 = OpLoad %6 %42 + %54 = OpSLessThan %22 %52 %53 + OpBranchConditional %54 %55 %49 + %55 = OpLabel + %56 = OpLoad %6 %45 + %57 = OpIAdd %6 %56 %25 + OpStore %45 %57 + OpBranch %50 + %50 = OpLabel + %58 = OpLoad %6 %46 + %59 = OpIAdd %6 %58 %25 + OpStore %46 %59 + OpBranch %48 + %49 = OpLabel + %60 = OpLoad %6 %45 + OpStore %47 %60 + %43 = OpLoad %6 %47 + %44 = OpIAdd %6 %62 %43 + OpStore %32 %44 + OpBranch %33 + %35 = OpLabel + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, after_op_0, context.get()); + + // Applying the first opportunity has killed the second opportunity, because + // there was a loop embedded in the continue target of the loop we have just + // eliminated; the continue-embedded loop is now unreachable. + ASSERT_FALSE(ops[1]->PreconditionHolds()); +} + +TEST(StructuredLoopToSelectionReductionPassTest, ConditionalBreak1) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %10 = OpTypeBool + %11 = OpConstantFalse %10 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + OpLoopMerge %8 %9 None + OpBranch %7 + %7 = OpLabel + OpSelectionMerge %13 None + OpBranchConditional %11 %12 %13 + %12 = OpLabel + OpBranch %8 + %13 = OpLabel + OpBranch %9 + %9 = OpLabel + OpBranchConditional %11 %6 %8 + %8 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto context = BuildModule(env, nullptr, shader, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + ASSERT_EQ(1, ops.size()); + + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + CheckValid(env, context.get()); + std::string after_op_0 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %10 = OpTypeBool + %11 = OpConstantFalse %10 + %14 = OpConstantTrue %10 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + OpSelectionMerge %8 None + OpBranchConditional %14 %7 %8 + %7 = OpLabel + OpSelectionMerge %13 None + OpBranchConditional %11 %12 %13 + %12 = OpLabel + OpBranch %13 + %13 = OpLabel + OpBranch %8 + %9 = OpLabel + OpBranchConditional %11 %6 %8 + %8 = OpLabel + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, after_op_0, context.get()); +} + +TEST(StructuredLoopToSelectionReductionPassTest, ConditionalBreak2) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %10 = OpTypeBool + %11 = OpConstantFalse %10 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + OpLoopMerge %8 %9 None + OpBranch %7 + %7 = OpLabel + OpSelectionMerge %13 None + OpBranchConditional %11 %8 %13 + %13 = OpLabel + OpBranch %9 + %9 = OpLabel + OpBranchConditional %11 %6 %8 + %8 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto context = BuildModule(env, nullptr, shader, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + ASSERT_EQ(1, ops.size()); + + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + CheckValid(env, context.get()); + std::string after_op_0 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %10 = OpTypeBool + %11 = OpConstantFalse %10 + %14 = OpConstantTrue %10 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + OpSelectionMerge %8 None + OpBranchConditional %14 %7 %8 + %7 = OpLabel + OpSelectionMerge %13 None + OpBranchConditional %11 %13 %13 + %13 = OpLabel + OpBranch %8 + %9 = OpLabel + OpBranchConditional %11 %6 %8 + %8 = OpLabel + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, after_op_0, context.get()); +} + +TEST(StructuredLoopToSelectionReductionPassTest, UnconditionalBreak) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + OpLoopMerge %8 %9 None + OpBranch %7 + %7 = OpLabel + OpBranch %8 + %9 = OpLabel + OpBranch %6 + %8 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto context = BuildModule(env, nullptr, shader, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + ASSERT_EQ(1, ops.size()); + + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + CheckValid(env, context.get()); + std::string after_op_0 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %10 = OpTypeBool + %11 = OpConstantTrue %10 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + OpSelectionMerge %8 None + OpBranchConditional %11 %7 %8 + %7 = OpLabel + OpBranch %8 + %9 = OpLabel + OpBranch %6 + %8 = OpLabel + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, after_op_0, context.get()); +} + +TEST(StructuredLoopToSelectionReductionPassTest, Complex) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + OpMemberDecorate %4 0 Offset 0 + OpMemberDecorate %4 1 Offset 4 + OpMemberDecorate %4 2 Offset 8 + OpMemberDecorate %4 3 Offset 12 + OpDecorate %4 Block + OpDecorate %5 DescriptorSet 0 + OpDecorate %5 Binding 0 + OpDecorate %3 Location 0 + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeBool + %9 = OpTypePointer Function %8 + %10 = OpTypeInt 32 1 + %4 = OpTypeStruct %10 %10 %10 %10 + %11 = OpTypePointer Uniform %4 + %5 = OpVariable %11 Uniform + %12 = OpConstant %10 0 + %13 = OpTypePointer Uniform %10 + %14 = OpTypeInt 32 0 + %15 = OpConstant %14 0 + %16 = OpConstant %10 1 + %17 = OpConstant %10 2 + %18 = OpConstant %10 3 + %19 = OpTypePointer Function %10 + %20 = OpConstantFalse %8 + %21 = OpTypeFloat 32 + %22 = OpTypeVector %21 4 + %23 = OpTypePointer Output %22 + %3 = OpVariable %23 Output + %2 = OpFunction %6 None %7 + %24 = OpLabel + %25 = OpVariable %9 Function + %26 = OpVariable %9 Function + %27 = OpVariable %9 Function + %28 = OpVariable %9 Function + %29 = OpVariable %9 Function + %30 = OpVariable %19 Function + %31 = OpAccessChain %13 %5 %12 + %32 = OpLoad %10 %31 + %33 = OpINotEqual %8 %32 %15 + OpStore %25 %33 + %34 = OpAccessChain %13 %5 %16 + %35 = OpLoad %10 %34 + %36 = OpINotEqual %8 %35 %15 + OpStore %26 %36 + %37 = OpAccessChain %13 %5 %17 + %38 = OpLoad %10 %37 + %39 = OpINotEqual %8 %38 %15 + OpStore %27 %39 + %40 = OpAccessChain %13 %5 %18 + %41 = OpLoad %10 %40 + %42 = OpINotEqual %8 %41 %15 + OpStore %28 %42 + %43 = OpLoad %8 %25 + OpStore %29 %43 + OpStore %30 %12 + OpBranch %44 + %44 = OpLabel + OpLoopMerge %45 %46 None + OpBranch %47 + %47 = OpLabel + %48 = OpLoad %8 %29 + OpBranchConditional %48 %49 %45 + %49 = OpLabel + %50 = OpLoad %8 %25 + OpSelectionMerge %51 None + OpBranchConditional %50 %52 %51 + %52 = OpLabel + %53 = OpLoad %8 %26 + OpStore %29 %53 + %54 = OpLoad %10 %30 + %55 = OpIAdd %10 %54 %16 + OpStore %30 %55 + OpBranch %51 + %51 = OpLabel + %56 = OpLoad %8 %26 + OpSelectionMerge %57 None + OpBranchConditional %56 %58 %57 + %58 = OpLabel + %59 = OpLoad %10 %30 + %60 = OpIAdd %10 %59 %16 + OpStore %30 %60 + %61 = OpLoad %8 %29 + %62 = OpLoad %8 %25 + %63 = OpLogicalOr %8 %61 %62 + OpStore %29 %63 + %64 = OpLoad %8 %27 + OpSelectionMerge %65 None + OpBranchConditional %64 %66 %65 + %66 = OpLabel + %67 = OpLoad %10 %30 + %68 = OpIAdd %10 %67 %17 + OpStore %30 %68 + %69 = OpLoad %8 %29 + %70 = OpLogicalNot %8 %69 + OpStore %29 %70 + OpBranch %46 + %65 = OpLabel + %71 = OpLoad %8 %29 + %72 = OpLogicalOr %8 %71 %20 + OpStore %29 %72 + OpBranch %46 + %57 = OpLabel + OpBranch %73 + %73 = OpLabel + OpLoopMerge %74 %75 None + OpBranch %76 + %76 = OpLabel + %77 = OpLoad %8 %28 + OpSelectionMerge %78 None + OpBranchConditional %77 %79 %80 + %79 = OpLabel + %81 = OpLoad %10 %30 + OpSelectionMerge %82 None + OpSwitch %81 %83 1 %84 2 %85 + %83 = OpLabel + OpBranch %82 + %84 = OpLabel + %86 = OpLoad %8 %29 + %87 = OpSelect %10 %86 %16 %17 + %88 = OpLoad %10 %30 + %89 = OpIAdd %10 %88 %87 + OpStore %30 %89 + OpBranch %82 + %85 = OpLabel + OpBranch %75 + %82 = OpLabel + %90 = OpLoad %8 %27 + OpSelectionMerge %91 None + OpBranchConditional %90 %92 %91 + %92 = OpLabel + OpBranch %75 + %91 = OpLabel + OpBranch %78 + %80 = OpLabel + OpBranch %74 + %78 = OpLabel + OpBranch %75 + %75 = OpLabel + %93 = OpLoad %8 %29 + OpBranchConditional %93 %73 %74 + %74 = OpLabel + OpBranch %46 + %46 = OpLabel + OpBranch %44 + %45 = OpLabel + %94 = OpLoad %10 %30 + %95 = OpConvertSToF %21 %94 + %96 = OpCompositeConstruct %22 %95 %95 %95 %95 + OpStore %3 %96 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto context = BuildModule(env, nullptr, shader, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + + ASSERT_EQ(2, ops.size()); + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + CheckValid(env, context.get()); + std::string after_op_0 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + OpMemberDecorate %4 0 Offset 0 + OpMemberDecorate %4 1 Offset 4 + OpMemberDecorate %4 2 Offset 8 + OpMemberDecorate %4 3 Offset 12 + OpDecorate %4 Block + OpDecorate %5 DescriptorSet 0 + OpDecorate %5 Binding 0 + OpDecorate %3 Location 0 + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeBool + %9 = OpTypePointer Function %8 + %10 = OpTypeInt 32 1 + %4 = OpTypeStruct %10 %10 %10 %10 + %11 = OpTypePointer Uniform %4 + %5 = OpVariable %11 Uniform + %12 = OpConstant %10 0 + %13 = OpTypePointer Uniform %10 + %14 = OpTypeInt 32 0 + %15 = OpConstant %14 0 + %16 = OpConstant %10 1 + %17 = OpConstant %10 2 + %18 = OpConstant %10 3 + %19 = OpTypePointer Function %10 + %20 = OpConstantFalse %8 + %21 = OpTypeFloat 32 + %22 = OpTypeVector %21 4 + %23 = OpTypePointer Output %22 + %3 = OpVariable %23 Output + %97 = OpConstantTrue %8 + %2 = OpFunction %6 None %7 + %24 = OpLabel + %25 = OpVariable %9 Function + %26 = OpVariable %9 Function + %27 = OpVariable %9 Function + %28 = OpVariable %9 Function + %29 = OpVariable %9 Function + %30 = OpVariable %19 Function + %31 = OpAccessChain %13 %5 %12 + %32 = OpLoad %10 %31 + %33 = OpINotEqual %8 %32 %15 + OpStore %25 %33 + %34 = OpAccessChain %13 %5 %16 + %35 = OpLoad %10 %34 + %36 = OpINotEqual %8 %35 %15 + OpStore %26 %36 + %37 = OpAccessChain %13 %5 %17 + %38 = OpLoad %10 %37 + %39 = OpINotEqual %8 %38 %15 + OpStore %27 %39 + %40 = OpAccessChain %13 %5 %18 + %41 = OpLoad %10 %40 + %42 = OpINotEqual %8 %41 %15 + OpStore %28 %42 + %43 = OpLoad %8 %25 + OpStore %29 %43 + OpStore %30 %12 + OpBranch %44 + %44 = OpLabel + OpSelectionMerge %45 None ; Was OpLoopMerge %45 %46 None + OpBranchConditional %97 %47 %45 ; Was OpBranch %47 + %47 = OpLabel + %48 = OpLoad %8 %29 + OpBranchConditional %48 %49 %45 + %49 = OpLabel + %50 = OpLoad %8 %25 + OpSelectionMerge %51 None + OpBranchConditional %50 %52 %51 + %52 = OpLabel + %53 = OpLoad %8 %26 + OpStore %29 %53 + %54 = OpLoad %10 %30 + %55 = OpIAdd %10 %54 %16 + OpStore %30 %55 + OpBranch %51 + %51 = OpLabel + %56 = OpLoad %8 %26 + OpSelectionMerge %57 None + OpBranchConditional %56 %58 %57 + %58 = OpLabel + %59 = OpLoad %10 %30 + %60 = OpIAdd %10 %59 %16 + OpStore %30 %60 + %61 = OpLoad %8 %29 + %62 = OpLoad %8 %25 + %63 = OpLogicalOr %8 %61 %62 + OpStore %29 %63 + %64 = OpLoad %8 %27 + OpSelectionMerge %65 None + OpBranchConditional %64 %66 %65 + %66 = OpLabel + %67 = OpLoad %10 %30 + %68 = OpIAdd %10 %67 %17 + OpStore %30 %68 + %69 = OpLoad %8 %29 + %70 = OpLogicalNot %8 %69 + OpStore %29 %70 + OpBranch %65 ; Was OpBranch %46 + %65 = OpLabel + %71 = OpLoad %8 %29 + %72 = OpLogicalOr %8 %71 %20 + OpStore %29 %72 + OpBranch %57 ; Was OpBranch %46 + %57 = OpLabel + OpBranch %73 + %73 = OpLabel + OpLoopMerge %74 %75 None + OpBranch %76 + %76 = OpLabel + %77 = OpLoad %8 %28 + OpSelectionMerge %78 None + OpBranchConditional %77 %79 %80 + %79 = OpLabel + %81 = OpLoad %10 %30 + OpSelectionMerge %82 None + OpSwitch %81 %83 1 %84 2 %85 + %83 = OpLabel + OpBranch %82 + %84 = OpLabel + %86 = OpLoad %8 %29 + %87 = OpSelect %10 %86 %16 %17 + %88 = OpLoad %10 %30 + %89 = OpIAdd %10 %88 %87 + OpStore %30 %89 + OpBranch %82 + %85 = OpLabel + OpBranch %75 + %82 = OpLabel + %90 = OpLoad %8 %27 + OpSelectionMerge %91 None + OpBranchConditional %90 %92 %91 + %92 = OpLabel + OpBranch %75 + %91 = OpLabel + OpBranch %78 + %80 = OpLabel + OpBranch %74 + %78 = OpLabel + OpBranch %75 + %75 = OpLabel + %93 = OpLoad %8 %29 + OpBranchConditional %93 %73 %74 + %74 = OpLabel + OpBranch %45 ; Was OpBranch %46 + %46 = OpLabel + OpBranch %44 + %45 = OpLabel + %94 = OpLoad %10 %30 + %95 = OpConvertSToF %21 %94 + %96 = OpCompositeConstruct %22 %95 %95 %95 %95 + OpStore %3 %96 + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, after_op_0, context.get()); + ASSERT_TRUE(ops[1]->PreconditionHolds()); + ops[1]->TryToApply(); + CheckValid(env, context.get()); + + std::string after_op_1 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + OpMemberDecorate %4 0 Offset 0 + OpMemberDecorate %4 1 Offset 4 + OpMemberDecorate %4 2 Offset 8 + OpMemberDecorate %4 3 Offset 12 + OpDecorate %4 Block + OpDecorate %5 DescriptorSet 0 + OpDecorate %5 Binding 0 + OpDecorate %3 Location 0 + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeBool + %9 = OpTypePointer Function %8 + %10 = OpTypeInt 32 1 + %4 = OpTypeStruct %10 %10 %10 %10 + %11 = OpTypePointer Uniform %4 + %5 = OpVariable %11 Uniform + %12 = OpConstant %10 0 + %13 = OpTypePointer Uniform %10 + %14 = OpTypeInt 32 0 + %15 = OpConstant %14 0 + %16 = OpConstant %10 1 + %17 = OpConstant %10 2 + %18 = OpConstant %10 3 + %19 = OpTypePointer Function %10 + %20 = OpConstantFalse %8 + %21 = OpTypeFloat 32 + %22 = OpTypeVector %21 4 + %23 = OpTypePointer Output %22 + %3 = OpVariable %23 Output + %97 = OpConstantTrue %8 + %2 = OpFunction %6 None %7 + %24 = OpLabel + %25 = OpVariable %9 Function + %26 = OpVariable %9 Function + %27 = OpVariable %9 Function + %28 = OpVariable %9 Function + %29 = OpVariable %9 Function + %30 = OpVariable %19 Function + %31 = OpAccessChain %13 %5 %12 + %32 = OpLoad %10 %31 + %33 = OpINotEqual %8 %32 %15 + OpStore %25 %33 + %34 = OpAccessChain %13 %5 %16 + %35 = OpLoad %10 %34 + %36 = OpINotEqual %8 %35 %15 + OpStore %26 %36 + %37 = OpAccessChain %13 %5 %17 + %38 = OpLoad %10 %37 + %39 = OpINotEqual %8 %38 %15 + OpStore %27 %39 + %40 = OpAccessChain %13 %5 %18 + %41 = OpLoad %10 %40 + %42 = OpINotEqual %8 %41 %15 + OpStore %28 %42 + %43 = OpLoad %8 %25 + OpStore %29 %43 + OpStore %30 %12 + OpBranch %44 + %44 = OpLabel + OpSelectionMerge %45 None ; Was OpLoopMerge %45 %46 None + OpBranchConditional %97 %47 %45 ; Was OpBranch %47 + %47 = OpLabel + %48 = OpLoad %8 %29 + OpBranchConditional %48 %49 %45 + %49 = OpLabel + %50 = OpLoad %8 %25 + OpSelectionMerge %51 None + OpBranchConditional %50 %52 %51 + %52 = OpLabel + %53 = OpLoad %8 %26 + OpStore %29 %53 + %54 = OpLoad %10 %30 + %55 = OpIAdd %10 %54 %16 + OpStore %30 %55 + OpBranch %51 + %51 = OpLabel + %56 = OpLoad %8 %26 + OpSelectionMerge %57 None + OpBranchConditional %56 %58 %57 + %58 = OpLabel + %59 = OpLoad %10 %30 + %60 = OpIAdd %10 %59 %16 + OpStore %30 %60 + %61 = OpLoad %8 %29 + %62 = OpLoad %8 %25 + %63 = OpLogicalOr %8 %61 %62 + OpStore %29 %63 + %64 = OpLoad %8 %27 + OpSelectionMerge %65 None + OpBranchConditional %64 %66 %65 + %66 = OpLabel + %67 = OpLoad %10 %30 + %68 = OpIAdd %10 %67 %17 + OpStore %30 %68 + %69 = OpLoad %8 %29 + %70 = OpLogicalNot %8 %69 + OpStore %29 %70 + OpBranch %65 ; Was OpBranch %46 + %65 = OpLabel + %71 = OpLoad %8 %29 + %72 = OpLogicalOr %8 %71 %20 + OpStore %29 %72 + OpBranch %57 ; Was OpBranch %46 + %57 = OpLabel + OpBranch %73 + %73 = OpLabel + OpSelectionMerge %74 None ; Was OpLoopMerge %74 %75 None + OpBranchConditional %97 %76 %74 ; Was OpBranch %76 + %76 = OpLabel + %77 = OpLoad %8 %28 + OpSelectionMerge %78 None + OpBranchConditional %77 %79 %80 + %79 = OpLabel + %81 = OpLoad %10 %30 + OpSelectionMerge %82 None + OpSwitch %81 %83 1 %84 2 %85 + %83 = OpLabel + OpBranch %82 + %84 = OpLabel + %86 = OpLoad %8 %29 + %87 = OpSelect %10 %86 %16 %17 + %88 = OpLoad %10 %30 + %89 = OpIAdd %10 %88 %87 + OpStore %30 %89 + OpBranch %82 + %85 = OpLabel + OpBranch %82 + %82 = OpLabel + %90 = OpLoad %8 %27 + OpSelectionMerge %91 None + OpBranchConditional %90 %92 %91 + %92 = OpLabel + OpBranch %91 + %91 = OpLabel + OpBranch %78 + %80 = OpLabel + OpBranch %78 ; Was OpBranch %74 + %78 = OpLabel + OpBranch %74 + %75 = OpLabel + %93 = OpLoad %8 %29 + OpBranchConditional %93 %73 %74 + %74 = OpLabel + OpBranch %45 ; Was OpBranch %46 + %46 = OpLabel + OpBranch %44 + %45 = OpLabel + %94 = OpLoad %10 %30 + %95 = OpConvertSToF %21 %94 + %96 = OpCompositeConstruct %22 %95 %95 %95 %95 + OpStore %3 %96 + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, after_op_1, context.get()); +} + +TEST(StructuredLoopToSelectionReductionPassTest, ComplexOptimized) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + OpMemberDecorate %4 0 Offset 0 + OpMemberDecorate %4 1 Offset 4 + OpMemberDecorate %4 2 Offset 8 + OpMemberDecorate %4 3 Offset 12 + OpDecorate %4 Block + OpDecorate %5 DescriptorSet 0 + OpDecorate %5 Binding 0 + OpDecorate %3 Location 0 + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeBool + %10 = OpTypeInt 32 1 + %4 = OpTypeStruct %10 %10 %10 %10 + %11 = OpTypePointer Uniform %4 + %5 = OpVariable %11 Uniform + %12 = OpConstant %10 0 + %13 = OpTypePointer Uniform %10 + %14 = OpTypeInt 32 0 + %15 = OpConstant %14 0 + %16 = OpConstant %10 1 + %17 = OpConstant %10 2 + %18 = OpConstant %10 3 + %20 = OpConstantFalse %8 + %21 = OpTypeFloat 32 + %22 = OpTypeVector %21 4 + %23 = OpTypePointer Output %22 + %3 = OpVariable %23 Output + %2 = OpFunction %6 None %7 + %24 = OpLabel + %31 = OpAccessChain %13 %5 %12 + %32 = OpLoad %10 %31 + %33 = OpINotEqual %8 %32 %15 + %34 = OpAccessChain %13 %5 %16 + %35 = OpLoad %10 %34 + %36 = OpINotEqual %8 %35 %15 + %37 = OpAccessChain %13 %5 %17 + %38 = OpLoad %10 %37 + %39 = OpINotEqual %8 %38 %15 + %40 = OpAccessChain %13 %5 %18 + %41 = OpLoad %10 %40 + %42 = OpINotEqual %8 %41 %15 + OpBranch %44 + %44 = OpLabel + %98 = OpPhi %10 %12 %24 %107 %46 + %97 = OpPhi %8 %33 %24 %105 %46 + OpLoopMerge %45 %46 None + OpBranchConditional %97 %49 %45 + %49 = OpLabel + OpSelectionMerge %51 None + OpBranchConditional %33 %52 %51 + %52 = OpLabel + %55 = OpIAdd %10 %98 %16 + OpBranch %51 + %51 = OpLabel + %100 = OpPhi %10 %98 %49 %55 %52 + %113 = OpSelect %8 %33 %36 %97 + OpSelectionMerge %57 None + OpBranchConditional %36 %58 %57 + %58 = OpLabel + %60 = OpIAdd %10 %100 %16 + %63 = OpLogicalOr %8 %113 %33 + OpSelectionMerge %65 None + OpBranchConditional %39 %66 %65 + %66 = OpLabel + %68 = OpIAdd %10 %100 %18 + %70 = OpLogicalNot %8 %63 + OpBranch %46 + %65 = OpLabel + %72 = OpLogicalOr %8 %63 %20 + OpBranch %46 + %57 = OpLabel + OpBranch %73 + %73 = OpLabel + %99 = OpPhi %10 %100 %57 %109 %75 + OpLoopMerge %74 %75 None + OpBranch %76 + %76 = OpLabel + OpSelectionMerge %78 None + OpBranchConditional %42 %79 %80 + %79 = OpLabel + OpSelectionMerge %82 None + OpSwitch %99 %83 1 %84 2 %85 + %83 = OpLabel + OpBranch %82 + %84 = OpLabel + %87 = OpSelect %10 %113 %16 %17 + %89 = OpIAdd %10 %99 %87 + OpBranch %82 + %85 = OpLabel + OpBranch %75 + %82 = OpLabel + %110 = OpPhi %10 %99 %83 %89 %84 + OpSelectionMerge %91 None + OpBranchConditional %39 %92 %91 + %92 = OpLabel + OpBranch %75 + %91 = OpLabel + OpBranch %78 + %80 = OpLabel + OpBranch %74 + %78 = OpLabel + OpBranch %75 + %75 = OpLabel + %109 = OpPhi %10 %99 %85 %110 %92 %110 %78 + OpBranchConditional %113 %73 %74 + %74 = OpLabel + %108 = OpPhi %10 %99 %80 %109 %75 + OpBranch %46 + %46 = OpLabel + %107 = OpPhi %10 %68 %66 %60 %65 %108 %74 + %105 = OpPhi %8 %70 %66 %72 %65 %113 %74 + OpBranch %44 + %45 = OpLabel + %95 = OpConvertSToF %21 %98 + %96 = OpCompositeConstruct %22 %95 %95 %95 %95 + OpStore %3 %96 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto context = BuildModule(env, nullptr, shader, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + + ASSERT_EQ(2, ops.size()); + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + CheckValid(env, context.get()); + std::string after_op_0 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + OpMemberDecorate %4 0 Offset 0 + OpMemberDecorate %4 1 Offset 4 + OpMemberDecorate %4 2 Offset 8 + OpMemberDecorate %4 3 Offset 12 + OpDecorate %4 Block + OpDecorate %5 DescriptorSet 0 + OpDecorate %5 Binding 0 + OpDecorate %3 Location 0 + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeBool + %10 = OpTypeInt 32 1 + %4 = OpTypeStruct %10 %10 %10 %10 + %11 = OpTypePointer Uniform %4 + %5 = OpVariable %11 Uniform + %12 = OpConstant %10 0 + %13 = OpTypePointer Uniform %10 + %14 = OpTypeInt 32 0 + %15 = OpConstant %14 0 + %16 = OpConstant %10 1 + %17 = OpConstant %10 2 + %18 = OpConstant %10 3 + %20 = OpConstantFalse %8 + %21 = OpTypeFloat 32 + %22 = OpTypeVector %21 4 + %23 = OpTypePointer Output %22 + %3 = OpVariable %23 Output + %114 = OpUndef %10 + %115 = OpUndef %8 + %2 = OpFunction %6 None %7 + %24 = OpLabel + %31 = OpAccessChain %13 %5 %12 + %32 = OpLoad %10 %31 + %33 = OpINotEqual %8 %32 %15 + %34 = OpAccessChain %13 %5 %16 + %35 = OpLoad %10 %34 + %36 = OpINotEqual %8 %35 %15 + %37 = OpAccessChain %13 %5 %17 + %38 = OpLoad %10 %37 + %39 = OpINotEqual %8 %38 %15 + %40 = OpAccessChain %13 %5 %18 + %41 = OpLoad %10 %40 + %42 = OpINotEqual %8 %41 %15 + OpBranch %44 + %44 = OpLabel + %98 = OpPhi %10 %12 %24 %114 %46 + %97 = OpPhi %8 %33 %24 %115 %46 + OpSelectionMerge %45 None ; Was OpLoopMerge %45 %46 None + OpBranchConditional %97 %49 %45 + %49 = OpLabel + OpSelectionMerge %51 None + OpBranchConditional %33 %52 %51 + %52 = OpLabel + %55 = OpIAdd %10 %98 %16 + OpBranch %51 + %51 = OpLabel + %100 = OpPhi %10 %98 %49 %55 %52 + %113 = OpSelect %8 %33 %36 %97 + OpSelectionMerge %57 None + OpBranchConditional %36 %58 %57 + %58 = OpLabel + %60 = OpIAdd %10 %100 %16 + %63 = OpLogicalOr %8 %113 %33 + OpSelectionMerge %65 None + OpBranchConditional %39 %66 %65 + %66 = OpLabel + %68 = OpIAdd %10 %100 %18 + %70 = OpLogicalNot %8 %63 + OpBranch %65 ; Was OpBranch %46 + %65 = OpLabel + %72 = OpLogicalOr %8 %63 %20 + OpBranch %57 ; Was OpBranch %46 + %57 = OpLabel + OpBranch %73 + %73 = OpLabel + %99 = OpPhi %10 %100 %57 %109 %75 + OpLoopMerge %74 %75 None + OpBranch %76 + %76 = OpLabel + OpSelectionMerge %78 None + OpBranchConditional %42 %79 %80 + %79 = OpLabel + OpSelectionMerge %82 None + OpSwitch %99 %83 1 %84 2 %85 + %83 = OpLabel + OpBranch %82 + %84 = OpLabel + %87 = OpSelect %10 %113 %16 %17 + %89 = OpIAdd %10 %99 %87 + OpBranch %82 + %85 = OpLabel + OpBranch %75 + %82 = OpLabel + %110 = OpPhi %10 %99 %83 %89 %84 + OpSelectionMerge %91 None + OpBranchConditional %39 %92 %91 + %92 = OpLabel + OpBranch %75 + %91 = OpLabel + OpBranch %78 + %80 = OpLabel + OpBranch %74 + %78 = OpLabel + OpBranch %75 + %75 = OpLabel + %109 = OpPhi %10 %99 %85 %110 %92 %110 %78 + OpBranchConditional %113 %73 %74 + %74 = OpLabel + %108 = OpPhi %10 %99 %80 %109 %75 + OpBranch %45 ; Was OpBranch %46 + %46 = OpLabel + %107 = OpPhi %10 ; Was OpPhi %10 %68 %66 %60 %65 %108 %74 + %105 = OpPhi %8 ; Was OpPhi %8 %70 %66 %72 %65 %113 %74 + OpBranch %44 + %45 = OpLabel + %95 = OpConvertSToF %21 %98 + %96 = OpCompositeConstruct %22 %95 %95 %95 %95 + OpStore %3 %96 + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, after_op_0, context.get()); + + ASSERT_TRUE(ops[1]->PreconditionHolds()); + ops[1]->TryToApply(); + CheckValid(env, context.get()); + std::string after_op_1 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + OpMemberDecorate %4 0 Offset 0 + OpMemberDecorate %4 1 Offset 4 + OpMemberDecorate %4 2 Offset 8 + OpMemberDecorate %4 3 Offset 12 + OpDecorate %4 Block + OpDecorate %5 DescriptorSet 0 + OpDecorate %5 Binding 0 + OpDecorate %3 Location 0 + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeBool + %10 = OpTypeInt 32 1 + %4 = OpTypeStruct %10 %10 %10 %10 + %11 = OpTypePointer Uniform %4 + %5 = OpVariable %11 Uniform + %12 = OpConstant %10 0 + %13 = OpTypePointer Uniform %10 + %14 = OpTypeInt 32 0 + %15 = OpConstant %14 0 + %16 = OpConstant %10 1 + %17 = OpConstant %10 2 + %18 = OpConstant %10 3 + %20 = OpConstantFalse %8 + %21 = OpTypeFloat 32 + %22 = OpTypeVector %21 4 + %23 = OpTypePointer Output %22 + %3 = OpVariable %23 Output + %114 = OpUndef %10 + %115 = OpUndef %8 + %116 = OpConstantTrue %8 + %2 = OpFunction %6 None %7 + %24 = OpLabel + %31 = OpAccessChain %13 %5 %12 + %32 = OpLoad %10 %31 + %33 = OpINotEqual %8 %32 %15 + %34 = OpAccessChain %13 %5 %16 + %35 = OpLoad %10 %34 + %36 = OpINotEqual %8 %35 %15 + %37 = OpAccessChain %13 %5 %17 + %38 = OpLoad %10 %37 + %39 = OpINotEqual %8 %38 %15 + %40 = OpAccessChain %13 %5 %18 + %41 = OpLoad %10 %40 + %42 = OpINotEqual %8 %41 %15 + OpBranch %44 + %44 = OpLabel + %98 = OpPhi %10 %12 %24 %114 %46 + %97 = OpPhi %8 %33 %24 %115 %46 + OpSelectionMerge %45 None ; Was OpLoopMerge %45 %46 None + OpBranchConditional %97 %49 %45 + %49 = OpLabel + OpSelectionMerge %51 None + OpBranchConditional %33 %52 %51 + %52 = OpLabel + %55 = OpIAdd %10 %98 %16 + OpBranch %51 + %51 = OpLabel + %100 = OpPhi %10 %98 %49 %55 %52 + %113 = OpSelect %8 %33 %36 %97 + OpSelectionMerge %57 None + OpBranchConditional %36 %58 %57 + %58 = OpLabel + %60 = OpIAdd %10 %100 %16 + %63 = OpLogicalOr %8 %113 %33 + OpSelectionMerge %65 None + OpBranchConditional %39 %66 %65 + %66 = OpLabel + %68 = OpIAdd %10 %100 %18 + %70 = OpLogicalNot %8 %63 + OpBranch %65 ; Was OpBranch %46 + %65 = OpLabel + %72 = OpLogicalOr %8 %63 %20 + OpBranch %57 ; Was OpBranch %46 + %57 = OpLabel + OpBranch %73 + %73 = OpLabel + %99 = OpPhi %10 %100 %57 %114 %75 + OpSelectionMerge %74 None ; Was OpLoopMerge %74 %75 None + OpBranchConditional %116 %76 %74 + %76 = OpLabel + OpSelectionMerge %78 None + OpBranchConditional %42 %79 %80 + %79 = OpLabel + OpSelectionMerge %82 None + OpSwitch %99 %83 1 %84 2 %85 + %83 = OpLabel + OpBranch %82 + %84 = OpLabel + %87 = OpSelect %10 %113 %16 %17 + %89 = OpIAdd %10 %99 %87 + OpBranch %82 + %85 = OpLabel + OpBranch %82 ; Was OpBranch %75 + %82 = OpLabel + %110 = OpPhi %10 %99 %83 %89 %84 %114 %85 ; Was OpPhi %10 %99 %83 %89 %84 + OpSelectionMerge %91 None + OpBranchConditional %39 %92 %91 + %92 = OpLabel + OpBranch %91 ; OpBranch %75 + %91 = OpLabel + OpBranch %78 + %80 = OpLabel + OpBranch %78 ; Was OpBranch %74 + %78 = OpLabel + OpBranch %74 ; Was OpBranch %75 + %75 = OpLabel + %109 = OpPhi %10 ; Was OpPhi %10 %99 %85 %110 %92 %110 %78 + OpBranchConditional %115 %73 %74 + %74 = OpLabel + %108 = OpPhi %10 %114 %75 %114 %78 %114 %73 ; Was OpPhi %10 %99 %80 %109 %75 + OpBranch %45 ; Was OpBranch %46 + %46 = OpLabel + %107 = OpPhi %10 ; Was OpPhi %10 %68 %66 %60 %65 %108 %74 + %105 = OpPhi %8 ; Was OpPhi %8 %70 %66 %72 %65 %113 %74 + OpBranch %44 + %45 = OpLabel + %95 = OpConvertSToF %21 %98 + %96 = OpCompositeConstruct %22 %95 %95 %95 %95 + OpStore %3 %96 + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, after_op_1, context.get()); +} + +TEST(StructuredLoopToSelectionReductionPassTest, DominanceIssue) { + // Exposes a scenario where redirecting edges results in uses of ids being + // non-dominated. We replace such uses with OpUndef to account for this. + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %5 = OpTypeInt 32 1 + %7 = OpTypePointer Function %5 + %6 = OpTypeBool + %8 = OpConstantTrue %6 + %9 = OpConstant %5 10 + %10 = OpConstant %5 20 + %11 = OpConstant %5 30 + %4 = OpFunction %2 None %3 + %12 = OpLabel + OpBranch %13 + %13 = OpLabel + OpLoopMerge %14 %15 None + OpBranch %16 + %16 = OpLabel + OpSelectionMerge %17 None + OpBranchConditional %8 %18 %19 + %18 = OpLabel + OpBranch %14 + %19 = OpLabel + %20 = OpIAdd %5 %9 %10 + OpBranch %17 + %17 = OpLabel + %21 = OpIAdd %5 %20 %11 + OpBranchConditional %8 %14 %15 + %15 = OpLabel + OpBranch %13 + %14 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto context = BuildModule(env, nullptr, shader, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + ASSERT_EQ(1, ops.size()); + + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + CheckValid(env, context.get()); + + std::string expected = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %5 = OpTypeInt 32 1 + %7 = OpTypePointer Function %5 + %6 = OpTypeBool + %8 = OpConstantTrue %6 + %9 = OpConstant %5 10 + %10 = OpConstant %5 20 + %11 = OpConstant %5 30 + %22 = OpUndef %5 + %4 = OpFunction %2 None %3 + %12 = OpLabel + OpBranch %13 + %13 = OpLabel + OpSelectionMerge %14 None + OpBranchConditional %8 %16 %14 + %16 = OpLabel + OpSelectionMerge %17 None + OpBranchConditional %8 %18 %19 + %18 = OpLabel + OpBranch %17 + %19 = OpLabel + %20 = OpIAdd %5 %9 %10 + OpBranch %17 + %17 = OpLabel + %21 = OpIAdd %5 %22 %11 + OpBranchConditional %8 %14 %14 + %15 = OpLabel + OpBranch %13 + %14 = OpLabel + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, expected, context.get()); +} + +TEST(StructuredLoopToSelectionReductionPassTest, AccessChainIssue) { + // Exposes a scenario where redirecting edges results in a use of an id + // generated by an access chain being non-dominated. + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %56 + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpMemberDecorate %28 0 Offset 0 + OpDecorate %28 Block + OpDecorate %30 DescriptorSet 0 + OpDecorate %30 Binding 0 + OpDecorate %56 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypeVector %6 2 + %8 = OpTypePointer Function %7 + %60 = OpTypePointer Private %7 + %10 = OpConstant %6 0 + %11 = OpConstantComposite %7 %10 %10 + %12 = OpTypePointer Function %6 + %59 = OpTypePointer Private %6 + %14 = OpTypeInt 32 1 + %15 = OpTypePointer Function %14 + %17 = OpConstant %14 0 + %24 = OpConstant %14 100 + %25 = OpTypeBool + %28 = OpTypeStruct %6 + %29 = OpTypePointer Uniform %28 + %30 = OpVariable %29 Uniform + %31 = OpTypePointer Uniform %6 + %39 = OpTypeInt 32 0 + %40 = OpConstant %39 1 + %45 = OpConstant %39 0 + %52 = OpConstant %14 1 + %54 = OpTypeVector %6 4 + %55 = OpTypePointer Output %54 + %56 = OpVariable %55 Output + %9 = OpVariable %60 Private + %4 = OpFunction %2 None %3 + %5 = OpLabel + %13 = OpVariable %12 Function + %16 = OpVariable %15 Function + %38 = OpVariable %12 Function + OpStore %9 %11 + OpStore %13 %10 + OpStore %16 %17 + OpBranch %18 + %18 = OpLabel + OpLoopMerge %20 %21 None + OpBranch %22 + %22 = OpLabel + %23 = OpLoad %14 %16 + %26 = OpSLessThan %25 %23 %24 + OpBranchConditional %26 %19 %20 + %19 = OpLabel + %27 = OpLoad %14 %16 + %32 = OpAccessChain %31 %30 %17 + %33 = OpLoad %6 %32 + %34 = OpConvertFToS %14 %33 + %35 = OpSLessThan %25 %27 %34 + OpSelectionMerge %37 None + OpBranchConditional %35 %36 %44 + %36 = OpLabel + %41 = OpAccessChain %59 %9 %40 + %42 = OpLoad %6 %41 + OpStore %38 %42 + OpBranch %20 + %44 = OpLabel + %46 = OpAccessChain %59 %9 %45 + OpBranch %37 + %37 = OpLabel + %47 = OpLoad %6 %46 + OpStore %38 %47 + %48 = OpLoad %6 %38 + %49 = OpLoad %6 %13 + %50 = OpFAdd %6 %49 %48 + OpStore %13 %50 + OpBranch %21 + %21 = OpLabel + %51 = OpLoad %14 %16 + %53 = OpIAdd %14 %51 %52 + OpStore %16 %53 + OpBranch %18 + %20 = OpLabel + %57 = OpLoad %6 %13 + %58 = OpCompositeConstruct %54 %57 %57 %57 %57 + OpStore %56 %58 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto context = BuildModule(env, nullptr, shader, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + ASSERT_EQ(1, ops.size()); + + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + CheckValid(env, context.get()); + + std::string expected = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %56 + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpMemberDecorate %28 0 Offset 0 + OpDecorate %28 Block + OpDecorate %30 DescriptorSet 0 + OpDecorate %30 Binding 0 + OpDecorate %56 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypeVector %6 2 + %8 = OpTypePointer Function %7 + %60 = OpTypePointer Private %7 + %10 = OpConstant %6 0 + %11 = OpConstantComposite %7 %10 %10 + %12 = OpTypePointer Function %6 + %59 = OpTypePointer Private %6 + %14 = OpTypeInt 32 1 + %15 = OpTypePointer Function %14 + %17 = OpConstant %14 0 + %24 = OpConstant %14 100 + %25 = OpTypeBool + %28 = OpTypeStruct %6 + %29 = OpTypePointer Uniform %28 + %30 = OpVariable %29 Uniform + %31 = OpTypePointer Uniform %6 + %39 = OpTypeInt 32 0 + %40 = OpConstant %39 1 + %45 = OpConstant %39 0 + %52 = OpConstant %14 1 + %54 = OpTypeVector %6 4 + %55 = OpTypePointer Output %54 + %56 = OpVariable %55 Output + %9 = OpVariable %60 Private + %61 = OpConstantTrue %25 + %62 = OpVariable %59 Private + %4 = OpFunction %2 None %3 + %5 = OpLabel + %13 = OpVariable %12 Function + %16 = OpVariable %15 Function + %38 = OpVariable %12 Function + OpStore %9 %11 + OpStore %13 %10 + OpStore %16 %17 + OpBranch %18 + %18 = OpLabel + OpSelectionMerge %20 None + OpBranchConditional %61 %22 %20 + %22 = OpLabel + %23 = OpLoad %14 %16 + %26 = OpSLessThan %25 %23 %24 + OpBranchConditional %26 %19 %20 + %19 = OpLabel + %27 = OpLoad %14 %16 + %32 = OpAccessChain %31 %30 %17 + %33 = OpLoad %6 %32 + %34 = OpConvertFToS %14 %33 + %35 = OpSLessThan %25 %27 %34 + OpSelectionMerge %37 None + OpBranchConditional %35 %36 %44 + %36 = OpLabel + %41 = OpAccessChain %59 %9 %40 + %42 = OpLoad %6 %41 + OpStore %38 %42 + OpBranch %37 + %44 = OpLabel + %46 = OpAccessChain %59 %9 %45 + OpBranch %37 + %37 = OpLabel + %47 = OpLoad %6 %62 + OpStore %38 %47 + %48 = OpLoad %6 %38 + %49 = OpLoad %6 %13 + %50 = OpFAdd %6 %49 %48 + OpStore %13 %50 + OpBranch %20 + %21 = OpLabel + %51 = OpLoad %14 %16 + %53 = OpIAdd %14 %51 %52 + OpStore %16 %53 + OpBranch %18 + %20 = OpLabel + %57 = OpLoad %6 %13 + %58 = OpCompositeConstruct %54 %57 %57 %57 %57 + OpStore %56 %58 + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, expected, context.get()); +} + +TEST(StructuredLoopToSelectionReductionPassTest, DominanceAndPhiIssue) { + // Exposes an interesting scenario where a use in a phi stops being dominated + // by the block with which it is associated in the phi. + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %17 = OpTypeBool + %18 = OpConstantTrue %17 + %19 = OpConstantFalse %17 + %20 = OpTypeInt 32 1 + %21 = OpConstant %20 5 + %22 = OpConstant %20 6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + OpLoopMerge %16 %15 None + OpBranch %7 + %7 = OpLabel + OpSelectionMerge %13 None + OpBranchConditional %18 %8 %9 + %8 = OpLabel + OpSelectionMerge %12 None + OpBranchConditional %18 %10 %11 + %9 = OpLabel + OpBranch %16 + %10 = OpLabel + OpBranch %16 + %11 = OpLabel + %23 = OpIAdd %20 %21 %22 + OpBranch %12 + %12 = OpLabel + OpBranch %13 + %13 = OpLabel + OpBranch %14 + %14 = OpLabel + %24 = OpPhi %20 %23 %13 + OpBranchConditional %19 %15 %16 + %15 = OpLabel + OpBranch %6 + %16 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto context = BuildModule(env, nullptr, shader, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + ASSERT_EQ(1, ops.size()); + + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + + CheckValid(env, context.get()); + + std::string expected = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %17 = OpTypeBool + %18 = OpConstantTrue %17 + %19 = OpConstantFalse %17 + %20 = OpTypeInt 32 1 + %21 = OpConstant %20 5 + %22 = OpConstant %20 6 + %25 = OpUndef %20 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + OpSelectionMerge %16 None + OpBranchConditional %18 %7 %16 + %7 = OpLabel + OpSelectionMerge %13 None + OpBranchConditional %18 %8 %9 + %8 = OpLabel + OpSelectionMerge %12 None + OpBranchConditional %18 %10 %11 + %9 = OpLabel + OpBranch %13 + %10 = OpLabel + OpBranch %12 + %11 = OpLabel + %23 = OpIAdd %20 %21 %22 + OpBranch %12 + %12 = OpLabel + OpBranch %13 + %13 = OpLabel + OpBranch %14 + %14 = OpLabel + %24 = OpPhi %20 %25 %13 + OpBranchConditional %19 %16 %16 + %15 = OpLabel + OpBranch %6 + %16 = OpLabel + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, expected, context.get()); +} + +TEST(StructuredLoopToSelectionReductionPassTest, OpLineBeforeOpPhi) { + // Test to ensure the pass knows OpLine and OpPhi instructions can be + // interleaved. + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + %3 = OpString "somefile" + %4 = OpTypeVoid + %5 = OpTypeFunction %4 + %6 = OpTypeInt 32 1 + %7 = OpConstant %6 10 + %8 = OpConstant %6 20 + %9 = OpConstant %6 30 + %10 = OpTypeBool + %11 = OpConstantTrue %10 + %2 = OpFunction %4 None %5 + %12 = OpLabel + OpBranch %13 + %13 = OpLabel + OpLoopMerge %14 %15 None + OpBranch %16 + %16 = OpLabel + OpSelectionMerge %17 None + OpBranchConditional %11 %18 %19 + %18 = OpLabel + %20 = OpIAdd %6 %7 %8 + %21 = OpIAdd %6 %7 %9 + OpBranch %17 + %19 = OpLabel + OpBranch %14 + %17 = OpLabel + %22 = OpPhi %6 %20 %18 + OpLine %3 0 0 + %23 = OpPhi %6 %21 %18 + OpBranch %15 + %15 = OpLabel + OpBranch %13 + %14 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto context = BuildModule(env, nullptr, shader, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + ASSERT_EQ(1, ops.size()); + + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + + CheckValid(env, context.get()); + + std::string expected = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + %3 = OpString "somefile" + %4 = OpTypeVoid + %5 = OpTypeFunction %4 + %6 = OpTypeInt 32 1 + %7 = OpConstant %6 10 + %8 = OpConstant %6 20 + %9 = OpConstant %6 30 + %10 = OpTypeBool + %11 = OpConstantTrue %10 + %24 = OpUndef %6 + %2 = OpFunction %4 None %5 + %12 = OpLabel + OpBranch %13 + %13 = OpLabel + OpSelectionMerge %14 None + OpBranchConditional %11 %16 %14 + %16 = OpLabel + OpSelectionMerge %17 None + OpBranchConditional %11 %18 %19 + %18 = OpLabel + %20 = OpIAdd %6 %7 %8 + %21 = OpIAdd %6 %7 %9 + OpBranch %17 + %19 = OpLabel + OpBranch %17 + %17 = OpLabel + %22 = OpPhi %6 %20 %18 %24 %19 + OpLine %3 0 0 + %23 = OpPhi %6 %21 %18 %24 %19 + OpBranch %14 + %15 = OpLabel + OpBranch %13 + %14 = OpLabel + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, expected, context.get()); +} + +TEST(StructuredLoopToSelectionReductionPassTest, + SelectionMergeIsContinueTarget) { + // Example where a loop's continue target is also the target of a selection. + // In this scenario we cautiously do not apply the transformation. + std::string shader = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "main" + %2 = OpTypeVoid + %3 = OpTypeBool + %4 = OpTypeFunction %2 + %1 = OpFunction %2 None %4 + %5 = OpLabel + %6 = OpUndef %3 + OpBranch %7 + %7 = OpLabel + %8 = OpPhi %3 %6 %5 %9 %10 + OpLoopMerge %11 %10 None + OpBranch %12 + %12 = OpLabel + %13 = OpUndef %3 + OpSelectionMerge %10 None + OpBranchConditional %13 %14 %10 + %14 = OpLabel + OpBranch %10 + %10 = OpLabel + %9 = OpUndef %3 + OpBranchConditional %9 %7 %11 + %11 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto context = BuildModule(env, nullptr, shader, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + + // There should be no opportunities. + ASSERT_EQ(0, ops.size()); +} + +TEST(StructuredLoopToSelectionReductionPassTest, + SwitchSelectionMergeIsContinueTarget) { + // Another example where a loop's continue target is also the target of a + // selection; this time a selection associated with an OpSwitch. We + // cautiously do not apply the transformation. + std::string shader = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "main" + %2 = OpTypeVoid + %3 = OpTypeBool + %5 = OpTypeInt 32 1 + %4 = OpTypeFunction %2 + %6 = OpConstant %5 2 + %7 = OpConstantTrue %3 + %1 = OpFunction %2 None %4 + %8 = OpLabel + OpBranch %9 + %9 = OpLabel + OpLoopMerge %14 %15 None + OpBranchConditional %7 %10 %14 + %10 = OpLabel + OpSelectionMerge %15 None + OpSwitch %6 %12 1 %11 2 %11 3 %15 + %11 = OpLabel + OpBranch %12 + %12 = OpLabel + OpBranch %15 + %15 = OpLabel + OpBranch %9 + %14 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto context = BuildModule(env, nullptr, shader, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + + // There should be no opportunities. + ASSERT_EQ(0, ops.size()); +} + +TEST(StructuredLoopToSelectionReductionPassTest, ContinueTargetIsSwitchTarget) { + std::string shader = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "main" + %2 = OpTypeVoid + %3 = OpTypeBool + %5 = OpTypeInt 32 1 + %4 = OpTypeFunction %2 + %6 = OpConstant %5 2 + %7 = OpConstantTrue %3 + %1 = OpFunction %2 None %4 + %8 = OpLabel + OpBranch %9 + %9 = OpLabel + OpLoopMerge %14 %12 None + OpBranchConditional %7 %10 %14 + %10 = OpLabel + OpSelectionMerge %15 None + OpSwitch %6 %12 1 %11 2 %11 3 %15 + %11 = OpLabel + OpBranch %12 + %12 = OpLabel + OpBranch %9 + %15 = OpLabel + OpBranch %14 + %14 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto context = BuildModule(env, nullptr, shader, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + + ASSERT_EQ(1, ops.size()); + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + + CheckValid(env, context.get()); + + std::string expected = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "main" + %2 = OpTypeVoid + %3 = OpTypeBool + %5 = OpTypeInt 32 1 + %4 = OpTypeFunction %2 + %6 = OpConstant %5 2 + %7 = OpConstantTrue %3 + %1 = OpFunction %2 None %4 + %8 = OpLabel + OpBranch %9 + %9 = OpLabel + OpSelectionMerge %14 None + OpBranchConditional %7 %10 %14 + %10 = OpLabel + OpSelectionMerge %15 None + OpSwitch %6 %15 1 %11 2 %11 3 %15 + %11 = OpLabel + OpBranch %15 + %12 = OpLabel + OpBranch %9 + %15 = OpLabel + OpBranch %14 + %14 = OpLabel + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, expected, context.get()); +} + +TEST(StructuredLoopToSelectionReductionPassTest, + MultipleSwitchTargetsAreContinueTarget) { + std::string shader = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "main" + %2 = OpTypeVoid + %3 = OpTypeBool + %5 = OpTypeInt 32 1 + %4 = OpTypeFunction %2 + %6 = OpConstant %5 2 + %7 = OpConstantTrue %3 + %1 = OpFunction %2 None %4 + %8 = OpLabel + OpBranch %9 + %9 = OpLabel + OpLoopMerge %14 %12 None + OpBranchConditional %7 %10 %14 + %10 = OpLabel + OpSelectionMerge %15 None + OpSwitch %6 %11 1 %12 2 %12 3 %15 + %11 = OpLabel + OpBranch %12 + %12 = OpLabel + OpBranch %9 + %15 = OpLabel + OpBranch %14 + %14 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto context = BuildModule(env, nullptr, shader, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + + ASSERT_EQ(1, ops.size()); + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + + CheckValid(env, context.get()); + + std::string expected = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "main" + %2 = OpTypeVoid + %3 = OpTypeBool + %5 = OpTypeInt 32 1 + %4 = OpTypeFunction %2 + %6 = OpConstant %5 2 + %7 = OpConstantTrue %3 + %1 = OpFunction %2 None %4 + %8 = OpLabel + OpBranch %9 + %9 = OpLabel + OpSelectionMerge %14 None + OpBranchConditional %7 %10 %14 + %10 = OpLabel + OpSelectionMerge %15 None + OpSwitch %6 %11 1 %15 2 %15 3 %15 + %11 = OpLabel + OpBranch %15 + %12 = OpLabel + OpBranch %9 + %15 = OpLabel + OpBranch %14 + %14 = OpLabel + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, expected, context.get()); +} + +TEST(StructuredLoopToSelectionReductionPassTest, LoopBranchesStraightToMerge) { + std::string shader = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "main" + %2 = OpTypeVoid + %4 = OpTypeFunction %2 + %1 = OpFunction %2 None %4 + %8 = OpLabel + OpBranch %9 + %9 = OpLabel + OpLoopMerge %14 %12 None + OpBranch %14 + %12 = OpLabel + OpBranch %9 + %14 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto context = BuildModule(env, nullptr, shader, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + + ASSERT_EQ(1, ops.size()); + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + + CheckValid(env, context.get()); + + std::string expected = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "main" + %2 = OpTypeVoid + %4 = OpTypeFunction %2 + %15 = OpTypeBool + %16 = OpConstantTrue %15 + %1 = OpFunction %2 None %4 + %8 = OpLabel + OpBranch %9 + %9 = OpLabel + OpSelectionMerge %14 None + OpBranchConditional %16 %14 %14 + %12 = OpLabel + OpBranch %9 + %14 = OpLabel + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, expected, context.get()); +} + +TEST(StructuredLoopToSelectionReductionPassTest, + LoopConditionallyJumpsToMergeOrContinue) { + std::string shader = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "main" + %2 = OpTypeVoid + %3 = OpTypeBool + %4 = OpTypeFunction %2 + %7 = OpConstantTrue %3 + %1 = OpFunction %2 None %4 + %8 = OpLabel + OpBranch %9 + %9 = OpLabel + OpLoopMerge %14 %12 None + OpBranchConditional %7 %14 %12 + %12 = OpLabel + OpBranch %9 + %14 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto context = BuildModule(env, nullptr, shader, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + + ASSERT_EQ(1, ops.size()); + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + + CheckValid(env, context.get()); + + std::string expected = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "main" + %2 = OpTypeVoid + %3 = OpTypeBool + %4 = OpTypeFunction %2 + %7 = OpConstantTrue %3 + %1 = OpFunction %2 None %4 + %8 = OpLabel + OpBranch %9 + %9 = OpLabel + OpSelectionMerge %14 None + OpBranchConditional %7 %14 %14 + %12 = OpLabel + OpBranch %9 + %14 = OpLabel + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, expected, context.get()); +} + +TEST(StructuredLoopToSelectionReductionPassTest, MultipleAccessChains) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypeStruct %6 + %8 = OpTypeStruct %7 + %9 = OpTypePointer Function %8 + %11 = OpConstant %6 3 + %12 = OpConstantComposite %7 %11 + %13 = OpConstantComposite %8 %12 + %14 = OpTypePointer Function %7 + %16 = OpConstant %6 0 + %19 = OpTypePointer Function %6 + %15 = OpTypeBool + %18 = OpConstantTrue %15 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %10 = OpVariable %9 Function + %20 = OpVariable %19 Function + OpStore %10 %13 + OpBranch %23 + %23 = OpLabel + OpLoopMerge %25 %26 None + OpBranch %27 + %27 = OpLabel + OpSelectionMerge %28 None + OpBranchConditional %18 %29 %25 + %29 = OpLabel + %17 = OpAccessChain %14 %10 %16 + OpBranch %28 + %28 = OpLabel + %21 = OpAccessChain %19 %17 %16 + %22 = OpLoad %6 %21 + %24 = OpAccessChain %19 %10 %16 %16 + OpStore %24 %22 + OpBranch %25 + %26 = OpLabel + OpBranch %23 + %25 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto context = BuildModule(env, nullptr, shader, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + + ASSERT_EQ(1, ops.size()); + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + + CheckValid(env, context.get()); + + std::string expected = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypeStruct %6 + %8 = OpTypeStruct %7 + %9 = OpTypePointer Function %8 + %11 = OpConstant %6 3 + %12 = OpConstantComposite %7 %11 + %13 = OpConstantComposite %8 %12 + %14 = OpTypePointer Function %7 + %16 = OpConstant %6 0 + %19 = OpTypePointer Function %6 + %15 = OpTypeBool + %18 = OpConstantTrue %15 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %10 = OpVariable %9 Function + %20 = OpVariable %19 Function + %30 = OpVariable %14 Function + OpStore %10 %13 + OpBranch %23 + %23 = OpLabel + OpSelectionMerge %25 None + OpBranchConditional %18 %27 %25 + %27 = OpLabel + OpSelectionMerge %28 None + OpBranchConditional %18 %29 %28 + %29 = OpLabel + %17 = OpAccessChain %14 %10 %16 + OpBranch %28 + %28 = OpLabel + %21 = OpAccessChain %19 %30 %16 + %22 = OpLoad %6 %21 + %24 = OpAccessChain %19 %10 %16 %16 + OpStore %24 %22 + OpBranch %25 + %26 = OpLabel + OpBranch %23 + %25 = OpLabel + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, expected, context.get()); +} + +TEST(StructuredLoopToSelectionReductionPassTest, + UnreachableInnerLoopContinueBranchingToOuterLoopMerge) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeBool + %6 = OpConstantTrue %5 + %2 = OpFunction %3 None %4 + %7 = OpLabel + OpBranch %8 + %8 = OpLabel + OpLoopMerge %9 %10 None + OpBranch %11 + %11 = OpLabel + OpLoopMerge %12 %13 None + OpBranch %12 + %13 = OpLabel + OpBranchConditional %6 %9 %11 + %12 = OpLabel + OpBranch %10 + %10 = OpLabel + OpBranchConditional %6 %9 %8 + %9 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto context = BuildModule(env, nullptr, shader, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + + ASSERT_EQ(2, ops.size()); + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + + CheckValid(env, context.get()); + + std::string after_op_0 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeBool + %6 = OpConstantTrue %5 + %2 = OpFunction %3 None %4 + %7 = OpLabel + OpBranch %8 + %8 = OpLabel + OpSelectionMerge %9 None + OpBranchConditional %6 %11 %9 + %11 = OpLabel + OpLoopMerge %12 %13 None + OpBranch %12 + %13 = OpLabel + OpBranchConditional %6 %9 %11 + %12 = OpLabel + OpBranch %9 + %10 = OpLabel + OpBranchConditional %6 %9 %8 + %9 = OpLabel + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, after_op_0, context.get()); + + ASSERT_TRUE(ops[1]->PreconditionHolds()); + ops[1]->TryToApply(); + + CheckValid(env, context.get()); + + std::string after_op_1 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeBool + %6 = OpConstantTrue %5 + %2 = OpFunction %3 None %4 + %7 = OpLabel + OpBranch %8 + %8 = OpLabel + OpSelectionMerge %9 None + OpBranchConditional %6 %11 %9 + %11 = OpLabel + OpSelectionMerge %12 None + OpBranchConditional %6 %12 %12 + %13 = OpLabel + OpBranchConditional %6 %9 %11 + %12 = OpLabel + OpBranch %9 + %10 = OpLabel + OpBranchConditional %6 %9 %8 + %9 = OpLabel + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, after_op_1, context.get()); +} + +TEST(StructuredLoopToSelectionReductionPassTest, + UnreachableInnerLoopContinueBranchingToOuterLoopMerge2) { + // In this test, the branch to the outer loop merge from the inner loop's + // continue is part of a structured selection. + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeBool + %6 = OpConstantTrue %5 + %2 = OpFunction %3 None %4 + %7 = OpLabel + OpBranch %8 + %8 = OpLabel + OpLoopMerge %9 %10 None + OpBranch %11 + %11 = OpLabel + OpLoopMerge %12 %13 None + OpBranch %12 + %13 = OpLabel + OpSelectionMerge %14 None + OpBranchConditional %6 %9 %14 + %14 = OpLabel + OpBranch %11 + %12 = OpLabel + OpBranch %10 + %10 = OpLabel + OpBranchConditional %6 %9 %8 + %9 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto context = BuildModule(env, nullptr, shader, kReduceAssembleOption); + const auto pass = TestSubclass(env); + const auto ops = pass.WrapGetAvailableOpportunities(context.get()); + + ASSERT_EQ(2, ops.size()); + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + + CheckValid(env, context.get()); + + std::string after_op_0 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeBool + %6 = OpConstantTrue %5 + %2 = OpFunction %3 None %4 + %7 = OpLabel + OpBranch %8 + %8 = OpLabel + OpSelectionMerge %9 None + OpBranchConditional %6 %11 %9 + %11 = OpLabel + OpLoopMerge %12 %13 None + OpBranch %12 + %13 = OpLabel + OpSelectionMerge %14 None + OpBranchConditional %6 %9 %14 + %14 = OpLabel + OpBranch %11 + %12 = OpLabel + OpBranch %9 + %10 = OpLabel + OpBranchConditional %6 %9 %8 + %9 = OpLabel + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, after_op_0, context.get()); + + ASSERT_TRUE(ops[1]->PreconditionHolds()); + ops[1]->TryToApply(); + + CheckValid(env, context.get()); + + std::string after_op_1 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeBool + %6 = OpConstantTrue %5 + %2 = OpFunction %3 None %4 + %7 = OpLabel + OpBranch %8 + %8 = OpLabel + OpSelectionMerge %9 None + OpBranchConditional %6 %11 %9 + %11 = OpLabel + OpSelectionMerge %12 None + OpBranchConditional %6 %12 %12 + %13 = OpLabel + OpSelectionMerge %14 None + OpBranchConditional %6 %9 %14 + %14 = OpLabel + OpBranch %11 + %12 = OpLabel + OpBranch %9 + %10 = OpLabel + OpBranchConditional %6 %9 %8 + %9 = OpLabel + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, after_op_1, context.get()); +} + +TEST(StructuredLoopToSelectionReductionPassTest, + InnerLoopHeaderBranchesToOuterLoopMerge) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeBool + %6 = OpConstantTrue %5 + %2 = OpFunction %3 None %4 + %7 = OpLabel + OpBranch %8 + %8 = OpLabel + OpLoopMerge %9 %10 None + OpBranch %11 + %11 = OpLabel + OpLoopMerge %12 %13 None + OpBranchConditional %6 %9 %13 + %13 = OpLabel + OpBranchConditional %6 %11 %12 + %12 = OpLabel + OpBranch %10 + %10 = OpLabel + OpBranchConditional %6 %9 %8 + %9 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto context = BuildModule(env, nullptr, shader, kReduceAssembleOption); + const auto pass = TestSubclass(env); + auto ops = pass.WrapGetAvailableOpportunities(context.get()); + + // We cannot transform the inner loop due to its header jumping straight to + // the outer loop merge (the inner loop's merge does not post-dominate its + // header). + ASSERT_EQ(1, ops.size()); + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + + CheckValid(env, context.get()); + + std::string after_op_0 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeBool + %6 = OpConstantTrue %5 + %2 = OpFunction %3 None %4 + %7 = OpLabel + OpBranch %8 + %8 = OpLabel + OpSelectionMerge %9 None + OpBranchConditional %6 %11 %9 + %11 = OpLabel + OpLoopMerge %12 %13 None + OpBranchConditional %6 %12 %13 + %13 = OpLabel + OpBranchConditional %6 %11 %12 + %12 = OpLabel + OpBranch %9 + %10 = OpLabel + OpBranchConditional %6 %9 %8 + %9 = OpLabel + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, after_op_0, context.get()); + + // Now look again for more opportunities. + ops = pass.WrapGetAvailableOpportunities(context.get()); + + // What was the inner loop should now be transformable, as the jump to the + // outer loop's merge has been redirected. + ASSERT_EQ(1, ops.size()); + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + + CheckValid(env, context.get()); + + std::string after_another_op_0 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeBool + %6 = OpConstantTrue %5 + %2 = OpFunction %3 None %4 + %7 = OpLabel + OpBranch %8 + %8 = OpLabel + OpSelectionMerge %9 None + OpBranchConditional %6 %11 %9 + %11 = OpLabel + OpSelectionMerge %12 None + OpBranchConditional %6 %12 %12 + %13 = OpLabel + OpBranchConditional %6 %11 %12 + %12 = OpLabel + OpBranch %9 + %10 = OpLabel + OpBranchConditional %6 %9 %8 + %9 = OpLabel + OpReturn + OpFunctionEnd + )"; + CheckEqual(env, after_another_op_0, context.get()); +} + +TEST(StructuredLoopToSelectionReductionPassTest, LongAccessChains) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %5 = OpTypeInt 32 1 + %6 = OpTypeInt 32 0 + %7 = OpConstant %6 5 + %8 = OpTypeArray %5 %7 + %9 = OpTypeStruct %8 + %10 = OpTypeStruct %9 %9 + %11 = OpConstant %6 2 + %12 = OpTypeArray %10 %11 + %13 = OpTypeStruct %12 + %14 = OpTypePointer Function %13 + %15 = OpConstant %5 0 + %16 = OpConstant %5 1 + %17 = OpConstant %5 2 + %18 = OpConstant %5 3 + %19 = OpConstant %5 4 + %20 = OpConstantComposite %8 %15 %16 %17 %18 %19 + %21 = OpConstantComposite %9 %20 + %22 = OpConstant %5 5 + %23 = OpConstant %5 6 + %24 = OpConstant %5 7 + %25 = OpConstant %5 8 + %26 = OpConstant %5 9 + %27 = OpConstantComposite %8 %22 %23 %24 %25 %26 + %28 = OpConstantComposite %9 %27 + %29 = OpConstantComposite %10 %21 %28 + %30 = OpConstant %5 10 + %31 = OpConstant %5 11 + %32 = OpConstant %5 12 + %33 = OpConstant %5 13 + %34 = OpConstant %5 14 + %35 = OpConstantComposite %8 %30 %31 %32 %33 %34 + %36 = OpConstantComposite %9 %35 + %37 = OpConstant %5 15 + %38 = OpConstant %5 16 + %39 = OpConstant %5 17 + %40 = OpConstant %5 18 + %41 = OpConstant %5 19 + %42 = OpConstantComposite %8 %37 %38 %39 %40 %41 + %43 = OpConstantComposite %9 %42 + %44 = OpConstantComposite %10 %36 %43 + %45 = OpConstantComposite %12 %29 %44 + %46 = OpConstantComposite %13 %45 + %47 = OpTypePointer Function %12 + %48 = OpTypePointer Function %10 + %49 = OpTypePointer Function %9 + %50 = OpTypePointer Function %8 + %51 = OpTypePointer Function %5 + %52 = OpTypeBool + %53 = OpConstantTrue %52 + %2 = OpFunction %3 None %4 + %54 = OpLabel + %55 = OpVariable %14 Function + OpStore %55 %46 + OpBranch %56 + %56 = OpLabel + OpLoopMerge %57 %58 None + OpBranchConditional %53 %57 %59 + %59 = OpLabel + OpSelectionMerge %60 None + OpBranchConditional %53 %61 %57 + %61 = OpLabel + %62 = OpAccessChain %47 %55 %15 + OpBranch %63 + %63 = OpLabel + OpSelectionMerge %64 None + OpBranchConditional %53 %65 %57 + %65 = OpLabel + %66 = OpAccessChain %48 %62 %16 + OpBranch %67 + %67 = OpLabel + OpSelectionMerge %68 None + OpBranchConditional %53 %69 %57 + %69 = OpLabel + %70 = OpAccessChain %49 %66 %16 + OpBranch %71 + %71 = OpLabel + OpSelectionMerge %72 None + OpBranchConditional %53 %73 %57 + %73 = OpLabel + %74 = OpAccessChain %50 %70 %15 + OpBranch %75 + %75 = OpLabel + OpSelectionMerge %76 None + OpBranchConditional %53 %77 %57 + %77 = OpLabel + %78 = OpAccessChain %51 %74 %17 + OpBranch %79 + %79 = OpLabel + OpSelectionMerge %80 None + OpBranchConditional %53 %81 %57 + %81 = OpLabel + %82 = OpLoad %5 %78 + OpBranch %80 + %80 = OpLabel + OpBranch %76 + %76 = OpLabel + OpBranch %72 + %72 = OpLabel + OpBranch %68 + %68 = OpLabel + OpBranch %64 + %64 = OpLabel + OpBranch %60 + %60 = OpLabel + OpBranch %58 + %58 = OpLabel + OpBranch %56 + %57 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto context = BuildModule(env, nullptr, shader, kReduceAssembleOption); + const auto pass = TestSubclass(env); + auto ops = pass.WrapGetAvailableOpportunities(context.get()); + + ASSERT_EQ(1, ops.size()); + ASSERT_TRUE(ops[0]->PreconditionHolds()); + ops[0]->TryToApply(); + + CheckValid(env, context.get()); + + // TODO(2183): When we have a more general solution for handling access + // chains, write an expected result for this test. + // std::string expected = R"( + // Expected text for transformed shader + //)"; + // CheckEqual(env, expected, context.get()); +} + +} // namespace +} // namespace reduce +} // namespace spvtools diff --git a/test/reduce/validation_during_reduction_test.cpp b/test/reduce/validation_during_reduction_test.cpp new file mode 100644 index 000000000..bb7d14e10 --- /dev/null +++ b/test/reduce/validation_during_reduction_test.cpp @@ -0,0 +1,376 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reduce_test_util.h" + +#include "source/reduce/reducer.h" +#include "source/reduce/reduction_pass.h" +#include "source/reduce/remove_instruction_reduction_opportunity.h" + +namespace spvtools { +namespace reduce { +namespace { + +// A dumb reduction pass that removes global values regardless of whether they +// are referenced. This is very likely to make the resulting module invalid. We +// use this to test the reducer's behavior in the scenario where a bad reduction +// pass leads to an invalid module. +class BlindlyRemoveGlobalValuesPass : public ReductionPass { + public: + // Creates the reduction pass in the context of the given target environment + // |target_env| + explicit BlindlyRemoveGlobalValuesPass(const spv_target_env target_env) + : ReductionPass(target_env) {} + + ~BlindlyRemoveGlobalValuesPass() override = default; + + // The name of this pass. + std::string GetName() const final { return "BlindlyRemoveGlobalValuesPass"; }; + + protected: + // Adds opportunities to remove all global values. Assuming they are all + // referenced (directly or indirectly) from elsewhere in the module, each such + // opportunity will make the module invalid. + std::vector> GetAvailableOpportunities( + opt::IRContext* context) const final { + std::vector> result; + for (auto& inst : context->module()->types_values()) { + if (inst.HasResultId()) { + result.push_back( + MakeUnique(&inst)); + } + } + return result; + } +}; + +TEST(ValidationDuringReductionTest, CheckInvalidPassMakesNoProgress) { + // A module whose global values are all referenced, so that any application of + // MakeModuleInvalidPass will make the module invalid. + std::string original = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %60 + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + OpName %16 "buf2" + OpMemberName %16 0 "i" + OpName %18 "" + OpName %25 "buf1" + OpMemberName %25 0 "f" + OpName %27 "" + OpName %60 "_GLF_color" + OpMemberDecorate %16 0 Offset 0 + OpDecorate %16 Block + OpDecorate %18 DescriptorSet 0 + OpDecorate %18 Binding 2 + OpMemberDecorate %25 0 Offset 0 + OpDecorate %25 Block + OpDecorate %27 DescriptorSet 0 + OpDecorate %27 Binding 1 + OpDecorate %60 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %9 = OpConstant %6 0 + %16 = OpTypeStruct %6 + %17 = OpTypePointer Uniform %16 + %18 = OpVariable %17 Uniform + %19 = OpTypePointer Uniform %6 + %22 = OpTypeBool + %24 = OpTypeFloat 32 + %25 = OpTypeStruct %24 + %26 = OpTypePointer Uniform %25 + %27 = OpVariable %26 Uniform + %28 = OpTypePointer Uniform %24 + %31 = OpConstant %24 2 + %56 = OpConstant %6 1 + %58 = OpTypeVector %24 4 + %59 = OpTypePointer Output %58 + %60 = OpVariable %59 Output + %72 = OpUndef %24 + %74 = OpUndef %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %10 + %10 = OpLabel + %73 = OpPhi %6 %74 %5 %77 %34 + %71 = OpPhi %24 %72 %5 %76 %34 + %70 = OpPhi %6 %9 %5 %57 %34 + %20 = OpAccessChain %19 %18 %9 + %21 = OpLoad %6 %20 + %23 = OpSLessThan %22 %70 %21 + OpLoopMerge %12 %34 None + OpBranchConditional %23 %11 %12 + %11 = OpLabel + %29 = OpAccessChain %28 %27 %9 + %30 = OpLoad %24 %29 + %32 = OpFOrdGreaterThan %22 %30 %31 + OpSelectionMerge %34 None + OpBranchConditional %32 %33 %46 + %33 = OpLabel + %40 = OpFAdd %24 %71 %30 + %45 = OpISub %6 %73 %21 + OpBranch %34 + %46 = OpLabel + %50 = OpFMul %24 %71 %30 + %54 = OpSDiv %6 %73 %21 + OpBranch %34 + %34 = OpLabel + %77 = OpPhi %6 %45 %33 %54 %46 + %76 = OpPhi %24 %40 %33 %50 %46 + %57 = OpIAdd %6 %70 %56 + OpBranch %10 + %12 = OpLabel + %61 = OpAccessChain %28 %27 %9 + %62 = OpLoad %24 %61 + %66 = OpConvertSToF %24 %21 + %68 = OpConvertSToF %24 %73 + %69 = OpCompositeConstruct %58 %62 %71 %66 %68 + OpStore %60 %69 + OpReturn + OpFunctionEnd + )"; + + spv_target_env env = SPV_ENV_UNIVERSAL_1_3; + Reducer reducer(env); + reducer.SetMessageConsumer(NopDiagnostic); + + // Say that every module is interesting. + reducer.SetInterestingnessFunction( + [](const std::vector&, uint32_t) -> bool { return true; }); + + reducer.AddReductionPass(MakeUnique(env)); + + std::vector binary_in; + SpirvTools t(env); + + ASSERT_TRUE(t.Assemble(original, &binary_in, kReduceAssembleOption)); + std::vector binary_out; + spvtools::ReducerOptions reducer_options; + reducer_options.set_step_limit(500); + + reducer.Run(std::move(binary_in), &binary_out, reducer_options); + + // The reducer should have no impact. + CheckEqual(env, original, binary_out); +} + +TEST(ValidationDuringReductionTest, CheckNotAlwaysInvalidCanMakeProgress) { + // A module with just one unreferenced global value. All but one application + // of MakeModuleInvalidPass will make the module invalid. + std::string original = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %60 + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + OpName %16 "buf2" + OpMemberName %16 0 "i" + OpName %18 "" + OpName %25 "buf1" + OpMemberName %25 0 "f" + OpName %27 "" + OpName %60 "_GLF_color" + OpMemberDecorate %16 0 Offset 0 + OpDecorate %16 Block + OpDecorate %18 DescriptorSet 0 + OpDecorate %18 Binding 2 + OpMemberDecorate %25 0 Offset 0 + OpDecorate %25 Block + OpDecorate %27 DescriptorSet 0 + OpDecorate %27 Binding 1 + OpDecorate %60 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %9 = OpConstant %6 0 + %16 = OpTypeStruct %6 + %17 = OpTypePointer Uniform %16 + %18 = OpVariable %17 Uniform + %19 = OpTypePointer Uniform %6 + %22 = OpTypeBool + %24 = OpTypeFloat 32 + %25 = OpTypeStruct %24 + %26 = OpTypePointer Uniform %25 + %27 = OpVariable %26 Uniform + %28 = OpTypePointer Uniform %24 + %31 = OpConstant %24 2 + %56 = OpConstant %6 1 + %1000 = OpConstant %6 1000 ; It should be possible to remove this instruction without making the module invalid. + %58 = OpTypeVector %24 4 + %59 = OpTypePointer Output %58 + %60 = OpVariable %59 Output + %72 = OpUndef %24 + %74 = OpUndef %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %10 + %10 = OpLabel + %73 = OpPhi %6 %74 %5 %77 %34 + %71 = OpPhi %24 %72 %5 %76 %34 + %70 = OpPhi %6 %9 %5 %57 %34 + %20 = OpAccessChain %19 %18 %9 + %21 = OpLoad %6 %20 + %23 = OpSLessThan %22 %70 %21 + OpLoopMerge %12 %34 None + OpBranchConditional %23 %11 %12 + %11 = OpLabel + %29 = OpAccessChain %28 %27 %9 + %30 = OpLoad %24 %29 + %32 = OpFOrdGreaterThan %22 %30 %31 + OpSelectionMerge %34 None + OpBranchConditional %32 %33 %46 + %33 = OpLabel + %40 = OpFAdd %24 %71 %30 + %45 = OpISub %6 %73 %21 + OpBranch %34 + %46 = OpLabel + %50 = OpFMul %24 %71 %30 + %54 = OpSDiv %6 %73 %21 + OpBranch %34 + %34 = OpLabel + %77 = OpPhi %6 %45 %33 %54 %46 + %76 = OpPhi %24 %40 %33 %50 %46 + %57 = OpIAdd %6 %70 %56 + OpBranch %10 + %12 = OpLabel + %61 = OpAccessChain %28 %27 %9 + %62 = OpLoad %24 %61 + %66 = OpConvertSToF %24 %21 + %68 = OpConvertSToF %24 %73 + %69 = OpCompositeConstruct %58 %62 %71 %66 %68 + OpStore %60 %69 + OpReturn + OpFunctionEnd + )"; + + // This is the same as the original, except that the constant declaration of + // 1000 is gone. + std::string expected = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" %60 + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + OpName %16 "buf2" + OpMemberName %16 0 "i" + OpName %18 "" + OpName %25 "buf1" + OpMemberName %25 0 "f" + OpName %27 "" + OpName %60 "_GLF_color" + OpMemberDecorate %16 0 Offset 0 + OpDecorate %16 Block + OpDecorate %18 DescriptorSet 0 + OpDecorate %18 Binding 2 + OpMemberDecorate %25 0 Offset 0 + OpDecorate %25 Block + OpDecorate %27 DescriptorSet 0 + OpDecorate %27 Binding 1 + OpDecorate %60 Location 0 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %9 = OpConstant %6 0 + %16 = OpTypeStruct %6 + %17 = OpTypePointer Uniform %16 + %18 = OpVariable %17 Uniform + %19 = OpTypePointer Uniform %6 + %22 = OpTypeBool + %24 = OpTypeFloat 32 + %25 = OpTypeStruct %24 + %26 = OpTypePointer Uniform %25 + %27 = OpVariable %26 Uniform + %28 = OpTypePointer Uniform %24 + %31 = OpConstant %24 2 + %56 = OpConstant %6 1 + %58 = OpTypeVector %24 4 + %59 = OpTypePointer Output %58 + %60 = OpVariable %59 Output + %72 = OpUndef %24 + %74 = OpUndef %6 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpBranch %10 + %10 = OpLabel + %73 = OpPhi %6 %74 %5 %77 %34 + %71 = OpPhi %24 %72 %5 %76 %34 + %70 = OpPhi %6 %9 %5 %57 %34 + %20 = OpAccessChain %19 %18 %9 + %21 = OpLoad %6 %20 + %23 = OpSLessThan %22 %70 %21 + OpLoopMerge %12 %34 None + OpBranchConditional %23 %11 %12 + %11 = OpLabel + %29 = OpAccessChain %28 %27 %9 + %30 = OpLoad %24 %29 + %32 = OpFOrdGreaterThan %22 %30 %31 + OpSelectionMerge %34 None + OpBranchConditional %32 %33 %46 + %33 = OpLabel + %40 = OpFAdd %24 %71 %30 + %45 = OpISub %6 %73 %21 + OpBranch %34 + %46 = OpLabel + %50 = OpFMul %24 %71 %30 + %54 = OpSDiv %6 %73 %21 + OpBranch %34 + %34 = OpLabel + %77 = OpPhi %6 %45 %33 %54 %46 + %76 = OpPhi %24 %40 %33 %50 %46 + %57 = OpIAdd %6 %70 %56 + OpBranch %10 + %12 = OpLabel + %61 = OpAccessChain %28 %27 %9 + %62 = OpLoad %24 %61 + %66 = OpConvertSToF %24 %21 + %68 = OpConvertSToF %24 %73 + %69 = OpCompositeConstruct %58 %62 %71 %66 %68 + OpStore %60 %69 + OpReturn + OpFunctionEnd + )"; + + spv_target_env env = SPV_ENV_UNIVERSAL_1_3; + Reducer reducer(env); + reducer.SetMessageConsumer(NopDiagnostic); + + // Say that every module is interesting. + reducer.SetInterestingnessFunction( + [](const std::vector&, uint32_t) -> bool { return true; }); + + reducer.AddReductionPass(MakeUnique(env)); + + std::vector binary_in; + SpirvTools t(env); + + ASSERT_TRUE(t.Assemble(original, &binary_in, kReduceAssembleOption)); + std::vector binary_out; + spvtools::ReducerOptions reducer_options; + reducer_options.set_step_limit(500); + + reducer.Run(std::move(binary_in), &binary_out, reducer_options); + CheckEqual(env, expected, binary_out); +} + +} // namespace +} // namespace reduce +} // namespace spvtools diff --git a/test/scripts/test_compact_ids.py b/test/scripts/test_compact_ids.py new file mode 100644 index 000000000..b9b5b1bc0 --- /dev/null +++ b/test/scripts/test_compact_ids.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python +# Copyright (c) 2017 Google Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests correctness of opt pass tools/opt --compact-ids.""" + +from __future__ import print_function + +import os.path +import sys +import tempfile + +def test_spirv_file(path, temp_dir): + optimized_spv_path = os.path.join(temp_dir, 'optimized.spv') + optimized_dis_path = os.path.join(temp_dir, 'optimized.dis') + converted_spv_path = os.path.join(temp_dir, 'converted.spv') + converted_dis_path = os.path.join(temp_dir, 'converted.dis') + + os.system('tools/spirv-opt ' + path + ' -o ' + optimized_spv_path + + ' --compact-ids') + os.system('tools/spirv-dis ' + optimized_spv_path + ' -o ' + + optimized_dis_path) + + os.system('tools/spirv-dis ' + path + ' -o ' + converted_dis_path) + os.system('tools/spirv-as ' + converted_dis_path + ' -o ' + + converted_spv_path) + os.system('tools/spirv-dis ' + converted_spv_path + ' -o ' + + converted_dis_path) + + with open(converted_dis_path, 'r') as f: + converted_dis = f.readlines()[3:] + + with open(optimized_dis_path, 'r') as f: + optimized_dis = f.readlines()[3:] + + return converted_dis == optimized_dis + +def print_usage(): + template= \ +"""{script} tests correctness of opt pass tools/opt --compact-ids + +USAGE: python {script} [] + +Requires tools/spirv-dis, tools/spirv-as and tools/spirv-opt to be in path +(call the script from the SPIRV-Tools build output directory). + +TIP: In order to test all .spv files under current dir use +find -name "*.spv" -print0 | xargs -0 -s 2000000 python {script} +""" + print(template.format(script=sys.argv[0])); + +def main(): + if not os.path.isfile('tools/spirv-dis'): + print('error: tools/spirv-dis not found') + print_usage() + exit(1) + + if not os.path.isfile('tools/spirv-as'): + print('error: tools/spirv-as not found') + print_usage() + exit(1) + + if not os.path.isfile('tools/spirv-opt'): + print('error: tools/spirv-opt not found') + print_usage() + exit(1) + + paths = sys.argv[1:] + if not paths: + print_usage() + + num_failed = 0 + + temp_dir = tempfile.mkdtemp() + + for path in paths: + success = test_spirv_file(path, temp_dir) + if not success: + print('Test failed for ' + path) + num_failed += 1 + + print('Tested ' + str(len(paths)) + ' files') + + if num_failed: + print(str(num_failed) + ' tests failed') + exit(1) + else: + print('All tests successful') + exit(0) + +if __name__ == '__main__': + main() diff --git a/test/software_version_test.cpp b/test/software_version_test.cpp new file mode 100644 index 000000000..80b944a30 --- /dev/null +++ b/test/software_version_test.cpp @@ -0,0 +1,67 @@ +// Copyright (c) 2015-2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using ::testing::AnyOf; +using ::testing::Eq; +using ::testing::Ge; +using ::testing::StartsWith; + +void CheckFormOfHighLevelVersion(const std::string& version) { + std::istringstream s(version); + char v = 'x'; + int year = -1; + char period = 'x'; + int index = -1; + s >> v >> year >> period >> index; + EXPECT_THAT(v, Eq('v')); + EXPECT_THAT(year, Ge(2016)); + EXPECT_THAT(period, Eq('.')); + EXPECT_THAT(index, Ge(0)); + EXPECT_TRUE(s.good() || s.eof()); + + std::string rest; + s >> rest; + EXPECT_THAT(rest, AnyOf("", "-dev")); +} + +TEST(SoftwareVersion, ShortIsCorrectForm) { + SCOPED_TRACE("short form"); + CheckFormOfHighLevelVersion(spvSoftwareVersionString()); +} + +TEST(SoftwareVersion, DetailedIsCorrectForm) { + const std::string detailed_version(spvSoftwareVersionDetailsString()); + EXPECT_THAT(detailed_version, StartsWith("SPIRV-Tools v")); + + // Parse the high level version. + const std::string from_v = + detailed_version.substr(detailed_version.find_first_of('v')); + const size_t first_space_after_v_or_npos = from_v.find_first_of(' '); + SCOPED_TRACE(detailed_version); + CheckFormOfHighLevelVersion(from_v.substr(0, first_space_after_v_or_npos)); + + // We don't actually care about what comes after the version number. +} + +} // namespace +} // namespace spvtools diff --git a/test/stats/CMakeLists.txt b/test/stats/CMakeLists.txt new file mode 100644 index 000000000..393cb246d --- /dev/null +++ b/test/stats/CMakeLists.txt @@ -0,0 +1,27 @@ +# Copyright (c) 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set(VAL_TEST_COMMON_SRCS + ${CMAKE_CURRENT_SOURCE_DIR}/../test_fixture.h + ${CMAKE_CURRENT_SOURCE_DIR}/../unit_spirv.h +) + +add_spvtools_unittest(TARGET stats + SRCS stats_aggregate_test.cpp + stats_analyzer_test.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../../tools/stats/spirv_stats.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../../tools/stats/stats_analyzer.cpp + ${VAL_TEST_COMMON_SRCS} + LIBS ${SPIRV_TOOLS} +) diff --git a/test/stats/stats_aggregate_test.cpp b/test/stats/stats_aggregate_test.cpp new file mode 100644 index 000000000..074528543 --- /dev/null +++ b/test/stats/stats_aggregate_test.cpp @@ -0,0 +1,438 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests for unique type declaration rules validator. + +#include +#include + +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +#include "tools/stats/spirv_stats.h" + +namespace spvtools { +namespace stats { +namespace { + +using spvtest::ScopedContext; + +void DiagnosticsMessageHandler(spv_message_level_t level, const char*, + const spv_position_t& position, + const char* message) { + switch (level) { + case SPV_MSG_FATAL: + case SPV_MSG_INTERNAL_ERROR: + case SPV_MSG_ERROR: + std::cerr << "error: " << position.index << ": " << message << std::endl; + break; + case SPV_MSG_WARNING: + std::cout << "warning: " << position.index << ": " << message + << std::endl; + break; + case SPV_MSG_INFO: + std::cout << "info: " << position.index << ": " << message << std::endl; + break; + default: + break; + } +} + +// Calls AggregateStats for binary compiled from |code|. +void CompileAndAggregateStats(const std::string& code, SpirvStats* stats, + spv_target_env env = SPV_ENV_UNIVERSAL_1_1) { + spvtools::Context ctx(env); + ctx.SetMessageConsumer(DiagnosticsMessageHandler); + spv_binary binary; + ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(ctx.CContext(), code.c_str(), + code.size(), &binary, nullptr)); + + ASSERT_EQ(SPV_SUCCESS, AggregateStats(ctx.CContext(), binary->code, + binary->wordCount, nullptr, stats)); + spvBinaryDestroy(binary); +} + +TEST(AggregateStats, CapabilityHistogram) { + const std::string code1 = R"( +OpCapability Addresses +OpCapability Kernel +OpCapability GenericPointer +OpCapability Linkage +OpMemoryModel Physical32 OpenCL +)"; + + const std::string code2 = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +)"; + + SpirvStats stats; + + CompileAndAggregateStats(code1, &stats); + EXPECT_EQ(4u, stats.capability_hist.size()); + EXPECT_EQ(0u, stats.capability_hist.count(SpvCapabilityShader)); + EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityAddresses)); + EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityKernel)); + EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityGenericPointer)); + EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityLinkage)); + + CompileAndAggregateStats(code2, &stats); + EXPECT_EQ(5u, stats.capability_hist.size()); + EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityShader)); + EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityAddresses)); + EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityKernel)); + EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityGenericPointer)); + EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityLinkage)); + + CompileAndAggregateStats(code1, &stats); + EXPECT_EQ(5u, stats.capability_hist.size()); + EXPECT_EQ(1u, stats.capability_hist.at(SpvCapabilityShader)); + EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityAddresses)); + EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityKernel)); + EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityGenericPointer)); + EXPECT_EQ(3u, stats.capability_hist.at(SpvCapabilityLinkage)); + + CompileAndAggregateStats(code2, &stats); + EXPECT_EQ(5u, stats.capability_hist.size()); + EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityShader)); + EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityAddresses)); + EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityKernel)); + EXPECT_EQ(2u, stats.capability_hist.at(SpvCapabilityGenericPointer)); + EXPECT_EQ(4u, stats.capability_hist.at(SpvCapabilityLinkage)); +} + +TEST(AggregateStats, ExtensionHistogram) { + const std::string code1 = R"( +OpCapability Addresses +OpCapability Kernel +OpCapability GenericPointer +OpCapability Linkage +OpExtension "SPV_KHR_16bit_storage" +OpMemoryModel Physical32 OpenCL +)"; + + const std::string code2 = R"( +OpCapability Shader +OpCapability Linkage +OpExtension "SPV_NV_viewport_array2" +OpExtension "greatest_extension_ever" +OpMemoryModel Logical GLSL450 +)"; + + SpirvStats stats; + + CompileAndAggregateStats(code1, &stats); + EXPECT_EQ(1u, stats.extension_hist.size()); + EXPECT_EQ(0u, stats.extension_hist.count("SPV_NV_viewport_array2")); + EXPECT_EQ(1u, stats.extension_hist.at("SPV_KHR_16bit_storage")); + + CompileAndAggregateStats(code2, &stats); + EXPECT_EQ(3u, stats.extension_hist.size()); + EXPECT_EQ(1u, stats.extension_hist.at("SPV_NV_viewport_array2")); + EXPECT_EQ(1u, stats.extension_hist.at("SPV_KHR_16bit_storage")); + EXPECT_EQ(1u, stats.extension_hist.at("greatest_extension_ever")); + + CompileAndAggregateStats(code1, &stats); + EXPECT_EQ(3u, stats.extension_hist.size()); + EXPECT_EQ(1u, stats.extension_hist.at("SPV_NV_viewport_array2")); + EXPECT_EQ(2u, stats.extension_hist.at("SPV_KHR_16bit_storage")); + EXPECT_EQ(1u, stats.extension_hist.at("greatest_extension_ever")); + + CompileAndAggregateStats(code2, &stats); + EXPECT_EQ(3u, stats.extension_hist.size()); + EXPECT_EQ(2u, stats.extension_hist.at("SPV_NV_viewport_array2")); + EXPECT_EQ(2u, stats.extension_hist.at("SPV_KHR_16bit_storage")); + EXPECT_EQ(2u, stats.extension_hist.at("greatest_extension_ever")); +} + +TEST(AggregateStats, VersionHistogram) { + const std::string code1 = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +)"; + + SpirvStats stats; + + CompileAndAggregateStats(code1, &stats); + EXPECT_EQ(1u, stats.version_hist.size()); + EXPECT_EQ(1u, stats.version_hist.at(0x00010100)); + + CompileAndAggregateStats(code1, &stats, SPV_ENV_UNIVERSAL_1_0); + EXPECT_EQ(2u, stats.version_hist.size()); + EXPECT_EQ(1u, stats.version_hist.at(0x00010100)); + EXPECT_EQ(1u, stats.version_hist.at(0x00010000)); + + CompileAndAggregateStats(code1, &stats); + EXPECT_EQ(2u, stats.version_hist.size()); + EXPECT_EQ(2u, stats.version_hist.at(0x00010100)); + EXPECT_EQ(1u, stats.version_hist.at(0x00010000)); + + CompileAndAggregateStats(code1, &stats, SPV_ENV_UNIVERSAL_1_0); + EXPECT_EQ(2u, stats.version_hist.size()); + EXPECT_EQ(2u, stats.version_hist.at(0x00010100)); + EXPECT_EQ(2u, stats.version_hist.at(0x00010000)); +} + +TEST(AggregateStats, GeneratorHistogram) { + const std::string code1 = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +)"; + + const uint32_t kGeneratorKhronosAssembler = SPV_GENERATOR_KHRONOS_ASSEMBLER + << 16; + + SpirvStats stats; + + CompileAndAggregateStats(code1, &stats); + EXPECT_EQ(1u, stats.generator_hist.size()); + EXPECT_EQ(1u, stats.generator_hist.at(kGeneratorKhronosAssembler)); + + CompileAndAggregateStats(code1, &stats); + EXPECT_EQ(1u, stats.generator_hist.size()); + EXPECT_EQ(2u, stats.generator_hist.at(kGeneratorKhronosAssembler)); +} + +TEST(AggregateStats, OpcodeHistogram) { + const std::string code1 = R"( +OpCapability Addresses +OpCapability Kernel +OpCapability Int64 +OpCapability Linkage +OpMemoryModel Physical32 OpenCL +%u64 = OpTypeInt 64 0 +%u32 = OpTypeInt 32 0 +%f32 = OpTypeFloat 32 +)"; + + const std::string code2 = R"( +OpCapability Shader +OpCapability Linkage +OpExtension "SPV_NV_viewport_array2" +OpMemoryModel Logical GLSL450 +)"; + + SpirvStats stats; + + CompileAndAggregateStats(code1, &stats); + EXPECT_EQ(4u, stats.opcode_hist.size()); + EXPECT_EQ(4u, stats.opcode_hist.at(SpvOpCapability)); + EXPECT_EQ(1u, stats.opcode_hist.at(SpvOpMemoryModel)); + EXPECT_EQ(2u, stats.opcode_hist.at(SpvOpTypeInt)); + EXPECT_EQ(1u, stats.opcode_hist.at(SpvOpTypeFloat)); + + CompileAndAggregateStats(code2, &stats); + EXPECT_EQ(5u, stats.opcode_hist.size()); + EXPECT_EQ(6u, stats.opcode_hist.at(SpvOpCapability)); + EXPECT_EQ(2u, stats.opcode_hist.at(SpvOpMemoryModel)); + EXPECT_EQ(2u, stats.opcode_hist.at(SpvOpTypeInt)); + EXPECT_EQ(1u, stats.opcode_hist.at(SpvOpTypeFloat)); + EXPECT_EQ(1u, stats.opcode_hist.at(SpvOpExtension)); + + CompileAndAggregateStats(code1, &stats); + EXPECT_EQ(5u, stats.opcode_hist.size()); + EXPECT_EQ(10u, stats.opcode_hist.at(SpvOpCapability)); + EXPECT_EQ(3u, stats.opcode_hist.at(SpvOpMemoryModel)); + EXPECT_EQ(4u, stats.opcode_hist.at(SpvOpTypeInt)); + EXPECT_EQ(2u, stats.opcode_hist.at(SpvOpTypeFloat)); + EXPECT_EQ(1u, stats.opcode_hist.at(SpvOpExtension)); + + CompileAndAggregateStats(code2, &stats); + EXPECT_EQ(5u, stats.opcode_hist.size()); + EXPECT_EQ(12u, stats.opcode_hist.at(SpvOpCapability)); + EXPECT_EQ(4u, stats.opcode_hist.at(SpvOpMemoryModel)); + EXPECT_EQ(4u, stats.opcode_hist.at(SpvOpTypeInt)); + EXPECT_EQ(2u, stats.opcode_hist.at(SpvOpTypeFloat)); + EXPECT_EQ(2u, stats.opcode_hist.at(SpvOpExtension)); +} + +TEST(AggregateStats, OpcodeMarkovHistogram) { + const std::string code1 = R"( +OpCapability Shader +OpCapability Linkage +OpExtension "SPV_NV_viewport_array2" +OpMemoryModel Logical GLSL450 +)"; + + const std::string code2 = R"( +OpCapability Addresses +OpCapability Kernel +OpCapability Int64 +OpCapability Linkage +OpMemoryModel Physical32 OpenCL +%u64 = OpTypeInt 64 0 +%u32 = OpTypeInt 32 0 +%f32 = OpTypeFloat 32 +)"; + + SpirvStats stats; + stats.opcode_markov_hist.resize(2); + + CompileAndAggregateStats(code1, &stats); + ASSERT_EQ(2u, stats.opcode_markov_hist.size()); + EXPECT_EQ(2u, stats.opcode_markov_hist[0].size()); + EXPECT_EQ(2u, stats.opcode_markov_hist[0].at(SpvOpCapability).size()); + EXPECT_EQ(1u, stats.opcode_markov_hist[0].at(SpvOpExtension).size()); + EXPECT_EQ( + 1u, stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpCapability)); + EXPECT_EQ(1u, + stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpExtension)); + EXPECT_EQ( + 1u, stats.opcode_markov_hist[0].at(SpvOpExtension).at(SpvOpMemoryModel)); + + EXPECT_EQ(1u, stats.opcode_markov_hist[1].size()); + EXPECT_EQ(2u, stats.opcode_markov_hist[1].at(SpvOpCapability).size()); + EXPECT_EQ(1u, + stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpExtension)); + EXPECT_EQ( + 1u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpMemoryModel)); + + CompileAndAggregateStats(code2, &stats); + ASSERT_EQ(2u, stats.opcode_markov_hist.size()); + EXPECT_EQ(4u, stats.opcode_markov_hist[0].size()); + EXPECT_EQ(3u, stats.opcode_markov_hist[0].at(SpvOpCapability).size()); + EXPECT_EQ(1u, stats.opcode_markov_hist[0].at(SpvOpExtension).size()); + EXPECT_EQ(1u, stats.opcode_markov_hist[0].at(SpvOpMemoryModel).size()); + EXPECT_EQ(2u, stats.opcode_markov_hist[0].at(SpvOpTypeInt).size()); + EXPECT_EQ( + 4u, stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpCapability)); + EXPECT_EQ(1u, + stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpExtension)); + EXPECT_EQ( + 1u, stats.opcode_markov_hist[0].at(SpvOpCapability).at(SpvOpMemoryModel)); + EXPECT_EQ( + 1u, stats.opcode_markov_hist[0].at(SpvOpExtension).at(SpvOpMemoryModel)); + EXPECT_EQ(1u, + stats.opcode_markov_hist[0].at(SpvOpMemoryModel).at(SpvOpTypeInt)); + EXPECT_EQ(1u, stats.opcode_markov_hist[0].at(SpvOpTypeInt).at(SpvOpTypeInt)); + EXPECT_EQ(1u, + stats.opcode_markov_hist[0].at(SpvOpTypeInt).at(SpvOpTypeFloat)); + + EXPECT_EQ(3u, stats.opcode_markov_hist[1].size()); + EXPECT_EQ(4u, stats.opcode_markov_hist[1].at(SpvOpCapability).size()); + EXPECT_EQ(1u, stats.opcode_markov_hist[1].at(SpvOpMemoryModel).size()); + EXPECT_EQ(1u, stats.opcode_markov_hist[1].at(SpvOpTypeInt).size()); + EXPECT_EQ( + 2u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpCapability)); + EXPECT_EQ(1u, + stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpExtension)); + EXPECT_EQ( + 2u, stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpMemoryModel)); + EXPECT_EQ(1u, + stats.opcode_markov_hist[1].at(SpvOpCapability).at(SpvOpTypeInt)); + EXPECT_EQ(1u, + stats.opcode_markov_hist[1].at(SpvOpMemoryModel).at(SpvOpTypeInt)); + EXPECT_EQ(1u, + stats.opcode_markov_hist[1].at(SpvOpTypeInt).at(SpvOpTypeFloat)); +} + +TEST(AggregateStats, ConstantLiteralsHistogram) { + const std::string code1 = R"( +OpCapability Addresses +OpCapability Kernel +OpCapability GenericPointer +OpCapability Linkage +OpCapability Float64 +OpCapability Int16 +OpCapability Int64 +OpMemoryModel Physical32 OpenCL +%u16 = OpTypeInt 16 0 +%u32 = OpTypeInt 32 0 +%u64 = OpTypeInt 64 0 +%f32 = OpTypeFloat 32 +%f64 = OpTypeFloat 64 +%1 = OpConstant %f32 0.1 +%2 = OpConstant %f32 -2 +%3 = OpConstant %f64 -2 +%4 = OpConstant %u16 16 +%5 = OpConstant %u16 2 +%6 = OpConstant %u32 32 +%7 = OpConstant %u64 64 +)"; + + const std::string code2 = R"( +OpCapability Shader +OpCapability Linkage +OpCapability Int16 +OpCapability Int64 +OpMemoryModel Logical GLSL450 +%f32 = OpTypeFloat 32 +%u16 = OpTypeInt 16 0 +%s16 = OpTypeInt 16 1 +%u32 = OpTypeInt 32 0 +%s32 = OpTypeInt 32 1 +%u64 = OpTypeInt 64 0 +%s64 = OpTypeInt 64 1 +%1 = OpConstant %f32 0.1 +%2 = OpConstant %f32 -2 +%3 = OpConstant %u16 1 +%4 = OpConstant %u16 16 +%5 = OpConstant %u16 2 +%6 = OpConstant %s16 -16 +%7 = OpConstant %u32 32 +%8 = OpConstant %s32 2 +%9 = OpConstant %s32 -32 +%10 = OpConstant %u64 64 +%11 = OpConstant %s64 -64 +)"; + + SpirvStats stats; + + CompileAndAggregateStats(code1, &stats); + EXPECT_EQ(2u, stats.f32_constant_hist.size()); + EXPECT_EQ(1u, stats.f64_constant_hist.size()); + EXPECT_EQ(1u, stats.f32_constant_hist.at(0.1f)); + EXPECT_EQ(1u, stats.f32_constant_hist.at(-2.f)); + EXPECT_EQ(1u, stats.f64_constant_hist.at(-2)); + + EXPECT_EQ(2u, stats.u16_constant_hist.size()); + EXPECT_EQ(0u, stats.s16_constant_hist.size()); + EXPECT_EQ(1u, stats.u32_constant_hist.size()); + EXPECT_EQ(0u, stats.s32_constant_hist.size()); + EXPECT_EQ(1u, stats.u64_constant_hist.size()); + EXPECT_EQ(0u, stats.s64_constant_hist.size()); + EXPECT_EQ(1u, stats.u16_constant_hist.at(16)); + EXPECT_EQ(1u, stats.u16_constant_hist.at(2)); + EXPECT_EQ(1u, stats.u32_constant_hist.at(32)); + EXPECT_EQ(1u, stats.u64_constant_hist.at(64)); + + CompileAndAggregateStats(code2, &stats); + EXPECT_EQ(2u, stats.f32_constant_hist.size()); + EXPECT_EQ(1u, stats.f64_constant_hist.size()); + EXPECT_EQ(2u, stats.f32_constant_hist.at(0.1f)); + EXPECT_EQ(2u, stats.f32_constant_hist.at(-2.f)); + EXPECT_EQ(1u, stats.f64_constant_hist.at(-2)); + + EXPECT_EQ(3u, stats.u16_constant_hist.size()); + EXPECT_EQ(1u, stats.s16_constant_hist.size()); + EXPECT_EQ(1u, stats.u32_constant_hist.size()); + EXPECT_EQ(2u, stats.s32_constant_hist.size()); + EXPECT_EQ(1u, stats.u64_constant_hist.size()); + EXPECT_EQ(1u, stats.s64_constant_hist.size()); + EXPECT_EQ(2u, stats.u16_constant_hist.at(16)); + EXPECT_EQ(2u, stats.u16_constant_hist.at(2)); + EXPECT_EQ(1u, stats.u16_constant_hist.at(1)); + EXPECT_EQ(1u, stats.s16_constant_hist.at(-16)); + EXPECT_EQ(2u, stats.u32_constant_hist.at(32)); + EXPECT_EQ(1u, stats.s32_constant_hist.at(2)); + EXPECT_EQ(1u, stats.s32_constant_hist.at(-32)); + EXPECT_EQ(2u, stats.u64_constant_hist.at(64)); + EXPECT_EQ(1u, stats.s64_constant_hist.at(-64)); +} + +} // namespace +} // namespace stats +} // namespace spvtools diff --git a/test/stats/stats_analyzer_test.cpp b/test/stats/stats_analyzer_test.cpp new file mode 100644 index 000000000..3764c5bdd --- /dev/null +++ b/test/stats/stats_analyzer_test.cpp @@ -0,0 +1,174 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests for unique type declaration rules validator. + +#include +#include + +#include "source/latest_version_spirv_header.h" +#include "test/test_fixture.h" +#include "tools/stats/stats_analyzer.h" + +namespace spvtools { +namespace stats { +namespace { + +// Fills |stats| with some synthetic header stats, as if aggregated from 100 +// modules (100 used for simpler percentage evaluation). +void FillDefaultStats(SpirvStats* stats) { + *stats = SpirvStats(); + stats->version_hist[0x00010000] = 40; + stats->version_hist[0x00010100] = 60; + stats->generator_hist[0x00000000] = 64; + stats->generator_hist[0x00010000] = 1; + stats->generator_hist[0x00020000] = 2; + stats->generator_hist[0x00030000] = 3; + stats->generator_hist[0x00040000] = 4; + stats->generator_hist[0x00050000] = 5; + stats->generator_hist[0x00060000] = 6; + stats->generator_hist[0x00070000] = 7; + stats->generator_hist[0x00080000] = 8; + + int num_version_entries = 0; + for (const auto& pair : stats->version_hist) { + num_version_entries += pair.second; + } + + int num_generator_entries = 0; + for (const auto& pair : stats->generator_hist) { + num_generator_entries += pair.second; + } + + EXPECT_EQ(num_version_entries, num_generator_entries); +} + +TEST(StatsAnalyzer, Version) { + SpirvStats stats; + FillDefaultStats(&stats); + + StatsAnalyzer analyzer(stats); + + std::stringstream ss; + analyzer.WriteVersion(ss); + const std::string output = ss.str(); + const std::string expected_output = "Version 1.1 60%\nVersion 1.0 40%\n"; + + EXPECT_EQ(expected_output, output); +} + +TEST(StatsAnalyzer, Generator) { + SpirvStats stats; + FillDefaultStats(&stats); + + StatsAnalyzer analyzer(stats); + + std::stringstream ss; + analyzer.WriteGenerator(ss); + const std::string output = ss.str(); + const std::string expected_output = + "Khronos 64%\nKhronos Glslang Reference Front End 8%\n" + "Khronos SPIR-V Tools Assembler 7%\nKhronos LLVM/SPIR-V Translator 6%" + "\nARM 5%\nNVIDIA 4%\nCodeplay 3%\nValve 2%\nLunarG 1%\n"; + + EXPECT_EQ(expected_output, output); +} + +TEST(StatsAnalyzer, Capability) { + SpirvStats stats; + FillDefaultStats(&stats); + + stats.capability_hist[SpvCapabilityShader] = 25; + stats.capability_hist[SpvCapabilityKernel] = 75; + + StatsAnalyzer analyzer(stats); + + std::stringstream ss; + analyzer.WriteCapability(ss); + const std::string output = ss.str(); + const std::string expected_output = "Kernel 75%\nShader 25%\n"; + + EXPECT_EQ(expected_output, output); +} + +TEST(StatsAnalyzer, Extension) { + SpirvStats stats; + FillDefaultStats(&stats); + + stats.extension_hist["greatest_extension_ever"] = 1; + stats.extension_hist["worst_extension_ever"] = 10; + + StatsAnalyzer analyzer(stats); + + std::stringstream ss; + analyzer.WriteExtension(ss); + const std::string output = ss.str(); + const std::string expected_output = + "worst_extension_ever 10%\ngreatest_extension_ever 1%\n"; + + EXPECT_EQ(expected_output, output); +} + +TEST(StatsAnalyzer, Opcode) { + SpirvStats stats; + FillDefaultStats(&stats); + + stats.opcode_hist[SpvOpCapability] = 20; + stats.opcode_hist[SpvOpConstant] = 80; + stats.opcode_hist[SpvOpDecorate] = 100; + + StatsAnalyzer analyzer(stats); + + std::stringstream ss; + analyzer.WriteOpcode(ss); + const std::string output = ss.str(); + const std::string expected_output = + "Total unique opcodes used: 3\nDecorate 50%\n" + "Constant 40%\nCapability 10%\n"; + + EXPECT_EQ(expected_output, output); +} + +TEST(StatsAnalyzer, OpcodeMarkov) { + SpirvStats stats; + FillDefaultStats(&stats); + + stats.opcode_hist[SpvOpFMul] = 400; + stats.opcode_hist[SpvOpFAdd] = 200; + stats.opcode_hist[SpvOpFSub] = 400; + + stats.opcode_markov_hist.resize(1); + auto& hist = stats.opcode_markov_hist[0]; + hist[SpvOpFMul][SpvOpFAdd] = 100; + hist[SpvOpFMul][SpvOpFSub] = 300; + hist[SpvOpFAdd][SpvOpFMul] = 100; + hist[SpvOpFAdd][SpvOpFAdd] = 100; + + StatsAnalyzer analyzer(stats); + + std::stringstream ss; + analyzer.WriteOpcodeMarkov(ss); + const std::string output = ss.str(); + const std::string expected_output = + "FMul -> FSub 75% (base rate 40%, pair occurrences 300)\n" + "FMul -> FAdd 25% (base rate 20%, pair occurrences 100)\n" + "FAdd -> FAdd 50% (base rate 20%, pair occurrences 100)\n" + "FAdd -> FMul 50% (base rate 40%, pair occurrences 100)\n"; + + EXPECT_EQ(expected_output, output); +} + +} // namespace +} // namespace stats +} // namespace spvtools diff --git a/test/string_utils_test.cpp b/test/string_utils_test.cpp new file mode 100644 index 000000000..58514158f --- /dev/null +++ b/test/string_utils_test.cpp @@ -0,0 +1,191 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gtest/gtest.h" +#include "source/util/string_utils.h" +#include "spirv-tools/libspirv.h" + +namespace spvtools { +namespace utils { +namespace { + +TEST(ToString, Int) { + EXPECT_EQ("0", ToString(0)); + EXPECT_EQ("1000", ToString(1000)); + EXPECT_EQ("-1", ToString(-1)); + EXPECT_EQ("0", ToString(0LL)); + EXPECT_EQ("1000", ToString(1000LL)); + EXPECT_EQ("-1", ToString(-1LL)); +} + +TEST(ToString, Uint) { + EXPECT_EQ("0", ToString(0U)); + EXPECT_EQ("1000", ToString(1000U)); + EXPECT_EQ("0", ToString(0ULL)); + EXPECT_EQ("1000", ToString(1000ULL)); +} + +TEST(ToString, Float) { + EXPECT_EQ("0", ToString(0.f)); + EXPECT_EQ("1000", ToString(1000.f)); + EXPECT_EQ("-1.5", ToString(-1.5f)); +} + +TEST(ToString, Double) { + EXPECT_EQ("0", ToString(0.)); + EXPECT_EQ("1000", ToString(1000.)); + EXPECT_EQ("-1.5", ToString(-1.5)); +} + +TEST(CardinalToOrdinal, Test) { + EXPECT_EQ("1st", CardinalToOrdinal(1)); + EXPECT_EQ("2nd", CardinalToOrdinal(2)); + EXPECT_EQ("3rd", CardinalToOrdinal(3)); + EXPECT_EQ("4th", CardinalToOrdinal(4)); + EXPECT_EQ("5th", CardinalToOrdinal(5)); + EXPECT_EQ("6th", CardinalToOrdinal(6)); + EXPECT_EQ("7th", CardinalToOrdinal(7)); + EXPECT_EQ("8th", CardinalToOrdinal(8)); + EXPECT_EQ("9th", CardinalToOrdinal(9)); + EXPECT_EQ("10th", CardinalToOrdinal(10)); + EXPECT_EQ("11th", CardinalToOrdinal(11)); + EXPECT_EQ("12th", CardinalToOrdinal(12)); + EXPECT_EQ("13th", CardinalToOrdinal(13)); + EXPECT_EQ("14th", CardinalToOrdinal(14)); + EXPECT_EQ("15th", CardinalToOrdinal(15)); + EXPECT_EQ("16th", CardinalToOrdinal(16)); + EXPECT_EQ("17th", CardinalToOrdinal(17)); + EXPECT_EQ("18th", CardinalToOrdinal(18)); + EXPECT_EQ("19th", CardinalToOrdinal(19)); + EXPECT_EQ("20th", CardinalToOrdinal(20)); + EXPECT_EQ("21st", CardinalToOrdinal(21)); + EXPECT_EQ("22nd", CardinalToOrdinal(22)); + EXPECT_EQ("23rd", CardinalToOrdinal(23)); + EXPECT_EQ("24th", CardinalToOrdinal(24)); + EXPECT_EQ("25th", CardinalToOrdinal(25)); + EXPECT_EQ("26th", CardinalToOrdinal(26)); + EXPECT_EQ("27th", CardinalToOrdinal(27)); + EXPECT_EQ("28th", CardinalToOrdinal(28)); + EXPECT_EQ("29th", CardinalToOrdinal(29)); + EXPECT_EQ("30th", CardinalToOrdinal(30)); + EXPECT_EQ("31st", CardinalToOrdinal(31)); + EXPECT_EQ("32nd", CardinalToOrdinal(32)); + EXPECT_EQ("33rd", CardinalToOrdinal(33)); + EXPECT_EQ("34th", CardinalToOrdinal(34)); + EXPECT_EQ("35th", CardinalToOrdinal(35)); + EXPECT_EQ("100th", CardinalToOrdinal(100)); + EXPECT_EQ("101st", CardinalToOrdinal(101)); + EXPECT_EQ("102nd", CardinalToOrdinal(102)); + EXPECT_EQ("103rd", CardinalToOrdinal(103)); + EXPECT_EQ("104th", CardinalToOrdinal(104)); + EXPECT_EQ("105th", CardinalToOrdinal(105)); + EXPECT_EQ("106th", CardinalToOrdinal(106)); + EXPECT_EQ("107th", CardinalToOrdinal(107)); + EXPECT_EQ("108th", CardinalToOrdinal(108)); + EXPECT_EQ("109th", CardinalToOrdinal(109)); + EXPECT_EQ("110th", CardinalToOrdinal(110)); + EXPECT_EQ("111th", CardinalToOrdinal(111)); + EXPECT_EQ("112th", CardinalToOrdinal(112)); + EXPECT_EQ("113th", CardinalToOrdinal(113)); + EXPECT_EQ("114th", CardinalToOrdinal(114)); + EXPECT_EQ("115th", CardinalToOrdinal(115)); + EXPECT_EQ("116th", CardinalToOrdinal(116)); + EXPECT_EQ("117th", CardinalToOrdinal(117)); + EXPECT_EQ("118th", CardinalToOrdinal(118)); + EXPECT_EQ("119th", CardinalToOrdinal(119)); + EXPECT_EQ("120th", CardinalToOrdinal(120)); + EXPECT_EQ("121st", CardinalToOrdinal(121)); + EXPECT_EQ("122nd", CardinalToOrdinal(122)); + EXPECT_EQ("123rd", CardinalToOrdinal(123)); + EXPECT_EQ("124th", CardinalToOrdinal(124)); + EXPECT_EQ("125th", CardinalToOrdinal(125)); + EXPECT_EQ("126th", CardinalToOrdinal(126)); + EXPECT_EQ("127th", CardinalToOrdinal(127)); + EXPECT_EQ("128th", CardinalToOrdinal(128)); + EXPECT_EQ("129th", CardinalToOrdinal(129)); + EXPECT_EQ("130th", CardinalToOrdinal(130)); + EXPECT_EQ("131st", CardinalToOrdinal(131)); + EXPECT_EQ("132nd", CardinalToOrdinal(132)); + EXPECT_EQ("133rd", CardinalToOrdinal(133)); + EXPECT_EQ("134th", CardinalToOrdinal(134)); + EXPECT_EQ("135th", CardinalToOrdinal(135)); + EXPECT_EQ("1000th", CardinalToOrdinal(1000)); + EXPECT_EQ("1001st", CardinalToOrdinal(1001)); + EXPECT_EQ("1002nd", CardinalToOrdinal(1002)); + EXPECT_EQ("1003rd", CardinalToOrdinal(1003)); + EXPECT_EQ("1004th", CardinalToOrdinal(1004)); + EXPECT_EQ("1005th", CardinalToOrdinal(1005)); + EXPECT_EQ("1006th", CardinalToOrdinal(1006)); + EXPECT_EQ("1007th", CardinalToOrdinal(1007)); + EXPECT_EQ("1008th", CardinalToOrdinal(1008)); + EXPECT_EQ("1009th", CardinalToOrdinal(1009)); + EXPECT_EQ("1010th", CardinalToOrdinal(1010)); + EXPECT_EQ("1011th", CardinalToOrdinal(1011)); + EXPECT_EQ("1012th", CardinalToOrdinal(1012)); + EXPECT_EQ("1013th", CardinalToOrdinal(1013)); + EXPECT_EQ("1014th", CardinalToOrdinal(1014)); + EXPECT_EQ("1015th", CardinalToOrdinal(1015)); + EXPECT_EQ("1016th", CardinalToOrdinal(1016)); + EXPECT_EQ("1017th", CardinalToOrdinal(1017)); + EXPECT_EQ("1018th", CardinalToOrdinal(1018)); + EXPECT_EQ("1019th", CardinalToOrdinal(1019)); + EXPECT_EQ("1020th", CardinalToOrdinal(1020)); + EXPECT_EQ("1021st", CardinalToOrdinal(1021)); + EXPECT_EQ("1022nd", CardinalToOrdinal(1022)); + EXPECT_EQ("1023rd", CardinalToOrdinal(1023)); + EXPECT_EQ("1024th", CardinalToOrdinal(1024)); + EXPECT_EQ("1025th", CardinalToOrdinal(1025)); + EXPECT_EQ("1026th", CardinalToOrdinal(1026)); + EXPECT_EQ("1027th", CardinalToOrdinal(1027)); + EXPECT_EQ("1028th", CardinalToOrdinal(1028)); + EXPECT_EQ("1029th", CardinalToOrdinal(1029)); + EXPECT_EQ("1030th", CardinalToOrdinal(1030)); + EXPECT_EQ("1031st", CardinalToOrdinal(1031)); + EXPECT_EQ("1032nd", CardinalToOrdinal(1032)); + EXPECT_EQ("1033rd", CardinalToOrdinal(1033)); + EXPECT_EQ("1034th", CardinalToOrdinal(1034)); + EXPECT_EQ("1035th", CardinalToOrdinal(1035)); + EXPECT_EQ("1200th", CardinalToOrdinal(1200)); + EXPECT_EQ("1201st", CardinalToOrdinal(1201)); + EXPECT_EQ("1202nd", CardinalToOrdinal(1202)); + EXPECT_EQ("1203rd", CardinalToOrdinal(1203)); + EXPECT_EQ("1204th", CardinalToOrdinal(1204)); + EXPECT_EQ("1205th", CardinalToOrdinal(1205)); + EXPECT_EQ("1206th", CardinalToOrdinal(1206)); + EXPECT_EQ("1207th", CardinalToOrdinal(1207)); + EXPECT_EQ("1208th", CardinalToOrdinal(1208)); + EXPECT_EQ("1209th", CardinalToOrdinal(1209)); + EXPECT_EQ("1210th", CardinalToOrdinal(1210)); + EXPECT_EQ("1211th", CardinalToOrdinal(1211)); + EXPECT_EQ("1212th", CardinalToOrdinal(1212)); + EXPECT_EQ("1213th", CardinalToOrdinal(1213)); + EXPECT_EQ("1214th", CardinalToOrdinal(1214)); + EXPECT_EQ("1215th", CardinalToOrdinal(1215)); + EXPECT_EQ("1216th", CardinalToOrdinal(1216)); + EXPECT_EQ("1217th", CardinalToOrdinal(1217)); + EXPECT_EQ("1218th", CardinalToOrdinal(1218)); + EXPECT_EQ("1219th", CardinalToOrdinal(1219)); + EXPECT_EQ("1220th", CardinalToOrdinal(1220)); + EXPECT_EQ("1221st", CardinalToOrdinal(1221)); + EXPECT_EQ("1222nd", CardinalToOrdinal(1222)); + EXPECT_EQ("1223rd", CardinalToOrdinal(1223)); + EXPECT_EQ("1224th", CardinalToOrdinal(1224)); + EXPECT_EQ("1225th", CardinalToOrdinal(1225)); +} + +} // namespace +} // namespace utils +} // namespace spvtools diff --git a/test/target_env_test.cpp b/test/target_env_test.cpp new file mode 100644 index 000000000..f9624646d --- /dev/null +++ b/test/target_env_test.cpp @@ -0,0 +1,106 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "source/spirv_target_env.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using ::testing::AnyOf; +using ::testing::Eq; +using ::testing::StartsWith; +using ::testing::ValuesIn; + +using TargetEnvTest = ::testing::TestWithParam; +TEST_P(TargetEnvTest, CreateContext) { + spv_target_env env = GetParam(); + spv_context context = spvContextCreate(env); + ASSERT_NE(nullptr, context); + spvContextDestroy(context); // Avoid leaking +} + +TEST_P(TargetEnvTest, ValidDescription) { + const char* description = spvTargetEnvDescription(GetParam()); + ASSERT_NE(nullptr, description); + ASSERT_THAT(description, StartsWith("SPIR-V ")); +} + +TEST_P(TargetEnvTest, ValidSpirvVersion) { + auto spirv_version = spvVersionForTargetEnv(GetParam()); + ASSERT_THAT(spirv_version, AnyOf(0x10000, 0x10100, 0x10200, 0x10300)); +} + +INSTANTIATE_TEST_CASE_P(AllTargetEnvs, TargetEnvTest, + ValuesIn(spvtest::AllTargetEnvironments())); + +TEST(GetContextTest, InvalidTargetEnvProducesNull) { + // Use a value beyond the last valid enum value. + spv_context context = spvContextCreate(static_cast(30)); + EXPECT_EQ(context, nullptr); +} + +// A test case for parsing an environment string. +struct ParseCase { + const char* input; + bool success; // Expect to successfully parse? + spv_target_env env; // The parsed environment, if successful. +}; + +using TargetParseTest = ::testing::TestWithParam; + +TEST_P(TargetParseTest, InvalidTargetEnvProducesNull) { + spv_target_env env; + bool parsed = spvParseTargetEnv(GetParam().input, &env); + EXPECT_THAT(parsed, Eq(GetParam().success)); + EXPECT_THAT(env, Eq(GetParam().env)); +} + +INSTANTIATE_TEST_CASE_P( + TargetParsing, TargetParseTest, + ValuesIn(std::vector{ + {"spv1.0", true, SPV_ENV_UNIVERSAL_1_0}, + {"spv1.1", true, SPV_ENV_UNIVERSAL_1_1}, + {"spv1.2", true, SPV_ENV_UNIVERSAL_1_2}, + {"spv1.3", true, SPV_ENV_UNIVERSAL_1_3}, + {"vulkan1.0", true, SPV_ENV_VULKAN_1_0}, + {"vulkan1.1", true, SPV_ENV_VULKAN_1_1}, + {"opencl2.1", true, SPV_ENV_OPENCL_2_1}, + {"opencl2.2", true, SPV_ENV_OPENCL_2_2}, + {"opengl4.0", true, SPV_ENV_OPENGL_4_0}, + {"opengl4.1", true, SPV_ENV_OPENGL_4_1}, + {"opengl4.2", true, SPV_ENV_OPENGL_4_2}, + {"opengl4.3", true, SPV_ENV_OPENGL_4_3}, + {"opengl4.5", true, SPV_ENV_OPENGL_4_5}, + {"opencl1.2", true, SPV_ENV_OPENCL_1_2}, + {"opencl1.2embedded", true, SPV_ENV_OPENCL_EMBEDDED_1_2}, + {"opencl2.0", true, SPV_ENV_OPENCL_2_0}, + {"opencl2.0embedded", true, SPV_ENV_OPENCL_EMBEDDED_2_0}, + {"opencl2.1embedded", true, SPV_ENV_OPENCL_EMBEDDED_2_1}, + {"opencl2.2embedded", true, SPV_ENV_OPENCL_EMBEDDED_2_2}, + {"webgpu0", true, SPV_ENV_WEBGPU_0}, + {"opencl2.3", false, SPV_ENV_UNIVERSAL_1_0}, + {"opencl3.0", false, SPV_ENV_UNIVERSAL_1_0}, + {"vulkan1.2", false, SPV_ENV_UNIVERSAL_1_0}, + {"vulkan2.0", false, SPV_ENV_UNIVERSAL_1_0}, + {nullptr, false, SPV_ENV_UNIVERSAL_1_0}, + {"", false, SPV_ENV_UNIVERSAL_1_0}, + {"abc", false, SPV_ENV_UNIVERSAL_1_0}, + })); + +} // namespace +} // namespace spvtools diff --git a/test/test_fixture.h b/test/test_fixture.h new file mode 100644 index 000000000..436993e21 --- /dev/null +++ b/test/test_fixture.h @@ -0,0 +1,198 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TEST_TEST_FIXTURE_H_ +#define TEST_TEST_FIXTURE_H_ + +#include +#include + +#include "test/unit_spirv.h" + +namespace spvtest { + +// RAII for spv_context. +struct ScopedContext { + ScopedContext(spv_target_env env = SPV_ENV_UNIVERSAL_1_0) + : context(spvContextCreate(env)) {} + ~ScopedContext() { spvContextDestroy(context); } + spv_context context; +}; + +// Common setup for TextToBinary tests. SetText() should be called to populate +// the actual test text. +template +class TextToBinaryTestBase : public T { + public: + // Shorthand for SPIR-V compilation result. + using SpirvVector = std::vector; + + // Offset into a SpirvVector at which the first instruction starts. + static const SpirvVector::size_type kFirstInstruction = 5; + + TextToBinaryTestBase() : diagnostic(nullptr), text(), binary(nullptr) { + char textStr[] = "substitute the text member variable with your test"; + text = {textStr, strlen(textStr)}; + } + + virtual ~TextToBinaryTestBase() { + DestroyBinary(); + if (diagnostic) spvDiagnosticDestroy(diagnostic); + } + + // Returns subvector v[from:end). + SpirvVector Subvector(const SpirvVector& v, SpirvVector::size_type from) { + assert(from <= v.size()); + return SpirvVector(v.begin() + from, v.end()); + } + + // Compiles SPIR-V text in the given assembly syntax format, asserting + // compilation success. Returns the compiled code. + SpirvVector CompileSuccessfully(const std::string& txt, + spv_target_env env = SPV_ENV_UNIVERSAL_1_0) { + DestroyBinary(); + DestroyDiagnostic(); + spv_result_t status = + spvTextToBinary(ScopedContext(env).context, txt.c_str(), txt.size(), + &binary, &diagnostic); + EXPECT_EQ(SPV_SUCCESS, status) << txt; + SpirvVector code_copy; + if (status == SPV_SUCCESS) { + code_copy = SpirvVector(binary->code, binary->code + binary->wordCount); + DestroyBinary(); + } else { + spvDiagnosticPrint(diagnostic); + } + return code_copy; + } + + // Compiles SPIR-V text with the given format, asserting compilation failure. + // Returns the error message(s). + std::string CompileFailure(const std::string& txt, + spv_target_env env = SPV_ENV_UNIVERSAL_1_0) { + DestroyBinary(); + DestroyDiagnostic(); + EXPECT_NE(SPV_SUCCESS, + spvTextToBinary(ScopedContext(env).context, txt.c_str(), + txt.size(), &binary, &diagnostic)) + << txt; + DestroyBinary(); + return diagnostic->error; + } + + // Encodes SPIR-V text into binary and then decodes the binary using + // given options. Returns the decoded text. + std::string EncodeAndDecodeSuccessfully( + const std::string& txt, + uint32_t disassemble_options = SPV_BINARY_TO_TEXT_OPTION_NONE, + spv_target_env env = SPV_ENV_UNIVERSAL_1_0) { + DestroyBinary(); + DestroyDiagnostic(); + ScopedContext context(env); + disassemble_options |= SPV_BINARY_TO_TEXT_OPTION_NO_HEADER; + spv_result_t error = spvTextToBinary(context.context, txt.c_str(), + txt.size(), &binary, &diagnostic); + if (error) { + spvDiagnosticPrint(diagnostic); + spvDiagnosticDestroy(diagnostic); + } + EXPECT_EQ(SPV_SUCCESS, error); + if (!binary) return ""; + + spv_text decoded_text; + error = spvBinaryToText(context.context, binary->code, binary->wordCount, + disassemble_options, &decoded_text, &diagnostic); + if (error) { + spvDiagnosticPrint(diagnostic); + spvDiagnosticDestroy(diagnostic); + } + EXPECT_EQ(SPV_SUCCESS, error) << txt; + + const std::string decoded_string = decoded_text->str; + spvTextDestroy(decoded_text); + + return decoded_string; + } + + // Encodes SPIR-V text into binary. This is expected to succeed. + // The given words are then appended to the binary, and the result + // is then decoded. This is expected to fail. + // Returns the error message. + std::string EncodeSuccessfullyDecodeFailed( + const std::string& txt, const SpirvVector& words_to_append) { + DestroyBinary(); + DestroyDiagnostic(); + SpirvVector code = + spvtest::Concatenate({CompileSuccessfully(txt), words_to_append}); + + spv_text decoded_text; + EXPECT_NE(SPV_SUCCESS, + spvBinaryToText(ScopedContext().context, code.data(), code.size(), + SPV_BINARY_TO_TEXT_OPTION_NONE, &decoded_text, + &diagnostic)); + if (diagnostic) { + std::string error_message = diagnostic->error; + spvDiagnosticDestroy(diagnostic); + diagnostic = nullptr; + return error_message; + } + return ""; + } + + // Compiles SPIR-V text, asserts success, and returns the words representing + // the instructions. In particular, skip the words in the SPIR-V header. + SpirvVector CompiledInstructions(const std::string& txt, + spv_target_env env = SPV_ENV_UNIVERSAL_1_0) { + const SpirvVector code = CompileSuccessfully(txt, env); + SpirvVector result; + // Extract just the instructions. + // If the code fails to compile, then return the empty vector. + // In any case, don't crash or invoke undefined behaviour. + if (code.size() >= kFirstInstruction) + result = Subvector(code, kFirstInstruction); + return result; + } + + void SetText(const std::string& code) { + textString = code; + text.str = textString.c_str(); + text.length = textString.size(); + } + + // Destroys the binary, if it exists. + void DestroyBinary() { + spvBinaryDestroy(binary); + binary = nullptr; + } + + // Destroys the diagnostic, if it exists. + void DestroyDiagnostic() { + spvDiagnosticDestroy(diagnostic); + diagnostic = nullptr; + } + + spv_diagnostic diagnostic; + + std::string textString; + spv_text_t text; + spv_binary binary; +}; + +using TextToBinaryTest = TextToBinaryTestBase<::testing::Test>; +} // namespace spvtest + +using RoundTripTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>; + +#endif // TEST_TEST_FIXTURE_H_ diff --git a/test/text_advance_test.cpp b/test/text_advance_test.cpp new file mode 100644 index 000000000..9de77a836 --- /dev/null +++ b/test/text_advance_test.cpp @@ -0,0 +1,134 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using spvtest::AutoText; + +TEST(TextAdvance, LeadingNewLines) { + AutoText input("\n\nWord"); + AssemblyContext data(input, nullptr); + ASSERT_EQ(SPV_SUCCESS, data.advance()); + ASSERT_EQ(0u, data.position().column); + ASSERT_EQ(2u, data.position().line); + ASSERT_EQ(2u, data.position().index); +} + +TEST(TextAdvance, LeadingSpaces) { + AutoText input(" Word"); + AssemblyContext data(input, nullptr); + ASSERT_EQ(SPV_SUCCESS, data.advance()); + ASSERT_EQ(4u, data.position().column); + ASSERT_EQ(0u, data.position().line); + ASSERT_EQ(4u, data.position().index); +} + +TEST(TextAdvance, LeadingTabs) { + AutoText input("\t\t\tWord"); + AssemblyContext data(input, nullptr); + ASSERT_EQ(SPV_SUCCESS, data.advance()); + ASSERT_EQ(3u, data.position().column); + ASSERT_EQ(0u, data.position().line); + ASSERT_EQ(3u, data.position().index); +} + +TEST(TextAdvance, LeadingNewLinesSpacesAndTabs) { + AutoText input("\n\n\t Word"); + AssemblyContext data(input, nullptr); + ASSERT_EQ(SPV_SUCCESS, data.advance()); + ASSERT_EQ(3u, data.position().column); + ASSERT_EQ(2u, data.position().line); + ASSERT_EQ(5u, data.position().index); +} + +TEST(TextAdvance, LeadingWhitespaceAfterCommentLine) { + AutoText input("; comment\n \t \tWord"); + AssemblyContext data(input, nullptr); + ASSERT_EQ(SPV_SUCCESS, data.advance()); + ASSERT_EQ(4u, data.position().column); + ASSERT_EQ(1u, data.position().line); + ASSERT_EQ(14u, data.position().index); +} + +TEST(TextAdvance, EOFAfterCommentLine) { + AutoText input("; comment"); + AssemblyContext data(input, nullptr); + ASSERT_EQ(SPV_END_OF_STREAM, data.advance()); +} + +TEST(TextAdvance, NullTerminator) { + AutoText input(""); + AssemblyContext data(input, nullptr); + ASSERT_EQ(SPV_END_OF_STREAM, data.advance()); +} + +TEST(TextAdvance, NoNullTerminatorAfterCommentLine) { + std::string input = "; comment|padding beyond the end"; + spv_text_t text = {input.data(), 9}; + AssemblyContext data(&text, nullptr); + ASSERT_EQ(SPV_END_OF_STREAM, data.advance()); + EXPECT_EQ(9u, data.position().index); +} + +TEST(TextAdvance, NoNullTerminator) { + spv_text_t text = {"OpNop\nSomething else in memory", 6}; + AssemblyContext data(&text, nullptr); + const spv_position_t line_break = {1u, 5u, 5u}; + data.setPosition(line_break); + ASSERT_EQ(SPV_END_OF_STREAM, data.advance()); +} + +// Invokes AssemblyContext::advance() on text, asserts success, and returns +// AssemblyContext::position(). +spv_position_t PositionAfterAdvance(const char* text) { + AutoText input(text); + AssemblyContext data(input, nullptr); + EXPECT_EQ(SPV_SUCCESS, data.advance()); + return data.position(); +} + +TEST(TextAdvance, SkipOverCR) { + const auto pos = PositionAfterAdvance("\rWord"); + EXPECT_EQ(1u, pos.column); + EXPECT_EQ(0u, pos.line); + EXPECT_EQ(1u, pos.index); +} + +TEST(TextAdvance, SkipOverCRs) { + const auto pos = PositionAfterAdvance("\r\r\rWord"); + EXPECT_EQ(3u, pos.column); + EXPECT_EQ(0u, pos.line); + EXPECT_EQ(3u, pos.index); +} + +TEST(TextAdvance, SkipOverCRLF) { + const auto pos = PositionAfterAdvance("\r\nWord"); + EXPECT_EQ(0u, pos.column); + EXPECT_EQ(1u, pos.line); + EXPECT_EQ(2u, pos.index); +} + +TEST(TextAdvance, SkipOverCRLFs) { + const auto pos = PositionAfterAdvance("\r\n\r\nWord"); + EXPECT_EQ(0u, pos.column); + EXPECT_EQ(2u, pos.line); + EXPECT_EQ(4u, pos.index); +} +} // namespace +} // namespace spvtools diff --git a/test/text_destroy_test.cpp b/test/text_destroy_test.cpp new file mode 100644 index 000000000..4c2837ba6 --- /dev/null +++ b/test/text_destroy_test.cpp @@ -0,0 +1,75 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +TEST(TextDestroy, DestroyNull) { spvBinaryDestroy(nullptr); } + +TEST(TextDestroy, Default) { + spv_context context = spvContextCreate(SPV_ENV_UNIVERSAL_1_0); + char textStr[] = R"( + OpSource OpenCL_C 12 + OpMemoryModel Physical64 OpenCL + OpSourceExtension "PlaceholderExtensionName" + OpEntryPoint Kernel %0 "" + OpExecutionMode %0 LocalSizeHint 1 1 1 + %1 = OpTypeVoid + %2 = OpTypeBool + %3 = OpTypeInt 8 0 + %4 = OpTypeInt 8 1 + %5 = OpTypeInt 16 0 + %6 = OpTypeInt 16 1 + %7 = OpTypeInt 32 0 + %8 = OpTypeInt 32 1 + %9 = OpTypeInt 64 0 + %10 = OpTypeInt 64 1 + %11 = OpTypeFloat 16 + %12 = OpTypeFloat 32 + %13 = OpTypeFloat 64 + %14 = OpTypeVector %3 2 + )"; + + spv_binary binary = nullptr; + spv_diagnostic diagnostic = nullptr; + EXPECT_EQ(SPV_SUCCESS, spvTextToBinary(context, textStr, strlen(textStr), + &binary, &diagnostic)); + EXPECT_NE(nullptr, binary); + EXPECT_NE(nullptr, binary->code); + EXPECT_NE(0u, binary->wordCount); + if (diagnostic) { + spvDiagnosticPrint(diagnostic); + ASSERT_TRUE(false); + } + + spv_text resultText = nullptr; + EXPECT_EQ(SPV_SUCCESS, + spvBinaryToText(context, binary->code, binary->wordCount, 0, + &resultText, &diagnostic)); + spvBinaryDestroy(binary); + if (diagnostic) { + spvDiagnosticPrint(diagnostic); + spvDiagnosticDestroy(diagnostic); + ASSERT_TRUE(false); + } + EXPECT_NE(nullptr, resultText->str); + EXPECT_NE(0u, resultText->length); + spvTextDestroy(resultText); + spvContextDestroy(context); +} + +} // namespace +} // namespace spvtools diff --git a/test/text_literal_test.cpp b/test/text_literal_test.cpp new file mode 100644 index 000000000..702808931 --- /dev/null +++ b/test/text_literal_test.cpp @@ -0,0 +1,412 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using ::testing::Eq; + +TEST(TextLiteral, GoodI32) { + spv_literal_t l; + + ASSERT_EQ(SPV_SUCCESS, spvTextToLiteral("-0", &l)); + EXPECT_EQ(SPV_LITERAL_TYPE_INT_32, l.type); + EXPECT_EQ(0, l.value.i32); + + ASSERT_EQ(SPV_SUCCESS, spvTextToLiteral("-2147483648", &l)); + EXPECT_EQ(SPV_LITERAL_TYPE_INT_32, l.type); + EXPECT_EQ((-2147483647L - 1), l.value.i32); +} + +TEST(TextLiteral, GoodU32) { + spv_literal_t l; + + ASSERT_EQ(SPV_SUCCESS, spvTextToLiteral("0", &l)); + EXPECT_EQ(SPV_LITERAL_TYPE_UINT_32, l.type); + EXPECT_EQ(0, l.value.i32); + + ASSERT_EQ(SPV_SUCCESS, spvTextToLiteral("4294967295", &l)); + EXPECT_EQ(SPV_LITERAL_TYPE_UINT_32, l.type); + EXPECT_EQ(4294967295, l.value.u32); +} + +TEST(TextLiteral, GoodI64) { + spv_literal_t l; + + ASSERT_EQ(SPV_SUCCESS, spvTextToLiteral("-2147483649", &l)); + EXPECT_EQ(SPV_LITERAL_TYPE_INT_64, l.type); + EXPECT_EQ(-2147483649LL, l.value.i64); +} + +TEST(TextLiteral, GoodU64) { + spv_literal_t l; + + ASSERT_EQ(SPV_SUCCESS, spvTextToLiteral("4294967296", &l)); + EXPECT_EQ(SPV_LITERAL_TYPE_UINT_64, l.type); + EXPECT_EQ(4294967296u, l.value.u64); +} + +TEST(TextLiteral, GoodFloat) { + spv_literal_t l; + + ASSERT_EQ(SPV_SUCCESS, spvTextToLiteral("1.0", &l)); + EXPECT_EQ(SPV_LITERAL_TYPE_FLOAT_32, l.type); + EXPECT_EQ(1.0, l.value.f); + + ASSERT_EQ(SPV_SUCCESS, spvTextToLiteral("1.5", &l)); + EXPECT_EQ(SPV_LITERAL_TYPE_FLOAT_32, l.type); + EXPECT_EQ(1.5, l.value.f); + + ASSERT_EQ(SPV_SUCCESS, spvTextToLiteral("-.25", &l)); + EXPECT_EQ(SPV_LITERAL_TYPE_FLOAT_32, l.type); + EXPECT_EQ(-.25, l.value.f); +} + +TEST(TextLiteral, BadString) { + spv_literal_t l; + + EXPECT_EQ(SPV_FAILED_MATCH, spvTextToLiteral("", &l)); + EXPECT_EQ(SPV_FAILED_MATCH, spvTextToLiteral("-", &l)); + EXPECT_EQ(SPV_FAILED_MATCH, spvTextToLiteral("--", &l)); + EXPECT_EQ(SPV_FAILED_MATCH, spvTextToLiteral("1-2", &l)); + EXPECT_EQ(SPV_FAILED_MATCH, spvTextToLiteral("123a", &l)); + EXPECT_EQ(SPV_FAILED_MATCH, spvTextToLiteral("12.2.3", &l)); + EXPECT_EQ(SPV_FAILED_MATCH, spvTextToLiteral("\"", &l)); + EXPECT_EQ(SPV_FAILED_MATCH, spvTextToLiteral("\"z", &l)); + EXPECT_EQ(SPV_FAILED_MATCH, spvTextToLiteral("a\"", &l)); +} + +class GoodStringTest + : public ::testing::TestWithParam> {}; + +TEST_P(GoodStringTest, GoodStrings) { + spv_literal_t l; + + ASSERT_EQ(SPV_SUCCESS, spvTextToLiteral(std::get<0>(GetParam()), &l)); + EXPECT_EQ(SPV_LITERAL_TYPE_STRING, l.type); + EXPECT_EQ(std::get<1>(GetParam()), l.str); +} + +INSTANTIATE_TEST_CASE_P( + TextLiteral, GoodStringTest, + ::testing::ValuesIn(std::vector>{ + {R"("-")", "-"}, + {R"("--")", "--"}, + {R"("1-2")", "1-2"}, + {R"("123a")", "123a"}, + {R"("12.2.3")", "12.2.3"}, + {R"("\"")", "\""}, + {R"("\\")", "\\"}, + {"\"\\foo\nbar\"", "foo\nbar"}, + {"\"\\foo\\\nbar\"", "foo\nbar"}, + {"\"\xE4\xBA\xB2\"", "\xE4\xBA\xB2"}, + {"\"\\\xE4\xBA\xB2\"", "\xE4\xBA\xB2"}, + {"\"this \\\" and this \\\\ and \\\xE4\xBA\xB2\"", + "this \" and this \\ and \xE4\xBA\xB2"}}), ); + +TEST(TextLiteral, StringTooLong) { + spv_literal_t l; + std::string too_long = + std::string("\"") + + std::string(SPV_LIMIT_LITERAL_STRING_BYTES_MAX + 1, 'a') + "\""; + EXPECT_EQ(SPV_ERROR_OUT_OF_MEMORY, spvTextToLiteral(too_long.data(), &l)); +} + +TEST(TextLiteral, GoodLongString) { + spv_literal_t l; + // The universal limit of 65535 Unicode characters might make this + // fail validation, since SPV_LIMIT_LITERAL_STRING_BYTES_MAX is 4*65535. + // However, as an implementation detail, we'll allow the assembler + // to parse it. Otherwise we'd have to scan the string for valid UTF-8 + // characters. + std::string unquoted(SPV_LIMIT_LITERAL_STRING_BYTES_MAX, 'a'); + std::string good_long = std::string("\"") + unquoted + "\""; + EXPECT_EQ(SPV_SUCCESS, spvTextToLiteral(good_long.data(), &l)); + EXPECT_EQ(SPV_LITERAL_TYPE_STRING, l.type); + EXPECT_EQ(unquoted.data(), l.str); +} + +TEST(TextLiteral, GoodUTF8String) { + const std::string unquoted = + spvtest::MakeLongUTF8String(SPV_LIMIT_LITERAL_STRING_UTF8_CHARS_MAX); + const std::string good_long = std::string("\"") + unquoted + "\""; + spv_literal_t l; + EXPECT_EQ(SPV_SUCCESS, spvTextToLiteral(good_long.data(), &l)); + EXPECT_EQ(SPV_LITERAL_TYPE_STRING, l.type); + EXPECT_EQ(unquoted.data(), l.str); +} + +// A test case for parsing literal numbers. +struct TextLiteralCase { + uint32_t bitwidth; + const char* text; + bool is_signed; + bool success; + std::vector expected_values; +}; + +using IntegerTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>; + +std::vector successfulEncode(const TextLiteralCase& test, + IdTypeClass type) { + spv_instruction_t inst; + std::string message; + auto capture_message = [&message](spv_message_level_t, const char*, + const spv_position_t&, + const char* m) { message = m; }; + IdType expected_type{test.bitwidth, test.is_signed, type}; + EXPECT_EQ(SPV_SUCCESS, + AssemblyContext(nullptr, capture_message) + .binaryEncodeNumericLiteral(test.text, SPV_ERROR_INVALID_TEXT, + expected_type, &inst)) + << message; + return inst.words; +} + +std::string failedEncode(const TextLiteralCase& test, IdTypeClass type) { + spv_instruction_t inst; + std::string message; + auto capture_message = [&message](spv_message_level_t, const char*, + const spv_position_t&, + const char* m) { message = m; }; + IdType expected_type{test.bitwidth, test.is_signed, type}; + EXPECT_EQ(SPV_ERROR_INVALID_TEXT, + AssemblyContext(nullptr, capture_message) + .binaryEncodeNumericLiteral(test.text, SPV_ERROR_INVALID_TEXT, + expected_type, &inst)); + return message; +} + +TEST_P(IntegerTest, IntegerBounds) { + if (GetParam().success) { + EXPECT_THAT(successfulEncode(GetParam(), IdTypeClass::kScalarIntegerType), + Eq(GetParam().expected_values)); + } else { + std::stringstream ss; + ss << "Integer " << GetParam().text << " does not fit in a " + << GetParam().bitwidth << "-bit " + << (GetParam().is_signed ? "signed" : "unsigned") << " integer"; + EXPECT_THAT(failedEncode(GetParam(), IdTypeClass::kScalarIntegerType), + Eq(ss.str())); + } +} + +// Four nicely named methods for making TextLiteralCase values. +// Their names have underscores in some places to make it easier +// to read the table that follows. +TextLiteralCase Make_Ok__Signed(uint32_t bitwidth, const char* text, + std::vector encoding) { + return TextLiteralCase{bitwidth, text, true, true, encoding}; +} +TextLiteralCase Make_Ok__Unsigned(uint32_t bitwidth, const char* text, + std::vector encoding) { + return TextLiteralCase{bitwidth, text, false, true, encoding}; +} +TextLiteralCase Make_Bad_Signed(uint32_t bitwidth, const char* text) { + return TextLiteralCase{bitwidth, text, true, false, {}}; +} +TextLiteralCase Make_Bad_Unsigned(uint32_t bitwidth, const char* text) { + return TextLiteralCase{bitwidth, text, false, false, {}}; +} + +// clang-format off +INSTANTIATE_TEST_CASE_P( + DecimalIntegers, IntegerTest, + ::testing::ValuesIn(std::vector{ + // Check max value and overflow value for 1-bit numbers. + Make_Ok__Signed(1, "0", {0}), + Make_Ok__Unsigned(1, "1", {1}), + Make_Bad_Signed(1, "1"), + Make_Bad_Unsigned(1, "2"), + + // Check max value and overflow value for 2-bit numbers. + Make_Ok__Signed(2, "1", {1}), + Make_Ok__Unsigned(2, "3", {3}), + Make_Bad_Signed(2, "2"), + Make_Bad_Unsigned(2, "4"), + + // Check max negative value and overflow value for signed + // 1- and 2-bit numbers. Signed negative numbers are sign-extended. + Make_Ok__Signed(1, "-0", {uint32_t(0)}), + Make_Ok__Signed(1, "-1", {uint32_t(-1)}), + Make_Ok__Signed(2, "-0", {0}), + Make_Ok__Signed(2, "-1", {uint32_t(-1)}), + Make_Ok__Signed(2, "-2", {uint32_t(-2)}), + Make_Bad_Signed(2, "-3"), + + Make_Bad_Unsigned(2, "2224323424242424"), + Make_Ok__Unsigned(16, "65535", {0xFFFF}), + Make_Bad_Unsigned(16, "65536"), + Make_Bad_Signed(16, "65535"), + Make_Ok__Signed(16, "32767", {0x7FFF}), + Make_Ok__Signed(16, "-32768", {0xFFFF8000}), + + // Check values around 32-bits in magnitude. + Make_Ok__Unsigned(33, "4294967296", {0, 1}), + Make_Ok__Unsigned(33, "4294967297", {1, 1}), + Make_Bad_Unsigned(33, "8589934592"), + Make_Bad_Signed(33, "4294967296"), + Make_Ok__Signed(33, "-4294967296", {0x0, 0xFFFFFFFF}), + Make_Ok__Unsigned(64, "4294967296", {0, 1}), + Make_Ok__Unsigned(64, "4294967297", {1, 1}), + + // Check max value and overflow value for 64-bit numbers. + Make_Ok__Signed(64, "9223372036854775807", {0xffffffff, 0x7fffffff}), + Make_Bad_Signed(64, "9223372036854775808"), + Make_Ok__Unsigned(64, "9223372036854775808", {0x00000000, 0x80000000}), + Make_Ok__Unsigned(64, "18446744073709551615", {0xffffffff, 0xffffffff}), + Make_Ok__Signed(64, "-9223372036854775808", {0x00000000, 0x80000000}), + + }),); +// clang-format on + +using IntegerLeadingMinusTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>; + +TEST_P(IntegerLeadingMinusTest, CantHaveLeadingMinusOnUnsigned) { + EXPECT_FALSE(GetParam().success); + EXPECT_THAT(failedEncode(GetParam(), IdTypeClass::kScalarIntegerType), + Eq("Cannot put a negative number in an unsigned literal")); +} + +// clang-format off +INSTANTIATE_TEST_CASE_P( + DecimalAndHexIntegers, IntegerLeadingMinusTest, + ::testing::ValuesIn(std::vector{ + // Unsigned numbers never allow a leading minus sign. + Make_Bad_Unsigned(16, "-0"), + Make_Bad_Unsigned(16, "-0x0"), + Make_Bad_Unsigned(16, "-0x1"), + Make_Bad_Unsigned(32, "-0"), + Make_Bad_Unsigned(32, "-0x0"), + Make_Bad_Unsigned(32, "-0x1"), + Make_Bad_Unsigned(64, "-0"), + Make_Bad_Unsigned(64, "-0x0"), + Make_Bad_Unsigned(64, "-0x1"), + }),); + +// clang-format off +INSTANTIATE_TEST_CASE_P( + HexIntegers, IntegerTest, + ::testing::ValuesIn(std::vector{ + // Check 0x and 0X prefices. + Make_Ok__Signed(16, "0x1234", {0x1234}), + Make_Ok__Signed(16, "0X1234", {0x1234}), + + // Check 1-bit numbers + Make_Ok__Signed(1, "0x0", {0}), + Make_Ok__Signed(1, "0x1", {uint32_t(-1)}), + Make_Ok__Unsigned(1, "0x0", {0}), + Make_Ok__Unsigned(1, "0x1", {1}), + Make_Bad_Signed(1, "0x2"), + Make_Bad_Unsigned(1, "0x2"), + + // Check 2-bit numbers + Make_Ok__Signed(2, "0x0", {0}), + Make_Ok__Signed(2, "0x1", {1}), + Make_Ok__Signed(2, "0x2", {uint32_t(-2)}), + Make_Ok__Signed(2, "0x3", {uint32_t(-1)}), + Make_Ok__Unsigned(2, "0x0", {0}), + Make_Ok__Unsigned(2, "0x1", {1}), + Make_Ok__Unsigned(2, "0x2", {2}), + Make_Ok__Unsigned(2, "0x3", {3}), + Make_Bad_Signed(2, "0x4"), + Make_Bad_Unsigned(2, "0x4"), + + // Check 8-bit numbers + Make_Ok__Signed(8, "0x7f", {0x7f}), + Make_Ok__Signed(8, "0x80", {0xffffff80}), + Make_Ok__Unsigned(8, "0x80", {0x80}), + Make_Ok__Unsigned(8, "0xff", {0xff}), + Make_Bad_Signed(8, "0x100"), + Make_Bad_Unsigned(8, "0x100"), + + // Check 16-bit numbers + Make_Ok__Signed(16, "0x7fff", {0x7fff}), + Make_Ok__Signed(16, "0x8000", {0xffff8000}), + Make_Ok__Unsigned(16, "0x8000", {0x8000}), + Make_Ok__Unsigned(16, "0xffff", {0xffff}), + Make_Bad_Signed(16, "0x10000"), + Make_Bad_Unsigned(16, "0x10000"), + + // Check 32-bit numbers + Make_Ok__Signed(32, "0x7fffffff", {0x7fffffff}), + Make_Ok__Signed(32, "0x80000000", {0x80000000}), + Make_Ok__Unsigned(32, "0x80000000", {0x80000000}), + Make_Ok__Unsigned(32, "0xffffffff", {0xffffffff}), + Make_Bad_Signed(32, "0x100000000"), + Make_Bad_Unsigned(32, "0x100000000"), + + // Check 48-bit numbers + Make_Ok__Unsigned(48, "0x7ffffffff", {0xffffffff, 7}), + Make_Ok__Unsigned(48, "0x800000000", {0, 8}), + Make_Ok__Signed(48, "0x7fffffffffff", {0xffffffff, 0x7fff}), + Make_Ok__Signed(48, "0x800000000000", {0, 0xffff8000}), + Make_Bad_Signed(48, "0x1000000000000"), + Make_Bad_Unsigned(48, "0x1000000000000"), + + // Check 64-bit numbers + Make_Ok__Signed(64, "0x7fffffffffffffff", {0xffffffff, 0x7fffffff}), + Make_Ok__Signed(64, "0x8000000000000000", {0x00000000, 0x80000000}), + Make_Ok__Unsigned(64, "0x7fffffffffffffff", {0xffffffff, 0x7fffffff}), + Make_Ok__Unsigned(64, "0x8000000000000000", {0x00000000, 0x80000000}), + }),); +// clang-format on + +TEST(OverflowIntegerParse, Decimal) { + std::string signed_input = "-18446744073709551616"; + std::string expected_message0 = + "Invalid signed integer literal: " + signed_input; + EXPECT_THAT(failedEncode(Make_Bad_Signed(64, signed_input.c_str()), + IdTypeClass::kScalarIntegerType), + Eq(expected_message0)); + + std::string unsigned_input = "18446744073709551616"; + std::string expected_message1 = + "Invalid unsigned integer literal: " + unsigned_input; + EXPECT_THAT(failedEncode(Make_Bad_Unsigned(64, unsigned_input.c_str()), + IdTypeClass::kScalarIntegerType), + Eq(expected_message1)); + + // TODO(dneto): When the given number doesn't have a leading sign, + // we say we're trying to parse an unsigned number, even when the caller + // asked for a signed number. This is kind of weird, but it's an + // artefact of how we do the parsing. + EXPECT_THAT(failedEncode(Make_Bad_Signed(64, unsigned_input.c_str()), + IdTypeClass::kScalarIntegerType), + Eq(expected_message1)); +} + +TEST(OverflowIntegerParse, Hex) { + std::string input = "0x10000000000000000"; + std::string expected_message = "Invalid unsigned integer literal: " + input; + EXPECT_THAT(failedEncode(Make_Bad_Signed(64, input.c_str()), + IdTypeClass::kScalarIntegerType), + Eq(expected_message)); + EXPECT_THAT(failedEncode(Make_Bad_Unsigned(64, input.c_str()), + IdTypeClass::kScalarIntegerType), + Eq(expected_message)); +} + +} // namespace +} // namespace spvtools diff --git a/test/text_start_new_inst_test.cpp b/test/text_start_new_inst_test.cpp new file mode 100644 index 000000000..ff35ac84c --- /dev/null +++ b/test/text_start_new_inst_test.cpp @@ -0,0 +1,75 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using spvtest::AutoText; + +TEST(TextStartsWithOp, YesAtStart) { + EXPECT_TRUE(AssemblyContext(AutoText("OpFoo"), nullptr).isStartOfNewInst()); + EXPECT_TRUE(AssemblyContext(AutoText("OpFoo"), nullptr).isStartOfNewInst()); + EXPECT_TRUE(AssemblyContext(AutoText("OpEnCL"), nullptr).isStartOfNewInst()); +} + +TEST(TextStartsWithOp, YesAtMiddle) { + { + AutoText text(" OpFoo"); + AssemblyContext dat(text, nullptr); + dat.seekForward(2); + EXPECT_TRUE(dat.isStartOfNewInst()); + } + { + AutoText text("xx OpFoo"); + AssemblyContext dat(text, nullptr); + dat.seekForward(2); + EXPECT_TRUE(dat.isStartOfNewInst()); + } +} + +TEST(TextStartsWithOp, NoIfTooFar) { + AutoText text(" OpFoo"); + AssemblyContext dat(text, nullptr); + dat.seekForward(3); + EXPECT_FALSE(dat.isStartOfNewInst()); +} + +TEST(TextStartsWithOp, NoRegular) { + EXPECT_FALSE( + AssemblyContext(AutoText("Fee Fi Fo Fum"), nullptr).isStartOfNewInst()); + EXPECT_FALSE(AssemblyContext(AutoText("123456"), nullptr).isStartOfNewInst()); + EXPECT_FALSE(AssemblyContext(AutoText("123456"), nullptr).isStartOfNewInst()); + EXPECT_FALSE(AssemblyContext(AutoText("OpenCL"), nullptr).isStartOfNewInst()); +} + +TEST(TextStartsWithOp, YesForValueGenerationForm) { + EXPECT_TRUE( + AssemblyContext(AutoText("%foo = OpAdd"), nullptr).isStartOfNewInst()); + EXPECT_TRUE( + AssemblyContext(AutoText("%foo = OpAdd"), nullptr).isStartOfNewInst()); +} + +TEST(TextStartsWithOp, NoForNearlyValueGeneration) { + EXPECT_FALSE( + AssemblyContext(AutoText("%foo = "), nullptr).isStartOfNewInst()); + EXPECT_FALSE(AssemblyContext(AutoText("%foo "), nullptr).isStartOfNewInst()); + EXPECT_FALSE(AssemblyContext(AutoText("%foo"), nullptr).isStartOfNewInst()); +} + +} // namespace +} // namespace spvtools diff --git a/test/text_to_binary.annotation_test.cpp b/test/text_to_binary.annotation_test.cpp new file mode 100644 index 000000000..7aec90555 --- /dev/null +++ b/test/text_to_binary.annotation_test.cpp @@ -0,0 +1,510 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Assembler tests for instructions in the "Annotation" section of the +// SPIR-V spec. + +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using spvtest::EnumCase; +using spvtest::MakeInstruction; +using spvtest::MakeVector; +using spvtest::TextToBinaryTest; +using ::testing::Combine; +using ::testing::Eq; +using ::testing::Values; +using ::testing::ValuesIn; + +// Test OpDecorate + +using OpDecorateSimpleTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam< + std::tuple>>>; + +TEST_P(OpDecorateSimpleTest, AnySimpleDecoration) { + // This string should assemble, but should not validate. + std::stringstream input; + input << "OpDecorate %1 " << std::get<1>(GetParam()).name(); + for (auto operand : std::get<1>(GetParam()).operands()) + input << " " << operand; + input << std::endl; + EXPECT_THAT(CompiledInstructions(input.str(), std::get<0>(GetParam())), + Eq(MakeInstruction(SpvOpDecorate, + {1, uint32_t(std::get<1>(GetParam()).value())}, + std::get<1>(GetParam()).operands()))); + // Also check disassembly. + EXPECT_THAT( + EncodeAndDecodeSuccessfully(input.str(), SPV_BINARY_TO_TEXT_OPTION_NONE, + std::get<0>(GetParam())), + Eq(input.str())); +} + +#define CASE(NAME) SpvDecoration##NAME, #NAME +INSTANTIATE_TEST_CASE_P( + TextToBinaryDecorateSimple, OpDecorateSimpleTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector>{ + // The operand literal values are arbitrarily chosen, + // but there are the right number of them. + {CASE(RelaxedPrecision), {}}, + {CASE(SpecId), {100}}, + {CASE(Block), {}}, + {CASE(BufferBlock), {}}, + {CASE(RowMajor), {}}, + {CASE(ColMajor), {}}, + {CASE(ArrayStride), {4}}, + {CASE(MatrixStride), {16}}, + {CASE(GLSLShared), {}}, + {CASE(GLSLPacked), {}}, + {CASE(CPacked), {}}, + // Placeholder line for enum value 12 + {CASE(NoPerspective), {}}, + {CASE(Flat), {}}, + {CASE(Patch), {}}, + {CASE(Centroid), {}}, + {CASE(Sample), {}}, + {CASE(Invariant), {}}, + {CASE(Restrict), {}}, + {CASE(Aliased), {}}, + {CASE(Volatile), {}}, + {CASE(Constant), {}}, + {CASE(Coherent), {}}, + {CASE(NonWritable), {}}, + {CASE(NonReadable), {}}, + {CASE(Uniform), {}}, + {CASE(SaturatedConversion), {}}, + {CASE(Stream), {2}}, + {CASE(Location), {6}}, + {CASE(Component), {3}}, + {CASE(Index), {14}}, + {CASE(Binding), {19}}, + {CASE(DescriptorSet), {7}}, + {CASE(Offset), {12}}, + {CASE(XfbBuffer), {1}}, + {CASE(XfbStride), {8}}, + {CASE(NoContraction), {}}, + {CASE(InputAttachmentIndex), {102}}, + {CASE(Alignment), {16}}, + })), ); + +INSTANTIATE_TEST_CASE_P(TextToBinaryDecorateSimpleV11, OpDecorateSimpleTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_1), + Values(EnumCase{ + CASE(MaxByteOffset), {128}})), ); +#undef CASE + +TEST_F(OpDecorateSimpleTest, WrongDecoration) { + EXPECT_THAT(CompileFailure("OpDecorate %1 xxyyzz"), + Eq("Invalid decoration 'xxyyzz'.")); +} + +TEST_F(OpDecorateSimpleTest, ExtraOperandsOnDecorationExpectingNone) { + EXPECT_THAT(CompileFailure("OpDecorate %1 RelaxedPrecision 99"), + Eq("Expected or at the beginning of an " + "instruction, found '99'.")); +} + +TEST_F(OpDecorateSimpleTest, ExtraOperandsOnDecorationExpectingOne) { + EXPECT_THAT(CompileFailure("OpDecorate %1 SpecId 99 100"), + Eq("Expected or at the beginning of an " + "instruction, found '100'.")); +} + +TEST_F(OpDecorateSimpleTest, ExtraOperandsOnDecorationExpectingTwo) { + EXPECT_THAT( + CompileFailure("OpDecorate %1 LinkageAttributes \"abc\" Import 42"), + Eq("Expected or at the beginning of an " + "instruction, found '42'.")); +} + +// A single test case for an enum decoration. +struct DecorateEnumCase { + // Place the enum value first, so it's easier to read the binary dumps when + // the test fails. + uint32_t value; // The value within the enum, e.g. Position + std::string name; + uint32_t enum_value; // Which enum, e.g. BuiltIn + std::string enum_name; +}; + +using OpDecorateEnumTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>; + +TEST_P(OpDecorateEnumTest, AnyEnumDecoration) { + // This string should assemble, but should not validate. + const std::string input = + "OpDecorate %1 " + GetParam().enum_name + " " + GetParam().name; + EXPECT_THAT(CompiledInstructions(input), + Eq(MakeInstruction(SpvOpDecorate, {1, GetParam().enum_value, + GetParam().value}))); +} + +// Test OpDecorate BuiltIn. +// clang-format off +#define CASE(NAME) \ + { SpvBuiltIn##NAME, #NAME, SpvDecorationBuiltIn, "BuiltIn" } +INSTANTIATE_TEST_CASE_P(TextToBinaryDecorateBuiltIn, OpDecorateEnumTest, + ::testing::ValuesIn(std::vector{ + CASE(Position), + CASE(PointSize), + CASE(ClipDistance), + CASE(CullDistance), + CASE(VertexId), + CASE(InstanceId), + CASE(PrimitiveId), + CASE(InvocationId), + CASE(Layer), + CASE(ViewportIndex), + CASE(TessLevelOuter), + CASE(TessLevelInner), + CASE(TessCoord), + CASE(PatchVertices), + CASE(FragCoord), + CASE(PointCoord), + CASE(FrontFacing), + CASE(SampleId), + CASE(SamplePosition), + CASE(SampleMask), + // Value 21 intentionally missing. + CASE(FragDepth), + CASE(HelperInvocation), + CASE(NumWorkgroups), + CASE(WorkgroupSize), + CASE(WorkgroupId), + CASE(LocalInvocationId), + CASE(GlobalInvocationId), + CASE(LocalInvocationIndex), + CASE(WorkDim), + CASE(GlobalSize), + CASE(EnqueuedWorkgroupSize), + CASE(GlobalOffset), + CASE(GlobalLinearId), + // Value 35 intentionally missing. + CASE(SubgroupSize), + CASE(SubgroupMaxSize), + CASE(NumSubgroups), + CASE(NumEnqueuedSubgroups), + CASE(SubgroupId), + CASE(SubgroupLocalInvocationId), + CASE(VertexIndex), + CASE(InstanceIndex), + }),); +#undef CASE +// clang-format on + +TEST_F(OpDecorateEnumTest, WrongBuiltIn) { + EXPECT_THAT(CompileFailure("OpDecorate %1 BuiltIn xxyyzz"), + Eq("Invalid built-in 'xxyyzz'.")); +} + +// Test OpDecorate FuncParamAttr +// clang-format off +#define CASE(NAME) \ + { SpvFunctionParameterAttribute##NAME, #NAME, SpvDecorationFuncParamAttr, "FuncParamAttr" } +INSTANTIATE_TEST_CASE_P(TextToBinaryDecorateFuncParamAttr, OpDecorateEnumTest, + ::testing::ValuesIn(std::vector{ + CASE(Zext), + CASE(Sext), + CASE(ByVal), + CASE(Sret), + CASE(NoAlias), + CASE(NoCapture), + CASE(NoWrite), + CASE(NoReadWrite), + }),); +#undef CASE +// clang-format on + +TEST_F(OpDecorateEnumTest, WrongFuncParamAttr) { + EXPECT_THAT(CompileFailure("OpDecorate %1 FuncParamAttr xxyyzz"), + Eq("Invalid function parameter attribute 'xxyyzz'.")); +} + +// Test OpDecorate FPRoundingMode +// clang-format off +#define CASE(NAME) \ + { SpvFPRoundingMode##NAME, #NAME, SpvDecorationFPRoundingMode, "FPRoundingMode" } +INSTANTIATE_TEST_CASE_P(TextToBinaryDecorateFPRoundingMode, OpDecorateEnumTest, + ::testing::ValuesIn(std::vector{ + CASE(RTE), + CASE(RTZ), + CASE(RTP), + CASE(RTN), + }),); +#undef CASE +// clang-format on + +TEST_F(OpDecorateEnumTest, WrongFPRoundingMode) { + EXPECT_THAT(CompileFailure("OpDecorate %1 FPRoundingMode xxyyzz"), + Eq("Invalid floating-point rounding mode 'xxyyzz'.")); +} + +// Test OpDecorate FPFastMathMode. +// These can by named enums for the single-bit masks. However, we don't support +// symbolic combinations of the masks. Rather, they can use ! +// syntax, e.g. !0x3 + +// clang-format off +#define CASE(ENUM,NAME) \ + { SpvFPFastMathMode##ENUM, #NAME, SpvDecorationFPFastMathMode, "FPFastMathMode" } +INSTANTIATE_TEST_CASE_P(TextToBinaryDecorateFPFastMathMode, OpDecorateEnumTest, + ::testing::ValuesIn(std::vector{ + CASE(MaskNone, None), + CASE(NotNaNMask, NotNaN), + CASE(NotInfMask, NotInf), + CASE(NSZMask, NSZ), + CASE(AllowRecipMask, AllowRecip), + CASE(FastMask, Fast), + }),); +#undef CASE +// clang-format on + +TEST_F(OpDecorateEnumTest, CombinedFPFastMathMask) { + // Sample a single combination. This ensures we've integrated + // the instruction parsing logic with spvTextParseMask. + const std::string input = "OpDecorate %1 FPFastMathMode NotNaN|NotInf|NSZ"; + const uint32_t expected_enum = SpvDecorationFPFastMathMode; + const uint32_t expected_mask = SpvFPFastMathModeNotNaNMask | + SpvFPFastMathModeNotInfMask | + SpvFPFastMathModeNSZMask; + EXPECT_THAT( + CompiledInstructions(input), + Eq(MakeInstruction(SpvOpDecorate, {1, expected_enum, expected_mask}))); +} + +TEST_F(OpDecorateEnumTest, WrongFPFastMathMode) { + EXPECT_THAT( + CompileFailure("OpDecorate %1 FPFastMathMode NotNaN|xxyyzz"), + Eq("Invalid floating-point fast math mode operand 'NotNaN|xxyyzz'.")); +} + +// Test OpDecorate Linkage + +// A single test case for a linkage +struct DecorateLinkageCase { + uint32_t linkage_type_value; + std::string linkage_type_name; + std::string external_name; +}; + +using OpDecorateLinkageTest = spvtest::TextToBinaryTestBase< + ::testing::TestWithParam>; + +TEST_P(OpDecorateLinkageTest, AnyLinkageDecoration) { + // This string should assemble, but should not validate. + const std::string input = "OpDecorate %1 LinkageAttributes \"" + + GetParam().external_name + "\" " + + GetParam().linkage_type_name; + std::vector expected_operands{1, SpvDecorationLinkageAttributes}; + std::vector encoded_external_name = + MakeVector(GetParam().external_name); + expected_operands.insert(expected_operands.end(), + encoded_external_name.begin(), + encoded_external_name.end()); + expected_operands.push_back(GetParam().linkage_type_value); + EXPECT_THAT(CompiledInstructions(input), + Eq(MakeInstruction(SpvOpDecorate, expected_operands))); +} + +// clang-format off +#define CASE(ENUM) SpvLinkageType##ENUM, #ENUM +INSTANTIATE_TEST_CASE_P(TextToBinaryDecorateLinkage, OpDecorateLinkageTest, + ::testing::ValuesIn(std::vector{ + { CASE(Import), "a" }, + { CASE(Export), "foo" }, + { CASE(Import), "some kind of long name with spaces etc." }, + // TODO(dneto): utf-8, escaping, quoting cases. + }),); +#undef CASE +// clang-format on + +TEST_F(OpDecorateLinkageTest, WrongType) { + EXPECT_THAT(CompileFailure("OpDecorate %1 LinkageAttributes \"foo\" xxyyzz"), + Eq("Invalid linkage type 'xxyyzz'.")); +} + +// Test OpGroupMemberDecorate + +TEST_F(TextToBinaryTest, GroupMemberDecorateGoodOneTarget) { + EXPECT_THAT(CompiledInstructions("OpGroupMemberDecorate %group %id0 42"), + Eq(MakeInstruction(SpvOpGroupMemberDecorate, {1, 2, 42}))); +} + +TEST_F(TextToBinaryTest, GroupMemberDecorateGoodTwoTargets) { + EXPECT_THAT( + CompiledInstructions("OpGroupMemberDecorate %group %id0 96 %id1 42"), + Eq(MakeInstruction(SpvOpGroupMemberDecorate, {1, 2, 96, 3, 42}))); +} + +TEST_F(TextToBinaryTest, GroupMemberDecorateMissingGroupId) { + EXPECT_THAT(CompileFailure("OpGroupMemberDecorate"), + Eq("Expected operand, found end of stream.")); +} + +TEST_F(TextToBinaryTest, GroupMemberDecorateInvalidGroupId) { + EXPECT_THAT(CompileFailure("OpGroupMemberDecorate 16"), + Eq("Expected id to start with %.")); +} + +TEST_F(TextToBinaryTest, GroupMemberDecorateInvalidTargetId) { + EXPECT_THAT(CompileFailure("OpGroupMemberDecorate %group 12"), + Eq("Expected id to start with %.")); +} + +TEST_F(TextToBinaryTest, GroupMemberDecorateMissingTargetMemberNumber) { + EXPECT_THAT(CompileFailure("OpGroupMemberDecorate %group %id0"), + Eq("Expected operand, found end of stream.")); +} + +TEST_F(TextToBinaryTest, GroupMemberDecorateInvalidTargetMemberNumber) { + EXPECT_THAT(CompileFailure("OpGroupMemberDecorate %group %id0 %id1"), + Eq("Invalid unsigned integer literal: %id1")); +} + +TEST_F(TextToBinaryTest, GroupMemberDecorateInvalidSecondTargetId) { + EXPECT_THAT(CompileFailure("OpGroupMemberDecorate %group %id1 42 12"), + Eq("Expected id to start with %.")); +} + +TEST_F(TextToBinaryTest, GroupMemberDecorateMissingSecondTargetMemberNumber) { + EXPECT_THAT(CompileFailure("OpGroupMemberDecorate %group %id0 42 %id1"), + Eq("Expected operand, found end of stream.")); +} + +TEST_F(TextToBinaryTest, GroupMemberDecorateInvalidSecondTargetMemberNumber) { + EXPECT_THAT(CompileFailure("OpGroupMemberDecorate %group %id0 42 %id1 %id2"), + Eq("Invalid unsigned integer literal: %id2")); +} + +// Test OpMemberDecorate + +using OpMemberDecorateSimpleTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam< + std::tuple>>>; + +TEST_P(OpMemberDecorateSimpleTest, AnySimpleDecoration) { + // This string should assemble, but should not validate. + std::stringstream input; + input << "OpMemberDecorate %1 42 " << std::get<1>(GetParam()).name(); + for (auto operand : std::get<1>(GetParam()).operands()) + input << " " << operand; + input << std::endl; + EXPECT_THAT( + CompiledInstructions(input.str(), std::get<0>(GetParam())), + Eq(MakeInstruction(SpvOpMemberDecorate, + {1, 42, uint32_t(std::get<1>(GetParam()).value())}, + std::get<1>(GetParam()).operands()))); + // Also check disassembly. + EXPECT_THAT( + EncodeAndDecodeSuccessfully(input.str(), SPV_BINARY_TO_TEXT_OPTION_NONE, + std::get<0>(GetParam())), + Eq(input.str())); +} + +#define CASE(NAME) SpvDecoration##NAME, #NAME +INSTANTIATE_TEST_CASE_P( + TextToBinaryDecorateSimple, OpMemberDecorateSimpleTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector>{ + // The operand literal values are arbitrarily chosen, + // but there are the right number of them. + {CASE(RelaxedPrecision), {}}, + {CASE(SpecId), {100}}, + {CASE(Block), {}}, + {CASE(BufferBlock), {}}, + {CASE(RowMajor), {}}, + {CASE(ColMajor), {}}, + {CASE(ArrayStride), {4}}, + {CASE(MatrixStride), {16}}, + {CASE(GLSLShared), {}}, + {CASE(GLSLPacked), {}}, + {CASE(CPacked), {}}, + // Placeholder line for enum value 12 + {CASE(NoPerspective), {}}, + {CASE(Flat), {}}, + {CASE(Patch), {}}, + {CASE(Centroid), {}}, + {CASE(Sample), {}}, + {CASE(Invariant), {}}, + {CASE(Restrict), {}}, + {CASE(Aliased), {}}, + {CASE(Volatile), {}}, + {CASE(Constant), {}}, + {CASE(Coherent), {}}, + {CASE(NonWritable), {}}, + {CASE(NonReadable), {}}, + {CASE(Uniform), {}}, + {CASE(SaturatedConversion), {}}, + {CASE(Stream), {2}}, + {CASE(Location), {6}}, + {CASE(Component), {3}}, + {CASE(Index), {14}}, + {CASE(Binding), {19}}, + {CASE(DescriptorSet), {7}}, + {CASE(Offset), {12}}, + {CASE(XfbBuffer), {1}}, + {CASE(XfbStride), {8}}, + {CASE(NoContraction), {}}, + {CASE(InputAttachmentIndex), {102}}, + {CASE(Alignment), {16}}, + })), ); + +INSTANTIATE_TEST_CASE_P( + TextToBinaryDecorateSimpleV11, OpMemberDecorateSimpleTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_1), + Values(EnumCase{CASE(MaxByteOffset), {128}})), ); +#undef CASE + +TEST_F(OpMemberDecorateSimpleTest, WrongDecoration) { + EXPECT_THAT(CompileFailure("OpMemberDecorate %1 9 xxyyzz"), + Eq("Invalid decoration 'xxyyzz'.")); +} + +TEST_F(OpMemberDecorateSimpleTest, ExtraOperandsOnDecorationExpectingNone) { + EXPECT_THAT(CompileFailure("OpMemberDecorate %1 12 RelaxedPrecision 99"), + Eq("Expected or at the beginning of an " + "instruction, found '99'.")); +} + +TEST_F(OpMemberDecorateSimpleTest, ExtraOperandsOnDecorationExpectingOne) { + EXPECT_THAT(CompileFailure("OpMemberDecorate %1 0 SpecId 99 100"), + Eq("Expected or at the beginning of an " + "instruction, found '100'.")); +} + +TEST_F(OpMemberDecorateSimpleTest, ExtraOperandsOnDecorationExpectingTwo) { + EXPECT_THAT(CompileFailure( + "OpMemberDecorate %1 1 LinkageAttributes \"abc\" Import 42"), + Eq("Expected or at the beginning of an " + "instruction, found '42'.")); +} + +// TODO(dneto): OpMemberDecorate cases for decorations with parameters which +// are: not just lists of literal numbers. + +// TODO(dneto): OpDecorationGroup +// TODO(dneto): OpGroupDecorate + +} // namespace +} // namespace spvtools diff --git a/test/text_to_binary.barrier_test.cpp b/test/text_to_binary.barrier_test.cpp new file mode 100644 index 000000000..545d26ff2 --- /dev/null +++ b/test/text_to_binary.barrier_test.cpp @@ -0,0 +1,170 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Assembler tests for instructions in the "Barrier Instructions" section +// of the SPIR-V spec. + +#include + +#include "gmock/gmock.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using spvtest::MakeInstruction; +using spvtest::TextToBinaryTest; +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::Eq; + +// Test OpMemoryBarrier + +using OpMemoryBarrier = spvtest::TextToBinaryTest; + +TEST_F(OpMemoryBarrier, Good) { + const std::string input = "OpMemoryBarrier %1 %2\n"; + EXPECT_THAT(CompiledInstructions(input), + Eq(MakeInstruction(SpvOpMemoryBarrier, {1, 2}))); + EXPECT_THAT(EncodeAndDecodeSuccessfully(input), Eq(input)); +} + +TEST_F(OpMemoryBarrier, BadMissingScopeId) { + const std::string input = "OpMemoryBarrier\n"; + EXPECT_THAT(CompileFailure(input), + Eq("Expected operand, found end of stream.")); +} + +TEST_F(OpMemoryBarrier, BadInvalidScopeId) { + const std::string input = "OpMemoryBarrier 99\n"; + EXPECT_THAT(CompileFailure(input), Eq("Expected id to start with %.")); +} + +TEST_F(OpMemoryBarrier, BadMissingMemorySemanticsId) { + const std::string input = "OpMemoryBarrier %scope\n"; + EXPECT_THAT(CompileFailure(input), + Eq("Expected operand, found end of stream.")); +} + +TEST_F(OpMemoryBarrier, BadInvalidMemorySemanticsId) { + const std::string input = "OpMemoryBarrier %scope 14\n"; + EXPECT_THAT(CompileFailure(input), Eq("Expected id to start with %.")); +} + +// TODO(dneto): OpControlBarrier +// TODO(dneto): OpGroupAsyncCopy +// TODO(dneto): OpGroupWaitEvents +// TODO(dneto): OpGroupAll +// TODO(dneto): OpGroupAny +// TODO(dneto): OpGroupBroadcast +// TODO(dneto): OpGroupIAdd +// TODO(dneto): OpGroupFAdd +// TODO(dneto): OpGroupFMin +// TODO(dneto): OpGroupUMin +// TODO(dneto): OpGroupSMin +// TODO(dneto): OpGroupFMax +// TODO(dneto): OpGroupUMax +// TODO(dneto): OpGroupSMax + +using NamedMemoryBarrierTest = spvtest::TextToBinaryTest; + +// OpMemoryNamedBarrier is not in 1.0, but it is enabled by a capability. +// We should be able to assemble it. Validation checks are in another test +// file. +TEST_F(NamedMemoryBarrierTest, OpcodeAssemblesInV10) { + EXPECT_THAT( + CompiledInstructions("OpMemoryNamedBarrier %bar %scope %semantics", + SPV_ENV_UNIVERSAL_1_0), + ElementsAre(spvOpcodeMake(4, SpvOpMemoryNamedBarrier), _, _, _)); +} + +TEST_F(NamedMemoryBarrierTest, ArgumentCount) { + EXPECT_THAT(CompileFailure("OpMemoryNamedBarrier", SPV_ENV_UNIVERSAL_1_1), + Eq("Expected operand, found end of stream.")); + EXPECT_THAT( + CompileFailure("OpMemoryNamedBarrier %bar", SPV_ENV_UNIVERSAL_1_1), + Eq("Expected operand, found end of stream.")); + EXPECT_THAT( + CompileFailure("OpMemoryNamedBarrier %bar %scope", SPV_ENV_UNIVERSAL_1_1), + Eq("Expected operand, found end of stream.")); + EXPECT_THAT( + CompiledInstructions("OpMemoryNamedBarrier %bar %scope %semantics", + SPV_ENV_UNIVERSAL_1_1), + ElementsAre(spvOpcodeMake(4, SpvOpMemoryNamedBarrier), _, _, _)); + EXPECT_THAT( + CompileFailure("OpMemoryNamedBarrier %bar %scope %semantics %extra", + SPV_ENV_UNIVERSAL_1_1), + Eq("Expected '=', found end of stream.")); +} + +TEST_F(NamedMemoryBarrierTest, ArgumentTypes) { + EXPECT_THAT(CompileFailure("OpMemoryNamedBarrier 123 %scope %semantics", + SPV_ENV_UNIVERSAL_1_1), + Eq("Expected id to start with %.")); + EXPECT_THAT(CompileFailure("OpMemoryNamedBarrier %bar %scope \"semantics\"", + SPV_ENV_UNIVERSAL_1_1), + Eq("Expected id to start with %.")); +} + +using TypeNamedBarrierTest = spvtest::TextToBinaryTest; + +TEST_F(TypeNamedBarrierTest, OpcodeAssemblesInV10) { + EXPECT_THAT( + CompiledInstructions("%t = OpTypeNamedBarrier", SPV_ENV_UNIVERSAL_1_0), + ElementsAre(spvOpcodeMake(2, SpvOpTypeNamedBarrier), _)); +} + +TEST_F(TypeNamedBarrierTest, ArgumentCount) { + EXPECT_THAT(CompileFailure("OpTypeNamedBarrier", SPV_ENV_UNIVERSAL_1_1), + Eq("Expected at the beginning of an instruction, " + "found 'OpTypeNamedBarrier'.")); + EXPECT_THAT( + CompiledInstructions("%t = OpTypeNamedBarrier", SPV_ENV_UNIVERSAL_1_1), + ElementsAre(spvOpcodeMake(2, SpvOpTypeNamedBarrier), _)); + EXPECT_THAT( + CompileFailure("%t = OpTypeNamedBarrier 1 2 3", SPV_ENV_UNIVERSAL_1_1), + Eq("Expected or at the beginning of an instruction, " + "found '1'.")); +} + +using NamedBarrierInitializeTest = spvtest::TextToBinaryTest; + +TEST_F(NamedBarrierInitializeTest, OpcodeAssemblesInV10) { + EXPECT_THAT( + CompiledInstructions("%bar = OpNamedBarrierInitialize %type %count", + SPV_ENV_UNIVERSAL_1_0), + ElementsAre(spvOpcodeMake(4, SpvOpNamedBarrierInitialize), _, _, _)); +} + +TEST_F(NamedBarrierInitializeTest, ArgumentCount) { + EXPECT_THAT( + CompileFailure("%bar = OpNamedBarrierInitialize", SPV_ENV_UNIVERSAL_1_1), + Eq("Expected operand, found end of stream.")); + EXPECT_THAT(CompileFailure("%bar = OpNamedBarrierInitialize %ype", + SPV_ENV_UNIVERSAL_1_1), + Eq("Expected operand, found end of stream.")); + EXPECT_THAT( + CompiledInstructions("%bar = OpNamedBarrierInitialize %type %count", + SPV_ENV_UNIVERSAL_1_1), + ElementsAre(spvOpcodeMake(4, SpvOpNamedBarrierInitialize), _, _, _)); + EXPECT_THAT( + CompileFailure("%bar = OpNamedBarrierInitialize %type %count \"extra\"", + SPV_ENV_UNIVERSAL_1_1), + Eq("Expected or at the beginning of an instruction, " + "found '\"extra\"'.")); +} + +} // namespace +} // namespace spvtools diff --git a/test/text_to_binary.constant_test.cpp b/test/text_to_binary.constant_test.cpp new file mode 100644 index 000000000..1a24b528f --- /dev/null +++ b/test/text_to_binary.constant_test.cpp @@ -0,0 +1,830 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Assembler tests for instructions in the "Group Instrucions" section of the +// SPIR-V spec. + +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using spvtest::Concatenate; +using spvtest::EnumCase; +using spvtest::MakeInstruction; +using ::testing::Eq; + +// Test Sampler Addressing Mode enum values + +using SamplerAddressingModeTest = spvtest::TextToBinaryTestBase< + ::testing::TestWithParam>>; + +TEST_P(SamplerAddressingModeTest, AnySamplerAddressingMode) { + const std::string input = + "%result = OpConstantSampler %type " + GetParam().name() + " 0 Nearest"; + EXPECT_THAT(CompiledInstructions(input), + Eq(MakeInstruction(SpvOpConstantSampler, + {1, 2, GetParam().value(), 0, 0}))); +} + +// clang-format off +#define CASE(NAME) { SpvSamplerAddressingMode##NAME, #NAME } +INSTANTIATE_TEST_CASE_P( + TextToBinarySamplerAddressingMode, SamplerAddressingModeTest, + ::testing::ValuesIn(std::vector>{ + CASE(None), + CASE(ClampToEdge), + CASE(Clamp), + CASE(Repeat), + CASE(RepeatMirrored), + }),); +#undef CASE +// clang-format on + +TEST_F(SamplerAddressingModeTest, WrongMode) { + EXPECT_THAT(CompileFailure("%r = OpConstantSampler %t xxyyzz 0 Nearest"), + Eq("Invalid sampler addressing mode 'xxyyzz'.")); +} + +// Test Sampler Filter Mode enum values + +using SamplerFilterModeTest = spvtest::TextToBinaryTestBase< + ::testing::TestWithParam>>; + +TEST_P(SamplerFilterModeTest, AnySamplerFilterMode) { + const std::string input = + "%result = OpConstantSampler %type Clamp 0 " + GetParam().name(); + EXPECT_THAT(CompiledInstructions(input), + Eq(MakeInstruction(SpvOpConstantSampler, + {1, 2, 2, 0, GetParam().value()}))); +} + +// clang-format off +#define CASE(NAME) { SpvSamplerFilterMode##NAME, #NAME} +INSTANTIATE_TEST_CASE_P( + TextToBinarySamplerFilterMode, SamplerFilterModeTest, + ::testing::ValuesIn(std::vector>{ + CASE(Nearest), + CASE(Linear), + }),); +#undef CASE +// clang-format on + +TEST_F(SamplerFilterModeTest, WrongMode) { + EXPECT_THAT(CompileFailure("%r = OpConstantSampler %t Clamp 0 xxyyzz"), + Eq("Invalid sampler filter mode 'xxyyzz'.")); +} + +struct ConstantTestCase { + std::string constant_type; + std::string constant_value; + std::vector expected_instructions; +}; + +using OpConstantValidTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>; + +TEST_P(OpConstantValidTest, ValidTypes) { + const std::string input = "%1 = " + GetParam().constant_type + + "\n" + "%2 = OpConstant %1 " + + GetParam().constant_value + "\n"; + std::vector instructions; + EXPECT_THAT(CompiledInstructions(input), Eq(GetParam().expected_instructions)) + << " type: " << GetParam().constant_type + << " literal: " << GetParam().constant_value; +} + +// clang-format off +INSTANTIATE_TEST_CASE_P( + TextToBinaryOpConstantValid, OpConstantValidTest, + ::testing::ValuesIn(std::vector{ + // Check 16 bits + {"OpTypeInt 16 0", "0x1234", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 16, 0}), + MakeInstruction(SpvOpConstant, {1, 2, 0x1234})})}, + {"OpTypeInt 16 0", "0x8000", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 16, 0}), + MakeInstruction(SpvOpConstant, {1, 2, 0x8000})})}, + {"OpTypeInt 16 0", "0", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 16, 0}), + MakeInstruction(SpvOpConstant, {1, 2, 0})})}, + {"OpTypeInt 16 0", "65535", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 16, 0}), + MakeInstruction(SpvOpConstant, {1, 2, 65535})})}, + {"OpTypeInt 16 0", "0xffff", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 16, 0}), + MakeInstruction(SpvOpConstant, {1, 2, 65535})})}, + {"OpTypeInt 16 1", "0x8000", // Test sign extension. + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 16, 1}), + MakeInstruction(SpvOpConstant, {1, 2, 0xffff8000})})}, + {"OpTypeInt 16 1", "-32", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 16, 1}), + MakeInstruction(SpvOpConstant, {1, 2, uint32_t(-32)})})}, + {"OpTypeInt 16 1", "0", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 16, 1}), + MakeInstruction(SpvOpConstant, {1, 2, 0})})}, + {"OpTypeInt 16 1", "-0", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 16, 1}), + MakeInstruction(SpvOpConstant, {1, 2, 0})})}, + {"OpTypeInt 16 1", "-0x0", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 16, 1}), + MakeInstruction(SpvOpConstant, {1, 2, 0})})}, + {"OpTypeInt 16 1", "-32768", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 16, 1}), + MakeInstruction(SpvOpConstant, {1, 2, uint32_t(-32768)})})}, + // Check 32 bits + {"OpTypeInt 32 0", "42", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 32, 0}), + MakeInstruction(SpvOpConstant, {1, 2, 42})})}, + {"OpTypeInt 32 1", "-32", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 32, 1}), + MakeInstruction(SpvOpConstant, {1, 2, uint32_t(-32)})})}, + {"OpTypeInt 32 1", "0", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 32, 1}), + MakeInstruction(SpvOpConstant, {1, 2, 0})})}, + {"OpTypeInt 32 1", "-0", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 32, 1}), + MakeInstruction(SpvOpConstant, {1, 2, 0})})}, + {"OpTypeInt 32 1", "-0x0", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 32, 1}), + MakeInstruction(SpvOpConstant, {1, 2, 0})})}, + {"OpTypeInt 32 1", "-0x001", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 32, 1}), + MakeInstruction(SpvOpConstant, {1, 2, uint32_t(-1)})})}, + {"OpTypeInt 32 1", "2147483647", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 32, 1}), + MakeInstruction(SpvOpConstant, {1, 2, 0x7fffffffu})})}, + {"OpTypeInt 32 1", "-2147483648", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 32, 1}), + MakeInstruction(SpvOpConstant, {1, 2, 0x80000000u})})}, + {"OpTypeFloat 32", "1.0", + Concatenate({MakeInstruction(SpvOpTypeFloat, {1, 32}), + MakeInstruction(SpvOpConstant, {1, 2, 0x3f800000})})}, + {"OpTypeFloat 32", "10.0", + Concatenate({MakeInstruction(SpvOpTypeFloat, {1, 32}), + MakeInstruction(SpvOpConstant, {1, 2, 0x41200000})})}, + {"OpTypeFloat 32", "-0x1p+128", // -infinity + Concatenate({MakeInstruction(SpvOpTypeFloat, {1, 32}), + MakeInstruction(SpvOpConstant, {1, 2, 0xFF800000})})}, + {"OpTypeFloat 32", "0x1p+128", // +infinity + Concatenate({MakeInstruction(SpvOpTypeFloat, {1, 32}), + MakeInstruction(SpvOpConstant, {1, 2, 0x7F800000})})}, + {"OpTypeFloat 32", "-0x1.8p+128", // A -NaN + Concatenate({MakeInstruction(SpvOpTypeFloat, {1, 32}), + MakeInstruction(SpvOpConstant, {1, 2, 0xFFC00000})})}, + {"OpTypeFloat 32", "-0x1.0002p+128", // A +NaN + Concatenate({MakeInstruction(SpvOpTypeFloat, {1, 32}), + MakeInstruction(SpvOpConstant, {1, 2, 0xFF800100})})}, + // Check 48 bits + {"OpTypeInt 48 0", "0x1234", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 48, 0}), + MakeInstruction(SpvOpConstant, {1, 2, 0x1234, 0})})}, + {"OpTypeInt 48 0", "0x800000000001", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 48, 0}), + MakeInstruction(SpvOpConstant, {1, 2, 1, 0x00008000})})}, + {"OpTypeInt 48 1", "0x800000000000", // Test sign extension. + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 48, 1}), + MakeInstruction(SpvOpConstant, {1, 2, 0, 0xffff8000})})}, + {"OpTypeInt 48 1", "-32", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 48, 1}), + MakeInstruction(SpvOpConstant, {1, 2, uint32_t(-32), uint32_t(-1)})})}, + // Check 64 bits + {"OpTypeInt 64 0", "0x1234", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 64, 0}), + MakeInstruction(SpvOpConstant, {1, 2, 0x1234, 0})})}, + {"OpTypeInt 64 0", "18446744073709551615", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 64, 0}), + MakeInstruction(SpvOpConstant, {1, 2, 0xffffffffu, 0xffffffffu})})}, + {"OpTypeInt 64 0", "0xffffffffffffffff", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 64, 0}), + MakeInstruction(SpvOpConstant, {1, 2, 0xffffffffu, 0xffffffffu})})}, + {"OpTypeInt 64 1", "0x1234", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 64, 1}), + MakeInstruction(SpvOpConstant, {1, 2, 0x1234, 0})})}, + {"OpTypeInt 64 1", "-42", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 64, 1}), + MakeInstruction(SpvOpConstant, {1, 2, uint32_t(-42), uint32_t(-1)})})}, + {"OpTypeInt 64 1", "-0x01", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 64, 1}), + MakeInstruction(SpvOpConstant, {1, 2, 0xffffffffu, 0xffffffffu})})}, + {"OpTypeInt 64 1", "9223372036854775807", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 64, 1}), + MakeInstruction(SpvOpConstant, {1, 2, 0xffffffffu, 0x7fffffffu})})}, + {"OpTypeInt 64 1", "0x7fffffff", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 64, 1}), + MakeInstruction(SpvOpConstant, {1, 2, 0x7fffffffu, 0})})}, + }),); +// clang-format on + +// A test case for checking OpConstant with invalid literals with a leading +// minus. +struct InvalidLeadingMinusCase { + std::string type; + std::string literal; +}; + +using OpConstantInvalidLeadingMinusTest = spvtest::TextToBinaryTestBase< + ::testing::TestWithParam>; + +TEST_P(OpConstantInvalidLeadingMinusTest, InvalidCase) { + const std::string input = "%1 = " + GetParam().type + + "\n" + "%2 = OpConstant %1 " + + GetParam().literal; + EXPECT_THAT(CompileFailure(input), + Eq("Cannot put a negative number in an unsigned literal")); +} + +// clang-format off +INSTANTIATE_TEST_CASE_P( + TextToBinaryOpConstantInvalidLeadingMinus, OpConstantInvalidLeadingMinusTest, + ::testing::ValuesIn(std::vector{ + {"OpTypeInt 16 0", "-0"}, + {"OpTypeInt 16 0", "-0x0"}, + {"OpTypeInt 16 0", "-1"}, + {"OpTypeInt 32 0", "-0"}, + {"OpTypeInt 32 0", "-0x0"}, + {"OpTypeInt 32 0", "-1"}, + {"OpTypeInt 64 0", "-0"}, + {"OpTypeInt 64 0", "-0x0"}, + {"OpTypeInt 64 0", "-1"}, + }),); +// clang-format on + +// A test case for invalid floating point literals. +struct InvalidFloatConstantCase { + uint32_t width; + std::string literal; +}; + +using OpConstantInvalidFloatConstant = spvtest::TextToBinaryTestBase< + ::testing::TestWithParam>; + +TEST_P(OpConstantInvalidFloatConstant, Samples) { + // Check both kinds of instructions that take literal floats. + for (const auto& instruction : {"OpConstant", "OpSpecConstant"}) { + std::stringstream input; + input << "%1 = OpTypeFloat " << GetParam().width << "\n" + << "%2 = " << instruction << " %1 " << GetParam().literal; + std::stringstream expected_error; + expected_error << "Invalid " << GetParam().width + << "-bit float literal: " << GetParam().literal; + EXPECT_THAT(CompileFailure(input.str()), Eq(expected_error.str())); + } +} + +// clang-format off +INSTANTIATE_TEST_CASE_P( + TextToBinaryInvalidFloatConstant, OpConstantInvalidFloatConstant, + ::testing::ValuesIn(std::vector{ + {16, "abc"}, + {16, "--1"}, + {16, "-+1"}, + {16, "+-1"}, + {16, "++1"}, + {16, "1e30"}, // Overflow is an error for 16-bit floats. + {16, "-1e30"}, + {16, "1e40"}, + {16, "-1e40"}, + {16, "1e400"}, + {16, "-1e400"}, + {32, "abc"}, + {32, "--1"}, + {32, "-+1"}, + {32, "+-1"}, + {32, "++1"}, + {32, "1e40"}, // Overflow is an error for 32-bit floats. + {32, "-1e40"}, + {32, "1e400"}, + {32, "-1e400"}, + {64, "abc"}, + {64, "--1"}, + {64, "-+1"}, + {64, "+-1"}, + {64, "++1"}, + {32, "1e400"}, // Overflow is an error for 64-bit floats. + {32, "-1e400"}, + }),); +// clang-format on + +using OpConstantInvalidTypeTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>; + +TEST_P(OpConstantInvalidTypeTest, InvalidTypes) { + const std::string input = "%1 = " + GetParam() + + "\n" + "%2 = OpConstant %1 0\n"; + EXPECT_THAT( + CompileFailure(input), + Eq("Type for Constant must be a scalar floating point or integer type")); +} + +// clang-format off +INSTANTIATE_TEST_CASE_P( + TextToBinaryOpConstantInvalidValidType, OpConstantInvalidTypeTest, + ::testing::ValuesIn(std::vector{ + {"OpTypeVoid", + "OpTypeBool", + "OpTypeVector %a 32", + "OpTypeMatrix %a 32", + "OpTypeImage %a 1D 0 0 0 0 Unknown", + "OpTypeSampler", + "OpTypeSampledImage %a", + "OpTypeArray %a %b", + "OpTypeRuntimeArray %a", + "OpTypeStruct %a", + "OpTypeOpaque \"Foo\"", + "OpTypePointer UniformConstant %a", + "OpTypeFunction %a %b", + "OpTypeEvent", + "OpTypeDeviceEvent", + "OpTypeReserveId", + "OpTypeQueue", + "OpTypePipe ReadOnly", + "OpTypeForwardPointer %a UniformConstant", + // At least one thing that isn't a type at all + "OpNot %a %b" + }, + }),); +// clang-format on + +using OpSpecConstantValidTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>; + +TEST_P(OpSpecConstantValidTest, ValidTypes) { + const std::string input = "%1 = " + GetParam().constant_type + + "\n" + "%2 = OpSpecConstant %1 " + + GetParam().constant_value + "\n"; + std::vector instructions; + EXPECT_THAT(CompiledInstructions(input), + Eq(GetParam().expected_instructions)); +} + +// clang-format off +INSTANTIATE_TEST_CASE_P( + TextToBinaryOpSpecConstantValid, OpSpecConstantValidTest, + ::testing::ValuesIn(std::vector{ + // Check 16 bits + {"OpTypeInt 16 0", "0x1234", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 16, 0}), + MakeInstruction(SpvOpSpecConstant, {1, 2, 0x1234})})}, + {"OpTypeInt 16 0", "0x8000", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 16, 0}), + MakeInstruction(SpvOpSpecConstant, {1, 2, 0x8000})})}, + {"OpTypeInt 16 1", "0x8000", // Test sign extension. + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 16, 1}), + MakeInstruction(SpvOpSpecConstant, {1, 2, 0xffff8000})})}, + {"OpTypeInt 16 1", "-32", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 16, 1}), + MakeInstruction(SpvOpSpecConstant, {1, 2, uint32_t(-32)})})}, + // Check 32 bits + {"OpTypeInt 32 0", "42", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 32, 0}), + MakeInstruction(SpvOpSpecConstant, {1, 2, 42})})}, + {"OpTypeInt 32 1", "-32", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 32, 1}), + MakeInstruction(SpvOpSpecConstant, {1, 2, uint32_t(-32)})})}, + {"OpTypeFloat 32", "1.0", + Concatenate({MakeInstruction(SpvOpTypeFloat, {1, 32}), + MakeInstruction(SpvOpSpecConstant, {1, 2, 0x3f800000})})}, + {"OpTypeFloat 32", "10.0", + Concatenate({MakeInstruction(SpvOpTypeFloat, {1, 32}), + MakeInstruction(SpvOpSpecConstant, {1, 2, 0x41200000})})}, + // Check 48 bits + {"OpTypeInt 48 0", "0x1234", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 48, 0}), + MakeInstruction(SpvOpSpecConstant, {1, 2, 0x1234, 0})})}, + {"OpTypeInt 48 0", "0x800000000001", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 48, 0}), + MakeInstruction(SpvOpSpecConstant, {1, 2, 1, 0x00008000})})}, + {"OpTypeInt 48 1", "0x800000000000", // Test sign extension. + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 48, 1}), + MakeInstruction(SpvOpSpecConstant, {1, 2, 0, 0xffff8000})})}, + {"OpTypeInt 48 1", "-32", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 48, 1}), + MakeInstruction(SpvOpSpecConstant, {1, 2, uint32_t(-32), uint32_t(-1)})})}, + // Check 64 bits + {"OpTypeInt 64 0", "0x1234", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 64, 0}), + MakeInstruction(SpvOpSpecConstant, {1, 2, 0x1234, 0})})}, + {"OpTypeInt 64 1", "0x1234", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 64, 1}), + MakeInstruction(SpvOpSpecConstant, {1, 2, 0x1234, 0})})}, + {"OpTypeInt 64 1", "-42", + Concatenate({MakeInstruction(SpvOpTypeInt, {1, 64, 1}), + MakeInstruction(SpvOpSpecConstant, {1, 2, uint32_t(-42), uint32_t(-1)})})}, + }),); +// clang-format on + +using OpSpecConstantInvalidTypeTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>; + +TEST_P(OpSpecConstantInvalidTypeTest, InvalidTypes) { + const std::string input = "%1 = " + GetParam() + + "\n" + "%2 = OpSpecConstant %1 0\n"; + EXPECT_THAT(CompileFailure(input), + Eq("Type for SpecConstant must be a scalar floating point or " + "integer type")); +} + +// clang-format off +INSTANTIATE_TEST_CASE_P( + TextToBinaryOpSpecConstantInvalidValidType, OpSpecConstantInvalidTypeTest, + ::testing::ValuesIn(std::vector{ + {"OpTypeVoid", + "OpTypeBool", + "OpTypeVector %a 32", + "OpTypeMatrix %a 32", + "OpTypeImage %a 1D 0 0 0 0 Unknown", + "OpTypeSampler", + "OpTypeSampledImage %a", + "OpTypeArray %a %b", + "OpTypeRuntimeArray %a", + "OpTypeStruct %a", + "OpTypeOpaque \"Foo\"", + "OpTypePointer UniformConstant %a", + "OpTypeFunction %a %b", + "OpTypeEvent", + "OpTypeDeviceEvent", + "OpTypeReserveId", + "OpTypeQueue", + "OpTypePipe ReadOnly", + "OpTypeForwardPointer %a UniformConstant", + // At least one thing that isn't a type at all + "OpNot %a %b" + }, + }),); +// clang-format on + +const int64_t kMaxUnsigned48Bit = (int64_t(1) << 48) - 1; +const int64_t kMaxSigned48Bit = (int64_t(1) << 47) - 1; +const int64_t kMinSigned48Bit = -kMaxSigned48Bit - 1; + +INSTANTIATE_TEST_CASE_P( + OpConstantRoundTrip, RoundTripTest, + ::testing::ValuesIn(std::vector{ + // 16 bit + "%1 = OpTypeInt 16 0\n%2 = OpConstant %1 0\n", + "%1 = OpTypeInt 16 0\n%2 = OpConstant %1 65535\n", + "%1 = OpTypeInt 16 1\n%2 = OpConstant %1 -32768\n", + "%1 = OpTypeInt 16 1\n%2 = OpConstant %1 32767\n", + "%1 = OpTypeInt 32 0\n%2 = OpConstant %1 0\n", + // 32 bit + std::string("%1 = OpTypeInt 32 0\n%2 = OpConstant %1 0\n"), + std::string("%1 = OpTypeInt 32 0\n%2 = OpConstant %1 ") + + std::to_string(std::numeric_limits::max()) + "\n", + std::string("%1 = OpTypeInt 32 1\n%2 = OpConstant %1 ") + + std::to_string(std::numeric_limits::max()) + "\n", + std::string("%1 = OpTypeInt 32 1\n%2 = OpConstant %1 ") + + std::to_string(std::numeric_limits::min()) + "\n", + // 48 bit + std::string("%1 = OpTypeInt 48 0\n%2 = OpConstant %1 0\n"), + std::string("%1 = OpTypeInt 48 0\n%2 = OpConstant %1 ") + + std::to_string(kMaxUnsigned48Bit) + "\n", + std::string("%1 = OpTypeInt 48 1\n%2 = OpConstant %1 ") + + std::to_string(kMaxSigned48Bit) + "\n", + std::string("%1 = OpTypeInt 48 1\n%2 = OpConstant %1 ") + + std::to_string(kMinSigned48Bit) + "\n", + // 64 bit + std::string("%1 = OpTypeInt 64 0\n%2 = OpConstant %1 0\n"), + std::string("%1 = OpTypeInt 64 0\n%2 = OpConstant %1 ") + + std::to_string(std::numeric_limits::max()) + "\n", + std::string("%1 = OpTypeInt 64 1\n%2 = OpConstant %1 ") + + std::to_string(std::numeric_limits::max()) + "\n", + std::string("%1 = OpTypeInt 64 1\n%2 = OpConstant %1 ") + + std::to_string(std::numeric_limits::min()) + "\n", + // 32-bit float + "%1 = OpTypeFloat 32\n%2 = OpConstant %1 0\n", + "%1 = OpTypeFloat 32\n%2 = OpConstant %1 13.5\n", + "%1 = OpTypeFloat 32\n%2 = OpConstant %1 -12.5\n", + // 64-bit float + "%1 = OpTypeFloat 64\n%2 = OpConstant %1 0\n", + "%1 = OpTypeFloat 64\n%2 = OpConstant %1 1.79767e+308\n", + "%1 = OpTypeFloat 64\n%2 = OpConstant %1 -1.79767e+308\n", + }), ); + +INSTANTIATE_TEST_CASE_P( + OpConstantHalfRoundTrip, RoundTripTest, + ::testing::ValuesIn(std::vector{ + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x0p+0\n", + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x0p+0\n", + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1p+0\n", + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.1p+0\n", + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.01p-1\n", + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.8p+1\n", + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.ffcp+1\n", + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1p+0\n", + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.1p+0\n", + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.01p-1\n", + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.8p+1\n", + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.ffcp+1\n", + + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1p-16\n", // some denorms + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1p-24\n", + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1p-24\n", + + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1p+16\n", // +inf + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1p+16\n", // -inf + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.01p+16\n", // -inf + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.01p+16\n", // nan + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.11p+16\n", // nan + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.ffp+16\n", // nan + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.ffcp+16\n", // nan + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 0x1.004p+16\n", // nan + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.01p+16\n", // -nan + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.11p+16\n", // -nan + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.ffp+16\n", // -nan + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.ffcp+16\n", // -nan + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 -0x1.004p+16\n", // -nan + }), ); + +// clang-format off +// (Clang-format really wants to break up these strings across lines. +INSTANTIATE_TEST_CASE_P( + OpConstantRoundTripNonFinite, RoundTripTest, + ::testing::ValuesIn(std::vector{ + "%1 = OpTypeFloat 32\n%2 = OpConstant %1 -0x1p+128\n", // -inf + "%1 = OpTypeFloat 32\n%2 = OpConstant %1 0x1p+128\n", // inf + "%1 = OpTypeFloat 32\n%2 = OpConstant %1 -0x1.8p+128\n", // -nan + "%1 = OpTypeFloat 32\n%2 = OpConstant %1 -0x1.0002p+128\n", // -nan + "%1 = OpTypeFloat 32\n%2 = OpConstant %1 -0x1.0018p+128\n", // -nan + "%1 = OpTypeFloat 32\n%2 = OpConstant %1 -0x1.01ep+128\n", // -nan + "%1 = OpTypeFloat 32\n%2 = OpConstant %1 -0x1.fffffep+128\n", // -nan + "%1 = OpTypeFloat 32\n%2 = OpConstant %1 0x1.8p+128\n", // +nan + "%1 = OpTypeFloat 32\n%2 = OpConstant %1 0x1.0002p+128\n", // +nan + "%1 = OpTypeFloat 32\n%2 = OpConstant %1 0x1.0018p+128\n", // +nan + "%1 = OpTypeFloat 32\n%2 = OpConstant %1 0x1.01ep+128\n", // +nan + "%1 = OpTypeFloat 32\n%2 = OpConstant %1 0x1.fffffep+128\n", // +nan + "%1 = OpTypeFloat 64\n%2 = OpConstant %1 -0x1p+1024\n", // -inf + "%1 = OpTypeFloat 64\n%2 = OpConstant %1 0x1p+1024\n", // +inf + "%1 = OpTypeFloat 64\n%2 = OpConstant %1 -0x1.8p+1024\n", // -nan + "%1 = OpTypeFloat 64\n%2 = OpConstant %1 -0x1.0fp+1024\n", // -nan + "%1 = OpTypeFloat 64\n%2 = OpConstant %1 -0x1.0000000000001p+1024\n", // -nan + "%1 = OpTypeFloat 64\n%2 = OpConstant %1 -0x1.00003p+1024\n", // -nan + "%1 = OpTypeFloat 64\n%2 = OpConstant %1 -0x1.fffffffffffffp+1024\n", // -nan + "%1 = OpTypeFloat 64\n%2 = OpConstant %1 0x1.8p+1024\n", // +nan + "%1 = OpTypeFloat 64\n%2 = OpConstant %1 0x1.0fp+1024\n", // +nan + "%1 = OpTypeFloat 64\n%2 = OpConstant %1 0x1.0000000000001p+1024\n", // -nan + "%1 = OpTypeFloat 64\n%2 = OpConstant %1 0x1.00003p+1024\n", // -nan + "%1 = OpTypeFloat 64\n%2 = OpConstant %1 0x1.fffffffffffffp+1024\n", // -nan + }),); +// clang-format on + +INSTANTIATE_TEST_CASE_P( + OpSpecConstantRoundTrip, RoundTripTest, + ::testing::ValuesIn(std::vector{ + // 16 bit + "%1 = OpTypeInt 16 0\n%2 = OpSpecConstant %1 0\n", + "%1 = OpTypeInt 16 0\n%2 = OpSpecConstant %1 65535\n", + "%1 = OpTypeInt 16 1\n%2 = OpSpecConstant %1 -32768\n", + "%1 = OpTypeInt 16 1\n%2 = OpSpecConstant %1 32767\n", + "%1 = OpTypeInt 32 0\n%2 = OpSpecConstant %1 0\n", + // 32 bit + std::string("%1 = OpTypeInt 32 0\n%2 = OpSpecConstant %1 0\n"), + std::string("%1 = OpTypeInt 32 0\n%2 = OpSpecConstant %1 ") + + std::to_string(std::numeric_limits::max()) + "\n", + std::string("%1 = OpTypeInt 32 1\n%2 = OpSpecConstant %1 ") + + std::to_string(std::numeric_limits::max()) + "\n", + std::string("%1 = OpTypeInt 32 1\n%2 = OpSpecConstant %1 ") + + std::to_string(std::numeric_limits::min()) + "\n", + // 48 bit + std::string("%1 = OpTypeInt 48 0\n%2 = OpSpecConstant %1 0\n"), + std::string("%1 = OpTypeInt 48 0\n%2 = OpSpecConstant %1 ") + + std::to_string(kMaxUnsigned48Bit) + "\n", + std::string("%1 = OpTypeInt 48 1\n%2 = OpSpecConstant %1 ") + + std::to_string(kMaxSigned48Bit) + "\n", + std::string("%1 = OpTypeInt 48 1\n%2 = OpSpecConstant %1 ") + + std::to_string(kMinSigned48Bit) + "\n", + // 64 bit + std::string("%1 = OpTypeInt 64 0\n%2 = OpSpecConstant %1 0\n"), + std::string("%1 = OpTypeInt 64 0\n%2 = OpSpecConstant %1 ") + + std::to_string(std::numeric_limits::max()) + "\n", + std::string("%1 = OpTypeInt 64 1\n%2 = OpSpecConstant %1 ") + + std::to_string(std::numeric_limits::max()) + "\n", + std::string("%1 = OpTypeInt 64 1\n%2 = OpSpecConstant %1 ") + + std::to_string(std::numeric_limits::min()) + "\n", + // 32-bit float + "%1 = OpTypeFloat 32\n%2 = OpSpecConstant %1 0\n", + "%1 = OpTypeFloat 32\n%2 = OpSpecConstant %1 13.5\n", + "%1 = OpTypeFloat 32\n%2 = OpSpecConstant %1 -12.5\n", + // 64-bit float + "%1 = OpTypeFloat 64\n%2 = OpSpecConstant %1 0\n", + "%1 = OpTypeFloat 64\n%2 = OpSpecConstant %1 1.79767e+308\n", + "%1 = OpTypeFloat 64\n%2 = OpSpecConstant %1 -1.79767e+308\n", + }), ); + +// Test OpSpecConstantOp + +using OpSpecConstantOpTestWithIds = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>>; + +// The operands to the OpSpecConstantOp opcode are all Ids. +TEST_P(OpSpecConstantOpTestWithIds, Assembly) { + std::stringstream input; + input << "%2 = OpSpecConstantOp %1 " << GetParam().name(); + for (auto id : GetParam().operands()) input << " %" << id; + input << "\n"; + + EXPECT_THAT(CompiledInstructions(input.str()), + Eq(MakeInstruction(SpvOpSpecConstantOp, + {1, 2, uint32_t(GetParam().value())}, + GetParam().operands()))); + + // Check the disassembler as well. + EXPECT_THAT(EncodeAndDecodeSuccessfully(input.str()), input.str()); +} + +// clang-format off +#define CASE1(NAME) { SpvOp##NAME, #NAME, {3} } +#define CASE2(NAME) { SpvOp##NAME, #NAME, {3, 4} } +#define CASE3(NAME) { SpvOp##NAME, #NAME, {3, 4, 5} } +#define CASE4(NAME) { SpvOp##NAME, #NAME, {3, 4, 5, 6} } +#define CASE5(NAME) { SpvOp##NAME, #NAME, {3, 4, 5, 6, 7} } +#define CASE6(NAME) { SpvOp##NAME, #NAME, {3, 4, 5, 6, 7, 8} } +INSTANTIATE_TEST_CASE_P( + TextToBinaryOpSpecConstantOp, OpSpecConstantOpTestWithIds, + ::testing::ValuesIn(std::vector>{ + // Conversion + CASE1(SConvert), + CASE1(FConvert), + CASE1(ConvertFToS), + CASE1(ConvertSToF), + CASE1(ConvertFToU), + CASE1(ConvertUToF), + CASE1(UConvert), + CASE1(ConvertPtrToU), + CASE1(ConvertUToPtr), + CASE1(GenericCastToPtr), + CASE1(PtrCastToGeneric), + CASE1(Bitcast), + CASE1(QuantizeToF16), + // Arithmetic + CASE1(SNegate), + CASE1(Not), + CASE2(IAdd), + CASE2(ISub), + CASE2(IMul), + CASE2(UDiv), + CASE2(SDiv), + CASE2(UMod), + CASE2(SRem), + CASE2(SMod), + CASE2(ShiftRightLogical), + CASE2(ShiftRightArithmetic), + CASE2(ShiftLeftLogical), + CASE2(BitwiseOr), + CASE2(BitwiseAnd), + CASE2(BitwiseXor), + CASE1(FNegate), + CASE2(FAdd), + CASE2(FSub), + CASE2(FMul), + CASE2(FDiv), + CASE2(FRem), + CASE2(FMod), + // Composite operations use literal numbers. So they're in another test. + // Logical + CASE2(LogicalOr), + CASE2(LogicalAnd), + CASE1(LogicalNot), + CASE2(LogicalEqual), + CASE2(LogicalNotEqual), + CASE3(Select), + // Comparison + CASE2(IEqual), + CASE2(INotEqual), // Allowed in 1.0 Rev 7 + CASE2(ULessThan), + CASE2(SLessThan), + CASE2(UGreaterThan), + CASE2(SGreaterThan), + CASE2(ULessThanEqual), + CASE2(SLessThanEqual), + CASE2(UGreaterThanEqual), + CASE2(SGreaterThanEqual), + // Memory + // For AccessChain, there is a base Id, then a sequence of index Ids. + // Having no index Ids is a corner case. + CASE1(AccessChain), + CASE2(AccessChain), + CASE6(AccessChain), + CASE1(InBoundsAccessChain), + CASE2(InBoundsAccessChain), + CASE6(InBoundsAccessChain), + // PtrAccessChain also has an element Id. + CASE2(PtrAccessChain), + CASE3(PtrAccessChain), + CASE6(PtrAccessChain), + CASE2(InBoundsPtrAccessChain), + CASE3(InBoundsPtrAccessChain), + CASE6(InBoundsPtrAccessChain), + }),); +#undef CASE1 +#undef CASE2 +#undef CASE3 +#undef CASE4 +#undef CASE5 +#undef CASE6 +// clang-format on + +using OpSpecConstantOpTestWithTwoIdsThenLiteralNumbers = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>>; + +// The operands to the OpSpecConstantOp opcode are two Ids followed by a +// sequence of literal numbers. +TEST_P(OpSpecConstantOpTestWithTwoIdsThenLiteralNumbers, Assembly) { + std::stringstream input; + input << "%2 = OpSpecConstantOp %1 " << GetParam().name() << " %3 %4"; + for (auto number : GetParam().operands()) input << " " << number; + input << "\n"; + + EXPECT_THAT(CompiledInstructions(input.str()), + Eq(MakeInstruction(SpvOpSpecConstantOp, + {1, 2, uint32_t(GetParam().value()), 3, 4}, + GetParam().operands()))); + + // Check the disassembler as well. + EXPECT_THAT(EncodeAndDecodeSuccessfully(input.str()), input.str()); +} + +#define CASE(NAME) SpvOp##NAME, #NAME +INSTANTIATE_TEST_CASE_P( + TextToBinaryOpSpecConstantOp, + OpSpecConstantOpTestWithTwoIdsThenLiteralNumbers, + ::testing::ValuesIn(std::vector>{ + // For VectorShuffle, there are two vector operands, and at least + // two selector Ids. OpenCL can have up to 16-element vectors. + {CASE(VectorShuffle), {0, 0}}, + {CASE(VectorShuffle), {4, 3, 2, 1}}, + {CASE(VectorShuffle), {0, 2, 4, 6, 1, 3, 5, 7}}, + {CASE(VectorShuffle), + {15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}}, + // For CompositeInsert, there is an object to insert, the target + // composite, and then literal indices. + {CASE(CompositeInsert), {0}}, + {CASE(CompositeInsert), {4, 3, 99, 1}}, + }), ); + +using OpSpecConstantOpTestWithOneIdThenLiteralNumbers = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>>; + +// The operands to the OpSpecConstantOp opcode are one Id followed by a +// sequence of literal numbers. +TEST_P(OpSpecConstantOpTestWithOneIdThenLiteralNumbers, Assembly) { + std::stringstream input; + input << "%2 = OpSpecConstantOp %1 " << GetParam().name() << " %3"; + for (auto number : GetParam().operands()) input << " " << number; + input << "\n"; + + EXPECT_THAT(CompiledInstructions(input.str()), + Eq(MakeInstruction(SpvOpSpecConstantOp, + {1, 2, uint32_t(GetParam().value()), 3}, + GetParam().operands()))); + + // Check the disassembler as well. + EXPECT_THAT(EncodeAndDecodeSuccessfully(input.str()), input.str()); +} + +#define CASE(NAME) SpvOp##NAME, #NAME +INSTANTIATE_TEST_CASE_P( + TextToBinaryOpSpecConstantOp, + OpSpecConstantOpTestWithOneIdThenLiteralNumbers, + ::testing::ValuesIn(std::vector>{ + // For CompositeExtract, the universal limit permits up to 255 literal + // indices. Let's only test a few. + {CASE(CompositeExtract), {0}}, + {CASE(CompositeExtract), {0, 99, 42, 16, 17, 12, 19}}, + }), ); + +// TODO(dneto): OpConstantTrue +// TODO(dneto): OpConstantFalse +// TODO(dneto): OpConstantComposite +// TODO(dneto): OpConstantSampler: other variations Param is 0 or 1 +// TODO(dneto): OpConstantNull +// TODO(dneto): OpSpecConstantTrue +// TODO(dneto): OpSpecConstantFalse +// TODO(dneto): OpSpecConstantComposite +// TODO(dneto): Negative tests for OpSpecConstantOp + +} // namespace +} // namespace spvtools diff --git a/test/text_to_binary.control_flow_test.cpp b/test/text_to_binary.control_flow_test.cpp new file mode 100644 index 000000000..07f110884 --- /dev/null +++ b/test/text_to_binary.control_flow_test.cpp @@ -0,0 +1,394 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Assembler tests for instructions in the "Control Flow" section of the +// SPIR-V spec. + +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using spvtest::Concatenate; +using spvtest::EnumCase; +using spvtest::MakeInstruction; +using spvtest::TextToBinaryTest; +using ::testing::Combine; +using ::testing::Eq; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +// Test OpSelectionMerge + +using OpSelectionMergeTest = spvtest::TextToBinaryTestBase< + TestWithParam>>; + +TEST_P(OpSelectionMergeTest, AnySingleSelectionControlMask) { + const std::string input = "OpSelectionMerge %1 " + GetParam().name(); + EXPECT_THAT( + CompiledInstructions(input), + Eq(MakeInstruction(SpvOpSelectionMerge, {1, GetParam().value()}))); +} + +// clang-format off +#define CASE(VALUE,NAME) { SpvSelectionControl##VALUE, NAME} +INSTANTIATE_TEST_CASE_P(TextToBinarySelectionMerge, OpSelectionMergeTest, + ValuesIn(std::vector>{ + CASE(MaskNone, "None"), + CASE(FlattenMask, "Flatten"), + CASE(DontFlattenMask, "DontFlatten"), + }),); +#undef CASE +// clang-format on + +TEST_F(OpSelectionMergeTest, CombinedSelectionControlMask) { + const std::string input = "OpSelectionMerge %1 Flatten|DontFlatten"; + const uint32_t expected_mask = + SpvSelectionControlFlattenMask | SpvSelectionControlDontFlattenMask; + EXPECT_THAT(CompiledInstructions(input), + Eq(MakeInstruction(SpvOpSelectionMerge, {1, expected_mask}))); +} + +TEST_F(OpSelectionMergeTest, WrongSelectionControl) { + // Case sensitive: "flatten" != "Flatten" and thus wrong. + EXPECT_THAT(CompileFailure("OpSelectionMerge %1 flatten|DontFlatten"), + Eq("Invalid selection control operand 'flatten|DontFlatten'.")); +} + +// Test OpLoopMerge + +using OpLoopMergeTest = spvtest::TextToBinaryTestBase< + TestWithParam>>>; + +TEST_P(OpLoopMergeTest, AnySingleLoopControlMask) { + const auto ctrl = std::get<1>(GetParam()); + std::ostringstream input; + input << "OpLoopMerge %merge %continue " << ctrl.name(); + for (auto num : ctrl.operands()) input << " " << num; + EXPECT_THAT(CompiledInstructions(input.str(), std::get<0>(GetParam())), + Eq(MakeInstruction(SpvOpLoopMerge, {1, 2, ctrl.value()}, + ctrl.operands()))); +} + +#define CASE(VALUE, NAME) \ + { SpvLoopControl##VALUE, NAME } +#define CASE1(VALUE, NAME, PARM) \ + { \ + SpvLoopControl##VALUE, NAME, { PARM } \ + } +INSTANTIATE_TEST_CASE_P( + TextToBinaryLoopMerge, OpLoopMergeTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector>{ + // clang-format off + CASE(MaskNone, "None"), + CASE(UnrollMask, "Unroll"), + CASE(DontUnrollMask, "DontUnroll"), + // clang-format on + })), ); + +INSTANTIATE_TEST_CASE_P( + TextToBinaryLoopMergeV11, OpLoopMergeTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector>{ + // clang-format off + CASE(DependencyInfiniteMask, "DependencyInfinite"), + CASE1(DependencyLengthMask, "DependencyLength", 234), + {SpvLoopControlUnrollMask|SpvLoopControlDependencyLengthMask, + "DependencyLength|Unroll", {33}}, + // clang-format on + })), ); +#undef CASE +#undef CASE1 + +TEST_F(OpLoopMergeTest, CombinedLoopControlMask) { + const std::string input = "OpLoopMerge %merge %continue Unroll|DontUnroll"; + const uint32_t expected_mask = + SpvLoopControlUnrollMask | SpvLoopControlDontUnrollMask; + EXPECT_THAT(CompiledInstructions(input), + Eq(MakeInstruction(SpvOpLoopMerge, {1, 2, expected_mask}))); +} + +TEST_F(OpLoopMergeTest, WrongLoopControl) { + EXPECT_THAT(CompileFailure("OpLoopMerge %m %c none"), + Eq("Invalid loop control operand 'none'.")); +} + +// Test OpSwitch + +TEST_F(TextToBinaryTest, SwitchGoodZeroTargets) { + EXPECT_THAT(CompiledInstructions("OpSwitch %selector %default"), + Eq(MakeInstruction(SpvOpSwitch, {1, 2}))); +} + +TEST_F(TextToBinaryTest, SwitchGoodOneTarget) { + EXPECT_THAT(CompiledInstructions("%1 = OpTypeInt 32 0\n" + "%2 = OpConstant %1 52\n" + "OpSwitch %2 %default 12 %target0"), + Eq(Concatenate({MakeInstruction(SpvOpTypeInt, {1, 32, 0}), + MakeInstruction(SpvOpConstant, {1, 2, 52}), + MakeInstruction(SpvOpSwitch, {2, 3, 12, 4})}))); +} + +TEST_F(TextToBinaryTest, SwitchGoodTwoTargets) { + EXPECT_THAT( + CompiledInstructions("%1 = OpTypeInt 32 0\n" + "%2 = OpConstant %1 52\n" + "OpSwitch %2 %default 12 %target0 42 %target1"), + Eq(Concatenate({ + MakeInstruction(SpvOpTypeInt, {1, 32, 0}), + MakeInstruction(SpvOpConstant, {1, 2, 52}), + MakeInstruction(SpvOpSwitch, {2, 3, 12, 4, 42, 5}), + }))); +} + +TEST_F(TextToBinaryTest, SwitchBadMissingSelector) { + EXPECT_THAT(CompileFailure("OpSwitch"), + Eq("Expected operand, found end of stream.")); +} + +TEST_F(TextToBinaryTest, SwitchBadInvalidSelector) { + EXPECT_THAT(CompileFailure("OpSwitch 12"), + Eq("Expected id to start with %.")); +} + +TEST_F(TextToBinaryTest, SwitchBadMissingDefault) { + EXPECT_THAT(CompileFailure("OpSwitch %selector"), + Eq("Expected operand, found end of stream.")); +} + +TEST_F(TextToBinaryTest, SwitchBadInvalidDefault) { + EXPECT_THAT(CompileFailure("OpSwitch %selector 12"), + Eq("Expected id to start with %.")); +} + +TEST_F(TextToBinaryTest, SwitchBadInvalidLiteral) { + // The assembler recognizes "OpSwitch %selector %default" as a complete + // instruction. Then it tries to parse "%abc" as the start of a new + // instruction, but can't since it hits the end of stream. + const auto input = R"(%i32 = OpTypeInt 32 0 + %selector = OpConstant %i32 42 + OpSwitch %selector %default %abc)"; + EXPECT_THAT(CompileFailure(input), Eq("Expected '=', found end of stream.")); +} + +TEST_F(TextToBinaryTest, SwitchBadMissingTarget) { + EXPECT_THAT(CompileFailure("%1 = OpTypeInt 32 0\n" + "%2 = OpConstant %1 52\n" + "OpSwitch %2 %default 12"), + Eq("Expected operand, found end of stream.")); +} + +// A test case for an OpSwitch. +// It is also parameterized to test encodings OpConstant +// integer literals. This can capture both single and multi-word +// integer literal tests. +struct SwitchTestCase { + std::string constant_type_args; + std::string constant_value_arg; + std::string case_value_arg; + std::vector expected_instructions; +}; + +using OpSwitchValidTest = + spvtest::TextToBinaryTestBase>; + +// Tests the encoding of OpConstant literal values, and also +// the literal integer cases in an OpSwitch. This can +// test both single and multi-word integer literal encodings. +TEST_P(OpSwitchValidTest, ValidTypes) { + const std::string input = "%1 = OpTypeInt " + GetParam().constant_type_args + + "\n" + "%2 = OpConstant %1 " + + GetParam().constant_value_arg + + "\n" + "OpSwitch %2 %default " + + GetParam().case_value_arg + " %4\n"; + std::vector instructions; + EXPECT_THAT(CompiledInstructions(input), + Eq(GetParam().expected_instructions)); +} + +// Constructs a SwitchTestCase from the given integer_width, signedness, +// constant value string, and expected encoded constant. +SwitchTestCase MakeSwitchTestCase(uint32_t integer_width, + uint32_t integer_signedness, + std::string constant_str, + std::vector encoded_constant, + std::string case_value_str, + std::vector encoded_case_value) { + std::stringstream ss; + ss << integer_width << " " << integer_signedness; + return SwitchTestCase{ + ss.str(), + constant_str, + case_value_str, + {Concatenate( + {MakeInstruction(SpvOpTypeInt, + {1, integer_width, integer_signedness}), + MakeInstruction(SpvOpConstant, + Concatenate({{1, 2}, encoded_constant})), + MakeInstruction(SpvOpSwitch, + Concatenate({{2, 3}, encoded_case_value, {4}}))})}}; +} + +INSTANTIATE_TEST_CASE_P( + TextToBinaryOpSwitchValid1Word, OpSwitchValidTest, + ValuesIn(std::vector({ + MakeSwitchTestCase(32, 0, "42", {42}, "100", {100}), + MakeSwitchTestCase(32, 1, "-1", {0xffffffff}, "100", {100}), + // SPIR-V 1.0 Rev 1 clarified that for an integer narrower than 32-bits, + // its bits will appear in the lower order bits of the 32-bit word, and + // a signed integer is sign-extended. + MakeSwitchTestCase(7, 0, "127", {127}, "100", {100}), + MakeSwitchTestCase(14, 0, "99", {99}, "100", {100}), + MakeSwitchTestCase(16, 0, "65535", {65535}, "100", {100}), + MakeSwitchTestCase(16, 1, "101", {101}, "100", {100}), + // Demonstrate sign extension + MakeSwitchTestCase(16, 1, "-2", {0xfffffffe}, "100", {100}), + // Hex cases + MakeSwitchTestCase(16, 1, "0x7ffe", {0x7ffe}, "0x1234", {0x1234}), + MakeSwitchTestCase(16, 1, "0x8000", {0xffff8000}, "0x8100", + {0xffff8100}), + MakeSwitchTestCase(16, 0, "0x8000", {0x00008000}, "0x8100", {0x8100}), + })), ); + +// NB: The words LOW ORDER bits show up first. +INSTANTIATE_TEST_CASE_P( + TextToBinaryOpSwitchValid2Words, OpSwitchValidTest, + ValuesIn(std::vector({ + MakeSwitchTestCase(33, 0, "101", {101, 0}, "500", {500, 0}), + MakeSwitchTestCase(48, 1, "-1", {0xffffffff, 0xffffffff}, "900", + {900, 0}), + MakeSwitchTestCase(64, 1, "-2", {0xfffffffe, 0xffffffff}, "-5", + {0xfffffffb, uint32_t(-1)}), + // Hex cases + MakeSwitchTestCase(48, 1, "0x7fffffffffff", {0xffffffff, 0x00007fff}, + "100", {100, 0}), + MakeSwitchTestCase(48, 1, "0x800000000000", {0x00000000, 0xffff8000}, + "0x800000000000", {0x00000000, 0xffff8000}), + MakeSwitchTestCase(48, 0, "0x800000000000", {0x00000000, 0x00008000}, + "0x800000000000", {0x00000000, 0x00008000}), + MakeSwitchTestCase(63, 0, "0x500000000", {0, 5}, "12", {12, 0}), + MakeSwitchTestCase(64, 0, "0x600000000", {0, 6}, "12", {12, 0}), + MakeSwitchTestCase(64, 1, "0x700000123", {0x123, 7}, "12", {12, 0}), + })), ); + +INSTANTIATE_TEST_CASE_P( + OpSwitchRoundTripUnsignedIntegers, RoundTripTest, + ValuesIn(std::vector({ + // Unsigned 16-bit. + "%1 = OpTypeInt 16 0\n%2 = OpConstant %1 65535\nOpSwitch %2 %3\n", + // Unsigned 32-bit, three non-default cases. + "%1 = OpTypeInt 32 0\n%2 = OpConstant %1 123456\n" + "OpSwitch %2 %3 100 %4 102 %5 1000000 %6\n", + // Unsigned 48-bit, three non-default cases. + "%1 = OpTypeInt 48 0\n%2 = OpConstant %1 5000000000\n" + "OpSwitch %2 %3 100 %4 102 %5 6000000000 %6\n", + // Unsigned 64-bit, three non-default cases. + "%1 = OpTypeInt 64 0\n%2 = OpConstant %1 9223372036854775807\n" + "OpSwitch %2 %3 100 %4 102 %5 9000000000000000000 %6\n", + })), ); + +INSTANTIATE_TEST_CASE_P( + OpSwitchRoundTripSignedIntegers, RoundTripTest, + ValuesIn(std::vector{ + // Signed 16-bit, with two non-default cases + "%1 = OpTypeInt 16 1\n%2 = OpConstant %1 32767\n" + "OpSwitch %2 %3 99 %4 -102 %5\n", + "%1 = OpTypeInt 16 1\n%2 = OpConstant %1 -32768\n" + "OpSwitch %2 %3 99 %4 -102 %5\n", + // Signed 32-bit, two non-default cases. + "%1 = OpTypeInt 32 1\n%2 = OpConstant %1 -123456\n" + "OpSwitch %2 %3 100 %4 -123456 %5\n", + "%1 = OpTypeInt 32 1\n%2 = OpConstant %1 123456\n" + "OpSwitch %2 %3 100 %4 123456 %5\n", + // Signed 48-bit, three non-default cases. + "%1 = OpTypeInt 48 1\n%2 = OpConstant %1 5000000000\n" + "OpSwitch %2 %3 100 %4 -7000000000 %5 6000000000 %6\n", + "%1 = OpTypeInt 48 1\n%2 = OpConstant %1 -5000000000\n" + "OpSwitch %2 %3 100 %4 -7000000000 %5 6000000000 %6\n", + // Signed 64-bit, three non-default cases. + "%1 = OpTypeInt 64 1\n%2 = OpConstant %1 9223372036854775807\n" + "OpSwitch %2 %3 100 %4 7000000000 %5 -1000000000000000000 %6\n", + "%1 = OpTypeInt 64 1\n%2 = OpConstant %1 -9223372036854775808\n" + "OpSwitch %2 %3 100 %4 7000000000 %5 -1000000000000000000 %6\n", + }), ); + +using OpSwitchInvalidTypeTestCase = + spvtest::TextToBinaryTestBase>; + +TEST_P(OpSwitchInvalidTypeTestCase, InvalidTypes) { + const std::string input = + "%1 = " + GetParam() + + "\n" + "%3 = OpCopyObject %1 %2\n" // We only care the type of the expression + "%4 = OpSwitch %3 %default 32 %c\n"; + EXPECT_THAT(CompileFailure(input), + Eq("The selector operand for OpSwitch must be the result of an " + "instruction that generates an integer scalar")); +} + +// clang-format off +INSTANTIATE_TEST_CASE_P( + TextToBinaryOpSwitchInvalidTests, OpSwitchInvalidTypeTestCase, + ValuesIn(std::vector{ + {"OpTypeVoid", + "OpTypeBool", + "OpTypeFloat 32", + "OpTypeVector %a 32", + "OpTypeMatrix %a 32", + "OpTypeImage %a 1D 0 0 0 0 Unknown", + "OpTypeSampler", + "OpTypeSampledImage %a", + "OpTypeArray %a %b", + "OpTypeRuntimeArray %a", + "OpTypeStruct %a", + "OpTypeOpaque \"Foo\"", + "OpTypePointer UniformConstant %a", + "OpTypeFunction %a %b", + "OpTypeEvent", + "OpTypeDeviceEvent", + "OpTypeReserveId", + "OpTypeQueue", + "OpTypePipe ReadOnly", + "OpTypeForwardPointer %a UniformConstant", + // At least one thing that isn't a type at all + "OpNot %a %b" + }, + }),); +// clang-format on + +// TODO(dneto): OpPhi +// TODO(dneto): OpLoopMerge +// TODO(dneto): OpLabel +// TODO(dneto): OpBranch +// TODO(dneto): OpSwitch +// TODO(dneto): OpKill +// TODO(dneto): OpReturn +// TODO(dneto): OpReturnValue +// TODO(dneto): OpUnreachable +// TODO(dneto): OpLifetimeStart +// TODO(dneto): OpLifetimeStop + +} // namespace +} // namespace spvtools diff --git a/test/text_to_binary.debug_test.cpp b/test/text_to_binary.debug_test.cpp new file mode 100644 index 000000000..b85650e5e --- /dev/null +++ b/test/text_to_binary.debug_test.cpp @@ -0,0 +1,214 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Assembler tests for instructions in the "Debug" section of the +// SPIR-V spec. + +#include +#include + +#include "gmock/gmock.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using spvtest::MakeInstruction; +using spvtest::MakeVector; +using spvtest::TextToBinaryTest; +using ::testing::Eq; + +// Test OpSource + +// A single test case for OpSource +struct LanguageCase { + uint32_t get_language_value() const { + return static_cast(language_value); + } + const char* language_name; + SpvSourceLanguage language_value; + uint32_t version; +}; + +// clang-format off +// The list of OpSource cases to use. +const LanguageCase kLanguageCases[] = { +#define CASE(NAME, VERSION) \ + { #NAME, SpvSourceLanguage##NAME, VERSION } + CASE(Unknown, 0), + CASE(Unknown, 999), + CASE(ESSL, 310), + CASE(GLSL, 450), + CASE(OpenCL_C, 120), + CASE(OpenCL_C, 200), + CASE(OpenCL_C, 210), + CASE(OpenCL_CPP, 210), + CASE(HLSL, 5), + CASE(HLSL, 6), +#undef CASE +}; +// clang-format on + +using OpSourceTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>; + +TEST_P(OpSourceTest, AnyLanguage) { + const std::string input = std::string("OpSource ") + + GetParam().language_name + " " + + std::to_string(GetParam().version); + EXPECT_THAT(CompiledInstructions(input), + Eq(MakeInstruction(SpvOpSource, {GetParam().get_language_value(), + GetParam().version}))); +} + +INSTANTIATE_TEST_CASE_P(TextToBinaryTestDebug, OpSourceTest, + ::testing::ValuesIn(kLanguageCases), ); + +TEST_F(OpSourceTest, WrongLanguage) { + EXPECT_THAT(CompileFailure("OpSource xxyyzz 12345"), + Eq("Invalid source language 'xxyyzz'.")); +} + +TEST_F(TextToBinaryTest, OpSourceAcceptsOptionalFileId) { + // In the grammar, the file id is an OperandOptionalId. + const std::string input = "OpSource GLSL 450 %file_id"; + EXPECT_THAT( + CompiledInstructions(input), + Eq(MakeInstruction(SpvOpSource, {SpvSourceLanguageGLSL, 450, 1}))); +} + +TEST_F(TextToBinaryTest, OpSourceAcceptsOptionalSourceText) { + std::string fake_source = "To be or not to be"; + const std::string input = + "OpSource GLSL 450 %file_id \"" + fake_source + "\""; + EXPECT_THAT(CompiledInstructions(input), + Eq(MakeInstruction(SpvOpSource, {SpvSourceLanguageGLSL, 450, 1}, + MakeVector(fake_source)))); +} + +// Test OpSourceContinued + +using OpSourceContinuedTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>; + +TEST_P(OpSourceContinuedTest, AnyExtension) { + // TODO(dneto): utf-8, quoting, escaping + const std::string input = + std::string("OpSourceContinued \"") + GetParam() + "\""; + EXPECT_THAT( + CompiledInstructions(input), + Eq(MakeInstruction(SpvOpSourceContinued, MakeVector(GetParam())))); +} + +// TODO(dneto): utf-8, quoting, escaping +INSTANTIATE_TEST_CASE_P(TextToBinaryTestDebug, OpSourceContinuedTest, + ::testing::ValuesIn(std::vector{ + "", "foo bar this and that"}), ); + +// Test OpSourceExtension + +using OpSourceExtensionTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>; + +TEST_P(OpSourceExtensionTest, AnyExtension) { + // TODO(dneto): utf-8, quoting, escaping + const std::string input = + std::string("OpSourceExtension \"") + GetParam() + "\""; + EXPECT_THAT( + CompiledInstructions(input), + Eq(MakeInstruction(SpvOpSourceExtension, MakeVector(GetParam())))); +} + +// TODO(dneto): utf-8, quoting, escaping +INSTANTIATE_TEST_CASE_P(TextToBinaryTestDebug, OpSourceExtensionTest, + ::testing::ValuesIn(std::vector{ + "", "foo bar this and that"}), ); + +TEST_F(TextToBinaryTest, OpLine) { + EXPECT_THAT(CompiledInstructions("OpLine %srcfile 42 99"), + Eq(MakeInstruction(SpvOpLine, {1, 42, 99}))); +} + +TEST_F(TextToBinaryTest, OpNoLine) { + EXPECT_THAT(CompiledInstructions("OpNoLine"), + Eq(MakeInstruction(SpvOpNoLine, {}))); +} + +using OpStringTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>; + +TEST_P(OpStringTest, AnyString) { + // TODO(dneto): utf-8, quoting, escaping + const std::string input = + std::string("%result = OpString \"") + GetParam() + "\""; + EXPECT_THAT(CompiledInstructions(input), + Eq(MakeInstruction(SpvOpString, {1}, MakeVector(GetParam())))); +} + +// TODO(dneto): utf-8, quoting, escaping +INSTANTIATE_TEST_CASE_P(TextToBinaryTestDebug, OpStringTest, + ::testing::ValuesIn(std::vector{ + "", "foo bar this and that"}), ); + +using OpNameTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>; + +TEST_P(OpNameTest, AnyString) { + const std::string input = + std::string("OpName %target \"") + GetParam() + "\""; + EXPECT_THAT(CompiledInstructions(input), + Eq(MakeInstruction(SpvOpName, {1}, MakeVector(GetParam())))); +} + +// UTF-8, quoting, escaping, etc. are covered in the StringLiterals tests in +// BinaryToText.Literal.cpp. +INSTANTIATE_TEST_CASE_P(TextToBinaryTestDebug, OpNameTest, + ::testing::Values("", "foo bar this and that"), ); + +using OpMemberNameTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>; + +TEST_P(OpMemberNameTest, AnyString) { + // TODO(dneto): utf-8, quoting, escaping + const std::string input = + std::string("OpMemberName %type 42 \"") + GetParam() + "\""; + EXPECT_THAT( + CompiledInstructions(input), + Eq(MakeInstruction(SpvOpMemberName, {1, 42}, MakeVector(GetParam())))); +} + +// TODO(dneto): utf-8, quoting, escaping +INSTANTIATE_TEST_CASE_P(TextToBinaryTestDebug, OpMemberNameTest, + ::testing::ValuesIn(std::vector{ + "", "foo bar this and that"}), ); + +// TODO(dneto): Parse failures? + +using OpModuleProcessedTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>; + +TEST_P(OpModuleProcessedTest, AnyString) { + const std::string input = + std::string("OpModuleProcessed \"") + GetParam() + "\""; + EXPECT_THAT( + CompiledInstructions(input, SPV_ENV_UNIVERSAL_1_1), + Eq(MakeInstruction(SpvOpModuleProcessed, MakeVector(GetParam())))); +} + +INSTANTIATE_TEST_CASE_P(TextToBinaryTestDebug, OpModuleProcessedTest, + ::testing::Values("", "foo bar this and that"), ); + +} // namespace +} // namespace spvtools diff --git a/test/text_to_binary.device_side_enqueue_test.cpp b/test/text_to_binary.device_side_enqueue_test.cpp new file mode 100644 index 000000000..25c100b8e --- /dev/null +++ b/test/text_to_binary.device_side_enqueue_test.cpp @@ -0,0 +1,112 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Assembler tests for instructions in the "Device-Side Enqueue Instructions" +// section of the SPIR-V spec. + +#include +#include + +#include "gmock/gmock.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using spvtest::MakeInstruction; +using ::testing::Eq; + +// Test OpEnqueueKernel + +struct KernelEnqueueCase { + std::string local_size_source; + std::vector local_size_operands; +}; + +using OpEnqueueKernelGood = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>; + +TEST_P(OpEnqueueKernelGood, Sample) { + const std::string input = + "%result = OpEnqueueKernel %type %queue %flags %NDRange %num_events" + " %wait_events %ret_event %invoke %param %param_size %param_align " + + GetParam().local_size_source; + EXPECT_THAT(CompiledInstructions(input), + Eq(MakeInstruction(SpvOpEnqueueKernel, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + GetParam().local_size_operands))); +} + +INSTANTIATE_TEST_CASE_P( + TextToBinaryTest, OpEnqueueKernelGood, + ::testing::ValuesIn(std::vector{ + // Provide IDs for pointer-to-local arguments for the + // invoked function. + // Test up to 10 such arguments. + // I (dneto) can't find a limit on the number of kernel + // arguments in OpenCL C 2.0 Rev 29, e.g. in section 6.9 + // Restrictions. + {"", {}}, + {"%l0", {13}}, + {"%l0 %l1", {13, 14}}, + {"%l0 %l1 %l2", {13, 14, 15}}, + {"%l0 %l1 %l2 %l3", {13, 14, 15, 16}}, + {"%l0 %l1 %l2 %l3 %l4", {13, 14, 15, 16, 17}}, + {"%l0 %l1 %l2 %l3 %l4 %l5", {13, 14, 15, 16, 17, 18}}, + {"%l0 %l1 %l2 %l3 %l4 %l5 %l6", {13, 14, 15, 16, 17, 18, 19}}, + {"%l0 %l1 %l2 %l3 %l4 %l5 %l6 %l7", {13, 14, 15, 16, 17, 18, 19, 20}}, + {"%l0 %l1 %l2 %l3 %l4 %l5 %l6 %l7 %l8", + {13, 14, 15, 16, 17, 18, 19, 20, 21}}, + {"%l0 %l1 %l2 %l3 %l4 %l5 %l6 %l7 %l8 %l9", + {13, 14, 15, 16, 17, 18, 19, 20, 21, 22}}, + }), ); + +// Test some bad parses of OpEnqueueKernel. For other cases, we're relying +// on the uniformity of the parsing algorithm. The following two tests, ensure +// that every required ID operand is specified, and is actually an ID operand. +using OpKernelEnqueueBad = spvtest::TextToBinaryTest; + +TEST_F(OpKernelEnqueueBad, MissingLastOperand) { + EXPECT_THAT( + CompileFailure( + "%result = OpEnqueueKernel %type %queue %flags %NDRange %num_events" + " %wait_events %ret_event %invoke %param %param_size"), + Eq("Expected operand, found end of stream.")); +} + +TEST_F(OpKernelEnqueueBad, InvalidLastOperand) { + EXPECT_THAT( + CompileFailure( + "%result = OpEnqueueKernel %type %queue %flags %NDRange %num_events" + " %wait_events %ret_event %invoke %param %param_size 42"), + Eq("Expected id to start with %.")); +} + +// TODO(dneto): OpEnqueueMarker +// TODO(dneto): OpGetKernelNDRangeSubGroupCount +// TODO(dneto): OpGetKernelNDRangeMaxSubGroupSize +// TODO(dneto): OpGetKernelWorkGroupSize +// TODO(dneto): OpGetKernelPreferredWorkGroupSizeMultiple +// TODO(dneto): OpRetainEvent +// TODO(dneto): OpReleaseEvent +// TODO(dneto): OpCreateUserEvent +// TODO(dneto): OpSetUserEventStatus +// TODO(dneto): OpCaptureEventProfilingInfo +// TODO(dneto): OpGetDefaultQueue +// TODO(dneto): OpBuildNDRange +// TODO(dneto): OpBuildNDRange + +} // namespace +} // namespace spvtools diff --git a/test/text_to_binary.extension_test.cpp b/test/text_to_binary.extension_test.cpp new file mode 100644 index 000000000..5c0bf9885 --- /dev/null +++ b/test/text_to_binary.extension_test.cpp @@ -0,0 +1,857 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Assembler tests for instructions in the "Extension Instruction" section +// of the SPIR-V spec. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/latest_version_glsl_std_450_header.h" +#include "source/latest_version_opencl_std_header.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using spvtest::Concatenate; +using spvtest::MakeInstruction; +using spvtest::MakeVector; +using spvtest::TextToBinaryTest; +using ::testing::Combine; +using ::testing::Eq; +using ::testing::Values; +using ::testing::ValuesIn; + +// Returns a generator of common Vulkan environment values to be tested. +std::vector CommonVulkanEnvs() { + return {SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, SPV_ENV_UNIVERSAL_1_2, + SPV_ENV_UNIVERSAL_1_3, SPV_ENV_VULKAN_1_0, SPV_ENV_VULKAN_1_1}; +} + +TEST_F(TextToBinaryTest, InvalidExtInstImportName) { + EXPECT_THAT(CompileFailure("%1 = OpExtInstImport \"Haskell.std\""), + Eq("Invalid extended instruction import 'Haskell.std'")); +} + +TEST_F(TextToBinaryTest, InvalidImportId) { + EXPECT_THAT(CompileFailure("%1 = OpTypeVoid\n" + "%2 = OpExtInst %1 %1"), + Eq("Invalid extended instruction import Id 2")); +} + +TEST_F(TextToBinaryTest, InvalidImportInstruction) { + const std::string input = R"(%1 = OpTypeVoid + %2 = OpExtInstImport "OpenCL.std" + %3 = OpExtInst %1 %2 not_in_the_opencl)"; + EXPECT_THAT(CompileFailure(input), + Eq("Invalid extended instruction name 'not_in_the_opencl'.")); +} + +TEST_F(TextToBinaryTest, MultiImport) { + const std::string input = R"(%2 = OpExtInstImport "OpenCL.std" + %2 = OpExtInstImport "OpenCL.std")"; + EXPECT_THAT(CompileFailure(input), + Eq("Import Id is being defined a second time")); +} + +TEST_F(TextToBinaryTest, TooManyArguments) { + const std::string input = R"(%opencl = OpExtInstImport "OpenCL.std" + %2 = OpExtInst %float %opencl cos %x %oops")"; + EXPECT_THAT(CompileFailure(input), Eq("Expected '=', found end of stream.")); +} + +TEST_F(TextToBinaryTest, ExtInstFromTwoDifferentImports) { + const std::string input = R"(%1 = OpExtInstImport "OpenCL.std" +%2 = OpExtInstImport "GLSL.std.450" +%4 = OpExtInst %3 %1 native_sqrt %5 +%7 = OpExtInst %6 %2 MatrixInverse %8 +)"; + + // Make sure it assembles correctly. + EXPECT_THAT( + CompiledInstructions(input), + Eq(Concatenate({ + MakeInstruction(SpvOpExtInstImport, {1}, MakeVector("OpenCL.std")), + MakeInstruction(SpvOpExtInstImport, {2}, MakeVector("GLSL.std.450")), + MakeInstruction( + SpvOpExtInst, + {3, 4, 1, uint32_t(OpenCLLIB::Entrypoints::Native_sqrt), 5}), + MakeInstruction(SpvOpExtInst, + {6, 7, 2, uint32_t(GLSLstd450MatrixInverse), 8}), + }))); + + // Make sure it disassembles correctly. + EXPECT_THAT(EncodeAndDecodeSuccessfully(input), Eq(input)); +} + +// A test case for assembling into words in an instruction. +struct AssemblyCase { + std::string input; + std::vector expected; +}; + +using ExtensionAssemblyTest = spvtest::TextToBinaryTestBase< + ::testing::TestWithParam>>; + +TEST_P(ExtensionAssemblyTest, Samples) { + const spv_target_env& env = std::get<0>(GetParam()); + const AssemblyCase& ac = std::get<1>(GetParam()); + + // Check that it assembles correctly. + EXPECT_THAT(CompiledInstructions(ac.input, env), Eq(ac.expected)); +} + +using ExtensionRoundTripTest = spvtest::TextToBinaryTestBase< + ::testing::TestWithParam>>; + +TEST_P(ExtensionRoundTripTest, Samples) { + const spv_target_env& env = std::get<0>(GetParam()); + const AssemblyCase& ac = std::get<1>(GetParam()); + + // Check that it assembles correctly. + EXPECT_THAT(CompiledInstructions(ac.input, env), Eq(ac.expected)); + + // Check round trip through the disassembler. + EXPECT_THAT(EncodeAndDecodeSuccessfully(ac.input, + SPV_BINARY_TO_TEXT_OPTION_NONE, env), + Eq(ac.input)) + << "target env: " << spvTargetEnvDescription(env) << "\n"; +} + +// SPV_KHR_shader_ballot + +INSTANTIATE_TEST_CASE_P( + SPV_KHR_shader_ballot, ExtensionRoundTripTest, + // We'll get coverage over operand tables by trying the universal + // environments, and at least one specific environment. + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, + SPV_ENV_VULKAN_1_0), + ValuesIn(std::vector{ + {"OpCapability SubgroupBallotKHR\n", + MakeInstruction(SpvOpCapability, + {SpvCapabilitySubgroupBallotKHR})}, + {"%2 = OpSubgroupBallotKHR %1 %3\n", + MakeInstruction(SpvOpSubgroupBallotKHR, {1, 2, 3})}, + {"%2 = OpSubgroupFirstInvocationKHR %1 %3\n", + MakeInstruction(SpvOpSubgroupFirstInvocationKHR, {1, 2, 3})}, + {"OpDecorate %1 BuiltIn SubgroupEqMask\n", + MakeInstruction(SpvOpDecorate, {1, SpvDecorationBuiltIn, + SpvBuiltInSubgroupEqMaskKHR})}, + {"OpDecorate %1 BuiltIn SubgroupGeMask\n", + MakeInstruction(SpvOpDecorate, {1, SpvDecorationBuiltIn, + SpvBuiltInSubgroupGeMaskKHR})}, + {"OpDecorate %1 BuiltIn SubgroupGtMask\n", + MakeInstruction(SpvOpDecorate, {1, SpvDecorationBuiltIn, + SpvBuiltInSubgroupGtMaskKHR})}, + {"OpDecorate %1 BuiltIn SubgroupLeMask\n", + MakeInstruction(SpvOpDecorate, {1, SpvDecorationBuiltIn, + SpvBuiltInSubgroupLeMaskKHR})}, + {"OpDecorate %1 BuiltIn SubgroupLtMask\n", + MakeInstruction(SpvOpDecorate, {1, SpvDecorationBuiltIn, + SpvBuiltInSubgroupLtMaskKHR})}, + })), ); + +INSTANTIATE_TEST_CASE_P( + SPV_KHR_shader_ballot_vulkan_1_1, ExtensionRoundTripTest, + // In SPIR-V 1.3 and Vulkan 1.1 we can drop the KHR suffix on the + // builtin enums. + Combine(Values(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_VULKAN_1_1), + ValuesIn(std::vector{ + {"OpCapability SubgroupBallotKHR\n", + MakeInstruction(SpvOpCapability, + {SpvCapabilitySubgroupBallotKHR})}, + {"%2 = OpSubgroupBallotKHR %1 %3\n", + MakeInstruction(SpvOpSubgroupBallotKHR, {1, 2, 3})}, + {"%2 = OpSubgroupFirstInvocationKHR %1 %3\n", + MakeInstruction(SpvOpSubgroupFirstInvocationKHR, {1, 2, 3})}, + {"OpDecorate %1 BuiltIn SubgroupEqMask\n", + MakeInstruction(SpvOpDecorate, {1, SpvDecorationBuiltIn, + SpvBuiltInSubgroupEqMask})}, + {"OpDecorate %1 BuiltIn SubgroupGeMask\n", + MakeInstruction(SpvOpDecorate, {1, SpvDecorationBuiltIn, + SpvBuiltInSubgroupGeMask})}, + {"OpDecorate %1 BuiltIn SubgroupGtMask\n", + MakeInstruction(SpvOpDecorate, {1, SpvDecorationBuiltIn, + SpvBuiltInSubgroupGtMask})}, + {"OpDecorate %1 BuiltIn SubgroupLeMask\n", + MakeInstruction(SpvOpDecorate, {1, SpvDecorationBuiltIn, + SpvBuiltInSubgroupLeMask})}, + {"OpDecorate %1 BuiltIn SubgroupLtMask\n", + MakeInstruction(SpvOpDecorate, {1, SpvDecorationBuiltIn, + SpvBuiltInSubgroupLtMask})}, + })), ); + +// The old builtin names (with KHR suffix) still work in the assmebler, and +// map to the enums without the KHR. +INSTANTIATE_TEST_CASE_P( + SPV_KHR_shader_ballot_vulkan_1_1_alias_check, ExtensionAssemblyTest, + // In SPIR-V 1.3 and Vulkan 1.1 we can drop the KHR suffix on the + // builtin enums. + Combine(Values(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_VULKAN_1_1), + ValuesIn(std::vector{ + {"OpDecorate %1 BuiltIn SubgroupEqMaskKHR\n", + MakeInstruction(SpvOpDecorate, {1, SpvDecorationBuiltIn, + SpvBuiltInSubgroupEqMask})}, + {"OpDecorate %1 BuiltIn SubgroupGeMaskKHR\n", + MakeInstruction(SpvOpDecorate, {1, SpvDecorationBuiltIn, + SpvBuiltInSubgroupGeMask})}, + {"OpDecorate %1 BuiltIn SubgroupGtMaskKHR\n", + MakeInstruction(SpvOpDecorate, {1, SpvDecorationBuiltIn, + SpvBuiltInSubgroupGtMask})}, + {"OpDecorate %1 BuiltIn SubgroupLeMaskKHR\n", + MakeInstruction(SpvOpDecorate, {1, SpvDecorationBuiltIn, + SpvBuiltInSubgroupLeMask})}, + {"OpDecorate %1 BuiltIn SubgroupLtMaskKHR\n", + MakeInstruction(SpvOpDecorate, {1, SpvDecorationBuiltIn, + SpvBuiltInSubgroupLtMask})}, + })), ); + +// SPV_KHR_shader_draw_parameters + +INSTANTIATE_TEST_CASE_P( + SPV_KHR_shader_draw_parameters, ExtensionRoundTripTest, + // We'll get coverage over operand tables by trying the universal + // environments, and at least one specific environment. + Combine( + ValuesIn(CommonVulkanEnvs()), + ValuesIn(std::vector{ + {"OpCapability DrawParameters\n", + MakeInstruction(SpvOpCapability, {SpvCapabilityDrawParameters})}, + {"OpDecorate %1 BuiltIn BaseVertex\n", + MakeInstruction(SpvOpDecorate, + {1, SpvDecorationBuiltIn, SpvBuiltInBaseVertex})}, + {"OpDecorate %1 BuiltIn BaseInstance\n", + MakeInstruction(SpvOpDecorate, {1, SpvDecorationBuiltIn, + SpvBuiltInBaseInstance})}, + {"OpDecorate %1 BuiltIn DrawIndex\n", + MakeInstruction(SpvOpDecorate, + {1, SpvDecorationBuiltIn, SpvBuiltInDrawIndex})}, + })), ); + +// SPV_KHR_subgroup_vote + +INSTANTIATE_TEST_CASE_P( + SPV_KHR_subgroup_vote, ExtensionRoundTripTest, + // We'll get coverage over operand tables by trying the universal + // environments, and at least one specific environment. + Combine(ValuesIn(CommonVulkanEnvs()), + ValuesIn(std::vector{ + {"OpCapability SubgroupVoteKHR\n", + MakeInstruction(SpvOpCapability, + {SpvCapabilitySubgroupVoteKHR})}, + {"%2 = OpSubgroupAnyKHR %1 %3\n", + MakeInstruction(SpvOpSubgroupAnyKHR, {1, 2, 3})}, + {"%2 = OpSubgroupAllKHR %1 %3\n", + MakeInstruction(SpvOpSubgroupAllKHR, {1, 2, 3})}, + {"%2 = OpSubgroupAllEqualKHR %1 %3\n", + MakeInstruction(SpvOpSubgroupAllEqualKHR, {1, 2, 3})}, + })), ); + +// SPV_KHR_16bit_storage + +INSTANTIATE_TEST_CASE_P( + SPV_KHR_16bit_storage, ExtensionRoundTripTest, + // We'll get coverage over operand tables by trying the universal + // environments, and at least one specific environment. + Combine(ValuesIn(CommonVulkanEnvs()), + ValuesIn(std::vector{ + {"OpCapability StorageBuffer16BitAccess\n", + MakeInstruction(SpvOpCapability, + {SpvCapabilityStorageUniformBufferBlock16})}, + {"OpCapability StorageBuffer16BitAccess\n", + MakeInstruction(SpvOpCapability, + {SpvCapabilityStorageBuffer16BitAccess})}, + {"OpCapability StorageUniform16\n", + MakeInstruction( + SpvOpCapability, + {SpvCapabilityUniformAndStorageBuffer16BitAccess})}, + {"OpCapability StorageUniform16\n", + MakeInstruction(SpvOpCapability, + {SpvCapabilityStorageUniform16})}, + {"OpCapability StoragePushConstant16\n", + MakeInstruction(SpvOpCapability, + {SpvCapabilityStoragePushConstant16})}, + {"OpCapability StorageInputOutput16\n", + MakeInstruction(SpvOpCapability, + {SpvCapabilityStorageInputOutput16})}, + })), ); + +INSTANTIATE_TEST_CASE_P( + SPV_KHR_16bit_storage_alias_check, ExtensionAssemblyTest, + Combine(ValuesIn(CommonVulkanEnvs()), + ValuesIn(std::vector{ + // The old name maps to the new enum. + {"OpCapability StorageUniformBufferBlock16\n", + MakeInstruction(SpvOpCapability, + {SpvCapabilityStorageBuffer16BitAccess})}, + // The new name maps to the old enum. + {"OpCapability UniformAndStorageBuffer16BitAccess\n", + MakeInstruction(SpvOpCapability, + {SpvCapabilityStorageUniform16})}, + })), ); + +// SPV_KHR_device_group + +INSTANTIATE_TEST_CASE_P( + SPV_KHR_device_group, ExtensionRoundTripTest, + // We'll get coverage over operand tables by trying the universal + // environments, and at least one specific environment. + Combine(ValuesIn(CommonVulkanEnvs()), + ValuesIn(std::vector{ + {"OpCapability DeviceGroup\n", + MakeInstruction(SpvOpCapability, {SpvCapabilityDeviceGroup})}, + {"OpDecorate %1 BuiltIn DeviceIndex\n", + MakeInstruction(SpvOpDecorate, {1, SpvDecorationBuiltIn, + SpvBuiltInDeviceIndex})}, + })), ); + +// SPV_KHR_8bit_storage + +INSTANTIATE_TEST_CASE_P( + SPV_KHR_8bit_storage, ExtensionRoundTripTest, + // We'll get coverage over operand tables by trying the universal + // environments, and at least one specific environment. + Combine( + ValuesIn(CommonVulkanEnvs()), + ValuesIn(std::vector{ + {"OpCapability StorageBuffer8BitAccess\n", + MakeInstruction(SpvOpCapability, + {SpvCapabilityStorageBuffer8BitAccess})}, + {"OpCapability UniformAndStorageBuffer8BitAccess\n", + MakeInstruction(SpvOpCapability, + {SpvCapabilityUniformAndStorageBuffer8BitAccess})}, + {"OpCapability StoragePushConstant8\n", + MakeInstruction(SpvOpCapability, + {SpvCapabilityStoragePushConstant8})}, + })), ); + +// SPV_KHR_multiview + +INSTANTIATE_TEST_CASE_P( + SPV_KHR_multiview, ExtensionRoundTripTest, + // We'll get coverage over operand tables by trying the universal + // environments, and at least one specific environment. + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, + SPV_ENV_VULKAN_1_0), + ValuesIn(std::vector{ + {"OpCapability MultiView\n", + MakeInstruction(SpvOpCapability, {SpvCapabilityMultiView})}, + {"OpDecorate %1 BuiltIn ViewIndex\n", + MakeInstruction(SpvOpDecorate, {1, SpvDecorationBuiltIn, + SpvBuiltInViewIndex})}, + })), ); + +// SPV_AMD_shader_explicit_vertex_parameter + +#define PREAMBLE \ + "%1 = OpExtInstImport \"SPV_AMD_shader_explicit_vertex_parameter\"\n" +INSTANTIATE_TEST_CASE_P( + SPV_AMD_shader_explicit_vertex_parameter, ExtensionRoundTripTest, + // We'll get coverage over operand tables by trying the universal + // environments, and at least one specific environment. + Combine( + Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, + SPV_ENV_VULKAN_1_0), + ValuesIn(std::vector{ + {PREAMBLE "%3 = OpExtInst %2 %1 InterpolateAtVertexAMD %4 %5\n", + Concatenate( + {MakeInstruction( + SpvOpExtInstImport, {1}, + MakeVector("SPV_AMD_shader_explicit_vertex_parameter")), + MakeInstruction(SpvOpExtInst, {2, 3, 1, 1, 4, 5})})}, + })), ); +#undef PREAMBLE + +// SPV_AMD_shader_trinary_minmax + +#define PREAMBLE "%1 = OpExtInstImport \"SPV_AMD_shader_trinary_minmax\"\n" +INSTANTIATE_TEST_CASE_P( + SPV_AMD_shader_trinary_minmax, ExtensionRoundTripTest, + // We'll get coverage over operand tables by trying the universal + // environments, and at least one specific environment. + Combine( + Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, + SPV_ENV_VULKAN_1_0), + ValuesIn(std::vector{ + {PREAMBLE "%3 = OpExtInst %2 %1 FMin3AMD %4 %5 %6\n", + Concatenate( + {MakeInstruction(SpvOpExtInstImport, {1}, + MakeVector("SPV_AMD_shader_trinary_minmax")), + MakeInstruction(SpvOpExtInst, {2, 3, 1, 1, 4, 5, 6})})}, + {PREAMBLE "%3 = OpExtInst %2 %1 UMin3AMD %4 %5 %6\n", + Concatenate( + {MakeInstruction(SpvOpExtInstImport, {1}, + MakeVector("SPV_AMD_shader_trinary_minmax")), + MakeInstruction(SpvOpExtInst, {2, 3, 1, 2, 4, 5, 6})})}, + {PREAMBLE "%3 = OpExtInst %2 %1 SMin3AMD %4 %5 %6\n", + Concatenate( + {MakeInstruction(SpvOpExtInstImport, {1}, + MakeVector("SPV_AMD_shader_trinary_minmax")), + MakeInstruction(SpvOpExtInst, {2, 3, 1, 3, 4, 5, 6})})}, + {PREAMBLE "%3 = OpExtInst %2 %1 FMax3AMD %4 %5 %6\n", + Concatenate( + {MakeInstruction(SpvOpExtInstImport, {1}, + MakeVector("SPV_AMD_shader_trinary_minmax")), + MakeInstruction(SpvOpExtInst, {2, 3, 1, 4, 4, 5, 6})})}, + {PREAMBLE "%3 = OpExtInst %2 %1 UMax3AMD %4 %5 %6\n", + Concatenate( + {MakeInstruction(SpvOpExtInstImport, {1}, + MakeVector("SPV_AMD_shader_trinary_minmax")), + MakeInstruction(SpvOpExtInst, {2, 3, 1, 5, 4, 5, 6})})}, + {PREAMBLE "%3 = OpExtInst %2 %1 SMax3AMD %4 %5 %6\n", + Concatenate( + {MakeInstruction(SpvOpExtInstImport, {1}, + MakeVector("SPV_AMD_shader_trinary_minmax")), + MakeInstruction(SpvOpExtInst, {2, 3, 1, 6, 4, 5, 6})})}, + {PREAMBLE "%3 = OpExtInst %2 %1 FMid3AMD %4 %5 %6\n", + Concatenate( + {MakeInstruction(SpvOpExtInstImport, {1}, + MakeVector("SPV_AMD_shader_trinary_minmax")), + MakeInstruction(SpvOpExtInst, {2, 3, 1, 7, 4, 5, 6})})}, + {PREAMBLE "%3 = OpExtInst %2 %1 UMid3AMD %4 %5 %6\n", + Concatenate( + {MakeInstruction(SpvOpExtInstImport, {1}, + MakeVector("SPV_AMD_shader_trinary_minmax")), + MakeInstruction(SpvOpExtInst, {2, 3, 1, 8, 4, 5, 6})})}, + {PREAMBLE "%3 = OpExtInst %2 %1 SMid3AMD %4 %5 %6\n", + Concatenate( + {MakeInstruction(SpvOpExtInstImport, {1}, + MakeVector("SPV_AMD_shader_trinary_minmax")), + MakeInstruction(SpvOpExtInst, {2, 3, 1, 9, 4, 5, 6})})}, + })), ); +#undef PREAMBLE + +// SPV_AMD_gcn_shader + +#define PREAMBLE "%1 = OpExtInstImport \"SPV_AMD_gcn_shader\"\n" +INSTANTIATE_TEST_CASE_P( + SPV_AMD_gcn_shader, ExtensionRoundTripTest, + // We'll get coverage over operand tables by trying the universal + // environments, and at least one specific environment. + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, + SPV_ENV_VULKAN_1_0), + ValuesIn(std::vector{ + {PREAMBLE "%3 = OpExtInst %2 %1 CubeFaceIndexAMD %4\n", + Concatenate({MakeInstruction(SpvOpExtInstImport, {1}, + MakeVector("SPV_AMD_gcn_shader")), + MakeInstruction(SpvOpExtInst, {2, 3, 1, 1, 4})})}, + {PREAMBLE "%3 = OpExtInst %2 %1 CubeFaceCoordAMD %4\n", + Concatenate({MakeInstruction(SpvOpExtInstImport, {1}, + MakeVector("SPV_AMD_gcn_shader")), + MakeInstruction(SpvOpExtInst, {2, 3, 1, 2, 4})})}, + {PREAMBLE "%3 = OpExtInst %2 %1 TimeAMD\n", + Concatenate({MakeInstruction(SpvOpExtInstImport, {1}, + MakeVector("SPV_AMD_gcn_shader")), + MakeInstruction(SpvOpExtInst, {2, 3, 1, 3})})}, + })), ); +#undef PREAMBLE + +// SPV_AMD_shader_ballot + +#define PREAMBLE "%1 = OpExtInstImport \"SPV_AMD_shader_ballot\"\n" +INSTANTIATE_TEST_CASE_P( + SPV_AMD_shader_ballot, ExtensionRoundTripTest, + // We'll get coverage over operand tables by trying the universal + // environments, and at least one specific environment. + Combine( + Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, + SPV_ENV_VULKAN_1_0), + ValuesIn(std::vector{ + {PREAMBLE "%3 = OpExtInst %2 %1 SwizzleInvocationsAMD %4 %5\n", + Concatenate({MakeInstruction(SpvOpExtInstImport, {1}, + MakeVector("SPV_AMD_shader_ballot")), + MakeInstruction(SpvOpExtInst, {2, 3, 1, 1, 4, 5})})}, + {PREAMBLE + "%3 = OpExtInst %2 %1 SwizzleInvocationsMaskedAMD %4 %5\n", + Concatenate({MakeInstruction(SpvOpExtInstImport, {1}, + MakeVector("SPV_AMD_shader_ballot")), + MakeInstruction(SpvOpExtInst, {2, 3, 1, 2, 4, 5})})}, + {PREAMBLE "%3 = OpExtInst %2 %1 WriteInvocationAMD %4 %5 %6\n", + Concatenate({MakeInstruction(SpvOpExtInstImport, {1}, + MakeVector("SPV_AMD_shader_ballot")), + MakeInstruction(SpvOpExtInst, + {2, 3, 1, 3, 4, 5, 6})})}, + {PREAMBLE "%3 = OpExtInst %2 %1 MbcntAMD %4\n", + Concatenate({MakeInstruction(SpvOpExtInstImport, {1}, + MakeVector("SPV_AMD_shader_ballot")), + MakeInstruction(SpvOpExtInst, {2, 3, 1, 4, 4})})}, + })), ); +#undef PREAMBLE + +// SPV_KHR_variable_pointers + +INSTANTIATE_TEST_CASE_P( + SPV_KHR_variable_pointers, ExtensionRoundTripTest, + // We'll get coverage over operand tables by trying the universal + // environments, and at least one specific environment. + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, + SPV_ENV_VULKAN_1_0), + ValuesIn(std::vector{ + {"OpCapability VariablePointers\n", + MakeInstruction(SpvOpCapability, + {SpvCapabilityVariablePointers})}, + {"OpCapability VariablePointersStorageBuffer\n", + MakeInstruction(SpvOpCapability, + {SpvCapabilityVariablePointersStorageBuffer})}, + })), ); + +// SPV_KHR_vulkan_memory_model + +INSTANTIATE_TEST_CASE_P( + SPV_KHR_vulkan_memory_model, ExtensionRoundTripTest, + // We'll get coverage over operand tables by trying the universal + // environments, and at least one specific environment. + // + // Note: SPV_KHR_vulkan_memory_model adds scope enum value QueueFamilyKHR. + // Scope enums are used in ID definitions elsewhere, that don't know they + // are using particular enums. So the assembler doesn't support assembling + // those enums names into the corresponding values. So there is no asm/dis + // tests for those enums. + Combine( + Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, + SPV_ENV_UNIVERSAL_1_3, SPV_ENV_VULKAN_1_0, SPV_ENV_VULKAN_1_1), + ValuesIn(std::vector{ + {"OpCapability VulkanMemoryModelKHR\n", + MakeInstruction(SpvOpCapability, + {SpvCapabilityVulkanMemoryModelKHR})}, + {"OpCapability VulkanMemoryModelDeviceScopeKHR\n", + MakeInstruction(SpvOpCapability, + {SpvCapabilityVulkanMemoryModelDeviceScopeKHR})}, + {"OpMemoryModel Logical VulkanKHR\n", + MakeInstruction(SpvOpMemoryModel, {SpvAddressingModelLogical, + SpvMemoryModelVulkanKHR})}, + {"OpStore %1 %2 MakePointerAvailableKHR %3\n", + MakeInstruction(SpvOpStore, + {1, 2, SpvMemoryAccessMakePointerAvailableKHRMask, + 3})}, + {"OpStore %1 %2 Volatile|MakePointerAvailableKHR %3\n", + MakeInstruction(SpvOpStore, + {1, 2, + int(SpvMemoryAccessMakePointerAvailableKHRMask) | + int(SpvMemoryAccessVolatileMask), + 3})}, + {"OpStore %1 %2 Aligned|MakePointerAvailableKHR 4 %3\n", + MakeInstruction(SpvOpStore, + {1, 2, + int(SpvMemoryAccessMakePointerAvailableKHRMask) | + int(SpvMemoryAccessAlignedMask), + 4, 3})}, + {"OpStore %1 %2 MakePointerAvailableKHR|NonPrivatePointerKHR %3\n", + MakeInstruction(SpvOpStore, + {1, 2, + int(SpvMemoryAccessMakePointerAvailableKHRMask) | + int(SpvMemoryAccessNonPrivatePointerKHRMask), + 3})}, + {"%2 = OpLoad %1 %3 MakePointerVisibleKHR %4\n", + MakeInstruction(SpvOpLoad, + {1, 2, 3, SpvMemoryAccessMakePointerVisibleKHRMask, + 4})}, + {"%2 = OpLoad %1 %3 Volatile|MakePointerVisibleKHR %4\n", + MakeInstruction(SpvOpLoad, + {1, 2, 3, + int(SpvMemoryAccessMakePointerVisibleKHRMask) | + int(SpvMemoryAccessVolatileMask), + 4})}, + {"%2 = OpLoad %1 %3 Aligned|MakePointerVisibleKHR 8 %4\n", + MakeInstruction(SpvOpLoad, + {1, 2, 3, + int(SpvMemoryAccessMakePointerVisibleKHRMask) | + int(SpvMemoryAccessAlignedMask), + 8, 4})}, + {"%2 = OpLoad %1 %3 MakePointerVisibleKHR|NonPrivatePointerKHR " + "%4\n", + MakeInstruction(SpvOpLoad, + {1, 2, 3, + int(SpvMemoryAccessMakePointerVisibleKHRMask) | + int(SpvMemoryAccessNonPrivatePointerKHRMask), + 4})}, + {"OpCopyMemory %1 %2 " + "MakePointerAvailableKHR|" + "MakePointerVisibleKHR|" + "NonPrivatePointerKHR " + "%3 %4\n", + MakeInstruction(SpvOpCopyMemory, + {1, 2, + (int(SpvMemoryAccessMakePointerVisibleKHRMask) | + int(SpvMemoryAccessMakePointerAvailableKHRMask) | + int(SpvMemoryAccessNonPrivatePointerKHRMask)), + 3, 4})}, + {"OpCopyMemorySized %1 %2 %3 " + "MakePointerAvailableKHR|" + "MakePointerVisibleKHR|" + "NonPrivatePointerKHR " + "%4 %5\n", + MakeInstruction(SpvOpCopyMemorySized, + {1, 2, 3, + (int(SpvMemoryAccessMakePointerVisibleKHRMask) | + int(SpvMemoryAccessMakePointerAvailableKHRMask) | + int(SpvMemoryAccessNonPrivatePointerKHRMask)), + 4, 5})}, + // Image operands + {"OpImageWrite %1 %2 %3 MakeTexelAvailableKHR " + "%4\n", + MakeInstruction( + SpvOpImageWrite, + {1, 2, 3, int(SpvImageOperandsMakeTexelAvailableKHRMask), 4})}, + {"OpImageWrite %1 %2 %3 MakeTexelAvailableKHR|NonPrivateTexelKHR " + "%4\n", + MakeInstruction(SpvOpImageWrite, + {1, 2, 3, + int(SpvImageOperandsMakeTexelAvailableKHRMask) | + int(SpvImageOperandsNonPrivateTexelKHRMask), + 4})}, + {"OpImageWrite %1 %2 %3 " + "MakeTexelAvailableKHR|NonPrivateTexelKHR|VolatileTexelKHR " + "%4\n", + MakeInstruction(SpvOpImageWrite, + {1, 2, 3, + int(SpvImageOperandsMakeTexelAvailableKHRMask) | + int(SpvImageOperandsNonPrivateTexelKHRMask) | + int(SpvImageOperandsVolatileTexelKHRMask), + 4})}, + {"%2 = OpImageRead %1 %3 %4 MakeTexelVisibleKHR " + "%5\n", + MakeInstruction(SpvOpImageRead, + {1, 2, 3, 4, + int(SpvImageOperandsMakeTexelVisibleKHRMask), + 5})}, + {"%2 = OpImageRead %1 %3 %4 " + "MakeTexelVisibleKHR|NonPrivateTexelKHR " + "%5\n", + MakeInstruction(SpvOpImageRead, + {1, 2, 3, 4, + int(SpvImageOperandsMakeTexelVisibleKHRMask) | + int(SpvImageOperandsNonPrivateTexelKHRMask), + 5})}, + {"%2 = OpImageRead %1 %3 %4 " + "MakeTexelVisibleKHR|NonPrivateTexelKHR|VolatileTexelKHR " + "%5\n", + MakeInstruction(SpvOpImageRead, + {1, 2, 3, 4, + int(SpvImageOperandsMakeTexelVisibleKHRMask) | + int(SpvImageOperandsNonPrivateTexelKHRMask) | + int(SpvImageOperandsVolatileTexelKHRMask), + 5})}, + + // Memory semantics ID values are numbers put into a SPIR-V + // constant integer referenced by Id. There is no token for + // them, and so no assembler or disassembler support required. + // Similar for Scope ID. + })), ); + +// SPV_GOOGLE_decorate_string + +INSTANTIATE_TEST_CASE_P( + SPV_GOOGLE_decorate_string, ExtensionRoundTripTest, + Combine( + // We'll get coverage over operand tables by trying the universal + // environments, and at least one specific environment. + Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, + SPV_ENV_UNIVERSAL_1_2, SPV_ENV_VULKAN_1_0), + ValuesIn(std::vector{ + {"OpDecorateStringGOOGLE %1 HlslSemanticGOOGLE \"ABC\"\n", + MakeInstruction(SpvOpDecorateStringGOOGLE, + {1, SpvDecorationHlslSemanticGOOGLE}, + MakeVector("ABC"))}, + {"OpMemberDecorateStringGOOGLE %1 3 HlslSemanticGOOGLE \"DEF\"\n", + MakeInstruction(SpvOpMemberDecorateStringGOOGLE, + {1, 3, SpvDecorationHlslSemanticGOOGLE}, + MakeVector("DEF"))}, + })), ); + +// SPV_GOOGLE_hlsl_functionality1 + +INSTANTIATE_TEST_CASE_P( + SPV_GOOGLE_hlsl_functionality1, ExtensionRoundTripTest, + Combine( + // We'll get coverage over operand tables by trying the universal + // environments, and at least one specific environment. + Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, + SPV_ENV_UNIVERSAL_1_2, SPV_ENV_VULKAN_1_0), + // HlslSemanticGOOGLE is tested in SPV_GOOGLE_decorate_string, since + // they are coupled together. + ValuesIn(std::vector{ + {"OpDecorateId %1 HlslCounterBufferGOOGLE %2\n", + MakeInstruction(SpvOpDecorateId, + {1, SpvDecorationHlslCounterBufferGOOGLE, 2})}, + })), ); + +// SPV_NV_viewport_array2 + +INSTANTIATE_TEST_CASE_P( + SPV_NV_viewport_array2, ExtensionRoundTripTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, + SPV_ENV_UNIVERSAL_1_2, SPV_ENV_UNIVERSAL_1_3, + SPV_ENV_VULKAN_1_0, SPV_ENV_VULKAN_1_1), + ValuesIn(std::vector{ + {"OpExtension \"SPV_NV_viewport_array2\"\n", + MakeInstruction(SpvOpExtension, + MakeVector("SPV_NV_viewport_array2"))}, + // The EXT and NV extensions have the same token number for this + // capability. + {"OpCapability ShaderViewportIndexLayerEXT\n", + MakeInstruction(SpvOpCapability, + {SpvCapabilityShaderViewportIndexLayerNV})}, + // Check the new capability's token number + {"OpCapability ShaderViewportIndexLayerEXT\n", + MakeInstruction(SpvOpCapability, {5254})}, + // Decorations + {"OpDecorate %1 ViewportRelativeNV\n", + MakeInstruction(SpvOpDecorate, + {1, SpvDecorationViewportRelativeNV})}, + {"OpDecorate %1 BuiltIn ViewportMaskNV\n", + MakeInstruction(SpvOpDecorate, {1, SpvDecorationBuiltIn, + SpvBuiltInViewportMaskNV})}, + })), ); + +// SPV_NV_shader_subgroup_partitioned + +INSTANTIATE_TEST_CASE_P( + SPV_NV_shader_subgroup_partitioned, ExtensionRoundTripTest, + Combine( + Values(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_VULKAN_1_1), + ValuesIn(std::vector{ + {"OpExtension \"SPV_NV_shader_subgroup_partitioned\"\n", + MakeInstruction(SpvOpExtension, + MakeVector("SPV_NV_shader_subgroup_partitioned"))}, + {"OpCapability GroupNonUniformPartitionedNV\n", + MakeInstruction(SpvOpCapability, + {SpvCapabilityGroupNonUniformPartitionedNV})}, + // Check the new capability's token number + {"OpCapability GroupNonUniformPartitionedNV\n", + MakeInstruction(SpvOpCapability, {5297})}, + {"%2 = OpGroupNonUniformPartitionNV %1 %3\n", + MakeInstruction(SpvOpGroupNonUniformPartitionNV, {1, 2, 3})}, + // Check the new instruction's token number + {"%2 = OpGroupNonUniformPartitionNV %1 %3\n", + MakeInstruction(static_cast(5296), {1, 2, 3})}, + // Check the new group operations + {"%2 = OpGroupIAdd %1 %3 PartitionedReduceNV %4\n", + MakeInstruction(SpvOpGroupIAdd, + {1, 2, 3, SpvGroupOperationPartitionedReduceNV, + 4})}, + {"%2 = OpGroupIAdd %1 %3 PartitionedReduceNV %4\n", + MakeInstruction(SpvOpGroupIAdd, {1, 2, 3, 6, 4})}, + {"%2 = OpGroupIAdd %1 %3 PartitionedInclusiveScanNV %4\n", + MakeInstruction(SpvOpGroupIAdd, + {1, 2, 3, + SpvGroupOperationPartitionedInclusiveScanNV, 4})}, + {"%2 = OpGroupIAdd %1 %3 PartitionedInclusiveScanNV %4\n", + MakeInstruction(SpvOpGroupIAdd, {1, 2, 3, 7, 4})}, + {"%2 = OpGroupIAdd %1 %3 PartitionedExclusiveScanNV %4\n", + MakeInstruction(SpvOpGroupIAdd, + {1, 2, 3, + SpvGroupOperationPartitionedExclusiveScanNV, 4})}, + {"%2 = OpGroupIAdd %1 %3 PartitionedExclusiveScanNV %4\n", + MakeInstruction(SpvOpGroupIAdd, {1, 2, 3, 8, 4})}, + })), ); + +// SPV_EXT_descriptor_indexing + +INSTANTIATE_TEST_CASE_P( + SPV_EXT_descriptor_indexing, ExtensionRoundTripTest, + Combine( + Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, + SPV_ENV_UNIVERSAL_1_2, SPV_ENV_UNIVERSAL_1_3, SPV_ENV_VULKAN_1_0, + SPV_ENV_VULKAN_1_1), + ValuesIn(std::vector{ + {"OpExtension \"SPV_EXT_descriptor_indexing\"\n", + MakeInstruction(SpvOpExtension, + MakeVector("SPV_EXT_descriptor_indexing"))}, + // Check capabilities, by name + {"OpCapability ShaderNonUniformEXT\n", + MakeInstruction(SpvOpCapability, + {SpvCapabilityShaderNonUniformEXT})}, + {"OpCapability RuntimeDescriptorArrayEXT\n", + MakeInstruction(SpvOpCapability, + {SpvCapabilityRuntimeDescriptorArrayEXT})}, + {"OpCapability InputAttachmentArrayDynamicIndexingEXT\n", + MakeInstruction( + SpvOpCapability, + {SpvCapabilityInputAttachmentArrayDynamicIndexingEXT})}, + {"OpCapability UniformTexelBufferArrayDynamicIndexingEXT\n", + MakeInstruction( + SpvOpCapability, + {SpvCapabilityUniformTexelBufferArrayDynamicIndexingEXT})}, + {"OpCapability StorageTexelBufferArrayDynamicIndexingEXT\n", + MakeInstruction( + SpvOpCapability, + {SpvCapabilityStorageTexelBufferArrayDynamicIndexingEXT})}, + {"OpCapability UniformBufferArrayNonUniformIndexingEXT\n", + MakeInstruction( + SpvOpCapability, + {SpvCapabilityUniformBufferArrayNonUniformIndexingEXT})}, + {"OpCapability SampledImageArrayNonUniformIndexingEXT\n", + MakeInstruction( + SpvOpCapability, + {SpvCapabilitySampledImageArrayNonUniformIndexingEXT})}, + {"OpCapability StorageBufferArrayNonUniformIndexingEXT\n", + MakeInstruction( + SpvOpCapability, + {SpvCapabilityStorageBufferArrayNonUniformIndexingEXT})}, + {"OpCapability StorageImageArrayNonUniformIndexingEXT\n", + MakeInstruction( + SpvOpCapability, + {SpvCapabilityStorageImageArrayNonUniformIndexingEXT})}, + {"OpCapability InputAttachmentArrayNonUniformIndexingEXT\n", + MakeInstruction( + SpvOpCapability, + {SpvCapabilityInputAttachmentArrayNonUniformIndexingEXT})}, + {"OpCapability UniformTexelBufferArrayNonUniformIndexingEXT\n", + MakeInstruction( + SpvOpCapability, + {SpvCapabilityUniformTexelBufferArrayNonUniformIndexingEXT})}, + {"OpCapability StorageTexelBufferArrayNonUniformIndexingEXT\n", + MakeInstruction( + SpvOpCapability, + {SpvCapabilityStorageTexelBufferArrayNonUniformIndexingEXT})}, + // Check capabilities, by number + {"OpCapability ShaderNonUniformEXT\n", + MakeInstruction(SpvOpCapability, {5301})}, + {"OpCapability RuntimeDescriptorArrayEXT\n", + MakeInstruction(SpvOpCapability, {5302})}, + {"OpCapability InputAttachmentArrayDynamicIndexingEXT\n", + MakeInstruction(SpvOpCapability, {5303})}, + {"OpCapability UniformTexelBufferArrayDynamicIndexingEXT\n", + MakeInstruction(SpvOpCapability, {5304})}, + {"OpCapability StorageTexelBufferArrayDynamicIndexingEXT\n", + MakeInstruction(SpvOpCapability, {5305})}, + {"OpCapability UniformBufferArrayNonUniformIndexingEXT\n", + MakeInstruction(SpvOpCapability, {5306})}, + {"OpCapability SampledImageArrayNonUniformIndexingEXT\n", + MakeInstruction(SpvOpCapability, {5307})}, + {"OpCapability StorageBufferArrayNonUniformIndexingEXT\n", + MakeInstruction(SpvOpCapability, {5308})}, + {"OpCapability StorageImageArrayNonUniformIndexingEXT\n", + MakeInstruction(SpvOpCapability, {5309})}, + {"OpCapability InputAttachmentArrayNonUniformIndexingEXT\n", + MakeInstruction(SpvOpCapability, {5310})}, + {"OpCapability UniformTexelBufferArrayNonUniformIndexingEXT\n", + MakeInstruction(SpvOpCapability, {5311})}, + {"OpCapability StorageTexelBufferArrayNonUniformIndexingEXT\n", + MakeInstruction(SpvOpCapability, {5312})}, + + // Check the decoration token + {"OpDecorate %1 NonUniformEXT\n", + MakeInstruction(SpvOpDecorate, {1, SpvDecorationNonUniformEXT})}, + {"OpDecorate %1 NonUniformEXT\n", + MakeInstruction(SpvOpDecorate, {1, 5300})}, + })), ); + +} // namespace +} // namespace spvtools diff --git a/test/text_to_binary.function_test.cpp b/test/text_to_binary.function_test.cpp new file mode 100644 index 000000000..748461fb1 --- /dev/null +++ b/test/text_to_binary.function_test.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Assembler tests for instructions in the "Function" section of the +// SPIR-V spec. + +#include +#include + +#include "gmock/gmock.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using spvtest::EnumCase; +using spvtest::MakeInstruction; +using spvtest::TextToBinaryTest; +using ::testing::Eq; + +// Test OpFunction + +using OpFunctionControlTest = spvtest::TextToBinaryTestBase< + ::testing::TestWithParam>>; + +TEST_P(OpFunctionControlTest, AnySingleFunctionControlMask) { + const std::string input = "%result_id = OpFunction %result_type " + + GetParam().name() + " %function_type "; + EXPECT_THAT( + CompiledInstructions(input), + Eq(MakeInstruction(SpvOpFunction, {1, 2, GetParam().value(), 3}))); +} + +// clang-format off +#define CASE(VALUE,NAME) { SpvFunctionControl##VALUE, NAME } +INSTANTIATE_TEST_CASE_P(TextToBinaryFunctionTest, OpFunctionControlTest, + ::testing::ValuesIn(std::vector>{ + CASE(MaskNone, "None"), + CASE(InlineMask, "Inline"), + CASE(DontInlineMask, "DontInline"), + CASE(PureMask, "Pure"), + CASE(ConstMask, "Const"), + }),); +#undef CASE +// clang-format on + +TEST_F(OpFunctionControlTest, CombinedFunctionControlMask) { + // Sample a single combination. This ensures we've integrated + // the instruction parsing logic with spvTextParseMask. + const std::string input = + "%result_id = OpFunction %result_type Inline|Pure|Const %function_type"; + const uint32_t expected_mask = SpvFunctionControlInlineMask | + SpvFunctionControlPureMask | + SpvFunctionControlConstMask; + EXPECT_THAT(CompiledInstructions(input), + Eq(MakeInstruction(SpvOpFunction, {1, 2, expected_mask, 3}))); +} + +TEST_F(OpFunctionControlTest, WrongFunctionControl) { + EXPECT_THAT(CompileFailure("%r = OpFunction %t Inline|Unroll %ft"), + Eq("Invalid function control operand 'Inline|Unroll'.")); +} + +// TODO(dneto): OpFunctionParameter +// TODO(dneto): OpFunctionEnd +// TODO(dneto): OpFunctionCall + +} // namespace +} // namespace spvtools diff --git a/test/text_to_binary.group_test.cpp b/test/text_to_binary.group_test.cpp new file mode 100644 index 000000000..2f4b76d2f --- /dev/null +++ b/test/text_to_binary.group_test.cpp @@ -0,0 +1,76 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Assembler tests for instructions in the "Group Instrucions" section of the +// SPIR-V spec. + +#include +#include + +#include "gmock/gmock.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using spvtest::EnumCase; +using spvtest::MakeInstruction; +using ::testing::Eq; + +// Test GroupOperation enum + +using GroupOperationTest = spvtest::TextToBinaryTestBase< + ::testing::TestWithParam>>; + +TEST_P(GroupOperationTest, AnyGroupOperation) { + const std::string input = + "%result = OpGroupIAdd %type %scope " + GetParam().name() + " %x"; + EXPECT_THAT( + CompiledInstructions(input), + Eq(MakeInstruction(SpvOpGroupIAdd, {1, 2, 3, GetParam().value(), 4}))); +} + +// clang-format off +#define CASE(NAME) { SpvGroupOperation##NAME, #NAME} +INSTANTIATE_TEST_CASE_P(TextToBinaryGroupOperation, GroupOperationTest, + ::testing::ValuesIn(std::vector>{ + CASE(Reduce), + CASE(InclusiveScan), + CASE(ExclusiveScan), + }),); +#undef CASE +// clang-format on + +TEST_F(GroupOperationTest, WrongGroupOperation) { + EXPECT_THAT(CompileFailure("%r = OpGroupUMin %t %e xxyyzz %x"), + Eq("Invalid group operation 'xxyyzz'.")); +} + +// TODO(dneto): OpGroupAsyncCopy +// TODO(dneto): OpGroupWaitEvents +// TODO(dneto): OpGroupAll +// TODO(dneto): OpGroupAny +// TODO(dneto): OpGroupBroadcast +// TODO(dneto): OpGroupIAdd +// TODO(dneto): OpGroupFAdd +// TODO(dneto): OpGroupFMin +// TODO(dneto): OpGroupUMin +// TODO(dneto): OpGroupSMin +// TODO(dneto): OpGroupFMax +// TODO(dneto): OpGroupUMax +// TODO(dneto): OpGroupSMax + +} // namespace +} // namespace spvtools diff --git a/test/text_to_binary.image_test.cpp b/test/text_to_binary.image_test.cpp new file mode 100644 index 000000000..c1adedf44 --- /dev/null +++ b/test/text_to_binary.image_test.cpp @@ -0,0 +1,276 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Assembler tests for instructions in the "Image Instructions" section of +// the SPIR-V spec. + +#include +#include + +#include "gmock/gmock.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using spvtest::MakeInstruction; +using spvtest::TextToBinaryTest; +using ::testing::Eq; + +// An example case for a mask value with operands. +struct ImageOperandsCase { + std::string image_operands; + // The expected mask, followed by its operands. + std::vector expected_mask_and_operands; +}; + +// Test all kinds of image operands. + +using ImageOperandsTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>; + +TEST_P(ImageOperandsTest, Sample) { + const std::string input = + "%2 = OpImageFetch %1 %3 %4" + GetParam().image_operands + "\n"; + EXPECT_THAT(CompiledInstructions(input), + Eq(MakeInstruction(SpvOpImageFetch, {1, 2, 3, 4}, + GetParam().expected_mask_and_operands))); +} + +#define MASK(NAME) SpvImageOperands##NAME##Mask +INSTANTIATE_TEST_CASE_P( + TextToBinaryImageOperandsAny, ImageOperandsTest, + ::testing::ValuesIn(std::vector{ + // TODO(dneto): Rev32 adds many more values, and rearranges their + // values. + // Image operands are optional. + {"", {}}, + // Test each kind, alone. + {" Bias %5", {MASK(Bias), 5}}, + {" Lod %5", {MASK(Lod), 5}}, + {" Grad %5 %6", {MASK(Grad), 5, 6}}, + {" ConstOffset %5", {MASK(ConstOffset), 5}}, + {" Offset %5", {MASK(Offset), 5}}, + {" ConstOffsets %5", {MASK(ConstOffsets), 5}}, + {" Sample %5", {MASK(Sample), 5}}, + {" MinLod %5", {MASK(MinLod), 5}}, + }), ); +#undef MASK +#define MASK(NAME) static_cast(SpvImageOperands##NAME##Mask) +INSTANTIATE_TEST_CASE_P( + TextToBinaryImageOperandsCombination, ImageOperandsTest, + ::testing::ValuesIn(std::vector{ + // TODO(dneto): Rev32 adds many more values, and rearranges their + // values. + // Test adjacent pairs, so we can easily debug the values when it fails. + {" Bias|Lod %5 %6", {MASK(Bias) | MASK(Lod), 5, 6}}, + {" Lod|Grad %5 %6 %7", {MASK(Lod) | MASK(Grad), 5, 6, 7}}, + {" Grad|ConstOffset %5 %6 %7", + {MASK(Grad) | MASK(ConstOffset), 5, 6, 7}}, + {" ConstOffset|Offset %5 %6", {MASK(ConstOffset) | MASK(Offset), 5, 6}}, + {" Offset|ConstOffsets %5 %6", + {MASK(Offset) | MASK(ConstOffsets), 5, 6}}, + {" ConstOffsets|Sample %5 %6", + {MASK(ConstOffsets) | MASK(Sample), 5, 6}}, + // Test all masks together. + {" Bias|Lod|Grad|ConstOffset|Offset|ConstOffsets|Sample" + " %5 %6 %7 %8 %9 %10 %11 %12", + {MASK(Bias) | MASK(Lod) | MASK(Grad) | MASK(ConstOffset) | + MASK(Offset) | MASK(ConstOffsets) | MASK(Sample), + 5, 6, 7, 8, 9, 10, 11, 12}}, + // The same, but with mask value names reversed. + {" Sample|ConstOffsets|Offset|ConstOffset|Grad|Lod|Bias" + " %5 %6 %7 %8 %9 %10 %11 %12", + {MASK(Bias) | MASK(Lod) | MASK(Grad) | MASK(ConstOffset) | + MASK(Offset) | MASK(ConstOffsets) | MASK(Sample), + 5, 6, 7, 8, 9, 10, 11, 12}}}), ); +#undef MASK + +TEST_F(ImageOperandsTest, WrongOperand) { + EXPECT_THAT(CompileFailure("%r = OpImageFetch %t %i %c xxyyzz"), + Eq("Invalid image operand 'xxyyzz'.")); +} + +// Test OpImage + +using OpImageTest = TextToBinaryTest; + +TEST_F(OpImageTest, Valid) { + const std::string input = "%2 = OpImage %1 %3\n"; + EXPECT_THAT(CompiledInstructions(input), + Eq(MakeInstruction(SpvOpImage, {1, 2, 3}))); + + // Test the disassembler. + EXPECT_THAT(EncodeAndDecodeSuccessfully(input), input); +} + +TEST_F(OpImageTest, InvalidTypeOperand) { + EXPECT_THAT(CompileFailure("%2 = OpImage 42"), + Eq("Expected id to start with %.")); +} + +TEST_F(OpImageTest, MissingSampledImageOperand) { + EXPECT_THAT(CompileFailure("%2 = OpImage %1"), + Eq("Expected operand, found end of stream.")); +} + +TEST_F(OpImageTest, InvalidSampledImageOperand) { + EXPECT_THAT(CompileFailure("%2 = OpImage %1 1000"), + Eq("Expected id to start with %.")); +} + +TEST_F(OpImageTest, TooManyOperands) { + // We should improve this message, to say what instruction we're trying to + // parse. + EXPECT_THAT(CompileFailure("%2 = OpImage %1 %3 %4"), // an Id + Eq("Expected '=', found end of stream.")); + + EXPECT_THAT(CompileFailure("%2 = OpImage %1 %3 99"), // a number + Eq("Expected or at the beginning of an " + "instruction, found '99'.")); + EXPECT_THAT(CompileFailure("%2 = OpImage %1 %3 \"abc\""), // a string + Eq("Expected or at the beginning of an " + "instruction, found '\"abc\"'.")); +} + +// Test OpImageSparseRead + +using OpImageSparseReadTest = TextToBinaryTest; + +TEST_F(OpImageSparseReadTest, OnlyRequiredOperands) { + const std::string input = "%2 = OpImageSparseRead %1 %3 %4\n"; + EXPECT_THAT(CompiledInstructions(input), + Eq(MakeInstruction(SpvOpImageSparseRead, {1, 2, 3, 4}))); + // Test the disassembler. + EXPECT_THAT(EncodeAndDecodeSuccessfully(input), input); +} + +// Test all kinds of image operands on OpImageSparseRead + +using ImageSparseReadImageOperandsTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>; + +TEST_P(ImageSparseReadImageOperandsTest, Sample) { + const std::string input = + "%2 = OpImageSparseRead %1 %3 %4" + GetParam().image_operands + "\n"; + EXPECT_THAT(CompiledInstructions(input), + Eq(MakeInstruction(SpvOpImageSparseRead, {1, 2, 3, 4}, + GetParam().expected_mask_and_operands))); + // Test the disassembler. + EXPECT_THAT(EncodeAndDecodeSuccessfully(input), input); +} + +#define MASK(NAME) SpvImageOperands##NAME##Mask +INSTANTIATE_TEST_CASE_P(ImageSparseReadImageOperandsAny, + ImageSparseReadImageOperandsTest, + ::testing::ValuesIn(std::vector{ + // Image operands are optional. + {"", {}}, + // Test each kind, alone. + {" Bias %5", {MASK(Bias), 5}}, + {" Lod %5", {MASK(Lod), 5}}, + {" Grad %5 %6", {MASK(Grad), 5, 6}}, + {" ConstOffset %5", {MASK(ConstOffset), 5}}, + {" Offset %5", {MASK(Offset), 5}}, + {" ConstOffsets %5", {MASK(ConstOffsets), 5}}, + {" Sample %5", {MASK(Sample), 5}}, + {" MinLod %5", {MASK(MinLod), 5}}, + }), ); +#undef MASK +#define MASK(NAME) static_cast(SpvImageOperands##NAME##Mask) +INSTANTIATE_TEST_CASE_P( + ImageSparseReadImageOperandsCombination, ImageSparseReadImageOperandsTest, + ::testing::ValuesIn(std::vector{ + // values. + // Test adjacent pairs, so we can easily debug the values when it fails. + {" Bias|Lod %5 %6", {MASK(Bias) | MASK(Lod), 5, 6}}, + {" Lod|Grad %5 %6 %7", {MASK(Lod) | MASK(Grad), 5, 6, 7}}, + {" Grad|ConstOffset %5 %6 %7", + {MASK(Grad) | MASK(ConstOffset), 5, 6, 7}}, + {" ConstOffset|Offset %5 %6", {MASK(ConstOffset) | MASK(Offset), 5, 6}}, + {" Offset|ConstOffsets %5 %6", + {MASK(Offset) | MASK(ConstOffsets), 5, 6}}, + {" ConstOffsets|Sample %5 %6", + {MASK(ConstOffsets) | MASK(Sample), 5, 6}}, + // Test all masks together. + {" Bias|Lod|Grad|ConstOffset|Offset|ConstOffsets|Sample" + " %5 %6 %7 %8 %9 %10 %11 %12", + {MASK(Bias) | MASK(Lod) | MASK(Grad) | MASK(ConstOffset) | + MASK(Offset) | MASK(ConstOffsets) | MASK(Sample), + 5, 6, 7, 8, 9, 10, 11, 12}}, + // Don't try the masks reversed, since this is a round trip test, + // and the disassembler will sort them. + }), ); +#undef MASK + +TEST_F(OpImageSparseReadTest, InvalidTypeOperand) { + EXPECT_THAT(CompileFailure("%2 = OpImageSparseRead 42"), + Eq("Expected id to start with %.")); +} + +TEST_F(OpImageSparseReadTest, MissingImageOperand) { + EXPECT_THAT(CompileFailure("%2 = OpImageSparseRead %1"), + Eq("Expected operand, found end of stream.")); +} + +TEST_F(OpImageSparseReadTest, InvalidImageOperand) { + EXPECT_THAT(CompileFailure("%2 = OpImageSparseRead %1 1000"), + Eq("Expected id to start with %.")); +} + +TEST_F(OpImageSparseReadTest, MissingCoordinateOperand) { + EXPECT_THAT(CompileFailure("%2 = OpImageSparseRead %1 %2"), + Eq("Expected operand, found end of stream.")); +} + +TEST_F(OpImageSparseReadTest, InvalidCoordinateOperand) { + EXPECT_THAT(CompileFailure("%2 = OpImageSparseRead %1 %2 1000"), + Eq("Expected id to start with %.")); +} + +// TODO(dneto): OpSampledImage +// TODO(dneto): OpImageSampleImplicitLod +// TODO(dneto): OpImageSampleExplicitLod +// TODO(dneto): OpImageSampleDrefImplicitLod +// TODO(dneto): OpImageSampleDrefExplicitLod +// TODO(dneto): OpImageSampleProjImplicitLod +// TODO(dneto): OpImageSampleProjExplicitLod +// TODO(dneto): OpImageSampleProjDrefImplicitLod +// TODO(dneto): OpImageSampleProjDrefExplicitLod +// TODO(dneto): OpImageGather +// TODO(dneto): OpImageDrefGather +// TODO(dneto): OpImageRead +// TODO(dneto): OpImageWrite +// TODO(dneto): OpImageQueryFormat +// TODO(dneto): OpImageQueryOrder +// TODO(dneto): OpImageQuerySizeLod +// TODO(dneto): OpImageQuerySize +// TODO(dneto): OpImageQueryLod +// TODO(dneto): OpImageQueryLevels +// TODO(dneto): OpImageQuerySamples +// TODO(dneto): OpImageSparseSampleImplicitLod +// TODO(dneto): OpImageSparseSampleExplicitLod +// TODO(dneto): OpImageSparseSampleDrefImplicitLod +// TODO(dneto): OpImageSparseSampleDrefExplicitLod +// TODO(dneto): OpImageSparseSampleProjImplicitLod +// TODO(dneto): OpImageSparseSampleProjExplicitLod +// TODO(dneto): OpImageSparseSampleProjDrefImplicitLod +// TODO(dneto): OpImageSparseSampleProjDrefExplicitLod +// TODO(dneto): OpImageSparseFetch +// TODO(dneto): OpImageSparseDrefGather +// TODO(dneto): OpImageSparseTexelsResident + +} // namespace +} // namespace spvtools diff --git a/test/text_to_binary.literal_test.cpp b/test/text_to_binary.literal_test.cpp new file mode 100644 index 000000000..bcbb63e0d --- /dev/null +++ b/test/text_to_binary.literal_test.cpp @@ -0,0 +1,125 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Assembler tests for literal numbers and literal strings. + +#include + +#include "test/test_fixture.h" + +namespace spvtools { +namespace { + +using spvtest::TextToBinaryTest; + +TEST_F(TextToBinaryTest, LiteralStringInPlaceOfLiteralNumber) { + EXPECT_EQ( + R"(Invalid unsigned integer literal: "I shouldn't be a string")", + CompileFailure(R"(OpSource GLSL "I shouldn't be a string")")); +} + +TEST_F(TextToBinaryTest, GarbageInPlaceOfLiteralString) { + EXPECT_EQ("Invalid literal string 'nice-source-code'.", + CompileFailure("OpSourceExtension nice-source-code")); +} + +TEST_F(TextToBinaryTest, LiteralNumberInPlaceOfLiteralString) { + EXPECT_EQ("Expected literal string, found literal number '1000'.", + CompileFailure("OpSourceExtension 1000")); +} + +TEST_F(TextToBinaryTest, LiteralFloatInPlaceOfLiteralInteger) { + EXPECT_EQ("Invalid unsigned integer literal: 10.5", + CompileFailure("OpSource GLSL 10.5")); + + EXPECT_EQ("Invalid unsigned integer literal: 0.2", + CompileFailure(R"(OpMemberName %type 0.2 "member0.2")")); + + EXPECT_EQ("Invalid unsigned integer literal: 32.42", + CompileFailure("%int = OpTypeInt 32.42 0")); + + EXPECT_EQ("Invalid unsigned integer literal: 4.5", + CompileFailure("%mat = OpTypeMatrix %vec 4.5")); + + EXPECT_EQ("Invalid unsigned integer literal: 1.5", + CompileFailure("OpExecutionMode %main LocalSize 1.5 1.6 1.7")); + + EXPECT_EQ("Invalid unsigned integer literal: 0.123", + CompileFailure("%i32 = OpTypeInt 32 1\n" + "%c = OpConstant %i32 0.123")); +} + +TEST_F(TextToBinaryTest, LiteralInt64) { + const std::string code = + "%1 = OpTypeInt 64 0\n%2 = OpConstant %1 123456789021\n"; + EXPECT_EQ(code, EncodeAndDecodeSuccessfully(code)); +} + +TEST_F(TextToBinaryTest, LiteralDouble) { + const std::string code = + "%1 = OpTypeFloat 64\n%2 = OpSpecConstant %1 3.14159265358979\n"; + EXPECT_EQ(code, EncodeAndDecodeSuccessfully(code)); +} + +TEST_F(TextToBinaryTest, LiteralStringASCIILong) { + // SPIR-V allows strings up to 65535 characters. + // Test the simple case of UTF-8 code points corresponding + // to ASCII characters. + EXPECT_EQ(65535, SPV_LIMIT_LITERAL_STRING_UTF8_CHARS_MAX); + const std::string code = + "OpSourceExtension \"" + + std::string(SPV_LIMIT_LITERAL_STRING_UTF8_CHARS_MAX, 'o') + "\"\n"; + EXPECT_EQ(code, EncodeAndDecodeSuccessfully(code)); +} + +TEST_F(TextToBinaryTest, LiteralStringUTF8LongEncodings) { + // SPIR-V allows strings up to 65535 characters. + // Test the case of many Unicode characters, each of which has + // a 4-byte UTF-8 encoding. + + // An instruction is at most 65535 words long. The first one + // contains the wordcount and opcode. So the worst case number of + // 4-byte UTF-8 characters is 65533, since we also need to + // store a terminating null character. + + // This string fits exactly into 65534 words. + const std::string good_string = + spvtest::MakeLongUTF8String(65533) + // The following single character has a 3 byte encoding, + // which fits snugly against the terminating null. + + "\xe8\x80\x80"; + + // These strings will overflow any instruction with 0 or 1 other + // arguments, respectively. + const std::string bad_0_arg_string = spvtest::MakeLongUTF8String(65534); + const std::string bad_1_arg_string = spvtest::MakeLongUTF8String(65533); + + const std::string good_code = "OpSourceExtension \"" + good_string + "\"\n"; + EXPECT_EQ(good_code, EncodeAndDecodeSuccessfully(good_code)); + + // Prove that it works on more than one instruction. + const std::string good_code_2 = "OpSourceContinued \"" + good_string + "\"\n"; + EXPECT_EQ(good_code, EncodeAndDecodeSuccessfully(good_code)); + + // Failure cases. + EXPECT_EQ("Instruction too long: more than 65535 words.", + CompileFailure("OpSourceExtension \"" + bad_0_arg_string + "\"\n")); + EXPECT_EQ("Instruction too long: more than 65535 words.", + CompileFailure("OpSourceContinued \"" + bad_0_arg_string + "\"\n")); + EXPECT_EQ("Instruction too long: more than 65535 words.", + CompileFailure("OpName %target \"" + bad_1_arg_string + "\"\n")); +} + +} // namespace +} // namespace spvtools diff --git a/test/text_to_binary.memory_test.cpp b/test/text_to_binary.memory_test.cpp new file mode 100644 index 000000000..ead08e6fd --- /dev/null +++ b/test/text_to_binary.memory_test.cpp @@ -0,0 +1,111 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Assembler tests for instructions in the "Memory Instructions" section of +// the SPIR-V spec. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using spvtest::EnumCase; +using spvtest::MakeInstruction; +using spvtest::TextToBinaryTest; +using ::testing::Eq; + +// Test assembly of Memory Access masks + +using MemoryAccessTest = spvtest::TextToBinaryTestBase< + ::testing::TestWithParam>>; + +TEST_P(MemoryAccessTest, AnySingleMemoryAccessMask) { + std::stringstream input; + input << "OpStore %ptr %value " << GetParam().name(); + for (auto operand : GetParam().operands()) input << " " << operand; + EXPECT_THAT(CompiledInstructions(input.str()), + Eq(MakeInstruction(SpvOpStore, {1, 2, GetParam().value()}, + GetParam().operands()))); +} + +INSTANTIATE_TEST_CASE_P( + TextToBinaryMemoryAccessTest, MemoryAccessTest, + ::testing::ValuesIn(std::vector>{ + {SpvMemoryAccessMaskNone, "None", {}}, + {SpvMemoryAccessVolatileMask, "Volatile", {}}, + {SpvMemoryAccessAlignedMask, "Aligned", {16}}, + {SpvMemoryAccessNontemporalMask, "Nontemporal", {}}, + }), ); + +TEST_F(TextToBinaryTest, CombinedMemoryAccessMask) { + const std::string input = "OpStore %ptr %value Volatile|Aligned 16"; + const uint32_t expected_mask = + SpvMemoryAccessVolatileMask | SpvMemoryAccessAlignedMask; + EXPECT_THAT(expected_mask, Eq(3u)); + EXPECT_THAT(CompiledInstructions(input), + Eq(MakeInstruction(SpvOpStore, {1, 2, expected_mask, 16}))); +} + +// Test Storage Class enum values + +using StorageClassTest = spvtest::TextToBinaryTestBase< + ::testing::TestWithParam>>; + +TEST_P(StorageClassTest, AnyStorageClass) { + const std::string input = "%1 = OpVariable %2 " + GetParam().name(); + EXPECT_THAT(CompiledInstructions(input), + Eq(MakeInstruction(SpvOpVariable, {1, 2, GetParam().value()}))); +} + +// clang-format off +#define CASE(NAME) { SpvStorageClass##NAME, #NAME, {} } +INSTANTIATE_TEST_CASE_P( + TextToBinaryStorageClassTest, StorageClassTest, + ::testing::ValuesIn(std::vector>{ + CASE(UniformConstant), + CASE(Input), + CASE(Uniform), + CASE(Output), + CASE(Workgroup), + CASE(CrossWorkgroup), + CASE(Private), + CASE(Function), + CASE(Generic), + CASE(PushConstant), + CASE(AtomicCounter), + CASE(Image), + }),); +#undef CASE +// clang-format on + +// TODO(dneto): OpVariable with initializers +// TODO(dneto): OpImageTexelPointer +// TODO(dneto): OpLoad +// TODO(dneto): OpStore +// TODO(dneto): OpCopyMemory +// TODO(dneto): OpCopyMemorySized +// TODO(dneto): OpAccessChain +// TODO(dneto): OpInBoundsAccessChain +// TODO(dneto): OpPtrAccessChain +// TODO(dneto): OpArrayLength +// TODO(dneto): OpGenercPtrMemSemantics + +} // namespace +} // namespace spvtools diff --git a/test/text_to_binary.misc_test.cpp b/test/text_to_binary.misc_test.cpp new file mode 100644 index 000000000..03b1e0914 --- /dev/null +++ b/test/text_to_binary.misc_test.cpp @@ -0,0 +1,58 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Assembler tests for instructions in the "Miscellaneous" section of the +// SPIR-V spec. + +#include "test/unit_spirv.h" + +#include "gmock/gmock.h" +#include "test/test_fixture.h" + +namespace spvtools { +namespace { + +using SpirvVector = spvtest::TextToBinaryTest::SpirvVector; +using spvtest::MakeInstruction; +using ::testing::Eq; +using TextToBinaryMisc = spvtest::TextToBinaryTest; + +TEST_F(TextToBinaryMisc, OpNop) { + EXPECT_THAT(CompiledInstructions("OpNop"), Eq(MakeInstruction(SpvOpNop, {}))); +} + +TEST_F(TextToBinaryMisc, OpUndef) { + const SpirvVector code = CompiledInstructions(R"(%f32 = OpTypeFloat 32 + %u = OpUndef %f32)"); + const uint32_t typeID = 1; + EXPECT_THAT(code[1], Eq(typeID)); + EXPECT_THAT(Subvector(code, 3), Eq(MakeInstruction(SpvOpUndef, {typeID, 2}))); +} + +TEST_F(TextToBinaryMisc, OpWrong) { + EXPECT_THAT(CompileFailure(" OpWrong %1 %2"), + Eq("Invalid Opcode name 'OpWrong'")); +} + +TEST_F(TextToBinaryMisc, OpWrongAfterRight) { + const auto assembly = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpXYZ +)"; + EXPECT_THAT(CompileFailure(assembly), Eq("Invalid Opcode name 'OpXYZ'")); +} + +} // namespace +} // namespace spvtools diff --git a/test/text_to_binary.mode_setting_test.cpp b/test/text_to_binary.mode_setting_test.cpp new file mode 100644 index 000000000..ed4fa2fb4 --- /dev/null +++ b/test/text_to_binary.mode_setting_test.cpp @@ -0,0 +1,302 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Assembler tests for instructions in the "Mode-Setting" section of the +// SPIR-V spec. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using spvtest::EnumCase; +using spvtest::MakeInstruction; +using spvtest::MakeVector; +using ::testing::Combine; +using ::testing::Eq; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +// Test OpMemoryModel + +// An example case for OpMemoryModel +struct MemoryModelCase { + uint32_t get_addressing_value() const { + return static_cast(addressing_value); + } + uint32_t get_memory_value() const { + return static_cast(memory_value); + } + SpvAddressingModel addressing_value; + std::string addressing_name; + SpvMemoryModel memory_value; + std::string memory_name; +}; + +using OpMemoryModelTest = + spvtest::TextToBinaryTestBase>; + +TEST_P(OpMemoryModelTest, AnyMemoryModelCase) { + const std::string input = "OpMemoryModel " + GetParam().addressing_name + + " " + GetParam().memory_name; + EXPECT_THAT( + CompiledInstructions(input), + Eq(MakeInstruction(SpvOpMemoryModel, {GetParam().get_addressing_value(), + GetParam().get_memory_value()}))); +} + +#define CASE(ADDRESSING, MEMORY) \ + { \ + SpvAddressingModel##ADDRESSING, #ADDRESSING, SpvMemoryModel##MEMORY, \ + #MEMORY \ + } +// clang-format off +INSTANTIATE_TEST_CASE_P(TextToBinaryMemoryModel, OpMemoryModelTest, + ValuesIn(std::vector{ + // These cases exercise each addressing model, and + // each memory model, but not necessarily in + // combination. + CASE(Logical,Simple), + CASE(Logical,GLSL450), + CASE(Physical32,OpenCL), + CASE(Physical64,OpenCL), + }),); +#undef CASE +// clang-format on + +TEST_F(OpMemoryModelTest, WrongModel) { + EXPECT_THAT(CompileFailure("OpMemoryModel xxyyzz Simple"), + Eq("Invalid addressing model 'xxyyzz'.")); + EXPECT_THAT(CompileFailure("OpMemoryModel Logical xxyyzz"), + Eq("Invalid memory model 'xxyyzz'.")); +} + +// Test OpEntryPoint + +// An example case for OpEntryPoint +struct EntryPointCase { + uint32_t get_execution_value() const { + return static_cast(execution_value); + } + SpvExecutionModel execution_value; + std::string execution_name; + std::string entry_point_name; +}; + +using OpEntryPointTest = + spvtest::TextToBinaryTestBase>; + +TEST_P(OpEntryPointTest, AnyEntryPointCase) { + // TODO(dneto): utf-8, escaping, quoting cases for entry point name. + const std::string input = "OpEntryPoint " + GetParam().execution_name + + " %1 \"" + GetParam().entry_point_name + "\""; + EXPECT_THAT( + CompiledInstructions(input), + Eq(MakeInstruction(SpvOpEntryPoint, {GetParam().get_execution_value(), 1}, + MakeVector(GetParam().entry_point_name)))); +} + +// clang-format off +#define CASE(NAME) SpvExecutionModel##NAME, #NAME +INSTANTIATE_TEST_CASE_P(TextToBinaryEntryPoint, OpEntryPointTest, + ValuesIn(std::vector{ + { CASE(Vertex), "" }, + { CASE(TessellationControl), "my tess" }, + { CASE(TessellationEvaluation), "really fancy" }, + { CASE(Geometry), "Euclid" }, + { CASE(Fragment), "FAT32" }, + { CASE(GLCompute), "cubic" }, + { CASE(Kernel), "Sanders" }, + }),); +#undef CASE +// clang-format on + +TEST_F(OpEntryPointTest, WrongModel) { + EXPECT_THAT(CompileFailure("OpEntryPoint xxyyzz %1 \"fun\""), + Eq("Invalid execution model 'xxyyzz'.")); +} + +// Test OpExecutionMode +using OpExecutionModeTest = spvtest::TextToBinaryTestBase< + TestWithParam>>>; + +TEST_P(OpExecutionModeTest, AnyExecutionMode) { + // This string should assemble, but should not validate. + std::stringstream input; + input << "OpExecutionMode %1 " << std::get<1>(GetParam()).name(); + for (auto operand : std::get<1>(GetParam()).operands()) + input << " " << operand; + EXPECT_THAT(CompiledInstructions(input.str(), std::get<0>(GetParam())), + Eq(MakeInstruction(SpvOpExecutionMode, + {1, std::get<1>(GetParam()).value()}, + std::get<1>(GetParam()).operands()))); +} + +#define CASE(NAME) SpvExecutionMode##NAME, #NAME +INSTANTIATE_TEST_CASE_P( + TextToBinaryExecutionMode, OpExecutionModeTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector>{ + // The operand literal values are arbitrarily chosen, + // but there are the right number of them. + {CASE(Invocations), {101}}, + {CASE(SpacingEqual), {}}, + {CASE(SpacingFractionalEven), {}}, + {CASE(SpacingFractionalOdd), {}}, + {CASE(VertexOrderCw), {}}, + {CASE(VertexOrderCcw), {}}, + {CASE(PixelCenterInteger), {}}, + {CASE(OriginUpperLeft), {}}, + {CASE(OriginLowerLeft), {}}, + {CASE(EarlyFragmentTests), {}}, + {CASE(PointMode), {}}, + {CASE(Xfb), {}}, + {CASE(DepthReplacing), {}}, + {CASE(DepthGreater), {}}, + {CASE(DepthLess), {}}, + {CASE(DepthUnchanged), {}}, + {CASE(LocalSize), {64, 1, 2}}, + {CASE(LocalSizeHint), {8, 2, 4}}, + {CASE(InputPoints), {}}, + {CASE(InputLines), {}}, + {CASE(InputLinesAdjacency), {}}, + {CASE(Triangles), {}}, + {CASE(InputTrianglesAdjacency), {}}, + {CASE(Quads), {}}, + {CASE(Isolines), {}}, + {CASE(OutputVertices), {21}}, + {CASE(OutputPoints), {}}, + {CASE(OutputLineStrip), {}}, + {CASE(OutputTriangleStrip), {}}, + {CASE(VecTypeHint), {96}}, + {CASE(ContractionOff), {}}, + })), ); + +INSTANTIATE_TEST_CASE_P( + TextToBinaryExecutionModeV11, OpExecutionModeTest, + Combine(Values(SPV_ENV_UNIVERSAL_1_1), + ValuesIn(std::vector>{ + {CASE(Initializer)}, + {CASE(Finalizer)}, + {CASE(SubgroupSize), {12}}, + {CASE(SubgroupsPerWorkgroup), {64}}})), ); +#undef CASE + +TEST_F(OpExecutionModeTest, WrongMode) { + EXPECT_THAT(CompileFailure("OpExecutionMode %1 xxyyzz"), + Eq("Invalid execution mode 'xxyyzz'.")); +} + +TEST_F(OpExecutionModeTest, TooManyModes) { + EXPECT_THAT(CompileFailure("OpExecutionMode %1 Xfb PointMode"), + Eq("Expected or at the beginning of an " + "instruction, found 'PointMode'.")); +} + +// Test OpCapability + +using OpCapabilityTest = + spvtest::TextToBinaryTestBase>>; + +TEST_P(OpCapabilityTest, AnyCapability) { + const std::string input = "OpCapability " + GetParam().name(); + EXPECT_THAT(CompiledInstructions(input), + Eq(MakeInstruction(SpvOpCapability, {GetParam().value()}))); +} + +// clang-format off +#define CASE(NAME) { SpvCapability##NAME, #NAME } +INSTANTIATE_TEST_CASE_P(TextToBinaryCapability, OpCapabilityTest, + ValuesIn(std::vector>{ + CASE(Matrix), + CASE(Shader), + CASE(Geometry), + CASE(Tessellation), + CASE(Addresses), + CASE(Linkage), + CASE(Kernel), + CASE(Vector16), + CASE(Float16Buffer), + CASE(Float16), + CASE(Float64), + CASE(Int64), + CASE(Int64Atomics), + CASE(ImageBasic), + CASE(ImageReadWrite), + CASE(ImageMipmap), + // Value 16 intentionally missing + CASE(Pipes), + CASE(Groups), + CASE(DeviceEnqueue), + CASE(LiteralSampler), + CASE(AtomicStorage), + CASE(Int16), + CASE(TessellationPointSize), + CASE(GeometryPointSize), + CASE(ImageGatherExtended), + // Value 26 intentionally missing + CASE(StorageImageMultisample), + CASE(UniformBufferArrayDynamicIndexing), + CASE(SampledImageArrayDynamicIndexing), + CASE(StorageBufferArrayDynamicIndexing), + CASE(StorageImageArrayDynamicIndexing), + CASE(ClipDistance), + CASE(CullDistance), + CASE(ImageCubeArray), + CASE(SampleRateShading), + CASE(ImageRect), + CASE(SampledRect), + CASE(GenericPointer), + CASE(Int8), + CASE(InputAttachment), + CASE(SparseResidency), + CASE(MinLod), + CASE(Sampled1D), + CASE(Image1D), + CASE(SampledCubeArray), + CASE(SampledBuffer), + CASE(ImageBuffer), + CASE(ImageMSArray), + CASE(StorageImageExtendedFormats), + CASE(ImageQuery), + CASE(DerivativeControl), + CASE(InterpolationFunction), + CASE(TransformFeedback), + }),); +#undef CASE +// clang-format on + +using TextToBinaryCapability = spvtest::TextToBinaryTest; + +TEST_F(TextToBinaryCapability, BadMissingCapability) { + EXPECT_THAT(CompileFailure("OpCapability"), + Eq("Expected operand, found end of stream.")); +} + +TEST_F(TextToBinaryCapability, BadInvalidCapability) { + EXPECT_THAT(CompileFailure("OpCapability 123"), + Eq("Invalid capability '123'.")); +} + +// TODO(dneto): OpExecutionMode + +} // namespace +} // namespace spvtools diff --git a/test/text_to_binary.pipe_storage_test.cpp b/test/text_to_binary.pipe_storage_test.cpp new file mode 100644 index 000000000..f74dbcfdf --- /dev/null +++ b/test/text_to_binary.pipe_storage_test.cpp @@ -0,0 +1,126 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gmock/gmock.h" +#include "test/test_fixture.h" + +namespace spvtools { +namespace { + +using ::spvtest::MakeInstruction; +using ::testing::Eq; + +using OpTypePipeStorageTest = spvtest::TextToBinaryTest; + +// It can assemble, but should not validate. Validation checks for version +// and capability are in another test file. +TEST_F(OpTypePipeStorageTest, OpcodeAssemblesInV10) { + EXPECT_THAT( + CompiledInstructions("%res = OpTypePipeStorage", SPV_ENV_UNIVERSAL_1_0), + Eq(MakeInstruction(SpvOpTypePipeStorage, {1}))); +} + +TEST_F(OpTypePipeStorageTest, ArgumentCount) { + EXPECT_THAT( + CompileFailure("OpTypePipeStorage", SPV_ENV_UNIVERSAL_1_1), + Eq("Expected at the beginning of an instruction, found " + "'OpTypePipeStorage'.")); + EXPECT_THAT( + CompiledInstructions("%res = OpTypePipeStorage", SPV_ENV_UNIVERSAL_1_1), + Eq(MakeInstruction(SpvOpTypePipeStorage, {1}))); + EXPECT_THAT(CompileFailure("%res = OpTypePipeStorage %1 %2 %3 %4 %5", + SPV_ENV_UNIVERSAL_1_1), + Eq("'=' expected after result id.")); +} + +using OpConstantPipeStorageTest = spvtest::TextToBinaryTest; + +TEST_F(OpConstantPipeStorageTest, OpcodeAssemblesInV10) { + EXPECT_THAT(CompiledInstructions("%1 = OpConstantPipeStorage %2 3 4 5", + SPV_ENV_UNIVERSAL_1_0), + Eq(MakeInstruction(SpvOpConstantPipeStorage, {1, 2, 3, 4, 5}))); +} + +TEST_F(OpConstantPipeStorageTest, ArgumentCount) { + EXPECT_THAT( + CompileFailure("OpConstantPipeStorage", SPV_ENV_UNIVERSAL_1_1), + Eq("Expected at the beginning of an instruction, found " + "'OpConstantPipeStorage'.")); + EXPECT_THAT( + CompileFailure("%1 = OpConstantPipeStorage", SPV_ENV_UNIVERSAL_1_1), + Eq("Expected operand, found end of stream.")); + EXPECT_THAT(CompileFailure("%1 = OpConstantPipeStorage %2 3 4", + SPV_ENV_UNIVERSAL_1_1), + Eq("Expected operand, found end of stream.")); + EXPECT_THAT(CompiledInstructions("%1 = OpConstantPipeStorage %2 3 4 5", + SPV_ENV_UNIVERSAL_1_1), + Eq(MakeInstruction(SpvOpConstantPipeStorage, {1, 2, 3, 4, 5}))); + EXPECT_THAT(CompileFailure("%1 = OpConstantPipeStorage %2 3 4 5 %6 %7", + SPV_ENV_UNIVERSAL_1_1), + Eq("'=' expected after result id.")); +} + +TEST_F(OpConstantPipeStorageTest, ArgumentTypes) { + EXPECT_THAT(CompileFailure("%1 = OpConstantPipeStorage %2 %3 4 5", + SPV_ENV_UNIVERSAL_1_1), + Eq("Invalid unsigned integer literal: %3")); + EXPECT_THAT(CompileFailure("%1 = OpConstantPipeStorage %2 3 %4 5", + SPV_ENV_UNIVERSAL_1_1), + Eq("Invalid unsigned integer literal: %4")); + EXPECT_THAT(CompileFailure("%1 = OpConstantPipeStorage 2 3 4 5", + SPV_ENV_UNIVERSAL_1_1), + Eq("Expected id to start with %.")); + EXPECT_THAT(CompileFailure("%1 = OpConstantPipeStorage %2 3 4 \"ab\"", + SPV_ENV_UNIVERSAL_1_1), + Eq("Invalid unsigned integer literal: \"ab\"")); +} + +using OpCreatePipeFromPipeStorageTest = spvtest::TextToBinaryTest; + +TEST_F(OpCreatePipeFromPipeStorageTest, OpcodeAssemblesInV10) { + EXPECT_THAT(CompiledInstructions("%1 = OpCreatePipeFromPipeStorage %2 %3", + SPV_ENV_UNIVERSAL_1_0), + Eq(MakeInstruction(SpvOpCreatePipeFromPipeStorage, {1, 2, 3}))); +} + +TEST_F(OpCreatePipeFromPipeStorageTest, ArgumentCount) { + EXPECT_THAT( + CompileFailure("OpCreatePipeFromPipeStorage", SPV_ENV_UNIVERSAL_1_1), + Eq("Expected at the beginning of an instruction, found " + "'OpCreatePipeFromPipeStorage'.")); + EXPECT_THAT( + CompileFailure("%1 = OpCreatePipeFromPipeStorage", SPV_ENV_UNIVERSAL_1_1), + Eq("Expected operand, found end of stream.")); + EXPECT_THAT(CompileFailure("%1 = OpCreatePipeFromPipeStorage %2 OpNop", + SPV_ENV_UNIVERSAL_1_1), + Eq("Expected operand, found next instruction instead.")); + EXPECT_THAT(CompiledInstructions("%1 = OpCreatePipeFromPipeStorage %2 %3", + SPV_ENV_UNIVERSAL_1_1), + Eq(MakeInstruction(SpvOpCreatePipeFromPipeStorage, {1, 2, 3}))); + EXPECT_THAT(CompileFailure("%1 = OpCreatePipeFromPipeStorage %2 %3 %4 %5", + SPV_ENV_UNIVERSAL_1_1), + Eq("'=' expected after result id.")); +} + +TEST_F(OpCreatePipeFromPipeStorageTest, ArgumentTypes) { + EXPECT_THAT(CompileFailure("%1 = OpCreatePipeFromPipeStorage \"\" %3", + SPV_ENV_UNIVERSAL_1_1), + Eq("Expected id to start with %.")); + EXPECT_THAT(CompileFailure("%1 = OpCreatePipeFromPipeStorage %2 3", + SPV_ENV_UNIVERSAL_1_1), + Eq("Expected id to start with %.")); +} + +} // namespace +} // namespace spvtools diff --git a/test/text_to_binary.reserved_sampling_test.cpp b/test/text_to_binary.reserved_sampling_test.cpp new file mode 100644 index 000000000..42e4e2aeb --- /dev/null +++ b/test/text_to_binary.reserved_sampling_test.cpp @@ -0,0 +1,63 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validation tests for illegal instructions + +#include + +#include "gmock/gmock.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using ::spvtest::MakeInstruction; +using ::testing::Eq; + +using ReservedSamplingInstTest = RoundTripTest; + +TEST_F(ReservedSamplingInstTest, OpImageSparseSampleProjImplicitLod) { + std::string input = "%2 = OpImageSparseSampleProjImplicitLod %1 %3 %4\n"; + EXPECT_THAT( + CompiledInstructions(input, SPV_ENV_UNIVERSAL_1_0), + Eq(MakeInstruction(SpvOpImageSparseSampleProjImplicitLod, {1, 2, 3, 4}))); +} + +TEST_F(ReservedSamplingInstTest, OpImageSparseSampleProjExplicitLod) { + std::string input = + "%2 = OpImageSparseSampleProjExplicitLod %1 %3 %4 Lod %5\n"; + EXPECT_THAT(CompiledInstructions(input, SPV_ENV_UNIVERSAL_1_0), + Eq(MakeInstruction(SpvOpImageSparseSampleProjExplicitLod, + {1, 2, 3, 4, SpvImageOperandsLodMask, 5}))); +} + +TEST_F(ReservedSamplingInstTest, OpImageSparseSampleProjDrefImplicitLod) { + std::string input = + "%2 = OpImageSparseSampleProjDrefImplicitLod %1 %3 %4 %5\n"; + EXPECT_THAT(CompiledInstructions(input, SPV_ENV_UNIVERSAL_1_0), + Eq(MakeInstruction(SpvOpImageSparseSampleProjDrefImplicitLod, + {1, 2, 3, 4, 5}))); +} + +TEST_F(ReservedSamplingInstTest, OpImageSparseSampleProjDrefExplicitLod) { + std::string input = + "%2 = OpImageSparseSampleProjDrefExplicitLod %1 %3 %4 %5 Lod %6\n"; + EXPECT_THAT(CompiledInstructions(input, SPV_ENV_UNIVERSAL_1_0), + Eq(MakeInstruction(SpvOpImageSparseSampleProjDrefExplicitLod, + {1, 2, 3, 4, 5, SpvImageOperandsLodMask, 6}))); +} + +} // namespace +} // namespace spvtools diff --git a/test/text_to_binary.subgroup_dispatch_test.cpp b/test/text_to_binary.subgroup_dispatch_test.cpp new file mode 100644 index 000000000..967e3c38b --- /dev/null +++ b/test/text_to_binary.subgroup_dispatch_test.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Assembler tests for instructions in the "Barrier Instructions" section +// of the SPIR-V spec. + +#include "test/unit_spirv.h" + +#include "gmock/gmock.h" +#include "test/test_fixture.h" + +namespace spvtools { +namespace { + +using ::spvtest::MakeInstruction; +using ::testing::Eq; + +using OpGetKernelLocalSizeForSubgroupCountTest = spvtest::TextToBinaryTest; + +// We should be able to assemble it. Validation checks are in another test +// file. +TEST_F(OpGetKernelLocalSizeForSubgroupCountTest, OpcodeAssemblesInV10) { + EXPECT_THAT( + CompiledInstructions("%res = OpGetKernelLocalSizeForSubgroupCount %type " + "%sgcount %invoke %param %param_size %param_align", + SPV_ENV_UNIVERSAL_1_0), + Eq(MakeInstruction(SpvOpGetKernelLocalSizeForSubgroupCount, + {1, 2, 3, 4, 5, 6, 7}))); +} + +TEST_F(OpGetKernelLocalSizeForSubgroupCountTest, ArgumentCount) { + EXPECT_THAT(CompileFailure("OpGetKernelLocalSizeForSubgroupCount", + SPV_ENV_UNIVERSAL_1_1), + Eq("Expected at the beginning of an instruction, " + "found 'OpGetKernelLocalSizeForSubgroupCount'.")); + EXPECT_THAT(CompileFailure("%res = OpGetKernelLocalSizeForSubgroupCount", + SPV_ENV_UNIVERSAL_1_1), + Eq("Expected operand, found end of stream.")); + EXPECT_THAT( + CompileFailure("%1 = OpGetKernelLocalSizeForSubgroupCount %2 %3 %4 %5 %6", + SPV_ENV_UNIVERSAL_1_1), + Eq("Expected operand, found end of stream.")); + EXPECT_THAT( + CompiledInstructions("%res = OpGetKernelLocalSizeForSubgroupCount %type " + "%sgcount %invoke %param %param_size %param_align", + SPV_ENV_UNIVERSAL_1_1), + Eq(MakeInstruction(SpvOpGetKernelLocalSizeForSubgroupCount, + {1, 2, 3, 4, 5, 6, 7}))); + EXPECT_THAT( + CompileFailure("%res = OpGetKernelLocalSizeForSubgroupCount %type " + "%sgcount %invoke %param %param_size %param_align %extra", + SPV_ENV_UNIVERSAL_1_1), + Eq("Expected '=', found end of stream.")); +} + +TEST_F(OpGetKernelLocalSizeForSubgroupCountTest, ArgumentTypes) { + EXPECT_THAT(CompileFailure( + "%1 = OpGetKernelLocalSizeForSubgroupCount 2 %3 %4 %5 %6 %7", + SPV_ENV_UNIVERSAL_1_1), + Eq("Expected id to start with %.")); + EXPECT_THAT( + CompileFailure( + "%1 = OpGetKernelLocalSizeForSubgroupCount %2 %3 %4 %5 %6 \"abc\"", + SPV_ENV_UNIVERSAL_1_1), + Eq("Expected id to start with %.")); +} + +using OpGetKernelMaxNumSubgroupsTest = spvtest::TextToBinaryTest; + +TEST_F(OpGetKernelMaxNumSubgroupsTest, OpcodeAssemblesInV10) { + EXPECT_THAT( + CompiledInstructions("%res = OpGetKernelMaxNumSubgroups %type " + "%invoke %param %param_size %param_align", + SPV_ENV_UNIVERSAL_1_0), + Eq(MakeInstruction(SpvOpGetKernelMaxNumSubgroups, {1, 2, 3, 4, 5, 6}))); +} + +TEST_F(OpGetKernelMaxNumSubgroupsTest, ArgumentCount) { + EXPECT_THAT( + CompileFailure("OpGetKernelMaxNumSubgroups", SPV_ENV_UNIVERSAL_1_1), + Eq("Expected at the beginning of an instruction, found " + "'OpGetKernelMaxNumSubgroups'.")); + EXPECT_THAT(CompileFailure("%res = OpGetKernelMaxNumSubgroups", + SPV_ENV_UNIVERSAL_1_1), + Eq("Expected operand, found end of stream.")); + EXPECT_THAT(CompileFailure("%1 = OpGetKernelMaxNumSubgroups %2 %3 %4 %5", + SPV_ENV_UNIVERSAL_1_1), + Eq("Expected operand, found end of stream.")); + EXPECT_THAT( + CompiledInstructions("%res = OpGetKernelMaxNumSubgroups %type " + "%invoke %param %param_size %param_align", + SPV_ENV_UNIVERSAL_1_1), + Eq(MakeInstruction(SpvOpGetKernelMaxNumSubgroups, {1, 2, 3, 4, 5, 6}))); + EXPECT_THAT(CompileFailure("%res = OpGetKernelMaxNumSubgroups %type %invoke " + "%param %param_size %param_align %extra", + SPV_ENV_UNIVERSAL_1_1), + Eq("Expected '=', found end of stream.")); +} + +TEST_F(OpGetKernelMaxNumSubgroupsTest, ArgumentTypes) { + EXPECT_THAT(CompileFailure("%1 = OpGetKernelMaxNumSubgroups 2 %3 %4 %5 %6", + SPV_ENV_UNIVERSAL_1_1), + Eq("Expected id to start with %.")); + EXPECT_THAT( + CompileFailure("%1 = OpGetKernelMaxNumSubgroups %2 %3 %4 %5 \"abc\"", + SPV_ENV_UNIVERSAL_1_1), + Eq("Expected id to start with %.")); +} + +} // namespace +} // namespace spvtools diff --git a/test/text_to_binary.type_declaration_test.cpp b/test/text_to_binary.type_declaration_test.cpp new file mode 100644 index 000000000..c6f158f29 --- /dev/null +++ b/test/text_to_binary.type_declaration_test.cpp @@ -0,0 +1,293 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Assembler tests for instructions in the "Type-Declaration" section of the +// SPIR-V spec. + +#include +#include + +#include "gmock/gmock.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using spvtest::EnumCase; +using spvtest::MakeInstruction; +using ::testing::Eq; + +// Test Dim enums via OpTypeImage + +using DimTest = + spvtest::TextToBinaryTestBase<::testing::TestWithParam>>; + +TEST_P(DimTest, AnyDim) { + const std::string input = + "%1 = OpTypeImage %2 " + GetParam().name() + " 2 3 0 4 Rgba8\n"; + EXPECT_THAT( + CompiledInstructions(input), + Eq(MakeInstruction(SpvOpTypeImage, {1, 2, GetParam().value(), 2, 3, 0, 4, + SpvImageFormatRgba8}))); + + // Check the disassembler as well. + EXPECT_THAT(EncodeAndDecodeSuccessfully(input), Eq(input)); +} + +// clang-format off +#define CASE(NAME) {SpvDim##NAME, #NAME} +INSTANTIATE_TEST_CASE_P( + TextToBinaryDim, DimTest, + ::testing::ValuesIn(std::vector>{ + CASE(1D), + CASE(2D), + CASE(3D), + CASE(Cube), + CASE(Rect), + CASE(Buffer), + CASE(SubpassData), + }),); +#undef CASE +// clang-format on + +TEST_F(DimTest, WrongDim) { + EXPECT_THAT(CompileFailure("%i = OpTypeImage %t xxyyzz 1 2 3 4 R8"), + Eq("Invalid dimensionality 'xxyyzz'.")); +} + +// Test ImageFormat enums via OpTypeImage + +using ImageFormatTest = spvtest::TextToBinaryTestBase< + ::testing::TestWithParam>>; + +TEST_P(ImageFormatTest, AnyImageFormatAndNoAccessQualifier) { + const std::string input = + "%1 = OpTypeImage %2 1D 2 3 0 4 " + GetParam().name() + "\n"; + EXPECT_THAT(CompiledInstructions(input), + Eq(MakeInstruction(SpvOpTypeImage, {1, 2, SpvDim1D, 2, 3, 0, 4, + GetParam().value()}))); + // Check the disassembler as well. + EXPECT_THAT(EncodeAndDecodeSuccessfully(input), Eq(input)); +} + +// clang-format off +#define CASE(NAME) {SpvImageFormat##NAME, #NAME} +INSTANTIATE_TEST_CASE_P( + TextToBinaryImageFormat, ImageFormatTest, + ::testing::ValuesIn(std::vector>{ + CASE(Unknown), + CASE(Rgba32f), + CASE(Rgba16f), + CASE(R32f), + CASE(Rgba8), + CASE(Rgba8Snorm), + CASE(Rg32f), + CASE(Rg16f), + CASE(R11fG11fB10f), + CASE(R16f), + CASE(Rgba16), + CASE(Rgb10A2), + CASE(Rg16), + CASE(Rg8), + CASE(R16), + CASE(R8), + CASE(Rgba16Snorm), + CASE(Rg16Snorm), + CASE(Rg8Snorm), + CASE(R16Snorm), + CASE(R8Snorm), + CASE(Rgba32i), + CASE(Rgba16i), + CASE(Rgba8i), + CASE(R32i), + CASE(Rg32i), + CASE(Rg16i), + CASE(Rg8i), + CASE(R16i), + CASE(R8i), + CASE(Rgba32ui), + CASE(Rgba16ui), + CASE(Rgba8ui), + CASE(R32ui), + CASE(Rgb10a2ui), + CASE(Rg32ui), + CASE(Rg16ui), + CASE(Rg8ui), + CASE(R16ui), + CASE(R8ui), + }),); +#undef CASE +// clang-format on + +TEST_F(ImageFormatTest, WrongFormat) { + EXPECT_THAT(CompileFailure("%r = OpTypeImage %t 1D 2 3 0 4 xxyyzz"), + Eq("Invalid image format 'xxyyzz'.")); +} + +// Test AccessQualifier enums via OpTypeImage. +using ImageAccessQualifierTest = spvtest::TextToBinaryTestBase< + ::testing::TestWithParam>>; + +TEST_P(ImageAccessQualifierTest, AnyAccessQualifier) { + const std::string input = + "%1 = OpTypeImage %2 1D 2 3 0 4 Rgba8 " + GetParam().name() + "\n"; + EXPECT_THAT(CompiledInstructions(input), + Eq(MakeInstruction(SpvOpTypeImage, + {1, 2, SpvDim1D, 2, 3, 0, 4, + SpvImageFormatRgba8, GetParam().value()}))); + // Check the disassembler as well. + EXPECT_THAT(EncodeAndDecodeSuccessfully(input), Eq(input)); +} + +// clang-format off +#define CASE(NAME) {SpvAccessQualifier##NAME, #NAME} +INSTANTIATE_TEST_CASE_P( + AccessQualifier, ImageAccessQualifierTest, + ::testing::ValuesIn(std::vector>{ + CASE(ReadOnly), + CASE(WriteOnly), + CASE(ReadWrite), + }),); +// clang-format on +#undef CASE + +// Test AccessQualifier enums via OpTypePipe. + +using OpTypePipeTest = spvtest::TextToBinaryTestBase< + ::testing::TestWithParam>>; + +TEST_P(OpTypePipeTest, AnyAccessQualifier) { + const std::string input = "%1 = OpTypePipe " + GetParam().name() + "\n"; + EXPECT_THAT(CompiledInstructions(input), + Eq(MakeInstruction(SpvOpTypePipe, {1, GetParam().value()}))); + // Check the disassembler as well. + EXPECT_THAT(EncodeAndDecodeSuccessfully(input), Eq(input)); +} + +// clang-format off +#define CASE(NAME) {SpvAccessQualifier##NAME, #NAME} +INSTANTIATE_TEST_CASE_P( + TextToBinaryTypePipe, OpTypePipeTest, + ::testing::ValuesIn(std::vector>{ + CASE(ReadOnly), + CASE(WriteOnly), + CASE(ReadWrite), + }),); +#undef CASE +// clang-format on + +TEST_F(OpTypePipeTest, WrongAccessQualifier) { + EXPECT_THAT(CompileFailure("%1 = OpTypePipe xxyyzz"), + Eq("Invalid access qualifier 'xxyyzz'.")); +} + +using OpTypeForwardPointerTest = spvtest::TextToBinaryTest; + +#define CASE(storage_class) \ + do { \ + EXPECT_THAT( \ + CompiledInstructions("OpTypeForwardPointer %pt " #storage_class), \ + Eq(MakeInstruction(SpvOpTypeForwardPointer, \ + {1, SpvStorageClass##storage_class}))); \ + } while (0) + +TEST_F(OpTypeForwardPointerTest, ValidStorageClass) { + CASE(UniformConstant); + CASE(Input); + CASE(Uniform); + CASE(Output); + CASE(Workgroup); + CASE(CrossWorkgroup); + CASE(Private); + CASE(Function); + CASE(Generic); + CASE(PushConstant); + CASE(AtomicCounter); + CASE(Image); + CASE(StorageBuffer); +} + +#undef CASE + +TEST_F(OpTypeForwardPointerTest, MissingType) { + EXPECT_THAT(CompileFailure("OpTypeForwardPointer"), + Eq("Expected operand, found end of stream.")); +} + +TEST_F(OpTypeForwardPointerTest, MissingClass) { + EXPECT_THAT(CompileFailure("OpTypeForwardPointer %pt"), + Eq("Expected operand, found end of stream.")); +} + +TEST_F(OpTypeForwardPointerTest, WrongClass) { + EXPECT_THAT(CompileFailure("OpTypeForwardPointer %pt xxyyzz"), + Eq("Invalid storage class 'xxyyzz'.")); +} + +using OpSizeOfTest = spvtest::TextToBinaryTest; + +// We should be able to assemble it. Validation checks are in another test +// file. +TEST_F(OpSizeOfTest, OpcodeAssemblesInV10) { + EXPECT_THAT( + CompiledInstructions("%1 = OpSizeOf %2 %3", SPV_ENV_UNIVERSAL_1_0), + Eq(MakeInstruction(SpvOpSizeOf, {1, 2, 3}))); +} + +TEST_F(OpSizeOfTest, ArgumentCount) { + EXPECT_THAT( + CompileFailure("OpSizeOf", SPV_ENV_UNIVERSAL_1_1), + Eq("Expected at the beginning of an instruction, found " + "'OpSizeOf'.")); + EXPECT_THAT(CompileFailure("%res = OpSizeOf OpNop", SPV_ENV_UNIVERSAL_1_1), + Eq("Expected operand, found next instruction instead.")); + EXPECT_THAT( + CompiledInstructions("%1 = OpSizeOf %2 %3", SPV_ENV_UNIVERSAL_1_1), + Eq(MakeInstruction(SpvOpSizeOf, {1, 2, 3}))); + EXPECT_THAT( + CompileFailure("%1 = OpSizeOf %2 %3 44 55 ", SPV_ENV_UNIVERSAL_1_1), + Eq("Expected or at the beginning of an instruction, " + "found '44'.")); +} + +TEST_F(OpSizeOfTest, ArgumentTypes) { + EXPECT_THAT(CompileFailure("%1 = OpSizeOf 2 %3", SPV_ENV_UNIVERSAL_1_1), + Eq("Expected id to start with %.")); + EXPECT_THAT(CompileFailure("%1 = OpSizeOf %2 \"abc\"", SPV_ENV_UNIVERSAL_1_1), + Eq("Expected id to start with %.")); +} + +// TODO(dneto): OpTypeVoid +// TODO(dneto): OpTypeBool +// TODO(dneto): OpTypeInt +// TODO(dneto): OpTypeFloat +// TODO(dneto): OpTypeVector +// TODO(dneto): OpTypeMatrix +// TODO(dneto): OpTypeImage +// TODO(dneto): OpTypeSampler +// TODO(dneto): OpTypeSampledImage +// TODO(dneto): OpTypeArray +// TODO(dneto): OpTypeRuntimeArray +// TODO(dneto): OpTypeStruct +// TODO(dneto): OpTypeOpaque +// TODO(dneto): OpTypePointer +// TODO(dneto): OpTypeFunction +// TODO(dneto): OpTypeEvent +// TODO(dneto): OpTypeDeviceEvent +// TODO(dneto): OpTypeReserveId +// TODO(dneto): OpTypeQueue + +} // namespace +} // namespace spvtools diff --git a/test/text_to_binary_test.cpp b/test/text_to_binary_test.cpp new file mode 100644 index 000000000..4ba37ad4d --- /dev/null +++ b/test/text_to_binary_test.cpp @@ -0,0 +1,269 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/spirv_constant.h" +#include "source/util/bitutils.h" +#include "source/util/hex_float.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using spvtest::AutoText; +using spvtest::Concatenate; +using spvtest::MakeInstruction; +using spvtest::ScopedContext; +using spvtest::TextToBinaryTest; +using testing::Eq; +using testing::IsNull; +using testing::NotNull; + +// An mask parsing test case. +struct MaskCase { + spv_operand_type_t which_enum; + uint32_t expected_value; + const char* expression; +}; + +using GoodMaskParseTest = ::testing::TestWithParam; + +TEST_P(GoodMaskParseTest, GoodMaskExpressions) { + spv_context context = spvContextCreate(SPV_ENV_UNIVERSAL_1_0); + + uint32_t value; + EXPECT_EQ(SPV_SUCCESS, + AssemblyGrammar(context).parseMaskOperand( + GetParam().which_enum, GetParam().expression, &value)); + EXPECT_EQ(GetParam().expected_value, value); + + spvContextDestroy(context); +} + +INSTANTIATE_TEST_CASE_P( + ParseMask, GoodMaskParseTest, + ::testing::ValuesIn(std::vector{ + {SPV_OPERAND_TYPE_FP_FAST_MATH_MODE, 0, "None"}, + {SPV_OPERAND_TYPE_FP_FAST_MATH_MODE, 1, "NotNaN"}, + {SPV_OPERAND_TYPE_FP_FAST_MATH_MODE, 2, "NotInf"}, + {SPV_OPERAND_TYPE_FP_FAST_MATH_MODE, 3, "NotNaN|NotInf"}, + // Mask experssions are symmetric. + {SPV_OPERAND_TYPE_FP_FAST_MATH_MODE, 3, "NotInf|NotNaN"}, + // Repeating a value has no effect. + {SPV_OPERAND_TYPE_FP_FAST_MATH_MODE, 3, "NotInf|NotNaN|NotInf"}, + // Using 3 operands still works. + {SPV_OPERAND_TYPE_FP_FAST_MATH_MODE, 0x13, "NotInf|NotNaN|Fast"}, + {SPV_OPERAND_TYPE_SELECTION_CONTROL, 0, "None"}, + {SPV_OPERAND_TYPE_SELECTION_CONTROL, 1, "Flatten"}, + {SPV_OPERAND_TYPE_SELECTION_CONTROL, 2, "DontFlatten"}, + // Weirdly, you can specify to flatten and don't flatten a selection. + {SPV_OPERAND_TYPE_SELECTION_CONTROL, 3, "Flatten|DontFlatten"}, + {SPV_OPERAND_TYPE_LOOP_CONTROL, 0, "None"}, + {SPV_OPERAND_TYPE_LOOP_CONTROL, 1, "Unroll"}, + {SPV_OPERAND_TYPE_LOOP_CONTROL, 2, "DontUnroll"}, + // Weirdly, you can specify to unroll and don't unroll a loop. + {SPV_OPERAND_TYPE_LOOP_CONTROL, 3, "Unroll|DontUnroll"}, + {SPV_OPERAND_TYPE_FUNCTION_CONTROL, 0, "None"}, + {SPV_OPERAND_TYPE_FUNCTION_CONTROL, 1, "Inline"}, + {SPV_OPERAND_TYPE_FUNCTION_CONTROL, 2, "DontInline"}, + {SPV_OPERAND_TYPE_FUNCTION_CONTROL, 4, "Pure"}, + {SPV_OPERAND_TYPE_FUNCTION_CONTROL, 8, "Const"}, + {SPV_OPERAND_TYPE_FUNCTION_CONTROL, 0xd, "Inline|Const|Pure"}, + }), ); + +using BadFPFastMathMaskParseTest = ::testing::TestWithParam; + +TEST_P(BadFPFastMathMaskParseTest, BadMaskExpressions) { + spv_context context = spvContextCreate(SPV_ENV_UNIVERSAL_1_0); + + uint32_t value; + EXPECT_NE(SPV_SUCCESS, + AssemblyGrammar(context).parseMaskOperand( + SPV_OPERAND_TYPE_FP_FAST_MATH_MODE, GetParam(), &value)); + + spvContextDestroy(context); +} + +INSTANTIATE_TEST_CASE_P(ParseMask, BadFPFastMathMaskParseTest, + ::testing::ValuesIn(std::vector{ + nullptr, "", "NotValidEnum", "|", "NotInf|", + "|NotInf", "NotInf||NotNaN", + "Unroll" // A good word, but for the wrong enum + }), ); + +TEST_F(TextToBinaryTest, InvalidText) { + ASSERT_EQ(SPV_ERROR_INVALID_TEXT, + spvTextToBinary(ScopedContext().context, nullptr, 0, &binary, + &diagnostic)); + EXPECT_NE(nullptr, diagnostic); + EXPECT_THAT(diagnostic->error, Eq(std::string("Missing assembly text."))); +} + +TEST_F(TextToBinaryTest, InvalidPointer) { + SetText( + "OpEntryPoint Kernel 0 \"\"\nOpExecutionMode 0 LocalSizeHint 1 1 1\n"); + ASSERT_EQ(SPV_ERROR_INVALID_POINTER, + spvTextToBinary(ScopedContext().context, text.str, text.length, + nullptr, &diagnostic)); +} + +TEST_F(TextToBinaryTest, InvalidPrefix) { + EXPECT_EQ( + "Expected or at the beginning of an instruction, " + "found 'Invalid'.", + CompileFailure("Invalid")); +} + +TEST_F(TextToBinaryTest, EmptyAssemblyString) { + // An empty assembly module is valid! + // It should produce a valid module with zero instructions. + EXPECT_THAT(CompiledInstructions(""), Eq(std::vector{})); +} + +TEST_F(TextToBinaryTest, StringSpace) { + const std::string code = ("OpSourceExtension \"string with spaces\"\n"); + EXPECT_EQ(code, EncodeAndDecodeSuccessfully(code)); +} + +TEST_F(TextToBinaryTest, UnknownBeginningOfInstruction) { + EXPECT_EQ( + "Expected or at the beginning of an instruction, " + "found 'Google'.", + CompileFailure( + "\nOpSource OpenCL_C 12\nOpMemoryModel Physical64 OpenCL\nGoogle\n")); + EXPECT_EQ(4u, diagnostic->position.line + 1); + EXPECT_EQ(1u, diagnostic->position.column + 1); +} + +TEST_F(TextToBinaryTest, NoEqualSign) { + EXPECT_EQ("Expected '=', found end of stream.", + CompileFailure("\nOpSource OpenCL_C 12\n" + "OpMemoryModel Physical64 OpenCL\n%2\n")); + EXPECT_EQ(5u, diagnostic->position.line + 1); + EXPECT_EQ(1u, diagnostic->position.column + 1); +} + +TEST_F(TextToBinaryTest, NoOpCode) { + EXPECT_EQ("Expected opcode, found end of stream.", + CompileFailure("\nOpSource OpenCL_C 12\n" + "OpMemoryModel Physical64 OpenCL\n%2 =\n")); + EXPECT_EQ(5u, diagnostic->position.line + 1); + EXPECT_EQ(1u, diagnostic->position.column + 1); +} + +TEST_F(TextToBinaryTest, WrongOpCode) { + EXPECT_EQ("Invalid Opcode prefix 'Wahahaha'.", + CompileFailure("\nOpSource OpenCL_C 12\n" + "OpMemoryModel Physical64 OpenCL\n%2 = Wahahaha\n")); + EXPECT_EQ(4u, diagnostic->position.line + 1); + EXPECT_EQ(6u, diagnostic->position.column + 1); +} + +TEST_F(TextToBinaryTest, CRLF) { + const std::string input = + "%i32 = OpTypeInt 32 1\r\n%c = OpConstant %i32 123\r\n"; + EXPECT_THAT(CompiledInstructions(input), + Eq(Concatenate({MakeInstruction(SpvOpTypeInt, {1, 32, 1}), + MakeInstruction(SpvOpConstant, {1, 2, 123})}))); +} + +using TextToBinaryFloatValueTest = spvtest::TextToBinaryTestBase< + ::testing::TestWithParam>>; + +TEST_P(TextToBinaryFloatValueTest, Samples) { + const std::string input = + "%1 = OpTypeFloat 32\n%2 = OpConstant %1 " + GetParam().first; + EXPECT_THAT(CompiledInstructions(input), + Eq(Concatenate({MakeInstruction(SpvOpTypeFloat, {1, 32}), + MakeInstruction(SpvOpConstant, + {1, 2, GetParam().second})}))); +} + +INSTANTIATE_TEST_CASE_P( + FloatValues, TextToBinaryFloatValueTest, + ::testing::ValuesIn(std::vector>{ + {"0.0", 0x00000000}, // +0 + {"!0x00000001", 0x00000001}, // +denorm + {"!0x00800000", 0x00800000}, // +norm + {"1.5", 0x3fc00000}, + {"!0x7f800000", 0x7f800000}, // +inf + {"!0x7f800001", 0x7f800001}, // NaN + + {"-0.0", 0x80000000}, // -0 + {"!0x80000001", 0x80000001}, // -denorm + {"!0x80800000", 0x80800000}, // -norm + {"-2.5", 0xc0200000}, + {"!0xff800000", 0xff800000}, // -inf + {"!0xff800001", 0xff800001}, // NaN + }), ); + +using TextToBinaryHalfValueTest = spvtest::TextToBinaryTestBase< + ::testing::TestWithParam>>; + +TEST_P(TextToBinaryHalfValueTest, Samples) { + const std::string input = + "%1 = OpTypeFloat 16\n%2 = OpConstant %1 " + GetParam().first; + EXPECT_THAT(CompiledInstructions(input), + Eq(Concatenate({MakeInstruction(SpvOpTypeFloat, {1, 16}), + MakeInstruction(SpvOpConstant, + {1, 2, GetParam().second})}))); +} + +INSTANTIATE_TEST_CASE_P( + HalfValues, TextToBinaryHalfValueTest, + ::testing::ValuesIn(std::vector>{ + {"0.0", 0x00000000}, + {"1.0", 0x00003c00}, + {"1.000844", 0x00003c00}, // Truncate to 1.0 + {"1.000977", 0x00003c01}, // Don't have to truncate + {"1.001465", 0x00003c01}, // Truncate to 1.0000977 + {"1.5", 0x00003e00}, + {"-1.0", 0x0000bc00}, + {"2.0", 0x00004000}, + {"-2.0", 0x0000c000}, + {"0x1p1", 0x00004000}, + {"-0x1p1", 0x0000c000}, + {"0x1.8p1", 0x00004200}, + {"0x1.8p4", 0x00004e00}, + {"0x1.801p4", 0x00004e00}, + {"0x1.804p4", 0x00004e01}, + }), ); + +TEST(CreateContext, InvalidEnvironment) { + spv_target_env env; + std::memset(&env, 99, sizeof(env)); + EXPECT_THAT(spvContextCreate(env), IsNull()); +} + +TEST(CreateContext, UniversalEnvironment) { + auto c = spvContextCreate(SPV_ENV_UNIVERSAL_1_0); + EXPECT_THAT(c, NotNull()); + spvContextDestroy(c); +} + +TEST(CreateContext, VulkanEnvironment) { + auto c = spvContextCreate(SPV_ENV_VULKAN_1_0); + EXPECT_THAT(c, NotNull()); + spvContextDestroy(c); +} + +} // namespace +} // namespace spvtools diff --git a/test/text_word_get_test.cpp b/test/text_word_get_test.cpp new file mode 100644 index 000000000..b74a680fa --- /dev/null +++ b/test/text_word_get_test.cpp @@ -0,0 +1,254 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "test/unit_spirv.h" + +namespace spvtools { +namespace { + +using spvtest::AutoText; + +#define TAB "\t" +#define NEWLINE "\n" +#define BACKSLASH R"(\)" +#define QUOTE R"(")" + +TEST(TextWordGet, NullTerminator) { + std::string word; + spv_position_t endPosition = {}; + ASSERT_EQ( + SPV_SUCCESS, + AssemblyContext(AutoText("Word"), nullptr).getWord(&word, &endPosition)); + ASSERT_EQ(4u, endPosition.column); + ASSERT_EQ(0u, endPosition.line); + ASSERT_EQ(4u, endPosition.index); + ASSERT_STREQ("Word", word.c_str()); +} + +TEST(TextWordGet, TabTerminator) { + std::string word; + spv_position_t endPosition = {}; + ASSERT_EQ(SPV_SUCCESS, AssemblyContext(AutoText("Word\t"), nullptr) + .getWord(&word, &endPosition)); + ASSERT_EQ(4u, endPosition.column); + ASSERT_EQ(0u, endPosition.line); + ASSERT_EQ(4u, endPosition.index); + ASSERT_STREQ("Word", word.c_str()); +} + +TEST(TextWordGet, SpaceTerminator) { + std::string word; + spv_position_t endPosition = {}; + ASSERT_EQ( + SPV_SUCCESS, + AssemblyContext(AutoText("Word "), nullptr).getWord(&word, &endPosition)); + ASSERT_EQ(4u, endPosition.column); + ASSERT_EQ(0u, endPosition.line); + ASSERT_EQ(4u, endPosition.index); + ASSERT_STREQ("Word", word.c_str()); +} + +TEST(TextWordGet, SemicolonTerminator) { + std::string word; + spv_position_t endPosition = {}; + ASSERT_EQ( + SPV_SUCCESS, + AssemblyContext(AutoText("Wo;rd"), nullptr).getWord(&word, &endPosition)); + ASSERT_EQ(2u, endPosition.column); + ASSERT_EQ(0u, endPosition.line); + ASSERT_EQ(2u, endPosition.index); + ASSERT_STREQ("Wo", word.c_str()); +} + +TEST(TextWordGet, NoTerminator) { + const std::string full_text = "abcdefghijklmn"; + for (size_t len = 1; len <= full_text.size(); ++len) { + std::string word; + spv_text_t text = {full_text.data(), len}; + spv_position_t endPosition = {}; + ASSERT_EQ(SPV_SUCCESS, + AssemblyContext(&text, nullptr).getWord(&word, &endPosition)); + ASSERT_EQ(0u, endPosition.line); + ASSERT_EQ(len, endPosition.column); + ASSERT_EQ(len, endPosition.index); + ASSERT_EQ(full_text.substr(0, len), word); + } +} + +TEST(TextWordGet, MultipleWords) { + AutoText input("Words in a sentence"); + AssemblyContext data(input, nullptr); + + spv_position_t endPosition = {}; + const char* words[] = {"Words", "in", "a", "sentence"}; + + std::string word; + for (uint32_t wordIndex = 0; wordIndex < 4; ++wordIndex) { + ASSERT_EQ(SPV_SUCCESS, data.getWord(&word, &endPosition)); + ASSERT_EQ(strlen(words[wordIndex]), + endPosition.column - data.position().column); + ASSERT_EQ(0u, endPosition.line); + ASSERT_EQ(strlen(words[wordIndex]), + endPosition.index - data.position().index); + ASSERT_STREQ(words[wordIndex], word.c_str()); + + data.setPosition(endPosition); + if (3 != wordIndex) { + ASSERT_EQ(SPV_SUCCESS, data.advance()); + } else { + ASSERT_EQ(SPV_END_OF_STREAM, data.advance()); + } + } +} + +TEST(TextWordGet, QuotesAreKept) { + AutoText input(R"("quotes" "around words")"); + const char* expected[] = {R"("quotes")", R"("around words")"}; + AssemblyContext data(input, nullptr); + + std::string word; + spv_position_t endPosition = {}; + ASSERT_EQ(SPV_SUCCESS, data.getWord(&word, &endPosition)); + EXPECT_EQ(8u, endPosition.column); + EXPECT_EQ(0u, endPosition.line); + EXPECT_EQ(8u, endPosition.index); + EXPECT_STREQ(expected[0], word.c_str()); + + // Move to the next word. + data.setPosition(endPosition); + data.seekForward(1); + + ASSERT_EQ(SPV_SUCCESS, data.getWord(&word, &endPosition)); + EXPECT_EQ(23u, endPosition.column); + EXPECT_EQ(0u, endPosition.line); + EXPECT_EQ(23u, endPosition.index); + EXPECT_STREQ(expected[1], word.c_str()); +} + +TEST(TextWordGet, QuotesBetweenWordsActLikeGlue) { + AutoText input(R"(quotes" "between words)"); + const char* expected[] = {R"(quotes" "between)", "words"}; + AssemblyContext data(input, nullptr); + + std::string word; + spv_position_t endPosition = {}; + ASSERT_EQ(SPV_SUCCESS, data.getWord(&word, &endPosition)); + EXPECT_EQ(16u, endPosition.column); + EXPECT_EQ(0u, endPosition.line); + EXPECT_EQ(16u, endPosition.index); + EXPECT_STREQ(expected[0], word.c_str()); + + // Move to the next word. + data.setPosition(endPosition); + data.seekForward(1); + + ASSERT_EQ(SPV_SUCCESS, data.getWord(&word, &endPosition)); + EXPECT_EQ(22u, endPosition.column); + EXPECT_EQ(0u, endPosition.line); + EXPECT_EQ(22u, endPosition.index); + EXPECT_STREQ(expected[1], word.c_str()); +} + +TEST(TextWordGet, QuotingWhitespace) { + AutoText input(QUOTE "white " NEWLINE TAB " space" QUOTE); + // Whitespace surrounded by quotes acts like glue. + std::string word; + spv_position_t endPosition = {}; + ASSERT_EQ(SPV_SUCCESS, + AssemblyContext(input, nullptr).getWord(&word, &endPosition)); + EXPECT_EQ(input.str.length(), endPosition.column); + EXPECT_EQ(0u, endPosition.line); + EXPECT_EQ(input.str.length(), endPosition.index); + EXPECT_EQ(input.str, word); +} + +TEST(TextWordGet, QuoteAlone) { + AutoText input(QUOTE); + std::string word; + spv_position_t endPosition = {}; + ASSERT_EQ(SPV_SUCCESS, + AssemblyContext(input, nullptr).getWord(&word, &endPosition)); + ASSERT_EQ(1u, endPosition.column); + ASSERT_EQ(0u, endPosition.line); + ASSERT_EQ(1u, endPosition.index); + ASSERT_STREQ(QUOTE, word.c_str()); +} + +TEST(TextWordGet, EscapeAlone) { + AutoText input(BACKSLASH); + std::string word; + spv_position_t endPosition = {}; + ASSERT_EQ(SPV_SUCCESS, + AssemblyContext(input, nullptr).getWord(&word, &endPosition)); + ASSERT_EQ(1u, endPosition.column); + ASSERT_EQ(0u, endPosition.line); + ASSERT_EQ(1u, endPosition.index); + ASSERT_STREQ(BACKSLASH, word.c_str()); +} + +TEST(TextWordGet, EscapeAtEndOfInput) { + AutoText input("word" BACKSLASH); + std::string word; + spv_position_t endPosition = {}; + ASSERT_EQ(SPV_SUCCESS, + AssemblyContext(input, nullptr).getWord(&word, &endPosition)); + ASSERT_EQ(5u, endPosition.column); + ASSERT_EQ(0u, endPosition.line); + ASSERT_EQ(5u, endPosition.index); + ASSERT_STREQ("word" BACKSLASH, word.c_str()); +} + +TEST(TextWordGet, Escaping) { + AutoText input("w" BACKSLASH QUOTE "o" BACKSLASH NEWLINE "r" BACKSLASH ";d"); + std::string word; + spv_position_t endPosition = {}; + ASSERT_EQ(SPV_SUCCESS, + AssemblyContext(input, nullptr).getWord(&word, &endPosition)); + ASSERT_EQ(10u, endPosition.column); + ASSERT_EQ(0u, endPosition.line); + ASSERT_EQ(10u, endPosition.index); + ASSERT_EQ(input.str, word); +} + +TEST(TextWordGet, EscapingEscape) { + AutoText input("word" BACKSLASH BACKSLASH " abc"); + std::string word; + spv_position_t endPosition = {}; + ASSERT_EQ(SPV_SUCCESS, + AssemblyContext(input, nullptr).getWord(&word, &endPosition)); + ASSERT_EQ(6u, endPosition.column); + ASSERT_EQ(0u, endPosition.line); + ASSERT_EQ(6u, endPosition.index); + ASSERT_STREQ("word" BACKSLASH BACKSLASH, word.c_str()); +} + +TEST(TextWordGet, CRLF) { + AutoText input("abc\r\nd"); + AssemblyContext data(input, nullptr); + std::string word; + spv_position_t pos = {}; + ASSERT_EQ(SPV_SUCCESS, data.getWord(&word, &pos)); + EXPECT_EQ(3u, pos.column); + EXPECT_STREQ("abc", word.c_str()); + data.setPosition(pos); + data.advance(); + ASSERT_EQ(SPV_SUCCESS, data.getWord(&word, &pos)); + EXPECT_EQ(1u, pos.column); + EXPECT_STREQ("d", word.c_str()); +} + +} // namespace +} // namespace spvtools diff --git a/test/timer_test.cpp b/test/timer_test.cpp new file mode 100644 index 000000000..e53af6653 --- /dev/null +++ b/test/timer_test.cpp @@ -0,0 +1,142 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gtest/gtest.h" +#include "source/util/timer.h" + +namespace spvtools { +namespace utils { +namespace { + +// A mock class to mimic Timer class for a testing purpose. It has fixed +// CPU/WALL/USR/SYS time, RSS delta, and the delta of the number of page faults. +class MockTimer : public Timer { + public: + MockTimer(std::ostream* out, bool measure_mem_usage = false) + : Timer(out, measure_mem_usage) {} + double CPUTime() override { return 0.019123; } + double WallTime() override { return 0.019723; } + double UserTime() override { return 0.012723; } + double SystemTime() override { return 0.002723; } + long RSS() const override { return 360L; } + long PageFault() const override { return 3600L; } +}; + +// This unit test checks whether the actual output of MockTimer::Report() is the +// same as fixed CPU/WALL/USR/SYS time, RSS delta, and the delta of the number +// of page faults that are returned by MockTimer. +TEST(MockTimer, DoNothing) { + std::ostringstream buf; + + PrintTimerDescription(&buf); + MockTimer timer(&buf); + timer.Start(); + + // Do nothing. + + timer.Stop(); + timer.Report("TimerTest"); + + EXPECT_EQ(0.019123, timer.CPUTime()); + EXPECT_EQ(0.019723, timer.WallTime()); + EXPECT_EQ(0.012723, timer.UserTime()); + EXPECT_EQ(0.002723, timer.SystemTime()); + EXPECT_EQ( + " PASS name CPU time WALL time USR time" + " SYS time\n TimerTest 0.02 0.02" + " 0.01 0.00\n", + buf.str()); +} + +// This unit test checks whether the ScopedTimer correctly reports +// the fixed CPU/WALL/USR/SYS time, RSS delta, and the delta of the number of +// page faults that are returned by MockTimer. +TEST(MockTimer, TestScopedTimer) { + std::ostringstream buf; + + { + ScopedTimer scopedtimer(&buf, "ScopedTimerTest"); + // Do nothing. + } + + EXPECT_EQ( + " ScopedTimerTest 0.02 0.02 0.01" + " 0.00\n", + buf.str()); +} + +// A mock class to mimic CumulativeTimer class for a testing purpose. It has +// fixed CPU/WALL/USR/SYS time, RSS delta, and the delta of the number of page +// faults for each measurement (i.e., a pair of Start() and Stop()). If the +// number of measurements increases, it increases |count_stop_| by the number of +// calling Stop() and the amount of each resource usage is proportional to +// |count_stop_|. +class MockCumulativeTimer : public CumulativeTimer { + public: + MockCumulativeTimer(std::ostream* out, bool measure_mem_usage = false) + : CumulativeTimer(out, measure_mem_usage), count_stop_(0) {} + double CPUTime() override { return count_stop_ * 0.019123; } + double WallTime() override { return count_stop_ * 0.019723; } + double UserTime() override { return count_stop_ * 0.012723; } + double SystemTime() override { return count_stop_ * 0.002723; } + long RSS() const override { return count_stop_ * 360L; } + long PageFault() const override { return count_stop_ * 3600L; } + + // Calling Stop() does nothing but just increases |count_stop_| by 1. + void Stop() override { ++count_stop_; }; + + private: + unsigned int count_stop_; +}; + +// This unit test checks whether the MockCumulativeTimer correctly reports the +// cumulative CPU/WALL/USR/SYS time, RSS delta, and the delta of the number of +// page faults whose values are fixed for each measurement (i.e., a pair of +// Start() and Stop()). +TEST(MockCumulativeTimer, DoNothing) { + CumulativeTimer* ctimer; + std::ostringstream buf; + + { + ctimer = new MockCumulativeTimer(&buf); + ctimer->Start(); + + // Do nothing. + + ctimer->Stop(); + } + + { + ctimer->Start(); + + // Do nothing. + + ctimer->Stop(); + ctimer->Report("CumulativeTimerTest"); + } + + EXPECT_EQ( + " CumulativeTimerTest 0.04 0.04 0.03" + " 0.01\n", + buf.str()); + + if (ctimer) delete ctimer; +} + +} // namespace +} // namespace utils +} // namespace spvtools diff --git a/test/tools/CMakeLists.txt b/test/tools/CMakeLists.txt new file mode 100644 index 000000000..cee95cadb --- /dev/null +++ b/test/tools/CMakeLists.txt @@ -0,0 +1,18 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +spirv_add_nosetests(expect) +spirv_add_nosetests(spirv_test_framework) + +add_subdirectory(opt) diff --git a/test/tools/expect.py b/test/tools/expect.py new file mode 100755 index 000000000..c9596506a --- /dev/null +++ b/test/tools/expect.py @@ -0,0 +1,677 @@ +# Copyright (c) 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A number of common spirv result checks coded in mixin classes. + +A test case can use these checks by declaring their enclosing mixin classes +as superclass and providing the expected_* variables required by the check_*() +methods in the mixin classes. +""" +import difflib +import os +import re +import subprocess +from spirv_test_framework import SpirvTest + + +def convert_to_unix_line_endings(source): + """Converts all line endings in source to be unix line endings.""" + return source.replace('\r\n', '\n').replace('\r', '\n') + + +def substitute_file_extension(filename, extension): + """Substitutes file extension, respecting known shader extensions. + + foo.vert -> foo.vert.[extension] [similarly for .frag, .comp, etc.] + foo.glsl -> foo.[extension] + foo.unknown -> foo.[extension] + foo -> foo.[extension] + """ + if filename[-5:] not in [ + '.vert', '.frag', '.tesc', '.tese', '.geom', '.comp', '.spvasm' + ]: + return filename.rsplit('.', 1)[0] + '.' + extension + else: + return filename + '.' + extension + + +def get_object_filename(source_filename): + """Gets the object filename for the given source file.""" + return substitute_file_extension(source_filename, 'spv') + + +def get_assembly_filename(source_filename): + """Gets the assembly filename for the given source file.""" + return substitute_file_extension(source_filename, 'spvasm') + + +def verify_file_non_empty(filename): + """Checks that a given file exists and is not empty.""" + if not os.path.isfile(filename): + return False, 'Cannot find file: ' + filename + if not os.path.getsize(filename): + return False, 'Empty file: ' + filename + return True, '' + + +class ReturnCodeIsZero(SpirvTest): + """Mixin class for checking that the return code is zero.""" + + def check_return_code_is_zero(self, status): + if status.returncode: + return False, 'Non-zero return code: {ret}\n'.format( + ret=status.returncode) + return True, '' + + +class NoOutputOnStdout(SpirvTest): + """Mixin class for checking that there is no output on stdout.""" + + def check_no_output_on_stdout(self, status): + if status.stdout: + return False, 'Non empty stdout: {out}\n'.format(out=status.stdout) + return True, '' + + +class NoOutputOnStderr(SpirvTest): + """Mixin class for checking that there is no output on stderr.""" + + def check_no_output_on_stderr(self, status): + if status.stderr: + return False, 'Non empty stderr: {err}\n'.format(err=status.stderr) + return True, '' + + +class SuccessfulReturn(ReturnCodeIsZero, NoOutputOnStdout, NoOutputOnStderr): + """Mixin class for checking that return code is zero and no output on + stdout and stderr.""" + pass + + +class NoGeneratedFiles(SpirvTest): + """Mixin class for checking that there is no file generated.""" + + def check_no_generated_files(self, status): + all_files = os.listdir(status.directory) + input_files = status.input_filenames + if all([f.startswith(status.directory) for f in input_files]): + all_files = [os.path.join(status.directory, f) for f in all_files] + generated_files = set(all_files) - set(input_files) + if len(generated_files) == 0: + return True, '' + else: + return False, 'Extra files generated: {}'.format(generated_files) + + +class CorrectBinaryLengthAndPreamble(SpirvTest): + """Provides methods for verifying preamble for a SPIR-V binary.""" + + def verify_binary_length_and_header(self, binary, spv_version=0x10000): + """Checks that the given SPIR-V binary has valid length and header. + + Returns: + False, error string if anything is invalid + True, '' otherwise + Args: + binary: a bytes object containing the SPIR-V binary + spv_version: target SPIR-V version number, with same encoding + as the version word in a SPIR-V header. + """ + + def read_word(binary, index, little_endian): + """Reads the index-th word from the given binary file.""" + word = binary[index * 4:(index + 1) * 4] + if little_endian: + word = reversed(word) + return reduce(lambda w, b: (w << 8) | ord(b), word, 0) + + def check_endianness(binary): + """Checks the endianness of the given SPIR-V binary. + + Returns: + True if it's little endian, False if it's big endian. + None if magic number is wrong. + """ + first_word = read_word(binary, 0, True) + if first_word == 0x07230203: + return True + first_word = read_word(binary, 0, False) + if first_word == 0x07230203: + return False + return None + + num_bytes = len(binary) + if num_bytes % 4 != 0: + return False, ('Incorrect SPV binary: size should be a multiple' + ' of words') + if num_bytes < 20: + return False, 'Incorrect SPV binary: size less than 5 words' + + preamble = binary[0:19] + little_endian = check_endianness(preamble) + # SPIR-V module magic number + if little_endian is None: + return False, 'Incorrect SPV binary: wrong magic number' + + # SPIR-V version number + version = read_word(preamble, 1, little_endian) + # TODO(dneto): Recent Glslang uses version word 0 for opengl_compat + # profile + + if version != spv_version and version != 0: + return False, 'Incorrect SPV binary: wrong version number' + # Shaderc-over-Glslang (0x000d....) or + # SPIRV-Tools (0x0007....) generator number + if read_word(preamble, 2, little_endian) != 0x000d0007 and \ + read_word(preamble, 2, little_endian) != 0x00070000: + return False, ('Incorrect SPV binary: wrong generator magic ' 'number') + # reserved for instruction schema + if read_word(preamble, 4, little_endian) != 0: + return False, 'Incorrect SPV binary: the 5th byte should be 0' + + return True, '' + + +class CorrectObjectFilePreamble(CorrectBinaryLengthAndPreamble): + """Provides methods for verifying preamble for a SPV object file.""" + + def verify_object_file_preamble(self, filename, spv_version=0x10000): + """Checks that the given SPIR-V binary file has correct preamble.""" + + success, message = verify_file_non_empty(filename) + if not success: + return False, message + + with open(filename, 'rb') as object_file: + object_file.seek(0, os.SEEK_END) + num_bytes = object_file.tell() + + object_file.seek(0) + + binary = bytes(object_file.read()) + return self.verify_binary_length_and_header(binary, spv_version) + + return True, '' + + +class CorrectAssemblyFilePreamble(SpirvTest): + """Provides methods for verifying preamble for a SPV assembly file.""" + + def verify_assembly_file_preamble(self, filename): + success, message = verify_file_non_empty(filename) + if not success: + return False, message + + with open(filename) as assembly_file: + line1 = assembly_file.readline() + line2 = assembly_file.readline() + line3 = assembly_file.readline() + + if (line1 != '; SPIR-V\n' or line2 != '; Version: 1.0\n' or + (not line3.startswith('; Generator: Google Shaderc over Glslang;'))): + return False, 'Incorrect SPV assembly' + + return True, '' + + +class ValidObjectFile(SuccessfulReturn, CorrectObjectFilePreamble): + """Mixin class for checking that every input file generates a valid SPIR-V 1.0 + object file following the object file naming rule, and there is no output on + stdout/stderr.""" + + def check_object_file_preamble(self, status): + for input_filename in status.input_filenames: + object_filename = get_object_filename(input_filename) + success, message = self.verify_object_file_preamble( + os.path.join(status.directory, object_filename)) + if not success: + return False, message + return True, '' + + +class ValidObjectFile1_3(ReturnCodeIsZero, CorrectObjectFilePreamble): + """Mixin class for checking that every input file generates a valid SPIR-V 1.3 + object file following the object file naming rule, and there is no output on + stdout/stderr.""" + + def check_object_file_preamble(self, status): + for input_filename in status.input_filenames: + object_filename = get_object_filename(input_filename) + success, message = self.verify_object_file_preamble( + os.path.join(status.directory, object_filename), 0x10300) + if not success: + return False, message + return True, '' + + +class ValidObjectFileWithAssemblySubstr(SuccessfulReturn, + CorrectObjectFilePreamble): + """Mixin class for checking that every input file generates a valid object + + file following the object file naming rule, there is no output on + stdout/stderr, and the disassmbly contains a specified substring per + input. + """ + + def check_object_file_disassembly(self, status): + for an_input in status.inputs: + object_filename = get_object_filename(an_input.filename) + obj_file = str(os.path.join(status.directory, object_filename)) + success, message = self.verify_object_file_preamble(obj_file) + if not success: + return False, message + cmd = [status.test_manager.disassembler_path, '--no-color', obj_file] + process = subprocess.Popen( + args=cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=status.directory) + output = process.communicate(None) + disassembly = output[0] + if not isinstance(an_input.assembly_substr, str): + return False, 'Missing assembly_substr member' + if an_input.assembly_substr not in disassembly: + return False, ('Incorrect disassembly output:\n{asm}\n' + 'Expected substring not found:\n{exp}'.format( + asm=disassembly, exp=an_input.assembly_substr)) + return True, '' + + +class ValidNamedObjectFile(SuccessfulReturn, CorrectObjectFilePreamble): + """Mixin class for checking that a list of object files with the given + names are correctly generated, and there is no output on stdout/stderr. + + To mix in this class, subclasses need to provide expected_object_filenames + as the expected object filenames. + """ + + def check_object_file_preamble(self, status): + for object_filename in self.expected_object_filenames: + success, message = self.verify_object_file_preamble( + os.path.join(status.directory, object_filename)) + if not success: + return False, message + return True, '' + + +class ValidFileContents(SpirvTest): + """Mixin class to test that a specific file contains specific text + To mix in this class, subclasses need to provide expected_file_contents as + the contents of the file and target_filename to determine the location.""" + + def check_file(self, status): + target_filename = os.path.join(status.directory, self.target_filename) + if not os.path.isfile(target_filename): + return False, 'Cannot find file: ' + target_filename + with open(target_filename, 'r') as target_file: + file_contents = target_file.read() + if isinstance(self.expected_file_contents, str): + if file_contents == self.expected_file_contents: + return True, '' + return False, ('Incorrect file output: \n{act}\n' + 'Expected:\n{exp}' + 'With diff:\n{diff}'.format( + act=file_contents, + exp=self.expected_file_contents, + diff='\n'.join( + list( + difflib.unified_diff( + self.expected_file_contents.split('\n'), + file_contents.split('\n'), + fromfile='expected_output', + tofile='actual_output'))))) + elif isinstance(self.expected_file_contents, type(re.compile(''))): + if self.expected_file_contents.search(file_contents): + return True, '' + return False, ('Incorrect file output: \n{act}\n' + 'Expected matching regex pattern:\n{exp}'.format( + act=file_contents, + exp=self.expected_file_contents.pattern)) + return False, ( + 'Could not open target file ' + target_filename + ' for reading') + + +class ValidAssemblyFile(SuccessfulReturn, CorrectAssemblyFilePreamble): + """Mixin class for checking that every input file generates a valid assembly + file following the assembly file naming rule, and there is no output on + stdout/stderr.""" + + def check_assembly_file_preamble(self, status): + for input_filename in status.input_filenames: + assembly_filename = get_assembly_filename(input_filename) + success, message = self.verify_assembly_file_preamble( + os.path.join(status.directory, assembly_filename)) + if not success: + return False, message + return True, '' + + +class ValidAssemblyFileWithSubstr(ValidAssemblyFile): + """Mixin class for checking that every input file generates a valid assembly + file following the assembly file naming rule, there is no output on + stdout/stderr, and all assembly files have the given substring specified + by expected_assembly_substr. + + To mix in this class, subclasses need to provde expected_assembly_substr + as the expected substring. + """ + + def check_assembly_with_substr(self, status): + for input_filename in status.input_filenames: + assembly_filename = get_assembly_filename(input_filename) + success, message = self.verify_assembly_file_preamble( + os.path.join(status.directory, assembly_filename)) + if not success: + return False, message + with open(assembly_filename, 'r') as f: + content = f.read() + if self.expected_assembly_substr not in convert_to_unix_line_endings( + content): + return False, ('Incorrect assembly output:\n{asm}\n' + 'Expected substring not found:\n{exp}'.format( + asm=content, exp=self.expected_assembly_substr)) + return True, '' + + +class ValidAssemblyFileWithoutSubstr(ValidAssemblyFile): + """Mixin class for checking that every input file generates a valid assembly + file following the assembly file naming rule, there is no output on + stdout/stderr, and no assembly files have the given substring specified + by unexpected_assembly_substr. + + To mix in this class, subclasses need to provde unexpected_assembly_substr + as the substring we expect not to see. + """ + + def check_assembly_for_substr(self, status): + for input_filename in status.input_filenames: + assembly_filename = get_assembly_filename(input_filename) + success, message = self.verify_assembly_file_preamble( + os.path.join(status.directory, assembly_filename)) + if not success: + return False, message + with open(assembly_filename, 'r') as f: + content = f.read() + if self.unexpected_assembly_substr in convert_to_unix_line_endings( + content): + return False, ('Incorrect assembly output:\n{asm}\n' + 'Unexpected substring found:\n{unexp}'.format( + asm=content, exp=self.unexpected_assembly_substr)) + return True, '' + + +class ValidNamedAssemblyFile(SuccessfulReturn, CorrectAssemblyFilePreamble): + """Mixin class for checking that a list of assembly files with the given + names are correctly generated, and there is no output on stdout/stderr. + + To mix in this class, subclasses need to provide expected_assembly_filenames + as the expected assembly filenames. + """ + + def check_object_file_preamble(self, status): + for assembly_filename in self.expected_assembly_filenames: + success, message = self.verify_assembly_file_preamble( + os.path.join(status.directory, assembly_filename)) + if not success: + return False, message + return True, '' + + +class ErrorMessage(SpirvTest): + """Mixin class for tests that fail with a specific error message. + + To mix in this class, subclasses need to provide expected_error as the + expected error message. + + The test should fail if the subprocess was terminated by a signal. + """ + + def check_has_error_message(self, status): + if not status.returncode: + return False, ('Expected error message, but returned success from ' + 'command execution') + if status.returncode < 0: + # On Unix, a negative value -N for Popen.returncode indicates + # termination by signal N. + # https://docs.python.org/2/library/subprocess.html + return False, ('Expected error message, but command was terminated by ' + 'signal ' + str(status.returncode)) + if not status.stderr: + return False, 'Expected error message, but no output on stderr' + if self.expected_error != convert_to_unix_line_endings(status.stderr): + return False, ('Incorrect stderr output:\n{act}\n' + 'Expected:\n{exp}'.format( + act=status.stderr, exp=self.expected_error)) + return True, '' + + +class ErrorMessageSubstr(SpirvTest): + """Mixin class for tests that fail with a specific substring in the error + message. + + To mix in this class, subclasses need to provide expected_error_substr as + the expected error message substring. + + The test should fail if the subprocess was terminated by a signal. + """ + + def check_has_error_message_as_substring(self, status): + if not status.returncode: + return False, ('Expected error message, but returned success from ' + 'command execution') + if status.returncode < 0: + # On Unix, a negative value -N for Popen.returncode indicates + # termination by signal N. + # https://docs.python.org/2/library/subprocess.html + return False, ('Expected error message, but command was terminated by ' + 'signal ' + str(status.returncode)) + if not status.stderr: + return False, 'Expected error message, but no output on stderr' + if self.expected_error_substr not in convert_to_unix_line_endings( + status.stderr): + return False, ('Incorrect stderr output:\n{act}\n' + 'Expected substring not found in stderr:\n{exp}'.format( + act=status.stderr, exp=self.expected_error_substr)) + return True, '' + + +class WarningMessage(SpirvTest): + """Mixin class for tests that succeed but have a specific warning message. + + To mix in this class, subclasses need to provide expected_warning as the + expected warning message. + """ + + def check_has_warning_message(self, status): + if status.returncode: + return False, ('Expected warning message, but returned failure from' + ' command execution') + if not status.stderr: + return False, 'Expected warning message, but no output on stderr' + if self.expected_warning != convert_to_unix_line_endings(status.stderr): + return False, ('Incorrect stderr output:\n{act}\n' + 'Expected:\n{exp}'.format( + act=status.stderr, exp=self.expected_warning)) + return True, '' + + +class ValidObjectFileWithWarning(NoOutputOnStdout, CorrectObjectFilePreamble, + WarningMessage): + """Mixin class for checking that every input file generates a valid object + file following the object file naming rule, with a specific warning message. + """ + + def check_object_file_preamble(self, status): + for input_filename in status.input_filenames: + object_filename = get_object_filename(input_filename) + success, message = self.verify_object_file_preamble( + os.path.join(status.directory, object_filename)) + if not success: + return False, message + return True, '' + + +class ValidAssemblyFileWithWarning(NoOutputOnStdout, + CorrectAssemblyFilePreamble, WarningMessage): + """Mixin class for checking that every input file generates a valid assembly + file following the assembly file naming rule, with a specific warning + message.""" + + def check_assembly_file_preamble(self, status): + for input_filename in status.input_filenames: + assembly_filename = get_assembly_filename(input_filename) + success, message = self.verify_assembly_file_preamble( + os.path.join(status.directory, assembly_filename)) + if not success: + return False, message + return True, '' + + +class StdoutMatch(SpirvTest): + """Mixin class for tests that can expect output on stdout. + + To mix in this class, subclasses need to provide expected_stdout as the + expected stdout output. + + For expected_stdout, if it's True, then they expect something on stdout but + will not check what it is. If it's a string, expect an exact match. If it's + anything else, it is assumed to be a compiled regular expression which will + be matched against re.search(). It will expect + expected_stdout.search(status.stdout) to be true. + """ + + def check_stdout_match(self, status): + # "True" in this case means we expect something on stdout, but we do not + # care what it is, we want to distinguish this from "blah" which means we + # expect exactly the string "blah". + if self.expected_stdout is True: + if not status.stdout: + return False, 'Expected something on stdout' + elif type(self.expected_stdout) == str: + if self.expected_stdout != convert_to_unix_line_endings(status.stdout): + return False, ('Incorrect stdout output:\n{ac}\n' + 'Expected:\n{ex}'.format( + ac=status.stdout, ex=self.expected_stdout)) + else: + if not self.expected_stdout.search( + convert_to_unix_line_endings(status.stdout)): + return False, ('Incorrect stdout output:\n{ac}\n' + 'Expected to match regex:\n{ex}'.format( + ac=status.stdout, ex=self.expected_stdout.pattern)) + return True, '' + + +class StderrMatch(SpirvTest): + """Mixin class for tests that can expect output on stderr. + + To mix in this class, subclasses need to provide expected_stderr as the + expected stderr output. + + For expected_stderr, if it's True, then they expect something on stderr, + but will not check what it is. If it's a string, expect an exact match. + If it's anything else, it is assumed to be a compiled regular expression + which will be matched against re.search(). It will expect + expected_stderr.search(status.stderr) to be true. + """ + + def check_stderr_match(self, status): + # "True" in this case means we expect something on stderr, but we do not + # care what it is, we want to distinguish this from "blah" which means we + # expect exactly the string "blah". + if self.expected_stderr is True: + if not status.stderr: + return False, 'Expected something on stderr' + elif type(self.expected_stderr) == str: + if self.expected_stderr != convert_to_unix_line_endings(status.stderr): + return False, ('Incorrect stderr output:\n{ac}\n' + 'Expected:\n{ex}'.format( + ac=status.stderr, ex=self.expected_stderr)) + else: + if not self.expected_stderr.search( + convert_to_unix_line_endings(status.stderr)): + return False, ('Incorrect stderr output:\n{ac}\n' + 'Expected to match regex:\n{ex}'.format( + ac=status.stderr, ex=self.expected_stderr.pattern)) + return True, '' + + +class StdoutNoWiderThan80Columns(SpirvTest): + """Mixin class for tests that require stdout to 80 characters or narrower. + + To mix in this class, subclasses need to provide expected_stdout as the + expected stdout output. + """ + + def check_stdout_not_too_wide(self, status): + if not status.stdout: + return True, '' + else: + for line in status.stdout.splitlines(): + if len(line) > 80: + return False, ('Stdout line longer than 80 columns: %s' % line) + return True, '' + + +class NoObjectFile(SpirvTest): + """Mixin class for checking that no input file has a corresponding object + file.""" + + def check_no_object_file(self, status): + for input_filename in status.input_filenames: + object_filename = get_object_filename(input_filename) + full_object_file = os.path.join(status.directory, object_filename) + print('checking %s' % full_object_file) + if os.path.isfile(full_object_file): + return False, ( + 'Expected no object file, but found: %s' % full_object_file) + return True, '' + + +class NoNamedOutputFiles(SpirvTest): + """Mixin class for checking that no specified output files exist. + + The expected_output_filenames member should be full pathnames.""" + + def check_no_named_output_files(self, status): + for object_filename in self.expected_output_filenames: + if os.path.isfile(object_filename): + return False, ( + 'Expected no output file, but found: %s' % object_filename) + return True, '' + + +class ExecutedListOfPasses(SpirvTest): + """Mixin class for checking that a list of passes where executed. + + It works by analyzing the output of the --print-all flag to spirv-opt. + + For this mixin to work, the class member expected_passes should be a sequence + of pass names as returned by Pass::name(). + """ + + def check_list_of_executed_passes(self, status): + # Collect all the output lines containing a pass name. + pass_names = [] + pass_name_re = re.compile(r'.*IR before pass (?P[\S]+)') + for line in status.stderr.splitlines(): + match = pass_name_re.match(line) + if match: + pass_names.append(match.group('pass_name')) + + for (expected, actual) in zip(self.expected_passes, pass_names): + if expected != actual: + return False, ( + 'Expected pass "%s" but found pass "%s"\n' % (expected, actual)) + + return True, '' diff --git a/test/tools/expect_nosetest.py b/test/tools/expect_nosetest.py new file mode 100755 index 000000000..b591a2d07 --- /dev/null +++ b/test/tools/expect_nosetest.py @@ -0,0 +1,80 @@ +# Copyright (c) 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the expect module.""" + +import expect +from spirv_test_framework import TestStatus +from nose.tools import assert_equal, assert_true, assert_false +import re + + +def nosetest_get_object_name(): + """Tests get_object_filename().""" + source_and_object_names = [('a.vert', 'a.vert.spv'), ('b.frag', 'b.frag.spv'), + ('c.tesc', 'c.tesc.spv'), ('d.tese', 'd.tese.spv'), + ('e.geom', 'e.geom.spv'), ('f.comp', 'f.comp.spv'), + ('file', 'file.spv'), ('file.', 'file.spv'), + ('file.uk', + 'file.spv'), ('file.vert.', + 'file.vert.spv'), ('file.vert.bla', + 'file.vert.spv')] + actual_object_names = [ + expect.get_object_filename(f[0]) for f in source_and_object_names + ] + expected_object_names = [f[1] for f in source_and_object_names] + + assert_equal(actual_object_names, expected_object_names) + + +class TestStdoutMatchADotC(expect.StdoutMatch): + expected_stdout = re.compile('a.c') + + +def nosetest_stdout_match_regex_has_match(): + test = TestStdoutMatchADotC() + status = TestStatus( + test_manager=None, + returncode=0, + stdout='0abc1', + stderr=None, + directory=None, + inputs=None, + input_filenames=None) + assert_true(test.check_stdout_match(status)[0]) + + +def nosetest_stdout_match_regex_no_match(): + test = TestStdoutMatchADotC() + status = TestStatus( + test_manager=None, + returncode=0, + stdout='ab', + stderr=None, + directory=None, + inputs=None, + input_filenames=None) + assert_false(test.check_stdout_match(status)[0]) + + +def nosetest_stdout_match_regex_empty_stdout(): + test = TestStdoutMatchADotC() + status = TestStatus( + test_manager=None, + returncode=0, + stdout='', + stderr=None, + directory=None, + inputs=None, + input_filenames=None) + assert_false(test.check_stdout_match(status)[0]) diff --git a/test/tools/opt/CMakeLists.txt b/test/tools/opt/CMakeLists.txt new file mode 100644 index 000000000..21aa247f1 --- /dev/null +++ b/test/tools/opt/CMakeLists.txt @@ -0,0 +1,25 @@ +# Copyright (c) 2018 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(NOT ${SPIRV_SKIP_TESTS}) + if(${PYTHONINTERP_FOUND}) + add_test(NAME spirv_opt_cli_tools_tests + COMMAND ${PYTHON_EXECUTABLE} + ${CMAKE_CURRENT_SOURCE_DIR}/../spirv_test_framework.py + $ $ $ + --test-dir ${CMAKE_CURRENT_SOURCE_DIR}) + else() + message("Skipping CLI tools tests - Python executable not found") + endif() +endif() diff --git a/test/tools/opt/flags.py b/test/tools/opt/flags.py new file mode 100644 index 000000000..69462fcc5 --- /dev/null +++ b/test/tools/opt/flags.py @@ -0,0 +1,333 @@ +# Copyright (c) 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import placeholder +import expect +import re + +from spirv_test_framework import inside_spirv_testsuite + + +def empty_main_assembly(): + return """ + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %4 "main" + OpName %4 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd""" + + +@inside_spirv_testsuite('SpirvOptBase') +class TestAssemblyFileAsOnlyParameter(expect.ValidObjectFile1_3): + """Tests that spirv-opt accepts a SPIR-V object file.""" + + shader = placeholder.FileSPIRVShader(empty_main_assembly(), '.spvasm') + output = placeholder.TempFileName('output.spv') + spirv_args = [shader, '-o', output] + expected_object_filenames = (output) + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestHelpFlag(expect.ReturnCodeIsZero, expect.StdoutMatch): + """Test the --help flag.""" + + spirv_args = ['--help'] + expected_stdout = re.compile(r'.*The SPIR-V binary is read from ') + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestValidPassFlags(expect.ValidObjectFile1_3, + expect.ExecutedListOfPasses): + """Tests that spirv-opt accepts all valid optimization flags.""" + + flags = [ + '--ccp', '--cfg-cleanup', '--combine-access-chains', '--compact-ids', + '--convert-local-access-chains', '--copy-propagate-arrays', + '--eliminate-common-uniform', '--eliminate-dead-branches', + '--eliminate-dead-code-aggressive', '--eliminate-dead-const', + '--eliminate-dead-functions', '--eliminate-dead-inserts', + '--eliminate-dead-variables', '--eliminate-insert-extract', + '--eliminate-local-multi-store', '--eliminate-local-single-block', + '--eliminate-local-single-store', '--flatten-decorations', + '--fold-spec-const-op-composite', '--freeze-spec-const', + '--if-conversion', '--inline-entry-points-exhaustive', '--loop-fission', + '20', '--loop-fusion', '5', '--loop-unroll', '--loop-unroll-partial', '3', + '--loop-peeling', '--merge-blocks', '--merge-return', '--loop-unswitch', + '--private-to-local', '--reduce-load-size', '--redundancy-elimination', + '--remove-duplicates', '--replace-invalid-opcode', '--ssa-rewrite', + '--scalar-replacement', '--scalar-replacement=42', '--strength-reduction', + '--strip-debug', '--strip-reflect', '--vector-dce', '--workaround-1209', + '--unify-const' + ] + expected_passes = [ + 'ccp', + 'cfg-cleanup', + 'combine-access-chains', + 'compact-ids', + 'convert-local-access-chains', + 'copy-propagate-arrays', + 'eliminate-common-uniform', + 'eliminate-dead-branches', + 'eliminate-dead-code-aggressive', + 'eliminate-dead-const', + 'eliminate-dead-functions', + 'eliminate-dead-inserts', + 'eliminate-dead-variables', + # --eliminate-insert-extract runs the simplify-instructions pass. + 'simplify-instructions', + 'eliminate-local-multi-store', + 'eliminate-local-single-block', + 'eliminate-local-single-store', + 'flatten-decorations', + 'fold-spec-const-op-composite', + 'freeze-spec-const', + 'if-conversion', + 'inline-entry-points-exhaustive', + 'loop-fission', + 'loop-fusion', + 'loop-unroll', + 'loop-unroll', + 'loop-peeling', + 'merge-blocks', + 'merge-return', + 'loop-unswitch', + 'private-to-local', + 'reduce-load-size', + 'redundancy-elimination', + 'remove-duplicates', + 'replace-invalid-opcode', + 'ssa-rewrite', + 'scalar-replacement=100', + 'scalar-replacement=42', + 'strength-reduction', + 'strip-debug', + 'strip-reflect', + 'vector-dce', + 'workaround-1209', + 'unify-const' + ] + shader = placeholder.FileSPIRVShader(empty_main_assembly(), '.spvasm') + output = placeholder.TempFileName('output.spv') + spirv_args = [shader, '-o', output, '--print-all'] + flags + expected_object_filenames = (output) + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestPerformanceOptimizationPasses(expect.ValidObjectFile1_3, + expect.ExecutedListOfPasses): + """Tests that spirv-opt schedules all the passes triggered by -O.""" + + flags = ['-O'] + expected_passes = [ + 'eliminate-dead-branches', + 'merge-return', + 'inline-entry-points-exhaustive', + 'eliminate-dead-code-aggressive', + 'private-to-local', + 'eliminate-local-single-block', + 'eliminate-local-single-store', + 'eliminate-dead-code-aggressive', + 'scalar-replacement=100', + 'convert-local-access-chains', + 'eliminate-local-single-block', + 'eliminate-local-single-store', + 'eliminate-dead-code-aggressive', + 'eliminate-local-multi-store', + 'eliminate-dead-code-aggressive', + 'ccp', + 'eliminate-dead-code-aggressive', + 'redundancy-elimination', + 'combine-access-chains', + 'simplify-instructions', + 'vector-dce', + 'eliminate-dead-inserts', + 'eliminate-dead-branches', + 'simplify-instructions', + 'if-conversion', + 'copy-propagate-arrays', + 'reduce-load-size', + 'eliminate-dead-code-aggressive', + 'merge-blocks', + 'redundancy-elimination', + 'eliminate-dead-branches', + 'merge-blocks', + 'simplify-instructions', + ] + shader = placeholder.FileSPIRVShader(empty_main_assembly(), '.spvasm') + output = placeholder.TempFileName('output.spv') + spirv_args = [shader, '-o', output, '--print-all'] + flags + expected_object_filenames = (output) + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestSizeOptimizationPasses(expect.ValidObjectFile1_3, + expect.ExecutedListOfPasses): + """Tests that spirv-opt schedules all the passes triggered by -Os.""" + + flags = ['-Os'] + expected_passes = [ + 'eliminate-dead-branches', + 'merge-return', + 'inline-entry-points-exhaustive', + 'eliminate-dead-code-aggressive', + 'private-to-local', + 'scalar-replacement=100', + 'convert-local-access-chains', + 'eliminate-local-single-block', + 'eliminate-local-single-store', + 'eliminate-dead-code-aggressive', + 'simplify-instructions', + 'eliminate-dead-inserts', + 'eliminate-local-multi-store', + 'eliminate-dead-code-aggressive', + 'ccp', + 'eliminate-dead-code-aggressive', + 'eliminate-dead-branches', + 'if-conversion', + 'eliminate-dead-code-aggressive', + 'merge-blocks', + 'simplify-instructions', + 'eliminate-dead-inserts', + 'redundancy-elimination', + 'cfg-cleanup', + 'eliminate-dead-code-aggressive', + ] + shader = placeholder.FileSPIRVShader(empty_main_assembly(), '.spvasm') + output = placeholder.TempFileName('output.spv') + spirv_args = [shader, '-o', output, '--print-all'] + flags + expected_object_filenames = (output) + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestLegalizationPasses(expect.ValidObjectFile1_3, + expect.ExecutedListOfPasses): + """Tests that spirv-opt schedules all the passes triggered by --legalize-hlsl. + """ + + flags = ['--legalize-hlsl'] + expected_passes = [ + 'eliminate-dead-branches', + 'merge-return', + 'inline-entry-points-exhaustive', + 'eliminate-dead-functions', + 'private-to-local', + 'eliminate-local-single-block', + 'eliminate-local-single-store', + 'eliminate-dead-code-aggressive', + 'scalar-replacement=0', + 'eliminate-local-single-block', + 'eliminate-local-single-store', + 'eliminate-dead-code-aggressive', + 'eliminate-local-multi-store', + 'eliminate-dead-code-aggressive', + 'ccp', + 'loop-unroll', + 'eliminate-dead-branches', + 'simplify-instructions', + 'eliminate-dead-code-aggressive', + 'copy-propagate-arrays', + 'vector-dce', + 'eliminate-dead-inserts', + 'reduce-load-size', + 'eliminate-dead-code-aggressive', + ] + shader = placeholder.FileSPIRVShader(empty_main_assembly(), '.spvasm') + output = placeholder.TempFileName('output.spv') + spirv_args = [shader, '-o', output, '--print-all'] + flags + expected_object_filenames = (output) + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestScalarReplacementArgsNegative(expect.ErrorMessageSubstr): + """Tests invalid arguments to --scalar-replacement.""" + + spirv_args = ['--scalar-replacement=-10'] + expected_error_substr = 'must have no arguments or a non-negative integer argument' + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestScalarReplacementArgsInvalidNumber(expect.ErrorMessageSubstr): + """Tests invalid arguments to --scalar-replacement.""" + + spirv_args = ['--scalar-replacement=a10f'] + expected_error_substr = 'must have no arguments or a non-negative integer argument' + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestLoopFissionArgsNegative(expect.ErrorMessageSubstr): + """Tests invalid arguments to --loop-fission.""" + + spirv_args = ['--loop-fission=-10'] + expected_error_substr = 'must have a positive integer argument' + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestLoopFissionArgsInvalidNumber(expect.ErrorMessageSubstr): + """Tests invalid arguments to --loop-fission.""" + + spirv_args = ['--loop-fission=a10f'] + expected_error_substr = 'must have a positive integer argument' + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestLoopFusionArgsNegative(expect.ErrorMessageSubstr): + """Tests invalid arguments to --loop-fusion.""" + + spirv_args = ['--loop-fusion=-10'] + expected_error_substr = 'must have a positive integer argument' + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestLoopFusionArgsInvalidNumber(expect.ErrorMessageSubstr): + """Tests invalid arguments to --loop-fusion.""" + + spirv_args = ['--loop-fusion=a10f'] + expected_error_substr = 'must have a positive integer argument' + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestLoopUnrollPartialArgsNegative(expect.ErrorMessageSubstr): + """Tests invalid arguments to --loop-unroll-partial.""" + + spirv_args = ['--loop-unroll-partial=-10'] + expected_error_substr = 'must have a positive integer argument' + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestLoopUnrollPartialArgsInvalidNumber(expect.ErrorMessageSubstr): + """Tests invalid arguments to --loop-unroll-partial.""" + + spirv_args = ['--loop-unroll-partial=a10f'] + expected_error_substr = 'must have a positive integer argument' + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestLoopPeelingThresholdArgsNegative(expect.ErrorMessageSubstr): + """Tests invalid arguments to --loop-peeling-threshold.""" + + spirv_args = ['--loop-peeling-threshold=-10'] + expected_error_substr = 'must have a positive integer argument' + + +@inside_spirv_testsuite('SpirvOptFlags') +class TestLoopPeelingThresholdArgsInvalidNumber(expect.ErrorMessageSubstr): + """Tests invalid arguments to --loop-peeling-threshold.""" + + spirv_args = ['--loop-peeling-threshold=a10f'] + expected_error_substr = 'must have a positive integer argument' diff --git a/test/tools/opt/oconfig.py b/test/tools/opt/oconfig.py new file mode 100644 index 000000000..899d93e56 --- /dev/null +++ b/test/tools/opt/oconfig.py @@ -0,0 +1,73 @@ +# Copyright (c) 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import placeholder +import expect +import re + +from spirv_test_framework import inside_spirv_testsuite + + +def empty_main_assembly(): + return """ + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %4 "main" + OpName %4 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd""" + + +@inside_spirv_testsuite('SpirvOptConfigFile') +class TestOconfigEmpty(expect.SuccessfulReturn): + """Tests empty config files are accepted.""" + + shader = placeholder.FileSPIRVShader(empty_main_assembly(), '.spvasm') + config = placeholder.ConfigFlagsFile('', '.cfg') + spirv_args = [shader, '-o', placeholder.TempFileName('output.spv'), config] + + +@inside_spirv_testsuite('SpirvOptConfigFile') +class TestOconfigComments(expect.SuccessfulReturn): + """Tests empty config files are accepted. + + https://github.com/KhronosGroup/SPIRV-Tools/issues/1778 + """ + + shader = placeholder.FileSPIRVShader(empty_main_assembly(), '.spvasm') + config = placeholder.ConfigFlagsFile(""" +# This is a comment. +-O +--loop-unroll +""", '.cfg') + spirv_args = [shader, '-o', placeholder.TempFileName('output.spv'), config] + +@inside_spirv_testsuite('SpirvOptConfigFile') +class TestOconfigComments(expect.SuccessfulReturn): + """Tests empty config files are accepted. + + https://github.com/KhronosGroup/SPIRV-Tools/issues/1778 + """ + + shader = placeholder.FileSPIRVShader(empty_main_assembly(), '.spvasm') + config = placeholder.ConfigFlagsFile(""" +# This is a comment. +-O +--relax-struct-store +""", '.cfg') + spirv_args = [shader, '-o', placeholder.TempFileName('output.spv'), config] diff --git a/test/tools/placeholder.py b/test/tools/placeholder.py new file mode 100755 index 000000000..7de3c467a --- /dev/null +++ b/test/tools/placeholder.py @@ -0,0 +1,213 @@ +# Copyright (c) 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A number of placeholders and their rules for expansion when used in tests. + +These placeholders, when used in spirv_args or expected_* variables of +SpirvTest, have special meanings. In spirv_args, they will be substituted by +the result of instantiate_for_spirv_args(), while in expected_*, by +instantiate_for_expectation(). A TestCase instance will be passed in as +argument to the instantiate_*() methods. +""" + +import os +import subprocess +import tempfile +from string import Template + + +class PlaceHolderException(Exception): + """Exception class for PlaceHolder.""" + pass + + +class PlaceHolder(object): + """Base class for placeholders.""" + + def instantiate_for_spirv_args(self, testcase): + """Instantiation rules for spirv_args. + + This method will be called when the current placeholder appears in + spirv_args. + + Returns: + A string to replace the current placeholder in spirv_args. + """ + raise PlaceHolderException('Subclass should implement this function.') + + def instantiate_for_expectation(self, testcase): + """Instantiation rules for expected_*. + + This method will be called when the current placeholder appears in + expected_*. + + Returns: + A string to replace the current placeholder in expected_*. + """ + raise PlaceHolderException('Subclass should implement this function.') + + +class FileShader(PlaceHolder): + """Stands for a shader whose source code is in a file.""" + + def __init__(self, source, suffix, assembly_substr=None): + assert isinstance(source, str) + assert isinstance(suffix, str) + self.source = source + self.suffix = suffix + self.filename = None + # If provided, this is a substring which is expected to be in + # the disassembly of the module generated from this input file. + self.assembly_substr = assembly_substr + + def instantiate_for_spirv_args(self, testcase): + """Creates a temporary file and writes the source into it. + + Returns: + The name of the temporary file. + """ + shader, self.filename = tempfile.mkstemp( + dir=testcase.directory, suffix=self.suffix) + shader_object = os.fdopen(shader, 'w') + shader_object.write(self.source) + shader_object.close() + return self.filename + + def instantiate_for_expectation(self, testcase): + assert self.filename is not None + return self.filename + + +class ConfigFlagsFile(PlaceHolder): + """Stands for a configuration file for spirv-opt generated out of a string.""" + + def __init__(self, content, suffix): + assert isinstance(content, str) + assert isinstance(suffix, str) + self.content = content + self.suffix = suffix + self.filename = None + + def instantiate_for_spirv_args(self, testcase): + """Creates a temporary file and writes content into it. + + Returns: + The name of the temporary file. + """ + temp_fd, self.filename = tempfile.mkstemp( + dir=testcase.directory, suffix=self.suffix) + fd = os.fdopen(temp_fd, 'w') + fd.write(self.content) + fd.close() + return '-Oconfig=%s' % self.filename + + def instantiate_for_expectation(self, testcase): + assert self.filename is not None + return self.filename + + +class FileSPIRVShader(PlaceHolder): + """Stands for a source shader file which must be converted to SPIR-V.""" + + def __init__(self, source, suffix, assembly_substr=None): + assert isinstance(source, str) + assert isinstance(suffix, str) + self.source = source + self.suffix = suffix + self.filename = None + # If provided, this is a substring which is expected to be in + # the disassembly of the module generated from this input file. + self.assembly_substr = assembly_substr + + def instantiate_for_spirv_args(self, testcase): + """Creates a temporary file, writes the source into it and assembles it. + + Returns: + The name of the assembled temporary file. + """ + shader, asm_filename = tempfile.mkstemp( + dir=testcase.directory, suffix=self.suffix) + shader_object = os.fdopen(shader, 'w') + shader_object.write(self.source) + shader_object.close() + self.filename = '%s.spv' % asm_filename + cmd = [ + testcase.test_manager.assembler_path, asm_filename, '-o', self.filename + ] + process = subprocess.Popen( + args=cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=testcase.directory) + output = process.communicate() + assert process.returncode == 0 and not output[0] and not output[1] + return self.filename + + def instantiate_for_expectation(self, testcase): + assert self.filename is not None + return self.filename + + +class StdinShader(PlaceHolder): + """Stands for a shader whose source code is from stdin.""" + + def __init__(self, source): + assert isinstance(source, str) + self.source = source + self.filename = None + + def instantiate_for_spirv_args(self, testcase): + """Writes the source code back to the TestCase instance.""" + testcase.stdin_shader = self.source + self.filename = '-' + return self.filename + + def instantiate_for_expectation(self, testcase): + assert self.filename is not None + return self.filename + + +class TempFileName(PlaceHolder): + """Stands for a temporary file's name.""" + + def __init__(self, filename): + assert isinstance(filename, str) + assert filename != '' + self.filename = filename + + def instantiate_for_spirv_args(self, testcase): + return os.path.join(testcase.directory, self.filename) + + def instantiate_for_expectation(self, testcase): + return os.path.join(testcase.directory, self.filename) + + +class SpecializedString(PlaceHolder): + """Returns a string that has been specialized based on TestCase. + + The string is specialized by expanding it as a string.Template + with all of the specialization being done with each $param replaced + by the associated member on TestCase. + """ + + def __init__(self, filename): + assert isinstance(filename, str) + assert filename != '' + self.filename = filename + + def instantiate_for_spirv_args(self, testcase): + return Template(self.filename).substitute(vars(testcase)) + + def instantiate_for_expectation(self, testcase): + return Template(self.filename).substitute(vars(testcase)) diff --git a/test/tools/spirv_test_framework.py b/test/tools/spirv_test_framework.py new file mode 100755 index 000000000..03ad08fa8 --- /dev/null +++ b/test/tools/spirv_test_framework.py @@ -0,0 +1,375 @@ +# Copyright (c) 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Manages and runs tests from the current working directory. + +This will traverse the current working directory and look for python files that +contain subclasses of SpirvTest. + +If a class has an @inside_spirv_testsuite decorator, an instance of that +class will be created and serve as a test case in that testsuite. The test +case is then run by the following steps: + + 1. A temporary directory will be created. + 2. The spirv_args member variable will be inspected and all placeholders in it + will be expanded by calling instantiate_for_spirv_args() on placeholders. + The transformed list elements are then supplied as arguments to the spirv-* + tool under test. + 3. If the environment member variable exists, its write() method will be + invoked. + 4. All expected_* member variables will be inspected and all placeholders in + them will be expanded by calling instantiate_for_expectation() on those + placeholders. After placeholder expansion, if the expected_* variable is + a list, its element will be joined together with '' to form a single + string. These expected_* variables are to be used by the check_*() methods. + 5. The spirv-* tool will be run with the arguments supplied in spirv_args. + 6. All check_*() member methods will be called by supplying a TestStatus as + argument. Each check_*() method is expected to return a (Success, Message) + pair where Success is a boolean indicating success and Message is an error + message. + 7. If any check_*() method fails, the error message is output and the + current test case fails. + +If --leave-output was not specified, all temporary files and directories will +be deleted. +""" + +from __future__ import print_function + +import argparse +import fnmatch +import inspect +import os +import shutil +import subprocess +import sys +import tempfile +from collections import defaultdict +from placeholder import PlaceHolder + +EXPECTED_BEHAVIOR_PREFIX = 'expected_' +VALIDATE_METHOD_PREFIX = 'check_' + + +def get_all_variables(instance): + """Returns the names of all the variables in instance.""" + return [v for v in dir(instance) if not callable(getattr(instance, v))] + + +def get_all_methods(instance): + """Returns the names of all methods in instance.""" + return [m for m in dir(instance) if callable(getattr(instance, m))] + + +def get_all_superclasses(cls): + """Returns all superclasses of a given class. + + Returns: + A list of superclasses of the given class. The order guarantees that + * A Base class precedes its derived classes, e.g., for "class B(A)", it + will be [..., A, B, ...]. + * When there are multiple base classes, base classes declared first + precede those declared later, e.g., for "class C(A, B), it will be + [..., A, B, C, ...] + """ + classes = [] + for superclass in cls.__bases__: + for c in get_all_superclasses(superclass): + if c not in classes: + classes.append(c) + for superclass in cls.__bases__: + if superclass not in classes: + classes.append(superclass) + return classes + + +def get_all_test_methods(test_class): + """Gets all validation methods. + + Returns: + A list of validation methods. The order guarantees that + * A method defined in superclass precedes one defined in subclass, + e.g., for "class A(B)", methods defined in B precedes those defined + in A. + * If a subclass has more than one superclass, e.g., "class C(A, B)", + then methods defined in A precedes those defined in B. + """ + classes = get_all_superclasses(test_class) + classes.append(test_class) + all_tests = [ + m for c in classes for m in get_all_methods(c) + if m.startswith(VALIDATE_METHOD_PREFIX) + ] + unique_tests = [] + for t in all_tests: + if t not in unique_tests: + unique_tests.append(t) + return unique_tests + + +class SpirvTest: + """Base class for spirv test cases. + + Subclasses define test cases' facts (shader source code, spirv command, + result validation), which will be used by the TestCase class for running + tests. Subclasses should define spirv_args (specifying spirv_tool command + arguments), and at least one check_*() method (for result validation) for + a full-fledged test case. All check_*() methods should take a TestStatus + parameter and return a (Success, Message) pair, in which Success is a + boolean indicating success and Message is an error message. The test passes + iff all check_*() methods returns true. + + Often, a test case class will delegate the check_* behaviors by inheriting + from other classes. + """ + + def name(self): + return self.__class__.__name__ + + +class TestStatus: + """A struct for holding run status of a test case.""" + + def __init__(self, test_manager, returncode, stdout, stderr, directory, + inputs, input_filenames): + self.test_manager = test_manager + self.returncode = returncode + self.stdout = stdout + self.stderr = stderr + # temporary directory where the test runs + self.directory = directory + # List of inputs, as PlaceHolder objects. + self.inputs = inputs + # the names of input shader files (potentially including paths) + self.input_filenames = input_filenames + + +class SpirvTestException(Exception): + """SpirvTest exception class.""" + pass + + +def inside_spirv_testsuite(testsuite_name): + """Decorator for subclasses of SpirvTest. + + This decorator checks that a class meets the requirements (see below) + for a test case class, and then puts the class in a certain testsuite. + * The class needs to be a subclass of SpirvTest. + * The class needs to have spirv_args defined as a list. + * The class needs to define at least one check_*() methods. + * All expected_* variables required by check_*() methods can only be + of bool, str, or list type. + * Python runtime will throw an exception if the expected_* member + attributes required by check_*() methods are missing. + """ + + def actual_decorator(cls): + if not inspect.isclass(cls): + raise SpirvTestException('Test case should be a class') + if not issubclass(cls, SpirvTest): + raise SpirvTestException( + 'All test cases should be subclasses of SpirvTest') + if 'spirv_args' not in get_all_variables(cls): + raise SpirvTestException('No spirv_args found in the test case') + if not isinstance(cls.spirv_args, list): + raise SpirvTestException('spirv_args needs to be a list') + if not any( + [m.startswith(VALIDATE_METHOD_PREFIX) for m in get_all_methods(cls)]): + raise SpirvTestException('No check_*() methods found in the test case') + if not all( + [isinstance(v, (bool, str, list)) for v in get_all_variables(cls)]): + raise SpirvTestException( + 'expected_* variables are only allowed to be bool, str, or ' + 'list type.') + cls.parent_testsuite = testsuite_name + return cls + + return actual_decorator + + +class TestManager: + """Manages and runs a set of tests.""" + + def __init__(self, executable_path, assembler_path, disassembler_path): + self.executable_path = executable_path + self.assembler_path = assembler_path + self.disassembler_path = disassembler_path + self.num_successes = 0 + self.num_failures = 0 + self.num_tests = 0 + self.leave_output = False + self.tests = defaultdict(list) + + def notify_result(self, test_case, success, message): + """Call this to notify the manager of the results of a test run.""" + self.num_successes += 1 if success else 0 + self.num_failures += 0 if success else 1 + counter_string = str(self.num_successes + self.num_failures) + '/' + str( + self.num_tests) + print('%-10s %-40s ' % (counter_string, test_case.test.name()) + + ('Passed' if success else '-Failed-')) + if not success: + print(' '.join(test_case.command)) + print(message) + + def add_test(self, testsuite, test): + """Add this to the current list of test cases.""" + self.tests[testsuite].append(TestCase(test, self)) + self.num_tests += 1 + + def run_tests(self): + for suite in self.tests: + print('SPIRV tool test suite: "{suite}"'.format(suite=suite)) + for x in self.tests[suite]: + x.runTest() + + +class TestCase: + """A single test case that runs in its own directory.""" + + def __init__(self, test, test_manager): + self.test = test + self.test_manager = test_manager + self.inputs = [] # inputs, as PlaceHolder objects. + self.file_shaders = [] # filenames of shader files. + self.stdin_shader = None # text to be passed to spirv_tool as stdin + + def setUp(self): + """Creates environment and instantiates placeholders for the test case.""" + + self.directory = tempfile.mkdtemp(dir=os.getcwd()) + spirv_args = self.test.spirv_args + # Instantiate placeholders in spirv_args + self.test.spirv_args = [ + arg.instantiate_for_spirv_args(self) + if isinstance(arg, PlaceHolder) else arg for arg in self.test.spirv_args + ] + # Get all shader files' names + self.inputs = [arg for arg in spirv_args if isinstance(arg, PlaceHolder)] + self.file_shaders = [arg.filename for arg in self.inputs] + + if 'environment' in get_all_variables(self.test): + self.test.environment.write(self.directory) + + expectations = [ + v for v in get_all_variables(self.test) + if v.startswith(EXPECTED_BEHAVIOR_PREFIX) + ] + # Instantiate placeholders in expectations + for expectation_name in expectations: + expectation = getattr(self.test, expectation_name) + if isinstance(expectation, list): + expanded_expections = [ + element.instantiate_for_expectation(self) + if isinstance(element, PlaceHolder) else element + for element in expectation + ] + setattr(self.test, expectation_name, expanded_expections) + elif isinstance(expectation, PlaceHolder): + setattr(self.test, expectation_name, + expectation.instantiate_for_expectation(self)) + + def tearDown(self): + """Removes the directory if we were not instructed to do otherwise.""" + if not self.test_manager.leave_output: + shutil.rmtree(self.directory) + + def runTest(self): + """Sets up and runs a test, reports any failures and then cleans up.""" + self.setUp() + success = False + message = '' + try: + self.command = [self.test_manager.executable_path] + self.command.extend(self.test.spirv_args) + + process = subprocess.Popen( + args=self.command, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=self.directory) + output = process.communicate(self.stdin_shader) + test_status = TestStatus(self.test_manager, process.returncode, output[0], + output[1], self.directory, self.inputs, + self.file_shaders) + run_results = [ + getattr(self.test, test_method)(test_status) + for test_method in get_all_test_methods(self.test.__class__) + ] + success, message = zip(*run_results) + success = all(success) + message = '\n'.join(message) + except Exception as e: + success = False + message = str(e) + self.test_manager.notify_result( + self, success, + message + '\nSTDOUT:\n%s\nSTDERR:\n%s' % (output[0], output[1])) + self.tearDown() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + 'spirv_tool', + metavar='path/to/spirv_tool', + type=str, + nargs=1, + help='Path to the spirv-* tool under test') + parser.add_argument( + 'spirv_as', + metavar='path/to/spirv-as', + type=str, + nargs=1, + help='Path to spirv-as') + parser.add_argument( + 'spirv_dis', + metavar='path/to/spirv-dis', + type=str, + nargs=1, + help='Path to spirv-dis') + parser.add_argument( + '--leave-output', + action='store_const', + const=1, + help='Do not clean up temporary directories') + parser.add_argument( + '--test-dir', nargs=1, help='Directory to gather the tests from') + args = parser.parse_args() + default_path = sys.path + root_dir = os.getcwd() + if args.test_dir: + root_dir = args.test_dir[0] + manager = TestManager(args.spirv_tool[0], args.spirv_as[0], args.spirv_dis[0]) + if args.leave_output: + manager.leave_output = True + for root, _, filenames in os.walk(root_dir): + for filename in fnmatch.filter(filenames, '*.py'): + if filename.endswith('nosetest.py'): + # Skip nose tests, which are for testing functions of + # the test framework. + continue + sys.path = default_path + sys.path.append(root) + mod = __import__(os.path.splitext(filename)[0]) + for _, obj, in inspect.getmembers(mod): + if inspect.isclass(obj) and hasattr(obj, 'parent_testsuite'): + manager.add_test(obj.parent_testsuite, obj()) + manager.run_tests() + if manager.num_failures > 0: + sys.exit(-1) + + +if __name__ == '__main__': + main() diff --git a/test/tools/spirv_test_framework_nosetest.py b/test/tools/spirv_test_framework_nosetest.py new file mode 100755 index 000000000..c0fbed581 --- /dev/null +++ b/test/tools/spirv_test_framework_nosetest.py @@ -0,0 +1,155 @@ +# Copyright (c) 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from spirv_test_framework import get_all_test_methods, get_all_superclasses +from nose.tools import assert_equal, with_setup + + +# Classes to be used in testing get_all_{superclasses|test_methods}() +class Root: + + def check_root(self): + pass + + +class A(Root): + + def check_a(self): + pass + + +class B(Root): + + def check_b(self): + pass + + +class C(Root): + + def check_c(self): + pass + + +class D(Root): + + def check_d(self): + pass + + +class E(Root): + + def check_e(self): + pass + + +class H(B, C, D): + + def check_h(self): + pass + + +class I(E): + + def check_i(self): + pass + + +class O(H, I): + + def check_o(self): + pass + + +class U(A, O): + + def check_u(self): + pass + + +class X(U, A): + + def check_x(self): + pass + + +class R1: + + def check_r1(self): + pass + + +class R2: + + def check_r2(self): + pass + + +class Multi(R1, R2): + + def check_multi(self): + pass + + +def nosetest_get_all_superclasses(): + """Tests get_all_superclasses().""" + + assert_equal(get_all_superclasses(A), [Root]) + assert_equal(get_all_superclasses(B), [Root]) + assert_equal(get_all_superclasses(C), [Root]) + assert_equal(get_all_superclasses(D), [Root]) + assert_equal(get_all_superclasses(E), [Root]) + + assert_equal(get_all_superclasses(H), [Root, B, C, D]) + assert_equal(get_all_superclasses(I), [Root, E]) + + assert_equal(get_all_superclasses(O), [Root, B, C, D, E, H, I]) + + assert_equal(get_all_superclasses(U), [Root, B, C, D, E, H, I, A, O]) + assert_equal(get_all_superclasses(X), [Root, B, C, D, E, H, I, A, O, U]) + + assert_equal(get_all_superclasses(Multi), [R1, R2]) + + +def nosetest_get_all_methods(): + """Tests get_all_test_methods().""" + assert_equal(get_all_test_methods(A), ['check_root', 'check_a']) + assert_equal(get_all_test_methods(B), ['check_root', 'check_b']) + assert_equal(get_all_test_methods(C), ['check_root', 'check_c']) + assert_equal(get_all_test_methods(D), ['check_root', 'check_d']) + assert_equal(get_all_test_methods(E), ['check_root', 'check_e']) + + assert_equal( + get_all_test_methods(H), + ['check_root', 'check_b', 'check_c', 'check_d', 'check_h']) + assert_equal(get_all_test_methods(I), ['check_root', 'check_e', 'check_i']) + + assert_equal( + get_all_test_methods(O), [ + 'check_root', 'check_b', 'check_c', 'check_d', 'check_e', 'check_h', + 'check_i', 'check_o' + ]) + + assert_equal( + get_all_test_methods(U), [ + 'check_root', 'check_b', 'check_c', 'check_d', 'check_e', 'check_h', + 'check_i', 'check_a', 'check_o', 'check_u' + ]) + assert_equal( + get_all_test_methods(X), [ + 'check_root', 'check_b', 'check_c', 'check_d', 'check_e', 'check_h', + 'check_i', 'check_a', 'check_o', 'check_u', 'check_x' + ]) + + assert_equal( + get_all_test_methods(Multi), ['check_r1', 'check_r2', 'check_multi']) diff --git a/test/unit_spirv.cpp b/test/unit_spirv.cpp new file mode 100644 index 000000000..84ed87a51 --- /dev/null +++ b/test/unit_spirv.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/unit_spirv.h" + +#include "gmock/gmock.h" +#include "test/test_fixture.h" + +namespace spvtools { +namespace { + +using spvtest::MakeVector; +using ::testing::Eq; +using Words = std::vector; + +TEST(MakeVector, Samples) { + EXPECT_THAT(MakeVector(""), Eq(Words{0})); + EXPECT_THAT(MakeVector("a"), Eq(Words{0x0061})); + EXPECT_THAT(MakeVector("ab"), Eq(Words{0x006261})); + EXPECT_THAT(MakeVector("abc"), Eq(Words{0x00636261})); + EXPECT_THAT(MakeVector("abcd"), Eq(Words{0x64636261, 0x00})); + EXPECT_THAT(MakeVector("abcde"), Eq(Words{0x64636261, 0x0065})); +} + +TEST(WordVectorPrintTo, PreservesFlagsAndFill) { + std::stringstream s; + s << std::setw(4) << std::oct << std::setfill('x') << 8 << " "; + spvtest::PrintTo(spvtest::WordVector({10, 16}), &s); + // The octal setting and fill character should be preserved + // from before the PrintTo. + // Width is reset after each emission of a regular scalar type. + // So set it explicitly again. + s << std::setw(4) << 9; + + EXPECT_THAT(s.str(), Eq("xx10 0x0000000a 0x00000010 xx11")); +} + +TEST_P(RoundTripTest, Sample) { + EXPECT_THAT(EncodeAndDecodeSuccessfully(GetParam()), Eq(GetParam())) + << GetParam(); +} + +} // namespace +} // namespace spvtools diff --git a/test/unit_spirv.h b/test/unit_spirv.h new file mode 100644 index 000000000..224428884 --- /dev/null +++ b/test/unit_spirv.h @@ -0,0 +1,234 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TEST_UNIT_SPIRV_H_ +#define TEST_UNIT_SPIRV_H_ + +#include + +#include +#include +#include + +#include "gtest/gtest.h" +#include "source/assembly_grammar.h" +#include "source/binary.h" +#include "source/diagnostic.h" +#include "source/enum_set.h" +#include "source/opcode.h" +#include "source/spirv_endian.h" +#include "source/text.h" +#include "source/text_handler.h" +#include "source/val/validate.h" +#include "spirv-tools/libspirv.h" + +#ifdef __ANDROID__ +#include +namespace std { +template +std::string to_string(const T& val) { + std::ostringstream os; + os << val; + return os.str(); +} +} // namespace std +#endif + +// Determine endianness & predicate tests on it +enum { + I32_ENDIAN_LITTLE = 0x03020100ul, + I32_ENDIAN_BIG = 0x00010203ul, +}; + +static const union { + unsigned char bytes[4]; + uint32_t value; +} o32_host_order = {{0, 1, 2, 3}}; +#define I32_ENDIAN_HOST (o32_host_order.value) + +// A namespace for utilities used in SPIR-V Tools unit tests. +namespace spvtest { + +class WordVector; + +// Emits the given word vector to the given stream. +// This function can be used by the gtest value printer. +void PrintTo(const WordVector& words, ::std::ostream* os); + +// A proxy class to allow us to easily write out vectors of SPIR-V words. +class WordVector { + public: + explicit WordVector(const std::vector& val) : value_(val) {} + explicit WordVector(const spv_binary_t& binary) + : value_(binary.code, binary.code + binary.wordCount) {} + + // Returns the underlying vector. + const std::vector& value() const { return value_; } + + // Returns the string representation of this word vector. + std::string str() const { + std::ostringstream os; + PrintTo(*this, &os); + return os.str(); + } + + private: + const std::vector value_; +}; + +inline void PrintTo(const WordVector& words, ::std::ostream* os) { + size_t count = 0; + const auto saved_flags = os->flags(); + const auto saved_fill = os->fill(); + for (uint32_t value : words.value()) { + *os << "0x" << std::setw(8) << std::setfill('0') << std::hex << value + << " "; + if (count++ % 8 == 7) { + *os << std::endl; + } + } + os->flags(saved_flags); + os->fill(saved_fill); +} + +// Returns a vector of words representing a single instruction with the +// given opcode and operand words as a vector. +inline std::vector MakeInstruction( + SpvOp opcode, const std::vector& args) { + std::vector result{ + spvOpcodeMake(uint16_t(args.size() + 1), opcode)}; + result.insert(result.end(), args.begin(), args.end()); + return result; +} + +// Returns a vector of words representing a single instruction with the +// given opcode and whose operands are the concatenation of the two given +// argument lists. +inline std::vector MakeInstruction( + SpvOp opcode, std::vector args, + const std::vector& extra_args) { + args.insert(args.end(), extra_args.begin(), extra_args.end()); + return MakeInstruction(opcode, args); +} + +// Returns the vector of words representing the concatenation +// of all input vectors. +inline std::vector Concatenate( + const std::vector>& instructions) { + std::vector result; + for (const auto& instruction : instructions) { + result.insert(result.end(), instruction.begin(), instruction.end()); + } + return result; +} + +// Encodes a string as a sequence of words, using the SPIR-V encoding. +inline std::vector MakeVector(std::string input) { + std::vector result; + uint32_t word = 0; + size_t num_bytes = input.size(); + // SPIR-V strings are null-terminated. The byte_index == num_bytes + // case is used to push the terminating null byte. + for (size_t byte_index = 0; byte_index <= num_bytes; byte_index++) { + const auto new_byte = + (byte_index < num_bytes ? uint8_t(input[byte_index]) : uint8_t(0)); + word |= (new_byte << (8 * (byte_index % sizeof(uint32_t)))); + if (3 == (byte_index % sizeof(uint32_t))) { + result.push_back(word); + word = 0; + } + } + // Emit a trailing partial word. + if ((num_bytes + 1) % sizeof(uint32_t)) { + result.push_back(word); + } + return result; +} + +// A type for easily creating spv_text_t values, with an implicit conversion to +// spv_text. +struct AutoText { + explicit AutoText(const std::string& value) + : str(value), text({str.data(), str.size()}) {} + operator spv_text() { return &text; } + std::string str; + spv_text_t text; +}; + +// An example case for an enumerated value, optionally with operands. +template +class EnumCase { + public: + EnumCase() = default; // Required by ::testing::Combine(). + EnumCase(E val, std::string enum_name, std::vector ops = {}) + : enum_value_(val), name_(enum_name), operands_(ops) {} + // Returns the enum value as a uint32_t. + uint32_t value() const { return static_cast(enum_value_); } + // Returns the name of the enumerant. + const std::string& name() const { return name_; } + // Returns a reference to the operands. + const std::vector& operands() const { return operands_; } + + private: + E enum_value_; + std::string name_; + std::vector operands_; +}; + +// Returns a string with num_4_byte_chars Unicode characters, +// each of which has a 4-byte UTF-8 encoding. +inline std::string MakeLongUTF8String(size_t num_4_byte_chars) { + // An example of a longest valid UTF-8 character. + // Be explicit about the character type because Microsoft compilers can + // otherwise interpret the character string as being over wide (16-bit) + // characters. Ideally, we would just use a C++11 UTF-8 string literal, + // but we want to support older Microsoft compilers. + const std::basic_string earth_africa("\xF0\x9F\x8C\x8D"); + EXPECT_EQ(4u, earth_africa.size()); + + std::string result; + result.reserve(num_4_byte_chars * 4); + for (size_t i = 0; i < num_4_byte_chars; i++) { + result += earth_africa; + } + EXPECT_EQ(4 * num_4_byte_chars, result.size()); + return result; +} + +// Returns a vector of all valid target environment enums. +inline std::vector AllTargetEnvironments() { + return { + SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, + SPV_ENV_OPENCL_1_2, SPV_ENV_OPENCL_EMBEDDED_1_2, + SPV_ENV_OPENCL_2_0, SPV_ENV_OPENCL_EMBEDDED_2_0, + SPV_ENV_OPENCL_2_1, SPV_ENV_OPENCL_EMBEDDED_2_1, + SPV_ENV_OPENCL_2_2, SPV_ENV_OPENCL_EMBEDDED_2_2, + SPV_ENV_VULKAN_1_0, SPV_ENV_OPENGL_4_0, + SPV_ENV_OPENGL_4_1, SPV_ENV_OPENGL_4_2, + SPV_ENV_OPENGL_4_3, SPV_ENV_OPENGL_4_5, + SPV_ENV_UNIVERSAL_1_2, SPV_ENV_UNIVERSAL_1_3, + SPV_ENV_VULKAN_1_1, SPV_ENV_WEBGPU_0, + }; +} + +// Returns the capabilities in a CapabilitySet as an ordered vector. +inline std::vector ElementsIn( + const spvtools::CapabilitySet& capabilities) { + std::vector result; + capabilities.ForEach([&result](SpvCapability c) { result.push_back(c); }); + return result; +} + +} // namespace spvtest +#endif // TEST_UNIT_SPIRV_H_ diff --git a/test/util/CMakeLists.txt b/test/util/CMakeLists.txt new file mode 100644 index 000000000..8cdb35f42 --- /dev/null +++ b/test/util/CMakeLists.txt @@ -0,0 +1,20 @@ +# Copyright (c) 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_spvtools_unittest(TARGET utils + SRCS ilist_test.cpp + bit_vector_test.cpp + small_vector_test.cpp + LIBS SPIRV-Tools-opt +) diff --git a/test/util/bit_vector_test.cpp b/test/util/bit_vector_test.cpp new file mode 100644 index 000000000..8d967f8f9 --- /dev/null +++ b/test/util/bit_vector_test.cpp @@ -0,0 +1,164 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" + +#include "source/util/bit_vector.h" + +namespace spvtools { +namespace utils { +namespace { + +using BitVectorTest = ::testing::Test; + +TEST(BitVectorTest, Initialize) { + BitVector bvec; + + // Checks that all values are 0. Also tests checking a bit past the end of + // the vector containing the bits. + for (int i = 1; i < 10000; i *= 2) { + EXPECT_FALSE(bvec.Get(i)); + } +} + +TEST(BitVectorTest, Set) { + BitVector bvec; + + // Since 10,000 is larger than the initial size, this tests the resizing + // code. + for (int i = 3; i < 10000; i *= 2) { + bvec.Set(i); + } + + // Check that bits that were not set are 0. + for (int i = 1; i < 10000; i *= 2) { + EXPECT_FALSE(bvec.Get(i)); + } + + // Check that bits that were set are 1. + for (int i = 3; i < 10000; i *= 2) { + EXPECT_TRUE(bvec.Get(i)); + } +} + +TEST(BitVectorTest, SetReturnValue) { + BitVector bvec; + + // Make sure |Set| returns false when the bit was not set. + for (int i = 3; i < 10000; i *= 2) { + EXPECT_FALSE(bvec.Set(i)); + } + + // Make sure |Set| returns true when the bit was already set. + for (int i = 3; i < 10000; i *= 2) { + EXPECT_TRUE(bvec.Set(i)); + } +} + +TEST(BitVectorTest, Clear) { + BitVector bvec; + for (int i = 3; i < 10000; i *= 2) { + bvec.Set(i); + } + + // Check that the bits were properly set. + for (int i = 3; i < 10000; i *= 2) { + EXPECT_TRUE(bvec.Get(i)); + } + + // Clear all of the bits except for bit 3. + for (int i = 6; i < 10000; i *= 2) { + bvec.Clear(i); + } + + // Make sure bit 3 was not cleared. + EXPECT_TRUE(bvec.Get(3)); + + // Make sure all of the other bits that were set have been cleared. + for (int i = 6; i < 10000; i *= 2) { + EXPECT_FALSE(bvec.Get(i)); + } +} + +TEST(BitVectorTest, ClearReturnValue) { + BitVector bvec; + for (int i = 3; i < 10000; i *= 2) { + bvec.Set(i); + } + + // Make sure |Clear| returns true if the bit was set. + for (int i = 3; i < 10000; i *= 2) { + EXPECT_TRUE(bvec.Clear(i)); + } + + // Make sure |Clear| returns false if the bit was not set. + for (int i = 3; i < 10000; i *= 2) { + EXPECT_FALSE(bvec.Clear(i)); + } +} + +TEST(BitVectorTest, SimpleOrTest) { + BitVector bvec1; + bvec1.Set(3); + bvec1.Set(4); + + BitVector bvec2; + bvec2.Set(2); + bvec2.Set(4); + + // Check that |bvec1| changed when doing the |Or| operation. + EXPECT_TRUE(bvec1.Or(bvec2)); + + // Check that the values are all correct. + EXPECT_FALSE(bvec1.Get(0)); + EXPECT_FALSE(bvec1.Get(1)); + EXPECT_TRUE(bvec1.Get(2)); + EXPECT_TRUE(bvec1.Get(3)); + EXPECT_TRUE(bvec1.Get(4)); +} + +TEST(BitVectorTest, ResizingOrTest) { + BitVector bvec1; + bvec1.Set(3); + bvec1.Set(4); + + BitVector bvec2; + bvec2.Set(10000); + + // Similar to above except with a large value to test resizing. + EXPECT_TRUE(bvec1.Or(bvec2)); + EXPECT_FALSE(bvec1.Get(0)); + EXPECT_FALSE(bvec1.Get(1)); + EXPECT_FALSE(bvec1.Get(2)); + EXPECT_TRUE(bvec1.Get(3)); + EXPECT_TRUE(bvec1.Get(10000)); +} + +TEST(BitVectorTest, SubsetOrTest) { + BitVector bvec1; + bvec1.Set(3); + bvec1.Set(4); + + BitVector bvec2; + bvec2.Set(3); + + // |Or| returns false if |bvec1| does not change. + EXPECT_FALSE(bvec1.Or(bvec2)); +} + +} // namespace +} // namespace utils +} // namespace spvtools diff --git a/test/util/ilist_test.cpp b/test/util/ilist_test.cpp new file mode 100644 index 000000000..4a546f993 --- /dev/null +++ b/test/util/ilist_test.cpp @@ -0,0 +1,325 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gmock/gmock.h" +#include "source/util/ilist.h" + +namespace spvtools { +namespace utils { +namespace { + +using ::testing::ElementsAre; +using IListTest = ::testing::Test; + +class TestNode : public IntrusiveNodeBase { + public: + TestNode() : IntrusiveNodeBase() {} + int data_; +}; + +class TestList : public IntrusiveList { + public: + TestList() = default; + TestList(TestList&& that) : IntrusiveList(std::move(that)) {} + TestList& operator=(TestList&& that) { + static_cast&>(*this) = + static_cast&&>(that); + return *this; + } +}; + +// This test checks the push_back method, as well as using an iterator to +// traverse the list from begin() to end(). This implicitly test the +// PreviousNode and NextNode functions. +TEST(IListTest, PushBack) { + TestNode nodes[10]; + TestList list; + for (int i = 0; i < 10; i++) { + nodes[i].data_ = i; + list.push_back(&nodes[i]); + } + + std::vector output; + for (auto& i : list) output.push_back(i.data_); + + EXPECT_THAT(output, ElementsAre(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)); +} + +// Returns a list containing the values 0 to n-1 using the first n elements of +// nodes to build the list. +TestList BuildList(TestNode nodes[], int n) { + TestList list; + for (int i = 0; i < n; i++) { + nodes[i].data_ = i; + list.push_back(&nodes[i]); + } + return list; +} + +// Test decrementing begin() +TEST(IListTest, DecrementingBegin) { + TestNode nodes[10]; + TestList list = BuildList(nodes, 10); + EXPECT_EQ(--list.begin(), list.end()); +} + +// Test incrementing end() +TEST(IListTest, IncrementingEnd1) { + TestNode nodes[10]; + TestList list = BuildList(nodes, 10); + EXPECT_EQ((++list.end())->data_, 0); +} + +// Test incrementing end() should equal begin() +TEST(IListTest, IncrementingEnd2) { + TestNode nodes[10]; + TestList list = BuildList(nodes, 10); + EXPECT_EQ(++list.end(), list.begin()); +} + +// Test decrementing end() +TEST(IListTest, DecrementingEnd) { + TestNode nodes[10]; + TestList list = BuildList(nodes, 10); + EXPECT_EQ((--list.end())->data_, 9); +} + +// Test the move constructor for the list class. +TEST(IListTest, MoveConstructor) { + TestNode nodes[10]; + TestList list = BuildList(nodes, 10); + std::vector output; + for (auto& i : list) output.push_back(i.data_); + + EXPECT_THAT(output, ElementsAre(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)); +} + +// Using a const list so we can test the const_iterator. +TEST(IListTest, ConstIterator) { + TestNode nodes[10]; + const TestList list = BuildList(nodes, 10); + std::vector output; + for (auto& i : list) output.push_back(i.data_); + + EXPECT_THAT(output, ElementsAre(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)); +} + +// Uses the move assignement instead of the move constructor. +TEST(IListTest, MoveAssignment) { + TestNode nodes[10]; + TestList list; + list = BuildList(nodes, 10); + std::vector output; + for (auto& i : list) output.push_back(i.data_); + + EXPECT_THAT(output, ElementsAre(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)); +} + +// Test inserting a new element at the end of a list using the IntrusiveNodeBase +// "InsertAfter" function. +TEST(IListTest, InsertAfter1) { + TestNode nodes[10]; + TestList list = BuildList(nodes, 5); + + nodes[5].data_ = 5; + nodes[5].InsertAfter(&nodes[4]); + + std::vector output; + for (auto& i : list) output.push_back(i.data_); + + EXPECT_THAT(output, ElementsAre(0, 1, 2, 3, 4, 5)); +} + +// Test inserting a new element in the middle of a list using the +// IntrusiveNodeBase "InsertAfter" function. +TEST(IListTest, InsertAfter2) { + TestNode nodes[10]; + TestList list = BuildList(nodes, 5); + + nodes[5].data_ = 5; + nodes[5].InsertAfter(&nodes[2]); + + std::vector output; + for (auto& i : list) output.push_back(i.data_); + + EXPECT_THAT(output, ElementsAre(0, 1, 2, 5, 3, 4)); +} + +// Test moving an element already in the list in the middle of a list using the +// IntrusiveNodeBase "InsertAfter" function. +TEST(IListTest, MoveUsingInsertAfter1) { + TestNode nodes[10]; + TestList list = BuildList(nodes, 6); + + nodes[5].InsertAfter(&nodes[2]); + + std::vector output; + for (auto& i : list) output.push_back(i.data_); + + EXPECT_THAT(output, ElementsAre(0, 1, 2, 5, 3, 4)); +} + +// Move the element at the start of the list into the middle. +TEST(IListTest, MoveUsingInsertAfter2) { + TestNode nodes[10]; + TestList list = BuildList(nodes, 6); + + nodes[0].InsertAfter(&nodes[2]); + + std::vector output; + for (auto& i : list) output.push_back(i.data_); + + EXPECT_THAT(output, ElementsAre(1, 2, 0, 3, 4, 5)); +} + +// Move an element in the middle of the list to the end. +TEST(IListTest, MoveUsingInsertAfter3) { + TestNode nodes[10]; + TestList list = BuildList(nodes, 6); + + nodes[2].InsertAfter(&nodes[5]); + + std::vector output; + for (auto& i : list) output.push_back(i.data_); + + EXPECT_THAT(output, ElementsAre(0, 1, 3, 4, 5, 2)); +} + +// Removing an element from the middle of a list. +TEST(IListTest, Remove1) { + TestNode nodes[10]; + TestList list = BuildList(nodes, 6); + + nodes[2].RemoveFromList(); + + std::vector output; + for (auto& i : list) output.push_back(i.data_); + + EXPECT_THAT(output, ElementsAre(0, 1, 3, 4, 5)); +} + +// Removing an element from the beginning of the list. +TEST(IListTest, Remove2) { + TestNode nodes[10]; + TestList list = BuildList(nodes, 6); + + nodes[0].RemoveFromList(); + + std::vector output; + for (auto& i : list) output.push_back(i.data_); + + EXPECT_THAT(output, ElementsAre(1, 2, 3, 4, 5)); +} + +// Removing the last element of a list. +TEST(IListTest, Remove3) { + TestNode nodes[10]; + TestList list = BuildList(nodes, 6); + + nodes[5].RemoveFromList(); + + std::vector output; + for (auto& i : list) output.push_back(i.data_); + + EXPECT_THAT(output, ElementsAre(0, 1, 2, 3, 4)); +} + +// Test that operator== and operator!= work properly for the iterator class. +TEST(IListTest, IteratorEqual) { + TestNode nodes[10]; + TestList list = BuildList(nodes, 6); + + std::vector output; + for (auto i = list.begin(); i != list.end(); ++i) + for (auto j = list.begin(); j != list.end(); ++j) + if (i == j) output.push_back(i->data_); + + EXPECT_THAT(output, ElementsAre(0, 1, 2, 3, 4, 5)); +} + +// Test MoveBefore. Moving into middle of a list. +TEST(IListTest, MoveBefore1) { + TestNode nodes[10]; + TestList list1 = BuildList(nodes, 6); + TestList list2 = BuildList(nodes + 6, 3); + + TestList::iterator insertion_point = list1.begin(); + ++insertion_point; + insertion_point.MoveBefore(&list2); + + std::vector output; + for (auto i = list1.begin(); i != list1.end(); ++i) { + output.push_back(i->data_); + } + + EXPECT_THAT(output, ElementsAre(0, 0, 1, 2, 1, 2, 3, 4, 5)); +} + +// Test MoveBefore. Moving to the start of a list. +TEST(IListTest, MoveBefore2) { + TestNode nodes[10]; + TestList list1 = BuildList(nodes, 6); + TestList list2 = BuildList(nodes + 6, 3); + + TestList::iterator insertion_point = list1.begin(); + insertion_point.MoveBefore(&list2); + + std::vector output; + for (auto i = list1.begin(); i != list1.end(); ++i) { + output.push_back(i->data_); + } + + EXPECT_THAT(output, ElementsAre(0, 1, 2, 0, 1, 2, 3, 4, 5)); +} + +// Test MoveBefore. Moving to the end of a list. +TEST(IListTest, MoveBefore3) { + TestNode nodes[10]; + TestList list1 = BuildList(nodes, 6); + TestList list2 = BuildList(nodes + 6, 3); + + TestList::iterator insertion_point = list1.end(); + insertion_point.MoveBefore(&list2); + + std::vector output; + for (auto i = list1.begin(); i != list1.end(); ++i) { + output.push_back(i->data_); + } + + EXPECT_THAT(output, ElementsAre(0, 1, 2, 3, 4, 5, 0, 1, 2)); +} + +// Test MoveBefore. Moving an empty list. +TEST(IListTest, MoveBefore4) { + TestNode nodes[10]; + TestList list1 = BuildList(nodes, 6); + TestList list2; + + TestList::iterator insertion_point = list1.end(); + insertion_point.MoveBefore(&list2); + + std::vector output; + for (auto i = list1.begin(); i != list1.end(); ++i) { + output.push_back(i->data_); + } + + EXPECT_THAT(output, ElementsAre(0, 1, 2, 3, 4, 5)); +} + +} // namespace +} // namespace utils +} // namespace spvtools diff --git a/test/util/small_vector_test.cpp b/test/util/small_vector_test.cpp new file mode 100644 index 000000000..01d7df185 --- /dev/null +++ b/test/util/small_vector_test.cpp @@ -0,0 +1,598 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gmock/gmock.h" +#include "source/util/small_vector.h" + +namespace spvtools { +namespace utils { +namespace { + +using SmallVectorTest = ::testing::Test; + +TEST(SmallVectorTest, Initialize_default) { + SmallVector vec; + + EXPECT_TRUE(vec.empty()); + EXPECT_EQ(vec.size(), 0); + EXPECT_EQ(vec.begin(), vec.end()); +} + +TEST(SmallVectorTest, Initialize_list1) { + SmallVector vec = {0, 1, 2, 3}; + + EXPECT_FALSE(vec.empty()); + EXPECT_EQ(vec.size(), 4); + + uint32_t result[] = {0, 1, 2, 3}; + for (uint32_t i = 0; i < vec.size(); ++i) { + EXPECT_EQ(vec[i], result[i]); + } +} + +TEST(SmallVectorTest, Initialize_list2) { + SmallVector vec = {0, 1, 2, 3}; + + EXPECT_FALSE(vec.empty()); + EXPECT_EQ(vec.size(), 4); + + uint32_t result[] = {0, 1, 2, 3}; + for (uint32_t i = 0; i < vec.size(); ++i) { + EXPECT_EQ(vec[i], result[i]); + } +} + +TEST(SmallVectorTest, Initialize_copy1) { + SmallVector vec1 = {0, 1, 2, 3}; + SmallVector vec2(vec1); + + EXPECT_EQ(vec2.size(), 4); + + uint32_t result[] = {0, 1, 2, 3}; + for (uint32_t i = 0; i < vec2.size(); ++i) { + EXPECT_EQ(vec2[i], result[i]); + } + + EXPECT_EQ(vec1, vec2); +} + +TEST(SmallVectorTest, Initialize_copy2) { + SmallVector vec1 = {0, 1, 2, 3}; + SmallVector vec2(vec1); + + EXPECT_EQ(vec2.size(), 4); + + uint32_t result[] = {0, 1, 2, 3}; + for (uint32_t i = 0; i < vec2.size(); ++i) { + EXPECT_EQ(vec2[i], result[i]); + } + + EXPECT_EQ(vec1, vec2); +} + +TEST(SmallVectorTest, Initialize_copy_vec1) { + std::vector vec1 = {0, 1, 2, 3}; + SmallVector vec2(vec1); + + EXPECT_EQ(vec2.size(), 4); + + uint32_t result[] = {0, 1, 2, 3}; + for (uint32_t i = 0; i < vec2.size(); ++i) { + EXPECT_EQ(vec2[i], result[i]); + } + + EXPECT_EQ(vec1, vec2); +} + +TEST(SmallVectorTest, Initialize_copy_vec2) { + std::vector vec1 = {0, 1, 2, 3}; + SmallVector vec2(vec1); + + EXPECT_EQ(vec2.size(), 4); + + uint32_t result[] = {0, 1, 2, 3}; + for (uint32_t i = 0; i < vec2.size(); ++i) { + EXPECT_EQ(vec2[i], result[i]); + } + + EXPECT_EQ(vec1, vec2); +} + +TEST(SmallVectorTest, Initialize_move1) { + SmallVector vec1 = {0, 1, 2, 3}; + SmallVector vec2(std::move(vec1)); + + EXPECT_EQ(vec2.size(), 4); + + uint32_t result[] = {0, 1, 2, 3}; + for (uint32_t i = 0; i < vec2.size(); ++i) { + EXPECT_EQ(vec2[i], result[i]); + } + EXPECT_TRUE(vec1.empty()); +} + +TEST(SmallVectorTest, Initialize_move2) { + SmallVector vec1 = {0, 1, 2, 3}; + SmallVector vec2(std::move(vec1)); + + EXPECT_EQ(vec2.size(), 4); + + uint32_t result[] = {0, 1, 2, 3}; + for (uint32_t i = 0; i < vec2.size(); ++i) { + EXPECT_EQ(vec2[i], result[i]); + } + EXPECT_TRUE(vec1.empty()); +} + +TEST(SmallVectorTest, Initialize_move_vec1) { + std::vector vec1 = {0, 1, 2, 3}; + SmallVector vec2(std::move(vec1)); + + EXPECT_EQ(vec2.size(), 4); + + uint32_t result[] = {0, 1, 2, 3}; + for (uint32_t i = 0; i < vec2.size(); ++i) { + EXPECT_EQ(vec2[i], result[i]); + } + EXPECT_TRUE(vec1.empty()); +} + +TEST(SmallVectorTest, Initialize_move_vec2) { + std::vector vec1 = {0, 1, 2, 3}; + SmallVector vec2(std::move(vec1)); + + EXPECT_EQ(vec2.size(), 4); + + uint32_t result[] = {0, 1, 2, 3}; + for (uint32_t i = 0; i < vec2.size(); ++i) { + EXPECT_EQ(vec2[i], result[i]); + } + EXPECT_TRUE(vec1.empty()); +} + +TEST(SmallVectorTest, Initialize_iterators1) { + SmallVector vec = {0, 1, 2, 3}; + + EXPECT_EQ(vec.size(), 4); + uint32_t result[] = {0, 1, 2, 3}; + + uint32_t i = 0; + for (uint32_t p : vec) { + EXPECT_EQ(p, result[i]); + i++; + } +} + +TEST(SmallVectorTest, Initialize_iterators2) { + SmallVector vec = {0, 1, 2, 3}; + + EXPECT_EQ(vec.size(), 4); + uint32_t result[] = {0, 1, 2, 3}; + + uint32_t i = 0; + for (uint32_t p : vec) { + EXPECT_EQ(p, result[i]); + i++; + } +} + +TEST(SmallVectorTest, Initialize_iterators3) { + SmallVector vec = {0, 1, 2, 3}; + + EXPECT_EQ(vec.size(), 4); + uint32_t result[] = {0, 1, 2, 3}; + + uint32_t i = 0; + for (SmallVector::iterator it = vec.begin(); it != vec.end(); + ++it) { + EXPECT_EQ(*it, result[i]); + i++; + } +} + +TEST(SmallVectorTest, Initialize_iterators4) { + SmallVector vec = {0, 1, 2, 3}; + + EXPECT_EQ(vec.size(), 4); + uint32_t result[] = {0, 1, 2, 3}; + + uint32_t i = 0; + for (SmallVector::iterator it = vec.begin(); it != vec.end(); + ++it) { + EXPECT_EQ(*it, result[i]); + i++; + } +} + +TEST(SmallVectorTest, Initialize_iterators_write1) { + SmallVector vec = {0, 1, 2, 3}; + + EXPECT_EQ(vec.size(), 4); + for (SmallVector::iterator it = vec.begin(); it != vec.end(); + ++it) { + *it *= 2; + } + + uint32_t result[] = {0, 2, 4, 6}; + + uint32_t i = 0; + for (SmallVector::iterator it = vec.begin(); it != vec.end(); + ++it) { + EXPECT_EQ(*it, result[i]); + i++; + } +} + +TEST(SmallVectorTest, Initialize_iterators_write2) { + SmallVector vec = {0, 1, 2, 3}; + + EXPECT_EQ(vec.size(), 4); + for (SmallVector::iterator it = vec.begin(); it != vec.end(); + ++it) { + *it *= 2; + } + + uint32_t result[] = {0, 2, 4, 6}; + + uint32_t i = 0; + for (SmallVector::iterator it = vec.begin(); it != vec.end(); + ++it) { + EXPECT_EQ(*it, result[i]); + i++; + } +} + +TEST(SmallVectorTest, Initialize_front) { + SmallVector vec = {0, 1, 2, 3}; + + EXPECT_EQ(vec.front(), 0); + for (SmallVector::iterator it = vec.begin(); it != vec.end(); + ++it) { + *it += 2; + } + EXPECT_EQ(vec.front(), 2); +} + +TEST(SmallVectorTest, Erase_element_front1) { + SmallVector vec = {0, 1, 2, 3}; + + EXPECT_EQ(vec.front(), 0); + EXPECT_EQ(vec.size(), 4); + vec.erase(vec.begin()); + EXPECT_EQ(vec.front(), 1); + EXPECT_EQ(vec.size(), 3); +} + +TEST(SmallVectorTest, Erase_element_front2) { + SmallVector vec = {0, 1, 2, 3}; + + EXPECT_EQ(vec.front(), 0); + EXPECT_EQ(vec.size(), 4); + vec.erase(vec.begin()); + EXPECT_EQ(vec.front(), 1); + EXPECT_EQ(vec.size(), 3); +} + +TEST(SmallVectorTest, Erase_element_back1) { + SmallVector vec = {0, 1, 2, 3}; + SmallVector result = {0, 1, 2}; + + EXPECT_EQ(vec[3], 3); + EXPECT_EQ(vec.size(), 4); + vec.erase(vec.begin() + 3); + EXPECT_EQ(vec.size(), 3); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Erase_element_back2) { + SmallVector vec = {0, 1, 2, 3}; + SmallVector result = {0, 1, 2}; + + EXPECT_EQ(vec[3], 3); + EXPECT_EQ(vec.size(), 4); + vec.erase(vec.begin() + 3); + EXPECT_EQ(vec.size(), 3); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Erase_element_middle1) { + SmallVector vec = {0, 1, 2, 3}; + SmallVector result = {0, 1, 3}; + + EXPECT_EQ(vec.size(), 4); + vec.erase(vec.begin() + 2); + EXPECT_EQ(vec.size(), 3); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Erase_element_middle2) { + SmallVector vec = {0, 1, 2, 3}; + SmallVector result = {0, 1, 3}; + + EXPECT_EQ(vec.size(), 4); + vec.erase(vec.begin() + 2); + EXPECT_EQ(vec.size(), 3); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Erase_range_1) { + SmallVector vec = {0, 1, 2, 3}; + SmallVector result = {}; + + EXPECT_EQ(vec.size(), 4); + vec.erase(vec.begin(), vec.end()); + EXPECT_EQ(vec.size(), 0); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Erase_range_2) { + SmallVector vec = {0, 1, 2, 3}; + SmallVector result = {}; + + EXPECT_EQ(vec.size(), 4); + vec.erase(vec.begin(), vec.end()); + EXPECT_EQ(vec.size(), 0); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Erase_range_3) { + SmallVector vec = {0, 1, 2, 3}; + SmallVector result = {2, 3}; + + EXPECT_EQ(vec.size(), 4); + vec.erase(vec.begin(), vec.begin() + 2); + EXPECT_EQ(vec.size(), 2); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Erase_range_4) { + SmallVector vec = {0, 1, 2, 3}; + SmallVector result = {2, 3}; + + EXPECT_EQ(vec.size(), 4); + vec.erase(vec.begin(), vec.begin() + 2); + EXPECT_EQ(vec.size(), 2); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Erase_range_5) { + SmallVector vec = {0, 1, 2, 3}; + SmallVector result = {0, 3}; + + EXPECT_EQ(vec.size(), 4); + vec.erase(vec.begin() + 1, vec.begin() + 3); + EXPECT_EQ(vec.size(), 2); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Erase_range_6) { + SmallVector vec = {0, 1, 2, 3}; + SmallVector result = {0, 3}; + + EXPECT_EQ(vec.size(), 4); + vec.erase(vec.begin() + 1, vec.begin() + 3); + EXPECT_EQ(vec.size(), 2); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Push_back) { + SmallVector vec; + SmallVector result = {0, 1, 2, 3}; + + EXPECT_EQ(vec.size(), 0); + vec.push_back(0); + EXPECT_EQ(vec.size(), 1); + vec.push_back(1); + EXPECT_EQ(vec.size(), 2); + vec.push_back(2); + EXPECT_EQ(vec.size(), 3); + vec.push_back(3); + EXPECT_EQ(vec.size(), 4); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Emplace_back) { + SmallVector vec; + SmallVector result = {0, 1, 2, 3}; + + EXPECT_EQ(vec.size(), 0); + vec.emplace_back(0); + EXPECT_EQ(vec.size(), 1); + vec.emplace_back(1); + EXPECT_EQ(vec.size(), 2); + vec.emplace_back(2); + EXPECT_EQ(vec.size(), 3); + vec.emplace_back(3); + EXPECT_EQ(vec.size(), 4); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Clear) { + SmallVector vec = {0, 1, 2, 3}; + SmallVector result = {}; + + EXPECT_EQ(vec.size(), 4); + vec.clear(); + EXPECT_EQ(vec.size(), 0); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Insert1) { + SmallVector vec = {}; + SmallVector insert_values = {10, 11}; + SmallVector result = {10, 11}; + + EXPECT_EQ(vec.size(), 0); + auto ret = + vec.insert(vec.begin(), insert_values.begin(), insert_values.end()); + EXPECT_EQ(vec.size(), 2); + EXPECT_EQ(vec, result); + EXPECT_EQ(*ret, 10); +} + +TEST(SmallVectorTest, Insert2) { + SmallVector vec = {}; + SmallVector insert_values = {10, 11, 12}; + SmallVector result = {10, 11, 12}; + + EXPECT_EQ(vec.size(), 0); + auto ret = + vec.insert(vec.begin(), insert_values.begin(), insert_values.end()); + EXPECT_EQ(vec.size(), 3); + EXPECT_EQ(vec, result); + EXPECT_EQ(*ret, 10); +} + +TEST(SmallVectorTest, Insert3) { + SmallVector vec = {0}; + SmallVector insert_values = {10, 11, 12}; + SmallVector result = {10, 11, 12, 0}; + + EXPECT_EQ(vec.size(), 1); + auto ret = + vec.insert(vec.begin(), insert_values.begin(), insert_values.end()); + EXPECT_EQ(vec.size(), 4); + EXPECT_EQ(vec, result); + EXPECT_EQ(*ret, 10); +} + +TEST(SmallVectorTest, Insert4) { + SmallVector vec = {0}; + SmallVector insert_values = {10, 11, 12}; + SmallVector result = {10, 11, 12, 0}; + + EXPECT_EQ(vec.size(), 1); + auto ret = + vec.insert(vec.begin(), insert_values.begin(), insert_values.end()); + EXPECT_EQ(vec.size(), 4); + EXPECT_EQ(vec, result); + EXPECT_EQ(*ret, 10); +} + +TEST(SmallVectorTest, Insert5) { + SmallVector vec = {0, 1, 2}; + SmallVector insert_values = {10, 11, 12}; + SmallVector result = {0, 1, 2, 10, 11, 12}; + + EXPECT_EQ(vec.size(), 3); + auto ret = vec.insert(vec.end(), insert_values.begin(), insert_values.end()); + EXPECT_EQ(vec.size(), 6); + EXPECT_EQ(vec, result); + EXPECT_EQ(*ret, 10); +} + +TEST(SmallVectorTest, Insert6) { + SmallVector vec = {0, 1, 2}; + SmallVector insert_values = {10, 11, 12}; + SmallVector result = {0, 1, 2, 10, 11, 12}; + + EXPECT_EQ(vec.size(), 3); + auto ret = vec.insert(vec.end(), insert_values.begin(), insert_values.end()); + EXPECT_EQ(vec.size(), 6); + EXPECT_EQ(vec, result); + EXPECT_EQ(*ret, 10); +} + +TEST(SmallVectorTest, Insert7) { + SmallVector vec = {0, 1, 2}; + SmallVector insert_values = {10, 11, 12}; + SmallVector result = {0, 10, 11, 12, 1, 2}; + + EXPECT_EQ(vec.size(), 3); + auto ret = + vec.insert(vec.begin() + 1, insert_values.begin(), insert_values.end()); + EXPECT_EQ(vec.size(), 6); + EXPECT_EQ(vec, result); + EXPECT_EQ(*ret, 10); +} + +TEST(SmallVectorTest, Insert8) { + SmallVector vec = {0, 1, 2}; + SmallVector insert_values = {10, 11, 12}; + SmallVector result = {0, 10, 11, 12, 1, 2}; + + EXPECT_EQ(vec.size(), 3); + auto ret = + vec.insert(vec.begin() + 1, insert_values.begin(), insert_values.end()); + EXPECT_EQ(vec.size(), 6); + EXPECT_EQ(vec, result); + EXPECT_EQ(*ret, 10); +} + +TEST(SmallVectorTest, Resize1) { + SmallVector vec = {0, 1, 2}; + SmallVector result = {0, 1, 2, 10, 10, 10}; + + EXPECT_EQ(vec.size(), 3); + vec.resize(6, 10); + EXPECT_EQ(vec.size(), 6); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Resize2) { + SmallVector vec = {0, 1, 2}; + SmallVector result = {0, 1, 2, 10, 10, 10}; + + EXPECT_EQ(vec.size(), 3); + vec.resize(6, 10); + EXPECT_EQ(vec.size(), 6); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Resize3) { + SmallVector vec = {0, 1, 2}; + SmallVector result = {0, 1, 2, 10, 10, 10}; + + EXPECT_EQ(vec.size(), 3); + vec.resize(6, 10); + EXPECT_EQ(vec.size(), 6); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Resize4) { + SmallVector vec = {0, 1, 2, 10, 10, 10}; + SmallVector result = {0, 1, 2}; + + EXPECT_EQ(vec.size(), 6); + vec.resize(3, 10); + EXPECT_EQ(vec.size(), 3); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Resize5) { + SmallVector vec = {0, 1, 2, 10, 10, 10}; + SmallVector result = {0, 1, 2}; + + EXPECT_EQ(vec.size(), 6); + vec.resize(3, 10); + EXPECT_EQ(vec.size(), 3); + EXPECT_EQ(vec, result); +} + +TEST(SmallVectorTest, Resize6) { + SmallVector vec = {0, 1, 2, 10, 10, 10}; + SmallVector result = {0, 1, 2}; + + EXPECT_EQ(vec.size(), 6); + vec.resize(3, 10); + EXPECT_EQ(vec.size(), 3); + EXPECT_EQ(vec, result); +} + +} // namespace +} // namespace utils +} // namespace spvtools diff --git a/test/val/CMakeLists.txt b/test/val/CMakeLists.txt new file mode 100644 index 000000000..d478a7d1c --- /dev/null +++ b/test/val/CMakeLists.txt @@ -0,0 +1,80 @@ +# Copyright (c) 2016 The Khronos Group Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set(VAL_TEST_COMMON_SRCS + ${CMAKE_CURRENT_SOURCE_DIR}/../test_fixture.h + ${CMAKE_CURRENT_SOURCE_DIR}/../unit_spirv.h + ${CMAKE_CURRENT_SOURCE_DIR}/val_fixtures.h +) + +add_spvtools_unittest(TARGET val_abcde + SRCS + val_adjacency_test.cpp + val_arithmetics_test.cpp + val_atomics_test.cpp + val_barriers_test.cpp + val_bitwise_test.cpp + val_builtins_test.cpp + val_capability_test.cpp + val_cfg_test.cpp + val_composites_test.cpp + val_constants_test.cpp + val_conversion_test.cpp + val_data_test.cpp + val_decoration_test.cpp + val_derivatives_test.cpp + val_explicit_reserved_test.cpp + val_extensions_test.cpp + val_ext_inst_test.cpp + ${VAL_TEST_COMMON_SRCS} + LIBS ${SPIRV_TOOLS} + PCH_FILE pch_test_val +) + +add_spvtools_unittest(TARGET val_limits + SRCS val_limits_test.cpp + ${VAL_TEST_COMMON_SRCS} + LIBS ${SPIRV_TOOLS} +) + +add_spvtools_unittest(TARGET val_ijklmnop + SRCS + val_id_test.cpp + val_image_test.cpp + val_interfaces_test.cpp + val_layout_test.cpp + val_literals_test.cpp + val_logicals_test.cpp + val_memory_test.cpp + val_modes_test.cpp + val_non_uniform_test.cpp + val_primitives_test.cpp + ${VAL_TEST_COMMON_SRCS} + LIBS ${SPIRV_TOOLS} + PCH_FILE pch_test_val +) + +add_spvtools_unittest(TARGET val_stuvw + SRCS + val_ssa_test.cpp + val_state_test.cpp + val_storage_test.cpp + val_type_unique_test.cpp + val_validation_state_test.cpp + val_version_test.cpp + val_webgpu_test.cpp + ${VAL_TEST_COMMON_SRCS} + LIBS ${SPIRV_TOOLS} + PCH_FILE pch_test_val +) diff --git a/test/val/pch_test_val.cpp b/test/val/pch_test_val.cpp new file mode 100644 index 000000000..fc92e375a --- /dev/null +++ b/test/val/pch_test_val.cpp @@ -0,0 +1,15 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "pch_test_val.h" diff --git a/test/val/pch_test_val.h b/test/val/pch_test_val.h new file mode 100644 index 000000000..7b5881c43 --- /dev/null +++ b/test/val/pch_test_val.h @@ -0,0 +1,19 @@ +// Copyright (c) 2018 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" diff --git a/test/val/val_adjacency_test.cpp b/test/val/val_adjacency_test.cpp new file mode 100644 index 000000000..5c1124ae5 --- /dev/null +++ b/test/val/val_adjacency_test.cpp @@ -0,0 +1,470 @@ +// Copyright (c) 2018 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::HasSubstr; +using ::testing::Not; + +using ValidateAdjacency = spvtest::ValidateBase; + +TEST_F(ValidateAdjacency, OpPhiBeginsModuleFail) { + const std::string module = R"( +%result = OpPhi %bool %true %true_label %false %false_label +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%false = OpConstantFalse %bool +%func = OpTypeFunction %void +%main = OpFunction %void None %func +%main_entry = OpLabel +OpBranch %true_label +%true_label = OpLabel +OpBranch %false_label +%false_label = OpLabel +OpBranch %end_label +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(module); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("ID 1[%bool] has not been defined")); +} + +TEST_F(ValidateAdjacency, OpLoopMergeEndsModuleFail) { + const std::string module = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%func = OpTypeFunction %void +%main = OpFunction %void None %func +%main_entry = OpLabel +OpBranch %loop +%loop = OpLabel +OpLoopMerge %end %loop None +)"; + + CompileSuccessfully(module); + EXPECT_EQ(SPV_ERROR_INVALID_LAYOUT, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Missing OpFunctionEnd at end of module")); +} + +TEST_F(ValidateAdjacency, OpSelectionMergeEndsModuleFail) { + const std::string module = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%func = OpTypeFunction %void +%main = OpFunction %void None %func +%main_entry = OpLabel +OpBranch %merge +%merge = OpLabel +OpSelectionMerge %merge None +)"; + + CompileSuccessfully(module); + EXPECT_EQ(SPV_ERROR_INVALID_LAYOUT, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Missing OpFunctionEnd at end of module")); +} + +std::string GenerateShaderCode( + const std::string& body, + const std::string& capabilities_and_extensions = "OpCapability Shader", + const std::string& execution_model = "Fragment") { + std::ostringstream ss; + ss << capabilities_and_extensions << "\n"; + ss << "OpMemoryModel Logical GLSL450\n"; + ss << "OpEntryPoint " << execution_model << " %main \"main\"\n"; + if (execution_model == "Fragment") { + ss << "OpExecutionMode %main OriginUpperLeft\n"; + } + + ss << R"( +%string = OpString "" +%void = OpTypeVoid +%bool = OpTypeBool +%int = OpTypeInt 32 0 +%true = OpConstantTrue %bool +%false = OpConstantFalse %bool +%zero = OpConstant %int 0 +%int_1 = OpConstant %int 1 +%func = OpTypeFunction %void +%func_int = OpTypePointer Function %int +%main = OpFunction %void None %func +%main_entry = OpLabel +)"; + + ss << body; + + ss << R"( +OpReturn +OpFunctionEnd)"; + + return ss.str(); +} + +TEST_F(ValidateAdjacency, OpPhiPreceededByOpLabelSuccess) { + const std::string body = R"( +OpSelectionMerge %end_label None +OpBranchConditional %true %true_label %false_label +%true_label = OpLabel +OpBranch %end_label +%false_label = OpLabel +OpBranch %end_label +%end_label = OpLabel +%line = OpLine %string 0 0 +%result = OpPhi %bool %true %true_label %false %false_label +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAdjacency, OpPhiPreceededByOpPhiSuccess) { + const std::string body = R"( +OpSelectionMerge %end_label None +OpBranchConditional %true %true_label %false_label +%true_label = OpLabel +OpBranch %end_label +%false_label = OpLabel +OpBranch %end_label +%end_label = OpLabel +%1 = OpPhi %bool %true %true_label %false %false_label +%2 = OpPhi %bool %true %true_label %false %false_label +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAdjacency, OpPhiPreceededByOpLineSuccess) { + const std::string body = R"( +OpSelectionMerge %end_label None +OpBranchConditional %true %true_label %false_label +%true_label = OpLabel +OpBranch %end_label +%false_label = OpLabel +OpBranch %end_label +%end_label = OpLabel +%line = OpLine %string 0 0 +%result = OpPhi %bool %true %true_label %false %false_label +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAdjacency, OpPhiPreceededByBadOpFail) { + const std::string body = R"( +OpSelectionMerge %end_label None +OpBranchConditional %true %true_label %false_label +%true_label = OpLabel +OpBranch %end_label +%false_label = OpLabel +OpBranch %end_label +%end_label = OpLabel +OpNop +%result = OpPhi %bool %true %true_label %false %false_label +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpPhi must appear within a non-entry block before all " + "non-OpPhi instructions")); +} + +TEST_F(ValidateAdjacency, OpPhiPreceededByOpLineAndBadOpFail) { + const std::string body = R"( +OpSelectionMerge %end_label None +OpBranchConditional %true %true_label %false_label +%true_label = OpLabel +OpBranch %end_label +%false_label = OpLabel +OpBranch %end_label +%end_label = OpLabel +OpNop +OpLine %string 1 1 +%result = OpPhi %bool %true %true_label %false %false_label +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpPhi must appear within a non-entry block before all " + "non-OpPhi instructions")); +} + +TEST_F(ValidateAdjacency, OpPhiFollowedByOpLineGood) { + const std::string body = R"( +OpSelectionMerge %end_label None +OpBranchConditional %true %true_label %false_label +%true_label = OpLabel +OpBranch %end_label +%false_label = OpLabel +OpBranch %end_label +%end_label = OpLabel +%result = OpPhi %bool %true %true_label %false %false_label +OpLine %string 1 1 +OpNop +OpNop +OpLine %string 2 1 +OpNop +OpLine %string 3 1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAdjacency, OpPhiMultipleOpLineAndOpPhiFail) { + const std::string body = R"( +OpSelectionMerge %end_label None +OpBranchConditional %true %true_label %false_label +%true_label = OpLabel +OpBranch %end_label +%false_label = OpLabel +OpBranch %end_label +%end_label = OpLabel +OpLine %string 1 1 +%value = OpPhi %int %zero %true_label %int_1 %false_label +OpNop +OpLine %string 2 1 +OpNop +OpLine %string 3 1 +%result = OpPhi %bool %true %true_label %false %false_label +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpPhi must appear within a non-entry block before all " + "non-OpPhi instructions")); +} + +TEST_F(ValidateAdjacency, OpPhiMultipleOpLineAndOpPhiGood) { + const std::string body = R"( +OpSelectionMerge %end_label None +OpBranchConditional %true %true_label %false_label +%true_label = OpLabel +OpBranch %end_label +%false_label = OpLabel +OpBranch %end_label +%end_label = OpLabel +OpLine %string 1 1 +%value = OpPhi %int %zero %true_label %int_1 %false_label +OpLine %string 2 1 +%result = OpPhi %bool %true %true_label %false %false_label +OpLine %string 3 1 +OpNop +OpNop +OpLine %string 4 1 +OpNop +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAdjacency, OpPhiInEntryBlockBad) { + const std::string body = R"( +OpLine %string 1 1 +%value = OpPhi %int +OpLine %string 2 1 +OpNop +OpLine %string 3 1 +OpNop +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpPhi must appear within a non-entry block before all " + "non-OpPhi instructions")); +} + +TEST_F(ValidateAdjacency, OpVariableInFunctionGood) { + const std::string body = R"( +OpLine %string 1 1 +%var = OpVariable %func_int Function +OpLine %string 2 1 +OpNop +OpLine %string 3 1 +OpNop +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAdjacency, OpVariableInFunctionMultipleGood) { + const std::string body = R"( +OpLine %string 1 1 +%1 = OpVariable %func_int Function +OpLine %string 2 1 +%2 = OpVariable %func_int Function +%3 = OpVariable %func_int Function +OpNop +OpLine %string 3 1 +OpNop +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAdjacency, OpVariableInFunctionBad) { + const std::string body = R"( +%1 = OpUndef %int +%2 = OpVariable %func_int Function +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("All OpVariable instructions in a function must be the " + "first instructions")); +} + +TEST_F(ValidateAdjacency, OpVariableInFunctionMultipleBad) { + const std::string body = R"( +OpNop +%1 = OpVariable %func_int Function +OpLine %string 1 1 +%2 = OpVariable %func_int Function +OpNop +OpNop +OpLine %string 2 1 +%3 = OpVariable %func_int Function +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("All OpVariable instructions in a function must be the " + "first instructions")); +} + +TEST_F(ValidateAdjacency, OpLoopMergePreceedsOpBranchSuccess) { + const std::string body = R"( +OpBranch %loop +%loop = OpLabel +OpLoopMerge %end %loop None +OpBranch %loop +%end = OpLabel +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAdjacency, OpLoopMergePreceedsOpBranchConditionalSuccess) { + const std::string body = R"( +OpBranch %loop +%loop = OpLabel +OpLoopMerge %end %loop None +OpBranchConditional %true %loop %end +%end = OpLabel +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAdjacency, OpLoopMergePreceedsBadOpFail) { + const std::string body = R"( +OpBranch %loop +%loop = OpLabel +OpLoopMerge %end %loop None +OpNop +OpBranchConditional %true %loop %end +%end = OpLabel +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpLoopMerge must immediately precede either an " + "OpBranch or OpBranchConditional instruction.")); +} + +TEST_F(ValidateAdjacency, OpSelectionMergePreceedsOpBranchConditionalSuccess) { + const std::string body = R"( +OpSelectionMerge %end_label None +OpBranchConditional %true %true_label %false_label +%true_label = OpLabel +OpBranch %end_label +%false_label = OpLabel +OpBranch %end_label +%end_label = OpLabel +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAdjacency, OpSelectionMergePreceedsOpSwitchSuccess) { + const std::string body = R"( +OpSelectionMerge %merge None +OpSwitch %zero %merge 0 %label +%label = OpLabel +OpBranch %merge +%merge = OpLabel +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAdjacency, OpSelectionMergePreceedsBadOpFail) { + const std::string body = R"( +OpSelectionMerge %merge None +OpNop +OpSwitch %zero %merge 0 %label +%label = OpLabel +OpBranch %merge +%merge = OpLabel +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpSelectionMerge must immediately precede either an " + "OpBranchConditional or OpSwitch instruction")); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_arithmetics_test.cpp b/test/val/val_arithmetics_test.cpp new file mode 100644 index 000000000..87e006c12 --- /dev/null +++ b/test/val/val_arithmetics_test.cpp @@ -0,0 +1,1278 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests for unique type declaration rules validator. + +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::HasSubstr; +using ::testing::Not; + +using ValidateArithmetics = spvtest::ValidateBase; + +std::string GenerateCode(const std::string& main_body) { + const std::string prefix = + R"( +OpCapability Shader +OpCapability Int64 +OpCapability Float64 +OpCapability Matrix +%ext_inst = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f32 = OpTypeFloat 32 +%u32 = OpTypeInt 32 0 +%s32 = OpTypeInt 32 1 +%f64 = OpTypeFloat 64 +%u64 = OpTypeInt 64 0 +%s64 = OpTypeInt 64 1 +%boolvec2 = OpTypeVector %bool 2 +%s32vec2 = OpTypeVector %s32 2 +%u32vec2 = OpTypeVector %u32 2 +%u64vec2 = OpTypeVector %u64 2 +%f32vec2 = OpTypeVector %f32 2 +%f64vec2 = OpTypeVector %f64 2 +%boolvec3 = OpTypeVector %bool 3 +%u32vec3 = OpTypeVector %u32 3 +%u64vec3 = OpTypeVector %u64 3 +%s32vec3 = OpTypeVector %s32 3 +%f32vec3 = OpTypeVector %f32 3 +%f64vec3 = OpTypeVector %f64 3 +%boolvec4 = OpTypeVector %bool 4 +%u32vec4 = OpTypeVector %u32 4 +%u64vec4 = OpTypeVector %u64 4 +%s32vec4 = OpTypeVector %s32 4 +%f32vec4 = OpTypeVector %f32 4 +%f64vec4 = OpTypeVector %f64 4 + +%f32mat22 = OpTypeMatrix %f32vec2 2 +%f32mat23 = OpTypeMatrix %f32vec2 3 +%f32mat32 = OpTypeMatrix %f32vec3 2 +%f32mat33 = OpTypeMatrix %f32vec3 3 +%f64mat22 = OpTypeMatrix %f64vec2 2 + +%struct_f32_f32 = OpTypeStruct %f32 %f32 +%struct_u32_u32 = OpTypeStruct %u32 %u32 +%struct_u32_u32_u32 = OpTypeStruct %u32 %u32 %u32 +%struct_s32_s32 = OpTypeStruct %s32 %s32 +%struct_s32_u32 = OpTypeStruct %s32 %u32 +%struct_u32vec2_u32vec2 = OpTypeStruct %u32vec2 %u32vec2 +%struct_s32vec2_s32vec2 = OpTypeStruct %s32vec2 %s32vec2 + +%f32_0 = OpConstant %f32 0 +%f32_1 = OpConstant %f32 1 +%f32_2 = OpConstant %f32 2 +%f32_3 = OpConstant %f32 3 +%f32_4 = OpConstant %f32 4 +%f32_pi = OpConstant %f32 3.14159 + +%s32_0 = OpConstant %s32 0 +%s32_1 = OpConstant %s32 1 +%s32_2 = OpConstant %s32 2 +%s32_3 = OpConstant %s32 3 +%s32_4 = OpConstant %s32 4 +%s32_m1 = OpConstant %s32 -1 + +%u32_0 = OpConstant %u32 0 +%u32_1 = OpConstant %u32 1 +%u32_2 = OpConstant %u32 2 +%u32_3 = OpConstant %u32 3 +%u32_4 = OpConstant %u32 4 + +%f64_0 = OpConstant %f64 0 +%f64_1 = OpConstant %f64 1 +%f64_2 = OpConstant %f64 2 +%f64_3 = OpConstant %f64 3 +%f64_4 = OpConstant %f64 4 + +%s64_0 = OpConstant %s64 0 +%s64_1 = OpConstant %s64 1 +%s64_2 = OpConstant %s64 2 +%s64_3 = OpConstant %s64 3 +%s64_4 = OpConstant %s64 4 +%s64_m1 = OpConstant %s64 -1 + +%u64_0 = OpConstant %u64 0 +%u64_1 = OpConstant %u64 1 +%u64_2 = OpConstant %u64 2 +%u64_3 = OpConstant %u64 3 +%u64_4 = OpConstant %u64 4 + +%u32vec2_01 = OpConstantComposite %u32vec2 %u32_0 %u32_1 +%u32vec2_12 = OpConstantComposite %u32vec2 %u32_1 %u32_2 +%u32vec3_012 = OpConstantComposite %u32vec3 %u32_0 %u32_1 %u32_2 +%u32vec3_123 = OpConstantComposite %u32vec3 %u32_1 %u32_2 %u32_3 +%u32vec4_0123 = OpConstantComposite %u32vec4 %u32_0 %u32_1 %u32_2 %u32_3 +%u32vec4_1234 = OpConstantComposite %u32vec4 %u32_1 %u32_2 %u32_3 %u32_4 + +%s32vec2_01 = OpConstantComposite %s32vec2 %s32_0 %s32_1 +%s32vec2_12 = OpConstantComposite %s32vec2 %s32_1 %s32_2 +%s32vec3_012 = OpConstantComposite %s32vec3 %s32_0 %s32_1 %s32_2 +%s32vec3_123 = OpConstantComposite %s32vec3 %s32_1 %s32_2 %s32_3 +%s32vec4_0123 = OpConstantComposite %s32vec4 %s32_0 %s32_1 %s32_2 %s32_3 +%s32vec4_1234 = OpConstantComposite %s32vec4 %s32_1 %s32_2 %s32_3 %s32_4 + +%f32vec2_01 = OpConstantComposite %f32vec2 %f32_0 %f32_1 +%f32vec2_12 = OpConstantComposite %f32vec2 %f32_1 %f32_2 +%f32vec3_012 = OpConstantComposite %f32vec3 %f32_0 %f32_1 %f32_2 +%f32vec3_123 = OpConstantComposite %f32vec3 %f32_1 %f32_2 %f32_3 +%f32vec4_0123 = OpConstantComposite %f32vec4 %f32_0 %f32_1 %f32_2 %f32_3 +%f32vec4_1234 = OpConstantComposite %f32vec4 %f32_1 %f32_2 %f32_3 %f32_4 + +%f64vec2_01 = OpConstantComposite %f64vec2 %f64_0 %f64_1 +%f64vec2_12 = OpConstantComposite %f64vec2 %f64_1 %f64_2 +%f64vec3_012 = OpConstantComposite %f64vec3 %f64_0 %f64_1 %f64_2 +%f64vec3_123 = OpConstantComposite %f64vec3 %f64_1 %f64_2 %f64_3 +%f64vec4_0123 = OpConstantComposite %f64vec4 %f64_0 %f64_1 %f64_2 %f64_3 +%f64vec4_1234 = OpConstantComposite %f64vec4 %f64_1 %f64_2 %f64_3 %f64_4 + +%f32mat22_1212 = OpConstantComposite %f32mat22 %f32vec2_12 %f32vec2_12 +%f32mat23_121212 = OpConstantComposite %f32mat23 %f32vec2_12 %f32vec2_12 %f32vec2_12 +%f32mat32_123123 = OpConstantComposite %f32mat32 %f32vec3_123 %f32vec3_123 +%f32mat33_123123123 = OpConstantComposite %f32mat33 %f32vec3_123 %f32vec3_123 %f32vec3_123 + +%f64mat22_1212 = OpConstantComposite %f64mat22 %f64vec2_12 %f64vec2_12 + +%main = OpFunction %void None %func +%main_entry = OpLabel)"; + + const std::string suffix = + R"( +OpReturn +OpFunctionEnd)"; + + return prefix + main_body + suffix; +} + +TEST_F(ValidateArithmetics, F32Success) { + const std::string body = R"( +%val1 = OpFMul %f32 %f32_0 %f32_1 +%val2 = OpFSub %f32 %f32_2 %f32_0 +%val3 = OpFAdd %f32 %val1 %val2 +%val4 = OpFNegate %f32 %val3 +%val5 = OpFDiv %f32 %val4 %val1 +%val6 = OpFRem %f32 %val4 %f32_2 +%val7 = OpFMod %f32 %val4 %f32_2 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, F64Success) { + const std::string body = R"( +%val1 = OpFMul %f64 %f64_0 %f64_1 +%val2 = OpFSub %f64 %f64_2 %f64_0 +%val3 = OpFAdd %f64 %val1 %val2 +%val4 = OpFNegate %f64 %val3 +%val5 = OpFDiv %f64 %val4 %val1 +%val6 = OpFRem %f64 %val4 %f64_2 +%val7 = OpFMod %f64 %val4 %f64_2 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, Int32Success) { + const std::string body = R"( +%val1 = OpIMul %u32 %s32_0 %u32_1 +%val2 = OpIMul %s32 %s32_2 %u32_1 +%val3 = OpIAdd %u32 %val1 %val2 +%val4 = OpIAdd %s32 %val1 %val2 +%val5 = OpISub %u32 %val3 %val4 +%val6 = OpISub %s32 %val4 %val3 +%val7 = OpSDiv %s32 %val4 %val3 +%val8 = OpSNegate %s32 %val7 +%val9 = OpSRem %s32 %val4 %val3 +%val10 = OpSMod %s32 %val4 %val3 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, Int64Success) { + const std::string body = R"( +%val1 = OpIMul %u64 %s64_0 %u64_1 +%val2 = OpIMul %s64 %s64_2 %u64_1 +%val3 = OpIAdd %u64 %val1 %val2 +%val4 = OpIAdd %s64 %val1 %val2 +%val5 = OpISub %u64 %val3 %val4 +%val6 = OpISub %s64 %val4 %val3 +%val7 = OpSDiv %s64 %val4 %val3 +%val8 = OpSNegate %s64 %val7 +%val9 = OpSRem %s64 %val4 %val3 +%val10 = OpSMod %s64 %val4 %val3 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, F32Vec2Success) { + const std::string body = R"( +%val1 = OpFMul %f32vec2 %f32vec2_01 %f32vec2_12 +%val2 = OpFSub %f32vec2 %f32vec2_12 %f32vec2_01 +%val3 = OpFAdd %f32vec2 %val1 %val2 +%val4 = OpFNegate %f32vec2 %val3 +%val5 = OpFDiv %f32vec2 %val4 %val1 +%val6 = OpFRem %f32vec2 %val4 %f32vec2_12 +%val7 = OpFMod %f32vec2 %val4 %f32vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, F64Vec2Success) { + const std::string body = R"( +%val1 = OpFMul %f64vec2 %f64vec2_01 %f64vec2_12 +%val2 = OpFSub %f64vec2 %f64vec2_12 %f64vec2_01 +%val3 = OpFAdd %f64vec2 %val1 %val2 +%val4 = OpFNegate %f64vec2 %val3 +%val5 = OpFDiv %f64vec2 %val4 %val1 +%val6 = OpFRem %f64vec2 %val4 %f64vec2_12 +%val7 = OpFMod %f64vec2 %val4 %f64vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, U32Vec2Success) { + const std::string body = R"( +%val1 = OpIMul %u32vec2 %u32vec2_01 %u32vec2_12 +%val2 = OpISub %u32vec2 %u32vec2_12 %u32vec2_01 +%val3 = OpIAdd %u32vec2 %val1 %val2 +%val4 = OpSNegate %u32vec2 %val3 +%val5 = OpSDiv %u32vec2 %val4 %val1 +%val6 = OpSRem %u32vec2 %val4 %u32vec2_12 +%val7 = OpSMod %u32vec2 %val4 %u32vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, FNegateTypeIdU32) { + const std::string body = R"( +%val = OpFNegate %u32 %u32_0 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected floating scalar or vector type as Result Type: FNegate")); +} + +TEST_F(ValidateArithmetics, FNegateTypeIdVec2U32) { + const std::string body = R"( +%val = OpFNegate %u32vec2 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected floating scalar or vector type as Result Type: FNegate")); +} + +TEST_F(ValidateArithmetics, FNegateWrongOperand) { + const std::string body = R"( +%val = OpFNegate %f32 %u32_0 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected arithmetic operands to be of Result Type: " + "FNegate operand index 2")); +} + +TEST_F(ValidateArithmetics, FMulTypeIdU32) { + const std::string body = R"( +%val = OpFMul %u32 %u32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected floating scalar or vector type as Result Type: FMul")); +} + +TEST_F(ValidateArithmetics, FMulTypeIdVec2U32) { + const std::string body = R"( +%val = OpFMul %u32vec2 %u32vec2_01 %u32vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected floating scalar or vector type as Result Type: FMul")); +} + +TEST_F(ValidateArithmetics, FMulWrongOperand1) { + const std::string body = R"( +%val = OpFMul %f32 %u32_0 %f32_1 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected arithmetic operands to be of Result Type: " + "FMul operand index 2")); +} + +TEST_F(ValidateArithmetics, FMulWrongOperand2) { + const std::string body = R"( +%val = OpFMul %f32 %f32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected arithmetic operands to be of Result Type: " + "FMul operand index 3")); +} + +TEST_F(ValidateArithmetics, FMulWrongVectorOperand1) { + const std::string body = R"( +%val = OpFMul %f64vec3 %f32vec3_123 %f64vec3_012 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected arithmetic operands to be of Result Type: " + "FMul operand index 2")); +} + +TEST_F(ValidateArithmetics, FMulWrongVectorOperand2) { + const std::string body = R"( +%val = OpFMul %f32vec3 %f32vec3_123 %f64vec3_012 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected arithmetic operands to be of Result Type: " + "FMul operand index 3")); +} + +TEST_F(ValidateArithmetics, IMulFloatTypeId) { + const std::string body = R"( +%val = OpIMul %f32 %u32_0 %s32_1 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected int scalar or vector type as Result Type: IMul")); +} + +TEST_F(ValidateArithmetics, IMulFloatOperand1) { + const std::string body = R"( +%val = OpIMul %u32 %f32_0 %s32_1 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected int scalar or vector type as operand: " + "IMul operand index 2")); +} + +TEST_F(ValidateArithmetics, IMulFloatOperand2) { + const std::string body = R"( +%val = OpIMul %u32 %s32_0 %f32_1 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected int scalar or vector type as operand: " + "IMul operand index 3")); +} + +TEST_F(ValidateArithmetics, IMulWrongBitWidthOperand1) { + const std::string body = R"( +%val = OpIMul %u64 %u32_0 %s64_1 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected arithmetic operands to have the same bit width " + "as Result Type: IMul operand index 2")); +} + +TEST_F(ValidateArithmetics, IMulWrongBitWidthOperand2) { + const std::string body = R"( +%val = OpIMul %u32 %u32_0 %s64_1 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected arithmetic operands to have the same bit width " + "as Result Type: IMul operand index 3")); +} + +TEST_F(ValidateArithmetics, IMulWrongBitWidthVector) { + const std::string body = R"( +%val = OpIMul %u64vec3 %u32vec3_012 %u32vec3_123 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected arithmetic operands to have the same bit width " + "as Result Type: IMul operand index 2")); +} + +TEST_F(ValidateArithmetics, IMulVectorScalarOperand1) { + const std::string body = R"( +%val = OpIMul %u32vec2 %u32_0 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected arithmetic operands to have the same dimension " + "as Result Type: IMul operand index 2")); +} + +TEST_F(ValidateArithmetics, IMulVectorScalarOperand2) { + const std::string body = R"( +%val = OpIMul %u32vec2 %u32vec2_01 %u32_0 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected arithmetic operands to have the same dimension " + "as Result Type: IMul operand index 3")); +} + +TEST_F(ValidateArithmetics, IMulScalarVectorOperand1) { + const std::string body = R"( +%val = OpIMul %s32 %u32vec2_01 %u32_0 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected arithmetic operands to have the same dimension " + "as Result Type: IMul operand index 2")); +} + +TEST_F(ValidateArithmetics, IMulScalarVectorOperand2) { + const std::string body = R"( +%val = OpIMul %u32 %u32_0 %s32vec2_01 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected arithmetic operands to have the same dimension " + "as Result Type: IMul operand index 3")); +} + +TEST_F(ValidateArithmetics, SNegateFloat) { + const std::string body = R"( +%val = OpSNegate %s32 %f32_1 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected int scalar or vector type as operand: " + "SNegate operand index 2")); +} + +TEST_F(ValidateArithmetics, UDivFloatType) { + const std::string body = R"( +%val = OpUDiv %f32 %u32_2 %u32_1 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected unsigned int scalar or vector type as Result Type: UDiv")); +} + +TEST_F(ValidateArithmetics, UDivSignedIntType) { + const std::string body = R"( +%val = OpUDiv %s32 %u32_2 %u32_1 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected unsigned int scalar or vector type as Result Type: UDiv")); +} + +TEST_F(ValidateArithmetics, UDivWrongOperand1) { + const std::string body = R"( +%val = OpUDiv %u64 %f64_2 %u64_1 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected arithmetic operands to be of Result Type: " + "UDiv operand index 2")); +} + +TEST_F(ValidateArithmetics, UDivWrongOperand2) { + const std::string body = R"( +%val = OpUDiv %u64 %u64_2 %u32_1 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected arithmetic operands to be of Result Type: " + "UDiv operand index 3")); +} + +TEST_F(ValidateArithmetics, DotSuccess) { + const std::string body = R"( +%val = OpDot %f32 %f32vec2_01 %f32vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, DotWrongTypeId) { + const std::string body = R"( +%val = OpDot %u32 %u32vec2_01 %u32vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected float scalar type as Result Type: Dot")); +} + +TEST_F(ValidateArithmetics, DotNotVectorTypeOperand1) { + const std::string body = R"( +%val = OpDot %f32 %f32 %f32vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("Operand 6[%float] cannot be a " + "type")); +} + +TEST_F(ValidateArithmetics, DotNotVectorTypeOperand2) { + const std::string body = R"( +%val = OpDot %f32 %f32vec3_012 %f32_1 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected float vector as operand: Dot operand index 3")); +} + +TEST_F(ValidateArithmetics, DotWrongComponentOperand1) { + const std::string body = R"( +%val = OpDot %f64 %f32vec2_01 %f64vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected component type to be equal to Result Type: " + "Dot operand index 2")); +} + +TEST_F(ValidateArithmetics, DotWrongComponentOperand2) { + const std::string body = R"( +%val = OpDot %f32 %f32vec2_01 %f64vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected component type to be equal to Result Type: " + "Dot operand index 3")); +} + +TEST_F(ValidateArithmetics, DotDifferentVectorSize) { + const std::string body = R"( +%val = OpDot %f32 %f32vec2_01 %f32vec3_123 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected operands to have the same number of componenets: Dot")); +} + +TEST_F(ValidateArithmetics, VectorTimesScalarSuccess) { + const std::string body = R"( +%val = OpVectorTimesScalar %f32vec2 %f32vec2_01 %f32_2 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, VectorTimesScalarWrongTypeId) { + const std::string body = R"( +%val = OpVectorTimesScalar %u32vec2 %f32vec2_01 %f32_2 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected float vector type as Result Type: " + "VectorTimesScalar")); +} + +TEST_F(ValidateArithmetics, VectorTimesScalarWrongVector) { + const std::string body = R"( +%val = OpVectorTimesScalar %f32vec2 %f32vec3_012 %f32_2 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected vector operand type to be equal to Result Type: " + "VectorTimesScalar")); +} + +TEST_F(ValidateArithmetics, VectorTimesScalarWrongScalar) { + const std::string body = R"( +%val = OpVectorTimesScalar %f32vec2 %f32vec2_01 %f64_2 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected scalar operand type to be equal to the component " + "type of the vector operand: VectorTimesScalar")); +} + +TEST_F(ValidateArithmetics, MatrixTimesScalarSuccess) { + const std::string body = R"( +%val = OpMatrixTimesScalar %f32mat22 %f32mat22_1212 %f32_2 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, MatrixTimesScalarWrongTypeId) { + const std::string body = R"( +%val = OpMatrixTimesScalar %f32vec2 %f32mat22_1212 %f32_2 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected float matrix type as Result Type: " + "MatrixTimesScalar")); +} + +TEST_F(ValidateArithmetics, MatrixTimesScalarWrongMatrix) { + const std::string body = R"( +%val = OpMatrixTimesScalar %f32mat22 %f32vec2_01 %f32_2 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected matrix operand type to be equal to Result Type: " + "MatrixTimesScalar")); +} + +TEST_F(ValidateArithmetics, MatrixTimesScalarWrongScalar) { + const std::string body = R"( +%val = OpMatrixTimesScalar %f32mat22 %f32mat22_1212 %f64_2 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected scalar operand type to be equal to the component " + "type of the matrix operand: MatrixTimesScalar")); +} + +TEST_F(ValidateArithmetics, VectorTimesMatrix2x22Success) { + const std::string body = R"( +%val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat22_1212 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, VectorTimesMatrix3x32Success) { + const std::string body = R"( +%val = OpVectorTimesMatrix %f32vec2 %f32vec3_123 %f32mat32_123123 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, VectorTimesMatrixWrongTypeId) { + const std::string body = R"( +%val = OpVectorTimesMatrix %f32mat22 %f32vec2_12 %f32mat22_1212 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected float vector type as Result Type: " + "VectorTimesMatrix")); +} + +TEST_F(ValidateArithmetics, VectorTimesMatrixNotFloatVector) { + const std::string body = R"( +%val = OpVectorTimesMatrix %f32vec2 %u32vec2_12 %f32mat22_1212 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected float vector type as left operand: " + "VectorTimesMatrix")); +} + +TEST_F(ValidateArithmetics, VectorTimesMatrixWrongVectorComponent) { + const std::string body = R"( +%val = OpVectorTimesMatrix %f32vec2 %f64vec2_12 %f32mat22_1212 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected component types of Result Type and vector to be equal: " + "VectorTimesMatrix")); +} + +TEST_F(ValidateArithmetics, VectorTimesMatrixWrongMatrix) { + const std::string body = R"( +%val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected float matrix type as right operand: " + "VectorTimesMatrix")); +} + +TEST_F(ValidateArithmetics, VectorTimesMatrixWrongMatrixComponent) { + const std::string body = R"( +%val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f64mat22_1212 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected component types of Result Type and matrix to be equal: " + "VectorTimesMatrix")); +} + +TEST_F(ValidateArithmetics, VectorTimesMatrix2eq2x23Fail) { + const std::string body = R"( +%val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat23_121212 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected number of columns of the matrix to be equal to Result Type " + "vector size: VectorTimesMatrix")); +} + +TEST_F(ValidateArithmetics, VectorTimesMatrix2x32Fail) { + const std::string body = R"( +%val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat32_123123 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected number of rows of the matrix to be equal to the vector " + "operand size: VectorTimesMatrix")); +} + +TEST_F(ValidateArithmetics, MatrixTimesVector22x2Success) { + const std::string body = R"( +%val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f32vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, MatrixTimesVector23x3Success) { + const std::string body = R"( +%val = OpMatrixTimesVector %f32vec2 %f32mat23_121212 %f32vec3_123 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, MatrixTimesVectorWrongTypeId) { + const std::string body = R"( +%val = OpMatrixTimesVector %f32mat22 %f32mat22_1212 %f32vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected float vector type as Result Type: " + "MatrixTimesVector")); +} + +TEST_F(ValidateArithmetics, MatrixTimesVectorWrongMatrix) { + const std::string body = R"( +%val = OpMatrixTimesVector %f32vec3 %f32vec3_123 %f32vec3_123 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected float matrix type as left operand: " + "MatrixTimesVector")); +} + +TEST_F(ValidateArithmetics, MatrixTimesVectorWrongMatrixCol) { + const std::string body = R"( +%val = OpMatrixTimesVector %f32vec3 %f32mat23_121212 %f32vec3_123 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected column type of the matrix to be equal to Result Type: " + "MatrixTimesVector")); +} + +TEST_F(ValidateArithmetics, MatrixTimesVectorWrongVector) { + const std::string body = R"( +%val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %u32vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected float vector type as right operand: " + "MatrixTimesVector")); +} + +TEST_F(ValidateArithmetics, MatrixTimesVectorDifferentComponents) { + const std::string body = R"( +%val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f64vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected component types of the operands to be equal: " + "MatrixTimesVector")); +} + +TEST_F(ValidateArithmetics, MatrixTimesVector22x3Fail) { + const std::string body = R"( +%val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f32vec3_123 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected number of columns of the matrix to be equal to the vector " + "size: MatrixTimesVector")); +} + +TEST_F(ValidateArithmetics, MatrixTimesMatrix22x22Success) { + const std::string body = R"( +%val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f32mat22_1212 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, MatrixTimesMatrix23x32Success) { + const std::string body = R"( +%val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat32_123123 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, MatrixTimesMatrix33x33Success) { + const std::string body = R"( +%val = OpMatrixTimesMatrix %f32mat33 %f32mat33_123123123 %f32mat33_123123123 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongTypeId) { + const std::string body = R"( +%val = OpMatrixTimesMatrix %f32vec2 %f32mat22_1212 %f32mat22_1212 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected float matrix type as Result Type: MatrixTimesMatrix")); +} + +TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongLeftOperand) { + const std::string body = R"( +%val = OpMatrixTimesMatrix %f32mat22 %f32vec2_12 %f32mat22_1212 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected float matrix type as left operand: MatrixTimesMatrix")); +} + +TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongRightOperand) { + const std::string body = R"( +%val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f32vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected float matrix type as right operand: MatrixTimesMatrix")); +} + +TEST_F(ValidateArithmetics, MatrixTimesMatrix32x23Fail) { + const std::string body = R"( +%val = OpMatrixTimesMatrix %f32mat22 %f32mat32_123123 %f32mat23_121212 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected column types of Result Type and left matrix to be equal: " + "MatrixTimesMatrix")); +} + +TEST_F(ValidateArithmetics, MatrixTimesMatrixDifferentComponents) { + const std::string body = R"( +%val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f64mat22_1212 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected component types of Result Type and right " + "matrix to be equal: " + "MatrixTimesMatrix")); +} + +TEST_F(ValidateArithmetics, MatrixTimesMatrix23x23Fail) { + const std::string body = R"( +%val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat23_121212 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected number of columns of Result Type and right " + "matrix to be equal: " + "MatrixTimesMatrix")); +} + +TEST_F(ValidateArithmetics, MatrixTimesMatrix23x22Fail) { + const std::string body = R"( +%val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat22_1212 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected number of columns of left matrix and number " + "of rows of right " + "matrix to be equal: MatrixTimesMatrix")); +} + +TEST_F(ValidateArithmetics, OuterProduct2x2Success) { + const std::string body = R"( +%val = OpOuterProduct %f32mat22 %f32vec2_12 %f32vec2_01 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, OuterProduct3x2Success) { + const std::string body = R"( +%val = OpOuterProduct %f32mat32 %f32vec3_123 %f32vec2_01 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, OuterProduct2x3Success) { + const std::string body = R"( +%val = OpOuterProduct %f32mat23 %f32vec2_01 %f32vec3_123 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, OuterProductWrongTypeId) { + const std::string body = R"( +%val = OpOuterProduct %f32vec2 %f32vec2_01 %f32vec3_123 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected float matrix type as Result Type: " + "OuterProduct")); +} + +TEST_F(ValidateArithmetics, OuterProductWrongLeftOperand) { + const std::string body = R"( +%val = OpOuterProduct %f32mat22 %f32vec3_123 %f32vec2_01 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected column type of Result Type to be equal to the type " + "of the left operand: OuterProduct")); +} + +TEST_F(ValidateArithmetics, OuterProductRightOperandNotFloatVector) { + const std::string body = R"( +%val = OpOuterProduct %f32mat22 %f32vec2_12 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected float vector type as right operand: OuterProduct")); +} + +TEST_F(ValidateArithmetics, OuterProductRightOperandWrongComponent) { + const std::string body = R"( +%val = OpOuterProduct %f32mat22 %f32vec2_12 %f64vec2_01 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected component types of the operands to be equal: " + "OuterProduct")); +} + +TEST_F(ValidateArithmetics, OuterProductRightOperandWrongDimension) { + const std::string body = R"( +%val = OpOuterProduct %f32mat22 %f32vec2_12 %f32vec3_123 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected number of columns of the matrix to be equal to the " + "vector size of the right operand: OuterProduct")); +} + +TEST_F(ValidateArithmetics, IAddCarrySuccess) { + const std::string body = R"( +%val1 = OpIAddCarry %struct_u32_u32 %u32_0 %u32_1 +%val2 = OpIAddCarry %struct_u32vec2_u32vec2 %u32vec2_01 %u32vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, IAddCarryResultTypeNotStruct) { + const std::string body = R"( +%val = OpIAddCarry %u32 %u32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected a struct as Result Type: IAddCarry")); +} + +TEST_F(ValidateArithmetics, IAddCarryResultTypeNotTwoMembers) { + const std::string body = R"( +%val = OpIAddCarry %struct_u32_u32_u32 %u32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Result Type struct to have two members: IAddCarry")); +} + +TEST_F(ValidateArithmetics, IAddCarryResultTypeMemberNotUnsignedInt) { + const std::string body = R"( +%val = OpIAddCarry %struct_s32_s32 %s32_0 %s32_1 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type struct member types to be " + "unsigned integer scalar " + "or vector: IAddCarry")); +} + +TEST_F(ValidateArithmetics, IAddCarryWrongLeftOperand) { + const std::string body = R"( +%val = OpIAddCarry %struct_u32_u32 %s32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected both operands to be of Result Type member " + "type: IAddCarry")); +} + +TEST_F(ValidateArithmetics, IAddCarryWrongRightOperand) { + const std::string body = R"( +%val = OpIAddCarry %struct_u32_u32 %u32_0 %s32_1 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected both operands to be of Result Type member " + "type: IAddCarry")); +} + +TEST_F(ValidateArithmetics, OpSMulExtendedSuccess) { + const std::string body = R"( +%val1 = OpSMulExtended %struct_u32_u32 %u32_0 %u32_1 +%val2 = OpSMulExtended %struct_s32_s32 %s32_0 %s32_1 +%val3 = OpSMulExtended %struct_u32vec2_u32vec2 %u32vec2_01 %u32vec2_12 +%val4 = OpSMulExtended %struct_s32vec2_s32vec2 %s32vec2_01 %s32vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, SMulExtendedResultTypeMemberNotInt) { + const std::string body = R"( +%val = OpSMulExtended %struct_f32_f32 %f32_0 %f32_1 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Result Type struct member types to be integer scalar " + "or vector: SMulExtended")); +} + +TEST_F(ValidateArithmetics, SMulExtendedResultTypeMembersNotIdentical) { + const std::string body = R"( +%val = OpSMulExtended %struct_s32_u32 %s32_0 %s32_1 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Result Type struct member types to be identical: " + "SMulExtended")); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_atomics_test.cpp b/test/val/val_atomics_test.cpp new file mode 100644 index 000000000..1f3001886 --- /dev/null +++ b/test/val/val_atomics_test.cpp @@ -0,0 +1,1933 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::HasSubstr; +using ::testing::Not; + +using ValidateAtomics = spvtest::ValidateBase; + +std::string GenerateShaderCodeImpl( + const std::string& body, const std::string& capabilities_and_extensions, + const std::string& definitions, const std::string& memory_model) { + std::ostringstream ss; + ss << R"( +OpCapability Shader +)"; + ss << capabilities_and_extensions; + ss << "OpMemoryModel Logical " << memory_model << "\n"; + ss << R"( +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f32 = OpTypeFloat 32 +%u32 = OpTypeInt 32 0 +%f32vec4 = OpTypeVector %f32 4 + +%f32_0 = OpConstant %f32 0 +%f32_1 = OpConstant %f32 1 +%u32_0 = OpConstant %u32 0 +%u32_1 = OpConstant %u32 1 +%f32vec4_0000 = OpConstantComposite %f32vec4 %f32_0 %f32_0 %f32_0 %f32_0 + +%cross_device = OpConstant %u32 0 +%device = OpConstant %u32 1 +%workgroup = OpConstant %u32 2 +%subgroup = OpConstant %u32 3 +%invocation = OpConstant %u32 4 +%queuefamily = OpConstant %u32 5 + +%relaxed = OpConstant %u32 0 +%acquire = OpConstant %u32 2 +%release = OpConstant %u32 4 +%acquire_release = OpConstant %u32 8 +%acquire_and_release = OpConstant %u32 6 +%sequentially_consistent = OpConstant %u32 16 +%acquire_release_uniform_workgroup = OpConstant %u32 328 + +%f32_ptr = OpTypePointer Workgroup %f32 +%f32_var = OpVariable %f32_ptr Workgroup + +%u32_ptr = OpTypePointer Workgroup %u32 +%u32_var = OpVariable %u32_ptr Workgroup + +%f32vec4_ptr = OpTypePointer Workgroup %f32vec4 +%f32vec4_var = OpVariable %f32vec4_ptr Workgroup + +%f32_ptr_function = OpTypePointer Function %f32 +)"; + ss << definitions; + ss << R"( +%main = OpFunction %void None %func +%main_entry = OpLabel +)"; + ss << body; + ss << R"( +OpReturn +OpFunctionEnd)"; + + return ss.str(); +} + +std::string GenerateShaderCode( + const std::string& body, + const std::string& capabilities_and_extensions = "", + const std::string& memory_model = "GLSL450") { + const std::string defintions = R"( +%u64 = OpTypeInt 64 0 +%s64 = OpTypeInt 64 1 + +%u64_1 = OpConstant %u64 1 +%s64_1 = OpConstant %s64 1 + +%u64_ptr = OpTypePointer Workgroup %u64 +%s64_ptr = OpTypePointer Workgroup %s64 +%u64_var = OpVariable %u64_ptr Workgroup +%s64_var = OpVariable %s64_ptr Workgroup +)"; + return GenerateShaderCodeImpl( + body, "OpCapability Int64\n" + capabilities_and_extensions, defintions, + memory_model); +} + +std::string GenerateWebGPUShaderCode( + const std::string& body, + const std::string& capabilities_and_extensions = "") { + const std::string vulkan_memory_capability = R"( +OpCapability VulkanMemoryModelDeviceScopeKHR +OpCapability VulkanMemoryModelKHR +)"; + const std::string vulkan_memory_extension = R"( +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + return GenerateShaderCodeImpl(body, + vulkan_memory_capability + + capabilities_and_extensions + + vulkan_memory_extension, + "", "VulkanKHR"); +} + +std::string GenerateKernelCode( + const std::string& body, + const std::string& capabilities_and_extensions = "") { + std::ostringstream ss; + ss << R"( +OpCapability Addresses +OpCapability Kernel +OpCapability Linkage +OpCapability Int64 +)"; + + ss << capabilities_and_extensions; + ss << R"( +OpMemoryModel Physical32 OpenCL +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f32 = OpTypeFloat 32 +%u32 = OpTypeInt 32 0 +%u64 = OpTypeInt 64 0 +%f32vec4 = OpTypeVector %f32 4 + +%f32_0 = OpConstant %f32 0 +%f32_1 = OpConstant %f32 1 +%u32_0 = OpConstant %u32 0 +%u32_1 = OpConstant %u32 1 +%u64_1 = OpConstant %u64 1 +%f32vec4_0000 = OpConstantComposite %f32vec4 %f32_0 %f32_0 %f32_0 %f32_0 + +%cross_device = OpConstant %u32 0 +%device = OpConstant %u32 1 +%workgroup = OpConstant %u32 2 +%subgroup = OpConstant %u32 3 +%invocation = OpConstant %u32 4 + +%relaxed = OpConstant %u32 0 +%acquire = OpConstant %u32 2 +%release = OpConstant %u32 4 +%acquire_release = OpConstant %u32 8 +%acquire_and_release = OpConstant %u32 6 +%sequentially_consistent = OpConstant %u32 16 +%acquire_release_uniform_workgroup = OpConstant %u32 328 +%acquire_release_atomic_counter_workgroup = OpConstant %u32 1288 + +%f32_ptr = OpTypePointer Workgroup %f32 +%f32_var = OpVariable %f32_ptr Workgroup + +%u32_ptr = OpTypePointer Workgroup %u32 +%u32_var = OpVariable %u32_ptr Workgroup + +%u64_ptr = OpTypePointer Workgroup %u64 +%u64_var = OpVariable %u64_ptr Workgroup + +%f32vec4_ptr = OpTypePointer Workgroup %f32vec4 +%f32vec4_var = OpVariable %f32vec4_ptr Workgroup + +%f32_ptr_function = OpTypePointer Function %f32 +%f32_ptr_uniformconstant = OpTypePointer UniformConstant %f32 +%f32_uc_var = OpVariable %f32_ptr_uniformconstant UniformConstant + +%main = OpFunction %void None %func +%main_entry = OpLabel +)"; + + ss << body; + + ss << R"( +OpReturn +OpFunctionEnd)"; + + return ss.str(); +} + +TEST_F(ValidateAtomics, AtomicLoadShaderSuccess) { + const std::string body = R"( +%val1 = OpAtomicLoad %u32 %u32_var %device %relaxed +%val2 = OpAtomicLoad %u32 %u32_var %workgroup %acquire +%val3 = OpAtomicLoad %u64 %u64_var %subgroup %sequentially_consistent +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAtomics, AtomicLoadKernelSuccess) { + const std::string body = R"( +%val1 = OpAtomicLoad %f32 %f32_var %device %relaxed +%val2 = OpAtomicLoad %u32 %u32_var %workgroup %sequentially_consistent +%val3 = OpAtomicLoad %u64 %u64_var %subgroup %acquire +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAtomics, AtomicLoadVulkanSuccess) { + const std::string body = R"( +%val1 = OpAtomicLoad %u32 %u32_var %device %relaxed +%val2 = OpAtomicLoad %u32 %u32_var %workgroup %acquire +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0)); +} + +TEST_F(ValidateAtomics, AtomicStoreOpenCLFunctionPointerStorageTypeSuccess) { + const std::string body = R"( +%f32_var_function = OpVariable %f32_ptr_function Function +OpAtomicStore %f32_var_function %device %relaxed %f32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body), SPV_ENV_OPENCL_1_2); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_OPENCL_1_2)); +} + +TEST_F(ValidateAtomics, AtomicStoreVulkanFunctionPointerStorageType) { + const std::string body = R"( +%f32_var_function = OpVariable %f32_ptr_function Function +OpAtomicStore %f32_var_function %device %relaxed %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("AtomicStore: expected Pointer Storage Class to be Uniform, " + "Workgroup, CrossWorkgroup, Generic, AtomicCounter, Image or " + "StorageBuffer")); +} + +// TODO(atgoo@github.com): the corresponding check fails Vulkan CTS, +// reenable once fixed. +TEST_F(ValidateAtomics, DISABLED_AtomicLoadVulkanSubgroup) { + const std::string body = R"( +%val1 = OpAtomicLoad %u32 %u32_var %subgroup %acquire +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("AtomicLoad: in Vulkan environment memory scope is " + "limited to Device, Workgroup and Invocation")); +} + +TEST_F(ValidateAtomics, AtomicLoadVulkanRelease) { + const std::string body = R"( +%val1 = OpAtomicLoad %u32 %u32_var %workgroup %release +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Vulkan spec disallows OpAtomicLoad with Memory Semantics " + "Release, AcquireRelease and SequentiallyConsistent")); +} + +TEST_F(ValidateAtomics, AtomicLoadVulkanAcquireRelease) { + const std::string body = R"( +%val1 = OpAtomicLoad %u32 %u32_var %workgroup %acquire_release +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Vulkan spec disallows OpAtomicLoad with Memory Semantics " + "Release, AcquireRelease and SequentiallyConsistent")); +} + +TEST_F(ValidateAtomics, AtomicLoadVulkanSequentiallyConsistent) { + const std::string body = R"( +%val1 = OpAtomicLoad %u32 %u32_var %workgroup %sequentially_consistent +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Vulkan spec disallows OpAtomicLoad with Memory Semantics " + "Release, AcquireRelease and SequentiallyConsistent")); +} + +TEST_F(ValidateAtomics, AtomicLoadShaderFloat) { + const std::string body = R"( +%val1 = OpAtomicLoad %f32 %f32_var %device %relaxed +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("AtomicLoad: " + "expected Result Type to be int scalar type")); +} + +TEST_F(ValidateAtomics, AtomicLoadVulkanInt64) { + const std::string body = R"( +%val1 = OpAtomicLoad %u64 %u64_var %device %relaxed +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "AtomicLoad: 64-bit atomics require the Int64Atomics capability")); +} + +TEST_F(ValidateAtomics, AtomicLoadWebGPUShaderSuccess) { + const std::string body = R"( +%val1 = OpAtomicLoad %u32 %u32_var %queuefamily %relaxed +%val2 = OpAtomicLoad %u32 %u32_var %workgroup %acquire +)"; + + CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0)); +} + +TEST_F(ValidateAtomics, AtomicLoadWebGPUShaderSequentiallyConsistentFailure) { + const std::string body = R"( +%val3 = OpAtomicLoad %u32 %u32_var %subgroup %sequentially_consistent +)"; + + CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "WebGPU spec disallows any bit masks in Memory Semantics that are " + "not Acquire, Release, AcquireRelease, UniformMemory, " + "WorkgroupMemory, ImageMemory, OutputMemoryKHR, MakeAvailableKHR, or " + "MakeVisibleKHR\n %34 = OpAtomicLoad %uint %29 %uint_3 %uint_16\n")); +} + +TEST_F(ValidateAtomics, VK_KHR_shader_atomic_int64Success) { + const std::string body = R"( +%val1 = OpAtomicUMin %u64 %u64_var %device %relaxed %u64_1 +%val2 = OpAtomicUMax %u64 %u64_var %device %relaxed %u64_1 +%val3 = OpAtomicSMin %u64 %u64_var %device %relaxed %u64_1 +%val4 = OpAtomicSMax %u64 %u64_var %device %relaxed %u64_1 +%val5 = OpAtomicAnd %u64 %u64_var %device %relaxed %u64_1 +%val6 = OpAtomicOr %u64 %u64_var %device %relaxed %u64_1 +%val7 = OpAtomicXor %u64 %u64_var %device %relaxed %u64_1 +%val8 = OpAtomicIAdd %u64 %u64_var %device %relaxed %u64_1 +%val9 = OpAtomicExchange %u64 %u64_var %device %relaxed %u64_1 +%val10 = OpAtomicCompareExchange %u64 %u64_var %device %relaxed %relaxed %u64_1 %u64_1 + +%val11 = OpAtomicUMin %s64 %s64_var %device %relaxed %s64_1 +%val12 = OpAtomicUMax %s64 %s64_var %device %relaxed %s64_1 +%val13 = OpAtomicSMin %s64 %s64_var %device %relaxed %s64_1 +%val14 = OpAtomicSMax %s64 %s64_var %device %relaxed %s64_1 +%val15 = OpAtomicAnd %s64 %s64_var %device %relaxed %s64_1 +%val16 = OpAtomicOr %s64 %s64_var %device %relaxed %s64_1 +%val17 = OpAtomicXor %s64 %s64_var %device %relaxed %s64_1 +%val18 = OpAtomicIAdd %s64 %s64_var %device %relaxed %s64_1 +%val19 = OpAtomicExchange %s64 %s64_var %device %relaxed %s64_1 +%val20 = OpAtomicCompareExchange %s64 %s64_var %device %relaxed %relaxed %s64_1 %s64_1 + +%val21 = OpAtomicLoad %u64 %u64_var %device %relaxed +%val22 = OpAtomicLoad %s64 %s64_var %device %relaxed + +OpAtomicStore %u64_var %device %relaxed %u64_1 +OpAtomicStore %s64_var %device %relaxed %s64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body, "OpCapability Int64Atomics\n"), + SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0)); +} + +TEST_F(ValidateAtomics, VK_KHR_shader_atomic_int64MissingCapability) { + const std::string body = R"( +%val1 = OpAtomicUMin %u64 %u64_var %device %relaxed %u64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "AtomicUMin: 64-bit atomics require the Int64Atomics capability")); +} + +TEST_F(ValidateAtomics, AtomicLoadWrongResultType) { + const std::string body = R"( +%val1 = OpAtomicLoad %f32vec4 %f32vec4_var %device %relaxed +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("AtomicLoad: " + "expected Result Type to be int or float scalar type")); +} + +TEST_F(ValidateAtomics, AtomicLoadWrongPointerType) { + const std::string body = R"( +%val1 = OpAtomicLoad %f32 %f32_ptr %device %relaxed +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Operand 27[%_ptr_Workgroup_float] cannot be a type")); +} + +TEST_F(ValidateAtomics, AtomicLoadWrongPointerDataType) { + const std::string body = R"( +%val1 = OpAtomicLoad %u32 %f32_var %device %relaxed +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("AtomicLoad: " + "expected Pointer to point to a value of type Result Type")); +} + +TEST_F(ValidateAtomics, AtomicLoadWrongScopeType) { + const std::string body = R"( +%val1 = OpAtomicLoad %f32 %f32_var %f32_1 %relaxed +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("AtomicLoad: expected Memory Scope to be a 32-bit int\n %40 = " + "OpAtomicLoad %float %28 %float_1 %uint_0_1\n")); +} + +TEST_F(ValidateAtomics, AtomicLoadWrongMemorySemanticsType) { + const std::string body = R"( +%val1 = OpAtomicLoad %f32 %f32_var %device %u64_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("AtomicLoad: expected Memory Semantics to be a 32-bit int")); +} + +TEST_F(ValidateAtomics, AtomicStoreKernelSuccess) { + const std::string body = R"( +OpAtomicStore %f32_var %device %relaxed %f32_1 +OpAtomicStore %u32_var %subgroup %release %u32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAtomics, AtomicStoreShaderSuccess) { + const std::string body = R"( +OpAtomicStore %u32_var %device %release %u32_1 +OpAtomicStore %u32_var %subgroup %sequentially_consistent %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAtomics, AtomicStoreVulkanSuccess) { + const std::string body = R"( +OpAtomicStore %u32_var %device %release %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0)); +} + +TEST_F(ValidateAtomics, AtomicStoreVulkanAcquire) { + const std::string body = R"( +OpAtomicStore %u32_var %device %acquire %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Vulkan spec disallows OpAtomicStore with Memory Semantics " + "Acquire, AcquireRelease and SequentiallyConsistent")); +} + +TEST_F(ValidateAtomics, AtomicStoreVulkanAcquireRelease) { + const std::string body = R"( +OpAtomicStore %u32_var %device %acquire_release %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Vulkan spec disallows OpAtomicStore with Memory Semantics " + "Acquire, AcquireRelease and SequentiallyConsistent")); +} + +TEST_F(ValidateAtomics, AtomicStoreVulkanSequentiallyConsistent) { + const std::string body = R"( +OpAtomicStore %u32_var %device %sequentially_consistent %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Vulkan spec disallows OpAtomicStore with Memory Semantics " + "Acquire, AcquireRelease and SequentiallyConsistent")); +} + +TEST_F(ValidateAtomics, AtomicStoreWebGPUSuccess) { + const std::string body = R"( +OpAtomicStore %u32_var %queuefamily %release %u32_1 +)"; + + CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0)); +} + +TEST_F(ValidateAtomics, AtomicStoreWebGPUSequentiallyConsistent) { + const std::string body = R"( +OpAtomicStore %u32_var %queuefamily %sequentially_consistent %u32_1 +)"; + + CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "WebGPU spec disallows any bit masks in Memory Semantics that are " + "not Acquire, Release, AcquireRelease, UniformMemory, " + "WorkgroupMemory, ImageMemory, OutputMemoryKHR, MakeAvailableKHR, or " + "MakeVisibleKHR\n" + " OpAtomicStore %29 %uint_5 %uint_16 %uint_1\n")); +} + +TEST_F(ValidateAtomics, AtomicStoreWrongPointerType) { + const std::string body = R"( +OpAtomicStore %f32_1 %device %relaxed %f32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("AtomicStore: expected Pointer to be of type OpTypePointer")); +} + +TEST_F(ValidateAtomics, AtomicStoreWrongPointerDataType) { + const std::string body = R"( +OpAtomicStore %f32vec4_var %device %relaxed %f32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("AtomicStore: " + "expected Pointer to be a pointer to int or float scalar " + "type")); +} + +TEST_F(ValidateAtomics, AtomicStoreWrongPointerStorageType) { + const std::string body = R"( +OpAtomicStore %f32_uc_var %device %relaxed %f32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("AtomicStore: expected Pointer Storage Class to be Uniform, " + "Workgroup, CrossWorkgroup, Generic, AtomicCounter, Image or " + "StorageBuffer")); +} + +TEST_F(ValidateAtomics, AtomicStoreWrongScopeType) { + const std::string body = R"( +OpAtomicStore %f32_var %f32_1 %relaxed %f32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("AtomicStore: expected Memory Scope to be a 32-bit int\n " + "OpAtomicStore %28 %float_1 %uint_0_1 %float_1\n")); +} + +TEST_F(ValidateAtomics, AtomicStoreWrongMemorySemanticsType) { + const std::string body = R"( +OpAtomicStore %f32_var %device %f32_1 %f32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("AtomicStore: expected Memory Semantics to be a 32-bit int")); +} + +TEST_F(ValidateAtomics, AtomicStoreWrongValueType) { + const std::string body = R"( +OpAtomicStore %f32_var %device %relaxed %u32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("AtomicStore: " + "expected Value type and the type pointed to by Pointer to " + "be the same")); +} + +TEST_F(ValidateAtomics, AtomicExchangeShaderSuccess) { + const std::string body = R"( +%val1 = OpAtomicStore %u32_var %device %relaxed %u32_1 +%val2 = OpAtomicExchange %u32 %u32_var %device %relaxed %u32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAtomics, AtomicExchangeKernelSuccess) { + const std::string body = R"( +OpAtomicStore %f32_var %device %relaxed %f32_1 +%val2 = OpAtomicExchange %f32 %f32_var %device %relaxed %f32_0 +%val3 = OpAtomicStore %u32_var %device %relaxed %u32_1 +%val4 = OpAtomicExchange %u32 %u32_var %device %relaxed %u32_0 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAtomics, AtomicExchangeShaderFloat) { + const std::string body = R"( +OpAtomicStore %f32_var %device %relaxed %f32_1 +%val2 = OpAtomicExchange %f32 %f32_var %device %relaxed %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("AtomicExchange: " + "expected Result Type to be int scalar type")); +} + +TEST_F(ValidateAtomics, AtomicExchangeWrongResultType) { + const std::string body = R"( +%val1 = OpStore %f32vec4_var %f32vec4_0000 +%val2 = OpAtomicExchange %f32vec4 %f32vec4_var %device %relaxed %f32vec4_0000 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("AtomicExchange: " + "expected Result Type to be int or float scalar type")); +} + +TEST_F(ValidateAtomics, AtomicExchangeWrongPointerType) { + const std::string body = R"( +%val2 = OpAtomicExchange %f32 %f32vec4_ptr %device %relaxed %f32vec4_0000 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Operand 33[%_ptr_Workgroup_v4float] cannot be a " + "type")); +} + +TEST_F(ValidateAtomics, AtomicExchangeWrongPointerDataType) { + const std::string body = R"( +%val1 = OpStore %f32vec4_var %f32vec4_0000 +%val2 = OpAtomicExchange %f32 %f32vec4_var %device %relaxed %f32vec4_0000 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("AtomicExchange: " + "expected Pointer to point to a value of type Result Type")); +} + +TEST_F(ValidateAtomics, AtomicExchangeWrongScopeType) { + const std::string body = R"( +OpAtomicStore %f32_var %device %relaxed %f32_1 +%val2 = OpAtomicExchange %f32 %f32_var %f32_1 %relaxed %f32_0 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "AtomicExchange: expected Memory Scope to be a 32-bit int\n %40 = " + "OpAtomicExchange %float %28 %float_1 %uint_0_1 %float_0\n")); +} + +TEST_F(ValidateAtomics, AtomicExchangeWrongMemorySemanticsType) { + const std::string body = R"( +OpAtomicStore %f32_var %device %relaxed %f32_1 +%val2 = OpAtomicExchange %f32 %f32_var %device %f32_1 %f32_0 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "AtomicExchange: expected Memory Semantics to be a 32-bit int")); +} + +TEST_F(ValidateAtomics, AtomicExchangeWrongValueType) { + const std::string body = R"( +OpAtomicStore %f32_var %device %relaxed %f32_1 +%val2 = OpAtomicExchange %f32 %f32_var %device %relaxed %u32_0 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("AtomicExchange: " + "expected Value to be of type Result Type")); +} + +TEST_F(ValidateAtomics, AtomicCompareExchangeShaderSuccess) { + const std::string body = R"( +%val1 = OpAtomicStore %u32_var %device %relaxed %u32_1 +%val2 = OpAtomicCompareExchange %u32 %u32_var %device %relaxed %relaxed %u32_0 %u32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAtomics, AtomicCompareExchangeKernelSuccess) { + const std::string body = R"( +OpAtomicStore %f32_var %device %relaxed %f32_1 +%val2 = OpAtomicCompareExchange %f32 %f32_var %device %relaxed %relaxed %f32_0 %f32_1 +%val3 = OpAtomicStore %u32_var %device %relaxed %u32_1 +%val4 = OpAtomicCompareExchange %u32 %u32_var %device %relaxed %relaxed %u32_0 %u32_0 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAtomics, AtomicCompareExchangeShaderFloat) { + const std::string body = R"( +OpAtomicStore %f32_var %device %relaxed %f32_1 +%val1 = OpAtomicCompareExchange %f32 %f32_var %device %relaxed %relaxed %f32_0 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("AtomicCompareExchange: " + "expected Result Type to be int scalar type")); +} + +TEST_F(ValidateAtomics, AtomicCompareExchangeWrongResultType) { + const std::string body = R"( +%val1 = OpStore %f32vec4_var %f32vec4_0000 +%val2 = OpAtomicCompareExchange %f32vec4 %f32vec4_var %device %relaxed %relaxed %f32vec4_0000 %f32vec4_0000 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("AtomicCompareExchange: " + "expected Result Type to be int or float scalar type")); +} + +TEST_F(ValidateAtomics, AtomicCompareExchangeWrongPointerType) { + const std::string body = R"( +%val2 = OpAtomicCompareExchange %f32 %f32vec4_ptr %device %relaxed %relaxed %f32vec4_0000 %f32vec4_0000 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Operand 33[%_ptr_Workgroup_v4float] cannot be a " + "type")); +} + +TEST_F(ValidateAtomics, AtomicCompareExchangeWrongPointerDataType) { + const std::string body = R"( +%val1 = OpStore %f32vec4_var %f32vec4_0000 +%val2 = OpAtomicCompareExchange %f32 %f32vec4_var %device %relaxed %relaxed %f32_0 %f32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("AtomicCompareExchange: " + "expected Pointer to point to a value of type Result Type")); +} + +TEST_F(ValidateAtomics, AtomicCompareExchangeWrongScopeType) { + const std::string body = R"( +OpAtomicStore %f32_var %device %relaxed %f32_1 +%val2 = OpAtomicCompareExchange %f32 %f32_var %f32_1 %relaxed %relaxed %f32_0 %f32_0 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("AtomicCompareExchange: expected Memory Scope to be a 32-bit " + "int\n %40 = OpAtomicCompareExchange %float %28 %float_1 " + "%uint_0_1 %uint_0_1 %float_0 %float_0\n")); +} + +TEST_F(ValidateAtomics, AtomicCompareExchangeWrongMemorySemanticsType1) { + const std::string body = R"( +OpAtomicStore %f32_var %device %relaxed %f32_1 +%val2 = OpAtomicCompareExchange %f32 %f32_var %device %f32_1 %relaxed %f32_0 %f32_0 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("AtomicCompareExchange: expected Memory Semantics to " + "be a 32-bit int")); +} + +TEST_F(ValidateAtomics, AtomicCompareExchangeWrongMemorySemanticsType2) { + const std::string body = R"( +OpAtomicStore %f32_var %device %relaxed %f32_1 +%val2 = OpAtomicCompareExchange %f32 %f32_var %device %relaxed %f32_1 %f32_0 %f32_0 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("AtomicCompareExchange: expected Memory Semantics to " + "be a 32-bit int")); +} + +TEST_F(ValidateAtomics, AtomicCompareExchangeUnequalRelease) { + const std::string body = R"( +OpAtomicStore %f32_var %device %relaxed %f32_1 +%val2 = OpAtomicCompareExchange %f32 %f32_var %device %relaxed %release %f32_0 %f32_0 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("AtomicCompareExchange: Memory Semantics Release and " + "AcquireRelease cannot be used for operand Unequal")); +} + +TEST_F(ValidateAtomics, AtomicCompareExchangeWrongValueType) { + const std::string body = R"( +OpAtomicStore %f32_var %device %relaxed %f32_1 +%val2 = OpAtomicCompareExchange %f32 %f32_var %device %relaxed %relaxed %u32_0 %f32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("AtomicCompareExchange: " + "expected Value to be of type Result Type")); +} + +TEST_F(ValidateAtomics, AtomicCompareExchangeWrongComparatorType) { + const std::string body = R"( +OpAtomicStore %f32_var %device %relaxed %f32_1 +%val2 = OpAtomicCompareExchange %f32 %f32_var %device %relaxed %relaxed %f32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("AtomicCompareExchange: " + "expected Comparator to be of type Result Type")); +} + +TEST_F(ValidateAtomics, AtomicCompareExchangeWeakSuccess) { + const std::string body = R"( +%val3 = OpAtomicStore %u32_var %device %relaxed %u32_1 +%val4 = OpAtomicCompareExchangeWeak %u32 %u32_var %device %relaxed %relaxed %u32_0 %u32_0 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAtomics, AtomicCompareExchangeWeakWrongResultType) { + const std::string body = R"( +OpAtomicStore %f32_var %device %relaxed %f32_1 +%val2 = OpAtomicCompareExchangeWeak %f32 %f32_var %device %relaxed %relaxed %f32_0 %f32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("AtomicCompareExchangeWeak: " + "expected Result Type to be int scalar type")); +} + +TEST_F(ValidateAtomics, AtomicArithmeticsSuccess) { + const std::string body = R"( +OpAtomicStore %u32_var %device %relaxed %u32_1 +%val1 = OpAtomicIIncrement %u32 %u32_var %device %acquire_release +%val2 = OpAtomicIDecrement %u32 %u32_var %device %acquire_release +%val3 = OpAtomicIAdd %u32 %u32_var %device %acquire_release %u32_1 +%val4 = OpAtomicISub %u32 %u32_var %device %acquire_release %u32_1 +%val5 = OpAtomicUMin %u32 %u32_var %device %acquire_release %u32_1 +%val6 = OpAtomicUMax %u32 %u32_var %device %acquire_release %u32_1 +%val7 = OpAtomicSMin %u32 %u32_var %device %sequentially_consistent %u32_1 +%val8 = OpAtomicSMax %u32 %u32_var %device %sequentially_consistent %u32_1 +%val9 = OpAtomicAnd %u32 %u32_var %device %sequentially_consistent %u32_1 +%val10 = OpAtomicOr %u32 %u32_var %device %sequentially_consistent %u32_1 +%val11 = OpAtomicXor %u32 %u32_var %device %sequentially_consistent %u32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAtomics, AtomicFlagsSuccess) { + const std::string body = R"( +OpAtomicFlagClear %u32_var %device %release +%val1 = OpAtomicFlagTestAndSet %bool %u32_var %device %relaxed +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAtomics, AtomicFlagTestAndSetWrongResultType) { + const std::string body = R"( +%val1 = OpAtomicFlagTestAndSet %u32 %u32_var %device %relaxed +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("AtomicFlagTestAndSet: " + "expected Result Type to be bool scalar type")); +} + +TEST_F(ValidateAtomics, AtomicFlagTestAndSetNotPointer) { + const std::string body = R"( +%val1 = OpAtomicFlagTestAndSet %bool %u32_1 %device %relaxed +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("AtomicFlagTestAndSet: " + "expected Pointer to be of type OpTypePointer")); +} + +TEST_F(ValidateAtomics, AtomicFlagTestAndSetNotIntPointer) { + const std::string body = R"( +%val1 = OpAtomicFlagTestAndSet %bool %f32_var %device %relaxed +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("AtomicFlagTestAndSet: " + "expected Pointer to point to a value of 32-bit int type")); +} + +TEST_F(ValidateAtomics, AtomicFlagTestAndSetNotInt32Pointer) { + const std::string body = R"( +%val1 = OpAtomicFlagTestAndSet %bool %u64_var %device %relaxed +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("AtomicFlagTestAndSet: " + "expected Pointer to point to a value of 32-bit int type")); +} + +TEST_F(ValidateAtomics, AtomicFlagTestAndSetWrongScopeType) { + const std::string body = R"( +%val1 = OpAtomicFlagTestAndSet %bool %u32_var %u64_1 %relaxed +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "AtomicFlagTestAndSet: expected Memory Scope to be a 32-bit int\n " + "%40 = OpAtomicFlagTestAndSet %bool %30 %ulong_1 %uint_0_1\n")); +} + +TEST_F(ValidateAtomics, AtomicFlagTestAndSetWrongMemorySemanticsType) { + const std::string body = R"( +%val1 = OpAtomicFlagTestAndSet %bool %u32_var %device %u64_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("AtomicFlagTestAndSet: " + "expected Memory Semantics to be a 32-bit int")); +} + +TEST_F(ValidateAtomics, AtomicFlagClearAcquire) { + const std::string body = R"( +OpAtomicFlagClear %u32_var %device %acquire +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Memory Semantics Acquire and AcquireRelease cannot be " + "used with AtomicFlagClear")); +} + +TEST_F(ValidateAtomics, AtomicFlagClearNotPointer) { + const std::string body = R"( +OpAtomicFlagClear %u32_1 %device %relaxed +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("AtomicFlagClear: " + "expected Pointer to be of type OpTypePointer")); +} + +TEST_F(ValidateAtomics, AtomicFlagClearNotIntPointer) { + const std::string body = R"( +OpAtomicFlagClear %f32_var %device %relaxed +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("AtomicFlagClear: " + "expected Pointer to point to a value of 32-bit int type")); +} + +TEST_F(ValidateAtomics, AtomicFlagClearNotInt32Pointer) { + const std::string body = R"( +OpAtomicFlagClear %u64_var %device %relaxed +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("AtomicFlagClear: " + "expected Pointer to point to a value of 32-bit int type")); +} + +TEST_F(ValidateAtomics, AtomicFlagClearWrongScopeType) { + const std::string body = R"( +OpAtomicFlagClear %u32_var %u64_1 %relaxed +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("AtomicFlagClear: expected Memory Scope to be a 32-bit " + "int\n OpAtomicFlagClear %30 %ulong_1 %uint_0_1\n")); +} + +TEST_F(ValidateAtomics, AtomicFlagClearWrongMemorySemanticsType) { + const std::string body = R"( +OpAtomicFlagClear %u32_var %device %u64_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "AtomicFlagClear: expected Memory Semantics to be a 32-bit int")); +} + +TEST_F(ValidateAtomics, AtomicIIncrementAcquireAndRelease) { + const std::string body = R"( +OpAtomicStore %u32_var %device %relaxed %u32_1 +%val1 = OpAtomicIIncrement %u32 %u32_var %device %acquire_and_release +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("AtomicIIncrement: Memory Semantics can have at most " + "one of the following bits set: Acquire, Release, " + "AcquireRelease or SequentiallyConsistent\n %40 = " + "OpAtomicIIncrement %uint %30 %uint_1_0 %uint_6\n")); +} + +TEST_F(ValidateAtomics, AtomicUniformMemorySemanticsShader) { + const std::string body = R"( +OpAtomicStore %u32_var %device %relaxed %u32_1 +%val1 = OpAtomicIIncrement %u32 %u32_var %device %acquire_release_uniform_workgroup +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAtomics, AtomicUniformMemorySemanticsKernel) { + const std::string body = R"( +OpAtomicStore %u32_var %device %relaxed %u32_1 +%val1 = OpAtomicIIncrement %u32 %u32_var %device %acquire_release_uniform_workgroup +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("AtomicIIncrement: Memory Semantics UniformMemory " + "requires capability Shader")); +} + +// Lack of the AtomicStorage capability is intentionally ignored, see +// https://github.com/KhronosGroup/glslang/issues/1618 for the reasoning why. +TEST_F(ValidateAtomics, AtomicCounterMemorySemanticsNoCapability) { + const std::string body = R"( + OpAtomicStore %u32_var %device %relaxed %u32_1 +%val1 = OpAtomicIIncrement %u32 %u32_var %device +%acquire_release_atomic_counter_workgroup +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAtomics, AtomicCounterMemorySemanticsWithCapability) { + const std::string body = R"( +OpAtomicStore %u32_var %device %relaxed %u32_1 +%val1 = OpAtomicIIncrement %u32 %u32_var %device %acquire_release_atomic_counter_workgroup +)"; + + CompileSuccessfully(GenerateKernelCode(body, "OpCapability AtomicStorage\n")); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAtomics, VulkanMemoryModelBanSequentiallyConsistentAtomicLoad) { + const std::string body = R"( +%ld = OpAtomicLoad %u32 %u32_var %workgroup %sequentially_consistent +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + + CompileSuccessfully(GenerateShaderCode(body, extra, "VulkanKHR"), + SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("SequentiallyConsistent memory semantics cannot be " + "used with the VulkanKHR memory model.")); +} + +TEST_F(ValidateAtomics, VulkanMemoryModelBanSequentiallyConsistentAtomicStore) { + const std::string body = R"( +OpAtomicStore %u32_var %workgroup %sequentially_consistent %u32_0 +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + + CompileSuccessfully(GenerateShaderCode(body, extra, "VulkanKHR"), + SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("SequentiallyConsistent memory semantics cannot be " + "used with the VulkanKHR memory model.")); +} + +TEST_F(ValidateAtomics, + VulkanMemoryModelBanSequentiallyConsistentAtomicExchange) { + const std::string body = R"( +%ex = OpAtomicExchange %u32 %u32_var %workgroup %sequentially_consistent %u32_0 +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + + CompileSuccessfully(GenerateShaderCode(body, extra, "VulkanKHR"), + SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("SequentiallyConsistent memory semantics cannot be " + "used with the VulkanKHR memory model.")); +} + +TEST_F(ValidateAtomics, + VulkanMemoryModelBanSequentiallyConsistentAtomicCompareExchangeEqual) { + const std::string body = R"( +%ex = OpAtomicCompareExchange %u32 %u32_var %workgroup %sequentially_consistent %relaxed %u32_0 %u32_0 +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + + CompileSuccessfully(GenerateShaderCode(body, extra, "VulkanKHR"), + SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("SequentiallyConsistent memory semantics cannot be " + "used with the VulkanKHR memory model.")); +} + +TEST_F(ValidateAtomics, + VulkanMemoryModelBanSequentiallyConsistentAtomicCompareExchangeUnequal) { + const std::string body = R"( +%ex = OpAtomicCompareExchange %u32 %u32_var %workgroup %relaxed %sequentially_consistent %u32_0 %u32_0 +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + + CompileSuccessfully(GenerateShaderCode(body, extra, "VulkanKHR"), + SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("SequentiallyConsistent memory semantics cannot be " + "used with the VulkanKHR memory model.")); +} + +TEST_F(ValidateAtomics, + VulkanMemoryModelBanSequentiallyConsistentAtomicIIncrement) { + const std::string body = R"( +%inc = OpAtomicIIncrement %u32 %u32_var %workgroup %sequentially_consistent +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + + CompileSuccessfully(GenerateShaderCode(body, extra, "VulkanKHR"), + SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("SequentiallyConsistent memory semantics cannot be " + "used with the VulkanKHR memory model.")); +} + +TEST_F(ValidateAtomics, + VulkanMemoryModelBanSequentiallyConsistentAtomicIDecrement) { + const std::string body = R"( +%dec = OpAtomicIDecrement %u32 %u32_var %workgroup %sequentially_consistent +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + + CompileSuccessfully(GenerateShaderCode(body, extra, "VulkanKHR"), + SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("SequentiallyConsistent memory semantics cannot be " + "used with the VulkanKHR memory model.")); +} + +TEST_F(ValidateAtomics, VulkanMemoryModelBanSequentiallyConsistentAtomicIAdd) { + const std::string body = R"( +%add = OpAtomicIAdd %u32 %u32_var %workgroup %sequentially_consistent %u32_0 +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + + CompileSuccessfully(GenerateShaderCode(body, extra, "VulkanKHR"), + SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("SequentiallyConsistent memory semantics cannot be " + "used with the VulkanKHR memory model.")); +} + +TEST_F(ValidateAtomics, VulkanMemoryModelBanSequentiallyConsistentAtomicISub) { + const std::string body = R"( +%sub = OpAtomicISub %u32 %u32_var %workgroup %sequentially_consistent %u32_0 +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + + CompileSuccessfully(GenerateShaderCode(body, extra, "VulkanKHR"), + SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("SequentiallyConsistent memory semantics cannot be " + "used with the VulkanKHR memory model.")); +} + +TEST_F(ValidateAtomics, VulkanMemoryModelBanSequentiallyConsistentAtomicSMin) { + const std::string body = R"( +%min = OpAtomicSMin %u32 %u32_var %workgroup %sequentially_consistent %u32_0 +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + + CompileSuccessfully(GenerateShaderCode(body, extra, "VulkanKHR"), + SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("SequentiallyConsistent memory semantics cannot be " + "used with the VulkanKHR memory model.")); +} + +TEST_F(ValidateAtomics, VulkanMemoryModelBanSequentiallyConsistentAtomicUMin) { + const std::string body = R"( +%min = OpAtomicUMin %u32 %u32_var %workgroup %sequentially_consistent %u32_0 +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + + CompileSuccessfully(GenerateShaderCode(body, extra, "VulkanKHR"), + SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("SequentiallyConsistent memory semantics cannot be " + "used with the VulkanKHR memory model.")); +} + +TEST_F(ValidateAtomics, VulkanMemoryModelBanSequentiallyConsistentAtomicSMax) { + const std::string body = R"( +%max = OpAtomicSMax %u32 %u32_var %workgroup %sequentially_consistent %u32_0 +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + + CompileSuccessfully(GenerateShaderCode(body, extra, "VulkanKHR"), + SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("SequentiallyConsistent memory semantics cannot be " + "used with the VulkanKHR memory model.")); +} + +TEST_F(ValidateAtomics, VulkanMemoryModelBanSequentiallyConsistentAtomicUMax) { + const std::string body = R"( +%max = OpAtomicUMax %u32 %u32_var %workgroup %sequentially_consistent %u32_0 +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + + CompileSuccessfully(GenerateShaderCode(body, extra, "VulkanKHR"), + SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("SequentiallyConsistent memory semantics cannot be " + "used with the VulkanKHR memory model.")); +} + +TEST_F(ValidateAtomics, VulkanMemoryModelBanSequentiallyConsistentAtomicAnd) { + const std::string body = R"( +%and = OpAtomicAnd %u32 %u32_var %workgroup %sequentially_consistent %u32_0 +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + + CompileSuccessfully(GenerateShaderCode(body, extra, "VulkanKHR"), + SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("SequentiallyConsistent memory semantics cannot be " + "used with the VulkanKHR memory model.")); +} + +TEST_F(ValidateAtomics, VulkanMemoryModelBanSequentiallyConsistentAtomicOr) { + const std::string body = R"( +%or = OpAtomicOr %u32 %u32_var %workgroup %sequentially_consistent %u32_0 +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + + CompileSuccessfully(GenerateShaderCode(body, extra, "VulkanKHR"), + SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("SequentiallyConsistent memory semantics cannot be " + "used with the VulkanKHR memory model.")); +} + +TEST_F(ValidateAtomics, VulkanMemoryModelBanSequentiallyConsistentAtomicXor) { + const std::string body = R"( +%xor = OpAtomicXor %u32 %u32_var %workgroup %sequentially_consistent %u32_0 +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + + CompileSuccessfully(GenerateShaderCode(body, extra, "VulkanKHR"), + SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("SequentiallyConsistent memory semantics cannot be " + "used with the VulkanKHR memory model.")); +} + +TEST_F(ValidateAtomics, OutputMemoryKHRRequiresVulkanMemoryModelKHR) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +OpExecutionMode %1 OriginUpperLeft +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%semantics = OpConstant %3 4100 +%5 = OpTypeFunction %2 +%workgroup = OpConstant %3 2 +%ptr = OpTypePointer Workgroup %3 +%var = OpVariable %ptr Workgroup +%1 = OpFunction %2 None %5 +%7 = OpLabel +OpAtomicStore %var %workgroup %semantics %workgroup +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("AtomicStore: Memory Semantics OutputMemoryKHR " + "requires capability VulkanMemoryModelKHR")); +} + +TEST_F(ValidateAtomics, MakeAvailableKHRRequiresVulkanMemoryModelKHR) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +OpExecutionMode %1 OriginUpperLeft +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%semantics = OpConstant %3 8196 +%5 = OpTypeFunction %2 +%workgroup = OpConstant %3 2 +%ptr = OpTypePointer Workgroup %3 +%var = OpVariable %ptr Workgroup +%1 = OpFunction %2 None %5 +%7 = OpLabel +OpAtomicStore %var %workgroup %semantics %workgroup +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("AtomicStore: Memory Semantics MakeAvailableKHR " + "requires capability VulkanMemoryModelKHR")); +} + +TEST_F(ValidateAtomics, MakeVisibleKHRRequiresVulkanMemoryModelKHR) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +OpExecutionMode %1 OriginUpperLeft +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%semantics = OpConstant %3 16386 +%5 = OpTypeFunction %2 +%workgroup = OpConstant %3 2 +%ptr = OpTypePointer Workgroup %3 +%var = OpVariable %ptr Workgroup +%1 = OpFunction %2 None %5 +%7 = OpLabel +%ld = OpAtomicLoad %3 %var %workgroup %semantics +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("AtomicLoad: Memory Semantics MakeVisibleKHR requires " + "capability VulkanMemoryModelKHR")); +} + +TEST_F(ValidateAtomics, MakeAvailableKHRRequiresReleaseSemantics) { + const std::string text = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %1 "func" +OpExecutionMode %1 OriginUpperLeft +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%semantics = OpConstant %3 8448 +%5 = OpTypeFunction %2 +%workgroup = OpConstant %3 2 +%ptr = OpTypePointer Workgroup %3 +%var = OpVariable %ptr Workgroup +%1 = OpFunction %2 None %5 +%7 = OpLabel +OpAtomicStore %var %workgroup %semantics %workgroup +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("AtomicStore: MakeAvailableKHR Memory Semantics also requires " + "either Release or AcquireRelease Memory Semantics")); +} + +TEST_F(ValidateAtomics, MakeVisibleKHRRequiresAcquireSemantics) { + const std::string text = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %1 "func" +OpExecutionMode %1 OriginUpperLeft +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%semantics = OpConstant %3 16640 +%5 = OpTypeFunction %2 +%workgroup = OpConstant %3 2 +%ptr = OpTypePointer Workgroup %3 +%var = OpVariable %ptr Workgroup +%1 = OpFunction %2 None %5 +%7 = OpLabel +%ld = OpAtomicLoad %3 %var %workgroup %semantics +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("AtomicLoad: MakeVisibleKHR Memory Semantics also requires " + "either Acquire or AcquireRelease Memory Semantics")); +} + +TEST_F(ValidateAtomics, MakeAvailableKHRRequiresStorageSemantics) { + const std::string text = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %1 "func" +OpExecutionMode %1 OriginUpperLeft +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%semantics = OpConstant %3 8196 +%5 = OpTypeFunction %2 +%workgroup = OpConstant %3 2 +%ptr = OpTypePointer Workgroup %3 +%var = OpVariable %ptr Workgroup +%1 = OpFunction %2 None %5 +%7 = OpLabel +OpAtomicStore %var %workgroup %semantics %workgroup +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "AtomicStore: expected Memory Semantics to include a storage class")); +} + +TEST_F(ValidateAtomics, MakeVisibleKHRRequiresStorageSemantics) { + const std::string text = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %1 "func" +OpExecutionMode %1 OriginUpperLeft +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%semantics = OpConstant %3 16386 +%5 = OpTypeFunction %2 +%workgroup = OpConstant %3 2 +%ptr = OpTypePointer Workgroup %3 +%var = OpVariable %ptr Workgroup +%1 = OpFunction %2 None %5 +%7 = OpLabel +%ld = OpAtomicLoad %3 %var %workgroup %semantics +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "AtomicLoad: expected Memory Semantics to include a storage class")); +} + +TEST_F(ValidateAtomics, VulkanMemoryModelAllowsQueueFamilyKHR) { + const std::string body = R"( +%val = OpAtomicAnd %u32 %u32_var %queuefamily %relaxed %u32_1 +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + + CompileSuccessfully(GenerateShaderCode(body, extra, "VulkanKHR"), + SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateAtomics, NonVulkanMemoryModelDisallowsQueueFamilyKHR) { + const std::string body = R"( +%val = OpAtomicAnd %u32 %u32_var %queuefamily %relaxed %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("AtomicAnd: Memory Scope QueueFamilyKHR requires " + "capability VulkanMemoryModelKHR\n %42 = OpAtomicAnd " + "%uint %29 %uint_5 %uint_0_1 %uint_1\n")); +} + +TEST_F(ValidateAtomics, SemanticsSpecConstantShader) { + const std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%spec_const = OpSpecConstant %int 0 +%workgroup = OpConstant %int 2 +%ptr_int_workgroup = OpTypePointer Workgroup %int +%var = OpVariable %ptr_int_workgroup Workgroup +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +%ld = OpAtomicLoad %int %var %workgroup %spec_const +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Memory Semantics ids must be OpConstant when Shader " + "capability is present")); +} + +TEST_F(ValidateAtomics, SemanticsSpecConstantKernel) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%spec_const = OpSpecConstant %int 0 +%workgroup = OpConstant %int 2 +%ptr_int_workgroup = OpTypePointer Workgroup %int +%var = OpVariable %ptr_int_workgroup Workgroup +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +%ld = OpAtomicLoad %int %var %workgroup %spec_const +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAtomics, ScopeSpecConstantShader) { + const std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%spec_const = OpSpecConstant %int 0 +%relaxed = OpConstant %int 0 +%ptr_int_workgroup = OpTypePointer Workgroup %int +%var = OpVariable %ptr_int_workgroup Workgroup +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +%ld = OpAtomicLoad %int %var %spec_const %relaxed +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Scope ids must be OpConstant when Shader capability is present")); +} + +TEST_F(ValidateAtomics, ScopeSpecConstantKernel) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%spec_const = OpSpecConstant %int 0 +%relaxed = OpConstant %int 0 +%ptr_int_workgroup = OpTypePointer Workgroup %int +%var = OpVariable %ptr_int_workgroup Workgroup +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +%ld = OpAtomicLoad %int %var %spec_const %relaxed +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateAtomics, VulkanMemoryModelDeviceScopeBad) { + const std::string body = R"( +%val = OpAtomicAnd %u32 %u32_var %device %relaxed %u32_1 +)"; + + const std::string extra = R"(OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + + CompileSuccessfully(GenerateShaderCode(body, extra, "VulkanKHR"), + SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Use of device scope with VulkanKHR memory model requires the " + "VulkanMemoryModelDeviceScopeKHR capability")); +} + +TEST_F(ValidateAtomics, VulkanMemoryModelDeviceScopeGood) { + const std::string body = R"( +%val = OpAtomicAnd %u32 %u32_var %device %relaxed %u32_1 +)"; + + const std::string extra = R"(OpCapability VulkanMemoryModelKHR +OpCapability VulkanMemoryModelDeviceScopeKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + + CompileSuccessfully(GenerateShaderCode(body, extra, "VulkanKHR"), + SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateAtomics, WebGPUCrossDeviceMemoryScopeBad) { + const std::string body = R"( +%val1 = OpAtomicLoad %u32 %u32_var %cross_device %relaxed +)"; + + CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("AtomicLoad: in WebGPU environment Memory Scope is limited to " + "Workgroup, Subgroup and QueuFamilyKHR\n" + " %34 = OpAtomicLoad %uint %29 %uint_0_0 %uint_0_1\n")); +} + +TEST_F(ValidateAtomics, WebGPUDeviceMemoryScopeBad) { + const std::string body = R"( +%val1 = OpAtomicLoad %u32 %u32_var %device %relaxed +)"; + + CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("AtomicLoad: in WebGPU environment Memory Scope is limited to " + "Workgroup, Subgroup and QueuFamilyKHR\n" + " %34 = OpAtomicLoad %uint %29 %uint_1_0 %uint_0_1\n")); +} + +TEST_F(ValidateAtomics, WebGPUWorkgroupMemoryScopeGood) { + const std::string body = R"( +%val1 = OpAtomicLoad %u32 %u32_var %workgroup %relaxed +)"; + + CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0)); +} + +TEST_F(ValidateAtomics, WebGPUSubgroupMemoryScopeGood) { + const std::string body = R"( +%val1 = OpAtomicLoad %u32 %u32_var %subgroup %relaxed +)"; + + CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0)); +} + +TEST_F(ValidateAtomics, WebGPUInvocationMemoryScopeBad) { + const std::string body = R"( +%val1 = OpAtomicLoad %u32 %u32_var %invocation %relaxed +)"; + + CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("AtomicLoad: in WebGPU environment Memory Scope is limited to " + "Workgroup, Subgroup and QueuFamilyKHR\n" + " %34 = OpAtomicLoad %uint %29 %uint_4 %uint_0_1\n")); +} + +TEST_F(ValidateAtomics, WebGPUQueueFamilyMemoryScopeGood) { + const std::string body = R"( +%val1 = OpAtomicLoad %u32 %u32_var %queuefamily %relaxed +)"; + + CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0)); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_barriers_test.cpp b/test/val/val_barriers_test.cpp new file mode 100644 index 000000000..264d130bd --- /dev/null +++ b/test/val/val_barriers_test.cpp @@ -0,0 +1,1284 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::HasSubstr; +using ::testing::Not; + +using ValidateBarriers = spvtest::ValidateBase; + +std::string GenerateShaderCodeImpl( + const std::string& body, const std::string& capabilities_and_extensions, + const std::string& definitions, const std::string& execution_model, + const std::string& memory_model) { + std::ostringstream ss; + ss << R"( +OpCapability Shader +)"; + + ss << capabilities_and_extensions; + ss << memory_model << std::endl; + ss << "OpEntryPoint " << execution_model << " %main \"main\"\n"; + if (execution_model == "Fragment") { + ss << "OpExecutionMode %main OriginUpperLeft\n"; + } else if (execution_model == "Geometry") { + ss << "OpExecutionMode %main InputPoints\n"; + ss << "OpExecutionMode %main OutputPoints\n"; + } else if (execution_model == "GLCompute") { + ss << "OpExecutionMode %main LocalSize 1 1 1\n"; + } + + ss << R"( +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f32 = OpTypeFloat 32 +%u32 = OpTypeInt 32 0 + +%f32_0 = OpConstant %f32 0 +%f32_1 = OpConstant %f32 1 +%u32_0 = OpConstant %u32 0 +%u32_1 = OpConstant %u32 1 +%u32_4 = OpConstant %u32 4 +)"; + ss << definitions; + ss << R"( +%cross_device = OpConstant %u32 0 +%device = OpConstant %u32 1 +%workgroup = OpConstant %u32 2 +%subgroup = OpConstant %u32 3 +%invocation = OpConstant %u32 4 +%queuefamily = OpConstant %u32 5 + +%none = OpConstant %u32 0 +%acquire = OpConstant %u32 2 +%release = OpConstant %u32 4 +%acquire_release = OpConstant %u32 8 +%acquire_and_release = OpConstant %u32 6 +%sequentially_consistent = OpConstant %u32 16 +%acquire_release_uniform_workgroup = OpConstant %u32 328 +%acquire_and_release_uniform = OpConstant %u32 70 +%acquire_release_subgroup = OpConstant %u32 136 +%uniform = OpConstant %u32 64 + +%main = OpFunction %void None %func +%main_entry = OpLabel +)"; + + ss << body; + + ss << R"( +OpReturn +OpFunctionEnd)"; + + return ss.str(); +} + +std::string GenerateShaderCode( + const std::string& body, + const std::string& capabilities_and_extensions = "", + const std::string& execution_model = "GLCompute") { + const std::string int64_capability = R"( +OpCapability Int64 +)"; + const std::string int64_declarations = R"( +%u64 = OpTypeInt 64 0 +%u64_0 = OpConstant %u64 0 +%u64_1 = OpConstant %u64 1 +)"; + const std::string memory_model = "OpMemoryModel Logical GLSL450"; + return GenerateShaderCodeImpl( + body, int64_capability + capabilities_and_extensions, int64_declarations, + execution_model, memory_model); +} + +std::string GenerateWebGPUShaderCode( + const std::string& body, + const std::string& capabilities_and_extensions = "", + const std::string& execution_model = "GLCompute") { + const std::string vulkan_memory_capability = R"( +OpCapability VulkanMemoryModelKHR +)"; + const std::string vulkan_memory_extension = R"( +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + const std::string memory_model = "OpMemoryModel Logical VulkanKHR"; + return GenerateShaderCodeImpl(body, + vulkan_memory_capability + + capabilities_and_extensions + + vulkan_memory_extension, + "", execution_model, memory_model); +} + +std::string GenerateKernelCode( + const std::string& body, + const std::string& capabilities_and_extensions = "") { + std::ostringstream ss; + ss << R"( +OpCapability Addresses +OpCapability Kernel +OpCapability Linkage +OpCapability Int64 +OpCapability NamedBarrier +)"; + + ss << capabilities_and_extensions; + ss << R"( +OpMemoryModel Physical32 OpenCL +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f32 = OpTypeFloat 32 +%u32 = OpTypeInt 32 0 +%u64 = OpTypeInt 64 0 + +%f32_0 = OpConstant %f32 0 +%f32_1 = OpConstant %f32 1 +%f32_4 = OpConstant %f32 4 +%u32_0 = OpConstant %u32 0 +%u32_1 = OpConstant %u32 1 +%u32_4 = OpConstant %u32 4 +%u64_0 = OpConstant %u64 0 +%u64_1 = OpConstant %u64 1 +%u64_4 = OpConstant %u64 4 + +%cross_device = OpConstant %u32 0 +%device = OpConstant %u32 1 +%workgroup = OpConstant %u32 2 +%subgroup = OpConstant %u32 3 +%invocation = OpConstant %u32 4 + +%none = OpConstant %u32 0 +%acquire = OpConstant %u32 2 +%release = OpConstant %u32 4 +%acquire_release = OpConstant %u32 8 +%acquire_and_release = OpConstant %u32 6 +%sequentially_consistent = OpConstant %u32 16 +%acquire_release_workgroup = OpConstant %u32 264 + +%named_barrier = OpTypeNamedBarrier + +%main = OpFunction %void None %func +%main_entry = OpLabel +)"; + + ss << body; + + ss << R"( +OpReturn +OpFunctionEnd)"; + + return ss.str(); +} + +TEST_F(ValidateBarriers, OpControlBarrierGLComputeSuccess) { + const std::string body = R"( +OpControlBarrier %device %device %none +OpControlBarrier %workgroup %workgroup %acquire +OpControlBarrier %workgroup %device %release +OpControlBarrier %cross_device %cross_device %acquire_release +OpControlBarrier %cross_device %cross_device %sequentially_consistent +OpControlBarrier %cross_device %cross_device %acquire_release_uniform_workgroup +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateBarriers, OpControlBarrierKernelSuccess) { + const std::string body = R"( +OpControlBarrier %device %device %none +OpControlBarrier %workgroup %workgroup %acquire +OpControlBarrier %workgroup %device %release +OpControlBarrier %cross_device %cross_device %acquire_release +OpControlBarrier %cross_device %cross_device %sequentially_consistent +OpControlBarrier %cross_device %cross_device %acquire_release_workgroup +)"; + + CompileSuccessfully(GenerateKernelCode(body), SPV_ENV_UNIVERSAL_1_1); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_1)); +} + +TEST_F(ValidateBarriers, OpControlBarrierTesselationControlSuccess) { + const std::string body = R"( +OpControlBarrier %device %device %none +OpControlBarrier %workgroup %workgroup %acquire +OpControlBarrier %workgroup %device %release +OpControlBarrier %cross_device %cross_device %acquire_release +OpControlBarrier %cross_device %cross_device %sequentially_consistent +OpControlBarrier %cross_device %cross_device %acquire_release_uniform_workgroup +)"; + + CompileSuccessfully(GenerateShaderCode(body, "OpCapability Tessellation\n", + "TessellationControl")); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateBarriers, OpControlBarrierVulkanSuccess) { + const std::string body = R"( +OpControlBarrier %workgroup %device %none +OpControlBarrier %workgroup %workgroup %acquire_release_uniform_workgroup +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0)); +} + +TEST_F(ValidateBarriers, OpControlBarrierWebGPUSuccess) { + const std::string body = R"( +OpControlBarrier %workgroup %queuefamily %none +OpControlBarrier %workgroup %workgroup %acquire_release_uniform_workgroup +)"; + + CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0)); +} + +TEST_F(ValidateBarriers, OpControlBarrierExecutionModelFragmentSpirv12) { + const std::string body = R"( +OpControlBarrier %device %device %none +)"; + + CompileSuccessfully(GenerateShaderCode(body, "", "Fragment"), + SPV_ENV_UNIVERSAL_1_2); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_2)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpControlBarrier requires one of the following Execution " + "Models: TessellationControl, GLCompute or Kernel")); +} + +TEST_F(ValidateBarriers, OpControlBarrierExecutionModelFragmentSpirv13) { + const std::string body = R"( +OpControlBarrier %device %device %none +)"; + + CompileSuccessfully(GenerateShaderCode(body, "", "Fragment"), + SPV_ENV_UNIVERSAL_1_3); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateBarriers, OpControlBarrierFloatExecutionScope) { + const std::string body = R"( +OpControlBarrier %f32_1 %device %none +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("ControlBarrier: expected Execution Scope to be a 32-bit int")); +} + +TEST_F(ValidateBarriers, OpControlBarrierU64ExecutionScope) { + const std::string body = R"( +OpControlBarrier %u64_1 %device %none +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("ControlBarrier: expected Execution Scope to be a 32-bit int")); +} + +TEST_F(ValidateBarriers, OpControlBarrierFloatMemoryScope) { + const std::string body = R"( +OpControlBarrier %device %f32_1 %none +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("ControlBarrier: expected Memory Scope to be a 32-bit int")); +} + +TEST_F(ValidateBarriers, OpControlBarrierU64MemoryScope) { + const std::string body = R"( +OpControlBarrier %device %u64_1 %none +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("ControlBarrier: expected Memory Scope to be a 32-bit int")); +} + +TEST_F(ValidateBarriers, OpControlBarrierFloatMemorySemantics) { + const std::string body = R"( +OpControlBarrier %device %device %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "ControlBarrier: expected Memory Semantics to be a 32-bit int")); +} + +TEST_F(ValidateBarriers, OpControlBarrierU64MemorySemantics) { + const std::string body = R"( +OpControlBarrier %device %device %u64_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "ControlBarrier: expected Memory Semantics to be a 32-bit int")); +} + +TEST_F(ValidateBarriers, OpControlBarrierVulkanExecutionScopeDevice) { + const std::string body = R"( +OpControlBarrier %device %workgroup %none +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("ControlBarrier: in Vulkan environment Execution Scope " + "is limited to Workgroup and Subgroup")); +} + +TEST_F(ValidateBarriers, OpControlBarrierWebGPUExecutionScopeDevice) { + const std::string body = R"( +OpControlBarrier %device %workgroup %none +)"; + + CompileSuccessfully(GenerateWebGPUShaderCode(body), SPV_ENV_WEBGPU_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("ControlBarrier: in WebGPU environment Execution Scope " + "is limited to Workgroup and Subgroup")); +} + +TEST_F(ValidateBarriers, OpControlBarrierVulkanMemoryScopeSubgroup) { + const std::string body = R"( +OpControlBarrier %subgroup %subgroup %none +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("ControlBarrier: in Vulkan 1.0 environment Memory Scope is " + "limited to Device, Workgroup and Invocation")); +} + +TEST_F(ValidateBarriers, OpControlBarrierVulkan1p1MemoryScopeSubgroup) { + const std::string body = R"( +OpControlBarrier %subgroup %subgroup %none +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateBarriers, OpControlBarrierVulkan1p1MemoryScopeCrossDevice) { + const std::string body = R"( +OpControlBarrier %subgroup %cross_device %none +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("ControlBarrier: in Vulkan environment, Memory Scope " + "cannot be CrossDevice")); +} + +TEST_F(ValidateBarriers, OpControlBarrierAcquireAndRelease) { + const std::string body = R"( +OpControlBarrier %device %device %acquire_and_release_uniform +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("ControlBarrier: Memory Semantics can have at most one " + "of the following bits set: Acquire, Release, " + "AcquireRelease or SequentiallyConsistent")); +} + +// TODO(atgoo@github.com): the corresponding check fails Vulkan CTS, +// reenable once fixed. +TEST_F(ValidateBarriers, DISABLED_OpControlBarrierVulkanSubgroupStorageClass) { + const std::string body = R"( +OpControlBarrier %workgroup %device %acquire_release_subgroup +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "ControlBarrier: expected Memory Semantics to include a " + "Vulkan-supported storage class if Memory Semantics is not None")); +} + +TEST_F(ValidateBarriers, OpControlBarrierSubgroupExecutionFragment1p1) { + const std::string body = R"( +OpControlBarrier %subgroup %subgroup %acquire_release_subgroup +)"; + + CompileSuccessfully(GenerateShaderCode(body, "", "Fragment"), + SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateBarriers, OpControlBarrierWorkgroupExecutionFragment1p1) { + const std::string body = R"( +OpControlBarrier %workgroup %workgroup %acquire_release +)"; + + CompileSuccessfully(GenerateShaderCode(body, "", "Fragment"), + SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpControlBarrier execution scope must be Subgroup for " + "Fragment, Vertex, Geometry and TessellationEvaluation " + "execution models")); +} + +TEST_F(ValidateBarriers, OpControlBarrierSubgroupExecutionFragment1p0) { + const std::string body = R"( +OpControlBarrier %subgroup %workgroup %acquire_release +)"; + + CompileSuccessfully(GenerateShaderCode(body, "", "Fragment"), + SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpControlBarrier requires one of the following Execution " + "Models: TessellationControl, GLCompute or Kernel")); +} + +TEST_F(ValidateBarriers, OpControlBarrierSubgroupExecutionVertex1p1) { + const std::string body = R"( +OpControlBarrier %subgroup %subgroup %acquire_release_subgroup +)"; + + CompileSuccessfully(GenerateShaderCode(body, "", "Vertex"), + SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateBarriers, OpControlBarrierWorkgroupExecutionVertex1p1) { + const std::string body = R"( +OpControlBarrier %workgroup %workgroup %acquire_release +)"; + + CompileSuccessfully(GenerateShaderCode(body, "", "Vertex"), + SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpControlBarrier execution scope must be Subgroup for " + "Fragment, Vertex, Geometry and TessellationEvaluation " + "execution models")); +} + +TEST_F(ValidateBarriers, OpControlBarrierSubgroupExecutionVertex1p0) { + const std::string body = R"( +OpControlBarrier %subgroup %workgroup %acquire_release +)"; + + CompileSuccessfully(GenerateShaderCode(body, "", "Vertex"), + SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpControlBarrier requires one of the following Execution " + "Models: TessellationControl, GLCompute or Kernel")); +} + +TEST_F(ValidateBarriers, OpControlBarrierSubgroupExecutionGeometry1p1) { + const std::string body = R"( +OpControlBarrier %subgroup %subgroup %acquire_release_subgroup +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability Geometry\n", "Geometry"), + SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateBarriers, OpControlBarrierWorkgroupExecutionGeometry1p1) { + const std::string body = R"( +OpControlBarrier %workgroup %workgroup %acquire_release +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability Geometry\n", "Geometry"), + SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpControlBarrier execution scope must be Subgroup for " + "Fragment, Vertex, Geometry and TessellationEvaluation " + "execution models")); +} + +TEST_F(ValidateBarriers, OpControlBarrierSubgroupExecutionGeometry1p0) { + const std::string body = R"( +OpControlBarrier %subgroup %workgroup %acquire_release +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability Geometry\n", "Geometry"), + SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpControlBarrier requires one of the following Execution " + "Models: TessellationControl, GLCompute or Kernel")); +} + +TEST_F(ValidateBarriers, + OpControlBarrierSubgroupExecutionTessellationEvaluation1p1) { + const std::string body = R"( +OpControlBarrier %subgroup %subgroup %acquire_release_subgroup +)"; + + CompileSuccessfully(GenerateShaderCode(body, "OpCapability Tessellation\n", + "TessellationEvaluation"), + SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateBarriers, + OpControlBarrierWorkgroupExecutionTessellationEvaluation1p1) { + const std::string body = R"( +OpControlBarrier %workgroup %workgroup %acquire_release +)"; + + CompileSuccessfully(GenerateShaderCode(body, "OpCapability Tessellation\n", + "TessellationEvaluation"), + SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpControlBarrier execution scope must be Subgroup for " + "Fragment, Vertex, Geometry and TessellationEvaluation " + "execution models")); +} + +TEST_F(ValidateBarriers, + OpControlBarrierSubgroupExecutionTessellationEvaluation1p0) { + const std::string body = R"( +OpControlBarrier %subgroup %workgroup %acquire_release +)"; + + CompileSuccessfully(GenerateShaderCode(body, "OpCapability Tessellation\n", + "TessellationEvaluation"), + SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpControlBarrier requires one of the following Execution " + "Models: TessellationControl, GLCompute or Kernel")); +} + +TEST_F(ValidateBarriers, OpMemoryBarrierSuccess) { + const std::string body = R"( +OpMemoryBarrier %cross_device %acquire_release_uniform_workgroup +OpMemoryBarrier %device %uniform +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateBarriers, OpMemoryBarrierKernelSuccess) { + const std::string body = R"( +OpMemoryBarrier %cross_device %acquire_release_workgroup +OpMemoryBarrier %device %none +)"; + + CompileSuccessfully(GenerateKernelCode(body), SPV_ENV_UNIVERSAL_1_1); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_1)); +} + +TEST_F(ValidateBarriers, OpMemoryBarrierVulkanSuccess) { + const std::string body = R"( +OpMemoryBarrier %workgroup %acquire_release_uniform_workgroup +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0)); +} + +TEST_F(ValidateBarriers, OpMemoryBarrierFloatMemoryScope) { + const std::string body = R"( +OpMemoryBarrier %f32_1 %acquire_release_uniform_workgroup +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("MemoryBarrier: expected Memory Scope to be a 32-bit int")); +} + +TEST_F(ValidateBarriers, OpMemoryBarrierU64MemoryScope) { + const std::string body = R"( +OpMemoryBarrier %u64_1 %acquire_release_uniform_workgroup +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("MemoryBarrier: expected Memory Scope to be a 32-bit int")); +} + +TEST_F(ValidateBarriers, OpMemoryBarrierFloatMemorySemantics) { + const std::string body = R"( +OpMemoryBarrier %device %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("MemoryBarrier: expected Memory Semantics to be a 32-bit int")); +} + +TEST_F(ValidateBarriers, OpMemoryBarrierU64MemorySemantics) { + const std::string body = R"( +OpMemoryBarrier %device %u64_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("MemoryBarrier: expected Memory Semantics to be a 32-bit int")); +} + +TEST_F(ValidateBarriers, OpMemoryBarrierVulkanMemoryScopeSubgroup) { + const std::string body = R"( +OpMemoryBarrier %subgroup %acquire_release_uniform_workgroup +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("MemoryBarrier: in Vulkan 1.0 environment Memory Scope is " + "limited to Device, Workgroup and Invocation")); +} + +TEST_F(ValidateBarriers, OpMemoryBarrierVulkan1p1MemoryScopeSubgroup) { + const std::string body = R"( +OpMemoryBarrier %subgroup %acquire_release_uniform_workgroup +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateBarriers, OpMemoryBarrierAcquireAndRelease) { + const std::string body = R"( +OpMemoryBarrier %device %acquire_and_release_uniform +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("MemoryBarrier: Memory Semantics can have at most one " + "of the following bits set: Acquire, Release, " + "AcquireRelease or SequentiallyConsistent")); +} + +TEST_F(ValidateBarriers, OpMemoryBarrierVulkanMemorySemanticsNone) { + const std::string body = R"( +OpMemoryBarrier %device %none +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("MemoryBarrier: Vulkan specification requires Memory Semantics " + "to have one of the following bits set: Acquire, Release, " + "AcquireRelease or SequentiallyConsistent")); +} + +TEST_F(ValidateBarriers, OpMemoryBarrierVulkanMemorySemanticsAcquire) { + const std::string body = R"( +OpMemoryBarrier %device %acquire +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("MemoryBarrier: expected Memory Semantics to include a " + "Vulkan-supported storage class")); +} + +TEST_F(ValidateBarriers, OpMemoryBarrierVulkanSubgroupStorageClass) { + const std::string body = R"( +OpMemoryBarrier %device %acquire_release_subgroup +)"; + + CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("MemoryBarrier: expected Memory Semantics to include a " + "Vulkan-supported storage class")); +} + +TEST_F(ValidateBarriers, OpNamedBarrierInitializeSuccess) { + const std::string body = R"( +%barrier = OpNamedBarrierInitialize %named_barrier %u32_4 +)"; + + CompileSuccessfully(GenerateKernelCode(body), SPV_ENV_UNIVERSAL_1_1); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_1)); +} + +TEST_F(ValidateBarriers, OpNamedBarrierInitializeWrongResultType) { + const std::string body = R"( +%barrier = OpNamedBarrierInitialize %u32 %u32_4 +)"; + + CompileSuccessfully(GenerateKernelCode(body), SPV_ENV_UNIVERSAL_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("NamedBarrierInitialize: expected Result Type to be " + "OpTypeNamedBarrier")); +} + +TEST_F(ValidateBarriers, OpNamedBarrierInitializeFloatSubgroupCount) { + const std::string body = R"( +%barrier = OpNamedBarrierInitialize %named_barrier %f32_4 +)"; + + CompileSuccessfully(GenerateKernelCode(body), SPV_ENV_UNIVERSAL_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("NamedBarrierInitialize: expected Subgroup Count to be " + "a 32-bit int")); +} + +TEST_F(ValidateBarriers, OpNamedBarrierInitializeU64SubgroupCount) { + const std::string body = R"( +%barrier = OpNamedBarrierInitialize %named_barrier %u64_4 +)"; + + CompileSuccessfully(GenerateKernelCode(body), SPV_ENV_UNIVERSAL_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("NamedBarrierInitialize: expected Subgroup Count to be " + "a 32-bit int")); +} + +TEST_F(ValidateBarriers, OpMemoryNamedBarrierSuccess) { + const std::string body = R"( +%barrier = OpNamedBarrierInitialize %named_barrier %u32_4 +OpMemoryNamedBarrier %barrier %workgroup %acquire_release_workgroup +)"; + + CompileSuccessfully(GenerateKernelCode(body), SPV_ENV_UNIVERSAL_1_1); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_1)); +} + +TEST_F(ValidateBarriers, OpMemoryNamedBarrierNotNamedBarrier) { + const std::string body = R"( +OpMemoryNamedBarrier %u32_1 %workgroup %acquire_release_workgroup +)"; + + CompileSuccessfully(GenerateKernelCode(body), SPV_ENV_UNIVERSAL_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("MemoryNamedBarrier: expected Named Barrier to be of " + "type OpTypeNamedBarrier")); +} + +TEST_F(ValidateBarriers, OpMemoryNamedBarrierFloatMemoryScope) { + const std::string body = R"( +%barrier = OpNamedBarrierInitialize %named_barrier %u32_4 +OpMemoryNamedBarrier %barrier %f32_1 %acquire_release_workgroup +)"; + + CompileSuccessfully(GenerateKernelCode(body), SPV_ENV_UNIVERSAL_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "MemoryNamedBarrier: expected Memory Scope to be a 32-bit int")); +} + +TEST_F(ValidateBarriers, OpMemoryNamedBarrierFloatMemorySemantics) { + const std::string body = R"( +%barrier = OpNamedBarrierInitialize %named_barrier %u32_4 +OpMemoryNamedBarrier %barrier %workgroup %f32_0 +)"; + + CompileSuccessfully(GenerateKernelCode(body), SPV_ENV_UNIVERSAL_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "MemoryNamedBarrier: expected Memory Semantics to be a 32-bit int")); +} + +TEST_F(ValidateBarriers, OpMemoryNamedBarrierAcquireAndRelease) { + const std::string body = R"( +%barrier = OpNamedBarrierInitialize %named_barrier %u32_4 +OpMemoryNamedBarrier %barrier %workgroup %acquire_and_release +)"; + + CompileSuccessfully(GenerateKernelCode(body), SPV_ENV_UNIVERSAL_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("MemoryNamedBarrier: Memory Semantics can have at most " + "one of the following bits set: Acquire, Release, " + "AcquireRelease or SequentiallyConsistent")); +} + +TEST_F(ValidateBarriers, TypeAsMemoryScope) { + const std::string body = R"( +OpMemoryBarrier %u32 %u32_0 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("Operand 5[%uint] cannot be a " + "type")); +} + +TEST_F(ValidateBarriers, + OpControlBarrierVulkanMemoryModelBanSequentiallyConsistent) { + const std::string text = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %1 "func" +OpExecutionMode %1 OriginUpperLeft +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpConstant %3 16 +%5 = OpTypeFunction %2 +%6 = OpConstant %3 5 +%1 = OpFunction %2 None %5 +%7 = OpLabel +OpControlBarrier %6 %6 %4 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("SequentiallyConsistent memory semantics cannot be " + "used with the VulkanKHR memory model.")); +} + +TEST_F(ValidateBarriers, + OpMemoryBarrierVulkanMemoryModelBanSequentiallyConsistent) { + const std::string text = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %1 "func" +OpExecutionMode %1 OriginUpperLeft +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpConstant %3 16 +%5 = OpTypeFunction %2 +%6 = OpConstant %3 5 +%1 = OpFunction %2 None %5 +%7 = OpLabel +OpMemoryBarrier %6 %4 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("SequentiallyConsistent memory semantics cannot be " + "used with the VulkanKHR memory model.")); +} + +TEST_F(ValidateBarriers, OutputMemoryKHRRequireVulkanMemoryModelKHR) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +OpExecutionMode %1 OriginUpperLeft +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%semantics = OpConstant %3 4104 +%5 = OpTypeFunction %2 +%device = OpConstant %3 1 +%1 = OpFunction %2 None %5 +%7 = OpLabel +OpControlBarrier %device %device %semantics +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("ControlBarrier: Memory Semantics OutputMemoryKHR " + "requires capability VulkanMemoryModelKHR")); +} + +TEST_F(ValidateBarriers, MakeAvailableKHRRequireVulkanMemoryModelKHR) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +OpExecutionMode %1 OriginUpperLeft +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%semantics = OpConstant %3 8264 +%5 = OpTypeFunction %2 +%device = OpConstant %3 1 +%1 = OpFunction %2 None %5 +%7 = OpLabel +OpControlBarrier %device %device %semantics +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("ControlBarrier: Memory Semantics MakeAvailableKHR " + "requires capability VulkanMemoryModelKHR")); +} + +TEST_F(ValidateBarriers, MakeVisibleKHRRequireVulkanMemoryModelKHR) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +OpExecutionMode %1 OriginUpperLeft +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%semantics = OpConstant %3 16456 +%5 = OpTypeFunction %2 +%device = OpConstant %3 1 +%1 = OpFunction %2 None %5 +%7 = OpLabel +OpControlBarrier %device %device %semantics +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("ControlBarrier: Memory Semantics MakeVisibleKHR " + "requires capability VulkanMemoryModelKHR")); +} + +TEST_F(ValidateBarriers, MakeAvailableKHRRequiresReleaseSemantics) { + const std::string text = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%workgroup = OpConstant %int 2 +%semantics = OpConstant %int 8448 +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +OpControlBarrier %workgroup %workgroup %semantics +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("ControlBarrier: MakeAvailableKHR Memory Semantics also " + "requires either Release or AcquireRelease Memory Semantics")); +} + +TEST_F(ValidateBarriers, MakeVisibleKHRRequiresAcquireSemantics) { + const std::string text = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%workgroup = OpConstant %int 2 +%semantics = OpConstant %int 16640 +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +OpControlBarrier %workgroup %workgroup %semantics +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("ControlBarrier: MakeVisibleKHR Memory Semantics also requires " + "either Acquire or AcquireRelease Memory Semantics")); +} + +TEST_F(ValidateBarriers, MakeAvailableKHRRequiresStorageSemantics) { + const std::string text = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%workgroup = OpConstant %int 2 +%semantics = OpConstant %int 8196 +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +OpMemoryBarrier %workgroup %semantics +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("MemoryBarrier: expected Memory Semantics to include a " + "storage class")); +} + +TEST_F(ValidateBarriers, MakeVisibleKHRRequiresStorageSemantics) { + const std::string text = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%workgroup = OpConstant %int 2 +%semantics = OpConstant %int 16386 +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +OpMemoryBarrier %workgroup %semantics +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("MemoryBarrier: expected Memory Semantics to include a " + "storage class")); +} + +TEST_F(ValidateBarriers, SemanticsSpecConstantShader) { + const std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%ptr_int_workgroup = OpTypePointer Workgroup %int +%var = OpVariable %ptr_int_workgroup Workgroup +%voidfn = OpTypeFunction %void +%spec_const = OpSpecConstant %int 0 +%workgroup = OpConstant %int 2 +%func = OpFunction %void None %voidfn +%entry = OpLabel +OpMemoryBarrier %workgroup %spec_const +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Memory Semantics ids must be OpConstant when Shader " + "capability is present")); +} + +TEST_F(ValidateBarriers, SemanticsSpecConstantKernel) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%ptr_int_workgroup = OpTypePointer Workgroup %int +%var = OpVariable %ptr_int_workgroup Workgroup +%voidfn = OpTypeFunction %void +%spec_const = OpSpecConstant %int 0 +%workgroup = OpConstant %int 2 +%func = OpFunction %void None %voidfn +%entry = OpLabel +OpMemoryBarrier %workgroup %spec_const +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateBarriers, ScopeSpecConstantShader) { + const std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%ptr_int_workgroup = OpTypePointer Workgroup %int +%var = OpVariable %ptr_int_workgroup Workgroup +%voidfn = OpTypeFunction %void +%spec_const = OpSpecConstant %int 0 +%relaxed = OpConstant %int 0 +%func = OpFunction %void None %voidfn +%entry = OpLabel +OpMemoryBarrier %spec_const %relaxed +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Scope ids must be OpConstant when Shader " + "capability is present")); +} + +TEST_F(ValidateBarriers, ScopeSpecConstantKernel) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%ptr_int_workgroup = OpTypePointer Workgroup %int +%var = OpVariable %ptr_int_workgroup Workgroup +%voidfn = OpTypeFunction %void +%spec_const = OpSpecConstant %int 0 +%relaxed = OpConstant %int 0 +%func = OpFunction %void None %voidfn +%entry = OpLabel +OpMemoryBarrier %spec_const %relaxed +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateBarriers, VulkanMemoryModelDeviceScopeBad) { + const std::string text = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%device = OpConstant %int 1 +%semantics = OpConstant %int 0 +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +OpMemoryBarrier %device %semantics +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Use of device scope with VulkanKHR memory model requires the " + "VulkanMemoryModelDeviceScopeKHR capability")); +} + +TEST_F(ValidateBarriers, VulkanMemoryModelDeviceScopeGood) { + const std::string text = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability VulkanMemoryModelDeviceScopeKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%device = OpConstant %int 1 +%semantics = OpConstant %int 0 +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +OpMemoryBarrier %device %semantics +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_bitwise_test.cpp b/test/val/val_bitwise_test.cpp new file mode 100644 index 000000000..1001def8f --- /dev/null +++ b/test/val/val_bitwise_test.cpp @@ -0,0 +1,549 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests for unique type declaration rules validator. + +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::HasSubstr; +using ::testing::Not; + +using ValidateBitwise = spvtest::ValidateBase; + +std::string GenerateShaderCode( + const std::string& body, + const std::string& capabilities_and_extensions = "") { + const std::string capabilities = + R"( +OpCapability Shader +OpCapability Int64 +OpCapability Float64)"; + + const std::string after_extension_before_body = + R"( +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f32 = OpTypeFloat 32 +%u32 = OpTypeInt 32 0 +%s32 = OpTypeInt 32 1 +%f64 = OpTypeFloat 64 +%u64 = OpTypeInt 64 0 +%s64 = OpTypeInt 64 1 +%boolvec2 = OpTypeVector %bool 2 +%s32vec2 = OpTypeVector %s32 2 +%u32vec2 = OpTypeVector %u32 2 +%u64vec2 = OpTypeVector %u64 2 +%f32vec2 = OpTypeVector %f32 2 +%f64vec2 = OpTypeVector %f64 2 +%boolvec3 = OpTypeVector %bool 3 +%u32vec3 = OpTypeVector %u32 3 +%u64vec3 = OpTypeVector %u64 3 +%s32vec3 = OpTypeVector %s32 3 +%f32vec3 = OpTypeVector %f32 3 +%f64vec3 = OpTypeVector %f64 3 +%boolvec4 = OpTypeVector %bool 4 +%u32vec4 = OpTypeVector %u32 4 +%u64vec4 = OpTypeVector %u64 4 +%s32vec4 = OpTypeVector %s32 4 +%f32vec4 = OpTypeVector %f32 4 +%f64vec4 = OpTypeVector %f64 4 + +%f32_0 = OpConstant %f32 0 +%f32_1 = OpConstant %f32 1 +%f32_2 = OpConstant %f32 2 +%f32_3 = OpConstant %f32 3 +%f32_4 = OpConstant %f32 4 + +%s32_0 = OpConstant %s32 0 +%s32_1 = OpConstant %s32 1 +%s32_2 = OpConstant %s32 2 +%s32_3 = OpConstant %s32 3 +%s32_4 = OpConstant %s32 4 +%s32_m1 = OpConstant %s32 -1 + +%u32_0 = OpConstant %u32 0 +%u32_1 = OpConstant %u32 1 +%u32_2 = OpConstant %u32 2 +%u32_3 = OpConstant %u32 3 +%u32_4 = OpConstant %u32 4 + +%f64_0 = OpConstant %f64 0 +%f64_1 = OpConstant %f64 1 +%f64_2 = OpConstant %f64 2 +%f64_3 = OpConstant %f64 3 +%f64_4 = OpConstant %f64 4 + +%s64_0 = OpConstant %s64 0 +%s64_1 = OpConstant %s64 1 +%s64_2 = OpConstant %s64 2 +%s64_3 = OpConstant %s64 3 +%s64_4 = OpConstant %s64 4 +%s64_m1 = OpConstant %s64 -1 + +%u64_0 = OpConstant %u64 0 +%u64_1 = OpConstant %u64 1 +%u64_2 = OpConstant %u64 2 +%u64_3 = OpConstant %u64 3 +%u64_4 = OpConstant %u64 4 + +%u32vec2_01 = OpConstantComposite %u32vec2 %u32_0 %u32_1 +%u32vec2_12 = OpConstantComposite %u32vec2 %u32_1 %u32_2 +%u32vec3_012 = OpConstantComposite %u32vec3 %u32_0 %u32_1 %u32_2 +%u32vec3_123 = OpConstantComposite %u32vec3 %u32_1 %u32_2 %u32_3 +%u32vec4_0123 = OpConstantComposite %u32vec4 %u32_0 %u32_1 %u32_2 %u32_3 +%u32vec4_1234 = OpConstantComposite %u32vec4 %u32_1 %u32_2 %u32_3 %u32_4 + +%s32vec2_01 = OpConstantComposite %s32vec2 %s32_0 %s32_1 +%s32vec2_12 = OpConstantComposite %s32vec2 %s32_1 %s32_2 +%s32vec3_012 = OpConstantComposite %s32vec3 %s32_0 %s32_1 %s32_2 +%s32vec3_123 = OpConstantComposite %s32vec3 %s32_1 %s32_2 %s32_3 +%s32vec4_0123 = OpConstantComposite %s32vec4 %s32_0 %s32_1 %s32_2 %s32_3 +%s32vec4_1234 = OpConstantComposite %s32vec4 %s32_1 %s32_2 %s32_3 %s32_4 + +%f32vec2_01 = OpConstantComposite %f32vec2 %f32_0 %f32_1 +%f32vec2_12 = OpConstantComposite %f32vec2 %f32_1 %f32_2 +%f32vec3_012 = OpConstantComposite %f32vec3 %f32_0 %f32_1 %f32_2 +%f32vec3_123 = OpConstantComposite %f32vec3 %f32_1 %f32_2 %f32_3 +%f32vec4_0123 = OpConstantComposite %f32vec4 %f32_0 %f32_1 %f32_2 %f32_3 +%f32vec4_1234 = OpConstantComposite %f32vec4 %f32_1 %f32_2 %f32_3 %f32_4 + +%main = OpFunction %void None %func +%main_entry = OpLabel)"; + + const std::string after_body = + R"( +OpReturn +OpFunctionEnd)"; + + return capabilities + capabilities_and_extensions + + after_extension_before_body + body + after_body; +} + +TEST_F(ValidateBitwise, ShiftAllSuccess) { + const std::string body = R"( +%val1 = OpShiftRightLogical %u64 %u64_1 %s32_2 +%val2 = OpShiftRightArithmetic %s32vec2 %s32vec2_12 %s32vec2_12 +%val3 = OpShiftLeftLogical %u32vec2 %s32vec2_12 %u32vec2_12 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateBitwise, OpShiftRightLogicalWrongResultType) { + const std::string body = R"( +%val1 = OpShiftRightLogical %bool %u64_1 %s32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected int scalar or vector type as Result Type: " + "ShiftRightLogical")); +} + +TEST_F(ValidateBitwise, OpShiftRightLogicalBaseNotInt) { + const std::string body = R"( +%val1 = OpShiftRightLogical %u32 %f32_1 %s32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Base to be int scalar or vector: ShiftRightLogical")); +} + +TEST_F(ValidateBitwise, OpShiftRightLogicalBaseWrongDimension) { + const std::string body = R"( +%val1 = OpShiftRightLogical %u32 %u32vec2_12 %s32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Base to have the same dimension as Result Type: " + "ShiftRightLogical")); +} + +TEST_F(ValidateBitwise, OpShiftRightLogicalBaseWrongBitWidth) { + const std::string body = R"( +%val1 = OpShiftRightLogical %u64 %u32_1 %s32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Base to have the same bit width as Result Type: " + "ShiftRightLogical")); +} + +TEST_F(ValidateBitwise, OpShiftRightLogicalShiftNotInt) { + const std::string body = R"( +%val1 = OpShiftRightLogical %u32 %u32_1 %f32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected Shift to be int scalar or vector: ShiftRightLogical")); +} + +TEST_F(ValidateBitwise, OpShiftRightLogicalShiftWrongDimension) { + const std::string body = R"( +%val1 = OpShiftRightLogical %u32 %u32_1 %s32vec2_12 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Shift to have the same dimension as Result Type: " + "ShiftRightLogical")); +} + +TEST_F(ValidateBitwise, LogicAllSuccess) { + const std::string body = R"( +%val1 = OpBitwiseOr %u64 %u64_1 %s64_0 +%val2 = OpBitwiseAnd %s64 %s64_1 %u64_0 +%val3 = OpBitwiseXor %s32vec2 %s32vec2_12 %u32vec2_01 +%val4 = OpNot %s32vec2 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateBitwise, OpBitwiseAndWrongResultType) { + const std::string body = R"( +%val1 = OpBitwiseAnd %bool %u64_1 %s32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected int scalar or vector type as Result Type: BitwiseAnd")); +} + +TEST_F(ValidateBitwise, OpBitwiseAndLeftNotInt) { + const std::string body = R"( +%val1 = OpBitwiseAnd %u32 %f32_1 %s32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected int scalar or vector as operand: BitwiseAnd " + "operand index 2")); +} + +TEST_F(ValidateBitwise, OpBitwiseAndRightNotInt) { + const std::string body = R"( +%val1 = OpBitwiseAnd %u32 %u32_1 %f32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected int scalar or vector as operand: BitwiseAnd " + "operand index 3")); +} + +TEST_F(ValidateBitwise, OpBitwiseAndLeftWrongDimension) { + const std::string body = R"( +%val1 = OpBitwiseAnd %u32 %u32vec2_12 %s32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected operands to have the same dimension as Result Type: " + "BitwiseAnd operand index 2")); +} + +TEST_F(ValidateBitwise, OpBitwiseAndRightWrongDimension) { + const std::string body = R"( +%val1 = OpBitwiseAnd %u32 %s32_2 %u32vec2_12 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected operands to have the same dimension as Result Type: " + "BitwiseAnd operand index 3")); +} + +TEST_F(ValidateBitwise, OpBitwiseAndLeftWrongBitWidth) { + const std::string body = R"( +%val1 = OpBitwiseAnd %u64 %u32_1 %s64_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected operands to have the same bit width as Result Type: " + "BitwiseAnd operand index 2")); +} + +TEST_F(ValidateBitwise, OpBitwiseAndRightWrongBitWidth) { + const std::string body = R"( +%val1 = OpBitwiseAnd %u64 %u64_1 %s32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected operands to have the same bit width as Result Type: " + "BitwiseAnd operand index 3")); +} + +TEST_F(ValidateBitwise, OpBitFieldInsertSuccess) { + const std::string body = R"( +%val1 = OpBitFieldInsert %u64 %u64_1 %u64_2 %s32_1 %s32_2 +%val2 = OpBitFieldInsert %s32vec2 %s32vec2_12 %s32vec2_12 %s32_1 %u32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateBitwise, OpBitFieldInsertWrongResultType) { + const std::string body = R"( +%val1 = OpBitFieldInsert %bool %u64_1 %u64_2 %s32_1 %s32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected int scalar or vector type as Result Type: BitFieldInsert")); +} + +TEST_F(ValidateBitwise, OpBitFieldInsertWrongBaseType) { + const std::string body = R"( +%val1 = OpBitFieldInsert %u64 %s64_1 %u64_2 %s32_1 %s32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected Base Type to be equal to Result Type: BitFieldInsert")); +} + +TEST_F(ValidateBitwise, OpBitFieldInsertWrongInsertType) { + const std::string body = R"( +%val1 = OpBitFieldInsert %u64 %u64_1 %s64_2 %s32_1 %s32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected Insert Type to be equal to Result Type: BitFieldInsert")); +} + +TEST_F(ValidateBitwise, OpBitFieldInsertOffsetNotInt) { + const std::string body = R"( +%val1 = OpBitFieldInsert %u64 %u64_1 %u64_2 %f32_1 %s32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Offset Type to be int scalar: BitFieldInsert")); +} + +TEST_F(ValidateBitwise, OpBitFieldInsertCountNotInt) { + const std::string body = R"( +%val1 = OpBitFieldInsert %u64 %u64_1 %u64_2 %u32_1 %f32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Count Type to be int scalar: BitFieldInsert")); +} + +TEST_F(ValidateBitwise, OpBitFieldSExtractSuccess) { + const std::string body = R"( +%val1 = OpBitFieldSExtract %u64 %u64_1 %s32_1 %s32_2 +%val2 = OpBitFieldSExtract %s32vec2 %s32vec2_12 %s32_1 %u32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateBitwise, OpBitFieldSExtractWrongResultType) { + const std::string body = R"( +%val1 = OpBitFieldSExtract %bool %u64_1 %s32_1 %s32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected int scalar or vector type as Result Type: " + "BitFieldSExtract")); +} + +TEST_F(ValidateBitwise, OpBitFieldSExtractWrongBaseType) { + const std::string body = R"( +%val1 = OpBitFieldSExtract %u64 %s64_1 %s32_1 %s32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected Base Type to be equal to Result Type: BitFieldSExtract")); +} + +TEST_F(ValidateBitwise, OpBitFieldSExtractOffsetNotInt) { + const std::string body = R"( +%val1 = OpBitFieldSExtract %u64 %u64_1 %f32_1 %s32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Offset Type to be int scalar: BitFieldSExtract")); +} + +TEST_F(ValidateBitwise, OpBitFieldSExtractCountNotInt) { + const std::string body = R"( +%val1 = OpBitFieldSExtract %u64 %u64_1 %u32_1 %f32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Count Type to be int scalar: BitFieldSExtract")); +} + +TEST_F(ValidateBitwise, OpBitReverseSuccess) { + const std::string body = R"( +%val1 = OpBitReverse %u64 %u64_1 +%val2 = OpBitReverse %s32vec2 %s32vec2_12 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateBitwise, OpBitReverseWrongResultType) { + const std::string body = R"( +%val1 = OpBitReverse %bool %u64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected int scalar or vector type as Result Type: BitReverse")); +} + +TEST_F(ValidateBitwise, OpBitReverseWrongBaseType) { + const std::string body = R"( +%val1 = OpBitReverse %u64 %s64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Base Type to be equal to Result Type: BitReverse")); +} + +TEST_F(ValidateBitwise, OpBitCountSuccess) { + const std::string body = R"( +%val1 = OpBitCount %s32 %u64_1 +%val2 = OpBitCount %u32vec2 %s32vec2_12 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateBitwise, OpBitCountWrongResultType) { + const std::string body = R"( +%val1 = OpBitCount %bool %u64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected int scalar or vector type as Result Type: BitCount")); +} + +TEST_F(ValidateBitwise, OpBitCountBaseNotInt) { + const std::string body = R"( +%val1 = OpBitCount %u32 %f64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Base Type to be int scalar or vector: BitCount")); +} + +TEST_F(ValidateBitwise, OpBitCountBaseWrongDimension) { + const std::string body = R"( +%val1 = OpBitCount %u32 %u32vec2_12 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Base dimension to be equal to Result Type dimension: " + "BitCount")); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_builtins_test.cpp b/test/val/val_builtins_test.cpp new file mode 100644 index 000000000..ec0758275 --- /dev/null +++ b/test/val/val_builtins_test.cpp @@ -0,0 +1,2283 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests validation rules of GLSL.450.std and OpenCL.std extended instructions. +// Doesn't test OpenCL.std vector size 2, 3, 4, 8 or 16 rules (not supported +// by standard SPIR-V). + +#include +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +struct TestResult { + TestResult(spv_result_t in_validation_result = SPV_SUCCESS, + const char* in_error_str = nullptr, + const char* in_error_str2 = nullptr) + : validation_result(in_validation_result), + error_str(in_error_str), + error_str2(in_error_str2) {} + spv_result_t validation_result; + const char* error_str; + const char* error_str2; +}; + +using ::testing::Combine; +using ::testing::HasSubstr; +using ::testing::Not; +using ::testing::Values; +using ::testing::ValuesIn; + +using ValidateBuiltIns = spvtest::ValidateBase; +using ValidateVulkanCombineBuiltInExecutionModelDataTypeResult = + spvtest::ValidateBase>; +using ValidateVulkanCombineBuiltInArrayedVariable = spvtest::ValidateBase< + std::tuple>; + +struct EntryPoint { + std::string name; + std::string execution_model; + std::string execution_modes; + std::string body; + std::string interfaces; +}; + +class CodeGenerator { + public: + std::string Build() const; + + std::vector entry_points_; + std::string capabilities_; + std::string extensions_; + std::string memory_model_; + std::string before_types_; + std::string types_; + std::string after_types_; + std::string add_at_the_end_; +}; + +std::string CodeGenerator::Build() const { + std::ostringstream ss; + + ss << capabilities_; + ss << extensions_; + ss << memory_model_; + + for (const EntryPoint& entry_point : entry_points_) { + ss << "OpEntryPoint " << entry_point.execution_model << " %" + << entry_point.name << " \"" << entry_point.name << "\" " + << entry_point.interfaces << "\n"; + } + + for (const EntryPoint& entry_point : entry_points_) { + ss << entry_point.execution_modes << "\n"; + } + + ss << before_types_; + ss << types_; + ss << after_types_; + + for (const EntryPoint& entry_point : entry_points_) { + ss << "\n"; + ss << "%" << entry_point.name << " = OpFunction %void None %func\n"; + ss << "%" << entry_point.name << "_entry = OpLabel\n"; + ss << entry_point.body; + ss << "\nOpReturn\nOpFunctionEnd\n"; + } + + ss << add_at_the_end_; + + return ss.str(); +} + +std::string GetDefaultShaderCapabilities() { + return R"( +OpCapability Shader +OpCapability Geometry +OpCapability Tessellation +OpCapability Float64 +OpCapability Int64 +OpCapability MultiViewport +OpCapability SampleRateShading +)"; +} + +std::string GetDefaultShaderTypes() { + return R"( +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f32 = OpTypeFloat 32 +%f64 = OpTypeFloat 64 +%u32 = OpTypeInt 32 0 +%u64 = OpTypeInt 64 0 +%f32vec2 = OpTypeVector %f32 2 +%f32vec3 = OpTypeVector %f32 3 +%f32vec4 = OpTypeVector %f32 4 +%f64vec2 = OpTypeVector %f64 2 +%f64vec3 = OpTypeVector %f64 3 +%f64vec4 = OpTypeVector %f64 4 +%u32vec2 = OpTypeVector %u32 2 +%u32vec3 = OpTypeVector %u32 3 +%u64vec3 = OpTypeVector %u64 3 +%u32vec4 = OpTypeVector %u32 4 +%u64vec2 = OpTypeVector %u64 2 + +%f32_0 = OpConstant %f32 0 +%f32_1 = OpConstant %f32 1 +%f32_2 = OpConstant %f32 2 +%f32_3 = OpConstant %f32 3 +%f32_4 = OpConstant %f32 4 +%f32_h = OpConstant %f32 0.5 +%f32vec2_01 = OpConstantComposite %f32vec2 %f32_0 %f32_1 +%f32vec2_12 = OpConstantComposite %f32vec2 %f32_1 %f32_2 +%f32vec3_012 = OpConstantComposite %f32vec3 %f32_0 %f32_1 %f32_2 +%f32vec3_123 = OpConstantComposite %f32vec3 %f32_1 %f32_2 %f32_3 +%f32vec4_0123 = OpConstantComposite %f32vec4 %f32_0 %f32_1 %f32_2 %f32_3 +%f32vec4_1234 = OpConstantComposite %f32vec4 %f32_1 %f32_2 %f32_3 %f32_4 + +%f64_0 = OpConstant %f64 0 +%f64_1 = OpConstant %f64 1 +%f64_2 = OpConstant %f64 2 +%f64_3 = OpConstant %f64 3 +%f64vec2_01 = OpConstantComposite %f64vec2 %f64_0 %f64_1 +%f64vec3_012 = OpConstantComposite %f64vec3 %f64_0 %f64_1 %f64_2 +%f64vec4_0123 = OpConstantComposite %f64vec4 %f64_0 %f64_1 %f64_2 %f64_3 + +%u32_0 = OpConstant %u32 0 +%u32_1 = OpConstant %u32 1 +%u32_2 = OpConstant %u32 2 +%u32_3 = OpConstant %u32 3 +%u32_4 = OpConstant %u32 4 + +%u64_0 = OpConstant %u64 0 +%u64_1 = OpConstant %u64 1 +%u64_2 = OpConstant %u64 2 +%u64_3 = OpConstant %u64 3 + +%u32vec2_01 = OpConstantComposite %u32vec2 %u32_0 %u32_1 +%u32vec2_12 = OpConstantComposite %u32vec2 %u32_1 %u32_2 +%u32vec4_0123 = OpConstantComposite %u32vec4 %u32_0 %u32_1 %u32_2 %u32_3 +%u64vec2_01 = OpConstantComposite %u64vec2 %u64_0 %u64_1 + +%u32arr2 = OpTypeArray %u32 %u32_2 +%u32arr3 = OpTypeArray %u32 %u32_3 +%u32arr4 = OpTypeArray %u32 %u32_4 +%u64arr2 = OpTypeArray %u64 %u32_2 +%u64arr3 = OpTypeArray %u64 %u32_3 +%u64arr4 = OpTypeArray %u64 %u32_4 +%f32arr2 = OpTypeArray %f32 %u32_2 +%f32arr3 = OpTypeArray %f32 %u32_3 +%f32arr4 = OpTypeArray %f32 %u32_4 +%f64arr2 = OpTypeArray %f64 %u32_2 +%f64arr3 = OpTypeArray %f64 %u32_3 +%f64arr4 = OpTypeArray %f64 %u32_4 + +%f32vec3arr3 = OpTypeArray %f32vec3 %u32_3 +%f32vec4arr3 = OpTypeArray %f32vec4 %u32_3 +%f64vec4arr3 = OpTypeArray %f64vec4 %u32_3 +)"; +} + +CodeGenerator GetDefaultShaderCodeGenerator() { + CodeGenerator generator; + generator.capabilities_ = GetDefaultShaderCapabilities(); + generator.memory_model_ = "OpMemoryModel Logical GLSL450\n"; + generator.types_ = GetDefaultShaderTypes(); + return generator; +} + +TEST_P(ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, InMain) { + const char* const built_in = std::get<0>(GetParam()); + const char* const execution_model = std::get<1>(GetParam()); + const char* const storage_class = std::get<2>(GetParam()); + const char* const data_type = std::get<3>(GetParam()); + const TestResult& test_result = std::get<4>(GetParam()); + + CodeGenerator generator = GetDefaultShaderCodeGenerator(); + generator.before_types_ = "OpMemberDecorate %built_in_type 0 BuiltIn "; + generator.before_types_ += built_in; + generator.before_types_ += "\n"; + + std::ostringstream after_types; + after_types << "%built_in_type = OpTypeStruct " << data_type << "\n"; + after_types << "%built_in_ptr = OpTypePointer " << storage_class + << " %built_in_type\n"; + after_types << "%built_in_var = OpVariable %built_in_ptr " << storage_class + << "\n"; + after_types << "%data_ptr = OpTypePointer " << storage_class << " " + << data_type << "\n"; + generator.after_types_ = after_types.str(); + + EntryPoint entry_point; + entry_point.name = "main"; + entry_point.execution_model = execution_model; + if (strncmp(storage_class, "Input", 5) == 0 || + strncmp(storage_class, "Output", 6) == 0) { + entry_point.interfaces = "%built_in_var"; + } + + std::ostringstream execution_modes; + if (0 == std::strcmp(execution_model, "Fragment")) { + execution_modes << "OpExecutionMode %" << entry_point.name + << " OriginUpperLeft\n"; + if (0 == std::strcmp(built_in, "FragDepth")) { + execution_modes << "OpExecutionMode %" << entry_point.name + << " DepthReplacing\n"; + } + } + if (0 == std::strcmp(execution_model, "Geometry")) { + execution_modes << "OpExecutionMode %" << entry_point.name + << " InputPoints\n"; + execution_modes << "OpExecutionMode %" << entry_point.name + << " OutputPoints\n"; + } + if (0 == std::strcmp(execution_model, "GLCompute")) { + execution_modes << "OpExecutionMode %" << entry_point.name + << " LocalSize 1 1 1\n"; + } + entry_point.execution_modes = execution_modes.str(); + + entry_point.body = R"( +%ptr = OpAccessChain %data_ptr %built_in_var %u32_0 +)"; + generator.entry_points_.push_back(std::move(entry_point)); + + CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(test_result.validation_result, + ValidateInstructions(SPV_ENV_VULKAN_1_0)); + if (test_result.error_str) { + EXPECT_THAT(getDiagnosticString(), HasSubstr(test_result.error_str)); + } + if (test_result.error_str2) { + EXPECT_THAT(getDiagnosticString(), HasSubstr(test_result.error_str2)); + } +} + +TEST_P(ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, InFunction) { + const char* const built_in = std::get<0>(GetParam()); + const char* const execution_model = std::get<1>(GetParam()); + const char* const storage_class = std::get<2>(GetParam()); + const char* const data_type = std::get<3>(GetParam()); + const TestResult& test_result = std::get<4>(GetParam()); + + CodeGenerator generator = GetDefaultShaderCodeGenerator(); + generator.before_types_ = "OpMemberDecorate %built_in_type 0 BuiltIn "; + generator.before_types_ += built_in; + generator.before_types_ += "\n"; + + std::ostringstream after_types; + after_types << "%built_in_type = OpTypeStruct " << data_type << "\n"; + after_types << "%built_in_ptr = OpTypePointer " << storage_class + << " %built_in_type\n"; + after_types << "%built_in_var = OpVariable %built_in_ptr " << storage_class + << "\n"; + after_types << "%data_ptr = OpTypePointer " << storage_class << " " + << data_type << "\n"; + generator.after_types_ = after_types.str(); + + EntryPoint entry_point; + entry_point.name = "main"; + entry_point.execution_model = execution_model; + if (strncmp(storage_class, "Input", 5) == 0 || + strncmp(storage_class, "Output", 6) == 0) { + entry_point.interfaces = "%built_in_var"; + } + + std::ostringstream execution_modes; + if (0 == std::strcmp(execution_model, "Fragment")) { + execution_modes << "OpExecutionMode %" << entry_point.name + << " OriginUpperLeft\n"; + if (0 == std::strcmp(built_in, "FragDepth")) { + execution_modes << "OpExecutionMode %" << entry_point.name + << " DepthReplacing\n"; + } + } + if (0 == std::strcmp(execution_model, "Geometry")) { + execution_modes << "OpExecutionMode %" << entry_point.name + << " InputPoints\n"; + execution_modes << "OpExecutionMode %" << entry_point.name + << " OutputPoints\n"; + } + if (0 == std::strcmp(execution_model, "GLCompute")) { + execution_modes << "OpExecutionMode %" << entry_point.name + << " LocalSize 1 1 1\n"; + } + entry_point.execution_modes = execution_modes.str(); + + entry_point.body = R"( +%val2 = OpFunctionCall %void %foo +)"; + + generator.add_at_the_end_ = R"( +%foo = OpFunction %void None %func +%foo_entry = OpLabel +%ptr = OpAccessChain %data_ptr %built_in_var %u32_0 +OpReturn +OpFunctionEnd +)"; + generator.entry_points_.push_back(std::move(entry_point)); + + CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(test_result.validation_result, + ValidateInstructions(SPV_ENV_VULKAN_1_0)); + if (test_result.error_str) { + EXPECT_THAT(getDiagnosticString(), HasSubstr(test_result.error_str)); + } + if (test_result.error_str2) { + EXPECT_THAT(getDiagnosticString(), HasSubstr(test_result.error_str2)); + } +} + +TEST_P(ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, Variable) { + const char* const built_in = std::get<0>(GetParam()); + const char* const execution_model = std::get<1>(GetParam()); + const char* const storage_class = std::get<2>(GetParam()); + const char* const data_type = std::get<3>(GetParam()); + const TestResult& test_result = std::get<4>(GetParam()); + + CodeGenerator generator = GetDefaultShaderCodeGenerator(); + generator.before_types_ = "OpDecorate %built_in_var BuiltIn "; + generator.before_types_ += built_in; + generator.before_types_ += "\n"; + + std::ostringstream after_types; + after_types << "%built_in_ptr = OpTypePointer " << storage_class << " " + << data_type << "\n"; + after_types << "%built_in_var = OpVariable %built_in_ptr " << storage_class + << "\n"; + generator.after_types_ = after_types.str(); + + EntryPoint entry_point; + entry_point.name = "main"; + entry_point.execution_model = execution_model; + if (strncmp(storage_class, "Input", 5) == 0 || + strncmp(storage_class, "Output", 6) == 0) { + entry_point.interfaces = "%built_in_var"; + } + // Any kind of reference would do. + entry_point.body = R"( +%val = OpBitcast %u64 %built_in_var +)"; + + std::ostringstream execution_modes; + if (0 == std::strcmp(execution_model, "Fragment")) { + execution_modes << "OpExecutionMode %" << entry_point.name + << " OriginUpperLeft\n"; + if (0 == std::strcmp(built_in, "FragDepth")) { + execution_modes << "OpExecutionMode %" << entry_point.name + << " DepthReplacing\n"; + } + } + if (0 == std::strcmp(execution_model, "Geometry")) { + execution_modes << "OpExecutionMode %" << entry_point.name + << " InputPoints\n"; + execution_modes << "OpExecutionMode %" << entry_point.name + << " OutputPoints\n"; + } + if (0 == std::strcmp(execution_model, "GLCompute")) { + execution_modes << "OpExecutionMode %" << entry_point.name + << " LocalSize 1 1 1\n"; + } + entry_point.execution_modes = execution_modes.str(); + + generator.entry_points_.push_back(std::move(entry_point)); + + CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(test_result.validation_result, + ValidateInstructions(SPV_ENV_VULKAN_1_0)); + if (test_result.error_str) { + EXPECT_THAT(getDiagnosticString(), HasSubstr(test_result.error_str)); + } + if (test_result.error_str2) { + EXPECT_THAT(getDiagnosticString(), HasSubstr(test_result.error_str2)); + } +} + +INSTANTIATE_TEST_CASE_P( + ClipAndCullDistanceOutputSuccess, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("ClipDistance", "CullDistance"), + Values("Vertex", "Geometry", "TessellationControl", + "TessellationEvaluation"), + Values("Output"), Values("%f32arr2", "%f32arr4"), + Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + ClipAndCullDistanceInputSuccess, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("ClipDistance", "CullDistance"), + Values("Fragment", "Geometry", "TessellationControl", + "TessellationEvaluation"), + Values("Input"), Values("%f32arr2", "%f32arr4"), + Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + ClipAndCullDistanceFragmentOutput, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("ClipDistance", "CullDistance"), Values("Fragment"), + Values("Output"), Values("%f32arr4"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "Vulkan spec doesn't allow BuiltIn ClipDistance/CullDistance " + "to be used for variables with Output storage class if " + "execution model is Fragment.", + "which is called with execution model Fragment."))), ); + +INSTANTIATE_TEST_CASE_P( + VertexIdAndInstanceIdVertexInput, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("VertexId", "InstanceId"), Values("Vertex"), Values("Input"), + Values("%u32"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "Vulkan spec doesn't allow BuiltIn VertexId/InstanceId to be " + "used."))), ); + +INSTANTIATE_TEST_CASE_P( + ClipAndCullDistanceVertexInput, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("ClipDistance", "CullDistance"), Values("Vertex"), + Values("Input"), Values("%f32arr4"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "Vulkan spec doesn't allow BuiltIn ClipDistance/CullDistance " + "to be used for variables with Input storage class if " + "execution model is Vertex.", + "which is called with execution model Vertex."))), ); + +INSTANTIATE_TEST_CASE_P( + ClipAndCullInvalidExecutionModel, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("ClipDistance", "CullDistance"), Values("GLCompute"), + Values("Input", "Output"), Values("%f32arr4"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "to be used only with Fragment, Vertex, TessellationControl, " + "TessellationEvaluation or Geometry execution models"))), ); + +INSTANTIATE_TEST_CASE_P( + ClipAndCullDistanceNotArray, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("ClipDistance", "CullDistance"), Values("Fragment"), + Values("Input"), Values("%f32vec2", "%f32vec4", "%f32"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit float array", + "is not an array"))), ); + +INSTANTIATE_TEST_CASE_P( + ClipAndCullDistanceNotFloatArray, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("ClipDistance", "CullDistance"), Values("Fragment"), + Values("Input"), Values("%u32arr2", "%u64arr4"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit float array", + "components are not float scalar"))), ); + +INSTANTIATE_TEST_CASE_P( + ClipAndCullDistanceNotF32Array, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("ClipDistance", "CullDistance"), Values("Fragment"), + Values("Input"), Values("%f64arr2", "%f64arr4"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit float array", + "has components with bit width 64"))), ); + +INSTANTIATE_TEST_CASE_P( + FragCoordSuccess, ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("FragCoord"), Values("Fragment"), Values("Input"), + Values("%f32vec4"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + FragCoordNotFragment, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine( + Values("FragCoord"), + Values("Vertex", "GLCompute", "Geometry", "TessellationControl", + "TessellationEvaluation"), + Values("Input"), Values("%f32vec4"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "to be used only with Fragment execution model"))), ); + +INSTANTIATE_TEST_CASE_P( + FragCoordNotInput, ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("FragCoord"), Values("Fragment"), Values("Output"), + Values("%f32vec4"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "to be only used for variables with Input storage class", + "uses storage class Output"))), ); + +INSTANTIATE_TEST_CASE_P( + FragCoordNotFloatVector, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("FragCoord"), Values("Fragment"), Values("Input"), + Values("%f32arr4", "%u32vec4"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 4-component 32-bit float vector", + "is not a float vector"))), ); + +INSTANTIATE_TEST_CASE_P( + FragCoordNotFloatVec4, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("FragCoord"), Values("Fragment"), Values("Input"), + Values("%f32vec3"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 4-component 32-bit float vector", + "has 3 components"))), ); + +INSTANTIATE_TEST_CASE_P( + FragCoordNotF32Vec4, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("FragCoord"), Values("Fragment"), Values("Input"), + Values("%f64vec4"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 4-component 32-bit float vector", + "has components with bit width 64"))), ); + +INSTANTIATE_TEST_CASE_P( + FragDepthSuccess, ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("FragDepth"), Values("Fragment"), Values("Output"), + Values("%f32"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + FragDepthNotFragment, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine( + Values("FragDepth"), + Values("Vertex", "GLCompute", "Geometry", "TessellationControl", + "TessellationEvaluation"), + Values("Output"), Values("%f32"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "to be used only with Fragment execution model"))), ); + +INSTANTIATE_TEST_CASE_P( + FragDepthNotOutput, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("FragDepth"), Values("Fragment"), Values("Input"), + Values("%f32"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "to be only used for variables with Output storage class", + "uses storage class Input"))), ); + +INSTANTIATE_TEST_CASE_P( + FragDepthNotFloatScalar, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("FragDepth"), Values("Fragment"), Values("Output"), + Values("%f32vec4", "%u32"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit float scalar", + "is not a float scalar"))), ); + +INSTANTIATE_TEST_CASE_P( + FragDepthNotF32, ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("FragDepth"), Values("Fragment"), Values("Output"), + Values("%f64"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit float scalar", + "has bit width 64"))), ); + +INSTANTIATE_TEST_CASE_P( + FrontFacingAndHelperInvocationSuccess, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("FrontFacing", "HelperInvocation"), Values("Fragment"), + Values("Input"), Values("%bool"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + FrontFacingAndHelperInvocationNotFragment, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine( + Values("FrontFacing", "HelperInvocation"), + Values("Vertex", "GLCompute", "Geometry", "TessellationControl", + "TessellationEvaluation"), + Values("Input"), Values("%bool"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "to be used only with Fragment execution model"))), ); + +INSTANTIATE_TEST_CASE_P( + FrontFacingAndHelperInvocationNotInput, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("FrontFacing", "HelperInvocation"), Values("Fragment"), + Values("Output"), Values("%bool"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "to be only used for variables with Input storage class", + "uses storage class Output"))), ); + +INSTANTIATE_TEST_CASE_P( + FrontFacingAndHelperInvocationNotBool, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("FrontFacing", "HelperInvocation"), Values("Fragment"), + Values("Input"), Values("%f32", "%u32"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a bool scalar", + "is not a bool scalar"))), ); + +INSTANTIATE_TEST_CASE_P( + ComputeShaderInputInt32Vec3Success, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("GlobalInvocationId", "LocalInvocationId", "NumWorkgroups", + "WorkgroupId"), + Values("GLCompute"), Values("Input"), Values("%u32vec3"), + Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + ComputeShaderInputInt32Vec3NotGLCompute, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("GlobalInvocationId", "LocalInvocationId", "NumWorkgroups", + "WorkgroupId"), + Values("Vertex", "Fragment", "Geometry", "TessellationControl", + "TessellationEvaluation"), + Values("Input"), Values("%u32vec3"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "to be used only with GLCompute execution model"))), ); + +INSTANTIATE_TEST_CASE_P( + ComputeShaderInputInt32Vec3NotInput, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("GlobalInvocationId", "LocalInvocationId", "NumWorkgroups", + "WorkgroupId"), + Values("GLCompute"), Values("Output"), Values("%u32vec3"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "to be only used for variables with Input storage class", + "uses storage class Output"))), ); + +INSTANTIATE_TEST_CASE_P( + ComputeShaderInputInt32Vec3NotIntVector, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("GlobalInvocationId", "LocalInvocationId", "NumWorkgroups", + "WorkgroupId"), + Values("GLCompute"), Values("Input"), + Values("%u32arr3", "%f32vec3"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 3-component 32-bit int vector", + "is not an int vector"))), ); + +INSTANTIATE_TEST_CASE_P( + ComputeShaderInputInt32Vec3NotIntVec3, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("GlobalInvocationId", "LocalInvocationId", "NumWorkgroups", + "WorkgroupId"), + Values("GLCompute"), Values("Input"), Values("%u32vec4"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 3-component 32-bit int vector", + "has 4 components"))), ); + +INSTANTIATE_TEST_CASE_P( + ComputeShaderInputInt32Vec3NotInt32Vec, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("GlobalInvocationId", "LocalInvocationId", "NumWorkgroups", + "WorkgroupId"), + Values("GLCompute"), Values("Input"), Values("%u64vec3"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 3-component 32-bit int vector", + "has components with bit width 64"))), ); + +INSTANTIATE_TEST_CASE_P( + InvocationIdSuccess, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("InvocationId"), Values("Geometry", "TessellationControl"), + Values("Input"), Values("%u32"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + InvocationIdInvalidExecutionModel, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("InvocationId"), + Values("Vertex", "Fragment", "GLCompute", "TessellationEvaluation"), + Values("Input"), Values("%u32"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "to be used only with TessellationControl or " + "Geometry execution models"))), ); + +INSTANTIATE_TEST_CASE_P( + InvocationIdNotInput, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("InvocationId"), Values("Geometry", "TessellationControl"), + Values("Output"), Values("%u32"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "to be only used for variables with Input storage class", + "uses storage class Output"))), ); + +INSTANTIATE_TEST_CASE_P( + InvocationIdNotIntScalar, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("InvocationId"), Values("Geometry", "TessellationControl"), + Values("Input"), Values("%f32", "%u32vec3"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit int scalar", + "is not an int scalar"))), ); + +INSTANTIATE_TEST_CASE_P( + InvocationIdNotInt32, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("InvocationId"), Values("Geometry", "TessellationControl"), + Values("Input"), Values("%u64"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit int scalar", + "has bit width 64"))), ); + +INSTANTIATE_TEST_CASE_P( + InstanceIndexSuccess, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("InstanceIndex"), Values("Vertex"), Values("Input"), + Values("%u32"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + InstanceIndexInvalidExecutionModel, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine( + Values("InstanceIndex"), + Values("Geometry", "Fragment", "GLCompute", "TessellationControl", + "TessellationEvaluation"), + Values("Input"), Values("%u32"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "to be used only with Vertex execution model"))), ); + +INSTANTIATE_TEST_CASE_P( + InstanceIndexNotInput, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("InstanceIndex"), Values("Vertex"), Values("Output"), + Values("%u32"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "to be only used for variables with Input storage class", + "uses storage class Output"))), ); + +INSTANTIATE_TEST_CASE_P( + InstanceIndexNotIntScalar, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("InstanceIndex"), Values("Vertex"), Values("Input"), + Values("%f32", "%u32vec3"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit int scalar", + "is not an int scalar"))), ); + +INSTANTIATE_TEST_CASE_P( + InstanceIndexNotInt32, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("InstanceIndex"), Values("Vertex"), Values("Input"), + Values("%u64"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit int scalar", + "has bit width 64"))), ); + +INSTANTIATE_TEST_CASE_P( + LayerAndViewportIndexInputSuccess, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("Layer", "ViewportIndex"), Values("Fragment"), + Values("Input"), Values("%u32"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + LayerAndViewportIndexOutputSuccess, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("Layer", "ViewportIndex"), Values("Geometry"), + Values("Output"), Values("%u32"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + LayerAndViewportIndexInvalidExecutionModel, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("Layer", "ViewportIndex"), + Values("TessellationControl", "GLCompute"), Values("Input"), + Values("%u32"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "to be used only with Vertex, TessellationEvaluation, " + "Geometry, or Fragment execution models"))), ); + +INSTANTIATE_TEST_CASE_P( + LayerAndViewportIndexExecutionModelEnabledByCapability, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("Layer", "ViewportIndex"), + Values("Vertex", "TessellationEvaluation"), Values("Output"), + Values("%u32"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "requires the ShaderViewportIndexLayerEXT capability"))), ); + +INSTANTIATE_TEST_CASE_P( + LayerAndViewportIndexFragmentNotInput, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine( + Values("Layer", "ViewportIndex"), Values("Fragment"), Values("Output"), + Values("%u32"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "Output storage class if execution model is Fragment", + "which is called with execution model Fragment"))), ); + +INSTANTIATE_TEST_CASE_P( + LayerAndViewportIndexGeometryNotOutput, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine( + Values("Layer", "ViewportIndex"), + Values("Vertex", "TessellationEvaluation", "Geometry"), Values("Input"), + Values("%u32"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "Input storage class if execution model is Vertex, " + "TessellationEvaluation, or Geometry", + "which is called with execution model"))), ); + +INSTANTIATE_TEST_CASE_P( + LayerAndViewportIndexNotIntScalar, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("Layer", "ViewportIndex"), Values("Fragment"), + Values("Input"), Values("%f32", "%u32vec3"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit int scalar", + "is not an int scalar"))), ); + +INSTANTIATE_TEST_CASE_P( + LayerAndViewportIndexNotInt32, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("Layer", "ViewportIndex"), Values("Fragment"), + Values("Input"), Values("%u64"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit int scalar", + "has bit width 64"))), ); + +INSTANTIATE_TEST_CASE_P( + PatchVerticesSuccess, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("PatchVertices"), + Values("TessellationEvaluation", "TessellationControl"), + Values("Input"), Values("%u32"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + PatchVerticesInvalidExecutionModel, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("PatchVertices"), + Values("Vertex", "Fragment", "GLCompute", "Geometry"), + Values("Input"), Values("%u32"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "to be used only with TessellationControl or " + "TessellationEvaluation execution models"))), ); + +INSTANTIATE_TEST_CASE_P( + PatchVerticesNotInput, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("PatchVertices"), + Values("TessellationEvaluation", "TessellationControl"), + Values("Output"), Values("%u32"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "to be only used for variables with Input storage class", + "uses storage class Output"))), ); + +INSTANTIATE_TEST_CASE_P( + PatchVerticesNotIntScalar, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("PatchVertices"), + Values("TessellationEvaluation", "TessellationControl"), + Values("Input"), Values("%f32", "%u32vec3"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit int scalar", + "is not an int scalar"))), ); + +INSTANTIATE_TEST_CASE_P( + PatchVerticesNotInt32, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("PatchVertices"), + Values("TessellationEvaluation", "TessellationControl"), + Values("Input"), Values("%u64"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit int scalar", + "has bit width 64"))), ); + +INSTANTIATE_TEST_CASE_P( + PointCoordSuccess, ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("PointCoord"), Values("Fragment"), Values("Input"), + Values("%f32vec2"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + PointCoordNotFragment, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine( + Values("PointCoord"), + Values("Vertex", "GLCompute", "Geometry", "TessellationControl", + "TessellationEvaluation"), + Values("Input"), Values("%f32vec2"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "to be used only with Fragment execution model"))), ); + +INSTANTIATE_TEST_CASE_P( + PointCoordNotInput, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("PointCoord"), Values("Fragment"), Values("Output"), + Values("%f32vec2"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "to be only used for variables with Input storage class", + "uses storage class Output"))), ); + +INSTANTIATE_TEST_CASE_P( + PointCoordNotFloatVector, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("PointCoord"), Values("Fragment"), Values("Input"), + Values("%f32arr2", "%u32vec2"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 2-component 32-bit float vector", + "is not a float vector"))), ); + +INSTANTIATE_TEST_CASE_P( + PointCoordNotFloatVec3, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("PointCoord"), Values("Fragment"), Values("Input"), + Values("%f32vec3"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 2-component 32-bit float vector", + "has 3 components"))), ); + +INSTANTIATE_TEST_CASE_P( + PointCoordNotF32Vec4, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("PointCoord"), Values("Fragment"), Values("Input"), + Values("%f64vec2"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 2-component 32-bit float vector", + "has components with bit width 64"))), ); + +INSTANTIATE_TEST_CASE_P( + PointSizeOutputSuccess, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("PointSize"), + Values("Vertex", "Geometry", "TessellationControl", + "TessellationEvaluation"), + Values("Output"), Values("%f32"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + PointSizeInputSuccess, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("PointSize"), + Values("Geometry", "TessellationControl", "TessellationEvaluation"), + Values("Input"), Values("%f32"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + PointSizeVertexInput, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("PointSize"), Values("Vertex"), Values("Input"), + Values("%f32"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "Vulkan spec doesn't allow BuiltIn PointSize " + "to be used for variables with Input storage class if " + "execution model is Vertex.", + "which is called with execution model Vertex."))), ); + +INSTANTIATE_TEST_CASE_P( + PointSizeInvalidExecutionModel, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("PointSize"), Values("GLCompute", "Fragment"), + Values("Input", "Output"), Values("%f32"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "to be used only with Vertex, TessellationControl, " + "TessellationEvaluation or Geometry execution models"))), ); + +INSTANTIATE_TEST_CASE_P( + PointSizeNotFloatScalar, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("PointSize"), Values("Vertex"), Values("Output"), + Values("%f32vec4", "%u32"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit float scalar", + "is not a float scalar"))), ); + +INSTANTIATE_TEST_CASE_P( + PointSizeNotF32, ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("PointSize"), Values("Vertex"), Values("Output"), + Values("%f64"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit float scalar", + "has bit width 64"))), ); + +INSTANTIATE_TEST_CASE_P( + PositionOutputSuccess, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("Position"), + Values("Vertex", "Geometry", "TessellationControl", + "TessellationEvaluation"), + Values("Output"), Values("%f32vec4"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + PositionInputSuccess, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("Position"), + Values("Geometry", "TessellationControl", "TessellationEvaluation"), + Values("Input"), Values("%f32vec4"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + PositionVertexInput, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("Position"), Values("Vertex"), Values("Input"), + Values("%f32vec4"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "Vulkan spec doesn't allow BuiltIn Position " + "to be used for variables with Input storage class if " + "execution model is Vertex.", + "which is called with execution model Vertex."))), ); + +INSTANTIATE_TEST_CASE_P( + PositionInvalidExecutionModel, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("Position"), Values("GLCompute", "Fragment"), + Values("Input", "Output"), Values("%f32vec4"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "to be used only with Vertex, TessellationControl, " + "TessellationEvaluation or Geometry execution models"))), ); + +INSTANTIATE_TEST_CASE_P( + PositionNotFloatVector, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("Position"), Values("Geometry"), Values("Input"), + Values("%f32arr4", "%u32vec4"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 4-component 32-bit float vector", + "is not a float vector"))), ); + +INSTANTIATE_TEST_CASE_P( + PositionNotFloatVec4, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("Position"), Values("Geometry"), Values("Input"), + Values("%f32vec3"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 4-component 32-bit float vector", + "has 3 components"))), ); + +INSTANTIATE_TEST_CASE_P( + PositionNotF32Vec4, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("Position"), Values("Geometry"), Values("Input"), + Values("%f64vec4"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 4-component 32-bit float vector", + "has components with bit width 64"))), ); + +INSTANTIATE_TEST_CASE_P( + PrimitiveIdInputSuccess, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("PrimitiveId"), + Values("Fragment", "TessellationControl", "TessellationEvaluation", + "Geometry"), + Values("Input"), Values("%u32"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + PrimitiveIdOutputSuccess, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("PrimitiveId"), Values("Geometry"), Values("Output"), + Values("%u32"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + PrimitiveIdInvalidExecutionModel, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("PrimitiveId"), Values("Vertex", "GLCompute"), + Values("Input"), Values("%u32"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "to be used only with Fragment, TessellationControl, " + "TessellationEvaluation or Geometry execution models"))), ); + +INSTANTIATE_TEST_CASE_P( + PrimitiveIdFragmentNotInput, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine( + Values("PrimitiveId"), Values("Fragment"), Values("Output"), + Values("%u32"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "Output storage class if execution model is Fragment", + "which is called with execution model Fragment"))), ); + +INSTANTIATE_TEST_CASE_P( + PrimitiveIdGeometryNotInput, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("PrimitiveId"), + Values("TessellationControl", "TessellationEvaluation"), + Values("Output"), Values("%u32"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "Output storage class if execution model is Tessellation", + "which is called with execution model Tessellation"))), ); + +INSTANTIATE_TEST_CASE_P( + PrimitiveIdNotIntScalar, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("PrimitiveId"), Values("Fragment"), Values("Input"), + Values("%f32", "%u32vec3"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit int scalar", + "is not an int scalar"))), ); + +INSTANTIATE_TEST_CASE_P( + PrimitiveIdNotInt32, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("PrimitiveId"), Values("Fragment"), Values("Input"), + Values("%u64"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit int scalar", + "has bit width 64"))), ); + +INSTANTIATE_TEST_CASE_P( + SampleIdSuccess, ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("SampleId"), Values("Fragment"), Values("Input"), + Values("%u32"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + SampleIdInvalidExecutionModel, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine( + Values("SampleId"), + Values("Vertex", "GLCompute", "Geometry", "TessellationControl", + "TessellationEvaluation"), + Values("Input"), Values("%u32"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "to be used only with Fragment execution model"))), ); + +INSTANTIATE_TEST_CASE_P( + SampleIdNotInput, ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine( + Values("SampleId"), Values("Fragment"), Values("Output"), + Values("%u32"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "Vulkan spec allows BuiltIn SampleId to be only used " + "for variables with Input storage class"))), ); + +INSTANTIATE_TEST_CASE_P( + SampleIdNotIntScalar, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("SampleId"), Values("Fragment"), Values("Input"), + Values("%f32", "%u32vec3"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit int scalar", + "is not an int scalar"))), ); + +INSTANTIATE_TEST_CASE_P( + SampleIdNotInt32, ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("SampleId"), Values("Fragment"), Values("Input"), + Values("%u64"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit int scalar", + "has bit width 64"))), ); + +INSTANTIATE_TEST_CASE_P( + SampleMaskSuccess, ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("SampleMask"), Values("Fragment"), Values("Input", "Output"), + Values("%u32arr2", "%u32arr4"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + SampleMaskInvalidExecutionModel, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine( + Values("SampleMask"), + Values("Vertex", "GLCompute", "Geometry", "TessellationControl", + "TessellationEvaluation"), + Values("Input"), Values("%u32arr2"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "to be used only with Fragment execution model"))), ); + +INSTANTIATE_TEST_CASE_P( + SampleMaskWrongStorageClass, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("SampleMask"), Values("Fragment"), Values("Workgroup"), + Values("%u32arr2"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "Vulkan spec allows BuiltIn SampleMask to be only used for " + "variables with Input or Output storage class"))), ); + +INSTANTIATE_TEST_CASE_P( + SampleMaskNotArray, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("SampleMask"), Values("Fragment"), Values("Input"), + Values("%f32", "%u32vec3"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit int array", + "is not an array"))), ); + +INSTANTIATE_TEST_CASE_P( + SampleMaskNotIntArray, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("SampleMask"), Values("Fragment"), Values("Input"), + Values("%f32arr2"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit int array", + "components are not int scalar"))), ); + +INSTANTIATE_TEST_CASE_P( + SampleMaskNotInt32Array, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("SampleMask"), Values("Fragment"), Values("Input"), + Values("%u64arr2"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit int array", + "has components with bit width 64"))), ); + +INSTANTIATE_TEST_CASE_P( + SamplePositionSuccess, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("SamplePosition"), Values("Fragment"), Values("Input"), + Values("%f32vec2"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + SamplePositionNotFragment, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine( + Values("SamplePosition"), + Values("Vertex", "GLCompute", "Geometry", "TessellationControl", + "TessellationEvaluation"), + Values("Input"), Values("%f32vec2"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "to be used only with Fragment execution model"))), ); + +INSTANTIATE_TEST_CASE_P( + SamplePositionNotInput, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("SamplePosition"), Values("Fragment"), Values("Output"), + Values("%f32vec2"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "to be only used for variables with Input storage class", + "uses storage class Output"))), ); + +INSTANTIATE_TEST_CASE_P( + SamplePositionNotFloatVector, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("SamplePosition"), Values("Fragment"), Values("Input"), + Values("%f32arr2", "%u32vec4"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 2-component 32-bit float vector", + "is not a float vector"))), ); + +INSTANTIATE_TEST_CASE_P( + SamplePositionNotFloatVec2, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("SamplePosition"), Values("Fragment"), Values("Input"), + Values("%f32vec3"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 2-component 32-bit float vector", + "has 3 components"))), ); + +INSTANTIATE_TEST_CASE_P( + SamplePositionNotF32Vec2, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("SamplePosition"), Values("Fragment"), Values("Input"), + Values("%f64vec2"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 2-component 32-bit float vector", + "has components with bit width 64"))), ); + +INSTANTIATE_TEST_CASE_P( + TessCoordSuccess, ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("TessCoord"), Values("TessellationEvaluation"), + Values("Input"), Values("%f32vec3"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + TessCoordNotFragment, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine( + Values("TessCoord"), + Values("Vertex", "GLCompute", "Geometry", "TessellationControl", + "Fragment"), + Values("Input"), Values("%f32vec3"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "to be used only with TessellationEvaluation execution model"))), ); + +INSTANTIATE_TEST_CASE_P( + TessCoordNotInput, ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("TessCoord"), Values("Fragment"), Values("Output"), + Values("%f32vec3"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "to be only used for variables with Input storage class", + "uses storage class Output"))), ); + +INSTANTIATE_TEST_CASE_P( + TessCoordNotFloatVector, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("TessCoord"), Values("Fragment"), Values("Input"), + Values("%f32arr3", "%u32vec4"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 3-component 32-bit float vector", + "is not a float vector"))), ); + +INSTANTIATE_TEST_CASE_P( + TessCoordNotFloatVec3, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("TessCoord"), Values("Fragment"), Values("Input"), + Values("%f32vec2"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 3-component 32-bit float vector", + "has 2 components"))), ); + +INSTANTIATE_TEST_CASE_P( + TessCoordNotF32Vec3, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("TessCoord"), Values("Fragment"), Values("Input"), + Values("%f64vec3"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 3-component 32-bit float vector", + "has components with bit width 64"))), ); + +INSTANTIATE_TEST_CASE_P( + TessLevelOuterTeseInputSuccess, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("TessLevelOuter"), Values("TessellationEvaluation"), + Values("Input"), Values("%f32arr4"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + TessLevelOuterTescOutputSuccess, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("TessLevelOuter"), Values("TessellationControl"), + Values("Output"), Values("%f32arr4"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + TessLevelOuterInvalidExecutionModel, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("TessLevelOuter"), + Values("Vertex", "GLCompute", "Geometry", "Fragment"), + Values("Input"), Values("%f32arr4"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "to be used only with TessellationControl or " + "TessellationEvaluation execution models."))), ); + +INSTANTIATE_TEST_CASE_P( + TessLevelOuterOutputTese, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("TessLevelOuter"), Values("TessellationEvaluation"), + Values("Output"), Values("%f32arr4"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "Vulkan spec doesn't allow TessLevelOuter/TessLevelInner to be " + "used for variables with Output storage class if execution " + "model is TessellationEvaluation."))), ); + +INSTANTIATE_TEST_CASE_P( + TessLevelOuterInputTesc, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("TessLevelOuter"), Values("TessellationControl"), + Values("Input"), Values("%f32arr4"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "Vulkan spec doesn't allow TessLevelOuter/TessLevelInner to be " + "used for variables with Input storage class if execution " + "model is TessellationControl."))), ); + +INSTANTIATE_TEST_CASE_P( + TessLevelOuterNotArray, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("TessLevelOuter"), Values("TessellationEvaluation"), + Values("Input"), Values("%f32vec4", "%f32"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 4-component 32-bit float array", + "is not an array"))), ); + +INSTANTIATE_TEST_CASE_P( + TessLevelOuterNotFloatArray, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("TessLevelOuter"), Values("TessellationEvaluation"), + Values("Input"), Values("%u32arr4"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 4-component 32-bit float array", + "components are not float scalar"))), ); + +INSTANTIATE_TEST_CASE_P( + TessLevelOuterNotFloatArr4, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("TessLevelOuter"), Values("TessellationEvaluation"), + Values("Input"), Values("%f32arr3"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 4-component 32-bit float array", + "has 3 components"))), ); + +INSTANTIATE_TEST_CASE_P( + TessLevelOuterNotF32Arr4, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("TessLevelOuter"), Values("TessellationEvaluation"), + Values("Input"), Values("%f64arr4"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 4-component 32-bit float array", + "has components with bit width 64"))), ); + +INSTANTIATE_TEST_CASE_P( + TessLevelInnerTeseInputSuccess, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("TessLevelInner"), Values("TessellationEvaluation"), + Values("Input"), Values("%f32arr2"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + TessLevelInnerTescOutputSuccess, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("TessLevelInner"), Values("TessellationControl"), + Values("Output"), Values("%f32arr2"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + TessLevelInnerInvalidExecutionModel, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("TessLevelInner"), + Values("Vertex", "GLCompute", "Geometry", "Fragment"), + Values("Input"), Values("%f32arr2"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "to be used only with TessellationControl or " + "TessellationEvaluation execution models."))), ); + +INSTANTIATE_TEST_CASE_P( + TessLevelInnerOutputTese, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("TessLevelInner"), Values("TessellationEvaluation"), + Values("Output"), Values("%f32arr2"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "Vulkan spec doesn't allow TessLevelOuter/TessLevelInner to be " + "used for variables with Output storage class if execution " + "model is TessellationEvaluation."))), ); + +INSTANTIATE_TEST_CASE_P( + TessLevelInnerInputTesc, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("TessLevelInner"), Values("TessellationControl"), + Values("Input"), Values("%f32arr2"), + Values(TestResult( + SPV_ERROR_INVALID_DATA, + "Vulkan spec doesn't allow TessLevelOuter/TessLevelInner to be " + "used for variables with Input storage class if execution " + "model is TessellationControl."))), ); + +INSTANTIATE_TEST_CASE_P( + TessLevelInnerNotArray, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("TessLevelInner"), Values("TessellationEvaluation"), + Values("Input"), Values("%f32vec2", "%f32"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 2-component 32-bit float array", + "is not an array"))), ); + +INSTANTIATE_TEST_CASE_P( + TessLevelInnerNotFloatArray, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("TessLevelInner"), Values("TessellationEvaluation"), + Values("Input"), Values("%u32arr2"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 2-component 32-bit float array", + "components are not float scalar"))), ); + +INSTANTIATE_TEST_CASE_P( + TessLevelInnerNotFloatArr2, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("TessLevelInner"), Values("TessellationEvaluation"), + Values("Input"), Values("%f32arr3"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 2-component 32-bit float array", + "has 3 components"))), ); + +INSTANTIATE_TEST_CASE_P( + TessLevelInnerNotF32Arr2, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("TessLevelInner"), Values("TessellationEvaluation"), + Values("Input"), Values("%f64arr2"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 2-component 32-bit float array", + "has components with bit width 64"))), ); + +INSTANTIATE_TEST_CASE_P( + VertexIndexSuccess, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("VertexIndex"), Values("Vertex"), Values("Input"), + Values("%u32"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + VertexIndexInvalidExecutionModel, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine( + Values("VertexIndex"), + Values("Fragment", "GLCompute", "Geometry", "TessellationControl", + "TessellationEvaluation"), + Values("Input"), Values("%u32"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "to be used only with Vertex execution model"))), ); + +INSTANTIATE_TEST_CASE_P( + VertexIndexNotInput, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine( + Values("VertexIndex"), Values("Vertex"), Values("Output"), + Values("%u32"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "Vulkan spec allows BuiltIn VertexIndex to be only " + "used for variables with Input storage class"))), ); + +INSTANTIATE_TEST_CASE_P( + VertexIndexNotIntScalar, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("VertexIndex"), Values("Vertex"), Values("Input"), + Values("%f32", "%u32vec3"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit int scalar", + "is not an int scalar"))), ); + +INSTANTIATE_TEST_CASE_P( + VertexIndexNotInt32, + ValidateVulkanCombineBuiltInExecutionModelDataTypeResult, + Combine(Values("VertexIndex"), Values("Vertex"), Values("Input"), + Values("%u64"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit int scalar", + "has bit width 64"))), ); + +TEST_P(ValidateVulkanCombineBuiltInArrayedVariable, Variable) { + const char* const built_in = std::get<0>(GetParam()); + const char* const execution_model = std::get<1>(GetParam()); + const char* const storage_class = std::get<2>(GetParam()); + const char* const data_type = std::get<3>(GetParam()); + const TestResult& test_result = std::get<4>(GetParam()); + + CodeGenerator generator = GetDefaultShaderCodeGenerator(); + generator.before_types_ = "OpDecorate %built_in_var BuiltIn "; + generator.before_types_ += built_in; + generator.before_types_ += "\n"; + + std::ostringstream after_types; + after_types << "%built_in_array = OpTypeArray " << data_type << " %u32_3\n"; + after_types << "%built_in_ptr = OpTypePointer " << storage_class + << " %built_in_array\n"; + after_types << "%built_in_var = OpVariable %built_in_ptr " << storage_class + << "\n"; + generator.after_types_ = after_types.str(); + + EntryPoint entry_point; + entry_point.name = "main"; + entry_point.execution_model = execution_model; + entry_point.interfaces = "%built_in_var"; + // Any kind of reference would do. + entry_point.body = R"( +%val = OpBitcast %u64 %built_in_var +)"; + + std::ostringstream execution_modes; + if (0 == std::strcmp(execution_model, "Fragment")) { + execution_modes << "OpExecutionMode %" << entry_point.name + << " OriginUpperLeft\n"; + if (0 == std::strcmp(built_in, "FragDepth")) { + execution_modes << "OpExecutionMode %" << entry_point.name + << " DepthReplacing\n"; + } + } + if (0 == std::strcmp(execution_model, "Geometry")) { + execution_modes << "OpExecutionMode %" << entry_point.name + << " InputPoints\n"; + execution_modes << "OpExecutionMode %" << entry_point.name + << " OutputPoints\n"; + } + if (0 == std::strcmp(execution_model, "GLCompute")) { + execution_modes << "OpExecutionMode %" << entry_point.name + << " LocalSize 1 1 1\n"; + } + entry_point.execution_modes = execution_modes.str(); + + generator.entry_points_.push_back(std::move(entry_point)); + + CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(test_result.validation_result, + ValidateInstructions(SPV_ENV_VULKAN_1_0)); + if (test_result.error_str) { + EXPECT_THAT(getDiagnosticString(), HasSubstr(test_result.error_str)); + } + if (test_result.error_str2) { + EXPECT_THAT(getDiagnosticString(), HasSubstr(test_result.error_str2)); + } +} + +INSTANTIATE_TEST_CASE_P(PointSizeArrayedF32TessControl, + ValidateVulkanCombineBuiltInArrayedVariable, + Combine(Values("PointSize"), + Values("TessellationControl"), Values("Input"), + Values("%f32"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + PointSizeArrayedF64TessControl, ValidateVulkanCombineBuiltInArrayedVariable, + Combine(Values("PointSize"), Values("TessellationControl"), Values("Input"), + Values("%f64"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit float scalar", + "has bit width 64"))), ); + +INSTANTIATE_TEST_CASE_P( + PointSizeArrayedF32Vertex, ValidateVulkanCombineBuiltInArrayedVariable, + Combine(Values("PointSize"), Values("Vertex"), Values("Output"), + Values("%f32"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit float scalar", + "is not a float scalar"))), ); + +INSTANTIATE_TEST_CASE_P(PositionArrayedF32Vec4TessControl, + ValidateVulkanCombineBuiltInArrayedVariable, + Combine(Values("Position"), + Values("TessellationControl"), Values("Input"), + Values("%f32vec4"), Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + PositionArrayedF32Vec3TessControl, + ValidateVulkanCombineBuiltInArrayedVariable, + Combine(Values("Position"), Values("TessellationControl"), Values("Input"), + Values("%f32vec3"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 4-component 32-bit float vector", + "has 3 components"))), ); + +INSTANTIATE_TEST_CASE_P( + PositionArrayedF32Vec4Vertex, ValidateVulkanCombineBuiltInArrayedVariable, + Combine(Values("Position"), Values("Vertex"), Values("Output"), + Values("%f32"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 4-component 32-bit float vector", + "is not a float vector"))), ); + +INSTANTIATE_TEST_CASE_P( + ClipAndCullDistanceOutputSuccess, + ValidateVulkanCombineBuiltInArrayedVariable, + Combine(Values("ClipDistance", "CullDistance"), + Values("Geometry", "TessellationControl", "TessellationEvaluation"), + Values("Output"), Values("%f32arr2", "%f32arr4"), + Values(TestResult())), ); + +INSTANTIATE_TEST_CASE_P( + ClipAndCullDistanceVertexInput, ValidateVulkanCombineBuiltInArrayedVariable, + Combine(Values("ClipDistance", "CullDistance"), Values("Fragment"), + Values("Input"), Values("%f32arr4"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit float array", + "components are not float scalar"))), ); + +INSTANTIATE_TEST_CASE_P( + ClipAndCullDistanceNotArray, ValidateVulkanCombineBuiltInArrayedVariable, + Combine(Values("ClipDistance", "CullDistance"), + Values("Geometry", "TessellationControl", "TessellationEvaluation"), + Values("Input"), Values("%f32vec2", "%f32vec4"), + Values(TestResult(SPV_ERROR_INVALID_DATA, + "needs to be a 32-bit float array", + "components are not float scalar"))), ); + +TEST_F(ValidateBuiltIns, WorkgroupSizeSuccess) { + CodeGenerator generator = GetDefaultShaderCodeGenerator(); + generator.before_types_ = R"( +OpDecorate %workgroup_size BuiltIn WorkgroupSize +)"; + + generator.after_types_ = R"( +%workgroup_size = OpConstantComposite %u32vec3 %u32_1 %u32_1 %u32_1 +)"; + + EntryPoint entry_point; + entry_point.name = "main"; + entry_point.execution_model = "GLCompute"; + entry_point.body = R"( +%copy = OpCopyObject %u32vec3 %workgroup_size +)"; + generator.entry_points_.push_back(std::move(entry_point)); + + CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0)); +} + +TEST_F(ValidateBuiltIns, WorkgroupSizeFragment) { + CodeGenerator generator = GetDefaultShaderCodeGenerator(); + generator.before_types_ = R"( +OpDecorate %workgroup_size BuiltIn WorkgroupSize +)"; + + generator.after_types_ = R"( +%workgroup_size = OpConstantComposite %u32vec3 %u32_1 %u32_1 %u32_1 +)"; + + EntryPoint entry_point; + entry_point.name = "main"; + entry_point.execution_model = "Fragment"; + entry_point.execution_modes = "OpExecutionMode %main OriginUpperLeft"; + entry_point.body = R"( +%copy = OpCopyObject %u32vec3 %workgroup_size +)"; + generator.entry_points_.push_back(std::move(entry_point)); + + CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Vulkan spec allows BuiltIn WorkgroupSize to be used " + "only with GLCompute execution model")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("is referencing ID <2> (OpConstantComposite) which is " + "decorated with BuiltIn WorkgroupSize in function <1> " + "called with execution model Fragment")); +} + +TEST_F(ValidateBuiltIns, WorkgroupSizeNotConstant) { + CodeGenerator generator = GetDefaultShaderCodeGenerator(); + generator.before_types_ = R"( +OpDecorate %copy BuiltIn WorkgroupSize +)"; + + generator.after_types_ = R"( +%workgroup_size = OpConstantComposite %u32vec3 %u32_1 %u32_1 %u32_1 +)"; + + EntryPoint entry_point; + entry_point.name = "main"; + entry_point.execution_model = "GLCompute"; + entry_point.body = R"( +%copy = OpCopyObject %u32vec3 %workgroup_size +)"; + generator.entry_points_.push_back(std::move(entry_point)); + + CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Vulkan spec requires BuiltIn WorkgroupSize to be a " + "constant. ID <2> (OpCopyObject) is not a constant")); +} + +TEST_F(ValidateBuiltIns, WorkgroupSizeNotVector) { + CodeGenerator generator = GetDefaultShaderCodeGenerator(); + generator.before_types_ = R"( +OpDecorate %workgroup_size BuiltIn WorkgroupSize +)"; + + generator.after_types_ = R"( +%workgroup_size = OpConstant %u32 16 +)"; + + EntryPoint entry_point; + entry_point.name = "main"; + entry_point.execution_model = "GLCompute"; + entry_point.body = R"( +%copy = OpCopyObject %u32 %workgroup_size +)"; + generator.entry_points_.push_back(std::move(entry_point)); + + CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("According to the Vulkan spec BuiltIn WorkgroupSize " + "variable needs to be a 3-component 32-bit int vector. " + "ID <2> (OpConstant) is not an int vector.")); +} + +TEST_F(ValidateBuiltIns, WorkgroupSizeNotIntVector) { + CodeGenerator generator = GetDefaultShaderCodeGenerator(); + generator.before_types_ = R"( +OpDecorate %workgroup_size BuiltIn WorkgroupSize +)"; + + generator.after_types_ = R"( +%workgroup_size = OpConstantComposite %f32vec3 %f32_1 %f32_1 %f32_1 +)"; + + EntryPoint entry_point; + entry_point.name = "main"; + entry_point.execution_model = "GLCompute"; + entry_point.body = R"( +%copy = OpCopyObject %f32vec3 %workgroup_size +)"; + generator.entry_points_.push_back(std::move(entry_point)); + + CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("According to the Vulkan spec BuiltIn WorkgroupSize " + "variable needs to be a 3-component 32-bit int vector. " + "ID <2> (OpConstantComposite) is not an int vector.")); +} + +TEST_F(ValidateBuiltIns, WorkgroupSizeNotVec3) { + CodeGenerator generator = GetDefaultShaderCodeGenerator(); + generator.before_types_ = R"( +OpDecorate %workgroup_size BuiltIn WorkgroupSize +)"; + + generator.after_types_ = R"( +%workgroup_size = OpConstantComposite %u32vec2 %u32_1 %u32_1 +)"; + + EntryPoint entry_point; + entry_point.name = "main"; + entry_point.execution_model = "GLCompute"; + entry_point.body = R"( +%copy = OpCopyObject %u32vec2 %workgroup_size +)"; + generator.entry_points_.push_back(std::move(entry_point)); + + CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("According to the Vulkan spec BuiltIn WorkgroupSize " + "variable needs to be a 3-component 32-bit int vector. " + "ID <2> (OpConstantComposite) has 2 components.")); +} + +TEST_F(ValidateBuiltIns, WorkgroupSizeNotInt32Vec) { + CodeGenerator generator = GetDefaultShaderCodeGenerator(); + generator.before_types_ = R"( +OpDecorate %workgroup_size BuiltIn WorkgroupSize +)"; + + generator.after_types_ = R"( +%workgroup_size = OpConstantComposite %u64vec3 %u64_1 %u64_1 %u64_1 +)"; + + EntryPoint entry_point; + entry_point.name = "main"; + entry_point.execution_model = "GLCompute"; + entry_point.body = R"( +%copy = OpCopyObject %u64vec3 %workgroup_size +)"; + generator.entry_points_.push_back(std::move(entry_point)); + + CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("According to the Vulkan spec BuiltIn WorkgroupSize variable " + "needs to be a 3-component 32-bit int vector. ID <2> " + "(OpConstantComposite) has components with bit width 64.")); +} + +TEST_F(ValidateBuiltIns, WorkgroupSizePrivateVar) { + CodeGenerator generator = GetDefaultShaderCodeGenerator(); + generator.before_types_ = R"( +OpDecorate %workgroup_size BuiltIn WorkgroupSize +)"; + + generator.after_types_ = R"( +%workgroup_size = OpConstantComposite %u32vec3 %u32_1 %u32_1 %u32_1 +%private_ptr_u32vec3 = OpTypePointer Private %u32vec3 +%var = OpVariable %private_ptr_u32vec3 Private %workgroup_size +)"; + + EntryPoint entry_point; + entry_point.name = "main"; + entry_point.execution_model = "GLCompute"; + entry_point.body = R"( +)"; + generator.entry_points_.push_back(std::move(entry_point)); + + CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0)); +} + +TEST_F(ValidateBuiltIns, GeometryPositionInOutSuccess) { + CodeGenerator generator = GetDefaultShaderCodeGenerator(); + + generator.before_types_ = R"( +OpMemberDecorate %input_type 0 BuiltIn Position +OpMemberDecorate %output_type 0 BuiltIn Position +)"; + + generator.after_types_ = R"( +%input_type = OpTypeStruct %f32vec4 +%arrayed_input_type = OpTypeArray %input_type %u32_3 +%input_ptr = OpTypePointer Input %arrayed_input_type +%input = OpVariable %input_ptr Input +%input_f32vec4_ptr = OpTypePointer Input %f32vec4 +%output_type = OpTypeStruct %f32vec4 +%arrayed_output_type = OpTypeArray %output_type %u32_3 +%output_ptr = OpTypePointer Output %arrayed_output_type +%output = OpVariable %output_ptr Output +%output_f32vec4_ptr = OpTypePointer Output %f32vec4 +)"; + + EntryPoint entry_point; + entry_point.name = "main"; + entry_point.execution_model = "Geometry"; + entry_point.interfaces = "%input %output"; + entry_point.body = R"( +%input_pos = OpAccessChain %input_f32vec4_ptr %input %u32_0 %u32_0 +%output_pos = OpAccessChain %output_f32vec4_ptr %output %u32_0 %u32_0 +%pos = OpLoad %f32vec4 %input_pos +OpStore %output_pos %pos +)"; + generator.entry_points_.push_back(std::move(entry_point)); + generator.entry_points_[0].execution_modes = + "OpExecutionMode %main InputPoints\nOpExecutionMode %main OutputPoints\n"; + + CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0)); +} + +TEST_F(ValidateBuiltIns, WorkgroupIdNotVec3) { + CodeGenerator generator = GetDefaultShaderCodeGenerator(); + generator.before_types_ = R"( +OpDecorate %workgroup_size BuiltIn WorkgroupSize +OpDecorate %workgroup_id BuiltIn WorkgroupId +)"; + + generator.after_types_ = R"( +%workgroup_size = OpConstantComposite %u32vec3 %u32_1 %u32_1 %u32_1 + %input_ptr = OpTypePointer Input %u32vec2 + %workgroup_id = OpVariable %input_ptr Input +)"; + + EntryPoint entry_point; + entry_point.name = "main"; + entry_point.execution_model = "GLCompute"; + entry_point.interfaces = "%workgroup_id"; + entry_point.body = R"( +%copy_size = OpCopyObject %u32vec3 %workgroup_size + %load_id = OpLoad %u32vec2 %workgroup_id +)"; + generator.entry_points_.push_back(std::move(entry_point)); + + CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("According to the Vulkan spec BuiltIn WorkgroupId " + "variable needs to be a 3-component 32-bit int vector. " + "ID <2> (OpVariable) has 2 components.")); +} + +TEST_F(ValidateBuiltIns, TwoBuiltInsFirstFails) { + CodeGenerator generator = GetDefaultShaderCodeGenerator(); + + generator.before_types_ = R"( +OpMemberDecorate %input_type 0 BuiltIn FragCoord +OpMemberDecorate %output_type 0 BuiltIn Position +)"; + + generator.after_types_ = R"( +%input_type = OpTypeStruct %f32vec4 +%input_ptr = OpTypePointer Input %input_type +%input = OpVariable %input_ptr Input +%input_f32vec4_ptr = OpTypePointer Input %f32vec4 +%output_type = OpTypeStruct %f32vec4 +%output_ptr = OpTypePointer Output %output_type +%output = OpVariable %output_ptr Output +%output_f32vec4_ptr = OpTypePointer Output %f32vec4 +)"; + + EntryPoint entry_point; + entry_point.name = "main"; + entry_point.execution_model = "Geometry"; + entry_point.interfaces = "%input %output"; + entry_point.body = R"( +%input_pos = OpAccessChain %input_f32vec4_ptr %input %u32_0 +%output_pos = OpAccessChain %output_f32vec4_ptr %output %u32_0 +%pos = OpLoad %f32vec4 %input_pos +OpStore %output_pos %pos +)"; + generator.entry_points_.push_back(std::move(entry_point)); + generator.entry_points_[0].execution_modes = + "OpExecutionMode %main InputPoints\nOpExecutionMode %main OutputPoints\n"; + + CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Vulkan spec allows BuiltIn FragCoord to be used only " + "with Fragment execution model")); +} + +TEST_F(ValidateBuiltIns, TwoBuiltInsSecondFails) { + CodeGenerator generator = GetDefaultShaderCodeGenerator(); + + generator.before_types_ = R"( +OpMemberDecorate %input_type 0 BuiltIn Position +OpMemberDecorate %output_type 0 BuiltIn FragCoord +)"; + + generator.after_types_ = R"( +%input_type = OpTypeStruct %f32vec4 +%input_ptr = OpTypePointer Input %input_type +%input = OpVariable %input_ptr Input +%input_f32vec4_ptr = OpTypePointer Input %f32vec4 +%output_type = OpTypeStruct %f32vec4 +%output_ptr = OpTypePointer Output %output_type +%output = OpVariable %output_ptr Output +%output_f32vec4_ptr = OpTypePointer Output %f32vec4 +)"; + + EntryPoint entry_point; + entry_point.name = "main"; + entry_point.execution_model = "Geometry"; + entry_point.interfaces = "%input %output"; + entry_point.body = R"( +%input_pos = OpAccessChain %input_f32vec4_ptr %input %u32_0 +%output_pos = OpAccessChain %output_f32vec4_ptr %output %u32_0 +%pos = OpLoad %f32vec4 %input_pos +OpStore %output_pos %pos +)"; + generator.entry_points_.push_back(std::move(entry_point)); + generator.entry_points_[0].execution_modes = + "OpExecutionMode %main InputPoints\nOpExecutionMode %main OutputPoints\n"; + + CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Vulkan spec allows BuiltIn FragCoord to be only used " + "for variables with Input storage class")); +} + +TEST_F(ValidateBuiltIns, VertexPositionVariableSuccess) { + CodeGenerator generator = GetDefaultShaderCodeGenerator(); + generator.before_types_ = R"( +OpDecorate %position BuiltIn Position +)"; + + generator.after_types_ = R"( +%f32vec4_ptr_output = OpTypePointer Output %f32vec4 +%position = OpVariable %f32vec4_ptr_output Output +)"; + + EntryPoint entry_point; + entry_point.name = "main"; + entry_point.execution_model = "Vertex"; + entry_point.interfaces = "%position"; + entry_point.body = R"( +OpStore %position %f32vec4_0123 +)"; + generator.entry_points_.push_back(std::move(entry_point)); + + CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0)); +} + +TEST_F(ValidateBuiltIns, FragmentPositionTwoEntryPoints) { + CodeGenerator generator = GetDefaultShaderCodeGenerator(); + generator.before_types_ = R"( +OpMemberDecorate %output_type 0 BuiltIn Position +)"; + + generator.after_types_ = R"( +%output_type = OpTypeStruct %f32vec4 +%output_ptr = OpTypePointer Output %output_type +%output = OpVariable %output_ptr Output +%output_f32vec4_ptr = OpTypePointer Output %f32vec4 +)"; + + EntryPoint entry_point; + entry_point.name = "vmain"; + entry_point.execution_model = "Vertex"; + entry_point.interfaces = "%output"; + entry_point.body = R"( +%val1 = OpFunctionCall %void %foo +)"; + generator.entry_points_.push_back(std::move(entry_point)); + + entry_point.name = "fmain"; + entry_point.execution_model = "Fragment"; + entry_point.interfaces = "%output"; + entry_point.execution_modes = "OpExecutionMode %fmain OriginUpperLeft"; + entry_point.body = R"( +%val2 = OpFunctionCall %void %foo +)"; + generator.entry_points_.push_back(std::move(entry_point)); + + generator.add_at_the_end_ = R"( +%foo = OpFunction %void None %func +%foo_entry = OpLabel +%position = OpAccessChain %output_f32vec4_ptr %output %u32_0 +OpStore %position %f32vec4_0123 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Vulkan spec allows BuiltIn Position to be used only " + "with Vertex, TessellationControl, " + "TessellationEvaluation or Geometry execution models")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("called with execution model Fragment")); +} + +TEST_F(ValidateBuiltIns, FragmentFragDepthNoDepthReplacing) { + CodeGenerator generator = GetDefaultShaderCodeGenerator(); + generator.before_types_ = R"( +OpMemberDecorate %output_type 0 BuiltIn FragDepth +)"; + + generator.after_types_ = R"( +%output_type = OpTypeStruct %f32 +%output_ptr = OpTypePointer Output %output_type +%output = OpVariable %output_ptr Output +%output_f32_ptr = OpTypePointer Output %f32 +)"; + + EntryPoint entry_point; + entry_point.name = "main"; + entry_point.execution_model = "Fragment"; + entry_point.interfaces = "%output"; + entry_point.execution_modes = "OpExecutionMode %main OriginUpperLeft"; + entry_point.body = R"( +%val2 = OpFunctionCall %void %foo +)"; + generator.entry_points_.push_back(std::move(entry_point)); + + generator.add_at_the_end_ = R"( +%foo = OpFunction %void None %func +%foo_entry = OpLabel +%frag_depth = OpAccessChain %output_f32_ptr %output %u32_0 +OpStore %frag_depth %f32_1 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Vulkan spec requires DepthReplacing execution mode to " + "be declared when using BuiltIn FragDepth")); +} + +TEST_F(ValidateBuiltIns, FragmentFragDepthOneMainHasDepthReplacingOtherHasnt) { + CodeGenerator generator = GetDefaultShaderCodeGenerator(); + generator.before_types_ = R"( +OpMemberDecorate %output_type 0 BuiltIn FragDepth +)"; + + generator.after_types_ = R"( +%output_type = OpTypeStruct %f32 +%output_ptr = OpTypePointer Output %output_type +%output = OpVariable %output_ptr Output +%output_f32_ptr = OpTypePointer Output %f32 +)"; + + EntryPoint entry_point; + entry_point.name = "main_d_r"; + entry_point.execution_model = "Fragment"; + entry_point.interfaces = "%output"; + entry_point.execution_modes = + "OpExecutionMode %main_d_r OriginUpperLeft\n" + "OpExecutionMode %main_d_r DepthReplacing"; + entry_point.body = R"( +%val2 = OpFunctionCall %void %foo +)"; + generator.entry_points_.push_back(std::move(entry_point)); + + entry_point.name = "main_no_d_r"; + entry_point.execution_model = "Fragment"; + entry_point.interfaces = "%output"; + entry_point.execution_modes = "OpExecutionMode %main_no_d_r OriginUpperLeft"; + entry_point.body = R"( +%val3 = OpFunctionCall %void %foo +)"; + generator.entry_points_.push_back(std::move(entry_point)); + + generator.add_at_the_end_ = R"( +%foo = OpFunction %void None %func +%foo_entry = OpLabel +%frag_depth = OpAccessChain %output_f32_ptr %output %u32_0 +OpStore %frag_depth %f32_1 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Vulkan spec requires DepthReplacing execution mode to " + "be declared when using BuiltIn FragDepth")); +} + +TEST_F(ValidateBuiltIns, AllowInstanceIdWithIntersectionShader) { + CodeGenerator generator = GetDefaultShaderCodeGenerator(); + generator.capabilities_ += R"( +OpCapability RayTracingNV +)"; + + generator.extensions_ = R"( +OpExtension "SPV_NV_ray_tracing" +)"; + + generator.before_types_ = R"( +OpMemberDecorate %input_type 0 BuiltIn InstanceId +)"; + + generator.after_types_ = R"( +%input_type = OpTypeStruct %u32 +%input_ptr = OpTypePointer Input %input_type +%input = OpVariable %input_ptr Input +)"; + + EntryPoint entry_point; + entry_point.name = "main_d_r"; + entry_point.execution_model = "IntersectionNV"; + entry_point.interfaces = "%input"; + entry_point.body = R"( +%val2 = OpFunctionCall %void %foo +)"; + generator.entry_points_.push_back(std::move(entry_point)); + + generator.add_at_the_end_ = R"( +%foo = OpFunction %void None %func +%foo_entry = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0); + EXPECT_THAT(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0)); +} + +TEST_F(ValidateBuiltIns, DisallowInstanceIdWithRayGenShader) { + CodeGenerator generator = GetDefaultShaderCodeGenerator(); + generator.capabilities_ += R"( +OpCapability RayTracingNV +)"; + + generator.extensions_ = R"( +OpExtension "SPV_NV_ray_tracing" +)"; + + generator.before_types_ = R"( +OpMemberDecorate %input_type 0 BuiltIn InstanceId +)"; + + generator.after_types_ = R"( +%input_type = OpTypeStruct %u32 +%input_ptr = OpTypePointer Input %input_type +%input_ptr_u32 = OpTypePointer Input %u32 +%input = OpVariable %input_ptr Input +)"; + + EntryPoint entry_point; + entry_point.name = "main_d_r"; + entry_point.execution_model = "RayGenerationNV"; + entry_point.interfaces = "%input"; + entry_point.body = R"( +%input_member = OpAccessChain %input_ptr_u32 %input %u32_0 +)"; + generator.entry_points_.push_back(std::move(entry_point)); + + CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Vulkan spec allows BuiltIn InstanceId to be used " + "only with IntersectionNV, ClosestHitNV and " + "AnyHitNV execution models")); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_capability_test.cpp b/test/val/val_capability_test.cpp new file mode 100644 index 000000000..488e957ae --- /dev/null +++ b/test/val/val_capability_test.cpp @@ -0,0 +1,2407 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validation tests for Logical Layout + +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/assembly_grammar.h" +#include "source/spirv_target_env.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using spvtest::ScopedContext; +using testing::Combine; +using testing::HasSubstr; +using testing::Values; +using testing::ValuesIn; + +// Parameter for validation test fixtures. The first std::string is a +// capability name that will begin the assembly under test, the second the +// remainder assembly, and the std::vector at the end determines whether the +// test expects success or failure. See below for details and convenience +// methods to access each one. +// +// The assembly to test is composed from a variable top line and a fixed +// remainder. The top line will be an OpCapability instruction, while the +// remainder will be some assembly text that succeeds or fails to assemble +// depending on which capability was chosen. For instance, the following will +// succeed: +// +// OpCapability Pipes ; implies Kernel +// OpLifetimeStop %1 0 ; requires Kernel +// +// and the following will fail: +// +// OpCapability Kernel +// %1 = OpTypeNamedBarrier ; requires NamedBarrier +// +// So how does the test parameter capture which capabilities should cause +// success and which shouldn't? The answer is in the last element: it's a +// std::vector of capabilities that make the remainder assembly succeed. So if +// the first-line capability exists in that std::vector, success is expected; +// otherwise, failure is expected in the tests. +// +// We will use testing::Combine() to vary the first line: when we combine +// AllCapabilities() with a single remainder assembly, we generate enough test +// cases to try the assembly with every possible capability that could be +// declared. However, Combine() only produces tuples -- it cannot produce, say, +// a struct. Therefore, this type must be a tuple. +using CapTestParameter = + std::tuple>>; + +const std::string& Capability(const CapTestParameter& p) { + return std::get<0>(p); +} +const std::string& Remainder(const CapTestParameter& p) { + return std::get<1>(p).first; +} +const std::vector& MustSucceed(const CapTestParameter& p) { + return std::get<1>(p).second; +} + +// Creates assembly to test from p. +std::string MakeAssembly(const CapTestParameter& p) { + std::ostringstream ss; + const std::string& capability = Capability(p); + if (!capability.empty()) { + ss << "OpCapability " << capability << "\n"; + } + ss << Remainder(p); + return ss.str(); +} + +// Expected validation result for p. +spv_result_t ExpectedResult(const CapTestParameter& p) { + const auto& caps = MustSucceed(p); + auto found = find(begin(caps), end(caps), Capability(p)); + return (found == end(caps)) ? SPV_ERROR_INVALID_CAPABILITY : SPV_SUCCESS; +} + +// Assembles using v1.0, unless the parameter's capability requires v1.1. +using ValidateCapability = spvtest::ValidateBase; + +// Always assembles using v1.1. +using ValidateCapabilityV11 = spvtest::ValidateBase; + +// Always assembles using Vulkan 1.0. +// TODO(dneto): Refactor all these tests to scale better across environments. +using ValidateCapabilityVulkan10 = spvtest::ValidateBase; +// Always assembles using OpenGL 4.0. +using ValidateCapabilityOpenGL40 = spvtest::ValidateBase; +// Always assembles using Vulkan 1.1. +using ValidateCapabilityVulkan11 = spvtest::ValidateBase; +// Always assembles using WebGPU. +using ValidateCapabilityWebGPU = spvtest::ValidateBase; + +TEST_F(ValidateCapability, Default) { + const char str[] = R"( + OpCapability Kernel + OpCapability Linkage + OpCapability Matrix + OpMemoryModel Logical OpenCL +%f32 = OpTypeFloat 32 +%vec3 = OpTypeVector %f32 3 +%mat33 = OpTypeMatrix %vec3 3 +)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// clang-format off +const std::vector& AllCapabilities() { + static const auto r = new std::vector{ + "", + "Matrix", + "Shader", + "Geometry", + "Tessellation", + "Addresses", + "Linkage", + "Kernel", + "Vector16", + "Float16Buffer", + "Float16", + "Float64", + "Int64", + "Int64Atomics", + "ImageBasic", + "ImageReadWrite", + "ImageMipmap", + "Pipes", + "Groups", + "DeviceEnqueue", + "LiteralSampler", + "AtomicStorage", + "Int16", + "TessellationPointSize", + "GeometryPointSize", + "ImageGatherExtended", + "StorageImageMultisample", + "UniformBufferArrayDynamicIndexing", + "SampledImageArrayDynamicIndexing", + "StorageBufferArrayDynamicIndexing", + "StorageImageArrayDynamicIndexing", + "ClipDistance", + "CullDistance", + "ImageCubeArray", + "SampleRateShading", + "ImageRect", + "SampledRect", + "GenericPointer", + "Int8", + "InputAttachment", + "SparseResidency", + "MinLod", + "Sampled1D", + "Image1D", + "SampledCubeArray", + "SampledBuffer", + "ImageBuffer", + "ImageMSArray", + "StorageImageExtendedFormats", + "ImageQuery", + "DerivativeControl", + "InterpolationFunction", + "TransformFeedback", + "GeometryStreams", + "StorageImageReadWithoutFormat", + "StorageImageWriteWithoutFormat", + "MultiViewport", + "SubgroupDispatch", + "NamedBarrier", + "PipeStorage", + "GroupNonUniform", + "GroupNonUniformVote", + "GroupNonUniformArithmetic", + "GroupNonUniformBallot", + "GroupNonUniformShuffle", + "GroupNonUniformShuffleRelative", + "GroupNonUniformClustered", + "GroupNonUniformQuad", + "DrawParameters", + "StorageBuffer16BitAccess", + "StorageUniformBufferBlock16", + "UniformAndStorageBuffer16BitAccess", + "StorageUniform16", + "StoragePushConstant16", + "StorageInputOutput16", + "DeviceGroup", + "MultiView", + "VariablePointersStorageBuffer", + "VariablePointers"}; + return *r; +} + +const std::vector& AllSpirV10Capabilities() { + static const auto r = new std::vector{ + "", + "Matrix", + "Shader", + "Geometry", + "Tessellation", + "Addresses", + "Linkage", + "Kernel", + "Vector16", + "Float16Buffer", + "Float16", + "Float64", + "Int64", + "Int64Atomics", + "ImageBasic", + "ImageReadWrite", + "ImageMipmap", + "Pipes", + "Groups", + "DeviceEnqueue", + "LiteralSampler", + "AtomicStorage", + "Int16", + "TessellationPointSize", + "GeometryPointSize", + "ImageGatherExtended", + "StorageImageMultisample", + "UniformBufferArrayDynamicIndexing", + "SampledImageArrayDynamicIndexing", + "StorageBufferArrayDynamicIndexing", + "StorageImageArrayDynamicIndexing", + "ClipDistance", + "CullDistance", + "ImageCubeArray", + "SampleRateShading", + "ImageRect", + "SampledRect", + "GenericPointer", + "Int8", + "InputAttachment", + "SparseResidency", + "MinLod", + "Sampled1D", + "Image1D", + "SampledCubeArray", + "SampledBuffer", + "ImageBuffer", + "ImageMSArray", + "StorageImageExtendedFormats", + "ImageQuery", + "DerivativeControl", + "InterpolationFunction", + "TransformFeedback", + "GeometryStreams", + "StorageImageReadWithoutFormat", + "StorageImageWriteWithoutFormat", + "MultiViewport"}; + return *r; +} + +const std::vector& AllVulkan10Capabilities() { + static const auto r = new std::vector{ + "", + "Matrix", + "Shader", + "InputAttachment", + "Sampled1D", + "Image1D", + "SampledBuffer", + "ImageBuffer", + "ImageQuery", + "DerivativeControl", + "Geometry", + "Tessellation", + "Float16", + "Float64", + "Int64", + "Int64Atomics", + "Int16", + "TessellationPointSize", + "GeometryPointSize", + "ImageGatherExtended", + "StorageImageMultisample", + "UniformBufferArrayDynamicIndexing", + "SampledImageArrayDynamicIndexing", + "StorageBufferArrayDynamicIndexing", + "StorageImageArrayDynamicIndexing", + "ClipDistance", + "CullDistance", + "ImageCubeArray", + "SampleRateShading", + "Int8", + "SparseResidency", + "MinLod", + "SampledCubeArray", + "ImageMSArray", + "StorageImageExtendedFormats", + "InterpolationFunction", + "StorageImageReadWithoutFormat", + "StorageImageWriteWithoutFormat", + "MultiViewport", + "TransformFeedback", + "GeometryStreams"}; + return *r; +} + +const std::vector& AllVulkan11Capabilities() { + static const auto r = new std::vector{ + "", + "Matrix", + "Shader", + "InputAttachment", + "Sampled1D", + "Image1D", + "SampledBuffer", + "ImageBuffer", + "ImageQuery", + "DerivativeControl", + "Geometry", + "Tessellation", + "Float16", + "Float64", + "Int64", + "Int64Atomics", + "Int16", + "TessellationPointSize", + "GeometryPointSize", + "ImageGatherExtended", + "StorageImageMultisample", + "UniformBufferArrayDynamicIndexing", + "SampledImageArrayDynamicIndexing", + "StorageBufferArrayDynamicIndexing", + "StorageImageArrayDynamicIndexing", + "ClipDistance", + "CullDistance", + "ImageCubeArray", + "SampleRateShading", + "Int8", + "SparseResidency", + "MinLod", + "SampledCubeArray", + "ImageMSArray", + "StorageImageExtendedFormats", + "InterpolationFunction", + "StorageImageReadWithoutFormat", + "StorageImageWriteWithoutFormat", + "MultiViewport", + "GroupNonUniform", + "GroupNonUniformVote", + "GroupNonUniformArithmetic", + "GroupNonUniformBallot", + "GroupNonUniformShuffle", + "GroupNonUniformShuffleRelative", + "GroupNonUniformClustered", + "GroupNonUniformQuad", + "DrawParameters", + "StorageBuffer16BitAccess", + "StorageUniformBufferBlock16", + "UniformAndStorageBuffer16BitAccess", + "StorageUniform16", + "StoragePushConstant16", + "StorageInputOutput16", + "DeviceGroup", + "MultiView", + "VariablePointersStorageBuffer", + "VariablePointers", + "TransformFeedback", + "GeometryStreams"}; + return *r; +} + +const std::vector& AllWebGPUCapabilities() { + static const auto r = new std::vector{ + "", + "Shader", + "Matrix", + "Sampled1D", + "Image1D", + "ImageQuery", + "DerivativeControl"}; + return *r; +} + +const std::vector& MatrixDependencies() { + static const auto r = new std::vector{ + "Matrix", + "Shader", + "Geometry", + "Tessellation", + "AtomicStorage", + "TessellationPointSize", + "GeometryPointSize", + "ImageGatherExtended", + "StorageImageMultisample", + "UniformBufferArrayDynamicIndexing", + "SampledImageArrayDynamicIndexing", + "StorageBufferArrayDynamicIndexing", + "StorageImageArrayDynamicIndexing", + "ClipDistance", + "CullDistance", + "ImageCubeArray", + "SampleRateShading", + "ImageRect", + "SampledRect", + "InputAttachment", + "SparseResidency", + "MinLod", + "SampledCubeArray", + "ImageMSArray", + "StorageImageExtendedFormats", + "ImageQuery", + "DerivativeControl", + "InterpolationFunction", + "TransformFeedback", + "GeometryStreams", + "StorageImageReadWithoutFormat", + "StorageImageWriteWithoutFormat", + "MultiViewport", + "DrawParameters", + "MultiView", + "VariablePointersStorageBuffer", + "VariablePointers"}; + return *r; +} + +const std::vector& ShaderDependencies() { + static const auto r = new std::vector{ + "Shader", + "Geometry", + "Tessellation", + "AtomicStorage", + "TessellationPointSize", + "GeometryPointSize", + "ImageGatherExtended", + "StorageImageMultisample", + "UniformBufferArrayDynamicIndexing", + "SampledImageArrayDynamicIndexing", + "StorageBufferArrayDynamicIndexing", + "StorageImageArrayDynamicIndexing", + "ClipDistance", + "CullDistance", + "ImageCubeArray", + "SampleRateShading", + "ImageRect", + "SampledRect", + "InputAttachment", + "SparseResidency", + "MinLod", + "SampledCubeArray", + "ImageMSArray", + "StorageImageExtendedFormats", + "ImageQuery", + "DerivativeControl", + "InterpolationFunction", + "TransformFeedback", + "GeometryStreams", + "StorageImageReadWithoutFormat", + "StorageImageWriteWithoutFormat", + "MultiViewport", + "DrawParameters", + "MultiView", + "VariablePointersStorageBuffer", + "VariablePointers"}; + return *r; +} + +const std::vector& TessellationDependencies() { + static const auto r = new std::vector{ + "Tessellation", + "TessellationPointSize"}; + return *r; +} + +const std::vector& GeometryDependencies() { + static const auto r = new std::vector{ + "Geometry", + "GeometryPointSize", + "GeometryStreams", + "MultiViewport"}; + return *r; +} + +const std::vector& GeometryTessellationDependencies() { + static const auto r = new std::vector{ + "Tessellation", + "TessellationPointSize", + "Geometry", + "GeometryPointSize", + "GeometryStreams", + "MultiViewport"}; + return *r; +} + +// Returns the names of capabilities that directly depend on Kernel, +// plus itself. +const std::vector& KernelDependencies() { + static const auto r = new std::vector{ + "Kernel", + "Vector16", + "Float16Buffer", + "ImageBasic", + "ImageReadWrite", + "ImageMipmap", + "Pipes", + "DeviceEnqueue", + "LiteralSampler", + "SubgroupDispatch", + "NamedBarrier", + "PipeStorage"}; + return *r; +} + +const std::vector& KernelAndGroupNonUniformDependencies() { + static const auto r = new std::vector{ + "Kernel", + "Vector16", + "Float16Buffer", + "ImageBasic", + "ImageReadWrite", + "ImageMipmap", + "Pipes", + "DeviceEnqueue", + "LiteralSampler", + "SubgroupDispatch", + "NamedBarrier", + "PipeStorage", + "GroupNonUniform", + "GroupNonUniformVote", + "GroupNonUniformArithmetic", + "GroupNonUniformBallot", + "GroupNonUniformShuffle", + "GroupNonUniformShuffleRelative", + "GroupNonUniformClustered", + "GroupNonUniformQuad"}; + return *r; +} + +const std::vector& AddressesDependencies() { + static const auto r = new std::vector{ + "Addresses", + "GenericPointer"}; + return *r; +} + +const std::vector& Sampled1DDependencies() { + static const auto r = new std::vector{ + "Sampled1D", + "Image1D"}; + return *r; +} + +const std::vector& SampledRectDependencies() { + static const auto r = new std::vector{ + "SampledRect", + "ImageRect"}; + return *r; +} + +const std::vector& SampledBufferDependencies() { + static const auto r = new std::vector{ + "SampledBuffer", + "ImageBuffer"}; + return *r; +} + +const char kOpenCLMemoryModel[] = \ + " OpCapability Kernel" + " OpMemoryModel Logical OpenCL "; + +const char kGLSL450MemoryModel[] = \ + " OpCapability Shader" + " OpMemoryModel Logical GLSL450 "; + +const char kVulkanMemoryModel[] = \ + " OpCapability Shader" + " OpCapability VulkanMemoryModelKHR" + " OpExtension \"SPV_KHR_vulkan_memory_model\"" + " OpMemoryModel Logical VulkanKHR "; + +const char kVoidFVoid[] = \ + " %void = OpTypeVoid" + " %void_f = OpTypeFunction %void" + " %func = OpFunction %void None %void_f" + " %label = OpLabel" + " OpReturn" + " OpFunctionEnd "; + +const char kVoidFVoid2[] = \ + " %void_f = OpTypeFunction %voidt" + " %func = OpFunction %voidt None %void_f" + " %label = OpLabel" + " OpReturn" + " OpFunctionEnd "; + +INSTANTIATE_TEST_CASE_P(ExecutionModel, ValidateCapability, + Combine( + ValuesIn(AllCapabilities()), + Values( +std::make_pair(std::string(kOpenCLMemoryModel) + + " OpEntryPoint Vertex %func \"shader\"" + + std::string(kVoidFVoid), ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + " OpEntryPoint TessellationControl %func \"shader\"" + + std::string(kVoidFVoid), TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + " OpEntryPoint TessellationEvaluation %func \"shader\"" + + std::string(kVoidFVoid), TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + " OpEntryPoint Geometry %func \"shader\"" + + " OpExecutionMode %func InputPoints" + + " OpExecutionMode %func OutputPoints" + + std::string(kVoidFVoid), GeometryDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + " OpEntryPoint Fragment %func \"shader\"" + + " OpExecutionMode %func OriginUpperLeft" + + std::string(kVoidFVoid), ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + " OpEntryPoint GLCompute %func \"shader\"" + + std::string(kVoidFVoid), ShaderDependencies()), +std::make_pair(std::string(kGLSL450MemoryModel) + + " OpEntryPoint Kernel %func \"shader\"" + + std::string(kVoidFVoid), KernelDependencies()) +)),); + +INSTANTIATE_TEST_CASE_P(AddressingAndMemoryModel, ValidateCapability, + Combine( + ValuesIn(AllCapabilities()), + Values( +std::make_pair(" OpCapability Shader" + " OpMemoryModel Logical Simple" + " OpEntryPoint Vertex %func \"shader\"" + + std::string(kVoidFVoid), AllCapabilities()), +std::make_pair(" OpCapability Shader" + " OpMemoryModel Logical GLSL450" + " OpEntryPoint Vertex %func \"shader\"" + + std::string(kVoidFVoid), AllCapabilities()), +std::make_pair(" OpCapability Kernel" + " OpMemoryModel Logical OpenCL" + " OpEntryPoint Kernel %func \"compute\"" + + std::string(kVoidFVoid), AllCapabilities()), +std::make_pair(" OpCapability Shader" + " OpMemoryModel Physical32 Simple" + " OpEntryPoint Vertex %func \"shader\"" + + std::string(kVoidFVoid), AddressesDependencies()), +std::make_pair(" OpCapability Shader" + " OpMemoryModel Physical32 GLSL450" + " OpEntryPoint Vertex %func \"shader\"" + + std::string(kVoidFVoid), AddressesDependencies()), +std::make_pair(" OpCapability Kernel" + " OpMemoryModel Physical32 OpenCL" + " OpEntryPoint Kernel %func \"compute\"" + + std::string(kVoidFVoid), AddressesDependencies()), +std::make_pair(" OpCapability Shader" + " OpMemoryModel Physical64 Simple" + " OpEntryPoint Vertex %func \"shader\"" + + std::string(kVoidFVoid), AddressesDependencies()), +std::make_pair(" OpCapability Shader" + " OpMemoryModel Physical64 GLSL450" + " OpEntryPoint Vertex %func \"shader\"" + + std::string(kVoidFVoid), AddressesDependencies()), +std::make_pair(" OpCapability Kernel" + " OpMemoryModel Physical64 OpenCL" + " OpEntryPoint Kernel %func \"compute\"" + + std::string(kVoidFVoid), AddressesDependencies()) +)),); + +INSTANTIATE_TEST_CASE_P(ExecutionMode, ValidateCapability, + Combine( + ValuesIn(AllCapabilities()), + Values( +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Geometry %func \"shader\" " + "OpExecutionMode %func Invocations 42" + + " OpExecutionMode %func InputPoints" + + " OpExecutionMode %func OutputPoints" + + std::string(kVoidFVoid), GeometryDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint TessellationControl %func \"shader\" " + "OpExecutionMode %func SpacingEqual" + + std::string(kVoidFVoid), TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint TessellationControl %func \"shader\" " + "OpExecutionMode %func SpacingFractionalEven" + + std::string(kVoidFVoid), TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint TessellationControl %func \"shader\" " + "OpExecutionMode %func SpacingFractionalOdd" + + std::string(kVoidFVoid), TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint TessellationControl %func \"shader\" " + "OpExecutionMode %func VertexOrderCw" + + std::string(kVoidFVoid), TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint TessellationControl %func \"shader\" " + "OpExecutionMode %func VertexOrderCcw" + + std::string(kVoidFVoid), TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Fragment %func \"shader\" " + "OpExecutionMode %func PixelCenterInteger" + + " OpExecutionMode %func OriginUpperLeft" + + std::string(kVoidFVoid), ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Fragment %func \"shader\" " + "OpExecutionMode %func OriginUpperLeft" + + std::string(kVoidFVoid), ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Fragment %func \"shader\" " + "OpExecutionMode %func OriginLowerLeft" + + std::string(kVoidFVoid), ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Fragment %func \"shader\" " + "OpExecutionMode %func EarlyFragmentTests" + + " OpExecutionMode %func OriginUpperLeft" + + std::string(kVoidFVoid), ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint TessellationControl %func \"shader\" " + "OpExecutionMode %func PointMode" + + std::string(kVoidFVoid), TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Vertex %func \"shader\" " + "OpExecutionMode %func Xfb" + + std::string(kVoidFVoid), std::vector{"TransformFeedback"}), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Fragment %func \"shader\" " + "OpExecutionMode %func DepthReplacing" + + " OpExecutionMode %func OriginUpperLeft" + + std::string(kVoidFVoid), ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Fragment %func \"shader\" " + "OpExecutionMode %func DepthGreater" + + " OpExecutionMode %func OriginUpperLeft" + + std::string(kVoidFVoid), ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Fragment %func \"shader\" " + "OpExecutionMode %func DepthLess" + + " OpExecutionMode %func OriginUpperLeft" + + std::string(kVoidFVoid), ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Fragment %func \"shader\" " + "OpExecutionMode %func DepthUnchanged" + + " OpExecutionMode %func OriginUpperLeft" + + std::string(kVoidFVoid), ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"shader\" " + "OpExecutionMode %func LocalSize 42 42 42" + + std::string(kVoidFVoid), AllCapabilities()), +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Kernel %func \"shader\" " + "OpExecutionMode %func LocalSizeHint 42 42 42" + + std::string(kVoidFVoid), KernelDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Geometry %func \"shader\" " + "OpExecutionMode %func InputPoints" + + " OpExecutionMode %func OutputPoints" + + std::string(kVoidFVoid), GeometryDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Geometry %func \"shader\" " + "OpExecutionMode %func InputLines" + + " OpExecutionMode %func OutputLineStrip" + + std::string(kVoidFVoid), GeometryDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Geometry %func \"shader\" " + "OpExecutionMode %func InputLinesAdjacency" + + " OpExecutionMode %func OutputLineStrip" + + std::string(kVoidFVoid), GeometryDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Geometry %func \"shader\" " + "OpExecutionMode %func Triangles" + + " OpExecutionMode %func OutputTriangleStrip" + + std::string(kVoidFVoid), GeometryDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint TessellationControl %func \"shader\" " + "OpExecutionMode %func Triangles" + + std::string(kVoidFVoid), TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Geometry %func \"shader\" " + "OpExecutionMode %func InputTrianglesAdjacency" + + " OpExecutionMode %func OutputTriangleStrip" + + std::string(kVoidFVoid), GeometryDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint TessellationControl %func \"shader\" " + "OpExecutionMode %func Quads" + + std::string(kVoidFVoid), TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint TessellationControl %func \"shader\" " + "OpExecutionMode %func Isolines" + + std::string(kVoidFVoid), TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Geometry %func \"shader\" " + "OpExecutionMode %func OutputVertices 42" + + " OpExecutionMode %func OutputPoints" + + " OpExecutionMode %func InputPoints" + + std::string(kVoidFVoid), GeometryDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint TessellationControl %func \"shader\" " + "OpExecutionMode %func OutputVertices 42" + + std::string(kVoidFVoid), TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Geometry %func \"shader\" " + "OpExecutionMode %func OutputPoints" + + " OpExecutionMode %func InputPoints" + + std::string(kVoidFVoid), GeometryDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Geometry %func \"shader\" " + "OpExecutionMode %func OutputLineStrip" + + " OpExecutionMode %func InputLines" + + std::string(kVoidFVoid), GeometryDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Geometry %func \"shader\" " + "OpExecutionMode %func OutputTriangleStrip" + + " OpExecutionMode %func Triangles" + + std::string(kVoidFVoid), GeometryDependencies()), +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Kernel %func \"shader\" " + "OpExecutionMode %func VecTypeHint 2" + + std::string(kVoidFVoid), KernelDependencies()), +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Kernel %func \"shader\" " + "OpExecutionMode %func ContractionOff" + + std::string(kVoidFVoid), KernelDependencies()))),); + +// clang-format on + +INSTANTIATE_TEST_CASE_P( + ExecutionModeV11, ValidateCapabilityV11, + Combine(ValuesIn(AllCapabilities()), + Values(std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"shader\" " + "OpExecutionMode %func SubgroupSize 1" + + std::string(kVoidFVoid), + std::vector{"SubgroupDispatch"}), + std::make_pair( + std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"shader\" " + "OpExecutionMode %func SubgroupsPerWorkgroup 65535" + + std::string(kVoidFVoid), + std::vector{"SubgroupDispatch"}))), ); +// clang-format off + +INSTANTIATE_TEST_CASE_P(StorageClass, ValidateCapability, + Combine( + ValuesIn(AllCapabilities()), + Values( +std::make_pair(std::string(kGLSL450MemoryModel) + + " OpEntryPoint Vertex %func \"shader\"" + + " %intt = OpTypeInt 32 0\n" + " %ptrt = OpTypePointer UniformConstant %intt\n" + " %var = OpVariable %ptrt UniformConstant\n" + std::string(kVoidFVoid), + AllCapabilities()), +std::make_pair(std::string(kOpenCLMemoryModel) + + " OpEntryPoint Kernel %func \"compute\"" + + " %intt = OpTypeInt 32 0\n" + " %ptrt = OpTypePointer Input %intt" + " %var = OpVariable %ptrt Input\n" + std::string(kVoidFVoid), + AllCapabilities()), +std::make_pair(std::string(kOpenCLMemoryModel) + + " OpEntryPoint Vertex %func \"shader\"" + + " %intt = OpTypeInt 32 0\n" + " %ptrt = OpTypePointer Uniform %intt\n" + " %var = OpVariable %ptrt Uniform\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + " OpEntryPoint Vertex %func \"shader\"" + + " %intt = OpTypeInt 32 0\n" + " %ptrt = OpTypePointer Output %intt\n" + " %var = OpVariable %ptrt Output\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kGLSL450MemoryModel) + + " OpEntryPoint Vertex %func \"shader\"" + + " %intt = OpTypeInt 32 0\n" + " %ptrt = OpTypePointer Workgroup %intt\n" + " %var = OpVariable %ptrt Workgroup\n" + std::string(kVoidFVoid), + AllCapabilities()), +std::make_pair(std::string(kGLSL450MemoryModel) + + " OpEntryPoint Vertex %func \"shader\"" + + " %intt = OpTypeInt 32 0\n" + " %ptrt = OpTypePointer CrossWorkgroup %intt\n" + " %var = OpVariable %ptrt CrossWorkgroup\n" + std::string(kVoidFVoid), + AllCapabilities()), +std::make_pair(std::string(kOpenCLMemoryModel) + + " OpEntryPoint Kernel %func \"compute\"" + + " %intt = OpTypeInt 32 0\n" + " %ptrt = OpTypePointer Private %intt\n" + " %var = OpVariable %ptrt Private\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + " OpEntryPoint Kernel %func \"compute\"" + + " %intt = OpTypeInt 32 0\n" + " %ptrt = OpTypePointer PushConstant %intt\n" + " %var = OpVariable %ptrt PushConstant\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kGLSL450MemoryModel) + + " OpEntryPoint Vertex %func \"shader\"" + + " %intt = OpTypeInt 32 0\n" + " %ptrt = OpTypePointer AtomicCounter %intt\n" + " %var = OpVariable %ptrt AtomicCounter\n" + std::string(kVoidFVoid), + std::vector{"AtomicStorage"}), +std::make_pair(std::string(kGLSL450MemoryModel) + + " OpEntryPoint Vertex %func \"shader\"" + + " %intt = OpTypeInt 32 0\n" + " %ptrt = OpTypePointer Image %intt\n" + " %var = OpVariable %ptrt Image\n" + std::string(kVoidFVoid), + AllCapabilities()) +)),); + +INSTANTIATE_TEST_CASE_P(Dim, ValidateCapability, + Combine( + ValuesIn(AllCapabilities()), + Values( +std::make_pair(" OpCapability ImageBasic" + + std::string(kOpenCLMemoryModel) + + std::string(" OpEntryPoint Kernel %func \"compute\"") + + " %voidt = OpTypeVoid" + " %imgt = OpTypeImage %voidt 1D 0 0 0 0 Unknown" + std::string(kVoidFVoid2), + Sampled1DDependencies()), +std::make_pair(" OpCapability ImageBasic" + + std::string(kOpenCLMemoryModel) + + std::string(" OpEntryPoint Kernel %func \"compute\"") + + " %voidt = OpTypeVoid" + " %imgt = OpTypeImage %voidt 2D 0 0 0 0 Unknown" + std::string(kVoidFVoid2), + AllCapabilities()), +std::make_pair(" OpCapability ImageBasic" + + std::string(kOpenCLMemoryModel) + + std::string(" OpEntryPoint Kernel %func \"compute\"") + + " %voidt = OpTypeVoid" + " %imgt = OpTypeImage %voidt 3D 0 0 0 0 Unknown" + std::string(kVoidFVoid2), + AllCapabilities()), +std::make_pair(" OpCapability ImageBasic" + + std::string(kOpenCLMemoryModel) + + std::string(" OpEntryPoint Kernel %func \"compute\"") + + " %voidt = OpTypeVoid" + " %imgt = OpTypeImage %voidt Cube 0 0 0 0 Unknown" + std::string(kVoidFVoid2), + ShaderDependencies()), +std::make_pair(" OpCapability ImageBasic" + + std::string(kOpenCLMemoryModel) + + std::string(" OpEntryPoint Kernel %func \"compute\"") + + " %voidt = OpTypeVoid" + " %imgt = OpTypeImage %voidt Rect 0 0 0 0 Unknown" + std::string(kVoidFVoid2), + SampledRectDependencies()), +std::make_pair(" OpCapability ImageBasic" + + std::string(kOpenCLMemoryModel) + + std::string(" OpEntryPoint Kernel %func \"compute\"") + + " %voidt = OpTypeVoid" + " %imgt = OpTypeImage %voidt Buffer 0 0 0 0 Unknown" + std::string(kVoidFVoid2), + SampledBufferDependencies()), +std::make_pair(" OpCapability ImageBasic" + + std::string(kOpenCLMemoryModel) + + std::string(" OpEntryPoint Kernel %func \"compute\"") + + " %voidt = OpTypeVoid" + " %imgt = OpTypeImage %voidt SubpassData 0 0 0 2 Unknown" + std::string(kVoidFVoid2), + std::vector{"InputAttachment"}) +)),); + +// NOTE: All Sampler Address Modes require kernel capabilities but the +// OpConstantSampler requires LiteralSampler which depends on Kernel +INSTANTIATE_TEST_CASE_P(SamplerAddressingMode, ValidateCapability, + Combine( + ValuesIn(AllCapabilities()), + Values( +std::make_pair(std::string(kGLSL450MemoryModel) + + " OpEntryPoint Vertex %func \"shader\"" + " %samplert = OpTypeSampler" + " %sampler = OpConstantSampler %samplert None 1 Nearest" + + std::string(kVoidFVoid), + std::vector{"LiteralSampler"}), +std::make_pair(std::string(kGLSL450MemoryModel) + + " OpEntryPoint Vertex %func \"shader\"" + " %samplert = OpTypeSampler" + " %sampler = OpConstantSampler %samplert ClampToEdge 1 Nearest" + + std::string(kVoidFVoid), + std::vector{"LiteralSampler"}), +std::make_pair(std::string(kGLSL450MemoryModel) + + " OpEntryPoint Vertex %func \"shader\"" + " %samplert = OpTypeSampler" + " %sampler = OpConstantSampler %samplert Clamp 1 Nearest" + + std::string(kVoidFVoid), + std::vector{"LiteralSampler"}), +std::make_pair(std::string(kGLSL450MemoryModel) + + " OpEntryPoint Vertex %func \"shader\"" + " %samplert = OpTypeSampler" + " %sampler = OpConstantSampler %samplert Repeat 1 Nearest" + + std::string(kVoidFVoid), + std::vector{"LiteralSampler"}), +std::make_pair(std::string(kGLSL450MemoryModel) + + " OpEntryPoint Vertex %func \"shader\"" + " %samplert = OpTypeSampler" + " %sampler = OpConstantSampler %samplert RepeatMirrored 1 Nearest" + + std::string(kVoidFVoid), + std::vector{"LiteralSampler"}) +)),); + +// TODO(umar): Sampler Filter Mode +// TODO(umar): Image Format +// TODO(umar): Image Channel Order +// TODO(umar): Image Channel Data Type +// TODO(umar): Image Operands +// TODO(umar): FP Fast Math Mode +// TODO(umar): FP Rounding Mode +// TODO(umar): Linkage Type +// TODO(umar): Access Qualifier +// TODO(umar): Function Parameter Attribute + +INSTANTIATE_TEST_CASE_P(Decoration, ValidateCapability, + Combine( + ValuesIn(AllCapabilities()), + Values( +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt RelaxedPrecision\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt Block\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt BufferBlock\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt RowMajor\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + MatrixDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt ColMajor\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + MatrixDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt ArrayStride 1\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt MatrixStride 1\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + MatrixDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt GLSLShared\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt GLSLPacked\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + "OpDecorate %intt CPacked\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + KernelDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt NoPerspective\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt Flat\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt Patch\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt Centroid\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt Sample\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + std::vector{"SampleRateShading"}), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt Invariant\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt Restrict\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + AllCapabilities()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt Aliased\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + AllCapabilities()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt Volatile\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + AllCapabilities()), +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + "OpDecorate %intt Constant\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + KernelDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt Coherent\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + AllCapabilities()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt NonWritable\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + AllCapabilities()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt NonReadable\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + AllCapabilities()), +std::make_pair(std::string(kOpenCLMemoryModel) + + // Uniform must target a non-void value. + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %int0 Uniform\n" + "%intt = OpTypeInt 32 0\n" + + "%int0 = OpConstantNull %intt" + + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + "OpDecorate %intt SaturatedConversion\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + KernelDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt Stream 0\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + std::vector{"GeometryStreams"}), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt Location 0\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt Component 0\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt Index 0\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt Binding 0\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt DescriptorSet 0\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt Offset 0\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt XfbBuffer 0\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + std::vector{"TransformFeedback"}), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt XfbStride 0\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + std::vector{"TransformFeedback"}), +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + "OpDecorate %intt FuncParamAttr Zext\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + KernelDependencies()), +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + "OpDecorate %intt FPFastMathMode Fast\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + KernelDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt LinkageAttributes \"other\" Import\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + std::vector{"Linkage"}), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt NoContraction\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %intt InputAttachmentIndex 0\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + std::vector{"InputAttachment"}), +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + "OpDecorate %intt Alignment 4\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + KernelDependencies()) +)),); + +// clang-format on +INSTANTIATE_TEST_CASE_P( + DecorationSpecId, ValidateCapability, + Combine( + ValuesIn(AllSpirV10Capabilities()), + Values(std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + + "OpDecorate %1 SpecId 1\n" + "%intt = OpTypeInt 32 0\n" + "%1 = OpSpecConstant %intt 0\n" + + std::string(kVoidFVoid), + ShaderDependencies()))), ); + +INSTANTIATE_TEST_CASE_P( + DecorationV11, ValidateCapabilityV11, + Combine(ValuesIn(AllCapabilities()), + Values(std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %p MaxByteOffset 0 " + "%i32 = OpTypeInt 32 0 " + "%pi32 = OpTypePointer Workgroup %i32 " + "%p = OpVariable %pi32 Workgroup " + + std::string(kVoidFVoid), + AddressesDependencies()), + // Trying to test OpDecorate here, but if this fails due to + // incorrect OpMemoryModel validation, that must also be + // fixed. + std::make_pair( + std::string("OpMemoryModel Logical OpenCL " + "OpEntryPoint Kernel %func \"compute\" \n" + "OpDecorate %1 SpecId 1 " + "%intt = OpTypeInt 32 0 " + "%1 = OpSpecConstant %intt 0") + + std::string(kVoidFVoid), + KernelDependencies()), + std::make_pair( + std::string("OpMemoryModel Logical Simple " + "OpEntryPoint Vertex %func \"shader\" \n" + "OpDecorate %1 SpecId 1 " + "%intt = OpTypeInt 32 0 " + "%1 = OpSpecConstant %intt 0") + + std::string(kVoidFVoid), + ShaderDependencies()))), ); +// clang-format off + +INSTANTIATE_TEST_CASE_P(BuiltIn, ValidateCapability, + Combine( + ValuesIn(AllCapabilities()), + Values( +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn Position\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +// Just mentioning PointSize, ClipDistance, or CullDistance as a BuiltIn does +// not trigger the requirement for the associated capability. +// See https://github.com/KhronosGroup/SPIRV-Tools/issues/365 +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn PointSize\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + AllCapabilities()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn ClipDistance\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + AllCapabilities()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn CullDistance\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + AllCapabilities()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn VertexId\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn InstanceId\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn PrimitiveId\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + GeometryTessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn InvocationId\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + GeometryTessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn Layer\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + GeometryDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn ViewportIndex\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + std::vector{"MultiViewport"}), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn TessLevelOuter\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn TessLevelInner\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn TessCoord\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn PatchVertices\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + TessellationDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn FragCoord\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn PointCoord\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn FrontFacing\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn SampleId\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + std::vector{"SampleRateShading"}), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn SamplePosition\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + std::vector{"SampleRateShading"}), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn SampleMask\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn FragDepth\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn HelperInvocation\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn VertexIndex\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn InstanceIndex\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn NumWorkgroups\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + AllCapabilities()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn WorkgroupSize\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + AllCapabilities()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn WorkgroupId\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + AllCapabilities()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn LocalInvocationId\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + AllCapabilities()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn GlobalInvocationId\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + AllCapabilities()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn LocalInvocationIndex\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + AllCapabilities()), +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + + "OpDecorate %intt BuiltIn WorkDim\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + KernelDependencies()), +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + + "OpDecorate %intt BuiltIn GlobalSize\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + KernelDependencies()), +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + + "OpDecorate %intt BuiltIn EnqueuedWorkgroupSize\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + KernelDependencies()), +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + + "OpDecorate %intt BuiltIn GlobalOffset\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + KernelDependencies()), +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + + "OpDecorate %intt BuiltIn GlobalLinearId\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + KernelDependencies()), +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + + "OpDecorate %intt BuiltIn SubgroupSize\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + KernelAndGroupNonUniformDependencies()), +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + + "OpDecorate %intt BuiltIn SubgroupMaxSize\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + KernelDependencies()), +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + + "OpDecorate %intt BuiltIn NumSubgroups\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + KernelAndGroupNonUniformDependencies()), +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + + "OpDecorate %intt BuiltIn NumEnqueuedSubgroups\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + KernelDependencies()), +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + + "OpDecorate %intt BuiltIn SubgroupId\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + KernelAndGroupNonUniformDependencies()), +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + + "OpDecorate %intt BuiltIn SubgroupLocalInvocationId\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + KernelAndGroupNonUniformDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn VertexIndex\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()), +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "OpDecorate %intt BuiltIn InstanceIndex\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + ShaderDependencies()) +)),); + +// Ensure that mere mention of PointSize, ClipDistance, or CullDistance as +// BuiltIns does not trigger the requirement for the associated +// capability. +// See https://github.com/KhronosGroup/SPIRV-Tools/issues/365 +INSTANTIATE_TEST_CASE_P(BuiltIn, ValidateCapabilityVulkan10, + Combine( + // All capabilities to try. + ValuesIn(AllSpirV10Capabilities()), + Values( +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + "OpMemberDecorate %block 0 BuiltIn PointSize\n" + "%f32 = OpTypeFloat 32\n" + "%block = OpTypeStruct %f32\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + // Capabilities which should succeed. + AllVulkan10Capabilities()), +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + "OpMemberDecorate %block 0 BuiltIn ClipDistance\n" + "%f32 = OpTypeFloat 32\n" + "%intt = OpTypeInt 32 0\n" + "%intt_4 = OpConstant %intt 4\n" + "%f32arr4 = OpTypeArray %f32 %intt_4\n" + "%block = OpTypeStruct %f32arr4\n" + std::string(kVoidFVoid), + AllVulkan10Capabilities()), +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + "OpMemberDecorate %block 0 BuiltIn CullDistance\n" + "%f32 = OpTypeFloat 32\n" + "%intt = OpTypeInt 32 0\n" + "%intt_4 = OpConstant %intt 4\n" + "%f32arr4 = OpTypeArray %f32 %intt_4\n" + "%block = OpTypeStruct %f32arr4\n" + std::string(kVoidFVoid), + AllVulkan10Capabilities()) +)),); + +INSTANTIATE_TEST_CASE_P(BuiltIn, ValidateCapabilityOpenGL40, + Combine( + // OpenGL 4.0 is based on SPIR-V 1.0 + ValuesIn(AllSpirV10Capabilities()), + Values( +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + + "OpDecorate %intt BuiltIn PointSize\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + AllSpirV10Capabilities()), +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + + "OpDecorate %intt BuiltIn ClipDistance\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + AllSpirV10Capabilities()), +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + + "OpDecorate %intt BuiltIn CullDistance\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + AllSpirV10Capabilities()) +)),); + +INSTANTIATE_TEST_CASE_P(Capabilities, ValidateCapabilityWebGPU, + Combine( + // All capabilities to try. + ValuesIn(AllCapabilities()), + Values( +std::make_pair(std::string(kVulkanMemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + std::string(kVoidFVoid), + AllWebGPUCapabilities()) +)),); + +INSTANTIATE_TEST_CASE_P(Capabilities, ValidateCapabilityVulkan11, + Combine( + // All capabilities to try. + ValuesIn(AllCapabilities()), + Values( +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + + "OpDecorate %intt BuiltIn PointSize\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + AllVulkan11Capabilities()), +std::make_pair(std::string(kGLSL450MemoryModel) + + "OpEntryPoint Vertex %func \"shader\" \n" + + "OpDecorate %intt BuiltIn CullDistance\n" + "%intt = OpTypeInt 32 0\n" + std::string(kVoidFVoid), + AllVulkan11Capabilities()) +)),); + +// TODO(umar): Selection Control +// TODO(umar): Loop Control +// TODO(umar): Function Control +// TODO(umar): Memory Semantics +// TODO(umar): Memory Access +// TODO(umar): Scope +// TODO(umar): Group Operation +// TODO(umar): Kernel Enqueue Flags +// TODO(umar): Kernel Profiling Flags + +INSTANTIATE_TEST_CASE_P(MatrixOp, ValidateCapability, + Combine( + ValuesIn(AllCapabilities()), + Values( +std::make_pair(std::string(kOpenCLMemoryModel) + + "OpEntryPoint Kernel %func \"compute\" \n" + + "%f32 = OpTypeFloat 32\n" + "%vec3 = OpTypeVector %f32 3\n" + "%mat33 = OpTypeMatrix %vec3 3\n" + std::string(kVoidFVoid), + MatrixDependencies()))),); +// clang-format on + +#if 0 +// TODO(atgoo@github.com) The following test is not valid as it generates +// invalid combinations of images, instructions and image operands. +// +// Creates assembly containing an OpImageFetch instruction using operands for +// the image-operands part. The assembly defines constants %fzero and %izero +// that can be used for operands where IDs are required. The assembly is valid, +// apart from not declaring any capabilities required by the operands. +string ImageOperandsTemplate(const std::string& operands) { + ostringstream ss; + // clang-format off + ss << R"( +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL + +%i32 = OpTypeInt 32 0 +%f32 = OpTypeFloat 32 +%v4i32 = OpTypeVector %i32 4 +%timg = OpTypeImage %i32 2D 0 0 0 0 Unknown +%pimg = OpTypePointer UniformConstant %timg +%tfun = OpTypeFunction %i32 + +%vimg = OpVariable %pimg UniformConstant +%izero = OpConstant %i32 0 +%fzero = OpConstant %f32 0. + +%main = OpFunction %i32 None %tfun +%lbl = OpLabel +%img = OpLoad %timg %vimg +%r1 = OpImageFetch %v4i32 %img %izero )" << operands << R"( +OpReturnValue %izero +OpFunctionEnd +)"; + // clang-format on + return ss.str(); +} + +INSTANTIATE_TEST_CASE_P( + TwoImageOperandsMask, ValidateCapability, + Combine( + ValuesIn(AllCapabilities()), + Values(std::make_pair(ImageOperandsTemplate("Bias|Lod %fzero %fzero"), + ShaderDependencies()), + std::make_pair(ImageOperandsTemplate("Lod|Offset %fzero %izero"), + std::vector{"ImageGatherExtended"}), + std::make_pair(ImageOperandsTemplate("Sample|MinLod %izero %fzero"), + std::vector{"MinLod"}), + std::make_pair(ImageOperandsTemplate("Lod|Sample %fzero %izero"), + AllCapabilities()))), ); +#endif + +// TODO(umar): Instruction capability checks + +spv_result_t spvCoreOperandTableNameLookup(spv_target_env env, + const spv_operand_table table, + const spv_operand_type_t type, + const char* name, + const size_t nameLength) { + if (!table) return SPV_ERROR_INVALID_TABLE; + if (!name) return SPV_ERROR_INVALID_POINTER; + + for (uint64_t typeIndex = 0; typeIndex < table->count; ++typeIndex) { + const auto& group = table->types[typeIndex]; + if (type != group.type) continue; + for (uint64_t index = 0; index < group.count; ++index) { + const auto& entry = group.entries[index]; + // Check for min version only. + if (spvVersionForTargetEnv(env) >= entry.minVersion && + nameLength == strlen(entry.name) && + !strncmp(entry.name, name, nameLength)) { + return SPV_SUCCESS; + } + } + } + + return SPV_ERROR_INVALID_LOOKUP; +} + +// True if capability exists in core spec of env. +bool Exists(const std::string& capability, spv_target_env env) { + ScopedContext sc(env); + return SPV_SUCCESS == + spvCoreOperandTableNameLookup(env, sc.context->operand_table, + SPV_OPERAND_TYPE_CAPABILITY, + capability.c_str(), capability.size()); +} + +TEST_P(ValidateCapability, Capability) { + const std::string capability = Capability(GetParam()); + spv_target_env env = SPV_ENV_UNIVERSAL_1_0; + if (!capability.empty()) { + if (Exists(capability, SPV_ENV_UNIVERSAL_1_0)) + env = SPV_ENV_UNIVERSAL_1_0; + else if (Exists(capability, SPV_ENV_UNIVERSAL_1_1)) + env = SPV_ENV_UNIVERSAL_1_1; + else if (Exists(capability, SPV_ENV_UNIVERSAL_1_2)) + env = SPV_ENV_UNIVERSAL_1_2; + else + env = SPV_ENV_UNIVERSAL_1_3; + } + const std::string test_code = MakeAssembly(GetParam()); + CompileSuccessfully(test_code, env); + ASSERT_EQ(ExpectedResult(GetParam()), ValidateInstructions(env)) + << "target env: " << spvTargetEnvDescription(env) << "\ntest code:\n" + << test_code; +} + +TEST_P(ValidateCapabilityV11, Capability) { + const std::string capability = Capability(GetParam()); + if (Exists(capability, SPV_ENV_UNIVERSAL_1_1)) { + const std::string test_code = MakeAssembly(GetParam()); + CompileSuccessfully(test_code, SPV_ENV_UNIVERSAL_1_1); + ASSERT_EQ(ExpectedResult(GetParam()), + ValidateInstructions(SPV_ENV_UNIVERSAL_1_1)) + << test_code; + } +} + +TEST_P(ValidateCapabilityVulkan10, Capability) { + const std::string capability = Capability(GetParam()); + if (Exists(capability, SPV_ENV_VULKAN_1_0)) { + const std::string test_code = MakeAssembly(GetParam()); + CompileSuccessfully(test_code, SPV_ENV_VULKAN_1_0); + ASSERT_EQ(ExpectedResult(GetParam()), + ValidateInstructions(SPV_ENV_VULKAN_1_0)) + << test_code; + } +} + +TEST_P(ValidateCapabilityVulkan11, Capability) { + const std::string capability = Capability(GetParam()); + if (Exists(capability, SPV_ENV_VULKAN_1_1)) { + const std::string test_code = MakeAssembly(GetParam()); + CompileSuccessfully(test_code, SPV_ENV_VULKAN_1_1); + ASSERT_EQ(ExpectedResult(GetParam()), + ValidateInstructions(SPV_ENV_VULKAN_1_1)) + << test_code; + } +} + +TEST_P(ValidateCapabilityOpenGL40, Capability) { + const std::string capability = Capability(GetParam()); + if (Exists(capability, SPV_ENV_OPENGL_4_0)) { + const std::string test_code = MakeAssembly(GetParam()); + CompileSuccessfully(test_code, SPV_ENV_OPENGL_4_0); + ASSERT_EQ(ExpectedResult(GetParam()), + ValidateInstructions(SPV_ENV_OPENGL_4_0)) + << test_code; + } +} + +TEST_P(ValidateCapabilityWebGPU, Capability) { + const std::string capability = Capability(GetParam()); + if (Exists(capability, SPV_ENV_WEBGPU_0)) { + const std::string test_code = MakeAssembly(GetParam()); + CompileSuccessfully(test_code, SPV_ENV_WEBGPU_0); + ASSERT_EQ(ExpectedResult(GetParam()), + ValidateInstructions(SPV_ENV_WEBGPU_0)) + << test_code; + } +} + +TEST_F(ValidateCapability, SemanticsIdIsAnIdNotALiteral) { + // From https://github.com/KhronosGroup/SPIRV-Tools/issues/248 + // The validator was interpreting the memory semantics ID number + // as the value to be checked rather than an ID that references + // another value to be checked. + // In this case a raw ID of 64 was mistaken to mean a literal + // semantic value of UniformMemory, which would require the Shader + // capability. + const char str[] = R"( +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL + +; %i32 has ID 1 +%i32 = OpTypeInt 32 0 +%tf = OpTypeFunction %i32 +%pi32 = OpTypePointer CrossWorkgroup %i32 +%var = OpVariable %pi32 CrossWorkgroup +%c = OpConstant %i32 100 +%scope = OpConstant %i32 1 ; Device scope + +; Fake an instruction with 64 as the result id. +; !64 = OpConstantNull %i32 +!0x3002e !1 !64 + +%f = OpFunction %i32 None %tf +%l = OpLabel +%result = OpAtomicIAdd %i32 %var %scope !64 %c +OpReturnValue %result +OpFunctionEnd +)"; + + CompileSuccessfully(str); + + // Since we are forcing usage of 64, the "id bound" in the binary header + // must be overwritten so that 64 is considered within bound. + // ID Bound is at index 3 of the binary. Set it to 65. + OverwriteAssembledBinary(3, 65); + + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateCapability, IntSignednessKernelGood) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +%i32 = OpTypeInt 32 0 +)"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateCapability, IntSignednessKernelBad) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +%i32 = OpTypeInt 32 1 +)"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("The Signedness in OpTypeInt must always be 0 when " + "Kernel capability is used.")); +} + +TEST_F(ValidateCapability, IntSignednessShaderGood) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%u32 = OpTypeInt 32 0 +%i32 = OpTypeInt 32 1 +)"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateCapability, NonVulkan10Capability) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%u32 = OpTypeInt 32 0 +%i32 = OpTypeInt 32 1 +)"; + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_0); + EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, + ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Capability Linkage is not allowed by Vulkan 1.0")); +} + +TEST_F(ValidateCapability, Vulkan10EnabledByExtension) { + const std::string spirv = R"( +OpCapability Shader +OpCapability DrawParameters +OpExtension "SPV_KHR_shader_draw_parameters" +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %func "shader" +OpMemberDecorate %block 0 BuiltIn PointSize +%f32 = OpTypeFloat 32 +%block = OpTypeStruct %f32 +)" + std::string(kVoidFVoid); + + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_0); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0)); +} + +TEST_F(ValidateCapability, Vulkan10NotEnabledByExtension) { + const std::string spirv = R"( +OpCapability Shader +OpCapability DrawParameters +OpMemoryModel Logical GLSL450 +OpEntryPoint Vertex %func "shader" +OpDecorate %intt BuiltIn PointSize +%intt = OpTypeInt 32 0 +)" + std::string(kVoidFVoid); + + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_0); + EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, + ValidateInstructions(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Capability DrawParameters is not allowed by Vulkan 1.0")); +} + +TEST_F(ValidateCapability, NonOpenCL12FullCapability) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Addresses +OpCapability Linkage +OpCapability Pipes +OpMemoryModel Physical64 OpenCL +%u32 = OpTypeInt 32 0 +)"; + CompileSuccessfully(spirv, SPV_ENV_OPENCL_1_2); + EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, + ValidateInstructions(SPV_ENV_OPENCL_1_2)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Capability Pipes is not allowed by OpenCL 1.2 Full Profile")); +} + +TEST_F(ValidateCapability, OpenCL12FullEnabledByCapability) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Addresses +OpCapability Linkage +OpCapability ImageBasic +OpCapability Sampled1D +OpMemoryModel Physical64 OpenCL +%u32 = OpTypeInt 32 0 +)" + std::string(kVoidFVoid); + + CompileSuccessfully(spirv, SPV_ENV_OPENCL_1_2); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_OPENCL_1_2)); +} + +TEST_F(ValidateCapability, OpenCL12FullNotEnabledByCapability) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Addresses +OpCapability Linkage +OpCapability Sampled1D +OpMemoryModel Physical64 OpenCL +%u32 = OpTypeInt 32 0 +)" + std::string(kVoidFVoid); + + CompileSuccessfully(spirv, SPV_ENV_OPENCL_1_2); + EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, + ValidateInstructions(SPV_ENV_OPENCL_1_2)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Capability Sampled1D is not allowed by OpenCL 1.2 Full Profile")); +} + +TEST_F(ValidateCapability, NonOpenCL12EmbeddedCapability) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Addresses +OpCapability Linkage +OpCapability Int64 +OpMemoryModel Physical64 OpenCL +%u32 = OpTypeInt 32 0 +)"; + CompileSuccessfully(spirv, SPV_ENV_OPENCL_EMBEDDED_1_2); + EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, + ValidateInstructions(SPV_ENV_OPENCL_EMBEDDED_1_2)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Capability Int64 is not allowed by OpenCL 1.2 Embedded Profile")); +} + +TEST_F(ValidateCapability, OpenCL12EmbeddedEnabledByCapability) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Addresses +OpCapability Linkage +OpCapability ImageBasic +OpCapability Sampled1D +OpMemoryModel Physical64 OpenCL +%u32 = OpTypeInt 32 0 +)" + std::string(kVoidFVoid); + + CompileSuccessfully(spirv, SPV_ENV_OPENCL_EMBEDDED_1_2); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_OPENCL_EMBEDDED_1_2)); +} + +TEST_F(ValidateCapability, OpenCL12EmbeddedNotEnabledByCapability) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Addresses +OpCapability Linkage +OpCapability Sampled1D +OpMemoryModel Physical64 OpenCL +%u32 = OpTypeInt 32 0 +)" + std::string(kVoidFVoid); + + CompileSuccessfully(spirv, SPV_ENV_OPENCL_EMBEDDED_1_2); + EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, + ValidateInstructions(SPV_ENV_OPENCL_EMBEDDED_1_2)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Capability Sampled1D is not allowed by OpenCL 1.2 " + "Embedded Profile")); +} + +TEST_F(ValidateCapability, OpenCL20FullCapability) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Addresses +OpCapability Linkage +OpCapability Pipes +OpMemoryModel Physical64 OpenCL +%u32 = OpTypeInt 32 0 +)"; + CompileSuccessfully(spirv, SPV_ENV_OPENCL_2_0); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_OPENCL_2_0)); +} + +TEST_F(ValidateCapability, NonOpenCL20FullCapability) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Addresses +OpCapability Linkage +OpCapability Matrix +OpMemoryModel Physical64 OpenCL +%u32 = OpTypeInt 32 0 +)"; + CompileSuccessfully(spirv, SPV_ENV_OPENCL_2_0); + EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, + ValidateInstructions(SPV_ENV_OPENCL_2_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Capability Matrix is not allowed by OpenCL 2.0/2.1 Full Profile")); +} + +TEST_F(ValidateCapability, OpenCL20FullEnabledByCapability) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Addresses +OpCapability Linkage +OpCapability ImageBasic +OpCapability Sampled1D +OpMemoryModel Physical64 OpenCL +%u32 = OpTypeInt 32 0 +)" + std::string(kVoidFVoid); + + CompileSuccessfully(spirv, SPV_ENV_OPENCL_2_0); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_OPENCL_2_0)); +} + +TEST_F(ValidateCapability, OpenCL20FullNotEnabledByCapability) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Addresses +OpCapability Linkage +OpCapability Sampled1D +OpMemoryModel Physical64 OpenCL +%u32 = OpTypeInt 32 0 +)" + std::string(kVoidFVoid); + + CompileSuccessfully(spirv, SPV_ENV_OPENCL_2_0); + EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, + ValidateInstructions(SPV_ENV_OPENCL_2_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Capability Sampled1D is not allowed by OpenCL 2.0/2.1 " + "Full Profile")); +} + +TEST_F(ValidateCapability, NonOpenCL20EmbeddedCapability) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Addresses +OpCapability Linkage +OpCapability Int64 +OpMemoryModel Physical64 OpenCL +%u32 = OpTypeInt 32 0 +)"; + CompileSuccessfully(spirv, SPV_ENV_OPENCL_EMBEDDED_2_0); + EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, + ValidateInstructions(SPV_ENV_OPENCL_EMBEDDED_2_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Capability Int64 is not allowed by OpenCL 2.0/2.1 " + "Embedded Profile")); +} + +TEST_F(ValidateCapability, OpenCL20EmbeddedEnabledByCapability) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Addresses +OpCapability Linkage +OpCapability ImageBasic +OpCapability Sampled1D +OpMemoryModel Physical64 OpenCL +%u32 = OpTypeInt 32 0 +)" + std::string(kVoidFVoid); + + CompileSuccessfully(spirv, SPV_ENV_OPENCL_EMBEDDED_2_0); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_OPENCL_EMBEDDED_2_0)); +} + +TEST_F(ValidateCapability, OpenCL20EmbeddedNotEnabledByCapability) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Addresses +OpCapability Linkage +OpCapability Sampled1D +OpMemoryModel Physical64 OpenCL +%u32 = OpTypeInt 32 0 +)" + std::string(kVoidFVoid); + + CompileSuccessfully(spirv, SPV_ENV_OPENCL_EMBEDDED_2_0); + EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, + ValidateInstructions(SPV_ENV_OPENCL_EMBEDDED_2_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Capability Sampled1D is not allowed by OpenCL 2.0/2.1 " + "Embedded Profile")); +} + +TEST_F(ValidateCapability, OpenCL22FullCapability) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Addresses +OpCapability Linkage +OpCapability PipeStorage +OpMemoryModel Physical64 OpenCL +%u32 = OpTypeInt 32 0 +)"; + CompileSuccessfully(spirv, SPV_ENV_OPENCL_2_2); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_OPENCL_2_2)); +} + +TEST_F(ValidateCapability, NonOpenCL22FullCapability) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Addresses +OpCapability Linkage +OpCapability Matrix +OpMemoryModel Physical64 OpenCL +%u32 = OpTypeInt 32 0 +)"; + CompileSuccessfully(spirv, SPV_ENV_OPENCL_2_2); + EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, + ValidateInstructions(SPV_ENV_OPENCL_2_2)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Capability Matrix is not allowed by OpenCL 2.2 Full Profile")); +} + +TEST_F(ValidateCapability, OpenCL22FullEnabledByCapability) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Addresses +OpCapability Linkage +OpCapability ImageBasic +OpCapability Sampled1D +OpMemoryModel Physical64 OpenCL +%u32 = OpTypeInt 32 0 +)" + std::string(kVoidFVoid); + + CompileSuccessfully(spirv, SPV_ENV_OPENCL_2_2); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_OPENCL_2_2)); +} + +TEST_F(ValidateCapability, OpenCL22FullNotEnabledByCapability) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Addresses +OpCapability Linkage +OpCapability Sampled1D +OpMemoryModel Physical64 OpenCL +%u32 = OpTypeInt 32 0 +)" + std::string(kVoidFVoid); + + CompileSuccessfully(spirv, SPV_ENV_OPENCL_2_2); + EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, + ValidateInstructions(SPV_ENV_OPENCL_2_2)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Capability Sampled1D is not allowed by OpenCL 2.2 Full Profile")); +} + +TEST_F(ValidateCapability, NonOpenCL22EmbeddedCapability) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Addresses +OpCapability Linkage +OpCapability Int64 +OpMemoryModel Physical64 OpenCL +%u32 = OpTypeInt 32 0 +)"; + CompileSuccessfully(spirv, SPV_ENV_OPENCL_EMBEDDED_2_2); + EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, + ValidateInstructions(SPV_ENV_OPENCL_EMBEDDED_2_2)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Capability Int64 is not allowed by OpenCL 2.2 Embedded Profile")); +} + +TEST_F(ValidateCapability, OpenCL22EmbeddedEnabledByCapability) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Addresses +OpCapability Linkage +OpCapability ImageBasic +OpCapability Sampled1D +OpMemoryModel Physical64 OpenCL +%u32 = OpTypeInt 32 0 +)" + std::string(kVoidFVoid); + + CompileSuccessfully(spirv, SPV_ENV_OPENCL_EMBEDDED_2_2); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_OPENCL_EMBEDDED_2_2)); +} + +TEST_F(ValidateCapability, OpenCL22EmbeddedNotEnabledByCapability) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Addresses +OpCapability Linkage +OpCapability Sampled1D +OpMemoryModel Physical64 OpenCL +%u32 = OpTypeInt 32 0 +)" + std::string(kVoidFVoid); + + CompileSuccessfully(spirv, SPV_ENV_OPENCL_EMBEDDED_2_2); + EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, + ValidateInstructions(SPV_ENV_OPENCL_EMBEDDED_2_2)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Capability Sampled1D is not allowed by OpenCL 2.2 " + "Embedded Profile")); +} + +// Three tests to check enablement of an enum (a decoration) which is not +// in core, and is directly enabled by a capability, but not directly enabled +// by an extension. See https://github.com/KhronosGroup/SPIRV-Tools/issues/1596 + +TEST_F(ValidateCapability, DecorationFromExtensionMissingEnabledByCapability) { + // Decoration ViewportRelativeNV is enabled by ShaderViewportMaskNV, which in + // turn is enabled by SPV_NV_viewport_array2. + const std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical Simple +OpDecorate %void ViewportRelativeNV +)" + std::string(kVoidFVoid); + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_0); + EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Operand 2 of Decorate requires one of these " + "capabilities: ShaderViewportMaskNV")); +} + +TEST_F(ValidateCapability, CapabilityEnabledByMissingExtension) { + // Capability ShaderViewportMaskNV is enabled by SPV_NV_viewport_array2. + const std::string spirv = R"( +OpCapability Shader +OpCapability ShaderViewportMaskNV +OpMemoryModel Logical Simple +)" + std::string(kVoidFVoid); + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_0); + EXPECT_EQ(SPV_ERROR_MISSING_EXTENSION, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("operand 5255 requires one of these extensions: " + "SPV_NV_viewport_array2")); +} + +TEST_F(ValidateCapability, + DecorationEnabledByCapabilityEnabledByPresentExtension) { + // Decoration ViewportRelativeNV is enabled by ShaderViewportMaskNV, which in + // turn is enabled by SPV_NV_viewport_array2. + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability ShaderViewportMaskNV +OpExtension "SPV_NV_viewport_array2" +OpMemoryModel Logical Simple +OpDecorate %void ViewportRelativeNV +%void = OpTypeVoid +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_0); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_0)) + << getDiagnosticString(); +} + +// Three tests to check enablement of an instruction which is not in core, and +// is directly enabled by a capability, but not directly enabled by an +// extension. See https://github.com/KhronosGroup/SPIRV-Tools/issues/1624 +// Instruction OpSubgroupShuffleINTEL is enabled by SubgroupShuffleINTEL, which +// in turn is enabled by SPV_INTEL_subgroups. + +TEST_F(ValidateCapability, InstructionFromExtensionMissingEnabledByCapability) { + // Decoration ViewportRelativeNV is enabled by ShaderViewportMaskNV, which in + // turn is enabled by SPV_NV_viewport_array2. + const std::string spirv = R"( +OpCapability Kernel +OpCapability Addresses +; OpCapability SubgroupShuffleINTEL +OpExtension "SPV_INTEL_subgroups" +OpMemoryModel Physical32 OpenCL +OpEntryPoint Kernel %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%voidfn = OpTypeFunction %void +%zero = OpConstant %uint 0 +%main = OpFunction %void None %voidfn +%entry = OpLabel +%foo = OpSubgroupShuffleINTEL %uint %zero %zero +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_0); + EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Opcode SubgroupShuffleINTEL requires one of these " + "capabilities: SubgroupShuffleINTEL")); +} + +TEST_F(ValidateCapability, + InstructionEnablingCapabilityEnabledByMissingExtension) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Addresses +OpCapability SubgroupShuffleINTEL +; OpExtension "SPV_INTEL_subgroups" +OpMemoryModel Physical32 OpenCL +OpEntryPoint Kernel %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%voidfn = OpTypeFunction %void +%zero = OpConstant %uint 0 +%main = OpFunction %void None %voidfn +%entry = OpLabel +%foo = OpSubgroupShuffleINTEL %uint %zero %zero +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_0); + EXPECT_EQ(SPV_ERROR_MISSING_EXTENSION, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("operand 5568 requires one of these extensions: " + "SPV_INTEL_subgroups")); +} + +TEST_F(ValidateCapability, + InstructionEnabledByCapabilityEnabledByPresentExtension) { + const std::string spirv = R"( +OpCapability Kernel +OpCapability Addresses +OpCapability SubgroupShuffleINTEL +OpExtension "SPV_INTEL_subgroups" +OpMemoryModel Physical32 OpenCL +OpEntryPoint Kernel %main "main" +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%voidfn = OpTypeFunction %void +%zero = OpConstant %uint 0 +%main = OpFunction %void None %voidfn +%entry = OpLabel +%foo = OpSubgroupShuffleINTEL %uint %zero %zero +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_0); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_0)) + << getDiagnosticString(); +} + +TEST_F(ValidateCapability, VulkanMemoryModelWithVulkanKHR) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)) + << getDiagnosticString(); +} + +TEST_F(ValidateCapability, VulkanMemoryModelWithGLSL450) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical GLSL450 +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("VulkanMemoryModelKHR capability must only be " + "specified if the VulkanKHR memory model is used")); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_cfg_test.cpp b/test/val/val_cfg_test.cpp new file mode 100644 index 000000000..aed0a5788 --- /dev/null +++ b/test/val/val_cfg_test.cpp @@ -0,0 +1,2067 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validation tests for Control Flow Graph + +#include +#include +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" + +#include "source/diagnostic.h" +#include "source/val/validate.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::HasSubstr; +using ::testing::MatchesRegex; + +using ValidateCFG = spvtest::ValidateBase; +using spvtest::ScopedContext; + +std::string nameOps() { return ""; } + +template +std::string nameOps(std::pair head, Args... names) { + return "OpName %" + head.first + " \"" + head.second + "\"\n" + + nameOps(names...); +} + +template +std::string nameOps(std::string head, Args... names) { + return "OpName %" + head + " \"" + head + "\"\n" + nameOps(names...); +} + +/// This class allows the easy creation of complex control flow without writing +/// SPIR-V. This class is used in the test cases below. +class Block { + std::string label_; + std::string body_; + SpvOp type_; + std::vector successors_; + + public: + /// Creates a Block with a given label + /// + /// @param[in]: label the label id of the block + /// @param[in]: type the branch instruciton that ends the block + explicit Block(std::string label, SpvOp type = SpvOpBranch) + : label_(label), body_(), type_(type), successors_() {} + + /// Sets the instructions which will appear in the body of the block + Block& SetBody(std::string body) { + body_ = body; + return *this; + } + + Block& AppendBody(std::string body) { + body_ += body; + return *this; + } + + /// Converts the block into a SPIR-V string + operator std::string() { + std::stringstream out; + out << std::setw(8) << "%" + label_ + " = OpLabel \n"; + if (!body_.empty()) { + out << body_; + } + + switch (type_) { + case SpvOpBranchConditional: + out << "OpBranchConditional %cond "; + for (Block& b : successors_) { + out << "%" + b.label_ + " "; + } + break; + case SpvOpSwitch: { + out << "OpSwitch %one %" + successors_.front().label_; + std::stringstream ss; + for (size_t i = 1; i < successors_.size(); i++) { + ss << " " << i << " %" << successors_[i].label_; + } + out << ss.str(); + } break; + case SpvOpReturn: + assert(successors_.size() == 0); + out << "OpReturn\n"; + break; + case SpvOpUnreachable: + assert(successors_.size() == 0); + out << "OpUnreachable\n"; + break; + case SpvOpBranch: + assert(successors_.size() == 1); + out << "OpBranch %" + successors_.front().label_; + break; + default: + assert(1 == 0 && "Unhandled"); + } + out << "\n"; + + return out.str(); + } + friend Block& operator>>(Block& curr, std::vector successors); + friend Block& operator>>(Block& lhs, Block& successor); +}; + +/// Assigns the successors for the Block on the lhs +Block& operator>>(Block& lhs, std::vector successors) { + if (lhs.type_ == SpvOpBranchConditional) { + assert(successors.size() == 2); + } else if (lhs.type_ == SpvOpSwitch) { + assert(successors.size() > 1); + } + lhs.successors_ = successors; + return lhs; +} + +/// Assigns the successor for the Block on the lhs +Block& operator>>(Block& lhs, Block& successor) { + assert(lhs.type_ == SpvOpBranch); + lhs.successors_.push_back(successor); + return lhs; +} + +const char* header(SpvCapability cap) { + static const char* shader_header = + "OpCapability Shader\n" + "OpCapability Linkage\n" + "OpMemoryModel Logical GLSL450\n"; + + static const char* kernel_header = + "OpCapability Kernel\n" + "OpCapability Linkage\n" + "OpMemoryModel Logical OpenCL\n"; + + return (cap == SpvCapabilityShader) ? shader_header : kernel_header; +} + +const char* types_consts() { + static const char* types = + "%voidt = OpTypeVoid\n" + "%boolt = OpTypeBool\n" + "%intt = OpTypeInt 32 0\n" + "%one = OpConstant %intt 1\n" + "%two = OpConstant %intt 2\n" + "%ptrt = OpTypePointer Function %intt\n" + "%funct = OpTypeFunction %voidt\n"; + + return types; +} + +INSTANTIATE_TEST_CASE_P(StructuredControlFlow, ValidateCFG, + ::testing::Values(SpvCapabilityShader, + SpvCapabilityKernel)); + +TEST_P(ValidateCFG, LoopReachableFromEntryButNeverLeadingToReturn) { + // In this case, the loop is reachable from a node without a predecessor, + // but never reaches a node with a return. + // + // This motivates the need for the pseudo-exit node to have a node + // from a cycle in its predecessors list. Otherwise the validator's + // post-dominance calculation will go into an infinite loop. + // + // For more motivation, see + // https://github.com/KhronosGroup/SPIRV-Tools/issues/279 + std::string str = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + + OpName %entry "entry" + OpName %loop "loop" + OpName %exit "exit" + +%voidt = OpTypeVoid +%funct = OpTypeFunction %voidt + +%main = OpFunction %voidt None %funct +%entry = OpLabel + OpBranch %loop +%loop = OpLabel + OpLoopMerge %exit %loop None + OpBranch %loop +%exit = OpLabel + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()) << str; +} + +TEST_P(ValidateCFG, LoopUnreachableFromEntryButLeadingToReturn) { + // In this case, the loop is not reachable from a node without a + // predecessor, but eventually reaches a node with a return. + // + // This motivates the need for the pseudo-entry node to have a node + // from a cycle in its successors list. Otherwise the validator's + // dominance calculation will go into an infinite loop. + // + // For more motivation, see + // https://github.com/KhronosGroup/SPIRV-Tools/issues/279 + // Before that fix, we'd have an infinite loop when calculating + // post-dominators. + std::string str = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + + OpName %entry "entry" + OpName %loop "loop" + OpName %cont "cont" + OpName %exit "exit" + +%voidt = OpTypeVoid +%funct = OpTypeFunction %voidt +%boolt = OpTypeBool +%false = OpConstantFalse %boolt + +%main = OpFunction %voidt None %funct +%entry = OpLabel + OpReturn + +%loop = OpLabel + OpLoopMerge %exit %cont None + OpBranch %cont + +%cont = OpLabel + OpBranchConditional %false %loop %exit + +%exit = OpLabel + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()) + << str << getDiagnosticString(); +} + +TEST_P(ValidateCFG, Simple) { + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block loop("loop", SpvOpBranchConditional); + Block cont("cont"); + Block merge("merge", SpvOpReturn); + + entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + if (is_shader) { + loop.SetBody("OpLoopMerge %merge %cont None\n"); + } + + std::string str = header(GetParam()) + + nameOps("loop", "entry", "cont", "merge", + std::make_pair("func", "Main")) + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> loop; + str += loop >> std::vector({cont, merge}); + str += cont >> loop; + str += merge; + str += "OpFunctionEnd\n"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateCFG, Variable) { + Block entry("entry"); + Block cont("cont"); + Block exit("exit", SpvOpReturn); + + entry.SetBody("%var = OpVariable %ptrt Function\n"); + + std::string str = header(GetParam()) + + nameOps(std::make_pair("func", "Main")) + types_consts() + + " %func = OpFunction %voidt None %funct\n"; + str += entry >> cont; + str += cont >> exit; + str += exit; + str += "OpFunctionEnd\n"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateCFG, VariableNotInFirstBlockBad) { + Block entry("entry"); + Block cont("cont"); + Block exit("exit", SpvOpReturn); + + // This operation should only be performed in the entry block + cont.SetBody("%var = OpVariable %ptrt Function\n"); + + std::string str = header(GetParam()) + + nameOps(std::make_pair("func", "Main")) + types_consts() + + " %func = OpFunction %voidt None %funct\n"; + + str += entry >> cont; + str += cont >> exit; + str += exit; + str += "OpFunctionEnd\n"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Variables can only be defined in the first block of a function")); +} + +TEST_P(ValidateCFG, BlockSelfLoopIsOk) { + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block loop("loop", SpvOpBranchConditional); + Block merge("merge", SpvOpReturn); + + entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + if (is_shader) loop.SetBody("OpLoopMerge %merge %loop None\n"); + + std::string str = header(GetParam()) + + nameOps("loop", "merge", std::make_pair("func", "Main")) + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> loop; + // loop branches to itself, but does not trigger an error. + str += loop >> std::vector({merge, loop}); + str += merge; + str += "OpFunctionEnd\n"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()) << getDiagnosticString(); +} + +TEST_P(ValidateCFG, BlockAppearsBeforeDominatorBad) { + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block cont("cont"); + Block branch("branch", SpvOpBranchConditional); + Block merge("merge", SpvOpReturn); + + entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + if (is_shader) branch.SetBody("OpSelectionMerge %merge None\n"); + + std::string str = header(GetParam()) + + nameOps("cont", "branch", std::make_pair("func", "Main")) + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> branch; + str += cont >> merge; // cont appears before its dominator + str += branch >> std::vector({cont, merge}); + str += merge; + str += "OpFunctionEnd\n"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + MatchesRegex("Block .\\[%cont\\] appears in the binary " + "before its dominator .\\[%branch\\]\n" + " %branch = OpLabel\n")); +} + +TEST_P(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksBad) { + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block loop("loop"); + Block selection("selection", SpvOpBranchConditional); + Block merge("merge", SpvOpReturn); + + entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + if (is_shader) loop.SetBody(" OpLoopMerge %merge %loop None\n"); + + // cannot share the same merge + if (is_shader) selection.SetBody("OpSelectionMerge %merge None\n"); + + std::string str = + header(GetParam()) + nameOps("merge", std::make_pair("func", "Main")) + + types_consts() + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> loop; + str += loop >> selection; + str += selection >> std::vector({loop, merge}); + str += merge; + str += "OpFunctionEnd\n"; + + CompileSuccessfully(str); + if (is_shader) { + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + MatchesRegex("Block .\\[%merge\\] is already a merge block " + "for another header\n" + " %Main = OpFunction %void None %9\n")); + } else { + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + } +} + +TEST_P(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksSelectionBad) { + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block loop("loop", SpvOpBranchConditional); + Block selection("selection", SpvOpBranchConditional); + Block merge("merge", SpvOpReturn); + + entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + if (is_shader) selection.SetBody(" OpSelectionMerge %merge None\n"); + + // cannot share the same merge + if (is_shader) loop.SetBody(" OpLoopMerge %merge %loop None\n"); + + std::string str = + header(GetParam()) + nameOps("merge", std::make_pair("func", "Main")) + + types_consts() + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> selection; + str += selection >> std::vector({merge, loop}); + str += loop >> std::vector({loop, merge}); + str += merge; + str += "OpFunctionEnd\n"; + + CompileSuccessfully(str); + if (is_shader) { + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + MatchesRegex("Block .\\[%merge\\] is already a merge block " + "for another header\n" + " %Main = OpFunction %void None %9\n")); + } else { + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + } +} + +TEST_P(ValidateCFG, BranchTargetFirstBlockBadSinceEntryBlock) { + Block entry("entry"); + Block bad("bad"); + Block end("end", SpvOpReturn); + std::string str = header(GetParam()) + + nameOps("entry", "bad", std::make_pair("func", "Main")) + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> bad; + str += bad >> entry; // Cannot target entry block + str += end; + str += "OpFunctionEnd\n"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + MatchesRegex("First block .\\[%entry\\] of function " + ".\\[%Main\\] is targeted by block .\\[%bad\\]\n" + " %Main = OpFunction %void None %10\n")); +} + +TEST_P(ValidateCFG, BranchTargetFirstBlockBadSinceValue) { + Block entry("entry"); + entry.SetBody("%undef = OpUndef %voidt\n"); + Block bad("bad"); + Block end("end", SpvOpReturn); + Block badvalue("undef"); // This referenes the OpUndef. + std::string str = header(GetParam()) + + nameOps("entry", "bad", std::make_pair("func", "Main")) + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> bad; + str += + bad >> badvalue; // Check branch to a function value (it's not a block!) + str += end; + str += "OpFunctionEnd\n"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + MatchesRegex("Block\\(s\\) \\{11\\[%11\\]\\} are referenced but not " + "defined in function .\\[%Main\\]\n %Main = OpFunction " + "%void None %10\n")) + << str; +} + +TEST_P(ValidateCFG, BranchConditionalTrueTargetFirstBlockBad) { + Block entry("entry"); + Block bad("bad", SpvOpBranchConditional); + Block exit("exit", SpvOpReturn); + + entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + bad.SetBody(" OpLoopMerge %entry %exit None\n"); + + std::string str = header(GetParam()) + + nameOps("entry", "bad", std::make_pair("func", "Main")) + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> bad; + str += bad >> std::vector({entry, exit}); // cannot target entry block + str += exit; + str += "OpFunctionEnd\n"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + MatchesRegex("First block .\\[%entry\\] of function .\\[%Main\\] " + "is targeted by block .\\[%bad\\]\n" + " %Main = OpFunction %void None %10\n")); +} + +TEST_P(ValidateCFG, BranchConditionalFalseTargetFirstBlockBad) { + Block entry("entry"); + Block bad("bad", SpvOpBranchConditional); + Block t("t"); + Block merge("merge"); + Block end("end", SpvOpReturn); + + entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + bad.SetBody("OpLoopMerge %merge %cont None\n"); + + std::string str = header(GetParam()) + + nameOps("entry", "bad", std::make_pair("func", "Main")) + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> bad; + str += bad >> std::vector({t, entry}); + str += merge >> end; + str += end; + str += "OpFunctionEnd\n"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + MatchesRegex("First block .\\[%entry\\] of function .\\[%Main\\] " + "is targeted by block .\\[%bad\\]\n" + " %Main = OpFunction %void None %10\n")); +} + +TEST_P(ValidateCFG, SwitchTargetFirstBlockBad) { + Block entry("entry"); + Block bad("bad", SpvOpSwitch); + Block block1("block1"); + Block block2("block2"); + Block block3("block3"); + Block def("def"); // default block + Block merge("merge"); + Block end("end", SpvOpReturn); + + entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + bad.SetBody("OpSelectionMerge %merge None\n"); + + std::string str = header(GetParam()) + + nameOps("entry", "bad", std::make_pair("func", "Main")) + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> bad; + str += bad >> std::vector({def, block1, block2, block3, entry}); + str += def >> merge; + str += block1 >> merge; + str += block2 >> merge; + str += block3 >> merge; + str += merge >> end; + str += end; + str += "OpFunctionEnd\n"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + MatchesRegex("First block .\\[%entry\\] of function .\\[%Main\\] " + "is targeted by block .\\[%bad\\]\n" + " %Main = OpFunction %void None %10\n")); +} + +TEST_P(ValidateCFG, BranchToBlockInOtherFunctionBad) { + Block entry("entry"); + Block middle("middle", SpvOpBranchConditional); + Block end("end", SpvOpReturn); + + entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + middle.SetBody("OpSelectionMerge %end None\n"); + + Block entry2("entry2"); + Block middle2("middle2"); + Block end2("end2", SpvOpReturn); + + std::string str = + header(GetParam()) + nameOps("middle2", std::make_pair("func", "Main")) + + types_consts() + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> middle; + str += middle >> std::vector({end, middle2}); + str += end; + str += "OpFunctionEnd\n"; + + str += "%func2 = OpFunction %voidt None %funct\n"; + str += entry2 >> middle2; + str += middle2 >> end2; + str += end2; + str += "OpFunctionEnd\n"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + MatchesRegex("Block\\(s\\) \\{.\\[%middle2\\]\\} are referenced but not " + "defined in function .\\[%Main\\]\n" + " %Main = OpFunction %void None %9\n")); +} + +TEST_P(ValidateCFG, HeaderDoesntDominatesMergeBad) { + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block head("head", SpvOpBranchConditional); + Block f("f"); + Block merge("merge", SpvOpReturn); + + head.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + + if (is_shader) head.AppendBody("OpSelectionMerge %merge None\n"); + + std::string str = header(GetParam()) + + nameOps("head", "merge", std::make_pair("func", "Main")) + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> merge; + str += head >> std::vector({merge, f}); + str += f >> merge; + str += merge; + str += "OpFunctionEnd\n"; + + CompileSuccessfully(str); + if (is_shader) { + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + MatchesRegex("The selection construct with the selection header " + ".\\[%head\\] does not dominate the merge block " + ".\\[%merge\\]\n %merge = OpLabel\n")); + } else { + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + } +} + +TEST_P(ValidateCFG, HeaderDoesntStrictlyDominateMergeBad) { + // If a merge block is reachable, then it must be strictly dominated by + // its header block. + bool is_shader = GetParam() == SpvCapabilityShader; + Block head("head", SpvOpBranchConditional); + Block exit("exit", SpvOpReturn); + + head.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + + if (is_shader) head.AppendBody("OpSelectionMerge %head None\n"); + + std::string str = header(GetParam()) + + nameOps("head", "exit", std::make_pair("func", "Main")) + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += head >> std::vector({exit, exit}); + str += exit; + str += "OpFunctionEnd\n"; + + CompileSuccessfully(str); + if (is_shader) { + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + MatchesRegex("The selection construct with the selection header " + ".\\[%head\\] does not strictly dominate the merge block " + ".\\[%head\\]\n %head = OpLabel\n")); + } else { + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()) << str; + } +} + +TEST_P(ValidateCFG, UnreachableMerge) { + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block branch("branch", SpvOpBranchConditional); + Block t("t", SpvOpReturn); + Block f("f", SpvOpReturn); + Block merge("merge", SpvOpReturn); + + entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + if (is_shader) branch.AppendBody("OpSelectionMerge %merge None\n"); + + std::string str = header(GetParam()) + + nameOps("branch", "merge", std::make_pair("func", "Main")) + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> branch; + str += branch >> std::vector({t, f}); + str += t; + str += f; + str += merge; + str += "OpFunctionEnd\n"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateCFG, UnreachableMergeDefinedByOpUnreachable) { + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block branch("branch", SpvOpBranchConditional); + Block t("t", SpvOpReturn); + Block f("f", SpvOpReturn); + Block merge("merge", SpvOpUnreachable); + + entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + if (is_shader) branch.AppendBody("OpSelectionMerge %merge None\n"); + + std::string str = header(GetParam()) + + nameOps("branch", "merge", std::make_pair("func", "Main")) + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> branch; + str += branch >> std::vector({t, f}); + str += t; + str += f; + str += merge; + str += "OpFunctionEnd\n"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateCFG, UnreachableBlock) { + Block entry("entry"); + Block unreachable("unreachable"); + Block exit("exit", SpvOpReturn); + + std::string str = + header(GetParam()) + + nameOps("unreachable", "exit", std::make_pair("func", "Main")) + + types_consts() + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> exit; + str += unreachable >> exit; + str += exit; + str += "OpFunctionEnd\n"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateCFG, UnreachableBranch) { + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block unreachable("unreachable", SpvOpBranchConditional); + Block unreachablechildt("unreachablechildt"); + Block unreachablechildf("unreachablechildf"); + Block merge("merge"); + Block exit("exit", SpvOpReturn); + + unreachable.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + if (is_shader) unreachable.AppendBody("OpSelectionMerge %merge None\n"); + std::string str = + header(GetParam()) + + nameOps("unreachable", "exit", std::make_pair("func", "Main")) + + types_consts() + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> exit; + str += + unreachable >> std::vector({unreachablechildt, unreachablechildf}); + str += unreachablechildt >> merge; + str += unreachablechildf >> merge; + str += merge >> exit; + str += exit; + str += "OpFunctionEnd\n"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateCFG, EmptyFunction) { + std::string str = header(GetParam()) + std::string(types_consts()) + + R"(%func = OpFunction %voidt None %funct + %l = OpLabel + OpReturn + OpFunctionEnd)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateCFG, SingleBlockLoop) { + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block loop("loop", SpvOpBranchConditional); + Block exit("exit", SpvOpReturn); + + entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + if (is_shader) loop.AppendBody("OpLoopMerge %exit %loop None\n"); + + std::string str = header(GetParam()) + std::string(types_consts()) + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> loop; + str += loop >> std::vector({loop, exit}); + str += exit; + str += "OpFunctionEnd"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateCFG, NestedLoops) { + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block loop1("loop1"); + Block loop1_cont_break_block("loop1_cont_break_block", + SpvOpBranchConditional); + Block loop2("loop2", SpvOpBranchConditional); + Block loop2_merge("loop2_merge"); + Block loop1_merge("loop1_merge"); + Block exit("exit", SpvOpReturn); + + entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + if (is_shader) { + loop1.SetBody("OpLoopMerge %loop1_merge %loop2 None\n"); + loop2.SetBody("OpLoopMerge %loop2_merge %loop2 None\n"); + } + + std::string str = header(GetParam()) + nameOps("loop2", "loop2_merge") + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> loop1; + str += loop1 >> loop1_cont_break_block; + str += loop1_cont_break_block >> std::vector({loop1_merge, loop2}); + str += loop2 >> std::vector({loop2, loop2_merge}); + str += loop2_merge >> loop1; + str += loop1_merge >> exit; + str += exit; + str += "OpFunctionEnd"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateCFG, NestedSelection) { + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + const int N = 256; + std::vector if_blocks; + std::vector merge_blocks; + Block inner("inner"); + + entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + + if_blocks.emplace_back("if0", SpvOpBranchConditional); + + if (is_shader) if_blocks[0].SetBody("OpSelectionMerge %if_merge0 None\n"); + merge_blocks.emplace_back("if_merge0", SpvOpReturn); + + for (int i = 1; i < N; i++) { + std::stringstream ss; + ss << i; + if_blocks.emplace_back("if" + ss.str(), SpvOpBranchConditional); + if (is_shader) + if_blocks[i].SetBody("OpSelectionMerge %if_merge" + ss.str() + " None\n"); + merge_blocks.emplace_back("if_merge" + ss.str(), SpvOpBranch); + } + std::string str = header(GetParam()) + std::string(types_consts()) + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> if_blocks[0]; + for (int i = 0; i < N - 1; i++) { + str += + if_blocks[i] >> std::vector({if_blocks[i + 1], merge_blocks[i]}); + } + str += if_blocks.back() >> std::vector({inner, merge_blocks.back()}); + str += inner >> merge_blocks.back(); + for (int i = N - 1; i > 0; i--) { + str += merge_blocks[i] >> merge_blocks[i - 1]; + } + str += merge_blocks[0]; + str += "OpFunctionEnd"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateCFG, BackEdgeBlockDoesntPostDominateContinueTargetBad) { + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block loop1("loop1", SpvOpBranchConditional); + Block loop2("loop2", SpvOpBranchConditional); + Block loop2_merge("loop2_merge", SpvOpBranchConditional); + Block be_block("be_block"); + Block exit("exit", SpvOpReturn); + + entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + if (is_shader) { + loop1.SetBody("OpLoopMerge %exit %loop2_merge None\n"); + loop2.SetBody("OpLoopMerge %loop2_merge %loop2 None\n"); + } + + std::string str = header(GetParam()) + + nameOps("loop1", "loop2", "be_block", "loop2_merge") + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> loop1; + str += loop1 >> std::vector({loop2, exit}); + str += loop2 >> std::vector({loop2, loop2_merge}); + str += loop2_merge >> std::vector({be_block, exit}); + str += be_block >> loop1; + str += exit; + str += "OpFunctionEnd"; + + CompileSuccessfully(str); + if (GetParam() == SpvCapabilityShader) { + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + MatchesRegex("The continue construct with the continue target " + ".\\[%loop2_merge\\] is not post dominated by the " + "back-edge block .\\[%be_block\\]\n" + " %be_block = OpLabel\n")); + } else { + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + } +} + +TEST_P(ValidateCFG, BranchingToNonLoopHeaderBlockBad) { + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block split("split", SpvOpBranchConditional); + Block t("t"); + Block f("f"); + Block exit("exit", SpvOpReturn); + + entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + if (is_shader) split.SetBody("OpSelectionMerge %exit None\n"); + + std::string str = header(GetParam()) + nameOps("split", "f") + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> split; + str += split >> std::vector({t, f}); + str += t >> exit; + str += f >> split; + str += exit; + str += "OpFunctionEnd"; + + CompileSuccessfully(str); + if (is_shader) { + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + MatchesRegex("Back-edges \\(.\\[%f\\] -> .\\[%split\\]\\) can only " + "be formed between a block and a loop header.\n" + " %f = OpLabel\n")); + } else { + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + } +} + +TEST_P(ValidateCFG, BranchingToSameNonLoopHeaderBlockBad) { + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block split("split", SpvOpBranchConditional); + Block exit("exit", SpvOpReturn); + + entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + if (is_shader) split.SetBody("OpSelectionMerge %exit None\n"); + + std::string str = header(GetParam()) + nameOps("split") + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> split; + str += split >> std::vector({split, exit}); + str += exit; + str += "OpFunctionEnd"; + + CompileSuccessfully(str); + if (is_shader) { + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + MatchesRegex( + "Back-edges \\(.\\[%split\\] -> .\\[%split\\]\\) can only be " + "formed between a block and a loop header.\n %split = OpLabel\n")); + } else { + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + } +} + +TEST_P(ValidateCFG, MultipleBackEdgeBlocksToLoopHeaderBad) { + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block loop("loop", SpvOpBranchConditional); + Block back0("back0"); + Block back1("back1"); + Block merge("merge", SpvOpReturn); + + entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + if (is_shader) loop.SetBody("OpLoopMerge %merge %back0 None\n"); + + std::string str = header(GetParam()) + nameOps("loop", "back0", "back1") + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> loop; + str += loop >> std::vector({back0, back1}); + str += back0 >> loop; + str += back1 >> loop; + str += merge; + str += "OpFunctionEnd"; + + CompileSuccessfully(str); + if (is_shader) { + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + MatchesRegex( + "Loop header .\\[%loop\\] is targeted by 2 back-edge blocks but " + "the standard requires exactly one\n %loop = OpLabel\n")) + << str; + } else { + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + } +} + +TEST_P(ValidateCFG, ContinueTargetMustBePostDominatedByBackEdge) { + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block loop("loop", SpvOpBranchConditional); + Block cheader("cheader", SpvOpBranchConditional); + Block be_block("be_block"); + Block merge("merge", SpvOpReturn); + Block exit("exit", SpvOpReturn); + + entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + if (is_shader) loop.SetBody("OpLoopMerge %merge %cheader None\n"); + + std::string str = header(GetParam()) + nameOps("cheader", "be_block") + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> loop; + str += loop >> std::vector({cheader, merge}); + str += cheader >> std::vector({exit, be_block}); + str += exit; // Branches out of a continue construct + str += be_block >> loop; + str += merge; + str += "OpFunctionEnd"; + + CompileSuccessfully(str); + if (is_shader) { + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + MatchesRegex("The continue construct with the continue target " + ".\\[%cheader\\] is not post dominated by the " + "back-edge block .\\[%be_block\\]\n" + " %be_block = OpLabel\n")); + } else { + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + } +} + +TEST_P(ValidateCFG, BranchOutOfConstructToMergeBad) { + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block loop("loop", SpvOpBranchConditional); + Block cont("cont", SpvOpBranchConditional); + Block merge("merge", SpvOpReturn); + + entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + if (is_shader) loop.SetBody("OpLoopMerge %merge %loop None\n"); + + std::string str = header(GetParam()) + nameOps("cont", "loop") + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> loop; + str += loop >> std::vector({cont, merge}); + str += cont >> std::vector({loop, merge}); + str += merge; + str += "OpFunctionEnd"; + + CompileSuccessfully(str); + if (is_shader) { + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + MatchesRegex("The continue construct with the continue target " + ".\\[%loop\\] is not post dominated by the " + "back-edge block .\\[%cont\\]\n" + " %cont = OpLabel\n")) + << str; + } else { + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + } +} + +TEST_P(ValidateCFG, BranchOutOfConstructBad) { + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block loop("loop", SpvOpBranchConditional); + Block cont("cont", SpvOpBranchConditional); + Block merge("merge"); + Block exit("exit", SpvOpReturn); + + entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + if (is_shader) loop.SetBody("OpLoopMerge %merge %loop None\n"); + + std::string str = header(GetParam()) + nameOps("cont", "loop") + + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> loop; + str += loop >> std::vector({cont, merge}); + str += cont >> std::vector({loop, exit}); + str += merge >> exit; + str += exit; + str += "OpFunctionEnd"; + + CompileSuccessfully(str); + if (is_shader) { + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + MatchesRegex("The continue construct with the continue target " + ".\\[%loop\\] is not post dominated by the " + "back-edge block .\\[%cont\\]\n" + " %cont = OpLabel\n")); + } else { + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + } +} + +TEST_F(ValidateCFG, OpSwitchToUnreachableBlock) { + Block entry("entry", SpvOpSwitch); + Block case0("case0"); + Block case1("case1"); + Block case2("case2"); + Block def("default", SpvOpUnreachable); + Block phi("phi", SpvOpReturn); + + std::string str = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" %id +OpExecutionMode %main LocalSize 1 1 1 +OpSource GLSL 430 +OpName %main "main" +OpDecorate %id BuiltIn GlobalInvocationId +%void = OpTypeVoid +%voidf = OpTypeFunction %void +%u32 = OpTypeInt 32 0 +%f32 = OpTypeFloat 32 +%uvec3 = OpTypeVector %u32 3 +%fvec3 = OpTypeVector %f32 3 +%uvec3ptr = OpTypePointer Input %uvec3 +%id = OpVariable %uvec3ptr Input +%one = OpConstant %u32 1 +%three = OpConstant %u32 3 +%main = OpFunction %void None %voidf +)"; + + entry.SetBody( + "%idval = OpLoad %uvec3 %id\n" + "%x = OpCompositeExtract %u32 %idval 0\n" + "%selector = OpUMod %u32 %x %three\n" + "OpSelectionMerge %phi None\n"); + str += entry >> std::vector({def, case0, case1, case2}); + str += case1 >> phi; + str += def; + str += phi; + str += case0 >> phi; + str += case2 >> phi; + str += "OpFunctionEnd"; + + CompileSuccessfully(str); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateCFG, LoopWithZeroBackEdgesBad) { + std::string str = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginUpperLeft + OpName %loop "loop" +%voidt = OpTypeVoid +%funct = OpTypeFunction %voidt +%main = OpFunction %voidt None %funct +%loop = OpLabel + OpLoopMerge %exit %exit None + OpBranch %exit +%exit = OpLabel + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + MatchesRegex("Loop header .\\[%loop\\] is targeted by " + "0 back-edge blocks but the standard requires exactly " + "one\n %loop = OpLabel\n")); +} + +TEST_F(ValidateCFG, LoopWithBackEdgeFromUnreachableContinueConstructGood) { + std::string str = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginUpperLeft + OpName %loop "loop" +%voidt = OpTypeVoid +%funct = OpTypeFunction %voidt +%floatt = OpTypeFloat 32 +%boolt = OpTypeBool +%one = OpConstant %floatt 1 +%two = OpConstant %floatt 2 +%main = OpFunction %voidt None %funct +%entry = OpLabel + OpBranch %loop +%loop = OpLabel + OpLoopMerge %exit %cont None + OpBranch %16 +%16 = OpLabel +%cond = OpFOrdLessThan %boolt %one %two + OpBranchConditional %cond %body %exit +%body = OpLabel + OpReturn +%cont = OpLabel ; Reachable only from OpLoopMerge ContinueTarget parameter + OpBranch %loop ; Should be considered a back-edge +%exit = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()) << getDiagnosticString(); +} + +TEST_P(ValidateCFG, + NestedConstructWithUnreachableMergeBlockBranchingToOuterMergeBlock) { + // Test for https://github.com/KhronosGroup/SPIRV-Tools/issues/297 + // The nested construct has an unreachable merge block. In the + // augmented CFG that merge block + // we still determine that the + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry", SpvOpBranchConditional); + Block inner_head("inner_head", SpvOpBranchConditional); + Block inner_true("inner_true", SpvOpReturn); + Block inner_false("inner_false", SpvOpReturn); + Block inner_merge("inner_merge"); + Block exit("exit", SpvOpReturn); + + entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + if (is_shader) { + entry.AppendBody("OpSelectionMerge %exit None\n"); + inner_head.SetBody("OpSelectionMerge %inner_merge None\n"); + } + + std::string str = header(GetParam()) + + nameOps("entry", "inner_merge", "exit") + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> std::vector({inner_head, exit}); + str += inner_head >> std::vector({inner_true, inner_false}); + str += inner_true; + str += inner_false; + str += inner_merge >> exit; + str += exit; + str += "OpFunctionEnd"; + + CompileSuccessfully(str); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()) << getDiagnosticString(); +} + +TEST_P(ValidateCFG, ContinueTargetCanBeMergeBlockForNestedStructureGood) { + // This example is valid. It shows that the validator can't just add + // an edge from the loop head to the continue target. If that edge + // is added, then the "if_merge" block is both the continue target + // for the loop and also the merge block for the nested selection, but + // then it wouldn't be dominated by "if_head", the header block for the + // nested selection. + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block loop("loop"); + Block if_head("if_head", SpvOpBranchConditional); + Block if_true("if_true"); + Block if_merge("if_merge", SpvOpBranchConditional); + Block merge("merge", SpvOpReturn); + + entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + if (is_shader) { + loop.SetBody("OpLoopMerge %merge %if_merge None\n"); + if_head.SetBody("OpSelectionMerge %if_merge None\n"); + } + + std::string str = + header(GetParam()) + + nameOps("entry", "loop", "if_head", "if_true", "if_merge", "merge") + + types_consts() + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> loop; + str += loop >> if_head; + str += if_head >> std::vector({if_true, if_merge}); + str += if_true >> if_merge; + str += if_merge >> std::vector({loop, merge}); + str += merge; + str += "OpFunctionEnd"; + + CompileSuccessfully(str); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()) << getDiagnosticString(); +} + +TEST_P(ValidateCFG, SingleLatchBlockMultipleBranchesToLoopHeader) { + // This test case ensures we allow both branches of a loop latch block + // to go back to the loop header. It still counts as a single back edge. + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block loop("loop", SpvOpBranchConditional); + Block latch("latch", SpvOpBranchConditional); + Block merge("merge", SpvOpReturn); + + entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + if (is_shader) { + loop.SetBody("OpLoopMerge %merge %latch None\n"); + } + + std::string str = + header(GetParam()) + nameOps("entry", "loop", "latch", "merge") + + types_consts() + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> loop; + str += loop >> std::vector({latch, merge}); + str += latch >> std::vector({loop, loop}); // This is the key + str += merge; + str += "OpFunctionEnd"; + + CompileSuccessfully(str); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()) + << str << getDiagnosticString(); +} + +TEST_P(ValidateCFG, SingleLatchBlockHeaderContinueTargetIsItselfGood) { + // This test case ensures we don't count a Continue Target from a loop + // header to itself as a self-loop when computing back edges. + // Also, it detects that there is an edge from %latch to the pseudo-exit + // node, rather than from %loop. In particular, it detects that we + // have used the *reverse* textual order of blocks when computing + // predecessor traversal roots. + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block loop("loop"); + Block latch("latch"); + Block merge("merge", SpvOpReturn); + + entry.SetBody("%cond = OpSLessThan %boolt %one %two\n"); + if (is_shader) { + loop.SetBody("OpLoopMerge %merge %loop None\n"); + } + + std::string str = + header(GetParam()) + nameOps("entry", "loop", "latch", "merge") + + types_consts() + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> loop; + str += loop >> latch; + str += latch >> loop; + str += merge; + str += "OpFunctionEnd"; + + CompileSuccessfully(str); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()) + << str << getDiagnosticString(); +} + +// Unit test to check the case where a basic block is the entry block of 2 +// different constructs. In this case, the basic block is the entry block of a +// continue construct as well as a selection construct. See issue# 517 for more +// details. +TEST_F(ValidateCFG, BasicBlockIsEntryBlockOfTwoConstructsGood) { + std::string spirv = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + %void = OpTypeVoid + %bool = OpTypeBool + %int = OpTypeInt 32 1 + %void_func = OpTypeFunction %void + %int_0 = OpConstant %int 0 + %testfun = OpFunction %void None %void_func + %label_1 = OpLabel + OpBranch %start + %start = OpLabel + %cond = OpSLessThan %bool %int_0 %int_0 + ; + ; Note: In this case, the "target" block is both the entry block of + ; the continue construct of the loop as well as the entry block of + ; the selection construct. + ; + OpLoopMerge %loop_merge %target None + OpBranchConditional %cond %target %loop_merge + %loop_merge = OpLabel + OpReturn + %target = OpLabel + OpSelectionMerge %selection_merge None + OpBranchConditional %cond %do_stuff %do_other_stuff + %do_other_stuff = OpLabel + OpBranch %selection_merge + %selection_merge = OpLabel + OpBranch %start + %do_stuff = OpLabel + OpBranch %selection_merge + OpFunctionEnd + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateCFG, OpReturnInNonVoidFunc) { + std::string spirv = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + %int = OpTypeInt 32 1 + %int_func = OpTypeFunction %int + %testfun = OpFunction %int None %int_func + %label_1 = OpLabel + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(spirv); + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpReturn can only be called from a function with void return type.\n" + " OpReturn")); +} + +TEST_F(ValidateCFG, StructuredCFGBranchIntoSelectionBody) { + std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%entry = OpLabel +OpSelectionMerge %merge None +OpBranchConditional %true %then %merge +%merge = OpLabel +OpBranch %then +%then = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("branches to the selection construct, but not to the " + "selection header 6\n %7 = OpLabel")); +} + +TEST_F(ValidateCFG, SwitchDefaultOnly) { + std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 0 +%4 = OpTypeFunction %1 +%5 = OpFunction %1 None %4 +%6 = OpLabel +OpSelectionMerge %7 None +OpSwitch %3 %7 +%7 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateCFG, SwitchSingleCase) { + std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 0 +%4 = OpTypeFunction %1 +%5 = OpFunction %1 None %4 +%6 = OpLabel +OpSelectionMerge %7 None +OpSwitch %3 %7 0 %8 +%8 = OpLabel +OpBranch %7 +%7 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateCFG, MultipleFallThroughBlocks) { + std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 0 +%4 = OpTypeFunction %1 +%5 = OpTypeBool +%6 = OpConstantTrue %5 +%7 = OpFunction %1 None %4 +%8 = OpLabel +OpSelectionMerge %9 None +OpSwitch %3 %10 0 %11 1 %12 +%10 = OpLabel +OpBranchConditional %6 %11 %12 +%11 = OpLabel +OpBranch %9 +%12 = OpLabel +OpBranch %9 +%9 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Case construct that targets 10[%10] has branches to multiple other " + "case construct targets 12[%12] and 11[%11]\n %10 = OpLabel")); +} + +TEST_F(ValidateCFG, MultipleFallThroughToDefault) { + std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 0 +%4 = OpTypeFunction %1 +%5 = OpTypeBool +%6 = OpConstantTrue %5 +%7 = OpFunction %1 None %4 +%8 = OpLabel +OpSelectionMerge %9 None +OpSwitch %3 %10 0 %11 1 %12 +%10 = OpLabel +OpBranch %9 +%11 = OpLabel +OpBranch %10 +%12 = OpLabel +OpBranch %10 +%9 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Multiple case constructs have branches to the case construct " + "that targets 10[%10]\n %10 = OpLabel")); +} + +TEST_F(ValidateCFG, MultipleFallThroughToNonDefault) { + std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 0 +%4 = OpTypeFunction %1 +%5 = OpTypeBool +%6 = OpConstantTrue %5 +%7 = OpFunction %1 None %4 +%8 = OpLabel +OpSelectionMerge %9 None +OpSwitch %3 %10 0 %11 1 %12 +%10 = OpLabel +OpBranch %12 +%11 = OpLabel +OpBranch %12 +%12 = OpLabel +OpBranch %9 +%9 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Multiple case constructs have branches to the case construct " + "that targets 12[%12]\n %12 = OpLabel")); +} + +TEST_F(ValidateCFG, DuplicateTargetWithFallThrough) { + std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 0 +%4 = OpTypeFunction %1 +%5 = OpTypeBool +%6 = OpConstantTrue %5 +%7 = OpFunction %1 None %4 +%8 = OpLabel +OpSelectionMerge %9 None +OpSwitch %3 %10 0 %10 1 %11 +%10 = OpLabel +OpBranch %11 +%11 = OpLabel +OpBranch %9 +%9 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateCFG, WrongOperandList) { + std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 0 +%4 = OpTypeFunction %1 +%5 = OpTypeBool +%6 = OpConstantTrue %5 +%7 = OpFunction %1 None %4 +%8 = OpLabel +OpSelectionMerge %9 None +OpSwitch %3 %10 0 %11 1 %12 +%10 = OpLabel +OpBranch %9 +%12 = OpLabel +OpBranch %11 +%11 = OpLabel +OpBranch %9 +%9 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Case construct that targets 12[%12] has branches to the case " + "construct that targets 11[%11], but does not immediately " + "precede it in the OpSwitch's target list\n" + " OpSwitch %uint_0 %10 0 %11 1 %12")); +} + +TEST_F(ValidateCFG, WrongOperandListThroughDefault) { + std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 0 +%4 = OpTypeFunction %1 +%5 = OpTypeBool +%6 = OpConstantTrue %5 +%7 = OpFunction %1 None %4 +%8 = OpLabel +OpSelectionMerge %9 None +OpSwitch %3 %10 0 %11 1 %12 +%10 = OpLabel +OpBranch %11 +%12 = OpLabel +OpBranch %10 +%11 = OpLabel +OpBranch %9 +%9 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Case construct that targets 12[%12] has branches to the case " + "construct that targets 11[%11], but does not immediately " + "precede it in the OpSwitch's target list\n" + " OpSwitch %uint_0 %10 0 %11 1 %12")); +} + +TEST_F(ValidateCFG, WrongOperandListNotLast) { + std::string text = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 0 +%4 = OpTypeFunction %1 +%5 = OpTypeBool +%6 = OpConstantTrue %5 +%7 = OpFunction %1 None %4 +%8 = OpLabel +OpSelectionMerge %9 None +OpSwitch %3 %10 0 %11 1 %12 2 %13 +%10 = OpLabel +OpBranch %9 +%12 = OpLabel +OpBranch %11 +%11 = OpLabel +OpBranch %9 +%13 = OpLabel +OpBranch %9 +%9 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Case construct that targets 12[%12] has branches to the case " + "construct that targets 11[%11], but does not immediately " + "precede it in the OpSwitch's target list\n" + " OpSwitch %uint_0 %10 0 %11 1 %12 2 %13")); +} + +TEST_F(ValidateCFG, GoodUnreachableSwitch) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +%3 = OpTypeVoid +%4 = OpTypeFunction %3 +%5 = OpTypeBool +%6 = OpConstantTrue %5 +%7 = OpTypeInt 32 1 +%9 = OpConstant %7 0 +%2 = OpFunction %3 None %4 +%10 = OpLabel +OpSelectionMerge %11 None +OpBranchConditional %6 %12 %13 +%12 = OpLabel +OpReturn +%13 = OpLabel +OpReturn +%11 = OpLabel +OpSelectionMerge %14 None +OpSwitch %9 %14 0 %15 +%15 = OpLabel +OpBranch %14 +%14 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + EXPECT_THAT(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateCFG, InvalidCaseExit) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +OpExecutionMode %1 OriginUpperLeft +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypeFunction %2 +%5 = OpConstant %3 0 +%1 = OpFunction %2 None %4 +%6 = OpLabel +OpSelectionMerge %7 None +OpSwitch %5 %7 0 %8 1 %9 +%8 = OpLabel +OpBranch %10 +%9 = OpLabel +OpBranch %10 +%10 = OpLabel +OpReturn +%7 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Case construct that targets 8[%8] has invalid branch " + "to block 10[%10] (not another case construct, " + "corresponding merge, outer loop merge or outer loop " + "continue)")); +} + +TEST_F(ValidateCFG, GoodCaseExitsToOuterConstructs) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%bool = OpTypeBool +%true = OpConstantTrue %bool +%int = OpTypeInt 32 0 +%int0 = OpConstant %int 0 +%func_ty = OpTypeFunction %void +%func = OpFunction %void None %func_ty +%1 = OpLabel +OpBranch %2 +%2 = OpLabel +OpLoopMerge %7 %6 None +OpBranch %3 +%3 = OpLabel +OpSelectionMerge %5 None +OpSwitch %int0 %5 0 %4 +%4 = OpLabel +OpBranchConditional %true %6 %7 +%5 = OpLabel +OpBranchConditional %true %6 %7 +%6 = OpLabel +OpBranch %2 +%7 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateCFG, GoodUnreachableSelection) { + const std::string text = R"( +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%8 = OpTypeFunction %void +%bool = OpTypeBool +%false = OpConstantFalse %bool +%main = OpFunction %void None %8 +%15 = OpLabel +OpBranch %16 +%16 = OpLabel +OpLoopMerge %17 %18 None +OpBranch %19 +%19 = OpLabel +OpBranchConditional %false %21 %17 +%21 = OpLabel +OpSelectionMerge %22 None +OpBranchConditional %false %23 %22 +%23 = OpLabel +OpBranch %24 +%24 = OpLabel +OpLoopMerge %25 %26 None +OpBranch %27 +%27 = OpLabel +OpReturn +%26 = OpLabel +OpBranchConditional %false %24 %25 +%25 = OpLabel +OpSelectionMerge %28 None +OpBranchConditional %false %18 %28 +%28 = OpLabel +OpBranch %22 +%22 = OpLabel +OpBranch %18 +%18 = OpLabel +OpBranch %16 +%17 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateCFG, ShaderWithPhiPtr) { + const std::string text = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + OpExecutionMode %1 LocalSize 1 1 1 + OpSource HLSL 600 + %bool = OpTypeBool +%_ptr_Function_bool = OpTypePointer Function %bool + %void = OpTypeVoid + %5 = OpTypeFunction %void + %1 = OpFunction %void None %5 + %6 = OpLabel + %7 = OpVariable %_ptr_Function_bool Function + %8 = OpVariable %_ptr_Function_bool Function + %9 = OpUndef %bool + OpSelectionMerge %10 None + OpBranchConditional %9 %11 %10 + %11 = OpLabel + OpBranch %10 + %10 = OpLabel + %12 = OpPhi %_ptr_Function_bool %7 %6 %8 %11 + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(text); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Using pointers with OpPhi requires capability " + "VariablePointers or VariablePointersStorageBuffer")); +} + +TEST_F(ValidateCFG, VarPtrShaderWithPhiPtr) { + const std::string text = R"( + OpCapability Shader + OpCapability VariablePointers + OpExtension "SPV_KHR_variable_pointers" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + OpExecutionMode %1 LocalSize 1 1 1 + OpSource HLSL 600 + %bool = OpTypeBool +%_ptr_Function_bool = OpTypePointer Function %bool + %void = OpTypeVoid + %5 = OpTypeFunction %void + %1 = OpFunction %void None %5 + %6 = OpLabel + %7 = OpVariable %_ptr_Function_bool Function + %8 = OpVariable %_ptr_Function_bool Function + %9 = OpUndef %bool + OpSelectionMerge %10 None + OpBranchConditional %9 %11 %10 + %11 = OpLabel + OpBranch %10 + %10 = OpLabel + %12 = OpPhi %_ptr_Function_bool %7 %6 %8 %11 + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(text); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateCFG, VarPtrStgBufShaderWithPhiStgBufPtr) { + const std::string text = R"( + OpCapability Shader + OpCapability VariablePointersStorageBuffer + OpExtension "SPV_KHR_variable_pointers" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + OpExecutionMode %1 LocalSize 1 1 1 + OpSource HLSL 600 + %bool = OpTypeBool + %float = OpTypeFloat 32 +%_ptr_StorageBuffer_float = OpTypePointer StorageBuffer %float + %7 = OpVariable %_ptr_StorageBuffer_float StorageBuffer + %8 = OpVariable %_ptr_StorageBuffer_float StorageBuffer + %void = OpTypeVoid + %5 = OpTypeFunction %void + %1 = OpFunction %void None %5 + %6 = OpLabel + %9 = OpUndef %bool + OpSelectionMerge %10 None + OpBranchConditional %9 %11 %10 + %11 = OpLabel + OpBranch %10 + %10 = OpLabel + %12 = OpPhi %_ptr_StorageBuffer_float %7 %6 %8 %11 + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(text); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateCFG, KernelWithPhiPtr) { + const std::string text = R"( + OpCapability Kernel + OpCapability Addresses + OpMemoryModel Physical32 OpenCL + OpEntryPoint Kernel %1 "main" + OpExecutionMode %1 LocalSize 1 1 1 + OpSource HLSL 600 + %bool = OpTypeBool +%_ptr_Function_bool = OpTypePointer Function %bool + %void = OpTypeVoid + %5 = OpTypeFunction %void + %1 = OpFunction %void None %5 + %6 = OpLabel + %7 = OpVariable %_ptr_Function_bool Function + %8 = OpVariable %_ptr_Function_bool Function + %9 = OpUndef %bool + OpSelectionMerge %10 None + OpBranchConditional %9 %11 %10 + %11 = OpLabel + OpBranch %10 + %10 = OpLabel + %12 = OpPhi %_ptr_Function_bool %7 %6 %8 %11 + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(text); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +/// TODO(umar): Nested CFG constructs + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_composites_test.cpp b/test/val/val_composites_test.cpp new file mode 100644 index 000000000..bf7f15d51 --- /dev/null +++ b/test/val/val_composites_test.cpp @@ -0,0 +1,1496 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::HasSubstr; +using ::testing::Not; + +using ValidateComposites = spvtest::ValidateBase; + +std::string GenerateShaderCode( + const std::string& body, + const std::string& capabilities_and_extensions = "", + const std::string& execution_model = "Fragment") { + std::ostringstream ss; + ss << R"( +OpCapability Shader +OpCapability Float64 +)"; + + ss << capabilities_and_extensions; + ss << "OpMemoryModel Logical GLSL450\n"; + ss << "OpEntryPoint " << execution_model << " %main \"main\"\n"; + if (execution_model == "Fragment") { + ss << "OpExecutionMode %main OriginUpperLeft\n"; + } + + ss << R"( +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f32 = OpTypeFloat 32 +%f64 = OpTypeFloat 64 +%u32 = OpTypeInt 32 0 +%s32 = OpTypeInt 32 1 +%f32vec2 = OpTypeVector %f32 2 +%f32vec3 = OpTypeVector %f32 3 +%f32vec4 = OpTypeVector %f32 4 +%f64vec2 = OpTypeVector %f64 2 +%u32vec2 = OpTypeVector %u32 2 +%u32vec4 = OpTypeVector %u32 4 +%f64mat22 = OpTypeMatrix %f64vec2 2 +%f32mat22 = OpTypeMatrix %f32vec2 2 +%f32mat23 = OpTypeMatrix %f32vec2 3 +%f32mat32 = OpTypeMatrix %f32vec3 2 + +%f32_0 = OpConstant %f32 0 +%f32_1 = OpConstant %f32 1 +%f32_2 = OpConstant %f32 2 +%f32_3 = OpConstant %f32 3 +%f32vec2_01 = OpConstantComposite %f32vec2 %f32_0 %f32_1 +%f32vec2_12 = OpConstantComposite %f32vec2 %f32_1 %f32_2 +%f32vec4_0123 = OpConstantComposite %f32vec4 %f32_0 %f32_1 %f32_2 %f32_3 + +%u32_0 = OpConstant %u32 0 +%u32_1 = OpConstant %u32 1 +%u32_2 = OpConstant %u32 2 +%u32_3 = OpConstant %u32 3 + +%u32vec2_01 = OpConstantComposite %u32vec2 %u32_0 %u32_1 +%u32vec4_0123 = OpConstantComposite %u32vec4 %u32_0 %u32_1 %u32_2 %u32_3 + +%f32mat22_1212 = OpConstantComposite %f32mat22 %f32vec2_12 %f32vec2_12 +%f32mat23_121212 = OpConstantComposite %f32mat23 %f32vec2_12 %f32vec2_12 %f32vec2_12 + +%f32vec2arr3 = OpTypeArray %f32vec2 %u32_3 +%f32vec2rarr = OpTypeRuntimeArray %f32vec2 + +%f32u32struct = OpTypeStruct %f32 %u32 +%big_struct = OpTypeStruct %f32 %f32vec4 %f32mat23 %f32vec2arr3 %f32vec2rarr %f32u32struct + +%ptr_big_struct = OpTypePointer Uniform %big_struct +%var_big_struct = OpVariable %ptr_big_struct Uniform + +%main = OpFunction %void None %func +%main_entry = OpLabel +)"; + + ss << body; + + ss << R"( +OpReturn +OpFunctionEnd)"; + + return ss.str(); +} + +// Returns header for legacy tests taken from val_id_test.cpp. +std::string GetHeaderForTestsFromValId() { + return R"( +OpCapability Shader +OpCapability Linkage +OpCapability Addresses +OpCapability Pipes +OpCapability LiteralSampler +OpCapability DeviceEnqueue +OpCapability Vector16 +OpCapability Int8 +OpCapability Int16 +OpCapability Int64 +OpCapability Float64 +OpMemoryModel Logical GLSL450 +%void = OpTypeVoid +%void_f = OpTypeFunction %void +%int = OpTypeInt 32 0 +%float = OpTypeFloat 32 +%v3float = OpTypeVector %float 3 +%mat4x3 = OpTypeMatrix %v3float 4 +%_ptr_Private_mat4x3 = OpTypePointer Private %mat4x3 +%_ptr_Private_float = OpTypePointer Private %float +%my_matrix = OpVariable %_ptr_Private_mat4x3 Private +%my_float_var = OpVariable %_ptr_Private_float Private +%_ptr_Function_float = OpTypePointer Function %float +%int_0 = OpConstant %int 0 +%int_1 = OpConstant %int 1 +%int_2 = OpConstant %int 2 +%int_3 = OpConstant %int 3 +%int_5 = OpConstant %int 5 + +; Making the following nested structures. +; +; struct S { +; bool b; +; vec4 v[5]; +; int i; +; mat4x3 m[5]; +; } +; uniform blockName { +; S s; +; bool cond; +; RunTimeArray arr; +; } + +%f32arr = OpTypeRuntimeArray %float +%v4float = OpTypeVector %float 4 +%array5_mat4x3 = OpTypeArray %mat4x3 %int_5 +%array5_vec4 = OpTypeArray %v4float %int_5 +%_ptr_Uniform_float = OpTypePointer Uniform %float +%_ptr_Function_vec4 = OpTypePointer Function %v4float +%_ptr_Uniform_vec4 = OpTypePointer Uniform %v4float +%struct_s = OpTypeStruct %int %array5_vec4 %int %array5_mat4x3 +%struct_blockName = OpTypeStruct %struct_s %int %f32arr +%_ptr_Uniform_blockName = OpTypePointer Uniform %struct_blockName +%_ptr_Uniform_struct_s = OpTypePointer Uniform %struct_s +%_ptr_Uniform_array5_mat4x3 = OpTypePointer Uniform %array5_mat4x3 +%_ptr_Uniform_mat4x3 = OpTypePointer Uniform %mat4x3 +%_ptr_Uniform_v3float = OpTypePointer Uniform %v3float +%blockName_var = OpVariable %_ptr_Uniform_blockName Uniform +%spec_int = OpSpecConstant %int 2 +%func = OpFunction %void None %void_f +%my_label = OpLabel +)"; +} + +TEST_F(ValidateComposites, VectorExtractDynamicSuccess) { + const std::string body = R"( +%val1 = OpVectorExtractDynamic %f32 %f32vec4_0123 %u32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateComposites, VectorExtractDynamicWrongResultType) { + const std::string body = R"( +%val1 = OpVectorExtractDynamic %f32vec4 %f32vec4_0123 %u32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be a scalar type")); +} + +TEST_F(ValidateComposites, VectorExtractDynamicNotVector) { + const std::string body = R"( +%val1 = OpVectorExtractDynamic %f32 %f32mat22_1212 %u32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Vector type to be OpTypeVector")); +} + +TEST_F(ValidateComposites, VectorExtractDynamicWrongVectorComponent) { + const std::string body = R"( +%val1 = OpVectorExtractDynamic %f32 %u32vec4_0123 %u32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Vector component type to be equal to Result Type")); +} + +TEST_F(ValidateComposites, VectorExtractDynamicWrongIndexType) { + const std::string body = R"( +%val1 = OpVectorExtractDynamic %f32 %f32vec4_0123 %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Index to be int scalar")); +} + +TEST_F(ValidateComposites, VectorInsertDynamicSuccess) { + const std::string body = R"( +%val1 = OpVectorInsertDynamic %f32vec4 %f32vec4_0123 %f32_1 %u32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateComposites, VectorInsertDynamicWrongResultType) { + const std::string body = R"( +%val1 = OpVectorInsertDynamic %f32 %f32vec4_0123 %f32_1 %u32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be OpTypeVector")); +} + +TEST_F(ValidateComposites, VectorInsertDynamicNotVector) { + const std::string body = R"( +%val1 = OpVectorInsertDynamic %f32vec4 %f32mat22_1212 %f32_1 %u32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Vector type to be equal to Result Type")); +} + +TEST_F(ValidateComposites, VectorInsertDynamicWrongComponentType) { + const std::string body = R"( +%val1 = OpVectorInsertDynamic %f32vec4 %f32vec4_0123 %u32_1 %u32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Component type to be equal to Result Type " + "component type")); +} + +TEST_F(ValidateComposites, VectorInsertDynamicWrongIndexType) { + const std::string body = R"( +%val1 = OpVectorInsertDynamic %f32vec4 %f32vec4_0123 %f32_1 %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Index to be int scalar")); +} + +TEST_F(ValidateComposites, CompositeConstructNotComposite) { + const std::string body = R"( +%val1 = OpCompositeConstruct %f32 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be a composite type")); +} + +TEST_F(ValidateComposites, CompositeConstructVectorSuccess) { + const std::string body = R"( +%val1 = OpCompositeConstruct %f32vec4 %f32vec2_12 %f32vec2_12 +%val2 = OpCompositeConstruct %f32vec4 %f32vec2_12 %f32_0 %f32_0 +%val3 = OpCompositeConstruct %f32vec4 %f32_0 %f32_0 %f32vec2_12 +%val4 = OpCompositeConstruct %f32vec4 %f32_0 %f32_1 %f32_2 %f32_3 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateComposites, CompositeConstructVectorOnlyOneConstituent) { + const std::string body = R"( +%val1 = OpCompositeConstruct %f32vec4 %f32vec4_0123 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected number of constituents to be at least 2")); +} + +TEST_F(ValidateComposites, CompositeConstructVectorWrongConsituent1) { + const std::string body = R"( +%val1 = OpCompositeConstruct %f32vec4 %f32 %f32vec2_12 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("Operand 5[%float] cannot be a " + "type")); +} + +TEST_F(ValidateComposites, CompositeConstructVectorWrongConsituent2) { + const std::string body = R"( +%val1 = OpCompositeConstruct %f32vec4 %f32vec2_12 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Constituents to be scalars or vectors of the same " + "type as Result Type components")); +} + +TEST_F(ValidateComposites, CompositeConstructVectorWrongConsituent3) { + const std::string body = R"( +%val1 = OpCompositeConstruct %f32vec4 %f32vec2_12 %u32_0 %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Constituents to be scalars or vectors of the same " + "type as Result Type components")); +} + +TEST_F(ValidateComposites, CompositeConstructVectorWrongComponentNumber1) { + const std::string body = R"( +%val1 = OpCompositeConstruct %f32vec4 %f32vec2_12 %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected total number of given components to be equal to the " + "size of Result Type vector")); +} + +TEST_F(ValidateComposites, CompositeConstructVectorWrongComponentNumber2) { + const std::string body = R"( +%val1 = OpCompositeConstruct %f32vec4 %f32vec2_12 %f32vec2_12 %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected total number of given components to be equal to the " + "size of Result Type vector")); +} + +TEST_F(ValidateComposites, CompositeConstructMatrixSuccess) { + const std::string body = R"( +%val1 = OpCompositeConstruct %f32mat22 %f32vec2_12 %f32vec2_12 +%val2 = OpCompositeConstruct %f32mat23 %f32vec2_12 %f32vec2_12 %f32vec2_12 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateComposites, CompositeConstructVectorWrongConsituentNumber1) { + const std::string body = R"( +%val1 = OpCompositeConstruct %f32mat22 %f32vec2_12 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected total number of Constituents to be equal to the " + "number of columns of Result Type matrix")); +} + +TEST_F(ValidateComposites, CompositeConstructVectorWrongConsituentNumber2) { + const std::string body = R"( +%val1 = OpCompositeConstruct %f32mat22 %f32vec2_12 %f32vec2_12 %f32vec2_12 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected total number of Constituents to be equal to the " + "number of columns of Result Type matrix")); +} + +TEST_F(ValidateComposites, CompositeConstructVectorWrongConsituent) { + const std::string body = R"( +%val1 = OpCompositeConstruct %f32mat22 %f32vec2_12 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Constituent type to be equal to the column type " + "Result Type matrix")); +} + +TEST_F(ValidateComposites, CompositeConstructArraySuccess) { + const std::string body = R"( +%val1 = OpCompositeConstruct %f32vec2arr3 %f32vec2_12 %f32vec2_12 %f32vec2_12 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateComposites, CompositeConstructArrayWrongConsituentNumber1) { + const std::string body = R"( +%val1 = OpCompositeConstruct %f32vec2arr3 %f32vec2_12 %f32vec2_12 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected total number of Constituents to be equal to the " + "number of elements of Result Type array")); +} + +TEST_F(ValidateComposites, CompositeConstructArrayWrongConsituentNumber2) { + const std::string body = R"( +%val1 = OpCompositeConstruct %f32vec2arr3 %f32vec2_12 %f32vec2_12 %f32vec2_12 %f32vec2_12 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected total number of Constituents to be equal to the " + "number of elements of Result Type array")); +} + +TEST_F(ValidateComposites, CompositeConstructArrayWrongConsituent) { + const std::string body = R"( +%val1 = OpCompositeConstruct %f32vec2arr3 %f32vec2_12 %u32vec2_01 %f32vec2_12 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Constituent type to be equal to the column type " + "Result Type array")); +} + +TEST_F(ValidateComposites, CompositeConstructStructSuccess) { + const std::string body = R"( +%val1 = OpCompositeConstruct %f32u32struct %f32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateComposites, CompositeConstructStructWrongConstituentNumber1) { + const std::string body = R"( +%val1 = OpCompositeConstruct %f32u32struct %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected total number of Constituents to be equal to the " + "number of members of Result Type struct")); +} + +TEST_F(ValidateComposites, CompositeConstructStructWrongConstituentNumber2) { + const std::string body = R"( +%val1 = OpCompositeConstruct %f32u32struct %f32_0 %u32_1 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected total number of Constituents to be equal to the " + "number of members of Result Type struct")); +} + +TEST_F(ValidateComposites, CompositeConstructStructWrongConstituent) { + const std::string body = R"( +%val1 = OpCompositeConstruct %f32u32struct %f32_0 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Constituent type to be equal to the " + "corresponding member type of Result Type struct")); +} + +TEST_F(ValidateComposites, CopyObjectSuccess) { + const std::string body = R"( +%val1 = OpCopyObject %f32 %f32_0 +%val2 = OpCopyObject %f32vec4 %f32vec4_0123 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateComposites, CopyObjectResultTypeNotType) { + const std::string body = R"( +%val1 = OpCopyObject %f32_0 %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("ID 19[%float_0] is not a type id")); +} + +TEST_F(ValidateComposites, CopyObjectWrongOperandType) { + const std::string body = R"( +%val1 = OpCopyObject %f32 %u32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Result Type and Operand type to be the same")); +} + +TEST_F(ValidateComposites, TransposeSuccess) { + const std::string body = R"( +%val1 = OpTranspose %f32mat32 %f32mat23_121212 +%val2 = OpTranspose %f32mat22 %f32mat22_1212 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateComposites, TransposeResultTypeNotMatrix) { + const std::string body = R"( +%val1 = OpTranspose %f32vec4 %f32mat22_1212 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be a matrix type")); +} + +TEST_F(ValidateComposites, TransposeDifferentComponentTypes) { + const std::string body = R"( +%val1 = OpTranspose %f64mat22 %f32mat22_1212 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected component types of Matrix and Result Type to be " + "identical")); +} + +TEST_F(ValidateComposites, TransposeIncompatibleDimensions1) { + const std::string body = R"( +%val1 = OpTranspose %f32mat23 %f32mat22_1212 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected number of columns and the column size " + "of Matrix to be the reverse of those of Result Type")); +} + +TEST_F(ValidateComposites, TransposeIncompatibleDimensions2) { + const std::string body = R"( +%val1 = OpTranspose %f32mat32 %f32mat22_1212 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected number of columns and the column size " + "of Matrix to be the reverse of those of Result Type")); +} + +TEST_F(ValidateComposites, TransposeIncompatibleDimensions3) { + const std::string body = R"( +%val1 = OpTranspose %f32mat23 %f32mat23_121212 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected number of columns and the column size " + "of Matrix to be the reverse of those of Result Type")); +} + +TEST_F(ValidateComposites, CompositeExtractSuccess) { + const std::string body = R"( +%val1 = OpCompositeExtract %f32 %f32vec4_0123 1 +%val2 = OpCompositeExtract %u32 %u32vec4_0123 0 +%val3 = OpCompositeExtract %f32 %f32mat22_1212 0 1 +%val4 = OpCompositeExtract %f32vec2 %f32mat22_1212 0 +%array = OpCompositeConstruct %f32vec2arr3 %f32vec2_12 %f32vec2_12 %f32vec2_12 +%val5 = OpCompositeExtract %f32vec2 %array 2 +%val6 = OpCompositeExtract %f32 %array 2 1 +%struct = OpLoad %big_struct %var_big_struct +%val7 = OpCompositeExtract %f32 %struct 0 +%val8 = OpCompositeExtract %f32vec4 %struct 1 +%val9 = OpCompositeExtract %f32 %struct 1 2 +%val10 = OpCompositeExtract %f32mat23 %struct 2 +%val11 = OpCompositeExtract %f32vec2 %struct 2 2 +%val12 = OpCompositeExtract %f32 %struct 2 2 1 +%val13 = OpCompositeExtract %f32vec2 %struct 3 2 +%val14 = OpCompositeExtract %f32 %struct 3 2 1 +%val15 = OpCompositeExtract %f32vec2 %struct 4 100 +%val16 = OpCompositeExtract %f32 %struct 4 1000 1 +%val17 = OpCompositeExtract %f32 %struct 5 0 +%val18 = OpCompositeExtract %u32 %struct 5 1 +%val19 = OpCompositeExtract %big_struct %struct +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateComposites, CompositeExtractNotObject) { + const std::string body = R"( +%val1 = OpCompositeExtract %f32 %f32vec4 1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("Operand 11[%v4float] cannot " + "be a type")); +} + +TEST_F(ValidateComposites, CompositeExtractNotComposite) { + const std::string body = R"( +%val1 = OpCompositeExtract %f32 %f32_1 0 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Reached non-composite type while indexes still remain " + "to be traversed.")); +} + +TEST_F(ValidateComposites, CompositeExtractVectorOutOfBounds) { + const std::string body = R"( +%val1 = OpCompositeExtract %f32 %f32vec4_0123 4 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Vector access is out of bounds, " + "vector size is 4, but access index is 4")); +} + +TEST_F(ValidateComposites, CompositeExtractMatrixOutOfCols) { + const std::string body = R"( +%val1 = OpCompositeExtract %f32 %f32mat23_121212 3 1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Matrix access is out of bounds, " + "matrix has 3 columns, but access index is 3")); +} + +TEST_F(ValidateComposites, CompositeExtractMatrixOutOfRows) { + const std::string body = R"( +%val1 = OpCompositeExtract %f32 %f32mat23_121212 2 5 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Vector access is out of bounds, " + "vector size is 2, but access index is 5")); +} + +TEST_F(ValidateComposites, CompositeExtractArrayOutOfBounds) { + const std::string body = R"( +%array = OpCompositeConstruct %f32vec2arr3 %f32vec2_12 %f32vec2_12 %f32vec2_12 +%val1 = OpCompositeExtract %f32vec2 %array 3 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Array access is out of bounds, " + "array size is 3, but access index is 3")); +} + +TEST_F(ValidateComposites, CompositeExtractStructOutOfBounds) { + const std::string body = R"( +%struct = OpLoad %big_struct %var_big_struct +%val1 = OpCompositeExtract %f32 %struct 6 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Index is out of bounds, can not find index 6 in the " + "structure '37'. This structure has 6 members. " + "Largest valid index is 5.")); +} + +TEST_F(ValidateComposites, CompositeExtractNestedVectorOutOfBounds) { + const std::string body = R"( +%struct = OpLoad %big_struct %var_big_struct +%val1 = OpCompositeExtract %f32 %struct 3 1 5 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Vector access is out of bounds, " + "vector size is 2, but access index is 5")); +} + +TEST_F(ValidateComposites, CompositeExtractTooManyIndices) { + const std::string body = R"( +%struct = OpLoad %big_struct %var_big_struct +%val1 = OpCompositeExtract %f32 %struct 3 1 1 2 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Reached non-composite type while " + "indexes still remain to be traversed.")); +} + +TEST_F(ValidateComposites, CompositeExtractWrongType1) { + const std::string body = R"( +%struct = OpLoad %big_struct %var_big_struct +%val1 = OpCompositeExtract %f32vec2 %struct 3 1 1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Result type (OpTypeVector) does not match the type that results " + "from indexing into the composite (OpTypeFloat).")); +} + +TEST_F(ValidateComposites, CompositeExtractWrongType2) { + const std::string body = R"( +%struct = OpLoad %big_struct %var_big_struct +%val1 = OpCompositeExtract %f32 %struct 3 1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Result type (OpTypeFloat) does not match the type " + "that results from indexing into the composite " + "(OpTypeVector).")); +} + +TEST_F(ValidateComposites, CompositeExtractWrongType3) { + const std::string body = R"( +%struct = OpLoad %big_struct %var_big_struct +%val1 = OpCompositeExtract %f32 %struct 2 1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Result type (OpTypeFloat) does not match the type " + "that results from indexing into the composite " + "(OpTypeVector).")); +} + +TEST_F(ValidateComposites, CompositeExtractWrongType4) { + const std::string body = R"( +%struct = OpLoad %big_struct %var_big_struct +%val1 = OpCompositeExtract %f32 %struct 4 1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Result type (OpTypeFloat) does not match the type " + "that results from indexing into the composite " + "(OpTypeVector).")); +} + +TEST_F(ValidateComposites, CompositeExtractWrongType5) { + const std::string body = R"( +%struct = OpLoad %big_struct %var_big_struct +%val1 = OpCompositeExtract %f32 %struct 5 1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Result type (OpTypeFloat) does not match the " + "type that results from indexing into the composite (OpTypeInt).")); +} + +TEST_F(ValidateComposites, CompositeInsertSuccess) { + const std::string body = R"( +%val1 = OpCompositeInsert %f32vec4 %f32_1 %f32vec4_0123 0 +%val2 = OpCompositeInsert %u32vec4 %u32_1 %u32vec4_0123 0 +%val3 = OpCompositeInsert %f32mat22 %f32_2 %f32mat22_1212 0 1 +%val4 = OpCompositeInsert %f32mat22 %f32vec2_01 %f32mat22_1212 0 +%array = OpCompositeConstruct %f32vec2arr3 %f32vec2_12 %f32vec2_12 %f32vec2_12 +%val5 = OpCompositeInsert %f32vec2arr3 %f32vec2_01 %array 2 +%val6 = OpCompositeInsert %f32vec2arr3 %f32_3 %array 2 1 +%struct = OpLoad %big_struct %var_big_struct +%val7 = OpCompositeInsert %big_struct %f32_3 %struct 0 +%val8 = OpCompositeInsert %big_struct %f32vec4_0123 %struct 1 +%val9 = OpCompositeInsert %big_struct %f32_3 %struct 1 2 +%val10 = OpCompositeInsert %big_struct %f32mat23_121212 %struct 2 +%val11 = OpCompositeInsert %big_struct %f32vec2_01 %struct 2 2 +%val12 = OpCompositeInsert %big_struct %f32_3 %struct 2 2 1 +%val13 = OpCompositeInsert %big_struct %f32vec2_01 %struct 3 2 +%val14 = OpCompositeInsert %big_struct %f32_3 %struct 3 2 1 +%val15 = OpCompositeInsert %big_struct %f32vec2_01 %struct 4 100 +%val16 = OpCompositeInsert %big_struct %f32_3 %struct 4 1000 1 +%val17 = OpCompositeInsert %big_struct %f32_3 %struct 5 0 +%val18 = OpCompositeInsert %big_struct %u32_3 %struct 5 1 +%val19 = OpCompositeInsert %big_struct %struct %struct +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateComposites, CompositeInsertResultTypeDifferentFromComposite) { + const std::string body = R"( +%val1 = OpCompositeInsert %f32 %f32_1 %f32vec4_0123 0 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("The Result Type must be the same as Composite type in " + "OpCompositeInsert yielding Result Id 5.")); +} + +TEST_F(ValidateComposites, CompositeInsertNotComposite) { + const std::string body = R"( +%val1 = OpCompositeInsert %f32 %f32_1 %f32_0 0 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Reached non-composite type while indexes still remain " + "to be traversed.")); +} + +TEST_F(ValidateComposites, CompositeInsertVectorOutOfBounds) { + const std::string body = R"( +%val1 = OpCompositeInsert %f32vec4 %f32_1 %f32vec4_0123 4 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Vector access is out of bounds, " + "vector size is 4, but access index is 4")); +} + +TEST_F(ValidateComposites, CompositeInsertMatrixOutOfCols) { + const std::string body = R"( +%val1 = OpCompositeInsert %f32mat23 %f32_1 %f32mat23_121212 3 1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Matrix access is out of bounds, " + "matrix has 3 columns, but access index is 3")); +} + +TEST_F(ValidateComposites, CompositeInsertMatrixOutOfRows) { + const std::string body = R"( +%val1 = OpCompositeInsert %f32mat23 %f32_1 %f32mat23_121212 2 5 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Vector access is out of bounds, " + "vector size is 2, but access index is 5")); +} + +TEST_F(ValidateComposites, CompositeInsertArrayOutOfBounds) { + const std::string body = R"( +%array = OpCompositeConstruct %f32vec2arr3 %f32vec2_12 %f32vec2_12 %f32vec2_12 +%val1 = OpCompositeInsert %f32vec2arr3 %f32vec2_01 %array 3 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Array access is out of bounds, array " + "size is 3, but access index is 3")); +} + +TEST_F(ValidateComposites, CompositeInsertStructOutOfBounds) { + const std::string body = R"( +%struct = OpLoad %big_struct %var_big_struct +%val1 = OpCompositeInsert %big_struct %f32_1 %struct 6 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Index is out of bounds, can not find index 6 in the " + "structure '37'. This structure has 6 members. " + "Largest valid index is 5.")); +} + +TEST_F(ValidateComposites, CompositeInsertNestedVectorOutOfBounds) { + const std::string body = R"( +%struct = OpLoad %big_struct %var_big_struct +%val1 = OpCompositeInsert %big_struct %f32_1 %struct 3 1 5 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Vector access is out of bounds, " + "vector size is 2, but access index is 5")); +} + +TEST_F(ValidateComposites, CompositeInsertTooManyIndices) { + const std::string body = R"( +%struct = OpLoad %big_struct %var_big_struct +%val1 = OpCompositeInsert %big_struct %f32_1 %struct 3 1 1 2 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Reached non-composite type while indexes still remain " + "to be traversed.")); +} + +TEST_F(ValidateComposites, CompositeInsertWrongType1) { + const std::string body = R"( +%struct = OpLoad %big_struct %var_big_struct +%val1 = OpCompositeInsert %big_struct %f32vec2_01 %struct 3 1 1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("The Object type (OpTypeVector) does not match the " + "type that results from indexing into the Composite " + "(OpTypeFloat).")); +} + +TEST_F(ValidateComposites, CompositeInsertWrongType2) { + const std::string body = R"( +%struct = OpLoad %big_struct %var_big_struct +%val1 = OpCompositeInsert %big_struct %f32_1 %struct 3 1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("The Object type (OpTypeFloat) does not match the type " + "that results from indexing into the Composite " + "(OpTypeVector).")); +} + +TEST_F(ValidateComposites, CompositeInsertWrongType3) { + const std::string body = R"( +%struct = OpLoad %big_struct %var_big_struct +%val1 = OpCompositeInsert %big_struct %f32_1 %struct 2 1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("The Object type (OpTypeFloat) does not match the type " + "that results from indexing into the Composite " + "(OpTypeVector).")); +} + +TEST_F(ValidateComposites, CompositeInsertWrongType4) { + const std::string body = R"( +%struct = OpLoad %big_struct %var_big_struct +%val1 = OpCompositeInsert %big_struct %f32_1 %struct 4 1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("The Object type (OpTypeFloat) does not match the type " + "that results from indexing into the Composite " + "(OpTypeVector).")); +} + +TEST_F(ValidateComposites, CompositeInsertWrongType5) { + const std::string body = R"( +%struct = OpLoad %big_struct %var_big_struct +%val1 = OpCompositeInsert %big_struct %f32_1 %struct 5 1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("The Object type (OpTypeFloat) does not match the type " + "that results from indexing into the Composite " + "(OpTypeInt).")); +} + +// Tests ported from val_id_test.cpp. + +// Valid. Tests both CompositeExtract and CompositeInsert with 255 indexes. +TEST_F(ValidateComposites, CompositeExtractInsertLimitsGood) { + int depth = 255; + std::string header = GetHeaderForTestsFromValId(); + header.erase(header.find("%func")); + std::ostringstream spirv; + spirv << header << std::endl; + + // Build nested structures. Struct 'i' contains struct 'i-1' + spirv << "%s_depth_1 = OpTypeStruct %float\n"; + for (int i = 2; i <= depth; ++i) { + spirv << "%s_depth_" << i << " = OpTypeStruct %s_depth_" << i - 1 << "\n"; + } + + // Define Pointer and Variable to use for CompositeExtract/Insert. + spirv << "%_ptr_Uniform_deep_struct = OpTypePointer Uniform %s_depth_" + << depth << "\n"; + spirv << "%deep_var = OpVariable %_ptr_Uniform_deep_struct Uniform\n"; + + // Function Start + spirv << R"( + %func = OpFunction %void None %void_f + %my_label = OpLabel + )"; + + // OpCompositeExtract/Insert with 'n' indexes (n = depth) + spirv << "%deep = OpLoad %s_depth_" << depth << " %deep_var" << std::endl; + spirv << "%entry = OpCompositeExtract %float %deep"; + for (int i = 0; i < depth; ++i) { + spirv << " 0"; + } + spirv << std::endl; + spirv << "%new_composite = OpCompositeInsert %s_depth_" << depth + << " %entry %deep"; + for (int i = 0; i < depth; ++i) { + spirv << " 0"; + } + spirv << std::endl; + + // Function end + spirv << R"( + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid: 256 indexes passed to OpCompositeExtract. Limit is 255. +TEST_F(ValidateComposites, CompositeExtractArgCountExceededLimitBad) { + std::ostringstream spirv; + spirv << GetHeaderForTestsFromValId() << std::endl; + spirv << "%matrix = OpLoad %mat4x3 %my_matrix" << std::endl; + spirv << "%entry = OpCompositeExtract %float %matrix"; + for (int i = 0; i < 256; ++i) { + spirv << " 0"; + } + spirv << R"( + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("The number of indexes in OpCompositeExtract may not " + "exceed 255. Found 256 indexes.")); +} + +// Invalid: 256 indexes passed to OpCompositeInsert. Limit is 255. +TEST_F(ValidateComposites, CompositeInsertArgCountExceededLimitBad) { + std::ostringstream spirv; + spirv << GetHeaderForTestsFromValId() << std::endl; + spirv << "%matrix = OpLoad %mat4x3 %my_matrix" << std::endl; + spirv << "%new_composite = OpCompositeInsert %mat4x3 %int_0 %matrix"; + for (int i = 0; i < 256; ++i) { + spirv << " 0"; + } + spirv << R"( + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("The number of indexes in OpCompositeInsert may not " + "exceed 255. Found 256 indexes.")); +} + +// Invalid: In OpCompositeInsert, result type must be the same as composite type +TEST_F(ValidateComposites, CompositeInsertWrongResultTypeBad) { + std::ostringstream spirv; + spirv << GetHeaderForTestsFromValId() << std::endl; + spirv << "%matrix = OpLoad %mat4x3 %my_matrix" << std::endl; + spirv << "%float_entry = OpCompositeExtract %float %matrix 0 1" << std::endl; + spirv << "%new_composite = OpCompositeInsert %float %float_entry %matrix 0 1" + << std::endl; + spirv << R"(OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("The Result Type must be the same as Composite type")); +} + +// Valid: No Indexes were passed to OpCompositeExtract, and the Result Type is +// the same as the Base Composite type. +TEST_F(ValidateComposites, CompositeExtractNoIndexesGood) { + std::ostringstream spirv; + spirv << GetHeaderForTestsFromValId() << std::endl; + spirv << "%matrix = OpLoad %mat4x3 %my_matrix" << std::endl; + spirv << "%float_entry = OpCompositeExtract %mat4x3 %matrix" << std::endl; + spirv << R"(OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid: No Indexes were passed to OpCompositeExtract, but the Result Type is +// different from the Base Composite type. +TEST_F(ValidateComposites, CompositeExtractNoIndexesBad) { + std::ostringstream spirv; + spirv << GetHeaderForTestsFromValId() << std::endl; + spirv << "%matrix = OpLoad %mat4x3 %my_matrix" << std::endl; + spirv << "%float_entry = OpCompositeExtract %float %matrix" << std::endl; + spirv << R"(OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Result type (OpTypeFloat) does not match the type " + "that results from indexing into the composite " + "(OpTypeMatrix).")); +} + +// Valid: No Indexes were passed to OpCompositeInsert, and the type of the +// Object argument matches the Composite type. +TEST_F(ValidateComposites, CompositeInsertMissingIndexesGood) { + std::ostringstream spirv; + spirv << GetHeaderForTestsFromValId() << std::endl; + spirv << "%matrix = OpLoad %mat4x3 %my_matrix" << std::endl; + spirv << "%matrix_2 = OpLoad %mat4x3 %my_matrix" << std::endl; + spirv << "%new_composite = OpCompositeInsert %mat4x3 %matrix_2 %matrix"; + spirv << R"( + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid: No Indexes were passed to OpCompositeInsert, but the type of the +// Object argument does not match the Composite type. +TEST_F(ValidateComposites, CompositeInsertMissingIndexesBad) { + std::ostringstream spirv; + spirv << GetHeaderForTestsFromValId() << std::endl; + spirv << "%matrix = OpLoad %mat4x3 %my_matrix" << std::endl; + spirv << "%new_composite = OpCompositeInsert %mat4x3 %int_0 %matrix"; + spirv << R"( + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("The Object type (OpTypeInt) does not match the type " + "that results from indexing into the Composite " + "(OpTypeMatrix).")); +} + +// Valid: Tests that we can index into Struct, Array, Matrix, and Vector! +TEST_F(ValidateComposites, CompositeExtractInsertIndexIntoAllTypesGood) { + // indexes that we are passing are: 0, 3, 1, 2, 0 + // 0 will select the struct_s within the base struct (blockName) + // 3 will select the Array that contains 5 matrices + // 1 will select the Matrix that is at index 1 of the array + // 2 will select the column (which is a vector) within the matrix at index 2 + // 0 will select the element at the index 0 of the vector. (which is a float). + std::ostringstream spirv; + spirv << GetHeaderForTestsFromValId() << R"( + %myblock = OpLoad %struct_blockName %blockName_var + %ss = OpCompositeExtract %struct_s %myblock 0 + %sa = OpCompositeExtract %array5_mat4x3 %myblock 0 3 + %sm = OpCompositeExtract %mat4x3 %myblock 0 3 1 + %sc = OpCompositeExtract %v3float %myblock 0 3 1 2 + %fl = OpCompositeExtract %float %myblock 0 3 1 2 0 + ; + ; Now let's insert back at different levels... + ; + %b1 = OpCompositeInsert %struct_blockName %ss %myblock 0 + %b2 = OpCompositeInsert %struct_blockName %sa %myblock 0 3 + %b3 = OpCompositeInsert %struct_blockName %sm %myblock 0 3 1 + %b4 = OpCompositeInsert %struct_blockName %sc %myblock 0 3 1 2 + %b5 = OpCompositeInsert %struct_blockName %fl %myblock 0 3 1 2 0 + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid. More indexes are provided than needed for OpCompositeExtract. +TEST_F(ValidateComposites, CompositeExtractReachedScalarBad) { + // indexes that we are passing are: 0, 3, 1, 2, 0 + // 0 will select the struct_s within the base struct (blockName) + // 3 will select the Array that contains 5 matrices + // 1 will select the Matrix that is at index 1 of the array + // 2 will select the column (which is a vector) within the matrix at index 2 + // 0 will select the element at the index 0 of the vector. (which is a float). + std::ostringstream spirv; + spirv << GetHeaderForTestsFromValId() << R"( + %myblock = OpLoad %struct_blockName %blockName_var + %fl = OpCompositeExtract %float %myblock 0 3 1 2 0 1 + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Reached non-composite type while indexes still remain " + "to be traversed.")); +} + +// Invalid. More indexes are provided than needed for OpCompositeInsert. +TEST_F(ValidateComposites, CompositeInsertReachedScalarBad) { + // indexes that we are passing are: 0, 3, 1, 2, 0 + // 0 will select the struct_s within the base struct (blockName) + // 3 will select the Array that contains 5 matrices + // 1 will select the Matrix that is at index 1 of the array + // 2 will select the column (which is a vector) within the matrix at index 2 + // 0 will select the element at the index 0 of the vector. (which is a float). + std::ostringstream spirv; + spirv << GetHeaderForTestsFromValId() << R"( + %myblock = OpLoad %struct_blockName %blockName_var + %fl = OpCompositeExtract %float %myblock 0 3 1 2 0 + %b5 = OpCompositeInsert %struct_blockName %fl %myblock 0 3 1 2 0 1 + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Reached non-composite type while indexes still remain " + "to be traversed.")); +} + +// Invalid. Result type doesn't match the type we get from indexing into +// the composite. +TEST_F(ValidateComposites, + CompositeExtractResultTypeDoesntMatchIndexedTypeBad) { + // indexes that we are passing are: 0, 3, 1, 2, 0 + // 0 will select the struct_s within the base struct (blockName) + // 3 will select the Array that contains 5 matrices + // 1 will select the Matrix that is at index 1 of the array + // 2 will select the column (which is a vector) within the matrix at index 2 + // 0 will select the element at the index 0 of the vector. (which is a float). + std::ostringstream spirv; + spirv << GetHeaderForTestsFromValId() << R"( + %myblock = OpLoad %struct_blockName %blockName_var + %fl = OpCompositeExtract %int %myblock 0 3 1 2 0 + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Result type (OpTypeInt) does not match the type that " + "results from indexing into the composite " + "(OpTypeFloat).")); +} + +// Invalid. Given object type doesn't match the type we get from indexing into +// the composite. +TEST_F(ValidateComposites, CompositeInsertObjectTypeDoesntMatchIndexedTypeBad) { + // indexes that we are passing are: 0, 3, 1, 2, 0 + // 0 will select the struct_s within the base struct (blockName) + // 3 will select the Array that contains 5 matrices + // 1 will select the Matrix that is at index 1 of the array + // 2 will select the column (which is a vector) within the matrix at index 2 + // 0 will select the element at the index 0 of the vector. (which is a float). + // We are trying to insert an integer where we should be inserting a float. + std::ostringstream spirv; + spirv << GetHeaderForTestsFromValId() << R"( + %myblock = OpLoad %struct_blockName %blockName_var + %b5 = OpCompositeInsert %struct_blockName %int_0 %myblock 0 3 1 2 0 + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("The Object type (OpTypeInt) does not match the type " + "that results from indexing into the Composite " + "(OpTypeFloat).")); +} + +// Invalid. Index into a struct is larger than the number of struct members. +TEST_F(ValidateComposites, CompositeExtractStructIndexOutOfBoundBad) { + // struct_blockName has 3 members (index 0,1,2). We'll try to access index 3. + std::ostringstream spirv; + spirv << GetHeaderForTestsFromValId() << R"( + %myblock = OpLoad %struct_blockName %blockName_var + %ss = OpCompositeExtract %struct_s %myblock 3 + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Index is out of bounds, can not find index 3 in the " + "structure '25'. This structure has 3 members. " + "Largest valid index is 2.")); +} + +// Invalid. Index into a struct is larger than the number of struct members. +TEST_F(ValidateComposites, CompositeInsertStructIndexOutOfBoundBad) { + // struct_blockName has 3 members (index 0,1,2). We'll try to access index 3. + std::ostringstream spirv; + spirv << GetHeaderForTestsFromValId() << R"( + %myblock = OpLoad %struct_blockName %blockName_var + %ss = OpCompositeExtract %struct_s %myblock 0 + %new_composite = OpCompositeInsert %struct_blockName %ss %myblock 3 + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Index is out of bounds, can not find index 3 in the structure " + " '25'. This structure has 3 members. Largest valid index " + "is 2.")); +} + +// #1403: Ensure that the default spec constant value is not used to check the +// extract index. +TEST_F(ValidateComposites, ExtractFromSpecConstantSizedArray) { + std::string spirv = R"( +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +OpDecorate %spec_const SpecId 1 +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%spec_const = OpSpecConstant %uint 3 +%uint_array = OpTypeArray %uint %spec_const +%undef = OpUndef %uint_array +%voidf = OpTypeFunction %void +%func = OpFunction %void None %voidf +%1 = OpLabel +%2 = OpCompositeExtract %uint %undef 4 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// #1403: Ensure that spec constant ops do not produce false positives. +TEST_F(ValidateComposites, ExtractFromSpecConstantOpSizedArray) { + std::string spirv = R"( +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +OpDecorate %spec_const SpecId 1 +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%const = OpConstant %uint 1 +%spec_const = OpSpecConstant %uint 3 +%spec_const_op = OpSpecConstantOp %uint IAdd %spec_const %const +%uint_array = OpTypeArray %uint %spec_const_op +%undef = OpUndef %uint_array +%voidf = OpTypeFunction %void +%func = OpFunction %void None %voidf +%1 = OpLabel +%2 = OpCompositeExtract %uint %undef 4 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// #1403: Ensure that the default spec constant value is not used to check the +// size of the array for a composite construct. This code has limited actual +// value as it is incorrect unless the specialization constant is assigned the +// value of 2, but it is still a valid module. +TEST_F(ValidateComposites, CompositeConstructSpecConstantSizedArray) { + std::string spirv = R"( +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +OpDecorate %spec_const SpecId 1 +%void = OpTypeVoid +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%spec_const = OpSpecConstant %uint 3 +%uint_array = OpTypeArray %uint %spec_const +%voidf = OpTypeFunction %void +%func = OpFunction %void None %voidf +%1 = OpLabel +%2 = OpCompositeConstruct %uint_array %uint_0 %uint_0 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateComposites, ExtractDynamicLabelIndex) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%void = OpTypeVoid +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%void_fn = OpTypeFunction %void +%float_0 = OpConstant %float 0 +%v4float_0 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +%func = OpFunction %void None %void_fn +%1 = OpLabel +%ex = OpVectorExtractDynamic %float %v4float_0 %v4float_0 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Index to be int scalar")); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_constants_test.cpp b/test/val/val_constants_test.cpp new file mode 100644 index 000000000..824d6fda5 --- /dev/null +++ b/test/val/val_constants_test.cpp @@ -0,0 +1,299 @@ +// Copyright (c) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Test validation of constants. +// +// This file contains newer tests. Older tests may be in other files such as +// val_id_test.cpp. + +#include +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::ValuesIn; + +using ValidateConstant = spvtest::ValidateBase; + +#define kBasicTypes \ + "%bool = OpTypeBool " \ + "%uint = OpTypeInt 32 0 " \ + "%uint2 = OpTypeVector %uint 2 " \ + "%float = OpTypeFloat 32 " \ + "%_ptr_uint = OpTypePointer Workgroup %uint " \ + "%uint_0 = OpConstantNull %uint " \ + "%uint2_0 = OpConstantNull %uint " \ + "%float_0 = OpConstantNull %float " \ + "%false = OpConstantFalse %bool " \ + "%true = OpConstantTrue %bool " \ + "%null = OpConstantNull %_ptr_uint " + +#define kShaderPreamble \ + "OpCapability Shader\n" \ + "OpCapability Linkage\n" \ + "OpMemoryModel Logical Simple\n" + +#define kKernelPreamble \ + "OpCapability Kernel\n" \ + "OpCapability Linkage\n" \ + "OpCapability Addresses\n" \ + "OpMemoryModel Physical32 OpenCL\n" + +struct ConstantOpCase { + spv_target_env env; + std::string assembly; + bool expect_success; + std::string expect_err; +}; + +using ValidateConstantOp = spvtest::ValidateBase; + +TEST_P(ValidateConstantOp, Samples) { + const auto env = GetParam().env; + CompileSuccessfully(GetParam().assembly, env); + const auto result = ValidateInstructions(env); + if (GetParam().expect_success) { + EXPECT_EQ(SPV_SUCCESS, result); + EXPECT_THAT(getDiagnosticString(), Eq("")); + } else { + EXPECT_EQ(SPV_ERROR_INVALID_ID, result); + EXPECT_THAT(getDiagnosticString(), HasSubstr(GetParam().expect_err)); + } +} + +#define GOOD_SHADER_10(STR) \ + { SPV_ENV_UNIVERSAL_1_0, kShaderPreamble kBasicTypes STR, true, "" } +#define GOOD_KERNEL_10(STR) \ + { SPV_ENV_UNIVERSAL_1_0, kKernelPreamble kBasicTypes STR, true, "" } +INSTANTIATE_TEST_CASE_P( + UniversalInShader, ValidateConstantOp, + ValuesIn(std::vector{ + // TODO(dneto): Conversions must change width. + GOOD_SHADER_10("%v = OpSpecConstantOp %uint SConvert %uint_0"), + GOOD_SHADER_10("%v = OpSpecConstantOp %float FConvert %float_0"), + GOOD_SHADER_10("%v = OpSpecConstantOp %uint SNegate %uint_0"), + GOOD_SHADER_10("%v = OpSpecConstantOp %uint Not %uint_0"), + GOOD_SHADER_10("%v = OpSpecConstantOp %uint IAdd %uint_0 %uint_0"), + GOOD_SHADER_10("%v = OpSpecConstantOp %uint ISub %uint_0 %uint_0"), + GOOD_SHADER_10("%v = OpSpecConstantOp %uint IMul %uint_0 %uint_0"), + GOOD_SHADER_10("%v = OpSpecConstantOp %uint UDiv %uint_0 %uint_0"), + GOOD_SHADER_10("%v = OpSpecConstantOp %uint SDiv %uint_0 %uint_0"), + GOOD_SHADER_10("%v = OpSpecConstantOp %uint UMod %uint_0 %uint_0"), + GOOD_SHADER_10("%v = OpSpecConstantOp %uint SRem %uint_0 %uint_0"), + GOOD_SHADER_10("%v = OpSpecConstantOp %uint SMod %uint_0 %uint_0"), + GOOD_SHADER_10( + "%v = OpSpecConstantOp %uint ShiftRightLogical %uint_0 %uint_0"), + GOOD_SHADER_10( + "%v = OpSpecConstantOp %uint ShiftRightArithmetic %uint_0 %uint_0"), + GOOD_SHADER_10( + "%v = OpSpecConstantOp %uint ShiftLeftLogical %uint_0 %uint_0"), + GOOD_SHADER_10("%v = OpSpecConstantOp %uint BitwiseOr %uint_0 %uint_0"), + GOOD_SHADER_10( + "%v = OpSpecConstantOp %uint BitwiseXor %uint_0 %uint_0"), + GOOD_SHADER_10( + "%v = OpSpecConstantOp %uint2 VectorShuffle %uint2_0 %uint2_0 1 3"), + GOOD_SHADER_10( + "%v = OpSpecConstantOp %uint CompositeExtract %uint2_0 1"), + GOOD_SHADER_10( + "%v = OpSpecConstantOp %uint2 CompositeInsert %uint_0 %uint2_0 1"), + GOOD_SHADER_10("%v = OpSpecConstantOp %bool LogicalOr %true %false"), + GOOD_SHADER_10("%v = OpSpecConstantOp %bool LogicalNot %true"), + GOOD_SHADER_10("%v = OpSpecConstantOp %bool LogicalAnd %true %false"), + GOOD_SHADER_10("%v = OpSpecConstantOp %bool LogicalEqual %true %false"), + GOOD_SHADER_10( + "%v = OpSpecConstantOp %bool LogicalNotEqual %true %false"), + GOOD_SHADER_10( + "%v = OpSpecConstantOp %uint Select %true %uint_0 %uint_0"), + GOOD_SHADER_10("%v = OpSpecConstantOp %bool IEqual %uint_0 %uint_0"), + GOOD_SHADER_10("%v = OpSpecConstantOp %bool INotEqual %uint_0 %uint_0"), + GOOD_SHADER_10("%v = OpSpecConstantOp %bool ULessThan %uint_0 %uint_0"), + GOOD_SHADER_10("%v = OpSpecConstantOp %bool SLessThan %uint_0 %uint_0"), + GOOD_SHADER_10( + "%v = OpSpecConstantOp %bool ULessThanEqual %uint_0 %uint_0"), + GOOD_SHADER_10( + "%v = OpSpecConstantOp %bool SLessThanEqual %uint_0 %uint_0"), + GOOD_SHADER_10( + "%v = OpSpecConstantOp %bool UGreaterThan %uint_0 %uint_0"), + GOOD_SHADER_10( + "%v = OpSpecConstantOp %bool UGreaterThanEqual %uint_0 %uint_0"), + GOOD_SHADER_10( + "%v = OpSpecConstantOp %bool SGreaterThan %uint_0 %uint_0"), + GOOD_SHADER_10( + "%v = OpSpecConstantOp %bool SGreaterThanEqual %uint_0 %uint_0"), + })); + +INSTANTIATE_TEST_CASE_P( + UniversalInKernel, ValidateConstantOp, + ValuesIn(std::vector{ + // TODO(dneto): Conversions must change width. + GOOD_KERNEL_10("%v = OpSpecConstantOp %uint SConvert %uint_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %float FConvert %float_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %uint SNegate %uint_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %uint Not %uint_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %uint IAdd %uint_0 %uint_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %uint ISub %uint_0 %uint_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %uint IMul %uint_0 %uint_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %uint UDiv %uint_0 %uint_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %uint SDiv %uint_0 %uint_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %uint UMod %uint_0 %uint_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %uint SRem %uint_0 %uint_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %uint SMod %uint_0 %uint_0"), + GOOD_KERNEL_10( + "%v = OpSpecConstantOp %uint ShiftRightLogical %uint_0 %uint_0"), + GOOD_KERNEL_10( + "%v = OpSpecConstantOp %uint ShiftRightArithmetic %uint_0 %uint_0"), + GOOD_KERNEL_10( + "%v = OpSpecConstantOp %uint ShiftLeftLogical %uint_0 %uint_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %uint BitwiseOr %uint_0 %uint_0"), + GOOD_KERNEL_10( + "%v = OpSpecConstantOp %uint BitwiseXor %uint_0 %uint_0"), + GOOD_KERNEL_10( + "%v = OpSpecConstantOp %uint2 VectorShuffle %uint2_0 %uint2_0 1 3"), + GOOD_KERNEL_10( + "%v = OpSpecConstantOp %uint CompositeExtract %uint2_0 1"), + GOOD_KERNEL_10( + "%v = OpSpecConstantOp %uint2 CompositeInsert %uint_0 %uint2_0 1"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %bool LogicalOr %true %false"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %bool LogicalNot %true"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %bool LogicalAnd %true %false"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %bool LogicalEqual %true %false"), + GOOD_KERNEL_10( + "%v = OpSpecConstantOp %bool LogicalNotEqual %true %false"), + GOOD_KERNEL_10( + "%v = OpSpecConstantOp %uint Select %true %uint_0 %uint_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %bool IEqual %uint_0 %uint_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %bool INotEqual %uint_0 %uint_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %bool ULessThan %uint_0 %uint_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %bool SLessThan %uint_0 %uint_0"), + GOOD_KERNEL_10( + "%v = OpSpecConstantOp %bool ULessThanEqual %uint_0 %uint_0"), + GOOD_KERNEL_10( + "%v = OpSpecConstantOp %bool SLessThanEqual %uint_0 %uint_0"), + GOOD_KERNEL_10( + "%v = OpSpecConstantOp %bool UGreaterThan %uint_0 %uint_0"), + GOOD_KERNEL_10( + "%v = OpSpecConstantOp %bool UGreaterThanEqual %uint_0 %uint_0"), + GOOD_KERNEL_10( + "%v = OpSpecConstantOp %bool SGreaterThan %uint_0 %uint_0"), + GOOD_KERNEL_10( + "%v = OpSpecConstantOp %bool SGreaterThanEqual %uint_0 %uint_0"), + })); + +INSTANTIATE_TEST_CASE_P( + KernelInKernel, ValidateConstantOp, + ValuesIn(std::vector{ + // TODO(dneto): Conversions must change width. + GOOD_KERNEL_10("%v = OpSpecConstantOp %uint ConvertFToS %float_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %float ConvertSToF %uint_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %uint ConvertFToU %float_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %float ConvertUToF %uint_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %uint UConvert %uint_0"), + GOOD_KERNEL_10( + "%v = OpSpecConstantOp %_ptr_uint GenericCastToPtr %null"), + GOOD_KERNEL_10( + "%v = OpSpecConstantOp %_ptr_uint PtrCastToGeneric %null"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %uint Bitcast %uint_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %float FNegate %float_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %float FAdd %float_0 %float_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %float FSub %float_0 %float_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %float FMul %float_0 %float_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %float FDiv %float_0 %float_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %float FRem %float_0 %float_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %float FMod %float_0 %float_0"), + GOOD_KERNEL_10( + "%v = OpSpecConstantOp %_ptr_uint AccessChain %null %uint_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %_ptr_uint InBoundsAccessChain " + "%null %uint_0"), + GOOD_KERNEL_10( + "%v = OpSpecConstantOp %_ptr_uint PtrAccessChain %null %uint_0"), + GOOD_KERNEL_10("%v = OpSpecConstantOp %_ptr_uint " + "InBoundsPtrAccessChain %null %uint_0"), + })); + +#define BAD_SHADER_10(STR, NAME) \ + { \ + SPV_ENV_UNIVERSAL_1_0, kShaderPreamble kBasicTypes STR, false, \ + "Specialization constant operation " NAME \ + " requires Kernel capability" \ + } +INSTANTIATE_TEST_CASE_P( + KernelInShader, ValidateConstantOp, + ValuesIn(std::vector{ + // TODO(dneto): Conversions must change width. + BAD_SHADER_10("%v = OpSpecConstantOp %uint ConvertFToS %float_0", + "ConvertFToS"), + BAD_SHADER_10("%v = OpSpecConstantOp %float ConvertSToF %uint_0", + "ConvertSToF"), + BAD_SHADER_10("%v = OpSpecConstantOp %uint ConvertFToU %float_0", + "ConvertFToU"), + BAD_SHADER_10("%v = OpSpecConstantOp %float ConvertUToF %uint_0", + "ConvertUToF"), + BAD_SHADER_10("%v = OpSpecConstantOp %_ptr_uint GenericCastToPtr %null", + "GenericCastToPtr"), + BAD_SHADER_10("%v = OpSpecConstantOp %_ptr_uint PtrCastToGeneric %null", + "PtrCastToGeneric"), + BAD_SHADER_10("%v = OpSpecConstantOp %uint Bitcast %uint_0", "Bitcast"), + BAD_SHADER_10("%v = OpSpecConstantOp %float FNegate %float_0", + "FNegate"), + BAD_SHADER_10("%v = OpSpecConstantOp %float FAdd %float_0 %float_0", + "FAdd"), + BAD_SHADER_10("%v = OpSpecConstantOp %float FSub %float_0 %float_0", + "FSub"), + BAD_SHADER_10("%v = OpSpecConstantOp %float FMul %float_0 %float_0", + "FMul"), + BAD_SHADER_10("%v = OpSpecConstantOp %float FDiv %float_0 %float_0", + "FDiv"), + BAD_SHADER_10("%v = OpSpecConstantOp %float FRem %float_0 %float_0", + "FRem"), + BAD_SHADER_10("%v = OpSpecConstantOp %float FMod %float_0 %float_0", + "FMod"), + BAD_SHADER_10( + "%v = OpSpecConstantOp %_ptr_uint AccessChain %null %uint_0", + "AccessChain"), + BAD_SHADER_10("%v = OpSpecConstantOp %_ptr_uint InBoundsAccessChain " + "%null %uint_0", + "InBoundsAccessChain"), + BAD_SHADER_10( + "%v = OpSpecConstantOp %_ptr_uint PtrAccessChain %null %uint_0", + "PtrAccessChain"), + BAD_SHADER_10("%v = OpSpecConstantOp %_ptr_uint " + "InBoundsPtrAccessChain %null %uint_0", + "InBoundsPtrAccessChain"), + })); + +INSTANTIATE_TEST_CASE_P( + UConvertInAMD_gpu_shader_int16, ValidateConstantOp, + ValuesIn(std::vector{ + // SPV_AMD_gpu_shader_int16 should enable UConvert for OpSpecConstantOp + // https://github.com/KhronosGroup/glslang/issues/848 + {SPV_ENV_UNIVERSAL_1_0, + "OpCapability Shader " + "OpCapability Linkage ; So we don't need to define a function\n" + "OpExtension \"SPV_AMD_gpu_shader_int16\" " + "OpMemoryModel Logical Simple " kBasicTypes + "%v = OpSpecConstantOp %uint UConvert %uint_0", + true, ""}, + })); + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_conversion_test.cpp b/test/val/val_conversion_test.cpp new file mode 100644 index 000000000..5e4ad4908 --- /dev/null +++ b/test/val/val_conversion_test.cpp @@ -0,0 +1,1419 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests for unique type declaration rules validator. + +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::HasSubstr; +using ::testing::Not; + +using ValidateConversion = spvtest::ValidateBase; + +std::string GenerateShaderCode( + const std::string& body, + const std::string& capabilities_and_extensions = "", + const std::string& decorations = "", const std::string& types = "", + const std::string& variables = "") { + const std::string capabilities = + R"( +OpCapability Shader +OpCapability Int64 +OpCapability Float64)"; + + const std::string after_extension_before_decorations = + R"( +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft)"; + + const std::string after_decorations_before_types = + R"( +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f32 = OpTypeFloat 32 +%u32 = OpTypeInt 32 0 +%s32 = OpTypeInt 32 1 +%f64 = OpTypeFloat 64 +%u64 = OpTypeInt 64 0 +%s64 = OpTypeInt 64 1 +%boolvec2 = OpTypeVector %bool 2 +%s32vec2 = OpTypeVector %s32 2 +%u32vec2 = OpTypeVector %u32 2 +%u64vec2 = OpTypeVector %u64 2 +%f32vec2 = OpTypeVector %f32 2 +%f64vec2 = OpTypeVector %f64 2 +%boolvec3 = OpTypeVector %bool 3 +%u32vec3 = OpTypeVector %u32 3 +%u64vec3 = OpTypeVector %u64 3 +%s32vec3 = OpTypeVector %s32 3 +%f32vec3 = OpTypeVector %f32 3 +%f64vec3 = OpTypeVector %f64 3 +%boolvec4 = OpTypeVector %bool 4 +%u32vec4 = OpTypeVector %u32 4 +%u64vec4 = OpTypeVector %u64 4 +%s32vec4 = OpTypeVector %s32 4 +%f32vec4 = OpTypeVector %f32 4 +%f64vec4 = OpTypeVector %f64 4 + +%f32_0 = OpConstant %f32 0 +%f32_1 = OpConstant %f32 1 +%f32_2 = OpConstant %f32 2 +%f32_3 = OpConstant %f32 3 +%f32_4 = OpConstant %f32 4 + +%s32_0 = OpConstant %s32 0 +%s32_1 = OpConstant %s32 1 +%s32_2 = OpConstant %s32 2 +%s32_3 = OpConstant %s32 3 +%s32_4 = OpConstant %s32 4 +%s32_m1 = OpConstant %s32 -1 + +%u32_0 = OpConstant %u32 0 +%u32_1 = OpConstant %u32 1 +%u32_2 = OpConstant %u32 2 +%u32_3 = OpConstant %u32 3 +%u32_4 = OpConstant %u32 4 + +%f64_0 = OpConstant %f64 0 +%f64_1 = OpConstant %f64 1 +%f64_2 = OpConstant %f64 2 +%f64_3 = OpConstant %f64 3 +%f64_4 = OpConstant %f64 4 + +%s64_0 = OpConstant %s64 0 +%s64_1 = OpConstant %s64 1 +%s64_2 = OpConstant %s64 2 +%s64_3 = OpConstant %s64 3 +%s64_4 = OpConstant %s64 4 +%s64_m1 = OpConstant %s64 -1 + +%u64_0 = OpConstant %u64 0 +%u64_1 = OpConstant %u64 1 +%u64_2 = OpConstant %u64 2 +%u64_3 = OpConstant %u64 3 +%u64_4 = OpConstant %u64 4 + +%u32vec2_01 = OpConstantComposite %u32vec2 %u32_0 %u32_1 +%u32vec2_12 = OpConstantComposite %u32vec2 %u32_1 %u32_2 +%u32vec3_012 = OpConstantComposite %u32vec3 %u32_0 %u32_1 %u32_2 +%u32vec3_123 = OpConstantComposite %u32vec3 %u32_1 %u32_2 %u32_3 +%u32vec4_0123 = OpConstantComposite %u32vec4 %u32_0 %u32_1 %u32_2 %u32_3 +%u32vec4_1234 = OpConstantComposite %u32vec4 %u32_1 %u32_2 %u32_3 %u32_4 + +%s32vec2_01 = OpConstantComposite %s32vec2 %s32_0 %s32_1 +%s32vec2_12 = OpConstantComposite %s32vec2 %s32_1 %s32_2 +%s32vec3_012 = OpConstantComposite %s32vec3 %s32_0 %s32_1 %s32_2 +%s32vec3_123 = OpConstantComposite %s32vec3 %s32_1 %s32_2 %s32_3 +%s32vec4_0123 = OpConstantComposite %s32vec4 %s32_0 %s32_1 %s32_2 %s32_3 +%s32vec4_1234 = OpConstantComposite %s32vec4 %s32_1 %s32_2 %s32_3 %s32_4 + +%f32vec2_01 = OpConstantComposite %f32vec2 %f32_0 %f32_1 +%f32vec2_12 = OpConstantComposite %f32vec2 %f32_1 %f32_2 +%f32vec3_012 = OpConstantComposite %f32vec3 %f32_0 %f32_1 %f32_2 +%f32vec3_123 = OpConstantComposite %f32vec3 %f32_1 %f32_2 %f32_3 +%f32vec4_0123 = OpConstantComposite %f32vec4 %f32_0 %f32_1 %f32_2 %f32_3 +%f32vec4_1234 = OpConstantComposite %f32vec4 %f32_1 %f32_2 %f32_3 %f32_4 + +%f64vec2_01 = OpConstantComposite %f64vec2 %f64_0 %f64_1 +%f64vec2_12 = OpConstantComposite %f64vec2 %f64_1 %f64_2 +%f64vec3_012 = OpConstantComposite %f64vec3 %f64_0 %f64_1 %f64_2 +%f64vec3_123 = OpConstantComposite %f64vec3 %f64_1 %f64_2 %f64_3 +%f64vec4_0123 = OpConstantComposite %f64vec4 %f64_0 %f64_1 %f64_2 %f64_3 +%f64vec4_1234 = OpConstantComposite %f64vec4 %f64_1 %f64_2 %f64_3 %f64_4 + +%true = OpConstantTrue %bool +%false = OpConstantFalse %bool + +%f32ptr_func = OpTypePointer Function %f32)"; + + const std::string after_variables_before_body = + R"( +%main = OpFunction %void None %func +%main_entry = OpLabel)"; + + const std::string after_body = + R"( +OpReturn +OpFunctionEnd)"; + + return capabilities + capabilities_and_extensions + + after_extension_before_decorations + decorations + + after_decorations_before_types + types + variables + + after_variables_before_body + body + after_body; +} + +std::string GenerateKernelCode( + const std::string& body, + const std::string& capabilities_and_extensions = "") { + const std::string capabilities = + R"( +OpCapability Addresses +OpCapability Kernel +OpCapability Linkage +OpCapability GenericPointer +OpCapability Int64 +OpCapability Float64)"; + + const std::string after_extension_before_body = + R"( +OpMemoryModel Physical32 OpenCL +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f32 = OpTypeFloat 32 +%u32 = OpTypeInt 32 0 +%f64 = OpTypeFloat 64 +%u64 = OpTypeInt 64 0 +%boolvec2 = OpTypeVector %bool 2 +%u32vec2 = OpTypeVector %u32 2 +%u64vec2 = OpTypeVector %u64 2 +%f32vec2 = OpTypeVector %f32 2 +%f64vec2 = OpTypeVector %f64 2 +%boolvec3 = OpTypeVector %bool 3 +%u32vec3 = OpTypeVector %u32 3 +%u64vec3 = OpTypeVector %u64 3 +%f32vec3 = OpTypeVector %f32 3 +%f64vec3 = OpTypeVector %f64 3 +%boolvec4 = OpTypeVector %bool 4 +%u32vec4 = OpTypeVector %u32 4 +%u64vec4 = OpTypeVector %u64 4 +%f32vec4 = OpTypeVector %f32 4 +%f64vec4 = OpTypeVector %f64 4 + +%f32_0 = OpConstant %f32 0 +%f32_1 = OpConstant %f32 1 +%f32_2 = OpConstant %f32 2 +%f32_3 = OpConstant %f32 3 +%f32_4 = OpConstant %f32 4 + +%u32_0 = OpConstant %u32 0 +%u32_1 = OpConstant %u32 1 +%u32_2 = OpConstant %u32 2 +%u32_3 = OpConstant %u32 3 +%u32_4 = OpConstant %u32 4 + +%f64_0 = OpConstant %f64 0 +%f64_1 = OpConstant %f64 1 +%f64_2 = OpConstant %f64 2 +%f64_3 = OpConstant %f64 3 +%f64_4 = OpConstant %f64 4 + +%u64_0 = OpConstant %u64 0 +%u64_1 = OpConstant %u64 1 +%u64_2 = OpConstant %u64 2 +%u64_3 = OpConstant %u64 3 +%u64_4 = OpConstant %u64 4 + +%u32vec2_01 = OpConstantComposite %u32vec2 %u32_0 %u32_1 +%u32vec2_12 = OpConstantComposite %u32vec2 %u32_1 %u32_2 +%u32vec3_012 = OpConstantComposite %u32vec3 %u32_0 %u32_1 %u32_2 +%u32vec3_123 = OpConstantComposite %u32vec3 %u32_1 %u32_2 %u32_3 +%u32vec4_0123 = OpConstantComposite %u32vec4 %u32_0 %u32_1 %u32_2 %u32_3 +%u32vec4_1234 = OpConstantComposite %u32vec4 %u32_1 %u32_2 %u32_3 %u32_4 + +%f32vec2_01 = OpConstantComposite %f32vec2 %f32_0 %f32_1 +%f32vec2_12 = OpConstantComposite %f32vec2 %f32_1 %f32_2 +%f32vec3_012 = OpConstantComposite %f32vec3 %f32_0 %f32_1 %f32_2 +%f32vec3_123 = OpConstantComposite %f32vec3 %f32_1 %f32_2 %f32_3 +%f32vec4_0123 = OpConstantComposite %f32vec4 %f32_0 %f32_1 %f32_2 %f32_3 +%f32vec4_1234 = OpConstantComposite %f32vec4 %f32_1 %f32_2 %f32_3 %f32_4 + +%f64vec2_01 = OpConstantComposite %f64vec2 %f64_0 %f64_1 +%f64vec2_12 = OpConstantComposite %f64vec2 %f64_1 %f64_2 +%f64vec3_012 = OpConstantComposite %f64vec3 %f64_0 %f64_1 %f64_2 +%f64vec3_123 = OpConstantComposite %f64vec3 %f64_1 %f64_2 %f64_3 +%f64vec4_0123 = OpConstantComposite %f64vec4 %f64_0 %f64_1 %f64_2 %f64_3 +%f64vec4_1234 = OpConstantComposite %f64vec4 %f64_1 %f64_2 %f64_3 %f64_4 + +%true = OpConstantTrue %bool +%false = OpConstantFalse %bool + +%f32ptr_func = OpTypePointer Function %f32 +%u32ptr_func = OpTypePointer Function %u32 +%f32ptr_gen = OpTypePointer Generic %f32 +%f32ptr_inp = OpTypePointer Input %f32 +%f32ptr_wg = OpTypePointer Workgroup %f32 +%f32ptr_cwg = OpTypePointer CrossWorkgroup %f32 + +%f32inp = OpVariable %f32ptr_inp Input + +%main = OpFunction %void None %func +%main_entry = OpLabel)"; + + const std::string after_body = + R"( +OpReturn +OpFunctionEnd)"; + + return capabilities + capabilities_and_extensions + + after_extension_before_body + body + after_body; +} + +TEST_F(ValidateConversion, ConvertFToUSuccess) { + const std::string body = R"( +%val1 = OpConvertFToU %u32 %f32_1 +%val2 = OpConvertFToU %u32 %f64_0 +%val3 = OpConvertFToU %u32vec2 %f32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateConversion, ConvertFToUWrongResultType) { + const std::string body = R"( +%val = OpConvertFToU %s32 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected unsigned int scalar or vector type as Result " + "Type: ConvertFToU")); +} + +TEST_F(ValidateConversion, ConvertFToUWrongInputType) { + const std::string body = R"( +%val = OpConvertFToU %u32 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected input to be float scalar or vector: ConvertFToU")); +} + +TEST_F(ValidateConversion, ConvertFToUDifferentDimension) { + const std::string body = R"( +%val = OpConvertFToU %u32 %f32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected input to have the same dimension as Result " + "Type: ConvertFToU")); +} + +TEST_F(ValidateConversion, ConvertFToSSuccess) { + const std::string body = R"( +%val1 = OpConvertFToS %s32 %f32_1 +%val2 = OpConvertFToS %u32 %f64_0 +%val3 = OpConvertFToS %s32vec2 %f32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateConversion, ConvertFToSWrongResultType) { + const std::string body = R"( +%val = OpConvertFToS %bool %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected int scalar or vector type as Result Type: ConvertFToS")); +} + +TEST_F(ValidateConversion, ConvertFToSWrongInputType) { + const std::string body = R"( +%val = OpConvertFToS %s32 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected input to be float scalar or vector: ConvertFToS")); +} + +TEST_F(ValidateConversion, ConvertFToSDifferentDimension) { + const std::string body = R"( +%val = OpConvertFToS %u32 %f32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected input to have the same dimension as Result " + "Type: ConvertFToS")); +} + +TEST_F(ValidateConversion, ConvertSToFSuccess) { + const std::string body = R"( +%val1 = OpConvertSToF %f32 %u32_1 +%val2 = OpConvertSToF %f32 %s64_0 +%val3 = OpConvertSToF %f32vec2 %s32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateConversion, ConvertSToFWrongResultType) { + const std::string body = R"( +%val = OpConvertSToF %u32 %s32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected float scalar or vector type as Result Type: ConvertSToF")); +} + +TEST_F(ValidateConversion, ConvertSToFWrongInputType) { + const std::string body = R"( +%val = OpConvertSToF %f32 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected input to be int scalar or vector: ConvertSToF")); +} + +TEST_F(ValidateConversion, ConvertSToFDifferentDimension) { + const std::string body = R"( +%val = OpConvertSToF %f32 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected input to have the same dimension as Result " + "Type: ConvertSToF")); +} + +TEST_F(ValidateConversion, UConvertSuccess) { + const std::string body = R"( +%val1 = OpUConvert %u32 %u64_1 +%val2 = OpUConvert %u64 %s32_0 +%val3 = OpUConvert %u64vec2 %s32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateConversion, UConvertWrongResultType) { + const std::string body = R"( +%val = OpUConvert %s32 %s32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected unsigned int scalar or vector type as Result " + "Type: UConvert")); +} + +TEST_F(ValidateConversion, UConvertWrongInputType) { + const std::string body = R"( +%val = OpUConvert %u32 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected input to be int scalar or vector: UConvert")); +} + +TEST_F(ValidateConversion, UConvertDifferentDimension) { + const std::string body = R"( +%val = OpUConvert %u32 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected input to have the same dimension as Result " + "Type: UConvert")); +} + +TEST_F(ValidateConversion, UConvertSameBitWidth) { + const std::string body = R"( +%val = OpUConvert %u32 %s32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected input to have different bit width from " + "Result Type: UConvert")); +} + +TEST_F(ValidateConversion, SConvertSuccess) { + const std::string body = R"( +%val1 = OpSConvert %s32 %u64_1 +%val2 = OpSConvert %s64 %s32_0 +%val3 = OpSConvert %u64vec2 %s32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateConversion, SConvertWrongResultType) { + const std::string body = R"( +%val = OpSConvert %f32 %s32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected int scalar or vector type as Result Type: SConvert")); +} + +TEST_F(ValidateConversion, SConvertWrongInputType) { + const std::string body = R"( +%val = OpSConvert %u32 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected input to be int scalar or vector: SConvert")); +} + +TEST_F(ValidateConversion, SConvertDifferentDimension) { + const std::string body = R"( +%val = OpSConvert %s32 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected input to have the same dimension as Result " + "Type: SConvert")); +} + +TEST_F(ValidateConversion, SConvertSameBitWidth) { + const std::string body = R"( +%val = OpSConvert %u32 %s32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected input to have different bit width from " + "Result Type: SConvert")); +} + +TEST_F(ValidateConversion, FConvertSuccess) { + const std::string body = R"( +%val1 = OpFConvert %f32 %f64_1 +%val2 = OpFConvert %f64 %f32_0 +%val3 = OpFConvert %f64vec2 %f32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateConversion, FConvertWrongResultType) { + const std::string body = R"( +%val = OpFConvert %u32 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected float scalar or vector type as Result Type: FConvert")); +} + +TEST_F(ValidateConversion, FConvertWrongInputType) { + const std::string body = R"( +%val = OpFConvert %f32 %u64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected input to be float scalar or vector: FConvert")); +} + +TEST_F(ValidateConversion, FConvertDifferentDimension) { + const std::string body = R"( +%val = OpFConvert %f64 %f32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected input to have the same dimension as Result " + "Type: FConvert")); +} + +TEST_F(ValidateConversion, FConvertSameBitWidth) { + const std::string body = R"( +%val = OpFConvert %f32 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected input to have different bit width from " + "Result Type: FConvert")); +} + +TEST_F(ValidateConversion, QuantizeToF16Success) { + const std::string body = R"( +%val1 = OpQuantizeToF16 %f32 %f32_1 +%val2 = OpQuantizeToF16 %f32 %f32_0 +%val3 = OpQuantizeToF16 %f32vec2 %f32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateConversion, QuantizeToF16WrongResultType) { + const std::string body = R"( +%val = OpQuantizeToF16 %u32 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected 32-bit float scalar or vector type as Result Type: " + "QuantizeToF16")); +} + +TEST_F(ValidateConversion, QuantizeToF16WrongResultTypeBitWidth) { + const std::string body = R"( +%val = OpQuantizeToF16 %u64 %f64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected 32-bit float scalar or vector type as Result Type: " + "QuantizeToF16")); +} + +TEST_F(ValidateConversion, QuantizeToF16WrongInputType) { + const std::string body = R"( +%val = OpQuantizeToF16 %f32 %f64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected input type to be equal to Result Type: QuantizeToF16")); +} + +TEST_F(ValidateConversion, ConvertFToS8BitStorage) { + const std::string capabilities_and_extensions = R"( +OpCapability StorageBuffer8BitAccess +OpExtension "SPV_KHR_8bit_storage" +OpExtension "SPV_KHR_storage_buffer_storage_class" +)"; + + const std::string decorations = R"( +OpDecorate %ssbo Block +OpDecorate %ssbo Binding 0 +OpDecorate %ssbo DescriptorSet 0 +OpMemberDecorate %ssbo 0 Offset 0 +)"; + + const std::string types = R"( +%i8 = OpTypeInt 8 1 +%i8ptr = OpTypePointer StorageBuffer %i8 +%ssbo = OpTypeStruct %i8 +%ssboptr = OpTypePointer StorageBuffer %ssbo +)"; + + const std::string variables = R"( +%var = OpVariable %ssboptr StorageBuffer +)"; + + const std::string body = R"( +%val = OpConvertFToS %i8 %f32_2 +%accesschain = OpAccessChain %i8ptr %var %u32_0 +OpStore %accesschain %val +)"; + + CompileSuccessfully(GenerateShaderCode(body, capabilities_and_extensions, + decorations, types, variables) + .c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Invalid cast to 8-bit integer from a floating-point: ConvertFToS")); +} + +TEST_F(ValidateConversion, ConvertFToU8BitStorage) { + const std::string capabilities_and_extensions = R"( +OpCapability StorageBuffer8BitAccess +OpExtension "SPV_KHR_8bit_storage" +OpExtension "SPV_KHR_storage_buffer_storage_class" +)"; + + const std::string decorations = R"( +OpDecorate %ssbo Block +OpDecorate %ssbo Binding 0 +OpDecorate %ssbo DescriptorSet 0 +OpMemberDecorate %ssbo 0 Offset 0 +)"; + + const std::string types = R"( +%u8 = OpTypeInt 8 0 +%u8ptr = OpTypePointer StorageBuffer %u8 +%ssbo = OpTypeStruct %u8 +%ssboptr = OpTypePointer StorageBuffer %ssbo +)"; + + const std::string variables = R"( +%var = OpVariable %ssboptr StorageBuffer +)"; + + const std::string body = R"( +%val = OpConvertFToU %u8 %f32_2 +%accesschain = OpAccessChain %u8ptr %var %u32_0 +OpStore %accesschain %val +)"; + + CompileSuccessfully(GenerateShaderCode(body, capabilities_and_extensions, + decorations, types, variables) + .c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Invalid cast to 8-bit integer from a floating-point: ConvertFToU")); +} + +TEST_F(ValidateConversion, ConvertSToF8BitStorage) { + const std::string capabilities_and_extensions = R"( +OpCapability StorageBuffer8BitAccess +OpExtension "SPV_KHR_8bit_storage" +OpExtension "SPV_KHR_storage_buffer_storage_class" +)"; + + const std::string decorations = R"( +OpDecorate %ssbo Block +OpDecorate %ssbo Binding 0 +OpDecorate %ssbo DescriptorSet 0 +OpMemberDecorate %ssbo 0 Offset 0 +)"; + + const std::string types = R"( +%i8 = OpTypeInt 8 1 +%i8ptr = OpTypePointer StorageBuffer %i8 +%ssbo = OpTypeStruct %i8 +%ssboptr = OpTypePointer StorageBuffer %ssbo +)"; + + const std::string variables = R"( +%var = OpVariable %ssboptr StorageBuffer +)"; + + const std::string body = R"( +%accesschain = OpAccessChain %i8ptr %var %u32_0 +%load = OpLoad %i8 %accesschain +%val = OpConvertSToF %f32 %load +)"; + + CompileSuccessfully(GenerateShaderCode(body, capabilities_and_extensions, + decorations, types, variables) + .c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Invalid cast to floating-point from an 8-bit integer: ConvertSToF")); +} + +TEST_F(ValidateConversion, ConvertUToF8BitStorage) { + const std::string capabilities_and_extensions = R"( +OpCapability StorageBuffer8BitAccess +OpExtension "SPV_KHR_8bit_storage" +OpExtension "SPV_KHR_storage_buffer_storage_class" +)"; + + const std::string decorations = R"( +OpDecorate %ssbo Block +OpDecorate %ssbo Binding 0 +OpDecorate %ssbo DescriptorSet 0 +OpMemberDecorate %ssbo 0 Offset 0 +)"; + + const std::string types = R"( +%u8 = OpTypeInt 8 0 +%u8ptr = OpTypePointer StorageBuffer %u8 +%ssbo = OpTypeStruct %u8 +%ssboptr = OpTypePointer StorageBuffer %ssbo +)"; + + const std::string variables = R"( +%var = OpVariable %ssboptr StorageBuffer +)"; + + const std::string body = R"( +%accesschain = OpAccessChain %u8ptr %var %u32_0 +%load = OpLoad %u8 %accesschain +%val = OpConvertUToF %f32 %load +)"; + + CompileSuccessfully(GenerateShaderCode(body, capabilities_and_extensions, + decorations, types, variables) + .c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Invalid cast to floating-point from an 8-bit integer: ConvertUToF")); +} + +TEST_F(ValidateConversion, ConvertPtrToUSuccess) { + const std::string body = R"( +%ptr = OpVariable %f32ptr_func Function +%val1 = OpConvertPtrToU %u32 %ptr +%val2 = OpConvertPtrToU %u64 %ptr +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateConversion, ConvertPtrToUWrongResultType) { + const std::string body = R"( +%ptr = OpVariable %f32ptr_func Function +%val = OpConvertPtrToU %f32 %ptr +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected unsigned int scalar type as Result Type: " + "ConvertPtrToU")); +} + +TEST_F(ValidateConversion, ConvertPtrToUNotPointer) { + const std::string body = R"( +%val = OpConvertPtrToU %u32 %f32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected input to be a pointer: ConvertPtrToU")); +} + +TEST_F(ValidateConversion, SatConvertSToUSuccess) { + const std::string body = R"( +%val1 = OpSatConvertSToU %u32 %u64_2 +%val2 = OpSatConvertSToU %u64 %u32_1 +%val3 = OpSatConvertSToU %u64vec2 %u32vec2_12 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateConversion, SatConvertSToUWrongResultType) { + const std::string body = R"( +%val = OpSatConvertSToU %f32 %u32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected int scalar or vector type as Result Type: " + "SatConvertSToU")); +} + +TEST_F(ValidateConversion, SatConvertSToUWrongInputType) { + const std::string body = R"( +%val = OpSatConvertSToU %u32 %f32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected int scalar or vector as input: SatConvertSToU")); +} + +TEST_F(ValidateConversion, SatConvertSToUDifferentDimension) { + const std::string body = R"( +%val = OpSatConvertSToU %u32 %u32vec2_12 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected input to have the same dimension as Result Type: " + "SatConvertSToU")); +} + +TEST_F(ValidateConversion, ConvertUToPtrSuccess) { + const std::string body = R"( +%val1 = OpConvertUToPtr %f32ptr_func %u32_1 +%val2 = OpConvertUToPtr %f32ptr_func %u64_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateConversion, ConvertUToPtrWrongResultType) { + const std::string body = R"( +%val = OpConvertUToPtr %f32 %u32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be a pointer: ConvertUToPtr")); +} + +TEST_F(ValidateConversion, ConvertUToPtrNotInt) { + const std::string body = R"( +%val = OpConvertUToPtr %f32ptr_func %f32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected int scalar as input: ConvertUToPtr")); +} + +TEST_F(ValidateConversion, ConvertUToPtrNotIntScalar) { + const std::string body = R"( +%val = OpConvertUToPtr %f32ptr_func %u32vec2_12 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected int scalar as input: ConvertUToPtr")); +} + +TEST_F(ValidateConversion, PtrCastToGenericSuccess) { + const std::string body = R"( +%ptr_func = OpVariable %f32ptr_func Function +%val = OpPtrCastToGeneric %f32ptr_gen %ptr_func +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateConversion, PtrCastToGenericWrongResultType) { + const std::string body = R"( +%ptr_func = OpVariable %f32ptr_func Function +%val = OpPtrCastToGeneric %f32 %ptr_func +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Result Type to be a pointer: PtrCastToGeneric")); +} + +TEST_F(ValidateConversion, PtrCastToGenericWrongResultStorageClass) { + const std::string body = R"( +%ptr_func = OpVariable %f32ptr_func Function +%val = OpPtrCastToGeneric %f32ptr_func %ptr_func +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to have storage class Generic: " + "PtrCastToGeneric")); +} + +TEST_F(ValidateConversion, PtrCastToGenericWrongInputType) { + const std::string body = R"( +%ptr_func = OpVariable %f32ptr_func Function +%val = OpPtrCastToGeneric %f32ptr_gen %f32 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("Operand 4[%float] cannot be a " + "type")); +} + +TEST_F(ValidateConversion, PtrCastToGenericWrongInputStorageClass) { + const std::string body = R"( +%val = OpPtrCastToGeneric %f32ptr_gen %f32inp +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected input to have storage class Workgroup, " + "CrossWorkgroup or Function: PtrCastToGeneric")); +} + +TEST_F(ValidateConversion, PtrCastToGenericPointToDifferentType) { + const std::string body = R"( +%ptr_func = OpVariable %u32ptr_func Function +%val = OpPtrCastToGeneric %f32ptr_gen %ptr_func +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected input and Result Type to point to the same type: " + "PtrCastToGeneric")); +} + +TEST_F(ValidateConversion, GenericCastToPtrSuccess) { + const std::string body = R"( +%ptr_func = OpVariable %f32ptr_func Function +%ptr_gen = OpPtrCastToGeneric %f32ptr_gen %ptr_func +%ptr_func2 = OpGenericCastToPtr %f32ptr_func %ptr_gen +%ptr_wg = OpGenericCastToPtr %f32ptr_wg %ptr_gen +%ptr_cwg = OpGenericCastToPtr %f32ptr_cwg %ptr_gen +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateConversion, GenericCastToPtrWrongResultType) { + const std::string body = R"( +%ptr_func = OpVariable %f32ptr_func Function +%ptr_gen = OpPtrCastToGeneric %f32ptr_gen %ptr_func +%ptr_func2 = OpGenericCastToPtr %f32 %ptr_gen +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Result Type to be a pointer: GenericCastToPtr")); +} + +TEST_F(ValidateConversion, GenericCastToPtrWrongResultStorageClass) { + const std::string body = R"( +%ptr_func = OpVariable %f32ptr_func Function +%ptr_gen = OpPtrCastToGeneric %f32ptr_gen %ptr_func +%ptr_func2 = OpGenericCastToPtr %f32ptr_gen %ptr_gen +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to have storage class Workgroup, " + "CrossWorkgroup or Function: GenericCastToPtr")); +} + +TEST_F(ValidateConversion, GenericCastToPtrWrongInputType) { + const std::string body = R"( +%ptr_func = OpVariable %f32ptr_func Function +%ptr_gen = OpPtrCastToGeneric %f32ptr_gen %ptr_func +%ptr_func2 = OpGenericCastToPtr %f32ptr_func %f32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected input to be a pointer: GenericCastToPtr")); +} + +TEST_F(ValidateConversion, GenericCastToPtrWrongInputStorageClass) { + const std::string body = R"( +%ptr_func = OpVariable %f32ptr_func Function +%ptr_func2 = OpGenericCastToPtr %f32ptr_func %ptr_func +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected input to have storage class Generic: " + "GenericCastToPtr")); +} + +TEST_F(ValidateConversion, GenericCastToPtrPointToDifferentType) { + const std::string body = R"( +%ptr_func = OpVariable %f32ptr_func Function +%ptr_gen = OpPtrCastToGeneric %f32ptr_gen %ptr_func +%ptr_func2 = OpGenericCastToPtr %u32ptr_func %ptr_gen +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected input and Result Type to point to the same type: " + "GenericCastToPtr")); +} + +TEST_F(ValidateConversion, GenericCastToPtrExplicitSuccess) { + const std::string body = R"( +%ptr_func = OpVariable %f32ptr_func Function +%ptr_gen = OpPtrCastToGeneric %f32ptr_gen %ptr_func +%ptr_func2 = OpGenericCastToPtrExplicit %f32ptr_func %ptr_gen Function +%ptr_wg = OpGenericCastToPtrExplicit %f32ptr_wg %ptr_gen Workgroup +%ptr_cwg = OpGenericCastToPtrExplicit %f32ptr_cwg %ptr_gen CrossWorkgroup +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateConversion, GenericCastToPtrExplicitWrongResultType) { + const std::string body = R"( +%ptr_func = OpVariable %f32ptr_func Function +%ptr_gen = OpPtrCastToGeneric %f32ptr_gen %ptr_func +%ptr_func2 = OpGenericCastToPtrExplicit %f32 %ptr_gen Function +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected Result Type to be a pointer: GenericCastToPtrExplicit")); +} + +TEST_F(ValidateConversion, GenericCastToPtrExplicitResultStorageClassDiffers) { + const std::string body = R"( +%ptr_func = OpVariable %f32ptr_func Function +%ptr_gen = OpPtrCastToGeneric %f32ptr_gen %ptr_func +%ptr_func2 = OpGenericCastToPtrExplicit %f32ptr_func %ptr_gen Workgroup +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be of target storage class: " + "GenericCastToPtrExplicit")); +} + +TEST_F(ValidateConversion, GenericCastToPtrExplicitWrongResultStorageClass) { + const std::string body = R"( +%ptr_func = OpVariable %f32ptr_func Function +%ptr_gen = OpPtrCastToGeneric %f32ptr_gen %ptr_func +%ptr_func2 = OpGenericCastToPtrExplicit %f32ptr_gen %ptr_gen Generic +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected target storage class to be Workgroup, " + "CrossWorkgroup or Function: GenericCastToPtrExplicit")); +} + +TEST_F(ValidateConversion, GenericCastToPtrExplicitWrongInputType) { + const std::string body = R"( +%ptr_func = OpVariable %f32ptr_func Function +%ptr_gen = OpPtrCastToGeneric %f32ptr_gen %ptr_func +%ptr_func2 = OpGenericCastToPtrExplicit %f32ptr_func %f32_1 Function +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected input to be a pointer: GenericCastToPtrExplicit")); +} + +TEST_F(ValidateConversion, GenericCastToPtrExplicitWrongInputStorageClass) { + const std::string body = R"( +%ptr_func = OpVariable %f32ptr_func Function +%ptr_func2 = OpGenericCastToPtrExplicit %f32ptr_func %ptr_func Function +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected input to have storage class Generic: " + "GenericCastToPtrExplicit")); +} + +TEST_F(ValidateConversion, GenericCastToPtrExplicitPointToDifferentType) { + const std::string body = R"( +%ptr_func = OpVariable %f32ptr_func Function +%ptr_gen = OpPtrCastToGeneric %f32ptr_gen %ptr_func +%ptr_func2 = OpGenericCastToPtrExplicit %u32ptr_func %ptr_gen Function +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected input and Result Type to point to the same type: " + "GenericCastToPtrExplicit")); +} + +TEST_F(ValidateConversion, BitcastSuccess) { + const std::string body = R"( +%ptr = OpVariable %f32ptr_func Function +%val1 = OpBitcast %u32 %ptr +%val2 = OpBitcast %u64 %ptr +%val3 = OpBitcast %f32ptr_func %u32_1 +%val4 = OpBitcast %f32ptr_wg %u64_1 +%val5 = OpBitcast %f32 %u32_1 +%val6 = OpBitcast %f32vec2 %u32vec2_12 +%val7 = OpBitcast %f32vec2 %u64_1 +%val8 = OpBitcast %f64 %u32vec2_12 +%val9 = OpBitcast %f32vec4 %f64vec2_12 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateConversion, BitcastInputHasNoType) { + const std::string body = R"( +%val = OpBitcast %u32 %f32 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("Operand 4[%float] cannot be a " + "type")); +} + +TEST_F(ValidateConversion, BitcastWrongResultType) { + const std::string body = R"( +%val = OpBitcast %bool %f32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Result Type to be a pointer or int or float vector " + "or scalar type: Bitcast")); +} + +TEST_F(ValidateConversion, BitcastWrongInputType) { + const std::string body = R"( +%val = OpBitcast %u32 %true +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected input to be a pointer or int or float vector " + "or scalar: Bitcast")); +} + +TEST_F(ValidateConversion, BitcastPtrWrongInputType) { + const std::string body = R"( +%val = OpBitcast %u32ptr_func %f32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected input to be a pointer or int scalar if Result Type " + "is pointer: Bitcast")); +} + +TEST_F(ValidateConversion, BitcastPtrWrongResultType) { + const std::string body = R"( +%val = OpBitcast %f32 %f32inp +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Pointer can only be converted to another pointer or int scalar: " + "Bitcast")); +} + +TEST_F(ValidateConversion, BitcastDifferentTotalBitWidth) { + const std::string body = R"( +%val = OpBitcast %f32 %u64_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected input to have the same total bit width as Result Type: " + "Bitcast")); +} + +TEST_F(ValidateConversion, ConvertUToPtrInputIsAType) { + const std::string spirv = R"( +OpCapability Addresses +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%int = OpTypeInt 32 0 +%ptr_int = OpTypePointer Function %int +%void = OpTypeVoid +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +%1 = OpConvertUToPtr %ptr_int %int +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("Operand 1[%uint] cannot be a " + "type")); +} + +TEST_F(ValidateConversion, ConvertUToPtrPSBSuccess) { + const std::string body = R"( +OpCapability PhysicalStorageBufferAddressesEXT +OpCapability Int64 +OpCapability Shader +OpExtension "SPV_EXT_physical_storage_buffer" +OpMemoryModel PhysicalStorageBuffer64EXT GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%uint64 = OpTypeInt 64 0 +%u64_1 = OpConstant %uint64 1 +%ptr = OpTypePointer PhysicalStorageBufferEXT %uint64 +%void = OpTypeVoid +%voidfn = OpTypeFunction %void +%main = OpFunction %void None %voidfn +%entry = OpLabel +%val1 = OpConvertUToPtr %ptr %u64_1 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateConversion, ConvertUToPtrPSBStorageClass) { + const std::string body = R"( +OpCapability PhysicalStorageBufferAddressesEXT +OpCapability Int64 +OpCapability Shader +OpExtension "SPV_EXT_physical_storage_buffer" +OpMemoryModel PhysicalStorageBuffer64EXT GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%uint64 = OpTypeInt 64 0 +%u64_1 = OpConstant %uint64 1 +%ptr = OpTypePointer Function %uint64 +%void = OpTypeVoid +%voidfn = OpTypeFunction %void +%main = OpFunction %void None %voidfn +%entry = OpLabel +%val1 = OpConvertUToPtr %ptr %u64_1 +%val2 = OpConvertPtrToU %uint64 %val1 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Pointer storage class must be " + "PhysicalStorageBufferEXT: ConvertUToPtr")); +} + +TEST_F(ValidateConversion, ConvertPtrToUPSBSuccess) { + const std::string body = R"( +OpCapability PhysicalStorageBufferAddressesEXT +OpCapability Int64 +OpCapability Shader +OpExtension "SPV_EXT_physical_storage_buffer" +OpMemoryModel PhysicalStorageBuffer64EXT GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpDecorate %val1 RestrictPointerEXT +%uint64 = OpTypeInt 64 0 +%u64_1 = OpConstant %uint64 1 +%ptr = OpTypePointer PhysicalStorageBufferEXT %uint64 +%pptr_f = OpTypePointer Function %ptr +%void = OpTypeVoid +%voidfn = OpTypeFunction %void +%main = OpFunction %void None %voidfn +%entry = OpLabel +%val1 = OpVariable %pptr_f Function +%val2 = OpLoad %ptr %val1 +%val3 = OpConvertPtrToU %uint64 %val2 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateConversion, ConvertPtrToUPSBStorageClass) { + const std::string body = R"( +OpCapability PhysicalStorageBufferAddressesEXT +OpCapability Int64 +OpCapability Shader +OpExtension "SPV_EXT_physical_storage_buffer" +OpMemoryModel PhysicalStorageBuffer64EXT GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%uint64 = OpTypeInt 64 0 +%u64_1 = OpConstant %uint64 1 +%ptr = OpTypePointer Function %uint64 +%void = OpTypeVoid +%voidfn = OpTypeFunction %void +%main = OpFunction %void None %voidfn +%entry = OpLabel +%val1 = OpVariable %ptr Function +%val2 = OpConvertPtrToU %uint64 %val1 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Pointer storage class must be " + "PhysicalStorageBufferEXT: ConvertPtrToU")); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_data_test.cpp b/test/val/val_data_test.cpp new file mode 100644 index 000000000..0aded92b5 --- /dev/null +++ b/test/val/val_data_test.cpp @@ -0,0 +1,939 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validation tests for Data Rules. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::HasSubstr; +using ::testing::MatchesRegex; + +using ValidateData = spvtest::ValidateBase>; + +std::string HeaderWith(std::string cap) { + return std::string("OpCapability Shader OpCapability Linkage OpCapability ") + + cap + " OpMemoryModel Logical GLSL450 "; +} + +std::string WebGPUHeaderWith(std::string cap) { + return R"( +OpCapability Shader +OpCapability )" + + cap + R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +)"; +} + +std::string webgpu_header = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +)"; + +std::string header = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 +)"; +std::string header_with_addresses = R"( + OpCapability Addresses + OpCapability Kernel + OpCapability GenericPointer + OpCapability Linkage + OpMemoryModel Physical32 OpenCL +)"; +std::string header_with_vec16_cap = R"( + OpCapability Shader + OpCapability Vector16 + OpCapability Linkage + OpMemoryModel Logical GLSL450 +)"; +std::string header_with_int8 = R"( + OpCapability Shader + OpCapability Linkage + OpCapability Int8 + OpMemoryModel Logical GLSL450 +)"; +std::string header_with_int16 = R"( + OpCapability Shader + OpCapability Linkage + OpCapability Int16 + OpMemoryModel Logical GLSL450 +)"; +std::string header_with_int64 = R"( + OpCapability Shader + OpCapability Linkage + OpCapability Int64 + OpMemoryModel Logical GLSL450 +)"; +std::string header_with_float16 = R"( + OpCapability Shader + OpCapability Linkage + OpCapability Float16 + OpMemoryModel Logical GLSL450 +)"; +std::string header_with_float16_buffer = R"( + OpCapability Shader + OpCapability Linkage + OpCapability Float16Buffer + OpMemoryModel Logical GLSL450 +)"; +std::string header_with_float64 = R"( + OpCapability Shader + OpCapability Linkage + OpCapability Float64 + OpMemoryModel Logical GLSL450 +)"; + +std::string invalid_comp_error = "Illegal number of components"; +std::string missing_cap_error = "requires the Vector16 capability"; +std::string missing_int8_cap_error = "requires the Int8 capability"; +std::string missing_int16_cap_error = + "requires the Int16 capability," + " or an extension that explicitly enables 16-bit integers."; +std::string missing_int64_cap_error = "requires the Int64 capability"; +std::string missing_float16_cap_error = + "requires the Float16 or Float16Buffer capability," + " or an extension that explicitly enables 16-bit floating point."; +std::string missing_float64_cap_error = "requires the Float64 capability"; +std::string invalid_num_bits_error = "Invalid number of bits"; + +TEST_F(ValidateData, vec0) { + std::string str = header + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 0 +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(invalid_comp_error)); +} + +TEST_F(ValidateData, vec1) { + std::string str = header + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 1 +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(invalid_comp_error)); +} + +TEST_F(ValidateData, vec2) { + std::string str = header + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 2 +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateData, vec3) { + std::string str = header + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 3 +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateData, vec4) { + std::string str = header + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 4 +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateData, vec5) { + std::string str = header + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 5 +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(invalid_comp_error)); +} + +TEST_F(ValidateData, vec8) { + std::string str = header + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 8 +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(missing_cap_error)); +} + +TEST_F(ValidateData, vec8_with_capability) { + std::string str = header_with_vec16_cap + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 8 +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateData, vec16) { + std::string str = header + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 8 +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(missing_cap_error)); +} + +TEST_F(ValidateData, vec16_with_capability) { + std::string str = header_with_vec16_cap + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 16 +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateData, vec15) { + std::string str = header + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 15 +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(invalid_comp_error)); +} + +TEST_F(ValidateData, int8_good) { + std::string str = header_with_int8 + "%2 = OpTypeInt 8 0"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateData, int8_bad) { + std::string str = header + "%2 = OpTypeInt 8 1"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(missing_int8_cap_error)); +} + +TEST_F(ValidateData, int8_with_storage_buffer_8bit_access_good) { + std::string str = HeaderWith( + "StorageBuffer8BitAccess " + "OpExtension \"SPV_KHR_8bit_storage\"") + + " %2 = OpTypeInt 8 0"; + CompileSuccessfully(str.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()) << getDiagnosticString(); +} + +TEST_F(ValidateData, int8_with_uniform_and_storage_buffer_8bit_access_good) { + std::string str = HeaderWith( + "UniformAndStorageBuffer8BitAccess " + "OpExtension \"SPV_KHR_8bit_storage\"") + + " %2 = OpTypeInt 8 0"; + CompileSuccessfully(str.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()) << getDiagnosticString(); +} + +TEST_F(ValidateData, int8_with_storage_push_constant_8_good) { + std::string str = HeaderWith( + "StoragePushConstant8 " + "OpExtension \"SPV_KHR_8bit_storage\"") + + " %2 = OpTypeInt 8 0"; + CompileSuccessfully(str.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()) << getDiagnosticString(); +} + +TEST_F(ValidateData, webgpu_int8_bad) { + std::string str = WebGPUHeaderWith("Int8") + "%2 = OpTypeInt 8 0"; + CompileSuccessfully(str.c_str(), SPV_ENV_WEBGPU_0); + ASSERT_EQ(SPV_ERROR_INVALID_CAPABILITY, + ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Capability Int8 is not allowed by WebGPU specification (or " + "requires extension)\n" + " OpCapability Int8\n")); +} + +TEST_F(ValidateData, int16_good) { + std::string str = header_with_int16 + "%2 = OpTypeInt 16 1"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateData, storage_uniform_buffer_block_16_good) { + std::string str = HeaderWith( + "StorageUniformBufferBlock16 " + "OpExtension \"SPV_KHR_16bit_storage\"") + + "%2 = OpTypeInt 16 1 %3 = OpTypeFloat 16"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateData, storage_uniform_16_good) { + std::string str = + HeaderWith("StorageUniform16 OpExtension \"SPV_KHR_16bit_storage\"") + + "%2 = OpTypeInt 16 1 %3 = OpTypeFloat 16"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateData, storage_push_constant_16_good) { + std::string str = HeaderWith( + "StoragePushConstant16 " + "OpExtension \"SPV_KHR_16bit_storage\"") + + "%2 = OpTypeInt 16 1 %3 = OpTypeFloat 16"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateData, storage_input_output_16_good) { + std::string str = HeaderWith( + "StorageInputOutput16 " + "OpExtension \"SPV_KHR_16bit_storage\"") + + "%2 = OpTypeInt 16 1 %3 = OpTypeFloat 16"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateData, amd_gpu_shader_half_float_fetch_16_good) { + std::string str = R"( + OpCapability Shader + OpCapability Linkage + OpExtension "SPV_AMD_gpu_shader_half_float_fetch" + OpMemoryModel Logical GLSL450 + %2 = OpTypeFloat 16)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateData, int16_bad) { + std::string str = header + "%2 = OpTypeInt 16 1"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(missing_int16_cap_error)); +} + +TEST_F(ValidateData, webgpu_int16_bad) { + std::string str = WebGPUHeaderWith("Int16") + "%2 = OpTypeInt 16 1"; + CompileSuccessfully(str.c_str(), SPV_ENV_WEBGPU_0); + ASSERT_EQ(SPV_ERROR_INVALID_CAPABILITY, + ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Capability Int16 is not allowed by WebGPU specification (or " + "requires extension)\n" + " OpCapability Int16\n")); +} + +TEST_F(ValidateData, webgpu_int32_good) { + std::string str = webgpu_header + R"( + OpEntryPoint Fragment %func "func" + OpExecutionMode %func OriginUpperLeft +%uint_t = OpTypeInt 32 0 + %void = OpTypeVoid +%func_t = OpTypeFunction %void + %func = OpFunction %void None %func_t + %1 = OpLabel + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str.c_str(), SPV_ENV_WEBGPU_0); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0)); +} + +TEST_F(ValidateData, int64_good) { + std::string str = header_with_int64 + "%2 = OpTypeInt 64 1"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateData, int64_bad) { + std::string str = header + "%2 = OpTypeInt 64 1"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(missing_int64_cap_error)); +} + +TEST_F(ValidateData, webgpu_int64_bad) { + std::string str = WebGPUHeaderWith("Int64") + "%2 = OpTypeInt 64 1"; + CompileSuccessfully(str.c_str(), SPV_ENV_WEBGPU_0); + ASSERT_EQ(SPV_ERROR_INVALID_CAPABILITY, + ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Capability Int64 is not allowed by WebGPU specification (or " + "requires extension)\n" + " OpCapability Int64\n")); +} + +// Number of bits in an integer may be only one of: {8,16,32,64} +TEST_F(ValidateData, int_invalid_num_bits) { + std::string str = header + "%2 = OpTypeInt 48 1"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(invalid_num_bits_error)); +} + +TEST_F(ValidateData, float16_good) { + std::string str = header_with_float16 + "%2 = OpTypeFloat 16"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateData, float16_buffer_good) { + std::string str = header_with_float16_buffer + "%2 = OpTypeFloat 16"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateData, float16_bad) { + std::string str = header + "%2 = OpTypeFloat 16"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(missing_float16_cap_error)); +} + +TEST_F(ValidateData, webgpu_float16_bad) { + std::string str = WebGPUHeaderWith("Float16") + "%2 = OpTypeFloat 16"; + CompileSuccessfully(str.c_str(), SPV_ENV_WEBGPU_0); + ASSERT_EQ(SPV_ERROR_INVALID_CAPABILITY, + ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Capability Float16 is not allowed by WebGPU specification (or " + "requires extension)\n" + " OpCapability Float16\n")); +} + +TEST_F(ValidateData, webgpu_float32_good) { + std::string str = webgpu_header + R"( + OpEntryPoint Fragment %func "func" + OpExecutionMode %func OriginUpperLeft +%float_t = OpTypeFloat 32 + %void = OpTypeVoid + %func_t = OpTypeFunction %void + %func = OpFunction %void None %func_t + %1 = OpLabel + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str.c_str(), SPV_ENV_WEBGPU_0); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0)); +} + +TEST_F(ValidateData, float64_good) { + std::string str = header_with_float64 + "%2 = OpTypeFloat 64"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateData, float64_bad) { + std::string str = header + "%2 = OpTypeFloat 64"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(missing_float64_cap_error)); +} + +TEST_F(ValidateData, webgpu_float64_bad) { + std::string str = WebGPUHeaderWith("Float64") + "%2 = OpTypeFloat 64"; + CompileSuccessfully(str.c_str(), SPV_ENV_WEBGPU_0); + ASSERT_EQ(SPV_ERROR_INVALID_CAPABILITY, + ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Capability Float64 is not allowed by WebGPU specification (or " + "requires extension)\n" + " OpCapability Float64\n")); +} + +// Number of bits in a float may be only one of: {16,32,64} +TEST_F(ValidateData, float_invalid_num_bits) { + std::string str = header + "%2 = OpTypeFloat 48"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(invalid_num_bits_error)); +} + +TEST_F(ValidateData, matrix_data_type_float) { + std::string str = header + R"( +%f32 = OpTypeFloat 32 +%vec3 = OpTypeVector %f32 3 +%mat33 = OpTypeMatrix %vec3 3 +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateData, ids_should_be_validated_before_data) { + std::string str = header + R"( +%f32 = OpTypeFloat 32 +%mat33 = OpTypeMatrix %vec3 3 +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("ID 3[%3] has not been defined")); +} + +TEST_F(ValidateData, matrix_bad_column_type) { + std::string str = header + R"( +%f32 = OpTypeFloat 32 +%mat33 = OpTypeMatrix %f32 3 +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Columns in a matrix must be of type vector")); +} + +TEST_F(ValidateData, matrix_data_type_int) { + std::string str = header + R"( +%int32 = OpTypeInt 32 1 +%vec3 = OpTypeVector %int32 3 +%mat33 = OpTypeMatrix %vec3 3 +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("can only be parameterized with floating-point types")); +} + +TEST_F(ValidateData, matrix_data_type_bool) { + std::string str = header + R"( +%boolt = OpTypeBool +%vec3 = OpTypeVector %boolt 3 +%mat33 = OpTypeMatrix %vec3 3 +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("can only be parameterized with floating-point types")); +} + +TEST_F(ValidateData, matrix_with_0_columns) { + std::string str = header + R"( +%f32 = OpTypeFloat 32 +%vec3 = OpTypeVector %f32 3 +%mat33 = OpTypeMatrix %vec3 0 +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("can only be parameterized as having only 2, 3, or 4 columns")); +} + +TEST_F(ValidateData, matrix_with_1_column) { + std::string str = header + R"( +%f32 = OpTypeFloat 32 +%vec3 = OpTypeVector %f32 3 +%mat33 = OpTypeMatrix %vec3 1 +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("can only be parameterized as having only 2, 3, or 4 columns")); +} + +TEST_F(ValidateData, matrix_with_2_columns) { + std::string str = header + R"( +%f32 = OpTypeFloat 32 +%vec3 = OpTypeVector %f32 3 +%mat33 = OpTypeMatrix %vec3 2 +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateData, matrix_with_3_columns) { + std::string str = header + R"( +%f32 = OpTypeFloat 32 +%vec3 = OpTypeVector %f32 3 +%mat33 = OpTypeMatrix %vec3 3 +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateData, matrix_with_4_columns) { + std::string str = header + R"( +%f32 = OpTypeFloat 32 +%vec3 = OpTypeVector %f32 3 +%mat33 = OpTypeMatrix %vec3 4 +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateData, matrix_with_5_column) { + std::string str = header + R"( +%f32 = OpTypeFloat 32 +%vec3 = OpTypeVector %f32 3 +%mat33 = OpTypeMatrix %vec3 5 +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("can only be parameterized as having only 2, 3, or 4 columns")); +} + +TEST_F(ValidateData, specialize_int) { + std::string str = header + R"( +%i32 = OpTypeInt 32 1 +%len = OpSpecConstant %i32 2)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateData, specialize_float) { + std::string str = header + R"( +%f32 = OpTypeFloat 32 +%len = OpSpecConstant %f32 2)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateData, specialize_boolean) { + std::string str = header + R"( +%2 = OpTypeBool +%3 = OpSpecConstantTrue %2 +%4 = OpSpecConstantFalse %2)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateData, specialize_boolean_to_int) { + std::string str = header + R"( +%2 = OpTypeInt 32 1 +%3 = OpSpecConstantTrue %2 +%4 = OpSpecConstantFalse %2)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Specialization constant must be a boolean")); +} + +TEST_F(ValidateData, missing_forward_pointer_decl) { + std::string str = header_with_addresses + R"( +%uintt = OpTypeInt 32 0 +%3 = OpTypeStruct %fwd_ptrt %uintt +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must first be declared using OpTypeForwardPointer")); +} + +TEST_F(ValidateData, missing_forward_pointer_decl_self_reference) { + std::string str = header_with_addresses + R"( +%uintt = OpTypeInt 32 0 +%3 = OpTypeStruct %3 %uintt +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must first be declared using OpTypeForwardPointer")); +} + +TEST_F(ValidateData, forward_pointer_missing_definition) { + std::string str = header_with_addresses + R"( +OpTypeForwardPointer %_ptr_Generic_struct_A Generic +%uintt = OpTypeInt 32 0 +%struct_B = OpTypeStruct %uintt %_ptr_Generic_struct_A +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("forward referenced IDs have not been defined")); +} + +TEST_F(ValidateData, forward_ref_bad_type) { + std::string str = header_with_addresses + R"( +OpTypeForwardPointer %_ptr_Generic_struct_A Generic +%uintt = OpTypeInt 32 0 +%struct_B = OpTypeStruct %uintt %_ptr_Generic_struct_A +%_ptr_Generic_struct_A = OpTypeFloat 32 +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Pointer type in OpTypeForwardPointer is not a pointer " + "type.\n OpTypeForwardPointer %float Generic\n")); +} + +TEST_F(ValidateData, forward_ref_points_to_non_struct) { + std::string str = header_with_addresses + R"( +OpTypeForwardPointer %_ptr_Generic_struct_A Generic +%uintt = OpTypeInt 32 0 +%struct_B = OpTypeStruct %uintt %_ptr_Generic_struct_A +%_ptr_Generic_struct_A = OpTypePointer Generic %uintt +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("A forward reference operand in an OpTypeStruct must " + "be an OpTypePointer that points to an OpTypeStruct. " + "Found OpTypePointer that points to OpTypeInt.")); +} + +TEST_F(ValidateData, struct_forward_pointer_good) { + std::string str = header_with_addresses + R"( +OpTypeForwardPointer %_ptr_Generic_struct_A Generic +%uintt = OpTypeInt 32 0 +%struct_B = OpTypeStruct %uintt %_ptr_Generic_struct_A +%struct_C = OpTypeStruct %uintt %struct_B +%struct_A = OpTypeStruct %uintt %struct_C +%_ptr_Generic_struct_A = OpTypePointer Generic %struct_C +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateData, ext_16bit_storage_caps_allow_free_fp_rounding_mode) { + for (const char* cap : {"StorageUniform16", "StorageUniformBufferBlock16", + "StoragePushConstant16", "StorageInputOutput16"}) { + for (const char* mode : {"RTE", "RTZ", "RTP", "RTN"}) { + std::string str = std::string(R"( + OpCapability Shader + OpCapability Linkage + OpCapability )") + + cap + R"( + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpExtension "SPV_KHR_variable_pointers" + OpExtension "SPV_KHR_16bit_storage" + OpMemoryModel Logical GLSL450 + OpDecorate %_ FPRoundingMode )" + mode + R"( + %half = OpTypeFloat 16 + %float = OpTypeFloat 32 + %float_1_25 = OpConstant %float 1.25 + %half_ptr = OpTypePointer StorageBuffer %half + %half_ptr_var = OpVariable %half_ptr StorageBuffer + %void = OpTypeVoid + %func = OpTypeFunction %void + %main = OpFunction %void None %func + %main_entry = OpLabel + %_ = OpFConvert %half %float_1_25 + OpStore %half_ptr_var %_ + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + } + } +} + +TEST_F(ValidateData, vulkan_disallow_free_fp_rounding_mode) { + for (const char* mode : {"RTE", "RTZ"}) { + for (const auto env : {SPV_ENV_VULKAN_1_0, SPV_ENV_VULKAN_1_1}) { + std::string str = std::string(R"( + OpCapability Shader + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpExtension "SPV_KHR_variable_pointers" + OpMemoryModel Logical GLSL450 + OpDecorate %_ FPRoundingMode )") + + mode + R"( + %half = OpTypeFloat 16 + %float = OpTypeFloat 32 + %float_1_25 = OpConstant %float 1.25 + %half_ptr = OpTypePointer StorageBuffer %half + %half_ptr_var = OpVariable %half_ptr StorageBuffer + %void = OpTypeVoid + %func = OpTypeFunction %void + %main = OpFunction %void None %func + %main_entry = OpLabel + %_ = OpFConvert %half %float_1_25 + OpStore %half_ptr_var %_ + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_CAPABILITY, ValidateInstructions(env)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Operand 2 of Decorate requires one of these capabilities: " + "StorageBuffer16BitAccess StorageUniform16 " + "StoragePushConstant16 StorageInputOutput16")); + } + } +} + +TEST_F(ValidateData, void_array) { + std::string str = header + R"( + %void = OpTypeVoid + %int = OpTypeInt 32 0 + %int_5 = OpConstant %int 5 + %array = OpTypeArray %void %int_5 + )"; + + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpTypeArray Element Type '1[%void]' is a void type.")); +} + +TEST_F(ValidateData, void_runtime_array) { + std::string str = header + R"( + %void = OpTypeVoid + %array = OpTypeRuntimeArray %void + )"; + + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpTypeRuntimeArray Element Type '1[%void]' is a void type.")); +} + +TEST_F(ValidateData, vulkan_RTA_array_at_end_of_struct) { + std::string str = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %func "func" + OpExecutionMode %func OriginUpperLeft + OpDecorate %array_t ArrayStride 4 + OpMemberDecorate %struct_t 0 Offset 0 + OpMemberDecorate %struct_t 1 Offset 4 + OpDecorate %struct_t Block + %uint_t = OpTypeInt 32 0 + %array_t = OpTypeRuntimeArray %uint_t + %struct_t = OpTypeStruct %uint_t %array_t +%struct_ptr = OpTypePointer StorageBuffer %struct_t + %2 = OpVariable %struct_ptr StorageBuffer + %void = OpTypeVoid + %func_t = OpTypeFunction %void + %func = OpFunction %void None %func_t + %1 = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str.c_str(), SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateData, vulkan_RTA_not_at_end_of_struct) { + std::string str = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %func "func" + OpExecutionMode %func OriginUpperLeft + OpDecorate %array_t ArrayStride 4 + OpMemberDecorate %struct_t 0 Offset 0 + OpMemberDecorate %struct_t 1 Offset 4 + OpDecorate %struct_t Block + %uint_t = OpTypeInt 32 0 + %array_t = OpTypeRuntimeArray %uint_t + %struct_t = OpTypeStruct %array_t %uint_t +%struct_ptr = OpTypePointer StorageBuffer %struct_t + %2 = OpVariable %struct_ptr StorageBuffer + %void = OpTypeVoid + %func_t = OpTypeFunction %void + %func = OpFunction %void None %func_t + %1 = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str.c_str(), SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("In Vulkan, OpTypeRuntimeArray must only be used for " + "the last member of an OpTypeStruct\n %_struct_3 = " + "OpTypeStruct %_runtimearr_uint %uint\n")); +} + +TEST_F(ValidateData, webgpu_RTA_array_at_end_of_struct) { + std::string str = R"( + OpCapability Shader + OpCapability VulkanMemoryModelKHR + OpExtension "SPV_KHR_vulkan_memory_model" + OpMemoryModel Logical VulkanKHR + OpEntryPoint Fragment %func "func" + OpExecutionMode %func OriginUpperLeft + OpDecorate %array_t ArrayStride 4 + OpMemberDecorate %struct_t 0 Offset 0 + OpMemberDecorate %struct_t 1 Offset 4 + OpDecorate %struct_t Block + %uint_t = OpTypeInt 32 0 + %array_t = OpTypeRuntimeArray %uint_t + %struct_t = OpTypeStruct %uint_t %array_t +%struct_ptr = OpTypePointer StorageBuffer %struct_t + %2 = OpVariable %struct_ptr StorageBuffer + %void = OpTypeVoid + %func_t = OpTypeFunction %void + %func = OpFunction %void None %func_t + %1 = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str.c_str(), SPV_ENV_WEBGPU_0); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0)); +} + +TEST_F(ValidateData, webgpu_RTA_not_at_end_of_struct) { + std::string str = R"( + OpCapability Shader + OpCapability VulkanMemoryModelKHR + OpExtension "SPV_KHR_vulkan_memory_model" + OpMemoryModel Logical VulkanKHR + OpEntryPoint Fragment %func "func" + OpExecutionMode %func OriginUpperLeft + OpDecorate %array_t ArrayStride 4 + OpMemberDecorate %struct_t 0 Offset 0 + OpMemberDecorate %struct_t 1 Offset 4 + OpDecorate %struct_t Block + %uint_t = OpTypeInt 32 0 + %array_t = OpTypeRuntimeArray %uint_t + %struct_t = OpTypeStruct %array_t %uint_t +%struct_ptr = OpTypePointer StorageBuffer %struct_t + %2 = OpVariable %struct_ptr StorageBuffer + %void = OpTypeVoid + %func_t = OpTypeFunction %void + %func = OpFunction %void None %func_t + %1 = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str.c_str(), SPV_ENV_WEBGPU_0); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("In WebGPU, OpTypeRuntimeArray must only be used for " + "the last member of an OpTypeStruct\n %_struct_3 = " + "OpTypeStruct %_runtimearr_uint %uint\n")); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_decoration_test.cpp b/test/val/val_decoration_test.cpp new file mode 100644 index 000000000..827ebf185 --- /dev/null +++ b/test/val/val_decoration_test.cpp @@ -0,0 +1,5307 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validation tests for decorations + +#include +#include + +#include "gmock/gmock.h" +#include "source/val/decoration.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::Eq; +using ::testing::HasSubstr; + +using ValidateDecorations = spvtest::ValidateBase; + +TEST_F(ValidateDecorations, ValidateOpDecorateRegistration) { + std::string spirv = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpDecorate %1 ArrayStride 4 + OpDecorate %1 RelaxedPrecision + %2 = OpTypeFloat 32 + %1 = OpTypeRuntimeArray %2 + ; Since %1 is used first in Decoration, it gets id 1. +)"; + const uint32_t id = 1; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); + // Must have 2 decorations. + EXPECT_THAT( + vstate_->id_decorations(id), + Eq(std::vector{Decoration(SpvDecorationArrayStride, {4}), + Decoration(SpvDecorationRelaxedPrecision)})); +} + +TEST_F(ValidateDecorations, ValidateOpMemberDecorateRegistration) { + std::string spirv = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpDecorate %_arr_double_uint_6 ArrayStride 4 + OpMemberDecorate %_struct_115 2 NonReadable + OpMemberDecorate %_struct_115 2 Offset 2 + OpDecorate %_struct_115 BufferBlock + %float = OpTypeFloat 32 + %uint = OpTypeInt 32 0 + %uint_6 = OpConstant %uint 6 + %_arr_double_uint_6 = OpTypeArray %float %uint_6 + %_struct_115 = OpTypeStruct %float %float %_arr_double_uint_6 +)"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); + + // The array must have 1 decoration. + const uint32_t arr_id = 1; + EXPECT_THAT( + vstate_->id_decorations(arr_id), + Eq(std::vector{Decoration(SpvDecorationArrayStride, {4})})); + + // The struct must have 3 decorations. + const uint32_t struct_id = 2; + EXPECT_THAT( + vstate_->id_decorations(struct_id), + Eq(std::vector{Decoration(SpvDecorationNonReadable, {}, 2), + Decoration(SpvDecorationOffset, {2}, 2), + Decoration(SpvDecorationBufferBlock)})); +} + +TEST_F(ValidateDecorations, ValidateOpMemberDecorateOutOfBound) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "Main" + OpExecutionMode %1 OriginUpperLeft + OpMemberDecorate %_struct_2 1 RelaxedPrecision + %void = OpTypeVoid + %4 = OpTypeFunction %void + %float = OpTypeFloat 32 + %_struct_2 = OpTypeStruct %float + %1 = OpFunction %void None %4 + %6 = OpLabel + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Index 1 provided in OpMemberDecorate for struct " + "2[%_struct_2] is out of bounds. The structure has 1 " + "members. Largest valid index is 0.")); +} + +TEST_F(ValidateDecorations, ValidateGroupDecorateRegistration) { + std::string spirv = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpDecorate %1 DescriptorSet 0 + OpDecorate %1 NonWritable + OpDecorate %1 Restrict + %1 = OpDecorationGroup + OpGroupDecorate %1 %2 %3 + OpGroupDecorate %1 %4 + %float = OpTypeFloat 32 +%_runtimearr_float = OpTypeRuntimeArray %float + %_struct_9 = OpTypeStruct %_runtimearr_float +%_ptr_Uniform__struct_9 = OpTypePointer Uniform %_struct_9 + %2 = OpVariable %_ptr_Uniform__struct_9 Uniform + %_struct_10 = OpTypeStruct %_runtimearr_float +%_ptr_Uniform__struct_10 = OpTypePointer Uniform %_struct_10 + %3 = OpVariable %_ptr_Uniform__struct_10 Uniform + %_struct_11 = OpTypeStruct %_runtimearr_float +%_ptr_Uniform__struct_11 = OpTypePointer Uniform %_struct_11 + %4 = OpVariable %_ptr_Uniform__struct_11 Uniform + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); + + // Decoration group has 3 decorations. + auto expected_decorations = std::vector{ + Decoration(SpvDecorationDescriptorSet, {0}), + Decoration(SpvDecorationNonWritable), Decoration(SpvDecorationRestrict)}; + + // Decoration group is applied to id 1, 2, 3, and 4. Note that id 1 (which is + // the decoration group id) also has all the decorations. + EXPECT_THAT(vstate_->id_decorations(1), Eq(expected_decorations)); + EXPECT_THAT(vstate_->id_decorations(2), Eq(expected_decorations)); + EXPECT_THAT(vstate_->id_decorations(3), Eq(expected_decorations)); + EXPECT_THAT(vstate_->id_decorations(4), Eq(expected_decorations)); +} + +TEST_F(ValidateDecorations, WebGPUOpDecorationGroupBad) { + std::string spirv = R"( + OpCapability Shader + OpCapability VulkanMemoryModelKHR + OpExtension "SPV_KHR_vulkan_memory_model" + OpMemoryModel Logical VulkanKHR + OpDecorate %1 DescriptorSet 0 + OpDecorate %1 NonWritable + OpDecorate %1 Restrict + %1 = OpDecorationGroup + OpGroupDecorate %1 %2 %3 + OpGroupDecorate %1 %4 + %float = OpTypeFloat 32 +%_runtimearr_float = OpTypeRuntimeArray %float + %_struct_9 = OpTypeStruct %_runtimearr_float +%_ptr_Uniform__struct_9 = OpTypePointer Uniform %_struct_9 + %2 = OpVariable %_ptr_Uniform__struct_9 Uniform + %_struct_10 = OpTypeStruct %_runtimearr_float +%_ptr_Uniform__struct_10 = OpTypePointer Uniform %_struct_10 + %3 = OpVariable %_ptr_Uniform__struct_10 Uniform + %_struct_11 = OpTypeStruct %_runtimearr_float +%_ptr_Uniform__struct_11 = OpTypePointer Uniform %_struct_11 + %4 = OpVariable %_ptr_Uniform__struct_11 Uniform + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpDecorationGroup is not allowed in the WebGPU " + "execution environment.\n %1 = OpDecorationGroup\n")); +} + +// For WebGPU, OpGroupDecorate does not have a test case, because it requires +// being preceded by OpDecorationGroup, which will cause a validation error. + +// For WebGPU, OpGroupMemberDecorate does not have a test case, because it +// requires being preceded by OpDecorationGroup, which will cause a validation +// error. + +TEST_F(ValidateDecorations, ValidateGroupMemberDecorateRegistration) { + std::string spirv = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpDecorate %1 Offset 3 + %1 = OpDecorationGroup + OpGroupMemberDecorate %1 %_struct_1 3 %_struct_2 3 %_struct_3 3 + %float = OpTypeFloat 32 +%_runtimearr = OpTypeRuntimeArray %float + %_struct_1 = OpTypeStruct %float %float %float %_runtimearr + %_struct_2 = OpTypeStruct %float %float %float %_runtimearr + %_struct_3 = OpTypeStruct %float %float %float %_runtimearr + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); + // Decoration group has 1 decoration. + auto expected_decorations = + std::vector{Decoration(SpvDecorationOffset, {3}, 3)}; + + // Decoration group is applied to id 2, 3, and 4. + EXPECT_THAT(vstate_->id_decorations(2), Eq(expected_decorations)); + EXPECT_THAT(vstate_->id_decorations(3), Eq(expected_decorations)); + EXPECT_THAT(vstate_->id_decorations(4), Eq(expected_decorations)); +} + +TEST_F(ValidateDecorations, LinkageImportUsedForInitializedVariableBad) { + std::string spirv = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpDecorate %target LinkageAttributes "link_ptr" Import + %float = OpTypeFloat 32 + %_ptr_float = OpTypePointer Uniform %float + %zero = OpConstantNull %float + %target = OpVariable %_ptr_float Uniform %zero + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("A module-scope OpVariable with initialization value " + "cannot be marked with the Import Linkage Type.")); +} +TEST_F(ValidateDecorations, LinkageExportUsedForInitializedVariableGood) { + std::string spirv = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpDecorate %target LinkageAttributes "link_ptr" Export + %float = OpTypeFloat 32 + %_ptr_float = OpTypePointer Uniform %float + %zero = OpConstantNull %float + %target = OpVariable %_ptr_float Uniform %zero + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateDecorations, StructAllMembersHaveBuiltInDecorationsGood) { + std::string spirv = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpMemberDecorate %_struct_1 0 BuiltIn Position + OpMemberDecorate %_struct_1 1 BuiltIn Position + OpMemberDecorate %_struct_1 2 BuiltIn Position + OpMemberDecorate %_struct_1 3 BuiltIn Position + %float = OpTypeFloat 32 +%_runtimearr = OpTypeRuntimeArray %float + %_struct_1 = OpTypeStruct %float %float %float %_runtimearr + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateDecorations, MixedBuiltInDecorationsBad) { + std::string spirv = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpMemberDecorate %_struct_1 0 BuiltIn Position + OpMemberDecorate %_struct_1 1 BuiltIn Position + %float = OpTypeFloat 32 +%_runtimearr = OpTypeRuntimeArray %float + %_struct_1 = OpTypeStruct %float %float %float %_runtimearr + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("When BuiltIn decoration is applied to a structure-type " + "member, all members of that structure type must also be " + "decorated with BuiltIn (No allowed mixing of built-in " + "variables and non-built-in variables within a single " + "structure). Structure id 1 does not meet this requirement.")); +} + +TEST_F(ValidateDecorations, StructContainsBuiltInStructBad) { + std::string spirv = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpMemberDecorate %_struct_1 0 BuiltIn Position + OpMemberDecorate %_struct_1 1 BuiltIn Position + OpMemberDecorate %_struct_1 2 BuiltIn Position + OpMemberDecorate %_struct_1 3 BuiltIn Position + %float = OpTypeFloat 32 +%_runtimearr = OpTypeRuntimeArray %float + %_struct_1 = OpTypeStruct %float %float %float %_runtimearr + %_struct_2 = OpTypeStruct %_struct_1 + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Structure 1[%_struct_1] contains members with " + "BuiltIn decoration. Therefore this structure may not " + "be contained as a member of another structure type. " + "Structure 4[%_struct_4] contains structure " + "1[%_struct_1].")); +} + +TEST_F(ValidateDecorations, StructContainsNonBuiltInStructGood) { + std::string spirv = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + %float = OpTypeFloat 32 + %_struct_1 = OpTypeStruct %float + %_struct_2 = OpTypeStruct %_struct_1 + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateDecorations, MultipleBuiltInObjectsConsumedByOpEntryPointBad) { + std::string spirv = R"( + OpCapability Shader + OpCapability Geometry + OpMemoryModel Logical GLSL450 + OpEntryPoint Geometry %main "main" %in_1 %in_2 + OpExecutionMode %main InputPoints + OpExecutionMode %main OutputPoints + OpMemberDecorate %struct_1 0 BuiltIn InvocationId + OpMemberDecorate %struct_2 0 BuiltIn Position + %int = OpTypeInt 32 1 + %void = OpTypeVoid + %func = OpTypeFunction %void + %float = OpTypeFloat 32 + %struct_1 = OpTypeStruct %int + %struct_2 = OpTypeStruct %float +%ptr_builtin_1 = OpTypePointer Input %struct_1 +%ptr_builtin_2 = OpTypePointer Input %struct_2 +%in_1 = OpVariable %ptr_builtin_1 Input +%in_2 = OpVariable %ptr_builtin_2 Input + %main = OpFunction %void None %func + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("There must be at most one object per Storage Class " + "that can contain a structure type containing members " + "decorated with BuiltIn, consumed per entry-point.")); +} + +TEST_F(ValidateDecorations, + OneBuiltInObjectPerStorageClassConsumedByOpEntryPointGood) { + std::string spirv = R"( + OpCapability Shader + OpCapability Geometry + OpMemoryModel Logical GLSL450 + OpEntryPoint Geometry %main "main" %in_1 %out_1 + OpExecutionMode %main InputPoints + OpExecutionMode %main OutputPoints + OpMemberDecorate %struct_1 0 BuiltIn InvocationId + OpMemberDecorate %struct_2 0 BuiltIn Position + %int = OpTypeInt 32 1 + %void = OpTypeVoid + %func = OpTypeFunction %void + %float = OpTypeFloat 32 + %struct_1 = OpTypeStruct %int + %struct_2 = OpTypeStruct %float +%ptr_builtin_1 = OpTypePointer Input %struct_1 +%ptr_builtin_2 = OpTypePointer Output %struct_2 +%in_1 = OpVariable %ptr_builtin_1 Input +%out_1 = OpVariable %ptr_builtin_2 Output + %main = OpFunction %void None %func + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateDecorations, NoBuiltInObjectsConsumedByOpEntryPointGood) { + std::string spirv = R"( + OpCapability Shader + OpCapability Geometry + OpMemoryModel Logical GLSL450 + OpEntryPoint Geometry %main "main" %in_1 %out_1 + OpExecutionMode %main InputPoints + OpExecutionMode %main OutputPoints + %int = OpTypeInt 32 1 + %void = OpTypeVoid + %func = OpTypeFunction %void + %float = OpTypeFloat 32 + %struct_1 = OpTypeStruct %int + %struct_2 = OpTypeStruct %float +%ptr_builtin_1 = OpTypePointer Input %struct_1 +%ptr_builtin_2 = OpTypePointer Output %struct_2 +%in_1 = OpVariable %ptr_builtin_1 Input +%out_1 = OpVariable %ptr_builtin_2 Output + %main = OpFunction %void None %func + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateDecorations, EntryPointFunctionHasLinkageAttributeBad) { + std::string spirv = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpDecorate %main LinkageAttributes "import_main" Import +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%main = OpFunction %1 None %2 +%4 = OpLabel + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("The LinkageAttributes Decoration (Linkage name: import_main) " + "cannot be applied to function id 1 because it is targeted by " + "an OpEntryPoint instruction.")); +} + +TEST_F(ValidateDecorations, FunctionDeclarationWithoutImportLinkageBad) { + std::string spirv = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + %void = OpTypeVoid + %func = OpTypeFunction %void + %main = OpFunction %void None %func + OpFunctionEnd + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Function declaration (id 3) must have a LinkageAttributes " + "decoration with the Import Linkage type.")); +} + +TEST_F(ValidateDecorations, FunctionDeclarationWithImportLinkageGood) { + std::string spirv = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpDecorate %main LinkageAttributes "link_fn" Import + %void = OpTypeVoid + %func = OpTypeFunction %void + %main = OpFunction %void None %func + OpFunctionEnd + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateDecorations, FunctionDeclarationWithExportLinkageBad) { + std::string spirv = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpDecorate %main LinkageAttributes "link_fn" Export + %void = OpTypeVoid + %func = OpTypeFunction %void + %main = OpFunction %void None %func + OpFunctionEnd + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Function declaration (id 1) must have a LinkageAttributes " + "decoration with the Import Linkage type.")); +} + +TEST_F(ValidateDecorations, FunctionDefinitionWithImportLinkageBad) { + std::string spirv = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpDecorate %main LinkageAttributes "link_fn" Import + %void = OpTypeVoid + %func = OpTypeFunction %void + %main = OpFunction %void None %func + %label = OpLabel + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Function definition (id 1) may not be decorated with " + "Import Linkage type.")); +} + +TEST_F(ValidateDecorations, FunctionDefinitionWithoutImportLinkageGood) { + std::string spirv = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + %void = OpTypeVoid + %func = OpTypeFunction %void + %main = OpFunction %void None %func + %label = OpLabel + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateDecorations, BuiltinVariablesGoodVulkan) { + const spv_target_env env = SPV_ENV_VULKAN_1_0; + std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %gl_FragCoord %_entryPointOutput +OpExecutionMode %main OriginUpperLeft +OpSource HLSL 500 +OpDecorate %gl_FragCoord BuiltIn FragCoord +OpDecorate %_entryPointOutput Location 0 +%void = OpTypeVoid +%3 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%float_0 = OpConstant %float 0 +%14 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%gl_FragCoord = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_entryPointOutput = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %3 +%5 = OpLabel +OpStore %_entryPointOutput %14 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, env); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState(env)); +} + +TEST_F(ValidateDecorations, BuiltinVariablesWithLocationDecorationVulkan) { + const spv_target_env env = SPV_ENV_VULKAN_1_0; + std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %gl_FragCoord %_entryPointOutput +OpExecutionMode %main OriginUpperLeft +OpSource HLSL 500 +OpDecorate %gl_FragCoord BuiltIn FragCoord +OpDecorate %gl_FragCoord Location 0 +OpDecorate %_entryPointOutput Location 0 +%void = OpTypeVoid +%3 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%float_0 = OpConstant %float 0 +%14 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%gl_FragCoord = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_entryPointOutput = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %3 +%5 = OpLabel +OpStore %_entryPointOutput %14 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, env); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState(env)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("A BuiltIn variable (id 2) cannot have any Location or " + "Component decorations")); +} +TEST_F(ValidateDecorations, BuiltinVariablesWithComponentDecorationVulkan) { + const spv_target_env env = SPV_ENV_VULKAN_1_0; + std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %gl_FragCoord %_entryPointOutput +OpExecutionMode %main OriginUpperLeft +OpSource HLSL 500 +OpDecorate %gl_FragCoord BuiltIn FragCoord +OpDecorate %gl_FragCoord Component 0 +OpDecorate %_entryPointOutput Location 0 +%void = OpTypeVoid +%3 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%float_0 = OpConstant %float 0 +%14 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%gl_FragCoord = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%_entryPointOutput = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %3 +%5 = OpLabel +OpStore %_entryPointOutput %14 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, env); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState(env)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("A BuiltIn variable (id 2) cannot have any Location or " + "Component decorations")); +} + +// #version 440 +// #extension GL_EXT_nonuniform_qualifier : enable +// layout(binding = 1) uniform sampler2D s2d[]; +// layout(location = 0) in nonuniformEXT int i; +// void main() +// { +// vec4 v = texture(s2d[i], vec2(0.3)); +// } +TEST_F(ValidateDecorations, RuntimeArrayOfDescriptorSetsIsAllowed) { + const spv_target_env env = SPV_ENV_VULKAN_1_0; + std::string spirv = R"( + OpCapability Shader + OpCapability ShaderNonUniformEXT + OpCapability RuntimeDescriptorArrayEXT + OpCapability SampledImageArrayNonUniformIndexingEXT + OpExtension "SPV_EXT_descriptor_indexing" + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" %i + OpSource GLSL 440 + OpSourceExtension "GL_EXT_nonuniform_qualifier" + OpName %main "main" + OpName %v "v" + OpName %s2d "s2d" + OpName %i "i" + OpDecorate %s2d DescriptorSet 0 + OpDecorate %s2d Binding 1 + OpDecorate %i Location 0 + OpDecorate %i NonUniformEXT + OpDecorate %18 NonUniformEXT + OpDecorate %21 NonUniformEXT + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float + %10 = OpTypeImage %float 2D 0 0 0 1 Unknown + %11 = OpTypeSampledImage %10 +%_runtimearr_11 = OpTypeRuntimeArray %11 +%_ptr_Uniform__runtimearr_11 = OpTypePointer Uniform %_runtimearr_11 + %s2d = OpVariable %_ptr_Uniform__runtimearr_11 Uniform + %int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int + %i = OpVariable %_ptr_Input_int Input +%_ptr_Uniform_11 = OpTypePointer Uniform %11 + %v2float = OpTypeVector %float 2 +%float_0_300000012 = OpConstant %float 0.300000012 + %24 = OpConstantComposite %v2float %float_0_300000012 %float_0_300000012 + %float_0 = OpConstant %float 0 + %main = OpFunction %void None %3 + %5 = OpLabel + %v = OpVariable %_ptr_Function_v4float Function + %18 = OpLoad %int %i + %20 = OpAccessChain %_ptr_Uniform_11 %s2d %18 + %21 = OpLoad %11 %20 + %26 = OpImageSampleExplicitLod %v4float %21 %24 Lod %float_0 + OpStore %v %26 + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(spirv, env); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateDecorations, BlockMissingOffsetBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %Output Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %Output = OpTypeStruct %float +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must be explicitly laid out with Offset decorations")); +} + +TEST_F(ValidateDecorations, BufferBlockMissingOffsetBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %Output BufferBlock + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %Output = OpTypeStruct %float +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must be explicitly laid out with Offset decorations")); +} + +TEST_F(ValidateDecorations, BlockNestedStructMissingOffsetBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 16 + OpMemberDecorate %Output 2 Offset 32 + OpDecorate %Output Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %v3float = OpTypeVector %float 3 + %int = OpTypeInt 32 1 + %S = OpTypeStruct %v3float %int + %Output = OpTypeStruct %float %v4float %S +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must be explicitly laid out with Offset decorations")); +} + +TEST_F(ValidateDecorations, BufferBlockNestedStructMissingOffsetBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 16 + OpMemberDecorate %Output 2 Offset 32 + OpDecorate %Output BufferBlock + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %v3float = OpTypeVector %float 3 + %int = OpTypeInt 32 1 + %S = OpTypeStruct %v3float %int + %Output = OpTypeStruct %float %v4float %S +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must be explicitly laid out with Offset decorations")); +} + +TEST_F(ValidateDecorations, BlockGLSLSharedBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %Output Block + OpDecorate %Output GLSLShared + OpMemberDecorate %Output 0 Offset 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %Output = OpTypeStruct %float +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must not use GLSLShared decoration")); +} + +TEST_F(ValidateDecorations, BufferBlockGLSLSharedBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %Output BufferBlock + OpDecorate %Output GLSLShared + OpMemberDecorate %Output 0 Offset 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %Output = OpTypeStruct %float +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must not use GLSLShared decoration")); +} + +TEST_F(ValidateDecorations, BlockNestedStructGLSLSharedBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %S 0 Offset 0 + OpDecorate %S GLSLShared + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 16 + OpMemberDecorate %Output 2 Offset 32 + OpDecorate %Output Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %int = OpTypeInt 32 1 + %S = OpTypeStruct %int + %Output = OpTypeStruct %float %v4float %S +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must not use GLSLShared decoration")); +} + +TEST_F(ValidateDecorations, BufferBlockNestedStructGLSLSharedBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %S 0 Offset 0 + OpDecorate %S GLSLShared + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 16 + OpMemberDecorate %Output 2 Offset 32 + OpDecorate %Output BufferBlock + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %int = OpTypeInt 32 1 + %S = OpTypeStruct %int + %Output = OpTypeStruct %float %v4float %S +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must not use GLSLShared decoration")); +} + +TEST_F(ValidateDecorations, BlockGLSLPackedBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %Output Block + OpDecorate %Output GLSLPacked + OpMemberDecorate %Output 0 Offset 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %Output = OpTypeStruct %float +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must not use GLSLPacked decoration")); +} + +TEST_F(ValidateDecorations, BufferBlockGLSLPackedBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %Output BufferBlock + OpDecorate %Output GLSLPacked + OpMemberDecorate %Output 0 Offset 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %Output = OpTypeStruct %float +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must not use GLSLPacked decoration")); +} + +TEST_F(ValidateDecorations, BlockNestedStructGLSLPackedBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %S 0 Offset 0 + OpDecorate %S GLSLPacked + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 16 + OpMemberDecorate %Output 2 Offset 32 + OpDecorate %Output Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %int = OpTypeInt 32 1 + %S = OpTypeStruct %int + %Output = OpTypeStruct %float %v4float %S +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must not use GLSLPacked decoration")); +} + +TEST_F(ValidateDecorations, BufferBlockNestedStructGLSLPackedBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %S 0 Offset 0 + OpDecorate %S GLSLPacked + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 16 + OpMemberDecorate %Output 2 Offset 32 + OpDecorate %Output BufferBlock + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %int = OpTypeInt 32 1 + %S = OpTypeStruct %int + %Output = OpTypeStruct %float %v4float %S +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("must not use GLSLPacked decoration")); +} + +TEST_F(ValidateDecorations, BlockMissingArrayStrideBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %Output Block + OpMemberDecorate %Output 0 Offset 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %int = OpTypeInt 32 1 + %int_3 = OpConstant %int 3 + %array = OpTypeArray %float %int_3 + %Output = OpTypeStruct %array +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("must be explicitly laid out with ArrayStride decorations")); +} + +TEST_F(ValidateDecorations, BufferBlockMissingArrayStrideBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %Output BufferBlock + OpMemberDecorate %Output 0 Offset 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %int = OpTypeInt 32 1 + %int_3 = OpConstant %int 3 + %array = OpTypeArray %float %int_3 + %Output = OpTypeStruct %array +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("must be explicitly laid out with ArrayStride decorations")); +} + +TEST_F(ValidateDecorations, BlockNestedStructMissingArrayStrideBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 16 + OpMemberDecorate %Output 2 Offset 32 + OpDecorate %Output Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %int = OpTypeInt 32 1 + %int_3 = OpConstant %int 3 + %array = OpTypeArray %float %int_3 + %S = OpTypeStruct %array + %Output = OpTypeStruct %float %v4float %S +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("must be explicitly laid out with ArrayStride decorations")); +} + +TEST_F(ValidateDecorations, BufferBlockNestedStructMissingArrayStrideBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 16 + OpMemberDecorate %Output 2 Offset 32 + OpDecorate %Output BufferBlock + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %int = OpTypeInt 32 1 + %int_3 = OpConstant %int 3 + %array = OpTypeArray %float %int_3 + %S = OpTypeStruct %array + %Output = OpTypeStruct %float %v4float %S +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("must be explicitly laid out with ArrayStride decorations")); +} + +TEST_F(ValidateDecorations, BlockMissingMatrixStrideBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %Output Block + OpMemberDecorate %Output 0 Offset 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %matrix = OpTypeMatrix %v3float 4 + %Output = OpTypeStruct %matrix +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("must be explicitly laid out with MatrixStride decorations")); +} + +TEST_F(ValidateDecorations, BufferBlockMissingMatrixStrideBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %Output BufferBlock + OpMemberDecorate %Output 0 Offset 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %matrix = OpTypeMatrix %v3float 4 + %Output = OpTypeStruct %matrix +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("must be explicitly laid out with MatrixStride decorations")); +} + +TEST_F(ValidateDecorations, BlockMissingMatrixStrideArrayBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %Output Block + OpMemberDecorate %Output 0 Offset 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %matrix = OpTypeMatrix %v3float 4 + %int = OpTypeInt 32 1 + %int_3 = OpConstant %int 3 + %array = OpTypeArray %matrix %int_3 + %Output = OpTypeStruct %matrix +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("must be explicitly laid out with MatrixStride decorations")); +} + +TEST_F(ValidateDecorations, BufferBlockMissingMatrixStrideArrayBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %Output BufferBlock + OpMemberDecorate %Output 0 Offset 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %matrix = OpTypeMatrix %v3float 4 + %int = OpTypeInt 32 1 + %int_3 = OpConstant %int 3 + %array = OpTypeArray %matrix %int_3 + %Output = OpTypeStruct %matrix +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("must be explicitly laid out with MatrixStride decorations")); +} + +TEST_F(ValidateDecorations, BlockNestedStructMissingMatrixStrideBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 16 + OpMemberDecorate %Output 2 Offset 32 + OpDecorate %Output Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %v4float = OpTypeVector %float 4 + %matrix = OpTypeMatrix %v3float 4 + %S = OpTypeStruct %matrix + %Output = OpTypeStruct %float %v4float %S +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("must be explicitly laid out with MatrixStride decorations")); +} + +TEST_F(ValidateDecorations, BufferBlockNestedStructMissingMatrixStrideBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 16 + OpMemberDecorate %Output 2 Offset 32 + OpDecorate %Output BufferBlock + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %v4float = OpTypeVector %float 4 + %matrix = OpTypeMatrix %v3float 4 + %S = OpTypeStruct %matrix + %Output = OpTypeStruct %float %v4float %S +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("must be explicitly laid out with MatrixStride decorations")); +} + +TEST_F(ValidateDecorations, BlockStandardUniformBufferLayout) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %F 0 Offset 0 + OpMemberDecorate %F 1 Offset 8 + OpDecorate %_arr_float_uint_2 ArrayStride 16 + OpDecorate %_arr_mat3v3float_uint_2 ArrayStride 48 + OpMemberDecorate %O 0 Offset 0 + OpMemberDecorate %O 1 Offset 16 + OpMemberDecorate %O 2 Offset 32 + OpMemberDecorate %O 3 Offset 64 + OpMemberDecorate %O 4 ColMajor + OpMemberDecorate %O 4 Offset 80 + OpMemberDecorate %O 4 MatrixStride 16 + OpDecorate %_arr_O_uint_2 ArrayStride 176 + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 8 + OpMemberDecorate %Output 2 Offset 16 + OpMemberDecorate %Output 3 Offset 32 + OpMemberDecorate %Output 4 Offset 48 + OpMemberDecorate %Output 5 Offset 64 + OpMemberDecorate %Output 6 ColMajor + OpMemberDecorate %Output 6 Offset 96 + OpMemberDecorate %Output 6 MatrixStride 16 + OpMemberDecorate %Output 7 Offset 128 + OpDecorate %Output Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %v3float = OpTypeVector %float 3 + %int = OpTypeInt 32 1 + %uint = OpTypeInt 32 0 + %v2uint = OpTypeVector %uint 2 + %F = OpTypeStruct %int %v2uint + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 +%mat2v3float = OpTypeMatrix %v3float 2 + %v3uint = OpTypeVector %uint 3 +%mat3v3float = OpTypeMatrix %v3float 3 +%_arr_mat3v3float_uint_2 = OpTypeArray %mat3v3float %uint_2 + %O = OpTypeStruct %v3uint %v2float %_arr_float_uint_2 %v2float %_arr_mat3v3float_uint_2 +%_arr_O_uint_2 = OpTypeArray %O %uint_2 + %Output = OpTypeStruct %float %v2float %v3float %F %float %_arr_float_uint_2 %mat2v3float %_arr_O_uint_2 +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateDecorations, BlockLayoutPermitsTightVec3ScalarPackingGood) { + // See https://github.com/KhronosGroup/SPIRV-Tools/issues/1666 + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 12 + OpDecorate %S Block + OpDecorate %B DescriptorSet 0 + OpDecorate %B Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %S = OpTypeStruct %v3float %float +%_ptr_Uniform_S = OpTypePointer Uniform %S + %B = OpVariable %_ptr_Uniform_S Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, BlockLayoutForbidsTightScalarVec3PackingBad) { + // See https://github.com/KhronosGroup/SPIRV-Tools/issues/1666 + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 4 + OpDecorate %S Block + OpDecorate %B DescriptorSet 0 + OpDecorate %B Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %S = OpTypeStruct %float %v3float +%_ptr_Uniform_S = OpTypePointer Uniform %S + %B = OpVariable %_ptr_Uniform_S Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Structure id 2 decorated as Block for variable in Uniform " + "storage class must follow standard uniform buffer layout " + "rules: member 1 at offset 4 is not aligned to 16")); +} + +TEST_F(ValidateDecorations, + BlockLayoutPermitsTightScalarVec3PackingWithRelaxedLayoutGood) { + // Same as previous test, but with explicit option to relax block layout. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 4 + OpDecorate %S Block + OpDecorate %B DescriptorSet 0 + OpDecorate %B Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %S = OpTypeStruct %float %v3float +%_ptr_Uniform_S = OpTypePointer Uniform %S + %B = OpVariable %_ptr_Uniform_S Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + spvValidatorOptionsSetRelaxBlockLayout(getValidatorOptions(), true); + EXPECT_EQ(SPV_SUCCESS, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateDecorations, + BlockLayoutPermitsTightScalarVec3PackingBadOffsetWithRelaxedLayoutBad) { + // Same as previous test, but with the vector not aligned to its scalar + // element. Use offset 5 instead of a multiple of 4. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 5 + OpDecorate %S Block + OpDecorate %B DescriptorSet 0 + OpDecorate %B Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %S = OpTypeStruct %float %v3float +%_ptr_Uniform_S = OpTypePointer Uniform %S + %B = OpVariable %_ptr_Uniform_S Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + spvValidatorOptionsSetRelaxBlockLayout(getValidatorOptions(), true); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 2 decorated as Block for variable in Uniform storage " + "class must follow relaxed uniform buffer layout rules: member 1 at " + "offset 5 is not aligned to scalar element size 4")); +} + +TEST_F(ValidateDecorations, + BlockLayoutPermitsTightScalarVec3PackingWithVulkan1_1Good) { + // Same as previous test, but with Vulkan 1.1. Vulkan 1.1 included + // VK_KHR_relaxed_block_layout in core. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 4 + OpDecorate %S Block + OpDecorate %B DescriptorSet 0 + OpDecorate %B Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %S = OpTypeStruct %float %v3float +%_ptr_Uniform_S = OpTypePointer Uniform %S + %B = OpVariable %_ptr_Uniform_S Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateDecorations, + BlockLayoutPermitsTightScalarVec3PackingWithScalarLayoutGood) { + // Same as previous test, but with scalar block layout. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 4 + OpDecorate %S Block + OpDecorate %B DescriptorSet 0 + OpDecorate %B Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %S = OpTypeStruct %float %v3float +%_ptr_Uniform_S = OpTypePointer Uniform %S + %B = OpVariable %_ptr_Uniform_S Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + spvValidatorOptionsSetScalarBlockLayout(getValidatorOptions(), true); + EXPECT_EQ(SPV_SUCCESS, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateDecorations, + BlockLayoutPermitsScalarAlignedArrayWithScalarLayoutGood) { + // The array at offset 4 is ok with scalar block layout. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 4 + OpDecorate %S Block + OpDecorate %B DescriptorSet 0 + OpDecorate %B Binding 0 + OpDecorate %arr_float ArrayStride 4 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %uint_3 = OpConstant %uint 3 + %float = OpTypeFloat 32 + %arr_float = OpTypeArray %float %uint_3 + %S = OpTypeStruct %float %arr_float +%_ptr_Uniform_S = OpTypePointer Uniform %S + %B = OpVariable %_ptr_Uniform_S Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + spvValidatorOptionsSetScalarBlockLayout(getValidatorOptions(), true); + EXPECT_EQ(SPV_SUCCESS, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateDecorations, + BlockLayoutPermitsScalarAlignedArrayOfVec3WithScalarLayoutGood) { + // The array at offset 4 is ok with scalar block layout, even though + // its elements are vec3. + // This is the same as the previous case, but the array elements are vec3 + // instead of float. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 4 + OpDecorate %S Block + OpDecorate %B DescriptorSet 0 + OpDecorate %B Binding 0 + OpDecorate %arr_vec3 ArrayStride 12 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %uint_3 = OpConstant %uint 3 + %float = OpTypeFloat 32 + %vec3 = OpTypeVector %float 3 + %arr_vec3 = OpTypeArray %vec3 %uint_3 + %S = OpTypeStruct %float %arr_vec3 +%_ptr_Uniform_S = OpTypePointer Uniform %S + %B = OpVariable %_ptr_Uniform_S Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + spvValidatorOptionsSetScalarBlockLayout(getValidatorOptions(), true); + EXPECT_EQ(SPV_SUCCESS, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateDecorations, + BlockLayoutPermitsScalarAlignedStructWithScalarLayoutGood) { + // Scalar block layout permits the struct at offset 4, even though + // it contains a vector with base alignment 8 and scalar alignment 4. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 4 + OpMemberDecorate %st 0 Offset 0 + OpMemberDecorate %st 1 Offset 8 + OpDecorate %S Block + OpDecorate %B DescriptorSet 0 + OpDecorate %B Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %vec2 = OpTypeVector %float 2 + %st = OpTypeStruct %vec2 %float + %S = OpTypeStruct %float %st +%_ptr_Uniform_S = OpTypePointer Uniform %S + %B = OpVariable %_ptr_Uniform_S Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + spvValidatorOptionsSetScalarBlockLayout(getValidatorOptions(), true); + EXPECT_EQ(SPV_SUCCESS, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F( + ValidateDecorations, + BlockLayoutPermitsFieldsInBaseAlignmentPaddingAtEndOfStructWithScalarLayoutGood) { + // Scalar block layout permits fields in what would normally be the padding at + // the end of a struct. + std::string spirv = R"( + OpCapability Shader + OpCapability Float64 + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %st 0 Offset 0 + OpMemberDecorate %st 1 Offset 8 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 12 + OpDecorate %S Block + OpDecorate %B DescriptorSet 0 + OpDecorate %B Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %double = OpTypeFloat 64 + %st = OpTypeStruct %double %float + %S = OpTypeStruct %st %float +%_ptr_Uniform_S = OpTypePointer Uniform %S + %B = OpVariable %_ptr_Uniform_S Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + spvValidatorOptionsSetScalarBlockLayout(getValidatorOptions(), true); + EXPECT_EQ(SPV_SUCCESS, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F( + ValidateDecorations, + BlockLayoutPermitsStraddlingVectorWithScalarLayoutOverrideRelaxBlockLayoutGood) { + // Same as previous, but set relaxed block layout first. Scalar layout always + // wins. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 4 + OpDecorate %S Block + OpDecorate %B DescriptorSet 0 + OpDecorate %B Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %vec4 = OpTypeVector %float 4 + %S = OpTypeStruct %float %vec4 +%_ptr_Uniform_S = OpTypePointer Uniform %S + %B = OpVariable %_ptr_Uniform_S Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + spvValidatorOptionsSetRelaxBlockLayout(getValidatorOptions(), true); + spvValidatorOptionsSetScalarBlockLayout(getValidatorOptions(), true); + EXPECT_EQ(SPV_SUCCESS, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F( + ValidateDecorations, + BlockLayoutPermitsStraddlingVectorWithRelaxedLayoutOverridenByScalarBlockLayoutGood) { + // Same as previous, but set scalar block layout first. Scalar layout always + // wins. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 4 + OpDecorate %S Block + OpDecorate %B DescriptorSet 0 + OpDecorate %B Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %vec4 = OpTypeVector %float 4 + %S = OpTypeStruct %float %vec4 +%_ptr_Uniform_S = OpTypePointer Uniform %S + %B = OpVariable %_ptr_Uniform_S Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + spvValidatorOptionsSetScalarBlockLayout(getValidatorOptions(), true); + spvValidatorOptionsSetRelaxBlockLayout(getValidatorOptions(), true); + EXPECT_EQ(SPV_SUCCESS, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateDecorations, BufferBlock16bitStandardStorageBufferLayout) { + std::string spirv = R"( + OpCapability Shader + OpCapability StorageUniform16 + OpExtension "SPV_KHR_16bit_storage" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpDecorate %f32arr ArrayStride 4 + OpDecorate %f16arr ArrayStride 2 + OpMemberDecorate %SSBO32 0 Offset 0 + OpMemberDecorate %SSBO16 0 Offset 0 + OpDecorate %SSBO32 BufferBlock + OpDecorate %SSBO16 BufferBlock + %void = OpTypeVoid + %voidf = OpTypeFunction %void + %u32 = OpTypeInt 32 0 + %i32 = OpTypeInt 32 1 + %f32 = OpTypeFloat 32 + %uvec3 = OpTypeVector %u32 3 + %c_i32_32 = OpConstant %i32 32 +%c_i32_128 = OpConstant %i32 128 + %f32arr = OpTypeArray %f32 %c_i32_128 + %f16 = OpTypeFloat 16 + %f16arr = OpTypeArray %f16 %c_i32_128 + %SSBO32 = OpTypeStruct %f32arr + %SSBO16 = OpTypeStruct %f16arr +%_ptr_Uniform_SSBO32 = OpTypePointer Uniform %SSBO32 + %varSSBO32 = OpVariable %_ptr_Uniform_SSBO32 Uniform +%_ptr_Uniform_SSBO16 = OpTypePointer Uniform %SSBO16 + %varSSBO16 = OpVariable %_ptr_Uniform_SSBO16 Uniform + %main = OpFunction %void None %voidf + %label = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateDecorations, BlockArrayBaseAlignmentGood) { + // For uniform buffer, Array base alignment is 16, and ArrayStride + // must be a multiple of 16. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpDecorate %_arr_float_uint_2 ArrayStride 16 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 16 + OpDecorate %S Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 + %S = OpTypeStruct %v2float %_arr_float_uint_2 +%_ptr_PushConstant_S = OpTypePointer PushConstant %S + %u = OpVariable %_ptr_PushConstant_S PushConstant + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, BlockArrayBadAlignmentBad) { + // For uniform buffer, Array base alignment is 16. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpDecorate %_arr_float_uint_2 ArrayStride 16 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 8 + OpDecorate %S Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 + %S = OpTypeStruct %v2float %_arr_float_uint_2 +%_ptr_Uniform_S = OpTypePointer Uniform %S + %u = OpVariable %_ptr_Uniform_S Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 3 decorated as Block for variable in Uniform " + "storage class must follow standard uniform buffer layout rules: " + "member 1 at offset 8 is not aligned to 16")); +} + +TEST_F(ValidateDecorations, BlockArrayBadAlignmentWithRelaxedLayoutStillBad) { + // For uniform buffer, Array base alignment is 16, and ArrayStride + // must be a multiple of 16. This case uses relaxed block layout. Relaxed + // layout only relaxes rules for vector alignment, not array alignment. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpDecorate %_arr_float_uint_2 ArrayStride 16 + OpDecorate %u DescriptorSet 0 + OpDecorate %u Binding 0 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 8 + OpDecorate %S Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 + %S = OpTypeStruct %v2float %_arr_float_uint_2 +%_ptr_Uniform_S = OpTypePointer Uniform %S + %u = OpVariable %_ptr_Uniform_S Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_0)); + spvValidatorOptionsSetRelaxBlockLayout(getValidatorOptions(), true); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 4 decorated as Block for variable in Uniform " + "storage class must follow standard uniform buffer layout rules: " + "member 1 at offset 8 is not aligned to 16")); +} + +TEST_F(ValidateDecorations, BlockArrayBadAlignmentWithVulkan1_1StillBad) { + // Same as previous test, but with Vulkan 1.1, which includes + // VK_KHR_relaxed_block_layout in core. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpDecorate %_arr_float_uint_2 ArrayStride 16 + OpDecorate %u DescriptorSet 0 + OpDecorate %u Binding 0 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 8 + OpDecorate %S Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 + %S = OpTypeStruct %v2float %_arr_float_uint_2 +%_ptr_Uniform_S = OpTypePointer Uniform %S + %u = OpVariable %_ptr_Uniform_S Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 4 decorated as Block for variable in Uniform " + "storage class must follow relaxed uniform buffer layout rules: " + "member 1 at offset 8 is not aligned to 16")); +} + +TEST_F(ValidateDecorations, PushConstantArrayBaseAlignmentGood) { + // Tests https://github.com/KhronosGroup/SPIRV-Tools/issues/1664 + // From GLSL vertex shader: + // #version 450 + // layout(push_constant) uniform S { vec2 v; float arr[2]; } u; + // void main() { } + + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpDecorate %_arr_float_uint_2 ArrayStride 4 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 8 + OpDecorate %S Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 + %S = OpTypeStruct %v2float %_arr_float_uint_2 +%_ptr_PushConstant_S = OpTypePointer PushConstant %S + %u = OpVariable %_ptr_PushConstant_S PushConstant + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, PushConstantArrayBadAlignmentBad) { + // Like the previous test, but with offset 7 instead of 8. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpDecorate %_arr_float_uint_2 ArrayStride 4 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 7 + OpDecorate %S Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 + %S = OpTypeStruct %v2float %_arr_float_uint_2 +%_ptr_PushConstant_S = OpTypePointer PushConstant %S + %u = OpVariable %_ptr_PushConstant_S PushConstant + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 3 decorated as Block for variable in PushConstant " + "storage class must follow standard storage buffer layout rules: " + "member 1 at offset 7 is not aligned to 4")); +} + +TEST_F(ValidateDecorations, + PushConstantLayoutPermitsTightVec3ScalarPackingGood) { + // See https://github.com/KhronosGroup/SPIRV-Tools/issues/1666 + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 12 + OpDecorate %S Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %S = OpTypeStruct %v3float %float +%_ptr_PushConstant_S = OpTypePointer PushConstant %S + %B = OpVariable %_ptr_PushConstant_S PushConstant + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, + PushConstantLayoutForbidsTightScalarVec3PackingBad) { + // See https://github.com/KhronosGroup/SPIRV-Tools/issues/1666 + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 4 + OpDecorate %S Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %S = OpTypeStruct %float %v3float +%_ptr_Uniform_S = OpTypePointer PushConstant %S + %B = OpVariable %_ptr_Uniform_S PushConstant + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 2 decorated as Block for variable in PushConstant " + "storage class must follow standard storage buffer layout " + "rules: member 1 at offset 4 is not aligned to 16")); +} + +TEST_F(ValidateDecorations, PushConstantMissingBlockGood) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpMemberDecorate %struct 0 Offset 0 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %struct = OpTypeStruct %float + %ptr = OpTypePointer PushConstant %struct + %pc = OpVariable %ptr PushConstant + + %1 = OpFunction %void None %voidfn + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, VulkanPushConstantMissingBlockBad) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpMemberDecorate %struct 0 Offset 0 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %struct = OpTypeStruct %float + %ptr = OpTypePointer PushConstant %struct + %pc = OpVariable %ptr PushConstant + + %1 = OpFunction %void None %voidfn + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("PushConstant id '2' is missing Block decoration.\n" + "From Vulkan spec, section 14.5.1:\n" + "Such variables must be identified with a Block " + "decoration")); +} + +TEST_F(ValidateDecorations, MultiplePushConstantsSingleEntryPointGood) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpDecorate %struct Block + OpMemberDecorate %struct 0 Offset 0 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %int = OpTypeInt 32 0 + %int_0 = OpConstant %int 0 + %struct = OpTypeStruct %float + %ptr = OpTypePointer PushConstant %struct + %ptr_float = OpTypePointer PushConstant %float + %pc1 = OpVariable %ptr PushConstant + %pc2 = OpVariable %ptr PushConstant + + %1 = OpFunction %void None %voidfn + %label = OpLabel + %2 = OpAccessChain %ptr_float %pc1 %int_0 + %3 = OpLoad %float %2 + %4 = OpAccessChain %ptr_float %pc2 %int_0 + %5 = OpLoad %float %4 + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, + VulkanMultiplePushConstantsDifferentEntryPointGood) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "func1" + OpEntryPoint Fragment %2 "func2" + OpExecutionMode %2 OriginUpperLeft + + OpDecorate %struct Block + OpMemberDecorate %struct 0 Offset 0 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %int = OpTypeInt 32 0 + %int_0 = OpConstant %int 0 + %struct = OpTypeStruct %float + %ptr = OpTypePointer PushConstant %struct + %ptr_float = OpTypePointer PushConstant %float + %pc1 = OpVariable %ptr PushConstant + %pc2 = OpVariable %ptr PushConstant + + %1 = OpFunction %void None %voidfn + %label1 = OpLabel + %3 = OpAccessChain %ptr_float %pc1 %int_0 + %4 = OpLoad %float %3 + OpReturn + OpFunctionEnd + + %2 = OpFunction %void None %voidfn + %label2 = OpLabel + %5 = OpAccessChain %ptr_float %pc2 %int_0 + %6 = OpLoad %float %5 + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1)) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, + VulkanMultiplePushConstantsUnusedSingleEntryPointGood) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpDecorate %struct Block + OpMemberDecorate %struct 0 Offset 0 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %int = OpTypeInt 32 0 + %int_0 = OpConstant %int 0 + %struct = OpTypeStruct %float + %ptr = OpTypePointer PushConstant %struct + %ptr_float = OpTypePointer PushConstant %float + %pc1 = OpVariable %ptr PushConstant + %pc2 = OpVariable %ptr PushConstant + + %1 = OpFunction %void None %voidfn + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1)) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, VulkanMultiplePushConstantsSingleEntryPointBad) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpDecorate %struct Block + OpMemberDecorate %struct 0 Offset 0 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %int = OpTypeInt 32 0 + %int_0 = OpConstant %int 0 + %struct = OpTypeStruct %float + %ptr = OpTypePointer PushConstant %struct + %ptr_float = OpTypePointer PushConstant %float + %pc1 = OpVariable %ptr PushConstant + %pc2 = OpVariable %ptr PushConstant + + %1 = OpFunction %void None %voidfn + %label = OpLabel + %2 = OpAccessChain %ptr_float %pc1 %int_0 + %3 = OpLoad %float %2 + %4 = OpAccessChain %ptr_float %pc2 %int_0 + %5 = OpLoad %float %4 + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Entry point id '1' uses more than one PushConstant interface.\n" + "From Vulkan spec, section 14.5.1:\n" + "There must be no more than one push constant block " + "statically used per shader entry point.")); +} + +TEST_F(ValidateDecorations, + VulkanMultiplePushConstantsDifferentEntryPointSubFunctionGood) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "func1" + OpEntryPoint Fragment %2 "func2" + OpExecutionMode %2 OriginUpperLeft + + OpDecorate %struct Block + OpMemberDecorate %struct 0 Offset 0 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %int = OpTypeInt 32 0 + %int_0 = OpConstant %int 0 + %struct = OpTypeStruct %float + %ptr = OpTypePointer PushConstant %struct + %ptr_float = OpTypePointer PushConstant %float + %pc1 = OpVariable %ptr PushConstant + %pc2 = OpVariable %ptr PushConstant + + %sub1 = OpFunction %void None %voidfn + %label_sub1 = OpLabel + %3 = OpAccessChain %ptr_float %pc1 %int_0 + %4 = OpLoad %float %3 + OpReturn + OpFunctionEnd + + %sub2 = OpFunction %void None %voidfn + %label_sub2 = OpLabel + %5 = OpAccessChain %ptr_float %pc2 %int_0 + %6 = OpLoad %float %5 + OpReturn + OpFunctionEnd + + %1 = OpFunction %void None %voidfn + %label1 = OpLabel + %call1 = OpFunctionCall %void %sub1 + OpReturn + OpFunctionEnd + + %2 = OpFunction %void None %voidfn + %label2 = OpLabel + %call2 = OpFunctionCall %void %sub2 + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1)) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, + VulkanMultiplePushConstantsSingleEntryPointSubFunctionBad) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpDecorate %struct Block + OpMemberDecorate %struct 0 Offset 0 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %int = OpTypeInt 32 0 + %int_0 = OpConstant %int 0 + %struct = OpTypeStruct %float + %ptr = OpTypePointer PushConstant %struct + %ptr_float = OpTypePointer PushConstant %float + %pc1 = OpVariable %ptr PushConstant + %pc2 = OpVariable %ptr PushConstant + + %sub1 = OpFunction %void None %voidfn + %label_sub1 = OpLabel + %3 = OpAccessChain %ptr_float %pc1 %int_0 + %4 = OpLoad %float %3 + OpReturn + OpFunctionEnd + + %sub2 = OpFunction %void None %voidfn + %label_sub2 = OpLabel + %5 = OpAccessChain %ptr_float %pc2 %int_0 + %6 = OpLoad %float %5 + OpReturn + OpFunctionEnd + + %1 = OpFunction %void None %voidfn + %label1 = OpLabel + %call1 = OpFunctionCall %void %sub1 + %call2 = OpFunctionCall %void %sub2 + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Entry point id '1' uses more than one PushConstant interface.\n" + "From Vulkan spec, section 14.5.1:\n" + "There must be no more than one push constant block " + "statically used per shader entry point.")); +} + +TEST_F(ValidateDecorations, VulkanUniformMissingDescriptorSetBad) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpDecorate %struct Block + OpMemberDecorate %struct 0 Offset 0 + OpDecorate %var Binding 0 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %struct = OpTypeStruct %float + %ptr = OpTypePointer Uniform %struct +%ptr_float = OpTypePointer Uniform %float + %var = OpVariable %ptr Uniform + %int = OpTypeInt 32 0 + %int_0 = OpConstant %int 0 + + %1 = OpFunction %void None %voidfn + %label = OpLabel + %2 = OpAccessChain %ptr_float %var %int_0 + %3 = OpLoad %float %2 + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Uniform id '3' is missing DescriptorSet decoration.\n" + "From Vulkan spec, section 14.5.2:\n" + "These variables must have DescriptorSet and Binding " + "decorations specified")); +} + +TEST_F(ValidateDecorations, VulkanUniformMissingBindingBad) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpDecorate %struct Block + OpMemberDecorate %struct 0 Offset 0 + OpDecorate %var DescriptorSet 0 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %struct = OpTypeStruct %float + %ptr = OpTypePointer Uniform %struct +%ptr_float = OpTypePointer Uniform %float + %var = OpVariable %ptr Uniform + %int = OpTypeInt 32 0 + %int_0 = OpConstant %int 0 + + %1 = OpFunction %void None %voidfn + %label = OpLabel + %2 = OpAccessChain %ptr_float %var %int_0 + %3 = OpLoad %float %2 + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Uniform id '3' is missing Binding decoration.\n" + "From Vulkan spec, section 14.5.2:\n" + "These variables must have DescriptorSet and Binding " + "decorations specified")); +} + +TEST_F(ValidateDecorations, VulkanUniformConstantMissingDescriptorSetBad) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpDecorate %var Binding 0 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %sampler = OpTypeSampler + %ptr = OpTypePointer UniformConstant %sampler + %var = OpVariable %ptr UniformConstant + + %1 = OpFunction %void None %voidfn + %label = OpLabel + %2 = OpLoad %sampler %var + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("UniformConstant id '2' is missing DescriptorSet decoration.\n" + "From Vulkan spec, section 14.5.2:\n" + "These variables must have DescriptorSet and Binding " + "decorations specified")); +} + +TEST_F(ValidateDecorations, VulkanUniformConstantMissingBindingBad) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpDecorate %var DescriptorSet 0 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %sampler = OpTypeSampler + %ptr = OpTypePointer UniformConstant %sampler + %var = OpVariable %ptr UniformConstant + + %1 = OpFunction %void None %voidfn + %label = OpLabel + %2 = OpLoad %sampler %var + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("UniformConstant id '2' is missing Binding decoration.\n" + "From Vulkan spec, section 14.5.2:\n" + "These variables must have DescriptorSet and Binding " + "decorations specified")); +} + +TEST_F(ValidateDecorations, VulkanStorageBufferMissingDescriptorSetBad) { + std::string spirv = R"( + OpCapability Shader + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpDecorate %struct Block + OpDecorate %var Binding 0 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %struct = OpTypeStruct %float + %ptr = OpTypePointer StorageBuffer %struct + %var = OpVariable %ptr StorageBuffer +%ptr_float = OpTypePointer StorageBuffer %float + %int = OpTypeInt 32 0 + %int_0 = OpConstant %int 0 + + %1 = OpFunction %void None %voidfn + %label = OpLabel + %2 = OpAccessChain %ptr_float %var %int_0 + %3 = OpLoad %float %2 + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("StorageBuffer id '3' is missing DescriptorSet decoration.\n" + "From Vulkan spec, section 14.5.2:\n" + "These variables must have DescriptorSet and Binding " + "decorations specified")); +} + +TEST_F(ValidateDecorations, VulkanStorageBufferMissingBindingBad) { + std::string spirv = R"( + OpCapability Shader + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpDecorate %struct Block + OpDecorate %var DescriptorSet 0 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %struct = OpTypeStruct %float + %ptr = OpTypePointer StorageBuffer %struct + %var = OpVariable %ptr StorageBuffer +%ptr_float = OpTypePointer StorageBuffer %float + %int = OpTypeInt 32 0 + %int_0 = OpConstant %int 0 + + %1 = OpFunction %void None %voidfn + %label = OpLabel + %2 = OpAccessChain %ptr_float %var %int_0 + %3 = OpLoad %float %2 + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("StorageBuffer id '3' is missing Binding decoration.\n" + "From Vulkan spec, section 14.5.2:\n" + "These variables must have DescriptorSet and Binding " + "decorations specified")); +} + +TEST_F(ValidateDecorations, + VulkanStorageBufferMissingDescriptorSetSubFunctionBad) { + std::string spirv = R"( + OpCapability Shader + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpDecorate %struct Block + OpDecorate %var Binding 0 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %struct = OpTypeStruct %float + %ptr = OpTypePointer StorageBuffer %struct + %var = OpVariable %ptr StorageBuffer +%ptr_float = OpTypePointer StorageBuffer %float + %int = OpTypeInt 32 0 + %int_0 = OpConstant %int 0 + + %1 = OpFunction %void None %voidfn + %label = OpLabel + %call = OpFunctionCall %void %2 + OpReturn + OpFunctionEnd + %2 = OpFunction %void None %voidfn + %label2 = OpLabel + %3 = OpAccessChain %ptr_float %var %int_0 + %4 = OpLoad %float %3 + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("StorageBuffer id '3' is missing DescriptorSet decoration.\n" + "From Vulkan spec, section 14.5.2:\n" + "These variables must have DescriptorSet and Binding " + "decorations specified")); +} + +TEST_F(ValidateDecorations, + VulkanStorageBufferMissingDescriptorAndBindingUnusedGood) { + std::string spirv = R"( + OpCapability Shader + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpDecorate %struct BufferBlock + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %struct = OpTypeStruct %float + %ptr = OpTypePointer StorageBuffer %struct + %var = OpVariable %ptr StorageBuffer + + %1 = OpFunction %void None %voidfn + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_SUCCESS, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateDecorations, UniformMissingDescriptorSetGood) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpDecorate %struct Block + OpMemberDecorate %struct 0 Offset 0 + OpDecorate %var Binding 0 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %struct = OpTypeStruct %float + %ptr = OpTypePointer Uniform %struct + %var = OpVariable %ptr Uniform + + %1 = OpFunction %void None %voidfn + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, UniformMissingBindingGood) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpDecorate %struct Block + OpMemberDecorate %struct 0 Offset 0 + OpDecorate %var DescriptorSet 0 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %struct = OpTypeStruct %float + %ptr = OpTypePointer Uniform %struct + %var = OpVariable %ptr Uniform + + %1 = OpFunction %void None %voidfn + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, UniformConstantMissingDescriptorSetGood) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpDecorate %var Binding 0 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %sampler = OpTypeSampler + %ptr = OpTypePointer UniformConstant %sampler + %var = OpVariable %ptr UniformConstant + + %1 = OpFunction %void None %voidfn + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, UniformConstantMissingBindingGood) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpDecorate %var DescriptorSet 0 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %sampler = OpTypeSampler + %ptr = OpTypePointer UniformConstant %sampler + %var = OpVariable %ptr UniformConstant + + %1 = OpFunction %void None %voidfn + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, StorageBufferMissingDescriptorSetGood) { + std::string spirv = R"( + OpCapability Shader + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpDecorate %struct BufferBlock + OpDecorate %var Binding 0 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %struct = OpTypeStruct %float + %ptr = OpTypePointer StorageBuffer %struct + %var = OpVariable %ptr StorageBuffer + + %1 = OpFunction %void None %voidfn + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, StorageBufferMissingBindingGood) { + std::string spirv = R"( + OpCapability Shader + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpDecorate %struct BufferBlock + OpDecorate %var DescriptorSet 0 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %struct = OpTypeStruct %float + %ptr = OpTypePointer StorageBuffer %struct + %var = OpVariable %ptr StorageBuffer + + %1 = OpFunction %void None %voidfn + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, StorageBufferStorageClassArrayBaseAlignmentGood) { + // Spot check buffer rules when using StorageBuffer storage class with Block + // decoration. + std::string spirv = R"( + OpCapability Shader + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpDecorate %_arr_float_uint_2 ArrayStride 4 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 8 + OpDecorate %S Block + OpDecorate %u DescriptorSet 0 + OpDecorate %u Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 + %S = OpTypeStruct %v2float %_arr_float_uint_2 +%_ptr_Uniform_S = OpTypePointer StorageBuffer %S + %u = OpVariable %_ptr_Uniform_S StorageBuffer + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, StorageBufferStorageClassArrayBadAlignmentBad) { + // Like the previous test, but with offset 7. + std::string spirv = R"( + OpCapability Shader + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpDecorate %_arr_float_uint_2 ArrayStride 4 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 7 + OpDecorate %S Block + OpDecorate %u DescriptorSet 0 + OpDecorate %u Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 + %S = OpTypeStruct %v2float %_arr_float_uint_2 +%_ptr_Uniform_S = OpTypePointer StorageBuffer %S + %u = OpVariable %_ptr_Uniform_S StorageBuffer + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 3 decorated as Block for variable in StorageBuffer " + "storage class must follow standard storage buffer layout rules: " + "member 1 at offset 7 is not aligned to 4")); +} + +TEST_F(ValidateDecorations, BufferBlockStandardStorageBufferLayout) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %F 0 Offset 0 + OpMemberDecorate %F 1 Offset 8 + OpDecorate %_arr_float_uint_2 ArrayStride 4 + OpDecorate %_arr_mat3v3float_uint_2 ArrayStride 48 + OpMemberDecorate %O 0 Offset 0 + OpMemberDecorate %O 1 Offset 16 + OpMemberDecorate %O 2 Offset 24 + OpMemberDecorate %O 3 Offset 32 + OpMemberDecorate %O 4 ColMajor + OpMemberDecorate %O 4 Offset 48 + OpMemberDecorate %O 4 MatrixStride 16 + OpDecorate %_arr_O_uint_2 ArrayStride 144 + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 8 + OpMemberDecorate %Output 2 Offset 16 + OpMemberDecorate %Output 3 Offset 32 + OpMemberDecorate %Output 4 Offset 48 + OpMemberDecorate %Output 5 Offset 52 + OpMemberDecorate %Output 6 ColMajor + OpMemberDecorate %Output 6 Offset 64 + OpMemberDecorate %Output 6 MatrixStride 16 + OpMemberDecorate %Output 7 Offset 96 + OpDecorate %Output BufferBlock + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %v3float = OpTypeVector %float 3 + %int = OpTypeInt 32 1 + %uint = OpTypeInt 32 0 + %v2uint = OpTypeVector %uint 2 + %F = OpTypeStruct %int %v2uint + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 +%mat2v3float = OpTypeMatrix %v3float 2 + %v3uint = OpTypeVector %uint 3 +%mat3v3float = OpTypeMatrix %v3float 3 +%_arr_mat3v3float_uint_2 = OpTypeArray %mat3v3float %uint_2 + %O = OpTypeStruct %v3uint %v2float %_arr_float_uint_2 %v2float %_arr_mat3v3float_uint_2 +%_arr_O_uint_2 = OpTypeArray %O %uint_2 + %Output = OpTypeStruct %float %v2float %v3float %F %float %_arr_float_uint_2 %mat2v3float %_arr_O_uint_2 +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateDecorations, + StorageBufferLayoutPermitsTightVec3ScalarPackingGood) { + // See https://github.com/KhronosGroup/SPIRV-Tools/issues/1666 + std::string spirv = R"( + OpCapability Shader + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 12 + OpDecorate %S Block + OpDecorate %B DescriptorSet 0 + OpDecorate %B Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %S = OpTypeStruct %v3float %float +%_ptr_StorageBuffer_S = OpTypePointer StorageBuffer %S + %B = OpVariable %_ptr_StorageBuffer_S StorageBuffer + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, + StorageBufferLayoutForbidsTightScalarVec3PackingBad) { + // See https://github.com/KhronosGroup/SPIRV-Tools/issues/1666 + std::string spirv = R"( + OpCapability Shader + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %S 0 Offset 0 + OpMemberDecorate %S 1 Offset 4 + OpDecorate %S Block + OpDecorate %B DescriptorSet 0 + OpDecorate %B Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %S = OpTypeStruct %float %v3float +%_ptr_StorageBuffer_S = OpTypePointer StorageBuffer %S + %B = OpVariable %_ptr_StorageBuffer_S StorageBuffer + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 2 decorated as Block for variable in StorageBuffer " + "storage class must follow standard storage buffer layout " + "rules: member 1 at offset 4 is not aligned to 16")); +} + +TEST_F(ValidateDecorations, + BlockStandardUniformBufferLayoutIncorrectOffset0Bad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %F 0 Offset 0 + OpMemberDecorate %F 1 Offset 8 + OpDecorate %_arr_float_uint_2 ArrayStride 16 + OpDecorate %_arr_mat3v3float_uint_2 ArrayStride 48 + OpMemberDecorate %O 0 Offset 0 + OpMemberDecorate %O 1 Offset 16 + OpMemberDecorate %O 2 Offset 24 + OpMemberDecorate %O 3 Offset 33 + OpMemberDecorate %O 4 ColMajor + OpMemberDecorate %O 4 Offset 80 + OpMemberDecorate %O 4 MatrixStride 16 + OpDecorate %_arr_O_uint_2 ArrayStride 176 + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 8 + OpMemberDecorate %Output 2 Offset 16 + OpMemberDecorate %Output 3 Offset 32 + OpMemberDecorate %Output 4 Offset 48 + OpMemberDecorate %Output 5 Offset 64 + OpMemberDecorate %Output 6 ColMajor + OpMemberDecorate %Output 6 Offset 96 + OpMemberDecorate %Output 6 MatrixStride 16 + OpMemberDecorate %Output 7 Offset 128 + OpDecorate %Output Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %v3float = OpTypeVector %float 3 + %int = OpTypeInt 32 1 + %uint = OpTypeInt 32 0 + %v2uint = OpTypeVector %uint 2 + %F = OpTypeStruct %int %v2uint + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 +%mat2v3float = OpTypeMatrix %v3float 2 + %v3uint = OpTypeVector %uint 3 +%mat3v3float = OpTypeMatrix %v3float 3 +%_arr_mat3v3float_uint_2 = OpTypeArray %mat3v3float %uint_2 + %O = OpTypeStruct %v3uint %v2float %_arr_float_uint_2 %v2float %_arr_mat3v3float_uint_2 +%_arr_O_uint_2 = OpTypeArray %O %uint_2 + %Output = OpTypeStruct %float %v2float %v3float %F %float %_arr_float_uint_2 %mat2v3float %_arr_O_uint_2 +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Structure id 6 decorated as Block for variable in Uniform " + "storage class must follow standard uniform buffer layout " + "rules: member 2 at offset 24 is not aligned to 16")); +} + +TEST_F(ValidateDecorations, + BlockStandardUniformBufferLayoutIncorrectOffset1Bad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %F 0 Offset 0 + OpMemberDecorate %F 1 Offset 8 + OpDecorate %_arr_float_uint_2 ArrayStride 16 + OpDecorate %_arr_mat3v3float_uint_2 ArrayStride 48 + OpMemberDecorate %O 0 Offset 0 + OpMemberDecorate %O 1 Offset 16 + OpMemberDecorate %O 2 Offset 32 + OpMemberDecorate %O 3 Offset 64 + OpMemberDecorate %O 4 ColMajor + OpMemberDecorate %O 4 Offset 80 + OpMemberDecorate %O 4 MatrixStride 16 + OpDecorate %_arr_O_uint_2 ArrayStride 176 + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 8 + OpMemberDecorate %Output 2 Offset 16 + OpMemberDecorate %Output 3 Offset 32 + OpMemberDecorate %Output 4 Offset 48 + OpMemberDecorate %Output 5 Offset 71 + OpMemberDecorate %Output 6 ColMajor + OpMemberDecorate %Output 6 Offset 96 + OpMemberDecorate %Output 6 MatrixStride 16 + OpMemberDecorate %Output 7 Offset 128 + OpDecorate %Output Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %v3float = OpTypeVector %float 3 + %int = OpTypeInt 32 1 + %uint = OpTypeInt 32 0 + %v2uint = OpTypeVector %uint 2 + %F = OpTypeStruct %int %v2uint + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 +%mat2v3float = OpTypeMatrix %v3float 2 + %v3uint = OpTypeVector %uint 3 +%mat3v3float = OpTypeMatrix %v3float 3 +%_arr_mat3v3float_uint_2 = OpTypeArray %mat3v3float %uint_2 + %O = OpTypeStruct %v3uint %v2float %_arr_float_uint_2 %v2float %_arr_mat3v3float_uint_2 +%_arr_O_uint_2 = OpTypeArray %O %uint_2 + %Output = OpTypeStruct %float %v2float %v3float %F %float %_arr_float_uint_2 %mat2v3float %_arr_O_uint_2 +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Structure id 8 decorated as Block for variable in Uniform " + "storage class must follow standard uniform buffer layout " + "rules: member 5 at offset 71 is not aligned to 16")); +} + +TEST_F(ValidateDecorations, BlockUniformBufferLayoutIncorrectArrayStrideBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %F 0 Offset 0 + OpMemberDecorate %F 1 Offset 8 + OpDecorate %_arr_float_uint_2 ArrayStride 16 + OpDecorate %_arr_mat3v3float_uint_2 ArrayStride 49 + OpMemberDecorate %O 0 Offset 0 + OpMemberDecorate %O 1 Offset 16 + OpMemberDecorate %O 2 Offset 32 + OpMemberDecorate %O 3 Offset 64 + OpMemberDecorate %O 4 ColMajor + OpMemberDecorate %O 4 Offset 80 + OpMemberDecorate %O 4 MatrixStride 16 + OpDecorate %_arr_O_uint_2 ArrayStride 176 + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 8 + OpMemberDecorate %Output 2 Offset 16 + OpMemberDecorate %Output 3 Offset 32 + OpMemberDecorate %Output 4 Offset 48 + OpMemberDecorate %Output 5 Offset 64 + OpMemberDecorate %Output 6 ColMajor + OpMemberDecorate %Output 6 Offset 96 + OpMemberDecorate %Output 6 MatrixStride 16 + OpMemberDecorate %Output 7 Offset 128 + OpDecorate %Output Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v2float = OpTypeVector %float 2 + %v3float = OpTypeVector %float 3 + %int = OpTypeInt 32 1 + %uint = OpTypeInt 32 0 + %v2uint = OpTypeVector %uint 2 + %F = OpTypeStruct %int %v2uint + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 +%mat2v3float = OpTypeMatrix %v3float 2 + %v3uint = OpTypeVector %uint 3 +%mat3v3float = OpTypeMatrix %v3float 3 +%_arr_mat3v3float_uint_2 = OpTypeArray %mat3v3float %uint_2 + %O = OpTypeStruct %v3uint %v2float %_arr_float_uint_2 %v2float %_arr_mat3v3float_uint_2 +%_arr_O_uint_2 = OpTypeArray %O %uint_2 + %Output = OpTypeStruct %float %v2float %v3float %F %float %_arr_float_uint_2 %mat2v3float %_arr_O_uint_2 +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 6 decorated as Block for variable in Uniform storage " + "class must follow standard uniform buffer layout rules: member 4 is " + "an array with stride 49 not satisfying alignment to 16")); +} + +TEST_F(ValidateDecorations, + BufferBlockStandardStorageBufferLayoutImproperStraddleBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 8 + OpDecorate %Output BufferBlock + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %Output = OpTypeStruct %float %v3float +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Structure id 3 decorated as BufferBlock for variable in " + "Uniform storage class must follow standard storage buffer " + "layout rules: member 1 at offset 8 is not aligned to 16")); +} + +TEST_F(ValidateDecorations, + BlockUniformBufferLayoutOffsetInsideArrayPaddingBad) { + // In this case the 2nd member fits entirely within the padding. + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpDecorate %_arr_float_uint_2 ArrayStride 16 + OpMemberDecorate %Output 0 Offset 0 + OpMemberDecorate %Output 1 Offset 20 + OpDecorate %Output Block + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %uint = OpTypeInt 32 0 + %v2uint = OpTypeVector %uint 2 + %uint_2 = OpConstant %uint 2 +%_arr_float_uint_2 = OpTypeArray %float %uint_2 + %Output = OpTypeStruct %_arr_float_uint_2 %float +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 4 decorated as Block for variable in Uniform storage " + "class must follow standard uniform buffer layout rules: member 1 at " + "offset 20 overlaps previous member ending at offset 31")); +} + +TEST_F(ValidateDecorations, + BlockUniformBufferLayoutOffsetInsideStructPaddingBad) { + // In this case the 2nd member fits entirely within the padding. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + OpMemberDecorate %_struct_6 0 Offset 0 + OpMemberDecorate %_struct_2 0 Offset 0 + OpMemberDecorate %_struct_2 1 Offset 4 + OpDecorate %_struct_2 Block + %void = OpTypeVoid + %4 = OpTypeFunction %void + %float = OpTypeFloat 32 + %_struct_6 = OpTypeStruct %float + %_struct_2 = OpTypeStruct %_struct_6 %float +%_ptr_Uniform__struct_2 = OpTypePointer Uniform %_struct_2 + %8 = OpVariable %_ptr_Uniform__struct_2 Uniform + %1 = OpFunction %void None %4 + %9 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 3 decorated as Block for variable in Uniform storage " + "class must follow standard uniform buffer layout rules: member 1 at " + "offset 4 overlaps previous member ending at offset 15")); +} + +TEST_F(ValidateDecorations, BlockLayoutOffsetOutOfOrderGoodUniversal1_0) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpMemberDecorate %Outer 0 Offset 4 + OpMemberDecorate %Outer 1 Offset 0 + OpDecorate %Outer Block + OpDecorate %O DescriptorSet 0 + OpDecorate %O Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %Outer = OpTypeStruct %uint %uint +%_ptr_Uniform_Outer = OpTypePointer Uniform %Outer + %O = OpVariable %_ptr_Uniform_Outer Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, + ValidateAndRetrieveValidationState(SPV_ENV_UNIVERSAL_1_0)); +} + +TEST_F(ValidateDecorations, BlockLayoutOffsetOutOfOrderGoodOpenGL4_5) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpMemberDecorate %Outer 0 Offset 4 + OpMemberDecorate %Outer 1 Offset 0 + OpDecorate %Outer Block + OpDecorate %O DescriptorSet 0 + OpDecorate %O Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %Outer = OpTypeStruct %uint %uint +%_ptr_Uniform_Outer = OpTypePointer Uniform %Outer + %O = OpVariable %_ptr_Uniform_Outer Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, + ValidateAndRetrieveValidationState(SPV_ENV_OPENGL_4_5)); +} + +TEST_F(ValidateDecorations, BlockLayoutOffsetOutOfOrderGoodVulkan1_1) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpMemberDecorate %Outer 0 Offset 4 + OpMemberDecorate %Outer 1 Offset 0 + OpDecorate %Outer Block + OpDecorate %O DescriptorSet 0 + OpDecorate %O Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %Outer = OpTypeStruct %uint %uint +%_ptr_Uniform_Outer = OpTypePointer Uniform %Outer + %O = OpVariable %_ptr_Uniform_Outer Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1)) + << getDiagnosticString(); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateDecorations, BlockLayoutOffsetOverlapBad) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpMemberDecorate %Outer 0 Offset 0 + OpMemberDecorate %Outer 1 Offset 16 + OpMemberDecorate %Inner 0 Offset 0 + OpMemberDecorate %Inner 1 Offset 16 + OpDecorate %Outer Block + OpDecorate %O DescriptorSet 0 + OpDecorate %O Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %Inner = OpTypeStruct %uint %uint + %Outer = OpTypeStruct %Inner %uint +%_ptr_Uniform_Outer = OpTypePointer Uniform %Outer + %O = OpVariable %_ptr_Uniform_Outer Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 3 decorated as Block for variable in Uniform storage " + "class must follow standard uniform buffer layout rules: member 1 at " + "offset 16 overlaps previous member ending at offset 31")); +} + +TEST_F(ValidateDecorations, BufferBlockEmptyStruct) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpMemberDecorate %Output 0 Offset 0 + OpDecorate %Output BufferBlock + %void = OpTypeVoid + %3 = OpTypeFunction %void + %S = OpTypeStruct + %Output = OpTypeStruct %S +%_ptr_Uniform_Output = OpTypePointer Uniform %Output + %dataOutput = OpVariable %_ptr_Uniform_Output Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateDecorations, RowMajorMatrixTightPackingGood) { + // Row major matrix rule: + // A row-major matrix of C columns has a base alignment equal to + // the base alignment of a vector of C matrix components. + // Note: The "matrix component" is the scalar element type. + + // The matrix has 3 columns and 2 rows (C=3, R=2). + // So the base alignment of b is the same as a vector of 3 floats, which is 16 + // bytes. The matrix consists of two of these, and therefore occupies 2 x 16 + // bytes, or 32 bytes. + // + // So the offsets can be: + // a -> 0 + // b -> 16 + // c -> 48 + // d -> 60 ; d fits at bytes 12-15 after offset of c. Tight (vec3;float) + // packing + + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "main" + OpSource GLSL 450 + OpMemberDecorate %_struct_2 0 Offset 0 + OpMemberDecorate %_struct_2 1 RowMajor + OpMemberDecorate %_struct_2 1 Offset 16 + OpMemberDecorate %_struct_2 1 MatrixStride 16 + OpMemberDecorate %_struct_2 2 Offset 48 + OpMemberDecorate %_struct_2 3 Offset 60 + OpDecorate %_struct_2 Block + OpDecorate %3 DescriptorSet 0 + OpDecorate %3 Binding 0 + %void = OpTypeVoid + %5 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %v2float = OpTypeVector %float 2 +%mat3v2float = OpTypeMatrix %v2float 3 + %v3float = OpTypeVector %float 3 + %_struct_2 = OpTypeStruct %v4float %mat3v2float %v3float %float +%_ptr_Uniform__struct_2 = OpTypePointer Uniform %_struct_2 + %3 = OpVariable %_ptr_Uniform__struct_2 Uniform + %1 = OpFunction %void None %5 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, ArrayArrayRowMajorMatrixTightPackingGood) { + // Like the previous case, but we have an array of arrays of matrices. + // The RowMajor decoration goes on the struct member (surprisingly). + + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "main" + OpSource GLSL 450 + OpMemberDecorate %_struct_2 0 Offset 0 + OpMemberDecorate %_struct_2 1 RowMajor + OpMemberDecorate %_struct_2 1 Offset 16 + OpMemberDecorate %_struct_2 1 MatrixStride 16 + OpMemberDecorate %_struct_2 2 Offset 80 + OpMemberDecorate %_struct_2 3 Offset 92 + OpDecorate %arr_mat ArrayStride 32 + OpDecorate %arr_arr_mat ArrayStride 32 + OpDecorate %_struct_2 Block + OpDecorate %3 DescriptorSet 0 + OpDecorate %3 Binding 0 + %void = OpTypeVoid + %5 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %v2float = OpTypeVector %float 2 +%mat3v2float = OpTypeMatrix %v2float 3 +%uint = OpTypeInt 32 0 +%uint_1 = OpConstant %uint 1 +%uint_2 = OpConstant %uint 2 + %arr_mat = OpTypeArray %mat3v2float %uint_1 +%arr_arr_mat = OpTypeArray %arr_mat %uint_2 + %v3float = OpTypeVector %float 3 + %_struct_2 = OpTypeStruct %v4float %arr_arr_mat %v3float %float +%_ptr_Uniform__struct_2 = OpTypePointer Uniform %_struct_2 + %3 = OpVariable %_ptr_Uniform__struct_2 Uniform + %1 = OpFunction %void None %5 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()) + << getDiagnosticString(); +} + +TEST_F(ValidateDecorations, ArrayArrayRowMajorMatrixNextMemberOverlapsBad) { + // Like the previous case, but the offset of member 2 overlaps the matrix. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "main" + OpSource GLSL 450 + OpMemberDecorate %_struct_2 0 Offset 0 + OpMemberDecorate %_struct_2 1 RowMajor + OpMemberDecorate %_struct_2 1 Offset 16 + OpMemberDecorate %_struct_2 1 MatrixStride 16 + OpMemberDecorate %_struct_2 2 Offset 64 + OpMemberDecorate %_struct_2 3 Offset 92 + OpDecorate %arr_mat ArrayStride 32 + OpDecorate %arr_arr_mat ArrayStride 32 + OpDecorate %_struct_2 Block + OpDecorate %3 DescriptorSet 0 + OpDecorate %3 Binding 0 + %void = OpTypeVoid + %5 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %v2float = OpTypeVector %float 2 +%mat3v2float = OpTypeMatrix %v2float 3 +%uint = OpTypeInt 32 0 +%uint_1 = OpConstant %uint 1 +%uint_2 = OpConstant %uint 2 + %arr_mat = OpTypeArray %mat3v2float %uint_1 +%arr_arr_mat = OpTypeArray %arr_mat %uint_2 + %v3float = OpTypeVector %float 3 + %_struct_2 = OpTypeStruct %v4float %arr_arr_mat %v3float %float +%_ptr_Uniform__struct_2 = OpTypePointer Uniform %_struct_2 + %3 = OpVariable %_ptr_Uniform__struct_2 Uniform + %1 = OpFunction %void None %5 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 2 decorated as Block for variable in Uniform storage " + "class must follow standard uniform buffer layout rules: member 2 at " + "offset 64 overlaps previous member ending at offset 79")); +} + +TEST_F(ValidateDecorations, StorageBufferArraySizeCalculationPackGood) { + // Original GLSL + + // #version 450 + // layout (set=0,binding=0) buffer S { + // uvec3 arr[2][2]; // first 3 elements are 16 bytes, last is 12 + // uint i; // Can have offset 60 = 3x16 + 12 + // } B; + // void main() {} + + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "main" + OpDecorate %_arr_v3uint_uint_2 ArrayStride 16 + OpDecorate %_arr__arr_v3uint_uint_2_uint_2 ArrayStride 32 + OpMemberDecorate %_struct_4 0 Offset 0 + OpMemberDecorate %_struct_4 1 Offset 60 + OpDecorate %_struct_4 BufferBlock + OpDecorate %5 DescriptorSet 0 + OpDecorate %5 Binding 0 + %void = OpTypeVoid + %7 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %v3uint = OpTypeVector %uint 3 + %uint_2 = OpConstant %uint 2 +%_arr_v3uint_uint_2 = OpTypeArray %v3uint %uint_2 +%_arr__arr_v3uint_uint_2_uint_2 = OpTypeArray %_arr_v3uint_uint_2 %uint_2 + %_struct_4 = OpTypeStruct %_arr__arr_v3uint_uint_2_uint_2 %uint +%_ptr_Uniform__struct_4 = OpTypePointer Uniform %_struct_4 + %5 = OpVariable %_ptr_Uniform__struct_4 Uniform + %1 = OpFunction %void None %7 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateDecorations, StorageBufferArraySizeCalculationPackBad) { + // Like previous but, the offset of the second member is too small. + + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "main" + OpDecorate %_arr_v3uint_uint_2 ArrayStride 16 + OpDecorate %_arr__arr_v3uint_uint_2_uint_2 ArrayStride 32 + OpMemberDecorate %_struct_4 0 Offset 0 + OpMemberDecorate %_struct_4 1 Offset 56 + OpDecorate %_struct_4 BufferBlock + OpDecorate %5 DescriptorSet 0 + OpDecorate %5 Binding 0 + %void = OpTypeVoid + %7 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %v3uint = OpTypeVector %uint 3 + %uint_2 = OpConstant %uint 2 +%_arr_v3uint_uint_2 = OpTypeArray %v3uint %uint_2 +%_arr__arr_v3uint_uint_2_uint_2 = OpTypeArray %_arr_v3uint_uint_2 %uint_2 + %_struct_4 = OpTypeStruct %_arr__arr_v3uint_uint_2_uint_2 %uint +%_ptr_Uniform__struct_4 = OpTypePointer Uniform %_struct_4 + %5 = OpVariable %_ptr_Uniform__struct_4 Uniform + %1 = OpFunction %void None %7 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Structure id 4 decorated as BufferBlock for variable " + "in Uniform storage class must follow standard storage " + "buffer layout rules: member 1 at offset 56 overlaps " + "previous member ending at offset 59")); +} + +TEST_F(ValidateDecorations, UniformBufferArraySizeCalculationPackGood) { + // Like the corresponding buffer block case, but the array padding must + // count for the last element as well, and so the offset of the second + // member must be at least 64. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "main" + OpDecorate %_arr_v3uint_uint_2 ArrayStride 16 + OpDecorate %_arr__arr_v3uint_uint_2_uint_2 ArrayStride 32 + OpMemberDecorate %_struct_4 0 Offset 0 + OpMemberDecorate %_struct_4 1 Offset 64 + OpDecorate %_struct_4 Block + OpDecorate %5 DescriptorSet 0 + OpDecorate %5 Binding 0 + %void = OpTypeVoid + %7 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %v3uint = OpTypeVector %uint 3 + %uint_2 = OpConstant %uint 2 +%_arr_v3uint_uint_2 = OpTypeArray %v3uint %uint_2 +%_arr__arr_v3uint_uint_2_uint_2 = OpTypeArray %_arr_v3uint_uint_2 %uint_2 + %_struct_4 = OpTypeStruct %_arr__arr_v3uint_uint_2_uint_2 %uint +%_ptr_Uniform__struct_4 = OpTypePointer Uniform %_struct_4 + %5 = OpVariable %_ptr_Uniform__struct_4 Uniform + %1 = OpFunction %void None %7 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateDecorations, UniformBufferArraySizeCalculationPackBad) { + // Like previous but, the offset of the second member is too small. + + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %1 "main" + OpDecorate %_arr_v3uint_uint_2 ArrayStride 16 + OpDecorate %_arr__arr_v3uint_uint_2_uint_2 ArrayStride 32 + OpMemberDecorate %_struct_4 0 Offset 0 + OpMemberDecorate %_struct_4 1 Offset 60 + OpDecorate %_struct_4 Block + OpDecorate %5 DescriptorSet 0 + OpDecorate %5 Binding 0 + %void = OpTypeVoid + %7 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %v3uint = OpTypeVector %uint 3 + %uint_2 = OpConstant %uint 2 +%_arr_v3uint_uint_2 = OpTypeArray %v3uint %uint_2 +%_arr__arr_v3uint_uint_2_uint_2 = OpTypeArray %_arr_v3uint_uint_2 %uint_2 + %_struct_4 = OpTypeStruct %_arr__arr_v3uint_uint_2_uint_2 %uint +%_ptr_Uniform__struct_4 = OpTypePointer Uniform %_struct_4 + %5 = OpVariable %_ptr_Uniform__struct_4 Uniform + %1 = OpFunction %void None %7 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure id 4 decorated as Block for variable in Uniform storage " + "class must follow standard uniform buffer layout rules: member 1 at " + "offset 60 overlaps previous member ending at offset 63")); +} + +TEST_F(ValidateDecorations, LayoutNotCheckedWhenSkipBlockLayout) { + // Checks that block layout is not verified in skipping block layout mode. + // Even for obviously wrong layout. + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 450 + OpMemberDecorate %S 0 Offset 3 ; wrong alignment + OpMemberDecorate %S 1 Offset 3 ; same offset as before! + OpDecorate %S Block + OpDecorate %B DescriptorSet 0 + OpDecorate %B Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %S = OpTypeStruct %float %v3float +%_ptr_Uniform_S = OpTypePointer Uniform %S + %B = OpVariable %_ptr_Uniform_S Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + spvValidatorOptionsSetSkipBlockLayout(getValidatorOptions(), true); + EXPECT_EQ(SPV_SUCCESS, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_0)); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateDecorations, EntryPointVariableWrongStorageClass) { + const std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" %var +OpExecutionMode %1 OriginUpperLeft +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%ptr_int_Workgroup = OpTypePointer Workgroup %int +%var = OpVariable %ptr_int_Workgroup Workgroup +%func_ty = OpTypeFunction %void +%1 = OpFunction %void None %func_ty +%2 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpEntryPoint interfaces must be OpVariables with " + "Storage Class of Input(1) or Output(3). Found Storage " + "Class 4 for Entry Point id 1.")); +} + +TEST_F(ValidateDecorations, VulkanMemoryModelNonCoherent) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpMemoryModel Logical VulkanKHR +OpDecorate %1 Coherent +%2 = OpTypeInt 32 0 +%3 = OpTypePointer StorageBuffer %2 +%1 = OpVariable %3 StorageBuffer +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Coherent decoration targeting 1[%1] is " + "banned when using the Vulkan memory model.")); +} + +TEST_F(ValidateDecorations, VulkanMemoryModelNoCoherentMember) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpMemberDecorate %1 0 Coherent +%2 = OpTypeInt 32 0 +%1 = OpTypeStruct %2 %2 +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Coherent decoration targeting 1[%_struct_1] (member index 0) " + "is banned when using the Vulkan memory model.")); +} + +TEST_F(ValidateDecorations, VulkanMemoryModelNoVolatile) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpMemoryModel Logical VulkanKHR +OpDecorate %1 Volatile +%2 = OpTypeInt 32 0 +%3 = OpTypePointer StorageBuffer %2 +%1 = OpVariable %3 StorageBuffer +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Volatile decoration targeting 1[%1] is banned when " + "using the Vulkan memory model.")); +} + +TEST_F(ValidateDecorations, VulkanMemoryModelNoVolatileMember) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpMemberDecorate %1 1 Volatile +%2 = OpTypeInt 32 0 +%1 = OpTypeStruct %2 %2 +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Volatile decoration targeting 1[%_struct_1] (member " + "index 1) is banned when using the Vulkan memory " + "model.")); +} + +TEST_F(ValidateDecorations, FPRoundingModeGood) { + std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability StorageBuffer16BitAccess +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpExtension "SPV_KHR_variable_pointers" +OpExtension "SPV_KHR_16bit_storage" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpDecorate %_ FPRoundingMode RTE +%half = OpTypeFloat 16 +%float = OpTypeFloat 32 +%float_1_25 = OpConstant %float 1.25 +%half_ptr = OpTypePointer StorageBuffer %half +%half_ptr_var = OpVariable %half_ptr StorageBuffer +%void = OpTypeVoid +%func = OpTypeFunction %void +%main = OpFunction %void None %func +%main_entry = OpLabel +%_ = OpFConvert %half %float_1_25 +OpStore %half_ptr_var %_ +OpReturn +OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateDecorations, FPRoundingModeVectorGood) { + std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability StorageBuffer16BitAccess +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpExtension "SPV_KHR_variable_pointers" +OpExtension "SPV_KHR_16bit_storage" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpDecorate %_ FPRoundingMode RTE +%half = OpTypeFloat 16 +%float = OpTypeFloat 32 +%v2half = OpTypeVector %half 2 +%v2float = OpTypeVector %float 2 +%float_1_25 = OpConstant %float 1.25 +%floats = OpConstantComposite %v2float %float_1_25 %float_1_25 +%halfs_ptr = OpTypePointer StorageBuffer %v2half +%halfs_ptr_var = OpVariable %halfs_ptr StorageBuffer +%void = OpTypeVoid +%func = OpTypeFunction %void +%main = OpFunction %void None %func +%main_entry = OpLabel +%_ = OpFConvert %v2half %floats +OpStore %halfs_ptr_var %_ +OpReturn +OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateDecorations, FPRoundingModeNotOpFConvert) { + std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability StorageBuffer16BitAccess +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpExtension "SPV_KHR_variable_pointers" +OpExtension "SPV_KHR_16bit_storage" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpDecorate %_ FPRoundingMode RTE +%short = OpTypeInt 16 1 +%int = OpTypeInt 32 1 +%int_17 = OpConstant %int 17 +%short_ptr = OpTypePointer StorageBuffer %short +%short_ptr_var = OpVariable %short_ptr StorageBuffer +%void = OpTypeVoid +%func = OpTypeFunction %void +%main = OpFunction %void None %func +%main_entry = OpLabel +%_ = OpSConvert %short %int_17 +OpStore %short_ptr_var %_ +OpReturn +OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("FPRoundingMode decoration can be applied only to a " + "width-only conversion instruction for floating-point " + "object.")); +} + +TEST_F(ValidateDecorations, FPRoundingModeNoOpStoreGood) { + std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability StorageBuffer16BitAccess +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpExtension "SPV_KHR_variable_pointers" +OpExtension "SPV_KHR_16bit_storage" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpDecorate %_ FPRoundingMode RTE +%half = OpTypeFloat 16 +%float = OpTypeFloat 32 +%float_1_25 = OpConstant %float 1.25 +%half_ptr = OpTypePointer StorageBuffer %half +%half_ptr_var = OpVariable %half_ptr StorageBuffer +%void = OpTypeVoid +%func = OpTypeFunction %void +%main = OpFunction %void None %func +%main_entry = OpLabel +%_ = OpFConvert %half %float_1_25 +OpReturn +OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateDecorations, FPRoundingModeFConvert64to16Good) { + std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability StorageBuffer16BitAccess +OpCapability Float64 +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpExtension "SPV_KHR_variable_pointers" +OpExtension "SPV_KHR_16bit_storage" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpDecorate %_ FPRoundingMode RTE +%half = OpTypeFloat 16 +%double = OpTypeFloat 64 +%double_1_25 = OpConstant %double 1.25 +%half_ptr = OpTypePointer StorageBuffer %half +%half_ptr_var = OpVariable %half_ptr StorageBuffer +%void = OpTypeVoid +%func = OpTypeFunction %void +%main = OpFunction %void None %func +%main_entry = OpLabel +%_ = OpFConvert %half %double_1_25 +OpStore %half_ptr_var %_ +OpReturn +OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateDecorations, FPRoundingModeNotStoreInFloat16) { + std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability StorageBuffer16BitAccess +OpCapability Float64 +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpExtension "SPV_KHR_variable_pointers" +OpExtension "SPV_KHR_16bit_storage" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpDecorate %_ FPRoundingMode RTE +%float = OpTypeFloat 32 +%double = OpTypeFloat 64 +%double_1_25 = OpConstant %double 1.25 +%float_ptr = OpTypePointer StorageBuffer %float +%float_ptr_var = OpVariable %float_ptr StorageBuffer +%void = OpTypeVoid +%func = OpTypeFunction %void +%main = OpFunction %void None %func +%main_entry = OpLabel +%_ = OpFConvert %float %double_1_25 +OpStore %float_ptr_var %_ +OpReturn +OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("FPRoundingMode decoration can be applied only to the " + "Object operand of an OpStore storing through a " + "pointer to a 16-bit floating-point scalar or vector object.")); +} + +TEST_F(ValidateDecorations, FPRoundingModeBadStorageClass) { + std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability StorageBuffer16BitAccess +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpExtension "SPV_KHR_variable_pointers" +OpExtension "SPV_KHR_16bit_storage" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpDecorate %_ FPRoundingMode RTE +%half = OpTypeFloat 16 +%float = OpTypeFloat 32 +%float_1_25 = OpConstant %float 1.25 +%half_ptr = OpTypePointer Private %half +%half_ptr_var = OpVariable %half_ptr Private +%void = OpTypeVoid +%func = OpTypeFunction %void +%main = OpFunction %void None %func +%main_entry = OpLabel +%_ = OpFConvert %half %float_1_25 +OpStore %half_ptr_var %_ +OpReturn +OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("FPRoundingMode decoration can be applied only to the " + "Object operand of an OpStore in the StorageBuffer, " + "PhysicalStorageBufferEXT, Uniform, " + "PushConstant, Input, or Output Storage Classes.")); +} + +TEST_F(ValidateDecorations, FPRoundingModeMultipleOpStoreGood) { + std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability StorageBuffer16BitAccess +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpExtension "SPV_KHR_variable_pointers" +OpExtension "SPV_KHR_16bit_storage" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpDecorate %_ FPRoundingMode RTE +%half = OpTypeFloat 16 +%float = OpTypeFloat 32 +%float_1_25 = OpConstant %float 1.25 +%half_ptr = OpTypePointer StorageBuffer %half +%half_ptr_var_0 = OpVariable %half_ptr StorageBuffer +%half_ptr_var_1 = OpVariable %half_ptr StorageBuffer +%half_ptr_var_2 = OpVariable %half_ptr StorageBuffer +%void = OpTypeVoid +%func = OpTypeFunction %void +%main = OpFunction %void None %func +%main_entry = OpLabel +%_ = OpFConvert %half %float_1_25 +OpStore %half_ptr_var_0 %_ +OpStore %half_ptr_var_1 %_ +OpStore %half_ptr_var_2 %_ +OpReturn +OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateDecorations, FPRoundingModeMultipleUsesBad) { + std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability StorageBuffer16BitAccess +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpExtension "SPV_KHR_variable_pointers" +OpExtension "SPV_KHR_16bit_storage" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpDecorate %_ FPRoundingMode RTE +%half = OpTypeFloat 16 +%float = OpTypeFloat 32 +%float_1_25 = OpConstant %float 1.25 +%half_ptr = OpTypePointer StorageBuffer %half +%half_ptr_var_0 = OpVariable %half_ptr StorageBuffer +%half_ptr_var_1 = OpVariable %half_ptr StorageBuffer +%void = OpTypeVoid +%func = OpTypeFunction %void +%main = OpFunction %void None %func +%main_entry = OpLabel +%_ = OpFConvert %half %float_1_25 +OpStore %half_ptr_var_0 %_ +%result = OpFAdd %half %_ %_ +OpStore %half_ptr_var_1 %_ +OpReturn +OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("FPRoundingMode decoration can be applied only to the " + "Object operand of an OpStore.")); +} + +TEST_F(ValidateDecorations, GroupDecorateTargetsDecorationGroup) { + std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpDecorationGroup +OpGroupDecorate %1 %1 +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpGroupDecorate may not target OpDecorationGroup " + "'1[%1]'")); +} + +TEST_F(ValidateDecorations, GroupDecorateTargetsDecorationGroup2) { + std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%1 = OpDecorationGroup +OpGroupDecorate %1 %2 %1 +%2 = OpTypeVoid +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpGroupDecorate may not target OpDecorationGroup " + "'1[%1]'")); +} + +TEST_F(ValidateDecorations, RecurseThroughRuntimeArray) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +OpDecorate %outer Block +OpMemberDecorate %inner 0 Offset 0 +OpMemberDecorate %inner 1 Offset 1 +OpDecorate %runtime ArrayStride 16 +OpMemberDecorate %outer 0 Offset 0 +%int = OpTypeInt 32 0 +%inner = OpTypeStruct %int %int +%runtime = OpTypeRuntimeArray %inner +%outer = OpTypeStruct %runtime +%outer_ptr = OpTypePointer Uniform %outer +%var = OpVariable %outer_ptr Uniform +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Structure id 2 decorated as Block for variable in Uniform " + "storage class must follow standard uniform buffer layout " + "rules: member 1 at offset 1 is not aligned to 4")); +} + +TEST_F(ValidateDecorations, EmptyStructAtNonZeroOffsetGood) { + const std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpExecutionMode %main LocalSize 1 1 1 +OpDecorate %struct Block +OpMemberDecorate %struct 0 Offset 0 +OpMemberDecorate %struct 1 Offset 16 +OpDecorate %var DescriptorSet 0 +OpDecorate %var Binding 0 +%void = OpTypeVoid +%float = OpTypeFloat 32 +%empty = OpTypeStruct +%struct = OpTypeStruct %float %empty +%ptr_struct_ubo = OpTypePointer Uniform %struct +%var = OpVariable %ptr_struct_ubo Uniform +%voidfn = OpTypeFunction %void +%main = OpFunction %void None %voidfn +%entry = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Uniform decoration + +TEST_F(ValidateDecorations, UniformDecorationGood) { + const std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical Simple +OpEntryPoint GLCompute %main "main" +OpExecutionMode %main LocalSize 1 1 1 +OpDecorate %int0 Uniform +OpDecorate %var Uniform +OpDecorate %val Uniform +%void = OpTypeVoid +%int = OpTypeInt 32 1 +%int0 = OpConstantNull %int +%intptr = OpTypePointer Private %int +%var = OpVariable %intptr Private +%fn = OpTypeFunction %void +%main = OpFunction %void None %fn +%entry = OpLabel +%val = OpLoad %int %var +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateDecorations, UniformDecorationTargetsTypeBad) { + const std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical Simple +OpEntryPoint GLCompute %main "main" +OpExecutionMode %main LocalSize 1 1 1 +OpDecorate %fn Uniform +%void = OpTypeVoid +%fn = OpTypeFunction %void +%main = OpFunction %void None %fn +%entry = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Uniform decoration applied to a non-object")); + EXPECT_THAT(getDiagnosticString(), HasSubstr("%2 = OpTypeFunction %void")); +} + +TEST_F(ValidateDecorations, UniformDecorationTargetsVoidValueBad) { + const std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical Simple +OpEntryPoint GLCompute %main "main" +OpExecutionMode %main LocalSize 1 1 1 +OpName %call "call" +OpName %myfunc "myfunc" +OpDecorate %call Uniform +%void = OpTypeVoid +%fnty = OpTypeFunction %void +%myfunc = OpFunction %void None %fnty +%myfuncentry = OpLabel +OpReturn +OpFunctionEnd +%main = OpFunction %void None %fnty +%entry = OpLabel +%call = OpFunctionCall %void %myfunc +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Uniform decoration applied to a value with void type\n" + " %call = OpFunctionCall %void %myfunc")); +} + +TEST_F(ValidateDecorations, MultipleOffsetDecorationsOnSameID) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpMemberDecorate %struct 0 Offset 0 + OpMemberDecorate %struct 0 Offset 0 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %struct = OpTypeStruct %float + + %1 = OpFunction %void None %voidfn + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("ID '2', member '0' decorated with Offset multiple " + "times is not allowed.")); +} + +TEST_F(ValidateDecorations, MultipleArrayStrideDecorationsOnSameID) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpDecorate %array ArrayStride 4 + OpDecorate %array ArrayStride 4 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %uint = OpTypeInt 32 0 + %uint_4 = OpConstant %uint 4 + %array = OpTypeArray %float %uint_4 + + %1 = OpFunction %void None %voidfn + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("ID '2' decorated with ArrayStride multiple " + "times is not allowed.")); +} + +TEST_F(ValidateDecorations, MultipleMatrixStrideDecorationsOnSameID) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpMemberDecorate %struct 0 Offset 0 + OpMemberDecorate %struct 0 ColMajor + OpMemberDecorate %struct 0 MatrixStride 16 + OpMemberDecorate %struct 0 MatrixStride 16 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %fvec4 = OpTypeVector %float 4 + %fmat4 = OpTypeMatrix %fvec4 4 + %struct = OpTypeStruct %fmat4 + + %1 = OpFunction %void None %voidfn + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("ID '2', member '0' decorated with MatrixStride " + "multiple times is not allowed.")); +} + +TEST_F(ValidateDecorations, MultipleRowMajorDecorationsOnSameID) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpMemberDecorate %struct 0 Offset 0 + OpMemberDecorate %struct 0 MatrixStride 16 + OpMemberDecorate %struct 0 RowMajor + OpMemberDecorate %struct 0 RowMajor + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %fvec4 = OpTypeVector %float 4 + %fmat4 = OpTypeMatrix %fvec4 4 + %struct = OpTypeStruct %fmat4 + + %1 = OpFunction %void None %voidfn + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("ID '2', member '0' decorated with RowMajor multiple " + "times is not allowed.")); +} + +TEST_F(ValidateDecorations, MultipleColMajorDecorationsOnSameID) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpMemberDecorate %struct 0 Offset 0 + OpMemberDecorate %struct 0 MatrixStride 16 + OpMemberDecorate %struct 0 ColMajor + OpMemberDecorate %struct 0 ColMajor + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %fvec4 = OpTypeVector %float 4 + %fmat4 = OpTypeMatrix %fvec4 4 + %struct = OpTypeStruct %fmat4 + + %1 = OpFunction %void None %voidfn + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("ID '2', member '0' decorated with ColMajor multiple " + "times is not allowed.")); +} + +TEST_F(ValidateDecorations, RowMajorAndColMajorDecorationsOnSameID) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpMemberDecorate %struct 0 Offset 0 + OpMemberDecorate %struct 0 MatrixStride 16 + OpMemberDecorate %struct 0 ColMajor + OpMemberDecorate %struct 0 RowMajor + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %fvec4 = OpTypeVector %float 4 + %fmat4 = OpTypeMatrix %fvec4 4 + %struct = OpTypeStruct %fmat4 + + %1 = OpFunction %void None %voidfn + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("ID '2', member '0' decorated with both RowMajor and " + "ColMajor is not allowed.")); +} + +TEST_F(ValidateDecorations, BlockAndBufferBlockDecorationsOnSameID) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpDecorate %struct Block + OpDecorate %struct BufferBlock + OpMemberDecorate %struct 0 Offset 0 + OpMemberDecorate %struct 0 MatrixStride 16 + OpMemberDecorate %struct 0 RowMajor + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %fvec4 = OpTypeVector %float 4 + %fmat4 = OpTypeMatrix %fvec4 4 + %struct = OpTypeStruct %fmat4 + + %1 = OpFunction %void None %voidfn + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateAndRetrieveValidationState()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "ID '2' decorated with both BufferBlock and Block is not allowed.")); +} + +std::string MakeIntegerShader( + const std::string& decoration, const std::string& inst, + const std::string& extension = + "OpExtension \"SPV_KHR_no_integer_wrap_decoration\"") { + return R"( +OpCapability Shader +OpCapability Linkage +)" + extension + + R"( +%glsl = OpExtInstImport "GLSL.std.450" +%opencl = OpExtInstImport "OpenCL.std" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpName %entry "entry" +)" + decoration + + R"( + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %int = OpTypeInt 32 1 + %zero = OpConstantNull %int + %float = OpTypeFloat 32 + %float0 = OpConstantNull %float + %main = OpFunction %void None %voidfn + %entry = OpLabel +)" + inst + + R"( +OpReturn +OpFunctionEnd)"; +} + +// NoSignedWrap + +TEST_F(ValidateDecorations, NoSignedWrapOnTypeBad) { + std::string spirv = MakeIntegerShader("OpDecorate %void NoSignedWrap", ""); + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("NoSignedWrap decoration may not be applied to TypeVoid")); +} + +TEST_F(ValidateDecorations, NoSignedWrapOnLabelBad) { + std::string spirv = MakeIntegerShader("OpDecorate %entry NoSignedWrap", ""); + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("NoSignedWrap decoration may not be applied to Label")); +} + +TEST_F(ValidateDecorations, NoSignedWrapRequiresExtensionBad) { + std::string spirv = MakeIntegerShader("OpDecorate %val NoSignedWrap", + "%val = OpIAdd %int %zero %zero", ""); + + CompileSuccessfully(spirv); + EXPECT_NE(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("requires one of these extensions: " + "SPV_KHR_no_integer_wrap_decoration")); +} + +TEST_F(ValidateDecorations, NoSignedWrapIAddGood) { + std::string spirv = MakeIntegerShader("OpDecorate %val NoSignedWrap", + "%val = OpIAdd %int %zero %zero"); + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateDecorations, NoSignedWrapISubGood) { + std::string spirv = MakeIntegerShader("OpDecorate %val NoSignedWrap", + "%val = OpISub %int %zero %zero"); + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateDecorations, NoSignedWrapIMulGood) { + std::string spirv = MakeIntegerShader("OpDecorate %val NoSignedWrap", + "%val = OpIMul %int %zero %zero"); + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateDecorations, NoSignedWrapShiftLeftLogicalGood) { + std::string spirv = + MakeIntegerShader("OpDecorate %val NoSignedWrap", + "%val = OpShiftLeftLogical %int %zero %zero"); + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateDecorations, NoSignedWrapSNegateGood) { + std::string spirv = MakeIntegerShader("OpDecorate %val NoSignedWrap", + "%val = OpSNegate %int %zero"); + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateDecorations, NoSignedWrapSRemBad) { + std::string spirv = MakeIntegerShader("OpDecorate %val NoSignedWrap", + "%val = OpSRem %int %zero %zero"); + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("NoSignedWrap decoration may not be applied to SRem")); +} + +TEST_F(ValidateDecorations, NoSignedWrapFAddBad) { + std::string spirv = MakeIntegerShader("OpDecorate %val NoSignedWrap", + "%val = OpFAdd %float %float0 %float0"); + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("NoSignedWrap decoration may not be applied to FAdd")); +} + +TEST_F(ValidateDecorations, NoSignedWrapExtInstOpenCLGood) { + std::string spirv = + MakeIntegerShader("OpDecorate %val NoSignedWrap", + "%val = OpExtInst %int %opencl s_abs %zero"); + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateDecorations, NoSignedWrapExtInstGLSLGood) { + std::string spirv = MakeIntegerShader( + "OpDecorate %val NoSignedWrap", "%val = OpExtInst %int %glsl SAbs %zero"); + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +// TODO(dneto): For NoSignedWrap and NoUnsignedWrap, permit +// "OpExtInst for instruction numbers specified in the extended +// instruction-set specifications as accepting this decoration." + +// NoUnignedWrap + +TEST_F(ValidateDecorations, NoUnsignedWrapOnTypeBad) { + std::string spirv = MakeIntegerShader("OpDecorate %void NoUnsignedWrap", ""); + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("NoUnsignedWrap decoration may not be applied to TypeVoid")); +} + +TEST_F(ValidateDecorations, NoUnsignedWrapOnLabelBad) { + std::string spirv = MakeIntegerShader("OpDecorate %entry NoUnsignedWrap", ""); + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("NoUnsignedWrap decoration may not be applied to Label")); +} + +TEST_F(ValidateDecorations, NoUnsignedWrapRequiresExtensionBad) { + std::string spirv = MakeIntegerShader("OpDecorate %val NoUnsignedWrap", + "%val = OpIAdd %int %zero %zero", ""); + + CompileSuccessfully(spirv); + EXPECT_NE(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("requires one of these extensions: " + "SPV_KHR_no_integer_wrap_decoration")); +} + +TEST_F(ValidateDecorations, NoUnsignedWrapIAddGood) { + std::string spirv = MakeIntegerShader("OpDecorate %val NoUnsignedWrap", + "%val = OpIAdd %int %zero %zero"); + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateDecorations, NoUnsignedWrapISubGood) { + std::string spirv = MakeIntegerShader("OpDecorate %val NoUnsignedWrap", + "%val = OpISub %int %zero %zero"); + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateDecorations, NoUnsignedWrapIMulGood) { + std::string spirv = MakeIntegerShader("OpDecorate %val NoUnsignedWrap", + "%val = OpIMul %int %zero %zero"); + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateDecorations, NoUnsignedWrapShiftLeftLogicalGood) { + std::string spirv = + MakeIntegerShader("OpDecorate %val NoUnsignedWrap", + "%val = OpShiftLeftLogical %int %zero %zero"); + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateDecorations, NoUnsignedWrapSNegateGood) { + std::string spirv = MakeIntegerShader("OpDecorate %val NoUnsignedWrap", + "%val = OpSNegate %int %zero"); + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateDecorations, NoUnsignedWrapSRemBad) { + std::string spirv = MakeIntegerShader("OpDecorate %val NoUnsignedWrap", + "%val = OpSRem %int %zero %zero"); + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("NoUnsignedWrap decoration may not be applied to SRem")); +} + +TEST_F(ValidateDecorations, NoUnsignedWrapFAddBad) { + std::string spirv = MakeIntegerShader("OpDecorate %val NoUnsignedWrap", + "%val = OpFAdd %float %float0 %float0"); + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("NoUnsignedWrap decoration may not be applied to FAdd")); +} + +TEST_F(ValidateDecorations, NoUnsignedWrapExtInstOpenCLGood) { + std::string spirv = + MakeIntegerShader("OpDecorate %val NoUnsignedWrap", + "%val = OpExtInst %int %opencl s_abs %zero"); + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateDecorations, NoUnsignedWrapExtInstGLSLGood) { + std::string spirv = + MakeIntegerShader("OpDecorate %val NoUnsignedWrap", + "%val = OpExtInst %int %glsl SAbs %zero"); + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +// TODO(dneto): For NoUnsignedWrap and NoUnsignedWrap, permit +// "OpExtInst for instruction numbers specified in the extended +// instruction-set specifications as accepting this decoration." + +TEST_F(ValidateDecorations, PSBAliasedRestrictPointerSuccess) { + const std::string body = R"( +OpCapability PhysicalStorageBufferAddressesEXT +OpCapability Int64 +OpCapability Shader +OpExtension "SPV_EXT_physical_storage_buffer" +OpMemoryModel PhysicalStorageBuffer64EXT GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpDecorate %val1 RestrictPointerEXT +%uint64 = OpTypeInt 64 0 +%ptr = OpTypePointer PhysicalStorageBufferEXT %uint64 +%pptr_f = OpTypePointer Function %ptr +%void = OpTypeVoid +%voidfn = OpTypeFunction %void +%main = OpFunction %void None %voidfn +%entry = OpLabel +%val1 = OpVariable %pptr_f Function +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateDecorations, PSBAliasedRestrictPointerMissing) { + const std::string body = R"( +OpCapability PhysicalStorageBufferAddressesEXT +OpCapability Int64 +OpCapability Shader +OpExtension "SPV_EXT_physical_storage_buffer" +OpMemoryModel PhysicalStorageBuffer64EXT GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%uint64 = OpTypeInt 64 0 +%ptr = OpTypePointer PhysicalStorageBufferEXT %uint64 +%pptr_f = OpTypePointer Function %ptr +%void = OpTypeVoid +%voidfn = OpTypeFunction %void +%main = OpFunction %void None %voidfn +%entry = OpLabel +%val1 = OpVariable %pptr_f Function +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("expected AliasedPointerEXT or RestrictPointerEXT for " + "PhysicalStorageBufferEXT pointer")); +} + +TEST_F(ValidateDecorations, PSBAliasedRestrictPointerBoth) { + const std::string body = R"( +OpCapability PhysicalStorageBufferAddressesEXT +OpCapability Int64 +OpCapability Shader +OpExtension "SPV_EXT_physical_storage_buffer" +OpMemoryModel PhysicalStorageBuffer64EXT GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpDecorate %val1 RestrictPointerEXT +OpDecorate %val1 AliasedPointerEXT +%uint64 = OpTypeInt 64 0 +%ptr = OpTypePointer PhysicalStorageBufferEXT %uint64 +%pptr_f = OpTypePointer Function %ptr +%void = OpTypeVoid +%voidfn = OpTypeFunction %void +%main = OpFunction %void None %voidfn +%entry = OpLabel +%val1 = OpVariable %pptr_f Function +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("can't specify both AliasedPointerEXT and RestrictPointerEXT " + "for PhysicalStorageBufferEXT pointer")); +} + +TEST_F(ValidateDecorations, PSBAliasedRestrictFunctionParamSuccess) { + const std::string body = R"( +OpCapability PhysicalStorageBufferAddressesEXT +OpCapability Int64 +OpCapability Shader +OpExtension "SPV_EXT_physical_storage_buffer" +OpMemoryModel PhysicalStorageBuffer64EXT GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpDecorate %fparam Restrict +%uint64 = OpTypeInt 64 0 +%ptr = OpTypePointer PhysicalStorageBufferEXT %uint64 +%void = OpTypeVoid +%voidfn = OpTypeFunction %void +%fnptr = OpTypeFunction %void %ptr +%main = OpFunction %void None %voidfn +%entry = OpLabel +OpReturn +OpFunctionEnd +%fn = OpFunction %void None %fnptr +%fparam = OpFunctionParameter %ptr +%lab = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateDecorations, PSBAliasedRestrictFunctionParamMissing) { + const std::string body = R"( +OpCapability PhysicalStorageBufferAddressesEXT +OpCapability Int64 +OpCapability Shader +OpExtension "SPV_EXT_physical_storage_buffer" +OpMemoryModel PhysicalStorageBuffer64EXT GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%uint64 = OpTypeInt 64 0 +%ptr = OpTypePointer PhysicalStorageBufferEXT %uint64 +%void = OpTypeVoid +%voidfn = OpTypeFunction %void +%fnptr = OpTypeFunction %void %ptr +%main = OpFunction %void None %voidfn +%entry = OpLabel +OpReturn +OpFunctionEnd +%fn = OpFunction %void None %fnptr +%fparam = OpFunctionParameter %ptr +%lab = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("expected Aliased or Restrict for " + "PhysicalStorageBufferEXT pointer")); +} + +TEST_F(ValidateDecorations, PSBAliasedRestrictFunctionParamBoth) { + const std::string body = R"( +OpCapability PhysicalStorageBufferAddressesEXT +OpCapability Int64 +OpCapability Shader +OpExtension "SPV_EXT_physical_storage_buffer" +OpMemoryModel PhysicalStorageBuffer64EXT GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpDecorate %fparam Restrict +OpDecorate %fparam Aliased +%uint64 = OpTypeInt 64 0 +%ptr = OpTypePointer PhysicalStorageBufferEXT %uint64 +%void = OpTypeVoid +%voidfn = OpTypeFunction %void +%fnptr = OpTypeFunction %void %ptr +%main = OpFunction %void None %voidfn +%entry = OpLabel +OpReturn +OpFunctionEnd +%fn = OpFunction %void None %fnptr +%fparam = OpFunctionParameter %ptr +%lab = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("can't specify both Aliased and Restrict for " + "PhysicalStorageBufferEXT pointer")); +} + +TEST_F(ValidateDecorations, PSBFPRoundingModeSuccess) { + std::string spirv = R"( +OpCapability PhysicalStorageBufferAddressesEXT +OpCapability Shader +OpCapability Linkage +OpCapability StorageBuffer16BitAccess +OpExtension "SPV_EXT_physical_storage_buffer" +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpExtension "SPV_KHR_variable_pointers" +OpExtension "SPV_KHR_16bit_storage" +OpMemoryModel PhysicalStorageBuffer64EXT GLSL450 +OpEntryPoint GLCompute %main "main" +OpDecorate %_ FPRoundingMode RTE +OpDecorate %half_ptr_var AliasedPointerEXT +%half = OpTypeFloat 16 +%float = OpTypeFloat 32 +%float_1_25 = OpConstant %float 1.25 +%half_ptr = OpTypePointer PhysicalStorageBufferEXT %half +%half_pptr_f = OpTypePointer Function %half_ptr +%void = OpTypeVoid +%func = OpTypeFunction %void +%main = OpFunction %void None %func +%main_entry = OpLabel +%half_ptr_var = OpVariable %half_pptr_f Function +%val1 = OpLoad %half_ptr %half_ptr_var +%_ = OpFConvert %half %float_1_25 +OpStore %val1 %_ Aligned 2 +OpReturn +OpFunctionEnd + )"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_derivatives_test.cpp b/test/val/val_derivatives_test.cpp new file mode 100644 index 000000000..480042a7d --- /dev/null +++ b/test/val/val_derivatives_test.cpp @@ -0,0 +1,156 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::HasSubstr; +using ::testing::Not; + +using ValidateDerivatives = spvtest::ValidateBase; + +std::string GenerateShaderCode( + const std::string& body, + const std::string& capabilities_and_extensions = "", + const std::string& execution_model = "Fragment") { + std::stringstream ss; + ss << R"( +OpCapability Shader +OpCapability DerivativeControl +)"; + + ss << capabilities_and_extensions; + ss << "OpMemoryModel Logical GLSL450\n"; + ss << "OpEntryPoint " << execution_model << " %main \"main\"" + << " %f32_var_input" + << " %f32vec4_var_input" + << "\n"; + if (execution_model == "Fragment") { + ss << "OpExecutionMode %main OriginUpperLeft\n"; + } + + ss << R"( +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f32 = OpTypeFloat 32 +%u32 = OpTypeInt 32 0 +%s32 = OpTypeInt 32 1 +%f32vec4 = OpTypeVector %f32 4 + +%f32_ptr_input = OpTypePointer Input %f32 +%f32_var_input = OpVariable %f32_ptr_input Input + +%f32vec4_ptr_input = OpTypePointer Input %f32vec4 +%f32vec4_var_input = OpVariable %f32vec4_ptr_input Input + +%main = OpFunction %void None %func +%main_entry = OpLabel +)"; + + ss << body; + + ss << R"( +OpReturn +OpFunctionEnd)"; + + return ss.str(); +} + +TEST_F(ValidateDerivatives, ScalarSuccess) { + const std::string body = R"( +%f32_var = OpLoad %f32 %f32_var_input +%val1 = OpDPdx %f32 %f32_var +%val2 = OpDPdy %f32 %f32_var +%val3 = OpFwidth %f32 %f32_var +%val4 = OpDPdxFine %f32 %f32_var +%val5 = OpDPdyFine %f32 %f32_var +%val6 = OpFwidthFine %f32 %f32_var +%val7 = OpDPdxCoarse %f32 %f32_var +%val8 = OpDPdyCoarse %f32 %f32_var +%val9 = OpFwidthCoarse %f32 %f32_var +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateDerivatives, VectorSuccess) { + const std::string body = R"( +%f32vec4_var = OpLoad %f32vec4 %f32vec4_var_input +%val1 = OpDPdx %f32vec4 %f32vec4_var +%val2 = OpDPdy %f32vec4 %f32vec4_var +%val3 = OpFwidth %f32vec4 %f32vec4_var +%val4 = OpDPdxFine %f32vec4 %f32vec4_var +%val5 = OpDPdyFine %f32vec4 %f32vec4_var +%val6 = OpFwidthFine %f32vec4 %f32vec4_var +%val7 = OpDPdxCoarse %f32vec4 %f32vec4_var +%val8 = OpDPdyCoarse %f32vec4 %f32vec4_var +%val9 = OpFwidthCoarse %f32vec4 %f32vec4_var +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateDerivatives, OpDPdxWrongResultType) { + const std::string body = R"( +%f32_var = OpLoad %f32 %f32_var_input +%val1 = OpDPdx %u32 %f32vec4 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("Operand 10[%v4float] cannot " + "be a type")); +} + +TEST_F(ValidateDerivatives, OpDPdxWrongPType) { + const std::string body = R"( +%f32vec4_var = OpLoad %f32vec4 %f32vec4_var_input +%val1 = OpDPdx %f32 %f32vec4_var +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected P type and Result Type to be the same: " + "DPdx")); +} + +TEST_F(ValidateDerivatives, OpDPdxWrongExecutionModel) { + const std::string body = R"( +%f32vec4_var = OpLoad %f32vec4 %f32vec4_var_input +%val1 = OpDPdx %f32vec4 %f32vec4_var +)"; + + CompileSuccessfully(GenerateShaderCode(body, "", "Vertex").c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Derivative instructions require Fragment execution model: DPdx")); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_explicit_reserved_test.cpp b/test/val/val_explicit_reserved_test.cpp new file mode 100644 index 000000000..f01e933fa --- /dev/null +++ b/test/val/val_explicit_reserved_test.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validation tests for illegal instructions + +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::Eq; +using ::testing::HasSubstr; + +using ReservedSamplingInstTest = spvtest::ValidateBase; + +// Generate a shader for use with validation tests for sparse sampling +// instructions. +std::string ShaderAssembly(const std::string& instruction_under_test) { + std::ostringstream os; + os << R"( OpCapability Shader + OpCapability SparseResidency + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + OpSource GLSL 450 + OpDecorate %2 DescriptorSet 0 + OpDecorate %2 Binding 0 + %void = OpTypeVoid + %4 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %float_0 = OpConstant %float 0 + %8 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 + %9 = OpTypeImage %float 2D 0 0 0 1 Unknown + %10 = OpTypeSampledImage %9 +%_ptr_UniformConstant_10 = OpTypePointer UniformConstant %10 + %2 = OpVariable %_ptr_UniformConstant_10 UniformConstant + %v2float = OpTypeVector %float 2 + %13 = OpConstantComposite %v2float %float_0 %float_0 + %int = OpTypeInt 32 1 + %_struct_15 = OpTypeStruct %int %v4float + %1 = OpFunction %void None %4 + %16 = OpLabel + %17 = OpLoad %10 %2 +)" << instruction_under_test + << R"( + OpReturn + OpFunctionEnd +)"; + + return os.str(); +} + +TEST_F(ReservedSamplingInstTest, OpImageSparseSampleProjImplicitLod) { + const std::string input = ShaderAssembly( + "%result = OpImageSparseSampleProjImplicitLod %_struct_15 %17 %13"); + CompileSuccessfully(input); + + EXPECT_THAT(ValidateInstructions(), Eq(SPV_ERROR_INVALID_BINARY)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Invalid Opcode name 'OpImageSparseSampleProjImplicitLod'")); +} + +TEST_F(ReservedSamplingInstTest, OpImageSparseSampleProjExplicitLod) { + const std::string input = ShaderAssembly( + "%result = OpImageSparseSampleProjExplicitLod %_struct_15 %17 %13 Lod " + "%float_0\n"); + CompileSuccessfully(input); + + EXPECT_THAT(ValidateInstructions(), Eq(SPV_ERROR_INVALID_BINARY)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Invalid Opcode name 'OpImageSparseSampleProjExplicitLod'")); +} + +TEST_F(ReservedSamplingInstTest, OpImageSparseSampleProjDrefImplicitLod) { + const std::string input = ShaderAssembly( + "%result = OpImageSparseSampleProjDrefImplicitLod %_struct_15 %17 %13 " + "%float_0\n"); + CompileSuccessfully(input); + + EXPECT_THAT(ValidateInstructions(), Eq(SPV_ERROR_INVALID_BINARY)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Invalid Opcode name 'OpImageSparseSampleProjDrefImplicitLod'")); +} + +TEST_F(ReservedSamplingInstTest, OpImageSparseSampleProjDrefExplicitLod) { + const std::string input = ShaderAssembly( + "%result = OpImageSparseSampleProjDrefExplicitLod %_struct_15 %17 %13 " + "%float_0 Lod " + "%float_0\n"); + CompileSuccessfully(input); + + EXPECT_THAT(ValidateInstructions(), Eq(SPV_ERROR_INVALID_BINARY)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Invalid Opcode name 'OpImageSparseSampleProjDrefExplicitLod'")); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_ext_inst_test.cpp b/test/val/val_ext_inst_test.cpp new file mode 100644 index 000000000..d1505da45 --- /dev/null +++ b/test/val/val_ext_inst_test.cpp @@ -0,0 +1,5819 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests validation rules of GLSL.450.std and OpenCL.std extended instructions. +// Doesn't test OpenCL.std vector size 2, 3, 4, 8 or 16 rules (not supported +// by standard SPIR-V). + +#include +#include +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::Not; + +using ValidateExtInst = spvtest::ValidateBase; +using ValidateGlslStd450SqrtLike = spvtest::ValidateBase; +using ValidateGlslStd450FMinLike = spvtest::ValidateBase; +using ValidateGlslStd450FClampLike = spvtest::ValidateBase; +using ValidateGlslStd450SAbsLike = spvtest::ValidateBase; +using ValidateGlslStd450UMinLike = spvtest::ValidateBase; +using ValidateGlslStd450UClampLike = spvtest::ValidateBase; +using ValidateGlslStd450SinLike = spvtest::ValidateBase; +using ValidateGlslStd450PowLike = spvtest::ValidateBase; +using ValidateGlslStd450Pack = spvtest::ValidateBase; +using ValidateGlslStd450Unpack = spvtest::ValidateBase; +using ValidateOpenCLStdSqrtLike = spvtest::ValidateBase; +using ValidateOpenCLStdFMinLike = spvtest::ValidateBase; +using ValidateOpenCLStdFClampLike = spvtest::ValidateBase; +using ValidateOpenCLStdSAbsLike = spvtest::ValidateBase; +using ValidateOpenCLStdUMinLike = spvtest::ValidateBase; +using ValidateOpenCLStdUClampLike = spvtest::ValidateBase; +using ValidateOpenCLStdUMul24Like = spvtest::ValidateBase; +using ValidateOpenCLStdUMad24Like = spvtest::ValidateBase; +using ValidateOpenCLStdLengthLike = spvtest::ValidateBase; +using ValidateOpenCLStdDistanceLike = spvtest::ValidateBase; +using ValidateOpenCLStdNormalizeLike = spvtest::ValidateBase; +using ValidateOpenCLStdVStoreHalfLike = spvtest::ValidateBase; +using ValidateOpenCLStdVLoadHalfLike = spvtest::ValidateBase; +using ValidateOpenCLStdFractLike = spvtest::ValidateBase; +using ValidateOpenCLStdFrexpLike = spvtest::ValidateBase; +using ValidateOpenCLStdLdexpLike = spvtest::ValidateBase; +using ValidateOpenCLStdUpsampleLike = spvtest::ValidateBase; + +// Returns number of components in Pack/Unpack extended instructions. +// |ext_inst_name| is expected to be of the format "PackHalf2x16". +// Number of components is assumed to be single-digit. +uint32_t GetPackedNumComponents(const std::string& ext_inst_name) { + const size_t x_index = ext_inst_name.find_last_of('x'); + const std::string num_components_str = + ext_inst_name.substr(x_index - 1, x_index); + return uint32_t(std::stoul(num_components_str)); +} + +// Returns packed bit width in Pack/Unpack extended instructions. +// |ext_inst_name| is expected to be of the format "PackHalf2x16". +uint32_t GetPackedBitWidth(const std::string& ext_inst_name) { + const size_t x_index = ext_inst_name.find_last_of('x'); + const std::string packed_bit_width_str = ext_inst_name.substr(x_index + 1); + return uint32_t(std::stoul(packed_bit_width_str)); +} + +std::string GenerateShaderCode( + const std::string& body, + const std::string& capabilities_and_extensions = "", + const std::string& execution_model = "Fragment") { + std::ostringstream ss; + ss << R"( +OpCapability Shader +OpCapability Float16 +OpCapability Float64 +OpCapability Int16 +OpCapability Int64 +)"; + + ss << capabilities_and_extensions; + ss << "%extinst = OpExtInstImport \"GLSL.std.450\"\n"; + ss << "OpMemoryModel Logical GLSL450\n"; + ss << "OpEntryPoint " << execution_model << " %main \"main\"" + << " %f32_output" + << " %f32vec2_output" + << " %u32_output" + << " %u32vec2_output" + << " %u64_output" + << " %f32_input" + << " %f32vec2_input" + << " %u32_input" + << " %u32vec2_input" + << " %u64_input" + << "\n"; + if (execution_model == "Fragment") { + ss << "OpExecutionMode %main OriginUpperLeft\n"; + } + + ss << R"( +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f16 = OpTypeFloat 16 +%f32 = OpTypeFloat 32 +%f64 = OpTypeFloat 64 +%u32 = OpTypeInt 32 0 +%s32 = OpTypeInt 32 1 +%u64 = OpTypeInt 64 0 +%s64 = OpTypeInt 64 1 +%u16 = OpTypeInt 16 0 +%s16 = OpTypeInt 16 1 +%f32vec2 = OpTypeVector %f32 2 +%f32vec3 = OpTypeVector %f32 3 +%f32vec4 = OpTypeVector %f32 4 +%f64vec2 = OpTypeVector %f64 2 +%f64vec3 = OpTypeVector %f64 3 +%f64vec4 = OpTypeVector %f64 4 +%u32vec2 = OpTypeVector %u32 2 +%u32vec3 = OpTypeVector %u32 3 +%s32vec2 = OpTypeVector %s32 2 +%u32vec4 = OpTypeVector %u32 4 +%s32vec4 = OpTypeVector %s32 4 +%u64vec2 = OpTypeVector %u64 2 +%s64vec2 = OpTypeVector %s64 2 +%f64mat22 = OpTypeMatrix %f64vec2 2 +%f32mat22 = OpTypeMatrix %f32vec2 2 +%f32mat23 = OpTypeMatrix %f32vec2 3 +%f32mat32 = OpTypeMatrix %f32vec3 2 +%f32mat33 = OpTypeMatrix %f32vec3 3 + +%f32_0 = OpConstant %f32 0 +%f32_1 = OpConstant %f32 1 +%f32_2 = OpConstant %f32 2 +%f32_3 = OpConstant %f32 3 +%f32_4 = OpConstant %f32 4 +%f32_h = OpConstant %f32 0.5 +%f32vec2_01 = OpConstantComposite %f32vec2 %f32_0 %f32_1 +%f32vec2_12 = OpConstantComposite %f32vec2 %f32_1 %f32_2 +%f32vec3_012 = OpConstantComposite %f32vec3 %f32_0 %f32_1 %f32_2 +%f32vec3_123 = OpConstantComposite %f32vec3 %f32_1 %f32_2 %f32_3 +%f32vec4_0123 = OpConstantComposite %f32vec4 %f32_0 %f32_1 %f32_2 %f32_3 +%f32vec4_1234 = OpConstantComposite %f32vec4 %f32_1 %f32_2 %f32_3 %f32_4 + +%f64_0 = OpConstant %f64 0 +%f64_1 = OpConstant %f64 1 +%f64_2 = OpConstant %f64 2 +%f64_3 = OpConstant %f64 3 +%f64vec2_01 = OpConstantComposite %f64vec2 %f64_0 %f64_1 +%f64vec3_012 = OpConstantComposite %f64vec3 %f64_0 %f64_1 %f64_2 +%f64vec4_0123 = OpConstantComposite %f64vec4 %f64_0 %f64_1 %f64_2 %f64_3 + +%f16_0 = OpConstant %f16 0 +%f16_1 = OpConstant %f16 1 +%f16_h = OpConstant %f16 0.5 + +%u32_0 = OpConstant %u32 0 +%u32_1 = OpConstant %u32 1 +%u32_2 = OpConstant %u32 2 +%u32_3 = OpConstant %u32 3 + +%s32_0 = OpConstant %s32 0 +%s32_1 = OpConstant %s32 1 +%s32_2 = OpConstant %s32 2 +%s32_3 = OpConstant %s32 3 + +%u64_0 = OpConstant %u64 0 +%u64_1 = OpConstant %u64 1 +%u64_2 = OpConstant %u64 2 +%u64_3 = OpConstant %u64 3 + +%s64_0 = OpConstant %s64 0 +%s64_1 = OpConstant %s64 1 +%s64_2 = OpConstant %s64 2 +%s64_3 = OpConstant %s64 3 + +%s32vec2_01 = OpConstantComposite %s32vec2 %s32_0 %s32_1 +%u32vec2_01 = OpConstantComposite %u32vec2 %u32_0 %u32_1 + +%s32vec2_12 = OpConstantComposite %s32vec2 %s32_1 %s32_2 +%u32vec2_12 = OpConstantComposite %u32vec2 %u32_1 %u32_2 + +%s32vec4_0123 = OpConstantComposite %s32vec4 %s32_0 %s32_1 %s32_2 %s32_3 +%u32vec4_0123 = OpConstantComposite %u32vec4 %u32_0 %u32_1 %u32_2 %u32_3 + +%s64vec2_01 = OpConstantComposite %s64vec2 %s64_0 %s64_1 +%u64vec2_01 = OpConstantComposite %u64vec2 %u64_0 %u64_1 + +%f32mat22_1212 = OpConstantComposite %f32mat22 %f32vec2_12 %f32vec2_12 +%f32mat23_121212 = OpConstantComposite %f32mat23 %f32vec2_12 %f32vec2_12 %f32vec2_12 + +%f32_ptr_output = OpTypePointer Output %f32 +%f32vec2_ptr_output = OpTypePointer Output %f32vec2 + +%u32_ptr_output = OpTypePointer Output %u32 +%u32vec2_ptr_output = OpTypePointer Output %u32vec2 + +%u64_ptr_output = OpTypePointer Output %u64 + +%f32_output = OpVariable %f32_ptr_output Output +%f32vec2_output = OpVariable %f32vec2_ptr_output Output + +%u32_output = OpVariable %u32_ptr_output Output +%u32vec2_output = OpVariable %u32vec2_ptr_output Output + +%u64_output = OpVariable %u64_ptr_output Output + +%f32_ptr_input = OpTypePointer Input %f32 +%f32vec2_ptr_input = OpTypePointer Input %f32vec2 + +%u32_ptr_input = OpTypePointer Input %u32 +%u32vec2_ptr_input = OpTypePointer Input %u32vec2 + +%u64_ptr_input = OpTypePointer Input %u64 + +%f32_input = OpVariable %f32_ptr_input Input +%f32vec2_input = OpVariable %f32vec2_ptr_input Input + +%u32_input = OpVariable %u32_ptr_input Input +%u32vec2_input = OpVariable %u32vec2_ptr_input Input + +%u64_input = OpVariable %u64_ptr_input Input + +%struct_f16_u16 = OpTypeStruct %f16 %u16 +%struct_f32_f32 = OpTypeStruct %f32 %f32 +%struct_f32_f32_f32 = OpTypeStruct %f32 %f32 %f32 +%struct_f32_u32 = OpTypeStruct %f32 %u32 +%struct_f32_u32_f32 = OpTypeStruct %f32 %u32 %f32 +%struct_u32_f32 = OpTypeStruct %u32 %f32 +%struct_u32_u32 = OpTypeStruct %u32 %u32 +%struct_f32_f64 = OpTypeStruct %f32 %f64 +%struct_f32vec2_f32vec2 = OpTypeStruct %f32vec2 %f32vec2 +%struct_f32vec2_u32vec2 = OpTypeStruct %f32vec2 %u32vec2 + +%main = OpFunction %void None %func +%main_entry = OpLabel +)"; + + ss << body; + + ss << R"( +OpReturn +OpFunctionEnd)"; + + return ss.str(); +} + +std::string GenerateKernelCode( + const std::string& body, + const std::string& capabilities_and_extensions = "", + const std::string& memory_model = "Physical32") { + std::ostringstream ss; + ss << R"( +OpCapability Addresses +OpCapability Kernel +OpCapability Linkage +OpCapability GenericPointer +OpCapability Int8 +OpCapability Int16 +OpCapability Int64 +OpCapability Float16 +OpCapability Float64 +OpCapability Vector16 +OpCapability Matrix +)"; + + ss << capabilities_and_extensions; + ss << "%extinst = OpExtInstImport \"OpenCL.std\"\n"; + ss << "OpMemoryModel " << memory_model << " OpenCL\n"; + + ss << R"( +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f16 = OpTypeFloat 16 +%f32 = OpTypeFloat 32 +%f64 = OpTypeFloat 64 +%u32 = OpTypeInt 32 0 +%u64 = OpTypeInt 64 0 +%u16 = OpTypeInt 16 0 +%u8 = OpTypeInt 8 0 +%f32vec2 = OpTypeVector %f32 2 +%f32vec3 = OpTypeVector %f32 3 +%f32vec4 = OpTypeVector %f32 4 +%f32vec8 = OpTypeVector %f32 8 +%f16vec8 = OpTypeVector %f16 8 +%f32vec16 = OpTypeVector %f32 16 +%f64vec2 = OpTypeVector %f64 2 +%f64vec3 = OpTypeVector %f64 3 +%f64vec4 = OpTypeVector %f64 4 +%u32vec2 = OpTypeVector %u32 2 +%u32vec3 = OpTypeVector %u32 3 +%u32vec4 = OpTypeVector %u32 4 +%u32vec8 = OpTypeVector %u32 8 +%u64vec2 = OpTypeVector %u64 2 +%f64mat22 = OpTypeMatrix %f64vec2 2 +%f32mat22 = OpTypeMatrix %f32vec2 2 +%f32mat23 = OpTypeMatrix %f32vec2 3 +%f32mat32 = OpTypeMatrix %f32vec3 2 +%f32mat33 = OpTypeMatrix %f32vec3 3 + +%f32_0 = OpConstant %f32 0 +%f32_1 = OpConstant %f32 1 +%f32_2 = OpConstant %f32 2 +%f32_3 = OpConstant %f32 3 +%f32_4 = OpConstant %f32 4 +%f32_h = OpConstant %f32 0.5 +%f32vec2_01 = OpConstantComposite %f32vec2 %f32_0 %f32_1 +%f32vec2_12 = OpConstantComposite %f32vec2 %f32_1 %f32_2 +%f32vec3_012 = OpConstantComposite %f32vec3 %f32_0 %f32_1 %f32_2 +%f32vec3_123 = OpConstantComposite %f32vec3 %f32_1 %f32_2 %f32_3 +%f32vec4_0123 = OpConstantComposite %f32vec4 %f32_0 %f32_1 %f32_2 %f32_3 +%f32vec4_1234 = OpConstantComposite %f32vec4 %f32_1 %f32_2 %f32_3 %f32_4 +%f32vec8_01010101 = OpConstantComposite %f32vec8 %f32_0 %f32_1 %f32_0 %f32_1 %f32_0 %f32_1 %f32_0 %f32_1 + +%f64_0 = OpConstant %f64 0 +%f64_1 = OpConstant %f64 1 +%f64_2 = OpConstant %f64 2 +%f64_3 = OpConstant %f64 3 +%f64vec2_01 = OpConstantComposite %f64vec2 %f64_0 %f64_1 +%f64vec3_012 = OpConstantComposite %f64vec3 %f64_0 %f64_1 %f64_2 +%f64vec4_0123 = OpConstantComposite %f64vec4 %f64_0 %f64_1 %f64_2 %f64_3 + +%f16_0 = OpConstant %f16 0 +%f16_1 = OpConstant %f16 1 + +%u8_0 = OpConstant %u8 0 +%u8_1 = OpConstant %u8 1 +%u8_2 = OpConstant %u8 2 +%u8_3 = OpConstant %u8 3 + +%u16_0 = OpConstant %u16 0 +%u16_1 = OpConstant %u16 1 +%u16_2 = OpConstant %u16 2 +%u16_3 = OpConstant %u16 3 + +%u32_0 = OpConstant %u32 0 +%u32_1 = OpConstant %u32 1 +%u32_2 = OpConstant %u32 2 +%u32_3 = OpConstant %u32 3 +%u32_256 = OpConstant %u32 256 + +%u64_0 = OpConstant %u64 0 +%u64_1 = OpConstant %u64 1 +%u64_2 = OpConstant %u64 2 +%u64_3 = OpConstant %u64 3 +%u64_256 = OpConstant %u64 256 + +%u32vec2_01 = OpConstantComposite %u32vec2 %u32_0 %u32_1 +%u32vec2_12 = OpConstantComposite %u32vec2 %u32_1 %u32_2 +%u32vec3_012 = OpConstantComposite %u32vec3 %u32_0 %u32_1 %u32_2 +%u32vec4_0123 = OpConstantComposite %u32vec4 %u32_0 %u32_1 %u32_2 %u32_3 + +%u64vec2_01 = OpConstantComposite %u64vec2 %u64_0 %u64_1 + +%f32mat22_1212 = OpConstantComposite %f32mat22 %f32vec2_12 %f32vec2_12 +%f32mat23_121212 = OpConstantComposite %f32mat23 %f32vec2_12 %f32vec2_12 %f32vec2_12 + +%struct_f32_f32 = OpTypeStruct %f32 %f32 +%struct_f32_f32_f32 = OpTypeStruct %f32 %f32 %f32 +%struct_f32_u32 = OpTypeStruct %f32 %u32 +%struct_f32_u32_f32 = OpTypeStruct %f32 %u32 %f32 +%struct_u32_f32 = OpTypeStruct %u32 %f32 +%struct_u32_u32 = OpTypeStruct %u32 %u32 +%struct_f32_f64 = OpTypeStruct %f32 %f64 +%struct_f32vec2_f32vec2 = OpTypeStruct %f32vec2 %f32vec2 +%struct_f32vec2_u32vec2 = OpTypeStruct %f32vec2 %u32vec2 + +%f16vec8_ptr_workgroup = OpTypePointer Workgroup %f16vec8 +%f16vec8_workgroup = OpVariable %f16vec8_ptr_workgroup Workgroup +%f16_ptr_workgroup = OpTypePointer Workgroup %f16 + +%u32vec8_ptr_workgroup = OpTypePointer Workgroup %u32vec8 +%u32vec8_workgroup = OpVariable %u32vec8_ptr_workgroup Workgroup +%u32_ptr_workgroup = OpTypePointer Workgroup %u32 + +%f32vec8_ptr_workgroup = OpTypePointer Workgroup %f32vec8 +%f32vec8_workgroup = OpVariable %f32vec8_ptr_workgroup Workgroup +%f32_ptr_workgroup = OpTypePointer Workgroup %f32 + +%u32arr = OpTypeArray %u32 %u32_256 +%u32arr_ptr_cross_workgroup = OpTypePointer CrossWorkgroup %u32arr +%u32arr_cross_workgroup = OpVariable %u32arr_ptr_cross_workgroup CrossWorkgroup +%u32_ptr_cross_workgroup = OpTypePointer CrossWorkgroup %u32 + +%f32arr = OpTypeArray %f32 %u32_256 +%f32arr_ptr_cross_workgroup = OpTypePointer CrossWorkgroup %f32arr +%f32arr_cross_workgroup = OpVariable %f32arr_ptr_cross_workgroup CrossWorkgroup +%f32_ptr_cross_workgroup = OpTypePointer CrossWorkgroup %f32 + +%f32vec2arr = OpTypeArray %f32vec2 %u32_256 +%f32vec2arr_ptr_cross_workgroup = OpTypePointer CrossWorkgroup %f32vec2arr +%f32vec2arr_cross_workgroup = OpVariable %f32vec2arr_ptr_cross_workgroup CrossWorkgroup +%f32vec2_ptr_cross_workgroup = OpTypePointer CrossWorkgroup %f32vec2 + +%struct_arr = OpTypeArray %struct_f32_f32 %u32_256 +%struct_arr_ptr_cross_workgroup = OpTypePointer CrossWorkgroup %struct_arr +%struct_arr_cross_workgroup = OpVariable %struct_arr_ptr_cross_workgroup CrossWorkgroup +%struct_ptr_cross_workgroup = OpTypePointer CrossWorkgroup %struct_f32_f32 + +%f16vec8_ptr_uniform_constant = OpTypePointer UniformConstant %f16vec8 +%f16vec8_uniform_constant = OpVariable %f16vec8_ptr_uniform_constant UniformConstant +%f16_ptr_uniform_constant = OpTypePointer UniformConstant %f16 + +%u32vec8_ptr_uniform_constant = OpTypePointer UniformConstant %u32vec8 +%u32vec8_uniform_constant = OpVariable %u32vec8_ptr_uniform_constant UniformConstant +%u32_ptr_uniform_constant = OpTypePointer UniformConstant %u32 + +%f32vec8_ptr_uniform_constant = OpTypePointer UniformConstant %f32vec8 +%f32vec8_uniform_constant = OpVariable %f32vec8_ptr_uniform_constant UniformConstant +%f32_ptr_uniform_constant = OpTypePointer UniformConstant %f32 + +%f16vec8_ptr_input = OpTypePointer Input %f16vec8 +%f16vec8_input = OpVariable %f16vec8_ptr_input Input +%f16_ptr_input = OpTypePointer Input %f16 + +%f32_ptr_generic = OpTypePointer Generic %f32 +%u32_ptr_generic = OpTypePointer Generic %u32 + +%f32_ptr_function = OpTypePointer Function %f32 +%f32vec2_ptr_function = OpTypePointer Function %f32vec2 +%u32_ptr_function = OpTypePointer Function %u32 +%u64_ptr_function = OpTypePointer Function %u64 +%u32vec2_ptr_function = OpTypePointer Function %u32vec2 + +%u8arr = OpTypeArray %u8 %u32_256 +%u8arr_ptr_uniform_constant = OpTypePointer UniformConstant %u8arr +%u8arr_uniform_constant = OpVariable %u8arr_ptr_uniform_constant UniformConstant +%u8_ptr_uniform_constant = OpTypePointer UniformConstant %u8 +%u8_ptr_generic = OpTypePointer Generic %u8 + +%main = OpFunction %void None %func +%main_entry = OpLabel +)"; + + ss << body; + + ss << R"( +OpReturn +OpFunctionEnd)"; + + return ss.str(); +} + +TEST_P(ValidateGlslStd450SqrtLike, Success) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name << " %f32_0\n"; + ss << "%val2 = OpExtInst %f32vec2 %extinst " << ext_inst_name + << " %f32vec2_01\n"; + ss << "%val3 = OpExtInst %f64 %extinst " << ext_inst_name << " %f64_0\n"; + CompileSuccessfully(GenerateShaderCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateGlslStd450SqrtLike, IntResultType) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + " %f32_0\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected Result Type to be a float scalar " + "or vector type")); +} + +TEST_P(ValidateGlslStd450SqrtLike, IntOperand) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + " %u32_0\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected types of all operands to be equal to " + "Result Type")); +} + +INSTANTIATE_TEST_CASE_P(AllSqrtLike, ValidateGlslStd450SqrtLike, + ::testing::ValuesIn(std::vector{ + "Round", + "RoundEven", + "FAbs", + "Trunc", + "FSign", + "Floor", + "Ceil", + "Fract", + "Sqrt", + "InverseSqrt", + "Normalize", + }), ); + +TEST_P(ValidateGlslStd450FMinLike, Success) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name + << " %f32_0 %f32_1\n"; + ss << "%val2 = OpExtInst %f32vec2 %extinst " << ext_inst_name + << " %f32vec2_01 %f32vec2_12\n"; + ss << "%val3 = OpExtInst %f64 %extinst " << ext_inst_name + << " %f64_0 %f64_0\n"; + CompileSuccessfully(GenerateShaderCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateGlslStd450FMinLike, IntResultType) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + " %f32_0 %f32_1\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected Result Type to be a float scalar " + "or vector type")); +} + +TEST_P(ValidateGlslStd450FMinLike, IntOperand1) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + " %u32_0 %f32_1\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected types of all operands to be equal to " + "Result Type")); +} + +TEST_P(ValidateGlslStd450FMinLike, IntOperand2) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + " %f32_0 %u32_1\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected types of all operands to be equal to " + "Result Type")); +} + +INSTANTIATE_TEST_CASE_P(AllFMinLike, ValidateGlslStd450FMinLike, + ::testing::ValuesIn(std::vector{ + "FMin", + "FMax", + "Step", + "Reflect", + "NMin", + "NMax", + }), ); + +TEST_P(ValidateGlslStd450FClampLike, Success) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name + << " %f32_0 %f32_1 %f32_2\n"; + ss << "%val2 = OpExtInst %f32vec2 %extinst " << ext_inst_name + << " %f32vec2_01 %f32vec2_01 %f32vec2_12\n"; + ss << "%val3 = OpExtInst %f64 %extinst " << ext_inst_name + << " %f64_0 %f64_0 %f64_1\n"; + CompileSuccessfully(GenerateShaderCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateGlslStd450FClampLike, IntResultType) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + + " %f32_0 %f32_1 %f32_2\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected Result Type to be a float scalar " + "or vector type")); +} + +TEST_P(ValidateGlslStd450FClampLike, IntOperand1) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + + " %u32_0 %f32_0 %f32_1\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected types of all operands to be equal to " + "Result Type")); +} + +TEST_P(ValidateGlslStd450FClampLike, IntOperand2) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + + " %f32_0 %u32_0 %f32_1\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected types of all operands to be equal to " + "Result Type")); +} + +TEST_P(ValidateGlslStd450FClampLike, IntOperand3) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + + " %f32_1 %f32_0 %u32_2\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected types of all operands to be equal to " + "Result Type")); +} + +INSTANTIATE_TEST_CASE_P(AllFClampLike, ValidateGlslStd450FClampLike, + ::testing::ValuesIn(std::vector{ + "FClamp", + "FMix", + "SmoothStep", + "Fma", + "FaceForward", + "NClamp", + }), ); + +TEST_P(ValidateGlslStd450SAbsLike, Success) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %s32 %extinst " << ext_inst_name << " %u32_1\n"; + ss << "%val2 = OpExtInst %s32 %extinst " << ext_inst_name << " %s32_1\n"; + ss << "%val3 = OpExtInst %u32 %extinst " << ext_inst_name << " %u32_1\n"; + ss << "%val4 = OpExtInst %u32 %extinst " << ext_inst_name << " %s32_1\n"; + ss << "%val5 = OpExtInst %s32vec2 %extinst " << ext_inst_name + << " %s32vec2_01\n"; + ss << "%val6 = OpExtInst %u32vec2 %extinst " << ext_inst_name + << " %u32vec2_01\n"; + ss << "%val7 = OpExtInst %u32vec2 %extinst " << ext_inst_name + << " %s32vec2_01\n"; + ss << "%val8 = OpExtInst %s32vec2 %extinst " << ext_inst_name + << " %u32vec2_01\n"; + CompileSuccessfully(GenerateShaderCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateGlslStd450SAbsLike, FloatResultType) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + " %u32_0\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected Result Type to be an int scalar " + "or vector type")); +} + +TEST_P(ValidateGlslStd450SAbsLike, FloatOperand) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %s32 %extinst " + ext_inst_name + " %f32_0\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected all operands to be int scalars or " + "vectors")); +} + +TEST_P(ValidateGlslStd450SAbsLike, WrongDimOperand) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %s32 %extinst " + ext_inst_name + " %s32vec2_01\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected all operands to have the same dimension as " + "Result Type")); +} + +TEST_P(ValidateGlslStd450SAbsLike, WrongBitWidthOperand) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %s64 %extinst " + ext_inst_name + " %s32_0\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected all operands to have the same bit width as " + "Result Type")); +} + +INSTANTIATE_TEST_CASE_P(AllSAbsLike, ValidateGlslStd450SAbsLike, + ::testing::ValuesIn(std::vector{ + "SAbs", + "SSign", + "FindILsb", + "FindUMsb", + "FindSMsb", + }), ); + +TEST_F(ValidateExtInst, FindUMsbNot32Bit) { + const std::string body = R"( +%val1 = OpExtInst %s64 %extinst FindUMsb %u64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 FindUMsb: this instruction is currently " + "limited to 32-bit width components")); +} + +TEST_F(ValidateExtInst, FindSMsbNot32Bit) { + const std::string body = R"( +%val1 = OpExtInst %s64 %extinst FindSMsb %u64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 FindSMsb: this instruction is currently " + "limited to 32-bit width components")); +} + +TEST_P(ValidateGlslStd450UMinLike, Success) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %s32 %extinst " << ext_inst_name + << " %u32_1 %s32_2\n"; + ss << "%val2 = OpExtInst %s32 %extinst " << ext_inst_name + << " %s32_1 %u32_2\n"; + ss << "%val3 = OpExtInst %u32 %extinst " << ext_inst_name + << " %u32_1 %s32_2\n"; + ss << "%val4 = OpExtInst %u32 %extinst " << ext_inst_name + << " %s32_1 %u32_2\n"; + ss << "%val5 = OpExtInst %s32vec2 %extinst " << ext_inst_name + << " %s32vec2_01 %u32vec2_01\n"; + ss << "%val6 = OpExtInst %u32vec2 %extinst " << ext_inst_name + << " %u32vec2_01 %s32vec2_01\n"; + ss << "%val7 = OpExtInst %u32vec2 %extinst " << ext_inst_name + << " %s32vec2_01 %u32vec2_01\n"; + ss << "%val8 = OpExtInst %s32vec2 %extinst " << ext_inst_name + << " %u32vec2_01 %s32vec2_01\n"; + ss << "%val9 = OpExtInst %s64 %extinst " << ext_inst_name + << " %u64_1 %s64_0\n"; + CompileSuccessfully(GenerateShaderCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateGlslStd450UMinLike, FloatResultType) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + " %u32_0 %u32_0\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected Result Type to be an int scalar " + "or vector type")); +} + +TEST_P(ValidateGlslStd450UMinLike, FloatOperand1) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %s32 %extinst " + ext_inst_name + " %f32_0 %u32_0\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected all operands to be int scalars or " + "vectors")); +} + +TEST_P(ValidateGlslStd450UMinLike, FloatOperand2) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %s32 %extinst " + ext_inst_name + " %u32_0 %f32_0\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected all operands to be int scalars or " + "vectors")); +} + +TEST_P(ValidateGlslStd450UMinLike, WrongDimOperand1) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %s32 %extinst " + ext_inst_name + + " %s32vec2_01 %s32_0\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected all operands to have the same dimension as " + "Result Type")); +} + +TEST_P(ValidateGlslStd450UMinLike, WrongDimOperand2) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %s32 %extinst " + ext_inst_name + + " %s32_0 %s32vec2_01\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected all operands to have the same dimension as " + "Result Type")); +} + +TEST_P(ValidateGlslStd450UMinLike, WrongBitWidthOperand1) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %s64 %extinst " + ext_inst_name + " %s32_0 %s64_0\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected all operands to have the same bit width as " + "Result Type")); +} + +TEST_P(ValidateGlslStd450UMinLike, WrongBitWidthOperand2) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %s64 %extinst " + ext_inst_name + " %s64_0 %s32_0\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected all operands to have the same bit width as " + "Result Type")); +} + +INSTANTIATE_TEST_CASE_P(AllUMinLike, ValidateGlslStd450UMinLike, + ::testing::ValuesIn(std::vector{ + "UMin", + "SMin", + "UMax", + "SMax", + }), ); + +TEST_P(ValidateGlslStd450UClampLike, Success) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %s32 %extinst " << ext_inst_name + << " %s32_0 %u32_1 %s32_2\n"; + ss << "%val2 = OpExtInst %s32 %extinst " << ext_inst_name + << " %u32_0 %s32_1 %u32_2\n"; + ss << "%val3 = OpExtInst %u32 %extinst " << ext_inst_name + << " %s32_0 %u32_1 %s32_2\n"; + ss << "%val4 = OpExtInst %u32 %extinst " << ext_inst_name + << " %u32_0 %s32_1 %u32_2\n"; + ss << "%val5 = OpExtInst %s32vec2 %extinst " << ext_inst_name + << " %s32vec2_01 %u32vec2_01 %u32vec2_12\n"; + ss << "%val6 = OpExtInst %u32vec2 %extinst " << ext_inst_name + << " %u32vec2_01 %s32vec2_01 %s32vec2_12\n"; + ss << "%val7 = OpExtInst %u32vec2 %extinst " << ext_inst_name + << " %s32vec2_01 %u32vec2_01 %u32vec2_12\n"; + ss << "%val8 = OpExtInst %s32vec2 %extinst " << ext_inst_name + << " %u32vec2_01 %s32vec2_01 %s32vec2_12\n"; + ss << "%val9 = OpExtInst %s64 %extinst " << ext_inst_name + << " %u64_1 %s64_0 %s64_1\n"; + CompileSuccessfully(GenerateShaderCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateGlslStd450UClampLike, FloatResultType) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + + " %u32_0 %u32_0 %u32_1\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected Result Type to be an int scalar " + "or vector type")); +} + +TEST_P(ValidateGlslStd450UClampLike, FloatOperand1) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %s32 %extinst " + ext_inst_name + + " %f32_0 %u32_0 %u32_1\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected all operands to be int scalars or " + "vectors")); +} + +TEST_P(ValidateGlslStd450UClampLike, FloatOperand2) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %s32 %extinst " + ext_inst_name + + " %u32_0 %f32_0 %u32_1\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected all operands to be int scalars or " + "vectors")); +} + +TEST_P(ValidateGlslStd450UClampLike, FloatOperand3) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %s32 %extinst " + ext_inst_name + + " %u32_0 %u32_0 %f32_1\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected all operands to be int scalars or " + "vectors")); +} + +TEST_P(ValidateGlslStd450UClampLike, WrongDimOperand1) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %s32 %extinst " + ext_inst_name + + " %s32vec2_01 %s32_0 %u32_1\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected all operands to have the same dimension as " + "Result Type")); +} + +TEST_P(ValidateGlslStd450UClampLike, WrongDimOperand2) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %s32 %extinst " + ext_inst_name + + " %s32_0 %s32vec2_01 %u32_1\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected all operands to have the same dimension as " + "Result Type")); +} + +TEST_P(ValidateGlslStd450UClampLike, WrongDimOperand3) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %s32 %extinst " + ext_inst_name + + " %s32_0 %u32_1 %s32vec2_01\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected all operands to have the same dimension as " + "Result Type")); +} + +TEST_P(ValidateGlslStd450UClampLike, WrongBitWidthOperand1) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %s64 %extinst " + ext_inst_name + + " %s32_0 %s64_0 %s64_1\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected all operands to have the same bit width as " + "Result Type")); +} + +TEST_P(ValidateGlslStd450UClampLike, WrongBitWidthOperand2) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %s64 %extinst " + ext_inst_name + + " %s64_0 %s32_0 %s64_1\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected all operands to have the same bit width as " + "Result Type")); +} + +TEST_P(ValidateGlslStd450UClampLike, WrongBitWidthOperand3) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %s64 %extinst " + ext_inst_name + + " %s64_0 %s64_0 %s32_1\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected all operands to have the same bit width as " + "Result Type")); +} + +INSTANTIATE_TEST_CASE_P(AllUClampLike, ValidateGlslStd450UClampLike, + ::testing::ValuesIn(std::vector{ + "UClamp", + "SClamp", + }), ); + +TEST_P(ValidateGlslStd450SinLike, Success) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name << " %f32_0\n"; + ss << "%val2 = OpExtInst %f32vec2 %extinst " << ext_inst_name + << " %f32vec2_01\n"; + CompileSuccessfully(GenerateShaderCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateGlslStd450SinLike, IntResultType) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + " %f32_0\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected Result Type to be a 16 or 32-bit scalar " + "or vector float type")); +} + +TEST_P(ValidateGlslStd450SinLike, F64ResultType) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %f64 %extinst " + ext_inst_name + " %f32_0\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected Result Type to be a 16 or 32-bit scalar " + "or vector float type")); +} + +TEST_P(ValidateGlslStd450SinLike, IntOperand) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + " %u32_0\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected types of all operands to be equal to " + "Result Type")); +} + +INSTANTIATE_TEST_CASE_P(AllSinLike, ValidateGlslStd450SinLike, + ::testing::ValuesIn(std::vector{ + "Radians", + "Degrees", + "Sin", + "Cos", + "Tan", + "Asin", + "Acos", + "Atan", + "Sinh", + "Cosh", + "Tanh", + "Asinh", + "Acosh", + "Atanh", + "Exp", + "Exp2", + "Log", + "Log2", + }), ); + +TEST_P(ValidateGlslStd450PowLike, Success) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name + << " %f32_1 %f32_1\n"; + ss << "%val2 = OpExtInst %f32vec2 %extinst " << ext_inst_name + << " %f32vec2_01 %f32vec2_12\n"; + CompileSuccessfully(GenerateShaderCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateGlslStd450PowLike, IntResultType) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + " %f32_1 %f32_0\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected Result Type to be a 16 or 32-bit scalar " + "or vector float type")); +} + +TEST_P(ValidateGlslStd450PowLike, F64ResultType) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %f64 %extinst " + ext_inst_name + " %f32_1 %f32_0\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected Result Type to be a 16 or 32-bit scalar " + "or vector float type")); +} + +TEST_P(ValidateGlslStd450PowLike, IntOperand1) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + " %u32_0 %f32_1\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected types of all operands to be equal to " + "Result Type")); +} + +TEST_P(ValidateGlslStd450PowLike, IntOperand2) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + " %f32_0 %u32_1\n"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 " + ext_inst_name + + ": expected types of all operands to be equal to " + "Result Type")); +} + +INSTANTIATE_TEST_CASE_P(AllPowLike, ValidateGlslStd450PowLike, + ::testing::ValuesIn(std::vector{ + "Atan2", + "Pow", + }), ); + +TEST_F(ValidateExtInst, GlslStd450DeterminantSuccess) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst Determinant %f32mat22_1212 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, GlslStd450DeterminantIncompatibleResultType) { + const std::string body = R"( +%val1 = OpExtInst %f64 %extinst Determinant %f32mat22_1212 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Determinant: " + "expected operand X component type to be equal to " + "Result Type")); +} + +TEST_F(ValidateExtInst, GlslStd450DeterminantNotMatrix) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst Determinant %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Determinant: " + "expected operand X to be a square matrix")); +} + +TEST_F(ValidateExtInst, GlslStd450DeterminantMatrixNotSquare) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst Determinant %f32mat23_121212 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Determinant: " + "expected operand X to be a square matrix")); +} + +TEST_F(ValidateExtInst, GlslStd450MatrixInverseSuccess) { + const std::string body = R"( +%val1 = OpExtInst %f32mat22 %extinst MatrixInverse %f32mat22_1212 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, GlslStd450MatrixInverseIncompatibleResultType) { + const std::string body = R"( +%val1 = OpExtInst %f32mat33 %extinst MatrixInverse %f32mat22_1212 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 MatrixInverse: " + "expected operand X type to be equal to " + "Result Type")); +} + +TEST_F(ValidateExtInst, GlslStd450MatrixInverseNotMatrix) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst MatrixInverse %f32mat22_1212 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 MatrixInverse: " + "expected Result Type to be a square matrix")); +} + +TEST_F(ValidateExtInst, GlslStd450MatrixInverseMatrixNotSquare) { + const std::string body = R"( +%val1 = OpExtInst %f32mat23 %extinst MatrixInverse %f32mat23_121212 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 MatrixInverse: " + "expected Result Type to be a square matrix")); +} + +TEST_F(ValidateExtInst, GlslStd450ModfSuccess) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst Modf %f32_h %f32_output +%val2 = OpExtInst %f32vec2 %extinst Modf %f32vec2_01 %f32vec2_output +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, GlslStd450ModfIntResultType) { + const std::string body = R"( +%val1 = OpExtInst %u32 %extinst Modf %f32_h %f32_output +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Modf: " + "expected Result Type to be a scalar or vector " + "float type")); +} + +TEST_F(ValidateExtInst, GlslStd450ModfXNotOfResultType) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst Modf %f64_0 %f32_output +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Modf: " + "expected operand X type to be equal to Result Type")); +} + +TEST_F(ValidateExtInst, GlslStd450ModfINotPointer) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst Modf %f32_h %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Modf: " + "expected operand I to be a pointer")); +} + +TEST_F(ValidateExtInst, GlslStd450ModfIDataNotOfResultType) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst Modf %f32_h %f32vec2_output +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Modf: " + "expected operand I data type to be equal to " + "Result Type")); +} + +TEST_F(ValidateExtInst, GlslStd450ModfStructSuccess) { + const std::string body = R"( +%val1 = OpExtInst %struct_f32_f32 %extinst ModfStruct %f32_h +%val2 = OpExtInst %struct_f32vec2_f32vec2 %extinst ModfStruct %f32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, GlslStd450ModfStructResultTypeNotStruct) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst ModfStruct %f32_h +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 ModfStruct: " + "expected Result Type to be a struct with two " + "identical scalar or vector float type members")); +} + +TEST_F(ValidateExtInst, GlslStd450ModfStructResultTypeStructWrongSize) { + const std::string body = R"( +%val1 = OpExtInst %struct_f32_f32_f32 %extinst ModfStruct %f32_h +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 ModfStruct: " + "expected Result Type to be a struct with two " + "identical scalar or vector float type members")); +} + +TEST_F(ValidateExtInst, GlslStd450ModfStructResultTypeStructWrongFirstMember) { + const std::string body = R"( +%val1 = OpExtInst %struct_u32_f32 %extinst ModfStruct %f32_h +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 ModfStruct: " + "expected Result Type to be a struct with two " + "identical scalar or vector float type members")); +} + +TEST_F(ValidateExtInst, GlslStd450ModfStructResultTypeStructMembersNotEqual) { + const std::string body = R"( +%val1 = OpExtInst %struct_f32_f64 %extinst ModfStruct %f32_h +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 ModfStruct: " + "expected Result Type to be a struct with two " + "identical scalar or vector float type members")); +} + +TEST_F(ValidateExtInst, GlslStd450ModfStructXWrongType) { + const std::string body = R"( +%val1 = OpExtInst %struct_f32_f32 %extinst ModfStruct %f64_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 ModfStruct: " + "expected operand X type to be equal to members of " + "Result Type struct")); +} + +TEST_F(ValidateExtInst, GlslStd450FrexpSuccess) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst Frexp %f32_h %u32_output +%val2 = OpExtInst %f32vec2 %extinst Frexp %f32vec2_01 %u32vec2_output +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, GlslStd450FrexpIntResultType) { + const std::string body = R"( +%val1 = OpExtInst %u32 %extinst Frexp %f32_h %u32_output +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Frexp: " + "expected Result Type to be a scalar or vector " + "float type")); +} + +TEST_F(ValidateExtInst, GlslStd450FrexpWrongXType) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst Frexp %u32_1 %u32_output +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Frexp: " + "expected operand X type to be equal to Result Type")); +} + +TEST_F(ValidateExtInst, GlslStd450FrexpExpNotPointer) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst Frexp %f32_1 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Frexp: " + "expected operand Exp to be a pointer")); +} + +TEST_F(ValidateExtInst, GlslStd450FrexpExpNotInt32Pointer) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst Frexp %f32_1 %f32_output +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Frexp: " + "expected operand Exp data type to be a 32-bit int " + "scalar or vector type")); +} + +TEST_F(ValidateExtInst, GlslStd450FrexpExpWrongComponentNumber) { + const std::string body = R"( +%val1 = OpExtInst %f32vec2 %extinst Frexp %f32vec2_01 %u32_output +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Frexp: " + "expected operand Exp data type to have the same " + "component number as Result Type")); +} + +TEST_F(ValidateExtInst, GlslStd450LdexpSuccess) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst Ldexp %f32_h %u32_2 +%val2 = OpExtInst %f32vec2 %extinst Ldexp %f32vec2_01 %u32vec2_12 +%val3 = OpExtInst %f32 %extinst Ldexp %f32_h %u64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, GlslStd450LdexpIntResultType) { + const std::string body = R"( +%val1 = OpExtInst %u32 %extinst Ldexp %f32_h %u32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Ldexp: " + "expected Result Type to be a scalar or vector " + "float type")); +} + +TEST_F(ValidateExtInst, GlslStd450LdexpWrongXType) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst Ldexp %u32_1 %u32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Ldexp: " + "expected operand X type to be equal to Result Type")); +} + +TEST_F(ValidateExtInst, GlslStd450LdexpFloatExp) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst Ldexp %f32_1 %f32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Ldexp: " + "expected operand Exp to be a 32-bit int scalar " + "or vector type")); +} + +TEST_F(ValidateExtInst, GlslStd450LdexpExpWrongSize) { + const std::string body = R"( +%val1 = OpExtInst %f32vec2 %extinst Ldexp %f32vec2_12 %u32_2 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Ldexp: " + "expected operand Exp to have the same component " + "number as Result Type")); +} + +TEST_F(ValidateExtInst, GlslStd450FrexpStructSuccess) { + const std::string body = R"( +%val1 = OpExtInst %struct_f32_u32 %extinst FrexpStruct %f32_h +%val2 = OpExtInst %struct_f32vec2_u32vec2 %extinst FrexpStruct %f32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, GlslStd450FrexpStructResultTypeNotStruct) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst FrexpStruct %f32_h +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 FrexpStruct: " + "expected Result Type to be a struct with two members, " + "first member a float scalar or vector, second member " + "a 32-bit int scalar or vector with the same number of " + "components as the first member")); +} + +TEST_F(ValidateExtInst, GlslStd450FrexpStructResultTypeStructWrongSize) { + const std::string body = R"( +%val1 = OpExtInst %struct_f32_u32_f32 %extinst FrexpStruct %f32_h +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 FrexpStruct: " + "expected Result Type to be a struct with two members, " + "first member a float scalar or vector, second member " + "a 32-bit int scalar or vector with the same number of " + "components as the first member")); +} + +TEST_F(ValidateExtInst, GlslStd450FrexpStructResultTypeStructWrongMember1) { + const std::string body = R"( +%val1 = OpExtInst %struct_u32_u32 %extinst FrexpStruct %f32_h +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 FrexpStruct: " + "expected Result Type to be a struct with two members, " + "first member a float scalar or vector, second member " + "a 32-bit int scalar or vector with the same number of " + "components as the first member")); +} + +TEST_F(ValidateExtInst, GlslStd450FrexpStructResultTypeStructWrongMember2) { + const std::string body = R"( +%val1 = OpExtInst %struct_f32_f32 %extinst FrexpStruct %f32_h +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 FrexpStruct: " + "expected Result Type to be a struct with two members, " + "first member a float scalar or vector, second member " + "a 32-bit int scalar or vector with the same number of " + "components as the first member")); +} + +TEST_F(ValidateExtInst, GlslStd450FrexpStructXWrongType) { + const std::string body = R"( +%val1 = OpExtInst %struct_f32_u32 %extinst FrexpStruct %f64_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 FrexpStruct: " + "expected operand X type to be equal to the first " + "member of Result Type struct")); +} + +TEST_F(ValidateExtInst, + GlslStd450FrexpStructResultTypeStructRightInt16Member2) { + const std::string body = R"( +%val1 = OpExtInst %struct_f16_u16 %extinst FrexpStruct %f16_h +)"; + + const std::string extension = R"( +OpExtension "SPV_AMD_gpu_shader_int16" +)"; + + CompileSuccessfully(GenerateShaderCode(body, extension)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, + GlslStd450FrexpStructResultTypeStructWrongInt16Member2) { + const std::string body = R"( +%val1 = OpExtInst %struct_f16_u16 %extinst FrexpStruct %f16_h +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 FrexpStruct: " + "expected Result Type to be a struct with two members, " + "first member a float scalar or vector, second member " + "a 32-bit int scalar or vector with the same number of " + "components as the first member")); +} + +TEST_P(ValidateGlslStd450Pack, Success) { + const std::string ext_inst_name = GetParam(); + const uint32_t num_components = GetPackedNumComponents(ext_inst_name); + const uint32_t packed_bit_width = GetPackedBitWidth(ext_inst_name); + const uint32_t total_bit_width = num_components * packed_bit_width; + const std::string vec_str = + num_components == 2 ? " %f32vec2_01\n" : " %f32vec4_0123\n"; + + std::ostringstream body; + body << "%val1 = OpExtInst %u" << total_bit_width << " %extinst " + << ext_inst_name << vec_str; + body << "%val2 = OpExtInst %s" << total_bit_width << " %extinst " + << ext_inst_name << vec_str; + CompileSuccessfully(GenerateShaderCode(body.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateGlslStd450Pack, Float32ResultType) { + const std::string ext_inst_name = GetParam(); + const uint32_t num_components = GetPackedNumComponents(ext_inst_name); + const uint32_t packed_bit_width = GetPackedBitWidth(ext_inst_name); + const uint32_t total_bit_width = num_components * packed_bit_width; + const std::string vec_str = + num_components == 2 ? " %f32vec2_01\n" : " %f32vec4_0123\n"; + + std::ostringstream body; + body << "%val1 = OpExtInst %f" << total_bit_width << " %extinst " + << ext_inst_name << vec_str; + + std::ostringstream expected; + expected << "GLSL.std.450 " << ext_inst_name + << ": expected Result Type to be " << total_bit_width + << "-bit int scalar type"; + + CompileSuccessfully(GenerateShaderCode(body.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(expected.str())); +} + +TEST_P(ValidateGlslStd450Pack, Int16ResultType) { + const std::string ext_inst_name = GetParam(); + const uint32_t num_components = GetPackedNumComponents(ext_inst_name); + const uint32_t packed_bit_width = GetPackedBitWidth(ext_inst_name); + const uint32_t total_bit_width = num_components * packed_bit_width; + const std::string vec_str = + num_components == 2 ? " %f32vec2_01\n" : " %f32vec4_0123\n"; + + std::ostringstream body; + body << "%val1 = OpExtInst %u16 %extinst " << ext_inst_name << vec_str; + + std::ostringstream expected; + expected << "GLSL.std.450 " << ext_inst_name + << ": expected Result Type to be " << total_bit_width + << "-bit int scalar type"; + + CompileSuccessfully(GenerateShaderCode(body.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(expected.str())); +} + +TEST_P(ValidateGlslStd450Pack, VNotVector) { + const std::string ext_inst_name = GetParam(); + const uint32_t num_components = GetPackedNumComponents(ext_inst_name); + const uint32_t packed_bit_width = GetPackedBitWidth(ext_inst_name); + const uint32_t total_bit_width = num_components * packed_bit_width; + + std::ostringstream body; + body << "%val1 = OpExtInst %u" << total_bit_width << " %extinst " + << ext_inst_name << " %f32_1\n"; + + std::ostringstream expected; + expected << "GLSL.std.450 " << ext_inst_name + << ": expected operand V to be a 32-bit float vector of size " + << num_components; + + CompileSuccessfully(GenerateShaderCode(body.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(expected.str())); +} + +TEST_P(ValidateGlslStd450Pack, VNotFloatVector) { + const std::string ext_inst_name = GetParam(); + const uint32_t num_components = GetPackedNumComponents(ext_inst_name); + const uint32_t packed_bit_width = GetPackedBitWidth(ext_inst_name); + const uint32_t total_bit_width = num_components * packed_bit_width; + const std::string vec_str = + num_components == 2 ? " %u32vec2_01\n" : " %u32vec4_0123\n"; + + std::ostringstream body; + body << "%val1 = OpExtInst %u" << total_bit_width << " %extinst " + << ext_inst_name << vec_str; + + std::ostringstream expected; + expected << "GLSL.std.450 " << ext_inst_name + << ": expected operand V to be a 32-bit float vector of size " + << num_components; + + CompileSuccessfully(GenerateShaderCode(body.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(expected.str())); +} + +TEST_P(ValidateGlslStd450Pack, VNotFloat32Vector) { + const std::string ext_inst_name = GetParam(); + const uint32_t num_components = GetPackedNumComponents(ext_inst_name); + const uint32_t packed_bit_width = GetPackedBitWidth(ext_inst_name); + const uint32_t total_bit_width = num_components * packed_bit_width; + const std::string vec_str = + num_components == 2 ? " %f64vec2_01\n" : " %f64vec4_0123\n"; + + std::ostringstream body; + body << "%val1 = OpExtInst %u" << total_bit_width << " %extinst " + << ext_inst_name << vec_str; + + std::ostringstream expected; + expected << "GLSL.std.450 " << ext_inst_name + << ": expected operand V to be a 32-bit float vector of size " + << num_components; + + CompileSuccessfully(GenerateShaderCode(body.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(expected.str())); +} + +TEST_P(ValidateGlslStd450Pack, VWrongSizeVector) { + const std::string ext_inst_name = GetParam(); + const uint32_t num_components = GetPackedNumComponents(ext_inst_name); + const uint32_t packed_bit_width = GetPackedBitWidth(ext_inst_name); + const uint32_t total_bit_width = num_components * packed_bit_width; + const std::string vec_str = + num_components == 4 ? " %f32vec2_01\n" : " %f32vec4_0123\n"; + + std::ostringstream body; + body << "%val1 = OpExtInst %u" << total_bit_width << " %extinst " + << ext_inst_name << vec_str; + + std::ostringstream expected; + expected << "GLSL.std.450 " << ext_inst_name + << ": expected operand V to be a 32-bit float vector of size " + << num_components; + + CompileSuccessfully(GenerateShaderCode(body.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(expected.str())); +} + +INSTANTIATE_TEST_CASE_P(AllPack, ValidateGlslStd450Pack, + ::testing::ValuesIn(std::vector{ + "PackSnorm4x8", + "PackUnorm4x8", + "PackSnorm2x16", + "PackUnorm2x16", + "PackHalf2x16", + }), ); + +TEST_F(ValidateExtInst, PackDouble2x32Success) { + const std::string body = R"( +%val1 = OpExtInst %f64 %extinst PackDouble2x32 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, PackDouble2x32Float32ResultType) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst PackDouble2x32 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 PackDouble2x32: expected Result Type to " + "be 64-bit float scalar type")); +} + +TEST_F(ValidateExtInst, PackDouble2x32Int64ResultType) { + const std::string body = R"( +%val1 = OpExtInst %u64 %extinst PackDouble2x32 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 PackDouble2x32: expected Result Type to " + "be 64-bit float scalar type")); +} + +TEST_F(ValidateExtInst, PackDouble2x32VNotVector) { + const std::string body = R"( +%val1 = OpExtInst %f64 %extinst PackDouble2x32 %u64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 PackDouble2x32: expected operand V to be " + "a 32-bit int vector of size 2")); +} + +TEST_F(ValidateExtInst, PackDouble2x32VNotIntVector) { + const std::string body = R"( +%val1 = OpExtInst %f64 %extinst PackDouble2x32 %f32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 PackDouble2x32: expected operand V to be " + "a 32-bit int vector of size 2")); +} + +TEST_F(ValidateExtInst, PackDouble2x32VNotInt32Vector) { + const std::string body = R"( +%val1 = OpExtInst %f64 %extinst PackDouble2x32 %u64vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 PackDouble2x32: expected operand V to be " + "a 32-bit int vector of size 2")); +} + +TEST_F(ValidateExtInst, PackDouble2x32VWrongSize) { + const std::string body = R"( +%val1 = OpExtInst %f64 %extinst PackDouble2x32 %u32vec4_0123 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 PackDouble2x32: expected operand V to be " + "a 32-bit int vector of size 2")); +} + +TEST_P(ValidateGlslStd450Unpack, Success) { + const std::string ext_inst_name = GetParam(); + const uint32_t num_components = GetPackedNumComponents(ext_inst_name); + const uint32_t packed_bit_width = GetPackedBitWidth(ext_inst_name); + const uint32_t total_bit_width = num_components * packed_bit_width; + const std::string result_type_str = + num_components == 2 ? "%f32vec2" : " %f32vec4"; + + std::ostringstream body; + body << "%val1 = OpExtInst " << result_type_str << " %extinst " + << ext_inst_name << " %u" << total_bit_width << "_1\n"; + body << "%val2 = OpExtInst " << result_type_str << " %extinst " + << ext_inst_name << " %s" << total_bit_width << "_1\n"; + CompileSuccessfully(GenerateShaderCode(body.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateGlslStd450Unpack, ResultTypeNotVector) { + const std::string ext_inst_name = GetParam(); + const uint32_t num_components = GetPackedNumComponents(ext_inst_name); + const uint32_t packed_bit_width = GetPackedBitWidth(ext_inst_name); + const uint32_t total_bit_width = num_components * packed_bit_width; + const std::string result_type_str = "%f32"; + + std::ostringstream body; + body << "%val1 = OpExtInst " << result_type_str << " %extinst " + << ext_inst_name << " %u" << total_bit_width << "_1\n"; + + std::ostringstream expected; + expected << "GLSL.std.450 " << ext_inst_name + << ": expected Result Type to be a 32-bit float vector of size " + << num_components; + + CompileSuccessfully(GenerateShaderCode(body.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(expected.str())); +} + +TEST_P(ValidateGlslStd450Unpack, ResultTypeNotFloatVector) { + const std::string ext_inst_name = GetParam(); + const uint32_t num_components = GetPackedNumComponents(ext_inst_name); + const uint32_t packed_bit_width = GetPackedBitWidth(ext_inst_name); + const uint32_t total_bit_width = num_components * packed_bit_width; + const std::string result_type_str = + num_components == 2 ? "%u32vec2" : " %u32vec4"; + + std::ostringstream body; + body << "%val1 = OpExtInst " << result_type_str << " %extinst " + << ext_inst_name << " %u" << total_bit_width << "_1\n"; + + std::ostringstream expected; + expected << "GLSL.std.450 " << ext_inst_name + << ": expected Result Type to be a 32-bit float vector of size " + << num_components; + + CompileSuccessfully(GenerateShaderCode(body.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(expected.str())); +} + +TEST_P(ValidateGlslStd450Unpack, ResultTypeNotFloat32Vector) { + const std::string ext_inst_name = GetParam(); + const uint32_t num_components = GetPackedNumComponents(ext_inst_name); + const uint32_t packed_bit_width = GetPackedBitWidth(ext_inst_name); + const uint32_t total_bit_width = num_components * packed_bit_width; + const std::string result_type_str = + num_components == 2 ? "%f64vec2" : " %f64vec4"; + + std::ostringstream body; + body << "%val1 = OpExtInst " << result_type_str << " %extinst " + << ext_inst_name << " %u" << total_bit_width << "_1\n"; + + std::ostringstream expected; + expected << "GLSL.std.450 " << ext_inst_name + << ": expected Result Type to be a 32-bit float vector of size " + << num_components; + + CompileSuccessfully(GenerateShaderCode(body.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(expected.str())); +} + +TEST_P(ValidateGlslStd450Unpack, ResultTypeWrongSize) { + const std::string ext_inst_name = GetParam(); + const uint32_t num_components = GetPackedNumComponents(ext_inst_name); + const uint32_t packed_bit_width = GetPackedBitWidth(ext_inst_name); + const uint32_t total_bit_width = num_components * packed_bit_width; + const std::string result_type_str = + num_components == 4 ? "%f32vec2" : " %f32vec4"; + + std::ostringstream body; + body << "%val1 = OpExtInst " << result_type_str << " %extinst " + << ext_inst_name << " %u" << total_bit_width << "_1\n"; + + std::ostringstream expected; + expected << "GLSL.std.450 " << ext_inst_name + << ": expected Result Type to be a 32-bit float vector of size " + << num_components; + + CompileSuccessfully(GenerateShaderCode(body.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(expected.str())); +} + +TEST_P(ValidateGlslStd450Unpack, ResultPNotInt) { + const std::string ext_inst_name = GetParam(); + const uint32_t num_components = GetPackedNumComponents(ext_inst_name); + const uint32_t packed_bit_width = GetPackedBitWidth(ext_inst_name); + const uint32_t total_bit_width = num_components * packed_bit_width; + const std::string result_type_str = + num_components == 2 ? "%f32vec2" : " %f32vec4"; + + std::ostringstream body; + body << "%val1 = OpExtInst " << result_type_str << " %extinst " + << ext_inst_name << " %f" << total_bit_width << "_1\n"; + + std::ostringstream expected; + expected << "GLSL.std.450 " << ext_inst_name + << ": expected operand P to be a " << total_bit_width + << "-bit int scalar"; + + CompileSuccessfully(GenerateShaderCode(body.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(expected.str())); +} + +TEST_P(ValidateGlslStd450Unpack, ResultPWrongBitWidth) { + const std::string ext_inst_name = GetParam(); + const uint32_t num_components = GetPackedNumComponents(ext_inst_name); + const uint32_t packed_bit_width = GetPackedBitWidth(ext_inst_name); + const uint32_t total_bit_width = num_components * packed_bit_width; + const uint32_t wrong_bit_width = total_bit_width == 32 ? 64 : 32; + const std::string result_type_str = + num_components == 2 ? "%f32vec2" : " %f32vec4"; + + std::ostringstream body; + body << "%val1 = OpExtInst " << result_type_str << " %extinst " + << ext_inst_name << " %u" << wrong_bit_width << "_1\n"; + + std::ostringstream expected; + expected << "GLSL.std.450 " << ext_inst_name + << ": expected operand P to be a " << total_bit_width + << "-bit int scalar"; + + CompileSuccessfully(GenerateShaderCode(body.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(expected.str())); +} + +INSTANTIATE_TEST_CASE_P(AllUnpack, ValidateGlslStd450Unpack, + ::testing::ValuesIn(std::vector{ + "UnpackSnorm4x8", + "UnpackUnorm4x8", + "UnpackSnorm2x16", + "UnpackUnorm2x16", + "UnpackHalf2x16", + }), ); + +TEST_F(ValidateExtInst, UnpackDouble2x32Success) { + const std::string body = R"( +%val1 = OpExtInst %u32vec2 %extinst UnpackDouble2x32 %f64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, UnpackDouble2x32ResultTypeNotVector) { + const std::string body = R"( +%val1 = OpExtInst %u64 %extinst UnpackDouble2x32 %f64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 UnpackDouble2x32: expected Result Type " + "to be a 32-bit int vector of size 2")); +} + +TEST_F(ValidateExtInst, UnpackDouble2x32ResultTypeNotIntVector) { + const std::string body = R"( +%val1 = OpExtInst %f32vec2 %extinst UnpackDouble2x32 %f64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 UnpackDouble2x32: expected Result Type " + "to be a 32-bit int vector of size 2")); +} + +TEST_F(ValidateExtInst, UnpackDouble2x32ResultTypeNotInt32Vector) { + const std::string body = R"( +%val1 = OpExtInst %u64vec2 %extinst UnpackDouble2x32 %f64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 UnpackDouble2x32: expected Result Type " + "to be a 32-bit int vector of size 2")); +} + +TEST_F(ValidateExtInst, UnpackDouble2x32ResultTypeWrongSize) { + const std::string body = R"( +%val1 = OpExtInst %u32vec4 %extinst UnpackDouble2x32 %f64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 UnpackDouble2x32: expected Result Type " + "to be a 32-bit int vector of size 2")); +} + +TEST_F(ValidateExtInst, UnpackDouble2x32VNotFloat) { + const std::string body = R"( +%val1 = OpExtInst %u32vec2 %extinst UnpackDouble2x32 %u64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 UnpackDouble2x32: expected operand V to " + "be a 64-bit float scalar")); +} + +TEST_F(ValidateExtInst, UnpackDouble2x32VNotFloat64) { + const std::string body = R"( +%val1 = OpExtInst %u32vec2 %extinst UnpackDouble2x32 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 UnpackDouble2x32: expected operand V to " + "be a 64-bit float scalar")); +} + +TEST_F(ValidateExtInst, GlslStd450LengthSuccess) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst Length %f32_1 +%val2 = OpExtInst %f32 %extinst Length %f32vec2_01 +%val3 = OpExtInst %f32 %extinst Length %f32vec4_0123 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, GlslStd450LengthIntResultType) { + const std::string body = R"( +%val1 = OpExtInst %u32 %extinst Length %f32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Length: " + "expected Result Type to be a float scalar type")); +} + +TEST_F(ValidateExtInst, GlslStd450LengthIntX) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst Length %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Length: " + "expected operand X to be of float scalar or " + "vector type")); +} + +TEST_F(ValidateExtInst, GlslStd450LengthDifferentType) { + const std::string body = R"( +%val1 = OpExtInst %f64 %extinst Length %f32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Length: " + "expected operand X component type to be equal to " + "Result Type")); +} + +TEST_F(ValidateExtInst, GlslStd450DistanceSuccess) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst Distance %f32_0 %f32_1 +%val2 = OpExtInst %f32 %extinst Distance %f32vec2_01 %f32vec2_12 +%val3 = OpExtInst %f32 %extinst Distance %f32vec4_0123 %f32vec4_1234 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, GlslStd450DistanceIntResultType) { + const std::string body = R"( +%val1 = OpExtInst %u32 %extinst Distance %f32vec2_01 %f32vec2_12 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Distance: " + "expected Result Type to be a float scalar type")); +} + +TEST_F(ValidateExtInst, GlslStd450DistanceIntP0) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst Distance %u32_0 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Distance: " + "expected operand P0 to be of float scalar or " + "vector type")); +} + +TEST_F(ValidateExtInst, GlslStd450DistanceF64VectorP0) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst Distance %f64vec2_01 %f32vec2_12 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Distance: " + "expected operand P0 component type to be equal to " + "Result Type")); +} + +TEST_F(ValidateExtInst, GlslStd450DistanceIntP1) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst Distance %f32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Distance: " + "expected operand P1 to be of float scalar or " + "vector type")); +} + +TEST_F(ValidateExtInst, GlslStd450DistanceF64VectorP1) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst Distance %f32vec2_12 %f64vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Distance: " + "expected operand P1 component type to be equal to " + "Result Type")); +} + +TEST_F(ValidateExtInst, GlslStd450DistanceDifferentSize) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst Distance %f32vec2_01 %f32vec4_0123 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Distance: " + "expected operands P0 and P1 to have the same number " + "of components")); +} + +TEST_F(ValidateExtInst, GlslStd450CrossSuccess) { + const std::string body = R"( +%val1 = OpExtInst %f32vec3 %extinst Cross %f32vec3_012 %f32vec3_123 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, GlslStd450CrossIntVectorResultType) { + const std::string body = R"( +%val1 = OpExtInst %u32vec3 %extinst Cross %f32vec3_012 %f32vec3_123 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Cross: " + "expected Result Type to be a float vector type")); +} + +TEST_F(ValidateExtInst, GlslStd450CrossResultTypeWrongSize) { + const std::string body = R"( +%val1 = OpExtInst %f32vec2 %extinst Cross %f32vec3_012 %f32vec3_123 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Cross: " + "expected Result Type to have 3 components")); +} + +TEST_F(ValidateExtInst, GlslStd450CrossXWrongType) { + const std::string body = R"( +%val1 = OpExtInst %f32vec3 %extinst Cross %f64vec3_012 %f32vec3_123 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Cross: " + "expected operand X type to be equal to Result Type")); +} + +TEST_F(ValidateExtInst, GlslStd450CrossYWrongType) { + const std::string body = R"( +%val1 = OpExtInst %f32vec3 %extinst Cross %f32vec3_123 %f64vec3_012 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Cross: " + "expected operand Y type to be equal to Result Type")); +} + +TEST_F(ValidateExtInst, GlslStd450RefractSuccess) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst Refract %f32_1 %f32_1 %f32_1 +%val2 = OpExtInst %f32vec2 %extinst Refract %f32vec2_01 %f32vec2_01 %f16_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, GlslStd450RefractIntVectorResultType) { + const std::string body = R"( +%val1 = OpExtInst %u32vec2 %extinst Refract %f32vec2_01 %f32vec2_01 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Refract: " + "expected Result Type to be a float scalar or " + "vector type")); +} + +TEST_F(ValidateExtInst, GlslStd450RefractIntVectorI) { + const std::string body = R"( +%val1 = OpExtInst %f32vec2 %extinst Refract %u32vec2_01 %f32vec2_01 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Refract: " + "expected operand I to be of type equal to " + "Result Type")); +} + +TEST_F(ValidateExtInst, GlslStd450RefractIntVectorN) { + const std::string body = R"( +%val1 = OpExtInst %f32vec2 %extinst Refract %f32vec2_01 %u32vec2_01 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Refract: " + "expected operand N to be of type equal to " + "Result Type")); +} + +TEST_F(ValidateExtInst, GlslStd450RefractIntEta) { + const std::string body = R"( +%val1 = OpExtInst %f32vec2 %extinst Refract %f32vec2_01 %f32vec2_01 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Refract: " + "expected operand Eta to be a float scalar")); +} + +TEST_F(ValidateExtInst, GlslStd450RefractFloat64Eta) { + // SPIR-V issue 337: Eta can be 64-bit float scalar. + const std::string body = R"( +%val1 = OpExtInst %f32vec2 %extinst Refract %f32vec2_01 %f32vec2_01 %f64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateExtInst, GlslStd450RefractVectorEta) { + const std::string body = R"( +%val1 = OpExtInst %f32vec2 %extinst Refract %f32vec2_01 %f32vec2_01 %f32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 Refract: " + "expected operand Eta to be a float scalar")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtCentroidSuccess) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst InterpolateAtCentroid %f32_input +%val2 = OpExtInst %f32vec2 %extinst InterpolateAtCentroid %f32vec2_input +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability InterpolationFunction\n")); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtCentroidNoCapability) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst InterpolateAtCentroid %f32_input +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_CAPABILITY, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtCentroid requires " + "capability InterpolationFunction")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtCentroidIntResultType) { + const std::string body = R"( +%val1 = OpExtInst %u32 %extinst InterpolateAtCentroid %f32_input +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability InterpolationFunction\n")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtCentroid: " + "expected Result Type to be a 32-bit float scalar " + "or vector type")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtCentroidF64ResultType) { + const std::string body = R"( +%val1 = OpExtInst %f64 %extinst InterpolateAtCentroid %f32_input +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability InterpolationFunction\n")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtCentroid: " + "expected Result Type to be a 32-bit float scalar " + "or vector type")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtCentroidNotPointer) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst InterpolateAtCentroid %f32_1 +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability InterpolationFunction\n")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtCentroid: " + "expected Interpolant to be a pointer")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtCentroidWrongDataType) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst InterpolateAtCentroid %f32vec2_input +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability InterpolationFunction\n")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtCentroid: " + "expected Interpolant data type to be equal to " + "Result Type")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtCentroidWrongStorageClass) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst InterpolateAtCentroid %f32_output +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability InterpolationFunction\n")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtCentroid: " + "expected Interpolant storage class to be Input")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtCentroidWrongExecutionModel) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst InterpolateAtCentroid %f32_input +)"; + + CompileSuccessfully(GenerateShaderCode( + body, "OpCapability InterpolationFunction\n", "Vertex")); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtCentroid requires " + "Fragment execution model")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtSampleSuccess) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst InterpolateAtSample %f32_input %u32_1 +%val2 = OpExtInst %f32vec2 %extinst InterpolateAtSample %f32vec2_input %u32_1 +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability InterpolationFunction\n")); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtSampleNoCapability) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst InterpolateAtSample %f32_input %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_CAPABILITY, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtSample requires " + "capability InterpolationFunction")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtSampleIntResultType) { + const std::string body = R"( +%val1 = OpExtInst %u32 %extinst InterpolateAtSample %f32_input %u32_1 +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability InterpolationFunction\n")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtSample: " + "expected Result Type to be a 32-bit float scalar " + "or vector type")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtSampleF64ResultType) { + const std::string body = R"( +%val1 = OpExtInst %f64 %extinst InterpolateAtSample %f32_input %u32_1 +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability InterpolationFunction\n")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtSample: " + "expected Result Type to be a 32-bit float scalar " + "or vector type")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtSampleNotPointer) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst InterpolateAtSample %f32_1 %u32_1 +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability InterpolationFunction\n")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtSample: " + "expected Interpolant to be a pointer")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtSampleWrongDataType) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst InterpolateAtSample %f32vec2_input %u32_1 +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability InterpolationFunction\n")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtSample: " + "expected Interpolant data type to be equal to " + "Result Type")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtSampleWrongStorageClass) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst InterpolateAtSample %f32_output %u32_1 +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability InterpolationFunction\n")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtSample: " + "expected Interpolant storage class to be Input")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtSampleFloatSample) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst InterpolateAtSample %f32_input %f32_1 +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability InterpolationFunction\n")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtSample: " + "expected Sample to be 32-bit integer")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtSampleU64Sample) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst InterpolateAtSample %f32_input %u64_1 +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability InterpolationFunction\n")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtSample: " + "expected Sample to be 32-bit integer")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtSampleWrongExecutionModel) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst InterpolateAtSample %f32_input %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode( + body, "OpCapability InterpolationFunction\n", "Vertex")); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtSample requires " + "Fragment execution model")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtOffsetSuccess) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst InterpolateAtOffset %f32_input %f32vec2_01 +%val2 = OpExtInst %f32vec2 %extinst InterpolateAtOffset %f32vec2_input %f32vec2_01 +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability InterpolationFunction\n")); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtOffsetNoCapability) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst InterpolateAtOffset %f32_input %f32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_CAPABILITY, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtOffset requires " + "capability InterpolationFunction")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtOffsetIntResultType) { + const std::string body = R"( +%val1 = OpExtInst %u32 %extinst InterpolateAtOffset %f32_input %f32vec2_01 +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability InterpolationFunction\n")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtOffset: " + "expected Result Type to be a 32-bit float scalar " + "or vector type")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtOffsetF64ResultType) { + const std::string body = R"( +%val1 = OpExtInst %f64 %extinst InterpolateAtOffset %f32_input %f32vec2_01 +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability InterpolationFunction\n")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtOffset: " + "expected Result Type to be a 32-bit float scalar " + "or vector type")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtOffsetNotPointer) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst InterpolateAtOffset %f32_1 %f32vec2_01 +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability InterpolationFunction\n")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtOffset: " + "expected Interpolant to be a pointer")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtOffsetWrongDataType) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst InterpolateAtOffset %f32vec2_input %f32vec2_01 +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability InterpolationFunction\n")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtOffset: " + "expected Interpolant data type to be equal to " + "Result Type")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtOffsetWrongStorageClass) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst InterpolateAtOffset %f32_output %f32vec2_01 +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability InterpolationFunction\n")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtOffset: " + "expected Interpolant storage class to be Input")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtOffsetOffsetNotVector) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst InterpolateAtOffset %f32_input %f32_0 +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability InterpolationFunction\n")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtOffset: " + "expected Offset to be a vector of 2 32-bit floats")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtOffsetOffsetNotVector2) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst InterpolateAtOffset %f32_input %f32vec3_012 +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability InterpolationFunction\n")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtOffset: " + "expected Offset to be a vector of 2 32-bit floats")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtOffsetOffsetNotFloatVector) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst InterpolateAtOffset %f32_input %u32vec2_01 +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability InterpolationFunction\n")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtOffset: " + "expected Offset to be a vector of 2 32-bit floats")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtOffsetOffsetNotFloat32Vector) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst InterpolateAtOffset %f32_input %f64vec2_01 +)"; + + CompileSuccessfully( + GenerateShaderCode(body, "OpCapability InterpolationFunction\n")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtOffset: " + "expected Offset to be a vector of 2 32-bit floats")); +} + +TEST_F(ValidateExtInst, GlslStd450InterpolateAtOffsetWrongExecutionModel) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst InterpolateAtOffset %f32_input %f32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode( + body, "OpCapability InterpolationFunction\n", "Vertex")); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("GLSL.std.450 InterpolateAtOffset requires " + "Fragment execution model")); +} + +TEST_P(ValidateOpenCLStdSqrtLike, Success) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name << " %f32_0\n"; + ss << "%val2 = OpExtInst %f32vec2 %extinst " << ext_inst_name + << " %f32vec2_01\n"; + ss << "%val3 = OpExtInst %f32vec4 %extinst " << ext_inst_name + << " %f32vec4_0123\n"; + ss << "%val4 = OpExtInst %f64 %extinst " << ext_inst_name << " %f64_0\n"; + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateOpenCLStdSqrtLike, IntResultType) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + " %f32_0\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected Result Type to be a float scalar " + "or vector type")); +} + +TEST_P(ValidateOpenCLStdSqrtLike, IntOperand) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + " %u32_0\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to " + "Result Type")); +} + +INSTANTIATE_TEST_CASE_P( + AllSqrtLike, ValidateOpenCLStdSqrtLike, + ::testing::ValuesIn(std::vector{ + "acos", "acosh", "acospi", "asin", + "asinh", "asinpi", "atan", "atanh", + "atanpi", "cbrt", "ceil", "cos", + "cosh", "cospi", "erfc", "erf", + "exp", "exp2", "exp10", "expm1", + "fabs", "floor", "log", "log2", + "log10", "log1p", "logb", "rint", + "round", "rsqrt", "sin", "sinh", + "sinpi", "sqrt", "tan", "tanh", + "tanpi", "tgamma", "trunc", "half_cos", + "half_exp", "half_exp2", "half_exp10", "half_log", + "half_log2", "half_log10", "half_recip", "half_rsqrt", + "half_sin", "half_sqrt", "half_tan", "lgamma", + "native_cos", "native_exp", "native_exp2", "native_exp10", + "native_log", "native_log2", "native_log10", "native_recip", + "native_rsqrt", "native_sin", "native_sqrt", "native_tan", + "degrees", "radians", "sign", + }), ); + +TEST_P(ValidateOpenCLStdFMinLike, Success) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name + << " %f32_0 %f32_1\n"; + ss << "%val2 = OpExtInst %f32vec2 %extinst " << ext_inst_name + << " %f32vec2_01 %f32vec2_12\n"; + ss << "%val3 = OpExtInst %f64 %extinst " << ext_inst_name + << " %f64_0 %f64_0\n"; + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateOpenCLStdFMinLike, IntResultType) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + " %f32_0 %f32_1\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected Result Type to be a float scalar " + "or vector type")); +} + +TEST_P(ValidateOpenCLStdFMinLike, IntOperand1) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + " %u32_0 %f32_1\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to " + "Result Type")); +} + +TEST_P(ValidateOpenCLStdFMinLike, IntOperand2) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + " %f32_0 %u32_1\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to " + "Result Type")); +} + +INSTANTIATE_TEST_CASE_P(AllFMinLike, ValidateOpenCLStdFMinLike, + ::testing::ValuesIn(std::vector{ + "atan2", "atan2pi", "copysign", + "fdim", "fmax", "fmin", + "fmod", "maxmag", "minmag", + "hypot", "nextafter", "pow", + "powr", "remainder", "half_divide", + "half_powr", "native_divide", "native_powr", + "step", "fmax_common", "fmin_common", + }), ); + +TEST_P(ValidateOpenCLStdFClampLike, Success) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name + << " %f32_0 %f32_1 %f32_2\n"; + ss << "%val2 = OpExtInst %f32vec2 %extinst " << ext_inst_name + << " %f32vec2_01 %f32vec2_01 %f32vec2_12\n"; + ss << "%val3 = OpExtInst %f64 %extinst " << ext_inst_name + << " %f64_0 %f64_0 %f64_1\n"; + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateOpenCLStdFClampLike, IntResultType) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + + " %f32_0 %f32_1 %f32_2\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected Result Type to be a float scalar " + "or vector type")); +} + +TEST_P(ValidateOpenCLStdFClampLike, IntOperand1) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + + " %u32_0 %f32_0 %f32_1\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to " + "Result Type")); +} + +TEST_P(ValidateOpenCLStdFClampLike, IntOperand2) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + + " %f32_0 %u32_0 %f32_1\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to " + "Result Type")); +} + +TEST_P(ValidateOpenCLStdFClampLike, IntOperand3) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + + " %f32_1 %f32_0 %u32_2\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to " + "Result Type")); +} + +INSTANTIATE_TEST_CASE_P(AllFClampLike, ValidateOpenCLStdFClampLike, + ::testing::ValuesIn(std::vector{ + "fma", + "mad", + "fclamp", + "mix", + "smoothstep", + }), ); + +TEST_P(ValidateOpenCLStdSAbsLike, Success) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %u32 %extinst " << ext_inst_name << " %u32_1\n"; + ss << "%val2 = OpExtInst %u32 %extinst " << ext_inst_name << " %u32_1\n"; + ss << "%val3 = OpExtInst %u32 %extinst " << ext_inst_name << " %u32_1\n"; + ss << "%val4 = OpExtInst %u32 %extinst " << ext_inst_name << " %u32_1\n"; + ss << "%val5 = OpExtInst %u32vec2 %extinst " << ext_inst_name + << " %u32vec2_01\n"; + ss << "%val6 = OpExtInst %u32vec2 %extinst " << ext_inst_name + << " %u32vec2_01\n"; + ss << "%val7 = OpExtInst %u32vec2 %extinst " << ext_inst_name + << " %u32vec2_01\n"; + ss << "%val8 = OpExtInst %u32vec2 %extinst " << ext_inst_name + << " %u32vec2_01\n"; + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateOpenCLStdSAbsLike, FloatResultType) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + " %u32_0\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected Result Type to be an int scalar " + "or vector type")); +} + +TEST_P(ValidateOpenCLStdSAbsLike, FloatOperand) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + " %f32_0\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to Result Type")); +} + +TEST_P(ValidateOpenCLStdSAbsLike, U64Operand) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + " %u64_0\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to Result Type")); +} + +INSTANTIATE_TEST_CASE_P(AllSAbsLike, ValidateOpenCLStdSAbsLike, + ::testing::ValuesIn(std::vector{ + "s_abs", + "clz", + "ctz", + "popcount", + "u_abs", + }), ); + +TEST_P(ValidateOpenCLStdUMinLike, Success) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %u32 %extinst " << ext_inst_name + << " %u32_1 %u32_2\n"; + ss << "%val2 = OpExtInst %u32 %extinst " << ext_inst_name + << " %u32_1 %u32_2\n"; + ss << "%val3 = OpExtInst %u32 %extinst " << ext_inst_name + << " %u32_1 %u32_2\n"; + ss << "%val4 = OpExtInst %u32 %extinst " << ext_inst_name + << " %u32_1 %u32_2\n"; + ss << "%val5 = OpExtInst %u32vec2 %extinst " << ext_inst_name + << " %u32vec2_01 %u32vec2_01\n"; + ss << "%val6 = OpExtInst %u32vec2 %extinst " << ext_inst_name + << " %u32vec2_01 %u32vec2_01\n"; + ss << "%val7 = OpExtInst %u32vec2 %extinst " << ext_inst_name + << " %u32vec2_01 %u32vec2_01\n"; + ss << "%val8 = OpExtInst %u32vec2 %extinst " << ext_inst_name + << " %u32vec2_01 %u32vec2_01\n"; + ss << "%val9 = OpExtInst %u64 %extinst " << ext_inst_name + << " %u64_1 %u64_0\n"; + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateOpenCLStdUMinLike, FloatResultType) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + " %u32_0 %u32_0\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected Result Type to be an int scalar " + "or vector type")); +} + +TEST_P(ValidateOpenCLStdUMinLike, FloatOperand1) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + " %f32_0 %u32_0\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to Result Type")); +} + +TEST_P(ValidateOpenCLStdUMinLike, FloatOperand2) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + " %u32_0 %f32_0\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to Result Type")); +} + +TEST_P(ValidateOpenCLStdUMinLike, U64Operand1) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + " %u64_0 %u32_0\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to Result Type")); +} + +TEST_P(ValidateOpenCLStdUMinLike, U64Operand2) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + " %u32_0 %u64_0\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to Result Type")); +} + +INSTANTIATE_TEST_CASE_P(AllUMinLike, ValidateOpenCLStdUMinLike, + ::testing::ValuesIn(std::vector{ + "s_max", + "u_max", + "s_min", + "u_min", + "s_abs_diff", + "s_add_sat", + "u_add_sat", + "s_mul_hi", + "rotate", + "s_sub_sat", + "u_sub_sat", + "s_hadd", + "u_hadd", + "s_rhadd", + "u_rhadd", + "u_abs_diff", + "u_mul_hi", + }), ); + +TEST_P(ValidateOpenCLStdUClampLike, Success) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %u32 %extinst " << ext_inst_name + << " %u32_0 %u32_1 %u32_2\n"; + ss << "%val2 = OpExtInst %u32 %extinst " << ext_inst_name + << " %u32_0 %u32_1 %u32_2\n"; + ss << "%val3 = OpExtInst %u32 %extinst " << ext_inst_name + << " %u32_0 %u32_1 %u32_2\n"; + ss << "%val4 = OpExtInst %u32 %extinst " << ext_inst_name + << " %u32_0 %u32_1 %u32_2\n"; + ss << "%val5 = OpExtInst %u32vec2 %extinst " << ext_inst_name + << " %u32vec2_01 %u32vec2_01 %u32vec2_12\n"; + ss << "%val6 = OpExtInst %u32vec2 %extinst " << ext_inst_name + << " %u32vec2_01 %u32vec2_01 %u32vec2_12\n"; + ss << "%val7 = OpExtInst %u32vec2 %extinst " << ext_inst_name + << " %u32vec2_01 %u32vec2_01 %u32vec2_12\n"; + ss << "%val8 = OpExtInst %u32vec2 %extinst " << ext_inst_name + << " %u32vec2_01 %u32vec2_01 %u32vec2_12\n"; + ss << "%val9 = OpExtInst %u64 %extinst " << ext_inst_name + << " %u64_1 %u64_0 %u64_1\n"; + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateOpenCLStdUClampLike, FloatResultType) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + + " %u32_0 %u32_0 %u32_1\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected Result Type to be an int scalar " + "or vector type")); +} + +TEST_P(ValidateOpenCLStdUClampLike, FloatOperand1) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + + " %f32_0 %u32_0 %u32_1\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to Result Type")); +} + +TEST_P(ValidateOpenCLStdUClampLike, FloatOperand2) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + + " %u32_0 %f32_0 %u32_1\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to Result Type")); +} + +TEST_P(ValidateOpenCLStdUClampLike, FloatOperand3) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + + " %u32_0 %u32_0 %f32_1\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to Result Type")); +} + +TEST_P(ValidateOpenCLStdUClampLike, U64Operand1) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + + " %f32_0 %u32_0 %u64_1\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to Result Type")); +} + +TEST_P(ValidateOpenCLStdUClampLike, U64Operand2) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + + " %u32_0 %f32_0 %u64_1\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to Result Type")); +} + +TEST_P(ValidateOpenCLStdUClampLike, U64Operand3) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + + " %u32_0 %u32_0 %u64_1\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to Result Type")); +} + +INSTANTIATE_TEST_CASE_P(AllUClampLike, ValidateOpenCLStdUClampLike, + ::testing::ValuesIn(std::vector{ + "s_clamp", + "u_clamp", + "s_mad_hi", + "u_mad_sat", + "s_mad_sat", + "u_mad_hi", + }), ); + +// ------------------------------------------------------------- +TEST_P(ValidateOpenCLStdUMul24Like, Success) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %u32 %extinst " << ext_inst_name + << " %u32_1 %u32_2\n"; + ss << "%val2 = OpExtInst %u32 %extinst " << ext_inst_name + << " %u32_1 %u32_2\n"; + ss << "%val3 = OpExtInst %u32 %extinst " << ext_inst_name + << " %u32_1 %u32_2\n"; + ss << "%val4 = OpExtInst %u32 %extinst " << ext_inst_name + << " %u32_1 %u32_2\n"; + ss << "%val5 = OpExtInst %u32vec2 %extinst " << ext_inst_name + << " %u32vec2_01 %u32vec2_01\n"; + ss << "%val6 = OpExtInst %u32vec2 %extinst " << ext_inst_name + << " %u32vec2_01 %u32vec2_01\n"; + ss << "%val7 = OpExtInst %u32vec2 %extinst " << ext_inst_name + << " %u32vec2_01 %u32vec2_01\n"; + ss << "%val8 = OpExtInst %u32vec2 %extinst " << ext_inst_name + << " %u32vec2_01 %u32vec2_01\n"; + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateOpenCLStdUMul24Like, FloatResultType) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + " %u32_0 %u32_0\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpenCL.std " + ext_inst_name + + ": expected Result Type to be a 32-bit int scalar or vector type")); +} + +TEST_P(ValidateOpenCLStdUMul24Like, U64ResultType) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %u64 %extinst " + ext_inst_name + " %u64_0 %u64_0\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpenCL.std " + ext_inst_name + + ": expected Result Type to be a 32-bit int scalar or vector type")); +} + +TEST_P(ValidateOpenCLStdUMul24Like, FloatOperand1) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + " %f32_0 %u32_0\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to Result Type")); +} + +TEST_P(ValidateOpenCLStdUMul24Like, FloatOperand2) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + " %u32_0 %f32_0\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to Result Type")); +} + +TEST_P(ValidateOpenCLStdUMul24Like, U64Operand1) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + " %u64_0 %u32_0\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to Result Type")); +} + +TEST_P(ValidateOpenCLStdUMul24Like, U64Operand2) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + " %u32_0 %u64_0\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to Result Type")); +} + +INSTANTIATE_TEST_CASE_P(AllUMul24Like, ValidateOpenCLStdUMul24Like, + ::testing::ValuesIn(std::vector{ + "s_mul24", + "u_mul24", + }), ); + +TEST_P(ValidateOpenCLStdUMad24Like, Success) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %u32 %extinst " << ext_inst_name + << " %u32_0 %u32_1 %u32_2\n"; + ss << "%val2 = OpExtInst %u32 %extinst " << ext_inst_name + << " %u32_0 %u32_1 %u32_2\n"; + ss << "%val3 = OpExtInst %u32 %extinst " << ext_inst_name + << " %u32_0 %u32_1 %u32_2\n"; + ss << "%val4 = OpExtInst %u32 %extinst " << ext_inst_name + << " %u32_0 %u32_1 %u32_2\n"; + ss << "%val5 = OpExtInst %u32vec2 %extinst " << ext_inst_name + << " %u32vec2_01 %u32vec2_01 %u32vec2_12\n"; + ss << "%val6 = OpExtInst %u32vec2 %extinst " << ext_inst_name + << " %u32vec2_01 %u32vec2_01 %u32vec2_12\n"; + ss << "%val7 = OpExtInst %u32vec2 %extinst " << ext_inst_name + << " %u32vec2_01 %u32vec2_01 %u32vec2_12\n"; + ss << "%val8 = OpExtInst %u32vec2 %extinst " << ext_inst_name + << " %u32vec2_01 %u32vec2_01 %u32vec2_12\n"; + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateOpenCLStdUMad24Like, FloatResultType) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + + " %u32_0 %u32_0 %u32_1\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpenCL.std " + ext_inst_name + + ": expected Result Type to be a 32-bit int scalar or vector type")); +} + +TEST_P(ValidateOpenCLStdUMad24Like, U64ResultType) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %u64 %extinst " + ext_inst_name + + " %u64_0 %u64_0 %u64_1\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpenCL.std " + ext_inst_name + + ": expected Result Type to be a 32-bit int scalar or vector type")); +} + +TEST_P(ValidateOpenCLStdUMad24Like, FloatOperand1) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + + " %f32_0 %u32_0 %u32_1\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to Result Type")); +} + +TEST_P(ValidateOpenCLStdUMad24Like, FloatOperand2) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + + " %u32_0 %f32_0 %u32_1\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to Result Type")); +} + +TEST_P(ValidateOpenCLStdUMad24Like, FloatOperand3) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + + " %u32_0 %u32_0 %f32_1\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to Result Type")); +} + +TEST_P(ValidateOpenCLStdUMad24Like, U64Operand1) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + + " %f32_0 %u32_0 %u64_1\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to Result Type")); +} + +TEST_P(ValidateOpenCLStdUMad24Like, U64Operand2) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + + " %u32_0 %f32_0 %u64_1\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to Result Type")); +} + +TEST_P(ValidateOpenCLStdUMad24Like, U64Operand3) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + + " %u32_0 %u32_0 %u64_1\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected types of all operands to be equal to Result Type")); +} + +INSTANTIATE_TEST_CASE_P(AllUMad24Like, ValidateOpenCLStdUMad24Like, + ::testing::ValuesIn(std::vector{ + "s_mad24", + "u_mad24", + }), ); + +TEST_F(ValidateExtInst, OpenCLStdCrossSuccess) { + const std::string body = R"( +%val1 = OpExtInst %f32vec3 %extinst cross %f32vec3_012 %f32vec3_123 +%val2 = OpExtInst %f32vec4 %extinst cross %f32vec4_0123 %f32vec4_0123 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, OpenCLStdCrossIntVectorResultType) { + const std::string body = R"( +%val1 = OpExtInst %u32vec3 %extinst cross %f32vec3_012 %f32vec3_123 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std cross: " + "expected Result Type to be a float vector type")); +} + +TEST_F(ValidateExtInst, OpenCLStdCrossResultTypeWrongSize) { + const std::string body = R"( +%val1 = OpExtInst %f32vec2 %extinst cross %f32vec3_012 %f32vec3_123 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std cross: " + "expected Result Type to have 3 or 4 components")); +} + +TEST_F(ValidateExtInst, OpenCLStdCrossXWrongType) { + const std::string body = R"( +%val1 = OpExtInst %f32vec3 %extinst cross %f64vec3_012 %f32vec3_123 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std cross: " + "expected operand X type to be equal to Result Type")); +} + +TEST_F(ValidateExtInst, OpenCLStdCrossYWrongType) { + const std::string body = R"( +%val1 = OpExtInst %f32vec3 %extinst cross %f32vec3_123 %f64vec3_012 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std cross: " + "expected operand Y type to be equal to Result Type")); +} + +TEST_P(ValidateOpenCLStdLengthLike, Success) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name << " %f32vec2_01\n"; + ss << "%val2 = OpExtInst %f32 %extinst " << ext_inst_name + << " %f32vec4_0123\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateOpenCLStdLengthLike, IntResultType) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + " %f32vec2_01\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": " + "expected Result Type to be a float scalar type")); +} + +TEST_P(ValidateOpenCLStdLengthLike, IntX) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + " %u32vec2_01\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": " + "expected operand P to be a float scalar or vector")); +} + +TEST_P(ValidateOpenCLStdLengthLike, VectorTooBig) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + + " %f32vec8_01010101\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": " + "expected operand P to have no more than 4 components")); +} + +TEST_P(ValidateOpenCLStdLengthLike, DifferentType) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %f64 %extinst " + ext_inst_name + " %f32vec2_01\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": " + "expected operand P component type to be equal to " + "Result Type")); +} + +INSTANTIATE_TEST_CASE_P(AllLengthLike, ValidateOpenCLStdLengthLike, + ::testing::ValuesIn(std::vector{ + "length", + "fast_length", + }), ); + +TEST_P(ValidateOpenCLStdDistanceLike, Success) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name + << " %f32vec2_01 %f32vec2_01\n"; + ss << "%val2 = OpExtInst %f32 %extinst " << ext_inst_name + << " %f32vec4_0123 %f32vec4_1234\n"; + ss << "%val3 = OpExtInst %f32 %extinst " << ext_inst_name + << " %f32_0 %f32_1\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateOpenCLStdDistanceLike, IntResultType) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + + " %f32vec2_01 %f32vec2_12\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": " + "expected Result Type to be a float scalar type")); +} + +TEST_P(ValidateOpenCLStdDistanceLike, IntP0) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + + " %u32vec2_01 %f32vec2_12\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": " + "expected operand P0 to be of float scalar or vector type")); +} + +TEST_P(ValidateOpenCLStdDistanceLike, VectorTooBig) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + + " %f32vec8_01010101 %f32vec8_01010101\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": " + "expected operand P0 to have no more than 4 components")); +} + +TEST_P(ValidateOpenCLStdDistanceLike, F64P0) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %f32 %extinst " + ext_inst_name + + " %f64vec2_01 %f32vec2_12\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpenCL.std " + ext_inst_name + + ": " + "expected operand P0 component type to be equal to Result Type")); +} + +TEST_P(ValidateOpenCLStdDistanceLike, DifferentOperands) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %f64 %extinst " + ext_inst_name + + " %f64vec2_01 %f32vec2_12\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": " + "expected operands P0 and P1 to be of the same type")); +} + +INSTANTIATE_TEST_CASE_P(AllDistanceLike, ValidateOpenCLStdDistanceLike, + ::testing::ValuesIn(std::vector{ + "distance", + "fast_distance", + }), ); + +TEST_P(ValidateOpenCLStdNormalizeLike, Success) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %f32vec2 %extinst " << ext_inst_name + << " %f32vec2_01\n"; + ss << "%val2 = OpExtInst %f32vec4 %extinst " << ext_inst_name + << " %f32vec4_0123\n"; + ss << "%val3 = OpExtInst %f32 %extinst " << ext_inst_name << " %f32_2\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateOpenCLStdNormalizeLike, IntResultType) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %u32 %extinst " + ext_inst_name + " %f32_2\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": " + "expected Result Type to be a float scalar or vector type")); +} + +TEST_P(ValidateOpenCLStdNormalizeLike, VectorTooBig) { + const std::string ext_inst_name = GetParam(); + const std::string body = "%val1 = OpExtInst %f32vec8 %extinst " + + ext_inst_name + " %f32vec8_01010101\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": " + "expected Result Type to have no more than 4 components")); +} + +TEST_P(ValidateOpenCLStdNormalizeLike, DifferentType) { + const std::string ext_inst_name = GetParam(); + const std::string body = + "%val1 = OpExtInst %f64vec2 %extinst " + ext_inst_name + " %f32vec2_01\n"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": " + "expected operand P type to be equal to Result Type")); +} + +INSTANTIATE_TEST_CASE_P(AllNormalizeLike, ValidateOpenCLStdNormalizeLike, + ::testing::ValuesIn(std::vector{ + "normalize", + "fast_normalize", + }), ); + +TEST_F(ValidateExtInst, OpenCLStdBitselectSuccess) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst bitselect %f32_2 %f32_1 %f32_1 +%val2 = OpExtInst %f32vec4 %extinst bitselect %f32vec4_0123 %f32vec4_1234 %f32vec4_0123 +%val3 = OpExtInst %u32 %extinst bitselect %u32_2 %u32_1 %u32_1 +%val4 = OpExtInst %u32vec4 %extinst bitselect %u32vec4_0123 %u32vec4_0123 %u32vec4_0123 +%val5 = OpExtInst %u64 %extinst bitselect %u64_2 %u64_1 %u64_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, OpenCLStdBitselectWrongResultType) { + const std::string body = R"( +%val3 = OpExtInst %struct_f32_f32 %extinst bitselect %u32_2 %u32_1 %u32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpenCL.std bitselect: " + "expected Result Type to be an int or float scalar or vector type")); +} + +TEST_F(ValidateExtInst, OpenCLStdBitselectAWrongType) { + const std::string body = R"( +%val3 = OpExtInst %u32 %extinst bitselect %f32_2 %u32_1 %u32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std bitselect: " + "expected types of all operands to be equal to Result Type")); +} + +TEST_F(ValidateExtInst, OpenCLStdBitselectBWrongType) { + const std::string body = R"( +%val3 = OpExtInst %u32 %extinst bitselect %u32_2 %f32_1 %u32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std bitselect: " + "expected types of all operands to be equal to Result Type")); +} + +TEST_F(ValidateExtInst, OpenCLStdBitselectCWrongType) { + const std::string body = R"( +%val3 = OpExtInst %u32 %extinst bitselect %u32_2 %u32_1 %f32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std bitselect: " + "expected types of all operands to be equal to Result Type")); +} + +TEST_F(ValidateExtInst, OpenCLStdSelectSuccess) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst select %f32_2 %f32_1 %u32_1 +%val2 = OpExtInst %f32vec4 %extinst select %f32vec4_0123 %f32vec4_1234 %u32vec4_0123 +%val3 = OpExtInst %u32 %extinst select %u32_2 %u32_1 %u32_1 +%val4 = OpExtInst %u32vec4 %extinst select %u32vec4_0123 %u32vec4_0123 %u32vec4_0123 +%val5 = OpExtInst %u64 %extinst select %u64_2 %u64_1 %u64_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, OpenCLStdSelectWrongResultType) { + const std::string body = R"( +%val3 = OpExtInst %struct_f32_f32 %extinst select %u32_2 %u32_1 %u32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpenCL.std select: " + "expected Result Type to be an int or float scalar or vector type")); +} + +TEST_F(ValidateExtInst, OpenCLStdSelectAWrongType) { + const std::string body = R"( +%val3 = OpExtInst %u32 %extinst select %f32_2 %u32_1 %u32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std select: " + "expected operand A type to be equal to Result Type")); +} + +TEST_F(ValidateExtInst, OpenCLStdSelectBWrongType) { + const std::string body = R"( +%val3 = OpExtInst %u32 %extinst select %u32_2 %f32_1 %u32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std select: " + "expected operand B type to be equal to Result Type")); +} + +TEST_F(ValidateExtInst, OpenCLStdSelectCWrongType) { + const std::string body = R"( +%val3 = OpExtInst %f32 %extinst select %f32_2 %f32_1 %f32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std select: " + "expected operand C to be an int scalar or vector")); +} + +TEST_F(ValidateExtInst, OpenCLStdSelectCWrongComponentNumber) { + const std::string body = R"( +%val3 = OpExtInst %f32vec2 %extinst select %f32vec2_12 %f32vec2_01 %u32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std select: " + "expected operand C to have the same number of " + "components as Result Type")); +} + +TEST_F(ValidateExtInst, OpenCLStdSelectCWrongBitWidth) { + const std::string body = R"( +%val3 = OpExtInst %f32vec2 %extinst select %f32vec2_12 %f32vec2_01 %u64vec2_01 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpenCL.std select: " + "expected operand C to have the same bit width as Result Type")); +} + +TEST_P(ValidateOpenCLStdVStoreHalfLike, SuccessPhysical32) { + const std::string ext_inst_name = GetParam(); + const std::string rounding_mode = + ext_inst_name.substr(ext_inst_name.length() - 2) == "_r" ? " RTE" : ""; + + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f16_ptr_workgroup %f16vec8_workgroup %u32_1\n"; + if (std::string::npos == ext_inst_name.find("halfn")) { + ss << "%val1 = OpExtInst %void %extinst " << ext_inst_name + << " %f32_1 %u32_1 %ptr" << rounding_mode << "\n"; + ss << "%val2 = OpExtInst %void %extinst " << ext_inst_name + << " %f64_0 %u32_2 %ptr" << rounding_mode << "\n"; + } else { + ss << "%val1 = OpExtInst %void %extinst " << ext_inst_name + << " %f32vec2_01 %u32_1 %ptr" << rounding_mode << "\n"; + ss << "%val2 = OpExtInst %void %extinst " << ext_inst_name + << " %f32vec4_0123 %u32_0 %ptr" << rounding_mode << "\n"; + ss << "%val3 = OpExtInst %void %extinst " << ext_inst_name + << " %f64vec2_01 %u32_2 %ptr" << rounding_mode << "\n"; + } + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateOpenCLStdVStoreHalfLike, SuccessPhysical64) { + const std::string ext_inst_name = GetParam(); + const std::string rounding_mode = + ext_inst_name.substr(ext_inst_name.length() - 2) == "_r" ? " RTE" : ""; + + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f16_ptr_workgroup %f16vec8_workgroup %u32_1\n"; + if (std::string::npos == ext_inst_name.find("halfn")) { + ss << "%val1 = OpExtInst %void %extinst " << ext_inst_name + << " %f32_1 %u64_1 %ptr" << rounding_mode << "\n"; + ss << "%val2 = OpExtInst %void %extinst " << ext_inst_name + << " %f64_0 %u64_2 %ptr" << rounding_mode << "\n"; + } else { + ss << "%val1 = OpExtInst %void %extinst " << ext_inst_name + << " %f32vec2_01 %u64_1 %ptr" << rounding_mode << "\n"; + ss << "%val2 = OpExtInst %void %extinst " << ext_inst_name + << " %f32vec4_0123 %u64_0 %ptr" << rounding_mode << "\n"; + ss << "%val3 = OpExtInst %void %extinst " << ext_inst_name + << " %f64vec2_01 %u64_2 %ptr" << rounding_mode << "\n"; + } + + CompileSuccessfully(GenerateKernelCode(ss.str(), "", "Physical64")); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateOpenCLStdVStoreHalfLike, NonVoidResultType) { + const std::string ext_inst_name = GetParam(); + const std::string rounding_mode = + ext_inst_name.substr(ext_inst_name.length() - 2) == "_r" ? " RTE" : ""; + + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f16_ptr_workgroup %f16vec8_workgroup %u32_1\n"; + if (std::string::npos == ext_inst_name.find("halfn")) { + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name + << " %f32_1 %u32_1 %ptr" << rounding_mode << "\n"; + } else { + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name + << " %f32vec2_01 %u32_1 %ptr" << rounding_mode << "\n"; + } + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected Result Type to be void")); +} + +TEST_P(ValidateOpenCLStdVStoreHalfLike, WrongDataType) { + const std::string ext_inst_name = GetParam(); + const std::string rounding_mode = + ext_inst_name.substr(ext_inst_name.length() - 2) == "_r" ? " RTE" : ""; + + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f16_ptr_workgroup %f16vec8_workgroup %u32_1\n"; + if (std::string::npos == ext_inst_name.find("halfn")) { + ss << "%val1 = OpExtInst %void %extinst " << ext_inst_name + << " %f64vec2_01 %u32_1 %ptr" << rounding_mode << "\n"; + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected Data to be a 32 or 64-bit float scalar")); + } else { + ss << "%val1 = OpExtInst %void %extinst " << ext_inst_name + << " %f64_0 %u32_1 %ptr" << rounding_mode << "\n"; + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected Data to be a 32 or 64-bit float vector")); + } +} + +TEST_P(ValidateOpenCLStdVStoreHalfLike, AddressingModelLogical) { + const std::string ext_inst_name = GetParam(); + const std::string rounding_mode = + ext_inst_name.substr(ext_inst_name.length() - 2) == "_r" ? " RTE" : ""; + + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f16_ptr_workgroup %f16vec8_workgroup %u32_1\n"; + if (std::string::npos == ext_inst_name.find("halfn")) { + ss << "%val1 = OpExtInst %void %extinst " << ext_inst_name + << " %f32_0 %u32_1 %ptr" << rounding_mode << "\n"; + } else { + ss << "%val1 = OpExtInst %void %extinst " << ext_inst_name + << " %f32vec2_01 %u32_1 %ptr" << rounding_mode << "\n"; + } + + CompileSuccessfully(GenerateKernelCode(ss.str(), "", "Logical")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + " can only be used with physical addressing models")); +} + +TEST_P(ValidateOpenCLStdVStoreHalfLike, OffsetNotSizeT) { + const std::string ext_inst_name = GetParam(); + const std::string rounding_mode = + ext_inst_name.substr(ext_inst_name.length() - 2) == "_r" ? " RTE" : ""; + + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f16_ptr_workgroup %f16vec8_workgroup %u32_1\n"; + if (std::string::npos == ext_inst_name.find("halfn")) { + ss << "%val1 = OpExtInst %void %extinst " << ext_inst_name + << " %f32_0 %u32_1 %ptr" << rounding_mode << "\n"; + } else { + ss << "%val1 = OpExtInst %void %extinst " << ext_inst_name + << " %f32vec2_01 %u32_1 %ptr" << rounding_mode << "\n"; + } + + CompileSuccessfully(GenerateKernelCode(ss.str(), "", "Physical64")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": " + "expected operand Offset to be of type size_t (64-bit integer " + "for the addressing model used in the module)")); +} + +TEST_P(ValidateOpenCLStdVStoreHalfLike, PNotPointer) { + const std::string ext_inst_name = GetParam(); + const std::string rounding_mode = + ext_inst_name.substr(ext_inst_name.length() - 2) == "_r" ? " RTE" : ""; + + std::ostringstream ss; + if (std::string::npos == ext_inst_name.find("halfn")) { + ss << "%val1 = OpExtInst %void %extinst " << ext_inst_name + << " %f32_0 %u32_1 %f16_ptr_workgroup" << rounding_mode << "\n"; + } else { + ss << "%val1 = OpExtInst %void %extinst " << ext_inst_name + << " %f32vec2_01 %u32_1 %f16_ptr_workgroup" << rounding_mode << "\n"; + } + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Operand 89[%_ptr_Workgroup_half] cannot be a type")); +} + +TEST_P(ValidateOpenCLStdVStoreHalfLike, ConstPointer) { + const std::string ext_inst_name = GetParam(); + const std::string rounding_mode = + ext_inst_name.substr(ext_inst_name.length() - 2) == "_r" ? " RTE" : ""; + + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f16_ptr_uniform_constant " + "%f16vec8_uniform_constant %u32_1\n"; + if (std::string::npos == ext_inst_name.find("halfn")) { + ss << "%val1 = OpExtInst %void %extinst " << ext_inst_name + << " %f32_0 %u32_1 %ptr" << rounding_mode << "\n"; + } else { + ss << "%val1 = OpExtInst %void %extinst " << ext_inst_name + << " %f32vec2_01 %u32_1 %ptr" << rounding_mode << "\n"; + } + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected operand P storage class to be Generic, " + "CrossWorkgroup, Workgroup or Function")); +} + +TEST_P(ValidateOpenCLStdVStoreHalfLike, PDataTypeInt) { + const std::string ext_inst_name = GetParam(); + const std::string rounding_mode = + ext_inst_name.substr(ext_inst_name.length() - 2) == "_r" ? " RTE" : ""; + + std::ostringstream ss; + ss << "%ptr = OpAccessChain %u32_ptr_workgroup %u32vec8_workgroup %u32_1\n"; + if (std::string::npos == ext_inst_name.find("halfn")) { + ss << "%val1 = OpExtInst %void %extinst " << ext_inst_name + << " %f32_0 %u32_1 %ptr" << rounding_mode << "\n"; + } else { + ss << "%val1 = OpExtInst %void %extinst " << ext_inst_name + << " %f32vec2_01 %u32_1 %ptr" << rounding_mode << "\n"; + } + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected operand P data type to be 16-bit float scalar")); +} + +TEST_P(ValidateOpenCLStdVStoreHalfLike, PDataTypeFloat32) { + const std::string ext_inst_name = GetParam(); + const std::string rounding_mode = + ext_inst_name.substr(ext_inst_name.length() - 2) == "_r" ? " RTE" : ""; + + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f32_ptr_workgroup %f32vec8_workgroup %u32_1\n"; + if (std::string::npos == ext_inst_name.find("halfn")) { + ss << "%val1 = OpExtInst %void %extinst " << ext_inst_name + << " %f32_0 %u32_1 %ptr" << rounding_mode << "\n"; + } else { + ss << "%val1 = OpExtInst %void %extinst " << ext_inst_name + << " %f32vec2_01 %u32_1 %ptr" << rounding_mode << "\n"; + } + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected operand P data type to be 16-bit float scalar")); +} + +INSTANTIATE_TEST_CASE_P(AllVStoreHalfLike, ValidateOpenCLStdVStoreHalfLike, + ::testing::ValuesIn(std::vector{ + "vstore_half", + "vstore_half_r", + "vstore_halfn", + "vstore_halfn_r", + "vstorea_halfn", + "vstorea_halfn_r", + }), ); + +TEST_P(ValidateOpenCLStdVLoadHalfLike, SuccessPhysical32) { + const std::string ext_inst_name = GetParam(); + + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f16_ptr_workgroup %f16vec8_workgroup %u32_1\n"; + ss << "%val1 = OpExtInst %f32vec2 %extinst " << ext_inst_name + << " %u32_1 %ptr 2\n"; + ss << "%val2 = OpExtInst %f32vec3 %extinst " << ext_inst_name + << " %u32_1 %ptr 3\n"; + ss << "%val3 = OpExtInst %f32vec4 %extinst " << ext_inst_name + << " %u32_1 %ptr 4\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateOpenCLStdVLoadHalfLike, SuccessPhysical64) { + const std::string ext_inst_name = GetParam(); + + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f16_ptr_workgroup %f16vec8_workgroup %u32_1\n"; + ss << "%val1 = OpExtInst %f32vec2 %extinst " << ext_inst_name + << " %u64_1 %ptr 2\n"; + ss << "%val2 = OpExtInst %f32vec3 %extinst " << ext_inst_name + << " %u64_1 %ptr 3\n"; + ss << "%val3 = OpExtInst %f32vec4 %extinst " << ext_inst_name + << " %u64_1 %ptr 4\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str(), "", "Physical64")); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateOpenCLStdVLoadHalfLike, ResultTypeNotFloatVector) { + const std::string ext_inst_name = GetParam(); + + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f16_ptr_workgroup %f16vec8_workgroup %u32_1\n"; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name + << " %u32_1 %ptr 1\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected Result Type to be a float vector type")); +} + +TEST_P(ValidateOpenCLStdVLoadHalfLike, AddressingModelLogical) { + const std::string ext_inst_name = GetParam(); + + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f16_ptr_workgroup %f16vec8_workgroup %u32_1\n"; + ss << "%val1 = OpExtInst %f32vec2 %extinst " << ext_inst_name + << " %u32_1 %ptr 2\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str(), "", "Logical")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + " can only be used with physical addressing models")); +} + +TEST_P(ValidateOpenCLStdVLoadHalfLike, OffsetNotSizeT) { + const std::string ext_inst_name = GetParam(); + + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f16_ptr_workgroup %f16vec8_workgroup %u32_1\n"; + ss << "%val1 = OpExtInst %f32vec2 %extinst " << ext_inst_name + << " %u64_1 %ptr 2\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected operand Offset to be of type size_t (32-bit " + "integer for the addressing model used in the module)")); +} + +TEST_P(ValidateOpenCLStdVLoadHalfLike, PNotPointer) { + const std::string ext_inst_name = GetParam(); + + std::ostringstream ss; + ss << "%val1 = OpExtInst %f32vec2 %extinst " << ext_inst_name + << " %u32_1 %f16_ptr_workgroup 2\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Operand 89[%_ptr_Workgroup_half] cannot be a type")); +} + +TEST_P(ValidateOpenCLStdVLoadHalfLike, OffsetWrongStorageType) { + const std::string ext_inst_name = GetParam(); + + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f16_ptr_input %f16vec8_input %u32_1\n"; + ss << "%val1 = OpExtInst %f32vec2 %extinst " << ext_inst_name + << " %u32_1 %ptr 2\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected operand P storage class to be UniformConstant, " + "Generic, CrossWorkgroup, Workgroup or Function")); +} + +TEST_P(ValidateOpenCLStdVLoadHalfLike, PDataTypeInt) { + const std::string ext_inst_name = GetParam(); + + std::ostringstream ss; + ss << "%ptr = OpAccessChain %u32_ptr_workgroup %u32vec8_workgroup %u32_1\n"; + ss << "%val1 = OpExtInst %f32vec2 %extinst " << ext_inst_name + << " %u32_1 %ptr 2\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected operand P data type to be 16-bit float scalar")); +} + +TEST_P(ValidateOpenCLStdVLoadHalfLike, PDataTypeFloat32) { + const std::string ext_inst_name = GetParam(); + + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f32_ptr_workgroup %f32vec8_workgroup %u32_1\n"; + ss << "%val1 = OpExtInst %f32vec2 %extinst " << ext_inst_name + << " %u32_1 %ptr 2\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected operand P data type to be 16-bit float scalar")); +} + +TEST_P(ValidateOpenCLStdVLoadHalfLike, WrongN) { + const std::string ext_inst_name = GetParam(); + + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f16_ptr_workgroup %f16vec8_workgroup %u32_1\n"; + ss << "%val1 = OpExtInst %f32vec2 %extinst " << ext_inst_name + << " %u32_1 %ptr 3\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected literal N to be equal to the number of " + "components of Result Type")); +} + +INSTANTIATE_TEST_CASE_P(AllVLoadHalfLike, ValidateOpenCLStdVLoadHalfLike, + ::testing::ValuesIn(std::vector{ + "vload_halfn", + "vloada_halfn", + }), ); + +TEST_F(ValidateExtInst, VLoadNSuccessFloatPhysical32) { + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f32_ptr_uniform_constant " + "%f32vec8_uniform_constant %u32_1\n"; + ss << "%val1 = OpExtInst %f32vec2 %extinst vloadn %u32_1 %ptr 2\n"; + ss << "%val2 = OpExtInst %f32vec3 %extinst vloadn %u32_1 %ptr 3\n"; + ss << "%val3 = OpExtInst %f32vec4 %extinst vloadn %u32_1 %ptr 4\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, VLoadNSuccessIntPhysical32) { + std::ostringstream ss; + ss << "%ptr = OpAccessChain %u32_ptr_uniform_constant " + "%u32vec8_uniform_constant %u32_1\n"; + ss << "%val1 = OpExtInst %u32vec2 %extinst vloadn %u32_1 %ptr 2\n"; + ss << "%val2 = OpExtInst %u32vec3 %extinst vloadn %u32_1 %ptr 3\n"; + ss << "%val3 = OpExtInst %u32vec4 %extinst vloadn %u32_1 %ptr 4\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, VLoadNSuccessFloatPhysical64) { + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f32_ptr_uniform_constant " + "%f32vec8_uniform_constant %u32_1\n"; + ss << "%val1 = OpExtInst %f32vec2 %extinst vloadn %u64_1 %ptr 2\n"; + ss << "%val2 = OpExtInst %f32vec3 %extinst vloadn %u64_1 %ptr 3\n"; + ss << "%val3 = OpExtInst %f32vec4 %extinst vloadn %u64_1 %ptr 4\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str(), "", "Physical64")); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, VLoadNSuccessIntPhysical64) { + std::ostringstream ss; + ss << "%ptr = OpAccessChain %u32_ptr_uniform_constant " + "%u32vec8_uniform_constant %u32_1\n"; + ss << "%val1 = OpExtInst %u32vec2 %extinst vloadn %u64_1 %ptr 2\n"; + ss << "%val2 = OpExtInst %u32vec3 %extinst vloadn %u64_1 %ptr 3\n"; + ss << "%val3 = OpExtInst %u32vec4 %extinst vloadn %u64_1 %ptr 4\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str(), "", "Physical64")); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, VLoadNWrongResultType) { + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f32_ptr_uniform_constant " + "%f32vec8_uniform_constant %u32_1\n"; + ss << "%val1 = OpExtInst %f32 %extinst vloadn %u32_1 %ptr 2\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std vloadn: " + "expected Result Type to be an int or float vector type")); +} + +TEST_F(ValidateExtInst, VLoadNAddressingModelLogical) { + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f32_ptr_uniform_constant " + "%f32vec8_uniform_constant %u32_1\n"; + ss << "%val1 = OpExtInst %f32vec2 %extinst vloadn %u32_1 %ptr 2\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str(), "", "Logical")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std vloadn can only be used with physical " + "addressing models")); +} + +TEST_F(ValidateExtInst, VLoadNOffsetNotSizeT) { + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f32_ptr_uniform_constant " + "%f32vec8_uniform_constant %u32_1\n"; + ss << "%val1 = OpExtInst %f32vec2 %extinst vloadn %u64_1 %ptr 2\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpenCL.std vloadn: expected operand Offset to be of type size_t " + "(32-bit integer for the addressing model used in the module)")); +} + +TEST_F(ValidateExtInst, VLoadNPNotPointer) { + std::ostringstream ss; + ss << "%val1 = OpExtInst %f32vec2 %extinst vloadn %u32_1 " + "%f32_ptr_uniform_constant 2\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Operand 120[%_ptr_UniformConstant_float] cannot be a " + "type")); +} + +TEST_F(ValidateExtInst, VLoadNWrongStorageClass) { + std::ostringstream ss; + ss << "%ptr = OpAccessChain %u32_ptr_workgroup %u32vec8_workgroup %u32_1\n"; + ss << "%val1 = OpExtInst %u32vec2 %extinst vloadn %u32_1 %ptr 2\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std vloadn: expected operand P storage class " + "to be UniformConstant or Generic")); +} + +TEST_F(ValidateExtInst, VLoadNWrongComponentType) { + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f32_ptr_uniform_constant " + "%f32vec8_uniform_constant %u32_1\n"; + ss << "%val1 = OpExtInst %u32vec2 %extinst vloadn %u32_1 %ptr 2\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std vloadn: expected operand P data type to be " + "equal to component type of Result Type")); +} + +TEST_F(ValidateExtInst, VLoadNWrongN) { + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f32_ptr_uniform_constant " + "%f32vec8_uniform_constant %u32_1\n"; + ss << "%val1 = OpExtInst %f32vec2 %extinst vloadn %u32_1 %ptr 3\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std vloadn: expected literal N to be equal to " + "the number of components of Result Type")); +} + +TEST_F(ValidateExtInst, VLoadHalfSuccessPhysical32) { + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f16_ptr_uniform_constant " + "%f16vec8_uniform_constant %u32_1\n"; + ss << "%val1 = OpExtInst %f32 %extinst vload_half %u32_1 %ptr\n"; + ss << "%val2 = OpExtInst %f64 %extinst vload_half %u32_1 %ptr\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, VLoadHalfSuccessPhysical64) { + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f16_ptr_uniform_constant " + "%f16vec8_uniform_constant %u32_1\n"; + ss << "%val1 = OpExtInst %f32 %extinst vload_half %u64_1 %ptr\n"; + ss << "%val2 = OpExtInst %f64 %extinst vload_half %u64_1 %ptr\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str(), "", "Physical64")); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, VLoadHalfWrongResultType) { + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f16_ptr_uniform_constant " + "%f16vec8_uniform_constant %u32_1\n"; + ss << "%val1 = OpExtInst %u32 %extinst vload_half %u32_1 %ptr\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std vload_half: " + "expected Result Type to be a float scalar type")); +} + +TEST_F(ValidateExtInst, VLoadHalfAddressingModelLogical) { + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f16_ptr_uniform_constant " + "%f16vec8_uniform_constant %u32_1\n"; + ss << "%val1 = OpExtInst %f32 %extinst vload_half %u32_1 %ptr\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str(), "", "Logical")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std vload_half can only be used with physical " + "addressing models")); +} + +TEST_F(ValidateExtInst, VLoadHalfOffsetNotSizeT) { + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f16_ptr_uniform_constant " + "%f16vec8_uniform_constant %u32_1\n"; + ss << "%val1 = OpExtInst %f32 %extinst vload_half %u64_1 %ptr\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpenCL.std vload_half: expected operand Offset to be of type size_t " + "(32-bit integer for the addressing model used in the module)")); +} + +TEST_F(ValidateExtInst, VLoadHalfPNotPointer) { + std::ostringstream ss; + ss << "%val1 = OpExtInst %f32 %extinst vload_half %u32_1 " + "%f16_ptr_uniform_constant\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Operand 114[%_ptr_UniformConstant_half] cannot be a " + "type")); +} + +TEST_F(ValidateExtInst, VLoadHalfWrongStorageClass) { + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f16_ptr_input %f16vec8_input %u32_1\n"; + ss << "%val1 = OpExtInst %f32 %extinst vload_half %u32_1 %ptr\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpenCL.std vload_half: expected operand P storage class to be " + "UniformConstant, Generic, CrossWorkgroup, Workgroup or Function")); +} + +TEST_F(ValidateExtInst, VLoadHalfPDataTypeInt) { + std::ostringstream ss; + ss << "%ptr = OpAccessChain %u32_ptr_uniform_constant " + "%u32vec8_uniform_constant %u32_1\n"; + ss << "%val1 = OpExtInst %f32 %extinst vload_half %u32_1 %ptr\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std vload_half: expected operand P data type " + "to be 16-bit float scalar")); +} + +TEST_F(ValidateExtInst, VLoadHalfPDataTypeFloat32) { + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f32_ptr_uniform_constant " + "%f32vec8_uniform_constant %u32_1\n"; + ss << "%val1 = OpExtInst %f32 %extinst vload_half %u32_1 %ptr\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std vload_half: expected operand P data type " + "to be 16-bit float scalar")); +} + +TEST_F(ValidateExtInst, VStoreNSuccessFloatPhysical32) { + std::ostringstream ss; + ss << "%ptr_w = OpAccessChain %f32_ptr_workgroup %f32vec8_workgroup %u32_1\n"; + ss << "%ptr_g = OpPtrCastToGeneric %f32_ptr_generic %ptr_w\n"; + ss << "%val1 = OpExtInst %void %extinst vstoren %f32vec2_01 %u32_1 %ptr_g\n"; + ss << "%val2 = OpExtInst %void %extinst vstoren %f32vec4_0123 %u32_1 " + "%ptr_g\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, VStoreNSuccessFloatPhysical64) { + std::ostringstream ss; + ss << "%ptr_w = OpAccessChain %f32_ptr_workgroup %f32vec8_workgroup %u32_1\n"; + ss << "%ptr_g = OpPtrCastToGeneric %f32_ptr_generic %ptr_w\n"; + ss << "%val1 = OpExtInst %void %extinst vstoren %f32vec2_01 %u64_1 %ptr_g\n"; + ss << "%val2 = OpExtInst %void %extinst vstoren %f32vec4_0123 %u64_1 " + "%ptr_g\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str(), "", "Physical64")); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, VStoreNSuccessIntPhysical32) { + std::ostringstream ss; + ss << "%ptr_w = OpAccessChain %u32_ptr_workgroup %u32vec8_workgroup %u32_1\n"; + ss << "%ptr_g = OpPtrCastToGeneric %u32_ptr_generic %ptr_w\n"; + ss << "%val1 = OpExtInst %void %extinst vstoren %u32vec2_01 %u32_1 %ptr_g\n"; + ss << "%val2 = OpExtInst %void %extinst vstoren %u32vec4_0123 %u32_1 " + "%ptr_g\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, VStoreNSuccessIntPhysical64) { + std::ostringstream ss; + ss << "%ptr_w = OpAccessChain %u32_ptr_workgroup %u32vec8_workgroup %u32_1\n"; + ss << "%ptr_g = OpPtrCastToGeneric %u32_ptr_generic %ptr_w\n"; + ss << "%val1 = OpExtInst %void %extinst vstoren %u32vec2_01 %u64_1 %ptr_g\n"; + ss << "%val2 = OpExtInst %void %extinst vstoren %u32vec4_0123 %u64_1 " + "%ptr_g\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str(), "", "Physical64")); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, VStoreNResultTypeNotVoid) { + std::ostringstream ss; + ss << "%ptr_w = OpAccessChain %f32_ptr_workgroup %f32vec8_workgroup %u32_1\n"; + ss << "%ptr_g = OpPtrCastToGeneric %f32_ptr_generic %ptr_w\n"; + ss << "%val1 = OpExtInst %f32 %extinst vstoren %f32vec2_01 %u32_1 %ptr_g\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std vstoren: expected Result Type to be void")); +} + +TEST_F(ValidateExtInst, VStoreNDataWrongType) { + std::ostringstream ss; + ss << "%ptr_w = OpAccessChain %f32_ptr_workgroup %f32vec8_workgroup %u32_1\n"; + ss << "%ptr_g = OpPtrCastToGeneric %f32_ptr_generic %ptr_w\n"; + ss << "%val1 = OpExtInst %void %extinst vstoren %f32_1 %u32_1 %ptr_g\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpenCL.std vstoren: expected Data to be an int or float vector")); +} + +TEST_F(ValidateExtInst, VStoreNAddressingModelLogical) { + std::ostringstream ss; + ss << "%ptr_w = OpAccessChain %f32_ptr_workgroup %f32vec8_workgroup %u32_1\n"; + ss << "%ptr_g = OpPtrCastToGeneric %f32_ptr_generic %ptr_w\n"; + ss << "%val1 = OpExtInst %void %extinst vstoren %f32vec2_01 %u32_1 %ptr_g\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str(), "", "Logical")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std vstoren can only be used with physical " + "addressing models")); +} + +TEST_F(ValidateExtInst, VStoreNOffsetNotSizeT) { + std::ostringstream ss; + ss << "%ptr_w = OpAccessChain %f32_ptr_workgroup %f32vec8_workgroup %u32_1\n"; + ss << "%ptr_g = OpPtrCastToGeneric %f32_ptr_generic %ptr_w\n"; + ss << "%val1 = OpExtInst %void %extinst vstoren %f32vec2_01 %u32_1 %ptr_g\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str(), "", "Physical64")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpenCL.std vstoren: expected operand Offset to be of type size_t " + "(64-bit integer for the addressing model used in the module)")); +} + +TEST_F(ValidateExtInst, VStoreNPNotPointer) { + std::ostringstream ss; + ss << "%val1 = OpExtInst %void %extinst vstoren %f32vec2_01 %u32_1 " + "%f32_ptr_generic\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Operand 124[%_ptr_Generic_float] cannot be a type")); +} + +TEST_F(ValidateExtInst, VStoreNPNotGeneric) { + std::ostringstream ss; + ss << "%ptr_w = OpAccessChain %f32_ptr_workgroup %f32vec8_workgroup %u32_1\n"; + ss << "%val1 = OpExtInst %void %extinst vstoren %f32vec2_01 %u32_1 %ptr_w\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std vstoren: expected operand P storage class " + "to be Generic")); +} + +TEST_F(ValidateExtInst, VStorePWrongDataType) { + std::ostringstream ss; + ss << "%ptr_w = OpAccessChain %f32_ptr_workgroup %f32vec8_workgroup %u32_1\n"; + ss << "%ptr_g = OpPtrCastToGeneric %f32_ptr_generic %ptr_w\n"; + ss << "%val1 = OpExtInst %void %extinst vstoren %u32vec2_01 %u32_1 %ptr_g\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std vstoren: expected operand P data type to " + "be equal to the type of operand Data components")); +} + +TEST_F(ValidateExtInst, OpenCLStdShuffleSuccess) { + const std::string body = R"( +%val1 = OpExtInst %f32vec2 %extinst shuffle %f32vec4_0123 %u32vec2_01 +%val2 = OpExtInst %f32vec4 %extinst shuffle %f32vec4_0123 %u32vec4_0123 +%val3 = OpExtInst %u32vec2 %extinst shuffle %u32vec4_0123 %u32vec2_01 +%val4 = OpExtInst %u32vec4 %extinst shuffle %u32vec4_0123 %u32vec4_0123 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, OpenCLStdShuffleWrongResultType) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst shuffle %f32vec4_0123 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std shuffle: " + "expected Result Type to be an int or float vector type")); +} + +TEST_F(ValidateExtInst, OpenCLStdShuffleResultTypeInvalidNumComponents) { + const std::string body = R"( +%val1 = OpExtInst %f32vec3 %extinst shuffle %f32vec4_0123 %u32vec3_012 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std shuffle: " + "expected Result Type to have 2, 4, 8 or 16 components")); +} + +TEST_F(ValidateExtInst, OpenCLStdShuffleXWrongType) { + const std::string body = R"( +%val1 = OpExtInst %f32vec2 %extinst shuffle %f32_0 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std shuffle: " + "expected operand X to be an int or float vector")); +} + +TEST_F(ValidateExtInst, OpenCLStdShuffleXInvalidNumComponents) { + const std::string body = R"( +%val1 = OpExtInst %f32vec2 %extinst shuffle %f32vec3_012 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std shuffle: " + "expected operand X to have 2, 4, 8 or 16 components")); +} + +TEST_F(ValidateExtInst, OpenCLStdShuffleXInvalidComponentType) { + const std::string body = R"( +%val1 = OpExtInst %f32vec2 %extinst shuffle %f64vec4_0123 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpenCL.std shuffle: " + "expected operand X and Result Type to have equal component types")); +} + +TEST_F(ValidateExtInst, OpenCLStdShuffleShuffleMaskNotIntVector) { + const std::string body = R"( +%val1 = OpExtInst %f32vec2 %extinst shuffle %f32vec4_0123 %f32vec2_01 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std shuffle: " + "expected operand Shuffle Mask to be an int vector")); +} + +TEST_F(ValidateExtInst, OpenCLStdShuffleShuffleMaskInvalidNumComponents) { + const std::string body = R"( +%val1 = OpExtInst %f32vec4 %extinst shuffle %f32vec4_0123 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std shuffle: " + "expected operand Shuffle Mask to have the same number " + "of components as Result Type")); +} + +TEST_F(ValidateExtInst, OpenCLStdShuffleShuffleMaskInvalidBitWidth) { + const std::string body = R"( +%val1 = OpExtInst %f64vec2 %extinst shuffle %f64vec4_0123 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std shuffle: " + "expected operand Shuffle Mask components to have the " + "same bit width as Result Type components")); +} + +TEST_F(ValidateExtInst, OpenCLStdShuffle2Success) { + const std::string body = R"( +%val1 = OpExtInst %f32vec2 %extinst shuffle2 %f32vec4_0123 %f32vec4_0123 %u32vec2_01 +%val2 = OpExtInst %f32vec4 %extinst shuffle2 %f32vec4_0123 %f32vec4_0123 %u32vec4_0123 +%val3 = OpExtInst %u32vec2 %extinst shuffle2 %u32vec4_0123 %u32vec4_0123 %u32vec2_01 +%val4 = OpExtInst %u32vec4 %extinst shuffle2 %u32vec4_0123 %u32vec4_0123 %u32vec4_0123 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, OpenCLStdShuffle2WrongResultType) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst shuffle2 %f32vec4_0123 %f32vec4_0123 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std shuffle2: " + "expected Result Type to be an int or float vector type")); +} + +TEST_F(ValidateExtInst, OpenCLStdShuffle2ResultTypeInvalidNumComponents) { + const std::string body = R"( +%val1 = OpExtInst %f32vec3 %extinst shuffle2 %f32vec4_0123 %f32vec4_0123 %u32vec3_012 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std shuffle2: " + "expected Result Type to have 2, 4, 8 or 16 components")); +} + +TEST_F(ValidateExtInst, OpenCLStdShuffle2XWrongType) { + const std::string body = R"( +%val1 = OpExtInst %f32vec2 %extinst shuffle2 %f32_0 %f32_0 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std shuffle2: " + "expected operand X to be an int or float vector")); +} + +TEST_F(ValidateExtInst, OpenCLStdShuffle2YTypeDifferentFromX) { + const std::string body = R"( +%val1 = OpExtInst %f32vec2 %extinst shuffle2 %f32vec2_01 %f32vec4_0123 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std shuffle2: " + "expected operands X and Y to be of the same type")); +} + +TEST_F(ValidateExtInst, OpenCLStdShuffle2XInvalidNumComponents) { + const std::string body = R"( +%val1 = OpExtInst %f32vec2 %extinst shuffle2 %f32vec3_012 %f32vec3_012 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std shuffle2: " + "expected operand X to have 2, 4, 8 or 16 components")); +} + +TEST_F(ValidateExtInst, OpenCLStdShuffle2XInvalidComponentType) { + const std::string body = R"( +%val1 = OpExtInst %f32vec2 %extinst shuffle2 %f64vec4_0123 %f64vec4_0123 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpenCL.std shuffle2: " + "expected operand X and Result Type to have equal component types")); +} + +TEST_F(ValidateExtInst, OpenCLStdShuffle2ShuffleMaskNotIntVector) { + const std::string body = R"( +%val1 = OpExtInst %f32vec2 %extinst shuffle2 %f32vec4_0123 %f32vec4_0123 %f32vec2_01 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std shuffle2: " + "expected operand Shuffle Mask to be an int vector")); +} + +TEST_F(ValidateExtInst, OpenCLStdShuffle2ShuffleMaskInvalidNumComponents) { + const std::string body = R"( +%val1 = OpExtInst %f32vec4 %extinst shuffle2 %f32vec4_0123 %f32vec4_0123 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std shuffle2: " + "expected operand Shuffle Mask to have the same number " + "of components as Result Type")); +} + +TEST_F(ValidateExtInst, OpenCLStdShuffle2ShuffleMaskInvalidBitWidth) { + const std::string body = R"( +%val1 = OpExtInst %f64vec2 %extinst shuffle2 %f64vec4_0123 %f64vec4_0123 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std shuffle2: " + "expected operand Shuffle Mask components to have the " + "same bit width as Result Type components")); +} + +TEST_F(ValidateExtInst, OpenCLStdPrintfSuccess) { + const std::string body = R"( +%format = OpAccessChain %u8_ptr_uniform_constant %u8arr_uniform_constant %u32_0 +%val1 = OpExtInst %u32 %extinst printf %format %u32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, OpenCLStdPrintfBoolResultType) { + const std::string body = R"( +%format = OpAccessChain %u8_ptr_uniform_constant %u8arr_uniform_constant %u32_0 +%val1 = OpExtInst %bool %extinst printf %format %u32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpenCL.std printf: expected Result Type to be a 32-bit int type")); +} + +TEST_F(ValidateExtInst, OpenCLStdPrintfU64ResultType) { + const std::string body = R"( +%format = OpAccessChain %u8_ptr_uniform_constant %u8arr_uniform_constant %u32_0 +%val1 = OpExtInst %u64 %extinst printf %format %u32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpenCL.std printf: expected Result Type to be a 32-bit int type")); +} + +TEST_F(ValidateExtInst, OpenCLStdPrintfFormatNotPointer) { + const std::string body = R"( +%val1 = OpExtInst %u32 %extinst printf %u8_ptr_uniform_constant %u32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Operand 134[%_ptr_UniformConstant_uchar] cannot be a " + "type")); +} + +TEST_F(ValidateExtInst, OpenCLStdPrintfFormatNotUniformConstStorageClass) { + const std::string body = R"( +%format_const = OpAccessChain %u8_ptr_uniform_constant %u8arr_uniform_constant %u32_0 +%format = OpBitcast %u8_ptr_generic %format_const +%val1 = OpExtInst %u32 %extinst printf %format %u32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std printf: expected Format storage class to " + "be UniformConstant")); +} + +TEST_F(ValidateExtInst, OpenCLStdPrintfFormatNotU8Pointer) { + const std::string body = R"( +%format = OpAccessChain %u32_ptr_uniform_constant %u32vec8_uniform_constant %u32_0 +%val1 = OpExtInst %u32 %extinst printf %format %u32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpenCL.std printf: expected Format data type to be 8-bit int")); +} + +TEST_F(ValidateExtInst, OpenCLStdPrefetchU32Success) { + const std::string body = R"( +%ptr = OpAccessChain %u32_ptr_cross_workgroup %u32arr_cross_workgroup %u32_0 +%val1 = OpExtInst %void %extinst prefetch %ptr %u32_256 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, OpenCLStdPrefetchU32Physical64Success) { + const std::string body = R"( +%ptr = OpAccessChain %u32_ptr_cross_workgroup %u32arr_cross_workgroup %u32_0 +%val1 = OpExtInst %void %extinst prefetch %ptr %u64_256 +)"; + + CompileSuccessfully(GenerateKernelCode(body, "", "Physical64")); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, OpenCLStdPrefetchF32Success) { + const std::string body = R"( +%ptr = OpAccessChain %f32_ptr_cross_workgroup %f32arr_cross_workgroup %u32_0 +%val1 = OpExtInst %void %extinst prefetch %ptr %u32_256 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, OpenCLStdPrefetchF32Vec2Success) { + const std::string body = R"( +%ptr = OpAccessChain %f32vec2_ptr_cross_workgroup %f32vec2arr_cross_workgroup %u32_0 +%val1 = OpExtInst %void %extinst prefetch %ptr %u32_256 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, OpenCLStdPrefetchResultTypeNotVoid) { + const std::string body = R"( +%ptr = OpAccessChain %u32_ptr_cross_workgroup %u32arr_cross_workgroup %u32_0 +%val1 = OpExtInst %u32 %extinst prefetch %ptr %u32_256 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std prefetch: expected Result Type to be void")); +} + +TEST_F(ValidateExtInst, OpenCLStdPrefetchPtrNotPointer) { + const std::string body = R"( +%val1 = OpExtInst %void %extinst prefetch %u32_ptr_cross_workgroup %u32_256 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Operand 99[%_ptr_CrossWorkgroup_uint] cannot be a " + "type")); +} + +TEST_F(ValidateExtInst, OpenCLStdPrefetchPtrNotCrossWorkgroup) { + const std::string body = R"( +%ptr = OpAccessChain %u8_ptr_uniform_constant %u8arr_uniform_constant %u32_0 +%val1 = OpExtInst %void %extinst prefetch %ptr %u32_256 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std prefetch: expected operand Ptr storage " + "class to be CrossWorkgroup")); +} + +TEST_F(ValidateExtInst, OpenCLStdPrefetchInvalidDataType) { + const std::string body = R"( +%ptr = OpAccessChain %struct_ptr_cross_workgroup %struct_arr_cross_workgroup %u32_0 +%val1 = OpExtInst %void %extinst prefetch %ptr %u32_256 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std prefetch: expected Ptr data type to be int " + "or float scalar or vector")); +} + +TEST_F(ValidateExtInst, OpenCLStdPrefetchAddressingModelLogical) { + const std::string body = R"( +%ptr = OpAccessChain %u32_ptr_cross_workgroup %u32arr_cross_workgroup %u32_0 +%val1 = OpExtInst %void %extinst prefetch %ptr %u32_256 +)"; + + CompileSuccessfully(GenerateKernelCode(body, "", "Logical")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std prefetch can only be used with physical " + "addressing models")); +} + +TEST_F(ValidateExtInst, OpenCLStdPrefetchNumElementsNotSizeT) { + const std::string body = R"( +%ptr = OpAccessChain %f32_ptr_cross_workgroup %f32arr_cross_workgroup %u32_0 +%val1 = OpExtInst %void %extinst prefetch %ptr %u32_256 +)"; + + CompileSuccessfully(GenerateKernelCode(body, "", "Physical64")); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std prefetch: expected operand Num Elements to " + "be of type size_t (64-bit integer for the addressing " + "model used in the module)")); +} + +TEST_P(ValidateOpenCLStdFractLike, Success) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%var_f32 = OpVariable %f32_ptr_function Function\n"; + ss << "%var_f32vec2 = OpVariable %f32vec2_ptr_function Function\n"; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name + << " %f32_0 %var_f32\n"; + ss << "%val2 = OpExtInst %f32vec2 %extinst " << ext_inst_name + << " %f32vec2_01 %var_f32vec2\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateOpenCLStdFractLike, IntResultType) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%var_f32 = OpVariable %f32_ptr_function Function\n"; + ss << "%val1 = OpExtInst %u32 %extinst " << ext_inst_name + << " %f32_0 %var_f32\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected Result Type to be a float scalar or vector type")); +} + +TEST_P(ValidateOpenCLStdFractLike, XWrongType) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%var_f32 = OpVariable %f32_ptr_function Function\n"; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name + << " %f64_0 %var_f32\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected type of operand X to be equal to Result Type")); +} + +TEST_P(ValidateOpenCLStdFractLike, NotPointer) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%var_f32 = OpVariable %f32_ptr_function Function\n"; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name + << " %f32_0 %f32_1\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected the last operand to be a pointer")); +} + +TEST_P(ValidateOpenCLStdFractLike, PointerInvalidStorageClass) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f32_ptr_uniform_constant " + "%f32vec8_uniform_constant %u32_1\n"; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name << " %f32_0 %ptr\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected storage class of the pointer to be " + "Generic, CrossWorkgroup, Workgroup or Function")); +} + +TEST_P(ValidateOpenCLStdFractLike, PointerWrongDataType) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%var_u32 = OpVariable %u32_ptr_function Function\n"; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name + << " %f32_0 %var_u32\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpenCL.std " + ext_inst_name + + ": expected data type of the pointer to be equal to Result Type")); +} + +INSTANTIATE_TEST_CASE_P(AllFractLike, ValidateOpenCLStdFractLike, + ::testing::ValuesIn(std::vector{ + "fract", + "modf", + "sincos", + }), ); + +TEST_F(ValidateExtInst, OpenCLStdRemquoSuccess) { + const std::string body = R"( +%var_f32 = OpVariable %f32_ptr_function Function +%var_f32vec2 = OpVariable %f32vec2_ptr_function Function +%val1 = OpExtInst %f32 %extinst remquo %f32_3 %f32_2 %var_f32 +%val2 = OpExtInst %f32vec2 %extinst remquo %f32vec2_01 %f32vec2_12 %var_f32vec2 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, OpenCLStdRemquoIntResultType) { + const std::string body = R"( +%var_f32 = OpVariable %f32_ptr_function Function +%val1 = OpExtInst %u32 %extinst remquo %f32_3 %f32_2 %var_f32 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std remquo: " + "expected Result Type to be a float scalar or vector type")); +} + +TEST_F(ValidateExtInst, OpenCLStdRemquoXWrongType) { + const std::string body = R"( +%var_f32 = OpVariable %f32_ptr_function Function +%val1 = OpExtInst %f32 %extinst remquo %u32_3 %f32_2 %var_f32 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std remquo: " + "expected type of operand X to be equal to Result Type")); +} + +TEST_F(ValidateExtInst, OpenCLStdRemquoYWrongType) { + const std::string body = R"( +%var_f32 = OpVariable %f32_ptr_function Function +%val1 = OpExtInst %f32 %extinst remquo %f32_3 %u32_2 %var_f32 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std remquo: " + "expected type of operand Y to be equal to Result Type")); +} + +TEST_F(ValidateExtInst, OpenCLStdRemquoNotPointer) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst remquo %f32_3 %f32_2 %f32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std remquo: " + "expected the last operand to be a pointer")); +} + +TEST_F(ValidateExtInst, OpenCLStdRemquoPointerWrongStorageClass) { + const std::string body = R"( +%ptr = OpAccessChain %f32_ptr_uniform_constant %f32vec8_uniform_constant %u32_1 +%val1 = OpExtInst %f32 %extinst remquo %f32_3 %f32_2 %ptr +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std remquo: " + "expected storage class of the pointer to be Generic, " + "CrossWorkgroup, Workgroup or Function")); +} + +TEST_F(ValidateExtInst, OpenCLStdRemquoPointerWrongDataType) { + const std::string body = R"( +%var_u32 = OpVariable %u32_ptr_function Function +%val1 = OpExtInst %f32 %extinst remquo %f32_3 %f32_2 %var_u32 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpenCL.std remquo: " + "expected data type of the pointer to be equal to Result Type")); +} + +TEST_P(ValidateOpenCLStdFrexpLike, Success) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%var_u32 = OpVariable %u32_ptr_function Function\n"; + ss << "%var_u32vec2 = OpVariable %u32vec2_ptr_function Function\n"; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name + << " %f32_0 %var_u32\n"; + ss << "%val2 = OpExtInst %f32vec2 %extinst " << ext_inst_name + << " %f32vec2_01 %var_u32vec2\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateOpenCLStdFrexpLike, IntResultType) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%var_u32 = OpVariable %u32_ptr_function Function\n"; + ss << "%val1 = OpExtInst %u32 %extinst " << ext_inst_name + << " %f32_0 %var_u32\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected Result Type to be a float scalar or vector type")); +} + +TEST_P(ValidateOpenCLStdFrexpLike, XWrongType) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%var_u32 = OpVariable %u32_ptr_function Function\n"; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name + << " %f64_0 %var_u32\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected type of operand X to be equal to Result Type")); +} + +TEST_P(ValidateOpenCLStdFrexpLike, NotPointer) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name + << " %f32_0 %u32_1\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected the last operand to be a pointer")); +} + +TEST_P(ValidateOpenCLStdFrexpLike, PointerInvalidStorageClass) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%ptr = OpAccessChain %f32_ptr_uniform_constant " + "%f32vec8_uniform_constant %u32_1\n"; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name << " %f32_0 %ptr\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected storage class of the pointer to be " + "Generic, CrossWorkgroup, Workgroup or Function")); +} + +TEST_P(ValidateOpenCLStdFrexpLike, PointerDataTypeFloat) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%var_f32 = OpVariable %f32_ptr_function Function\n"; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name + << " %f32_0 %var_f32\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected data type of the pointer to be a 32-bit " + "int scalar or vector type")); +} + +TEST_P(ValidateOpenCLStdFrexpLike, PointerDataTypeU64) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%var_u64 = OpVariable %u64_ptr_function Function\n"; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name + << " %f32_0 %var_u64\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected data type of the pointer to be a 32-bit " + "int scalar or vector type")); +} + +TEST_P(ValidateOpenCLStdFrexpLike, PointerDataTypeDiffSize) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%var_u32 = OpVariable %u32_ptr_function Function\n"; + ss << "%val1 = OpExtInst %f32vec2 %extinst " << ext_inst_name + << " %f32vec2_01 %var_u32\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected data type of the pointer to have the same " + "number of components as Result Type")); +} + +INSTANTIATE_TEST_CASE_P(AllFrexpLike, ValidateOpenCLStdFrexpLike, + ::testing::ValuesIn(std::vector{ + "frexp", + "lgamma_r", + }), ); + +TEST_F(ValidateExtInst, OpenCLStdIlogbSuccess) { + const std::string body = R"( +%val1 = OpExtInst %u32 %extinst ilogb %f32_3 +%val2 = OpExtInst %u32vec2 %extinst ilogb %f32vec2_12 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, OpenCLStdIlogbFloatResultType) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst ilogb %f32_3 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpenCL.std ilogb: " + "expected Result Type to be a 32-bit int scalar or vector type")); +} + +TEST_F(ValidateExtInst, OpenCLStdIlogbIntX) { + const std::string body = R"( +%val1 = OpExtInst %u32 %extinst ilogb %u32_3 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std ilogb: " + "expected operand X to be a float scalar or vector")); +} + +TEST_F(ValidateExtInst, OpenCLStdIlogbDiffSize) { + const std::string body = R"( +%val2 = OpExtInst %u32vec2 %extinst ilogb %f32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std ilogb: " + "expected operand X to have the same number of " + "components as Result Type")); +} + +TEST_F(ValidateExtInst, OpenCLStdNanSuccess) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst nan %u32_3 +%val2 = OpExtInst %f32vec2 %extinst nan %u32vec2_12 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtInst, OpenCLStdNanIntResultType) { + const std::string body = R"( +%val1 = OpExtInst %u32 %extinst nan %u32_3 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std nan: " + "expected Result Type to be a float scalar or vector type")); +} + +TEST_F(ValidateExtInst, OpenCLStdNanFloatNancode) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst nan %f32_3 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std nan: " + "expected Nancode to be an int scalar or vector type")); +} + +TEST_F(ValidateExtInst, OpenCLStdNanFloatDiffSize) { + const std::string body = R"( +%val1 = OpExtInst %f32 %extinst nan %u32vec2_12 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std nan: " + "expected Nancode to have the same number of " + "components as Result Type")); +} + +TEST_F(ValidateExtInst, OpenCLStdNanFloatDiffBitWidth) { + const std::string body = R"( +%val1 = OpExtInst %f64 %extinst nan %u32_2 +)"; + + CompileSuccessfully(GenerateKernelCode(body)); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std nan: " + "expected Nancode to have the same bit width as Result Type")); +} + +TEST_P(ValidateOpenCLStdLdexpLike, Success) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name + << " %f32_0 %u32_1\n"; + ss << "%val2 = OpExtInst %f32vec2 %extinst " << ext_inst_name + << " %f32vec2_12 %u32vec2_12\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateOpenCLStdLdexpLike, IntResultType) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %u32 %extinst " << ext_inst_name + << " %f32_0 %u32_1\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected Result Type to be a float scalar or vector type")); +} + +TEST_P(ValidateOpenCLStdLdexpLike, XWrongType) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name + << " %u32_0 %u32_1\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected type of operand X to be equal to Result Type")); +} + +TEST_P(ValidateOpenCLStdLdexpLike, ExponentNotInt) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name + << " %f32_0 %f32_1\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected the exponent to be a 32-bit int scalar or vector")); +} + +TEST_P(ValidateOpenCLStdLdexpLike, ExponentNotInt32) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name + << " %f32_0 %u64_1\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected the exponent to be a 32-bit int scalar or vector")); +} + +TEST_P(ValidateOpenCLStdLdexpLike, ExponentWrongSize) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %f32 %extinst " << ext_inst_name + << " %f32_0 %u32vec2_01\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected the exponent to have the same number of " + "components as Result Type")); +} + +INSTANTIATE_TEST_CASE_P(AllLdexpLike, ValidateOpenCLStdLdexpLike, + ::testing::ValuesIn(std::vector{ + "ldexp", + "pown", + "rootn", + }), ); + +TEST_P(ValidateOpenCLStdUpsampleLike, Success) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %u16 %extinst " << ext_inst_name << " %u8_1 %u8_2\n"; + ss << "%val2 = OpExtInst %u32 %extinst " << ext_inst_name + << " %u16_1 %u16_2\n"; + ss << "%val3 = OpExtInst %u64 %extinst " << ext_inst_name + << " %u32_1 %u32_2\n"; + ss << "%val4 = OpExtInst %u64vec2 %extinst " << ext_inst_name + << " %u32vec2_01 %u32vec2_01\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateOpenCLStdUpsampleLike, FloatResultType) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %f64 %extinst " << ext_inst_name + << " %u32_1 %u32_2\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected Result Type to be an int scalar or vector type")); +} + +TEST_P(ValidateOpenCLStdUpsampleLike, InvalidResultTypeBitWidth) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %u8 %extinst " << ext_inst_name << " %u8_1 %u8_2\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpenCL.std " + ext_inst_name + + ": expected bit width of Result Type components to be 16, 32 or 64")); +} + +TEST_P(ValidateOpenCLStdUpsampleLike, LoHiDiffType) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %u64 %extinst " << ext_inst_name + << " %u32_1 %u16_2\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected Hi and Lo operands to have the same type")); +} + +TEST_P(ValidateOpenCLStdUpsampleLike, DiffNumberOfComponents) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %u64vec2 %extinst " << ext_inst_name + << " %u32_1 %u32_2\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected Hi and Lo operands to have the same number " + "of components as Result Type")); +} + +TEST_P(ValidateOpenCLStdUpsampleLike, HiLoWrongBitWidth) { + const std::string ext_inst_name = GetParam(); + std::ostringstream ss; + ss << "%val1 = OpExtInst %u64 %extinst " << ext_inst_name + << " %u16_1 %u16_2\n"; + + CompileSuccessfully(GenerateKernelCode(ss.str())); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpenCL.std " + ext_inst_name + + ": expected bit width of components of Hi and Lo operands to " + "be half of the bit width of components of Result Type")); +} + +INSTANTIATE_TEST_CASE_P(AllUpsampleLike, ValidateOpenCLStdUpsampleLike, + ::testing::ValuesIn(std::vector{ + "u_upsample", + "s_upsample", + }), ); + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_extensions_test.cpp b/test/val/val_extensions_test.cpp new file mode 100644 index 000000000..3d6466d5c --- /dev/null +++ b/test/val/val_extensions_test.cpp @@ -0,0 +1,345 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests for OpExtension validator rules. + +#include +#include + +#include "gmock/gmock.h" +#include "source/enum_string_mapping.h" +#include "source/extensions.h" +#include "source/spirv_target_env.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::HasSubstr; +using ::testing::Not; +using ::testing::Values; +using ::testing::ValuesIn; + +using ValidateKnownExtensions = spvtest::ValidateBase; +using ValidateUnknownExtensions = spvtest::ValidateBase; +using ValidateExtensionCapabilities = spvtest::ValidateBase; + +// Returns expected error string if |extension| is not recognized. +std::string GetErrorString(const std::string& extension) { + return "Found unrecognized extension " + extension; +} + +INSTANTIATE_TEST_CASE_P( + ExpectSuccess, ValidateKnownExtensions, + Values( + // Match the order as published on the SPIR-V Registry. + "SPV_AMD_shader_explicit_vertex_parameter", + "SPV_AMD_shader_trinary_minmax", "SPV_AMD_gcn_shader", + "SPV_KHR_shader_ballot", "SPV_AMD_shader_ballot", + "SPV_AMD_gpu_shader_half_float", "SPV_KHR_shader_draw_parameters", + "SPV_KHR_subgroup_vote", "SPV_KHR_16bit_storage", + "SPV_KHR_device_group", "SPV_KHR_multiview", + "SPV_NVX_multiview_per_view_attributes", "SPV_NV_viewport_array2", + "SPV_NV_stereo_view_rendering", "SPV_NV_sample_mask_override_coverage", + "SPV_NV_geometry_shader_passthrough", "SPV_AMD_texture_gather_bias_lod", + "SPV_KHR_storage_buffer_storage_class", "SPV_KHR_variable_pointers", + "SPV_AMD_gpu_shader_int16", "SPV_KHR_post_depth_coverage", + "SPV_KHR_shader_atomic_counter_ops", "SPV_EXT_shader_stencil_export", + "SPV_EXT_shader_viewport_index_layer", + "SPV_AMD_shader_image_load_store_lod", "SPV_AMD_shader_fragment_mask", + "SPV_GOOGLE_decorate_string", "SPV_GOOGLE_hlsl_functionality1", + "SPV_NV_shader_subgroup_partitioned", "SPV_EXT_descriptor_indexing")); + +INSTANTIATE_TEST_CASE_P(FailSilently, ValidateUnknownExtensions, + Values("ERROR_unknown_extension", "SPV_KHR_", + "SPV_KHR_shader_ballot_ERROR")); + +TEST_P(ValidateKnownExtensions, ExpectSuccess) { + const std::string extension = GetParam(); + const std::string str = + "OpCapability Shader\nOpCapability Linkage\nOpExtension \"" + extension + + "\"\nOpMemoryModel Logical GLSL450"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), Not(HasSubstr(GetErrorString(extension)))); +} + +TEST_P(ValidateUnknownExtensions, FailSilently) { + const std::string extension = GetParam(); + const std::string str = + "OpCapability Shader\nOpCapability Linkage\nOpExtension \"" + extension + + "\"\nOpMemoryModel Logical GLSL450"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(GetErrorString(extension))); +} + +TEST_F(ValidateUnknownExtensions, HitMaxNumOfWarnings) { + const std::string str = + std::string("OpCapability Shader\n") + "OpCapability Linkage\n" + + "OpExtension \"bad_ext\"\n" + "OpExtension \"bad_ext\"\n" + + "OpExtension \"bad_ext\"\n" + "OpExtension \"bad_ext\"\n" + + "OpExtension \"bad_ext\"\n" + "OpExtension \"bad_ext\"\n" + + "OpExtension \"bad_ext\"\n" + "OpExtension \"bad_ext\"\n" + + "OpExtension \"bad_ext\"\n" + "OpExtension \"bad_ext\"\n" + + "OpExtension \"bad_ext\"\n" + "OpExtension \"bad_ext\"\n" + + "OpExtension \"bad_ext\"\n" + "OpExtension \"bad_ext\"\n" + + "OpExtension \"bad_ext\"\n" + "OpExtension \"bad_ext\"\n" + + "OpExtension \"bad_ext\"\n" + "OpExtension \"bad_ext\"\n" + + "OpExtension \"bad_ext\"\n" + "OpExtension \"bad_ext\"\n" + + "OpExtension \"bad_ext\"\n" + "OpExtension \"bad_ext\"\n" + + "OpExtension \"bad_ext\"\n" + "OpExtension \"bad_ext\"\n" + + "OpExtension \"bad_ext\"\n" + "OpExtension \"bad_ext\"\n" + + "OpMemoryModel Logical GLSL450"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Other warnings have been suppressed.")); +} + +TEST_F(ValidateExtensionCapabilities, DeclCapabilitySuccess) { + const std::string str = + "OpCapability Shader\nOpCapability Linkage\nOpCapability DeviceGroup\n" + "OpExtension \"SPV_KHR_device_group\"" + "\nOpMemoryModel Logical GLSL450"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateExtensionCapabilities, DeclCapabilityFailure) { + const std::string str = + "OpCapability Shader\nOpCapability Linkage\nOpCapability DeviceGroup\n" + "\nOpMemoryModel Logical GLSL450"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_ERROR_MISSING_EXTENSION, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("1st operand of Capability")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("requires one of these extensions")); + EXPECT_THAT(getDiagnosticString(), HasSubstr("SPV_KHR_device_group")); +} + +using ValidateAMDShaderBallotCapabilities = spvtest::ValidateBase; + +// Returns a vector of strings for the prefix of a SPIR-V assembly shader +// that can use the group instructions introduced by SPV_AMD_shader_ballot. +std::vector ShaderPartsForAMDShaderBallot() { + return std::vector{R"( + OpCapability Shader + OpCapability Linkage + )", + R"( + OpMemoryModel Logical GLSL450 + %float = OpTypeFloat 32 + %uint = OpTypeInt 32 0 + %int = OpTypeInt 32 1 + %scope = OpConstant %uint 3 + %uint_const = OpConstant %uint 42 + %int_const = OpConstant %uint 45 + %float_const = OpConstant %float 3.5 + + %void = OpTypeVoid + %fn_ty = OpTypeFunction %void + %fn = OpFunction %void None %fn_ty + %entry = OpLabel + )"}; +} + +// Returns a list of SPIR-V assembly strings, where each uses only types +// and IDs that can fit with a shader made from parts from the result +// of ShaderPartsForAMDShaderBallot. +std::vector AMDShaderBallotGroupInstructions() { + return std::vector{ + "%iadd_reduce = OpGroupIAddNonUniformAMD %uint %scope Reduce %uint_const", + "%iadd_iscan = OpGroupIAddNonUniformAMD %uint %scope InclusiveScan " + "%uint_const", + "%iadd_escan = OpGroupIAddNonUniformAMD %uint %scope ExclusiveScan " + "%uint_const", + + "%fadd_reduce = OpGroupFAddNonUniformAMD %float %scope Reduce " + "%float_const", + "%fadd_iscan = OpGroupFAddNonUniformAMD %float %scope InclusiveScan " + "%float_const", + "%fadd_escan = OpGroupFAddNonUniformAMD %float %scope ExclusiveScan " + "%float_const", + + "%fmin_reduce = OpGroupFMinNonUniformAMD %float %scope Reduce " + "%float_const", + "%fmin_iscan = OpGroupFMinNonUniformAMD %float %scope InclusiveScan " + "%float_const", + "%fmin_escan = OpGroupFMinNonUniformAMD %float %scope ExclusiveScan " + "%float_const", + + "%umin_reduce = OpGroupUMinNonUniformAMD %uint %scope Reduce %uint_const", + "%umin_iscan = OpGroupUMinNonUniformAMD %uint %scope InclusiveScan " + "%uint_const", + "%umin_escan = OpGroupUMinNonUniformAMD %uint %scope ExclusiveScan " + "%uint_const", + + "%smin_reduce = OpGroupUMinNonUniformAMD %int %scope Reduce %int_const", + "%smin_iscan = OpGroupUMinNonUniformAMD %int %scope InclusiveScan " + "%int_const", + "%smin_escan = OpGroupUMinNonUniformAMD %int %scope ExclusiveScan " + "%int_const", + + "%fmax_reduce = OpGroupFMaxNonUniformAMD %float %scope Reduce " + "%float_const", + "%fmax_iscan = OpGroupFMaxNonUniformAMD %float %scope InclusiveScan " + "%float_const", + "%fmax_escan = OpGroupFMaxNonUniformAMD %float %scope ExclusiveScan " + "%float_const", + + "%umax_reduce = OpGroupUMaxNonUniformAMD %uint %scope Reduce %uint_const", + "%umax_iscan = OpGroupUMaxNonUniformAMD %uint %scope InclusiveScan " + "%uint_const", + "%umax_escan = OpGroupUMaxNonUniformAMD %uint %scope ExclusiveScan " + "%uint_const", + + "%smax_reduce = OpGroupUMaxNonUniformAMD %int %scope Reduce %int_const", + "%smax_iscan = OpGroupUMaxNonUniformAMD %int %scope InclusiveScan " + "%int_const", + "%smax_escan = OpGroupUMaxNonUniformAMD %int %scope ExclusiveScan " + "%int_const"}; +} + +TEST_P(ValidateAMDShaderBallotCapabilities, ExpectSuccess) { + // Succeed because the module specifies the SPV_AMD_shader_ballot extension. + auto parts = ShaderPartsForAMDShaderBallot(); + + const std::string assembly = + parts[0] + "OpExtension \"SPV_AMD_shader_ballot\"\n" + parts[1] + + GetParam() + "\nOpReturn OpFunctionEnd"; + + CompileSuccessfully(assembly.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()) << getDiagnosticString(); +} + +INSTANTIATE_TEST_CASE_P(ExpectSuccess, ValidateAMDShaderBallotCapabilities, + ValuesIn(AMDShaderBallotGroupInstructions())); + +TEST_P(ValidateAMDShaderBallotCapabilities, ExpectFailure) { + // Fail because the module does not specify the SPV_AMD_shader_ballot + // extension. + auto parts = ShaderPartsForAMDShaderBallot(); + + const std::string assembly = + parts[0] + parts[1] + GetParam() + "\nOpReturn OpFunctionEnd"; + + CompileSuccessfully(assembly.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, ValidateInstructions()); + + // Make sure we get an appropriate error message. + // Find just the opcode name, skipping over the "Op" part. + auto prefix_with_opcode = GetParam().substr(GetParam().find("Group")); + auto opcode = prefix_with_opcode.substr(0, prefix_with_opcode.find(' ')); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr(std::string("Opcode " + opcode + + " requires one of these capabilities: Groups"))); +} + +INSTANTIATE_TEST_CASE_P(ExpectFailure, ValidateAMDShaderBallotCapabilities, + ValuesIn(AMDShaderBallotGroupInstructions())); + +struct ExtIntoCoreCase { + const char* ext; + const char* cap; + const char* builtin; + spv_target_env env; + bool success; +}; + +using ValidateExtIntoCore = spvtest::ValidateBase; + +// Make sure that we don't panic about missing extensions for using +// functionalities that introduced in extensions but became core SPIR-V later. + +TEST_P(ValidateExtIntoCore, DoNotAskForExtensionInLaterVersion) { + const std::string code = std::string(R"( + OpCapability Shader + OpCapability )") + + GetParam().cap + R"( + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" %builtin + OpDecorate %builtin BuiltIn )" + GetParam().builtin + R"( + %void = OpTypeVoid + %3 = OpTypeFunction %void + %int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int + %builtin = OpVariable %_ptr_Input_int Input + %main = OpFunction %void None %3 + %5 = OpLabel + %18 = OpLoad %int %builtin + OpReturn + OpFunctionEnd)"; + + CompileSuccessfully(code.c_str(), GetParam().env); + if (GetParam().success) { + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(GetParam().env)); + } else { + ASSERT_NE(SPV_SUCCESS, ValidateInstructions(GetParam().env)); + const std::string message = getDiagnosticString(); + if (spvIsVulkanEnv(GetParam().env)) { + EXPECT_THAT(message, HasSubstr(std::string(GetParam().cap) + + " is not allowed by Vulkan")); + EXPECT_THAT(message, HasSubstr(std::string("or requires extension"))); + } else { + EXPECT_THAT(message, + HasSubstr(std::string("requires one of these extensions: ") + + GetParam().ext)); + } + } +} + +// clang-format off +INSTANTIATE_TEST_CASE_P( + KHR_extensions, ValidateExtIntoCore, + ValuesIn(std::vector{ + // SPV_KHR_shader_draw_parameters became core SPIR-V 1.3 + {"SPV_KHR_shader_draw_parameters", "DrawParameters", "BaseVertex", SPV_ENV_UNIVERSAL_1_3, true}, + {"SPV_KHR_shader_draw_parameters", "DrawParameters", "BaseVertex", SPV_ENV_UNIVERSAL_1_2, false}, + {"SPV_KHR_shader_draw_parameters", "DrawParameters", "BaseVertex", SPV_ENV_UNIVERSAL_1_1, false}, + {"SPV_KHR_shader_draw_parameters", "DrawParameters", "BaseVertex", SPV_ENV_UNIVERSAL_1_0, false}, + {"SPV_KHR_shader_draw_parameters", "DrawParameters", "BaseVertex", SPV_ENV_VULKAN_1_1, true}, + {"SPV_KHR_shader_draw_parameters", "DrawParameters", "BaseVertex", SPV_ENV_VULKAN_1_0, false}, + + {"SPV_KHR_shader_draw_parameters", "DrawParameters", "BaseInstance", SPV_ENV_UNIVERSAL_1_3, true}, + {"SPV_KHR_shader_draw_parameters", "DrawParameters", "BaseInstance", SPV_ENV_VULKAN_1_0, false}, + + {"SPV_KHR_shader_draw_parameters", "DrawParameters", "DrawIndex", SPV_ENV_UNIVERSAL_1_3, true}, + {"SPV_KHR_shader_draw_parameters", "DrawParameters", "DrawIndex", SPV_ENV_UNIVERSAL_1_1, false}, + + // SPV_KHR_multiview became core SPIR-V 1.3 + {"SPV_KHR_multiview", "MultiView", "ViewIndex", SPV_ENV_UNIVERSAL_1_3, true}, + {"SPV_KHR_multiview", "MultiView", "ViewIndex", SPV_ENV_UNIVERSAL_1_2, false}, + {"SPV_KHR_multiview", "MultiView", "ViewIndex", SPV_ENV_UNIVERSAL_1_1, false}, + {"SPV_KHR_multiview", "MultiView", "ViewIndex", SPV_ENV_UNIVERSAL_1_0, false}, + {"SPV_KHR_multiview", "MultiView", "ViewIndex", SPV_ENV_VULKAN_1_1, true}, + {"SPV_KHR_multiview", "MultiView", "ViewIndex", SPV_ENV_VULKAN_1_0, false}, + + // SPV_KHR_device_group became core SPIR-V 1.3 + {"SPV_KHR_device_group", "DeviceGroup", "DeviceIndex", SPV_ENV_UNIVERSAL_1_3, true}, + {"SPV_KHR_device_group", "DeviceGroup", "DeviceIndex", SPV_ENV_UNIVERSAL_1_2, false}, + {"SPV_KHR_device_group", "DeviceGroup", "DeviceIndex", SPV_ENV_UNIVERSAL_1_1, false}, + {"SPV_KHR_device_group", "DeviceGroup", "DeviceIndex", SPV_ENV_UNIVERSAL_1_0, false}, + {"SPV_KHR_device_group", "DeviceGroup", "DeviceIndex", SPV_ENV_VULKAN_1_1, true}, + {"SPV_KHR_device_group", "DeviceGroup", "DeviceIndex", SPV_ENV_VULKAN_1_0, false}, + })); +// clang-format on + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_fixtures.h b/test/val/val_fixtures.h new file mode 100644 index 000000000..69bc9d82f --- /dev/null +++ b/test/val/val_fixtures.h @@ -0,0 +1,159 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Common validation fixtures for unit tests + +#ifndef TEST_VAL_VAL_FIXTURES_H_ +#define TEST_VAL_VAL_FIXTURES_H_ + +#include +#include + +#include "source/val/validation_state.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" + +namespace spvtest { + +template +class ValidateBase : public ::testing::Test, + public ::testing::WithParamInterface { + public: + ValidateBase(); + + virtual void TearDown(); + + // Returns the a spv_const_binary struct + spv_const_binary get_const_binary(); + + // Checks that 'code' is valid SPIR-V text representation and stores the + // binary version for further method calls. + void CompileSuccessfully(std::string code, + spv_target_env env = SPV_ENV_UNIVERSAL_1_0); + + // Overwrites the word at index 'index' with the given word. + // For testing purposes, it is often useful to be able to manipulate the + // assembled binary before running the validator on it. + // This function overwrites the word at the given index with a new word. + void OverwriteAssembledBinary(uint32_t index, uint32_t word); + + // Performs validation on the SPIR-V code. + spv_result_t ValidateInstructions(spv_target_env env = SPV_ENV_UNIVERSAL_1_0); + + // Performs validation. Returns the status and stores validation state into + // the vstate_ member. + spv_result_t ValidateAndRetrieveValidationState( + spv_target_env env = SPV_ENV_UNIVERSAL_1_0); + + // Destroys the stored binary. + void DestroyBinary() { + spvBinaryDestroy(binary_); + binary_ = nullptr; + } + + // Destroys the stored diagnostic. + void DestroyDiagnostic() { + spvDiagnosticDestroy(diagnostic_); + diagnostic_ = nullptr; + } + + std::string getDiagnosticString(); + spv_position_t getErrorPosition(); + spv_validator_options getValidatorOptions(); + + spv_binary binary_; + spv_diagnostic diagnostic_; + spv_validator_options options_; + std::unique_ptr vstate_; +}; + +template +ValidateBase::ValidateBase() : binary_(nullptr), diagnostic_(nullptr) { + // Initialize to default command line options. Different tests can then + // specialize specific options as necessary. + options_ = spvValidatorOptionsCreate(); +} + +template +spv_const_binary ValidateBase::get_const_binary() { + return spv_const_binary(binary_); +} + +template +void ValidateBase::TearDown() { + if (diagnostic_) { + spvDiagnosticPrint(diagnostic_); + } + DestroyBinary(); + DestroyDiagnostic(); + spvValidatorOptionsDestroy(options_); +} + +template +void ValidateBase::CompileSuccessfully(std::string code, + spv_target_env env) { + DestroyBinary(); + spv_diagnostic diagnostic = nullptr; + ASSERT_EQ(SPV_SUCCESS, + spvTextToBinary(ScopedContext(env).context, code.c_str(), + code.size(), &binary_, &diagnostic)) + << "ERROR: " << diagnostic->error + << "\nSPIR-V could not be compiled into binary:\n" + << code; + spvDiagnosticDestroy(diagnostic); +} + +template +void ValidateBase::OverwriteAssembledBinary(uint32_t index, uint32_t word) { + ASSERT_TRUE(index < binary_->wordCount) + << "OverwriteAssembledBinary: The given index is larger than the binary " + "word count."; + binary_->code[index] = word; +} + +template +spv_result_t ValidateBase::ValidateInstructions(spv_target_env env) { + DestroyDiagnostic(); + return spvValidateWithOptions(ScopedContext(env).context, options_, + get_const_binary(), &diagnostic_); +} + +template +spv_result_t ValidateBase::ValidateAndRetrieveValidationState( + spv_target_env env) { + DestroyDiagnostic(); + return spvtools::val::ValidateBinaryAndKeepValidationState( + ScopedContext(env).context, options_, get_const_binary()->code, + get_const_binary()->wordCount, &diagnostic_, &vstate_); +} + +template +std::string ValidateBase::getDiagnosticString() { + return diagnostic_ == nullptr ? std::string() + : std::string(diagnostic_->error); +} + +template +spv_validator_options ValidateBase::getValidatorOptions() { + return options_; +} + +template +spv_position_t ValidateBase::getErrorPosition() { + return diagnostic_ == nullptr ? spv_position_t() : diagnostic_->position; +} + +} // namespace spvtest + +#endif // TEST_VAL_VAL_FIXTURES_H_ diff --git a/test/val/val_id_test.cpp b/test/val/val_id_test.cpp new file mode 100644 index 000000000..a09761e02 --- /dev/null +++ b/test/val/val_id_test.cpp @@ -0,0 +1,6403 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +// NOTE: The tests in this file are ONLY testing ID usage, there for the input +// SPIR-V does not follow the logical layout rules from the spec in all cases in +// order to makes the tests smaller. Validation of the whole module is handled +// in stages, ID validation is only one of these stages. All validation stages +// are stand alone. + +namespace spvtools { +namespace val { +namespace { + +using spvtest::ScopedContext; +using ::testing::HasSubstr; +using ::testing::ValuesIn; + +using ValidateIdWithMessage = spvtest::ValidateBase; + +std::string kOpCapabilitySetupWithoutVector16 = R"( + OpCapability Shader + OpCapability Linkage + OpCapability Addresses + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float64 + OpCapability LiteralSampler + OpCapability Pipes + OpCapability DeviceEnqueue +)"; + +std::string kOpCapabilitySetup = R"( + OpCapability Shader + OpCapability Linkage + OpCapability Addresses + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float64 + OpCapability LiteralSampler + OpCapability Pipes + OpCapability DeviceEnqueue + OpCapability Vector16 +)"; + +std::string kOpVariablePtrSetUp = R"( + OpCapability VariablePointers + OpExtension "SPV_KHR_variable_pointers" +)"; + +std::string kGLSL450MemoryModel = + kOpCapabilitySetup + kOpVariablePtrSetUp + R"( + OpMemoryModel Logical GLSL450 +)"; + +std::string kGLSL450MemoryModelWithoutVector16 = + kOpCapabilitySetupWithoutVector16 + kOpVariablePtrSetUp + R"( + OpMemoryModel Logical GLSL450 +)"; + +std::string kNoKernelGLSL450MemoryModel = R"( + OpCapability Shader + OpCapability Linkage + OpCapability Addresses + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float64 + OpMemoryModel Logical GLSL450 +)"; + +std::string kOpenCLMemoryModel32 = R"( + OpCapability Addresses + OpCapability Linkage + OpCapability Kernel +%1 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical32 OpenCL +)"; + +std::string kOpenCLMemoryModel64 = R"( + OpCapability Addresses + OpCapability Linkage + OpCapability Kernel + OpCapability Int64 +%1 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL +)"; + +std::string sampledImageSetup = R"( + %void = OpTypeVoid + %typeFuncVoid = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %image_type = OpTypeImage %float 2D 0 0 0 1 Unknown +%_ptr_UniformConstant_img = OpTypePointer UniformConstant %image_type + %tex = OpVariable %_ptr_UniformConstant_img UniformConstant + %sampler_type = OpTypeSampler +%_ptr_UniformConstant_sam = OpTypePointer UniformConstant %sampler_type + %s = OpVariable %_ptr_UniformConstant_sam UniformConstant + %sampled_image_type = OpTypeSampledImage %image_type + %v2float = OpTypeVector %float 2 + %float_1 = OpConstant %float 1 + %float_2 = OpConstant %float 2 + %const_vec_1_1 = OpConstantComposite %v2float %float_1 %float_1 + %const_vec_2_2 = OpConstantComposite %v2float %float_2 %float_2 + %bool_type = OpTypeBool + %spec_true = OpSpecConstantTrue %bool_type + %main = OpFunction %void None %typeFuncVoid + %label_1 = OpLabel + %image_inst = OpLoad %image_type %tex + %sampler_inst = OpLoad %sampler_type %s +)"; + +std::string BranchConditionalSetup = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 140 + OpName %main "main" + + ; type definitions + %bool = OpTypeBool + %uint = OpTypeInt 32 0 + %int = OpTypeInt 32 1 + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + + ; constants + %true = OpConstantTrue %bool + %i0 = OpConstant %int 0 + %i1 = OpConstant %int 1 + %f0 = OpConstant %float 0 + %f1 = OpConstant %float 1 + + + ; main function header + %void = OpTypeVoid + %voidfunc = OpTypeFunction %void + %main = OpFunction %void None %voidfunc + %lmain = OpLabel +)"; + +std::string BranchConditionalTail = R"( + %target_t = OpLabel + OpNop + OpBranch %end + %target_f = OpLabel + OpNop + OpBranch %end + + %end = OpLabel + + OpReturn + OpFunctionEnd +)"; + +// TODO: OpUndef + +TEST_F(ValidateIdWithMessage, OpName) { + std::string spirv = kGLSL450MemoryModel + R"( + OpName %2 "name" +%1 = OpTypeInt 32 0 +%2 = OpTypePointer UniformConstant %1 +%3 = OpVariable %2 UniformConstant)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpMemberNameGood) { + std::string spirv = kGLSL450MemoryModel + R"( + OpMemberName %2 0 "foo" +%1 = OpTypeInt 32 0 +%2 = OpTypeStruct %1)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpMemberNameTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( + OpMemberName %1 0 "foo" +%1 = OpTypeInt 32 0)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpMemberName Type '1[%uint]' is not a struct type.")); +} +TEST_F(ValidateIdWithMessage, OpMemberNameMemberBad) { + std::string spirv = kGLSL450MemoryModel + R"( + OpMemberName %1 1 "foo" +%2 = OpTypeInt 32 0 +%1 = OpTypeStruct %2)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpMemberName Member '1[%_struct_1]' index is larger " + "than Type '1[%_struct_1]'s member count.")); +} + +TEST_F(ValidateIdWithMessage, OpLineGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpString "/path/to/source.file" + OpLine %1 0 0 +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Input %2 +%4 = OpVariable %3 Input)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpLineFileBad) { + std::string spirv = kGLSL450MemoryModel + R"( + %1 = OpTypeInt 32 0 + OpLine %1 0 0 + )"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpLine Target '1[%uint]' is not an OpString.")); +} + +TEST_F(ValidateIdWithMessage, OpDecorateGood) { + std::string spirv = kGLSL450MemoryModel + R"( + OpDecorate %2 GLSLShared +%1 = OpTypeInt 64 0 +%2 = OpTypeStruct %1 %1)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpDecorateBad) { + std::string spirv = kGLSL450MemoryModel + R"( +OpDecorate %1 GLSLShared)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("forward referenced IDs have not been defined")); +} + +TEST_F(ValidateIdWithMessage, OpMemberDecorateGood) { + std::string spirv = kGLSL450MemoryModel + R"( + OpMemberDecorate %2 0 RelaxedPrecision +%1 = OpTypeInt 32 0 +%2 = OpTypeStruct %1 %1)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpMemberDecorateBad) { + std::string spirv = kGLSL450MemoryModel + R"( + OpMemberDecorate %1 0 RelaxedPrecision +%1 = OpTypeInt 32 0)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpMemberDecorate Structure type '1[%uint]' is " + "not a struct type.")); +} +TEST_F(ValidateIdWithMessage, OpMemberDecorateMemberBad) { + std::string spirv = kGLSL450MemoryModel + R"( + OpMemberDecorate %1 3 RelaxedPrecision +%int = OpTypeInt 32 0 +%1 = OpTypeStruct %int %int)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Index 3 provided in OpMemberDecorate for struct " + "1[%_struct_1] is out of bounds. The structure has 2 " + "members. Largest valid index is 1.")); +} + +TEST_F(ValidateIdWithMessage, OpGroupDecorateGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpDecorationGroup + OpDecorate %1 RelaxedPrecision + OpDecorate %1 GLSLShared + OpGroupDecorate %1 %3 %4 +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 42 +%4 = OpConstant %2 23)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpDecorationGroupBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpDecorationGroup + OpDecorate %1 RelaxedPrecision + OpDecorate %1 GLSLShared + OpMemberDecorate %1 0 Constant + )"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Result id of OpDecorationGroup can only " + "be targeted by OpName, OpGroupDecorate, " + "OpDecorate, and OpGroupMemberDecorate")); +} +TEST_F(ValidateIdWithMessage, OpGroupDecorateDecorationGroupBad) { + std::string spirv = R"( + OpCapability Shader + OpCapability Linkage + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpGroupDecorate %1 %2 %3 +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 42)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpGroupDecorate Decoration group '1[%1]' is not " + "a decoration group.")); +} +TEST_F(ValidateIdWithMessage, OpGroupDecorateTargetBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpDecorationGroup + OpDecorate %1 RelaxedPrecision + OpDecorate %1 GLSLShared + OpGroupDecorate %1 %3 +%2 = OpTypeInt 32 0)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("forward referenced IDs have not been defined")); +} +TEST_F(ValidateIdWithMessage, OpGroupMemberDecorateDecorationGroupBad) { + std::string spirv = R"( + OpCapability Shader + OpCapability Linkage + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpGroupMemberDecorate %1 %2 0 +%2 = OpTypeInt 32 0)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpGroupMemberDecorate Decoration group '1[%1]' " + "is not a decoration group.")); +} +TEST_F(ValidateIdWithMessage, OpGroupMemberDecorateIdNotStructBad) { + std::string spirv = kGLSL450MemoryModel + R"( + %1 = OpDecorationGroup + OpGroupMemberDecorate %1 %2 0 +%2 = OpTypeInt 32 0)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpGroupMemberDecorate Structure type '2[%uint]' " + "is not a struct type.")); +} +TEST_F(ValidateIdWithMessage, OpGroupMemberDecorateIndexOutOfBoundBad) { + std::string spirv = kGLSL450MemoryModel + R"( + OpDecorate %1 Offset 0 + %1 = OpDecorationGroup + OpGroupMemberDecorate %1 %struct 3 +%float = OpTypeFloat 32 +%struct = OpTypeStruct %float %float %float +)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Index 3 provided in OpGroupMemberDecorate for struct " + " 2[%_struct_2] is out of bounds. The structure " + "has 3 members. Largest valid index is 2.")); +} + +// TODO: OpExtInst + +TEST_F(ValidateIdWithMessage, OpEntryPointGood) { + std::string spirv = kGLSL450MemoryModel + R"( + OpEntryPoint GLCompute %3 "" +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpFunction %1 None %2 +%4 = OpLabel + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpEntryPointFunctionBad) { + std::string spirv = kGLSL450MemoryModel + R"( + OpEntryPoint GLCompute %1 "" +%1 = OpTypeVoid)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpEntryPoint Entry Point '1[%void]' is not a " + "function.")); +} +TEST_F(ValidateIdWithMessage, OpEntryPointParameterCountBad) { + std::string spirv = kGLSL450MemoryModel + R"( + OpEntryPoint GLCompute %3 "" +%1 = OpTypeVoid +%2 = OpTypeFunction %1 %1 +%3 = OpFunction %1 None %2 +%4 = OpLabel + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpEntryPoint Entry Point '1[%1]'s function " + "parameter count is not zero")); +} +TEST_F(ValidateIdWithMessage, OpEntryPointReturnTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( + OpEntryPoint GLCompute %3 "" +%1 = OpTypeInt 32 0 +%ret = OpConstant %1 0 +%2 = OpTypeFunction %1 +%3 = OpFunction %1 None %2 +%4 = OpLabel + OpReturnValue %ret + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpEntryPoint Entry Point '1[%1]'s function " + "return type is not void.")); +} + +TEST_F(ValidateIdWithMessage, OpEntryPointInterfaceIsNotVariableTypeBad) { + std::string spirv = R"( + OpCapability Shader + OpCapability Geometry + OpMemoryModel Logical GLSL450 + OpEntryPoint Geometry %main "main" %ptr_builtin_1 + OpExecutionMode %main InputPoints + OpExecutionMode %main OutputPoints + OpMemberDecorate %struct_1 0 BuiltIn InvocationId + %int = OpTypeInt 32 1 + %void = OpTypeVoid + %func = OpTypeFunction %void + %struct_1 = OpTypeStruct %int +%ptr_builtin_1 = OpTypePointer Input %struct_1 + %main = OpFunction %void None %func + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Interfaces passed to OpEntryPoint must be of type " + "OpTypeVariable. Found OpTypePointer.")); +} + +TEST_F(ValidateIdWithMessage, OpEntryPointInterfaceStorageClassBad) { + std::string spirv = R"( + OpCapability Shader + OpCapability Geometry + OpMemoryModel Logical GLSL450 + OpEntryPoint Geometry %main "main" %in_1 + OpExecutionMode %main InputPoints + OpExecutionMode %main OutputPoints + OpMemberDecorate %struct_1 0 BuiltIn InvocationId + %int = OpTypeInt 32 1 + %void = OpTypeVoid + %func = OpTypeFunction %void + %struct_1 = OpTypeStruct %int +%ptr_builtin_1 = OpTypePointer Uniform %struct_1 + %in_1 = OpVariable %ptr_builtin_1 Uniform + %main = OpFunction %void None %func + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpEntryPoint interfaces must be OpVariables with " + "Storage Class of Input(1) or Output(3). Found Storage " + "Class 2 for Entry Point id 1.")); +} + +TEST_F(ValidateIdWithMessage, OpExecutionModeGood) { + std::string spirv = kGLSL450MemoryModel + R"( + OpEntryPoint GLCompute %3 "" + OpExecutionMode %3 LocalSize 1 1 1 +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpFunction %1 None %2 +%4 = OpLabel + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpExecutionModeEntryPointMissing) { + std::string spirv = kGLSL450MemoryModel + R"( + OpExecutionMode %3 LocalSize 1 1 1 +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpFunction %1 None %2 +%4 = OpLabel + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpExecutionMode Entry Point '1[%1]' is not the " + "Entry Point operand of an OpEntryPoint.")); +} + +TEST_F(ValidateIdWithMessage, OpExecutionModeEntryPointBad) { + std::string spirv = kGLSL450MemoryModel + R"( + OpEntryPoint GLCompute %3 "" %a + OpExecutionMode %a LocalSize 1 1 1 +%void = OpTypeVoid +%ptr = OpTypePointer Input %void +%a = OpVariable %ptr Input +%2 = OpTypeFunction %void +%3 = OpFunction %void None %2 +%4 = OpLabel + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpExecutionMode Entry Point '2[%2]' is not the " + "Entry Point operand of an OpEntryPoint.")); +} + +TEST_F(ValidateIdWithMessage, OpTypeVectorFloat) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 4)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpTypeVectorInt) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpTypeVector %1 4)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpTypeVectorUInt) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 64 0 +%2 = OpTypeVector %1 4)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpTypeVectorBool) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeBool +%2 = OpTypeVector %1 4)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpTypeVectorComponentTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpTypePointer UniformConstant %1 +%3 = OpTypeVector %2 4)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpTypeVector Component Type " + "'2[%_ptr_UniformConstant_float]' is not a scalar type.")); +} + +TEST_F(ValidateIdWithMessage, OpTypeVectorColumnCountLessThanTwoBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 1)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Illegal number of components (1) for TypeVector\n %v1float = " + "OpTypeVector %float 1\n")); +} + +TEST_F(ValidateIdWithMessage, OpTypeVectorColumnCountGreaterThanFourBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 5)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Illegal number of components (5) for TypeVector\n %v5float = " + "OpTypeVector %float 5\n")); +} + +TEST_F(ValidateIdWithMessage, OpTypeVectorColumnCountEightWithoutVector16Bad) { + std::string spirv = kGLSL450MemoryModelWithoutVector16 + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 8)"; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Having 8 components for TypeVector requires the Vector16 " + "capability\n %v8float = OpTypeVector %float 8\n")); +} + +TEST_F(ValidateIdWithMessage, + OpTypeVectorColumnCountSixteenWithoutVector16Bad) { + std::string spirv = kGLSL450MemoryModelWithoutVector16 + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 16)"; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Having 16 components for TypeVector requires the Vector16 " + "capability\n %v16float = OpTypeVector %float 16\n")); +} + +TEST_F(ValidateIdWithMessage, OpTypeVectorColumnCountOfEightWithVector16Good) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 8)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, + OpTypeVectorColumnCountOfSixteenWithVector16Good) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 16)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpTypeMatrixGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 2 +%3 = OpTypeMatrix %2 3)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpTypeMatrixColumnTypeNonVectorBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeMatrix %1 3)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("olumns in a matrix must be of type vector.\n %mat3float = " + "OpTypeMatrix %float 3\n")); +} + +TEST_F(ValidateIdWithMessage, OpTypeMatrixVectorTypeNonFloatBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 16 0 +%2 = OpTypeVector %1 2 +%3 = OpTypeMatrix %2 2)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Matrix types can only be parameterized with floating-point " + "types.\n %mat2v2ushort = OpTypeMatrix %v2ushort 2\n")); +} + +TEST_F(ValidateIdWithMessage, OpTypeMatrixColumnCountLessThanTwoBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 2 +%3 = OpTypeMatrix %2 1)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Matrix types can only be parameterized as having only 2, 3, " + "or 4 columns.\n %mat1v2float = OpTypeMatrix %v2float 1\n")); +} + +TEST_F(ValidateIdWithMessage, OpTypeMatrixColumnCountGreaterThanFourBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 2 +%3 = OpTypeMatrix %2 8)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Matrix types can only be parameterized as having only 2, 3, " + "or 4 columns.\n %mat8v2float = OpTypeMatrix %v2float 8\n")); +} + +TEST_F(ValidateIdWithMessage, OpTypeSamplerGood) { + // In Rev31, OpTypeSampler takes no arguments. + std::string spirv = kGLSL450MemoryModel + R"( +%s = OpTypeSampler)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpTypeArrayGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpConstant %1 1 +%3 = OpTypeArray %1 %2)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpTypeArrayElementTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpConstant %1 1 +%3 = OpTypeArray %2 %2)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpTypeArray Element Type '2[%uint_1]' is not a " + "type.")); +} + +// Signed or unsigned. +enum Signed { kSigned, kUnsigned }; + +// Creates an assembly snippet declaring OpTypeArray with the given length. +std::string MakeArrayLength(const std::string& len, Signed isSigned, + int width) { + std::ostringstream ss; + ss << R"( + OpCapability Shader + OpCapability Linkage + OpCapability Int16 + OpCapability Int64 + )"; + ss << "OpMemoryModel Logical GLSL450\n"; + ss << " %t = OpTypeInt " << width << (isSigned == kSigned ? " 1" : " 0"); + ss << " %l = OpConstant %t " << len; + ss << " %a = OpTypeArray %t %l"; + return ss.str(); +} + +// Tests OpTypeArray. Parameter is the width (in bits) of the array-length's +// type. +class OpTypeArrayLengthTest + : public spvtest::TextToBinaryTestBase<::testing::TestWithParam> { + protected: + OpTypeArrayLengthTest() + : position_(spv_position_t{0, 0, 0}), + diagnostic_(spvDiagnosticCreate(&position_, "")) {} + + ~OpTypeArrayLengthTest() { spvDiagnosticDestroy(diagnostic_); } + + // Runs spvValidate() on v, printing any errors via spvDiagnosticPrint(). + spv_result_t Val(const SpirvVector& v, const std::string& expected_err = "") { + spv_const_binary_t cbinary{v.data(), v.size()}; + spvDiagnosticDestroy(diagnostic_); + diagnostic_ = nullptr; + const auto status = + spvValidate(ScopedContext().context, &cbinary, &diagnostic_); + if (status != SPV_SUCCESS) { + spvDiagnosticPrint(diagnostic_); + EXPECT_THAT(std::string(diagnostic_->error), + testing::ContainsRegex(expected_err)); + } + return status; + } + + private: + spv_position_t position_; // For creating diagnostic_. + spv_diagnostic diagnostic_; +}; + +TEST_P(OpTypeArrayLengthTest, LengthPositive) { + const int width = GetParam(); + EXPECT_EQ(SPV_SUCCESS, + Val(CompileSuccessfully(MakeArrayLength("1", kSigned, width)))); + EXPECT_EQ(SPV_SUCCESS, + Val(CompileSuccessfully(MakeArrayLength("1", kUnsigned, width)))); + EXPECT_EQ(SPV_SUCCESS, + Val(CompileSuccessfully(MakeArrayLength("2", kSigned, width)))); + EXPECT_EQ(SPV_SUCCESS, + Val(CompileSuccessfully(MakeArrayLength("2", kUnsigned, width)))); + EXPECT_EQ(SPV_SUCCESS, + Val(CompileSuccessfully(MakeArrayLength("55", kSigned, width)))); + EXPECT_EQ(SPV_SUCCESS, + Val(CompileSuccessfully(MakeArrayLength("55", kUnsigned, width)))); + const std::string fpad(width / 4 - 1, 'F'); + EXPECT_EQ( + SPV_SUCCESS, + Val(CompileSuccessfully(MakeArrayLength("0x7" + fpad, kSigned, width)))); + EXPECT_EQ(SPV_SUCCESS, Val(CompileSuccessfully( + MakeArrayLength("0xF" + fpad, kUnsigned, width)))); +} + +TEST_P(OpTypeArrayLengthTest, LengthZero) { + const int width = GetParam(); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + Val(CompileSuccessfully(MakeArrayLength("0", kSigned, width)), + "OpTypeArray Length '2\\[%.*\\]' default value must be at " + "least 1.")); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + Val(CompileSuccessfully(MakeArrayLength("0", kUnsigned, width)), + "OpTypeArray Length '2\\[%.*\\]' default value must be at " + "least 1.")); +} + +TEST_P(OpTypeArrayLengthTest, LengthNegative) { + const int width = GetParam(); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + Val(CompileSuccessfully(MakeArrayLength("-1", kSigned, width)), + "OpTypeArray Length '2\\[%.*\\]' default value must be at " + "least 1.")); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + Val(CompileSuccessfully(MakeArrayLength("-2", kSigned, width)), + "OpTypeArray Length '2\\[%.*\\]' default value must be at " + "least 1.")); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + Val(CompileSuccessfully(MakeArrayLength("-123", kSigned, width)), + "OpTypeArray Length '2\\[%.*\\]' default value must be at " + "least 1.")); + const std::string neg_max = "0x8" + std::string(width / 4 - 1, '0'); + EXPECT_EQ(SPV_ERROR_INVALID_ID, + Val(CompileSuccessfully(MakeArrayLength(neg_max, kSigned, width)), + "OpTypeArray Length '2\\[%.*\\]' default value must be at " + "least 1.")); +} + +// The only valid widths for integers are 8, 16, 32, and 64. +// Since the Int8 capability requires the Kernel capability, and the Kernel +// capability prohibits usage of signed integers, we can skip 8-bit integers +// here since the purpose of these tests is to check the validity of +// OpTypeArray, not OpTypeInt. +INSTANTIATE_TEST_CASE_P(Widths, OpTypeArrayLengthTest, + ValuesIn(std::vector{16, 32, 64})); + +TEST_F(ValidateIdWithMessage, OpTypeArrayLengthNull) { + std::string spirv = kGLSL450MemoryModel + R"( +%i32 = OpTypeInt 32 0 +%len = OpConstantNull %i32 +%ary = OpTypeArray %i32 %len)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpTypeArray Length '2[%2]' default value must be at least 1.")); +} + +TEST_F(ValidateIdWithMessage, OpTypeArrayLengthSpecConst) { + std::string spirv = kGLSL450MemoryModel + R"( +%i32 = OpTypeInt 32 0 +%len = OpSpecConstant %i32 2 +%ary = OpTypeArray %i32 %len)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpTypeArrayLengthSpecConstOp) { + std::string spirv = kGLSL450MemoryModel + R"( +%i32 = OpTypeInt 32 0 +%c1 = OpConstant %i32 1 +%c2 = OpConstant %i32 2 +%len = OpSpecConstantOp %i32 IAdd %c1 %c2 +%ary = OpTypeArray %i32 %len)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpTypeRuntimeArrayGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpTypeRuntimeArray %1)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpTypeRuntimeArrayBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpConstant %1 0 +%3 = OpTypeRuntimeArray %2)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpTypeRuntimeArray Element Type '2[%uint_0]' is not a " + "type.")); +} +// TODO: Object of this type can only be created with OpVariable using the +// Unifrom Storage Class + +TEST_F(ValidateIdWithMessage, OpTypeStructGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpTypeFloat 64 +%3 = OpTypePointer Input %1 +%4 = OpTypeStruct %1 %2 %3)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpTypeStructMemberTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpTypeFloat 64 +%3 = OpConstant %2 0.0 +%4 = OpTypeStruct %1 %2 %3)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpTypeStruct Member Type '3[%double_0]' is not " + "a type.")); +} + +TEST_F(ValidateIdWithMessage, OpTypePointerGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpTypePointer Input %1)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpTypePointerBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpConstant %1 0 +%3 = OpTypePointer Input %2)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpTypePointer Type '2[%uint_0]' is not a " + "type.")); +} + +TEST_F(ValidateIdWithMessage, OpTypeFunctionGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeFunction %1)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpTypeFunctionReturnTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpConstant %1 0 +%3 = OpTypeFunction %2)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpTypeFunction Return Type '2[%uint_0]' is not " + "a type.")); +} +TEST_F(ValidateIdWithMessage, OpTypeFunctionParameterBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 0 +%4 = OpTypeFunction %1 %2 %3)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpTypeFunction Parameter Type '3[%uint_0]' is not a " + "type.")); +} + +TEST_F(ValidateIdWithMessage, OpTypeFunctionParameterTypeVoidBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%4 = OpTypeFunction %1 %2 %1)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpTypeFunction Parameter Type '1[%void]' cannot " + "be OpTypeVoid.")); +} + +TEST_F(ValidateIdWithMessage, OpTypePipeGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 16 +%3 = OpTypePipe ReadOnly)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpConstantTrueGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeBool +%2 = OpConstantTrue %1)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpConstantTrueBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpConstantTrue %1)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpConstantTrue Result Type '1[%void]' is not a boolean " + "type.")); +} + +TEST_F(ValidateIdWithMessage, OpConstantFalseGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeBool +%2 = OpConstantTrue %1)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpConstantFalseBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpConstantFalse %1)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpConstantFalse Result Type '1[%void]' is not a boolean " + "type.")); +} + +TEST_F(ValidateIdWithMessage, OpConstantGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpConstant %1 1)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpConstantBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpConstant !1 !0)"; + // The expected failure code is implementation dependent (currently + // INVALID_BINARY because the binary parser catches these cases) and may + // change over time, but this must always fail. + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpConstantCompositeVectorGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 4 +%3 = OpConstant %1 3.14 +%4 = OpConstantComposite %2 %3 %3 %3 %3)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpConstantCompositeVectorWithUndefGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 4 +%3 = OpConstant %1 3.14 +%9 = OpUndef %1 +%4 = OpConstantComposite %2 %3 %3 %3 %9)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpConstantCompositeVectorResultTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 4 +%3 = OpConstant %1 3.14 +%4 = OpConstantComposite %1 %3 %3 %3 %3)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpConstantComposite Result Type '1[%float]' is not a " + "composite type.")); +} +TEST_F(ValidateIdWithMessage, OpConstantCompositeVectorConstituentTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 4 +%4 = OpTypeInt 32 0 +%3 = OpConstant %1 3.14 +%5 = OpConstant %4 42 ; bad type for constant value +%6 = OpConstantComposite %2 %3 %5 %3 %3)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpConstantComposite Constituent '5[%uint_42]'s type " + "does not match Result Type '2[%v4float]'s vector " + "element type.")); +} +TEST_F(ValidateIdWithMessage, + OpConstantCompositeVectorConstituentUndefTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 4 +%4 = OpTypeInt 32 0 +%3 = OpConstant %1 3.14 +%5 = OpUndef %4 ; bad type for undef value +%6 = OpConstantComposite %2 %3 %5 %3 %3)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpConstantComposite Constituent '5[%5]'s type does not " + "match Result Type '2[%v4float]'s vector element type.")); +} +TEST_F(ValidateIdWithMessage, OpConstantCompositeMatrixGood) { + std::string spirv = kGLSL450MemoryModel + R"( + %1 = OpTypeFloat 32 + %2 = OpTypeVector %1 4 + %3 = OpTypeMatrix %2 4 + %4 = OpConstant %1 1.0 + %5 = OpConstant %1 0.0 + %6 = OpConstantComposite %2 %4 %5 %5 %5 + %7 = OpConstantComposite %2 %5 %4 %5 %5 + %8 = OpConstantComposite %2 %5 %5 %4 %5 + %9 = OpConstantComposite %2 %5 %5 %5 %4 +%10 = OpConstantComposite %3 %6 %7 %8 %9)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpConstantCompositeMatrixUndefGood) { + std::string spirv = kGLSL450MemoryModel + R"( + %1 = OpTypeFloat 32 + %2 = OpTypeVector %1 4 + %3 = OpTypeMatrix %2 4 + %4 = OpConstant %1 1.0 + %5 = OpConstant %1 0.0 + %6 = OpConstantComposite %2 %4 %5 %5 %5 + %7 = OpConstantComposite %2 %5 %4 %5 %5 + %8 = OpConstantComposite %2 %5 %5 %4 %5 + %9 = OpUndef %2 +%10 = OpConstantComposite %3 %6 %7 %8 %9)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpConstantCompositeMatrixConstituentTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( + %1 = OpTypeFloat 32 + %2 = OpTypeVector %1 4 +%11 = OpTypeVector %1 3 + %3 = OpTypeMatrix %2 4 + %4 = OpConstant %1 1.0 + %5 = OpConstant %1 0.0 + %6 = OpConstantComposite %2 %4 %5 %5 %5 + %7 = OpConstantComposite %2 %5 %4 %5 %5 + %8 = OpConstantComposite %2 %5 %5 %4 %5 + %9 = OpConstantComposite %11 %5 %5 %5 +%10 = OpConstantComposite %3 %6 %7 %8 %9)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpConstantComposite Constituent '10[%10]' vector " + "component count does not match Result Type " + "'4[%mat4v4float]'s vector component count.")); +} +TEST_F(ValidateIdWithMessage, + OpConstantCompositeMatrixConstituentUndefTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( + %1 = OpTypeFloat 32 + %2 = OpTypeVector %1 4 +%11 = OpTypeVector %1 3 + %3 = OpTypeMatrix %2 4 + %4 = OpConstant %1 1.0 + %5 = OpConstant %1 0.0 + %6 = OpConstantComposite %2 %4 %5 %5 %5 + %7 = OpConstantComposite %2 %5 %4 %5 %5 + %8 = OpConstantComposite %2 %5 %5 %4 %5 + %9 = OpUndef %11 +%10 = OpConstantComposite %3 %6 %7 %8 %9)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpConstantComposite Constituent '10[%10]' vector " + "component count does not match Result Type " + "'4[%mat4v4float]'s vector component count.")); +} +TEST_F(ValidateIdWithMessage, OpConstantCompositeArrayGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpConstant %1 4 +%3 = OpTypeArray %1 %2 +%4 = OpConstantComposite %3 %2 %2 %2 %2)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpConstantCompositeArrayWithUndefGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpConstant %1 4 +%9 = OpUndef %1 +%3 = OpTypeArray %1 %2 +%4 = OpConstantComposite %3 %2 %2 %2 %9)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpConstantCompositeArrayConstConstituentTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpConstant %1 4 +%3 = OpTypeArray %1 %2 +%4 = OpConstantComposite %3 %2 %2 %2 %1)"; // Uses a type as operand + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("Operand 1[%uint] cannot be a " + "type")); +} +TEST_F(ValidateIdWithMessage, OpConstantCompositeArrayConstConstituentBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpConstant %1 4 +%3 = OpTypeArray %1 %2 +%4 = OpTypePointer Uniform %1 +%5 = OpVariable %4 Uniform +%6 = OpConstantComposite %3 %2 %2 %2 %5)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpConstantComposite Constituent '5[%5]' is not a " + "constant or undef.")); +} +TEST_F(ValidateIdWithMessage, OpConstantCompositeArrayConstituentTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpConstant %1 4 +%3 = OpTypeArray %1 %2 +%5 = OpTypeFloat 32 +%6 = OpConstant %5 3.14 ; bad type for const value +%4 = OpConstantComposite %3 %2 %2 %2 %6)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpConstantComposite Constituent " + "'5[%float_3_1400001]'s type does not match Result " + "Type '3[%_arr_uint_uint_4]'s array element " + "type.")); +} +TEST_F(ValidateIdWithMessage, OpConstantCompositeArrayConstituentUndefTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpConstant %1 4 +%3 = OpTypeArray %1 %2 +%5 = OpTypeFloat 32 +%6 = OpUndef %5 ; bad type for undef +%4 = OpConstantComposite %3 %2 %2 %2 %6)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpConstantComposite Constituent " + "'5[%5]'s type does not match Result " + "Type '3[%_arr_uint_uint_4]'s array element " + "type.")); +} +TEST_F(ValidateIdWithMessage, OpConstantCompositeStructGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpTypeInt 64 0 +%3 = OpTypeStruct %1 %1 %2 +%4 = OpConstant %1 42 +%5 = OpConstant %2 4300000000 +%6 = OpConstantComposite %3 %4 %4 %5)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpConstantCompositeStructUndefGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpTypeInt 64 0 +%3 = OpTypeStruct %1 %1 %2 +%4 = OpConstant %1 42 +%5 = OpUndef %2 +%6 = OpConstantComposite %3 %4 %4 %5)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpConstantCompositeStructMemberTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpTypeInt 64 0 +%3 = OpTypeStruct %1 %1 %2 +%4 = OpConstant %1 42 +%5 = OpConstant %2 4300000000 +%6 = OpConstantComposite %3 %4 %5 %4)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpConstantComposite Constituent " + "'5[%ulong_4300000000]' type does not match the " + "Result Type '3[%_struct_3]'s member type.")); +} + +TEST_F(ValidateIdWithMessage, OpConstantCompositeStructMemberUndefTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpTypeInt 64 0 +%3 = OpTypeStruct %1 %1 %2 +%4 = OpConstant %1 42 +%5 = OpUndef %2 +%6 = OpConstantComposite %3 %4 %5 %4)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpConstantComposite Constituent '5[%5]' type " + "does not match the Result Type '3[%_struct_3]'s " + "member type.")); +} + +TEST_F(ValidateIdWithMessage, OpConstantSamplerGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%float = OpTypeFloat 32 +%samplerType = OpTypeSampler +%3 = OpConstantSampler %samplerType ClampToEdge 0 Nearest)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpConstantSamplerResultTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpConstantSampler %1 Clamp 0 Nearest)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpConstantSampler Result Type '1[%float]' is not a sampler " + "type.")); +} + +TEST_F(ValidateIdWithMessage, OpConstantNullGood) { + std::string spirv = kGLSL450MemoryModel + R"( + %1 = OpTypeBool + %2 = OpConstantNull %1 + %3 = OpTypeInt 32 0 + %4 = OpConstantNull %3 + %5 = OpTypeFloat 32 + %6 = OpConstantNull %5 + %7 = OpTypePointer UniformConstant %3 + %8 = OpConstantNull %7 + %9 = OpTypeEvent +%10 = OpConstantNull %9 +%11 = OpTypeDeviceEvent +%12 = OpConstantNull %11 +%13 = OpTypeReserveId +%14 = OpConstantNull %13 +%15 = OpTypeQueue +%16 = OpConstantNull %15 +%17 = OpTypeVector %5 2 +%18 = OpConstantNull %17 +%19 = OpTypeMatrix %17 2 +%20 = OpConstantNull %19 +%25 = OpConstant %3 8 +%21 = OpTypeArray %3 %25 +%22 = OpConstantNull %21 +%23 = OpTypeStruct %3 %5 %1 +%24 = OpConstantNull %23 +%26 = OpTypeArray %17 %25 +%27 = OpConstantNull %26 +%28 = OpTypeStruct %7 %26 %26 %1 +%29 = OpConstantNull %28 +)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpConstantNullBasicBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpConstantNull %1)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpConstantNull Result Type '1[%void]' cannot have a null " + "value.")); +} + +TEST_F(ValidateIdWithMessage, OpConstantNullArrayBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%2 = OpTypeInt 32 0 +%3 = OpTypeSampler +%4 = OpConstant %2 4 +%5 = OpTypeArray %3 %4 +%6 = OpConstantNull %5)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpConstantNull Result Type '4[%_arr_2_uint_4]' cannot have a " + "null value.")); +} + +TEST_F(ValidateIdWithMessage, OpConstantNullStructBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%2 = OpTypeSampler +%3 = OpTypeStruct %2 %2 +%4 = OpConstantNull %3)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpConstantNull Result Type '2[%_struct_2]' " + "cannot have a null value.")); +} + +TEST_F(ValidateIdWithMessage, OpConstantNullRuntimeArrayBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%bool = OpTypeBool +%array = OpTypeRuntimeArray %bool +%null = OpConstantNull %array)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpConstantNull Result Type '2[%_runtimearr_bool]' cannot have " + "a null value.")); +} + +TEST_F(ValidateIdWithMessage, OpSpecConstantTrueGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeBool +%2 = OpSpecConstantTrue %1)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpSpecConstantTrueBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpSpecConstantTrue %1)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Specialization constant must be a boolean type.")); +} + +TEST_F(ValidateIdWithMessage, OpSpecConstantFalseGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeBool +%2 = OpSpecConstantFalse %1)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpSpecConstantFalseBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpSpecConstantFalse %1)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Specialization constant must be a boolean type.")); +} + +TEST_F(ValidateIdWithMessage, OpSpecConstantGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpSpecConstant %1 42)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpSpecConstantBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpSpecConstant !1 !4)"; + // The expected failure code is implementation dependent (currently + // INVALID_BINARY because the binary parser catches these cases) and may + // change over time, but this must always fail. + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Type Id 1 is not a scalar numeric type")); +} + +// Valid: SpecConstantComposite specializes to a vector. +TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeVectorGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 4 +%3 = OpSpecConstant %1 3.14 +%4 = OpConstant %1 3.14 +%5 = OpSpecConstantComposite %2 %3 %3 %4 %4)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Valid: Vector of floats and Undefs. +TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeVectorWithUndefGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 4 +%3 = OpSpecConstant %1 3.14 +%5 = OpConstant %1 3.14 +%9 = OpUndef %1 +%4 = OpSpecConstantComposite %2 %3 %5 %3 %9)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid: result type is float. +TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeVectorResultTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 4 +%3 = OpSpecConstant %1 3.14 +%4 = OpSpecConstantComposite %1 %3 %3 %3 %3)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("is not a composite type")); +} + +// Invalid: Vector contains a mix of Int and Float. +TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeVectorConstituentTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 4 +%4 = OpTypeInt 32 0 +%3 = OpSpecConstant %1 3.14 +%5 = OpConstant %4 42 ; bad type for constant value +%6 = OpSpecConstantComposite %2 %3 %5 %3 %3)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpSpecConstantComposite Constituent " + "'5[%uint_42]'s type does not match Result Type " + "'2[%v4float]'s vector element type.")); +} + +// Invalid: Constituent is not a constant +TEST_F(ValidateIdWithMessage, + OpSpecConstantCompositeVectorConstituentNotConstantBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 4 +%3 = OpTypeInt 32 0 +%4 = OpSpecConstant %1 3.14 +%5 = OpTypePointer Uniform %1 +%6 = OpVariable %5 Uniform +%7 = OpSpecConstantComposite %2 %6 %4 %4 %4)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpSpecConstantComposite Constituent '6[%6]' is " + "not a constant or undef.")); +} + +// Invalid: Vector contains a mix of Undef-int and Float. +TEST_F(ValidateIdWithMessage, + OpSpecConstantCompositeVectorConstituentUndefTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 4 +%4 = OpTypeInt 32 0 +%3 = OpSpecConstant %1 3.14 +%5 = OpUndef %4 ; bad type for undef value +%6 = OpSpecConstantComposite %2 %3 %5 %3 %3)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpSpecConstantComposite Constituent '5[%5]'s " + "type does not match Result Type '2[%v4float]'s " + "vector element type.")); +} + +// Invalid: Vector expects 3 components, but 4 specified. +TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeVectorNumComponentsBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeVector %1 3 +%3 = OpConstant %1 3.14 +%5 = OpSpecConstant %1 4.0 +%6 = OpSpecConstantComposite %2 %3 %5 %3 %3)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpSpecConstantComposite Constituent count does " + "not match Result Type '2[%v3float]'s vector " + "component count.")); +} + +// Valid: 4x4 matrix of floats +TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeMatrixGood) { + std::string spirv = kGLSL450MemoryModel + R"( + %1 = OpTypeFloat 32 + %2 = OpTypeVector %1 4 + %3 = OpTypeMatrix %2 4 + %4 = OpConstant %1 1.0 + %5 = OpSpecConstant %1 0.0 + %6 = OpSpecConstantComposite %2 %4 %5 %5 %5 + %7 = OpSpecConstantComposite %2 %5 %4 %5 %5 + %8 = OpSpecConstantComposite %2 %5 %5 %4 %5 + %9 = OpSpecConstantComposite %2 %5 %5 %5 %4 +%10 = OpSpecConstantComposite %3 %6 %7 %8 %9)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Valid: Matrix in which one column is Undef +TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeMatrixUndefGood) { + std::string spirv = kGLSL450MemoryModel + R"( + %1 = OpTypeFloat 32 + %2 = OpTypeVector %1 4 + %3 = OpTypeMatrix %2 4 + %4 = OpConstant %1 1.0 + %5 = OpSpecConstant %1 0.0 + %6 = OpSpecConstantComposite %2 %4 %5 %5 %5 + %7 = OpSpecConstantComposite %2 %5 %4 %5 %5 + %8 = OpSpecConstantComposite %2 %5 %5 %4 %5 + %9 = OpUndef %2 +%10 = OpSpecConstantComposite %3 %6 %7 %8 %9)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid: Matrix in which the sizes of column vectors are not equal. +TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeMatrixConstituentTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( + %1 = OpTypeFloat 32 + %2 = OpTypeVector %1 4 + %3 = OpTypeVector %1 3 + %4 = OpTypeMatrix %2 4 + %5 = OpSpecConstant %1 1.0 + %6 = OpConstant %1 0.0 + %7 = OpSpecConstantComposite %2 %5 %6 %6 %6 + %8 = OpSpecConstantComposite %2 %6 %5 %6 %6 + %9 = OpSpecConstantComposite %2 %6 %6 %5 %6 + %10 = OpSpecConstantComposite %3 %6 %6 %6 +%11 = OpSpecConstantComposite %4 %7 %8 %9 %10)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpSpecConstantComposite Constituent '10[%10]' " + "vector component count does not match Result Type " + " '4[%mat4v4float]'s vector component count.")); +} + +// Invalid: Matrix type expects 4 columns but only 3 specified. +TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeMatrixNumColsBad) { + std::string spirv = kGLSL450MemoryModel + R"( + %1 = OpTypeFloat 32 + %2 = OpTypeVector %1 4 + %3 = OpTypeMatrix %2 4 + %4 = OpSpecConstant %1 1.0 + %5 = OpConstant %1 0.0 + %6 = OpSpecConstantComposite %2 %4 %5 %5 %5 + %7 = OpSpecConstantComposite %2 %5 %4 %5 %5 + %8 = OpSpecConstantComposite %2 %5 %5 %4 %5 +%10 = OpSpecConstantComposite %3 %6 %7 %8)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpSpecConstantComposite Constituent count does " + "not match Result Type '3[%mat4v4float]'s matrix column " + "count.")); +} + +// Invalid: Composite contains a non-const/undef component +TEST_F(ValidateIdWithMessage, + OpSpecConstantCompositeMatrixConstituentNotConstBad) { + std::string spirv = kGLSL450MemoryModel + R"( + %1 = OpTypeFloat 32 + %2 = OpConstant %1 0.0 + %3 = OpTypeVector %1 4 + %4 = OpTypeMatrix %3 4 + %5 = OpSpecConstantComposite %3 %2 %2 %2 %2 + %6 = OpTypePointer Uniform %1 + %7 = OpVariable %6 Uniform + %8 = OpSpecConstantComposite %4 %5 %5 %5 %7)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpSpecConstantComposite Constituent '7[%7]' is " + "not a constant composite or undef.")); +} + +// Invalid: Composite contains a column that is *not* a vector (it's an array) +TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeMatrixColTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( + %1 = OpTypeFloat 32 + %2 = OpTypeInt 32 0 + %3 = OpSpecConstant %2 4 + %4 = OpConstant %1 0.0 + %5 = OpTypeVector %1 4 + %6 = OpTypeArray %2 %3 + %7 = OpTypeMatrix %5 4 + %8 = OpSpecConstantComposite %6 %3 %3 %3 %3 + %9 = OpSpecConstantComposite %5 %4 %4 %4 %4 + %10 = OpSpecConstantComposite %7 %9 %9 %9 %8)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpSpecConstantComposite Constituent '8[%8]' type " + "does not match Result Type '7[%mat4v4float]'s " + "matrix column type.")); +} + +// Invalid: Matrix with an Undef column of the wrong size. +TEST_F(ValidateIdWithMessage, + OpSpecConstantCompositeMatrixConstituentUndefTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( + %1 = OpTypeFloat 32 + %2 = OpTypeVector %1 4 + %3 = OpTypeVector %1 3 + %4 = OpTypeMatrix %2 4 + %5 = OpSpecConstant %1 1.0 + %6 = OpSpecConstant %1 0.0 + %7 = OpSpecConstantComposite %2 %5 %6 %6 %6 + %8 = OpSpecConstantComposite %2 %6 %5 %6 %6 + %9 = OpSpecConstantComposite %2 %6 %6 %5 %6 + %10 = OpUndef %3 + %11 = OpSpecConstantComposite %4 %7 %8 %9 %10)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpSpecConstantComposite Constituent '10[%10]' " + "vector component count does not match Result Type " + " '4[%mat4v4float]'s vector component count.")); +} + +// Invalid: Matrix in which some columns are Int and some are Float. +TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeMatrixColumnTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( + %1 = OpTypeInt 32 0 + %2 = OpTypeFloat 32 + %3 = OpTypeVector %1 2 + %4 = OpTypeVector %2 2 + %5 = OpTypeMatrix %4 2 + %6 = OpSpecConstant %1 42 + %7 = OpConstant %2 3.14 + %8 = OpSpecConstantComposite %3 %6 %6 + %9 = OpSpecConstantComposite %4 %7 %7 +%10 = OpSpecConstantComposite %5 %8 %9)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpSpecConstantComposite Constituent '8[%8]' " + "component type does not match Result Type " + "'5[%mat2v2float]'s matrix column component type.")); +} + +// Valid: Array of integers +TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeArrayGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpSpecConstant %1 4 +%5 = OpConstant %1 5 +%3 = OpTypeArray %1 %2 +%6 = OpTypeArray %1 %5 +%4 = OpSpecConstantComposite %3 %2 %2 %2 %2 +%7 = OpSpecConstantComposite %3 %5 %5 %5 %5)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid: Expecting an array of 4 components, but 3 specified. +TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeArrayNumComponentsBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpConstant %1 4 +%3 = OpTypeArray %1 %2 +%4 = OpSpecConstantComposite %3 %2 %2 %2)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpSpecConstantComposite Constituent count does not " + "match Result Type '3[%_arr_uint_uint_4]'s array " + "length.")); +} + +// Valid: Array of Integers and Undef-int +TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeArrayWithUndefGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpSpecConstant %1 4 +%9 = OpUndef %1 +%3 = OpTypeArray %1 %2 +%4 = OpSpecConstantComposite %3 %2 %2 %2 %9)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid: Array uses a type as operand. +TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeArrayConstConstituentBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpConstant %1 4 +%3 = OpTypeArray %1 %2 +%4 = OpTypePointer Uniform %1 +%5 = OpVariable %4 Uniform +%6 = OpSpecConstantComposite %3 %2 %2 %2 %5)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpSpecConstantComposite Constituent '5[%5]' is " + "not a constant or undef.")); +} + +// Invalid: Array has a mix of Int and Float components. +TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeArrayConstituentTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpConstant %1 4 +%3 = OpTypeArray %1 %2 +%4 = OpTypeFloat 32 +%5 = OpSpecConstant %4 3.14 ; bad type for const value +%6 = OpSpecConstantComposite %3 %2 %2 %2 %5)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpSpecConstantComposite Constituent '5[%5]'s " + "type does not match Result Type " + "'3[%_arr_uint_uint_4]'s array element type.")); +} + +// Invalid: Array has a mix of Int and Undef-float. +TEST_F(ValidateIdWithMessage, + OpSpecConstantCompositeArrayConstituentUndefTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpSpecConstant %1 4 +%3 = OpTypeArray %1 %2 +%5 = OpTypeFloat 32 +%6 = OpUndef %5 ; bad type for undef +%4 = OpSpecConstantComposite %3 %2 %2 %2 %6)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpSpecConstantComposite Constituent '5[%5]'s " + "type does not match Result Type " + "'3[%_arr_uint_2]'s array element type.")); +} + +// Valid: Struct of {Int32,Int32,Int64}. +TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeStructGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpTypeInt 64 0 +%3 = OpTypeStruct %1 %1 %2 +%4 = OpConstant %1 42 +%5 = OpSpecConstant %2 4300000000 +%6 = OpSpecConstantComposite %3 %4 %4 %5)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid: missing one int32 struct member. +TEST_F(ValidateIdWithMessage, + OpSpecConstantCompositeStructMissingComponentBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%3 = OpTypeStruct %1 %1 %1 +%4 = OpConstant %1 42 +%5 = OpSpecConstant %1 430 +%6 = OpSpecConstantComposite %3 %4 %5)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpSpecConstantComposite Constituent " + "'2[%_struct_2]' count does not match Result Type " + " '2[%_struct_2]'s struct member count.")); +} + +// Valid: Struct uses Undef-int64. +TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeStructUndefGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpTypeInt 64 0 +%3 = OpTypeStruct %1 %1 %2 +%4 = OpSpecConstant %1 42 +%5 = OpUndef %2 +%6 = OpSpecConstantComposite %3 %4 %4 %5)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid: Composite contains non-const/undef component. +TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeStructNonConstBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpTypeInt 64 0 +%3 = OpTypeStruct %1 %1 %2 +%4 = OpSpecConstant %1 42 +%5 = OpUndef %2 +%6 = OpTypePointer Uniform %1 +%7 = OpVariable %6 Uniform +%8 = OpSpecConstantComposite %3 %4 %7 %5)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpSpecConstantComposite Constituent '7[%7]' is " + "not a constant or undef.")); +} + +// Invalid: Struct component type does not match expected specialization type. +// Second component was expected to be Int32, but got Int64. +TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeStructMemberTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpTypeInt 64 0 +%3 = OpTypeStruct %1 %1 %2 +%4 = OpConstant %1 42 +%5 = OpSpecConstant %2 4300000000 +%6 = OpSpecConstantComposite %3 %4 %5 %4)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpSpecConstantComposite Constituent '5[%5]' type " + "does not match the Result Type '3[%_struct_3]'s " + "member type.")); +} + +// Invalid: Undef-int64 used when Int32 was expected. +TEST_F(ValidateIdWithMessage, OpSpecConstantCompositeStructMemberUndefTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpTypeInt 64 0 +%3 = OpTypeStruct %1 %1 %2 +%4 = OpSpecConstant %1 42 +%5 = OpUndef %2 +%6 = OpSpecConstantComposite %3 %4 %5 %4)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpSpecConstantComposite Constituent '5[%5]' type " + "does not match the Result Type '3[%_struct_3]'s " + "member type.")); +} + +// TODO: OpSpecConstantOp + +TEST_F(ValidateIdWithMessage, OpVariableGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpTypePointer Input %1 +%3 = OpVariable %2 Input)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpVariableInitializerConstantGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpTypePointer Input %1 +%3 = OpConstant %1 42 +%4 = OpVariable %2 Input %3)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpVariableInitializerGlobalVariableGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpTypePointer Uniform %1 +%3 = OpVariable %2 Uniform +%4 = OpTypePointer Private %2 ; pointer to pointer +%5 = OpVariable %4 Private %3 +)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +// TODO: Positive test OpVariable with OpConstantNull of OpTypePointer +TEST_F(ValidateIdWithMessage, OpVariableResultTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpVariable %1 Input)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpVariable Result Type '1[%uint]' is not a pointer " + "type.")); +} +TEST_F(ValidateIdWithMessage, OpVariableInitializerIsTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpTypePointer Input %1 +%3 = OpVariable %2 Input %2)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("Operand 2[%_ptr_Input_uint] " + "cannot be a type")); +} + +TEST_F(ValidateIdWithMessage, OpVariableInitializerIsFunctionVarBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%int = OpTypeInt 32 0 +%ptrint = OpTypePointer Function %int +%ptrptrint = OpTypePointer Function %ptrint +%void = OpTypeVoid +%fnty = OpTypeFunction %void +%main = OpFunction %void None %fnty +%entry = OpLabel +%var = OpVariable %ptrint Function +%varinit = OpVariable %ptrptrint Function %var ; Can't initialize function variable. +OpReturn +OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpVariable Initializer '8[%8]' is not a constant " + "or module-scope variable")); +} + +TEST_F(ValidateIdWithMessage, OpVariableInitializerIsModuleVarGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%int = OpTypeInt 32 0 +%ptrint = OpTypePointer Uniform %int +%mvar = OpVariable %ptrint Uniform +%ptrptrint = OpTypePointer Function %ptrint +%void = OpTypeVoid +%fnty = OpTypeFunction %void +%main = OpFunction %void None %fnty +%entry = OpLabel +%goodvar = OpVariable %ptrptrint Function %mvar ; This is ok +OpReturn +OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpVariableContainsBoolBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%bool = OpTypeBool +%int = OpTypeInt 32 0 +%block = OpTypeStruct %bool %int +%_ptr_Uniform_block = OpTypePointer Uniform %block +%var = OpVariable %_ptr_Uniform_block Uniform +%void = OpTypeVoid +%fnty = OpTypeFunction %void +%main = OpFunction %void None %fnty +%entry = OpLabel +%load = OpLoad %block %var +OpReturn +OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("If OpTypeBool is stored in conjunction with OpVariable" + ", it can only be used with non-externally visible " + "shader Storage Classes: Workgroup, CrossWorkgroup, " + "Private, and Function")); +} + +TEST_F(ValidateIdWithMessage, OpVariableContainsBoolPointerGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%bool = OpTypeBool +%boolptr = OpTypePointer Uniform %bool +%int = OpTypeInt 32 0 +%block = OpTypeStruct %boolptr %int +%_ptr_Uniform_block = OpTypePointer Uniform %block +%var = OpVariable %_ptr_Uniform_block Uniform +%void = OpTypeVoid +%fnty = OpTypeFunction %void +%main = OpFunction %void None %fnty +%entry = OpLabel +%load = OpLoad %block %var +OpReturn +OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpVariableContainsBuiltinBoolGood) { + std::string spirv = kGLSL450MemoryModel + R"( +OpMemberDecorate %input 0 BuiltIn FrontFacing +%bool = OpTypeBool +%input = OpTypeStruct %bool +%_ptr_input = OpTypePointer Input %input +%var = OpVariable %_ptr_input Input +%void = OpTypeVoid +%fnty = OpTypeFunction %void +%main = OpFunction %void None %fnty +%entry = OpLabel +%load = OpLoad %input %var +OpReturn +OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpVariableContainsRayPayloadBoolGood) { + std::string spirv = R"( +OpCapability RayTracingNV +OpCapability Shader +OpCapability Linkage +OpExtension "SPV_NV_ray_tracing" +OpMemoryModel Logical GLSL450 +%bool = OpTypeBool +%PerRayData = OpTypeStruct %bool +%_ptr_PerRayData = OpTypePointer RayPayloadNV %PerRayData +%var = OpVariable %_ptr_PerRayData RayPayloadNV +%void = OpTypeVoid +%fnty = OpTypeFunction %void +%main = OpFunction %void None %fnty +%entry = OpLabel +%load = OpLoad %PerRayData %var +OpReturn +OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpVariablePointerNoVariablePointersBad) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%_ptr_workgroup_int = OpTypePointer Workgroup %int +%_ptr_function_ptr = OpTypePointer Function %_ptr_workgroup_int +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +%var = OpVariable %_ptr_function_ptr Function +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "In Logical addressing, variables may not allocate a pointer type")); +} + +TEST_F(ValidateIdWithMessage, + OpVariablePointerNoVariablePointersRelaxedLogicalGood) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%_ptr_workgroup_int = OpTypePointer Workgroup %int +%_ptr_function_ptr = OpTypePointer Function %_ptr_workgroup_int +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +%var = OpVariable %_ptr_function_ptr Function +OpReturn +OpFunctionEnd +)"; + + auto options = getValidatorOptions(); + options->relax_logical_pointer = true; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, + OpVariablePointerVariablePointersStorageBufferGood) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability VariablePointersStorageBuffer +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%_ptr_workgroup_int = OpTypePointer Workgroup %int +%_ptr_function_ptr = OpTypePointer Function %_ptr_workgroup_int +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +%var = OpVariable %_ptr_function_ptr Function +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpVariablePointerVariablePointersGood) { + const std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%_ptr_workgroup_int = OpTypePointer Workgroup %int +%_ptr_function_ptr = OpTypePointer Function %_ptr_workgroup_int +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +%var = OpVariable %_ptr_function_ptr Function +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpVariablePointerVariablePointersBad) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%_ptr_workgroup_int = OpTypePointer Workgroup %int +%_ptr_uniform_ptr = OpTypePointer Uniform %_ptr_workgroup_int +%var = OpVariable %_ptr_uniform_ptr Uniform +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("In Logical addressing with variable pointers, " + "variables that allocate pointers must be in Function " + "or Private storage classes")); +} + +TEST_F(ValidateIdWithMessage, OpLoadGood) { + std::string spirv = kGLSL450MemoryModel + R"( + %1 = OpTypeVoid + %2 = OpTypeInt 32 0 + %3 = OpTypePointer UniformConstant %2 + %4 = OpTypeFunction %1 + %5 = OpVariable %3 UniformConstant + %6 = OpFunction %1 None %4 + %7 = OpLabel + %8 = OpLoad %2 %5 + %9 = OpReturn +%10 = OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// TODO: Add tests that exercise VariablePointersStorageBuffer instead of +// VariablePointers. +void createVariablePointerSpirvProgram(std::ostringstream* spirv, + std::string result_strategy, + bool use_varptr_cap, + bool add_helper_function) { + *spirv << "OpCapability Shader "; + if (use_varptr_cap) { + *spirv << "OpCapability VariablePointers "; + *spirv << "OpExtension \"SPV_KHR_variable_pointers\" "; + } + *spirv << R"( + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + %void = OpTypeVoid + %voidf = OpTypeFunction %void + %bool = OpTypeBool + %i32 = OpTypeInt 32 1 + %f32 = OpTypeFloat 32 + %f32ptr = OpTypePointer Uniform %f32 + %i = OpConstant %i32 1 + %zero = OpConstant %i32 0 + %float_1 = OpConstant %f32 1.0 + %ptr1 = OpVariable %f32ptr Uniform + %ptr2 = OpVariable %f32ptr Uniform + )"; + if (add_helper_function) { + *spirv << R"( + ; //////////////////////////////////////////////////////////// + ;;;; Function that returns a pointer + ; //////////////////////////////////////////////////////////// + %selector_func_type = OpTypeFunction %f32ptr %bool %f32ptr %f32ptr + %choose_input_func = OpFunction %f32ptr None %selector_func_type + %is_neg_param = OpFunctionParameter %bool + %first_ptr_param = OpFunctionParameter %f32ptr + %second_ptr_param = OpFunctionParameter %f32ptr + %selector_func_begin = OpLabel + %result_ptr = OpSelect %f32ptr %is_neg_param %first_ptr_param %second_ptr_param + OpReturnValue %result_ptr + OpFunctionEnd + )"; + } + *spirv << R"( + %main = OpFunction %void None %voidf + %label = OpLabel + )"; + *spirv << result_strategy; + *spirv << R"( + OpReturn + OpFunctionEnd + )"; +} + +// With the VariablePointer Capability, OpLoad should allow loading a +// VaiablePointer. In this test the variable pointer is obtained by an OpSelect +TEST_F(ValidateIdWithMessage, OpLoadVarPtrOpSelectGood) { + std::string result_strategy = R"( + %isneg = OpSLessThan %bool %i %zero + %varptr = OpSelect %f32ptr %isneg %ptr1 %ptr2 + %result = OpLoad %f32 %varptr + )"; + + std::ostringstream spirv; + createVariablePointerSpirvProgram(&spirv, result_strategy, + true /* Add VariablePointers Capability? */, + false /* Use Helper Function? */); + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Without the VariablePointers Capability, OpLoad will not allow loading +// through a variable pointer. +// Disabled since using OpSelect with pointers without VariablePointers will +// fail LogicalsPass. +TEST_F(ValidateIdWithMessage, DISABLED_OpLoadVarPtrOpSelectBad) { + std::string result_strategy = R"( + %isneg = OpSLessThan %bool %i %zero + %varptr = OpSelect %f32ptr %isneg %ptr1 %ptr2 + %result = OpLoad %f32 %varptr + )"; + + std::ostringstream spirv; + createVariablePointerSpirvProgram(&spirv, result_strategy, + false /* Add VariablePointers Capability?*/, + false /* Use Helper Function? */); + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("is not a logical pointer.")); +} + +// With the VariablePointer Capability, OpLoad should allow loading a +// VaiablePointer. In this test the variable pointer is obtained by an OpPhi +TEST_F(ValidateIdWithMessage, OpLoadVarPtrOpPhiGood) { + std::string result_strategy = R"( + %is_neg = OpSLessThan %bool %i %zero + OpSelectionMerge %end_label None + OpBranchConditional %is_neg %take_ptr_1 %take_ptr_2 + %take_ptr_1 = OpLabel + OpBranch %end_label + %take_ptr_2 = OpLabel + OpBranch %end_label + %end_label = OpLabel + %varptr = OpPhi %f32ptr %ptr1 %take_ptr_1 %ptr2 %take_ptr_2 + %result = OpLoad %f32 %varptr + )"; + + std::ostringstream spirv; + createVariablePointerSpirvProgram(&spirv, result_strategy, + true /* Add VariablePointers Capability?*/, + false /* Use Helper Function? */); + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Without the VariablePointers Capability, OpPhi can have a pointer result +// type. +TEST_F(ValidateIdWithMessage, OpPhiBad) { + std::string result_strategy = R"( + %is_neg = OpSLessThan %bool %i %zero + OpSelectionMerge %end_label None + OpBranchConditional %is_neg %take_ptr_1 %take_ptr_2 + %take_ptr_1 = OpLabel + OpBranch %end_label + %take_ptr_2 = OpLabel + OpBranch %end_label + %end_label = OpLabel + %varptr = OpPhi %f32ptr %ptr1 %take_ptr_1 %ptr2 %take_ptr_2 + %result = OpLoad %f32 %varptr + )"; + + std::ostringstream spirv; + createVariablePointerSpirvProgram(&spirv, result_strategy, + false /* Add VariablePointers Capability?*/, + false /* Use Helper Function? */); + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Using pointers with OpPhi requires capability " + "VariablePointers or VariablePointersStorageBuffer")); +} + +// With the VariablePointer Capability, OpLoad should allow loading through a +// VaiablePointer. In this test the variable pointer is obtained from an +// OpFunctionCall (return value from a function) +TEST_F(ValidateIdWithMessage, OpLoadVarPtrOpFunctionCallGood) { + std::ostringstream spirv; + std::string result_strategy = R"( + %isneg = OpSLessThan %bool %i %zero + %varptr = OpFunctionCall %f32ptr %choose_input_func %isneg %ptr1 %ptr2 + %result = OpLoad %f32 %varptr + )"; + + createVariablePointerSpirvProgram(&spirv, result_strategy, + true /* Add VariablePointers Capability?*/, + true /* Use Helper Function? */); + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpLoadResultTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer UniformConstant %2 +%4 = OpTypeFunction %1 +%5 = OpVariable %3 UniformConstant +%6 = OpFunction %1 None %4 +%7 = OpLabel +%8 = OpLoad %3 %5 + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpLoad Result Type " + "'3[%_ptr_UniformConstant_uint]' does not match " + "Pointer '5[%5]'s type.")); +} + +TEST_F(ValidateIdWithMessage, OpLoadPointerBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer UniformConstant %2 +%4 = OpTypeFunction %1 +%5 = OpFunction %1 None %4 +%6 = OpLabel +%7 = OpLoad %2 %8 + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + // Prove that SSA checks trigger for a bad Id value. + // The next test case show the not-a-logical-pointer case. + EXPECT_THAT(getDiagnosticString(), HasSubstr("ID 8[%8] has not been " + "defined")); +} + +// Disabled as bitcasting type to object is now not valid. +TEST_F(ValidateIdWithMessage, DISABLED_OpLoadLogicalPointerBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypeFloat 32 +%4 = OpTypePointer UniformConstant %2 +%5 = OpTypePointer UniformConstant %3 +%6 = OpTypeFunction %1 +%7 = OpFunction %1 None %6 +%8 = OpLabel +%9 = OpBitcast %5 %4 ; Not valid in logical addressing +%10 = OpLoad %3 %9 ; Should trigger message + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + // Once we start checking bitcasts, we might catch that + // as the error first, instead of catching it here. + // I don't know if it's possible to generate a bad case + // if/when the validator is complete. + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpLoad Pointer '9' is not a logical pointer.")); +} + +TEST_F(ValidateIdWithMessage, OpStoreGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Uniform %2 +%4 = OpTypeFunction %1 +%5 = OpConstant %2 42 +%6 = OpVariable %3 Uniform +%7 = OpFunction %1 None %4 +%8 = OpLabel + OpStore %6 %5 + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpStorePointerBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer UniformConstant %2 +%4 = OpTypeFunction %1 +%5 = OpConstant %2 42 +%6 = OpVariable %3 UniformConstant +%7 = OpConstant %2 0 +%8 = OpFunction %1 None %4 +%9 = OpLabel + OpStore %7 %5 + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpStore Pointer '7[%uint_0]' is not a logical " + "pointer.")); +} + +// Disabled as bitcasting type to object is now not valid. +TEST_F(ValidateIdWithMessage, DISABLED_OpStoreLogicalPointerBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypeFloat 32 +%4 = OpTypePointer UniformConstant %2 +%5 = OpTypePointer UniformConstant %3 +%6 = OpTypeFunction %1 +%7 = OpConstantNull %5 +%8 = OpFunction %1 None %6 +%9 = OpLabel +%10 = OpBitcast %5 %4 ; Not valid in logical addressing +%11 = OpStore %10 %7 ; Should trigger message + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpStore Pointer '10' is not a logical pointer.")); +} + +// Without the VariablePointer Capability, OpStore should may not store +// through a variable pointer. +// Disabled since using OpSelect with pointers without VariablePointers will +// fail LogicalsPass. +TEST_F(ValidateIdWithMessage, DISABLED_OpStoreVarPtrBad) { + std::string result_strategy = R"( + %isneg = OpSLessThan %bool %i %zero + %varptr = OpSelect %f32ptr %isneg %ptr1 %ptr2 + OpStore %varptr %float_1 + )"; + + std::ostringstream spirv; + createVariablePointerSpirvProgram( + &spirv, result_strategy, false /* Add VariablePointers Capability? */, + false /* Use Helper Function? */); + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("is not a logical pointer.")); +} + +// With the VariablePointer Capability, OpStore should allow storing through a +// variable pointer. +TEST_F(ValidateIdWithMessage, OpStoreVarPtrGood) { + std::string result_strategy = R"( + %isneg = OpSLessThan %bool %i %zero + %varptr = OpSelect %f32ptr %isneg %ptr1 %ptr2 + OpStore %varptr %float_1 + )"; + + std::ostringstream spirv; + createVariablePointerSpirvProgram(&spirv, result_strategy, + true /* Add VariablePointers Capability? */, + false /* Use Helper Function? */); + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpStoreObjectGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Uniform %2 +%4 = OpTypeFunction %1 +%5 = OpConstant %2 42 +%6 = OpVariable %3 Uniform +%7 = OpFunction %1 None %4 +%8 = OpLabel +%9 = OpUndef %1 + OpStore %6 %9 + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpStore Object '9[%9]'s type is void.")); +} +TEST_F(ValidateIdWithMessage, OpStoreTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%9 = OpTypeFloat 32 +%3 = OpTypePointer Uniform %2 +%4 = OpTypeFunction %1 +%5 = OpConstant %9 3.14 +%6 = OpVariable %3 Uniform +%7 = OpFunction %1 None %4 +%8 = OpLabel + OpStore %6 %5 + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpStore Pointer '7[%7]'s type does not match " + "Object '6[%float_3_1400001]'s type.")); +} + +// The next series of test check test a relaxation of the rules for stores to +// structs. The first test checks that we get a failure when the option is not +// set to relax the rule. +// TODO: Add tests for layout compatible arrays and matricies when the validator +// relaxes the rules for them as well. Also need test to check for layout +// decorations specific to those types. +TEST_F(ValidateIdWithMessage, OpStoreTypeBadStruct) { + std::string spirv = kGLSL450MemoryModel + R"( + OpMemberDecorate %1 0 Offset 0 + OpMemberDecorate %1 1 Offset 4 + OpMemberDecorate %2 0 Offset 0 + OpMemberDecorate %2 1 Offset 4 +%3 = OpTypeVoid +%4 = OpTypeFloat 32 +%1 = OpTypeStruct %4 %4 +%5 = OpTypePointer Uniform %1 +%2 = OpTypeStruct %4 %4 +%6 = OpTypeFunction %3 +%7 = OpConstant %4 3.14 +%8 = OpVariable %5 Uniform +%9 = OpFunction %3 None %6 +%10 = OpLabel +%11 = OpCompositeConstruct %2 %7 %7 + OpStore %8 %11 + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpStore Pointer '8[%8]'s type does not match " + "Object '11[%11]'s type.")); +} + +// Same code as the last test. The difference is that we relax the rule. +// Because the structs %3 and %5 are defined the same way. +TEST_F(ValidateIdWithMessage, OpStoreTypeRelaxedStruct) { + std::string spirv = kGLSL450MemoryModel + R"( + OpMemberDecorate %1 0 Offset 0 + OpMemberDecorate %1 1 Offset 4 + OpMemberDecorate %2 0 Offset 0 + OpMemberDecorate %2 1 Offset 4 +%3 = OpTypeVoid +%4 = OpTypeFloat 32 +%1 = OpTypeStruct %4 %4 +%5 = OpTypePointer Uniform %1 +%2 = OpTypeStruct %4 %4 +%6 = OpTypeFunction %3 +%7 = OpConstant %4 3.14 +%8 = OpVariable %5 Uniform +%9 = OpFunction %3 None %6 +%10 = OpLabel +%11 = OpCompositeConstruct %2 %7 %7 + OpStore %8 %11 + OpReturn + OpFunctionEnd)"; + spvValidatorOptionsSetRelaxStoreStruct(options_, true); + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Same code as the last test excect for an extra decoration on one of the +// members. With the relaxed rules, the code is still valid. +TEST_F(ValidateIdWithMessage, OpStoreTypeRelaxedStructWithExtraDecoration) { + std::string spirv = kGLSL450MemoryModel + R"( + OpMemberDecorate %1 0 Offset 0 + OpMemberDecorate %1 1 Offset 4 + OpMemberDecorate %1 0 RelaxedPrecision + OpMemberDecorate %2 0 Offset 0 + OpMemberDecorate %2 1 Offset 4 +%3 = OpTypeVoid +%4 = OpTypeFloat 32 +%1 = OpTypeStruct %4 %4 +%5 = OpTypePointer Uniform %1 +%2 = OpTypeStruct %4 %4 +%6 = OpTypeFunction %3 +%7 = OpConstant %4 3.14 +%8 = OpVariable %5 Uniform +%9 = OpFunction %3 None %6 +%10 = OpLabel +%11 = OpCompositeConstruct %2 %7 %7 + OpStore %8 %11 + OpReturn + OpFunctionEnd)"; + spvValidatorOptionsSetRelaxStoreStruct(options_, true); + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// This test check that we recursively traverse the struct to check if they are +// interchangable. +TEST_F(ValidateIdWithMessage, OpStoreTypeRelaxedNestedStruct) { + std::string spirv = kGLSL450MemoryModel + R"( + OpMemberDecorate %1 0 Offset 0 + OpMemberDecorate %1 1 Offset 4 + OpMemberDecorate %2 0 Offset 0 + OpMemberDecorate %2 1 Offset 8 + OpMemberDecorate %3 0 Offset 0 + OpMemberDecorate %3 1 Offset 4 + OpMemberDecorate %4 0 Offset 0 + OpMemberDecorate %4 1 Offset 8 +%5 = OpTypeVoid +%6 = OpTypeInt 32 0 +%7 = OpTypeFloat 32 +%1 = OpTypeStruct %7 %6 +%2 = OpTypeStruct %1 %1 +%8 = OpTypePointer Uniform %2 +%3 = OpTypeStruct %7 %6 +%4 = OpTypeStruct %3 %3 +%9 = OpTypeFunction %5 +%10 = OpConstant %6 7 +%11 = OpConstant %7 3.14 +%12 = OpConstantComposite %3 %11 %10 +%13 = OpVariable %8 Uniform +%14 = OpFunction %5 None %9 +%15 = OpLabel +%16 = OpCompositeConstruct %4 %12 %12 + OpStore %13 %16 + OpReturn + OpFunctionEnd)"; + spvValidatorOptionsSetRelaxStoreStruct(options_, true); + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// This test check that the even with the relaxed rules an error is identified +// if the members of the struct are in a different order. +TEST_F(ValidateIdWithMessage, OpStoreTypeBadRelaxedStruct1) { + std::string spirv = kGLSL450MemoryModel + R"( + OpMemberDecorate %1 0 Offset 0 + OpMemberDecorate %1 1 Offset 4 + OpMemberDecorate %2 0 Offset 0 + OpMemberDecorate %2 1 Offset 8 + OpMemberDecorate %3 0 Offset 0 + OpMemberDecorate %3 1 Offset 4 + OpMemberDecorate %4 0 Offset 0 + OpMemberDecorate %4 1 Offset 8 +%5 = OpTypeVoid +%6 = OpTypeInt 32 0 +%7 = OpTypeFloat 32 +%1 = OpTypeStruct %6 %7 +%2 = OpTypeStruct %1 %1 +%8 = OpTypePointer Uniform %2 +%3 = OpTypeStruct %7 %6 +%4 = OpTypeStruct %3 %3 +%9 = OpTypeFunction %5 +%10 = OpConstant %6 7 +%11 = OpConstant %7 3.14 +%12 = OpConstantComposite %3 %11 %10 +%13 = OpVariable %8 Uniform +%14 = OpFunction %5 None %9 +%15 = OpLabel +%16 = OpCompositeConstruct %4 %12 %12 + OpStore %13 %16 + OpReturn + OpFunctionEnd)"; + spvValidatorOptionsSetRelaxStoreStruct(options_, true); + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpStore Pointer '13[%13]'s layout does not match Object " + " '16[%16]'s layout.")); +} + +// This test check that the even with the relaxed rules an error is identified +// if the members of the struct are at different offsets. +TEST_F(ValidateIdWithMessage, OpStoreTypeBadRelaxedStruct2) { + std::string spirv = kGLSL450MemoryModel + R"( + OpMemberDecorate %1 0 Offset 4 + OpMemberDecorate %1 1 Offset 0 + OpMemberDecorate %2 0 Offset 0 + OpMemberDecorate %2 1 Offset 8 + OpMemberDecorate %3 0 Offset 0 + OpMemberDecorate %3 1 Offset 4 + OpMemberDecorate %4 0 Offset 0 + OpMemberDecorate %4 1 Offset 8 +%5 = OpTypeVoid +%6 = OpTypeInt 32 0 +%7 = OpTypeFloat 32 +%1 = OpTypeStruct %7 %6 +%2 = OpTypeStruct %1 %1 +%8 = OpTypePointer Uniform %2 +%3 = OpTypeStruct %7 %6 +%4 = OpTypeStruct %3 %3 +%9 = OpTypeFunction %5 +%10 = OpConstant %6 7 +%11 = OpConstant %7 3.14 +%12 = OpConstantComposite %3 %11 %10 +%13 = OpVariable %8 Uniform +%14 = OpFunction %5 None %9 +%15 = OpLabel +%16 = OpCompositeConstruct %4 %12 %12 + OpStore %13 %16 + OpReturn + OpFunctionEnd)"; + spvValidatorOptionsSetRelaxStoreStruct(options_, true); + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpStore Pointer '13[%13]'s layout does not match Object " + " '16[%16]'s layout.")); +} + +TEST_F(ValidateIdWithMessage, OpStoreTypeRelaxedLogicalPointerReturnPointer) { + const std::string spirv = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 +%1 = OpTypeInt 32 1 +%2 = OpTypePointer Function %1 +%3 = OpTypeFunction %2 %2 +%4 = OpFunction %2 None %3 +%5 = OpFunctionParameter %2 +%6 = OpLabel + OpReturnValue %5 + OpFunctionEnd)"; + + spvValidatorOptionsSetRelaxLogicalPointer(options_, true); + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpStoreTypeRelaxedLogicalPointerAllocPointer) { + const std::string spirv = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + %1 = OpTypeVoid + %2 = OpTypeInt 32 1 + %3 = OpTypeFunction %1 ; void(void) + %4 = OpTypePointer Uniform %2 ; int* + %5 = OpTypePointer Private %4 ; int** (Private) + %6 = OpTypePointer Function %4 ; int** (Function) + %7 = OpVariable %5 Private + %8 = OpFunction %1 None %3 + %9 = OpLabel +%10 = OpVariable %6 Function + OpReturn + OpFunctionEnd)"; + + spvValidatorOptionsSetRelaxLogicalPointer(options_, true); + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpStoreVoid) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Uniform %2 +%4 = OpTypeFunction %1 +%6 = OpVariable %3 Uniform +%7 = OpFunction %1 None %4 +%8 = OpLabel +%9 = OpFunctionCall %1 %7 + OpStore %6 %9 + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpStore Object '8[%8]'s type is void.")); +} + +TEST_F(ValidateIdWithMessage, OpStoreLabel) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Uniform %2 +%4 = OpTypeFunction %1 +%6 = OpVariable %3 Uniform +%7 = OpFunction %1 None %4 +%8 = OpLabel + OpStore %6 %8 + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Operand 7[%7] requires a type")); +} + +// TODO: enable when this bug is fixed: +// https://cvs.khronos.org/bugzilla/show_bug.cgi?id=15404 +TEST_F(ValidateIdWithMessage, DISABLED_OpStoreFunction) { + std::string spirv = kGLSL450MemoryModel + R"( +%2 = OpTypeInt 32 0 +%3 = OpTypePointer UniformConstant %2 +%4 = OpTypeFunction %2 +%5 = OpConstant %2 123 +%6 = OpVariable %3 UniformConstant +%7 = OpFunction %2 None %4 +%8 = OpLabel + OpStore %6 %7 + OpReturnValue %5 + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpStoreBuiltin) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" %gl_GlobalInvocationID + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 450 + OpName %main "main" + + OpName %gl_GlobalInvocationID "gl_GlobalInvocationID" + OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId + + %int = OpTypeInt 32 1 + %uint = OpTypeInt 32 0 + %v3uint = OpTypeVector %uint 3 +%_ptr_Input_v3uint = OpTypePointer Input %v3uint +%gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input + + %zero = OpConstant %uint 0 + %v3uint_000 = OpConstantComposite %v3uint %zero %zero %zero + + %void = OpTypeVoid + %voidfunc = OpTypeFunction %void + %main = OpFunction %void None %voidfunc + %lmain = OpLabel + + OpStore %gl_GlobalInvocationID %v3uint_000 + + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("storage class is read-only")); +} + +TEST_F(ValidateIdWithMessage, OpCopyMemoryGood) { + std::string spirv = kGLSL450MemoryModel + R"( + %1 = OpTypeVoid + %2 = OpTypeInt 32 0 + %3 = OpTypePointer UniformConstant %2 + %4 = OpConstant %2 42 + %5 = OpVariable %3 UniformConstant %4 + %6 = OpTypePointer Function %2 + %7 = OpTypeFunction %1 + %8 = OpFunction %1 None %7 + %9 = OpLabel +%10 = OpVariable %6 Function + OpCopyMemory %10 %5 None + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpCopyMemoryNonPointerTarget) { + const std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Uniform %2 +%4 = OpTypeFunction %1 %2 %3 +%5 = OpFunction %1 None %4 +%6 = OpFunctionParameter %2 +%7 = OpFunctionParameter %3 +%8 = OpLabel +OpCopyMemory %6 %7 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Target operand '6[%6]' is not a pointer.")); +} + +TEST_F(ValidateIdWithMessage, OpCopyMemoryNonPointerSource) { + const std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Uniform %2 +%4 = OpTypeFunction %1 %2 %3 +%5 = OpFunction %1 None %4 +%6 = OpFunctionParameter %2 +%7 = OpFunctionParameter %3 +%8 = OpLabel +OpCopyMemory %7 %6 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Source operand '6[%6]' is not a pointer.")); +} + +TEST_F(ValidateIdWithMessage, OpCopyMemoryBad) { + std::string spirv = kGLSL450MemoryModel + R"( + %1 = OpTypeVoid + %2 = OpTypeInt 32 0 + %3 = OpTypePointer UniformConstant %2 + %4 = OpConstant %2 42 + %5 = OpVariable %3 UniformConstant %4 +%11 = OpTypeFloat 32 + %6 = OpTypePointer Function %11 + %7 = OpTypeFunction %1 + %8 = OpFunction %1 None %7 + %9 = OpLabel +%10 = OpVariable %6 Function + OpCopyMemory %10 %5 None + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Target '5[%5]'s type does not match " + "Source '2[%uint]'s type.")); +} + +TEST_F(ValidateIdWithMessage, OpCopyMemoryVoidTarget) { + const std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Uniform %1 +%4 = OpTypePointer Uniform %2 +%5 = OpTypeFunction %1 %3 %4 +%6 = OpFunction %1 None %5 +%7 = OpFunctionParameter %3 +%8 = OpFunctionParameter %4 +%9 = OpLabel +OpCopyMemory %7 %8 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Target operand '7[%7]' cannot be a void " + "pointer.")); +} + +TEST_F(ValidateIdWithMessage, OpCopyMemoryVoidSource) { + const std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Uniform %1 +%4 = OpTypePointer Uniform %2 +%5 = OpTypeFunction %1 %3 %4 +%6 = OpFunction %1 None %5 +%7 = OpFunctionParameter %3 +%8 = OpFunctionParameter %4 +%9 = OpLabel +OpCopyMemory %8 %7 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Source operand '7[%7]' cannot be a void " + "pointer.")); +} + +TEST_F(ValidateIdWithMessage, OpCopyMemorySizedGood) { + std::string spirv = kGLSL450MemoryModel + R"( + %1 = OpTypeVoid + %2 = OpTypeInt 32 0 + %3 = OpTypePointer UniformConstant %2 + %4 = OpTypePointer Function %2 + %5 = OpConstant %2 4 + %6 = OpVariable %3 UniformConstant %5 + %7 = OpTypeFunction %1 + %8 = OpFunction %1 None %7 + %9 = OpLabel +%10 = OpVariable %4 Function + OpCopyMemorySized %10 %6 %5 None + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpCopyMemorySizedTargetBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer UniformConstant %2 +%4 = OpTypePointer Function %2 +%5 = OpConstant %2 4 +%6 = OpVariable %3 UniformConstant %5 +%7 = OpTypeFunction %1 +%8 = OpFunction %1 None %7 +%9 = OpLabel + OpCopyMemorySized %5 %5 %5 None + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Target operand '5[%uint_4]' is not a pointer.")); +} +TEST_F(ValidateIdWithMessage, OpCopyMemorySizedSourceBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer UniformConstant %2 +%4 = OpTypePointer Function %2 +%5 = OpConstant %2 4 +%6 = OpTypeFunction %1 +%7 = OpFunction %1 None %6 +%8 = OpLabel +%9 = OpVariable %4 Function + OpCopyMemorySized %9 %5 %5 None + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Source operand '5[%uint_4]' is not a pointer.")); +} +TEST_F(ValidateIdWithMessage, OpCopyMemorySizedSizeBad) { + std::string spirv = kGLSL450MemoryModel + R"( + %1 = OpTypeVoid + %2 = OpTypeInt 32 0 + %3 = OpTypePointer UniformConstant %2 + %4 = OpTypePointer Function %2 + %5 = OpConstant %2 4 + %6 = OpVariable %3 UniformConstant %5 + %7 = OpTypeFunction %1 + %8 = OpFunction %1 None %7 + %9 = OpLabel +%10 = OpVariable %4 Function + OpCopyMemorySized %10 %6 %6 None + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Size operand '6[%6]' must be a scalar integer type.")); +} +TEST_F(ValidateIdWithMessage, OpCopyMemorySizedSizeTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( + %1 = OpTypeVoid + %2 = OpTypeInt 32 0 + %3 = OpTypePointer UniformConstant %2 + %4 = OpTypePointer Function %2 + %5 = OpConstant %2 4 + %6 = OpVariable %3 UniformConstant %5 + %7 = OpTypeFunction %1 +%11 = OpTypeFloat 32 +%12 = OpConstant %11 1.0 + %8 = OpFunction %1 None %7 + %9 = OpLabel +%10 = OpVariable %4 Function + OpCopyMemorySized %10 %6 %12 None + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Size operand '9[%float_1]' must be a scalar integer " + "type.")); +} + +TEST_F(ValidateIdWithMessage, OpCopyMemorySizedSizeConstantNull) { + const std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstantNull %2 +%4 = OpTypePointer Uniform %2 +%5 = OpTypeFloat 32 +%6 = OpTypePointer UniformConstant %5 +%7 = OpTypeFunction %1 %4 %6 +%8 = OpFunction %1 None %7 +%9 = OpFunctionParameter %4 +%10 = OpFunctionParameter %6 +%11 = OpLabel +OpCopyMemorySized %9 %10 %3 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Size operand '3[%3]' cannot be a constant " + "zero.")); +} + +TEST_F(ValidateIdWithMessage, OpCopyMemorySizedSizeConstantZero) { + const std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 0 +%4 = OpTypePointer Uniform %2 +%5 = OpTypeFloat 32 +%6 = OpTypePointer UniformConstant %5 +%7 = OpTypeFunction %1 %4 %6 +%8 = OpFunction %1 None %7 +%9 = OpFunctionParameter %4 +%10 = OpFunctionParameter %6 +%11 = OpLabel +OpCopyMemorySized %9 %10 %3 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Size operand '3[%uint_0]' cannot be a constant " + "zero.")); +} + +TEST_F(ValidateIdWithMessage, OpCopyMemorySizedSizeConstantZero64) { + const std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 64 0 +%3 = OpConstant %2 0 +%4 = OpTypePointer Uniform %2 +%5 = OpTypeFloat 32 +%6 = OpTypePointer UniformConstant %5 +%7 = OpTypeFunction %1 %4 %6 +%8 = OpFunction %1 None %7 +%9 = OpFunctionParameter %4 +%10 = OpFunctionParameter %6 +%11 = OpLabel +OpCopyMemorySized %9 %10 %3 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Size operand '3[%ulong_0]' cannot be a constant " + "zero.")); +} + +TEST_F(ValidateIdWithMessage, OpCopyMemorySizedSizeConstantNegative) { + const std::string spirv = kNoKernelGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 1 +%3 = OpConstant %2 -1 +%4 = OpTypePointer Uniform %2 +%5 = OpTypeFloat 32 +%6 = OpTypePointer UniformConstant %5 +%7 = OpTypeFunction %1 %4 %6 +%8 = OpFunction %1 None %7 +%9 = OpFunctionParameter %4 +%10 = OpFunctionParameter %6 +%11 = OpLabel +OpCopyMemorySized %9 %10 %3 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Size operand '3[%int_n1]' cannot have the sign bit set " + "to 1.")); +} + +TEST_F(ValidateIdWithMessage, OpCopyMemorySizedSizeConstantNegative64) { + const std::string spirv = kNoKernelGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 64 1 +%3 = OpConstant %2 -1 +%4 = OpTypePointer Uniform %2 +%5 = OpTypeFloat 32 +%6 = OpTypePointer UniformConstant %5 +%7 = OpTypeFunction %1 %4 %6 +%8 = OpFunction %1 None %7 +%9 = OpFunctionParameter %4 +%10 = OpFunctionParameter %6 +%11 = OpLabel +OpCopyMemorySized %9 %10 %3 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Size operand '3[%long_n1]' cannot have the sign bit set " + "to 1.")); +} + +TEST_F(ValidateIdWithMessage, OpCopyMemorySizedSizeUnsignedNegative) { + const std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 2147483648 +%4 = OpTypePointer Uniform %2 +%5 = OpTypeFloat 32 +%6 = OpTypePointer UniformConstant %5 +%7 = OpTypeFunction %1 %4 %6 +%8 = OpFunction %1 None %7 +%9 = OpFunctionParameter %4 +%10 = OpFunctionParameter %6 +%11 = OpLabel +OpCopyMemorySized %9 %10 %3 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpCopyMemorySizedSizeUnsignedNegative64) { + const std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 64 0 +%3 = OpConstant %2 9223372036854775808 +%4 = OpTypePointer Uniform %2 +%5 = OpTypeFloat 32 +%6 = OpTypePointer UniformConstant %5 +%7 = OpTypeFunction %1 %4 %6 +%8 = OpFunction %1 None %7 +%9 = OpFunctionParameter %4 +%10 = OpFunctionParameter %6 +%11 = OpLabel +OpCopyMemorySized %9 %10 %3 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +const char kDeeplyNestedStructureSetup[] = R"( +%void = OpTypeVoid +%void_f = OpTypeFunction %void +%int = OpTypeInt 32 0 +%float = OpTypeFloat 32 +%v3float = OpTypeVector %float 3 +%mat4x3 = OpTypeMatrix %v3float 4 +%_ptr_Private_mat4x3 = OpTypePointer Private %mat4x3 +%_ptr_Private_float = OpTypePointer Private %float +%my_matrix = OpVariable %_ptr_Private_mat4x3 Private +%my_float_var = OpVariable %_ptr_Private_float Private +%_ptr_Function_float = OpTypePointer Function %float +%int_0 = OpConstant %int 0 +%int_1 = OpConstant %int 1 +%int_2 = OpConstant %int 2 +%int_3 = OpConstant %int 3 +%int_5 = OpConstant %int 5 + +; Making the following nested structures. +; +; struct S { +; bool b; +; vec4 v[5]; +; int i; +; mat4x3 m[5]; +; } +; uniform blockName { +; S s; +; bool cond; +; RunTimeArray arr; +; } + +%f32arr = OpTypeRuntimeArray %float +%v4float = OpTypeVector %float 4 +%array5_mat4x3 = OpTypeArray %mat4x3 %int_5 +%array5_vec4 = OpTypeArray %v4float %int_5 +%_ptr_Uniform_float = OpTypePointer Uniform %float +%_ptr_Function_vec4 = OpTypePointer Function %v4float +%_ptr_Uniform_vec4 = OpTypePointer Uniform %v4float +%struct_s = OpTypeStruct %int %array5_vec4 %int %array5_mat4x3 +%struct_blockName = OpTypeStruct %struct_s %int %f32arr +%_ptr_Uniform_blockName = OpTypePointer Uniform %struct_blockName +%_ptr_Uniform_struct_s = OpTypePointer Uniform %struct_s +%_ptr_Uniform_array5_mat4x3 = OpTypePointer Uniform %array5_mat4x3 +%_ptr_Uniform_mat4x3 = OpTypePointer Uniform %mat4x3 +%_ptr_Uniform_v3float = OpTypePointer Uniform %v3float +%blockName_var = OpVariable %_ptr_Uniform_blockName Uniform +%spec_int = OpSpecConstant %int 2 +%float_0 = OpConstant %float 0 +%func = OpFunction %void None %void_f +%my_label = OpLabel +)"; + +// In what follows, Access Chain Instruction refers to one of the following: +// OpAccessChain, OpInBoundsAccessChain, OpPtrAccessChain, and +// OpInBoundsPtrAccessChain +using AccessChainInstructionTest = spvtest::ValidateBase; + +// Determines whether the access chain instruction requires the 'element id' +// argument. +bool AccessChainRequiresElemId(const std::string& instr) { + return (instr == "OpPtrAccessChain" || instr == "OpInBoundsPtrAccessChain"); +} + +// Valid: Access a float in a matrix using an access chain instruction. +TEST_P(AccessChainInstructionTest, AccessChainGood) { + const std::string instr = GetParam(); + const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + + "%float_entry = " + instr + + R"( %_ptr_Private_float %my_matrix )" + elem + + R"(%int_0 %int_1 + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid. The result type of an access chain instruction must be a pointer. +TEST_P(AccessChainInstructionTest, AccessChainResultTypeBad) { + const std::string instr = GetParam(); + const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( +%float_entry = )" + + instr + + R"( %float %my_matrix )" + elem + + R"(%int_0 %int_1 +OpReturn +OpFunctionEnd + )"; + + const std::string expected_err = "The Result Type of " + instr + + " '36[%36]' must be " + "OpTypePointer. Found OpTypeFloat."; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(expected_err)); +} + +// Invalid. The base type of an access chain instruction must be a pointer. +TEST_P(AccessChainInstructionTest, AccessChainBaseTypeVoidBad) { + const std::string instr = GetParam(); + const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( +%float_entry = )" + + instr + " %_ptr_Private_float %void " + elem + + R"(%int_0 %int_1 +OpReturn +OpFunctionEnd + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("Operand 1[%void] cannot be a " + "type")); +} + +// Invalid. The base type of an access chain instruction must be a pointer. +TEST_P(AccessChainInstructionTest, AccessChainBaseTypeNonPtrVariableBad) { + const std::string instr = GetParam(); + const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( +%entry = )" + + instr + R"( %_ptr_Private_float %_ptr_Private_float )" + + elem + + R"(%int_0 %int_1 +OpReturn +OpFunctionEnd + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Operand 8[%_ptr_Private_float] cannot be a type")); +} + +// Invalid: The storage class of Base and Result do not match. +TEST_P(AccessChainInstructionTest, + AccessChainResultAndBaseStorageClassDoesntMatchBad) { + const std::string instr = GetParam(); + const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( +%entry = )" + + instr + R"( %_ptr_Function_float %my_matrix )" + elem + + R"(%int_0 %int_1 +OpReturn +OpFunctionEnd + )"; + const std::string expected_err = + "The result pointer storage class and base pointer storage class in " + + instr + " do not match."; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(expected_err)); +} + +// Invalid. The base type of an access chain instruction must point to a +// composite object. +TEST_P(AccessChainInstructionTest, + AccessChainBasePtrNotPointingToCompositeBad) { + const std::string instr = GetParam(); + const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( +%entry = )" + + instr + R"( %_ptr_Private_float %my_float_var )" + elem + + R"(%int_0 +OpReturn +OpFunctionEnd + )"; + const std::string expected_err = instr + + " reached non-composite type while " + "indexes still remain to be traversed."; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(expected_err)); +} + +// Valid. No Indexes were passed to the access chain instruction. The Result +// Type is the same as the Base type. +TEST_P(AccessChainInstructionTest, AccessChainNoIndexesGood) { + const std::string instr = GetParam(); + const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( +%entry = )" + + instr + R"( %_ptr_Private_float %my_float_var )" + elem + + R"( +OpReturn +OpFunctionEnd + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid. No Indexes were passed to the access chain instruction, but the +// Result Type is different from the Base type. +TEST_P(AccessChainInstructionTest, AccessChainNoIndexesBad) { + const std::string instr = GetParam(); + const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( +%entry = )" + + instr + R"( %_ptr_Private_mat4x3 %my_float_var )" + elem + + R"( +OpReturn +OpFunctionEnd + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("result type (OpTypeMatrix) does not match the type that " + "results from indexing into the base (OpTypeFloat).")); +} + +// Valid: 255 indexes passed to the access chain instruction. Limit is 255. +TEST_P(AccessChainInstructionTest, AccessChainTooManyIndexesGood) { + const std::string instr = GetParam(); + const std::string elem = AccessChainRequiresElemId(instr) ? " %int_0 " : ""; + int depth = 255; + std::string header = kGLSL450MemoryModel + kDeeplyNestedStructureSetup; + header.erase(header.find("%func")); + std::ostringstream spirv; + spirv << header << "\n"; + + // Build nested structures. Struct 'i' contains struct 'i-1' + spirv << "%s_depth_1 = OpTypeStruct %float\n"; + for (int i = 2; i <= depth; ++i) { + spirv << "%s_depth_" << i << " = OpTypeStruct %s_depth_" << i - 1 << "\n"; + } + + // Define Pointer and Variable to use for the AccessChain instruction. + spirv << "%_ptr_Uniform_deep_struct = OpTypePointer Uniform %s_depth_" + << depth << "\n"; + spirv << "%deep_var = OpVariable %_ptr_Uniform_deep_struct Uniform\n"; + + // Function Start + spirv << R"( + %func = OpFunction %void None %void_f + %my_label = OpLabel + )"; + + // AccessChain with 'n' indexes (n = depth) + spirv << "%entry = " << instr << " %_ptr_Uniform_float %deep_var" << elem; + for (int i = 0; i < depth; ++i) { + spirv << " %int_0"; + } + + // Function end + spirv << R"( + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid: 256 indexes passed to the access chain instruction. Limit is 255. +TEST_P(AccessChainInstructionTest, AccessChainTooManyIndexesBad) { + const std::string instr = GetParam(); + const std::string elem = AccessChainRequiresElemId(instr) ? " %int_0 " : ""; + std::ostringstream spirv; + spirv << kGLSL450MemoryModel << kDeeplyNestedStructureSetup; + spirv << "%entry = " << instr << " %_ptr_Private_float %my_matrix" << elem; + for (int i = 0; i < 256; ++i) { + spirv << " %int_0"; + } + spirv << R"( + OpReturn + OpFunctionEnd + )"; + const std::string expected_err = "The number of indexes in " + instr + + " may not exceed 255. Found 256 indexes."; + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(expected_err)); +} + +// Valid: 10 indexes passed to the access chain instruction. (Custom limit: 10) +TEST_P(AccessChainInstructionTest, CustomizedAccessChainTooManyIndexesGood) { + const std::string instr = GetParam(); + const std::string elem = AccessChainRequiresElemId(instr) ? " %int_0 " : ""; + int depth = 10; + std::string header = kGLSL450MemoryModel + kDeeplyNestedStructureSetup; + header.erase(header.find("%func")); + std::ostringstream spirv; + spirv << header << "\n"; + + // Build nested structures. Struct 'i' contains struct 'i-1' + spirv << "%s_depth_1 = OpTypeStruct %float\n"; + for (int i = 2; i <= depth; ++i) { + spirv << "%s_depth_" << i << " = OpTypeStruct %s_depth_" << i - 1 << "\n"; + } + + // Define Pointer and Variable to use for the AccessChain instruction. + spirv << "%_ptr_Uniform_deep_struct = OpTypePointer Uniform %s_depth_" + << depth << "\n"; + spirv << "%deep_var = OpVariable %_ptr_Uniform_deep_struct Uniform\n"; + + // Function Start + spirv << R"( + %func = OpFunction %void None %void_f + %my_label = OpLabel + )"; + + // AccessChain with 'n' indexes (n = depth) + spirv << "%entry = " << instr << " %_ptr_Uniform_float %deep_var" << elem; + for (int i = 0; i < depth; ++i) { + spirv << " %int_0"; + } + + // Function end + spirv << R"( + OpReturn + OpFunctionEnd + )"; + + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_access_chain_indexes, 10u); + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid: 11 indexes passed to the access chain instruction. Custom Limit:10 +TEST_P(AccessChainInstructionTest, CustomizedAccessChainTooManyIndexesBad) { + const std::string instr = GetParam(); + const std::string elem = AccessChainRequiresElemId(instr) ? " %int_0 " : ""; + std::ostringstream spirv; + spirv << kGLSL450MemoryModel << kDeeplyNestedStructureSetup; + spirv << "%entry = " << instr << " %_ptr_Private_float %my_matrix" << elem; + for (int i = 0; i < 11; ++i) { + spirv << " %int_0"; + } + spirv << R"( + OpReturn + OpFunctionEnd + )"; + const std::string expected_err = "The number of indexes in " + instr + + " may not exceed 10. Found 11 indexes."; + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_access_chain_indexes, 10u); + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(expected_err)); +} + +// Invalid: Index passed to the access chain instruction is float (must be +// integer). +TEST_P(AccessChainInstructionTest, AccessChainUndefinedIndexBad) { + const std::string instr = GetParam(); + const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( +%entry = )" + + instr + R"( %_ptr_Private_float %my_matrix )" + elem + + R"(%float_0 %int_1 +OpReturn +OpFunctionEnd + )"; + const std::string expected_err = + "Indexes passed to " + instr + " must be of type integer."; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(expected_err)); +} + +// Invalid: The index argument that indexes into a struct must be of type +// OpConstant. +TEST_P(AccessChainInstructionTest, AccessChainStructIndexNotConstantBad) { + const std::string instr = GetParam(); + const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( +%f = )" + + instr + R"( %_ptr_Uniform_float %blockName_var )" + elem + + R"(%int_0 %spec_int %int_2 +OpReturn +OpFunctionEnd + )"; + const std::string expected_err = + "The passed to " + instr + + " to index into a structure must be an OpConstant."; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(expected_err)); +} + +// Invalid: Indexing up to a vec4 granularity, but result type expected float. +TEST_P(AccessChainInstructionTest, + AccessChainStructResultTypeDoesntMatchIndexedTypeBad) { + const std::string instr = GetParam(); + const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( +%entry = )" + + instr + R"( %_ptr_Uniform_float %blockName_var )" + elem + + R"(%int_0 %int_1 %int_2 +OpReturn +OpFunctionEnd + )"; + const std::string expected_err = instr + + " result type (OpTypeFloat) does not match " + "the type that results from indexing into " + "the base (OpTypeVector)."; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(expected_err)); +} + +// Invalid: Reach non-composite type (bool) when unused indexes remain. +TEST_P(AccessChainInstructionTest, AccessChainStructTooManyIndexesBad) { + const std::string instr = GetParam(); + const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( +%entry = )" + + instr + R"( %_ptr_Uniform_float %blockName_var )" + elem + + R"(%int_0 %int_2 %int_2 +OpReturn +OpFunctionEnd + )"; + const std::string expected_err = instr + + " reached non-composite type while " + "indexes still remain to be traversed."; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(expected_err)); +} + +// Invalid: Trying to find index 3 of the struct that has only 3 members. +TEST_P(AccessChainInstructionTest, AccessChainStructIndexOutOfBoundBad) { + const std::string instr = GetParam(); + const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( +%entry = )" + + instr + R"( %_ptr_Uniform_float %blockName_var )" + elem + + R"(%int_3 %int_2 %int_2 +OpReturn +OpFunctionEnd + )"; + const std::string expected_err = "Index is out of bounds: " + instr + + " can not find index 3 into the structure " + " '25[%_struct_25]'. This structure " + "has 3 members. Largest valid index is 2."; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(expected_err)); +} + +// Valid: Tests that we can index into Struct, Array, Matrix, and Vector! +TEST_P(AccessChainInstructionTest, AccessChainIndexIntoAllTypesGood) { + // indexes that we are passing are: 0, 3, 1, 2, 0 + // 0 will select the struct_s within the base struct (blockName) + // 3 will select the Array that contains 5 matrices + // 1 will select the Matrix that is at index 1 of the array + // 2 will select the column (which is a vector) within the matrix at index 2 + // 0 will select the element at the index 0 of the vector. (which is a float). + const std::string instr = GetParam(); + const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; + std::ostringstream spirv; + spirv << kGLSL450MemoryModel << kDeeplyNestedStructureSetup << std::endl; + spirv << "%ss = " << instr << " %_ptr_Uniform_struct_s %blockName_var " + << elem << "%int_0" << std::endl; + spirv << "%sa = " << instr << " %_ptr_Uniform_array5_mat4x3 %blockName_var " + << elem << "%int_0 %int_3" << std::endl; + spirv << "%sm = " << instr << " %_ptr_Uniform_mat4x3 %blockName_var " << elem + << "%int_0 %int_3 %int_1" << std::endl; + spirv << "%sc = " << instr << " %_ptr_Uniform_v3float %blockName_var " << elem + << "%int_0 %int_3 %int_1 %int_2" << std::endl; + spirv << "%entry = " << instr << " %_ptr_Uniform_float %blockName_var " + << elem << "%int_0 %int_3 %int_1 %int_2 %int_0" << std::endl; + spirv << R"( +OpReturn +OpFunctionEnd + )"; + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Valid: Access an element of OpTypeRuntimeArray. +TEST_P(AccessChainInstructionTest, AccessChainIndexIntoRuntimeArrayGood) { + const std::string instr = GetParam(); + const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( +%runtime_arr_entry = )" + + instr + R"( %_ptr_Uniform_float %blockName_var )" + elem + + R"(%int_2 %int_0 +OpReturn +OpFunctionEnd + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid: Unused index when accessing OpTypeRuntimeArray. +TEST_P(AccessChainInstructionTest, AccessChainIndexIntoRuntimeArrayBad) { + const std::string instr = GetParam(); + const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( +%runtime_arr_entry = )" + + instr + R"( %_ptr_Uniform_float %blockName_var )" + elem + + R"(%int_2 %int_0 %int_1 +OpReturn +OpFunctionEnd + )"; + const std::string expected_err = + instr + + " reached non-composite type while indexes still remain to be traversed."; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(expected_err)); +} + +// Invalid: Reached scalar type before arguments to the access chain instruction +// finished. +TEST_P(AccessChainInstructionTest, AccessChainMatrixMoreArgsThanNeededBad) { + const std::string instr = GetParam(); + const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( +%entry = )" + + instr + R"( %_ptr_Private_float %my_matrix )" + elem + + R"(%int_0 %int_1 %int_0 +OpReturn +OpFunctionEnd + )"; + const std::string expected_err = instr + + " reached non-composite type while " + "indexes still remain to be traversed."; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(expected_err)); +} + +// Invalid: The result type and the type indexed into do not match. +TEST_P(AccessChainInstructionTest, + AccessChainResultTypeDoesntMatchIndexedTypeBad) { + const std::string instr = GetParam(); + const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; + std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( +%entry = )" + + instr + R"( %_ptr_Private_mat4x3 %my_matrix )" + elem + + R"(%int_0 %int_1 +OpReturn +OpFunctionEnd + )"; + const std::string expected_err = instr + + " result type (OpTypeMatrix) does not match " + "the type that results from indexing into " + "the base (OpTypeFloat)."; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(expected_err)); +} + +// Run tests for Access Chain Instructions. +INSTANTIATE_TEST_CASE_P( + CheckAccessChainInstructions, AccessChainInstructionTest, + ::testing::Values("OpAccessChain", "OpInBoundsAccessChain", + "OpPtrAccessChain", "OpInBoundsPtrAccessChain")); + +// TODO: OpArrayLength +// TODO: OpImagePointer +// TODO: OpGenericPtrMemSemantics + +TEST_F(ValidateIdWithMessage, OpFunctionGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypeFunction %1 %2 %2 +%4 = OpFunction %1 None %3 +%5 = OpLabel + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpFunctionResultTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpConstant %2 42 +%4 = OpTypeFunction %1 %2 %2 +%5 = OpFunction %2 None %4 +%6 = OpLabel + OpReturnValue %3 + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpFunction Result Type '2[%uint]' does not " + "match the Function Type's return type " + "'1[%void]'.")); +} +TEST_F(ValidateIdWithMessage, OpReturnValueTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeInt 32 0 +%2 = OpTypeFloat 32 +%3 = OpConstant %2 0 +%4 = OpTypeFunction %1 +%5 = OpFunction %1 None %4 +%6 = OpLabel + OpReturnValue %3 + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpReturnValue Value '3[%float_0]'s type does " + "not match OpFunction's return type.")); +} +TEST_F(ValidateIdWithMessage, OpFunctionFunctionTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%4 = OpFunction %1 None %2 +%5 = OpLabel + OpReturn +OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpFunction Function Type '2[%uint]' is not a function " + "type.")); +} + +TEST_F(ValidateIdWithMessage, OpFunctionUseBad) { + const std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeFloat 32 +%2 = OpTypeFunction %1 +%3 = OpFunction %1 None %2 +%4 = OpLabel +OpReturnValue %3 +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Invalid use of function result id 3[%3].")); +} + +TEST_F(ValidateIdWithMessage, OpFunctionParameterGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypeFunction %1 %2 +%4 = OpFunction %1 None %3 +%5 = OpFunctionParameter %2 +%6 = OpLabel + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpFunctionParameterMultipleGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypeFunction %1 %2 %2 +%4 = OpFunction %1 None %3 +%5 = OpFunctionParameter %2 +%6 = OpFunctionParameter %2 +%7 = OpLabel + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpFunctionParameterResultTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypeFunction %1 %2 +%4 = OpFunction %1 None %3 +%5 = OpFunctionParameter %1 +%6 = OpLabel + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpFunctionParameter Result Type '1[%void]' does not " + "match the OpTypeFunction parameter type of the same index.")); +} + +TEST_F(ValidateIdWithMessage, OpFunctionCallGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypeFunction %2 %2 +%4 = OpTypeFunction %1 +%5 = OpConstant %2 42 ;21 + +%6 = OpFunction %2 None %3 +%7 = OpFunctionParameter %2 +%8 = OpLabel + OpReturnValue %7 + OpFunctionEnd + +%10 = OpFunction %1 None %4 +%11 = OpLabel +%12 = OpFunctionCall %2 %6 %5 + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpFunctionCallResultTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypeFunction %2 %2 +%4 = OpTypeFunction %1 +%5 = OpConstant %2 42 ;21 + +%6 = OpFunction %2 None %3 +%7 = OpFunctionParameter %2 +%8 = OpLabel +%9 = OpIAdd %2 %7 %7 + OpReturnValue %9 + OpFunctionEnd + +%10 = OpFunction %1 None %4 +%11 = OpLabel +%12 = OpFunctionCall %1 %6 %5 + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpFunctionCall Result Type '1[%void]'s type " + "does not match Function '2[%uint]'s return " + "type.")); +} +TEST_F(ValidateIdWithMessage, OpFunctionCallFunctionBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypeFunction %2 %2 +%4 = OpTypeFunction %1 +%5 = OpConstant %2 42 ;21 + +%10 = OpFunction %1 None %4 +%11 = OpLabel +%12 = OpFunctionCall %2 %5 %5 + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpFunctionCall Function '5[%uint_42]' is not a " + "function.")); +} +TEST_F(ValidateIdWithMessage, OpFunctionCallArgumentTypeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypeFunction %2 %2 +%4 = OpTypeFunction %1 +%5 = OpConstant %2 42 + +%13 = OpTypeFloat 32 +%14 = OpConstant %13 3.14 + +%6 = OpFunction %2 None %3 +%7 = OpFunctionParameter %2 +%8 = OpLabel +%9 = OpIAdd %2 %7 %7 + OpReturnValue %9 + OpFunctionEnd + +%10 = OpFunction %1 None %4 +%11 = OpLabel +%12 = OpFunctionCall %2 %6 %14 + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpFunctionCall Argument '7[%float_3_1400001]'s " + "type does not match Function '2[%uint]'s " + "parameter type.")); +} + +// Valid: OpSampledImage result is used in the same block by +// OpImageSampleImplictLod +TEST_F(ValidateIdWithMessage, OpSampledImageGood) { + std::string spirv = kGLSL450MemoryModel + sampledImageSetup + R"( +%smpld_img = OpSampledImage %sampled_image_type %image_inst %sampler_inst +%si_lod = OpImageSampleImplicitLod %v4float %smpld_img %const_vec_1_1 + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid: OpSampledImage result is defined in one block and used in a +// different block. +TEST_F(ValidateIdWithMessage, OpSampledImageUsedInDifferentBlockBad) { + std::string spirv = kGLSL450MemoryModel + sampledImageSetup + R"( +%smpld_img = OpSampledImage %sampled_image_type %image_inst %sampler_inst +OpBranch %label_2 +%label_2 = OpLabel +%si_lod = OpImageSampleImplicitLod %v4float %smpld_img %const_vec_1_1 +OpReturn +OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("All OpSampledImage instructions must be in the same block in " + "which their Result are consumed. OpSampledImage Result " + "Type '23[%23]' has a consumer in a different basic " + "block. The consumer instruction is '25[%25]'.")); +} + +// Invalid: OpSampledImage result is used by OpSelect +// Note: According to the Spec, OpSelect parameters must be either a scalar or a +// vector. Therefore, OpTypeSampledImage is an illegal parameter for OpSelect. +// However, the OpSelect validation does not catch this today. Therefore, it is +// caught by the OpSampledImage validation. If the OpSelect validation code is +// updated, the error message for this test may change. +// +// Disabled since OpSelect catches this now. +TEST_F(ValidateIdWithMessage, DISABLED_OpSampledImageUsedInOpSelectBad) { + std::string spirv = kGLSL450MemoryModel + sampledImageSetup + R"( +%smpld_img = OpSampledImage %sampled_image_type %image_inst %sampler_inst +%select_img = OpSelect %sampled_image_type %spec_true %smpld_img %smpld_img +OpReturn +OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Result from OpSampledImage instruction must not " + "appear as operands of OpSelect. Found result " + "'23' as an operand of '24'.")); +} + +// Valid: Get a float in a matrix using CompositeExtract. +// Valid: Insert float into a matrix using CompositeInsert. +TEST_F(ValidateIdWithMessage, CompositeExtractInsertGood) { + std::ostringstream spirv; + spirv << kGLSL450MemoryModel << kDeeplyNestedStructureSetup << std::endl; + spirv << "%matrix = OpLoad %mat4x3 %my_matrix" << std::endl; + spirv << "%float_entry = OpCompositeExtract %float %matrix 0 1" << std::endl; + + // To test CompositeInsert, insert the object back in after extraction. + spirv << "%new_composite = OpCompositeInsert %mat4x3 %float_entry %matrix 0 1" + << std::endl; + spirv << R"(OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +#if 0 +TEST_F(ValidateIdWithMessage, OpFunctionCallArgumentCountBar) { + const char *spirv = R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypeFunction %2 %2 +%4 = OpTypeFunction %1 +%5 = OpConstant %2 42 ;21 + +%6 = OpFunction %2 None %3 +%7 = OpFunctionParameter %2 +%8 = OpLabel +%9 = OpLoad %2 %7 + OpReturnValue %9 + OpFunctionEnd + +%10 = OpFunction %1 None %4 +%11 = OpLabel + OpReturn +%12 = OpFunctionCall %2 %6 %5 + OpFunctionEnd)"; + CHECK(spirv, SPV_ERROR_INVALID_ID); +} +#endif + +// TODO: The many things that changed with how images are used. +// TODO: OpTextureSample +// TODO: OpTextureSampleDref +// TODO: OpTextureSampleLod +// TODO: OpTextureSampleProj +// TODO: OpTextureSampleGrad +// TODO: OpTextureSampleOffset +// TODO: OpTextureSampleProjLod +// TODO: OpTextureSampleProjGrad +// TODO: OpTextureSampleLodOffset +// TODO: OpTextureSampleProjOffset +// TODO: OpTextureSampleGradOffset +// TODO: OpTextureSampleProjLodOffset +// TODO: OpTextureSampleProjGradOffset +// TODO: OpTextureFetchTexelLod +// TODO: OpTextureFetchTexelOffset +// TODO: OpTextureFetchSample +// TODO: OpTextureFetchTexel +// TODO: OpTextureGather +// TODO: OpTextureGatherOffset +// TODO: OpTextureGatherOffsets +// TODO: OpTextureQuerySizeLod +// TODO: OpTextureQuerySize +// TODO: OpTextureQueryLevels +// TODO: OpTextureQuerySamples +// TODO: OpConvertUToF +// TODO: OpConvertFToS +// TODO: OpConvertSToF +// TODO: OpConvertUToF +// TODO: OpUConvert +// TODO: OpSConvert +// TODO: OpFConvert +// TODO: OpConvertPtrToU +// TODO: OpConvertUToPtr +// TODO: OpPtrCastToGeneric +// TODO: OpGenericCastToPtr +// TODO: OpBitcast +// TODO: OpGenericCastToPtrExplicit +// TODO: OpSatConvertSToU +// TODO: OpSatConvertUToS +// TODO: OpVectorExtractDynamic +// TODO: OpVectorInsertDynamic + +TEST_F(ValidateIdWithMessage, OpVectorShuffleIntGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%int = OpTypeInt 32 0 +%ivec3 = OpTypeVector %int 3 +%ivec4 = OpTypeVector %int 4 +%ptr_ivec3 = OpTypePointer Function %ivec3 +%undef = OpUndef %ivec4 +%int_42 = OpConstant %int 42 +%int_0 = OpConstant %int 0 +%int_2 = OpConstant %int 2 +%1 = OpConstantComposite %ivec3 %int_42 %int_0 %int_2 +%2 = OpTypeFunction %ivec3 +%3 = OpFunction %ivec3 None %2 +%4 = OpLabel +%var = OpVariable %ptr_ivec3 Function %1 +%5 = OpLoad %ivec3 %var +%6 = OpVectorShuffle %ivec3 %5 %undef 2 1 0 + OpReturnValue %6 + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpVectorShuffleFloatGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%float = OpTypeFloat 32 +%vec2 = OpTypeVector %float 2 +%vec3 = OpTypeVector %float 3 +%vec4 = OpTypeVector %float 4 +%ptr_vec2 = OpTypePointer Function %vec2 +%ptr_vec3 = OpTypePointer Function %vec3 +%float_1 = OpConstant %float 1 +%float_2 = OpConstant %float 2 +%1 = OpConstantComposite %vec2 %float_2 %float_1 +%2 = OpConstantComposite %vec3 %float_1 %float_2 %float_2 +%3 = OpTypeFunction %vec4 +%4 = OpFunction %vec4 None %3 +%5 = OpLabel +%var = OpVariable %ptr_vec2 Function %1 +%var2 = OpVariable %ptr_vec3 Function %2 +%6 = OpLoad %vec2 %var +%7 = OpLoad %vec3 %var2 +%8 = OpVectorShuffle %vec4 %6 %7 4 3 1 0xffffffff + OpReturnValue %8 + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpVectorShuffleScalarResultType) { + std::string spirv = kGLSL450MemoryModel + R"( +%float = OpTypeFloat 32 +%vec2 = OpTypeVector %float 2 +%ptr_vec2 = OpTypePointer Function %vec2 +%float_1 = OpConstant %float 1 +%float_2 = OpConstant %float 2 +%1 = OpConstantComposite %vec2 %float_2 %float_1 +%2 = OpTypeFunction %float +%3 = OpFunction %float None %2 +%4 = OpLabel +%var = OpVariable %ptr_vec2 Function %1 +%5 = OpLoad %vec2 %var +%6 = OpVectorShuffle %float %5 %5 0 + OpReturnValue %6 + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Result Type of OpVectorShuffle must be OpTypeVector.")); +} + +TEST_F(ValidateIdWithMessage, OpVectorShuffleComponentCount) { + std::string spirv = kGLSL450MemoryModel + R"( +%int = OpTypeInt 32 0 +%ivec3 = OpTypeVector %int 3 +%ptr_ivec3 = OpTypePointer Function %ivec3 +%int_42 = OpConstant %int 42 +%int_0 = OpConstant %int 0 +%int_2 = OpConstant %int 2 +%1 = OpConstantComposite %ivec3 %int_42 %int_0 %int_2 +%2 = OpTypeFunction %ivec3 +%3 = OpFunction %ivec3 None %2 +%4 = OpLabel +%var = OpVariable %ptr_ivec3 Function %1 +%5 = OpLoad %ivec3 %var +%6 = OpVectorShuffle %ivec3 %5 %5 0 1 + OpReturnValue %6 + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpVectorShuffle component literals count does not match " + "Result Type '2[%v3uint]'s vector component count.")); +} + +TEST_F(ValidateIdWithMessage, OpVectorShuffleVector1Type) { + std::string spirv = kGLSL450MemoryModel + R"( +%int = OpTypeInt 32 0 +%ivec2 = OpTypeVector %int 2 +%ptr_int = OpTypePointer Function %int +%undef = OpUndef %ivec2 +%int_42 = OpConstant %int 42 +%2 = OpTypeFunction %ivec2 +%3 = OpFunction %ivec2 None %2 +%4 = OpLabel +%var = OpVariable %ptr_int Function %int_42 +%5 = OpLoad %int %var +%6 = OpVectorShuffle %ivec2 %5 %undef 0 0 + OpReturnValue %6 + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("The type of Vector 1 must be OpTypeVector.")); +} + +TEST_F(ValidateIdWithMessage, OpVectorShuffleVector2Type) { + std::string spirv = kGLSL450MemoryModel + R"( +%int = OpTypeInt 32 0 +%ivec2 = OpTypeVector %int 2 +%ptr_ivec2 = OpTypePointer Function %ivec2 +%undef = OpUndef %int +%int_42 = OpConstant %int 42 +%1 = OpConstantComposite %ivec2 %int_42 %int_42 +%2 = OpTypeFunction %ivec2 +%3 = OpFunction %ivec2 None %2 +%4 = OpLabel +%var = OpVariable %ptr_ivec2 Function %1 +%5 = OpLoad %ivec2 %var +%6 = OpVectorShuffle %ivec2 %5 %undef 0 1 + OpReturnValue %6 + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("The type of Vector 2 must be OpTypeVector.")); +} + +TEST_F(ValidateIdWithMessage, OpVectorShuffleVector1ComponentType) { + std::string spirv = kGLSL450MemoryModel + R"( +%int = OpTypeInt 32 0 +%ivec3 = OpTypeVector %int 3 +%ptr_ivec3 = OpTypePointer Function %ivec3 +%int_42 = OpConstant %int 42 +%int_0 = OpConstant %int 0 +%int_2 = OpConstant %int 2 +%float = OpTypeFloat 32 +%vec3 = OpTypeVector %float 3 +%vec4 = OpTypeVector %float 4 +%ptr_vec3 = OpTypePointer Function %vec3 +%float_1 = OpConstant %float 1 +%float_2 = OpConstant %float 2 +%1 = OpConstantComposite %ivec3 %int_42 %int_0 %int_2 +%2 = OpConstantComposite %vec3 %float_1 %float_2 %float_2 +%3 = OpTypeFunction %vec4 +%4 = OpFunction %vec4 None %3 +%5 = OpLabel +%var = OpVariable %ptr_ivec3 Function %1 +%var2 = OpVariable %ptr_vec3 Function %2 +%6 = OpLoad %ivec3 %var +%7 = OpLoad %vec3 %var2 +%8 = OpVectorShuffle %vec4 %6 %7 4 3 1 0 + OpReturnValue %8 + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("The Component Type of Vector 1 must be the same as " + "ResultType.")); +} + +TEST_F(ValidateIdWithMessage, OpVectorShuffleVector2ComponentType) { + std::string spirv = kGLSL450MemoryModel + R"( +%int = OpTypeInt 32 0 +%ivec3 = OpTypeVector %int 3 +%ptr_ivec3 = OpTypePointer Function %ivec3 +%int_42 = OpConstant %int 42 +%int_0 = OpConstant %int 0 +%int_2 = OpConstant %int 2 +%float = OpTypeFloat 32 +%vec3 = OpTypeVector %float 3 +%vec4 = OpTypeVector %float 4 +%ptr_vec3 = OpTypePointer Function %vec3 +%float_1 = OpConstant %float 1 +%float_2 = OpConstant %float 2 +%1 = OpConstantComposite %ivec3 %int_42 %int_0 %int_2 +%2 = OpConstantComposite %vec3 %float_1 %float_2 %float_2 +%3 = OpTypeFunction %vec4 +%4 = OpFunction %vec4 None %3 +%5 = OpLabel +%var = OpVariable %ptr_ivec3 Function %1 +%var2 = OpVariable %ptr_vec3 Function %2 +%6 = OpLoad %vec3 %var2 +%7 = OpLoad %ivec3 %var +%8 = OpVectorShuffle %vec4 %6 %7 4 3 1 0 + OpReturnValue %8 + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("The Component Type of Vector 2 must be the same as " + "ResultType.")); +} + +TEST_F(ValidateIdWithMessage, OpVectorShuffleLiterals) { + std::string spirv = kGLSL450MemoryModel + R"( +%float = OpTypeFloat 32 +%vec2 = OpTypeVector %float 2 +%vec3 = OpTypeVector %float 3 +%vec4 = OpTypeVector %float 4 +%ptr_vec2 = OpTypePointer Function %vec2 +%ptr_vec3 = OpTypePointer Function %vec3 +%float_1 = OpConstant %float 1 +%float_2 = OpConstant %float 2 +%1 = OpConstantComposite %vec2 %float_2 %float_1 +%2 = OpConstantComposite %vec3 %float_1 %float_2 %float_2 +%3 = OpTypeFunction %vec4 +%4 = OpFunction %vec4 None %3 +%5 = OpLabel +%var = OpVariable %ptr_vec2 Function %1 +%var2 = OpVariable %ptr_vec3 Function %2 +%6 = OpLoad %vec2 %var +%7 = OpLoad %vec3 %var2 +%8 = OpVectorShuffle %vec4 %6 %7 0 8 2 6 + OpReturnValue %8 + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Component index 8 is out of bounds for combined (Vector1 + Vector2) " + "size of 5.")); +} + +TEST_F(ValidateIdWithMessage, WebGPUOpVectorShuffle0xFFFFFFFFLiteralBad) { + std::string spirv = R"( + OpCapability Shader + OpCapability VulkanMemoryModelKHR + OpExtension "SPV_KHR_vulkan_memory_model" + OpMemoryModel Logical VulkanKHR +%float = OpTypeFloat 32 +%vec2 = OpTypeVector %float 2 +%vec3 = OpTypeVector %float 3 +%vec4 = OpTypeVector %float 4 +%ptr_vec2 = OpTypePointer Function %vec2 +%ptr_vec3 = OpTypePointer Function %vec3 +%float_1 = OpConstant %float 1 +%float_2 = OpConstant %float 2 +%1 = OpConstantComposite %vec2 %float_2 %float_1 +%2 = OpConstantComposite %vec3 %float_1 %float_2 %float_2 +%3 = OpTypeFunction %vec4 +%4 = OpFunction %vec4 None %3 +%5 = OpLabel +%var = OpVariable %ptr_vec2 Function %1 +%var2 = OpVariable %ptr_vec3 Function %2 +%6 = OpLoad %vec2 %var +%7 = OpLoad %vec3 %var2 +%8 = OpVectorShuffle %vec4 %6 %7 4 3 1 0xffffffff + OpReturnValue %8 + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str(), SPV_ENV_WEBGPU_0); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Component literal at operand 3 cannot be 0xFFFFFFFF in" + " WebGPU execution environment.")); +} + +// TODO: OpCompositeConstruct +// TODO: OpCompositeExtract +// TODO: OpCompositeInsert +// TODO: OpCopyObject +// TODO: OpTranspose +// TODO: OpSNegate +// TODO: OpFNegate +// TODO: OpNot +// TODO: OpIAdd +// TODO: OpFAdd +// TODO: OpISub +// TODO: OpFSub +// TODO: OpIMul +// TODO: OpFMul +// TODO: OpUDiv +// TODO: OpSDiv +// TODO: OpFDiv +// TODO: OpUMod +// TODO: OpSRem +// TODO: OpSMod +// TODO: OpFRem +// TODO: OpFMod +// TODO: OpVectorTimesScalar +// TODO: OpMatrixTimesScalar +// TODO: OpVectorTimesMatrix +// TODO: OpMatrixTimesVector +// TODO: OpMatrixTimesMatrix +// TODO: OpOuterProduct +// TODO: OpDot +// TODO: OpShiftRightLogical +// TODO: OpShiftRightArithmetic +// TODO: OpShiftLeftLogical +// TODO: OpBitwiseOr +// TODO: OpBitwiseXor +// TODO: OpBitwiseAnd +// TODO: OpAny +// TODO: OpAll +// TODO: OpIsNan +// TODO: OpIsInf +// TODO: OpIsFinite +// TODO: OpIsNormal +// TODO: OpSignBitSet +// TODO: OpLessOrGreater +// TODO: OpOrdered +// TODO: OpUnordered +// TODO: OpLogicalOr +// TODO: OpLogicalXor +// TODO: OpLogicalAnd +// TODO: OpSelect +// TODO: OpIEqual +// TODO: OpFOrdEqual +// TODO: OpFUnordEqual +// TODO: OpINotEqual +// TODO: OpFOrdNotEqual +// TODO: OpFUnordNotEqual +// TODO: OpULessThan +// TODO: OpSLessThan +// TODO: OpFOrdLessThan +// TODO: OpFUnordLessThan +// TODO: OpUGreaterThan +// TODO: OpSGreaterThan +// TODO: OpFOrdGreaterThan +// TODO: OpFUnordGreaterThan +// TODO: OpULessThanEqual +// TODO: OpSLessThanEqual +// TODO: OpFOrdLessThanEqual +// TODO: OpFUnordLessThanEqual +// TODO: OpUGreaterThanEqual +// TODO: OpSGreaterThanEqual +// TODO: OpFOrdGreaterThanEqual +// TODO: OpFUnordGreaterThanEqual +// TODO: OpDPdx +// TODO: OpDPdy +// TODO: OpFWidth +// TODO: OpDPdxFine +// TODO: OpDPdyFine +// TODO: OpFwidthFine +// TODO: OpDPdxCoarse +// TODO: OpDPdyCoarse +// TODO: OpFwidthCoarse +// TODO: OpLoopMerge +// TODO: OpSelectionMerge +// TODO: OpBranch + +TEST_F(ValidateIdWithMessage, OpPhiNotAType) { + std::string spirv = kOpenCLMemoryModel32 + R"( +%2 = OpTypeBool +%3 = OpConstantTrue %2 +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpFunction %4 None %5 +%7 = OpLabel +OpBranch %8 +%8 = OpLabel +%9 = OpPhi %3 %3 %7 +OpReturn +OpFunctionEnd + )"; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("ID 3[%true] is not a type " + "id")); +} + +TEST_F(ValidateIdWithMessage, OpPhiSamePredecessor) { + std::string spirv = kOpenCLMemoryModel32 + R"( +%2 = OpTypeBool +%3 = OpConstantTrue %2 +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpFunction %4 None %5 +%7 = OpLabel +OpBranchConditional %3 %8 %8 +%8 = OpLabel +%9 = OpPhi %2 %3 %7 +OpReturn +OpFunctionEnd + )"; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpPhiOddArgumentNumber) { + std::string spirv = kOpenCLMemoryModel32 + R"( +%2 = OpTypeBool +%3 = OpConstantTrue %2 +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpFunction %4 None %5 +%7 = OpLabel +OpBranch %8 +%8 = OpLabel +%9 = OpPhi %2 %3 +OpReturn +OpFunctionEnd + )"; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpPhi does not have an equal number of incoming " + "values and basic blocks.")); +} + +TEST_F(ValidateIdWithMessage, OpPhiTooFewPredecessors) { + std::string spirv = kOpenCLMemoryModel32 + R"( +%2 = OpTypeBool +%3 = OpConstantTrue %2 +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpFunction %4 None %5 +%7 = OpLabel +OpBranch %8 +%8 = OpLabel +%9 = OpPhi %2 +OpReturn +OpFunctionEnd + )"; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpPhi's number of incoming blocks (0) does not match " + "block's predecessor count (1).")); +} + +TEST_F(ValidateIdWithMessage, OpPhiTooManyPredecessors) { + std::string spirv = kOpenCLMemoryModel32 + R"( +%2 = OpTypeBool +%3 = OpConstantTrue %2 +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpFunction %4 None %5 +%7 = OpLabel +OpBranch %8 +%9 = OpLabel +OpReturn +%8 = OpLabel +%10 = OpPhi %2 %3 %7 %3 %9 +OpReturn +OpFunctionEnd + )"; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpPhi's number of incoming blocks (2) does not match " + "block's predecessor count (1).")); +} + +TEST_F(ValidateIdWithMessage, OpPhiMismatchedTypes) { + std::string spirv = kOpenCLMemoryModel32 + R"( +%2 = OpTypeBool +%3 = OpConstantTrue %2 +%4 = OpTypeVoid +%5 = OpTypeInt 32 0 +%6 = OpConstant %5 0 +%7 = OpTypeFunction %4 +%8 = OpFunction %4 None %7 +%9 = OpLabel +OpBranchConditional %3 %10 %11 +%11 = OpLabel +OpBranch %10 +%10 = OpLabel +%12 = OpPhi %2 %3 %9 %6 %11 +OpReturn +OpFunctionEnd + )"; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpPhi's result type 2[%bool] does not match " + "incoming value 6[%uint_0] type " + "5[%uint].")); +} + +TEST_F(ValidateIdWithMessage, OpPhiPredecessorNotABlock) { + std::string spirv = kOpenCLMemoryModel32 + R"( +%2 = OpTypeBool +%3 = OpConstantTrue %2 +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpFunction %4 None %5 +%7 = OpLabel +OpBranchConditional %3 %8 %9 +%9 = OpLabel +OpBranch %11 +%11 = OpLabel +OpBranch %8 +%8 = OpLabel +%10 = OpPhi %2 %3 %7 %3 %3 +OpReturn +OpFunctionEnd + )"; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpPhi's incoming basic block 3[%true] is not an " + "OpLabel.")); +} + +TEST_F(ValidateIdWithMessage, OpPhiNotAPredecessor) { + std::string spirv = kOpenCLMemoryModel32 + R"( +%2 = OpTypeBool +%3 = OpConstantTrue %2 +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpFunction %4 None %5 +%7 = OpLabel +OpBranchConditional %3 %8 %9 +%9 = OpLabel +OpBranch %11 +%11 = OpLabel +OpBranch %8 +%8 = OpLabel +%10 = OpPhi %2 %3 %7 %3 %9 +OpReturn +OpFunctionEnd + )"; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpPhi's incoming basic block 9[%9] is not a " + "predecessor of 8[%8].")); +} + +TEST_F(ValidateIdWithMessage, OpBranchConditionalGood) { + std::string spirv = BranchConditionalSetup + R"( + %branch_cond = OpINotEqual %bool %i0 %i1 + OpSelectionMerge %end None + OpBranchConditional %branch_cond %target_t %target_f + )" + BranchConditionalTail; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateIdWithMessage, OpBranchConditionalWithWeightsGood) { + std::string spirv = BranchConditionalSetup + R"( + %branch_cond = OpINotEqual %bool %i0 %i1 + OpSelectionMerge %end None + OpBranchConditional %branch_cond %target_t %target_f 1 1 + )" + BranchConditionalTail; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateIdWithMessage, OpBranchConditional_CondIsScalarInt) { + std::string spirv = BranchConditionalSetup + R"( + OpSelectionMerge %end None + OpBranchConditional %i0 %target_t %target_f + )" + BranchConditionalTail; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Condition operand for OpBranchConditional must be of boolean type")); +} + +TEST_F(ValidateIdWithMessage, OpBranchConditional_TrueTargetIsNotLabel) { + std::string spirv = BranchConditionalSetup + R"( + OpSelectionMerge %end None + OpBranchConditional %true %i0 %target_f + )" + BranchConditionalTail; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("The 'True Label' operand for OpBranchConditional must " + "be the ID of an OpLabel instruction")); +} + +TEST_F(ValidateIdWithMessage, OpBranchConditional_FalseTargetIsNotLabel) { + std::string spirv = BranchConditionalSetup + R"( + OpSelectionMerge %end None + OpBranchConditional %true %target_t %i0 + )" + BranchConditionalTail; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("The 'False Label' operand for OpBranchConditional " + "must be the ID of an OpLabel instruction")); +} + +TEST_F(ValidateIdWithMessage, OpBranchConditional_NotEnoughWeights) { + std::string spirv = BranchConditionalSetup + R"( + %branch_cond = OpINotEqual %bool %i0 %i1 + OpSelectionMerge %end None + OpBranchConditional %branch_cond %target_t %target_f 1 + )" + BranchConditionalTail; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpBranchConditional requires either 3 or 5 parameters")); +} + +TEST_F(ValidateIdWithMessage, OpBranchConditional_TooManyWeights) { + std::string spirv = BranchConditionalSetup + R"( + %branch_cond = OpINotEqual %bool %i0 %i1 + OpSelectionMerge %end None + OpBranchConditional %branch_cond %target_t %target_f 1 2 3 + )" + BranchConditionalTail; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpBranchConditional requires either 3 or 5 parameters")); +} + +TEST_F(ValidateIdWithMessage, OpBranchConditional_ConditionIsAType) { + std::string spirv = BranchConditionalSetup + R"( +OpBranchConditional %bool %target_t %target_f +)" + BranchConditionalTail; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("Operand 3[%bool] cannot be a " + "type")); +} + +// TODO: OpSwitch + +TEST_F(ValidateIdWithMessage, OpReturnValueConstantGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypeFunction %2 +%4 = OpConstant %2 42 +%5 = OpFunction %2 None %3 +%6 = OpLabel + OpReturnValue %4 + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpReturnValueVariableGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 ;10 +%3 = OpTypeFunction %2 +%8 = OpTypePointer Function %2 ;18 +%4 = OpConstant %2 42 ;22 +%5 = OpFunction %2 None %3 ;27 +%6 = OpLabel ;29 +%7 = OpVariable %8 Function %4 ;34 +%9 = OpLoad %2 %7 + OpReturnValue %9 ;36 + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpReturnValueExpressionGood) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypeFunction %2 +%4 = OpConstant %2 42 +%5 = OpFunction %2 None %3 +%6 = OpLabel +%7 = OpIAdd %2 %4 %4 + OpReturnValue %7 + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpReturnValueIsType) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypeFunction %2 +%5 = OpFunction %2 None %3 +%6 = OpLabel + OpReturnValue %1 + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("Operand 1[%void] cannot be a " + "type")); +} + +TEST_F(ValidateIdWithMessage, OpReturnValueIsLabel) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypeFunction %2 +%5 = OpFunction %2 None %3 +%6 = OpLabel + OpReturnValue %6 + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Operand 5[%5] requires a type")); +} + +TEST_F(ValidateIdWithMessage, OpReturnValueIsVoid) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypeFunction %1 +%5 = OpFunction %1 None %3 +%6 = OpLabel +%7 = OpFunctionCall %1 %5 + OpReturnValue %7 + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpReturnValue value's type '1[%void]' is missing or " + "void.")); +} + +TEST_F(ValidateIdWithMessage, OpReturnValueIsVariableInPhysical) { + // It's valid to return a pointer in a physical addressing model. + std::string spirv = kOpCapabilitySetup + R"( + OpMemoryModel Physical32 OpenCL +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Function %2 +%4 = OpTypeFunction %3 +%5 = OpFunction %3 None %4 +%6 = OpLabel +%7 = OpVariable %3 Function + OpReturnValue %7 + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpReturnValueIsVariableInLogical) { + // It's invalid to return a pointer in a physical addressing model. + std::string spirv = kOpCapabilitySetup + R"( + OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Function %2 +%4 = OpTypeFunction %3 +%5 = OpFunction %3 None %4 +%6 = OpLabel +%7 = OpVariable %3 Function + OpReturnValue %7 + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpReturnValue value's type " + "'3[%_ptr_Function_uint]' is a pointer, which is " + "invalid in the Logical addressing model.")); +} + +// With the VariablePointer Capability, the return value of a function is +// allowed to be a pointer. +TEST_F(ValidateIdWithMessage, OpReturnValueVarPtrGood) { + std::ostringstream spirv; + createVariablePointerSpirvProgram(&spirv, + "" /* Instructions to add to "main" */, + true /* Add VariablePointers Capability?*/, + true /* Use Helper Function? */); + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Without the VariablePointer Capability, the return value of a function is +// *not* allowed to be a pointer. +// Disabled since using OpSelect with pointers without VariablePointers will +// fail LogicalsPass. +TEST_F(ValidateIdWithMessage, DISABLED_OpReturnValueVarPtrBad) { + std::ostringstream spirv; + createVariablePointerSpirvProgram(&spirv, + "" /* Instructions to add to "main" */, + false /* Add VariablePointers Capability?*/, + true /* Use Helper Function? */); + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpReturnValue value's type '7' is a pointer, " + "which is invalid in the Logical addressing model.")); +} + +// TODO: enable when this bug is fixed: +// https://cvs.khronos.org/bugzilla/show_bug.cgi?id=15404 +TEST_F(ValidateIdWithMessage, DISABLED_OpReturnValueIsFunction) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypeFunction %2 +%5 = OpFunction %2 None %3 +%6 = OpLabel + OpReturnValue %5 + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, UndefinedTypeId) { + std::string spirv = kGLSL450MemoryModel + R"( +%s = OpTypeStruct %i32 +)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Forward reference operands in an OpTypeStruct must " + "first be declared using OpTypeForwardPointer.")); +} + +TEST_F(ValidateIdWithMessage, UndefinedIdScope) { + std::string spirv = kGLSL450MemoryModel + R"( +%u32 = OpTypeInt 32 0 +%memsem = OpConstant %u32 0 +%void = OpTypeVoid +%void_f = OpTypeFunction %void +%f = OpFunction %void None %void_f +%l = OpLabel + OpMemoryBarrier %undef %memsem + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("ID 7[%7] has not been " + "defined")); +} + +TEST_F(ValidateIdWithMessage, UndefinedIdMemSem) { + std::string spirv = kGLSL450MemoryModel + R"( +%u32 = OpTypeInt 32 0 +%scope = OpConstant %u32 0 +%void = OpTypeVoid +%void_f = OpTypeFunction %void +%f = OpFunction %void None %void_f +%l = OpLabel + OpMemoryBarrier %scope %undef + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("ID 7[%7] has not been " + "defined")); +} + +TEST_F(ValidateIdWithMessage, + KernelOpEntryPointAndOpInBoundsPtrAccessChainGood) { + std::string spirv = kOpenCLMemoryModel32 + R"( + OpEntryPoint Kernel %2 "simple_kernel" + OpSource OpenCL_C 200000 + OpDecorate %3 BuiltIn GlobalInvocationId + OpDecorate %3 Constant + OpDecorate %4 FuncParamAttr NoCapture + OpDecorate %3 LinkageAttributes "__spirv_GlobalInvocationId" Import + %5 = OpTypeInt 32 0 + %6 = OpTypeVector %5 3 + %7 = OpTypePointer UniformConstant %6 + %3 = OpVariable %7 UniformConstant + %8 = OpTypeVoid + %9 = OpTypeStruct %5 +%10 = OpTypePointer CrossWorkgroup %9 +%11 = OpTypeFunction %8 %10 +%12 = OpConstant %5 0 +%13 = OpTypePointer CrossWorkgroup %5 +%14 = OpConstant %5 42 + %2 = OpFunction %8 None %11 + %4 = OpFunctionParameter %10 +%15 = OpLabel +%16 = OpLoad %6 %3 Aligned 0 +%17 = OpCompositeExtract %5 %16 0 +%18 = OpInBoundsPtrAccessChain %13 %4 %17 %12 + OpStore %18 %14 Aligned 4 + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpPtrAccessChainGood) { + std::string spirv = kOpenCLMemoryModel64 + R"( + OpEntryPoint Kernel %2 "another_kernel" + OpSource OpenCL_C 200000 + OpDecorate %3 BuiltIn GlobalInvocationId + OpDecorate %3 Constant + OpDecorate %4 FuncParamAttr NoCapture + OpDecorate %3 LinkageAttributes "__spirv_GlobalInvocationId" Import + %5 = OpTypeInt 64 0 + %6 = OpTypeVector %5 3 + %7 = OpTypePointer UniformConstant %6 + %3 = OpVariable %7 UniformConstant + %8 = OpTypeVoid + %9 = OpTypeInt 32 0 +%10 = OpTypeStruct %9 +%11 = OpTypePointer CrossWorkgroup %10 +%12 = OpTypeFunction %8 %11 +%13 = OpConstant %5 4294967295 +%14 = OpConstant %9 0 +%15 = OpTypePointer CrossWorkgroup %9 +%16 = OpConstant %9 42 + %2 = OpFunction %8 None %12 + %4 = OpFunctionParameter %11 +%17 = OpLabel +%18 = OpLoad %6 %3 Aligned 0 +%19 = OpCompositeExtract %5 %18 0 +%20 = OpBitwiseAnd %5 %19 %13 +%21 = OpPtrAccessChain %15 %4 %20 %14 + OpStore %21 %16 Aligned 4 + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, StgBufOpPtrAccessChainGood) { + std::string spirv = R"( + OpCapability Shader + OpCapability Linkage + OpCapability VariablePointersStorageBuffer + OpExtension "SPV_KHR_variable_pointers" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %3 "" +%int = OpTypeInt 32 0 +%int_2 = OpConstant %int 2 +%int_4 = OpConstant %int 4 +%struct = OpTypeStruct %int +%array = OpTypeArray %struct %int_4 +%ptr = OpTypePointer StorageBuffer %array +%var = OpVariable %ptr StorageBuffer +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpFunction %1 None %2 +%4 = OpLabel +%5 = OpPtrAccessChain %ptr %var %int_2 + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateIdWithMessage, OpLoadBitcastPointerGood) { + std::string spirv = kOpenCLMemoryModel64 + R"( +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypeFloat 32 +%5 = OpTypePointer UniformConstant %3 +%6 = OpTypePointer UniformConstant %4 +%7 = OpVariable %5 UniformConstant +%8 = OpTypeFunction %2 +%9 = OpFunction %2 None %8 +%10 = OpLabel +%11 = OpBitcast %6 %7 +%12 = OpLoad %4 %11 + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpLoadBitcastNonPointerBad) { + std::string spirv = kOpenCLMemoryModel64 + R"( +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypeFloat 32 +%5 = OpTypePointer UniformConstant %3 +%6 = OpTypeFunction %2 +%7 = OpVariable %5 UniformConstant +%8 = OpFunction %2 None %6 +%9 = OpLabel +%10 = OpLoad %3 %7 +%11 = OpBitcast %4 %10 +%12 = OpLoad %3 %11 + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpLoad type for pointer '11[%11]' is not a pointer " + "type.")); +} +TEST_F(ValidateIdWithMessage, OpStoreBitcastPointerGood) { + std::string spirv = kOpenCLMemoryModel64 + R"( +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypeFloat 32 +%5 = OpTypePointer Function %3 +%6 = OpTypePointer Function %4 +%7 = OpTypeFunction %2 +%8 = OpConstant %3 42 +%9 = OpFunction %2 None %7 +%10 = OpLabel +%11 = OpVariable %6 Function +%12 = OpBitcast %5 %11 + OpStore %12 %8 + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} +TEST_F(ValidateIdWithMessage, OpStoreBitcastNonPointerBad) { + std::string spirv = kOpenCLMemoryModel64 + R"( +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypeFloat 32 +%5 = OpTypePointer Function %4 +%6 = OpTypeFunction %2 +%7 = OpConstant %4 42 +%8 = OpFunction %2 None %6 +%9 = OpLabel +%10 = OpVariable %5 Function +%11 = OpBitcast %3 %7 + OpStore %11 %7 + OpReturn + OpFunctionEnd)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpStore type for pointer '11[%11]' is not a pointer " + "type.")); +} + +// Result resulting from an instruction within a function may not be used +// outside that function. +TEST_F(ValidateIdWithMessage, ResultIdUsedOutsideOfFunctionBad) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Function %3 +%5 = OpFunction %1 None %2 +%6 = OpLabel +%7 = OpVariable %4 Function +OpReturn +OpFunctionEnd +%8 = OpFunction %1 None %2 +%9 = OpLabel +%10 = OpLoad %3 %7 +OpReturn +OpFunctionEnd + )"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "ID 7[%7] defined in block 6[%6] does not dominate its use in block " + "9[%9]")); +} + +TEST_F(ValidateIdWithMessage, SpecIdTargetNotSpecializationConstant) { + std::string spirv = kGLSL450MemoryModel + R"( +OpDecorate %1 SpecId 200 +%void = OpTypeVoid +%2 = OpTypeFunction %void +%int = OpTypeInt 32 0 +%1 = OpConstant %int 3 +%main = OpFunction %void None %2 +%4 = OpLabel +OpReturnValue %1 +OpFunctionEnd + )"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpDecorate SpecId decoration target " + "'1[%uint_3]' is not a scalar specialization " + "constant.")); +} + +TEST_F(ValidateIdWithMessage, SpecIdTargetOpSpecConstantOpBad) { + std::string spirv = kGLSL450MemoryModel + R"( +OpDecorate %1 SpecId 200 +%void = OpTypeVoid +%2 = OpTypeFunction %void +%int = OpTypeInt 32 0 +%3 = OpConstant %int 1 +%4 = OpConstant %int 2 +%1 = OpSpecConstantOp %int IAdd %3 %4 +%main = OpFunction %void None %2 +%6 = OpLabel +OpReturnValue %3 +OpFunctionEnd + )"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpDecorate SpecId decoration target '1[%1]' is " + "not a scalar specialization constant.")); +} + +TEST_F(ValidateIdWithMessage, SpecIdTargetOpSpecConstantCompositeBad) { + std::string spirv = kGLSL450MemoryModel + R"( +OpDecorate %1 SpecId 200 +%void = OpTypeVoid +%2 = OpTypeFunction %void +%int = OpTypeInt 32 0 +%3 = OpConstant %int 1 +%1 = OpSpecConstantComposite %int +%main = OpFunction %void None %2 +%4 = OpLabel +OpReturnValue %3 +OpFunctionEnd + )"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpDecorate SpecId decoration target '1[%1]' is " + "not a scalar specialization constant.")); +} + +TEST_F(ValidateIdWithMessage, SpecIdTargetGood) { + std::string spirv = kGLSL450MemoryModel + R"( +OpDecorate %3 SpecId 200 +OpDecorate %4 SpecId 201 +OpDecorate %5 SpecId 202 +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%int = OpTypeInt 32 0 +%bool = OpTypeBool +%3 = OpSpecConstant %int 3 +%4 = OpSpecConstantTrue %bool +%5 = OpSpecConstantFalse %bool +%main = OpFunction %1 None %2 +%6 = OpLabel +OpReturn +OpFunctionEnd + )"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidateIdWithMessage, CorrectErrorForShuffle) { + std::string spirv = kGLSL450MemoryModel + R"( + %uint = OpTypeInt 32 0 + %float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%v2float = OpTypeVector %float 2 + %void = OpTypeVoid + %548 = OpTypeFunction %void + %CS = OpFunction %void None %548 + %550 = OpLabel + %6275 = OpUndef %v2float + %6280 = OpUndef %v2float + %6282 = OpVectorShuffle %v4float %6275 %6280 0 1 4 5 + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Component index 4 is out of bounds for combined (Vector1 + Vector2) " + "size of 4.")); + EXPECT_EQ(25, getErrorPosition().index); +} + +TEST_F(ValidateIdWithMessage, VoidStructMember) { + const std::string spirv = kGLSL450MemoryModel + R"( +%void = OpTypeVoid +%struct = OpTypeStruct %void +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Structures cannot contain a void type.")); +} + +TEST_F(ValidateIdWithMessage, TypeFunctionBadUse) { + std::string spirv = kGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpTypePointer Function %2 +%4 = OpFunction %1 None %2 +%5 = OpLabel + OpReturn + OpFunctionEnd)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Invalid use of function type result id 2[%2].")); +} + +TEST_F(ValidateIdWithMessage, BadTypeId) { + std::string spirv = kGLSL450MemoryModel + R"( + %1 = OpTypeVoid + %2 = OpTypeFunction %1 + %3 = OpTypeFloat 32 + %4 = OpConstant %3 0 + %5 = OpFunction %1 None %2 + %6 = OpLabel + %7 = OpUndef %4 + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("ID 4[%float_0] is not a type " + "id")); +} + +TEST_F(ValidateIdWithMessage, VulkanMemoryModelLoadMakePointerVisibleGood) { + std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Workgroup %2 +%4 = OpVariable %3 Workgroup +%5 = OpTypeFunction %1 +%6 = OpConstant %2 2 +%7 = OpFunction %1 None %5 +%8 = OpLabel +%9 = OpLoad %2 %4 NonPrivatePointerKHR|MakePointerVisibleKHR %6 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateIdWithMessage, + VulkanMemoryModelLoadMakePointerVisibleMissingNonPrivatePointer) { + std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Workgroup %2 +%4 = OpVariable %3 Workgroup +%5 = OpTypeFunction %1 +%6 = OpConstant %2 2 +%7 = OpFunction %1 None %5 +%8 = OpLabel +%9 = OpLoad %2 %4 MakePointerVisibleKHR %6 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("NonPrivatePointerKHR must be specified if " + "MakePointerVisibleKHR is specified.")); +} + +TEST_F(ValidateIdWithMessage, + VulkanMemoryModelLoadNonPrivatePointerBadStorageClass) { + std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Private %2 +%4 = OpVariable %3 Private +%5 = OpTypeFunction %1 +%6 = OpConstant %2 2 +%7 = OpFunction %1 None %5 +%8 = OpLabel +%9 = OpLoad %2 %4 NonPrivatePointerKHR +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("NonPrivatePointerKHR requires a pointer in Uniform, " + "Workgroup, CrossWorkgroup, Generic, Image or " + "StorageBuffer storage classes.")); +} + +TEST_F(ValidateIdWithMessage, + VulkanMemoryModelLoadMakePointerAvailableCannotBeUsed) { + std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Workgroup %2 +%4 = OpVariable %3 Workgroup +%5 = OpTypeFunction %1 +%6 = OpConstant %2 2 +%7 = OpFunction %1 None %5 +%8 = OpLabel +%9 = OpLoad %2 %4 NonPrivatePointerKHR|MakePointerAvailableKHR %6 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("MakePointerAvailableKHR cannot be used with OpLoad")); +} + +TEST_F(ValidateIdWithMessage, VulkanMemoryModelStoreMakePointerAvailableGood) { + std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Uniform %2 +%4 = OpVariable %3 Uniform +%5 = OpTypeFunction %1 +%6 = OpConstant %2 5 +%7 = OpFunction %1 None %5 +%8 = OpLabel +OpStore %4 %6 NonPrivatePointerKHR|MakePointerAvailableKHR %6 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateIdWithMessage, + VulkanMemoryModelStoreMakePointerAvailableMissingNonPrivatePointer) { + std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Uniform %2 +%4 = OpVariable %3 Uniform +%5 = OpTypeFunction %1 +%6 = OpConstant %2 5 +%7 = OpFunction %1 None %5 +%8 = OpLabel +OpStore %4 %6 MakePointerAvailableKHR %6 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("NonPrivatePointerKHR must be specified if " + "MakePointerAvailableKHR is specified.")); +} + +TEST_F(ValidateIdWithMessage, + VulkanMemoryModelStoreNonPrivatePointerBadStorageClass) { + std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Output %2 +%4 = OpVariable %3 Output +%5 = OpTypeFunction %1 +%6 = OpConstant %2 5 +%7 = OpFunction %1 None %5 +%8 = OpLabel +OpStore %4 %6 NonPrivatePointerKHR +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("NonPrivatePointerKHR requires a pointer in Uniform, " + "Workgroup, CrossWorkgroup, Generic, Image or " + "StorageBuffer storage classes.")); +} + +TEST_F(ValidateIdWithMessage, + VulkanMemoryModelStoreMakePointerVisibleCannotBeUsed) { + std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Uniform %2 +%4 = OpVariable %3 Uniform +%5 = OpTypeFunction %1 +%6 = OpConstant %2 5 +%7 = OpFunction %1 None %5 +%8 = OpLabel +OpStore %4 %6 NonPrivatePointerKHR|MakePointerVisibleKHR %6 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("MakePointerVisibleKHR cannot be used with OpStore.")); +} + +TEST_F(ValidateIdWithMessage, VulkanMemoryModelCopyMemoryAvailable) { + std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Workgroup %2 +%4 = OpVariable %3 Workgroup +%5 = OpTypePointer Uniform %2 +%6 = OpVariable %5 Uniform +%7 = OpConstant %2 2 +%8 = OpConstant %2 5 +%9 = OpTypeFunction %1 +%10 = OpFunction %1 None %9 +%11 = OpLabel +OpCopyMemory %4 %6 NonPrivatePointerKHR|MakePointerAvailableKHR %7 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateIdWithMessage, VulkanMemoryModelCopyMemoryVisible) { + std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Workgroup %2 +%4 = OpVariable %3 Workgroup +%5 = OpTypePointer Uniform %2 +%6 = OpVariable %5 Uniform +%7 = OpConstant %2 2 +%8 = OpConstant %2 5 +%9 = OpTypeFunction %1 +%10 = OpFunction %1 None %9 +%11 = OpLabel +OpCopyMemory %4 %6 NonPrivatePointerKHR|MakePointerVisibleKHR %8 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateIdWithMessage, VulkanMemoryModelCopyMemoryAvailableAndVisible) { + std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Workgroup %2 +%4 = OpVariable %3 Workgroup +%5 = OpTypePointer Uniform %2 +%6 = OpVariable %5 Uniform +%7 = OpConstant %2 2 +%8 = OpConstant %2 5 +%9 = OpTypeFunction %1 +%10 = OpFunction %1 None %9 +%11 = OpLabel +OpCopyMemory %4 %6 NonPrivatePointerKHR|MakePointerAvailableKHR|MakePointerVisibleKHR %7 %8 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateIdWithMessage, + VulkanMemoryModelCopyMemoryAvailableMissingNonPrivatePointer) { + std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Workgroup %2 +%4 = OpVariable %3 Workgroup +%5 = OpTypePointer Uniform %2 +%6 = OpVariable %5 Uniform +%7 = OpConstant %2 2 +%8 = OpConstant %2 5 +%9 = OpTypeFunction %1 +%10 = OpFunction %1 None %9 +%11 = OpLabel +OpCopyMemory %4 %6 MakePointerAvailableKHR %7 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("NonPrivatePointerKHR must be specified if " + "MakePointerAvailableKHR is specified.")); +} + +TEST_F(ValidateIdWithMessage, + VulkanMemoryModelCopyMemoryVisibleMissingNonPrivatePointer) { + std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Workgroup %2 +%4 = OpVariable %3 Workgroup +%5 = OpTypePointer Uniform %2 +%6 = OpVariable %5 Uniform +%7 = OpConstant %2 2 +%8 = OpConstant %2 5 +%9 = OpTypeFunction %1 +%10 = OpFunction %1 None %9 +%11 = OpLabel +OpCopyMemory %4 %6 MakePointerVisibleKHR %8 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("NonPrivatePointerKHR must be specified if " + "MakePointerVisibleKHR is specified.")); +} + +TEST_F(ValidateIdWithMessage, + VulkanMemoryModelCopyMemoryAvailableBadStorageClass) { + std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Output %2 +%4 = OpVariable %3 Output +%5 = OpTypePointer Uniform %2 +%6 = OpVariable %5 Uniform +%7 = OpConstant %2 2 +%8 = OpConstant %2 5 +%9 = OpTypeFunction %1 +%10 = OpFunction %1 None %9 +%11 = OpLabel +OpCopyMemory %4 %6 NonPrivatePointerKHR +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("NonPrivatePointerKHR requires a pointer in Uniform, " + "Workgroup, CrossWorkgroup, Generic, Image or " + "StorageBuffer storage classes.")); +} + +TEST_F(ValidateIdWithMessage, + VulkanMemoryModelCopyMemoryVisibleBadStorageClass) { + std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Workgroup %2 +%4 = OpVariable %3 Workgroup +%5 = OpTypePointer Input %2 +%6 = OpVariable %5 Input +%7 = OpConstant %2 2 +%8 = OpConstant %2 5 +%9 = OpTypeFunction %1 +%10 = OpFunction %1 None %9 +%11 = OpLabel +OpCopyMemory %4 %6 NonPrivatePointerKHR +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("NonPrivatePointerKHR requires a pointer in Uniform, " + "Workgroup, CrossWorkgroup, Generic, Image or " + "StorageBuffer storage classes.")); +} + +TEST_F(ValidateIdWithMessage, VulkanMemoryModelCopyMemorySizedAvailable) { + std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability Addresses +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Workgroup %2 +%4 = OpVariable %3 Workgroup +%5 = OpTypePointer Uniform %2 +%6 = OpVariable %5 Uniform +%7 = OpConstant %2 2 +%8 = OpConstant %2 5 +%9 = OpTypeFunction %1 +%10 = OpFunction %1 None %9 +%11 = OpLabel +OpCopyMemorySized %4 %6 %7 NonPrivatePointerKHR|MakePointerAvailableKHR %7 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateIdWithMessage, VulkanMemoryModelCopyMemorySizedVisible) { + std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability Addresses +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Workgroup %2 +%4 = OpVariable %3 Workgroup +%5 = OpTypePointer Uniform %2 +%6 = OpVariable %5 Uniform +%7 = OpConstant %2 2 +%8 = OpConstant %2 5 +%9 = OpTypeFunction %1 +%10 = OpFunction %1 None %9 +%11 = OpLabel +OpCopyMemorySized %4 %6 %7 NonPrivatePointerKHR|MakePointerVisibleKHR %8 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateIdWithMessage, + VulkanMemoryModelCopyMemorySizedAvailableAndVisible) { + std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability Addresses +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Workgroup %2 +%4 = OpVariable %3 Workgroup +%5 = OpTypePointer Uniform %2 +%6 = OpVariable %5 Uniform +%7 = OpConstant %2 2 +%8 = OpConstant %2 5 +%9 = OpTypeFunction %1 +%10 = OpFunction %1 None %9 +%11 = OpLabel +OpCopyMemorySized %4 %6 %7 NonPrivatePointerKHR|MakePointerAvailableKHR|MakePointerVisibleKHR %7 %8 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateIdWithMessage, + VulkanMemoryModelCopyMemorySizedAvailableMissingNonPrivatePointer) { + std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability Addresses +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Workgroup %2 +%4 = OpVariable %3 Workgroup +%5 = OpTypePointer Uniform %2 +%6 = OpVariable %5 Uniform +%7 = OpConstant %2 2 +%8 = OpConstant %2 5 +%9 = OpTypeFunction %1 +%10 = OpFunction %1 None %9 +%11 = OpLabel +OpCopyMemorySized %4 %6 %7 MakePointerAvailableKHR %7 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("NonPrivatePointerKHR must be specified if " + "MakePointerAvailableKHR is specified.")); +} + +TEST_F(ValidateIdWithMessage, + VulkanMemoryModelCopyMemorySizedVisibleMissingNonPrivatePointer) { + std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability Addresses +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Workgroup %2 +%4 = OpVariable %3 Workgroup +%5 = OpTypePointer Uniform %2 +%6 = OpVariable %5 Uniform +%7 = OpConstant %2 2 +%8 = OpConstant %2 5 +%9 = OpTypeFunction %1 +%10 = OpFunction %1 None %9 +%11 = OpLabel +OpCopyMemorySized %4 %6 %7 MakePointerVisibleKHR %8 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("NonPrivatePointerKHR must be specified if " + "MakePointerVisibleKHR is specified.")); +} + +TEST_F(ValidateIdWithMessage, + VulkanMemoryModelCopyMemorySizedAvailableBadStorageClass) { + std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability Addresses +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Output %2 +%4 = OpVariable %3 Output +%5 = OpTypePointer Uniform %2 +%6 = OpVariable %5 Uniform +%7 = OpConstant %2 2 +%8 = OpConstant %2 5 +%9 = OpTypeFunction %1 +%10 = OpFunction %1 None %9 +%11 = OpLabel +OpCopyMemorySized %4 %6 %7 NonPrivatePointerKHR +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("NonPrivatePointerKHR requires a pointer in Uniform, " + "Workgroup, CrossWorkgroup, Generic, Image or " + "StorageBuffer storage classes.")); +} + +TEST_F(ValidateIdWithMessage, + VulkanMemoryModelCopyMemorySizedVisibleBadStorageClass) { + std::string spirv = R"( +OpCapability Shader +OpCapability Linkage +OpCapability Addresses +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%1 = OpTypeVoid +%2 = OpTypeInt 32 0 +%3 = OpTypePointer Workgroup %2 +%4 = OpVariable %3 Workgroup +%5 = OpTypePointer Input %2 +%6 = OpVariable %5 Input +%7 = OpConstant %2 2 +%8 = OpConstant %2 5 +%9 = OpTypeFunction %1 +%10 = OpFunction %1 None %9 +%11 = OpLabel +OpCopyMemorySized %4 %6 %7 NonPrivatePointerKHR +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("NonPrivatePointerKHR requires a pointer in Uniform, " + "Workgroup, CrossWorkgroup, Generic, Image or " + "StorageBuffer storage classes.")); +} + +TEST_F(ValidateIdWithMessage, IdDefInUnreachableBlock1) { + const std::string spirv = kNoKernelGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpTypeFloat 32 +%4 = OpTypeFunction %3 +%5 = OpFunction %1 None %2 +%6 = OpLabel +%7 = OpFunctionCall %3 %8 +OpUnreachable +OpFunctionEnd +%8 = OpFunction %3 None %4 +%9 = OpLabel +OpReturnValue %7 +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("ID 7[%7] defined in block 6[%6] does not dominate its " + "use in block 9[%9]\n %9 = OpLabel")); +} + +TEST_F(ValidateIdWithMessage, IdDefInUnreachableBlock2) { + const std::string spirv = kNoKernelGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpTypeFloat 32 +%4 = OpTypeFunction %3 +%5 = OpFunction %1 None %2 +%6 = OpLabel +OpReturn +%7 = OpLabel +%8 = OpFunctionCall %3 %9 +OpUnreachable +OpFunctionEnd +%9 = OpFunction %3 None %4 +%10 = OpLabel +OpReturnValue %8 +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("ID 8[%8] defined in block 7[%7] does not dominate its " + "use in block 10[%10]\n %10 = OpLabel")); +} + +TEST_F(ValidateIdWithMessage, IdDefInUnreachableBlock3) { + const std::string spirv = kNoKernelGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpTypeFloat 32 +%4 = OpTypeFunction %3 +%5 = OpFunction %1 None %2 +%6 = OpLabel +OpReturn +%7 = OpLabel +%8 = OpFunctionCall %3 %9 +OpReturn +OpFunctionEnd +%9 = OpFunction %3 None %4 +%10 = OpLabel +OpReturnValue %8 +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("ID 8[%8] defined in block 7[%7] does not dominate its " + "use in block 10[%10]\n %10 = OpLabel")); +} + +TEST_F(ValidateIdWithMessage, IdDefInUnreachableBlock4) { + const std::string spirv = kNoKernelGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpTypeFloat 32 +%4 = OpTypeFunction %3 +%5 = OpFunction %1 None %2 +%6 = OpLabel +OpReturn +%7 = OpLabel +%8 = OpUndef %3 +%9 = OpCopyObject %3 %8 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateIdWithMessage, IdDefInUnreachableBlock5) { + const std::string spirv = kNoKernelGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpTypeFloat 32 +%4 = OpTypeFunction %3 +%5 = OpFunction %1 None %2 +%6 = OpLabel +OpReturn +%7 = OpLabel +%8 = OpUndef %3 +OpBranch %9 +%9 = OpLabel +%10 = OpCopyObject %3 %8 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateIdWithMessage, IdDefInUnreachableBlock6) { + const std::string spirv = kNoKernelGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpTypeFloat 32 +%4 = OpTypeFunction %3 +%5 = OpFunction %1 None %2 +%6 = OpLabel +OpBranch %7 +%8 = OpLabel +%9 = OpUndef %3 +OpBranch %7 +%7 = OpLabel +%10 = OpCopyObject %3 %9 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("ID 9[%9] defined in block 8[%8] does not dominate its " + "use in block 7[%7]\n %7 = OpLabel")); +} + +TEST_F(ValidateIdWithMessage, ReachableDefUnreachableUse) { + const std::string spirv = kNoKernelGLSL450MemoryModel + R"( +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpTypeFloat 32 +%4 = OpTypeFunction %3 +%5 = OpFunction %1 None %2 +%6 = OpLabel +%7 = OpUndef %3 +OpReturn +%8 = OpLabel +%9 = OpCopyObject %3 %7 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateIdWithMessage, UnreachableDefUsedInPhi) { + const std::string spirv = kNoKernelGLSL450MemoryModel + R"( + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %bool = OpTypeBool + %6 = OpTypeFunction %float + %1 = OpFunction %void None %3 + %7 = OpLabel + %8 = OpUndef %bool + OpSelectionMerge %9 None + OpBranchConditional %8 %10 %9 + %10 = OpLabel + %11 = OpUndef %float + OpBranch %9 + %12 = OpLabel + %13 = OpUndef %float + OpUnreachable + %9 = OpLabel + %14 = OpPhi %float %11 %10 %13 %7 + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("In OpPhi instruction 14[%14], ID 13[%13] definition does not " + "dominate its parent 7[%7]\n %14 = OpPhi %float %11 %10 %13 " + "%7")); +} + +TEST_F(ValidateIdWithMessage, OpTypeForwardPointerNotAPointerType) { + std::string spirv = R"( + OpCapability GenericPointer + OpCapability VariablePointersStorageBuffer + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginLowerLeft + OpTypeForwardPointer %2 CrossWorkgroup +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%1 = OpFunction %2 DontInline %3 +%4 = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Pointer type in OpTypeForwardPointer is not a pointer " + "type.\n OpTypeForwardPointer %void CrossWorkgroup")); +} + +TEST_F(ValidateIdWithMessage, OpTypeForwardPointerWrongStorageClass) { + std::string spirv = R"( + OpCapability GenericPointer + OpCapability VariablePointersStorageBuffer + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginLowerLeft + OpTypeForwardPointer %2 CrossWorkgroup +%int = OpTypeInt 32 1 +%2 = OpTypePointer Function %int +%void = OpTypeVoid +%3 = OpTypeFunction %void +%1 = OpFunction %void None %3 +%4 = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Storage class in OpTypeForwardPointer does not match the " + "pointer definition.\n OpTypeForwardPointer " + "%_ptr_Function_int CrossWorkgroup")); +} +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_image_test.cpp b/test/val/val_image_test.cpp new file mode 100644 index 000000000..79aecb25b --- /dev/null +++ b/test/val/val_image_test.cpp @@ -0,0 +1,4457 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests for unique type declaration rules validator. + +#include +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::HasSubstr; +using ::testing::Not; + +using ValidateImage = spvtest::ValidateBase; + +std::string GenerateShaderCode( + const std::string& body, + const std::string& capabilities_and_extensions = "", + const std::string& execution_model = "Fragment", + const spv_target_env env = SPV_ENV_UNIVERSAL_1_0, + const std::string& memory_model = "GLSL450") { + std::ostringstream ss; + ss << R"( +OpCapability Shader +OpCapability InputAttachment +OpCapability ImageGatherExtended +OpCapability MinLod +OpCapability Sampled1D +OpCapability ImageQuery +OpCapability Int64 +OpCapability Float64 +OpCapability SparseResidency +OpCapability ImageBuffer +)"; + + if (env == SPV_ENV_UNIVERSAL_1_0) { + ss << "OpCapability SampledRect\n"; + } + + ss << capabilities_and_extensions; + ss << "OpMemoryModel Logical " << memory_model << "\n"; + ss << "OpEntryPoint " << execution_model << " %main \"main\"\n"; + if (execution_model == "Fragment") { + ss << "OpExecutionMode %main OriginUpperLeft\n"; + } + + if (env == SPV_ENV_VULKAN_1_0) { + ss << R"( +OpDecorate %uniform_image_f32_1d_0001 DescriptorSet 0 +OpDecorate %uniform_image_f32_1d_0001 Binding 0 +OpDecorate %uniform_image_f32_1d_0002_rgba32f DescriptorSet 0 +OpDecorate %uniform_image_f32_1d_0002_rgba32f Binding 1 +OpDecorate %uniform_image_f32_2d_0001 DescriptorSet 0 +OpDecorate %uniform_image_f32_2d_0001 Binding 2 +OpDecorate %uniform_image_f32_2d_0010 DescriptorSet 0 +OpDecorate %uniform_image_f32_2d_0010 Binding 3 +OpDecorate %uniform_image_u32_2d_0001 DescriptorSet 1 +OpDecorate %uniform_image_u32_2d_0001 Binding 0 +OpDecorate %uniform_image_u32_2d_0000 DescriptorSet 1 +OpDecorate %uniform_image_u32_2d_0000 Binding 1 +OpDecorate %uniform_image_s32_3d_0001 DescriptorSet 1 +OpDecorate %uniform_image_s32_3d_0001 Binding 2 +OpDecorate %uniform_image_f32_2d_0002 DescriptorSet 1 +OpDecorate %uniform_image_f32_2d_0002 Binding 3 +OpDecorate %uniform_image_f32_spd_0002 DescriptorSet 2 +OpDecorate %uniform_image_f32_spd_0002 Binding 0 +OpDecorate %uniform_image_f32_3d_0111 DescriptorSet 2 +OpDecorate %uniform_image_f32_3d_0111 Binding 1 +OpDecorate %uniform_image_f32_cube_0101 DescriptorSet 2 +OpDecorate %uniform_image_f32_cube_0101 Binding 2 +OpDecorate %uniform_image_f32_cube_0102_rgba32f DescriptorSet 2 +OpDecorate %uniform_image_f32_cube_0102_rgba32f Binding 3 +OpDecorate %uniform_sampler DescriptorSet 3 +OpDecorate %uniform_sampler Binding 0 +)"; + } + + ss << R"( +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f32 = OpTypeFloat 32 +%f64 = OpTypeFloat 64 +%u32 = OpTypeInt 32 0 +%s32 = OpTypeInt 32 1 +%u64 = OpTypeInt 64 0 +%s32vec2 = OpTypeVector %s32 2 +%u32vec2 = OpTypeVector %u32 2 +%f32vec2 = OpTypeVector %f32 2 +%u32vec3 = OpTypeVector %u32 3 +%s32vec3 = OpTypeVector %s32 3 +%f32vec3 = OpTypeVector %f32 3 +%u32vec4 = OpTypeVector %u32 4 +%s32vec4 = OpTypeVector %s32 4 +%f32vec4 = OpTypeVector %f32 4 + +%f32_0 = OpConstant %f32 0 +%f32_1 = OpConstant %f32 1 +%f32_0_5 = OpConstant %f32 0.5 +%f32_0_25 = OpConstant %f32 0.25 +%f32_0_75 = OpConstant %f32 0.75 + +%f64_0 = OpConstant %f64 0 +%f64_1 = OpConstant %f64 1 + +%s32_0 = OpConstant %s32 0 +%s32_1 = OpConstant %s32 1 +%s32_2 = OpConstant %s32 2 +%s32_3 = OpConstant %s32 3 +%s32_4 = OpConstant %s32 4 +%s32_m1 = OpConstant %s32 -1 + +%u32_0 = OpConstant %u32 0 +%u32_1 = OpConstant %u32 1 +%u32_2 = OpConstant %u32 2 +%u32_3 = OpConstant %u32 3 +%u32_4 = OpConstant %u32 4 + +%u64_0 = OpConstant %u64 0 + +%u32vec2arr4 = OpTypeArray %u32vec2 %u32_4 +%u32vec2arr3 = OpTypeArray %u32vec2 %u32_3 +%u32arr4 = OpTypeArray %u32 %u32_4 +%u32vec3arr4 = OpTypeArray %u32vec3 %u32_4 + +%struct_u32_f32vec4 = OpTypeStruct %u32 %f32vec4 +%struct_u64_f32vec4 = OpTypeStruct %u64 %f32vec4 +%struct_u32_u32vec4 = OpTypeStruct %u32 %u32vec4 +%struct_u32_f32vec3 = OpTypeStruct %u32 %f32vec3 +%struct_f32_f32vec4 = OpTypeStruct %f32 %f32vec4 +%struct_u32_u32 = OpTypeStruct %u32 %u32 +%struct_f32_f32 = OpTypeStruct %f32 %f32 +%struct_u32 = OpTypeStruct %u32 +%struct_u32_f32_u32 = OpTypeStruct %u32 %f32 %u32 +%struct_u32_f32vec4_u32 = OpTypeStruct %u32 %f32vec4 %u32 +%struct_u32_u32arr4 = OpTypeStruct %u32 %u32arr4 + +%u32vec2_01 = OpConstantComposite %u32vec2 %u32_0 %u32_1 +%u32vec2_12 = OpConstantComposite %u32vec2 %u32_1 %u32_2 +%u32vec3_012 = OpConstantComposite %u32vec3 %u32_0 %u32_1 %u32_2 +%u32vec3_123 = OpConstantComposite %u32vec3 %u32_1 %u32_2 %u32_3 +%u32vec4_0123 = OpConstantComposite %u32vec4 %u32_0 %u32_1 %u32_2 %u32_3 +%u32vec4_1234 = OpConstantComposite %u32vec4 %u32_1 %u32_2 %u32_3 %u32_4 + +%s32vec2_01 = OpConstantComposite %s32vec2 %s32_0 %s32_1 +%s32vec2_12 = OpConstantComposite %s32vec2 %s32_1 %s32_2 +%s32vec3_012 = OpConstantComposite %s32vec3 %s32_0 %s32_1 %s32_2 +%s32vec3_123 = OpConstantComposite %s32vec3 %s32_1 %s32_2 %s32_3 +%s32vec4_0123 = OpConstantComposite %s32vec4 %s32_0 %s32_1 %s32_2 %s32_3 +%s32vec4_1234 = OpConstantComposite %s32vec4 %s32_1 %s32_2 %s32_3 %s32_4 + +%f32vec2_00 = OpConstantComposite %f32vec2 %f32_0 %f32_0 +%f32vec2_01 = OpConstantComposite %f32vec2 %f32_0 %f32_1 +%f32vec2_10 = OpConstantComposite %f32vec2 %f32_1 %f32_0 +%f32vec2_11 = OpConstantComposite %f32vec2 %f32_1 %f32_1 +%f32vec2_hh = OpConstantComposite %f32vec2 %f32_0_5 %f32_0_5 + +%f32vec3_000 = OpConstantComposite %f32vec3 %f32_0 %f32_0 %f32_0 +%f32vec3_hhh = OpConstantComposite %f32vec3 %f32_0_5 %f32_0_5 %f32_0_5 + +%f32vec4_0000 = OpConstantComposite %f32vec4 %f32_0 %f32_0 %f32_0 %f32_0 + +%const_offsets = OpConstantComposite %u32vec2arr4 %u32vec2_01 %u32vec2_12 %u32vec2_01 %u32vec2_12 +%const_offsets3x2 = OpConstantComposite %u32vec2arr3 %u32vec2_01 %u32vec2_12 %u32vec2_01 +%const_offsets4xu = OpConstantComposite %u32arr4 %u32_0 %u32_0 %u32_0 %u32_0 +%const_offsets4x3 = OpConstantComposite %u32vec3arr4 %u32vec3_012 %u32vec3_012 %u32vec3_012 %u32vec3_012 + +%type_image_f32_1d_0001 = OpTypeImage %f32 1D 0 0 0 1 Unknown +%ptr_image_f32_1d_0001 = OpTypePointer UniformConstant %type_image_f32_1d_0001 +%uniform_image_f32_1d_0001 = OpVariable %ptr_image_f32_1d_0001 UniformConstant +%type_sampled_image_f32_1d_0001 = OpTypeSampledImage %type_image_f32_1d_0001 + +%type_image_f32_1d_0002_rgba32f = OpTypeImage %f32 1D 0 0 0 2 Rgba32f +%ptr_image_f32_1d_0002_rgba32f = OpTypePointer UniformConstant %type_image_f32_1d_0002_rgba32f +%uniform_image_f32_1d_0002_rgba32f = OpVariable %ptr_image_f32_1d_0002_rgba32f UniformConstant +%type_sampled_image_f32_1d_0002_rgba32f = OpTypeSampledImage %type_image_f32_1d_0002_rgba32f + +%type_image_f32_2d_0001 = OpTypeImage %f32 2D 0 0 0 1 Unknown +%ptr_image_f32_2d_0001 = OpTypePointer UniformConstant %type_image_f32_2d_0001 +%uniform_image_f32_2d_0001 = OpVariable %ptr_image_f32_2d_0001 UniformConstant +%type_sampled_image_f32_2d_0001 = OpTypeSampledImage %type_image_f32_2d_0001 + +%type_image_f32_2d_0010 = OpTypeImage %f32 2D 0 0 1 0 Unknown +%ptr_image_f32_2d_0010 = OpTypePointer UniformConstant %type_image_f32_2d_0010 +%uniform_image_f32_2d_0010 = OpVariable %ptr_image_f32_2d_0010 UniformConstant +%type_sampled_image_f32_2d_0010 = OpTypeSampledImage %type_image_f32_2d_0010 + +%type_image_u32_2d_0001 = OpTypeImage %u32 2D 0 0 0 1 Unknown +%ptr_image_u32_2d_0001 = OpTypePointer UniformConstant %type_image_u32_2d_0001 +%uniform_image_u32_2d_0001 = OpVariable %ptr_image_u32_2d_0001 UniformConstant +%type_sampled_image_u32_2d_0001 = OpTypeSampledImage %type_image_u32_2d_0001 + +%type_image_u32_2d_0000 = OpTypeImage %u32 2D 0 0 0 0 Unknown +%ptr_image_u32_2d_0000 = OpTypePointer UniformConstant %type_image_u32_2d_0000 +%uniform_image_u32_2d_0000 = OpVariable %ptr_image_u32_2d_0000 UniformConstant +%type_sampled_image_u32_2d_0000 = OpTypeSampledImage %type_image_u32_2d_0000 + +%type_image_s32_3d_0001 = OpTypeImage %s32 3D 0 0 0 1 Unknown +%ptr_image_s32_3d_0001 = OpTypePointer UniformConstant %type_image_s32_3d_0001 +%uniform_image_s32_3d_0001 = OpVariable %ptr_image_s32_3d_0001 UniformConstant +%type_sampled_image_s32_3d_0001 = OpTypeSampledImage %type_image_s32_3d_0001 + +%type_image_f32_2d_0002 = OpTypeImage %f32 2D 0 0 0 2 Unknown +%ptr_image_f32_2d_0002 = OpTypePointer UniformConstant %type_image_f32_2d_0002 +%uniform_image_f32_2d_0002 = OpVariable %ptr_image_f32_2d_0002 UniformConstant +%type_sampled_image_f32_2d_0002 = OpTypeSampledImage %type_image_f32_2d_0002 + +%type_image_f32_spd_0002 = OpTypeImage %f32 SubpassData 0 0 0 2 Unknown +%ptr_image_f32_spd_0002 = OpTypePointer UniformConstant %type_image_f32_spd_0002 +%uniform_image_f32_spd_0002 = OpVariable %ptr_image_f32_spd_0002 UniformConstant +%type_sampled_image_f32_spd_0002 = OpTypeSampledImage %type_image_f32_spd_0002 + +%type_image_f32_3d_0111 = OpTypeImage %f32 3D 0 1 1 1 Unknown +%ptr_image_f32_3d_0111 = OpTypePointer UniformConstant %type_image_f32_3d_0111 +%uniform_image_f32_3d_0111 = OpVariable %ptr_image_f32_3d_0111 UniformConstant +%type_sampled_image_f32_3d_0111 = OpTypeSampledImage %type_image_f32_3d_0111 + +%type_image_f32_cube_0101 = OpTypeImage %f32 Cube 0 1 0 1 Unknown +%ptr_image_f32_cube_0101 = OpTypePointer UniformConstant %type_image_f32_cube_0101 +%uniform_image_f32_cube_0101 = OpVariable %ptr_image_f32_cube_0101 UniformConstant +%type_sampled_image_f32_cube_0101 = OpTypeSampledImage %type_image_f32_cube_0101 + +%type_image_f32_cube_0102_rgba32f = OpTypeImage %f32 Cube 0 1 0 2 Rgba32f +%ptr_image_f32_cube_0102_rgba32f = OpTypePointer UniformConstant %type_image_f32_cube_0102_rgba32f +%uniform_image_f32_cube_0102_rgba32f = OpVariable %ptr_image_f32_cube_0102_rgba32f UniformConstant +%type_sampled_image_f32_cube_0102_rgba32f = OpTypeSampledImage %type_image_f32_cube_0102_rgba32f + +%type_sampler = OpTypeSampler +%ptr_sampler = OpTypePointer UniformConstant %type_sampler +%uniform_sampler = OpVariable %ptr_sampler UniformConstant + +%type_image_u32_buffer_0002_r32ui = OpTypeImage %u32 Buffer 0 0 0 2 R32ui +%ptr_Image_u32 = OpTypePointer Image %u32 +%ptr_image_u32_buffer_0002_r32ui = OpTypePointer Private %type_image_u32_buffer_0002_r32ui +%private_image_u32_buffer_0002_r32ui = OpVariable %ptr_image_u32_buffer_0002_r32ui Private + +%ptr_Image_u32arr4 = OpTypePointer Image %u32arr4 + +%type_image_u32_spd_0002 = OpTypeImage %u32 SubpassData 0 0 0 2 Unknown +%ptr_image_u32_spd_0002 = OpTypePointer Private %type_image_u32_spd_0002 +%private_image_u32_spd_0002 = OpVariable %ptr_image_u32_spd_0002 Private + +%type_image_f32_buffer_0002_r32ui = OpTypeImage %f32 Buffer 0 0 0 2 R32ui +%ptr_Image_f32 = OpTypePointer Image %f32 +%ptr_image_f32_buffer_0002_r32ui = OpTypePointer Private %type_image_f32_buffer_0002_r32ui +%private_image_f32_buffer_0002_r32ui = OpVariable %ptr_image_f32_buffer_0002_r32ui Private +)"; + + if (env == SPV_ENV_UNIVERSAL_1_0) { + ss << R"( +%type_image_void_2d_0001 = OpTypeImage %void 2D 0 0 0 1 Unknown +%ptr_image_void_2d_0001 = OpTypePointer UniformConstant %type_image_void_2d_0001 +%uniform_image_void_2d_0001 = OpVariable %ptr_image_void_2d_0001 UniformConstant +%type_sampled_image_void_2d_0001 = OpTypeSampledImage %type_image_void_2d_0001 + +%type_image_void_2d_0002 = OpTypeImage %void 2D 0 0 0 2 Unknown +%ptr_image_void_2d_0002 = OpTypePointer UniformConstant %type_image_void_2d_0002 +%uniform_image_void_2d_0002 = OpVariable %ptr_image_void_2d_0002 UniformConstant +%type_sampled_image_void_2d_0002 = OpTypeSampledImage %type_image_void_2d_0002 + +%type_image_f32_rect_0001 = OpTypeImage %f32 Rect 0 0 0 1 Unknown +%ptr_image_f32_rect_0001 = OpTypePointer UniformConstant %type_image_f32_rect_0001 +%uniform_image_f32_rect_0001 = OpVariable %ptr_image_f32_rect_0001 UniformConstant +%type_sampled_image_f32_rect_0001 = OpTypeSampledImage %type_image_f32_rect_0001 +)"; + } + + ss << R"( +%main = OpFunction %void None %func +%main_entry = OpLabel +)"; + + ss << body; + + ss << R"( +OpReturn +OpFunctionEnd)"; + + return ss.str(); +} + +std::string GenerateKernelCode( + const std::string& body, + const std::string& capabilities_and_extensions = "") { + std::ostringstream ss; + ss << R"( +OpCapability Addresses +OpCapability Kernel +OpCapability Linkage +OpCapability ImageQuery +OpCapability ImageGatherExtended +OpCapability InputAttachment +OpCapability SampledRect +)"; + + ss << capabilities_and_extensions; + ss << R"( +OpMemoryModel Physical32 OpenCL +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f32 = OpTypeFloat 32 +%u32 = OpTypeInt 32 0 +%u32vec2 = OpTypeVector %u32 2 +%f32vec2 = OpTypeVector %f32 2 +%u32vec3 = OpTypeVector %u32 3 +%f32vec3 = OpTypeVector %f32 3 +%u32vec4 = OpTypeVector %u32 4 +%f32vec4 = OpTypeVector %f32 4 + +%f32_0 = OpConstant %f32 0 +%f32_1 = OpConstant %f32 1 +%f32_0_5 = OpConstant %f32 0.5 +%f32_0_25 = OpConstant %f32 0.25 +%f32_0_75 = OpConstant %f32 0.75 + +%u32_0 = OpConstant %u32 0 +%u32_1 = OpConstant %u32 1 +%u32_2 = OpConstant %u32 2 +%u32_3 = OpConstant %u32 3 +%u32_4 = OpConstant %u32 4 + +%u32vec2_01 = OpConstantComposite %u32vec2 %u32_0 %u32_1 +%u32vec2_12 = OpConstantComposite %u32vec2 %u32_1 %u32_2 +%u32vec3_012 = OpConstantComposite %u32vec3 %u32_0 %u32_1 %u32_2 +%u32vec3_123 = OpConstantComposite %u32vec3 %u32_1 %u32_2 %u32_3 +%u32vec4_0123 = OpConstantComposite %u32vec4 %u32_0 %u32_1 %u32_2 %u32_3 +%u32vec4_1234 = OpConstantComposite %u32vec4 %u32_1 %u32_2 %u32_3 %u32_4 + +%f32vec2_00 = OpConstantComposite %f32vec2 %f32_0 %f32_0 +%f32vec2_01 = OpConstantComposite %f32vec2 %f32_0 %f32_1 +%f32vec2_10 = OpConstantComposite %f32vec2 %f32_1 %f32_0 +%f32vec2_11 = OpConstantComposite %f32vec2 %f32_1 %f32_1 +%f32vec2_hh = OpConstantComposite %f32vec2 %f32_0_5 %f32_0_5 + +%f32vec3_000 = OpConstantComposite %f32vec3 %f32_0 %f32_0 %f32_0 +%f32vec3_hhh = OpConstantComposite %f32vec3 %f32_0_5 %f32_0_5 %f32_0_5 + +%f32vec4_0000 = OpConstantComposite %f32vec4 %f32_0 %f32_0 %f32_0 %f32_0 + +%type_image_f32_2d_0001 = OpTypeImage %f32 2D 0 0 0 1 Unknown +%ptr_image_f32_2d_0001 = OpTypePointer UniformConstant %type_image_f32_2d_0001 +%uniform_image_f32_2d_0001 = OpVariable %ptr_image_f32_2d_0001 UniformConstant +%type_sampled_image_f32_2d_0001 = OpTypeSampledImage %type_image_f32_2d_0001 + +%type_image_f32_2d_0010 = OpTypeImage %f32 2D 0 0 1 0 Unknown +%ptr_image_f32_2d_0010 = OpTypePointer UniformConstant %type_image_f32_2d_0010 +%uniform_image_f32_2d_0010 = OpVariable %ptr_image_f32_2d_0010 UniformConstant +%type_sampled_image_f32_2d_0010 = OpTypeSampledImage %type_image_f32_2d_0010 + +%type_image_f32_3d_0010 = OpTypeImage %f32 3D 0 0 1 0 Unknown +%ptr_image_f32_3d_0010 = OpTypePointer UniformConstant %type_image_f32_3d_0010 +%uniform_image_f32_3d_0010 = OpVariable %ptr_image_f32_3d_0010 UniformConstant +%type_sampled_image_f32_3d_0010 = OpTypeSampledImage %type_image_f32_3d_0010 + +%type_image_f32_rect_0001 = OpTypeImage %f32 Rect 0 0 0 1 Unknown +%ptr_image_f32_rect_0001 = OpTypePointer UniformConstant %type_image_f32_rect_0001 +%uniform_image_f32_rect_0001 = OpVariable %ptr_image_f32_rect_0001 UniformConstant +%type_sampled_image_f32_rect_0001 = OpTypeSampledImage %type_image_f32_rect_0001 + +%type_sampler = OpTypeSampler +%ptr_sampler = OpTypePointer UniformConstant %type_sampler +%uniform_sampler = OpVariable %ptr_sampler UniformConstant + +%main = OpFunction %void None %func +%main_entry = OpLabel +)"; + + ss << body; + ss << R"( +OpReturn +OpFunctionEnd)"; + + return ss.str(); +} + +std::string GetShaderHeader(const std::string& capabilities_and_extensions = "", + bool include_entry_point = true) { + std::ostringstream ss; + ss << R"( +OpCapability Shader +OpCapability Int64 +)"; + + ss << capabilities_and_extensions; + if (!include_entry_point) { + ss << "OpCapability Linkage"; + } + + ss << R"( +OpMemoryModel Logical GLSL450 +)"; + + if (include_entry_point) { + ss << "OpEntryPoint Fragment %main \"main\"\n"; + ss << "OpExecutionMode %main OriginUpperLeft"; + } + ss << R"( +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f32 = OpTypeFloat 32 +%u32 = OpTypeInt 32 0 +%u64 = OpTypeInt 64 0 +%s32 = OpTypeInt 32 1 +)"; + + return ss.str(); +} + +TEST_F(ValidateImage, TypeImageWrongSampledType) { + const std::string code = GetShaderHeader("", false) + R"( +%img_type = OpTypeImage %bool 2D 0 0 0 1 Unknown +)"; + + CompileSuccessfully(code.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Sampled Type to be either void or " + "numerical scalar " + "type")); +} + +TEST_F(ValidateImage, TypeImageVoidSampledTypeVulkan) { + const std::string code = GetShaderHeader() + R"( +%img_type = OpTypeImage %void 2D 0 0 0 1 Unknown +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +OpReturn +OpFunctionEnd +)"; + + const spv_target_env env = SPV_ENV_VULKAN_1_0; + CompileSuccessfully(code, env); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(env)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Sampled Type to be a 32-bit int " + "or float scalar type for Vulkan environment")); +} + +TEST_F(ValidateImage, TypeImageU64SampledTypeVulkan) { + const std::string code = GetShaderHeader() + R"( +%img_type = OpTypeImage %u64 2D 0 0 0 1 Unknown +%void_func = OpTypeFunction %void +%main = OpFunction %void None %void_func +%main_lab = OpLabel +OpReturn +OpFunctionEnd +)"; + + const spv_target_env env = SPV_ENV_VULKAN_1_0; + CompileSuccessfully(code, env); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(env)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Sampled Type to be a 32-bit int " + "or float scalar type for Vulkan environment")); +} + +TEST_F(ValidateImage, TypeImageWrongDepth) { + const std::string code = GetShaderHeader("", false) + R"( +%img_type = OpTypeImage %f32 2D 3 0 0 1 Unknown +)"; + + CompileSuccessfully(code.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Invalid Depth 3 (must be 0, 1 or 2)")); +} + +TEST_F(ValidateImage, TypeImageWrongArrayed) { + const std::string code = GetShaderHeader("", false) + R"( +%img_type = OpTypeImage %f32 2D 0 2 0 1 Unknown +)"; + + CompileSuccessfully(code.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Invalid Arrayed 2 (must be 0 or 1)")); +} + +TEST_F(ValidateImage, TypeImageWrongMS) { + const std::string code = GetShaderHeader("", false) + R"( +%img_type = OpTypeImage %f32 2D 0 0 2 1 Unknown +)"; + + CompileSuccessfully(code.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Invalid MS 2 (must be 0 or 1)")); +} + +TEST_F(ValidateImage, TypeImageWrongSampled) { + const std::string code = GetShaderHeader("", false) + R"( +%img_type = OpTypeImage %f32 2D 0 0 0 3 Unknown +)"; + + CompileSuccessfully(code.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Invalid Sampled 3 (must be 0, 1 or 2)")); +} + +TEST_F(ValidateImage, TypeImageWrongSampledForSubpassData) { + const std::string code = + GetShaderHeader("OpCapability InputAttachment\n", false) + + R"( +%img_type = OpTypeImage %f32 SubpassData 0 0 0 1 Unknown +)"; + + CompileSuccessfully(code.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Dim SubpassData requires Sampled to be 2")); +} + +TEST_F(ValidateImage, TypeImageWrongFormatForSubpassData) { + const std::string code = + GetShaderHeader("OpCapability InputAttachment\n", false) + + R"( +%img_type = OpTypeImage %f32 SubpassData 0 0 0 2 Rgba32f +)"; + + CompileSuccessfully(code.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Dim SubpassData requires format Unknown")); +} + +TEST_F(ValidateImage, TypeSampledImageNotImage) { + const std::string code = GetShaderHeader("", false) + R"( +%simg_type = OpTypeSampledImage %f32 +)"; + + CompileSuccessfully(code.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image to be of type OpTypeImage")); +} + +TEST_F(ValidateImage, SampledImageSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, SampledImageVulkanSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +)"; + + const spv_target_env env = SPV_ENV_VULKAN_1_0; + CompileSuccessfully(GenerateShaderCode(body, "", "Fragment", env), env); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(env)); +} + +TEST_F(ValidateImage, SampledImageWrongResultType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_image_f32_2d_0001 %img %sampler +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be OpTypeSampledImage")); +} + +TEST_F(ValidateImage, SampledImageNotImage) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg1 = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%simg2 = OpSampledImage %type_sampled_image_f32_2d_0001 %simg1 %sampler +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image to be of type OpTypeImage")); +} + +TEST_F(ValidateImage, SampledImageImageNotForSampling) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0002 %uniform_image_f32_2d_0002 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0002 %img %sampler +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image 'Sampled' parameter to be 0 or 1")); +} + +TEST_F(ValidateImage, SampledImageVulkanUnknownSampled) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0000 %uniform_image_u32_2d_0000 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_u32_2d_0000 %img %sampler +)"; + + const spv_target_env env = SPV_ENV_VULKAN_1_0; + CompileSuccessfully(GenerateShaderCode(body, "", "Fragment", env), env); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(env)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image 'Sampled' parameter to " + "be 1 for Vulkan environment.")); +} + +TEST_F(ValidateImage, SampledImageNotSampler) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %img +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Sampler to be of type OpTypeSampler")); +} + +TEST_F(ValidateImage, ImageTexelPointerSuccess) { + const std::string body = R"( +%texel_ptr = OpImageTexelPointer %ptr_Image_u32 %private_image_u32_buffer_0002_r32ui %u32_0 %u32_0 +%sum = OpAtomicIAdd %u32 %texel_ptr %u32_1 %u32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, ImageTexelPointerResultTypeNotPointer) { + const std::string body = R"( +%texel_ptr = OpImageTexelPointer %type_image_u32_buffer_0002_r32ui %private_image_u32_buffer_0002_r32ui %u32_0 %u32_0 +%sum = OpAtomicIAdd %u32 %texel_ptr %u32_1 %u32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be OpTypePointer")); +} + +TEST_F(ValidateImage, ImageTexelPointerResultTypeNotImageClass) { + const std::string body = R"( +%texel_ptr = OpImageTexelPointer %ptr_image_f32_cube_0101 %private_image_u32_buffer_0002_r32ui %u32_0 %u32_0 +%sum = OpAtomicIAdd %u32 %texel_ptr %u32_1 %u32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be OpTypePointer whose " + "Storage Class operand is Image")); +} + +TEST_F(ValidateImage, ImageTexelPointerResultTypeNotNumericNorVoid) { + const std::string body = R"( +%texel_ptr = OpImageTexelPointer %ptr_Image_u32arr4 %private_image_u32_buffer_0002_r32ui %u32_0 %u32_0 +%sum = OpAtomicIAdd %u32 %texel_ptr %u32_1 %u32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Result Type to be OpTypePointer whose Type operand " + "must be a scalar numerical type or OpTypeVoid")); +} + +TEST_F(ValidateImage, ImageTexelPointerImageNotResultTypePointer) { + const std::string body = R"( +%texel_ptr = OpImageTexelPointer %ptr_Image_u32 %type_image_f32_buffer_0002_r32ui %u32_0 %u32_0 +%sum = OpAtomicIAdd %u32 %texel_ptr %u32_1 %u32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("Operand 136[%136] cannot be a " + "type")); +} + +TEST_F(ValidateImage, ImageTexelPointerImageNotImage) { + const std::string body = R"( +%texel_ptr = OpImageTexelPointer %ptr_Image_u32 %uniform_sampler %u32_0 %u32_0 +%sum = OpAtomicIAdd %u32 %texel_ptr %u32_1 %u32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Image to be OpTypePointer with Type OpTypeImage")); +} + +TEST_F(ValidateImage, ImageTexelPointerImageSampledNotResultType) { + const std::string body = R"( +%texel_ptr = OpImageTexelPointer %ptr_Image_u32 %uniform_image_f32_cube_0101 %u32_0 %u32_0 +%sum = OpAtomicIAdd %u32 %texel_ptr %u32_1 %u32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image 'Sampled Type' to be the same as the " + "Type pointed to by Result Type")); +} + +TEST_F(ValidateImage, ImageTexelPointerImageDimSubpassDataBad) { + const std::string body = R"( +%texel_ptr = OpImageTexelPointer %ptr_Image_u32 %private_image_u32_spd_0002 %u32_0 %u32_0 +%sum = OpAtomicIAdd %u32 %texel_ptr %u32_1 %u32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Image Dim SubpassData cannot be used with OpImageTexelPointer")); +} + +TEST_F(ValidateImage, ImageTexelPointerImageCoordTypeBad) { + const std::string body = R"( +%texel_ptr = OpImageTexelPointer %ptr_Image_f32 %private_image_f32_buffer_0002_r32ui %f32_0 %f32_0 +%sum = OpAtomicIAdd %f32 %texel_ptr %f32_1 %f32_0 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to be integer scalar or vector")); +} + +TEST_F(ValidateImage, ImageTexelPointerImageCoordSizeBad) { + const std::string body = R"( +%texel_ptr = OpImageTexelPointer %ptr_Image_u32 %uniform_image_u32_2d_0000 %u32vec3_012 %u32_0 +%sum = OpAtomicIAdd %u32 %texel_ptr %u32_1 %u32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Coordinate to have 2 components, but given 3")); +} + +TEST_F(ValidateImage, ImageTexelPointerSampleNotIntScalar) { + const std::string body = R"( +%texel_ptr = OpImageTexelPointer %ptr_Image_u32 %private_image_u32_buffer_0002_r32ui %u32_0 %f32_0 +%sum = OpAtomicIAdd %u32 %texel_ptr %u32_1 %u32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Sample to be integer scalar")); +} + +TEST_F(ValidateImage, ImageTexelPointerSampleNotZeroForImageWithMSZero) { + const std::string body = R"( +%texel_ptr = OpImageTexelPointer %ptr_Image_u32 %private_image_u32_buffer_0002_r32ui %u32_0 %u32_1 +%sum = OpAtomicIAdd %u32 %texel_ptr %u32_1 %u32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Sample for Image with MS 0 to be a valid " + " for the value 0")); +} + +TEST_F(ValidateImage, SampleImplicitLodSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec2_hh +%res2 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec2_hh Bias %f32_0_25 +%res4 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec2_hh ConstOffset %s32vec2_01 +%res5 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec2_hh Offset %s32vec2_01 +%res6 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec2_hh MinLod %f32_0_5 +%res7 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec2_hh Bias|Offset|MinLod %f32_0_25 %s32vec2_01 %f32_0_5 +%res8 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec2_hh NonPrivateTexelKHR +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment", + SPV_ENV_UNIVERSAL_1_3, "VulkanKHR") + .c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateImage, SampleImplicitLodWrongResultType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleImplicitLod %f32 %simg %f32vec2_hh +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be int or float vector type")); +} + +TEST_F(ValidateImage, SampleImplicitLodWrongNumComponentsResultType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleImplicitLod %f32vec3 %simg %f32vec2_hh +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to have 4 components")); +} + +TEST_F(ValidateImage, SampleImplicitLodNotSampledImage) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%res1 = OpImageSampleImplicitLod %f32vec4 %img %f32vec2_hh +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage")); +} + +TEST_F(ValidateImage, SampleImplicitLodWrongSampledType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleImplicitLod %u32vec4 %simg %f32vec2_00 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image 'Sampled Type' to be the same as " + "Result Type components")); +} + +TEST_F(ValidateImage, SampleImplicitLodVoidSampledType) { + const std::string body = R"( +%img = OpLoad %type_image_void_2d_0001 %uniform_image_void_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_void_2d_0001 %img %sampler +%res1 = OpImageSampleImplicitLod %u32vec4 %simg %f32vec2_00 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, SampleImplicitLodWrongCoordinateType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleImplicitLod %f32vec4 %simg %img +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to be float scalar or vector")); +} + +TEST_F(ValidateImage, SampleImplicitLodCoordinateSizeTooSmall) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleImplicitLod %f32vec4 %simg %f32_0_5 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to have at least 2 components, " + "but given only 1")); +} + +TEST_F(ValidateImage, SampleExplicitLodSuccessShader) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleExplicitLod %f32vec4 %simg %f32vec4_0000 Lod %f32_1 +%res2 = OpImageSampleExplicitLod %f32vec4 %simg %f32vec2_hh Grad %f32vec2_10 %f32vec2_01 +%res3 = OpImageSampleExplicitLod %f32vec4 %simg %f32vec2_hh ConstOffset %s32vec2_01 +%res4 = OpImageSampleExplicitLod %f32vec4 %simg %f32vec3_hhh Offset %s32vec2_01 +%res5 = OpImageSampleExplicitLod %f32vec4 %simg %f32vec2_hh Grad|Offset|MinLod %f32vec2_10 %f32vec2_01 %s32vec2_01 %f32_0_5 +%res6 = OpImageSampleExplicitLod %f32vec4 %simg %f32vec4_0000 Lod|NonPrivateTexelKHR %f32_1 +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment", + SPV_ENV_UNIVERSAL_1_3, "VulkanKHR") + .c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateImage, SampleExplicitLodSuccessKernel) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleExplicitLod %f32vec4 %simg %u32vec4_0123 Lod %f32_1 +%res2 = OpImageSampleExplicitLod %f32vec4 %simg %u32vec2_01 Grad %f32vec2_10 %f32vec2_01 +%res3 = OpImageSampleExplicitLod %f32vec4 %simg %f32vec2_hh ConstOffset %u32vec2_01 +%res4 = OpImageSampleExplicitLod %f32vec4 %simg %u32vec2_01 Offset %u32vec2_01 +%res5 = OpImageSampleExplicitLod %f32vec4 %simg %f32vec2_hh Grad|Offset %f32vec2_10 %f32vec2_01 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, SampleExplicitLodSuccessCubeArrayed) { + const std::string body = R"( +%img = OpLoad %type_image_f32_cube_0101 %uniform_image_f32_cube_0101 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_cube_0101 %img %sampler +%res1 = OpImageSampleExplicitLod %f32vec4 %simg %f32vec4_0000 Grad %f32vec3_hhh %f32vec3_hhh +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, SampleExplicitLodWrongResultType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleExplicitLod %f32 %simg %f32vec2_hh Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be int or float vector type")); +} + +TEST_F(ValidateImage, SampleExplicitLodWrongNumComponentsResultType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleExplicitLod %f32vec3 %simg %f32vec2_hh Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to have 4 components")); +} + +TEST_F(ValidateImage, SampleExplicitLodNotSampledImage) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%res1 = OpImageSampleExplicitLod %f32vec4 %img %f32vec2_hh Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage")); +} + +TEST_F(ValidateImage, SampleExplicitLodWrongSampledType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleExplicitLod %u32vec4 %simg %f32vec2_00 Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image 'Sampled Type' to be the same as " + "Result Type components")); +} + +TEST_F(ValidateImage, SampleExplicitLodVoidSampledType) { + const std::string body = R"( +%img = OpLoad %type_image_void_2d_0001 %uniform_image_void_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_void_2d_0001 %img %sampler +%res1 = OpImageSampleExplicitLod %u32vec4 %simg %f32vec2_00 Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, SampleExplicitLodWrongCoordinateType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleExplicitLod %f32vec4 %simg %img Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to be float scalar or vector")); +} + +TEST_F(ValidateImage, SampleExplicitLodCoordinateSizeTooSmall) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleExplicitLod %f32vec4 %simg %f32_0_5 Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to have at least 2 components, " + "but given only 1")); +} + +TEST_F(ValidateImage, SampleExplicitLodBias) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleExplicitLod %f32vec4 %simg %f32vec2_00 Bias|Lod %f32_1 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Image Operand Bias can only be used with ImplicitLod opcodes")); +} + +TEST_F(ValidateImage, LodAndGrad) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleExplicitLod %f32vec4 %simg %f32vec2_00 Lod|Grad %f32_1 %f32vec2_hh %f32vec2_hh +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Image Operand bits Lod and Grad cannot be set at the same time")); +} + +TEST_F(ValidateImage, ImplicitLodWithLod) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res2 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec2_hh Lod %f32_0_5 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Image Operand Lod can only be used with ExplicitLod opcodes " + "and OpImageFetch")); +} + +TEST_F(ValidateImage, LodWrongType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleExplicitLod %f32vec4 %simg %f32vec2_00 Lod %f32vec2_hh)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image Operand Lod to be float scalar when " + "used with ExplicitLod")); +} + +TEST_F(ValidateImage, LodWrongDim) { + const std::string body = R"( +%img = OpLoad %type_image_f32_rect_0001 %uniform_image_f32_rect_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_rect_0001 %img %sampler +%res1 = OpImageSampleExplicitLod %f32vec4 %simg %f32vec2_00 Lod %f32_0)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Image Operand Lod requires 'Dim' parameter to be 1D, " + "2D, 3D or Cube")); +} + +TEST_F(ValidateImage, LodMultisampled) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0010 %uniform_image_f32_2d_0010 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0010 %img %sampler +%res1 = OpImageSampleExplicitLod %f32vec4 %simg %f32vec2_00 Lod %f32_0)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Image Operand Lod requires 'MS' parameter to be 0")); +} + +TEST_F(ValidateImage, MinLodIncompatible) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleExplicitLod %f32vec4 %simg %f32vec2_00 Lod|MinLod %f32_0 %f32_0)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Image Operand MinLod can only be used with ImplicitLod opcodes or " + "together with Image Operand Grad")); +} + +TEST_F(ValidateImage, ImplicitLodWithGrad) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res2 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec2_hh Grad %f32vec2_hh %f32vec2_hh +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Image Operand Grad can only be used with ExplicitLod opcodes")); +} + +TEST_F(ValidateImage, SampleImplicitLod3DArrayedMultisampledSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_f32_3d_0111 %uniform_image_f32_3d_0111 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_3d_0111 %img %sampler +%res1 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec4_0000 +%res2 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec4_0000 ConstOffset %s32vec3_012 +%res3 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec4_0000 Offset %s32vec3_012 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, SampleImplicitLodCubeArrayedSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_f32_cube_0101 %uniform_image_f32_cube_0101 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_cube_0101 %img %sampler +%res1 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec4_0000 +%res2 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec4_0000 Bias %f32_0_25 +%res4 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec4_0000 MinLod %f32_0_5 +%res5 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec4_0000 Bias|MinLod %f32_0_25 %f32_0_5 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, SampleImplicitLodBiasWrongType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res2 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec2_hh Bias %u32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image Operand Bias to be float scalar")); +} + +TEST_F(ValidateImage, SampleImplicitLodBiasWrongDim) { + const std::string body = R"( +%img = OpLoad %type_image_f32_rect_0001 %uniform_image_f32_rect_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_rect_0001 %img %sampler +%res2 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec2_hh Bias %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Image Operand Bias requires 'Dim' parameter to be 1D, " + "2D, 3D or Cube")); +} + +TEST_F(ValidateImage, SampleImplicitLodBiasMultisampled) { + const std::string body = R"( +%img = OpLoad %type_image_f32_3d_0111 %uniform_image_f32_3d_0111 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_3d_0111 %img %sampler +%res1 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec4_0000 Bias %f32_0_25 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Image Operand Bias requires 'MS' parameter to be 0")); +} + +TEST_F(ValidateImage, SampleExplicitLodGradDxWrongType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_cube_0101 %uniform_image_f32_cube_0101 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_cube_0101 %img %sampler +%res1 = OpImageSampleExplicitLod %f32vec4 %simg %f32vec4_0000 Grad %s32vec3_012 %f32vec3_hhh +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected both Image Operand Grad ids to be float " + "scalars or vectors")); +} + +TEST_F(ValidateImage, SampleExplicitLodGradDyWrongType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_cube_0101 %uniform_image_f32_cube_0101 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_cube_0101 %img %sampler +%res1 = OpImageSampleExplicitLod %f32vec4 %simg %f32vec4_0000 Grad %f32vec3_hhh %s32vec3_012 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected both Image Operand Grad ids to be float " + "scalars or vectors")); +} + +TEST_F(ValidateImage, SampleExplicitLodGradDxWrongSize) { + const std::string body = R"( +%img = OpLoad %type_image_f32_cube_0101 %uniform_image_f32_cube_0101 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_cube_0101 %img %sampler +%res1 = OpImageSampleExplicitLod %f32vec4 %simg %f32vec4_0000 Grad %f32vec2_00 %f32vec3_hhh +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected Image Operand Grad dx to have 3 components, but given 2")); +} + +TEST_F(ValidateImage, SampleExplicitLodGradDyWrongSize) { + const std::string body = R"( +%img = OpLoad %type_image_f32_cube_0101 %uniform_image_f32_cube_0101 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_cube_0101 %img %sampler +%res1 = OpImageSampleExplicitLod %f32vec4 %simg %f32vec4_0000 Grad %f32vec3_hhh %f32vec2_00 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected Image Operand Grad dy to have 3 components, but given 2")); +} + +TEST_F(ValidateImage, SampleExplicitLodGradMultisampled) { + const std::string body = R"( +%img = OpLoad %type_image_f32_3d_0111 %uniform_image_f32_3d_0111 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_3d_0111 %img %sampler +%res1 = OpImageSampleExplicitLod %f32vec4 %simg %f32vec4_0000 Grad %f32vec3_000 %f32vec3_000 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Image Operand Grad requires 'MS' parameter to be 0")); +} + +TEST_F(ValidateImage, SampleImplicitLodConstOffsetCubeDim) { + const std::string body = R"( +%img = OpLoad %type_image_f32_cube_0101 %uniform_image_f32_cube_0101 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_cube_0101 %img %sampler +%res4 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec4_0000 ConstOffset %s32vec3_012 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Image Operand ConstOffset cannot be used with Cube Image 'Dim'")); +} + +TEST_F(ValidateImage, SampleImplicitLodConstOffsetWrongType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_3d_0111 %uniform_image_f32_3d_0111 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_3d_0111 %img %sampler +%res4 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec4_0000 ConstOffset %f32vec3_000 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected Image Operand ConstOffset to be int scalar or vector")); +} + +TEST_F(ValidateImage, SampleImplicitLodConstOffsetWrongSize) { + const std::string body = R"( +%img = OpLoad %type_image_f32_3d_0111 %uniform_image_f32_3d_0111 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_3d_0111 %img %sampler +%res4 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec4_0000 ConstOffset %s32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image Operand ConstOffset to have 3 " + "components, but given 2")); +} + +TEST_F(ValidateImage, SampleImplicitLodConstOffsetNotConst) { + const std::string body = R"( +%img = OpLoad %type_image_f32_3d_0111 %uniform_image_f32_3d_0111 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_3d_0111 %img %sampler +%offset = OpSNegate %s32vec3 %s32vec3_012 +%res4 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec4_0000 ConstOffset %offset +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Image Operand ConstOffset to be a const object")); +} + +TEST_F(ValidateImage, SampleImplicitLodOffsetCubeDim) { + const std::string body = R"( +%img = OpLoad %type_image_f32_cube_0101 %uniform_image_f32_cube_0101 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_cube_0101 %img %sampler +%res4 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec4_0000 Offset %s32vec3_012 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Image Operand Offset cannot be used with Cube Image 'Dim'")); +} + +TEST_F(ValidateImage, SampleImplicitLodOffsetWrongType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_3d_0111 %uniform_image_f32_3d_0111 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_3d_0111 %img %sampler +%res4 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec4_0000 Offset %f32vec3_000 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Image Operand Offset to be int scalar or vector")); +} + +TEST_F(ValidateImage, SampleImplicitLodOffsetWrongSize) { + const std::string body = R"( +%img = OpLoad %type_image_f32_3d_0111 %uniform_image_f32_3d_0111 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_3d_0111 %img %sampler +%res4 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec4_0000 Offset %s32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected Image Operand Offset to have 3 components, but given 2")); +} + +TEST_F(ValidateImage, SampleImplicitLodMoreThanOneOffset) { + const std::string body = R"( +%img = OpLoad %type_image_f32_3d_0111 %uniform_image_f32_3d_0111 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_3d_0111 %img %sampler +%res4 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec4_0000 ConstOffset|Offset %s32vec3_012 %s32vec3_012 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Image Operands Offset, ConstOffset, ConstOffsets " + "cannot be used together")); +} + +TEST_F(ValidateImage, SampleImplicitLodMinLodWrongType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_cube_0101 %uniform_image_f32_cube_0101 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_cube_0101 %img %sampler +%res1 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec4_0000 MinLod %s32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image Operand MinLod to be float scalar")); +} + +TEST_F(ValidateImage, SampleImplicitLodMinLodWrongDim) { + const std::string body = R"( +%img = OpLoad %type_image_f32_rect_0001 %uniform_image_f32_rect_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_rect_0001 %img %sampler +%res2 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec2_hh MinLod %f32_0_25 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Image Operand MinLod requires 'Dim' parameter to be " + "1D, 2D, 3D or Cube")); +} + +TEST_F(ValidateImage, SampleImplicitLodMinLodMultisampled) { + const std::string body = R"( +%img = OpLoad %type_image_f32_3d_0111 %uniform_image_f32_3d_0111 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_3d_0111 %img %sampler +%res1 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec4_0000 MinLod %f32_0_25 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Image Operand MinLod requires 'MS' parameter to be 0")); +} + +TEST_F(ValidateImage, SampleProjExplicitLodSuccess2D) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleProjExplicitLod %f32vec4 %simg %f32vec3_hhh Lod %f32_1 +%res3 = OpImageSampleProjExplicitLod %f32vec4 %simg %f32vec3_hhh Grad %f32vec2_10 %f32vec2_01 +%res4 = OpImageSampleProjExplicitLod %f32vec4 %simg %f32vec3_hhh ConstOffset %s32vec2_01 +%res5 = OpImageSampleProjExplicitLod %f32vec4 %simg %f32vec3_hhh Offset %s32vec2_01 +%res7 = OpImageSampleProjExplicitLod %f32vec4 %simg %f32vec3_hhh Grad|Offset %f32vec2_10 %f32vec2_01 %s32vec2_01 +%res8 = OpImageSampleProjExplicitLod %f32vec4 %simg %f32vec3_hhh Lod|NonPrivateTexelKHR %f32_1 +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment", + SPV_ENV_UNIVERSAL_1_3, "VulkanKHR") + .c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateImage, SampleProjExplicitLodSuccessRect) { + const std::string body = R"( +%img = OpLoad %type_image_f32_rect_0001 %uniform_image_f32_rect_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_rect_0001 %img %sampler +%res1 = OpImageSampleProjExplicitLod %f32vec4 %simg %f32vec3_hhh Grad %f32vec2_10 %f32vec2_01 +%res2 = OpImageSampleProjExplicitLod %f32vec4 %simg %f32vec3_hhh Grad|Offset %f32vec2_10 %f32vec2_01 %s32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, SampleProjExplicitLodWrongResultType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleProjExplicitLod %f32 %simg %f32vec3_hhh Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be int or float vector type")); +} + +TEST_F(ValidateImage, SampleProjExplicitLodWrongNumComponentsResultType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleProjExplicitLod %f32vec3 %simg %f32vec3_hhh Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to have 4 components")); +} + +TEST_F(ValidateImage, SampleProjExplicitLodNotSampledImage) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%res1 = OpImageSampleProjExplicitLod %f32vec4 %img %f32vec3_hhh Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage")); +} + +TEST_F(ValidateImage, SampleProjExplicitLodWrongSampledType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleProjExplicitLod %u32vec4 %simg %f32vec3_hhh Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image 'Sampled Type' to be the same as " + "Result Type components")); +} + +TEST_F(ValidateImage, SampleProjExplicitLodVoidSampledType) { + const std::string body = R"( +%img = OpLoad %type_image_void_2d_0001 %uniform_image_void_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_void_2d_0001 %img %sampler +%res1 = OpImageSampleProjExplicitLod %u32vec4 %simg %f32vec3_hhh Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, SampleProjExplicitLodWrongCoordinateType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleProjExplicitLod %f32vec4 %simg %img Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to be float scalar or vector")); +} + +TEST_F(ValidateImage, SampleProjExplicitLodCoordinateSizeTooSmall) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleProjExplicitLod %f32vec4 %simg %f32vec2_hh Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to have at least 3 components, " + "but given only 2")); +} + +TEST_F(ValidateImage, SampleProjImplicitLodSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleProjImplicitLod %f32vec4 %simg %f32vec3_hhh +%res2 = OpImageSampleProjImplicitLod %f32vec4 %simg %f32vec3_hhh Bias %f32_0_25 +%res4 = OpImageSampleProjImplicitLod %f32vec4 %simg %f32vec3_hhh ConstOffset %s32vec2_01 +%res5 = OpImageSampleProjImplicitLod %f32vec4 %simg %f32vec3_hhh Offset %s32vec2_01 +%res6 = OpImageSampleProjImplicitLod %f32vec4 %simg %f32vec3_hhh MinLod %f32_0_5 +%res7 = OpImageSampleProjImplicitLod %f32vec4 %simg %f32vec3_hhh Bias|Offset|MinLod %f32_0_25 %s32vec2_01 %f32_0_5 +%res8 = OpImageSampleProjImplicitLod %f32vec4 %simg %f32vec3_hhh NonPrivateTexelKHR +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment", + SPV_ENV_UNIVERSAL_1_3, "VulkanKHR") + .c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateImage, SampleProjImplicitLodWrongResultType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleProjImplicitLod %f32 %simg %f32vec3_hhh +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be int or float vector type")); +} + +TEST_F(ValidateImage, SampleProjImplicitLodWrongNumComponentsResultType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleProjImplicitLod %f32vec3 %simg %f32vec3_hhh +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to have 4 components")); +} + +TEST_F(ValidateImage, SampleProjImplicitLodNotSampledImage) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%res1 = OpImageSampleProjImplicitLod %f32vec4 %img %f32vec3_hhh +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage")); +} + +TEST_F(ValidateImage, SampleProjImplicitLodWrongSampledType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleProjImplicitLod %u32vec4 %simg %f32vec3_hhh +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image 'Sampled Type' to be the same as " + "Result Type components")); +} + +TEST_F(ValidateImage, SampleProjImplicitLodVoidSampledType) { + const std::string body = R"( +%img = OpLoad %type_image_void_2d_0001 %uniform_image_void_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_void_2d_0001 %img %sampler +%res1 = OpImageSampleProjImplicitLod %u32vec4 %simg %f32vec3_hhh +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, SampleProjImplicitLodWrongCoordinateType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleProjImplicitLod %f32vec4 %simg %img +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to be float scalar or vector")); +} + +TEST_F(ValidateImage, SampleProjImplicitLodCoordinateSizeTooSmall) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleProjImplicitLod %f32vec4 %simg %f32vec2_hh +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to have at least 3 components, " + "but given only 2")); +} + +TEST_F(ValidateImage, SampleDrefImplicitLodSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0001 %uniform_image_u32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_u32_2d_0001 %img %sampler +%res1 = OpImageSampleDrefImplicitLod %u32 %simg %f32vec2_hh %f32_1 +%res2 = OpImageSampleDrefImplicitLod %u32 %simg %f32vec2_hh %f32_1 Bias %f32_0_25 +%res4 = OpImageSampleDrefImplicitLod %u32 %simg %f32vec2_hh %f32_1 ConstOffset %s32vec2_01 +%res5 = OpImageSampleDrefImplicitLod %u32 %simg %f32vec2_hh %f32_1 Offset %s32vec2_01 +%res6 = OpImageSampleDrefImplicitLod %u32 %simg %f32vec2_hh %f32_1 MinLod %f32_0_5 +%res7 = OpImageSampleDrefImplicitLod %u32 %simg %f32vec2_hh %f32_1 Bias|Offset|MinLod %f32_0_25 %s32vec2_01 %f32_0_5 +%res8 = OpImageSampleDrefImplicitLod %u32 %simg %f32vec2_hh %f32_1 NonPrivateTexelKHR +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment", + SPV_ENV_UNIVERSAL_1_3, "VulkanKHR") + .c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateImage, SampleDrefImplicitLodWrongResultType) { + const std::string body = R"( +%img = OpLoad %type_image_void_2d_0001 %uniform_image_void_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_void_2d_0001 %img %sampler +%res1 = OpImageSampleDrefImplicitLod %void %simg %f32vec2_hh %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be int or float scalar type")); +} + +TEST_F(ValidateImage, SampleDrefImplicitLodNotSampledImage) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0001 %uniform_image_u32_2d_0001 +%res1 = OpImageSampleDrefImplicitLod %u32 %img %f32vec2_hh %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage")); +} + +TEST_F(ValidateImage, SampleDrefImplicitLodWrongSampledType) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0001 %uniform_image_u32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_u32_2d_0001 %img %sampler +%res1 = OpImageSampleDrefImplicitLod %f32 %simg %f32vec2_00 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Image 'Sampled Type' to be the same as Result Type")); +} + +TEST_F(ValidateImage, SampleDrefImplicitLodVoidSampledType) { + const std::string body = R"( +%img = OpLoad %type_image_void_2d_0001 %uniform_image_void_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_void_2d_0001 %img %sampler +%res1 = OpImageSampleDrefImplicitLod %u32 %simg %f32vec2_00 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Image 'Sampled Type' to be the same as Result Type")); +} + +TEST_F(ValidateImage, SampleDrefImplicitLodWrongCoordinateType) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0001 %uniform_image_u32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_u32_2d_0001 %img %sampler +%res1 = OpImageSampleDrefImplicitLod %u32 %simg %img %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to be float scalar or vector")); +} + +TEST_F(ValidateImage, SampleDrefImplicitLodCoordinateSizeTooSmall) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleDrefImplicitLod %f32 %simg %f32_0_5 %f32_0_5 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to have at least 2 components, " + "but given only 1")); +} + +TEST_F(ValidateImage, SampleDrefImplicitLodWrongDrefType) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0001 %uniform_image_u32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_u32_2d_0001 %img %sampler +%res1 = OpImageSampleDrefImplicitLod %u32 %simg %f32vec2_00 %f64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Dref to be of 32-bit float type")); +} + +TEST_F(ValidateImage, SampleDrefExplicitLodSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_s32_3d_0001 %uniform_image_s32_3d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_s32_3d_0001 %img %sampler +%res1 = OpImageSampleDrefExplicitLod %s32 %simg %f32vec4_0000 %f32_1 Lod %f32_1 +%res3 = OpImageSampleDrefExplicitLod %s32 %simg %f32vec3_hhh %f32_1 Grad %f32vec3_hhh %f32vec3_hhh +%res4 = OpImageSampleDrefExplicitLod %s32 %simg %f32vec3_hhh %f32_1 ConstOffset %s32vec3_012 +%res5 = OpImageSampleDrefExplicitLod %s32 %simg %f32vec4_0000 %f32_1 Offset %s32vec3_012 +%res7 = OpImageSampleDrefExplicitLod %s32 %simg %f32vec3_hhh %f32_1 Grad|Offset %f32vec3_hhh %f32vec3_hhh %s32vec3_012 +%res8 = OpImageSampleDrefExplicitLod %s32 %simg %f32vec4_0000 %f32_1 Lod|NonPrivateTexelKHR %f32_1 +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment", + SPV_ENV_UNIVERSAL_1_3, "VulkanKHR") + .c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateImage, SampleDrefExplicitLodWrongResultType) { + const std::string body = R"( +%img = OpLoad %type_image_s32_3d_0001 %uniform_image_s32_3d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_s32_3d_0001 %img %sampler +%res1 = OpImageSampleDrefExplicitLod %bool %simg %f32vec3_hhh %s32_1 Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be int or float scalar type")); +} + +TEST_F(ValidateImage, SampleDrefExplicitLodNotSampledImage) { + const std::string body = R"( +%img = OpLoad %type_image_s32_3d_0001 %uniform_image_s32_3d_0001 +%res1 = OpImageSampleDrefExplicitLod %s32 %img %f32vec3_hhh %s32_1 Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage")); +} + +TEST_F(ValidateImage, SampleDrefExplicitLodWrongSampledType) { + const std::string body = R"( +%img = OpLoad %type_image_s32_3d_0001 %uniform_image_s32_3d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_s32_3d_0001 %img %sampler +%res1 = OpImageSampleDrefExplicitLod %f32 %simg %f32vec3_hhh %s32_1 Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Image 'Sampled Type' to be the same as Result Type")); +} + +TEST_F(ValidateImage, SampleDrefExplicitLodVoidSampledType) { + const std::string body = R"( +%img = OpLoad %type_image_void_2d_0001 %uniform_image_void_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_void_2d_0001 %img %sampler +%res1 = OpImageSampleDrefExplicitLod %u32 %simg %f32vec2_00 %s32_1 Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Image 'Sampled Type' to be the same as Result Type")); +} + +TEST_F(ValidateImage, SampleDrefExplicitLodWrongCoordinateType) { + const std::string body = R"( +%img = OpLoad %type_image_s32_3d_0001 %uniform_image_s32_3d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_s32_3d_0001 %img %sampler +%res1 = OpImageSampleDrefExplicitLod %s32 %simg %img %s32_1 Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to be float scalar or vector")); +} + +TEST_F(ValidateImage, SampleDrefExplicitLodCoordinateSizeTooSmall) { + const std::string body = R"( +%img = OpLoad %type_image_s32_3d_0001 %uniform_image_s32_3d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_s32_3d_0001 %img %sampler +%res1 = OpImageSampleDrefExplicitLod %s32 %simg %f32vec2_hh %s32_1 Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to have at least 3 components, " + "but given only 2")); +} + +TEST_F(ValidateImage, SampleDrefExplicitLodWrongDrefType) { + const std::string body = R"( +%img = OpLoad %type_image_s32_3d_0001 %uniform_image_s32_3d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_s32_3d_0001 %img %sampler +%res1 = OpImageSampleDrefExplicitLod %s32 %simg %f32vec3_hhh %u32_1 Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Dref to be of 32-bit float type")); +} + +TEST_F(ValidateImage, SampleProjDrefImplicitLodSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleProjDrefImplicitLod %f32 %simg %f32vec3_hhh %f32_0_5 +%res2 = OpImageSampleProjDrefImplicitLod %f32 %simg %f32vec3_hhh %f32_0_5 Bias %f32_0_25 +%res4 = OpImageSampleProjDrefImplicitLod %f32 %simg %f32vec3_hhh %f32_0_5 ConstOffset %s32vec2_01 +%res5 = OpImageSampleProjDrefImplicitLod %f32 %simg %f32vec3_hhh %f32_0_5 Offset %s32vec2_01 +%res6 = OpImageSampleProjDrefImplicitLod %f32 %simg %f32vec3_hhh %f32_0_5 MinLod %f32_0_5 +%res7 = OpImageSampleProjDrefImplicitLod %f32 %simg %f32vec3_hhh %f32_0_5 Bias|Offset|MinLod %f32_0_25 %s32vec2_01 %f32_0_5 +%res8 = OpImageSampleProjDrefImplicitLod %f32 %simg %f32vec3_hhh %f32_0_5 NonPrivateTexelKHR +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment", + SPV_ENV_UNIVERSAL_1_3, "VulkanKHR") + .c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateImage, SampleProjDrefImplicitLodWrongResultType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleProjDrefImplicitLod %void %simg %f32vec3_hhh %f32_0_5 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be int or float scalar type")); +} + +TEST_F(ValidateImage, SampleProjDrefImplicitLodNotSampledImage) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%res1 = OpImageSampleProjDrefImplicitLod %f32 %img %f32vec3_hhh %f32_0_5 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage")); +} + +TEST_F(ValidateImage, SampleProjDrefImplicitLodWrongSampledType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleProjDrefImplicitLod %u32 %simg %f32vec3_hhh %f32_0_5 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Image 'Sampled Type' to be the same as Result Type")); +} + +TEST_F(ValidateImage, SampleProjDrefImplicitLodVoidSampledType) { + const std::string body = R"( +%img = OpLoad %type_image_void_2d_0001 %uniform_image_void_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_void_2d_0001 %img %sampler +%res1 = OpImageSampleProjDrefImplicitLod %u32 %simg %f32vec3_hhh %f32_0_5 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Image 'Sampled Type' to be the same as Result Type")); +} + +TEST_F(ValidateImage, SampleProjDrefImplicitLodWrongCoordinateType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleProjDrefImplicitLod %f32 %simg %img %f32_0_5 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to be float scalar or vector")); +} + +TEST_F(ValidateImage, SampleProjDrefImplicitLodCoordinateSizeTooSmall) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleProjDrefImplicitLod %f32 %simg %f32vec2_hh %f32_0_5 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to have at least 3 components, " + "but given only 2")); +} + +TEST_F(ValidateImage, SampleProjDrefImplicitLodWrongDrefType) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0001 %uniform_image_u32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_u32_2d_0001 %img %sampler +%res1 = OpImageSampleProjDrefImplicitLod %u32 %simg %f32vec3_hhh %f32vec4_0000 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Dref to be of 32-bit float type")); +} + +TEST_F(ValidateImage, SampleProjDrefExplicitLodSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_f32_1d_0001 %uniform_image_f32_1d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_1d_0001 %img %sampler +%res1 = OpImageSampleProjDrefExplicitLod %f32 %simg %f32vec2_hh %f32_0_5 Lod %f32_1 +%res2 = OpImageSampleProjDrefExplicitLod %f32 %simg %f32vec3_hhh %f32_0_5 Grad %f32_0_5 %f32_0_5 +%res3 = OpImageSampleProjDrefExplicitLod %f32 %simg %f32vec2_hh %f32_0_5 ConstOffset %s32_1 +%res4 = OpImageSampleProjDrefExplicitLod %f32 %simg %f32vec2_hh %f32_0_5 Offset %s32_1 +%res5 = OpImageSampleProjDrefExplicitLod %f32 %simg %f32vec2_hh %f32_0_5 Grad|Offset %f32_0_5 %f32_0_5 %s32_1 +%res6 = OpImageSampleProjDrefExplicitLod %f32 %simg %f32vec2_hh %f32_0_5 Lod|NonPrivateTexelKHR %f32_1 +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment", + SPV_ENV_UNIVERSAL_1_3, "VulkanKHR") + .c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateImage, SampleProjDrefExplicitLodWrongResultType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_1d_0001 %uniform_image_f32_1d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_1d_0001 %img %sampler +%res1 = OpImageSampleProjDrefExplicitLod %bool %simg %f32vec2_hh %f32_0_5 Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be int or float scalar type")); +} + +TEST_F(ValidateImage, SampleProjDrefExplicitLodNotSampledImage) { + const std::string body = R"( +%img = OpLoad %type_image_f32_1d_0001 %uniform_image_f32_1d_0001 +%res1 = OpImageSampleProjDrefExplicitLod %f32 %img %f32vec2_hh %f32_0_5 Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage")); +} + +TEST_F(ValidateImage, SampleProjDrefExplicitLodWrongSampledType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_1d_0001 %uniform_image_f32_1d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_1d_0001 %img %sampler +%res1 = OpImageSampleProjDrefExplicitLod %u32 %simg %f32vec2_hh %f32_0_5 Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Image 'Sampled Type' to be the same as Result Type")); +} + +TEST_F(ValidateImage, SampleProjDrefExplicitLodVoidSampledType) { + const std::string body = R"( +%img = OpLoad %type_image_void_2d_0001 %uniform_image_void_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_void_2d_0001 %img %sampler +%res1 = OpImageSampleProjDrefExplicitLod %u32 %simg %f32vec3_hhh %f32_0_5 Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Image 'Sampled Type' to be the same as Result Type")); +} + +TEST_F(ValidateImage, SampleProjDrefExplicitLodWrongCoordinateType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_1d_0001 %uniform_image_f32_1d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_1d_0001 %img %sampler +%res1 = OpImageSampleProjDrefExplicitLod %f32 %simg %img %f32_0_5 Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to be float scalar or vector")); +} + +TEST_F(ValidateImage, SampleProjDrefExplicitLodCoordinateSizeTooSmall) { + const std::string body = R"( +%img = OpLoad %type_image_f32_1d_0001 %uniform_image_f32_1d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_1d_0001 %img %sampler +%res1 = OpImageSampleProjDrefExplicitLod %f32 %simg %f32_0_5 %f32_0_5 Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to have at least 2 components, " + "but given only 1")); +} + +TEST_F(ValidateImage, FetchSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_f32_1d_0001 %uniform_image_f32_1d_0001 +%res1 = OpImageFetch %f32vec4 %img %u32vec2_01 +%res2 = OpImageFetch %f32vec4 %img %u32vec2_01 NonPrivateTexelKHR +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment", + SPV_ENV_UNIVERSAL_1_3, "VulkanKHR") + .c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateImage, FetchWrongResultType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_rect_0001 %uniform_image_f32_rect_0001 +%res1 = OpImageFetch %f32 %img %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be int or float vector type")); +} + +TEST_F(ValidateImage, FetchWrongNumComponentsResultType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_rect_0001 %uniform_image_f32_rect_0001 +%res1 = OpImageFetch %f32vec3 %img %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to have 4 components")); +} + +TEST_F(ValidateImage, FetchNotImage) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageFetch %f32vec4 %simg %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image to be of type OpTypeImage")); +} + +TEST_F(ValidateImage, FetchNotSampled) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0000 %uniform_image_u32_2d_0000 +%res1 = OpImageFetch %u32vec4 %img %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image 'Sampled' parameter to be 1")); +} + +TEST_F(ValidateImage, FetchCube) { + const std::string body = R"( +%img = OpLoad %type_image_f32_cube_0101 %uniform_image_f32_cube_0101 +%res1 = OpImageFetch %f32vec4 %img %u32vec3_012 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("Image 'Dim' cannot be Cube")); +} + +TEST_F(ValidateImage, FetchWrongSampledType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_rect_0001 %uniform_image_f32_rect_0001 +%res1 = OpImageFetch %u32vec4 %img %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image 'Sampled Type' to be the same as " + "Result Type components")); +} + +TEST_F(ValidateImage, FetchVoidSampledType) { + const std::string body = R"( +%img = OpLoad %type_image_void_2d_0001 %uniform_image_void_2d_0001 +%res1 = OpImageFetch %f32vec4 %img %u32vec2_01 +%res2 = OpImageFetch %u32vec4 %img %u32vec2_01 +%res3 = OpImageFetch %s32vec4 %img %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, FetchWrongCoordinateType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_rect_0001 %uniform_image_f32_rect_0001 +%res1 = OpImageFetch %f32vec4 %img %f32vec2_00 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to be int scalar or vector")); +} + +TEST_F(ValidateImage, FetchCoordinateSizeTooSmall) { + const std::string body = R"( +%img = OpLoad %type_image_f32_rect_0001 %uniform_image_f32_rect_0001 +%res1 = OpImageFetch %f32vec4 %img %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to have at least 2 components, " + "but given only 1")); +} + +TEST_F(ValidateImage, FetchLodNotInt) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%res1 = OpImageFetch %f32vec4 %img %u32vec2_01 Lod %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image Operand Lod to be int scalar when used " + "with OpImageFetch")); +} + +TEST_F(ValidateImage, GatherSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageGather %f32vec4 %simg %f32vec4_0000 %u32_1 +%res2 = OpImageGather %f32vec4 %simg %f32vec4_0000 %u32_1 ConstOffsets %const_offsets +%res3 = OpImageGather %f32vec4 %simg %f32vec4_0000 %u32_1 NonPrivateTexelKHR +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment", + SPV_ENV_UNIVERSAL_1_3, "VulkanKHR") + .c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateImage, GatherWrongResultType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_cube_0101 %uniform_image_f32_cube_0101 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_cube_0101 %img %sampler +%res1 = OpImageGather %f32 %simg %f32vec4_0000 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be int or float vector type")); +} + +TEST_F(ValidateImage, GatherWrongNumComponentsResultType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_cube_0101 %uniform_image_f32_cube_0101 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_cube_0101 %img %sampler +%res1 = OpImageGather %f32vec3 %simg %f32vec4_0000 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to have 4 components")); +} + +TEST_F(ValidateImage, GatherNotSampledImage) { + const std::string body = R"( +%img = OpLoad %type_image_f32_cube_0101 %uniform_image_f32_cube_0101 +%res1 = OpImageGather %f32vec4 %img %f32vec4_0000 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Sampled Image to be of type OpTypeSampledImage")); +} + +TEST_F(ValidateImage, GatherWrongSampledType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_cube_0101 %uniform_image_f32_cube_0101 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_cube_0101 %img %sampler +%res1 = OpImageGather %u32vec4 %simg %f32vec4_0000 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image 'Sampled Type' to be the same as " + "Result Type components")); +} + +TEST_F(ValidateImage, GatherVoidSampledType) { + const std::string body = R"( +%img = OpLoad %type_image_void_2d_0001 %uniform_image_void_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_void_2d_0001 %img %sampler +%res1 = OpImageGather %u32vec4 %simg %f32vec2_00 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, GatherWrongCoordinateType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_cube_0101 %uniform_image_f32_cube_0101 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_cube_0101 %img %sampler +%res1 = OpImageGather %f32vec4 %simg %u32vec4_0123 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to be float scalar or vector")); +} + +TEST_F(ValidateImage, GatherCoordinateSizeTooSmall) { + const std::string body = R"( +%img = OpLoad %type_image_f32_cube_0101 %uniform_image_f32_cube_0101 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_cube_0101 %img %sampler +%res1 = OpImageGather %f32vec4 %simg %f32_0_5 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to have at least 4 components, " + "but given only 1")); +} + +TEST_F(ValidateImage, GatherWrongComponentType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_cube_0101 %uniform_image_f32_cube_0101 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_cube_0101 %img %sampler +%res1 = OpImageGather %f32vec4 %simg %f32vec4_0000 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Component to be 32-bit int scalar")); +} + +TEST_F(ValidateImage, GatherComponentNot32Bit) { + const std::string body = R"( +%img = OpLoad %type_image_f32_cube_0101 %uniform_image_f32_cube_0101 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_cube_0101 %img %sampler +%res1 = OpImageGather %f32vec4 %simg %f32vec4_0000 %u64_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Component to be 32-bit int scalar")); +} + +TEST_F(ValidateImage, GatherDimCube) { + const std::string body = R"( +%img = OpLoad %type_image_f32_cube_0101 %uniform_image_f32_cube_0101 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_cube_0101 %img %sampler +%res1 = OpImageGather %f32vec4 %simg %f32vec4_0000 %u32_1 ConstOffsets %const_offsets +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Image Operand ConstOffsets cannot be used with Cube Image 'Dim'")); +} + +TEST_F(ValidateImage, GatherConstOffsetsNotArray) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageGather %f32vec4 %simg %f32vec4_0000 %u32_1 ConstOffsets %u32vec4_0123 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected Image Operand ConstOffsets to be an array of size 4")); +} + +TEST_F(ValidateImage, GatherConstOffsetsArrayWrongSize) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageGather %f32vec4 %simg %f32vec4_0000 %u32_1 ConstOffsets %const_offsets3x2 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected Image Operand ConstOffsets to be an array of size 4")); +} + +TEST_F(ValidateImage, GatherConstOffsetsArrayNotVector) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageGather %f32vec4 %simg %f32vec4_0000 %u32_1 ConstOffsets %const_offsets4xu +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image Operand ConstOffsets array componenets " + "to be int vectors of size 2")); +} + +TEST_F(ValidateImage, GatherConstOffsetsArrayVectorWrongSize) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageGather %f32vec4 %simg %f32vec4_0000 %u32_1 ConstOffsets %const_offsets4x3 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image Operand ConstOffsets array componenets " + "to be int vectors of size 2")); +} + +TEST_F(ValidateImage, GatherConstOffsetsArrayNotConst) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%offsets = OpUndef %u32vec2arr4 +%res1 = OpImageGather %f32vec4 %simg %f32vec4_0000 %u32_1 ConstOffsets %offsets +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Image Operand ConstOffsets to be a const object")); +} + +TEST_F(ValidateImage, NotGatherWithConstOffsets) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res2 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec2_hh ConstOffsets %const_offsets +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Image Operand ConstOffsets can only be used with OpImageGather " + "and OpImageDrefGather")); +} + +TEST_F(ValidateImage, DrefGatherSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageDrefGather %f32vec4 %simg %f32vec4_0000 %f32_0_5 +%res2 = OpImageDrefGather %f32vec4 %simg %f32vec4_0000 %f32_0_5 ConstOffsets %const_offsets +%res3 = OpImageDrefGather %f32vec4 %simg %f32vec4_0000 %f32_0_5 NonPrivateTexelKHR +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment", + SPV_ENV_UNIVERSAL_1_3, "VulkanKHR") + .c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateImage, DrefGatherVoidSampledType) { + const std::string body = R"( +%img = OpLoad %type_image_void_2d_0001 %uniform_image_void_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_void_2d_0001 %img %sampler +%res1 = OpImageDrefGather %u32vec4 %simg %f32vec2_00 %f32_0_5 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image 'Sampled Type' to be the same as " + "Result Type components")); +} + +TEST_F(ValidateImage, DrefGatherWrongDrefType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_cube_0101 %uniform_image_f32_cube_0101 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_cube_0101 %img %sampler +%res1 = OpImageDrefGather %f32vec4 %simg %f32vec4_0000 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Dref to be of 32-bit float type")); +} + +TEST_F(ValidateImage, ReadSuccess1) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0000 %uniform_image_u32_2d_0000 +%res1 = OpImageRead %u32vec4 %img %u32vec2_01 +)"; + + const std::string extra = "\nOpCapability StorageImageReadWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, ReadSuccess2) { + const std::string body = R"( +%img = OpLoad %type_image_f32_1d_0002_rgba32f %uniform_image_f32_1d_0002_rgba32f +%res1 = OpImageRead %f32vec4 %img %u32vec2_01 +)"; + + const std::string extra = "\nOpCapability Image1D\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, ReadSuccess3) { + const std::string body = R"( +%img = OpLoad %type_image_f32_cube_0102_rgba32f %uniform_image_f32_cube_0102_rgba32f +%res1 = OpImageRead %f32vec4 %img %u32vec3_012 +)"; + + const std::string extra = "\nOpCapability ImageCubeArray\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, ReadSuccess4) { + const std::string body = R"( +%img = OpLoad %type_image_f32_spd_0002 %uniform_image_f32_spd_0002 +%res1 = OpImageRead %f32vec4 %img %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, ReadNeedCapabilityStorageImageReadWithoutFormat) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0000 %uniform_image_u32_2d_0000 +%res1 = OpImageRead %u32vec4 %img %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Capability StorageImageReadWithoutFormat is required " + "to read storage image")); +} + +TEST_F(ValidateImage, ReadNeedCapabilityImage1D) { + const std::string body = R"( +%img = OpLoad %type_image_f32_1d_0002_rgba32f %uniform_image_f32_1d_0002_rgba32f +%res1 = OpImageRead %f32vec4 %img %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Capability Image1D is required to access storage image")); +} + +TEST_F(ValidateImage, ReadNeedCapabilityImageCubeArray) { + const std::string body = R"( +%img = OpLoad %type_image_f32_cube_0102_rgba32f %uniform_image_f32_cube_0102_rgba32f +%res1 = OpImageRead %f32vec4 %img %u32vec3_012 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Capability ImageCubeArray is required to access storage image")); +} + +// TODO(atgoo@github.com) Disabled until the spec is clarified. +TEST_F(ValidateImage, DISABLED_ReadWrongResultType) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0000 %uniform_image_u32_2d_0000 +%res1 = OpImageRead %f32 %img %u32vec2_01 +)"; + + const std::string extra = "\nOpCapability StorageImageReadWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be int or float vector type")); +} + +// TODO(atgoo@github.com) Disabled until the spec is clarified. +TEST_F(ValidateImage, DISABLED_ReadWrongNumComponentsResultType) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0000 %uniform_image_u32_2d_0000 +%res1 = OpImageRead %f32vec3 %img %u32vec2_01 +)"; + + const std::string extra = "\nOpCapability StorageImageReadWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to have 4 components")); +} + +TEST_F(ValidateImage, ReadNotImage) { + const std::string body = R"( +%sampler = OpLoad %type_sampler %uniform_sampler +%res1 = OpImageRead %f32vec4 %sampler %u32vec2_01 +)"; + + const std::string extra = "\nOpCapability StorageImageReadWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image to be of type OpTypeImage")); +} + +TEST_F(ValidateImage, ReadImageSampled) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%res1 = OpImageRead %f32vec4 %img %u32vec2_01 +)"; + + const std::string extra = "\nOpCapability StorageImageReadWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image 'Sampled' parameter to be 0 or 2")); +} + +TEST_F(ValidateImage, ReadWrongSampledType) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0000 %uniform_image_u32_2d_0000 +%res1 = OpImageRead %f32vec4 %img %u32vec2_01 +)"; + + const std::string extra = "\nOpCapability StorageImageReadWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image 'Sampled Type' to be the same as " + "Result Type components")); +} + +TEST_F(ValidateImage, ReadVoidSampledType) { + const std::string body = R"( +%img = OpLoad %type_image_void_2d_0002 %uniform_image_void_2d_0002 +%res1 = OpImageRead %f32vec4 %img %u32vec2_01 +%res2 = OpImageRead %u32vec4 %img %u32vec2_01 +%res3 = OpImageRead %s32vec4 %img %u32vec2_01 +)"; + + const std::string extra = "\nOpCapability StorageImageReadWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, ReadWrongCoordinateType) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0000 %uniform_image_u32_2d_0000 +%res1 = OpImageRead %u32vec4 %img %f32vec2_00 +)"; + + const std::string extra = "\nOpCapability StorageImageReadWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to be int scalar or vector")); +} + +TEST_F(ValidateImage, ReadCoordinateSizeTooSmall) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0000 %uniform_image_u32_2d_0000 +%res1 = OpImageRead %u32vec4 %img %u32_1 +)"; + + const std::string extra = "\nOpCapability StorageImageReadWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to have at least 2 components, " + "but given only 1")); +} + +TEST_F(ValidateImage, WriteSuccess1) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0000 %uniform_image_u32_2d_0000 +%res1 = OpImageWrite %img %u32vec2_01 %u32vec4_0123 +)"; + + const std::string extra = "\nOpCapability StorageImageWriteWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, WriteSuccess2) { + const std::string body = R"( +%img = OpLoad %type_image_f32_1d_0002_rgba32f %uniform_image_f32_1d_0002_rgba32f +%res1 = OpImageWrite %img %u32_1 %f32vec4_0000 +)"; + + const std::string extra = "\nOpCapability Image1D\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, WriteSuccess3) { + const std::string body = R"( +%img = OpLoad %type_image_f32_cube_0102_rgba32f %uniform_image_f32_cube_0102_rgba32f +%res1 = OpImageWrite %img %u32vec3_012 %f32vec4_0000 +)"; + + const std::string extra = "\nOpCapability ImageCubeArray\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, WriteSuccess4) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0010 %uniform_image_f32_2d_0010 +;TODO(atgoo@github.com) Is it legal to write to MS image without sample index? +%res1 = OpImageWrite %img %u32vec2_01 %f32vec4_0000 +%res2 = OpImageWrite %img %u32vec2_01 %f32vec4_0000 Sample %u32_1 +)"; + + const std::string extra = "\nOpCapability StorageImageWriteWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, WriteSubpassData) { + const std::string body = R"( +%img = OpLoad %type_image_f32_spd_0002 %uniform_image_f32_spd_0002 +%res1 = OpImageWrite %img %u32vec2_01 %f32vec4_0000 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Image 'Dim' cannot be SubpassData")); +} + +TEST_F(ValidateImage, WriteNeedCapabilityStorageImageWriteWithoutFormat) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0000 %uniform_image_u32_2d_0000 +%res1 = OpImageWrite %img %u32vec2_01 %u32vec4_0123 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Capability StorageImageWriteWithoutFormat is required to write to " + "storage image")); +} + +TEST_F(ValidateImage, WriteNeedCapabilityImage1D) { + const std::string body = R"( +%img = OpLoad %type_image_f32_1d_0002_rgba32f %uniform_image_f32_1d_0002_rgba32f +%res1 = OpImageWrite %img %u32vec2_01 %f32vec4_0000 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Capability Image1D is required to access storage " + "image")); +} + +TEST_F(ValidateImage, WriteNeedCapabilityImageCubeArray) { + const std::string body = R"( +%img = OpLoad %type_image_f32_cube_0102_rgba32f %uniform_image_f32_cube_0102_rgba32f +%res1 = OpImageWrite %img %u32vec3_012 %f32vec4_0000 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Capability ImageCubeArray is required to access storage image")); +} + +TEST_F(ValidateImage, WriteNotImage) { + const std::string body = R"( +%sampler = OpLoad %type_sampler %uniform_sampler +%res1 = OpImageWrite %sampler %u32vec2_01 %f32vec4_0000 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image to be of type OpTypeImage")); +} + +TEST_F(ValidateImage, WriteImageSampled) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%res1 = OpImageWrite %img %u32vec2_01 %f32vec4_0000 +)"; + + const std::string extra = "\nOpCapability StorageImageWriteWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image 'Sampled' parameter to be 0 or 2")); +} + +TEST_F(ValidateImage, WriteWrongCoordinateType) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0000 %uniform_image_u32_2d_0000 +%res1 = OpImageWrite %img %f32vec2_00 %u32vec4_0123 +)"; + + const std::string extra = "\nOpCapability StorageImageWriteWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to be int scalar or vector")); +} + +TEST_F(ValidateImage, WriteCoordinateSizeTooSmall) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0000 %uniform_image_u32_2d_0000 +%res1 = OpImageWrite %img %u32_1 %u32vec4_0123 +)"; + + const std::string extra = "\nOpCapability StorageImageWriteWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to have at least 2 components, " + "but given only 1")); +} + +TEST_F(ValidateImage, WriteTexelWrongType) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0000 %uniform_image_u32_2d_0000 +%res1 = OpImageWrite %img %u32vec2_01 %img +)"; + + const std::string extra = "\nOpCapability StorageImageWriteWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Texel to be int or float vector or scalar")); +} + +TEST_F(ValidateImage, DISABLED_WriteTexelNotVector4) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0000 %uniform_image_u32_2d_0000 +%res1 = OpImageWrite %img %u32vec2_01 %u32vec3_012 +)"; + + const std::string extra = "\nOpCapability StorageImageWriteWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Texel to have 4 components")); +} + +TEST_F(ValidateImage, WriteTexelWrongComponentType) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0000 %uniform_image_u32_2d_0000 +%res1 = OpImageWrite %img %u32vec2_01 %f32vec4_0000 +)"; + + const std::string extra = "\nOpCapability StorageImageWriteWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected Image 'Sampled Type' to be the same as Texel components")); +} + +TEST_F(ValidateImage, WriteSampleNotInteger) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0010 %uniform_image_f32_2d_0010 +%res1 = OpImageWrite %img %u32vec2_01 %f32vec4_0000 Sample %f32_1 +)"; + + const std::string extra = "\nOpCapability StorageImageWriteWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image Operand Sample to be int scalar")); +} + +TEST_F(ValidateImage, SampleNotMultisampled) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0002 %uniform_image_f32_2d_0002 +%res2 = OpImageWrite %img %u32vec2_01 %f32vec4_0000 Sample %u32_1 +)"; + + const std::string extra = "\nOpCapability StorageImageWriteWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Image Operand Sample requires non-zero 'MS' parameter")); +} + +TEST_F(ValidateImage, SampleWrongOpcode) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0010 %uniform_image_f32_2d_0010 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0010 %img %sampler +%res1 = OpImageSampleExplicitLod %f32vec4 %simg %f32vec2_00 Sample %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Image Operand Sample can only be used with " + "OpImageFetch, OpImageRead, OpImageWrite, " + "OpImageSparseFetch and OpImageSparseRead")); +} + +TEST_F(ValidateImage, SampleImageToImageSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%img2 = OpImage %type_image_f32_2d_0001 %simg +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, SampleImageToImageWrongResultType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%img2 = OpImage %type_sampled_image_f32_2d_0001 %simg +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be OpTypeImage")); +} + +TEST_F(ValidateImage, SampleImageToImageNotSampledImage) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%img2 = OpImage %type_image_f32_2d_0001 %img +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Sample Image to be of type OpTypeSampleImage")); +} + +TEST_F(ValidateImage, SampleImageToImageNotTheSameImageType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%img2 = OpImage %type_image_f32_2d_0002 %simg +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Sample Image image type to be equal to " + "Result Type")); +} + +TEST_F(ValidateImage, QueryFormatSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%res1 = OpImageQueryFormat %u32 %img +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, QueryFormatWrongResultType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%res1 = OpImageQueryFormat %bool %img +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be int scalar type")); +} + +TEST_F(ValidateImage, QueryFormatNotImage) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageQueryFormat %u32 %simg +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected operand to be of type OpTypeImage")); +} + +TEST_F(ValidateImage, QueryOrderSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%res1 = OpImageQueryOrder %u32 %img +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, QueryOrderWrongResultType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%res1 = OpImageQueryOrder %bool %img +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be int scalar type")); +} + +TEST_F(ValidateImage, QueryOrderNotImage) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageQueryOrder %u32 %simg +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected operand to be of type OpTypeImage")); +} + +TEST_F(ValidateImage, QuerySizeLodSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%res1 = OpImageQuerySizeLod %u32vec2 %img %u32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, QuerySizeLodWrongResultType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%res1 = OpImageQuerySizeLod %f32vec2 %img %u32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Result Type to be int scalar or vector type")); +} + +TEST_F(ValidateImage, QuerySizeLodResultTypeWrongSize) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%res1 = OpImageQuerySizeLod %u32 %img %u32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Result Type has 1 components, but 2 expected")); +} + +TEST_F(ValidateImage, QuerySizeLodNotImage) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageQuerySizeLod %u32vec2 %simg %u32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image to be of type OpTypeImage")); +} + +TEST_F(ValidateImage, QuerySizeLodWrongImageDim) { + const std::string body = R"( +%img = OpLoad %type_image_f32_rect_0001 %uniform_image_f32_rect_0001 +%res1 = OpImageQuerySizeLod %u32vec2 %img %u32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Image 'Dim' must be 1D, 2D, 3D or Cube")); +} + +TEST_F(ValidateImage, QuerySizeLodMultisampled) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0010 %uniform_image_f32_2d_0010 +%res1 = OpImageQuerySizeLod %u32vec2 %img %u32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("Image 'MS' must be 0")); +} + +TEST_F(ValidateImage, QuerySizeLodWrongLodType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%res1 = OpImageQuerySizeLod %u32vec2 %img %f32_0 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Level of Detail to be int scalar")); +} + +TEST_F(ValidateImage, QuerySizeSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0010 %uniform_image_f32_2d_0010 +%res1 = OpImageQuerySize %u32vec2 %img +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, QuerySizeWrongResultType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0010 %uniform_image_f32_2d_0010 +%res1 = OpImageQuerySize %f32vec2 %img +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Result Type to be int scalar or vector type")); +} + +TEST_F(ValidateImage, QuerySizeNotImage) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0010 %uniform_image_f32_2d_0010 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageQuerySize %u32vec2 %simg +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image to be of type OpTypeImage")); +} + +TEST_F(ValidateImage, QuerySizeDimSubpassDataBad) { + const std::string body = R"( +%img = OpLoad %type_image_f32_spd_0002 %uniform_image_f32_spd_0002 +%res1 = OpImageQuerySize %u32vec2 %img +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Image 'Dim' must be 1D, Buffer, 2D, Cube, 3D or Rect")); +} + +TEST_F(ValidateImage, QuerySizeWrongSampling) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%res1 = OpImageQuerySize %u32vec2 %img +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Image must have either 'MS'=1 or 'Sampled'=0 or 'Sampled'=2")); +} + +TEST_F(ValidateImage, QuerySizeWrongNumberOfComponents) { + const std::string body = R"( +%img = OpLoad %type_image_f32_3d_0111 %uniform_image_f32_3d_0111 +%res1 = OpImageQuerySize %u32vec2 %img +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Result Type has 2 components, but 4 expected")); +} + +TEST_F(ValidateImage, QueryLodSuccessKernel) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageQueryLod %f32vec2 %simg %f32vec2_hh +%res2 = OpImageQueryLod %f32vec2 %simg %u32vec2_01 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, QueryLodSuccessShader) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageQueryLod %f32vec2 %simg %f32vec2_hh +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, QueryLodWrongResultType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageQueryLod %u32vec2 %simg %f32vec2_hh +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be float vector type")); +} + +TEST_F(ValidateImage, QueryLodResultTypeWrongSize) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageQueryLod %f32vec3 %simg %f32vec2_hh +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to have 2 components")); +} + +TEST_F(ValidateImage, QueryLodNotSampledImage) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%res1 = OpImageQueryLod %f32vec2 %img %f32vec2_hh +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Image operand to be of type OpTypeSampledImage")); +} + +TEST_F(ValidateImage, QueryLodWrongDim) { + const std::string body = R"( +%img = OpLoad %type_image_f32_rect_0001 %uniform_image_f32_rect_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_rect_0001 %img %sampler +%res1 = OpImageQueryLod %f32vec2 %simg %f32vec2_hh +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Image 'Dim' must be 1D, 2D, 3D or Cube")); +} + +TEST_F(ValidateImage, QueryLodWrongCoordinateType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageQueryLod %f32vec2 %simg %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to be float scalar or vector")); +} + +TEST_F(ValidateImage, QueryLodCoordinateSizeTooSmall) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageQueryLod %f32vec2 %simg %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Coordinate to have at least 2 components, " + "but given only 1")); +} + +TEST_F(ValidateImage, QueryLevelsSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%res1 = OpImageQueryLevels %u32 %img +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, QueryLevelsWrongResultType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%res1 = OpImageQueryLevels %f32 %img +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be int scalar type")); +} + +TEST_F(ValidateImage, QueryLevelsNotImage) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageQueryLevels %u32 %simg +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image to be of type OpTypeImage")); +} + +TEST_F(ValidateImage, QueryLevelsWrongDim) { + const std::string body = R"( +%img = OpLoad %type_image_f32_rect_0001 %uniform_image_f32_rect_0001 +%res1 = OpImageQueryLevels %u32 %img +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Image 'Dim' must be 1D, 2D, 3D or Cube")); +} + +TEST_F(ValidateImage, QuerySamplesSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0010 %uniform_image_f32_2d_0010 +%res1 = OpImageQuerySamples %u32 %img +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, QuerySamplesNot2D) { + const std::string body = R"( +%img = OpLoad %type_image_f32_3d_0010 %uniform_image_f32_3d_0010 +%res1 = OpImageQuerySamples %u32 %img +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("Image 'Dim' must be 2D")); +} + +TEST_F(ValidateImage, QuerySamplesNotMultisampled) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%res1 = OpImageQuerySamples %u32 %img +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("Image 'MS' must be 1")); +} + +TEST_F(ValidateImage, QueryLodWrongExecutionModel) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageQueryLod %f32vec2 %simg %f32vec2_hh +)"; + + CompileSuccessfully(GenerateShaderCode(body, "", "Vertex").c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpImageQueryLod requires Fragment execution model")); +} + +TEST_F(ValidateImage, QueryLodWrongExecutionModelWithFunc) { + const std::string body = R"( +%call_ret = OpFunctionCall %void %my_func +OpReturn +OpFunctionEnd +%my_func = OpFunction %void None %func +%my_func_entry = OpLabel +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageQueryLod %f32vec2 %simg %f32vec2_hh +)"; + + CompileSuccessfully(GenerateShaderCode(body, "", "Vertex").c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpImageQueryLod requires Fragment execution model")); +} + +TEST_F(ValidateImage, ImplicitLodWrongExecutionModel) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec2_hh +)"; + + CompileSuccessfully(GenerateShaderCode(body, "", "Vertex").c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("ImplicitLod instructions require Fragment execution model")); +} + +TEST_F(ValidateImage, ReadSubpassDataWrongExecutionModel) { + const std::string body = R"( +%img = OpLoad %type_image_f32_spd_0002 %uniform_image_f32_spd_0002 +%res1 = OpImageRead %f32vec4 %img %u32vec2_01 +)"; + + const std::string extra = "\nOpCapability StorageImageReadWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Vertex").c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Dim SubpassData requires Fragment execution model")); +} + +TEST_F(ValidateImage, SparseSampleImplicitLodSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSparseSampleImplicitLod %struct_u32_f32vec4 %simg %f32vec2_hh +%res2 = OpImageSparseSampleImplicitLod %struct_u32_f32vec4 %simg %f32vec2_hh Bias %f32_0_25 +%res4 = OpImageSparseSampleImplicitLod %struct_u32_f32vec4 %simg %f32vec2_hh ConstOffset %s32vec2_01 +%res5 = OpImageSparseSampleImplicitLod %struct_u32_f32vec4 %simg %f32vec2_hh Offset %s32vec2_01 +%res6 = OpImageSparseSampleImplicitLod %struct_u32_f32vec4 %simg %f32vec2_hh MinLod %f32_0_5 +%res7 = OpImageSparseSampleImplicitLod %struct_u64_f32vec4 %simg %f32vec2_hh Bias|Offset|MinLod %f32_0_25 %s32vec2_01 %f32_0_5 +%res8 = OpImageSparseSampleImplicitLod %struct_u32_f32vec4 %simg %f32vec2_hh NonPrivateTexelKHR +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment", + SPV_ENV_UNIVERSAL_1_3, "VulkanKHR") + .c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateImage, SparseSampleImplicitLodResultTypeNotStruct) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSparseSampleImplicitLod %f32 %simg %f32vec2_hh +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be OpTypeStruct")); +} + +TEST_F(ValidateImage, SparseSampleImplicitLodResultTypeNotTwoMembers1) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSparseSampleImplicitLod %struct_u32 %simg %f32vec2_hh +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be a struct containing an int " + "scalar and a texel")); +} + +TEST_F(ValidateImage, SparseSampleImplicitLodResultTypeNotTwoMembers2) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSparseSampleImplicitLod %struct_u32_f32vec4_u32 %simg %f32vec2_hh +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be a struct containing an " + "int scalar and a texel")); +} + +TEST_F(ValidateImage, SparseSampleImplicitLodResultTypeFirstMemberNotInt) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSparseSampleImplicitLod %struct_f32_f32vec4 %simg %f32vec2_hh +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be a struct containing an " + "int scalar and a texel")); +} + +TEST_F(ValidateImage, SparseSampleImplicitLodResultTypeTexelNotVector) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSparseSampleImplicitLod %struct_u32_u32 %simg %f32vec2_hh +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type's second member to be int or " + "float vector type")); +} + +TEST_F(ValidateImage, SparseSampleImplicitLodWrongNumComponentsTexel) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSparseSampleImplicitLod %struct_u32_f32vec3 %simg %f32vec2_hh +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type's second member to have 4 " + "components")); +} + +TEST_F(ValidateImage, SparseSampleImplicitLodWrongComponentTypeTexel) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSparseSampleImplicitLod %struct_u32_u32vec4 %simg %f32vec2_hh +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image 'Sampled Type' to be the same as " + "Result Type's second member components")); +} + +TEST_F(ValidateImage, SparseSampleDrefImplicitLodSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0001 %uniform_image_u32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_u32_2d_0001 %img %sampler +%res1 = OpImageSparseSampleDrefImplicitLod %struct_u32_u32 %simg %f32vec2_hh %f32_1 +%res2 = OpImageSparseSampleDrefImplicitLod %struct_u32_u32 %simg %f32vec2_hh %f32_1 Bias %f32_0_25 +%res4 = OpImageSparseSampleDrefImplicitLod %struct_u32_u32 %simg %f32vec2_hh %f32_1 ConstOffset %s32vec2_01 +%res5 = OpImageSparseSampleDrefImplicitLod %struct_u32_u32 %simg %f32vec2_hh %f32_1 Offset %s32vec2_01 +%res6 = OpImageSparseSampleDrefImplicitLod %struct_u32_u32 %simg %f32vec2_hh %f32_1 MinLod %f32_0_5 +%res7 = OpImageSparseSampleDrefImplicitLod %struct_u32_u32 %simg %f32vec2_hh %f32_1 Bias|Offset|MinLod %f32_0_25 %s32vec2_01 %f32_0_5 +%res8 = OpImageSparseSampleDrefImplicitLod %struct_u32_u32 %simg %f32vec2_hh %f32_1 NonPrivateTexelKHR +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment", + SPV_ENV_UNIVERSAL_1_3, "VulkanKHR") + .c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateImage, SparseSampleDrefImplicitLodResultTypeNotStruct) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSparseSampleDrefImplicitLod %f32 %simg %f32vec2_hh %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be OpTypeStruct")); +} + +TEST_F(ValidateImage, SparseSampleDrefImplicitLodResultTypeNotTwoMembers1) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSparseSampleDrefImplicitLod %struct_u32 %simg %f32vec2_hh %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Result Type to be a struct containing an int scalar " + "and a texel")); +} + +TEST_F(ValidateImage, SparseSampleDrefImplicitLodResultTypeNotTwoMembers2) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSparseSampleDrefImplicitLod %struct_u32_f32_u32 %simg %f32vec2_hh %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Result Type to be a struct containing an int scalar " + "and a texel")); +} + +TEST_F(ValidateImage, SparseSampleDrefImplicitLodResultTypeFirstMemberNotInt) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSparseSampleDrefImplicitLod %struct_f32_f32 %simg %f32vec2_hh %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected Result Type to be a struct containing an int scalar " + "and a texel")); +} + +TEST_F(ValidateImage, SparseSampleDrefImplicitLodDifferentSampledType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSparseSampleDrefImplicitLod %struct_u32_u32 %simg %f32vec2_hh %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image 'Sampled Type' to be the same as " + "Result Type's second member")); +} + +TEST_F(ValidateImage, SparseFetchSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_f32_1d_0001 %uniform_image_f32_1d_0001 +%res1 = OpImageSparseFetch %struct_u32_f32vec4 %img %u32vec2_01 +%res2 = OpImageSparseFetch %struct_u32_f32vec4 %img %u32vec2_01 NonPrivateTexelKHR +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment", + SPV_ENV_UNIVERSAL_1_3, "VulkanKHR") + .c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateImage, SparseFetchResultTypeNotStruct) { + const std::string body = R"( +%img = OpLoad %type_image_f32_rect_0001 %uniform_image_f32_rect_0001 +%res1 = OpImageSparseFetch %f32 %img %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be OpTypeStruct")); +} + +TEST_F(ValidateImage, SparseFetchResultTypeNotTwoMembers1) { + const std::string body = R"( +%img = OpLoad %type_image_f32_rect_0001 %uniform_image_f32_rect_0001 +%res1 = OpImageSparseFetch %struct_u32 %img %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be a struct containing an " + "int scalar and a texel")); +} + +TEST_F(ValidateImage, SparseFetchResultTypeNotTwoMembers2) { + const std::string body = R"( +%img = OpLoad %type_image_f32_rect_0001 %uniform_image_f32_rect_0001 +%res1 = OpImageSparseFetch %struct_u32_f32vec4_u32 %img %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be a struct containing an " + "int scalar and a texel")); +} + +TEST_F(ValidateImage, SparseFetchResultTypeFirstMemberNotInt) { + const std::string body = R"( +%img = OpLoad %type_image_f32_rect_0001 %uniform_image_f32_rect_0001 +%res1 = OpImageSparseFetch %struct_f32_f32vec4 %img %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be a struct containing an " + "int scalar and a texel")); +} + +TEST_F(ValidateImage, SparseFetchResultTypeTexelNotVector) { + const std::string body = R"( +%img = OpLoad %type_image_f32_rect_0001 %uniform_image_f32_rect_0001 +%res1 = OpImageSparseFetch %struct_u32_u32 %img %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type's second member to be int or " + "float vector type")); +} + +TEST_F(ValidateImage, SparseFetchWrongNumComponentsTexel) { + const std::string body = R"( +%img = OpLoad %type_image_f32_rect_0001 %uniform_image_f32_rect_0001 +%res1 = OpImageSparseFetch %struct_u32_f32vec3 %img %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type's second member to have 4 " + "components")); +} + +TEST_F(ValidateImage, SparseFetchWrongComponentTypeTexel) { + const std::string body = R"( +%img = OpLoad %type_image_f32_rect_0001 %uniform_image_f32_rect_0001 +%res1 = OpImageSparseFetch %struct_u32_u32vec4 %img %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image 'Sampled Type' to be the same as " + "Result Type's second member components")); +} + +TEST_F(ValidateImage, SparseReadSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0002 %uniform_image_f32_2d_0002 +%res1 = OpImageSparseRead %struct_u32_f32vec4 %img %u32vec2_01 +)"; + + const std::string extra = "\nOpCapability StorageImageReadWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, SparseReadResultTypeNotStruct) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0002 %uniform_image_f32_2d_0002 +%res1 = OpImageSparseRead %f32 %img %u32vec2_01 +)"; + + const std::string extra = "\nOpCapability StorageImageReadWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be OpTypeStruct")); +} + +TEST_F(ValidateImage, SparseReadResultTypeNotTwoMembers1) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0002 %uniform_image_f32_2d_0002 +%res1 = OpImageSparseRead %struct_u32 %img %u32vec2_01 +)"; + + const std::string extra = "\nOpCapability StorageImageReadWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be a struct containing an " + "int scalar and a texel")); +} + +TEST_F(ValidateImage, SparseReadResultTypeNotTwoMembers2) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0002 %uniform_image_f32_2d_0002 +%res1 = OpImageSparseRead %struct_u32_f32vec4_u32 %img %u32vec2_01 +)"; + + const std::string extra = "\nOpCapability StorageImageReadWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be a struct containing an " + "int scalar and a texel")); +} + +TEST_F(ValidateImage, SparseReadResultTypeFirstMemberNotInt) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0002 %uniform_image_f32_2d_0002 +%res1 = OpImageSparseRead %struct_f32_f32vec4 %img %u32vec2_01 +)"; + + const std::string extra = "\nOpCapability StorageImageReadWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be a struct containing an " + "int scalar and a texel")); +} + +TEST_F(ValidateImage, SparseReadResultTypeTexelWrongType) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0002 %uniform_image_f32_2d_0002 +%res1 = OpImageSparseRead %struct_u32_u32arr4 %img %u32vec2_01 +)"; + + const std::string extra = "\nOpCapability StorageImageReadWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type's second member to be int or " + "float scalar or vector type")); +} + +TEST_F(ValidateImage, SparseReadWrongComponentTypeTexel) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0002 %uniform_image_f32_2d_0002 +%res1 = OpImageSparseRead %struct_u32_u32vec4 %img %u32vec2_01 +)"; + + const std::string extra = "\nOpCapability StorageImageReadWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image 'Sampled Type' to be the same as " + "Result Type's second member components")); +} + +TEST_F(ValidateImage, SparseReadSubpassDataNotAllowed) { + const std::string body = R"( +%img = OpLoad %type_image_f32_spd_0002 %uniform_image_f32_spd_0002 +%res1 = OpImageSparseRead %struct_u32_f32vec4 %img %u32vec2_01 +)"; + + const std::string extra = "\nOpCapability StorageImageReadWithoutFormat\n"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment").c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Image Dim SubpassData cannot be used with ImageSparseRead")); +} + +TEST_F(ValidateImage, SparseGatherSuccess) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSparseGather %struct_u32_f32vec4 %simg %f32vec4_0000 %u32_1 +%res2 = OpImageSparseGather %struct_u32_f32vec4 %simg %f32vec4_0000 %u32_1 NonPrivateTexelKHR +)"; + + const std::string extra = R"( +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment", + SPV_ENV_UNIVERSAL_1_3, "VulkanKHR") + .c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateImage, SparseGatherResultTypeNotStruct) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSparseGather %f32 %simg %f32vec2_hh %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be OpTypeStruct")); +} + +TEST_F(ValidateImage, SparseGatherResultTypeNotTwoMembers1) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSparseGather %struct_u32 %simg %f32vec2_hh %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be a struct containing an int " + "scalar and a texel")); +} + +TEST_F(ValidateImage, SparseGatherResultTypeNotTwoMembers2) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSparseGather %struct_u32_f32vec4_u32 %simg %f32vec2_hh %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be a struct containing an int " + "scalar and a texel")); +} + +TEST_F(ValidateImage, SparseGatherResultTypeFirstMemberNotInt) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSparseGather %struct_f32_f32vec4 %simg %f32vec2_hh %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be a struct containing an " + "int scalar and a texel")); +} + +TEST_F(ValidateImage, SparseGatherResultTypeTexelNotVector) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSparseGather %struct_u32_u32 %simg %f32vec2_hh %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type's second member to be int or " + "float vector type")); +} + +TEST_F(ValidateImage, SparseGatherWrongNumComponentsTexel) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSparseGather %struct_u32_f32vec3 %simg %f32vec2_hh %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type's second member to have 4 " + "components")); +} + +TEST_F(ValidateImage, SparseGatherWrongComponentTypeTexel) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSparseGather %struct_u32_u32vec4 %simg %f32vec2_hh %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Image 'Sampled Type' to be the same as " + "Result Type's second member components")); +} + +TEST_F(ValidateImage, SparseTexelsResidentSuccess) { + const std::string body = R"( +%res1 = OpImageSparseTexelsResident %bool %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateImage, SparseTexelsResidentResultTypeNotBool) { + const std::string body = R"( +%res1 = OpImageSparseTexelsResident %u32 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Result Type to be bool scalar type")); +} + +TEST_F(ValidateImage, MakeTexelVisibleKHRSuccessImageRead) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0000 %uniform_image_u32_2d_0000 +%res1 = OpImageRead %u32vec4 %img %u32vec2_01 MakeTexelVisibleKHR|NonPrivateTexelKHR %u32_2 +)"; + + const std::string extra = R"( +OpCapability StorageImageReadWithoutFormat +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment", + SPV_ENV_UNIVERSAL_1_3, "VulkanKHR") + .c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateImage, MakeTexelVisibleKHRSuccessImageSparseRead) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0002 %uniform_image_f32_2d_0002 +%res1 = OpImageSparseRead %struct_u32_f32vec4 %img %u32vec2_01 MakeTexelVisibleKHR|NonPrivateTexelKHR %u32_2 +)"; + + const std::string extra = R"( +OpCapability StorageImageReadWithoutFormat +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment", + SPV_ENV_UNIVERSAL_1_3, "VulkanKHR") + .c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateImage, MakeTexelVisibleKHRFailureOpcode) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec2_hh MakeTexelVisibleKHR|NonPrivateTexelKHR %u32_1 +)"; + + const std::string extra = R"( +OpCapability StorageImageReadWithoutFormat +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment", + SPV_ENV_UNIVERSAL_1_3, "VulkanKHR") + .c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Image Operand MakeTexelVisibleKHR can only be used with " + "OpImageRead or OpImageSparseRead: OpImageSampleImplicitLod")); +} + +TEST_F(ValidateImage, MakeTexelVisibleKHRFailureMissingNonPrivate) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0000 %uniform_image_u32_2d_0000 +%res1 = OpImageRead %u32vec4 %img %u32vec2_01 MakeTexelVisibleKHR %u32_1 +)"; + + const std::string extra = R"( +OpCapability StorageImageReadWithoutFormat +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment", + SPV_ENV_UNIVERSAL_1_3, "VulkanKHR") + .c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Image Operand MakeTexelVisibleKHR requires " + "NonPrivateTexelKHR is also specified: OpImageRead")); +} + +TEST_F(ValidateImage, MakeTexelAvailableKHRSuccessImageWrite) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0000 %uniform_image_u32_2d_0000 +%res1 = OpImageWrite %img %u32vec2_01 %u32vec4_0123 MakeTexelAvailableKHR|NonPrivateTexelKHR %u32_2 +)"; + + const std::string extra = R"( +OpCapability StorageImageWriteWithoutFormat +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment", + SPV_ENV_UNIVERSAL_1_3, "VulkanKHR") + .c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateImage, MakeTexelAvailableKHRFailureOpcode) { + const std::string body = R"( +%img = OpLoad %type_image_f32_2d_0001 %uniform_image_f32_2d_0001 +%sampler = OpLoad %type_sampler %uniform_sampler +%simg = OpSampledImage %type_sampled_image_f32_2d_0001 %img %sampler +%res1 = OpImageSampleImplicitLod %f32vec4 %simg %f32vec2_hh MakeTexelAvailableKHR|NonPrivateTexelKHR %u32_1 +)"; + + const std::string extra = R"( +OpCapability StorageImageReadWithoutFormat +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment", + SPV_ENV_UNIVERSAL_1_3, "VulkanKHR") + .c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Image Operand MakeTexelAvailableKHR can only be used " + "with OpImageWrite: OpImageSampleImplicitLod")); +} + +TEST_F(ValidateImage, MakeTexelAvailableKHRFailureMissingNonPrivate) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0000 %uniform_image_u32_2d_0000 +%res1 = OpImageWrite %img %u32vec2_01 %u32vec4_0123 MakeTexelAvailableKHR %u32_1 +)"; + + const std::string extra = R"( +OpCapability StorageImageWriteWithoutFormat +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment", + SPV_ENV_UNIVERSAL_1_3, "VulkanKHR") + .c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Image Operand MakeTexelAvailableKHR requires " + "NonPrivateTexelKHR is also specified: OpImageWrite")); +} + +TEST_F(ValidateImage, VulkanMemoryModelDeviceScopeImageWriteBad) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0000 %uniform_image_u32_2d_0000 +%res1 = OpImageWrite %img %u32vec2_01 %u32vec4_0123 MakeTexelAvailableKHR|NonPrivateTexelKHR %u32_1 +)"; + + const std::string extra = R"( +OpCapability StorageImageWriteWithoutFormat +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment", + SPV_ENV_UNIVERSAL_1_3, "VulkanKHR") + .c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Use of device scope with VulkanKHR memory model requires the " + "VulkanMemoryModelDeviceScopeKHR capability")); +} + +TEST_F(ValidateImage, VulkanMemoryModelDeviceScopeImageWriteGood) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0000 %uniform_image_u32_2d_0000 +%res1 = OpImageWrite %img %u32vec2_01 %u32vec4_0123 MakeTexelAvailableKHR|NonPrivateTexelKHR %u32_1 +)"; + + const std::string extra = R"( +OpCapability StorageImageWriteWithoutFormat +OpCapability VulkanMemoryModelKHR +OpCapability VulkanMemoryModelDeviceScopeKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment", + SPV_ENV_UNIVERSAL_1_3, "VulkanKHR") + .c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateImage, VulkanMemoryModelDeviceScopeImageReadBad) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0000 %uniform_image_u32_2d_0000 +%res1 = OpImageRead %u32vec4 %img %u32vec2_01 MakeTexelVisibleKHR|NonPrivateTexelKHR %u32_1 +)"; + + const std::string extra = R"( +OpCapability StorageImageReadWithoutFormat +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment", + SPV_ENV_UNIVERSAL_1_3, "VulkanKHR") + .c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Use of device scope with VulkanKHR memory model requires the " + "VulkanMemoryModelDeviceScopeKHR capability")); +} + +TEST_F(ValidateImage, VulkanMemoryModelDeviceScopeImageReadGood) { + const std::string body = R"( +%img = OpLoad %type_image_u32_2d_0000 %uniform_image_u32_2d_0000 +%res1 = OpImageRead %u32vec4 %img %u32vec2_01 MakeTexelVisibleKHR|NonPrivateTexelKHR %u32_1 +)"; + + const std::string extra = R"( +OpCapability StorageImageReadWithoutFormat +OpCapability VulkanMemoryModelKHR +OpCapability VulkanMemoryModelDeviceScopeKHR +OpExtension "SPV_KHR_vulkan_memory_model" +)"; + CompileSuccessfully(GenerateShaderCode(body, extra, "Fragment", + SPV_ENV_UNIVERSAL_1_3, "VulkanKHR") + .c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_interfaces_test.cpp b/test/val/val_interfaces_test.cpp new file mode 100644 index 000000000..ce430f615 --- /dev/null +++ b/test/val/val_interfaces_test.cpp @@ -0,0 +1,169 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::HasSubstr; + +using ValidateInterfacesTest = spvtest::ValidateBase; + +TEST_F(ValidateInterfacesTest, EntryPointMissingInput) { + std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +OpExecutionMode %1 OriginUpperLeft +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Input %3 +%5 = OpVariable %4 Input +%6 = OpTypeFunction %2 +%1 = OpFunction %2 None %6 +%7 = OpLabel +%8 = OpLoad %3 %5 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Input variable id <5> is used by entry point 'func' id <1>, " + "but is not listed as an interface")); +} + +TEST_F(ValidateInterfacesTest, EntryPointMissingOutput) { + std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +OpExecutionMode %1 OriginUpperLeft +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Output %3 +%5 = OpVariable %4 Output +%6 = OpTypeFunction %2 +%1 = OpFunction %2 None %6 +%7 = OpLabel +%8 = OpLoad %3 %5 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Output variable id <5> is used by entry point 'func' id <1>, " + "but is not listed as an interface")); +} + +TEST_F(ValidateInterfacesTest, InterfaceMissingUseInSubfunction) { + std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +OpExecutionMode %1 OriginUpperLeft +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Input %3 +%5 = OpVariable %4 Input +%6 = OpTypeFunction %2 +%1 = OpFunction %2 None %6 +%7 = OpLabel +%8 = OpFunctionCall %2 %9 +OpReturn +OpFunctionEnd +%9 = OpFunction %2 None %6 +%10 = OpLabel +%11 = OpLoad %3 %5 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Input variable id <5> is used by entry point 'func' id <1>, " + "but is not listed as an interface")); +} + +TEST_F(ValidateInterfacesTest, TwoEntryPointsOneFunction) { + std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" %2 +OpEntryPoint Fragment %1 "func2" +OpExecutionMode %1 OriginUpperLeft +%3 = OpTypeVoid +%4 = OpTypeInt 32 0 +%5 = OpTypePointer Input %4 +%2 = OpVariable %5 Input +%6 = OpTypeFunction %3 +%1 = OpFunction %3 None %6 +%7 = OpLabel +%8 = OpLoad %4 %2 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Input variable id <2> is used by entry point 'func2' id <1>, " + "but is not listed as an interface")); +} + +TEST_F(ValidateInterfacesTest, MissingInterfaceThroughInitializer) { + const std::string text = R"( +OpCapability Shader +OpCapability VariablePointers +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "func" +OpExecutionMode %1 OriginUpperLeft +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypePointer Input %3 +%5 = OpTypePointer Function %4 +%6 = OpVariable %4 Input +%7 = OpTypeFunction %2 +%1 = OpFunction %2 None %7 +%8 = OpLabel +%9 = OpVariable %5 Function %6 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text, SPV_ENV_UNIVERSAL_1_3); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Input variable id <6> is used by entry point 'func' id <1>, " + "but is not listed as an interface")); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_layout_test.cpp b/test/val/val_layout_test.cpp new file mode 100644 index 000000000..e502d8c24 --- /dev/null +++ b/test/val/val_layout_test.cpp @@ -0,0 +1,706 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validation tests for Logical Layout + +#include +#include +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/diagnostic.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::StrEq; + +using pred_type = std::function; +using ValidateLayout = spvtest::ValidateBase< + std::tuple>>; + +// returns true if order is equal to VAL +template +spv_result_t Equals(int order) { + return order == VAL ? SPV_SUCCESS : RET; +} + +// returns true if order is between MIN and MAX(inclusive) +template +struct Range { + explicit Range(bool inverse = false) : inverse_(inverse) {} + spv_result_t operator()(int order) { + return (inverse_ ^ (order >= MIN && order <= MAX)) ? SPV_SUCCESS : RET; + } + + private: + bool inverse_; +}; + +template +spv_result_t InvalidSet(int order) { + for (spv_result_t val : {T(true)(order)...}) + if (val != SPV_SUCCESS) return val; + return SPV_SUCCESS; +} + +// SPIRV source used to test the logical layout +const std::vector& getInstructions() { + // clang-format off + static const std::vector instructions = { + "OpCapability Shader", + "OpExtension \"TestExtension\"", + "%inst = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint GLCompute %func \"\"", + "OpExecutionMode %func LocalSize 1 1 1", + "OpExecutionModeId %func LocalSizeId %one %one %one", + "%str = OpString \"Test String\"", + "%str2 = OpString \"blabla\"", + "OpSource GLSL 450 %str \"uniform vec3 var = vec3(4.0);\"", + "OpSourceContinued \"void main(){return;}\"", + "OpSourceExtension \"Test extension\"", + "OpName %func \"MyFunction\"", + "OpMemberName %struct 1 \"my_member\"", + "OpDecorate %dgrp RowMajor", + "OpMemberDecorate %struct 1 RowMajor", + "%dgrp = OpDecorationGroup", + "OpGroupDecorate %dgrp %mat33 %mat44", + "%intt = OpTypeInt 32 1", + "%floatt = OpTypeFloat 32", + "%voidt = OpTypeVoid", + "%boolt = OpTypeBool", + "%vec4 = OpTypeVector %floatt 4", + "%vec3 = OpTypeVector %floatt 3", + "%mat33 = OpTypeMatrix %vec3 3", + "%mat44 = OpTypeMatrix %vec4 4", + "%struct = OpTypeStruct %intt %mat33", + "%vfunct = OpTypeFunction %voidt", + "%viifunct = OpTypeFunction %voidt %intt %intt", + "%one = OpConstant %intt 1", + // TODO(umar): OpConstant fails because the type is not defined + // TODO(umar): OpGroupMemberDecorate + "OpLine %str 3 4", + "OpNoLine", + "%func = OpFunction %voidt None %vfunct", + "%l = OpLabel", + "OpReturn ; %func return", + "OpFunctionEnd ; %func end", + "%func2 = OpFunction %voidt None %viifunct", + "%funcp1 = OpFunctionParameter %intt", + "%funcp2 = OpFunctionParameter %intt", + "%fLabel = OpLabel", + "OpNop", + "OpReturn ; %func2 return", + "OpFunctionEnd" + }; + return instructions; +} + +static const int kRangeEnd = 1000; +pred_type All = Range<0, kRangeEnd>(); + +INSTANTIATE_TEST_CASE_P(InstructionsOrder, + ValidateLayout, + ::testing::Combine(::testing::Range((int)0, (int)getInstructions().size()), + // Note: Because of ID dependencies between instructions, some instructions + // are not free to be placed anywhere without triggering an non-layout + // validation error. Therefore, "Lines to compile" for some instructions + // are not "All" in the below. + // + // | Instruction | Line(s) valid | Lines to compile + ::testing::Values(std::make_tuple(std::string("OpCapability") , Equals<0> , Range<0, 2>()) + , std::make_tuple(std::string("OpExtension") , Equals<1> , All) + , std::make_tuple(std::string("OpExtInstImport") , Equals<2> , All) + , std::make_tuple(std::string("OpMemoryModel") , Equals<3> , Range<1, kRangeEnd>()) + , std::make_tuple(std::string("OpEntryPoint") , Equals<4> , All) + , std::make_tuple(std::string("OpExecutionMode ") , Range<5, 6>() , All) + , std::make_tuple(std::string("OpExecutionModeId") , Range<5, 6>() , All) + , std::make_tuple(std::string("OpSource ") , Range<7, 11>() , Range<8, kRangeEnd>()) + , std::make_tuple(std::string("OpSourceContinued ") , Range<7, 11>() , All) + , std::make_tuple(std::string("OpSourceExtension ") , Range<7, 11>() , All) + , std::make_tuple(std::string("%str2 = OpString ") , Range<7, 11>() , All) + , std::make_tuple(std::string("OpName ") , Range<12, 13>() , All) + , std::make_tuple(std::string("OpMemberName ") , Range<12, 13>() , All) + , std::make_tuple(std::string("OpDecorate ") , Range<14, 17>() , All) + , std::make_tuple(std::string("OpMemberDecorate ") , Range<14, 17>() , All) + , std::make_tuple(std::string("OpGroupDecorate ") , Range<14, 17>() , Range<17, kRangeEnd>()) + , std::make_tuple(std::string("OpDecorationGroup") , Range<14, 17>() , Range<0, 16>()) + , std::make_tuple(std::string("OpTypeBool") , Range<18, 31>() , All) + , std::make_tuple(std::string("OpTypeVoid") , Range<18, 31>() , Range<0, 26>()) + , std::make_tuple(std::string("OpTypeFloat") , Range<18, 31>() , Range<0,21>()) + , std::make_tuple(std::string("OpTypeInt") , Range<18, 31>() , Range<0, 21>()) + , std::make_tuple(std::string("OpTypeVector %floatt 4") , Range<18, 31>() , Range<20, 24>()) + , std::make_tuple(std::string("OpTypeMatrix %vec4 4") , Range<18, 31>() , Range<23, kRangeEnd>()) + , std::make_tuple(std::string("OpTypeStruct") , Range<18, 31>() , Range<25, kRangeEnd>()) + , std::make_tuple(std::string("%vfunct = OpTypeFunction"), Range<18, 31>() , Range<21, 31>()) + , std::make_tuple(std::string("OpConstant") , Range<18, 31>() , Range<21, kRangeEnd>()) + , std::make_tuple(std::string("OpLine ") , Range<18, kRangeEnd>() , Range<8, kRangeEnd>()) + , std::make_tuple(std::string("OpNoLine") , Range<18, kRangeEnd>() , All) + , std::make_tuple(std::string("%fLabel = OpLabel") , Equals<39> , All) + , std::make_tuple(std::string("OpNop") , Equals<40> , Range<40,kRangeEnd>()) + , std::make_tuple(std::string("OpReturn ; %func2 return") , Equals<41> , All) + )),); +// clang-format on + +// Creates a new vector which removes the string if the substr is found in the +// instructions vector and reinserts it in the location specified by order. +// NOTE: This will not work correctly if there are two instances of substr in +// instructions +std::vector GenerateCode(std::string substr, int order) { + std::vector code(getInstructions().size()); + std::vector inst(1); + partition_copy(std::begin(getInstructions()), std::end(getInstructions()), + std::begin(code), std::begin(inst), + [=](const std::string& str) { + return std::string::npos == str.find(substr); + }); + + code.insert(std::begin(code) + order, inst.front()); + return code; +} + +// This test will check the logical layout of a binary by removing each +// instruction in the pair of the INSTANTIATE_TEST_CASE_P call and moving it in +// the SPIRV source formed by combining the vector "instructions". +TEST_P(ValidateLayout, Layout) { + int order; + std::string instruction; + pred_type pred; + pred_type test_pred; // Predicate to determine if the test should be build + std::tuple testCase; + + std::tie(order, testCase) = GetParam(); + std::tie(instruction, pred, test_pred) = testCase; + + // Skip test which break the code generation + if (test_pred(order)) return; + + std::vector code = GenerateCode(instruction, order); + + std::stringstream ss; + std::copy(std::begin(code), std::end(code), + std::ostream_iterator(ss, "\n")); + + const auto env = SPV_ENV_UNIVERSAL_1_3; + // printf("code: \n%s\n", ss.str().c_str()); + CompileSuccessfully(ss.str(), env); + spv_result_t result; + // clang-format off + ASSERT_EQ(pred(order), result = ValidateInstructions(env)) + << "Actual: " << spvResultToString(result) + << "\nExpected: " << spvResultToString(pred(order)) + << "\nOrder: " << order + << "\nInstruction: " << instruction + << "\nCode: \n" << ss.str(); + // clang-format on +} + +TEST_F(ValidateLayout, MemoryModelMissingBeforeEntryPoint) { + std::string str = R"( + OpCapability Matrix + OpExtension "TestExtension" + %inst = OpExtInstImport "GLSL.std.450" + OpEntryPoint GLCompute %func "" + OpExecutionMode %func LocalSize 1 1 1 + )"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_LAYOUT, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "EntryPoint cannot appear before the memory model instruction")); +} + +TEST_F(ValidateLayout, MemoryModelMissing) { + char str[] = R"(OpCapability Linkage)"; + CompileSuccessfully(str, SPV_ENV_UNIVERSAL_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_LAYOUT, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Missing required OpMemoryModel instruction")); +} + +TEST_F(ValidateLayout, MemoryModelSpecifiedTwice) { + char str[] = R"( + OpCapability Linkage + OpCapability Shader + OpMemoryModel Logical Simple + OpMemoryModel Logical Simple + )"; + + CompileSuccessfully(str, SPV_ENV_UNIVERSAL_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_LAYOUT, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpMemoryModel should only be provided once")); +} + +TEST_F(ValidateLayout, FunctionDefinitionBeforeDeclarationBad) { + char str[] = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpDecorate %var Restrict +%intt = OpTypeInt 32 1 +%voidt = OpTypeVoid +%vfunct = OpTypeFunction %voidt +%vifunct = OpTypeFunction %voidt %intt +%ptrt = OpTypePointer Function %intt +%func = OpFunction %voidt None %vfunct +%funcl = OpLabel + OpNop + OpReturn + OpFunctionEnd +%func2 = OpFunction %voidt None %vifunct ; must appear before definition +%func2p = OpFunctionParameter %intt + OpFunctionEnd +)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_LAYOUT, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Function declarations must appear before function definitions.")); +} + +// TODO(umar): Passes but gives incorrect error message. Should be fixed after +// type checking +TEST_F(ValidateLayout, LabelBeforeFunctionParameterBad) { + char str[] = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpDecorate %var Restrict +%intt = OpTypeInt 32 1 +%voidt = OpTypeVoid +%vfunct = OpTypeFunction %voidt +%vifunct = OpTypeFunction %voidt %intt +%ptrt = OpTypePointer Function %intt +%func = OpFunction %voidt None %vifunct +%funcl = OpLabel ; Label appears before function parameter +%func2p = OpFunctionParameter %intt + OpNop + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_LAYOUT, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Function parameters must only appear immediately " + "after the function definition")); +} + +TEST_F(ValidateLayout, FuncParameterNotImmediatlyAfterFuncBad) { + char str[] = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpDecorate %var Restrict +%intt = OpTypeInt 32 1 +%voidt = OpTypeVoid +%vfunct = OpTypeFunction %voidt +%vifunct = OpTypeFunction %voidt %intt +%ptrt = OpTypePointer Function %intt +%func = OpFunction %voidt None %vifunct +%funcl = OpLabel + OpNop + OpBranch %next +%func2p = OpFunctionParameter %intt ;FunctionParameter appears in a function but not immediately afterwards +%next = OpLabel + OpNop + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_LAYOUT, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Function parameters must only appear immediately " + "after the function definition")); +} + +TEST_F(ValidateLayout, OpUndefCanAppearInTypeDeclarationSection) { + std::string str = R"( + OpCapability Kernel + OpCapability Linkage + OpMemoryModel Logical OpenCL +%voidt = OpTypeVoid +%uintt = OpTypeInt 32 0 +%funct = OpTypeFunction %voidt +%udef = OpUndef %uintt +%func = OpFunction %voidt None %funct +%entry = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateLayout, OpUndefCanAppearInBlock) { + std::string str = R"( + OpCapability Kernel + OpCapability Linkage + OpMemoryModel Logical OpenCL +%voidt = OpTypeVoid +%uintt = OpTypeInt 32 0 +%funct = OpTypeFunction %voidt +%func = OpFunction %voidt None %funct +%entry = OpLabel +%udef = OpUndef %uintt + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateLayout, MissingFunctionEndForFunctionWithBody) { + const auto s = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%void = OpTypeVoid +%tf = OpTypeFunction %void +%f = OpFunction %void None %tf +%l = OpLabel +OpReturn +)"; + + CompileSuccessfully(s); + ASSERT_EQ(SPV_ERROR_INVALID_LAYOUT, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + StrEq("Missing OpFunctionEnd at end of module.")); +} + +TEST_F(ValidateLayout, MissingFunctionEndForFunctionPrototype) { + const auto s = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%void = OpTypeVoid +%tf = OpTypeFunction %void +%f = OpFunction %void None %tf +)"; + + CompileSuccessfully(s); + ASSERT_EQ(SPV_ERROR_INVALID_LAYOUT, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + StrEq("Missing OpFunctionEnd at end of module.")); +} + +using ValidateOpFunctionParameter = spvtest::ValidateBase; + +TEST_F(ValidateOpFunctionParameter, OpLineBetweenParameters) { + const auto s = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%foo_frag = OpString "foo.frag" +%i32 = OpTypeInt 32 1 +%tf = OpTypeFunction %i32 %i32 %i32 +%c = OpConstant %i32 123 +%f = OpFunction %i32 None %tf +OpLine %foo_frag 1 1 +%p1 = OpFunctionParameter %i32 +OpNoLine +%p2 = OpFunctionParameter %i32 +%l = OpLabel +OpReturnValue %c +OpFunctionEnd +)"; + CompileSuccessfully(s); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateOpFunctionParameter, TooManyParameters) { + const auto s = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%i32 = OpTypeInt 32 1 +%tf = OpTypeFunction %i32 %i32 %i32 +%c = OpConstant %i32 123 +%f = OpFunction %i32 None %tf +%p1 = OpFunctionParameter %i32 +%p2 = OpFunctionParameter %i32 +%xp3 = OpFunctionParameter %i32 +%xp4 = OpFunctionParameter %i32 +%xp5 = OpFunctionParameter %i32 +%xp6 = OpFunctionParameter %i32 +%xp7 = OpFunctionParameter %i32 +%l = OpLabel +OpReturnValue %c +OpFunctionEnd +)"; + CompileSuccessfully(s); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); +} + +using ValidateEntryPoint = spvtest::ValidateBase; + +// Tests that not having OpEntryPoint causes an error. +TEST_F(ValidateEntryPoint, NoEntryPointBad) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450)"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("No OpEntryPoint instruction was found. This is only " + "allowed if the Linkage capability is being used.")); +} + +// Invalid. A function may not be a target of both OpEntryPoint and +// OpFunctionCall. +TEST_F(ValidateEntryPoint, FunctionIsTargetOfEntryPointAndFunctionCallBad) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %foo "foo" + OpExecutionMode %foo OriginUpperLeft +%voidt = OpTypeVoid +%funct = OpTypeFunction %voidt +%foo = OpFunction %voidt None %funct +%entry = OpLabel +%recurse = OpFunctionCall %voidt %foo + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("A function (1) may not be targeted by both an OpEntryPoint " + "instruction and an OpFunctionCall instruction.")); +} + +// Invalid. Must be within a function to make a function call. +TEST_F(ValidateEntryPoint, FunctionCallOutsideFunctionBody) { + std::string spirv = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpName %variableName "variableName" + %34 = OpFunctionCall %variableName %1 + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_LAYOUT, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("FunctionCall must happen within a function body.")); +} + +// Valid. Module with a function but no entry point is valid when Linkage +// Capability is used. +TEST_F(ValidateEntryPoint, NoEntryPointWithLinkageCapGood) { + std::string spirv = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 +%voidt = OpTypeVoid +%funct = OpTypeFunction %voidt +%foo = OpFunction %voidt None %funct +%entry = OpLabel + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateLayout, ModuleProcessedInvalidIn10) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpName %void "void" + OpModuleProcessed "this is ok in 1.1 and later" + OpDecorate %void Volatile ; bogus, but makes the example short +%void = OpTypeVoid +)"; + + CompileSuccessfully(str, SPV_ENV_UNIVERSAL_1_1); + ASSERT_EQ(SPV_ERROR_WRONG_VERSION, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_0)); + // In a 1.0 environment the version check fails. + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Invalid SPIR-V binary version 1.1 for target " + "environment SPIR-V 1.0.")); +} + +TEST_F(ValidateLayout, ModuleProcessedValidIn11) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpName %void "void" + OpModuleProcessed "this is ok in 1.1 and later" + OpDecorate %void Volatile ; bogus, but makes the example short +%void = OpTypeVoid +)"; + + CompileSuccessfully(str, SPV_ENV_UNIVERSAL_1_1); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_1)); + EXPECT_THAT(getDiagnosticString(), Eq("")); +} + +TEST_F(ValidateLayout, ModuleProcessedBeforeLastNameIsTooEarly) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpModuleProcessed "this is too early" + OpName %void "void" +%void = OpTypeVoid +)"; + + CompileSuccessfully(str, SPV_ENV_UNIVERSAL_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_LAYOUT, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_1)); + // By the mechanics of the validator, we assume ModuleProcessed is in the + // right spot, but then that OpName is in the wrong spot. + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Name cannot appear in a function declaration")); +} + +TEST_F(ValidateLayout, ModuleProcessedInvalidAfterFirstAnnotation) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpDecorate %void Volatile ; this is bogus, but keeps the example short + OpModuleProcessed "this is too late" +%void = OpTypeVoid +)"; + + CompileSuccessfully(str, SPV_ENV_UNIVERSAL_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_LAYOUT, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("ModuleProcessed cannot appear in a function declaration")); +} + +TEST_F(ValidateLayout, ModuleProcessedInvalidInFunctionBeforeLabel) { + char str[] = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" +%void = OpTypeVoid +%voidfn = OpTypeFunction %void +%main = OpFunction %void None %voidfn + OpModuleProcessed "this is too late, in function before label" +%entry = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str, SPV_ENV_UNIVERSAL_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_LAYOUT, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("ModuleProcessed cannot appear in a function declaration")); +} + +TEST_F(ValidateLayout, ModuleProcessedInvalidInBasicBlock) { + char str[] = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" +%void = OpTypeVoid +%voidfn = OpTypeFunction %void +%main = OpFunction %void None %voidfn +%entry = OpLabel + OpModuleProcessed "this is too late, in basic block" + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str, SPV_ENV_UNIVERSAL_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_LAYOUT, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("ModuleProcessed cannot appear in a function declaration")); +} + +TEST_F(ValidateLayout, WebGPUCallerBeforeCalleeBad) { + char str[] = R"( + OpCapability Shader + OpCapability VulkanMemoryModelKHR + OpExtension "SPV_KHR_vulkan_memory_model" + OpMemoryModel Logical VulkanKHR + OpEntryPoint GLCompute %main "main" +%void = OpTypeVoid +%voidfn = OpTypeFunction %void +%main = OpFunction %void None %voidfn +%1 = OpLabel +%2 = OpFunctionCall %void %callee + OpReturn + OpFunctionEnd +%callee = OpFunction %void None %voidfn +%3 = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str, SPV_ENV_WEBGPU_0); + ASSERT_EQ(SPV_ERROR_INVALID_LAYOUT, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("For WebGPU, functions need to be defined before being " + "called.\n %5 = OpFunctionCall %void %6\n")); +} + +TEST_F(ValidateLayout, WebGPUCalleeBeforeCallerGood) { + char str[] = R"( + OpCapability Shader + OpCapability VulkanMemoryModelKHR + OpExtension "SPV_KHR_vulkan_memory_model" + OpMemoryModel Logical VulkanKHR + OpEntryPoint GLCompute %main "main" +%void = OpTypeVoid +%voidfn = OpTypeFunction %void +%callee = OpFunction %void None %voidfn +%3 = OpLabel + OpReturn + OpFunctionEnd +%main = OpFunction %void None %voidfn +%1 = OpLabel +%2 = OpFunctionCall %void %callee + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str, SPV_ENV_WEBGPU_0); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0)); +} + +// TODO(umar): Test optional instructions + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_limits_test.cpp b/test/val/val_limits_test.cpp new file mode 100644 index 000000000..791b709ff --- /dev/null +++ b/test/val/val_limits_test.cpp @@ -0,0 +1,774 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validation tests for Universal Limits. (Section 2.17 of the SPIR-V Spec) + +#include +#include +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::HasSubstr; +using ::testing::MatchesRegex; + +using ValidateLimits = spvtest::ValidateBase; + +std::string header = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 +)"; + +TEST_F(ValidateLimits, IdLargerThanBoundBad) { + std::string str = header + R"( +; %i32 has ID 1 +%i32 = OpTypeInt 32 1 +%c = OpConstant %i32 100 + +; Fake an instruction with 64 as the result id. +; !64 = OpConstantNull %i32 +!0x3002e !1 !64 +)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Result '64' must be less than the ID bound '3'.")); +} + +TEST_F(ValidateLimits, IdEqualToBoundBad) { + std::string str = header + R"( +; %i32 has ID 1 +%i32 = OpTypeInt 32 1 +%c = OpConstant %i32 100 + +; Fake an instruction with 64 as the result id. +; !64 = OpConstantNull %i32 +!0x3002e !1 !64 +)"; + + CompileSuccessfully(str); + + // The largest ID used in this program is 64. Let's overwrite the ID bound in + // the header to be 64. This should result in an error because all IDs must + // satisfy: 0 < id < bound. + OverwriteAssembledBinary(3, 64); + + ASSERT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Result '64' must be less than the ID bound '64'.")); +} + +TEST_F(ValidateLimits, IdBoundTooBigDeaultLimit) { + std::string str = header; + + CompileSuccessfully(str); + + // The largest ID used in this program is 64. Let's overwrite the ID bound in + // the header to be 64. This should result in an error because all IDs must + // satisfy: 0 < id < bound. + OverwriteAssembledBinary(3, 0x4FFFFF); + + ASSERT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Invalid SPIR-V. The id bound is larger than the max " + "id bound 4194303.")); +} + +TEST_F(ValidateLimits, IdBoundAtSetLimit) { + std::string str = header; + + CompileSuccessfully(str); + + // The largest ID used in this program is 64. Let's overwrite the ID bound in + // the header to be 64. This should result in an error because all IDs must + // satisfy: 0 < id < bound. + uint32_t id_bound = 0x4FFFFF; + + OverwriteAssembledBinary(3, id_bound); + getValidatorOptions()->universal_limits_.max_id_bound = id_bound; + + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateLimits, IdBoundJustAboveSetLimit) { + std::string str = header; + + CompileSuccessfully(str); + + // The largest ID used in this program is 64. Let's overwrite the ID bound in + // the header to be 64. This should result in an error because all IDs must + // satisfy: 0 < id < bound. + uint32_t id_bound = 5242878; + + OverwriteAssembledBinary(3, id_bound); + getValidatorOptions()->universal_limits_.max_id_bound = id_bound - 1; + + ASSERT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Invalid SPIR-V. The id bound is larger than the max " + "id bound 5242877.")); +} + +TEST_F(ValidateLimits, IdBoundAtInMaxLimit) { + std::string str = header; + + CompileSuccessfully(str); + + uint32_t id_bound = std::numeric_limits::max(); + + OverwriteAssembledBinary(3, id_bound); + getValidatorOptions()->universal_limits_.max_id_bound = id_bound; + + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateLimits, StructNumMembersGood) { + std::ostringstream spirv; + spirv << header << R"( +%1 = OpTypeInt 32 0 +%2 = OpTypeStruct)"; + for (int i = 0; i < 16383; ++i) { + spirv << " %1"; + } + CompileSuccessfully(spirv.str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateLimits, StructNumMembersExceededBad) { + std::ostringstream spirv; + spirv << header << R"( +%1 = OpTypeInt 32 0 +%2 = OpTypeStruct)"; + for (int i = 0; i < 16384; ++i) { + spirv << " %1"; + } + CompileSuccessfully(spirv.str()); + ASSERT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Number of OpTypeStruct members (16384) has exceeded " + "the limit (16383).")); +} + +TEST_F(ValidateLimits, CustomizedStructNumMembersGood) { + std::ostringstream spirv; + spirv << header << R"( +%1 = OpTypeInt 32 0 +%2 = OpTypeStruct)"; + for (int i = 0; i < 32000; ++i) { + spirv << " %1"; + } + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_struct_members, 32000u); + CompileSuccessfully(spirv.str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateLimits, CustomizedStructNumMembersBad) { + std::ostringstream spirv; + spirv << header << R"( +%1 = OpTypeInt 32 0 +%2 = OpTypeStruct)"; + for (int i = 0; i < 32001; ++i) { + spirv << " %1"; + } + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_struct_members, 32000u); + CompileSuccessfully(spirv.str()); + ASSERT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Number of OpTypeStruct members (32001) has exceeded " + "the limit (32000).")); +} + +// Valid: Switch statement has 16,383 branches. +TEST_F(ValidateLimits, SwitchNumBranchesGood) { + std::ostringstream spirv; + spirv << header << R"( +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpTypeInt 32 0 +%4 = OpConstant %3 1234 +%5 = OpFunction %1 None %2 +%7 = OpLabel +%8 = OpIAdd %3 %4 %4 +%9 = OpSwitch %4 %10)"; + + // Now add the (literal, label) pairs + for (int i = 0; i < 16383; ++i) { + spirv << " 1 %10"; + } + + spirv << R"( +%10 = OpLabel +OpReturn +OpFunctionEnd + )"; + + CompileSuccessfully(spirv.str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid: Switch statement has 16,384 branches. +TEST_F(ValidateLimits, SwitchNumBranchesBad) { + std::ostringstream spirv; + spirv << header << R"( +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpTypeInt 32 0 +%4 = OpConstant %3 1234 +%5 = OpFunction %1 None %2 +%7 = OpLabel +%8 = OpIAdd %3 %4 %4 +%9 = OpSwitch %4 %10)"; + + // Now add the (literal, label) pairs + for (int i = 0; i < 16384; ++i) { + spirv << " 1 %10"; + } + + spirv << R"( +%10 = OpLabel +OpReturn +OpFunctionEnd + )"; + + CompileSuccessfully(spirv.str()); + ASSERT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Number of (literal, label) pairs in OpSwitch (16384) " + "exceeds the limit (16383).")); +} + +// Valid: Switch statement has 10 branches (limit is 10) +TEST_F(ValidateLimits, CustomizedSwitchNumBranchesGood) { + std::ostringstream spirv; + spirv << header << R"( +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpTypeInt 32 0 +%4 = OpConstant %3 1234 +%5 = OpFunction %1 None %2 +%7 = OpLabel +%8 = OpIAdd %3 %4 %4 +%9 = OpSwitch %4 %10)"; + + // Now add the (literal, label) pairs + for (int i = 0; i < 10; ++i) { + spirv << " 1 %10"; + } + + spirv << R"( +%10 = OpLabel +OpReturn +OpFunctionEnd + )"; + + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_switch_branches, 10u); + CompileSuccessfully(spirv.str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid: Switch statement has 11 branches (limit is 10) +TEST_F(ValidateLimits, CustomizedSwitchNumBranchesBad) { + std::ostringstream spirv; + spirv << header << R"( +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpTypeInt 32 0 +%4 = OpConstant %3 1234 +%5 = OpFunction %1 None %2 +%7 = OpLabel +%8 = OpIAdd %3 %4 %4 +%9 = OpSwitch %4 %10)"; + + // Now add the (literal, label) pairs + for (int i = 0; i < 11; ++i) { + spirv << " 1 %10"; + } + + spirv << R"( +%10 = OpLabel +OpReturn +OpFunctionEnd + )"; + + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_switch_branches, 10u); + CompileSuccessfully(spirv.str()); + ASSERT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Number of (literal, label) pairs in OpSwitch (11) " + "exceeds the limit (10).")); +} + +// Valid: OpTypeFunction with 255 arguments. +TEST_F(ValidateLimits, OpTypeFunctionGood) { + int num_args = 255; + std::ostringstream spirv; + spirv << header << R"( +%1 = OpTypeInt 32 0 +%2 = OpTypeFunction %1)"; + // add parameters + for (int i = 0; i < num_args; ++i) { + spirv << " %1"; + } + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid: OpTypeFunction with 256 arguments. (limit is 255 according to the +// spec Universal Limits (2.17). +TEST_F(ValidateLimits, OpTypeFunctionBad) { + int num_args = 256; + std::ostringstream spirv; + spirv << header << R"( +%1 = OpTypeInt 32 0 +%2 = OpTypeFunction %1)"; + for (int i = 0; i < num_args; ++i) { + spirv << " %1"; + } + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpTypeFunction may not take more than 255 arguments. " + "OpTypeFunction '2[%2]' has 256 arguments.")); +} + +// Valid: OpTypeFunction with 100 arguments (Custom limit: 100) +TEST_F(ValidateLimits, CustomizedOpTypeFunctionGood) { + int num_args = 100; + std::ostringstream spirv; + spirv << header << R"( +%1 = OpTypeInt 32 0 +%2 = OpTypeFunction %1)"; + // add parameters + for (int i = 0; i < num_args; ++i) { + spirv << " %1"; + } + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_function_args, 100u); + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid: OpTypeFunction with 101 arguments. (Custom limit: 100) +TEST_F(ValidateLimits, CustomizedOpTypeFunctionBad) { + int num_args = 101; + std::ostringstream spirv; + spirv << header << R"( +%1 = OpTypeInt 32 0 +%2 = OpTypeFunction %1)"; + for (int i = 0; i < num_args; ++i) { + spirv << " %1"; + } + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_function_args, 100u); + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpTypeFunction may not take more than 100 arguments. " + "OpTypeFunction '2[%2]' has 101 arguments.")); +} + +// Valid: module has 65,535 global variables. +TEST_F(ValidateLimits, NumGlobalVarsGood) { + int num_globals = 65535; + std::ostringstream spirv; + spirv << header << R"( + %int = OpTypeInt 32 0 +%_ptr_int = OpTypePointer Input %int + )"; + + for (int i = 0; i < num_globals; ++i) { + spirv << "%var_" << i << " = OpVariable %_ptr_int Input\n"; + } + + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid: module has 65,536 global variables (limit is 65,535). +TEST_F(ValidateLimits, NumGlobalVarsBad) { + int num_globals = 65536; + std::ostringstream spirv; + spirv << header << R"( + %int = OpTypeInt 32 0 +%_ptr_int = OpTypePointer Input %int + )"; + + for (int i = 0; i < num_globals; ++i) { + spirv << "%var_" << i << " = OpVariable %_ptr_int Input\n"; + } + + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Number of Global Variables (Storage Class other than " + "'Function') exceeded the valid limit (65535).")); +} + +// Valid: module has 50 global variables (limit is 50) +TEST_F(ValidateLimits, CustomizedNumGlobalVarsGood) { + int num_globals = 50; + std::ostringstream spirv; + spirv << header << R"( + %int = OpTypeInt 32 0 +%_ptr_int = OpTypePointer Input %int + )"; + + for (int i = 0; i < num_globals; ++i) { + spirv << "%var_" << i << " = OpVariable %_ptr_int Input\n"; + } + + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_global_variables, 50u); + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid: module has 51 global variables (limit is 50). +TEST_F(ValidateLimits, CustomizedNumGlobalVarsBad) { + int num_globals = 51; + std::ostringstream spirv; + spirv << header << R"( + %int = OpTypeInt 32 0 +%_ptr_int = OpTypePointer Input %int + )"; + + for (int i = 0; i < num_globals; ++i) { + spirv << "%var_" << i << " = OpVariable %_ptr_int Input\n"; + } + + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_global_variables, 50u); + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Number of Global Variables (Storage Class other than " + "'Function') exceeded the valid limit (50).")); +} + +// Valid: module has 524,287 local variables. +// Note: AppVeyor limits process time to 300s. For a VisualStudio Debug +// build, going up to 524287 local variables gets too close to that +// limit. So test with an artificially lowered limit. +TEST_F(ValidateLimits, NumLocalVarsGoodArtificiallyLowLimit5K) { + int num_locals = 5000; + std::ostringstream spirv; + spirv << header << R"( + %int = OpTypeInt 32 0 + %_ptr_int = OpTypePointer Function %int + %voidt = OpTypeVoid + %funct = OpTypeFunction %voidt + %main = OpFunction %voidt None %funct + %entry = OpLabel + )"; + + for (int i = 0; i < num_locals; ++i) { + spirv << "%var_" << i << " = OpVariable %_ptr_int Function\n"; + } + + spirv << R"( + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv.str()); + // Artificially limit it. + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_local_variables, num_locals); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid: module has 524,288 local variables (limit is 524,287). +// Artificially limit the check to 5001. +TEST_F(ValidateLimits, NumLocalVarsBadArtificiallyLowLimit5K) { + int num_locals = 5001; + std::ostringstream spirv; + spirv << header << R"( + %int = OpTypeInt 32 0 + %_ptr_int = OpTypePointer Function %int + %voidt = OpTypeVoid + %funct = OpTypeFunction %voidt + %main = OpFunction %voidt None %funct + %entry = OpLabel + )"; + + for (int i = 0; i < num_locals; ++i) { + spirv << "%var_" << i << " = OpVariable %_ptr_int Function\n"; + } + + spirv << R"( + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(spirv.str()); + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_local_variables, 5000u); + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Number of local variables ('Function' Storage Class) " + "exceeded the valid limit (5000).")); +} + +// Valid: module has 100 local variables (limit is 100). +TEST_F(ValidateLimits, CustomizedNumLocalVarsGood) { + int num_locals = 100; + std::ostringstream spirv; + spirv << header << R"( + %int = OpTypeInt 32 0 + %_ptr_int = OpTypePointer Function %int + %voidt = OpTypeVoid + %funct = OpTypeFunction %voidt + %main = OpFunction %voidt None %funct + %entry = OpLabel + )"; + + for (int i = 0; i < num_locals; ++i) { + spirv << "%var_" << i << " = OpVariable %_ptr_int Function\n"; + } + + spirv << R"( + OpReturn + OpFunctionEnd + )"; + + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_local_variables, 100u); + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid: module has 101 local variables (limit is 100). +TEST_F(ValidateLimits, CustomizedNumLocalVarsBad) { + int num_locals = 101; + std::ostringstream spirv; + spirv << header << R"( + %int = OpTypeInt 32 0 + %_ptr_int = OpTypePointer Function %int + %voidt = OpTypeVoid + %funct = OpTypeFunction %voidt + %main = OpFunction %voidt None %funct + %entry = OpLabel + )"; + + for (int i = 0; i < num_locals; ++i) { + spirv << "%var_" << i << " = OpVariable %_ptr_int Function\n"; + } + + spirv << R"( + OpReturn + OpFunctionEnd + )"; + + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_local_variables, 100u); + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Number of local variables ('Function' Storage Class) " + "exceeded the valid limit (100).")); +} + +// Valid: Structure nesting depth of 255. +TEST_F(ValidateLimits, StructNestingDepthGood) { + std::ostringstream spirv; + spirv << header << R"( + %int = OpTypeInt 32 0 + %s_depth_1 = OpTypeStruct %int + )"; + for (auto i = 2; i <= 255; ++i) { + spirv << "%s_depth_" << i << " = OpTypeStruct %int %s_depth_" << i - 1; + spirv << "\n"; + } + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid: Structure nesting depth of 256. +TEST_F(ValidateLimits, StructNestingDepthBad) { + std::ostringstream spirv; + spirv << header << R"( + %int = OpTypeInt 32 0 + %s_depth_1 = OpTypeStruct %int + )"; + for (auto i = 2; i <= 256; ++i) { + spirv << "%s_depth_" << i << " = OpTypeStruct %int %s_depth_" << i - 1; + spirv << "\n"; + } + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure Nesting Depth may not be larger than 255. Found 256.")); +} + +// Valid: Structure nesting depth of 100 (limit is 100). +TEST_F(ValidateLimits, CustomizedStructNestingDepthGood) { + std::ostringstream spirv; + spirv << header << R"( + %int = OpTypeInt 32 0 + %s_depth_1 = OpTypeStruct %int + )"; + for (auto i = 2; i <= 100; ++i) { + spirv << "%s_depth_" << i << " = OpTypeStruct %int %s_depth_" << i - 1; + spirv << "\n"; + } + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_struct_depth, 100u); + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid: Structure nesting depth of 101 (limit is 100). +TEST_F(ValidateLimits, CustomizedStructNestingDepthBad) { + std::ostringstream spirv; + spirv << header << R"( + %int = OpTypeInt 32 0 + %s_depth_1 = OpTypeStruct %int + )"; + for (auto i = 2; i <= 101; ++i) { + spirv << "%s_depth_" << i << " = OpTypeStruct %int %s_depth_" << i - 1; + spirv << "\n"; + } + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_struct_depth, 100u); + CompileSuccessfully(spirv.str()); + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Structure Nesting Depth may not be larger than 100. Found 101.")); +} + +// clang-format off +// Generates an SPIRV program with the given control flow nesting depth +void GenerateSpirvProgramWithCfgNestingDepth(std::string& str, int depth) { + std::ostringstream spirv; + spirv << header << R"( + %void = OpTypeVoid + %3 = OpTypeFunction %void + %bool = OpTypeBool + %12 = OpConstantTrue %bool + %main = OpFunction %void None %3 + %5 = OpLabel + OpBranch %6 + %6 = OpLabel + OpLoopMerge %8 %9 None + OpBranch %10 + %10 = OpLabel + OpBranchConditional %12 %7 %8 + %7 = OpLabel + )"; + int first_id = 13; + int last_id = 14; + // We already have 1 level of nesting due to the Loop. + int num_if_conditions = depth-1; + int largest_index = first_id + 2*num_if_conditions - 2; + for (int i = first_id; i <= largest_index; i = i + 2) { + spirv << "OpSelectionMerge %" << i+1 << " None" << "\n"; + spirv << "OpBranchConditional %12 " << "%" << i << " %" << i+1 << "\n"; + spirv << "%" << i << " = OpLabel" << "\n"; + } + spirv << "OpBranch %9" << "\n"; + + for (int i = largest_index+1; i > last_id; i = i - 2) { + spirv << "%" << i << " = OpLabel" << "\n"; + spirv << "OpBranch %" << i-2 << "\n"; + } + spirv << "%" << last_id << " = OpLabel" << "\n"; + spirv << "OpBranch %9" << "\n"; + spirv << R"( + %9 = OpLabel + OpBranch %6 + %8 = OpLabel + OpReturn + OpFunctionEnd + )"; + str = spirv.str(); +} +// clang-format on + +// Invalid: Control Flow Nesting depth is 1024. (limit is 1023). +TEST_F(ValidateLimits, ControlFlowDepthBad) { + std::string spirv; + GenerateSpirvProgramWithCfgNestingDepth(spirv, 1024); + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Maximum Control Flow nesting depth exceeded.")); +} + +// Valid: Control Flow Nesting depth is 10 (custom limit: 10). +TEST_F(ValidateLimits, CustomizedControlFlowDepthGood) { + std::string spirv; + GenerateSpirvProgramWithCfgNestingDepth(spirv, 10); + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_control_flow_nesting_depth, 10u); + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// Invalid: Control Flow Nesting depth is 11. (custom limit: 10). +TEST_F(ValidateLimits, CustomizedControlFlowDepthBad) { + std::string spirv; + GenerateSpirvProgramWithCfgNestingDepth(spirv, 11); + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_control_flow_nesting_depth, 10u); + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Maximum Control Flow nesting depth exceeded.")); +} + +// Valid. The purpose here is to test the CFG depth calculation code when a loop +// continue target is the loop iteself. It also exercises the case where a loop +// is unreachable. +TEST_F(ValidateLimits, ControlFlowNoEntryToLoopGood) { + std::string str = header + R"( + OpName %entry "entry" + OpName %loop "loop" + OpName %exit "exit" +%voidt = OpTypeVoid +%funct = OpTypeFunction %voidt +%main = OpFunction %voidt None %funct +%entry = OpLabel + OpBranch %exit +%loop = OpLabel + OpLoopMerge %loop %loop None + OpBranch %loop +%exit = OpLabel + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(str); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_literals_test.cpp b/test/val/val_literals_test.cpp new file mode 100644 index 000000000..cbdbdd10e --- /dev/null +++ b/test/val/val_literals_test.cpp @@ -0,0 +1,142 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validation tests for ilegal literals + +#include +#include + +#include "gmock/gmock.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::HasSubstr; + +using ValidateLiterals = spvtest::ValidateBase; +using ValidateLiteralsShader = spvtest::ValidateBase; +using ValidateLiteralsKernel = spvtest::ValidateBase; + +std::string GenerateShaderCode() { + std::string str = R"( + OpCapability Shader + OpCapability Linkage + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + OpMemoryModel Logical GLSL450 +%int16 = OpTypeInt 16 1 +%uint16 = OpTypeInt 16 0 +%int32 = OpTypeInt 32 1 +%uint32 = OpTypeInt 32 0 +%int64 = OpTypeInt 64 1 +%uint64 = OpTypeInt 64 0 +%half = OpTypeFloat 16 +%float = OpTypeFloat 32 +%double = OpTypeFloat 64 +%10 = OpTypeVoid + )"; + return str; +} + +std::string GenerateKernelCode() { + std::string str = R"( + OpCapability Kernel + OpCapability Addresses + OpCapability Linkage + OpCapability Int8 + OpMemoryModel Physical64 OpenCL +%uint8 = OpTypeInt 8 0 + )"; + return str; +} + +TEST_F(ValidateLiterals, LiteralsShaderGood) { + std::string str = GenerateShaderCode() + R"( +%11 = OpConstant %int16 !0x00007FFF +%12 = OpConstant %int16 !0xFFFF8000 +%13 = OpConstant %int16 !0xFFFFABCD +%14 = OpConstant %uint16 !0x0000ABCD +%15 = OpConstant %int16 -32768 +%16 = OpConstant %uint16 65535 +%17 = OpConstant %int32 -2147483648 +%18 = OpConstant %uint32 4294967295 +%19 = OpConstant %int64 -9223372036854775808 +%20 = OpConstant %uint64 18446744073709551615 +%21 = OpConstant %half !0x0000FFFF +%22 = OpConstant %float !0xFFFFFFFF +%23 = OpConstant %double !0xFFFFFFFF !0xFFFFFFFF + )"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateLiteralsShader, LiteralsShaderBad) { + std::string str = GenerateShaderCode() + GetParam(); + std::string inst_id = "11"; + CompileSuccessfully(str); + EXPECT_EQ(SPV_ERROR_INVALID_VALUE, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("The high-order bits of a literal number in instruction " + + inst_id + + " must be 0 for a floating-point type, " + "or 0 for an integer type with Signedness of 0, " + "or sign extended when Signedness is 1")); +} + +INSTANTIATE_TEST_CASE_P( + LiteralsShaderCases, ValidateLiteralsShader, + ::testing::Values("%11 = OpConstant %int16 !0xFFFF0000", // Sign bit is 0 + "%11 = OpConstant %int16 !0x00008000", // Sign bit is 1 + "%11 = OpConstant %int16 !0xABCD8000", // Sign bit is 1 + "%11 = OpConstant %int16 !0xABCD0000", + "%11 = OpConstant %uint16 !0xABCD0000", + "%11 = OpConstant %half !0xABCD0000", + "%11 = OpConstant %half !0x00010000")); + +TEST_F(ValidateLiterals, LiteralsKernelGood) { + std::string str = GenerateKernelCode() + R"( +%4 = OpConstant %uint8 !0x000000AB +%6 = OpConstant %uint8 255 + )"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateLiteralsKernel, LiteralsKernelBad) { + std::string str = GenerateKernelCode() + GetParam(); + std::string inst_id = "2"; + CompileSuccessfully(str); + EXPECT_EQ(SPV_ERROR_INVALID_VALUE, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("The high-order bits of a literal number in instruction " + + inst_id + + " must be 0 for a floating-point type, " + "or 0 for an integer type with Signedness of 0, " + "or sign extended when Signedness is 1")); +} + +INSTANTIATE_TEST_CASE_P( + LiteralsKernelCases, ValidateLiteralsKernel, + ::testing::Values("%2 = OpConstant %uint8 !0xABCDEF00", + "%2 = OpConstant %uint8 !0xABCDEFFF")); + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_logicals_test.cpp b/test/val/val_logicals_test.cpp new file mode 100644 index 000000000..da8e7d9e9 --- /dev/null +++ b/test/val/val_logicals_test.cpp @@ -0,0 +1,954 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests for unique type declaration rules validator. + +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::HasSubstr; +using ::testing::Not; + +using ValidateLogicals = spvtest::ValidateBase; + +std::string GenerateShaderCode( + const std::string& body, + const std::string& capabilities_and_extensions = "") { + const std::string capabilities = + R"( +OpCapability Shader +OpCapability Int64 +OpCapability Float64)"; + + const std::string after_extension_before_body = + R"( +%ext_inst = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f32 = OpTypeFloat 32 +%u32 = OpTypeInt 32 0 +%s32 = OpTypeInt 32 1 +%f64 = OpTypeFloat 64 +%u64 = OpTypeInt 64 0 +%s64 = OpTypeInt 64 1 +%boolvec2 = OpTypeVector %bool 2 +%s32vec2 = OpTypeVector %s32 2 +%u32vec2 = OpTypeVector %u32 2 +%u64vec2 = OpTypeVector %u64 2 +%f32vec2 = OpTypeVector %f32 2 +%f64vec2 = OpTypeVector %f64 2 +%boolvec3 = OpTypeVector %bool 3 +%u32vec3 = OpTypeVector %u32 3 +%u64vec3 = OpTypeVector %u64 3 +%s32vec3 = OpTypeVector %s32 3 +%f32vec3 = OpTypeVector %f32 3 +%f64vec3 = OpTypeVector %f64 3 +%boolvec4 = OpTypeVector %bool 4 +%u32vec4 = OpTypeVector %u32 4 +%u64vec4 = OpTypeVector %u64 4 +%s32vec4 = OpTypeVector %s32 4 +%f32vec4 = OpTypeVector %f32 4 +%f64vec4 = OpTypeVector %f64 4 + +%f32_0 = OpConstant %f32 0 +%f32_1 = OpConstant %f32 1 +%f32_2 = OpConstant %f32 2 +%f32_3 = OpConstant %f32 3 +%f32_4 = OpConstant %f32 4 + +%s32_0 = OpConstant %s32 0 +%s32_1 = OpConstant %s32 1 +%s32_2 = OpConstant %s32 2 +%s32_3 = OpConstant %s32 3 +%s32_4 = OpConstant %s32 4 +%s32_m1 = OpConstant %s32 -1 + +%u32_0 = OpConstant %u32 0 +%u32_1 = OpConstant %u32 1 +%u32_2 = OpConstant %u32 2 +%u32_3 = OpConstant %u32 3 +%u32_4 = OpConstant %u32 4 + +%f64_0 = OpConstant %f64 0 +%f64_1 = OpConstant %f64 1 +%f64_2 = OpConstant %f64 2 +%f64_3 = OpConstant %f64 3 +%f64_4 = OpConstant %f64 4 + +%s64_0 = OpConstant %s64 0 +%s64_1 = OpConstant %s64 1 +%s64_2 = OpConstant %s64 2 +%s64_3 = OpConstant %s64 3 +%s64_4 = OpConstant %s64 4 +%s64_m1 = OpConstant %s64 -1 + +%u64_0 = OpConstant %u64 0 +%u64_1 = OpConstant %u64 1 +%u64_2 = OpConstant %u64 2 +%u64_3 = OpConstant %u64 3 +%u64_4 = OpConstant %u64 4 + +%u32vec2_01 = OpConstantComposite %u32vec2 %u32_0 %u32_1 +%u32vec2_12 = OpConstantComposite %u32vec2 %u32_1 %u32_2 +%u32vec3_012 = OpConstantComposite %u32vec3 %u32_0 %u32_1 %u32_2 +%u32vec3_123 = OpConstantComposite %u32vec3 %u32_1 %u32_2 %u32_3 +%u32vec4_0123 = OpConstantComposite %u32vec4 %u32_0 %u32_1 %u32_2 %u32_3 +%u32vec4_1234 = OpConstantComposite %u32vec4 %u32_1 %u32_2 %u32_3 %u32_4 + +%s32vec2_01 = OpConstantComposite %s32vec2 %s32_0 %s32_1 +%s32vec2_12 = OpConstantComposite %s32vec2 %s32_1 %s32_2 +%s32vec3_012 = OpConstantComposite %s32vec3 %s32_0 %s32_1 %s32_2 +%s32vec3_123 = OpConstantComposite %s32vec3 %s32_1 %s32_2 %s32_3 +%s32vec4_0123 = OpConstantComposite %s32vec4 %s32_0 %s32_1 %s32_2 %s32_3 +%s32vec4_1234 = OpConstantComposite %s32vec4 %s32_1 %s32_2 %s32_3 %s32_4 + +%f32vec2_01 = OpConstantComposite %f32vec2 %f32_0 %f32_1 +%f32vec2_12 = OpConstantComposite %f32vec2 %f32_1 %f32_2 +%f32vec3_012 = OpConstantComposite %f32vec3 %f32_0 %f32_1 %f32_2 +%f32vec3_123 = OpConstantComposite %f32vec3 %f32_1 %f32_2 %f32_3 +%f32vec4_0123 = OpConstantComposite %f32vec4 %f32_0 %f32_1 %f32_2 %f32_3 +%f32vec4_1234 = OpConstantComposite %f32vec4 %f32_1 %f32_2 %f32_3 %f32_4 + +%f64vec2_01 = OpConstantComposite %f64vec2 %f64_0 %f64_1 +%f64vec2_12 = OpConstantComposite %f64vec2 %f64_1 %f64_2 +%f64vec3_012 = OpConstantComposite %f64vec3 %f64_0 %f64_1 %f64_2 +%f64vec3_123 = OpConstantComposite %f64vec3 %f64_1 %f64_2 %f64_3 +%f64vec4_0123 = OpConstantComposite %f64vec4 %f64_0 %f64_1 %f64_2 %f64_3 +%f64vec4_1234 = OpConstantComposite %f64vec4 %f64_1 %f64_2 %f64_3 %f64_4 + +%true = OpConstantTrue %bool +%false = OpConstantFalse %bool +%boolvec2_tf = OpConstantComposite %boolvec2 %true %false +%boolvec3_tft = OpConstantComposite %boolvec3 %true %false %true +%boolvec4_tftf = OpConstantComposite %boolvec4 %true %false %true %false + +%f32vec4ptr = OpTypePointer Function %f32vec4 + +%main = OpFunction %void None %func +%main_entry = OpLabel)"; + + const std::string after_body = + R"( +OpReturn +OpFunctionEnd)"; + + return capabilities + capabilities_and_extensions + + after_extension_before_body + body + after_body; +} + +std::string GenerateKernelCode( + const std::string& body, + const std::string& capabilities_and_extensions = "") { + const std::string capabilities = + R"( +OpCapability Addresses +OpCapability Kernel +OpCapability Linkage +OpCapability Int64 +OpCapability Float64)"; + + const std::string after_extension_before_body = + R"( +OpMemoryModel Physical32 OpenCL +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f32 = OpTypeFloat 32 +%u32 = OpTypeInt 32 0 +%f64 = OpTypeFloat 64 +%u64 = OpTypeInt 64 0 +%boolvec2 = OpTypeVector %bool 2 +%u32vec2 = OpTypeVector %u32 2 +%u64vec2 = OpTypeVector %u64 2 +%f32vec2 = OpTypeVector %f32 2 +%f64vec2 = OpTypeVector %f64 2 +%boolvec3 = OpTypeVector %bool 3 +%u32vec3 = OpTypeVector %u32 3 +%u64vec3 = OpTypeVector %u64 3 +%f32vec3 = OpTypeVector %f32 3 +%f64vec3 = OpTypeVector %f64 3 +%boolvec4 = OpTypeVector %bool 4 +%u32vec4 = OpTypeVector %u32 4 +%u64vec4 = OpTypeVector %u64 4 +%f32vec4 = OpTypeVector %f32 4 +%f64vec4 = OpTypeVector %f64 4 + +%f32_0 = OpConstant %f32 0 +%f32_1 = OpConstant %f32 1 +%f32_2 = OpConstant %f32 2 +%f32_3 = OpConstant %f32 3 +%f32_4 = OpConstant %f32 4 + +%u32_0 = OpConstant %u32 0 +%u32_1 = OpConstant %u32 1 +%u32_2 = OpConstant %u32 2 +%u32_3 = OpConstant %u32 3 +%u32_4 = OpConstant %u32 4 + +%f64_0 = OpConstant %f64 0 +%f64_1 = OpConstant %f64 1 +%f64_2 = OpConstant %f64 2 +%f64_3 = OpConstant %f64 3 +%f64_4 = OpConstant %f64 4 + +%u64_0 = OpConstant %u64 0 +%u64_1 = OpConstant %u64 1 +%u64_2 = OpConstant %u64 2 +%u64_3 = OpConstant %u64 3 +%u64_4 = OpConstant %u64 4 + +%u32vec2_01 = OpConstantComposite %u32vec2 %u32_0 %u32_1 +%u32vec2_12 = OpConstantComposite %u32vec2 %u32_1 %u32_2 +%u32vec3_012 = OpConstantComposite %u32vec3 %u32_0 %u32_1 %u32_2 +%u32vec3_123 = OpConstantComposite %u32vec3 %u32_1 %u32_2 %u32_3 +%u32vec4_0123 = OpConstantComposite %u32vec4 %u32_0 %u32_1 %u32_2 %u32_3 +%u32vec4_1234 = OpConstantComposite %u32vec4 %u32_1 %u32_2 %u32_3 %u32_4 + +%f32vec2_01 = OpConstantComposite %f32vec2 %f32_0 %f32_1 +%f32vec2_12 = OpConstantComposite %f32vec2 %f32_1 %f32_2 +%f32vec3_012 = OpConstantComposite %f32vec3 %f32_0 %f32_1 %f32_2 +%f32vec3_123 = OpConstantComposite %f32vec3 %f32_1 %f32_2 %f32_3 +%f32vec4_0123 = OpConstantComposite %f32vec4 %f32_0 %f32_1 %f32_2 %f32_3 +%f32vec4_1234 = OpConstantComposite %f32vec4 %f32_1 %f32_2 %f32_3 %f32_4 + +%f64vec2_01 = OpConstantComposite %f64vec2 %f64_0 %f64_1 +%f64vec2_12 = OpConstantComposite %f64vec2 %f64_1 %f64_2 +%f64vec3_012 = OpConstantComposite %f64vec3 %f64_0 %f64_1 %f64_2 +%f64vec3_123 = OpConstantComposite %f64vec3 %f64_1 %f64_2 %f64_3 +%f64vec4_0123 = OpConstantComposite %f64vec4 %f64_0 %f64_1 %f64_2 %f64_3 +%f64vec4_1234 = OpConstantComposite %f64vec4 %f64_1 %f64_2 %f64_3 %f64_4 + +%true = OpConstantTrue %bool +%false = OpConstantFalse %bool +%boolvec2_tf = OpConstantComposite %boolvec2 %true %false +%boolvec3_tft = OpConstantComposite %boolvec3 %true %false %true +%boolvec4_tftf = OpConstantComposite %boolvec4 %true %false %true %false + +%f32vec4ptr = OpTypePointer Function %f32vec4 + +%main = OpFunction %void None %func +%main_entry = OpLabel)"; + + const std::string after_body = + R"( +OpReturn +OpFunctionEnd)"; + + return capabilities + capabilities_and_extensions + + after_extension_before_body + body + after_body; +} + +TEST_F(ValidateLogicals, OpAnySuccess) { + const std::string body = R"( +%val1 = OpAny %bool %boolvec2_tf +%val2 = OpAny %bool %boolvec3_tft +%val3 = OpAny %bool %boolvec4_tftf +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateLogicals, OpAnyWrongTypeId) { + const std::string body = R"( +%val = OpAny %u32 %boolvec2_tf +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected bool scalar type as Result Type: Any")); +} + +TEST_F(ValidateLogicals, OpAnyWrongOperand) { + const std::string body = R"( +%val = OpAny %bool %u32vec3_123 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected operand to be vector bool: Any")); +} + +TEST_F(ValidateLogicals, OpIsNanSuccess) { + const std::string body = R"( +%val1 = OpIsNan %bool %f32_1 +%val2 = OpIsNan %bool %f64_0 +%val3 = OpIsNan %boolvec2 %f32vec2_12 +%val4 = OpIsNan %boolvec3 %f32vec3_123 +%val5 = OpIsNan %boolvec4 %f32vec4_1234 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateLogicals, OpIsNanWrongTypeId) { + const std::string body = R"( +%val1 = OpIsNan %u32 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected bool scalar or vector type as Result Type: IsNan")); +} + +TEST_F(ValidateLogicals, OpIsNanOperandNotFloat) { + const std::string body = R"( +%val1 = OpIsNan %bool %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected operand to be scalar or vector float: IsNan")); +} + +TEST_F(ValidateLogicals, OpIsNanOperandWrongSize) { + const std::string body = R"( +%val1 = OpIsNan %bool %f32vec2_12 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected vector sizes of Result Type and the operand to be equal: " + "IsNan")); +} + +TEST_F(ValidateLogicals, OpLessOrGreaterSuccess) { + const std::string body = R"( +%val1 = OpLessOrGreater %bool %f32_0 %f32_1 +%val2 = OpLessOrGreater %bool %f64_0 %f64_0 +%val3 = OpLessOrGreater %boolvec2 %f32vec2_12 %f32vec2_12 +%val4 = OpLessOrGreater %boolvec3 %f32vec3_123 %f32vec3_123 +%val5 = OpLessOrGreater %boolvec4 %f32vec4_1234 %f32vec4_1234 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateLogicals, OpLessOrGreaterWrongTypeId) { + const std::string body = R"( +%val1 = OpLessOrGreater %u32 %f32_1 %f32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected bool scalar or vector type as Result Type: LessOrGreater")); +} + +TEST_F(ValidateLogicals, OpLessOrGreaterLeftOperandNotFloat) { + const std::string body = R"( +%val1 = OpLessOrGreater %bool %u32_1 %f32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected operands to be scalar or vector float: LessOrGreater")); +} + +TEST_F(ValidateLogicals, OpLessOrGreaterLeftOperandWrongSize) { + const std::string body = R"( +%val1 = OpLessOrGreater %bool %f32vec2_12 %f32_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected vector sizes of Result Type and the operands to be equal: " + "LessOrGreater")); +} + +TEST_F(ValidateLogicals, OpLessOrGreaterOperandsDifferentType) { + const std::string body = R"( +%val1 = OpLessOrGreater %bool %f32_1 %f64_1 +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected left and right operands to have the same type: " + "LessOrGreater")); +} + +TEST_F(ValidateLogicals, OpFOrdEqualSuccess) { + const std::string body = R"( +%val1 = OpFOrdEqual %bool %f32_0 %f32_1 +%val2 = OpFOrdEqual %bool %f64_0 %f64_0 +%val3 = OpFOrdEqual %boolvec2 %f32vec2_12 %f32vec2_12 +%val4 = OpFOrdEqual %boolvec3 %f32vec3_123 %f32vec3_123 +%val5 = OpFOrdEqual %boolvec4 %f32vec4_1234 %f32vec4_1234 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateLogicals, OpFOrdEqualWrongTypeId) { + const std::string body = R"( +%val1 = OpFOrdEqual %u32 %f32_1 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected bool scalar or vector type as Result Type: FOrdEqual")); +} + +TEST_F(ValidateLogicals, OpFOrdEqualLeftOperandNotFloat) { + const std::string body = R"( +%val1 = OpFOrdEqual %bool %u32_1 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected operands to be scalar or vector float: FOrdEqual")); +} + +TEST_F(ValidateLogicals, OpFOrdEqualLeftOperandWrongSize) { + const std::string body = R"( +%val1 = OpFOrdEqual %bool %f32vec2_12 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected vector sizes of Result Type and the operands to be equal: " + "FOrdEqual")); +} + +TEST_F(ValidateLogicals, OpFOrdEqualOperandsDifferentType) { + const std::string body = R"( +%val1 = OpFOrdEqual %bool %f32_1 %f64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected left and right operands to have the same type: " + "FOrdEqual")); +} + +TEST_F(ValidateLogicals, OpLogicalEqualSuccess) { + const std::string body = R"( +%val1 = OpLogicalEqual %bool %true %false +%val2 = OpLogicalEqual %boolvec2 %boolvec2_tf %boolvec2_tf +%val3 = OpLogicalEqual %boolvec3 %boolvec3_tft %boolvec3_tft +%val4 = OpLogicalEqual %boolvec4 %boolvec4_tftf %boolvec4_tftf +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateLogicals, OpLogicalEqualWrongTypeId) { + const std::string body = R"( +%val1 = OpLogicalEqual %u32 %true %false +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected bool scalar or vector type as Result Type: LogicalEqual")); +} + +TEST_F(ValidateLogicals, OpLogicalEqualWrongLeftOperand) { + const std::string body = R"( +%val1 = OpLogicalEqual %bool %boolvec2_tf %false +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected both operands to be of Result Type: LogicalEqual")); +} + +TEST_F(ValidateLogicals, OpLogicalEqualWrongRightOperand) { + const std::string body = R"( +%val1 = OpLogicalEqual %boolvec2 %boolvec2_tf %false +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected both operands to be of Result Type: LogicalEqual")); +} + +TEST_F(ValidateLogicals, OpLogicalNotSuccess) { + const std::string body = R"( +%val1 = OpLogicalNot %bool %true +%val2 = OpLogicalNot %boolvec2 %boolvec2_tf +%val3 = OpLogicalNot %boolvec3 %boolvec3_tft +%val4 = OpLogicalNot %boolvec4 %boolvec4_tftf +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateLogicals, OpLogicalNotWrongTypeId) { + const std::string body = R"( +%val1 = OpLogicalNot %u32 %true +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected bool scalar or vector type as Result Type: LogicalNot")); +} + +TEST_F(ValidateLogicals, OpLogicalNotWrongOperand) { + const std::string body = R"( +%val1 = OpLogicalNot %bool %boolvec2_tf +)"; + + CompileSuccessfully(GenerateKernelCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected operand to be of Result Type: LogicalNot")); +} + +TEST_F(ValidateLogicals, OpSelectSuccess) { + const std::string body = R"( +%val1 = OpSelect %u32 %true %u32_0 %u32_1 +%val2 = OpSelect %f32 %true %f32_0 %f32_1 +%val3 = OpSelect %f64 %true %f64_0 %f64_1 +%val4 = OpSelect %f32vec2 %boolvec2_tf %f32vec2_01 %f32vec2_12 +%val5 = OpSelect %f32vec4 %boolvec4_tftf %f32vec4_0123 %f32vec4_1234 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateLogicals, OpSelectWrongTypeId) { + const std::string body = R"( +%val1 = OpSelect %void %true %u32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected scalar or vector type as Result Type: Select")); +} + +TEST_F(ValidateLogicals, OpSelectPointerNoCapability) { + const std::string body = R"( +%x = OpVariable %f32vec4ptr Function +%y = OpVariable %f32vec4ptr Function +OpStore %x %f32vec4_0123 +OpStore %y %f32vec4_1234 +%val1 = OpSelect %f32vec4ptr %true %x %y +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Using pointers with OpSelect requires capability VariablePointers " + "or VariablePointersStorageBuffer")); +} + +TEST_F(ValidateLogicals, OpSelectPointerWithCapability1) { + const std::string body = R"( +%x = OpVariable %f32vec4ptr Function +%y = OpVariable %f32vec4ptr Function +OpStore %x %f32vec4_0123 +OpStore %y %f32vec4_1234 +%val1 = OpSelect %f32vec4ptr %true %x %y +)"; + + const std::string extra_cap_ext = R"( +OpCapability VariablePointers +OpExtension "SPV_KHR_variable_pointers" +)"; + + CompileSuccessfully(GenerateShaderCode(body, extra_cap_ext).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateLogicals, OpSelectPointerWithCapability2) { + const std::string body = R"( +%x = OpVariable %f32vec4ptr Function +%y = OpVariable %f32vec4ptr Function +OpStore %x %f32vec4_0123 +OpStore %y %f32vec4_1234 +%val1 = OpSelect %f32vec4ptr %true %x %y +)"; + + const std::string extra_cap_ext = R"( +OpCapability VariablePointersStorageBuffer +OpExtension "SPV_KHR_variable_pointers" +)"; + + CompileSuccessfully(GenerateShaderCode(body, extra_cap_ext).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateLogicals, OpSelectWrongCondition) { + const std::string body = R"( +%val1 = OpSelect %u32 %u32_1 %u32_0 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected bool scalar or vector type as condition: Select")); +} + +TEST_F(ValidateLogicals, OpSelectWrongConditionDimension) { + const std::string body = R"( +%val1 = OpSelect %u32vec2 %true %u32vec2_01 %u32vec2_12 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected vector sizes of Result Type and the condition to be equal: " + "Select")); +} + +TEST_F(ValidateLogicals, OpSelectWrongLeftObject) { + const std::string body = R"( +%val1 = OpSelect %bool %true %u32vec2_01 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected both objects to be of Result Type: Select")); +} + +TEST_F(ValidateLogicals, OpSelectWrongRightObject) { + const std::string body = R"( +%val1 = OpSelect %bool %true %u32_1 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected both objects to be of Result Type: Select")); +} + +TEST_F(ValidateLogicals, OpIEqualSuccess) { + const std::string body = R"( +%val1 = OpIEqual %bool %u32_0 %s32_1 +%val2 = OpIEqual %bool %s64_0 %u64_0 +%val3 = OpIEqual %boolvec2 %s32vec2_12 %u32vec2_12 +%val4 = OpIEqual %boolvec3 %s32vec3_123 %u32vec3_123 +%val5 = OpIEqual %boolvec4 %s32vec4_1234 %u32vec4_1234 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateLogicals, OpIEqualWrongTypeId) { + const std::string body = R"( +%val1 = OpIEqual %u32 %s32_1 %s32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected bool scalar or vector type as Result Type: IEqual")); +} + +TEST_F(ValidateLogicals, OpIEqualLeftOperandNotInt) { + const std::string body = R"( +%val1 = OpIEqual %bool %f32_1 %s32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected operands to be scalar or vector int: IEqual")); +} + +TEST_F(ValidateLogicals, OpIEqualLeftOperandWrongSize) { + const std::string body = R"( +%val1 = OpIEqual %bool %s32vec2_12 %s32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected vector sizes of Result Type and the operands to be equal: " + "IEqual")); +} + +TEST_F(ValidateLogicals, OpIEqualRightOperandNotInt) { + const std::string body = R"( +%val1 = OpIEqual %bool %u32_1 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected operands to be scalar or vector int: IEqual")); +} + +TEST_F(ValidateLogicals, OpIEqualDifferentBitWidth) { + const std::string body = R"( +%val1 = OpIEqual %bool %u32_1 %u64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected both operands to have the same component bit " + "width: IEqual")); +} + +TEST_F(ValidateLogicals, OpUGreaterThanSuccess) { + const std::string body = R"( +%val1 = OpUGreaterThan %bool %u32_0 %u32_1 +%val2 = OpUGreaterThan %bool %s32_0 %u32_1 +%val3 = OpUGreaterThan %bool %u64_0 %u64_0 +%val4 = OpUGreaterThan %bool %u64_0 %s64_0 +%val5 = OpUGreaterThan %boolvec2 %u32vec2_12 %u32vec2_12 +%val6 = OpUGreaterThan %boolvec3 %s32vec3_123 %u32vec3_123 +%val7 = OpUGreaterThan %boolvec4 %u32vec4_1234 %u32vec4_1234 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateLogicals, OpUGreaterThanWrongTypeId) { + const std::string body = R"( +%val1 = OpUGreaterThan %u32 %u32_1 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected bool scalar or vector type as Result Type: UGreaterThan")); +} + +TEST_F(ValidateLogicals, OpUGreaterThanLeftOperandNotInt) { + const std::string body = R"( +%val1 = OpUGreaterThan %bool %f32_1 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected operands to be scalar or vector int: UGreaterThan")); +} + +TEST_F(ValidateLogicals, OpUGreaterThanLeftOperandWrongSize) { + const std::string body = R"( +%val1 = OpUGreaterThan %bool %u32vec2_12 %u32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected vector sizes of Result Type and the operands to be equal: " + "UGreaterThan")); +} + +TEST_F(ValidateLogicals, OpUGreaterThanRightOperandNotInt) { + const std::string body = R"( +%val1 = OpUGreaterThan %bool %u32_1 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected operands to be scalar or vector int: UGreaterThan")); +} + +TEST_F(ValidateLogicals, OpUGreaterThanDifferentBitWidth) { + const std::string body = R"( +%val1 = OpUGreaterThan %bool %u32_1 %u64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected both operands to have the same component bit width: " + "UGreaterThan")); +} + +TEST_F(ValidateLogicals, OpSGreaterThanSuccess) { + const std::string body = R"( +%val1 = OpSGreaterThan %bool %s32_0 %s32_1 +%val2 = OpSGreaterThan %bool %u32_0 %s32_1 +%val3 = OpSGreaterThan %bool %s64_0 %s64_0 +%val4 = OpSGreaterThan %bool %s64_0 %u64_0 +%val5 = OpSGreaterThan %boolvec2 %s32vec2_12 %s32vec2_12 +%val6 = OpSGreaterThan %boolvec3 %s32vec3_123 %u32vec3_123 +%val7 = OpSGreaterThan %boolvec4 %s32vec4_1234 %s32vec4_1234 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateLogicals, OpSGreaterThanWrongTypeId) { + const std::string body = R"( +%val1 = OpSGreaterThan %s32 %s32_1 %s32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected bool scalar or vector type as Result Type: SGreaterThan")); +} + +TEST_F(ValidateLogicals, OpSGreaterThanLeftOperandNotInt) { + const std::string body = R"( +%val1 = OpSGreaterThan %bool %f32_1 %s32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected operands to be scalar or vector int: SGreaterThan")); +} + +TEST_F(ValidateLogicals, OpSGreaterThanLeftOperandWrongSize) { + const std::string body = R"( +%val1 = OpSGreaterThan %bool %s32vec2_12 %s32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Expected vector sizes of Result Type and the operands to be equal: " + "SGreaterThan")); +} + +TEST_F(ValidateLogicals, OpSGreaterThanRightOperandNotInt) { + const std::string body = R"( +%val1 = OpSGreaterThan %bool %s32_1 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Expected operands to be scalar or vector int: SGreaterThan")); +} + +TEST_F(ValidateLogicals, OpSGreaterThanDifferentBitWidth) { + const std::string body = R"( +%val1 = OpSGreaterThan %bool %s32_1 %s64_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected both operands to have the same component bit " + "width: SGreaterThan")); +} + +TEST_F(ValidateLogicals, PSBSelectSuccess) { + const std::string body = R"( +OpCapability PhysicalStorageBufferAddressesEXT +OpCapability Int64 +OpCapability Shader +OpExtension "SPV_EXT_physical_storage_buffer" +OpMemoryModel PhysicalStorageBuffer64EXT GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpDecorate %val1 AliasedPointerEXT +%uint64 = OpTypeInt 64 0 +%bool = OpTypeBool +%true = OpConstantTrue %bool +%ptr = OpTypePointer PhysicalStorageBufferEXT %uint64 +%pptr_f = OpTypePointer Function %ptr +%void = OpTypeVoid +%voidfn = OpTypeFunction %void +%main = OpFunction %void None %voidfn +%entry = OpLabel +%val1 = OpVariable %pptr_f Function +%val2 = OpLoad %ptr %val1 +%val3 = OpSelect %ptr %true %val2 %val2 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_memory_test.cpp b/test/val/val_memory_test.cpp new file mode 100644 index 000000000..097477c31 --- /dev/null +++ b/test/val/val_memory_test.cpp @@ -0,0 +1,2573 @@ +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validation tests for memory/storage + +#include +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::Eq; +using ::testing::HasSubstr; + +using ValidateMemory = spvtest::ValidateBase; + +TEST_F(ValidateMemory, VulkanUniformConstantOnNonOpaqueResourceBad) { + std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%float = OpTypeFloat 32 +%float_ptr = OpTypePointer UniformConstant %float +%2 = OpVariable %float_ptr UniformConstant +%void = OpTypeVoid +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("From Vulkan spec, section 14.5.2:\n" + "Variables identified with the UniformConstant storage class " + "are used only as handles to refer to opaque resources. Such " + "variables must be typed as OpTypeImage, OpTypeSampler, " + "OpTypeSampledImage, OpTypeAccelerationStructureNV, or an " + "array of one of these types.")); +} + +TEST_F(ValidateMemory, VulkanUniformConstantOnOpaqueResourceGood) { + std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +OpDecorate %2 DescriptorSet 0 +OpDecorate %2 Binding 0 +%sampler = OpTypeSampler +%sampler_ptr = OpTypePointer UniformConstant %sampler +%2 = OpVariable %sampler_ptr UniformConstant +%void = OpTypeVoid +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateMemory, VulkanUniformConstantOnNonOpaqueResourceArrayBad) { + std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%float = OpTypeFloat 32 +%uint = OpTypeInt 32 0 +%array_size = OpConstant %uint 5 +%array = OpTypeArray %float %array_size +%array_ptr = OpTypePointer UniformConstant %array +%2 = OpVariable %array_ptr UniformConstant +%void = OpTypeVoid +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("From Vulkan spec, section 14.5.2:\n" + "Variables identified with the UniformConstant storage class " + "are used only as handles to refer to opaque resources. Such " + "variables must be typed as OpTypeImage, OpTypeSampler, " + "OpTypeSampledImage, OpTypeAccelerationStructureNV, or an " + "array of one of these types.")); +} + +TEST_F(ValidateMemory, VulkanUniformConstantOnOpaqueResourceArrayGood) { + std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +OpDecorate %2 DescriptorSet 0 +OpDecorate %2 Binding 0 +%sampler = OpTypeSampler +%uint = OpTypeInt 32 0 +%array_size = OpConstant %uint 5 +%array = OpTypeArray %sampler %array_size +%array_ptr = OpTypePointer UniformConstant %array +%2 = OpVariable %array_ptr UniformConstant +%void = OpTypeVoid +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateMemory, VulkanUniformConstantOnOpaqueResourceRuntimeArrayGood) { + std::string spirv = R"( +OpCapability RuntimeDescriptorArrayEXT +OpCapability Shader +OpExtension "SPV_EXT_descriptor_indexing" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +OpDecorate %2 DescriptorSet 0 +OpDecorate %2 Binding 0 +%sampler = OpTypeSampler +%uint = OpTypeInt 32 0 +%array = OpTypeRuntimeArray %sampler +%array_ptr = OpTypePointer UniformConstant %array +%2 = OpVariable %array_ptr UniformConstant +%void = OpTypeVoid +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateMemory, VulkanUniformOnIntBad) { + char src[] = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %kernel "main" + OpExecutionMode %kernel LocalSize 1 1 1 + + OpDecorate %var DescriptorSet 0 + OpDecorate %var Binding 0 + + %voidty = OpTypeVoid +%kernelty = OpTypeFunction %voidty + %intty = OpTypeInt 32 0 + %varty = OpTypePointer Uniform %intty + %value = OpConstant %intty 42 + + %var = OpVariable %varty Uniform + + %kernel = OpFunction %voidty None %kernelty + %label = OpLabel + OpStore %var %value + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(src, SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("From Vulkan spec, section 14.5.2:\n" + "Variables identified with the Uniform storage class are used " + "to access transparent buffer backed resources. Such variables " + "must be typed as OpTypeStruct, or an array of this type")); +} + +// #version 440 +// #extension GL_EXT_nonuniform_qualifier : enable +// layout(binding = 1) uniform sampler2D s2d[][2]; +// layout(location = 0) in nonuniformEXT int i; +// void main() +// { +// vec4 v = texture(s2d[i][i], vec2(0.3)); +// } +TEST_F(ValidateMemory, VulkanUniformOnRuntimeArrayOfArrayBad) { + char src[] = R"( + OpCapability Shader + OpCapability ShaderNonUniformEXT + OpCapability RuntimeDescriptorArrayEXT + OpCapability SampledImageArrayNonUniformIndexingEXT + OpExtension "SPV_EXT_descriptor_indexing" + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" %i + OpSource GLSL 440 + OpSourceExtension "GL_EXT_nonuniform_qualifier" + OpName %main "main" + OpName %v "v" + OpName %s2d "s2d" + OpName %i "i" + OpDecorate %s2d DescriptorSet 0 + OpDecorate %s2d Binding 1 + OpDecorate %i Location 0 + OpDecorate %i NonUniformEXT + OpDecorate %21 NonUniformEXT + OpDecorate %22 NonUniformEXT + OpDecorate %25 NonUniformEXT + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float + %10 = OpTypeImage %float 2D 0 0 0 1 Unknown + %11 = OpTypeSampledImage %10 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%_arr_11_uint_2 = OpTypeArray %11 %uint_2 +%_runtimearr__arr_11_uint_2 = OpTypeRuntimeArray %_arr_11_uint_2 +%_ptr_Uniform__runtimearr__arr_11_uint_2 = OpTypePointer Uniform %_runtimearr__arr_11_uint_2 + %s2d = OpVariable %_ptr_Uniform__runtimearr__arr_11_uint_2 Uniform + %int = OpTypeInt 32 1 +%_ptr_Input_int = OpTypePointer Input %int + %i = OpVariable %_ptr_Input_int Input +%_ptr_Uniform_11 = OpTypePointer Uniform %11 + %v2float = OpTypeVector %float 2 +%float_0_300000012 = OpConstant %float 0.300000012 + %28 = OpConstantComposite %v2float %float_0_300000012 %float_0_300000012 + %float_0 = OpConstant %float 0 + %main = OpFunction %void None %3 + %5 = OpLabel + %v = OpVariable %_ptr_Function_v4float Function + %21 = OpLoad %int %i + %22 = OpLoad %int %i + %24 = OpAccessChain %_ptr_Uniform_11 %s2d %21 %22 + %25 = OpLoad %11 %24 + %30 = OpImageSampleExplicitLod %v4float %25 %28 Lod %float_0 + OpStore %v %30 + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(src, SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("From Vulkan spec, section 14.5.2:\n" + "Variables identified with the Uniform storage class are used " + "to access transparent buffer backed resources. Such variables " + "must be typed as OpTypeStruct, or an array of this type")); +} + +// #version 440 +// layout (set=1, binding=1) uniform sampler2D variableName[2][2]; +// void main() { +// } +TEST_F(ValidateMemory, VulkanUniformOnArrayOfArrayBad) { + char src[] = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Vertex %main "main" + OpSource GLSL 440 + OpName %main "main" + OpName %variableName "variableName" + OpDecorate %variableName DescriptorSet 1 + OpDecorate %variableName Binding 1 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %7 = OpTypeImage %float 2D 0 0 0 1 Unknown + %8 = OpTypeSampledImage %7 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%_arr_8_uint_2 = OpTypeArray %8 %uint_2 +%_arr__arr_8_uint_2_uint_2 = OpTypeArray %_arr_8_uint_2 %uint_2 +%_ptr_Uniform__arr__arr_8_uint_2_uint_2 = OpTypePointer Uniform %_arr__arr_8_uint_2_uint_2 +%variableName = OpVariable %_ptr_Uniform__arr__arr_8_uint_2_uint_2 Uniform + %main = OpFunction %void None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(src, SPV_ENV_VULKAN_1_1); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("From Vulkan spec, section 14.5.2:\n" + "Variables identified with the Uniform storage class are used " + "to access transparent buffer backed resources. Such variables " + "must be typed as OpTypeStruct, or an array of this type")); +} + +TEST_F(ValidateMemory, MismatchingStorageClassesBad) { + std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%float = OpTypeFloat 32 +%float_ptr = OpTypePointer Uniform %float +%void = OpTypeVoid +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +%2 = OpVariable %float_ptr Function +OpReturn +OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "From SPIR-V spec, section 3.32.8 on OpVariable:\n" + "Its Storage Class operand must be the same as the Storage Class " + "operand of the result type.")); +} + +TEST_F(ValidateMemory, MatchingStorageClassesGood) { + std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%float = OpTypeFloat 32 +%float_ptr = OpTypePointer Function %float +%void = OpTypeVoid +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +%2 = OpVariable %float_ptr Function +OpReturn +OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateMemory, WebGPUInitializerWithOutputStorageClassesGood) { + std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%float = OpTypeFloat 32 +%float_ptr = OpTypePointer Output %float +%init_val = OpConstant %float 1.0 +%1 = OpVariable %float_ptr Output %init_val +%void = OpTypeVoid +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%2 = OpLabel +OpReturn +OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str(), SPV_ENV_WEBGPU_0); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0)); +} + +TEST_F(ValidateMemory, WebGPUInitializerWithFunctionStorageClassesGood) { + std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%float = OpTypeFloat 32 +%float_ptr = OpTypePointer Function %float +%init_val = OpConstant %float 1.0 +%void = OpTypeVoid +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +%2 = OpVariable %float_ptr Function %init_val +OpReturn +OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str(), SPV_ENV_WEBGPU_0); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0)); +} + +TEST_F(ValidateMemory, WebGPUInitializerWithPrivateStorageClassesGood) { + std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%float = OpTypeFloat 32 +%float_ptr = OpTypePointer Private %float +%init_val = OpConstant %float 1.0 +%1 = OpVariable %float_ptr Private %init_val +%void = OpTypeVoid +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%2 = OpLabel +OpReturn +OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str(), SPV_ENV_WEBGPU_0); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0)); +} + +TEST_F(ValidateMemory, WebGPUInitializerWithDisallowedStorageClassesBad) { + std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%float = OpTypeFloat 32 +%float_ptr = OpTypePointer Uniform %float +%init_val = OpConstant %float 1.0 +%1 = OpVariable %float_ptr Uniform %init_val +%void = OpTypeVoid +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%2 = OpLabel +OpReturn +OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str(), SPV_ENV_WEBGPU_0); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpVariable, '5[%5]', has a disallowed initializer & storage " + "class combination.\nFrom WebGPU execution environment spec:\n" + "Variable declarations that include initializers must have one of " + "the following storage classes: Output, Private, or Function\n" + " %5 = OpVariable %_ptr_Uniform_float Uniform %float_1\n")); +} + +TEST_F(ValidateMemory, WebGPUOutputStorageClassWithoutInitializerBad) { + std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%float = OpTypeFloat 32 +%float_ptr = OpTypePointer Output %float +%1 = OpVariable %float_ptr Output +%void = OpTypeVoid +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%2 = OpLabel +OpReturn +OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str(), SPV_ENV_WEBGPU_0); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpVariable, '4[%4]', must have an initializer.\n" + "From WebGPU execution environment spec:\n" + "All variables in the following storage classes must have an " + "initializer: Output, Private, or Function\n" + " %4 = OpVariable %_ptr_Output_float Output\n")); +} + +TEST_F(ValidateMemory, WebGPUFunctionStorageClassWithoutInitializerBad) { + std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%float = OpTypeFloat 32 +%float_ptr = OpTypePointer Function %float +%void = OpTypeVoid +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +%2 = OpVariable %float_ptr Function +OpReturn +OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str(), SPV_ENV_WEBGPU_0); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpVariable, '7[%7]', must have an initializer.\n" + "From WebGPU execution environment spec:\n" + "All variables in the following storage classes must have an " + "initializer: Output, Private, or Function\n" + " %7 = OpVariable %_ptr_Function_float Function\n")); +} + +TEST_F(ValidateMemory, WebGPUPrivateStorageClassWithoutInitializerBad) { + std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%float = OpTypeFloat 32 +%float_ptr = OpTypePointer Private %float +%1 = OpVariable %float_ptr Private +%void = OpTypeVoid +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%2 = OpLabel +OpReturn +OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str(), SPV_ENV_WEBGPU_0); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpVariable, '4[%4]', must have an initializer.\n" + "From WebGPU execution environment spec:\n" + "All variables in the following storage classes must have an " + "initializer: Output, Private, or Function\n" + " %4 = OpVariable %_ptr_Private_float Private\n")); +} + +TEST_F(ValidateMemory, VulkanInitializerWithOutputStorageClassesGood) { + std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%float = OpTypeFloat 32 +%float_ptr = OpTypePointer Output %float +%init_val = OpConstant %float 1.0 +%1 = OpVariable %float_ptr Output %init_val +%void = OpTypeVoid +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%2 = OpLabel +OpReturn +OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateMemory, VulkanInitializerWithFunctionStorageClassesGood) { + std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%float = OpTypeFloat 32 +%float_ptr = OpTypePointer Function %float +%init_val = OpConstant %float 1.0 +%void = OpTypeVoid +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +%2 = OpVariable %float_ptr Function %init_val +OpReturn +OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateMemory, VulkanInitializerWithPrivateStorageClassesGood) { + std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%float = OpTypeFloat 32 +%float_ptr = OpTypePointer Private %float +%init_val = OpConstant %float 1.0 +%1 = OpVariable %float_ptr Private %init_val +%void = OpTypeVoid +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%2 = OpLabel +OpReturn +OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateMemory, VulkanInitializerWithDisallowedStorageClassesBad) { + std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%float = OpTypeFloat 32 +%float_ptr = OpTypePointer Input %float +%init_val = OpConstant %float 1.0 +%1 = OpVariable %float_ptr Input %init_val +%void = OpTypeVoid +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%2 = OpLabel +OpReturn +OpFunctionEnd +)"; + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpVariable, '5[%5]', has a disallowed initializer & storage " + "class combination.\nFrom Vulkan spec, Appendix A:\n" + "Variable declarations that include initializers must have one of " + "the following storage classes: Output, Private, or Function\n " + "%5 = OpVariable %_ptr_Input_float Input %float_1\n")); +} + +TEST_F(ValidateMemory, ArrayLenCorrectResultType) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %uint = OpTypeInt 32 0 +%_runtimearr_float = OpTypeRuntimeArray %float + %_struct_7 = OpTypeStruct %_runtimearr_float +%_ptr_Function__struct_7 = OpTypePointer Function %_struct_7 + %1 = OpFunction %void None %3 + %9 = OpLabel + %10 = OpVariable %_ptr_Function__struct_7 Function + %11 = OpArrayLength %uint %10 0 + OpReturn + OpFunctionEnd + +)"; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateMemory, ArrayLenIndexCorrectWith2Members) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %uint = OpTypeInt 32 0 +%_runtimearr_float = OpTypeRuntimeArray %float + %_struct_7 = OpTypeStruct %float %_runtimearr_float +%_ptr_Function__struct_7 = OpTypePointer Function %_struct_7 + %1 = OpFunction %void None %3 + %9 = OpLabel + %10 = OpVariable %_ptr_Function__struct_7 Function + %11 = OpArrayLength %uint %10 1 + OpReturn + OpFunctionEnd + +)"; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateMemory, ArrayLenResultNotIntType) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 +%_runtimearr_float = OpTypeRuntimeArray %float + %_struct_6 = OpTypeStruct %_runtimearr_float +%_ptr_Function__struct_6 = OpTypePointer Function %_struct_6 + %1 = OpFunction %void None %3 + %8 = OpLabel + %9 = OpVariable %_ptr_Function__struct_6 Function + %10 = OpArrayLength %float %9 0 + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "The Result Type of OpArrayLength '10[%10]' must be OpTypeInt " + "with width 32 and signedness 0.\n %10 = OpArrayLength %float %9 " + "0\n")); +} + +TEST_F(ValidateMemory, ArrayLenResultNot32bits) { + std::string spirv = R"( + OpCapability Shader + OpCapability Int16 + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %ushort = OpTypeInt 16 0 +%_runtimearr_float = OpTypeRuntimeArray %float + %_struct_7 = OpTypeStruct %_runtimearr_float +%_ptr_Function__struct_7 = OpTypePointer Function %_struct_7 + %1 = OpFunction %void None %3 + %9 = OpLabel + %10 = OpVariable %_ptr_Function__struct_7 Function + %11 = OpArrayLength %ushort %10 0 + OpReturn + OpFunctionEnd + +)"; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "The Result Type of OpArrayLength '11[%11]' must be OpTypeInt " + "with width 32 and signedness 0.\n %11 = OpArrayLength %ushort %10 " + "0\n")); +} + +TEST_F(ValidateMemory, ArrayLenResultSigned) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %int = OpTypeInt 32 1 +%_runtimearr_float = OpTypeRuntimeArray %float + %_struct_7 = OpTypeStruct %_runtimearr_float +%_ptr_Function__struct_7 = OpTypePointer Function %_struct_7 + %1 = OpFunction %void None %3 + %9 = OpLabel + %10 = OpVariable %_ptr_Function__struct_7 Function + %11 = OpArrayLength %int %10 0 + OpReturn + OpFunctionEnd + +)"; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "The Result Type of OpArrayLength '11[%11]' must be OpTypeInt " + "with width 32 and signedness 0.\n %11 = OpArrayLength %int %10 " + "0\n")); +} + +TEST_F(ValidateMemory, ArrayLenInputNotStruct) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %uint = OpTypeInt 32 0 +%_runtimearr_float = OpTypeRuntimeArray %float + %_struct_7 = OpTypeStruct %_runtimearr_float +%_ptr_Function_float = OpTypePointer Function %float + %1 = OpFunction %void None %3 + %9 = OpLabel + %10 = OpVariable %_ptr_Function_float Function + %11 = OpArrayLength %uint %10 0 + OpReturn + OpFunctionEnd + +)"; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("The Struture's type in OpArrayLength '11[%11]' " + "must be a pointer to an OpTypeStruct.")); +} + +TEST_F(ValidateMemory, ArrayLenInputLastMemberNoRTA) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %uint = OpTypeInt 32 0 +%_runtimearr_float = OpTypeRuntimeArray %float + %_struct_7 = OpTypeStruct %float +%_ptr_Function__struct_7 = OpTypePointer Function %_struct_7 + %1 = OpFunction %void None %3 + %9 = OpLabel + %10 = OpVariable %_ptr_Function__struct_7 Function + %11 = OpArrayLength %uint %10 0 + OpReturn + OpFunctionEnd + +)"; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("The Struture's last member in OpArrayLength '11[%11]' " + "must be an OpTypeRuntimeArray.\n %11 = OpArrayLength %uint " + "%10 0\n")); +} + +TEST_F(ValidateMemory, ArrayLenInputLastMemberNoRTA2) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %uint = OpTypeInt 32 0 +%_runtimearr_float = OpTypeRuntimeArray %float + %_struct_7 = OpTypeStruct %_runtimearr_float %float +%_ptr_Function__struct_7 = OpTypePointer Function %_struct_7 + %1 = OpFunction %void None %3 + %9 = OpLabel + %10 = OpVariable %_ptr_Function__struct_7 Function + %11 = OpArrayLength %uint %10 1 + OpReturn + OpFunctionEnd + +)"; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("The Struture's last member in OpArrayLength '11[%11]' " + "must be an OpTypeRuntimeArray.\n %11 = OpArrayLength %uint " + "%10 1\n")); +} + +TEST_F(ValidateMemory, ArrayLenIndexNotLastMember) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %uint = OpTypeInt 32 0 +%_runtimearr_float = OpTypeRuntimeArray %float + %_struct_7 = OpTypeStruct %float %_runtimearr_float +%_ptr_Function__struct_7 = OpTypePointer Function %_struct_7 + %1 = OpFunction %void None %3 + %9 = OpLabel + %10 = OpVariable %_ptr_Function__struct_7 Function + %11 = OpArrayLength %uint %10 0 + OpReturn + OpFunctionEnd + +)"; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "The array member in OpArrayLength '11[%11]' must be an the " + "last member of the struct.\n %11 = OpArrayLength %uint %10 0\n")); +} + +TEST_F(ValidateMemory, ArrayLenIndexNotPointerToStruct) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %uint = OpTypeInt 32 0 +%_runtimearr_float = OpTypeRuntimeArray %float + %_struct_7 = OpTypeStruct %float %_runtimearr_float +%_ptr_Function__struct_7 = OpTypePointer Function %_struct_7 + %1 = OpFunction %void None %3 + %9 = OpLabel + %10 = OpVariable %_ptr_Function__struct_7 Function + %11 = OpLoad %_struct_7 %10 + %12 = OpArrayLength %uint %11 0 + OpReturn + OpFunctionEnd + +)"; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "The Struture's type in OpArrayLength '12[%12]' must be a " + "pointer to an OpTypeStruct.\n %12 = OpArrayLength %uint %11 0\n")); +} + +TEST_F(ValidateMemory, ArrayLenPointerIsAType) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %uint = OpTypeInt 32 0 + %1 = OpFunction %void None %3 + %9 = OpLabel + %12 = OpArrayLength %uint %float 0 + OpReturn + OpFunctionEnd + +)"; + + CompileSuccessfully(spirv.c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("Operand 4[%float] cannot be a " + "type")); +} + +TEST_F(ValidateMemory, PushConstantNotStructGood) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %ptr = OpTypePointer PushConstant %float + %pc = OpVariable %ptr PushConstant + + %1 = OpFunction %void None %voidfn + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateMemory, VulkanPushConstantNotStructBad) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %ptr = OpTypePointer PushConstant %float + %pc = OpVariable %ptr PushConstant + + %1 = OpFunction %void None %voidfn + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("PushConstant OpVariable '6[%6]' has illegal " + "type.\nFrom Vulkan spec, section 14.5.1:\n" + "Such variables must be typed as OpTypeStruct, " + "or an array of this type")); +} + +TEST_F(ValidateMemory, VulkanPushConstant) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginUpperLeft + + OpDecorate %struct Block + OpMemberDecorate %struct 0 Offset 0 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + %float = OpTypeFloat 32 + %struct = OpTypeStruct %float + %ptr = OpTypePointer PushConstant %struct + %pc = OpVariable %ptr PushConstant + + %1 = OpFunction %void None %voidfn + %label = OpLabel + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateMemory, VulkanMemoryModelDeviceScopeLoadBad1) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%device = OpConstant %int 1 +%int_ptr_ssbo = OpTypePointer StorageBuffer %int +%var = OpVariable %int_ptr_ssbo StorageBuffer +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +%load = OpLoad %int %var MakePointerVisibleKHR|NonPrivatePointerKHR %device +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Use of device scope with VulkanKHR memory model requires the " + "VulkanMemoryModelDeviceScopeKHR capability")); +} + +TEST_F(ValidateMemory, VulkanMemoryModelDeviceScopeLoadBad2) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%device = OpConstant %int 1 +%int_ptr_ssbo = OpTypePointer StorageBuffer %int +%var = OpVariable %int_ptr_ssbo StorageBuffer +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +%load = OpLoad %int %var Aligned|MakePointerVisibleKHR|NonPrivatePointerKHR 4 %device +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Use of device scope with VulkanKHR memory model requires the " + "VulkanMemoryModelDeviceScopeKHR capability")); +} + +TEST_F(ValidateMemory, VulkanMemoryModelDeviceScopeLoadGood1) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability VulkanMemoryModelDeviceScopeKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%device = OpConstant %int 1 +%int_ptr_ssbo = OpTypePointer StorageBuffer %int +%var = OpVariable %int_ptr_ssbo StorageBuffer +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +%load = OpLoad %int %var MakePointerVisibleKHR|NonPrivatePointerKHR %device +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateMemory, VulkanMemoryModelDeviceScopeLoadGood2) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability VulkanMemoryModelDeviceScopeKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%device = OpConstant %int 1 +%int_ptr_ssbo = OpTypePointer StorageBuffer %int +%var = OpVariable %int_ptr_ssbo StorageBuffer +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +%load = OpLoad %int %var Aligned|MakePointerVisibleKHR|NonPrivatePointerKHR 4 %device +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateMemory, VulkanMemoryModelDeviceScopeStoreBad1) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%device = OpConstant %int 1 +%int_ptr_ssbo = OpTypePointer StorageBuffer %int +%var = OpVariable %int_ptr_ssbo StorageBuffer +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +OpStore %var %device MakePointerAvailableKHR|NonPrivatePointerKHR %device +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Use of device scope with VulkanKHR memory model requires the " + "VulkanMemoryModelDeviceScopeKHR capability")); +} + +TEST_F(ValidateMemory, VulkanMemoryModelDeviceScopeStoreBad2) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%device = OpConstant %int 1 +%int_ptr_ssbo = OpTypePointer StorageBuffer %int +%var = OpVariable %int_ptr_ssbo StorageBuffer +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +OpStore %var %device Aligned|MakePointerAvailableKHR|NonPrivatePointerKHR 4 %device +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Use of device scope with VulkanKHR memory model requires the " + "VulkanMemoryModelDeviceScopeKHR capability")); +} + +TEST_F(ValidateMemory, VulkanMemoryModelDeviceScopeStoreGood1) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability VulkanMemoryModelDeviceScopeKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%device = OpConstant %int 1 +%int_ptr_ssbo = OpTypePointer StorageBuffer %int +%var = OpVariable %int_ptr_ssbo StorageBuffer +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +OpStore %var %device MakePointerAvailableKHR|NonPrivatePointerKHR %device +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateMemory, VulkanMemoryModelDeviceScopeStoreGood2) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability VulkanMemoryModelDeviceScopeKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%device = OpConstant %int 1 +%int_ptr_ssbo = OpTypePointer StorageBuffer %int +%var = OpVariable %int_ptr_ssbo StorageBuffer +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +OpStore %var %device Aligned|MakePointerAvailableKHR|NonPrivatePointerKHR 4 %device +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateMemory, VulkanMemoryModelDeviceScopeCopyMemoryBad1) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%device = OpConstant %int 1 +%int_ptr_ssbo = OpTypePointer StorageBuffer %int +%var1 = OpVariable %int_ptr_ssbo StorageBuffer +%var2 = OpVariable %int_ptr_ssbo StorageBuffer +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +OpCopyMemory %var1 %var2 MakePointerAvailableKHR|NonPrivatePointerKHR %device +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Use of device scope with VulkanKHR memory model requires the " + "VulkanMemoryModelDeviceScopeKHR capability")); +} + +TEST_F(ValidateMemory, VulkanMemoryModelDeviceScopeCopyMemoryBad2) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%device = OpConstant %int 1 +%workgroup = OpConstant %int 1 +%int_ptr_ssbo = OpTypePointer StorageBuffer %int +%var1 = OpVariable %int_ptr_ssbo StorageBuffer +%var2 = OpVariable %int_ptr_ssbo StorageBuffer +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +OpCopyMemory %var1 %var2 Aligned|MakePointerVisibleKHR|MakePointerAvailableKHR|NonPrivatePointerKHR 4 %device %workgroup +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Use of device scope with VulkanKHR memory model requires the " + "VulkanMemoryModelDeviceScopeKHR capability")); +} + +TEST_F(ValidateMemory, VulkanMemoryModelDeviceScopeCopyMemoryBad3) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%device = OpConstant %int 1 +%workgroup = OpConstant %int 1 +%int_ptr_ssbo = OpTypePointer StorageBuffer %int +%var1 = OpVariable %int_ptr_ssbo StorageBuffer +%var2 = OpVariable %int_ptr_ssbo StorageBuffer +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +OpCopyMemory %var1 %var2 Aligned|MakePointerVisibleKHR|MakePointerAvailableKHR|NonPrivatePointerKHR 4 %workgroup %device +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Use of device scope with VulkanKHR memory model requires the " + "VulkanMemoryModelDeviceScopeKHR capability")); +} + +TEST_F(ValidateMemory, VulkanMemoryModelDeviceScopeCopyMemoryGood1) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability VulkanMemoryModelDeviceScopeKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%device = OpConstant %int 1 +%int_ptr_ssbo = OpTypePointer StorageBuffer %int +%var1 = OpVariable %int_ptr_ssbo StorageBuffer +%var2 = OpVariable %int_ptr_ssbo StorageBuffer +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +OpCopyMemory %var1 %var2 MakePointerAvailableKHR|NonPrivatePointerKHR %device +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateMemory, VulkanMemoryModelDeviceScopeCopyMemoryGood2) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability VulkanMemoryModelDeviceScopeKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%device = OpConstant %int 1 +%workgroup = OpConstant %int 2 +%int_ptr_ssbo = OpTypePointer StorageBuffer %int +%var1 = OpVariable %int_ptr_ssbo StorageBuffer +%var2 = OpVariable %int_ptr_ssbo StorageBuffer +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +OpCopyMemory %var1 %var2 Aligned|MakePointerVisibleKHR|MakePointerAvailableKHR|NonPrivatePointerKHR 4 %device %workgroup +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateMemory, VulkanMemoryModelDeviceScopeCopyMemoryGood3) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability VulkanMemoryModelDeviceScopeKHR +OpCapability Linkage +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%device = OpConstant %int 1 +%workgroup = OpConstant %int 2 +%int_ptr_ssbo = OpTypePointer StorageBuffer %int +%var1 = OpVariable %int_ptr_ssbo StorageBuffer +%var2 = OpVariable %int_ptr_ssbo StorageBuffer +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +OpCopyMemory %var1 %var2 Aligned|MakePointerVisibleKHR|MakePointerAvailableKHR|NonPrivatePointerKHR 4 %workgroup %device +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateMemory, VulkanMemoryModelDeviceScopeCopyMemorySizedBad1) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability Linkage +OpCapability Addresses +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%device = OpConstant %int 1 +%int_ptr_ssbo = OpTypePointer StorageBuffer %int +%var1 = OpVariable %int_ptr_ssbo StorageBuffer +%var2 = OpVariable %int_ptr_ssbo StorageBuffer +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +OpCopyMemorySized %var1 %var2 %device MakePointerAvailableKHR|NonPrivatePointerKHR %device +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Use of device scope with VulkanKHR memory model requires the " + "VulkanMemoryModelDeviceScopeKHR capability")); +} + +TEST_F(ValidateMemory, VulkanMemoryModelDeviceScopeCopyMemorySizedBad2) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability Linkage +OpCapability Addresses +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%device = OpConstant %int 1 +%workgroup = OpConstant %int 1 +%int_ptr_ssbo = OpTypePointer StorageBuffer %int +%var1 = OpVariable %int_ptr_ssbo StorageBuffer +%var2 = OpVariable %int_ptr_ssbo StorageBuffer +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +OpCopyMemorySized %var1 %var2 %device Aligned|MakePointerVisibleKHR|MakePointerAvailableKHR|NonPrivatePointerKHR 4 %device %workgroup +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Use of device scope with VulkanKHR memory model requires the " + "VulkanMemoryModelDeviceScopeKHR capability")); +} + +TEST_F(ValidateMemory, VulkanMemoryModelDeviceScopeCopyMemorySizedBad3) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability Linkage +OpCapability Addresses +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%device = OpConstant %int 1 +%workgroup = OpConstant %int 1 +%int_ptr_ssbo = OpTypePointer StorageBuffer %int +%var1 = OpVariable %int_ptr_ssbo StorageBuffer +%var2 = OpVariable %int_ptr_ssbo StorageBuffer +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +OpCopyMemorySized %var1 %var2 %device Aligned|MakePointerVisibleKHR|MakePointerAvailableKHR|NonPrivatePointerKHR 4 %workgroup %device +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Use of device scope with VulkanKHR memory model requires the " + "VulkanMemoryModelDeviceScopeKHR capability")); +} + +TEST_F(ValidateMemory, VulkanMemoryModelDeviceScopeCopyMemorySizedGood1) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability VulkanMemoryModelDeviceScopeKHR +OpCapability Linkage +OpCapability Addresses +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%device = OpConstant %int 1 +%int_ptr_ssbo = OpTypePointer StorageBuffer %int +%var1 = OpVariable %int_ptr_ssbo StorageBuffer +%var2 = OpVariable %int_ptr_ssbo StorageBuffer +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +OpCopyMemorySized %var1 %var2 %device MakePointerAvailableKHR|NonPrivatePointerKHR %device +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateMemory, VulkanMemoryModelDeviceScopeCopyMemorySizedGood2) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability VulkanMemoryModelDeviceScopeKHR +OpCapability Linkage +OpCapability Addresses +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%device = OpConstant %int 1 +%workgroup = OpConstant %int 2 +%int_ptr_ssbo = OpTypePointer StorageBuffer %int +%var1 = OpVariable %int_ptr_ssbo StorageBuffer +%var2 = OpVariable %int_ptr_ssbo StorageBuffer +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +OpCopyMemorySized %var1 %var2 %device Aligned|MakePointerVisibleKHR|MakePointerAvailableKHR|NonPrivatePointerKHR 4 %device %workgroup +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateMemory, VulkanMemoryModelDeviceScopeCopyMemorySizedGood3) { + const std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpCapability VulkanMemoryModelDeviceScopeKHR +OpCapability Linkage +OpCapability Addresses +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +%void = OpTypeVoid +%int = OpTypeInt 32 0 +%device = OpConstant %int 1 +%workgroup = OpConstant %int 2 +%int_ptr_ssbo = OpTypePointer StorageBuffer %int +%var1 = OpVariable %int_ptr_ssbo StorageBuffer +%var2 = OpVariable %int_ptr_ssbo StorageBuffer +%voidfn = OpTypeFunction %void +%func = OpFunction %void None %voidfn +%entry = OpLabel +OpCopyMemorySized %var1 %var2 %device Aligned|MakePointerVisibleKHR|MakePointerAvailableKHR|NonPrivatePointerKHR 4 %workgroup %device +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); +} + +TEST_F(ValidateMemory, ArrayLengthStructIsLabel) { + const std::string spirv = R"( +OpCapability Tessellation +OpMemoryModel Logical GLSL450 +OpName %20 "incorrect" +%void = OpTypeVoid +%3 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%uint = OpTypeInt 32 0 +%4 = OpFunction %void None %3 +%20 = OpLabel +%24 = OpArrayLength %uint %20 0 +%25 = OpLoad %v4float %24 +OpReturnValue %25 +OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Operand 1[%incorrect] requires a type")); +} + +TEST_F(ValidateMemory, PSBLoadAlignedSuccess) { + const std::string body = R"( +OpCapability PhysicalStorageBufferAddressesEXT +OpCapability Int64 +OpCapability Shader +OpExtension "SPV_EXT_physical_storage_buffer" +OpMemoryModel PhysicalStorageBuffer64EXT GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpDecorate %val1 AliasedPointerEXT +%uint64 = OpTypeInt 64 0 +%ptr = OpTypePointer PhysicalStorageBufferEXT %uint64 +%pptr_f = OpTypePointer Function %ptr +%void = OpTypeVoid +%voidfn = OpTypeFunction %void +%main = OpFunction %void None %voidfn +%entry = OpLabel +%val1 = OpVariable %pptr_f Function +%val2 = OpLoad %ptr %val1 +%val3 = OpLoad %uint64 %val2 Aligned 8 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateMemory, PSBLoadAlignedMissing) { + const std::string body = R"( +OpCapability PhysicalStorageBufferAddressesEXT +OpCapability Int64 +OpCapability Shader +OpExtension "SPV_EXT_physical_storage_buffer" +OpMemoryModel PhysicalStorageBuffer64EXT GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpDecorate %val1 AliasedPointerEXT +%uint64 = OpTypeInt 64 0 +%ptr = OpTypePointer PhysicalStorageBufferEXT %uint64 +%pptr_f = OpTypePointer Function %ptr +%void = OpTypeVoid +%voidfn = OpTypeFunction %void +%main = OpFunction %void None %voidfn +%entry = OpLabel +%val1 = OpVariable %pptr_f Function +%val2 = OpLoad %ptr %val1 +%val3 = OpLoad %uint64 %val2 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Memory accesses with PhysicalStorageBufferEXT must use Aligned")); +} + +TEST_F(ValidateMemory, PSBStoreAlignedSuccess) { + const std::string body = R"( +OpCapability PhysicalStorageBufferAddressesEXT +OpCapability Int64 +OpCapability Shader +OpExtension "SPV_EXT_physical_storage_buffer" +OpMemoryModel PhysicalStorageBuffer64EXT GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpDecorate %val1 AliasedPointerEXT +%uint64 = OpTypeInt 64 0 +%u64_1 = OpConstant %uint64 1 +%ptr = OpTypePointer PhysicalStorageBufferEXT %uint64 +%pptr_f = OpTypePointer Function %ptr +%void = OpTypeVoid +%voidfn = OpTypeFunction %void +%main = OpFunction %void None %voidfn +%entry = OpLabel +%val1 = OpVariable %pptr_f Function +%val2 = OpLoad %ptr %val1 +OpStore %val2 %u64_1 Aligned 8 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateMemory, PSBStoreAlignedMissing) { + const std::string body = R"( +OpCapability PhysicalStorageBufferAddressesEXT +OpCapability Int64 +OpCapability Shader +OpExtension "SPV_EXT_physical_storage_buffer" +OpMemoryModel PhysicalStorageBuffer64EXT GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpDecorate %val1 AliasedPointerEXT +%uint64 = OpTypeInt 64 0 +%u64_1 = OpConstant %uint64 1 +%ptr = OpTypePointer PhysicalStorageBufferEXT %uint64 +%pptr_f = OpTypePointer Function %ptr +%void = OpTypeVoid +%voidfn = OpTypeFunction %void +%main = OpFunction %void None %voidfn +%entry = OpLabel +%val1 = OpVariable %pptr_f Function +%val2 = OpLoad %ptr %val1 +OpStore %val2 %u64_1 None +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Memory accesses with PhysicalStorageBufferEXT must use Aligned")); +} + +TEST_F(ValidateMemory, PSBVariable) { + const std::string body = R"( +OpCapability PhysicalStorageBufferAddressesEXT +OpCapability Int64 +OpCapability Shader +OpExtension "SPV_EXT_physical_storage_buffer" +OpMemoryModel PhysicalStorageBuffer64EXT GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpDecorate %val1 AliasedPointerEXT +%uint64 = OpTypeInt 64 0 +%ptr = OpTypePointer PhysicalStorageBufferEXT %uint64 +%val1 = OpVariable %ptr PhysicalStorageBufferEXT +%void = OpTypeVoid +%voidfn = OpTypeFunction %void +%main = OpFunction %void None %voidfn +%entry = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("PhysicalStorageBufferEXT must not be used with OpVariable")); +} + +TEST_F(ValidateMemory, VulkanRTAOutsideOfStructBad) { + std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%sampler_t = OpTypeSampler +%array_t = OpTypeRuntimeArray %sampler_t +%array_ptr = OpTypePointer UniformConstant %array_t +%2 = OpVariable %array_ptr UniformConstant +%void = OpTypeVoid +%func_t = OpTypeFunction %void +%func = OpFunction %void None %func_t +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpVariable, '5[%5]', is attempting to create memory for an " + "illegal type, OpTypeRuntimeArray.\nFor Vulkan OpTypeRuntimeArray " + "can only appear as the final member of an OpTypeStruct, thus cannot " + "be instantiated via OpVariable\n %5 = OpVariable " + "%_ptr_UniformConstant__runtimearr_2 UniformConstant\n")); +} + +TEST_F(ValidateMemory, WebGPURTAOutsideOfStructBad) { + std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%sampler_t = OpTypeSampler +%array_t = OpTypeRuntimeArray %sampler_t +%array_ptr = OpTypePointer UniformConstant %array_t +%2 = OpVariable %array_ptr UniformConstant +%void = OpTypeVoid +%func_t = OpTypeFunction %void +%func = OpFunction %void None %func_t +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str(), SPV_ENV_WEBGPU_0); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpVariable, '5[%5]', is attempting to create memory for an " + "illegal type, OpTypeRuntimeArray.\nFor WebGPU OpTypeRuntimeArray " + "can only appear as the final member of an OpTypeStruct, thus cannot " + "be instantiated via OpVariable\n %5 = OpVariable " + "%_ptr_UniformConstant__runtimearr_2 UniformConstant\n")); +} + +TEST_F(ValidateMemory, VulkanRTAOutsideOfStructWithRuntimeDescriptorArrayGood) { + std::string spirv = R"( +OpCapability Shader +OpCapability RuntimeDescriptorArrayEXT +OpExtension "SPV_EXT_descriptor_indexing" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%sampler_t = OpTypeSampler +%array_t = OpTypeRuntimeArray %sampler_t +%array_sb_ptr = OpTypePointer StorageBuffer %array_t +%2 = OpVariable %array_sb_ptr StorageBuffer +%array_uc_ptr = OpTypePointer UniformConstant %array_t +%3 = OpVariable %array_uc_ptr UniformConstant +%void = OpTypeVoid +%func_t = OpTypeFunction %void +%func = OpFunction %void None %func_t +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F( + ValidateMemory, + VulkanRTAOutsideOfStructWithRuntimeDescriptorArrayAndWrongStorageClassBad) { + std::string spirv = R"( +OpCapability Shader +OpCapability RuntimeDescriptorArrayEXT +OpExtension "SPV_EXT_descriptor_indexing" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%uint_t = OpTypeInt 32 0 +%array_t = OpTypeRuntimeArray %uint_t +%array_ptr = OpTypePointer Workgroup %array_t +%2 = OpVariable %array_ptr Workgroup +%void = OpTypeVoid +%func_t = OpTypeFunction %void +%func = OpFunction %void None %func_t +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("For Vulkan with RuntimeDescriptorArrayEXT, a variable " + "containing OpTypeRuntimeArray must have storage class of " + "StorageBuffer, Uniform, or UniformConstant.\n %5 = " + "OpVariable %_ptr_Workgroup__runtimearr_uint Workgroup\n")); +} + +TEST_F(ValidateMemory, VulkanRTAInsideStorageBufferStructGood) { + std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +OpDecorate %array_t ArrayStride 4 +OpMemberDecorate %struct_t 0 Offset 0 +OpDecorate %struct_t Block +%uint_t = OpTypeInt 32 0 +%array_t = OpTypeRuntimeArray %uint_t +%struct_t = OpTypeStruct %array_t +%struct_ptr = OpTypePointer StorageBuffer %struct_t +%2 = OpVariable %struct_ptr StorageBuffer +%void = OpTypeVoid +%func_t = OpTypeFunction %void +%func = OpFunction %void None %func_t +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateMemory, WebGPURTAInsideStorageBufferStructGood) { + std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +OpDecorate %array_t ArrayStride 4 +OpMemberDecorate %struct_t 0 Offset 0 +OpDecorate %struct_t Block +%uint_t = OpTypeInt 32 0 +%array_t = OpTypeRuntimeArray %uint_t +%struct_t = OpTypeStruct %array_t +%struct_ptr = OpTypePointer StorageBuffer %struct_t +%2 = OpVariable %struct_ptr StorageBuffer +%void = OpTypeVoid +%func_t = OpTypeFunction %void +%func = OpFunction %void None %func_t +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str(), SPV_ENV_WEBGPU_0); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0)); +} + +TEST_F(ValidateMemory, VulkanRTAInsideWrongStorageClassStructBad) { + std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%uint_t = OpTypeInt 32 0 +%array_t = OpTypeRuntimeArray %uint_t +%struct_t = OpTypeStruct %array_t +%struct_ptr = OpTypePointer Workgroup %struct_t +%2 = OpVariable %struct_ptr Workgroup +%void = OpTypeVoid +%func_t = OpTypeFunction %void +%func = OpFunction %void None %func_t +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "For Vulkan, OpTypeStruct variables containing OpTypeRuntimeArray " + "must have storage class of StorageBuffer or Uniform.\n %6 = " + "OpVariable %_ptr_Workgroup__struct_4 Workgroup\n")); +} + +TEST_F(ValidateMemory, WebGPURTAInsideWrongStorageClassStructBad) { + std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%uint_t = OpTypeInt 32 0 +%array_t = OpTypeRuntimeArray %uint_t +%struct_t = OpTypeStruct %array_t +%struct_ptr = OpTypePointer Workgroup %struct_t +%2 = OpVariable %struct_ptr Workgroup +%void = OpTypeVoid +%func_t = OpTypeFunction %void +%func = OpFunction %void None %func_t +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str(), SPV_ENV_WEBGPU_0); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("For WebGPU, OpTypeStruct variables containing " + "OpTypeRuntimeArray must have storage class of StorageBuffer\n " + " %6 = OpVariable %_ptr_Workgroup__struct_4 Workgroup\n")); +} + +TEST_F(ValidateMemory, VulkanRTAInsideStorageBufferStructWithoutBlockBad) { + std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%uint_t = OpTypeInt 32 0 +%array_t = OpTypeRuntimeArray %uint_t +%struct_t = OpTypeStruct %array_t +%struct_ptr = OpTypePointer StorageBuffer %struct_t +%2 = OpVariable %struct_ptr StorageBuffer +%void = OpTypeVoid +%func_t = OpTypeFunction %void +%func = OpFunction %void None %func_t +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("For Vulkan, an OpTypeStruct variable containing an " + "OpTypeRuntimeArray must be decorated with Block if it " + "has storage class StorageBuffer.\n %6 = OpVariable " + "%_ptr_StorageBuffer__struct_4 StorageBuffer\n")); +} + +TEST_F(ValidateMemory, WebGPURTAInsideStorageBufferStructWithoutBlockBad) { + std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%uint_t = OpTypeInt 32 0 +%array_t = OpTypeRuntimeArray %uint_t +%struct_t = OpTypeStruct %array_t +%struct_ptr = OpTypePointer StorageBuffer %struct_t +%2 = OpVariable %struct_ptr StorageBuffer +%void = OpTypeVoid +%func_t = OpTypeFunction %void +%func = OpFunction %void None %func_t +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str(), SPV_ENV_WEBGPU_0); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("For WebGPU, an OpTypeStruct variable containing an " + "OpTypeRuntimeArray must be decorated with Block if it " + "has storage class StorageBuffer.\n %6 = OpVariable " + "%_ptr_StorageBuffer__struct_4 StorageBuffer\n")); +} + +TEST_F(ValidateMemory, VulkanRTAInsideUniformStructGood) { + std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +OpDecorate %array_t ArrayStride 4 +OpMemberDecorate %struct_t 0 Offset 0 +OpDecorate %struct_t BufferBlock +%uint_t = OpTypeInt 32 0 +%array_t = OpTypeRuntimeArray %uint_t +%struct_t = OpTypeStruct %array_t +%struct_ptr = OpTypePointer Uniform %struct_t +%2 = OpVariable %struct_ptr Uniform +%void = OpTypeVoid +%func_t = OpTypeFunction %void +%func = OpFunction %void None %func_t +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateMemory, WebGPURTAInsideUniformStructBad) { + std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +OpDecorate %array_t ArrayStride 4 +OpMemberDecorate %struct_t 0 Offset 0 +OpDecorate %struct_t BufferBlock +%uint_t = OpTypeInt 32 0 +%array_t = OpTypeRuntimeArray %uint_t +%struct_t = OpTypeStruct %array_t +%struct_ptr = OpTypePointer Uniform %struct_t +%2 = OpVariable %struct_ptr Uniform +%void = OpTypeVoid +%func_t = OpTypeFunction %void +%func = OpFunction %void None %func_t +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str(), SPV_ENV_WEBGPU_0); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("For WebGPU, OpTypeStruct variables containing " + "OpTypeRuntimeArray must have storage class of StorageBuffer\n " + " %6 = OpVariable %_ptr_Uniform__struct_3 Uniform\n")); +} + +TEST_F(ValidateMemory, VulkanRTAInsideUniformStructWithoutBufferBlockBad) { + std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%uint_t = OpTypeInt 32 0 +%array_t = OpTypeRuntimeArray %uint_t +%struct_t = OpTypeStruct %array_t +%struct_ptr = OpTypePointer Uniform %struct_t +%2 = OpVariable %struct_ptr Uniform +%void = OpTypeVoid +%func_t = OpTypeFunction %void +%func = OpFunction %void None %func_t +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("For Vulkan, an OpTypeStruct variable containing an " + "OpTypeRuntimeArray must be decorated with BufferBlock " + "if it has storage class Uniform.\n %6 = OpVariable " + "%_ptr_Uniform__struct_4 Uniform\n")); +} + +TEST_F(ValidateMemory, VulkanRTAInsideRTABad) { + std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%sampler_t = OpTypeSampler +%inner_array_t = OpTypeRuntimeArray %sampler_t +%array_t = OpTypeRuntimeArray %inner_array_t +%array_ptr = OpTypePointer UniformConstant %array_t +%2 = OpVariable %array_ptr UniformConstant +%void = OpTypeVoid +%func_t = OpTypeFunction %void +%func = OpFunction %void None %func_t +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpTypeRuntimeArray Element Type '3[%_runtimearr_2]' is not " + "valid in Vulkan environment.\n %_runtimearr__runtimearr_2 = " + "OpTypeRuntimeArray %_runtimearr_2\n")); +} + +TEST_F(ValidateMemory, WebGPURTAInsideRTABad) { + std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%sampler_t = OpTypeSampler +%inner_array_t = OpTypeRuntimeArray %sampler_t +%array_t = OpTypeRuntimeArray %inner_array_t +%array_ptr = OpTypePointer UniformConstant %array_t +%2 = OpVariable %array_ptr UniformConstant +%void = OpTypeVoid +%func_t = OpTypeFunction %void +%func = OpFunction %void None %func_t +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str(), SPV_ENV_WEBGPU_0); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpTypeRuntimeArray Element Type '3[%_runtimearr_2]' is not " + "valid in WebGPU environment.\n %_runtimearr__runtimearr_2 = " + "OpTypeRuntimeArray %_runtimearr_2\n")); +} + +TEST_F(ValidateMemory, VulkanRTAInsideRTAWithRuntimeDescriptorArrayBad) { + std::string spirv = R"( +OpCapability RuntimeDescriptorArrayEXT +OpCapability Shader +OpExtension "SPV_EXT_descriptor_indexing" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +OpDecorate %array_t Block +%uint_t = OpTypeInt 32 0 +%inner_array_t = OpTypeRuntimeArray %uint_t +%array_t = OpTypeRuntimeArray %inner_array_t +%array_ptr = OpTypePointer StorageBuffer %array_t +%2 = OpVariable %array_ptr StorageBuffer +%void = OpTypeVoid +%func_t = OpTypeFunction %void +%func = OpFunction %void None %func_t +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpTypeRuntimeArray Element Type '4[%_runtimearr_uint]' is not " + "valid in Vulkan environment.\n %_runtimearr__runtimearr_uint = " + "OpTypeRuntimeArray %_runtimearr_uint\n")); +} + +TEST_F(ValidateMemory, + VulkanUniformStructInsideRTAWithRuntimeDescriptorArrayGood) { + std::string spirv = R"( +OpCapability RuntimeDescriptorArrayEXT +OpCapability Shader +OpExtension "SPV_EXT_descriptor_indexing" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +OpDecorate %array_t ArrayStride 4 +OpMemberDecorate %struct_t 0 Offset 0 +OpDecorate %struct_t Block +%uint_t = OpTypeInt 32 0 +%struct_t = OpTypeStruct %uint_t +%array_t = OpTypeRuntimeArray %struct_t +%array_ptr = OpTypePointer Uniform %array_t +%2 = OpVariable %array_ptr Uniform +%void = OpTypeVoid +%func_t = OpTypeFunction %void +%func = OpFunction %void None %func_t +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateMemory, VulkanRTAInsideRTAInsideStructBad) { + std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +OpDecorate %array_t ArrayStride 4 +OpMemberDecorate %struct_t 0 Offset 0 +OpDecorate %struct_t Block +%uint_t = OpTypeInt 32 0 +%inner_array_t = OpTypeRuntimeArray %uint_t +%array_t = OpTypeRuntimeArray %inner_array_t +%struct_t = OpTypeStruct %array_t +%struct_ptr = OpTypePointer StorageBuffer %struct_t +%2 = OpVariable %struct_ptr StorageBuffer +%void = OpTypeVoid +%func_t = OpTypeFunction %void +%func = OpFunction %void None %func_t +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpTypeRuntimeArray Element Type '5[%_runtimearr_uint]' is not " + "valid in Vulkan environment.\n %_runtimearr__runtimearr_uint = " + "OpTypeRuntimeArray %_runtimearr_uint\n")); +} + +TEST_F(ValidateMemory, + VulkanRTAInsideRTAInsideStructWithRuntimeDescriptorArrayBad) { + std::string spirv = R"( +OpCapability RuntimeDescriptorArrayEXT +OpCapability Shader +OpExtension "SPV_EXT_descriptor_indexing" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +OpDecorate %array_t ArrayStride 4 +OpMemberDecorate %struct_t 0 Offset 0 +OpDecorate %struct_t Block +%uint_t = OpTypeInt 32 0 +%inner_array_t = OpTypeRuntimeArray %uint_t +%array_t = OpTypeRuntimeArray %inner_array_t +%struct_t = OpTypeStruct %array_t +%struct_ptr = OpTypePointer StorageBuffer %struct_t +%2 = OpVariable %struct_ptr StorageBuffer +%void = OpTypeVoid +%func_t = OpTypeFunction %void +%func = OpFunction %void None %func_t +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpTypeRuntimeArray Element Type '5[%_runtimearr_uint]' is not " + "valid in Vulkan environment.\n %_runtimearr__runtimearr_uint = " + "OpTypeRuntimeArray %_runtimearr_uint\n")); +} + +TEST_F(ValidateMemory, VulkanRTAInsideArrayBad) { + std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%uint_t = OpTypeInt 32 0 +%dim = OpConstant %uint_t 1 +%sampler_t = OpTypeSampler +%inner_array_t = OpTypeRuntimeArray %sampler_t +%array_t = OpTypeArray %inner_array_t %dim +%array_ptr = OpTypePointer UniformConstant %array_t +%2 = OpVariable %array_ptr UniformConstant +%void = OpTypeVoid +%func_t = OpTypeFunction %void +%func = OpFunction %void None %func_t +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpTypeArray Element Type '5[%_runtimearr_4]' is not " + "valid in Vulkan environment.\n %_arr__runtimearr_4_uint_1 = " + "OpTypeArray %_runtimearr_4 %uint_1\n")); +} + +TEST_F(ValidateMemory, WebGPURTAInsideArrayBad) { + std::string spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%uint_t = OpTypeInt 32 0 +%dim = OpConstant %uint_t 1 +%sampler_t = OpTypeSampler +%inner_array_t = OpTypeRuntimeArray %sampler_t +%array_t = OpTypeArray %inner_array_t %dim +%array_ptr = OpTypePointer UniformConstant %array_t +%2 = OpVariable %array_ptr UniformConstant +%void = OpTypeVoid +%func_t = OpTypeFunction %void +%func = OpFunction %void None %func_t +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str(), SPV_ENV_WEBGPU_0); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpTypeArray Element Type '5[%_runtimearr_4]' is not " + "valid in WebGPU environment.\n %_arr__runtimearr_4_uint_1 = " + "OpTypeArray %_runtimearr_4 %uint_1\n")); +} + +TEST_F(ValidateMemory, VulkanRTAInsideArrayWithRuntimeDescriptorArrayBad) { + std::string spirv = R"( +OpCapability RuntimeDescriptorArrayEXT +OpCapability Shader +OpExtension "SPV_EXT_descriptor_indexing" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +OpDecorate %array_t Block +%uint_t = OpTypeInt 32 0 +%dim = OpConstant %uint_t 1 +%sampler_t = OpTypeSampler +%inner_array_t = OpTypeRuntimeArray %uint_t +%array_t = OpTypeRuntimeArray %inner_array_t +%array_ptr = OpTypePointer StorageBuffer %array_t +%2 = OpVariable %array_ptr StorageBuffer +%void = OpTypeVoid +%func_t = OpTypeFunction %void +%func = OpFunction %void None %func_t +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "OpTypeRuntimeArray Element Type '6[%_runtimearr_uint]' is not " + "valid in Vulkan environment.\n %_runtimearr__runtimearr_uint = " + "OpTypeRuntimeArray %_runtimearr_uint\n")); +} + +TEST_F(ValidateMemory, VulkanRTAInsideArrayInsideStructBad) { + std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +OpDecorate %array_t ArrayStride 4 +OpMemberDecorate %struct_t 0 Offset 0 +OpDecorate %struct_t Block +%uint_t = OpTypeInt 32 0 +%dim = OpConstant %uint_t 1 +%inner_array_t = OpTypeRuntimeArray %uint_t +%array_t = OpTypeArray %inner_array_t %dim +%struct_t = OpTypeStruct %array_t +%struct_ptr = OpTypePointer StorageBuffer %struct_t +%2 = OpVariable %struct_ptr StorageBuffer +%void = OpTypeVoid +%func_t = OpTypeFunction %void +%func = OpFunction %void None %func_t +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpTypeArray Element Type '6[%_runtimearr_uint]' is not " + "valid in Vulkan environment.\n %_arr__runtimearr_uint_uint_1 " + "= OpTypeArray %_runtimearr_uint %uint_1\n")); +} + +TEST_F(ValidateMemory, + VulkanRTAInsideArrayInsideStructWithRuntimeDescriptorArrayBad) { + std::string spirv = R"( +OpCapability RuntimeDescriptorArrayEXT +OpCapability Shader +OpExtension "SPV_EXT_descriptor_indexing" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +OpDecorate %array_t ArrayStride 4 +OpMemberDecorate %struct_t 0 Offset 0 +OpDecorate %struct_t Block +%uint_t = OpTypeInt 32 0 +%dim = OpConstant %uint_t 1 +%inner_array_t = OpTypeRuntimeArray %uint_t +%array_t = OpTypeArray %inner_array_t %dim +%struct_t = OpTypeStruct %array_t +%struct_ptr = OpTypePointer StorageBuffer %struct_t +%2 = OpVariable %struct_ptr StorageBuffer +%void = OpTypeVoid +%func_t = OpTypeFunction %void +%func = OpFunction %void None %func_t +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("OpTypeArray Element Type '6[%_runtimearr_uint]' is not " + "valid in Vulkan environment.\n %_arr__runtimearr_uint_uint_1 " + "= OpTypeArray %_runtimearr_uint %uint_1\n")); +} + +TEST_F(ValidateMemory, VulkanRTAStructInsideRTAWithRuntimeDescriptorArrayGood) { + std::string spirv = R"( +OpCapability RuntimeDescriptorArrayEXT +OpCapability Shader +OpExtension "SPV_EXT_descriptor_indexing" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +OpDecorate %array_t ArrayStride 4 +OpMemberDecorate %struct_t 0 Offset 0 +OpDecorate %struct_t Block +%uint_t = OpTypeInt 32 0 +%inner_array_t = OpTypeRuntimeArray %uint_t +%struct_t = OpTypeStruct %inner_array_t +%array_t = OpTypeRuntimeArray %struct_t +%array_ptr = OpTypePointer StorageBuffer %array_t +%2 = OpVariable %array_ptr StorageBuffer +%void = OpTypeVoid +%func_t = OpTypeFunction %void +%func = OpFunction %void None %func_t +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidateMemory, VulkanRTAStructInsideArrayGood) { + std::string spirv = R"( +OpCapability RuntimeDescriptorArrayEXT +OpCapability Shader +OpExtension "SPV_EXT_descriptor_indexing" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +OpDecorate %array_t ArrayStride 4 +OpMemberDecorate %struct_t 0 Offset 0 +OpDecorate %struct_t Block +%uint_t = OpTypeInt 32 0 +%inner_array_t = OpTypeRuntimeArray %uint_t +%struct_t = OpTypeStruct %inner_array_t +%array_size = OpConstant %uint_t 5 +%array_t = OpTypeArray %struct_t %array_size +%array_ptr = OpTypePointer StorageBuffer %array_t +%2 = OpVariable %array_ptr StorageBuffer +%void = OpTypeVoid +%func_t = OpTypeFunction %void +%func = OpFunction %void None %func_t +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_modes_test.cpp b/test/val/val_modes_test.cpp new file mode 100644 index 000000000..7f1ef093d --- /dev/null +++ b/test/val/val_modes_test.cpp @@ -0,0 +1,801 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "source/spirv_target_env.h" +#include "test/test_fixture.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::Combine; +using ::testing::HasSubstr; +using ::testing::Values; +using ::testing::ValuesIn; + +using ValidateMode = spvtest::ValidateBase; + +const std::string kVoidFunction = R"(%void = OpTypeVoid +%void_fn = OpTypeFunction %void +%main = OpFunction %void None %void_fn +%entry = OpLabel +OpReturn +OpFunctionEnd +)"; + +TEST_F(ValidateMode, GLComputeNoMode) { + const std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +)" + kVoidFunction; + + CompileSuccessfully(spirv); + EXPECT_THAT(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateMode, GLComputeNoModeVulkan) { + const std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +)" + kVoidFunction; + + spv_target_env env = SPV_ENV_VULKAN_1_0; + CompileSuccessfully(spirv, env); + EXPECT_THAT(SPV_ERROR_INVALID_DATA, ValidateInstructions(env)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("In the Vulkan environment, GLCompute execution model entry " + "points require either the LocalSize execution mode or an " + "object decorated with WorkgroupSize must be specified.")); +} + +TEST_F(ValidateMode, GLComputeNoModeVulkanWorkgroupSize) { + const std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpDecorate %int3_1 BuiltIn WorkgroupSize +%int = OpTypeInt 32 0 +%int3 = OpTypeVector %int 3 +%int_1 = OpConstant %int 1 +%int3_1 = OpConstantComposite %int3 %int_1 %int_1 %int_1 +)" + kVoidFunction; + + spv_target_env env = SPV_ENV_VULKAN_1_0; + CompileSuccessfully(spirv, env); + EXPECT_THAT(SPV_SUCCESS, ValidateInstructions(env)); +} + +TEST_F(ValidateMode, GLComputeVulkanLocalSize) { + const std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpExecutionMode %main LocalSize 1 1 1 +)" + kVoidFunction; + + spv_target_env env = SPV_ENV_VULKAN_1_0; + CompileSuccessfully(spirv, env); + EXPECT_THAT(SPV_SUCCESS, ValidateInstructions(env)); +} + +TEST_F(ValidateMode, FragmentOriginLowerLeftVulkan) { + const std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginLowerLeft +)" + kVoidFunction; + + spv_target_env env = SPV_ENV_VULKAN_1_0; + CompileSuccessfully(spirv, env); + EXPECT_THAT(SPV_ERROR_INVALID_DATA, ValidateInstructions(env)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("In the Vulkan environment, the OriginLowerLeft " + "execution mode must not be used.")); +} + +TEST_F(ValidateMode, FragmentPixelCenterIntegerVulkan) { + const std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpExecutionMode %main PixelCenterInteger +)" + kVoidFunction; + + spv_target_env env = SPV_ENV_VULKAN_1_0; + CompileSuccessfully(spirv, env); + EXPECT_THAT(SPV_ERROR_INVALID_DATA, ValidateInstructions(env)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("In the Vulkan environment, the PixelCenterInteger " + "execution mode must not be used.")); +} + +TEST_F(ValidateMode, GeometryNoOutputMode) { + const std::string spirv = R"( +OpCapability Geometry +OpMemoryModel Logical GLSL450 +OpEntryPoint Geometry %main "main" +OpExecutionMode %main InputPoints +)" + kVoidFunction; + + CompileSuccessfully(spirv); + EXPECT_THAT(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Geometry execution model entry points must specify " + "exactly one of OutputPoints, OutputLineStrip or " + "OutputTriangleStrip execution modes.")); +} + +TEST_F(ValidateMode, GeometryNoInputMode) { + const std::string spirv = R"( +OpCapability Geometry +OpMemoryModel Logical GLSL450 +OpEntryPoint Geometry %main "main" +OpExecutionMode %main OutputPoints +)" + kVoidFunction; + + CompileSuccessfully(spirv); + EXPECT_THAT(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Geometry execution model entry points must specify exactly " + "one of InputPoints, InputLines, InputLinesAdjacency, " + "Triangles or InputTrianglesAdjacency execution modes.")); +} + +TEST_F(ValidateMode, FragmentNoOrigin) { + const std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +)" + kVoidFunction; + + CompileSuccessfully(spirv); + EXPECT_THAT(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Fragment execution model entry points require either an " + "OriginUpperLeft or OriginLowerLeft execution mode.")); +} + +TEST_F(ValidateMode, FragmentBothOrigins) { + const std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpExecutionMode %main OriginLowerLeft +)" + kVoidFunction; + + CompileSuccessfully(spirv); + EXPECT_THAT(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Fragment execution model entry points can only specify one of " + "OriginUpperLeft or OriginLowerLeft execution modes.")); +} + +TEST_F(ValidateMode, FragmentDepthGreaterAndLess) { + const std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpExecutionMode %main DepthGreater +OpExecutionMode %main DepthLess +)" + kVoidFunction; + + CompileSuccessfully(spirv); + EXPECT_THAT(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Fragment execution model entry points can specify at " + "most one of DepthGreater, DepthLess or DepthUnchanged " + "execution modes.")); +} + +TEST_F(ValidateMode, FragmentDepthGreaterAndUnchanged) { + const std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpExecutionMode %main DepthGreater +OpExecutionMode %main DepthUnchanged +)" + kVoidFunction; + + CompileSuccessfully(spirv); + EXPECT_THAT(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Fragment execution model entry points can specify at " + "most one of DepthGreater, DepthLess or DepthUnchanged " + "execution modes.")); +} + +TEST_F(ValidateMode, FragmentDepthLessAndUnchanged) { + const std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpExecutionMode %main DepthLess +OpExecutionMode %main DepthUnchanged +)" + kVoidFunction; + + CompileSuccessfully(spirv); + EXPECT_THAT(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Fragment execution model entry points can specify at " + "most one of DepthGreater, DepthLess or DepthUnchanged " + "execution modes.")); +} + +TEST_F(ValidateMode, FragmentAllDepths) { + const std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpExecutionMode %main DepthGreater +OpExecutionMode %main DepthLess +OpExecutionMode %main DepthUnchanged +)" + kVoidFunction; + + CompileSuccessfully(spirv); + EXPECT_THAT(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Fragment execution model entry points can specify at " + "most one of DepthGreater, DepthLess or DepthUnchanged " + "execution modes.")); +} + +TEST_F(ValidateMode, TessellationControlSpacingEqualAndFractionalOdd) { + const std::string spirv = R"( +OpCapability Tessellation +OpMemoryModel Logical GLSL450 +OpEntryPoint TessellationControl %main "main" +OpExecutionMode %main SpacingEqual +OpExecutionMode %main SpacingFractionalOdd +)" + kVoidFunction; + + CompileSuccessfully(spirv); + EXPECT_THAT(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Tessellation execution model entry points can specify " + "at most one of SpacingEqual, SpacingFractionalOdd or " + "SpacingFractionalEven execution modes.")); +} + +TEST_F(ValidateMode, TessellationControlSpacingEqualAndSpacingFractionalEven) { + const std::string spirv = R"( +OpCapability Tessellation +OpMemoryModel Logical GLSL450 +OpEntryPoint TessellationControl %main "main" +OpExecutionMode %main SpacingEqual +OpExecutionMode %main SpacingFractionalEven +)" + kVoidFunction; + + CompileSuccessfully(spirv); + EXPECT_THAT(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Tessellation execution model entry points can specify " + "at most one of SpacingEqual, SpacingFractionalOdd or " + "SpacingFractionalEven execution modes.")); +} + +TEST_F(ValidateMode, + TessellationControlSpacingFractionalOddAndSpacingFractionalEven) { + const std::string spirv = R"( +OpCapability Tessellation +OpMemoryModel Logical GLSL450 +OpEntryPoint TessellationControl %main "main" +OpExecutionMode %main SpacingFractionalOdd +OpExecutionMode %main SpacingFractionalEven +)" + kVoidFunction; + + CompileSuccessfully(spirv); + EXPECT_THAT(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Tessellation execution model entry points can specify " + "at most one of SpacingEqual, SpacingFractionalOdd or " + "SpacingFractionalEven execution modes.")); +} + +TEST_F(ValidateMode, TessellationControlAllSpacing) { + const std::string spirv = R"( +OpCapability Tessellation +OpMemoryModel Logical GLSL450 +OpEntryPoint TessellationControl %main "main" +OpExecutionMode %main SpacingEqual +OpExecutionMode %main SpacingFractionalOdd +OpExecutionMode %main SpacingFractionalEven +)" + kVoidFunction; + + CompileSuccessfully(spirv); + EXPECT_THAT(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Tessellation execution model entry points can specify " + "at most one of SpacingEqual, SpacingFractionalOdd or " + "SpacingFractionalEven execution modes.")); +} + +TEST_F(ValidateMode, + TessellationEvaluationSpacingEqualAndSpacingFractionalOdd) { + const std::string spirv = R"( +OpCapability Tessellation +OpMemoryModel Logical GLSL450 +OpEntryPoint TessellationEvaluation %main "main" +OpExecutionMode %main SpacingEqual +OpExecutionMode %main SpacingFractionalOdd +)" + kVoidFunction; + + CompileSuccessfully(spirv); + EXPECT_THAT(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Tessellation execution model entry points can specify " + "at most one of SpacingEqual, SpacingFractionalOdd or " + "SpacingFractionalEven execution modes.")); +} + +TEST_F(ValidateMode, + TessellationEvaluationSpacingEqualAndSpacingFractionalEven) { + const std::string spirv = R"( +OpCapability Tessellation +OpMemoryModel Logical GLSL450 +OpEntryPoint TessellationEvaluation %main "main" +OpExecutionMode %main SpacingEqual +OpExecutionMode %main SpacingFractionalEven +)" + kVoidFunction; + + CompileSuccessfully(spirv); + EXPECT_THAT(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Tessellation execution model entry points can specify " + "at most one of SpacingEqual, SpacingFractionalOdd or " + "SpacingFractionalEven execution modes.")); +} + +TEST_F(ValidateMode, + TessellationEvaluationSpacingFractionalOddAndSpacingFractionalEven) { + const std::string spirv = R"( +OpCapability Tessellation +OpMemoryModel Logical GLSL450 +OpEntryPoint TessellationEvaluation %main "main" +OpExecutionMode %main SpacingFractionalOdd +OpExecutionMode %main SpacingFractionalEven +)" + kVoidFunction; + + CompileSuccessfully(spirv); + EXPECT_THAT(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Tessellation execution model entry points can specify " + "at most one of SpacingEqual, SpacingFractionalOdd or " + "SpacingFractionalEven execution modes.")); +} + +TEST_F(ValidateMode, TessellationEvaluationAllSpacing) { + const std::string spirv = R"( +OpCapability Tessellation +OpMemoryModel Logical GLSL450 +OpEntryPoint TessellationEvaluation %main "main" +OpExecutionMode %main SpacingEqual +OpExecutionMode %main SpacingFractionalOdd +OpExecutionMode %main SpacingFractionalEven +)" + kVoidFunction; + + CompileSuccessfully(spirv); + EXPECT_THAT(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Tessellation execution model entry points can specify " + "at most one of SpacingEqual, SpacingFractionalOdd or " + "SpacingFractionalEven execution modes.")); +} + +TEST_F(ValidateMode, TessellationControlBothVertex) { + const std::string spirv = R"( +OpCapability Tessellation +OpMemoryModel Logical GLSL450 +OpEntryPoint TessellationControl %main "main" +OpExecutionMode %main VertexOrderCw +OpExecutionMode %main VertexOrderCcw +)" + kVoidFunction; + + CompileSuccessfully(spirv); + EXPECT_THAT(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Tessellation execution model entry points can specify at most " + "one of VertexOrderCw or VertexOrderCcw execution modes.")); +} + +TEST_F(ValidateMode, TessellationEvaluationBothVertex) { + const std::string spirv = R"( +OpCapability Tessellation +OpMemoryModel Logical GLSL450 +OpEntryPoint TessellationEvaluation %main "main" +OpExecutionMode %main VertexOrderCw +OpExecutionMode %main VertexOrderCcw +)" + kVoidFunction; + + CompileSuccessfully(spirv); + EXPECT_THAT(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Tessellation execution model entry points can specify at most " + "one of VertexOrderCw or VertexOrderCcw execution modes.")); +} + +using ValidateModeGeometry = spvtest::ValidateBase, + std::tuple>>; + +TEST_P(ValidateModeGeometry, ExecutionMode) { + std::vector input_modes; + std::vector output_modes; + input_modes.push_back(std::get<0>(std::get<0>(GetParam()))); + input_modes.push_back(std::get<1>(std::get<0>(GetParam()))); + input_modes.push_back(std::get<2>(std::get<0>(GetParam()))); + input_modes.push_back(std::get<3>(std::get<0>(GetParam()))); + input_modes.push_back(std::get<4>(std::get<0>(GetParam()))); + output_modes.push_back(std::get<0>(std::get<1>(GetParam()))); + output_modes.push_back(std::get<1>(std::get<1>(GetParam()))); + output_modes.push_back(std::get<2>(std::get<1>(GetParam()))); + + std::ostringstream sstr; + sstr << "OpCapability Geometry\n"; + sstr << "OpMemoryModel Logical GLSL450\n"; + sstr << "OpEntryPoint Geometry %main \"main\"\n"; + size_t num_input_modes = 0; + for (auto input : input_modes) { + if (!input.empty()) { + num_input_modes++; + sstr << "OpExecutionMode %main " << input << "\n"; + } + } + size_t num_output_modes = 0; + for (auto output : output_modes) { + if (!output.empty()) { + num_output_modes++; + sstr << "OpExecutionMode %main " << output << "\n"; + } + } + sstr << "%void = OpTypeVoid\n"; + sstr << "%void_fn = OpTypeFunction %void\n"; + sstr << "%int = OpTypeInt 32 0\n"; + sstr << "%int1 = OpConstant %int 1\n"; + sstr << "%main = OpFunction %void None %void_fn\n"; + sstr << "%entry = OpLabel\n"; + sstr << "OpReturn\n"; + sstr << "OpFunctionEnd\n"; + + CompileSuccessfully(sstr.str()); + if (num_input_modes == 1 && num_output_modes == 1) { + EXPECT_THAT(SPV_SUCCESS, ValidateInstructions()); + } else { + EXPECT_THAT(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + if (num_input_modes != 1) { + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Geometry execution model entry points must " + "specify exactly one of InputPoints, InputLines, " + "InputLinesAdjacency, Triangles or " + "InputTrianglesAdjacency execution modes.")); + } else { + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Geometry execution model entry points must specify " + "exactly one of OutputPoints, OutputLineStrip or " + "OutputTriangleStrip execution modes.")); + } + } +} + +INSTANTIATE_TEST_CASE_P( + GeometryRequiredModes, ValidateModeGeometry, + Combine(Combine(Values("InputPoints", ""), Values("InputLines", ""), + Values("InputLinesAdjacency", ""), Values("Triangles", ""), + Values("InputTrianglesAdjacency", "")), + Combine(Values("OutputPoints", ""), Values("OutputLineStrip", ""), + Values("OutputTriangleStrip", "")))); + +using ValidateModeExecution = + spvtest::ValidateBase>; + +TEST_P(ValidateModeExecution, ExecutionMode) { + const spv_result_t expectation = std::get<0>(GetParam()); + const std::string error = std::get<1>(GetParam()); + const std::string model = std::get<2>(GetParam()); + const std::string mode = std::get<3>(GetParam()); + const spv_target_env env = std::get<4>(GetParam()); + + std::ostringstream sstr; + sstr << "OpCapability Shader\n"; + sstr << "OpCapability Geometry\n"; + sstr << "OpCapability Tessellation\n"; + sstr << "OpCapability TransformFeedback\n"; + if (!spvIsVulkanEnv(env)) { + sstr << "OpCapability Kernel\n"; + if (env == SPV_ENV_UNIVERSAL_1_3) { + sstr << "OpCapability SubgroupDispatch\n"; + } + } + sstr << "OpMemoryModel Logical GLSL450\n"; + sstr << "OpEntryPoint " << model << " %main \"main\"\n"; + if (mode.find("LocalSizeId") == 0 || mode.find("LocalSizeHintId") == 0 || + mode.find("SubgroupsPerWorkgroupId") == 0) { + sstr << "OpExecutionModeId %main " << mode << "\n"; + } else { + sstr << "OpExecutionMode %main " << mode << "\n"; + } + if (model == "Geometry") { + if (!(mode.find("InputPoints") == 0 || mode.find("InputLines") == 0 || + mode.find("InputLinesAdjacency") == 0 || + mode.find("Triangles") == 0 || + mode.find("InputTrianglesAdjacency") == 0)) { + // Exactly one of the above modes is required for Geometry shaders. + sstr << "OpExecutionMode %main InputPoints\n"; + } + if (!(mode.find("OutputPoints") == 0 || mode.find("OutputLineStrip") == 0 || + mode.find("OutputTriangleStrip") == 0)) { + // Exactly one of the above modes is required for Geometry shaders. + sstr << "OpExecutionMode %main OutputPoints\n"; + } + } else if (model == "Fragment") { + if (!(mode.find("OriginUpperLeft") == 0 || + mode.find("OriginLowerLeft") == 0)) { + // Exactly one of the above modes is required for Fragment shaders. + sstr << "OpExecutionMode %main OriginUpperLeft\n"; + } + } + sstr << "%void = OpTypeVoid\n"; + sstr << "%void_fn = OpTypeFunction %void\n"; + sstr << "%int = OpTypeInt 32 0\n"; + sstr << "%int1 = OpConstant %int 1\n"; + sstr << "%main = OpFunction %void None %void_fn\n"; + sstr << "%entry = OpLabel\n"; + sstr << "OpReturn\n"; + sstr << "OpFunctionEnd\n"; + + CompileSuccessfully(sstr.str(), env); + EXPECT_THAT(expectation, ValidateInstructions(env)); + if (expectation != SPV_SUCCESS) { + EXPECT_THAT(getDiagnosticString(), HasSubstr(error)); + } +} + +INSTANTIATE_TEST_CASE_P( + ValidateModeGeometryOnlyGoodSpv10, ValidateModeExecution, + Combine(Values(SPV_SUCCESS), Values(""), Values("Geometry"), + Values("Invocations 3", "InputPoints", "InputLines", + "InputLinesAdjacency", "InputTrianglesAdjacency", + "OutputPoints", "OutputLineStrip", "OutputTriangleStrip"), + Values(SPV_ENV_UNIVERSAL_1_0))); + +INSTANTIATE_TEST_CASE_P( + ValidateModeGeometryOnlyBadSpv10, ValidateModeExecution, + Combine(Values(SPV_ERROR_INVALID_DATA), + Values("Execution mode can only be used with the Geometry " + "execution model."), + Values("Fragment", "TessellationEvaluation", "TessellationControl", + "GLCompute", "Vertex", "Kernel"), + Values("Invocations 3", "InputPoints", "InputLines", + "InputLinesAdjacency", "InputTrianglesAdjacency", + "OutputPoints", "OutputLineStrip", "OutputTriangleStrip"), + Values(SPV_ENV_UNIVERSAL_1_0))); + +INSTANTIATE_TEST_CASE_P( + ValidateModeTessellationOnlyGoodSpv10, ValidateModeExecution, + Combine(Values(SPV_SUCCESS), Values(""), + Values("TessellationControl", "TessellationEvaluation"), + Values("SpacingEqual", "SpacingFractionalEven", + "SpacingFractionalOdd", "VertexOrderCw", "VertexOrderCcw", + "PointMode", "Quads", "Isolines"), + Values(SPV_ENV_UNIVERSAL_1_0))); + +INSTANTIATE_TEST_CASE_P( + ValidateModeTessellationOnlyBadSpv10, ValidateModeExecution, + Combine(Values(SPV_ERROR_INVALID_DATA), + Values("Execution mode can only be used with a tessellation " + "execution model."), + Values("Fragment", "Geometry", "GLCompute", "Vertex", "Kernel"), + Values("SpacingEqual", "SpacingFractionalEven", + "SpacingFractionalOdd", "VertexOrderCw", "VertexOrderCcw", + "PointMode", "Quads", "Isolines"), + Values(SPV_ENV_UNIVERSAL_1_0))); + +INSTANTIATE_TEST_CASE_P(ValidateModeGeometryAndTessellationGoodSpv10, + ValidateModeExecution, + Combine(Values(SPV_SUCCESS), Values(""), + Values("TessellationControl", + "TessellationEvaluation", "Geometry"), + Values("Triangles", "OutputVertices 3"), + Values(SPV_ENV_UNIVERSAL_1_0))); + +INSTANTIATE_TEST_CASE_P( + ValidateModeGeometryAndTessellationBadSpv10, ValidateModeExecution, + Combine(Values(SPV_ERROR_INVALID_DATA), + Values("Execution mode can only be used with a Geometry or " + "tessellation execution model."), + Values("Fragment", "GLCompute", "Vertex", "Kernel"), + Values("Triangles", "OutputVertices 3"), + Values(SPV_ENV_UNIVERSAL_1_0))); + +INSTANTIATE_TEST_CASE_P( + ValidateModeFragmentOnlyGoodSpv10, ValidateModeExecution, + Combine(Values(SPV_SUCCESS), Values(""), Values("Fragment"), + Values("PixelCenterInteger", "OriginUpperLeft", "OriginLowerLeft", + "EarlyFragmentTests", "DepthReplacing", "DepthLess", + "DepthUnchanged"), + Values(SPV_ENV_UNIVERSAL_1_0))); + +INSTANTIATE_TEST_CASE_P( + ValidateModeFragmentOnlyBadSpv10, ValidateModeExecution, + Combine(Values(SPV_ERROR_INVALID_DATA), + Values("Execution mode can only be used with the Fragment " + "execution model."), + Values("Geometry", "TessellationControl", "TessellationEvaluation", + "GLCompute", "Vertex", "Kernel"), + Values("PixelCenterInteger", "OriginUpperLeft", "OriginLowerLeft", + "EarlyFragmentTests", "DepthReplacing", "DepthLess", + "DepthUnchanged"), + Values(SPV_ENV_UNIVERSAL_1_0))); + +INSTANTIATE_TEST_CASE_P(ValidateModeKernelOnlyGoodSpv13, ValidateModeExecution, + Combine(Values(SPV_SUCCESS), Values(""), + Values("Kernel"), + Values("LocalSizeHint 1 1 1", "VecTypeHint 4", + "ContractionOff", + "LocalSizeHintId %int1"), + Values(SPV_ENV_UNIVERSAL_1_3))); + +INSTANTIATE_TEST_CASE_P( + ValidateModeKernelOnlyBadSpv13, ValidateModeExecution, + Combine( + Values(SPV_ERROR_INVALID_DATA), + Values( + "Execution mode can only be used with the Kernel execution model."), + Values("Geometry", "TessellationControl", "TessellationEvaluation", + "GLCompute", "Vertex", "Fragment"), + Values("LocalSizeHint 1 1 1", "VecTypeHint 4", "ContractionOff", + "LocalSizeHintId %int1"), + Values(SPV_ENV_UNIVERSAL_1_3))); + +INSTANTIATE_TEST_CASE_P( + ValidateModeGLComputeAndKernelGoodSpv13, ValidateModeExecution, + Combine(Values(SPV_SUCCESS), Values(""), Values("Kernel", "GLCompute"), + Values("LocalSize 1 1 1", "LocalSizeId %int1 %int1 %int1"), + Values(SPV_ENV_UNIVERSAL_1_3))); + +INSTANTIATE_TEST_CASE_P( + ValidateModeGLComputeAndKernelBadSpv13, ValidateModeExecution, + Combine(Values(SPV_ERROR_INVALID_DATA), + Values("Execution mode can only be used with a Kernel or GLCompute " + "execution model."), + Values("Geometry", "TessellationControl", "TessellationEvaluation", + "Fragment", "Vertex"), + Values("LocalSize 1 1 1", "LocalSizeId %int1 %int1 %int1"), + Values(SPV_ENV_UNIVERSAL_1_3))); + +INSTANTIATE_TEST_CASE_P( + ValidateModeAllGoodSpv13, ValidateModeExecution, + Combine(Values(SPV_SUCCESS), Values(""), + Values("Kernel", "GLCompute", "Geometry", "TessellationControl", + "TessellationEvaluation", "Fragment", "Vertex"), + Values("Xfb", "Initializer", "Finalizer", "SubgroupSize 1", + "SubgroupsPerWorkgroup 1", "SubgroupsPerWorkgroupId %int1"), + Values(SPV_ENV_UNIVERSAL_1_3))); + +TEST_F(ValidateModeExecution, MeshNVLocalSize) { + const std::string spirv = R"( +OpCapability Shader +OpCapability MeshShadingNV +OpExtension "SPV_NV_mesh_shader" +OpMemoryModel Logical GLSL450 +OpEntryPoint MeshNV %main "main" +OpExecutionMode %main LocalSize 1 1 1 +)" + kVoidFunction; + + CompileSuccessfully(spirv); + EXPECT_THAT(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateModeExecution, TaskNVLocalSize) { + const std::string spirv = R"( +OpCapability Shader +OpCapability MeshShadingNV +OpExtension "SPV_NV_mesh_shader" +OpMemoryModel Logical GLSL450 +OpEntryPoint TaskNV %main "main" +OpExecutionMode %main LocalSize 1 1 1 +)" + kVoidFunction; + + CompileSuccessfully(spirv); + EXPECT_THAT(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateModeExecution, MeshNVOutputPoints) { + const std::string spirv = R"( +OpCapability Shader +OpCapability MeshShadingNV +OpExtension "SPV_NV_mesh_shader" +OpMemoryModel Logical GLSL450 +OpEntryPoint MeshNV %main "main" +OpExecutionMode %main OutputPoints +)" + kVoidFunction; + + CompileSuccessfully(spirv); + EXPECT_THAT(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateModeExecution, MeshNVOutputVertices) { + const std::string spirv = R"( +OpCapability Shader +OpCapability MeshShadingNV +OpExtension "SPV_NV_mesh_shader" +OpMemoryModel Logical GLSL450 +OpEntryPoint MeshNV %main "main" +OpExecutionMode %main OutputVertices 42 +)" + kVoidFunction; + + CompileSuccessfully(spirv); + EXPECT_THAT(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateModeExecution, MeshNVLocalSizeId) { + const std::string spirv = R"( +OpCapability Shader +OpCapability MeshShadingNV +OpExtension "SPV_NV_mesh_shader" +OpMemoryModel Logical GLSL450 +OpEntryPoint MeshNV %main "main" +OpExecutionModeId %main LocalSizeId %int_1 %int_1 %int_1 +%int = OpTypeInt 32 0 +%int_1 = OpConstant %int 1 +)" + kVoidFunction; + + spv_target_env env = SPV_ENV_UNIVERSAL_1_3; + CompileSuccessfully(spirv, env); + EXPECT_THAT(SPV_SUCCESS, ValidateInstructions(env)); +} + +TEST_F(ValidateModeExecution, TaskNVLocalSizeId) { + const std::string spirv = R"( +OpCapability Shader +OpCapability MeshShadingNV +OpExtension "SPV_NV_mesh_shader" +OpMemoryModel Logical GLSL450 +OpEntryPoint TaskNV %main "main" +OpExecutionModeId %main LocalSizeId %int_1 %int_1 %int_1 +%int = OpTypeInt 32 0 +%int_1 = OpConstant %int 1 +)" + kVoidFunction; + + spv_target_env env = SPV_ENV_UNIVERSAL_1_3; + CompileSuccessfully(spirv, env); + EXPECT_THAT(SPV_SUCCESS, ValidateInstructions(env)); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_non_uniform_test.cpp b/test/val/val_non_uniform_test.cpp new file mode 100644 index 000000000..a185d8782 --- /dev/null +++ b/test/val/val_non_uniform_test.cpp @@ -0,0 +1,293 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::Combine; +using ::testing::HasSubstr; +using ::testing::Values; +using ::testing::ValuesIn; + +std::string GenerateShaderCode( + const std::string& body, + const std::string& capabilities_and_extensions = "", + const std::string& execution_model = "GLCompute") { + std::ostringstream ss; + ss << R"( +OpCapability Shader +OpCapability GroupNonUniform +OpCapability GroupNonUniformVote +OpCapability GroupNonUniformBallot +OpCapability GroupNonUniformShuffle +OpCapability GroupNonUniformShuffleRelative +OpCapability GroupNonUniformArithmetic +OpCapability GroupNonUniformClustered +OpCapability GroupNonUniformQuad +)"; + + ss << capabilities_and_extensions; + ss << "OpMemoryModel Logical GLSL450\n"; + ss << "OpEntryPoint " << execution_model << " %main \"main\"\n"; + if (execution_model == "GLCompute") { + ss << "OpExecutionMode %main LocalSize 1 1 1\n"; + } + + ss << R"( +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%u32 = OpTypeInt 32 0 +%int = OpTypeInt 32 1 +%float = OpTypeFloat 32 +%u32vec4 = OpTypeVector %u32 4 +%u32vec3 = OpTypeVector %u32 3 + +%true = OpConstantTrue %bool +%false = OpConstantFalse %bool + +%u32_0 = OpConstant %u32 0 + +%float_0 = OpConstant %float 0 + +%u32vec4_null = OpConstantComposite %u32vec4 %u32_0 %u32_0 %u32_0 %u32_0 +%u32vec3_null = OpConstantComposite %u32vec3 %u32_0 %u32_0 %u32_0 + +%cross_device = OpConstant %u32 0 +%device = OpConstant %u32 1 +%workgroup = OpConstant %u32 2 +%subgroup = OpConstant %u32 3 +%invocation = OpConstant %u32 4 + +%reduce = OpConstant %u32 0 +%inclusive_scan = OpConstant %u32 1 +%exclusive_scan = OpConstant %u32 2 +%clustered_reduce = OpConstant %u32 3 + +%main = OpFunction %void None %func +%main_entry = OpLabel +)"; + + ss << body; + + ss << R"( +OpReturn +OpFunctionEnd)"; + + return ss.str(); +} + +SpvScope scopes[] = {SpvScopeCrossDevice, SpvScopeDevice, SpvScopeWorkgroup, + SpvScopeSubgroup, SpvScopeInvocation}; + +using GroupNonUniform = spvtest::ValidateBase< + std::tuple>; + +std::string ConvertScope(SpvScope scope) { + switch (scope) { + case SpvScopeCrossDevice: + return "%cross_device"; + case SpvScopeDevice: + return "%device"; + case SpvScopeWorkgroup: + return "%workgroup"; + case SpvScopeSubgroup: + return "%subgroup"; + case SpvScopeInvocation: + return "%invocation"; + default: + return ""; + } +} + +TEST_P(GroupNonUniform, Vulkan1p1) { + std::string opcode = std::get<0>(GetParam()); + std::string type = std::get<1>(GetParam()); + SpvScope execution_scope = std::get<2>(GetParam()); + std::string args = std::get<3>(GetParam()); + std::string error = std::get<4>(GetParam()); + + std::ostringstream sstr; + sstr << "%result = " << opcode << " "; + sstr << type << " "; + sstr << ConvertScope(execution_scope) << " "; + sstr << args << "\n"; + + CompileSuccessfully(GenerateShaderCode(sstr.str()), SPV_ENV_VULKAN_1_1); + spv_result_t result = ValidateInstructions(SPV_ENV_VULKAN_1_1); + if (error == "") { + if (execution_scope == SpvScopeSubgroup) { + EXPECT_EQ(SPV_SUCCESS, result); + } else { + EXPECT_EQ(SPV_ERROR_INVALID_DATA, result); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "in Vulkan environment Execution scope is limited to Subgroup")); + } + } else { + EXPECT_EQ(SPV_ERROR_INVALID_DATA, result); + EXPECT_THAT(getDiagnosticString(), HasSubstr(error)); + } +} + +TEST_P(GroupNonUniform, Spirv1p3) { + std::string opcode = std::get<0>(GetParam()); + std::string type = std::get<1>(GetParam()); + SpvScope execution_scope = std::get<2>(GetParam()); + std::string args = std::get<3>(GetParam()); + std::string error = std::get<4>(GetParam()); + + std::ostringstream sstr; + sstr << "%result = " << opcode << " "; + sstr << type << " "; + sstr << ConvertScope(execution_scope) << " "; + sstr << args << "\n"; + + CompileSuccessfully(GenerateShaderCode(sstr.str()), SPV_ENV_UNIVERSAL_1_3); + spv_result_t result = ValidateInstructions(SPV_ENV_UNIVERSAL_1_3); + if (error == "") { + if (execution_scope == SpvScopeSubgroup || + execution_scope == SpvScopeWorkgroup) { + EXPECT_EQ(SPV_SUCCESS, result); + } else { + EXPECT_EQ(SPV_ERROR_INVALID_DATA, result); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Execution scope is limited to Subgroup or Workgroup")); + } + } else { + EXPECT_EQ(SPV_ERROR_INVALID_DATA, result); + EXPECT_THAT(getDiagnosticString(), HasSubstr(error)); + } +} + +INSTANTIATE_TEST_CASE_P(GroupNonUniformElect, GroupNonUniform, + Combine(Values("OpGroupNonUniformElect"), + Values("%bool"), ValuesIn(scopes), Values(""), + Values(""))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformVote, GroupNonUniform, + Combine(Values("OpGroupNonUniformAll", + "OpGroupNonUniformAny", + "OpGroupNonUniformAllEqual"), + Values("%bool"), ValuesIn(scopes), + Values("%true"), Values(""))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformBroadcast, GroupNonUniform, + Combine(Values("OpGroupNonUniformBroadcast"), + Values("%bool"), ValuesIn(scopes), + Values("%true %u32_0"), Values(""))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformBroadcastFirst, GroupNonUniform, + Combine(Values("OpGroupNonUniformBroadcastFirst"), + Values("%bool"), ValuesIn(scopes), + Values("%true"), Values(""))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformBallot, GroupNonUniform, + Combine(Values("OpGroupNonUniformBallot"), + Values("%u32vec4"), ValuesIn(scopes), + Values("%true"), Values(""))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformInverseBallot, GroupNonUniform, + Combine(Values("OpGroupNonUniformInverseBallot"), + Values("%bool"), ValuesIn(scopes), + Values("%u32vec4_null"), Values(""))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformBallotBitExtract, GroupNonUniform, + Combine(Values("OpGroupNonUniformBallotBitExtract"), + Values("%bool"), ValuesIn(scopes), + Values("%u32vec4_null %u32_0"), Values(""))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformBallotBitCount, GroupNonUniform, + Combine(Values("OpGroupNonUniformBallotBitCount"), + Values("%u32"), ValuesIn(scopes), + Values("Reduce %u32vec4_null"), Values(""))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformBallotFind, GroupNonUniform, + Combine(Values("OpGroupNonUniformBallotFindLSB", + "OpGroupNonUniformBallotFindMSB"), + Values("%u32"), ValuesIn(scopes), + Values("%u32vec4_null"), Values(""))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformShuffle, GroupNonUniform, + Combine(Values("OpGroupNonUniformShuffle", + "OpGroupNonUniformShuffleXor", + "OpGroupNonUniformShuffleUp", + "OpGroupNonUniformShuffleDown"), + Values("%u32"), ValuesIn(scopes), + Values("%u32_0 %u32_0"), Values(""))); + +INSTANTIATE_TEST_CASE_P( + GroupNonUniformIntegerArithmetic, GroupNonUniform, + Combine(Values("OpGroupNonUniformIAdd", "OpGroupNonUniformIMul", + "OpGroupNonUniformSMin", "OpGroupNonUniformUMin", + "OpGroupNonUniformSMax", "OpGroupNonUniformUMax", + "OpGroupNonUniformBitwiseAnd", "OpGroupNonUniformBitwiseOr", + "OpGroupNonUniformBitwiseXor"), + Values("%u32"), ValuesIn(scopes), Values("Reduce %u32_0"), + Values(""))); + +INSTANTIATE_TEST_CASE_P( + GroupNonUniformFloatArithmetic, GroupNonUniform, + Combine(Values("OpGroupNonUniformFAdd", "OpGroupNonUniformFMul", + "OpGroupNonUniformFMin", "OpGroupNonUniformFMax"), + Values("%float"), ValuesIn(scopes), Values("Reduce %float_0"), + Values(""))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformLogicalArithmetic, GroupNonUniform, + Combine(Values("OpGroupNonUniformLogicalAnd", + "OpGroupNonUniformLogicalOr", + "OpGroupNonUniformLogicalXor"), + Values("%bool"), ValuesIn(scopes), + Values("Reduce %true"), Values(""))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformQuad, GroupNonUniform, + Combine(Values("OpGroupNonUniformQuadBroadcast", + "OpGroupNonUniformQuadSwap"), + Values("%u32"), ValuesIn(scopes), + Values("%u32_0 %u32_0"), Values(""))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformBallotBitCountScope, GroupNonUniform, + Combine(Values("OpGroupNonUniformBallotBitCount"), + Values("%u32"), ValuesIn(scopes), + Values("Reduce %u32vec4_null"), Values(""))); + +INSTANTIATE_TEST_CASE_P( + GroupNonUniformBallotBitCountBadResultType, GroupNonUniform, + Combine( + Values("OpGroupNonUniformBallotBitCount"), Values("%float", "%int"), + Values(SpvScopeSubgroup), Values("Reduce %u32vec4_null"), + Values("Expected Result Type to be an unsigned integer type scalar."))); + +INSTANTIATE_TEST_CASE_P(GroupNonUniformBallotBitCountBadValue, GroupNonUniform, + Combine(Values("OpGroupNonUniformBallotBitCount"), + Values("%u32"), Values(SpvScopeSubgroup), + Values("Reduce %u32vec3_null", "Reduce %u32_0", + "Reduce %float_0"), + Values("Expected Value to be a vector of four " + "components of integer type scalar"))); + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_primitives_test.cpp b/test/val/val_primitives_test.cpp new file mode 100644 index 000000000..04d0a4f8a --- /dev/null +++ b/test/val/val_primitives_test.cpp @@ -0,0 +1,321 @@ +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::HasSubstr; +using ::testing::Not; + +using ValidatePrimitives = spvtest::ValidateBase; + +std::string GenerateShaderCode( + const std::string& body, + const std::string& capabilities_and_extensions = + "OpCapability GeometryStreams", + const std::string& execution_model = "Geometry") { + std::ostringstream ss; + ss << capabilities_and_extensions << "\n"; + ss << "OpMemoryModel Logical GLSL450\n"; + ss << "OpEntryPoint " << execution_model << " %main \"main\"\n"; + if (execution_model == "Geometry") { + ss << "OpExecutionMode %main InputPoints\n"; + ss << "OpExecutionMode %main OutputPoints\n"; + } + + ss << R"( +%void = OpTypeVoid +%func = OpTypeFunction %void +%f32 = OpTypeFloat 32 +%u32 = OpTypeInt 32 0 +%u32vec4 = OpTypeVector %u32 4 + +%f32_0 = OpConstant %f32 0 +%u32_0 = OpConstant %u32 0 +%u32_1 = OpConstant %u32 1 +%u32_2 = OpConstant %u32 2 +%u32_3 = OpConstant %u32 3 +%u32vec4_0123 = OpConstantComposite %u32vec4 %u32_0 %u32_1 %u32_2 %u32_3 + +%main = OpFunction %void None %func +%main_entry = OpLabel +)"; + + ss << body; + + ss << R"( +OpReturn +OpFunctionEnd)"; + + return ss.str(); +} + +// Returns SPIR-V assembly fragment representing a function call, +// the end of the callee body, and the preamble and body of the called +// function with the given body, but missing the final return and +// function-end. The result is of the form where it can be used in the +// |body| argument to GenerateShaderCode. +std::string CallAndCallee(const std::string& body) { + std::ostringstream ss; + ss << R"( +%dummy = OpFunctionCall %void %foo +OpReturn +OpFunctionEnd + +%foo = OpFunction %void None %func +%foo_entry = OpLabel +)"; + + ss << body; + + return ss.str(); +} + +// OpEmitVertex doesn't have any parameters, so other validation +// is handled by the binary parser, and generic dominance checks. +TEST_F(ValidatePrimitives, EmitVertexSuccess) { + CompileSuccessfully( + GenerateShaderCode("OpEmitVertex", "OpCapability Geometry")); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidatePrimitives, EmitVertexFailMissingCapability) { + CompileSuccessfully( + GenerateShaderCode("OpEmitVertex", "OpCapability Shader", "Vertex")); + EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Opcode EmitVertex requires one of these capabilities: Geometry")); +} + +TEST_F(ValidatePrimitives, EmitVertexFailWrongExecutionMode) { + CompileSuccessfully( + GenerateShaderCode("OpEmitVertex", "OpCapability Geometry", "Vertex")); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("EmitVertex instructions require Geometry execution model")); +} + +TEST_F(ValidatePrimitives, EmitVertexFailWrongExecutionModeNestedFunction) { + CompileSuccessfully(GenerateShaderCode(CallAndCallee("OpEmitVertex"), + "OpCapability Geometry", "Vertex")); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("EmitVertex instructions require Geometry execution model")); +} + +// OpEndPrimitive doesn't have any parameters, so other validation +// is handled by the binary parser, and generic dominance checks. +TEST_F(ValidatePrimitives, EndPrimitiveSuccess) { + CompileSuccessfully( + GenerateShaderCode("OpEndPrimitive", "OpCapability Geometry")); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidatePrimitives, EndPrimitiveFailMissingCapability) { + CompileSuccessfully( + GenerateShaderCode("OpEndPrimitive", "OpCapability Shader", "Vertex")); + EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "Opcode EndPrimitive requires one of these capabilities: Geometry")); +} + +TEST_F(ValidatePrimitives, EndPrimitiveFailWrongExecutionMode) { + CompileSuccessfully( + GenerateShaderCode("OpEndPrimitive", "OpCapability Geometry", "Vertex")); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("EndPrimitive instructions require Geometry execution model")); +} + +TEST_F(ValidatePrimitives, EndPrimitiveFailWrongExecutionModeNestedFunction) { + CompileSuccessfully(GenerateShaderCode(CallAndCallee("OpEndPrimitive"), + "OpCapability Geometry", "Vertex")); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("EndPrimitive instructions require Geometry execution model")); +} + +TEST_F(ValidatePrimitives, EmitStreamVertexSuccess) { + const std::string body = R"( +OpEmitStreamVertex %u32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidatePrimitives, EmitStreamVertexFailMissingCapability) { + CompileSuccessfully(GenerateShaderCode("OpEmitStreamVertex %u32_0", + "OpCapability Shader", "Vertex")); + EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Opcode EmitStreamVertex requires one of these " + "capabilities: GeometryStreams")); +} + +TEST_F(ValidatePrimitives, EmitStreamVertexFailWrongExecutionMode) { + CompileSuccessfully(GenerateShaderCode( + "OpEmitStreamVertex %u32_0", "OpCapability GeometryStreams", "Vertex")); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "EmitStreamVertex instructions require Geometry execution model")); +} + +TEST_F(ValidatePrimitives, + EmitStreamVertexFailWrongExecutionModeNestedFunction) { + CompileSuccessfully( + GenerateShaderCode(CallAndCallee("OpEmitStreamVertex %u32_0"), + "OpCapability GeometryStreams", "Vertex")); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "EmitStreamVertex instructions require Geometry execution model")); +} + +TEST_F(ValidatePrimitives, EmitStreamVertexNonInt) { + const std::string body = R"( +OpEmitStreamVertex %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("EmitStreamVertex: " + "expected Stream to be int scalar")); +} + +TEST_F(ValidatePrimitives, EmitStreamVertexNonScalar) { + const std::string body = R"( +OpEmitStreamVertex %u32vec4_0123 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("EmitStreamVertex: " + "expected Stream to be int scalar")); +} + +TEST_F(ValidatePrimitives, EmitStreamVertexNonConstant) { + const std::string body = R"( +%val1 = OpIAdd %u32 %u32_0 %u32_1 +OpEmitStreamVertex %val1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("EmitStreamVertex: " + "expected Stream to be constant instruction")); +} + +TEST_F(ValidatePrimitives, EndStreamPrimitiveSuccess) { + const std::string body = R"( +OpEndStreamPrimitive %u32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidatePrimitives, EndStreamPrimitiveFailMissingCapability) { + CompileSuccessfully(GenerateShaderCode("OpEndStreamPrimitive %u32_0", + "OpCapability Shader", "Vertex")); + EXPECT_EQ(SPV_ERROR_INVALID_CAPABILITY, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Opcode EndStreamPrimitive requires one of these " + "capabilities: GeometryStreams")); +} + +TEST_F(ValidatePrimitives, EndStreamPrimitiveFailWrongExecutionMode) { + CompileSuccessfully(GenerateShaderCode( + "OpEndStreamPrimitive %u32_0", "OpCapability GeometryStreams", "Vertex")); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "EndStreamPrimitive instructions require Geometry execution model")); +} + +TEST_F(ValidatePrimitives, + EndStreamPrimitiveFailWrongExecutionModeNestedFunction) { + CompileSuccessfully( + GenerateShaderCode(CallAndCallee("OpEndStreamPrimitive %u32_0"), + "OpCapability GeometryStreams", "Vertex")); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "EndStreamPrimitive instructions require Geometry execution model")); +} + +TEST_F(ValidatePrimitives, EndStreamPrimitiveNonInt) { + const std::string body = R"( +OpEndStreamPrimitive %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("EndStreamPrimitive: " + "expected Stream to be int scalar")); +} + +TEST_F(ValidatePrimitives, EndStreamPrimitiveNonScalar) { + const std::string body = R"( +OpEndStreamPrimitive %u32vec4_0123 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("EndStreamPrimitive: " + "expected Stream to be int scalar")); +} + +TEST_F(ValidatePrimitives, EndStreamPrimitiveNonConstant) { + const std::string body = R"( +%val1 = OpIAdd %u32 %u32_0 %u32_1 +OpEndStreamPrimitive %val1 +)"; + + CompileSuccessfully(GenerateShaderCode(body)); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("EndStreamPrimitive: " + "expected Stream to be constant instruction")); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_ssa_test.cpp b/test/val/val_ssa_test.cpp new file mode 100644 index 000000000..5d8fa4b7e --- /dev/null +++ b/test/val/val_ssa_test.cpp @@ -0,0 +1,1449 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validation tests for SSA + +#include +#include +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::HasSubstr; +using ::testing::MatchesRegex; + +using ValidateSSA = spvtest::ValidateBase>; + +TEST_F(ValidateSSA, Default) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %3 "" + OpExecutionMode %3 LocalSize 1 1 1 +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpFunction %1 None %2 +%4 = OpLabel + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateSSA, IdUndefinedBad) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpName %missing "missing" +%voidt = OpTypeVoid +%vfunct = OpTypeFunction %voidt +%func = OpFunction %vfunct None %missing +%flabel = OpLabel + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("missing")); +} + +TEST_F(ValidateSSA, IdRedefinedBad) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpName %2 "redefined" +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%2 = OpFunction %1 None %2 +%4 = OpLabel + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); +} + +TEST_F(ValidateSSA, DominateUsageBad) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpName %1 "not_dominant" +%2 = OpTypeFunction %1 ; uses %1 before it's definition +%1 = OpTypeVoid +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("not_dominant")); +} + +TEST_F(ValidateSSA, DominateUsageWithinBlockBad) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpName %bad "bad" +%voidt = OpTypeVoid +%funct = OpTypeFunction %voidt +%uintt = OpTypeInt 32 0 +%one = OpConstant %uintt 1 +%func = OpFunction %voidt None %funct +%entry = OpLabel +%sum = OpIAdd %uintt %one %bad +%bad = OpCopyObject %uintt %sum + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + MatchesRegex("ID .\\[%bad\\] has not been defined\n" + " %8 = OpIAdd %uint %uint_1 %bad\n")); +} + +TEST_F(ValidateSSA, DominateUsageSameInstructionBad) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpName %sum "sum" +%voidt = OpTypeVoid +%funct = OpTypeFunction %voidt +%uintt = OpTypeInt 32 0 +%one = OpConstant %uintt 1 +%func = OpFunction %voidt None %funct +%entry = OpLabel +%sum = OpIAdd %uintt %one %sum + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + MatchesRegex("ID .\\[%sum\\] has not been defined\n" + " %sum = OpIAdd %uint %uint_1 %sum\n")); +} + +TEST_F(ValidateSSA, ForwardNameGood) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpName %3 "main" +%1 = OpTypeVoid +%2 = OpTypeFunction %1 +%3 = OpFunction %1 None %2 +%4 = OpLabel + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateSSA, ForwardNameMissingTargetBad) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpName %5 "main" ; Target never defined +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("main")); +} + +TEST_F(ValidateSSA, ForwardMemberNameGood) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpMemberName %struct 0 "value" + OpMemberName %struct 1 "size" +%intt = OpTypeInt 32 1 +%uintt = OpTypeInt 32 0 +%struct = OpTypeStruct %intt %uintt +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateSSA, ForwardMemberNameMissingTargetBad) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpMemberName %struct 0 "value" + OpMemberName %bad 1 "size" ; Target is not defined +%intt = OpTypeInt 32 1 +%uintt = OpTypeInt 32 0 +%struct = OpTypeStruct %intt %uintt +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("The following forward referenced IDs have not been " + "defined:\n2[%2]")); +} + +TEST_F(ValidateSSA, ForwardDecorateGood) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpDecorate %var Restrict +%intt = OpTypeInt 32 1 +%ptrt = OpTypePointer UniformConstant %intt +%var = OpVariable %ptrt UniformConstant +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateSSA, ForwardDecorateInvalidIDBad) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpName %missing "missing" + OpDecorate %missing Restrict ;Missing ID +%voidt = OpTypeVoid +%intt = OpTypeInt 32 1 +%ptrt = OpTypePointer UniformConstant %intt +%var = OpVariable %ptrt UniformConstant +%2 = OpTypeFunction %voidt +%3 = OpFunction %voidt None %2 +%4 = OpLabel + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("missing")); +} + +TEST_F(ValidateSSA, ForwardMemberDecorateGood) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpMemberDecorate %struct 1 RowMajor +%intt = OpTypeInt 32 1 +%f32 = OpTypeFloat 32 +%vec3 = OpTypeVector %f32 3 +%mat33 = OpTypeMatrix %vec3 3 +%struct = OpTypeStruct %intt %mat33 +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateSSA, ForwardMemberDecorateInvalidIdBad) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpName %missing "missing" + OpMemberDecorate %missing 1 RowMajor ; Target not defined +%intt = OpTypeInt 32 1 +%f32 = OpTypeFloat 32 +%vec3 = OpTypeVector %f32 3 +%mat33 = OpTypeMatrix %vec3 3 +%struct = OpTypeStruct %intt %mat33 +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("missing")); +} + +TEST_F(ValidateSSA, ForwardGroupDecorateGood) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpDecorate %dgrp RowMajor +%dgrp = OpDecorationGroup + OpGroupDecorate %dgrp %mat33 %mat44 +%f32 = OpTypeFloat 32 +%vec3 = OpTypeVector %f32 3 +%vec4 = OpTypeVector %f32 4 +%mat33 = OpTypeMatrix %vec3 3 +%mat44 = OpTypeMatrix %vec4 4 +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateSSA, ForwardGroupDecorateMissingGroupBad) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpName %missing "missing" + OpDecorate %dgrp RowMajor +%dgrp = OpDecorationGroup + OpGroupDecorate %missing %mat33 %mat44 ; Target not defined +%intt = OpTypeInt 32 1 +%vec3 = OpTypeVector %intt 3 +%vec4 = OpTypeVector %intt 4 +%mat33 = OpTypeMatrix %vec3 3 +%mat44 = OpTypeMatrix %vec4 4 +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("missing")); +} + +TEST_F(ValidateSSA, ForwardGroupDecorateMissingTargetBad) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpName %missing "missing" + OpDecorate %dgrp RowMajor +%dgrp = OpDecorationGroup + OpGroupDecorate %dgrp %missing %mat44 ; Target not defined +%f32 = OpTypeFloat 32 +%vec3 = OpTypeVector %f32 3 +%vec4 = OpTypeVector %f32 4 +%mat33 = OpTypeMatrix %vec3 3 +%mat44 = OpTypeMatrix %vec4 4 +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("missing")); +} + +TEST_F(ValidateSSA, ForwardGroupDecorateDecorationGroupDominateBad) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpName %dgrp "group" + OpDecorate %dgrp RowMajor + OpGroupDecorate %dgrp %mat33 %mat44 ; Decoration group does not dominate usage +%dgrp = OpDecorationGroup +%intt = OpTypeInt 32 1 +%vec3 = OpTypeVector %intt 3 +%vec4 = OpTypeVector %intt 4 +%mat33 = OpTypeMatrix %vec3 3 +%mat44 = OpTypeMatrix %vec4 4 +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("group")); +} + +TEST_F(ValidateSSA, ForwardDecorateInvalidIdBad) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpName %missing "missing" + OpDecorate %missing Restrict ; Missing target +%voidt = OpTypeVoid +%intt = OpTypeInt 32 1 +%ptrt = OpTypePointer UniformConstant %intt +%var = OpVariable %ptrt UniformConstant +%2 = OpTypeFunction %voidt +%3 = OpFunction %voidt None %2 +%4 = OpLabel + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("missing")); +} + +TEST_F(ValidateSSA, FunctionCallGood) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeInt 32 1 +%3 = OpTypeInt 32 0 +%4 = OpTypeFunction %1 +%8 = OpTypeFunction %1 %2 %3 +%four = OpConstant %2 4 +%five = OpConstant %3 5 +%9 = OpFunction %1 None %8 +%10 = OpFunctionParameter %2 +%11 = OpFunctionParameter %3 +%12 = OpLabel + OpReturn + OpFunctionEnd +%5 = OpFunction %1 None %4 +%6 = OpLabel +%7 = OpFunctionCall %1 %9 %four %five + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateSSA, ForwardFunctionCallGood) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 +%1 = OpTypeVoid +%2 = OpTypeInt 32 1 +%3 = OpTypeInt 32 0 +%four = OpConstant %2 4 +%five = OpConstant %3 5 +%8 = OpTypeFunction %1 %2 %3 +%4 = OpTypeFunction %1 +%5 = OpFunction %1 None %4 +%6 = OpLabel +%7 = OpFunctionCall %1 %9 %four %five + OpReturn + OpFunctionEnd +%9 = OpFunction %1 None %8 +%10 = OpFunctionParameter %2 +%11 = OpFunctionParameter %3 +%12 = OpLabel + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateSSA, ForwardBranchConditionalGood) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 +%voidt = OpTypeVoid +%boolt = OpTypeBool +%vfunct = OpTypeFunction %voidt +%true = OpConstantTrue %boolt +%main = OpFunction %voidt None %vfunct +%mainl = OpLabel + OpSelectionMerge %endl None + OpBranchConditional %true %truel %falsel +%truel = OpLabel + OpNop + OpBranch %endl +%falsel = OpLabel + OpNop + OpBranch %endl +%endl = OpLabel + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateSSA, ForwardBranchConditionalWithWeightsGood) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 +%voidt = OpTypeVoid +%boolt = OpTypeBool +%vfunct = OpTypeFunction %voidt +%true = OpConstantTrue %boolt +%main = OpFunction %voidt None %vfunct +%mainl = OpLabel + OpSelectionMerge %endl None + OpBranchConditional %true %truel %falsel 1 9 +%truel = OpLabel + OpNop + OpBranch %endl +%falsel = OpLabel + OpNop + OpBranch %endl +%endl = OpLabel + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateSSA, ForwardBranchConditionalNonDominantConditionBad) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpName %tcpy "conditional" +%voidt = OpTypeVoid +%boolt = OpTypeBool +%vfunct = OpTypeFunction %voidt +%true = OpConstantTrue %boolt +%main = OpFunction %voidt None %vfunct +%mainl = OpLabel + OpSelectionMerge %endl None + OpBranchConditional %tcpy %truel %falsel ; +%truel = OpLabel + OpNop + OpBranch %endl +%falsel = OpLabel + OpNop + OpBranch %endl +%endl = OpLabel +%tcpy = OpCopyObject %boolt %true + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("conditional")); +} + +TEST_F(ValidateSSA, ForwardBranchConditionalMissingTargetBad) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 + OpName %missing "missing" +%voidt = OpTypeVoid +%boolt = OpTypeBool +%vfunct = OpTypeFunction %voidt +%true = OpConstantTrue %boolt +%main = OpFunction %voidt None %vfunct +%mainl = OpLabel + OpSelectionMerge %endl None + OpBranchConditional %true %missing %falsel +%truel = OpLabel + OpNop + OpBranch %endl +%falsel = OpLabel + OpNop + OpBranch %endl +%endl = OpLabel + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("missing")); +} + +// Since Int8 requires the Kernel capability, the signedness of int types may +// not be "1". +const std::string kHeader = R"( +OpCapability Int8 +OpCapability DeviceEnqueue +OpCapability Linkage +OpMemoryModel Logical OpenCL +)"; + +const std::string kBasicTypes = R"( +%voidt = OpTypeVoid +%boolt = OpTypeBool +%int8t = OpTypeInt 8 0 +%uintt = OpTypeInt 32 0 +%vfunct = OpTypeFunction %voidt +%intptrt = OpTypePointer UniformConstant %uintt +%zero = OpConstant %uintt 0 +%one = OpConstant %uintt 1 +%ten = OpConstant %uintt 10 +%false = OpConstantFalse %boolt +)"; + +const std::string kKernelTypesAndConstants = R"( +%queuet = OpTypeQueue + +%three = OpConstant %uintt 3 +%arr3t = OpTypeArray %uintt %three +%ndt = OpTypeStruct %uintt %arr3t %arr3t %arr3t + +%eventt = OpTypeEvent + +%offset = OpConstant %uintt 0 +%local = OpConstant %uintt 1 +%gl = OpConstant %uintt 1 + +%nevent = OpConstant %uintt 0 +%event = OpConstantNull %eventt + +%firstp = OpConstant %int8t 0 +%psize = OpConstant %uintt 0 +%palign = OpConstant %uintt 32 +%lsize = OpConstant %uintt 1 +%flags = OpConstant %uintt 0 ; NoWait + +%kfunct = OpTypeFunction %voidt %intptrt +)"; + +const std::string kKernelSetup = R"( +%dqueue = OpGetDefaultQueue %queuet +%ndval = OpBuildNDRange %ndt %gl %local %offset +%revent = OpUndef %eventt + +)"; + +const std::string kKernelDefinition = R"( +%kfunc = OpFunction %voidt None %kfunct +%iparam = OpFunctionParameter %intptrt +%kfuncl = OpLabel + OpNop + OpReturn + OpFunctionEnd +)"; + +TEST_F(ValidateSSA, EnqueueKernelGood) { + std::string str = kHeader + kBasicTypes + kKernelTypesAndConstants + + kKernelDefinition + R"( + %main = OpFunction %voidt None %vfunct + %mainl = OpLabel + )" + kKernelSetup + R"( + %err = OpEnqueueKernel %uintt %dqueue %flags %ndval %nevent + %event %revent %kfunc %firstp %psize + %palign %lsize + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateSSA, ForwardEnqueueKernelGood) { + std::string str = kHeader + kBasicTypes + kKernelTypesAndConstants + R"( + %main = OpFunction %voidt None %vfunct + %mainl = OpLabel + )" + + kKernelSetup + R"( + %err = OpEnqueueKernel %uintt %dqueue %flags %ndval %nevent + %event %revent %kfunc %firstp %psize + %palign %lsize + OpReturn + OpFunctionEnd + )" + kKernelDefinition; + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateSSA, EnqueueMissingFunctionBad) { + std::string str = kHeader + "OpName %kfunc \"kfunc\"" + kBasicTypes + + kKernelTypesAndConstants + R"( + %main = OpFunction %voidt None %vfunct + %mainl = OpLabel + )" + kKernelSetup + R"( + %err = OpEnqueueKernel %uintt %dqueue %flags %ndval %nevent + %event %revent %kfunc %firstp %psize + %palign %lsize + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("kfunc")); +} + +std::string forwardKernelNonDominantParameterBaseCode( + std::string name = std::string()) { + std::string op_name; + if (name.empty()) { + op_name = ""; + } else { + op_name = "\nOpName %" + name + " \"" + name + "\"\n"; + } + std::string out = kHeader + op_name + kBasicTypes + kKernelTypesAndConstants + + kKernelDefinition + + R"( + %main = OpFunction %voidt None %vfunct + %mainl = OpLabel + )" + kKernelSetup; + return out; +} + +TEST_F(ValidateSSA, ForwardEnqueueKernelMissingParameter1Bad) { + std::string str = forwardKernelNonDominantParameterBaseCode("missing") + R"( + %err = OpEnqueueKernel %missing %dqueue %flags %ndval + %nevent %event %revent %kfunc %firstp + %psize %palign %lsize + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("missing")); +} + +TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter2Bad) { + std::string str = forwardKernelNonDominantParameterBaseCode("dqueue2") + R"( + %err = OpEnqueueKernel %uintt %dqueue2 %flags %ndval + %nevent %event %revent %kfunc + %firstp %psize %palign %lsize + %dqueue2 = OpGetDefaultQueue %queuet + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("dqueue2")); +} + +TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter3Bad) { + std::string str = forwardKernelNonDominantParameterBaseCode("ndval2") + R"( + %err = OpEnqueueKernel %uintt %dqueue %flags %ndval2 + %nevent %event %revent %kfunc %firstp + %psize %palign %lsize + %ndval2 = OpBuildNDRange %ndt %gl %local %offset + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("ndval2")); +} + +TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter4Bad) { + std::string str = forwardKernelNonDominantParameterBaseCode("nevent2") + R"( + %err = OpEnqueueKernel %uintt %dqueue %flags %ndval %nevent2 + %event %revent %kfunc %firstp %psize + %palign %lsize + %nevent2 = OpCopyObject %uintt %nevent + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("nevent2")); +} + +TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter5Bad) { + std::string str = forwardKernelNonDominantParameterBaseCode("event2") + R"( + %err = OpEnqueueKernel %uintt %dqueue %flags %ndval %nevent + %event2 %revent %kfunc %firstp %psize + %palign %lsize + %event2 = OpCopyObject %eventt %event + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("event2")); +} + +TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter6Bad) { + std::string str = forwardKernelNonDominantParameterBaseCode("revent2") + R"( + %err = OpEnqueueKernel %uintt %dqueue %flags %ndval %nevent + %event %revent2 %kfunc %firstp %psize + %palign %lsize + %revent2 = OpCopyObject %eventt %revent + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("revent2")); +} + +TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter8Bad) { + std::string str = forwardKernelNonDominantParameterBaseCode("firstp2") + R"( + %err = OpEnqueueKernel %uintt %dqueue %flags %ndval %nevent + %event %revent %kfunc %firstp2 %psize + %palign %lsize + %firstp2 = OpCopyObject %int8t %firstp + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("firstp2")); +} + +TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter9Bad) { + std::string str = forwardKernelNonDominantParameterBaseCode("psize2") + R"( + %err = OpEnqueueKernel %uintt %dqueue %flags %ndval %nevent + %event %revent %kfunc %firstp %psize2 + %palign %lsize + %psize2 = OpCopyObject %uintt %psize + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("psize2")); +} + +TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter10Bad) { + std::string str = forwardKernelNonDominantParameterBaseCode("palign2") + R"( + %err = OpEnqueueKernel %uintt %dqueue %flags %ndval %nevent + %event %revent %kfunc %firstp %psize + %palign2 %lsize + %palign2 = OpCopyObject %uintt %palign + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("palign2")); +} + +TEST_F(ValidateSSA, ForwardEnqueueKernelNonDominantParameter11Bad) { + std::string str = forwardKernelNonDominantParameterBaseCode("lsize2") + R"( + %err = OpEnqueueKernel %uintt %dqueue %flags %ndval %nevent + %event %revent %kfunc %firstp %psize + %palign %lsize2 + %lsize2 = OpCopyObject %uintt %lsize + OpReturn + OpFunctionEnd + )"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("lsize2")); +} + +static const bool kWithNDrange = true; +static const bool kNoNDrange = false; +std::pair cases[] = { + {"OpGetKernelNDrangeSubGroupCount", kWithNDrange}, + {"OpGetKernelNDrangeMaxSubGroupSize", kWithNDrange}, + {"OpGetKernelWorkGroupSize", kNoNDrange}, + {"OpGetKernelPreferredWorkGroupSizeMultiple", kNoNDrange}}; + +INSTANTIATE_TEST_CASE_P(KernelArgs, ValidateSSA, ::testing::ValuesIn(cases), ); + +static const std::string return_instructions = R"( + OpReturn + OpFunctionEnd +)"; + +TEST_P(ValidateSSA, GetKernelGood) { + std::string instruction = GetParam().first; + bool with_ndrange = GetParam().second; + std::string ndrange_param = with_ndrange ? " %ndval " : " "; + + std::stringstream ss; + // clang-format off + ss << forwardKernelNonDominantParameterBaseCode() + " %numsg = " + << instruction + " %uintt" + ndrange_param + "%kfunc %firstp %psize %palign" + << return_instructions; + // clang-format on + + CompileSuccessfully(ss.str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateSSA, ForwardGetKernelGood) { + std::string instruction = GetParam().first; + bool with_ndrange = GetParam().second; + std::string ndrange_param = with_ndrange ? " %ndval " : " "; + + // clang-format off + std::string str = kHeader + kBasicTypes + kKernelTypesAndConstants + + R"( + %main = OpFunction %voidt None %vfunct + %mainl = OpLabel + )" + + kKernelSetup + " %numsg = " + + instruction + " %uintt" + ndrange_param + "%kfunc %firstp %psize %palign" + + return_instructions + kKernelDefinition; + // clang-format on + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(ValidateSSA, ForwardGetKernelMissingDefinitionBad) { + std::string instruction = GetParam().first; + bool with_ndrange = GetParam().second; + std::string ndrange_param = with_ndrange ? " %ndval " : " "; + + std::stringstream ss; + // clang-format off + ss << forwardKernelNonDominantParameterBaseCode("missing") + " %numsg = " + << instruction + " %uintt" + ndrange_param + "%missing %firstp %psize %palign" + << return_instructions; + // clang-format on + + CompileSuccessfully(ss.str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("missing")); +} + +TEST_P(ValidateSSA, ForwardGetKernelNDrangeSubGroupCountMissingParameter1Bad) { + std::string instruction = GetParam().first; + bool with_ndrange = GetParam().second; + std::string ndrange_param = with_ndrange ? " %ndval " : " "; + + std::stringstream ss; + // clang-format off + ss << forwardKernelNonDominantParameterBaseCode("missing") + " %numsg = " + << instruction + " %missing" + ndrange_param + "%kfunc %firstp %psize %palign" + << return_instructions; + // clang-format on + + CompileSuccessfully(ss.str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("missing")); +} + +TEST_P(ValidateSSA, + ForwardGetKernelNDrangeSubGroupCountNonDominantParameter2Bad) { + std::string instruction = GetParam().first; + bool with_ndrange = GetParam().second; + std::string ndrange_param = with_ndrange ? " %ndval2 " : " "; + + std::stringstream ss; + // clang-format off + ss << forwardKernelNonDominantParameterBaseCode("ndval2") + " %numsg = " + << instruction + " %uintt" + ndrange_param + "%kfunc %firstp %psize %palign" + << "\n %ndval2 = OpBuildNDRange %ndt %gl %local %offset" + << return_instructions; + // clang-format on + + if (GetParam().second) { + CompileSuccessfully(ss.str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("ndval2")); + } +} + +TEST_P(ValidateSSA, + ForwardGetKernelNDrangeSubGroupCountNonDominantParameter4Bad) { + std::string instruction = GetParam().first; + bool with_ndrange = GetParam().second; + std::string ndrange_param = with_ndrange ? " %ndval " : " "; + + std::stringstream ss; + // clang-format off + ss << forwardKernelNonDominantParameterBaseCode("firstp2") + " %numsg = " + << instruction + " %uintt" + ndrange_param + "%kfunc %firstp2 %psize %palign" + << "\n %firstp2 = OpCopyObject %int8t %firstp" + << return_instructions; + // clang-format on + + CompileSuccessfully(ss.str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("firstp2")); +} + +TEST_P(ValidateSSA, + ForwardGetKernelNDrangeSubGroupCountNonDominantParameter5Bad) { + std::string instruction = GetParam().first; + bool with_ndrange = GetParam().second; + std::string ndrange_param = with_ndrange ? " %ndval " : " "; + + std::stringstream ss; + // clang-format off + ss << forwardKernelNonDominantParameterBaseCode("psize2") + " %numsg = " + << instruction + " %uintt" + ndrange_param + "%kfunc %firstp %psize2 %palign" + << "\n %psize2 = OpCopyObject %uintt %psize" + << return_instructions; + // clang-format on + + CompileSuccessfully(ss.str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("psize2")); +} + +TEST_P(ValidateSSA, + ForwardGetKernelNDrangeSubGroupCountNonDominantParameter6Bad) { + std::string instruction = GetParam().first; + bool with_ndrange = GetParam().second; + std::string ndrange_param = with_ndrange ? " %ndval " : " "; + + std::stringstream ss; + // clang-format off + ss << forwardKernelNonDominantParameterBaseCode("palign2") + " %numsg = " + << instruction + " %uintt" + ndrange_param + "%kfunc %firstp %psize %palign2" + << "\n %palign2 = OpCopyObject %uintt %palign" + << return_instructions; + // clang-format on + + if (GetParam().second) { + CompileSuccessfully(ss.str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("palign2")); + } +} + +TEST_F(ValidateSSA, PhiGood) { + std::string str = kHeader + kBasicTypes + + R"( +%func = OpFunction %voidt None %vfunct +%preheader = OpLabel +%init = OpCopyObject %uintt %zero + OpBranch %loop +%loop = OpLabel +%i = OpPhi %uintt %init %preheader %loopi %loop +%loopi = OpIAdd %uintt %i %one + OpNop +%cond = OpSLessThan %boolt %i %ten + OpLoopMerge %endl %loop None + OpBranchConditional %cond %loop %endl +%endl = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateSSA, PhiMissingTypeBad) { + std::string str = kHeader + "OpName %missing \"missing\"" + kBasicTypes + + R"( +%func = OpFunction %voidt None %vfunct +%preheader = OpLabel +%init = OpCopyObject %uintt %zero + OpBranch %loop +%loop = OpLabel +%i = OpPhi %missing %init %preheader %loopi %loop +%loopi = OpIAdd %uintt %i %one + OpNop +%cond = OpSLessThan %boolt %i %ten + OpLoopMerge %endl %loop None + OpBranchConditional %cond %loop %endl +%endl = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("missing")); +} + +TEST_F(ValidateSSA, PhiMissingIdBad) { + std::string str = kHeader + "OpName %missing \"missing\"" + kBasicTypes + + R"( +%func = OpFunction %voidt None %vfunct +%preheader = OpLabel +%init = OpCopyObject %uintt %zero + OpBranch %loop +%loop = OpLabel +%i = OpPhi %uintt %missing %preheader %loopi %loop +%loopi = OpIAdd %uintt %i %one + OpNop +%cond = OpSLessThan %boolt %i %ten + OpLoopMerge %endl %loop None + OpBranchConditional %cond %loop %endl +%endl = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("missing")); +} + +TEST_F(ValidateSSA, PhiMissingLabelBad) { + std::string str = kHeader + "OpName %missing \"missing\"" + kBasicTypes + + R"( +%func = OpFunction %voidt None %vfunct +%preheader = OpLabel +%init = OpCopyObject %uintt %zero + OpBranch %loop +%loop = OpLabel +%i = OpPhi %uintt %init %missing %loopi %loop +%loopi = OpIAdd %uintt %i %one + OpNop +%cond = OpSLessThan %boolt %i %ten + OpLoopMerge %endl %loop None + OpBranchConditional %cond %loop %endl +%endl = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("missing")); +} + +TEST_F(ValidateSSA, IdDominatesItsUseGood) { + std::string str = kHeader + kBasicTypes + + R"( +%func = OpFunction %voidt None %vfunct +%entry = OpLabel +%cond = OpSLessThan %boolt %one %ten +%eleven = OpIAdd %uintt %one %ten + OpSelectionMerge %merge None + OpBranchConditional %cond %t %f +%t = OpLabel +%twelve = OpIAdd %uintt %eleven %one + OpBranch %merge +%f = OpLabel +%twentytwo = OpIAdd %uintt %eleven %ten + OpBranch %merge +%merge = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateSSA, IdDoesNotDominateItsUseBad) { + std::string str = kHeader + + "OpName %eleven \"eleven\"\n" + "OpName %true_block \"true_block\"\n" + "OpName %false_block \"false_block\"" + + kBasicTypes + + R"( +%func = OpFunction %voidt None %vfunct +%entry = OpLabel +%cond = OpSLessThan %boolt %one %ten + OpSelectionMerge %merge None + OpBranchConditional %cond %true_block %false_block +%true_block = OpLabel +%eleven = OpIAdd %uintt %one %ten +%twelve = OpIAdd %uintt %eleven %one + OpBranch %merge +%false_block = OpLabel +%twentytwo = OpIAdd %uintt %eleven %ten + OpBranch %merge +%merge = OpLabel + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + MatchesRegex("ID .\\[%eleven\\] defined in block .\\[%true_block\\] " + "does not dominate its use in block .\\[%false_block\\]\n" + " %false_block = OpLabel\n")); +} + +TEST_F(ValidateSSA, PhiUseDoesntDominateDefinitionGood) { + std::string str = kHeader + kBasicTypes + + R"( +%funcintptrt = OpTypePointer Function %uintt +%func = OpFunction %voidt None %vfunct +%entry = OpLabel +%var_one = OpVariable %funcintptrt Function %one +%one_val = OpLoad %uintt %var_one + OpBranch %loop +%loop = OpLabel +%i = OpPhi %uintt %one_val %entry %inew %cont +%cond = OpSLessThan %boolt %one %ten + OpLoopMerge %merge %cont None + OpBranchConditional %cond %body %merge +%body = OpLabel + OpBranch %cont +%cont = OpLabel +%inew = OpIAdd %uintt %i %one + OpBranch %loop +%merge = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateSSA, + PhiUseDoesntDominateUseOfPhiOperandUsedBeforeDefinitionBad) { + std::string str = kHeader + "OpName %inew \"inew\"" + kBasicTypes + + R"( +%func = OpFunction %voidt None %vfunct +%entry = OpLabel +%var_one = OpVariable %intptrt Function %one +%one_val = OpLoad %uintt %var_one + OpBranch %loop +%loop = OpLabel +%i = OpPhi %uintt %one_val %entry %inew %cont +%bad = OpIAdd %uintt %inew %one +%cond = OpSLessThan %boolt %one %ten + OpLoopMerge %merge %cont None + OpBranchConditional %cond %body %merge +%body = OpLabel + OpBranch %cont +%cont = OpLabel +%inew = OpIAdd %uintt %i %one + OpBranch %loop +%merge = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + MatchesRegex("ID .\\[%inew\\] has not been defined\n" + " %19 = OpIAdd %uint %inew %uint_1\n")); +} + +TEST_F(ValidateSSA, PhiUseMayComeFromNonDominatingBlockGood) { + std::string str = kHeader + "OpName %if_true \"if_true\"\n" + + "OpName %exit \"exit\"\n" + "OpName %copy \"copy\"\n" + + kBasicTypes + + R"( +%func = OpFunction %voidt None %vfunct +%entry = OpLabel + OpBranchConditional %false %if_true %exit + +%if_true = OpLabel +%copy = OpCopyObject %boolt %false + OpBranch %exit + +; The use of %copy here is ok, even though it was defined +; in a block that does not dominate %exit. That's the point +; of an OpPhi. +%exit = OpLabel +%value = OpPhi %boolt %false %entry %copy %if_true + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()) << getDiagnosticString(); +} + +TEST_F(ValidateSSA, PhiUsesItsOwnDefinitionGood) { + // See https://github.com/KhronosGroup/SPIRV-Tools/issues/415 + // + // Non-phi instructions can't use their own definitions, as + // already checked in test DominateUsageSameInstructionBad. + std::string str = kHeader + "OpName %loop \"loop\"\n" + + "OpName %value \"value\"\n" + kBasicTypes + + R"( +%func = OpFunction %voidt None %vfunct +%entry = OpLabel + OpBranch %loop + +%loop = OpLabel +%value = OpPhi %boolt %false %entry %value %loop + OpBranch %loop + + OpFunctionEnd +)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()) << getDiagnosticString(); +} + +TEST_F(ValidateSSA, PhiVariableDefNotDominatedByParentBlockBad) { + std::string str = kHeader + "OpName %if_true \"if_true\"\n" + + "OpName %if_false \"if_false\"\n" + + "OpName %exit \"exit\"\n" + "OpName %value \"phi\"\n" + + "OpName %true_copy \"true_copy\"\n" + + "OpName %false_copy \"false_copy\"\n" + kBasicTypes + + R"( +%func = OpFunction %voidt None %vfunct +%entry = OpLabel + OpBranchConditional %false %if_true %if_false + +%if_true = OpLabel +%true_copy = OpCopyObject %boolt %false + OpBranch %exit + +%if_false = OpLabel +%false_copy = OpCopyObject %boolt %false + OpBranch %exit + +; The (variable,Id) pairs are swapped. +%exit = OpLabel +%value = OpPhi %boolt %true_copy %if_false %false_copy %if_true + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + MatchesRegex("In OpPhi instruction .\\[%phi\\], ID .\\[%true_copy\\] " + "definition does not dominate its parent .\\[%if_false\\]\n" + " %phi = OpPhi %bool %true_copy %if_false %false_copy " + "%if_true\n")); +} + +TEST_F(ValidateSSA, PhiVariableDefDominatesButNotDefinedInParentBlock) { + std::string str = kHeader + "OpName %if_true \"if_true\"\n" + kBasicTypes + + R"( +%func = OpFunction %voidt None %vfunct +%entry = OpLabel + OpBranchConditional %false %if_true %if_false + +%if_true = OpLabel +%true_copy = OpCopyObject %boolt %false + OpBranch %if_tnext +%if_tnext = OpLabel + OpBranch %exit + +%if_false = OpLabel +%false_copy = OpCopyObject %boolt %false + OpBranch %if_fnext +%if_fnext = OpLabel + OpBranch %exit + +%exit = OpLabel +%value = OpPhi %boolt %true_copy %if_tnext %false_copy %if_fnext + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateSSA, + DominanceCheckIgnoresUsesInUnreachableBlocksDefInBlockGood) { + std::string str = kHeader + kBasicTypes + + R"( +%func = OpFunction %voidt None %vfunct +%entry = OpLabel +%def = OpCopyObject %boolt %false + OpReturn + +%unreach = OpLabel +%use = OpCopyObject %boolt %def + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()) << getDiagnosticString(); +} + +TEST_F(ValidateSSA, PhiVariableUnreachableDefNotInParentBlock) { + std::string str = kHeader + "OpName %unreachable \"unreachable\"\n" + + kBasicTypes + + R"( +%func = OpFunction %voidt None %vfunct +%entry = OpLabel + OpBranch %if_false + +%unreachable = OpLabel +%copy = OpCopyObject %boolt %false + OpBranch %if_tnext +%if_tnext = OpLabel + OpBranch %exit + +%if_false = OpLabel +%false_copy = OpCopyObject %boolt %false + OpBranch %if_fnext +%if_fnext = OpLabel + OpBranch %exit + +%exit = OpLabel +%value = OpPhi %boolt %copy %if_tnext %false_copy %if_fnext + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateSSA, + DominanceCheckIgnoresUsesInUnreachableBlocksDefIsParamGood) { + std::string str = kHeader + kBasicTypes + + R"( +%void_fn_int = OpTypeFunction %voidt %uintt +%func = OpFunction %voidt None %void_fn_int +%int_param = OpFunctionParameter %uintt +%entry = OpLabel + OpReturn + +%unreach = OpLabel +%use = OpCopyObject %uintt %int_param + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()) << getDiagnosticString(); +} + +TEST_F(ValidateSSA, UseFunctionParameterFromOtherFunctionBad) { + std::string str = kHeader + + "OpName %first \"first\"\n" + "OpName %func \"func\"\n" + + "OpName %func2 \"func2\"\n" + kBasicTypes + + R"( +%viifunct = OpTypeFunction %voidt %uintt %uintt +%func = OpFunction %voidt None %viifunct +%first = OpFunctionParameter %uintt +%second = OpFunctionParameter %uintt + OpFunctionEnd +%func2 = OpFunction %voidt None %viifunct +%first2 = OpFunctionParameter %uintt +%second2 = OpFunctionParameter %uintt +%entry2 = OpLabel +%baduse = OpIAdd %uintt %first %first2 + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + MatchesRegex("ID .\\[%first\\] used in function .\\[%func2\\] is used " + "outside of it's defining function .\\[%func\\]\n" + " %func = OpFunction %void None %14\n")); +} + +TEST_F(ValidateSSA, TypeForwardPointerForwardReference) { + // See https://github.com/KhronosGroup/SPIRV-Tools/issues/429 + // + // ForwardPointers can references instructions that have not been defined + std::string str = R"( + OpCapability Kernel + OpCapability Addresses + OpCapability Linkage + OpMemoryModel Logical OpenCL + OpName %intptrt "intptrt" + OpTypeForwardPointer %intptrt UniformConstant + %uint = OpTypeInt 32 0 + %intptrt = OpTypePointer UniformConstant %uint +)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateSSA, TypeStructForwardReference) { + std::string str = R"( + OpCapability Kernel + OpCapability Addresses + OpCapability Linkage + OpMemoryModel Logical OpenCL + OpName %structptr "structptr" + OpTypeForwardPointer %structptr UniformConstant + %uint = OpTypeInt 32 0 + %structt1 = OpTypeStruct %structptr %uint + %structt2 = OpTypeStruct %uint %structptr + %structt3 = OpTypeStruct %uint %uint %structptr + %structt4 = OpTypeStruct %uint %uint %uint %structptr + %structptr = OpTypePointer UniformConstant %structt1 +)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// TODO(umar): OpGroupMemberDecorate + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_state_test.cpp b/test/val/val_state_test.cpp new file mode 100644 index 000000000..18a4ef99e --- /dev/null +++ b/test/val/val_state_test.cpp @@ -0,0 +1,139 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Unit tests for ValidationState_t. + +#include + +#include "gtest/gtest.h" +#include "source/latest_version_spirv_header.h" + +#include "source/enum_set.h" +#include "source/extensions.h" +#include "source/spirv_validator_options.h" +#include "source/val/construct.h" +#include "source/val/function.h" +#include "source/val/validate.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +// This is all we need for these tests. +static uint32_t kFakeBinary[] = {0}; + +// A test with a ValidationState_t member transparently. +class ValidationStateTest : public testing::Test { + public: + ValidationStateTest() + : context_(spvContextCreate(SPV_ENV_UNIVERSAL_1_0)), + options_(spvValidatorOptionsCreate()), + state_(context_, options_, kFakeBinary, 0, 1) {} + + ~ValidationStateTest() { + spvContextDestroy(context_); + spvValidatorOptionsDestroy(options_); + } + + protected: + spv_context context_; + spv_validator_options options_; + ValidationState_t state_; +}; + +// A test of ValidationState_t::HasAnyOfCapabilities(). +using ValidationState_HasAnyOfCapabilities = ValidationStateTest; + +TEST_F(ValidationState_HasAnyOfCapabilities, EmptyMask) { + EXPECT_TRUE(state_.HasAnyOfCapabilities({})); + state_.RegisterCapability(SpvCapabilityMatrix); + EXPECT_TRUE(state_.HasAnyOfCapabilities({})); + state_.RegisterCapability(SpvCapabilityImageMipmap); + EXPECT_TRUE(state_.HasAnyOfCapabilities({})); + state_.RegisterCapability(SpvCapabilityPipes); + EXPECT_TRUE(state_.HasAnyOfCapabilities({})); + state_.RegisterCapability(SpvCapabilityStorageImageArrayDynamicIndexing); + EXPECT_TRUE(state_.HasAnyOfCapabilities({})); + state_.RegisterCapability(SpvCapabilityClipDistance); + EXPECT_TRUE(state_.HasAnyOfCapabilities({})); + state_.RegisterCapability(SpvCapabilityStorageImageWriteWithoutFormat); + EXPECT_TRUE(state_.HasAnyOfCapabilities({})); +} + +TEST_F(ValidationState_HasAnyOfCapabilities, SingleCapMask) { + EXPECT_FALSE(state_.HasAnyOfCapabilities({SpvCapabilityMatrix})); + EXPECT_FALSE(state_.HasAnyOfCapabilities({SpvCapabilityImageMipmap})); + state_.RegisterCapability(SpvCapabilityMatrix); + EXPECT_TRUE(state_.HasAnyOfCapabilities({SpvCapabilityMatrix})); + EXPECT_FALSE(state_.HasAnyOfCapabilities({SpvCapabilityImageMipmap})); + state_.RegisterCapability(SpvCapabilityImageMipmap); + EXPECT_TRUE(state_.HasAnyOfCapabilities({SpvCapabilityMatrix})); + EXPECT_TRUE(state_.HasAnyOfCapabilities({SpvCapabilityImageMipmap})); +} + +TEST_F(ValidationState_HasAnyOfCapabilities, MultiCapMask) { + const auto set1 = + CapabilitySet{SpvCapabilitySampledRect, SpvCapabilityImageBuffer}; + const auto set2 = CapabilitySet{SpvCapabilityStorageImageWriteWithoutFormat, + SpvCapabilityStorageImageReadWithoutFormat, + SpvCapabilityGeometryStreams}; + EXPECT_FALSE(state_.HasAnyOfCapabilities(set1)); + EXPECT_FALSE(state_.HasAnyOfCapabilities(set2)); + state_.RegisterCapability(SpvCapabilityImageBuffer); + EXPECT_TRUE(state_.HasAnyOfCapabilities(set1)); + EXPECT_FALSE(state_.HasAnyOfCapabilities(set2)); +} + +// A test of ValidationState_t::HasAnyOfExtensions(). +using ValidationState_HasAnyOfExtensions = ValidationStateTest; + +TEST_F(ValidationState_HasAnyOfExtensions, EmptyMask) { + EXPECT_TRUE(state_.HasAnyOfExtensions({})); + state_.RegisterExtension(Extension::kSPV_KHR_shader_ballot); + EXPECT_TRUE(state_.HasAnyOfExtensions({})); + state_.RegisterExtension(Extension::kSPV_KHR_16bit_storage); + EXPECT_TRUE(state_.HasAnyOfExtensions({})); + state_.RegisterExtension(Extension::kSPV_NV_viewport_array2); + EXPECT_TRUE(state_.HasAnyOfExtensions({})); +} + +TEST_F(ValidationState_HasAnyOfExtensions, SingleCapMask) { + EXPECT_FALSE(state_.HasAnyOfExtensions({Extension::kSPV_KHR_shader_ballot})); + EXPECT_FALSE(state_.HasAnyOfExtensions({Extension::kSPV_KHR_16bit_storage})); + state_.RegisterExtension(Extension::kSPV_KHR_shader_ballot); + EXPECT_TRUE(state_.HasAnyOfExtensions({Extension::kSPV_KHR_shader_ballot})); + EXPECT_FALSE(state_.HasAnyOfExtensions({Extension::kSPV_KHR_16bit_storage})); + state_.RegisterExtension(Extension::kSPV_KHR_16bit_storage); + EXPECT_TRUE(state_.HasAnyOfExtensions({Extension::kSPV_KHR_shader_ballot})); + EXPECT_TRUE(state_.HasAnyOfExtensions({Extension::kSPV_KHR_16bit_storage})); +} + +TEST_F(ValidationState_HasAnyOfExtensions, MultiCapMask) { + const auto set1 = ExtensionSet{Extension::kSPV_KHR_multiview, + Extension::kSPV_KHR_16bit_storage}; + const auto set2 = ExtensionSet{Extension::kSPV_KHR_shader_draw_parameters, + Extension::kSPV_NV_stereo_view_rendering, + Extension::kSPV_KHR_shader_ballot}; + EXPECT_FALSE(state_.HasAnyOfExtensions(set1)); + EXPECT_FALSE(state_.HasAnyOfExtensions(set2)); + state_.RegisterExtension(Extension::kSPV_KHR_multiview); + EXPECT_TRUE(state_.HasAnyOfExtensions(set1)); + EXPECT_FALSE(state_.HasAnyOfExtensions(set2)); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_storage_test.cpp b/test/val/val_storage_test.cpp new file mode 100644 index 000000000..aa1eecda8 --- /dev/null +++ b/test/val/val_storage_test.cpp @@ -0,0 +1,191 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validation tests for OpVariable storage class + +#include +#include +#include + +#include "gmock/gmock.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::HasSubstr; +using ValidateStorage = spvtest::ValidateBase; + +TEST_F(ValidateStorage, FunctionStorageInsideFunction) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 +%intt = OpTypeInt 32 1 +%voidt = OpTypeVoid +%vfunct = OpTypeFunction %voidt +%ptrt = OpTypePointer Function %intt +%func = OpFunction %voidt None %vfunct +%funcl = OpLabel +%var = OpVariable %ptrt Function + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateStorage, FunctionStorageOutsideFunction) { + char str[] = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 +%intt = OpTypeInt 32 1 +%voidt = OpTypeVoid +%vfunct = OpTypeFunction %voidt +%ptrt = OpTypePointer Function %intt +%var = OpVariable %ptrt Function +%func = OpFunction %voidt None %vfunct +%funcl = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_LAYOUT, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Variables can not have a function[7] storage class " + "outside of a function")); +} + +TEST_F(ValidateStorage, OtherStorageOutsideFunction) { + char str[] = R"( + OpCapability Shader + OpCapability Kernel + OpCapability AtomicStorage + OpCapability Linkage + OpMemoryModel Logical GLSL450 +%intt = OpTypeInt 32 0 +%voidt = OpTypeVoid +%vfunct = OpTypeFunction %voidt +%uniconptrt = OpTypePointer UniformConstant %intt +%unicon = OpVariable %uniconptrt UniformConstant +%inputptrt = OpTypePointer Input %intt +%input = OpVariable %inputptrt Input +%unifptrt = OpTypePointer Uniform %intt +%unif = OpVariable %unifptrt Uniform +%outputptrt = OpTypePointer Output %intt +%output = OpVariable %outputptrt Output +%wgroupptrt = OpTypePointer Workgroup %intt +%wgroup = OpVariable %wgroupptrt Workgroup +%xwgrpptrt = OpTypePointer CrossWorkgroup %intt +%xwgrp = OpVariable %xwgrpptrt CrossWorkgroup +%privptrt = OpTypePointer Private %intt +%priv = OpVariable %privptrt Private +%pushcoptrt = OpTypePointer PushConstant %intt +%pushco = OpVariable %pushcoptrt PushConstant +%atomcptrt = OpTypePointer AtomicCounter %intt +%atomct = OpVariable %atomcptrt AtomicCounter +%imageptrt = OpTypePointer Image %intt +%image = OpVariable %imageptrt Image +%func = OpFunction %voidt None %vfunct +%funcl = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +// clang-format off +TEST_P(ValidateStorage, OtherStorageInsideFunction) { + std::stringstream ss; + ss << R"( + OpCapability Shader + OpCapability Kernel + OpCapability AtomicStorage + OpCapability Linkage + OpMemoryModel Logical GLSL450 +%intt = OpTypeInt 32 0 +%voidt = OpTypeVoid +%vfunct = OpTypeFunction %voidt +%ptrt = OpTypePointer Function %intt +%func = OpFunction %voidt None %vfunct +%funcl = OpLabel +%var = OpVariable %ptrt )" << GetParam() << R"( + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(ss.str()); + ASSERT_EQ(SPV_ERROR_INVALID_LAYOUT, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Variables must have a function[7] storage class inside of a function")); +} + +INSTANTIATE_TEST_CASE_P(MatrixOp, ValidateStorage, + ::testing::Values( + "Input", + "Uniform", + "Output", + "Workgroup", + "CrossWorkgroup", + "Private", + "PushConstant", + "AtomicCounter", + "Image"),); +// clang-format on + +TEST_F(ValidateStorage, GenericVariableOutsideFunction) { + const auto str = R"( + OpCapability Kernel + OpCapability Linkage + OpMemoryModel Logical OpenCL +%intt = OpTypeInt 32 0 +%ptrt = OpTypePointer Function %intt +%var = OpVariable %ptrt Generic +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpVariable storage class cannot be Generic")); +} + +TEST_F(ValidateStorage, GenericVariableInsideFunction) { + const auto str = R"( + OpCapability Shader + OpCapability Linkage + OpMemoryModel Logical GLSL450 +%intt = OpTypeInt 32 1 +%voidt = OpTypeVoid +%vfunct = OpTypeFunction %voidt +%ptrt = OpTypePointer Function %intt +%func = OpFunction %voidt None %vfunct +%funcl = OpLabel +%var = OpVariable %ptrt Generic + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(str); + ASSERT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpVariable storage class cannot be Generic")); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_type_unique_test.cpp b/test/val/val_type_unique_test.cpp new file mode 100644 index 000000000..67ceaddb8 --- /dev/null +++ b/test/val/val_type_unique_test.cpp @@ -0,0 +1,269 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests for unique type declaration rules validator. + +#include + +#include "gmock/gmock.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::HasSubstr; +using ::testing::Not; + +using ValidateTypeUnique = spvtest::ValidateBase; + +const spv_result_t kDuplicateTypeError = SPV_ERROR_INVALID_DATA; + +const std::string& GetHeader() { + static const std::string header = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%floatt = OpTypeFloat 32 +%vec2t = OpTypeVector %floatt 2 +%vec3t = OpTypeVector %floatt 3 +%vec4t = OpTypeVector %floatt 4 +%mat22t = OpTypeMatrix %vec2t 2 +%mat33t = OpTypeMatrix %vec3t 3 +%mat44t = OpTypeMatrix %vec4t 4 +%intt = OpTypeInt 32 1 +%uintt = OpTypeInt 32 0 +%num3 = OpConstant %uintt 3 +%const3 = OpConstant %uintt 3 +%val3 = OpConstant %uintt 3 +%array = OpTypeArray %vec3t %num3 +%struct = OpTypeStruct %floatt %floatt %vec3t +%boolt = OpTypeBool +%array2 = OpTypeArray %vec3t %num3 +%voidt = OpTypeVoid +%vfunct = OpTypeFunction %voidt +%struct2 = OpTypeStruct %floatt %floatt %vec3t +%false = OpConstantFalse %boolt +%true = OpConstantTrue %boolt +%runtime_arrayt = OpTypeRuntimeArray %floatt +%runtime_arrayt2 = OpTypeRuntimeArray %floatt +)"; + + return header; +} + +const std::string& GetBody() { + static const std::string body = R"( +%main = OpFunction %voidt None %vfunct +%mainl = OpLabel +%a = OpIAdd %uintt %const3 %val3 +%b = OpIAdd %uintt %const3 %val3 +OpSelectionMerge %endl None +OpBranchConditional %true %truel %falsel +%truel = OpLabel +%add1 = OpIAdd %uintt %a %b +%add2 = OpIAdd %uintt %a %b +OpBranch %endl +%falsel = OpLabel +%sub1 = OpISub %uintt %a %b +%sub2 = OpISub %uintt %a %b +OpBranch %endl +%endl = OpLabel +OpReturn +OpFunctionEnd +)"; + + return body; +} + +// Returns expected error string if |opcode| produces a duplicate type +// declaration. +std::string GetErrorString(SpvOp opcode) { + return "Duplicate non-aggregate type declarations are not allowed. Opcode: " + + std::string(spvOpcodeString(opcode)); +} + +TEST_F(ValidateTypeUnique, success) { + std::string str = GetHeader() + GetBody(); + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateTypeUnique, duplicate_void) { + std::string str = GetHeader() + R"( +%boolt2 = OpTypeVoid +)" + GetBody(); + CompileSuccessfully(str.c_str()); + ASSERT_EQ(kDuplicateTypeError, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(GetErrorString(SpvOpTypeVoid))); +} + +TEST_F(ValidateTypeUnique, duplicate_bool) { + std::string str = GetHeader() + R"( +%boolt2 = OpTypeBool +)" + GetBody(); + CompileSuccessfully(str.c_str()); + ASSERT_EQ(kDuplicateTypeError, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(GetErrorString(SpvOpTypeBool))); +} + +TEST_F(ValidateTypeUnique, duplicate_int) { + std::string str = GetHeader() + R"( +%uintt2 = OpTypeInt 32 0 +)" + GetBody(); + CompileSuccessfully(str.c_str()); + ASSERT_EQ(kDuplicateTypeError, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(GetErrorString(SpvOpTypeInt))); +} + +TEST_F(ValidateTypeUnique, duplicate_float) { + std::string str = GetHeader() + R"( +%floatt2 = OpTypeFloat 32 +)" + GetBody(); + CompileSuccessfully(str.c_str()); + ASSERT_EQ(kDuplicateTypeError, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr(GetErrorString(SpvOpTypeFloat))); +} + +TEST_F(ValidateTypeUnique, duplicate_vec3) { + std::string str = GetHeader() + R"( +%vec3t2 = OpTypeVector %floatt 3 +)" + GetBody(); + CompileSuccessfully(str.c_str()); + ASSERT_EQ(kDuplicateTypeError, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr(GetErrorString(SpvOpTypeVector))); +} + +TEST_F(ValidateTypeUnique, duplicate_mat33) { + std::string str = GetHeader() + R"( +%mat33t2 = OpTypeMatrix %vec3t 3 +)" + GetBody(); + CompileSuccessfully(str.c_str()); + ASSERT_EQ(kDuplicateTypeError, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr(GetErrorString(SpvOpTypeMatrix))); +} + +TEST_F(ValidateTypeUnique, duplicate_vfunc) { + std::string str = GetHeader() + R"( +%vfunct2 = OpTypeFunction %voidt +)" + GetBody(); + CompileSuccessfully(str.c_str()); + ASSERT_EQ(kDuplicateTypeError, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr(GetErrorString(SpvOpTypeFunction))); +} + +TEST_F(ValidateTypeUnique, duplicate_pipe_storage) { + std::string str = R"( +OpCapability Addresses +OpCapability Kernel +OpCapability Linkage +OpCapability Pipes +OpCapability PipeStorage +OpMemoryModel Physical32 OpenCL +%ps = OpTypePipeStorage +%ps2 = OpTypePipeStorage +)"; + CompileSuccessfully(str.c_str(), SPV_ENV_UNIVERSAL_1_1); + ASSERT_EQ(kDuplicateTypeError, ValidateInstructions(SPV_ENV_UNIVERSAL_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr(GetErrorString(SpvOpTypePipeStorage))); +} + +TEST_F(ValidateTypeUnique, duplicate_named_barrier) { + std::string str = R"( +OpCapability Addresses +OpCapability Kernel +OpCapability Linkage +OpCapability NamedBarrier +OpMemoryModel Physical32 OpenCL +%nb = OpTypeNamedBarrier +%nb2 = OpTypeNamedBarrier +)"; + CompileSuccessfully(str.c_str(), SPV_ENV_UNIVERSAL_1_1); + ASSERT_EQ(kDuplicateTypeError, ValidateInstructions(SPV_ENV_UNIVERSAL_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr(GetErrorString(SpvOpTypeNamedBarrier))); +} + +TEST_F(ValidateTypeUnique, duplicate_forward_pointer) { + std::string str = R"( +OpCapability Addresses +OpCapability Kernel +OpCapability GenericPointer +OpCapability Linkage +OpMemoryModel Physical32 OpenCL +OpTypeForwardPointer %ptr Generic +OpTypeForwardPointer %ptr2 Generic +%intt = OpTypeInt 32 0 +%floatt = OpTypeFloat 32 +%ptr = OpTypePointer Generic %intt +%ptr2 = OpTypePointer Generic %floatt +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateTypeUnique, duplicate_void_with_extension) { + std::string str = R"( +OpCapability Addresses +OpCapability Kernel +OpCapability Linkage +OpCapability Pipes +OpExtension "SPV_VALIDATOR_ignore_type_decl_unique" +OpMemoryModel Physical32 OpenCL +%voidt = OpTypeVoid +%voidt2 = OpTypeVoid +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + Not(HasSubstr(GetErrorString(SpvOpTypeVoid)))); +} + +TEST_F(ValidateTypeUnique, DuplicatePointerTypesNoExtension) { + std::string str = R"( +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%u32 = OpTypeInt 32 0 +%ptr1 = OpTypePointer Input %u32 +%ptr2 = OpTypePointer Input %u32 +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateTypeUnique, DuplicatePointerTypesWithExtension) { + std::string str = R"( +OpCapability Shader +OpCapability Linkage +OpExtension "SPV_KHR_variable_pointers" +OpMemoryModel Logical GLSL450 +%u32 = OpTypeInt 32 0 +%ptr1 = OpTypePointer Input %u32 +%ptr2 = OpTypePointer Input %u32 +)"; + CompileSuccessfully(str.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + Not(HasSubstr(GetErrorString(SpvOpTypePointer)))); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_validation_state_test.cpp b/test/val/val_validation_state_test.cpp new file mode 100644 index 000000000..bf6509497 --- /dev/null +++ b/test/val/val_validation_state_test.cpp @@ -0,0 +1,359 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Basic tests for the ValidationState_t datastructure. + +#include + +#include "gmock/gmock.h" +#include "source/spirv_validator_options.h" +#include "test/unit_spirv.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::HasSubstr; + +using ValidationStateTest = spvtest::ValidateBase; + +const char kHeader[] = + " OpCapability Shader" + " OpCapability Linkage" + " OpMemoryModel Logical GLSL450 "; + +const char kVulkanMemoryHeader[] = + " OpCapability Shader" + " OpCapability VulkanMemoryModelKHR" + " OpExtension \"SPV_KHR_vulkan_memory_model\"" + " OpMemoryModel Logical VulkanKHR "; + +const char kVoidFVoid[] = + " %void = OpTypeVoid" + " %void_f = OpTypeFunction %void" + " %func = OpFunction %void None %void_f" + " %label = OpLabel" + " OpReturn" + " OpFunctionEnd "; + +// k*RecursiveBody examples originally from test/opt/function_test.cpp +const char* kNonRecursiveBody = R"( +OpEntryPoint Fragment %1 "main" +OpExecutionMode %1 OriginUpperLeft +%void = OpTypeVoid +%4 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_struct_6 = OpTypeStruct %float %float +%7 = OpTypeFunction %_struct_6 +%12 = OpFunction %_struct_6 None %7 +%13 = OpLabel +OpUnreachable +OpFunctionEnd +%9 = OpFunction %_struct_6 None %7 +%10 = OpLabel +%11 = OpFunctionCall %_struct_6 %12 +OpUnreachable +OpFunctionEnd +%1 = OpFunction %void Pure|Const %4 +%8 = OpLabel +%2 = OpFunctionCall %_struct_6 %9 +OpKill +OpFunctionEnd +)"; + +const char* kDirectlyRecursiveBody = R"( +OpEntryPoint Fragment %1 "main" +OpExecutionMode %1 OriginUpperLeft +%void = OpTypeVoid +%4 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_struct_6 = OpTypeStruct %float %float +%7 = OpTypeFunction %_struct_6 +%9 = OpFunction %_struct_6 None %7 +%10 = OpLabel +%11 = OpFunctionCall %_struct_6 %9 +OpKill +OpFunctionEnd +%1 = OpFunction %void Pure|Const %4 +%8 = OpLabel +%2 = OpFunctionCall %_struct_6 %9 +OpUnreachable +OpFunctionEnd +)"; + +const char* kIndirectlyRecursiveBody = R"( +OpEntryPoint Fragment %1 "main" +OpExecutionMode %1 OriginUpperLeft +%void = OpTypeVoid +%4 = OpTypeFunction %void +%float = OpTypeFloat 32 +%_struct_6 = OpTypeStruct %float %float +%7 = OpTypeFunction %_struct_6 +%9 = OpFunction %_struct_6 None %7 +%10 = OpLabel +%11 = OpFunctionCall %_struct_6 %12 +OpUnreachable +OpFunctionEnd +%12 = OpFunction %_struct_6 None %7 +%13 = OpLabel +%14 = OpFunctionCall %_struct_6 %9 +OpUnreachable +OpFunctionEnd +%1 = OpFunction %void Pure|Const %4 +%8 = OpLabel +%2 = OpFunctionCall %_struct_6 %9 +OpKill +OpFunctionEnd +)"; + +// Tests that the instruction count in ValidationState is correct. +TEST_F(ValidationStateTest, CheckNumInstructions) { + std::string spirv = std::string(kHeader) + "%int = OpTypeInt 32 0"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); + EXPECT_EQ(size_t(4), vstate_->ordered_instructions().size()); +} + +// Tests that the number of global variables in ValidationState is correct. +TEST_F(ValidationStateTest, CheckNumGlobalVars) { + std::string spirv = std::string(kHeader) + R"( + %int = OpTypeInt 32 0 +%_ptr_int = OpTypePointer Input %int + %var_1 = OpVariable %_ptr_int Input + %var_2 = OpVariable %_ptr_int Input + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); + EXPECT_EQ(unsigned(2), vstate_->num_global_vars()); +} + +// Tests that the number of local variables in ValidationState is correct. +TEST_F(ValidationStateTest, CheckNumLocalVars) { + std::string spirv = std::string(kHeader) + R"( + %int = OpTypeInt 32 0 + %_ptr_int = OpTypePointer Function %int + %voidt = OpTypeVoid + %funct = OpTypeFunction %voidt + %main = OpFunction %voidt None %funct + %entry = OpLabel + %var_1 = OpVariable %_ptr_int Function + %var_2 = OpVariable %_ptr_int Function + %var_3 = OpVariable %_ptr_int Function + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); + EXPECT_EQ(unsigned(3), vstate_->num_local_vars()); +} + +// Tests that the "id bound" in ValidationState is correct. +TEST_F(ValidationStateTest, CheckIdBound) { + std::string spirv = std::string(kHeader) + R"( + %int = OpTypeInt 32 0 + %voidt = OpTypeVoid + )"; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); + EXPECT_EQ(unsigned(3), vstate_->getIdBound()); +} + +// Tests that the entry_points in ValidationState is correct. +TEST_F(ValidationStateTest, CheckEntryPoints) { + std::string spirv = std::string(kHeader) + + " OpEntryPoint Vertex %func \"shader\"" + + std::string(kVoidFVoid); + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); + EXPECT_EQ(size_t(1), vstate_->entry_points().size()); + EXPECT_EQ(SpvOpFunction, + vstate_->FindDef(vstate_->entry_points()[0])->opcode()); +} + +TEST_F(ValidationStateTest, CheckStructMemberLimitOption) { + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_struct_members, 32000u); + EXPECT_EQ(32000u, options_->universal_limits_.max_struct_members); +} + +TEST_F(ValidationStateTest, CheckNumGlobalVarsLimitOption) { + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_global_variables, 100u); + EXPECT_EQ(100u, options_->universal_limits_.max_global_variables); +} + +TEST_F(ValidationStateTest, CheckNumLocalVarsLimitOption) { + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_local_variables, 100u); + EXPECT_EQ(100u, options_->universal_limits_.max_local_variables); +} + +TEST_F(ValidationStateTest, CheckStructDepthLimitOption) { + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_struct_depth, 100u); + EXPECT_EQ(100u, options_->universal_limits_.max_struct_depth); +} + +TEST_F(ValidationStateTest, CheckSwitchBranchesLimitOption) { + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_switch_branches, 100u); + EXPECT_EQ(100u, options_->universal_limits_.max_switch_branches); +} + +TEST_F(ValidationStateTest, CheckFunctionArgsLimitOption) { + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_function_args, 100u); + EXPECT_EQ(100u, options_->universal_limits_.max_function_args); +} + +TEST_F(ValidationStateTest, CheckCFGDepthLimitOption) { + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_control_flow_nesting_depth, 100u); + EXPECT_EQ(100u, options_->universal_limits_.max_control_flow_nesting_depth); +} + +TEST_F(ValidationStateTest, CheckAccessChainIndexesLimitOption) { + spvValidatorOptionsSetUniversalLimit( + options_, spv_validator_limit_max_access_chain_indexes, 100u); + EXPECT_EQ(100u, options_->universal_limits_.max_access_chain_indexes); +} + +TEST_F(ValidationStateTest, CheckNonRecursiveBodyGood) { + std::string spirv = std::string(kHeader) + kNonRecursiveBody; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidationStateTest, CheckVulkanNonRecursiveBodyGood) { + std::string spirv = std::string(kVulkanMemoryHeader) + kNonRecursiveBody; + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_SUCCESS, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1)); +} + +TEST_F(ValidationStateTest, CheckWebGPUNonRecursiveBodyGood) { + std::string spirv = std::string(kVulkanMemoryHeader) + kNonRecursiveBody; + CompileSuccessfully(spirv, SPV_ENV_WEBGPU_0); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState(SPV_ENV_WEBGPU_0)); +} + +TEST_F(ValidationStateTest, CheckDirectlyRecursiveBodyGood) { + std::string spirv = std::string(kHeader) + kDirectlyRecursiveBody; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidationStateTest, CheckVulkanDirectlyRecursiveBodyBad) { + std::string spirv = std::string(kVulkanMemoryHeader) + kDirectlyRecursiveBody; + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Entry points may not have a call graph with cycles.\n " + " %1 = OpFunction %void Pure|Const %3\n")); +} + +TEST_F(ValidationStateTest, CheckWebGPUDirectlyRecursiveBodyBad) { + std::string spirv = std::string(kVulkanMemoryHeader) + kDirectlyRecursiveBody; + CompileSuccessfully(spirv, SPV_ENV_WEBGPU_0); + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + ValidateAndRetrieveValidationState(SPV_ENV_WEBGPU_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Entry points may not have a call graph with cycles.\n " + " %1 = OpFunction %void Pure|Const %3\n")); +} + +TEST_F(ValidationStateTest, CheckIndirectlyRecursiveBodyGood) { + std::string spirv = std::string(kHeader) + kIndirectlyRecursiveBody; + CompileSuccessfully(spirv); + EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState()); +} + +TEST_F(ValidationStateTest, CheckVulkanIndirectlyRecursiveBodyBad) { + std::string spirv = + std::string(kVulkanMemoryHeader) + kIndirectlyRecursiveBody; + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Entry points may not have a call graph with cycles.\n " + " %1 = OpFunction %void Pure|Const %3\n")); +} + +// Indirectly recursive functions are caught by the function definition layout +// rules, because they cause a situation where there are 2 functions that have +// to be before each other, and layout is checked earlier. +TEST_F(ValidationStateTest, CheckWebGPUIndirectlyRecursiveBodyBad) { + std::string spirv = + std::string(kVulkanMemoryHeader) + kIndirectlyRecursiveBody; + CompileSuccessfully(spirv, SPV_ENV_WEBGPU_0); + EXPECT_EQ(SPV_ERROR_INVALID_LAYOUT, + ValidateAndRetrieveValidationState(SPV_ENV_WEBGPU_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("For WebGPU, functions need to be defined before being " + "called.\n %9 = OpFunctionCall %_struct_5 %10\n")); +} + +TEST_F(ValidationStateTest, + CheckWebGPUDuplicateEntryNamesDifferentFunctionsBad) { + std::string spirv = std::string(kVulkanMemoryHeader) + R"( +OpEntryPoint Fragment %func_1 "main" +OpEntryPoint Vertex %func_2 "main" +OpExecutionMode %func_1 OriginUpperLeft +%void = OpTypeVoid +%void_f = OpTypeFunction %void +%func_1 = OpFunction %void None %void_f +%label_1 = OpLabel + OpReturn + OpFunctionEnd +%func_2 = OpFunction %void None %void_f +%label_2 = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_WEBGPU_0); + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + ValidateAndRetrieveValidationState(SPV_ENV_WEBGPU_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Entry point name \"main\" is not unique, which is not allow " + "in WebGPU env.\n %1 = OpFunction %void None %4\n")); +} + +TEST_F(ValidationStateTest, CheckWebGPUDuplicateEntryNamesSameFunctionBad) { + std::string spirv = std::string(kVulkanMemoryHeader) + R"( +OpEntryPoint GLCompute %func_1 "main" +OpEntryPoint Vertex %func_1 "main" +%void = OpTypeVoid +%void_f = OpTypeFunction %void +%func_1 = OpFunction %void None %void_f +%label_1 = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_WEBGPU_0); + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + ValidateAndRetrieveValidationState(SPV_ENV_WEBGPU_0)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Entry point name \"main\" is not unique, which is not allow " + "in WebGPU env.\n %1 = OpFunction %void None %3\n")); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_version_test.cpp b/test/val/val_version_test.cpp new file mode 100644 index 000000000..eb9bbb960 --- /dev/null +++ b/test/val/val_version_test.cpp @@ -0,0 +1,294 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gmock/gmock.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::HasSubstr; + +using ValidateVersion = spvtest::ValidateBase< + std::tuple>; + +const std::string vulkan_spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + +const std::string webgpu_spirv = R"( +OpCapability Shader +OpCapability VulkanMemoryModelKHR +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical VulkanKHR +OpEntryPoint Fragment %func "func" +OpExecutionMode %func OriginUpperLeft +%void = OpTypeVoid +%functy = OpTypeFunction %void +%func = OpFunction %void None %functy +%1 = OpLabel +OpReturn +OpFunctionEnd +)"; + +const std::string opencl_spirv = R"( +OpCapability Kernel +OpCapability Linkage +OpMemoryModel Logical OpenCL +)"; + +std::string version(spv_target_env env) { + switch (env) { + case SPV_ENV_UNIVERSAL_1_0: + case SPV_ENV_VULKAN_1_0: + case SPV_ENV_OPENGL_4_0: + case SPV_ENV_OPENGL_4_1: + case SPV_ENV_OPENGL_4_2: + case SPV_ENV_OPENGL_4_3: + case SPV_ENV_OPENGL_4_5: + case SPV_ENV_OPENCL_1_2: + case SPV_ENV_OPENCL_2_0: + case SPV_ENV_OPENCL_EMBEDDED_2_0: + return "1.0"; + case SPV_ENV_UNIVERSAL_1_1: + case SPV_ENV_OPENCL_2_1: + case SPV_ENV_OPENCL_EMBEDDED_2_1: + return "1.1"; + case SPV_ENV_UNIVERSAL_1_2: + case SPV_ENV_OPENCL_2_2: + case SPV_ENV_OPENCL_EMBEDDED_2_2: + return "1.2"; + case SPV_ENV_UNIVERSAL_1_3: + case SPV_ENV_VULKAN_1_1: + case SPV_ENV_WEBGPU_0: + return "1.3"; + default: + return "0"; + } +} + +TEST_P(ValidateVersion, version) { + CompileSuccessfully(std::get<2>(GetParam()), std::get<0>(GetParam())); + spv_result_t res = ValidateInstructions(std::get<1>(GetParam())); + if (std::get<3>(GetParam())) { + ASSERT_EQ(SPV_SUCCESS, res); + } else { + ASSERT_EQ(SPV_ERROR_WRONG_VERSION, res); + + std::string msg = "Invalid SPIR-V binary version "; + msg += version(std::get<0>(GetParam())); + msg += " for target environment "; + msg += spvTargetEnvDescription(std::get<1>(GetParam())); + EXPECT_THAT(getDiagnosticString(), HasSubstr(msg)); + } +} + +// clang-format off +INSTANTIATE_TEST_CASE_P(Universal, ValidateVersion, + ::testing::Values( + // Binary version, Target environment + std::make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_0, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_2, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_3, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_VULKAN_1_0, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_VULKAN_1_1, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_OPENGL_4_0, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_OPENGL_4_1, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_OPENGL_4_2, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_OPENGL_4_3, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_OPENGL_4_5, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_WEBGPU_0, webgpu_spirv, true), + + std::make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_UNIVERSAL_1_0, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_UNIVERSAL_1_1, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_UNIVERSAL_1_2, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_UNIVERSAL_1_3, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_VULKAN_1_0, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_VULKAN_1_1, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_OPENGL_4_0, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_OPENGL_4_1, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_OPENGL_4_2, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_OPENGL_4_3, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_OPENGL_4_5, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_1, SPV_ENV_WEBGPU_0, webgpu_spirv, true), + + std::make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_UNIVERSAL_1_0, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_UNIVERSAL_1_1, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_UNIVERSAL_1_2, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_UNIVERSAL_1_3, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_VULKAN_1_0, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_VULKAN_1_1, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_OPENGL_4_0, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_OPENGL_4_1, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_OPENGL_4_2, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_OPENGL_4_3, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_OPENGL_4_5, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_2, SPV_ENV_WEBGPU_0, webgpu_spirv, true), + + std::make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_UNIVERSAL_1_0, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_UNIVERSAL_1_1, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_UNIVERSAL_1_2, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_UNIVERSAL_1_3, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_VULKAN_1_0, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_VULKAN_1_1, vulkan_spirv, true), + std::make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_OPENGL_4_0, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_OPENGL_4_1, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_OPENGL_4_2, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_OPENGL_4_3, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_OPENGL_4_5, vulkan_spirv, false), + std::make_tuple(SPV_ENV_UNIVERSAL_1_3, SPV_ENV_WEBGPU_0, webgpu_spirv, true) + ) +); + +INSTANTIATE_TEST_CASE_P(Vulkan, ValidateVersion, + ::testing::Values( + // Binary version, Target environment + std::make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_UNIVERSAL_1_0, vulkan_spirv, true), + std::make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_UNIVERSAL_1_1, vulkan_spirv, true), + std::make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_UNIVERSAL_1_2, vulkan_spirv, true), + std::make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_UNIVERSAL_1_3, vulkan_spirv, true), + std::make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_VULKAN_1_0, vulkan_spirv, true), + std::make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_VULKAN_1_1, vulkan_spirv, true), + std::make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_OPENGL_4_0, vulkan_spirv, true), + std::make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_OPENGL_4_1, vulkan_spirv, true), + std::make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_OPENGL_4_2, vulkan_spirv, true), + std::make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_OPENGL_4_3, vulkan_spirv, true), + std::make_tuple(SPV_ENV_VULKAN_1_0, SPV_ENV_OPENGL_4_5, vulkan_spirv, true), + + std::make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_UNIVERSAL_1_0, vulkan_spirv, false), + std::make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_UNIVERSAL_1_1, vulkan_spirv, false), + std::make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_UNIVERSAL_1_2, vulkan_spirv, false), + std::make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_UNIVERSAL_1_3, vulkan_spirv, true), + std::make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_VULKAN_1_0, vulkan_spirv, false), + std::make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_VULKAN_1_1, vulkan_spirv, true), + std::make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_OPENGL_4_0, vulkan_spirv, false), + std::make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_OPENGL_4_1, vulkan_spirv, false), + std::make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_OPENGL_4_2, vulkan_spirv, false), + std::make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_OPENGL_4_3, vulkan_spirv, false), + std::make_tuple(SPV_ENV_VULKAN_1_1, SPV_ENV_OPENGL_4_5, vulkan_spirv, false) + ) +); + +INSTANTIATE_TEST_CASE_P(OpenCL, ValidateVersion, + ::testing::Values( + // Binary version, Target environment + std::make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_UNIVERSAL_1_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_UNIVERSAL_1_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_UNIVERSAL_1_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_UNIVERSAL_1_3, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_OPENCL_2_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_OPENCL_2_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_OPENCL_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_OPENCL_EMBEDDED_2_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_OPENCL_EMBEDDED_2_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_OPENCL_EMBEDDED_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_0, SPV_ENV_OPENCL_1_2, opencl_spirv, true), + + std::make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_UNIVERSAL_1_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_UNIVERSAL_1_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_UNIVERSAL_1_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_UNIVERSAL_1_3, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_OPENCL_2_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_OPENCL_2_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_OPENCL_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_OPENCL_EMBEDDED_2_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_OPENCL_EMBEDDED_2_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_OPENCL_EMBEDDED_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_1, SPV_ENV_OPENCL_1_2, opencl_spirv, true), + + std::make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_UNIVERSAL_1_0, opencl_spirv, false), + std::make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_UNIVERSAL_1_1, opencl_spirv, false), + std::make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_UNIVERSAL_1_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_UNIVERSAL_1_3, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_OPENCL_2_0, opencl_spirv, false), + std::make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_OPENCL_2_1, opencl_spirv, false), + std::make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_OPENCL_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_OPENCL_EMBEDDED_2_0, opencl_spirv, false), + std::make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_OPENCL_EMBEDDED_2_1, opencl_spirv, false), + std::make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_OPENCL_EMBEDDED_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_2_2, SPV_ENV_OPENCL_1_2, opencl_spirv, false), + + std::make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_UNIVERSAL_1_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_UNIVERSAL_1_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_UNIVERSAL_1_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_UNIVERSAL_1_3, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_OPENCL_2_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_OPENCL_2_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_OPENCL_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_OPENCL_EMBEDDED_2_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_OPENCL_EMBEDDED_2_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_OPENCL_EMBEDDED_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_1_2, SPV_ENV_OPENCL_1_2, opencl_spirv, true) + ) +); + +INSTANTIATE_TEST_CASE_P(OpenCLEmbedded, ValidateVersion, + ::testing::Values( + // Binary version, Target environment + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_UNIVERSAL_1_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_UNIVERSAL_1_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_UNIVERSAL_1_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_UNIVERSAL_1_3, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_OPENCL_2_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_OPENCL_2_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_OPENCL_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_OPENCL_EMBEDDED_2_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_OPENCL_EMBEDDED_2_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_OPENCL_EMBEDDED_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_0, SPV_ENV_OPENCL_1_2, opencl_spirv, true), + + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_UNIVERSAL_1_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_UNIVERSAL_1_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_UNIVERSAL_1_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_UNIVERSAL_1_3, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_OPENCL_2_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_OPENCL_2_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_OPENCL_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_OPENCL_EMBEDDED_2_0, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_OPENCL_EMBEDDED_2_1, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_OPENCL_EMBEDDED_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_1, SPV_ENV_OPENCL_1_2, opencl_spirv, true), + + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_UNIVERSAL_1_0, opencl_spirv, false), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_UNIVERSAL_1_1, opencl_spirv, false), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_UNIVERSAL_1_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_UNIVERSAL_1_3, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_OPENCL_2_0, opencl_spirv, false), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_OPENCL_2_1, opencl_spirv, false), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_OPENCL_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_OPENCL_EMBEDDED_2_0, opencl_spirv, false), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_OPENCL_EMBEDDED_2_1, opencl_spirv, false), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_OPENCL_EMBEDDED_2_2, opencl_spirv, true), + std::make_tuple(SPV_ENV_OPENCL_EMBEDDED_2_2, SPV_ENV_OPENCL_1_2, opencl_spirv, false) + ) +); +// clang-format on + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/test/val/val_webgpu_test.cpp b/test/val/val_webgpu_test.cpp new file mode 100644 index 000000000..48ea21db2 --- /dev/null +++ b/test/val/val_webgpu_test.cpp @@ -0,0 +1,281 @@ +// Copyright (c) 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validation tests for WebGPU env specific checks + +#include + +#include "gmock/gmock.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using testing::HasSubstr; + +using ValidateWebGPU = spvtest::ValidateBase; + +TEST_F(ValidateWebGPU, OpUndefIsDisallowed) { + std::string spirv = R"( + OpCapability Shader + OpCapability VulkanMemoryModelKHR + OpExtension "SPV_KHR_vulkan_memory_model" + OpMemoryModel Logical VulkanKHR + OpEntryPoint Vertex %func "shader" +%float = OpTypeFloat 32 +%1 = OpUndef %float +%void = OpTypeVoid +%void_f = OpTypeFunction %void +%func = OpFunction %void None %void_f +%label = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + + // Control case: OpUndef is allowed in SPIR-V 1.3 + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + + // Control case: OpUndef is disallowed in the WebGPU env + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT(getDiagnosticString(), HasSubstr("OpUndef is disallowed")); +} + +TEST_F(ValidateWebGPU, OpNameIsDisallowed) { + std::string spirv = R"( + OpCapability Shader + OpCapability VulkanMemoryModelKHR + OpExtension "SPV_KHR_vulkan_memory_model" + OpMemoryModel Logical VulkanKHR + OpName %1 "foo" +%1 = OpTypeFloat 32 +)"; + + CompileSuccessfully(spirv); + + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Debugging instructions are not allowed in the WebGPU " + "execution environment.\n OpName %foo \"foo\"\n")); +} + +TEST_F(ValidateWebGPU, OpMemberNameIsDisallowed) { + std::string spirv = R"( + OpCapability Shader + OpCapability VulkanMemoryModelKHR + OpExtension "SPV_KHR_vulkan_memory_model" + OpMemoryModel Logical VulkanKHR + OpMemberName %2 0 "foo" +%1 = OpTypeFloat 32 +%2 = OpTypeStruct %1 +)"; + + CompileSuccessfully(spirv); + + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Debugging instructions are not allowed in the WebGPU " + "execution environment.\n OpMemberName %_struct_1 0 " + "\"foo\"\n")); +} + +TEST_F(ValidateWebGPU, OpSourceIsDisallowed) { + std::string spirv = R"( + OpCapability Shader + OpCapability VulkanMemoryModelKHR + OpExtension "SPV_KHR_vulkan_memory_model" + OpMemoryModel Logical VulkanKHR + OpSource GLSL 450 +)"; + + CompileSuccessfully(spirv); + + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Debugging instructions are not allowed in the WebGPU " + "execution environment.\n OpSource GLSL 450\n")); +} + +// OpSourceContinued does not have a test case, because it requires being +// preceded by OpSource, which will cause a validation error. + +TEST_F(ValidateWebGPU, OpSourceExtensionIsDisallowed) { + std::string spirv = R"( + OpCapability Shader + OpCapability VulkanMemoryModelKHR + OpExtension "SPV_KHR_vulkan_memory_model" + OpMemoryModel Logical VulkanKHR + OpSourceExtension "bar" +)"; + + CompileSuccessfully(spirv); + + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Debugging instructions are not allowed in the WebGPU " + "execution environment.\n OpSourceExtension " + "\"bar\"\n")); +} + +TEST_F(ValidateWebGPU, OpStringIsDisallowed) { + std::string spirv = R"( + OpCapability Shader + OpCapability VulkanMemoryModelKHR + OpExtension "SPV_KHR_vulkan_memory_model" + OpMemoryModel Logical VulkanKHR +%1 = OpString "foo" +)"; + + CompileSuccessfully(spirv); + + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Debugging instructions are not allowed in the WebGPU " + "execution environment.\n %1 = OpString \"foo\"\n")); +} + +// OpLine does not have a test case, because it requires being preceded by +// OpString, which will cause a validation error. + +TEST_F(ValidateWebGPU, OpNoLineDisallowed) { + std::string spirv = R"( + OpCapability Shader + OpCapability VulkanMemoryModelKHR + OpExtension "SPV_KHR_vulkan_memory_model" + OpMemoryModel Logical VulkanKHR + OpNoLine +)"; + + CompileSuccessfully(spirv); + + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Debugging instructions are not allowed in the WebGPU " + "execution environment.\n OpNoLine\n")); +} + +TEST_F(ValidateWebGPU, LogicalAddressingVulkanKHRMemoryGood) { + std::string spirv = R"( + OpCapability Shader + OpCapability VulkanMemoryModelKHR + OpExtension "SPV_KHR_vulkan_memory_model" + OpMemoryModel Logical VulkanKHR + OpEntryPoint Vertex %func "shader" +%void = OpTypeVoid +%void_f = OpTypeFunction %void +%func = OpFunction %void None %void_f +%label = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0)); +} + +TEST_F(ValidateWebGPU, NonLogicalAddressingModelBad) { + std::string spirv = R"( + OpCapability Shader + OpCapability VulkanMemoryModelKHR + OpExtension "SPV_KHR_vulkan_memory_model" + OpMemoryModel Physical32 VulkanKHR +)"; + + CompileSuccessfully(spirv); + + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Addressing model must be Logical for WebGPU " + "environment.\n OpMemoryModel Physical32 " + "VulkanKHR\n")); +} + +TEST_F(ValidateWebGPU, NonVulkanKHRMemoryModelBad) { + std::string spirv = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpNoLine +)"; + + CompileSuccessfully(spirv); + + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Memory model must be VulkanKHR for WebGPU " + "environment.\n OpMemoryModel Logical GLSL450\n")); +} + +TEST_F(ValidateWebGPU, WhitelistedExtendedInstructionsImportGood) { + std::string spirv = R"( + OpCapability Shader + OpCapability VulkanMemoryModelKHR + OpExtension "SPV_KHR_vulkan_memory_model" +%1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical VulkanKHR + OpEntryPoint Vertex %func "shader" +%void = OpTypeVoid +%void_f = OpTypeFunction %void +%func = OpFunction %void None %void_f +%label = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0)); +} + +TEST_F(ValidateWebGPU, NonWhitelistedExtendedInstructionsImportBad) { + std::string spirv = R"( + OpCapability Shader + OpCapability VulkanMemoryModelKHR + OpExtension "SPV_KHR_vulkan_memory_model" +%1 = OpExtInstImport "OpenCL.std" + OpMemoryModel Logical VulkanKHR +)"; + + CompileSuccessfully(spirv); + + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("For WebGPU, the only valid parameter to " + "OpExtInstImport is \"GLSL.std.450\".\n %1 = " + "OpExtInstImport \"OpenCL.std\"\n")); +} + +TEST_F(ValidateWebGPU, NonVulkanKHRMemoryModelExtensionBad) { + std::string spirv = R"( + OpCapability Shader + OpCapability VulkanMemoryModelKHR + OpExtension "SPV_KHR_8bit_storage" + OpExtension "SPV_KHR_vulkan_memory_model" + OpMemoryModel Logical VulkanKHR +)"; + + CompileSuccessfully(spirv); + + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("For WebGPU, the only valid parameter to OpExtension " + "is \"SPV_KHR_vulkan_memory_model\".\n OpExtension " + "\"SPV_KHR_8bit_storage\"\n")); +} + +} // namespace +} // namespace val +} // namespace spvtools diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt new file mode 100644 index 000000000..9fb3a91a0 --- /dev/null +++ b/tools/CMakeLists.txt @@ -0,0 +1,84 @@ +# Copyright (c) 2015-2016 The Khronos Group Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_subdirectory(lesspipe) +add_subdirectory(emacs) + +# Add a SPIR-V Tools command line tool. Signature: +# add_spvtools_tool( +# TARGET target_name +# SRCS src_file1.cpp src_file2.cpp +# LIBS lib_target1 lib_target2 +# ) +function(add_spvtools_tool) + set(one_value_args TARGET) + set(multi_value_args SRCS LIBS) + cmake_parse_arguments( + ARG "" "${one_value_args}" "${multi_value_args}" ${ARGN}) + + add_executable(${ARG_TARGET} ${ARG_SRCS}) + spvtools_default_compile_options(${ARG_TARGET}) + target_link_libraries(${ARG_TARGET} PRIVATE ${ARG_LIBS}) + target_include_directories(${ARG_TARGET} PRIVATE + ${spirv-tools_SOURCE_DIR} + ${spirv-tools_BINARY_DIR} + ) + set_property(TARGET ${ARG_TARGET} PROPERTY FOLDER "SPIRV-Tools executables") +endfunction() + +if (NOT ${SPIRV_SKIP_EXECUTABLES}) + add_spvtools_tool(TARGET spirv-as SRCS as/as.cpp LIBS ${SPIRV_TOOLS}) + add_spvtools_tool(TARGET spirv-dis SRCS dis/dis.cpp LIBS ${SPIRV_TOOLS}) + add_spvtools_tool(TARGET spirv-val SRCS val/val.cpp util/cli_consumer.cpp LIBS ${SPIRV_TOOLS}) + add_spvtools_tool(TARGET spirv-opt SRCS opt/opt.cpp util/cli_consumer.cpp LIBS SPIRV-Tools-opt ${SPIRV_TOOLS}) + add_spvtools_tool(TARGET spirv-reduce SRCS reduce/reduce.cpp util/cli_consumer.cpp LIBS SPIRV-Tools-reduce ${SPIRV_TOOLS}) + add_spvtools_tool(TARGET spirv-link SRCS link/linker.cpp LIBS SPIRV-Tools-link ${SPIRV_TOOLS}) + add_spvtools_tool(TARGET spirv-stats + SRCS stats/stats.cpp + stats/stats_analyzer.cpp + stats/stats_analyzer.h + stats/spirv_stats.cpp + stats/spirv_stats.h + LIBS ${SPIRV_TOOLS}) + add_spvtools_tool(TARGET spirv-cfg + SRCS cfg/cfg.cpp + cfg/bin_to_dot.h + cfg/bin_to_dot.cpp + LIBS ${SPIRV_TOOLS}) + target_include_directories(spirv-cfg PRIVATE ${spirv-tools_SOURCE_DIR} + ${SPIRV_HEADER_INCLUDE_DIR}) + target_include_directories(spirv-stats PRIVATE ${spirv-tools_SOURCE_DIR} + ${SPIRV_HEADER_INCLUDE_DIR}) + + set(SPIRV_INSTALL_TARGETS spirv-as spirv-dis spirv-val spirv-opt spirv-stats + spirv-cfg spirv-link spirv-reduce) + + if(SPIRV_BUILD_COMPRESSION) + add_spvtools_tool(TARGET spirv-markv + SRCS comp/markv.cpp + comp/markv_model_factory.cpp + comp/markv_model_shader.cpp + LIBS SPIRV-Tools-comp SPIRV-Tools-opt ${SPIRV_TOOLS}) + target_include_directories(spirv-markv PRIVATE ${spirv-tools_SOURCE_DIR} + ${SPIRV_HEADER_INCLUDE_DIR}) + set(SPIRV_INSTALL_TARGETS ${SPIRV_INSTALL_TARGETS} spirv-markv) + endif(SPIRV_BUILD_COMPRESSION) + + if(ENABLE_SPIRV_TOOLS_INSTALL) + install(TARGETS ${SPIRV_INSTALL_TARGETS} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}) + endif(ENABLE_SPIRV_TOOLS_INSTALL) +endif() diff --git a/tools/as/as.cpp b/tools/as/as.cpp new file mode 100644 index 000000000..287ba51f8 --- /dev/null +++ b/tools/as/as.cpp @@ -0,0 +1,154 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "source/spirv_target_env.h" +#include "spirv-tools/libspirv.h" +#include "tools/io.h" + +void print_usage(char* argv0) { + printf( + R"(%s - Create a SPIR-V binary module from SPIR-V assembly text + +Usage: %s [options] [] + +The SPIR-V assembly text is read from . If no file is specified, +or if the filename is "-", then the assembly text is read from standard input. +The SPIR-V binary module is written to file "out.spv", unless the -o option +is used. + +Options: + + -h, --help Print this help. + + -o Set the output filename. Use '-' to mean stdout. + --version Display assembler version information. + --preserve-numeric-ids + Numeric IDs in the binary will have the same values as in the + source. Non-numeric IDs are allocated by filling in the gaps, + starting with 1 and going up. + --target-env {vulkan1.0|vulkan1.1|spv1.0|spv1.1|spv1.2|spv1.3} + Use Vulkan 1.0, Vulkan 1.1, SPIR-V 1.0, SPIR-V 1.1, + SPIR-V 1.2, or SPIR-V 1.3 +)", + argv0, argv0); +} + +static const auto kDefaultEnvironment = SPV_ENV_UNIVERSAL_1_3; + +int main(int argc, char** argv) { + const char* inFile = nullptr; + const char* outFile = nullptr; + uint32_t options = 0; + spv_target_env target_env = kDefaultEnvironment; + for (int argi = 1; argi < argc; ++argi) { + if ('-' == argv[argi][0]) { + switch (argv[argi][1]) { + case 'h': { + print_usage(argv[0]); + return 0; + } + case 'o': { + if (!outFile && argi + 1 < argc) { + outFile = argv[++argi]; + } else { + print_usage(argv[0]); + return 1; + } + } break; + case 0: { + // Setting a filename of "-" to indicate stdin. + if (!inFile) { + inFile = argv[argi]; + } else { + fprintf(stderr, "error: More than one input file specified\n"); + return 1; + } + } break; + case '-': { + // Long options + if (0 == strcmp(argv[argi], "--version")) { + printf("%s\n", spvSoftwareVersionDetailsString()); + printf("Target: %s\n", + spvTargetEnvDescription(kDefaultEnvironment)); + return 0; + } else if (0 == strcmp(argv[argi], "--help")) { + print_usage(argv[0]); + return 0; + } else if (0 == strcmp(argv[argi], "--preserve-numeric-ids")) { + options |= SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS; + } else if (0 == strcmp(argv[argi], "--target-env")) { + if (argi + 1 < argc) { + const auto env_str = argv[++argi]; + if (!spvParseTargetEnv(env_str, &target_env)) { + fprintf(stderr, "error: Unrecognized target env: %s\n", + env_str); + return 1; + } + } else { + fprintf(stderr, "error: Missing argument to --target-env\n"); + return 1; + } + } else { + fprintf(stderr, "error: Unrecognized option: %s\n\n", argv[argi]); + print_usage(argv[0]); + return 1; + } + } break; + default: + fprintf(stderr, "error: Unrecognized option: %s\n\n", argv[argi]); + print_usage(argv[0]); + return 1; + } + } else { + if (!inFile) { + inFile = argv[argi]; + } else { + fprintf(stderr, "error: More than one input file specified\n"); + return 1; + } + } + } + + if (!outFile) { + outFile = "out.spv"; + } + + std::vector contents; + if (!ReadFile(inFile, "r", &contents)) return 1; + + spv_binary binary; + spv_diagnostic diagnostic = nullptr; + spv_context context = spvContextCreate(target_env); + spv_result_t error = spvTextToBinaryWithOptions( + context, contents.data(), contents.size(), options, &binary, &diagnostic); + spvContextDestroy(context); + if (error) { + spvDiagnosticPrint(diagnostic); + spvDiagnosticDestroy(diagnostic); + return error; + } + + if (!WriteFile(outFile, "wb", binary->code, binary->wordCount)) { + spvBinaryDestroy(binary); + return 1; + } + + spvBinaryDestroy(binary); + + return 0; +} diff --git a/tools/cfg/bin_to_dot.cpp b/tools/cfg/bin_to_dot.cpp new file mode 100644 index 000000000..2561eea40 --- /dev/null +++ b/tools/cfg/bin_to_dot.cpp @@ -0,0 +1,187 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/cfg/bin_to_dot.h" + +#include +#include +#include +#include + +#include "source/assembly_grammar.h" +#include "source/name_mapper.h" + +namespace { + +const char* kMergeStyle = "style=dashed"; +const char* kContinueStyle = "style=dotted"; + +// A DotConverter can be used to dump the GraphViz "dot" graph for +// a SPIR-V module. +class DotConverter { + public: + DotConverter(spvtools::NameMapper name_mapper, std::iostream* out) + : name_mapper_(std::move(name_mapper)), out_(*out) {} + + // Emits the graph preamble. + void Begin() const { + out_ << "digraph {\n"; + // Emit a simple legend + out_ << "legend_merge_src [shape=plaintext, label=\"\"];\n" + << "legend_merge_dest [shape=plaintext, label=\"\"];\n" + << "legend_merge_src -> legend_merge_dest [label=\" merge\"," + << kMergeStyle << "];\n" + << "legend_continue_src [shape=plaintext, label=\"\"];\n" + << "legend_continue_dest [shape=plaintext, label=\"\"];\n" + << "legend_continue_src -> legend_continue_dest [label=\" continue\"," + << kContinueStyle << "];\n"; + } + // Emits the graph postamble. + void End() const { out_ << "}\n"; } + + // Emits the Dot commands for the given instruction. + spv_result_t HandleInstruction(const spv_parsed_instruction_t& inst); + + private: + // Ends processing for the current block, emitting its dot code. + void FlushBlock(const std::vector& successors); + + // The ID of the current functio, or 0 if outside of a function. + uint32_t current_function_id_ = 0; + + // The ID of the current basic block, or 0 if outside of a block. + uint32_t current_block_id_ = 0; + + // Have we completed processing for the entry block to this fuction? + bool seen_function_entry_block_ = false; + + // The Id of the merge block for this block if it exists, or 0 otherwise. + uint32_t merge_ = 0; + // The Id of the continue target block for this block if it exists, or 0 + // otherwise. + uint32_t continue_target_ = 0; + + // An object for mapping Ids to names. + spvtools::NameMapper name_mapper_; + + // The output stream. + std::ostream& out_; +}; + +spv_result_t DotConverter::HandleInstruction( + const spv_parsed_instruction_t& inst) { + switch (inst.opcode) { + case SpvOpFunction: + current_function_id_ = inst.result_id; + seen_function_entry_block_ = false; + break; + case SpvOpFunctionEnd: + current_function_id_ = 0; + break; + + case SpvOpLabel: + current_block_id_ = inst.result_id; + break; + + case SpvOpBranch: + FlushBlock({inst.words[1]}); + break; + case SpvOpBranchConditional: + FlushBlock({inst.words[2], inst.words[3]}); + break; + case SpvOpSwitch: { + std::vector successors{inst.words[2]}; + for (size_t i = 3; i < inst.num_operands; i += 2) { + successors.push_back(inst.words[inst.operands[i].offset]); + } + FlushBlock(successors); + } break; + + case SpvOpKill: + case SpvOpReturn: + case SpvOpUnreachable: + case SpvOpReturnValue: + FlushBlock({}); + break; + + case SpvOpLoopMerge: + merge_ = inst.words[1]; + continue_target_ = inst.words[2]; + break; + case SpvOpSelectionMerge: + merge_ = inst.words[1]; + break; + default: + break; + } + return SPV_SUCCESS; +} + +void DotConverter::FlushBlock(const std::vector& successors) { + out_ << current_block_id_; + if (!seen_function_entry_block_) { + out_ << " [label=\"" << name_mapper_(current_block_id_) << "\nFn " + << name_mapper_(current_function_id_) << " entry\", shape=box];\n"; + } else { + out_ << " [label=\"" << name_mapper_(current_block_id_) << "\"];\n"; + } + + for (auto successor : successors) { + out_ << current_block_id_ << " -> " << successor << ";\n"; + } + + if (merge_) { + out_ << current_block_id_ << " -> " << merge_ << " [" << kMergeStyle + << "];\n"; + } + if (continue_target_) { + out_ << current_block_id_ << " -> " << continue_target_ << " [" + << kContinueStyle << "];\n"; + } + + // Reset the book-keeping for a block. + seen_function_entry_block_ = true; + merge_ = 0; + continue_target_ = 0; +} + +spv_result_t HandleInstruction( + void* user_data, const spv_parsed_instruction_t* parsed_instruction) { + assert(user_data); + auto converter = static_cast(user_data); + return converter->HandleInstruction(*parsed_instruction); +} + +} // anonymous namespace + +spv_result_t BinaryToDot(const spv_const_context context, const uint32_t* words, + size_t num_words, std::iostream* out, + spv_diagnostic* diagnostic) { + // Invalid arguments return error codes, but don't necessarily generate + // diagnostics. These are programmer errors, not user errors. + if (!diagnostic) return SPV_ERROR_INVALID_DIAGNOSTIC; + const spvtools::AssemblyGrammar grammar(context); + if (!grammar.isValid()) return SPV_ERROR_INVALID_TABLE; + + spvtools::FriendlyNameMapper friendly_mapper(context, words, num_words); + DotConverter converter(friendly_mapper.GetNameMapper(), out); + converter.Begin(); + if (auto error = spvBinaryParse(context, &converter, words, num_words, + nullptr, HandleInstruction, diagnostic)) { + return error; + } + converter.End(); + + return SPV_SUCCESS; +} diff --git a/tools/cfg/bin_to_dot.h b/tools/cfg/bin_to_dot.h new file mode 100644 index 000000000..4de2e07fa --- /dev/null +++ b/tools/cfg/bin_to_dot.h @@ -0,0 +1,28 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TOOLS_CFG_BIN_TO_DOT_H_ +#define TOOLS_CFG_BIN_TO_DOT_H_ + +#include + +#include "spirv-tools/libspirv.h" + +// Dumps the control flow graph for the given module to the output stream. +// Returns SPV_SUCCESS on succes. +spv_result_t BinaryToDot(const spv_const_context context, const uint32_t* words, + size_t num_words, std::iostream* out, + spv_diagnostic* diagnostic); + +#endif // TOOLS_CFG_BIN_TO_DOT_H_ diff --git a/tools/cfg/cfg.cpp b/tools/cfg/cfg.cpp new file mode 100644 index 000000000..9e2c448ba --- /dev/null +++ b/tools/cfg/cfg.cpp @@ -0,0 +1,127 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "spirv-tools/libspirv.h" +#include "tools/cfg/bin_to_dot.h" +#include "tools/io.h" + +// Prints a program usage message to stdout. +static void print_usage(const char* argv0) { + printf( + R"(%s - Show the control flow graph in GraphiViz "dot" form. EXPERIMENTAL + +Usage: %s [options] [] + +The SPIR-V binary is read from . If no file is specified, +or if the filename is "-", then the binary is read from standard input. + +Options: + + -h, --help Print this help. + --version Display version information. + + -o Set the output filename. + Output goes to standard output if this option is + not specified, or if the filename is "-". +)", + argv0, argv0); +} + +static const auto kDefaultEnvironment = SPV_ENV_UNIVERSAL_1_2; + +int main(int argc, char** argv) { + const char* inFile = nullptr; + const char* outFile = nullptr; // Stays nullptr if printing to stdout. + + for (int argi = 1; argi < argc; ++argi) { + if ('-' == argv[argi][0]) { + switch (argv[argi][1]) { + case 'h': + print_usage(argv[0]); + return 0; + case 'o': { + if (!outFile && argi + 1 < argc) { + outFile = argv[++argi]; + } else { + print_usage(argv[0]); + return 1; + } + } break; + case '-': { + // Long options + if (0 == strcmp(argv[argi], "--help")) { + print_usage(argv[0]); + return 0; + } else if (0 == strcmp(argv[argi], "--version")) { + printf("%s EXPERIMENTAL\n", spvSoftwareVersionDetailsString()); + printf("Target: %s\n", + spvTargetEnvDescription(kDefaultEnvironment)); + return 0; + } else { + print_usage(argv[0]); + return 1; + } + } break; + case 0: { + // Setting a filename of "-" to indicate stdin. + if (!inFile) { + inFile = argv[argi]; + } else { + fprintf(stderr, "error: More than one input file specified\n"); + return 1; + } + } break; + default: + print_usage(argv[0]); + return 1; + } + } else { + if (!inFile) { + inFile = argv[argi]; + } else { + fprintf(stderr, "error: More than one input file specified\n"); + return 1; + } + } + } + + // Read the input binary. + std::vector contents; + if (!ReadFile(inFile, "rb", &contents)) return 1; + spv_context context = spvContextCreate(kDefaultEnvironment); + spv_diagnostic diagnostic = nullptr; + + std::stringstream ss; + auto error = + BinaryToDot(context, contents.data(), contents.size(), &ss, &diagnostic); + if (error) { + spvDiagnosticPrint(diagnostic); + spvDiagnosticDestroy(diagnostic); + spvContextDestroy(context); + return error; + } + std::string str = ss.str(); + WriteFile(outFile, "w", str.data(), str.size()); + + spvDiagnosticDestroy(diagnostic); + spvContextDestroy(context); + + return 0; +} diff --git a/tools/comp/markv.cpp b/tools/comp/markv.cpp new file mode 100644 index 000000000..9a0a51808 --- /dev/null +++ b/tools/comp/markv.cpp @@ -0,0 +1,385 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/comp/markv.h" +#include "source/spirv_target_env.h" +#include "source/table.h" +#include "spirv-tools/optimizer.hpp" +#include "tools/comp/markv_model_factory.h" +#include "tools/io.h" + +namespace { + +const auto kSpvEnv = SPV_ENV_UNIVERSAL_1_2; + +enum Task { + kNoTask = 0, + kEncode, + kDecode, + kTest, +}; + +struct ScopedContext { + ScopedContext(spv_target_env env) : context(spvContextCreate(env)) {} + ~ScopedContext() { spvContextDestroy(context); } + spv_context context; +}; + +void print_usage(char* argv0) { + printf( + R"(%s - Encodes or decodes a SPIR-V binary to or from a MARK-V binary. + +USAGE: %s [e|d|t] [options] [] + +The input binary is read from . If no file is specified, +or if the filename is "-", then the binary is read from standard input. + +If no output is specified then the output is printed to stdout in a human +readable format. + +WIP: MARK-V codec is in early stages of development. At the moment it only +can encode and decode some SPIR-V files and only if exacly the same build of +software is used (is doesn't write or handle version numbers yet). + +Tasks: + e Encode SPIR-V to MARK-V. + d Decode MARK-V to SPIR-V. + t Test the codec by first encoding the given SPIR-V file to + MARK-V, then decoding it back to SPIR-V and comparing results. + +Options: + -h, --help Print this help. + --comments Write codec comments to stderr. + --version Display MARK-V codec version. + --validate Validate SPIR-V while encoding or decoding. + --model= + Compression model, possible values: + shader_lite - fast, poor compression ratio + shader_mid - balanced + shader_max - best compression ratio + Default: shader_lite + + -o Set the output filename. + Output goes to standard output if this option is + not specified, or if the filename is "-". + Not needed for 't' task (testing). +)", + argv0, argv0); +} + +void DiagnosticsMessageHandler(spv_message_level_t level, const char*, + const spv_position_t& position, + const char* message) { + switch (level) { + case SPV_MSG_FATAL: + case SPV_MSG_INTERNAL_ERROR: + case SPV_MSG_ERROR: + std::cerr << "error: " << position.index << ": " << message << std::endl; + break; + case SPV_MSG_WARNING: + std::cerr << "warning: " << position.index << ": " << message + << std::endl; + break; + case SPV_MSG_INFO: + std::cerr << "info: " << position.index << ": " << message << std::endl; + break; + default: + break; + } +} + +} // namespace + +int main(int argc, char** argv) { + const char* input_filename = nullptr; + const char* output_filename = nullptr; + + Task task = kNoTask; + + if (argc < 3) { + print_usage(argv[0]); + return 0; + } + + const char* task_char = argv[1]; + if (0 == strcmp("e", task_char)) { + task = kEncode; + } else if (0 == strcmp("d", task_char)) { + task = kDecode; + } else if (0 == strcmp("t", task_char)) { + task = kTest; + } + + if (task == kNoTask) { + print_usage(argv[0]); + return 1; + } + + bool want_comments = false; + bool validate_spirv_binary = false; + + spvtools::comp::MarkvModelType model_type = + spvtools::comp::kMarkvModelUnknown; + + for (int argi = 2; argi < argc; ++argi) { + if ('-' == argv[argi][0]) { + switch (argv[argi][1]) { + case 'h': + print_usage(argv[0]); + return 0; + case 'o': { + if (!output_filename && argi + 1 < argc && + (task == kEncode || task == kDecode)) { + output_filename = argv[++argi]; + } else { + print_usage(argv[0]); + return 1; + } + } break; + case '-': { + if (0 == strcmp(argv[argi], "--help")) { + print_usage(argv[0]); + return 0; + } else if (0 == strcmp(argv[argi], "--comments")) { + want_comments = true; + } else if (0 == strcmp(argv[argi], "--version")) { + fprintf(stderr, "error: Not implemented\n"); + return 1; + } else if (0 == strcmp(argv[argi], "--validate")) { + validate_spirv_binary = true; + } else if (0 == strcmp(argv[argi], "--model=shader_lite")) { + if (model_type != spvtools::comp::kMarkvModelUnknown) + fprintf(stderr, "error: More than one model specified\n"); + model_type = spvtools::comp::kMarkvModelShaderLite; + } else if (0 == strcmp(argv[argi], "--model=shader_mid")) { + if (model_type != spvtools::comp::kMarkvModelUnknown) + fprintf(stderr, "error: More than one model specified\n"); + model_type = spvtools::comp::kMarkvModelShaderMid; + } else if (0 == strcmp(argv[argi], "--model=shader_max")) { + if (model_type != spvtools::comp::kMarkvModelUnknown) + fprintf(stderr, "error: More than one model specified\n"); + model_type = spvtools::comp::kMarkvModelShaderMax; + } else { + print_usage(argv[0]); + return 1; + } + } break; + case '\0': { + // Setting a filename of "-" to indicate stdin. + if (!input_filename) { + input_filename = argv[argi]; + } else { + fprintf(stderr, "error: More than one input file specified\n"); + return 1; + } + } break; + default: + print_usage(argv[0]); + return 1; + } + } else { + if (!input_filename) { + input_filename = argv[argi]; + } else { + fprintf(stderr, "error: More than one input file specified\n"); + return 1; + } + } + } + + if (model_type == spvtools::comp::kMarkvModelUnknown) + model_type = spvtools::comp::kMarkvModelShaderLite; + + const auto no_comments = spvtools::comp::MarkvLogConsumer(); + const auto output_to_stderr = [](const std::string& str) { + std::cerr << str; + }; + + ScopedContext ctx(kSpvEnv); + + std::unique_ptr model = + spvtools::comp::CreateMarkvModel(model_type); + + std::vector spirv; + std::vector markv; + + spvtools::comp::MarkvCodecOptions options; + options.validate_spirv_binary = validate_spirv_binary; + + if (task == kEncode) { + if (!ReadFile(input_filename, "rb", &spirv)) return 1; + assert(!spirv.empty()); + + if (SPV_SUCCESS != spvtools::comp::SpirvToMarkv( + ctx.context, spirv, options, *model, + DiagnosticsMessageHandler, + want_comments ? output_to_stderr : no_comments, + spvtools::comp::MarkvDebugConsumer(), &markv)) { + std::cerr << "error: Failed to encode " << input_filename << " to MARK-V " + << std::endl; + return 1; + } + + if (!WriteFile(output_filename, "wb", markv.data(), markv.size())) + return 1; + } else if (task == kDecode) { + if (!ReadFile(input_filename, "rb", &markv)) return 1; + assert(!markv.empty()); + + if (SPV_SUCCESS != spvtools::comp::MarkvToSpirv( + ctx.context, markv, options, *model, + DiagnosticsMessageHandler, + want_comments ? output_to_stderr : no_comments, + spvtools::comp::MarkvDebugConsumer(), &spirv)) { + std::cerr << "error: Failed to decode " << input_filename << " to SPIR-V " + << std::endl; + return 1; + } + + if (!WriteFile(output_filename, "wb", spirv.data(), spirv.size())) + return 1; + } else if (task == kTest) { + if (!ReadFile(input_filename, "rb", &spirv)) return 1; + assert(!spirv.empty()); + + std::vector spirv_before; + spvtools::Optimizer optimizer(kSpvEnv); + optimizer.RegisterPass(spvtools::CreateCompactIdsPass()); + if (!optimizer.Run(spirv.data(), spirv.size(), &spirv_before)) { + std::cerr << "error: Optimizer failure on: " << input_filename + << std::endl; + } + + std::vector encoder_instruction_bits; + std::vector encoder_instruction_comments; + std::vector> encoder_instruction_words; + std::vector decoder_instruction_bits; + std::vector decoder_instruction_comments; + std::vector> decoder_instruction_words; + + const auto encoder_debug_consumer = [&](const std::vector& words, + const std::string& bits, + const std::string& comment) { + encoder_instruction_words.push_back(words); + encoder_instruction_bits.push_back(bits); + encoder_instruction_comments.push_back(comment); + return true; + }; + + if (SPV_SUCCESS != spvtools::comp::SpirvToMarkv( + ctx.context, spirv_before, options, *model, + DiagnosticsMessageHandler, + want_comments ? output_to_stderr : no_comments, + encoder_debug_consumer, &markv)) { + std::cerr << "error: Failed to encode " << input_filename << " to MARK-V " + << std::endl; + return 1; + } + + const auto write_bug_report = [&]() { + for (size_t inst_index = 0; inst_index < decoder_instruction_words.size(); + ++inst_index) { + std::cerr << "\nInstruction #" << inst_index << std::endl; + std::cerr << "\nEncoder words: "; + for (uint32_t word : encoder_instruction_words[inst_index]) + std::cerr << word << " "; + std::cerr << "\nDecoder words: "; + for (uint32_t word : decoder_instruction_words[inst_index]) + std::cerr << word << " "; + std::cerr << std::endl; + + std::cerr << "\nEncoder bits: " << encoder_instruction_bits[inst_index]; + std::cerr << "\nDecoder bits: " << decoder_instruction_bits[inst_index]; + std::cerr << std::endl; + + std::cerr << "\nEncoder comments:\n" + << encoder_instruction_comments[inst_index]; + std::cerr << "Decoder comments:\n" + << decoder_instruction_comments[inst_index]; + std::cerr << std::endl; + } + }; + + const auto decoder_debug_consumer = [&](const std::vector& words, + const std::string& bits, + const std::string& comment) { + const size_t inst_index = decoder_instruction_words.size(); + if (inst_index >= encoder_instruction_words.size()) { + write_bug_report(); + std::cerr << "error: Decoder has more instructions than encoder: " + << input_filename << std::endl; + return false; + } + + decoder_instruction_words.push_back(words); + decoder_instruction_bits.push_back(bits); + decoder_instruction_comments.push_back(comment); + + if (encoder_instruction_words[inst_index] != + decoder_instruction_words[inst_index]) { + write_bug_report(); + std::cerr << "error: Words of the last decoded instruction differ from " + "reference: " + << input_filename << std::endl; + return false; + } + + if (encoder_instruction_bits[inst_index] != + decoder_instruction_bits[inst_index]) { + write_bug_report(); + std::cerr << "error: Bits of the last decoded instruction differ from " + "reference: " + << input_filename << std::endl; + return false; + } + return true; + }; + + std::vector spirv_after; + const spv_result_t decoding_result = spvtools::comp::MarkvToSpirv( + ctx.context, markv, options, *model, DiagnosticsMessageHandler, + want_comments ? output_to_stderr : no_comments, decoder_debug_consumer, + &spirv_after); + + if (decoding_result == SPV_REQUESTED_TERMINATION) { + std::cerr << "error: Decoding interrupted by the debugger: " + << input_filename << std::endl; + return 1; + } + + if (decoding_result != SPV_SUCCESS) { + std::cerr << "error: Failed to decode encoded " << input_filename + << " back to SPIR-V " << std::endl; + return 1; + } + + assert(spirv_before.size() == spirv_after.size()); + assert(std::mismatch(std::next(spirv_before.begin(), 5), spirv_before.end(), + std::next(spirv_after.begin(), 5)) == + std::make_pair(spirv_before.end(), spirv_after.end())); + } + + return 0; +} diff --git a/tools/comp/markv_model_factory.cpp b/tools/comp/markv_model_factory.cpp new file mode 100644 index 000000000..863fcf558 --- /dev/null +++ b/tools/comp/markv_model_factory.cpp @@ -0,0 +1,50 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/comp/markv_model_factory.h" + +#include "source/util/make_unique.h" +#include "tools/comp/markv_model_shader.h" + +namespace spvtools { +namespace comp { + +std::unique_ptr CreateMarkvModel(MarkvModelType type) { + std::unique_ptr model; + switch (type) { + case kMarkvModelShaderLite: { + model = MakeUnique(); + break; + } + case kMarkvModelShaderMid: { + model = MakeUnique(); + break; + } + case kMarkvModelShaderMax: { + model = MakeUnique(); + break; + } + case kMarkvModelUnknown: { + assert(0 && "kMarkvModelUnknown supplied to CreateMarkvModel"); + return model; + } + } + + model->SetModelType(static_cast(type)); + + return model; +} + +} // namespace comp +} // namespace spvtools diff --git a/tools/comp/markv_model_factory.h b/tools/comp/markv_model_factory.h new file mode 100644 index 000000000..c13898b98 --- /dev/null +++ b/tools/comp/markv_model_factory.h @@ -0,0 +1,37 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TOOLS_COMP_MARKV_MODEL_FACTORY_H_ +#define TOOLS_COMP_MARKV_MODEL_FACTORY_H_ + +#include + +#include "source/comp/markv_model.h" + +namespace spvtools { +namespace comp { + +enum MarkvModelType { + kMarkvModelUnknown = 0, + kMarkvModelShaderLite, + kMarkvModelShaderMid, + kMarkvModelShaderMax, +}; + +std::unique_ptr CreateMarkvModel(MarkvModelType type); + +} // namespace comp +} // namespace spvtools + +#endif // TOOLS_COMP_MARKV_MODEL_FACTORY_H_ diff --git a/tools/comp/markv_model_shader.cpp b/tools/comp/markv_model_shader.cpp new file mode 100644 index 000000000..8e296cd8c --- /dev/null +++ b/tools/comp/markv_model_shader.cpp @@ -0,0 +1,84 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/comp/markv_model_shader.h" + +#include +#include +#include +#include +#include +#include + +#include "source/util/make_unique.h" + +namespace spvtools { +namespace comp { +namespace { + +// Signals that the value is not in the coding scheme and a fallback method +// needs to be used. +const uint64_t kMarkvNoneOfTheAbove = MarkvModel::GetMarkvNoneOfTheAbove(); + +inline uint32_t CombineOpcodeAndNumOperands(uint32_t opcode, + uint32_t num_operands) { + return opcode | (num_operands << 16); +} + +#include "tools/comp/markv_model_shader_default_autogen.inc" + +} // namespace + +MarkvModelShaderLite::MarkvModelShaderLite() { + const uint16_t kVersionNumber = 1; + SetModelVersion(kVersionNumber); + + opcode_and_num_operands_huffman_codec_ = + MakeUnique>(GetOpcodeAndNumOperandsHist()); + + id_fallback_strategy_ = IdFallbackStrategy::kShortDescriptor; +} + +MarkvModelShaderMid::MarkvModelShaderMid() { + const uint16_t kVersionNumber = 1; + SetModelVersion(kVersionNumber); + + opcode_and_num_operands_huffman_codec_ = + MakeUnique>(GetOpcodeAndNumOperandsHist()); + non_id_word_huffman_codecs_ = GetNonIdWordHuffmanCodecs(); + id_descriptor_huffman_codecs_ = GetIdDescriptorHuffmanCodecs(); + descriptors_with_coding_scheme_ = GetDescriptorsWithCodingScheme(); + literal_string_huffman_codecs_ = GetLiteralStringHuffmanCodecs(); + + id_fallback_strategy_ = IdFallbackStrategy::kShortDescriptor; +} + +MarkvModelShaderMax::MarkvModelShaderMax() { + const uint16_t kVersionNumber = 1; + SetModelVersion(kVersionNumber); + + opcode_and_num_operands_huffman_codec_ = + MakeUnique>(GetOpcodeAndNumOperandsHist()); + opcode_and_num_operands_markov_huffman_codecs_ = + GetOpcodeAndNumOperandsMarkovHuffmanCodecs(); + non_id_word_huffman_codecs_ = GetNonIdWordHuffmanCodecs(); + id_descriptor_huffman_codecs_ = GetIdDescriptorHuffmanCodecs(); + descriptors_with_coding_scheme_ = GetDescriptorsWithCodingScheme(); + literal_string_huffman_codecs_ = GetLiteralStringHuffmanCodecs(); + + id_fallback_strategy_ = IdFallbackStrategy::kRuleBased; +} + +} // namespace comp +} // namespace spvtools diff --git a/tools/comp/markv_model_shader.h b/tools/comp/markv_model_shader.h new file mode 100644 index 000000000..3a704571f --- /dev/null +++ b/tools/comp/markv_model_shader.h @@ -0,0 +1,47 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TOOLS_COMP_MARKV_MODEL_SHADER_H_ +#define TOOLS_COMP_MARKV_MODEL_SHADER_H_ + +#include "source/comp/markv_model.h" + +namespace spvtools { +namespace comp { + +// MARK-V shader compression model, which only uses fast and lightweight +// algorithms, which do not require training and are not heavily dependent on +// SPIR-V grammar. Compression ratio is worse than by other models. +class MarkvModelShaderLite : public MarkvModel { + public: + MarkvModelShaderLite(); +}; + +// MARK-V shader compression model with balanced compression ratio and runtime +// performance. +class MarkvModelShaderMid : public MarkvModel { + public: + MarkvModelShaderMid(); +}; + +// MARK-V shader compression model designed for maximum compression. +class MarkvModelShaderMax : public MarkvModel { + public: + MarkvModelShaderMax(); +}; + +} // namespace comp +} // namespace spvtools + +#endif // TOOLS_COMP_MARKV_MODEL_SHADER_H_ diff --git a/tools/comp/markv_model_shader_default_autogen.inc b/tools/comp/markv_model_shader_default_autogen.inc new file mode 100644 index 000000000..0093cf1c0 --- /dev/null +++ b/tools/comp/markv_model_shader_default_autogen.inc @@ -0,0 +1,14519 @@ + +std::map GetOpcodeAndNumOperandsHist() { + return std::map({ + { CombineOpcodeAndNumOperands(SpvOpExtInst, 7), 158282 }, + { CombineOpcodeAndNumOperands(SpvOpDot, 4), 151035 }, + { CombineOpcodeAndNumOperands(SpvOpVectorShuffle, 6), 183292 }, + { CombineOpcodeAndNumOperands(SpvOpImageSampleImplicitLod, 4), 126492 }, + { CombineOpcodeAndNumOperands(SpvOpExecutionMode, 2), 13311 }, + { CombineOpcodeAndNumOperands(SpvOpFNegate, 3), 29952 }, + { CombineOpcodeAndNumOperands(SpvOpExtInst, 5), 106847 }, + { CombineOpcodeAndNumOperands(SpvOpImageSampleExplicitLod, 7), 26350 }, + { CombineOpcodeAndNumOperands(SpvOpImageSampleExplicitLod, 6), 28186 }, + { CombineOpcodeAndNumOperands(SpvOpFDiv, 4), 41635 }, + { CombineOpcodeAndNumOperands(SpvOpFMul, 4), 412786 }, + { CombineOpcodeAndNumOperands(SpvOpFunction, 4), 62905 }, + { CombineOpcodeAndNumOperands(SpvOpVectorShuffle, 8), 118614 }, + { CombineOpcodeAndNumOperands(SpvOpDecorate, 2), 100735 }, + { CombineOpcodeAndNumOperands(SpvOpReturnValue, 1), 40852 }, + { CombineOpcodeAndNumOperands(SpvOpVectorTimesScalar, 4), 157091 }, + { CombineOpcodeAndNumOperands(SpvOpExtInst, 6), 122100 }, + { CombineOpcodeAndNumOperands(SpvOpAccessChain, 5), 82930 }, + { CombineOpcodeAndNumOperands(SpvOpFSub, 4), 161019 }, + { CombineOpcodeAndNumOperands(SpvOpConstant, 3), 466014 }, + { CombineOpcodeAndNumOperands(SpvOpCompositeExtract, 5), 107126 }, + { CombineOpcodeAndNumOperands(SpvOpTypeImage, 8), 34775 }, + { CombineOpcodeAndNumOperands(SpvOpImageSampleDrefExplicitLod, 7), 26146 }, + { CombineOpcodeAndNumOperands(SpvOpMemoryModel, 2), 18879 }, + { CombineOpcodeAndNumOperands(SpvOpDecorate, 3), 485251 }, + { CombineOpcodeAndNumOperands(SpvOpCompositeConstruct, 4), 78011 }, + { CombineOpcodeAndNumOperands(SpvOpTypeFloat, 2), 18879 }, + { CombineOpcodeAndNumOperands(SpvOpVectorTimesMatrix, 4), 15848 }, + { CombineOpcodeAndNumOperands(SpvOpTypeVector, 3), 69404 }, + { CombineOpcodeAndNumOperands(SpvOpTypeFunction, 3), 19998 }, + { CombineOpcodeAndNumOperands(SpvOpConstantComposite, 6), 40228 }, + { CombineOpcodeAndNumOperands(SpvOpCapability, 1), 22510 }, + { CombineOpcodeAndNumOperands(SpvOpTypeArray, 3), 37585 }, + { CombineOpcodeAndNumOperands(SpvOpTypeInt, 3), 30454 }, + { CombineOpcodeAndNumOperands(SpvOpFunctionCall, 4), 29021 }, + { CombineOpcodeAndNumOperands(SpvOpFAdd, 4), 342237 }, + { CombineOpcodeAndNumOperands(SpvOpTypeMatrix, 3), 24449 }, + { CombineOpcodeAndNumOperands(SpvOpLabel, 1), 129408 }, + { CombineOpcodeAndNumOperands(SpvOpTypePointer, 3), 246535 }, + { CombineOpcodeAndNumOperands(SpvOpAccessChain, 4), 503456 }, + { CombineOpcodeAndNumOperands(SpvOpTypeFunction, 2), 19779 }, + { CombineOpcodeAndNumOperands(SpvOpBranchConditional, 3), 24139 }, + { CombineOpcodeAndNumOperands(SpvOpVariable, 3), 697946 }, + { CombineOpcodeAndNumOperands(SpvOpConstantComposite, 5), 55769 }, + { CombineOpcodeAndNumOperands(SpvOpTypeVoid, 1), 18879 }, + { CombineOpcodeAndNumOperands(SpvOpCompositeConstruct, 6), 145508 }, + { CombineOpcodeAndNumOperands(SpvOpFunctionParameter, 2), 85583 }, + { CombineOpcodeAndNumOperands(SpvOpTypeSampledImage, 2), 34775 }, + { CombineOpcodeAndNumOperands(SpvOpConstantComposite, 4), 66362 }, + { CombineOpcodeAndNumOperands(SpvOpLoad, 3), 1272902 }, + { CombineOpcodeAndNumOperands(SpvOpReturn, 0), 22122 }, + { CombineOpcodeAndNumOperands(SpvOpCompositeExtract, 4), 861008 }, + { CombineOpcodeAndNumOperands(SpvOpFunctionEnd, 0), 62905 }, + { CombineOpcodeAndNumOperands(SpvOpExtInstImport, 2), 18879 }, + { CombineOpcodeAndNumOperands(SpvOpSelectionMerge, 2), 22009 }, + { CombineOpcodeAndNumOperands(SpvOpBranch, 1), 38275 }, + { CombineOpcodeAndNumOperands(SpvOpTypeBool, 1), 12208 }, + { CombineOpcodeAndNumOperands(SpvOpSampledImage, 4), 95518 }, + { CombineOpcodeAndNumOperands(SpvOpMemberDecorate, 3), 94887 }, + { CombineOpcodeAndNumOperands(SpvOpMemberDecorate, 4), 1942215 }, + { CombineOpcodeAndNumOperands(SpvOpCompositeConstruct, 5), 205266 }, + { CombineOpcodeAndNumOperands(SpvOpUndef, 2), 22157 }, + { CombineOpcodeAndNumOperands(SpvOpCompositeInsert, 5), 142749 }, + { CombineOpcodeAndNumOperands(SpvOpCompositeInsert, 6), 24420 }, + { CombineOpcodeAndNumOperands(SpvOpCompositeExtract, 6), 16896 }, + { CombineOpcodeAndNumOperands(SpvOpStore, 2), 604982 }, + { CombineOpcodeAndNumOperands(SpvOpIAdd, 4), 14471 }, + { CombineOpcodeAndNumOperands(SpvOpVectorShuffle, 7), 269658 }, + { kMarkvNoneOfTheAbove, 399895 }, + }); +} + +std::map>> +GetOpcodeAndNumOperandsMarkovHuffmanCodecs() { + std::map>> codecs; + { + std::unique_ptr> codec(new HuffmanCodec(35, { + {0, 0, 0}, + {65790, 0, 0}, + {131134, 0, 0}, + {196669, 0, 0}, + {262209, 0, 0}, + {262221, 0, 0}, + {262225, 0, 0}, + {262230, 0, 0}, + {262273, 0, 0}, + {262277, 0, 0}, + {262286, 0, 0}, + {327745, 0, 0}, + {327761, 0, 0}, + {327762, 0, 0}, + {393295, 0, 0}, + {393304, 0, 0}, + {458831, 0, 0}, + {458840, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 11, 8}, + {0, 12, 19}, + {0, 18, 20}, + {0, 5, 21}, + {0, 15, 7}, + {0, 10, 1}, + {0, 23, 22}, + {0, 14, 24}, + {0, 6, 4}, + {0, 2, 17}, + {0, 13, 25}, + {0, 9, 26}, + {0, 28, 27}, + {0, 3, 29}, + {0, 30, 16}, + {0, 32, 31}, + {0, 34, 33}, + })); + + codecs.emplace(SpvOpImageSampleExplicitLod, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(55, { + {0, 0, 0}, + {65785, 0, 0}, + {65790, 0, 0}, + {131134, 0, 0}, + {196669, 0, 0}, + {196735, 0, 0}, + {262201, 0, 0}, + {262209, 0, 0}, + {262224, 0, 0}, + {262225, 0, 0}, + {262231, 0, 0}, + {262273, 0, 0}, + {262275, 0, 0}, + {262277, 0, 0}, + {262280, 0, 0}, + {262286, 0, 0}, + {327692, 0, 0}, + {327745, 0, 0}, + {327760, 0, 0}, + {327762, 0, 0}, + {393228, 0, 0}, + {393295, 0, 0}, + {393296, 0, 0}, + {393303, 0, 0}, + {393304, 0, 0}, + {458764, 0, 0}, + {458831, 0, 0}, + {524367, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 14, 5}, + {0, 29, 17}, + {0, 1, 30}, + {0, 10, 20}, + {0, 32, 31}, + {0, 33, 2}, + {0, 34, 23}, + {0, 8, 35}, + {0, 6, 36}, + {0, 19, 22}, + {0, 28, 25}, + {0, 38, 37}, + {0, 13, 39}, + {0, 40, 24}, + {0, 27, 21}, + {0, 26, 41}, + {0, 42, 12}, + {0, 15, 43}, + {0, 44, 18}, + {0, 45, 3}, + {0, 11, 7}, + {0, 16, 46}, + {0, 47, 9}, + {0, 4, 48}, + {0, 50, 49}, + {0, 52, 51}, + {0, 54, 53}, + })); + + codecs.emplace(SpvOpFDiv, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(19, { + {0, 0, 0}, + {196669, 0, 0}, + {262209, 0, 0}, + {262224, 0, 0}, + {262225, 0, 0}, + {262231, 0, 0}, + {262286, 0, 0}, + {393295, 0, 0}, + {393304, 0, 0}, + {458840, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 8, 10}, + {0, 11, 3}, + {0, 2, 9}, + {0, 4, 1}, + {0, 5, 6}, + {0, 13, 12}, + {0, 15, 14}, + {0, 16, 7}, + {0, 18, 17}, + })); + + codecs.emplace(SpvOpSampledImage, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(67, { + {0, 0, 0}, + {65785, 0, 0}, + {65790, 0, 0}, + {131134, 0, 0}, + {131319, 0, 0}, + {196669, 0, 0}, + {196735, 0, 0}, + {262209, 0, 0}, + {262224, 0, 0}, + {262225, 0, 0}, + {262231, 0, 0}, + {262272, 0, 0}, + {262273, 0, 0}, + {262275, 0, 0}, + {262277, 0, 0}, + {262280, 0, 0}, + {262285, 0, 0}, + {262286, 0, 0}, + {262292, 0, 0}, + {327692, 0, 0}, + {327745, 0, 0}, + {327760, 0, 0}, + {327761, 0, 0}, + {327762, 0, 0}, + {393228, 0, 0}, + {393281, 0, 0}, + {393295, 0, 0}, + {393296, 0, 0}, + {393297, 0, 0}, + {393298, 0, 0}, + {393304, 0, 0}, + {458764, 0, 0}, + {458831, 0, 0}, + {524367, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 4, 10}, + {0, 30, 35}, + {0, 1, 36}, + {0, 11, 37}, + {0, 38, 6}, + {0, 16, 39}, + {0, 15, 40}, + {0, 25, 2}, + {0, 41, 20}, + {0, 26, 19}, + {0, 42, 29}, + {0, 28, 22}, + {0, 23, 34}, + {0, 44, 43}, + {0, 17, 45}, + {0, 24, 27}, + {0, 18, 33}, + {0, 47, 46}, + {0, 8, 48}, + {0, 50, 49}, + {0, 32, 51}, + {0, 31, 52}, + {0, 53, 21}, + {0, 54, 13}, + {0, 3, 55}, + {0, 7, 14}, + {0, 57, 56}, + {0, 58, 5}, + {0, 59, 9}, + {0, 61, 60}, + {0, 63, 62}, + {0, 64, 12}, + {0, 66, 65}, + })); + + codecs.emplace(SpvOpFMul, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(79, { + {0, 0, 0}, + {65785, 0, 0}, + {65790, 0, 0}, + {131134, 0, 0}, + {196669, 0, 0}, + {196735, 0, 0}, + {262201, 0, 0}, + {262209, 0, 0}, + {262224, 0, 0}, + {262225, 0, 0}, + {262230, 0, 0}, + {262231, 0, 0}, + {262272, 0, 0}, + {262273, 0, 0}, + {262275, 0, 0}, + {262277, 0, 0}, + {262280, 0, 0}, + {262286, 0, 0}, + {262288, 0, 0}, + {262292, 0, 0}, + {262328, 0, 0}, + {262334, 0, 0}, + {327692, 0, 0}, + {327737, 0, 0}, + {327745, 0, 0}, + {327760, 0, 0}, + {327761, 0, 0}, + {327762, 0, 0}, + {393228, 0, 0}, + {393281, 0, 0}, + {393295, 0, 0}, + {393296, 0, 0}, + {393297, 0, 0}, + {393303, 0, 0}, + {393304, 0, 0}, + {458764, 0, 0}, + {458831, 0, 0}, + {458840, 0, 0}, + {524345, 0, 0}, + {524367, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 38, 33}, + {0, 18, 41}, + {0, 42, 23}, + {0, 43, 6}, + {0, 34, 44}, + {0, 1, 45}, + {0, 31, 14}, + {0, 47, 46}, + {0, 48, 2}, + {0, 12, 21}, + {0, 49, 30}, + {0, 37, 50}, + {0, 51, 20}, + {0, 5, 24}, + {0, 40, 16}, + {0, 29, 13}, + {0, 26, 52}, + {0, 53, 17}, + {0, 36, 54}, + {0, 55, 28}, + {0, 57, 56}, + {0, 19, 25}, + {0, 39, 8}, + {0, 32, 58}, + {0, 59, 27}, + {0, 22, 10}, + {0, 35, 60}, + {0, 62, 61}, + {0, 63, 7}, + {0, 65, 64}, + {0, 4, 66}, + {0, 68, 67}, + {0, 11, 3}, + {0, 15, 69}, + {0, 9, 70}, + {0, 72, 71}, + {0, 74, 73}, + {0, 76, 75}, + {0, 78, 77}, + })); + + codecs.emplace(SpvOpFAdd, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(55, { + {0, 0, 0}, + {65556, 0, 0}, + {65562, 0, 0}, + {131073, 0, 0}, + {131094, 0, 0}, + {131105, 0, 0}, + {196629, 0, 0}, + {196631, 0, 0}, + {196632, 0, 0}, + {196636, 0, 0}, + {196640, 0, 0}, + {196641, 0, 0}, + {196651, 0, 0}, + {196667, 0, 0}, + {262177, 0, 0}, + {262188, 0, 0}, + {262198, 0, 0}, + {327713, 0, 0}, + {327724, 0, 0}, + {393249, 0, 0}, + {393260, 0, 0}, + {458785, 0, 0}, + {524313, 0, 0}, + {524321, 0, 0}, + {589857, 0, 0}, + {655393, 0, 0}, + {720929, 0, 0}, + {852001, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 26, 24}, + {0, 29, 27}, + {0, 4, 30}, + {0, 21, 9}, + {0, 31, 20}, + {0, 33, 32}, + {0, 34, 3}, + {0, 8, 35}, + {0, 36, 5}, + {0, 23, 16}, + {0, 38, 37}, + {0, 25, 2}, + {0, 39, 1}, + {0, 17, 40}, + {0, 41, 15}, + {0, 18, 42}, + {0, 43, 6}, + {0, 44, 14}, + {0, 28, 19}, + {0, 7, 45}, + {0, 46, 22}, + {0, 48, 47}, + {0, 49, 11}, + {0, 51, 50}, + {0, 12, 10}, + {0, 53, 52}, + {0, 13, 54}, + })); + + codecs.emplace(SpvOpTypePointer, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(57, { + {0, 0, 0}, + {65785, 0, 0}, + {65790, 0, 0}, + {131134, 0, 0}, + {196669, 0, 0}, + {196735, 0, 0}, + {262209, 0, 0}, + {262224, 0, 0}, + {262225, 0, 0}, + {262272, 0, 0}, + {262273, 0, 0}, + {262275, 0, 0}, + {262277, 0, 0}, + {262280, 0, 0}, + {262286, 0, 0}, + {262292, 0, 0}, + {262328, 0, 0}, + {327692, 0, 0}, + {327745, 0, 0}, + {327760, 0, 0}, + {327761, 0, 0}, + {327762, 0, 0}, + {393228, 0, 0}, + {393273, 0, 0}, + {393295, 0, 0}, + {393296, 0, 0}, + {458764, 0, 0}, + {458831, 0, 0}, + {524367, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 9, 23}, + {0, 1, 30}, + {0, 5, 31}, + {0, 32, 28}, + {0, 33, 25}, + {0, 34, 29}, + {0, 18, 24}, + {0, 27, 16}, + {0, 7, 13}, + {0, 14, 35}, + {0, 20, 10}, + {0, 36, 21}, + {0, 2, 37}, + {0, 38, 3}, + {0, 39, 22}, + {0, 40, 19}, + {0, 41, 11}, + {0, 6, 4}, + {0, 12, 42}, + {0, 43, 8}, + {0, 15, 26}, + {0, 45, 44}, + {0, 47, 46}, + {0, 48, 17}, + {0, 50, 49}, + {0, 52, 51}, + {0, 54, 53}, + {0, 56, 55}, + })); + + codecs.emplace(SpvOpFSub, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(13, { + {0, 0, 0}, + {65785, 0, 0}, + {131134, 0, 0}, + {196719, 0, 0}, + {262209, 0, 0}, + {262276, 0, 0}, + {327745, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 7, 4}, + {0, 2, 8}, + {0, 1, 9}, + {0, 5, 10}, + {0, 3, 6}, + {0, 12, 11}, + })); + + codecs.emplace(SpvOpIAdd, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(83, { + {0, 0, 0}, + {65785, 0, 0}, + {65790, 0, 0}, + {131134, 0, 0}, + {131319, 0, 0}, + {196669, 0, 0}, + {196732, 0, 0}, + {196735, 0, 0}, + {262209, 0, 0}, + {262221, 0, 0}, + {262224, 0, 0}, + {262225, 0, 0}, + {262230, 0, 0}, + {262231, 0, 0}, + {262273, 0, 0}, + {262275, 0, 0}, + {262277, 0, 0}, + {262280, 0, 0}, + {262286, 0, 0}, + {262288, 0, 0}, + {262292, 0, 0}, + {262328, 0, 0}, + {262334, 0, 0}, + {262340, 0, 0}, + {327692, 0, 0}, + {327737, 0, 0}, + {327745, 0, 0}, + {327760, 0, 0}, + {327761, 0, 0}, + {327762, 0, 0}, + {393228, 0, 0}, + {393273, 0, 0}, + {393295, 0, 0}, + {393296, 0, 0}, + {393297, 0, 0}, + {393298, 0, 0}, + {393304, 0, 0}, + {458764, 0, 0}, + {458831, 0, 0}, + {458840, 0, 0}, + {458842, 0, 0}, + {524367, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 25, 2}, + {0, 31, 43}, + {0, 4, 44}, + {0, 26, 45}, + {0, 39, 46}, + {0, 34, 36}, + {0, 19, 47}, + {0, 6, 48}, + {0, 35, 9}, + {0, 12, 29}, + {0, 21, 49}, + {0, 22, 13}, + {0, 17, 50}, + {0, 23, 51}, + {0, 52, 7}, + {0, 37, 1}, + {0, 53, 3}, + {0, 54, 24}, + {0, 56, 55}, + {0, 32, 57}, + {0, 59, 58}, + {0, 42, 10}, + {0, 60, 8}, + {0, 5, 41}, + {0, 61, 20}, + {0, 62, 38}, + {0, 64, 63}, + {0, 40, 65}, + {0, 66, 18}, + {0, 15, 28}, + {0, 14, 67}, + {0, 68, 30}, + {0, 70, 69}, + {0, 72, 71}, + {0, 73, 27}, + {0, 16, 74}, + {0, 75, 33}, + {0, 77, 76}, + {0, 79, 78}, + {0, 81, 80}, + {0, 82, 11}, + })); + + codecs.emplace(SpvOpCompositeExtract, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(29, { + {0, 0, 0}, + {65790, 0, 0}, + {131134, 0, 0}, + {196669, 0, 0}, + {262209, 0, 0}, + {262225, 0, 0}, + {262273, 0, 0}, + {262288, 0, 0}, + {262292, 0, 0}, + {327692, 0, 0}, + {327761, 0, 0}, + {327762, 0, 0}, + {393295, 0, 0}, + {458831, 0, 0}, + {524367, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 10, 6}, + {0, 16, 13}, + {0, 7, 17}, + {0, 15, 18}, + {0, 19, 12}, + {0, 20, 14}, + {0, 1, 4}, + {0, 22, 21}, + {0, 11, 8}, + {0, 2, 5}, + {0, 9, 23}, + {0, 3, 24}, + {0, 26, 25}, + {0, 28, 27}, + })); + + codecs.emplace(SpvOpVectorTimesMatrix, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {65784, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(SpvOpBranch, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {262198, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(SpvOpFunctionEnd, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {65784, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(SpvOpBranchConditional, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(53, { + {0, 0, 0}, + {65785, 0, 0}, + {65790, 0, 0}, + {131134, 0, 0}, + {131319, 0, 0}, + {196665, 0, 0}, + {196669, 0, 0}, + {196735, 0, 0}, + {262201, 0, 0}, + {262209, 0, 0}, + {262224, 0, 0}, + {262225, 0, 0}, + {262231, 0, 0}, + {262273, 0, 0}, + {262275, 0, 0}, + {262277, 0, 0}, + {262280, 0, 0}, + {262286, 0, 0}, + {262288, 0, 0}, + {262292, 0, 0}, + {327692, 0, 0}, + {327745, 0, 0}, + {327760, 0, 0}, + {393228, 0, 0}, + {393295, 0, 0}, + {458764, 0, 0}, + {458831, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 25, 16}, + {0, 21, 28}, + {0, 18, 23}, + {0, 4, 29}, + {0, 10, 5}, + {0, 1, 30}, + {0, 32, 31}, + {0, 22, 33}, + {0, 34, 8}, + {0, 35, 15}, + {0, 13, 36}, + {0, 26, 17}, + {0, 38, 37}, + {0, 39, 11}, + {0, 40, 14}, + {0, 12, 27}, + {0, 19, 41}, + {0, 24, 42}, + {0, 44, 43}, + {0, 45, 7}, + {0, 20, 46}, + {0, 9, 47}, + {0, 48, 2}, + {0, 50, 49}, + {0, 6, 3}, + {0, 52, 51}, + })); + + codecs.emplace(SpvOpFunctionCall, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(71, { + {0, 0, 0}, + {65556, 0, 0}, + {65562, 0, 0}, + {131073, 0, 0}, + {131094, 0, 0}, + {131099, 0, 0}, + {131134, 0, 0}, + {196629, 0, 0}, + {196631, 0, 0}, + {196632, 0, 0}, + {196636, 0, 0}, + {196640, 0, 0}, + {196651, 0, 0}, + {196665, 0, 0}, + {196667, 0, 0}, + {196669, 0, 0}, + {262188, 0, 0}, + {262198, 0, 0}, + {262201, 0, 0}, + {262209, 0, 0}, + {262225, 0, 0}, + {262275, 0, 0}, + {262280, 0, 0}, + {262292, 0, 0}, + {327692, 0, 0}, + {327724, 0, 0}, + {327737, 0, 0}, + {327745, 0, 0}, + {393228, 0, 0}, + {393260, 0, 0}, + {393273, 0, 0}, + {393295, 0, 0}, + {393296, 0, 0}, + {458831, 0, 0}, + {524313, 0, 0}, + {524367, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 22, 4}, + {0, 32, 23}, + {0, 37, 30}, + {0, 21, 38}, + {0, 39, 31}, + {0, 41, 40}, + {0, 13, 42}, + {0, 43, 26}, + {0, 10, 44}, + {0, 28, 45}, + {0, 35, 18}, + {0, 20, 46}, + {0, 33, 47}, + {0, 24, 48}, + {0, 6, 49}, + {0, 3, 50}, + {0, 16, 51}, + {0, 27, 52}, + {0, 53, 1}, + {0, 9, 17}, + {0, 29, 54}, + {0, 19, 2}, + {0, 8, 36}, + {0, 55, 34}, + {0, 25, 56}, + {0, 7, 57}, + {0, 5, 58}, + {0, 60, 59}, + {0, 61, 15}, + {0, 63, 62}, + {0, 65, 64}, + {0, 66, 11}, + {0, 12, 67}, + {0, 69, 68}, + {0, 14, 70}, + })); + + codecs.emplace(SpvOpVariable, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(5, { + {0, 0, 0}, + {131134, 0, 0}, + {196669, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 3}, + {0, 2, 4}, + })); + + codecs.emplace(SpvOpAccessChain, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(73, { + {0, 0, 0}, + {252, 0, 0}, + {253, 0, 0}, + {65785, 0, 0}, + {65790, 0, 0}, + {131073, 0, 0}, + {131134, 0, 0}, + {131319, 0, 0}, + {196665, 0, 0}, + {196667, 0, 0}, + {196669, 0, 0}, + {196735, 0, 0}, + {196854, 0, 0}, + {262201, 0, 0}, + {262209, 0, 0}, + {262221, 0, 0}, + {262225, 0, 0}, + {262272, 0, 0}, + {262273, 0, 0}, + {262275, 0, 0}, + {262276, 0, 0}, + {262277, 0, 0}, + {262280, 0, 0}, + {262286, 0, 0}, + {262292, 0, 0}, + {262321, 0, 0}, + {327692, 0, 0}, + {327745, 0, 0}, + {327761, 0, 0}, + {327762, 0, 0}, + {393228, 0, 0}, + {393295, 0, 0}, + {393296, 0, 0}, + {393298, 0, 0}, + {393461, 0, 0}, + {458831, 0, 0}, + {524367, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 28, 5}, + {0, 30, 8}, + {0, 13, 38}, + {0, 40, 39}, + {0, 41, 26}, + {0, 42, 19}, + {0, 43, 29}, + {0, 23, 44}, + {0, 36, 32}, + {0, 45, 22}, + {0, 2, 46}, + {0, 21, 20}, + {0, 48, 47}, + {0, 33, 49}, + {0, 4, 50}, + {0, 51, 24}, + {0, 18, 11}, + {0, 52, 12}, + {0, 25, 15}, + {0, 53, 17}, + {0, 37, 54}, + {0, 55, 35}, + {0, 7, 27}, + {0, 57, 56}, + {0, 58, 31}, + {0, 6, 59}, + {0, 1, 60}, + {0, 62, 61}, + {0, 63, 14}, + {0, 3, 16}, + {0, 34, 64}, + {0, 66, 65}, + {0, 68, 67}, + {0, 70, 69}, + {0, 10, 9}, + {0, 72, 71}, + })); + + codecs.emplace(SpvOpLabel, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(5, { + {0, 0, 0}, + {56, 0, 0}, + {65784, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 2}, + {0, 1, 4}, + })); + + codecs.emplace(SpvOpReturn, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(5, { + {0, 0, 0}, + {65784, 0, 0}, + {131127, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 3}, + {0, 2, 4}, + })); + + codecs.emplace(SpvOpFunction, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(31, { + {0, 0, 0}, + {65556, 0, 0}, + {196629, 0, 0}, + {196631, 0, 0}, + {196632, 0, 0}, + {196636, 0, 0}, + {196640, 0, 0}, + {196641, 0, 0}, + {196651, 0, 0}, + {196667, 0, 0}, + {262177, 0, 0}, + {262188, 0, 0}, + {262198, 0, 0}, + {327713, 0, 0}, + {393260, 0, 0}, + {524313, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 12, 1}, + {0, 13, 5}, + {0, 18, 17}, + {0, 7, 19}, + {0, 9, 20}, + {0, 16, 21}, + {0, 15, 10}, + {0, 22, 4}, + {0, 24, 23}, + {0, 25, 14}, + {0, 8, 11}, + {0, 2, 26}, + {0, 28, 27}, + {0, 3, 6}, + {0, 30, 29}, + })); + + codecs.emplace(SpvOpTypeVector, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(5, { + {0, 0, 0}, + {65784, 0, 0}, + {131127, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 3}, + {0, 4, 1}, + })); + + codecs.emplace(SpvOpFunctionParameter, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(5, { + {0, 0, 0}, + {56, 0, 0}, + {65784, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 2}, + {0, 1, 4}, + })); + + codecs.emplace(SpvOpReturnValue, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {131105, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(SpvOpTypeVoid, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(89, { + {0, 0, 0}, + {253, 0, 0}, + {65785, 0, 0}, + {65790, 0, 0}, + {131134, 0, 0}, + {131319, 0, 0}, + {196665, 0, 0}, + {196669, 0, 0}, + {196735, 0, 0}, + {262201, 0, 0}, + {262209, 0, 0}, + {262224, 0, 0}, + {262225, 0, 0}, + {262272, 0, 0}, + {262273, 0, 0}, + {262275, 0, 0}, + {262277, 0, 0}, + {262280, 0, 0}, + {262286, 0, 0}, + {262288, 0, 0}, + {262292, 0, 0}, + {327692, 0, 0}, + {327737, 0, 0}, + {327745, 0, 0}, + {327760, 0, 0}, + {327761, 0, 0}, + {327762, 0, 0}, + {393228, 0, 0}, + {393273, 0, 0}, + {393281, 0, 0}, + {393295, 0, 0}, + {393296, 0, 0}, + {458764, 0, 0}, + {458809, 0, 0}, + {458831, 0, 0}, + {524345, 0, 0}, + {524367, 0, 0}, + {589881, 0, 0}, + {655417, 0, 0}, + {720953, 0, 0}, + {786489, 0, 0}, + {852025, 0, 0}, + {917561, 0, 0}, + {983097, 0, 0}, + {1114169, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 40, 32}, + {0, 46, 29}, + {0, 38, 27}, + {0, 20, 47}, + {0, 49, 48}, + {0, 50, 44}, + {0, 51, 43}, + {0, 14, 5}, + {0, 42, 52}, + {0, 13, 19}, + {0, 3, 26}, + {0, 54, 53}, + {0, 56, 55}, + {0, 57, 6}, + {0, 39, 37}, + {0, 15, 58}, + {0, 18, 31}, + {0, 59, 21}, + {0, 60, 17}, + {0, 61, 41}, + {0, 62, 24}, + {0, 34, 63}, + {0, 35, 64}, + {0, 65, 8}, + {0, 66, 36}, + {0, 67, 30}, + {0, 16, 11}, + {0, 69, 68}, + {0, 70, 28}, + {0, 22, 71}, + {0, 33, 72}, + {0, 45, 73}, + {0, 75, 74}, + {0, 77, 76}, + {0, 78, 12}, + {0, 1, 2}, + {0, 9, 79}, + {0, 25, 80}, + {0, 23, 81}, + {0, 4, 82}, + {0, 84, 83}, + {0, 86, 85}, + {0, 7, 10}, + {0, 88, 87}, + })); + + codecs.emplace(SpvOpStore, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(13, { + {0, 0, 0}, + {131075, 0, 0}, + {131088, 0, 0}, + {131143, 0, 0}, + {196624, 0, 0}, + {196679, 0, 0}, + {262216, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 4}, + {0, 1, 8}, + {0, 7, 9}, + {0, 6, 10}, + {0, 5, 11}, + {0, 2, 12}, + })); + + codecs.emplace(SpvOpEntryPoint, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(97, { + {0, 0, 0}, + {65785, 0, 0}, + {65790, 0, 0}, + {131134, 0, 0}, + {131319, 0, 0}, + {196665, 0, 0}, + {196669, 0, 0}, + {196732, 0, 0}, + {196735, 0, 0}, + {262201, 0, 0}, + {262209, 0, 0}, + {262224, 0, 0}, + {262225, 0, 0}, + {262230, 0, 0}, + {262231, 0, 0}, + {262272, 0, 0}, + {262273, 0, 0}, + {262275, 0, 0}, + {262276, 0, 0}, + {262277, 0, 0}, + {262280, 0, 0}, + {262286, 0, 0}, + {262288, 0, 0}, + {262292, 0, 0}, + {262326, 0, 0}, + {262328, 0, 0}, + {262330, 0, 0}, + {327692, 0, 0}, + {327737, 0, 0}, + {327745, 0, 0}, + {327760, 0, 0}, + {327761, 0, 0}, + {327762, 0, 0}, + {393228, 0, 0}, + {393273, 0, 0}, + {393281, 0, 0}, + {393295, 0, 0}, + {393296, 0, 0}, + {393297, 0, 0}, + {393304, 0, 0}, + {458764, 0, 0}, + {458809, 0, 0}, + {458817, 0, 0}, + {458831, 0, 0}, + {458840, 0, 0}, + {524345, 0, 0}, + {524367, 0, 0}, + {589881, 0, 0}, + {720953, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 42, 47}, + {0, 48, 50}, + {0, 45, 51}, + {0, 34, 52}, + {0, 53, 41}, + {0, 1, 54}, + {0, 55, 5}, + {0, 15, 4}, + {0, 56, 35}, + {0, 26, 24}, + {0, 18, 28}, + {0, 57, 38}, + {0, 59, 58}, + {0, 60, 25}, + {0, 20, 9}, + {0, 7, 61}, + {0, 62, 22}, + {0, 11, 31}, + {0, 63, 8}, + {0, 64, 40}, + {0, 66, 65}, + {0, 27, 44}, + {0, 29, 67}, + {0, 68, 39}, + {0, 69, 2}, + {0, 37, 49}, + {0, 71, 70}, + {0, 30, 72}, + {0, 73, 17}, + {0, 33, 74}, + {0, 23, 14}, + {0, 32, 75}, + {0, 21, 76}, + {0, 77, 16}, + {0, 46, 78}, + {0, 13, 79}, + {0, 80, 12}, + {0, 19, 81}, + {0, 43, 36}, + {0, 83, 82}, + {0, 10, 84}, + {0, 85, 3}, + {0, 6, 86}, + {0, 88, 87}, + {0, 90, 89}, + {0, 92, 91}, + {0, 94, 93}, + {0, 96, 95}, + })); + + codecs.emplace(SpvOpLoad, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(47, { + {0, 0, 0}, + {262159, 0, 0}, + {327695, 0, 0}, + {393231, 0, 0}, + {458767, 0, 0}, + {524303, 0, 0}, + {589839, 0, 0}, + {655375, 0, 0}, + {720911, 0, 0}, + {786447, 0, 0}, + {851983, 0, 0}, + {917519, 0, 0}, + {983055, 0, 0}, + {1048591, 0, 0}, + {1114127, 0, 0}, + {1179663, 0, 0}, + {1245199, 0, 0}, + {1310735, 0, 0}, + {1376271, 0, 0}, + {1441807, 0, 0}, + {1507343, 0, 0}, + {1572879, 0, 0}, + {1638415, 0, 0}, + {1703951, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 23}, + {0, 22, 25}, + {0, 21, 26}, + {0, 6, 20}, + {0, 19, 27}, + {0, 29, 28}, + {0, 24, 18}, + {0, 30, 13}, + {0, 31, 14}, + {0, 32, 7}, + {0, 17, 15}, + {0, 33, 2}, + {0, 34, 8}, + {0, 16, 12}, + {0, 35, 3}, + {0, 36, 5}, + {0, 9, 37}, + {0, 39, 38}, + {0, 11, 40}, + {0, 4, 10}, + {0, 42, 41}, + {0, 44, 43}, + {0, 46, 45}, + })); + + codecs.emplace(SpvOpMemoryModel, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {196631, 0, 0}, + {196640, 0, 0}, + {196641, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 3}, + {0, 4, 5}, + {0, 1, 6}, + })); + + codecs.emplace(SpvOpTypeFloat, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(69, { + {0, 0, 0}, + {65790, 0, 0}, + {131134, 0, 0}, + {196669, 0, 0}, + {196735, 0, 0}, + {262201, 0, 0}, + {262209, 0, 0}, + {262224, 0, 0}, + {262225, 0, 0}, + {262231, 0, 0}, + {262272, 0, 0}, + {262273, 0, 0}, + {262275, 0, 0}, + {262277, 0, 0}, + {262280, 0, 0}, + {262286, 0, 0}, + {262288, 0, 0}, + {262289, 0, 0}, + {262292, 0, 0}, + {327692, 0, 0}, + {327745, 0, 0}, + {327760, 0, 0}, + {327761, 0, 0}, + {327762, 0, 0}, + {327849, 0, 0}, + {393228, 0, 0}, + {393281, 0, 0}, + {393295, 0, 0}, + {393296, 0, 0}, + {393304, 0, 0}, + {458764, 0, 0}, + {458809, 0, 0}, + {458831, 0, 0}, + {524345, 0, 0}, + {524367, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 33, 10}, + {0, 31, 36}, + {0, 26, 37}, + {0, 5, 38}, + {0, 20, 39}, + {0, 22, 40}, + {0, 24, 25}, + {0, 15, 41}, + {0, 9, 17}, + {0, 1, 42}, + {0, 4, 43}, + {0, 35, 44}, + {0, 34, 45}, + {0, 19, 46}, + {0, 7, 29}, + {0, 16, 47}, + {0, 48, 32}, + {0, 49, 27}, + {0, 11, 14}, + {0, 18, 28}, + {0, 23, 50}, + {0, 51, 12}, + {0, 52, 21}, + {0, 6, 53}, + {0, 55, 54}, + {0, 57, 56}, + {0, 3, 58}, + {0, 13, 59}, + {0, 60, 8}, + {0, 30, 61}, + {0, 62, 2}, + {0, 64, 63}, + {0, 66, 65}, + {0, 68, 67}, + })); + + codecs.emplace(SpvOpCompositeConstruct, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(39, { + {0, 0, 0}, + {65556, 0, 0}, + {131094, 0, 0}, + {131105, 0, 0}, + {196629, 0, 0}, + {196631, 0, 0}, + {196632, 0, 0}, + {196640, 0, 0}, + {196641, 0, 0}, + {262177, 0, 0}, + {327713, 0, 0}, + {393249, 0, 0}, + {458785, 0, 0}, + {524313, 0, 0}, + {524321, 0, 0}, + {589857, 0, 0}, + {655393, 0, 0}, + {786465, 0, 0}, + {917537, 0, 0}, + {1048609, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 19, 18}, + {0, 21, 15}, + {0, 1, 22}, + {0, 16, 23}, + {0, 14, 24}, + {0, 20, 25}, + {0, 13, 17}, + {0, 3, 26}, + {0, 6, 11}, + {0, 27, 12}, + {0, 4, 28}, + {0, 29, 10}, + {0, 9, 30}, + {0, 7, 31}, + {0, 33, 32}, + {0, 34, 5}, + {0, 8, 35}, + {0, 2, 36}, + {0, 38, 37}, + })); + + codecs.emplace(SpvOpTypeFunction, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {131086, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(SpvOpExtInstImport, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(5, { + {0, 0, 0}, + {131099, 0, 0}, + {196640, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 3}, + {0, 1, 4}, + })); + + codecs.emplace(SpvOpTypeImage, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(9, { + {0, 0, 0}, + {131143, 0, 0}, + {196679, 0, 0}, + {196680, 0, 0}, + {262216, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 5, 2}, + {0, 3, 6}, + {0, 7, 1}, + {0, 4, 8}, + })); + + codecs.emplace(SpvOpMemberDecorate, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(5, { + {0, 0, 0}, + {65553, 0, 0}, + {131083, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 3}, + {0, 2, 4}, + })); + + codecs.emplace(SpvOpCapability, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(17, { + {0, 0, 0}, + {196629, 0, 0}, + {196631, 0, 0}, + {196632, 0, 0}, + {196640, 0, 0}, + {196651, 0, 0}, + {196667, 0, 0}, + {327713, 0, 0}, + {458785, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 7, 8}, + {0, 1, 10}, + {0, 6, 11}, + {0, 9, 12}, + {0, 4, 13}, + {0, 3, 14}, + {0, 15, 2}, + {0, 5, 16}, + })); + + codecs.emplace(SpvOpTypeInt, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(29, { + {0, 0, 0}, + {65556, 0, 0}, + {131073, 0, 0}, + {196629, 0, 0}, + {196631, 0, 0}, + {196632, 0, 0}, + {196636, 0, 0}, + {196640, 0, 0}, + {196651, 0, 0}, + {196667, 0, 0}, + {262188, 0, 0}, + {262198, 0, 0}, + {327724, 0, 0}, + {393260, 0, 0}, + {524313, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 6}, + {0, 16, 3}, + {0, 11, 17}, + {0, 5, 18}, + {0, 15, 19}, + {0, 13, 20}, + {0, 1, 4}, + {0, 12, 21}, + {0, 7, 22}, + {0, 14, 23}, + {0, 24, 10}, + {0, 25, 9}, + {0, 27, 26}, + {0, 8, 28}, + })); + + codecs.emplace(SpvOpConstantComposite, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(15, { + {0, 0, 0}, + {65556, 0, 0}, + {196631, 0, 0}, + {196640, 0, 0}, + {196651, 0, 0}, + {196667, 0, 0}, + {327724, 0, 0}, + {393260, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 6, 7}, + {0, 1, 9}, + {0, 10, 8}, + {0, 2, 11}, + {0, 5, 12}, + {0, 13, 4}, + {0, 3, 14}, + })); + + codecs.emplace(SpvOpTypeSampledImage, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(21, { + {0, 0, 0}, + {131073, 0, 0}, + {196629, 0, 0}, + {196631, 0, 0}, + {196632, 0, 0}, + {196636, 0, 0}, + {196640, 0, 0}, + {196641, 0, 0}, + {196651, 0, 0}, + {196667, 0, 0}, + {262198, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 5}, + {0, 11, 12}, + {0, 8, 13}, + {0, 7, 14}, + {0, 4, 10}, + {0, 9, 2}, + {0, 16, 15}, + {0, 1, 17}, + {0, 19, 18}, + {0, 6, 20}, + })); + + codecs.emplace(SpvOpTypeStruct, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(49, { + {0, 0, 0}, + {65785, 0, 0}, + {65790, 0, 0}, + {131134, 0, 0}, + {196669, 0, 0}, + {196735, 0, 0}, + {262201, 0, 0}, + {262209, 0, 0}, + {262224, 0, 0}, + {262225, 0, 0}, + {262272, 0, 0}, + {262273, 0, 0}, + {262275, 0, 0}, + {262277, 0, 0}, + {262286, 0, 0}, + {262292, 0, 0}, + {327692, 0, 0}, + {327745, 0, 0}, + {327760, 0, 0}, + {327762, 0, 0}, + {393228, 0, 0}, + {393295, 0, 0}, + {393296, 0, 0}, + {458764, 0, 0}, + {458831, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 20, 12}, + {0, 26, 24}, + {0, 21, 27}, + {0, 28, 16}, + {0, 10, 8}, + {0, 30, 29}, + {0, 31, 17}, + {0, 32, 13}, + {0, 25, 6}, + {0, 1, 33}, + {0, 14, 11}, + {0, 3, 34}, + {0, 18, 35}, + {0, 37, 36}, + {0, 23, 5}, + {0, 38, 2}, + {0, 39, 7}, + {0, 4, 9}, + {0, 40, 19}, + {0, 42, 41}, + {0, 43, 22}, + {0, 45, 44}, + {0, 46, 15}, + {0, 48, 47}, + })); + + codecs.emplace(SpvOpFNegate, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(11, { + {0, 0, 0}, + {65555, 0, 0}, + {131143, 0, 0}, + {196679, 0, 0}, + {196680, 0, 0}, + {262216, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 4, 6}, + {0, 1, 2}, + {0, 8, 7}, + {0, 5, 9}, + {0, 3, 10}, + })); + + codecs.emplace(SpvOpDecorate, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(25, { + {0, 0, 0}, + {65562, 0, 0}, + {196629, 0, 0}, + {196631, 0, 0}, + {196632, 0, 0}, + {196636, 0, 0}, + {196640, 0, 0}, + {196641, 0, 0}, + {196651, 0, 0}, + {196667, 0, 0}, + {262177, 0, 0}, + {262198, 0, 0}, + {327713, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 12, 11}, + {0, 9, 14}, + {0, 10, 15}, + {0, 13, 16}, + {0, 4, 17}, + {0, 2, 1}, + {0, 18, 7}, + {0, 20, 19}, + {0, 21, 3}, + {0, 22, 6}, + {0, 5, 8}, + {0, 24, 23}, + })); + + codecs.emplace(SpvOpTypeMatrix, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(31, { + {0, 0, 0}, + {65556, 0, 0}, + {131073, 0, 0}, + {131094, 0, 0}, + {196629, 0, 0}, + {196631, 0, 0}, + {196632, 0, 0}, + {196636, 0, 0}, + {196640, 0, 0}, + {196651, 0, 0}, + {196667, 0, 0}, + {262188, 0, 0}, + {262198, 0, 0}, + {327724, 0, 0}, + {393260, 0, 0}, + {524313, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 12, 2}, + {0, 17, 3}, + {0, 5, 18}, + {0, 1, 19}, + {0, 16, 4}, + {0, 21, 20}, + {0, 6, 15}, + {0, 7, 22}, + {0, 24, 23}, + {0, 13, 14}, + {0, 25, 8}, + {0, 26, 11}, + {0, 27, 10}, + {0, 29, 28}, + {0, 30, 9}, + })); + + codecs.emplace(SpvOpConstant, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(33, { + {0, 0, 0}, + {131113, 0, 0}, + {196629, 0, 0}, + {196631, 0, 0}, + {196632, 0, 0}, + {196640, 0, 0}, + {196641, 0, 0}, + {196651, 0, 0}, + {196667, 0, 0}, + {262188, 0, 0}, + {262198, 0, 0}, + {327713, 0, 0}, + {327724, 0, 0}, + {393249, 0, 0}, + {393260, 0, 0}, + {524313, 0, 0}, + {524321, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 6, 4}, + {0, 13, 11}, + {0, 16, 15}, + {0, 18, 10}, + {0, 20, 19}, + {0, 21, 2}, + {0, 23, 22}, + {0, 8, 24}, + {0, 9, 25}, + {0, 17, 26}, + {0, 14, 27}, + {0, 12, 28}, + {0, 1, 3}, + {0, 5, 29}, + {0, 30, 7}, + {0, 32, 31}, + })); + + codecs.emplace(SpvOpTypeBool, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(11, { + {0, 0, 0}, + {196636, 0, 0}, + {196640, 0, 0}, + {196651, 0, 0}, + {196667, 0, 0}, + {524313, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 4, 5}, + {0, 3, 7}, + {0, 2, 8}, + {0, 6, 9}, + {0, 1, 10}, + })); + + codecs.emplace(SpvOpTypeArray, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(67, { + {0, 0, 0}, + {65785, 0, 0}, + {65790, 0, 0}, + {131134, 0, 0}, + {131319, 0, 0}, + {196669, 0, 0}, + {196735, 0, 0}, + {262201, 0, 0}, + {262209, 0, 0}, + {262224, 0, 0}, + {262225, 0, 0}, + {262231, 0, 0}, + {262272, 0, 0}, + {262273, 0, 0}, + {262275, 0, 0}, + {262277, 0, 0}, + {262280, 0, 0}, + {262286, 0, 0}, + {262292, 0, 0}, + {262334, 0, 0}, + {327692, 0, 0}, + {327737, 0, 0}, + {327745, 0, 0}, + {327760, 0, 0}, + {327761, 0, 0}, + {327762, 0, 0}, + {393228, 0, 0}, + {393273, 0, 0}, + {393281, 0, 0}, + {393295, 0, 0}, + {393296, 0, 0}, + {458764, 0, 0}, + {458831, 0, 0}, + {524367, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 7, 27}, + {0, 11, 28}, + {0, 35, 21}, + {0, 36, 1}, + {0, 4, 37}, + {0, 39, 38}, + {0, 40, 30}, + {0, 41, 12}, + {0, 19, 42}, + {0, 13, 43}, + {0, 16, 44}, + {0, 45, 22}, + {0, 34, 18}, + {0, 29, 24}, + {0, 46, 25}, + {0, 6, 2}, + {0, 9, 31}, + {0, 17, 47}, + {0, 49, 48}, + {0, 50, 33}, + {0, 51, 26}, + {0, 20, 52}, + {0, 32, 53}, + {0, 3, 54}, + {0, 15, 14}, + {0, 23, 55}, + {0, 8, 56}, + {0, 58, 57}, + {0, 10, 59}, + {0, 5, 60}, + {0, 62, 61}, + {0, 64, 63}, + {0, 66, 65}, + })); + + codecs.emplace(SpvOpExtInst, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(57, { + {0, 0, 0}, + {65790, 0, 0}, + {131134, 0, 0}, + {196665, 0, 0}, + {196669, 0, 0}, + {196718, 0, 0}, + {262201, 0, 0}, + {262209, 0, 0}, + {262224, 0, 0}, + {262225, 0, 0}, + {262231, 0, 0}, + {262273, 0, 0}, + {262275, 0, 0}, + {262277, 0, 0}, + {262280, 0, 0}, + {262286, 0, 0}, + {262292, 0, 0}, + {327692, 0, 0}, + {327745, 0, 0}, + {327760, 0, 0}, + {327762, 0, 0}, + {393228, 0, 0}, + {393273, 0, 0}, + {393295, 0, 0}, + {393296, 0, 0}, + {393303, 0, 0}, + {458764, 0, 0}, + {458831, 0, 0}, + {524367, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 18, 6}, + {0, 30, 22}, + {0, 31, 25}, + {0, 10, 32}, + {0, 21, 33}, + {0, 3, 34}, + {0, 35, 5}, + {0, 23, 36}, + {0, 14, 17}, + {0, 37, 26}, + {0, 1, 38}, + {0, 29, 39}, + {0, 13, 40}, + {0, 41, 19}, + {0, 28, 20}, + {0, 16, 42}, + {0, 27, 43}, + {0, 8, 24}, + {0, 7, 44}, + {0, 9, 45}, + {0, 15, 46}, + {0, 12, 47}, + {0, 48, 2}, + {0, 4, 49}, + {0, 51, 50}, + {0, 11, 52}, + {0, 54, 53}, + {0, 56, 55}, + })); + + codecs.emplace(SpvOpVectorTimesScalar, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(67, { + {0, 0, 0}, + {65785, 0, 0}, + {65790, 0, 0}, + {131134, 0, 0}, + {196669, 0, 0}, + {196735, 0, 0}, + {262201, 0, 0}, + {262209, 0, 0}, + {262224, 0, 0}, + {262225, 0, 0}, + {262230, 0, 0}, + {262231, 0, 0}, + {262272, 0, 0}, + {262273, 0, 0}, + {262275, 0, 0}, + {262277, 0, 0}, + {262280, 0, 0}, + {262286, 0, 0}, + {262292, 0, 0}, + {327692, 0, 0}, + {327737, 0, 0}, + {327745, 0, 0}, + {327760, 0, 0}, + {327761, 0, 0}, + {327762, 0, 0}, + {393228, 0, 0}, + {393273, 0, 0}, + {393295, 0, 0}, + {393296, 0, 0}, + {393303, 0, 0}, + {393304, 0, 0}, + {458764, 0, 0}, + {458831, 0, 0}, + {524367, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 26, 29}, + {0, 20, 35}, + {0, 12, 36}, + {0, 6, 37}, + {0, 38, 28}, + {0, 30, 5}, + {0, 8, 39}, + {0, 2, 40}, + {0, 41, 21}, + {0, 1, 10}, + {0, 43, 42}, + {0, 23, 16}, + {0, 44, 33}, + {0, 34, 31}, + {0, 14, 45}, + {0, 19, 46}, + {0, 25, 47}, + {0, 49, 48}, + {0, 27, 22}, + {0, 7, 50}, + {0, 17, 32}, + {0, 18, 51}, + {0, 24, 52}, + {0, 54, 53}, + {0, 55, 9}, + {0, 56, 11}, + {0, 57, 4}, + {0, 15, 58}, + {0, 59, 13}, + {0, 60, 3}, + {0, 62, 61}, + {0, 64, 63}, + {0, 66, 65}, + })); + + codecs.emplace(SpvOpVectorShuffle, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(33, { + {0, 0, 0}, + {65790, 0, 0}, + {131134, 0, 0}, + {196669, 0, 0}, + {262201, 0, 0}, + {262209, 0, 0}, + {262225, 0, 0}, + {262231, 0, 0}, + {262273, 0, 0}, + {262277, 0, 0}, + {262286, 0, 0}, + {262292, 0, 0}, + {327745, 0, 0}, + {393281, 0, 0}, + {393295, 0, 0}, + {393296, 0, 0}, + {458831, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 13, 12}, + {0, 1, 18}, + {0, 19, 11}, + {0, 9, 20}, + {0, 10, 21}, + {0, 22, 15}, + {0, 23, 8}, + {0, 4, 24}, + {0, 25, 7}, + {0, 17, 26}, + {0, 5, 27}, + {0, 14, 3}, + {0, 29, 28}, + {0, 30, 2}, + {0, 6, 31}, + {0, 32, 16}, + })); + + codecs.emplace(SpvOpImageSampleImplicitLod, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(55, { + {0, 0, 0}, + {65785, 0, 0}, + {65790, 0, 0}, + {131134, 0, 0}, + {196669, 0, 0}, + {196735, 0, 0}, + {196817, 0, 0}, + {262209, 0, 0}, + {262224, 0, 0}, + {262225, 0, 0}, + {262273, 0, 0}, + {262275, 0, 0}, + {262277, 0, 0}, + {262280, 0, 0}, + {262286, 0, 0}, + {262292, 0, 0}, + {327692, 0, 0}, + {327745, 0, 0}, + {327760, 0, 0}, + {327761, 0, 0}, + {327762, 0, 0}, + {393228, 0, 0}, + {393281, 0, 0}, + {393295, 0, 0}, + {393296, 0, 0}, + {393298, 0, 0}, + {458764, 0, 0}, + {458831, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 5, 2}, + {0, 22, 29}, + {0, 30, 1}, + {0, 6, 31}, + {0, 9, 32}, + {0, 28, 3}, + {0, 27, 33}, + {0, 20, 16}, + {0, 34, 8}, + {0, 10, 35}, + {0, 4, 36}, + {0, 24, 23}, + {0, 21, 13}, + {0, 7, 37}, + {0, 38, 14}, + {0, 25, 39}, + {0, 17, 11}, + {0, 12, 19}, + {0, 41, 40}, + {0, 42, 18}, + {0, 15, 43}, + {0, 45, 44}, + {0, 47, 46}, + {0, 26, 48}, + {0, 50, 49}, + {0, 52, 51}, + {0, 54, 53}, + })); + + codecs.emplace(SpvOpDot, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(11, { + {0, 0, 0}, + {131075, 0, 0}, + {131088, 0, 0}, + {196624, 0, 0}, + {196679, 0, 0}, + {262216, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 5, 3}, + {0, 2, 7}, + {0, 1, 8}, + {0, 6, 9}, + {0, 4, 10}, + })); + + codecs.emplace(SpvOpExecutionMode, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {196858, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(SpvOpSelectionMerge, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(23, { + {0, 0, 0}, + {131134, 0, 0}, + {196669, 0, 0}, + {262209, 0, 0}, + {262224, 0, 0}, + {262225, 0, 0}, + {262277, 0, 0}, + {327745, 0, 0}, + {327761, 0, 0}, + {327762, 0, 0}, + {393295, 0, 0}, + {393296, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 12}, + {0, 7, 13}, + {0, 5, 1}, + {0, 4, 10}, + {0, 14, 6}, + {0, 16, 15}, + {0, 17, 11}, + {0, 3, 8}, + {0, 19, 18}, + {0, 9, 20}, + {0, 22, 21}, + })); + + codecs.emplace(SpvOpImageSampleDrefExplicitLod, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {65790, 0, 0}, + {131073, 0, 0}, + {262198, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 4, 1}, + {0, 3, 5}, + {0, 2, 6}, + })); + + codecs.emplace(SpvOpUndef, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(59, { + {0, 0, 0}, + {65785, 0, 0}, + {131134, 0, 0}, + {131319, 0, 0}, + {196669, 0, 0}, + {196735, 0, 0}, + {262209, 0, 0}, + {262221, 0, 0}, + {262224, 0, 0}, + {262225, 0, 0}, + {262230, 0, 0}, + {262273, 0, 0}, + {262275, 0, 0}, + {262277, 0, 0}, + {262280, 0, 0}, + {262286, 0, 0}, + {262288, 0, 0}, + {262292, 0, 0}, + {262334, 0, 0}, + {327692, 0, 0}, + {327760, 0, 0}, + {327761, 0, 0}, + {327762, 0, 0}, + {393228, 0, 0}, + {393295, 0, 0}, + {393296, 0, 0}, + {393298, 0, 0}, + {458764, 0, 0}, + {458831, 0, 0}, + {524367, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 17, 3}, + {0, 5, 31}, + {0, 11, 32}, + {0, 33, 12}, + {0, 34, 20}, + {0, 16, 27}, + {0, 35, 23}, + {0, 37, 36}, + {0, 14, 18}, + {0, 39, 38}, + {0, 7, 30}, + {0, 8, 25}, + {0, 40, 15}, + {0, 13, 2}, + {0, 1, 29}, + {0, 19, 41}, + {0, 43, 42}, + {0, 28, 44}, + {0, 46, 45}, + {0, 22, 21}, + {0, 47, 24}, + {0, 48, 26}, + {0, 10, 6}, + {0, 50, 49}, + {0, 52, 51}, + {0, 54, 53}, + {0, 4, 9}, + {0, 56, 55}, + {0, 58, 57}, + })); + + codecs.emplace(SpvOpCompositeInsert, std::move(codec)); + } + + return codecs; +} + +std::map>> +GetLiteralStringHuffmanCodecs() { + std::map>> codecs; + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {"", 0, 0}, + {"MainPs", 0, 0}, + {"MainVs", 0, 0}, + {"kMarkvNoneOfTheAbove", 0, 0}, + {"main", 0, 0}, + {"", 2, 3}, + {"", 1, 5}, + {"", 4, 6}, + })); + + codecs.emplace(SpvOpEntryPoint, std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {"", 0, 0}, + {"GLSL.std.450", 0, 0}, + {"kMarkvNoneOfTheAbove", 0, 0}, + {"", 1, 2}, + })); + + codecs.emplace(SpvOpExtInstImport, std::move(codec)); + } + + return codecs; +} + +std::map, std::unique_ptr>> +GetNonIdWordHuffmanCodecs() { + std::map, std::unique_ptr>> codecs; + { + std::unique_ptr> codec(new HuffmanCodec(33, { + {0, 0, 0}, + {4, 0, 0}, + {8, 0, 0}, + {10, 0, 0}, + {26, 0, 0}, + {29, 0, 0}, + {31, 0, 0}, + {37, 0, 0}, + {40, 0, 0}, + {43, 0, 0}, + {46, 0, 0}, + {49, 0, 0}, + {66, 0, 0}, + {67, 0, 0}, + {68, 0, 0}, + {69, 0, 0}, + {71, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 12, 5}, + {0, 18, 13}, + {0, 3, 7}, + {0, 19, 11}, + {0, 20, 16}, + {0, 14, 17}, + {0, 21, 1}, + {0, 2, 6}, + {0, 23, 22}, + {0, 4, 24}, + {0, 26, 25}, + {0, 28, 27}, + {0, 10, 15}, + {0, 8, 9}, + {0, 30, 29}, + {0, 32, 31}, + })); + + codecs.emplace(std::pair(SpvOpExtInst, 3), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {0, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpMemoryModel, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {1, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpMemoryModel, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(5, { + {0, 0, 0}, + {0, 0, 0}, + {4, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 3}, + {0, 2, 4}, + })); + + codecs.emplace(std::pair(SpvOpEntryPoint, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(5, { + {0, 0, 0}, + {7, 0, 0}, + {8, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 2}, + {0, 1, 4}, + })); + + codecs.emplace(std::pair(SpvOpExecutionMode, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(11, { + {0, 0, 0}, + {1, 0, 0}, + {2, 0, 0}, + {3, 0, 0}, + {4, 0, 0}, + {18, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 4, 2}, + {0, 6, 5}, + {0, 7, 1}, + {0, 3, 8}, + {0, 10, 9}, + })); + + codecs.emplace(std::pair(SpvOpExecutionMode, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(5, { + {0, 0, 0}, + {1, 0, 0}, + {32, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 3}, + {0, 1, 4}, + })); + + codecs.emplace(std::pair(SpvOpCapability, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {32, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpTypeInt, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(5, { + {0, 0, 0}, + {0, 0, 0}, + {1, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 3}, + {0, 1, 4}, + })); + + codecs.emplace(std::pair(SpvOpTypeInt, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {32, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpTypeFloat, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {2, 0, 0}, + {3, 0, 0}, + {4, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 4}, + {0, 1, 5}, + {0, 6, 3}, + })); + + codecs.emplace(std::pair(SpvOpTypeVector, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {2, 0, 0}, + {3, 0, 0}, + {4, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 4}, + {0, 2, 5}, + {0, 3, 6}, + })); + + codecs.emplace(std::pair(SpvOpTypeMatrix, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {1, 0, 0}, + {2, 0, 0}, + {3, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 4}, + {0, 2, 5}, + {0, 1, 6}, + })); + + codecs.emplace(std::pair(SpvOpTypeImage, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(5, { + {0, 0, 0}, + {0, 0, 0}, + {1, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 3}, + {0, 1, 4}, + })); + + codecs.emplace(std::pair(SpvOpTypeImage, 3), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {0, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpTypeImage, 4), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {0, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpTypeImage, 5), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {1, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpTypeImage, 6), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {0, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpTypeImage, 7), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(13, { + {0, 0, 0}, + {0, 0, 0}, + {1, 0, 0}, + {2, 0, 0}, + {3, 0, 0}, + {6, 0, 0}, + {7, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 5, 7}, + {0, 6, 8}, + {0, 1, 4}, + {0, 2, 9}, + {0, 10, 3}, + {0, 12, 11}, + })); + + codecs.emplace(std::pair(SpvOpTypePointer, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(173, { + {0, 0, 0}, + {0, 0, 0}, + {1, 0, 0}, + {2, 0, 0}, + {3, 0, 0}, + {4, 0, 0}, + {5, 0, 0}, + {6, 0, 0}, + {7, 0, 0}, + {8, 0, 0}, + {9, 0, 0}, + {10, 0, 0}, + {11, 0, 0}, + {12, 0, 0}, + {13, 0, 0}, + {14, 0, 0}, + {15, 0, 0}, + {16, 0, 0}, + {17, 0, 0}, + {18, 0, 0}, + {19, 0, 0}, + {20, 0, 0}, + {21, 0, 0}, + {22, 0, 0}, + {23, 0, 0}, + {24, 0, 0}, + {26, 0, 0}, + {27, 0, 0}, + {28, 0, 0}, + {29, 0, 0}, + {30, 0, 0}, + {31, 0, 0}, + {32, 0, 0}, + {256, 0, 0}, + {507307272, 0, 0}, + {864026611, 0, 0}, + {981668463, 0, 0}, + {997553156, 0, 0}, + {1014330372, 0, 0}, + {1020708227, 0, 0}, + {1028443341, 0, 0}, + {1032953056, 0, 0}, + {1033463938, 0, 0}, + {1033463943, 0, 0}, + {1039998884, 0, 0}, + {1039998950, 0, 0}, + {1040187392, 0, 0}, + {1042401985, 0, 0}, + {1044220635, 0, 0}, + {1045622707, 0, 0}, + {1045622740, 0, 0}, + {1048576000, 0, 0}, + {1053609165, 0, 0}, + {1053790359, 0, 0}, + {1054448026, 0, 0}, + {1055437881, 0, 0}, + {1056300230, 0, 0}, + {1056964608, 0, 0}, + {1058056805, 0, 0}, + {1059286575, 0, 0}, + {1061158912, 0, 0}, + {1061997773, 0, 0}, + {1064514355, 0, 0}, + {1064854933, 0, 0}, + {1065353216, 0, 0}, + {1069547520, 0, 0}, + {1073741824, 0, 0}, + {1077936128, 0, 0}, + {1082130432, 0, 0}, + {1091567616, 0, 0}, + {1115422720, 0, 0}, + {1124073472, 0, 0}, + {1132396544, 0, 0}, + {1140850688, 0, 0}, + {1199562752, 0, 0}, + {3179067684, 0, 0}, + {3180973575, 0, 0}, + {3182651297, 0, 0}, + {3196448879, 0, 0}, + {3204448256, 0, 0}, + {3204993516, 0, 0}, + {3205248529, 0, 0}, + {3207137644, 0, 0}, + {3208642560, 0, 0}, + {3211081967, 0, 0}, + {3212836864, 0, 0}, + {3332128768, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 38, 37}, + {0, 42, 39}, + {0, 49, 44}, + {0, 45, 43}, + {0, 26, 50}, + {0, 46, 73}, + {0, 35, 28}, + {0, 32, 65}, + {0, 83, 40}, + {0, 60, 62}, + {0, 27, 54}, + {0, 79, 67}, + {0, 31, 74}, + {0, 51, 12}, + {0, 70, 30}, + {0, 15, 16}, + {0, 88, 25}, + {0, 90, 89}, + {0, 34, 71}, + {0, 72, 29}, + {0, 92, 91}, + {0, 14, 33}, + {0, 94, 93}, + {0, 22, 23}, + {0, 21, 95}, + {0, 19, 24}, + {0, 96, 13}, + {0, 47, 41}, + {0, 53, 48}, + {0, 58, 56}, + {0, 63, 59}, + {0, 76, 75}, + {0, 78, 77}, + {0, 81, 80}, + {0, 84, 82}, + {0, 52, 20}, + {0, 97, 69}, + {0, 99, 98}, + {0, 18, 10}, + {0, 68, 61}, + {0, 17, 100}, + {0, 102, 101}, + {0, 11, 36}, + {0, 104, 103}, + {0, 86, 105}, + {0, 107, 106}, + {0, 109, 108}, + {0, 110, 9}, + {0, 8, 111}, + {0, 113, 112}, + {0, 115, 114}, + {0, 117, 116}, + {0, 119, 118}, + {0, 121, 120}, + {0, 123, 122}, + {0, 125, 124}, + {0, 126, 7}, + {0, 127, 85}, + {0, 6, 128}, + {0, 129, 55}, + {0, 130, 5}, + {0, 132, 131}, + {0, 134, 133}, + {0, 136, 135}, + {0, 137, 66}, + {0, 139, 138}, + {0, 141, 140}, + {0, 143, 142}, + {0, 145, 144}, + {0, 146, 57}, + {0, 147, 64}, + {0, 148, 4}, + {0, 149, 2}, + {0, 151, 150}, + {0, 152, 3}, + {0, 154, 153}, + {0, 156, 155}, + {0, 158, 157}, + {0, 159, 1}, + {0, 160, 87}, + {0, 162, 161}, + {0, 164, 163}, + {0, 166, 165}, + {0, 168, 167}, + {0, 170, 169}, + {0, 172, 171}, + })); + + codecs.emplace(std::pair(SpvOpConstant, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {0, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpFunction, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(13, { + {0, 0, 0}, + {0, 0, 0}, + {1, 0, 0}, + {2, 0, 0}, + {3, 0, 0}, + {6, 0, 0}, + {7, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 7}, + {0, 4, 8}, + {0, 9, 2}, + {0, 1, 5}, + {0, 10, 6}, + {0, 12, 11}, + })); + + codecs.emplace(std::pair(SpvOpVariable, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(15, { + {0, 0, 0}, + {0, 0, 0}, + {2, 0, 0}, + {6, 0, 0}, + {11, 0, 0}, + {30, 0, 0}, + {33, 0, 0}, + {34, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 4, 8}, + {0, 9, 1}, + {0, 3, 10}, + {0, 6, 11}, + {0, 12, 2}, + {0, 7, 5}, + {0, 14, 13}, + })); + + codecs.emplace(std::pair(SpvOpDecorate, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(37, { + {0, 0, 0}, + {0, 0, 0}, + {1, 0, 0}, + {2, 0, 0}, + {3, 0, 0}, + {4, 0, 0}, + {5, 0, 0}, + {6, 0, 0}, + {7, 0, 0}, + {8, 0, 0}, + {9, 0, 0}, + {10, 0, 0}, + {12, 0, 0}, + {13, 0, 0}, + {14, 0, 0}, + {15, 0, 0}, + {16, 0, 0}, + {18, 0, 0}, + {64, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 17, 11}, + {0, 10, 13}, + {0, 12, 14}, + {0, 21, 20}, + {0, 9, 22}, + {0, 19, 15}, + {0, 8, 23}, + {0, 18, 24}, + {0, 25, 7}, + {0, 5, 6}, + {0, 26, 16}, + {0, 27, 4}, + {0, 28, 3}, + {0, 30, 29}, + {0, 31, 2}, + {0, 33, 32}, + {0, 35, 34}, + {0, 1, 36}, + })); + + codecs.emplace(std::pair(SpvOpDecorate, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(79, { + {0, 0, 0}, + {0, 0, 0}, + {1, 0, 0}, + {2, 0, 0}, + {3, 0, 0}, + {4, 0, 0}, + {5, 0, 0}, + {6, 0, 0}, + {7, 0, 0}, + {8, 0, 0}, + {9, 0, 0}, + {10, 0, 0}, + {11, 0, 0}, + {12, 0, 0}, + {13, 0, 0}, + {14, 0, 0}, + {15, 0, 0}, + {16, 0, 0}, + {17, 0, 0}, + {18, 0, 0}, + {19, 0, 0}, + {20, 0, 0}, + {21, 0, 0}, + {22, 0, 0}, + {23, 0, 0}, + {24, 0, 0}, + {25, 0, 0}, + {26, 0, 0}, + {27, 0, 0}, + {28, 0, 0}, + {29, 0, 0}, + {30, 0, 0}, + {31, 0, 0}, + {32, 0, 0}, + {33, 0, 0}, + {34, 0, 0}, + {35, 0, 0}, + {36, 0, 0}, + {37, 0, 0}, + {38, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 39, 37}, + {0, 40, 36}, + {0, 34, 35}, + {0, 32, 33}, + {0, 30, 31}, + {0, 27, 29}, + {0, 26, 28}, + {0, 42, 41}, + {0, 23, 25}, + {0, 38, 22}, + {0, 44, 43}, + {0, 46, 45}, + {0, 21, 47}, + {0, 19, 20}, + {0, 17, 18}, + {0, 14, 15}, + {0, 12, 10}, + {0, 16, 13}, + {0, 9, 11}, + {0, 7, 8}, + {0, 6, 5}, + {0, 24, 48}, + {0, 50, 49}, + {0, 3, 4}, + {0, 51, 2}, + {0, 1, 52}, + {0, 54, 53}, + {0, 56, 55}, + {0, 58, 57}, + {0, 60, 59}, + {0, 62, 61}, + {0, 64, 63}, + {0, 66, 65}, + {0, 68, 67}, + {0, 70, 69}, + {0, 72, 71}, + {0, 74, 73}, + {0, 76, 75}, + {0, 78, 77}, + })); + + codecs.emplace(std::pair(SpvOpMemberDecorate, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {4, 0, 0}, + {7, 0, 0}, + {35, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 4}, + {0, 5, 2}, + {0, 3, 6}, + })); + + codecs.emplace(std::pair(SpvOpMemberDecorate, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(149, { + {0, 0, 0}, + {0, 0, 0}, + {16, 0, 0}, + {28, 0, 0}, + {32, 0, 0}, + {36, 0, 0}, + {40, 0, 0}, + {44, 0, 0}, + {48, 0, 0}, + {60, 0, 0}, + {64, 0, 0}, + {76, 0, 0}, + {80, 0, 0}, + {84, 0, 0}, + {88, 0, 0}, + {92, 0, 0}, + {96, 0, 0}, + {100, 0, 0}, + {108, 0, 0}, + {112, 0, 0}, + {120, 0, 0}, + {124, 0, 0}, + {128, 0, 0}, + {132, 0, 0}, + {136, 0, 0}, + {140, 0, 0}, + {144, 0, 0}, + {148, 0, 0}, + {152, 0, 0}, + {156, 0, 0}, + {160, 0, 0}, + {172, 0, 0}, + {176, 0, 0}, + {192, 0, 0}, + {204, 0, 0}, + {208, 0, 0}, + {224, 0, 0}, + {236, 0, 0}, + {240, 0, 0}, + {248, 0, 0}, + {256, 0, 0}, + {272, 0, 0}, + {288, 0, 0}, + {292, 0, 0}, + {296, 0, 0}, + {300, 0, 0}, + {304, 0, 0}, + {316, 0, 0}, + {320, 0, 0}, + {332, 0, 0}, + {336, 0, 0}, + {348, 0, 0}, + {352, 0, 0}, + {364, 0, 0}, + {368, 0, 0}, + {372, 0, 0}, + {376, 0, 0}, + {384, 0, 0}, + {392, 0, 0}, + {400, 0, 0}, + {416, 0, 0}, + {424, 0, 0}, + {432, 0, 0}, + {448, 0, 0}, + {460, 0, 0}, + {464, 0, 0}, + {468, 0, 0}, + {472, 0, 0}, + {476, 0, 0}, + {480, 0, 0}, + {488, 0, 0}, + {492, 0, 0}, + {496, 0, 0}, + {512, 0, 0}, + {640, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 14, 17}, + {0, 37, 31}, + {0, 21, 39}, + {0, 24, 23}, + {0, 5, 13}, + {0, 38, 76}, + {0, 51, 77}, + {0, 55, 53}, + {0, 58, 56}, + {0, 64, 61}, + {0, 67, 66}, + {0, 70, 68}, + {0, 54, 71}, + {0, 62, 60}, + {0, 65, 63}, + {0, 73, 72}, + {0, 59, 57}, + {0, 52, 74}, + {0, 50, 69}, + {0, 49, 47}, + {0, 48, 46}, + {0, 45, 43}, + {0, 42, 44}, + {0, 78, 41}, + {0, 20, 18}, + {0, 80, 79}, + {0, 15, 27}, + {0, 7, 34}, + {0, 81, 6}, + {0, 28, 3}, + {0, 35, 82}, + {0, 9, 36}, + {0, 84, 83}, + {0, 86, 85}, + {0, 88, 87}, + {0, 90, 89}, + {0, 92, 91}, + {0, 94, 93}, + {0, 96, 95}, + {0, 98, 97}, + {0, 11, 29}, + {0, 99, 25}, + {0, 100, 40}, + {0, 102, 101}, + {0, 26, 32}, + {0, 19, 30}, + {0, 16, 12}, + {0, 4, 8}, + {0, 104, 103}, + {0, 106, 105}, + {0, 33, 107}, + {0, 109, 108}, + {0, 111, 110}, + {0, 22, 112}, + {0, 113, 10}, + {0, 115, 114}, + {0, 75, 116}, + {0, 118, 117}, + {0, 119, 1}, + {0, 121, 120}, + {0, 123, 122}, + {0, 125, 124}, + {0, 127, 126}, + {0, 129, 128}, + {0, 131, 130}, + {0, 132, 2}, + {0, 134, 133}, + {0, 136, 135}, + {0, 138, 137}, + {0, 140, 139}, + {0, 142, 141}, + {0, 144, 143}, + {0, 146, 145}, + {0, 148, 147}, + })); + + codecs.emplace(std::pair(SpvOpMemberDecorate, 3), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(11, { + {0, 0, 0}, + {0, 0, 0}, + {1, 0, 0}, + {2, 0, 0}, + {3, 0, 0}, + {4, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 6}, + {0, 4, 7}, + {0, 8, 3}, + {0, 9, 5}, + {0, 1, 10}, + })); + + codecs.emplace(std::pair(SpvOpVectorShuffle, 4), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(13, { + {0, 0, 0}, + {0, 0, 0}, + {1, 0, 0}, + {2, 0, 0}, + {3, 0, 0}, + {4, 0, 0}, + {5, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 7}, + {0, 8, 5}, + {0, 9, 1}, + {0, 4, 10}, + {0, 11, 6}, + {0, 2, 12}, + })); + + codecs.emplace(std::pair(SpvOpVectorShuffle, 5), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(15, { + {0, 0, 0}, + {0, 0, 0}, + {1, 0, 0}, + {2, 0, 0}, + {3, 0, 0}, + {4, 0, 0}, + {5, 0, 0}, + {6, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 6, 8}, + {0, 5, 2}, + {0, 10, 9}, + {0, 1, 4}, + {0, 12, 11}, + {0, 7, 13}, + {0, 3, 14}, + })); + + codecs.emplace(std::pair(SpvOpVectorShuffle, 6), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(15, { + {0, 0, 0}, + {0, 0, 0}, + {1, 0, 0}, + {2, 0, 0}, + {3, 0, 0}, + {4, 0, 0}, + {5, 0, 0}, + {6, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 8, 5}, + {0, 9, 7}, + {0, 10, 3}, + {0, 11, 2}, + {0, 6, 1}, + {0, 13, 12}, + {0, 4, 14}, + })); + + codecs.emplace(std::pair(SpvOpVectorShuffle, 7), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(61, { + {0, 0, 0}, + {0, 0, 0}, + {1, 0, 0}, + {2, 0, 0}, + {3, 0, 0}, + {4, 0, 0}, + {5, 0, 0}, + {6, 0, 0}, + {7, 0, 0}, + {8, 0, 0}, + {9, 0, 0}, + {10, 0, 0}, + {11, 0, 0}, + {12, 0, 0}, + {13, 0, 0}, + {14, 0, 0}, + {15, 0, 0}, + {16, 0, 0}, + {17, 0, 0}, + {18, 0, 0}, + {19, 0, 0}, + {20, 0, 0}, + {21, 0, 0}, + {22, 0, 0}, + {23, 0, 0}, + {24, 0, 0}, + {27, 0, 0}, + {28, 0, 0}, + {29, 0, 0}, + {30, 0, 0}, + {31, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 30, 16}, + {0, 26, 27}, + {0, 29, 28}, + {0, 18, 22}, + {0, 12, 19}, + {0, 15, 20}, + {0, 14, 23}, + {0, 32, 7}, + {0, 8, 21}, + {0, 11, 33}, + {0, 17, 34}, + {0, 25, 13}, + {0, 36, 35}, + {0, 9, 10}, + {0, 38, 37}, + {0, 39, 31}, + {0, 5, 40}, + {0, 42, 41}, + {0, 44, 43}, + {0, 6, 45}, + {0, 46, 24}, + {0, 48, 47}, + {0, 50, 49}, + {0, 52, 51}, + {0, 54, 53}, + {0, 55, 4}, + {0, 56, 3}, + {0, 57, 2}, + {0, 58, 1}, + {0, 60, 59}, + })); + + codecs.emplace(std::pair(SpvOpCompositeExtract, 3), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(63, { + {0, 0, 0}, + {0, 0, 0}, + {1, 0, 0}, + {2, 0, 0}, + {3, 0, 0}, + {4, 0, 0}, + {5, 0, 0}, + {6, 0, 0}, + {7, 0, 0}, + {8, 0, 0}, + {9, 0, 0}, + {10, 0, 0}, + {11, 0, 0}, + {12, 0, 0}, + {13, 0, 0}, + {29, 0, 0}, + {30, 0, 0}, + {31, 0, 0}, + {32, 0, 0}, + {33, 0, 0}, + {34, 0, 0}, + {35, 0, 0}, + {36, 0, 0}, + {37, 0, 0}, + {38, 0, 0}, + {39, 0, 0}, + {40, 0, 0}, + {41, 0, 0}, + {42, 0, 0}, + {43, 0, 0}, + {44, 0, 0}, + {45, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 13, 14}, + {0, 12, 9}, + {0, 11, 25}, + {0, 27, 26}, + {0, 29, 28}, + {0, 31, 30}, + {0, 23, 22}, + {0, 10, 24}, + {0, 8, 21}, + {0, 17, 7}, + {0, 19, 18}, + {0, 15, 20}, + {0, 6, 16}, + {0, 5, 33}, + {0, 35, 34}, + {0, 37, 36}, + {0, 39, 38}, + {0, 41, 40}, + {0, 43, 42}, + {0, 45, 44}, + {0, 47, 46}, + {0, 49, 48}, + {0, 51, 50}, + {0, 32, 52}, + {0, 54, 53}, + {0, 56, 55}, + {0, 58, 57}, + {0, 3, 2}, + {0, 59, 4}, + {0, 60, 1}, + {0, 62, 61}, + })); + + codecs.emplace(std::pair(SpvOpCompositeExtract, 4), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(9, { + {0, 0, 0}, + {0, 0, 0}, + {1, 0, 0}, + {2, 0, 0}, + {3, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 5}, + {0, 3, 2}, + {0, 6, 4}, + {0, 8, 7}, + })); + + codecs.emplace(std::pair(SpvOpCompositeExtract, 5), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(23, { + {0, 0, 0}, + {0, 0, 0}, + {1, 0, 0}, + {2, 0, 0}, + {3, 0, 0}, + {4, 0, 0}, + {5, 0, 0}, + {6, 0, 0}, + {7, 0, 0}, + {8, 0, 0}, + {9, 0, 0}, + {10, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 12, 11}, + {0, 10, 13}, + {0, 9, 14}, + {0, 7, 5}, + {0, 8, 6}, + {0, 4, 15}, + {0, 17, 16}, + {0, 18, 3}, + {0, 19, 2}, + {0, 20, 1}, + {0, 22, 21}, + })); + + codecs.emplace(std::pair(SpvOpCompositeInsert, 4), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(9, { + {0, 0, 0}, + {0, 0, 0}, + {1, 0, 0}, + {2, 0, 0}, + {3, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 5}, + {0, 2, 6}, + {0, 7, 1}, + {0, 4, 8}, + })); + + codecs.emplace(std::pair(SpvOpCompositeInsert, 5), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {1, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpImageSampleImplicitLod, 4), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(5, { + {0, 0, 0}, + {2, 0, 0}, + {10, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 3}, + {0, 1, 4}, + })); + + codecs.emplace(std::pair(SpvOpImageSampleExplicitLod, 4), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {2, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpImageSampleDrefExplicitLod, 5), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {0, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpSelectionMerge, 1), std::move(codec)); + } + + return codecs; +} + +std::map, std::unique_ptr>> +GetIdDescriptorHuffmanCodecs() { + std::map, std::unique_ptr>> codecs; + { + std::unique_ptr> codec(new HuffmanCodec(9, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 5}, + {0, 4, 6}, + {0, 1, 7}, + {0, 2, 8}, + })); + + codecs.emplace(std::pair(SpvOpExtInst, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(63, { + {0, 0, 0}, + {34183582, 0, 0}, + {223800276, 0, 0}, + {295018543, 0, 0}, + {439764402, 0, 0}, + {443558693, 0, 0}, + {583624926, 0, 0}, + {599185303, 0, 0}, + {779021139, 0, 0}, + {1015552308, 0, 0}, + {1027242654, 0, 0}, + {1077859090, 0, 0}, + {1104362365, 0, 0}, + {1132589448, 0, 0}, + {1236389532, 0, 0}, + {1739837626, 0, 0}, + {1955104493, 0, 0}, + {2161102232, 0, 0}, + {2197874825, 0, 0}, + {2217833278, 0, 0}, + {2244470522, 0, 0}, + {2532518896, 0, 0}, + {2789375411, 0, 0}, + {3061690214, 0, 0}, + {3287039847, 0, 0}, + {3357301402, 0, 0}, + {3365041621, 0, 0}, + {3510257966, 0, 0}, + {3534235309, 0, 0}, + {4018237905, 0, 0}, + {4145966869, 0, 0}, + {4272200782, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 10, 19}, + {0, 6, 1}, + {0, 26, 13}, + {0, 2, 11}, + {0, 15, 22}, + {0, 23, 18}, + {0, 4, 27}, + {0, 28, 12}, + {0, 3, 30}, + {0, 9, 7}, + {0, 20, 14}, + {0, 29, 16}, + {0, 21, 8}, + {0, 34, 33}, + {0, 36, 35}, + {0, 31, 25}, + {0, 37, 24}, + {0, 39, 38}, + {0, 41, 40}, + {0, 43, 42}, + {0, 45, 44}, + {0, 17, 5}, + {0, 47, 46}, + {0, 49, 48}, + {0, 51, 50}, + {0, 53, 52}, + {0, 55, 54}, + {0, 57, 56}, + {0, 59, 58}, + {0, 61, 60}, + {0, 32, 62}, + })); + + codecs.emplace(std::pair(SpvOpExtInst, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {4228502127, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpExtInst, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(113, { + {0, 0, 0}, + {50998433, 0, 0}, + {139011596, 0, 0}, + {181902171, 0, 0}, + {296981500, 0, 0}, + {321630747, 0, 0}, + {416853049, 0, 0}, + {464259778, 0, 0}, + {615982737, 0, 0}, + {669982125, 0, 0}, + {759277550, 0, 0}, + {810488476, 0, 0}, + {870594305, 0, 0}, + {922996215, 0, 0}, + {969500141, 0, 0}, + {1015552308, 0, 0}, + {1139547465, 0, 0}, + {1203545131, 0, 0}, + {1220643281, 0, 0}, + {1220749418, 0, 0}, + {1367301635, 0, 0}, + {1395923345, 0, 0}, + {1554194368, 0, 0}, + {1742737136, 0, 0}, + {1755648697, 0, 0}, + {1962162282, 0, 0}, + {1964254745, 0, 0}, + {2055836767, 0, 0}, + {2096388952, 0, 0}, + {2124837447, 0, 0}, + {2161102232, 0, 0}, + {2321729979, 0, 0}, + {2346547796, 0, 0}, + {2399809085, 0, 0}, + {2432827426, 0, 0}, + {2455417440, 0, 0}, + {2572638469, 0, 0}, + {2614879967, 0, 0}, + {2855506940, 0, 0}, + {2919796598, 0, 0}, + {2970183398, 0, 0}, + {2976066508, 0, 0}, + {3044188332, 0, 0}, + {3061690214, 0, 0}, + {3091876332, 0, 0}, + {3104643263, 0, 0}, + {3107165180, 0, 0}, + {3187066832, 0, 0}, + {3413713311, 0, 0}, + {3487022798, 0, 0}, + {3602693817, 0, 0}, + {3678875745, 0, 0}, + {3701632935, 0, 0}, + {3829325073, 0, 0}, + {4040340620, 0, 0}, + {4174489262, 0, 0}, + {4272200782, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 33, 7}, + {0, 13, 34}, + {0, 21, 18}, + {0, 53, 22}, + {0, 39, 1}, + {0, 14, 9}, + {0, 43, 26}, + {0, 51, 35}, + {0, 19, 6}, + {0, 15, 25}, + {0, 55, 29}, + {0, 32, 3}, + {0, 27, 44}, + {0, 10, 46}, + {0, 45, 24}, + {0, 36, 40}, + {0, 47, 8}, + {0, 48, 54}, + {0, 58, 5}, + {0, 60, 59}, + {0, 30, 61}, + {0, 62, 56}, + {0, 64, 63}, + {0, 41, 50}, + {0, 66, 65}, + {0, 68, 67}, + {0, 70, 69}, + {0, 37, 31}, + {0, 4, 17}, + {0, 16, 20}, + {0, 72, 71}, + {0, 73, 52}, + {0, 49, 12}, + {0, 75, 74}, + {0, 76, 11}, + {0, 23, 42}, + {0, 78, 77}, + {0, 80, 79}, + {0, 82, 81}, + {0, 84, 83}, + {0, 85, 28}, + {0, 87, 86}, + {0, 89, 88}, + {0, 91, 90}, + {0, 93, 92}, + {0, 94, 2}, + {0, 96, 95}, + {0, 98, 97}, + {0, 100, 99}, + {0, 102, 101}, + {0, 38, 103}, + {0, 105, 104}, + {0, 107, 106}, + {0, 109, 108}, + {0, 111, 110}, + {0, 57, 112}, + })); + + codecs.emplace(std::pair(SpvOpExtInst, 4), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(127, { + {0, 0, 0}, + {72782198, 0, 0}, + {139011596, 0, 0}, + {296981500, 0, 0}, + {300939750, 0, 0}, + {401211099, 0, 0}, + {429277936, 0, 0}, + {505940164, 0, 0}, + {538168945, 0, 0}, + {603915804, 0, 0}, + {688216667, 0, 0}, + {706016261, 0, 0}, + {790502615, 0, 0}, + {810488476, 0, 0}, + {993150979, 0, 0}, + {1203545131, 0, 0}, + {1206726575, 0, 0}, + {1265796414, 0, 0}, + {1314843976, 0, 0}, + {1367301635, 0, 0}, + {1378082995, 0, 0}, + {1410311776, 0, 0}, + {1443829854, 0, 0}, + {1448448666, 0, 0}, + {1468919488, 0, 0}, + {1496351055, 0, 0}, + {1619778288, 0, 0}, + {1684282922, 0, 0}, + {1848784182, 0, 0}, + {1901166356, 0, 0}, + {2095546797, 0, 0}, + {2096388952, 0, 0}, + {2162986400, 0, 0}, + {2197874825, 0, 0}, + {2246405597, 0, 0}, + {2250225826, 0, 0}, + {2282454607, 0, 0}, + {2328748202, 0, 0}, + {2348201466, 0, 0}, + {2597020383, 0, 0}, + {2633682514, 0, 0}, + {2817335337, 0, 0}, + {2855506940, 0, 0}, + {2936040203, 0, 0}, + {2955375511, 0, 0}, + {3122368657, 0, 0}, + {3154597438, 0, 0}, + {3184381405, 0, 0}, + {3187066832, 0, 0}, + {3233393284, 0, 0}, + {3251128023, 0, 0}, + {3260309823, 0, 0}, + {3441531391, 0, 0}, + {3496407048, 0, 0}, + {3582002820, 0, 0}, + {3647586740, 0, 0}, + {3653838348, 0, 0}, + {3730093054, 0, 0}, + {3759072440, 0, 0}, + {3928764629, 0, 0}, + {3969279737, 0, 0}, + {3994511488, 0, 0}, + {4026740269, 0, 0}, + {4274214049, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 43, 23}, + {0, 5, 24}, + {0, 9, 8}, + {0, 36, 21}, + {0, 13, 46}, + {0, 7, 12}, + {0, 35, 20}, + {0, 61, 59}, + {0, 22, 29}, + {0, 38, 62}, + {0, 56, 45}, + {0, 6, 48}, + {0, 33, 30}, + {0, 14, 58}, + {0, 34, 28}, + {0, 51, 40}, + {0, 63, 55}, + {0, 25, 16}, + {0, 17, 11}, + {0, 53, 52}, + {0, 65, 27}, + {0, 39, 41}, + {0, 67, 66}, + {0, 69, 68}, + {0, 10, 4}, + {0, 37, 18}, + {0, 60, 47}, + {0, 1, 32}, + {0, 71, 70}, + {0, 73, 72}, + {0, 57, 26}, + {0, 74, 31}, + {0, 76, 75}, + {0, 77, 44}, + {0, 78, 15}, + {0, 79, 54}, + {0, 81, 80}, + {0, 82, 49}, + {0, 84, 83}, + {0, 86, 85}, + {0, 88, 87}, + {0, 89, 19}, + {0, 91, 90}, + {0, 93, 92}, + {0, 95, 94}, + {0, 2, 96}, + {0, 98, 97}, + {0, 100, 99}, + {0, 102, 101}, + {0, 104, 103}, + {0, 106, 105}, + {0, 3, 107}, + {0, 109, 108}, + {0, 111, 110}, + {0, 113, 112}, + {0, 114, 50}, + {0, 116, 115}, + {0, 118, 117}, + {0, 120, 119}, + {0, 122, 121}, + {0, 124, 123}, + {0, 64, 42}, + {0, 126, 125}, + })); + + codecs.emplace(std::pair(SpvOpExtInst, 5), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(93, { + {0, 0, 0}, + {99347751, 0, 0}, + {102542696, 0, 0}, + {107497541, 0, 0}, + {112452386, 0, 0}, + {139011596, 0, 0}, + {296981500, 0, 0}, + {429277936, 0, 0}, + {451957774, 0, 0}, + {508217552, 0, 0}, + {573901046, 0, 0}, + {774727851, 0, 0}, + {801484894, 0, 0}, + {920604853, 0, 0}, + {925559698, 0, 0}, + {1022915255, 0, 0}, + {1209418480, 0, 0}, + {1287937401, 0, 0}, + {1319785741, 0, 0}, + {1392080469, 0, 0}, + {1538342947, 0, 0}, + {1541020250, 0, 0}, + {1587209598, 0, 0}, + {1594733696, 0, 0}, + {1631434666, 0, 0}, + {1636389511, 0, 0}, + {1684282922, 0, 0}, + {1859128680, 0, 0}, + {1901166356, 0, 0}, + {2004567202, 0, 0}, + {2119793999, 0, 0}, + {2280400314, 0, 0}, + {2538917932, 0, 0}, + {2677264274, 0, 0}, + {2683080096, 0, 0}, + {2854085372, 0, 0}, + {2879917501, 0, 0}, + {3059119137, 0, 0}, + {3174324790, 0, 0}, + {3194725903, 0, 0}, + {3358097187, 0, 0}, + {3547456240, 0, 0}, + {3614752756, 0, 0}, + {3753486980, 0, 0}, + {3811268385, 0, 0}, + {3953733490, 0, 0}, + {3990925720, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 23, 22}, + {0, 36, 31}, + {0, 17, 40}, + {0, 27, 19}, + {0, 35, 33}, + {0, 30, 38}, + {0, 42, 39}, + {0, 46, 32}, + {0, 13, 12}, + {0, 44, 14}, + {0, 29, 11}, + {0, 10, 18}, + {0, 15, 37}, + {0, 1, 4}, + {0, 45, 2}, + {0, 21, 28}, + {0, 8, 5}, + {0, 49, 48}, + {0, 51, 50}, + {0, 53, 52}, + {0, 54, 16}, + {0, 55, 25}, + {0, 56, 3}, + {0, 58, 57}, + {0, 59, 26}, + {0, 20, 7}, + {0, 61, 60}, + {0, 62, 24}, + {0, 41, 63}, + {0, 65, 64}, + {0, 9, 34}, + {0, 67, 66}, + {0, 69, 68}, + {0, 71, 70}, + {0, 73, 72}, + {0, 75, 74}, + {0, 76, 43}, + {0, 78, 77}, + {0, 80, 79}, + {0, 82, 81}, + {0, 84, 83}, + {0, 86, 85}, + {0, 88, 87}, + {0, 90, 89}, + {0, 47, 91}, + {0, 92, 6}, + })); + + codecs.emplace(std::pair(SpvOpExtInst, 6), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(15, { + {0, 0, 0}, + {166253838, 0, 0}, + {679771963, 0, 0}, + {1247793383, 0, 0}, + {2261697609, 0, 0}, + {2263349224, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 4, 8}, + {0, 9, 1}, + {0, 3, 5}, + {0, 11, 10}, + {0, 2, 12}, + {0, 7, 6}, + {0, 14, 13}, + })); + + codecs.emplace(std::pair(SpvOpTypeVector, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(9, { + {0, 0, 0}, + {789872778, 0, 0}, + {1415510495, 0, 0}, + {1951208733, 0, 0}, + {2430404313, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 5}, + {0, 4, 6}, + {0, 7, 1}, + {0, 3, 8}, + })); + + codecs.emplace(std::pair(SpvOpTypeVector, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(15, { + {0, 0, 0}, + {1389644742, 0, 0}, + {3232633974, 0, 0}, + {3278176820, 0, 0}, + {3648138580, 0, 0}, + {3687777340, 0, 0}, + {3694383800, 0, 0}, + {3697687030, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 5, 4}, + {0, 9, 6}, + {0, 10, 8}, + {0, 2, 11}, + {0, 12, 3}, + {0, 1, 13}, + {0, 14, 7}, + })); + + codecs.emplace(std::pair(SpvOpTypeArray, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {1951208733, 0, 0}, + {2160380860, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 4}, + {0, 2, 5}, + {0, 3, 6}, + })); + + codecs.emplace(std::pair(SpvOpTypeArray, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(13, { + {0, 0, 0}, + {144116905, 0, 0}, + {827246872, 0, 0}, + {1545298048, 0, 0}, + {2715370488, 0, 0}, + {2798552666, 0, 0}, + {3812456892, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 3}, + {0, 8, 6}, + {0, 9, 7}, + {0, 1, 10}, + {0, 11, 4}, + {0, 5, 12}, + })); + + codecs.emplace(std::pair(SpvOpTypeArray, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(67, { + {0, 0, 0}, + {40653745, 0, 0}, + {119981689, 0, 0}, + {153085016, 0, 0}, + {451382997, 0, 0}, + {545678922, 0, 0}, + {899570100, 0, 0}, + {929101967, 0, 0}, + {1070791291, 0, 0}, + {1100599986, 0, 0}, + {1103903216, 0, 0}, + {1154919607, 0, 0}, + {1199157863, 0, 0}, + {1258105452, 0, 0}, + {1369578001, 0, 0}, + {1372881231, 0, 0}, + {1674803691, 0, 0}, + {1677700667, 0, 0}, + {1989520052, 0, 0}, + {2593884753, 0, 0}, + {2664825925, 0, 0}, + {2924146124, 0, 0}, + {2926633629, 0, 0}, + {3249265647, 0, 0}, + {3345288309, 0, 0}, + {3410158390, 0, 0}, + {3489360962, 0, 0}, + {3495967422, 0, 0}, + {3504981554, 0, 0}, + {3705139860, 0, 0}, + {3822983876, 0, 0}, + {4141567741, 0, 0}, + {4234287173, 0, 0}, + {4240893633, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 15, 23}, + {0, 20, 17}, + {0, 32, 22}, + {0, 19, 12}, + {0, 13, 3}, + {0, 30, 27}, + {0, 4, 35}, + {0, 24, 36}, + {0, 31, 37}, + {0, 33, 38}, + {0, 39, 7}, + {0, 6, 40}, + {0, 41, 29}, + {0, 14, 42}, + {0, 43, 28}, + {0, 10, 44}, + {0, 45, 18}, + {0, 26, 46}, + {0, 5, 47}, + {0, 48, 2}, + {0, 49, 9}, + {0, 50, 16}, + {0, 34, 25}, + {0, 52, 51}, + {0, 54, 53}, + {0, 56, 55}, + {0, 58, 57}, + {0, 60, 59}, + {0, 8, 21}, + {0, 1, 11}, + {0, 62, 61}, + {0, 64, 63}, + {0, 66, 65}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(11, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2160380860, 0, 0}, + {3278176820, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 4, 6}, + {0, 2, 7}, + {0, 3, 8}, + {0, 9, 1}, + {0, 5, 10}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(13, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2160380860, 0, 0}, + {2320303498, 0, 0}, + {3232633974, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 5, 7}, + {0, 2, 8}, + {0, 4, 9}, + {0, 10, 3}, + {0, 1, 6}, + {0, 12, 11}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(11, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2160380860, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 5, 6}, + {0, 1, 7}, + {0, 3, 4}, + {0, 8, 2}, + {0, 10, 9}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 3), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(11, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2160380860, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 6}, + {0, 3, 7}, + {0, 5, 4}, + {0, 8, 1}, + {0, 10, 9}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 4), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(9, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2263349224, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 5}, + {0, 1, 6}, + {0, 2, 7}, + {0, 8, 4}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 5), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(9, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 5}, + {0, 1, 6}, + {0, 2, 7}, + {0, 8, 4}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 6), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(9, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 5}, + {0, 4, 6}, + {0, 7, 1}, + {0, 2, 8}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 7), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(9, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 5}, + {0, 1, 6}, + {0, 7, 4}, + {0, 2, 8}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 8), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 4}, + {0, 3, 5}, + {0, 1, 6}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 9), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(9, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 5, 3}, + {0, 1, 6}, + {0, 4, 7}, + {0, 8, 2}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 10), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(9, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 5}, + {0, 1, 6}, + {0, 7, 4}, + {0, 8, 2}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 11), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(5, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 3}, + {0, 4, 1}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 12), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 4}, + {0, 1, 5}, + {0, 2, 6}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 13), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(9, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 4, 5}, + {0, 3, 6}, + {0, 7, 1}, + {0, 8, 2}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 14), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 4}, + {0, 5, 3}, + {0, 6, 2}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 15), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 4}, + {0, 2, 5}, + {0, 1, 6}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 16), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(9, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 5}, + {0, 4, 6}, + {0, 7, 1}, + {0, 8, 2}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 17), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 4}, + {0, 1, 5}, + {0, 2, 6}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 18), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 4}, + {0, 3, 5}, + {0, 2, 6}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 19), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 4}, + {0, 1, 5}, + {0, 2, 6}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 20), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 4}, + {0, 2, 5}, + {0, 3, 6}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 21), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(9, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 5, 1}, + {0, 2, 6}, + {0, 3, 7}, + {0, 8, 4}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 22), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(9, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 5}, + {0, 2, 6}, + {0, 4, 7}, + {0, 8, 3}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 23), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(11, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2160380860, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 6, 4}, + {0, 1, 7}, + {0, 2, 8}, + {0, 3, 9}, + {0, 10, 5}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 24), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(9, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 5}, + {0, 2, 6}, + {0, 4, 7}, + {0, 8, 3}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 25), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 4}, + {0, 2, 5}, + {0, 3, 6}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 26), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 4}, + {0, 2, 5}, + {0, 3, 6}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 27), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(5, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 3}, + {0, 1, 4}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 28), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(5, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 3}, + {0, 2, 4}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 29), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(5, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 3}, + {0, 2, 4}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 30), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 4, 3}, + {0, 1, 5}, + {0, 2, 6}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 31), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 4}, + {0, 1, 5}, + {0, 2, 6}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 32), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(5, { + {0, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 3}, + {0, 1, 4}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 33), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(5, { + {0, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 3}, + {0, 2, 4}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 34), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(5, { + {0, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 3}, + {0, 1, 4}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 35), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(5, { + {0, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 3}, + {0, 1, 4}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 36), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {679771963, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 37), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {1389644742, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 38), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {3697687030, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 39), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {2320303498, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 40), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {2320303498, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 41), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {679771963, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 42), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {679771963, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 43), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {679771963, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 44), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {679771963, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 45), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {679771963, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 46), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {679771963, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 47), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {679771963, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 48), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {679771963, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 49), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {679771963, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 50), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {679771963, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpTypeStruct, 51), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(101, { + {0, 0, 0}, + {85880059, 0, 0}, + {135486769, 0, 0}, + {304448521, 0, 0}, + {436416061, 0, 0}, + {440421571, 0, 0}, + {450406196, 0, 0}, + {503094540, 0, 0}, + {543621065, 0, 0}, + {626892406, 0, 0}, + {628544021, 0, 0}, + {827698488, 0, 0}, + {869050696, 0, 0}, + {907126242, 0, 0}, + {908777857, 0, 0}, + {910429472, 0, 0}, + {1113409935, 0, 0}, + {1294403159, 0, 0}, + {1296054774, 0, 0}, + {1297706389, 0, 0}, + {1322549027, 0, 0}, + {1784441183, 0, 0}, + {2080953106, 0, 0}, + {2194691858, 0, 0}, + {2448331885, 0, 0}, + {2466255445, 0, 0}, + {2468230023, 0, 0}, + {2547657777, 0, 0}, + {2549309392, 0, 0}, + {2550961007, 0, 0}, + {2894051250, 0, 0}, + {2929019254, 0, 0}, + {2934934694, 0, 0}, + {2936586309, 0, 0}, + {2938237924, 0, 0}, + {3077271274, 0, 0}, + {3092528578, 0, 0}, + {3094180193, 0, 0}, + {3094857332, 0, 0}, + {3095831808, 0, 0}, + {3183924418, 0, 0}, + {3207966516, 0, 0}, + {3282979782, 0, 0}, + {3433956341, 0, 0}, + {3561562003, 0, 0}, + {3563213618, 0, 0}, + {3564865233, 0, 0}, + {3585511591, 0, 0}, + {4028622909, 0, 0}, + {4039938779, 0, 0}, + {4050155669, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 16, 25}, + {0, 50, 1}, + {0, 42, 35}, + {0, 31, 41}, + {0, 4, 43}, + {0, 9, 10}, + {0, 3, 30}, + {0, 52, 47}, + {0, 12, 53}, + {0, 55, 54}, + {0, 36, 56}, + {0, 49, 57}, + {0, 6, 58}, + {0, 34, 33}, + {0, 59, 26}, + {0, 21, 32}, + {0, 60, 15}, + {0, 24, 61}, + {0, 62, 38}, + {0, 22, 2}, + {0, 37, 7}, + {0, 63, 46}, + {0, 14, 13}, + {0, 64, 5}, + {0, 65, 45}, + {0, 66, 19}, + {0, 18, 67}, + {0, 17, 20}, + {0, 68, 11}, + {0, 8, 69}, + {0, 70, 39}, + {0, 72, 71}, + {0, 74, 73}, + {0, 40, 75}, + {0, 76, 23}, + {0, 78, 77}, + {0, 29, 79}, + {0, 28, 80}, + {0, 27, 48}, + {0, 82, 81}, + {0, 51, 83}, + {0, 84, 44}, + {0, 86, 85}, + {0, 88, 87}, + {0, 90, 89}, + {0, 92, 91}, + {0, 94, 93}, + {0, 96, 95}, + {0, 98, 97}, + {0, 100, 99}, + })); + + codecs.emplace(std::pair(SpvOpTypePointer, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(65, { + {0, 0, 0}, + {119981689, 0, 0}, + {162255877, 0, 0}, + {451382997, 0, 0}, + {545678922, 0, 0}, + {679771963, 0, 0}, + {789872778, 0, 0}, + {1100599986, 0, 0}, + {1103903216, 0, 0}, + {1154919607, 0, 0}, + {1343794461, 0, 0}, + {1415510495, 0, 0}, + {1674803691, 0, 0}, + {1951208733, 0, 0}, + {1989520052, 0, 0}, + {2160380860, 0, 0}, + {2263349224, 0, 0}, + {2320303498, 0, 0}, + {2924146124, 0, 0}, + {2984325996, 0, 0}, + {3334207724, 0, 0}, + {3345288309, 0, 0}, + {3410158390, 0, 0}, + {3489360962, 0, 0}, + {3495967422, 0, 0}, + {3504981554, 0, 0}, + {3800912395, 0, 0}, + {3802564010, 0, 0}, + {3866587616, 0, 0}, + {3868239231, 0, 0}, + {3869890846, 0, 0}, + {3998230222, 0, 0}, + {4240893633, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 4, 3}, + {0, 6, 24}, + {0, 11, 7}, + {0, 32, 21}, + {0, 27, 34}, + {0, 35, 25}, + {0, 36, 8}, + {0, 26, 31}, + {0, 14, 15}, + {0, 28, 37}, + {0, 1, 23}, + {0, 39, 38}, + {0, 12, 40}, + {0, 22, 41}, + {0, 10, 16}, + {0, 43, 42}, + {0, 29, 44}, + {0, 2, 45}, + {0, 46, 19}, + {0, 48, 47}, + {0, 18, 49}, + {0, 50, 30}, + {0, 9, 33}, + {0, 52, 51}, + {0, 54, 53}, + {0, 13, 55}, + {0, 17, 56}, + {0, 5, 57}, + {0, 59, 58}, + {0, 60, 20}, + {0, 62, 61}, + {0, 64, 63}, + })); + + codecs.emplace(std::pair(SpvOpTypePointer, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(99, { + {0, 0, 0}, + {75986790, 0, 0}, + {95470391, 0, 0}, + {170378107, 0, 0}, + {172029722, 0, 0}, + {204234270, 0, 0}, + {205885885, 0, 0}, + {244668133, 0, 0}, + {265778447, 0, 0}, + {616435646, 0, 0}, + {618087261, 0, 0}, + {753954113, 0, 0}, + {1000070091, 0, 0}, + {1308462133, 0, 0}, + {1671139745, 0, 0}, + {1774874546, 0, 0}, + {1776526161, 0, 0}, + {1887808856, 0, 0}, + {1889460471, 0, 0}, + {1917966999, 0, 0}, + {2044728014, 0, 0}, + {2192810893, 0, 0}, + {2293247016, 0, 0}, + {2503194620, 0, 0}, + {2605012269, 0, 0}, + {2608484640, 0, 0}, + {2615111110, 0, 0}, + {2668769415, 0, 0}, + {2759951687, 0, 0}, + {2761603302, 0, 0}, + {2856623532, 0, 0}, + {2945369269, 0, 0}, + {2956189845, 0, 0}, + {3085119011, 0, 0}, + {3367313400, 0, 0}, + {3447882276, 0, 0}, + {3633746133, 0, 0}, + {3635397748, 0, 0}, + {3710645347, 0, 0}, + {3712296962, 0, 0}, + {3715846592, 0, 0}, + {3727494858, 0, 0}, + {3747079365, 0, 0}, + {3748965853, 0, 0}, + {3750617468, 0, 0}, + {4018820793, 0, 0}, + {4022124023, 0, 0}, + {4024173916, 0, 0}, + {4215670524, 0, 0}, + {4217322139, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 10, 9}, + {0, 31, 24}, + {0, 40, 13}, + {0, 45, 33}, + {0, 34, 46}, + {0, 43, 38}, + {0, 44, 15}, + {0, 11, 30}, + {0, 21, 6}, + {0, 47, 3}, + {0, 51, 16}, + {0, 14, 52}, + {0, 8, 53}, + {0, 35, 5}, + {0, 55, 54}, + {0, 56, 26}, + {0, 20, 57}, + {0, 39, 19}, + {0, 59, 58}, + {0, 61, 60}, + {0, 4, 62}, + {0, 2, 63}, + {0, 25, 7}, + {0, 64, 27}, + {0, 12, 22}, + {0, 65, 48}, + {0, 41, 42}, + {0, 17, 23}, + {0, 49, 66}, + {0, 68, 67}, + {0, 70, 69}, + {0, 72, 71}, + {0, 74, 73}, + {0, 18, 75}, + {0, 37, 32}, + {0, 76, 36}, + {0, 78, 77}, + {0, 79, 28}, + {0, 81, 80}, + {0, 82, 29}, + {0, 84, 83}, + {0, 86, 85}, + {0, 88, 87}, + {0, 90, 89}, + {0, 91, 50}, + {0, 93, 92}, + {0, 95, 94}, + {0, 1, 96}, + {0, 98, 97}, + })); + + codecs.emplace(std::pair(SpvOpTypeFunction, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(27, { + {0, 0, 0}, + {545678922, 0, 0}, + {679771963, 0, 0}, + {899570100, 0, 0}, + {929101967, 0, 0}, + {1100599986, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {3056042030, 0, 0}, + {3334207724, 0, 0}, + {3357250579, 0, 0}, + {3705139860, 0, 0}, + {3800912395, 0, 0}, + {3802564010, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 5, 3}, + {0, 10, 13}, + {0, 4, 15}, + {0, 16, 11}, + {0, 17, 1}, + {0, 14, 12}, + {0, 19, 18}, + {0, 21, 20}, + {0, 7, 6}, + {0, 9, 22}, + {0, 24, 23}, + {0, 25, 2}, + {0, 26, 8}, + })); + + codecs.emplace(std::pair(SpvOpTypeFunction, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(57, { + {0, 0, 0}, + {283209196, 0, 0}, + {436416061, 0, 0}, + {679771963, 0, 0}, + {789872778, 0, 0}, + {815757910, 0, 0}, + {827698488, 0, 0}, + {1164221089, 0, 0}, + {1294403159, 0, 0}, + {1296054774, 0, 0}, + {1297706389, 0, 0}, + {1525861001, 0, 0}, + {1579585816, 0, 0}, + {1675764636, 0, 0}, + {1824016656, 0, 0}, + {1951208733, 0, 0}, + {1991787192, 0, 0}, + {2180701723, 0, 0}, + {2194691858, 0, 0}, + {2320303498, 0, 0}, + {2881886868, 0, 0}, + {2926633629, 0, 0}, + {3249265647, 0, 0}, + {3334207724, 0, 0}, + {3472123498, 0, 0}, + {3674863070, 0, 0}, + {4050155669, 0, 0}, + {4141567741, 0, 0}, + {4155122613, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 24, 7}, + {0, 17, 1}, + {0, 4, 15}, + {0, 11, 16}, + {0, 28, 30}, + {0, 25, 20}, + {0, 14, 31}, + {0, 32, 26}, + {0, 12, 5}, + {0, 2, 22}, + {0, 33, 13}, + {0, 35, 34}, + {0, 37, 36}, + {0, 39, 38}, + {0, 40, 21}, + {0, 29, 18}, + {0, 27, 41}, + {0, 43, 42}, + {0, 19, 44}, + {0, 45, 23}, + {0, 6, 3}, + {0, 47, 46}, + {0, 49, 48}, + {0, 51, 50}, + {0, 10, 8}, + {0, 53, 52}, + {0, 9, 54}, + {0, 56, 55}, + })); + + codecs.emplace(std::pair(SpvOpTypeFunction, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(17, { + {0, 0, 0}, + {679771963, 0, 0}, + {827698488, 0, 0}, + {1294403159, 0, 0}, + {1296054774, 0, 0}, + {1297706389, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 8, 9}, + {0, 10, 6}, + {0, 1, 5}, + {0, 11, 3}, + {0, 12, 7}, + {0, 13, 2}, + {0, 15, 14}, + {0, 16, 4}, + })); + + codecs.emplace(std::pair(SpvOpTypeFunction, 3), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(17, { + {0, 0, 0}, + {679771963, 0, 0}, + {827698488, 0, 0}, + {1294403159, 0, 0}, + {1296054774, 0, 0}, + {1951208733, 0, 0}, + {2194691858, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 8, 5}, + {0, 10, 9}, + {0, 11, 6}, + {0, 7, 12}, + {0, 1, 3}, + {0, 2, 13}, + {0, 15, 14}, + {0, 4, 16}, + })); + + codecs.emplace(std::pair(SpvOpTypeFunction, 4), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(11, { + {0, 0, 0}, + {827698488, 0, 0}, + {1294403159, 0, 0}, + {1296054774, 0, 0}, + {1297706389, 0, 0}, + {1951208733, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 4, 6}, + {0, 5, 7}, + {0, 2, 8}, + {0, 1, 9}, + {0, 10, 3}, + })); + + codecs.emplace(std::pair(SpvOpTypeFunction, 5), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(11, { + {0, 0, 0}, + {827698488, 0, 0}, + {1294403159, 0, 0}, + {1296054774, 0, 0}, + {1951208733, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 6}, + {0, 4, 7}, + {0, 8, 5}, + {0, 3, 9}, + {0, 1, 10}, + })); + + codecs.emplace(std::pair(SpvOpTypeFunction, 6), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(9, { + {0, 0, 0}, + {789872778, 0, 0}, + {827698488, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 5, 1}, + {0, 4, 6}, + {0, 3, 7}, + {0, 2, 8}, + })); + + codecs.emplace(std::pair(SpvOpTypeFunction, 7), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {543621065, 0, 0}, + {827698488, 0, 0}, + {1951208733, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 4}, + {0, 1, 5}, + {0, 2, 6}, + })); + + codecs.emplace(std::pair(SpvOpTypeFunction, 8), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {827698488, 0, 0}, + {1951208733, 0, 0}, + {3095831808, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 4}, + {0, 3, 5}, + {0, 1, 6}, + })); + + codecs.emplace(std::pair(SpvOpTypeFunction, 9), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(5, { + {0, 0, 0}, + {1296054774, 0, 0}, + {1951208733, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 2}, + {0, 1, 4}, + })); + + codecs.emplace(std::pair(SpvOpTypeFunction, 10), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(5, { + {0, 0, 0}, + {1296054774, 0, 0}, + {2320303498, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 2}, + {0, 1, 4}, + })); + + codecs.emplace(std::pair(SpvOpTypeFunction, 11), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(5, { + {0, 0, 0}, + {789872778, 0, 0}, + {1951208733, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 2}, + {0, 4, 1}, + })); + + codecs.emplace(std::pair(SpvOpTypeFunction, 12), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(5, { + {0, 0, 0}, + {789872778, 0, 0}, + {1951208733, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 1}, + {0, 4, 3}, + })); + + codecs.emplace(std::pair(SpvOpTypeFunction, 13), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {1951208733, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 1}, + })); + + codecs.emplace(std::pair(SpvOpTypeFunction, 14), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {1951208733, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 1}, + })); + + codecs.emplace(std::pair(SpvOpTypeFunction, 15), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {789872778, 0, 0}, + {1951208733, 0, 0}, + {2430404313, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 4}, + {0, 1, 5}, + {0, 2, 6}, + })); + + codecs.emplace(std::pair(SpvOpConstant, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(183, { + {0, 0, 0}, + {51041423, 0, 0}, + {52882140, 0, 0}, + {72782198, 0, 0}, + {142465290, 0, 0}, + {144116905, 0, 0}, + {158160339, 0, 0}, + {169135842, 0, 0}, + {210116709, 0, 0}, + {290391815, 0, 0}, + {296981500, 0, 0}, + {385229009, 0, 0}, + {438318340, 0, 0}, + {529742207, 0, 0}, + {628331516, 0, 0}, + {677668732, 0, 0}, + {778500192, 0, 0}, + {825595257, 0, 0}, + {910398460, 0, 0}, + {917019124, 0, 0}, + {959681532, 0, 0}, + {1031290113, 0, 0}, + {1039111164, 0, 0}, + {1064945649, 0, 0}, + {1087394637, 0, 0}, + {1092948665, 0, 0}, + {1156369516, 0, 0}, + {1158021131, 0, 0}, + {1172110445, 0, 0}, + {1304296041, 0, 0}, + {1400019344, 0, 0}, + {1450415100, 0, 0}, + {1452222566, 0, 0}, + {1543646433, 0, 0}, + {1543672828, 0, 0}, + {1612361408, 0, 0}, + {1622381564, 0, 0}, + {1691572958, 0, 0}, + {1755648697, 0, 0}, + {1782996825, 0, 0}, + {1784648440, 0, 0}, + {1930923350, 0, 0}, + {1939359710, 0, 0}, + {1971252067, 0, 0}, + {1979847999, 0, 0}, + {2078849875, 0, 0}, + {2113115132, 0, 0}, + {2135340676, 0, 0}, + {2170273742, 0, 0}, + {2268204687, 0, 0}, + {2285081596, 0, 0}, + {2318200267, 0, 0}, + {2321729979, 0, 0}, + {2326636627, 0, 0}, + {2444465148, 0, 0}, + {2466126792, 0, 0}, + {2490492987, 0, 0}, + {2524697596, 0, 0}, + {2557550659, 0, 0}, + {2678954464, 0, 0}, + {2705477184, 0, 0}, + {2715370488, 0, 0}, + {2732195517, 0, 0}, + {2775815164, 0, 0}, + {2796901051, 0, 0}, + {2798552666, 0, 0}, + {2855506940, 0, 0}, + {2860348412, 0, 0}, + {2922615804, 0, 0}, + {2937761472, 0, 0}, + {2944827576, 0, 0}, + {3092754101, 0, 0}, + {3107165180, 0, 0}, + {3168953855, 0, 0}, + {3184177968, 0, 0}, + {3202349435, 0, 0}, + {3266548732, 0, 0}, + {3332104493, 0, 0}, + {3362723943, 0, 0}, + {3571454885, 0, 0}, + {3712763835, 0, 0}, + {3743748793, 0, 0}, + {3810805277, 0, 0}, + {3912967080, 0, 0}, + {3929248764, 0, 0}, + {3958731802, 0, 0}, + {3997952447, 0, 0}, + {4016096296, 0, 0}, + {4106658327, 0, 0}, + {4172568578, 0, 0}, + {4198082194, 0, 0}, + {4248015868, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 35, 16}, + {0, 49, 42}, + {0, 86, 69}, + {0, 53, 30}, + {0, 45, 89}, + {0, 50, 68}, + {0, 73, 71}, + {0, 17, 46}, + {0, 14, 81}, + {0, 63, 44}, + {0, 12, 3}, + {0, 72, 31}, + {0, 55, 67}, + {0, 36, 19}, + {0, 22, 88}, + {0, 9, 70}, + {0, 93, 23}, + {0, 95, 94}, + {0, 47, 91}, + {0, 34, 32}, + {0, 97, 96}, + {0, 41, 61}, + {0, 99, 98}, + {0, 37, 1}, + {0, 77, 100}, + {0, 51, 60}, + {0, 101, 79}, + {0, 6, 2}, + {0, 11, 7}, + {0, 24, 21}, + {0, 43, 28}, + {0, 59, 56}, + {0, 75, 62}, + {0, 80, 78}, + {0, 87, 83}, + {0, 18, 15}, + {0, 102, 38}, + {0, 104, 103}, + {0, 85, 90}, + {0, 76, 25}, + {0, 29, 105}, + {0, 107, 106}, + {0, 58, 52}, + {0, 109, 108}, + {0, 57, 110}, + {0, 112, 111}, + {0, 114, 113}, + {0, 115, 33}, + {0, 74, 116}, + {0, 118, 117}, + {0, 120, 119}, + {0, 122, 121}, + {0, 124, 123}, + {0, 126, 125}, + {0, 128, 127}, + {0, 130, 129}, + {0, 131, 13}, + {0, 54, 27}, + {0, 133, 132}, + {0, 48, 40}, + {0, 5, 8}, + {0, 82, 134}, + {0, 26, 135}, + {0, 39, 4}, + {0, 136, 64}, + {0, 138, 137}, + {0, 140, 139}, + {0, 84, 141}, + {0, 143, 142}, + {0, 145, 144}, + {0, 147, 146}, + {0, 149, 148}, + {0, 20, 150}, + {0, 65, 151}, + {0, 66, 152}, + {0, 153, 10}, + {0, 155, 154}, + {0, 157, 156}, + {0, 159, 158}, + {0, 161, 160}, + {0, 163, 162}, + {0, 165, 164}, + {0, 167, 166}, + {0, 169, 168}, + {0, 170, 92}, + {0, 172, 171}, + {0, 174, 173}, + {0, 176, 175}, + {0, 178, 177}, + {0, 180, 179}, + {0, 182, 181}, + })); + + codecs.emplace(std::pair(SpvOpConstant, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(9, { + {0, 0, 0}, + {679771963, 0, 0}, + {1247793383, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 5}, + {0, 4, 6}, + {0, 1, 3}, + {0, 8, 7}, + })); + + codecs.emplace(std::pair(SpvOpConstantComposite, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(83, { + {0, 0, 0}, + {15502752, 0, 0}, + {46736908, 0, 0}, + {139011596, 0, 0}, + {149720480, 0, 0}, + {249378857, 0, 0}, + {251209228, 0, 0}, + {503145996, 0, 0}, + {836581417, 0, 0}, + {882718761, 0, 0}, + {1289566249, 0, 0}, + {1325348861, 0, 0}, + {1558001705, 0, 0}, + {1646147798, 0, 0}, + {1679946323, 0, 0}, + {1766401548, 0, 0}, + {1992893964, 0, 0}, + {2123388694, 0, 0}, + {2162986400, 0, 0}, + {2580096524, 0, 0}, + {2598189097, 0, 0}, + {2683080096, 0, 0}, + {2698156268, 0, 0}, + {2763960513, 0, 0}, + {3015046341, 0, 0}, + {3133016299, 0, 0}, + {3251128023, 0, 0}, + {3504158761, 0, 0}, + {3535289452, 0, 0}, + {3536941067, 0, 0}, + {3538592682, 0, 0}, + {3540244297, 0, 0}, + {3541895912, 0, 0}, + {3570219049, 0, 0}, + {3653838348, 0, 0}, + {3764205609, 0, 0}, + {3882634684, 0, 0}, + {3913885196, 0, 0}, + {3982047273, 0, 0}, + {4024252457, 0, 0}, + {4243119782, 0, 0}, + {4255182614, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 8, 4}, + {0, 39, 2}, + {0, 38, 10}, + {0, 29, 41}, + {0, 23, 28}, + {0, 9, 24}, + {0, 44, 43}, + {0, 45, 6}, + {0, 20, 12}, + {0, 18, 33}, + {0, 19, 16}, + {0, 7, 46}, + {0, 48, 47}, + {0, 5, 49}, + {0, 13, 11}, + {0, 17, 14}, + {0, 25, 22}, + {0, 40, 36}, + {0, 1, 50}, + {0, 31, 30}, + {0, 51, 32}, + {0, 42, 52}, + {0, 54, 53}, + {0, 55, 15}, + {0, 37, 56}, + {0, 57, 34}, + {0, 59, 58}, + {0, 61, 60}, + {0, 35, 21}, + {0, 62, 26}, + {0, 64, 63}, + {0, 65, 27}, + {0, 3, 66}, + {0, 68, 67}, + {0, 70, 69}, + {0, 72, 71}, + {0, 74, 73}, + {0, 76, 75}, + {0, 78, 77}, + {0, 80, 79}, + {0, 82, 81}, + })); + + codecs.emplace(std::pair(SpvOpConstantComposite, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(65, { + {0, 0, 0}, + {142465290, 0, 0}, + {158160339, 0, 0}, + {169135842, 0, 0}, + {210116709, 0, 0}, + {296981500, 0, 0}, + {615748604, 0, 0}, + {910398460, 0, 0}, + {959681532, 0, 0}, + {1039111164, 0, 0}, + {1087394637, 0, 0}, + {1156369516, 0, 0}, + {1450415100, 0, 0}, + {1543672828, 0, 0}, + {2100532220, 0, 0}, + {2170273742, 0, 0}, + {2285081596, 0, 0}, + {2326636627, 0, 0}, + {2444465148, 0, 0}, + {2732195517, 0, 0}, + {2763232252, 0, 0}, + {2796901051, 0, 0}, + {2855506940, 0, 0}, + {2922615804, 0, 0}, + {2937761472, 0, 0}, + {3202349435, 0, 0}, + {3362723943, 0, 0}, + {3712763835, 0, 0}, + {3810805277, 0, 0}, + {3929248764, 0, 0}, + {4016096296, 0, 0}, + {4172568578, 0, 0}, + {4248015868, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 12, 23}, + {0, 13, 6}, + {0, 20, 14}, + {0, 15, 24}, + {0, 17, 28}, + {0, 16, 31}, + {0, 7, 34}, + {0, 9, 32}, + {0, 36, 35}, + {0, 38, 37}, + {0, 40, 39}, + {0, 2, 8}, + {0, 10, 3}, + {0, 25, 19}, + {0, 27, 26}, + {0, 33, 30}, + {0, 11, 41}, + {0, 1, 21}, + {0, 18, 42}, + {0, 44, 43}, + {0, 46, 45}, + {0, 48, 47}, + {0, 29, 49}, + {0, 4, 50}, + {0, 52, 51}, + {0, 54, 53}, + {0, 56, 55}, + {0, 58, 57}, + {0, 59, 5}, + {0, 61, 60}, + {0, 62, 22}, + {0, 64, 63}, + })); + + codecs.emplace(std::pair(SpvOpConstantComposite, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(57, { + {0, 0, 0}, + {52882140, 0, 0}, + {210116709, 0, 0}, + {296981500, 0, 0}, + {385229009, 0, 0}, + {615748604, 0, 0}, + {910398460, 0, 0}, + {959681532, 0, 0}, + {1031290113, 0, 0}, + {1039111164, 0, 0}, + {1172110445, 0, 0}, + {1450415100, 0, 0}, + {1543672828, 0, 0}, + {1622381564, 0, 0}, + {1782996825, 0, 0}, + {1971252067, 0, 0}, + {2100532220, 0, 0}, + {2268204687, 0, 0}, + {2326636627, 0, 0}, + {2444465148, 0, 0}, + {2490492987, 0, 0}, + {2678954464, 0, 0}, + {2763232252, 0, 0}, + {2855506940, 0, 0}, + {2922615804, 0, 0}, + {3912967080, 0, 0}, + {3929248764, 0, 0}, + {4172568578, 0, 0}, + {4248015868, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 11, 24}, + {0, 12, 5}, + {0, 22, 16}, + {0, 18, 17}, + {0, 30, 27}, + {0, 6, 13}, + {0, 9, 28}, + {0, 32, 31}, + {0, 34, 33}, + {0, 7, 35}, + {0, 4, 1}, + {0, 10, 8}, + {0, 20, 15}, + {0, 25, 21}, + {0, 36, 29}, + {0, 19, 37}, + {0, 39, 38}, + {0, 41, 40}, + {0, 43, 42}, + {0, 26, 44}, + {0, 45, 2}, + {0, 47, 46}, + {0, 49, 48}, + {0, 50, 14}, + {0, 51, 3}, + {0, 53, 52}, + {0, 54, 23}, + {0, 56, 55}, + })); + + codecs.emplace(std::pair(SpvOpConstantComposite, 3), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(39, { + {0, 0, 0}, + {210116709, 0, 0}, + {296981500, 0, 0}, + {615748604, 0, 0}, + {910398460, 0, 0}, + {959681532, 0, 0}, + {1039111164, 0, 0}, + {1092948665, 0, 0}, + {1450415100, 0, 0}, + {1543672828, 0, 0}, + {1612361408, 0, 0}, + {2100532220, 0, 0}, + {2326636627, 0, 0}, + {2444465148, 0, 0}, + {2524697596, 0, 0}, + {2763232252, 0, 0}, + {2855506940, 0, 0}, + {3929248764, 0, 0}, + {4172568578, 0, 0}, + {4248015868, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 8, 7}, + {0, 9, 3}, + {0, 15, 11}, + {0, 10, 21}, + {0, 18, 12}, + {0, 4, 20}, + {0, 22, 19}, + {0, 23, 6}, + {0, 14, 24}, + {0, 5, 25}, + {0, 27, 26}, + {0, 28, 17}, + {0, 30, 29}, + {0, 31, 13}, + {0, 1, 32}, + {0, 34, 33}, + {0, 16, 35}, + {0, 2, 36}, + {0, 38, 37}, + })); + + codecs.emplace(std::pair(SpvOpConstantComposite, 4), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(35, { + {0, 0, 0}, + {296981500, 0, 0}, + {615748604, 0, 0}, + {673708384, 0, 0}, + {959681532, 0, 0}, + {1039111164, 0, 0}, + {1450415100, 0, 0}, + {1543672828, 0, 0}, + {1939359710, 0, 0}, + {2100532220, 0, 0}, + {2113115132, 0, 0}, + {2326636627, 0, 0}, + {2444465148, 0, 0}, + {2763232252, 0, 0}, + {2855506940, 0, 0}, + {3929248764, 0, 0}, + {4172568578, 0, 0}, + {4248015868, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 18, 3}, + {0, 6, 19}, + {0, 12, 4}, + {0, 17, 2}, + {0, 9, 7}, + {0, 20, 13}, + {0, 11, 8}, + {0, 10, 16}, + {0, 21, 15}, + {0, 5, 22}, + {0, 24, 23}, + {0, 26, 25}, + {0, 28, 27}, + {0, 29, 1}, + {0, 31, 30}, + {0, 33, 32}, + {0, 34, 14}, + })); + + codecs.emplace(std::pair(SpvOpConstantComposite, 5), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(23, { + {0, 0, 0}, + {545678922, 0, 0}, + {679771963, 0, 0}, + {929101967, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {3056042030, 0, 0}, + {3334207724, 0, 0}, + {3357250579, 0, 0}, + {3705139860, 0, 0}, + {3800912395, 0, 0}, + {3802564010, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 8, 11}, + {0, 9, 3}, + {0, 1, 13}, + {0, 14, 10}, + {0, 12, 15}, + {0, 17, 16}, + {0, 18, 4}, + {0, 7, 5}, + {0, 20, 19}, + {0, 2, 21}, + {0, 22, 6}, + })); + + codecs.emplace(std::pair(SpvOpFunction, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(89, { + {0, 0, 0}, + {35240468, 0, 0}, + {123060826, 0, 0}, + {184634770, 0, 0}, + {359054425, 0, 0}, + {459968607, 0, 0}, + {619875033, 0, 0}, + {904486530, 0, 0}, + {945128292, 0, 0}, + {950731750, 0, 0}, + {1058429216, 0, 0}, + {1182296898, 0, 0}, + {1238120570, 0, 0}, + {1429389803, 0, 0}, + {1652168174, 0, 0}, + {1717510093, 0, 0}, + {1766422419, 0, 0}, + {1775308984, 0, 0}, + {1776629361, 0, 0}, + {1824526196, 0, 0}, + {1957265068, 0, 0}, + {1998433745, 0, 0}, + {2055664760, 0, 0}, + {2303184249, 0, 0}, + {2451531615, 0, 0}, + {2507457870, 0, 0}, + {2550501832, 0, 0}, + {2590402790, 0, 0}, + {2649103430, 0, 0}, + {2780190687, 0, 0}, + {2831059514, 0, 0}, + {3167253437, 0, 0}, + {3269075805, 0, 0}, + {3323202731, 0, 0}, + {3361419439, 0, 0}, + {3464197236, 0, 0}, + {3472029049, 0, 0}, + {3518630848, 0, 0}, + {3604842236, 0, 0}, + {3653985133, 0, 0}, + {4091916710, 0, 0}, + {4121643374, 0, 0}, + {4185590212, 0, 0}, + {4233562270, 0, 0}, + {4235213885, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 6, 40}, + {0, 14, 31}, + {0, 7, 9}, + {0, 29, 27}, + {0, 18, 44}, + {0, 8, 5}, + {0, 10, 3}, + {0, 41, 37}, + {0, 42, 35}, + {0, 2, 1}, + {0, 47, 46}, + {0, 48, 4}, + {0, 11, 49}, + {0, 50, 36}, + {0, 19, 51}, + {0, 53, 52}, + {0, 55, 54}, + {0, 15, 12}, + {0, 26, 16}, + {0, 56, 21}, + {0, 25, 33}, + {0, 43, 24}, + {0, 57, 39}, + {0, 59, 58}, + {0, 61, 60}, + {0, 62, 34}, + {0, 64, 63}, + {0, 17, 30}, + {0, 66, 65}, + {0, 20, 67}, + {0, 13, 68}, + {0, 28, 69}, + {0, 70, 32}, + {0, 72, 71}, + {0, 73, 22}, + {0, 75, 74}, + {0, 77, 76}, + {0, 79, 78}, + {0, 80, 23}, + {0, 45, 81}, + {0, 83, 82}, + {0, 85, 84}, + {0, 38, 86}, + {0, 88, 87}, + })); + + codecs.emplace(std::pair(SpvOpFunction, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(87, { + {0, 0, 0}, + {75986790, 0, 0}, + {95470391, 0, 0}, + {170378107, 0, 0}, + {172029722, 0, 0}, + {204234270, 0, 0}, + {205885885, 0, 0}, + {244668133, 0, 0}, + {265778447, 0, 0}, + {753954113, 0, 0}, + {1000070091, 0, 0}, + {1671139745, 0, 0}, + {1774874546, 0, 0}, + {1776526161, 0, 0}, + {1887808856, 0, 0}, + {1889460471, 0, 0}, + {1917966999, 0, 0}, + {2044728014, 0, 0}, + {2192810893, 0, 0}, + {2293247016, 0, 0}, + {2503194620, 0, 0}, + {2608484640, 0, 0}, + {2615111110, 0, 0}, + {2668769415, 0, 0}, + {2759951687, 0, 0}, + {2761603302, 0, 0}, + {2856623532, 0, 0}, + {2956189845, 0, 0}, + {3085119011, 0, 0}, + {3367313400, 0, 0}, + {3447882276, 0, 0}, + {3633746133, 0, 0}, + {3635397748, 0, 0}, + {3710645347, 0, 0}, + {3712296962, 0, 0}, + {3727494858, 0, 0}, + {3747079365, 0, 0}, + {3748965853, 0, 0}, + {3750617468, 0, 0}, + {4018820793, 0, 0}, + {4022124023, 0, 0}, + {4024173916, 0, 0}, + {4215670524, 0, 0}, + {4217322139, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 39, 28}, + {0, 29, 40}, + {0, 37, 33}, + {0, 38, 12}, + {0, 9, 26}, + {0, 18, 6}, + {0, 41, 3}, + {0, 11, 13}, + {0, 5, 8}, + {0, 45, 30}, + {0, 22, 46}, + {0, 48, 47}, + {0, 16, 17}, + {0, 34, 49}, + {0, 51, 50}, + {0, 53, 52}, + {0, 7, 2}, + {0, 23, 21}, + {0, 54, 10}, + {0, 20, 36}, + {0, 55, 35}, + {0, 56, 4}, + {0, 43, 57}, + {0, 59, 58}, + {0, 60, 42}, + {0, 62, 61}, + {0, 63, 15}, + {0, 64, 31}, + {0, 14, 65}, + {0, 66, 24}, + {0, 67, 32}, + {0, 68, 19}, + {0, 70, 69}, + {0, 71, 27}, + {0, 73, 72}, + {0, 75, 74}, + {0, 77, 76}, + {0, 78, 25}, + {0, 44, 79}, + {0, 81, 80}, + {0, 83, 82}, + {0, 1, 84}, + {0, 86, 85}, + })); + + codecs.emplace(std::pair(SpvOpFunction, 3), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(41, { + {0, 0, 0}, + {436416061, 0, 0}, + {543621065, 0, 0}, + {679771963, 0, 0}, + {815757910, 0, 0}, + {827698488, 0, 0}, + {1294403159, 0, 0}, + {1296054774, 0, 0}, + {1297706389, 0, 0}, + {1579585816, 0, 0}, + {1675764636, 0, 0}, + {1824016656, 0, 0}, + {1951208733, 0, 0}, + {2194691858, 0, 0}, + {2320303498, 0, 0}, + {2926633629, 0, 0}, + {3095831808, 0, 0}, + {3249265647, 0, 0}, + {3334207724, 0, 0}, + {4050155669, 0, 0}, + {4141567741, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 11}, + {0, 19, 16}, + {0, 9, 4}, + {0, 1, 17}, + {0, 22, 10}, + {0, 24, 23}, + {0, 15, 25}, + {0, 13, 26}, + {0, 27, 20}, + {0, 12, 28}, + {0, 30, 29}, + {0, 31, 18}, + {0, 3, 21}, + {0, 32, 14}, + {0, 34, 33}, + {0, 35, 8}, + {0, 5, 6}, + {0, 37, 36}, + {0, 39, 38}, + {0, 40, 7}, + })); + + codecs.emplace(std::pair(SpvOpFunctionParameter, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(41, { + {0, 0, 0}, + {522971108, 0, 0}, + {615341051, 0, 0}, + {718301639, 0, 0}, + {985750227, 0, 0}, + {1395113939, 0, 0}, + {1510333659, 0, 0}, + {1642805350, 0, 0}, + {1846856260, 0, 0}, + {1957218950, 0, 0}, + {1977038330, 0, 0}, + {1978689945, 0, 0}, + {1980341560, 0, 0}, + {2262220987, 0, 0}, + {2674422363, 0, 0}, + {3197739982, 0, 0}, + {3465954368, 0, 0}, + {3941049054, 0, 0}, + {3945795573, 0, 0}, + {4080527786, 0, 0}, + {4154758669, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 17}, + {0, 4, 15}, + {0, 8, 7}, + {0, 2, 20}, + {0, 22, 19}, + {0, 24, 23}, + {0, 14, 25}, + {0, 16, 26}, + {0, 27, 13}, + {0, 6, 28}, + {0, 30, 29}, + {0, 31, 10}, + {0, 11, 21}, + {0, 32, 12}, + {0, 34, 33}, + {0, 35, 5}, + {0, 9, 18}, + {0, 37, 36}, + {0, 39, 38}, + {0, 40, 1}, + })); + + codecs.emplace(std::pair(SpvOpFunctionParameter, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(27, { + {0, 0, 0}, + {545678922, 0, 0}, + {679771963, 0, 0}, + {899570100, 0, 0}, + {929101967, 0, 0}, + {1100599986, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {3056042030, 0, 0}, + {3334207724, 0, 0}, + {3357250579, 0, 0}, + {3705139860, 0, 0}, + {3800912395, 0, 0}, + {3802564010, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 5, 3}, + {0, 10, 13}, + {0, 4, 15}, + {0, 16, 11}, + {0, 17, 1}, + {0, 14, 12}, + {0, 19, 18}, + {0, 21, 20}, + {0, 22, 8}, + {0, 7, 6}, + {0, 23, 9}, + {0, 25, 24}, + {0, 26, 2}, + })); + + codecs.emplace(std::pair(SpvOpFunctionCall, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(115, { + {0, 0, 0}, + {57149555, 0, 0}, + {86116519, 0, 0}, + {168339452, 0, 0}, + {181902171, 0, 0}, + {284226441, 0, 0}, + {314809953, 0, 0}, + {330249537, 0, 0}, + {527665290, 0, 0}, + {545363837, 0, 0}, + {707478563, 0, 0}, + {740921498, 0, 0}, + {807276090, 0, 0}, + {824323032, 0, 0}, + {835458563, 0, 0}, + {1162127370, 0, 0}, + {1245448751, 0, 0}, + {1277245109, 0, 0}, + {1375043498, 0, 0}, + {1380991098, 0, 0}, + {1603937321, 0, 0}, + {1708264968, 0, 0}, + {1717555224, 0, 0}, + {1765126703, 0, 0}, + {1838993983, 0, 0}, + {1949856502, 0, 0}, + {2108571893, 0, 0}, + {2110223508, 0, 0}, + {2293637521, 0, 0}, + {2377112119, 0, 0}, + {2378763734, 0, 0}, + {2512398201, 0, 0}, + {2516325050, 0, 0}, + {2645135839, 0, 0}, + {2708915136, 0, 0}, + {2894979602, 0, 0}, + {2903897222, 0, 0}, + {2976581453, 0, 0}, + {3054834317, 0, 0}, + {3075866530, 0, 0}, + {3085157904, 0, 0}, + {3242843022, 0, 0}, + {3266028549, 0, 0}, + {3296691317, 0, 0}, + {3299488628, 0, 0}, + {3322500634, 0, 0}, + {3345707173, 0, 0}, + {3536390697, 0, 0}, + {3584683259, 0, 0}, + {3647606635, 0, 0}, + {3760372982, 0, 0}, + {3823959661, 0, 0}, + {3839389658, 0, 0}, + {4124281183, 0, 0}, + {4130950286, 0, 0}, + {4169878842, 0, 0}, + {4174489262, 0, 0}, + {4237497041, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 17, 23}, + {0, 37, 8}, + {0, 45, 39}, + {0, 41, 14}, + {0, 48, 43}, + {0, 40, 31}, + {0, 19, 29}, + {0, 53, 26}, + {0, 10, 5}, + {0, 50, 24}, + {0, 27, 3}, + {0, 59, 32}, + {0, 51, 18}, + {0, 52, 55}, + {0, 60, 57}, + {0, 62, 61}, + {0, 36, 33}, + {0, 64, 63}, + {0, 65, 22}, + {0, 66, 46}, + {0, 6, 67}, + {0, 68, 13}, + {0, 21, 44}, + {0, 1, 69}, + {0, 30, 11}, + {0, 71, 70}, + {0, 12, 72}, + {0, 74, 73}, + {0, 76, 75}, + {0, 16, 2}, + {0, 49, 35}, + {0, 77, 9}, + {0, 42, 28}, + {0, 15, 78}, + {0, 80, 79}, + {0, 82, 81}, + {0, 47, 83}, + {0, 85, 84}, + {0, 87, 86}, + {0, 89, 88}, + {0, 20, 38}, + {0, 54, 90}, + {0, 34, 91}, + {0, 93, 92}, + {0, 25, 94}, + {0, 95, 7}, + {0, 97, 96}, + {0, 56, 98}, + {0, 100, 99}, + {0, 102, 101}, + {0, 104, 103}, + {0, 4, 105}, + {0, 107, 106}, + {0, 58, 108}, + {0, 110, 109}, + {0, 112, 111}, + {0, 114, 113}, + })); + + codecs.emplace(std::pair(SpvOpFunctionCall, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(81, { + {0, 0, 0}, + {35240468, 0, 0}, + {36096192, 0, 0}, + {123060826, 0, 0}, + {184634770, 0, 0}, + {459968607, 0, 0}, + {619875033, 0, 0}, + {950731750, 0, 0}, + {1058429216, 0, 0}, + {1182296898, 0, 0}, + {1238120570, 0, 0}, + {1271484400, 0, 0}, + {1429389803, 0, 0}, + {1717510093, 0, 0}, + {1766422419, 0, 0}, + {1775308984, 0, 0}, + {1817271123, 0, 0}, + {1917336504, 0, 0}, + {1957265068, 0, 0}, + {1998433745, 0, 0}, + {2055664760, 0, 0}, + {2303184249, 0, 0}, + {2308565678, 0, 0}, + {2451531615, 0, 0}, + {2496297824, 0, 0}, + {2507457870, 0, 0}, + {2550501832, 0, 0}, + {2590402790, 0, 0}, + {2649103430, 0, 0}, + {2831059514, 0, 0}, + {2836440943, 0, 0}, + {3269075805, 0, 0}, + {3361419439, 0, 0}, + {3457269042, 0, 0}, + {3464197236, 0, 0}, + {3472029049, 0, 0}, + {3518630848, 0, 0}, + {3587381650, 0, 0}, + {3653985133, 0, 0}, + {4185590212, 0, 0}, + {4233562270, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 40, 37}, + {0, 22, 30}, + {0, 2, 7}, + {0, 24, 11}, + {0, 16, 33}, + {0, 6, 34}, + {0, 42, 27}, + {0, 5, 43}, + {0, 4, 44}, + {0, 36, 8}, + {0, 39, 45}, + {0, 46, 1}, + {0, 3, 47}, + {0, 48, 23}, + {0, 49, 9}, + {0, 50, 35}, + {0, 52, 51}, + {0, 32, 53}, + {0, 13, 10}, + {0, 26, 14}, + {0, 19, 54}, + {0, 55, 25}, + {0, 56, 38}, + {0, 17, 57}, + {0, 59, 58}, + {0, 61, 60}, + {0, 62, 29}, + {0, 12, 15}, + {0, 18, 63}, + {0, 28, 64}, + {0, 65, 31}, + {0, 67, 66}, + {0, 20, 41}, + {0, 69, 68}, + {0, 71, 70}, + {0, 21, 72}, + {0, 74, 73}, + {0, 76, 75}, + {0, 78, 77}, + {0, 80, 79}, + })); + + codecs.emplace(std::pair(SpvOpFunctionCall, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(61, { + {0, 0, 0}, + {37459569, 0, 0}, + {162167595, 0, 0}, + {535067202, 0, 0}, + {701281393, 0, 0}, + {837715723, 0, 0}, + {1320550031, 0, 0}, + {1630583316, 0, 0}, + {1913735398, 0, 0}, + {1918481917, 0, 0}, + {1955871800, 0, 0}, + {1977038330, 0, 0}, + {2053214130, 0, 0}, + {2443959748, 0, 0}, + {2564745684, 0, 0}, + {2622612602, 0, 0}, + {2677252364, 0, 0}, + {2736026107, 0, 0}, + {2790624748, 0, 0}, + {2882994691, 0, 0}, + {2888125966, 0, 0}, + {2970183398, 0, 0}, + {3253403867, 0, 0}, + {3427283542, 0, 0}, + {3570411982, 0, 0}, + {3619787319, 0, 0}, + {3662767579, 0, 0}, + {3884846406, 0, 0}, + {3910458990, 0, 0}, + {3927915220, 0, 0}, + {4224872590, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 5, 20}, + {0, 6, 25}, + {0, 23, 3}, + {0, 2, 4}, + {0, 14, 17}, + {0, 11, 8}, + {0, 27, 10}, + {0, 19, 28}, + {0, 12, 16}, + {0, 33, 32}, + {0, 35, 34}, + {0, 37, 36}, + {0, 39, 38}, + {0, 40, 15}, + {0, 41, 7}, + {0, 1, 21}, + {0, 24, 13}, + {0, 29, 42}, + {0, 44, 43}, + {0, 22, 45}, + {0, 47, 46}, + {0, 49, 48}, + {0, 50, 30}, + {0, 31, 51}, + {0, 53, 52}, + {0, 55, 54}, + {0, 56, 9}, + {0, 57, 26}, + {0, 59, 58}, + {0, 60, 18}, + })); + + codecs.emplace(std::pair(SpvOpFunctionCall, 3), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(39, { + {0, 0, 0}, + {744062262, 0, 0}, + {810488476, 0, 0}, + {1040775722, 0, 0}, + {1280126114, 0, 0}, + {1367301635, 0, 0}, + {1684282922, 0, 0}, + {1918481917, 0, 0}, + {1978689945, 0, 0}, + {1980341560, 0, 0}, + {2443959748, 0, 0}, + {2629265310, 0, 0}, + {2790624748, 0, 0}, + {2970183398, 0, 0}, + {3044188332, 0, 0}, + {3496407048, 0, 0}, + {3662767579, 0, 0}, + {3887377256, 0, 0}, + {3971481069, 0, 0}, + {4224872590, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 2}, + {0, 18, 15}, + {0, 21, 6}, + {0, 13, 11}, + {0, 4, 22}, + {0, 14, 1}, + {0, 24, 23}, + {0, 25, 8}, + {0, 27, 26}, + {0, 20, 17}, + {0, 5, 28}, + {0, 29, 9}, + {0, 16, 10}, + {0, 31, 30}, + {0, 32, 7}, + {0, 19, 33}, + {0, 35, 34}, + {0, 37, 36}, + {0, 38, 12}, + })); + + codecs.emplace(std::pair(SpvOpFunctionCall, 4), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(27, { + {0, 0, 0}, + {37459569, 0, 0}, + {837715723, 0, 0}, + {1352628475, 0, 0}, + {1918481917, 0, 0}, + {1978689945, 0, 0}, + {1980341560, 0, 0}, + {2096388952, 0, 0}, + {2622612602, 0, 0}, + {2790624748, 0, 0}, + {2970183398, 0, 0}, + {3510682541, 0, 0}, + {3783543823, 0, 0}, + {4224872590, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 7, 11}, + {0, 2, 8}, + {0, 15, 12}, + {0, 1, 3}, + {0, 16, 6}, + {0, 18, 17}, + {0, 19, 14}, + {0, 20, 5}, + {0, 10, 21}, + {0, 22, 4}, + {0, 23, 13}, + {0, 25, 24}, + {0, 9, 26}, + })); + + codecs.emplace(std::pair(SpvOpFunctionCall, 5), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(13, { + {0, 0, 0}, + {1510333659, 0, 0}, + {1684282922, 0, 0}, + {1918481917, 0, 0}, + {2790624748, 0, 0}, + {3662767579, 0, 0}, + {4224872590, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 5, 1}, + {0, 8, 2}, + {0, 9, 7}, + {0, 3, 10}, + {0, 6, 11}, + {0, 4, 12}, + })); + + codecs.emplace(std::pair(SpvOpFunctionCall, 6), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(27, { + {0, 0, 0}, + {161668409, 0, 0}, + {188347929, 0, 0}, + {653708953, 0, 0}, + {976111724, 0, 0}, + {1510333659, 0, 0}, + {1918481917, 0, 0}, + {2790624748, 0, 0}, + {3033873113, 0, 0}, + {3499234137, 0, 0}, + {3525913657, 0, 0}, + {3552593177, 0, 0}, + {3570411982, 0, 0}, + {4224872590, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 8, 3}, + {0, 2, 9}, + {0, 10, 11}, + {0, 15, 1}, + {0, 17, 16}, + {0, 19, 18}, + {0, 5, 4}, + {0, 20, 6}, + {0, 12, 21}, + {0, 14, 22}, + {0, 24, 23}, + {0, 7, 25}, + {0, 13, 26}, + })); + + codecs.emplace(std::pair(SpvOpFunctionCall, 7), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(31, { + {0, 0, 0}, + {226836633, 0, 0}, + {296981500, 0, 0}, + {718877177, 0, 0}, + {745556697, 0, 0}, + {798915737, 0, 0}, + {1510333659, 0, 0}, + {1684282922, 0, 0}, + {2444465148, 0, 0}, + {2713718873, 0, 0}, + {3495546641, 0, 0}, + {3564402361, 0, 0}, + {4056442905, 0, 0}, + {4083122425, 0, 0}, + {4123141705, 0, 0}, + {4224872590, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 14, 4}, + {0, 5, 3}, + {0, 9, 8}, + {0, 13, 12}, + {0, 1, 11}, + {0, 18, 17}, + {0, 2, 19}, + {0, 21, 20}, + {0, 23, 22}, + {0, 25, 24}, + {0, 26, 7}, + {0, 27, 16}, + {0, 10, 6}, + {0, 29, 28}, + {0, 15, 30}, + })); + + codecs.emplace(std::pair(SpvOpFunctionCall, 8), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(35, { + {0, 0, 0}, + {161668409, 0, 0}, + {188347929, 0, 0}, + {215027449, 0, 0}, + {296981500, 0, 0}, + {653708953, 0, 0}, + {680388473, 0, 0}, + {1119069977, 0, 0}, + {1510333659, 0, 0}, + {1584774136, 0, 0}, + {2049792025, 0, 0}, + {2444465148, 0, 0}, + {2568512089, 0, 0}, + {3033873113, 0, 0}, + {3499234137, 0, 0}, + {3525913657, 0, 0}, + {3552593177, 0, 0}, + {4224872590, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 7, 6}, + {0, 10, 12}, + {0, 4, 3}, + {0, 16, 11}, + {0, 19, 14}, + {0, 5, 2}, + {0, 20, 13}, + {0, 21, 15}, + {0, 1, 22}, + {0, 24, 23}, + {0, 26, 25}, + {0, 28, 27}, + {0, 18, 29}, + {0, 8, 30}, + {0, 32, 31}, + {0, 9, 33}, + {0, 17, 34}, + })); + + codecs.emplace(std::pair(SpvOpFunctionCall, 9), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(25, { + {0, 0, 0}, + {825595257, 0, 0}, + {1064945649, 0, 0}, + {1290956281, 0, 0}, + {1510333659, 0, 0}, + {2096388952, 0, 0}, + {2248357849, 0, 0}, + {2713718873, 0, 0}, + {3187066832, 0, 0}, + {3205759417, 0, 0}, + {4064212479, 0, 0}, + {4163160985, 0, 0}, + {4224872590, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 8, 3}, + {0, 2, 9}, + {0, 7, 6}, + {0, 5, 14}, + {0, 16, 15}, + {0, 17, 11}, + {0, 19, 18}, + {0, 20, 1}, + {0, 4, 13}, + {0, 22, 21}, + {0, 10, 23}, + {0, 12, 24}, + })); + + codecs.emplace(std::pair(SpvOpFunctionCall, 10), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(27, { + {0, 0, 0}, + {123108003, 0, 0}, + {296981500, 0, 0}, + {595410904, 0, 0}, + {1466938734, 0, 0}, + {1503477720, 0, 0}, + {1816558243, 0, 0}, + {1990431740, 0, 0}, + {2724625059, 0, 0}, + {2790624748, 0, 0}, + {2812498065, 0, 0}, + {3160388974, 0, 0}, + {3745223676, 0, 0}, + {3982311384, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 5, 13}, + {0, 8, 1}, + {0, 12, 11}, + {0, 15, 3}, + {0, 6, 4}, + {0, 16, 7}, + {0, 17, 14}, + {0, 18, 2}, + {0, 19, 10}, + {0, 21, 20}, + {0, 23, 22}, + {0, 25, 24}, + {0, 9, 26}, + })); + + codecs.emplace(std::pair(SpvOpFunctionCall, 11), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(25, { + {0, 0, 0}, + {94145952, 0, 0}, + {1054641568, 0, 0}, + {1269075360, 0, 0}, + {1675922848, 0, 0}, + {2038205856, 0, 0}, + {2433519008, 0, 0}, + {2636942752, 0, 0}, + {2790624748, 0, 0}, + {2840366496, 0, 0}, + {2851900832, 0, 0}, + {2964622752, 0, 0}, + {3654061472, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 7, 1}, + {0, 12, 6}, + {0, 14, 10}, + {0, 13, 4}, + {0, 11, 15}, + {0, 3, 16}, + {0, 2, 17}, + {0, 18, 5}, + {0, 9, 19}, + {0, 21, 20}, + {0, 23, 22}, + {0, 8, 24}, + })); + + codecs.emplace(std::pair(SpvOpFunctionCall, 12), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(45, { + {0, 0, 0}, + {107544081, 0, 0}, + {125015036, 0, 0}, + {586244865, 0, 0}, + {1033081852, 0, 0}, + {1064945649, 0, 0}, + {1155765244, 0, 0}, + {1304296041, 0, 0}, + {1543646433, 0, 0}, + {1782996825, 0, 0}, + {1941148668, 0, 0}, + {2002490364, 0, 0}, + {2022347217, 0, 0}, + {2063832060, 0, 0}, + {2487708241, 0, 0}, + {2726532092, 0, 0}, + {2849215484, 0, 0}, + {2966409025, 0, 0}, + {3445109809, 0, 0}, + {3458449569, 0, 0}, + {3634598908, 0, 0}, + {3695940604, 0, 0}, + {3923810593, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 7, 2}, + {0, 14, 13}, + {0, 1, 23}, + {0, 6, 5}, + {0, 16, 15}, + {0, 24, 17}, + {0, 12, 25}, + {0, 22, 18}, + {0, 10, 26}, + {0, 28, 27}, + {0, 21, 29}, + {0, 31, 30}, + {0, 9, 8}, + {0, 11, 32}, + {0, 33, 19}, + {0, 3, 34}, + {0, 36, 35}, + {0, 38, 37}, + {0, 20, 39}, + {0, 41, 40}, + {0, 42, 4}, + {0, 44, 43}, + })); + + codecs.emplace(std::pair(SpvOpFunctionCall, 13), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(23, { + {0, 0, 0}, + {247698428, 0, 0}, + {309040124, 0, 0}, + {333554713, 0, 0}, + {572905105, 0, 0}, + {1033081852, 0, 0}, + {2002490364, 0, 0}, + {2009007457, 0, 0}, + {2487708241, 0, 0}, + {3634598908, 0, 0}, + {3695940604, 0, 0}, + {3923810593, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 6, 1}, + {0, 9, 7}, + {0, 5, 12}, + {0, 14, 13}, + {0, 15, 8}, + {0, 3, 16}, + {0, 17, 11}, + {0, 10, 4}, + {0, 2, 18}, + {0, 20, 19}, + {0, 22, 21}, + })); + + codecs.emplace(std::pair(SpvOpFunctionCall, 14), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(11, { + {0, 0, 0}, + {247698428, 0, 0}, + {1033081852, 0, 0}, + {2002490364, 0, 0}, + {2910557180, 0, 0}, + {3757282300, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 6, 4}, + {0, 7, 3}, + {0, 2, 8}, + {0, 1, 5}, + {0, 10, 9}, + })); + + codecs.emplace(std::pair(SpvOpFunctionCall, 15), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(9, { + {0, 0, 0}, + {1033081852, 0, 0}, + {1094423548, 0, 0}, + {2002490364, 0, 0}, + {3757282300, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 5}, + {0, 6, 2}, + {0, 4, 7}, + {0, 8, 1}, + })); + + codecs.emplace(std::pair(SpvOpFunctionCall, 16), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(57, { + {0, 0, 0}, + {135486769, 0, 0}, + {450406196, 0, 0}, + {503094540, 0, 0}, + {543621065, 0, 0}, + {827698488, 0, 0}, + {1294403159, 0, 0}, + {1296054774, 0, 0}, + {1297706389, 0, 0}, + {1322549027, 0, 0}, + {1784441183, 0, 0}, + {2194691858, 0, 0}, + {2448331885, 0, 0}, + {2468230023, 0, 0}, + {2547657777, 0, 0}, + {2549309392, 0, 0}, + {2550961007, 0, 0}, + {2934934694, 0, 0}, + {2936586309, 0, 0}, + {2938237924, 0, 0}, + {3094180193, 0, 0}, + {3095831808, 0, 0}, + {3183924418, 0, 0}, + {3561562003, 0, 0}, + {3563213618, 0, 0}, + {3564865233, 0, 0}, + {4028622909, 0, 0}, + {4039938779, 0, 0}, + {4050155669, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 27, 28}, + {0, 10, 2}, + {0, 25, 24}, + {0, 1, 12}, + {0, 30, 3}, + {0, 20, 31}, + {0, 9, 32}, + {0, 34, 33}, + {0, 35, 22}, + {0, 26, 15}, + {0, 19, 36}, + {0, 18, 37}, + {0, 38, 16}, + {0, 39, 8}, + {0, 5, 40}, + {0, 6, 41}, + {0, 21, 42}, + {0, 11, 29}, + {0, 4, 43}, + {0, 13, 23}, + {0, 14, 17}, + {0, 7, 44}, + {0, 46, 45}, + {0, 48, 47}, + {0, 50, 49}, + {0, 52, 51}, + {0, 54, 53}, + {0, 56, 55}, + })); + + codecs.emplace(std::pair(SpvOpVariable, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(57, { + {0, 0, 0}, + {37459569, 0, 0}, + {112745085, 0, 0}, + {137840602, 0, 0}, + {565334834, 0, 0}, + {625975427, 0, 0}, + {630964591, 0, 0}, + {680016782, 0, 0}, + {769422756, 0, 0}, + {1009983433, 0, 0}, + {1093210099, 0, 0}, + {1572088444, 0, 0}, + {1584774136, 0, 0}, + {1641565587, 0, 0}, + {1918481917, 0, 0}, + {2190437442, 0, 0}, + {2790624748, 0, 0}, + {3085467405, 0, 0}, + {3181646225, 0, 0}, + {3192069648, 0, 0}, + {3253403867, 0, 0}, + {3390051757, 0, 0}, + {3560665067, 0, 0}, + {3662767579, 0, 0}, + {4053789056, 0, 0}, + {4064212479, 0, 0}, + {4192247221, 0, 0}, + {4224872590, 0, 0}, + {4290024976, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 20}, + {0, 28, 10}, + {0, 13, 8}, + {0, 15, 17}, + {0, 30, 21}, + {0, 19, 31}, + {0, 4, 32}, + {0, 34, 33}, + {0, 35, 5}, + {0, 7, 24}, + {0, 9, 36}, + {0, 3, 37}, + {0, 38, 6}, + {0, 39, 23}, + {0, 27, 40}, + {0, 14, 41}, + {0, 25, 42}, + {0, 1, 29}, + {0, 12, 43}, + {0, 11, 26}, + {0, 18, 22}, + {0, 16, 44}, + {0, 46, 45}, + {0, 48, 47}, + {0, 50, 49}, + {0, 52, 51}, + {0, 54, 53}, + {0, 56, 55}, + })); + + codecs.emplace(std::pair(SpvOpVariable, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(27, { + {0, 0, 0}, + {162255877, 0, 0}, + {679771963, 0, 0}, + {789872778, 0, 0}, + {1154919607, 0, 0}, + {1343794461, 0, 0}, + {1951208733, 0, 0}, + {2263349224, 0, 0}, + {2320303498, 0, 0}, + {2924146124, 0, 0}, + {2984325996, 0, 0}, + {3334207724, 0, 0}, + {3868239231, 0, 0}, + {3869890846, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 5, 3}, + {0, 9, 7}, + {0, 12, 4}, + {0, 16, 15}, + {0, 18, 17}, + {0, 14, 19}, + {0, 13, 10}, + {0, 20, 1}, + {0, 21, 8}, + {0, 2, 22}, + {0, 11, 23}, + {0, 6, 24}, + {0, 26, 25}, + })); + + codecs.emplace(std::pair(SpvOpLoad, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(83, { + {0, 0, 0}, + {169674806, 0, 0}, + {269823086, 0, 0}, + {408465899, 0, 0}, + {451264926, 0, 0}, + {543558236, 0, 0}, + {810488476, 0, 0}, + {850497536, 0, 0}, + {870594305, 0, 0}, + {883854656, 0, 0}, + {1033363654, 0, 0}, + {1069781886, 0, 0}, + {1141965917, 0, 0}, + {1323407757, 0, 0}, + {1570165302, 0, 0}, + {1684282922, 0, 0}, + {1742737136, 0, 0}, + {1901166356, 0, 0}, + {1949759310, 0, 0}, + {2043873558, 0, 0}, + {2087004702, 0, 0}, + {2096388952, 0, 0}, + {2157103435, 0, 0}, + {2219733501, 0, 0}, + {2356768706, 0, 0}, + {2443959748, 0, 0}, + {2517964682, 0, 0}, + {2614879967, 0, 0}, + {2622612602, 0, 0}, + {2660843182, 0, 0}, + {2959147533, 0, 0}, + {2970183398, 0, 0}, + {3044188332, 0, 0}, + {3091876332, 0, 0}, + {3187066832, 0, 0}, + {3244209297, 0, 0}, + {3487022798, 0, 0}, + {3496407048, 0, 0}, + {3570411982, 0, 0}, + {3692647551, 0, 0}, + {3713290482, 0, 0}, + {3831290364, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 4, 1}, + {0, 35, 13}, + {0, 25, 11}, + {0, 7, 10}, + {0, 19, 36}, + {0, 43, 27}, + {0, 16, 29}, + {0, 22, 3}, + {0, 41, 30}, + {0, 44, 12}, + {0, 2, 24}, + {0, 40, 32}, + {0, 23, 45}, + {0, 46, 39}, + {0, 17, 33}, + {0, 48, 47}, + {0, 8, 49}, + {0, 51, 50}, + {0, 52, 20}, + {0, 53, 14}, + {0, 31, 54}, + {0, 15, 55}, + {0, 57, 56}, + {0, 59, 58}, + {0, 6, 26}, + {0, 61, 60}, + {0, 34, 62}, + {0, 64, 63}, + {0, 5, 37}, + {0, 9, 65}, + {0, 18, 28}, + {0, 66, 38}, + {0, 68, 67}, + {0, 69, 21}, + {0, 71, 70}, + {0, 73, 72}, + {0, 75, 74}, + {0, 77, 76}, + {0, 79, 78}, + {0, 80, 42}, + {0, 82, 81}, + })); + + codecs.emplace(std::pair(SpvOpLoad, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(83, { + {0, 0, 0}, + {28782128, 0, 0}, + {30433743, 0, 0}, + {37459569, 0, 0}, + {137840602, 0, 0}, + {522971108, 0, 0}, + {565334834, 0, 0}, + {625975427, 0, 0}, + {630964591, 0, 0}, + {680016782, 0, 0}, + {1009983433, 0, 0}, + {1079999262, 0, 0}, + {1395113939, 0, 0}, + {1572088444, 0, 0}, + {1584774136, 0, 0}, + {1649426421, 0, 0}, + {1918481917, 0, 0}, + {1957218950, 0, 0}, + {2311941439, 0, 0}, + {2313593054, 0, 0}, + {2790624748, 0, 0}, + {2838165089, 0, 0}, + {2839816704, 0, 0}, + {2841468319, 0, 0}, + {3085467405, 0, 0}, + {3181646225, 0, 0}, + {3192069648, 0, 0}, + {3253403867, 0, 0}, + {3364388739, 0, 0}, + {3366040354, 0, 0}, + {3367691969, 0, 0}, + {3369343584, 0, 0}, + {3560665067, 0, 0}, + {3662767579, 0, 0}, + {3945795573, 0, 0}, + {4053789056, 0, 0}, + {4064212479, 0, 0}, + {4224872590, 0, 0}, + {4239834800, 0, 0}, + {4241486415, 0, 0}, + {4243138030, 0, 0}, + {4244789645, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 27}, + {0, 15, 2}, + {0, 10, 26}, + {0, 7, 24}, + {0, 9, 31}, + {0, 43, 30}, + {0, 29, 12}, + {0, 11, 41}, + {0, 40, 39}, + {0, 44, 23}, + {0, 22, 6}, + {0, 34, 35}, + {0, 18, 45}, + {0, 46, 21}, + {0, 17, 19}, + {0, 48, 47}, + {0, 28, 49}, + {0, 51, 50}, + {0, 52, 38}, + {0, 53, 33}, + {0, 4, 54}, + {0, 13, 55}, + {0, 57, 56}, + {0, 59, 58}, + {0, 37, 8}, + {0, 61, 60}, + {0, 5, 62}, + {0, 64, 63}, + {0, 36, 32}, + {0, 3, 65}, + {0, 14, 16}, + {0, 66, 25}, + {0, 68, 67}, + {0, 69, 20}, + {0, 71, 70}, + {0, 73, 72}, + {0, 75, 74}, + {0, 77, 76}, + {0, 79, 78}, + {0, 80, 42}, + {0, 82, 81}, + })); + + codecs.emplace(std::pair(SpvOpLoad, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(49, { + {0, 0, 0}, + {137840602, 0, 0}, + {522971108, 0, 0}, + {769422756, 0, 0}, + {1009983433, 0, 0}, + {1079999262, 0, 0}, + {1558345254, 0, 0}, + {1572088444, 0, 0}, + {1641565587, 0, 0}, + {1918481917, 0, 0}, + {2311941439, 0, 0}, + {2313593054, 0, 0}, + {2790624748, 0, 0}, + {2838165089, 0, 0}, + {2994529201, 0, 0}, + {2996180816, 0, 0}, + {2997832431, 0, 0}, + {3027538652, 0, 0}, + {3253403867, 0, 0}, + {3364388739, 0, 0}, + {3560665067, 0, 0}, + {3662767579, 0, 0}, + {3945795573, 0, 0}, + {4192247221, 0, 0}, + {4224872590, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 14, 17}, + {0, 16, 15}, + {0, 13, 11}, + {0, 10, 3}, + {0, 22, 18}, + {0, 6, 8}, + {0, 19, 2}, + {0, 27, 26}, + {0, 28, 5}, + {0, 30, 29}, + {0, 32, 31}, + {0, 34, 33}, + {0, 4, 35}, + {0, 37, 36}, + {0, 21, 1}, + {0, 39, 38}, + {0, 40, 24}, + {0, 7, 23}, + {0, 20, 9}, + {0, 42, 41}, + {0, 43, 25}, + {0, 44, 12}, + {0, 46, 45}, + {0, 48, 47}, + })); + + codecs.emplace(std::pair(SpvOpStore, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(59, { + {0, 0, 0}, + {139011596, 0, 0}, + {177111659, 0, 0}, + {296981500, 0, 0}, + {408465899, 0, 0}, + {495107308, 0, 0}, + {810488476, 0, 0}, + {870594305, 0, 0}, + {1367301635, 0, 0}, + {1901166356, 0, 0}, + {2055836767, 0, 0}, + {2087004702, 0, 0}, + {2096388952, 0, 0}, + {2204920111, 0, 0}, + {2517964682, 0, 0}, + {2622612602, 0, 0}, + {2660843182, 0, 0}, + {2842919847, 0, 0}, + {2855506940, 0, 0}, + {2959147533, 0, 0}, + {3044188332, 0, 0}, + {3187066832, 0, 0}, + {3504158761, 0, 0}, + {3570411982, 0, 0}, + {3619787319, 0, 0}, + {3653838348, 0, 0}, + {3692647551, 0, 0}, + {3764205609, 0, 0}, + {3831290364, 0, 0}, + {3913885196, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 20, 29}, + {0, 25, 8}, + {0, 5, 1}, + {0, 24, 26}, + {0, 14, 9}, + {0, 27, 16}, + {0, 31, 7}, + {0, 33, 32}, + {0, 17, 34}, + {0, 35, 13}, + {0, 22, 6}, + {0, 3, 2}, + {0, 23, 36}, + {0, 28, 37}, + {0, 19, 4}, + {0, 38, 10}, + {0, 39, 15}, + {0, 40, 18}, + {0, 42, 41}, + {0, 43, 12}, + {0, 44, 21}, + {0, 45, 11}, + {0, 47, 46}, + {0, 49, 48}, + {0, 51, 50}, + {0, 53, 52}, + {0, 55, 54}, + {0, 57, 56}, + {0, 30, 58}, + })); + + codecs.emplace(std::pair(SpvOpStore, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(35, { + {0, 0, 0}, + {440421571, 0, 0}, + {827698488, 0, 0}, + {907126242, 0, 0}, + {908777857, 0, 0}, + {910429472, 0, 0}, + {1294403159, 0, 0}, + {1296054774, 0, 0}, + {1297706389, 0, 0}, + {2080953106, 0, 0}, + {2468230023, 0, 0}, + {2547657777, 0, 0}, + {2549309392, 0, 0}, + {2550961007, 0, 0}, + {3094857332, 0, 0}, + {3561562003, 0, 0}, + {3563213618, 0, 0}, + {3564865233, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 16, 12}, + {0, 17, 13}, + {0, 14, 19}, + {0, 18, 20}, + {0, 5, 21}, + {0, 11, 7}, + {0, 15, 22}, + {0, 9, 8}, + {0, 24, 23}, + {0, 25, 4}, + {0, 27, 26}, + {0, 28, 3}, + {0, 29, 10}, + {0, 6, 1}, + {0, 31, 30}, + {0, 32, 2}, + {0, 34, 33}, + })); + + codecs.emplace(std::pair(SpvOpAccessChain, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(99, { + {0, 0, 0}, + {27130513, 0, 0}, + {28782128, 0, 0}, + {30433743, 0, 0}, + {32085358, 0, 0}, + {155458798, 0, 0}, + {157110413, 0, 0}, + {163402553, 0, 0}, + {165054168, 0, 0}, + {213642219, 0, 0}, + {215293834, 0, 0}, + {216945449, 0, 0}, + {221900294, 0, 0}, + {545986953, 0, 0}, + {979993429, 0, 0}, + {1079999262, 0, 0}, + {1302400505, 0, 0}, + {1313182965, 0, 0}, + {1314834580, 0, 0}, + {1315613425, 0, 0}, + {1317265040, 0, 0}, + {1558345254, 0, 0}, + {1649426421, 0, 0}, + {2311941439, 0, 0}, + {2313593054, 0, 0}, + {2602027658, 0, 0}, + {2838165089, 0, 0}, + {2839816704, 0, 0}, + {2841468319, 0, 0}, + {2863084840, 0, 0}, + {2994529201, 0, 0}, + {2996180816, 0, 0}, + {2997832431, 0, 0}, + {3027538652, 0, 0}, + {3187387500, 0, 0}, + {3189039115, 0, 0}, + {3364388739, 0, 0}, + {3366040354, 0, 0}, + {3367691969, 0, 0}, + {3369343584, 0, 0}, + {3716914380, 0, 0}, + {3928842969, 0, 0}, + {3930494584, 0, 0}, + {3932146199, 0, 0}, + {3945482286, 0, 0}, + {4105051793, 0, 0}, + {4239834800, 0, 0}, + {4241486415, 0, 0}, + {4243138030, 0, 0}, + {4244789645, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 29, 10}, + {0, 17, 18}, + {0, 13, 14}, + {0, 44, 25}, + {0, 8, 7}, + {0, 20, 11}, + {0, 33, 19}, + {0, 6, 45}, + {0, 42, 43}, + {0, 40, 5}, + {0, 9, 16}, + {0, 1, 4}, + {0, 35, 34}, + {0, 12, 21}, + {0, 52, 51}, + {0, 31, 30}, + {0, 41, 32}, + {0, 54, 53}, + {0, 55, 2}, + {0, 3, 56}, + {0, 58, 57}, + {0, 60, 59}, + {0, 61, 22}, + {0, 63, 62}, + {0, 65, 64}, + {0, 67, 66}, + {0, 39, 68}, + {0, 38, 69}, + {0, 47, 70}, + {0, 49, 71}, + {0, 28, 48}, + {0, 37, 15}, + {0, 73, 72}, + {0, 74, 27}, + {0, 23, 75}, + {0, 76, 26}, + {0, 24, 77}, + {0, 79, 78}, + {0, 81, 80}, + {0, 82, 46}, + {0, 36, 83}, + {0, 85, 84}, + {0, 87, 86}, + {0, 89, 88}, + {0, 91, 90}, + {0, 93, 92}, + {0, 95, 94}, + {0, 97, 96}, + {0, 50, 98}, + })); + + codecs.emplace(std::pair(SpvOpAccessChain, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(101, { + {0, 0, 0}, + {112745085, 0, 0}, + {116376005, 0, 0}, + {137840602, 0, 0}, + {400248103, 0, 0}, + {406044930, 0, 0}, + {468372467, 0, 0}, + {522971108, 0, 0}, + {615341051, 0, 0}, + {625975427, 0, 0}, + {630964591, 0, 0}, + {680016782, 0, 0}, + {763027711, 0, 0}, + {977312655, 0, 0}, + {1009983433, 0, 0}, + {1062250709, 0, 0}, + {1395113939, 0, 0}, + {1410849099, 0, 0}, + {1642805350, 0, 0}, + {1692932387, 0, 0}, + {1698730948, 0, 0}, + {1827244161, 0, 0}, + {1918481917, 0, 0}, + {2096472894, 0, 0}, + {2190437442, 0, 0}, + {2299842241, 0, 0}, + {2433358586, 0, 0}, + {2593325766, 0, 0}, + {2785441472, 0, 0}, + {2790624748, 0, 0}, + {2879917723, 0, 0}, + {2882994691, 0, 0}, + {2902069960, 0, 0}, + {3090408469, 0, 0}, + {3181646225, 0, 0}, + {3255947500, 0, 0}, + {3263901372, 0, 0}, + {3268751013, 0, 0}, + {3347863687, 0, 0}, + {3390051757, 0, 0}, + {3560665067, 0, 0}, + {3617689692, 0, 0}, + {3662767579, 0, 0}, + {3717523241, 0, 0}, + {3854557817, 0, 0}, + {3910458990, 0, 0}, + {3941049054, 0, 0}, + {3945795573, 0, 0}, + {4080527786, 0, 0}, + {4101009465, 0, 0}, + {4290024976, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 32, 44}, + {0, 41, 26}, + {0, 16, 10}, + {0, 27, 45}, + {0, 25, 38}, + {0, 12, 18}, + {0, 6, 35}, + {0, 46, 23}, + {0, 20, 37}, + {0, 52, 19}, + {0, 53, 21}, + {0, 54, 48}, + {0, 33, 55}, + {0, 3, 8}, + {0, 28, 56}, + {0, 13, 57}, + {0, 59, 58}, + {0, 1, 49}, + {0, 47, 60}, + {0, 61, 14}, + {0, 63, 62}, + {0, 64, 43}, + {0, 7, 4}, + {0, 65, 15}, + {0, 67, 66}, + {0, 68, 17}, + {0, 36, 2}, + {0, 30, 69}, + {0, 71, 70}, + {0, 34, 5}, + {0, 73, 72}, + {0, 75, 74}, + {0, 77, 76}, + {0, 24, 78}, + {0, 39, 31}, + {0, 80, 79}, + {0, 9, 11}, + {0, 42, 81}, + {0, 83, 82}, + {0, 29, 50}, + {0, 84, 51}, + {0, 86, 85}, + {0, 22, 40}, + {0, 88, 87}, + {0, 90, 89}, + {0, 92, 91}, + {0, 94, 93}, + {0, 96, 95}, + {0, 98, 97}, + {0, 100, 99}, + })); + + codecs.emplace(std::pair(SpvOpAccessChain, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(69, { + {0, 0, 0}, + {51041423, 0, 0}, + {142465290, 0, 0}, + {144116905, 0, 0}, + {290391815, 0, 0}, + {438318340, 0, 0}, + {529742207, 0, 0}, + {677668732, 0, 0}, + {917019124, 0, 0}, + {1064945649, 0, 0}, + {1156369516, 0, 0}, + {1158021131, 0, 0}, + {1304296041, 0, 0}, + {1452222566, 0, 0}, + {1543646433, 0, 0}, + {1691572958, 0, 0}, + {1782996825, 0, 0}, + {1784648440, 0, 0}, + {1930923350, 0, 0}, + {2170273742, 0, 0}, + {2318200267, 0, 0}, + {2466126792, 0, 0}, + {2557550659, 0, 0}, + {2705477184, 0, 0}, + {2796901051, 0, 0}, + {2798552666, 0, 0}, + {2944827576, 0, 0}, + {3092754101, 0, 0}, + {3184177968, 0, 0}, + {3332104493, 0, 0}, + {3571454885, 0, 0}, + {3810805277, 0, 0}, + {3958731802, 0, 0}, + {4106658327, 0, 0}, + {4198082194, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 27, 33}, + {0, 21, 5}, + {0, 26, 13}, + {0, 20, 8}, + {0, 15, 7}, + {0, 37, 36}, + {0, 32, 29}, + {0, 38, 4}, + {0, 30, 1}, + {0, 9, 12}, + {0, 39, 18}, + {0, 22, 40}, + {0, 42, 41}, + {0, 44, 43}, + {0, 45, 35}, + {0, 46, 34}, + {0, 6, 14}, + {0, 28, 23}, + {0, 48, 47}, + {0, 49, 31}, + {0, 51, 50}, + {0, 19, 24}, + {0, 52, 10}, + {0, 2, 53}, + {0, 55, 54}, + {0, 25, 56}, + {0, 11, 57}, + {0, 59, 58}, + {0, 3, 17}, + {0, 61, 60}, + {0, 16, 62}, + {0, 64, 63}, + {0, 66, 65}, + {0, 68, 67}, + })); + + codecs.emplace(std::pair(SpvOpAccessChain, 3), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(85, { + {0, 0, 0}, + {142465290, 0, 0}, + {144116905, 0, 0}, + {198967948, 0, 0}, + {290391815, 0, 0}, + {529742207, 0, 0}, + {586244865, 0, 0}, + {677668732, 0, 0}, + {825595257, 0, 0}, + {917019124, 0, 0}, + {973521782, 0, 0}, + {1064945649, 0, 0}, + {1156369516, 0, 0}, + {1158021131, 0, 0}, + {1212872174, 0, 0}, + {1304296041, 0, 0}, + {1452222566, 0, 0}, + {1543646433, 0, 0}, + {1600149091, 0, 0}, + {1782996825, 0, 0}, + {1784648440, 0, 0}, + {1839499483, 0, 0}, + {1930923350, 0, 0}, + {2170273742, 0, 0}, + {2226776400, 0, 0}, + {2318200267, 0, 0}, + {2466126792, 0, 0}, + {2557550659, 0, 0}, + {2614053317, 0, 0}, + {2796901051, 0, 0}, + {2798552666, 0, 0}, + {2853403709, 0, 0}, + {2944827576, 0, 0}, + {3184177968, 0, 0}, + {3240680626, 0, 0}, + {3480031018, 0, 0}, + {3571454885, 0, 0}, + {3810805277, 0, 0}, + {3867307935, 0, 0}, + {3958731802, 0, 0}, + {4106658327, 0, 0}, + {4198082194, 0, 0}, + {4254584852, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 7, 11}, + {0, 15, 4}, + {0, 32, 25}, + {0, 44, 39}, + {0, 36, 22}, + {0, 45, 17}, + {0, 24, 46}, + {0, 10, 9}, + {0, 6, 27}, + {0, 28, 18}, + {0, 42, 34}, + {0, 31, 14}, + {0, 41, 38}, + {0, 26, 3}, + {0, 47, 33}, + {0, 21, 8}, + {0, 5, 35}, + {0, 40, 16}, + {0, 37, 23}, + {0, 49, 48}, + {0, 51, 50}, + {0, 53, 52}, + {0, 55, 54}, + {0, 57, 56}, + {0, 59, 58}, + {0, 61, 60}, + {0, 63, 62}, + {0, 65, 64}, + {0, 67, 66}, + {0, 68, 12}, + {0, 29, 69}, + {0, 70, 1}, + {0, 30, 2}, + {0, 43, 71}, + {0, 73, 72}, + {0, 74, 20}, + {0, 75, 19}, + {0, 77, 76}, + {0, 13, 78}, + {0, 80, 79}, + {0, 82, 81}, + {0, 84, 83}, + })); + + codecs.emplace(std::pair(SpvOpAccessChain, 4), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(9, { + {0, 0, 0}, + {144116905, 0, 0}, + {1158021131, 0, 0}, + {1784648440, 0, 0}, + {2798552666, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 5}, + {0, 4, 2}, + {0, 6, 3}, + {0, 8, 7}, + })); + + codecs.emplace(std::pair(SpvOpAccessChain, 5), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(5, { + {0, 0, 0}, + {142465290, 0, 0}, + {1782996825, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 1}, + {0, 4, 3}, + })); + + codecs.emplace(std::pair(SpvOpAccessChain, 6), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {679771963, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 4}, + {0, 2, 5}, + {0, 6, 1}, + })); + + codecs.emplace(std::pair(SpvOpVectorShuffle, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(59, { + {0, 0, 0}, + {177111659, 0, 0}, + {413918748, 0, 0}, + {529383565, 0, 0}, + {646282397, 0, 0}, + {837715723, 0, 0}, + {1019457583, 0, 0}, + {1022544883, 0, 0}, + {1054461787, 0, 0}, + {1097775533, 0, 0}, + {1136775085, 0, 0}, + {1191015885, 0, 0}, + {1196280518, 0, 0}, + {1203545131, 0, 0}, + {1352628475, 0, 0}, + {1367301635, 0, 0}, + {1918742169, 0, 0}, + {1922045399, 0, 0}, + {2055836767, 0, 0}, + {2183547611, 0, 0}, + {2204920111, 0, 0}, + {2358141757, 0, 0}, + {2572638469, 0, 0}, + {2597020383, 0, 0}, + {2842919847, 0, 0}, + {3619787319, 0, 0}, + {3701632935, 0, 0}, + {3783543823, 0, 0}, + {4245257809, 0, 0}, + {4265894873, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 11, 23}, + {0, 12, 2}, + {0, 9, 7}, + {0, 21, 19}, + {0, 4, 29}, + {0, 10, 28}, + {0, 17, 16}, + {0, 27, 3}, + {0, 32, 31}, + {0, 33, 22}, + {0, 6, 34}, + {0, 35, 8}, + {0, 36, 24}, + {0, 38, 37}, + {0, 1, 14}, + {0, 39, 20}, + {0, 5, 40}, + {0, 42, 41}, + {0, 43, 26}, + {0, 45, 44}, + {0, 47, 46}, + {0, 48, 18}, + {0, 15, 49}, + {0, 50, 25}, + {0, 51, 13}, + {0, 53, 52}, + {0, 55, 54}, + {0, 57, 56}, + {0, 30, 58}, + })); + + codecs.emplace(std::pair(SpvOpVectorShuffle, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(59, { + {0, 0, 0}, + {236660303, 0, 0}, + {342159236, 0, 0}, + {371428004, 0, 0}, + {373079619, 0, 0}, + {488500848, 0, 0}, + {495107308, 0, 0}, + {864295921, 0, 0}, + {1071164424, 0, 0}, + {1136911283, 0, 0}, + {1178317551, 0, 0}, + {1510422521, 0, 0}, + {1570165302, 0, 0}, + {1822823090, 0, 0}, + {1858116930, 0, 0}, + {1977038330, 0, 0}, + {2096388952, 0, 0}, + {2157103435, 0, 0}, + {2231688008, 0, 0}, + {2604576561, 0, 0}, + {2622612602, 0, 0}, + {2771938750, 0, 0}, + {2777172031, 0, 0}, + {2996594997, 0, 0}, + {3187066832, 0, 0}, + {3496407048, 0, 0}, + {3570411982, 0, 0}, + {3609540589, 0, 0}, + {3713290482, 0, 0}, + {3797761273, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 18, 8}, + {0, 27, 9}, + {0, 21, 10}, + {0, 14, 24}, + {0, 12, 19}, + {0, 11, 15}, + {0, 23, 2}, + {0, 7, 13}, + {0, 31, 22}, + {0, 32, 4}, + {0, 33, 29}, + {0, 34, 1}, + {0, 35, 3}, + {0, 37, 36}, + {0, 38, 28}, + {0, 39, 5}, + {0, 41, 40}, + {0, 42, 17}, + {0, 16, 43}, + {0, 45, 44}, + {0, 46, 6}, + {0, 48, 47}, + {0, 50, 49}, + {0, 52, 51}, + {0, 25, 53}, + {0, 54, 20}, + {0, 55, 26}, + {0, 57, 56}, + {0, 30, 58}, + })); + + codecs.emplace(std::pair(SpvOpVectorShuffle, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(47, { + {0, 0, 0}, + {236660303, 0, 0}, + {342159236, 0, 0}, + {488500848, 0, 0}, + {495107308, 0, 0}, + {864295921, 0, 0}, + {1178317551, 0, 0}, + {1510422521, 0, 0}, + {1570165302, 0, 0}, + {1858116930, 0, 0}, + {1977038330, 0, 0}, + {2096388952, 0, 0}, + {2157103435, 0, 0}, + {2231688008, 0, 0}, + {2604576561, 0, 0}, + {2622612602, 0, 0}, + {2771938750, 0, 0}, + {2777172031, 0, 0}, + {2996594997, 0, 0}, + {3496407048, 0, 0}, + {3570411982, 0, 0}, + {3609540589, 0, 0}, + {3713290482, 0, 0}, + {3797761273, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 21, 13}, + {0, 16, 6}, + {0, 14, 9}, + {0, 7, 10}, + {0, 18, 2}, + {0, 17, 5}, + {0, 25, 8}, + {0, 22, 12}, + {0, 26, 23}, + {0, 27, 1}, + {0, 28, 3}, + {0, 30, 29}, + {0, 32, 31}, + {0, 34, 33}, + {0, 35, 11}, + {0, 36, 4}, + {0, 38, 37}, + {0, 40, 39}, + {0, 41, 15}, + {0, 42, 19}, + {0, 20, 43}, + {0, 45, 44}, + {0, 24, 46}, + })); + + codecs.emplace(std::pair(SpvOpVectorShuffle, 3), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(15, { + {0, 0, 0}, + {679771963, 0, 0}, + {1146476634, 0, 0}, + {2160380860, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {3800912395, 0, 0}, + {3802564010, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 3}, + {0, 9, 6}, + {0, 8, 7}, + {0, 11, 10}, + {0, 4, 12}, + {0, 5, 13}, + {0, 14, 1}, + })); + + codecs.emplace(std::pair(SpvOpCompositeConstruct, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(79, { + {0, 0, 0}, + {107497541, 0, 0}, + {289648234, 0, 0}, + {348584153, 0, 0}, + {369686787, 0, 0}, + {429277936, 0, 0}, + {449954059, 0, 0}, + {508217552, 0, 0}, + {742917749, 0, 0}, + {1032593647, 0, 0}, + {1158929937, 0, 0}, + {1209418480, 0, 0}, + {1319785741, 0, 0}, + {1321616112, 0, 0}, + {1417363940, 0, 0}, + {1541020250, 0, 0}, + {1564342316, 0, 0}, + {1578775276, 0, 0}, + {1631434666, 0, 0}, + {1636389511, 0, 0}, + {2012838864, 0, 0}, + {2262137600, 0, 0}, + {2281956980, 0, 0}, + {2359973133, 0, 0}, + {2464905186, 0, 0}, + {2613179511, 0, 0}, + {2621255555, 0, 0}, + {2817335337, 0, 0}, + {2881302403, 0, 0}, + {3063300848, 0, 0}, + {3151638847, 0, 0}, + {3233393284, 0, 0}, + {3323682385, 0, 0}, + {3337532056, 0, 0}, + {3456899824, 0, 0}, + {3547456240, 0, 0}, + {3675926744, 0, 0}, + {3753486980, 0, 0}, + {3931641900, 0, 0}, + {3970432934, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 25, 1}, + {0, 6, 4}, + {0, 8, 19}, + {0, 39, 24}, + {0, 3, 2}, + {0, 34, 14}, + {0, 10, 9}, + {0, 18, 38}, + {0, 32, 15}, + {0, 27, 16}, + {0, 28, 35}, + {0, 13, 26}, + {0, 20, 23}, + {0, 21, 11}, + {0, 36, 33}, + {0, 5, 22}, + {0, 42, 41}, + {0, 43, 29}, + {0, 45, 44}, + {0, 7, 46}, + {0, 48, 47}, + {0, 30, 31}, + {0, 50, 49}, + {0, 52, 51}, + {0, 54, 53}, + {0, 55, 17}, + {0, 57, 56}, + {0, 59, 58}, + {0, 61, 60}, + {0, 62, 12}, + {0, 64, 63}, + {0, 66, 65}, + {0, 67, 37}, + {0, 69, 68}, + {0, 71, 70}, + {0, 73, 72}, + {0, 75, 74}, + {0, 77, 76}, + {0, 40, 78}, + })); + + codecs.emplace(std::pair(SpvOpCompositeConstruct, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(87, { + {0, 0, 0}, + {153013225, 0, 0}, + {296836635, 0, 0}, + {296981500, 0, 0}, + {778500192, 0, 0}, + {810488476, 0, 0}, + {848380423, 0, 0}, + {900522183, 0, 0}, + {910398460, 0, 0}, + {959681532, 0, 0}, + {1141965917, 0, 0}, + {1287304304, 0, 0}, + {1323407757, 0, 0}, + {1417363940, 0, 0}, + {1471851763, 0, 0}, + {1526654696, 0, 0}, + {1654776395, 0, 0}, + {1684282922, 0, 0}, + {1739837626, 0, 0}, + {1791352211, 0, 0}, + {2195550588, 0, 0}, + {2319227476, 0, 0}, + {2491124112, 0, 0}, + {2789375411, 0, 0}, + {2807448986, 0, 0}, + {2817579280, 0, 0}, + {2835131395, 0, 0}, + {2847102741, 0, 0}, + {2855506940, 0, 0}, + {2860348412, 0, 0}, + {3079287749, 0, 0}, + {3091876332, 0, 0}, + {3168953855, 0, 0}, + {3374978006, 0, 0}, + {3399062057, 0, 0}, + {3510257966, 0, 0}, + {3554463148, 0, 0}, + {3579593979, 0, 0}, + {3757851979, 0, 0}, + {3759503594, 0, 0}, + {3761155209, 0, 0}, + {3762806824, 0, 0}, + {3902853271, 0, 0}, + {4140081844, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 38, 42}, + {0, 14, 23}, + {0, 26, 18}, + {0, 39, 35}, + {0, 6, 40}, + {0, 16, 13}, + {0, 33, 34}, + {0, 12, 4}, + {0, 27, 41}, + {0, 25, 21}, + {0, 24, 1}, + {0, 37, 19}, + {0, 32, 22}, + {0, 2, 8}, + {0, 20, 17}, + {0, 43, 36}, + {0, 29, 15}, + {0, 46, 45}, + {0, 48, 47}, + {0, 50, 49}, + {0, 52, 51}, + {0, 54, 53}, + {0, 7, 55}, + {0, 56, 30}, + {0, 57, 5}, + {0, 59, 58}, + {0, 60, 11}, + {0, 9, 61}, + {0, 63, 62}, + {0, 65, 64}, + {0, 66, 31}, + {0, 68, 67}, + {0, 10, 69}, + {0, 71, 70}, + {0, 28, 72}, + {0, 74, 73}, + {0, 76, 75}, + {0, 78, 77}, + {0, 79, 3}, + {0, 81, 80}, + {0, 83, 82}, + {0, 85, 84}, + {0, 44, 86}, + })); + + codecs.emplace(std::pair(SpvOpCompositeConstruct, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(81, { + {0, 0, 0}, + {14244860, 0, 0}, + {150820676, 0, 0}, + {153013225, 0, 0}, + {269823086, 0, 0}, + {289648234, 0, 0}, + {296981500, 0, 0}, + {678695941, 0, 0}, + {810488476, 0, 0}, + {850592577, 0, 0}, + {870594305, 0, 0}, + {910398460, 0, 0}, + {959681532, 0, 0}, + {1206571206, 0, 0}, + {1287304304, 0, 0}, + {1323407757, 0, 0}, + {1471851763, 0, 0}, + {1526654696, 0, 0}, + {1684282922, 0, 0}, + {1734446471, 0, 0}, + {1758530522, 0, 0}, + {2117320444, 0, 0}, + {2118972059, 0, 0}, + {2120623674, 0, 0}, + {2122275289, 0, 0}, + {2219733501, 0, 0}, + {2262321736, 0, 0}, + {2807448986, 0, 0}, + {2817579280, 0, 0}, + {2835131395, 0, 0}, + {2855506940, 0, 0}, + {2860348412, 0, 0}, + {2951272396, 0, 0}, + {3079287749, 0, 0}, + {3168953855, 0, 0}, + {3502816184, 0, 0}, + {3510257966, 0, 0}, + {3554463148, 0, 0}, + {3997952447, 0, 0}, + {4140081844, 0, 0}, + {4182141402, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 21, 26}, + {0, 29, 16}, + {0, 22, 36}, + {0, 1, 23}, + {0, 20, 5}, + {0, 19, 35}, + {0, 10, 38}, + {0, 13, 24}, + {0, 28, 7}, + {0, 27, 3}, + {0, 40, 2}, + {0, 34, 9}, + {0, 32, 11}, + {0, 33, 18}, + {0, 39, 37}, + {0, 31, 17}, + {0, 43, 42}, + {0, 45, 44}, + {0, 47, 46}, + {0, 49, 48}, + {0, 51, 50}, + {0, 8, 52}, + {0, 15, 53}, + {0, 55, 54}, + {0, 56, 14}, + {0, 58, 57}, + {0, 60, 59}, + {0, 61, 25}, + {0, 63, 62}, + {0, 4, 64}, + {0, 66, 65}, + {0, 68, 67}, + {0, 70, 69}, + {0, 71, 12}, + {0, 6, 72}, + {0, 30, 73}, + {0, 75, 74}, + {0, 77, 76}, + {0, 79, 78}, + {0, 41, 80}, + })); + + codecs.emplace(std::pair(SpvOpCompositeConstruct, 3), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(111, { + {0, 0, 0}, + {34183582, 0, 0}, + {93914936, 0, 0}, + {94303122, 0, 0}, + {117998987, 0, 0}, + {153013225, 0, 0}, + {296981500, 0, 0}, + {451264926, 0, 0}, + {473485679, 0, 0}, + {476788909, 0, 0}, + {478440524, 0, 0}, + {480092139, 0, 0}, + {481743754, 0, 0}, + {810488476, 0, 0}, + {871966503, 0, 0}, + {910398460, 0, 0}, + {918189168, 0, 0}, + {933769938, 0, 0}, + {959681532, 0, 0}, + {1149665466, 0, 0}, + {1166917451, 0, 0}, + {1227221002, 0, 0}, + {1310740861, 0, 0}, + {1323407757, 0, 0}, + {1341516288, 0, 0}, + {1373166395, 0, 0}, + {1445161581, 0, 0}, + {1461645203, 0, 0}, + {1471851763, 0, 0}, + {1526654696, 0, 0}, + {1561718045, 0, 0}, + {1593584949, 0, 0}, + {1684282922, 0, 0}, + {1800404122, 0, 0}, + {1862284649, 0, 0}, + {2213411495, 0, 0}, + {2668680621, 0, 0}, + {2805256437, 0, 0}, + {2807448986, 0, 0}, + {2835131395, 0, 0}, + {2855506940, 0, 0}, + {2860348412, 0, 0}, + {3000904950, 0, 0}, + {3107413701, 0, 0}, + {3168953855, 0, 0}, + {3333131702, 0, 0}, + {3365041621, 0, 0}, + {3456899824, 0, 0}, + {3505028338, 0, 0}, + {3510257966, 0, 0}, + {3554463148, 0, 0}, + {3606320646, 0, 0}, + {3692647551, 0, 0}, + {3861006967, 0, 0}, + {4126287524, 0, 0}, + {4140081844, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 14, 33}, + {0, 35, 25}, + {0, 27, 17}, + {0, 8, 20}, + {0, 3, 54}, + {0, 1, 19}, + {0, 10, 46}, + {0, 11, 9}, + {0, 39, 28}, + {0, 53, 49}, + {0, 12, 2}, + {0, 34, 4}, + {0, 47, 36}, + {0, 23, 45}, + {0, 5, 37}, + {0, 24, 38}, + {0, 43, 26}, + {0, 48, 51}, + {0, 44, 32}, + {0, 15, 16}, + {0, 57, 22}, + {0, 55, 50}, + {0, 29, 58}, + {0, 60, 59}, + {0, 41, 61}, + {0, 63, 62}, + {0, 65, 64}, + {0, 67, 66}, + {0, 69, 68}, + {0, 13, 70}, + {0, 71, 7}, + {0, 42, 31}, + {0, 73, 72}, + {0, 75, 74}, + {0, 21, 30}, + {0, 77, 76}, + {0, 79, 78}, + {0, 81, 80}, + {0, 82, 18}, + {0, 84, 83}, + {0, 86, 85}, + {0, 88, 87}, + {0, 90, 89}, + {0, 52, 91}, + {0, 6, 92}, + {0, 94, 93}, + {0, 96, 95}, + {0, 98, 97}, + {0, 99, 40}, + {0, 101, 100}, + {0, 103, 102}, + {0, 105, 104}, + {0, 107, 106}, + {0, 109, 108}, + {0, 56, 110}, + })); + + codecs.emplace(std::pair(SpvOpCompositeConstruct, 4), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(155, { + {0, 0, 0}, + {18776483, 0, 0}, + {37009196, 0, 0}, + {277023757, 0, 0}, + {296981500, 0, 0}, + {348988933, 0, 0}, + {451264926, 0, 0}, + {564884461, 0, 0}, + {804899022, 0, 0}, + {810488476, 0, 0}, + {870594305, 0, 0}, + {876864198, 0, 0}, + {900522183, 0, 0}, + {928261291, 0, 0}, + {959681532, 0, 0}, + {1164724902, 0, 0}, + {1323407757, 0, 0}, + {1332774287, 0, 0}, + {1404739463, 0, 0}, + {1447712361, 0, 0}, + {1450415100, 0, 0}, + {1513770932, 0, 0}, + {1620634991, 0, 0}, + {1692600167, 0, 0}, + {1860649552, 0, 0}, + {1932614728, 0, 0}, + {2087004702, 0, 0}, + {2148510256, 0, 0}, + {2220475432, 0, 0}, + {2388524817, 0, 0}, + {2460489993, 0, 0}, + {2676385521, 0, 0}, + {2748350697, 0, 0}, + {2855506940, 0, 0}, + {2860348412, 0, 0}, + {2916400082, 0, 0}, + {2988365258, 0, 0}, + {3061856840, 0, 0}, + {3063508455, 0, 0}, + {3065160070, 0, 0}, + {3066811685, 0, 0}, + {3068463300, 0, 0}, + {3070114915, 0, 0}, + {3071766530, 0, 0}, + {3073418145, 0, 0}, + {3075069760, 0, 0}, + {3076721375, 0, 0}, + {3078372990, 0, 0}, + {3080024605, 0, 0}, + {3081676220, 0, 0}, + {3083327835, 0, 0}, + {3084979450, 0, 0}, + {3086631065, 0, 0}, + {3088282680, 0, 0}, + {3114708520, 0, 0}, + {3116360135, 0, 0}, + {3118011750, 0, 0}, + {3119663365, 0, 0}, + {3121314980, 0, 0}, + {3124618210, 0, 0}, + {3126269825, 0, 0}, + {3127921440, 0, 0}, + {3129573055, 0, 0}, + {3131224670, 0, 0}, + {3132876285, 0, 0}, + {3134527900, 0, 0}, + {3136179515, 0, 0}, + {3204260786, 0, 0}, + {3264086791, 0, 0}, + {3276225962, 0, 0}, + {3444275347, 0, 0}, + {3516240523, 0, 0}, + {3588205699, 0, 0}, + {3732136051, 0, 0}, + {3804101227, 0, 0}, + {3874089391, 0, 0}, + {4044115788, 0, 0}, + {4116080964, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 45, 43}, + {0, 3, 46}, + {0, 71, 36}, + {0, 44, 34}, + {0, 76, 54}, + {0, 73, 55}, + {0, 57, 67}, + {0, 51, 56}, + {0, 31, 27}, + {0, 38, 37}, + {0, 40, 39}, + {0, 42, 41}, + {0, 49, 47}, + {0, 35, 50}, + {0, 21, 70}, + {0, 19, 5}, + {0, 8, 58}, + {0, 17, 11}, + {0, 24, 18}, + {0, 30, 29}, + {0, 52, 9}, + {0, 77, 22}, + {0, 62, 48}, + {0, 25, 53}, + {0, 20, 59}, + {0, 26, 60}, + {0, 72, 6}, + {0, 79, 69}, + {0, 80, 7}, + {0, 81, 2}, + {0, 12, 13}, + {0, 82, 68}, + {0, 65, 61}, + {0, 74, 63}, + {0, 23, 83}, + {0, 64, 10}, + {0, 84, 32}, + {0, 66, 28}, + {0, 15, 85}, + {0, 86, 16}, + {0, 88, 87}, + {0, 90, 89}, + {0, 92, 91}, + {0, 1, 93}, + {0, 95, 94}, + {0, 97, 96}, + {0, 99, 98}, + {0, 100, 75}, + {0, 102, 101}, + {0, 104, 103}, + {0, 106, 105}, + {0, 107, 14}, + {0, 109, 108}, + {0, 111, 110}, + {0, 113, 112}, + {0, 115, 114}, + {0, 117, 116}, + {0, 119, 118}, + {0, 121, 120}, + {0, 123, 122}, + {0, 125, 124}, + {0, 127, 126}, + {0, 129, 128}, + {0, 131, 130}, + {0, 133, 132}, + {0, 135, 134}, + {0, 137, 136}, + {0, 139, 138}, + {0, 141, 140}, + {0, 143, 142}, + {0, 145, 144}, + {0, 147, 146}, + {0, 33, 148}, + {0, 4, 149}, + {0, 78, 150}, + {0, 152, 151}, + {0, 154, 153}, + })); + + codecs.emplace(std::pair(SpvOpCompositeConstruct, 5), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(11, { + {0, 0, 0}, + {679771963, 0, 0}, + {789872778, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 6, 2}, + {0, 4, 7}, + {0, 1, 8}, + {0, 9, 5}, + {0, 3, 10}, + })); + + codecs.emplace(std::pair(SpvOpCompositeExtract, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(49, { + {0, 0, 0}, + {126463145, 0, 0}, + {171307615, 0, 0}, + {342159236, 0, 0}, + {354479447, 0, 0}, + {593829839, 0, 0}, + {743407979, 0, 0}, + {898191441, 0, 0}, + {900522183, 0, 0}, + {1265796414, 0, 0}, + {1287304304, 0, 0}, + {1356063462, 0, 0}, + {1368383673, 0, 0}, + {1526654696, 0, 0}, + {1766994680, 0, 0}, + {1793544760, 0, 0}, + {1811839150, 0, 0}, + {2234361374, 0, 0}, + {2279700640, 0, 0}, + {2383939514, 0, 0}, + {2780898906, 0, 0}, + {2996594997, 0, 0}, + {3413713311, 0, 0}, + {3554463148, 0, 0}, + {3635542517, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 11, 15}, + {0, 20, 14}, + {0, 7, 18}, + {0, 6, 1}, + {0, 12, 10}, + {0, 23, 19}, + {0, 13, 5}, + {0, 24, 17}, + {0, 21, 3}, + {0, 22, 16}, + {0, 26, 2}, + {0, 27, 8}, + {0, 4, 28}, + {0, 29, 9}, + {0, 31, 30}, + {0, 33, 32}, + {0, 35, 34}, + {0, 37, 36}, + {0, 39, 38}, + {0, 41, 40}, + {0, 43, 42}, + {0, 45, 44}, + {0, 47, 46}, + {0, 25, 48}, + })); + + codecs.emplace(std::pair(SpvOpCompositeExtract, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(153, { + {0, 0, 0}, + {13107491, 0, 0}, + {257136089, 0, 0}, + {293528591, 0, 0}, + {321459212, 0, 0}, + {425022309, 0, 0}, + {490769168, 0, 0}, + {495107308, 0, 0}, + {517919178, 0, 0}, + {617312262, 0, 0}, + {708736129, 0, 0}, + {753756604, 0, 0}, + {765238787, 0, 0}, + {796985462, 0, 0}, + {819503463, 0, 0}, + {850497536, 0, 0}, + {948086521, 0, 0}, + {1004589179, 0, 0}, + {1120149824, 0, 0}, + {1165671422, 0, 0}, + {1203545131, 0, 0}, + {1297165140, 0, 0}, + {1335363438, 0, 0}, + {1351676723, 0, 0}, + {1391866096, 0, 0}, + {1584369690, 0, 0}, + {1631216488, 0, 0}, + {1691646294, 0, 0}, + {1779143013, 0, 0}, + {1858116930, 0, 0}, + {1890300748, 0, 0}, + {1915438939, 0, 0}, + {1918742169, 0, 0}, + {1922045399, 0, 0}, + {1961990747, 0, 0}, + {2037710159, 0, 0}, + {2037814253, 0, 0}, + {2043873558, 0, 0}, + {2096388952, 0, 0}, + {2169307971, 0, 0}, + {2257843797, 0, 0}, + {2262220987, 0, 0}, + {2338272340, 0, 0}, + {2405770322, 0, 0}, + {2498042266, 0, 0}, + {2563789125, 0, 0}, + {2588618056, 0, 0}, + {2645120714, 0, 0}, + {2864863800, 0, 0}, + {2909957084, 0, 0}, + {2975894973, 0, 0}, + {3041450802, 0, 0}, + {3151638847, 0, 0}, + {3187066832, 0, 0}, + {3244716568, 0, 0}, + {3271748023, 0, 0}, + {3304438238, 0, 0}, + {3312467582, 0, 0}, + {3325419312, 0, 0}, + {3370185097, 0, 0}, + {3419674548, 0, 0}, + {3435931956, 0, 0}, + {3504158761, 0, 0}, + {3602522282, 0, 0}, + {3653059026, 0, 0}, + {3716353056, 0, 0}, + {3782099915, 0, 0}, + {3838648480, 0, 0}, + {3847846774, 0, 0}, + {3913593633, 0, 0}, + {3989799199, 0, 0}, + {3997038726, 0, 0}, + {4046301857, 0, 0}, + {4092654294, 0, 0}, + {4176581069, 0, 0}, + {4242327928, 0, 0}, + {4285652249, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 74, 38}, + {0, 12, 56}, + {0, 28, 24}, + {0, 60, 43}, + {0, 65, 72}, + {0, 18, 2}, + {0, 52, 3}, + {0, 19, 10}, + {0, 49, 36}, + {0, 67, 66}, + {0, 41, 17}, + {0, 53, 11}, + {0, 29, 68}, + {0, 26, 55}, + {0, 70, 76}, + {0, 73, 47}, + {0, 51, 22}, + {0, 39, 21}, + {0, 5, 9}, + {0, 40, 48}, + {0, 59, 44}, + {0, 6, 69}, + {0, 32, 31}, + {0, 4, 33}, + {0, 13, 54}, + {0, 14, 50}, + {0, 35, 75}, + {0, 58, 23}, + {0, 16, 34}, + {0, 27, 63}, + {0, 45, 61}, + {0, 20, 46}, + {0, 71, 1}, + {0, 79, 78}, + {0, 81, 80}, + {0, 83, 82}, + {0, 84, 8}, + {0, 86, 85}, + {0, 88, 87}, + {0, 90, 89}, + {0, 92, 91}, + {0, 94, 93}, + {0, 96, 95}, + {0, 98, 97}, + {0, 64, 99}, + {0, 101, 100}, + {0, 103, 102}, + {0, 105, 104}, + {0, 106, 62}, + {0, 108, 107}, + {0, 110, 109}, + {0, 7, 111}, + {0, 113, 112}, + {0, 115, 114}, + {0, 117, 116}, + {0, 119, 118}, + {0, 121, 120}, + {0, 123, 122}, + {0, 30, 124}, + {0, 126, 125}, + {0, 128, 127}, + {0, 130, 129}, + {0, 132, 131}, + {0, 134, 133}, + {0, 135, 25}, + {0, 57, 136}, + {0, 138, 137}, + {0, 42, 139}, + {0, 37, 140}, + {0, 142, 141}, + {0, 143, 15}, + {0, 145, 144}, + {0, 147, 146}, + {0, 149, 148}, + {0, 151, 150}, + {0, 152, 77}, + })); + + codecs.emplace(std::pair(SpvOpCompositeExtract, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(47, { + {0, 0, 0}, + {545678922, 0, 0}, + {630592085, 0, 0}, + {679771963, 0, 0}, + {899570100, 0, 0}, + {906176560, 0, 0}, + {929101967, 0, 0}, + {1100599986, 0, 0}, + {1103903216, 0, 0}, + {1107206446, 0, 0}, + {1369578001, 0, 0}, + {1372881231, 0, 0}, + {2320303498, 0, 0}, + {2926633629, 0, 0}, + {3249265647, 0, 0}, + {3334207724, 0, 0}, + {3486057732, 0, 0}, + {3674863070, 0, 0}, + {3705139860, 0, 0}, + {3800912395, 0, 0}, + {3802564010, 0, 0}, + {3822983876, 0, 0}, + {4141567741, 0, 0}, + {4292991777, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 9, 17}, + {0, 20, 11}, + {0, 25, 5}, + {0, 2, 14}, + {0, 23, 13}, + {0, 16, 26}, + {0, 27, 24}, + {0, 28, 8}, + {0, 29, 18}, + {0, 22, 30}, + {0, 6, 31}, + {0, 21, 32}, + {0, 3, 33}, + {0, 35, 34}, + {0, 1, 12}, + {0, 10, 36}, + {0, 37, 19}, + {0, 4, 15}, + {0, 39, 38}, + {0, 7, 40}, + {0, 42, 41}, + {0, 44, 43}, + {0, 46, 45}, + })); + + codecs.emplace(std::pair(SpvOpCompositeInsert, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(103, { + {0, 0, 0}, + {125792961, 0, 0}, + {132755933, 0, 0}, + {156014509, 0, 0}, + {436066778, 0, 0}, + {463084678, 0, 0}, + {531559080, 0, 0}, + {565233904, 0, 0}, + {578132535, 0, 0}, + {600906020, 0, 0}, + {602222721, 0, 0}, + {694743357, 0, 0}, + {760554870, 0, 0}, + {996663016, 0, 0}, + {1022309772, 0, 0}, + {1351676723, 0, 0}, + {1496901698, 0, 0}, + {1502470404, 0, 0}, + {1522901980, 0, 0}, + {1548254487, 0, 0}, + {1637661947, 0, 0}, + {1788504755, 0, 0}, + {2092468906, 0, 0}, + {2094647776, 0, 0}, + {2127660080, 0, 0}, + {2213946343, 0, 0}, + {2225172640, 0, 0}, + {2259467579, 0, 0}, + {2263866576, 0, 0}, + {2600961503, 0, 0}, + {2727022058, 0, 0}, + {2752967311, 0, 0}, + {2864705739, 0, 0}, + {3021406120, 0, 0}, + {3044723416, 0, 0}, + {3052439312, 0, 0}, + {3136865519, 0, 0}, + {3297860332, 0, 0}, + {3352361837, 0, 0}, + {3670298840, 0, 0}, + {3712946115, 0, 0}, + {3732709413, 0, 0}, + {3764662384, 0, 0}, + {3788324110, 0, 0}, + {3928555688, 0, 0}, + {4083347580, 0, 0}, + {4098876453, 0, 0}, + {4147239510, 0, 0}, + {4199470013, 0, 0}, + {4211577142, 0, 0}, + {4218799564, 0, 0}, + {4290374884, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 4, 2}, + {0, 9, 8}, + {0, 17, 10}, + {0, 20, 18}, + {0, 22, 21}, + {0, 26, 23}, + {0, 31, 29}, + {0, 35, 34}, + {0, 45, 36}, + {0, 5, 3}, + {0, 12, 6}, + {0, 15, 14}, + {0, 25, 19}, + {0, 28, 27}, + {0, 38, 33}, + {0, 43, 39}, + {0, 47, 46}, + {0, 50, 49}, + {0, 7, 51}, + {0, 1, 48}, + {0, 37, 24}, + {0, 44, 42}, + {0, 13, 11}, + {0, 41, 40}, + {0, 54, 53}, + {0, 56, 55}, + {0, 58, 57}, + {0, 60, 59}, + {0, 62, 61}, + {0, 64, 63}, + {0, 66, 65}, + {0, 68, 67}, + {0, 70, 69}, + {0, 72, 71}, + {0, 30, 16}, + {0, 73, 32}, + {0, 75, 74}, + {0, 77, 76}, + {0, 79, 78}, + {0, 81, 80}, + {0, 83, 82}, + {0, 85, 84}, + {0, 87, 86}, + {0, 89, 88}, + {0, 91, 90}, + {0, 93, 92}, + {0, 95, 94}, + {0, 97, 96}, + {0, 99, 98}, + {0, 101, 100}, + {0, 52, 102}, + })); + + codecs.emplace(std::pair(SpvOpCompositeInsert, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(93, { + {0, 0, 0}, + {17185761, 0, 0}, + {117250846, 0, 0}, + {296981500, 0, 0}, + {330388453, 0, 0}, + {346929928, 0, 0}, + {533021259, 0, 0}, + {564302770, 0, 0}, + {680157484, 0, 0}, + {721450866, 0, 0}, + {798549062, 0, 0}, + {853200279, 0, 0}, + {864295921, 0, 0}, + {900522183, 0, 0}, + {973908139, 0, 0}, + {983243705, 0, 0}, + {1033363654, 0, 0}, + {1037370721, 0, 0}, + {1464587427, 0, 0}, + {1670691893, 0, 0}, + {1686512349, 0, 0}, + {1849065716, 0, 0}, + {1917602962, 0, 0}, + {1965902997, 0, 0}, + {2121980967, 0, 0}, + {2311072371, 0, 0}, + {2339901602, 0, 0}, + {2517964682, 0, 0}, + {2542834724, 0, 0}, + {2558655180, 0, 0}, + {2736881867, 0, 0}, + {2855506940, 0, 0}, + {2888753905, 0, 0}, + {2950446516, 0, 0}, + {3044188332, 0, 0}, + {3079287749, 0, 0}, + {3153451899, 0, 0}, + {3214537066, 0, 0}, + {3234673086, 0, 0}, + {3349230696, 0, 0}, + {3504158761, 0, 0}, + {3570411982, 0, 0}, + {3652695478, 0, 0}, + {3764205609, 0, 0}, + {3940720663, 0, 0}, + {4180570743, 0, 0}, + {4221373527, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 24, 18}, + {0, 4, 2}, + {0, 15, 14}, + {0, 21, 20}, + {0, 29, 26}, + {0, 42, 36}, + {0, 7, 45}, + {0, 37, 9}, + {0, 8, 5}, + {0, 32, 11}, + {0, 39, 38}, + {0, 12, 10}, + {0, 28, 19}, + {0, 1, 46}, + {0, 17, 6}, + {0, 30, 23}, + {0, 44, 33}, + {0, 35, 13}, + {0, 16, 48}, + {0, 50, 49}, + {0, 52, 51}, + {0, 54, 53}, + {0, 55, 40}, + {0, 57, 56}, + {0, 59, 58}, + {0, 61, 60}, + {0, 25, 22}, + {0, 63, 62}, + {0, 3, 64}, + {0, 66, 65}, + {0, 68, 67}, + {0, 70, 69}, + {0, 34, 71}, + {0, 73, 72}, + {0, 75, 74}, + {0, 77, 76}, + {0, 27, 43}, + {0, 79, 78}, + {0, 81, 80}, + {0, 83, 82}, + {0, 84, 31}, + {0, 86, 85}, + {0, 41, 87}, + {0, 89, 88}, + {0, 91, 90}, + {0, 47, 92}, + })); + + codecs.emplace(std::pair(SpvOpCompositeInsert, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(115, { + {0, 0, 0}, + {132755933, 0, 0}, + {156014509, 0, 0}, + {255227811, 0, 0}, + {371186900, 0, 0}, + {371428004, 0, 0}, + {374731234, 0, 0}, + {531559080, 0, 0}, + {565233904, 0, 0}, + {578132535, 0, 0}, + {591140762, 0, 0}, + {600906020, 0, 0}, + {602222721, 0, 0}, + {656610661, 0, 0}, + {760554870, 0, 0}, + {996663016, 0, 0}, + {1022309772, 0, 0}, + {1496901698, 0, 0}, + {1502470404, 0, 0}, + {1522901980, 0, 0}, + {1536350567, 0, 0}, + {1543280290, 0, 0}, + {1548254487, 0, 0}, + {1788504755, 0, 0}, + {2064733527, 0, 0}, + {2092468906, 0, 0}, + {2094647776, 0, 0}, + {2162986400, 0, 0}, + {2225172640, 0, 0}, + {2259467579, 0, 0}, + {2263866576, 0, 0}, + {2360004627, 0, 0}, + {2507709226, 0, 0}, + {2600961503, 0, 0}, + {2727022058, 0, 0}, + {2752967311, 0, 0}, + {2864705739, 0, 0}, + {3021406120, 0, 0}, + {3052439312, 0, 0}, + {3136865519, 0, 0}, + {3297860332, 0, 0}, + {3352361837, 0, 0}, + {3598957382, 0, 0}, + {3619787319, 0, 0}, + {3655201337, 0, 0}, + {3670298840, 0, 0}, + {3774892253, 0, 0}, + {3788324110, 0, 0}, + {3808408202, 0, 0}, + {3951925872, 0, 0}, + {3952316364, 0, 0}, + {4098876453, 0, 0}, + {4147239510, 0, 0}, + {4199470013, 0, 0}, + {4211577142, 0, 0}, + {4217306348, 0, 0}, + {4218799564, 0, 0}, + {4290374884, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 6, 43}, + {0, 4, 1}, + {0, 11, 9}, + {0, 13, 12}, + {0, 19, 18}, + {0, 25, 23}, + {0, 28, 26}, + {0, 35, 33}, + {0, 39, 38}, + {0, 2, 49}, + {0, 7, 3}, + {0, 16, 14}, + {0, 29, 22}, + {0, 37, 30}, + {0, 45, 41}, + {0, 51, 47}, + {0, 54, 52}, + {0, 57, 56}, + {0, 53, 8}, + {0, 32, 10}, + {0, 42, 40}, + {0, 24, 46}, + {0, 15, 50}, + {0, 55, 20}, + {0, 59, 44}, + {0, 61, 60}, + {0, 63, 62}, + {0, 65, 64}, + {0, 67, 66}, + {0, 69, 68}, + {0, 71, 70}, + {0, 73, 72}, + {0, 75, 74}, + {0, 77, 76}, + {0, 31, 17}, + {0, 36, 34}, + {0, 79, 78}, + {0, 81, 80}, + {0, 27, 82}, + {0, 5, 21}, + {0, 48, 83}, + {0, 85, 84}, + {0, 87, 86}, + {0, 89, 88}, + {0, 91, 90}, + {0, 93, 92}, + {0, 95, 94}, + {0, 97, 96}, + {0, 99, 98}, + {0, 101, 100}, + {0, 103, 102}, + {0, 105, 104}, + {0, 107, 106}, + {0, 109, 108}, + {0, 111, 110}, + {0, 113, 112}, + {0, 58, 114}, + })); + + codecs.emplace(std::pair(SpvOpCompositeInsert, 3), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {3866587616, 0, 0}, + {3868239231, 0, 0}, + {3869890846, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 4}, + {0, 2, 5}, + {0, 3, 6}, + })); + + codecs.emplace(std::pair(SpvOpSampledImage, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(9, { + {0, 0, 0}, + {1164218401, 0, 0}, + {2036361232, 0, 0}, + {2637132451, 0, 0}, + {3237903670, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 4, 5}, + {0, 3, 6}, + {0, 1, 7}, + {0, 2, 8}, + })); + + codecs.emplace(std::pair(SpvOpSampledImage, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {543558236, 0, 0}, + {1069781886, 0, 0}, + {1596005536, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 4}, + {0, 2, 5}, + {0, 1, 6}, + })); + + codecs.emplace(std::pair(SpvOpSampledImage, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {1949759310, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpSampledImage, 3), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpImageSampleImplicitLod, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(87, { + {0, 0, 0}, + {236660303, 0, 0}, + {347505241, 0, 0}, + {426360862, 0, 0}, + {439998433, 0, 0}, + {488500848, 0, 0}, + {495107308, 0, 0}, + {868652905, 0, 0}, + {1191735827, 0, 0}, + {1265998516, 0, 0}, + {1309728002, 0, 0}, + {1365842164, 0, 0}, + {1396344138, 0, 0}, + {1508074873, 0, 0}, + {1553476262, 0, 0}, + {1642818143, 0, 0}, + {1851510470, 0, 0}, + {1858116930, 0, 0}, + {1863199739, 0, 0}, + {1979978194, 0, 0}, + {1986584654, 0, 0}, + {2092100514, 0, 0}, + {2098706974, 0, 0}, + {2231688008, 0, 0}, + {2232491275, 0, 0}, + {2329992200, 0, 0}, + {2637935122, 0, 0}, + {2693892518, 0, 0}, + {2759250216, 0, 0}, + {2839765116, 0, 0}, + {2855895374, 0, 0}, + {2913136690, 0, 0}, + {3012980338, 0, 0}, + {3327770644, 0, 0}, + {3362344229, 0, 0}, + {3398925952, 0, 0}, + {3448018532, 0, 0}, + {3457985288, 0, 0}, + {3566035349, 0, 0}, + {3657635382, 0, 0}, + {3702405475, 0, 0}, + {3757479030, 0, 0}, + {3797204453, 0, 0}, + {4291477370, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 32, 28}, + {0, 9, 35}, + {0, 31, 11}, + {0, 10, 30}, + {0, 25, 21}, + {0, 40, 2}, + {0, 15, 19}, + {0, 24, 36}, + {0, 42, 4}, + {0, 18, 16}, + {0, 29, 26}, + {0, 43, 7}, + {0, 45, 8}, + {0, 37, 13}, + {0, 47, 46}, + {0, 48, 33}, + {0, 49, 14}, + {0, 3, 22}, + {0, 50, 12}, + {0, 41, 39}, + {0, 51, 34}, + {0, 52, 20}, + {0, 54, 53}, + {0, 56, 55}, + {0, 58, 57}, + {0, 60, 59}, + {0, 61, 23}, + {0, 63, 62}, + {0, 65, 64}, + {0, 27, 66}, + {0, 67, 38}, + {0, 68, 17}, + {0, 70, 69}, + {0, 72, 71}, + {0, 74, 73}, + {0, 76, 75}, + {0, 5, 77}, + {0, 78, 1}, + {0, 80, 79}, + {0, 82, 81}, + {0, 83, 6}, + {0, 85, 84}, + {0, 44, 86}, + })); + + codecs.emplace(std::pair(SpvOpImageSampleImplicitLod, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(15, { + {0, 0, 0}, + {883854656, 0, 0}, + {1962971231, 0, 0}, + {2036361232, 0, 0}, + {2356768706, 0, 0}, + {2637132451, 0, 0}, + {3237903670, 0, 0}, + {3829682756, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 8, 2}, + {0, 6, 9}, + {0, 10, 7}, + {0, 4, 5}, + {0, 12, 11}, + {0, 3, 13}, + {0, 14, 1}, + })); + + codecs.emplace(std::pair(SpvOpImageSampleImplicitLod, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(87, { + {0, 0, 0}, + {150685616, 0, 0}, + {255302575, 0, 0}, + {414620710, 0, 0}, + {557400685, 0, 0}, + {575205902, 0, 0}, + {618761615, 0, 0}, + {646282397, 0, 0}, + {686024761, 0, 0}, + {740921498, 0, 0}, + {921246433, 0, 0}, + {1057578789, 0, 0}, + {1162127370, 0, 0}, + {1329499601, 0, 0}, + {1352628475, 0, 0}, + {1502028603, 0, 0}, + {1519723107, 0, 0}, + {1543798545, 0, 0}, + {1545450160, 0, 0}, + {1570165302, 0, 0}, + {1600392975, 0, 0}, + {1641415225, 0, 0}, + {2204920111, 0, 0}, + {2257971049, 0, 0}, + {2276405827, 0, 0}, + {2339018837, 0, 0}, + {2340670452, 0, 0}, + {2517964682, 0, 0}, + {2532518896, 0, 0}, + {2674090849, 0, 0}, + {2754074729, 0, 0}, + {2804281092, 0, 0}, + {2816338013, 0, 0}, + {2841008029, 0, 0}, + {3234673086, 0, 0}, + {3249261197, 0, 0}, + {3619787319, 0, 0}, + {3627739127, 0, 0}, + {3669223677, 0, 0}, + {3787567939, 0, 0}, + {3898287302, 0, 0}, + {4142016703, 0, 0}, + {4237092412, 0, 0}, + {4285779501, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 16, 15}, + {0, 2, 33}, + {0, 41, 35}, + {0, 32, 30}, + {0, 39, 38}, + {0, 5, 1}, + {0, 9, 43}, + {0, 40, 22}, + {0, 29, 12}, + {0, 4, 3}, + {0, 25, 37}, + {0, 34, 26}, + {0, 45, 19}, + {0, 31, 24}, + {0, 47, 46}, + {0, 48, 20}, + {0, 49, 6}, + {0, 8, 21}, + {0, 50, 11}, + {0, 13, 10}, + {0, 51, 42}, + {0, 52, 23}, + {0, 54, 53}, + {0, 56, 55}, + {0, 58, 57}, + {0, 60, 59}, + {0, 61, 28}, + {0, 63, 62}, + {0, 65, 64}, + {0, 17, 66}, + {0, 67, 18}, + {0, 68, 7}, + {0, 70, 69}, + {0, 72, 71}, + {0, 74, 73}, + {0, 76, 75}, + {0, 14, 77}, + {0, 78, 27}, + {0, 80, 79}, + {0, 82, 81}, + {0, 83, 36}, + {0, 85, 84}, + {0, 44, 86}, + })); + + codecs.emplace(std::pair(SpvOpImageSampleImplicitLod, 3), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {2855506940, 0, 0}, + {3266548732, 0, 0}, + {3732640764, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 1}, + {0, 5, 4}, + {0, 3, 6}, + })); + + codecs.emplace(std::pair(SpvOpImageSampleImplicitLod, 5), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpImageSampleExplicitLod, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(139, { + {0, 0, 0}, + {27177503, 0, 0}, + {30663912, 0, 0}, + {151672195, 0, 0}, + {162608772, 0, 0}, + {180913835, 0, 0}, + {371621315, 0, 0}, + {414444763, 0, 0}, + {421602934, 0, 0}, + {443347828, 0, 0}, + {458937500, 0, 0}, + {587888644, 0, 0}, + {601656217, 0, 0}, + {665789406, 0, 0}, + {712168842, 0, 0}, + {730943059, 0, 0}, + {750870327, 0, 0}, + {875212982, 0, 0}, + {899320334, 0, 0}, + {973908139, 0, 0}, + {989813600, 0, 0}, + {1057606514, 0, 0}, + {1171541710, 0, 0}, + {1243764146, 0, 0}, + {1310404265, 0, 0}, + {1366337101, 0, 0}, + {1443547269, 0, 0}, + {1472185378, 0, 0}, + {1473799048, 0, 0}, + {1543935193, 0, 0}, + {1572834111, 0, 0}, + {1623013158, 0, 0}, + {1686512349, 0, 0}, + {1705716306, 0, 0}, + {1747355813, 0, 0}, + {1755165354, 0, 0}, + {1781864804, 0, 0}, + {1916983087, 0, 0}, + {1941403425, 0, 0}, + {2023008475, 0, 0}, + {2043684541, 0, 0}, + {2274226560, 0, 0}, + {2285438321, 0, 0}, + {2315690100, 0, 0}, + {2344328209, 0, 0}, + {2414725163, 0, 0}, + {2493146691, 0, 0}, + {2495155989, 0, 0}, + {2558655180, 0, 0}, + {2577859137, 0, 0}, + {2857814560, 0, 0}, + {2895151306, 0, 0}, + {2986830770, 0, 0}, + {3006548167, 0, 0}, + {3127329373, 0, 0}, + {3157581152, 0, 0}, + {3216471040, 0, 0}, + {3296722158, 0, 0}, + {3367298820, 0, 0}, + {3376009661, 0, 0}, + {3450001968, 0, 0}, + {3526837441, 0, 0}, + {3609540589, 0, 0}, + {3743398113, 0, 0}, + {3858973601, 0, 0}, + {3953984401, 0, 0}, + {3999472204, 0, 0}, + {4088613871, 0, 0}, + {4184019303, 0, 0}, + {4258229445, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 31, 16}, + {0, 58, 47}, + {0, 21, 61}, + {0, 6, 14}, + {0, 65, 23}, + {0, 35, 5}, + {0, 2, 7}, + {0, 10, 25}, + {0, 40, 22}, + {0, 9, 50}, + {0, 20, 11}, + {0, 38, 36}, + {0, 13, 12}, + {0, 67, 28}, + {0, 71, 68}, + {0, 73, 72}, + {0, 3, 29}, + {0, 27, 8}, + {0, 44, 37}, + {0, 74, 63}, + {0, 76, 75}, + {0, 18, 1}, + {0, 78, 77}, + {0, 80, 79}, + {0, 82, 81}, + {0, 26, 15}, + {0, 83, 43}, + {0, 85, 84}, + {0, 19, 86}, + {0, 48, 32}, + {0, 33, 46}, + {0, 87, 49}, + {0, 89, 88}, + {0, 91, 90}, + {0, 41, 30}, + {0, 52, 42}, + {0, 64, 55}, + {0, 92, 53}, + {0, 94, 93}, + {0, 51, 39}, + {0, 45, 95}, + {0, 66, 54}, + {0, 97, 96}, + {0, 57, 98}, + {0, 99, 69}, + {0, 101, 100}, + {0, 56, 102}, + {0, 4, 59}, + {0, 34, 17}, + {0, 103, 24}, + {0, 105, 104}, + {0, 107, 106}, + {0, 109, 108}, + {0, 60, 110}, + {0, 111, 62}, + {0, 113, 112}, + {0, 115, 114}, + {0, 117, 116}, + {0, 119, 118}, + {0, 121, 120}, + {0, 123, 122}, + {0, 125, 124}, + {0, 127, 126}, + {0, 129, 128}, + {0, 70, 130}, + {0, 132, 131}, + {0, 134, 133}, + {0, 136, 135}, + {0, 138, 137}, + })); + + codecs.emplace(std::pair(SpvOpImageSampleExplicitLod, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(11, { + {0, 0, 0}, + {883854656, 0, 0}, + {1962971231, 0, 0}, + {2036361232, 0, 0}, + {2366506734, 0, 0}, + {3829682756, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 4, 2}, + {0, 6, 7}, + {0, 8, 5}, + {0, 3, 9}, + {0, 1, 10}, + })); + + codecs.emplace(std::pair(SpvOpImageSampleExplicitLod, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(73, { + {0, 0, 0}, + {178571546, 0, 0}, + {223310468, 0, 0}, + {388034151, 0, 0}, + {449954059, 0, 0}, + {694743357, 0, 0}, + {797415788, 0, 0}, + {835638766, 0, 0}, + {1002144380, 0, 0}, + {1221183390, 0, 0}, + {1570165302, 0, 0}, + {1663234329, 0, 0}, + {1750829822, 0, 0}, + {1894133125, 0, 0}, + {1967643923, 0, 0}, + {1980341560, 0, 0}, + {2278706468, 0, 0}, + {2326990117, 0, 0}, + {2464905186, 0, 0}, + {2511346984, 0, 0}, + {2517964682, 0, 0}, + {2616085763, 0, 0}, + {2710583246, 0, 0}, + {2745872368, 0, 0}, + {2924263085, 0, 0}, + {3027500544, 0, 0}, + {3044723416, 0, 0}, + {3202324433, 0, 0}, + {3289213933, 0, 0}, + {3323682385, 0, 0}, + {3366848728, 0, 0}, + {3417583519, 0, 0}, + {3732916270, 0, 0}, + {3787909072, 0, 0}, + {3877813395, 0, 0}, + {4028028350, 0, 0}, + {4178218543, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 36, 31}, + {0, 15, 3}, + {0, 17, 1}, + {0, 24, 12}, + {0, 35, 34}, + {0, 28, 27}, + {0, 21, 38}, + {0, 6, 13}, + {0, 14, 7}, + {0, 39, 25}, + {0, 40, 30}, + {0, 42, 41}, + {0, 32, 43}, + {0, 23, 9}, + {0, 11, 44}, + {0, 45, 22}, + {0, 47, 46}, + {0, 2, 16}, + {0, 49, 48}, + {0, 4, 50}, + {0, 51, 18}, + {0, 53, 52}, + {0, 33, 54}, + {0, 26, 55}, + {0, 57, 56}, + {0, 5, 58}, + {0, 59, 8}, + {0, 19, 60}, + {0, 10, 61}, + {0, 29, 62}, + {0, 37, 63}, + {0, 65, 64}, + {0, 67, 66}, + {0, 20, 68}, + {0, 70, 69}, + {0, 72, 71}, + })); + + codecs.emplace(std::pair(SpvOpImageSampleExplicitLod, 3), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {2855506940, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpImageSampleExplicitLod, 5), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(13, { + {0, 0, 0}, + {3533637837, 0, 0}, + {3535289452, 0, 0}, + {3536941067, 0, 0}, + {3538592682, 0, 0}, + {3540244297, 0, 0}, + {3541895912, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 7}, + {0, 2, 8}, + {0, 9, 3}, + {0, 4, 10}, + {0, 5, 11}, + {0, 12, 6}, + })); + + codecs.emplace(std::pair(SpvOpImageSampleExplicitLod, 6), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(9, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 4, 5}, + {0, 2, 6}, + {0, 1, 3}, + {0, 8, 7}, + })); + + codecs.emplace(std::pair(SpvOpFAdd, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(45, { + {0, 0, 0}, + {328661377, 0, 0}, + {464259778, 0, 0}, + {920941800, 0, 0}, + {969500141, 0, 0}, + {1449907751, 0, 0}, + {1451831482, 0, 0}, + {1543798545, 0, 0}, + {1545450160, 0, 0}, + {1626224034, 0, 0}, + {1669930486, 0, 0}, + {1770165905, 0, 0}, + {2278571792, 0, 0}, + {2432827426, 0, 0}, + {2656211099, 0, 0}, + {2736844435, 0, 0}, + {2870852215, 0, 0}, + {2919626325, 0, 0}, + {2923708820, 0, 0}, + {3325419312, 0, 0}, + {3678875745, 0, 0}, + {4182141402, 0, 0}, + {4241374559, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 2, 6}, + {0, 9, 13}, + {0, 5, 15}, + {0, 4, 11}, + {0, 20, 22}, + {0, 10, 1}, + {0, 18, 14}, + {0, 16, 3}, + {0, 12, 21}, + {0, 8, 7}, + {0, 24, 17}, + {0, 19, 25}, + {0, 27, 26}, + {0, 29, 28}, + {0, 31, 30}, + {0, 33, 32}, + {0, 35, 34}, + {0, 37, 36}, + {0, 39, 38}, + {0, 41, 40}, + {0, 43, 42}, + {0, 23, 44}, + })); + + codecs.emplace(std::pair(SpvOpFAdd, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(89, { + {0, 0, 0}, + {135920445, 0, 0}, + {176166202, 0, 0}, + {294390719, 0, 0}, + {296981500, 0, 0}, + {743407979, 0, 0}, + {810488476, 0, 0}, + {837715723, 0, 0}, + {885020215, 0, 0}, + {922996215, 0, 0}, + {959681532, 0, 0}, + {963902061, 0, 0}, + {1136775085, 0, 0}, + {1189681639, 0, 0}, + {1203545131, 0, 0}, + {1297294717, 0, 0}, + {1317058015, 0, 0}, + {1352397672, 0, 0}, + {1367301635, 0, 0}, + {1412908157, 0, 0}, + {1570165302, 0, 0}, + {1763758554, 0, 0}, + {1791427568, 0, 0}, + {1992893964, 0, 0}, + {2013867381, 0, 0}, + {2096388952, 0, 0}, + {2219733501, 0, 0}, + {2383939514, 0, 0}, + {2517964682, 0, 0}, + {2555315060, 0, 0}, + {2572638469, 0, 0}, + {2762094724, 0, 0}, + {2770161927, 0, 0}, + {2855506940, 0, 0}, + {3044188332, 0, 0}, + {3187066832, 0, 0}, + {3319278167, 0, 0}, + {3653838348, 0, 0}, + {3675926744, 0, 0}, + {3701632935, 0, 0}, + {3712946115, 0, 0}, + {3732709413, 0, 0}, + {3743748793, 0, 0}, + {3783543823, 0, 0}, + {3930727258, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 15, 12}, + {0, 38, 16}, + {0, 41, 40}, + {0, 1, 33}, + {0, 21, 34}, + {0, 9, 2}, + {0, 24, 7}, + {0, 39, 44}, + {0, 29, 22}, + {0, 17, 19}, + {0, 36, 32}, + {0, 26, 18}, + {0, 30, 3}, + {0, 11, 8}, + {0, 42, 35}, + {0, 46, 31}, + {0, 27, 5}, + {0, 48, 47}, + {0, 28, 49}, + {0, 51, 50}, + {0, 52, 23}, + {0, 54, 53}, + {0, 13, 14}, + {0, 6, 55}, + {0, 57, 56}, + {0, 59, 58}, + {0, 60, 43}, + {0, 62, 61}, + {0, 37, 63}, + {0, 65, 64}, + {0, 67, 66}, + {0, 69, 68}, + {0, 70, 4}, + {0, 10, 71}, + {0, 72, 20}, + {0, 74, 73}, + {0, 76, 75}, + {0, 78, 77}, + {0, 80, 79}, + {0, 81, 25}, + {0, 83, 82}, + {0, 85, 84}, + {0, 87, 86}, + {0, 45, 88}, + })); + + codecs.emplace(std::pair(SpvOpFAdd, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(103, { + {0, 0, 0}, + {126463145, 0, 0}, + {220008971, 0, 0}, + {246375791, 0, 0}, + {503145996, 0, 0}, + {628331516, 0, 0}, + {643418617, 0, 0}, + {743407979, 0, 0}, + {837715723, 0, 0}, + {858902117, 0, 0}, + {870594305, 0, 0}, + {939671928, 0, 0}, + {959681532, 0, 0}, + {1051471757, 0, 0}, + {1092948665, 0, 0}, + {1097775533, 0, 0}, + {1136775085, 0, 0}, + {1140367371, 0, 0}, + {1332643570, 0, 0}, + {1367301635, 0, 0}, + {1558001705, 0, 0}, + {1684282922, 0, 0}, + {2096388952, 0, 0}, + {2183547611, 0, 0}, + {2219733501, 0, 0}, + {2358141757, 0, 0}, + {2359973133, 0, 0}, + {2383939514, 0, 0}, + {2444465148, 0, 0}, + {2517964682, 0, 0}, + {2567901801, 0, 0}, + {2598189097, 0, 0}, + {2655147757, 0, 0}, + {2683080096, 0, 0}, + {2705434194, 0, 0}, + {2738307068, 0, 0}, + {2780898906, 0, 0}, + {3030911670, 0, 0}, + {3032677281, 0, 0}, + {3063300848, 0, 0}, + {3277199633, 0, 0}, + {3289969989, 0, 0}, + {3401762422, 0, 0}, + {3436143898, 0, 0}, + {3560552546, 0, 0}, + {3656163446, 0, 0}, + {3675926744, 0, 0}, + {3701632935, 0, 0}, + {3743748793, 0, 0}, + {3752211294, 0, 0}, + {3794803132, 0, 0}, + {4241374559, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 21}, + {0, 17, 11}, + {0, 36, 35}, + {0, 46, 45}, + {0, 50, 49}, + {0, 9, 3}, + {0, 20, 47}, + {0, 37, 31}, + {0, 2, 34}, + {0, 40, 13}, + {0, 51, 32}, + {0, 41, 10}, + {0, 38, 19}, + {0, 18, 44}, + {0, 43, 16}, + {0, 48, 24}, + {0, 26, 5}, + {0, 53, 8}, + {0, 15, 7}, + {0, 25, 23}, + {0, 54, 27}, + {0, 56, 55}, + {0, 58, 57}, + {0, 60, 59}, + {0, 39, 42}, + {0, 62, 61}, + {0, 30, 63}, + {0, 4, 64}, + {0, 65, 28}, + {0, 66, 22}, + {0, 68, 67}, + {0, 69, 14}, + {0, 70, 33}, + {0, 71, 6}, + {0, 73, 72}, + {0, 75, 74}, + {0, 29, 76}, + {0, 78, 77}, + {0, 80, 79}, + {0, 82, 81}, + {0, 84, 83}, + {0, 86, 85}, + {0, 88, 87}, + {0, 90, 89}, + {0, 91, 12}, + {0, 93, 92}, + {0, 95, 94}, + {0, 97, 96}, + {0, 99, 98}, + {0, 101, 100}, + {0, 52, 102}, + })); + + codecs.emplace(std::pair(SpvOpFAdd, 3), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(9, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 5}, + {0, 4, 6}, + {0, 1, 7}, + {0, 2, 8}, + })); + + codecs.emplace(std::pair(SpvOpFSub, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(159, { + {0, 0, 0}, + {50385656, 0, 0}, + {117250846, 0, 0}, + {171494987, 0, 0}, + {195244192, 0, 0}, + {210754155, 0, 0}, + {265392489, 0, 0}, + {333855951, 0, 0}, + {416853049, 0, 0}, + {529068443, 0, 0}, + {533021259, 0, 0}, + {615982737, 0, 0}, + {660038281, 0, 0}, + {663341511, 0, 0}, + {669812542, 0, 0}, + {716890919, 0, 0}, + {1081536219, 0, 0}, + {1119744229, 0, 0}, + {1123617794, 0, 0}, + {1139547465, 0, 0}, + {1162789888, 0, 0}, + {1178317551, 0, 0}, + {1190147516, 0, 0}, + {1193734351, 0, 0}, + {1215030156, 0, 0}, + {1220749418, 0, 0}, + {1318479490, 0, 0}, + {1461398554, 0, 0}, + {1486207619, 0, 0}, + {1551372768, 0, 0}, + {1763758554, 0, 0}, + {1797960910, 0, 0}, + {1850331254, 0, 0}, + {1894417995, 0, 0}, + {1964254745, 0, 0}, + {1965902997, 0, 0}, + {1989327599, 0, 0}, + {2095027856, 0, 0}, + {2123683379, 0, 0}, + {2124837447, 0, 0}, + {2137526937, 0, 0}, + {2269114589, 0, 0}, + {2269130237, 0, 0}, + {2330636993, 0, 0}, + {2481746922, 0, 0}, + {2503770904, 0, 0}, + {2589449658, 0, 0}, + {2603020391, 0, 0}, + {2604576561, 0, 0}, + {2795773560, 0, 0}, + {2835131395, 0, 0}, + {2852854788, 0, 0}, + {2890638791, 0, 0}, + {2895413148, 0, 0}, + {2950446516, 0, 0}, + {2963744582, 0, 0}, + {3079287749, 0, 0}, + {3088785099, 0, 0}, + {3280064277, 0, 0}, + {3335250889, 0, 0}, + {3510242586, 0, 0}, + {3517169445, 0, 0}, + {3518703473, 0, 0}, + {3536471583, 0, 0}, + {3579593979, 0, 0}, + {3591222197, 0, 0}, + {3673811979, 0, 0}, + {3727034815, 0, 0}, + {3730093054, 0, 0}, + {3898287302, 0, 0}, + {3944781937, 0, 0}, + {3950980241, 0, 0}, + {4033586023, 0, 0}, + {4041974454, 0, 0}, + {4052965752, 0, 0}, + {4083161638, 0, 0}, + {4167600590, 0, 0}, + {4185661467, 0, 0}, + {4237092412, 0, 0}, + {4244540017, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 44, 18}, + {0, 69, 57}, + {0, 24, 16}, + {0, 79, 5}, + {0, 59, 4}, + {0, 76, 40}, + {0, 53, 45}, + {0, 14, 2}, + {0, 62, 61}, + {0, 33, 75}, + {0, 38, 37}, + {0, 42, 58}, + {0, 66, 47}, + {0, 63, 67}, + {0, 1, 7}, + {0, 10, 3}, + {0, 13, 12}, + {0, 23, 22}, + {0, 32, 28}, + {0, 36, 35}, + {0, 72, 49}, + {0, 74, 73}, + {0, 77, 55}, + {0, 27, 41}, + {0, 31, 15}, + {0, 6, 54}, + {0, 78, 17}, + {0, 81, 56}, + {0, 83, 82}, + {0, 85, 84}, + {0, 48, 30}, + {0, 71, 60}, + {0, 65, 51}, + {0, 87, 86}, + {0, 50, 34}, + {0, 89, 88}, + {0, 90, 9}, + {0, 25, 8}, + {0, 92, 91}, + {0, 93, 26}, + {0, 95, 94}, + {0, 52, 39}, + {0, 29, 20}, + {0, 97, 96}, + {0, 99, 98}, + {0, 101, 100}, + {0, 64, 102}, + {0, 104, 103}, + {0, 106, 105}, + {0, 21, 107}, + {0, 108, 68}, + {0, 109, 46}, + {0, 110, 11}, + {0, 112, 111}, + {0, 114, 113}, + {0, 116, 115}, + {0, 117, 70}, + {0, 43, 118}, + {0, 120, 119}, + {0, 122, 121}, + {0, 124, 123}, + {0, 126, 125}, + {0, 128, 127}, + {0, 129, 19}, + {0, 131, 130}, + {0, 133, 132}, + {0, 135, 134}, + {0, 137, 136}, + {0, 139, 138}, + {0, 141, 140}, + {0, 143, 142}, + {0, 145, 144}, + {0, 147, 146}, + {0, 149, 148}, + {0, 151, 150}, + {0, 153, 152}, + {0, 155, 154}, + {0, 157, 156}, + {0, 158, 80}, + })); + + codecs.emplace(std::pair(SpvOpFSub, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(103, { + {0, 0, 0}, + {50998433, 0, 0}, + {171494987, 0, 0}, + {249378857, 0, 0}, + {296981500, 0, 0}, + {508007510, 0, 0}, + {610429940, 0, 0}, + {660038281, 0, 0}, + {663341511, 0, 0}, + {836581417, 0, 0}, + {1027242654, 0, 0}, + {1167160774, 0, 0}, + {1191015885, 0, 0}, + {1200870684, 0, 0}, + {1203545131, 0, 0}, + {1265796414, 0, 0}, + {1319785741, 0, 0}, + {1669959736, 0, 0}, + {1684282922, 0, 0}, + {1752686878, 0, 0}, + {1850331254, 0, 0}, + {1901166356, 0, 0}, + {1906988301, 0, 0}, + {2055836767, 0, 0}, + {2095027856, 0, 0}, + {2096388952, 0, 0}, + {2144962711, 0, 0}, + {2217833278, 0, 0}, + {2500819054, 0, 0}, + {2525173102, 0, 0}, + {2575525651, 0, 0}, + {2660843182, 0, 0}, + {2855506940, 0, 0}, + {2918750759, 0, 0}, + {2919787747, 0, 0}, + {3091876332, 0, 0}, + {3187066832, 0, 0}, + {3244209297, 0, 0}, + {3423702268, 0, 0}, + {3508792859, 0, 0}, + {3548535223, 0, 0}, + {3619787319, 0, 0}, + {3653838348, 0, 0}, + {3692647551, 0, 0}, + {3713290482, 0, 0}, + {3753486980, 0, 0}, + {3783756895, 0, 0}, + {3797961332, 0, 0}, + {3836822275, 0, 0}, + {4043078107, 0, 0}, + {4052965752, 0, 0}, + {4091394002, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 31, 49}, + {0, 24, 19}, + {0, 46, 45}, + {0, 6, 48}, + {0, 12, 33}, + {0, 17, 21}, + {0, 43, 11}, + {0, 7, 2}, + {0, 9, 8}, + {0, 28, 13}, + {0, 44, 38}, + {0, 30, 50}, + {0, 26, 22}, + {0, 29, 51}, + {0, 34, 37}, + {0, 53, 40}, + {0, 23, 54}, + {0, 55, 25}, + {0, 27, 18}, + {0, 1, 10}, + {0, 57, 56}, + {0, 59, 58}, + {0, 5, 47}, + {0, 60, 20}, + {0, 62, 61}, + {0, 64, 63}, + {0, 66, 65}, + {0, 67, 39}, + {0, 69, 68}, + {0, 16, 70}, + {0, 3, 71}, + {0, 73, 72}, + {0, 41, 15}, + {0, 35, 74}, + {0, 76, 75}, + {0, 78, 77}, + {0, 36, 79}, + {0, 81, 80}, + {0, 83, 82}, + {0, 14, 84}, + {0, 86, 85}, + {0, 88, 87}, + {0, 32, 89}, + {0, 42, 90}, + {0, 92, 91}, + {0, 94, 93}, + {0, 96, 95}, + {0, 98, 97}, + {0, 52, 99}, + {0, 100, 4}, + {0, 102, 101}, + })); + + codecs.emplace(std::pair(SpvOpFSub, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(157, { + {0, 0, 0}, + {49456560, 0, 0}, + {170690025, 0, 0}, + {243178923, 0, 0}, + {295017943, 0, 0}, + {296981500, 0, 0}, + {330249537, 0, 0}, + {435256475, 0, 0}, + {443558693, 0, 0}, + {456043370, 0, 0}, + {470277359, 0, 0}, + {592180731, 0, 0}, + {663258455, 0, 0}, + {706238670, 0, 0}, + {810488476, 0, 0}, + {870594305, 0, 0}, + {877895868, 0, 0}, + {900522183, 0, 0}, + {1077859090, 0, 0}, + {1082941229, 0, 0}, + {1104362365, 0, 0}, + {1132589448, 0, 0}, + {1173092699, 0, 0}, + {1203545131, 0, 0}, + {1265796414, 0, 0}, + {1278818058, 0, 0}, + {1285705317, 0, 0}, + {1319785741, 0, 0}, + {1382106590, 0, 0}, + {1461897718, 0, 0}, + {1474506522, 0, 0}, + {1530183840, 0, 0}, + {1558001705, 0, 0}, + {1558990974, 0, 0}, + {1616846013, 0, 0}, + {1633850097, 0, 0}, + {1684282922, 0, 0}, + {1725011064, 0, 0}, + {1767704813, 0, 0}, + {1923453688, 0, 0}, + {1941148668, 0, 0}, + {1955104493, 0, 0}, + {2022961611, 0, 0}, + {2162274327, 0, 0}, + {2212501241, 0, 0}, + {2219733501, 0, 0}, + {2234361374, 0, 0}, + {2272221101, 0, 0}, + {2305269460, 0, 0}, + {2488410748, 0, 0}, + {2566666743, 0, 0}, + {2598189097, 0, 0}, + {2775815164, 0, 0}, + {2793529873, 0, 0}, + {2844616706, 0, 0}, + {2970183398, 0, 0}, + {3103302036, 0, 0}, + {3110479131, 0, 0}, + {3115038057, 0, 0}, + {3116932970, 0, 0}, + {3152745753, 0, 0}, + {3187066832, 0, 0}, + {3244209297, 0, 0}, + {3383007207, 0, 0}, + {3392887901, 0, 0}, + {3508792859, 0, 0}, + {3737376990, 0, 0}, + {3753486980, 0, 0}, + {3765247327, 0, 0}, + {3817149113, 0, 0}, + {3839047923, 0, 0}, + {3886529747, 0, 0}, + {4044928561, 0, 0}, + {4061558677, 0, 0}, + {4069720347, 0, 0}, + {4069810315, 0, 0}, + {4128942283, 0, 0}, + {4164704452, 0, 0}, + {4273793488, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 74, 47}, + {0, 34, 33}, + {0, 36, 14}, + {0, 61, 48}, + {0, 13, 31}, + {0, 39, 25}, + {0, 37, 29}, + {0, 65, 54}, + {0, 4, 73}, + {0, 38, 10}, + {0, 15, 43}, + {0, 6, 35}, + {0, 9, 16}, + {0, 30, 19}, + {0, 49, 44}, + {0, 57, 53}, + {0, 60, 58}, + {0, 72, 66}, + {0, 59, 76}, + {0, 1, 68}, + {0, 70, 42}, + {0, 63, 3}, + {0, 28, 69}, + {0, 17, 55}, + {0, 45, 64}, + {0, 81, 80}, + {0, 7, 82}, + {0, 12, 11}, + {0, 21, 50}, + {0, 83, 18}, + {0, 22, 84}, + {0, 85, 26}, + {0, 20, 86}, + {0, 87, 40}, + {0, 56, 88}, + {0, 90, 89}, + {0, 92, 91}, + {0, 93, 2}, + {0, 95, 94}, + {0, 97, 96}, + {0, 98, 41}, + {0, 100, 99}, + {0, 101, 52}, + {0, 103, 102}, + {0, 77, 71}, + {0, 104, 78}, + {0, 105, 46}, + {0, 32, 8}, + {0, 106, 51}, + {0, 108, 107}, + {0, 23, 109}, + {0, 110, 27}, + {0, 112, 111}, + {0, 113, 75}, + {0, 115, 114}, + {0, 117, 116}, + {0, 119, 118}, + {0, 121, 120}, + {0, 123, 122}, + {0, 124, 62}, + {0, 126, 125}, + {0, 128, 127}, + {0, 67, 129}, + {0, 131, 130}, + {0, 5, 132}, + {0, 134, 133}, + {0, 136, 135}, + {0, 138, 137}, + {0, 139, 24}, + {0, 141, 140}, + {0, 143, 142}, + {0, 145, 144}, + {0, 147, 146}, + {0, 149, 148}, + {0, 151, 150}, + {0, 153, 152}, + {0, 79, 154}, + {0, 156, 155}, + })); + + codecs.emplace(std::pair(SpvOpFSub, 3), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(9, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 4, 5}, + {0, 3, 6}, + {0, 1, 7}, + {0, 8, 2}, + })); + + codecs.emplace(std::pair(SpvOpFMul, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(41, { + {0, 0, 0}, + {342197850, 0, 0}, + {885020215, 0, 0}, + {963902061, 0, 0}, + {1041368449, 0, 0}, + {1352397672, 0, 0}, + {1791427568, 0, 0}, + {2013867381, 0, 0}, + {2513230733, 0, 0}, + {2555315060, 0, 0}, + {2562485583, 0, 0}, + {2567901801, 0, 0}, + {2655147757, 0, 0}, + {2680283743, 0, 0}, + {2752766693, 0, 0}, + {2806716850, 0, 0}, + {3030911670, 0, 0}, + {3401762422, 0, 0}, + {3697738938, 0, 0}, + {4164704452, 0, 0}, + {4273793488, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 14, 10}, + {0, 7, 16}, + {0, 1, 15}, + {0, 9, 6}, + {0, 4, 12}, + {0, 18, 5}, + {0, 13, 2}, + {0, 19, 3}, + {0, 17, 20}, + {0, 23, 22}, + {0, 24, 8}, + {0, 26, 25}, + {0, 27, 11}, + {0, 29, 28}, + {0, 31, 30}, + {0, 33, 32}, + {0, 35, 34}, + {0, 37, 36}, + {0, 39, 38}, + {0, 21, 40}, + })); + + codecs.emplace(std::pair(SpvOpFMul, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(129, { + {0, 0, 0}, + {126463145, 0, 0}, + {129135650, 0, 0}, + {200922300, 0, 0}, + {328661377, 0, 0}, + {354479447, 0, 0}, + {360730278, 0, 0}, + {451264926, 0, 0}, + {529068443, 0, 0}, + {593829839, 0, 0}, + {742917749, 0, 0}, + {761731755, 0, 0}, + {810488476, 0, 0}, + {870594305, 0, 0}, + {894529125, 0, 0}, + {959681532, 0, 0}, + {1054461787, 0, 0}, + {1077859090, 0, 0}, + {1086964761, 0, 0}, + {1158929937, 0, 0}, + {1168927492, 0, 0}, + {1196280518, 0, 0}, + {1203545131, 0, 0}, + {1367301635, 0, 0}, + {1508550646, 0, 0}, + {1618544981, 0, 0}, + {1661163736, 0, 0}, + {1684282922, 0, 0}, + {1766994680, 0, 0}, + {1830851200, 0, 0}, + {1901166356, 0, 0}, + {1955104493, 0, 0}, + {2055836767, 0, 0}, + {2096388952, 0, 0}, + {2100052708, 0, 0}, + {2161102232, 0, 0}, + {2197904616, 0, 0}, + {2262137600, 0, 0}, + {2278571792, 0, 0}, + {2281956980, 0, 0}, + {2438466459, 0, 0}, + {2443959748, 0, 0}, + {2517964682, 0, 0}, + {2557754096, 0, 0}, + {2622612602, 0, 0}, + {2660843182, 0, 0}, + {2736844435, 0, 0}, + {2780898906, 0, 0}, + {3044188332, 0, 0}, + {3059119137, 0, 0}, + {3194725903, 0, 0}, + {3270430997, 0, 0}, + {3337532056, 0, 0}, + {3407526215, 0, 0}, + {3496407048, 0, 0}, + {3504158761, 0, 0}, + {3534518722, 0, 0}, + {3570411982, 0, 0}, + {3701632935, 0, 0}, + {3929248764, 0, 0}, + {3944781937, 0, 0}, + {3970432934, 0, 0}, + {4008405264, 0, 0}, + {4245257809, 0, 0}, + {4253051659, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 6, 26}, + {0, 46, 24}, + {0, 64, 50}, + {0, 7, 17}, + {0, 40, 57}, + {0, 56, 49}, + {0, 34, 10}, + {0, 32, 61}, + {0, 36, 44}, + {0, 8, 43}, + {0, 4, 18}, + {0, 25, 23}, + {0, 9, 54}, + {0, 45, 41}, + {0, 13, 21}, + {0, 47, 31}, + {0, 39, 53}, + {0, 11, 3}, + {0, 29, 20}, + {0, 38, 58}, + {0, 37, 14}, + {0, 66, 52}, + {0, 67, 35}, + {0, 48, 68}, + {0, 1, 69}, + {0, 70, 28}, + {0, 27, 63}, + {0, 72, 71}, + {0, 74, 73}, + {0, 75, 60}, + {0, 77, 76}, + {0, 5, 51}, + {0, 15, 78}, + {0, 30, 79}, + {0, 55, 80}, + {0, 42, 81}, + {0, 83, 82}, + {0, 85, 84}, + {0, 86, 2}, + {0, 19, 16}, + {0, 87, 59}, + {0, 62, 88}, + {0, 90, 89}, + {0, 22, 91}, + {0, 93, 92}, + {0, 95, 94}, + {0, 97, 96}, + {0, 99, 98}, + {0, 101, 100}, + {0, 12, 102}, + {0, 104, 103}, + {0, 33, 105}, + {0, 107, 106}, + {0, 109, 108}, + {0, 111, 110}, + {0, 113, 112}, + {0, 115, 114}, + {0, 117, 116}, + {0, 119, 118}, + {0, 121, 120}, + {0, 123, 122}, + {0, 125, 124}, + {0, 127, 126}, + {0, 65, 128}, + })); + + codecs.emplace(std::pair(SpvOpFMul, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(127, { + {0, 0, 0}, + {13319433, 0, 0}, + {15502752, 0, 0}, + {162608772, 0, 0}, + {171307615, 0, 0}, + {296981500, 0, 0}, + {354479447, 0, 0}, + {413918748, 0, 0}, + {443490822, 0, 0}, + {487719832, 0, 0}, + {593829839, 0, 0}, + {615982737, 0, 0}, + {703543228, 0, 0}, + {810488476, 0, 0}, + {870594305, 0, 0}, + {875212982, 0, 0}, + {959681532, 0, 0}, + {1019457583, 0, 0}, + {1203545131, 0, 0}, + {1278448636, 0, 0}, + {1325348861, 0, 0}, + {1368383673, 0, 0}, + {1400019344, 0, 0}, + {1646147798, 0, 0}, + {1679946323, 0, 0}, + {1684282922, 0, 0}, + {1747355813, 0, 0}, + {1755648697, 0, 0}, + {1793544760, 0, 0}, + {1811839150, 0, 0}, + {1901166356, 0, 0}, + {1947620272, 0, 0}, + {1992893964, 0, 0}, + {2042001863, 0, 0}, + {2096388952, 0, 0}, + {2123388694, 0, 0}, + {2128251367, 0, 0}, + {2130747644, 0, 0}, + {2135340676, 0, 0}, + {2161102232, 0, 0}, + {2443959748, 0, 0}, + {2513230733, 0, 0}, + {2557754096, 0, 0}, + {2580096524, 0, 0}, + {2589449658, 0, 0}, + {2614879967, 0, 0}, + {2698156268, 0, 0}, + {2970183398, 0, 0}, + {3002890475, 0, 0}, + {3133016299, 0, 0}, + {3142155593, 0, 0}, + {3187066832, 0, 0}, + {3266548732, 0, 0}, + {3287039847, 0, 0}, + {3357301402, 0, 0}, + {3413713311, 0, 0}, + {3434076295, 0, 0}, + {3496407048, 0, 0}, + {3504158761, 0, 0}, + {3882634684, 0, 0}, + {3929248764, 0, 0}, + {3987079331, 0, 0}, + {4076840151, 0, 0}, + {4243119782, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 31, 8}, + {0, 14, 56}, + {0, 7, 12}, + {0, 9, 30}, + {0, 42, 36}, + {0, 19, 11}, + {0, 22, 40}, + {0, 15, 3}, + {0, 57, 26}, + {0, 58, 61}, + {0, 55, 51}, + {0, 48, 34}, + {0, 20, 1}, + {0, 24, 23}, + {0, 46, 35}, + {0, 59, 49}, + {0, 21, 63}, + {0, 62, 44}, + {0, 6, 50}, + {0, 28, 18}, + {0, 66, 65}, + {0, 41, 32}, + {0, 39, 54}, + {0, 53, 67}, + {0, 68, 37}, + {0, 33, 69}, + {0, 43, 70}, + {0, 71, 38}, + {0, 72, 27}, + {0, 13, 47}, + {0, 45, 73}, + {0, 75, 74}, + {0, 76, 5}, + {0, 77, 17}, + {0, 79, 78}, + {0, 52, 80}, + {0, 2, 81}, + {0, 83, 82}, + {0, 85, 84}, + {0, 87, 86}, + {0, 4, 88}, + {0, 16, 29}, + {0, 90, 89}, + {0, 92, 91}, + {0, 94, 93}, + {0, 60, 95}, + {0, 97, 96}, + {0, 98, 10}, + {0, 25, 99}, + {0, 101, 100}, + {0, 103, 102}, + {0, 105, 104}, + {0, 107, 106}, + {0, 109, 108}, + {0, 111, 110}, + {0, 113, 112}, + {0, 115, 114}, + {0, 117, 116}, + {0, 119, 118}, + {0, 121, 120}, + {0, 123, 122}, + {0, 125, 124}, + {0, 64, 126}, + })); + + codecs.emplace(std::pair(SpvOpFMul, 3), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(9, { + {0, 0, 0}, + {679771963, 0, 0}, + {1951208733, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 4, 5}, + {0, 3, 6}, + {0, 7, 1}, + {0, 2, 8}, + })); + + codecs.emplace(std::pair(SpvOpFDiv, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(153, { + {0, 0, 0}, + {10142671, 0, 0}, + {27865391, 0, 0}, + {29517006, 0, 0}, + {41739659, 0, 0}, + {97231530, 0, 0}, + {171334650, 0, 0}, + {200553094, 0, 0}, + {257136089, 0, 0}, + {294390719, 0, 0}, + {375530199, 0, 0}, + {380957745, 0, 0}, + {388034151, 0, 0}, + {455591063, 0, 0}, + {462664429, 0, 0}, + {491456522, 0, 0}, + {502863753, 0, 0}, + {626480004, 0, 0}, + {643418617, 0, 0}, + {651464351, 0, 0}, + {701281393, 0, 0}, + {744817486, 0, 0}, + {783918780, 0, 0}, + {862784766, 0, 0}, + {930804377, 0, 0}, + {952536201, 0, 0}, + {955476870, 0, 0}, + {1043738701, 0, 0}, + {1047011733, 0, 0}, + {1080545747, 0, 0}, + {1137442027, 0, 0}, + {1235468610, 0, 0}, + {1412908157, 0, 0}, + {1431749301, 0, 0}, + {1434223270, 0, 0}, + {1440646342, 0, 0}, + {1508570930, 0, 0}, + {1510422521, 0, 0}, + {1548121999, 0, 0}, + {1582841441, 0, 0}, + {1612225949, 0, 0}, + {1665981878, 0, 0}, + {1680746207, 0, 0}, + {1696076631, 0, 0}, + {1702168830, 0, 0}, + {1761469971, 0, 0}, + {1799299383, 0, 0}, + {1910240213, 0, 0}, + {1917451875, 0, 0}, + {1945006185, 0, 0}, + {1998444837, 0, 0}, + {2045285083, 0, 0}, + {2217966239, 0, 0}, + {2279273489, 0, 0}, + {2289803479, 0, 0}, + {2348676810, 0, 0}, + {2353194283, 0, 0}, + {2403632109, 0, 0}, + {2409539315, 0, 0}, + {2414984922, 0, 0}, + {2477389837, 0, 0}, + {2524531022, 0, 0}, + {2573160348, 0, 0}, + {2639720559, 0, 0}, + {2773229577, 0, 0}, + {2796513469, 0, 0}, + {2881225774, 0, 0}, + {2890570341, 0, 0}, + {2952850186, 0, 0}, + {3023287679, 0, 0}, + {3118548424, 0, 0}, + {3877813395, 0, 0}, + {3931288033, 0, 0}, + {3972309363, 0, 0}, + {4117704995, 0, 0}, + {4140081844, 0, 0}, + {4258414038, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 74, 53}, + {0, 58, 52}, + {0, 65, 60}, + {0, 41, 5}, + {0, 1, 67}, + {0, 24, 28}, + {0, 27, 26}, + {0, 55, 31}, + {0, 36, 61}, + {0, 13, 49}, + {0, 56, 48}, + {0, 16, 64}, + {0, 76, 42}, + {0, 45, 29}, + {0, 23, 6}, + {0, 72, 12}, + {0, 35, 19}, + {0, 20, 7}, + {0, 21, 46}, + {0, 71, 78}, + {0, 80, 79}, + {0, 47, 17}, + {0, 81, 70}, + {0, 34, 25}, + {0, 83, 82}, + {0, 85, 84}, + {0, 37, 86}, + {0, 87, 73}, + {0, 10, 4}, + {0, 40, 30}, + {0, 88, 57}, + {0, 54, 89}, + {0, 50, 90}, + {0, 11, 91}, + {0, 39, 15}, + {0, 59, 44}, + {0, 92, 66}, + {0, 69, 93}, + {0, 95, 94}, + {0, 14, 96}, + {0, 98, 97}, + {0, 62, 51}, + {0, 100, 99}, + {0, 102, 101}, + {0, 104, 103}, + {0, 32, 43}, + {0, 105, 38}, + {0, 107, 106}, + {0, 109, 108}, + {0, 22, 9}, + {0, 33, 110}, + {0, 2, 111}, + {0, 112, 3}, + {0, 114, 113}, + {0, 116, 115}, + {0, 68, 63}, + {0, 118, 117}, + {0, 120, 119}, + {0, 121, 8}, + {0, 123, 122}, + {0, 125, 124}, + {0, 127, 126}, + {0, 129, 128}, + {0, 131, 130}, + {0, 133, 132}, + {0, 75, 18}, + {0, 135, 134}, + {0, 137, 136}, + {0, 139, 138}, + {0, 141, 140}, + {0, 143, 142}, + {0, 145, 144}, + {0, 147, 146}, + {0, 149, 148}, + {0, 150, 77}, + {0, 152, 151}, + })); + + codecs.emplace(std::pair(SpvOpFDiv, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(131, { + {0, 0, 0}, + {5908395, 0, 0}, + {139011596, 0, 0}, + {296981500, 0, 0}, + {342615870, 0, 0}, + {370232173, 0, 0}, + {492958971, 0, 0}, + {528662843, 0, 0}, + {551924251, 0, 0}, + {604894932, 0, 0}, + {610429940, 0, 0}, + {780957373, 0, 0}, + {810488476, 0, 0}, + {872544165, 0, 0}, + {878733439, 0, 0}, + {918849409, 0, 0}, + {959681532, 0, 0}, + {1013756921, 0, 0}, + {1038982109, 0, 0}, + {1081611718, 0, 0}, + {1125913837, 0, 0}, + {1209418480, 0, 0}, + {1318081294, 0, 0}, + {1367301635, 0, 0}, + {1417425499, 0, 0}, + {1625742020, 0, 0}, + {1684282922, 0, 0}, + {1746004874, 0, 0}, + {1758287856, 0, 0}, + {1777640493, 0, 0}, + {2066323109, 0, 0}, + {2094550054, 0, 0}, + {2096388952, 0, 0}, + {2144962711, 0, 0}, + {2434845539, 0, 0}, + {2480811229, 0, 0}, + {2552825357, 0, 0}, + {2636946065, 0, 0}, + {2651956495, 0, 0}, + {2669086217, 0, 0}, + {2680819379, 0, 0}, + {2709694527, 0, 0}, + {2715304020, 0, 0}, + {2790648021, 0, 0}, + {2802261839, 0, 0}, + {2806296851, 0, 0}, + {2864543087, 0, 0}, + {2952260510, 0, 0}, + {2963184673, 0, 0}, + {3091876332, 0, 0}, + {3098991995, 0, 0}, + {3131890669, 0, 0}, + {3138977758, 0, 0}, + {3198541202, 0, 0}, + {3260579369, 0, 0}, + {3263841912, 0, 0}, + {3335250889, 0, 0}, + {3345856521, 0, 0}, + {3381478137, 0, 0}, + {3489269251, 0, 0}, + {3510242586, 0, 0}, + {3820814597, 0, 0}, + {3900859293, 0, 0}, + {4041974454, 0, 0}, + {4244540017, 0, 0}, + {4265894873, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 15, 52}, + {0, 20, 18}, + {0, 39, 29}, + {0, 9, 43}, + {0, 22, 13}, + {0, 46, 27}, + {0, 51, 48}, + {0, 19, 57}, + {0, 34, 24}, + {0, 64, 59}, + {0, 5, 7}, + {0, 38, 37}, + {0, 45, 47}, + {0, 2, 56}, + {0, 67, 8}, + {0, 17, 68}, + {0, 69, 61}, + {0, 70, 6}, + {0, 55, 54}, + {0, 72, 71}, + {0, 4, 73}, + {0, 74, 40}, + {0, 30, 11}, + {0, 42, 36}, + {0, 75, 58}, + {0, 31, 76}, + {0, 1, 77}, + {0, 44, 14}, + {0, 78, 50}, + {0, 79, 23}, + {0, 26, 80}, + {0, 81, 12}, + {0, 83, 82}, + {0, 84, 21}, + {0, 32, 85}, + {0, 87, 86}, + {0, 35, 10}, + {0, 88, 62}, + {0, 90, 89}, + {0, 41, 91}, + {0, 92, 53}, + {0, 93, 63}, + {0, 95, 94}, + {0, 33, 96}, + {0, 98, 97}, + {0, 99, 3}, + {0, 100, 28}, + {0, 101, 49}, + {0, 102, 60}, + {0, 104, 103}, + {0, 106, 105}, + {0, 108, 107}, + {0, 110, 109}, + {0, 65, 111}, + {0, 25, 112}, + {0, 114, 113}, + {0, 116, 115}, + {0, 117, 16}, + {0, 119, 118}, + {0, 121, 120}, + {0, 123, 122}, + {0, 125, 124}, + {0, 127, 126}, + {0, 128, 66}, + {0, 130, 129}, + })); + + codecs.emplace(std::pair(SpvOpFDiv, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(95, { + {0, 0, 0}, + {116093251, 0, 0}, + {149720480, 0, 0}, + {183103444, 0, 0}, + {251209228, 0, 0}, + {296981500, 0, 0}, + {357505993, 0, 0}, + {394654115, 0, 0}, + {410274915, 0, 0}, + {452208841, 0, 0}, + {788046331, 0, 0}, + {797934924, 0, 0}, + {810488476, 0, 0}, + {1144188012, 0, 0}, + {1220127364, 0, 0}, + {1321616112, 0, 0}, + {1324351672, 0, 0}, + {1348149915, 0, 0}, + {1459457331, 0, 0}, + {1465623797, 0, 0}, + {1531216990, 0, 0}, + {1543672828, 0, 0}, + {1578775276, 0, 0}, + {1738815671, 0, 0}, + {1904128160, 0, 0}, + {2071351379, 0, 0}, + {2119793999, 0, 0}, + {2274779301, 0, 0}, + {2291766425, 0, 0}, + {2357410109, 0, 0}, + {2438466459, 0, 0}, + {2496463830, 0, 0}, + {2630220147, 0, 0}, + {2682510803, 0, 0}, + {3047649911, 0, 0}, + {3085703811, 0, 0}, + {3235459678, 0, 0}, + {3261703164, 0, 0}, + {3331487616, 0, 0}, + {3462674048, 0, 0}, + {3570219049, 0, 0}, + {3585315836, 0, 0}, + {3602108619, 0, 0}, + {3724004880, 0, 0}, + {3931641900, 0, 0}, + {3955205564, 0, 0}, + {4073492988, 0, 0}, + {4127308103, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 24, 37}, + {0, 13, 38}, + {0, 17, 39}, + {0, 35, 23}, + {0, 18, 36}, + {0, 46, 19}, + {0, 20, 33}, + {0, 47, 6}, + {0, 1, 45}, + {0, 3, 27}, + {0, 8, 49}, + {0, 50, 29}, + {0, 10, 51}, + {0, 43, 31}, + {0, 53, 52}, + {0, 54, 26}, + {0, 7, 55}, + {0, 56, 32}, + {0, 57, 41}, + {0, 59, 58}, + {0, 61, 60}, + {0, 63, 62}, + {0, 64, 25}, + {0, 2, 34}, + {0, 65, 14}, + {0, 67, 66}, + {0, 12, 21}, + {0, 9, 68}, + {0, 69, 16}, + {0, 71, 70}, + {0, 72, 44}, + {0, 11, 73}, + {0, 74, 30}, + {0, 4, 75}, + {0, 28, 15}, + {0, 76, 42}, + {0, 5, 77}, + {0, 78, 40}, + {0, 80, 79}, + {0, 82, 81}, + {0, 22, 83}, + {0, 85, 84}, + {0, 86, 48}, + {0, 88, 87}, + {0, 90, 89}, + {0, 92, 91}, + {0, 94, 93}, + })); + + codecs.emplace(std::pair(SpvOpFDiv, 3), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(7, { + {0, 0, 0}, + {679771963, 0, 0}, + {2320303498, 0, 0}, + {3334207724, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 3, 4}, + {0, 2, 5}, + {0, 1, 6}, + })); + + codecs.emplace(std::pair(SpvOpVectorTimesScalar, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(121, { + {0, 0, 0}, + {14113753, 0, 0}, + {102358168, 0, 0}, + {179458548, 0, 0}, + {330388453, 0, 0}, + {386525753, 0, 0}, + {470277359, 0, 0}, + {497658126, 0, 0}, + {508007510, 0, 0}, + {815034111, 0, 0}, + {826214242, 0, 0}, + {849867303, 0, 0}, + {885645401, 0, 0}, + {939415664, 0, 0}, + {968885186, 0, 0}, + {1105835505, 0, 0}, + {1159301677, 0, 0}, + {1461897718, 0, 0}, + {1482251215, 0, 0}, + {1486206763, 0, 0}, + {1527762373, 0, 0}, + {1558990974, 0, 0}, + {1618754372, 0, 0}, + {1669959736, 0, 0}, + {1752686878, 0, 0}, + {2004567202, 0, 0}, + {2055637638, 0, 0}, + {2113506324, 0, 0}, + {2154320787, 0, 0}, + {2162274327, 0, 0}, + {2306141594, 0, 0}, + {2345566651, 0, 0}, + {2457690657, 0, 0}, + {2473053808, 0, 0}, + {2500422644, 0, 0}, + {2504802016, 0, 0}, + {2506771164, 0, 0}, + {2793529873, 0, 0}, + {2801333547, 0, 0}, + {2879050471, 0, 0}, + {3032677281, 0, 0}, + {3045470312, 0, 0}, + {3181546731, 0, 0}, + {3240977890, 0, 0}, + {3262572726, 0, 0}, + {3307100165, 0, 0}, + {3425841570, 0, 0}, + {3560552546, 0, 0}, + {3641833815, 0, 0}, + {3652695478, 0, 0}, + {3782362128, 0, 0}, + {3797961332, 0, 0}, + {3837583704, 0, 0}, + {3886529747, 0, 0}, + {3907920335, 0, 0}, + {4043078107, 0, 0}, + {4044928561, 0, 0}, + {4069720347, 0, 0}, + {4180570743, 0, 0}, + {4245743275, 0, 0}, + {4285201458, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 44, 28}, + {0, 13, 45}, + {0, 19, 15}, + {0, 32, 31}, + {0, 43, 42}, + {0, 16, 52}, + {0, 33, 22}, + {0, 57, 55}, + {0, 24, 21}, + {0, 2, 59}, + {0, 10, 3}, + {0, 18, 12}, + {0, 41, 39}, + {0, 60, 46}, + {0, 4, 25}, + {0, 58, 49}, + {0, 14, 1}, + {0, 27, 17}, + {0, 50, 36}, + {0, 23, 54}, + {0, 5, 30}, + {0, 11, 7}, + {0, 38, 29}, + {0, 37, 8}, + {0, 48, 56}, + {0, 20, 6}, + {0, 34, 26}, + {0, 63, 62}, + {0, 65, 64}, + {0, 67, 66}, + {0, 69, 68}, + {0, 71, 70}, + {0, 73, 72}, + {0, 75, 74}, + {0, 9, 76}, + {0, 78, 77}, + {0, 80, 79}, + {0, 82, 81}, + {0, 84, 83}, + {0, 40, 35}, + {0, 85, 47}, + {0, 86, 51}, + {0, 88, 87}, + {0, 90, 89}, + {0, 53, 91}, + {0, 93, 92}, + {0, 95, 94}, + {0, 97, 96}, + {0, 99, 98}, + {0, 101, 100}, + {0, 103, 102}, + {0, 105, 104}, + {0, 107, 106}, + {0, 109, 108}, + {0, 111, 110}, + {0, 113, 112}, + {0, 115, 114}, + {0, 117, 116}, + {0, 119, 118}, + {0, 61, 120}, + })); + + codecs.emplace(std::pair(SpvOpVectorTimesScalar, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(127, { + {0, 0, 0}, + {100979271, 0, 0}, + {269576093, 0, 0}, + {314809953, 0, 0}, + {354479447, 0, 0}, + {497658126, 0, 0}, + {882718761, 0, 0}, + {968885186, 0, 0}, + {973908139, 0, 0}, + {1019457583, 0, 0}, + {1191015885, 0, 0}, + {1266262705, 0, 0}, + {1310404265, 0, 0}, + {1325348861, 0, 0}, + {1367301635, 0, 0}, + {1368383673, 0, 0}, + {1570165302, 0, 0}, + {1618544981, 0, 0}, + {1646147798, 0, 0}, + {1674464100, 0, 0}, + {1679946323, 0, 0}, + {1686512349, 0, 0}, + {1766401548, 0, 0}, + {1774052499, 0, 0}, + {1788301425, 0, 0}, + {2023008475, 0, 0}, + {2055836767, 0, 0}, + {2096388952, 0, 0}, + {2123388694, 0, 0}, + {2129301998, 0, 0}, + {2212501241, 0, 0}, + {2274226560, 0, 0}, + {2362972044, 0, 0}, + {2378763734, 0, 0}, + {2506771164, 0, 0}, + {2558655180, 0, 0}, + {2622612602, 0, 0}, + {2660843182, 0, 0}, + {2698156268, 0, 0}, + {2801333547, 0, 0}, + {2850246066, 0, 0}, + {2895151306, 0, 0}, + {2970183398, 0, 0}, + {2986830770, 0, 0}, + {3001444829, 0, 0}, + {3133016299, 0, 0}, + {3152745753, 0, 0}, + {3187066832, 0, 0}, + {3261122899, 0, 0}, + {3496407048, 0, 0}, + {3513669836, 0, 0}, + {3536390697, 0, 0}, + {3570411982, 0, 0}, + {3653838348, 0, 0}, + {3713290482, 0, 0}, + {3858973601, 0, 0}, + {3873587660, 0, 0}, + {3877583949, 0, 0}, + {3882634684, 0, 0}, + {3907920335, 0, 0}, + {3997432565, 0, 0}, + {4169226615, 0, 0}, + {4219766939, 0, 0}, + {4243119782, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 25, 12}, + {0, 41, 29}, + {0, 56, 44}, + {0, 1, 3}, + {0, 48, 24}, + {0, 33, 60}, + {0, 8, 50}, + {0, 35, 21}, + {0, 11, 7}, + {0, 34, 23}, + {0, 59, 57}, + {0, 10, 62}, + {0, 40, 2}, + {0, 5, 49}, + {0, 39, 17}, + {0, 9, 61}, + {0, 30, 6}, + {0, 19, 46}, + {0, 53, 54}, + {0, 31, 52}, + {0, 55, 43}, + {0, 66, 65}, + {0, 16, 67}, + {0, 51, 68}, + {0, 70, 69}, + {0, 26, 36}, + {0, 72, 71}, + {0, 74, 73}, + {0, 76, 75}, + {0, 78, 77}, + {0, 80, 79}, + {0, 82, 81}, + {0, 37, 83}, + {0, 85, 84}, + {0, 13, 86}, + {0, 20, 18}, + {0, 38, 28}, + {0, 58, 45}, + {0, 87, 63}, + {0, 15, 88}, + {0, 32, 22}, + {0, 89, 4}, + {0, 90, 14}, + {0, 91, 42}, + {0, 93, 92}, + {0, 95, 94}, + {0, 97, 96}, + {0, 99, 98}, + {0, 101, 100}, + {0, 103, 102}, + {0, 105, 104}, + {0, 107, 106}, + {0, 109, 108}, + {0, 111, 110}, + {0, 113, 112}, + {0, 115, 114}, + {0, 27, 47}, + {0, 117, 116}, + {0, 119, 118}, + {0, 121, 120}, + {0, 123, 122}, + {0, 125, 124}, + {0, 126, 64}, + })); + + codecs.emplace(std::pair(SpvOpVectorTimesScalar, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(137, { + {0, 0, 0}, + {11698369, 0, 0}, + {146392076, 0, 0}, + {151810803, 0, 0}, + {223800276, 0, 0}, + {227103506, 0, 0}, + {253329281, 0, 0}, + {346929928, 0, 0}, + {461040879, 0, 0}, + {629859130, 0, 0}, + {680157484, 0, 0}, + {783918780, 0, 0}, + {810488476, 0, 0}, + {824323032, 0, 0}, + {870594305, 0, 0}, + {959681532, 0, 0}, + {975807626, 0, 0}, + {1081642571, 0, 0}, + {1084574846, 0, 0}, + {1094817798, 0, 0}, + {1141965917, 0, 0}, + {1164137269, 0, 0}, + {1166917451, 0, 0}, + {1204787336, 0, 0}, + {1232501371, 0, 0}, + {1318479490, 0, 0}, + {1369818198, 0, 0}, + {1372785527, 0, 0}, + {1526654696, 0, 0}, + {1543672828, 0, 0}, + {1548121999, 0, 0}, + {1635292159, 0, 0}, + {1641070431, 0, 0}, + {1684282922, 0, 0}, + {1767704813, 0, 0}, + {1781765116, 0, 0}, + {1838763297, 0, 0}, + {1901166356, 0, 0}, + {1904846533, 0, 0}, + {2011183308, 0, 0}, + {2032069771, 0, 0}, + {2071351379, 0, 0}, + {2087004702, 0, 0}, + {2244928358, 0, 0}, + {2314864456, 0, 0}, + {2374216296, 0, 0}, + {2394332122, 0, 0}, + {2443610186, 0, 0}, + {2524697596, 0, 0}, + {2526961521, 0, 0}, + {2568098594, 0, 0}, + {2807907995, 0, 0}, + {3103302036, 0, 0}, + {3117071189, 0, 0}, + {3188115516, 0, 0}, + {3417584874, 0, 0}, + {3554463148, 0, 0}, + {3561482820, 0, 0}, + {3691770462, 0, 0}, + {3729929345, 0, 0}, + {3733675151, 0, 0}, + {3831290364, 0, 0}, + {3866493821, 0, 0}, + {3929248764, 0, 0}, + {4060703604, 0, 0}, + {4092487128, 0, 0}, + {4167600590, 0, 0}, + {4214779116, 0, 0}, + {4248015868, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 36, 13}, + {0, 49, 60}, + {0, 51, 9}, + {0, 3, 62}, + {0, 67, 41}, + {0, 4, 31}, + {0, 66, 5}, + {0, 55, 32}, + {0, 2, 1}, + {0, 30, 16}, + {0, 7, 38}, + {0, 19, 10}, + {0, 34, 20}, + {0, 45, 46}, + {0, 22, 11}, + {0, 25, 23}, + {0, 40, 39}, + {0, 21, 57}, + {0, 6, 35}, + {0, 61, 8}, + {0, 52, 26}, + {0, 70, 59}, + {0, 71, 14}, + {0, 68, 47}, + {0, 73, 72}, + {0, 29, 74}, + {0, 76, 75}, + {0, 77, 17}, + {0, 79, 78}, + {0, 81, 80}, + {0, 82, 18}, + {0, 83, 42}, + {0, 85, 84}, + {0, 87, 86}, + {0, 27, 37}, + {0, 53, 43}, + {0, 89, 88}, + {0, 64, 54}, + {0, 90, 65}, + {0, 92, 91}, + {0, 58, 93}, + {0, 56, 48}, + {0, 94, 28}, + {0, 96, 95}, + {0, 98, 97}, + {0, 44, 99}, + {0, 101, 100}, + {0, 15, 12}, + {0, 103, 102}, + {0, 104, 33}, + {0, 106, 105}, + {0, 108, 107}, + {0, 24, 109}, + {0, 111, 110}, + {0, 113, 112}, + {0, 114, 50}, + {0, 116, 115}, + {0, 118, 117}, + {0, 120, 119}, + {0, 122, 121}, + {0, 124, 123}, + {0, 126, 125}, + {0, 128, 127}, + {0, 129, 63}, + {0, 131, 130}, + {0, 133, 132}, + {0, 135, 134}, + {0, 136, 69}, + })); + + codecs.emplace(std::pair(SpvOpVectorTimesScalar, 3), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {1951208733, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpDot, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(97, { + {0, 0, 0}, + {78001013, 0, 0}, + {170690025, 0, 0}, + {206688607, 0, 0}, + {443490822, 0, 0}, + {461476226, 0, 0}, + {537830163, 0, 0}, + {669982125, 0, 0}, + {790502615, 0, 0}, + {805072272, 0, 0}, + {1173092699, 0, 0}, + {1220643281, 0, 0}, + {1448448666, 0, 0}, + {1466804584, 0, 0}, + {1473411044, 0, 0}, + {1515695460, 0, 0}, + {1587730355, 0, 0}, + {1625742020, 0, 0}, + {2071351379, 0, 0}, + {2250055803, 0, 0}, + {2291766425, 0, 0}, + {2416108131, 0, 0}, + {2427834344, 0, 0}, + {2436009347, 0, 0}, + {2455417440, 0, 0}, + {2480811229, 0, 0}, + {2654325647, 0, 0}, + {2919796598, 0, 0}, + {3047649911, 0, 0}, + {3088511797, 0, 0}, + {3104643263, 0, 0}, + {3198541202, 0, 0}, + {3204986803, 0, 0}, + {3272233597, 0, 0}, + {3383007207, 0, 0}, + {3602108619, 0, 0}, + {3622349409, 0, 0}, + {3714664910, 0, 0}, + {3717942504, 0, 0}, + {3732000233, 0, 0}, + {3759072440, 0, 0}, + {3765247327, 0, 0}, + {3805423332, 0, 0}, + {3829325073, 0, 0}, + {3866493821, 0, 0}, + {4058280485, 0, 0}, + {4061558677, 0, 0}, + {4148979936, 0, 0}, + {4155586396, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 13, 38}, + {0, 39, 14}, + {0, 44, 9}, + {0, 48, 47}, + {0, 23, 15}, + {0, 33, 25}, + {0, 1, 42}, + {0, 5, 46}, + {0, 31, 3}, + {0, 36, 28}, + {0, 16, 12}, + {0, 32, 22}, + {0, 41, 21}, + {0, 6, 50}, + {0, 51, 29}, + {0, 45, 34}, + {0, 37, 8}, + {0, 19, 52}, + {0, 11, 4}, + {0, 43, 40}, + {0, 27, 53}, + {0, 54, 10}, + {0, 24, 55}, + {0, 57, 56}, + {0, 58, 26}, + {0, 2, 59}, + {0, 61, 60}, + {0, 63, 62}, + {0, 65, 64}, + {0, 20, 66}, + {0, 30, 35}, + {0, 67, 17}, + {0, 68, 7}, + {0, 70, 69}, + {0, 71, 18}, + {0, 73, 72}, + {0, 75, 74}, + {0, 77, 76}, + {0, 79, 78}, + {0, 81, 80}, + {0, 83, 82}, + {0, 85, 84}, + {0, 87, 86}, + {0, 89, 88}, + {0, 91, 90}, + {0, 93, 92}, + {0, 95, 94}, + {0, 49, 96}, + })); + + codecs.emplace(std::pair(SpvOpDot, 1), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(117, { + {0, 0, 0}, + {50385656, 0, 0}, + {181902171, 0, 0}, + {560078433, 0, 0}, + {615982737, 0, 0}, + {674428451, 0, 0}, + {837715723, 0, 0}, + {886972033, 0, 0}, + {900101778, 0, 0}, + {983299427, 0, 0}, + {1237148906, 0, 0}, + {1364157225, 0, 0}, + {1367301635, 0, 0}, + {1380160211, 0, 0}, + {1451831482, 0, 0}, + {1499923635, 0, 0}, + {1570165302, 0, 0}, + {1735295265, 0, 0}, + {1766401548, 0, 0}, + {1796311149, 0, 0}, + {1826456251, 0, 0}, + {1839669171, 0, 0}, + {2012838864, 0, 0}, + {2024071551, 0, 0}, + {2096388952, 0, 0}, + {2161102232, 0, 0}, + {2197874825, 0, 0}, + {2279700640, 0, 0}, + {2289183712, 0, 0}, + {2351620600, 0, 0}, + {2362972044, 0, 0}, + {2472176885, 0, 0}, + {2477434291, 0, 0}, + {2530899578, 0, 0}, + {2531826164, 0, 0}, + {2558133383, 0, 0}, + {2589449658, 0, 0}, + {2621255555, 0, 0}, + {2622612602, 0, 0}, + {2872580757, 0, 0}, + {2881302403, 0, 0}, + {2891091137, 0, 0}, + {2923708820, 0, 0}, + {2936040203, 0, 0}, + {2970183398, 0, 0}, + {3187066832, 0, 0}, + {3224952074, 0, 0}, + {3244383472, 0, 0}, + {3261122899, 0, 0}, + {3362830643, 0, 0}, + {3538158875, 0, 0}, + {3635542517, 0, 0}, + {3682213068, 0, 0}, + {3721902098, 0, 0}, + {3826846522, 0, 0}, + {3877583949, 0, 0}, + {3997432565, 0, 0}, + {4093615095, 0, 0}, + {4106828015, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 52, 28}, + {0, 33, 20}, + {0, 46, 57}, + {0, 47, 54}, + {0, 21, 17}, + {0, 31, 58}, + {0, 12, 53}, + {0, 29, 3}, + {0, 35, 34}, + {0, 48, 41}, + {0, 8, 5}, + {0, 7, 55}, + {0, 37, 32}, + {0, 60, 38}, + {0, 61, 16}, + {0, 14, 62}, + {0, 23, 63}, + {0, 13, 19}, + {0, 64, 9}, + {0, 65, 39}, + {0, 2, 66}, + {0, 67, 42}, + {0, 69, 68}, + {0, 25, 70}, + {0, 1, 49}, + {0, 6, 71}, + {0, 72, 15}, + {0, 73, 11}, + {0, 75, 74}, + {0, 77, 76}, + {0, 4, 78}, + {0, 56, 50}, + {0, 80, 79}, + {0, 10, 81}, + {0, 83, 82}, + {0, 85, 84}, + {0, 86, 27}, + {0, 43, 40}, + {0, 88, 87}, + {0, 44, 24}, + {0, 30, 89}, + {0, 51, 36}, + {0, 45, 90}, + {0, 18, 91}, + {0, 93, 92}, + {0, 22, 94}, + {0, 26, 95}, + {0, 97, 96}, + {0, 99, 98}, + {0, 101, 100}, + {0, 103, 102}, + {0, 105, 104}, + {0, 107, 106}, + {0, 109, 108}, + {0, 111, 110}, + {0, 113, 112}, + {0, 59, 114}, + {0, 116, 115}, + })); + + codecs.emplace(std::pair(SpvOpDot, 2), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(179, { + {0, 0, 0}, + {27177503, 0, 0}, + {50385656, 0, 0}, + {129748122, 0, 0}, + {139011596, 0, 0}, + {162608772, 0, 0}, + {181902171, 0, 0}, + {225200779, 0, 0}, + {342159236, 0, 0}, + {386293029, 0, 0}, + {429023543, 0, 0}, + {443558693, 0, 0}, + {504514034, 0, 0}, + {615982737, 0, 0}, + {669812542, 0, 0}, + {674428451, 0, 0}, + {837715723, 0, 0}, + {861753115, 0, 0}, + {875212982, 0, 0}, + {876867882, 0, 0}, + {899320334, 0, 0}, + {900101778, 0, 0}, + {938517572, 0, 0}, + {1347339159, 0, 0}, + {1356063462, 0, 0}, + {1373856501, 0, 0}, + {1376656865, 0, 0}, + {1451831482, 0, 0}, + {1522979646, 0, 0}, + {1548491889, 0, 0}, + {1570165302, 0, 0}, + {1735295265, 0, 0}, + {1747355813, 0, 0}, + {1766401548, 0, 0}, + {1871105284, 0, 0}, + {1918742169, 0, 0}, + {1922045399, 0, 0}, + {1978689945, 0, 0}, + {2024071551, 0, 0}, + {2059975069, 0, 0}, + {2076833303, 0, 0}, + {2096388952, 0, 0}, + {2181030375, 0, 0}, + {2197874825, 0, 0}, + {2362972044, 0, 0}, + {2414725163, 0, 0}, + {2517964682, 0, 0}, + {2564745684, 0, 0}, + {2577387676, 0, 0}, + {2589449658, 0, 0}, + {2604242419, 0, 0}, + {2683080096, 0, 0}, + {2696349144, 0, 0}, + {2763960513, 0, 0}, + {2817823941, 0, 0}, + {2852854788, 0, 0}, + {2891091137, 0, 0}, + {2919626325, 0, 0}, + {2923708820, 0, 0}, + {2936040203, 0, 0}, + {2963744582, 0, 0}, + {2970183398, 0, 0}, + {2984459037, 0, 0}, + {2996594997, 0, 0}, + {3015046341, 0, 0}, + {3055195668, 0, 0}, + {3127329373, 0, 0}, + {3187066832, 0, 0}, + {3193597927, 0, 0}, + {3200890815, 0, 0}, + {3224258475, 0, 0}, + {3224480461, 0, 0}, + {3261122899, 0, 0}, + {3609540589, 0, 0}, + {3619404941, 0, 0}, + {3619626927, 0, 0}, + {3727034815, 0, 0}, + {3742724777, 0, 0}, + {3742946763, 0, 0}, + {3836179806, 0, 0}, + {3913885196, 0, 0}, + {3927338499, 0, 0}, + {3927466635, 0, 0}, + {3997432565, 0, 0}, + {3999472204, 0, 0}, + {4010499223, 0, 0}, + {4032662899, 0, 0}, + {4110915453, 0, 0}, + {4145966869, 0, 0}, + {4228303141, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 23, 87}, + {0, 9, 28}, + {0, 42, 17}, + {0, 74, 70}, + {0, 86, 77}, + {0, 18, 5}, + {0, 31, 32}, + {0, 34, 3}, + {0, 38, 68}, + {0, 50, 29}, + {0, 72, 62}, + {0, 21, 15}, + {0, 14, 54}, + {0, 56, 22}, + {0, 48, 88}, + {0, 2, 76}, + {0, 6, 47}, + {0, 26, 79}, + {0, 65, 12}, + {0, 37, 81}, + {0, 91, 60}, + {0, 30, 92}, + {0, 25, 7}, + {0, 45, 40}, + {0, 66, 52}, + {0, 71, 69}, + {0, 78, 75}, + {0, 84, 82}, + {0, 94, 93}, + {0, 27, 95}, + {0, 97, 96}, + {0, 99, 98}, + {0, 100, 39}, + {0, 55, 101}, + {0, 58, 102}, + {0, 89, 103}, + {0, 35, 11}, + {0, 104, 36}, + {0, 53, 10}, + {0, 1, 64}, + {0, 73, 20}, + {0, 105, 13}, + {0, 107, 106}, + {0, 8, 16}, + {0, 24, 19}, + {0, 85, 63}, + {0, 109, 108}, + {0, 111, 110}, + {0, 4, 112}, + {0, 114, 113}, + {0, 116, 115}, + {0, 118, 117}, + {0, 83, 119}, + {0, 121, 120}, + {0, 123, 122}, + {0, 49, 44}, + {0, 124, 57}, + {0, 125, 59}, + {0, 126, 67}, + {0, 128, 127}, + {0, 130, 129}, + {0, 132, 131}, + {0, 134, 133}, + {0, 135, 51}, + {0, 137, 136}, + {0, 138, 61}, + {0, 43, 41}, + {0, 140, 139}, + {0, 142, 141}, + {0, 144, 143}, + {0, 146, 145}, + {0, 148, 147}, + {0, 149, 33}, + {0, 80, 150}, + {0, 152, 151}, + {0, 154, 153}, + {0, 156, 155}, + {0, 158, 157}, + {0, 160, 159}, + {0, 162, 161}, + {0, 164, 163}, + {0, 166, 165}, + {0, 168, 167}, + {0, 46, 169}, + {0, 171, 170}, + {0, 90, 172}, + {0, 174, 173}, + {0, 176, 175}, + {0, 178, 177}, + })); + + codecs.emplace(std::pair(SpvOpDot, 3), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {1036475267, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpLabel, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(3, { + {0, 0, 0}, + {1036475267, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 1, 2}, + })); + + codecs.emplace(std::pair(SpvOpBranch, 0), std::move(codec)); + } + + { + std::unique_ptr> codec(new HuffmanCodec(119, { + {0, 0, 0}, + {57149555, 0, 0}, + {139011596, 0, 0}, + {255835594, 0, 0}, + {330249537, 0, 0}, + {388686774, 0, 0}, + {508217552, 0, 0}, + {550831114, 0, 0}, + {559246409, 0, 0}, + {599185303, 0, 0}, + {649208064, 0, 0}, + {679061455, 0, 0}, + {810488476, 0, 0}, + {951841533, 0, 0}, + {1008886329, 0, 0}, + {1022544883, 0, 0}, + {1215030156, 0, 0}, + {1305703280, 0, 0}, + {1367301635, 0, 0}, + {1453447304, 0, 0}, + {1487177499, 0, 0}, + {1603937321, 0, 0}, + {1617826947, 0, 0}, + {1643868273, 0, 0}, + {1672607981, 0, 0}, + {1681941034, 0, 0}, + {1755165354, 0, 0}, + {1781864804, 0, 0}, + {1795715718, 0, 0}, + {1977038330, 0, 0}, + {2096388952, 0, 0}, + {2204920111, 0, 0}, + {2244470522, 0, 0}, + {2330636993, 0, 0}, + {2400601988, 0, 0}, + {2424848261, 0, 0}, + {2603020391, 0, 0}, + {2622612602, 0, 0}, + {2645135839, 0, 0}, + {2660843182, 0, 0}, + {2708915136, 0, 0}, + {2724166585, 0, 0}, + {2728667725, 0, 0}, + {2890638791, 0, 0}, + {2901034693, 0, 0}, + {2941648648, 0, 0}, + {2970183398, 0, 0}, + {2998120306, 0, 0}, + {3123244280, 0, 0}, + {3187066832, 0, 0}, + {3209399506, 0, 0}, + {3230260738, 0, 0}, + {3344189994, 0, 0}, + {3345707173, 0, 0}, + {3367298820, 0, 0}, + {3397078357, 0, 0}, + {3569736966, 0, 0}, + {3816961131, 0, 0}, + {4091670162, 0, 0}, + {4237497041, 0, 0}, + {1111111111111111111, 0, 0}, + {0, 17, 44}, + {0, 25, 20}, + {0, 29, 34}, + {0, 18, 2}, + {0, 54, 49}, + {0, 28, 7}, + {0, 47, 52}, + {0, 23, 56}, + {0, 55, 26}, + {0, 24, 61}, + {0, 13, 62}, + {0, 63, 45}, + {0, 27, 15}, + {0, 64, 8}, + {0, 65, 59}, + {0, 35, 22}, + {0, 53, 38}, + {0, 58, 51}, + {0, 11, 66}, + {0, 10, 3}, + {0, 46, 67}, + {0, 69, 68}, + {0, 1, 50}, + {0, 42, 19}, + {0, 70, 6}, + {0, 31, 71}, + {0, 16, 72}, + {0, 74, 73}, + {0, 76, 75}, + {0, 78, 77}, + {0, 79, 4}, + {0, 5, 37}, + {0, 14, 36}, + {0, 80, 57}, + {0, 81, 48}, + {0, 83, 82}, + {0, 39, 84}, + {0, 86, 85}, + {0, 40, 87}, + {0, 89, 88}, + {0, 91, 90}, + {0, 93, 92}, + {0, 21, 9}, + {0, 41, 32}, + {0, 12, 43}, + {0, 95, 94}, + {0, 97, 96}, + {0, 99, 98}, + {0, 100, 33}, + {0, 60, 101}, + {0, 103, 102}, + {0, 105, 104}, + {0, 107, 106}, + {0, 109, 108}, + {0, 110, 30}, + {0, 112, 111}, + {0, 114, 113}, + {0, 116, 115}, + {0, 118, 117}, + })); + + codecs.emplace(std::pair(SpvOpReturnValue, 0), std::move(codec)); + } + + return codecs; +} + +std::unordered_set GetDescriptorsWithCodingScheme() { + std::unordered_set descriptors_with_coding_scheme = { + 3816961131, + 3569736966, + 3397078357, + 3344189994, + 3230260738, + 2941648648, + 2901034693, + 2728667725, + 2400601988, + 1795715718, + 1681941034, + 1487177499, + 1453447304, + 679061455, + 649208064, + 559246409, + 388686774, + 4228303141, + 4110915453, + 4010499223, + 3927466635, + 3927338499, + 3836179806, + 3742724777, + 3619404941, + 3224480461, + 3224258475, + 3200890815, + 3742946763, + 3193597927, + 2604242419, + 2577387676, + 2181030375, + 1376656865, + 1347339159, + 938517572, + 876867882, + 429023543, + 129748122, + 4106828015, + 4093615095, + 3826846522, + 3721902098, + 3244383472, + 2891091137, + 2872580757, + 2558133383, + 2477434291, + 1839669171, + 2059975069, + 1735295265, + 1364157225, + 1237148906, + 886972033, + 674428451, + 4148979936, + 3805423332, + 3732000233, + 3717942504, + 3714664910, + 3622349409, + 3272233597, + 3204986803, + 3088511797, + 1672607981, + 2416108131, + 2250055803, + 1796311149, + 1515695460, + 537830163, + 461476226, + 206688607, + 78001013, + 3866493821, + 3417584874, + 3188115516, + 2526961521, + 2443610186, + 2394332122, + 2374216296, + 2032069771, + 2011183308, + 1904846533, + 1641070431, + 1635292159, + 1372785527, + 1369818198, + 1204787336, + 1826456251, + 1164137269, + 1081642571, + 629859130, + 253329281, + 227103506, + 11698369, + 4219766939, + 4169226615, + 3997432565, + 3873587660, + 3513669836, + 3261122899, + 2129301998, + 1774052499, + 1266262705, + 4285201458, + 4245743275, + 3907920335, + 3837583704, + 3641833815, + 3307100165, + 1232501371, + 3262572726, + 3045470312, + 2879050471, + 2801333547, + 2506771164, + 2504802016, + 2500422644, + 2473053808, + 2457690657, + 2345566651, + 2306141594, + 2154320787, + 2055637638, + 1527762373, + 1486206763, + 1159301677, + 1105835505, + 968885186, + 885645401, + 849867303, + 815034111, + 497658126, + 386525753, + 179458548, + 102358168, + 4127308103, + 4073492988, + 1473411044, + 805072272, + 3724004880, + 3602108619, + 3585315836, + 3331487616, + 3261703164, + 3235459678, + 3085703811, + 3047649911, + 2357410109, + 2291766425, + 2071351379, + 1904128160, + 1738815671, + 1531216990, + 1465623797, + 1324351672, + 1220127364, + 1144188012, + 183103444, + 116093251, + 3900859293, + 3345856521, + 3691770462, + 3263841912, + 3198541202, + 3098991995, + 3682213068, + 2963184673, + 2864543087, + 2802261839, + 2790648021, + 900101778, + 2715304020, + 100979271, + 2709694527, + 2669086217, + 2531826164, + 2651956495, + 2552825357, + 2480811229, + 3138977758, + 2434845539, + 2066323109, + 1777640493, + 1758287856, + 1746004874, + 3945482286, + 3932146199, + 3129573055, + 3126269825, + 3716914380, + 985750227, + 1543672828, + 3189039115, + 1839499483, + 2696349144, + 1536350567, + 3971481069, + 3001444829, + 4028622909, + 215293834, + 213642219, + 153085016, + 1189681639, + 165054168, + 29517006, + 2614879967, + 27865391, + 1649426421, + 4239834800, + 1947620272, + 28782128, + 3207966516, + 3713290482, + 2042001863, + 2724166585, + 2356768706, + 1793544760, + 4092654294, + 2157103435, + 2087004702, + 2043873558, + 27177503, + 1033363654, + 4214779116, + 408465899, + 451264926, + 2377112119, + 1182296898, + 760554870, + 3566035349, + 2630220147, + 4192247221, + 1572088444, + 3538592682, + 769422756, + 1674803691, + 630964591, + 3458449569, + 565334834, + 137840602, + 3955205564, + 2009007457, + 1258105452, + 333554713, + 3923810593, + 126463145, + 3445109809, + 2966409025, + 2849215484, + 1910240213, + 3131890669, + 586244865, + 2320303498, + 3116932970, + 1317265040, + 2812498065, + 1466938734, + 4064212479, + 2613179511, + 2095546797, + 1671139745, + 2568512089, + 3695940604, + 1119069977, + 215027449, + 4123141705, + 3495546641, + 1978689945, + 3202324433, + 3783543823, + 2674422363, + 1352628475, + 1290956281, + 1894417995, + 740921498, + 4211577142, + 1033081852, + 3884846406, + 3253403867, + 2790624748, + 2538917932, + 2144962711, + 3323202731, + 4290024976, + 2564745684, + 2963744582, + 2443959748, + 354479447, + 750870327, + 1918481917, + 4032662899, + 3587381650, + 2414725163, + 1081611718, + 1625742020, + 2308565678, + 1871105284, + 2807907995, + 2121980967, + 1054641568, + 413918748, + 1917336504, + 1816558243, + 4130950286, + 1522979646, + 1669959736, + 1320550031, + 3104643263, + 3823959661, + 3525913657, + 3584683259, + 2918750759, + 3536390697, + 94303122, + 3296691317, + 801484894, + 2496463830, + 3266028549, + 3085157904, + 973908139, + 3787909072, + 3107413701, + 2378763734, + 920604853, + 2516325050, + 1838993983, + 1603937321, + 3183924418, + 1945006185, + 3982311384, + 2682510803, + 680388473, + 979993429, + 2405770322, + 461040879, + 2817579280, + 14113753, + 2894979602, + 168339452, + 951841533, + 4154758669, + 2637132451, + 3877583949, + 1949856502, + 922996215, + 3941049054, + 4182141402, + 2262220987, + 1957218950, + 2094550054, + 1846856260, + 3499234137, + 3086631065, + 3054834317, + 593829839, + 522971108, + 1162127370, + 4233562270, + 2780190687, + 1558345254, + 3716353056, + 3518630848, + 1158929937, + 2038205856, + 86116519, + 4185661467, + 975807626, + 3910458990, + 4124281183, + 3361419439, + 171334650, + 2590402790, + 2890570341, + 2303184249, + 385229009, + 1998433745, + 1717510093, + 4022124023, + 1429389803, + 945128292, + 904486530, + 3869890846, + 619875033, + 459968607, + 3743748793, + 359054425, + 1417363940, + 3653985133, + 255835594, + 1047011733, + 2763232252, + 1329499601, + 328661377, + 2162274327, + 2100532220, + 4255182614, + 4243119782, + 3982047273, + 4053789056, + 401211099, + 950731750, + 1319785741, + 32085358, + 3882634684, + 3117071189, + 3554463148, + 3570219049, + 3535289452, + 2314864456, + 3913885196, + 2763960513, + 1079999262, + 27130513, + 3033873113, + 2976581453, + 2598189097, + 595410904, + 1572834111, + 13319433, + 1084574846, + 2123388694, + 560078433, + 1679946323, + 3518703473, + 184634770, + 296981500, + 1646147798, + 455591063, + 1325348861, + 3224952074, + 1027242654, + 2281956980, + 4221373527, + 1289566249, + 4044928561, + 882718761, + 1510333659, + 836581417, + 1901166356, + 2276405827, + 4052965752, + 1155765244, + 503145996, + 251209228, + 495107308, + 3944781937, + 37459569, + 4248015868, + 4198082194, + 1302400505, + 4106658327, + 680016782, + 2319227476, + 2738307068, + 3929248764, + 2850246066, + 1824526196, + 3912967080, + 3044723416, + 3133016299, + 2517964682, + 3647586740, + 3653838348, + 929101967, + 3571454885, + 2806296851, + 977312655, + 646282397, + 3448018532, + 824323032, + 204234270, + 1579585816, + 3712763835, + 1212872174, + 3953984401, + 3168953855, + 2944827576, + 1582841441, + 2796901051, + 3323682385, + 1317058015, + 2557550659, + 1620634991, + 2986830770, + 2490492987, + 1817271123, + 40653745, + 1696076631, + 2466126792, + 4169878842, + 3251128023, + 2444465148, + 678695941, + 2481746922, + 2836440943, + 774727851, + 2246405597, + 4028028350, + 2524697596, + 1977038330, + 2817823941, + 2219733501, + 688216667, + 3634598908, + 3232633974, + 2724625059, + 3269075805, + 3732640764, + 2263349224, + 1680746207, + 2414984922, + 2507457870, + 50998433, + 3092528578, + 3712946115, + 1543935193, + 807276090, + 1221183390, + 172029722, + 2122275289, + 3990925720, + 2261697609, + 2736881867, + 295017943, + 3278176820, + 3748965853, + 3174324790, + 1103903216, + 3184177968, + 1113409935, + 2299842241, + 2162986400, + 1538342947, + 4056442905, + 1631434666, + 205885885, + 1594733696, + 1955104493, + 1022309772, + 3820814597, + 993150979, + 1209418480, + 1784441183, + 3958731802, + 2250225826, + 3065160070, + 2024071551, + 107497541, + 628544021, + 2732195517, + 4241486415, + 3969279737, + 870594305, + 2916400082, + 1193734351, + 3202349435, + 3831290364, + 3282979782, + 3928764629, + 1308462133, + 3216471040, + 2433519008, + 2022961611, + 3604842236, + 3374978006, + 2855895374, + 3496407048, + 1482251215, + 3994511488, + 2997832431, + 1132589448, + 1348149915, + 2092468906, + 2451531615, + 779021139, + 3730093054, + 3413713311, + 1022915255, + 2204920111, + 2660843182, + 1080545747, + 1642805350, + 1766422419, + 4141567741, + 1558990974, + 4185590212, + 2841468319, + 701281393, + 3325419312, + 451957774, + 357505993, + 1156369516, + 3187387500, + 2259467579, + 2678954464, + 3154597438, + 543558236, + 2359973133, + 1990431740, + 2705477184, + 1041368449, + 3122368657, + 3181646225, + 1094423548, + 2955375511, + 2888125966, + 153013225, + 2936040203, + 1758530522, + 573901046, + 3030911670, + 1675922848, + 4235213885, + 4091916710, + 2633682514, + 4254584852, + 2328748202, + 3357301402, + 3877813395, + 2004567202, + 2496297824, + 3334207724, + 1600149091, + 293528591, + 1782996825, + 3757282300, + 1107206446, + 1092948665, + 1797960910, + 1206726575, + 1496351055, + 3021406120, + 99347751, + 3797204453, + 1468919488, + 797415788, + 1314843976, + 2934934694, + 490769168, + 1474506522, + 3811268385, + 864295921, + 3081676220, + 151810803, + 2588618056, + 2998120306, + 416853049, + 3495967422, + 3233393284, + 508007510, + 759277550, + 1971252067, + 869050696, + 810488476, + 745556697, + 789872778, + 3362723943, + 1617826947, + 3260309823, + 2197904616, + 1199157863, + 1643868273, + 2430404313, + 321630747, + 2503194620, + 3194725903, + 2881225774, + 3997952447, + 1389644742, + 2713718873, + 3585511591, + 1684282922, + 3366848728, + 284226441, + 1541020250, + 4018237905, + 1369578001, + 2424848261, + 2654325647, + 1626224034, + 1081536219, + 309040124, + 123060826, + 3997038726, + 1670691893, + 1543280290, + 443347828, + 1776629361, + 3118548424, + 478440524, + 679771963, + 3729929345, + 4244789645, + 2366506734, + 2838165089, + 1619778288, + 1313182965, + 3240680626, + 1323407757, + 883854656, + 2194691858, + 15502752, + 3760372982, + 1366337101, + 3656163446, + 295018543, + 825595257, + 57149555, + 2563789125, + 2353194283, + 2636942752, + 4026740269, + 3570411982, + 123108003, + 3782362128, + 1280126114, + 1410849099, + 4228502127, + 3609540589, + 3365041621, + 269823086, + 348988933, + 1636389511, + 2936586309, + 2761603302, + 2318200267, + 449954059, + 2895413148, + 1755165354, + 4274214049, + 778500192, + 3345707173, + 3732136051, + 721450866, + 1600392975, + 2466255445, + 4050155669, + 3541895912, + 1139547465, + 394654115, + 1380991098, + 3516240523, + 2234361374, + 1094817798, + 744817486, + 3564402361, + 1452222566, + 1851510470, + 3619787319, + 4265894873, + 216945449, + 3061690214, + 2910557180, + 255227811, + 4167600590, + 1587209598, + 3157581152, + 3184381405, + 2572638469, + 615748604, + 2532518896, + 1774874546, + 599185303, + 1561718045, + 1742737136, + 1674464100, + 3136865519, + 706016261, + 2793529873, + 3504981554, + 4155122613, + 2080953106, + 1104362365, + 2879917501, + 850497536, + 1392080469, + 1287937401, + 718877177, + 1917966999, + 1822823090, + 3701632935, + 3591222197, + 2817335337, + 1941148668, + 3110479131, + 3289213933, + 583624926, + 468372467, + 1633850097, + 2110223508, + 898191441, + 112745085, + 4018820793, + 3085119011, + 2919626325, + 3094857332, + 2348201466, + 2192810893, + 4163160985, + 1269075360, + 3952316364, + 2881886868, + 439764402, + 1584774136, + 169674806, + 3759072440, + 102542696, + 2996180816, + 804899022, + 1015552308, + 963902061, + 3504158761, + 2002490364, + 2806716850, + 265778447, + 4083122425, + 181902171, + 1238120570, + 75986790, + 1265796414, + 899570100, + 2988365258, + 3655201337, + 3654061472, + 3061856840, + 1077859090, + 615341051, + 3678875745, + 3349230696, + 3647606635, + 2549309392, + 1508570930, + 1766401548, + 1448448666, + 1499923635, + 2882994691, + 3674863070, + 3056042030, + 4240893633, + 1395113939, + 2964622752, + 1951208733, + 3536941067, + 4176581069, + 1203545131, + 3092754101, + 246375791, + 2736026107, + 1069781886, + 3687777340, + 1564342316, + 535067202, + 1395923345, + 3240977890, + 1447712361, + 2602027658, + 718301639, + 3123244280, + 1032593647, + 2840366496, + 2680819379, + 3839389658, + 277023757, + 1172110445, + 1755648697, + 2472176885, + 223800276, + 625975427, + 976111724, + 4145966869, + 2789375411, + 618087261, + 249378857, + 4058280485, + 827698488, + 1558001705, + 3561482820, + 2562485583, + 4243138030, + 615982737, + 1220643281, + 150685616, + 3091876332, + 1040775722, + 669982125, + 4116080964, + 3582002820, + 910398460, + 1036475267, + 3800912395, + 146392076, + 1686512349, + 2326636627, + 2839816704, + 3502816184, + 226836633, + 3953733490, + 257136089, + 819503463, + 2863084840, + 1949759310, + 210754155, + 1367301635, + 3822983876, + 4273793488, + 3635397748, + 3930494584, + 3127921440, + 3167253437, + 3868239231, + 1859128680, + 3480031018, + 3810805277, + 2677252364, + 156014509, + 3627739127, + 2321729979, + 1146476634, + 4039938779, + 1964254745, + 2055836767, + 119981689, + 2629265310, + 2448331885, + 3737376990, + 144116905, + 2272221101, + 2197874825, + 1277245109, + 2503770904, + 360730278, + 3489360962, + 1166917451, + 707478563, + 4155586396, + 162255877, + 347505241, + 4215670524, + 3187066832, + 2399809085, + 2754074729, + 4060703604, + 628331516, + 1304296041, + 616435646, + 4080527786, + 1443829854, + 2512398201, + 708736129, + 13107491, + 3794803132, + 2049792025, + 2455417440, + 3367313400, + 3357250579, + 3694383800, + 2339901602, + 3242843022, + 2282454607, + 1243764146, + 835458563, + 1297706389, + 464259778, + 1766994680, + 1294403159, + 2568098594, + 3107165180, + 4040340620, + 3352361837, + 1031290113, + 2903897222, + 1677700667, + 3160388974, + 107544081, + 3044188332, + 2285081596, + 2835131395, + 2984459037, + 4174489262, + 1236389532, + 2938237924, + 321459212, + 3407526215, + 300939750, + 3441531391, + 2909957084, + 3192069648, + 1849065716, + 2524531022, + 505940164, + 4121643374, + 3774892253, + 3197739982, + 2161102232, + 2715370488, + 1992893964, + 1781864804, + 587888644, + 1039111164, + 4237497041, + 451382997, + 969500141, + 1415510495, + 3743398113, + 3027538652, + 2525173102, + 1708264968, + 3366040354, + 1100599986, + 188347929, + 2597020383, + 2705434194, + 2593884753, + 3472123498, + 2975894973, + 3152745753, + 1154919607, + 1930923350, + 3287039847, + 1372881231, + 2280400314, + 3369343584, + 2351620600, + 2645135839, + 2752766693, + 1471851763, + 1989520052, + 1141965917, + 1503477720, + 653708953, + 1765126703, + 2432827426, + 95470391, + 2567901801, + 2589449658, + 4218799564, + 3249265647, + 3673811979, + 210116709, + 1593584949, + 1791352211, + 3457985288, + 3345288309, + 531559080, + 2491124112, + 3410158390, + 4224872590, + 3705139860, + 162608772, + 4258229445, + 925559698, + 3928842969, + 4253051659, + 3633746133, + 3867307935, + 3560665067, + 798915737, + 2945369269, + 2677264274, + 2278571792, + 177111659, + 85880059, + 1297165140, + 1630583316, + 2232491275, + 1848784182, + 2487708241, + 626480004, + 3427283542, + 2108571893, + 304448521, + 3332104493, + 2244470522, + 436416061, + 221900294, + 1502470404, + 3552593177, + 440421571, + 450406196, + 503094540, + 3836822275, + 2708915136, + 3750617468, + 1119744229, + 3614752756, + 921246433, + 2285438321, + 626892406, + 2362972044, + 72782198, + 2929019254, + 2795773560, + 907126242, + 155458798, + 2798552666, + 1404739463, + 4285652249, + 1998444837, + 908777857, + 872544165, + 910429472, + 135486769, + 3457269042, + 426360862, + 1725011064, + 296836635, + 1322549027, + 2044728014, + 1530183840, + 529742207, + 4272200782, + 1341516288, + 2608484640, + 41739659, + 3260579369, + 2745872368, + 2894051250, + 862784766, + 3077271274, + 3094180193, + 3619626927, + 3745223676, + 2976066508, + 2854085372, + 2959147533, + 3266548732, + 1776526161, + 3712296962, + 1955871800, + 2580096524, + 2507709226, + 3564865233, + 948086521, + 1548254487, + 142465290, + 1472185378, + 1459457331, + 2274226560, + 3153451899, + 492958971, + 3563213618, + 1285705317, + 410274915, + 3710645347, + 1309728002, + 2119793999, + 1343794461, + 4024173916, + 2383939514, + 955476870, + 2698156268, + 35240468, + 2655147757, + 3764205609, + 3802564010, + 170690025, + 2311941439, + 3181546731, + 3866587616, + 3648138580, + 93914936, + 170378107, + 2120623674, + 1064945649, + 1618754372, + 244668133, + 247698428, + 3669223677, + 470277359, + 1781765116, + 1691572958, + 1373856501, + 2668769415, + 1087394637, + 1009983433, + 2180701723, + 4008405264, + 2831059514, + 2645120714, + 2649103430, + 2664825925, + 790502615, + 1739837626, + 2293247016, + 1784648440, + 1887808856, + 1788504755, + 112452386, + 1979978194, + 3462674048, + 2170273742, + 538168945, + 753954113, + 374731234, + 3715846592, + 1962971231, + 1860649552, + 1378082995, + 665789406, + 1717555224, + 139011596, + 1375043498, + 1618544981, + 1889460471, + 2262321736, + 1788301425, + 1652168174, + 2668680621, + 2636946065, + 2856623532, + 2759951687, + 959681532, + 3209399506, + 3055195668, + 1227221002, + 508217552, + 3289969989, + 243178923, + 2956189845, + 3075866530, + 2274779301, + 3940720663, + 3998230222, + 1178317551, + 4016096296, + 1545450160, + 2842919847, + 314809953, + 2952850186, + 3747079365, + 4147239510, + 169135842, + 1332643570, + 2994529201, + 973521782, + 1584369690, + 1043738701, + 2851900832, + 290391815, + 283209196, + 2468230023, + 1164221089, + 1991787192, + 3358097187, + 51041423, + 52882140, + 2339018837, + 2053214130, + 3757479030, + 158160339, + 853200279, + 1986584654, + 438318340, + 827246872, + 3299488628, + 2924263085, + 3472029049, + 2736844435, + 677668732, + 604894932, + 1158021131, + 1400019344, + 2268204687, + 1450415100, + 3854557817, + 1543646433, + 1278448636, + 342615870, + 1554194368, + 3080024605, + 3423702268, + 1675764636, + 1622381564, + 2078849875, + 2113115132, + 1380160211, + 3132876285, + 125015036, + 269576093, + 94145952, + 2777172031, + 2683080096, + 3812456892, + 488500848, + 3270430997, + 2895151306, + 116376005, + 400248103, + 406044930, + 1616846013, + 10142671, + 763027711, + 225200779, + 1062250709, + 2013867381, + 2113506324, + 1692932387, + 1827244161, + 3124618210, + 2096472894, + 2924146124, + 2128251367, + 2433358586, + 1939359710, + 2593325766, + 2879917723, + 694743357, + 2902069960, + 220008971, + 3090408469, + 917019124, + 1705716306, + 3263901372, + 3347863687, + 3447882276, + 1661163736, + 3617689692, + 3928555688, + 1057578789, + 435256475, + 4101009465, + 1941403425, + 198967948, + 3733675151, + 2043684541, + 3517169445, + 2226776400, + 2853403709, + 529383565, + 2807448986, + 4234287173, + 1019457583, + 1022544883, + 2493146691, + 1054461787, + 1008886329, + 1136775085, + 1191015885, + 1196280518, + 1979847999, + 50385656, + 1918742169, + 3999472204, + 3697687030, + 2220475432, + 2358141757, + 2360004627, + 4245257809, + 236660303, + 429277936, + 342159236, + 2622612602, + 371428004, + 373079619, + 643418617, + 2095027856, + 1071164424, + 1136911283, + 1548491889, + 2169307971, + 375530199, + 1510422521, + 3151638847, + 1698730948, + 2231688008, + 2604576561, + 2771938750, + 2996594997, + 289648234, + 348584153, + 2748350697, + 2926633629, + 2123683379, + 369686787, + 742917749, + 3538158875, + 2937761472, + 1545298048, + 1321616112, + 2855506940, + 900522183, + 1578775276, + 2217833278, + 2012838864, + 3753486980, + 2839765116, + 2464905186, + 2621255555, + 1305703280, + 861753115, + 3319278167, + 3063300848, + 149720480, + 1082941229, + 3337532056, + 2248357849, + 3675926744, + 1508550646, + 2289803479, + 3456899824, + 3931641900, + 3970432934, + 3419674548, + 1093210099, + 456043370, + 848380423, + 1287304304, + 1526654696, + 2055664760, + 1373166395, + 4291477370, + 2195550588, + 2847102741, + 3399062057, + 1641565587, + 2888753905, + 3579593979, + 3653059026, + 3757851979, + 2922615804, + 2919796598, + 1553476262, + 2566666743, + 3759503594, + 550831114, + 3761155209, + 3762806824, + 3902853271, + 4140081844, + 14244860, + 3847846774, + 150820676, + 1278818058, + 850592577, + 1206571206, + 1734446471, + 2117320444, + 1382106590, + 2436009347, + 2118972059, + 2951272396, + 36096192, + 117998987, + 473485679, + 2244928358, + 476788909, + 3489269251, + 610429940, + 480092139, + 481743754, + 871966503, + 918189168, + 601656217, + 933769938, + 939671928, + 1799299383, + 3312467582, + 1149665466, + 3006548167, + 1310740861, + 3602693817, + 1461645203, + 3367691969, + 1800404122, + 3486057732, + 1862284649, + 2076833303, + 2213411495, + 2805256437, + 3927915220, + 3000904950, + 2094647776, + 3333131702, + 1315613425, + 3752211294, + 603915804, + 3505028338, + 663258455, + 3322500634, + 1612225949, + 3606320646, + 157110413, + 1352397672, + 3861006967, + 452208841, + 18776483, + 1058429216, + 37009196, + 564884461, + 876864198, + 2952260510, + 2860348412, + 928261291, + 1164724902, + 2775815164, + 1332774287, + 780957373, + 939415664, + 1513770932, + 788046331, + 1692600167, + 4069810315, + 673708384, + 4024252457, + 1932614728, + 2148510256, + 3131224670, + 2388524817, + 2460489993, + 2676385521, + 826214242, + 3692647551, + 3063508455, + 3071766530, + 2063832060, + 1525861001, + 3073418145, + 837715723, + 3075069760, + 3076721375, + 3078372990, + 983243705, + 3083327835, + 171307615, + 1824016656, + 3084979450, + 1310404265, + 1775308984, + 3114708520, + 3116360135, + 3121314980, + 3134527900, + 1691646294, + 2804281092, + 97231530, + 3136179515, + 3204260786, + 3276225962, + 1220749418, + 3588205699, + 3874089391, + 4044115788, + 3268751013, + 743407979, + 166253838, + 1356063462, + 1368383673, + 2279700640, + 2130747644, + 3945795573, + 2780898906, + 3635542517, + 425022309, + 517919178, + 4061558677, + 2190437442, + 543621065, + 753756604, + 2500819054, + 1004589179, + 1165671422, + 30433743, + 3444275347, + 1335363438, + 1913735398, + 1265998516, + 3829325073, + 3662767579, + 463084678, + 1351676723, + 1391866096, + 3398925952, + 1631216488, + 815757910, + 1915438939, + 2427834344, + 1445161581, + 1890300748, + 2864863800, + 1961990747, + 575205902, + 2037710159, + 2037814253, + 617312262, + 3732916270, + 783918780, + 2257843797, + 2096388952, + 2338272340, + 1434223270, + 578132535, + 1980341560, + 1002144380, + 3244716568, + 4258414038, + 3271748023, + 3304438238, + 3717523241, + 3370185097, + 3435931956, + 1957265068, + 3602522282, + 2547657777, + 439998433, + 3838648480, + 3913593633, + 3989799199, + 906176560, + 1894133125, + 4046301857, + 4242327928, + 630592085, + 2693892518, + 4292991777, + 545678922, + 125792961, + 3015046341, + 132755933, + 2615111110, + 1570165302, + 1440646342, + 436066778, + 565233904, + 600906020, + 602222721, + 3951925872, + 1496901698, + 1522901980, + 2785441472, + 3041450802, + 1637661947, + 2127660080, + 3487022798, + 2269114589, + 1314834580, + 2315690100, + 3817149113, + 4091670162, + 1431749301, + 1858116930, + 2213946343, + 2225172640, + 2263866576, + 2727022058, + 2752967311, + 2864705739, + 3052439312, + 3510257966, + 2614053317, + 3297860332, + 3670298840, + 3732709413, + 3788324110, + 4098876453, + 4290374884, + 1623013158, + 3381478137, + 17185761, + 3931288033, + 2890638791, + 330388453, + 346929928, + 2022347217, + 4083347580, + 533021259, + 564302770, + 1917602962, + 680157484, + 3264086791, + 3727034815, + 798549062, + 3068463300, + 669812542, + 1965902997, + 2311072371, + 3079287749, + 2542834724, + 1587730355, + 2558655180, + 1838763297, + 4172568578, + 2160380860, + 2950446516, + 1830851200, + 3214537066, + 3234673086, + 3652695478, + 3103302036, + 3465954368, + 4180570743, + 3534518722, + 371186900, + 4091394002, + 1013756921, + 443558693, + 591140762, + 656610661, + 2064733527, + 3808408202, + 983299427, + 4217306348, + 1164218401, + 2036361232, + 3237903670, + 2970183398, + 2293637521, + 135920445, + 1596005536, + 868652905, + 1191735827, + 3987079331, + 1365842164, + 1508074873, + 1642818143, + 3436143898, + 4105051793, + 1863199739, + 3425841570, + 1070791291, + 2135340676, + 2639720559, + 3364388739, + 3797761273, + 2092100514, + 2098706974, + 2329992200, + 414444763, + 2759250216, + 2913136690, + 3012980338, + 3327770644, + 4128942283, + 3362344229, + 161668409, + 3401762422, + 2852854788, + 4237092412, + 1245448751, + 3702405475, + 918849409, + 3829682756, + 1612361408, + 255302575, + 414620710, + 386293029, + 618761615, + 686024761, + 744062262, + 1502028603, + 1543798545, + 1641415225, + 1548121999, + 2257971049, + 2124837447, + 878733439, + 2340670452, + 2674090849, + 3118011750, + 2816338013, + 178571546, + 2841008029, + 3249261197, + 370232173, + 4092487128, + 3787567939, + 3898287302, + 4142016703, + 4285779501, + 30663912, + 151672195, + 180913835, + 3534235309, + 34183582, + 4083161638, + 651464351, + 1410311776, + 371621315, + 421602934, + 458937500, + 2710583246, + 712168842, + 730943059, + 1519723107, + 875212982, + 1247793383, + 4217322139, + 989813600, + 1057606514, + 3764662384, + 1443547269, + 3066811685, + 3598957382, + 1791427568, + 1171541710, + 3930727258, + 1473799048, + 1296054774, + 1747355813, + 765238787, + 2023008475, + 1190147516, + 2344328209, + 2495155989, + 2577859137, + 2857814560, + 3127329373, + 3296722158, + 2773229577, + 3376009661, + 3450001968, + 920941800, + 3526837441, + 3858973601, + 1702168830, + 4088613871, + 1464587427, + 223310468, + 388034151, + 2346547796, + 1663234329, + 1750829822, + 1967643923, + 2881302403, + 2278706468, + 2326990117, + 2511346984, + 3088785099, + 2616085763, + 3027500544, + 3417583519, + 4178218543, + 1412908157, + 797934924, + 3533637837, + 1449907751, + 3362830643, + 1451831482, + 2637935122, + 3070114915, + 3023287679, + 551924251, + 1669930486, + 46736908, + 2870852215, + 1120149824, + 2923708820, + 3887377256, + 3464197236, + 4241374559, + 527665290, + 996663016, + 885020215, + 1763758554, + 3059119137, + 2555315060, + 2762094724, + 2530899578, + 2770161927, + 2262137600, + 3547456240, + 858902117, + 1140367371, + 1215030156, + 443490822, + 294390719, + 3032677281, + 1917451875, + 4184019303, + 3277199633, + 1271484400, + 1297294717, + 3560552546, + 171494987, + 195244192, + 3002890475, + 1811839150, + 265392489, + 1461398554, + 3205759417, + 333855951, + 529068443, + 660038281, + 557400685, + 663341511, + 930804377, + 1922045399, + 716890919, + 162167595, + 1654776395, + 1779143013, + 1123617794, + 2984325996, + 1162789888, + 1318479490, + 1235468610, + 3561562003, + 1486207619, + 1551372768, + 1850331254, + 3255947500, + 1037370721, + 1989327599, + 2137526937, + 835638766, + 2269130237, + 1962162282, + 3244209297, + 2330636993, + 3095831808, + 1396344138, + 2603020391, + 3434076295, + 3280064277, + 2656211099, + 3335250889, + 2550961007, + 3510242586, + 3536471583, + 3950980241, + 4033586023, + 117250846, + 3088282680, + 4041974454, + 4244540017, + 1167160774, + 899320334, + 1200870684, + 1752686878, + 1906988301, + 3804101227, + 2575525651, + 2919787747, + 3508792859, + 3548535223, + 3783756895, + 3797961332, + 4043078107, + 3115038057, + 2313593054, + 49456560, + 592180731, + 1051471757, + 1097775533, + 706238670, + 877895868, + 1173092699, + 1461897718, + 1767704813, + 1770165905, + 1923453688, + 2212501241, + 2305269460, + 2488410748, + 3782099915, + 2844616706, + 3383007207, + 3392887901, + 504514034, + 3765247327, + 1000070091, + 3727494858, + 3657635382, + 3839047923, + 3886529747, + 4069720347, + 4164704452, + 342197850, + 3540244297, + 2513230733, + 4117704995, + 3367298820, + 2680283743, + 3119663365, + 3697738938, + 545363837, + 163402553, + 5908395, + 129135650, + 2289183712, + 200922300, + 761731755, + 894529125, + 1086964761, + 1168927492, + 2100052708, + 2438466459, + 3390051757, + 2498042266, + 2557754096, + 2600961503, + 487719832, + 703543228, + 2726532092, + 4199470013, + 3142155593, + 2550501832, + 4076840151, + 200553094, + 380957745, + 572905105, + 462664429, + 1466804584, + 330249537, + 2605012269, + 491456522, + 4126287524, + 502863753, + 952536201, + 3510682541, + 1137442027, + 1665981878, + 1761469971, + 3085467405, + 2045285083, + 796985462, + 3433956341, + 2217966239, + 2183547611, + 2279273489, + 1916983087, + 2348676810, + 2403632109, + 2409539315, + 545986953, + 176166202, + 2477389837, + 2573160348, + 2796513469, + 3972309363, + 528662843, + 1038982109, + 1125913837, + 1318081294, + 1417425499, + }; + return descriptors_with_coding_scheme; +} diff --git a/tools/dis/dis.cpp b/tools/dis/dis.cpp new file mode 100644 index 000000000..6a2e26932 --- /dev/null +++ b/tools/dis/dis.cpp @@ -0,0 +1,209 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) +#include // Need fileno +#include +#endif + +#include +#include +#include +#include + +#include "spirv-tools/libspirv.h" +#include "tools/io.h" + +static void print_usage(char* argv0) { + printf( + R"(%s - Disassemble a SPIR-V binary module + +Usage: %s [options] [] + +The SPIR-V binary is read from . If no file is specified, +or if the filename is "-", then the binary is read from standard input. + +Options: + + -h, --help Print this help. + --version Display disassembler version information. + + -o Set the output filename. + Output goes to standard output if this option is + not specified, or if the filename is "-". + + --color Force color output. The default when printing to a terminal. + Overrides a previous --no-color option. + --no-color Don't print in color. Overrides a previous --color option. + The default when output goes to something other than a + terminal (e.g. a file, a pipe, or a shell redirection). + + --no-indent Don't indent instructions. + + --no-header Don't output the header as leading comments. + + --raw-id Show raw Id values instead of friendly names. + + --offsets Show byte offsets for each instruction. +)", + argv0, argv0); +} + +static const auto kDefaultEnvironment = SPV_ENV_UNIVERSAL_1_3; + +int main(int argc, char** argv) { + const char* inFile = nullptr; + const char* outFile = nullptr; + + bool color_is_possible = +#if SPIRV_COLOR_TERMINAL + true; +#else + false; +#endif + bool force_color = false; + bool force_no_color = false; + + bool allow_indent = true; + bool show_byte_offsets = false; + bool no_header = false; + bool friendly_names = true; + + for (int argi = 1; argi < argc; ++argi) { + if ('-' == argv[argi][0]) { + switch (argv[argi][1]) { + case 'h': + print_usage(argv[0]); + return 0; + case 'o': { + if (!outFile && argi + 1 < argc) { + outFile = argv[++argi]; + } else { + print_usage(argv[0]); + return 1; + } + } break; + case '-': { + // Long options + if (0 == strcmp(argv[argi], "--no-color")) { + force_no_color = true; + force_color = false; + } else if (0 == strcmp(argv[argi], "--color")) { + force_no_color = false; + force_color = true; + } else if (0 == strcmp(argv[argi], "--no-indent")) { + allow_indent = false; + } else if (0 == strcmp(argv[argi], "--offsets")) { + show_byte_offsets = true; + } else if (0 == strcmp(argv[argi], "--no-header")) { + no_header = true; + } else if (0 == strcmp(argv[argi], "--raw-id")) { + friendly_names = false; + } else if (0 == strcmp(argv[argi], "--help")) { + print_usage(argv[0]); + return 0; + } else if (0 == strcmp(argv[argi], "--version")) { + printf("%s\n", spvSoftwareVersionDetailsString()); + printf("Target: %s\n", + spvTargetEnvDescription(kDefaultEnvironment)); + return 0; + } else { + print_usage(argv[0]); + return 1; + } + } break; + case 0: { + // Setting a filename of "-" to indicate stdin. + if (!inFile) { + inFile = argv[argi]; + } else { + fprintf(stderr, "error: More than one input file specified\n"); + return 1; + } + } break; + default: + print_usage(argv[0]); + return 1; + } + } else { + if (!inFile) { + inFile = argv[argi]; + } else { + fprintf(stderr, "error: More than one input file specified\n"); + return 1; + } + } + } + + uint32_t options = SPV_BINARY_TO_TEXT_OPTION_NONE; + + if (allow_indent) options |= SPV_BINARY_TO_TEXT_OPTION_INDENT; + + if (show_byte_offsets) options |= SPV_BINARY_TO_TEXT_OPTION_SHOW_BYTE_OFFSET; + + if (no_header) options |= SPV_BINARY_TO_TEXT_OPTION_NO_HEADER; + + if (friendly_names) options |= SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES; + + if (!outFile || (0 == strcmp("-", outFile))) { + // Print to standard output. + options |= SPV_BINARY_TO_TEXT_OPTION_PRINT; + + if (color_is_possible && !force_no_color) { + bool output_is_tty = true; +#if defined(_POSIX_VERSION) + output_is_tty = isatty(fileno(stdout)); +#endif + if (output_is_tty || force_color) { + options |= SPV_BINARY_TO_TEXT_OPTION_COLOR; + } + } + } + + // Read the input binary. + std::vector contents; + if (!ReadFile(inFile, "rb", &contents)) return 1; + + // If printing to standard output, then spvBinaryToText should + // do the printing. In particular, colour printing on Windows is + // controlled by modifying console objects synchronously while + // outputting to the stream rather than by injecting escape codes + // into the output stream. + // If the printing option is off, then save the text in memory, so + // it can be emitted later in this function. + const bool print_to_stdout = SPV_BINARY_TO_TEXT_OPTION_PRINT & options; + spv_text text = nullptr; + spv_text* textOrNull = print_to_stdout ? nullptr : &text; + spv_diagnostic diagnostic = nullptr; + spv_context context = spvContextCreate(kDefaultEnvironment); + spv_result_t error = + spvBinaryToText(context, contents.data(), contents.size(), options, + textOrNull, &diagnostic); + spvContextDestroy(context); + if (error) { + spvDiagnosticPrint(diagnostic); + spvDiagnosticDestroy(diagnostic); + return error; + } + + if (!print_to_stdout) { + if (!WriteFile(outFile, "w", text->str, text->length)) { + spvTextDestroy(text); + return 1; + } + } + spvTextDestroy(text); + + return 0; +} diff --git a/tools/emacs/50spirv-tools.el b/tools/emacs/50spirv-tools.el new file mode 100644 index 000000000..1d4fbeef8 --- /dev/null +++ b/tools/emacs/50spirv-tools.el @@ -0,0 +1,40 @@ +;; Copyright (c) 2016 LunarG Inc. +;; +;; Licensed under the Apache License, Version 2.0 (the "License"); +;; you may not use this file except in compliance with the License. +;; You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. + +;; Upon loading a file with the .spv extension into emacs, the file +;; will be disassembled using spirv-dis, and the result colorized with +;; asm-mode in emacs. The file may be edited within the constraints +;; of validity, and when re-saved will be re-assembled using spirv-as. + +;; Note that symbol IDs are not preserved through a load/edit/save operation. +;; This may change if the ability is added to spirv-as. + +;; It is required that those tools be in your PATH. If that is not the case +;; when starting emacs, the path can be modified as in this example: +;; (setenv "PATH" (concat (getenv "PATH") ":/path/to/spirv/tools")) +;; +;; See https://github.com/KhronosGroup/SPIRV-Tools/issues/359 + +(require 'jka-compr) +(require 'asm-mode) + +(add-to-list 'jka-compr-compression-info-list + '["\\.spv\\'" + "Assembling SPIRV" "spirv-as" ("-o" "-") + "Disassembling SPIRV" "spirv-dis" ("--no-color" "--raw-id") + t nil "\003\002\043\007"]) + +(add-to-list 'auto-mode-alist '("\\.spv\\'" . asm-mode)) + +(jka-compr-update) diff --git a/tools/emacs/CMakeLists.txt b/tools/emacs/CMakeLists.txt new file mode 100644 index 000000000..ecd7c277a --- /dev/null +++ b/tools/emacs/CMakeLists.txt @@ -0,0 +1,48 @@ +# Copyright (c) 2016 LunarG Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Install a script for use with the auto-compression feature of emacs(1). +# Upon loading a file with the .spv extension, the file will be disassembled +# using spirv-dis, and the result colorized with asm-mode in emacs. The file +# may be edited within the constraints of validity, and when re-saved will be +# re-assembled using spirv-as. + +# It is required that those tools be in your PATH. If that is not the case +# when starting emacs, the path can be modified as in this example: +# (setenv "PATH" (concat (getenv "PATH") ":/path/to/spirv/tools")) +# +# See https://github.com/KhronosGroup/SPIRV-Tools/issues/359 + +# This is an absolute directory, and ignores CMAKE_INSTALL_PREFIX, or +# it will not be found by emacs upon startup. It is only installed if +# both of the following are true: +# 1. SPIRV_TOOLS_INSTALL_EMACS_HELPERS is defined +# 2. The directory /etc/emacs/site-start.d already exists at the time of +# cmake invocation (not at the time of make install). This is +# typically true if emacs is installed on the system. + +# Note that symbol IDs are not preserved through a load/edit/save operation. +# This may change if the ability is added to spirv-as. + +option(SPIRV_TOOLS_INSTALL_EMACS_HELPERS + "Install Emacs helper to disassemble/assemble SPIR-V binaries on file load/save." + ${SPIRV_TOOLS_INSTALL_EMACS_HELPERS}) +if (${SPIRV_TOOLS_INSTALL_EMACS_HELPERS}) + if(EXISTS /etc/emacs/site-start.d) + if(ENABLE_SPIRV_TOOLS_INSTALL) + install(FILES 50spirv-tools.el DESTINATION /etc/emacs/site-start.d) + endif(ENABLE_SPIRV_TOOLS_INSTALL) + endif() +endif() + diff --git a/tools/io.h b/tools/io.h new file mode 100644 index 000000000..aaf8fcdd2 --- /dev/null +++ b/tools/io.h @@ -0,0 +1,82 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TOOLS_IO_H_ +#define TOOLS_IO_H_ + +#include +#include +#include + +// Appends the content from the file named as |filename| to |data|, assuming +// each element in the file is of type |T|. The file is opened with the given +// |mode|. If |filename| is nullptr or "-", reads from the standard input, but +// reopened with the given mode. If any error occurs, writes error messages to +// standard error and returns false. +template +bool ReadFile(const char* filename, const char* mode, std::vector* data) { + const int buf_size = 1024; + const bool use_file = filename && strcmp("-", filename); + if (FILE* fp = + (use_file ? fopen(filename, mode) : freopen(nullptr, mode, stdin))) { + T buf[buf_size]; + while (size_t len = fread(buf, sizeof(T), buf_size, fp)) { + data->insert(data->end(), buf, buf + len); + } + if (ftell(fp) == -1L) { + if (ferror(fp)) { + fprintf(stderr, "error: error reading file '%s'\n", filename); + return false; + } + } else { + if (sizeof(T) != 1 && (ftell(fp) % sizeof(T))) { + fprintf( + stderr, + "error: file size should be a multiple of %zd; file '%s' corrupt\n", + sizeof(T), filename); + return false; + } + } + if (use_file) fclose(fp); + } else { + fprintf(stderr, "error: file does not exist '%s'\n", filename); + return false; + } + return true; +} + +// Writes the given |data| into the file named as |filename| using the given +// |mode|, assuming |data| is an array of |count| elements of type |T|. If +// |filename| is nullptr or "-", writes to standard output. If any error occurs, +// returns false and outputs error message to standard error. +template +bool WriteFile(const char* filename, const char* mode, const T* data, + size_t count) { + const bool use_stdout = + !filename || (filename[0] == '-' && filename[1] == '\0'); + if (FILE* fp = (use_stdout ? stdout : fopen(filename, mode))) { + size_t written = fwrite(data, sizeof(T), count, fp); + if (count != written) { + fprintf(stderr, "error: could not write to file '%s'\n", filename); + return false; + } + if (!use_stdout) fclose(fp); + } else { + fprintf(stderr, "error: could not open file '%s'\n", filename); + return false; + } + return true; +} + +#endif // TOOLS_IO_H_ diff --git a/tools/lesspipe/CMakeLists.txt b/tools/lesspipe/CMakeLists.txt new file mode 100644 index 000000000..484e51e58 --- /dev/null +++ b/tools/lesspipe/CMakeLists.txt @@ -0,0 +1,28 @@ +# Copyright (c) 2016 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Install a script for use with the LESSOPEN of less(1). +# For example, after installation into /usr/local do: +# export LESSOPEN='|/usr/local/bin "%s"' +# less -R foo.spv +# +# See https://github.com/KhronosGroup/SPIRV-Tools/issues/359 + +# The script will be installed with everyone having read and execute +# permissions. +# We have a .sh extension because Windows users often configure +# executable settings via filename extension. +if(ENABLE_SPIRV_TOOLS_INSTALL) + install(PROGRAMS spirv-lesspipe.sh DESTINATION ${CMAKE_INSTALL_BINDIR}) +endif(ENABLE_SPIRV_TOOLS_INSTALL) diff --git a/tools/lesspipe/spirv-lesspipe.sh b/tools/lesspipe/spirv-lesspipe.sh new file mode 100644 index 000000000..57684a201 --- /dev/null +++ b/tools/lesspipe/spirv-lesspipe.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env sh +# Copyright (c) 2016 The Khronos Group Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# A script for automatically disassembling a .spv file +# for less(1). This assumes spirv-dis is on our PATH. +# +# See https://github.com/KhronosGroup/SPIRV-Tools/issues/359 + +case "$1" in + *.spv) spirv-dis "$@" 2>/dev/null;; + *) exit 1;; +esac + +exit $? + diff --git a/tools/link/linker.cpp b/tools/link/linker.cpp new file mode 100644 index 000000000..fb44a37ad --- /dev/null +++ b/tools/link/linker.cpp @@ -0,0 +1,159 @@ +// Copyright (c) 2017 Pierre Moreau +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "source/spirv_target_env.h" +#include "source/table.h" +#include "spirv-tools/libspirv.hpp" +#include "spirv-tools/linker.hpp" +#include "tools/io.h" + +void print_usage(char* argv0) { + printf( + R"(%s - Link SPIR-V binary files together. + +USAGE: %s [options] [ ...] + +The SPIR-V binaries are read from the different . + +NOTE: The linker is a work in progress. + +Options: + -h, --help Print this help. + -o Name of the resulting linked SPIR-V binary. + --create-library Link the binaries into a library, keeping all exported symbols. + --allow-partial-linkage Allow partial linkage by accepting imported symbols to be unresolved. + --verify-ids Verify that IDs in the resulting modules are truly unique. + --version Display linker version information + --target-env {vulkan1.0|spv1.0|spv1.1|spv1.2|opencl2.1|opencl2.2} + Use Vulkan1.0/SPIR-V1.0/SPIR-V1.1/SPIR-V1.2/OpenCL-2.1/OpenCL2.2 validation rules. +)", + argv0, argv0); +} + +int main(int argc, char** argv) { + std::vector inFiles; + const char* outFile = nullptr; + spv_target_env target_env = SPV_ENV_UNIVERSAL_1_0; + spvtools::LinkerOptions options; + bool continue_processing = true; + int return_code = 0; + + for (int argi = 1; continue_processing && argi < argc; ++argi) { + const char* cur_arg = argv[argi]; + if ('-' == cur_arg[0]) { + if (0 == strcmp(cur_arg, "-o")) { + if (argi + 1 < argc) { + if (!outFile) { + outFile = argv[++argi]; + } else { + fprintf(stderr, "error: More than one output file specified\n"); + continue_processing = false; + return_code = 1; + } + } else { + fprintf(stderr, "error: Missing argument to %s\n", cur_arg); + continue_processing = false; + return_code = 1; + } + } else if (0 == strcmp(cur_arg, "--create-library")) { + options.SetCreateLibrary(true); + } else if (0 == strcmp(cur_arg, "--verify-ids")) { + options.SetVerifyIds(true); + } else if (0 == strcmp(cur_arg, "--allow-partial-linkage")) { + options.SetAllowPartialLinkage(true); + } else if (0 == strcmp(cur_arg, "--version")) { + printf("%s\n", spvSoftwareVersionDetailsString()); + // TODO(dneto): Add OpenCL 2.2 at least. + printf("Targets:\n %s\n %s\n %s\n", + spvTargetEnvDescription(SPV_ENV_UNIVERSAL_1_1), + spvTargetEnvDescription(SPV_ENV_VULKAN_1_0), + spvTargetEnvDescription(SPV_ENV_UNIVERSAL_1_2)); + continue_processing = false; + return_code = 0; + } else if (0 == strcmp(cur_arg, "--help") || 0 == strcmp(cur_arg, "-h")) { + print_usage(argv[0]); + continue_processing = false; + return_code = 0; + } else if (0 == strcmp(cur_arg, "--target-env")) { + if (argi + 1 < argc) { + const auto env_str = argv[++argi]; + if (!spvParseTargetEnv(env_str, &target_env)) { + fprintf(stderr, "error: Unrecognized target env: %s\n", env_str); + continue_processing = false; + return_code = 1; + } + } else { + fprintf(stderr, "error: Missing argument to --target-env\n"); + continue_processing = false; + return_code = 1; + } + } + } else { + inFiles.push_back(cur_arg); + } + } + + // Exit if command line parsing was not successful. + if (!continue_processing) { + return return_code; + } + + if (inFiles.empty()) { + fprintf(stderr, "error: No input file specified\n"); + return 1; + } + + std::vector> contents(inFiles.size()); + for (size_t i = 0u; i < inFiles.size(); ++i) { + if (!ReadFile(inFiles[i], "rb", &contents[i])) return 1; + } + + const spvtools::MessageConsumer consumer = [](spv_message_level_t level, + const char*, + const spv_position_t& position, + const char* message) { + switch (level) { + case SPV_MSG_FATAL: + case SPV_MSG_INTERNAL_ERROR: + case SPV_MSG_ERROR: + std::cerr << "error: " << position.index << ": " << message + << std::endl; + break; + case SPV_MSG_WARNING: + std::cout << "warning: " << position.index << ": " << message + << std::endl; + break; + case SPV_MSG_INFO: + std::cout << "info: " << position.index << ": " << message << std::endl; + break; + default: + break; + } + }; + spvtools::Context context(target_env); + context.SetMessageConsumer(consumer); + + std::vector linkingResult; + spv_result_t status = Link(context, contents, &linkingResult, options); + + if (!WriteFile(outFile, "wb", linkingResult.data(), + linkingResult.size())) + return 1; + + return status == SPV_SUCCESS ? 0 : 1; +} diff --git a/tools/opt/opt.cpp b/tools/opt/opt.cpp new file mode 100644 index 000000000..fb794270e --- /dev/null +++ b/tools/opt/opt.cpp @@ -0,0 +1,699 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/opt/log.h" +#include "source/spirv_target_env.h" +#include "source/util/string_utils.h" +#include "spirv-tools/libspirv.hpp" +#include "spirv-tools/optimizer.hpp" +#include "tools/io.h" +#include "tools/util/cli_consumer.h" + +namespace { + +// Status and actions to perform after parsing command-line arguments. +enum OptActions { OPT_CONTINUE, OPT_STOP }; + +struct OptStatus { + OptActions action; + int code; +}; + +// Message consumer for this tool. Used to emit diagnostics during +// initialization and setup. Note that |source| and |position| are irrelevant +// here because we are still not processing a SPIR-V input file. +void opt_diagnostic(spv_message_level_t level, const char* /*source*/, + const spv_position_t& /*positon*/, const char* message) { + if (level == SPV_MSG_ERROR) { + fprintf(stderr, "error: "); + } + fprintf(stderr, "%s\n", message); +} + +std::string GetListOfPassesAsString(const spvtools::Optimizer& optimizer) { + std::stringstream ss; + for (const auto& name : optimizer.GetPassNames()) { + ss << "\n\t\t" << name; + } + return ss.str(); +} + +const auto kDefaultEnvironment = SPV_ENV_UNIVERSAL_1_3; + +std::string GetLegalizationPasses() { + spvtools::Optimizer optimizer(kDefaultEnvironment); + optimizer.RegisterLegalizationPasses(); + return GetListOfPassesAsString(optimizer); +} + +std::string GetOptimizationPasses() { + spvtools::Optimizer optimizer(kDefaultEnvironment); + optimizer.RegisterPerformancePasses(); + return GetListOfPassesAsString(optimizer); +} + +std::string GetSizePasses() { + spvtools::Optimizer optimizer(kDefaultEnvironment); + optimizer.RegisterSizePasses(); + return GetListOfPassesAsString(optimizer); +} + +std::string GetWebGPUPasses() { + spvtools::Optimizer optimizer(SPV_ENV_WEBGPU_0); + optimizer.RegisterWebGPUPasses(); + return GetListOfPassesAsString(optimizer); +} + +void PrintUsage(const char* program) { + // NOTE: Please maintain flags in lexicographical order. + printf( + R"(%s - Optimize a SPIR-V binary file. + +USAGE: %s [options] [] -o + +The SPIR-V binary is read from . If no file is specified, +or if is "-", then the binary is read from standard input. +if is "-", then the optimized output is written to +standard output. + +NOTE: The optimizer is a work in progress. + +Options (in lexicographical order): + --ccp + Apply the conditional constant propagation transform. This will + propagate constant values throughout the program, and simplify + expressions and conditional jumps with known predicate + values. Performed on entry point call tree functions and + exported functions. + --cfg-cleanup + Cleanup the control flow graph. This will remove any unnecessary + code from the CFG like unreachable code. Performed on entry + point call tree functions and exported functions. + --combine-access-chains + Combines chained access chains to produce a single instruction + where possible. + --compact-ids + Remap result ids to a compact range starting from %%1 and without + any gaps. + --convert-local-access-chains + Convert constant index access chain loads/stores into + equivalent load/stores with inserts and extracts. Performed + on function scope variables referenced only with load, store, + and constant index access chains in entry point call tree + functions. + --copy-propagate-arrays + Does propagation of memory references when an array is a copy of + another. It will only propagate an array if the source is never + written to, and the only store to the target is the copy. + --eliminate-common-uniform + Perform load/load elimination for duplicate uniform values. + Converts any constant index access chain uniform loads into + its equivalent load and extract. Some loads will be moved + to facilitate sharing. Performed only on entry point + call tree functions. + --eliminate-dead-branches + Convert conditional branches with constant condition to the + indicated unconditional brranch. Delete all resulting dead + code. Performed only on entry point call tree functions. + --eliminate-dead-code-aggressive + Delete instructions which do not contribute to a function's + output. Performed only on entry point call tree functions. + --eliminate-dead-const + Eliminate dead constants. + --eliminate-dead-functions + Deletes functions that cannot be reached from entry points or + exported functions. + --eliminate-dead-inserts + Deletes unreferenced inserts into composites, most notably + unused stores to vector components, that are not removed by + aggressive dead code elimination. + --eliminate-dead-variables + Deletes module scope variables that are not referenced. + --eliminate-insert-extract + DEPRECATED. This pass has been replaced by the simplification + pass, and that pass will be run instead. + See --simplify-instructions. + --eliminate-local-multi-store + Replace stores and loads of function scope variables that are + stored multiple times. Performed on variables referenceed only + with loads and stores. Performed only on entry point call tree + functions. + --eliminate-local-single-block + Perform single-block store/load and load/load elimination. + Performed only on function scope variables in entry point + call tree functions. + --eliminate-local-single-store + Replace stores and loads of function scope variables that are + only stored once. Performed on variables referenceed only with + loads and stores. Performed only on entry point call tree + functions. + --flatten-decorations + Replace decoration groups with repeated OpDecorate and + OpMemberDecorate instructions. + --fold-spec-const-op-composite + Fold the spec constants defined by OpSpecConstantOp or + OpSpecConstantComposite instructions to front-end constants + when possible. + --freeze-spec-const + Freeze the values of specialization constants to their default + values. + --if-conversion + Convert if-then-else like assignments into OpSelect. + --inline-entry-points-exhaustive + Exhaustively inline all function calls in entry point call tree + functions. Currently does not inline calls to functions with + early return in a loop. + --legalize-hlsl + Runs a series of optimizations that attempts to take SPIR-V + generated by an HLSL front-end and generates legal Vulkan SPIR-V. + The optimizations are: + %s + + Note this does not guarantee legal code. This option passes the + option --relax-logical-pointer to the validator. + --local-redundancy-elimination + Looks for instructions in the same basic block that compute the + same value, and deletes the redundant ones. + --loop-fission + Splits any top level loops in which the register pressure has + exceeded a given threshold. The threshold must follow the use of + this flag and must be a positive integer value. + --loop-fusion + Identifies adjacent loops with the same lower and upper bound. + If this is legal, then merge the loops into a single loop. + Includes heuristics to ensure it does not increase number of + registers too much, while reducing the number of loads from + memory. Takes an additional positive integer argument to set + the maximum number of registers. + --loop-invariant-code-motion + Identifies code in loops that has the same value for every + iteration of the loop, and move it to the loop pre-header. + --loop-unroll + Fully unrolls loops marked with the Unroll flag + --loop-unroll-partial + Partially unrolls loops marked with the Unroll flag. Takes an + additional non-0 integer argument to set the unroll factor, or + how many times a loop body should be duplicated + --loop-peeling + Execute few first (respectively last) iterations before + (respectively after) the loop if it can elide some branches. + --loop-peeling-threshold + Takes a non-0 integer argument to set the loop peeling code size + growth threshold. The threshold prevents the loop peeling + from happening if the code size increase created by + the optimization is above the threshold. + --max-id-bound= + Sets the maximum value for the id bound for the moudle. The + default is the minimum value for this limit, 0x3FFFFF. See + section 2.17 of the Spir-V specification. + --merge-blocks + Join two blocks into a single block if the second has the + first as its only predecessor. Performed only on entry point + call tree functions. + --merge-return + Changes functions that have multiple return statements so they + have a single return statement. + + For structured control flow it is assumed that the only + unreachable blocks in the function are trivial merge and continue + blocks. + + A trivial merge block contains the label and an OpUnreachable + instructions, nothing else. A trivial continue block contain a + label and an OpBranch to the header, nothing else. + + These conditions are guaranteed to be met after running + dead-branch elimination. + --loop-unswitch + Hoists loop-invariant conditionals out of loops by duplicating + the loop on each branch of the conditional and adjusting each + copy of the loop. + -O + Optimize for performance. Apply a sequence of transformations + in an attempt to improve the performance of the generated + code. For this version of the optimizer, this flag is equivalent + to specifying the following optimization code names: + %s + -Os + Optimize for size. Apply a sequence of transformations in an + attempt to minimize the size of the generated code. For this + version of the optimizer, this flag is equivalent to specifying + the following optimization code names: + %s + + NOTE: The specific transformations done by -O and -Os change + from release to release. + -Oconfig= + Apply the sequence of transformations indicated in . + This file contains a sequence of strings separated by whitespace + (tabs, newlines or blanks). Each string is one of the flags + accepted by spirv-opt. Optimizations will be applied in the + sequence they appear in the file. This is equivalent to + specifying all the flags on the command line. For example, + given the file opts.cfg with the content: + + --inline-entry-points-exhaustive + --eliminate-dead-code-aggressive + + The following two invocations to spirv-opt are equivalent: + + $ spirv-opt -Oconfig=opts.cfg program.spv + + $ spirv-opt --inline-entry-points-exhaustive \ + --eliminate-dead-code-aggressive program.spv + + Lines starting with the character '#' in the configuration + file indicate a comment and will be ignored. + + The -O, -Os, and -Oconfig flags act as macros. Using one of them + is equivalent to explicitly inserting the underlying flags at + that position in the command line. For example, the invocation + 'spirv-opt --merge-blocks -O ...' applies the transformation + --merge-blocks followed by all the transformations implied by + -O. + --print-all + Print SPIR-V assembly to standard error output before each pass + and after the last pass. + --private-to-local + Change the scope of private variables that are used in a single + function to that function. + --reduce-load-size + Replaces loads of composite objects where not every component is + used by loads of just the elements that are used. + --redundancy-elimination + Looks for instructions in the same function that compute the + same value, and deletes the redundant ones. + --relax-struct-store + Allow store from one struct type to a different type with + compatible layout and members. This option is forwarded to the + validator. + --remove-duplicates + Removes duplicate types, decorations, capabilities and extension + instructions. + --replace-invalid-opcode + Replaces instructions whose opcode is valid for shader modules, + but not for the current shader stage. To have an effect, all + entry points must have the same execution model. + --ssa-rewrite + Replace loads and stores to function local variables with + operations on SSA IDs. + --scalar-replacement[=] + Replace aggregate function scope variables that are only accessed + via their elements with new function variables representing each + element. is a limit on the size of the aggragates that will + be replaced. 0 means there is no limit. The default value is + 100. + --set-spec-const-default-value ": ..." + Set the default values of the specialization constants with + : pairs specified in a double-quoted + string. : pairs must be separated by + blank spaces, and in each pair, spec id and default value must + be separated with colon ':' without any blank spaces in between. + e.g.: --set-spec-const-default-value "1:100 2:400" + --simplify-instructions + Will simplify all instructions in the function as much as + possible. + --skip-validation + Will not validate the SPIR-V before optimizing. If the SPIR-V + is invalid, the optimizer may fail or generate incorrect code. + This options should be used rarely, and with caution. + --strength-reduction + Replaces instructions with equivalent and less expensive ones. + --strip-debug + Remove all debug instructions. + --strip-reflect + Remove all reflection information. For now, this covers + reflection information defined by SPV_GOOGLE_hlsl_functionality1. + --target-env= + Set the target environment. Without this flag the target + enviroment defaults to spv1.3. + must be one of vulkan1.0, vulkan1.1, opencl2.2, spv1.0, + spv1.1, spv1.2, spv1.3, or webgpu0. + --time-report + Print the resource utilization of each pass (e.g., CPU time, + RSS) to standard error output. Currently it supports only Unix + systems. This option is the same as -ftime-report in GCC. It + prints CPU/WALL/USR/SYS time (and RSS if possible), but note that + USR/SYS time are returned by getrusage() and can have a small + error. + --upgrade-memory-model + Upgrades the Logical GLSL450 memory model to Logical VulkanKHR. + Transforms memory, image, atomic and barrier operations to conform + to that model's requirements. + --vector-dce + This pass looks for components of vectors that are unused, and + removes them from the vector. Note this would still leave around + lots of dead code that a pass of ADCE will be able to remove. + --webgpu-mode + Turns on the prescribed passes for WebGPU and sets the target + environmet to webgpu0. Other passes may be turned on via + additional flags, but such combinations are not tested. + Using --target-env with this flag is not allowed. + + This flag is the equivalent of passing in --target-env=webgpu0 + and specifying the following optimization code names: + %s + + NOTE: This flag is a WIP and its behaviour is subject to change. + --workaround-1209 + Rewrites instructions for which there are known driver bugs to + avoid triggering those bugs. + Current workarounds: Avoid OpUnreachable in loops. + --unify-const + Remove the duplicated constants. + -h, --help + Print this help. + --version + Display optimizer version information. +)", + program, program, GetLegalizationPasses().c_str(), + GetOptimizationPasses().c_str(), GetSizePasses().c_str(), + GetWebGPUPasses().c_str()); +} + +// Reads command-line flags the file specified in |oconfig_flag|. This string +// is assumed to have the form "-Oconfig=FILENAME". This function parses the +// string and extracts the file name after the '=' sign. +// +// Flags found in |FILENAME| are pushed at the end of the vector |file_flags|. +// +// This function returns true on success, false on failure. +bool ReadFlagsFromFile(const char* oconfig_flag, + std::vector* file_flags) { + const char* fname = strchr(oconfig_flag, '='); + if (fname == nullptr || fname[0] != '=') { + spvtools::Errorf(opt_diagnostic, nullptr, {}, "Invalid -Oconfig flag %s", + oconfig_flag); + return false; + } + fname++; + + std::ifstream input_file; + input_file.open(fname); + if (input_file.fail()) { + spvtools::Errorf(opt_diagnostic, nullptr, {}, "Could not open file '%s'", + fname); + return false; + } + + std::string line; + while (std::getline(input_file, line)) { + // Ignore empty lines and lines starting with the comment marker '#'. + if (line.length() == 0 || line[0] == '#') { + continue; + } + + // Tokenize the line. Add all found tokens to the list of found flags. This + // mimics the way the shell will parse whitespace on the command line. NOTE: + // This does not support quoting and it is not intended to. + std::istringstream iss(line); + while (!iss.eof()) { + std::string flag; + iss >> flag; + file_flags->push_back(flag); + } + } + + return true; +} + +OptStatus ParseFlags(int argc, const char** argv, + spvtools::Optimizer* optimizer, const char** in_file, + const char** out_file, + spvtools::ValidatorOptions* validator_options, + spvtools::OptimizerOptions* optimizer_options); + +// Parses and handles the -Oconfig flag. |prog_name| contains the name of +// the spirv-opt binary (used to build a new argv vector for the recursive +// invocation to ParseFlags). |opt_flag| contains the -Oconfig=FILENAME flag. +// |optimizer|, |in_file|, |out_file|, |validator_options|, and +// |optimizer_options| are as in ParseFlags. +// +// This returns the same OptStatus instance returned by ParseFlags. +OptStatus ParseOconfigFlag(const char* prog_name, const char* opt_flag, + spvtools::Optimizer* optimizer, const char** in_file, + const char** out_file, + spvtools::ValidatorOptions* validator_options, + spvtools::OptimizerOptions* optimizer_options) { + std::vector flags; + flags.push_back(prog_name); + + std::vector file_flags; + if (!ReadFlagsFromFile(opt_flag, &file_flags)) { + spvtools::Error(opt_diagnostic, nullptr, {}, + "Could not read optimizer flags from configuration file"); + return {OPT_STOP, 1}; + } + flags.insert(flags.end(), file_flags.begin(), file_flags.end()); + + const char** new_argv = new const char*[flags.size()]; + for (size_t i = 0; i < flags.size(); i++) { + if (flags[i].find("-Oconfig=") != std::string::npos) { + spvtools::Error( + opt_diagnostic, nullptr, {}, + "Flag -Oconfig= may not be used inside the configuration file"); + return {OPT_STOP, 1}; + } + new_argv[i] = flags[i].c_str(); + } + + auto ret_val = + ParseFlags(static_cast(flags.size()), new_argv, optimizer, in_file, + out_file, validator_options, optimizer_options); + delete[] new_argv; + return ret_val; +} + +// Canonicalize the flag in |argv[argi]| of the form '--pass arg' into +// '--pass=arg'. The optimizer only accepts arguments to pass names that use the +// form '--pass_name=arg'. Since spirv-opt also accepts the other form, this +// function makes the necessary conversion. +// +// Pass flags that require additional arguments should be handled here. Note +// that additional arguments should be given as a single string. If the flag +// requires more than one argument, the pass creator in +// Optimizer::GetPassFromFlag() should parse it accordingly (e.g., see the +// handler for --set-spec-const-default-value). +// +// If the argument requests one of the passes that need an additional argument, +// |argi| is modified to point past the current argument, and the string +// "argv[argi]=argv[argi + 1]" is returned. Otherwise, |argi| is unmodified and +// the string "|argv[argi]|" is returned. +std::string CanonicalizeFlag(const char** argv, int argc, int* argi) { + const char* cur_arg = argv[*argi]; + const char* next_arg = (*argi + 1 < argc) ? argv[*argi + 1] : nullptr; + std::ostringstream canonical_arg; + canonical_arg << cur_arg; + + // NOTE: DO NOT ADD NEW FLAGS HERE. + // + // These flags are supported for backwards compatibility. When adding new + // passes that need extra arguments in its command-line flag, please make them + // use the syntax "--pass_name[=pass_arg]. + if (0 == strcmp(cur_arg, "--set-spec-const-default-value") || + 0 == strcmp(cur_arg, "--loop-fission") || + 0 == strcmp(cur_arg, "--loop-fusion") || + 0 == strcmp(cur_arg, "--loop-unroll-partial") || + 0 == strcmp(cur_arg, "--loop-peeling-threshold")) { + if (next_arg) { + canonical_arg << "=" << next_arg; + ++(*argi); + } + } + + return canonical_arg.str(); +} + +// Parses command-line flags. |argc| contains the number of command-line flags. +// |argv| points to an array of strings holding the flags. |optimizer| is the +// Optimizer instance used to optimize the program. +// +// On return, this function stores the name of the input program in |in_file|. +// The name of the output file in |out_file|. The return value indicates whether +// optimization should continue and a status code indicating an error or +// success. +OptStatus ParseFlags(int argc, const char** argv, + spvtools::Optimizer* optimizer, const char** in_file, + const char** out_file, + spvtools::ValidatorOptions* validator_options, + spvtools::OptimizerOptions* optimizer_options) { + std::vector pass_flags; + bool target_env_set = false; + bool webgpu_mode_set = false; + for (int argi = 1; argi < argc; ++argi) { + const char* cur_arg = argv[argi]; + if ('-' == cur_arg[0]) { + if (0 == strcmp(cur_arg, "--version")) { + spvtools::Logf(opt_diagnostic, SPV_MSG_INFO, nullptr, {}, "%s\n", + spvSoftwareVersionDetailsString()); + return {OPT_STOP, 0}; + } else if (0 == strcmp(cur_arg, "--help") || 0 == strcmp(cur_arg, "-h")) { + PrintUsage(argv[0]); + return {OPT_STOP, 0}; + } else if (0 == strcmp(cur_arg, "-o")) { + if (!*out_file && argi + 1 < argc) { + *out_file = argv[++argi]; + } else { + PrintUsage(argv[0]); + return {OPT_STOP, 1}; + } + } else if ('\0' == cur_arg[1]) { + // Setting a filename of "-" to indicate stdin. + if (!*in_file) { + *in_file = cur_arg; + } else { + spvtools::Error(opt_diagnostic, nullptr, {}, + "More than one input file specified"); + return {OPT_STOP, 1}; + } + } else if (0 == strncmp(cur_arg, "-Oconfig=", sizeof("-Oconfig=") - 1)) { + OptStatus status = + ParseOconfigFlag(argv[0], cur_arg, optimizer, in_file, out_file, + validator_options, optimizer_options); + if (status.action != OPT_CONTINUE) { + return status; + } + } else if (0 == strcmp(cur_arg, "--skip-validation")) { + optimizer_options->set_run_validator(false); + } else if (0 == strcmp(cur_arg, "--print-all")) { + optimizer->SetPrintAll(&std::cerr); + } else if (0 == strcmp(cur_arg, "--time-report")) { + optimizer->SetTimeReport(&std::cerr); + } else if (0 == strcmp(cur_arg, "--relax-struct-store")) { + validator_options->SetRelaxStructStore(true); + } else if (0 == strncmp(cur_arg, "--max-id-bound=", + sizeof("--max-id-bound=") - 1)) { + auto split_flag = spvtools::utils::SplitFlagArgs(cur_arg); + // Will not allow values in the range [2^31,2^32). + uint32_t max_id_bound = + static_cast(atoi(split_flag.second.c_str())); + + // That SPIR-V mandates the minimum value for max id bound but + // implementations may allow higher minimum bounds. + if (max_id_bound < kDefaultMaxIdBound) { + spvtools::Error(opt_diagnostic, nullptr, {}, + "The max id bound must be at least 0x3FFFFF"); + return {OPT_STOP, 1}; + } + optimizer_options->set_max_id_bound(max_id_bound); + validator_options->SetUniversalLimit(spv_validator_limit_max_id_bound, + max_id_bound); + } else if (0 == strncmp(cur_arg, + "--target-env=", sizeof("--target-env=") - 1)) { + if (webgpu_mode_set) { + spvtools::Error(opt_diagnostic, nullptr, {}, + "Cannot use both --webgpu-mode and --target-env at " + "the same time"); + return {OPT_STOP, 1}; + } + const auto split_flag = spvtools::utils::SplitFlagArgs(cur_arg); + const auto target_env_str = split_flag.second.c_str(); + spv_target_env target_env; + if (!spvParseTargetEnv(target_env_str, &target_env)) { + spvtools::Error(opt_diagnostic, nullptr, {}, + "Invalid value passed to --target-env"); + return {OPT_STOP, 1}; + } + optimizer->SetTargetEnv(target_env); + } else if (0 == strcmp(cur_arg, "--webgpu-mode")) { + if (target_env_set) { + spvtools::Error(opt_diagnostic, nullptr, {}, + "Cannot use both --webgpu-mode and --target-env at " + "the same time"); + return {OPT_STOP, 1}; + } + + optimizer->SetTargetEnv(SPV_ENV_WEBGPU_0); + optimizer->RegisterWebGPUPasses(); + } else { + // Some passes used to accept the form '--pass arg', canonicalize them + // to '--pass=arg'. + pass_flags.push_back(CanonicalizeFlag(argv, argc, &argi)); + + // If we were requested to legalize SPIR-V generated from the HLSL + // front-end, skip validation. + if (0 == strcmp(cur_arg, "--legalize-hlsl")) { + validator_options->SetRelaxLogicalPointer(true); + } + } + } else { + if (!*in_file) { + *in_file = cur_arg; + } else { + spvtools::Error(opt_diagnostic, nullptr, {}, + "More than one input file specified"); + return {OPT_STOP, 1}; + } + } + } + + if (!optimizer->RegisterPassesFromFlags(pass_flags)) { + return {OPT_STOP, 1}; + } + + return {OPT_CONTINUE, 0}; +} + +} // namespace + +int main(int argc, const char** argv) { + const char* in_file = nullptr; + const char* out_file = nullptr; + + spv_target_env target_env = kDefaultEnvironment; + + spvtools::Optimizer optimizer(target_env); + optimizer.SetMessageConsumer(spvtools::utils::CLIMessageConsumer); + + spvtools::ValidatorOptions validator_options; + spvtools::OptimizerOptions optimizer_options; + OptStatus status = ParseFlags(argc, argv, &optimizer, &in_file, &out_file, + &validator_options, &optimizer_options); + optimizer_options.set_validator_options(validator_options); + + if (status.action == OPT_STOP) { + return status.code; + } + + if (out_file == nullptr) { + spvtools::Error(opt_diagnostic, nullptr, {}, "-o required"); + return 1; + } + + std::vector binary; + if (!ReadFile(in_file, "rb", &binary)) { + return 1; + } + + // By using the same vector as input and output, we save time in the case + // that there was no change. + bool ok = + optimizer.Run(binary.data(), binary.size(), &binary, optimizer_options); + + if (!WriteFile(out_file, "wb", binary.data(), binary.size())) { + return 1; + } + + return ok ? 0 : 1; +} diff --git a/tools/reduce/reduce.cpp b/tools/reduce/reduce.cpp new file mode 100644 index 000000000..65325f745 --- /dev/null +++ b/tools/reduce/reduce.cpp @@ -0,0 +1,242 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "source/opt/build_module.h" +#include "source/opt/ir_context.h" +#include "source/opt/log.h" +#include "source/reduce/operand_to_const_reduction_pass.h" +#include "source/reduce/operand_to_dominating_id_reduction_pass.h" +#include "source/reduce/operand_to_undef_reduction_pass.h" +#include "source/reduce/reducer.h" +#include "source/reduce/remove_opname_instruction_reduction_pass.h" +#include "source/reduce/remove_unreferenced_instruction_reduction_pass.h" +#include "source/reduce/structured_loop_to_selection_reduction_pass.h" +#include "source/spirv_reducer_options.h" +#include "source/util/make_unique.h" +#include "source/util/string_utils.h" +#include "spirv-tools/libspirv.hpp" +#include "tools/io.h" +#include "tools/util/cli_consumer.h" + +using namespace spvtools::reduce; + +namespace { + +using ErrorOrInt = std::pair; + +// Check that the std::system function can actually be used. +bool CheckExecuteCommand() { + int res = std::system(nullptr); + return res != 0; +} + +// Execute a command using the shell. +// Returns true if and only if the command's exit status was 0. +bool ExecuteCommand(const std::string& command) { + errno = 0; + int status = std::system(command.c_str()); + assert(errno == 0 && "failed to execute command"); + // The result returned by 'system' is implementation-defined, but is + // usually the case that the returned value is 0 when the command's exit + // code was 0. We are assuming that here, and that's all we depend on. + return status == 0; +} + +// Status and actions to perform after parsing command-line arguments. +enum ReduceActions { REDUCE_CONTINUE, REDUCE_STOP }; + +struct ReduceStatus { + ReduceActions action; + int code; +}; + +void PrintUsage(const char* program) { + // NOTE: Please maintain flags in lexicographical order. + printf( + R"(%s - Reduce a SPIR-V binary file with respect to a user-provided + interestingness test. + +USAGE: %s [options] + +The SPIR-V binary is read from . + +Whether a binary is interesting is determined by , which +is typically a script. + +NOTE: The reducer is a work in progress. + +Options (in lexicographical order): + -h, --help + Print this help. + --step-limit + 32-bit unsigned integer specifying maximum number of + steps the reducer will take before giving up. + --version + Display reducer version information. +)", + program, program); +} + +// Message consumer for this tool. Used to emit diagnostics during +// initialization and setup. Note that |source| and |position| are irrelevant +// here because we are still not processing a SPIR-V input file. +void ReduceDiagnostic(spv_message_level_t level, const char* /*source*/, + const spv_position_t& /*position*/, const char* message) { + if (level == SPV_MSG_ERROR) { + fprintf(stderr, "error: "); + } + fprintf(stderr, "%s\n", message); +} + +ReduceStatus ParseFlags(int argc, const char** argv, const char** in_file, + const char** interestingness_test, + spvtools::ReducerOptions* reducer_options) { + uint32_t positional_arg_index = 0; + + for (int argi = 1; argi < argc; ++argi) { + const char* cur_arg = argv[argi]; + if ('-' == cur_arg[0]) { + if (0 == strcmp(cur_arg, "--version")) { + spvtools::Logf(ReduceDiagnostic, SPV_MSG_INFO, nullptr, {}, "%s\n", + spvSoftwareVersionDetailsString()); + return {REDUCE_STOP, 0}; + } else if (0 == strcmp(cur_arg, "--help") || 0 == strcmp(cur_arg, "-h")) { + PrintUsage(argv[0]); + return {REDUCE_STOP, 0}; + } else if ('\0' == cur_arg[1]) { + // We do not support reduction from standard input. We could support + // this if there was a compelling use case. + PrintUsage(argv[0]); + return {REDUCE_STOP, 0}; + } else if (0 == strncmp(cur_arg, + "--step-limit=", sizeof("--step-limit=") - 1)) { + const auto split_flag = spvtools::utils::SplitFlagArgs(cur_arg); + char* end = nullptr; + errno = 0; + const auto step_limit = + static_cast(strtol(split_flag.second.c_str(), &end, 10)); + assert(end != split_flag.second.c_str() && errno == 0); + reducer_options->set_step_limit(step_limit); + } + } else if (positional_arg_index == 0) { + // Input file name + assert(!*in_file); + *in_file = cur_arg; + positional_arg_index++; + } else if (positional_arg_index == 1) { + assert(!*interestingness_test); + *interestingness_test = cur_arg; + positional_arg_index++; + } else { + spvtools::Error(ReduceDiagnostic, nullptr, {}, + "Too many positional arguments specified"); + return {REDUCE_STOP, 1}; + } + } + + if (!*in_file) { + spvtools::Error(ReduceDiagnostic, nullptr, {}, "No input file specified"); + return {REDUCE_STOP, 1}; + } + + if (!*interestingness_test) { + spvtools::Error(ReduceDiagnostic, nullptr, {}, + "No interestingness test specified"); + return {REDUCE_STOP, 1}; + } + + return {REDUCE_CONTINUE, 0}; +} + +} // namespace + +const auto kDefaultEnvironment = SPV_ENV_UNIVERSAL_1_3; + +int main(int argc, const char** argv) { + const char* in_file = nullptr; + const char* interestingness_test = nullptr; + + spv_target_env target_env = kDefaultEnvironment; + spvtools::ReducerOptions reducer_options; + + ReduceStatus status = + ParseFlags(argc, argv, &in_file, &interestingness_test, &reducer_options); + + if (status.action == REDUCE_STOP) { + return status.code; + } + + if (!CheckExecuteCommand()) { + std::cerr << "could not find shell interpreter for executing a command" + << std::endl; + return 2; + } + + Reducer reducer(target_env); + + reducer.SetInterestingnessFunction( + [interestingness_test](std::vector binary, + uint32_t reductions_applied) -> bool { + std::stringstream ss; + ss << "temp_" << std::setw(4) << std::setfill('0') << reductions_applied + << ".spv"; + const auto spv_file = ss.str(); + const std::string command = + std::string(interestingness_test) + " " + spv_file; + auto write_file_succeeded = + WriteFile(spv_file.c_str(), "wb", &binary[0], binary.size()); + (void)(write_file_succeeded); + assert(write_file_succeeded); + return ExecuteCommand(command); + }); + + reducer.AddReductionPass( + spvtools::MakeUnique(target_env)); + reducer.AddReductionPass( + spvtools::MakeUnique(target_env)); + reducer.AddReductionPass( + spvtools::MakeUnique(target_env)); + reducer.AddReductionPass( + spvtools::MakeUnique(target_env)); + reducer.AddReductionPass( + spvtools::MakeUnique( + target_env)); + reducer.AddReductionPass( + spvtools::MakeUnique(target_env)); + + reducer.SetMessageConsumer(spvtools::utils::CLIMessageConsumer); + + std::vector binary_in; + if (!ReadFile(in_file, "rb", &binary_in)) { + return 1; + } + + std::vector binary_out; + const auto reduction_status = + reducer.Run(std::move(binary_in), &binary_out, reducer_options); + + if (reduction_status == + Reducer::ReductionResultStatus::kInitialStateNotInteresting || + !WriteFile("_reduced_final.spv", "wb", binary_out.data(), + binary_out.size())) { + return 1; + } + + return 0; +} diff --git a/tools/stats/spirv_stats.cpp b/tools/stats/spirv_stats.cpp new file mode 100644 index 000000000..609a6c972 --- /dev/null +++ b/tools/stats/spirv_stats.cpp @@ -0,0 +1,165 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/stats/spirv_stats.h" + +#include + +#include +#include +#include + +#include "source/diagnostic.h" +#include "source/enum_string_mapping.h" +#include "source/extensions.h" +#include "source/id_descriptor.h" +#include "source/instruction.h" +#include "source/opcode.h" +#include "source/operand.h" +#include "source/val/instruction.h" +#include "source/val/validate.h" +#include "source/val/validation_state.h" +#include "spirv-tools/libspirv.h" + +namespace spvtools { +namespace stats { +namespace { + +// Helper class for stats aggregation. Receives as in/out parameter. +// Constructs ValidationState and updates it by running validator for each +// instruction. +class StatsAggregator { + public: + StatsAggregator(SpirvStats* in_out_stats, const val::ValidationState_t* state) + : stats_(in_out_stats), vstate_(state) {} + + // Processes the instructions to collect stats. + void aggregate() { + const auto& instructions = vstate_->ordered_instructions(); + + ++stats_->version_hist[vstate_->version()]; + ++stats_->generator_hist[vstate_->generator()]; + + for (size_t i = 0; i < instructions.size(); ++i) { + const auto& inst = instructions[i]; + + ProcessOpcode(&inst, i); + ProcessCapability(&inst); + ProcessExtension(&inst); + ProcessConstant(&inst); + } + } + + // Collects OpCapability statistics. + void ProcessCapability(const val::Instruction* inst) { + if (inst->opcode() != SpvOpCapability) return; + const uint32_t capability = inst->word(inst->operands()[0].offset); + ++stats_->capability_hist[capability]; + } + + // Collects OpExtension statistics. + void ProcessExtension(const val::Instruction* inst) { + if (inst->opcode() != SpvOpExtension) return; + const std::string extension = GetExtensionString(&inst->c_inst()); + ++stats_->extension_hist[extension]; + } + + // Collects OpCode statistics. + void ProcessOpcode(const val::Instruction* inst, size_t idx) { + const SpvOp opcode = inst->opcode(); + ++stats_->opcode_hist[opcode]; + + if (idx == 0) return; + + --idx; + + const auto& instructions = vstate_->ordered_instructions(); + + auto step_it = stats_->opcode_markov_hist.begin(); + for (; step_it != stats_->opcode_markov_hist.end(); --idx, ++step_it) { + auto& hist = (*step_it)[instructions[idx].opcode()]; + ++hist[opcode]; + + if (idx == 0) break; + } + } + + // Collects OpConstant statistics. + void ProcessConstant(const val::Instruction* inst) { + if (inst->opcode() != SpvOpConstant) return; + + const uint32_t type_id = inst->GetOperandAs(0); + const auto type_decl_it = vstate_->all_definitions().find(type_id); + assert(type_decl_it != vstate_->all_definitions().end()); + + const val::Instruction& type_decl_inst = *type_decl_it->second; + const SpvOp type_op = type_decl_inst.opcode(); + if (type_op == SpvOpTypeInt) { + const uint32_t bit_width = type_decl_inst.GetOperandAs(1); + const uint32_t is_signed = type_decl_inst.GetOperandAs(2); + assert(is_signed == 0 || is_signed == 1); + if (bit_width == 16) { + if (is_signed) + ++stats_->s16_constant_hist[inst->GetOperandAs(2)]; + else + ++stats_->u16_constant_hist[inst->GetOperandAs(2)]; + } else if (bit_width == 32) { + if (is_signed) + ++stats_->s32_constant_hist[inst->GetOperandAs(2)]; + else + ++stats_->u32_constant_hist[inst->GetOperandAs(2)]; + } else if (bit_width == 64) { + if (is_signed) + ++stats_->s64_constant_hist[inst->GetOperandAs(2)]; + else + ++stats_->u64_constant_hist[inst->GetOperandAs(2)]; + } else { + assert(false && "TypeInt bit width is not 16, 32 or 64"); + } + } else if (type_op == SpvOpTypeFloat) { + const uint32_t bit_width = type_decl_inst.GetOperandAs(1); + if (bit_width == 32) { + ++stats_->f32_constant_hist[inst->GetOperandAs(2)]; + } else if (bit_width == 64) { + ++stats_->f64_constant_hist[inst->GetOperandAs(2)]; + } else { + assert(bit_width == 16); + } + } + } + + private: + SpirvStats* stats_; + const val::ValidationState_t* vstate_; + IdDescriptorCollection id_descriptors_; +}; + +} // namespace + +spv_result_t AggregateStats(const spv_context context, const uint32_t* words, + const size_t num_words, spv_diagnostic* pDiagnostic, + SpirvStats* stats) { + std::unique_ptr vstate; + spv_validator_options_t options; + spv_result_t result = ValidateBinaryAndKeepValidationState( + context, &options, words, num_words, pDiagnostic, &vstate); + if (result != SPV_SUCCESS) return result; + + StatsAggregator stats_aggregator(stats, vstate.get()); + stats_aggregator.aggregate(); + return SPV_SUCCESS; +} + +} // namespace stats +} // namespace spvtools diff --git a/tools/stats/spirv_stats.h b/tools/stats/spirv_stats.h new file mode 100644 index 000000000..757695775 --- /dev/null +++ b/tools/stats/spirv_stats.h @@ -0,0 +1,93 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TOOLS_STATS_SPIRV_STATS_H_ +#define TOOLS_STATS_SPIRV_STATS_H_ + +#include +#include +#include +#include +#include + +#include "spirv-tools/libspirv.hpp" + +namespace spvtools { +namespace stats { + +struct SpirvStats { + // Version histogram, version_word -> count. + std::unordered_map version_hist; + + // Generator histogram, generator_word -> count. + std::unordered_map generator_hist; + + // Capability histogram, SpvCapabilityXXX -> count. + std::unordered_map capability_hist; + + // Extension histogram, extension_string -> count. + std::unordered_map extension_hist; + + // Opcode histogram, SpvOpXXX -> count. + std::unordered_map opcode_hist; + + // OpConstant u16 histogram, value -> count. + std::unordered_map u16_constant_hist; + + // OpConstant u32 histogram, value -> count. + std::unordered_map u32_constant_hist; + + // OpConstant u64 histogram, value -> count. + std::unordered_map u64_constant_hist; + + // OpConstant s16 histogram, value -> count. + std::unordered_map s16_constant_hist; + + // OpConstant s32 histogram, value -> count. + std::unordered_map s32_constant_hist; + + // OpConstant s64 histogram, value -> count. + std::unordered_map s64_constant_hist; + + // OpConstant f32 histogram, value -> count. + std::unordered_map f32_constant_hist; + + // OpConstant f64 histogram, value -> count. + std::unordered_map f64_constant_hist; + + // Used to collect statistics on opcodes triggering other opcodes. + // Container scheme: gap between instructions -> cue opcode -> later opcode + // -> count. + // For example opcode_markov_hist[2][OpFMul][OpFAdd] corresponds to + // the number of times an OpMul appears, followed by 2 other instructions, + // followed by OpFAdd. + // opcode_markov_hist[0][OpFMul][OpFAdd] corresponds to how many times + // OpFMul appears, directly followed by OpFAdd. + // The size of the outer std::vector also serves as an input parameter, + // determining how many steps will be collected. + // I.e. do opcode_markov_hist.resize(1) to collect data for one step only. + std::vector< + std::unordered_map>> + opcode_markov_hist; +}; + +// Aggregates existing |stats| with new stats extracted from |binary|. +spv_result_t AggregateStats(const spv_context context, const uint32_t* words, + const size_t num_words, spv_diagnostic* pDiagnostic, + SpirvStats* stats); + +} // namespace stats +} // namespace spvtools + +#endif // TOOLS_STATS_SPIRV_STATS_H_ diff --git a/tools/stats/stats.cpp b/tools/stats/stats.cpp new file mode 100644 index 000000000..30e3bccd6 --- /dev/null +++ b/tools/stats/stats.cpp @@ -0,0 +1,173 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include "spirv-tools/libspirv.h" +#include "tools/io.h" +#include "tools/stats/spirv_stats.h" +#include "tools/stats/stats_analyzer.h" + +namespace { + +void PrintUsage(char* argv0) { + printf( + R"(%s - Collect statistics from one or more SPIR-V binary file(s). + +USAGE: %s [options] [] + +TIP: In order to collect statistics from all .spv files under current dir use +find . -name "*.spv" -print0 | xargs -0 -s 2000000 %s + +Options: + -h, --help + Print this help. + + -v, --verbose + Print additional info to stderr. +)", + argv0, argv0, argv0); +} + +void DiagnosticsMessageHandler(spv_message_level_t level, const char*, + const spv_position_t& position, + const char* message) { + switch (level) { + case SPV_MSG_FATAL: + case SPV_MSG_INTERNAL_ERROR: + case SPV_MSG_ERROR: + std::cerr << "error: " << position.index << ": " << message << std::endl; + break; + case SPV_MSG_WARNING: + std::cout << "warning: " << position.index << ": " << message + << std::endl; + break; + case SPV_MSG_INFO: + std::cout << "info: " << position.index << ": " << message << std::endl; + break; + default: + break; + } +} + +} // namespace + +int main(int argc, char** argv) { + bool continue_processing = true; + int return_code = 0; + + bool expect_output_path = false; + bool verbose = false; + + std::vector paths; + const char* output_path = nullptr; + + for (int argi = 1; continue_processing && argi < argc; ++argi) { + const char* cur_arg = argv[argi]; + if ('-' == cur_arg[0]) { + if (0 == strcmp(cur_arg, "--help") || 0 == strcmp(cur_arg, "-h")) { + PrintUsage(argv[0]); + continue_processing = false; + return_code = 0; + } else if (0 == strcmp(cur_arg, "--verbose") || + 0 == strcmp(cur_arg, "-v")) { + verbose = true; + } else if (0 == strcmp(cur_arg, "--output") || + 0 == strcmp(cur_arg, "-o")) { + expect_output_path = true; + } else { + PrintUsage(argv[0]); + continue_processing = false; + return_code = 1; + } + } else { + if (expect_output_path) { + output_path = cur_arg; + expect_output_path = false; + } else { + paths.push_back(cur_arg); + } + } + } + + // Exit if command line parsing was not successful. + if (!continue_processing) { + return return_code; + } + + std::cerr << "Processing " << paths.size() << " files..." << std::endl; + + spvtools::Context ctx(SPV_ENV_UNIVERSAL_1_1); + ctx.SetMessageConsumer(DiagnosticsMessageHandler); + + spvtools::stats::SpirvStats stats; + stats.opcode_markov_hist.resize(1); + + for (size_t index = 0; index < paths.size(); ++index) { + const size_t kMilestonePeriod = 1000; + if (verbose) { + if (index % kMilestonePeriod == kMilestonePeriod - 1) + std::cerr << "Processed " << index + 1 << " files..." << std::endl; + } + + const char* path = paths[index]; + std::vector contents; + if (!ReadFile(path, "rb", &contents)) return 1; + + if (SPV_SUCCESS != + spvtools::stats::AggregateStats(ctx.CContext(), contents.data(), + contents.size(), nullptr, &stats)) { + std::cerr << "error: Failed to aggregate stats for " << path << std::endl; + return 1; + } + } + + spvtools::stats::StatsAnalyzer analyzer(stats); + + std::ofstream fout; + if (output_path) { + fout.open(output_path); + if (!fout.is_open()) { + std::cerr << "error: Failed to open " << output_path << std::endl; + return 1; + } + } + + std::ostream& out = fout.is_open() ? fout : std::cout; + out << std::endl; + analyzer.WriteVersion(out); + analyzer.WriteGenerator(out); + + out << std::endl; + analyzer.WriteCapability(out); + + out << std::endl; + analyzer.WriteExtension(out); + + out << std::endl; + analyzer.WriteOpcode(out); + + out << std::endl; + analyzer.WriteOpcodeMarkov(out); + + out << std::endl; + analyzer.WriteConstantLiterals(out); + + return 0; +} diff --git a/tools/stats/stats_analyzer.cpp b/tools/stats/stats_analyzer.cpp new file mode 100644 index 000000000..6d4cabbf6 --- /dev/null +++ b/tools/stats/stats_analyzer.cpp @@ -0,0 +1,235 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/stats/stats_analyzer.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "source/comp/markv_model.h" +#include "source/enum_string_mapping.h" +#include "source/latest_version_spirv_header.h" +#include "source/opcode.h" +#include "source/operand.h" +#include "source/spirv_constant.h" + +namespace spvtools { +namespace stats { +namespace { + +// Signals that the value is not in the coding scheme and a fallback method +// needs to be used. +const uint64_t kMarkvNoneOfTheAbove = + comp::MarkvModel::GetMarkvNoneOfTheAbove(); + +std::string GetVersionString(uint32_t word) { + std::stringstream ss; + ss << "Version " << SPV_SPIRV_VERSION_MAJOR_PART(word) << "." + << SPV_SPIRV_VERSION_MINOR_PART(word); + return ss.str(); +} + +std::string GetGeneratorString(uint32_t word) { + return spvGeneratorStr(SPV_GENERATOR_TOOL_PART(word)); +} + +std::string GetOpcodeString(uint32_t word) { + return spvOpcodeString(static_cast(word)); +} + +std::string GetCapabilityString(uint32_t word) { + return CapabilityToString(static_cast(word)); +} + +template +std::string KeyIsLabel(T key) { + std::stringstream ss; + ss << key; + return ss.str(); +} + +template +std::unordered_map GetRecall( + const std::unordered_map& hist, uint64_t total) { + std::unordered_map freq; + for (const auto& pair : hist) { + const double frequency = + static_cast(pair.second) / static_cast(total); + freq.emplace(pair.first, frequency); + } + return freq; +} + +template +std::unordered_map GetPrevalence( + const std::unordered_map& hist) { + uint64_t total = 0; + for (const auto& pair : hist) { + total += pair.second; + } + + return GetRecall(hist, total); +} + +// Writes |freq| to |out| sorted by frequency in the following format: +// LABEL3 70% +// LABEL1 20% +// LABEL2 10% +// |label_from_key| is used to convert |Key| to label. +template +void WriteFreq(std::ostream& out, const std::unordered_map& freq, + std::string (*label_from_key)(Key)) { + std::vector> sorted_freq(freq.begin(), freq.end()); + std::sort(sorted_freq.begin(), sorted_freq.end(), + [](const std::pair& left, + const std::pair& right) { + return left.second > right.second; + }); + + for (const auto& pair : sorted_freq) { + if (pair.second < 0.001) break; + out << label_from_key(pair.first) << " " << pair.second * 100.0 << "%" + << std::endl; + } +} + +} // namespace + +StatsAnalyzer::StatsAnalyzer(const SpirvStats& stats) : stats_(stats) { + num_modules_ = 0; + for (const auto& pair : stats_.version_hist) { + num_modules_ += pair.second; + } + + version_freq_ = GetRecall(stats_.version_hist, num_modules_); + generator_freq_ = GetRecall(stats_.generator_hist, num_modules_); + capability_freq_ = GetRecall(stats_.capability_hist, num_modules_); + extension_freq_ = GetRecall(stats_.extension_hist, num_modules_); + opcode_freq_ = GetPrevalence(stats_.opcode_hist); +} + +void StatsAnalyzer::WriteVersion(std::ostream& out) { + WriteFreq(out, version_freq_, GetVersionString); +} + +void StatsAnalyzer::WriteGenerator(std::ostream& out) { + WriteFreq(out, generator_freq_, GetGeneratorString); +} + +void StatsAnalyzer::WriteCapability(std::ostream& out) { + WriteFreq(out, capability_freq_, GetCapabilityString); +} + +void StatsAnalyzer::WriteExtension(std::ostream& out) { + WriteFreq(out, extension_freq_, KeyIsLabel); +} + +void StatsAnalyzer::WriteOpcode(std::ostream& out) { + out << "Total unique opcodes used: " << opcode_freq_.size() << std::endl; + WriteFreq(out, opcode_freq_, GetOpcodeString); +} + +void StatsAnalyzer::WriteConstantLiterals(std::ostream& out) { + out << "Constant literals" << std::endl; + + out << "Float 32" << std::endl; + WriteFreq(out, GetPrevalence(stats_.f32_constant_hist), KeyIsLabel); + + out << std::endl << "Float 64" << std::endl; + WriteFreq(out, GetPrevalence(stats_.f64_constant_hist), KeyIsLabel); + + out << std::endl << "Unsigned int 16" << std::endl; + WriteFreq(out, GetPrevalence(stats_.u16_constant_hist), KeyIsLabel); + + out << std::endl << "Signed int 16" << std::endl; + WriteFreq(out, GetPrevalence(stats_.s16_constant_hist), KeyIsLabel); + + out << std::endl << "Unsigned int 32" << std::endl; + WriteFreq(out, GetPrevalence(stats_.u32_constant_hist), KeyIsLabel); + + out << std::endl << "Signed int 32" << std::endl; + WriteFreq(out, GetPrevalence(stats_.s32_constant_hist), KeyIsLabel); + + out << std::endl << "Unsigned int 64" << std::endl; + WriteFreq(out, GetPrevalence(stats_.u64_constant_hist), KeyIsLabel); + + out << std::endl << "Signed int 64" << std::endl; + WriteFreq(out, GetPrevalence(stats_.s64_constant_hist), KeyIsLabel); +} + +void StatsAnalyzer::WriteOpcodeMarkov(std::ostream& out) { + if (stats_.opcode_markov_hist.empty()) return; + + const std::unordered_map>& + cue_to_hist = stats_.opcode_markov_hist[0]; + + // Sort by prevalence of the opcodes in opcode_freq_ (descending). + std::vector>> + sorted_cue_to_hist(cue_to_hist.begin(), cue_to_hist.end()); + std::sort( + sorted_cue_to_hist.begin(), sorted_cue_to_hist.end(), + [this](const std::pair>& + left, + const std::pair>& + right) { + const double lf = opcode_freq_[left.first]; + const double rf = opcode_freq_[right.first]; + if (lf == rf) return right.first > left.first; + return lf > rf; + }); + + for (const auto& kv : sorted_cue_to_hist) { + const uint32_t cue = kv.first; + const double kFrequentEnoughToAnalyze = 0.0001; + if (opcode_freq_[cue] < kFrequentEnoughToAnalyze) continue; + + const std::unordered_map& hist = kv.second; + + uint32_t total = 0; + for (const auto& pair : hist) { + total += pair.second; + } + + std::vector> sorted_hist(hist.begin(), + hist.end()); + std::sort(sorted_hist.begin(), sorted_hist.end(), + [](const std::pair& left, + const std::pair& right) { + if (left.second == right.second) + return right.first > left.first; + return left.second > right.second; + }); + + for (const auto& pair : sorted_hist) { + const double prior = opcode_freq_[pair.first]; + const double posterior = + static_cast(pair.second) / static_cast(total); + out << GetOpcodeString(cue) << " -> " << GetOpcodeString(pair.first) + << " " << posterior * 100 << "% (base rate " << prior * 100 + << "%, pair occurrences " << pair.second << ")" << std::endl; + } + } +} + +} // namespace stats +} // namespace spvtools diff --git a/tools/stats/stats_analyzer.h b/tools/stats/stats_analyzer.h new file mode 100644 index 000000000..f1c37bfaa --- /dev/null +++ b/tools/stats/stats_analyzer.h @@ -0,0 +1,58 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TOOLS_STATS_STATS_ANALYZER_H_ +#define TOOLS_STATS_STATS_ANALYZER_H_ + +#include +#include + +#include "tools/stats/spirv_stats.h" + +namespace spvtools { +namespace stats { + +class StatsAnalyzer { + public: + explicit StatsAnalyzer(const SpirvStats& stats); + + // Writes respective histograms to |out|. + void WriteVersion(std::ostream& out); + void WriteGenerator(std::ostream& out); + void WriteCapability(std::ostream& out); + void WriteExtension(std::ostream& out); + void WriteOpcode(std::ostream& out); + void WriteConstantLiterals(std::ostream& out); + + // Writes first order Markov analysis to |out|. + // stats_.opcode_markov_hist needs to contain raw data for at least one + // level. + void WriteOpcodeMarkov(std::ostream& out); + + private: + const SpirvStats& stats_; + + uint32_t num_modules_; + + std::unordered_map version_freq_; + std::unordered_map generator_freq_; + std::unordered_map capability_freq_; + std::unordered_map extension_freq_; + std::unordered_map opcode_freq_; +}; + +} // namespace stats +} // namespace spvtools + +#endif // TOOLS_STATS_STATS_ANALYZER_H_ diff --git a/tools/util/cli_consumer.cpp b/tools/util/cli_consumer.cpp new file mode 100644 index 000000000..77db734e8 --- /dev/null +++ b/tools/util/cli_consumer.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/util/cli_consumer.h" + +#include + +namespace spvtools { +namespace utils { + +void CLIMessageConsumer(spv_message_level_t level, const char*, + const spv_position_t& position, const char* message) { + switch (level) { + case SPV_MSG_FATAL: + case SPV_MSG_INTERNAL_ERROR: + case SPV_MSG_ERROR: + std::cerr << "error: line " << position.index << ": " << message + << std::endl; + break; + case SPV_MSG_WARNING: + std::cout << "warning: line " << position.index << ": " << message + << std::endl; + break; + case SPV_MSG_INFO: + std::cout << "info: line " << position.index << ": " << message + << std::endl; + break; + default: + break; + } +} + +} // namespace utils +} // namespace spvtools diff --git a/tools/util/cli_consumer.h b/tools/util/cli_consumer.h new file mode 100644 index 000000000..ca3d91b95 --- /dev/null +++ b/tools/util/cli_consumer.h @@ -0,0 +1,31 @@ +// Copyright (c) 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_UTIL_CLI_CONSUMMER_H_ +#define SOURCE_UTIL_CLI_CONSUMMER_H_ + +#include + +namespace spvtools { +namespace utils { + +// A message consumer that can be used by command line tools like spirv-opt and +// spirv-val to display messages. +void CLIMessageConsumer(spv_message_level_t level, const char*, + const spv_position_t& position, const char* message); + +} // namespace utils +} // namespace spvtools + +#endif // SOURCE_UTIL_CLI_CONSUMMER_H_ diff --git a/tools/val/val.cpp b/tools/val/val.cpp new file mode 100644 index 000000000..8b1d048fd --- /dev/null +++ b/tools/val/val.cpp @@ -0,0 +1,182 @@ +// Copyright (c) 2015-2016 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "source/spirv_target_env.h" +#include "source/spirv_validator_options.h" +#include "spirv-tools/libspirv.hpp" +#include "tools/io.h" +#include "tools/util/cli_consumer.h" + +void print_usage(char* argv0) { + printf( + R"(%s - Validate a SPIR-V binary file. + +USAGE: %s [options] [] + +The SPIR-V binary is read from . If no file is specified, +or if the filename is "-", then the binary is read from standard input. + +NOTE: The validator is a work in progress. + +Options: + -h, --help Print this help. + --max-struct-members + --max-struct-depth + --max-local-variables + --max-global-variables + --max-switch-branches + --max-function-args + --max-control-flow-nesting-depth + --max-access-chain-indexes + --max-id-bound + --relax-logical-pointer Allow allocating an object of a pointer type and returning + a pointer value from a function in logical addressing mode + --relax-block-layout Enable VK_KHR_relaxed_block_layout when checking standard + uniform, storage buffer, and push constant layouts. + This is the default when targeting Vulkan 1.1 or later. + --scalar-block-layout Enable VK_EXT_scalar_block_layout when checking standard + uniform, storage buffer, and push constant layouts. Scalar layout + rules are more permissive than relaxed block layout so in effect + this will override the --relax-block-layout option. + --skip-block-layout Skip checking standard uniform/storage buffer layout. + Overrides any --relax-block-layout or --scalar-block-layout option. + --relax-struct-store Allow store from one struct type to a + different type with compatible layout and + members. + --version Display validator version information. + --target-env {vulkan1.0|vulkan1.1|opencl2.2|spv1.0|spv1.1|spv1.2|spv1.3|webgpu0} + Use Vulkan 1.0, Vulkan 1.1, OpenCL 2.2, SPIR-V 1.0, + SPIR-V 1.1, SPIR-V 1.2, SPIR-V 1.3 or WIP WebGPU validation rules. +)", + argv0, argv0); +} + +int main(int argc, char** argv) { + const char* inFile = nullptr; + spv_target_env target_env = SPV_ENV_UNIVERSAL_1_3; + spvtools::ValidatorOptions options; + bool continue_processing = true; + int return_code = 0; + + for (int argi = 1; continue_processing && argi < argc; ++argi) { + const char* cur_arg = argv[argi]; + if ('-' == cur_arg[0]) { + if (0 == strncmp(cur_arg, "--max-", 6)) { + if (argi + 1 < argc) { + spv_validator_limit limit_type; + if (spvParseUniversalLimitsOptions(cur_arg, &limit_type)) { + uint32_t limit = 0; + if (sscanf(argv[++argi], "%u", &limit)) { + options.SetUniversalLimit(limit_type, limit); + } else { + fprintf(stderr, "error: missing argument to %s\n", cur_arg); + continue_processing = false; + return_code = 1; + } + } else { + fprintf(stderr, "error: unrecognized option: %s\n", cur_arg); + continue_processing = false; + return_code = 1; + } + } else { + fprintf(stderr, "error: Missing argument to %s\n", cur_arg); + continue_processing = false; + return_code = 1; + } + } else if (0 == strcmp(cur_arg, "--version")) { + printf("%s\n", spvSoftwareVersionDetailsString()); + printf("Targets:\n %s\n %s\n %s\n %s\n %s\n %s\n %s\n %s\n", + spvTargetEnvDescription(SPV_ENV_UNIVERSAL_1_0), + spvTargetEnvDescription(SPV_ENV_UNIVERSAL_1_1), + spvTargetEnvDescription(SPV_ENV_UNIVERSAL_1_2), + spvTargetEnvDescription(SPV_ENV_UNIVERSAL_1_3), + spvTargetEnvDescription(SPV_ENV_OPENCL_2_2), + spvTargetEnvDescription(SPV_ENV_VULKAN_1_0), + spvTargetEnvDescription(SPV_ENV_VULKAN_1_1), + spvTargetEnvDescription(SPV_ENV_WEBGPU_0)); + continue_processing = false; + return_code = 0; + } else if (0 == strcmp(cur_arg, "--help") || 0 == strcmp(cur_arg, "-h")) { + print_usage(argv[0]); + continue_processing = false; + return_code = 0; + } else if (0 == strcmp(cur_arg, "--target-env")) { + if (argi + 1 < argc) { + const auto env_str = argv[++argi]; + if (!spvParseTargetEnv(env_str, &target_env)) { + fprintf(stderr, "error: Unrecognized target env: %s\n", env_str); + continue_processing = false; + return_code = 1; + } + } else { + fprintf(stderr, "error: Missing argument to --target-env\n"); + continue_processing = false; + return_code = 1; + } + } else if (0 == strcmp(cur_arg, "--relax-logical-pointer")) { + options.SetRelaxLogicalPointer(true); + } else if (0 == strcmp(cur_arg, "--relax-block-layout")) { + options.SetRelaxBlockLayout(true); + } else if (0 == strcmp(cur_arg, "--scalar-block-layout")) { + options.SetScalarBlockLayout(true); + } else if (0 == strcmp(cur_arg, "--skip-block-layout")) { + options.SetSkipBlockLayout(true); + } else if (0 == strcmp(cur_arg, "--relax-struct-store")) { + options.SetRelaxStructStore(true); + } else if (0 == cur_arg[1]) { + // Setting a filename of "-" to indicate stdin. + if (!inFile) { + inFile = cur_arg; + } else { + fprintf(stderr, "error: More than one input file specified\n"); + continue_processing = false; + return_code = 1; + } + } else { + print_usage(argv[0]); + continue_processing = false; + return_code = 1; + } + } else { + if (!inFile) { + inFile = cur_arg; + } else { + fprintf(stderr, "error: More than one input file specified\n"); + continue_processing = false; + return_code = 1; + } + } + } + + // Exit if command line parsing was not successful. + if (!continue_processing) { + return return_code; + } + + std::vector contents; + if (!ReadFile(inFile, "rb", &contents)) return 1; + + spvtools::SpirvTools tools(target_env); + tools.SetMessageConsumer(spvtools::utils::CLIMessageConsumer); + + bool succeed = tools.Validate(contents.data(), contents.size(), options); + + return !succeed; +} diff --git a/utils/check_code_format.sh b/utils/check_code_format.sh new file mode 100755 index 000000000..a6a58796a --- /dev/null +++ b/utils/check_code_format.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Copyright (c) 2017 Google Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Script to determine if source code in Pull Request is properly formatted. +# Exits with non 0 exit code if formatting is needed. +# +# This script assumes to be invoked at the project root directory. + +FILES_TO_CHECK=$(git diff --name-only master | grep -E ".*\.(cpp|cc|c\+\+|cxx|c|h|hpp)$") + +if [ -z "${FILES_TO_CHECK}" ]; then + echo "No source code to check for formatting." + exit 0 +fi + +FORMAT_DIFF=$(git diff -U0 master -- ${FILES_TO_CHECK} | python ./utils/clang-format-diff.py -p1 -style=file) + +if [ -z "${FORMAT_DIFF}" ]; then + echo "All source code in PR properly formatted." + exit 0 +else + echo "Found formatting errors!" + echo "${FORMAT_DIFF}" + exit 1 +fi diff --git a/utils/check_copyright.py b/utils/check_copyright.py new file mode 100755 index 000000000..fc249e9bd --- /dev/null +++ b/utils/check_copyright.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python +# Copyright (c) 2016 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Checks for copyright notices in all the files that need them under the +current directory. Optionally insert them. When inserting, replaces +an MIT or Khronos free use license with Apache 2. +""" +from __future__ import print_function + +import argparse +import fileinput +import fnmatch +import inspect +import os +import re +import sys + +# List of designated copyright owners. +AUTHORS = ['The Khronos Group Inc.', + 'LunarG Inc.', + 'Google Inc.', + 'Google LLC', + 'Pierre Moreau'] +CURRENT_YEAR='2019' + +YEARS = '(2014-2016|2015-2016|2016|2016-2017|2017|2018|2019)' +COPYRIGHT_RE = re.compile( + 'Copyright \(c\) {} ({})'.format(YEARS, '|'.join(AUTHORS))) + +MIT_BEGIN_RE = re.compile('Permission is hereby granted, ' + 'free of charge, to any person obtaining a') +MIT_END_RE = re.compile('MATERIALS OR THE USE OR OTHER DEALINGS IN ' + 'THE MATERIALS.') +APACHE2_BEGIN_RE = re.compile('Licensed under the Apache License, ' + 'Version 2.0 \(the "License"\);') +APACHE2_END_RE = re.compile('limitations under the License.') + +LICENSED = """Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License.""" +LICENSED_LEN = 10 # Number of lines in LICENSED + + +def find(top, filename_glob, skip_glob_dir_list, skip_glob_files_list): + """Returns files in the tree rooted at top matching filename_glob but not + in directories matching skip_glob_dir_list nor files matching + skip_glob_dir_list.""" + + file_list = [] + for path, dirs, files in os.walk(top): + for glob in skip_glob_dir_list: + for match in fnmatch.filter(dirs, glob): + dirs.remove(match) + for filename in fnmatch.filter(files, filename_glob): + full_file = os.path.join(path, filename) + if full_file not in skip_glob_files_list: + file_list.append(full_file) + return file_list + + +def filtered_descendants(glob): + """Returns glob-matching filenames under the current directory, but skips + some irrelevant paths.""" + return find('.', glob, ['third_party', 'external', 'CompilerIdCXX', + 'build*', 'out*'], ['./utils/clang-format-diff.py']) + + +def skip(line): + """Returns true if line is all whitespace or shebang.""" + stripped = line.lstrip() + return stripped == '' or stripped.startswith('#!') + + +def comment(text, prefix): + """Returns commented-out text. + + Each line of text will be prefixed by prefix and a space character. Any + trailing whitespace will be trimmed. + """ + accum = ['{} {}'.format(prefix, line).rstrip() for line in text.split('\n')] + return '\n'.join(accum) + + +def insert_copyright(author, glob, comment_prefix): + """Finds all glob-matching files under the current directory and inserts the + copyright message, and license notice. An MIT license or Khronos free + use license (modified MIT) is replaced with an Apache 2 license. + + The copyright message goes into the first non-whitespace, non-shebang line + in a file. The license notice follows it. Both are prefixed on each line + by comment_prefix and a space. + """ + + copyright = comment('Copyright (c) {} {}'.format(CURRENT_YEAR, author), + comment_prefix) + '\n\n' + licensed = comment(LICENSED, comment_prefix) + '\n\n' + for file in filtered_descendants(glob): + # Parsing states are: + # 0 Initial: Have not seen a copyright declaration. + # 1 Seen a copyright line and no other interesting lines + # 2 In the middle of an MIT or Khronos free use license + # 9 Exited any of the above + state = 0 + update_file = False + for line in fileinput.input(file, inplace=1): + emit = True + if state is 0: + if COPYRIGHT_RE.search(line): + state = 1 + elif skip(line): + pass + else: + # Didn't see a copyright. Inject copyright and license. + sys.stdout.write(copyright) + sys.stdout.write(licensed) + # Assume there isn't a previous license notice. + state = 1 + elif state is 1: + if MIT_BEGIN_RE.search(line): + state = 2 + emit = False + elif APACHE2_BEGIN_RE.search(line): + # Assume an Apache license is preceded by a copyright + # notice. So just emit it like the rest of the file. + state = 9 + elif state is 2: + # Replace the MIT license with Apache 2 + emit = False + if MIT_END_RE.search(line): + state = 9 + sys.stdout.write(licensed) + if emit: + sys.stdout.write(line) + + +def alert_if_no_copyright(glob, comment_prefix): + """Prints names of all files missing either a copyright or Apache 2 license. + + Finds all glob-matching files under the current directory and checks if they + contain the copyright message and license notice. Prints the names of all the + files that don't meet both criteria. + + Returns the total number of file names printed. + """ + printed_count = 0 + for file in filtered_descendants(glob): + has_copyright = False + has_apache2 = False + line_num = 0 + apache_expected_end = 0 + with open(file) as contents: + for line in contents: + line_num += 1 + if COPYRIGHT_RE.search(line): + has_copyright = True + if APACHE2_BEGIN_RE.search(line): + apache_expected_end = line_num + LICENSED_LEN + if (line_num is apache_expected_end) and APACHE2_END_RE.search(line): + has_apache2 = True + if not (has_copyright and has_apache2): + message = file + if not has_copyright: + message += ' has no copyright' + if not has_apache2: + message += ' has no Apache 2 license notice' + print(message) + printed_count += 1 + return printed_count + + +class ArgParser(argparse.ArgumentParser): + def __init__(self): + super(ArgParser, self).__init__( + description=inspect.getdoc(sys.modules[__name__])) + self.add_argument('--update', dest='author', action='store', + help='For files missing a copyright notice, insert ' + 'one for the given author, and add a license ' + 'notice. The author must be in the AUTHORS ' + 'list in the script.') + + +def main(): + glob_comment_pairs = [('*.h', '//'), ('*.hpp', '//'), ('*.sh', '#'), + ('*.py', '#'), ('*.cpp', '//'), + ('CMakeLists.txt', '#')] + argparser = ArgParser() + args = argparser.parse_args() + + if args.author: + if args.author not in AUTHORS: + print('error: --update argument must be in the AUTHORS list in ' + 'check_copyright.py: {}'.format(AUTHORS)) + sys.exit(1) + for pair in glob_comment_pairs: + insert_copyright(args.author, *pair) + sys.exit(0) + else: + count = sum([alert_if_no_copyright(*p) for p in glob_comment_pairs]) + sys.exit(count > 0) + + +if __name__ == '__main__': + main() diff --git a/utils/check_symbol_exports.py b/utils/check_symbol_exports.py new file mode 100755 index 000000000..c9c0364df --- /dev/null +++ b/utils/check_symbol_exports.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python +# Copyright (c) 2017 Google Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Checks names of global exports from a library.""" + +from __future__ import print_function + +import os.path +import re +import subprocess +import sys + + +PROG = 'check_symbol_exports' + + +def command_output(cmd, directory): + """Runs a command in a directory and returns its standard output stream. + + Captures the standard error stream. + + Raises a RuntimeError if the command fails to launch or otherwise fails. + """ + p = subprocess.Popen(cmd, + cwd=directory, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True) + (stdout, _) = p.communicate() + if p.returncode != 0: + raise RuntimeError('Failed to run %s in %s' % (cmd, directory)) + return stdout + + +def check_library(library): + """Scans the given library file for global exports. If all such + exports are namespaced or begin with spv (in either C or C++ styles) + then return 0. Otherwise emit a message and return 1.""" + + # The pattern for a global symbol record + symbol_pattern = re.compile(r'^[0-aA-Fa-f]+ g *F \.text.*[0-9A-Fa-f]+ +(.*)') + + # Ok patterns are as follows, assuming Itanium name mangling: + # spv[A-Z] : extern "C" symbol starting with spv + # _ZN : something in a namespace + # _Z[0-9]+spv[A-Z_] : C++ symbol starting with spv[A-Z_] + symbol_ok_pattern = re.compile(r'^(spv[A-Z]|_ZN|_Z[0-9]+spv[A-Z_])') + seen = set() + result = 0 + for line in command_output(['objdump', '-t', library], '.').split('\n'): + match = symbol_pattern.search(line) + if match: + symbol = match.group(1) + if symbol not in seen: + seen.add(symbol) + #print("look at '{}'".format(symbol)) + if not symbol_ok_pattern.match(symbol): + print('{}: error: Unescaped exported symbol: {}'.format(PROG, symbol)) + result = 1 + return result + + +def main(): + import argparse + parser = argparse.ArgumentParser(description='Check global names exported from a library') + parser.add_argument('library', help='The static library to examine') + args = parser.parse_args() + + if not os.path.isfile(args.library): + print('{}: error: {} does not exist'.format(PROG, args.library)) + sys.exit(1) + + if os.name is 'posix': + status = check_library(args.library) + sys.exit(status) + else: + print('Passing test since not on Posix') + sys.exit(0) + + +if __name__ == '__main__': + main() diff --git a/utils/fixup_fuzz_result.py b/utils/fixup_fuzz_result.py new file mode 100755 index 000000000..9fe54a3cc --- /dev/null +++ b/utils/fixup_fuzz_result.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# Copyright (c) 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +if len(sys.argv) < 1: + print("Need file to chop"); + +with open(sys.argv[1], mode='rb') as file: + file_content = file.read() + content = file_content[:len(file_content) - (len(file_content) % 4)] + sys.stdout.write(content) + diff --git a/utils/generate_grammar_tables.py b/utils/generate_grammar_tables.py new file mode 100755 index 000000000..aabdad505 --- /dev/null +++ b/utils/generate_grammar_tables.py @@ -0,0 +1,749 @@ +#!/usr/bin/env python +# Copyright (c) 2016 Google Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Generates various info tables from SPIR-V JSON grammar.""" + +from __future__ import print_function + +import errno +import json +import os.path +import re + +# Prefix for all C variables generated by this script. +PYGEN_VARIABLE_PREFIX = 'pygen_variable' + +# Extensions to recognize, but which don't necessarily come from the SPIR-V +# core or KHR grammar files. Get this list from the SPIR-V registery web page. +# NOTE: Only put things on this list if it is not in those grammar files. +EXTENSIONS_FROM_SPIRV_REGISTRY_AND_NOT_FROM_GRAMMARS = """ +SPV_AMD_gcn_shader +SPV_AMD_gpu_shader_half_float +SPV_AMD_gpu_shader_int16 +SPV_AMD_shader_trinary_minmax +""" + + +def make_path_to_file(f): + """Makes all ancestor directories to the given file, if they + don't yet exist. + + Arguments: + f: The file whose ancestor directories are to be created. + """ + dir = os.path.dirname(os.path.abspath(f)) + try: + os.makedirs(dir) + except OSError as e: + if e.errno == errno.EEXIST and os.path.isdir(dir): + pass + else: + raise + + +def convert_min_required_version(version): + """Converts the minimal required SPIR-V version encoded in the + grammar to the symbol in SPIRV-Tools""" + if version is None: + return 'SPV_SPIRV_VERSION_WORD(1, 0)' + if version == 'None': + return '0xffffffffu' + return 'SPV_SPIRV_VERSION_WORD({})'.format(version.replace('.', ',')) + + +def compose_capability_list(caps): + """Returns a string containing a braced list of capabilities as enums. + + Arguments: + - caps: a sequence of capability names + + Returns: + a string containing the braced list of SpvCapability* enums named by caps. + """ + return "{" + ", ".join(['SpvCapability{}'.format(c) for c in caps]) + "}" + + +def get_capability_array_name(caps): + """Returns the name of the array containing all the given capabilities. + + Args: + - caps: a sequence of capability names + """ + if not caps: + return 'nullptr' + return '{}_caps_{}'.format(PYGEN_VARIABLE_PREFIX, ''.join(caps)) + + +def generate_capability_arrays(caps): + """Returns the arrays of capabilities. + + Arguments: + - caps: a sequence of sequence of capability names + """ + caps = sorted(set([tuple(c) for c in caps if c])) + arrays = [ + 'static const SpvCapability {}[] = {};'.format( + get_capability_array_name(c), compose_capability_list(c)) + for c in caps] + return '\n'.join(arrays) + + +def compose_extension_list(exts): + """Returns a string containing a braced list of extensions as enums. + + Arguments: + - exts: a sequence of extension names + + Returns: + a string containing the braced list of extensions named by exts. + """ + return "{" + ", ".join( + ['spvtools::Extension::k{}'.format(e) for e in exts]) + "}" + + +def get_extension_array_name(extensions): + """Returns the name of the array containing all the given extensions. + + Args: + - extensions: a sequence of extension names + """ + if not extensions: + return 'nullptr' + else: + return '{}_exts_{}'.format( + PYGEN_VARIABLE_PREFIX, ''.join(extensions)) + + +def generate_extension_arrays(extensions): + """Returns the arrays of extensions. + + Arguments: + - caps: a sequence of sequence of extension names + """ + extensions = sorted(set([tuple(e) for e in extensions if e])) + arrays = [ + 'static const spvtools::Extension {}[] = {};'.format( + get_extension_array_name(e), compose_extension_list(e)) + for e in extensions] + return '\n'.join(arrays) + + +def convert_operand_kind(operand_tuple): + """Returns the corresponding operand type used in spirv-tools for + the given operand kind and quantifier used in the JSON grammar. + + Arguments: + - operand_tuple: a tuple of two elements: + - operand kind: used in the JSON grammar + - quantifier: '', '?', or '*' + + Returns: + a string of the enumerant name in spv_operand_type_t + """ + kind, quantifier = operand_tuple + # The following cases are where we differ between the JSON grammar and + # spirv-tools. + if kind == 'IdResultType': + kind = 'TypeId' + elif kind == 'IdResult': + kind = 'ResultId' + elif kind == 'IdMemorySemantics' or kind == 'MemorySemantics': + kind = 'MemorySemanticsId' + elif kind == 'IdScope' or kind == 'Scope': + kind = 'ScopeId' + elif kind == 'IdRef': + kind = 'Id' + + elif kind == 'ImageOperands': + kind = 'Image' + elif kind == 'Dim': + kind = 'Dimensionality' + elif kind == 'ImageFormat': + kind = 'SamplerImageFormat' + elif kind == 'KernelEnqueueFlags': + kind = 'KernelEnqFlags' + + elif kind == 'LiteralExtInstInteger': + kind = 'ExtensionInstructionNumber' + elif kind == 'LiteralSpecConstantOpInteger': + kind = 'SpecConstantOpNumber' + elif kind == 'LiteralContextDependentNumber': + kind = 'TypedLiteralNumber' + + elif kind == 'PairLiteralIntegerIdRef': + kind = 'LiteralIntegerId' + elif kind == 'PairIdRefLiteralInteger': + kind = 'IdLiteralInteger' + elif kind == 'PairIdRefIdRef': # Used by OpPhi in the grammar + kind = 'Id' + + if kind == 'FPRoundingMode': + kind = 'FpRoundingMode' + elif kind == 'FPFastMathMode': + kind = 'FpFastMathMode' + + if quantifier == '?': + kind = 'Optional{}'.format(kind) + elif quantifier == '*': + kind = 'Variable{}'.format(kind) + + return 'SPV_OPERAND_TYPE_{}'.format( + re.sub(r'([a-z])([A-Z])', r'\1_\2', kind).upper()) + + +class InstInitializer(object): + """Instances holds a SPIR-V instruction suitable for printing as + the initializer for spv_opcode_desc_t.""" + + def __init__(self, opname, caps, exts, operands, version): + """Initialization. + + Arguments: + - opname: opcode name (with the 'Op' prefix) + - caps: a sequence of capability names required by this opcode + - exts: a sequence of names of extensions enabling this enumerant + - operands: a sequence of (operand-kind, operand-quantifier) tuples + - version: minimal SPIR-V version required for this opcode + """ + + assert opname.startswith('Op') + self.opname = opname[2:] # Remove the "Op" prefix. + self.num_caps = len(caps) + self.caps_mask = get_capability_array_name(caps) + self.num_exts = len(exts) + self.exts = get_extension_array_name(exts) + self.operands = [convert_operand_kind(o) for o in operands] + + self.fix_syntax() + + operands = [o[0] for o in operands] + self.ref_type_id = 'IdResultType' in operands + self.def_result_id = 'IdResult' in operands + + self.version = convert_min_required_version(version) + + def fix_syntax(self): + """Fix an instruction's syntax, adjusting for differences between + the officially released grammar and how SPIRV-Tools uses the grammar. + + Fixes: + - ExtInst should not end with SPV_OPERAND_VARIABLE_ID. + https://github.com/KhronosGroup/SPIRV-Tools/issues/233 + """ + if (self.opname == 'ExtInst' + and self.operands[-1] == 'SPV_OPERAND_TYPE_VARIABLE_ID'): + self.operands.pop() + + def __str__(self): + template = ['{{"{opname}"', 'SpvOp{opname}', + '{num_caps}', '{caps_mask}', + '{num_operands}', '{{{operands}}}', + '{def_result_id}', '{ref_type_id}', + '{num_exts}', '{exts}', + '{min_version}}}'] + return ', '.join(template).format( + opname=self.opname, + num_caps=self.num_caps, + caps_mask=self.caps_mask, + num_operands=len(self.operands), + operands=', '.join(self.operands), + def_result_id=(1 if self.def_result_id else 0), + ref_type_id=(1 if self.ref_type_id else 0), + num_exts=self.num_exts, + exts=self.exts, + min_version=self.version) + + +class ExtInstInitializer(object): + """Instances holds a SPIR-V extended instruction suitable for printing as + the initializer for spv_ext_inst_desc_t.""" + + def __init__(self, opname, opcode, caps, operands): + """Initialization. + + Arguments: + - opname: opcode name + - opcode: enumerant value for this opcode + - caps: a sequence of capability names required by this opcode + - operands: a sequence of (operand-kind, operand-quantifier) tuples + """ + self.opname = opname + self.opcode = opcode + self.num_caps = len(caps) + self.caps_mask = get_capability_array_name(caps) + self.operands = [convert_operand_kind(o) for o in operands] + self.operands.append('SPV_OPERAND_TYPE_NONE') + + def __str__(self): + template = ['{{"{opname}"', '{opcode}', '{num_caps}', '{caps_mask}', + '{{{operands}}}}}'] + return ', '.join(template).format( + opname=self.opname, + opcode=self.opcode, + num_caps=self.num_caps, + caps_mask=self.caps_mask, + operands=', '.join(self.operands)) + + +def generate_instruction(inst, is_ext_inst): + """Returns the C initializer for the given SPIR-V instruction. + + Arguments: + - inst: a dict containing information about a SPIR-V instruction + - is_ext_inst: a bool indicating whether |inst| is an extended + instruction. + + Returns: + a string containing the C initializer for spv_opcode_desc_t or + spv_ext_inst_desc_t + """ + opname = inst.get('opname') + opcode = inst.get('opcode') + caps = inst.get('capabilities', []) + exts = inst.get('extensions', []) + operands = inst.get('operands', {}) + operands = [(o['kind'], o.get('quantifier', '')) for o in operands] + min_version = inst.get('version', None) + + assert opname is not None + + if is_ext_inst: + return str(ExtInstInitializer(opname, opcode, caps, operands)) + else: + return str(InstInitializer(opname, caps, exts, operands, min_version)) + + +def generate_instruction_table(inst_table): + """Returns the info table containing all SPIR-V instructions, + sorted by opcode, and prefixed by capability arrays. + + Note: + - the built-in sorted() function is guaranteed to be stable. + https://docs.python.org/3/library/functions.html#sorted + + Arguments: + - inst_table: a list containing all SPIR-V instructions. + """ + inst_table = sorted(inst_table, key=lambda k: (k['opcode'], k['opname'])) + + caps_arrays = generate_capability_arrays( + [inst.get('capabilities', []) for inst in inst_table]) + exts_arrays = generate_extension_arrays( + [inst.get('extensions', []) for inst in inst_table]) + + insts = [generate_instruction(inst, False) for inst in inst_table] + insts = ['static const spv_opcode_desc_t kOpcodeTableEntries[] = {{\n' + ' {}\n}};'.format(',\n '.join(insts))] + + return '{}\n\n{}\n\n{}'.format(caps_arrays, exts_arrays, '\n'.join(insts)) + + +def generate_extended_instruction_table(inst_table, set_name): + """Returns the info table containing all SPIR-V extended instructions, + sorted by opcode, and prefixed by capability arrays. + + Arguments: + - inst_table: a list containing all SPIR-V instructions. + - set_name: the name of the extended instruction set. + """ + inst_table = sorted(inst_table, key=lambda k: k['opcode']) + caps = [inst.get('capabilities', []) for inst in inst_table] + caps_arrays = generate_capability_arrays(caps) + insts = [generate_instruction(inst, True) for inst in inst_table] + insts = ['static const spv_ext_inst_desc_t {}_entries[] = {{\n' + ' {}\n}};'.format(set_name, ',\n '.join(insts))] + + return '{}\n\n{}'.format(caps_arrays, '\n'.join(insts)) + + +class EnumerantInitializer(object): + """Prints an enumerant as the initializer for spv_operand_desc_t.""" + + def __init__(self, enumerant, value, caps, exts, parameters, version): + """Initialization. + + Arguments: + - enumerant: enumerant name + - value: enumerant value + - caps: a sequence of capability names required by this enumerant + - exts: a sequence of names of extensions enabling this enumerant + - parameters: a sequence of (operand-kind, operand-quantifier) tuples + - version: minimal SPIR-V version required for this opcode + """ + self.enumerant = enumerant + self.value = value + self.num_caps = len(caps) + self.caps = get_capability_array_name(caps) + self.num_exts = len(exts) + self.exts = get_extension_array_name(exts) + self.parameters = [convert_operand_kind(p) for p in parameters] + self.version = convert_min_required_version(version) + + def __str__(self): + template = ['{{"{enumerant}"', '{value}', '{num_caps}', + '{caps}', '{num_exts}', '{exts}', + '{{{parameters}}}', '{min_version}}}'] + return ', '.join(template).format( + enumerant=self.enumerant, + value=self.value, + num_caps=self.num_caps, + caps=self.caps, + num_exts=self.num_exts, + exts=self.exts, + parameters=', '.join(self.parameters), + min_version=self.version) + + +def generate_enum_operand_kind_entry(entry): + """Returns the C initializer for the given operand enum entry. + + Arguments: + - entry: a dict containing information about an enum entry + + Returns: + a string containing the C initializer for spv_operand_desc_t + """ + enumerant = entry.get('enumerant') + value = entry.get('value') + caps = entry.get('capabilities', []) + exts = entry.get('extensions', []) + params = entry.get('parameters', []) + params = [p.get('kind') for p in params] + params = zip(params, [''] * len(params)) + version = entry.get('version', None) + + assert enumerant is not None + assert value is not None + + return str(EnumerantInitializer( + enumerant, value, caps, exts, params, version)) + + +def generate_enum_operand_kind(enum): + """Returns the C definition for the given operand kind.""" + kind = enum.get('kind') + assert kind is not None + + # Sort all enumerants first according to their values and then + # their names so that the symbols with the same values are + # grouped together. + if enum.get('category') == 'ValueEnum': + functor = lambda k: (k['value'], k['enumerant']) + else: + functor = lambda k: (int(k['value'], 16), k['enumerant']) + entries = sorted(enum.get('enumerants', []), key=functor) + + name = '{}_{}Entries'.format(PYGEN_VARIABLE_PREFIX, kind) + entries = [' {}'.format(generate_enum_operand_kind_entry(e)) + for e in entries] + + template = ['static const spv_operand_desc_t {name}[] = {{', + '{entries}', '}};'] + entries = '\n'.join(template).format( + name=name, + entries=',\n'.join(entries)) + + return kind, name, entries + + +def generate_operand_kind_table(enums): + """Returns the info table containing all SPIR-V operand kinds.""" + # We only need to output info tables for those operand kinds that are enums. + enums = [e for e in enums if e.get('category') in ['ValueEnum', 'BitEnum']] + + caps = [entry.get('capabilities', []) + for enum in enums + for entry in enum.get('enumerants', [])] + caps_arrays = generate_capability_arrays(caps) + + exts = [entry.get('extensions', []) + for enum in enums + for entry in enum.get('enumerants', [])] + exts_arrays = generate_extension_arrays(exts) + + enums = [generate_enum_operand_kind(e) for e in enums] + # We have three operand kinds that requires their optional counterpart to + # exist in the operand info table. + three_optional_enums = ['ImageOperands', 'AccessQualifier', 'MemoryAccess'] + three_optional_enums = [e for e in enums if e[0] in three_optional_enums] + enums.extend(three_optional_enums) + + enum_kinds, enum_names, enum_entries = zip(*enums) + # Mark the last three as optional ones. + enum_quantifiers = [''] * (len(enums) - 3) + ['?'] * 3 + # And we don't want redefinition of them. + enum_entries = enum_entries[:-3] + enum_kinds = [convert_operand_kind(e) + for e in zip(enum_kinds, enum_quantifiers)] + table_entries = zip(enum_kinds, enum_names, enum_names) + table_entries = [' {{{}, ARRAY_SIZE({}), {}}}'.format(*e) + for e in table_entries] + + template = [ + 'static const spv_operand_desc_group_t {p}_OperandInfoTable[] = {{', + '{enums}', '}};'] + table = '\n'.join(template).format( + p=PYGEN_VARIABLE_PREFIX, enums=',\n'.join(table_entries)) + + return '\n\n'.join((caps_arrays,) + (exts_arrays,) + enum_entries + (table,)) + + +def get_extension_list(instructions, operand_kinds): + """Returns extensions as an alphabetically sorted list of strings.""" + + things_with_an_extensions_field = [item for item in instructions] + + enumerants = sum([item.get('enumerants', []) for item in operand_kinds], []) + + things_with_an_extensions_field.extend(enumerants) + + extensions = sum([item.get('extensions', []) + for item in things_with_an_extensions_field + if item.get('extensions')], []) + + for item in EXTENSIONS_FROM_SPIRV_REGISTRY_AND_NOT_FROM_GRAMMARS.split(): + # If it's already listed in a grammar, then don't put it in the + # special exceptions list. + assert item not in extensions, "Extension %s is already in a grammar file" % item + + extensions.extend(EXTENSIONS_FROM_SPIRV_REGISTRY_AND_NOT_FROM_GRAMMARS.split()) + + # Validator would ignore type declaration unique check. Should only be used + # for legacy autogenerated test files containing multiple instances of the + # same type declaration, if fixing the test by other methods is too + # difficult. Shouldn't be used for any other reasons. + extensions.append('SPV_VALIDATOR_ignore_type_decl_unique') + + return sorted(set(extensions)) + + +def get_capabilities(operand_kinds): + """Returns capabilities as a list of JSON objects, in order of + appearance. + """ + enumerants = sum([item.get('enumerants', []) for item in operand_kinds + if item.get('kind') in ['Capability']], []) + return enumerants + + +def generate_extension_enum(extensions): + """Returns enumeration containing extensions declared in the grammar.""" + return ',\n'.join(['k' + extension for extension in extensions]) + + +def generate_extension_to_string_mapping(extensions): + """Returns mapping function from extensions to corresponding strings.""" + function = 'const char* ExtensionToString(Extension extension) {\n' + function += ' switch (extension) {\n' + template = ' case Extension::k{extension}:\n' \ + ' return "{extension}";\n' + function += ''.join([template.format(extension=extension) + for extension in extensions]) + function += ' };\n\n return "";\n}' + return function + + +def generate_string_to_extension_mapping(extensions): + """Returns mapping function from strings to corresponding extensions.""" + + function = ''' + bool GetExtensionFromString(const char* str, Extension* extension) {{ + static const char* known_ext_strs[] = {{ {strs} }}; + static const Extension known_ext_ids[] = {{ {ids} }}; + const auto b = std::begin(known_ext_strs); + const auto e = std::end(known_ext_strs); + const auto found = std::equal_range( + b, e, str, [](const char* str1, const char* str2) {{ + return std::strcmp(str1, str2) < 0; + }}); + if (found.first == e || found.first == found.second) return false; + + *extension = known_ext_ids[found.first - b]; + return true; + }} + '''.format(strs=', '.join(['"{}"'.format(e) for e in extensions]), + ids=', '.join(['Extension::k{}'.format(e) for e in extensions])) + + return function + + +def generate_capability_to_string_mapping(operand_kinds): + """Returns mapping function from capabilities to corresponding strings. + We take care to avoid emitting duplicate values. + """ + function = 'const char* CapabilityToString(SpvCapability capability) {\n' + function += ' switch (capability) {\n' + template = ' case SpvCapability{capability}:\n' \ + ' return "{capability}";\n' + emitted = set() # The values of capabilities we already have emitted + for capability in get_capabilities(operand_kinds): + value = capability.get('value') + if value not in emitted: + emitted.add(value) + function += template.format(capability=capability.get('enumerant')) + function += ' case SpvCapabilityMax:\n' \ + ' assert(0 && "Attempting to convert SpvCapabilityMax to string");\n' \ + ' return "";\n' + function += ' };\n\n return "";\n}' + return function + + +def generate_all_string_enum_mappings(extensions, operand_kinds): + """Returns all string-to-enum / enum-to-string mapping tables.""" + tables = [] + tables.append(generate_extension_to_string_mapping(extensions)) + tables.append(generate_string_to_extension_mapping(extensions)) + tables.append(generate_capability_to_string_mapping(operand_kinds)) + return '\n\n'.join(tables) + + +def main(): + import argparse + parser = argparse.ArgumentParser(description='Generate SPIR-V info tables') + + parser.add_argument('--spirv-core-grammar', metavar='', + type=str, required=False, + help='input JSON grammar file for core SPIR-V ' + 'instructions') + parser.add_argument('--extinst-debuginfo-grammar', metavar='', + type=str, required=False, default=None, + help='input JSON grammar file for DebugInfo extended ' + 'instruction set') + parser.add_argument('--extinst-glsl-grammar', metavar='', + type=str, required=False, default=None, + help='input JSON grammar file for GLSL extended ' + 'instruction set') + parser.add_argument('--extinst-opencl-grammar', metavar='', + type=str, required=False, default=None, + help='input JSON grammar file for OpenCL extended ' + 'instruction set') + + parser.add_argument('--core-insts-output', metavar='', + type=str, required=False, default=None, + help='output file for core SPIR-V instructions') + parser.add_argument('--glsl-insts-output', metavar='', + type=str, required=False, default=None, + help='output file for GLSL extended instruction set') + parser.add_argument('--opencl-insts-output', metavar='', + type=str, required=False, default=None, + help='output file for OpenCL extended instruction set') + parser.add_argument('--operand-kinds-output', metavar='', + type=str, required=False, default=None, + help='output file for operand kinds') + parser.add_argument('--extension-enum-output', metavar='', + type=str, required=False, default=None, + help='output file for extension enumeration') + parser.add_argument('--enum-string-mapping-output', metavar='', + type=str, required=False, default=None, + help='output file for enum-string mappings') + parser.add_argument('--extinst-vendor-grammar', metavar='', + type=str, required=False, default=None, + help='input JSON grammar file for vendor extended ' + 'instruction set'), + parser.add_argument('--vendor-insts-output', metavar='', + type=str, required=False, default=None, + help='output file for vendor extended instruction set') + args = parser.parse_args() + + if (args.core_insts_output is None) != \ + (args.operand_kinds_output is None): + print('error: --core-insts-output and --operand-kinds-output ' + 'should be specified together.') + exit(1) + if args.operand_kinds_output and not (args.spirv_core_grammar and args.extinst_debuginfo_grammar): + print('error: --operand-kinds-output requires --spirv-core-grammar ' + 'and --exinst-debuginfo-grammar') + exit(1) + if (args.glsl_insts_output is None) != \ + (args.extinst_glsl_grammar is None): + print('error: --glsl-insts-output and --extinst-glsl-grammar ' + 'should be specified together.') + exit(1) + if (args.opencl_insts_output is None) != \ + (args.extinst_opencl_grammar is None): + print('error: --opencl-insts-output and --extinst-opencl-grammar ' + 'should be specified together.') + exit(1) + if (args.vendor_insts_output is None) != \ + (args.extinst_vendor_grammar is None): + print('error: --vendor-insts-output and ' + '--extinst-vendor-grammar should be specified together.') + exit(1) + if all([args.core_insts_output is None, + args.glsl_insts_output is None, + args.opencl_insts_output is None, + args.vendor_insts_output is None, + args.extension_enum_output is None, + args.enum_string_mapping_output is None]): + print('error: at least one output should be specified.') + exit(1) + + if args.spirv_core_grammar is not None: + with open(args.spirv_core_grammar) as json_file: + core_grammar = json.loads(json_file.read()) + with open(args.extinst_debuginfo_grammar) as debuginfo_json_file: + debuginfo_grammar = json.loads(debuginfo_json_file.read()) + instructions = [] + instructions.extend(core_grammar['instructions']) + instructions.extend(debuginfo_grammar['instructions']) + operand_kinds = [] + operand_kinds.extend(core_grammar['operand_kinds']) + operand_kinds.extend(debuginfo_grammar['operand_kinds']) + extensions = get_extension_list(instructions, operand_kinds) + if args.core_insts_output is not None: + make_path_to_file(args.core_insts_output) + make_path_to_file(args.operand_kinds_output) + print(generate_instruction_table(core_grammar['instructions']), + file=open(args.core_insts_output, 'w')) + print(generate_operand_kind_table(operand_kinds), + file=open(args.operand_kinds_output, 'w')) + if args.extension_enum_output is not None: + make_path_to_file(args.extension_enum_output) + print(generate_extension_enum(extensions), + file=open(args.extension_enum_output, 'w')) + if args.enum_string_mapping_output is not None: + make_path_to_file(args.enum_string_mapping_output) + print(generate_all_string_enum_mappings(extensions, operand_kinds), + file=open(args.enum_string_mapping_output, 'w')) + + if args.extinst_glsl_grammar is not None: + with open(args.extinst_glsl_grammar) as json_file: + grammar = json.loads(json_file.read()) + make_path_to_file(args.glsl_insts_output) + print(generate_extended_instruction_table( + grammar['instructions'], "glsl"), + file=open(args.glsl_insts_output, 'w')) + + if args.extinst_opencl_grammar is not None: + with open(args.extinst_opencl_grammar) as json_file: + grammar = json.loads(json_file.read()) + make_path_to_file(args.opencl_insts_output) + print(generate_extended_instruction_table( + grammar['instructions'], "opencl"), + file=open(args.opencl_insts_output, 'w')) + + if args.extinst_vendor_grammar is not None: + with open(args.extinst_vendor_grammar) as json_file: + grammar = json.loads(json_file.read()) + make_path_to_file(args.vendor_insts_output) + name = args.extinst_vendor_grammar + start = name.find("extinst.") + len("extinst.") + name = name[start:-len(".grammar.json")].replace("-", "_") + print(generate_extended_instruction_table( + grammar['instructions'], name), + file=open(args.vendor_insts_output, 'w')) + + +if __name__ == '__main__': + main() diff --git a/utils/generate_language_headers.py b/utils/generate_language_headers.py new file mode 100755 index 000000000..1886bf4f3 --- /dev/null +++ b/utils/generate_language_headers.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python +# Copyright (c) 2017 Google Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Generates language headers from a JSON grammar file""" + +from __future__ import print_function + +import errno +import json +import os.path +import re + + +def make_path_to_file(f): + """Makes all ancestor directories to the given file, if they + don't yet exist. + + Arguments: + f: The file whose ancestor directories are to be created. + """ + dir = os.path.dirname(os.path.abspath(f)) + try: + os.makedirs(dir) + except OSError as e: + if e.errno == errno.EEXIST and os.path.isdir(dir): + pass + else: + raise + +class ExtInstGrammar: + """The grammar for an extended instruction set""" + + def __init__(self, name, copyright, instructions, operand_kinds, version = None, revision = None): + self.name = name + self.copyright = copyright + self.instructions = instructions + self.operand_kinds = operand_kinds + self.version = version + self.revision = revision + + +class LangGenerator: + """A language-specific generator""" + + def __init__(self): + self.upper_case_initial = re.compile('^[A-Z]') + pass + + def comment_prefix(self): + return "" + + def namespace_prefix(self): + return "" + + def uses_guards(self): + return False + + def cpp_guard_preamble(self): + return "" + + def cpp_guard_postamble(self): + return "" + + def enum_value(self, prefix, name, value): + if self.upper_case_initial.match(name): + use_name = name + else: + use_name = '_' + name + + return " {}{} = {},".format(prefix, use_name, value) + + def generate(self, grammar): + """Returns a string that is the language-specific header for the given grammar""" + + parts = [] + if grammar.copyright: + parts.extend(["{}{}".format(self.comment_prefix(), f) for f in grammar.copyright]) + parts.append('') + + guard = 'SPIRV_EXTINST_{}_H_'.format(grammar.name) + if self.uses_guards: + parts.append('#ifndef {}'.format(guard)) + parts.append('#define {}'.format(guard)) + parts.append('') + + parts.append(self.cpp_guard_preamble()) + + if grammar.version: + parts.append(self.const_definition(grammar.name, 'Version', grammar.version)) + + if grammar.revision is not None: + parts.append(self.const_definition(grammar.name, 'Revision', grammar.revision)) + + parts.append('') + + if grammar.instructions: + parts.append(self.enum_prefix(grammar.name, 'Instructions')) + for inst in grammar.instructions: + parts.append(self.enum_value(grammar.name, inst['opname'], inst['opcode'])) + parts.append(self.enum_end(grammar.name, 'Instructions')) + parts.append('') + + if grammar.operand_kinds: + for kind in grammar.operand_kinds: + parts.append(self.enum_prefix(grammar.name, kind['kind'])) + for e in kind['enumerants']: + parts.append(self.enum_value(grammar.name, e['enumerant'], e['value'])) + parts.append(self.enum_end(grammar.name, kind['kind'])) + parts.append('') + + parts.append(self.cpp_guard_postamble()) + + if self.uses_guards: + parts.append('#endif // {}'.format(guard)) + + return '\n'.join(parts) + + +class CLikeGenerator(LangGenerator): + def uses_guards(self): + return True + + def comment_prefix(self): + return "// " + + def const_definition(self, prefix, var, value): + # Use an anonymous enum. Don't use a static const int variable because + # that can bloat binary size. + return 'enum {0} {1}{2} = {3}, {1}{2}_BitWidthPadding = 0x7fffffff {4};'.format( + '{', prefix, var, value, '}') + + def enum_prefix(self, prefix, name): + return 'enum {}{} {}'.format(prefix, name, '{') + + def enum_end(self, prefix, enum): + return ' {}{}Max = 0x7ffffff\n{};\n'.format(prefix, enum, '}') + + def cpp_guard_preamble(self): + return '#ifdef __cplusplus\nextern "C" {\n#endif\n' + + def cpp_guard_postamble(self): + return '#ifdef __cplusplus\n}\n#endif\n' + + +class CGenerator(CLikeGenerator): + pass + + +def main(): + import argparse + parser = argparse.ArgumentParser(description='Generate language headers from a JSON grammar') + + parser.add_argument('--extinst-name', + type=str, required=True, + help='The name to use in tokens') + parser.add_argument('--extinst-grammar', metavar='', + type=str, required=True, + help='input JSON grammar file for extended instruction set') + parser.add_argument('--extinst-output-base', metavar='', + type=str, required=True, + help='Basename of the language-specific output file.') + args = parser.parse_args() + + with open(args.extinst_grammar) as json_file: + grammar_json = json.loads(json_file.read()) + grammar = ExtInstGrammar(name = args.extinst_name, + copyright = grammar_json['copyright'], + instructions = grammar_json['instructions'], + operand_kinds = grammar_json['operand_kinds'], + version = grammar_json['version'], + revision = grammar_json['revision']) + make_path_to_file(args.extinst_output_base) + print(CGenerator().generate(grammar), file=open(args.extinst_output_base + '.h', 'w')) + + +if __name__ == '__main__': + main() diff --git a/utils/generate_registry_tables.py b/utils/generate_registry_tables.py new file mode 100755 index 000000000..8b1c35721 --- /dev/null +++ b/utils/generate_registry_tables.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python +# Copyright (c) 2016 Google Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Generates the vendor tool table from the SPIR-V XML registry.""" + +from __future__ import print_function + +import distutils.dir_util +import os.path +import xml.etree.ElementTree + + +def generate_vendor_table(registry): + """Returns a list of C style initializers for the registered vendors + and their tools. + + Args: + registry: The SPIR-V XMLregistry as an xml.ElementTree + """ + + lines = [] + for ids in registry.iter('ids'): + if 'vendor' == ids.attrib['type']: + for an_id in ids.iter('id'): + value = an_id.attrib['value'] + vendor = an_id.attrib['vendor'] + if 'tool' in an_id.attrib: + tool = an_id.attrib['tool'] + vendor_tool = vendor + ' ' + tool + else: + tool = '' + vendor_tool = vendor + line = '{' + '{}, "{}", "{}", "{}"'.format(value, + vendor, + tool, + vendor_tool) + '},' + lines.append(line) + return '\n'.join(lines) + + +def main(): + import argparse + parser = argparse.ArgumentParser(description= + 'Generate tables from SPIR-V XML registry') + parser.add_argument('--xml', metavar='', + type=str, required=True, + help='SPIR-V XML Registry file') + parser.add_argument('--generator-output', metavar='', + type=str, required=True, + help='output file for SPIR-V generators table') + args = parser.parse_args() + + with open(args.xml) as xml_in: + registry = xml.etree.ElementTree.fromstring(xml_in.read()) + + distutils.dir_util.mkpath(os.path.dirname(args.generator_output)) + print(generate_vendor_table(registry), file=open(args.generator_output, 'w')) + + +if __name__ == '__main__': + main() diff --git a/utils/generate_vim_syntax.py b/utils/generate_vim_syntax.py new file mode 100755 index 000000000..03c0b478b --- /dev/null +++ b/utils/generate_vim_syntax.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python +# Copyright (c) 2016 Google Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Generates Vim syntax rules for SPIR-V assembly (.spvasm) files""" + +from __future__ import print_function + +import json + +PREAMBLE="""" Vim syntax file +" Language: spvasm +" Generated by SPIRV-Tools + +if version < 600 + syntax clear +elseif exists("b:current_syntax") + finish +endif + +syn case match +""" + +POSTAMBLE=""" + +syntax keyword spvasmTodo TODO FIXME contained + +syn match spvasmIdNumber /%\d\+\>/ + +" The assembler treats the leading minus sign as part of the number token. +" This applies to integers, and to floats below. +syn match spvasmNumber /-\?\<\d\+\>/ + +" Floating point literals. +" In general, C++ requires at least digit in the mantissa, and the +" floating point is optional. This applies to both the regular decimal float +" case and the hex float case. + +" First case: digits before the optional decimal, no trailing digits. +syn match spvasmFloat /-\?\d\+\.\?\(e[+-]\d\+\)\?/ +" Second case: optional digits before decimal, trailing digits +syn match spvasmFloat /-\?\d*\.\d\+\(e[+-]\d\+\)\?/ + +" First case: hex digits before the optional decimal, no trailing hex digits. +syn match spvasmFloat /-\?0[xX]\\x\+\.\?p[-+]\d\+/ +" Second case: optional hex digits before decimal, trailing hex digits +syn match spvasmFloat /-\?0[xX]\\x*\.\\x\+p[-+]\d\+/ + +syn match spvasmComment /;.*$/ contains=spvasmTodo +syn region spvasmString start=/"/ skip=/\\\\"/ end=/"/ +syn match spvasmId /%[a-zA-Z_][a-zA-Z_0-9]*/ + +" Highlight unknown constants and statements as errors +syn match spvasmError /[a-zA-Z][a-zA-Z_0-9]*/ + + +if version >= 508 || !exists("did_c_syn_inits") + if version < 508 + let did_c_syn_inits = 1 + command -nargs=+ HiLink hi link + else + command -nargs=+ HiLink hi def link + endif + + HiLink spvasmStatement Statement + HiLink spvasmNumber Number + HiLink spvasmComment Comment + HiLink spvasmString String + HiLink spvasmFloat Float + HiLink spvasmConstant Constant + HiLink spvasmIdNumber Identifier + HiLink spvasmId Identifier + HiLink spvasmTodo Todo + + delcommand HiLink +endif + +let b:current_syntax = "spvasm" +""" + +# This list is taken from the description of OpSpecConstantOp in SPIR-V 1.1. +# TODO(dneto): Propose that this information be embedded in the grammar file. +SPEC_CONSTANT_OP_OPCODES = """ + OpSConvert, OpFConvert + OpSNegate, OpNot + OpIAdd, OpISub + OpIMul, OpUDiv, OpSDiv, OpUMod, OpSRem, OpSMod + OpShiftRightLogical, OpShiftRightArithmetic, OpShiftLeftLogical + OpBitwiseOr, OpBitwiseXor, OpBitwiseAnd + OpVectorShuffle, OpCompositeExtract, OpCompositeInsert + OpLogicalOr, OpLogicalAnd, OpLogicalNot, + OpLogicalEqual, OpLogicalNotEqual + OpSelect + OpIEqual, OpINotEqual + OpULessThan, OpSLessThan + OpUGreaterThan, OpSGreaterThan + OpULessThanEqual, OpSLessThanEqual + OpUGreaterThanEqual, OpSGreaterThanEqual + + OpQuantizeToF16 + + OpConvertFToS, OpConvertSToF + OpConvertFToU, OpConvertUToF + OpUConvert + OpConvertPtrToU, OpConvertUToPtr + OpGenericCastToPtr, OpPtrCastToGeneric + OpBitcast + OpFNegate + OpFAdd, OpFSub + OpFMul, OpFDiv + OpFRem, OpFMod + OpAccessChain, OpInBoundsAccessChain + OpPtrAccessChain, OpInBoundsPtrAccessChain""" + + +def EmitAsStatement(name): + """Emits the given name as a statement token""" + print('syn keyword spvasmStatement', name) + + +def EmitAsEnumerant(name): + """Emits the given name as an named operand token""" + print('syn keyword spvasmConstant', name) + + +def main(): + """Parses arguments, then generates the Vim syntax rules for SPIR-V assembly + on stdout.""" + import argparse + parser = argparse.ArgumentParser(description='Generate SPIR-V info tables') + parser.add_argument('--spirv-core-grammar', metavar='', + type=str, required=True, + help='input JSON grammar file for core SPIR-V ' + 'instructions') + parser.add_argument('--extinst-glsl-grammar', metavar='', + type=str, required=False, default=None, + help='input JSON grammar file for GLSL extended ' + 'instruction set') + parser.add_argument('--extinst-opencl-grammar', metavar='', + type=str, required=False, default=None, + help='input JSON grammar file for OpenGL extended ' + 'instruction set') + parser.add_argument('--extinst-debuginfo-grammar', metavar='', + type=str, required=False, default=None, + help='input JSON grammar file for DebugInfo extended ' + 'instruction set') + args = parser.parse_args() + + # Generate the syntax rules. + print(PREAMBLE) + + core = json.loads(open(args.spirv_core_grammar).read()) + print('\n" Core instructions') + for inst in core["instructions"]: + EmitAsStatement(inst['opname']) + print('\n" Core operand enums') + for operand_kind in core["operand_kinds"]: + if 'enumerants' in operand_kind: + for e in operand_kind['enumerants']: + EmitAsEnumerant(e['enumerant']) + + if args.extinst_glsl_grammar is not None: + print('\n" GLSL.std.450 extended instructions') + glsl = json.loads(open(args.extinst_glsl_grammar).read()) + # These opcodes are really enumerant operands for the OpExtInst + # instruction. + for inst in glsl["instructions"]: + EmitAsEnumerant(inst['opname']) + + if args.extinst_opencl_grammar is not None: + print('\n" OpenCL.std extended instructions') + opencl = json.loads(open(args.extinst_opencl_grammar).read()) + for inst in opencl["instructions"]: + EmitAsEnumerant(inst['opname']) + + if args.extinst_debuginfo_grammar is not None: + print('\n" DebugInfo extended instructions') + debuginfo = json.loads(open(args.extinst_debuginfo_grammar).read()) + for inst in debuginfo["instructions"]: + EmitAsEnumerant(inst['opname']) + print('\n" DebugInfo operand enums') + for operand_kind in debuginfo["operand_kinds"]: + if 'enumerants' in operand_kind: + for e in operand_kind['enumerants']: + EmitAsEnumerant(e['enumerant']) + + print('\n" OpSpecConstantOp opcodes') + for word in SPEC_CONSTANT_OP_OPCODES.split(' '): + stripped = word.strip('\n,') + if stripped != "": + # Treat as an enumerant, but without the leading "Op" + EmitAsEnumerant(stripped[2:]) + print(POSTAMBLE) + + +if __name__ == '__main__': + main() diff --git a/utils/update_build_version.py b/utils/update_build_version.py new file mode 100755 index 000000000..d71aeccd9 --- /dev/null +++ b/utils/update_build_version.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python + +# Copyright (c) 2016 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Updates an output file with version info unless the new content is the same +# as the existing content. +# +# Args: +# +# The output file will contain a line of text consisting of two C source syntax +# string literals separated by a comma: +# - The software version deduced from the CHANGES file in the given directory. +# - A longer string with the project name, the software version number, and +# git commit information for the directory. The commit information +# is the output of "git describe" if that succeeds, or "git rev-parse HEAD" +# if that succeeds, or otherwise a message containing the phrase +# "unknown hash". +# The string contents are escaped as necessary. + +from __future__ import print_function + +import datetime +import errno +import os +import os.path +import re +import subprocess +import sys +import time + + +def mkdir_p(directory): + """Make the directory, and all its ancestors as required. Any of the + directories are allowed to already exist.""" + + if directory == "": + # We're being asked to make the current directory. + return + + try: + os.makedirs(directory) + except OSError as e: + if e.errno == errno.EEXIST and os.path.isdir(directory): + pass + else: + raise + + +def command_output(cmd, directory): + """Runs a command in a directory and returns its standard output stream. + + Captures the standard error stream. + + Raises a RuntimeError if the command fails to launch or otherwise fails. + """ + p = subprocess.Popen(cmd, + cwd=directory, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + (stdout, _) = p.communicate() + if p.returncode != 0: + raise RuntimeError('Failed to run %s in %s' % (cmd, directory)) + return stdout + + +def deduce_software_version(directory): + """Returns a software version number parsed from the CHANGES file + in the given directory. + + The CHANGES file describes most recent versions first. + """ + + # Match the first well-formed version-and-date line. + # Allow trailing whitespace in the checked-out source code has + # unexpected carriage returns on a linefeed-only system such as + # Linux. + pattern = re.compile(r'^(v\d+\.\d+(-dev)?) \d\d\d\d-\d\d-\d\d\s*$') + changes_file = os.path.join(directory, 'CHANGES') + with open(changes_file, mode='rU') as f: + for line in f.readlines(): + match = pattern.match(line) + if match: + return match.group(1) + raise Exception('No version number found in {}'.format(changes_file)) + + +def describe(directory): + """Returns a string describing the current Git HEAD version as descriptively + as possible. + + Runs 'git describe', or alternately 'git rev-parse HEAD', in directory. If + successful, returns the output; otherwise returns 'unknown hash, '.""" + try: + # decode() is needed here for Python3 compatibility. In Python2, + # str and bytes are the same type, but not in Python3. + # Popen.communicate() returns a bytes instance, which needs to be + # decoded into text data first in Python3. And this decode() won't + # hurt Python2. + return command_output(['git', 'describe'], directory).rstrip().decode() + except: + try: + return command_output( + ['git', 'rev-parse', 'HEAD'], directory).rstrip().decode() + except: + # This is the fallback case where git gives us no information, + # e.g. because the source tree might not be in a git tree. + # In this case, usually use a timestamp. However, to ensure + # reproducible builds, allow the builder to override the wall + # clock time with enviornment variable SOURCE_DATE_EPOCH + # containing a (presumably) fixed timestamp. + timestamp = int(os.environ.get('SOURCE_DATE_EPOCH', time.time())) + formatted = datetime.date.fromtimestamp(timestamp).isoformat() + return 'unknown hash, {}'.format(formatted) + + +def main(): + if len(sys.argv) != 3: + print('usage: {} '.format(sys.argv[0])) + sys.exit(1) + + output_file = sys.argv[2] + mkdir_p(os.path.dirname(output_file)) + + software_version = deduce_software_version(sys.argv[1]) + new_content = '"{}", "SPIRV-Tools {} {}"\n'.format( + software_version, software_version, + describe(sys.argv[1]).replace('"', '\\"')) + + if os.path.isfile(output_file): + with open(output_file, 'r') as f: + if new_content == f.read(): + return + + with open(output_file, 'w') as f: + f.write(new_content) + +if __name__ == '__main__': + main() -- 2.11.0